diff --git a/backend/consts/openai.go b/backend/consts/openai.go index 68a1815..935f147 100644 --- a/backend/consts/openai.go +++ b/backend/consts/openai.go @@ -12,4 +12,5 @@ type ChatRole string const ( ChatRoleUser ChatRole = "user" ChatRoleAssistant ChatRole = "assistant" + ChatRoleSystem ChatRole = "system" ) diff --git a/backend/consts/proxy.go b/backend/consts/proxy.go new file mode 100644 index 0000000..f4c72eb --- /dev/null +++ b/backend/consts/proxy.go @@ -0,0 +1,9 @@ +package consts + +type ReportAction string + +const ( + ReportActionAccept ReportAction = "accept" + ReportActionSuggest ReportAction = "suggest" + ReportActionFileWritten ReportAction = "file_written" +) diff --git a/backend/db/migrate/schema.go b/backend/db/migrate/schema.go index 35f1d1b..0f58067 100644 --- a/backend/db/migrate/schema.go +++ b/backend/db/migrate/schema.go @@ -270,6 +270,7 @@ var ( {Name: "code_lines", Type: field.TypeInt64, Nullable: true}, {Name: "input_tokens", Type: field.TypeInt64, Nullable: true}, {Name: "output_tokens", Type: field.TypeInt64, Nullable: true}, + {Name: "is_suggested", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime}, {Name: "updated_at", Type: field.TypeTime}, {Name: "model_id", Type: field.TypeUUID, Nullable: true}, @@ -283,13 +284,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "tasks_models_tasks", - Columns: []*schema.Column{TasksColumns[13]}, + Columns: []*schema.Column{TasksColumns[14]}, RefColumns: []*schema.Column{ModelsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "tasks_users_tasks", - Columns: []*schema.Column{TasksColumns[14]}, + Columns: []*schema.Column{TasksColumns[15]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.SetNull, }, @@ -302,6 +303,8 @@ var ( {Name: "role", Type: field.TypeString}, {Name: "completion", Type: field.TypeString}, {Name: "output_tokens", Type: field.TypeInt64}, + {Name: "code_lines", Type: field.TypeInt64}, + {Name: "code", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, {Name: "updated_at", Type: field.TypeTime}, {Name: "task_id", Type: field.TypeUUID, Nullable: true}, @@ -314,7 +317,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "task_records_tasks_task_records", - Columns: []*schema.Column{TaskRecordsColumns[7]}, + Columns: []*schema.Column{TaskRecordsColumns[9]}, RefColumns: []*schema.Column{TasksColumns[0]}, OnDelete: schema.SetNull, }, diff --git a/backend/db/mutation.go b/backend/db/mutation.go index bc77a7e..9cfdc87 100644 --- a/backend/db/mutation.go +++ b/backend/db/mutation.go @@ -9876,6 +9876,7 @@ type TaskMutation struct { addinput_tokens *int64 output_tokens *int64 addoutput_tokens *int64 + is_suggested *bool created_at *time.Time updated_at *time.Time clearedFields map[string]struct{} @@ -10607,6 +10608,42 @@ func (m *TaskMutation) ResetOutputTokens() { delete(m.clearedFields, task.FieldOutputTokens) } +// SetIsSuggested sets the "is_suggested" field. +func (m *TaskMutation) SetIsSuggested(b bool) { + m.is_suggested = &b +} + +// IsSuggested returns the value of the "is_suggested" field in the mutation. +func (m *TaskMutation) IsSuggested() (r bool, exists bool) { + v := m.is_suggested + if v == nil { + return + } + return *v, true +} + +// OldIsSuggested returns the old "is_suggested" field's value of the Task entity. +// If the Task object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskMutation) OldIsSuggested(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIsSuggested is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIsSuggested requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIsSuggested: %w", err) + } + return oldValue.IsSuggested, nil +} + +// ResetIsSuggested resets all changes to the "is_suggested" field. +func (m *TaskMutation) ResetIsSuggested() { + m.is_suggested = nil +} + // SetCreatedAt sets the "created_at" field. func (m *TaskMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -10821,7 +10858,7 @@ func (m *TaskMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *TaskMutation) Fields() []string { - fields := make([]string, 0, 14) + fields := make([]string, 0, 15) if m.task_id != nil { fields = append(fields, task.FieldTaskID) } @@ -10858,6 +10895,9 @@ func (m *TaskMutation) Fields() []string { if m.output_tokens != nil { fields = append(fields, task.FieldOutputTokens) } + if m.is_suggested != nil { + fields = append(fields, task.FieldIsSuggested) + } if m.created_at != nil { fields = append(fields, task.FieldCreatedAt) } @@ -10896,6 +10936,8 @@ func (m *TaskMutation) Field(name string) (ent.Value, bool) { return m.InputTokens() case task.FieldOutputTokens: return m.OutputTokens() + case task.FieldIsSuggested: + return m.IsSuggested() case task.FieldCreatedAt: return m.CreatedAt() case task.FieldUpdatedAt: @@ -10933,6 +10975,8 @@ func (m *TaskMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldInputTokens(ctx) case task.FieldOutputTokens: return m.OldOutputTokens(ctx) + case task.FieldIsSuggested: + return m.OldIsSuggested(ctx) case task.FieldCreatedAt: return m.OldCreatedAt(ctx) case task.FieldUpdatedAt: @@ -11030,6 +11074,13 @@ func (m *TaskMutation) SetField(name string, value ent.Value) error { } m.SetOutputTokens(v) return nil + case task.FieldIsSuggested: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIsSuggested(v) + return nil case task.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -11225,6 +11276,9 @@ func (m *TaskMutation) ResetField(name string) error { case task.FieldOutputTokens: m.ResetOutputTokens() return nil + case task.FieldIsSuggested: + m.ResetIsSuggested() + return nil case task.FieldCreatedAt: m.ResetCreatedAt() return nil @@ -11366,6 +11420,9 @@ type TaskRecordMutation struct { completion *string output_tokens *int64 addoutput_tokens *int64 + code_lines *int64 + addcode_lines *int64 + code *string created_at *time.Time updated_at *time.Time clearedFields map[string]struct{} @@ -11706,6 +11763,111 @@ func (m *TaskRecordMutation) ResetOutputTokens() { m.addoutput_tokens = nil } +// SetCodeLines sets the "code_lines" field. +func (m *TaskRecordMutation) SetCodeLines(i int64) { + m.code_lines = &i + m.addcode_lines = nil +} + +// CodeLines returns the value of the "code_lines" field in the mutation. +func (m *TaskRecordMutation) CodeLines() (r int64, exists bool) { + v := m.code_lines + if v == nil { + return + } + return *v, true +} + +// OldCodeLines returns the old "code_lines" field's value of the TaskRecord entity. +// If the TaskRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskRecordMutation) OldCodeLines(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCodeLines is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCodeLines requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCodeLines: %w", err) + } + return oldValue.CodeLines, nil +} + +// AddCodeLines adds i to the "code_lines" field. +func (m *TaskRecordMutation) AddCodeLines(i int64) { + if m.addcode_lines != nil { + *m.addcode_lines += i + } else { + m.addcode_lines = &i + } +} + +// AddedCodeLines returns the value that was added to the "code_lines" field in this mutation. +func (m *TaskRecordMutation) AddedCodeLines() (r int64, exists bool) { + v := m.addcode_lines + if v == nil { + return + } + return *v, true +} + +// ResetCodeLines resets all changes to the "code_lines" field. +func (m *TaskRecordMutation) ResetCodeLines() { + m.code_lines = nil + m.addcode_lines = nil +} + +// SetCode sets the "code" field. +func (m *TaskRecordMutation) SetCode(s string) { + m.code = &s +} + +// Code returns the value of the "code" field in the mutation. +func (m *TaskRecordMutation) Code() (r string, exists bool) { + v := m.code + if v == nil { + return + } + return *v, true +} + +// OldCode returns the old "code" field's value of the TaskRecord entity. +// If the TaskRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskRecordMutation) OldCode(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCode: %w", err) + } + return oldValue.Code, nil +} + +// ClearCode clears the value of the "code" field. +func (m *TaskRecordMutation) ClearCode() { + m.code = nil + m.clearedFields[taskrecord.FieldCode] = struct{}{} +} + +// CodeCleared returns if the "code" field was cleared in this mutation. +func (m *TaskRecordMutation) CodeCleared() bool { + _, ok := m.clearedFields[taskrecord.FieldCode] + return ok +} + +// ResetCode resets all changes to the "code" field. +func (m *TaskRecordMutation) ResetCode() { + m.code = nil + delete(m.clearedFields, taskrecord.FieldCode) +} + // SetCreatedAt sets the "created_at" field. func (m *TaskRecordMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -11839,7 +12001,7 @@ func (m *TaskRecordMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *TaskRecordMutation) Fields() []string { - fields := make([]string, 0, 7) + fields := make([]string, 0, 9) if m.task != nil { fields = append(fields, taskrecord.FieldTaskID) } @@ -11855,6 +12017,12 @@ func (m *TaskRecordMutation) Fields() []string { if m.output_tokens != nil { fields = append(fields, taskrecord.FieldOutputTokens) } + if m.code_lines != nil { + fields = append(fields, taskrecord.FieldCodeLines) + } + if m.code != nil { + fields = append(fields, taskrecord.FieldCode) + } if m.created_at != nil { fields = append(fields, taskrecord.FieldCreatedAt) } @@ -11879,6 +12047,10 @@ func (m *TaskRecordMutation) Field(name string) (ent.Value, bool) { return m.Completion() case taskrecord.FieldOutputTokens: return m.OutputTokens() + case taskrecord.FieldCodeLines: + return m.CodeLines() + case taskrecord.FieldCode: + return m.Code() case taskrecord.FieldCreatedAt: return m.CreatedAt() case taskrecord.FieldUpdatedAt: @@ -11902,6 +12074,10 @@ func (m *TaskRecordMutation) OldField(ctx context.Context, name string) (ent.Val return m.OldCompletion(ctx) case taskrecord.FieldOutputTokens: return m.OldOutputTokens(ctx) + case taskrecord.FieldCodeLines: + return m.OldCodeLines(ctx) + case taskrecord.FieldCode: + return m.OldCode(ctx) case taskrecord.FieldCreatedAt: return m.OldCreatedAt(ctx) case taskrecord.FieldUpdatedAt: @@ -11950,6 +12126,20 @@ func (m *TaskRecordMutation) SetField(name string, value ent.Value) error { } m.SetOutputTokens(v) return nil + case taskrecord.FieldCodeLines: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCodeLines(v) + return nil + case taskrecord.FieldCode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCode(v) + return nil case taskrecord.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -11975,6 +12165,9 @@ func (m *TaskRecordMutation) AddedFields() []string { if m.addoutput_tokens != nil { fields = append(fields, taskrecord.FieldOutputTokens) } + if m.addcode_lines != nil { + fields = append(fields, taskrecord.FieldCodeLines) + } return fields } @@ -11985,6 +12178,8 @@ func (m *TaskRecordMutation) AddedField(name string) (ent.Value, bool) { switch name { case taskrecord.FieldOutputTokens: return m.AddedOutputTokens() + case taskrecord.FieldCodeLines: + return m.AddedCodeLines() } return nil, false } @@ -12001,6 +12196,13 @@ func (m *TaskRecordMutation) AddField(name string, value ent.Value) error { } m.AddOutputTokens(v) return nil + case taskrecord.FieldCodeLines: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCodeLines(v) + return nil } return fmt.Errorf("unknown TaskRecord numeric field %s", name) } @@ -12015,6 +12217,9 @@ func (m *TaskRecordMutation) ClearedFields() []string { if m.FieldCleared(taskrecord.FieldPrompt) { fields = append(fields, taskrecord.FieldPrompt) } + if m.FieldCleared(taskrecord.FieldCode) { + fields = append(fields, taskrecord.FieldCode) + } return fields } @@ -12035,6 +12240,9 @@ func (m *TaskRecordMutation) ClearField(name string) error { case taskrecord.FieldPrompt: m.ClearPrompt() return nil + case taskrecord.FieldCode: + m.ClearCode() + return nil } return fmt.Errorf("unknown TaskRecord nullable field %s", name) } @@ -12058,6 +12266,12 @@ func (m *TaskRecordMutation) ResetField(name string) error { case taskrecord.FieldOutputTokens: m.ResetOutputTokens() return nil + case taskrecord.FieldCodeLines: + m.ResetCodeLines() + return nil + case taskrecord.FieldCode: + m.ResetCode() + return nil case taskrecord.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/backend/db/runtime/runtime.go b/backend/db/runtime/runtime.go index 3d2d22a..8131d59 100644 --- a/backend/db/runtime/runtime.go +++ b/backend/db/runtime/runtime.go @@ -256,12 +256,16 @@ func init() { taskDescIsAccept := taskFields[6].Descriptor() // task.DefaultIsAccept holds the default value on creation for the is_accept field. task.DefaultIsAccept = taskDescIsAccept.Default.(bool) + // taskDescIsSuggested is the schema descriptor for is_suggested field. + taskDescIsSuggested := taskFields[13].Descriptor() + // task.DefaultIsSuggested holds the default value on creation for the is_suggested field. + task.DefaultIsSuggested = taskDescIsSuggested.Default.(bool) // taskDescCreatedAt is the schema descriptor for created_at field. - taskDescCreatedAt := taskFields[13].Descriptor() + taskDescCreatedAt := taskFields[14].Descriptor() // task.DefaultCreatedAt holds the default value on creation for the created_at field. task.DefaultCreatedAt = taskDescCreatedAt.Default.(func() time.Time) // taskDescUpdatedAt is the schema descriptor for updated_at field. - taskDescUpdatedAt := taskFields[14].Descriptor() + taskDescUpdatedAt := taskFields[15].Descriptor() // task.DefaultUpdatedAt holds the default value on creation for the updated_at field. task.DefaultUpdatedAt = taskDescUpdatedAt.Default.(func() time.Time) // task.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. @@ -269,11 +273,11 @@ func init() { taskrecordFields := schema.TaskRecord{}.Fields() _ = taskrecordFields // taskrecordDescCreatedAt is the schema descriptor for created_at field. - taskrecordDescCreatedAt := taskrecordFields[6].Descriptor() + taskrecordDescCreatedAt := taskrecordFields[8].Descriptor() // taskrecord.DefaultCreatedAt holds the default value on creation for the created_at field. taskrecord.DefaultCreatedAt = taskrecordDescCreatedAt.Default.(func() time.Time) // taskrecordDescUpdatedAt is the schema descriptor for updated_at field. - taskrecordDescUpdatedAt := taskrecordFields[7].Descriptor() + taskrecordDescUpdatedAt := taskrecordFields[9].Descriptor() // taskrecord.DefaultUpdatedAt holds the default value on creation for the updated_at field. taskrecord.DefaultUpdatedAt = taskrecordDescUpdatedAt.Default.(func() time.Time) // taskrecord.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. diff --git a/backend/db/task.go b/backend/db/task.go index 043c9c3..d9902b5 100644 --- a/backend/db/task.go +++ b/backend/db/task.go @@ -45,6 +45,8 @@ type Task struct { InputTokens int64 `json:"input_tokens,omitempty"` // OutputTokens holds the value of the "output_tokens" field. OutputTokens int64 `json:"output_tokens,omitempty"` + // IsSuggested holds the value of the "is_suggested" field. + IsSuggested bool `json:"is_suggested,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. @@ -104,7 +106,7 @@ func (*Task) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case task.FieldIsAccept: + case task.FieldIsAccept, task.FieldIsSuggested: values[i] = new(sql.NullBool) case task.FieldCodeLines, task.FieldInputTokens, task.FieldOutputTokens: values[i] = new(sql.NullInt64) @@ -207,6 +209,12 @@ func (t *Task) assignValues(columns []string, values []any) error { } else if value.Valid { t.OutputTokens = value.Int64 } + case task.FieldIsSuggested: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field is_suggested", values[i]) + } else if value.Valid { + t.IsSuggested = value.Bool + } case task.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -306,6 +314,9 @@ func (t *Task) String() string { builder.WriteString("output_tokens=") builder.WriteString(fmt.Sprintf("%v", t.OutputTokens)) builder.WriteString(", ") + builder.WriteString("is_suggested=") + builder.WriteString(fmt.Sprintf("%v", t.IsSuggested)) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(t.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") diff --git a/backend/db/task/task.go b/backend/db/task/task.go index fed0b73..cba7158 100644 --- a/backend/db/task/task.go +++ b/backend/db/task/task.go @@ -38,6 +38,8 @@ const ( FieldInputTokens = "input_tokens" // FieldOutputTokens holds the string denoting the output_tokens field in the database. FieldOutputTokens = "output_tokens" + // FieldIsSuggested holds the string denoting the is_suggested field in the database. + FieldIsSuggested = "is_suggested" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // FieldUpdatedAt holds the string denoting the updated_at field in the database. @@ -88,6 +90,7 @@ var Columns = []string{ FieldCodeLines, FieldInputTokens, FieldOutputTokens, + FieldIsSuggested, FieldCreatedAt, FieldUpdatedAt, } @@ -105,6 +108,8 @@ func ValidColumn(column string) bool { var ( // DefaultIsAccept holds the default value on creation for the "is_accept" field. DefaultIsAccept bool + // DefaultIsSuggested holds the default value on creation for the "is_suggested" field. + DefaultIsSuggested bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. @@ -181,6 +186,11 @@ func ByOutputTokens(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldOutputTokens, opts...).ToFunc() } +// ByIsSuggested orders the results by the is_suggested field. +func ByIsSuggested(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIsSuggested, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/db/task/where.go b/backend/db/task/where.go index 5321211..5f37f73 100644 --- a/backend/db/task/where.go +++ b/backend/db/task/where.go @@ -118,6 +118,11 @@ func OutputTokens(v int64) predicate.Task { return predicate.Task(sql.FieldEQ(FieldOutputTokens, v)) } +// IsSuggested applies equality check predicate on the "is_suggested" field. It's identical to IsSuggestedEQ. +func IsSuggested(v bool) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldIsSuggested, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Task { return predicate.Task(sql.FieldEQ(FieldCreatedAt, v)) @@ -797,6 +802,16 @@ func OutputTokensNotNil() predicate.Task { return predicate.Task(sql.FieldNotNull(FieldOutputTokens)) } +// IsSuggestedEQ applies the EQ predicate on the "is_suggested" field. +func IsSuggestedEQ(v bool) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldIsSuggested, v)) +} + +// IsSuggestedNEQ applies the NEQ predicate on the "is_suggested" field. +func IsSuggestedNEQ(v bool) predicate.Task { + return predicate.Task(sql.FieldNEQ(FieldIsSuggested, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Task { return predicate.Task(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/db/task_create.go b/backend/db/task_create.go index adcfb3d..50d4335 100644 --- a/backend/db/task_create.go +++ b/backend/db/task_create.go @@ -180,6 +180,20 @@ func (tc *TaskCreate) SetNillableOutputTokens(i *int64) *TaskCreate { return tc } +// SetIsSuggested sets the "is_suggested" field. +func (tc *TaskCreate) SetIsSuggested(b bool) *TaskCreate { + tc.mutation.SetIsSuggested(b) + return tc +} + +// SetNillableIsSuggested sets the "is_suggested" field if the given value is not nil. +func (tc *TaskCreate) SetNillableIsSuggested(b *bool) *TaskCreate { + if b != nil { + tc.SetIsSuggested(*b) + } + return tc +} + // SetCreatedAt sets the "created_at" field. func (tc *TaskCreate) SetCreatedAt(t time.Time) *TaskCreate { tc.mutation.SetCreatedAt(t) @@ -278,6 +292,10 @@ func (tc *TaskCreate) defaults() { v := task.DefaultIsAccept tc.mutation.SetIsAccept(v) } + if _, ok := tc.mutation.IsSuggested(); !ok { + v := task.DefaultIsSuggested + tc.mutation.SetIsSuggested(v) + } if _, ok := tc.mutation.CreatedAt(); !ok { v := task.DefaultCreatedAt() tc.mutation.SetCreatedAt(v) @@ -299,6 +317,9 @@ func (tc *TaskCreate) check() error { if _, ok := tc.mutation.IsAccept(); !ok { return &ValidationError{Name: "is_accept", err: errors.New(`db: missing required field "Task.is_accept"`)} } + if _, ok := tc.mutation.IsSuggested(); !ok { + return &ValidationError{Name: "is_suggested", err: errors.New(`db: missing required field "Task.is_suggested"`)} + } if _, ok := tc.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`db: missing required field "Task.created_at"`)} } @@ -381,6 +402,10 @@ func (tc *TaskCreate) createSpec() (*Task, *sqlgraph.CreateSpec) { _spec.SetField(task.FieldOutputTokens, field.TypeInt64, value) _node.OutputTokens = value } + if value, ok := tc.mutation.IsSuggested(); ok { + _spec.SetField(task.FieldIsSuggested, field.TypeBool, value) + _node.IsSuggested = value + } if value, ok := tc.mutation.CreatedAt(); ok { _spec.SetField(task.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -707,6 +732,18 @@ func (u *TaskUpsert) ClearOutputTokens() *TaskUpsert { return u } +// SetIsSuggested sets the "is_suggested" field. +func (u *TaskUpsert) SetIsSuggested(v bool) *TaskUpsert { + u.Set(task.FieldIsSuggested, v) + return u +} + +// UpdateIsSuggested sets the "is_suggested" field to the value that was provided on create. +func (u *TaskUpsert) UpdateIsSuggested() *TaskUpsert { + u.SetExcluded(task.FieldIsSuggested) + return u +} + // SetCreatedAt sets the "created_at" field. func (u *TaskUpsert) SetCreatedAt(v time.Time) *TaskUpsert { u.Set(task.FieldCreatedAt, v) @@ -1031,6 +1068,20 @@ func (u *TaskUpsertOne) ClearOutputTokens() *TaskUpsertOne { }) } +// SetIsSuggested sets the "is_suggested" field. +func (u *TaskUpsertOne) SetIsSuggested(v bool) *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.SetIsSuggested(v) + }) +} + +// UpdateIsSuggested sets the "is_suggested" field to the value that was provided on create. +func (u *TaskUpsertOne) UpdateIsSuggested() *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.UpdateIsSuggested() + }) +} + // SetCreatedAt sets the "created_at" field. func (u *TaskUpsertOne) SetCreatedAt(v time.Time) *TaskUpsertOne { return u.Update(func(s *TaskUpsert) { @@ -1526,6 +1577,20 @@ func (u *TaskUpsertBulk) ClearOutputTokens() *TaskUpsertBulk { }) } +// SetIsSuggested sets the "is_suggested" field. +func (u *TaskUpsertBulk) SetIsSuggested(v bool) *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.SetIsSuggested(v) + }) +} + +// UpdateIsSuggested sets the "is_suggested" field to the value that was provided on create. +func (u *TaskUpsertBulk) UpdateIsSuggested() *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.UpdateIsSuggested() + }) +} + // SetCreatedAt sets the "created_at" field. func (u *TaskUpsertBulk) SetCreatedAt(v time.Time) *TaskUpsertBulk { return u.Update(func(s *TaskUpsert) { diff --git a/backend/db/task_update.go b/backend/db/task_update.go index dfddaa9..d202867 100644 --- a/backend/db/task_update.go +++ b/backend/db/task_update.go @@ -277,6 +277,20 @@ func (tu *TaskUpdate) ClearOutputTokens() *TaskUpdate { return tu } +// SetIsSuggested sets the "is_suggested" field. +func (tu *TaskUpdate) SetIsSuggested(b bool) *TaskUpdate { + tu.mutation.SetIsSuggested(b) + return tu +} + +// SetNillableIsSuggested sets the "is_suggested" field if the given value is not nil. +func (tu *TaskUpdate) SetNillableIsSuggested(b *bool) *TaskUpdate { + if b != nil { + tu.SetIsSuggested(*b) + } + return tu +} + // SetCreatedAt sets the "created_at" field. func (tu *TaskUpdate) SetCreatedAt(t time.Time) *TaskUpdate { tu.mutation.SetCreatedAt(t) @@ -471,6 +485,9 @@ func (tu *TaskUpdate) sqlSave(ctx context.Context) (n int, err error) { if tu.mutation.OutputTokensCleared() { _spec.ClearField(task.FieldOutputTokens, field.TypeInt64) } + if value, ok := tu.mutation.IsSuggested(); ok { + _spec.SetField(task.FieldIsSuggested, field.TypeBool, value) + } if value, ok := tu.mutation.CreatedAt(); ok { _spec.SetField(task.FieldCreatedAt, field.TypeTime, value) } @@ -845,6 +862,20 @@ func (tuo *TaskUpdateOne) ClearOutputTokens() *TaskUpdateOne { return tuo } +// SetIsSuggested sets the "is_suggested" field. +func (tuo *TaskUpdateOne) SetIsSuggested(b bool) *TaskUpdateOne { + tuo.mutation.SetIsSuggested(b) + return tuo +} + +// SetNillableIsSuggested sets the "is_suggested" field if the given value is not nil. +func (tuo *TaskUpdateOne) SetNillableIsSuggested(b *bool) *TaskUpdateOne { + if b != nil { + tuo.SetIsSuggested(*b) + } + return tuo +} + // SetCreatedAt sets the "created_at" field. func (tuo *TaskUpdateOne) SetCreatedAt(t time.Time) *TaskUpdateOne { tuo.mutation.SetCreatedAt(t) @@ -1069,6 +1100,9 @@ func (tuo *TaskUpdateOne) sqlSave(ctx context.Context) (_node *Task, err error) if tuo.mutation.OutputTokensCleared() { _spec.ClearField(task.FieldOutputTokens, field.TypeInt64) } + if value, ok := tuo.mutation.IsSuggested(); ok { + _spec.SetField(task.FieldIsSuggested, field.TypeBool, value) + } if value, ok := tuo.mutation.CreatedAt(); ok { _spec.SetField(task.FieldCreatedAt, field.TypeTime, value) } diff --git a/backend/db/taskrecord.go b/backend/db/taskrecord.go index 88efe18..6a270bd 100644 --- a/backend/db/taskrecord.go +++ b/backend/db/taskrecord.go @@ -30,6 +30,10 @@ type TaskRecord struct { Completion string `json:"completion,omitempty"` // OutputTokens holds the value of the "output_tokens" field. OutputTokens int64 `json:"output_tokens,omitempty"` + // CodeLines holds the value of the "code_lines" field. + CodeLines int64 `json:"code_lines,omitempty"` + // Code holds the value of the "code" field. + Code string `json:"code,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. @@ -65,9 +69,9 @@ func (*TaskRecord) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case taskrecord.FieldOutputTokens: + case taskrecord.FieldOutputTokens, taskrecord.FieldCodeLines: values[i] = new(sql.NullInt64) - case taskrecord.FieldPrompt, taskrecord.FieldRole, taskrecord.FieldCompletion: + case taskrecord.FieldPrompt, taskrecord.FieldRole, taskrecord.FieldCompletion, taskrecord.FieldCode: values[i] = new(sql.NullString) case taskrecord.FieldCreatedAt, taskrecord.FieldUpdatedAt: values[i] = new(sql.NullTime) @@ -124,6 +128,18 @@ func (tr *TaskRecord) assignValues(columns []string, values []any) error { } else if value.Valid { tr.OutputTokens = value.Int64 } + case taskrecord.FieldCodeLines: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field code_lines", values[i]) + } else if value.Valid { + tr.CodeLines = value.Int64 + } + case taskrecord.FieldCode: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field code", values[i]) + } else if value.Valid { + tr.Code = value.String + } case taskrecord.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -192,6 +208,12 @@ func (tr *TaskRecord) String() string { builder.WriteString("output_tokens=") builder.WriteString(fmt.Sprintf("%v", tr.OutputTokens)) builder.WriteString(", ") + builder.WriteString("code_lines=") + builder.WriteString(fmt.Sprintf("%v", tr.CodeLines)) + builder.WriteString(", ") + builder.WriteString("code=") + builder.WriteString(tr.Code) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(tr.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") diff --git a/backend/db/taskrecord/taskrecord.go b/backend/db/taskrecord/taskrecord.go index bca05dd..e290321 100644 --- a/backend/db/taskrecord/taskrecord.go +++ b/backend/db/taskrecord/taskrecord.go @@ -24,6 +24,10 @@ const ( FieldCompletion = "completion" // FieldOutputTokens holds the string denoting the output_tokens field in the database. FieldOutputTokens = "output_tokens" + // FieldCodeLines holds the string denoting the code_lines field in the database. + FieldCodeLines = "code_lines" + // FieldCode holds the string denoting the code field in the database. + FieldCode = "code" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // FieldUpdatedAt holds the string denoting the updated_at field in the database. @@ -49,6 +53,8 @@ var Columns = []string{ FieldRole, FieldCompletion, FieldOutputTokens, + FieldCodeLines, + FieldCode, FieldCreatedAt, FieldUpdatedAt, } @@ -105,6 +111,16 @@ func ByOutputTokens(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldOutputTokens, opts...).ToFunc() } +// ByCodeLines orders the results by the code_lines field. +func ByCodeLines(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCodeLines, opts...).ToFunc() +} + +// ByCode orders the results by the code field. +func ByCode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCode, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/db/taskrecord/where.go b/backend/db/taskrecord/where.go index 38a4a74..a8736f8 100644 --- a/backend/db/taskrecord/where.go +++ b/backend/db/taskrecord/where.go @@ -83,6 +83,16 @@ func OutputTokens(v int64) predicate.TaskRecord { return predicate.TaskRecord(sql.FieldEQ(FieldOutputTokens, v)) } +// CodeLines applies equality check predicate on the "code_lines" field. It's identical to CodeLinesEQ. +func CodeLines(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldEQ(FieldCodeLines, v)) +} + +// Code applies equality check predicate on the "code" field. It's identical to CodeEQ. +func Code(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldEQ(FieldCode, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.TaskRecord { return predicate.TaskRecord(sql.FieldEQ(FieldCreatedAt, v)) @@ -387,6 +397,121 @@ func OutputTokensLTE(v int64) predicate.TaskRecord { return predicate.TaskRecord(sql.FieldLTE(FieldOutputTokens, v)) } +// CodeLinesEQ applies the EQ predicate on the "code_lines" field. +func CodeLinesEQ(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldEQ(FieldCodeLines, v)) +} + +// CodeLinesNEQ applies the NEQ predicate on the "code_lines" field. +func CodeLinesNEQ(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNEQ(FieldCodeLines, v)) +} + +// CodeLinesIn applies the In predicate on the "code_lines" field. +func CodeLinesIn(vs ...int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldIn(FieldCodeLines, vs...)) +} + +// CodeLinesNotIn applies the NotIn predicate on the "code_lines" field. +func CodeLinesNotIn(vs ...int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNotIn(FieldCodeLines, vs...)) +} + +// CodeLinesGT applies the GT predicate on the "code_lines" field. +func CodeLinesGT(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldGT(FieldCodeLines, v)) +} + +// CodeLinesGTE applies the GTE predicate on the "code_lines" field. +func CodeLinesGTE(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldGTE(FieldCodeLines, v)) +} + +// CodeLinesLT applies the LT predicate on the "code_lines" field. +func CodeLinesLT(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldLT(FieldCodeLines, v)) +} + +// CodeLinesLTE applies the LTE predicate on the "code_lines" field. +func CodeLinesLTE(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldLTE(FieldCodeLines, v)) +} + +// CodeEQ applies the EQ predicate on the "code" field. +func CodeEQ(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldEQ(FieldCode, v)) +} + +// CodeNEQ applies the NEQ predicate on the "code" field. +func CodeNEQ(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNEQ(FieldCode, v)) +} + +// CodeIn applies the In predicate on the "code" field. +func CodeIn(vs ...string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldIn(FieldCode, vs...)) +} + +// CodeNotIn applies the NotIn predicate on the "code" field. +func CodeNotIn(vs ...string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNotIn(FieldCode, vs...)) +} + +// CodeGT applies the GT predicate on the "code" field. +func CodeGT(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldGT(FieldCode, v)) +} + +// CodeGTE applies the GTE predicate on the "code" field. +func CodeGTE(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldGTE(FieldCode, v)) +} + +// CodeLT applies the LT predicate on the "code" field. +func CodeLT(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldLT(FieldCode, v)) +} + +// CodeLTE applies the LTE predicate on the "code" field. +func CodeLTE(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldLTE(FieldCode, v)) +} + +// CodeContains applies the Contains predicate on the "code" field. +func CodeContains(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldContains(FieldCode, v)) +} + +// CodeHasPrefix applies the HasPrefix predicate on the "code" field. +func CodeHasPrefix(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldHasPrefix(FieldCode, v)) +} + +// CodeHasSuffix applies the HasSuffix predicate on the "code" field. +func CodeHasSuffix(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldHasSuffix(FieldCode, v)) +} + +// CodeIsNil applies the IsNil predicate on the "code" field. +func CodeIsNil() predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldIsNull(FieldCode)) +} + +// CodeNotNil applies the NotNil predicate on the "code" field. +func CodeNotNil() predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNotNull(FieldCode)) +} + +// CodeEqualFold applies the EqualFold predicate on the "code" field. +func CodeEqualFold(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldEqualFold(FieldCode, v)) +} + +// CodeContainsFold applies the ContainsFold predicate on the "code" field. +func CodeContainsFold(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldContainsFold(FieldCode, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.TaskRecord { return predicate.TaskRecord(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/db/taskrecord_create.go b/backend/db/taskrecord_create.go index aa78db4..e7b89bc 100644 --- a/backend/db/taskrecord_create.go +++ b/backend/db/taskrecord_create.go @@ -72,6 +72,26 @@ func (trc *TaskRecordCreate) SetOutputTokens(i int64) *TaskRecordCreate { return trc } +// SetCodeLines sets the "code_lines" field. +func (trc *TaskRecordCreate) SetCodeLines(i int64) *TaskRecordCreate { + trc.mutation.SetCodeLines(i) + return trc +} + +// SetCode sets the "code" field. +func (trc *TaskRecordCreate) SetCode(s string) *TaskRecordCreate { + trc.mutation.SetCode(s) + return trc +} + +// SetNillableCode sets the "code" field if the given value is not nil. +func (trc *TaskRecordCreate) SetNillableCode(s *string) *TaskRecordCreate { + if s != nil { + trc.SetCode(*s) + } + return trc +} + // SetCreatedAt sets the "created_at" field. func (trc *TaskRecordCreate) SetCreatedAt(t time.Time) *TaskRecordCreate { trc.mutation.SetCreatedAt(t) @@ -167,6 +187,9 @@ func (trc *TaskRecordCreate) check() error { if _, ok := trc.mutation.OutputTokens(); !ok { return &ValidationError{Name: "output_tokens", err: errors.New(`db: missing required field "TaskRecord.output_tokens"`)} } + if _, ok := trc.mutation.CodeLines(); !ok { + return &ValidationError{Name: "code_lines", err: errors.New(`db: missing required field "TaskRecord.code_lines"`)} + } if _, ok := trc.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`db: missing required field "TaskRecord.created_at"`)} } @@ -225,6 +248,14 @@ func (trc *TaskRecordCreate) createSpec() (*TaskRecord, *sqlgraph.CreateSpec) { _spec.SetField(taskrecord.FieldOutputTokens, field.TypeInt64, value) _node.OutputTokens = value } + if value, ok := trc.mutation.CodeLines(); ok { + _spec.SetField(taskrecord.FieldCodeLines, field.TypeInt64, value) + _node.CodeLines = value + } + if value, ok := trc.mutation.Code(); ok { + _spec.SetField(taskrecord.FieldCode, field.TypeString, value) + _node.Code = value + } if value, ok := trc.mutation.CreatedAt(); ok { _spec.SetField(taskrecord.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -380,6 +411,42 @@ func (u *TaskRecordUpsert) AddOutputTokens(v int64) *TaskRecordUpsert { return u } +// SetCodeLines sets the "code_lines" field. +func (u *TaskRecordUpsert) SetCodeLines(v int64) *TaskRecordUpsert { + u.Set(taskrecord.FieldCodeLines, v) + return u +} + +// UpdateCodeLines sets the "code_lines" field to the value that was provided on create. +func (u *TaskRecordUpsert) UpdateCodeLines() *TaskRecordUpsert { + u.SetExcluded(taskrecord.FieldCodeLines) + return u +} + +// AddCodeLines adds v to the "code_lines" field. +func (u *TaskRecordUpsert) AddCodeLines(v int64) *TaskRecordUpsert { + u.Add(taskrecord.FieldCodeLines, v) + return u +} + +// SetCode sets the "code" field. +func (u *TaskRecordUpsert) SetCode(v string) *TaskRecordUpsert { + u.Set(taskrecord.FieldCode, v) + return u +} + +// UpdateCode sets the "code" field to the value that was provided on create. +func (u *TaskRecordUpsert) UpdateCode() *TaskRecordUpsert { + u.SetExcluded(taskrecord.FieldCode) + return u +} + +// ClearCode clears the value of the "code" field. +func (u *TaskRecordUpsert) ClearCode() *TaskRecordUpsert { + u.SetNull(taskrecord.FieldCode) + return u +} + // SetCreatedAt sets the "created_at" field. func (u *TaskRecordUpsert) SetCreatedAt(v time.Time) *TaskRecordUpsert { u.Set(taskrecord.FieldCreatedAt, v) @@ -543,6 +610,48 @@ func (u *TaskRecordUpsertOne) UpdateOutputTokens() *TaskRecordUpsertOne { }) } +// SetCodeLines sets the "code_lines" field. +func (u *TaskRecordUpsertOne) SetCodeLines(v int64) *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.SetCodeLines(v) + }) +} + +// AddCodeLines adds v to the "code_lines" field. +func (u *TaskRecordUpsertOne) AddCodeLines(v int64) *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.AddCodeLines(v) + }) +} + +// UpdateCodeLines sets the "code_lines" field to the value that was provided on create. +func (u *TaskRecordUpsertOne) UpdateCodeLines() *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.UpdateCodeLines() + }) +} + +// SetCode sets the "code" field. +func (u *TaskRecordUpsertOne) SetCode(v string) *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.SetCode(v) + }) +} + +// UpdateCode sets the "code" field to the value that was provided on create. +func (u *TaskRecordUpsertOne) UpdateCode() *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.UpdateCode() + }) +} + +// ClearCode clears the value of the "code" field. +func (u *TaskRecordUpsertOne) ClearCode() *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.ClearCode() + }) +} + // SetCreatedAt sets the "created_at" field. func (u *TaskRecordUpsertOne) SetCreatedAt(v time.Time) *TaskRecordUpsertOne { return u.Update(func(s *TaskRecordUpsert) { @@ -877,6 +986,48 @@ func (u *TaskRecordUpsertBulk) UpdateOutputTokens() *TaskRecordUpsertBulk { }) } +// SetCodeLines sets the "code_lines" field. +func (u *TaskRecordUpsertBulk) SetCodeLines(v int64) *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.SetCodeLines(v) + }) +} + +// AddCodeLines adds v to the "code_lines" field. +func (u *TaskRecordUpsertBulk) AddCodeLines(v int64) *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.AddCodeLines(v) + }) +} + +// UpdateCodeLines sets the "code_lines" field to the value that was provided on create. +func (u *TaskRecordUpsertBulk) UpdateCodeLines() *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.UpdateCodeLines() + }) +} + +// SetCode sets the "code" field. +func (u *TaskRecordUpsertBulk) SetCode(v string) *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.SetCode(v) + }) +} + +// UpdateCode sets the "code" field to the value that was provided on create. +func (u *TaskRecordUpsertBulk) UpdateCode() *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.UpdateCode() + }) +} + +// ClearCode clears the value of the "code" field. +func (u *TaskRecordUpsertBulk) ClearCode() *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.ClearCode() + }) +} + // SetCreatedAt sets the "created_at" field. func (u *TaskRecordUpsertBulk) SetCreatedAt(v time.Time) *TaskRecordUpsertBulk { return u.Update(func(s *TaskRecordUpsert) { diff --git a/backend/db/taskrecord_update.go b/backend/db/taskrecord_update.go index 4f16d34..1c5be23 100644 --- a/backend/db/taskrecord_update.go +++ b/backend/db/taskrecord_update.go @@ -121,6 +121,47 @@ func (tru *TaskRecordUpdate) AddOutputTokens(i int64) *TaskRecordUpdate { return tru } +// SetCodeLines sets the "code_lines" field. +func (tru *TaskRecordUpdate) SetCodeLines(i int64) *TaskRecordUpdate { + tru.mutation.ResetCodeLines() + tru.mutation.SetCodeLines(i) + return tru +} + +// SetNillableCodeLines sets the "code_lines" field if the given value is not nil. +func (tru *TaskRecordUpdate) SetNillableCodeLines(i *int64) *TaskRecordUpdate { + if i != nil { + tru.SetCodeLines(*i) + } + return tru +} + +// AddCodeLines adds i to the "code_lines" field. +func (tru *TaskRecordUpdate) AddCodeLines(i int64) *TaskRecordUpdate { + tru.mutation.AddCodeLines(i) + return tru +} + +// SetCode sets the "code" field. +func (tru *TaskRecordUpdate) SetCode(s string) *TaskRecordUpdate { + tru.mutation.SetCode(s) + return tru +} + +// SetNillableCode sets the "code" field if the given value is not nil. +func (tru *TaskRecordUpdate) SetNillableCode(s *string) *TaskRecordUpdate { + if s != nil { + tru.SetCode(*s) + } + return tru +} + +// ClearCode clears the value of the "code" field. +func (tru *TaskRecordUpdate) ClearCode() *TaskRecordUpdate { + tru.mutation.ClearCode() + return tru +} + // SetCreatedAt sets the "created_at" field. func (tru *TaskRecordUpdate) SetCreatedAt(t time.Time) *TaskRecordUpdate { tru.mutation.SetCreatedAt(t) @@ -226,6 +267,18 @@ func (tru *TaskRecordUpdate) sqlSave(ctx context.Context) (n int, err error) { if value, ok := tru.mutation.AddedOutputTokens(); ok { _spec.AddField(taskrecord.FieldOutputTokens, field.TypeInt64, value) } + if value, ok := tru.mutation.CodeLines(); ok { + _spec.SetField(taskrecord.FieldCodeLines, field.TypeInt64, value) + } + if value, ok := tru.mutation.AddedCodeLines(); ok { + _spec.AddField(taskrecord.FieldCodeLines, field.TypeInt64, value) + } + if value, ok := tru.mutation.Code(); ok { + _spec.SetField(taskrecord.FieldCode, field.TypeString, value) + } + if tru.mutation.CodeCleared() { + _spec.ClearField(taskrecord.FieldCode, field.TypeString) + } if value, ok := tru.mutation.CreatedAt(); ok { _spec.SetField(taskrecord.FieldCreatedAt, field.TypeTime, value) } @@ -372,6 +425,47 @@ func (truo *TaskRecordUpdateOne) AddOutputTokens(i int64) *TaskRecordUpdateOne { return truo } +// SetCodeLines sets the "code_lines" field. +func (truo *TaskRecordUpdateOne) SetCodeLines(i int64) *TaskRecordUpdateOne { + truo.mutation.ResetCodeLines() + truo.mutation.SetCodeLines(i) + return truo +} + +// SetNillableCodeLines sets the "code_lines" field if the given value is not nil. +func (truo *TaskRecordUpdateOne) SetNillableCodeLines(i *int64) *TaskRecordUpdateOne { + if i != nil { + truo.SetCodeLines(*i) + } + return truo +} + +// AddCodeLines adds i to the "code_lines" field. +func (truo *TaskRecordUpdateOne) AddCodeLines(i int64) *TaskRecordUpdateOne { + truo.mutation.AddCodeLines(i) + return truo +} + +// SetCode sets the "code" field. +func (truo *TaskRecordUpdateOne) SetCode(s string) *TaskRecordUpdateOne { + truo.mutation.SetCode(s) + return truo +} + +// SetNillableCode sets the "code" field if the given value is not nil. +func (truo *TaskRecordUpdateOne) SetNillableCode(s *string) *TaskRecordUpdateOne { + if s != nil { + truo.SetCode(*s) + } + return truo +} + +// ClearCode clears the value of the "code" field. +func (truo *TaskRecordUpdateOne) ClearCode() *TaskRecordUpdateOne { + truo.mutation.ClearCode() + return truo +} + // SetCreatedAt sets the "created_at" field. func (truo *TaskRecordUpdateOne) SetCreatedAt(t time.Time) *TaskRecordUpdateOne { truo.mutation.SetCreatedAt(t) @@ -507,6 +601,18 @@ func (truo *TaskRecordUpdateOne) sqlSave(ctx context.Context) (_node *TaskRecord if value, ok := truo.mutation.AddedOutputTokens(); ok { _spec.AddField(taskrecord.FieldOutputTokens, field.TypeInt64, value) } + if value, ok := truo.mutation.CodeLines(); ok { + _spec.SetField(taskrecord.FieldCodeLines, field.TypeInt64, value) + } + if value, ok := truo.mutation.AddedCodeLines(); ok { + _spec.AddField(taskrecord.FieldCodeLines, field.TypeInt64, value) + } + if value, ok := truo.mutation.Code(); ok { + _spec.SetField(taskrecord.FieldCode, field.TypeString, value) + } + if truo.mutation.CodeCleared() { + _spec.ClearField(taskrecord.FieldCode, field.TypeString) + } if value, ok := truo.mutation.CreatedAt(); ok { _spec.SetField(taskrecord.FieldCreatedAt, field.TypeTime, value) } diff --git a/backend/docs/swagger.json b/backend/docs/swagger.json index a68ece6..eac0ad2 100644 --- a/backend/docs/swagger.json +++ b/backend/docs/swagger.json @@ -2045,6 +2045,41 @@ } } } + }, + "/v1/report": { + "post": { + "description": "报告", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "OpenAIV1" + ], + "summary": "报告", + "operationId": "report", + "parameters": [ + { + "description": "报告请求", + "name": "param", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/domain.ReportReq" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/web.Resp" + } + } + } + } } }, "definitions": { @@ -2063,11 +2098,13 @@ "type": "string", "enum": [ "user", - "assistant" + "assistant", + "system" ], "x-enum-varnames": [ "ChatRoleUser", - "ChatRoleAssistant" + "ChatRoleAssistant", + "ChatRoleSystem" ] }, "consts.ModelProvider": { @@ -2125,6 +2162,19 @@ "ModelTypeReranker" ] }, + "consts.ReportAction": { + "type": "string", + "enum": [ + "accept", + "suggest", + "file_written" + ], + "x-enum-varnames": [ + "ReportActionAccept", + "ReportActionSuggest", + "ReportActionFileWritten" + ] + }, "consts.UserPlatform": { "type": "string", "enum": [ @@ -3125,6 +3175,26 @@ } } }, + "domain.ReportReq": { + "type": "object", + "properties": { + "action": { + "$ref": "#/definitions/consts.ReportAction" + }, + "content": { + "description": "内容", + "type": "string" + }, + "id": { + "description": "task_id or resp_id", + "type": "string" + }, + "tool": { + "description": "工具", + "type": "string" + } + } + }, "domain.Setting": { "type": "object", "properties": { @@ -3486,6 +3556,14 @@ "description": "代码行数", "type": "integer" }, + "user": { + "description": "用户信息", + "allOf": [ + { + "$ref": "#/definitions/domain.User" + } + ] + }, "username": { "description": "用户名", "type": "string" diff --git a/backend/domain/dashboard.go b/backend/domain/dashboard.go index 0a19360..99e1358 100644 --- a/backend/domain/dashboard.go +++ b/backend/domain/dashboard.go @@ -5,6 +5,7 @@ import ( "time" "github.com/chaitin/MonkeyCode/backend/db" + "github.com/chaitin/MonkeyCode/backend/pkg/cvt" ) type DashboardUsecase interface { @@ -62,6 +63,7 @@ type UserHeatmap struct { type UserCodeRank struct { Username string `json:"username"` // 用户名 Lines int64 `json:"lines"` // 代码行数 + User *User `json:"user"` // 用户信息 } func (u *UserCodeRank) From(d *db.Task) *UserCodeRank { @@ -70,6 +72,7 @@ func (u *UserCodeRank) From(d *db.Task) *UserCodeRank { } u.Username = d.Edges.User.Username u.Lines = d.CodeLines + u.User = cvt.From(d.Edges.User, &User{}) return u } diff --git a/backend/domain/proxy.go b/backend/domain/proxy.go index d344a98..24a5eb6 100644 --- a/backend/domain/proxy.go +++ b/backend/domain/proxy.go @@ -22,12 +22,14 @@ type ProxyUsecase interface { Record(ctx context.Context, record *RecordParam) error ValidateApiKey(ctx context.Context, key string) (*ApiKey, error) AcceptCompletion(ctx context.Context, req *AcceptCompletionReq) error + Report(ctx context.Context, req *ReportReq) error } type ProxyRepo interface { Record(ctx context.Context, record *RecordParam) error UpdateByTaskID(ctx context.Context, taskID string, fn func(*db.TaskUpdateOne)) error AcceptCompletion(ctx context.Context, req *AcceptCompletionReq) error + Report(ctx context.Context, req *ReportReq) error SelectModelWithLoadBalancing(modelName string, modelType consts.ModelType) (*db.Model, error) ValidateApiKey(ctx context.Context, key string) (*db.ApiKey, error) } @@ -42,6 +44,13 @@ type AcceptCompletionReq struct { Completion string `json:"completion"` // 补全内容 } +type ReportReq struct { + Action consts.ReportAction `json:"action"` + ID string `json:"id"` // task_id or resp_id + Content string `json:"content"` // 内容 + Tool string `json:"tool"` // 工具 +} + type RecordParam struct { RequestID string TaskID string @@ -57,6 +66,7 @@ type RecordParam struct { Completion string WorkMode string CodeLines int64 + Code string } func (r *RecordParam) Clone() *RecordParam { diff --git a/backend/ent/schema/task.go b/backend/ent/schema/task.go index 265385a..ae58d8f 100644 --- a/backend/ent/schema/task.go +++ b/backend/ent/schema/task.go @@ -41,6 +41,7 @@ func (Task) Fields() []ent.Field { field.Int64("code_lines").Optional(), field.Int64("input_tokens").Optional(), field.Int64("output_tokens").Optional(), + field.Bool("is_suggested").Default(false), field.Time("created_at").Default(time.Now), field.Time("updated_at").Default(time.Now).UpdateDefault(time.Now), } diff --git a/backend/ent/schema/taskrecord.go b/backend/ent/schema/taskrecord.go index edb1f3e..206fac6 100644 --- a/backend/ent/schema/taskrecord.go +++ b/backend/ent/schema/taskrecord.go @@ -35,6 +35,8 @@ func (TaskRecord) Fields() []ent.Field { field.String("role").GoType(consts.ChatRole("")), field.String("completion"), field.Int64("output_tokens"), + field.Int64("code_lines"), + field.String("code").Optional(), field.Time("created_at").Default(time.Now), field.Time("updated_at").Default(time.Now).UpdateDefault(time.Now), } diff --git a/backend/internal/billing/repo/billing.go b/backend/internal/billing/repo/billing.go index 0215e8c..0ea936f 100644 --- a/backend/internal/billing/repo/billing.go +++ b/backend/internal/billing/repo/billing.go @@ -28,6 +28,7 @@ func (b *BillingRepo) ChatInfo(ctx context.Context, id string) (*domain.ChatInfo record, err := b.db.Task.Query(). WithTaskRecords(func(trq *db.TaskRecordQuery) { trq.Order(taskrecord.ByCreatedAt(sql.OrderAsc())) + trq.Where(taskrecord.RoleNEQ(consts.ChatRoleSystem)) }). Where(task.TaskID(id)). First(ctx) @@ -114,6 +115,7 @@ func (b *BillingRepo) ListCompletionRecord(ctx context.Context, req domain.ListR trq.Order(taskrecord.ByCreatedAt(sql.OrderAsc())) }). Where(task.ModelType(consts.ModelTypeCoder)). + Where(task.IsSuggested(true)). Order(task.ByCreatedAt(sql.OrderDesc())) filterTask(q, req) diff --git a/backend/internal/dashboard/repo/dashboard.go b/backend/internal/dashboard/repo/dashboard.go index 823fcef..6aaf5b1 100644 --- a/backend/internal/dashboard/repo/dashboard.go +++ b/backend/internal/dashboard/repo/dashboard.go @@ -224,6 +224,7 @@ func (d *DashboardRepo) UserCodeRank(ctx context.Context, req domain.StatisticsF return &domain.UserCodeRank{ Username: m[v.UserID].Username, Lines: v.CodeLines, + User: cvt.From(m[v.UserID], &domain.User{}), } }), nil } diff --git a/backend/internal/openai/handler/v1/v1.go b/backend/internal/openai/handler/v1/v1.go index b1aaefc..46208bf 100644 --- a/backend/internal/openai/handler/v1/v1.go +++ b/backend/internal/openai/handler/v1/v1.go @@ -50,6 +50,7 @@ func NewV1Handler( g := w.Group("/v1", middleware.Auth()) g.GET("/models", web.BaseHandler(h.ModelList)) g.POST("/completion/accept", web.BindHandler(h.AcceptCompletion), active.Active("user")) + g.POST("/report", web.BindHandler(h.Report), active.Active("user")) g.POST("/chat/completions", web.BaseHandler(h.ChatCompletion), active.Active("user")) g.POST("/completions", web.BaseHandler(h.Completions), active.Active("user")) g.POST("/embeddings", web.BaseHandler(h.Embeddings), active.Active("user")) @@ -96,6 +97,25 @@ func (h *V1Handler) AcceptCompletion(c *web.Context, req domain.AcceptCompletion return nil } +// Report 报告 +// +// @Tags OpenAIV1 +// @Summary 报告 +// @Description 报告 +// @ID report +// @Accept json +// @Produce json +// @Param param body domain.ReportReq true "报告请求" +// @Success 200 {object} web.Resp{} +// @Router /v1/report [post] +func (h *V1Handler) Report(c *web.Context, req domain.ReportReq) error { + h.logger.DebugContext(c.Request().Context(), "Report", slog.Any("req", req)) + if err := h.proxyUse.Report(c.Request().Context(), &req); err != nil { + return err + } + return c.Success(nil) +} + // ModelList 模型列表 // // @Tags OpenAIV1 diff --git a/backend/internal/proxy/recorder.go b/backend/internal/proxy/recorder.go index 55e3e6a..8023835 100644 --- a/backend/internal/proxy/recorder.go +++ b/backend/internal/proxy/recorder.go @@ -16,6 +16,7 @@ import ( "github.com/chaitin/MonkeyCode/backend/config" "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/pkg/diff" "github.com/chaitin/MonkeyCode/backend/pkg/promptparser" ) @@ -66,7 +67,7 @@ func (r *Recorder) handleShadow() { } var ( - taskID, mode, prompt, language string + taskID, mode, prompt, language, tool, code string ) switch r.ctx.Model.ModelType { @@ -76,9 +77,11 @@ func (r *Recorder) handleShadow() { r.logger.WarnContext(r.ctx.ctx, "unmarshal chat completion request failed", "error", err) return } - prompt = r.getPrompt(r.ctx.ctx, &req) + prompt = req.Metadata["prompt"] taskID = req.Metadata["task_id"] mode = req.Metadata["mode"] + tool = req.Metadata["tool"] + code = req.Metadata["code"] case consts.ModelTypeCoder: var req domain.CompletionRequest @@ -108,19 +111,22 @@ func (r *Recorder) handleShadow() { WorkMode: mode, Prompt: prompt, ProgramLanguage: language, - Role: consts.ChatRoleUser, + Role: consts.ChatRoleAssistant, + } + + switch tool { + case "appliedDiff", "editedExistingFile": + lines := diff.ParseConflictsAndCountLines(code) + for _, line := range lines { + rc.CodeLines += int64(line) + } + case "newFileCreated": + rc.CodeLines = int64(strings.Count(code, "\n")) } - var assistantRc *domain.RecordParam ct := r.ctx.RespHeader.Get("Content-Type") if strings.Contains(ct, "stream") { r.handleStream(rc) - if r.ctx.Model.ModelType == consts.ModelTypeLLM { - assistantRc = rc.Clone() - assistantRc.Role = consts.ChatRoleAssistant - rc.Completion = "" - rc.OutputTokens = 0 - } } else { r.handleJson(rc) } @@ -130,15 +136,20 @@ func (r *Recorder) handleShadow() { With("resp_header", formatHeader(r.ctx.RespHeader)). DebugContext(r.ctx.ctx, "handle shadow", "rc", rc) - if err := r.usecase.Record(context.Background(), rc); err != nil { - r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err) - } - - if assistantRc != nil { - if err := r.usecase.Record(context.Background(), assistantRc); err != nil { + // 记录用户的提问 + if r.ctx.Model.ModelType == consts.ModelTypeLLM && prompt != "" { + tmp := rc.Clone() + tmp.Role = consts.ChatRoleUser + tmp.Completion = "" + tmp.OutputTokens = 0 + if err := r.usecase.Record(context.Background(), tmp); err != nil { r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err) } } + + if err := r.usecase.Record(context.Background(), rc); err != nil { + r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err) + } } func (r *Recorder) writeMeta(body []byte) { diff --git a/backend/internal/proxy/repo/proxy.go b/backend/internal/proxy/repo/proxy.go index de51136..9970d8f 100644 --- a/backend/internal/proxy/repo/proxy.go +++ b/backend/internal/proxy/repo/proxy.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "strings" "time" "github.com/google/uuid" @@ -16,6 +17,7 @@ import ( "github.com/chaitin/MonkeyCode/backend/db/task" "github.com/chaitin/MonkeyCode/backend/db/taskrecord" "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/pkg/diff" "github.com/chaitin/MonkeyCode/backend/pkg/entx" ) @@ -113,6 +115,9 @@ func (r *ProxyRepo) Record(ctx context.Context, record *domain.RecordParam) erro if t.InputTokens == 0 && record.InputTokens > 0 { up.SetInputTokens(record.InputTokens) } + if t.CodeLines > 0 { + up.AddCodeLines(record.CodeLines) + } if t.RequestID != record.RequestID { up.SetRequestID(record.RequestID) up.AddInputTokens(record.InputTokens) @@ -144,6 +149,8 @@ func (r *ProxyRepo) Record(ctx context.Context, record *domain.RecordParam) erro SetPrompt(record.Prompt). SetCompletion(record.Completion). SetOutputTokens(record.OutputTokens). + SetCodeLines(record.CodeLines). + SetCode(record.Code). Save(ctx) return err @@ -182,3 +189,88 @@ func (r *ProxyRepo) AcceptCompletion(ctx context.Context, req *domain.AcceptComp SetCompletion(req.Completion).Exec(ctx) }) } + +func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error { + return entx.WithTx(ctx, r.db, func(tx *db.Tx) error { + rc, err := tx.Task.Query().Where(task.TaskID(req.ID)).Only(ctx) + if err != nil { + return err + } + + switch req.Action { + case consts.ReportActionAccept: + if err := tx.Task.UpdateOneID(rc.ID). + SetIsAccept(true). + SetCompletion(req.Content). + Exec(ctx); err != nil { + return err + } + + return tx.TaskRecord.Update(). + Where(taskrecord.TaskID(rc.ID)). + SetCompletion(req.Content).Exec(ctx) + + case consts.ReportActionSuggest: + if err := tx.Task.UpdateOneID(rc.ID). + SetIsSuggested(true). + SetCompletion(req.Content). + Exec(ctx); err != nil { + return err + } + + return tx.TaskRecord.Update(). + Where(taskrecord.TaskID(rc.ID)). + SetCompletion(req.Content).Exec(ctx) + + case consts.ReportActionFileWritten: + if err := r.handleFileWritten(ctx, tx, rc, req); err != nil { + return err + } + } + + return nil + }) +} + +func (r *ProxyRepo) handleFileWritten(ctx context.Context, tx *db.Tx, rc *db.Task, req *domain.ReportReq) error { + lines := 0 + switch req.Tool { + case "appliedDiff", "editedExistingFile", "insertContent": + if strings.Contains(req.Content, "<<<<<<<") { + lines := diff.ParseConflictsAndCountLines(req.Content) + for _, line := range lines { + rc.CodeLines += int64(line) + } + } else { + rc.CodeLines = int64(strings.Count(req.Content, "\n")) + } + case "newFileCreated": + rc.CodeLines = int64(strings.Count(req.Content, "\n")) + } + + if lines > 0 { + if err := tx.Task. + UpdateOneID(rc.ID). + AddCodeLines(int64(lines)). + SetIsAccept(true). + Exec(ctx); err != nil { + return err + } + } + + if req.Content != "" { + if _, err := tx.TaskRecord.Create(). + SetTaskID(rc.ID). + SetRole(consts.ChatRoleSystem). + SetPrompt("写入文件"). + SetCompletion(""). + SetCodeLines(int64(lines)). + SetCode(req.Content). + SetOutputTokens(0). + Save(ctx); err != nil { + return err + } + } + + return nil +} diff --git a/backend/internal/proxy/usecase/proxy.go b/backend/internal/proxy/usecase/proxy.go index 065e506..65dcc7b 100644 --- a/backend/internal/proxy/usecase/proxy.go +++ b/backend/internal/proxy/usecase/proxy.go @@ -42,3 +42,7 @@ func (p *ProxyUsecase) ValidateApiKey(ctx context.Context, key string) (*domain. func (p *ProxyUsecase) AcceptCompletion(ctx context.Context, req *domain.AcceptCompletionReq) error { return p.repo.AcceptCompletion(ctx, req) } + +func (p *ProxyUsecase) Report(ctx context.Context, req *domain.ReportReq) error { + return p.repo.Report(ctx, req) +} diff --git a/backend/migration/000009_alter_task_table.down.sql b/backend/migration/000009_alter_task_table.down.sql new file mode 100644 index 0000000..e69de29 diff --git a/backend/migration/000009_alter_task_table.up.sql b/backend/migration/000009_alter_task_table.up.sql new file mode 100644 index 0000000..9ad72e9 --- /dev/null +++ b/backend/migration/000009_alter_task_table.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE tasks ADD column is_suggested boolean default false; +ALTER TABLE task_records ADD column code_lines int default 0; +ALTER TABLE task_records ADD column code text; \ No newline at end of file diff --git a/backend/pkg/diff/diff.go b/backend/pkg/diff/diff.go new file mode 100644 index 0000000..769615a --- /dev/null +++ b/backend/pkg/diff/diff.go @@ -0,0 +1,157 @@ +package diff + +import ( + "strings" +) + +type ConflictBlock struct { + OursContent []string + TheirsContent []string + StartLine int + EndLine int +} + +type ConflictParser struct { + lines []string + currentLine int + conflicts []ConflictBlock +} + +func NewConflictParser(text string) *ConflictParser { + return &ConflictParser{ + lines: strings.Split(text, "\n"), + } +} + +func (p *ConflictParser) ParseConflicts() []ConflictBlock { + for p.currentLine < len(p.lines) { + line := p.lines[p.currentLine] + + if p.isConflictStart(line) { + conflict := p.parseConflictBlock() + if conflict != nil { + p.conflicts = append(p.conflicts, *conflict) + } + } else { + p.currentLine++ + } + } + return p.conflicts +} + +func (p *ConflictParser) isConflictStart(line string) bool { + trimmed := strings.TrimSpace(line) + if len(trimmed) < 7 { + return false + } + + for i := range 7 { + if trimmed[i] != '<' { + return false + } + } + return true +} + +func (p *ConflictParser) isConflictSeparator(line string) bool { + trimmed := strings.TrimSpace(line) + if len(trimmed) != 7 { + return false + } + + for i := range 7 { + if trimmed[i] != '=' { + return false + } + } + return true +} + +func (p *ConflictParser) isConflictEnd(line string) bool { + trimmed := strings.TrimSpace(line) + if len(trimmed) < 7 { + return false + } + + for i := range 7 { + if trimmed[i] != '>' { + return false + } + } + return true +} + +func (p *ConflictParser) parseConflictBlock() *ConflictBlock { + startLine := p.currentLine + p.currentLine++ + + conflict := &ConflictBlock{ + StartLine: startLine, + } + + for p.currentLine < len(p.lines) { + line := p.lines[p.currentLine] + + if p.isConflictSeparator(line) { + p.currentLine++ + break + } + + conflict.OursContent = append(conflict.OursContent, line) + p.currentLine++ + } + + for p.currentLine < len(p.lines) { + line := p.lines[p.currentLine] + + if p.isConflictEnd(line) { + conflict.EndLine = p.currentLine + p.currentLine++ + return conflict + } + + conflict.TheirsContent = append(conflict.TheirsContent, line) + p.currentLine++ + } + + return nil +} + +func CountAddedLines(text string) int { + lines := strings.Split(text, "\n") + addedLines := 0 + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed != "" && !strings.HasPrefix(trimmed, "//") && !strings.HasPrefix(trimmed, "#") && !strings.HasPrefix(trimmed, "/*") { + addedLines++ + } + } + + return addedLines +} + +func (cb *ConflictBlock) CountAddedLinesInConflict() int { + theirsText := strings.Join(cb.TheirsContent, "\n") + return CountAddedLines(theirsText) +} + +func (cb *ConflictBlock) CountNetAddedLines() int { + oursText := strings.Join(cb.OursContent, "\n") + theirsText := strings.Join(cb.TheirsContent, "\n") + oursLines := CountAddedLines(oursText) + theirsLines := CountAddedLines(theirsText) + return theirsLines - oursLines +} + +func ParseConflictsAndCountLines(text string) []int { + parser := NewConflictParser(text) + conflicts := parser.ParseConflicts() + + var addedLines []int + for _, conflict := range conflicts { + addedLines = append(addedLines, conflict.CountAddedLinesInConflict()) + } + + return addedLines +} diff --git a/backend/pkg/diff/diff_test.go b/backend/pkg/diff/diff_test.go new file mode 100644 index 0000000..d4ef4a7 --- /dev/null +++ b/backend/pkg/diff/diff_test.go @@ -0,0 +1,25 @@ +package diff + +import ( + "fmt" + "testing" +) + +func TestParseConflictsAndCountLines(t *testing.T) { + conflictText := `<<<<<<< HEAD +old line 1 +======= +new line 1 +new line 2 +>>>>>>> branch1 +normal line +<<<<<<< HEAD +old line 2 +old line 3 +======= +new line 3 +>>>>>>> branch2` + + addedLines := ParseConflictsAndCountLines(conflictText) + fmt.Println(addedLines) +}