diff --git a/backend/consts/proxy.go b/backend/consts/proxy.go index f4c72eb..aa3dabe 100644 --- a/backend/consts/proxy.go +++ b/backend/consts/proxy.go @@ -6,4 +6,5 @@ const ( ReportActionAccept ReportAction = "accept" ReportActionSuggest ReportAction = "suggest" ReportActionFileWritten ReportAction = "file_written" + ReportActionReject ReportAction = "reject" ) diff --git a/backend/domain/proxy.go b/backend/domain/proxy.go index 11fc83d..6142805 100644 --- a/backend/domain/proxy.go +++ b/backend/domain/proxy.go @@ -45,10 +45,13 @@ type AcceptCompletionReq struct { } type ReportReq struct { - Action consts.ReportAction `json:"action"` - ID string `json:"id"` // task_id or resp_id - Content string `json:"content"` // 内容 - Tool string `json:"tool"` // 工具 + Action consts.ReportAction `json:"action"` + ID string `json:"id"` // task_id or resp_id + Content string `json:"content"` // 内容 + Tool string `json:"tool"` // 工具 + UserInput string `json:"user_input"` // 用户输入的新文本(用于reject action) + SourceCode string `json:"source_code"` // 当前文件的原文(用于reject action) + CursorPosition int64 `json:"cursor_position"` // 光标位置(用于reject action) } type RecordParam struct { diff --git a/backend/internal/openai/handler/v1/v1.go b/backend/internal/openai/handler/v1/v1.go index 6a40cb1..df42871 100644 --- a/backend/internal/openai/handler/v1/v1.go +++ b/backend/internal/openai/handler/v1/v1.go @@ -101,7 +101,7 @@ func (h *V1Handler) AcceptCompletion(c *web.Context, req domain.AcceptCompletion // // @Tags OpenAIV1 // @Summary 报告 -// @Description 报告 +// @Description 报告,支持多种操作:accept(接受补全)、suggest(建议)、reject(拒绝补全并记录用户输入)、file_written(文件写入) // @ID report // @Accept json // @Produce json diff --git a/backend/internal/proxy/repo/proxy.go b/backend/internal/proxy/repo/proxy.go index 532e91d..e146a6d 100644 --- a/backend/internal/proxy/repo/proxy.go +++ b/backend/internal/proxy/repo/proxy.go @@ -203,6 +203,14 @@ func (r *ProxyRepo) AcceptCompletion(ctx context.Context, req *domain.AcceptComp }) } +// abs 返回整数的绝对值 +func abs(x int64) int64 { + if x < 0 { + return -x + } + return x +} + func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error { return entx.WithTx(ctx, r.db, func(tx *db.Tx) error { rc, err := tx.Task.Query().Where(task.TaskID(req.ID)).Only(ctx) @@ -235,6 +243,11 @@ func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error { Where(taskrecord.TaskID(rc.ID)). SetCompletion(req.Content).Exec(ctx) + case consts.ReportActionReject: + if err := r.handleRejectCompletion(ctx, tx, rc, req); err != nil { + return err + } + case consts.ReportActionFileWritten: if err := r.handleFileWritten(ctx, tx, rc, req); err != nil { return err @@ -245,6 +258,41 @@ func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error { }) } +func (r *ProxyRepo) handleRejectCompletion(ctx context.Context, tx *db.Tx, rc *db.Task, req *domain.ReportReq) error { + // 检测用户是否在同样的光标位置输入了新代码 + // 1. 检查光标位置是否匹配 + // 2. 检查源代码是否发生了变化 + // 3. 只有在检测到变化时才记录用户输入 + + shouldRecord := false + + // 检查光标位置是否匹配(允许小的误差) + if req.CursorPosition > 0 && rc.CursorPosition > 0 { + positionDiff := abs(req.CursorPosition - rc.CursorPosition) + if positionDiff <= 10 { // 允许10个字符的误差 + shouldRecord = true + } + } + + // 检查源代码是否发生了变化 + if req.SourceCode != "" && rc.SourceCode != "" { + if req.SourceCode != rc.SourceCode { + shouldRecord = true + } + } + + // 如果检测到变化,记录用户输入 + if shouldRecord && req.UserInput != "" { + if err := tx.Task.UpdateOneID(rc.ID). + SetUserInput(req.UserInput). + Exec(ctx); err != nil { + return err + } + } + + return nil +} + func (r *ProxyRepo) handleFileWritten(ctx context.Context, tx *db.Tx, rc *db.Task, req *domain.ReportReq) error { lineCount := 0 switch req.Tool {