mirror of
https://github.com/chaitin/MonkeyCode.git
synced 2026-02-01 22:33:30 +08:00
198 lines
5.4 KiB
Go
198 lines
5.4 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/chaitin/MonkeyCode/backend/config"
|
|
"github.com/chaitin/MonkeyCode/backend/consts"
|
|
"github.com/chaitin/MonkeyCode/backend/domain"
|
|
"github.com/chaitin/MonkeyCode/backend/internal/middleware"
|
|
"github.com/chaitin/MonkeyCode/backend/pkg/logger"
|
|
"github.com/chaitin/MonkeyCode/backend/pkg/tee"
|
|
)
|
|
|
|
type CtxKey struct{}
|
|
|
|
type ProxyCtx struct {
|
|
ctx context.Context
|
|
Path string
|
|
Model *domain.Model
|
|
Header http.Header
|
|
RespHeader http.Header
|
|
ReqTee *tee.ReqTee
|
|
RequestID string
|
|
UserID string
|
|
Metadata map[string]string
|
|
}
|
|
|
|
type LLMProxy struct {
|
|
logger *slog.Logger
|
|
cfg *config.Config
|
|
usecase domain.ProxyUsecase
|
|
transport *http.Transport
|
|
proxy *httputil.ReverseProxy
|
|
}
|
|
|
|
func NewLLMProxy(
|
|
logger *slog.Logger,
|
|
cfg *config.Config,
|
|
usecase domain.ProxyUsecase,
|
|
) *LLMProxy {
|
|
l := &LLMProxy{
|
|
logger: logger,
|
|
cfg: cfg,
|
|
usecase: usecase,
|
|
}
|
|
|
|
l.transport = &http.Transport{
|
|
MaxIdleConns: cfg.LLMProxy.ClientPoolSize,
|
|
MaxConnsPerHost: cfg.LLMProxy.ClientPoolSize,
|
|
MaxIdleConnsPerHost: cfg.LLMProxy.ClientPoolSize,
|
|
IdleConnTimeout: 24 * time.Hour,
|
|
Proxy: http.ProxyFromEnvironment,
|
|
DialContext: (&net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 24 * time.Hour,
|
|
}).DialContext,
|
|
}
|
|
l.proxy = &httputil.ReverseProxy{
|
|
Transport: l.transport,
|
|
Rewrite: l.rewrite,
|
|
ModifyResponse: l.modifyResponse,
|
|
ErrorHandler: l.errorHandler,
|
|
FlushInterval: 100 * time.Millisecond,
|
|
}
|
|
return l
|
|
}
|
|
|
|
func (l *LLMProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
l.proxy.ServeHTTP(w, r)
|
|
}
|
|
|
|
func (l *LLMProxy) Close() error {
|
|
l.transport.CloseIdleConnections()
|
|
return nil
|
|
}
|
|
|
|
var modelType = map[string]consts.ModelType{
|
|
"/v1/chat/completions": consts.ModelTypeLLM,
|
|
"/v1/completions": consts.ModelTypeCoder,
|
|
}
|
|
|
|
func (l *LLMProxy) rewrite(r *httputil.ProxyRequest) {
|
|
l.logger.DebugContext(r.In.Context(), "rewrite request", slog.String("path", r.In.URL.Path))
|
|
mt, ok := modelType[r.In.URL.Path]
|
|
if !ok {
|
|
l.logger.ErrorContext(r.In.Context(), "model type not found", slog.String("path", r.In.URL.Path))
|
|
return
|
|
}
|
|
|
|
var m *domain.Model
|
|
var err error
|
|
if m = middleware.GetProxyModel(r.In.Context()); m == nil {
|
|
m, err = l.usecase.SelectModelWithLoadBalancing("", mt)
|
|
if err != nil {
|
|
l.logger.ErrorContext(r.In.Context(), "select model with load balancing failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
|
|
return
|
|
}
|
|
}
|
|
|
|
ul, err := url.Parse(m.APIBase)
|
|
if err != nil {
|
|
l.logger.ErrorContext(r.In.Context(), "parse model api base failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
|
|
return
|
|
}
|
|
|
|
metadata := make(map[string]string)
|
|
if m.Provider == consts.ModelProviderZhiPu {
|
|
body, err := io.ReadAll(r.In.Body)
|
|
if err != nil {
|
|
l.logger.ErrorContext(r.In.Context(), "read request body failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
|
|
return
|
|
}
|
|
req := make(map[string]any)
|
|
err = json.Unmarshal(body, &req)
|
|
if err != nil {
|
|
l.logger.ErrorContext(r.In.Context(), "unmarshal request body failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
|
|
return
|
|
}
|
|
if md, ok := req["metadata"].(map[string]any); ok {
|
|
for k, v := range md {
|
|
metadata[k] = v.(string)
|
|
}
|
|
}
|
|
delete(req, "metadata")
|
|
body, err = json.Marshal(req)
|
|
if err != nil {
|
|
l.logger.ErrorContext(r.In.Context(), "marshal request body failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
|
|
return
|
|
}
|
|
r.In.Body = io.NopCloser(bytes.NewBuffer(body))
|
|
r.In.ContentLength = int64(len(body))
|
|
r.Out.Body = io.NopCloser(bytes.NewBuffer(body))
|
|
r.Out.ContentLength = int64(len(body))
|
|
}
|
|
|
|
path := r.In.URL.Path
|
|
path = strings.ReplaceAll(path, "/v1", "")
|
|
path = ul.Path + path
|
|
if r.In.ContentLength > 0 {
|
|
tee := tee.NewReqTeeWithMaxSize(r.In.Body, 10*1024*1024)
|
|
r.Out.Body = tee
|
|
ctx := context.WithValue(r.In.Context(), CtxKey{}, &ProxyCtx{
|
|
ctx: r.In.Context(),
|
|
Path: path,
|
|
Model: m,
|
|
ReqTee: tee,
|
|
RequestID: r.In.Context().Value(logger.RequestIDKey{}).(string),
|
|
UserID: r.In.Context().Value(logger.UserIDKey{}).(string),
|
|
Header: r.In.Header,
|
|
Metadata: metadata,
|
|
})
|
|
r.Out = r.Out.WithContext(ctx)
|
|
}
|
|
|
|
r.Out.URL.Scheme = ul.Scheme
|
|
r.Out.URL.Host = ul.Host
|
|
r.Out.URL.Path = path
|
|
r.Out.Header.Set("Authorization", "Bearer "+m.APIKey)
|
|
r.SetXForwarded()
|
|
r.Out.Host = ul.Host
|
|
l.logger.With(
|
|
"in", r.In.URL.Path,
|
|
"out", r.Out.URL.Path,
|
|
"metadata", metadata,
|
|
).DebugContext(r.In.Context(), "rewrite request")
|
|
}
|
|
|
|
func (l *LLMProxy) modifyResponse(resp *http.Response) error {
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
l.logger.ErrorContext(resp.Request.Context(), "modify response failed", slog.String("body", string(body)))
|
|
resp.Body = io.NopCloser(bytes.NewBuffer(body))
|
|
return nil
|
|
}
|
|
|
|
ctx := resp.Request.Context()
|
|
if pctx, ok := ctx.Value(CtxKey{}).(*ProxyCtx); ok {
|
|
pctx.ctx = ctx
|
|
pctx.RespHeader = resp.Header
|
|
resp.Body = NewRecorder(l.cfg, pctx, resp.Body, l.logger, l.usecase)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (l *LLMProxy) errorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
|
l.logger.ErrorContext(r.Context(), "error handler", slog.String("path", r.URL.Path), slog.Any("err", err))
|
|
}
|