From 7426de7ba6f7171fd0e40dcc3d8a12e7f1f1559d Mon Sep 17 00:00:00 2001 From: yokowu <18836617@qq.com> Date: Thu, 3 Jul 2025 10:15:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=AE=B0=E5=BD=95=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E7=9A=84=E6=89=80=E6=9C=89=E9=97=AE=E7=AD=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/consts/openai.go | 7 + backend/db/client.go | 6 +- backend/db/migrate/schema.go | 10 +- backend/db/mutation.go | 279 +++++++++++++----- backend/db/runtime/runtime.go | 15 +- backend/db/task.go | 13 +- backend/db/task/task.go | 8 - backend/db/task/where.go | 80 ----- backend/db/task_create.go | 78 ----- backend/db/task_update.go | 52 ---- backend/db/taskrecord.go | 25 +- backend/db/taskrecord/taskrecord.go | 16 + backend/db/taskrecord/where.go | 171 +++++++++++ backend/db/taskrecord_create.go | 132 +++++++++ backend/db/taskrecord_update.go | 87 ++++++ backend/db/user.go | 13 +- backend/db/user/user.go | 16 + backend/db/user/where.go | 55 ++++ backend/db/user_create.go | 97 +++++- backend/db/user_query.go | 8 +- backend/db/user_update.go | 52 ++++ backend/docs/swagger.json | 34 ++- backend/domain/billing.go | 35 ++- backend/domain/proxy.go | 20 ++ backend/ent/schema/task.go | 1 - backend/ent/schema/taskrecord.go | 4 + backend/ent/schema/user.go | 7 + backend/internal/billing/repo/billing.go | 7 +- backend/internal/openai/handler/v1/v1.go | 1 - backend/internal/proxy/proxy.go | 72 +++-- backend/internal/proxy/repo/proxy.go | 20 +- .../migration/000002_create_core_table.up.sql | 6 +- backend/pkg/promptparser/promptparse.go | 23 +- backend/pkg/promptparser/promptparse_test.go | 17 ++ 34 files changed, 1095 insertions(+), 372 deletions(-) create mode 100644 backend/pkg/promptparser/promptparse_test.go diff --git a/backend/consts/openai.go b/backend/consts/openai.go index 5320032..68a1815 100644 --- a/backend/consts/openai.go +++ b/backend/consts/openai.go @@ -6,3 +6,10 @@ const ( ConfigTypeContinue ConfigType = "continue" ConfigTypeRooCode ConfigType = "roo-code" ) + +type ChatRole string + +const ( + ChatRoleUser ChatRole = "user" + ChatRoleAssistant ChatRole = "assistant" +) diff --git a/backend/db/client.go b/backend/db/client.go index 5a7ff86..820cd92 100644 --- a/backend/db/client.go +++ b/backend/db/client.go @@ -2542,12 +2542,14 @@ func (c *UserClient) QueryIdentities(u *User) *UserIdentityQuery { // Hooks returns the client hooks. func (c *UserClient) Hooks() []Hook { - return c.hooks.User + hooks := c.hooks.User + return append(hooks[:len(hooks):len(hooks)], user.Hooks[:]...) } // Interceptors returns the client interceptors. func (c *UserClient) Interceptors() []Interceptor { - return c.inters.User + inters := c.inters.User + return append(inters[:len(inters):len(inters)], user.Interceptors[:]...) } func (c *UserClient) mutate(ctx context.Context, m *UserMutation) (Value, error) { diff --git a/backend/db/migrate/schema.go b/backend/db/migrate/schema.go index b530578..a071d72 100644 --- a/backend/db/migrate/schema.go +++ b/backend/db/migrate/schema.go @@ -246,7 +246,6 @@ var ( {Name: "task_id", Type: field.TypeString, Unique: true}, {Name: "request_id", Type: field.TypeString, Nullable: true}, {Name: "model_type", Type: field.TypeString}, - {Name: "prompt", Type: field.TypeString, Nullable: true}, {Name: "is_accept", Type: field.TypeBool, Default: false}, {Name: "program_language", Type: field.TypeString, Nullable: true}, {Name: "work_mode", Type: field.TypeString, Nullable: true}, @@ -267,13 +266,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "tasks_models_tasks", - Columns: []*schema.Column{TasksColumns[14]}, + Columns: []*schema.Column{TasksColumns[13]}, RefColumns: []*schema.Column{ModelsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "tasks_users_tasks", - Columns: []*schema.Column{TasksColumns[15]}, + Columns: []*schema.Column{TasksColumns[14]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.SetNull, }, @@ -282,6 +281,8 @@ var ( // TaskRecordsColumns holds the columns for the "task_records" table. TaskRecordsColumns = []*schema.Column{ {Name: "id", Type: field.TypeUUID, Unique: true}, + {Name: "prompt", Type: field.TypeString, Nullable: true}, + {Name: "role", Type: field.TypeString}, {Name: "completion", Type: field.TypeString}, {Name: "output_tokens", Type: field.TypeInt64}, {Name: "created_at", Type: field.TypeTime}, @@ -296,7 +297,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "task_records_tasks_task_records", - Columns: []*schema.Column{TaskRecordsColumns[5]}, + Columns: []*schema.Column{TaskRecordsColumns[7]}, RefColumns: []*schema.Column{TasksColumns[0]}, OnDelete: schema.SetNull, }, @@ -305,6 +306,7 @@ var ( // UsersColumns holds the columns for the "users" table. UsersColumns = []*schema.Column{ {Name: "id", Type: field.TypeUUID}, + {Name: "deleted_at", Type: field.TypeTime, Nullable: true}, {Name: "username", Type: field.TypeString, Nullable: true}, {Name: "password", Type: field.TypeString, Nullable: true}, {Name: "email", Type: field.TypeString, Nullable: true}, diff --git a/backend/db/mutation.go b/backend/db/mutation.go index 601dc95..32c1a29 100644 --- a/backend/db/mutation.go +++ b/backend/db/mutation.go @@ -9017,7 +9017,6 @@ type TaskMutation struct { task_id *string request_id *string model_type *consts.ModelType - prompt *string is_accept *bool program_language *string work_mode *string @@ -9366,55 +9365,6 @@ func (m *TaskMutation) ResetModelType() { m.model_type = nil } -// SetPrompt sets the "prompt" field. -func (m *TaskMutation) SetPrompt(s string) { - m.prompt = &s -} - -// Prompt returns the value of the "prompt" field in the mutation. -func (m *TaskMutation) Prompt() (r string, exists bool) { - v := m.prompt - if v == nil { - return - } - return *v, true -} - -// OldPrompt returns the old "prompt" field's value of the Task entity. -// If the Task object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *TaskMutation) OldPrompt(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPrompt is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPrompt requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPrompt: %w", err) - } - return oldValue.Prompt, nil -} - -// ClearPrompt clears the value of the "prompt" field. -func (m *TaskMutation) ClearPrompt() { - m.prompt = nil - m.clearedFields[task.FieldPrompt] = struct{}{} -} - -// PromptCleared returns if the "prompt" field was cleared in this mutation. -func (m *TaskMutation) PromptCleared() bool { - _, ok := m.clearedFields[task.FieldPrompt] - return ok -} - -// ResetPrompt resets all changes to the "prompt" field. -func (m *TaskMutation) ResetPrompt() { - m.prompt = nil - delete(m.clearedFields, task.FieldPrompt) -} - // SetIsAccept sets the "is_accept" field. func (m *TaskMutation) SetIsAccept(b bool) { m.is_accept = &b @@ -10022,7 +9972,7 @@ func (m *TaskMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *TaskMutation) Fields() []string { - fields := make([]string, 0, 15) + fields := make([]string, 0, 14) if m.task_id != nil { fields = append(fields, task.FieldTaskID) } @@ -10038,9 +9988,6 @@ func (m *TaskMutation) Fields() []string { if m.model_type != nil { fields = append(fields, task.FieldModelType) } - if m.prompt != nil { - fields = append(fields, task.FieldPrompt) - } if m.is_accept != nil { fields = append(fields, task.FieldIsAccept) } @@ -10086,8 +10033,6 @@ func (m *TaskMutation) Field(name string) (ent.Value, bool) { return m.RequestID() case task.FieldModelType: return m.ModelType() - case task.FieldPrompt: - return m.Prompt() case task.FieldIsAccept: return m.IsAccept() case task.FieldProgramLanguage: @@ -10125,8 +10070,6 @@ func (m *TaskMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldRequestID(ctx) case task.FieldModelType: return m.OldModelType(ctx) - case task.FieldPrompt: - return m.OldPrompt(ctx) case task.FieldIsAccept: return m.OldIsAccept(ctx) case task.FieldProgramLanguage: @@ -10189,13 +10132,6 @@ func (m *TaskMutation) SetField(name string, value ent.Value) error { } m.SetModelType(v) return nil - case task.FieldPrompt: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPrompt(v) - return nil case task.FieldIsAccept: v, ok := value.(bool) if !ok { @@ -10337,9 +10273,6 @@ func (m *TaskMutation) ClearedFields() []string { if m.FieldCleared(task.FieldRequestID) { fields = append(fields, task.FieldRequestID) } - if m.FieldCleared(task.FieldPrompt) { - fields = append(fields, task.FieldPrompt) - } if m.FieldCleared(task.FieldProgramLanguage) { fields = append(fields, task.FieldProgramLanguage) } @@ -10381,9 +10314,6 @@ func (m *TaskMutation) ClearField(name string) error { case task.FieldRequestID: m.ClearRequestID() return nil - case task.FieldPrompt: - m.ClearPrompt() - return nil case task.FieldProgramLanguage: m.ClearProgramLanguage() return nil @@ -10425,9 +10355,6 @@ func (m *TaskMutation) ResetField(name string) error { case task.FieldModelType: m.ResetModelType() return nil - case task.FieldPrompt: - m.ResetPrompt() - return nil case task.FieldIsAccept: m.ResetIsAccept() return nil @@ -10585,6 +10512,8 @@ type TaskRecordMutation struct { op Op typ string id *uuid.UUID + prompt *string + role *consts.ChatRole completion *string output_tokens *int64 addoutput_tokens *int64 @@ -10751,6 +10680,91 @@ func (m *TaskRecordMutation) ResetTaskID() { delete(m.clearedFields, taskrecord.FieldTaskID) } +// SetPrompt sets the "prompt" field. +func (m *TaskRecordMutation) SetPrompt(s string) { + m.prompt = &s +} + +// Prompt returns the value of the "prompt" field in the mutation. +func (m *TaskRecordMutation) Prompt() (r string, exists bool) { + v := m.prompt + if v == nil { + return + } + return *v, true +} + +// OldPrompt returns the old "prompt" field's value of the TaskRecord entity. +// If the TaskRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskRecordMutation) OldPrompt(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPrompt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPrompt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPrompt: %w", err) + } + return oldValue.Prompt, nil +} + +// ClearPrompt clears the value of the "prompt" field. +func (m *TaskRecordMutation) ClearPrompt() { + m.prompt = nil + m.clearedFields[taskrecord.FieldPrompt] = struct{}{} +} + +// PromptCleared returns if the "prompt" field was cleared in this mutation. +func (m *TaskRecordMutation) PromptCleared() bool { + _, ok := m.clearedFields[taskrecord.FieldPrompt] + return ok +} + +// ResetPrompt resets all changes to the "prompt" field. +func (m *TaskRecordMutation) ResetPrompt() { + m.prompt = nil + delete(m.clearedFields, taskrecord.FieldPrompt) +} + +// SetRole sets the "role" field. +func (m *TaskRecordMutation) SetRole(cr consts.ChatRole) { + m.role = &cr +} + +// Role returns the value of the "role" field in the mutation. +func (m *TaskRecordMutation) Role() (r consts.ChatRole, exists bool) { + v := m.role + if v == nil { + return + } + return *v, true +} + +// OldRole returns the old "role" field's value of the TaskRecord entity. +// If the TaskRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskRecordMutation) OldRole(ctx context.Context) (v consts.ChatRole, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRole is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRole requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRole: %w", err) + } + return oldValue.Role, nil +} + +// ResetRole resets all changes to the "role" field. +func (m *TaskRecordMutation) ResetRole() { + m.role = nil +} + // SetCompletion sets the "completion" field. func (m *TaskRecordMutation) SetCompletion(s string) { m.completion = &s @@ -10976,10 +10990,16 @@ func (m *TaskRecordMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *TaskRecordMutation) Fields() []string { - fields := make([]string, 0, 5) + fields := make([]string, 0, 7) if m.task != nil { fields = append(fields, taskrecord.FieldTaskID) } + if m.prompt != nil { + fields = append(fields, taskrecord.FieldPrompt) + } + if m.role != nil { + fields = append(fields, taskrecord.FieldRole) + } if m.completion != nil { fields = append(fields, taskrecord.FieldCompletion) } @@ -11002,6 +11022,10 @@ func (m *TaskRecordMutation) Field(name string) (ent.Value, bool) { switch name { case taskrecord.FieldTaskID: return m.TaskID() + case taskrecord.FieldPrompt: + return m.Prompt() + case taskrecord.FieldRole: + return m.Role() case taskrecord.FieldCompletion: return m.Completion() case taskrecord.FieldOutputTokens: @@ -11021,6 +11045,10 @@ func (m *TaskRecordMutation) OldField(ctx context.Context, name string) (ent.Val switch name { case taskrecord.FieldTaskID: return m.OldTaskID(ctx) + case taskrecord.FieldPrompt: + return m.OldPrompt(ctx) + case taskrecord.FieldRole: + return m.OldRole(ctx) case taskrecord.FieldCompletion: return m.OldCompletion(ctx) case taskrecord.FieldOutputTokens: @@ -11045,6 +11073,20 @@ func (m *TaskRecordMutation) SetField(name string, value ent.Value) error { } m.SetTaskID(v) return nil + case taskrecord.FieldPrompt: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPrompt(v) + return nil + case taskrecord.FieldRole: + v, ok := value.(consts.ChatRole) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRole(v) + return nil case taskrecord.FieldCompletion: v, ok := value.(string) if !ok { @@ -11121,6 +11163,9 @@ func (m *TaskRecordMutation) ClearedFields() []string { if m.FieldCleared(taskrecord.FieldTaskID) { fields = append(fields, taskrecord.FieldTaskID) } + if m.FieldCleared(taskrecord.FieldPrompt) { + fields = append(fields, taskrecord.FieldPrompt) + } return fields } @@ -11138,6 +11183,9 @@ func (m *TaskRecordMutation) ClearField(name string) error { case taskrecord.FieldTaskID: m.ClearTaskID() return nil + case taskrecord.FieldPrompt: + m.ClearPrompt() + return nil } return fmt.Errorf("unknown TaskRecord nullable field %s", name) } @@ -11149,6 +11197,12 @@ func (m *TaskRecordMutation) ResetField(name string) error { case taskrecord.FieldTaskID: m.ResetTaskID() return nil + case taskrecord.FieldPrompt: + m.ResetPrompt() + return nil + case taskrecord.FieldRole: + m.ResetRole() + return nil case taskrecord.FieldCompletion: m.ResetCompletion() return nil @@ -11245,6 +11299,7 @@ type UserMutation struct { op Op typ string id *uuid.UUID + deleted_at *time.Time username *string password *string email *string @@ -11375,6 +11430,55 @@ func (m *UserMutation) IDs(ctx context.Context) ([]uuid.UUID, error) { } } +// SetDeletedAt sets the "deleted_at" field. +func (m *UserMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t +} + +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *UserMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at + if v == nil { + return + } + return *v, true +} + +// OldDeletedAt returns the old "deleted_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldDeletedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) + } + return oldValue.DeletedAt, nil +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *UserMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[user.FieldDeletedAt] = struct{}{} +} + +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *UserMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[user.FieldDeletedAt] + return ok +} + +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *UserMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, user.FieldDeletedAt) +} + // SetUsername sets the "username" field. func (m *UserMutation) SetUsername(s string) { m.username = &s @@ -11965,7 +12069,10 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 8) + fields := make([]string, 0, 9) + if m.deleted_at != nil { + fields = append(fields, user.FieldDeletedAt) + } if m.username != nil { fields = append(fields, user.FieldUsername) } @@ -11998,6 +12105,8 @@ func (m *UserMutation) Fields() []string { // schema. func (m *UserMutation) Field(name string) (ent.Value, bool) { switch name { + case user.FieldDeletedAt: + return m.DeletedAt() case user.FieldUsername: return m.Username() case user.FieldPassword: @@ -12023,6 +12132,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { // database failed. func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { + case user.FieldDeletedAt: + return m.OldDeletedAt(ctx) case user.FieldUsername: return m.OldUsername(ctx) case user.FieldPassword: @@ -12048,6 +12159,13 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er // type. func (m *UserMutation) SetField(name string, value ent.Value) error { switch name { + case user.FieldDeletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedAt(v) + return nil case user.FieldUsername: v, ok := value.(string) if !ok { @@ -12134,6 +12252,9 @@ func (m *UserMutation) AddField(name string, value ent.Value) error { // mutation. func (m *UserMutation) ClearedFields() []string { var fields []string + if m.FieldCleared(user.FieldDeletedAt) { + fields = append(fields, user.FieldDeletedAt) + } if m.FieldCleared(user.FieldUsername) { fields = append(fields, user.FieldUsername) } @@ -12160,6 +12281,9 @@ func (m *UserMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *UserMutation) ClearField(name string) error { switch name { + case user.FieldDeletedAt: + m.ClearDeletedAt() + return nil case user.FieldUsername: m.ClearUsername() return nil @@ -12180,6 +12304,9 @@ func (m *UserMutation) ClearField(name string) error { // It returns an error if the field is not defined in the schema. func (m *UserMutation) ResetField(name string) error { switch name { + case user.FieldDeletedAt: + m.ResetDeletedAt() + return nil case user.FieldUsername: m.ResetUsername() return nil diff --git a/backend/db/runtime/runtime.go b/backend/db/runtime/runtime.go index 2d50541..27478ec 100644 --- a/backend/db/runtime/runtime.go +++ b/backend/db/runtime/runtime.go @@ -238,15 +238,15 @@ func init() { taskFields := schema.Task{}.Fields() _ = taskFields // taskDescIsAccept is the schema descriptor for is_accept field. - taskDescIsAccept := taskFields[7].Descriptor() + taskDescIsAccept := taskFields[6].Descriptor() // task.DefaultIsAccept holds the default value on creation for the is_accept field. task.DefaultIsAccept = taskDescIsAccept.Default.(bool) // taskDescCreatedAt is the schema descriptor for created_at field. - taskDescCreatedAt := taskFields[14].Descriptor() + taskDescCreatedAt := taskFields[13].Descriptor() // task.DefaultCreatedAt holds the default value on creation for the created_at field. task.DefaultCreatedAt = taskDescCreatedAt.Default.(func() time.Time) // taskDescUpdatedAt is the schema descriptor for updated_at field. - taskDescUpdatedAt := taskFields[15].Descriptor() + taskDescUpdatedAt := taskFields[14].Descriptor() // task.DefaultUpdatedAt holds the default value on creation for the updated_at field. task.DefaultUpdatedAt = taskDescUpdatedAt.Default.(func() time.Time) // task.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. @@ -254,15 +254,20 @@ func init() { taskrecordFields := schema.TaskRecord{}.Fields() _ = taskrecordFields // taskrecordDescCreatedAt is the schema descriptor for created_at field. - taskrecordDescCreatedAt := taskrecordFields[4].Descriptor() + taskrecordDescCreatedAt := taskrecordFields[6].Descriptor() // taskrecord.DefaultCreatedAt holds the default value on creation for the created_at field. taskrecord.DefaultCreatedAt = taskrecordDescCreatedAt.Default.(func() time.Time) // taskrecordDescUpdatedAt is the schema descriptor for updated_at field. - taskrecordDescUpdatedAt := taskrecordFields[5].Descriptor() + taskrecordDescUpdatedAt := taskrecordFields[7].Descriptor() // taskrecord.DefaultUpdatedAt holds the default value on creation for the updated_at field. taskrecord.DefaultUpdatedAt = taskrecordDescUpdatedAt.Default.(func() time.Time) // taskrecord.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. taskrecord.UpdateDefaultUpdatedAt = taskrecordDescUpdatedAt.UpdateDefault.(func() time.Time) + userMixin := schema.User{}.Mixin() + userMixinHooks0 := userMixin[0].Hooks() + user.Hooks[0] = userMixinHooks0[0] + userMixinInters0 := userMixin[0].Interceptors() + user.Interceptors[0] = userMixinInters0[0] userFields := schema.User{}.Fields() _ = userFields // userDescPlatform is the schema descriptor for platform field. diff --git a/backend/db/task.go b/backend/db/task.go index 3c86a41..043c9c3 100644 --- a/backend/db/task.go +++ b/backend/db/task.go @@ -31,8 +31,6 @@ type Task struct { RequestID string `json:"request_id,omitempty"` // ModelType holds the value of the "model_type" field. ModelType consts.ModelType `json:"model_type,omitempty"` - // Prompt holds the value of the "prompt" field. - Prompt string `json:"prompt,omitempty"` // IsAccept holds the value of the "is_accept" field. IsAccept bool `json:"is_accept,omitempty"` // ProgramLanguage holds the value of the "program_language" field. @@ -110,7 +108,7 @@ func (*Task) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case task.FieldCodeLines, task.FieldInputTokens, task.FieldOutputTokens: values[i] = new(sql.NullInt64) - case task.FieldTaskID, task.FieldRequestID, task.FieldModelType, task.FieldPrompt, task.FieldProgramLanguage, task.FieldWorkMode, task.FieldCompletion: + case task.FieldTaskID, task.FieldRequestID, task.FieldModelType, task.FieldProgramLanguage, task.FieldWorkMode, task.FieldCompletion: values[i] = new(sql.NullString) case task.FieldCreatedAt, task.FieldUpdatedAt: values[i] = new(sql.NullTime) @@ -167,12 +165,6 @@ func (t *Task) assignValues(columns []string, values []any) error { } else if value.Valid { t.ModelType = consts.ModelType(value.String) } - case task.FieldPrompt: - if value, ok := values[i].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field prompt", values[i]) - } else if value.Valid { - t.Prompt = value.String - } case task.FieldIsAccept: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field is_accept", values[i]) @@ -293,9 +285,6 @@ func (t *Task) String() string { builder.WriteString("model_type=") builder.WriteString(fmt.Sprintf("%v", t.ModelType)) builder.WriteString(", ") - builder.WriteString("prompt=") - builder.WriteString(t.Prompt) - builder.WriteString(", ") builder.WriteString("is_accept=") builder.WriteString(fmt.Sprintf("%v", t.IsAccept)) builder.WriteString(", ") diff --git a/backend/db/task/task.go b/backend/db/task/task.go index ef9b130..fed0b73 100644 --- a/backend/db/task/task.go +++ b/backend/db/task/task.go @@ -24,8 +24,6 @@ const ( FieldRequestID = "request_id" // FieldModelType holds the string denoting the model_type field in the database. FieldModelType = "model_type" - // FieldPrompt holds the string denoting the prompt field in the database. - FieldPrompt = "prompt" // FieldIsAccept holds the string denoting the is_accept field in the database. FieldIsAccept = "is_accept" // FieldProgramLanguage holds the string denoting the program_language field in the database. @@ -83,7 +81,6 @@ var Columns = []string{ FieldModelID, FieldRequestID, FieldModelType, - FieldPrompt, FieldIsAccept, FieldProgramLanguage, FieldWorkMode, @@ -149,11 +146,6 @@ func ByModelType(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldModelType, opts...).ToFunc() } -// ByPrompt orders the results by the prompt field. -func ByPrompt(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldPrompt, opts...).ToFunc() -} - // ByIsAccept orders the results by the is_accept field. func ByIsAccept(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldIsAccept, opts...).ToFunc() diff --git a/backend/db/task/where.go b/backend/db/task/where.go index 6675562..5321211 100644 --- a/backend/db/task/where.go +++ b/backend/db/task/where.go @@ -83,11 +83,6 @@ func ModelType(v consts.ModelType) predicate.Task { return predicate.Task(sql.FieldEQ(FieldModelType, vc)) } -// Prompt applies equality check predicate on the "prompt" field. It's identical to PromptEQ. -func Prompt(v string) predicate.Task { - return predicate.Task(sql.FieldEQ(FieldPrompt, v)) -} - // IsAccept applies equality check predicate on the "is_accept" field. It's identical to IsAcceptEQ. func IsAccept(v bool) predicate.Task { return predicate.Task(sql.FieldEQ(FieldIsAccept, v)) @@ -417,81 +412,6 @@ func ModelTypeContainsFold(v consts.ModelType) predicate.Task { return predicate.Task(sql.FieldContainsFold(FieldModelType, vc)) } -// PromptEQ applies the EQ predicate on the "prompt" field. -func PromptEQ(v string) predicate.Task { - return predicate.Task(sql.FieldEQ(FieldPrompt, v)) -} - -// PromptNEQ applies the NEQ predicate on the "prompt" field. -func PromptNEQ(v string) predicate.Task { - return predicate.Task(sql.FieldNEQ(FieldPrompt, v)) -} - -// PromptIn applies the In predicate on the "prompt" field. -func PromptIn(vs ...string) predicate.Task { - return predicate.Task(sql.FieldIn(FieldPrompt, vs...)) -} - -// PromptNotIn applies the NotIn predicate on the "prompt" field. -func PromptNotIn(vs ...string) predicate.Task { - return predicate.Task(sql.FieldNotIn(FieldPrompt, vs...)) -} - -// PromptGT applies the GT predicate on the "prompt" field. -func PromptGT(v string) predicate.Task { - return predicate.Task(sql.FieldGT(FieldPrompt, v)) -} - -// PromptGTE applies the GTE predicate on the "prompt" field. -func PromptGTE(v string) predicate.Task { - return predicate.Task(sql.FieldGTE(FieldPrompt, v)) -} - -// PromptLT applies the LT predicate on the "prompt" field. -func PromptLT(v string) predicate.Task { - return predicate.Task(sql.FieldLT(FieldPrompt, v)) -} - -// PromptLTE applies the LTE predicate on the "prompt" field. -func PromptLTE(v string) predicate.Task { - return predicate.Task(sql.FieldLTE(FieldPrompt, v)) -} - -// PromptContains applies the Contains predicate on the "prompt" field. -func PromptContains(v string) predicate.Task { - return predicate.Task(sql.FieldContains(FieldPrompt, v)) -} - -// PromptHasPrefix applies the HasPrefix predicate on the "prompt" field. -func PromptHasPrefix(v string) predicate.Task { - return predicate.Task(sql.FieldHasPrefix(FieldPrompt, v)) -} - -// PromptHasSuffix applies the HasSuffix predicate on the "prompt" field. -func PromptHasSuffix(v string) predicate.Task { - return predicate.Task(sql.FieldHasSuffix(FieldPrompt, v)) -} - -// PromptIsNil applies the IsNil predicate on the "prompt" field. -func PromptIsNil() predicate.Task { - return predicate.Task(sql.FieldIsNull(FieldPrompt)) -} - -// PromptNotNil applies the NotNil predicate on the "prompt" field. -func PromptNotNil() predicate.Task { - return predicate.Task(sql.FieldNotNull(FieldPrompt)) -} - -// PromptEqualFold applies the EqualFold predicate on the "prompt" field. -func PromptEqualFold(v string) predicate.Task { - return predicate.Task(sql.FieldEqualFold(FieldPrompt, v)) -} - -// PromptContainsFold applies the ContainsFold predicate on the "prompt" field. -func PromptContainsFold(v string) predicate.Task { - return predicate.Task(sql.FieldContainsFold(FieldPrompt, v)) -} - // IsAcceptEQ applies the EQ predicate on the "is_accept" field. func IsAcceptEQ(v bool) predicate.Task { return predicate.Task(sql.FieldEQ(FieldIsAccept, v)) diff --git a/backend/db/task_create.go b/backend/db/task_create.go index f6ada06..adcfb3d 100644 --- a/backend/db/task_create.go +++ b/backend/db/task_create.go @@ -82,20 +82,6 @@ func (tc *TaskCreate) SetModelType(ct consts.ModelType) *TaskCreate { return tc } -// SetPrompt sets the "prompt" field. -func (tc *TaskCreate) SetPrompt(s string) *TaskCreate { - tc.mutation.SetPrompt(s) - return tc -} - -// SetNillablePrompt sets the "prompt" field if the given value is not nil. -func (tc *TaskCreate) SetNillablePrompt(s *string) *TaskCreate { - if s != nil { - tc.SetPrompt(*s) - } - return tc -} - // SetIsAccept sets the "is_accept" field. func (tc *TaskCreate) SetIsAccept(b bool) *TaskCreate { tc.mutation.SetIsAccept(b) @@ -367,10 +353,6 @@ func (tc *TaskCreate) createSpec() (*Task, *sqlgraph.CreateSpec) { _spec.SetField(task.FieldModelType, field.TypeString, value) _node.ModelType = value } - if value, ok := tc.mutation.Prompt(); ok { - _spec.SetField(task.FieldPrompt, field.TypeString, value) - _node.Prompt = value - } if value, ok := tc.mutation.IsAccept(); ok { _spec.SetField(task.FieldIsAccept, field.TypeBool, value) _node.IsAccept = value @@ -587,24 +569,6 @@ func (u *TaskUpsert) UpdateModelType() *TaskUpsert { return u } -// SetPrompt sets the "prompt" field. -func (u *TaskUpsert) SetPrompt(v string) *TaskUpsert { - u.Set(task.FieldPrompt, v) - return u -} - -// UpdatePrompt sets the "prompt" field to the value that was provided on create. -func (u *TaskUpsert) UpdatePrompt() *TaskUpsert { - u.SetExcluded(task.FieldPrompt) - return u -} - -// ClearPrompt clears the value of the "prompt" field. -func (u *TaskUpsert) ClearPrompt() *TaskUpsert { - u.SetNull(task.FieldPrompt) - return u -} - // SetIsAccept sets the "is_accept" field. func (u *TaskUpsert) SetIsAccept(v bool) *TaskUpsert { u.Set(task.FieldIsAccept, v) @@ -906,27 +870,6 @@ func (u *TaskUpsertOne) UpdateModelType() *TaskUpsertOne { }) } -// SetPrompt sets the "prompt" field. -func (u *TaskUpsertOne) SetPrompt(v string) *TaskUpsertOne { - return u.Update(func(s *TaskUpsert) { - s.SetPrompt(v) - }) -} - -// UpdatePrompt sets the "prompt" field to the value that was provided on create. -func (u *TaskUpsertOne) UpdatePrompt() *TaskUpsertOne { - return u.Update(func(s *TaskUpsert) { - s.UpdatePrompt() - }) -} - -// ClearPrompt clears the value of the "prompt" field. -func (u *TaskUpsertOne) ClearPrompt() *TaskUpsertOne { - return u.Update(func(s *TaskUpsert) { - s.ClearPrompt() - }) -} - // SetIsAccept sets the "is_accept" field. func (u *TaskUpsertOne) SetIsAccept(v bool) *TaskUpsertOne { return u.Update(func(s *TaskUpsert) { @@ -1422,27 +1365,6 @@ func (u *TaskUpsertBulk) UpdateModelType() *TaskUpsertBulk { }) } -// SetPrompt sets the "prompt" field. -func (u *TaskUpsertBulk) SetPrompt(v string) *TaskUpsertBulk { - return u.Update(func(s *TaskUpsert) { - s.SetPrompt(v) - }) -} - -// UpdatePrompt sets the "prompt" field to the value that was provided on create. -func (u *TaskUpsertBulk) UpdatePrompt() *TaskUpsertBulk { - return u.Update(func(s *TaskUpsert) { - s.UpdatePrompt() - }) -} - -// ClearPrompt clears the value of the "prompt" field. -func (u *TaskUpsertBulk) ClearPrompt() *TaskUpsertBulk { - return u.Update(func(s *TaskUpsert) { - s.ClearPrompt() - }) -} - // SetIsAccept sets the "is_accept" field. func (u *TaskUpsertBulk) SetIsAccept(v bool) *TaskUpsertBulk { return u.Update(func(s *TaskUpsert) { diff --git a/backend/db/task_update.go b/backend/db/task_update.go index 87a1453..dfddaa9 100644 --- a/backend/db/task_update.go +++ b/backend/db/task_update.go @@ -122,26 +122,6 @@ func (tu *TaskUpdate) SetNillableModelType(ct *consts.ModelType) *TaskUpdate { return tu } -// SetPrompt sets the "prompt" field. -func (tu *TaskUpdate) SetPrompt(s string) *TaskUpdate { - tu.mutation.SetPrompt(s) - return tu -} - -// SetNillablePrompt sets the "prompt" field if the given value is not nil. -func (tu *TaskUpdate) SetNillablePrompt(s *string) *TaskUpdate { - if s != nil { - tu.SetPrompt(*s) - } - return tu -} - -// ClearPrompt clears the value of the "prompt" field. -func (tu *TaskUpdate) ClearPrompt() *TaskUpdate { - tu.mutation.ClearPrompt() - return tu -} - // SetIsAccept sets the "is_accept" field. func (tu *TaskUpdate) SetIsAccept(b bool) *TaskUpdate { tu.mutation.SetIsAccept(b) @@ -443,12 +423,6 @@ func (tu *TaskUpdate) sqlSave(ctx context.Context) (n int, err error) { if value, ok := tu.mutation.ModelType(); ok { _spec.SetField(task.FieldModelType, field.TypeString, value) } - if value, ok := tu.mutation.Prompt(); ok { - _spec.SetField(task.FieldPrompt, field.TypeString, value) - } - if tu.mutation.PromptCleared() { - _spec.ClearField(task.FieldPrompt, field.TypeString) - } if value, ok := tu.mutation.IsAccept(); ok { _spec.SetField(task.FieldIsAccept, field.TypeBool, value) } @@ -716,26 +690,6 @@ func (tuo *TaskUpdateOne) SetNillableModelType(ct *consts.ModelType) *TaskUpdate return tuo } -// SetPrompt sets the "prompt" field. -func (tuo *TaskUpdateOne) SetPrompt(s string) *TaskUpdateOne { - tuo.mutation.SetPrompt(s) - return tuo -} - -// SetNillablePrompt sets the "prompt" field if the given value is not nil. -func (tuo *TaskUpdateOne) SetNillablePrompt(s *string) *TaskUpdateOne { - if s != nil { - tuo.SetPrompt(*s) - } - return tuo -} - -// ClearPrompt clears the value of the "prompt" field. -func (tuo *TaskUpdateOne) ClearPrompt() *TaskUpdateOne { - tuo.mutation.ClearPrompt() - return tuo -} - // SetIsAccept sets the "is_accept" field. func (tuo *TaskUpdateOne) SetIsAccept(b bool) *TaskUpdateOne { tuo.mutation.SetIsAccept(b) @@ -1067,12 +1021,6 @@ func (tuo *TaskUpdateOne) sqlSave(ctx context.Context) (_node *Task, err error) if value, ok := tuo.mutation.ModelType(); ok { _spec.SetField(task.FieldModelType, field.TypeString, value) } - if value, ok := tuo.mutation.Prompt(); ok { - _spec.SetField(task.FieldPrompt, field.TypeString, value) - } - if tuo.mutation.PromptCleared() { - _spec.ClearField(task.FieldPrompt, field.TypeString) - } if value, ok := tuo.mutation.IsAccept(); ok { _spec.SetField(task.FieldIsAccept, field.TypeBool, value) } diff --git a/backend/db/taskrecord.go b/backend/db/taskrecord.go index 44f5271..88efe18 100644 --- a/backend/db/taskrecord.go +++ b/backend/db/taskrecord.go @@ -9,6 +9,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" + "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/db/task" "github.com/chaitin/MonkeyCode/backend/db/taskrecord" "github.com/google/uuid" @@ -21,6 +22,10 @@ type TaskRecord struct { ID uuid.UUID `json:"id,omitempty"` // TaskID holds the value of the "task_id" field. TaskID uuid.UUID `json:"task_id,omitempty"` + // Prompt holds the value of the "prompt" field. + Prompt string `json:"prompt,omitempty"` + // Role holds the value of the "role" field. + Role consts.ChatRole `json:"role,omitempty"` // Completion holds the value of the "completion" field. Completion string `json:"completion,omitempty"` // OutputTokens holds the value of the "output_tokens" field. @@ -62,7 +67,7 @@ func (*TaskRecord) scanValues(columns []string) ([]any, error) { switch columns[i] { case taskrecord.FieldOutputTokens: values[i] = new(sql.NullInt64) - case taskrecord.FieldCompletion: + case taskrecord.FieldPrompt, taskrecord.FieldRole, taskrecord.FieldCompletion: values[i] = new(sql.NullString) case taskrecord.FieldCreatedAt, taskrecord.FieldUpdatedAt: values[i] = new(sql.NullTime) @@ -95,6 +100,18 @@ func (tr *TaskRecord) assignValues(columns []string, values []any) error { } else if value != nil { tr.TaskID = *value } + case taskrecord.FieldPrompt: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field prompt", values[i]) + } else if value.Valid { + tr.Prompt = value.String + } + case taskrecord.FieldRole: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field role", values[i]) + } else if value.Valid { + tr.Role = consts.ChatRole(value.String) + } case taskrecord.FieldCompletion: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field completion", values[i]) @@ -163,6 +180,12 @@ func (tr *TaskRecord) String() string { builder.WriteString("task_id=") builder.WriteString(fmt.Sprintf("%v", tr.TaskID)) builder.WriteString(", ") + builder.WriteString("prompt=") + builder.WriteString(tr.Prompt) + builder.WriteString(", ") + builder.WriteString("role=") + builder.WriteString(fmt.Sprintf("%v", tr.Role)) + builder.WriteString(", ") builder.WriteString("completion=") builder.WriteString(tr.Completion) builder.WriteString(", ") diff --git a/backend/db/taskrecord/taskrecord.go b/backend/db/taskrecord/taskrecord.go index bfb5ce2..bca05dd 100644 --- a/backend/db/taskrecord/taskrecord.go +++ b/backend/db/taskrecord/taskrecord.go @@ -16,6 +16,10 @@ const ( FieldID = "id" // FieldTaskID holds the string denoting the task_id field in the database. FieldTaskID = "task_id" + // FieldPrompt holds the string denoting the prompt field in the database. + FieldPrompt = "prompt" + // FieldRole holds the string denoting the role field in the database. + FieldRole = "role" // FieldCompletion holds the string denoting the completion field in the database. FieldCompletion = "completion" // FieldOutputTokens holds the string denoting the output_tokens field in the database. @@ -41,6 +45,8 @@ const ( var Columns = []string{ FieldID, FieldTaskID, + FieldPrompt, + FieldRole, FieldCompletion, FieldOutputTokens, FieldCreatedAt, @@ -79,6 +85,16 @@ func ByTaskID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTaskID, opts...).ToFunc() } +// ByPrompt orders the results by the prompt field. +func ByPrompt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPrompt, opts...).ToFunc() +} + +// ByRole orders the results by the role field. +func ByRole(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRole, opts...).ToFunc() +} + // ByCompletion orders the results by the completion field. func ByCompletion(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCompletion, opts...).ToFunc() diff --git a/backend/db/taskrecord/where.go b/backend/db/taskrecord/where.go index a53a3e2..38a4a74 100644 --- a/backend/db/taskrecord/where.go +++ b/backend/db/taskrecord/where.go @@ -7,6 +7,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/db/predicate" "github.com/google/uuid" ) @@ -61,6 +62,17 @@ func TaskID(v uuid.UUID) predicate.TaskRecord { return predicate.TaskRecord(sql.FieldEQ(FieldTaskID, v)) } +// Prompt applies equality check predicate on the "prompt" field. It's identical to PromptEQ. +func Prompt(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldEQ(FieldPrompt, v)) +} + +// Role applies equality check predicate on the "role" field. It's identical to RoleEQ. +func Role(v consts.ChatRole) predicate.TaskRecord { + vc := string(v) + return predicate.TaskRecord(sql.FieldEQ(FieldRole, vc)) +} + // Completion applies equality check predicate on the "completion" field. It's identical to CompletionEQ. func Completion(v string) predicate.TaskRecord { return predicate.TaskRecord(sql.FieldEQ(FieldCompletion, v)) @@ -111,6 +123,165 @@ func TaskIDNotNil() predicate.TaskRecord { return predicate.TaskRecord(sql.FieldNotNull(FieldTaskID)) } +// PromptEQ applies the EQ predicate on the "prompt" field. +func PromptEQ(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldEQ(FieldPrompt, v)) +} + +// PromptNEQ applies the NEQ predicate on the "prompt" field. +func PromptNEQ(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNEQ(FieldPrompt, v)) +} + +// PromptIn applies the In predicate on the "prompt" field. +func PromptIn(vs ...string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldIn(FieldPrompt, vs...)) +} + +// PromptNotIn applies the NotIn predicate on the "prompt" field. +func PromptNotIn(vs ...string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNotIn(FieldPrompt, vs...)) +} + +// PromptGT applies the GT predicate on the "prompt" field. +func PromptGT(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldGT(FieldPrompt, v)) +} + +// PromptGTE applies the GTE predicate on the "prompt" field. +func PromptGTE(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldGTE(FieldPrompt, v)) +} + +// PromptLT applies the LT predicate on the "prompt" field. +func PromptLT(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldLT(FieldPrompt, v)) +} + +// PromptLTE applies the LTE predicate on the "prompt" field. +func PromptLTE(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldLTE(FieldPrompt, v)) +} + +// PromptContains applies the Contains predicate on the "prompt" field. +func PromptContains(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldContains(FieldPrompt, v)) +} + +// PromptHasPrefix applies the HasPrefix predicate on the "prompt" field. +func PromptHasPrefix(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldHasPrefix(FieldPrompt, v)) +} + +// PromptHasSuffix applies the HasSuffix predicate on the "prompt" field. +func PromptHasSuffix(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldHasSuffix(FieldPrompt, v)) +} + +// PromptIsNil applies the IsNil predicate on the "prompt" field. +func PromptIsNil() predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldIsNull(FieldPrompt)) +} + +// PromptNotNil applies the NotNil predicate on the "prompt" field. +func PromptNotNil() predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNotNull(FieldPrompt)) +} + +// PromptEqualFold applies the EqualFold predicate on the "prompt" field. +func PromptEqualFold(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldEqualFold(FieldPrompt, v)) +} + +// PromptContainsFold applies the ContainsFold predicate on the "prompt" field. +func PromptContainsFold(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldContainsFold(FieldPrompt, v)) +} + +// RoleEQ applies the EQ predicate on the "role" field. +func RoleEQ(v consts.ChatRole) predicate.TaskRecord { + vc := string(v) + return predicate.TaskRecord(sql.FieldEQ(FieldRole, vc)) +} + +// RoleNEQ applies the NEQ predicate on the "role" field. +func RoleNEQ(v consts.ChatRole) predicate.TaskRecord { + vc := string(v) + return predicate.TaskRecord(sql.FieldNEQ(FieldRole, vc)) +} + +// RoleIn applies the In predicate on the "role" field. +func RoleIn(vs ...consts.ChatRole) predicate.TaskRecord { + v := make([]any, len(vs)) + for i := range v { + v[i] = string(vs[i]) + } + return predicate.TaskRecord(sql.FieldIn(FieldRole, v...)) +} + +// RoleNotIn applies the NotIn predicate on the "role" field. +func RoleNotIn(vs ...consts.ChatRole) predicate.TaskRecord { + v := make([]any, len(vs)) + for i := range v { + v[i] = string(vs[i]) + } + return predicate.TaskRecord(sql.FieldNotIn(FieldRole, v...)) +} + +// RoleGT applies the GT predicate on the "role" field. +func RoleGT(v consts.ChatRole) predicate.TaskRecord { + vc := string(v) + return predicate.TaskRecord(sql.FieldGT(FieldRole, vc)) +} + +// RoleGTE applies the GTE predicate on the "role" field. +func RoleGTE(v consts.ChatRole) predicate.TaskRecord { + vc := string(v) + return predicate.TaskRecord(sql.FieldGTE(FieldRole, vc)) +} + +// RoleLT applies the LT predicate on the "role" field. +func RoleLT(v consts.ChatRole) predicate.TaskRecord { + vc := string(v) + return predicate.TaskRecord(sql.FieldLT(FieldRole, vc)) +} + +// RoleLTE applies the LTE predicate on the "role" field. +func RoleLTE(v consts.ChatRole) predicate.TaskRecord { + vc := string(v) + return predicate.TaskRecord(sql.FieldLTE(FieldRole, vc)) +} + +// RoleContains applies the Contains predicate on the "role" field. +func RoleContains(v consts.ChatRole) predicate.TaskRecord { + vc := string(v) + return predicate.TaskRecord(sql.FieldContains(FieldRole, vc)) +} + +// RoleHasPrefix applies the HasPrefix predicate on the "role" field. +func RoleHasPrefix(v consts.ChatRole) predicate.TaskRecord { + vc := string(v) + return predicate.TaskRecord(sql.FieldHasPrefix(FieldRole, vc)) +} + +// RoleHasSuffix applies the HasSuffix predicate on the "role" field. +func RoleHasSuffix(v consts.ChatRole) predicate.TaskRecord { + vc := string(v) + return predicate.TaskRecord(sql.FieldHasSuffix(FieldRole, vc)) +} + +// RoleEqualFold applies the EqualFold predicate on the "role" field. +func RoleEqualFold(v consts.ChatRole) predicate.TaskRecord { + vc := string(v) + return predicate.TaskRecord(sql.FieldEqualFold(FieldRole, vc)) +} + +// RoleContainsFold applies the ContainsFold predicate on the "role" field. +func RoleContainsFold(v consts.ChatRole) predicate.TaskRecord { + vc := string(v) + return predicate.TaskRecord(sql.FieldContainsFold(FieldRole, vc)) +} + // CompletionEQ applies the EQ predicate on the "completion" field. func CompletionEQ(v string) predicate.TaskRecord { return predicate.TaskRecord(sql.FieldEQ(FieldCompletion, v)) diff --git a/backend/db/taskrecord_create.go b/backend/db/taskrecord_create.go index 078eab5..aa78db4 100644 --- a/backend/db/taskrecord_create.go +++ b/backend/db/taskrecord_create.go @@ -12,6 +12,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/db/task" "github.com/chaitin/MonkeyCode/backend/db/taskrecord" "github.com/google/uuid" @@ -39,6 +40,26 @@ func (trc *TaskRecordCreate) SetNillableTaskID(u *uuid.UUID) *TaskRecordCreate { return trc } +// SetPrompt sets the "prompt" field. +func (trc *TaskRecordCreate) SetPrompt(s string) *TaskRecordCreate { + trc.mutation.SetPrompt(s) + return trc +} + +// SetNillablePrompt sets the "prompt" field if the given value is not nil. +func (trc *TaskRecordCreate) SetNillablePrompt(s *string) *TaskRecordCreate { + if s != nil { + trc.SetPrompt(*s) + } + return trc +} + +// SetRole sets the "role" field. +func (trc *TaskRecordCreate) SetRole(cr consts.ChatRole) *TaskRecordCreate { + trc.mutation.SetRole(cr) + return trc +} + // SetCompletion sets the "completion" field. func (trc *TaskRecordCreate) SetCompletion(s string) *TaskRecordCreate { trc.mutation.SetCompletion(s) @@ -137,6 +158,9 @@ func (trc *TaskRecordCreate) defaults() { // check runs all checks and user-defined validators on the builder. func (trc *TaskRecordCreate) check() error { + if _, ok := trc.mutation.Role(); !ok { + return &ValidationError{Name: "role", err: errors.New(`db: missing required field "TaskRecord.role"`)} + } if _, ok := trc.mutation.Completion(); !ok { return &ValidationError{Name: "completion", err: errors.New(`db: missing required field "TaskRecord.completion"`)} } @@ -185,6 +209,14 @@ func (trc *TaskRecordCreate) createSpec() (*TaskRecord, *sqlgraph.CreateSpec) { _node.ID = id _spec.ID.Value = &id } + if value, ok := trc.mutation.Prompt(); ok { + _spec.SetField(taskrecord.FieldPrompt, field.TypeString, value) + _node.Prompt = value + } + if value, ok := trc.mutation.Role(); ok { + _spec.SetField(taskrecord.FieldRole, field.TypeString, value) + _node.Role = value + } if value, ok := trc.mutation.Completion(); ok { _spec.SetField(taskrecord.FieldCompletion, field.TypeString, value) _node.Completion = value @@ -288,6 +320,36 @@ func (u *TaskRecordUpsert) ClearTaskID() *TaskRecordUpsert { return u } +// SetPrompt sets the "prompt" field. +func (u *TaskRecordUpsert) SetPrompt(v string) *TaskRecordUpsert { + u.Set(taskrecord.FieldPrompt, v) + return u +} + +// UpdatePrompt sets the "prompt" field to the value that was provided on create. +func (u *TaskRecordUpsert) UpdatePrompt() *TaskRecordUpsert { + u.SetExcluded(taskrecord.FieldPrompt) + return u +} + +// ClearPrompt clears the value of the "prompt" field. +func (u *TaskRecordUpsert) ClearPrompt() *TaskRecordUpsert { + u.SetNull(taskrecord.FieldPrompt) + return u +} + +// SetRole sets the "role" field. +func (u *TaskRecordUpsert) SetRole(v consts.ChatRole) *TaskRecordUpsert { + u.Set(taskrecord.FieldRole, v) + return u +} + +// UpdateRole sets the "role" field to the value that was provided on create. +func (u *TaskRecordUpsert) UpdateRole() *TaskRecordUpsert { + u.SetExcluded(taskrecord.FieldRole) + return u +} + // SetCompletion sets the "completion" field. func (u *TaskRecordUpsert) SetCompletion(v string) *TaskRecordUpsert { u.Set(taskrecord.FieldCompletion, v) @@ -411,6 +473,41 @@ func (u *TaskRecordUpsertOne) ClearTaskID() *TaskRecordUpsertOne { }) } +// SetPrompt sets the "prompt" field. +func (u *TaskRecordUpsertOne) SetPrompt(v string) *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.SetPrompt(v) + }) +} + +// UpdatePrompt sets the "prompt" field to the value that was provided on create. +func (u *TaskRecordUpsertOne) UpdatePrompt() *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.UpdatePrompt() + }) +} + +// ClearPrompt clears the value of the "prompt" field. +func (u *TaskRecordUpsertOne) ClearPrompt() *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.ClearPrompt() + }) +} + +// SetRole sets the "role" field. +func (u *TaskRecordUpsertOne) SetRole(v consts.ChatRole) *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.SetRole(v) + }) +} + +// UpdateRole sets the "role" field to the value that was provided on create. +func (u *TaskRecordUpsertOne) UpdateRole() *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.UpdateRole() + }) +} + // SetCompletion sets the "completion" field. func (u *TaskRecordUpsertOne) SetCompletion(v string) *TaskRecordUpsertOne { return u.Update(func(s *TaskRecordUpsert) { @@ -710,6 +807,41 @@ func (u *TaskRecordUpsertBulk) ClearTaskID() *TaskRecordUpsertBulk { }) } +// SetPrompt sets the "prompt" field. +func (u *TaskRecordUpsertBulk) SetPrompt(v string) *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.SetPrompt(v) + }) +} + +// UpdatePrompt sets the "prompt" field to the value that was provided on create. +func (u *TaskRecordUpsertBulk) UpdatePrompt() *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.UpdatePrompt() + }) +} + +// ClearPrompt clears the value of the "prompt" field. +func (u *TaskRecordUpsertBulk) ClearPrompt() *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.ClearPrompt() + }) +} + +// SetRole sets the "role" field. +func (u *TaskRecordUpsertBulk) SetRole(v consts.ChatRole) *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.SetRole(v) + }) +} + +// UpdateRole sets the "role" field to the value that was provided on create. +func (u *TaskRecordUpsertBulk) UpdateRole() *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.UpdateRole() + }) +} + // SetCompletion sets the "completion" field. func (u *TaskRecordUpsertBulk) SetCompletion(v string) *TaskRecordUpsertBulk { return u.Update(func(s *TaskRecordUpsert) { diff --git a/backend/db/taskrecord_update.go b/backend/db/taskrecord_update.go index 2576fc9..4f16d34 100644 --- a/backend/db/taskrecord_update.go +++ b/backend/db/taskrecord_update.go @@ -11,6 +11,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/db/predicate" "github.com/chaitin/MonkeyCode/backend/db/task" "github.com/chaitin/MonkeyCode/backend/db/taskrecord" @@ -51,6 +52,40 @@ func (tru *TaskRecordUpdate) ClearTaskID() *TaskRecordUpdate { return tru } +// SetPrompt sets the "prompt" field. +func (tru *TaskRecordUpdate) SetPrompt(s string) *TaskRecordUpdate { + tru.mutation.SetPrompt(s) + return tru +} + +// SetNillablePrompt sets the "prompt" field if the given value is not nil. +func (tru *TaskRecordUpdate) SetNillablePrompt(s *string) *TaskRecordUpdate { + if s != nil { + tru.SetPrompt(*s) + } + return tru +} + +// ClearPrompt clears the value of the "prompt" field. +func (tru *TaskRecordUpdate) ClearPrompt() *TaskRecordUpdate { + tru.mutation.ClearPrompt() + return tru +} + +// SetRole sets the "role" field. +func (tru *TaskRecordUpdate) SetRole(cr consts.ChatRole) *TaskRecordUpdate { + tru.mutation.SetRole(cr) + return tru +} + +// SetNillableRole sets the "role" field if the given value is not nil. +func (tru *TaskRecordUpdate) SetNillableRole(cr *consts.ChatRole) *TaskRecordUpdate { + if cr != nil { + tru.SetRole(*cr) + } + return tru +} + // SetCompletion sets the "completion" field. func (tru *TaskRecordUpdate) SetCompletion(s string) *TaskRecordUpdate { tru.mutation.SetCompletion(s) @@ -173,6 +208,15 @@ func (tru *TaskRecordUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } + if value, ok := tru.mutation.Prompt(); ok { + _spec.SetField(taskrecord.FieldPrompt, field.TypeString, value) + } + if tru.mutation.PromptCleared() { + _spec.ClearField(taskrecord.FieldPrompt, field.TypeString) + } + if value, ok := tru.mutation.Role(); ok { + _spec.SetField(taskrecord.FieldRole, field.TypeString, value) + } if value, ok := tru.mutation.Completion(); ok { _spec.SetField(taskrecord.FieldCompletion, field.TypeString, value) } @@ -259,6 +303,40 @@ func (truo *TaskRecordUpdateOne) ClearTaskID() *TaskRecordUpdateOne { return truo } +// SetPrompt sets the "prompt" field. +func (truo *TaskRecordUpdateOne) SetPrompt(s string) *TaskRecordUpdateOne { + truo.mutation.SetPrompt(s) + return truo +} + +// SetNillablePrompt sets the "prompt" field if the given value is not nil. +func (truo *TaskRecordUpdateOne) SetNillablePrompt(s *string) *TaskRecordUpdateOne { + if s != nil { + truo.SetPrompt(*s) + } + return truo +} + +// ClearPrompt clears the value of the "prompt" field. +func (truo *TaskRecordUpdateOne) ClearPrompt() *TaskRecordUpdateOne { + truo.mutation.ClearPrompt() + return truo +} + +// SetRole sets the "role" field. +func (truo *TaskRecordUpdateOne) SetRole(cr consts.ChatRole) *TaskRecordUpdateOne { + truo.mutation.SetRole(cr) + return truo +} + +// SetNillableRole sets the "role" field if the given value is not nil. +func (truo *TaskRecordUpdateOne) SetNillableRole(cr *consts.ChatRole) *TaskRecordUpdateOne { + if cr != nil { + truo.SetRole(*cr) + } + return truo +} + // SetCompletion sets the "completion" field. func (truo *TaskRecordUpdateOne) SetCompletion(s string) *TaskRecordUpdateOne { truo.mutation.SetCompletion(s) @@ -411,6 +489,15 @@ func (truo *TaskRecordUpdateOne) sqlSave(ctx context.Context) (_node *TaskRecord } } } + if value, ok := truo.mutation.Prompt(); ok { + _spec.SetField(taskrecord.FieldPrompt, field.TypeString, value) + } + if truo.mutation.PromptCleared() { + _spec.ClearField(taskrecord.FieldPrompt, field.TypeString) + } + if value, ok := truo.mutation.Role(); ok { + _spec.SetField(taskrecord.FieldRole, field.TypeString, value) + } if value, ok := truo.mutation.Completion(); ok { _spec.SetField(taskrecord.FieldCompletion, field.TypeString, value) } diff --git a/backend/db/user.go b/backend/db/user.go index d62628b..6f4ec07 100644 --- a/backend/db/user.go +++ b/backend/db/user.go @@ -19,6 +19,8 @@ type User struct { config `json:"-"` // ID of the ent. ID uuid.UUID `json:"id,omitempty"` + // DeletedAt holds the value of the "deleted_at" field. + DeletedAt time.Time `json:"deleted_at,omitempty"` // Username holds the value of the "username" field. Username string `json:"username,omitempty"` // Password holds the value of the "password" field. @@ -99,7 +101,7 @@ func (*User) scanValues(columns []string) ([]any, error) { switch columns[i] { case user.FieldUsername, user.FieldPassword, user.FieldEmail, user.FieldAvatarURL, user.FieldPlatform, user.FieldStatus: values[i] = new(sql.NullString) - case user.FieldCreatedAt, user.FieldUpdatedAt: + case user.FieldDeletedAt, user.FieldCreatedAt, user.FieldUpdatedAt: values[i] = new(sql.NullTime) case user.FieldID: values[i] = new(uuid.UUID) @@ -124,6 +126,12 @@ func (u *User) assignValues(columns []string, values []any) error { } else if value != nil { u.ID = *value } + case user.FieldDeletedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field deleted_at", values[i]) + } else if value.Valid { + u.DeletedAt = value.Time + } case user.FieldUsername: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field username", values[i]) @@ -228,6 +236,9 @@ func (u *User) String() string { var builder strings.Builder builder.WriteString("User(") builder.WriteString(fmt.Sprintf("id=%v, ", u.ID)) + builder.WriteString("deleted_at=") + builder.WriteString(u.DeletedAt.Format(time.ANSIC)) + builder.WriteString(", ") builder.WriteString("username=") builder.WriteString(u.Username) builder.WriteString(", ") diff --git a/backend/db/user/user.go b/backend/db/user/user.go index 29f0043..79a65cc 100644 --- a/backend/db/user/user.go +++ b/backend/db/user/user.go @@ -5,6 +5,7 @@ package user import ( "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "github.com/chaitin/MonkeyCode/backend/consts" @@ -15,6 +16,8 @@ const ( Label = "user" // FieldID holds the string denoting the id field in the database. FieldID = "id" + // FieldDeletedAt holds the string denoting the deleted_at field in the database. + FieldDeletedAt = "deleted_at" // FieldUsername holds the string denoting the username field in the database. FieldUsername = "username" // FieldPassword holds the string denoting the password field in the database. @@ -74,6 +77,7 @@ const ( // Columns holds all SQL columns for user fields. var Columns = []string{ FieldID, + FieldDeletedAt, FieldUsername, FieldPassword, FieldEmail, @@ -94,7 +98,14 @@ func ValidColumn(column string) bool { return false } +// Note that the variables below are initialized by the runtime +// package on the initialization of the application. Therefore, +// it should be imported in the main as follows: +// +// import _ "github.com/chaitin/MonkeyCode/backend/db/runtime" var ( + Hooks [1]ent.Hook + Interceptors [1]ent.Interceptor // DefaultPlatform holds the default value on creation for the "platform" field. DefaultPlatform consts.UserPlatform // DefaultStatus holds the default value on creation for the "status" field. @@ -113,6 +124,11 @@ func ByID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldID, opts...).ToFunc() } +// ByDeletedAt orders the results by the deleted_at field. +func ByDeletedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedAt, opts...).ToFunc() +} + // ByUsername orders the results by the username field. func ByUsername(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldUsername, opts...).ToFunc() diff --git a/backend/db/user/where.go b/backend/db/user/where.go index 92b7ff4..b18fecf 100644 --- a/backend/db/user/where.go +++ b/backend/db/user/where.go @@ -57,6 +57,11 @@ func IDLTE(id uuid.UUID) predicate.User { return predicate.User(sql.FieldLTE(FieldID, id)) } +// DeletedAt applies equality check predicate on the "deleted_at" field. It's identical to DeletedAtEQ. +func DeletedAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldDeletedAt, v)) +} + // Username applies equality check predicate on the "username" field. It's identical to UsernameEQ. func Username(v string) predicate.User { return predicate.User(sql.FieldEQ(FieldUsername, v)) @@ -99,6 +104,56 @@ func UpdatedAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldUpdatedAt, v)) } +// DeletedAtEQ applies the EQ predicate on the "deleted_at" field. +func DeletedAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldDeletedAt, v)) +} + +// DeletedAtNEQ applies the NEQ predicate on the "deleted_at" field. +func DeletedAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldDeletedAt, v)) +} + +// DeletedAtIn applies the In predicate on the "deleted_at" field. +func DeletedAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldDeletedAt, vs...)) +} + +// DeletedAtNotIn applies the NotIn predicate on the "deleted_at" field. +func DeletedAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldDeletedAt, vs...)) +} + +// DeletedAtGT applies the GT predicate on the "deleted_at" field. +func DeletedAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldDeletedAt, v)) +} + +// DeletedAtGTE applies the GTE predicate on the "deleted_at" field. +func DeletedAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldDeletedAt, v)) +} + +// DeletedAtLT applies the LT predicate on the "deleted_at" field. +func DeletedAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldDeletedAt, v)) +} + +// DeletedAtLTE applies the LTE predicate on the "deleted_at" field. +func DeletedAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldDeletedAt, v)) +} + +// DeletedAtIsNil applies the IsNil predicate on the "deleted_at" field. +func DeletedAtIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldDeletedAt)) +} + +// DeletedAtNotNil applies the NotNil predicate on the "deleted_at" field. +func DeletedAtNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldDeletedAt)) +} + // UsernameEQ applies the EQ predicate on the "username" field. func UsernameEQ(v string) predicate.User { return predicate.User(sql.FieldEQ(FieldUsername, v)) diff --git a/backend/db/user_create.go b/backend/db/user_create.go index 21c3646..fdfc5ee 100644 --- a/backend/db/user_create.go +++ b/backend/db/user_create.go @@ -29,6 +29,20 @@ type UserCreate struct { conflict []sql.ConflictOption } +// SetDeletedAt sets the "deleted_at" field. +func (uc *UserCreate) SetDeletedAt(t time.Time) *UserCreate { + uc.mutation.SetDeletedAt(t) + return uc +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (uc *UserCreate) SetNillableDeletedAt(t *time.Time) *UserCreate { + if t != nil { + uc.SetDeletedAt(*t) + } + return uc +} + // SetUsername sets the "username" field. func (uc *UserCreate) SetUsername(s string) *UserCreate { uc.mutation.SetUsername(s) @@ -214,7 +228,9 @@ func (uc *UserCreate) Mutation() *UserMutation { // Save creates the User in the database. func (uc *UserCreate) Save(ctx context.Context) (*User, error) { - uc.defaults() + if err := uc.defaults(); err != nil { + return nil, err + } return withHooks(ctx, uc.sqlSave, uc.mutation, uc.hooks) } @@ -241,7 +257,7 @@ func (uc *UserCreate) ExecX(ctx context.Context) { } // defaults sets the default values of the builder before save. -func (uc *UserCreate) defaults() { +func (uc *UserCreate) defaults() error { if _, ok := uc.mutation.Platform(); !ok { v := user.DefaultPlatform uc.mutation.SetPlatform(v) @@ -251,13 +267,20 @@ func (uc *UserCreate) defaults() { uc.mutation.SetStatus(v) } if _, ok := uc.mutation.CreatedAt(); !ok { + if user.DefaultCreatedAt == nil { + return fmt.Errorf("db: uninitialized user.DefaultCreatedAt (forgotten import db/runtime?)") + } v := user.DefaultCreatedAt() uc.mutation.SetCreatedAt(v) } if _, ok := uc.mutation.UpdatedAt(); !ok { + if user.DefaultUpdatedAt == nil { + return fmt.Errorf("db: uninitialized user.DefaultUpdatedAt (forgotten import db/runtime?)") + } v := user.DefaultUpdatedAt() uc.mutation.SetUpdatedAt(v) } + return nil } // check runs all checks and user-defined validators on the builder. @@ -310,6 +333,10 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _node.ID = id _spec.ID.Value = &id } + if value, ok := uc.mutation.DeletedAt(); ok { + _spec.SetField(user.FieldDeletedAt, field.TypeTime, value) + _node.DeletedAt = value + } if value, ok := uc.mutation.Username(); ok { _spec.SetField(user.FieldUsername, field.TypeString, value) _node.Username = value @@ -413,7 +440,7 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { // of the `INSERT` statement. For example: // // client.User.Create(). -// SetUsername(v). +// SetDeletedAt(v). // OnConflict( // // Update the row with the new values // // the was proposed for insertion. @@ -422,7 +449,7 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { // // Override some of the fields with custom // // update values. // Update(func(u *ent.UserUpsert) { -// SetUsername(v+v). +// SetDeletedAt(v+v). // }). // Exec(ctx) func (uc *UserCreate) OnConflict(opts ...sql.ConflictOption) *UserUpsertOne { @@ -458,6 +485,24 @@ type ( } ) +// SetDeletedAt sets the "deleted_at" field. +func (u *UserUpsert) SetDeletedAt(v time.Time) *UserUpsert { + u.Set(user.FieldDeletedAt, v) + return u +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateDeletedAt() *UserUpsert { + u.SetExcluded(user.FieldDeletedAt) + return u +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserUpsert) ClearDeletedAt() *UserUpsert { + u.SetNull(user.FieldDeletedAt) + return u +} + // SetUsername sets the "username" field. func (u *UserUpsert) SetUsername(v string) *UserUpsert { u.Set(user.FieldUsername, v) @@ -626,6 +671,27 @@ func (u *UserUpsertOne) Update(set func(*UserUpsert)) *UserUpsertOne { return u } +// SetDeletedAt sets the "deleted_at" field. +func (u *UserUpsertOne) SetDeletedAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateDeletedAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserUpsertOne) ClearDeletedAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearDeletedAt() + }) +} + // SetUsername sets the "username" field. func (u *UserUpsertOne) SetUsername(v string) *UserUpsertOne { return u.Update(func(s *UserUpsert) { @@ -902,7 +968,7 @@ func (ucb *UserCreateBulk) ExecX(ctx context.Context) { // // Override some of the fields with custom // // update values. // Update(func(u *ent.UserUpsert) { -// SetUsername(v+v). +// SetDeletedAt(v+v). // }). // Exec(ctx) func (ucb *UserCreateBulk) OnConflict(opts ...sql.ConflictOption) *UserUpsertBulk { @@ -981,6 +1047,27 @@ func (u *UserUpsertBulk) Update(set func(*UserUpsert)) *UserUpsertBulk { return u } +// SetDeletedAt sets the "deleted_at" field. +func (u *UserUpsertBulk) SetDeletedAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetDeletedAt(v) + }) +} + +// UpdateDeletedAt sets the "deleted_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateDeletedAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateDeletedAt() + }) +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (u *UserUpsertBulk) ClearDeletedAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearDeletedAt() + }) +} + // SetUsername sets the "username" field. func (u *UserUpsertBulk) SetUsername(v string) *UserUpsertBulk { return u.Update(func(s *UserUpsert) { diff --git a/backend/db/user_query.go b/backend/db/user_query.go index 089a910..2512f73 100644 --- a/backend/db/user_query.go +++ b/backend/db/user_query.go @@ -411,12 +411,12 @@ func (uq *UserQuery) WithIdentities(opts ...func(*UserIdentityQuery)) *UserQuery // Example: // // var v []struct { -// Username string `json:"username,omitempty"` +// DeletedAt time.Time `json:"deleted_at,omitempty"` // Count int `json:"count,omitempty"` // } // // client.User.Query(). -// GroupBy(user.FieldUsername). +// GroupBy(user.FieldDeletedAt). // Aggregate(db.Count()). // Scan(ctx, &v) func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { @@ -434,11 +434,11 @@ func (uq *UserQuery) GroupBy(field string, fields ...string) *UserGroupBy { // Example: // // var v []struct { -// Username string `json:"username,omitempty"` +// DeletedAt time.Time `json:"deleted_at,omitempty"` // } // // client.User.Query(). -// Select(user.FieldUsername). +// Select(user.FieldDeletedAt). // Scan(ctx, &v) func (uq *UserQuery) Select(fields ...string) *UserSelect { uq.ctx.Fields = append(uq.ctx.Fields, fields...) diff --git a/backend/db/user_update.go b/backend/db/user_update.go index 8fc4a43..ccd1c3b 100644 --- a/backend/db/user_update.go +++ b/backend/db/user_update.go @@ -35,6 +35,26 @@ func (uu *UserUpdate) Where(ps ...predicate.User) *UserUpdate { return uu } +// SetDeletedAt sets the "deleted_at" field. +func (uu *UserUpdate) SetDeletedAt(t time.Time) *UserUpdate { + uu.mutation.SetDeletedAt(t) + return uu +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (uu *UserUpdate) SetNillableDeletedAt(t *time.Time) *UserUpdate { + if t != nil { + uu.SetDeletedAt(*t) + } + return uu +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (uu *UserUpdate) ClearDeletedAt() *UserUpdate { + uu.mutation.ClearDeletedAt() + return uu +} + // SetUsername sets the "username" field. func (uu *UserUpdate) SetUsername(s string) *UserUpdate { uu.mutation.SetUsername(s) @@ -362,6 +382,12 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { } } } + if value, ok := uu.mutation.DeletedAt(); ok { + _spec.SetField(user.FieldDeletedAt, field.TypeTime, value) + } + if uu.mutation.DeletedAtCleared() { + _spec.ClearField(user.FieldDeletedAt, field.TypeTime) + } if value, ok := uu.mutation.Username(); ok { _spec.SetField(user.FieldUsername, field.TypeString, value) } @@ -600,6 +626,26 @@ type UserUpdateOne struct { modifiers []func(*sql.UpdateBuilder) } +// SetDeletedAt sets the "deleted_at" field. +func (uuo *UserUpdateOne) SetDeletedAt(t time.Time) *UserUpdateOne { + uuo.mutation.SetDeletedAt(t) + return uuo +} + +// SetNillableDeletedAt sets the "deleted_at" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableDeletedAt(t *time.Time) *UserUpdateOne { + if t != nil { + uuo.SetDeletedAt(*t) + } + return uuo +} + +// ClearDeletedAt clears the value of the "deleted_at" field. +func (uuo *UserUpdateOne) ClearDeletedAt() *UserUpdateOne { + uuo.mutation.ClearDeletedAt() + return uuo +} + // SetUsername sets the "username" field. func (uuo *UserUpdateOne) SetUsername(s string) *UserUpdateOne { uuo.mutation.SetUsername(s) @@ -957,6 +1003,12 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) } } } + if value, ok := uuo.mutation.DeletedAt(); ok { + _spec.SetField(user.FieldDeletedAt, field.TypeTime, value) + } + if uuo.mutation.DeletedAtCleared() { + _spec.ClearField(user.FieldDeletedAt, field.TypeTime) + } if value, ok := uuo.mutation.Username(); ok { _spec.SetField(user.FieldUsername, field.TypeString, value) } diff --git a/backend/docs/swagger.json b/backend/docs/swagger.json index 5c329e4..635cbcf 100644 --- a/backend/docs/swagger.json +++ b/backend/docs/swagger.json @@ -1776,6 +1776,17 @@ "AdminStatusInactive" ] }, + "consts.ChatRole": { + "type": "string", + "enum": [ + "user", + "assistant" + ], + "x-enum-varnames": [ + "ChatRoleUser", + "ChatRoleAssistant" + ] + }, "consts.ModelStatus": { "type": "string", "enum": [ @@ -1947,15 +1958,36 @@ } } }, - "domain.ChatInfo": { + "domain.ChatContent": { "type": "object", "properties": { "content": { + "description": "内容", "type": "string" }, "created_at": { "type": "integer" }, + "role": { + "description": "角色,如user: 用户的提问 assistant: 机器人回复", + "allOf": [ + { + "$ref": "#/definitions/consts.ChatRole" + } + ] + } + } + }, + "domain.ChatInfo": { + "type": "object", + "properties": { + "contents": { + "description": "消息内容", + "type": "array", + "items": { + "$ref": "#/definitions/domain.ChatContent" + } + }, "id": { "type": "string" } diff --git a/backend/domain/billing.go b/backend/domain/billing.go index 290910f..efea1ad 100644 --- a/backend/domain/billing.go +++ b/backend/domain/billing.go @@ -5,6 +5,7 @@ import ( "github.com/GoYoko/web" + "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/db" "github.com/chaitin/MonkeyCode/backend/pkg/cvt" ) @@ -58,7 +59,9 @@ func (c *ChatRecord) From(e *db.Task) *ChatRecord { return c } c.ID = e.TaskID - c.Question = e.Prompt + if len(e.Edges.TaskRecords) > 0 { + c.Question = e.Edges.TaskRecords[0].Prompt + } c.User = cvt.From(e.Edges.User, &User{}) c.Model = cvt.From(e.Edges.Model, &Model{}) c.WorkMode = e.WorkMode @@ -104,18 +107,33 @@ func (c *CompletionInfo) From(e *db.Task) *CompletionInfo { return c } c.ID = e.TaskID - c.Prompt = e.Prompt if len(e.Edges.TaskRecords) > 0 { + c.Prompt = e.Edges.TaskRecords[0].Prompt c.Content = e.Edges.TaskRecords[0].Completion } c.CreatedAt = e.CreatedAt.Unix() return c } +type ChatContent struct { + Role consts.ChatRole `json:"role"` // 角色,如user: 用户的提问 assistant: 机器人回复 + Content string `json:"content"` // 内容 + CreatedAt int64 `json:"created_at"` +} + +func (c *ChatContent) From(e *db.TaskRecord) *ChatContent { + if e == nil { + return c + } + c.Role = e.Role + c.Content = e.Completion + c.CreatedAt = e.CreatedAt.Unix() + return c +} + type ChatInfo struct { - ID string `json:"id"` - Content string `json:"content"` - CreatedAt int64 `json:"created_at"` + ID string `json:"id"` + Contents []*ChatContent `json:"contents"` // 消息内容 } func (c *ChatInfo) From(e *db.Task) *ChatInfo { @@ -123,10 +141,9 @@ func (c *ChatInfo) From(e *db.Task) *ChatInfo { return c } c.ID = e.TaskID - for _, tr := range e.Edges.TaskRecords { - c.Content += tr.Completion + "\n" - } - c.CreatedAt = e.CreatedAt.Unix() + c.Contents = cvt.Iter(e.Edges.TaskRecords, func(_ int, r *db.TaskRecord) *ChatContent { + return cvt.From(r, &ChatContent{}) + }) return c } diff --git a/backend/domain/proxy.go b/backend/domain/proxy.go index 554a880..2b0bdbb 100644 --- a/backend/domain/proxy.go +++ b/backend/domain/proxy.go @@ -43,6 +43,7 @@ type RecordParam struct { UserID string ModelID string ModelType consts.ModelType + Role consts.ChatRole Prompt string ProgramLanguage string InputTokens int64 @@ -52,3 +53,22 @@ type RecordParam struct { WorkMode string CodeLines int64 } + +func (r *RecordParam) Clone() *RecordParam { + return &RecordParam{ + RequestID: r.RequestID, + TaskID: r.TaskID, + UserID: r.UserID, + ModelID: r.ModelID, + ModelType: r.ModelType, + Role: r.Role, + Prompt: r.Prompt, + ProgramLanguage: r.ProgramLanguage, + InputTokens: r.InputTokens, + OutputTokens: r.OutputTokens, + IsAccept: r.IsAccept, + Completion: r.Completion, + WorkMode: r.WorkMode, + CodeLines: r.CodeLines, + } +} diff --git a/backend/ent/schema/task.go b/backend/ent/schema/task.go index 33d7047..265385a 100644 --- a/backend/ent/schema/task.go +++ b/backend/ent/schema/task.go @@ -34,7 +34,6 @@ func (Task) Fields() []ent.Field { field.UUID("model_id", uuid.UUID{}).Optional(), field.String("request_id").Optional(), field.String("model_type").GoType(consts.ModelType("")), - field.String("prompt").Optional(), field.Bool("is_accept").Default(false), field.String("program_language").Optional(), field.String("work_mode").Optional(), diff --git a/backend/ent/schema/taskrecord.go b/backend/ent/schema/taskrecord.go index b333189..edb1f3e 100644 --- a/backend/ent/schema/taskrecord.go +++ b/backend/ent/schema/taskrecord.go @@ -9,6 +9,8 @@ import ( "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" "github.com/google/uuid" + + "github.com/chaitin/MonkeyCode/backend/consts" ) // TaskRecord holds the schema definition for the TaskRecord entity. @@ -29,6 +31,8 @@ func (TaskRecord) Fields() []ent.Field { return []ent.Field{ field.UUID("id", uuid.UUID{}).Unique(), field.UUID("task_id", uuid.UUID{}).Optional(), + field.String("prompt").Optional(), + field.String("role").GoType(consts.ChatRole("")), field.String("completion"), field.Int64("output_tokens"), field.Time("created_at").Default(time.Now), diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index b4a2bbc..5ca39ab 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" "github.com/chaitin/MonkeyCode/backend/consts" + "github.com/chaitin/MonkeyCode/backend/pkg/entx" ) // User holds the schema definition for the User entity. @@ -26,6 +27,12 @@ func (User) Annotations() []schema.Annotation { } } +func (User) Mixin() []ent.Mixin { + return []ent.Mixin{ + entx.SoftDeleteMixin{}, + } +} + // Fields of the User. func (User) Fields() []ent.Field { return []ent.Field{ diff --git a/backend/internal/billing/repo/billing.go b/backend/internal/billing/repo/billing.go index fa55349..4bf15bd 100644 --- a/backend/internal/billing/repo/billing.go +++ b/backend/internal/billing/repo/billing.go @@ -57,7 +57,9 @@ func (b *BillingRepo) ListChatRecord(ctx context.Context, req domain.ListRecordR q := b.db.Task.Query(). WithUser(). WithModel(). - WithTaskRecords(). + WithTaskRecords(func(trq *db.TaskRecordQuery) { + trq.Order(taskrecord.ByCreatedAt(sql.OrderAsc())) + }). Where(task.ModelType(consts.ModelTypeLLM)). Order(task.ByCreatedAt(sql.OrderDesc())) @@ -101,6 +103,9 @@ func (b *BillingRepo) ListCompletionRecord(ctx context.Context, req domain.ListR q := b.db.Task.Query(). WithUser(). WithModel(). + WithTaskRecords(func(trq *db.TaskRecordQuery) { + trq.Order(taskrecord.ByCreatedAt(sql.OrderAsc())) + }). Where(task.ModelType(consts.ModelTypeCoder)). Order(task.ByCreatedAt(sql.OrderDesc())) diff --git a/backend/internal/openai/handler/v1/v1.go b/backend/internal/openai/handler/v1/v1.go index 423d14d..24226c2 100644 --- a/backend/internal/openai/handler/v1/v1.go +++ b/backend/internal/openai/handler/v1/v1.go @@ -104,7 +104,6 @@ func (h *V1Handler) ChatCompletion(c *web.Context, req openai.ChatCompletionRequ return BadRequest(c, "模型不能为空") } - h.logger.With("request", req).DebugContext(c.Request().Context(), "处理聊天补全请求") // if len(req.Tools) > 0 && req.Model != "qwen-max" { // if h.toolsCall(c, req, req.Stream, req.Model) { // return nil diff --git a/backend/internal/proxy/proxy.go b/backend/internal/proxy/proxy.go index a8b32c6..cfabd44 100644 --- a/backend/internal/proxy/proxy.go +++ b/backend/internal/proxy/proxy.go @@ -20,7 +20,9 @@ import ( "github.com/chaitin/MonkeyCode/backend/config" "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/pkg/cvt" "github.com/chaitin/MonkeyCode/backend/pkg/logger" + "github.com/chaitin/MonkeyCode/backend/pkg/promptparser" "github.com/chaitin/MonkeyCode/backend/pkg/request" ) @@ -346,6 +348,7 @@ func (p *LLMProxy) handleCompletionStream(ctx context.Context, w http.ResponseWr ModelID: m.ID, ModelType: consts.ModelTypeLLM, Prompt: req.Prompt.(string), + Role: consts.ChatRoleAssistant, } buf := bufio.NewWriterSize(w, 32*1024) defer buf.Flush() @@ -537,32 +540,44 @@ func streamRead(ctx context.Context, r io.Reader, fn func([]byte) error) error { } } -func getPrompt(req *openai.ChatCompletionRequest) string { +func (p *LLMProxy) getPrompt(ctx context.Context, req *openai.ChatCompletionRequest) string { + prompt := "" + parse := promptparser.New(promptparser.KindTask) for _, message := range req.Messages { if message.Role == "system" { continue } - if strings.Contains(message.Content, "") { - return message.Content + if strings.Contains(message.Content, "") || + strings.Contains(message.Content, "") || + strings.Contains(message.Content, "") { + if info, err := parse.Parse(message.Content); err == nil { + prompt = info.Prompt + } else { + p.logger.With("message", message.Content).WarnContext(ctx, "解析Prompt失败", "error", err) + } } for _, m := range message.MultiContent { - if strings.Contains(m.Text, "") { - return m.Text + if strings.Contains(m.Text, "") || + strings.Contains(m.Text, "") || + strings.Contains(m.Text, "") { + if info, err := parse.Parse(m.Text); err == nil { + prompt = info.Prompt + } else { + p.logger.With("message", m.Text).WarnContext(ctx, "解析Prompt失败", "error", err) + } } } } - return "" + return prompt } func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.ResponseWriter, req *openai.ChatCompletionRequest) { endpoint := "/chat/completions" p.handle(ctx, func(c *Ctx, log *RequestResponseLog) error { - // 记录开始时间用于性能监控 startTime := time.Now() - // 使用负载均衡算法选择模型 m, err := p.usecase.SelectModelWithLoadBalancing(req.Model, consts.ModelTypeLLM) if err != nil { p.logger.With("modelName", req.Model, "modelType", consts.ModelTypeLLM).WarnContext(ctx, "流式请求模型选择失败", "error", err) @@ -570,37 +585,34 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon return err } - prompt := getPrompt(req) + prompt := p.getPrompt(ctx, req) mode := req.Metadata["mode"] taskID := req.Metadata["task_id"] - // 构造上游API URL upstream := m.APIBase + endpoint log.UpstreamURL = upstream - // 创建上游请求 body, err := json.Marshal(req) if err != nil { p.logger.ErrorContext(ctx, "序列化请求体失败", "error", err) return fmt.Errorf("序列化请求体失败: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstream, bytes.NewReader(body)) + newReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstream, bytes.NewReader(body)) if err != nil { p.logger.With("upstream", upstream).WarnContext(ctx, "创建上游流式请求失败", "error", err) return fmt.Errorf("创建上游请求失败: %w", err) } - // 设置请求头 - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "text/event-stream") + newReq.Header.Set("Content-Type", "application/json") + newReq.Header.Set("Accept", "text/event-stream") if m.APIKey != "" && m.APIKey != "none" { - req.Header.Set("Authorization", "Bearer "+m.APIKey) + newReq.Header.Set("Authorization", "Bearer "+m.APIKey) } // 保存请求头(去除敏感信息) requestHeaders := make(map[string][]string) - for k, v := range req.Header { + for k, v := range newReq.Header { if k != "Authorization" { requestHeaders[k] = v } else { @@ -616,13 +628,16 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon "modelType", consts.ModelTypeLLM, "apiBase", m.APIBase, "work_mode", mode, - "requestHeader", req.Header, - "requestBody", req, + "requestHeader", newReq.Header, + "requestBody", newReq, "taskID", taskID, + "messages", cvt.Filter(req.Messages, func(i int, v openai.ChatCompletionMessage) (openai.ChatCompletionMessage, bool) { + return v, v.Role != "system" + }), ).DebugContext(ctx, "转发流式请求到上游API") // 发送请求 - resp, err := p.client.Do(req) + resp, err := p.client.Do(newReq) if err != nil { p.logger.With("upstreamURL", upstream).WarnContext(ctx, "发送上游流式请求失败", "error", err) return fmt.Errorf("发送上游请求失败: %w", err) @@ -651,7 +666,7 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon p.logger.With( "endpoint", endpoint, "upstreamURL", upstream, - "requestBody", req, + "requestBody", newReq, "statusCode", resp.StatusCode, "errorType", errorResp.Error.Type, "errorCode", errorResp.Error.Code, @@ -665,7 +680,7 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon p.logger.With( "endpoint", endpoint, "upstreamURL", upstream, - "requestBody", req, + "requestBody", newReq, "statusCode", resp.StatusCode, "responseBody", string(responseBody), ).WarnContext(ctx, "上游API流式请求异常详情") @@ -697,6 +712,7 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon ModelType: consts.ModelTypeLLM, WorkMode: mode, Prompt: prompt, + Role: consts.ChatRoleAssistant, } buf := bufio.NewWriterSize(w, 32*1024) defer buf.Flush() @@ -705,6 +721,16 @@ func (p *LLMProxy) handleChatCompletionStream(ctx context.Context, w http.Respon defer close(ch) go func(rc *domain.RecordParam) { + if rc.Prompt != "" { + urc := rc.Clone() + urc.Role = consts.ChatRoleUser + urc.Completion = urc.Prompt + if err := p.usecase.Record(context.Background(), urc); err != nil { + p.logger.With("modelID", m.ID, "modelName", m.ModelName, "modelType", consts.ModelTypeLLM). + WarnContext(ctx, "插入流式记录失败", "error", err) + } + } + for line := range ch { if bytes.HasPrefix(line, []byte("data:")) { line = bytes.TrimPrefix(line, []byte("data: ")) @@ -774,7 +800,7 @@ func (p *LLMProxy) handleChatCompletion(ctx context.Context, w http.ResponseWrit } startTime := time.Now() - prompt := getPrompt(req) + prompt := p.getPrompt(ctx, req) mode := req.Metadata["mode"] taskID := req.Metadata["task_id"] diff --git a/backend/internal/proxy/repo/proxy.go b/backend/internal/proxy/repo/proxy.go index edfeca8..4b90d14 100644 --- a/backend/internal/proxy/repo/proxy.go +++ b/backend/internal/proxy/repo/proxy.go @@ -10,6 +10,7 @@ import ( "github.com/chaitin/MonkeyCode/backend/db/apikey" "github.com/chaitin/MonkeyCode/backend/db/model" "github.com/chaitin/MonkeyCode/backend/db/task" + "github.com/chaitin/MonkeyCode/backend/db/taskrecord" "github.com/chaitin/MonkeyCode/backend/domain" "github.com/chaitin/MonkeyCode/backend/pkg/entx" ) @@ -62,7 +63,6 @@ func (r *ProxyRepo) Record(ctx context.Context, record *domain.RecordParam) erro SetRequestID(record.RequestID). SetUserID(userID). SetModelID(modelID). - SetPrompt(record.Prompt). SetProgramLanguage(record.ProgramLanguage). SetInputTokens(record.InputTokens). SetOutputTokens(record.OutputTokens). @@ -93,8 +93,26 @@ func (r *ProxyRepo) Record(ctx context.Context, record *domain.RecordParam) erro } } + if record.Role == consts.ChatRoleUser { + count, err := tx.TaskRecord.Query(). + Where( + taskrecord.TaskID(t.ID), + taskrecord.Role(consts.ChatRoleUser), + taskrecord.Prompt(record.Prompt), + ). + Count(ctx) + if err != nil { + return err + } + if count > 0 { + return nil + } + } + _, err = tx.TaskRecord.Create(). SetTaskID(t.ID). + SetRole(record.Role). + SetPrompt(record.Prompt). SetCompletion(record.Completion). SetOutputTokens(record.OutputTokens). Save(ctx) diff --git a/backend/migration/000002_create_core_table.up.sql b/backend/migration/000002_create_core_table.up.sql index f93378a..0fd19d0 100644 --- a/backend/migration/000002_create_core_table.up.sql +++ b/backend/migration/000002_create_core_table.up.sql @@ -27,7 +27,8 @@ CREATE TABLE IF NOT EXISTS users ( status VARCHAR(20) DEFAULT 'active', platform VARCHAR(12), created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP + updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + deleted_at TIMESTAMPTZ ); CREATE UNIQUE INDEX IF NOT EXISTS unique_idx_users_username ON users (username) WHERE username IS NOT NULL; @@ -122,7 +123,6 @@ CREATE TABLE IF NOT EXISTS tasks ( model_id UUID NOT NULL, request_id VARCHAR(255), model_type VARCHAR(255) NOT NULL, - prompt TEXT, completion TEXT, is_accept BOOLEAN DEFAULT FALSE, program_language VARCHAR(255), @@ -146,6 +146,8 @@ CREATE INDEX IF NOT EXISTS idx_tasks_updated_at ON tasks (updated_at); CREATE TABLE IF NOT EXISTS task_records ( id UUID PRIMARY KEY DEFAULT uuid_generate_v1(), task_id UUID NOT NULL, + role VARCHAR(255), + prompt TEXT, completion TEXT, output_tokens BIGINT, created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, diff --git a/backend/pkg/promptparser/promptparse.go b/backend/pkg/promptparser/promptparse.go index fcc10b9..2eb80f6 100644 --- a/backend/pkg/promptparser/promptparse.go +++ b/backend/pkg/promptparser/promptparse.go @@ -1,8 +1,11 @@ package promptparser import ( + "encoding/xml" "fmt" "regexp" + + "github.com/chaitin/MonkeyCode/backend/pkg/cvt" ) type PromptParser interface { @@ -51,15 +54,25 @@ func (n *NormalParser) Parse(prompt string) (*Info, error) { } type TaskParse struct { + Task string `xml:"task"` + Feedback string `xml:"feedback"` + UserMessage string `xml:"user_message"` } func (m *TaskParse) Parse(prompt string) (*Info, error) { - re := regexp.MustCompile(`(.*)(.*)(.*)`) - match := re.FindStringSubmatch(prompt) - if len(match) < 5 { - return nil, fmt.Errorf("invalid prompt") + var tp TaskParse + prompt = "" + prompt + "" + if err := xml.Unmarshal([]byte(prompt), &tp); err != nil { + return nil, err } + return &Info{ - Prompt: match[1], + Prompt: cvt.CanditionVar(func() (string, bool) { + return tp.Task, tp.Task != "" + }, func() (string, bool) { + return tp.Feedback, tp.Feedback != "" + }, func() (string, bool) { + return tp.UserMessage, tp.UserMessage != "" + }), }, nil } diff --git a/backend/pkg/promptparser/promptparse_test.go b/backend/pkg/promptparser/promptparse_test.go new file mode 100644 index 0000000..1f1605f --- /dev/null +++ b/backend/pkg/promptparser/promptparse_test.go @@ -0,0 +1,17 @@ +package promptparser + +import ( + "fmt" + "testing" +) + +func TestTaskParse(t *testing.T) { + prompt := "The user denied this operation and provided the following feedback:\n\u003cfeedback\u003e\n直接写到 'src/notion/mod.rs' (see below for file content) 这个文件\n\u003c/feedback\u003e\n\u003cfiles\u003e\n\u003cfile\u003e\u003cpath\u003eCargo.toml\u003c/path\u003e\u003cstatus\u003eDenied by user\u003c/status\u003e\u003c/file\u003e\n\u003c/files\u003e\n\n\u003cfile_content path=\"src/notion/mod.rs\"\u003e\n\n\u003c/file_content\u003e" + + tp := TaskParse{} + info, err := tp.Parse(prompt) + if err != nil { + t.Fatal(err) + } + fmt.Printf("%+v\n", info) +}