diff --git a/backend/config/config.go b/backend/config/config.go index 49f4e5f..5f721ac 100644 --- a/backend/config/config.go +++ b/backend/config/config.go @@ -49,10 +49,11 @@ type Config struct { } `mapstructure:"redis"` LLMProxy struct { - Timeout string `mapstructure:"timeout"` - KeepAlive string `mapstructure:"keep_alive"` - ClientPoolSize int `mapstructure:"client_pool_size"` - RequestLogPath string `mapstructure:"request_log_path"` + Timeout string `mapstructure:"timeout"` + KeepAlive string `mapstructure:"keep_alive"` + ClientPoolSize int `mapstructure:"client_pool_size"` + StreamClientPoolSize int `mapstructure:"stream_client_pool_size"` + RequestLogPath string `mapstructure:"request_log_path"` } `mapstructure:"llm_proxy"` InitModel struct { @@ -92,6 +93,7 @@ func Init() (*Config, error) { v.SetDefault("llm_proxy.timeout", "30s") v.SetDefault("llm_proxy.keep_alive", "60s") v.SetDefault("llm_proxy.client_pool_size", 100) + v.SetDefault("llm_proxy.stream_client_pool_size", 5000) v.SetDefault("llm_proxy.request_log_path", "/app/request/logs") v.SetDefault("init_model.name", "qwen2.5-coder-3b-instruct") v.SetDefault("init_model.key", "") diff --git a/backend/internal/middleware/logger.go b/backend/internal/middleware/logger.go index 494ccff..c2c03bf 100644 --- a/backend/internal/middleware/logger.go +++ b/backend/internal/middleware/logger.go @@ -15,7 +15,7 @@ func RequestID() echo.MiddlewareFunc { return func(c echo.Context) error { ctx := c.Request().Context() requestID := uuid.New().String() - ctx = context.WithValue(ctx, logger.RequestIDKey, requestID) + ctx = context.WithValue(ctx, logger.RequestIDKey{}, requestID) c.SetRequest(c.Request().WithContext(ctx)) return next(c) } diff --git a/backend/internal/middleware/proxy.go b/backend/internal/middleware/proxy.go index ac1269b..f8a42a0 100644 --- a/backend/internal/middleware/proxy.go +++ b/backend/internal/middleware/proxy.go @@ -44,7 +44,7 @@ func (p *ProxyMiddleware) Auth() echo.MiddlewareFunc { } ctx := c.Request().Context() - ctx = context.WithValue(ctx, logger.UserIDKey, key.UserID) + ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID) c.SetRequest(c.Request().WithContext(ctx)) c.Set(ApiContextKey, key) return next(c) diff --git a/backend/internal/proxy/proxy.go b/backend/internal/proxy/proxy.go index b3f6c35..c0a24e5 100644 --- a/backend/internal/proxy/proxy.go +++ b/backend/internal/proxy/proxy.go @@ -59,6 +59,7 @@ type LLMProxy struct { usecase domain.ProxyUsecase cfg *config.Config client *http.Client + streamClient *http.Client logger *slog.Logger requestLogPath string // 请求日志保存路径 } @@ -83,7 +84,6 @@ func NewLLMProxy( logger.Warn("解析保持连接时间失败, 使用默认值 60s", "error", err) } - // 创建HTTP客户端 client := &http.Client{ Timeout: timeout, Transport: &http.Transport{ @@ -98,6 +98,18 @@ func NewLLMProxy( }, } + streamClient := &http.Client{ + Timeout: 60 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: cfg.LLMProxy.StreamClientPoolSize, + MaxConnsPerHost: cfg.LLMProxy.StreamClientPoolSize, + MaxIdleConnsPerHost: cfg.LLMProxy.StreamClientPoolSize, + IdleConnTimeout: 24 * time.Hour, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + } + // 获取日志配置 requestLogPath := "" if cfg.LLMProxy.RequestLogPath != "" { @@ -111,6 +123,7 @@ func NewLLMProxy( return &LLMProxy{ usecase: usecase, client: client, + streamClient: streamClient, cfg: cfg, requestLogPath: requestLogPath, logger: logger, @@ -174,12 +187,12 @@ type Ctx struct { func (p *LLMProxy) handle(ctx context.Context, fn func(ctx *Ctx, log *RequestResponseLog) error) { // 获取用户ID userID := "unknown" - if id, ok := ctx.Value(logger.UserIDKey).(string); ok { + if id, ok := ctx.Value(logger.UserIDKey{}).(string); ok { userID = id } requestID := "unknown" - if id, ok := ctx.Value(logger.RequestIDKey).(string); ok { + if id, ok := ctx.Value(logger.RequestIDKey{}).(string); ok { requestID = id } @@ -203,11 +216,11 @@ func (p *LLMProxy) handle(ctx context.Context, fn func(ctx *Ctx, log *RequestRes } if err := fn(c, l); err != nil { - p.logger.With("userID", userID, "requestID", requestID, "sourceip", sourceip).ErrorContext(ctx, "处理请求失败", "error", err) + p.logger.With("source_ip", sourceip).ErrorContext(ctx, "处理请求失败", "error", err) l.Error = err.Error() } - p.saveRequestResponseLog(l) + go p.saveRequestResponseLog(l) } func (p *LLMProxy) HandleCompletion(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) { @@ -585,10 +598,6 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon return err } - prompt := p.getPrompt(ctx, req) - mode := req.Metadata["mode"] - taskID := req.Metadata["task_id"] - upstream := m.APIBase + endpoint log.UpstreamURL = upstream @@ -606,9 +615,7 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon newReq.Header.Set("Content-Type", "application/json") newReq.Header.Set("Accept", "text/event-stream") - if m.APIKey != "" && m.APIKey != "none" { - newReq.Header.Set("Authorization", "Bearer "+m.APIKey) - } + newReq.Header.Set("Authorization", "Bearer "+m.APIKey) // 保存请求头(去除敏感信息) requestHeaders := make(map[string][]string) @@ -622,22 +629,26 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon } log.RequestHeader = requestHeaders - p.logger.With( + logger := p.logger.With( + "request_id", c.RequestID, + "source_ip", c.SourceIP, "upstreamURL", upstream, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM, "apiBase", m.APIBase, - "work_mode", mode, + ) + + logger.With( + "upstreamURL", upstream, "requestHeader", newReq.Header, "requestBody", req, - "taskID", taskID, "messages", cvt.Filter(req.Messages, func(i int, v openai.ChatCompletionMessage) (openai.ChatCompletionMessage, bool) { return v, v.Role != "system" }), ).DebugContext(ctx, "转发流式请求到上游API") // 发送请求 - resp, err := p.client.Do(newReq) + resp, err := p.streamClient.Do(newReq) if err != nil { p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游流式请求失败", "error", err) return fmt.Errorf("发送上游请求失败: %w", err) @@ -655,7 +666,7 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon log.Latency = time.Since(startTime).Milliseconds() // 在debug级别记录错误的流式响应内容 - p.logger.With( + logger.With( "statusCode", resp.StatusCode, "responseHeader", resp.Header, "responseBody", string(responseBody), @@ -663,9 +674,8 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon var errorResp ErrResp if err := json.Unmarshal(responseBody, &errorResp); err == nil { - p.logger.With( + logger.With( "endpoint", endpoint, - "upstreamURL", upstream, "requestBody", newReq, "statusCode", resp.StatusCode, "errorType", errorResp.Error.Type, @@ -677,9 +687,8 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon return fmt.Errorf("上游API返回错误: %s", errorResp.Error.Message) } - p.logger.With( + logger.With( "endpoint", endpoint, - "upstreamURL", upstream, "requestBody", newReq, "statusCode", resp.StatusCode, "responseBody", string(responseBody), @@ -688,12 +697,10 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon return fmt.Errorf("上游API返回非200状态码: %d, 响应: %s", resp.StatusCode, string(responseBody)) } - // 更新日志信息 log.StatusCode = resp.StatusCode log.ResponseHeader = resp.Header - // 在debug级别记录流式响应头信息 - p.logger.With( + logger.With( "statusCode", resp.StatusCode, "responseHeader", resp.Header, ).DebugContext(ctx, "上游流式响应头信息") @@ -705,78 +712,18 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("X-Accel-Buffering", "no") - rc := &domain.RecordParam{ - RequestID: c.RequestID, - TaskID: taskID, - UserID: c.UserID, - ModelID: m.ID, - ModelType: consts.ModelTypeLLM, - WorkMode: mode, - Prompt: prompt, - Role: consts.ChatRoleAssistant, - } - - ch := make(chan []byte, 1024) - defer close(ch) - - go func(rc *domain.RecordParam) { - if rc.Prompt != "" { - urc := rc.Clone() - urc.Role = consts.ChatRoleUser - urc.Completion = urc.Prompt - if err := p.usecase.Record(context.Background(), urc); err != nil { - p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM). - WarnContext(ctx, "插入流式记录失败", "error", err) - } - } - - for line := range ch { - if bytes.HasPrefix(line, []byte("data:")) { - line = bytes.TrimPrefix(line, []byte("data: ")) - line = bytes.TrimSpace(line) - if len(line) == 0 { - continue - } - - if bytes.Equal(line, []byte("[DONE]")) { - break - } - - var t openai.ChatCompletionStreamResponse - if err := json.Unmarshal(line, &t); err != nil { - p.logger.With("line", string(line)).WarnContext(ctx, "解析流式数据失败", "error", err) - continue - } - - p.logger.With("response", t).DebugContext(ctx, "流式响应数据") - if len(t.Choices) > 0 { - rc.Completion += t.Choices[0].Delta.Content - } - if t.Usage != nil { - rc.InputTokens = int64(t.Usage.PromptTokens) - rc.OutputTokens = int64(t.Usage.CompletionTokens) - } - } - } - - p.logger.With("record", rc).DebugContext(ctx, "流式记录") - if err := p.usecase.Record(context.Background(), rc); err != nil { - p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM). - WarnContext(ctx, "插入流式记录失败", "error", err) - } - }(rc) - - err = streamRead(ctx, resp.Body, func(line []byte) error { - ch <- line - if _, err := w.Write(line); err != nil { - return fmt.Errorf("写入响应失败: %w", err) - } - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - return nil - }) - return err + recorder := NewChatRecorder( + ctx, + c, + p.usecase, + m, + req, + resp.Body, + w, + p.logger.With("module", "ChatRecorder"), + ) + defer recorder.Close() + return recorder.Stream() }) } diff --git a/backend/internal/proxy/recorder.go b/backend/internal/proxy/recorder.go new file mode 100644 index 0000000..9e94a33 --- /dev/null +++ b/backend/internal/proxy/recorder.go @@ -0,0 +1,223 @@ +package proxy + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "strings" + + "github.com/rokku-c/go-openai" + + "github.com/chaitin/MonkeyCode/backend/consts" + "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/pkg/promptparser" + "github.com/chaitin/MonkeyCode/backend/pkg/tee" +) + +type ChatRecorder struct { + *tee.Tee + ctx context.Context + cx *Ctx + usecase domain.ProxyUsecase + req *openai.ChatCompletionRequest + logger *slog.Logger + model *domain.Model + completion strings.Builder // 累积完整的响应内容 + usage *openai.Usage // 最终的使用统计 + buffer strings.Builder // 缓存不完整的行 + recorded bool // 标记是否已记录完整对话 +} + +func NewChatRecorder( + ctx context.Context, + cx *Ctx, + usecase domain.ProxyUsecase, + model *domain.Model, + req *openai.ChatCompletionRequest, + r io.Reader, + w io.Writer, + logger *slog.Logger, +) *ChatRecorder { + c := &ChatRecorder{ + ctx: ctx, + cx: cx, + usecase: usecase, + model: model, + req: req, + logger: logger, + } + c.Tee = tee.NewTee(ctx, logger, r, w, c.handle) + return c +} + +func (c *ChatRecorder) handle(ctx context.Context, data []byte) error { + c.buffer.Write(data) + bufferContent := c.buffer.String() + + lines := strings.Split(bufferContent, "\n") + if len(lines) > 0 { + lastLine := lines[len(lines)-1] + lines = lines[:len(lines)-1] + c.buffer.Reset() + c.buffer.WriteString(lastLine) + } + + for _, line := range lines { + if err := c.processSSELine(ctx, line); err != nil { + return err + } + } + + return nil +} + +func (c *ChatRecorder) processSSELine(ctx context.Context, line string) error { + line = strings.TrimSpace(line) + + if line == "" || !strings.HasPrefix(line, "data:") { + return nil + } + + dataContent := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if dataContent == "" { + return nil + } + + if dataContent == "[DONE]" { + c.processCompletedChat(ctx) + return nil + } + + var resp openai.ChatCompletionStreamResponse + if err := json.Unmarshal([]byte(dataContent), &resp); err != nil { + c.logger.With("data", dataContent).WarnContext(ctx, "解析流式响应失败", "error", err) + return nil + } + + prompt := c.getPrompt(ctx, c.req) + mode := c.req.Metadata["mode"] + taskID := c.req.Metadata["task_id"] + + rc := &domain.RecordParam{ + RequestID: c.cx.RequestID, + TaskID: taskID, + UserID: c.cx.UserID, + ModelID: c.model.ID, + ModelType: c.model.ModelType, + WorkMode: mode, + Prompt: prompt, + Role: consts.ChatRoleAssistant, + } + if resp.Usage != nil { + rc.InputTokens = int64(resp.Usage.PromptTokens) + } + + if rc.Prompt != "" { + urc := rc.Clone() + urc.Role = consts.ChatRoleUser + urc.Completion = urc.Prompt + if err := c.usecase.Record(context.Background(), urc); err != nil { + c.logger.With("modelID", c.model.ID, "modelName", c.model.ModelName, "modelType", consts.ModelTypeLLM). + WarnContext(ctx, "插入流式记录失败", "error", err) + } + } + + if len(resp.Choices) > 0 { + content := resp.Choices[0].Delta.Content + if content != "" { + c.completion.WriteString(content) + } + } + + if resp.Usage != nil { + c.usage = resp.Usage + } + + return nil +} + +func (c *ChatRecorder) processCompletedChat(ctx context.Context) { + if c.recorded { + return // 避免重复记录 + } + + mode := c.req.Metadata["mode"] + taskID := c.req.Metadata["task_id"] + + rc := &domain.RecordParam{ + RequestID: c.cx.RequestID, + TaskID: taskID, + UserID: c.cx.UserID, + ModelID: c.model.ID, + ModelType: c.model.ModelType, + WorkMode: mode, + Role: consts.ChatRoleAssistant, + Completion: c.completion.String(), + InputTokens: int64(c.usage.PromptTokens), + OutputTokens: int64(c.usage.CompletionTokens), + } + + if err := c.usecase.Record(context.Background(), rc); err != nil { + c.logger.With("modelID", c.model.ID, "modelName", c.model.ModelName, "modelType", consts.ModelTypeLLM). + WarnContext(ctx, "插入流式记录失败", "error", err) + } else { + c.recorded = true + c.logger.With("requestID", c.cx.RequestID, "completion_length", len(c.completion.String())). + InfoContext(ctx, "流式对话记录已保存") + } +} + +// Close 关闭 recorder 并确保数据被保存 +func (c *ChatRecorder) Close() { + // 如果有累积的内容但还没有记录,强制保存 + if !c.recorded && c.completion.Len() > 0 { + c.logger.With("requestID", c.cx.RequestID). + WarnContext(c.ctx, "数据流异常中断,强制保存已累积的内容") + c.processCompletedChat(c.ctx) + } + + if c.Tee != nil { + c.Tee.Close() + } +} + +func (c *ChatRecorder) Reset() { + c.completion.Reset() + c.buffer.Reset() + c.usage = nil + c.recorded = false +} + +func (c *ChatRecorder) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string { + prompt := "" + parse := promptparser.New(promptparser.KindTask) + for _, message := range req.Messages { + if message.Role == "system" { + continue + } + + if strings.Contains(message.Content, "") || + strings.Contains(message.Content, "") || + strings.Contains(message.Content, "") { + if info, err := parse.Parse(message.Content); err == nil { + prompt = info.Prompt + } else { + c.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err) + } + } + + for _, m := range message.MultiContent { + if strings.Contains(m.Text, "") || + strings.Contains(m.Text, "") || + strings.Contains(m.Text, "") { + if info, err := parse.Parse(m.Text); err == nil { + prompt = info.Prompt + } else { + c.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err) + } + } + } + } + return prompt +} diff --git a/backend/pkg/logger/context.go b/backend/pkg/logger/context.go index 871047a..f2eddab 100644 --- a/backend/pkg/logger/context.go +++ b/backend/pkg/logger/context.go @@ -5,12 +5,8 @@ import ( "log/slog" ) -type contextKey string - -const ( - RequestIDKey contextKey = "request_id" - UserIDKey contextKey = "user_id" -) +type RequestIDKey struct{} +type UserIDKey struct{} type ContextLogger struct { slog.Handler @@ -31,11 +27,11 @@ func (c *ContextLogger) WithGroup(name string) slog.Handler { func (c *ContextLogger) Handle(ctx context.Context, r slog.Record) error { newRecord := r.Clone() - if i, ok := ctx.Value(RequestIDKey).(string); ok { + if i, ok := ctx.Value(RequestIDKey{}).(string); ok { newRecord.AddAttrs(slog.String("request_id", i)) } - if i, ok := ctx.Value(UserIDKey).(string); ok { + if i, ok := ctx.Value(UserIDKey{}).(string); ok { newRecord.AddAttrs(slog.String("user_id", i)) } diff --git a/backend/pkg/tee/tee.go b/backend/pkg/tee/tee.go new file mode 100644 index 0000000..577e924 --- /dev/null +++ b/backend/pkg/tee/tee.go @@ -0,0 +1,103 @@ +package tee + +import ( + "context" + "io" + "log/slog" + "net/http" + "sync" +) + +type TeeHandleFunc func(ctx context.Context, data []byte) error + +var bufferPool = sync.Pool{ + New: func() any { + buf := make([]byte, 4096) + return &buf + }, +} + +type Tee struct { + ctx context.Context + logger *slog.Logger + Reader io.Reader + Writer io.Writer + ch chan []byte + handle TeeHandleFunc +} + +func NewTee( + ctx context.Context, + logger *slog.Logger, + reader io.Reader, + writer io.Writer, + handle TeeHandleFunc, +) *Tee { + t := &Tee{ + ctx: ctx, + logger: logger, + Reader: reader, + Writer: writer, + handle: handle, + ch: make(chan []byte, 32*1024), + } + go t.Handle() + return t +} + +func (t *Tee) Close() { + select { + case <-t.ch: + // channel 已经关闭 + default: + close(t.ch) + } +} + +func (t *Tee) Handle() { + for { + select { + case data, ok := <-t.ch: + if !ok { + t.logger.DebugContext(t.ctx, "Tee Handle closed") + return + } + err := t.handle(t.ctx, data) + if err != nil { + t.logger.With("data", string(data)).With("error", err).ErrorContext(t.ctx, "Tee Handle error") + return + } + case <-t.ctx.Done(): + t.logger.DebugContext(t.ctx, "Tee Handle ctx done") + return + } + } +} + +func (t *Tee) Stream() error { + bufPtr := bufferPool.Get().(*[]byte) + buf := *bufPtr + defer bufferPool.Put(bufPtr) + + for { + n, err := t.Reader.Read(buf) + if err != nil { + if err == io.EOF { + return nil + } + return err + } + if n > 0 { + _, err = t.Writer.Write(buf[:n]) + if err != nil { + return err + } + if flusher, ok := t.Writer.(http.Flusher); ok { + flusher.Flush() + } + data := make([]byte, n) + copy(data, buf[:n]) + t.ch <- data + } + } +} diff --git a/backend/pkg/tee/tee_test.go b/backend/pkg/tee/tee_test.go new file mode 100644 index 0000000..5f91a4c --- /dev/null +++ b/backend/pkg/tee/tee_test.go @@ -0,0 +1,376 @@ +package tee + +import ( + "bytes" + "context" + "errors" + "io" + "log/slog" + "os" + "strings" + "sync" + "testing" + "time" +) + +// mockWriter 模拟 Writer 接口 +type mockWriter struct { + buf *bytes.Buffer + delay time.Duration + errorOn int // 在第几次写入时返回错误 + count int +} + +func newMockWriter() *mockWriter { + return &mockWriter{ + buf: &bytes.Buffer{}, + } +} + +func (m *mockWriter) Write(p []byte) (n int, err error) { + m.count++ + if m.errorOn > 0 && m.count >= m.errorOn { + return 0, errors.New("mock write error") + } + if m.delay > 0 { + time.Sleep(m.delay) + } + return m.buf.Write(p) +} + +func (m *mockWriter) String() string { + return m.buf.String() +} + +// mockReader 模拟 Reader 接口 +type mockReader struct { + data []byte + pos int + chunk int // 每次读取的字节数 + errorOn int // 在第几次读取时返回错误 + count int +} + +func newMockReader(data string, chunk int) *mockReader { + return &mockReader{ + data: []byte(data), + chunk: chunk, + } +} + +func (m *mockReader) Read(p []byte) (n int, err error) { + m.count++ + if m.errorOn > 0 && m.count >= m.errorOn { + return 0, errors.New("mock read error") + } + + if m.pos >= len(m.data) { + return 0, io.EOF + } + + readSize := m.chunk + if readSize <= 0 || readSize > len(p) { + readSize = len(p) + } + + remaining := len(m.data) - m.pos + if readSize > remaining { + readSize = remaining + } + + copy(p, m.data[m.pos:m.pos+readSize]) + m.pos += readSize + return readSize, nil +} + +// TestTeeBasicFunctionality 测试基本功能 +func TestTeeBasicFunctionality(t *testing.T) { + ctx := context.Background() + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError})) + + testData := "Hello, World! This is a test message." + reader := newMockReader(testData, 10) // 每次读取10字节 + writer := newMockWriter() + + var handledData [][]byte + var mu sync.Mutex + + handle := func(ctx context.Context, data []byte) error { + mu.Lock() + defer mu.Unlock() + // 复制数据,因为原始数据可能被重用 + dataCopy := make([]byte, len(data)) + copy(dataCopy, data) + handledData = append(handledData, dataCopy) + return nil + } + + tee := NewTee(ctx, logger, reader, writer, handle) + defer tee.Close() + + err := tee.Stream() + if err != nil { + t.Fatalf("Stream() failed: %v", err) + } + + // 等待处理完成 + time.Sleep(100 * time.Millisecond) + + // 验证写入的数据 + if writer.String() != testData { + t.Errorf("Expected writer data %q, got %q", testData, writer.String()) + } + + // 验证处理的数据 + mu.Lock() + var totalHandled []byte + for _, chunk := range handledData { + totalHandled = append(totalHandled, chunk...) + } + mu.Unlock() + + if string(totalHandled) != testData { + t.Errorf("Expected handled data %q, got %q", testData, string(totalHandled)) + } +} + +// TestTeeWithErrors 测试错误处理 +func TestTeeWithErrors(t *testing.T) { + ctx := context.Background() + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError})) + + t.Run("ReaderError", func(t *testing.T) { + reader := newMockReader("test data", 5) + reader.errorOn = 2 // 第二次读取时出错 + writer := newMockWriter() + + handle := func(ctx context.Context, data []byte) error { + return nil + } + + tee := NewTee(ctx, logger, reader, writer, handle) + defer tee.Close() + + err := tee.Stream() + if err == nil { + t.Error("Expected error from reader, got nil") + } + }) + + t.Run("WriterError", func(t *testing.T) { + reader := newMockReader("test data", 5) + writer := newMockWriter() + writer.errorOn = 2 // 第二次写入时出错 + + handle := func(ctx context.Context, data []byte) error { + return nil + } + + tee := NewTee(ctx, logger, reader, writer, handle) + defer tee.Close() + + err := tee.Stream() + if err == nil { + t.Error("Expected error from writer, got nil") + } + }) + + t.Run("HandleError", func(t *testing.T) { + reader := newMockReader("test data", 5) + writer := newMockWriter() + + handle := func(ctx context.Context, data []byte) error { + return errors.New("handle error") + } + + tee := NewTee(ctx, logger, reader, writer, handle) + defer tee.Close() + + // 启动 Stream 在单独的 goroutine 中 + go func() { + tee.Stream() + }() + + // 等待一段时间让处理器有机会处理数据并出错 + time.Sleep(200 * time.Millisecond) + }) +} + +// TestTeeContextCancellation 测试上下文取消 +func TestTeeContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError})) + + // 创建一个会持续产生数据的 reader + reader := strings.NewReader(strings.Repeat("test data ", 1000)) + writer := newMockWriter() + + handle := func(ctx context.Context, data []byte) error { + return nil + } + + tee := NewTee(ctx, logger, reader, writer, handle) + defer tee.Close() + + // 在单独的 goroutine 中启动 Stream + done := make(chan error, 1) + go func() { + done <- tee.Stream() + }() + + // 等待一段时间后取消上下文 + time.Sleep(50 * time.Millisecond) + cancel() + + // 等待 Stream 完成 + select { + case err := <-done: + if err != nil && err != io.EOF { + t.Logf("Stream completed with error: %v", err) + } + case <-time.After(2 * time.Second): + t.Error("Stream did not complete within timeout") + } +} + +// TestTeeConcurrentSafety 测试并发安全性 +func TestTeeConcurrentSafety(t *testing.T) { + ctx := context.Background() + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError})) + + testData := strings.Repeat("concurrent test data ", 100) + reader := strings.NewReader(testData) + writer := newMockWriter() + + var processedCount int64 + var mu sync.Mutex + + handle := func(ctx context.Context, data []byte) error { + mu.Lock() + processedCount++ + mu.Unlock() + // 模拟一些处理时间 + time.Sleep(time.Microsecond) + return nil + } + + tee := NewTee(ctx, logger, reader, writer, handle) + defer tee.Close() + + err := tee.Stream() + if err != nil { + t.Fatalf("Stream() failed: %v", err) + } + + // 等待所有数据处理完成 + time.Sleep(500 * time.Millisecond) + + mu.Lock() + finalCount := processedCount + mu.Unlock() + + if finalCount == 0 { + t.Error("No data was processed") + } + + t.Logf("Processed %d chunks of data", finalCount) +} + +// TestBufferPoolEfficiency 测试缓冲区池的效率 +func TestBufferPoolEfficiency(t *testing.T) { + // 这个测试验证缓冲区池是否正常工作 + // 通过多次获取和归还缓冲区来测试 + + var buffers []*[]byte + + // 获取多个缓冲区 + for i := 0; i < 10; i++ { + bufPtr := bufferPool.Get().(*[]byte) + buffers = append(buffers, bufPtr) + + // 验证缓冲区大小 + if len(*bufPtr) != 4096 { + t.Errorf("Expected buffer size 4096, got %d", len(*bufPtr)) + } + } + + // 归还所有缓冲区 + for _, bufPtr := range buffers { + bufferPool.Put(bufPtr) + } + + // 再次获取缓冲区,应该重用之前的缓冲区 + for i := 0; i < 5; i++ { + bufPtr := bufferPool.Get().(*[]byte) + if len(*bufPtr) != 4096 { + t.Errorf("Expected reused buffer size 4096, got %d", len(*bufPtr)) + } + bufferPool.Put(bufPtr) + } +} + +// BenchmarkTeeStream 基准测试 +func BenchmarkTeeStream(b *testing.B) { + ctx := context.Background() + logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError})) + + testData := strings.Repeat("benchmark test data ", 1000) + + handle := func(ctx context.Context, data []byte) error { + return nil + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + reader := strings.NewReader(testData) + writer := io.Discard + + tee := NewTee(ctx, logger, reader, writer, handle) + err := tee.Stream() + if err != nil { + b.Fatalf("Stream() failed: %v", err) + } + tee.Close() + } +} + +// BenchmarkBufferPool 缓冲区池基准测试 +func BenchmarkBufferPool(b *testing.B) { + b.Run("WithPool", func(b *testing.B) { + for i := 0; i < b.N; i++ { + bufPtr := bufferPool.Get().(*[]byte) + // 模拟使用缓冲区 + _ = *bufPtr + bufferPool.Put(bufPtr) + } + }) + + b.Run("WithoutPool", func(b *testing.B) { + for i := 0; i < b.N; i++ { + buf := make([]byte, 4096) + // 模拟使用缓冲区 + _ = buf + } + }) +} + +// TestTeeClose 测试关闭功能 +func TestTeeClose(t *testing.T) { + ctx := context.Background() + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError})) + + reader := strings.NewReader("test data") + writer := newMockWriter() + + handle := func(ctx context.Context, data []byte) error { + return nil + } + + tee := NewTee(ctx, logger, reader, writer, handle) + + // 测试多次关闭不会 panic + tee.Close() + tee.Close() + tee.Close() +}