Files

660 lines
19 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"sync"
"time"
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/domain"
socketio "github.com/doquangtan/socket.io/v4"
)
type FileUpdateData struct {
ID string `json:"id"`
FilePath string `json:"filePath"`
Hash string `json:"hash"`
Event string `json:"event"`
Content string `json:"content,omitempty"`
PreviousHash string `json:"previousHash,omitempty"`
Timestamp int64 `json:"timestamp"`
ApiKey string `json:"apiKey,omitempty"`
WorkspacePath string `json:"workspacePath,omitempty"`
}
type AckResponse struct {
ID string `json:"id"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
}
type TestPingData struct {
Timestamp int64 `json:"timestamp"`
Message string `json:"message"`
SocketID string `json:"socketId"`
}
type HeartbeatData struct {
Type string `json:"type"`
Timestamp int64 `json:"timestamp"`
ClientID string `json:"clientId"`
}
type SocketHandler struct {
config *config.Config
logger *slog.Logger
workspaceService domain.WorkspaceFileUsecase
workspaceUsecase domain.WorkspaceUsecase
userService domain.UserUsecase
io *socketio.Io
mu sync.Mutex
workspaceCache map[string]*domain.Workspace
cacheMutex sync.RWMutex
workspaceProcessing sync.Map
}
func NewSocketHandler(config *config.Config, logger *slog.Logger, workspaceService domain.WorkspaceFileUsecase, workspaceUsecase domain.WorkspaceUsecase, userService domain.UserUsecase) (*SocketHandler, error) {
// 创建Socket.IO服务器
io := socketio.New()
handler := &SocketHandler{
config: config,
logger: logger,
workspaceService: workspaceService,
workspaceUsecase: workspaceUsecase,
userService: userService,
io: io,
mu: sync.Mutex{}, // 初始化互斥锁
workspaceCache: make(map[string]*domain.Workspace),
cacheMutex: sync.RWMutex{},
}
// 设置事件处理器
handler.setupEventHandlers()
return handler, nil
}
func (h *SocketHandler) setupEventHandlers() {
h.io.OnConnection(h.handleConnection)
}
func (h *SocketHandler) handleConnection(socket *socketio.Socket) {
h.logger.Debug("Client connected", "socketId", socket.Id)
h.sendServerStatus(socket, "ready", "Server is ready to receive updates")
// 注册事件处理器
h.registerDisconnectHandler(socket)
h.registerFileUpdateHandler(socket)
h.registerTestPingHandler(socket)
h.registerHeartbeatHandler(socket)
h.registerWorkspaceStatsHandler(socket)
}
func (h *SocketHandler) registerDisconnectHandler(socket *socketio.Socket) {
socket.On("disconnect", func(data *socketio.EventPayload) {
reason := "unknown"
if len(data.Data) > 0 {
if r, ok := data.Data[0].(string); ok {
reason = r
}
}
h.logger.Debug("Client disconnected", "socketId", socket.Id, "reason", reason)
})
}
func (h *SocketHandler) registerFileUpdateHandler(socket *socketio.Socket) {
socket.On("file:update", func(data *socketio.EventPayload) {
if len(data.Data) == 0 {
h.sendErrorACK(data, "No data provided")
return
}
h.processFileUpdateData(socket, data)
})
}
func (h *SocketHandler) processFileUpdateData(socket *socketio.Socket, data *socketio.EventPayload) {
switch v := data.Data[0].(type) {
case map[string]interface{}:
response := h.handleFileUpdateFromObject(socket, *data)
h.sendACKWithLock(data, response)
case string:
response := h.handleFileUpdate(socket, v)
h.sendACKWithLock(data, response)
default:
h.logger.Error("Data is neither string nor object",
"socketId", socket.Id,
"dataType", fmt.Sprintf("%T", v))
h.sendErrorACK(data, "Invalid data format - expected string or object")
}
}
func (h *SocketHandler) registerTestPingHandler(socket *socketio.Socket) {
socket.On("test:ping", func(data *socketio.EventPayload) {
if len(data.Data) > 0 {
if dataStr, ok := data.Data[0].(string); ok {
h.handleTestPing(socket, dataStr)
}
}
})
}
func (h *SocketHandler) registerHeartbeatHandler(socket *socketio.Socket) {
socket.On("heartbeat", func(data *socketio.EventPayload) {
if len(data.Data) == 0 {
h.sendErrorACK(data, "No heartbeat data")
return
}
// 直接传递第一个数据元素,支持对象和字符串
response := h.handleHeartbeat(socket, data.Data[0])
if data.Ack != nil {
data.Ack(response)
}
})
}
func (h *SocketHandler) registerWorkspaceStatsHandler(socket *socketio.Socket) {
socket.On("workspace:stats", func(data *socketio.EventPayload) {
// Note: GetWorkspaceStats is not in the new interface.
// This will need to be implemented or removed.
// For now, returning a placeholder.
response := map[string]interface{}{
"status": "not_implemented",
"message": "Workspace stats functionality is not available.",
}
if data.Ack != nil {
data.Ack(response)
}
})
}
func (h *SocketHandler) handleFileUpdate(socket *socketio.Socket, data string) interface{} {
var updateData FileUpdateData
if err := json.Unmarshal([]byte(data), &updateData); err != nil {
h.logger.Error("Failed to parse file update data", "error", err, "data", data)
return map[string]interface{}{
"status": "error",
"message": "Invalid data format",
}
}
// 立即返回确认收到
immediateAck := AckResponse{
ID: updateData.ID,
Status: "received",
Message: "File update received, processing...",
}
// 异步处理文件操作
go h.processFileUpdateAsync(socket, updateData)
return immediateAck
}
func (h *SocketHandler) handleFileUpdateFromObject(socket *socketio.Socket, data socketio.EventPayload) interface{} {
// 从数据中获取第一个元素应该是map
if len(data.Data) == 0 {
h.logger.Error("No data provided for file update")
return AckResponse{
Status: "error",
Message: "No data provided",
}
}
dataMap, ok := data.Data[0].(map[string]interface{})
if !ok {
h.logger.Error("Invalid data format for file update", "type", fmt.Sprintf("%T", data.Data[0]))
return AckResponse{
Status: "error",
Message: "Invalid data format",
}
}
// 解析数据字段
var updateData FileUpdateData
// 使用类型断言提取字段
if id, ok := dataMap["id"].(string); ok {
updateData.ID = id
}
if filePath, ok := dataMap["filePath"].(string); ok {
updateData.FilePath = filePath
}
if event, ok := dataMap["event"].(string); ok {
updateData.Event = event
}
if hash, ok := dataMap["hash"].(string); ok {
updateData.Hash = hash
}
if content, ok := dataMap["content"].(string); ok {
updateData.Content = content
}
if timestamp, ok := dataMap["timestamp"].(float64); ok {
updateData.Timestamp = int64(timestamp)
}
if apiKey, ok := dataMap["apiKey"].(string); ok {
updateData.ApiKey = apiKey
}
if workspacePath, ok := dataMap["workspacePath"].(string); ok {
updateData.WorkspacePath = workspacePath
}
// 立即返回确认收到
immediateAck := AckResponse{
ID: updateData.ID,
Status: "received",
Message: "File update received, processing...",
}
// 异步处理文件操作
go h.processFileUpdateAsync(socket, updateData)
return immediateAck
}
func (h *SocketHandler) processFileUpdateAsync(socket *socketio.Socket, updateData FileUpdateData) {
// 处理文件操作
var finalStatus, message string
ctx := context.Background()
// 通过ApiKey获取用户信息
user, err := h.userService.GetUserByApiKey(ctx, updateData.ApiKey)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Invalid API key: %v", err)
h.logger.Error("Failed to get user by API key", "apiKey", updateData.ApiKey, "error", err)
h.sendFinalResult(socket, updateData, finalStatus, message)
return
}
userID := user.ID.String()
// 确保workspace存在
workspaceID, err := h.ensureWorkspace(ctx, userID, updateData.WorkspacePath)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to ensure workspace: %v", err)
h.logger.Error("Failed to ensure workspace", "error", err)
h.sendFinalResult(socket, updateData, finalStatus, message)
return
}
// Workspace ID obtained
switch updateData.Event {
case "initial_scan", "added":
existingFile, err := h.workspaceService.GetByPath(ctx, userID, workspaceID, updateData.FilePath)
if err != nil {
// "Not Found",文件不存在,执行创建逻辑
if db.IsNotFound(err) {
createReq := &domain.CreateWorkspaceFileReq{
Path: updateData.FilePath,
Content: updateData.Content,
Hash: updateData.Hash,
UserID: userID,
WorkspaceID: workspaceID,
}
_, createErr := h.workspaceService.Create(ctx, createReq)
if createErr != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to create file: %v", createErr)
h.logger.Error("Failed to create file", "path", updateData.FilePath, "error", createErr)
} else {
// 调用GetAndSave处理新创建的文件
fileExtension := h.getFileExtension(updateData.FilePath)
codeFiles := domain.CodeFiles{
Files: []domain.FileMeta{
{
FilePath: updateData.FilePath,
// FileExtension: fileExtension,
Language: h.getFileLanguage(fileExtension),
Content: updateData.Content,
},
},
}
getAndSaveReq := &domain.GetAndSaveReq{
UserID: userID,
WorkspaceID: workspaceID,
FileMetas: codeFiles.Files,
}
err = h.workspaceService.GetAndSave(ctx, getAndSaveReq)
if err != nil {
h.logger.Debug("Failed to process file with GetAndSave", "path", updateData.FilePath, "error", err)
}
finalStatus = "success"
message = "File created successfully"
}
} else {
// 其他错误
finalStatus = "error"
message = fmt.Sprintf("Error checking for existing file: %v", err)
h.logger.Error("Error checking for existing file", "path", updateData.FilePath, "error", err)
}
} else {
// 文件已存在,检查是否需要更新
if existingFile.Hash == updateData.Hash {
finalStatus = "success"
message = "File is already up-to-date"
} else {
updateReq := &domain.UpdateWorkspaceFileReq{
ID: existingFile.ID,
Content: &updateData.Content,
Hash: &updateData.Hash,
}
_, updateErr := h.workspaceService.Update(ctx, updateReq)
if updateErr != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to update existing file: %v", updateErr)
h.logger.Error("Failed to update existing file", "path", updateData.FilePath, "error", updateErr)
} else {
finalStatus = "success"
message = "File updated successfully"
}
}
}
case "modified":
// First, get the file by path to find its ID
file, err := h.workspaceService.GetByPath(ctx, userID, workspaceID, updateData.FilePath)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to find file for update: %v", err)
h.logger.Error("Failed to find file for update", "path", updateData.FilePath, "error", err)
break
}
req := &domain.UpdateWorkspaceFileReq{
ID: file.ID,
Content: &updateData.Content,
Hash: &updateData.Hash,
}
_, err = h.workspaceService.Update(ctx, req)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to update file: %v", err)
h.logger.Error("Failed to update file", "path", updateData.FilePath, "error", err)
} else {
finalStatus = "success"
message = "File updated successfully"
// 调用GetAndSave处理更新的文件
fileExtension := h.getFileExtension(updateData.FilePath)
codeFiles := domain.CodeFiles{
Files: []domain.FileMeta{
{
FilePath: updateData.FilePath,
// FileExtension: fileExtension,
Language: h.getFileLanguage(fileExtension),
Content: updateData.Content,
},
},
}
getAndSaveReq := &domain.GetAndSaveReq{
UserID: userID,
WorkspaceID: workspaceID,
FileMetas: codeFiles.Files,
}
err = h.workspaceService.GetAndSave(ctx, getAndSaveReq)
if err != nil {
h.logger.Debug("Failed to process file with GetAndSave", "path", updateData.FilePath, "error", err)
}
}
case "deleted":
// First, get the file by path to find its ID
file, err := h.workspaceService.GetByPath(ctx, userID, workspaceID, updateData.FilePath)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to find file for deletion: %v", err)
h.logger.Error("Failed to find file for deletion", "path", updateData.FilePath, "error", err)
break
}
err = h.workspaceService.Delete(ctx, file.ID)
if err != nil {
finalStatus = "error"
message = fmt.Sprintf("Failed to delete file: %v", err)
h.logger.Error("Failed to delete file", "path", updateData.FilePath, "error", err)
} else {
finalStatus = "success"
message = "File deleted successfully"
}
default:
finalStatus = "error"
message = fmt.Sprintf("Unknown event type: %s", updateData.Event)
}
// 发送最终处理结果
h.sendFinalResult(socket, updateData, finalStatus, message)
}
// ensureWorkspace ensures that a workspace exists for the given workspacePath
func (h *SocketHandler) ensureWorkspace(ctx context.Context, userID, workspacePath string) (string, error) {
if workspacePath == "" {
return "", fmt.Errorf("no workspace path provided")
}
// 创建处理键,防止同一个 workspace 的并发处理
processingKey := fmt.Sprintf("%s:%s", userID, workspacePath)
// 检查是否已经在处理中
if _, processing := h.workspaceProcessing.LoadOrStore(processingKey, true); processing {
h.logger.Debug("workspace already being processed, waiting", "userID", userID, "workspacePath", workspacePath)
// 等待一段时间后重试
maxWaitRetries := 10
for i := 0; i < maxWaitRetries; i++ {
time.Sleep(50 * time.Millisecond)
if _, stillProcessing := h.workspaceProcessing.Load(processingKey); !stillProcessing {
break
}
}
// 如果仍在处理中,直接调用 EnsureWorkspace此时应该会很快返回现有的workspace
h.logger.Debug("proceeding with workspace creation after wait", "userID", userID, "workspacePath", workspacePath)
}
// 确保在函数结束时清理处理标记
defer h.workspaceProcessing.Delete(processingKey)
// Use EnsureWorkspace to create or update workspace based on path
workspace, err := h.workspaceUsecase.EnsureWorkspace(ctx, userID, workspacePath, "")
if err != nil {
h.logger.Error("Error ensuring workspace", "userID", userID, "path", workspacePath, "error", err)
return "", fmt.Errorf("failed to ensure workspace: %w", err)
}
h.logger.Debug("workspace ensured successfully", "userID", userID, "workspacePath", workspacePath, "workspaceID", workspace.ID)
return workspace.ID, nil
}
func (h *SocketHandler) handleTestPing(socket *socketio.Socket, data string) {
var pingData TestPingData
if err := json.Unmarshal([]byte(data), &pingData); err != nil {
h.logger.Error("Failed to parse test ping data", "error", err)
return
}
// 发送pong响应
pongData := map[string]interface{}{
"timestamp": time.Now().UnixMilli(),
"serverTime": time.Now().Format(time.RFC3339),
"message": "Pong from MonkeyCode server",
"receivedPing": pingData,
"socketId": socket.Id,
"serverStatus": "ok",
}
h.mu.Lock()
socket.Emit("test:pong", pongData)
h.mu.Unlock()
}
func (h *SocketHandler) handleHeartbeat(socket *socketio.Socket, data interface{}) interface{} {
var heartbeatData HeartbeatData
// 处理不同类型的数据
switch v := data.(type) {
case string:
if err := json.Unmarshal([]byte(v), &heartbeatData); err != nil {
h.logger.Error("Failed to parse heartbeat data from string", "error", err)
return map[string]interface{}{
"status": "error",
"message": "Invalid heartbeat data format",
}
}
case map[string]interface{}:
// 直接从map中提取数据
if clientID, ok := v["clientId"].(string); ok {
heartbeatData.ClientID = clientID
}
if timestamp, ok := v["timestamp"].(float64); ok {
heartbeatData.Timestamp = int64(timestamp)
}
if typeStr, ok := v["type"].(string); ok {
heartbeatData.Type = typeStr
}
default:
h.logger.Error("Unexpected heartbeat data type", "type", fmt.Sprintf("%T", data))
return map[string]interface{}{
"status": "error",
"message": "Invalid heartbeat data type",
}
}
// 返回心跳响应
response := map[string]interface{}{
"status": "ok",
"serverTime": time.Now().UnixMilli(),
"socketId": socket.Id,
"clientId": heartbeatData.ClientID,
}
return response
}
func (h *SocketHandler) sendServerStatus(socket *socketio.Socket, status, message string) {
statusData := map[string]string{
"status": status,
"message": message,
}
socket.Emit("server:status", statusData)
}
// GetServer 返回Socket.IO服务器实例
func (h *SocketHandler) GetServer() *socketio.Io {
return h.io
}
// BroadcastServerStatus 向所有连接的客户端广播服务器状态
func (h *SocketHandler) BroadcastServerStatus(status, message string) {
statusData := map[string]interface{}{
"status": status,
"message": message,
}
h.io.Emit("server:status", statusData)
}
// GetConnectedClients 获取连接的客户端数量
func (h *SocketHandler) GetConnectedClients() int {
sockets := h.io.Sockets()
return len(sockets)
}
// 辅助方法发送错误ACK
func (h *SocketHandler) sendErrorACK(data *socketio.EventPayload, message string) {
if data.Ack != nil {
errorResp := map[string]interface{}{
"status": "error",
"message": message,
}
data.Ack(errorResp)
}
}
// 辅助方法带锁发送ACK
func (h *SocketHandler) sendACKWithLock(data *socketio.EventPayload, response interface{}) {
if data.Ack != nil {
h.mu.Lock()
data.Ack(response)
h.mu.Unlock()
}
}
// 发送最终处理结果
func (h *SocketHandler) sendFinalResult(socket *socketio.Socket, updateData FileUpdateData, status, message string) {
finalResponse := map[string]interface{}{
"id": updateData.ID,
"status": status,
"message": message,
"file": updateData.FilePath,
}
// 使用互斥锁保护Socket写入
h.mu.Lock()
socket.Emit("file:update:ack", finalResponse)
h.mu.Unlock()
}
// getFileExtension 获取文件扩展名
func (h *SocketHandler) getFileExtension(filePath string) string {
ext := ""
if len(filePath) > 0 {
for i := len(filePath) - 1; i >= 0; i-- {
if filePath[i] == '.' {
ext = filePath[i+1:]
break
}
}
}
return ext
}
// getFileLanguage 根据文件扩展名获取编程语言类型
func (h *SocketHandler) getFileLanguage(fileExtension string) domain.CodeLanguageType {
switch fileExtension {
case "go":
return domain.CodeLanguageTypeGo
case "py":
return domain.CodeLanguageTypePython
case "java":
return domain.CodeLanguageTypeJava
case "js":
return domain.CodeLanguageTypeJavaScript
case "ts":
return domain.CodeLanguageTypeTypeScript
case "jsx":
return domain.CodeLanguageTypeJSX
case "tsx":
return domain.CodeLanguageTypeTSX
case "html":
return domain.CodeLanguageTypeHTML
case "css":
return domain.CodeLanguageTypeCSS
case "php":
return domain.CodeLanguageTypePHP
case "rs":
return domain.CodeLanguageTypeRust
case "swift":
return domain.CodeLanguageTypeSwift
case "kt":
return domain.CodeLanguageTypeKotlin
case "c":
return domain.CodeLanguageTypeC
case "cpp", "cc", "cxx":
return domain.CodeLanguageTypeCpp
default:
return ""
}
}