Merge pull request #119 from AlanAlanAlanXu/main

引入了用户拒绝补全时的记录请求
This commit is contained in:
Yoko
2025-07-22 10:40:25 +08:00
committed by GitHub
4 changed files with 57 additions and 5 deletions

View File

@@ -6,4 +6,5 @@ const (
ReportActionAccept ReportAction = "accept"
ReportActionSuggest ReportAction = "suggest"
ReportActionFileWritten ReportAction = "file_written"
ReportActionReject ReportAction = "reject"
)

View File

@@ -45,10 +45,13 @@ type AcceptCompletionReq struct {
}
type ReportReq struct {
Action consts.ReportAction `json:"action"`
ID string `json:"id"` // task_id or resp_id
Content string `json:"content"` // 内容
Tool string `json:"tool"` // 工具
Action consts.ReportAction `json:"action"`
ID string `json:"id"` // task_id or resp_id
Content string `json:"content"` // 内容
Tool string `json:"tool"` // 工具
UserInput string `json:"user_input"` // 用户输入的新文本用于reject action
SourceCode string `json:"source_code"` // 当前文件的原文用于reject action
CursorPosition int64 `json:"cursor_position"` // 光标位置用于reject action
}
type RecordParam struct {

View File

@@ -101,7 +101,7 @@ func (h *V1Handler) AcceptCompletion(c *web.Context, req domain.AcceptCompletion
//
// @Tags OpenAIV1
// @Summary 报告
// @Description 报告
// @Description 报告支持多种操作accept接受补全、suggest建议、reject拒绝补全并记录用户输入、file_written文件写入
// @ID report
// @Accept json
// @Produce json

View File

@@ -203,6 +203,14 @@ func (r *ProxyRepo) AcceptCompletion(ctx context.Context, req *domain.AcceptComp
})
}
// abs 返回整数的绝对值
func abs(x int64) int64 {
if x < 0 {
return -x
}
return x
}
func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error {
return entx.WithTx(ctx, r.db, func(tx *db.Tx) error {
rc, err := tx.Task.Query().Where(task.TaskID(req.ID)).Only(ctx)
@@ -235,6 +243,11 @@ func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error {
Where(taskrecord.TaskID(rc.ID)).
SetCompletion(req.Content).Exec(ctx)
case consts.ReportActionReject:
if err := r.handleRejectCompletion(ctx, tx, rc, req); err != nil {
return err
}
case consts.ReportActionFileWritten:
if err := r.handleFileWritten(ctx, tx, rc, req); err != nil {
return err
@@ -245,6 +258,41 @@ func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error {
})
}
func (r *ProxyRepo) handleRejectCompletion(ctx context.Context, tx *db.Tx, rc *db.Task, req *domain.ReportReq) error {
// 检测用户是否在同样的光标位置输入了新代码
// 1. 检查光标位置是否匹配
// 2. 检查源代码是否发生了变化
// 3. 只有在检测到变化时才记录用户输入
shouldRecord := false
// 检查光标位置是否匹配(允许小的误差)
if req.CursorPosition > 0 && rc.CursorPosition > 0 {
positionDiff := abs(req.CursorPosition - rc.CursorPosition)
if positionDiff <= 10 { // 允许10个字符的误差
shouldRecord = true
}
}
// 检查源代码是否发生了变化
if req.SourceCode != "" && rc.SourceCode != "" {
if req.SourceCode != rc.SourceCode {
shouldRecord = true
}
}
// 如果检测到变化,记录用户输入
if shouldRecord && req.UserInput != "" {
if err := tx.Task.UpdateOneID(rc.ID).
SetUserInput(req.UserInput).
Exec(ctx); err != nil {
return err
}
}
return nil
}
func (r *ProxyRepo) handleFileWritten(ctx context.Context, tx *db.Tx, rc *db.Task, req *domain.ReportReq) error {
lineCount := 0
switch req.Tool {