mirror of
https://github.com/chaitin/MonkeyCode.git
synced 2026-02-01 22:33:30 +08:00
108 lines
2.9 KiB
Go
108 lines
2.9 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/redis/go-redis/v9"
|
|
|
|
"github.com/chaitin/MonkeyCode/backend/domain"
|
|
"github.com/chaitin/MonkeyCode/backend/ent/rule"
|
|
"github.com/chaitin/MonkeyCode/backend/pkg/logger"
|
|
)
|
|
|
|
const (
|
|
ApiContextKey = "session:apikey"
|
|
)
|
|
|
|
type proxyModelKey struct{}
|
|
|
|
type ProxyMiddleware struct {
|
|
usecase domain.ProxyUsecase
|
|
redis *redis.Client
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func NewProxyMiddleware(
|
|
usecase domain.ProxyUsecase,
|
|
redis *redis.Client,
|
|
logger *slog.Logger,
|
|
) *ProxyMiddleware {
|
|
return &ProxyMiddleware{
|
|
usecase: usecase,
|
|
redis: redis,
|
|
logger: logger.With("module", "ProxyMiddleware"),
|
|
}
|
|
}
|
|
|
|
func (p *ProxyMiddleware) Auth() echo.MiddlewareFunc {
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
apiKey := c.Request().Header.Get("X-API-Key")
|
|
if apiKey == "" {
|
|
apiKey = strings.TrimPrefix(c.Request().Header.Get("Authorization"), "Bearer ")
|
|
}
|
|
if apiKey == "" {
|
|
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
|
|
}
|
|
|
|
ctx := c.Request().Context()
|
|
p.logger.With("apiKey", apiKey).DebugContext(ctx, "v1 auth")
|
|
if strings.Contains(apiKey, ".") {
|
|
s, err := p.redis.Get(ctx, apiKey).Result()
|
|
if err != nil {
|
|
p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to get api key from redis")
|
|
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
|
|
}
|
|
var model *domain.Model
|
|
if err := json.Unmarshal([]byte(s), &model); err != nil {
|
|
p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to unmarshal model from redis")
|
|
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
|
|
}
|
|
parts := strings.Split(apiKey, ".")
|
|
if len(parts) != 2 {
|
|
p.logger.With("fn", "Auth").With("apiKey", apiKey).ErrorContext(ctx, "invalid api key")
|
|
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
|
|
}
|
|
ctx = context.WithValue(ctx, proxyModelKey{}, model)
|
|
ctx = context.WithValue(ctx, logger.UserIDKey{}, parts[0])
|
|
c.Set(ApiContextKey, &domain.ApiKey{
|
|
UserID: parts[0],
|
|
Key: apiKey,
|
|
})
|
|
} else {
|
|
key, err := p.usecase.ValidateApiKey(ctx, apiKey)
|
|
if err != nil {
|
|
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
|
|
}
|
|
ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID)
|
|
c.Set(ApiContextKey, key)
|
|
}
|
|
|
|
ctx = rule.SkipPermission(ctx)
|
|
c.SetRequest(c.Request().WithContext(ctx))
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func GetProxyModel(ctx context.Context) *domain.Model {
|
|
m := ctx.Value(proxyModelKey{})
|
|
if m == nil {
|
|
return nil
|
|
}
|
|
return m.(*domain.Model)
|
|
}
|
|
|
|
func GetApiKey(c echo.Context) *domain.ApiKey {
|
|
i := c.Get(ApiContextKey)
|
|
if i == nil {
|
|
return nil
|
|
}
|
|
return i.(*domain.ApiKey)
|
|
}
|