diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index b2e0624..ab0016a 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -68,20 +68,20 @@ func newServer() (*Server, error) { modelUsecase := usecase3.NewModelUsecase(slogLogger, modelRepo, configConfig) sessionSession := session.NewSession(configConfig) authMiddleware := middleware.NewAuthMiddleware(sessionSession, slogLogger) - modelHandler := v1_2.NewModelHandler(web, modelUsecase, authMiddleware, slogLogger) + modelHandler := v1_2.NewModelHandler(web, modelUsecase, authMiddleware, activeMiddleware, slogLogger) ipdbIPDB, err := ipdb.NewIPDB(slogLogger) if err != nil { return nil, err } - userRepo := repo5.NewUserRepo(client, ipdbIPDB) + userRepo := repo5.NewUserRepo(client, ipdbIPDB, redisClient) userUsecase := usecase4.NewUserUsecase(configConfig, redisClient, userRepo, slogLogger) - userHandler := v1_3.NewUserHandler(web, userUsecase, extensionUsecase, authMiddleware, sessionSession, slogLogger, configConfig) + userHandler := v1_3.NewUserHandler(web, userUsecase, extensionUsecase, authMiddleware, activeMiddleware, sessionSession, slogLogger, configConfig) dashboardRepo := repo6.NewDashboardRepo(client) dashboardUsecase := usecase5.NewDashboardUsecase(dashboardRepo) - dashboardHandler := v1_4.NewDashboardHandler(web, dashboardUsecase, authMiddleware) + dashboardHandler := v1_4.NewDashboardHandler(web, dashboardUsecase, authMiddleware, activeMiddleware) billingRepo := repo7.NewBillingRepo(client) billingUsecase := usecase6.NewBillingUsecase(billingRepo) - billingHandler := v1_5.NewBillingHandler(web, billingUsecase, authMiddleware) + billingHandler := v1_5.NewBillingHandler(web, billingUsecase, authMiddleware, activeMiddleware) server := &Server{ config: configConfig, web: web, diff --git a/backend/consts/user.go b/backend/consts/user.go index 64cc308..8e6bb3a 100644 --- a/backend/consts/user.go +++ b/backend/consts/user.go @@ -1,7 +1,8 @@ package consts const ( - UserActiveKeyFmt = "user:active:%s" + UserActiveKeyFmt = "user:active:%s" + AdminActiveKeyFmt = "admin:active:%s" ) type UserStatus string diff --git a/backend/domain/user.go b/backend/domain/user.go index 9fa2806..0972a13 100644 --- a/backend/domain/user.go +++ b/backend/domain/user.go @@ -251,7 +251,6 @@ func (a *AdminUser) From(e *db.Admin) *AdminUser { a.ID = e.ID.String() a.Username = e.Username - a.LastActiveAt = e.LastActiveAt.Unix() a.Status = e.Status a.CreatedAt = e.CreatedAt.Unix() diff --git a/backend/internal/billing/handler/http/v1/billing.go b/backend/internal/billing/handler/http/v1/billing.go index 0b430e4..1ab76ae 100644 --- a/backend/internal/billing/handler/http/v1/billing.go +++ b/backend/internal/billing/handler/http/v1/billing.go @@ -15,13 +15,14 @@ func NewBillingHandler( w *web.Web, usecase domain.BillingUsecase, auth *middleware.AuthMiddleware, + active *middleware.ActiveMiddleware, ) *BillingHandler { b := &BillingHandler{ usecase: usecase, } g := w.Group("/api/v1/billing") - g.Use(auth.Auth()) + g.Use(auth.Auth(), active.Active("admin")) g.GET("/chat/record", web.BindHandler(b.ListChatRecord, web.WithPage())) g.GET("/completion/record", web.BindHandler(b.ListCompletionRecord, web.WithPage())) diff --git a/backend/internal/dashboard/handler/v1/dashboard.go b/backend/internal/dashboard/handler/v1/dashboard.go index 6022dec..afbf876 100644 --- a/backend/internal/dashboard/handler/v1/dashboard.go +++ b/backend/internal/dashboard/handler/v1/dashboard.go @@ -15,11 +15,12 @@ func NewDashboardHandler( w *web.Web, usecase domain.DashboardUsecase, auth *middleware.AuthMiddleware, + active *middleware.ActiveMiddleware, ) *DashboardHandler { h := &DashboardHandler{usecase: usecase} g := w.Group("/api/v1/dashboard") - g.Use(auth.Auth()) + g.Use(auth.Auth(), active.Active("admin")) g.GET("/statistics", web.BaseHandler(h.Statistics)) g.GET("/category-stat", web.BindHandler(h.CategoryStat)) g.GET("/time-stat", web.BindHandler(h.TimeStat)) diff --git a/backend/internal/middleware/active.go b/backend/internal/middleware/active.go index d6d04c5..e8365df 100644 --- a/backend/internal/middleware/active.go +++ b/backend/internal/middleware/active.go @@ -6,9 +6,10 @@ import ( "log/slog" "time" - "github.com/chaitin/MonkeyCode/backend/consts" "github.com/labstack/echo/v4" "github.com/redis/go-redis/v9" + + "github.com/chaitin/MonkeyCode/backend/consts" ) type ActiveMiddleware struct { @@ -23,14 +24,24 @@ func NewActiveMiddleware(redis *redis.Client, logger *slog.Logger) *ActiveMiddle } } -func (a *ActiveMiddleware) Active() echo.MiddlewareFunc { +func (a *ActiveMiddleware) Active(scope string) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if apikey := GetApiKey(c); apikey != nil { - if err := a.redis.Set(context.Background(), fmt.Sprintf(consts.UserActiveKeyFmt, apikey.UserID), time.Now().Unix(), 0).Err(); err != nil { - a.logger.With("error", err).ErrorContext(c.Request().Context(), "failed to set user active status in Redis") + switch scope { + case "admin": + if user := GetUser(c); user != nil { + if err := a.redis.Set(context.Background(), fmt.Sprintf(consts.AdminActiveKeyFmt, user.ID), time.Now().Unix(), 0).Err(); err != nil { + a.logger.With("error", err).ErrorContext(c.Request().Context(), "failed to set admin active status in Redis") + } + } + case "user": + if apikey := GetApiKey(c); apikey != nil { + if err := a.redis.Set(context.Background(), fmt.Sprintf(consts.UserActiveKeyFmt, apikey.UserID), time.Now().Unix(), 0).Err(); err != nil { + a.logger.With("error", err).ErrorContext(c.Request().Context(), "failed to set user active status in Redis") + } } } + return next(c) } } diff --git a/backend/internal/model/handler/http/v1/model.go b/backend/internal/model/handler/http/v1/model.go index fb10a3b..ee8c284 100644 --- a/backend/internal/model/handler/http/v1/model.go +++ b/backend/internal/model/handler/http/v1/model.go @@ -19,12 +19,13 @@ func NewModelHandler( w *web.Web, usecase domain.ModelUsecase, auth *middleware.AuthMiddleware, + active *middleware.ActiveMiddleware, logger *slog.Logger, ) *ModelHandler { m := &ModelHandler{usecase: usecase, logger: logger.With("handler", "model")} g := w.Group("/api/v1/model") - g.Use(auth.Auth()) + g.Use(auth.Auth(), active.Active("admin")) g.POST("/check", web.BindHandler(m.Check)) g.GET("", web.BaseHandler(m.List)) diff --git a/backend/internal/model/repo/model.go b/backend/internal/model/repo/model.go index 0316eff..a90cf8c 100644 --- a/backend/internal/model/repo/model.go +++ b/backend/internal/model/repo/model.go @@ -107,12 +107,7 @@ func (r *ModelRepo) Update(ctx context.Context, id string, fn func(tx *db.Tx, ol } func (r *ModelRepo) MyModelList(ctx context.Context, req *domain.MyModelListReq) ([]*db.Model, error) { - userID, err := uuid.Parse(req.UserID) - if err != nil { - return nil, err - } q := r.db.Model.Query(). - Where(model.UserID(userID)). Where(model.ModelType(req.ModelType)). Order(model.ByCreatedAt(sql.OrderAsc())) return q.All(ctx) diff --git a/backend/internal/openai/handler/v1/v1.go b/backend/internal/openai/handler/v1/v1.go index aa98a91..b1aaefc 100644 --- a/backend/internal/openai/handler/v1/v1.go +++ b/backend/internal/openai/handler/v1/v1.go @@ -49,10 +49,10 @@ func NewV1Handler( g := w.Group("/v1", middleware.Auth()) g.GET("/models", web.BaseHandler(h.ModelList)) - g.POST("/completion/accept", web.BindHandler(h.AcceptCompletion), active.Active()) - g.POST("/chat/completions", web.BaseHandler(h.ChatCompletion), active.Active()) - g.POST("/completions", web.BaseHandler(h.Completions), active.Active()) - g.POST("/embeddings", web.BaseHandler(h.Embeddings), active.Active()) + g.POST("/completion/accept", web.BindHandler(h.AcceptCompletion), active.Active("user")) + g.POST("/chat/completions", web.BaseHandler(h.ChatCompletion), active.Active("user")) + g.POST("/completions", web.BaseHandler(h.Completions), active.Active("user")) + g.POST("/embeddings", web.BaseHandler(h.Embeddings), active.Active("user")) return h } diff --git a/backend/internal/user/handler/v1/user.go b/backend/internal/user/handler/v1/user.go index 7292ace..804a903 100644 --- a/backend/internal/user/handler/v1/user.go +++ b/backend/internal/user/handler/v1/user.go @@ -43,6 +43,7 @@ func NewUserHandler( usecase domain.UserUsecase, euse domain.ExtensionUsecase, auth *middleware.AuthMiddleware, + active *middleware.ActiveMiddleware, session *session.Session, logger *slog.Logger, cfg *config.Config, @@ -66,7 +67,7 @@ func NewUserHandler( admin.POST("/login", web.BindHandler(u.AdminLogin)) admin.GET("/setting", web.BaseHandler(u.GetSetting)) - admin.Use(auth.Auth()) + admin.Use(auth.Auth(), active.Active("admin")) admin.PUT("/setting", web.BindHandler(u.UpdateSetting)) admin.POST("/create", web.BindHandler(u.CreateAdmin)) admin.GET("/list", web.BaseHandler(u.AdminList, web.WithPage())) @@ -80,7 +81,7 @@ func NewUserHandler( g.POST("/register", web.BindHandler(u.Register)) g.POST("/login", web.BindHandler(u.Login)) - g.Use(auth.Auth()) + g.Use(auth.Auth(), active.Active("admin")) g.PUT("/update", web.BindHandler(u.Update)) g.DELETE("/delete", web.BaseHandler(u.Delete)) diff --git a/backend/internal/user/repo/user.go b/backend/internal/user/repo/user.go index 4937bfc..eadc4f9 100644 --- a/backend/internal/user/repo/user.go +++ b/backend/internal/user/repo/user.go @@ -8,6 +8,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/google/uuid" + "github.com/redis/go-redis/v9" "github.com/GoYoko/web" @@ -27,12 +28,13 @@ import ( ) type UserRepo struct { - db *db.Client - ipdb *ipdb.IPDB + db *db.Client + ipdb *ipdb.IPDB + redis *redis.Client } -func NewUserRepo(db *db.Client, ipdb *ipdb.IPDB) domain.UserRepo { - return &UserRepo{db: db, ipdb: ipdb} +func NewUserRepo(db *db.Client, ipdb *ipdb.IPDB, redis *redis.Client) domain.UserRepo { + return &UserRepo{db: db, ipdb: ipdb, redis: redis} } func (r *UserRepo) InitAdmin(ctx context.Context, username, password string) error { @@ -251,10 +253,21 @@ func (r *UserRepo) Delete(ctx context.Context, id string) error { if err != nil { return err } - if _, err := tx.ApiKey.Delete().Where(apikey.UserID(user.ID)).Exec(ctx); err != nil { + + keys, err := tx.ApiKey.Query().Where(apikey.UserID(user.ID)).All(ctx) + if err != nil { return err } + for _, v := range keys { + if _, err := tx.ApiKey.Delete().Where(apikey.ID(v.ID)).Exec(ctx); err != nil { + return err + } + if err := r.redis.Del(ctx, fmt.Sprintf("sk-%s", v.Key)).Err(); err != nil { + return err + } + } + for _, v := range user.Edges.Identities { if _, err := tx.UserIdentity.Delete().Where(useridentity.ID(v.ID)).Exec(ctx); err != nil { return err diff --git a/backend/internal/user/usecase/user.go b/backend/internal/user/usecase/user.go index 77b79dc..e9a6de2 100644 --- a/backend/internal/user/usecase/user.go +++ b/backend/internal/user/usecase/user.go @@ -3,6 +3,7 @@ package usecase import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "net/url" @@ -81,7 +82,7 @@ func (u *UserUsecase) getUserActive(ctx context.Context, ids []string) (map[stri m := make(map[string]int64) for _, id := range ids { key := fmt.Sprintf(consts.UserActiveKeyFmt, id) - if t, err := u.redis.Get(ctx, key).Int64(); err != nil { + if t, err := u.redis.Get(ctx, key).Int64(); err != nil && !errors.Is(err, redis.Nil) { u.logger.With("key", key).With("error", err).Warn("get user active time failed") } else { m[id] = t @@ -98,14 +99,36 @@ func (u *UserUsecase) AdminList(ctx context.Context, page *web.Pagination) (*dom return nil, err } + ids := cvt.Iter(admins, func(_ int, u *db.Admin) string { return u.ID.String() }) + m, err := u.getAdminActive(ctx, ids) + if err != nil { + return nil, err + } + return &domain.ListAdminUserResp{ PageInfo: p, Users: cvt.Iter(admins, func(_ int, e *db.Admin) *domain.AdminUser { - return cvt.From(e, &domain.AdminUser{}).From(e) + return cvt.From(e, &domain.AdminUser{ + LastActiveAt: m[e.ID.String()], + }) }), }, nil } +func (u *UserUsecase) getAdminActive(ctx context.Context, ids []string) (map[string]int64, error) { + m := make(map[string]int64) + for _, id := range ids { + key := fmt.Sprintf(consts.AdminActiveKeyFmt, id) + if t, err := u.redis.Get(ctx, key).Int64(); err != nil && !errors.Is(err, redis.Nil) { + u.logger.With("key", key).With("error", err).Warn("get admin active time failed") + } else { + m[id] = t + } + } + + return m, nil +} + // AdminLoginHistory implements domain.UserUsecase. func (u *UserUsecase) AdminLoginHistory(ctx context.Context, page *web.Pagination) (*domain.ListAdminLoginHistoryResp, error) { histories, p, err := u.repo.AdminLoginHistory(ctx, page)