Files
MonkeyCode/backend/internal/proxy/proxy.go
2025-08-29 18:29:50 +08:00

198 lines
5.4 KiB
Go

package proxy
import (
"bytes"
"context"
"encoding/json"
"io"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/internal/middleware"
"github.com/chaitin/MonkeyCode/backend/pkg/logger"
"github.com/chaitin/MonkeyCode/backend/pkg/tee"
)
type CtxKey struct{}
type ProxyCtx struct {
ctx context.Context
Path string
Model *domain.Model
Header http.Header
RespHeader http.Header
ReqTee *tee.ReqTee
RequestID string
UserID string
Metadata map[string]string
}
type LLMProxy struct {
logger *slog.Logger
cfg *config.Config
usecase domain.ProxyUsecase
transport *http.Transport
proxy *httputil.ReverseProxy
}
func NewLLMProxy(
logger *slog.Logger,
cfg *config.Config,
usecase domain.ProxyUsecase,
) *LLMProxy {
l := &LLMProxy{
logger: logger,
cfg: cfg,
usecase: usecase,
}
l.transport = &http.Transport{
MaxIdleConns: cfg.LLMProxy.ClientPoolSize,
MaxConnsPerHost: cfg.LLMProxy.ClientPoolSize,
MaxIdleConnsPerHost: cfg.LLMProxy.ClientPoolSize,
IdleConnTimeout: 24 * time.Hour,
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 24 * time.Hour,
}).DialContext,
}
l.proxy = &httputil.ReverseProxy{
Transport: l.transport,
Rewrite: l.rewrite,
ModifyResponse: l.modifyResponse,
ErrorHandler: l.errorHandler,
FlushInterval: 100 * time.Millisecond,
}
return l
}
func (l *LLMProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
l.proxy.ServeHTTP(w, r)
}
func (l *LLMProxy) Close() error {
l.transport.CloseIdleConnections()
return nil
}
var modelType = map[string]consts.ModelType{
"/v1/chat/completions": consts.ModelTypeLLM,
"/v1/completions": consts.ModelTypeCoder,
}
func (l *LLMProxy) rewrite(r *httputil.ProxyRequest) {
l.logger.DebugContext(r.In.Context(), "rewrite request", slog.String("path", r.In.URL.Path))
mt, ok := modelType[r.In.URL.Path]
if !ok {
l.logger.ErrorContext(r.In.Context(), "model type not found", slog.String("path", r.In.URL.Path))
return
}
var m *domain.Model
var err error
if m = middleware.GetProxyModel(r.In.Context()); m == nil {
m, err = l.usecase.SelectModelWithLoadBalancing("", mt)
if err != nil {
l.logger.ErrorContext(r.In.Context(), "select model with load balancing failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
return
}
}
ul, err := url.Parse(m.APIBase)
if err != nil {
l.logger.ErrorContext(r.In.Context(), "parse model api base failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
return
}
metadata := make(map[string]string)
if m.Provider == consts.ModelProviderZhiPu {
body, err := io.ReadAll(r.In.Body)
if err != nil {
l.logger.ErrorContext(r.In.Context(), "read request body failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
return
}
req := make(map[string]any)
err = json.Unmarshal(body, &req)
if err != nil {
l.logger.ErrorContext(r.In.Context(), "unmarshal request body failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
return
}
if md, ok := req["metadata"].(map[string]any); ok {
for k, v := range md {
metadata[k] = v.(string)
}
}
delete(req, "metadata")
body, err = json.Marshal(req)
if err != nil {
l.logger.ErrorContext(r.In.Context(), "marshal request body failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
return
}
r.In.Body = io.NopCloser(bytes.NewBuffer(body))
r.In.ContentLength = int64(len(body))
r.Out.Body = io.NopCloser(bytes.NewBuffer(body))
r.Out.ContentLength = int64(len(body))
}
path := r.In.URL.Path
path = strings.ReplaceAll(path, "/v1", "")
path = ul.Path + path
if r.In.ContentLength > 0 {
tee := tee.NewReqTeeWithMaxSize(r.In.Body, 10*1024*1024)
r.Out.Body = tee
ctx := context.WithValue(r.In.Context(), CtxKey{}, &ProxyCtx{
ctx: r.In.Context(),
Path: path,
Model: m,
ReqTee: tee,
RequestID: r.In.Context().Value(logger.RequestIDKey{}).(string),
UserID: r.In.Context().Value(logger.UserIDKey{}).(string),
Header: r.In.Header,
Metadata: metadata,
})
r.Out = r.Out.WithContext(ctx)
}
r.Out.URL.Scheme = ul.Scheme
r.Out.URL.Host = ul.Host
r.Out.URL.Path = path
r.Out.Header.Set("Authorization", "Bearer "+m.APIKey)
r.SetXForwarded()
r.Out.Host = ul.Host
l.logger.With(
"in", r.In.URL.Path,
"out", r.Out.URL.Path,
"metadata", metadata,
).DebugContext(r.In.Context(), "rewrite request")
}
func (l *LLMProxy) modifyResponse(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
l.logger.ErrorContext(resp.Request.Context(), "modify response failed", slog.String("body", string(body)))
resp.Body = io.NopCloser(bytes.NewBuffer(body))
return nil
}
ctx := resp.Request.Context()
if pctx, ok := ctx.Value(CtxKey{}).(*ProxyCtx); ok {
pctx.ctx = ctx
pctx.RespHeader = resp.Header
resp.Body = NewRecorder(l.cfg, pctx, resp.Body, l.logger, l.usecase)
}
return nil
}
func (l *LLMProxy) errorHandler(w http.ResponseWriter, r *http.Request, err error) {
l.logger.ErrorContext(r.Context(), "error handler", slog.String("path", r.URL.Path), slog.Any("err", err))
}