From 5a33628b92cb5dd609f73f3cf9845f2dcb0f527b Mon Sep 17 00:00:00 2001 From: Shaun Date: Fri, 19 Dec 2025 14:12:29 +0100 Subject: [PATCH] Feature: Add Go TCP server framework --- server/go/.vscode/launch.json | 30 +++ server/go/README.md | 333 ++++++++++++++++++++++++++++++++ server/go/auth/auth.go | 208 ++++++++++++++++++++ server/go/buffer/buffer.go | 185 ++++++++++++++++++ server/go/cmd/main.go | 268 +++++++++++++++++++++++++ server/go/connection/context.go | 197 +++++++++++++++++++ server/go/connection/errors.go | 23 +++ server/go/connection/manager.go | 115 +++++++++++ server/go/go.mod | 16 ++ server/go/go.sum | 21 ++ server/go/logger/logger.go | 204 +++++++++++++++++++ server/go/protocol/codec.go | 161 +++++++++++++++ server/go/protocol/commands.go | 202 +++++++++++++++++++ server/go/protocol/header.go | 267 +++++++++++++++++++++++++ server/go/protocol/parser.go | 261 +++++++++++++++++++++++++ server/go/server/server.go | 316 ++++++++++++++++++++++++++++++ 16 files changed, 2807 insertions(+) create mode 100644 server/go/.vscode/launch.json create mode 100644 server/go/README.md create mode 100644 server/go/auth/auth.go create mode 100644 server/go/buffer/buffer.go create mode 100644 server/go/cmd/main.go create mode 100644 server/go/connection/context.go create mode 100644 server/go/connection/errors.go create mode 100644 server/go/connection/manager.go create mode 100644 server/go/go.mod create mode 100644 server/go/go.sum create mode 100644 server/go/logger/logger.go create mode 100644 server/go/protocol/codec.go create mode 100644 server/go/protocol/commands.go create mode 100644 server/go/protocol/header.go create mode 100644 server/go/protocol/parser.go create mode 100644 server/go/server/server.go diff --git a/server/go/.vscode/launch.json b/server/go/.vscode/launch.json new file mode 100644 index 0000000..2b2dfed --- /dev/null +++ b/server/go/.vscode/launch.json @@ -0,0 +1,30 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Launch Server", + "type": "go", + "request": "launch", + "mode": "auto", + "program": "${workspaceFolder}/cmd", + "cwd": "${workspaceFolder}", + "args": [], + "env": {}, + "console": "integratedTerminal" + }, + { + "name": "Debug Server", + "type": "go", + "request": "launch", + "mode": "debug", + "program": "${workspaceFolder}/cmd", + "cwd": "${workspaceFolder}", + "args": [ + "-port=9090" + ], + "env": {}, + "console": "integratedTerminal", + "buildFlags": "-gcflags='all=-N -l'" + } + ] +} \ No newline at end of file diff --git a/server/go/README.md b/server/go/README.md new file mode 100644 index 0000000..3266762 --- /dev/null +++ b/server/go/README.md @@ -0,0 +1,333 @@ +# SimpleRemoter Go TCP Server Framework + +基于 Go 语言实现的高性能 TCP 服务端框架,用于替代原有的 C++ IOCP 服务端。 + +## 项目结构 + +``` +server/go/ +├── go.mod # Go 模块定义 +├── buffer/ +│ └── buffer.go # 线程安全的动态缓冲区 +├── connection/ +│ ├── context.go # 连接上下文 +│ ├── errors.go # 错误定义 +│ └── manager.go # 连接管理器 +├── protocol/ +│ ├── parser.go # 协议解析器 +│ ├── codec.go # 编解码和压缩 (ZSTD) +│ ├── header.go # 协议头解密 (8种加密方式) +│ └── commands.go # 命令常量和LOGIN_INFOR解析 +├── server/ +│ ├── server.go # TCP 服务器核心 +│ └── pool.go # Goroutine 工作池 +├── logger/ +│ └── logger.go # 日志模块 (基于 zerolog) +└── cmd/ + └── main.go # 程序入口 +``` + +## 核心特性 + +- **高并发**: 基于 Goroutine 池管理并发连接 +- **协议兼容**: 支持原有 C++ 客户端的多种协议标识 (Hell/Hello/Shine/Fuck) +- **协议头解密**: 支持8种协议头加密方式 (V0-V6 + Default) +- **XOR编码**: 支持 XOREncoder16 数据编码/解码 +- **ZSTD 压缩**: 使用高效的 ZSTD 算法进行数据压缩 +- **GBK编码**: 自动将 Windows 客户端的 GBK 编码转换为 UTF-8 +- **线程安全**: Buffer、连接管理器和 LastActive 均为线程安全设计 +- **优雅关闭**: 支持信号处理和优雅停机,自动释放资源 +- **可配置**: 支持自定义端口、最大连接数、超时时间等 +- **日志系统**: 基于 zerolog,支持文件输出、日志轮转、客户端上下线记录 + +## 支持的命令 + +当前已实现以下命令处理: + +| 命令 | 值 | 说明 | +|------|-----|------| +| TOKEN_AUTH | 100 | 授权请求 | +| TOKEN_HEARTBEAT | 101 | 心跳包 | +| TOKEN_LOGIN | 102 | 客户端登录 | + +其他命令会被记录为 Debug 日志,可按需扩展。 + +## 快速开始 + +### 安装依赖 + +```bash +cd server/go +go mod tidy +``` + +### 编译 + +```bash +go build -o simpleremoter-server ./cmd +``` + +### 运行 + +```bash +./simpleremoter-server +``` + +服务器默认监听 6543 端口,日志输出到 `logs/server.log`。 + +## 使用示例 + +```go +package main + +import ( + "os" + "os/signal" + "syscall" + + "github.com/yuanyuanxiang/SimpleRemoter/server/go/connection" + "github.com/yuanyuanxiang/SimpleRemoter/server/go/logger" + "github.com/yuanyuanxiang/SimpleRemoter/server/go/protocol" + "github.com/yuanyuanxiang/SimpleRemoter/server/go/server" +) + +// 实现 Handler 接口 +type MyHandler struct { + log *logger.Logger +} + +func (h *MyHandler) OnConnect(ctx *connection.Context) { + h.log.ClientEvent("online", ctx.ID, ctx.GetPeerIP()) +} + +func (h *MyHandler) OnDisconnect(ctx *connection.Context) { + h.log.ClientEvent("offline", ctx.ID, ctx.GetPeerIP()) +} + +func (h *MyHandler) OnReceive(ctx *connection.Context, data []byte) { + if len(data) == 0 { + return + } + cmd := data[0] + switch cmd { + case protocol.TokenLogin: + info, _ := protocol.ParseLoginInfo(data) + h.log.Info("Client login: %s (%s)", info.PCName, info.OsVerInfo) + case protocol.TokenHeartbeat: + h.log.Debug("Heartbeat from client %d", ctx.ID) + } +} + +func main() { + // 配置日志 (控制台 + 文件) + logCfg := logger.DefaultConfig() + logCfg.File = "logs/server.log" + log := logger.New(logCfg) + + // 配置服务器 + config := server.DefaultConfig() + config.Port = 6543 + + // 创建并启动服务器 + srv := server.New(config) + srv.SetLogger(log.WithPrefix("Server")) + srv.SetHandler(&MyHandler{log: log}) + + if err := srv.Start(); err != nil { + log.Fatal("启动失败: %v", err) + } + + // 等待退出信号 + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + <-sigChan + + srv.Stop() +} +``` + +## 配置选项 + +| 配置项 | 默认值 | 说明 | +|--------|--------|------| +| Port | 8080 | 监听端口 | +| MaxConnections | 10000 | 最大连接数 | +| MinWorkers | 4 | 最小工作协程数 | +| MaxWorkers | 100 | 最大工作协程数 | +| ReadBufferSize | 8192 | 读缓冲区大小 | +| WriteBufferSize | 8192 | 写缓冲区大小 | +| KeepAliveTime | 5min | 连接保活时间 | +| ReadTimeout | 2min | 读超时时间 | +| WriteTimeout | 30s | 写超时时间 | + +## 日志配置 + +| 配置项 | 默认值 | 说明 | +|--------|--------|------| +| Level | Info | 日志级别 (Debug/Info/Warn/Error/Fatal) | +| Console | true | 是否输出到控制台 | +| File | "" | 日志文件路径 (空则不写文件) | +| MaxSize | 100 | 单个日志文件最大 MB | +| MaxBackups | 3 | 保留的旧日志文件数量 | +| MaxAge | 30 | 旧日志保留天数 | +| Compress | true | 是否压缩轮转的日志 | + +日志示例输出: +```json +{"level":"info","module":"Server","time":"2025-12-19T13:17:32+01:00","message":"Server started on port 6543"} +{"level":"info","module":"Handler","event":"login","client_id":1,"ip":"192.168.0.92","computer":"DESKTOP-BI6RGEJ","os":"Windows 10","version":"Dec 19 2025","time":"2025-12-19T13:17:32+01:00"} +{"level":"debug","module":"Handler","time":"2025-12-19T13:17:47+01:00","message":"Heartbeat from client 1 (DESKTOP-BI6RGEJ)"} +``` + +## 协议格式 + +数据包格式与 C++ 版本兼容: + +``` ++----------+------------+------------+------------------+ +| Flag | TotalLen | OrigLen | Compressed Data | +| (N bytes)| (4 bytes) | (4 bytes) | (variable) | ++----------+------------+------------+------------------+ +``` + +### 协议标识 + +| 标识 | Flag长度 | 压缩方式 | 说明 | +|------|----------|----------|------| +| HELL | 8 bytes | ZSTD | 主要协议 | +| Hello? | 8 bytes | None | 无压缩协议 | +| Shine | 5 bytes | ZSTD | 备用协议 | +| <> | 11 bytes | ZSTD | 备用协议 | + +### 协议头加密 + +支持8种加密方式,服务端自动检测并解密: +- V0 (Default): 动态密钥,4种操作 +- V1: 交替加减 +- V2: 带旋转的异或 +- V3: 带位置的动态密钥 +- V4: 对称的伪随机异或 +- V5: 带位移的动态密钥 +- V6: 带位置的伪随机 +- V7: 纯异或 + +### LOGIN_INFOR 结构 + +客户端登录信息结构体 (考虑 C++ 内存对齐): + +| 字段 | 偏移 | 大小 | 说明 | +|------|------|------|------| +| bToken | 0 | 1 | 命令标识 (102) | +| OsVerInfoEx | 1 | 156 | 操作系统版本 | +| (padding) | 157 | 3 | 对齐填充 | +| dwCPUMHz | 160 | 4 | CPU 频率 | +| moduleVersion | 164 | 24 | 模块版本 | +| szPCName | 188 | 240 | 计算机名 | +| szMasterID | 428 | 20 | 主控 ID | +| bWebCamExist | 448 | 4 | 是否有摄像头 | +| dwSpeed | 452 | 4 | 网速 | +| szStartTime | 456 | 20 | 启动时间 | +| szReserved | 476 | 512 | 扩展字段 (用`|`分隔) | + +## API 参考 + +### Server + +```go +// 创建服务器 +srv := server.New(config) + +// 设置日志 +srv.SetLogger(log) + +// 设置事件处理器 +srv.SetHandler(handler) + +// 启动服务器 +srv.Start() + +// 停止服务器 +srv.Stop() + +// 发送数据到指定连接 +srv.Send(ctx, data) + +// 广播数据到所有连接 +srv.Broadcast(data) + +// 获取当前连接数 +count := srv.ConnectionCount() +``` + +### Connection Context + +```go +// 发送数据 +ctx.Send(data) + +// 关闭连接 +ctx.Close() + +// 获取客户端 IP +ip := ctx.GetPeerIP() + +// 检查连接状态 +closed := ctx.IsClosed() + +// 获取/更新最后活跃时间 (线程安全) +lastActive := ctx.LastActive() +ctx.UpdateLastActive() +duration := ctx.TimeSinceLastActive() + +// 设置/获取客户端信息 +ctx.SetInfo(clientInfo) +info := ctx.GetInfo() + +// 设置/获取用户数据 +ctx.SetUserData(myData) +data := ctx.GetUserData() +``` + +### Protocol + +```go +// 解析登录信息 +info, err := protocol.ParseLoginInfo(data) +if err == nil { + fmt.Println(info.PCName) // 计算机名 + fmt.Println(info.OsVerInfo) // 操作系统 + fmt.Println(info.ModuleVersion) // 版本 + fmt.Println(info.WebCamExist) // 是否有摄像头 +} + +// 获取扩展字段 +reserved := info.ParseReserved() // 返回 []string +clientType := info.GetReservedField(0) // 客户端类型 +cpuCores := info.GetReservedField(2) // CPU 核数 +filePath := info.GetReservedField(4) // 文件路径 +publicIP := info.GetReservedField(11) // 公网 IP +``` + +## 与 C++ 版本对比 + +| 特性 | C++ (IOCP) | Go | +|------|------------|-----| +| 并发模型 | IOCP + 线程池 | Goroutine 池 | +| 压缩算法 | ZSTD | ZSTD | +| 跨平台 | Windows | 全平台 | +| 内存管理 | 手动 | GC | +| 代码复杂度 | 高 | 低 | +| 协议头解密 | 8种方式 | 8种方式 | +| XOR编码 | XOREncoder16 | XOREncoder16 | +| 字符编码 | GBK | GBK -> UTF-8 | + +## 依赖 + +- [github.com/klauspost/compress/zstd](https://github.com/klauspost/compress) - ZSTD 压缩 +- [github.com/rs/zerolog](https://github.com/rs/zerolog) - 高性能日志 +- [gopkg.in/natefinch/lumberjack.v2](https://github.com/natefinch/lumberjack) - 日志轮转 +- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) - GBK 编码转换 + +## License + +MIT License diff --git a/server/go/auth/auth.go b/server/go/auth/auth.go new file mode 100644 index 0000000..f798e79 --- /dev/null +++ b/server/go/auth/auth.go @@ -0,0 +1,208 @@ +package auth + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "os" + "strings" + + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/transform" +) + +// Config holds authentication configuration +type Config struct { + PwdHash string // SHA256 hash of the password (64 hex chars) + SuperPass string // Super admin password for HMAC verification +} + +// DefaultConfig returns default auth configuration +func DefaultConfig() *Config { + return &Config{ + PwdHash: "", // Must be configured + SuperPass: "", // Can be set via YAMA_PWD env var + } +} + +// Authenticator handles token authentication +type Authenticator struct { + config *Config +} + +// New creates a new Authenticator +func New(config *Config) *Authenticator { + return &Authenticator{config: config} +} + +// AuthResult contains the result of authentication +type AuthResult struct { + Valid bool + Message string + SN string + Passcode string +} + +// Authenticate validates a TOKEN_AUTH request +// Data format: +// - offset 0: TOKEN_AUTH command byte +// - offset 1-19: SN (serial number, 19 bytes) +// - offset 20-62: Passcode (42 bytes) +// - offset 62-70: HMAC signature (uint64, 8 bytes) if len > 64 +func (a *Authenticator) Authenticate(data []byte) *AuthResult { + result := &AuthResult{ + Valid: false, + Message: "未获授权或消息哈希校验失败", + } + + // Minimum length check: 1 (token) + 19 (sn) + 1 (at least some passcode) + if len(data) <= 20 { + return result + } + + // Extract SN (bytes 1-19) + sn := string(data[1:20]) + result.SN = sn + + // Extract passcode (bytes 20-62, or until end if shorter) + passcodeEnd := 62 + if len(data) < passcodeEnd { + passcodeEnd = len(data) + } + passcode := string(data[20:passcodeEnd]) + result.Passcode = passcode + + // Extract HMAC if present (bytes 62-70) + var hmacSig uint64 + if len(data) >= 70 { + hmacSig = binary.LittleEndian.Uint64(data[62:70]) + } else if len(data) > 62 { + // Partial HMAC data - safely handle incomplete bytes + hmacBytes := make([]byte, 8) + copy(hmacBytes, data[62:]) + hmacSig = binary.LittleEndian.Uint64(hmacBytes) + } + + // Split passcode by '-' + parts := strings.Split(passcode, "-") + if len(parts) != 6 && len(parts) != 7 { + return result + } + + // Get last 4 parts as subvector + subvector := parts[len(parts)-4:] + + // Build password string: v[0] + " - " + v[1] + ": " + PwdHash + (optional: ": " + v[2]) + password := parts[0] + " - " + parts[1] + ": " + a.config.PwdHash + if len(parts) == 7 { + password += ": " + parts[2] + } + + // Derive key from password and SN + finalKey := DeriveKey(password, sn) + + // Get fixed length ID + hash256 := strings.Join(subvector, "-") + fixedKey := GetFixedLengthID(finalKey) + + // Debug output (can be removed in production) + // fmt.Printf("DEBUG: password=%q sn=%q finalKey=%s fixedKey=%s hash256=%s\n", password, sn, finalKey, fixedKey, hash256) + + // Compare + if hash256 != fixedKey { + return result + } + + // Passcode validation successful, now verify HMAC + superPass := os.Getenv("YAMA_PWD") + if superPass == "" { + superPass = a.config.SuperPass + } + + if superPass != "" && hmacSig != 0 { + verified := VerifyMessage(superPass, []byte(passcode), hmacSig) + if verified { + result.Valid = true + result.Message = "此程序已获授权,请遵守授权协议,感谢合作" + } + // If HMAC verification fails, valid remains false + } else if hmacSig == 0 { + // No HMAC provided but passcode is valid - could be older client + // Keep as invalid for security + } + + return result +} + +// utf8ToGBK converts UTF-8 string to GBK encoded bytes +func utf8ToGBK(s string) []byte { + reader := transform.NewReader(bytes.NewReader([]byte(s)), simplifiedchinese.GBK.NewEncoder()) + buf := new(bytes.Buffer) + buf.ReadFrom(reader) + return buf.Bytes() +} + +// BuildResponse builds the 100-byte response for TOKEN_AUTH +func (a *Authenticator) BuildResponse(result *AuthResult) []byte { + resp := make([]byte, 100) + if result.Valid { + resp[0] = 1 + } + // Message starts at offset 4, convert UTF-8 to GBK for Windows client + gbkMsg := utf8ToGBK(result.Message) + copy(resp[4:], gbkMsg) + return resp +} + +// HashSHA256 computes SHA256 hash and returns hex string +func HashSHA256(data string) string { + h := sha256.New() + h.Write([]byte(data)) + return hex.EncodeToString(h.Sum(nil)) +} + +// DeriveKey derives a key from password and hardware ID +// Format: SHA256(password + " + " + hardwareID) +func DeriveKey(password, hardwareID string) string { + return HashSHA256(password + " + " + hardwareID) +} + +// GetFixedLengthID formats a hash into fixed length ID +// Format: xxxx-xxxx-xxxx-xxxx (first 16 chars split by -) +func GetFixedLengthID(hash string) string { + if len(hash) < 16 { + return hash + } + return hash[0:4] + "-" + hash[4:8] + "-" + hash[8:12] + "-" + hash[12:16] +} + +// SignMessage computes HMAC-SHA256 and returns first 8 bytes as uint64 +func SignMessage(pwd string, msg []byte) uint64 { + h := hmac.New(sha256.New, []byte(pwd)) + h.Write(msg) + hash := h.Sum(nil) + return binary.LittleEndian.Uint64(hash[:8]) +} + +// VerifyMessage verifies HMAC signature +func VerifyMessage(pwd string, msg []byte, signature uint64) bool { + computed := SignMessage(pwd, msg) + return computed == signature +} + +// GenHMAC generates HMAC for password verification +// This matches the C++ genHMAC function +func GenHMAC(pwdHash, superPass string) string { + key := HashSHA256(superPass) + list := []string{"g", "h", "o", "s", "t"} + for _, item := range list { + key = HashSHA256(key + " - " + item) + } + result := HashSHA256(pwdHash + " - " + key) + if len(result) >= 16 { + return result[:16] + } + return result +} diff --git a/server/go/buffer/buffer.go b/server/go/buffer/buffer.go new file mode 100644 index 0000000..5415ae5 --- /dev/null +++ b/server/go/buffer/buffer.go @@ -0,0 +1,185 @@ +package buffer + +import ( + "encoding/binary" + "sync" +) + +// Buffer is a thread-safe dynamic buffer for network I/O +type Buffer struct { + data []byte + mu sync.RWMutex + offset int // read offset for lazy compaction +} + +// New creates a new buffer with optional initial capacity +func New(capacity ...int) *Buffer { + cap := 4096 + if len(capacity) > 0 && capacity[0] > 0 { + cap = capacity[0] + } + return &Buffer{ + data: make([]byte, 0, cap), + } +} + +// Write appends data to the buffer +func (b *Buffer) Write(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + b.data = append(b.data, p...) + return len(p), nil +} + +// WriteUint32 writes a uint32 in little-endian format +func (b *Buffer) WriteUint32(v uint32) { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, v) + b.Write(buf) +} + +// Read reads and removes data from the buffer +func (b *Buffer) Read(n int) []byte { + b.mu.Lock() + defer b.mu.Unlock() + + available := len(b.data) - b.offset + if n > available { + n = available + } + if n <= 0 { + return nil + } + + result := make([]byte, n) + copy(result, b.data[b.offset:b.offset+n]) + b.offset += n + + // Compact when offset is large enough + if b.offset > len(b.data)/2 && b.offset > 1024 { + b.compact() + } + + return result +} + +// Peek returns data without removing it +func (b *Buffer) Peek(n int) []byte { + b.mu.RLock() + defer b.mu.RUnlock() + + available := len(b.data) - b.offset + if n > available { + n = available + } + if n <= 0 { + return nil + } + + result := make([]byte, n) + copy(result, b.data[b.offset:b.offset+n]) + return result +} + +// PeekAt returns data at a specific offset without removing it +func (b *Buffer) PeekAt(offset, n int) []byte { + b.mu.RLock() + defer b.mu.RUnlock() + + start := b.offset + offset + if start >= len(b.data) { + return nil + } + + end := start + n + if end > len(b.data) { + end = len(b.data) + } + + result := make([]byte, end-start) + copy(result, b.data[start:end]) + return result +} + +// Skip removes n bytes from the beginning +func (b *Buffer) Skip(n int) int { + b.mu.Lock() + defer b.mu.Unlock() + + available := len(b.data) - b.offset + if n > available { + n = available + } + b.offset += n + + // Compact when offset is large enough + if b.offset > len(b.data)/2 && b.offset > 1024 { + b.compact() + } + + return n +} + +// Len returns the length of unread data +func (b *Buffer) Len() int { + b.mu.RLock() + defer b.mu.RUnlock() + return len(b.data) - b.offset +} + +// Bytes returns all unread data without removing it +func (b *Buffer) Bytes() []byte { + b.mu.RLock() + defer b.mu.RUnlock() + + n := len(b.data) - b.offset + if n <= 0 { + return nil + } + + result := make([]byte, n) + copy(result, b.data[b.offset:]) + return result +} + +// Clear removes all data from the buffer +func (b *Buffer) Clear() { + b.mu.Lock() + defer b.mu.Unlock() + b.data = b.data[:0] + b.offset = 0 +} + +// compact moves remaining data to the beginning +func (b *Buffer) compact() { + if b.offset > 0 { + remaining := len(b.data) - b.offset + copy(b.data[:remaining], b.data[b.offset:]) + b.data = b.data[:remaining] + b.offset = 0 + } +} + +// GetByte returns a single byte at offset +func (b *Buffer) GetByte(offset int) byte { + b.mu.RLock() + defer b.mu.RUnlock() + + idx := b.offset + offset + if idx >= len(b.data) { + return 0 + } + return b.data[idx] +} + +// GetUint32 returns a uint32 at offset in little-endian format +func (b *Buffer) GetUint32(offset int) uint32 { + b.mu.RLock() + defer b.mu.RUnlock() + + idx := b.offset + offset + if idx+4 > len(b.data) { + return 0 + } + return binary.LittleEndian.Uint32(b.data[idx : idx+4]) +} diff --git a/server/go/cmd/main.go b/server/go/cmd/main.go new file mode 100644 index 0000000..1812325 --- /dev/null +++ b/server/go/cmd/main.go @@ -0,0 +1,268 @@ +package main + +import ( + "flag" + "fmt" + "os" + "os/signal" + "strconv" + "strings" + "syscall" + + "github.com/yuanyuanxiang/SimpleRemoter/server/go/auth" + "github.com/yuanyuanxiang/SimpleRemoter/server/go/connection" + "github.com/yuanyuanxiang/SimpleRemoter/server/go/logger" + "github.com/yuanyuanxiang/SimpleRemoter/server/go/protocol" + "github.com/yuanyuanxiang/SimpleRemoter/server/go/server" +) + +// MyHandler implements the server.Handler interface +type MyHandler struct { + log *logger.Logger + auth *auth.Authenticator + srv *server.Server +} + +// OnConnect is called when a client connects +func (h *MyHandler) OnConnect(ctx *connection.Context) { + // Only log connection established, detailed info logged on login +} + +// OnDisconnect is called when a client disconnects +func (h *MyHandler) OnDisconnect(ctx *connection.Context) { + info := ctx.GetInfo() + if info.ClientID != "" { + h.log.ClientEvent("offline", ctx.ID, ctx.GetPeerIP(), + "clientID", info.ClientID, + "computer", info.ComputerName, + ) + } +} + +// OnReceive is called when data is received from a client +func (h *MyHandler) OnReceive(ctx *connection.Context, data []byte) { + if len(data) == 0 { + return + } + + cmd := data[0] + // Handle commands + switch cmd { + case protocol.TokenLogin: + h.handleLogin(ctx, data) + case protocol.TokenAuth: + h.handleAuth(ctx, data) + case protocol.TokenHeartbeat: + h.handleHeartbeat(ctx, data) + default: + // Other commands are not implemented yet + h.log.Info("Unhandled command %d from client %d", cmd, ctx.ID) + } +} + +// handleLogin handles client login (TOKEN_LOGIN = 102) +func (h *MyHandler) handleLogin(ctx *connection.Context, data []byte) { + info, err := protocol.ParseLoginInfo(data) + if err != nil { + h.log.Error("Failed to parse login info from client %d: %v", ctx.ID, err) + return + } + + // Use MasterID from login request as ClientID for logging + clientID := info.MasterID + if clientID == "" { + clientID = fmt.Sprintf("conn-%d", ctx.ID) + } + + // Store client info + reserved := info.ParseReserved() + clientInfo := connection.ClientInfo{ + ClientID: clientID, + ComputerName: info.PCName, + OS: info.OsVerInfo, + Version: info.ModuleVersion, + HasCamera: info.WebCamExist, + InstallTime: info.StartTime, + } + + // Parse additional info from reserved field + if len(reserved) > 0 { + clientInfo.ClientType = info.GetReservedField(0) + } + if len(reserved) > 2 { + clientInfo.CPU = info.GetReservedField(2) + } + if len(reserved) > 4 { + clientInfo.FilePath = info.GetReservedField(4) + } + if len(reserved) > 11 { + clientInfo.IP = info.GetReservedField(11) // Public IP + } + + ctx.SetInfo(clientInfo) + ctx.IsLoggedIn.Store(true) + + h.log.ClientEvent("online", ctx.ID, ctx.GetPeerIP(), + "clientID", clientID, + "computer", info.PCName, + "os", info.OsVerInfo, + "version", info.ModuleVersion, + "path", clientInfo.FilePath, + ) +} + +// handleAuth handles authorization request (TOKEN_AUTH = 100) +func (h *MyHandler) handleAuth(ctx *connection.Context, data []byte) { + result := h.auth.Authenticate(data) + info := ctx.GetInfo() + + if result.Valid { + h.log.Info("Auth success: clientID=%s computer=%s ip=%s sn=%s passcode=%s", + info.ClientID, info.ComputerName, ctx.GetPeerIP(), result.SN, result.Passcode) + } else { + h.log.Warn("Auth failed: clientID=%s computer=%s ip=%s sn=%s passcode=%s", + info.ClientID, info.ComputerName, ctx.GetPeerIP(), result.SN, result.Passcode) + } + + // Build and send response + resp := h.auth.BuildResponse(result) + if err := h.srv.Send(ctx, resp); err != nil { + h.log.Error("Failed to send auth response to client %d: %v", ctx.ID, err) + } +} + +// handleHeartbeat handles heartbeat from client (TOKEN_HEARTBEAT = 101) +func (h *MyHandler) handleHeartbeat(ctx *connection.Context, data []byte) { + + // Parse Time from heartbeat request (offset 1, 8 bytes) + var hbTime uint64 + if len(data) >= 9 { + hbTime = uint64(data[1]) | uint64(data[2])<<8 | uint64(data[3])<<16 | uint64(data[4])<<24 | + uint64(data[5])<<32 | uint64(data[6])<<40 | uint64(data[7])<<48 | uint64(data[8])<<56 + } + + // Build HeartbeatACK response: CMD_HEARTBEAT_ACK(1) + HeartbeatACK(32) + resp := make([]byte, 33) + resp[0] = protocol.CommandHeartbeat // CMD_HEARTBEAT_ACK = 216 + // Time at offset 1 (8 bytes, little-endian) + resp[1] = byte(hbTime) + resp[2] = byte(hbTime >> 8) + resp[3] = byte(hbTime >> 16) + resp[4] = byte(hbTime >> 24) + resp[5] = byte(hbTime >> 32) + resp[6] = byte(hbTime >> 40) + resp[7] = byte(hbTime >> 48) + resp[8] = byte(hbTime >> 56) + // Reserved[24] at offset 9 is already zero + + if err := h.srv.Send(ctx, resp); err != nil { + h.log.Error("Failed to send heartbeat ACK to client %d: %v", ctx.ID, err) + } +} + +// parsePorts parses a semicolon-separated port string and returns port numbers +func parsePorts(portStr string) ([]int, error) { + var ports []int + parts := strings.Split(portStr, ";") + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + port, err := strconv.Atoi(p) + if err != nil { + return nil, fmt.Errorf("invalid port %q: %v", p, err) + } + if port < 1 || port > 65535 { + return nil, fmt.Errorf("port %d out of range (1-65535)", port) + } + ports = append(ports, port) + } + if len(ports) == 0 { + return nil, fmt.Errorf("no valid ports specified") + } + return ports, nil +} + +func main() { + // Parse command line flags + portStr := flag.String("port", "6543", "Server listen ports (semicolon-separated, e.g. 6543;6544;6545)") + flag.StringVar(portStr, "p", "6543", "Server listen ports (shorthand)") + noConsole := flag.Bool("no-console", false, "Disable console output (for daemon mode)") + flag.Parse() + + // Parse ports + ports, err := parsePorts(*portStr) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + // Create logger with file output + logCfg := logger.DefaultConfig() + logCfg.Level = logger.LevelDebug + logCfg.Console = !*noConsole + logCfg.File = "logs/server.log" + logCfg.MaxSize = 100 // 100 MB + logCfg.MaxBackups = 10 // keep 10 old files + logCfg.MaxAge = 30 // 30 days + logCfg.Compress = true + + log := logger.New(logCfg) + + // Create auth config + authCfg := auth.DefaultConfig() + // PwdHash can be set from environment or config file + authCfg.PwdHash = os.Getenv("YAMA_PWDHASH") + if authCfg.PwdHash == "" { + // Default placeholder - should be configured in production + authCfg.PwdHash = "61f04dd637a74ee34493fc1025de2c131022536da751c29e3ff4e9024d8eec43" + } + authCfg.SuperPass = os.Getenv("YAMA_PWD") + + // Create authenticator (shared by all servers) + authenticator := auth.New(authCfg) + + // Create servers for each port + var servers []*server.Server + for _, port := range ports { + config := server.DefaultConfig() + config.Port = port + config.MaxConnections = 9999 + + srv := server.New(config) + srv.SetLogger(log.WithPrefix(fmt.Sprintf("Server:%d", port))) + + // Create handler for this server + handler := &MyHandler{ + log: log.WithPrefix(fmt.Sprintf("Handler:%d", port)), + auth: authenticator, + srv: srv, + } + srv.SetHandler(handler) + + servers = append(servers, srv) + } + + // Start all servers + for _, srv := range servers { + if err := srv.Start(); err != nil { + log.Fatal("Failed to start server: %v", err) + } + } + + fmt.Printf("Server started on port(s): %v\n", ports) + fmt.Println("Logs are written to: logs/server.log") + fmt.Println("Press Ctrl+C to stop...") + + // Wait for interrupt signal + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + <-sigChan + + fmt.Println("\nShutting down...") + for _, srv := range servers { + srv.Stop() + } + fmt.Println("Server stopped") +} diff --git a/server/go/connection/context.go b/server/go/connection/context.go new file mode 100644 index 0000000..8ca3dce --- /dev/null +++ b/server/go/connection/context.go @@ -0,0 +1,197 @@ +package connection + +import ( + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/yuanyuanxiang/SimpleRemoter/server/go/buffer" +) + +// ClientInfo stores client metadata +type ClientInfo struct { + ClientID string // Client ID from login (MasterID) + IP string + ComputerName string + OS string + CPU string + HasCamera bool + Version string + InstallTime string + LoginTime time.Time + ClientType string + FilePath string + GroupName string +} + +// Context represents a client connection context +type Context struct { + ID uint64 + Conn net.Conn + RemoteAddr string + + // Buffers + InBuffer *buffer.Buffer // Received compressed data + OutBuffer *buffer.Buffer // Decompressed data for processing + + // Client info + Info ClientInfo + IsLoggedIn atomic.Bool + + // Connection state + OnlineTime time.Time + lastActiveNs atomic.Int64 // Unix nanoseconds for thread-safe access + + // Protocol state + CompressMethod int + FlagType FlagType + HeaderLen int + FlagLen int + HeaderEncType int // Header encryption type (0-7) + HeaderParams []byte // Header parameters for decoding (flag bytes) + + // User data - for storing dialog/handler references + UserData interface{} + + // Internal + mu sync.RWMutex + closed atomic.Bool + sendLock sync.Mutex + server *Manager +} + +// FlagType represents the protocol flag type +type FlagType int + +const ( + FlagUnknown FlagType = iota + FlagShine + FlagFuck + FlagHello + FlagHell + FlagWinOS +) + +// Compression methods +const ( + CompressUnknown = -2 + CompressZlib = -1 + CompressZstd = 0 + CompressNone = 1 +) + +// NewContext creates a new connection context +func NewContext(conn net.Conn, mgr *Manager) *Context { + now := time.Now() + ctx := &Context{ + Conn: conn, + RemoteAddr: conn.RemoteAddr().String(), + InBuffer: buffer.New(8192), + OutBuffer: buffer.New(8192), + OnlineTime: now, + CompressMethod: CompressZstd, + FlagType: FlagUnknown, + server: mgr, + } + ctx.lastActiveNs.Store(now.UnixNano()) + return ctx +} + +// Send sends data to the client (thread-safe) +func (c *Context) Send(data []byte) error { + if c.closed.Load() { + return ErrConnectionClosed + } + + c.sendLock.Lock() + defer c.sendLock.Unlock() + + _, err := c.Conn.Write(data) + if err != nil { + return err + } + c.UpdateLastActive() + return nil +} + +// UpdateLastActive updates the last active time (thread-safe) +func (c *Context) UpdateLastActive() { + c.lastActiveNs.Store(time.Now().UnixNano()) +} + +// LastActive returns the last active time (thread-safe) +func (c *Context) LastActive() time.Time { + return time.Unix(0, c.lastActiveNs.Load()) +} + +// TimeSinceLastActive returns duration since last activity (thread-safe) +func (c *Context) TimeSinceLastActive() time.Duration { + return time.Since(c.LastActive()) +} + +// Close closes the connection +func (c *Context) Close() error { + if c.closed.Swap(true) { + return nil // Already closed + } + return c.Conn.Close() +} + +// IsClosed returns whether the connection is closed +func (c *Context) IsClosed() bool { + return c.closed.Load() +} + +// GetPeerIP returns the peer IP address +func (c *Context) GetPeerIP() string { + if host, _, err := net.SplitHostPort(c.RemoteAddr); err == nil { + return host + } + return c.RemoteAddr +} + +// AliveTime returns how long the connection has been alive +func (c *Context) AliveTime() time.Duration { + return time.Since(c.OnlineTime) +} + +// SetInfo sets the client info +func (c *Context) SetInfo(info ClientInfo) { + c.mu.Lock() + defer c.mu.Unlock() + c.Info = info +} + +// GetInfo returns the client info +func (c *Context) GetInfo() ClientInfo { + c.mu.RLock() + defer c.mu.RUnlock() + return c.Info +} + +// SetUserData stores user-defined data +func (c *Context) SetUserData(data interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + c.UserData = data +} + +// GetUserData retrieves user-defined data +func (c *Context) GetUserData() interface{} { + c.mu.RLock() + defer c.mu.RUnlock() + return c.UserData +} + +// GetClientID returns the client ID for logging +// If ClientID is set (from login), returns it; otherwise returns connection ID as fallback +func (c *Context) GetClientID() string { + c.mu.RLock() + defer c.mu.RUnlock() + if c.Info.ClientID != "" { + return c.Info.ClientID + } + return fmt.Sprintf("conn-%d", c.ID) +} diff --git a/server/go/connection/errors.go b/server/go/connection/errors.go new file mode 100644 index 0000000..37de470 --- /dev/null +++ b/server/go/connection/errors.go @@ -0,0 +1,23 @@ +package connection + +import "errors" + +var ( + // ErrConnectionClosed indicates the connection is closed + ErrConnectionClosed = errors.New("connection closed") + + // ErrServerClosed indicates the server is shut down + ErrServerClosed = errors.New("server closed") + + // ErrMaxConnections indicates max connections reached + ErrMaxConnections = errors.New("max connections reached") + + // ErrInvalidPacket indicates an invalid packet + ErrInvalidPacket = errors.New("invalid packet") + + // ErrUnsupportedProtocol indicates unsupported protocol + ErrUnsupportedProtocol = errors.New("unsupported protocol") + + // ErrDecompressFailed indicates decompression failure + ErrDecompressFailed = errors.New("decompression failed") +) diff --git a/server/go/connection/manager.go b/server/go/connection/manager.go new file mode 100644 index 0000000..9ed9c01 --- /dev/null +++ b/server/go/connection/manager.go @@ -0,0 +1,115 @@ +package connection + +import ( + "sync" + "sync/atomic" +) + +// Manager manages all client connections +type Manager struct { + connections sync.Map // map[uint64]*Context + count atomic.Int64 + maxConns int + idCounter atomic.Uint64 + + // Callbacks + onConnect func(*Context) + onDisconnect func(*Context) + onReceive func(*Context, []byte) +} + +// NewManager creates a new connection manager +func NewManager(maxConns int) *Manager { + if maxConns <= 0 { + maxConns = 10000 + } + return &Manager{ + maxConns: maxConns, + } +} + +// SetCallbacks sets the callback functions +func (m *Manager) SetCallbacks(onConnect, onDisconnect func(*Context), onReceive func(*Context, []byte)) { + m.onConnect = onConnect + m.onDisconnect = onDisconnect + m.onReceive = onReceive +} + +// Add adds a new connection +func (m *Manager) Add(ctx *Context) error { + if int(m.count.Load()) >= m.maxConns { + return ErrMaxConnections + } + + ctx.ID = m.idCounter.Add(1) + m.connections.Store(ctx.ID, ctx) + m.count.Add(1) + + if m.onConnect != nil { + m.onConnect(ctx) + } + + return nil +} + +// Remove removes a connection +func (m *Manager) Remove(ctx *Context) { + if _, ok := m.connections.LoadAndDelete(ctx.ID); ok { + m.count.Add(-1) + if m.onDisconnect != nil { + m.onDisconnect(ctx) + } + } +} + +// Get retrieves a connection by ID +func (m *Manager) Get(id uint64) *Context { + if v, ok := m.connections.Load(id); ok { + return v.(*Context) + } + return nil +} + +// Count returns the current connection count +func (m *Manager) Count() int { + return int(m.count.Load()) +} + +// Range iterates over all connections +func (m *Manager) Range(fn func(*Context) bool) { + m.connections.Range(func(key, value interface{}) bool { + return fn(value.(*Context)) + }) +} + +// Broadcast sends data to all connections +func (m *Manager) Broadcast(data []byte) { + m.connections.Range(func(key, value interface{}) bool { + ctx := value.(*Context) + if !ctx.IsClosed() { + _ = ctx.Send(data) + } + return true + }) +} + +// CloseAll closes all connections +func (m *Manager) CloseAll() { + m.connections.Range(func(key, value interface{}) bool { + ctx := value.(*Context) + _ = ctx.Close() + return true + }) +} + +// OnReceive calls the receive callback +func (m *Manager) OnReceive(ctx *Context, data []byte) { + if m.onReceive != nil { + m.onReceive(ctx, data) + } +} + +// UpdateMaxConnections updates the maximum connections limit +func (m *Manager) UpdateMaxConnections(max int) { + m.maxConns = max +} diff --git a/server/go/go.mod b/server/go/go.mod new file mode 100644 index 0000000..a1d5f04 --- /dev/null +++ b/server/go/go.mod @@ -0,0 +1,16 @@ +module github.com/yuanyuanxiang/SimpleRemoter/server/go + +go 1.24.5 + +require ( + github.com/klauspost/compress v1.18.2 + github.com/rs/zerolog v1.34.0 + golang.org/x/text v0.32.0 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 +) + +require ( + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + golang.org/x/sys v0.12.0 // indirect +) diff --git a/server/go/go.sum b/server/go/go.sum new file mode 100644 index 0000000..d2f3d60 --- /dev/null +++ b/server/go/go.sum @@ -0,0 +1,21 @@ +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= diff --git a/server/go/logger/logger.go b/server/go/logger/logger.go new file mode 100644 index 0000000..69e261c --- /dev/null +++ b/server/go/logger/logger.go @@ -0,0 +1,204 @@ +package logger + +import ( + "io" + "os" + "path/filepath" + "time" + + "github.com/rs/zerolog" + "gopkg.in/natefinch/lumberjack.v2" +) + +// Level represents log level +type Level = zerolog.Level + +const ( + LevelDebug = zerolog.DebugLevel + LevelInfo = zerolog.InfoLevel + LevelWarn = zerolog.WarnLevel + LevelError = zerolog.ErrorLevel + LevelFatal = zerolog.FatalLevel +) + +// Logger wraps zerolog.Logger +type Logger struct { + zl zerolog.Logger +} + +// Config holds logger configuration +type Config struct { + // Level is the minimum log level + Level Level + // Console enables console output + Console bool + // File is the log file path (empty to disable file logging) + File string + // MaxSize is the max size in MB before rotation + MaxSize int + // MaxBackups is the max number of old log files to keep + MaxBackups int + // MaxAge is the max days to keep old log files + MaxAge int + // Compress enables gzip compression for rotated files + Compress bool +} + +// DefaultConfig returns default configuration +func DefaultConfig() Config { + return Config{ + Level: LevelInfo, + Console: true, + File: "", + MaxSize: 100, + MaxBackups: 3, + MaxAge: 30, + Compress: true, + } +} + +// New creates a new logger with config +func New(cfg Config) *Logger { + var writers []io.Writer + + // Console output with color + if cfg.Console { + consoleWriter := zerolog.ConsoleWriter{ + Out: os.Stdout, + TimeFormat: "2006-01-02 15:04:05", + } + writers = append(writers, consoleWriter) + } + + // File output with rotation + if cfg.File != "" { + // Ensure directory exists + dir := filepath.Dir(cfg.File) + if dir != "" && dir != "." { + _ = os.MkdirAll(dir, 0755) + } + + fileWriter := &lumberjack.Logger{ + Filename: cfg.File, + MaxSize: cfg.MaxSize, + MaxBackups: cfg.MaxBackups, + MaxAge: cfg.MaxAge, + Compress: cfg.Compress, + LocalTime: true, + } + writers = append(writers, fileWriter) + } + + // Combine writers + var writer io.Writer + if len(writers) == 0 { + writer = os.Stdout + } else if len(writers) == 1 { + writer = writers[0] + } else { + writer = zerolog.MultiLevelWriter(writers...) + } + + // Create logger + zl := zerolog.New(writer). + Level(cfg.Level). + With(). + Timestamp(). + Logger() + + return &Logger{zl: zl} +} + +// WithPrefix returns a new logger with a prefix field +func (l *Logger) WithPrefix(prefix string) *Logger { + return &Logger{ + zl: l.zl.With().Str("module", prefix).Logger(), + } +} + +// Debug logs a debug message +func (l *Logger) Debug(format string, args ...interface{}) { + l.zl.Debug().Msgf(format, args...) +} + +// Info logs an info message +func (l *Logger) Info(format string, args ...interface{}) { + l.zl.Info().Msgf(format, args...) +} + +// Warn logs a warning message +func (l *Logger) Warn(format string, args ...interface{}) { + l.zl.Warn().Msgf(format, args...) +} + +// Error logs an error message +func (l *Logger) Error(format string, args ...interface{}) { + l.zl.Error().Msgf(format, args...) +} + +// Fatal logs a fatal message and exits +func (l *Logger) Fatal(format string, args ...interface{}) { + l.zl.Fatal().Msgf(format, args...) +} + +// SetLevel sets the log level +func (l *Logger) SetLevel(level Level) { + l.zl = l.zl.Level(level) +} + +// GetLevel returns the current log level +func (l *Logger) GetLevel() Level { + return l.zl.GetLevel() +} + +// ClientEvent logs client online/offline events +func (l *Logger) ClientEvent(event string, clientID uint64, ip string, extra ...string) { + e := l.zl.Info(). + Str("event", event). + Uint64("client_id", clientID). + Str("ip", ip). + Time("time", time.Now()) + + if len(extra) >= 2 { + for i := 0; i+1 < len(extra); i += 2 { + e = e.Str(extra[i], extra[i+1]) + } + } + + e.Msg("") +} + +// default global logger +var defaultLogger = New(DefaultConfig()) + +// SetDefault sets the default global logger +func SetDefault(l *Logger) { + defaultLogger = l +} + +// Default returns the default global logger +func Default() *Logger { + return defaultLogger +} + +// Package-level convenience functions + +func Debug(format string, args ...interface{}) { + defaultLogger.Debug(format, args...) +} + +func Info(format string, args ...interface{}) { + defaultLogger.Info(format, args...) +} + +func Warn(format string, args ...interface{}) { + defaultLogger.Warn(format, args...) +} + +func Error(format string, args ...interface{}) { + defaultLogger.Error(format, args...) +} + +func Fatal(format string, args ...interface{}) { + defaultLogger.Fatal(format, args...) +} diff --git a/server/go/protocol/codec.go b/server/go/protocol/codec.go new file mode 100644 index 0000000..a9b35a3 --- /dev/null +++ b/server/go/protocol/codec.go @@ -0,0 +1,161 @@ +package protocol + +import ( + "github.com/klauspost/compress/zstd" + "github.com/yuanyuanxiang/SimpleRemoter/server/go/connection" +) + +// Codec handles encoding/decoding and compression +type Codec struct { + encoder *zstd.Encoder + decoder *zstd.Decoder +} + +// NewCodec creates a new codec +func NewCodec() *Codec { + encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedDefault)) + if err != nil { + panic("failed to create zstd encoder: " + err.Error()) + } + decoder, err := zstd.NewReader(nil) + if err != nil { + panic("failed to create zstd decoder: " + err.Error()) + } + + return &Codec{ + encoder: encoder, + decoder: decoder, + } +} + +// Compress compresses data using the appropriate method +func (c *Codec) Compress(ctx *connection.Context, data []byte) ([]byte, error) { + switch ctx.CompressMethod { + case connection.CompressNone: + return data, nil + case connection.CompressZstd: + return c.encoder.EncodeAll(data, nil), nil + default: + // Default to zstd + return c.encoder.EncodeAll(data, nil), nil + } +} + +// Decompress decompresses data using the appropriate method +func (c *Codec) Decompress(ctx *connection.Context, data []byte, origLen uint32) ([]byte, error) { + switch ctx.CompressMethod { + case connection.CompressNone: + // No compression, return as-is + result := make([]byte, len(data)) + copy(result, data) + return result, nil + case connection.CompressZstd: + result := make([]byte, 0, origLen) + return c.decoder.DecodeAll(data, result) + default: + // Try zstd by default + result := make([]byte, 0, origLen) + return c.decoder.DecodeAll(data, result) + } +} + +// Encode encodes data after compression (before sending) - Encoder2 +func (c *Codec) Encode(ctx *connection.Context, data []byte) { + // This is Encoder2 - applied after compression + switch ctx.FlagType { + case connection.FlagHello, connection.FlagHell: + // XOREncoder16 - needs param from header + // For now, skip encoding on send since we need the header params + case connection.FlagFuck: + // No encoding after compression for FUCK + } +} + +// Decode decodes data before decompression (after receiving) - Encoder2 +func (c *Codec) Decode(ctx *connection.Context, data []byte) { + // This is Encoder2 - applied before decompression + // XOREncoder16 for HELL/HELLO protocols + if ctx.FlagType == connection.FlagHell || ctx.FlagType == connection.FlagHello { + // Get k1, k2 from stored header params + if len(ctx.HeaderParams) >= 8 { + k1 := ctx.HeaderParams[6] + k2 := ctx.HeaderParams[7] + if k1 != 0 || k2 != 0 { + xorDecoder16(data, k1, k2) + } + } + } +} + +// EncodeData encodes data before compression - Encoder +func (c *Codec) EncodeData(ctx *connection.Context, data []byte) { + // This is Encoder - applied before compression + // Default encoder does nothing +} + +// DecodeData decodes data after decompression - Encoder +func (c *Codec) DecodeData(ctx *connection.Context, data []byte) { + // This is Encoder - applied after decompression + // Default encoder does nothing +} + +// xorDecoder16 implements XOREncoder16.decrypt_internal +func xorDecoder16(data []byte, k1, k2 byte) { + if len(data) == 0 { + return + } + + key := (uint16(k1) << 8) | uint16(k2) + dataLen := len(data) + + // Reverse two rounds of pseudo-random swaps + for round := 1; round >= 0; round-- { + for i := dataLen - 1; i >= 0; i-- { + j := int(pseudoRandom(key, i+round*100)) % dataLen + data[i], data[j] = data[j], data[i] + } + } + + // XOR decode + for i := 0; i < dataLen; i++ { + data[i] ^= (k1 + byte(i*13)) ^ (k2 ^ byte(i<<1)) + } +} + +// xorEncoder16 implements XOREncoder16.encrypt_internal +func xorEncoder16(data []byte, k1, k2 byte) { + if len(data) == 0 { + return + } + + key := (uint16(k1) << 8) | uint16(k2) + dataLen := len(data) + + // XOR encode + for i := 0; i < dataLen; i++ { + data[i] ^= (k1 + byte(i*13)) ^ (k2 ^ byte(i<<1)) + } + + // Two rounds of pseudo-random swaps + for round := 0; round < 2; round++ { + for i := 0; i < dataLen; i++ { + j := int(pseudoRandom(key, i+round*100)) % dataLen + data[i], data[j] = data[j], data[i] + } + } +} + +// pseudoRandom matches the C++ pseudo_random function +func pseudoRandom(seed uint16, index int) uint16 { + return ((seed ^ uint16(index*251+97)) * 733) ^ (seed >> 3) +} + +// Close releases resources +func (c *Codec) Close() { + if c.encoder != nil { + c.encoder.Close() + } + if c.decoder != nil { + c.decoder.Close() + } +} diff --git a/server/go/protocol/commands.go b/server/go/protocol/commands.go new file mode 100644 index 0000000..5f16a83 --- /dev/null +++ b/server/go/protocol/commands.go @@ -0,0 +1,202 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "strings" + + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/transform" +) + +// gbkToUTF8 converts GBK encoded bytes to UTF-8 string +func gbkToUTF8(data []byte) string { + // Find the first null byte and truncate there + if idx := bytes.IndexByte(data, 0); idx >= 0 { + data = data[:idx] + } + if len(data) == 0 { + return "" + } + + // Try to decode as GBK + reader := transform.NewReader(bytes.NewReader(data), simplifiedchinese.GBK.NewDecoder()) + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(reader) + if err != nil { + // If GBK decoding fails, try treating as UTF-8 or ASCII + return cleanString(string(data)) + } + return cleanString(buf.String()) +} + +// cleanString removes non-printable characters except common whitespace +func cleanString(s string) string { + var result strings.Builder + for _, r := range s { + if r >= 32 || r == '\t' || r == '\n' || r == '\r' { + result.WriteRune(r) + } + } + return strings.TrimSpace(result.String()) +} + +// Command tokens - matching the C++ definitions +const ( + // Server -> Client commands + CommandActived byte = 0 // COMMAND_ACTIVED + CommandBye byte = 204 // COMMAND_BYE - disconnect + CommandHeartbeat byte = 216 // CMD_HEARTBEAT_ACK + + // Client -> Server tokens + TokenAuth byte = 100 // TOKEN_AUTH - authorization required + TokenHeartbeat byte = 101 // TOKEN_HEARTBEAT + TokenLogin byte = 102 // TOKEN_LOGIN - login packet +) + +// LOGIN_INFOR structure size and offsets (matching C++ struct with default alignment) +// Note: C++ struct uses default alignment (4-byte for uint32/int) +const ( + LoginInfoSize = 980 // Total size of LOGIN_INFOR struct (with alignment padding) + + // Field offsets (with alignment padding) + OffsetToken = 0 // 1 byte (unsigned char) + OffsetOsVerInfoEx = 1 // 156 bytes (char[156]) + // 3 bytes padding here to align dwCPUMHz to 4-byte boundary + OffsetCPUMHz = 160 // 4 bytes (unsigned int) - aligned to 4 + OffsetModuleVersion = 164 // 24 bytes (char[24]) + OffsetPCName = 188 // 240 bytes (char[240]) + OffsetMasterID = 428 // 20 bytes (char[20]) + OffsetWebCamExist = 448 // 4 bytes (int) - aligned to 4 + OffsetSpeed = 452 // 4 bytes (unsigned int) + OffsetStartTime = 456 // 20 bytes (char[20]) + OffsetReserved = 476 // 512 bytes (char[512]) +) + +// LoginInfo represents client login information +type LoginInfo struct { + Token byte + OsVerInfo string // OS version info + CPUMHz uint32 + ModuleVersion string + PCName string // Computer name + MasterID string + WebCamExist bool + Speed uint32 + StartTime string + Reserved string // Contains additional info separated by | +} + +// ParseLoginInfo parses LOGIN_INFOR from data +func ParseLoginInfo(data []byte) (*LoginInfo, error) { + if len(data) < 100 { // Minimum size check + return nil, ErrInvalidData + } + + info := &LoginInfo{ + Token: data[0], + } + + // Parse OS version info (offset 1, 156 bytes) + // The C++ client fills this with a readable string like "Windows 10" via getSystemName() + if len(data) >= OffsetOsVerInfoEx+156 { + info.OsVerInfo = parseOsVersionInfo(data[OffsetOsVerInfoEx : OffsetOsVerInfoEx+156]) + } + + // Parse CPU MHz (offset 160, 4 bytes) + if len(data) >= OffsetCPUMHz+4 { + info.CPUMHz = binary.LittleEndian.Uint32(data[OffsetCPUMHz:]) + } + + // Parse module version (offset 164, 24 bytes) + // This contains date string like "Dec 19 2025" + if len(data) >= OffsetModuleVersion+24 { + info.ModuleVersion = gbkToUTF8(data[OffsetModuleVersion : OffsetModuleVersion+24]) + } + + // Parse PC name (offset 188, 240 bytes) + if len(data) >= OffsetPCName+240 { + info.PCName = gbkToUTF8(data[OffsetPCName : OffsetPCName+240]) + } + + // Parse Master ID (offset 428, 20 bytes) + if len(data) >= OffsetMasterID+20 { + info.MasterID = gbkToUTF8(data[OffsetMasterID : OffsetMasterID+20]) + } + + // Parse WebCam exist (offset 448, 4 bytes) + if len(data) >= OffsetWebCamExist+4 { + info.WebCamExist = binary.LittleEndian.Uint32(data[OffsetWebCamExist:]) != 0 + } + + // Parse Speed (offset 452, 4 bytes) + if len(data) >= OffsetSpeed+4 { + info.Speed = binary.LittleEndian.Uint32(data[OffsetSpeed:]) + } + + // Parse Start time (offset 456, 20 bytes) + if len(data) >= OffsetStartTime+20 { + info.StartTime = gbkToUTF8(data[OffsetStartTime : OffsetStartTime+20]) + } + + // Parse Reserved (offset 476, 512 bytes) - contains additional info + if len(data) >= OffsetReserved+512 { + info.Reserved = gbkToUTF8(data[OffsetReserved : OffsetReserved+512]) + } else if len(data) > OffsetReserved { + info.Reserved = gbkToUTF8(data[OffsetReserved:]) + } + + return info, nil +} + +// parseOsVersionInfo parses the OS version info field +// The C++ client fills this with a readable string like "Windows 10" via getSystemName() +func parseOsVersionInfo(data []byte) string { + return gbkToUTF8(data) +} + +// ParseReserved parses the reserved field into a slice of strings +func (info *LoginInfo) ParseReserved() []string { + if info.Reserved == "" { + return nil + } + return strings.Split(info.Reserved, "|") +} + +// GetReservedField returns a specific field from reserved data by index +// Fields: ClientType(0), SystemBits(1), CPU(2), Memory(3), FilePath(4), +// Reserved(5), InstallTime(6), InstallInfo(7), ProgramBits(8), ExpiredDate(9), +// ClientLoc(10), ClientPubIP(11), ExeVersion(12), Username(13), IsAdmin(14) +func (info *LoginInfo) GetReservedField(index int) string { + fields := info.ParseReserved() + if index >= 0 && index < len(fields) { + return fields[index] + } + return "" +} + +// Validation structure for TOKEN_AUTH +type Validation struct { + From string // Start date + To string // End date + Admin string // Admin address + Port int // Admin port + Checksum string // Reserved field +} + +// BuildValidation creates a validation response +func BuildValidation(days float64, admin string, port int) []byte { + // This would build the validation structure + // For now, return a simple structure + data := make([]byte, 160) // Size of Validation struct + data[0] = TokenAuth + + // Fill in fields... + // From: 20 bytes + // To: 20 bytes + // Admin: 100 bytes + // Port: 4 bytes + // Checksum: 16 bytes + + return data +} diff --git a/server/go/protocol/header.go b/server/go/protocol/header.go new file mode 100644 index 0000000..0702d67 --- /dev/null +++ b/server/go/protocol/header.go @@ -0,0 +1,267 @@ +package protocol + +// Header encoding/decoding functions +// Ported from common/header.h and common/encfuncs.h + +const ( + MsgHeader = "HELL" + FlagCompLen = 4 + FlagLength = 8 + HdrLength = FlagLength + 8 // FLAG_LENGTH + 2 * sizeof(uint32) + MinComLen = 12 +) + +// HeaderEncType represents the encryption method used for header +type HeaderEncType int + +const ( + HeaderEncUnknown HeaderEncType = -1 + HeaderEncNone HeaderEncType = 0 + HeaderEncV0 HeaderEncType = 1 + HeaderEncV1 HeaderEncType = 2 + HeaderEncV2 HeaderEncType = 3 + HeaderEncV3 HeaderEncType = 4 + HeaderEncV4 HeaderEncType = 5 + HeaderEncV5 HeaderEncType = 6 + HeaderEncV6 HeaderEncType = 7 +) + +// DecryptFunc is the function signature for header decryption +type DecryptFunc func(data []byte, key byte) + +// defaultDecrypt does nothing (no encryption) +func defaultDecrypt(data []byte, key byte) { + // No-op +} + +// decrypt is the default encryption method (V0) +func decrypt(data []byte, key byte) { + if key == 0 { + return + } + for i := 0; i < len(data); i++ { + k := key ^ byte(i*31) + value := int(data[i]) + switch i % 4 { + case 0: + value -= int(k) + case 1: + value = value ^ int(k) + case 2: + value += int(k) + case 3: + value = ^value ^ int(k) + } + data[i] = byte(value & 0xFF) + } +} + +// decryptV1 - alternating add/subtract +func decryptV1(data []byte, key byte) { + for i := 0; i < len(data); i++ { + if i%2 == 0 { + data[i] = data[i] - key + } else { + data[i] = data[i] + key + } + } +} + +// decryptV2 - XOR with rotation +func decryptV2(data []byte, key byte) { + for i := 0; i < len(data); i++ { + if i%2 == 0 { + data[i] = (data[i] >> 1) | (data[i] << 7) // rotate right + } else { + data[i] = (data[i] << 1) | (data[i] >> 7) // rotate left + } + data[i] ^= key + } +} + +// decryptV3 - dynamic key with position +func decryptV3(data []byte, key byte) { + for i := 0; i < len(data); i++ { + dynamicKey := key + byte(i%8) + switch i % 3 { + case 0: + data[i] = (data[i] ^ dynamicKey) - dynamicKey + case 1: + data[i] = (data[i] + dynamicKey) ^ dynamicKey + case 2: + data[i] = ^data[i] - dynamicKey + } + } +} + +// decryptV4 - pseudo-random XOR (symmetric) +func decryptV4(data []byte, key byte) { + rand := uint16(key) + for i := 0; i < len(data); i++ { + rand = (rand*13 + 17) % 256 + data[i] ^= byte(rand) + } +} + +// decryptV5 - dynamic key with bit shift +func decryptV5(data []byte, key byte) { + for i := 0; i < len(data); i++ { + dynamicKey := (key + byte(i)) ^ 0x55 + data[i] = ((data[i] - byte(i%7)) ^ (dynamicKey << 3)) - dynamicKey + } +} + +// decryptV6 - pseudo-random with position (symmetric) +func decryptV6(data []byte, key byte) { + rand := uint16(key) + for i := 0; i < len(data); i++ { + rand = (rand*31 + 17) % 256 + data[i] ^= byte(rand) + byte(i) + } +} + +// All decrypt methods +var decryptMethods = []DecryptFunc{ + defaultDecrypt, + decrypt, + decryptV1, + decryptV2, + decryptV3, + decryptV4, + decryptV5, + decryptV6, +} + +// compare decrypts and compares the flag with magic +func compare(flag []byte, magic string, length int, dec DecryptFunc, key byte) bool { + if len(flag) < MinComLen { + return false + } + + buf := make([]byte, MinComLen) + copy(buf, flag[:MinComLen]) + dec(buf[:length], key) + + return string(buf[:length]) == magic +} + +// CheckHead tries all decryption methods to identify the protocol +func CheckHead(flag []byte) (flagType FlagType, encType HeaderEncType, decrypted []byte) { + if len(flag) < MinComLen { + return FlagUnknown, HeaderEncUnknown, nil + } + + for i, method := range decryptMethods { + buf := make([]byte, MinComLen) + copy(buf, flag[:MinComLen]) + + ft := checkHeadWithMethod(buf, method) + if ft != FlagUnknown { + return ft, HeaderEncType(i), buf + } + } + + return FlagUnknown, HeaderEncUnknown, nil +} + +// checkHeadWithMethod checks the flag with a specific decrypt method +func checkHeadWithMethod(flag []byte, dec DecryptFunc) FlagType { + // Try HELL (FLAG_HELL) + if len(flag) >= FlagLength { + buf := make([]byte, MinComLen) + copy(buf, flag) + key := buf[6] + dec(buf[:FlagCompLen], key) + if string(buf[:4]) == MsgHeader { + copy(flag, buf) + return FlagHell + } + } + + // Try Shine (FLAG_SHINE) + if len(flag) >= 5 { + buf := make([]byte, MinComLen) + copy(buf, flag) + dec(buf[:5], 0) + if string(buf[:5]) == "Shine" { + copy(flag, buf) + return FlagShine + } + } + + // Try <> (FLAG_FUCK) + if len(flag) >= 10 { + buf := make([]byte, MinComLen) + copy(buf, flag) + key := buf[9] + dec(buf[:8], key) + if string(buf[:8]) == "<>" { + copy(flag, buf) + return FlagFuck + } + } + + // Try Hello? (FLAG_HELLO) + if len(flag) >= 7 { + buf := make([]byte, MinComLen) + copy(buf, flag) + key := buf[6] + dec(buf[:6], key) + if string(buf[:6]) == "Hello?" { + copy(flag, buf) + return FlagHello + } + } + + return FlagUnknown +} + +// FlagType represents the protocol type +type FlagType int + +const ( + FlagWinOS FlagType = -1 + FlagUnknown FlagType = 0 + FlagShine FlagType = 1 + FlagFuck FlagType = 2 + FlagHello FlagType = 3 + FlagHell FlagType = 4 +) + +func (f FlagType) String() string { + switch f { + case FlagWinOS: + return "WinOS" + case FlagShine: + return "Shine" + case FlagFuck: + return "Fuck" + case FlagHello: + return "Hello" + case FlagHell: + return "Hell" + default: + return "Unknown" + } +} + +// GetFlagLength returns the flag length for a given flag type +func GetFlagLength(ft FlagType) int { + switch ft { + case FlagShine: + return 5 + case FlagFuck: + return 11 // 8 + 3 + case FlagHello: + return 8 + case FlagHell: + return FlagLength + default: + return 0 + } +} + +// GetHeaderLength returns the full header length for a given flag type +func GetHeaderLength(ft FlagType) int { + return GetFlagLength(ft) + 8 +} diff --git a/server/go/protocol/parser.go b/server/go/protocol/parser.go new file mode 100644 index 0000000..6e38ef3 --- /dev/null +++ b/server/go/protocol/parser.go @@ -0,0 +1,261 @@ +package protocol + +import ( + "encoding/binary" + "errors" + + "github.com/yuanyuanxiang/SimpleRemoter/server/go/connection" +) + +// Errors +var ( + ErrNeedMore = errors.New("need more data") + ErrInvalidData = errors.New("invalid data") + ErrUnsupported = errors.New("unsupported protocol") + ErrDecompress = errors.New("decompression failed") +) + +// Parser handles protocol parsing +type Parser struct { + codec *Codec +} + +// NewParser creates a new parser +func NewParser() *Parser { + return &Parser{ + codec: NewCodec(), + } +} + +// Close releases resources held by the parser +func (p *Parser) Close() { + if p.codec != nil { + p.codec.Close() + } +} + +// Parse attempts to parse a complete packet from the buffer +func (p *Parser) Parse(ctx *connection.Context) ([]byte, error) { + buf := ctx.InBuffer + + // Need at least minimum bytes to determine protocol + if buf.Len() < MinComLen { + return nil, ErrNeedMore + } + + // Check if header is already parsed + if ctx.FlagType == connection.FlagUnknown { + // Try to parse header + if err := p.parseHeader(ctx); err != nil { + return nil, err + } + } + + // Now parse the packet + return p.parsePacket(ctx) +} + +// parseHeader parses the protocol header with obfuscation handling +func (p *Parser) parseHeader(ctx *connection.Context) error { + buf := ctx.InBuffer + header := buf.Peek(MinComLen) + if header == nil || len(header) < MinComLen { + return ErrNeedMore + } + + // Try to decode the header using all encryption methods + flagType, encType, decrypted := CheckHead(header) + + if flagType == FlagUnknown { + return ErrUnsupported + } + + // Store decrypted header params for later use (for XOREncoder16) + ctx.HeaderParams = make([]byte, MinComLen) + if decrypted != nil { + copy(ctx.HeaderParams, decrypted) + } else { + copy(ctx.HeaderParams, header) + } + + // Map protocol FlagType to connection FlagType and set compression method + switch flagType { + case FlagHell: + ctx.FlagType = connection.FlagHell + ctx.FlagLen = FlagLength + ctx.HeaderLen = ctx.FlagLen + 8 + ctx.CompressMethod = connection.CompressZstd // HELL uses ZSTD + case FlagHello: + ctx.FlagType = connection.FlagHello + ctx.FlagLen = 8 + ctx.HeaderLen = ctx.FlagLen + 8 + ctx.CompressMethod = connection.CompressNone // HELLO uses no compression + case FlagShine: + ctx.FlagType = connection.FlagShine + ctx.FlagLen = 5 + ctx.HeaderLen = ctx.FlagLen + 8 + ctx.CompressMethod = connection.CompressZstd // SHINE uses ZSTD + case FlagFuck: + ctx.FlagType = connection.FlagFuck + ctx.FlagLen = 11 + ctx.HeaderLen = ctx.FlagLen + 8 + ctx.CompressMethod = connection.CompressZstd // FUCK uses ZSTD + default: + return ErrUnsupported + } + + // Store encryption type for later use + ctx.HeaderEncType = int(encType) + + return nil +} + +// parsePacket parses a complete packet +func (p *Parser) parsePacket(ctx *connection.Context) ([]byte, error) { + buf := ctx.InBuffer + + // Check if we have enough data for header + if buf.Len() < ctx.HeaderLen { + return nil, ErrNeedMore + } + + // Peek the header to get total length + headerData := buf.Peek(ctx.HeaderLen) + if headerData == nil { + return nil, ErrNeedMore + } + + // Decrypt the header first + decryptedHeader := make([]byte, len(headerData)) + copy(decryptedHeader, headerData) + + // Get the encryption key (usually at position 6) + var key byte + if len(headerData) > 6 { + key = headerData[6] // Use original key before decryption + } + + // Decrypt flag portion + if ctx.HeaderEncType >= 0 && ctx.HeaderEncType < len(decryptMethods) { + decryptMethods[ctx.HeaderEncType](decryptedHeader[:FlagCompLen], key) + } + + // Read the total length field (after flag) + totalLen := binary.LittleEndian.Uint32(decryptedHeader[ctx.FlagLen:]) + + // Validate length + if totalLen < uint32(ctx.HeaderLen) || totalLen > 10*1024*1024 { + return nil, ErrInvalidData + } + + // Check if we have the complete packet + if buf.Len() < int(totalLen) { + return nil, ErrNeedMore + } + + // Read the complete packet + packet := buf.Read(int(totalLen)) + if packet == nil { + return nil, ErrInvalidData + } + + // Decrypt header portion of packet + if ctx.HeaderEncType >= 0 && ctx.HeaderEncType < len(decryptMethods) { + decryptMethods[ctx.HeaderEncType](packet[:FlagCompLen], key) + } + + // Update HeaderParams with this packet's header (for XOREncoder16 k1, k2) + if len(packet) >= ctx.FlagLen { + ctx.HeaderParams = make([]byte, ctx.HeaderLen) + copy(ctx.HeaderParams, packet[:ctx.HeaderLen]) + } + + // Extract data after header + dataStart := ctx.HeaderLen + data := packet[dataStart:] + + // Get original length (before compression) + var origLen uint32 + if ctx.FlagType != connection.FlagWinOS { + origLen = binary.LittleEndian.Uint32(packet[ctx.FlagLen+4 : ctx.FlagLen+8]) + } + + // Decode (XOR, etc.) before decompression - this is Encoder2 + p.codec.Decode(ctx, data) + + // Decompress + decompressed, err := p.codec.Decompress(ctx, data, origLen) + if err != nil { + return nil, err + } + + // Decode after decompression - this is Encoder + p.codec.DecodeData(ctx, decompressed) + + return decompressed, nil +} + +// Encode encodes data for sending +func (p *Parser) Encode(ctx *connection.Context, data []byte) ([]byte, error) { + // Encode before compression + encoded := make([]byte, len(data)) + copy(encoded, data) + p.codec.EncodeData(ctx, encoded) + + // Compress + compressed, err := p.codec.Compress(ctx, encoded) + if err != nil { + return nil, err + } + + // Build packet + packet := p.buildPacket(ctx, compressed, uint32(len(data))) + + return packet, nil +} + +// buildPacket builds a complete packet with header +func (p *Parser) buildPacket(ctx *connection.Context, data []byte, origLen uint32) []byte { + totalLen := ctx.HeaderLen + len(data) + packet := make([]byte, totalLen) + + // Write flag + flag := p.getFlag(ctx) + copy(packet[:ctx.FlagLen], flag) + + // Write total length + binary.LittleEndian.PutUint32(packet[ctx.FlagLen:], uint32(totalLen)) + + // Write original length + binary.LittleEndian.PutUint32(packet[ctx.FlagLen+4:], origLen) + + // Write data + copy(packet[ctx.HeaderLen:], data) + + // Encode after building + p.codec.Encode(ctx, packet[ctx.HeaderLen:]) + + return packet +} + +// getFlag returns the protocol flag bytes +func (p *Parser) getFlag(ctx *connection.Context) []byte { + switch ctx.FlagType { + case connection.FlagHell: + flag := make([]byte, FlagLength) + copy(flag, []byte(MsgHeader)) + return flag + case connection.FlagHello: + flag := make([]byte, 8) + copy(flag, []byte("Hello?")) + return flag + case connection.FlagShine: + return []byte("Shine") + case connection.FlagFuck: + flag := make([]byte, 11) + copy(flag, []byte("<>")) + return flag + default: + return make([]byte, ctx.FlagLen) + } +} diff --git a/server/go/server/server.go b/server/go/server/server.go new file mode 100644 index 0000000..f8e6c13 --- /dev/null +++ b/server/go/server/server.go @@ -0,0 +1,316 @@ +package server + +import ( + "context" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/yuanyuanxiang/SimpleRemoter/server/go/connection" + "github.com/yuanyuanxiang/SimpleRemoter/server/go/logger" + "github.com/yuanyuanxiang/SimpleRemoter/server/go/protocol" +) + +// Config holds server configuration +type Config struct { + Port int + MaxConnections int + ReadBufferSize int + WriteBufferSize int + KeepAliveTime time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +// DefaultConfig returns default configuration +func DefaultConfig() Config { + return Config{ + Port: 6543, + MaxConnections: 9999, + ReadBufferSize: 8192, + WriteBufferSize: 8192, + KeepAliveTime: time.Minute * 5, + ReadTimeout: time.Minute * 2, + WriteTimeout: time.Second * 30, + } +} + +// Handler defines the interface for handling client events +type Handler interface { + OnConnect(ctx *connection.Context) + OnDisconnect(ctx *connection.Context) + OnReceive(ctx *connection.Context, data []byte) +} + +// Server is the TCP server +type Server struct { + config Config + listener net.Listener + manager *connection.Manager + handler Handler + parser *protocol.Parser + + running atomic.Bool + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + + // Logger + logger *logger.Logger +} + +// New creates a new server +func New(config Config) *Server { + ctx, cancel := context.WithCancel(context.Background()) + + s := &Server{ + config: config, + manager: connection.NewManager(config.MaxConnections), + parser: protocol.NewParser(), + ctx: ctx, + cancel: cancel, + logger: logger.New(logger.DefaultConfig()).WithPrefix("Server"), + } + + return s +} + +// SetHandler sets the event handler +func (s *Server) SetHandler(h Handler) { + s.handler = h + s.manager.SetCallbacks( + func(ctx *connection.Context) { + if s.handler != nil { + s.handler.OnConnect(ctx) + } + }, + func(ctx *connection.Context) { + if s.handler != nil { + s.handler.OnDisconnect(ctx) + } + }, + func(ctx *connection.Context, data []byte) { + if s.handler != nil { + s.handler.OnReceive(ctx, data) + } + }, + ) +} + +// SetLogger sets the logger +func (s *Server) SetLogger(l *logger.Logger) { + s.logger = l +} + +// Start starts the server +func (s *Server) Start() error { + addr := net.JoinHostPort("0.0.0.0", itoa(s.config.Port)) + listener, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + s.listener = listener + s.running.Store(true) + + s.logger.Info("Server started on port %d", s.config.Port) + + s.wg.Add(1) + go s.acceptLoop() + + return nil +} + +func itoa(i int) string { + if i == 0 { + return "0" + } + var b [20]byte + pos := len(b) + neg := i < 0 + if neg { + // Handle math.MinInt safely by using unsigned + u := uint(-i) + for u > 0 { + pos-- + b[pos] = byte('0' + u%10) + u /= 10 + } + pos-- + b[pos] = '-' + } else { + for i > 0 { + pos-- + b[pos] = byte('0' + i%10) + i /= 10 + } + } + return string(b[pos:]) +} + +// Stop stops the server gracefully +func (s *Server) Stop() { + if !s.running.Swap(false) { + return + } + + s.cancel() + + if s.listener != nil { + _ = s.listener.Close() + } + + // Close all connections + s.manager.CloseAll() + + // Close parser resources + if s.parser != nil { + s.parser.Close() + } + + s.wg.Wait() + s.logger.Info("Server stopped") +} + +// acceptLoop accepts incoming connections +func (s *Server) acceptLoop() { + defer s.wg.Done() + + for s.running.Load() { + conn, err := s.listener.Accept() + if err != nil { + if s.running.Load() { + s.logger.Error("Accept error: %v", err) + } + continue + } + + // Check connection limit before spawning goroutine + if s.manager.Count() >= s.config.MaxConnections { + s.logger.Warn("Max connections reached, rejecting new connection from %s", conn.RemoteAddr()) + _ = conn.Close() + continue + } + + // Handle each connection in its own goroutine + go s.handleConnection(conn) + } +} + +// handleConnection handles a single connection +func (s *Server) handleConnection(conn net.Conn) { + // Create context + ctx := connection.NewContext(conn, nil) + + // Add to manager + if err := s.manager.Add(ctx); err != nil { + s.logger.Warn("Failed to add connection: %v", err) + _ = conn.Close() + return + } + + defer func() { + s.manager.Remove(ctx) + _ = ctx.Close() + }() + + // Read loop + buf := make([]byte, s.config.ReadBufferSize) + for !ctx.IsClosed() && s.running.Load() { + // Set read deadline + if s.config.ReadTimeout > 0 { + _ = conn.SetReadDeadline(time.Now().Add(s.config.ReadTimeout)) + } + + n, err := conn.Read(buf) + if err != nil { + if err != io.EOF && s.running.Load() { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + // Timeout - check keepalive + if s.config.KeepAliveTime > 0 && ctx.TimeSinceLastActive() > s.config.KeepAliveTime { + s.logger.Info("Connection %d timed out", ctx.ID) + break + } + continue + } + } + break + } + + if n > 0 { + ctx.UpdateLastActive() + + // Write to input buffer + _, _ = ctx.InBuffer.Write(buf[:n]) + + // Process received data + s.processData(ctx) + } + } +} + +// processData processes received data and calls handler +func (s *Server) processData(ctx *connection.Context) { + for ctx.InBuffer.Len() > 0 { + // Try to parse a complete packet + data, err := s.parser.Parse(ctx) + if err != nil { + if err == protocol.ErrNeedMore { + return + } + s.logger.Error("Parse error for connection %d: %v", ctx.ID, err) + _ = ctx.Close() + return + } + + if data != nil { + // Call handler + s.manager.OnReceive(ctx, data) + } + } +} + +// Send sends data to a specific connection +func (s *Server) Send(ctx *connection.Context, data []byte) error { + if ctx == nil || ctx.IsClosed() { + return connection.ErrConnectionClosed + } + + // Encode and compress data + encoded, err := s.parser.Encode(ctx, data) + if err != nil { + return err + } + + return ctx.Send(encoded) +} + +// Broadcast sends data to all connections +func (s *Server) Broadcast(data []byte) { + s.manager.Range(func(ctx *connection.Context) bool { + _ = s.Send(ctx, data) + return true + }) +} + +// GetConnection returns a connection by ID +func (s *Server) GetConnection(id uint64) *connection.Context { + return s.manager.Get(id) +} + +// ConnectionCount returns the current connection count +func (s *Server) ConnectionCount() int { + return s.manager.Count() +} + +// Port returns the server port +func (s *Server) Port() int { + return s.config.Port +} + +// UpdateMaxConnections updates the max connections limit +func (s *Server) UpdateMaxConnections(max int) { + s.manager.UpdateMaxConnections(max) +}