Merge pull request #98 from yokowu/fix-api-key

fix(api-key): 删除用户同时删除api key, 给api key加上缓存
This commit is contained in:
Yoko
2025-07-17 16:49:38 +08:00
committed by GitHub
13 changed files with 115 additions and 37 deletions

View File

@@ -53,7 +53,8 @@ func newServer() (*Server, error) {
if err != nil {
return nil, err
}
proxyRepo := repo.NewProxyRepo(client)
redisClient := store.NewRedisCli(configConfig)
proxyRepo := repo.NewProxyRepo(client, redisClient)
modelRepo := repo2.NewModelRepo(client)
proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo)
llmProxy := proxy.NewLLMProxy(slogLogger, configConfig, proxyUsecase)
@@ -62,26 +63,25 @@ func newServer() (*Server, error) {
extensionRepo := repo4.NewExtensionRepo(client)
extensionUsecase := usecase2.NewExtensionUsecase(extensionRepo, configConfig, slogLogger)
proxyMiddleware := middleware.NewProxyMiddleware(proxyUsecase)
redisClient := store.NewRedisCli(configConfig)
activeMiddleware := middleware.NewActiveMiddleware(redisClient, slogLogger)
v1Handler := v1.NewV1Handler(slogLogger, web, llmProxy, proxyUsecase, openAIUsecase, extensionUsecase, proxyMiddleware, activeMiddleware, configConfig)
modelUsecase := usecase3.NewModelUsecase(slogLogger, modelRepo, configConfig)
sessionSession := session.NewSession(configConfig)
authMiddleware := middleware.NewAuthMiddleware(sessionSession, slogLogger)
modelHandler := v1_2.NewModelHandler(web, modelUsecase, authMiddleware, slogLogger)
modelHandler := v1_2.NewModelHandler(web, modelUsecase, authMiddleware, activeMiddleware, slogLogger)
ipdbIPDB, err := ipdb.NewIPDB(slogLogger)
if err != nil {
return nil, err
}
userRepo := repo5.NewUserRepo(client, ipdbIPDB)
userRepo := repo5.NewUserRepo(client, ipdbIPDB, redisClient)
userUsecase := usecase4.NewUserUsecase(configConfig, redisClient, userRepo, slogLogger)
userHandler := v1_3.NewUserHandler(web, userUsecase, extensionUsecase, authMiddleware, sessionSession, slogLogger, configConfig)
userHandler := v1_3.NewUserHandler(web, userUsecase, extensionUsecase, authMiddleware, activeMiddleware, sessionSession, slogLogger, configConfig)
dashboardRepo := repo6.NewDashboardRepo(client)
dashboardUsecase := usecase5.NewDashboardUsecase(dashboardRepo)
dashboardHandler := v1_4.NewDashboardHandler(web, dashboardUsecase, authMiddleware)
dashboardHandler := v1_4.NewDashboardHandler(web, dashboardUsecase, authMiddleware, activeMiddleware)
billingRepo := repo7.NewBillingRepo(client)
billingUsecase := usecase6.NewBillingUsecase(billingRepo)
billingHandler := v1_5.NewBillingHandler(web, billingUsecase, authMiddleware)
billingHandler := v1_5.NewBillingHandler(web, billingUsecase, authMiddleware, activeMiddleware)
server := &Server{
config: configConfig,
web: web,

View File

@@ -1,7 +1,8 @@
package consts
const (
UserActiveKeyFmt = "user:active:%s"
UserActiveKeyFmt = "user:active:%s"
AdminActiveKeyFmt = "admin:active:%s"
)
type UserStatus string

View File

@@ -251,7 +251,6 @@ func (a *AdminUser) From(e *db.Admin) *AdminUser {
a.ID = e.ID.String()
a.Username = e.Username
a.LastActiveAt = e.LastActiveAt.Unix()
a.Status = e.Status
a.CreatedAt = e.CreatedAt.Unix()

View File

@@ -15,13 +15,14 @@ func NewBillingHandler(
w *web.Web,
usecase domain.BillingUsecase,
auth *middleware.AuthMiddleware,
active *middleware.ActiveMiddleware,
) *BillingHandler {
b := &BillingHandler{
usecase: usecase,
}
g := w.Group("/api/v1/billing")
g.Use(auth.Auth())
g.Use(auth.Auth(), active.Active("admin"))
g.GET("/chat/record", web.BindHandler(b.ListChatRecord, web.WithPage()))
g.GET("/completion/record", web.BindHandler(b.ListCompletionRecord, web.WithPage()))

View File

@@ -15,11 +15,12 @@ func NewDashboardHandler(
w *web.Web,
usecase domain.DashboardUsecase,
auth *middleware.AuthMiddleware,
active *middleware.ActiveMiddleware,
) *DashboardHandler {
h := &DashboardHandler{usecase: usecase}
g := w.Group("/api/v1/dashboard")
g.Use(auth.Auth())
g.Use(auth.Auth(), active.Active("admin"))
g.GET("/statistics", web.BaseHandler(h.Statistics))
g.GET("/category-stat", web.BindHandler(h.CategoryStat))
g.GET("/time-stat", web.BindHandler(h.TimeStat))

View File

@@ -6,9 +6,10 @@ import (
"log/slog"
"time"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/labstack/echo/v4"
"github.com/redis/go-redis/v9"
"github.com/chaitin/MonkeyCode/backend/consts"
)
type ActiveMiddleware struct {
@@ -23,14 +24,24 @@ func NewActiveMiddleware(redis *redis.Client, logger *slog.Logger) *ActiveMiddle
}
}
func (a *ActiveMiddleware) Active() echo.MiddlewareFunc {
func (a *ActiveMiddleware) Active(scope string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if apikey := GetApiKey(c); apikey != nil {
if err := a.redis.Set(context.Background(), fmt.Sprintf(consts.UserActiveKeyFmt, apikey.UserID), time.Now().Unix(), 0).Err(); err != nil {
a.logger.With("error", err).ErrorContext(c.Request().Context(), "failed to set user active status in Redis")
switch scope {
case "admin":
if user := GetUser(c); user != nil {
if err := a.redis.Set(context.Background(), fmt.Sprintf(consts.AdminActiveKeyFmt, user.ID), time.Now().Unix(), 0).Err(); err != nil {
a.logger.With("error", err).ErrorContext(c.Request().Context(), "failed to set admin active status in Redis")
}
}
case "user":
if apikey := GetApiKey(c); apikey != nil {
if err := a.redis.Set(context.Background(), fmt.Sprintf(consts.UserActiveKeyFmt, apikey.UserID), time.Now().Unix(), 0).Err(); err != nil {
a.logger.With("error", err).ErrorContext(c.Request().Context(), "failed to set user active status in Redis")
}
}
}
return next(c)
}
}

View File

@@ -19,12 +19,13 @@ func NewModelHandler(
w *web.Web,
usecase domain.ModelUsecase,
auth *middleware.AuthMiddleware,
active *middleware.ActiveMiddleware,
logger *slog.Logger,
) *ModelHandler {
m := &ModelHandler{usecase: usecase, logger: logger.With("handler", "model")}
g := w.Group("/api/v1/model")
g.Use(auth.Auth())
g.Use(auth.Auth(), active.Active("admin"))
g.POST("/check", web.BindHandler(m.Check))
g.GET("", web.BaseHandler(m.List))

View File

@@ -107,12 +107,7 @@ func (r *ModelRepo) Update(ctx context.Context, id string, fn func(tx *db.Tx, ol
}
func (r *ModelRepo) MyModelList(ctx context.Context, req *domain.MyModelListReq) ([]*db.Model, error) {
userID, err := uuid.Parse(req.UserID)
if err != nil {
return nil, err
}
q := r.db.Model.Query().
Where(model.UserID(userID)).
Where(model.ModelType(req.ModelType)).
Order(model.ByCreatedAt(sql.OrderAsc()))
return q.All(ctx)

View File

@@ -49,10 +49,10 @@ func NewV1Handler(
g := w.Group("/v1", middleware.Auth())
g.GET("/models", web.BaseHandler(h.ModelList))
g.POST("/completion/accept", web.BindHandler(h.AcceptCompletion), active.Active())
g.POST("/chat/completions", web.BaseHandler(h.ChatCompletion), active.Active())
g.POST("/completions", web.BaseHandler(h.Completions), active.Active())
g.POST("/embeddings", web.BaseHandler(h.Embeddings), active.Active())
g.POST("/completion/accept", web.BindHandler(h.AcceptCompletion), active.Active("user"))
g.POST("/chat/completions", web.BaseHandler(h.ChatCompletion), active.Active("user"))
g.POST("/completions", web.BaseHandler(h.Completions), active.Active("user"))
g.POST("/embeddings", web.BaseHandler(h.Embeddings), active.Active("user"))
return h
}

View File

@@ -2,8 +2,12 @@ package repo
import (
"context"
"encoding/json"
"errors"
"time"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/db"
@@ -16,11 +20,12 @@ import (
)
type ProxyRepo struct {
db *db.Client
db *db.Client
redis *redis.Client
}
func NewProxyRepo(db *db.Client) domain.ProxyRepo {
return &ProxyRepo{db: db}
func NewProxyRepo(db *db.Client, redis *redis.Client) domain.ProxyRepo {
return &ProxyRepo{db: db, redis: redis}
}
func (r *ProxyRepo) SelectModelWithLoadBalancing(modelName string, modelType consts.ModelType) (*db.Model, error) {
@@ -35,12 +40,36 @@ func (r *ProxyRepo) SelectModelWithLoadBalancing(modelName string, modelType con
}
func (r *ProxyRepo) ValidateApiKey(ctx context.Context, key string) (*db.ApiKey, error) {
rkey := "sk-" + key
data, err := r.redis.Get(ctx, rkey).Result()
if err == nil {
key := db.ApiKey{}
if err := json.Unmarshal([]byte(data), &key); err != nil {
return nil, err
}
return &key, nil
}
if !errors.Is(err, redis.Nil) {
return nil, err
}
a, err := r.db.ApiKey.Query().
Where(apikey.Key(key), apikey.Status(consts.ApiKeyStatusActive)).
Only(ctx)
if err != nil {
return nil, err
}
b, err := json.Marshal(a)
if err != nil {
return nil, err
}
if err := r.redis.Set(ctx, rkey, string(b), 24*time.Hour).Err(); err != nil {
return nil, err
}
return a, nil
}

View File

@@ -43,6 +43,7 @@ func NewUserHandler(
usecase domain.UserUsecase,
euse domain.ExtensionUsecase,
auth *middleware.AuthMiddleware,
active *middleware.ActiveMiddleware,
session *session.Session,
logger *slog.Logger,
cfg *config.Config,
@@ -66,7 +67,7 @@ func NewUserHandler(
admin.POST("/login", web.BindHandler(u.AdminLogin))
admin.GET("/setting", web.BaseHandler(u.GetSetting))
admin.Use(auth.Auth())
admin.Use(auth.Auth(), active.Active("admin"))
admin.PUT("/setting", web.BindHandler(u.UpdateSetting))
admin.POST("/create", web.BindHandler(u.CreateAdmin))
admin.GET("/list", web.BaseHandler(u.AdminList, web.WithPage()))
@@ -80,7 +81,7 @@ func NewUserHandler(
g.POST("/register", web.BindHandler(u.Register))
g.POST("/login", web.BindHandler(u.Login))
g.Use(auth.Auth())
g.Use(auth.Auth(), active.Active("admin"))
g.PUT("/update", web.BindHandler(u.Update))
g.DELETE("/delete", web.BaseHandler(u.Delete))

View File

@@ -8,6 +8,7 @@ import (
"entgo.io/ent/dialect/sql"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/GoYoko/web"
@@ -27,12 +28,13 @@ import (
)
type UserRepo struct {
db *db.Client
ipdb *ipdb.IPDB
db *db.Client
ipdb *ipdb.IPDB
redis *redis.Client
}
func NewUserRepo(db *db.Client, ipdb *ipdb.IPDB) domain.UserRepo {
return &UserRepo{db: db, ipdb: ipdb}
func NewUserRepo(db *db.Client, ipdb *ipdb.IPDB, redis *redis.Client) domain.UserRepo {
return &UserRepo{db: db, ipdb: ipdb, redis: redis}
}
func (r *UserRepo) InitAdmin(ctx context.Context, username, password string) error {
@@ -252,6 +254,20 @@ func (r *UserRepo) Delete(ctx context.Context, id string) error {
return err
}
keys, err := tx.ApiKey.Query().Where(apikey.UserID(user.ID)).All(ctx)
if err != nil {
return err
}
for _, v := range keys {
if _, err := tx.ApiKey.Delete().Where(apikey.ID(v.ID)).Exec(ctx); err != nil {
return err
}
if err := r.redis.Del(ctx, fmt.Sprintf("sk-%s", v.Key)).Err(); err != nil {
return err
}
}
for _, v := range user.Edges.Identities {
if _, err := tx.UserIdentity.Delete().Where(useridentity.ID(v.ID)).Exec(ctx); err != nil {
return err

View File

@@ -3,6 +3,7 @@ package usecase
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/url"
@@ -81,7 +82,7 @@ func (u *UserUsecase) getUserActive(ctx context.Context, ids []string) (map[stri
m := make(map[string]int64)
for _, id := range ids {
key := fmt.Sprintf(consts.UserActiveKeyFmt, id)
if t, err := u.redis.Get(ctx, key).Int64(); err != nil {
if t, err := u.redis.Get(ctx, key).Int64(); err != nil && !errors.Is(err, redis.Nil) {
u.logger.With("key", key).With("error", err).Warn("get user active time failed")
} else {
m[id] = t
@@ -98,14 +99,36 @@ func (u *UserUsecase) AdminList(ctx context.Context, page *web.Pagination) (*dom
return nil, err
}
ids := cvt.Iter(admins, func(_ int, u *db.Admin) string { return u.ID.String() })
m, err := u.getAdminActive(ctx, ids)
if err != nil {
return nil, err
}
return &domain.ListAdminUserResp{
PageInfo: p,
Users: cvt.Iter(admins, func(_ int, e *db.Admin) *domain.AdminUser {
return cvt.From(e, &domain.AdminUser{}).From(e)
return cvt.From(e, &domain.AdminUser{
LastActiveAt: m[e.ID.String()],
})
}),
}, nil
}
func (u *UserUsecase) getAdminActive(ctx context.Context, ids []string) (map[string]int64, error) {
m := make(map[string]int64)
for _, id := range ids {
key := fmt.Sprintf(consts.AdminActiveKeyFmt, id)
if t, err := u.redis.Get(ctx, key).Int64(); err != nil && !errors.Is(err, redis.Nil) {
u.logger.With("key", key).With("error", err).Warn("get admin active time failed")
} else {
m[id] = t
}
}
return m, nil
}
// AdminLoginHistory implements domain.UserUsecase.
func (u *UserUsecase) AdminLoginHistory(ctx context.Context, page *web.Pagination) (*domain.ListAdminLoginHistoryResp, error) {
histories, p, err := u.repo.AdminLoginHistory(ctx, page)