Merge pull request #141 from yokowu/feat-task-report

feat(tasks): 优化任务上报
This commit is contained in:
Yoko
2025-07-24 18:52:50 +08:00
committed by GitHub
17 changed files with 203 additions and 34 deletions

View File

@@ -61,7 +61,7 @@ func newServer() (*Server, error) {
redisClient := store.NewRedisCli(configConfig)
proxyRepo := repo.NewProxyRepo(client, redisClient)
modelRepo := repo2.NewModelRepo(client)
proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo)
proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo, slogLogger)
llmProxy := proxy.NewLLMProxy(slogLogger, configConfig, proxyUsecase)
openAIRepo := repo3.NewOpenAIRepo(client)
openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, slogLogger)

View File

@@ -3,8 +3,11 @@ package consts
type ReportAction string
const (
ReportActionAccept ReportAction = "accept"
ReportActionSuggest ReportAction = "suggest"
ReportActionFileWritten ReportAction = "file_written"
ReportActionReject ReportAction = "reject"
ReportActionAccept ReportAction = "accept"
ReportActionSuggest ReportAction = "suggest"
ReportActionFileWritten ReportAction = "file_written"
ReportActionReject ReportAction = "reject"
ReportActionNewTask ReportAction = "new_task"
ReportActionFeedbackTask ReportAction = "feedback_task"
ReportActionAbortTask ReportAction = "abort_task"
)

View File

@@ -305,9 +305,9 @@ var (
{Name: "id", Type: field.TypeUUID, Unique: true},
{Name: "prompt", Type: field.TypeString, Nullable: true},
{Name: "role", Type: field.TypeString},
{Name: "completion", Type: field.TypeString},
{Name: "output_tokens", Type: field.TypeInt64},
{Name: "code_lines", Type: field.TypeInt64},
{Name: "completion", Type: field.TypeString, Nullable: true},
{Name: "output_tokens", Type: field.TypeInt64, Default: 0},
{Name: "code_lines", Type: field.TypeInt64, Default: 0},
{Name: "code", Type: field.TypeString, Nullable: true},
{Name: "created_at", Type: field.TypeTime},
{Name: "updated_at", Type: field.TypeTime},

View File

@@ -11994,9 +11994,22 @@ func (m *TaskRecordMutation) OldCompletion(ctx context.Context) (v string, err e
return oldValue.Completion, nil
}
// ClearCompletion clears the value of the "completion" field.
func (m *TaskRecordMutation) ClearCompletion() {
m.completion = nil
m.clearedFields[taskrecord.FieldCompletion] = struct{}{}
}
// CompletionCleared returns if the "completion" field was cleared in this mutation.
func (m *TaskRecordMutation) CompletionCleared() bool {
_, ok := m.clearedFields[taskrecord.FieldCompletion]
return ok
}
// ResetCompletion resets all changes to the "completion" field.
func (m *TaskRecordMutation) ResetCompletion() {
m.completion = nil
delete(m.clearedFields, taskrecord.FieldCompletion)
}
// SetOutputTokens sets the "output_tokens" field.
@@ -12509,6 +12522,9 @@ func (m *TaskRecordMutation) ClearedFields() []string {
if m.FieldCleared(taskrecord.FieldPrompt) {
fields = append(fields, taskrecord.FieldPrompt)
}
if m.FieldCleared(taskrecord.FieldCompletion) {
fields = append(fields, taskrecord.FieldCompletion)
}
if m.FieldCleared(taskrecord.FieldCode) {
fields = append(fields, taskrecord.FieldCode)
}
@@ -12532,6 +12548,9 @@ func (m *TaskRecordMutation) ClearField(name string) error {
case taskrecord.FieldPrompt:
m.ClearPrompt()
return nil
case taskrecord.FieldCompletion:
m.ClearCompletion()
return nil
case taskrecord.FieldCode:
m.ClearCode()
return nil

View File

@@ -272,6 +272,14 @@ func init() {
task.UpdateDefaultUpdatedAt = taskDescUpdatedAt.UpdateDefault.(func() time.Time)
taskrecordFields := schema.TaskRecord{}.Fields()
_ = taskrecordFields
// taskrecordDescOutputTokens is the schema descriptor for output_tokens field.
taskrecordDescOutputTokens := taskrecordFields[5].Descriptor()
// taskrecord.DefaultOutputTokens holds the default value on creation for the output_tokens field.
taskrecord.DefaultOutputTokens = taskrecordDescOutputTokens.Default.(int64)
// taskrecordDescCodeLines is the schema descriptor for code_lines field.
taskrecordDescCodeLines := taskrecordFields[6].Descriptor()
// taskrecord.DefaultCodeLines holds the default value on creation for the code_lines field.
taskrecord.DefaultCodeLines = taskrecordDescCodeLines.Default.(int64)
// taskrecordDescCreatedAt is the schema descriptor for created_at field.
taskrecordDescCreatedAt := taskrecordFields[8].Descriptor()
// taskrecord.DefaultCreatedAt holds the default value on creation for the created_at field.

View File

@@ -70,6 +70,10 @@ func ValidColumn(column string) bool {
}
var (
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
DefaultOutputTokens int64
// DefaultCodeLines holds the default value on creation for the "code_lines" field.
DefaultCodeLines int64
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.

View File

@@ -347,6 +347,16 @@ func CompletionHasSuffix(v string) predicate.TaskRecord {
return predicate.TaskRecord(sql.FieldHasSuffix(FieldCompletion, v))
}
// CompletionIsNil applies the IsNil predicate on the "completion" field.
func CompletionIsNil() predicate.TaskRecord {
return predicate.TaskRecord(sql.FieldIsNull(FieldCompletion))
}
// CompletionNotNil applies the NotNil predicate on the "completion" field.
func CompletionNotNil() predicate.TaskRecord {
return predicate.TaskRecord(sql.FieldNotNull(FieldCompletion))
}
// CompletionEqualFold applies the EqualFold predicate on the "completion" field.
func CompletionEqualFold(v string) predicate.TaskRecord {
return predicate.TaskRecord(sql.FieldEqualFold(FieldCompletion, v))

View File

@@ -66,18 +66,42 @@ func (trc *TaskRecordCreate) SetCompletion(s string) *TaskRecordCreate {
return trc
}
// SetNillableCompletion sets the "completion" field if the given value is not nil.
func (trc *TaskRecordCreate) SetNillableCompletion(s *string) *TaskRecordCreate {
if s != nil {
trc.SetCompletion(*s)
}
return trc
}
// SetOutputTokens sets the "output_tokens" field.
func (trc *TaskRecordCreate) SetOutputTokens(i int64) *TaskRecordCreate {
trc.mutation.SetOutputTokens(i)
return trc
}
// SetNillableOutputTokens sets the "output_tokens" field if the given value is not nil.
func (trc *TaskRecordCreate) SetNillableOutputTokens(i *int64) *TaskRecordCreate {
if i != nil {
trc.SetOutputTokens(*i)
}
return trc
}
// SetCodeLines sets the "code_lines" field.
func (trc *TaskRecordCreate) SetCodeLines(i int64) *TaskRecordCreate {
trc.mutation.SetCodeLines(i)
return trc
}
// SetNillableCodeLines sets the "code_lines" field if the given value is not nil.
func (trc *TaskRecordCreate) SetNillableCodeLines(i *int64) *TaskRecordCreate {
if i != nil {
trc.SetCodeLines(*i)
}
return trc
}
// SetCode sets the "code" field.
func (trc *TaskRecordCreate) SetCode(s string) *TaskRecordCreate {
trc.mutation.SetCode(s)
@@ -166,6 +190,14 @@ func (trc *TaskRecordCreate) ExecX(ctx context.Context) {
// defaults sets the default values of the builder before save.
func (trc *TaskRecordCreate) defaults() {
if _, ok := trc.mutation.OutputTokens(); !ok {
v := taskrecord.DefaultOutputTokens
trc.mutation.SetOutputTokens(v)
}
if _, ok := trc.mutation.CodeLines(); !ok {
v := taskrecord.DefaultCodeLines
trc.mutation.SetCodeLines(v)
}
if _, ok := trc.mutation.CreatedAt(); !ok {
v := taskrecord.DefaultCreatedAt()
trc.mutation.SetCreatedAt(v)
@@ -181,9 +213,6 @@ func (trc *TaskRecordCreate) check() error {
if _, ok := trc.mutation.Role(); !ok {
return &ValidationError{Name: "role", err: errors.New(`db: missing required field "TaskRecord.role"`)}
}
if _, ok := trc.mutation.Completion(); !ok {
return &ValidationError{Name: "completion", err: errors.New(`db: missing required field "TaskRecord.completion"`)}
}
if _, ok := trc.mutation.OutputTokens(); !ok {
return &ValidationError{Name: "output_tokens", err: errors.New(`db: missing required field "TaskRecord.output_tokens"`)}
}
@@ -393,6 +422,12 @@ func (u *TaskRecordUpsert) UpdateCompletion() *TaskRecordUpsert {
return u
}
// ClearCompletion clears the value of the "completion" field.
func (u *TaskRecordUpsert) ClearCompletion() *TaskRecordUpsert {
u.SetNull(taskrecord.FieldCompletion)
return u
}
// SetOutputTokens sets the "output_tokens" field.
func (u *TaskRecordUpsert) SetOutputTokens(v int64) *TaskRecordUpsert {
u.Set(taskrecord.FieldOutputTokens, v)
@@ -589,6 +624,13 @@ func (u *TaskRecordUpsertOne) UpdateCompletion() *TaskRecordUpsertOne {
})
}
// ClearCompletion clears the value of the "completion" field.
func (u *TaskRecordUpsertOne) ClearCompletion() *TaskRecordUpsertOne {
return u.Update(func(s *TaskRecordUpsert) {
s.ClearCompletion()
})
}
// SetOutputTokens sets the "output_tokens" field.
func (u *TaskRecordUpsertOne) SetOutputTokens(v int64) *TaskRecordUpsertOne {
return u.Update(func(s *TaskRecordUpsert) {
@@ -965,6 +1007,13 @@ func (u *TaskRecordUpsertBulk) UpdateCompletion() *TaskRecordUpsertBulk {
})
}
// ClearCompletion clears the value of the "completion" field.
func (u *TaskRecordUpsertBulk) ClearCompletion() *TaskRecordUpsertBulk {
return u.Update(func(s *TaskRecordUpsert) {
s.ClearCompletion()
})
}
// SetOutputTokens sets the "output_tokens" field.
func (u *TaskRecordUpsertBulk) SetOutputTokens(v int64) *TaskRecordUpsertBulk {
return u.Update(func(s *TaskRecordUpsert) {

View File

@@ -100,6 +100,12 @@ func (tru *TaskRecordUpdate) SetNillableCompletion(s *string) *TaskRecordUpdate
return tru
}
// ClearCompletion clears the value of the "completion" field.
func (tru *TaskRecordUpdate) ClearCompletion() *TaskRecordUpdate {
tru.mutation.ClearCompletion()
return tru
}
// SetOutputTokens sets the "output_tokens" field.
func (tru *TaskRecordUpdate) SetOutputTokens(i int64) *TaskRecordUpdate {
tru.mutation.ResetOutputTokens()
@@ -261,6 +267,9 @@ func (tru *TaskRecordUpdate) sqlSave(ctx context.Context) (n int, err error) {
if value, ok := tru.mutation.Completion(); ok {
_spec.SetField(taskrecord.FieldCompletion, field.TypeString, value)
}
if tru.mutation.CompletionCleared() {
_spec.ClearField(taskrecord.FieldCompletion, field.TypeString)
}
if value, ok := tru.mutation.OutputTokens(); ok {
_spec.SetField(taskrecord.FieldOutputTokens, field.TypeInt64, value)
}
@@ -404,6 +413,12 @@ func (truo *TaskRecordUpdateOne) SetNillableCompletion(s *string) *TaskRecordUpd
return truo
}
// ClearCompletion clears the value of the "completion" field.
func (truo *TaskRecordUpdateOne) ClearCompletion() *TaskRecordUpdateOne {
truo.mutation.ClearCompletion()
return truo
}
// SetOutputTokens sets the "output_tokens" field.
func (truo *TaskRecordUpdateOne) SetOutputTokens(i int64) *TaskRecordUpdateOne {
truo.mutation.ResetOutputTokens()
@@ -595,6 +610,9 @@ func (truo *TaskRecordUpdateOne) sqlSave(ctx context.Context) (_node *TaskRecord
if value, ok := truo.mutation.Completion(); ok {
_spec.SetField(taskrecord.FieldCompletion, field.TypeString, value)
}
if truo.mutation.CompletionCleared() {
_spec.ClearField(taskrecord.FieldCompletion, field.TypeString)
}
if value, ok := truo.mutation.OutputTokens(); ok {
_spec.SetField(taskrecord.FieldOutputTokens, field.TypeInt64, value)
}

View File

@@ -118,7 +118,7 @@ func (c *CompletionInfo) From(e *db.Task) *CompletionInfo {
}
type ChatContent struct {
Role consts.ChatRole `json:"role"` // 角色如user: 用户的提问 assistant: 机器人回复
Role consts.ChatRole `json:"role"` // 角色如user: 用户的提问 assistant: 机器人回复 system: 系统消息
Content string `json:"content"` // 内容
CreatedAt int64 `json:"created_at"`
}
@@ -133,6 +133,8 @@ func (c *ChatContent) From(e *db.TaskRecord) *ChatContent {
c.Content = e.Prompt
case consts.ChatRoleAssistant:
c.Content = e.Completion
case consts.ChatRoleSystem:
c.Content = e.Completion
}
c.CreatedAt = e.CreatedAt.Unix()
return c
@@ -151,6 +153,9 @@ func (c *ChatInfo) From(e *db.Task) *ChatInfo {
c.Contents = cvt.Iter(e.Edges.TaskRecords, func(_ int, r *db.TaskRecord) *ChatContent {
return cvt.From(r, &ChatContent{})
})
c.Contents = cvt.Filter(c.Contents, func(_ int, r *ChatContent) (*ChatContent, bool) {
return r, r.Content != ""
})
return c
}

View File

@@ -29,7 +29,7 @@ type ProxyRepo interface {
Record(ctx context.Context, record *RecordParam) error
UpdateByTaskID(ctx context.Context, taskID string, fn func(*db.TaskUpdateOne)) error
AcceptCompletion(ctx context.Context, req *AcceptCompletionReq) error
Report(ctx context.Context, req *ReportReq) error
Report(ctx context.Context, model *db.Model, req *ReportReq) error
SelectModelWithLoadBalancing(modelName string, modelType consts.ModelType) (*db.Model, error)
ValidateApiKey(ctx context.Context, key string) (*db.ApiKey, error)
}
@@ -52,6 +52,8 @@ type ReportReq struct {
UserInput string `json:"user_input"` // 用户输入的新文本用于reject action
SourceCode string `json:"source_code"` // 当前文件的原文用于reject action
CursorPosition map[string]any `json:"cursor_position"` // 光标位置用于reject action
Mode string `json:"mode"` // 模式
UserID string `json:"-"`
}
type RecordParam struct {

View File

@@ -33,9 +33,9 @@ func (TaskRecord) Fields() []ent.Field {
field.UUID("task_id", uuid.UUID{}).Optional(),
field.String("prompt").Optional(),
field.String("role").GoType(consts.ChatRole("")),
field.String("completion"),
field.Int64("output_tokens"),
field.Int64("code_lines"),
field.String("completion").Optional(),
field.Int64("output_tokens").Default(0),
field.Int64("code_lines").Default(0),
field.String("code").Optional(),
field.Time("created_at").Default(time.Now),
field.Time("updated_at").Default(time.Now).UpdateDefault(time.Now),

View File

@@ -29,7 +29,6 @@ func (b *BillingRepo) ChatInfo(ctx context.Context, id, userID string) (*domain.
q := b.db.Task.Query().
WithTaskRecords(func(trq *db.TaskRecordQuery) {
trq.Order(taskrecord.ByCreatedAt(sql.OrderAsc()))
trq.Where(taskrecord.RoleNEQ(consts.ChatRoleSystem))
}).
Where(task.TaskID(id))
if userID != "" {

View File

@@ -110,6 +110,7 @@ func (h *V1Handler) AcceptCompletion(c *web.Context, req domain.AcceptCompletion
// @Router /v1/report [post]
func (h *V1Handler) Report(c *web.Context, req domain.ReportReq) error {
h.logger.DebugContext(c.Request().Context(), "Report", slog.Any("req", req))
req.UserID = middleware.GetApiKey(c).UserID
if err := h.proxyUse.Report(c.Request().Context(), &req); err != nil {
return err
}

View File

@@ -153,19 +153,8 @@ func (r *Recorder) handleShadow() {
With("resp_header", formatHeader(r.ctx.RespHeader)).
DebugContext(r.ctx.ctx, "handle shadow", "rc", rc)
// 记录用户的提问
if r.ctx.Model.ModelType == consts.ModelTypeLLM && prompt != "" {
tmp := rc.Clone()
tmp.Role = consts.ChatRoleUser
tmp.Completion = ""
tmp.OutputTokens = 0
if err := r.usecase.Record(context.Background(), tmp); err != nil {
r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err)
}
}
if err := r.usecase.Record(context.Background(), rc); err != nil {
r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err)
r.logger.With("record", rc).WarnContext(r.ctx.ctx, "记录请求失败", "error", err)
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
@@ -76,6 +77,9 @@ func (r *ProxyRepo) ValidateApiKey(ctx context.Context, key string) (*db.ApiKey,
}
func (r *ProxyRepo) Record(ctx context.Context, record *domain.RecordParam) error {
if record.TaskID == "" {
return fmt.Errorf("task_id is empty")
}
userID, err := uuid.Parse(record.UserID)
if err != nil {
return err
@@ -212,14 +216,52 @@ func abs(x int64) int64 {
return x
}
func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error {
func (r *ProxyRepo) Report(ctx context.Context, model *db.Model, req *domain.ReportReq) error {
return entx.WithTx(ctx, r.db, func(tx *db.Tx) error {
rc, err := tx.Task.Query().Where(task.TaskID(req.ID)).Only(ctx)
if err != nil {
return err
if req.Action == consts.ReportActionNewTask && db.IsNotFound(err) {
uid, err := uuid.Parse(req.UserID)
if err != nil {
return err
}
newTask, err := tx.Task.Create().
SetTaskID(req.ID).
SetRequestID(uuid.NewString()).
SetUserID(uid).
SetModelID(model.ID).
SetModelType(model.ModelType).
SetWorkMode(req.Mode).
SetPrompt(req.Content).
Save(ctx)
if err != nil {
return err
}
rc = newTask
} else {
return err
}
}
switch req.Action {
case consts.ReportActionNewTask, consts.ReportActionFeedbackTask:
if err := tx.TaskRecord.Create().
SetTaskID(rc.ID).
SetRole(consts.ChatRoleUser).
SetPrompt(req.Content).
Exec(ctx); err != nil {
return err
}
case consts.ReportActionAbortTask:
if err := tx.TaskRecord.Create().
SetTaskID(rc.ID).
SetRole(consts.ChatRoleSystem).
SetCompletion(req.Content).
Exec(ctx); err != nil {
return err
}
case consts.ReportActionAccept:
if err := tx.Task.UpdateOneID(rc.ID).
SetIsAccept(true).

View File

@@ -2,7 +2,9 @@ package usecase
import (
"context"
"log/slog"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
"github.com/chaitin/MonkeyCode/backend/consts"
@@ -12,10 +14,19 @@ import (
type ProxyUsecase struct {
repo domain.ProxyRepo
modelRepo domain.ModelRepo
logger *slog.Logger
}
func NewProxyUsecase(repo domain.ProxyRepo, modelRepo domain.ModelRepo) domain.ProxyUsecase {
return &ProxyUsecase{repo: repo, modelRepo: modelRepo}
func NewProxyUsecase(
repo domain.ProxyRepo,
modelRepo domain.ModelRepo,
logger *slog.Logger,
) domain.ProxyUsecase {
return &ProxyUsecase{
repo: repo,
modelRepo: modelRepo,
logger: logger.With("module", "ProxyUsecase"),
}
}
func (p *ProxyUsecase) Record(ctx context.Context, record *domain.RecordParam) error {
@@ -44,5 +55,14 @@ func (p *ProxyUsecase) AcceptCompletion(ctx context.Context, req *domain.AcceptC
}
func (p *ProxyUsecase) Report(ctx context.Context, req *domain.ReportReq) error {
return p.repo.Report(ctx, req)
var model *db.Model
var err error
if req.Action == consts.ReportActionNewTask {
model, err = p.modelRepo.GetWithCache(context.Background(), consts.ModelTypeLLM)
if err != nil {
p.logger.With("fn", "Report").With("error", err).ErrorContext(ctx, "failed to get model")
return err
}
}
return p.repo.Report(ctx, model, req)
}