diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 65808d1..53f46f7 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -14,6 +14,7 @@ import ( "github.com/chaitin/MonkeyCode/backend/db" "github.com/chaitin/MonkeyCode/backend/domain" billingv1 "github.com/chaitin/MonkeyCode/backend/internal/billing/handler/http/v1" + codesnippetv1 "github.com/chaitin/MonkeyCode/backend/internal/codesnippet/handler/http/v1" dashv1 "github.com/chaitin/MonkeyCode/backend/internal/dashboard/handler/v1" v1 "github.com/chaitin/MonkeyCode/backend/internal/model/handler/http/v1" openaiV1 "github.com/chaitin/MonkeyCode/backend/internal/openai/handler/v1" @@ -24,20 +25,21 @@ import ( ) type Server struct { - config *config.Config - web *web.Web - ent *db.Client - logger *slog.Logger - openaiV1 *openaiV1.V1Handler - modelV1 *v1.ModelHandler - userV1 *userV1.UserHandler - dashboardV1 *dashv1.DashboardHandler - billingV1 *billingv1.BillingHandler - socketH *sockethandler.SocketHandler - version *version.VersionInfo - report *report.Reporter - reportuse domain.ReportUsecase - euse domain.ExtensionUsecase + config *config.Config + web *web.Web + ent *db.Client + logger *slog.Logger + openaiV1 *openaiV1.V1Handler + modelV1 *v1.ModelHandler + userV1 *userV1.UserHandler + dashboardV1 *dashv1.DashboardHandler + billingV1 *billingv1.BillingHandler + socketH *sockethandler.SocketHandler + version *version.VersionInfo + report *report.Reporter + reportuse domain.ReportUsecase + euse domain.ExtensionUsecase + codeSnippetV1 *codesnippetv1.CodeSnippetHandler } func newServer() (*Server, error) { diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 0a8fe9e..39e8595 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -14,6 +14,7 @@ import ( v1_5 "github.com/chaitin/MonkeyCode/backend/internal/billing/handler/http/v1" repo7 "github.com/chaitin/MonkeyCode/backend/internal/billing/repo" usecase6 "github.com/chaitin/MonkeyCode/backend/internal/billing/usecase" + v1_6 "github.com/chaitin/MonkeyCode/backend/internal/codesnippet/handler/http/v1" repo9 "github.com/chaitin/MonkeyCode/backend/internal/codesnippet/repo" usecase8 "github.com/chaitin/MonkeyCode/backend/internal/codesnippet/usecase" v1_4 "github.com/chaitin/MonkeyCode/backend/internal/dashboard/handler/v1" @@ -107,21 +108,23 @@ func newServer() (*Server, error) { reporter := report.NewReport(slogLogger, configConfig, versionInfo) reportRepo := repo10.NewReportRepo(client) reportUsecase := usecase9.NewReportUsecase(reportRepo, slogLogger, reporter, redisClient) + codeSnippetHandler := v1_6.NewCodeSnippetHandler(web, codeSnippetUsecase, authMiddleware, activeMiddleware, readOnlyMiddleware, proxyMiddleware, slogLogger) server := &Server{ - config: configConfig, - web: web, - ent: client, - logger: slogLogger, - openaiV1: v1Handler, - modelV1: modelHandler, - userV1: userHandler, - dashboardV1: dashboardHandler, - billingV1: billingHandler, - socketH: socketHandler, - version: versionInfo, - report: reporter, - reportuse: reportUsecase, - euse: extensionUsecase, + config: configConfig, + web: web, + ent: client, + logger: slogLogger, + openaiV1: v1Handler, + modelV1: modelHandler, + userV1: userHandler, + dashboardV1: dashboardHandler, + billingV1: billingHandler, + socketH: socketHandler, + version: versionInfo, + report: reporter, + reportuse: reportUsecase, + euse: extensionUsecase, + codeSnippetV1: codeSnippetHandler, } return server, nil } @@ -129,18 +132,19 @@ func newServer() (*Server, error) { // wire.go: type Server struct { - config *config.Config - web *web.Web - ent *db.Client - logger *slog.Logger - openaiV1 *v1.V1Handler - modelV1 *v1_2.ModelHandler - userV1 *v1_3.UserHandler - dashboardV1 *v1_4.DashboardHandler - billingV1 *v1_5.BillingHandler - socketH *handler.SocketHandler - version *version.VersionInfo - report *report.Reporter - reportuse domain.ReportUsecase - euse domain.ExtensionUsecase + config *config.Config + web *web.Web + ent *db.Client + logger *slog.Logger + openaiV1 *v1.V1Handler + modelV1 *v1_2.ModelHandler + userV1 *v1_3.UserHandler + dashboardV1 *v1_4.DashboardHandler + billingV1 *v1_5.BillingHandler + socketH *handler.SocketHandler + version *version.VersionInfo + report *report.Reporter + reportuse domain.ReportUsecase + euse domain.ExtensionUsecase + codeSnippetV1 *v1_6.CodeSnippetHandler } diff --git a/backend/docs/swagger.json b/backend/docs/swagger.json index 4423de1..fe8b9ad 100644 --- a/backend/docs/swagger.json +++ b/backend/docs/swagger.json @@ -1157,6 +1157,61 @@ } } }, + "/api/v1/ide/codesnippet/context": { + "post": { + "security": [ + { + "ApiKeyAuth": [] + } + ], + "description": "为IDE端提供代码片段上下文检索功能,使用API Key认证。支持单个查询和批量查询。", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "CodeSnippet" + ], + "summary": "IDE端上下文检索", + "operationId": "get-context", + "parameters": [ + { + "description": "检索请求参数", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/internal_codesnippet_handler_http_v1.GetContextReq" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "allOf": [ + { + "$ref": "#/definitions/web.Resp" + }, + { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/definitions/domain.CodeSnippet" + } + } + } + } + ] + } + } + } + } + }, "/api/v1/model": { "get": { "description": "获取模型列表", @@ -3797,6 +3852,98 @@ "CodeLanguageTypeCpp" ] }, + "domain.CodeSnippet": { + "type": "object", + "properties": { + "container_name": { + "description": "容器名称", + "type": "string" + }, + "content": { + "description": "代码片段内容", + "type": "string" + }, + "definition_text": { + "description": "定义文本", + "type": "string" + }, + "dependencies": { + "description": "依赖项", + "type": "array", + "items": { + "type": "string" + } + }, + "end_column": { + "description": "结束列号", + "type": "integer" + }, + "end_line": { + "description": "结束行号", + "type": "integer" + }, + "hash": { + "description": "内容哈希", + "type": "string" + }, + "id": { + "description": "代码片段ID", + "type": "string" + }, + "language": { + "description": "编程语言", + "type": "string" + }, + "name": { + "description": "代码片段名称", + "type": "string" + }, + "namespace": { + "description": "命名空间", + "type": "string" + }, + "parameters": { + "description": "参数列表", + "type": "array", + "items": { + "type": "object", + "additionalProperties": {} + } + }, + "scope": { + "description": "作用域信息", + "type": "array", + "items": { + "type": "string" + } + }, + "signature": { + "description": "函数签名", + "type": "string" + }, + "snippet_type": { + "description": "代码片段类型", + "type": "string" + }, + "start_column": { + "description": "起始列号", + "type": "integer" + }, + "start_line": { + "description": "起始行号", + "type": "integer" + }, + "structured_info": { + "description": "结构化信息", + "type": "object", + "additionalProperties": {} + }, + "workspace_file_id": { + "description": "关联的workspace file ID", + "type": "string" + } + } + }, "domain.CompletionData": { "type": "object", "properties": { @@ -5555,6 +5702,51 @@ } } }, + "internal_codesnippet_handler_http_v1.GetContextReq": { + "type": "object", + "properties": { + "limit": { + "description": "返回结果数量限制,默认10", + "type": "integer" + }, + "queries": { + "description": "批量查询参数", + "type": "array", + "items": { + "$ref": "#/definitions/internal_codesnippet_handler_http_v1.Query" + } + }, + "query": { + "description": "单个查询参数", + "allOf": [ + { + "$ref": "#/definitions/internal_codesnippet_handler_http_v1.Query" + } + ] + }, + "workspace_path": { + "description": "工作区路径(必填)", + "type": "string" + } + } + }, + "internal_codesnippet_handler_http_v1.Query": { + "type": "object", + "properties": { + "language": { + "description": "编程语言(可选)", + "type": "string" + }, + "name": { + "description": "代码片段名称(可选)", + "type": "string" + }, + "type": { + "description": "代码片段类型(可选)", + "type": "string" + } + } + }, "web.Resp": { "type": "object", "properties": { diff --git a/backend/domain/codesnippet.go b/backend/domain/codesnippet.go index 57540e3..9868444 100644 --- a/backend/domain/codesnippet.go +++ b/backend/domain/codesnippet.go @@ -12,6 +12,8 @@ type CodeSnippetUsecase interface { ListByWorkspaceFile(ctx context.Context, workspaceFileID string) ([]*CodeSnippet, error) GetByID(ctx context.Context, id string) (*CodeSnippet, error) Delete(ctx context.Context, id string) error + Search(ctx context.Context, name, snippetType, language string) ([]*CodeSnippet, error) + SearchByWorkspace(ctx context.Context, userID, workspacePath, name, snippetType, language string) ([]*CodeSnippet, error) } // CodeSnippetRepo 定义 CodeSnippet 数据访问接口 @@ -20,6 +22,8 @@ type CodeSnippetRepo interface { ListByWorkspaceFile(ctx context.Context, workspaceFileID string) ([]*db.CodeSnippet, error) GetByID(ctx context.Context, id string) (*db.CodeSnippet, error) Delete(ctx context.Context, id string) error + Search(ctx context.Context, name, snippetType, language string) ([]*db.CodeSnippet, error) + SearchByWorkspace(ctx context.Context, userID, workspacePath, name, snippetType, language string) ([]*db.CodeSnippet, error) } // 请求结构体 diff --git a/backend/internal/codesnippet/handler/http/v1/codesnippet.go b/backend/internal/codesnippet/handler/http/v1/codesnippet.go new file mode 100644 index 0000000..c520e15 --- /dev/null +++ b/backend/internal/codesnippet/handler/http/v1/codesnippet.go @@ -0,0 +1,139 @@ +package v1 + +import ( + "fmt" + "log/slog" + + "github.com/GoYoko/web" + + "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/internal/middleware" + "github.com/chaitin/MonkeyCode/backend/pkg/logger" +) + +type CodeSnippetHandler struct { + usecase domain.CodeSnippetUsecase + logger *slog.Logger +} + +func NewCodeSnippetHandler( + w *web.Web, + usecase domain.CodeSnippetUsecase, + auth *middleware.AuthMiddleware, + active *middleware.ActiveMiddleware, + readonly *middleware.ReadOnlyMiddleware, + proxy *middleware.ProxyMiddleware, + logger *slog.Logger, +) *CodeSnippetHandler { + h := &CodeSnippetHandler{ + usecase: usecase, + logger: logger.With("handler", "codesnippet"), + } + + // 设置路由 - 使用API Key认证的接口(IDE端使用) + ide := w.Group("/api/v1/ide/codesnippet") + ide.Use(proxy.Auth(), active.Active("apikey"), readonly.Guard()) + + // IDE端上下文检索接口 + ide.POST("/context", web.BindHandler(h.GetContext)) + + return h +} + +// GetContextReq IDE端上下文检索请求 +type GetContextReq struct { + // 批量查询参数 + Queries []Query `json:"queries,omitempty"` // 批量查询条件 + + // 单个查询参数 + Query Query `json:"query,omitempty"` // 单个查询条件 + + Limit int `json:"limit"` // 返回结果数量限制,默认10 + WorkspacePath string `json:"workspace_path"` // 工作区路径(必填) +} + +// Query 批量查询条件 +type Query struct { + Name string `json:"name,omitempty"` // 代码片段名称(可选) + Type string `json:"type,omitempty"` // 代码片段类型(可选) + Language string `json:"language,omitempty"` // 编程语言(可选) +} + +// GetContext IDE端上下文检索接口 +// +// @Tags CodeSnippet +// @Summary IDE端上下文检索 +// @Description 为IDE端提供代码片段上下文检索功能,使用API Key认证。支持单个查询和批量查询。 +// @ID get-context +// @Accept json +// @Produce json +// @Param request body GetContextReq true "检索请求参数" +// @Success 200 {object} web.Resp{data=[]domain.CodeSnippet} +// @Router /api/v1/ide/codesnippet/context [post] +// @Security ApiKeyAuth +func (h *CodeSnippetHandler) GetContext(c *web.Context, req GetContextReq) error { + // 设置默认限制 + if req.Limit <= 0 { + req.Limit = 10 + } + if req.Limit > 50 { + req.Limit = 50 // 最大限制50个结果 + } + + // 如果没有提供workspace_path,则返回空结果 + if req.WorkspacePath == "" { + return c.Success([]*domain.CodeSnippet{}) + } + + // 获取用户ID,主要使用API Key认证 + var userID string + if ctxUserID := c.Request().Context().Value(logger.UserIDKey{}); ctxUserID != nil { + userID = ctxUserID.(string) + } else { + h.logger.Error("API Key authentication required for IDE context retrieval") + return fmt.Errorf("API Key authentication required") + } + + var allSnippets []*domain.CodeSnippet + + // 如果提供了批量查询条件,则执行批量查询 + if len(req.Queries) > 0 { + // 执行批量查询 + for _, query := range req.Queries { + snippets, err := h.usecase.SearchByWorkspace(c.Request().Context(), userID, req.WorkspacePath, query.Name, query.Type, query.Language) + if err != nil { + h.logger.Error("failed to get context for IDE", "error", err) + return err + } + allSnippets = append(allSnippets, snippets...) + } + } else { + // 执行单个查询 + snippets, err := h.usecase.SearchByWorkspace(c.Request().Context(), userID, req.WorkspacePath, req.Query.Name, req.Query.Type, req.Query.Language) + if err != nil { + h.logger.Error("failed to get context for IDE", "error", err) + return err + } + allSnippets = snippets + } + + // 限制返回结果数量 + if len(allSnippets) > req.Limit { + allSnippets = allSnippets[:req.Limit] + } + + h.logger.Info("IDE context retrieval completed", + "userID", c.Request().Context().Value(logger.UserIDKey{}), + "resultCount", len(allSnippets), + "filters", map[string]interface{}{ + "singleQuery": map[string]string{ + "name": req.Query.Name, + "type": req.Query.Type, + "language": req.Query.Language, + }, + "batchQueryCount": len(req.Queries), + "workspace_path": req.WorkspacePath, + }) + + return c.Success(allSnippets) +} diff --git a/backend/internal/codesnippet/repo/codesnippet.go b/backend/internal/codesnippet/repo/codesnippet.go index 4845561..a1438e8 100644 --- a/backend/internal/codesnippet/repo/codesnippet.go +++ b/backend/internal/codesnippet/repo/codesnippet.go @@ -6,6 +6,10 @@ import ( "log/slog" "github.com/chaitin/MonkeyCode/backend/db" + "github.com/chaitin/MonkeyCode/backend/db/codesnippet" + "github.com/chaitin/MonkeyCode/backend/db/predicate" + "github.com/chaitin/MonkeyCode/backend/db/workspace" + "github.com/chaitin/MonkeyCode/backend/db/workspacefile" "github.com/chaitin/MonkeyCode/backend/domain" "github.com/google/uuid" ) @@ -60,22 +64,165 @@ func (r *CodeSnippetRepo) Create(ctx context.Context, req *domain.CreateCodeSnip } func (r *CodeSnippetRepo) ListByWorkspaceFile(ctx context.Context, workspaceFileID string) ([]*db.CodeSnippet, error) { - // 实现列出特定工作区文件的所有代码片段的逻辑 - // 这里需要将 workspaceFileID 字符串转换为 UUID - // 为简化起见,这里暂时返回空列表,实际实现需要根据需求完成 - return []*db.CodeSnippet{}, nil + // 将 workspaceFileID 字符串转换为 UUID + workspaceFileUUID, err := uuid.Parse(workspaceFileID) + if err != nil { + r.logger.Error("failed to parse workspace file ID", "error", err, "id", workspaceFileID) + return nil, fmt.Errorf("invalid workspace file ID: %w", err) + } + + // 查询特定工作区文件的所有代码片段 + snippets, err := r.client.CodeSnippet.Query(). + Where(codesnippet.WorkspaceFileID(workspaceFileUUID)). + All(ctx) + if err != nil { + r.logger.Error("failed to list code snippets by workspace file", "error", err, "workspaceFileID", workspaceFileID) + return nil, fmt.Errorf("failed to list code snippets: %w", err) + } + + return snippets, nil } func (r *CodeSnippetRepo) GetByID(ctx context.Context, id string) (*db.CodeSnippet, error) { - // 实现根据 ID 获取代码片段的逻辑 - // 这里需要将 id 字符串转换为 UUID - // 为简化起见,这里暂时返回 nil,实际实现需要根据需求完成 - return nil, nil + // 将 id 字符串转换为 UUID + uuid, err := uuid.Parse(id) + if err != nil { + r.logger.Error("failed to parse code snippet ID", "error", err, "id", id) + return nil, fmt.Errorf("invalid code snippet ID: %w", err) + } + + // 根据 ID 获取代码片段 + snippet, err := r.client.CodeSnippet.Get(ctx, uuid) + if err != nil { + r.logger.Error("failed to get code snippet by ID", "error", err, "id", id) + return nil, fmt.Errorf("failed to get code snippet: %w", err) + } + + return snippet, nil } func (r *CodeSnippetRepo) Delete(ctx context.Context, id string) error { - // 实现删除代码片段的逻辑 - // 这里需要将 id 字符串转换为 UUID - // 为简化起见,这里暂时返回 nil,实际实现需要根据需求完成 + // 将 id 字符串转换为 UUID + uuid, err := uuid.Parse(id) + if err != nil { + r.logger.Error("failed to parse code snippet ID", "error", err, "id", id) + return fmt.Errorf("invalid code snippet ID: %w", err) + } + + // 删除代码片段 + err = r.client.CodeSnippet.DeleteOneID(uuid).Exec(ctx) + if err != nil { + r.logger.Error("failed to delete code snippet", "error", err, "id", id) + return fmt.Errorf("failed to delete code snippet: %w", err) + } + return nil } + +// Search 根据名称、类型和语言搜索代码片段 +func (r *CodeSnippetRepo) Search(ctx context.Context, name, snippetType, language string) ([]*db.CodeSnippet, error) { + // 构建查询 + query := r.client.CodeSnippet.Query() + + // 创建一个切片来存储所有谓词 + var predicates []predicate.CodeSnippet + + // 如果提供了名称参数,则添加名称过滤条件 + if name != "" { + predicates = append(predicates, codesnippet.Name(name)) + } + + // 如果提供了类型参数,则添加类型过滤条件 + if snippetType != "" { + predicates = append(predicates, codesnippet.SnippetType(snippetType)) + } + + // 如果提供了语言参数,则添加语言过滤条件 + if language != "" { + predicates = append(predicates, codesnippet.Language(language)) + } + + // 如果有任何谓词,将它们添加到查询中 + if len(predicates) > 0 { + query = query.Where(codesnippet.And(predicates...)) + } + + // 执行查询并返回结果 + return query.All(ctx) +} + +// SearchByWorkspace 根据用户ID、工作区路径和搜索条件搜索代码片段 +// 只有在提供了至少一个搜索条件时才返回结果,否则返回空数组 +func (r *CodeSnippetRepo) SearchByWorkspace(ctx context.Context, userID, workspacePath, name, snippetType, language string) ([]*db.CodeSnippet, error) { + // 检查是否提供了至少一个搜索条件 + if name == "" && snippetType == "" && language == "" { + // 如果没有提供任何搜索条件,返回空结果 + return []*db.CodeSnippet{}, nil + } + + // 首先根据用户ID和工作区路径找到工作区 + userUUID, err := uuid.Parse(userID) + if err != nil { + r.logger.Error("failed to parse user ID", "error", err, "userID", userID) + return nil, fmt.Errorf("invalid user ID: %w", err) + } + + // 查询工作区 + workspace, err := r.client.Workspace.Query(). + Where( + workspace.UserID(userUUID), + workspace.RootPath(workspacePath), + ). + Only(ctx) + if err != nil { + r.logger.Error("failed to find workspace", "error", err, "userID", userID, "workspacePath", workspacePath) + return nil, fmt.Errorf("workspace not found: %w", err) + } + + // 查询该工作区下的所有文件 + workspaceFiles, err := r.client.WorkspaceFile.Query(). + Where(workspacefile.WorkspaceID(workspace.ID)). + All(ctx) + if err != nil { + r.logger.Error("failed to get workspace files", "error", err, "workspaceID", workspace.ID) + return nil, fmt.Errorf("failed to get workspace files: %w", err) + } + + if len(workspaceFiles) == 0 { + return []*db.CodeSnippet{}, nil + } + + // 提取文件ID列表 + var fileIDs []uuid.UUID + for _, file := range workspaceFiles { + fileIDs = append(fileIDs, file.ID) + } + + // 构建代码片段查询 + query := r.client.CodeSnippet.Query(). + Where(codesnippet.WorkspaceFileIDIn(fileIDs...)) + + // 创建一个切片来存储所有谓词 + var predicates []predicate.CodeSnippet + + // 如果提供了名称参数,则添加名称过滤条件 + if name != "" { + predicates = append(predicates, codesnippet.Name(name)) + } + + // 如果提供了类型参数,则添加类型过滤条件 + if snippetType != "" { + predicates = append(predicates, codesnippet.SnippetType(snippetType)) + } + + // 如果提供了语言参数,则添加语言过滤条件 + if language != "" { + predicates = append(predicates, codesnippet.Language(language)) + } + + // 添加谓词到查询中(这里总是会添加,因为我们已经检查了至少有一个条件) + query = query.Where(codesnippet.And(predicates...)) + + // 执行查询并返回结果 + return query.All(ctx) +} diff --git a/backend/internal/codesnippet/usecase/codesnippet.go b/backend/internal/codesnippet/usecase/codesnippet.go index e782e63..c9496d0 100644 --- a/backend/internal/codesnippet/usecase/codesnippet.go +++ b/backend/internal/codesnippet/usecase/codesnippet.go @@ -62,21 +62,82 @@ func (u *CodeSnippetUsecase) CreateFromIndexResult(ctx context.Context, workspac // ListByWorkspaceFile 列出特定工作区文件的所有代码片段 func (u *CodeSnippetUsecase) ListByWorkspaceFile(ctx context.Context, workspaceFileID string) ([]*domain.CodeSnippet, error) { - // 实现列出特定工作区文件的所有代码片段的逻辑 - // 为简化起见,这里暂时返回空列表,实际实现需要根据需求完成 - return []*domain.CodeSnippet{}, nil + // 调用 repository 层的方法 + dbSnippets, err := u.repo.ListByWorkspaceFile(ctx, workspaceFileID) + if err != nil { + u.logger.Error("failed to list code snippets by workspace file", "error", err, "workspaceFileID", workspaceFileID) + return nil, fmt.Errorf("failed to list code snippets: %w", err) + } + + // 将数据库模型转换为领域模型 + var snippets []*domain.CodeSnippet + for _, dbSnippet := range dbSnippets { + snippet := (&domain.CodeSnippet{}).From(dbSnippet) + snippets = append(snippets, snippet) + } + + return snippets, nil } // GetByID 根据 ID 获取代码片段 func (u *CodeSnippetUsecase) GetByID(ctx context.Context, id string) (*domain.CodeSnippet, error) { - // 实现根据 ID 获取代码片段的逻辑 - // 为简化起见,这里暂时返回 nil,实际实现需要根据需求完成 - return nil, nil + // 调用 repository 层的方法 + dbSnippet, err := u.repo.GetByID(ctx, id) + if err != nil { + u.logger.Error("failed to get code snippet by ID", "error", err, "id", id) + return nil, fmt.Errorf("failed to get code snippet: %w", err) + } + + // 将数据库模型转换为领域模型 + return (&domain.CodeSnippet{}).From(dbSnippet), nil } // Delete 删除代码片段 func (u *CodeSnippetUsecase) Delete(ctx context.Context, id string) error { - // 实现删除代码片段的逻辑 - // 为简化起见,这里暂时返回 nil,实际实现需要根据需求完成 + // 调用 repository 层的方法 + err := u.repo.Delete(ctx, id) + if err != nil { + u.logger.Error("failed to delete code snippet", "error", err, "id", id) + return fmt.Errorf("failed to delete code snippet: %w", err) + } + return nil } + +// Search 根据名称、类型和语言搜索代码片段 +func (u *CodeSnippetUsecase) Search(ctx context.Context, name, snippetType, language string) ([]*domain.CodeSnippet, error) { + // 调用 repository 层的 Search 方法 + dbSnippets, err := u.repo.Search(ctx, name, snippetType, language) + if err != nil { + u.logger.Error("failed to search code snippets", "error", err) + return nil, fmt.Errorf("failed to search code snippets: %w", err) + } + + // 将数据库模型转换为领域模型 + var snippets []*domain.CodeSnippet + for _, dbSnippet := range dbSnippets { + snippet := (&domain.CodeSnippet{}).From(dbSnippet) + snippets = append(snippets, snippet) + } + + return snippets, nil +} + +// SearchByWorkspace 根据用户ID、工作区路径和搜索条件搜索代码片段 +func (u *CodeSnippetUsecase) SearchByWorkspace(ctx context.Context, userID, workspacePath, name, snippetType, language string) ([]*domain.CodeSnippet, error) { + // 调用 repository 层的 SearchByWorkspace 方法 + dbSnippets, err := u.repo.SearchByWorkspace(ctx, userID, workspacePath, name, snippetType, language) + if err != nil { + u.logger.Error("failed to search code snippets by workspace", "error", err, "userID", userID, "workspacePath", workspacePath) + return nil, fmt.Errorf("failed to search code snippets by workspace: %w", err) + } + + // 将数据库模型转换为领域模型 + var snippets []*domain.CodeSnippet + for _, dbSnippet := range dbSnippets { + snippet := (&domain.CodeSnippet{}).From(dbSnippet) + snippets = append(snippets, snippet) + } + + return snippets, nil +} diff --git a/backend/internal/provider.go b/backend/internal/provider.go index d29d7ea..fc19ae9 100644 --- a/backend/internal/provider.go +++ b/backend/internal/provider.go @@ -6,6 +6,7 @@ import ( billingv1 "github.com/chaitin/MonkeyCode/backend/internal/billing/handler/http/v1" billingrepo "github.com/chaitin/MonkeyCode/backend/internal/billing/repo" billingusecase "github.com/chaitin/MonkeyCode/backend/internal/billing/usecase" + codesnippetv1 "github.com/chaitin/MonkeyCode/backend/internal/codesnippet/handler/http/v1" codesnippetrepo "github.com/chaitin/MonkeyCode/backend/internal/codesnippet/repo" codesnippetusecase "github.com/chaitin/MonkeyCode/backend/internal/codesnippet/usecase" dashv1 "github.com/chaitin/MonkeyCode/backend/internal/dashboard/handler/v1" @@ -71,4 +72,5 @@ var Provider = wire.NewSet( reportrepo.NewReportRepo, codesnippetrepo.NewCodeSnippetRepo, codesnippetusecase.NewCodeSnippetUsecase, + codesnippetv1.NewCodeSnippetHandler, )