mirror of
https://github.com/chaitin/MonkeyCode.git
synced 2026-02-02 06:43:23 +08:00
Merge pull request #92 from yokowu/feat-proxyv2
feat(proxy): 利用 ReverseProxy 实现代理, 进一步分离代理与分析的逻辑
This commit is contained in:
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
"github.com/chaitin/MonkeyCode/backend/config"
|
||||
"github.com/chaitin/MonkeyCode/backend/db"
|
||||
"github.com/chaitin/MonkeyCode/backend/domain"
|
||||
billingv1 "github.com/chaitin/MonkeyCode/backend/internal/billing/handler/http/v1"
|
||||
dashv1 "github.com/chaitin/MonkeyCode/backend/internal/dashboard/handler/v1"
|
||||
v1 "github.com/chaitin/MonkeyCode/backend/internal/model/handler/http/v1"
|
||||
@@ -25,7 +24,6 @@ type Server struct {
|
||||
web *web.Web
|
||||
ent *db.Client
|
||||
logger *slog.Logger
|
||||
proxy domain.Proxy
|
||||
openaiV1 *openaiV1.V1Handler
|
||||
modelV1 *v1.ModelHandler
|
||||
userV1 *userV1.UserHandler
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/GoYoko/web"
|
||||
"github.com/chaitin/MonkeyCode/backend/config"
|
||||
"github.com/chaitin/MonkeyCode/backend/db"
|
||||
"github.com/chaitin/MonkeyCode/backend/domain"
|
||||
v1_5 "github.com/chaitin/MonkeyCode/backend/internal/billing/handler/http/v1"
|
||||
repo7 "github.com/chaitin/MonkeyCode/backend/internal/billing/repo"
|
||||
usecase6 "github.com/chaitin/MonkeyCode/backend/internal/billing/usecase"
|
||||
@@ -56,7 +55,7 @@ func newServer() (*Server, error) {
|
||||
proxyRepo := repo.NewProxyRepo(client)
|
||||
modelRepo := repo2.NewModelRepo(client)
|
||||
proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo)
|
||||
domainProxy := proxy.NewLLMProxy(proxyUsecase, configConfig, slogLogger)
|
||||
llmProxy := proxy.NewLLMProxy(slogLogger, configConfig, proxyUsecase)
|
||||
openAIRepo := repo3.NewOpenAIRepo(client)
|
||||
openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, slogLogger)
|
||||
extensionRepo := repo4.NewExtensionRepo(client)
|
||||
@@ -64,7 +63,7 @@ func newServer() (*Server, error) {
|
||||
proxyMiddleware := middleware.NewProxyMiddleware(proxyUsecase)
|
||||
redisClient := store.NewRedisCli(configConfig)
|
||||
activeMiddleware := middleware.NewActiveMiddleware(redisClient, slogLogger)
|
||||
v1Handler := v1.NewV1Handler(slogLogger, web, domainProxy, openAIUsecase, extensionUsecase, proxyMiddleware, activeMiddleware, configConfig)
|
||||
v1Handler := v1.NewV1Handler(slogLogger, web, llmProxy, proxyUsecase, openAIUsecase, extensionUsecase, proxyMiddleware, activeMiddleware, configConfig)
|
||||
modelUsecase := usecase3.NewModelUsecase(slogLogger, modelRepo, configConfig)
|
||||
sessionSession := session.NewSession(configConfig)
|
||||
authMiddleware := middleware.NewAuthMiddleware(sessionSession, slogLogger)
|
||||
@@ -83,7 +82,6 @@ func newServer() (*Server, error) {
|
||||
web: web,
|
||||
ent: client,
|
||||
logger: slogLogger,
|
||||
proxy: domainProxy,
|
||||
openaiV1: v1Handler,
|
||||
modelV1: modelHandler,
|
||||
userV1: userHandler,
|
||||
@@ -100,7 +98,6 @@ type Server struct {
|
||||
web *web.Web
|
||||
ent *db.Client
|
||||
logger *slog.Logger
|
||||
proxy domain.Proxy
|
||||
openaiV1 *v1.V1Handler
|
||||
modelV1 *v1_2.ModelHandler
|
||||
userV1 *v1_3.UserHandler
|
||||
|
||||
@@ -64,6 +64,7 @@ type Config struct {
|
||||
|
||||
Extension struct {
|
||||
Baseurl string `mapstructure:"baseurl"`
|
||||
Limit int `mapstructure:"limit"`
|
||||
} `mapstructure:"extension"`
|
||||
}
|
||||
|
||||
@@ -99,6 +100,7 @@ func Init() (*Config, error) {
|
||||
v.SetDefault("init_model.key", "")
|
||||
v.SetDefault("init_model.url", "https://model-square.app.baizhi.cloud/v1")
|
||||
v.SetDefault("extension.baseurl", "https://release.baizhi.cloud")
|
||||
v.SetDefault("extension.limit", 10)
|
||||
|
||||
c := Config{}
|
||||
if err := v.Unmarshal(&c); err != nil {
|
||||
|
||||
@@ -126,7 +126,12 @@ func (c *ChatContent) From(e *db.TaskRecord) *ChatContent {
|
||||
return c
|
||||
}
|
||||
c.Role = e.Role
|
||||
c.Content = e.Completion
|
||||
switch e.Role {
|
||||
case consts.ChatRoleUser:
|
||||
c.Content = e.Prompt
|
||||
case consts.ChatRoleAssistant:
|
||||
c.Content = e.Completion
|
||||
}
|
||||
c.CreatedAt = e.CreatedAt.Unix()
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -3,10 +3,9 @@ package domain
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/rokku-c/go-openai"
|
||||
|
||||
"github.com/chaitin/MonkeyCode/backend/consts"
|
||||
"github.com/chaitin/MonkeyCode/backend/db"
|
||||
"github.com/rokku-c/go-openai"
|
||||
)
|
||||
|
||||
type OpenAIUsecase interface {
|
||||
@@ -21,7 +20,6 @@ type OpenAIRepo interface {
|
||||
|
||||
type CompletionRequest struct {
|
||||
openai.CompletionRequest
|
||||
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
}
|
||||
|
||||
|
||||
@@ -6,27 +6,29 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rokku-c/go-openai"
|
||||
|
||||
"github.com/GoYoko/web"
|
||||
|
||||
"github.com/chaitin/MonkeyCode/backend/config"
|
||||
"github.com/chaitin/MonkeyCode/backend/domain"
|
||||
"github.com/chaitin/MonkeyCode/backend/internal/middleware"
|
||||
"github.com/chaitin/MonkeyCode/backend/internal/proxy"
|
||||
)
|
||||
|
||||
type V1Handler struct {
|
||||
logger *slog.Logger
|
||||
proxy domain.Proxy
|
||||
usecase domain.OpenAIUsecase
|
||||
euse domain.ExtensionUsecase
|
||||
config *config.Config
|
||||
logger *slog.Logger
|
||||
proxy *proxy.LLMProxy
|
||||
proxyUse domain.ProxyUsecase
|
||||
usecase domain.OpenAIUsecase
|
||||
euse domain.ExtensionUsecase
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewV1Handler(
|
||||
logger *slog.Logger,
|
||||
w *web.Web,
|
||||
proxy domain.Proxy,
|
||||
proxy *proxy.LLMProxy,
|
||||
proxyUse domain.ProxyUsecase,
|
||||
usecase domain.OpenAIUsecase,
|
||||
euse domain.ExtensionUsecase,
|
||||
middleware *middleware.ProxyMiddleware,
|
||||
@@ -34,21 +36,23 @@ func NewV1Handler(
|
||||
config *config.Config,
|
||||
) *V1Handler {
|
||||
h := &V1Handler{
|
||||
logger: logger.With(slog.String("handler", "openai")),
|
||||
proxy: proxy,
|
||||
usecase: usecase,
|
||||
euse: euse,
|
||||
config: config,
|
||||
logger: logger.With(slog.String("handler", "openai")),
|
||||
proxy: proxy,
|
||||
proxyUse: proxyUse,
|
||||
usecase: usecase,
|
||||
euse: euse,
|
||||
config: config,
|
||||
}
|
||||
|
||||
w.GET("/api/config", web.BindHandler(h.GetConfig), middleware.Auth())
|
||||
w.GET("/v1/version", web.BaseHandler(h.Version), middleware.Auth())
|
||||
|
||||
g := w.Group("/v1", middleware.Auth())
|
||||
g.GET("/models", web.BaseHandler(h.ModelList))
|
||||
g.POST("/completion/accept", web.BindHandler(h.AcceptCompletion), active.Active())
|
||||
g.POST("/chat/completions", web.BindHandler(h.ChatCompletion), active.Active())
|
||||
g.POST("/completions", web.BindHandler(h.Completions), active.Active())
|
||||
g.POST("/embeddings", web.BindHandler(h.Embeddings), active.Active())
|
||||
g.POST("/chat/completions", web.BaseHandler(h.ChatCompletion), active.Active())
|
||||
g.POST("/completions", web.BaseHandler(h.Completions), active.Active())
|
||||
g.POST("/embeddings", web.BaseHandler(h.Embeddings), active.Active())
|
||||
return h
|
||||
}
|
||||
|
||||
@@ -86,7 +90,7 @@ func (h *V1Handler) Version(c *web.Context) error {
|
||||
// @Success 200 {object} web.Resp{}
|
||||
// @Router /v1/completion/accept [post]
|
||||
func (h *V1Handler) AcceptCompletion(c *web.Context, req domain.AcceptCompletionReq) error {
|
||||
if err := h.proxy.AcceptCompletion(c.Request().Context(), &req); err != nil {
|
||||
if err := h.proxyUse.AcceptCompletion(c.Request().Context(), &req); err != nil {
|
||||
return BadRequest(c, err.Error())
|
||||
}
|
||||
return nil
|
||||
@@ -120,19 +124,8 @@ func (h *V1Handler) ModelList(c *web.Context) error {
|
||||
// @Produce json
|
||||
// @Success 200 {object} web.Resp{}
|
||||
// @Router /v1/chat/completions [post]
|
||||
func (h *V1Handler) ChatCompletion(c *web.Context, req openai.ChatCompletionRequest) error {
|
||||
// TODO: 记录请求到文件
|
||||
if req.Model == "" {
|
||||
return BadRequest(c, "模型不能为空")
|
||||
}
|
||||
|
||||
// if len(req.Tools) > 0 && req.Model != "qwen-max" {
|
||||
// if h.toolsCall(c, req, req.Stream, req.Model) {
|
||||
// return nil
|
||||
// }
|
||||
// }
|
||||
|
||||
h.proxy.HandleChatCompletion(c.Request().Context(), c.Response(), &req)
|
||||
func (h *V1Handler) ChatCompletion(c *web.Context) error {
|
||||
h.proxy.ServeHTTP(c.Response(), c.Request())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -146,13 +139,8 @@ func (h *V1Handler) ChatCompletion(c *web.Context, req openai.ChatCompletionRequ
|
||||
// @Produce json
|
||||
// @Success 200 {object} web.Resp{}
|
||||
// @Router /v1/completions [post]
|
||||
func (h *V1Handler) Completions(c *web.Context, req domain.CompletionRequest) error {
|
||||
// TODO: 记录请求到文件
|
||||
if req.Model == "" {
|
||||
return BadRequest(c, "模型不能为空")
|
||||
}
|
||||
h.logger.With("request", req).DebugContext(c.Request().Context(), "处理文本补全请求")
|
||||
h.proxy.HandleCompletion(c.Request().Context(), c.Response(), req)
|
||||
func (h *V1Handler) Completions(c *web.Context) error {
|
||||
h.proxy.ServeHTTP(c.Response(), c.Request())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -166,12 +154,8 @@ func (h *V1Handler) Completions(c *web.Context, req domain.CompletionRequest) er
|
||||
// @Produce json
|
||||
// @Success 200 {object} web.Resp{}
|
||||
// @Router /v1/embeddings [post]
|
||||
func (h *V1Handler) Embeddings(c *web.Context, req openai.EmbeddingRequest) error {
|
||||
if req.Model == "" {
|
||||
return BadRequest(c, "模型不能为空")
|
||||
}
|
||||
|
||||
h.proxy.HandleEmbeddings(c.Request().Context(), c.Response(), &req)
|
||||
func (h *V1Handler) Embeddings(c *web.Context) error {
|
||||
h.proxy.ServeHTTP(c.Response(), c.Request())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,840 +1,141 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rokku-c/go-openai"
|
||||
|
||||
"github.com/chaitin/MonkeyCode/backend/config"
|
||||
"github.com/chaitin/MonkeyCode/backend/consts"
|
||||
"github.com/chaitin/MonkeyCode/backend/domain"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/logger"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/promptparser"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/request"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/tee"
|
||||
)
|
||||
|
||||
// ErrResp 错误响应
|
||||
type ErrResp struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code,omitempty"`
|
||||
} `json:"error"`
|
||||
type CtxKey struct{}
|
||||
|
||||
type ProxyCtx struct {
|
||||
ctx context.Context
|
||||
Path string
|
||||
Model *domain.Model
|
||||
Header http.Header
|
||||
RespHeader http.Header
|
||||
ReqTee *tee.ReqTee
|
||||
RequestID string
|
||||
UserID string
|
||||
}
|
||||
|
||||
// RequestResponseLog 请求响应日志结构
|
||||
type RequestResponseLog struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
RequestID string `json:"request_id"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
ModelName string `json:"model_name"`
|
||||
ModelType consts.ModelType `json:"model_type"`
|
||||
UpstreamURL string `json:"upstream_url"`
|
||||
RequestBody any `json:"request_body"`
|
||||
RequestHeader map[string][]string `json:"request_header"`
|
||||
StatusCode int `json:"status_code"`
|
||||
ResponseBody any `json:"response_body"`
|
||||
ResponseHeader map[string][]string `json:"response_header"`
|
||||
Latency int64 `json:"latency_ms"`
|
||||
Error string `json:"error,omitempty"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
SourceIP string `json:"source_ip"` // 请求来源IP地址
|
||||
}
|
||||
|
||||
// LLMProxy LLM API代理实现
|
||||
type LLMProxy struct {
|
||||
usecase domain.ProxyUsecase
|
||||
cfg *config.Config
|
||||
client *http.Client
|
||||
streamClient *http.Client
|
||||
logger *slog.Logger
|
||||
requestLogPath string // 请求日志保存路径
|
||||
logger *slog.Logger
|
||||
cfg *config.Config
|
||||
usecase domain.ProxyUsecase
|
||||
transport *http.Transport
|
||||
proxy *httputil.ReverseProxy
|
||||
}
|
||||
|
||||
// NewLLMProxy 创建LLM API代理实例
|
||||
func NewLLMProxy(
|
||||
usecase domain.ProxyUsecase,
|
||||
cfg *config.Config,
|
||||
logger *slog.Logger,
|
||||
) domain.Proxy {
|
||||
// 解析超时时间
|
||||
timeout, err := time.ParseDuration(cfg.LLMProxy.Timeout)
|
||||
if err != nil {
|
||||
timeout = 30 * time.Second
|
||||
logger.Warn("解析超时时间失败, 使用默认值 30s", "error", err)
|
||||
cfg *config.Config,
|
||||
usecase domain.ProxyUsecase,
|
||||
) *LLMProxy {
|
||||
l := &LLMProxy{
|
||||
logger: logger,
|
||||
cfg: cfg,
|
||||
usecase: usecase,
|
||||
}
|
||||
|
||||
// 解析保持连接时间
|
||||
keepAlive, err := time.ParseDuration(cfg.LLMProxy.KeepAlive)
|
||||
if err != nil {
|
||||
keepAlive = 60 * time.Second
|
||||
logger.Warn("解析保持连接时间失败, 使用默认值 60s", "error", err)
|
||||
l.transport = &http.Transport{
|
||||
MaxIdleConns: cfg.LLMProxy.ClientPoolSize,
|
||||
MaxConnsPerHost: cfg.LLMProxy.ClientPoolSize,
|
||||
MaxIdleConnsPerHost: cfg.LLMProxy.ClientPoolSize,
|
||||
IdleConnTimeout: 24 * time.Hour,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 24 * time.Hour,
|
||||
}).DialContext,
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: cfg.LLMProxy.ClientPoolSize,
|
||||
MaxConnsPerHost: cfg.LLMProxy.ClientPoolSize,
|
||||
MaxIdleConnsPerHost: cfg.LLMProxy.ClientPoolSize,
|
||||
IdleConnTimeout: keepAlive,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
DisableCompression: false,
|
||||
ForceAttemptHTTP2: true,
|
||||
},
|
||||
}
|
||||
|
||||
streamClient := &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: cfg.LLMProxy.StreamClientPoolSize,
|
||||
MaxConnsPerHost: cfg.LLMProxy.StreamClientPoolSize,
|
||||
MaxIdleConnsPerHost: cfg.LLMProxy.StreamClientPoolSize,
|
||||
IdleConnTimeout: 24 * time.Hour,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
// 获取日志配置
|
||||
requestLogPath := ""
|
||||
if cfg.LLMProxy.RequestLogPath != "" {
|
||||
requestLogPath = cfg.LLMProxy.RequestLogPath
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(requestLogPath, 0755); err != nil {
|
||||
logger.Warn("创建请求日志目录失败", "error", err, "path", requestLogPath)
|
||||
}
|
||||
}
|
||||
|
||||
return &LLMProxy{
|
||||
usecase: usecase,
|
||||
client: client,
|
||||
streamClient: streamClient,
|
||||
cfg: cfg,
|
||||
requestLogPath: requestLogPath,
|
||||
logger: logger,
|
||||
l.proxy = &httputil.ReverseProxy{
|
||||
Transport: l.transport,
|
||||
Rewrite: l.rewrite,
|
||||
ModifyResponse: l.modifyResponse,
|
||||
ErrorHandler: l.errorHandler,
|
||||
FlushInterval: 100 * time.Millisecond,
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// saveRequestResponseLog 保存请求响应日志到文件
|
||||
func (p *LLMProxy) saveRequestResponseLog(log *RequestResponseLog) {
|
||||
if p.requestLogPath == "" {
|
||||
func (l *LLMProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
l.proxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (l *LLMProxy) Close() error {
|
||||
l.transport.CloseIdleConnections()
|
||||
return nil
|
||||
}
|
||||
|
||||
var modelType = map[string]consts.ModelType{
|
||||
"/v1/chat/completions": consts.ModelTypeLLM,
|
||||
"/v1/completions": consts.ModelTypeCoder,
|
||||
}
|
||||
|
||||
func (l *LLMProxy) rewrite(r *httputil.ProxyRequest) {
|
||||
l.logger.DebugContext(r.In.Context(), "rewrite request", slog.String("path", r.In.URL.Path))
|
||||
mt, ok := modelType[r.In.URL.Path]
|
||||
if !ok {
|
||||
l.logger.Error("model type not found", slog.String("path", r.In.URL.Path))
|
||||
return
|
||||
}
|
||||
|
||||
// 创建文件名,格式:YYYYMMDD_HHMMSS_请求ID.json
|
||||
timestamp := log.Timestamp.Format("20060102_150405") + fmt.Sprintf("_%03d", log.Timestamp.Nanosecond()/1e6)
|
||||
filename := fmt.Sprintf("%s_%s.json", timestamp, log.RequestID)
|
||||
filepath := filepath.Join(p.requestLogPath, filename)
|
||||
|
||||
// 将日志序列化为JSON
|
||||
logData, err := json.MarshalIndent(log, "", " ")
|
||||
m, err := l.usecase.SelectModelWithLoadBalancing("", mt)
|
||||
if err != nil {
|
||||
p.logger.Error("序列化请求日志失败", "error", err)
|
||||
l.logger.Error("select model with load balancing failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(filepath, logData, 0644); err != nil {
|
||||
p.logger.Error("写入请求日志文件失败", "error", err, "path", filepath)
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Debug("请求响应日志已保存", "path", filepath)
|
||||
}
|
||||
|
||||
func (p *LLMProxy) AcceptCompletion(ctx context.Context, req *domain.AcceptCompletionReq) error {
|
||||
return p.usecase.AcceptCompletion(ctx, req)
|
||||
}
|
||||
|
||||
func writeErrResp(w http.ResponseWriter, code int, message, errorType string) {
|
||||
resp := ErrResp{
|
||||
Error: struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code,omitempty"`
|
||||
}{
|
||||
Message: message,
|
||||
Type: errorType,
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
b, _ := json.Marshal(resp)
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
type Ctx struct {
|
||||
UserID string
|
||||
RequestID string
|
||||
SourceIP string
|
||||
}
|
||||
|
||||
func (p *LLMProxy) handle(ctx context.Context, fn func(ctx *Ctx, log *RequestResponseLog) error) {
|
||||
// 获取用户ID
|
||||
userID := "unknown"
|
||||
if id, ok := ctx.Value(logger.UserIDKey{}).(string); ok {
|
||||
userID = id
|
||||
}
|
||||
|
||||
requestID := "unknown"
|
||||
if id, ok := ctx.Value(logger.RequestIDKey{}).(string); ok {
|
||||
requestID = id
|
||||
}
|
||||
|
||||
sourceip := "unknown"
|
||||
if ip, ok := ctx.Value("remote_addr").(string); ok {
|
||||
sourceip = ip
|
||||
}
|
||||
|
||||
// 创建请求日志结构
|
||||
l := &RequestResponseLog{
|
||||
Timestamp: time.Now(),
|
||||
RequestID: requestID,
|
||||
ModelType: consts.ModelTypeCoder,
|
||||
SourceIP: sourceip,
|
||||
}
|
||||
|
||||
c := &Ctx{
|
||||
UserID: userID,
|
||||
RequestID: requestID,
|
||||
SourceIP: sourceip,
|
||||
}
|
||||
|
||||
if err := fn(c, l); err != nil {
|
||||
p.logger.With("source_ip", sourceip).ErrorContext(ctx, "处理请求失败", "error", err)
|
||||
l.Error = err.Error()
|
||||
}
|
||||
|
||||
go p.saveRequestResponseLog(l)
|
||||
}
|
||||
|
||||
func (p *LLMProxy) HandleCompletion(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) {
|
||||
if req.Stream {
|
||||
p.handleCompletionStream(ctx, w, req)
|
||||
} else {
|
||||
p.handleCompletion(ctx, w, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *LLMProxy) handleCompletionStream(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) {
|
||||
endpoint := "/completions"
|
||||
p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error {
|
||||
// 使用负载均衡算法选择模型
|
||||
m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeCoder)
|
||||
if err != nil {
|
||||
p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "模型选择失败", "error", err)
|
||||
writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error")
|
||||
return err
|
||||
}
|
||||
|
||||
// 构造上游API URL
|
||||
upstream := m.APIBase + endpoint
|
||||
log.UpstreamURL = upstream
|
||||
startTime := time.Now()
|
||||
|
||||
// 创建上游请求
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
p.logger.ErrorContext(ctx, "序列化请求体失败", "error", err)
|
||||
return fmt.Errorf("序列化请求体失败: %w", err)
|
||||
}
|
||||
|
||||
upreq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstream, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
p.logger.With("upstream", upstream).WarnContext(ctx, "创建上游流式请求失败", "error", err)
|
||||
return fmt.Errorf("创建上游请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
upreq.Header.Set("Content-Type", "application/json")
|
||||
upreq.Header.Set("Accept", "text/event-stream")
|
||||
if m.APIKey != "" && m.APIKey != "none" {
|
||||
upreq.Header.Set("Authorization", "Bearer "+m.APIKey)
|
||||
}
|
||||
|
||||
// 保存请求头(去除敏感信息)
|
||||
requestHeaders := make(map[string][]string)
|
||||
for k, v := range upreq.Header {
|
||||
if k != "Authorization" {
|
||||
requestHeaders[k] = v
|
||||
} else {
|
||||
// 敏感信息脱敏
|
||||
requestHeaders[k] = []string{"Bearer ***"}
|
||||
}
|
||||
}
|
||||
log.RequestHeader = requestHeaders
|
||||
|
||||
p.logger.With(
|
||||
"upstreamURL", upstream,
|
||||
"modelName", m.ModelName,
|
||||
"modelType", consts.ModelTypeLLM,
|
||||
"apiBase", m.APIBase,
|
||||
"requestHeader", upreq.Header,
|
||||
"requestBody", upreq,
|
||||
).DebugContext(ctx, "转发流式请求到上游API")
|
||||
|
||||
// 发送请求
|
||||
resp, err := p.client.Do(upreq)
|
||||
if err != nil {
|
||||
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游流式请求失败", "error", err)
|
||||
return fmt.Errorf("发送上游请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
responseBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 更新日志错误信息
|
||||
log.StatusCode = resp.StatusCode
|
||||
log.ResponseHeader = resp.Header
|
||||
log.ResponseBody = string(responseBody)
|
||||
log.Latency = time.Since(startTime).Milliseconds()
|
||||
|
||||
// 在debug级别记录错误的流式响应内容
|
||||
p.logger.With(
|
||||
"statusCode", resp.StatusCode,
|
||||
"responseHeader", resp.Header,
|
||||
"responseBody", string(responseBody),
|
||||
).DebugContext(ctx, "上游流式响应错误原始内容")
|
||||
|
||||
var errorResp ErrResp
|
||||
if err := json.Unmarshal(responseBody, &errorResp); err == nil {
|
||||
p.logger.With(
|
||||
"endpoint", endpoint,
|
||||
"upstreamURL", upstream,
|
||||
"requestBody", upreq,
|
||||
"statusCode", resp.StatusCode,
|
||||
"errorType", errorResp.Error.Type,
|
||||
"errorCode", errorResp.Error.Code,
|
||||
"errorMessage", errorResp.Error.Message,
|
||||
"latency", time.Since(startTime),
|
||||
).WarnContext(ctx, "上游API流式请求异常详情")
|
||||
|
||||
return fmt.Errorf("上游API返回错误: %s", errorResp.Error.Message)
|
||||
}
|
||||
|
||||
p.logger.With(
|
||||
"endpoint", endpoint,
|
||||
"upstreamURL", upstream,
|
||||
"requestBody", upreq,
|
||||
"statusCode", resp.StatusCode,
|
||||
"responseBody", string(responseBody),
|
||||
).WarnContext(ctx, "上游API流式请求异常详情")
|
||||
|
||||
return fmt.Errorf("上游API返回非200状态码: %d, 响应: %s", resp.StatusCode, string(responseBody))
|
||||
}
|
||||
|
||||
// 更新日志信息
|
||||
log.StatusCode = resp.StatusCode
|
||||
log.ResponseHeader = resp.Header
|
||||
|
||||
// 在debug级别记录流式响应头信息
|
||||
p.logger.With(
|
||||
"statusCode", resp.StatusCode,
|
||||
"responseHeader", resp.Header,
|
||||
).DebugContext(ctx, "上游流式响应头信息")
|
||||
|
||||
// 设置响应头
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
|
||||
rc := &domain.RecordParam{
|
||||
UserID: c.UserID,
|
||||
ModelID: m.ID,
|
||||
ModelType: consts.ModelTypeLLM,
|
||||
Prompt: req.Prompt.(string),
|
||||
Role: consts.ChatRoleAssistant,
|
||||
}
|
||||
buf := bufio.NewWriterSize(w, 32*1024)
|
||||
defer buf.Flush()
|
||||
|
||||
ch := make(chan []byte, 1024)
|
||||
defer close(ch)
|
||||
|
||||
go func(rc *domain.RecordParam) {
|
||||
for line := range ch {
|
||||
if bytes.HasPrefix(line, []byte("data:")) {
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(line, []byte("[DONE]")) {
|
||||
break
|
||||
}
|
||||
|
||||
var t openai.CompletionResponse
|
||||
if err := json.Unmarshal(line, &t); err != nil {
|
||||
p.logger.With("line", string(line)).WarnContext(ctx, "解析流式数据失败", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
p.logger.With("response", t).DebugContext(ctx, "流式响应数据")
|
||||
if len(t.Choices) > 0 {
|
||||
rc.Completion += t.Choices[0].Text
|
||||
}
|
||||
rc.InputTokens = int64(t.Usage.PromptTokens)
|
||||
rc.OutputTokens += int64(t.Usage.CompletionTokens)
|
||||
}
|
||||
}
|
||||
|
||||
if rc.OutputTokens == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.With("record", rc).DebugContext(ctx, "流式记录")
|
||||
if err := p.usecase.Record(context.Background(), rc); err != nil {
|
||||
p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM).
|
||||
WarnContext(ctx, "插入流式记录失败", "error", err)
|
||||
}
|
||||
}(rc)
|
||||
|
||||
err = streamRead(ctx, resp.Body, func(line []byte) error {
|
||||
ch <- line
|
||||
if _, err := buf.Write(line); err != nil {
|
||||
return fmt.Errorf("写入响应失败: %w", err)
|
||||
}
|
||||
return buf.Flush()
|
||||
if r.In.ContentLength > 0 {
|
||||
tee := tee.NewReqTeeWithMaxSize(r.In.Body, 10*1024*1024)
|
||||
r.Out.Body = tee
|
||||
ctx := context.WithValue(r.In.Context(), CtxKey{}, &ProxyCtx{
|
||||
ctx: r.In.Context(),
|
||||
Path: r.In.URL.Path,
|
||||
Model: m,
|
||||
ReqTee: tee,
|
||||
RequestID: r.In.Context().Value(logger.RequestIDKey{}).(string),
|
||||
UserID: r.In.Context().Value(logger.UserIDKey{}).(string),
|
||||
Header: r.In.Header,
|
||||
})
|
||||
return err
|
||||
})
|
||||
}
|
||||
r.Out = r.Out.WithContext(ctx)
|
||||
}
|
||||
|
||||
func (p *LLMProxy) handleCompletion(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) {
|
||||
endpoint := "/completions"
|
||||
p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error {
|
||||
// 使用负载均衡算法选择模型
|
||||
m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeCoder)
|
||||
if err != nil {
|
||||
p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "模型选择失败", "error", err)
|
||||
writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error")
|
||||
return err
|
||||
}
|
||||
|
||||
// 构造上游API URL
|
||||
upstream := m.APIBase + endpoint
|
||||
log.UpstreamURL = upstream
|
||||
u, err := url.Parse(upstream)
|
||||
if err != nil {
|
||||
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "解析上游URL失败", "error", err)
|
||||
writeErrResp(w, http.StatusInternalServerError, "无效的上游URL", "proxy_error")
|
||||
return err
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
client := request.NewClient(
|
||||
u.Scheme,
|
||||
u.Host,
|
||||
30*time.Second,
|
||||
request.WithClient(p.client),
|
||||
)
|
||||
client.SetDebug(p.cfg.Debug)
|
||||
resp, err := request.Post[openai.CompletionResponse](client, u.Path, req, request.WithHeader(request.Header{
|
||||
"Authorization": "Bearer " + m.APIKey,
|
||||
}))
|
||||
if err != nil {
|
||||
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游请求失败", "error", err)
|
||||
writeErrResp(w, http.StatusInternalServerError, "发送上游请求失败", "proxy_error")
|
||||
log.Latency = time.Since(startTime).Milliseconds()
|
||||
return err
|
||||
}
|
||||
|
||||
latency := time.Since(startTime)
|
||||
|
||||
// 记录请求信息
|
||||
p.logger.With(
|
||||
"statusCode", http.StatusOK,
|
||||
"upstreamURL", upstream,
|
||||
"modelName", req.Model,
|
||||
"modelType", consts.ModelTypeCoder,
|
||||
"apiBase", m.APIBase,
|
||||
"requestBody", req,
|
||||
"resp", resp,
|
||||
"latency", latency.String(),
|
||||
).DebugContext(ctx, "转发请求到上游API")
|
||||
|
||||
go p.recordCompletion(c, m.ID, req, resp)
|
||||
|
||||
// 更新请求日志
|
||||
log.StatusCode = http.StatusOK
|
||||
log.ResponseBody = resp
|
||||
log.ResponseHeader = resp.Header()
|
||||
log.Latency = latency.Milliseconds()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
b, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
p.logger.With("response", resp).WarnContext(ctx, "序列化响应失败", "error", err)
|
||||
return err
|
||||
}
|
||||
w.Write(b)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (p *LLMProxy) recordCompletion(c *Ctx, modelID string, req domain.CompletionRequest, resp *openai.CompletionResponse) {
|
||||
if resp.Usage.CompletionTokens == 0 {
|
||||
u, err := url.Parse(m.APIBase)
|
||||
if err != nil {
|
||||
l.logger.ErrorContext(r.In.Context(), "parse model api base failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
prompt := req.Prompt.(string)
|
||||
rc := &domain.RecordParam{
|
||||
TaskID: resp.ID,
|
||||
UserID: c.UserID,
|
||||
ModelID: modelID,
|
||||
ModelType: consts.ModelTypeCoder,
|
||||
Prompt: prompt,
|
||||
ProgramLanguage: req.Metadata["program_language"],
|
||||
InputTokens: int64(resp.Usage.PromptTokens),
|
||||
OutputTokens: int64(resp.Usage.CompletionTokens),
|
||||
}
|
||||
|
||||
for _, choice := range resp.Choices {
|
||||
rc.Completion += choice.Text
|
||||
}
|
||||
lines := strings.Count(rc.Completion, "\n") + 1
|
||||
rc.CodeLines = int64(lines)
|
||||
|
||||
if err := p.usecase.Record(ctx, rc); err != nil {
|
||||
p.logger.With("modelID", modelID, "modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "记录请求失败", "error", err)
|
||||
r.Out.URL.Scheme = u.Scheme
|
||||
r.Out.URL.Host = u.Host
|
||||
r.Out.URL.Path = r.In.URL.Path
|
||||
r.Out.Header.Set("Authorization", "Bearer "+m.APIKey)
|
||||
r.SetXForwarded()
|
||||
r.Out.Host = u.Host
|
||||
l.logger.With("in", r.In.URL.Path, "out", r.Out.URL.Path).DebugContext(r.In.Context(), "rewrite request")
|
||||
}
|
||||
|
||||
func (l *LLMProxy) modifyResponse(resp *http.Response) error {
|
||||
ctx := resp.Request.Context()
|
||||
if pctx, ok := ctx.Value(CtxKey{}).(*ProxyCtx); ok {
|
||||
pctx.ctx = ctx
|
||||
pctx.RespHeader = resp.Header
|
||||
resp.Body = NewRecorder(l.cfg, pctx, resp.Body, l.logger, l.usecase)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *LLMProxy) HandleChatCompletion(ctx context.Context, w http.ResponseWriter, req *openai.ChatCompletionRequest) {
|
||||
if req.Stream {
|
||||
p.handleChatCompletionStream(ctx, w, req)
|
||||
} else {
|
||||
p.handleChatCompletion(ctx, w, req)
|
||||
}
|
||||
}
|
||||
|
||||
func streamRead(ctx context.Context, r io.Reader, fn func([]byte) error) error {
|
||||
reader := bufio.NewReaderSize(r, 32*1024)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("流式请求被取消: %w", ctx.Err())
|
||||
default:
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取流式数据失败: %w", err)
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
line = append(line, '\n')
|
||||
line = append(line, '\n')
|
||||
|
||||
fn(line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *LLMProxy) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string {
|
||||
prompt := ""
|
||||
parse := promptparser.New(promptparser.KindTask)
|
||||
for _, message := range req.Messages {
|
||||
if message.Role == "system" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(message.Content, "<task>") ||
|
||||
strings.Contains(message.Content, "<feedback>") ||
|
||||
strings.Contains(message.Content, "<user_message>") {
|
||||
if info, err := parse.Parse(message.Content); err == nil {
|
||||
prompt = info.Prompt
|
||||
} else {
|
||||
p.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, m := range message.MultiContent {
|
||||
if strings.Contains(m.Text, "<task>") ||
|
||||
strings.Contains(m.Text, "<feedback>") ||
|
||||
strings.Contains(m.Text, "<user_message>") {
|
||||
if info, err := parse.Parse(m.Text); err == nil {
|
||||
prompt = info.Prompt
|
||||
} else {
|
||||
p.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return prompt
|
||||
}
|
||||
|
||||
func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.ResponseWriter, req *openai.ChatCompletionRequest) {
|
||||
endpoint := "/chat/completions"
|
||||
p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error {
|
||||
startTime := time.Now()
|
||||
|
||||
m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeLLM)
|
||||
if err != nil {
|
||||
p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeLLM).WarnContext(ctx, "流式请求模型选择失败", "error", err)
|
||||
writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error")
|
||||
return err
|
||||
}
|
||||
|
||||
upstream := m.APIBase + endpoint
|
||||
log.UpstreamURL = upstream
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
p.logger.ErrorContext(ctx, "序列化请求体失败", "error", err)
|
||||
return fmt.Errorf("序列化请求体失败: %w", err)
|
||||
}
|
||||
|
||||
newReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstream, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
p.logger.With("upstream", upstream).WarnContext(ctx, "创建上游流式请求失败", "error", err)
|
||||
return fmt.Errorf("创建上游请求失败: %w", err)
|
||||
}
|
||||
|
||||
newReq.Header.Set("Content-Type", "application/json")
|
||||
newReq.Header.Set("Accept", "text/event-stream")
|
||||
newReq.Header.Set("Authorization", "Bearer "+m.APIKey)
|
||||
|
||||
// 保存请求头(去除敏感信息)
|
||||
requestHeaders := make(map[string][]string)
|
||||
for k, v := range newReq.Header {
|
||||
if k != "Authorization" {
|
||||
requestHeaders[k] = v
|
||||
} else {
|
||||
// 敏感信息脱敏
|
||||
requestHeaders[k] = []string{"Bearer ***"}
|
||||
}
|
||||
}
|
||||
log.RequestHeader = requestHeaders
|
||||
|
||||
logger := p.logger.With(
|
||||
"request_id", c.RequestID,
|
||||
"source_ip", c.SourceIP,
|
||||
"upstreamURL", upstream,
|
||||
"modelName", m.ModelName,
|
||||
"modelType", consts.ModelTypeLLM,
|
||||
"apiBase", m.APIBase,
|
||||
)
|
||||
|
||||
logger.With(
|
||||
"upstreamURL", upstream,
|
||||
"requestHeader", newReq.Header,
|
||||
"requestBody", req,
|
||||
"messages", cvt.Filter(req.Messages, func(i int, v openai.ChatCompletionMessage) (openai.ChatCompletionMessage, bool) {
|
||||
return v, v.Role != "system"
|
||||
}),
|
||||
).DebugContext(ctx, "转发流式请求到上游API")
|
||||
|
||||
// 发送请求
|
||||
resp, err := p.streamClient.Do(newReq)
|
||||
if err != nil {
|
||||
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游流式请求失败", "error", err)
|
||||
return fmt.Errorf("发送上游请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
responseBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 更新日志错误信息
|
||||
log.StatusCode = resp.StatusCode
|
||||
log.ResponseHeader = resp.Header
|
||||
log.ResponseBody = string(responseBody)
|
||||
log.Latency = time.Since(startTime).Milliseconds()
|
||||
|
||||
// 在debug级别记录错误的流式响应内容
|
||||
logger.With(
|
||||
"statusCode", resp.StatusCode,
|
||||
"responseHeader", resp.Header,
|
||||
"responseBody", string(responseBody),
|
||||
).DebugContext(ctx, "上游流式响应错误原始内容")
|
||||
|
||||
var errorResp ErrResp
|
||||
if err := json.Unmarshal(responseBody, &errorResp); err == nil {
|
||||
logger.With(
|
||||
"endpoint", endpoint,
|
||||
"requestBody", newReq,
|
||||
"statusCode", resp.StatusCode,
|
||||
"errorType", errorResp.Error.Type,
|
||||
"errorCode", errorResp.Error.Code,
|
||||
"errorMessage", errorResp.Error.Message,
|
||||
"latency", time.Since(startTime),
|
||||
).WarnContext(ctx, "上游API流式请求异常详情")
|
||||
|
||||
return fmt.Errorf("上游API返回错误: %s", errorResp.Error.Message)
|
||||
}
|
||||
|
||||
logger.With(
|
||||
"endpoint", endpoint,
|
||||
"requestBody", newReq,
|
||||
"statusCode", resp.StatusCode,
|
||||
"responseBody", string(responseBody),
|
||||
).WarnContext(ctx, "上游API流式请求异常详情")
|
||||
|
||||
return fmt.Errorf("上游API返回非200状态码: %d, 响应: %s", resp.StatusCode, string(responseBody))
|
||||
}
|
||||
|
||||
log.StatusCode = resp.StatusCode
|
||||
log.ResponseHeader = resp.Header
|
||||
|
||||
logger.With(
|
||||
"statusCode", resp.StatusCode,
|
||||
"responseHeader", resp.Header,
|
||||
).DebugContext(ctx, "上游流式响应头信息")
|
||||
|
||||
// 设置响应头
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
recorder := NewChatRecorder(
|
||||
ctx,
|
||||
c,
|
||||
p.usecase,
|
||||
m,
|
||||
req,
|
||||
resp.Body,
|
||||
w,
|
||||
p.logger.With("module", "ChatRecorder"),
|
||||
)
|
||||
defer recorder.Close()
|
||||
return recorder.Stream()
|
||||
})
|
||||
}
|
||||
|
||||
func (p *LLMProxy) handleChatCompletion(ctx context.Context, w http.ResponseWriter, req *openai.ChatCompletionRequest) {
|
||||
endpoint := "/chat/completions"
|
||||
p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error {
|
||||
// 使用负载均衡算法选择模型
|
||||
m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeCoder)
|
||||
if err != nil {
|
||||
p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "模型选择失败", "error", err)
|
||||
writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error")
|
||||
return err
|
||||
}
|
||||
|
||||
// 构造上游API URL
|
||||
upstream := m.APIBase + endpoint
|
||||
log.UpstreamURL = upstream
|
||||
u, err := url.Parse(upstream)
|
||||
if err != nil {
|
||||
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "解析上游URL失败", "error", err)
|
||||
writeErrResp(w, http.StatusInternalServerError, "无效的上游URL", "proxy_error")
|
||||
return err
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
prompt := p.getPrompt(ctx, req)
|
||||
mode := req.Metadata["mode"]
|
||||
taskID := req.Metadata["task_id"]
|
||||
|
||||
client := request.NewClient(
|
||||
u.Scheme,
|
||||
u.Host,
|
||||
30*time.Second,
|
||||
request.WithClient(p.client),
|
||||
)
|
||||
resp, err := request.Post[openai.ChatCompletionResponse](client, u.Path, req, request.WithHeader(request.Header{
|
||||
"Authorization": "Bearer " + m.APIKey,
|
||||
}))
|
||||
if err != nil {
|
||||
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游请求失败", "error", err)
|
||||
writeErrResp(w, http.StatusInternalServerError, "发送上游请求失败", "proxy_error")
|
||||
log.Latency = time.Since(startTime).Milliseconds()
|
||||
return err
|
||||
}
|
||||
|
||||
// 记录请求信息
|
||||
p.logger.With(
|
||||
"upstreamURL", upstream,
|
||||
"modelName", req.Model,
|
||||
"modelType", consts.ModelTypeCoder,
|
||||
"apiBase", m.APIBase,
|
||||
"requestBody", req,
|
||||
).DebugContext(ctx, "转发请求到上游API")
|
||||
|
||||
go func() {
|
||||
rc := &domain.RecordParam{
|
||||
TaskID: taskID,
|
||||
UserID: c.UserID,
|
||||
Prompt: prompt,
|
||||
WorkMode: mode,
|
||||
ModelID: m.ID,
|
||||
ModelType: m.ModelType,
|
||||
InputTokens: int64(resp.Usage.PromptTokens),
|
||||
OutputTokens: int64(resp.Usage.CompletionTokens),
|
||||
ProgramLanguage: req.Metadata["program_language"],
|
||||
}
|
||||
|
||||
for _, choice := range resp.Choices {
|
||||
rc.Completion += choice.Message.Content + "\n\n"
|
||||
}
|
||||
|
||||
p.logger.With("record", rc).DebugContext(ctx, "记录")
|
||||
|
||||
if err := p.usecase.Record(ctx, rc); err != nil {
|
||||
p.logger.With("modelID", m.ID, "modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "记录请求失败", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 计算请求耗时
|
||||
latency := time.Since(startTime)
|
||||
|
||||
// 更新请求日志
|
||||
log.StatusCode = http.StatusOK
|
||||
log.ResponseBody = resp
|
||||
log.ResponseHeader = resp.Header()
|
||||
log.Latency = latency.Milliseconds()
|
||||
|
||||
// 记录响应状态
|
||||
p.logger.With(
|
||||
"statusCode", http.StatusOK,
|
||||
"responseHeader", resp.Header(),
|
||||
"responseBody", resp,
|
||||
"latency", latency.String(),
|
||||
).DebugContext(ctx, "上游API响应")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
b, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
p.logger.With("response", resp).WarnContext(ctx, "序列化响应失败", "error", err)
|
||||
return err
|
||||
}
|
||||
w.Write(b)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (p *LLMProxy) HandleEmbeddings(ctx context.Context, w http.ResponseWriter, req *openai.EmbeddingRequest) {
|
||||
func (l *LLMProxy) errorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||
l.logger.ErrorContext(r.Context(), "error handler", slog.String("path", r.URL.Path), slog.Any("err", err))
|
||||
}
|
||||
|
||||
@@ -5,191 +5,331 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rokku-c/go-openai"
|
||||
|
||||
"github.com/chaitin/MonkeyCode/backend/config"
|
||||
"github.com/chaitin/MonkeyCode/backend/consts"
|
||||
"github.com/chaitin/MonkeyCode/backend/domain"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/promptparser"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/tee"
|
||||
)
|
||||
|
||||
type ChatRecorder struct {
|
||||
*tee.Tee
|
||||
ctx context.Context
|
||||
cx *Ctx
|
||||
usecase domain.ProxyUsecase
|
||||
req *openai.ChatCompletionRequest
|
||||
logger *slog.Logger
|
||||
model *domain.Model
|
||||
completion strings.Builder // 累积完整的响应内容
|
||||
usage *openai.Usage // 最终的使用统计
|
||||
buffer strings.Builder // 缓存不完整的行
|
||||
recorded bool // 标记是否已记录完整对话
|
||||
type Recorder struct {
|
||||
cfg *config.Config
|
||||
usecase domain.ProxyUsecase
|
||||
shadown chan []byte
|
||||
src io.ReadCloser
|
||||
ctx *ProxyCtx
|
||||
logger *slog.Logger
|
||||
logFile *os.File
|
||||
}
|
||||
|
||||
func NewChatRecorder(
|
||||
ctx context.Context,
|
||||
cx *Ctx,
|
||||
usecase domain.ProxyUsecase,
|
||||
model *domain.Model,
|
||||
req *openai.ChatCompletionRequest,
|
||||
r io.Reader,
|
||||
w io.Writer,
|
||||
var _ io.ReadCloser = &Recorder{}
|
||||
|
||||
func NewRecorder(
|
||||
cfg *config.Config,
|
||||
ctx *ProxyCtx,
|
||||
src io.ReadCloser,
|
||||
logger *slog.Logger,
|
||||
) *ChatRecorder {
|
||||
c := &ChatRecorder{
|
||||
ctx: ctx,
|
||||
cx: cx,
|
||||
usecase domain.ProxyUsecase,
|
||||
) *Recorder {
|
||||
r := &Recorder{
|
||||
cfg: cfg,
|
||||
usecase: usecase,
|
||||
model: model,
|
||||
req: req,
|
||||
shadown: make(chan []byte, 128*1024),
|
||||
src: src,
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
}
|
||||
c.Tee = tee.NewTee(ctx, logger, r, w, c.handle)
|
||||
return c
|
||||
go r.handleShadow()
|
||||
return r
|
||||
}
|
||||
|
||||
func (c *ChatRecorder) handle(ctx context.Context, data []byte) error {
|
||||
c.buffer.Write(data)
|
||||
bufferContent := c.buffer.String()
|
||||
func formatHeader(header http.Header) map[string]string {
|
||||
headerMap := make(map[string]string)
|
||||
for key, values := range header {
|
||||
headerMap[key] = strings.Join(values, ",")
|
||||
}
|
||||
return headerMap
|
||||
}
|
||||
|
||||
lines := strings.Split(bufferContent, "\n")
|
||||
if len(lines) > 0 {
|
||||
lastLine := lines[len(lines)-1]
|
||||
lines = lines[:len(lines)-1]
|
||||
c.buffer.Reset()
|
||||
c.buffer.WriteString(lastLine)
|
||||
func (r *Recorder) handleShadow() {
|
||||
body, err := r.ctx.ReqTee.GetBody()
|
||||
if err != nil {
|
||||
r.logger.WarnContext(r.ctx.ctx, "get req tee body failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
if err := c.processSSELine(ctx, line); err != nil {
|
||||
return err
|
||||
var (
|
||||
taskID, mode, prompt, language string
|
||||
)
|
||||
|
||||
switch r.ctx.Model.ModelType {
|
||||
case consts.ModelTypeLLM:
|
||||
var req openai.ChatCompletionRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
r.logger.WarnContext(r.ctx.ctx, "unmarshal chat completion request failed", "error", err)
|
||||
return
|
||||
}
|
||||
prompt = r.getPrompt(r.ctx.ctx, &req)
|
||||
taskID = req.Metadata["task_id"]
|
||||
mode = req.Metadata["mode"]
|
||||
|
||||
case consts.ModelTypeCoder:
|
||||
var req domain.CompletionRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
r.logger.WarnContext(r.ctx.ctx, "unmarshal completion request failed", "error", err)
|
||||
return
|
||||
}
|
||||
prompt = req.Prompt.(string)
|
||||
taskID = req.Metadata["task_id"]
|
||||
mode = req.Metadata["mode"]
|
||||
language = req.Metadata["program_language"]
|
||||
|
||||
default:
|
||||
r.logger.WarnContext(r.ctx.ctx, "skip handle shadow, model type not support", "modelType", r.ctx.Model.ModelType)
|
||||
return
|
||||
}
|
||||
|
||||
r.createFile(taskID)
|
||||
r.writeMeta(body)
|
||||
|
||||
rc := &domain.RecordParam{
|
||||
RequestID: r.ctx.RequestID,
|
||||
TaskID: taskID,
|
||||
UserID: r.ctx.UserID,
|
||||
ModelID: r.ctx.Model.ID,
|
||||
ModelType: r.ctx.Model.ModelType,
|
||||
WorkMode: mode,
|
||||
Prompt: prompt,
|
||||
ProgramLanguage: language,
|
||||
Role: consts.ChatRoleUser,
|
||||
}
|
||||
|
||||
var assistantRc *domain.RecordParam
|
||||
ct := r.ctx.RespHeader.Get("Content-Type")
|
||||
if strings.Contains(ct, "stream") {
|
||||
r.handleStream(rc)
|
||||
if r.ctx.Model.ModelType == consts.ModelTypeLLM {
|
||||
assistantRc = rc.Clone()
|
||||
assistantRc.Role = consts.ChatRoleAssistant
|
||||
rc.Completion = ""
|
||||
rc.OutputTokens = 0
|
||||
}
|
||||
} else {
|
||||
r.handleJson(rc)
|
||||
}
|
||||
|
||||
r.logger.
|
||||
With("header", formatHeader(r.ctx.Header)).
|
||||
With("resp_header", formatHeader(r.ctx.RespHeader)).
|
||||
DebugContext(r.ctx.ctx, "handle shadow", "rc", rc)
|
||||
|
||||
if err := r.usecase.Record(context.Background(), rc); err != nil {
|
||||
r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err)
|
||||
}
|
||||
|
||||
if assistantRc != nil {
|
||||
if err := r.usecase.Record(context.Background(), assistantRc); err != nil {
|
||||
r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Recorder) writeMeta(body []byte) {
|
||||
if r.logFile == nil {
|
||||
return
|
||||
}
|
||||
|
||||
r.logFile.WriteString("------------------ Request Header ------------------\n")
|
||||
for key, value := range formatHeader(r.ctx.Header) {
|
||||
r.logFile.WriteString(key + ": " + value + "\n")
|
||||
}
|
||||
r.logFile.WriteString("------------------ Request Body ------------------\n")
|
||||
r.logFile.WriteString(string(body))
|
||||
r.logFile.WriteString("\n")
|
||||
r.logFile.WriteString("------------------ Response Header ------------------\n")
|
||||
for key, value := range formatHeader(r.ctx.RespHeader) {
|
||||
r.logFile.WriteString(key + ": " + value + "\n")
|
||||
}
|
||||
r.logFile.WriteString("------------------ Response Body ------------------\n")
|
||||
}
|
||||
|
||||
func (r *Recorder) createFile(taskID string) {
|
||||
if r.cfg.LLMProxy.RequestLogPath == "" {
|
||||
return
|
||||
}
|
||||
|
||||
dir := filepath.Join(r.cfg.LLMProxy.RequestLogPath, time.Now().Format("2006010215"))
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
r.logger.WarnContext(r.ctx.ctx, "create dir failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
id := r.ctx.RequestID
|
||||
if taskID != "" {
|
||||
id = taskID
|
||||
}
|
||||
filename := filepath.Join(dir, id+".log")
|
||||
f, err := os.OpenFile(filename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
r.logger.WarnContext(r.ctx.ctx, "create file failed", "error", err)
|
||||
return
|
||||
}
|
||||
r.logFile = f
|
||||
}
|
||||
|
||||
func (r *Recorder) handleJson(rc *domain.RecordParam) {
|
||||
buffer := strings.Builder{}
|
||||
for data := range r.shadown {
|
||||
buffer.Write(data)
|
||||
if r.logFile != nil {
|
||||
r.logFile.Write(data)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
switch rc.ModelType {
|
||||
case consts.ModelTypeLLM:
|
||||
var resp openai.ChatCompletionResponse
|
||||
if err := json.Unmarshal([]byte(buffer.String()), &resp); err != nil {
|
||||
r.logger.WarnContext(r.ctx.ctx, "unmarshal chat completion response failed", "error", err)
|
||||
return
|
||||
}
|
||||
if len(resp.Choices) > 0 {
|
||||
rc.Completion = resp.Choices[0].Message.Content
|
||||
rc.InputTokens = int64(resp.Usage.PromptTokens)
|
||||
rc.OutputTokens = int64(resp.Usage.CompletionTokens)
|
||||
}
|
||||
|
||||
case consts.ModelTypeCoder:
|
||||
var resp openai.CompletionResponse
|
||||
if err := json.Unmarshal([]byte(buffer.String()), &resp); err != nil {
|
||||
r.logger.WarnContext(r.ctx.ctx, "unmarshal completion response failed", "error", err)
|
||||
return
|
||||
}
|
||||
if rc.TaskID == "" {
|
||||
rc.TaskID = resp.ID
|
||||
}
|
||||
rc.InputTokens = int64(resp.Usage.PromptTokens)
|
||||
rc.OutputTokens = int64(resp.Usage.CompletionTokens)
|
||||
if len(resp.Choices) > 0 {
|
||||
rc.Completion = resp.Choices[0].Text
|
||||
rc.CodeLines = int64(strings.Count(resp.Choices[0].Text, "\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatRecorder) processSSELine(ctx context.Context, line string) error {
|
||||
line = strings.TrimSpace(line)
|
||||
func (r *Recorder) handleStream(rc *domain.RecordParam) {
|
||||
buffer := strings.Builder{}
|
||||
for data := range r.shadown {
|
||||
buffer.Write(data)
|
||||
cnt := buffer.String()
|
||||
|
||||
if r.logFile != nil {
|
||||
r.logFile.Write(data)
|
||||
}
|
||||
|
||||
lines := strings.Split(cnt, "\n")
|
||||
if len(lines) > 0 {
|
||||
lastLine := lines[len(lines)-1]
|
||||
lines = lines[:len(lines)-1]
|
||||
buffer.Reset()
|
||||
buffer.WriteString(lastLine)
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
if err := r.processSSELine(r.ctx.ctx, line, rc); err != nil {
|
||||
r.logger.WarnContext(r.ctx.ctx, "处理SSE行失败", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Recorder) processSSELine(ctx context.Context, line string, rc *domain.RecordParam) error {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "data:") {
|
||||
return nil
|
||||
}
|
||||
|
||||
dataContent := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if dataContent == "" {
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if data == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dataContent == "[DONE]" {
|
||||
c.processCompletedChat(ctx)
|
||||
if data == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var resp openai.ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal([]byte(dataContent), &resp); err != nil {
|
||||
c.logger.With("data", dataContent).WarnContext(ctx, "解析流式响应失败", "error", err)
|
||||
return nil
|
||||
}
|
||||
switch r.ctx.Model.ModelType {
|
||||
case consts.ModelTypeLLM:
|
||||
var resp openai.ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal([]byte(data), &resp); err != nil {
|
||||
r.logger.With("model_type", r.ctx.Model.ModelType).With("data", data).WarnContext(ctx, "解析SSE行失败", "error", err)
|
||||
return nil
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
rc.InputTokens = int64(resp.Usage.PromptTokens)
|
||||
rc.OutputTokens += int64(resp.Usage.CompletionTokens)
|
||||
}
|
||||
if len(resp.Choices) > 0 {
|
||||
content := resp.Choices[0].Delta.Content
|
||||
if content != "" {
|
||||
rc.Completion += content
|
||||
}
|
||||
}
|
||||
|
||||
prompt := c.getPrompt(ctx, c.req)
|
||||
mode := c.req.Metadata["mode"]
|
||||
taskID := c.req.Metadata["task_id"]
|
||||
case consts.ModelTypeCoder:
|
||||
var resp openai.CompletionResponse
|
||||
if err := json.Unmarshal([]byte(data), &resp); err != nil {
|
||||
r.logger.With("model_type", r.ctx.Model.ModelType).With("data", data).WarnContext(ctx, "解析SSE行失败", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
rc := &domain.RecordParam{
|
||||
RequestID: c.cx.RequestID,
|
||||
TaskID: taskID,
|
||||
UserID: c.cx.UserID,
|
||||
ModelID: c.model.ID,
|
||||
ModelType: c.model.ModelType,
|
||||
WorkMode: mode,
|
||||
Prompt: prompt,
|
||||
Role: consts.ChatRoleAssistant,
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
if rc.TaskID == "" {
|
||||
rc.TaskID = resp.ID
|
||||
}
|
||||
rc.InputTokens = int64(resp.Usage.PromptTokens)
|
||||
}
|
||||
|
||||
if rc.Prompt != "" {
|
||||
urc := rc.Clone()
|
||||
urc.Role = consts.ChatRoleUser
|
||||
urc.Completion = urc.Prompt
|
||||
if err := c.usecase.Record(context.Background(), urc); err != nil {
|
||||
c.logger.With("modelID", c.model.ID, "modelName", c.model.ModelName, "modelType", consts.ModelTypeLLM).
|
||||
WarnContext(ctx, "插入流式记录失败", "error", err)
|
||||
rc.OutputTokens += int64(resp.Usage.CompletionTokens)
|
||||
if len(resp.Choices) > 0 {
|
||||
rc.Completion += resp.Choices[0].Text
|
||||
rc.CodeLines += int64(strings.Count(resp.Choices[0].Text, "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
if len(resp.Choices) > 0 {
|
||||
content := resp.Choices[0].Delta.Content
|
||||
if content != "" {
|
||||
c.completion.WriteString(content)
|
||||
}
|
||||
}
|
||||
|
||||
if resp.Usage != nil {
|
||||
c.usage = resp.Usage
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChatRecorder) processCompletedChat(ctx context.Context) {
|
||||
if c.recorded {
|
||||
return // 避免重复记录
|
||||
// Close implements io.ReadCloser.
|
||||
func (r *Recorder) Close() error {
|
||||
if r.shadown != nil {
|
||||
close(r.shadown)
|
||||
}
|
||||
|
||||
mode := c.req.Metadata["mode"]
|
||||
taskID := c.req.Metadata["task_id"]
|
||||
|
||||
rc := &domain.RecordParam{
|
||||
RequestID: c.cx.RequestID,
|
||||
TaskID: taskID,
|
||||
UserID: c.cx.UserID,
|
||||
ModelID: c.model.ID,
|
||||
ModelType: c.model.ModelType,
|
||||
WorkMode: mode,
|
||||
Role: consts.ChatRoleAssistant,
|
||||
Completion: c.completion.String(),
|
||||
InputTokens: int64(c.usage.PromptTokens),
|
||||
OutputTokens: int64(c.usage.CompletionTokens),
|
||||
}
|
||||
|
||||
if err := c.usecase.Record(context.Background(), rc); err != nil {
|
||||
c.logger.With("modelID", c.model.ID, "modelName", c.model.ModelName, "modelType", consts.ModelTypeLLM).
|
||||
WarnContext(ctx, "插入流式记录失败", "error", err)
|
||||
} else {
|
||||
c.recorded = true
|
||||
c.logger.With("requestID", c.cx.RequestID, "completion_length", len(c.completion.String())).
|
||||
InfoContext(ctx, "流式对话记录已保存")
|
||||
if r.logFile != nil {
|
||||
r.logFile.Close()
|
||||
}
|
||||
return r.src.Close()
|
||||
}
|
||||
|
||||
// Close 关闭 recorder 并确保数据被保存
|
||||
func (c *ChatRecorder) Close() {
|
||||
// 如果有累积的内容但还没有记录,强制保存
|
||||
if !c.recorded && c.completion.Len() > 0 {
|
||||
c.logger.With("requestID", c.cx.RequestID).
|
||||
WarnContext(c.ctx, "数据流异常中断,强制保存已累积的内容")
|
||||
c.processCompletedChat(c.ctx)
|
||||
// Read implements io.ReadCloser.
|
||||
func (r *Recorder) Read(p []byte) (n int, err error) {
|
||||
n, err = r.src.Read(p)
|
||||
if n > 0 {
|
||||
data := make([]byte, n)
|
||||
copy(data, p[:n])
|
||||
r.shadown <- data
|
||||
}
|
||||
|
||||
if c.Tee != nil {
|
||||
c.Tee.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ChatRecorder) Reset() {
|
||||
c.completion.Reset()
|
||||
c.buffer.Reset()
|
||||
c.usage = nil
|
||||
c.recorded = false
|
||||
}
|
||||
|
||||
func (c *ChatRecorder) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string {
|
||||
func (r *Recorder) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string {
|
||||
prompt := ""
|
||||
parse := promptparser.New(promptparser.KindTask)
|
||||
for _, message := range req.Messages {
|
||||
@@ -203,7 +343,7 @@ func (c *ChatRecorder) getPrompt(ctx context.Context, req *openai.ChatCompletion
|
||||
if info, err := parse.Parse(message.Content); err == nil {
|
||||
prompt = info.Prompt
|
||||
} else {
|
||||
c.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err)
|
||||
r.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,7 +354,7 @@ func (c *ChatRecorder) getPrompt(ctx context.Context, req *openai.ChatCompletion
|
||||
if info, err := parse.Parse(m.Text); err == nil {
|
||||
prompt = info.Prompt
|
||||
} else {
|
||||
c.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err)
|
||||
r.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ type UserHandler struct {
|
||||
session *session.Session
|
||||
logger *slog.Logger
|
||||
cfg *config.Config
|
||||
limitCh chan struct{}
|
||||
}
|
||||
|
||||
func NewUserHandler(
|
||||
@@ -40,6 +41,7 @@ func NewUserHandler(
|
||||
logger: logger,
|
||||
cfg: cfg,
|
||||
euse: euse,
|
||||
limitCh: make(chan struct{}, cfg.Extension.Limit),
|
||||
}
|
||||
|
||||
w.GET("/api/v1/static/vsix/:version", web.BaseHandler(u.VSIXDownload))
|
||||
@@ -94,6 +96,11 @@ func (h *UserHandler) VSCodeAuthInit(c *web.Context, req domain.VSCodeAuthInitRe
|
||||
// @Produce octet-stream
|
||||
// @Router /api/v1/static/vsix [get]
|
||||
func (h *UserHandler) VSIXDownload(c *web.Context) error {
|
||||
h.limitCh <- struct{}{}
|
||||
defer func() {
|
||||
<-h.limitCh
|
||||
}()
|
||||
|
||||
v, err := h.euse.GetByVersion(c.Request().Context(), c.Param("version"))
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
81
backend/pkg/tee/reqtee.go
Normal file
81
backend/pkg/tee/reqtee.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package tee
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ReqTee struct {
|
||||
src io.ReadCloser
|
||||
buf bytes.Buffer
|
||||
done chan struct{}
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
maxSize int64 // 最大缓冲区大小,0表示无限制
|
||||
}
|
||||
|
||||
var _ io.ReadCloser = &ReqTee{}
|
||||
|
||||
func NewReqTee(src io.ReadCloser) *ReqTee {
|
||||
return NewReqTeeWithMaxSize(src, 0)
|
||||
}
|
||||
|
||||
// NewReqTeeWithMaxSize 创建带最大缓冲区大小限制的ReqTee
|
||||
func NewReqTeeWithMaxSize(src io.ReadCloser, maxSize int64) *ReqTee {
|
||||
return &ReqTee{
|
||||
src: src,
|
||||
buf: bytes.Buffer{},
|
||||
done: make(chan struct{}),
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ReqTee) GetBody() ([]byte, error) {
|
||||
return r.GetBodyWithTimeout(30 * time.Second)
|
||||
}
|
||||
|
||||
// GetBodyWithTimeout 获取缓冲的数据,支持自定义超时时间
|
||||
func (r *ReqTee) GetBodyWithTimeout(timeout time.Duration) ([]byte, error) {
|
||||
select {
|
||||
case <-r.done:
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.buf.Bytes(), nil
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("timeout waiting for data")
|
||||
}
|
||||
}
|
||||
|
||||
// Close implements io.ReadCloser.
|
||||
func (r *ReqTee) Close() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.closed {
|
||||
return nil
|
||||
}
|
||||
r.closed = true
|
||||
|
||||
close(r.done)
|
||||
return r.src.Close()
|
||||
}
|
||||
|
||||
// Read implements io.ReadCloser.
|
||||
func (r *ReqTee) Read(p []byte) (n int, err error) {
|
||||
n, err = r.src.Read(p)
|
||||
if n > 0 {
|
||||
r.mu.Lock()
|
||||
// 检查缓冲区大小限制
|
||||
if r.maxSize > 0 && int64(r.buf.Len()+n) > r.maxSize {
|
||||
r.mu.Unlock()
|
||||
return n, fmt.Errorf("buffer size limit exceeded: %d bytes", r.maxSize)
|
||||
}
|
||||
// 直接写入缓冲区,避免额外的内存分配和复制
|
||||
r.buf.Write(p[:n])
|
||||
r.mu.Unlock()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
package tee
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type TeeHandleFunc func(ctx context.Context, data []byte) error
|
||||
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, 4096)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
type Tee struct {
|
||||
ctx context.Context
|
||||
logger *slog.Logger
|
||||
Reader io.Reader
|
||||
Writer io.Writer
|
||||
ch chan []byte
|
||||
handle TeeHandleFunc
|
||||
}
|
||||
|
||||
func NewTee(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
reader io.Reader,
|
||||
writer io.Writer,
|
||||
handle TeeHandleFunc,
|
||||
) *Tee {
|
||||
t := &Tee{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
Reader: reader,
|
||||
Writer: writer,
|
||||
handle: handle,
|
||||
ch: make(chan []byte, 32*1024),
|
||||
}
|
||||
go t.Handle()
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Tee) Close() {
|
||||
select {
|
||||
case <-t.ch:
|
||||
// channel 已经关闭
|
||||
default:
|
||||
close(t.ch)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tee) Handle() {
|
||||
for {
|
||||
select {
|
||||
case data, ok := <-t.ch:
|
||||
if !ok {
|
||||
t.logger.DebugContext(t.ctx, "Tee Handle closed")
|
||||
return
|
||||
}
|
||||
err := t.handle(t.ctx, data)
|
||||
if err != nil {
|
||||
t.logger.With("data", string(data)).With("error", err).ErrorContext(t.ctx, "Tee Handle error")
|
||||
return
|
||||
}
|
||||
case <-t.ctx.Done():
|
||||
t.logger.DebugContext(t.ctx, "Tee Handle ctx done")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tee) Stream() error {
|
||||
bufPtr := bufferPool.Get().(*[]byte)
|
||||
buf := *bufPtr
|
||||
defer bufferPool.Put(bufPtr)
|
||||
|
||||
for {
|
||||
n, err := t.Reader.Read(buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if n > 0 {
|
||||
_, err = t.Writer.Write(buf[:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if flusher, ok := t.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
data := make([]byte, n)
|
||||
copy(data, buf[:n])
|
||||
t.ch <- data
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,376 +0,0 @@
|
||||
package tee
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockWriter 模拟 Writer 接口
|
||||
type mockWriter struct {
|
||||
buf *bytes.Buffer
|
||||
delay time.Duration
|
||||
errorOn int // 在第几次写入时返回错误
|
||||
count int
|
||||
}
|
||||
|
||||
func newMockWriter() *mockWriter {
|
||||
return &mockWriter{
|
||||
buf: &bytes.Buffer{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockWriter) Write(p []byte) (n int, err error) {
|
||||
m.count++
|
||||
if m.errorOn > 0 && m.count >= m.errorOn {
|
||||
return 0, errors.New("mock write error")
|
||||
}
|
||||
if m.delay > 0 {
|
||||
time.Sleep(m.delay)
|
||||
}
|
||||
return m.buf.Write(p)
|
||||
}
|
||||
|
||||
func (m *mockWriter) String() string {
|
||||
return m.buf.String()
|
||||
}
|
||||
|
||||
// mockReader 模拟 Reader 接口
|
||||
type mockReader struct {
|
||||
data []byte
|
||||
pos int
|
||||
chunk int // 每次读取的字节数
|
||||
errorOn int // 在第几次读取时返回错误
|
||||
count int
|
||||
}
|
||||
|
||||
func newMockReader(data string, chunk int) *mockReader {
|
||||
return &mockReader{
|
||||
data: []byte(data),
|
||||
chunk: chunk,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockReader) Read(p []byte) (n int, err error) {
|
||||
m.count++
|
||||
if m.errorOn > 0 && m.count >= m.errorOn {
|
||||
return 0, errors.New("mock read error")
|
||||
}
|
||||
|
||||
if m.pos >= len(m.data) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
readSize := m.chunk
|
||||
if readSize <= 0 || readSize > len(p) {
|
||||
readSize = len(p)
|
||||
}
|
||||
|
||||
remaining := len(m.data) - m.pos
|
||||
if readSize > remaining {
|
||||
readSize = remaining
|
||||
}
|
||||
|
||||
copy(p, m.data[m.pos:m.pos+readSize])
|
||||
m.pos += readSize
|
||||
return readSize, nil
|
||||
}
|
||||
|
||||
// TestTeeBasicFunctionality 测试基本功能
|
||||
func TestTeeBasicFunctionality(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
testData := "Hello, World! This is a test message."
|
||||
reader := newMockReader(testData, 10) // 每次读取10字节
|
||||
writer := newMockWriter()
|
||||
|
||||
var handledData [][]byte
|
||||
var mu sync.Mutex
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// 复制数据,因为原始数据可能被重用
|
||||
dataCopy := make([]byte, len(data))
|
||||
copy(dataCopy, data)
|
||||
handledData = append(handledData, dataCopy)
|
||||
return nil
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
defer tee.Close()
|
||||
|
||||
err := tee.Stream()
|
||||
if err != nil {
|
||||
t.Fatalf("Stream() failed: %v", err)
|
||||
}
|
||||
|
||||
// 等待处理完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证写入的数据
|
||||
if writer.String() != testData {
|
||||
t.Errorf("Expected writer data %q, got %q", testData, writer.String())
|
||||
}
|
||||
|
||||
// 验证处理的数据
|
||||
mu.Lock()
|
||||
var totalHandled []byte
|
||||
for _, chunk := range handledData {
|
||||
totalHandled = append(totalHandled, chunk...)
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if string(totalHandled) != testData {
|
||||
t.Errorf("Expected handled data %q, got %q", testData, string(totalHandled))
|
||||
}
|
||||
}
|
||||
|
||||
// TestTeeWithErrors 测试错误处理
|
||||
func TestTeeWithErrors(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
t.Run("ReaderError", func(t *testing.T) {
|
||||
reader := newMockReader("test data", 5)
|
||||
reader.errorOn = 2 // 第二次读取时出错
|
||||
writer := newMockWriter()
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
defer tee.Close()
|
||||
|
||||
err := tee.Stream()
|
||||
if err == nil {
|
||||
t.Error("Expected error from reader, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WriterError", func(t *testing.T) {
|
||||
reader := newMockReader("test data", 5)
|
||||
writer := newMockWriter()
|
||||
writer.errorOn = 2 // 第二次写入时出错
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
defer tee.Close()
|
||||
|
||||
err := tee.Stream()
|
||||
if err == nil {
|
||||
t.Error("Expected error from writer, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HandleError", func(t *testing.T) {
|
||||
reader := newMockReader("test data", 5)
|
||||
writer := newMockWriter()
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
return errors.New("handle error")
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
defer tee.Close()
|
||||
|
||||
// 启动 Stream 在单独的 goroutine 中
|
||||
go func() {
|
||||
tee.Stream()
|
||||
}()
|
||||
|
||||
// 等待一段时间让处理器有机会处理数据并出错
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
})
|
||||
}
|
||||
|
||||
// TestTeeContextCancellation 测试上下文取消
|
||||
func TestTeeContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
// 创建一个会持续产生数据的 reader
|
||||
reader := strings.NewReader(strings.Repeat("test data ", 1000))
|
||||
writer := newMockWriter()
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
defer tee.Close()
|
||||
|
||||
// 在单独的 goroutine 中启动 Stream
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- tee.Stream()
|
||||
}()
|
||||
|
||||
// 等待一段时间后取消上下文
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
// 等待 Stream 完成
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil && err != io.EOF {
|
||||
t.Logf("Stream completed with error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Stream did not complete within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTeeConcurrentSafety 测试并发安全性
|
||||
func TestTeeConcurrentSafety(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
testData := strings.Repeat("concurrent test data ", 100)
|
||||
reader := strings.NewReader(testData)
|
||||
writer := newMockWriter()
|
||||
|
||||
var processedCount int64
|
||||
var mu sync.Mutex
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
mu.Lock()
|
||||
processedCount++
|
||||
mu.Unlock()
|
||||
// 模拟一些处理时间
|
||||
time.Sleep(time.Microsecond)
|
||||
return nil
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
defer tee.Close()
|
||||
|
||||
err := tee.Stream()
|
||||
if err != nil {
|
||||
t.Fatalf("Stream() failed: %v", err)
|
||||
}
|
||||
|
||||
// 等待所有数据处理完成
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
finalCount := processedCount
|
||||
mu.Unlock()
|
||||
|
||||
if finalCount == 0 {
|
||||
t.Error("No data was processed")
|
||||
}
|
||||
|
||||
t.Logf("Processed %d chunks of data", finalCount)
|
||||
}
|
||||
|
||||
// TestBufferPoolEfficiency 测试缓冲区池的效率
|
||||
func TestBufferPoolEfficiency(t *testing.T) {
|
||||
// 这个测试验证缓冲区池是否正常工作
|
||||
// 通过多次获取和归还缓冲区来测试
|
||||
|
||||
var buffers []*[]byte
|
||||
|
||||
// 获取多个缓冲区
|
||||
for i := 0; i < 10; i++ {
|
||||
bufPtr := bufferPool.Get().(*[]byte)
|
||||
buffers = append(buffers, bufPtr)
|
||||
|
||||
// 验证缓冲区大小
|
||||
if len(*bufPtr) != 4096 {
|
||||
t.Errorf("Expected buffer size 4096, got %d", len(*bufPtr))
|
||||
}
|
||||
}
|
||||
|
||||
// 归还所有缓冲区
|
||||
for _, bufPtr := range buffers {
|
||||
bufferPool.Put(bufPtr)
|
||||
}
|
||||
|
||||
// 再次获取缓冲区,应该重用之前的缓冲区
|
||||
for i := 0; i < 5; i++ {
|
||||
bufPtr := bufferPool.Get().(*[]byte)
|
||||
if len(*bufPtr) != 4096 {
|
||||
t.Errorf("Expected reused buffer size 4096, got %d", len(*bufPtr))
|
||||
}
|
||||
bufferPool.Put(bufPtr)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkTeeStream 基准测试
|
||||
func BenchmarkTeeStream(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
testData := strings.Repeat("benchmark test data ", 1000)
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
reader := strings.NewReader(testData)
|
||||
writer := io.Discard
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
err := tee.Stream()
|
||||
if err != nil {
|
||||
b.Fatalf("Stream() failed: %v", err)
|
||||
}
|
||||
tee.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkBufferPool 缓冲区池基准测试
|
||||
func BenchmarkBufferPool(b *testing.B) {
|
||||
b.Run("WithPool", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
bufPtr := bufferPool.Get().(*[]byte)
|
||||
// 模拟使用缓冲区
|
||||
_ = *bufPtr
|
||||
bufferPool.Put(bufPtr)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WithoutPool", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := make([]byte, 4096)
|
||||
// 模拟使用缓冲区
|
||||
_ = buf
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestTeeClose 测试关闭功能
|
||||
func TestTeeClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
reader := strings.NewReader("test data")
|
||||
writer := newMockWriter()
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
|
||||
// 测试多次关闭不会 panic
|
||||
tee.Close()
|
||||
tee.Close()
|
||||
tee.Close()
|
||||
}
|
||||
Reference in New Issue
Block a user