diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 6af21f5..723c958 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -12,7 +12,6 @@ import ( "github.com/chaitin/MonkeyCode/backend/config" "github.com/chaitin/MonkeyCode/backend/db" - "github.com/chaitin/MonkeyCode/backend/domain" billingv1 "github.com/chaitin/MonkeyCode/backend/internal/billing/handler/http/v1" dashv1 "github.com/chaitin/MonkeyCode/backend/internal/dashboard/handler/v1" v1 "github.com/chaitin/MonkeyCode/backend/internal/model/handler/http/v1" @@ -25,7 +24,6 @@ type Server struct { web *web.Web ent *db.Client logger *slog.Logger - proxy domain.Proxy openaiV1 *openaiV1.V1Handler modelV1 *v1.ModelHandler userV1 *userV1.UserHandler diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index c5c2997..d757e09 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -10,7 +10,6 @@ import ( "github.com/GoYoko/web" "github.com/chaitin/MonkeyCode/backend/config" "github.com/chaitin/MonkeyCode/backend/db" - "github.com/chaitin/MonkeyCode/backend/domain" v1_5 "github.com/chaitin/MonkeyCode/backend/internal/billing/handler/http/v1" repo7 "github.com/chaitin/MonkeyCode/backend/internal/billing/repo" usecase6 "github.com/chaitin/MonkeyCode/backend/internal/billing/usecase" @@ -56,7 +55,7 @@ func newServer() (*Server, error) { proxyRepo := repo.NewProxyRepo(client) modelRepo := repo2.NewModelRepo(client) proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo) - domainProxy := proxy.NewLLMProxy(proxyUsecase, configConfig, slogLogger) + llmProxy := proxy.NewLLMProxy(slogLogger, configConfig, proxyUsecase) openAIRepo := repo3.NewOpenAIRepo(client) openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, slogLogger) extensionRepo := repo4.NewExtensionRepo(client) @@ -64,7 +63,7 @@ func newServer() (*Server, error) { proxyMiddleware := middleware.NewProxyMiddleware(proxyUsecase) redisClient := store.NewRedisCli(configConfig) activeMiddleware := middleware.NewActiveMiddleware(redisClient, slogLogger) - v1Handler := v1.NewV1Handler(slogLogger, web, domainProxy, openAIUsecase, extensionUsecase, proxyMiddleware, activeMiddleware, configConfig) + v1Handler := v1.NewV1Handler(slogLogger, web, llmProxy, proxyUsecase, openAIUsecase, extensionUsecase, proxyMiddleware, activeMiddleware, configConfig) modelUsecase := usecase3.NewModelUsecase(slogLogger, modelRepo, configConfig) sessionSession := session.NewSession(configConfig) authMiddleware := middleware.NewAuthMiddleware(sessionSession, slogLogger) @@ -83,7 +82,6 @@ func newServer() (*Server, error) { web: web, ent: client, logger: slogLogger, - proxy: domainProxy, openaiV1: v1Handler, modelV1: modelHandler, userV1: userHandler, @@ -100,7 +98,6 @@ type Server struct { web *web.Web ent *db.Client logger *slog.Logger - proxy domain.Proxy openaiV1 *v1.V1Handler modelV1 *v1_2.ModelHandler userV1 *v1_3.UserHandler diff --git a/backend/config/config.go b/backend/config/config.go index 5f721ac..9022a05 100644 --- a/backend/config/config.go +++ b/backend/config/config.go @@ -64,6 +64,7 @@ type Config struct { Extension struct { Baseurl string `mapstructure:"baseurl"` + Limit int `mapstructure:"limit"` } `mapstructure:"extension"` } @@ -99,6 +100,7 @@ func Init() (*Config, error) { v.SetDefault("init_model.key", "") v.SetDefault("init_model.url", "https://model-square.app.baizhi.cloud/v1") v.SetDefault("extension.baseurl", "https://release.baizhi.cloud") + v.SetDefault("extension.limit", 10) c := Config{} if err := v.Unmarshal(&c); err != nil { diff --git a/backend/domain/billing.go b/backend/domain/billing.go index efea1ad..81de700 100644 --- a/backend/domain/billing.go +++ b/backend/domain/billing.go @@ -126,7 +126,12 @@ func (c *ChatContent) From(e *db.TaskRecord) *ChatContent { return c } c.Role = e.Role - c.Content = e.Completion + switch e.Role { + case consts.ChatRoleUser: + c.Content = e.Prompt + case consts.ChatRoleAssistant: + c.Content = e.Completion + } c.CreatedAt = e.CreatedAt.Unix() return c } diff --git a/backend/domain/openai.go b/backend/domain/openai.go index 0849bc8..6398c22 100644 --- a/backend/domain/openai.go +++ b/backend/domain/openai.go @@ -3,10 +3,9 @@ package domain import ( "context" - "github.com/rokku-c/go-openai" - "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/db" + "github.com/rokku-c/go-openai" ) type OpenAIUsecase interface { @@ -21,7 +20,6 @@ type OpenAIRepo interface { type CompletionRequest struct { openai.CompletionRequest - Metadata map[string]string `json:"metadata"` } diff --git a/backend/internal/openai/handler/v1/v1.go b/backend/internal/openai/handler/v1/v1.go index 86684ab..aa98a91 100644 --- a/backend/internal/openai/handler/v1/v1.go +++ b/backend/internal/openai/handler/v1/v1.go @@ -6,27 +6,29 @@ import ( "net/http" "github.com/labstack/echo/v4" - "github.com/rokku-c/go-openai" "github.com/GoYoko/web" "github.com/chaitin/MonkeyCode/backend/config" "github.com/chaitin/MonkeyCode/backend/domain" "github.com/chaitin/MonkeyCode/backend/internal/middleware" + "github.com/chaitin/MonkeyCode/backend/internal/proxy" ) type V1Handler struct { - logger *slog.Logger - proxy domain.Proxy - usecase domain.OpenAIUsecase - euse domain.ExtensionUsecase - config *config.Config + logger *slog.Logger + proxy *proxy.LLMProxy + proxyUse domain.ProxyUsecase + usecase domain.OpenAIUsecase + euse domain.ExtensionUsecase + config *config.Config } func NewV1Handler( logger *slog.Logger, w *web.Web, - proxy domain.Proxy, + proxy *proxy.LLMProxy, + proxyUse domain.ProxyUsecase, usecase domain.OpenAIUsecase, euse domain.ExtensionUsecase, middleware *middleware.ProxyMiddleware, @@ -34,21 +36,23 @@ func NewV1Handler( config *config.Config, ) *V1Handler { h := &V1Handler{ - logger: logger.With(slog.String("handler", "openai")), - proxy: proxy, - usecase: usecase, - euse: euse, - config: config, + logger: logger.With(slog.String("handler", "openai")), + proxy: proxy, + proxyUse: proxyUse, + usecase: usecase, + euse: euse, + config: config, } + w.GET("/api/config", web.BindHandler(h.GetConfig), middleware.Auth()) w.GET("/v1/version", web.BaseHandler(h.Version), middleware.Auth()) g := w.Group("/v1", middleware.Auth()) g.GET("/models", web.BaseHandler(h.ModelList)) g.POST("/completion/accept", web.BindHandler(h.AcceptCompletion), active.Active()) - g.POST("/chat/completions", web.BindHandler(h.ChatCompletion), active.Active()) - g.POST("/completions", web.BindHandler(h.Completions), active.Active()) - g.POST("/embeddings", web.BindHandler(h.Embeddings), active.Active()) + g.POST("/chat/completions", web.BaseHandler(h.ChatCompletion), active.Active()) + g.POST("/completions", web.BaseHandler(h.Completions), active.Active()) + g.POST("/embeddings", web.BaseHandler(h.Embeddings), active.Active()) return h } @@ -86,7 +90,7 @@ func (h *V1Handler) Version(c *web.Context) error { // @Success 200 {object} web.Resp{} // @Router /v1/completion/accept [post] func (h *V1Handler) AcceptCompletion(c *web.Context, req domain.AcceptCompletionReq) error { - if err := h.proxy.AcceptCompletion(c.Request().Context(), &req); err != nil { + if err := h.proxyUse.AcceptCompletion(c.Request().Context(), &req); err != nil { return BadRequest(c, err.Error()) } return nil @@ -120,19 +124,8 @@ func (h *V1Handler) ModelList(c *web.Context) error { // @Produce json // @Success 200 {object} web.Resp{} // @Router /v1/chat/completions [post] -func (h *V1Handler) ChatCompletion(c *web.Context, req openai.ChatCompletionRequest) error { - // TODO: 记录请求到文件 - if req.Model == "" { - return BadRequest(c, "模型不能为空") - } - - // if len(req.Tools) > 0 && req.Model != "qwen-max" { - // if h.toolsCall(c, req, req.Stream, req.Model) { - // return nil - // } - // } - - h.proxy.HandleChatCompletion(c.Request().Context(), c.Response(), &req) +func (h *V1Handler) ChatCompletion(c *web.Context) error { + h.proxy.ServeHTTP(c.Response(), c.Request()) return nil } @@ -146,13 +139,8 @@ func (h *V1Handler) ChatCompletion(c *web.Context, req openai.ChatCompletionRequ // @Produce json // @Success 200 {object} web.Resp{} // @Router /v1/completions [post] -func (h *V1Handler) Completions(c *web.Context, req domain.CompletionRequest) error { - // TODO: 记录请求到文件 - if req.Model == "" { - return BadRequest(c, "模型不能为空") - } - h.logger.With("request", req).DebugContext(c.Request().Context(), "处理文本补全请求") - h.proxy.HandleCompletion(c.Request().Context(), c.Response(), req) +func (h *V1Handler) Completions(c *web.Context) error { + h.proxy.ServeHTTP(c.Response(), c.Request()) return nil } @@ -166,12 +154,8 @@ func (h *V1Handler) Completions(c *web.Context, req domain.CompletionRequest) er // @Produce json // @Success 200 {object} web.Resp{} // @Router /v1/embeddings [post] -func (h *V1Handler) Embeddings(c *web.Context, req openai.EmbeddingRequest) error { - if req.Model == "" { - return BadRequest(c, "模型不能为空") - } - - h.proxy.HandleEmbeddings(c.Request().Context(), c.Response(), &req) +func (h *V1Handler) Embeddings(c *web.Context) error { + h.proxy.ServeHTTP(c.Response(), c.Request()) return nil } diff --git a/backend/internal/proxy/proxy.go b/backend/internal/proxy/proxy.go index 0e61c3b..29cb16f 100644 --- a/backend/internal/proxy/proxy.go +++ b/backend/internal/proxy/proxy.go @@ -1,840 +1,141 @@ package proxy import ( - "bufio" - "bytes" "context" - "encoding/json" - "fmt" - "io" "log/slog" + "net" "net/http" + "net/http/httputil" "net/url" - "os" - "path/filepath" - "strings" "time" - "github.com/rokku-c/go-openai" - "github.com/chaitin/MonkeyCode/backend/config" "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/domain" - "github.com/chaitin/MonkeyCode/backend/pkg/cvt" "github.com/chaitin/MonkeyCode/backend/pkg/logger" - "github.com/chaitin/MonkeyCode/backend/pkg/promptparser" - "github.com/chaitin/MonkeyCode/backend/pkg/request" + "github.com/chaitin/MonkeyCode/backend/pkg/tee" ) -// ErrResp 错误响应 -type ErrResp struct { - Error struct { - Message string `json:"message"` - Type string `json:"type"` - Code string `json:"code,omitempty"` - } `json:"error"` +type CtxKey struct{} + +type ProxyCtx struct { + ctx context.Context + Path string + Model *domain.Model + Header http.Header + RespHeader http.Header + ReqTee *tee.ReqTee + RequestID string + UserID string } -// RequestResponseLog 请求响应日志结构 -type RequestResponseLog struct { - Timestamp time.Time `json:"timestamp"` - RequestID string `json:"request_id"` - Endpoint string `json:"endpoint"` - ModelName string `json:"model_name"` - ModelType consts.ModelType `json:"model_type"` - UpstreamURL string `json:"upstream_url"` - RequestBody any `json:"request_body"` - RequestHeader map[string][]string `json:"request_header"` - StatusCode int `json:"status_code"` - ResponseBody any `json:"response_body"` - ResponseHeader map[string][]string `json:"response_header"` - Latency int64 `json:"latency_ms"` - Error string `json:"error,omitempty"` - IsStream bool `json:"is_stream"` - SourceIP string `json:"source_ip"` // 请求来源IP地址 -} - -// LLMProxy LLM API代理实现 type LLMProxy struct { - usecase domain.ProxyUsecase - cfg *config.Config - client *http.Client - streamClient *http.Client - logger *slog.Logger - requestLogPath string // 请求日志保存路径 + logger *slog.Logger + cfg *config.Config + usecase domain.ProxyUsecase + transport *http.Transport + proxy *httputil.ReverseProxy } -// NewLLMProxy 创建LLM API代理实例 func NewLLMProxy( - usecase domain.ProxyUsecase, - cfg *config.Config, logger *slog.Logger, -) domain.Proxy { - // 解析超时时间 - timeout, err := time.ParseDuration(cfg.LLMProxy.Timeout) - if err != nil { - timeout = 30 * time.Second - logger.Warn("解析超时时间失败, 使用默认值 30s", "error", err) + cfg *config.Config, + usecase domain.ProxyUsecase, +) *LLMProxy { + l := &LLMProxy{ + logger: logger, + cfg: cfg, + usecase: usecase, } - // 解析保持连接时间 - keepAlive, err := time.ParseDuration(cfg.LLMProxy.KeepAlive) - if err != nil { - keepAlive = 60 * time.Second - logger.Warn("解析保持连接时间失败, 使用默认值 60s", "error", err) + l.transport = &http.Transport{ + MaxIdleConns: cfg.LLMProxy.ClientPoolSize, + MaxConnsPerHost: cfg.LLMProxy.ClientPoolSize, + MaxIdleConnsPerHost: cfg.LLMProxy.ClientPoolSize, + IdleConnTimeout: 24 * time.Hour, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 24 * time.Hour, + }).DialContext, } - - client := &http.Client{ - Timeout: timeout, - Transport: &http.Transport{ - MaxIdleConns: cfg.LLMProxy.ClientPoolSize, - MaxConnsPerHost: cfg.LLMProxy.ClientPoolSize, - MaxIdleConnsPerHost: cfg.LLMProxy.ClientPoolSize, - IdleConnTimeout: keepAlive, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - DisableCompression: false, - ForceAttemptHTTP2: true, - }, - } - - streamClient := &http.Client{ - Timeout: 60 * time.Second, - Transport: &http.Transport{ - MaxIdleConns: cfg.LLMProxy.StreamClientPoolSize, - MaxConnsPerHost: cfg.LLMProxy.StreamClientPoolSize, - MaxIdleConnsPerHost: cfg.LLMProxy.StreamClientPoolSize, - IdleConnTimeout: 24 * time.Hour, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - }, - } - - // 获取日志配置 - requestLogPath := "" - if cfg.LLMProxy.RequestLogPath != "" { - requestLogPath = cfg.LLMProxy.RequestLogPath - // 确保目录存在 - if err := os.MkdirAll(requestLogPath, 0755); err != nil { - logger.Warn("创建请求日志目录失败", "error", err, "path", requestLogPath) - } - } - - return &LLMProxy{ - usecase: usecase, - client: client, - streamClient: streamClient, - cfg: cfg, - requestLogPath: requestLogPath, - logger: logger, + l.proxy = &httputil.ReverseProxy{ + Transport: l.transport, + Rewrite: l.rewrite, + ModifyResponse: l.modifyResponse, + ErrorHandler: l.errorHandler, + FlushInterval: 100 * time.Millisecond, } + return l } -// saveRequestResponseLog 保存请求响应日志到文件 -func (p *LLMProxy) saveRequestResponseLog(log *RequestResponseLog) { - if p.requestLogPath == "" { +func (l *LLMProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + l.proxy.ServeHTTP(w, r) +} + +func (l *LLMProxy) Close() error { + l.transport.CloseIdleConnections() + return nil +} + +var modelType = map[string]consts.ModelType{ + "/v1/chat/completions": consts.ModelTypeLLM, + "/v1/completions": consts.ModelTypeCoder, +} + +func (l *LLMProxy) rewrite(r *httputil.ProxyRequest) { + l.logger.DebugContext(r.In.Context(), "rewrite request", slog.String("path", r.In.URL.Path)) + mt, ok := modelType[r.In.URL.Path] + if !ok { + l.logger.Error("model type not found", slog.String("path", r.In.URL.Path)) return } - // 创建文件名,格式:YYYYMMDD_HHMMSS_请求ID.json - timestamp := log.Timestamp.Format("20060102_150405") + fmt.Sprintf("_%03d", log.Timestamp.Nanosecond()/1e6) - filename := fmt.Sprintf("%s_%s.json", timestamp, log.RequestID) - filepath := filepath.Join(p.requestLogPath, filename) - - // 将日志序列化为JSON - logData, err := json.MarshalIndent(log, "", " ") + m, err := l.usecase.SelectModelWithLoadBalancing("", mt) if err != nil { - p.logger.Error("序列化请求日志失败", "error", err) + l.logger.Error("select model with load balancing failed", slog.String("path", r.In.URL.Path), slog.Any("err", err)) return } - // 写入文件 - if err := os.WriteFile(filepath, logData, 0644); err != nil { - p.logger.Error("写入请求日志文件失败", "error", err, "path", filepath) - return - } - - p.logger.Debug("请求响应日志已保存", "path", filepath) -} - -func (p *LLMProxy) AcceptCompletion(ctx context.Context, req *domain.AcceptCompletionReq) error { - return p.usecase.AcceptCompletion(ctx, req) -} - -func writeErrResp(w http.ResponseWriter, code int, message, errorType string) { - resp := ErrResp{ - Error: struct { - Message string `json:"message"` - Type string `json:"type"` - Code string `json:"code,omitempty"` - }{ - Message: message, - Type: errorType, - }, - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) - b, _ := json.Marshal(resp) - w.Write(b) -} - -type Ctx struct { - UserID string - RequestID string - SourceIP string -} - -func (p *LLMProxy) handle(ctx context.Context, fn func(ctx *Ctx, log *RequestResponseLog) error) { - // 获取用户ID - userID := "unknown" - if id, ok := ctx.Value(logger.UserIDKey{}).(string); ok { - userID = id - } - - requestID := "unknown" - if id, ok := ctx.Value(logger.RequestIDKey{}).(string); ok { - requestID = id - } - - sourceip := "unknown" - if ip, ok := ctx.Value("remote_addr").(string); ok { - sourceip = ip - } - - // 创建请求日志结构 - l := &RequestResponseLog{ - Timestamp: time.Now(), - RequestID: requestID, - ModelType: consts.ModelTypeCoder, - SourceIP: sourceip, - } - - c := &Ctx{ - UserID: userID, - RequestID: requestID, - SourceIP: sourceip, - } - - if err := fn(c, l); err != nil { - p.logger.With("source_ip", sourceip).ErrorContext(ctx, "处理请求失败", "error", err) - l.Error = err.Error() - } - - go p.saveRequestResponseLog(l) -} - -func (p *LLMProxy) HandleCompletion(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) { - if req.Stream { - p.handleCompletionStream(ctx, w, req) - } else { - p.handleCompletion(ctx, w, req) - } -} - -func (p *LLMProxy) handleCompletionStream(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) { - endpoint := "/completions" - p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error { - // 使用负载均衡算法选择模型 - m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeCoder) - if err != nil { - p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "模型选择失败", "error", err) - writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error") - return err - } - - // 构造上游API URL - upstream := m.APIBase + endpoint - log.UpstreamURL = upstream - startTime := time.Now() - - // 创建上游请求 - body, err := json.Marshal(req) - if err != nil { - p.logger.ErrorContext(ctx, "序列化请求体失败", "error", err) - return fmt.Errorf("序列化请求体失败: %w", err) - } - - upreq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstream, bytes.NewReader(body)) - if err != nil { - p.logger.With("upstream", upstream).WarnContext(ctx, "创建上游流式请求失败", "error", err) - return fmt.Errorf("创建上游请求失败: %w", err) - } - - // 设置请求头 - upreq.Header.Set("Content-Type", "application/json") - upreq.Header.Set("Accept", "text/event-stream") - if m.APIKey != "" && m.APIKey != "none" { - upreq.Header.Set("Authorization", "Bearer "+m.APIKey) - } - - // 保存请求头(去除敏感信息) - requestHeaders := make(map[string][]string) - for k, v := range upreq.Header { - if k != "Authorization" { - requestHeaders[k] = v - } else { - // 敏感信息脱敏 - requestHeaders[k] = []string{"Bearer ***"} - } - } - log.RequestHeader = requestHeaders - - p.logger.With( - "upstreamURL", upstream, - "modelName", m.ModelName, - "modelType", consts.ModelTypeLLM, - "apiBase", m.APIBase, - "requestHeader", upreq.Header, - "requestBody", upreq, - ).DebugContext(ctx, "转发流式请求到上游API") - - // 发送请求 - resp, err := p.client.Do(upreq) - if err != nil { - p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游流式请求失败", "error", err) - return fmt.Errorf("发送上游请求失败: %w", err) - } - defer resp.Body.Close() - - // 检查响应状态 - if resp.StatusCode != http.StatusOK { - responseBody, _ := io.ReadAll(resp.Body) - - // 更新日志错误信息 - log.StatusCode = resp.StatusCode - log.ResponseHeader = resp.Header - log.ResponseBody = string(responseBody) - log.Latency = time.Since(startTime).Milliseconds() - - // 在debug级别记录错误的流式响应内容 - p.logger.With( - "statusCode", resp.StatusCode, - "responseHeader", resp.Header, - "responseBody", string(responseBody), - ).DebugContext(ctx, "上游流式响应错误原始内容") - - var errorResp ErrResp - if err := json.Unmarshal(responseBody, &errorResp); err == nil { - p.logger.With( - "endpoint", endpoint, - "upstreamURL", upstream, - "requestBody", upreq, - "statusCode", resp.StatusCode, - "errorType", errorResp.Error.Type, - "errorCode", errorResp.Error.Code, - "errorMessage", errorResp.Error.Message, - "latency", time.Since(startTime), - ).WarnContext(ctx, "上游API流式请求异常详情") - - return fmt.Errorf("上游API返回错误: %s", errorResp.Error.Message) - } - - p.logger.With( - "endpoint", endpoint, - "upstreamURL", upstream, - "requestBody", upreq, - "statusCode", resp.StatusCode, - "responseBody", string(responseBody), - ).WarnContext(ctx, "上游API流式请求异常详情") - - return fmt.Errorf("上游API返回非200状态码: %d, 响应: %s", resp.StatusCode, string(responseBody)) - } - - // 更新日志信息 - log.StatusCode = resp.StatusCode - log.ResponseHeader = resp.Header - - // 在debug级别记录流式响应头信息 - p.logger.With( - "statusCode", resp.StatusCode, - "responseHeader", resp.Header, - ).DebugContext(ctx, "上游流式响应头信息") - - // 设置响应头 - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Transfer-Encoding", "chunked") - - rc := &domain.RecordParam{ - UserID: c.UserID, - ModelID: m.ID, - ModelType: consts.ModelTypeLLM, - Prompt: req.Prompt.(string), - Role: consts.ChatRoleAssistant, - } - buf := bufio.NewWriterSize(w, 32*1024) - defer buf.Flush() - - ch := make(chan []byte, 1024) - defer close(ch) - - go func(rc *domain.RecordParam) { - for line := range ch { - if bytes.HasPrefix(line, []byte("data:")) { - line = bytes.TrimPrefix(line, []byte("data: ")) - line = bytes.TrimSpace(line) - if len(line) == 0 { - continue - } - - if bytes.Equal(line, []byte("[DONE]")) { - break - } - - var t openai.CompletionResponse - if err := json.Unmarshal(line, &t); err != nil { - p.logger.With("line", string(line)).WarnContext(ctx, "解析流式数据失败", "error", err) - continue - } - - p.logger.With("response", t).DebugContext(ctx, "流式响应数据") - if len(t.Choices) > 0 { - rc.Completion += t.Choices[0].Text - } - rc.InputTokens = int64(t.Usage.PromptTokens) - rc.OutputTokens += int64(t.Usage.CompletionTokens) - } - } - - if rc.OutputTokens == 0 { - return - } - - p.logger.With("record", rc).DebugContext(ctx, "流式记录") - if err := p.usecase.Record(context.Background(), rc); err != nil { - p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM). - WarnContext(ctx, "插入流式记录失败", "error", err) - } - }(rc) - - err = streamRead(ctx, resp.Body, func(line []byte) error { - ch <- line - if _, err := buf.Write(line); err != nil { - return fmt.Errorf("写入响应失败: %w", err) - } - return buf.Flush() + if r.In.ContentLength > 0 { + tee := tee.NewReqTeeWithMaxSize(r.In.Body, 10*1024*1024) + r.Out.Body = tee + ctx := context.WithValue(r.In.Context(), CtxKey{}, &ProxyCtx{ + ctx: r.In.Context(), + Path: r.In.URL.Path, + Model: m, + ReqTee: tee, + RequestID: r.In.Context().Value(logger.RequestIDKey{}).(string), + UserID: r.In.Context().Value(logger.UserIDKey{}).(string), + Header: r.In.Header, }) - return err - }) -} + r.Out = r.Out.WithContext(ctx) + } -func (p *LLMProxy) handleCompletion(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) { - endpoint := "/completions" - p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error { - // 使用负载均衡算法选择模型 - m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeCoder) - if err != nil { - p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "模型选择失败", "error", err) - writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error") - return err - } - - // 构造上游API URL - upstream := m.APIBase + endpoint - log.UpstreamURL = upstream - u, err := url.Parse(upstream) - if err != nil { - p.logger.With("upstreamURL", upstream).WarnContext(ctx, "解析上游URL失败", "error", err) - writeErrResp(w, http.StatusInternalServerError, "无效的上游URL", "proxy_error") - return err - } - - startTime := time.Now() - - client := request.NewClient( - u.Scheme, - u.Host, - 30*time.Second, - request.WithClient(p.client), - ) - client.SetDebug(p.cfg.Debug) - resp, err := request.Post[openai.CompletionResponse](client, u.Path, req, request.WithHeader(request.Header{ - "Authorization": "Bearer " + m.APIKey, - })) - if err != nil { - p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游请求失败", "error", err) - writeErrResp(w, http.StatusInternalServerError, "发送上游请求失败", "proxy_error") - log.Latency = time.Since(startTime).Milliseconds() - return err - } - - latency := time.Since(startTime) - - // 记录请求信息 - p.logger.With( - "statusCode", http.StatusOK, - "upstreamURL", upstream, - "modelName", req.Model, - "modelType", consts.ModelTypeCoder, - "apiBase", m.APIBase, - "requestBody", req, - "resp", resp, - "latency", latency.String(), - ).DebugContext(ctx, "转发请求到上游API") - - go p.recordCompletion(c, m.ID, req, resp) - - // 更新请求日志 - log.StatusCode = http.StatusOK - log.ResponseBody = resp - log.ResponseHeader = resp.Header() - log.Latency = latency.Milliseconds() - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - b, err := json.Marshal(resp) - if err != nil { - p.logger.With("response", resp).WarnContext(ctx, "序列化响应失败", "error", err) - return err - } - w.Write(b) - return nil - }) -} - -func (p *LLMProxy) recordCompletion(c *Ctx, modelID string, req domain.CompletionRequest, resp *openai.CompletionResponse) { - if resp.Usage.CompletionTokens == 0 { + u, err := url.Parse(m.APIBase) + if err != nil { + l.logger.ErrorContext(r.In.Context(), "parse model api base failed", slog.String("path", r.In.URL.Path), slog.Any("err", err)) return } - ctx := context.Background() - prompt := req.Prompt.(string) - rc := &domain.RecordParam{ - TaskID: resp.ID, - UserID: c.UserID, - ModelID: modelID, - ModelType: consts.ModelTypeCoder, - Prompt: prompt, - ProgramLanguage: req.Metadata["program_language"], - InputTokens: int64(resp.Usage.PromptTokens), - OutputTokens: int64(resp.Usage.CompletionTokens), - } - - for _, choice := range resp.Choices { - rc.Completion += choice.Text - } - lines := strings.Count(rc.Completion, "\n") + 1 - rc.CodeLines = int64(lines) - - if err := p.usecase.Record(ctx, rc); err != nil { - p.logger.With("modelID", modelID, "modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "记录请求失败", "error", err) + r.Out.URL.Scheme = u.Scheme + r.Out.URL.Host = u.Host + r.Out.URL.Path = r.In.URL.Path + r.Out.Header.Set("Authorization", "Bearer "+m.APIKey) + r.SetXForwarded() + r.Out.Host = u.Host + l.logger.With("in", r.In.URL.Path, "out", r.Out.URL.Path).DebugContext(r.In.Context(), "rewrite request") +} + +func (l *LLMProxy) modifyResponse(resp *http.Response) error { + ctx := resp.Request.Context() + if pctx, ok := ctx.Value(CtxKey{}).(*ProxyCtx); ok { + pctx.ctx = ctx + pctx.RespHeader = resp.Header + resp.Body = NewRecorder(l.cfg, pctx, resp.Body, l.logger, l.usecase) } + return nil } -func (p *LLMProxy) HandleChatCompletion(ctx context.Context, w http.ResponseWriter, req *openai.ChatCompletionRequest) { - if req.Stream { - p.handleChatCompletionStream(ctx, w, req) - } else { - p.handleChatCompletion(ctx, w, req) - } -} - -func streamRead(ctx context.Context, r io.Reader, fn func([]byte) error) error { - reader := bufio.NewReaderSize(r, 32*1024) - for { - select { - case <-ctx.Done(): - return fmt.Errorf("流式请求被取消: %w", ctx.Err()) - default: - line, err := reader.ReadBytes('\n') - if err == io.EOF { - return nil - } - if err != nil { - return fmt.Errorf("读取流式数据失败: %w", err) - } - - line = bytes.TrimSpace(line) - if len(line) == 0 { - continue - } - - line = append(line, '\n') - line = append(line, '\n') - - fn(line) - } - } -} - -func (p *LLMProxy) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string { - prompt := "" - parse := promptparser.New(promptparser.KindTask) - for _, message := range req.Messages { - if message.Role == "system" { - continue - } - - if strings.Contains(message.Content, "") || - strings.Contains(message.Content, "") || - strings.Contains(message.Content, "") { - if info, err := parse.Parse(message.Content); err == nil { - prompt = info.Prompt - } else { - p.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err) - } - } - - for _, m := range message.MultiContent { - if strings.Contains(m.Text, "") || - strings.Contains(m.Text, "") || - strings.Contains(m.Text, "") { - if info, err := parse.Parse(m.Text); err == nil { - prompt = info.Prompt - } else { - p.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err) - } - } - } - } - return prompt -} - -func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.ResponseWriter, req *openai.ChatCompletionRequest) { - endpoint := "/chat/completions" - p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error { - startTime := time.Now() - - m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeLLM) - if err != nil { - p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeLLM).WarnContext(ctx, "流式请求模型选择失败", "error", err) - writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error") - return err - } - - upstream := m.APIBase + endpoint - log.UpstreamURL = upstream - - body, err := json.Marshal(req) - if err != nil { - p.logger.ErrorContext(ctx, "序列化请求体失败", "error", err) - return fmt.Errorf("序列化请求体失败: %w", err) - } - - newReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstream, bytes.NewReader(body)) - if err != nil { - p.logger.With("upstream", upstream).WarnContext(ctx, "创建上游流式请求失败", "error", err) - return fmt.Errorf("创建上游请求失败: %w", err) - } - - newReq.Header.Set("Content-Type", "application/json") - newReq.Header.Set("Accept", "text/event-stream") - newReq.Header.Set("Authorization", "Bearer "+m.APIKey) - - // 保存请求头(去除敏感信息) - requestHeaders := make(map[string][]string) - for k, v := range newReq.Header { - if k != "Authorization" { - requestHeaders[k] = v - } else { - // 敏感信息脱敏 - requestHeaders[k] = []string{"Bearer ***"} - } - } - log.RequestHeader = requestHeaders - - logger := p.logger.With( - "request_id", c.RequestID, - "source_ip", c.SourceIP, - "upstreamURL", upstream, - "modelName", m.ModelName, - "modelType", consts.ModelTypeLLM, - "apiBase", m.APIBase, - ) - - logger.With( - "upstreamURL", upstream, - "requestHeader", newReq.Header, - "requestBody", req, - "messages", cvt.Filter(req.Messages, func(i int, v openai.ChatCompletionMessage) (openai.ChatCompletionMessage, bool) { - return v, v.Role != "system" - }), - ).DebugContext(ctx, "转发流式请求到上游API") - - // 发送请求 - resp, err := p.streamClient.Do(newReq) - if err != nil { - p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游流式请求失败", "error", err) - return fmt.Errorf("发送上游请求失败: %w", err) - } - defer resp.Body.Close() - - // 检查响应状态 - if resp.StatusCode != http.StatusOK { - responseBody, _ := io.ReadAll(resp.Body) - - // 更新日志错误信息 - log.StatusCode = resp.StatusCode - log.ResponseHeader = resp.Header - log.ResponseBody = string(responseBody) - log.Latency = time.Since(startTime).Milliseconds() - - // 在debug级别记录错误的流式响应内容 - logger.With( - "statusCode", resp.StatusCode, - "responseHeader", resp.Header, - "responseBody", string(responseBody), - ).DebugContext(ctx, "上游流式响应错误原始内容") - - var errorResp ErrResp - if err := json.Unmarshal(responseBody, &errorResp); err == nil { - logger.With( - "endpoint", endpoint, - "requestBody", newReq, - "statusCode", resp.StatusCode, - "errorType", errorResp.Error.Type, - "errorCode", errorResp.Error.Code, - "errorMessage", errorResp.Error.Message, - "latency", time.Since(startTime), - ).WarnContext(ctx, "上游API流式请求异常详情") - - return fmt.Errorf("上游API返回错误: %s", errorResp.Error.Message) - } - - logger.With( - "endpoint", endpoint, - "requestBody", newReq, - "statusCode", resp.StatusCode, - "responseBody", string(responseBody), - ).WarnContext(ctx, "上游API流式请求异常详情") - - return fmt.Errorf("上游API返回非200状态码: %d, 响应: %s", resp.StatusCode, string(responseBody)) - } - - log.StatusCode = resp.StatusCode - log.ResponseHeader = resp.Header - - logger.With( - "statusCode", resp.StatusCode, - "responseHeader", resp.Header, - ).DebugContext(ctx, "上游流式响应头信息") - - // 设置响应头 - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Transfer-Encoding", "chunked") - w.Header().Set("X-Accel-Buffering", "no") - - recorder := NewChatRecorder( - ctx, - c, - p.usecase, - m, - req, - resp.Body, - w, - p.logger.With("module", "ChatRecorder"), - ) - defer recorder.Close() - return recorder.Stream() - }) -} - -func (p *LLMProxy) handleChatCompletion(ctx context.Context, w http.ResponseWriter, req *openai.ChatCompletionRequest) { - endpoint := "/chat/completions" - p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error { - // 使用负载均衡算法选择模型 - m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeCoder) - if err != nil { - p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "模型选择失败", "error", err) - writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error") - return err - } - - // 构造上游API URL - upstream := m.APIBase + endpoint - log.UpstreamURL = upstream - u, err := url.Parse(upstream) - if err != nil { - p.logger.With("upstreamURL", upstream).WarnContext(ctx, "解析上游URL失败", "error", err) - writeErrResp(w, http.StatusInternalServerError, "无效的上游URL", "proxy_error") - return err - } - - startTime := time.Now() - prompt := p.getPrompt(ctx, req) - mode := req.Metadata["mode"] - taskID := req.Metadata["task_id"] - - client := request.NewClient( - u.Scheme, - u.Host, - 30*time.Second, - request.WithClient(p.client), - ) - resp, err := request.Post[openai.ChatCompletionResponse](client, u.Path, req, request.WithHeader(request.Header{ - "Authorization": "Bearer " + m.APIKey, - })) - if err != nil { - p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游请求失败", "error", err) - writeErrResp(w, http.StatusInternalServerError, "发送上游请求失败", "proxy_error") - log.Latency = time.Since(startTime).Milliseconds() - return err - } - - // 记录请求信息 - p.logger.With( - "upstreamURL", upstream, - "modelName", req.Model, - "modelType", consts.ModelTypeCoder, - "apiBase", m.APIBase, - "requestBody", req, - ).DebugContext(ctx, "转发请求到上游API") - - go func() { - rc := &domain.RecordParam{ - TaskID: taskID, - UserID: c.UserID, - Prompt: prompt, - WorkMode: mode, - ModelID: m.ID, - ModelType: m.ModelType, - InputTokens: int64(resp.Usage.PromptTokens), - OutputTokens: int64(resp.Usage.CompletionTokens), - ProgramLanguage: req.Metadata["program_language"], - } - - for _, choice := range resp.Choices { - rc.Completion += choice.Message.Content + "\n\n" - } - - p.logger.With("record", rc).DebugContext(ctx, "记录") - - if err := p.usecase.Record(ctx, rc); err != nil { - p.logger.With("modelID", m.ID, "modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "记录请求失败", "error", err) - } - }() - - // 计算请求耗时 - latency := time.Since(startTime) - - // 更新请求日志 - log.StatusCode = http.StatusOK - log.ResponseBody = resp - log.ResponseHeader = resp.Header() - log.Latency = latency.Milliseconds() - - // 记录响应状态 - p.logger.With( - "statusCode", http.StatusOK, - "responseHeader", resp.Header(), - "responseBody", resp, - "latency", latency.String(), - ).DebugContext(ctx, "上游API响应") - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - b, err := json.Marshal(resp) - if err != nil { - p.logger.With("response", resp).WarnContext(ctx, "序列化响应失败", "error", err) - return err - } - w.Write(b) - return nil - }) -} - -func (p *LLMProxy) HandleEmbeddings(ctx context.Context, w http.ResponseWriter, req *openai.EmbeddingRequest) { +func (l *LLMProxy) errorHandler(w http.ResponseWriter, r *http.Request, err error) { + l.logger.ErrorContext(r.Context(), "error handler", slog.String("path", r.URL.Path), slog.Any("err", err)) } diff --git a/backend/internal/proxy/recorder.go b/backend/internal/proxy/recorder.go index 9e94a33..55e3e6a 100644 --- a/backend/internal/proxy/recorder.go +++ b/backend/internal/proxy/recorder.go @@ -5,191 +5,331 @@ import ( "encoding/json" "io" "log/slog" + "net/http" + "os" + "path/filepath" "strings" + "time" "github.com/rokku-c/go-openai" + "github.com/chaitin/MonkeyCode/backend/config" "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/domain" "github.com/chaitin/MonkeyCode/backend/pkg/promptparser" - "github.com/chaitin/MonkeyCode/backend/pkg/tee" ) -type ChatRecorder struct { - *tee.Tee - ctx context.Context - cx *Ctx - usecase domain.ProxyUsecase - req *openai.ChatCompletionRequest - logger *slog.Logger - model *domain.Model - completion strings.Builder // 累积完整的响应内容 - usage *openai.Usage // 最终的使用统计 - buffer strings.Builder // 缓存不完整的行 - recorded bool // 标记是否已记录完整对话 +type Recorder struct { + cfg *config.Config + usecase domain.ProxyUsecase + shadown chan []byte + src io.ReadCloser + ctx *ProxyCtx + logger *slog.Logger + logFile *os.File } -func NewChatRecorder( - ctx context.Context, - cx *Ctx, - usecase domain.ProxyUsecase, - model *domain.Model, - req *openai.ChatCompletionRequest, - r io.Reader, - w io.Writer, +var _ io.ReadCloser = &Recorder{} + +func NewRecorder( + cfg *config.Config, + ctx *ProxyCtx, + src io.ReadCloser, logger *slog.Logger, -) *ChatRecorder { - c := &ChatRecorder{ - ctx: ctx, - cx: cx, + usecase domain.ProxyUsecase, +) *Recorder { + r := &Recorder{ + cfg: cfg, usecase: usecase, - model: model, - req: req, + shadown: make(chan []byte, 128*1024), + src: src, + ctx: ctx, logger: logger, } - c.Tee = tee.NewTee(ctx, logger, r, w, c.handle) - return c + go r.handleShadow() + return r } -func (c *ChatRecorder) handle(ctx context.Context, data []byte) error { - c.buffer.Write(data) - bufferContent := c.buffer.String() +func formatHeader(header http.Header) map[string]string { + headerMap := make(map[string]string) + for key, values := range header { + headerMap[key] = strings.Join(values, ",") + } + return headerMap +} - lines := strings.Split(bufferContent, "\n") - if len(lines) > 0 { - lastLine := lines[len(lines)-1] - lines = lines[:len(lines)-1] - c.buffer.Reset() - c.buffer.WriteString(lastLine) +func (r *Recorder) handleShadow() { + body, err := r.ctx.ReqTee.GetBody() + if err != nil { + r.logger.WarnContext(r.ctx.ctx, "get req tee body failed", "error", err) + return } - for _, line := range lines { - if err := c.processSSELine(ctx, line); err != nil { - return err + var ( + taskID, mode, prompt, language string + ) + + switch r.ctx.Model.ModelType { + case consts.ModelTypeLLM: + var req openai.ChatCompletionRequest + if err := json.Unmarshal(body, &req); err != nil { + r.logger.WarnContext(r.ctx.ctx, "unmarshal chat completion request failed", "error", err) + return + } + prompt = r.getPrompt(r.ctx.ctx, &req) + taskID = req.Metadata["task_id"] + mode = req.Metadata["mode"] + + case consts.ModelTypeCoder: + var req domain.CompletionRequest + if err := json.Unmarshal(body, &req); err != nil { + r.logger.WarnContext(r.ctx.ctx, "unmarshal completion request failed", "error", err) + return + } + prompt = req.Prompt.(string) + taskID = req.Metadata["task_id"] + mode = req.Metadata["mode"] + language = req.Metadata["program_language"] + + default: + r.logger.WarnContext(r.ctx.ctx, "skip handle shadow, model type not support", "modelType", r.ctx.Model.ModelType) + return + } + + r.createFile(taskID) + r.writeMeta(body) + + rc := &domain.RecordParam{ + RequestID: r.ctx.RequestID, + TaskID: taskID, + UserID: r.ctx.UserID, + ModelID: r.ctx.Model.ID, + ModelType: r.ctx.Model.ModelType, + WorkMode: mode, + Prompt: prompt, + ProgramLanguage: language, + Role: consts.ChatRoleUser, + } + + var assistantRc *domain.RecordParam + ct := r.ctx.RespHeader.Get("Content-Type") + if strings.Contains(ct, "stream") { + r.handleStream(rc) + if r.ctx.Model.ModelType == consts.ModelTypeLLM { + assistantRc = rc.Clone() + assistantRc.Role = consts.ChatRoleAssistant + rc.Completion = "" + rc.OutputTokens = 0 + } + } else { + r.handleJson(rc) + } + + r.logger. + With("header", formatHeader(r.ctx.Header)). + With("resp_header", formatHeader(r.ctx.RespHeader)). + DebugContext(r.ctx.ctx, "handle shadow", "rc", rc) + + if err := r.usecase.Record(context.Background(), rc); err != nil { + r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err) + } + + if assistantRc != nil { + if err := r.usecase.Record(context.Background(), assistantRc); err != nil { + r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err) + } + } +} + +func (r *Recorder) writeMeta(body []byte) { + if r.logFile == nil { + return + } + + r.logFile.WriteString("------------------ Request Header ------------------\n") + for key, value := range formatHeader(r.ctx.Header) { + r.logFile.WriteString(key + ": " + value + "\n") + } + r.logFile.WriteString("------------------ Request Body ------------------\n") + r.logFile.WriteString(string(body)) + r.logFile.WriteString("\n") + r.logFile.WriteString("------------------ Response Header ------------------\n") + for key, value := range formatHeader(r.ctx.RespHeader) { + r.logFile.WriteString(key + ": " + value + "\n") + } + r.logFile.WriteString("------------------ Response Body ------------------\n") +} + +func (r *Recorder) createFile(taskID string) { + if r.cfg.LLMProxy.RequestLogPath == "" { + return + } + + dir := filepath.Join(r.cfg.LLMProxy.RequestLogPath, time.Now().Format("2006010215")) + if err := os.MkdirAll(dir, 0755); err != nil { + r.logger.WarnContext(r.ctx.ctx, "create dir failed", "error", err) + return + } + + id := r.ctx.RequestID + if taskID != "" { + id = taskID + } + filename := filepath.Join(dir, id+".log") + f, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + r.logger.WarnContext(r.ctx.ctx, "create file failed", "error", err) + return + } + r.logFile = f +} + +func (r *Recorder) handleJson(rc *domain.RecordParam) { + buffer := strings.Builder{} + for data := range r.shadown { + buffer.Write(data) + if r.logFile != nil { + r.logFile.Write(data) } } - return nil + switch rc.ModelType { + case consts.ModelTypeLLM: + var resp openai.ChatCompletionResponse + if err := json.Unmarshal([]byte(buffer.String()), &resp); err != nil { + r.logger.WarnContext(r.ctx.ctx, "unmarshal chat completion response failed", "error", err) + return + } + if len(resp.Choices) > 0 { + rc.Completion = resp.Choices[0].Message.Content + rc.InputTokens = int64(resp.Usage.PromptTokens) + rc.OutputTokens = int64(resp.Usage.CompletionTokens) + } + + case consts.ModelTypeCoder: + var resp openai.CompletionResponse + if err := json.Unmarshal([]byte(buffer.String()), &resp); err != nil { + r.logger.WarnContext(r.ctx.ctx, "unmarshal completion response failed", "error", err) + return + } + if rc.TaskID == "" { + rc.TaskID = resp.ID + } + rc.InputTokens = int64(resp.Usage.PromptTokens) + rc.OutputTokens = int64(resp.Usage.CompletionTokens) + if len(resp.Choices) > 0 { + rc.Completion = resp.Choices[0].Text + rc.CodeLines = int64(strings.Count(resp.Choices[0].Text, "\n")) + } + } } -func (c *ChatRecorder) processSSELine(ctx context.Context, line string) error { - line = strings.TrimSpace(line) +func (r *Recorder) handleStream(rc *domain.RecordParam) { + buffer := strings.Builder{} + for data := range r.shadown { + buffer.Write(data) + cnt := buffer.String() + if r.logFile != nil { + r.logFile.Write(data) + } + + lines := strings.Split(cnt, "\n") + if len(lines) > 0 { + lastLine := lines[len(lines)-1] + lines = lines[:len(lines)-1] + buffer.Reset() + buffer.WriteString(lastLine) + } + + for _, line := range lines { + if err := r.processSSELine(r.ctx.ctx, line, rc); err != nil { + r.logger.WarnContext(r.ctx.ctx, "处理SSE行失败", "error", err) + } + } + } +} + +func (r *Recorder) processSSELine(ctx context.Context, line string, rc *domain.RecordParam) error { + line = strings.TrimSpace(line) if line == "" || !strings.HasPrefix(line, "data:") { return nil } - dataContent := strings.TrimSpace(strings.TrimPrefix(line, "data:")) - if dataContent == "" { + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if data == "" { return nil } - if dataContent == "[DONE]" { - c.processCompletedChat(ctx) + if data == "[DONE]" { return nil } - var resp openai.ChatCompletionStreamResponse - if err := json.Unmarshal([]byte(dataContent), &resp); err != nil { - c.logger.With("data", dataContent).WarnContext(ctx, "解析流式响应失败", "error", err) - return nil - } + switch r.ctx.Model.ModelType { + case consts.ModelTypeLLM: + var resp openai.ChatCompletionStreamResponse + if err := json.Unmarshal([]byte(data), &resp); err != nil { + r.logger.With("model_type", r.ctx.Model.ModelType).With("data", data).WarnContext(ctx, "解析SSE行失败", "error", err) + return nil + } + if resp.Usage != nil { + rc.InputTokens = int64(resp.Usage.PromptTokens) + rc.OutputTokens += int64(resp.Usage.CompletionTokens) + } + if len(resp.Choices) > 0 { + content := resp.Choices[0].Delta.Content + if content != "" { + rc.Completion += content + } + } - prompt := c.getPrompt(ctx, c.req) - mode := c.req.Metadata["mode"] - taskID := c.req.Metadata["task_id"] + case consts.ModelTypeCoder: + var resp openai.CompletionResponse + if err := json.Unmarshal([]byte(data), &resp); err != nil { + r.logger.With("model_type", r.ctx.Model.ModelType).With("data", data).WarnContext(ctx, "解析SSE行失败", "error", err) + return nil + } - rc := &domain.RecordParam{ - RequestID: c.cx.RequestID, - TaskID: taskID, - UserID: c.cx.UserID, - ModelID: c.model.ID, - ModelType: c.model.ModelType, - WorkMode: mode, - Prompt: prompt, - Role: consts.ChatRoleAssistant, - } - if resp.Usage != nil { + if rc.TaskID == "" { + rc.TaskID = resp.ID + } rc.InputTokens = int64(resp.Usage.PromptTokens) - } - - if rc.Prompt != "" { - urc := rc.Clone() - urc.Role = consts.ChatRoleUser - urc.Completion = urc.Prompt - if err := c.usecase.Record(context.Background(), urc); err != nil { - c.logger.With("modelID", c.model.ID, "modelName", c.model.ModelName, "modelType", consts.ModelTypeLLM). - WarnContext(ctx, "插入流式记录失败", "error", err) + rc.OutputTokens += int64(resp.Usage.CompletionTokens) + if len(resp.Choices) > 0 { + rc.Completion += resp.Choices[0].Text + rc.CodeLines += int64(strings.Count(resp.Choices[0].Text, "\n")) } } - if len(resp.Choices) > 0 { - content := resp.Choices[0].Delta.Content - if content != "" { - c.completion.WriteString(content) - } - } - - if resp.Usage != nil { - c.usage = resp.Usage - } - return nil } -func (c *ChatRecorder) processCompletedChat(ctx context.Context) { - if c.recorded { - return // 避免重复记录 +// Close implements io.ReadCloser. +func (r *Recorder) Close() error { + if r.shadown != nil { + close(r.shadown) } - - mode := c.req.Metadata["mode"] - taskID := c.req.Metadata["task_id"] - - rc := &domain.RecordParam{ - RequestID: c.cx.RequestID, - TaskID: taskID, - UserID: c.cx.UserID, - ModelID: c.model.ID, - ModelType: c.model.ModelType, - WorkMode: mode, - Role: consts.ChatRoleAssistant, - Completion: c.completion.String(), - InputTokens: int64(c.usage.PromptTokens), - OutputTokens: int64(c.usage.CompletionTokens), - } - - if err := c.usecase.Record(context.Background(), rc); err != nil { - c.logger.With("modelID", c.model.ID, "modelName", c.model.ModelName, "modelType", consts.ModelTypeLLM). - WarnContext(ctx, "插入流式记录失败", "error", err) - } else { - c.recorded = true - c.logger.With("requestID", c.cx.RequestID, "completion_length", len(c.completion.String())). - InfoContext(ctx, "流式对话记录已保存") + if r.logFile != nil { + r.logFile.Close() } + return r.src.Close() } -// Close 关闭 recorder 并确保数据被保存 -func (c *ChatRecorder) Close() { - // 如果有累积的内容但还没有记录,强制保存 - if !c.recorded && c.completion.Len() > 0 { - c.logger.With("requestID", c.cx.RequestID). - WarnContext(c.ctx, "数据流异常中断,强制保存已累积的内容") - c.processCompletedChat(c.ctx) +// Read implements io.ReadCloser. +func (r *Recorder) Read(p []byte) (n int, err error) { + n, err = r.src.Read(p) + if n > 0 { + data := make([]byte, n) + copy(data, p[:n]) + r.shadown <- data } - - if c.Tee != nil { - c.Tee.Close() + if err != nil { + return } + return } -func (c *ChatRecorder) Reset() { - c.completion.Reset() - c.buffer.Reset() - c.usage = nil - c.recorded = false -} - -func (c *ChatRecorder) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string { +func (r *Recorder) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string { prompt := "" parse := promptparser.New(promptparser.KindTask) for _, message := range req.Messages { @@ -203,7 +343,7 @@ func (c *ChatRecorder) getPrompt(ctx context.Context, req *openai.ChatCompletion if info, err := parse.Parse(message.Content); err == nil { prompt = info.Prompt } else { - c.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err) + r.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err) } } @@ -214,7 +354,7 @@ func (c *ChatRecorder) getPrompt(ctx context.Context, req *openai.ChatCompletion if info, err := parse.Parse(m.Text); err == nil { prompt = info.Prompt } else { - c.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err) + r.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err) } } } diff --git a/backend/internal/user/handler/v1/user.go b/backend/internal/user/handler/v1/user.go index e532b0c..e8df701 100644 --- a/backend/internal/user/handler/v1/user.go +++ b/backend/internal/user/handler/v1/user.go @@ -23,6 +23,7 @@ type UserHandler struct { session *session.Session logger *slog.Logger cfg *config.Config + limitCh chan struct{} } func NewUserHandler( @@ -40,6 +41,7 @@ func NewUserHandler( logger: logger, cfg: cfg, euse: euse, + limitCh: make(chan struct{}, cfg.Extension.Limit), } w.GET("/api/v1/static/vsix/:version", web.BaseHandler(u.VSIXDownload)) @@ -94,6 +96,11 @@ func (h *UserHandler) VSCodeAuthInit(c *web.Context, req domain.VSCodeAuthInitRe // @Produce octet-stream // @Router /api/v1/static/vsix [get] func (h *UserHandler) VSIXDownload(c *web.Context) error { + h.limitCh <- struct{}{} + defer func() { + <-h.limitCh + }() + v, err := h.euse.GetByVersion(c.Request().Context(), c.Param("version")) if err != nil { return err diff --git a/backend/pkg/tee/reqtee.go b/backend/pkg/tee/reqtee.go new file mode 100644 index 0000000..2faef03 --- /dev/null +++ b/backend/pkg/tee/reqtee.go @@ -0,0 +1,81 @@ +package tee + +import ( + "bytes" + "fmt" + "io" + "sync" + "time" +) + +type ReqTee struct { + src io.ReadCloser + buf bytes.Buffer + done chan struct{} + mu sync.RWMutex + closed bool + maxSize int64 // 最大缓冲区大小,0表示无限制 +} + +var _ io.ReadCloser = &ReqTee{} + +func NewReqTee(src io.ReadCloser) *ReqTee { + return NewReqTeeWithMaxSize(src, 0) +} + +// NewReqTeeWithMaxSize 创建带最大缓冲区大小限制的ReqTee +func NewReqTeeWithMaxSize(src io.ReadCloser, maxSize int64) *ReqTee { + return &ReqTee{ + src: src, + buf: bytes.Buffer{}, + done: make(chan struct{}), + maxSize: maxSize, + } +} + +func (r *ReqTee) GetBody() ([]byte, error) { + return r.GetBodyWithTimeout(30 * time.Second) +} + +// GetBodyWithTimeout 获取缓冲的数据,支持自定义超时时间 +func (r *ReqTee) GetBodyWithTimeout(timeout time.Duration) ([]byte, error) { + select { + case <-r.done: + r.mu.RLock() + defer r.mu.RUnlock() + return r.buf.Bytes(), nil + case <-time.After(timeout): + return nil, fmt.Errorf("timeout waiting for data") + } +} + +// Close implements io.ReadCloser. +func (r *ReqTee) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return nil + } + r.closed = true + + close(r.done) + return r.src.Close() +} + +// Read implements io.ReadCloser. +func (r *ReqTee) Read(p []byte) (n int, err error) { + n, err = r.src.Read(p) + if n > 0 { + r.mu.Lock() + // 检查缓冲区大小限制 + if r.maxSize > 0 && int64(r.buf.Len()+n) > r.maxSize { + r.mu.Unlock() + return n, fmt.Errorf("buffer size limit exceeded: %d bytes", r.maxSize) + } + // 直接写入缓冲区,避免额外的内存分配和复制 + r.buf.Write(p[:n]) + r.mu.Unlock() + } + return n, err +} diff --git a/backend/pkg/tee/tee.go b/backend/pkg/tee/tee.go deleted file mode 100644 index 577e924..0000000 --- a/backend/pkg/tee/tee.go +++ /dev/null @@ -1,103 +0,0 @@ -package tee - -import ( - "context" - "io" - "log/slog" - "net/http" - "sync" -) - -type TeeHandleFunc func(ctx context.Context, data []byte) error - -var bufferPool = sync.Pool{ - New: func() any { - buf := make([]byte, 4096) - return &buf - }, -} - -type Tee struct { - ctx context.Context - logger *slog.Logger - Reader io.Reader - Writer io.Writer - ch chan []byte - handle TeeHandleFunc -} - -func NewTee( - ctx context.Context, - logger *slog.Logger, - reader io.Reader, - writer io.Writer, - handle TeeHandleFunc, -) *Tee { - t := &Tee{ - ctx: ctx, - logger: logger, - Reader: reader, - Writer: writer, - handle: handle, - ch: make(chan []byte, 32*1024), - } - go t.Handle() - return t -} - -func (t *Tee) Close() { - select { - case <-t.ch: - // channel 已经关闭 - default: - close(t.ch) - } -} - -func (t *Tee) Handle() { - for { - select { - case data, ok := <-t.ch: - if !ok { - t.logger.DebugContext(t.ctx, "Tee Handle closed") - return - } - err := t.handle(t.ctx, data) - if err != nil { - t.logger.With("data", string(data)).With("error", err).ErrorContext(t.ctx, "Tee Handle error") - return - } - case <-t.ctx.Done(): - t.logger.DebugContext(t.ctx, "Tee Handle ctx done") - return - } - } -} - -func (t *Tee) Stream() error { - bufPtr := bufferPool.Get().(*[]byte) - buf := *bufPtr - defer bufferPool.Put(bufPtr) - - for { - n, err := t.Reader.Read(buf) - if err != nil { - if err == io.EOF { - return nil - } - return err - } - if n > 0 { - _, err = t.Writer.Write(buf[:n]) - if err != nil { - return err - } - if flusher, ok := t.Writer.(http.Flusher); ok { - flusher.Flush() - } - data := make([]byte, n) - copy(data, buf[:n]) - t.ch <- data - } - } -} diff --git a/backend/pkg/tee/tee_test.go b/backend/pkg/tee/tee_test.go deleted file mode 100644 index 5f91a4c..0000000 --- a/backend/pkg/tee/tee_test.go +++ /dev/null @@ -1,376 +0,0 @@ -package tee - -import ( - "bytes" - "context" - "errors" - "io" - "log/slog" - "os" - "strings" - "sync" - "testing" - "time" -) - -// mockWriter 模拟 Writer 接口 -type mockWriter struct { - buf *bytes.Buffer - delay time.Duration - errorOn int // 在第几次写入时返回错误 - count int -} - -func newMockWriter() *mockWriter { - return &mockWriter{ - buf: &bytes.Buffer{}, - } -} - -func (m *mockWriter) Write(p []byte) (n int, err error) { - m.count++ - if m.errorOn > 0 && m.count >= m.errorOn { - return 0, errors.New("mock write error") - } - if m.delay > 0 { - time.Sleep(m.delay) - } - return m.buf.Write(p) -} - -func (m *mockWriter) String() string { - return m.buf.String() -} - -// mockReader 模拟 Reader 接口 -type mockReader struct { - data []byte - pos int - chunk int // 每次读取的字节数 - errorOn int // 在第几次读取时返回错误 - count int -} - -func newMockReader(data string, chunk int) *mockReader { - return &mockReader{ - data: []byte(data), - chunk: chunk, - } -} - -func (m *mockReader) Read(p []byte) (n int, err error) { - m.count++ - if m.errorOn > 0 && m.count >= m.errorOn { - return 0, errors.New("mock read error") - } - - if m.pos >= len(m.data) { - return 0, io.EOF - } - - readSize := m.chunk - if readSize <= 0 || readSize > len(p) { - readSize = len(p) - } - - remaining := len(m.data) - m.pos - if readSize > remaining { - readSize = remaining - } - - copy(p, m.data[m.pos:m.pos+readSize]) - m.pos += readSize - return readSize, nil -} - -// TestTeeBasicFunctionality 测试基本功能 -func TestTeeBasicFunctionality(t *testing.T) { - ctx := context.Background() - logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError})) - - testData := "Hello, World! This is a test message." - reader := newMockReader(testData, 10) // 每次读取10字节 - writer := newMockWriter() - - var handledData [][]byte - var mu sync.Mutex - - handle := func(ctx context.Context, data []byte) error { - mu.Lock() - defer mu.Unlock() - // 复制数据,因为原始数据可能被重用 - dataCopy := make([]byte, len(data)) - copy(dataCopy, data) - handledData = append(handledData, dataCopy) - return nil - } - - tee := NewTee(ctx, logger, reader, writer, handle) - defer tee.Close() - - err := tee.Stream() - if err != nil { - t.Fatalf("Stream() failed: %v", err) - } - - // 等待处理完成 - time.Sleep(100 * time.Millisecond) - - // 验证写入的数据 - if writer.String() != testData { - t.Errorf("Expected writer data %q, got %q", testData, writer.String()) - } - - // 验证处理的数据 - mu.Lock() - var totalHandled []byte - for _, chunk := range handledData { - totalHandled = append(totalHandled, chunk...) - } - mu.Unlock() - - if string(totalHandled) != testData { - t.Errorf("Expected handled data %q, got %q", testData, string(totalHandled)) - } -} - -// TestTeeWithErrors 测试错误处理 -func TestTeeWithErrors(t *testing.T) { - ctx := context.Background() - logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError})) - - t.Run("ReaderError", func(t *testing.T) { - reader := newMockReader("test data", 5) - reader.errorOn = 2 // 第二次读取时出错 - writer := newMockWriter() - - handle := func(ctx context.Context, data []byte) error { - return nil - } - - tee := NewTee(ctx, logger, reader, writer, handle) - defer tee.Close() - - err := tee.Stream() - if err == nil { - t.Error("Expected error from reader, got nil") - } - }) - - t.Run("WriterError", func(t *testing.T) { - reader := newMockReader("test data", 5) - writer := newMockWriter() - writer.errorOn = 2 // 第二次写入时出错 - - handle := func(ctx context.Context, data []byte) error { - return nil - } - - tee := NewTee(ctx, logger, reader, writer, handle) - defer tee.Close() - - err := tee.Stream() - if err == nil { - t.Error("Expected error from writer, got nil") - } - }) - - t.Run("HandleError", func(t *testing.T) { - reader := newMockReader("test data", 5) - writer := newMockWriter() - - handle := func(ctx context.Context, data []byte) error { - return errors.New("handle error") - } - - tee := NewTee(ctx, logger, reader, writer, handle) - defer tee.Close() - - // 启动 Stream 在单独的 goroutine 中 - go func() { - tee.Stream() - }() - - // 等待一段时间让处理器有机会处理数据并出错 - time.Sleep(200 * time.Millisecond) - }) -} - -// TestTeeContextCancellation 测试上下文取消 -func TestTeeContextCancellation(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError})) - - // 创建一个会持续产生数据的 reader - reader := strings.NewReader(strings.Repeat("test data ", 1000)) - writer := newMockWriter() - - handle := func(ctx context.Context, data []byte) error { - return nil - } - - tee := NewTee(ctx, logger, reader, writer, handle) - defer tee.Close() - - // 在单独的 goroutine 中启动 Stream - done := make(chan error, 1) - go func() { - done <- tee.Stream() - }() - - // 等待一段时间后取消上下文 - time.Sleep(50 * time.Millisecond) - cancel() - - // 等待 Stream 完成 - select { - case err := <-done: - if err != nil && err != io.EOF { - t.Logf("Stream completed with error: %v", err) - } - case <-time.After(2 * time.Second): - t.Error("Stream did not complete within timeout") - } -} - -// TestTeeConcurrentSafety 测试并发安全性 -func TestTeeConcurrentSafety(t *testing.T) { - ctx := context.Background() - logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError})) - - testData := strings.Repeat("concurrent test data ", 100) - reader := strings.NewReader(testData) - writer := newMockWriter() - - var processedCount int64 - var mu sync.Mutex - - handle := func(ctx context.Context, data []byte) error { - mu.Lock() - processedCount++ - mu.Unlock() - // 模拟一些处理时间 - time.Sleep(time.Microsecond) - return nil - } - - tee := NewTee(ctx, logger, reader, writer, handle) - defer tee.Close() - - err := tee.Stream() - if err != nil { - t.Fatalf("Stream() failed: %v", err) - } - - // 等待所有数据处理完成 - time.Sleep(500 * time.Millisecond) - - mu.Lock() - finalCount := processedCount - mu.Unlock() - - if finalCount == 0 { - t.Error("No data was processed") - } - - t.Logf("Processed %d chunks of data", finalCount) -} - -// TestBufferPoolEfficiency 测试缓冲区池的效率 -func TestBufferPoolEfficiency(t *testing.T) { - // 这个测试验证缓冲区池是否正常工作 - // 通过多次获取和归还缓冲区来测试 - - var buffers []*[]byte - - // 获取多个缓冲区 - for i := 0; i < 10; i++ { - bufPtr := bufferPool.Get().(*[]byte) - buffers = append(buffers, bufPtr) - - // 验证缓冲区大小 - if len(*bufPtr) != 4096 { - t.Errorf("Expected buffer size 4096, got %d", len(*bufPtr)) - } - } - - // 归还所有缓冲区 - for _, bufPtr := range buffers { - bufferPool.Put(bufPtr) - } - - // 再次获取缓冲区,应该重用之前的缓冲区 - for i := 0; i < 5; i++ { - bufPtr := bufferPool.Get().(*[]byte) - if len(*bufPtr) != 4096 { - t.Errorf("Expected reused buffer size 4096, got %d", len(*bufPtr)) - } - bufferPool.Put(bufPtr) - } -} - -// BenchmarkTeeStream 基准测试 -func BenchmarkTeeStream(b *testing.B) { - ctx := context.Background() - logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError})) - - testData := strings.Repeat("benchmark test data ", 1000) - - handle := func(ctx context.Context, data []byte) error { - return nil - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - reader := strings.NewReader(testData) - writer := io.Discard - - tee := NewTee(ctx, logger, reader, writer, handle) - err := tee.Stream() - if err != nil { - b.Fatalf("Stream() failed: %v", err) - } - tee.Close() - } -} - -// BenchmarkBufferPool 缓冲区池基准测试 -func BenchmarkBufferPool(b *testing.B) { - b.Run("WithPool", func(b *testing.B) { - for i := 0; i < b.N; i++ { - bufPtr := bufferPool.Get().(*[]byte) - // 模拟使用缓冲区 - _ = *bufPtr - bufferPool.Put(bufPtr) - } - }) - - b.Run("WithoutPool", func(b *testing.B) { - for i := 0; i < b.N; i++ { - buf := make([]byte, 4096) - // 模拟使用缓冲区 - _ = buf - } - }) -} - -// TestTeeClose 测试关闭功能 -func TestTeeClose(t *testing.T) { - ctx := context.Background() - logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError})) - - reader := strings.NewReader("test data") - writer := newMockWriter() - - handle := func(ctx context.Context, data []byte) error { - return nil - } - - tee := NewTee(ctx, logger, reader, writer, handle) - - // 测试多次关闭不会 panic - tee.Close() - tee.Close() - tee.Close() -}