feat: Enhance workspace management with new methods and ensure workspace existence

This commit is contained in:
Haoxin Li
2025-07-25 19:02:52 +08:00
parent 3dab9bf7dd
commit 0d8673a54d
4 changed files with 421 additions and 30 deletions

View File

@@ -11,19 +11,20 @@ import (
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/pkg/cli"
socketio "github.com/doquangtan/socket.io/v4"
)
type FileUpdateData struct {
ID string `json:"id"`
FilePath string `json:"filePath"`
Hash string `json:"hash"`
Event string `json:"event"`
Content string `json:"content,omitempty"`
PreviousHash string `json:"previousHash,omitempty"`
Timestamp int64 `json:"timestamp"`
ApiKey string `json:"apiKey,omitempty"`
WorkspaceID string `json:"workspaceId,omitempty"`
ID string `json:"id"`
FilePath string `json:"filePath"`
Hash string `json:"hash"`
Event string `json:"event"`
Content string `json:"content,omitempty"`
PreviousHash string `json:"previousHash,omitempty"`
Timestamp int64 `json:"timestamp"`
ApiKey string `json:"apiKey,omitempty"`
WorkspacePath string `json:"workspacePath,omitempty"`
}
type AckResponse struct {
@@ -48,12 +49,15 @@ type SocketHandler struct {
config *config.Config
logger *slog.Logger
workspaceService domain.WorkspaceFileUsecase
workspaceUsecase domain.WorkspaceUsecase
userService domain.UserUsecase
io *socketio.Io
mu sync.Mutex
workspaceCache map[string]*domain.Workspace
cacheMutex sync.RWMutex
}
func NewSocketHandler(config *config.Config, logger *slog.Logger, workspaceService domain.WorkspaceFileUsecase, userService domain.UserUsecase) (*SocketHandler, error) {
func NewSocketHandler(config *config.Config, logger *slog.Logger, workspaceService domain.WorkspaceFileUsecase, workspaceUsecase domain.WorkspaceUsecase, userService domain.UserUsecase) (*SocketHandler, error) {
// 创建Socket.IO服务器
io := socketio.New()
@@ -61,9 +65,12 @@ func NewSocketHandler(config *config.Config, logger *slog.Logger, workspaceServi
config: config,
logger: logger,
workspaceService: workspaceService,
workspaceUsecase: workspaceUsecase,
userService: userService,
io: io,
mu: sync.Mutex{}, // 初始化互斥锁
workspaceCache: make(map[string]*domain.Workspace),
cacheMutex: sync.RWMutex{},
}
// 设置事件处理器
@@ -263,8 +270,11 @@ func (h *SocketHandler) handleFileUpdateFromObject(socket *socketio.Socket, data
if apiKey, ok := dataMap["apiKey"].(string); ok {
updateData.ApiKey = apiKey
}
if workspaceID, ok := dataMap["workspaceId"].(string); ok {
updateData.WorkspaceID = workspaceID
if workspacePath, ok := dataMap["workspacePath"].(string); ok {
updateData.WorkspacePath = workspacePath
h.logger.Debug("Extracted workspacePath from dataMap", "workspacePath", workspacePath)
} else {
h.logger.Debug("Failed to extract workspacePath from dataMap", "workspacePathType", fmt.Sprintf("%T", dataMap["workspacePath"]), "workspacePathValue", dataMap["workspacePath"])
}
h.logger.Info("Processing file update",
@@ -272,7 +282,7 @@ func (h *SocketHandler) handleFileUpdateFromObject(socket *socketio.Socket, data
"event", updateData.Event,
"file", updateData.FilePath,
"apiKey", updateData.ApiKey,
"workspaceId", updateData.WorkspaceID)
"workspacePath", updateData.WorkspacePath)
// 立即返回确认收到
immediateAck := AckResponse{
@@ -304,9 +314,21 @@ func (h *SocketHandler) processFileUpdateAsync(socket *socketio.Socket, updateDa
userID := user.ID.String()
// 确保workspace存在
workspaceID, err := h.ensureWorkspace(ctx, userID, updateData.WorkspacePath, updateData.FilePath)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to ensure workspace: %v", err)
h.logger.Error("Failed to ensure workspace", "error", err)
h.sendFinalResult(socket, updateData, finalStatus, message)
return
}
h.logger.Debug("Workspace ID obtained", "workspaceID", workspaceID, "filePath", updateData.FilePath)
switch updateData.Event {
case "initial_scan", "added":
existingFile, err := h.workspaceService.GetByPath(ctx, userID, updateData.WorkspaceID, updateData.FilePath)
existingFile, err := h.workspaceService.GetByPath(ctx, userID, workspaceID, updateData.FilePath)
if err != nil {
// "Not Found",文件不存在,执行创建逻辑
@@ -316,7 +338,7 @@ func (h *SocketHandler) processFileUpdateAsync(socket *socketio.Socket, updateDa
Content: updateData.Content,
Hash: updateData.Hash,
UserID: userID,
WorkspaceID: updateData.WorkspaceID,
WorkspaceID: workspaceID,
}
_, createErr := h.workspaceService.Create(ctx, createReq)
if createErr != nil {
@@ -361,7 +383,7 @@ func (h *SocketHandler) processFileUpdateAsync(socket *socketio.Socket, updateDa
case "modified":
// First, get the file by path to find its ID
file, err := h.workspaceService.GetByPath(ctx, userID, updateData.WorkspaceID, updateData.FilePath)
file, err := h.workspaceService.GetByPath(ctx, userID, workspaceID, updateData.FilePath)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to find file for update: %v", err)
@@ -387,7 +409,7 @@ func (h *SocketHandler) processFileUpdateAsync(socket *socketio.Socket, updateDa
case "deleted":
// First, get the file by path to find its ID
file, err := h.workspaceService.GetByPath(ctx, userID, updateData.WorkspaceID, updateData.FilePath)
file, err := h.workspaceService.GetByPath(ctx, userID, workspaceID, updateData.FilePath)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to find file for deletion: %v", err)
@@ -415,6 +437,27 @@ func (h *SocketHandler) processFileUpdateAsync(socket *socketio.Socket, updateDa
h.sendFinalResult(socket, updateData, finalStatus, message)
}
// ensureWorkspace ensures that a workspace exists for the given workspacePath
func (h *SocketHandler) ensureWorkspace(ctx context.Context, userID, workspacePath, filePath string) (string, error) {
h.logger.Debug("ensureWorkspace called", "userID", userID, "workspacePath", workspacePath, "filePath", filePath, "workspacePathLength", len(workspacePath))
if workspacePath != "" {
h.logger.Debug("Ensuring workspace for path", "path", workspacePath)
// Use EnsureWorkspace to create or update workspace based on path
workspace, err := h.workspaceUsecase.EnsureWorkspace(ctx, userID, workspacePath, "")
if err != nil {
h.logger.Error("Error ensuring workspace", "path", workspacePath, "error", err)
return "", fmt.Errorf("failed to ensure workspace: %w", err)
}
h.logger.Debug("Using existing or created workspace", "workspaceID", workspace.ID, "path", workspacePath)
return workspace.ID, nil
}
// If no workspacePath provided, return an error
h.logger.Debug("No workspace path provided, returning error")
return "", fmt.Errorf("no workspace path provided")
}
func (h *SocketHandler) handleTestPing(socket *socketio.Socket, data string) {
var pingData TestPingData
if err := json.Unmarshal([]byte(data), &pingData); err != nil {
@@ -533,3 +576,43 @@ func (h *SocketHandler) sendFinalResult(socket *socketio.Socket, updateData File
socket.Emit("file:update:ack", finalResponse)
h.mu.Unlock()
}
// generateAST 生成文件的AST信息
func (h *SocketHandler) generateAST(filePath, content string) string {
// 只对支持的编程语言生成AST
supportedLanguages := map[string]bool{
"go": true, "typescript": true, "javascript": true, "python": true,
}
// 简单判断文件扩展名
ext := ""
if len(filePath) > 0 {
for i := len(filePath) - 1; i >= 0; i-- {
if filePath[i] == '.' {
ext = filePath[i+1:]
break
}
}
}
// 如果不是支持的语言,返回空字符串
if !supportedLanguages[ext] {
return ""
}
// 创建临时文件来调用ctcode-cli
// 注意:这里是一个简化版本,实际使用时可能需要更复杂的临时文件处理
// 为了验证功能我们直接调用cli假设它能处理内容
results, err := cli.RunParseCLI("parse", "--successOnly", filePath)
if err != nil {
h.logger.Error("Failed to generate AST", "filePath", filePath, "error", err)
return ""
}
// 如果解析成功返回第一个结果的definition
if len(results) > 0 && results[0].Success {
return results[0].Definition
}
return ""
}

View File

@@ -104,7 +104,7 @@ func (h *WorkspaceFileHandler) GetAndSave(ctx *web.Context, req *domain.GetAndSa
h.logger.Error("failed to get and save workspace files", "error", err, "count", len(req.CodeFiles.Files))
return err
}
return ctx.Success(nil)
return ctx.Success(nil)
}
// Update 更新工作区文件

View File

@@ -8,19 +8,141 @@ import (
"github.com/google/uuid"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/db/workspace"
"github.com/chaitin/MonkeyCode/backend/db/workspacefile"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/pkg/entx"
)
type WorkspaceRepo struct {
db *db.Client
}
type WorkspaceFileRepo struct {
db *db.Client
}
func NewWorkspaceRepo(db *db.Client) domain.WorkspaceRepo {
return &WorkspaceRepo{db: db}
}
func NewWorkspaceFileRepo(db *db.Client) domain.WorkspaceFileRepo {
return &WorkspaceFileRepo{db: db}
}
// WorkspaceRepo methods
func (r *WorkspaceRepo) Create(ctx context.Context, req *domain.CreateWorkspaceReq) (*db.Workspace, error) {
userID, err := uuid.Parse(req.UserID)
if err != nil {
return nil, fmt.Errorf("invalid user ID: %w", err)
}
return r.db.Workspace.Create().
SetUserID(userID).
SetName(req.Name).
SetDescription(req.Description).
SetRootPath(req.RootPath).
SetSettings(req.Settings).
Save(ctx)
}
func (r *WorkspaceRepo) Update(ctx context.Context, id string, fn func(*db.WorkspaceUpdateOne) error) (*db.Workspace, error) {
workspaceID, err := uuid.Parse(id)
if err != nil {
return nil, fmt.Errorf("invalid workspace ID: %w", err)
}
var workspace *db.Workspace
err = entx.WithTx(ctx, r.db, func(tx *db.Tx) error {
old, err := tx.Workspace.Get(ctx, workspaceID)
if err != nil {
return err
}
up := tx.Workspace.UpdateOneID(old.ID)
if err := fn(up); err != nil {
return err
}
if updated, err := up.Save(ctx); err != nil {
return err
} else {
workspace = updated
}
return nil
})
return workspace, err
}
func (r *WorkspaceRepo) Delete(ctx context.Context, id string) error {
workspaceID, err := uuid.Parse(id)
if err != nil {
return fmt.Errorf("invalid workspace ID: %w", err)
}
return r.db.Workspace.DeleteOneID(workspaceID).Exec(ctx)
}
func (r *WorkspaceRepo) GetByID(ctx context.Context, id string) (*db.Workspace, error) {
workspaceID, err := uuid.Parse(id)
if err != nil {
return nil, fmt.Errorf("invalid workspace ID: %w", err)
}
return r.db.Workspace.Query().
Where(workspace.ID(workspaceID)).
Only(ctx)
}
func (r *WorkspaceRepo) GetByUserAndPath(ctx context.Context, userID, rootPath string) (*db.Workspace, error) {
userUUID, err := uuid.Parse(userID)
if err != nil {
return nil, fmt.Errorf("invalid user ID: %w", err)
}
return r.db.Workspace.Query().
Where(
workspace.UserID(userUUID),
workspace.RootPath(rootPath),
).
Only(ctx)
}
func (r *WorkspaceRepo) List(ctx context.Context, req *domain.ListWorkspaceReq) ([]*db.Workspace, *db.PageInfo, error) {
q := r.db.Workspace.Query()
// 添加筛选条件
if req.UserID != "" {
userID, err := uuid.Parse(req.UserID)
if err != nil {
return nil, nil, fmt.Errorf("invalid user ID: %w", err)
}
q = q.Where(workspace.UserID(userID))
}
if req.Search != "" {
q = q.Where(
workspace.Or(
workspace.NameContains(req.Search),
workspace.DescriptionContains(req.Search),
),
)
}
if req.RootPath != "" {
q = q.Where(workspace.RootPath(req.RootPath))
}
// 排序
q = q.Order(workspace.ByLastAccessedAt(sql.OrderDesc()))
// 分页查询
return q.Page(ctx, req.Page, req.Size)
}
// WorkspaceFileRepo methods
func (r *WorkspaceFileRepo) Create(ctx context.Context, req *domain.CreateWorkspaceFileReq) (*db.WorkspaceFile, error) {
userID, err := uuid.Parse(req.UserID)
if err != nil {

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"log/slog"
"strings"
"time"
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/db"
@@ -16,24 +17,171 @@ import (
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
)
type WorkspaceFileUsecase struct {
repo domain.WorkspaceFileRepo
type WorkspaceUsecase struct {
repo domain.WorkspaceRepo
config *config.Config
logger *slog.Logger
}
type WorkspaceFileUsecase struct {
repo domain.WorkspaceFileRepo
workspaceSvc domain.WorkspaceUsecase
config *config.Config
logger *slog.Logger
}
func NewWorkspaceUsecase(
repo domain.WorkspaceRepo,
config *config.Config,
logger *slog.Logger,
) domain.WorkspaceUsecase {
return &WorkspaceUsecase{
repo: repo,
config: config,
logger: logger.With("usecase", "workspace"),
}
}
func NewWorkspaceFileUsecase(
repo domain.WorkspaceFileRepo,
workspaceSvc domain.WorkspaceUsecase,
config *config.Config,
logger *slog.Logger,
) domain.WorkspaceFileUsecase {
return &WorkspaceFileUsecase{
repo: repo,
config: config,
logger: logger.With("usecase", "workspace_file"),
repo: repo,
workspaceSvc: workspaceSvc,
config: config,
logger: logger.With("usecase", "workspace_file"),
}
}
// WorkspaceUsecase methods
func (u *WorkspaceUsecase) Create(ctx context.Context, req *domain.CreateWorkspaceReq) (*domain.Workspace, error) {
workspace, err := u.repo.Create(ctx, req)
if err != nil {
u.logger.Error("failed to create workspace", "error", err, "name", req.Name, "root_path", req.RootPath)
return nil, fmt.Errorf("failed to create workspace: %w", err)
}
u.logger.Info("workspace created", "id", workspace.ID, "name", req.Name, "root_path", req.RootPath)
return cvt.From(workspace, &domain.Workspace{}), nil
}
func (u *WorkspaceUsecase) GetByID(ctx context.Context, id string) (*domain.Workspace, error) {
workspace, err := u.repo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get workspace: %w", err)
}
return cvt.From(workspace, &domain.Workspace{}), nil
}
func (u *WorkspaceUsecase) GetByUserAndPath(ctx context.Context, userID, rootPath string) (*domain.Workspace, error) {
workspace, err := u.repo.GetByUserAndPath(ctx, userID, rootPath)
if err != nil {
return nil, fmt.Errorf("failed to get workspace by user and path: %w", err)
}
return cvt.From(workspace, &domain.Workspace{}), nil
}
func (u *WorkspaceUsecase) List(ctx context.Context, req *domain.ListWorkspaceReq) (*domain.ListWorkspaceResp, error) {
workspaces, pageInfo, err := u.repo.List(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to list workspaces: %w", err)
}
return &domain.ListWorkspaceResp{
PageInfo: pageInfo,
Workspaces: domain.FromWorkspaces(workspaces),
}, nil
}
func (u *WorkspaceUsecase) Update(ctx context.Context, req *domain.UpdateWorkspaceReq) (*domain.Workspace, error) {
workspace, err := u.repo.Update(ctx, req.ID, func(up *db.WorkspaceUpdateOne) error {
if req.Name != nil {
up.SetName(*req.Name)
}
if req.Description != nil {
up.SetDescription(*req.Description)
}
if req.Settings != nil {
up.SetSettings(req.Settings)
}
return nil
})
if err != nil {
u.logger.Error("failed to update workspace", "error", err, "id", req.ID)
return nil, fmt.Errorf("failed to update workspace: %w", err)
}
u.logger.Info("workspace updated", "id", req.ID)
return cvt.From(workspace, &domain.Workspace{}), nil
}
func (u *WorkspaceUsecase) Delete(ctx context.Context, id string) error {
err := u.repo.Delete(ctx, id)
if err != nil {
u.logger.Error("failed to delete workspace", "error", err, "id", id)
return fmt.Errorf("failed to delete workspace: %w", err)
}
u.logger.Info("workspace deleted", "id", id)
return nil
}
func (u *WorkspaceUsecase) EnsureWorkspace(ctx context.Context, userID, rootPath, name string) (*domain.Workspace, error) {
// 首先尝试获取已存在的工作区
workspace, err := u.repo.GetByUserAndPath(ctx, userID, rootPath)
if err == nil {
// 工作区已存在,更新最后访问时间
updated, err := u.repo.Update(ctx, workspace.ID.String(), func(up *db.WorkspaceUpdateOne) error {
up.SetLastAccessedAt(time.Now())
return nil
})
if err != nil {
u.logger.Warn("failed to update workspace last accessed time", "error", err, "id", workspace.ID)
}
return cvt.From(updated, &domain.Workspace{}), nil
}
// 如果工作区不存在,创建新的工作区
if !db.IsNotFound(err) {
return nil, fmt.Errorf("failed to check workspace existence: %w", err)
}
// 自动生成工作区名称(如果未提供)
if name == "" {
name = u.generateWorkspaceName(rootPath)
}
createReq := &domain.CreateWorkspaceReq{
UserID: userID,
Name: name,
Description: fmt.Sprintf("Auto-created workspace for %s", rootPath),
RootPath: rootPath,
Settings: map[string]interface{}{},
}
return u.Create(ctx, createReq)
}
func (u *WorkspaceUsecase) generateWorkspaceName(rootPath string) string {
// 从路径中提取最后一个目录名作为工作区名称
parts := strings.Split(rootPath, "/")
if len(parts) > 0 {
name := parts[len(parts)-1]
if name != "" {
return name
}
}
return "Untitled Workspace"
}
// WorkspaceFileUsecase methods
func (u *WorkspaceFileUsecase) Create(ctx context.Context, req *domain.CreateWorkspaceFileReq) (*domain.WorkspaceFile, error) {
// 验证和计算哈希
if req.Hash == "" {
@@ -52,6 +200,18 @@ func (u *WorkspaceFileUsecase) Create(ctx context.Context, req *domain.CreateWor
req.Language = u.inferLanguage(req.Path)
}
// 确保工作区存在
// 首先通过workspace ID获取workspace信息然后使用其root path来确保workspace存在
workspace, err := u.workspaceSvc.GetByID(ctx, req.WorkspaceID)
if err != nil {
return nil, fmt.Errorf("failed to get workspace by ID: %w", err)
}
_, err = u.workspaceSvc.EnsureWorkspace(ctx, req.UserID, workspace.RootPath, "")
if err != nil {
return nil, fmt.Errorf("failed to ensure workspace exists: %w", err)
}
file, err := u.repo.Create(ctx, req)
if err != nil {
u.logger.Error("failed to create workspace file", "error", err, "path", req.Path)
@@ -92,6 +252,8 @@ func (u *WorkspaceFileUsecase) Update(ctx context.Context, req *domain.UpdateWor
up.SetSize(*req.Size)
}
// AST field has been removed from the domain model
return nil
})
if err != nil {
@@ -126,13 +288,13 @@ func (u *WorkspaceFileUsecase) GetByID(ctx context.Context, id string) (*domain.
func (u *WorkspaceFileUsecase) GetAndSave(ctx context.Context, req *domain.GetAndSaveReq) (error) {
results, err := cli.RunCli("index", "", req.CodeFiles)
if err != nil {
return err
}
return err
}
for _, res := range results {
file, err := u.repo.GetByPath(ctx, req.UserID, req.ProjectID, res.FilePath)
file, err := u.repo.GetByPath(ctx, req.UserID, req.ProjectID, res.FilePath)
if err != nil {
return err
}
return err
}
resString, err := json.Marshal(res)
if err!= nil {
@@ -142,10 +304,10 @@ func (u *WorkspaceFileUsecase) GetAndSave(ctx context.Context, req *domain.GetAn
return up.SetContent(string(resString)).Exec(ctx)
})
if err != nil {
return err
return err
}
}
return nil
return nil
}
func (u *WorkspaceFileUsecase) GetByPath(ctx context.Context, userID, workspaceID, path string) (*domain.WorkspaceFile, error) {
@@ -170,6 +332,18 @@ func (u *WorkspaceFileUsecase) List(ctx context.Context, req *domain.ListWorkspa
}
func (u *WorkspaceFileUsecase) BatchCreate(ctx context.Context, req *domain.BatchCreateWorkspaceFileReq) ([]*domain.WorkspaceFile, error) {
// 确保工作区存在
// 首先通过workspace ID获取workspace信息然后使用其root path来确保workspace存在
workspace, err := u.workspaceSvc.GetByID(ctx, req.WorkspaceID)
if err != nil {
return nil, fmt.Errorf("failed to get workspace by ID: %w", err)
}
_, err = u.workspaceSvc.EnsureWorkspace(ctx, req.UserID, workspace.RootPath, "")
if err != nil {
return nil, fmt.Errorf("failed to ensure workspace exists: %w", err)
}
// 验证和预处理文件
for _, file := range req.Files {
if file.Hash == "" {
@@ -217,6 +391,18 @@ func (u *WorkspaceFileUsecase) BatchUpdate(ctx context.Context, req *domain.Batc
}
func (u *WorkspaceFileUsecase) Sync(ctx context.Context, req *domain.SyncWorkspaceFileReq) (*domain.SyncWorkspaceFileResp, error) {
// 确保工作区存在
// 首先通过workspace ID获取workspace信息然后使用其root path来确保workspace存在
workspace, err := u.workspaceSvc.GetByID(ctx, req.WorkspaceID)
if err != nil {
return nil, fmt.Errorf("failed to get workspace by ID: %w", err)
}
_, err = u.workspaceSvc.EnsureWorkspace(ctx, req.UserID, workspace.RootPath, "")
if err != nil {
return nil, fmt.Errorf("failed to ensure workspace exists: %w", err)
}
// 获取要同步的文件哈希列表
var hashes []string
fileMap := make(map[string]*domain.CreateWorkspaceFileReq)