Merge pull request #92 from yokowu/feat-proxyv2

feat(proxy): 利用 ReverseProxy 实现代理, 进一步分离代理与分析的逻辑
This commit is contained in:
Yoko
2025-07-16 16:04:58 +08:00
committed by GitHub
12 changed files with 496 additions and 1462 deletions

View File

@@ -12,7 +12,6 @@ import (
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/domain"
billingv1 "github.com/chaitin/MonkeyCode/backend/internal/billing/handler/http/v1"
dashv1 "github.com/chaitin/MonkeyCode/backend/internal/dashboard/handler/v1"
v1 "github.com/chaitin/MonkeyCode/backend/internal/model/handler/http/v1"
@@ -25,7 +24,6 @@ type Server struct {
web *web.Web
ent *db.Client
logger *slog.Logger
proxy domain.Proxy
openaiV1 *openaiV1.V1Handler
modelV1 *v1.ModelHandler
userV1 *userV1.UserHandler

View File

@@ -10,7 +10,6 @@ import (
"github.com/GoYoko/web"
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/domain"
v1_5 "github.com/chaitin/MonkeyCode/backend/internal/billing/handler/http/v1"
repo7 "github.com/chaitin/MonkeyCode/backend/internal/billing/repo"
usecase6 "github.com/chaitin/MonkeyCode/backend/internal/billing/usecase"
@@ -56,7 +55,7 @@ func newServer() (*Server, error) {
proxyRepo := repo.NewProxyRepo(client)
modelRepo := repo2.NewModelRepo(client)
proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo)
domainProxy := proxy.NewLLMProxy(proxyUsecase, configConfig, slogLogger)
llmProxy := proxy.NewLLMProxy(slogLogger, configConfig, proxyUsecase)
openAIRepo := repo3.NewOpenAIRepo(client)
openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, slogLogger)
extensionRepo := repo4.NewExtensionRepo(client)
@@ -64,7 +63,7 @@ func newServer() (*Server, error) {
proxyMiddleware := middleware.NewProxyMiddleware(proxyUsecase)
redisClient := store.NewRedisCli(configConfig)
activeMiddleware := middleware.NewActiveMiddleware(redisClient, slogLogger)
v1Handler := v1.NewV1Handler(slogLogger, web, domainProxy, openAIUsecase, extensionUsecase, proxyMiddleware, activeMiddleware, configConfig)
v1Handler := v1.NewV1Handler(slogLogger, web, llmProxy, proxyUsecase, openAIUsecase, extensionUsecase, proxyMiddleware, activeMiddleware, configConfig)
modelUsecase := usecase3.NewModelUsecase(slogLogger, modelRepo, configConfig)
sessionSession := session.NewSession(configConfig)
authMiddleware := middleware.NewAuthMiddleware(sessionSession, slogLogger)
@@ -83,7 +82,6 @@ func newServer() (*Server, error) {
web: web,
ent: client,
logger: slogLogger,
proxy: domainProxy,
openaiV1: v1Handler,
modelV1: modelHandler,
userV1: userHandler,
@@ -100,7 +98,6 @@ type Server struct {
web *web.Web
ent *db.Client
logger *slog.Logger
proxy domain.Proxy
openaiV1 *v1.V1Handler
modelV1 *v1_2.ModelHandler
userV1 *v1_3.UserHandler

View File

@@ -64,6 +64,7 @@ type Config struct {
Extension struct {
Baseurl string `mapstructure:"baseurl"`
Limit int `mapstructure:"limit"`
} `mapstructure:"extension"`
}
@@ -99,6 +100,7 @@ func Init() (*Config, error) {
v.SetDefault("init_model.key", "")
v.SetDefault("init_model.url", "https://model-square.app.baizhi.cloud/v1")
v.SetDefault("extension.baseurl", "https://release.baizhi.cloud")
v.SetDefault("extension.limit", 10)
c := Config{}
if err := v.Unmarshal(&c); err != nil {

View File

@@ -126,7 +126,12 @@ func (c *ChatContent) From(e *db.TaskRecord) *ChatContent {
return c
}
c.Role = e.Role
c.Content = e.Completion
switch e.Role {
case consts.ChatRoleUser:
c.Content = e.Prompt
case consts.ChatRoleAssistant:
c.Content = e.Completion
}
c.CreatedAt = e.CreatedAt.Unix()
return c
}

View File

@@ -3,10 +3,9 @@ package domain
import (
"context"
"github.com/rokku-c/go-openai"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/rokku-c/go-openai"
)
type OpenAIUsecase interface {
@@ -21,7 +20,6 @@ type OpenAIRepo interface {
type CompletionRequest struct {
openai.CompletionRequest
Metadata map[string]string `json:"metadata"`
}

View File

@@ -6,27 +6,29 @@ import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/rokku-c/go-openai"
"github.com/GoYoko/web"
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/internal/middleware"
"github.com/chaitin/MonkeyCode/backend/internal/proxy"
)
type V1Handler struct {
logger *slog.Logger
proxy domain.Proxy
usecase domain.OpenAIUsecase
euse domain.ExtensionUsecase
config *config.Config
logger *slog.Logger
proxy *proxy.LLMProxy
proxyUse domain.ProxyUsecase
usecase domain.OpenAIUsecase
euse domain.ExtensionUsecase
config *config.Config
}
func NewV1Handler(
logger *slog.Logger,
w *web.Web,
proxy domain.Proxy,
proxy *proxy.LLMProxy,
proxyUse domain.ProxyUsecase,
usecase domain.OpenAIUsecase,
euse domain.ExtensionUsecase,
middleware *middleware.ProxyMiddleware,
@@ -34,21 +36,23 @@ func NewV1Handler(
config *config.Config,
) *V1Handler {
h := &V1Handler{
logger: logger.With(slog.String("handler", "openai")),
proxy: proxy,
usecase: usecase,
euse: euse,
config: config,
logger: logger.With(slog.String("handler", "openai")),
proxy: proxy,
proxyUse: proxyUse,
usecase: usecase,
euse: euse,
config: config,
}
w.GET("/api/config", web.BindHandler(h.GetConfig), middleware.Auth())
w.GET("/v1/version", web.BaseHandler(h.Version), middleware.Auth())
g := w.Group("/v1", middleware.Auth())
g.GET("/models", web.BaseHandler(h.ModelList))
g.POST("/completion/accept", web.BindHandler(h.AcceptCompletion), active.Active())
g.POST("/chat/completions", web.BindHandler(h.ChatCompletion), active.Active())
g.POST("/completions", web.BindHandler(h.Completions), active.Active())
g.POST("/embeddings", web.BindHandler(h.Embeddings), active.Active())
g.POST("/chat/completions", web.BaseHandler(h.ChatCompletion), active.Active())
g.POST("/completions", web.BaseHandler(h.Completions), active.Active())
g.POST("/embeddings", web.BaseHandler(h.Embeddings), active.Active())
return h
}
@@ -86,7 +90,7 @@ func (h *V1Handler) Version(c *web.Context) error {
// @Success 200 {object} web.Resp{}
// @Router /v1/completion/accept [post]
func (h *V1Handler) AcceptCompletion(c *web.Context, req domain.AcceptCompletionReq) error {
if err := h.proxy.AcceptCompletion(c.Request().Context(), &req); err != nil {
if err := h.proxyUse.AcceptCompletion(c.Request().Context(), &req); err != nil {
return BadRequest(c, err.Error())
}
return nil
@@ -120,19 +124,8 @@ func (h *V1Handler) ModelList(c *web.Context) error {
// @Produce json
// @Success 200 {object} web.Resp{}
// @Router /v1/chat/completions [post]
func (h *V1Handler) ChatCompletion(c *web.Context, req openai.ChatCompletionRequest) error {
// TODO: 记录请求到文件
if req.Model == "" {
return BadRequest(c, "模型不能为空")
}
// if len(req.Tools) > 0 && req.Model != "qwen-max" {
// if h.toolsCall(c, req, req.Stream, req.Model) {
// return nil
// }
// }
h.proxy.HandleChatCompletion(c.Request().Context(), c.Response(), &req)
func (h *V1Handler) ChatCompletion(c *web.Context) error {
h.proxy.ServeHTTP(c.Response(), c.Request())
return nil
}
@@ -146,13 +139,8 @@ func (h *V1Handler) ChatCompletion(c *web.Context, req openai.ChatCompletionRequ
// @Produce json
// @Success 200 {object} web.Resp{}
// @Router /v1/completions [post]
func (h *V1Handler) Completions(c *web.Context, req domain.CompletionRequest) error {
// TODO: 记录请求到文件
if req.Model == "" {
return BadRequest(c, "模型不能为空")
}
h.logger.With("request", req).DebugContext(c.Request().Context(), "处理文本补全请求")
h.proxy.HandleCompletion(c.Request().Context(), c.Response(), req)
func (h *V1Handler) Completions(c *web.Context) error {
h.proxy.ServeHTTP(c.Response(), c.Request())
return nil
}
@@ -166,12 +154,8 @@ func (h *V1Handler) Completions(c *web.Context, req domain.CompletionRequest) er
// @Produce json
// @Success 200 {object} web.Resp{}
// @Router /v1/embeddings [post]
func (h *V1Handler) Embeddings(c *web.Context, req openai.EmbeddingRequest) error {
if req.Model == "" {
return BadRequest(c, "模型不能为空")
}
h.proxy.HandleEmbeddings(c.Request().Context(), c.Response(), &req)
func (h *V1Handler) Embeddings(c *web.Context) error {
h.proxy.ServeHTTP(c.Response(), c.Request())
return nil
}

View File

@@ -1,840 +1,141 @@
package proxy
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/rokku-c/go-openai"
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
"github.com/chaitin/MonkeyCode/backend/pkg/logger"
"github.com/chaitin/MonkeyCode/backend/pkg/promptparser"
"github.com/chaitin/MonkeyCode/backend/pkg/request"
"github.com/chaitin/MonkeyCode/backend/pkg/tee"
)
// ErrResp 错误响应
type ErrResp struct {
Error struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code,omitempty"`
} `json:"error"`
type CtxKey struct{}
type ProxyCtx struct {
ctx context.Context
Path string
Model *domain.Model
Header http.Header
RespHeader http.Header
ReqTee *tee.ReqTee
RequestID string
UserID string
}
// RequestResponseLog 请求响应日志结构
type RequestResponseLog struct {
Timestamp time.Time `json:"timestamp"`
RequestID string `json:"request_id"`
Endpoint string `json:"endpoint"`
ModelName string `json:"model_name"`
ModelType consts.ModelType `json:"model_type"`
UpstreamURL string `json:"upstream_url"`
RequestBody any `json:"request_body"`
RequestHeader map[string][]string `json:"request_header"`
StatusCode int `json:"status_code"`
ResponseBody any `json:"response_body"`
ResponseHeader map[string][]string `json:"response_header"`
Latency int64 `json:"latency_ms"`
Error string `json:"error,omitempty"`
IsStream bool `json:"is_stream"`
SourceIP string `json:"source_ip"` // 请求来源IP地址
}
// LLMProxy LLM API代理实现
type LLMProxy struct {
usecase domain.ProxyUsecase
cfg *config.Config
client *http.Client
streamClient *http.Client
logger *slog.Logger
requestLogPath string // 请求日志保存路径
logger *slog.Logger
cfg *config.Config
usecase domain.ProxyUsecase
transport *http.Transport
proxy *httputil.ReverseProxy
}
// NewLLMProxy 创建LLM API代理实例
func NewLLMProxy(
usecase domain.ProxyUsecase,
cfg *config.Config,
logger *slog.Logger,
) domain.Proxy {
// 解析超时时间
timeout, err := time.ParseDuration(cfg.LLMProxy.Timeout)
if err != nil {
timeout = 30 * time.Second
logger.Warn("解析超时时间失败, 使用默认值 30s", "error", err)
cfg *config.Config,
usecase domain.ProxyUsecase,
) *LLMProxy {
l := &LLMProxy{
logger: logger,
cfg: cfg,
usecase: usecase,
}
// 解析保持连接时间
keepAlive, err := time.ParseDuration(cfg.LLMProxy.KeepAlive)
if err != nil {
keepAlive = 60 * time.Second
logger.Warn("解析保持连接时间失败, 使用默认值 60s", "error", err)
l.transport = &http.Transport{
MaxIdleConns: cfg.LLMProxy.ClientPoolSize,
MaxConnsPerHost: cfg.LLMProxy.ClientPoolSize,
MaxIdleConnsPerHost: cfg.LLMProxy.ClientPoolSize,
IdleConnTimeout: 24 * time.Hour,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 24 * time.Hour,
}).DialContext,
}
client := &http.Client{
Timeout: timeout,
Transport: &http.Transport{
MaxIdleConns: cfg.LLMProxy.ClientPoolSize,
MaxConnsPerHost: cfg.LLMProxy.ClientPoolSize,
MaxIdleConnsPerHost: cfg.LLMProxy.ClientPoolSize,
IdleConnTimeout: keepAlive,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableCompression: false,
ForceAttemptHTTP2: true,
},
}
streamClient := &http.Client{
Timeout: 60 * time.Second,
Transport: &http.Transport{
MaxIdleConns: cfg.LLMProxy.StreamClientPoolSize,
MaxConnsPerHost: cfg.LLMProxy.StreamClientPoolSize,
MaxIdleConnsPerHost: cfg.LLMProxy.StreamClientPoolSize,
IdleConnTimeout: 24 * time.Hour,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}
// 获取日志配置
requestLogPath := ""
if cfg.LLMProxy.RequestLogPath != "" {
requestLogPath = cfg.LLMProxy.RequestLogPath
// 确保目录存在
if err := os.MkdirAll(requestLogPath, 0755); err != nil {
logger.Warn("创建请求日志目录失败", "error", err, "path", requestLogPath)
}
}
return &LLMProxy{
usecase: usecase,
client: client,
streamClient: streamClient,
cfg: cfg,
requestLogPath: requestLogPath,
logger: logger,
l.proxy = &httputil.ReverseProxy{
Transport: l.transport,
Rewrite: l.rewrite,
ModifyResponse: l.modifyResponse,
ErrorHandler: l.errorHandler,
FlushInterval: 100 * time.Millisecond,
}
return l
}
// saveRequestResponseLog 保存请求响应日志到文件
func (p *LLMProxy) saveRequestResponseLog(log *RequestResponseLog) {
if p.requestLogPath == "" {
func (l *LLMProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
l.proxy.ServeHTTP(w, r)
}
func (l *LLMProxy) Close() error {
l.transport.CloseIdleConnections()
return nil
}
var modelType = map[string]consts.ModelType{
"/v1/chat/completions": consts.ModelTypeLLM,
"/v1/completions": consts.ModelTypeCoder,
}
func (l *LLMProxy) rewrite(r *httputil.ProxyRequest) {
l.logger.DebugContext(r.In.Context(), "rewrite request", slog.String("path", r.In.URL.Path))
mt, ok := modelType[r.In.URL.Path]
if !ok {
l.logger.Error("model type not found", slog.String("path", r.In.URL.Path))
return
}
// 创建文件名格式YYYYMMDD_HHMMSS_请求ID.json
timestamp := log.Timestamp.Format("20060102_150405") + fmt.Sprintf("_%03d", log.Timestamp.Nanosecond()/1e6)
filename := fmt.Sprintf("%s_%s.json", timestamp, log.RequestID)
filepath := filepath.Join(p.requestLogPath, filename)
// 将日志序列化为JSON
logData, err := json.MarshalIndent(log, "", " ")
m, err := l.usecase.SelectModelWithLoadBalancing("", mt)
if err != nil {
p.logger.Error("序列化请求日志失败", "error", err)
l.logger.Error("select model with load balancing failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
return
}
// 写入文件
if err := os.WriteFile(filepath, logData, 0644); err != nil {
p.logger.Error("写入请求日志文件失败", "error", err, "path", filepath)
return
}
p.logger.Debug("请求响应日志已保存", "path", filepath)
}
func (p *LLMProxy) AcceptCompletion(ctx context.Context, req *domain.AcceptCompletionReq) error {
return p.usecase.AcceptCompletion(ctx, req)
}
func writeErrResp(w http.ResponseWriter, code int, message, errorType string) {
resp := ErrResp{
Error: struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code,omitempty"`
}{
Message: message,
Type: errorType,
},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
b, _ := json.Marshal(resp)
w.Write(b)
}
type Ctx struct {
UserID string
RequestID string
SourceIP string
}
func (p *LLMProxy) handle(ctx context.Context, fn func(ctx *Ctx, log *RequestResponseLog) error) {
// 获取用户ID
userID := "unknown"
if id, ok := ctx.Value(logger.UserIDKey{}).(string); ok {
userID = id
}
requestID := "unknown"
if id, ok := ctx.Value(logger.RequestIDKey{}).(string); ok {
requestID = id
}
sourceip := "unknown"
if ip, ok := ctx.Value("remote_addr").(string); ok {
sourceip = ip
}
// 创建请求日志结构
l := &RequestResponseLog{
Timestamp: time.Now(),
RequestID: requestID,
ModelType: consts.ModelTypeCoder,
SourceIP: sourceip,
}
c := &Ctx{
UserID: userID,
RequestID: requestID,
SourceIP: sourceip,
}
if err := fn(c, l); err != nil {
p.logger.With("source_ip", sourceip).ErrorContext(ctx, "处理请求失败", "error", err)
l.Error = err.Error()
}
go p.saveRequestResponseLog(l)
}
func (p *LLMProxy) HandleCompletion(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) {
if req.Stream {
p.handleCompletionStream(ctx, w, req)
} else {
p.handleCompletion(ctx, w, req)
}
}
func (p *LLMProxy) handleCompletionStream(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) {
endpoint := "/completions"
p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error {
// 使用负载均衡算法选择模型
m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeCoder)
if err != nil {
p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "模型选择失败", "error", err)
writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error")
return err
}
// 构造上游API URL
upstream := m.APIBase + endpoint
log.UpstreamURL = upstream
startTime := time.Now()
// 创建上游请求
body, err := json.Marshal(req)
if err != nil {
p.logger.ErrorContext(ctx, "序列化请求体失败", "error", err)
return fmt.Errorf("序列化请求体失败: %w", err)
}
upreq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstream, bytes.NewReader(body))
if err != nil {
p.logger.With("upstream", upstream).WarnContext(ctx, "创建上游流式请求失败", "error", err)
return fmt.Errorf("创建上游请求失败: %w", err)
}
// 设置请求头
upreq.Header.Set("Content-Type", "application/json")
upreq.Header.Set("Accept", "text/event-stream")
if m.APIKey != "" && m.APIKey != "none" {
upreq.Header.Set("Authorization", "Bearer "+m.APIKey)
}
// 保存请求头(去除敏感信息)
requestHeaders := make(map[string][]string)
for k, v := range upreq.Header {
if k != "Authorization" {
requestHeaders[k] = v
} else {
// 敏感信息脱敏
requestHeaders[k] = []string{"Bearer ***"}
}
}
log.RequestHeader = requestHeaders
p.logger.With(
"upstreamURL", upstream,
"modelName", m.ModelName,
"modelType", consts.ModelTypeLLM,
"apiBase", m.APIBase,
"requestHeader", upreq.Header,
"requestBody", upreq,
).DebugContext(ctx, "转发流式请求到上游API")
// 发送请求
resp, err := p.client.Do(upreq)
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游流式请求失败", "error", err)
return fmt.Errorf("发送上游请求失败: %w", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
// 更新日志错误信息
log.StatusCode = resp.StatusCode
log.ResponseHeader = resp.Header
log.ResponseBody = string(responseBody)
log.Latency = time.Since(startTime).Milliseconds()
// 在debug级别记录错误的流式响应内容
p.logger.With(
"statusCode", resp.StatusCode,
"responseHeader", resp.Header,
"responseBody", string(responseBody),
).DebugContext(ctx, "上游流式响应错误原始内容")
var errorResp ErrResp
if err := json.Unmarshal(responseBody, &errorResp); err == nil {
p.logger.With(
"endpoint", endpoint,
"upstreamURL", upstream,
"requestBody", upreq,
"statusCode", resp.StatusCode,
"errorType", errorResp.Error.Type,
"errorCode", errorResp.Error.Code,
"errorMessage", errorResp.Error.Message,
"latency", time.Since(startTime),
).WarnContext(ctx, "上游API流式请求异常详情")
return fmt.Errorf("上游API返回错误: %s", errorResp.Error.Message)
}
p.logger.With(
"endpoint", endpoint,
"upstreamURL", upstream,
"requestBody", upreq,
"statusCode", resp.StatusCode,
"responseBody", string(responseBody),
).WarnContext(ctx, "上游API流式请求异常详情")
return fmt.Errorf("上游API返回非200状态码: %d, 响应: %s", resp.StatusCode, string(responseBody))
}
// 更新日志信息
log.StatusCode = resp.StatusCode
log.ResponseHeader = resp.Header
// 在debug级别记录流式响应头信息
p.logger.With(
"statusCode", resp.StatusCode,
"responseHeader", resp.Header,
).DebugContext(ctx, "上游流式响应头信息")
// 设置响应头
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Transfer-Encoding", "chunked")
rc := &domain.RecordParam{
UserID: c.UserID,
ModelID: m.ID,
ModelType: consts.ModelTypeLLM,
Prompt: req.Prompt.(string),
Role: consts.ChatRoleAssistant,
}
buf := bufio.NewWriterSize(w, 32*1024)
defer buf.Flush()
ch := make(chan []byte, 1024)
defer close(ch)
go func(rc *domain.RecordParam) {
for line := range ch {
if bytes.HasPrefix(line, []byte("data:")) {
line = bytes.TrimPrefix(line, []byte("data: "))
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.Equal(line, []byte("[DONE]")) {
break
}
var t openai.CompletionResponse
if err := json.Unmarshal(line, &t); err != nil {
p.logger.With("line", string(line)).WarnContext(ctx, "解析流式数据失败", "error", err)
continue
}
p.logger.With("response", t).DebugContext(ctx, "流式响应数据")
if len(t.Choices) > 0 {
rc.Completion += t.Choices[0].Text
}
rc.InputTokens = int64(t.Usage.PromptTokens)
rc.OutputTokens += int64(t.Usage.CompletionTokens)
}
}
if rc.OutputTokens == 0 {
return
}
p.logger.With("record", rc).DebugContext(ctx, "流式记录")
if err := p.usecase.Record(context.Background(), rc); err != nil {
p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM).
WarnContext(ctx, "插入流式记录失败", "error", err)
}
}(rc)
err = streamRead(ctx, resp.Body, func(line []byte) error {
ch <- line
if _, err := buf.Write(line); err != nil {
return fmt.Errorf("写入响应失败: %w", err)
}
return buf.Flush()
if r.In.ContentLength > 0 {
tee := tee.NewReqTeeWithMaxSize(r.In.Body, 10*1024*1024)
r.Out.Body = tee
ctx := context.WithValue(r.In.Context(), CtxKey{}, &ProxyCtx{
ctx: r.In.Context(),
Path: r.In.URL.Path,
Model: m,
ReqTee: tee,
RequestID: r.In.Context().Value(logger.RequestIDKey{}).(string),
UserID: r.In.Context().Value(logger.UserIDKey{}).(string),
Header: r.In.Header,
})
return err
})
}
r.Out = r.Out.WithContext(ctx)
}
func (p *LLMProxy) handleCompletion(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) {
endpoint := "/completions"
p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error {
// 使用负载均衡算法选择模型
m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeCoder)
if err != nil {
p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "模型选择失败", "error", err)
writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error")
return err
}
// 构造上游API URL
upstream := m.APIBase + endpoint
log.UpstreamURL = upstream
u, err := url.Parse(upstream)
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "解析上游URL失败", "error", err)
writeErrResp(w, http.StatusInternalServerError, "无效的上游URL", "proxy_error")
return err
}
startTime := time.Now()
client := request.NewClient(
u.Scheme,
u.Host,
30*time.Second,
request.WithClient(p.client),
)
client.SetDebug(p.cfg.Debug)
resp, err := request.Post[openai.CompletionResponse](client, u.Path, req, request.WithHeader(request.Header{
"Authorization": "Bearer " + m.APIKey,
}))
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游请求失败", "error", err)
writeErrResp(w, http.StatusInternalServerError, "发送上游请求失败", "proxy_error")
log.Latency = time.Since(startTime).Milliseconds()
return err
}
latency := time.Since(startTime)
// 记录请求信息
p.logger.With(
"statusCode", http.StatusOK,
"upstreamURL", upstream,
"modelName", req.Model,
"modelType", consts.ModelTypeCoder,
"apiBase", m.APIBase,
"requestBody", req,
"resp", resp,
"latency", latency.String(),
).DebugContext(ctx, "转发请求到上游API")
go p.recordCompletion(c, m.ID, req, resp)
// 更新请求日志
log.StatusCode = http.StatusOK
log.ResponseBody = resp
log.ResponseHeader = resp.Header()
log.Latency = latency.Milliseconds()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
b, err := json.Marshal(resp)
if err != nil {
p.logger.With("response", resp).WarnContext(ctx, "序列化响应失败", "error", err)
return err
}
w.Write(b)
return nil
})
}
func (p *LLMProxy) recordCompletion(c *Ctx, modelID string, req domain.CompletionRequest, resp *openai.CompletionResponse) {
if resp.Usage.CompletionTokens == 0 {
u, err := url.Parse(m.APIBase)
if err != nil {
l.logger.ErrorContext(r.In.Context(), "parse model api base failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
return
}
ctx := context.Background()
prompt := req.Prompt.(string)
rc := &domain.RecordParam{
TaskID: resp.ID,
UserID: c.UserID,
ModelID: modelID,
ModelType: consts.ModelTypeCoder,
Prompt: prompt,
ProgramLanguage: req.Metadata["program_language"],
InputTokens: int64(resp.Usage.PromptTokens),
OutputTokens: int64(resp.Usage.CompletionTokens),
}
for _, choice := range resp.Choices {
rc.Completion += choice.Text
}
lines := strings.Count(rc.Completion, "\n") + 1
rc.CodeLines = int64(lines)
if err := p.usecase.Record(ctx, rc); err != nil {
p.logger.With("modelID", modelID, "modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "记录请求失败", "error", err)
r.Out.URL.Scheme = u.Scheme
r.Out.URL.Host = u.Host
r.Out.URL.Path = r.In.URL.Path
r.Out.Header.Set("Authorization", "Bearer "+m.APIKey)
r.SetXForwarded()
r.Out.Host = u.Host
l.logger.With("in", r.In.URL.Path, "out", r.Out.URL.Path).DebugContext(r.In.Context(), "rewrite request")
}
func (l *LLMProxy) modifyResponse(resp *http.Response) error {
ctx := resp.Request.Context()
if pctx, ok := ctx.Value(CtxKey{}).(*ProxyCtx); ok {
pctx.ctx = ctx
pctx.RespHeader = resp.Header
resp.Body = NewRecorder(l.cfg, pctx, resp.Body, l.logger, l.usecase)
}
return nil
}
func (p *LLMProxy) HandleChatCompletion(ctx context.Context, w http.ResponseWriter, req *openai.ChatCompletionRequest) {
if req.Stream {
p.handleChatCompletionStream(ctx, w, req)
} else {
p.handleChatCompletion(ctx, w, req)
}
}
func streamRead(ctx context.Context, r io.Reader, fn func([]byte) error) error {
reader := bufio.NewReaderSize(r, 32*1024)
for {
select {
case <-ctx.Done():
return fmt.Errorf("流式请求被取消: %w", ctx.Err())
default:
line, err := reader.ReadBytes('\n')
if err == io.EOF {
return nil
}
if err != nil {
return fmt.Errorf("读取流式数据失败: %w", err)
}
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
line = append(line, '\n')
line = append(line, '\n')
fn(line)
}
}
}
func (p *LLMProxy) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string {
prompt := ""
parse := promptparser.New(promptparser.KindTask)
for _, message := range req.Messages {
if message.Role == "system" {
continue
}
if strings.Contains(message.Content, "<task>") ||
strings.Contains(message.Content, "<feedback>") ||
strings.Contains(message.Content, "<user_message>") {
if info, err := parse.Parse(message.Content); err == nil {
prompt = info.Prompt
} else {
p.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err)
}
}
for _, m := range message.MultiContent {
if strings.Contains(m.Text, "<task>") ||
strings.Contains(m.Text, "<feedback>") ||
strings.Contains(m.Text, "<user_message>") {
if info, err := parse.Parse(m.Text); err == nil {
prompt = info.Prompt
} else {
p.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err)
}
}
}
}
return prompt
}
func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.ResponseWriter, req *openai.ChatCompletionRequest) {
endpoint := "/chat/completions"
p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error {
startTime := time.Now()
m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeLLM)
if err != nil {
p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeLLM).WarnContext(ctx, "流式请求模型选择失败", "error", err)
writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error")
return err
}
upstream := m.APIBase + endpoint
log.UpstreamURL = upstream
body, err := json.Marshal(req)
if err != nil {
p.logger.ErrorContext(ctx, "序列化请求体失败", "error", err)
return fmt.Errorf("序列化请求体失败: %w", err)
}
newReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstream, bytes.NewReader(body))
if err != nil {
p.logger.With("upstream", upstream).WarnContext(ctx, "创建上游流式请求失败", "error", err)
return fmt.Errorf("创建上游请求失败: %w", err)
}
newReq.Header.Set("Content-Type", "application/json")
newReq.Header.Set("Accept", "text/event-stream")
newReq.Header.Set("Authorization", "Bearer "+m.APIKey)
// 保存请求头(去除敏感信息)
requestHeaders := make(map[string][]string)
for k, v := range newReq.Header {
if k != "Authorization" {
requestHeaders[k] = v
} else {
// 敏感信息脱敏
requestHeaders[k] = []string{"Bearer ***"}
}
}
log.RequestHeader = requestHeaders
logger := p.logger.With(
"request_id", c.RequestID,
"source_ip", c.SourceIP,
"upstreamURL", upstream,
"modelName", m.ModelName,
"modelType", consts.ModelTypeLLM,
"apiBase", m.APIBase,
)
logger.With(
"upstreamURL", upstream,
"requestHeader", newReq.Header,
"requestBody", req,
"messages", cvt.Filter(req.Messages, func(i int, v openai.ChatCompletionMessage) (openai.ChatCompletionMessage, bool) {
return v, v.Role != "system"
}),
).DebugContext(ctx, "转发流式请求到上游API")
// 发送请求
resp, err := p.streamClient.Do(newReq)
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游流式请求失败", "error", err)
return fmt.Errorf("发送上游请求失败: %w", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
// 更新日志错误信息
log.StatusCode = resp.StatusCode
log.ResponseHeader = resp.Header
log.ResponseBody = string(responseBody)
log.Latency = time.Since(startTime).Milliseconds()
// 在debug级别记录错误的流式响应内容
logger.With(
"statusCode", resp.StatusCode,
"responseHeader", resp.Header,
"responseBody", string(responseBody),
).DebugContext(ctx, "上游流式响应错误原始内容")
var errorResp ErrResp
if err := json.Unmarshal(responseBody, &errorResp); err == nil {
logger.With(
"endpoint", endpoint,
"requestBody", newReq,
"statusCode", resp.StatusCode,
"errorType", errorResp.Error.Type,
"errorCode", errorResp.Error.Code,
"errorMessage", errorResp.Error.Message,
"latency", time.Since(startTime),
).WarnContext(ctx, "上游API流式请求异常详情")
return fmt.Errorf("上游API返回错误: %s", errorResp.Error.Message)
}
logger.With(
"endpoint", endpoint,
"requestBody", newReq,
"statusCode", resp.StatusCode,
"responseBody", string(responseBody),
).WarnContext(ctx, "上游API流式请求异常详情")
return fmt.Errorf("上游API返回非200状态码: %d, 响应: %s", resp.StatusCode, string(responseBody))
}
log.StatusCode = resp.StatusCode
log.ResponseHeader = resp.Header
logger.With(
"statusCode", resp.StatusCode,
"responseHeader", resp.Header,
).DebugContext(ctx, "上游流式响应头信息")
// 设置响应头
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("X-Accel-Buffering", "no")
recorder := NewChatRecorder(
ctx,
c,
p.usecase,
m,
req,
resp.Body,
w,
p.logger.With("module", "ChatRecorder"),
)
defer recorder.Close()
return recorder.Stream()
})
}
func (p *LLMProxy) handleChatCompletion(ctx context.Context, w http.ResponseWriter, req *openai.ChatCompletionRequest) {
endpoint := "/chat/completions"
p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error {
// 使用负载均衡算法选择模型
m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeCoder)
if err != nil {
p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "模型选择失败", "error", err)
writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error")
return err
}
// 构造上游API URL
upstream := m.APIBase + endpoint
log.UpstreamURL = upstream
u, err := url.Parse(upstream)
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "解析上游URL失败", "error", err)
writeErrResp(w, http.StatusInternalServerError, "无效的上游URL", "proxy_error")
return err
}
startTime := time.Now()
prompt := p.getPrompt(ctx, req)
mode := req.Metadata["mode"]
taskID := req.Metadata["task_id"]
client := request.NewClient(
u.Scheme,
u.Host,
30*time.Second,
request.WithClient(p.client),
)
resp, err := request.Post[openai.ChatCompletionResponse](client, u.Path, req, request.WithHeader(request.Header{
"Authorization": "Bearer " + m.APIKey,
}))
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游请求失败", "error", err)
writeErrResp(w, http.StatusInternalServerError, "发送上游请求失败", "proxy_error")
log.Latency = time.Since(startTime).Milliseconds()
return err
}
// 记录请求信息
p.logger.With(
"upstreamURL", upstream,
"modelName", req.Model,
"modelType", consts.ModelTypeCoder,
"apiBase", m.APIBase,
"requestBody", req,
).DebugContext(ctx, "转发请求到上游API")
go func() {
rc := &domain.RecordParam{
TaskID: taskID,
UserID: c.UserID,
Prompt: prompt,
WorkMode: mode,
ModelID: m.ID,
ModelType: m.ModelType,
InputTokens: int64(resp.Usage.PromptTokens),
OutputTokens: int64(resp.Usage.CompletionTokens),
ProgramLanguage: req.Metadata["program_language"],
}
for _, choice := range resp.Choices {
rc.Completion += choice.Message.Content + "\n\n"
}
p.logger.With("record", rc).DebugContext(ctx, "记录")
if err := p.usecase.Record(ctx, rc); err != nil {
p.logger.With("modelID", m.ID, "modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "记录请求失败", "error", err)
}
}()
// 计算请求耗时
latency := time.Since(startTime)
// 更新请求日志
log.StatusCode = http.StatusOK
log.ResponseBody = resp
log.ResponseHeader = resp.Header()
log.Latency = latency.Milliseconds()
// 记录响应状态
p.logger.With(
"statusCode", http.StatusOK,
"responseHeader", resp.Header(),
"responseBody", resp,
"latency", latency.String(),
).DebugContext(ctx, "上游API响应")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
b, err := json.Marshal(resp)
if err != nil {
p.logger.With("response", resp).WarnContext(ctx, "序列化响应失败", "error", err)
return err
}
w.Write(b)
return nil
})
}
func (p *LLMProxy) HandleEmbeddings(ctx context.Context, w http.ResponseWriter, req *openai.EmbeddingRequest) {
func (l *LLMProxy) errorHandler(w http.ResponseWriter, r *http.Request, err error) {
l.logger.ErrorContext(r.Context(), "error handler", slog.String("path", r.URL.Path), slog.Any("err", err))
}

View File

@@ -5,191 +5,331 @@ import (
"encoding/json"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/rokku-c/go-openai"
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/pkg/promptparser"
"github.com/chaitin/MonkeyCode/backend/pkg/tee"
)
type ChatRecorder struct {
*tee.Tee
ctx context.Context
cx *Ctx
usecase domain.ProxyUsecase
req *openai.ChatCompletionRequest
logger *slog.Logger
model *domain.Model
completion strings.Builder // 累积完整的响应内容
usage *openai.Usage // 最终的使用统计
buffer strings.Builder // 缓存不完整的行
recorded bool // 标记是否已记录完整对话
type Recorder struct {
cfg *config.Config
usecase domain.ProxyUsecase
shadown chan []byte
src io.ReadCloser
ctx *ProxyCtx
logger *slog.Logger
logFile *os.File
}
func NewChatRecorder(
ctx context.Context,
cx *Ctx,
usecase domain.ProxyUsecase,
model *domain.Model,
req *openai.ChatCompletionRequest,
r io.Reader,
w io.Writer,
var _ io.ReadCloser = &Recorder{}
func NewRecorder(
cfg *config.Config,
ctx *ProxyCtx,
src io.ReadCloser,
logger *slog.Logger,
) *ChatRecorder {
c := &ChatRecorder{
ctx: ctx,
cx: cx,
usecase domain.ProxyUsecase,
) *Recorder {
r := &Recorder{
cfg: cfg,
usecase: usecase,
model: model,
req: req,
shadown: make(chan []byte, 128*1024),
src: src,
ctx: ctx,
logger: logger,
}
c.Tee = tee.NewTee(ctx, logger, r, w, c.handle)
return c
go r.handleShadow()
return r
}
func (c *ChatRecorder) handle(ctx context.Context, data []byte) error {
c.buffer.Write(data)
bufferContent := c.buffer.String()
func formatHeader(header http.Header) map[string]string {
headerMap := make(map[string]string)
for key, values := range header {
headerMap[key] = strings.Join(values, ",")
}
return headerMap
}
lines := strings.Split(bufferContent, "\n")
if len(lines) > 0 {
lastLine := lines[len(lines)-1]
lines = lines[:len(lines)-1]
c.buffer.Reset()
c.buffer.WriteString(lastLine)
func (r *Recorder) handleShadow() {
body, err := r.ctx.ReqTee.GetBody()
if err != nil {
r.logger.WarnContext(r.ctx.ctx, "get req tee body failed", "error", err)
return
}
for _, line := range lines {
if err := c.processSSELine(ctx, line); err != nil {
return err
var (
taskID, mode, prompt, language string
)
switch r.ctx.Model.ModelType {
case consts.ModelTypeLLM:
var req openai.ChatCompletionRequest
if err := json.Unmarshal(body, &req); err != nil {
r.logger.WarnContext(r.ctx.ctx, "unmarshal chat completion request failed", "error", err)
return
}
prompt = r.getPrompt(r.ctx.ctx, &req)
taskID = req.Metadata["task_id"]
mode = req.Metadata["mode"]
case consts.ModelTypeCoder:
var req domain.CompletionRequest
if err := json.Unmarshal(body, &req); err != nil {
r.logger.WarnContext(r.ctx.ctx, "unmarshal completion request failed", "error", err)
return
}
prompt = req.Prompt.(string)
taskID = req.Metadata["task_id"]
mode = req.Metadata["mode"]
language = req.Metadata["program_language"]
default:
r.logger.WarnContext(r.ctx.ctx, "skip handle shadow, model type not support", "modelType", r.ctx.Model.ModelType)
return
}
r.createFile(taskID)
r.writeMeta(body)
rc := &domain.RecordParam{
RequestID: r.ctx.RequestID,
TaskID: taskID,
UserID: r.ctx.UserID,
ModelID: r.ctx.Model.ID,
ModelType: r.ctx.Model.ModelType,
WorkMode: mode,
Prompt: prompt,
ProgramLanguage: language,
Role: consts.ChatRoleUser,
}
var assistantRc *domain.RecordParam
ct := r.ctx.RespHeader.Get("Content-Type")
if strings.Contains(ct, "stream") {
r.handleStream(rc)
if r.ctx.Model.ModelType == consts.ModelTypeLLM {
assistantRc = rc.Clone()
assistantRc.Role = consts.ChatRoleAssistant
rc.Completion = ""
rc.OutputTokens = 0
}
} else {
r.handleJson(rc)
}
r.logger.
With("header", formatHeader(r.ctx.Header)).
With("resp_header", formatHeader(r.ctx.RespHeader)).
DebugContext(r.ctx.ctx, "handle shadow", "rc", rc)
if err := r.usecase.Record(context.Background(), rc); err != nil {
r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err)
}
if assistantRc != nil {
if err := r.usecase.Record(context.Background(), assistantRc); err != nil {
r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err)
}
}
}
func (r *Recorder) writeMeta(body []byte) {
if r.logFile == nil {
return
}
r.logFile.WriteString("------------------ Request Header ------------------\n")
for key, value := range formatHeader(r.ctx.Header) {
r.logFile.WriteString(key + ": " + value + "\n")
}
r.logFile.WriteString("------------------ Request Body ------------------\n")
r.logFile.WriteString(string(body))
r.logFile.WriteString("\n")
r.logFile.WriteString("------------------ Response Header ------------------\n")
for key, value := range formatHeader(r.ctx.RespHeader) {
r.logFile.WriteString(key + ": " + value + "\n")
}
r.logFile.WriteString("------------------ Response Body ------------------\n")
}
func (r *Recorder) createFile(taskID string) {
if r.cfg.LLMProxy.RequestLogPath == "" {
return
}
dir := filepath.Join(r.cfg.LLMProxy.RequestLogPath, time.Now().Format("2006010215"))
if err := os.MkdirAll(dir, 0755); err != nil {
r.logger.WarnContext(r.ctx.ctx, "create dir failed", "error", err)
return
}
id := r.ctx.RequestID
if taskID != "" {
id = taskID
}
filename := filepath.Join(dir, id+".log")
f, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
r.logger.WarnContext(r.ctx.ctx, "create file failed", "error", err)
return
}
r.logFile = f
}
func (r *Recorder) handleJson(rc *domain.RecordParam) {
buffer := strings.Builder{}
for data := range r.shadown {
buffer.Write(data)
if r.logFile != nil {
r.logFile.Write(data)
}
}
return nil
switch rc.ModelType {
case consts.ModelTypeLLM:
var resp openai.ChatCompletionResponse
if err := json.Unmarshal([]byte(buffer.String()), &resp); err != nil {
r.logger.WarnContext(r.ctx.ctx, "unmarshal chat completion response failed", "error", err)
return
}
if len(resp.Choices) > 0 {
rc.Completion = resp.Choices[0].Message.Content
rc.InputTokens = int64(resp.Usage.PromptTokens)
rc.OutputTokens = int64(resp.Usage.CompletionTokens)
}
case consts.ModelTypeCoder:
var resp openai.CompletionResponse
if err := json.Unmarshal([]byte(buffer.String()), &resp); err != nil {
r.logger.WarnContext(r.ctx.ctx, "unmarshal completion response failed", "error", err)
return
}
if rc.TaskID == "" {
rc.TaskID = resp.ID
}
rc.InputTokens = int64(resp.Usage.PromptTokens)
rc.OutputTokens = int64(resp.Usage.CompletionTokens)
if len(resp.Choices) > 0 {
rc.Completion = resp.Choices[0].Text
rc.CodeLines = int64(strings.Count(resp.Choices[0].Text, "\n"))
}
}
}
func (c *ChatRecorder) processSSELine(ctx context.Context, line string) error {
line = strings.TrimSpace(line)
func (r *Recorder) handleStream(rc *domain.RecordParam) {
buffer := strings.Builder{}
for data := range r.shadown {
buffer.Write(data)
cnt := buffer.String()
if r.logFile != nil {
r.logFile.Write(data)
}
lines := strings.Split(cnt, "\n")
if len(lines) > 0 {
lastLine := lines[len(lines)-1]
lines = lines[:len(lines)-1]
buffer.Reset()
buffer.WriteString(lastLine)
}
for _, line := range lines {
if err := r.processSSELine(r.ctx.ctx, line, rc); err != nil {
r.logger.WarnContext(r.ctx.ctx, "处理SSE行失败", "error", err)
}
}
}
}
func (r *Recorder) processSSELine(ctx context.Context, line string, rc *domain.RecordParam) error {
line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "data:") {
return nil
}
dataContent := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if dataContent == "" {
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if data == "" {
return nil
}
if dataContent == "[DONE]" {
c.processCompletedChat(ctx)
if data == "[DONE]" {
return nil
}
var resp openai.ChatCompletionStreamResponse
if err := json.Unmarshal([]byte(dataContent), &resp); err != nil {
c.logger.With("data", dataContent).WarnContext(ctx, "解析流式响应失败", "error", err)
return nil
}
switch r.ctx.Model.ModelType {
case consts.ModelTypeLLM:
var resp openai.ChatCompletionStreamResponse
if err := json.Unmarshal([]byte(data), &resp); err != nil {
r.logger.With("model_type", r.ctx.Model.ModelType).With("data", data).WarnContext(ctx, "解析SSE行失败", "error", err)
return nil
}
if resp.Usage != nil {
rc.InputTokens = int64(resp.Usage.PromptTokens)
rc.OutputTokens += int64(resp.Usage.CompletionTokens)
}
if len(resp.Choices) > 0 {
content := resp.Choices[0].Delta.Content
if content != "" {
rc.Completion += content
}
}
prompt := c.getPrompt(ctx, c.req)
mode := c.req.Metadata["mode"]
taskID := c.req.Metadata["task_id"]
case consts.ModelTypeCoder:
var resp openai.CompletionResponse
if err := json.Unmarshal([]byte(data), &resp); err != nil {
r.logger.With("model_type", r.ctx.Model.ModelType).With("data", data).WarnContext(ctx, "解析SSE行失败", "error", err)
return nil
}
rc := &domain.RecordParam{
RequestID: c.cx.RequestID,
TaskID: taskID,
UserID: c.cx.UserID,
ModelID: c.model.ID,
ModelType: c.model.ModelType,
WorkMode: mode,
Prompt: prompt,
Role: consts.ChatRoleAssistant,
}
if resp.Usage != nil {
if rc.TaskID == "" {
rc.TaskID = resp.ID
}
rc.InputTokens = int64(resp.Usage.PromptTokens)
}
if rc.Prompt != "" {
urc := rc.Clone()
urc.Role = consts.ChatRoleUser
urc.Completion = urc.Prompt
if err := c.usecase.Record(context.Background(), urc); err != nil {
c.logger.With("modelID", c.model.ID, "modelName", c.model.ModelName, "modelType", consts.ModelTypeLLM).
WarnContext(ctx, "插入流式记录失败", "error", err)
rc.OutputTokens += int64(resp.Usage.CompletionTokens)
if len(resp.Choices) > 0 {
rc.Completion += resp.Choices[0].Text
rc.CodeLines += int64(strings.Count(resp.Choices[0].Text, "\n"))
}
}
if len(resp.Choices) > 0 {
content := resp.Choices[0].Delta.Content
if content != "" {
c.completion.WriteString(content)
}
}
if resp.Usage != nil {
c.usage = resp.Usage
}
return nil
}
func (c *ChatRecorder) processCompletedChat(ctx context.Context) {
if c.recorded {
return // 避免重复记录
// Close implements io.ReadCloser.
func (r *Recorder) Close() error {
if r.shadown != nil {
close(r.shadown)
}
mode := c.req.Metadata["mode"]
taskID := c.req.Metadata["task_id"]
rc := &domain.RecordParam{
RequestID: c.cx.RequestID,
TaskID: taskID,
UserID: c.cx.UserID,
ModelID: c.model.ID,
ModelType: c.model.ModelType,
WorkMode: mode,
Role: consts.ChatRoleAssistant,
Completion: c.completion.String(),
InputTokens: int64(c.usage.PromptTokens),
OutputTokens: int64(c.usage.CompletionTokens),
}
if err := c.usecase.Record(context.Background(), rc); err != nil {
c.logger.With("modelID", c.model.ID, "modelName", c.model.ModelName, "modelType", consts.ModelTypeLLM).
WarnContext(ctx, "插入流式记录失败", "error", err)
} else {
c.recorded = true
c.logger.With("requestID", c.cx.RequestID, "completion_length", len(c.completion.String())).
InfoContext(ctx, "流式对话记录已保存")
if r.logFile != nil {
r.logFile.Close()
}
return r.src.Close()
}
// Close 关闭 recorder 并确保数据被保存
func (c *ChatRecorder) Close() {
// 如果有累积的内容但还没有记录,强制保存
if !c.recorded && c.completion.Len() > 0 {
c.logger.With("requestID", c.cx.RequestID).
WarnContext(c.ctx, "数据流异常中断,强制保存已累积的内容")
c.processCompletedChat(c.ctx)
// Read implements io.ReadCloser.
func (r *Recorder) Read(p []byte) (n int, err error) {
n, err = r.src.Read(p)
if n > 0 {
data := make([]byte, n)
copy(data, p[:n])
r.shadown <- data
}
if c.Tee != nil {
c.Tee.Close()
if err != nil {
return
}
return
}
func (c *ChatRecorder) Reset() {
c.completion.Reset()
c.buffer.Reset()
c.usage = nil
c.recorded = false
}
func (c *ChatRecorder) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string {
func (r *Recorder) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string {
prompt := ""
parse := promptparser.New(promptparser.KindTask)
for _, message := range req.Messages {
@@ -203,7 +343,7 @@ func (c *ChatRecorder) getPrompt(ctx context.Context, req *openai.ChatCompletion
if info, err := parse.Parse(message.Content); err == nil {
prompt = info.Prompt
} else {
c.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err)
r.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err)
}
}
@@ -214,7 +354,7 @@ func (c *ChatRecorder) getPrompt(ctx context.Context, req *openai.ChatCompletion
if info, err := parse.Parse(m.Text); err == nil {
prompt = info.Prompt
} else {
c.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err)
r.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err)
}
}
}

View File

@@ -23,6 +23,7 @@ type UserHandler struct {
session *session.Session
logger *slog.Logger
cfg *config.Config
limitCh chan struct{}
}
func NewUserHandler(
@@ -40,6 +41,7 @@ func NewUserHandler(
logger: logger,
cfg: cfg,
euse: euse,
limitCh: make(chan struct{}, cfg.Extension.Limit),
}
w.GET("/api/v1/static/vsix/:version", web.BaseHandler(u.VSIXDownload))
@@ -94,6 +96,11 @@ func (h *UserHandler) VSCodeAuthInit(c *web.Context, req domain.VSCodeAuthInitRe
// @Produce octet-stream
// @Router /api/v1/static/vsix [get]
func (h *UserHandler) VSIXDownload(c *web.Context) error {
h.limitCh <- struct{}{}
defer func() {
<-h.limitCh
}()
v, err := h.euse.GetByVersion(c.Request().Context(), c.Param("version"))
if err != nil {
return err

81
backend/pkg/tee/reqtee.go Normal file
View File

@@ -0,0 +1,81 @@
package tee
import (
"bytes"
"fmt"
"io"
"sync"
"time"
)
type ReqTee struct {
src io.ReadCloser
buf bytes.Buffer
done chan struct{}
mu sync.RWMutex
closed bool
maxSize int64 // 最大缓冲区大小0表示无限制
}
var _ io.ReadCloser = &ReqTee{}
func NewReqTee(src io.ReadCloser) *ReqTee {
return NewReqTeeWithMaxSize(src, 0)
}
// NewReqTeeWithMaxSize 创建带最大缓冲区大小限制的ReqTee
func NewReqTeeWithMaxSize(src io.ReadCloser, maxSize int64) *ReqTee {
return &ReqTee{
src: src,
buf: bytes.Buffer{},
done: make(chan struct{}),
maxSize: maxSize,
}
}
func (r *ReqTee) GetBody() ([]byte, error) {
return r.GetBodyWithTimeout(30 * time.Second)
}
// GetBodyWithTimeout 获取缓冲的数据,支持自定义超时时间
func (r *ReqTee) GetBodyWithTimeout(timeout time.Duration) ([]byte, error) {
select {
case <-r.done:
r.mu.RLock()
defer r.mu.RUnlock()
return r.buf.Bytes(), nil
case <-time.After(timeout):
return nil, fmt.Errorf("timeout waiting for data")
}
}
// Close implements io.ReadCloser.
func (r *ReqTee) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return nil
}
r.closed = true
close(r.done)
return r.src.Close()
}
// Read implements io.ReadCloser.
func (r *ReqTee) Read(p []byte) (n int, err error) {
n, err = r.src.Read(p)
if n > 0 {
r.mu.Lock()
// 检查缓冲区大小限制
if r.maxSize > 0 && int64(r.buf.Len()+n) > r.maxSize {
r.mu.Unlock()
return n, fmt.Errorf("buffer size limit exceeded: %d bytes", r.maxSize)
}
// 直接写入缓冲区,避免额外的内存分配和复制
r.buf.Write(p[:n])
r.mu.Unlock()
}
return n, err
}

View File

@@ -1,103 +0,0 @@
package tee
import (
"context"
"io"
"log/slog"
"net/http"
"sync"
)
type TeeHandleFunc func(ctx context.Context, data []byte) error
var bufferPool = sync.Pool{
New: func() any {
buf := make([]byte, 4096)
return &buf
},
}
type Tee struct {
ctx context.Context
logger *slog.Logger
Reader io.Reader
Writer io.Writer
ch chan []byte
handle TeeHandleFunc
}
func NewTee(
ctx context.Context,
logger *slog.Logger,
reader io.Reader,
writer io.Writer,
handle TeeHandleFunc,
) *Tee {
t := &Tee{
ctx: ctx,
logger: logger,
Reader: reader,
Writer: writer,
handle: handle,
ch: make(chan []byte, 32*1024),
}
go t.Handle()
return t
}
func (t *Tee) Close() {
select {
case <-t.ch:
// channel 已经关闭
default:
close(t.ch)
}
}
func (t *Tee) Handle() {
for {
select {
case data, ok := <-t.ch:
if !ok {
t.logger.DebugContext(t.ctx, "Tee Handle closed")
return
}
err := t.handle(t.ctx, data)
if err != nil {
t.logger.With("data", string(data)).With("error", err).ErrorContext(t.ctx, "Tee Handle error")
return
}
case <-t.ctx.Done():
t.logger.DebugContext(t.ctx, "Tee Handle ctx done")
return
}
}
}
func (t *Tee) Stream() error {
bufPtr := bufferPool.Get().(*[]byte)
buf := *bufPtr
defer bufferPool.Put(bufPtr)
for {
n, err := t.Reader.Read(buf)
if err != nil {
if err == io.EOF {
return nil
}
return err
}
if n > 0 {
_, err = t.Writer.Write(buf[:n])
if err != nil {
return err
}
if flusher, ok := t.Writer.(http.Flusher); ok {
flusher.Flush()
}
data := make([]byte, n)
copy(data, buf[:n])
t.ch <- data
}
}
}

View File

@@ -1,376 +0,0 @@
package tee
import (
"bytes"
"context"
"errors"
"io"
"log/slog"
"os"
"strings"
"sync"
"testing"
"time"
)
// mockWriter 模拟 Writer 接口
type mockWriter struct {
buf *bytes.Buffer
delay time.Duration
errorOn int // 在第几次写入时返回错误
count int
}
func newMockWriter() *mockWriter {
return &mockWriter{
buf: &bytes.Buffer{},
}
}
func (m *mockWriter) Write(p []byte) (n int, err error) {
m.count++
if m.errorOn > 0 && m.count >= m.errorOn {
return 0, errors.New("mock write error")
}
if m.delay > 0 {
time.Sleep(m.delay)
}
return m.buf.Write(p)
}
func (m *mockWriter) String() string {
return m.buf.String()
}
// mockReader 模拟 Reader 接口
type mockReader struct {
data []byte
pos int
chunk int // 每次读取的字节数
errorOn int // 在第几次读取时返回错误
count int
}
func newMockReader(data string, chunk int) *mockReader {
return &mockReader{
data: []byte(data),
chunk: chunk,
}
}
func (m *mockReader) Read(p []byte) (n int, err error) {
m.count++
if m.errorOn > 0 && m.count >= m.errorOn {
return 0, errors.New("mock read error")
}
if m.pos >= len(m.data) {
return 0, io.EOF
}
readSize := m.chunk
if readSize <= 0 || readSize > len(p) {
readSize = len(p)
}
remaining := len(m.data) - m.pos
if readSize > remaining {
readSize = remaining
}
copy(p, m.data[m.pos:m.pos+readSize])
m.pos += readSize
return readSize, nil
}
// TestTeeBasicFunctionality 测试基本功能
func TestTeeBasicFunctionality(t *testing.T) {
ctx := context.Background()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
testData := "Hello, World! This is a test message."
reader := newMockReader(testData, 10) // 每次读取10字节
writer := newMockWriter()
var handledData [][]byte
var mu sync.Mutex
handle := func(ctx context.Context, data []byte) error {
mu.Lock()
defer mu.Unlock()
// 复制数据,因为原始数据可能被重用
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
handledData = append(handledData, dataCopy)
return nil
}
tee := NewTee(ctx, logger, reader, writer, handle)
defer tee.Close()
err := tee.Stream()
if err != nil {
t.Fatalf("Stream() failed: %v", err)
}
// 等待处理完成
time.Sleep(100 * time.Millisecond)
// 验证写入的数据
if writer.String() != testData {
t.Errorf("Expected writer data %q, got %q", testData, writer.String())
}
// 验证处理的数据
mu.Lock()
var totalHandled []byte
for _, chunk := range handledData {
totalHandled = append(totalHandled, chunk...)
}
mu.Unlock()
if string(totalHandled) != testData {
t.Errorf("Expected handled data %q, got %q", testData, string(totalHandled))
}
}
// TestTeeWithErrors 测试错误处理
func TestTeeWithErrors(t *testing.T) {
ctx := context.Background()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
t.Run("ReaderError", func(t *testing.T) {
reader := newMockReader("test data", 5)
reader.errorOn = 2 // 第二次读取时出错
writer := newMockWriter()
handle := func(ctx context.Context, data []byte) error {
return nil
}
tee := NewTee(ctx, logger, reader, writer, handle)
defer tee.Close()
err := tee.Stream()
if err == nil {
t.Error("Expected error from reader, got nil")
}
})
t.Run("WriterError", func(t *testing.T) {
reader := newMockReader("test data", 5)
writer := newMockWriter()
writer.errorOn = 2 // 第二次写入时出错
handle := func(ctx context.Context, data []byte) error {
return nil
}
tee := NewTee(ctx, logger, reader, writer, handle)
defer tee.Close()
err := tee.Stream()
if err == nil {
t.Error("Expected error from writer, got nil")
}
})
t.Run("HandleError", func(t *testing.T) {
reader := newMockReader("test data", 5)
writer := newMockWriter()
handle := func(ctx context.Context, data []byte) error {
return errors.New("handle error")
}
tee := NewTee(ctx, logger, reader, writer, handle)
defer tee.Close()
// 启动 Stream 在单独的 goroutine 中
go func() {
tee.Stream()
}()
// 等待一段时间让处理器有机会处理数据并出错
time.Sleep(200 * time.Millisecond)
})
}
// TestTeeContextCancellation 测试上下文取消
func TestTeeContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
// 创建一个会持续产生数据的 reader
reader := strings.NewReader(strings.Repeat("test data ", 1000))
writer := newMockWriter()
handle := func(ctx context.Context, data []byte) error {
return nil
}
tee := NewTee(ctx, logger, reader, writer, handle)
defer tee.Close()
// 在单独的 goroutine 中启动 Stream
done := make(chan error, 1)
go func() {
done <- tee.Stream()
}()
// 等待一段时间后取消上下文
time.Sleep(50 * time.Millisecond)
cancel()
// 等待 Stream 完成
select {
case err := <-done:
if err != nil && err != io.EOF {
t.Logf("Stream completed with error: %v", err)
}
case <-time.After(2 * time.Second):
t.Error("Stream did not complete within timeout")
}
}
// TestTeeConcurrentSafety 测试并发安全性
func TestTeeConcurrentSafety(t *testing.T) {
ctx := context.Background()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
testData := strings.Repeat("concurrent test data ", 100)
reader := strings.NewReader(testData)
writer := newMockWriter()
var processedCount int64
var mu sync.Mutex
handle := func(ctx context.Context, data []byte) error {
mu.Lock()
processedCount++
mu.Unlock()
// 模拟一些处理时间
time.Sleep(time.Microsecond)
return nil
}
tee := NewTee(ctx, logger, reader, writer, handle)
defer tee.Close()
err := tee.Stream()
if err != nil {
t.Fatalf("Stream() failed: %v", err)
}
// 等待所有数据处理完成
time.Sleep(500 * time.Millisecond)
mu.Lock()
finalCount := processedCount
mu.Unlock()
if finalCount == 0 {
t.Error("No data was processed")
}
t.Logf("Processed %d chunks of data", finalCount)
}
// TestBufferPoolEfficiency 测试缓冲区池的效率
func TestBufferPoolEfficiency(t *testing.T) {
// 这个测试验证缓冲区池是否正常工作
// 通过多次获取和归还缓冲区来测试
var buffers []*[]byte
// 获取多个缓冲区
for i := 0; i < 10; i++ {
bufPtr := bufferPool.Get().(*[]byte)
buffers = append(buffers, bufPtr)
// 验证缓冲区大小
if len(*bufPtr) != 4096 {
t.Errorf("Expected buffer size 4096, got %d", len(*bufPtr))
}
}
// 归还所有缓冲区
for _, bufPtr := range buffers {
bufferPool.Put(bufPtr)
}
// 再次获取缓冲区,应该重用之前的缓冲区
for i := 0; i < 5; i++ {
bufPtr := bufferPool.Get().(*[]byte)
if len(*bufPtr) != 4096 {
t.Errorf("Expected reused buffer size 4096, got %d", len(*bufPtr))
}
bufferPool.Put(bufPtr)
}
}
// BenchmarkTeeStream 基准测试
func BenchmarkTeeStream(b *testing.B) {
ctx := context.Background()
logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}))
testData := strings.Repeat("benchmark test data ", 1000)
handle := func(ctx context.Context, data []byte) error {
return nil
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
reader := strings.NewReader(testData)
writer := io.Discard
tee := NewTee(ctx, logger, reader, writer, handle)
err := tee.Stream()
if err != nil {
b.Fatalf("Stream() failed: %v", err)
}
tee.Close()
}
}
// BenchmarkBufferPool 缓冲区池基准测试
func BenchmarkBufferPool(b *testing.B) {
b.Run("WithPool", func(b *testing.B) {
for i := 0; i < b.N; i++ {
bufPtr := bufferPool.Get().(*[]byte)
// 模拟使用缓冲区
_ = *bufPtr
bufferPool.Put(bufPtr)
}
})
b.Run("WithoutPool", func(b *testing.B) {
for i := 0; i < b.N; i++ {
buf := make([]byte, 4096)
// 模拟使用缓冲区
_ = buf
}
})
}
// TestTeeClose 测试关闭功能
func TestTeeClose(t *testing.T) {
ctx := context.Background()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
reader := strings.NewReader("test data")
writer := newMockWriter()
handle := func(ctx context.Context, data []byte) error {
return nil
}
tee := NewTee(ctx, logger, reader, writer, handle)
// 测试多次关闭不会 panic
tee.Close()
tee.Close()
tee.Close()
}