Files
MonkeyCode/backend/internal/proxy/proxy.go
2025-07-04 10:56:03 +08:00

882 lines
26 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package proxy
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/rokku-c/go-openai"
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
"github.com/chaitin/MonkeyCode/backend/pkg/logger"
"github.com/chaitin/MonkeyCode/backend/pkg/promptparser"
"github.com/chaitin/MonkeyCode/backend/pkg/request"
)
// ErrResp 错误响应
type ErrResp struct {
Error struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code,omitempty"`
} `json:"error"`
}
// RequestResponseLog 请求响应日志结构
type RequestResponseLog struct {
Timestamp time.Time `json:"timestamp"`
RequestID string `json:"request_id"`
Endpoint string `json:"endpoint"`
ModelName string `json:"model_name"`
ModelType consts.ModelType `json:"model_type"`
UpstreamURL string `json:"upstream_url"`
RequestBody any `json:"request_body"`
RequestHeader map[string][]string `json:"request_header"`
StatusCode int `json:"status_code"`
ResponseBody any `json:"response_body"`
ResponseHeader map[string][]string `json:"response_header"`
Latency int64 `json:"latency_ms"`
Error string `json:"error,omitempty"`
IsStream bool `json:"is_stream"`
SourceIP string `json:"source_ip"` // 请求来源IP地址
}
// LLMProxy LLM API代理实现
type LLMProxy struct {
usecase domain.ProxyUsecase
cfg *config.Config
client *http.Client
logger *slog.Logger
requestLogPath string // 请求日志保存路径
}
// NewLLMProxy 创建LLM API代理实例
func NewLLMProxy(
usecase domain.ProxyUsecase,
cfg *config.Config,
logger *slog.Logger,
) domain.Proxy {
// 解析超时时间
timeout, err := time.ParseDuration(cfg.LLMProxy.Timeout)
if err != nil {
timeout = 30 * time.Second
logger.Warn("解析超时时间失败, 使用默认值 30s", "error", err)
}
// 解析保持连接时间
keepAlive, err := time.ParseDuration(cfg.LLMProxy.KeepAlive)
if err != nil {
keepAlive = 60 * time.Second
logger.Warn("解析保持连接时间失败, 使用默认值 60s", "error", err)
}
// 创建HTTP客户端
client := &http.Client{
Timeout: timeout,
Transport: &http.Transport{
MaxIdleConns: cfg.LLMProxy.ClientPoolSize,
MaxConnsPerHost: cfg.LLMProxy.ClientPoolSize,
MaxIdleConnsPerHost: cfg.LLMProxy.ClientPoolSize,
IdleConnTimeout: keepAlive,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableCompression: false,
ForceAttemptHTTP2: true,
},
}
// 获取日志配置
requestLogPath := ""
if cfg.LLMProxy.RequestLogPath != "" {
requestLogPath = cfg.LLMProxy.RequestLogPath
// 确保目录存在
if err := os.MkdirAll(requestLogPath, 0755); err != nil {
logger.Warn("创建请求日志目录失败", "error", err, "path", requestLogPath)
}
}
return &LLMProxy{
usecase: usecase,
client: client,
cfg: cfg,
requestLogPath: requestLogPath,
logger: logger,
}
}
// saveRequestResponseLog 保存请求响应日志到文件
func (p *LLMProxy) saveRequestResponseLog(log *RequestResponseLog) {
if p.requestLogPath == "" {
return
}
// 创建文件名格式YYYYMMDD_HHMMSS_请求ID.json
timestamp := log.Timestamp.Format("20060102_150405") + fmt.Sprintf("_%03d", log.Timestamp.Nanosecond()/1e6)
filename := fmt.Sprintf("%s_%s.json", timestamp, log.RequestID)
filepath := filepath.Join(p.requestLogPath, filename)
// 将日志序列化为JSON
logData, err := json.MarshalIndent(log, "", " ")
if err != nil {
p.logger.Error("序列化请求日志失败", "error", err)
return
}
// 写入文件
if err := os.WriteFile(filepath, logData, 0644); err != nil {
p.logger.Error("写入请求日志文件失败", "error", err, "path", filepath)
return
}
p.logger.Debug("请求响应日志已保存", "path", filepath)
}
func (p *LLMProxy) AcceptCompletion(ctx context.Context, req *domain.AcceptCompletionReq) error {
return p.usecase.AcceptCompletion(ctx, req)
}
func writeErrResp(w http.ResponseWriter, code int, message, errorType string) {
resp := ErrResp{
Error: struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code,omitempty"`
}{
Message: message,
Type: errorType,
},
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
b, _ := json.Marshal(resp)
w.Write(b)
}
type Ctx struct {
UserID string
RequestID string
SourceIP string
}
func (p *LLMProxy) handle(ctx context.Context, fn func(ctx *Ctx, log *RequestResponseLog) error) {
// 获取用户ID
userID := "unknown"
if id, ok := ctx.Value(logger.UserIDKey).(string); ok {
userID = id
}
requestID := "unknown"
if id, ok := ctx.Value(logger.RequestIDKey).(string); ok {
requestID = id
}
sourceip := "unknown"
if ip, ok := ctx.Value("remote_addr").(string); ok {
sourceip = ip
}
// 创建请求日志结构
l := &RequestResponseLog{
Timestamp: time.Now(),
RequestID: requestID,
ModelType: consts.ModelTypeCoder,
SourceIP: sourceip,
}
c := &Ctx{
UserID: userID,
RequestID: requestID,
SourceIP: sourceip,
}
if err := fn(c, l); err != nil {
p.logger.With("userID", userID, "requestID", requestID, "sourceip", sourceip).ErrorContext(ctx, "处理请求失败", "error", err)
l.Error = err.Error()
}
p.saveRequestResponseLog(l)
}
func (p *LLMProxy) HandleCompletion(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) {
if req.Stream {
p.handleCompletionStream(ctx, w, req)
} else {
p.handleCompletion(ctx, w, req)
}
}
func (p *LLMProxy) handleCompletionStream(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) {
endpoint := "/completions"
p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error {
// 使用负载均衡算法选择模型
m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeCoder)
if err != nil {
p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "模型选择失败", "error", err)
writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error")
return err
}
// 构造上游API URL
upstream := m.APIBase + endpoint
log.UpstreamURL = upstream
startTime := time.Now()
// 创建上游请求
body, err := json.Marshal(req)
if err != nil {
p.logger.ErrorContext(ctx, "序列化请求体失败", "error", err)
return fmt.Errorf("序列化请求体失败: %w", err)
}
upreq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstream, bytes.NewReader(body))
if err != nil {
p.logger.With("upstream", upstream).WarnContext(ctx, "创建上游流式请求失败", "error", err)
return fmt.Errorf("创建上游请求失败: %w", err)
}
// 设置请求头
upreq.Header.Set("Content-Type", "application/json")
upreq.Header.Set("Accept", "text/event-stream")
if m.APIKey != "" && m.APIKey != "none" {
upreq.Header.Set("Authorization", "Bearer "+m.APIKey)
}
// 保存请求头(去除敏感信息)
requestHeaders := make(map[string][]string)
for k, v := range upreq.Header {
if k != "Authorization" {
requestHeaders[k] = v
} else {
// 敏感信息脱敏
requestHeaders[k] = []string{"Bearer ***"}
}
}
log.RequestHeader = requestHeaders
p.logger.With(
"upstreamURL", upstream,
"modelName", m.ModelName,
"modelType", consts.ModelTypeLLM,
"apiBase", m.APIBase,
"requestHeader", upreq.Header,
"requestBody", upreq,
).DebugContext(ctx, "转发流式请求到上游API")
// 发送请求
resp, err := p.client.Do(upreq)
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游流式请求失败", "error", err)
return fmt.Errorf("发送上游请求失败: %w", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
// 更新日志错误信息
log.StatusCode = resp.StatusCode
log.ResponseHeader = resp.Header
log.ResponseBody = string(responseBody)
log.Latency = time.Since(startTime).Milliseconds()
// 在debug级别记录错误的流式响应内容
p.logger.With(
"statusCode", resp.StatusCode,
"responseHeader", resp.Header,
"responseBody", string(responseBody),
).DebugContext(ctx, "上游流式响应错误原始内容")
var errorResp ErrResp
if err := json.Unmarshal(responseBody, &errorResp); err == nil {
p.logger.With(
"endpoint", endpoint,
"upstreamURL", upstream,
"requestBody", upreq,
"statusCode", resp.StatusCode,
"errorType", errorResp.Error.Type,
"errorCode", errorResp.Error.Code,
"errorMessage", errorResp.Error.Message,
"latency", time.Since(startTime),
).WarnContext(ctx, "上游API流式请求异常详情")
return fmt.Errorf("上游API返回错误: %s", errorResp.Error.Message)
}
p.logger.With(
"endpoint", endpoint,
"upstreamURL", upstream,
"requestBody", upreq,
"statusCode", resp.StatusCode,
"responseBody", string(responseBody),
).WarnContext(ctx, "上游API流式请求异常详情")
return fmt.Errorf("上游API返回非200状态码: %d, 响应: %s", resp.StatusCode, string(responseBody))
}
// 更新日志信息
log.StatusCode = resp.StatusCode
log.ResponseHeader = resp.Header
// 在debug级别记录流式响应头信息
p.logger.With(
"statusCode", resp.StatusCode,
"responseHeader", resp.Header,
).DebugContext(ctx, "上游流式响应头信息")
// 设置响应头
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Transfer-Encoding", "chunked")
rc := &domain.RecordParam{
UserID: c.UserID,
ModelID: m.ID,
ModelType: consts.ModelTypeLLM,
Prompt: req.Prompt.(string),
Role: consts.ChatRoleAssistant,
}
buf := bufio.NewWriterSize(w, 32*1024)
defer buf.Flush()
ch := make(chan []byte, 1024)
defer close(ch)
go func(rc *domain.RecordParam) {
for line := range ch {
if bytes.HasPrefix(line, []byte("data:")) {
line = bytes.TrimPrefix(line, []byte("data: "))
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.Equal(line, []byte("[DONE]")) {
break
}
var t openai.CompletionResponse
if err := json.Unmarshal(line, &t); err != nil {
p.logger.With("line", string(line)).WarnContext(ctx, "解析流式数据失败", "error", err)
continue
}
p.logger.With("response", t).DebugContext(ctx, "流式响应数据")
if len(t.Choices) > 0 {
rc.Completion += t.Choices[0].Text
}
rc.InputTokens = int64(t.Usage.PromptTokens)
rc.OutputTokens += int64(t.Usage.CompletionTokens)
}
}
if rc.OutputTokens == 0 {
return
}
p.logger.With("record", rc).DebugContext(ctx, "流式记录")
if err := p.usecase.Record(context.Background(), rc); err != nil {
p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM).
WarnContext(ctx, "插入流式记录失败", "error", err)
}
}(rc)
err = streamRead(ctx, resp.Body, func(line []byte) error {
ch <- line
if _, err := buf.Write(line); err != nil {
return fmt.Errorf("写入响应失败: %w", err)
}
return buf.Flush()
})
return err
})
}
func (p *LLMProxy) handleCompletion(ctx context.Context, w http.ResponseWriter, req domain.CompletionRequest) {
endpoint := "/completions"
p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error {
// 使用负载均衡算法选择模型
m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeCoder)
if err != nil {
p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "模型选择失败", "error", err)
writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error")
return err
}
// 构造上游API URL
upstream := m.APIBase + endpoint
log.UpstreamURL = upstream
u, err := url.Parse(upstream)
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "解析上游URL失败", "error", err)
writeErrResp(w, http.StatusInternalServerError, "无效的上游URL", "proxy_error")
return err
}
startTime := time.Now()
client := request.NewClient(u.Scheme, u.Host, 30*time.Second)
client.SetDebug(p.cfg.Debug)
resp, err := request.Post[openai.CompletionResponse](client, u.Path, req, request.WithHeader(request.Header{
"Authorization": "Bearer " + m.APIKey,
}))
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游请求失败", "error", err)
writeErrResp(w, http.StatusInternalServerError, "发送上游请求失败", "proxy_error")
log.Latency = time.Since(startTime).Milliseconds()
return err
}
latency := time.Since(startTime)
// 记录请求信息
p.logger.With(
"statusCode", http.StatusOK,
"upstreamURL", upstream,
"modelName", req.Model,
"modelType", consts.ModelTypeCoder,
"apiBase", m.APIBase,
"requestBody", req,
"resp", resp,
"latency", latency.String(),
).DebugContext(ctx, "转发请求到上游API")
go p.recordCompletion(c, m.ID, req, resp)
// 更新请求日志
log.StatusCode = http.StatusOK
log.ResponseBody = resp
log.ResponseHeader = resp.Header()
log.Latency = latency.Milliseconds()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
b, err := json.Marshal(resp)
if err != nil {
p.logger.With("response", resp).WarnContext(ctx, "序列化响应失败", "error", err)
return err
}
w.Write(b)
return nil
})
}
func (p *LLMProxy) recordCompletion(c *Ctx, modelID string, req domain.CompletionRequest, resp *openai.CompletionResponse) {
if resp.Usage.CompletionTokens == 0 {
return
}
ctx := context.Background()
prompt := req.Prompt.(string)
rc := &domain.RecordParam{
TaskID: resp.ID,
UserID: c.UserID,
ModelID: modelID,
ModelType: consts.ModelTypeCoder,
Prompt: prompt,
ProgramLanguage: req.Metadata["program_language"],
InputTokens: int64(resp.Usage.PromptTokens),
OutputTokens: int64(resp.Usage.CompletionTokens),
}
for _, choice := range resp.Choices {
rc.Completion += choice.Text
}
lines := strings.Count(rc.Completion, "\n") + 1
rc.CodeLines = int64(lines)
if err := p.usecase.Record(ctx, rc); err != nil {
p.logger.With("modelID", modelID, "modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "记录请求失败", "error", err)
}
}
func (p *LLMProxy) HandleChatCompletion(ctx context.Context, w http.ResponseWriter, req *openai.ChatCompletionRequest) {
if req.Stream {
p.handleChatCompletionStream(ctx, w, req)
} else {
p.handleChatCompletion(ctx, w, req)
}
}
func streamRead(ctx context.Context, r io.Reader, fn func([]byte) error) error {
reader := bufio.NewReaderSize(r, 32*1024)
for {
select {
case <-ctx.Done():
return fmt.Errorf("流式请求被取消: %w", ctx.Err())
default:
line, err := reader.ReadBytes('\n')
if err == io.EOF {
return nil
}
if err != nil {
return fmt.Errorf("读取流式数据失败: %w", err)
}
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
line = append(line, '\n')
line = append(line, '\n')
fn(line)
}
}
}
func (p *LLMProxy) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string {
prompt := ""
parse := promptparser.New(promptparser.KindTask)
for _, message := range req.Messages {
if message.Role == "system" {
continue
}
if strings.Contains(message.Content, "<task>") ||
strings.Contains(message.Content, "<feedback>") ||
strings.Contains(message.Content, "<user_message>") {
if info, err := parse.Parse(message.Content); err == nil {
prompt = info.Prompt
} else {
p.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err)
}
}
for _, m := range message.MultiContent {
if strings.Contains(m.Text, "<task>") ||
strings.Contains(m.Text, "<feedback>") ||
strings.Contains(m.Text, "<user_message>") {
if info, err := parse.Parse(m.Text); err == nil {
prompt = info.Prompt
} else {
p.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err)
}
}
}
}
return prompt
}
func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.ResponseWriter, req *openai.ChatCompletionRequest) {
endpoint := "/chat/completions"
p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error {
startTime := time.Now()
m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeLLM)
if err != nil {
p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeLLM).WarnContext(ctx, "流式请求模型选择失败", "error", err)
writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error")
return err
}
prompt := p.getPrompt(ctx, req)
mode := req.Metadata["mode"]
taskID := req.Metadata["task_id"]
upstream := m.APIBase + endpoint
log.UpstreamURL = upstream
body, err := json.Marshal(req)
if err != nil {
p.logger.ErrorContext(ctx, "序列化请求体失败", "error", err)
return fmt.Errorf("序列化请求体失败: %w", err)
}
newReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstream, bytes.NewReader(body))
if err != nil {
p.logger.With("upstream", upstream).WarnContext(ctx, "创建上游流式请求失败", "error", err)
return fmt.Errorf("创建上游请求失败: %w", err)
}
newReq.Header.Set("Content-Type", "application/json")
newReq.Header.Set("Accept", "text/event-stream")
if m.APIKey != "" && m.APIKey != "none" {
newReq.Header.Set("Authorization", "Bearer "+m.APIKey)
}
// 保存请求头(去除敏感信息)
requestHeaders := make(map[string][]string)
for k, v := range newReq.Header {
if k != "Authorization" {
requestHeaders[k] = v
} else {
// 敏感信息脱敏
requestHeaders[k] = []string{"Bearer ***"}
}
}
log.RequestHeader = requestHeaders
p.logger.With(
"upstreamURL", upstream,
"modelName", m.ModelName,
"modelType", consts.ModelTypeLLM,
"apiBase", m.APIBase,
"work_mode", mode,
"requestHeader", newReq.Header,
"requestBody", req,
"taskID", taskID,
"messages", cvt.Filter(req.Messages, func(i int, v openai.ChatCompletionMessage) (openai.ChatCompletionMessage, bool) {
return v, v.Role != "system"
}),
).DebugContext(ctx, "转发流式请求到上游API")
// 发送请求
resp, err := p.client.Do(newReq)
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游流式请求失败", "error", err)
return fmt.Errorf("发送上游请求失败: %w", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
// 更新日志错误信息
log.StatusCode = resp.StatusCode
log.ResponseHeader = resp.Header
log.ResponseBody = string(responseBody)
log.Latency = time.Since(startTime).Milliseconds()
// 在debug级别记录错误的流式响应内容
p.logger.With(
"statusCode", resp.StatusCode,
"responseHeader", resp.Header,
"responseBody", string(responseBody),
).DebugContext(ctx, "上游流式响应错误原始内容")
var errorResp ErrResp
if err := json.Unmarshal(responseBody, &errorResp); err == nil {
p.logger.With(
"endpoint", endpoint,
"upstreamURL", upstream,
"requestBody", newReq,
"statusCode", resp.StatusCode,
"errorType", errorResp.Error.Type,
"errorCode", errorResp.Error.Code,
"errorMessage", errorResp.Error.Message,
"latency", time.Since(startTime),
).WarnContext(ctx, "上游API流式请求异常详情")
return fmt.Errorf("上游API返回错误: %s", errorResp.Error.Message)
}
p.logger.With(
"endpoint", endpoint,
"upstreamURL", upstream,
"requestBody", newReq,
"statusCode", resp.StatusCode,
"responseBody", string(responseBody),
).WarnContext(ctx, "上游API流式请求异常详情")
return fmt.Errorf("上游API返回非200状态码: %d, 响应: %s", resp.StatusCode, string(responseBody))
}
// 更新日志信息
log.StatusCode = resp.StatusCode
log.ResponseHeader = resp.Header
// 在debug级别记录流式响应头信息
p.logger.With(
"statusCode", resp.StatusCode,
"responseHeader", resp.Header,
).DebugContext(ctx, "上游流式响应头信息")
// 设置响应头
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Transfer-Encoding", "chunked")
rc := &domain.RecordParam{
RequestID: c.RequestID,
TaskID: taskID,
UserID: c.UserID,
ModelID: m.ID,
ModelType: consts.ModelTypeLLM,
WorkMode: mode,
Prompt: prompt,
Role: consts.ChatRoleAssistant,
}
buf := bufio.NewWriterSize(w, 32*1024)
defer buf.Flush()
ch := make(chan []byte, 1024)
defer close(ch)
go func(rc *domain.RecordParam) {
if rc.Prompt != "" {
urc := rc.Clone()
urc.Role = consts.ChatRoleUser
urc.Completion = urc.Prompt
if err := p.usecase.Record(context.Background(), urc); err != nil {
p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM).
WarnContext(ctx, "插入流式记录失败", "error", err)
}
}
for line := range ch {
if bytes.HasPrefix(line, []byte("data:")) {
line = bytes.TrimPrefix(line, []byte("data: "))
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.Equal(line, []byte("[DONE]")) {
break
}
var t openai.ChatCompletionStreamResponse
if err := json.Unmarshal(line, &t); err != nil {
p.logger.With("line", string(line)).WarnContext(ctx, "解析流式数据失败", "error", err)
continue
}
p.logger.With("response", t).DebugContext(ctx, "流式响应数据")
if len(t.Choices) > 0 {
rc.Completion += t.Choices[0].Delta.Content
}
if t.Usage != nil {
rc.InputTokens = int64(t.Usage.PromptTokens)
rc.OutputTokens = int64(t.Usage.CompletionTokens)
}
}
}
p.logger.With("record", rc).DebugContext(ctx, "流式记录")
if err := p.usecase.Record(context.Background(), rc); err != nil {
p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM).
WarnContext(ctx, "插入流式记录失败", "error", err)
}
}(rc)
err = streamRead(ctx, resp.Body, func(line []byte) error {
ch <- line
if _, err := buf.Write(line); err != nil {
return fmt.Errorf("写入响应失败: %w", err)
}
return buf.Flush()
})
return err
})
}
func (p *LLMProxy) handleChatCompletion(ctx context.Context, w http.ResponseWriter, req *openai.ChatCompletionRequest) {
endpoint := "/chat/completions"
p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error {
// 使用负载均衡算法选择模型
m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeCoder)
if err != nil {
p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "模型选择失败", "error", err)
writeErrResp(w, http.StatusNotFound, "模型未找到", "proxy_error")
return err
}
// 构造上游API URL
upstream := m.APIBase + endpoint
log.UpstreamURL = upstream
u, err := url.Parse(upstream)
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "解析上游URL失败", "error", err)
writeErrResp(w, http.StatusInternalServerError, "无效的上游URL", "proxy_error")
return err
}
startTime := time.Now()
prompt := p.getPrompt(ctx, req)
mode := req.Metadata["mode"]
taskID := req.Metadata["task_id"]
client := request.NewClient(u.Scheme, u.Host, 30*time.Second)
resp, err := request.Post[openai.ChatCompletionResponse](client, u.Path, req, request.WithHeader(request.Header{
"Authorization": "Bearer " + m.APIKey,
}))
if err != nil {
p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游请求失败", "error", err)
writeErrResp(w, http.StatusInternalServerError, "发送上游请求失败", "proxy_error")
log.Latency = time.Since(startTime).Milliseconds()
return err
}
// 记录请求信息
p.logger.With(
"upstreamURL", upstream,
"modelName", req.Model,
"modelType", consts.ModelTypeCoder,
"apiBase", m.APIBase,
"requestBody", req,
).DebugContext(ctx, "转发请求到上游API")
go func() {
rc := &domain.RecordParam{
TaskID: taskID,
UserID: c.UserID,
Prompt: prompt,
WorkMode: mode,
ModelID: m.ID,
ModelType: m.ModelType,
InputTokens: int64(resp.Usage.PromptTokens),
OutputTokens: int64(resp.Usage.CompletionTokens),
ProgramLanguage: req.Metadata["program_language"],
}
for _, choice := range resp.Choices {
rc.Completion += choice.Message.Content + "\n\n"
}
p.logger.With("record", rc).DebugContext(ctx, "记录")
if err := p.usecase.Record(ctx, rc); err != nil {
p.logger.With("modelID", m.ID, "modelName", req.Model, "modelType", consts.ModelTypeCoder).WarnContext(ctx, "记录请求失败", "error", err)
}
}()
// 计算请求耗时
latency := time.Since(startTime)
// 更新请求日志
log.StatusCode = http.StatusOK
log.ResponseBody = resp
log.ResponseHeader = resp.Header()
log.Latency = latency.Milliseconds()
// 记录响应状态
p.logger.With(
"statusCode", http.StatusOK,
"responseHeader", resp.Header(),
"responseBody", resp,
"latency", latency.String(),
).DebugContext(ctx, "上游API响应")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
b, err := json.Marshal(resp)
if err != nil {
p.logger.With("response", resp).WarnContext(ctx, "序列化响应失败", "error", err)
return err
}
w.Write(b)
return nil
})
}
func (p *LLMProxy) HandleEmbeddings(ctx context.Context, w http.ResponseWriter, req *openai.EmbeddingRequest) {
}