Files
MonkeyCode/backend/internal/workspace/usecase/workspace.go

648 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package usecase
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"log/slog"
"math"
"strings"
"sync"
"time"
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/pkg/cli"
)
type WorkspaceUsecase struct {
repo domain.WorkspaceRepo
config *config.Config
logger *slog.Logger
ensureLocks sync.Map // map[string]*sync.Mutex
}
type WorkspaceFileUsecase struct {
repo domain.WorkspaceFileRepo
workspaceSvc domain.WorkspaceUsecase
codeSnippetSvc domain.CodeSnippetUsecase
config *config.Config
logger *slog.Logger
}
func NewWorkspaceUsecase(
repo domain.WorkspaceRepo,
config *config.Config,
logger *slog.Logger,
) domain.WorkspaceUsecase {
return &WorkspaceUsecase{
repo: repo,
config: config,
logger: logger.With("usecase", "workspace"),
}
}
func NewWorkspaceFileUsecase(
repo domain.WorkspaceFileRepo,
workspaceSvc domain.WorkspaceUsecase,
codeSnippetSvc domain.CodeSnippetUsecase,
config *config.Config,
logger *slog.Logger,
) domain.WorkspaceFileUsecase {
return &WorkspaceFileUsecase{
repo: repo,
workspaceSvc: workspaceSvc,
codeSnippetSvc: codeSnippetSvc,
config: config,
logger: logger.With("usecase", "workspace_file"),
}
}
// WorkspaceUsecase methods
func (u *WorkspaceUsecase) Create(ctx context.Context, req *domain.CreateWorkspaceReq) (*domain.Workspace, error) {
workspace, err := u.repo.Create(ctx, req)
if err != nil {
u.logger.Error("failed to create workspace", "error", err, "name", req.Name, "root_path", req.RootPath)
return nil, fmt.Errorf("failed to create workspace: %w", err)
}
u.logger.Info("workspace created", "id", workspace.ID, "name", req.Name, "root_path", req.RootPath)
return (&domain.Workspace{}).From(workspace), nil
}
func (u *WorkspaceUsecase) GetByID(ctx context.Context, id string) (*domain.Workspace, error) {
workspace, err := u.repo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get workspace: %w", err)
}
return (&domain.Workspace{}).From(workspace), nil
}
func (u *WorkspaceUsecase) GetByUserAndPath(ctx context.Context, userID, rootPath string) (*domain.Workspace, error) {
workspace, err := u.repo.GetByUserAndPath(ctx, userID, rootPath)
if err != nil {
return nil, fmt.Errorf("failed to get workspace by user and path: %w", err)
}
return (&domain.Workspace{}).From(workspace), nil
}
func (u *WorkspaceUsecase) List(ctx context.Context, req *domain.ListWorkspaceReq) (*domain.ListWorkspaceResp, error) {
workspaces, pageInfo, err := u.repo.List(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to list workspaces: %w", err)
}
return &domain.ListWorkspaceResp{
PageInfo: pageInfo,
Workspaces: domain.FromWorkspaces(workspaces),
}, nil
}
func (u *WorkspaceUsecase) Update(ctx context.Context, req *domain.UpdateWorkspaceReq) (*domain.Workspace, error) {
workspace, err := u.repo.Update(ctx, req.ID, func(up *db.WorkspaceUpdateOne) error {
if req.Name != nil {
up.SetName(*req.Name)
}
if req.Description != nil {
up.SetDescription(*req.Description)
}
if req.Settings != nil {
up.SetSettings(req.Settings)
}
return nil
})
if err != nil {
u.logger.Error("failed to update workspace", "error", err, "id", req.ID)
return nil, fmt.Errorf("failed to update workspace: %w", err)
}
u.logger.Info("workspace updated", "id", req.ID)
return (&domain.Workspace{}).From(workspace), nil
}
func (u *WorkspaceUsecase) Delete(ctx context.Context, id string) error {
err := u.repo.Delete(ctx, id)
if err != nil {
u.logger.Error("failed to delete workspace", "error", err, "id", id)
return fmt.Errorf("failed to delete workspace: %w", err)
}
u.logger.Info("workspace deleted", "id", id)
return nil
}
func (u *WorkspaceUsecase) EnsureWorkspace(ctx context.Context, userID, rootPath, name string) (*domain.Workspace, error) {
// 创建锁的唯一键
lockKey := fmt.Sprintf("%s:%s", userID, rootPath)
// 获取或创建针对这个 userID+rootPath 的锁
lockValue, _ := u.ensureLocks.LoadOrStore(lockKey, &sync.Mutex{})
lock := lockValue.(*sync.Mutex)
// 加锁,防止并发创建
lock.Lock()
defer lock.Unlock()
// 自动生成工作区名称(如果未提供)
if name == "" {
name = u.generateWorkspaceName(rootPath)
}
// 首先尝试获取已存在的工作区
workspace, err := u.repo.GetByUserAndPath(ctx, userID, rootPath)
if err == nil {
// 工作区已存在,更新最后访问时间
updated, err := u.repo.Update(ctx, workspace.ID.String(), func(up *db.WorkspaceUpdateOne) error {
up.SetLastAccessedAt(time.Now())
return nil
})
if err != nil {
u.logger.Warn("failed to update workspace last accessed time", "error", err, "id", workspace.ID)
// 即使更新访问时间失败,也返回工作区
return (&domain.Workspace{}).From(workspace), nil
}
return (&domain.Workspace{}).From(updated), nil
}
// 如果工作区不存在,创建新的工作区
if !db.IsNotFound(err) {
return nil, fmt.Errorf("failed to check workspace existence: %w", err)
}
// 使用改进的重试机制来处理并发创建的情况
maxRetries := 5
for i := range maxRetries {
createReq := &domain.CreateWorkspaceReq{
UserID: userID,
Name: name,
Description: fmt.Sprintf("Auto-created workspace for %s", rootPath),
RootPath: rootPath,
Settings: map[string]any{},
}
workspace, err := u.Create(ctx, createReq)
if err == nil {
u.logger.Info("workspace created successfully", "userID", userID, "rootPath", rootPath, "retry", i)
return workspace, nil
}
// 如果是唯一约束错误,说明工作区已经被其他请求创建了
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
u.logger.Debug("workspace creation conflict, retrying", "userID", userID, "rootPath", rootPath, "retry", i, "error", err)
// 使用指数退避等待
waitTime := time.Duration(math.Pow(2, float64(i))) * 25 * time.Millisecond
if waitTime > 500*time.Millisecond {
waitTime = 500 * time.Millisecond
}
time.Sleep(waitTime)
// 尝试获取已创建的工作区
existing, err := u.repo.GetByUserAndPath(ctx, userID, rootPath)
if err == nil {
u.logger.Info("found existing workspace after conflict", "userID", userID, "rootPath", rootPath, "retry", i)
// 更新最后访问时间
updated, err := u.repo.Update(ctx, existing.ID.String(), func(up *db.WorkspaceUpdateOne) error {
up.SetLastAccessedAt(time.Now())
return nil
})
if err != nil {
u.logger.Warn("failed to update workspace last accessed time", "error", err, "id", existing.ID)
// 即使更新访问时间失败,也返回工作区
return (&domain.Workspace{}).From(existing), nil
}
return (&domain.Workspace{}).From(updated), nil
} else {
u.logger.Warn("failed to get workspace after conflict", "userID", userID, "rootPath", rootPath, "retry", i, "error", err)
}
// 如果是最后一次重试,返回错误
if i == maxRetries-1 {
u.logger.Error("failed to resolve workspace creation conflict after all retries", "userID", userID, "rootPath", rootPath, "maxRetries", maxRetries)
return nil, fmt.Errorf("workspace creation conflict persists after %d retries: %w", maxRetries, err)
}
continue
}
// 如果不是唯一约束错误,直接返回错误
u.logger.Error("workspace creation failed with non-conflict error", "userID", userID, "rootPath", rootPath, "retry", i, "error", err)
return nil, fmt.Errorf("failed to create workspace: %w", err)
}
return nil, fmt.Errorf("failed to create workspace after %d retries", maxRetries)
}
func (u *WorkspaceUsecase) generateWorkspaceName(rootPath string) string {
// 从路径中提取最后一个目录名作为工作区名称
parts := strings.Split(rootPath, "/")
if len(parts) > 0 {
name := parts[len(parts)-1]
if name != "" {
return name
}
}
return "Untitled Workspace"
}
// WorkspaceFileUsecase methods
func (u *WorkspaceFileUsecase) Create(ctx context.Context, req *domain.CreateWorkspaceFileReq) (*domain.WorkspaceFile, error) {
// 验证和计算哈希
if req.Hash == "" {
req.Hash = u.calculateHash(req.Content)
} else if !u.verifyHash(req.Content, req.Hash) {
return nil, fmt.Errorf("provided hash does not match content")
}
// 计算文件大小
if req.Size == 0 {
req.Size = int64(len(req.Content))
}
// 推断编程语言
if req.Language == "" {
req.Language = u.inferLanguage(req.Path)
}
// 确保工作区存在
// 首先通过workspace ID获取workspace信息然后使用其root path来确保workspace存在
workspace, err := u.workspaceSvc.GetByID(ctx, req.WorkspaceID)
if err != nil {
return nil, fmt.Errorf("failed to get workspace by ID: %w", err)
}
_, err = u.workspaceSvc.EnsureWorkspace(ctx, req.UserID, workspace.RootPath, "")
if err != nil {
return nil, fmt.Errorf("failed to ensure workspace exists: %w", err)
}
file, err := u.repo.Create(ctx, req)
if err != nil {
u.logger.Error("failed to create workspace file", "error", err, "path", req.Path)
return nil, fmt.Errorf("failed to create file: %w", err)
}
u.logger.Info("workspace file created", "id", file.ID, "path", req.Path)
return (&domain.WorkspaceFile{}).From(file), nil
}
func (u *WorkspaceFileUsecase) Update(ctx context.Context, req *domain.UpdateWorkspaceFileReq) (*domain.WorkspaceFile, error) {
file, err := u.repo.Update(ctx, req.ID, func(up *db.WorkspaceFileUpdateOne) error {
if req.Content != nil {
up.SetContent(*req.Content)
// 更新内容时重新计算哈希和大小
if req.Hash != nil {
if !u.verifyHash(*req.Content, *req.Hash) {
return fmt.Errorf("provided hash does not match content")
}
up.SetHash(*req.Hash)
} else {
up.SetHash(u.calculateHash(*req.Content))
}
if req.Size != nil {
up.SetSize(*req.Size)
} else {
up.SetSize(int64(len(*req.Content)))
}
} else if req.Hash != nil {
up.SetHash(*req.Hash)
}
if req.Language != nil {
up.SetLanguage(*req.Language)
}
if req.Size != nil && req.Content == nil {
up.SetSize(*req.Size)
}
// AST field has been removed from the domain model
return nil
})
if err != nil {
u.logger.Error("failed to update workspace file", "error", err, "id", req.ID)
return nil, fmt.Errorf("failed to update file: %w", err)
}
u.logger.Info("workspace file updated", "id", req.ID)
return (&domain.WorkspaceFile{}).From(file), nil
}
func (u *WorkspaceFileUsecase) Delete(ctx context.Context, id string) error {
err := u.repo.Delete(ctx, id)
if err != nil {
u.logger.Error("failed to delete workspace file", "error", err, "id", id)
return fmt.Errorf("failed to delete file: %w", err)
}
u.logger.Info("workspace file deleted", "id", id)
return nil
}
func (u *WorkspaceFileUsecase) GetByID(ctx context.Context, id string) (*domain.WorkspaceFile, error) {
file, err := u.repo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get file: %w", err)
}
return (&domain.WorkspaceFile{}).From(file), nil
}
func (u *WorkspaceFileUsecase) GetAndSave(ctx context.Context, req *domain.GetAndSaveReq) error {
// 获取workspace信息以获取RootPath
workspace, err := u.workspaceSvc.GetByID(ctx, req.WorkspaceID)
if err != nil {
u.logger.Error("failed to get workspace by ID", "error", err, "workspaceID", req.WorkspaceID)
return fmt.Errorf("failed to get workspace: %w", err)
}
results, err := cli.RunCli("index", "", req.FileMetas)
if err != nil {
return err
}
for _, res := range results {
file, err := u.repo.GetByPath(ctx, req.UserID, req.WorkspaceID, res.FilePath)
if err != nil {
return err
}
// 先删除与该文件关联的所有旧代码片段
existingSnippets, err := u.codeSnippetSvc.ListByWorkspaceFile(ctx, file.ID.String())
if err != nil {
u.logger.Error("failed to list existing code snippets", "error", err, "fileID", file.ID)
// 继续处理,不因错误而中断整个流程
} else {
for _, snippet := range existingSnippets {
// 检查snippet ID是否为空
if snippet.ID == "" {
u.logger.Warn("skipping deletion of code snippet with empty ID", "fileID", file.ID)
continue
}
err := u.codeSnippetSvc.Delete(ctx, snippet.ID)
if err != nil {
u.logger.Error("failed to delete existing code snippet", "error", err, "snippetID", snippet.ID)
// 继续处理其他片段,不因单个错误而中断整个流程
}
}
}
// 创建新的CodeSnippet传递workspacePath
_, err = u.codeSnippetSvc.CreateFromIndexResult(ctx, file.ID.String(), &res, workspace.RootPath)
if err != nil {
u.logger.Error("failed to create code snippet from index result", "error", err, "filePath", res.FilePath)
// 继续处理其他结果,不因单个错误而中断整个流程
}
}
return nil
}
func (u *WorkspaceFileUsecase) GetByPath(ctx context.Context, userID, workspaceID, path string) (*domain.WorkspaceFile, error) {
file, err := u.repo.GetByPath(ctx, userID, workspaceID, path)
if err != nil {
return nil, fmt.Errorf("failed to get file by path: %w", err)
}
return (&domain.WorkspaceFile{}).From(file), nil
}
func (u *WorkspaceFileUsecase) List(ctx context.Context, req *domain.ListWorkspaceFileReq) (*domain.ListWorkspaceFileResp, error) {
files, pageInfo, err := u.repo.List(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to list files: %w", err)
}
return &domain.ListWorkspaceFileResp{
PageInfo: pageInfo,
Files: domain.FromWorkspaceFiles(files),
}, nil
}
func (u *WorkspaceFileUsecase) BatchCreate(ctx context.Context, req *domain.BatchCreateWorkspaceFileReq) ([]*domain.WorkspaceFile, error) {
// 确保工作区存在
// 首先通过workspace ID获取workspace信息然后使用其root path来确保workspace存在
workspace, err := u.workspaceSvc.GetByID(ctx, req.WorkspaceID)
if err != nil {
return nil, fmt.Errorf("failed to get workspace by ID: %w", err)
}
_, err = u.workspaceSvc.EnsureWorkspace(ctx, req.UserID, workspace.RootPath, "")
if err != nil {
return nil, fmt.Errorf("failed to ensure workspace exists: %w", err)
}
// 验证和预处理文件
for _, file := range req.Files {
if file.Hash == "" {
file.Hash = u.calculateHash(file.Content)
} else if !u.verifyHash(file.Content, file.Hash) {
return nil, fmt.Errorf("hash mismatch for file %s", file.Path)
}
if file.Size == 0 {
file.Size = int64(len(file.Content))
}
if file.Language == "" {
file.Language = u.inferLanguage(file.Path)
}
// 设置用户ID和工作区ID
file.UserID = req.UserID
file.WorkspaceID = req.WorkspaceID
}
files, err := u.repo.BatchCreate(ctx, req.Files)
if err != nil {
u.logger.Error("failed to batch create workspace files", "error", err, "count", len(req.Files))
return nil, fmt.Errorf("failed to batch create files: %w", err)
}
u.logger.Info("workspace files batch created", "count", len(files))
return domain.FromWorkspaceFiles(files), nil
}
func (u *WorkspaceFileUsecase) BatchUpdate(ctx context.Context, req *domain.BatchUpdateWorkspaceFileReq) ([]*domain.WorkspaceFile, error) {
var results []*domain.WorkspaceFile
for _, updateReq := range req.Files {
file, err := u.Update(ctx, updateReq)
if err != nil {
return nil, fmt.Errorf("failed to update file %s: %w", updateReq.ID, err)
}
results = append(results, file)
}
u.logger.Info("workspace files batch updated", "count", len(results))
return results, nil
}
func (u *WorkspaceFileUsecase) Sync(ctx context.Context, req *domain.SyncWorkspaceFileReq) (*domain.SyncWorkspaceFileResp, error) {
// 确保工作区存在
// 首先通过workspace ID获取workspace信息然后使用其root path来确保workspace存在
workspace, err := u.workspaceSvc.GetByID(ctx, req.WorkspaceID)
if err != nil {
return nil, fmt.Errorf("failed to get workspace by ID: %w", err)
}
_, err = u.workspaceSvc.EnsureWorkspace(ctx, req.UserID, workspace.RootPath, "")
if err != nil {
return nil, fmt.Errorf("failed to ensure workspace exists: %w", err)
}
// 获取要同步的文件哈希列表
var hashes []string
fileMap := make(map[string]*domain.CreateWorkspaceFileReq)
for _, file := range req.Files {
if file.Hash == "" {
file.Hash = u.calculateHash(file.Content)
}
hashes = append(hashes, file.Hash)
fileMap[file.Hash] = file
}
// 查找工作区中已存在的文件
existing, err := u.repo.GetByHashes(ctx, req.WorkspaceID, hashes)
if err != nil {
return nil, fmt.Errorf("failed to get existing files: %w", err)
}
var toCreate []*domain.CreateWorkspaceFileReq
var toUpdate []*domain.UpdateWorkspaceFileReq
// 分类处理:创建新文件或更新现有文件
for hash, file := range fileMap {
file.UserID = req.UserID
file.WorkspaceID = req.WorkspaceID
if existingFile, exists := existing[hash]; exists {
// 文件已存在,检查是否需要更新
if existingFile.Path != file.Path || existingFile.Language != file.Language {
updateReq := &domain.UpdateWorkspaceFileReq{
ID: existingFile.ID.String(),
}
if existingFile.Path != file.Path {
updateReq.Content = &file.Content
}
if existingFile.Language != file.Language {
updateReq.Language = &file.Language
}
toUpdate = append(toUpdate, updateReq)
}
} else {
// 新文件,需要创建
if file.Language == "" {
file.Language = u.inferLanguage(file.Path)
}
if file.Size == 0 {
file.Size = int64(len(file.Content))
}
toCreate = append(toCreate, file)
}
}
resp := &domain.SyncWorkspaceFileResp{}
// 批量创建新文件
if len(toCreate) > 0 {
created, err := u.repo.BatchCreate(ctx, toCreate)
if err != nil {
return nil, fmt.Errorf("failed to create new files: %w", err)
}
resp.Created = domain.FromWorkspaceFiles(created)
}
// 批量更新现有文件
if len(toUpdate) > 0 {
updated, err := u.BatchUpdate(ctx, &domain.BatchUpdateWorkspaceFileReq{Files: toUpdate})
if err != nil {
return nil, fmt.Errorf("failed to update existing files: %w", err)
}
resp.Updated = updated
}
resp.Total = len(resp.Created) + len(resp.Updated)
u.logger.Info("workspace files synced",
"created", len(resp.Created),
"updated", len(resp.Updated),
"total", resp.Total)
return resp, nil
}
// 辅助方法
func (u *WorkspaceFileUsecase) calculateHash(content string) string {
hash := sha256.Sum256([]byte(content))
return hex.EncodeToString(hash[:])
}
func (u *WorkspaceFileUsecase) verifyHash(content, expectedHash string) bool {
actualHash := u.calculateHash(content)
return actualHash == expectedHash
}
func (u *WorkspaceFileUsecase) inferLanguage(path string) string {
// 简单的文件扩展名到语言的映射
if idx := strings.LastIndex(path, "."); idx != -1 {
ext := strings.ToLower(path[idx+1:])
switch ext {
case "go":
return "go"
case "js", "mjs":
return "javascript"
case "ts":
return "typescript"
case "py":
return "python"
case "java":
return "java"
case "cpp", "cc", "cxx":
return "cpp"
case "c":
return "c"
case "rs":
return "rust"
case "php":
return "php"
case "rb":
return "ruby"
case "swift":
return "swift"
case "kt":
return "kotlin"
case "cs":
return "csharp"
case "sh", "bash":
return "shell"
case "sql":
return "sql"
case "yaml", "yml":
return "yaml"
case "json":
return "json"
case "xml":
return "xml"
case "html":
return "html"
case "css":
return "css"
case "md":
return "markdown"
case "toml":
return "toml"
case "ini":
return "ini"
default:
return "text"
}
}
return "text"
}