From 1d61853118d3ad0c9613a4910442948c48a5dc77 Mon Sep 17 00:00:00 2001 From: yokowu <18836617@qq.com> Date: Thu, 24 Jul 2025 16:46:19 +0800 Subject: [PATCH] =?UTF-8?q?feat(tasks):=20=E4=BC=98=E5=8C=96=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E4=B8=8A=E6=8A=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/cmd/server/wire_gen.go | 2 +- backend/consts/proxy.go | 11 +++-- backend/db/migrate/schema.go | 6 +-- backend/db/mutation.go | 19 ++++++++ backend/db/runtime/runtime.go | 8 ++++ backend/db/taskrecord/taskrecord.go | 4 ++ backend/db/taskrecord/where.go | 10 +++++ backend/db/taskrecord_create.go | 55 ++++++++++++++++++++++-- backend/db/taskrecord_update.go | 18 ++++++++ backend/domain/billing.go | 7 ++- backend/domain/proxy.go | 4 +- backend/ent/schema/taskrecord.go | 6 +-- backend/internal/billing/repo/billing.go | 1 - backend/internal/openai/handler/v1/v1.go | 1 + backend/internal/proxy/recorder.go | 13 +----- backend/internal/proxy/repo/proxy.go | 46 +++++++++++++++++++- backend/internal/proxy/usecase/proxy.go | 26 +++++++++-- 17 files changed, 203 insertions(+), 34 deletions(-) diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index df12a93..e3ecd13 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -61,7 +61,7 @@ func newServer() (*Server, error) { redisClient := store.NewRedisCli(configConfig) proxyRepo := repo.NewProxyRepo(client, redisClient) modelRepo := repo2.NewModelRepo(client) - proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo) + proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo, slogLogger) llmProxy := proxy.NewLLMProxy(slogLogger, configConfig, proxyUsecase) openAIRepo := repo3.NewOpenAIRepo(client) openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, slogLogger) diff --git a/backend/consts/proxy.go b/backend/consts/proxy.go index aa3dabe..02f6a92 100644 --- a/backend/consts/proxy.go +++ b/backend/consts/proxy.go @@ -3,8 +3,11 @@ package consts type ReportAction string const ( - ReportActionAccept ReportAction = "accept" - ReportActionSuggest ReportAction = "suggest" - ReportActionFileWritten ReportAction = "file_written" - ReportActionReject ReportAction = "reject" + ReportActionAccept ReportAction = "accept" + ReportActionSuggest ReportAction = "suggest" + ReportActionFileWritten ReportAction = "file_written" + ReportActionReject ReportAction = "reject" + ReportActionNewTask ReportAction = "new_task" + ReportActionFeedbackTask ReportAction = "feedback_task" + ReportActionAbortTask ReportAction = "abort_task" ) diff --git a/backend/db/migrate/schema.go b/backend/db/migrate/schema.go index b9e953d..c6fea4c 100644 --- a/backend/db/migrate/schema.go +++ b/backend/db/migrate/schema.go @@ -305,9 +305,9 @@ var ( {Name: "id", Type: field.TypeUUID, Unique: true}, {Name: "prompt", Type: field.TypeString, Nullable: true}, {Name: "role", Type: field.TypeString}, - {Name: "completion", Type: field.TypeString}, - {Name: "output_tokens", Type: field.TypeInt64}, - {Name: "code_lines", Type: field.TypeInt64}, + {Name: "completion", Type: field.TypeString, Nullable: true}, + {Name: "output_tokens", Type: field.TypeInt64, Default: 0}, + {Name: "code_lines", Type: field.TypeInt64, Default: 0}, {Name: "code", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, {Name: "updated_at", Type: field.TypeTime}, diff --git a/backend/db/mutation.go b/backend/db/mutation.go index c451827..c78690e 100644 --- a/backend/db/mutation.go +++ b/backend/db/mutation.go @@ -11994,9 +11994,22 @@ func (m *TaskRecordMutation) OldCompletion(ctx context.Context) (v string, err e return oldValue.Completion, nil } +// ClearCompletion clears the value of the "completion" field. +func (m *TaskRecordMutation) ClearCompletion() { + m.completion = nil + m.clearedFields[taskrecord.FieldCompletion] = struct{}{} +} + +// CompletionCleared returns if the "completion" field was cleared in this mutation. +func (m *TaskRecordMutation) CompletionCleared() bool { + _, ok := m.clearedFields[taskrecord.FieldCompletion] + return ok +} + // ResetCompletion resets all changes to the "completion" field. func (m *TaskRecordMutation) ResetCompletion() { m.completion = nil + delete(m.clearedFields, taskrecord.FieldCompletion) } // SetOutputTokens sets the "output_tokens" field. @@ -12509,6 +12522,9 @@ func (m *TaskRecordMutation) ClearedFields() []string { if m.FieldCleared(taskrecord.FieldPrompt) { fields = append(fields, taskrecord.FieldPrompt) } + if m.FieldCleared(taskrecord.FieldCompletion) { + fields = append(fields, taskrecord.FieldCompletion) + } if m.FieldCleared(taskrecord.FieldCode) { fields = append(fields, taskrecord.FieldCode) } @@ -12532,6 +12548,9 @@ func (m *TaskRecordMutation) ClearField(name string) error { case taskrecord.FieldPrompt: m.ClearPrompt() return nil + case taskrecord.FieldCompletion: + m.ClearCompletion() + return nil case taskrecord.FieldCode: m.ClearCode() return nil diff --git a/backend/db/runtime/runtime.go b/backend/db/runtime/runtime.go index af72064..134ae82 100644 --- a/backend/db/runtime/runtime.go +++ b/backend/db/runtime/runtime.go @@ -272,6 +272,14 @@ func init() { task.UpdateDefaultUpdatedAt = taskDescUpdatedAt.UpdateDefault.(func() time.Time) taskrecordFields := schema.TaskRecord{}.Fields() _ = taskrecordFields + // taskrecordDescOutputTokens is the schema descriptor for output_tokens field. + taskrecordDescOutputTokens := taskrecordFields[5].Descriptor() + // taskrecord.DefaultOutputTokens holds the default value on creation for the output_tokens field. + taskrecord.DefaultOutputTokens = taskrecordDescOutputTokens.Default.(int64) + // taskrecordDescCodeLines is the schema descriptor for code_lines field. + taskrecordDescCodeLines := taskrecordFields[6].Descriptor() + // taskrecord.DefaultCodeLines holds the default value on creation for the code_lines field. + taskrecord.DefaultCodeLines = taskrecordDescCodeLines.Default.(int64) // taskrecordDescCreatedAt is the schema descriptor for created_at field. taskrecordDescCreatedAt := taskrecordFields[8].Descriptor() // taskrecord.DefaultCreatedAt holds the default value on creation for the created_at field. diff --git a/backend/db/taskrecord/taskrecord.go b/backend/db/taskrecord/taskrecord.go index e290321..907b8f5 100644 --- a/backend/db/taskrecord/taskrecord.go +++ b/backend/db/taskrecord/taskrecord.go @@ -70,6 +70,10 @@ func ValidColumn(column string) bool { } var ( + // DefaultOutputTokens holds the default value on creation for the "output_tokens" field. + DefaultOutputTokens int64 + // DefaultCodeLines holds the default value on creation for the "code_lines" field. + DefaultCodeLines int64 // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. diff --git a/backend/db/taskrecord/where.go b/backend/db/taskrecord/where.go index a8736f8..7d1a575 100644 --- a/backend/db/taskrecord/where.go +++ b/backend/db/taskrecord/where.go @@ -347,6 +347,16 @@ func CompletionHasSuffix(v string) predicate.TaskRecord { return predicate.TaskRecord(sql.FieldHasSuffix(FieldCompletion, v)) } +// CompletionIsNil applies the IsNil predicate on the "completion" field. +func CompletionIsNil() predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldIsNull(FieldCompletion)) +} + +// CompletionNotNil applies the NotNil predicate on the "completion" field. +func CompletionNotNil() predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNotNull(FieldCompletion)) +} + // CompletionEqualFold applies the EqualFold predicate on the "completion" field. func CompletionEqualFold(v string) predicate.TaskRecord { return predicate.TaskRecord(sql.FieldEqualFold(FieldCompletion, v)) diff --git a/backend/db/taskrecord_create.go b/backend/db/taskrecord_create.go index e7b89bc..28970ef 100644 --- a/backend/db/taskrecord_create.go +++ b/backend/db/taskrecord_create.go @@ -66,18 +66,42 @@ func (trc *TaskRecordCreate) SetCompletion(s string) *TaskRecordCreate { return trc } +// SetNillableCompletion sets the "completion" field if the given value is not nil. +func (trc *TaskRecordCreate) SetNillableCompletion(s *string) *TaskRecordCreate { + if s != nil { + trc.SetCompletion(*s) + } + return trc +} + // SetOutputTokens sets the "output_tokens" field. func (trc *TaskRecordCreate) SetOutputTokens(i int64) *TaskRecordCreate { trc.mutation.SetOutputTokens(i) return trc } +// SetNillableOutputTokens sets the "output_tokens" field if the given value is not nil. +func (trc *TaskRecordCreate) SetNillableOutputTokens(i *int64) *TaskRecordCreate { + if i != nil { + trc.SetOutputTokens(*i) + } + return trc +} + // SetCodeLines sets the "code_lines" field. func (trc *TaskRecordCreate) SetCodeLines(i int64) *TaskRecordCreate { trc.mutation.SetCodeLines(i) return trc } +// SetNillableCodeLines sets the "code_lines" field if the given value is not nil. +func (trc *TaskRecordCreate) SetNillableCodeLines(i *int64) *TaskRecordCreate { + if i != nil { + trc.SetCodeLines(*i) + } + return trc +} + // SetCode sets the "code" field. func (trc *TaskRecordCreate) SetCode(s string) *TaskRecordCreate { trc.mutation.SetCode(s) @@ -166,6 +190,14 @@ func (trc *TaskRecordCreate) ExecX(ctx context.Context) { // defaults sets the default values of the builder before save. func (trc *TaskRecordCreate) defaults() { + if _, ok := trc.mutation.OutputTokens(); !ok { + v := taskrecord.DefaultOutputTokens + trc.mutation.SetOutputTokens(v) + } + if _, ok := trc.mutation.CodeLines(); !ok { + v := taskrecord.DefaultCodeLines + trc.mutation.SetCodeLines(v) + } if _, ok := trc.mutation.CreatedAt(); !ok { v := taskrecord.DefaultCreatedAt() trc.mutation.SetCreatedAt(v) @@ -181,9 +213,6 @@ func (trc *TaskRecordCreate) check() error { if _, ok := trc.mutation.Role(); !ok { return &ValidationError{Name: "role", err: errors.New(`db: missing required field "TaskRecord.role"`)} } - if _, ok := trc.mutation.Completion(); !ok { - return &ValidationError{Name: "completion", err: errors.New(`db: missing required field "TaskRecord.completion"`)} - } if _, ok := trc.mutation.OutputTokens(); !ok { return &ValidationError{Name: "output_tokens", err: errors.New(`db: missing required field "TaskRecord.output_tokens"`)} } @@ -393,6 +422,12 @@ func (u *TaskRecordUpsert) UpdateCompletion() *TaskRecordUpsert { return u } +// ClearCompletion clears the value of the "completion" field. +func (u *TaskRecordUpsert) ClearCompletion() *TaskRecordUpsert { + u.SetNull(taskrecord.FieldCompletion) + return u +} + // SetOutputTokens sets the "output_tokens" field. func (u *TaskRecordUpsert) SetOutputTokens(v int64) *TaskRecordUpsert { u.Set(taskrecord.FieldOutputTokens, v) @@ -589,6 +624,13 @@ func (u *TaskRecordUpsertOne) UpdateCompletion() *TaskRecordUpsertOne { }) } +// ClearCompletion clears the value of the "completion" field. +func (u *TaskRecordUpsertOne) ClearCompletion() *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.ClearCompletion() + }) +} + // SetOutputTokens sets the "output_tokens" field. func (u *TaskRecordUpsertOne) SetOutputTokens(v int64) *TaskRecordUpsertOne { return u.Update(func(s *TaskRecordUpsert) { @@ -965,6 +1007,13 @@ func (u *TaskRecordUpsertBulk) UpdateCompletion() *TaskRecordUpsertBulk { }) } +// ClearCompletion clears the value of the "completion" field. +func (u *TaskRecordUpsertBulk) ClearCompletion() *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.ClearCompletion() + }) +} + // SetOutputTokens sets the "output_tokens" field. func (u *TaskRecordUpsertBulk) SetOutputTokens(v int64) *TaskRecordUpsertBulk { return u.Update(func(s *TaskRecordUpsert) { diff --git a/backend/db/taskrecord_update.go b/backend/db/taskrecord_update.go index 1c5be23..6e85d4d 100644 --- a/backend/db/taskrecord_update.go +++ b/backend/db/taskrecord_update.go @@ -100,6 +100,12 @@ func (tru *TaskRecordUpdate) SetNillableCompletion(s *string) *TaskRecordUpdate return tru } +// ClearCompletion clears the value of the "completion" field. +func (tru *TaskRecordUpdate) ClearCompletion() *TaskRecordUpdate { + tru.mutation.ClearCompletion() + return tru +} + // SetOutputTokens sets the "output_tokens" field. func (tru *TaskRecordUpdate) SetOutputTokens(i int64) *TaskRecordUpdate { tru.mutation.ResetOutputTokens() @@ -261,6 +267,9 @@ func (tru *TaskRecordUpdate) sqlSave(ctx context.Context) (n int, err error) { if value, ok := tru.mutation.Completion(); ok { _spec.SetField(taskrecord.FieldCompletion, field.TypeString, value) } + if tru.mutation.CompletionCleared() { + _spec.ClearField(taskrecord.FieldCompletion, field.TypeString) + } if value, ok := tru.mutation.OutputTokens(); ok { _spec.SetField(taskrecord.FieldOutputTokens, field.TypeInt64, value) } @@ -404,6 +413,12 @@ func (truo *TaskRecordUpdateOne) SetNillableCompletion(s *string) *TaskRecordUpd return truo } +// ClearCompletion clears the value of the "completion" field. +func (truo *TaskRecordUpdateOne) ClearCompletion() *TaskRecordUpdateOne { + truo.mutation.ClearCompletion() + return truo +} + // SetOutputTokens sets the "output_tokens" field. func (truo *TaskRecordUpdateOne) SetOutputTokens(i int64) *TaskRecordUpdateOne { truo.mutation.ResetOutputTokens() @@ -595,6 +610,9 @@ func (truo *TaskRecordUpdateOne) sqlSave(ctx context.Context) (_node *TaskRecord if value, ok := truo.mutation.Completion(); ok { _spec.SetField(taskrecord.FieldCompletion, field.TypeString, value) } + if truo.mutation.CompletionCleared() { + _spec.ClearField(taskrecord.FieldCompletion, field.TypeString) + } if value, ok := truo.mutation.OutputTokens(); ok { _spec.SetField(taskrecord.FieldOutputTokens, field.TypeInt64, value) } diff --git a/backend/domain/billing.go b/backend/domain/billing.go index c412b1b..20d2b81 100644 --- a/backend/domain/billing.go +++ b/backend/domain/billing.go @@ -118,7 +118,7 @@ func (c *CompletionInfo) From(e *db.Task) *CompletionInfo { } type ChatContent struct { - Role consts.ChatRole `json:"role"` // 角色,如user: 用户的提问 assistant: 机器人回复 + Role consts.ChatRole `json:"role"` // 角色,如user: 用户的提问 assistant: 机器人回复 system: 系统消息 Content string `json:"content"` // 内容 CreatedAt int64 `json:"created_at"` } @@ -133,6 +133,8 @@ func (c *ChatContent) From(e *db.TaskRecord) *ChatContent { c.Content = e.Prompt case consts.ChatRoleAssistant: c.Content = e.Completion + case consts.ChatRoleSystem: + c.Content = e.Completion } c.CreatedAt = e.CreatedAt.Unix() return c @@ -151,6 +153,9 @@ func (c *ChatInfo) From(e *db.Task) *ChatInfo { c.Contents = cvt.Iter(e.Edges.TaskRecords, func(_ int, r *db.TaskRecord) *ChatContent { return cvt.From(r, &ChatContent{}) }) + c.Contents = cvt.Filter(c.Contents, func(_ int, r *ChatContent) (*ChatContent, bool) { + return r, r.Content != "" + }) return c } diff --git a/backend/domain/proxy.go b/backend/domain/proxy.go index 2d51a22..745eacb 100644 --- a/backend/domain/proxy.go +++ b/backend/domain/proxy.go @@ -29,7 +29,7 @@ type ProxyRepo interface { Record(ctx context.Context, record *RecordParam) error UpdateByTaskID(ctx context.Context, taskID string, fn func(*db.TaskUpdateOne)) error AcceptCompletion(ctx context.Context, req *AcceptCompletionReq) error - Report(ctx context.Context, req *ReportReq) error + Report(ctx context.Context, model *db.Model, req *ReportReq) error SelectModelWithLoadBalancing(modelName string, modelType consts.ModelType) (*db.Model, error) ValidateApiKey(ctx context.Context, key string) (*db.ApiKey, error) } @@ -52,6 +52,8 @@ type ReportReq struct { UserInput string `json:"user_input"` // 用户输入的新文本(用于reject action) SourceCode string `json:"source_code"` // 当前文件的原文(用于reject action) CursorPosition map[string]any `json:"cursor_position"` // 光标位置(用于reject action) + Mode string `json:"mode"` // 模式 + UserID string `json:"-"` } type RecordParam struct { diff --git a/backend/ent/schema/taskrecord.go b/backend/ent/schema/taskrecord.go index 206fac6..c363938 100644 --- a/backend/ent/schema/taskrecord.go +++ b/backend/ent/schema/taskrecord.go @@ -33,9 +33,9 @@ func (TaskRecord) Fields() []ent.Field { field.UUID("task_id", uuid.UUID{}).Optional(), field.String("prompt").Optional(), field.String("role").GoType(consts.ChatRole("")), - field.String("completion"), - field.Int64("output_tokens"), - field.Int64("code_lines"), + field.String("completion").Optional(), + field.Int64("output_tokens").Default(0), + field.Int64("code_lines").Default(0), field.String("code").Optional(), field.Time("created_at").Default(time.Now), field.Time("updated_at").Default(time.Now).UpdateDefault(time.Now), diff --git a/backend/internal/billing/repo/billing.go b/backend/internal/billing/repo/billing.go index 8e7c7c9..4beb61a 100644 --- a/backend/internal/billing/repo/billing.go +++ b/backend/internal/billing/repo/billing.go @@ -29,7 +29,6 @@ func (b *BillingRepo) ChatInfo(ctx context.Context, id, userID string) (*domain. q := b.db.Task.Query(). WithTaskRecords(func(trq *db.TaskRecordQuery) { trq.Order(taskrecord.ByCreatedAt(sql.OrderAsc())) - trq.Where(taskrecord.RoleNEQ(consts.ChatRoleSystem)) }). Where(task.TaskID(id)) if userID != "" { diff --git a/backend/internal/openai/handler/v1/v1.go b/backend/internal/openai/handler/v1/v1.go index 74a75f5..bf263a2 100644 --- a/backend/internal/openai/handler/v1/v1.go +++ b/backend/internal/openai/handler/v1/v1.go @@ -110,6 +110,7 @@ func (h *V1Handler) AcceptCompletion(c *web.Context, req domain.AcceptCompletion // @Router /v1/report [post] func (h *V1Handler) Report(c *web.Context, req domain.ReportReq) error { h.logger.DebugContext(c.Request().Context(), "Report", slog.Any("req", req)) + req.UserID = middleware.GetApiKey(c).UserID if err := h.proxyUse.Report(c.Request().Context(), &req); err != nil { return err } diff --git a/backend/internal/proxy/recorder.go b/backend/internal/proxy/recorder.go index 18492f7..fb6974d 100644 --- a/backend/internal/proxy/recorder.go +++ b/backend/internal/proxy/recorder.go @@ -153,19 +153,8 @@ func (r *Recorder) handleShadow() { With("resp_header", formatHeader(r.ctx.RespHeader)). DebugContext(r.ctx.ctx, "handle shadow", "rc", rc) - // 记录用户的提问 - if r.ctx.Model.ModelType == consts.ModelTypeLLM && prompt != "" { - tmp := rc.Clone() - tmp.Role = consts.ChatRoleUser - tmp.Completion = "" - tmp.OutputTokens = 0 - if err := r.usecase.Record(context.Background(), tmp); err != nil { - r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err) - } - } - if err := r.usecase.Record(context.Background(), rc); err != nil { - r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err) + r.logger.With("record", rc).WarnContext(r.ctx.ctx, "记录请求失败", "error", err) } } diff --git a/backend/internal/proxy/repo/proxy.go b/backend/internal/proxy/repo/proxy.go index 552dcfb..98799aa 100644 --- a/backend/internal/proxy/repo/proxy.go +++ b/backend/internal/proxy/repo/proxy.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "strings" "time" @@ -76,6 +77,9 @@ func (r *ProxyRepo) ValidateApiKey(ctx context.Context, key string) (*db.ApiKey, } func (r *ProxyRepo) Record(ctx context.Context, record *domain.RecordParam) error { + if record.TaskID == "" { + return fmt.Errorf("task_id is empty") + } userID, err := uuid.Parse(record.UserID) if err != nil { return err @@ -212,14 +216,52 @@ func abs(x int64) int64 { return x } -func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error { +func (r *ProxyRepo) Report(ctx context.Context, model *db.Model, req *domain.ReportReq) error { return entx.WithTx(ctx, r.db, func(tx *db.Tx) error { rc, err := tx.Task.Query().Where(task.TaskID(req.ID)).Only(ctx) if err != nil { - return err + if req.Action == consts.ReportActionNewTask && db.IsNotFound(err) { + uid, err := uuid.Parse(req.UserID) + if err != nil { + return err + } + newTask, err := tx.Task.Create(). + SetTaskID(req.ID). + SetRequestID(uuid.NewString()). + SetUserID(uid). + SetModelID(model.ID). + SetModelType(model.ModelType). + SetWorkMode(req.Mode). + SetPrompt(req.Content). + Save(ctx) + if err != nil { + return err + } + rc = newTask + } else { + return err + } } switch req.Action { + case consts.ReportActionNewTask, consts.ReportActionFeedbackTask: + if err := tx.TaskRecord.Create(). + SetTaskID(rc.ID). + SetRole(consts.ChatRoleUser). + SetPrompt(req.Content). + Exec(ctx); err != nil { + return err + } + + case consts.ReportActionAbortTask: + if err := tx.TaskRecord.Create(). + SetTaskID(rc.ID). + SetRole(consts.ChatRoleSystem). + SetCompletion(req.Content). + Exec(ctx); err != nil { + return err + } + case consts.ReportActionAccept: if err := tx.Task.UpdateOneID(rc.ID). SetIsAccept(true). diff --git a/backend/internal/proxy/usecase/proxy.go b/backend/internal/proxy/usecase/proxy.go index 65dcc7b..635c98b 100644 --- a/backend/internal/proxy/usecase/proxy.go +++ b/backend/internal/proxy/usecase/proxy.go @@ -2,7 +2,9 @@ package usecase import ( "context" + "log/slog" + "github.com/chaitin/MonkeyCode/backend/db" "github.com/chaitin/MonkeyCode/backend/pkg/cvt" "github.com/chaitin/MonkeyCode/backend/consts" @@ -12,10 +14,19 @@ import ( type ProxyUsecase struct { repo domain.ProxyRepo modelRepo domain.ModelRepo + logger *slog.Logger } -func NewProxyUsecase(repo domain.ProxyRepo, modelRepo domain.ModelRepo) domain.ProxyUsecase { - return &ProxyUsecase{repo: repo, modelRepo: modelRepo} +func NewProxyUsecase( + repo domain.ProxyRepo, + modelRepo domain.ModelRepo, + logger *slog.Logger, +) domain.ProxyUsecase { + return &ProxyUsecase{ + repo: repo, + modelRepo: modelRepo, + logger: logger.With("module", "ProxyUsecase"), + } } func (p *ProxyUsecase) Record(ctx context.Context, record *domain.RecordParam) error { @@ -44,5 +55,14 @@ func (p *ProxyUsecase) AcceptCompletion(ctx context.Context, req *domain.AcceptC } func (p *ProxyUsecase) Report(ctx context.Context, req *domain.ReportReq) error { - return p.repo.Report(ctx, req) + var model *db.Model + var err error + if req.Action == consts.ReportActionNewTask { + model, err = p.modelRepo.GetWithCache(context.Background(), consts.ModelTypeLLM) + if err != nil { + p.logger.With("fn", "Report").With("error", err).ErrorContext(ctx, "failed to get model") + return err + } + } + return p.repo.Report(ctx, model, req) }