Files
MonkeyCode/backend/internal/middleware/proxy.go
2025-08-29 18:29:50 +08:00

108 lines
2.9 KiB
Go

package middleware
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"github.com/redis/go-redis/v9"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/ent/rule"
"github.com/chaitin/MonkeyCode/backend/pkg/logger"
)
const (
ApiContextKey = "session:apikey"
)
type proxyModelKey struct{}
type ProxyMiddleware struct {
usecase domain.ProxyUsecase
redis *redis.Client
logger *slog.Logger
}
func NewProxyMiddleware(
usecase domain.ProxyUsecase,
redis *redis.Client,
logger *slog.Logger,
) *ProxyMiddleware {
return &ProxyMiddleware{
usecase: usecase,
redis: redis,
logger: logger.With("module", "ProxyMiddleware"),
}
}
func (p *ProxyMiddleware) Auth() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
apiKey := c.Request().Header.Get("X-API-Key")
if apiKey == "" {
apiKey = strings.TrimPrefix(c.Request().Header.Get("Authorization"), "Bearer ")
}
if apiKey == "" {
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
}
ctx := c.Request().Context()
p.logger.With("apiKey", apiKey).DebugContext(ctx, "v1 auth")
if strings.Contains(apiKey, ".") {
s, err := p.redis.Get(ctx, apiKey).Result()
if err != nil {
p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to get api key from redis")
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
}
var model *domain.Model
if err := json.Unmarshal([]byte(s), &model); err != nil {
p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to unmarshal model from redis")
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
}
parts := strings.Split(apiKey, ".")
if len(parts) != 2 {
p.logger.With("fn", "Auth").With("apiKey", apiKey).ErrorContext(ctx, "invalid api key")
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
}
ctx = context.WithValue(ctx, proxyModelKey{}, model)
ctx = context.WithValue(ctx, logger.UserIDKey{}, parts[0])
c.Set(ApiContextKey, &domain.ApiKey{
UserID: parts[0],
Key: apiKey,
})
} else {
key, err := p.usecase.ValidateApiKey(ctx, apiKey)
if err != nil {
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
}
ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID)
c.Set(ApiContextKey, key)
}
ctx = rule.SkipPermission(ctx)
c.SetRequest(c.Request().WithContext(ctx))
return next(c)
}
}
}
func GetProxyModel(ctx context.Context) *domain.Model {
m := ctx.Value(proxyModelKey{})
if m == nil {
return nil
}
return m.(*domain.Model)
}
func GetApiKey(c echo.Context) *domain.ApiKey {
i := c.Get(ApiContextKey)
if i == nil {
return nil
}
return i.(*domain.ApiKey)
}