mirror of
https://github.com/chaitin/MonkeyCode.git
synced 2026-02-05 00:04:50 +08:00
Merge pull request #141 from yokowu/feat-task-report
feat(tasks): 优化任务上报
This commit is contained in:
@@ -61,7 +61,7 @@ func newServer() (*Server, error) {
|
||||
redisClient := store.NewRedisCli(configConfig)
|
||||
proxyRepo := repo.NewProxyRepo(client, redisClient)
|
||||
modelRepo := repo2.NewModelRepo(client)
|
||||
proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo)
|
||||
proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo, slogLogger)
|
||||
llmProxy := proxy.NewLLMProxy(slogLogger, configConfig, proxyUsecase)
|
||||
openAIRepo := repo3.NewOpenAIRepo(client)
|
||||
openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, slogLogger)
|
||||
|
||||
@@ -3,8 +3,11 @@ package consts
|
||||
type ReportAction string
|
||||
|
||||
const (
|
||||
ReportActionAccept ReportAction = "accept"
|
||||
ReportActionSuggest ReportAction = "suggest"
|
||||
ReportActionFileWritten ReportAction = "file_written"
|
||||
ReportActionReject ReportAction = "reject"
|
||||
ReportActionAccept ReportAction = "accept"
|
||||
ReportActionSuggest ReportAction = "suggest"
|
||||
ReportActionFileWritten ReportAction = "file_written"
|
||||
ReportActionReject ReportAction = "reject"
|
||||
ReportActionNewTask ReportAction = "new_task"
|
||||
ReportActionFeedbackTask ReportAction = "feedback_task"
|
||||
ReportActionAbortTask ReportAction = "abort_task"
|
||||
)
|
||||
|
||||
@@ -305,9 +305,9 @@ var (
|
||||
{Name: "id", Type: field.TypeUUID, Unique: true},
|
||||
{Name: "prompt", Type: field.TypeString, Nullable: true},
|
||||
{Name: "role", Type: field.TypeString},
|
||||
{Name: "completion", Type: field.TypeString},
|
||||
{Name: "output_tokens", Type: field.TypeInt64},
|
||||
{Name: "code_lines", Type: field.TypeInt64},
|
||||
{Name: "completion", Type: field.TypeString, Nullable: true},
|
||||
{Name: "output_tokens", Type: field.TypeInt64, Default: 0},
|
||||
{Name: "code_lines", Type: field.TypeInt64, Default: 0},
|
||||
{Name: "code", Type: field.TypeString, Nullable: true},
|
||||
{Name: "created_at", Type: field.TypeTime},
|
||||
{Name: "updated_at", Type: field.TypeTime},
|
||||
|
||||
@@ -11994,9 +11994,22 @@ func (m *TaskRecordMutation) OldCompletion(ctx context.Context) (v string, err e
|
||||
return oldValue.Completion, nil
|
||||
}
|
||||
|
||||
// ClearCompletion clears the value of the "completion" field.
|
||||
func (m *TaskRecordMutation) ClearCompletion() {
|
||||
m.completion = nil
|
||||
m.clearedFields[taskrecord.FieldCompletion] = struct{}{}
|
||||
}
|
||||
|
||||
// CompletionCleared returns if the "completion" field was cleared in this mutation.
|
||||
func (m *TaskRecordMutation) CompletionCleared() bool {
|
||||
_, ok := m.clearedFields[taskrecord.FieldCompletion]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetCompletion resets all changes to the "completion" field.
|
||||
func (m *TaskRecordMutation) ResetCompletion() {
|
||||
m.completion = nil
|
||||
delete(m.clearedFields, taskrecord.FieldCompletion)
|
||||
}
|
||||
|
||||
// SetOutputTokens sets the "output_tokens" field.
|
||||
@@ -12509,6 +12522,9 @@ func (m *TaskRecordMutation) ClearedFields() []string {
|
||||
if m.FieldCleared(taskrecord.FieldPrompt) {
|
||||
fields = append(fields, taskrecord.FieldPrompt)
|
||||
}
|
||||
if m.FieldCleared(taskrecord.FieldCompletion) {
|
||||
fields = append(fields, taskrecord.FieldCompletion)
|
||||
}
|
||||
if m.FieldCleared(taskrecord.FieldCode) {
|
||||
fields = append(fields, taskrecord.FieldCode)
|
||||
}
|
||||
@@ -12532,6 +12548,9 @@ func (m *TaskRecordMutation) ClearField(name string) error {
|
||||
case taskrecord.FieldPrompt:
|
||||
m.ClearPrompt()
|
||||
return nil
|
||||
case taskrecord.FieldCompletion:
|
||||
m.ClearCompletion()
|
||||
return nil
|
||||
case taskrecord.FieldCode:
|
||||
m.ClearCode()
|
||||
return nil
|
||||
|
||||
@@ -272,6 +272,14 @@ func init() {
|
||||
task.UpdateDefaultUpdatedAt = taskDescUpdatedAt.UpdateDefault.(func() time.Time)
|
||||
taskrecordFields := schema.TaskRecord{}.Fields()
|
||||
_ = taskrecordFields
|
||||
// taskrecordDescOutputTokens is the schema descriptor for output_tokens field.
|
||||
taskrecordDescOutputTokens := taskrecordFields[5].Descriptor()
|
||||
// taskrecord.DefaultOutputTokens holds the default value on creation for the output_tokens field.
|
||||
taskrecord.DefaultOutputTokens = taskrecordDescOutputTokens.Default.(int64)
|
||||
// taskrecordDescCodeLines is the schema descriptor for code_lines field.
|
||||
taskrecordDescCodeLines := taskrecordFields[6].Descriptor()
|
||||
// taskrecord.DefaultCodeLines holds the default value on creation for the code_lines field.
|
||||
taskrecord.DefaultCodeLines = taskrecordDescCodeLines.Default.(int64)
|
||||
// taskrecordDescCreatedAt is the schema descriptor for created_at field.
|
||||
taskrecordDescCreatedAt := taskrecordFields[8].Descriptor()
|
||||
// taskrecord.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
|
||||
@@ -70,6 +70,10 @@ func ValidColumn(column string) bool {
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
|
||||
DefaultOutputTokens int64
|
||||
// DefaultCodeLines holds the default value on creation for the "code_lines" field.
|
||||
DefaultCodeLines int64
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
DefaultCreatedAt func() time.Time
|
||||
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||
|
||||
@@ -347,6 +347,16 @@ func CompletionHasSuffix(v string) predicate.TaskRecord {
|
||||
return predicate.TaskRecord(sql.FieldHasSuffix(FieldCompletion, v))
|
||||
}
|
||||
|
||||
// CompletionIsNil applies the IsNil predicate on the "completion" field.
|
||||
func CompletionIsNil() predicate.TaskRecord {
|
||||
return predicate.TaskRecord(sql.FieldIsNull(FieldCompletion))
|
||||
}
|
||||
|
||||
// CompletionNotNil applies the NotNil predicate on the "completion" field.
|
||||
func CompletionNotNil() predicate.TaskRecord {
|
||||
return predicate.TaskRecord(sql.FieldNotNull(FieldCompletion))
|
||||
}
|
||||
|
||||
// CompletionEqualFold applies the EqualFold predicate on the "completion" field.
|
||||
func CompletionEqualFold(v string) predicate.TaskRecord {
|
||||
return predicate.TaskRecord(sql.FieldEqualFold(FieldCompletion, v))
|
||||
|
||||
@@ -66,18 +66,42 @@ func (trc *TaskRecordCreate) SetCompletion(s string) *TaskRecordCreate {
|
||||
return trc
|
||||
}
|
||||
|
||||
// SetNillableCompletion sets the "completion" field if the given value is not nil.
|
||||
func (trc *TaskRecordCreate) SetNillableCompletion(s *string) *TaskRecordCreate {
|
||||
if s != nil {
|
||||
trc.SetCompletion(*s)
|
||||
}
|
||||
return trc
|
||||
}
|
||||
|
||||
// SetOutputTokens sets the "output_tokens" field.
|
||||
func (trc *TaskRecordCreate) SetOutputTokens(i int64) *TaskRecordCreate {
|
||||
trc.mutation.SetOutputTokens(i)
|
||||
return trc
|
||||
}
|
||||
|
||||
// SetNillableOutputTokens sets the "output_tokens" field if the given value is not nil.
|
||||
func (trc *TaskRecordCreate) SetNillableOutputTokens(i *int64) *TaskRecordCreate {
|
||||
if i != nil {
|
||||
trc.SetOutputTokens(*i)
|
||||
}
|
||||
return trc
|
||||
}
|
||||
|
||||
// SetCodeLines sets the "code_lines" field.
|
||||
func (trc *TaskRecordCreate) SetCodeLines(i int64) *TaskRecordCreate {
|
||||
trc.mutation.SetCodeLines(i)
|
||||
return trc
|
||||
}
|
||||
|
||||
// SetNillableCodeLines sets the "code_lines" field if the given value is not nil.
|
||||
func (trc *TaskRecordCreate) SetNillableCodeLines(i *int64) *TaskRecordCreate {
|
||||
if i != nil {
|
||||
trc.SetCodeLines(*i)
|
||||
}
|
||||
return trc
|
||||
}
|
||||
|
||||
// SetCode sets the "code" field.
|
||||
func (trc *TaskRecordCreate) SetCode(s string) *TaskRecordCreate {
|
||||
trc.mutation.SetCode(s)
|
||||
@@ -166,6 +190,14 @@ func (trc *TaskRecordCreate) ExecX(ctx context.Context) {
|
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (trc *TaskRecordCreate) defaults() {
|
||||
if _, ok := trc.mutation.OutputTokens(); !ok {
|
||||
v := taskrecord.DefaultOutputTokens
|
||||
trc.mutation.SetOutputTokens(v)
|
||||
}
|
||||
if _, ok := trc.mutation.CodeLines(); !ok {
|
||||
v := taskrecord.DefaultCodeLines
|
||||
trc.mutation.SetCodeLines(v)
|
||||
}
|
||||
if _, ok := trc.mutation.CreatedAt(); !ok {
|
||||
v := taskrecord.DefaultCreatedAt()
|
||||
trc.mutation.SetCreatedAt(v)
|
||||
@@ -181,9 +213,6 @@ func (trc *TaskRecordCreate) check() error {
|
||||
if _, ok := trc.mutation.Role(); !ok {
|
||||
return &ValidationError{Name: "role", err: errors.New(`db: missing required field "TaskRecord.role"`)}
|
||||
}
|
||||
if _, ok := trc.mutation.Completion(); !ok {
|
||||
return &ValidationError{Name: "completion", err: errors.New(`db: missing required field "TaskRecord.completion"`)}
|
||||
}
|
||||
if _, ok := trc.mutation.OutputTokens(); !ok {
|
||||
return &ValidationError{Name: "output_tokens", err: errors.New(`db: missing required field "TaskRecord.output_tokens"`)}
|
||||
}
|
||||
@@ -393,6 +422,12 @@ func (u *TaskRecordUpsert) UpdateCompletion() *TaskRecordUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearCompletion clears the value of the "completion" field.
|
||||
func (u *TaskRecordUpsert) ClearCompletion() *TaskRecordUpsert {
|
||||
u.SetNull(taskrecord.FieldCompletion)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetOutputTokens sets the "output_tokens" field.
|
||||
func (u *TaskRecordUpsert) SetOutputTokens(v int64) *TaskRecordUpsert {
|
||||
u.Set(taskrecord.FieldOutputTokens, v)
|
||||
@@ -589,6 +624,13 @@ func (u *TaskRecordUpsertOne) UpdateCompletion() *TaskRecordUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// ClearCompletion clears the value of the "completion" field.
|
||||
func (u *TaskRecordUpsertOne) ClearCompletion() *TaskRecordUpsertOne {
|
||||
return u.Update(func(s *TaskRecordUpsert) {
|
||||
s.ClearCompletion()
|
||||
})
|
||||
}
|
||||
|
||||
// SetOutputTokens sets the "output_tokens" field.
|
||||
func (u *TaskRecordUpsertOne) SetOutputTokens(v int64) *TaskRecordUpsertOne {
|
||||
return u.Update(func(s *TaskRecordUpsert) {
|
||||
@@ -965,6 +1007,13 @@ func (u *TaskRecordUpsertBulk) UpdateCompletion() *TaskRecordUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// ClearCompletion clears the value of the "completion" field.
|
||||
func (u *TaskRecordUpsertBulk) ClearCompletion() *TaskRecordUpsertBulk {
|
||||
return u.Update(func(s *TaskRecordUpsert) {
|
||||
s.ClearCompletion()
|
||||
})
|
||||
}
|
||||
|
||||
// SetOutputTokens sets the "output_tokens" field.
|
||||
func (u *TaskRecordUpsertBulk) SetOutputTokens(v int64) *TaskRecordUpsertBulk {
|
||||
return u.Update(func(s *TaskRecordUpsert) {
|
||||
|
||||
@@ -100,6 +100,12 @@ func (tru *TaskRecordUpdate) SetNillableCompletion(s *string) *TaskRecordUpdate
|
||||
return tru
|
||||
}
|
||||
|
||||
// ClearCompletion clears the value of the "completion" field.
|
||||
func (tru *TaskRecordUpdate) ClearCompletion() *TaskRecordUpdate {
|
||||
tru.mutation.ClearCompletion()
|
||||
return tru
|
||||
}
|
||||
|
||||
// SetOutputTokens sets the "output_tokens" field.
|
||||
func (tru *TaskRecordUpdate) SetOutputTokens(i int64) *TaskRecordUpdate {
|
||||
tru.mutation.ResetOutputTokens()
|
||||
@@ -261,6 +267,9 @@ func (tru *TaskRecordUpdate) sqlSave(ctx context.Context) (n int, err error) {
|
||||
if value, ok := tru.mutation.Completion(); ok {
|
||||
_spec.SetField(taskrecord.FieldCompletion, field.TypeString, value)
|
||||
}
|
||||
if tru.mutation.CompletionCleared() {
|
||||
_spec.ClearField(taskrecord.FieldCompletion, field.TypeString)
|
||||
}
|
||||
if value, ok := tru.mutation.OutputTokens(); ok {
|
||||
_spec.SetField(taskrecord.FieldOutputTokens, field.TypeInt64, value)
|
||||
}
|
||||
@@ -404,6 +413,12 @@ func (truo *TaskRecordUpdateOne) SetNillableCompletion(s *string) *TaskRecordUpd
|
||||
return truo
|
||||
}
|
||||
|
||||
// ClearCompletion clears the value of the "completion" field.
|
||||
func (truo *TaskRecordUpdateOne) ClearCompletion() *TaskRecordUpdateOne {
|
||||
truo.mutation.ClearCompletion()
|
||||
return truo
|
||||
}
|
||||
|
||||
// SetOutputTokens sets the "output_tokens" field.
|
||||
func (truo *TaskRecordUpdateOne) SetOutputTokens(i int64) *TaskRecordUpdateOne {
|
||||
truo.mutation.ResetOutputTokens()
|
||||
@@ -595,6 +610,9 @@ func (truo *TaskRecordUpdateOne) sqlSave(ctx context.Context) (_node *TaskRecord
|
||||
if value, ok := truo.mutation.Completion(); ok {
|
||||
_spec.SetField(taskrecord.FieldCompletion, field.TypeString, value)
|
||||
}
|
||||
if truo.mutation.CompletionCleared() {
|
||||
_spec.ClearField(taskrecord.FieldCompletion, field.TypeString)
|
||||
}
|
||||
if value, ok := truo.mutation.OutputTokens(); ok {
|
||||
_spec.SetField(taskrecord.FieldOutputTokens, field.TypeInt64, value)
|
||||
}
|
||||
|
||||
@@ -118,7 +118,7 @@ func (c *CompletionInfo) From(e *db.Task) *CompletionInfo {
|
||||
}
|
||||
|
||||
type ChatContent struct {
|
||||
Role consts.ChatRole `json:"role"` // 角色,如user: 用户的提问 assistant: 机器人回复
|
||||
Role consts.ChatRole `json:"role"` // 角色,如user: 用户的提问 assistant: 机器人回复 system: 系统消息
|
||||
Content string `json:"content"` // 内容
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
@@ -133,6 +133,8 @@ func (c *ChatContent) From(e *db.TaskRecord) *ChatContent {
|
||||
c.Content = e.Prompt
|
||||
case consts.ChatRoleAssistant:
|
||||
c.Content = e.Completion
|
||||
case consts.ChatRoleSystem:
|
||||
c.Content = e.Completion
|
||||
}
|
||||
c.CreatedAt = e.CreatedAt.Unix()
|
||||
return c
|
||||
@@ -151,6 +153,9 @@ func (c *ChatInfo) From(e *db.Task) *ChatInfo {
|
||||
c.Contents = cvt.Iter(e.Edges.TaskRecords, func(_ int, r *db.TaskRecord) *ChatContent {
|
||||
return cvt.From(r, &ChatContent{})
|
||||
})
|
||||
c.Contents = cvt.Filter(c.Contents, func(_ int, r *ChatContent) (*ChatContent, bool) {
|
||||
return r, r.Content != ""
|
||||
})
|
||||
return c
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ type ProxyRepo interface {
|
||||
Record(ctx context.Context, record *RecordParam) error
|
||||
UpdateByTaskID(ctx context.Context, taskID string, fn func(*db.TaskUpdateOne)) error
|
||||
AcceptCompletion(ctx context.Context, req *AcceptCompletionReq) error
|
||||
Report(ctx context.Context, req *ReportReq) error
|
||||
Report(ctx context.Context, model *db.Model, req *ReportReq) error
|
||||
SelectModelWithLoadBalancing(modelName string, modelType consts.ModelType) (*db.Model, error)
|
||||
ValidateApiKey(ctx context.Context, key string) (*db.ApiKey, error)
|
||||
}
|
||||
@@ -52,6 +52,8 @@ type ReportReq struct {
|
||||
UserInput string `json:"user_input"` // 用户输入的新文本(用于reject action)
|
||||
SourceCode string `json:"source_code"` // 当前文件的原文(用于reject action)
|
||||
CursorPosition map[string]any `json:"cursor_position"` // 光标位置(用于reject action)
|
||||
Mode string `json:"mode"` // 模式
|
||||
UserID string `json:"-"`
|
||||
}
|
||||
|
||||
type RecordParam struct {
|
||||
|
||||
@@ -33,9 +33,9 @@ func (TaskRecord) Fields() []ent.Field {
|
||||
field.UUID("task_id", uuid.UUID{}).Optional(),
|
||||
field.String("prompt").Optional(),
|
||||
field.String("role").GoType(consts.ChatRole("")),
|
||||
field.String("completion"),
|
||||
field.Int64("output_tokens"),
|
||||
field.Int64("code_lines"),
|
||||
field.String("completion").Optional(),
|
||||
field.Int64("output_tokens").Default(0),
|
||||
field.Int64("code_lines").Default(0),
|
||||
field.String("code").Optional(),
|
||||
field.Time("created_at").Default(time.Now),
|
||||
field.Time("updated_at").Default(time.Now).UpdateDefault(time.Now),
|
||||
|
||||
@@ -29,7 +29,6 @@ func (b *BillingRepo) ChatInfo(ctx context.Context, id, userID string) (*domain.
|
||||
q := b.db.Task.Query().
|
||||
WithTaskRecords(func(trq *db.TaskRecordQuery) {
|
||||
trq.Order(taskrecord.ByCreatedAt(sql.OrderAsc()))
|
||||
trq.Where(taskrecord.RoleNEQ(consts.ChatRoleSystem))
|
||||
}).
|
||||
Where(task.TaskID(id))
|
||||
if userID != "" {
|
||||
|
||||
@@ -110,6 +110,7 @@ func (h *V1Handler) AcceptCompletion(c *web.Context, req domain.AcceptCompletion
|
||||
// @Router /v1/report [post]
|
||||
func (h *V1Handler) Report(c *web.Context, req domain.ReportReq) error {
|
||||
h.logger.DebugContext(c.Request().Context(), "Report", slog.Any("req", req))
|
||||
req.UserID = middleware.GetApiKey(c).UserID
|
||||
if err := h.proxyUse.Report(c.Request().Context(), &req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -153,19 +153,8 @@ func (r *Recorder) handleShadow() {
|
||||
With("resp_header", formatHeader(r.ctx.RespHeader)).
|
||||
DebugContext(r.ctx.ctx, "handle shadow", "rc", rc)
|
||||
|
||||
// 记录用户的提问
|
||||
if r.ctx.Model.ModelType == consts.ModelTypeLLM && prompt != "" {
|
||||
tmp := rc.Clone()
|
||||
tmp.Role = consts.ChatRoleUser
|
||||
tmp.Completion = ""
|
||||
tmp.OutputTokens = 0
|
||||
if err := r.usecase.Record(context.Background(), tmp); err != nil {
|
||||
r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.usecase.Record(context.Background(), rc); err != nil {
|
||||
r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err)
|
||||
r.logger.With("record", rc).WarnContext(r.ctx.ctx, "记录请求失败", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -76,6 +77,9 @@ func (r *ProxyRepo) ValidateApiKey(ctx context.Context, key string) (*db.ApiKey,
|
||||
}
|
||||
|
||||
func (r *ProxyRepo) Record(ctx context.Context, record *domain.RecordParam) error {
|
||||
if record.TaskID == "" {
|
||||
return fmt.Errorf("task_id is empty")
|
||||
}
|
||||
userID, err := uuid.Parse(record.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -212,14 +216,52 @@ func abs(x int64) int64 {
|
||||
return x
|
||||
}
|
||||
|
||||
func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error {
|
||||
func (r *ProxyRepo) Report(ctx context.Context, model *db.Model, req *domain.ReportReq) error {
|
||||
return entx.WithTx(ctx, r.db, func(tx *db.Tx) error {
|
||||
rc, err := tx.Task.Query().Where(task.TaskID(req.ID)).Only(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
if req.Action == consts.ReportActionNewTask && db.IsNotFound(err) {
|
||||
uid, err := uuid.Parse(req.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newTask, err := tx.Task.Create().
|
||||
SetTaskID(req.ID).
|
||||
SetRequestID(uuid.NewString()).
|
||||
SetUserID(uid).
|
||||
SetModelID(model.ID).
|
||||
SetModelType(model.ModelType).
|
||||
SetWorkMode(req.Mode).
|
||||
SetPrompt(req.Content).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rc = newTask
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
switch req.Action {
|
||||
case consts.ReportActionNewTask, consts.ReportActionFeedbackTask:
|
||||
if err := tx.TaskRecord.Create().
|
||||
SetTaskID(rc.ID).
|
||||
SetRole(consts.ChatRoleUser).
|
||||
SetPrompt(req.Content).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case consts.ReportActionAbortTask:
|
||||
if err := tx.TaskRecord.Create().
|
||||
SetTaskID(rc.ID).
|
||||
SetRole(consts.ChatRoleSystem).
|
||||
SetCompletion(req.Content).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case consts.ReportActionAccept:
|
||||
if err := tx.Task.UpdateOneID(rc.ID).
|
||||
SetIsAccept(true).
|
||||
|
||||
@@ -2,7 +2,9 @@ package usecase
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/chaitin/MonkeyCode/backend/db"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
|
||||
|
||||
"github.com/chaitin/MonkeyCode/backend/consts"
|
||||
@@ -12,10 +14,19 @@ import (
|
||||
type ProxyUsecase struct {
|
||||
repo domain.ProxyRepo
|
||||
modelRepo domain.ModelRepo
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewProxyUsecase(repo domain.ProxyRepo, modelRepo domain.ModelRepo) domain.ProxyUsecase {
|
||||
return &ProxyUsecase{repo: repo, modelRepo: modelRepo}
|
||||
func NewProxyUsecase(
|
||||
repo domain.ProxyRepo,
|
||||
modelRepo domain.ModelRepo,
|
||||
logger *slog.Logger,
|
||||
) domain.ProxyUsecase {
|
||||
return &ProxyUsecase{
|
||||
repo: repo,
|
||||
modelRepo: modelRepo,
|
||||
logger: logger.With("module", "ProxyUsecase"),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyUsecase) Record(ctx context.Context, record *domain.RecordParam) error {
|
||||
@@ -44,5 +55,14 @@ func (p *ProxyUsecase) AcceptCompletion(ctx context.Context, req *domain.AcceptC
|
||||
}
|
||||
|
||||
func (p *ProxyUsecase) Report(ctx context.Context, req *domain.ReportReq) error {
|
||||
return p.repo.Report(ctx, req)
|
||||
var model *db.Model
|
||||
var err error
|
||||
if req.Action == consts.ReportActionNewTask {
|
||||
model, err = p.modelRepo.GetWithCache(context.Background(), consts.ModelTypeLLM)
|
||||
if err != nil {
|
||||
p.logger.With("fn", "Report").With("error", err).ErrorContext(ctx, "failed to get model")
|
||||
return err
|
||||
}
|
||||
}
|
||||
return p.repo.Report(ctx, model, req)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user