Files

660 lines
19 KiB
Go
Raw Permalink Normal View History

package handler
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"sync"
"time"
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/domain"
socketio "github.com/doquangtan/socket.io/v4"
)
type FileUpdateData struct {
ID string `json:"id"`
FilePath string `json:"filePath"`
Hash string `json:"hash"`
Event string `json:"event"`
Content string `json:"content,omitempty"`
PreviousHash string `json:"previousHash,omitempty"`
Timestamp int64 `json:"timestamp"`
ApiKey string `json:"apiKey,omitempty"`
WorkspacePath string `json:"workspacePath,omitempty"`
}
type AckResponse struct {
ID string `json:"id"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
}
type TestPingData struct {
Timestamp int64 `json:"timestamp"`
Message string `json:"message"`
SocketID string `json:"socketId"`
}
type HeartbeatData struct {
Type string `json:"type"`
Timestamp int64 `json:"timestamp"`
ClientID string `json:"clientId"`
}
type SocketHandler struct {
config *config.Config
logger *slog.Logger
workspaceService domain.WorkspaceFileUsecase
workspaceUsecase domain.WorkspaceUsecase
userService domain.UserUsecase
io *socketio.Io
mu sync.Mutex
workspaceCache map[string]*domain.Workspace
cacheMutex sync.RWMutex
workspaceProcessing sync.Map
}
func NewSocketHandler(config *config.Config, logger *slog.Logger, workspaceService domain.WorkspaceFileUsecase, workspaceUsecase domain.WorkspaceUsecase, userService domain.UserUsecase) (*SocketHandler, error) {
// 创建Socket.IO服务器
io := socketio.New()
handler := &SocketHandler{
config: config,
logger: logger,
workspaceService: workspaceService,
workspaceUsecase: workspaceUsecase,
userService: userService,
io: io,
mu: sync.Mutex{}, // 初始化互斥锁
workspaceCache: make(map[string]*domain.Workspace),
cacheMutex: sync.RWMutex{},
}
// 设置事件处理器
handler.setupEventHandlers()
return handler, nil
}
func (h *SocketHandler) setupEventHandlers() {
h.io.OnConnection(h.handleConnection)
}
func (h *SocketHandler) handleConnection(socket *socketio.Socket) {
h.logger.Debug("Client connected", "socketId", socket.Id)
h.sendServerStatus(socket, "ready", "Server is ready to receive updates")
// 注册事件处理器
h.registerDisconnectHandler(socket)
h.registerFileUpdateHandler(socket)
h.registerTestPingHandler(socket)
h.registerHeartbeatHandler(socket)
h.registerWorkspaceStatsHandler(socket)
}
func (h *SocketHandler) registerDisconnectHandler(socket *socketio.Socket) {
socket.On("disconnect", func(data *socketio.EventPayload) {
reason := "unknown"
if len(data.Data) > 0 {
if r, ok := data.Data[0].(string); ok {
reason = r
}
}
h.logger.Debug("Client disconnected", "socketId", socket.Id, "reason", reason)
})
}
func (h *SocketHandler) registerFileUpdateHandler(socket *socketio.Socket) {
socket.On("file:update", func(data *socketio.EventPayload) {
if len(data.Data) == 0 {
h.sendErrorACK(data, "No data provided")
return
}
h.processFileUpdateData(socket, data)
})
}
func (h *SocketHandler) processFileUpdateData(socket *socketio.Socket, data *socketio.EventPayload) {
switch v := data.Data[0].(type) {
case map[string]interface{}:
response := h.handleFileUpdateFromObject(socket, *data)
h.sendACKWithLock(data, response)
case string:
response := h.handleFileUpdate(socket, v)
h.sendACKWithLock(data, response)
default:
h.logger.Error("Data is neither string nor object",
"socketId", socket.Id,
"dataType", fmt.Sprintf("%T", v))
h.sendErrorACK(data, "Invalid data format - expected string or object")
}
}
func (h *SocketHandler) registerTestPingHandler(socket *socketio.Socket) {
socket.On("test:ping", func(data *socketio.EventPayload) {
if len(data.Data) > 0 {
if dataStr, ok := data.Data[0].(string); ok {
h.handleTestPing(socket, dataStr)
}
}
})
}
func (h *SocketHandler) registerHeartbeatHandler(socket *socketio.Socket) {
socket.On("heartbeat", func(data *socketio.EventPayload) {
if len(data.Data) == 0 {
h.sendErrorACK(data, "No heartbeat data")
return
}
// 直接传递第一个数据元素,支持对象和字符串
response := h.handleHeartbeat(socket, data.Data[0])
if data.Ack != nil {
data.Ack(response)
}
})
}
func (h *SocketHandler) registerWorkspaceStatsHandler(socket *socketio.Socket) {
socket.On("workspace:stats", func(data *socketio.EventPayload) {
// Note: GetWorkspaceStats is not in the new interface.
// This will need to be implemented or removed.
// For now, returning a placeholder.
response := map[string]interface{}{
"status": "not_implemented",
"message": "Workspace stats functionality is not available.",
}
if data.Ack != nil {
data.Ack(response)
}
})
}
func (h *SocketHandler) handleFileUpdate(socket *socketio.Socket, data string) interface{} {
var updateData FileUpdateData
if err := json.Unmarshal([]byte(data), &updateData); err != nil {
h.logger.Error("Failed to parse file update data", "error", err, "data", data)
return map[string]interface{}{
"status": "error",
"message": "Invalid data format",
}
}
// 立即返回确认收到
immediateAck := AckResponse{
ID: updateData.ID,
Status: "received",
Message: "File update received, processing...",
}
// 异步处理文件操作
go h.processFileUpdateAsync(socket, updateData)
return immediateAck
}
func (h *SocketHandler) handleFileUpdateFromObject(socket *socketio.Socket, data socketio.EventPayload) interface{} {
// 从数据中获取第一个元素应该是map
if len(data.Data) == 0 {
h.logger.Error("No data provided for file update")
return AckResponse{
Status: "error",
Message: "No data provided",
}
}
dataMap, ok := data.Data[0].(map[string]interface{})
if !ok {
h.logger.Error("Invalid data format for file update", "type", fmt.Sprintf("%T", data.Data[0]))
return AckResponse{
Status: "error",
Message: "Invalid data format",
}
}
// 解析数据字段
var updateData FileUpdateData
// 使用类型断言提取字段
if id, ok := dataMap["id"].(string); ok {
updateData.ID = id
}
if filePath, ok := dataMap["filePath"].(string); ok {
updateData.FilePath = filePath
}
if event, ok := dataMap["event"].(string); ok {
updateData.Event = event
}
if hash, ok := dataMap["hash"].(string); ok {
updateData.Hash = hash
}
if content, ok := dataMap["content"].(string); ok {
updateData.Content = content
}
if timestamp, ok := dataMap["timestamp"].(float64); ok {
updateData.Timestamp = int64(timestamp)
}
if apiKey, ok := dataMap["apiKey"].(string); ok {
updateData.ApiKey = apiKey
}
if workspacePath, ok := dataMap["workspacePath"].(string); ok {
updateData.WorkspacePath = workspacePath
}
// 立即返回确认收到
immediateAck := AckResponse{
ID: updateData.ID,
Status: "received",
Message: "File update received, processing...",
}
// 异步处理文件操作
go h.processFileUpdateAsync(socket, updateData)
return immediateAck
}
func (h *SocketHandler) processFileUpdateAsync(socket *socketio.Socket, updateData FileUpdateData) {
// 处理文件操作
var finalStatus, message string
ctx := context.Background()
// 通过ApiKey获取用户信息
user, err := h.userService.GetUserByApiKey(ctx, updateData.ApiKey)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Invalid API key: %v", err)
h.logger.Error("Failed to get user by API key", "apiKey", updateData.ApiKey, "error", err)
h.sendFinalResult(socket, updateData, finalStatus, message)
return
}
userID := user.ID.String()
// 确保workspace存在
workspaceID, err := h.ensureWorkspace(ctx, userID, updateData.WorkspacePath)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to ensure workspace: %v", err)
h.logger.Error("Failed to ensure workspace", "error", err)
h.sendFinalResult(socket, updateData, finalStatus, message)
return
}
// Workspace ID obtained
switch updateData.Event {
case "initial_scan", "added":
existingFile, err := h.workspaceService.GetByPath(ctx, userID, workspaceID, updateData.FilePath)
if err != nil {
// "Not Found",文件不存在,执行创建逻辑
if db.IsNotFound(err) {
createReq := &domain.CreateWorkspaceFileReq{
Path: updateData.FilePath,
Content: updateData.Content,
Hash: updateData.Hash,
UserID: userID,
WorkspaceID: workspaceID,
}
_, createErr := h.workspaceService.Create(ctx, createReq)
if createErr != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to create file: %v", createErr)
h.logger.Error("Failed to create file", "path", updateData.FilePath, "error", createErr)
} else {
// 调用GetAndSave处理新创建的文件
fileExtension := h.getFileExtension(updateData.FilePath)
codeFiles := domain.CodeFiles{
Files: []domain.FileMeta{
{
FilePath: updateData.FilePath,
// FileExtension: fileExtension,
Language: h.getFileLanguage(fileExtension),
Content: updateData.Content,
},
},
}
getAndSaveReq := &domain.GetAndSaveReq{
UserID: userID,
WorkspaceID: workspaceID,
FileMetas: codeFiles.Files,
}
err = h.workspaceService.GetAndSave(ctx, getAndSaveReq)
if err != nil {
h.logger.Debug("Failed to process file with GetAndSave", "path", updateData.FilePath, "error", err)
}
finalStatus = "success"
message = "File created successfully"
}
} else {
// 其他错误
finalStatus = "error"
message = fmt.Sprintf("Error checking for existing file: %v", err)
h.logger.Error("Error checking for existing file", "path", updateData.FilePath, "error", err)
}
} else {
// 文件已存在,检查是否需要更新
if existingFile.Hash == updateData.Hash {
finalStatus = "success"
message = "File is already up-to-date"
} else {
updateReq := &domain.UpdateWorkspaceFileReq{
ID: existingFile.ID,
Content: &updateData.Content,
Hash: &updateData.Hash,
}
_, updateErr := h.workspaceService.Update(ctx, updateReq)
if updateErr != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to update existing file: %v", updateErr)
h.logger.Error("Failed to update existing file", "path", updateData.FilePath, "error", updateErr)
} else {
finalStatus = "success"
message = "File updated successfully"
}
}
}
case "modified":
// First, get the file by path to find its ID
file, err := h.workspaceService.GetByPath(ctx, userID, workspaceID, updateData.FilePath)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to find file for update: %v", err)
h.logger.Error("Failed to find file for update", "path", updateData.FilePath, "error", err)
break
}
req := &domain.UpdateWorkspaceFileReq{
ID: file.ID,
Content: &updateData.Content,
Hash: &updateData.Hash,
}
_, err = h.workspaceService.Update(ctx, req)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to update file: %v", err)
h.logger.Error("Failed to update file", "path", updateData.FilePath, "error", err)
} else {
finalStatus = "success"
message = "File updated successfully"
// 调用GetAndSave处理更新的文件
fileExtension := h.getFileExtension(updateData.FilePath)
codeFiles := domain.CodeFiles{
Files: []domain.FileMeta{
{
FilePath: updateData.FilePath,
// FileExtension: fileExtension,
Language: h.getFileLanguage(fileExtension),
Content: updateData.Content,
},
},
}
getAndSaveReq := &domain.GetAndSaveReq{
UserID: userID,
WorkspaceID: workspaceID,
FileMetas: codeFiles.Files,
}
err = h.workspaceService.GetAndSave(ctx, getAndSaveReq)
if err != nil {
h.logger.Debug("Failed to process file with GetAndSave", "path", updateData.FilePath, "error", err)
}
}
case "deleted":
// First, get the file by path to find its ID
file, err := h.workspaceService.GetByPath(ctx, userID, workspaceID, updateData.FilePath)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to find file for deletion: %v", err)
h.logger.Error("Failed to find file for deletion", "path", updateData.FilePath, "error", err)
break
}
err = h.workspaceService.Delete(ctx, file.ID)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to delete file: %v", err)
h.logger.Error("Failed to delete file", "path", updateData.FilePath, "error", err)
} else {
finalStatus = "success"
message = "File deleted successfully"
}
default:
finalStatus = "error"
message = fmt.Sprintf("Unknown event type: %s", updateData.Event)
}
// 发送最终处理结果
h.sendFinalResult(socket, updateData, finalStatus, message)
}
// ensureWorkspace ensures that a workspace exists for the given workspacePath
func (h *SocketHandler) ensureWorkspace(ctx context.Context, userID, workspacePath string) (string, error) {
if workspacePath == "" {
return "", fmt.Errorf("no workspace path provided")
}
// 创建处理键,防止同一个 workspace 的并发处理
processingKey := fmt.Sprintf("%s:%s", userID, workspacePath)
// 检查是否已经在处理中
if _, processing := h.workspaceProcessing.LoadOrStore(processingKey, true); processing {
h.logger.Debug("workspace already being processed, waiting", "userID", userID, "workspacePath", workspacePath)
// 等待一段时间后重试
maxWaitRetries := 10
for i := 0; i < maxWaitRetries; i++ {
time.Sleep(50 * time.Millisecond)
if _, stillProcessing := h.workspaceProcessing.Load(processingKey); !stillProcessing {
break
}
}
// 如果仍在处理中,直接调用 EnsureWorkspace此时应该会很快返回现有的workspace
h.logger.Debug("proceeding with workspace creation after wait", "userID", userID, "workspacePath", workspacePath)
}
// 确保在函数结束时清理处理标记
defer h.workspaceProcessing.Delete(processingKey)
// Use EnsureWorkspace to create or update workspace based on path
workspace, err := h.workspaceUsecase.EnsureWorkspace(ctx, userID, workspacePath, "")
if err != nil {
h.logger.Error("Error ensuring workspace", "userID", userID, "path", workspacePath, "error", err)
return "", fmt.Errorf("failed to ensure workspace: %w", err)
}
h.logger.Debug("workspace ensured successfully", "userID", userID, "workspacePath", workspacePath, "workspaceID", workspace.ID)
return workspace.ID, nil
}
func (h *SocketHandler) handleTestPing(socket *socketio.Socket, data string) {
var pingData TestPingData
if err := json.Unmarshal([]byte(data), &pingData); err != nil {
h.logger.Error("Failed to parse test ping data", "error", err)
return
}
// 发送pong响应
pongData := map[string]interface{}{
"timestamp": time.Now().UnixMilli(),
"serverTime": time.Now().Format(time.RFC3339),
"message": "Pong from MonkeyCode server",
"receivedPing": pingData,
"socketId": socket.Id,
"serverStatus": "ok",
}
h.mu.Lock()
socket.Emit("test:pong", pongData)
h.mu.Unlock()
}
func (h *SocketHandler) handleHeartbeat(socket *socketio.Socket, data interface{}) interface{} {
var heartbeatData HeartbeatData
// 处理不同类型的数据
switch v := data.(type) {
case string:
if err := json.Unmarshal([]byte(v), &heartbeatData); err != nil {
h.logger.Error("Failed to parse heartbeat data from string", "error", err)
return map[string]interface{}{
"status": "error",
"message": "Invalid heartbeat data format",
}
}
case map[string]interface{}:
// 直接从map中提取数据
if clientID, ok := v["clientId"].(string); ok {
heartbeatData.ClientID = clientID
}
if timestamp, ok := v["timestamp"].(float64); ok {
heartbeatData.Timestamp = int64(timestamp)
}
if typeStr, ok := v["type"].(string); ok {
heartbeatData.Type = typeStr
}
default:
h.logger.Error("Unexpected heartbeat data type", "type", fmt.Sprintf("%T", data))
return map[string]interface{}{
"status": "error",
"message": "Invalid heartbeat data type",
}
}
// 返回心跳响应
response := map[string]interface{}{
"status": "ok",
"serverTime": time.Now().UnixMilli(),
"socketId": socket.Id,
"clientId": heartbeatData.ClientID,
}
return response
}
func (h *SocketHandler) sendServerStatus(socket *socketio.Socket, status, message string) {
statusData := map[string]string{
"status": status,
"message": message,
}
socket.Emit("server:status", statusData)
}
// GetServer 返回Socket.IO服务器实例
func (h *SocketHandler) GetServer() *socketio.Io {
return h.io
}
// BroadcastServerStatus 向所有连接的客户端广播服务器状态
func (h *SocketHandler) BroadcastServerStatus(status, message string) {
statusData := map[string]interface{}{
"status": status,
"message": message,
}
h.io.Emit("server:status", statusData)
}
// GetConnectedClients 获取连接的客户端数量
func (h *SocketHandler) GetConnectedClients() int {
sockets := h.io.Sockets()
return len(sockets)
}
// 辅助方法发送错误ACK
func (h *SocketHandler) sendErrorACK(data *socketio.EventPayload, message string) {
if data.Ack != nil {
errorResp := map[string]interface{}{
"status": "error",
"message": message,
}
data.Ack(errorResp)
}
}
// 辅助方法带锁发送ACK
func (h *SocketHandler) sendACKWithLock(data *socketio.EventPayload, response interface{}) {
if data.Ack != nil {
h.mu.Lock()
data.Ack(response)
h.mu.Unlock()
}
}
// 发送最终处理结果
func (h *SocketHandler) sendFinalResult(socket *socketio.Socket, updateData FileUpdateData, status, message string) {
finalResponse := map[string]interface{}{
"id": updateData.ID,
"status": status,
"message": message,
"file": updateData.FilePath,
}
// 使用互斥锁保护Socket写入
h.mu.Lock()
socket.Emit("file:update:ack", finalResponse)
h.mu.Unlock()
}
// getFileExtension 获取文件扩展名
func (h *SocketHandler) getFileExtension(filePath string) string {
ext := ""
if len(filePath) > 0 {
for i := len(filePath) - 1; i >= 0; i-- {
if filePath[i] == '.' {
ext = filePath[i+1:]
break
}
}
}
return ext
}
// getFileLanguage 根据文件扩展名获取编程语言类型
func (h *SocketHandler) getFileLanguage(fileExtension string) domain.CodeLanguageType {
switch fileExtension {
case "go":
return domain.CodeLanguageTypeGo
case "py":
return domain.CodeLanguageTypePython
case "java":
return domain.CodeLanguageTypeJava
case "js":
return domain.CodeLanguageTypeJavaScript
case "ts":
return domain.CodeLanguageTypeTypeScript
case "jsx":
return domain.CodeLanguageTypeJSX
case "tsx":
return domain.CodeLanguageTypeTSX
case "html":
return domain.CodeLanguageTypeHTML
case "css":
return domain.CodeLanguageTypeCSS
case "php":
return domain.CodeLanguageTypePHP
case "rs":
return domain.CodeLanguageTypeRust
case "swift":
return domain.CodeLanguageTypeSwift
case "kt":
return domain.CodeLanguageTypeKotlin
case "c":
return domain.CodeLanguageTypeC
case "cpp", "cc", "cxx":
return domain.CodeLanguageTypeCpp
default:
return ""
}
}