diff --git a/backend/internal/proxy/repo/proxy.go b/backend/internal/proxy/repo/proxy.go index 532e91d..e146a6d 100644 --- a/backend/internal/proxy/repo/proxy.go +++ b/backend/internal/proxy/repo/proxy.go @@ -203,6 +203,14 @@ func (r *ProxyRepo) AcceptCompletion(ctx context.Context, req *domain.AcceptComp }) } +// abs 返回整数的绝对值 +func abs(x int64) int64 { + if x < 0 { + return -x + } + return x +} + func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error { return entx.WithTx(ctx, r.db, func(tx *db.Tx) error { rc, err := tx.Task.Query().Where(task.TaskID(req.ID)).Only(ctx) @@ -235,6 +243,11 @@ func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error { Where(taskrecord.TaskID(rc.ID)). SetCompletion(req.Content).Exec(ctx) + case consts.ReportActionReject: + if err := r.handleRejectCompletion(ctx, tx, rc, req); err != nil { + return err + } + case consts.ReportActionFileWritten: if err := r.handleFileWritten(ctx, tx, rc, req); err != nil { return err @@ -245,6 +258,41 @@ func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error { }) } +func (r *ProxyRepo) handleRejectCompletion(ctx context.Context, tx *db.Tx, rc *db.Task, req *domain.ReportReq) error { + // 检测用户是否在同样的光标位置输入了新代码 + // 1. 检查光标位置是否匹配 + // 2. 检查源代码是否发生了变化 + // 3. 只有在检测到变化时才记录用户输入 + + shouldRecord := false + + // 检查光标位置是否匹配(允许小的误差) + if req.CursorPosition > 0 && rc.CursorPosition > 0 { + positionDiff := abs(req.CursorPosition - rc.CursorPosition) + if positionDiff <= 10 { // 允许10个字符的误差 + shouldRecord = true + } + } + + // 检查源代码是否发生了变化 + if req.SourceCode != "" && rc.SourceCode != "" { + if req.SourceCode != rc.SourceCode { + shouldRecord = true + } + } + + // 如果检测到变化,记录用户输入 + if shouldRecord && req.UserInput != "" { + if err := tx.Task.UpdateOneID(rc.ID). + SetUserInput(req.UserInput). + Exec(ctx); err != nil { + return err + } + } + + return nil +} + func (r *ProxyRepo) handleFileWritten(ctx context.Context, tx *db.Tx, rc *db.Task, req *domain.ReportReq) error { lineCount := 0 switch req.Tool {