Merge pull request #70 from yokowu/feat-optimize-proxy

feat(proxy): 优化流式代理, 隔离代理与记录逻辑
This commit is contained in:
Yoko
2025-07-12 22:57:21 +08:00
committed by GitHub
8 changed files with 757 additions and 110 deletions

View File

@@ -49,10 +49,11 @@ type Config struct {
} `mapstructure:"redis"`
LLMProxy struct {
Timeout string `mapstructure:"timeout"`
KeepAlive string `mapstructure:"keep_alive"`
ClientPoolSize int `mapstructure:"client_pool_size"`
RequestLogPath string `mapstructure:"request_log_path"`
Timeout string `mapstructure:"timeout"`
KeepAlive string `mapstructure:"keep_alive"`
ClientPoolSize int `mapstructure:"client_pool_size"`
StreamClientPoolSize int `mapstructure:"stream_client_pool_size"`
RequestLogPath string `mapstructure:"request_log_path"`
} `mapstructure:"llm_proxy"`
InitModel struct {
@@ -92,6 +93,7 @@ func Init() (*Config, error) {
v.SetDefault("llm_proxy.timeout", "30s")
v.SetDefault("llm_proxy.keep_alive", "60s")
v.SetDefault("llm_proxy.client_pool_size", 100)
v.SetDefault("llm_proxy.stream_client_pool_size", 5000)
v.SetDefault("llm_proxy.request_log_path", "/app/request/logs")
v.SetDefault("init_model.name", "qwen2.5-coder-3b-instruct")
v.SetDefault("init_model.key", "")

View File

@@ -15,7 +15,7 @@ func RequestID() echo.MiddlewareFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
requestID := uuid.New().String()
ctx = context.WithValue(ctx, logger.RequestIDKey, requestID)
ctx = context.WithValue(ctx, logger.RequestIDKey{}, requestID)
c.SetRequest(c.Request().WithContext(ctx))
return next(c)
}

View File

@@ -44,7 +44,7 @@ func (p *ProxyMiddleware) Auth() echo.MiddlewareFunc {
}
ctx := c.Request().Context()
ctx = context.WithValue(ctx, logger.UserIDKey, key.UserID)
ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID)
c.SetRequest(c.Request().WithContext(ctx))
c.Set(ApiContextKey, key)
return next(c)

View File

@@ -59,6 +59,7 @@ type LLMProxy struct {
usecase domain.ProxyUsecase
cfg *config.Config
client *http.Client
streamClient *http.Client
logger *slog.Logger
requestLogPath string // 请求日志保存路径
}
@@ -83,7 +84,6 @@ func NewLLMProxy(
logger.Warn("解析保持连接时间失败, 使用默认值 60s", "error", err)
}
// 创建HTTP客户端
client := &http.Client{
Timeout: timeout,
Transport: &http.Transport{
@@ -98,6 +98,18 @@ func NewLLMProxy(
},
}
streamClient := &http.Client{
Timeout: 60 * time.Second,
Transport: &http.Transport{
MaxIdleConns: cfg.LLMProxy.StreamClientPoolSize,
MaxConnsPerHost: cfg.LLMProxy.StreamClientPoolSize,
MaxIdleConnsPerHost: cfg.LLMProxy.StreamClientPoolSize,
IdleConnTimeout: 24 * time.Hour,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}
// 获取日志配置
requestLogPath := ""
if cfg.LLMProxy.RequestLogPath != "" {
@@ -111,6 +123,7 @@ func NewLLMProxy(
return &LLMProxy{
usecase: usecase,
client: client,
streamClient: streamClient,
cfg: cfg,
requestLogPath: requestLogPath,
logger: logger,
@@ -174,12 +187,12 @@ type Ctx struct {
func (p *LLMProxy) handle(ctx context.Context, fn func(ctx *Ctx, log *RequestResponseLog) error) {
// 获取用户ID
userID := "unknown"
if id, ok := ctx.Value(logger.UserIDKey).(string); ok {
if id, ok := ctx.Value(logger.UserIDKey{}).(string); ok {
userID = id
}
requestID := "unknown"
if id, ok := ctx.Value(logger.RequestIDKey).(string); ok {
if id, ok := ctx.Value(logger.RequestIDKey{}).(string); ok {
requestID = id
}
@@ -203,11 +216,11 @@ func (p *LLMProxy) handle(ctx context.Context, fn func(ctx *Ctx, log *RequestRes
}
if err := fn(c, l); err != nil {
p.logger.With("userID", userID, "requestID", requestID, "sourceip", sourceip).ErrorContext(ctx, "处理请求失败", "error", err)
p.logger.With("source_ip", sourceip).ErrorContext(ctx, "处理请求失败", "error", err)
l.Error = err.Error()
}
p.saveRequestResponseLog(l)
go p.saveRequestResponseLog(l)
}
func (p *LLMProxy) HandleCompletion(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) {
@@ -585,10 +598,6 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
return err
}
prompt := p.getPrompt(ctx, req)
mode := req.Metadata["mode"]
taskID := req.Metadata["task_id"]
upstream := m.APIBase + endpoint
log.UpstreamURL = upstream
@@ -606,9 +615,7 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
newReq.Header.Set("Content-Type", "application/json")
newReq.Header.Set("Accept", "text/event-stream")
if m.APIKey != "" && m.APIKey != "none" {
newReq.Header.Set("Authorization", "Bearer "+m.APIKey)
}
newReq.Header.Set("Authorization", "Bearer "+m.APIKey)
// 保存请求头(去除敏感信息)
requestHeaders := make(map[string][]string)
@@ -622,22 +629,26 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
}
log.RequestHeader = requestHeaders
p.logger.With(
logger := p.logger.With(
"request_id", c.RequestID,
"source_ip", c.SourceIP,
"upstreamURL", upstream,
"modelName", m.ModelName,
"modelType", consts.ModelTypeLLM,
"apiBase", m.APIBase,
"work_mode", mode,
)
logger.With(
"upstreamURL", upstream,
"requestHeader", newReq.Header,
"requestBody", req,
"taskID", taskID,
"messages", cvt.Filter(req.Messages, func(i int, v openai.ChatCompletionMessage) (openai.ChatCompletionMessage, bool) {
return v, v.Role != "system"
}),
).DebugContext(ctx, "转发流式请求到上游API")
// 发送请求
resp, err := p.client.Do(newReq)
resp, err := p.streamClient.Do(newReq)
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游流式请求失败", "error", err)
return fmt.Errorf("发送上游请求失败: %w", err)
@@ -655,7 +666,7 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
log.Latency = time.Since(startTime).Milliseconds()
// 在debug级别记录错误的流式响应内容
p.logger.With(
logger.With(
"statusCode", resp.StatusCode,
"responseHeader", resp.Header,
"responseBody", string(responseBody),
@@ -663,9 +674,8 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
var errorResp ErrResp
if err := json.Unmarshal(responseBody, &errorResp); err == nil {
p.logger.With(
logger.With(
"endpoint", endpoint,
"upstreamURL", upstream,
"requestBody", newReq,
"statusCode", resp.StatusCode,
"errorType", errorResp.Error.Type,
@@ -677,9 +687,8 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
return fmt.Errorf("上游API返回错误: %s", errorResp.Error.Message)
}
p.logger.With(
logger.With(
"endpoint", endpoint,
"upstreamURL", upstream,
"requestBody", newReq,
"statusCode", resp.StatusCode,
"responseBody", string(responseBody),
@@ -688,12 +697,10 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
return fmt.Errorf("上游API返回非200状态码: %d, 响应: %s", resp.StatusCode, string(responseBody))
}
// 更新日志信息
log.StatusCode = resp.StatusCode
log.ResponseHeader = resp.Header
// 在debug级别记录流式响应头信息
p.logger.With(
logger.With(
"statusCode", resp.StatusCode,
"responseHeader", resp.Header,
).DebugContext(ctx, "上游流式响应头信息")
@@ -705,78 +712,18 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon
w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("X-Accel-Buffering", "no")
rc := &domain.RecordParam{
RequestID: c.RequestID,
TaskID: taskID,
UserID: c.UserID,
ModelID: m.ID,
ModelType: consts.ModelTypeLLM,
WorkMode: mode,
Prompt: prompt,
Role: consts.ChatRoleAssistant,
}
ch := make(chan []byte, 1024)
defer close(ch)
go func(rc *domain.RecordParam) {
if rc.Prompt != "" {
urc := rc.Clone()
urc.Role = consts.ChatRoleUser
urc.Completion = urc.Prompt
if err := p.usecase.Record(context.Background(), urc); err != nil {
p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM).
WarnContext(ctx, "插入流式记录失败", "error", err)
}
}
for line := range ch {
if bytes.HasPrefix(line, []byte("data:")) {
line = bytes.TrimPrefix(line, []byte("data: "))
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.Equal(line, []byte("[DONE]")) {
break
}
var t openai.ChatCompletionStreamResponse
if err := json.Unmarshal(line, &t); err != nil {
p.logger.With("line", string(line)).WarnContext(ctx, "解析流式数据失败", "error", err)
continue
}
p.logger.With("response", t).DebugContext(ctx, "流式响应数据")
if len(t.Choices) > 0 {
rc.Completion += t.Choices[0].Delta.Content
}
if t.Usage != nil {
rc.InputTokens = int64(t.Usage.PromptTokens)
rc.OutputTokens = int64(t.Usage.CompletionTokens)
}
}
}
p.logger.With("record", rc).DebugContext(ctx, "流式记录")
if err := p.usecase.Record(context.Background(), rc); err != nil {
p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM).
WarnContext(ctx, "插入流式记录失败", "error", err)
}
}(rc)
err = streamRead(ctx, resp.Body, func(line []byte) error {
ch <- line
if _, err := w.Write(line); err != nil {
return fmt.Errorf("写入响应失败: %w", err)
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
return nil
})
return err
recorder := NewChatRecorder(
ctx,
c,
p.usecase,
m,
req,
resp.Body,
w,
p.logger.With("module", "ChatRecorder"),
)
defer recorder.Close()
return recorder.Stream()
})
}

View File

@@ -0,0 +1,223 @@
package proxy
import (
"context"
"encoding/json"
"io"
"log/slog"
"strings"
"github.com/rokku-c/go-openai"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/pkg/promptparser"
"github.com/chaitin/MonkeyCode/backend/pkg/tee"
)
type ChatRecorder struct {
*tee.Tee
ctx context.Context
cx *Ctx
usecase domain.ProxyUsecase
req *openai.ChatCompletionRequest
logger *slog.Logger
model *domain.Model
completion strings.Builder // 累积完整的响应内容
usage *openai.Usage // 最终的使用统计
buffer strings.Builder // 缓存不完整的行
recorded bool // 标记是否已记录完整对话
}
func NewChatRecorder(
ctx context.Context,
cx *Ctx,
usecase domain.ProxyUsecase,
model *domain.Model,
req *openai.ChatCompletionRequest,
r io.Reader,
w io.Writer,
logger *slog.Logger,
) *ChatRecorder {
c := &ChatRecorder{
ctx: ctx,
cx: cx,
usecase: usecase,
model: model,
req: req,
logger: logger,
}
c.Tee = tee.NewTee(ctx, logger, r, w, c.handle)
return c
}
func (c *ChatRecorder) handle(ctx context.Context, data []byte) error {
c.buffer.Write(data)
bufferContent := c.buffer.String()
lines := strings.Split(bufferContent, "\n")
if len(lines) > 0 {
lastLine := lines[len(lines)-1]
lines = lines[:len(lines)-1]
c.buffer.Reset()
c.buffer.WriteString(lastLine)
}
for _, line := range lines {
if err := c.processSSELine(ctx, line); err != nil {
return err
}
}
return nil
}
func (c *ChatRecorder) processSSELine(ctx context.Context, line string) error {
line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "data:") {
return nil
}
dataContent := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if dataContent == "" {
return nil
}
if dataContent == "[DONE]" {
c.processCompletedChat(ctx)
return nil
}
var resp openai.ChatCompletionStreamResponse
if err := json.Unmarshal([]byte(dataContent), &resp); err != nil {
c.logger.With("data", dataContent).WarnContext(ctx, "解析流式响应失败", "error", err)
return nil
}
prompt := c.getPrompt(ctx, c.req)
mode := c.req.Metadata["mode"]
taskID := c.req.Metadata["task_id"]
rc := &domain.RecordParam{
RequestID: c.cx.RequestID,
TaskID: taskID,
UserID: c.cx.UserID,
ModelID: c.model.ID,
ModelType: c.model.ModelType,
WorkMode: mode,
Prompt: prompt,
Role: consts.ChatRoleAssistant,
}
if resp.Usage != nil {
rc.InputTokens = int64(resp.Usage.PromptTokens)
}
if rc.Prompt != "" {
urc := rc.Clone()
urc.Role = consts.ChatRoleUser
urc.Completion = urc.Prompt
if err := c.usecase.Record(context.Background(), urc); err != nil {
c.logger.With("modelID", c.model.ID, "modelName", c.model.ModelName, "modelType", consts.ModelTypeLLM).
WarnContext(ctx, "插入流式记录失败", "error", err)
}
}
if len(resp.Choices) > 0 {
content := resp.Choices[0].Delta.Content
if content != "" {
c.completion.WriteString(content)
}
}
if resp.Usage != nil {
c.usage = resp.Usage
}
return nil
}
func (c *ChatRecorder) processCompletedChat(ctx context.Context) {
if c.recorded {
return // 避免重复记录
}
mode := c.req.Metadata["mode"]
taskID := c.req.Metadata["task_id"]
rc := &domain.RecordParam{
RequestID: c.cx.RequestID,
TaskID: taskID,
UserID: c.cx.UserID,
ModelID: c.model.ID,
ModelType: c.model.ModelType,
WorkMode: mode,
Role: consts.ChatRoleAssistant,
Completion: c.completion.String(),
InputTokens: int64(c.usage.PromptTokens),
OutputTokens: int64(c.usage.CompletionTokens),
}
if err := c.usecase.Record(context.Background(), rc); err != nil {
c.logger.With("modelID", c.model.ID, "modelName", c.model.ModelName, "modelType", consts.ModelTypeLLM).
WarnContext(ctx, "插入流式记录失败", "error", err)
} else {
c.recorded = true
c.logger.With("requestID", c.cx.RequestID, "completion_length", len(c.completion.String())).
InfoContext(ctx, "流式对话记录已保存")
}
}
// Close 关闭 recorder 并确保数据被保存
func (c *ChatRecorder) Close() {
// 如果有累积的内容但还没有记录,强制保存
if !c.recorded && c.completion.Len() > 0 {
c.logger.With("requestID", c.cx.RequestID).
WarnContext(c.ctx, "数据流异常中断,强制保存已累积的内容")
c.processCompletedChat(c.ctx)
}
if c.Tee != nil {
c.Tee.Close()
}
}
func (c *ChatRecorder) Reset() {
c.completion.Reset()
c.buffer.Reset()
c.usage = nil
c.recorded = false
}
func (c *ChatRecorder) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string {
prompt := ""
parse := promptparser.New(promptparser.KindTask)
for _, message := range req.Messages {
if message.Role == "system" {
continue
}
if strings.Contains(message.Content, "<task>") ||
strings.Contains(message.Content, "<feedback>") ||
strings.Contains(message.Content, "<user_message>") {
if info, err := parse.Parse(message.Content); err == nil {
prompt = info.Prompt
} else {
c.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err)
}
}
for _, m := range message.MultiContent {
if strings.Contains(m.Text, "<task>") ||
strings.Contains(m.Text, "<feedback>") ||
strings.Contains(m.Text, "<user_message>") {
if info, err := parse.Parse(m.Text); err == nil {
prompt = info.Prompt
} else {
c.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err)
}
}
}
}
return prompt
}

View File

@@ -5,12 +5,8 @@ import (
"log/slog"
)
type contextKey string
const (
RequestIDKey contextKey = "request_id"
UserIDKey contextKey = "user_id"
)
type RequestIDKey struct{}
type UserIDKey struct{}
type ContextLogger struct {
slog.Handler
@@ -31,11 +27,11 @@ func (c *ContextLogger) WithGroup(name string) slog.Handler {
func (c *ContextLogger) Handle(ctx context.Context, r slog.Record) error {
newRecord := r.Clone()
if i, ok := ctx.Value(RequestIDKey).(string); ok {
if i, ok := ctx.Value(RequestIDKey{}).(string); ok {
newRecord.AddAttrs(slog.String("request_id", i))
}
if i, ok := ctx.Value(UserIDKey).(string); ok {
if i, ok := ctx.Value(UserIDKey{}).(string); ok {
newRecord.AddAttrs(slog.String("user_id", i))
}

103
backend/pkg/tee/tee.go Normal file
View File

@@ -0,0 +1,103 @@
package tee
import (
"context"
"io"
"log/slog"
"net/http"
"sync"
)
type TeeHandleFunc func(ctx context.Context, data []byte) error
var bufferPool = sync.Pool{
New: func() any {
buf := make([]byte, 4096)
return &buf
},
}
type Tee struct {
ctx context.Context
logger *slog.Logger
Reader io.Reader
Writer io.Writer
ch chan []byte
handle TeeHandleFunc
}
func NewTee(
ctx context.Context,
logger *slog.Logger,
reader io.Reader,
writer io.Writer,
handle TeeHandleFunc,
) *Tee {
t := &Tee{
ctx: ctx,
logger: logger,
Reader: reader,
Writer: writer,
handle: handle,
ch: make(chan []byte, 32*1024),
}
go t.Handle()
return t
}
func (t *Tee) Close() {
select {
case <-t.ch:
// channel 已经关闭
default:
close(t.ch)
}
}
func (t *Tee) Handle() {
for {
select {
case data, ok := <-t.ch:
if !ok {
t.logger.DebugContext(t.ctx, "Tee Handle closed")
return
}
err := t.handle(t.ctx, data)
if err != nil {
t.logger.With("data", string(data)).With("error", err).ErrorContext(t.ctx, "Tee Handle error")
return
}
case <-t.ctx.Done():
t.logger.DebugContext(t.ctx, "Tee Handle ctx done")
return
}
}
}
func (t *Tee) Stream() error {
bufPtr := bufferPool.Get().(*[]byte)
buf := *bufPtr
defer bufferPool.Put(bufPtr)
for {
n, err := t.Reader.Read(buf)
if err != nil {
if err == io.EOF {
return nil
}
return err
}
if n > 0 {
_, err = t.Writer.Write(buf[:n])
if err != nil {
return err
}
if flusher, ok := t.Writer.(http.Flusher); ok {
flusher.Flush()
}
data := make([]byte, n)
copy(data, buf[:n])
t.ch <- data
}
}
}

376
backend/pkg/tee/tee_test.go Normal file
View File

@@ -0,0 +1,376 @@
package tee
import (
"bytes"
"context"
"errors"
"io"
"log/slog"
"os"
"strings"
"sync"
"testing"
"time"
)
// mockWriter 模拟 Writer 接口
type mockWriter struct {
buf *bytes.Buffer
delay time.Duration
errorOn int // 在第几次写入时返回错误
count int
}
func newMockWriter() *mockWriter {
return &mockWriter{
buf: &bytes.Buffer{},
}
}
func (m *mockWriter) Write(p []byte) (n int, err error) {
m.count++
if m.errorOn > 0 && m.count >= m.errorOn {
return 0, errors.New("mock write error")
}
if m.delay > 0 {
time.Sleep(m.delay)
}
return m.buf.Write(p)
}
func (m *mockWriter) String() string {
return m.buf.String()
}
// mockReader 模拟 Reader 接口
type mockReader struct {
data []byte
pos int
chunk int // 每次读取的字节数
errorOn int // 在第几次读取时返回错误
count int
}
func newMockReader(data string, chunk int) *mockReader {
return &mockReader{
data: []byte(data),
chunk: chunk,
}
}
func (m *mockReader) Read(p []byte) (n int, err error) {
m.count++
if m.errorOn > 0 && m.count >= m.errorOn {
return 0, errors.New("mock read error")
}
if m.pos >= len(m.data) {
return 0, io.EOF
}
readSize := m.chunk
if readSize <= 0 || readSize > len(p) {
readSize = len(p)
}
remaining := len(m.data) - m.pos
if readSize > remaining {
readSize = remaining
}
copy(p, m.data[m.pos:m.pos+readSize])
m.pos += readSize
return readSize, nil
}
// TestTeeBasicFunctionality 测试基本功能
func TestTeeBasicFunctionality(t *testing.T) {
ctx := context.Background()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
testData := "Hello, World! This is a test message."
reader := newMockReader(testData, 10) // 每次读取10字节
writer := newMockWriter()
var handledData [][]byte
var mu sync.Mutex
handle := func(ctx context.Context, data []byte) error {
mu.Lock()
defer mu.Unlock()
// 复制数据,因为原始数据可能被重用
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
handledData = append(handledData, dataCopy)
return nil
}
tee := NewTee(ctx, logger, reader, writer, handle)
defer tee.Close()
err := tee.Stream()
if err != nil {
t.Fatalf("Stream() failed: %v", err)
}
// 等待处理完成
time.Sleep(100 * time.Millisecond)
// 验证写入的数据
if writer.String() != testData {
t.Errorf("Expected writer data %q, got %q", testData, writer.String())
}
// 验证处理的数据
mu.Lock()
var totalHandled []byte
for _, chunk := range handledData {
totalHandled = append(totalHandled, chunk...)
}
mu.Unlock()
if string(totalHandled) != testData {
t.Errorf("Expected handled data %q, got %q", testData, string(totalHandled))
}
}
// TestTeeWithErrors 测试错误处理
func TestTeeWithErrors(t *testing.T) {
ctx := context.Background()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
t.Run("ReaderError", func(t *testing.T) {
reader := newMockReader("test data", 5)
reader.errorOn = 2 // 第二次读取时出错
writer := newMockWriter()
handle := func(ctx context.Context, data []byte) error {
return nil
}
tee := NewTee(ctx, logger, reader, writer, handle)
defer tee.Close()
err := tee.Stream()
if err == nil {
t.Error("Expected error from reader, got nil")
}
})
t.Run("WriterError", func(t *testing.T) {
reader := newMockReader("test data", 5)
writer := newMockWriter()
writer.errorOn = 2 // 第二次写入时出错
handle := func(ctx context.Context, data []byte) error {
return nil
}
tee := NewTee(ctx, logger, reader, writer, handle)
defer tee.Close()
err := tee.Stream()
if err == nil {
t.Error("Expected error from writer, got nil")
}
})
t.Run("HandleError", func(t *testing.T) {
reader := newMockReader("test data", 5)
writer := newMockWriter()
handle := func(ctx context.Context, data []byte) error {
return errors.New("handle error")
}
tee := NewTee(ctx, logger, reader, writer, handle)
defer tee.Close()
// 启动 Stream 在单独的 goroutine 中
go func() {
tee.Stream()
}()
// 等待一段时间让处理器有机会处理数据并出错
time.Sleep(200 * time.Millisecond)
})
}
// TestTeeContextCancellation 测试上下文取消
func TestTeeContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
// 创建一个会持续产生数据的 reader
reader := strings.NewReader(strings.Repeat("test data ", 1000))
writer := newMockWriter()
handle := func(ctx context.Context, data []byte) error {
return nil
}
tee := NewTee(ctx, logger, reader, writer, handle)
defer tee.Close()
// 在单独的 goroutine 中启动 Stream
done := make(chan error, 1)
go func() {
done <- tee.Stream()
}()
// 等待一段时间后取消上下文
time.Sleep(50 * time.Millisecond)
cancel()
// 等待 Stream 完成
select {
case err := <-done:
if err != nil && err != io.EOF {
t.Logf("Stream completed with error: %v", err)
}
case <-time.After(2 * time.Second):
t.Error("Stream did not complete within timeout")
}
}
// TestTeeConcurrentSafety 测试并发安全性
func TestTeeConcurrentSafety(t *testing.T) {
ctx := context.Background()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
testData := strings.Repeat("concurrent test data ", 100)
reader := strings.NewReader(testData)
writer := newMockWriter()
var processedCount int64
var mu sync.Mutex
handle := func(ctx context.Context, data []byte) error {
mu.Lock()
processedCount++
mu.Unlock()
// 模拟一些处理时间
time.Sleep(time.Microsecond)
return nil
}
tee := NewTee(ctx, logger, reader, writer, handle)
defer tee.Close()
err := tee.Stream()
if err != nil {
t.Fatalf("Stream() failed: %v", err)
}
// 等待所有数据处理完成
time.Sleep(500 * time.Millisecond)
mu.Lock()
finalCount := processedCount
mu.Unlock()
if finalCount == 0 {
t.Error("No data was processed")
}
t.Logf("Processed %d chunks of data", finalCount)
}
// TestBufferPoolEfficiency 测试缓冲区池的效率
func TestBufferPoolEfficiency(t *testing.T) {
// 这个测试验证缓冲区池是否正常工作
// 通过多次获取和归还缓冲区来测试
var buffers []*[]byte
// 获取多个缓冲区
for i := 0; i < 10; i++ {
bufPtr := bufferPool.Get().(*[]byte)
buffers = append(buffers, bufPtr)
// 验证缓冲区大小
if len(*bufPtr) != 4096 {
t.Errorf("Expected buffer size 4096, got %d", len(*bufPtr))
}
}
// 归还所有缓冲区
for _, bufPtr := range buffers {
bufferPool.Put(bufPtr)
}
// 再次获取缓冲区,应该重用之前的缓冲区
for i := 0; i < 5; i++ {
bufPtr := bufferPool.Get().(*[]byte)
if len(*bufPtr) != 4096 {
t.Errorf("Expected reused buffer size 4096, got %d", len(*bufPtr))
}
bufferPool.Put(bufPtr)
}
}
// BenchmarkTeeStream 基准测试
func BenchmarkTeeStream(b *testing.B) {
ctx := context.Background()
logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}))
testData := strings.Repeat("benchmark test data ", 1000)
handle := func(ctx context.Context, data []byte) error {
return nil
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
reader := strings.NewReader(testData)
writer := io.Discard
tee := NewTee(ctx, logger, reader, writer, handle)
err := tee.Stream()
if err != nil {
b.Fatalf("Stream() failed: %v", err)
}
tee.Close()
}
}
// BenchmarkBufferPool 缓冲区池基准测试
func BenchmarkBufferPool(b *testing.B) {
b.Run("WithPool", func(b *testing.B) {
for i := 0; i < b.N; i++ {
bufPtr := bufferPool.Get().(*[]byte)
// 模拟使用缓冲区
_ = *bufPtr
bufferPool.Put(bufPtr)
}
})
b.Run("WithoutPool", func(b *testing.B) {
for i := 0; i < b.N; i++ {
buf := make([]byte, 4096)
// 模拟使用缓冲区
_ = buf
}
})
}
// TestTeeClose 测试关闭功能
func TestTeeClose(t *testing.T) {
ctx := context.Background()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
reader := strings.NewReader("test data")
writer := newMockWriter()
handle := func(ctx context.Context, data []byte) error {
return nil
}
tee := NewTee(ctx, logger, reader, writer, handle)
// 测试多次关闭不会 panic
tee.Close()
tee.Close()
tee.Close()
}