mirror of
https://github.com/chaitin/MonkeyCode.git
synced 2026-02-03 23:33:37 +08:00
Merge pull request #70 from yokowu/feat-optimize-proxy
feat(proxy): 优化流式代理, 隔离代理与记录逻辑
This commit is contained in:
@@ -49,10 +49,11 @@ type Config struct {
|
||||
} `mapstructure:"redis"`
|
||||
|
||||
LLMProxy struct {
|
||||
Timeout string `mapstructure:"timeout"`
|
||||
KeepAlive string `mapstructure:"keep_alive"`
|
||||
ClientPoolSize int `mapstructure:"client_pool_size"`
|
||||
RequestLogPath string `mapstructure:"request_log_path"`
|
||||
Timeout string `mapstructure:"timeout"`
|
||||
KeepAlive string `mapstructure:"keep_alive"`
|
||||
ClientPoolSize int `mapstructure:"client_pool_size"`
|
||||
StreamClientPoolSize int `mapstructure:"stream_client_pool_size"`
|
||||
RequestLogPath string `mapstructure:"request_log_path"`
|
||||
} `mapstructure:"llm_proxy"`
|
||||
|
||||
InitModel struct {
|
||||
@@ -92,6 +93,7 @@ func Init() (*Config, error) {
|
||||
v.SetDefault("llm_proxy.timeout", "30s")
|
||||
v.SetDefault("llm_proxy.keep_alive", "60s")
|
||||
v.SetDefault("llm_proxy.client_pool_size", 100)
|
||||
v.SetDefault("llm_proxy.stream_client_pool_size", 5000)
|
||||
v.SetDefault("llm_proxy.request_log_path", "/app/request/logs")
|
||||
v.SetDefault("init_model.name", "qwen2.5-coder-3b-instruct")
|
||||
v.SetDefault("init_model.key", "")
|
||||
|
||||
@@ -15,7 +15,7 @@ func RequestID() echo.MiddlewareFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
requestID := uuid.New().String()
|
||||
ctx = context.WithValue(ctx, logger.RequestIDKey, requestID)
|
||||
ctx = context.WithValue(ctx, logger.RequestIDKey{}, requestID)
|
||||
c.SetRequest(c.Request().WithContext(ctx))
|
||||
return next(c)
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func (p *ProxyMiddleware) Auth() echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
ctx := c.Request().Context()
|
||||
ctx = context.WithValue(ctx, logger.UserIDKey, key.UserID)
|
||||
ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID)
|
||||
c.SetRequest(c.Request().WithContext(ctx))
|
||||
c.Set(ApiContextKey, key)
|
||||
return next(c)
|
||||
|
||||
@@ -59,6 +59,7 @@ type LLMProxy struct {
|
||||
usecase domain.ProxyUsecase
|
||||
cfg *config.Config
|
||||
client *http.Client
|
||||
streamClient *http.Client
|
||||
logger *slog.Logger
|
||||
requestLogPath string // 请求日志保存路径
|
||||
}
|
||||
@@ -83,7 +84,6 @@ func NewLLMProxy(
|
||||
logger.Warn("解析保持连接时间失败, 使用默认值 60s", "error", err)
|
||||
}
|
||||
|
||||
// 创建HTTP客户端
|
||||
client := &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
@@ -98,6 +98,18 @@ func NewLLMProxy(
|
||||
},
|
||||
}
|
||||
|
||||
streamClient := &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: cfg.LLMProxy.StreamClientPoolSize,
|
||||
MaxConnsPerHost: cfg.LLMProxy.StreamClientPoolSize,
|
||||
MaxIdleConnsPerHost: cfg.LLMProxy.StreamClientPoolSize,
|
||||
IdleConnTimeout: 24 * time.Hour,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
// 获取日志配置
|
||||
requestLogPath := ""
|
||||
if cfg.LLMProxy.RequestLogPath != "" {
|
||||
@@ -111,6 +123,7 @@ func NewLLMProxy(
|
||||
return &LLMProxy{
|
||||
usecase: usecase,
|
||||
client: client,
|
||||
streamClient: streamClient,
|
||||
cfg: cfg,
|
||||
requestLogPath: requestLogPath,
|
||||
logger: logger,
|
||||
@@ -174,12 +187,12 @@ type Ctx struct {
|
||||
func (p *LLMProxy) handle(ctx context.Context, fn func(ctx *Ctx, log *RequestResponseLog) error) {
|
||||
// 获取用户ID
|
||||
userID := "unknown"
|
||||
if id, ok := ctx.Value(logger.UserIDKey).(string); ok {
|
||||
if id, ok := ctx.Value(logger.UserIDKey{}).(string); ok {
|
||||
userID = id
|
||||
}
|
||||
|
||||
requestID := "unknown"
|
||||
if id, ok := ctx.Value(logger.RequestIDKey).(string); ok {
|
||||
if id, ok := ctx.Value(logger.RequestIDKey{}).(string); ok {
|
||||
requestID = id
|
||||
}
|
||||
|
||||
@@ -203,11 +216,11 @@ func (p *LLMProxy) handle(ctx context.Context, fn func(ctx *Ctx, log *RequestRes
|
||||
}
|
||||
|
||||
if err := fn(c, l); err != nil {
|
||||
p.logger.With("userID", userID, "requestID", requestID, "sourceip", sourceip).ErrorContext(ctx, "处理请求失败", "error", err)
|
||||
p.logger.With("source_ip", sourceip).ErrorContext(ctx, "处理请求失败", "error", err)
|
||||
l.Error = err.Error()
|
||||
}
|
||||
|
||||
p.saveRequestResponseLog(l)
|
||||
go p.saveRequestResponseLog(l)
|
||||
}
|
||||
|
||||
func (p *LLMProxy) HandleCompletion(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) {
|
||||
@@ -585,10 +598,6 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
|
||||
return err
|
||||
}
|
||||
|
||||
prompt := p.getPrompt(ctx, req)
|
||||
mode := req.Metadata["mode"]
|
||||
taskID := req.Metadata["task_id"]
|
||||
|
||||
upstream := m.APIBase + endpoint
|
||||
log.UpstreamURL = upstream
|
||||
|
||||
@@ -606,9 +615,7 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
|
||||
|
||||
newReq.Header.Set("Content-Type", "application/json")
|
||||
newReq.Header.Set("Accept", "text/event-stream")
|
||||
if m.APIKey != "" && m.APIKey != "none" {
|
||||
newReq.Header.Set("Authorization", "Bearer "+m.APIKey)
|
||||
}
|
||||
newReq.Header.Set("Authorization", "Bearer "+m.APIKey)
|
||||
|
||||
// 保存请求头(去除敏感信息)
|
||||
requestHeaders := make(map[string][]string)
|
||||
@@ -622,22 +629,26 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
|
||||
}
|
||||
log.RequestHeader = requestHeaders
|
||||
|
||||
p.logger.With(
|
||||
logger := p.logger.With(
|
||||
"request_id", c.RequestID,
|
||||
"source_ip", c.SourceIP,
|
||||
"upstreamURL", upstream,
|
||||
"modelName", m.ModelName,
|
||||
"modelType", consts.ModelTypeLLM,
|
||||
"apiBase", m.APIBase,
|
||||
"work_mode", mode,
|
||||
)
|
||||
|
||||
logger.With(
|
||||
"upstreamURL", upstream,
|
||||
"requestHeader", newReq.Header,
|
||||
"requestBody", req,
|
||||
"taskID", taskID,
|
||||
"messages", cvt.Filter(req.Messages, func(i int, v openai.ChatCompletionMessage) (openai.ChatCompletionMessage, bool) {
|
||||
return v, v.Role != "system"
|
||||
}),
|
||||
).DebugContext(ctx, "转发流式请求到上游API")
|
||||
|
||||
// 发送请求
|
||||
resp, err := p.client.Do(newReq)
|
||||
resp, err := p.streamClient.Do(newReq)
|
||||
if err != nil {
|
||||
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游流式请求失败", "error", err)
|
||||
return fmt.Errorf("发送上游请求失败: %w", err)
|
||||
@@ -655,7 +666,7 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
|
||||
log.Latency = time.Since(startTime).Milliseconds()
|
||||
|
||||
// 在debug级别记录错误的流式响应内容
|
||||
p.logger.With(
|
||||
logger.With(
|
||||
"statusCode", resp.StatusCode,
|
||||
"responseHeader", resp.Header,
|
||||
"responseBody", string(responseBody),
|
||||
@@ -663,9 +674,8 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
|
||||
|
||||
var errorResp ErrResp
|
||||
if err := json.Unmarshal(responseBody, &errorResp); err == nil {
|
||||
p.logger.With(
|
||||
logger.With(
|
||||
"endpoint", endpoint,
|
||||
"upstreamURL", upstream,
|
||||
"requestBody", newReq,
|
||||
"statusCode", resp.StatusCode,
|
||||
"errorType", errorResp.Error.Type,
|
||||
@@ -677,9 +687,8 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
|
||||
return fmt.Errorf("上游API返回错误: %s", errorResp.Error.Message)
|
||||
}
|
||||
|
||||
p.logger.With(
|
||||
logger.With(
|
||||
"endpoint", endpoint,
|
||||
"upstreamURL", upstream,
|
||||
"requestBody", newReq,
|
||||
"statusCode", resp.StatusCode,
|
||||
"responseBody", string(responseBody),
|
||||
@@ -688,12 +697,10 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
|
||||
return fmt.Errorf("上游API返回非200状态码: %d, 响应: %s", resp.StatusCode, string(responseBody))
|
||||
}
|
||||
|
||||
// 更新日志信息
|
||||
log.StatusCode = resp.StatusCode
|
||||
log.ResponseHeader = resp.Header
|
||||
|
||||
// 在debug级别记录流式响应头信息
|
||||
p.logger.With(
|
||||
logger.With(
|
||||
"statusCode", resp.StatusCode,
|
||||
"responseHeader", resp.Header,
|
||||
).DebugContext(ctx, "上游流式响应头信息")
|
||||
@@ -705,78 +712,18 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
|
||||
rc := &domain.RecordParam{
|
||||
RequestID: c.RequestID,
|
||||
TaskID: taskID,
|
||||
UserID: c.UserID,
|
||||
ModelID: m.ID,
|
||||
ModelType: consts.ModelTypeLLM,
|
||||
WorkMode: mode,
|
||||
Prompt: prompt,
|
||||
Role: consts.ChatRoleAssistant,
|
||||
}
|
||||
|
||||
ch := make(chan []byte, 1024)
|
||||
defer close(ch)
|
||||
|
||||
go func(rc *domain.RecordParam) {
|
||||
if rc.Prompt != "" {
|
||||
urc := rc.Clone()
|
||||
urc.Role = consts.ChatRoleUser
|
||||
urc.Completion = urc.Prompt
|
||||
if err := p.usecase.Record(context.Background(), urc); err != nil {
|
||||
p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM).
|
||||
WarnContext(ctx, "插入流式记录失败", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
for line := range ch {
|
||||
if bytes.HasPrefix(line, []byte("data:")) {
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(line, []byte("[DONE]")) {
|
||||
break
|
||||
}
|
||||
|
||||
var t openai.ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal(line, &t); err != nil {
|
||||
p.logger.With("line", string(line)).WarnContext(ctx, "解析流式数据失败", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
p.logger.With("response", t).DebugContext(ctx, "流式响应数据")
|
||||
if len(t.Choices) > 0 {
|
||||
rc.Completion += t.Choices[0].Delta.Content
|
||||
}
|
||||
if t.Usage != nil {
|
||||
rc.InputTokens = int64(t.Usage.PromptTokens)
|
||||
rc.OutputTokens = int64(t.Usage.CompletionTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.logger.With("record", rc).DebugContext(ctx, "流式记录")
|
||||
if err := p.usecase.Record(context.Background(), rc); err != nil {
|
||||
p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM).
|
||||
WarnContext(ctx, "插入流式记录失败", "error", err)
|
||||
}
|
||||
}(rc)
|
||||
|
||||
err = streamRead(ctx, resp.Body, func(line []byte) error {
|
||||
ch <- line
|
||||
if _, err := w.Write(line); err != nil {
|
||||
return fmt.Errorf("写入响应失败: %w", err)
|
||||
}
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
recorder := NewChatRecorder(
|
||||
ctx,
|
||||
c,
|
||||
p.usecase,
|
||||
m,
|
||||
req,
|
||||
resp.Body,
|
||||
w,
|
||||
p.logger.With("module", "ChatRecorder"),
|
||||
)
|
||||
defer recorder.Close()
|
||||
return recorder.Stream()
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
223
backend/internal/proxy/recorder.go
Normal file
223
backend/internal/proxy/recorder.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/rokku-c/go-openai"
|
||||
|
||||
"github.com/chaitin/MonkeyCode/backend/consts"
|
||||
"github.com/chaitin/MonkeyCode/backend/domain"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/promptparser"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/tee"
|
||||
)
|
||||
|
||||
type ChatRecorder struct {
|
||||
*tee.Tee
|
||||
ctx context.Context
|
||||
cx *Ctx
|
||||
usecase domain.ProxyUsecase
|
||||
req *openai.ChatCompletionRequest
|
||||
logger *slog.Logger
|
||||
model *domain.Model
|
||||
completion strings.Builder // 累积完整的响应内容
|
||||
usage *openai.Usage // 最终的使用统计
|
||||
buffer strings.Builder // 缓存不完整的行
|
||||
recorded bool // 标记是否已记录完整对话
|
||||
}
|
||||
|
||||
func NewChatRecorder(
|
||||
ctx context.Context,
|
||||
cx *Ctx,
|
||||
usecase domain.ProxyUsecase,
|
||||
model *domain.Model,
|
||||
req *openai.ChatCompletionRequest,
|
||||
r io.Reader,
|
||||
w io.Writer,
|
||||
logger *slog.Logger,
|
||||
) *ChatRecorder {
|
||||
c := &ChatRecorder{
|
||||
ctx: ctx,
|
||||
cx: cx,
|
||||
usecase: usecase,
|
||||
model: model,
|
||||
req: req,
|
||||
logger: logger,
|
||||
}
|
||||
c.Tee = tee.NewTee(ctx, logger, r, w, c.handle)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *ChatRecorder) handle(ctx context.Context, data []byte) error {
|
||||
c.buffer.Write(data)
|
||||
bufferContent := c.buffer.String()
|
||||
|
||||
lines := strings.Split(bufferContent, "\n")
|
||||
if len(lines) > 0 {
|
||||
lastLine := lines[len(lines)-1]
|
||||
lines = lines[:len(lines)-1]
|
||||
c.buffer.Reset()
|
||||
c.buffer.WriteString(lastLine)
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
if err := c.processSSELine(ctx, line); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChatRecorder) processSSELine(ctx context.Context, line string) error {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if line == "" || !strings.HasPrefix(line, "data:") {
|
||||
return nil
|
||||
}
|
||||
|
||||
dataContent := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if dataContent == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dataContent == "[DONE]" {
|
||||
c.processCompletedChat(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
var resp openai.ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal([]byte(dataContent), &resp); err != nil {
|
||||
c.logger.With("data", dataContent).WarnContext(ctx, "解析流式响应失败", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
prompt := c.getPrompt(ctx, c.req)
|
||||
mode := c.req.Metadata["mode"]
|
||||
taskID := c.req.Metadata["task_id"]
|
||||
|
||||
rc := &domain.RecordParam{
|
||||
RequestID: c.cx.RequestID,
|
||||
TaskID: taskID,
|
||||
UserID: c.cx.UserID,
|
||||
ModelID: c.model.ID,
|
||||
ModelType: c.model.ModelType,
|
||||
WorkMode: mode,
|
||||
Prompt: prompt,
|
||||
Role: consts.ChatRoleAssistant,
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
rc.InputTokens = int64(resp.Usage.PromptTokens)
|
||||
}
|
||||
|
||||
if rc.Prompt != "" {
|
||||
urc := rc.Clone()
|
||||
urc.Role = consts.ChatRoleUser
|
||||
urc.Completion = urc.Prompt
|
||||
if err := c.usecase.Record(context.Background(), urc); err != nil {
|
||||
c.logger.With("modelID", c.model.ID, "modelName", c.model.ModelName, "modelType", consts.ModelTypeLLM).
|
||||
WarnContext(ctx, "插入流式记录失败", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(resp.Choices) > 0 {
|
||||
content := resp.Choices[0].Delta.Content
|
||||
if content != "" {
|
||||
c.completion.WriteString(content)
|
||||
}
|
||||
}
|
||||
|
||||
if resp.Usage != nil {
|
||||
c.usage = resp.Usage
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChatRecorder) processCompletedChat(ctx context.Context) {
|
||||
if c.recorded {
|
||||
return // 避免重复记录
|
||||
}
|
||||
|
||||
mode := c.req.Metadata["mode"]
|
||||
taskID := c.req.Metadata["task_id"]
|
||||
|
||||
rc := &domain.RecordParam{
|
||||
RequestID: c.cx.RequestID,
|
||||
TaskID: taskID,
|
||||
UserID: c.cx.UserID,
|
||||
ModelID: c.model.ID,
|
||||
ModelType: c.model.ModelType,
|
||||
WorkMode: mode,
|
||||
Role: consts.ChatRoleAssistant,
|
||||
Completion: c.completion.String(),
|
||||
InputTokens: int64(c.usage.PromptTokens),
|
||||
OutputTokens: int64(c.usage.CompletionTokens),
|
||||
}
|
||||
|
||||
if err := c.usecase.Record(context.Background(), rc); err != nil {
|
||||
c.logger.With("modelID", c.model.ID, "modelName", c.model.ModelName, "modelType", consts.ModelTypeLLM).
|
||||
WarnContext(ctx, "插入流式记录失败", "error", err)
|
||||
} else {
|
||||
c.recorded = true
|
||||
c.logger.With("requestID", c.cx.RequestID, "completion_length", len(c.completion.String())).
|
||||
InfoContext(ctx, "流式对话记录已保存")
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭 recorder 并确保数据被保存
|
||||
func (c *ChatRecorder) Close() {
|
||||
// 如果有累积的内容但还没有记录,强制保存
|
||||
if !c.recorded && c.completion.Len() > 0 {
|
||||
c.logger.With("requestID", c.cx.RequestID).
|
||||
WarnContext(c.ctx, "数据流异常中断,强制保存已累积的内容")
|
||||
c.processCompletedChat(c.ctx)
|
||||
}
|
||||
|
||||
if c.Tee != nil {
|
||||
c.Tee.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatRecorder) Reset() {
|
||||
c.completion.Reset()
|
||||
c.buffer.Reset()
|
||||
c.usage = nil
|
||||
c.recorded = false
|
||||
}
|
||||
|
||||
func (c *ChatRecorder) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string {
|
||||
prompt := ""
|
||||
parse := promptparser.New(promptparser.KindTask)
|
||||
for _, message := range req.Messages {
|
||||
if message.Role == "system" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(message.Content, "<task>") ||
|
||||
strings.Contains(message.Content, "<feedback>") ||
|
||||
strings.Contains(message.Content, "<user_message>") {
|
||||
if info, err := parse.Parse(message.Content); err == nil {
|
||||
prompt = info.Prompt
|
||||
} else {
|
||||
c.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, m := range message.MultiContent {
|
||||
if strings.Contains(m.Text, "<task>") ||
|
||||
strings.Contains(m.Text, "<feedback>") ||
|
||||
strings.Contains(m.Text, "<user_message>") {
|
||||
if info, err := parse.Parse(m.Text); err == nil {
|
||||
prompt = info.Prompt
|
||||
} else {
|
||||
c.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return prompt
|
||||
}
|
||||
@@ -5,12 +5,8 @@ import (
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
RequestIDKey contextKey = "request_id"
|
||||
UserIDKey contextKey = "user_id"
|
||||
)
|
||||
type RequestIDKey struct{}
|
||||
type UserIDKey struct{}
|
||||
|
||||
type ContextLogger struct {
|
||||
slog.Handler
|
||||
@@ -31,11 +27,11 @@ func (c *ContextLogger) WithGroup(name string) slog.Handler {
|
||||
func (c *ContextLogger) Handle(ctx context.Context, r slog.Record) error {
|
||||
newRecord := r.Clone()
|
||||
|
||||
if i, ok := ctx.Value(RequestIDKey).(string); ok {
|
||||
if i, ok := ctx.Value(RequestIDKey{}).(string); ok {
|
||||
newRecord.AddAttrs(slog.String("request_id", i))
|
||||
}
|
||||
|
||||
if i, ok := ctx.Value(UserIDKey).(string); ok {
|
||||
if i, ok := ctx.Value(UserIDKey{}).(string); ok {
|
||||
newRecord.AddAttrs(slog.String("user_id", i))
|
||||
}
|
||||
|
||||
|
||||
103
backend/pkg/tee/tee.go
Normal file
103
backend/pkg/tee/tee.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package tee
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type TeeHandleFunc func(ctx context.Context, data []byte) error
|
||||
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, 4096)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
type Tee struct {
|
||||
ctx context.Context
|
||||
logger *slog.Logger
|
||||
Reader io.Reader
|
||||
Writer io.Writer
|
||||
ch chan []byte
|
||||
handle TeeHandleFunc
|
||||
}
|
||||
|
||||
func NewTee(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
reader io.Reader,
|
||||
writer io.Writer,
|
||||
handle TeeHandleFunc,
|
||||
) *Tee {
|
||||
t := &Tee{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
Reader: reader,
|
||||
Writer: writer,
|
||||
handle: handle,
|
||||
ch: make(chan []byte, 32*1024),
|
||||
}
|
||||
go t.Handle()
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Tee) Close() {
|
||||
select {
|
||||
case <-t.ch:
|
||||
// channel 已经关闭
|
||||
default:
|
||||
close(t.ch)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tee) Handle() {
|
||||
for {
|
||||
select {
|
||||
case data, ok := <-t.ch:
|
||||
if !ok {
|
||||
t.logger.DebugContext(t.ctx, "Tee Handle closed")
|
||||
return
|
||||
}
|
||||
err := t.handle(t.ctx, data)
|
||||
if err != nil {
|
||||
t.logger.With("data", string(data)).With("error", err).ErrorContext(t.ctx, "Tee Handle error")
|
||||
return
|
||||
}
|
||||
case <-t.ctx.Done():
|
||||
t.logger.DebugContext(t.ctx, "Tee Handle ctx done")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tee) Stream() error {
|
||||
bufPtr := bufferPool.Get().(*[]byte)
|
||||
buf := *bufPtr
|
||||
defer bufferPool.Put(bufPtr)
|
||||
|
||||
for {
|
||||
n, err := t.Reader.Read(buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if n > 0 {
|
||||
_, err = t.Writer.Write(buf[:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if flusher, ok := t.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
data := make([]byte, n)
|
||||
copy(data, buf[:n])
|
||||
t.ch <- data
|
||||
}
|
||||
}
|
||||
}
|
||||
376
backend/pkg/tee/tee_test.go
Normal file
376
backend/pkg/tee/tee_test.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package tee
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockWriter 模拟 Writer 接口
|
||||
type mockWriter struct {
|
||||
buf *bytes.Buffer
|
||||
delay time.Duration
|
||||
errorOn int // 在第几次写入时返回错误
|
||||
count int
|
||||
}
|
||||
|
||||
func newMockWriter() *mockWriter {
|
||||
return &mockWriter{
|
||||
buf: &bytes.Buffer{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockWriter) Write(p []byte) (n int, err error) {
|
||||
m.count++
|
||||
if m.errorOn > 0 && m.count >= m.errorOn {
|
||||
return 0, errors.New("mock write error")
|
||||
}
|
||||
if m.delay > 0 {
|
||||
time.Sleep(m.delay)
|
||||
}
|
||||
return m.buf.Write(p)
|
||||
}
|
||||
|
||||
func (m *mockWriter) String() string {
|
||||
return m.buf.String()
|
||||
}
|
||||
|
||||
// mockReader 模拟 Reader 接口
|
||||
type mockReader struct {
|
||||
data []byte
|
||||
pos int
|
||||
chunk int // 每次读取的字节数
|
||||
errorOn int // 在第几次读取时返回错误
|
||||
count int
|
||||
}
|
||||
|
||||
func newMockReader(data string, chunk int) *mockReader {
|
||||
return &mockReader{
|
||||
data: []byte(data),
|
||||
chunk: chunk,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockReader) Read(p []byte) (n int, err error) {
|
||||
m.count++
|
||||
if m.errorOn > 0 && m.count >= m.errorOn {
|
||||
return 0, errors.New("mock read error")
|
||||
}
|
||||
|
||||
if m.pos >= len(m.data) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
readSize := m.chunk
|
||||
if readSize <= 0 || readSize > len(p) {
|
||||
readSize = len(p)
|
||||
}
|
||||
|
||||
remaining := len(m.data) - m.pos
|
||||
if readSize > remaining {
|
||||
readSize = remaining
|
||||
}
|
||||
|
||||
copy(p, m.data[m.pos:m.pos+readSize])
|
||||
m.pos += readSize
|
||||
return readSize, nil
|
||||
}
|
||||
|
||||
// TestTeeBasicFunctionality 测试基本功能
|
||||
func TestTeeBasicFunctionality(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
testData := "Hello, World! This is a test message."
|
||||
reader := newMockReader(testData, 10) // 每次读取10字节
|
||||
writer := newMockWriter()
|
||||
|
||||
var handledData [][]byte
|
||||
var mu sync.Mutex
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// 复制数据,因为原始数据可能被重用
|
||||
dataCopy := make([]byte, len(data))
|
||||
copy(dataCopy, data)
|
||||
handledData = append(handledData, dataCopy)
|
||||
return nil
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
defer tee.Close()
|
||||
|
||||
err := tee.Stream()
|
||||
if err != nil {
|
||||
t.Fatalf("Stream() failed: %v", err)
|
||||
}
|
||||
|
||||
// 等待处理完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证写入的数据
|
||||
if writer.String() != testData {
|
||||
t.Errorf("Expected writer data %q, got %q", testData, writer.String())
|
||||
}
|
||||
|
||||
// 验证处理的数据
|
||||
mu.Lock()
|
||||
var totalHandled []byte
|
||||
for _, chunk := range handledData {
|
||||
totalHandled = append(totalHandled, chunk...)
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
if string(totalHandled) != testData {
|
||||
t.Errorf("Expected handled data %q, got %q", testData, string(totalHandled))
|
||||
}
|
||||
}
|
||||
|
||||
// TestTeeWithErrors 测试错误处理
|
||||
func TestTeeWithErrors(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
t.Run("ReaderError", func(t *testing.T) {
|
||||
reader := newMockReader("test data", 5)
|
||||
reader.errorOn = 2 // 第二次读取时出错
|
||||
writer := newMockWriter()
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
defer tee.Close()
|
||||
|
||||
err := tee.Stream()
|
||||
if err == nil {
|
||||
t.Error("Expected error from reader, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WriterError", func(t *testing.T) {
|
||||
reader := newMockReader("test data", 5)
|
||||
writer := newMockWriter()
|
||||
writer.errorOn = 2 // 第二次写入时出错
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
defer tee.Close()
|
||||
|
||||
err := tee.Stream()
|
||||
if err == nil {
|
||||
t.Error("Expected error from writer, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HandleError", func(t *testing.T) {
|
||||
reader := newMockReader("test data", 5)
|
||||
writer := newMockWriter()
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
return errors.New("handle error")
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
defer tee.Close()
|
||||
|
||||
// 启动 Stream 在单独的 goroutine 中
|
||||
go func() {
|
||||
tee.Stream()
|
||||
}()
|
||||
|
||||
// 等待一段时间让处理器有机会处理数据并出错
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
})
|
||||
}
|
||||
|
||||
// TestTeeContextCancellation 测试上下文取消
|
||||
func TestTeeContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
// 创建一个会持续产生数据的 reader
|
||||
reader := strings.NewReader(strings.Repeat("test data ", 1000))
|
||||
writer := newMockWriter()
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
defer tee.Close()
|
||||
|
||||
// 在单独的 goroutine 中启动 Stream
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- tee.Stream()
|
||||
}()
|
||||
|
||||
// 等待一段时间后取消上下文
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
// 等待 Stream 完成
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil && err != io.EOF {
|
||||
t.Logf("Stream completed with error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Stream did not complete within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTeeConcurrentSafety 测试并发安全性
|
||||
func TestTeeConcurrentSafety(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
testData := strings.Repeat("concurrent test data ", 100)
|
||||
reader := strings.NewReader(testData)
|
||||
writer := newMockWriter()
|
||||
|
||||
var processedCount int64
|
||||
var mu sync.Mutex
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
mu.Lock()
|
||||
processedCount++
|
||||
mu.Unlock()
|
||||
// 模拟一些处理时间
|
||||
time.Sleep(time.Microsecond)
|
||||
return nil
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
defer tee.Close()
|
||||
|
||||
err := tee.Stream()
|
||||
if err != nil {
|
||||
t.Fatalf("Stream() failed: %v", err)
|
||||
}
|
||||
|
||||
// 等待所有数据处理完成
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
finalCount := processedCount
|
||||
mu.Unlock()
|
||||
|
||||
if finalCount == 0 {
|
||||
t.Error("No data was processed")
|
||||
}
|
||||
|
||||
t.Logf("Processed %d chunks of data", finalCount)
|
||||
}
|
||||
|
||||
// TestBufferPoolEfficiency 测试缓冲区池的效率
|
||||
func TestBufferPoolEfficiency(t *testing.T) {
|
||||
// 这个测试验证缓冲区池是否正常工作
|
||||
// 通过多次获取和归还缓冲区来测试
|
||||
|
||||
var buffers []*[]byte
|
||||
|
||||
// 获取多个缓冲区
|
||||
for i := 0; i < 10; i++ {
|
||||
bufPtr := bufferPool.Get().(*[]byte)
|
||||
buffers = append(buffers, bufPtr)
|
||||
|
||||
// 验证缓冲区大小
|
||||
if len(*bufPtr) != 4096 {
|
||||
t.Errorf("Expected buffer size 4096, got %d", len(*bufPtr))
|
||||
}
|
||||
}
|
||||
|
||||
// 归还所有缓冲区
|
||||
for _, bufPtr := range buffers {
|
||||
bufferPool.Put(bufPtr)
|
||||
}
|
||||
|
||||
// 再次获取缓冲区,应该重用之前的缓冲区
|
||||
for i := 0; i < 5; i++ {
|
||||
bufPtr := bufferPool.Get().(*[]byte)
|
||||
if len(*bufPtr) != 4096 {
|
||||
t.Errorf("Expected reused buffer size 4096, got %d", len(*bufPtr))
|
||||
}
|
||||
bufferPool.Put(bufPtr)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkTeeStream 基准测试
|
||||
func BenchmarkTeeStream(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
testData := strings.Repeat("benchmark test data ", 1000)
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
reader := strings.NewReader(testData)
|
||||
writer := io.Discard
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
err := tee.Stream()
|
||||
if err != nil {
|
||||
b.Fatalf("Stream() failed: %v", err)
|
||||
}
|
||||
tee.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkBufferPool 缓冲区池基准测试
|
||||
func BenchmarkBufferPool(b *testing.B) {
|
||||
b.Run("WithPool", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
bufPtr := bufferPool.Get().(*[]byte)
|
||||
// 模拟使用缓冲区
|
||||
_ = *bufPtr
|
||||
bufferPool.Put(bufPtr)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("WithoutPool", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := make([]byte, 4096)
|
||||
// 模拟使用缓冲区
|
||||
_ = buf
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestTeeClose 测试关闭功能
|
||||
func TestTeeClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
|
||||
reader := strings.NewReader("test data")
|
||||
writer := newMockWriter()
|
||||
|
||||
handle := func(ctx context.Context, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
tee := NewTee(ctx, logger, reader, writer, handle)
|
||||
|
||||
// 测试多次关闭不会 panic
|
||||
tee.Close()
|
||||
tee.Close()
|
||||
tee.Close()
|
||||
}
|
||||
Reference in New Issue
Block a user