mirror of
https://github.com/chaitin/MonkeyCode.git
synced 2026-02-02 23:03:57 +08:00
559 lines
16 KiB
Go
559 lines
16 KiB
Go
package usecase
|
||
|
||
import (
|
||
"context"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"fmt"
|
||
"log/slog"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/chaitin/MonkeyCode/backend/config"
|
||
"github.com/chaitin/MonkeyCode/backend/db"
|
||
"github.com/chaitin/MonkeyCode/backend/domain"
|
||
"github.com/chaitin/MonkeyCode/backend/pkg/cli"
|
||
)
|
||
|
||
type WorkspaceUsecase struct {
|
||
repo domain.WorkspaceRepo
|
||
config *config.Config
|
||
logger *slog.Logger
|
||
}
|
||
|
||
type WorkspaceFileUsecase struct {
|
||
repo domain.WorkspaceFileRepo
|
||
workspaceSvc domain.WorkspaceUsecase
|
||
codeSnippetSvc domain.CodeSnippetUsecase
|
||
config *config.Config
|
||
logger *slog.Logger
|
||
}
|
||
|
||
func NewWorkspaceUsecase(
|
||
repo domain.WorkspaceRepo,
|
||
config *config.Config,
|
||
logger *slog.Logger,
|
||
) domain.WorkspaceUsecase {
|
||
return &WorkspaceUsecase{
|
||
repo: repo,
|
||
config: config,
|
||
logger: logger.With("usecase", "workspace"),
|
||
}
|
||
}
|
||
|
||
func NewWorkspaceFileUsecase(
|
||
repo domain.WorkspaceFileRepo,
|
||
workspaceSvc domain.WorkspaceUsecase,
|
||
codeSnippetSvc domain.CodeSnippetUsecase,
|
||
config *config.Config,
|
||
logger *slog.Logger,
|
||
) domain.WorkspaceFileUsecase {
|
||
return &WorkspaceFileUsecase{
|
||
repo: repo,
|
||
workspaceSvc: workspaceSvc,
|
||
codeSnippetSvc: codeSnippetSvc,
|
||
config: config,
|
||
logger: logger.With("usecase", "workspace_file"),
|
||
}
|
||
}
|
||
|
||
// WorkspaceUsecase methods
|
||
|
||
func (u *WorkspaceUsecase) Create(ctx context.Context, req *domain.CreateWorkspaceReq) (*domain.Workspace, error) {
|
||
workspace, err := u.repo.Create(ctx, req)
|
||
if err != nil {
|
||
u.logger.Error("failed to create workspace", "error", err, "name", req.Name, "root_path", req.RootPath)
|
||
return nil, fmt.Errorf("failed to create workspace: %w", err)
|
||
}
|
||
|
||
u.logger.Info("workspace created", "id", workspace.ID, "name", req.Name, "root_path", req.RootPath)
|
||
return (&domain.Workspace{}).From(workspace), nil
|
||
}
|
||
|
||
func (u *WorkspaceUsecase) GetByID(ctx context.Context, id string) (*domain.Workspace, error) {
|
||
workspace, err := u.repo.GetByID(ctx, id)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get workspace: %w", err)
|
||
}
|
||
|
||
return (&domain.Workspace{}).From(workspace), nil
|
||
}
|
||
|
||
func (u *WorkspaceUsecase) GetByUserAndPath(ctx context.Context, userID, rootPath string) (*domain.Workspace, error) {
|
||
workspace, err := u.repo.GetByUserAndPath(ctx, userID, rootPath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get workspace by user and path: %w", err)
|
||
}
|
||
|
||
return (&domain.Workspace{}).From(workspace), nil
|
||
}
|
||
|
||
func (u *WorkspaceUsecase) List(ctx context.Context, req *domain.ListWorkspaceReq) (*domain.ListWorkspaceResp, error) {
|
||
workspaces, pageInfo, err := u.repo.List(ctx, req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to list workspaces: %w", err)
|
||
}
|
||
|
||
return &domain.ListWorkspaceResp{
|
||
PageInfo: pageInfo,
|
||
Workspaces: domain.FromWorkspaces(workspaces),
|
||
}, nil
|
||
}
|
||
|
||
func (u *WorkspaceUsecase) Update(ctx context.Context, req *domain.UpdateWorkspaceReq) (*domain.Workspace, error) {
|
||
workspace, err := u.repo.Update(ctx, req.ID, func(up *db.WorkspaceUpdateOne) error {
|
||
if req.Name != nil {
|
||
up.SetName(*req.Name)
|
||
}
|
||
if req.Description != nil {
|
||
up.SetDescription(*req.Description)
|
||
}
|
||
if req.Settings != nil {
|
||
up.SetSettings(req.Settings)
|
||
}
|
||
return nil
|
||
})
|
||
if err != nil {
|
||
u.logger.Error("failed to update workspace", "error", err, "id", req.ID)
|
||
return nil, fmt.Errorf("failed to update workspace: %w", err)
|
||
}
|
||
|
||
u.logger.Info("workspace updated", "id", req.ID)
|
||
return (&domain.Workspace{}).From(workspace), nil
|
||
}
|
||
|
||
func (u *WorkspaceUsecase) Delete(ctx context.Context, id string) error {
|
||
err := u.repo.Delete(ctx, id)
|
||
if err != nil {
|
||
u.logger.Error("failed to delete workspace", "error", err, "id", id)
|
||
return fmt.Errorf("failed to delete workspace: %w", err)
|
||
}
|
||
|
||
u.logger.Info("workspace deleted", "id", id)
|
||
return nil
|
||
}
|
||
|
||
func (u *WorkspaceUsecase) EnsureWorkspace(ctx context.Context, userID, rootPath, name string) (*domain.Workspace, error) {
|
||
// 首先尝试获取已存在的工作区
|
||
workspace, err := u.repo.GetByUserAndPath(ctx, userID, rootPath)
|
||
if err == nil {
|
||
// 工作区已存在,更新最后访问时间
|
||
updated, err := u.repo.Update(ctx, workspace.ID.String(), func(up *db.WorkspaceUpdateOne) error {
|
||
up.SetLastAccessedAt(time.Now())
|
||
return nil
|
||
})
|
||
if err != nil {
|
||
u.logger.Warn("failed to update workspace last accessed time", "error", err, "id", workspace.ID)
|
||
}
|
||
return (&domain.Workspace{}).From(updated), nil
|
||
}
|
||
|
||
// 如果工作区不存在,创建新的工作区
|
||
if !db.IsNotFound(err) {
|
||
return nil, fmt.Errorf("failed to check workspace existence: %w", err)
|
||
}
|
||
|
||
// 自动生成工作区名称(如果未提供)
|
||
if name == "" {
|
||
name = u.generateWorkspaceName(rootPath)
|
||
}
|
||
|
||
createReq := &domain.CreateWorkspaceReq{
|
||
UserID: userID,
|
||
Name: name,
|
||
Description: fmt.Sprintf("Auto-created workspace for %s", rootPath),
|
||
RootPath: rootPath,
|
||
Settings: map[string]any{},
|
||
}
|
||
|
||
return u.Create(ctx, createReq)
|
||
}
|
||
|
||
func (u *WorkspaceUsecase) generateWorkspaceName(rootPath string) string {
|
||
// 从路径中提取最后一个目录名作为工作区名称
|
||
parts := strings.Split(rootPath, "/")
|
||
if len(parts) > 0 {
|
||
name := parts[len(parts)-1]
|
||
if name != "" {
|
||
return name
|
||
}
|
||
}
|
||
return "Untitled Workspace"
|
||
}
|
||
|
||
// WorkspaceFileUsecase methods
|
||
|
||
func (u *WorkspaceFileUsecase) Create(ctx context.Context, req *domain.CreateWorkspaceFileReq) (*domain.WorkspaceFile, error) {
|
||
// 验证和计算哈希
|
||
if req.Hash == "" {
|
||
req.Hash = u.calculateHash(req.Content)
|
||
} else if !u.verifyHash(req.Content, req.Hash) {
|
||
return nil, fmt.Errorf("provided hash does not match content")
|
||
}
|
||
|
||
// 计算文件大小
|
||
if req.Size == 0 {
|
||
req.Size = int64(len(req.Content))
|
||
}
|
||
|
||
// 推断编程语言
|
||
if req.Language == "" {
|
||
req.Language = u.inferLanguage(req.Path)
|
||
}
|
||
|
||
// 确保工作区存在
|
||
// 首先通过workspace ID获取workspace信息,然后使用其root path来确保workspace存在
|
||
workspace, err := u.workspaceSvc.GetByID(ctx, req.WorkspaceID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get workspace by ID: %w", err)
|
||
}
|
||
|
||
_, err = u.workspaceSvc.EnsureWorkspace(ctx, req.UserID, workspace.RootPath, "")
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to ensure workspace exists: %w", err)
|
||
}
|
||
|
||
file, err := u.repo.Create(ctx, req)
|
||
if err != nil {
|
||
u.logger.Error("failed to create workspace file", "error", err, "path", req.Path)
|
||
return nil, fmt.Errorf("failed to create file: %w", err)
|
||
}
|
||
|
||
u.logger.Info("workspace file created", "id", file.ID, "path", req.Path)
|
||
return (&domain.WorkspaceFile{}).From(file), nil
|
||
}
|
||
|
||
func (u *WorkspaceFileUsecase) Update(ctx context.Context, req *domain.UpdateWorkspaceFileReq) (*domain.WorkspaceFile, error) {
|
||
file, err := u.repo.Update(ctx, req.ID, func(up *db.WorkspaceFileUpdateOne) error {
|
||
if req.Content != nil {
|
||
up.SetContent(*req.Content)
|
||
// 更新内容时重新计算哈希和大小
|
||
if req.Hash != nil {
|
||
if !u.verifyHash(*req.Content, *req.Hash) {
|
||
return fmt.Errorf("provided hash does not match content")
|
||
}
|
||
up.SetHash(*req.Hash)
|
||
} else {
|
||
up.SetHash(u.calculateHash(*req.Content))
|
||
}
|
||
if req.Size != nil {
|
||
up.SetSize(*req.Size)
|
||
} else {
|
||
up.SetSize(int64(len(*req.Content)))
|
||
}
|
||
} else if req.Hash != nil {
|
||
up.SetHash(*req.Hash)
|
||
}
|
||
|
||
if req.Language != nil {
|
||
up.SetLanguage(*req.Language)
|
||
}
|
||
|
||
if req.Size != nil && req.Content == nil {
|
||
up.SetSize(*req.Size)
|
||
}
|
||
|
||
// AST field has been removed from the domain model
|
||
|
||
return nil
|
||
})
|
||
if err != nil {
|
||
u.logger.Error("failed to update workspace file", "error", err, "id", req.ID)
|
||
return nil, fmt.Errorf("failed to update file: %w", err)
|
||
}
|
||
|
||
u.logger.Info("workspace file updated", "id", req.ID)
|
||
return (&domain.WorkspaceFile{}).From(file), nil
|
||
}
|
||
|
||
func (u *WorkspaceFileUsecase) Delete(ctx context.Context, id string) error {
|
||
err := u.repo.Delete(ctx, id)
|
||
if err != nil {
|
||
u.logger.Error("failed to delete workspace file", "error", err, "id", id)
|
||
return fmt.Errorf("failed to delete file: %w", err)
|
||
}
|
||
|
||
u.logger.Info("workspace file deleted", "id", id)
|
||
return nil
|
||
}
|
||
|
||
func (u *WorkspaceFileUsecase) GetByID(ctx context.Context, id string) (*domain.WorkspaceFile, error) {
|
||
file, err := u.repo.GetByID(ctx, id)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get file: %w", err)
|
||
}
|
||
|
||
return (&domain.WorkspaceFile{}).From(file), nil
|
||
}
|
||
|
||
func (u *WorkspaceFileUsecase) GetAndSave(ctx context.Context, req *domain.GetAndSaveReq) error {
|
||
results, err := cli.RunCli("index", "", req.FileMetas)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
for _, res := range results {
|
||
file, err := u.repo.GetByPath(ctx, req.UserID, req.WorkspaceID, res.FilePath)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 判断是否已经保存了相同的CodeSnippet
|
||
// 目前不需要更新已有的CodeSnippet
|
||
if _, err := u.codeSnippetSvc.GetByID(ctx, ""); err != nil {
|
||
u.logger.Info("code snippet already exists, updating content")
|
||
} else {
|
||
// 创建新的CodeSnippet
|
||
_, err = u.codeSnippetSvc.CreateFromIndexResult(ctx, file.ID.String(), &res)
|
||
if err != nil {
|
||
u.logger.Error("failed to create code snippet from index result", "error", err, "filePath", res.FilePath)
|
||
// 继续处理其他结果,不因单个错误而中断整个流程
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (u *WorkspaceFileUsecase) GetByPath(ctx context.Context, userID, workspaceID, path string) (*domain.WorkspaceFile, error) {
|
||
file, err := u.repo.GetByPath(ctx, userID, workspaceID, path)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get file by path: %w", err)
|
||
}
|
||
|
||
return (&domain.WorkspaceFile{}).From(file), nil
|
||
}
|
||
|
||
func (u *WorkspaceFileUsecase) List(ctx context.Context, req *domain.ListWorkspaceFileReq) (*domain.ListWorkspaceFileResp, error) {
|
||
files, pageInfo, err := u.repo.List(ctx, req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to list files: %w", err)
|
||
}
|
||
|
||
return &domain.ListWorkspaceFileResp{
|
||
PageInfo: pageInfo,
|
||
Files: domain.FromWorkspaceFiles(files),
|
||
}, nil
|
||
}
|
||
|
||
func (u *WorkspaceFileUsecase) BatchCreate(ctx context.Context, req *domain.BatchCreateWorkspaceFileReq) ([]*domain.WorkspaceFile, error) {
|
||
// 确保工作区存在
|
||
// 首先通过workspace ID获取workspace信息,然后使用其root path来确保workspace存在
|
||
workspace, err := u.workspaceSvc.GetByID(ctx, req.WorkspaceID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get workspace by ID: %w", err)
|
||
}
|
||
|
||
_, err = u.workspaceSvc.EnsureWorkspace(ctx, req.UserID, workspace.RootPath, "")
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to ensure workspace exists: %w", err)
|
||
}
|
||
|
||
// 验证和预处理文件
|
||
for _, file := range req.Files {
|
||
if file.Hash == "" {
|
||
file.Hash = u.calculateHash(file.Content)
|
||
} else if !u.verifyHash(file.Content, file.Hash) {
|
||
return nil, fmt.Errorf("hash mismatch for file %s", file.Path)
|
||
}
|
||
|
||
if file.Size == 0 {
|
||
file.Size = int64(len(file.Content))
|
||
}
|
||
|
||
if file.Language == "" {
|
||
file.Language = u.inferLanguage(file.Path)
|
||
}
|
||
|
||
// 设置用户ID和工作区ID
|
||
file.UserID = req.UserID
|
||
file.WorkspaceID = req.WorkspaceID
|
||
}
|
||
|
||
files, err := u.repo.BatchCreate(ctx, req.Files)
|
||
if err != nil {
|
||
u.logger.Error("failed to batch create workspace files", "error", err, "count", len(req.Files))
|
||
return nil, fmt.Errorf("failed to batch create files: %w", err)
|
||
}
|
||
|
||
u.logger.Info("workspace files batch created", "count", len(files))
|
||
return domain.FromWorkspaceFiles(files), nil
|
||
}
|
||
|
||
func (u *WorkspaceFileUsecase) BatchUpdate(ctx context.Context, req *domain.BatchUpdateWorkspaceFileReq) ([]*domain.WorkspaceFile, error) {
|
||
var results []*domain.WorkspaceFile
|
||
|
||
for _, updateReq := range req.Files {
|
||
file, err := u.Update(ctx, updateReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to update file %s: %w", updateReq.ID, err)
|
||
}
|
||
results = append(results, file)
|
||
}
|
||
|
||
u.logger.Info("workspace files batch updated", "count", len(results))
|
||
return results, nil
|
||
}
|
||
|
||
func (u *WorkspaceFileUsecase) Sync(ctx context.Context, req *domain.SyncWorkspaceFileReq) (*domain.SyncWorkspaceFileResp, error) {
|
||
// 确保工作区存在
|
||
// 首先通过workspace ID获取workspace信息,然后使用其root path来确保workspace存在
|
||
workspace, err := u.workspaceSvc.GetByID(ctx, req.WorkspaceID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get workspace by ID: %w", err)
|
||
}
|
||
|
||
_, err = u.workspaceSvc.EnsureWorkspace(ctx, req.UserID, workspace.RootPath, "")
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to ensure workspace exists: %w", err)
|
||
}
|
||
|
||
// 获取要同步的文件哈希列表
|
||
var hashes []string
|
||
fileMap := make(map[string]*domain.CreateWorkspaceFileReq)
|
||
for _, file := range req.Files {
|
||
if file.Hash == "" {
|
||
file.Hash = u.calculateHash(file.Content)
|
||
}
|
||
hashes = append(hashes, file.Hash)
|
||
fileMap[file.Hash] = file
|
||
}
|
||
|
||
// 查找工作区中已存在的文件
|
||
existing, err := u.repo.GetByHashes(ctx, req.WorkspaceID, hashes)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get existing files: %w", err)
|
||
}
|
||
|
||
var toCreate []*domain.CreateWorkspaceFileReq
|
||
var toUpdate []*domain.UpdateWorkspaceFileReq
|
||
|
||
// 分类处理:创建新文件或更新现有文件
|
||
for hash, file := range fileMap {
|
||
file.UserID = req.UserID
|
||
file.WorkspaceID = req.WorkspaceID
|
||
|
||
if existingFile, exists := existing[hash]; exists {
|
||
// 文件已存在,检查是否需要更新
|
||
if existingFile.Path != file.Path || existingFile.Language != file.Language {
|
||
updateReq := &domain.UpdateWorkspaceFileReq{
|
||
ID: existingFile.ID.String(),
|
||
}
|
||
if existingFile.Path != file.Path {
|
||
updateReq.Content = &file.Content
|
||
}
|
||
if existingFile.Language != file.Language {
|
||
updateReq.Language = &file.Language
|
||
}
|
||
toUpdate = append(toUpdate, updateReq)
|
||
}
|
||
} else {
|
||
// 新文件,需要创建
|
||
if file.Language == "" {
|
||
file.Language = u.inferLanguage(file.Path)
|
||
}
|
||
if file.Size == 0 {
|
||
file.Size = int64(len(file.Content))
|
||
}
|
||
toCreate = append(toCreate, file)
|
||
}
|
||
}
|
||
|
||
resp := &domain.SyncWorkspaceFileResp{}
|
||
|
||
// 批量创建新文件
|
||
if len(toCreate) > 0 {
|
||
created, err := u.repo.BatchCreate(ctx, toCreate)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create new files: %w", err)
|
||
}
|
||
resp.Created = domain.FromWorkspaceFiles(created)
|
||
}
|
||
|
||
// 批量更新现有文件
|
||
if len(toUpdate) > 0 {
|
||
updated, err := u.BatchUpdate(ctx, &domain.BatchUpdateWorkspaceFileReq{Files: toUpdate})
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to update existing files: %w", err)
|
||
}
|
||
resp.Updated = updated
|
||
}
|
||
|
||
resp.Total = len(resp.Created) + len(resp.Updated)
|
||
|
||
u.logger.Info("workspace files synced",
|
||
"created", len(resp.Created),
|
||
"updated", len(resp.Updated),
|
||
"total", resp.Total)
|
||
|
||
return resp, nil
|
||
}
|
||
|
||
// 辅助方法
|
||
|
||
func (u *WorkspaceFileUsecase) calculateHash(content string) string {
|
||
hash := sha256.Sum256([]byte(content))
|
||
return hex.EncodeToString(hash[:])
|
||
}
|
||
|
||
func (u *WorkspaceFileUsecase) verifyHash(content, expectedHash string) bool {
|
||
actualHash := u.calculateHash(content)
|
||
return actualHash == expectedHash
|
||
}
|
||
|
||
func (u *WorkspaceFileUsecase) inferLanguage(path string) string {
|
||
// 简单的文件扩展名到语言的映射
|
||
if idx := strings.LastIndex(path, "."); idx != -1 {
|
||
ext := strings.ToLower(path[idx+1:])
|
||
switch ext {
|
||
case "go":
|
||
return "go"
|
||
case "js", "mjs":
|
||
return "javascript"
|
||
case "ts":
|
||
return "typescript"
|
||
case "py":
|
||
return "python"
|
||
case "java":
|
||
return "java"
|
||
case "cpp", "cc", "cxx":
|
||
return "cpp"
|
||
case "c":
|
||
return "c"
|
||
case "rs":
|
||
return "rust"
|
||
case "php":
|
||
return "php"
|
||
case "rb":
|
||
return "ruby"
|
||
case "swift":
|
||
return "swift"
|
||
case "kt":
|
||
return "kotlin"
|
||
case "cs":
|
||
return "csharp"
|
||
case "sh", "bash":
|
||
return "shell"
|
||
case "sql":
|
||
return "sql"
|
||
case "yaml", "yml":
|
||
return "yaml"
|
||
case "json":
|
||
return "json"
|
||
case "xml":
|
||
return "xml"
|
||
case "html":
|
||
return "html"
|
||
case "css":
|
||
return "css"
|
||
case "md":
|
||
return "markdown"
|
||
case "toml":
|
||
return "toml"
|
||
case "ini":
|
||
return "ini"
|
||
default:
|
||
return "text"
|
||
}
|
||
}
|
||
return "text"
|
||
}
|