Feature: Add Go TCP server framework

This commit is contained in:
Shaun
2025-12-19 14:12:29 +01:00
committed by yuanyuanxiang
parent 7d2cf647ec
commit 5a33628b92
16 changed files with 2807 additions and 0 deletions

30
server/go/.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,30 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Launch Server",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/cmd",
"cwd": "${workspaceFolder}",
"args": [],
"env": {},
"console": "integratedTerminal"
},
{
"name": "Debug Server",
"type": "go",
"request": "launch",
"mode": "debug",
"program": "${workspaceFolder}/cmd",
"cwd": "${workspaceFolder}",
"args": [
"-port=9090"
],
"env": {},
"console": "integratedTerminal",
"buildFlags": "-gcflags='all=-N -l'"
}
]
}

333
server/go/README.md Normal file
View File

@@ -0,0 +1,333 @@
# SimpleRemoter Go TCP Server Framework
基于 Go 语言实现的高性能 TCP 服务端框架,用于替代原有的 C++ IOCP 服务端。
## 项目结构
```
server/go/
├── go.mod # Go 模块定义
├── buffer/
│ └── buffer.go # 线程安全的动态缓冲区
├── connection/
│ ├── context.go # 连接上下文
│ ├── errors.go # 错误定义
│ └── manager.go # 连接管理器
├── protocol/
│ ├── parser.go # 协议解析器
│ ├── codec.go # 编解码和压缩 (ZSTD)
│ ├── header.go # 协议头解密 (8种加密方式)
│ └── commands.go # 命令常量和LOGIN_INFOR解析
├── server/
│ ├── server.go # TCP 服务器核心
│ └── pool.go # Goroutine 工作池
├── logger/
│ └── logger.go # 日志模块 (基于 zerolog)
└── cmd/
└── main.go # 程序入口
```
## 核心特性
- **高并发**: 基于 Goroutine 池管理并发连接
- **协议兼容**: 支持原有 C++ 客户端的多种协议标识 (Hell/Hello/Shine/Fuck)
- **协议头解密**: 支持8种协议头加密方式 (V0-V6 + Default)
- **XOR编码**: 支持 XOREncoder16 数据编码/解码
- **ZSTD 压缩**: 使用高效的 ZSTD 算法进行数据压缩
- **GBK编码**: 自动将 Windows 客户端的 GBK 编码转换为 UTF-8
- **线程安全**: Buffer、连接管理器和 LastActive 均为线程安全设计
- **优雅关闭**: 支持信号处理和优雅停机,自动释放资源
- **可配置**: 支持自定义端口、最大连接数、超时时间等
- **日志系统**: 基于 zerolog支持文件输出、日志轮转、客户端上下线记录
## 支持的命令
当前已实现以下命令处理:
| 命令 | 值 | 说明 |
|------|-----|------|
| TOKEN_AUTH | 100 | 授权请求 |
| TOKEN_HEARTBEAT | 101 | 心跳包 |
| TOKEN_LOGIN | 102 | 客户端登录 |
其他命令会被记录为 Debug 日志,可按需扩展。
## 快速开始
### 安装依赖
```bash
cd server/go
go mod tidy
```
### 编译
```bash
go build -o simpleremoter-server ./cmd
```
### 运行
```bash
./simpleremoter-server
```
服务器默认监听 6543 端口,日志输出到 `logs/server.log`
## 使用示例
```go
package main
import (
"os"
"os/signal"
"syscall"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/logger"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/protocol"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/server"
)
// 实现 Handler 接口
type MyHandler struct {
log *logger.Logger
}
func (h *MyHandler) OnConnect(ctx *connection.Context) {
h.log.ClientEvent("online", ctx.ID, ctx.GetPeerIP())
}
func (h *MyHandler) OnDisconnect(ctx *connection.Context) {
h.log.ClientEvent("offline", ctx.ID, ctx.GetPeerIP())
}
func (h *MyHandler) OnReceive(ctx *connection.Context, data []byte) {
if len(data) == 0 {
return
}
cmd := data[0]
switch cmd {
case protocol.TokenLogin:
info, _ := protocol.ParseLoginInfo(data)
h.log.Info("Client login: %s (%s)", info.PCName, info.OsVerInfo)
case protocol.TokenHeartbeat:
h.log.Debug("Heartbeat from client %d", ctx.ID)
}
}
func main() {
// 配置日志 (控制台 + 文件)
logCfg := logger.DefaultConfig()
logCfg.File = "logs/server.log"
log := logger.New(logCfg)
// 配置服务器
config := server.DefaultConfig()
config.Port = 6543
// 创建并启动服务器
srv := server.New(config)
srv.SetLogger(log.WithPrefix("Server"))
srv.SetHandler(&MyHandler{log: log})
if err := srv.Start(); err != nil {
log.Fatal("启动失败: %v", err)
}
// 等待退出信号
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
srv.Stop()
}
```
## 配置选项
| 配置项 | 默认值 | 说明 |
|--------|--------|------|
| Port | 8080 | 监听端口 |
| MaxConnections | 10000 | 最大连接数 |
| MinWorkers | 4 | 最小工作协程数 |
| MaxWorkers | 100 | 最大工作协程数 |
| ReadBufferSize | 8192 | 读缓冲区大小 |
| WriteBufferSize | 8192 | 写缓冲区大小 |
| KeepAliveTime | 5min | 连接保活时间 |
| ReadTimeout | 2min | 读超时时间 |
| WriteTimeout | 30s | 写超时时间 |
## 日志配置
| 配置项 | 默认值 | 说明 |
|--------|--------|------|
| Level | Info | 日志级别 (Debug/Info/Warn/Error/Fatal) |
| Console | true | 是否输出到控制台 |
| File | "" | 日志文件路径 (空则不写文件) |
| MaxSize | 100 | 单个日志文件最大 MB |
| MaxBackups | 3 | 保留的旧日志文件数量 |
| MaxAge | 30 | 旧日志保留天数 |
| Compress | true | 是否压缩轮转的日志 |
日志示例输出:
```json
{"level":"info","module":"Server","time":"2025-12-19T13:17:32+01:00","message":"Server started on port 6543"}
{"level":"info","module":"Handler","event":"login","client_id":1,"ip":"192.168.0.92","computer":"DESKTOP-BI6RGEJ","os":"Windows 10","version":"Dec 19 2025","time":"2025-12-19T13:17:32+01:00"}
{"level":"debug","module":"Handler","time":"2025-12-19T13:17:47+01:00","message":"Heartbeat from client 1 (DESKTOP-BI6RGEJ)"}
```
## 协议格式
数据包格式与 C++ 版本兼容:
```
+----------+------------+------------+------------------+
| Flag | TotalLen | OrigLen | Compressed Data |
| (N bytes)| (4 bytes) | (4 bytes) | (variable) |
+----------+------------+------------+------------------+
```
### 协议标识
| 标识 | Flag长度 | 压缩方式 | 说明 |
|------|----------|----------|------|
| HELL | 8 bytes | ZSTD | 主要协议 |
| Hello? | 8 bytes | None | 无压缩协议 |
| Shine | 5 bytes | ZSTD | 备用协议 |
| <<FUCK>> | 11 bytes | ZSTD | 备用协议 |
### 协议头加密
支持8种加密方式服务端自动检测并解密
- V0 (Default): 动态密钥4种操作
- V1: 交替加减
- V2: 带旋转的异或
- V3: 带位置的动态密钥
- V4: 对称的伪随机异或
- V5: 带位移的动态密钥
- V6: 带位置的伪随机
- V7: 纯异或
### LOGIN_INFOR 结构
客户端登录信息结构体 (考虑 C++ 内存对齐)
| 字段 | 偏移 | 大小 | 说明 |
|------|------|------|------|
| bToken | 0 | 1 | 命令标识 (102) |
| OsVerInfoEx | 1 | 156 | 操作系统版本 |
| (padding) | 157 | 3 | 对齐填充 |
| dwCPUMHz | 160 | 4 | CPU 频率 |
| moduleVersion | 164 | 24 | 模块版本 |
| szPCName | 188 | 240 | 计算机名 |
| szMasterID | 428 | 20 | 主控 ID |
| bWebCamExist | 448 | 4 | 是否有摄像头 |
| dwSpeed | 452 | 4 | 网速 |
| szStartTime | 456 | 20 | 启动时间 |
| szReserved | 476 | 512 | 扩展字段 (用`|`分隔) |
## API 参考
### Server
```go
// 创建服务器
srv := server.New(config)
// 设置日志
srv.SetLogger(log)
// 设置事件处理器
srv.SetHandler(handler)
// 启动服务器
srv.Start()
// 停止服务器
srv.Stop()
// 发送数据到指定连接
srv.Send(ctx, data)
// 广播数据到所有连接
srv.Broadcast(data)
// 获取当前连接数
count := srv.ConnectionCount()
```
### Connection Context
```go
// 发送数据
ctx.Send(data)
// 关闭连接
ctx.Close()
// 获取客户端 IP
ip := ctx.GetPeerIP()
// 检查连接状态
closed := ctx.IsClosed()
// 获取/更新最后活跃时间 (线程安全)
lastActive := ctx.LastActive()
ctx.UpdateLastActive()
duration := ctx.TimeSinceLastActive()
// 设置/获取客户端信息
ctx.SetInfo(clientInfo)
info := ctx.GetInfo()
// 设置/获取用户数据
ctx.SetUserData(myData)
data := ctx.GetUserData()
```
### Protocol
```go
// 解析登录信息
info, err := protocol.ParseLoginInfo(data)
if err == nil {
fmt.Println(info.PCName) // 计算机名
fmt.Println(info.OsVerInfo) // 操作系统
fmt.Println(info.ModuleVersion) // 版本
fmt.Println(info.WebCamExist) // 是否有摄像头
}
// 获取扩展字段
reserved := info.ParseReserved() // 返回 []string
clientType := info.GetReservedField(0) // 客户端类型
cpuCores := info.GetReservedField(2) // CPU 核数
filePath := info.GetReservedField(4) // 文件路径
publicIP := info.GetReservedField(11) // 公网 IP
```
## 与 C++ 版本对比
| 特性 | C++ (IOCP) | Go |
|------|------------|-----|
| 并发模型 | IOCP + 线程池 | Goroutine 池 |
| 压缩算法 | ZSTD | ZSTD |
| 跨平台 | Windows | 全平台 |
| 内存管理 | 手动 | GC |
| 代码复杂度 | 高 | 低 |
| 协议头解密 | 8种方式 | 8种方式 |
| XOR编码 | XOREncoder16 | XOREncoder16 |
| 字符编码 | GBK | GBK -> UTF-8 |
## 依赖
- [github.com/klauspost/compress/zstd](https://github.com/klauspost/compress) - ZSTD 压缩
- [github.com/rs/zerolog](https://github.com/rs/zerolog) - 高性能日志
- [gopkg.in/natefinch/lumberjack.v2](https://github.com/natefinch/lumberjack) - 日志轮转
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) - GBK 编码转换
## License
MIT License

208
server/go/auth/auth.go Normal file
View File

@@ -0,0 +1,208 @@
package auth
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"os"
"strings"
"golang.org/x/text/encoding/simplifiedchinese"
"golang.org/x/text/transform"
)
// Config holds authentication configuration
type Config struct {
PwdHash string // SHA256 hash of the password (64 hex chars)
SuperPass string // Super admin password for HMAC verification
}
// DefaultConfig returns default auth configuration
func DefaultConfig() *Config {
return &Config{
PwdHash: "", // Must be configured
SuperPass: "", // Can be set via YAMA_PWD env var
}
}
// Authenticator handles token authentication
type Authenticator struct {
config *Config
}
// New creates a new Authenticator
func New(config *Config) *Authenticator {
return &Authenticator{config: config}
}
// AuthResult contains the result of authentication
type AuthResult struct {
Valid bool
Message string
SN string
Passcode string
}
// Authenticate validates a TOKEN_AUTH request
// Data format:
// - offset 0: TOKEN_AUTH command byte
// - offset 1-19: SN (serial number, 19 bytes)
// - offset 20-62: Passcode (42 bytes)
// - offset 62-70: HMAC signature (uint64, 8 bytes) if len > 64
func (a *Authenticator) Authenticate(data []byte) *AuthResult {
result := &AuthResult{
Valid: false,
Message: "未获授权或消息哈希校验失败",
}
// Minimum length check: 1 (token) + 19 (sn) + 1 (at least some passcode)
if len(data) <= 20 {
return result
}
// Extract SN (bytes 1-19)
sn := string(data[1:20])
result.SN = sn
// Extract passcode (bytes 20-62, or until end if shorter)
passcodeEnd := 62
if len(data) < passcodeEnd {
passcodeEnd = len(data)
}
passcode := string(data[20:passcodeEnd])
result.Passcode = passcode
// Extract HMAC if present (bytes 62-70)
var hmacSig uint64
if len(data) >= 70 {
hmacSig = binary.LittleEndian.Uint64(data[62:70])
} else if len(data) > 62 {
// Partial HMAC data - safely handle incomplete bytes
hmacBytes := make([]byte, 8)
copy(hmacBytes, data[62:])
hmacSig = binary.LittleEndian.Uint64(hmacBytes)
}
// Split passcode by '-'
parts := strings.Split(passcode, "-")
if len(parts) != 6 && len(parts) != 7 {
return result
}
// Get last 4 parts as subvector
subvector := parts[len(parts)-4:]
// Build password string: v[0] + " - " + v[1] + ": " + PwdHash + (optional: ": " + v[2])
password := parts[0] + " - " + parts[1] + ": " + a.config.PwdHash
if len(parts) == 7 {
password += ": " + parts[2]
}
// Derive key from password and SN
finalKey := DeriveKey(password, sn)
// Get fixed length ID
hash256 := strings.Join(subvector, "-")
fixedKey := GetFixedLengthID(finalKey)
// Debug output (can be removed in production)
// fmt.Printf("DEBUG: password=%q sn=%q finalKey=%s fixedKey=%s hash256=%s\n", password, sn, finalKey, fixedKey, hash256)
// Compare
if hash256 != fixedKey {
return result
}
// Passcode validation successful, now verify HMAC
superPass := os.Getenv("YAMA_PWD")
if superPass == "" {
superPass = a.config.SuperPass
}
if superPass != "" && hmacSig != 0 {
verified := VerifyMessage(superPass, []byte(passcode), hmacSig)
if verified {
result.Valid = true
result.Message = "此程序已获授权,请遵守授权协议,感谢合作"
}
// If HMAC verification fails, valid remains false
} else if hmacSig == 0 {
// No HMAC provided but passcode is valid - could be older client
// Keep as invalid for security
}
return result
}
// utf8ToGBK converts UTF-8 string to GBK encoded bytes
func utf8ToGBK(s string) []byte {
reader := transform.NewReader(bytes.NewReader([]byte(s)), simplifiedchinese.GBK.NewEncoder())
buf := new(bytes.Buffer)
buf.ReadFrom(reader)
return buf.Bytes()
}
// BuildResponse builds the 100-byte response for TOKEN_AUTH
func (a *Authenticator) BuildResponse(result *AuthResult) []byte {
resp := make([]byte, 100)
if result.Valid {
resp[0] = 1
}
// Message starts at offset 4, convert UTF-8 to GBK for Windows client
gbkMsg := utf8ToGBK(result.Message)
copy(resp[4:], gbkMsg)
return resp
}
// HashSHA256 computes SHA256 hash and returns hex string
func HashSHA256(data string) string {
h := sha256.New()
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}
// DeriveKey derives a key from password and hardware ID
// Format: SHA256(password + " + " + hardwareID)
func DeriveKey(password, hardwareID string) string {
return HashSHA256(password + " + " + hardwareID)
}
// GetFixedLengthID formats a hash into fixed length ID
// Format: xxxx-xxxx-xxxx-xxxx (first 16 chars split by -)
func GetFixedLengthID(hash string) string {
if len(hash) < 16 {
return hash
}
return hash[0:4] + "-" + hash[4:8] + "-" + hash[8:12] + "-" + hash[12:16]
}
// SignMessage computes HMAC-SHA256 and returns first 8 bytes as uint64
func SignMessage(pwd string, msg []byte) uint64 {
h := hmac.New(sha256.New, []byte(pwd))
h.Write(msg)
hash := h.Sum(nil)
return binary.LittleEndian.Uint64(hash[:8])
}
// VerifyMessage verifies HMAC signature
func VerifyMessage(pwd string, msg []byte, signature uint64) bool {
computed := SignMessage(pwd, msg)
return computed == signature
}
// GenHMAC generates HMAC for password verification
// This matches the C++ genHMAC function
func GenHMAC(pwdHash, superPass string) string {
key := HashSHA256(superPass)
list := []string{"g", "h", "o", "s", "t"}
for _, item := range list {
key = HashSHA256(key + " - " + item)
}
result := HashSHA256(pwdHash + " - " + key)
if len(result) >= 16 {
return result[:16]
}
return result
}

185
server/go/buffer/buffer.go Normal file
View File

@@ -0,0 +1,185 @@
package buffer
import (
"encoding/binary"
"sync"
)
// Buffer is a thread-safe dynamic buffer for network I/O
type Buffer struct {
data []byte
mu sync.RWMutex
offset int // read offset for lazy compaction
}
// New creates a new buffer with optional initial capacity
func New(capacity ...int) *Buffer {
cap := 4096
if len(capacity) > 0 && capacity[0] > 0 {
cap = capacity[0]
}
return &Buffer{
data: make([]byte, 0, cap),
}
}
// Write appends data to the buffer
func (b *Buffer) Write(p []byte) (int, error) {
b.mu.Lock()
defer b.mu.Unlock()
b.data = append(b.data, p...)
return len(p), nil
}
// WriteUint32 writes a uint32 in little-endian format
func (b *Buffer) WriteUint32(v uint32) {
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, v)
b.Write(buf)
}
// Read reads and removes data from the buffer
func (b *Buffer) Read(n int) []byte {
b.mu.Lock()
defer b.mu.Unlock()
available := len(b.data) - b.offset
if n > available {
n = available
}
if n <= 0 {
return nil
}
result := make([]byte, n)
copy(result, b.data[b.offset:b.offset+n])
b.offset += n
// Compact when offset is large enough
if b.offset > len(b.data)/2 && b.offset > 1024 {
b.compact()
}
return result
}
// Peek returns data without removing it
func (b *Buffer) Peek(n int) []byte {
b.mu.RLock()
defer b.mu.RUnlock()
available := len(b.data) - b.offset
if n > available {
n = available
}
if n <= 0 {
return nil
}
result := make([]byte, n)
copy(result, b.data[b.offset:b.offset+n])
return result
}
// PeekAt returns data at a specific offset without removing it
func (b *Buffer) PeekAt(offset, n int) []byte {
b.mu.RLock()
defer b.mu.RUnlock()
start := b.offset + offset
if start >= len(b.data) {
return nil
}
end := start + n
if end > len(b.data) {
end = len(b.data)
}
result := make([]byte, end-start)
copy(result, b.data[start:end])
return result
}
// Skip removes n bytes from the beginning
func (b *Buffer) Skip(n int) int {
b.mu.Lock()
defer b.mu.Unlock()
available := len(b.data) - b.offset
if n > available {
n = available
}
b.offset += n
// Compact when offset is large enough
if b.offset > len(b.data)/2 && b.offset > 1024 {
b.compact()
}
return n
}
// Len returns the length of unread data
func (b *Buffer) Len() int {
b.mu.RLock()
defer b.mu.RUnlock()
return len(b.data) - b.offset
}
// Bytes returns all unread data without removing it
func (b *Buffer) Bytes() []byte {
b.mu.RLock()
defer b.mu.RUnlock()
n := len(b.data) - b.offset
if n <= 0 {
return nil
}
result := make([]byte, n)
copy(result, b.data[b.offset:])
return result
}
// Clear removes all data from the buffer
func (b *Buffer) Clear() {
b.mu.Lock()
defer b.mu.Unlock()
b.data = b.data[:0]
b.offset = 0
}
// compact moves remaining data to the beginning
func (b *Buffer) compact() {
if b.offset > 0 {
remaining := len(b.data) - b.offset
copy(b.data[:remaining], b.data[b.offset:])
b.data = b.data[:remaining]
b.offset = 0
}
}
// GetByte returns a single byte at offset
func (b *Buffer) GetByte(offset int) byte {
b.mu.RLock()
defer b.mu.RUnlock()
idx := b.offset + offset
if idx >= len(b.data) {
return 0
}
return b.data[idx]
}
// GetUint32 returns a uint32 at offset in little-endian format
func (b *Buffer) GetUint32(offset int) uint32 {
b.mu.RLock()
defer b.mu.RUnlock()
idx := b.offset + offset
if idx+4 > len(b.data) {
return 0
}
return binary.LittleEndian.Uint32(b.data[idx : idx+4])
}

268
server/go/cmd/main.go Normal file
View File

@@ -0,0 +1,268 @@
package main
import (
"flag"
"fmt"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/auth"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/logger"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/protocol"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/server"
)
// MyHandler implements the server.Handler interface
type MyHandler struct {
log *logger.Logger
auth *auth.Authenticator
srv *server.Server
}
// OnConnect is called when a client connects
func (h *MyHandler) OnConnect(ctx *connection.Context) {
// Only log connection established, detailed info logged on login
}
// OnDisconnect is called when a client disconnects
func (h *MyHandler) OnDisconnect(ctx *connection.Context) {
info := ctx.GetInfo()
if info.ClientID != "" {
h.log.ClientEvent("offline", ctx.ID, ctx.GetPeerIP(),
"clientID", info.ClientID,
"computer", info.ComputerName,
)
}
}
// OnReceive is called when data is received from a client
func (h *MyHandler) OnReceive(ctx *connection.Context, data []byte) {
if len(data) == 0 {
return
}
cmd := data[0]
// Handle commands
switch cmd {
case protocol.TokenLogin:
h.handleLogin(ctx, data)
case protocol.TokenAuth:
h.handleAuth(ctx, data)
case protocol.TokenHeartbeat:
h.handleHeartbeat(ctx, data)
default:
// Other commands are not implemented yet
h.log.Info("Unhandled command %d from client %d", cmd, ctx.ID)
}
}
// handleLogin handles client login (TOKEN_LOGIN = 102)
func (h *MyHandler) handleLogin(ctx *connection.Context, data []byte) {
info, err := protocol.ParseLoginInfo(data)
if err != nil {
h.log.Error("Failed to parse login info from client %d: %v", ctx.ID, err)
return
}
// Use MasterID from login request as ClientID for logging
clientID := info.MasterID
if clientID == "" {
clientID = fmt.Sprintf("conn-%d", ctx.ID)
}
// Store client info
reserved := info.ParseReserved()
clientInfo := connection.ClientInfo{
ClientID: clientID,
ComputerName: info.PCName,
OS: info.OsVerInfo,
Version: info.ModuleVersion,
HasCamera: info.WebCamExist,
InstallTime: info.StartTime,
}
// Parse additional info from reserved field
if len(reserved) > 0 {
clientInfo.ClientType = info.GetReservedField(0)
}
if len(reserved) > 2 {
clientInfo.CPU = info.GetReservedField(2)
}
if len(reserved) > 4 {
clientInfo.FilePath = info.GetReservedField(4)
}
if len(reserved) > 11 {
clientInfo.IP = info.GetReservedField(11) // Public IP
}
ctx.SetInfo(clientInfo)
ctx.IsLoggedIn.Store(true)
h.log.ClientEvent("online", ctx.ID, ctx.GetPeerIP(),
"clientID", clientID,
"computer", info.PCName,
"os", info.OsVerInfo,
"version", info.ModuleVersion,
"path", clientInfo.FilePath,
)
}
// handleAuth handles authorization request (TOKEN_AUTH = 100)
func (h *MyHandler) handleAuth(ctx *connection.Context, data []byte) {
result := h.auth.Authenticate(data)
info := ctx.GetInfo()
if result.Valid {
h.log.Info("Auth success: clientID=%s computer=%s ip=%s sn=%s passcode=%s",
info.ClientID, info.ComputerName, ctx.GetPeerIP(), result.SN, result.Passcode)
} else {
h.log.Warn("Auth failed: clientID=%s computer=%s ip=%s sn=%s passcode=%s",
info.ClientID, info.ComputerName, ctx.GetPeerIP(), result.SN, result.Passcode)
}
// Build and send response
resp := h.auth.BuildResponse(result)
if err := h.srv.Send(ctx, resp); err != nil {
h.log.Error("Failed to send auth response to client %d: %v", ctx.ID, err)
}
}
// handleHeartbeat handles heartbeat from client (TOKEN_HEARTBEAT = 101)
func (h *MyHandler) handleHeartbeat(ctx *connection.Context, data []byte) {
// Parse Time from heartbeat request (offset 1, 8 bytes)
var hbTime uint64
if len(data) >= 9 {
hbTime = uint64(data[1]) | uint64(data[2])<<8 | uint64(data[3])<<16 | uint64(data[4])<<24 |
uint64(data[5])<<32 | uint64(data[6])<<40 | uint64(data[7])<<48 | uint64(data[8])<<56
}
// Build HeartbeatACK response: CMD_HEARTBEAT_ACK(1) + HeartbeatACK(32)
resp := make([]byte, 33)
resp[0] = protocol.CommandHeartbeat // CMD_HEARTBEAT_ACK = 216
// Time at offset 1 (8 bytes, little-endian)
resp[1] = byte(hbTime)
resp[2] = byte(hbTime >> 8)
resp[3] = byte(hbTime >> 16)
resp[4] = byte(hbTime >> 24)
resp[5] = byte(hbTime >> 32)
resp[6] = byte(hbTime >> 40)
resp[7] = byte(hbTime >> 48)
resp[8] = byte(hbTime >> 56)
// Reserved[24] at offset 9 is already zero
if err := h.srv.Send(ctx, resp); err != nil {
h.log.Error("Failed to send heartbeat ACK to client %d: %v", ctx.ID, err)
}
}
// parsePorts parses a semicolon-separated port string and returns port numbers
func parsePorts(portStr string) ([]int, error) {
var ports []int
parts := strings.Split(portStr, ";")
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
port, err := strconv.Atoi(p)
if err != nil {
return nil, fmt.Errorf("invalid port %q: %v", p, err)
}
if port < 1 || port > 65535 {
return nil, fmt.Errorf("port %d out of range (1-65535)", port)
}
ports = append(ports, port)
}
if len(ports) == 0 {
return nil, fmt.Errorf("no valid ports specified")
}
return ports, nil
}
func main() {
// Parse command line flags
portStr := flag.String("port", "6543", "Server listen ports (semicolon-separated, e.g. 6543;6544;6545)")
flag.StringVar(portStr, "p", "6543", "Server listen ports (shorthand)")
noConsole := flag.Bool("no-console", false, "Disable console output (for daemon mode)")
flag.Parse()
// Parse ports
ports, err := parsePorts(*portStr)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
// Create logger with file output
logCfg := logger.DefaultConfig()
logCfg.Level = logger.LevelDebug
logCfg.Console = !*noConsole
logCfg.File = "logs/server.log"
logCfg.MaxSize = 100 // 100 MB
logCfg.MaxBackups = 10 // keep 10 old files
logCfg.MaxAge = 30 // 30 days
logCfg.Compress = true
log := logger.New(logCfg)
// Create auth config
authCfg := auth.DefaultConfig()
// PwdHash can be set from environment or config file
authCfg.PwdHash = os.Getenv("YAMA_PWDHASH")
if authCfg.PwdHash == "" {
// Default placeholder - should be configured in production
authCfg.PwdHash = "61f04dd637a74ee34493fc1025de2c131022536da751c29e3ff4e9024d8eec43"
}
authCfg.SuperPass = os.Getenv("YAMA_PWD")
// Create authenticator (shared by all servers)
authenticator := auth.New(authCfg)
// Create servers for each port
var servers []*server.Server
for _, port := range ports {
config := server.DefaultConfig()
config.Port = port
config.MaxConnections = 9999
srv := server.New(config)
srv.SetLogger(log.WithPrefix(fmt.Sprintf("Server:%d", port)))
// Create handler for this server
handler := &MyHandler{
log: log.WithPrefix(fmt.Sprintf("Handler:%d", port)),
auth: authenticator,
srv: srv,
}
srv.SetHandler(handler)
servers = append(servers, srv)
}
// Start all servers
for _, srv := range servers {
if err := srv.Start(); err != nil {
log.Fatal("Failed to start server: %v", err)
}
}
fmt.Printf("Server started on port(s): %v\n", ports)
fmt.Println("Logs are written to: logs/server.log")
fmt.Println("Press Ctrl+C to stop...")
// Wait for interrupt signal
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
fmt.Println("\nShutting down...")
for _, srv := range servers {
srv.Stop()
}
fmt.Println("Server stopped")
}

View File

@@ -0,0 +1,197 @@
package connection
import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/buffer"
)
// ClientInfo stores client metadata
type ClientInfo struct {
ClientID string // Client ID from login (MasterID)
IP string
ComputerName string
OS string
CPU string
HasCamera bool
Version string
InstallTime string
LoginTime time.Time
ClientType string
FilePath string
GroupName string
}
// Context represents a client connection context
type Context struct {
ID uint64
Conn net.Conn
RemoteAddr string
// Buffers
InBuffer *buffer.Buffer // Received compressed data
OutBuffer *buffer.Buffer // Decompressed data for processing
// Client info
Info ClientInfo
IsLoggedIn atomic.Bool
// Connection state
OnlineTime time.Time
lastActiveNs atomic.Int64 // Unix nanoseconds for thread-safe access
// Protocol state
CompressMethod int
FlagType FlagType
HeaderLen int
FlagLen int
HeaderEncType int // Header encryption type (0-7)
HeaderParams []byte // Header parameters for decoding (flag bytes)
// User data - for storing dialog/handler references
UserData interface{}
// Internal
mu sync.RWMutex
closed atomic.Bool
sendLock sync.Mutex
server *Manager
}
// FlagType represents the protocol flag type
type FlagType int
const (
FlagUnknown FlagType = iota
FlagShine
FlagFuck
FlagHello
FlagHell
FlagWinOS
)
// Compression methods
const (
CompressUnknown = -2
CompressZlib = -1
CompressZstd = 0
CompressNone = 1
)
// NewContext creates a new connection context
func NewContext(conn net.Conn, mgr *Manager) *Context {
now := time.Now()
ctx := &Context{
Conn: conn,
RemoteAddr: conn.RemoteAddr().String(),
InBuffer: buffer.New(8192),
OutBuffer: buffer.New(8192),
OnlineTime: now,
CompressMethod: CompressZstd,
FlagType: FlagUnknown,
server: mgr,
}
ctx.lastActiveNs.Store(now.UnixNano())
return ctx
}
// Send sends data to the client (thread-safe)
func (c *Context) Send(data []byte) error {
if c.closed.Load() {
return ErrConnectionClosed
}
c.sendLock.Lock()
defer c.sendLock.Unlock()
_, err := c.Conn.Write(data)
if err != nil {
return err
}
c.UpdateLastActive()
return nil
}
// UpdateLastActive updates the last active time (thread-safe)
func (c *Context) UpdateLastActive() {
c.lastActiveNs.Store(time.Now().UnixNano())
}
// LastActive returns the last active time (thread-safe)
func (c *Context) LastActive() time.Time {
return time.Unix(0, c.lastActiveNs.Load())
}
// TimeSinceLastActive returns duration since last activity (thread-safe)
func (c *Context) TimeSinceLastActive() time.Duration {
return time.Since(c.LastActive())
}
// Close closes the connection
func (c *Context) Close() error {
if c.closed.Swap(true) {
return nil // Already closed
}
return c.Conn.Close()
}
// IsClosed returns whether the connection is closed
func (c *Context) IsClosed() bool {
return c.closed.Load()
}
// GetPeerIP returns the peer IP address
func (c *Context) GetPeerIP() string {
if host, _, err := net.SplitHostPort(c.RemoteAddr); err == nil {
return host
}
return c.RemoteAddr
}
// AliveTime returns how long the connection has been alive
func (c *Context) AliveTime() time.Duration {
return time.Since(c.OnlineTime)
}
// SetInfo sets the client info
func (c *Context) SetInfo(info ClientInfo) {
c.mu.Lock()
defer c.mu.Unlock()
c.Info = info
}
// GetInfo returns the client info
func (c *Context) GetInfo() ClientInfo {
c.mu.RLock()
defer c.mu.RUnlock()
return c.Info
}
// SetUserData stores user-defined data
func (c *Context) SetUserData(data interface{}) {
c.mu.Lock()
defer c.mu.Unlock()
c.UserData = data
}
// GetUserData retrieves user-defined data
func (c *Context) GetUserData() interface{} {
c.mu.RLock()
defer c.mu.RUnlock()
return c.UserData
}
// GetClientID returns the client ID for logging
// If ClientID is set (from login), returns it; otherwise returns connection ID as fallback
func (c *Context) GetClientID() string {
c.mu.RLock()
defer c.mu.RUnlock()
if c.Info.ClientID != "" {
return c.Info.ClientID
}
return fmt.Sprintf("conn-%d", c.ID)
}

View File

@@ -0,0 +1,23 @@
package connection
import "errors"
var (
// ErrConnectionClosed indicates the connection is closed
ErrConnectionClosed = errors.New("connection closed")
// ErrServerClosed indicates the server is shut down
ErrServerClosed = errors.New("server closed")
// ErrMaxConnections indicates max connections reached
ErrMaxConnections = errors.New("max connections reached")
// ErrInvalidPacket indicates an invalid packet
ErrInvalidPacket = errors.New("invalid packet")
// ErrUnsupportedProtocol indicates unsupported protocol
ErrUnsupportedProtocol = errors.New("unsupported protocol")
// ErrDecompressFailed indicates decompression failure
ErrDecompressFailed = errors.New("decompression failed")
)

View File

@@ -0,0 +1,115 @@
package connection
import (
"sync"
"sync/atomic"
)
// Manager manages all client connections
type Manager struct {
connections sync.Map // map[uint64]*Context
count atomic.Int64
maxConns int
idCounter atomic.Uint64
// Callbacks
onConnect func(*Context)
onDisconnect func(*Context)
onReceive func(*Context, []byte)
}
// NewManager creates a new connection manager
func NewManager(maxConns int) *Manager {
if maxConns <= 0 {
maxConns = 10000
}
return &Manager{
maxConns: maxConns,
}
}
// SetCallbacks sets the callback functions
func (m *Manager) SetCallbacks(onConnect, onDisconnect func(*Context), onReceive func(*Context, []byte)) {
m.onConnect = onConnect
m.onDisconnect = onDisconnect
m.onReceive = onReceive
}
// Add adds a new connection
func (m *Manager) Add(ctx *Context) error {
if int(m.count.Load()) >= m.maxConns {
return ErrMaxConnections
}
ctx.ID = m.idCounter.Add(1)
m.connections.Store(ctx.ID, ctx)
m.count.Add(1)
if m.onConnect != nil {
m.onConnect(ctx)
}
return nil
}
// Remove removes a connection
func (m *Manager) Remove(ctx *Context) {
if _, ok := m.connections.LoadAndDelete(ctx.ID); ok {
m.count.Add(-1)
if m.onDisconnect != nil {
m.onDisconnect(ctx)
}
}
}
// Get retrieves a connection by ID
func (m *Manager) Get(id uint64) *Context {
if v, ok := m.connections.Load(id); ok {
return v.(*Context)
}
return nil
}
// Count returns the current connection count
func (m *Manager) Count() int {
return int(m.count.Load())
}
// Range iterates over all connections
func (m *Manager) Range(fn func(*Context) bool) {
m.connections.Range(func(key, value interface{}) bool {
return fn(value.(*Context))
})
}
// Broadcast sends data to all connections
func (m *Manager) Broadcast(data []byte) {
m.connections.Range(func(key, value interface{}) bool {
ctx := value.(*Context)
if !ctx.IsClosed() {
_ = ctx.Send(data)
}
return true
})
}
// CloseAll closes all connections
func (m *Manager) CloseAll() {
m.connections.Range(func(key, value interface{}) bool {
ctx := value.(*Context)
_ = ctx.Close()
return true
})
}
// OnReceive calls the receive callback
func (m *Manager) OnReceive(ctx *Context, data []byte) {
if m.onReceive != nil {
m.onReceive(ctx, data)
}
}
// UpdateMaxConnections updates the maximum connections limit
func (m *Manager) UpdateMaxConnections(max int) {
m.maxConns = max
}

16
server/go/go.mod Normal file
View File

@@ -0,0 +1,16 @@
module github.com/yuanyuanxiang/SimpleRemoter/server/go
go 1.24.5
require (
github.com/klauspost/compress v1.18.2
github.com/rs/zerolog v1.34.0
golang.org/x/text v0.32.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
)
require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
golang.org/x/sys v0.12.0 // indirect
)

21
server/go/go.sum Normal file
View File

@@ -0,0 +1,21 @@
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=

204
server/go/logger/logger.go Normal file
View File

@@ -0,0 +1,204 @@
package logger
import (
"io"
"os"
"path/filepath"
"time"
"github.com/rs/zerolog"
"gopkg.in/natefinch/lumberjack.v2"
)
// Level represents log level
type Level = zerolog.Level
const (
LevelDebug = zerolog.DebugLevel
LevelInfo = zerolog.InfoLevel
LevelWarn = zerolog.WarnLevel
LevelError = zerolog.ErrorLevel
LevelFatal = zerolog.FatalLevel
)
// Logger wraps zerolog.Logger
type Logger struct {
zl zerolog.Logger
}
// Config holds logger configuration
type Config struct {
// Level is the minimum log level
Level Level
// Console enables console output
Console bool
// File is the log file path (empty to disable file logging)
File string
// MaxSize is the max size in MB before rotation
MaxSize int
// MaxBackups is the max number of old log files to keep
MaxBackups int
// MaxAge is the max days to keep old log files
MaxAge int
// Compress enables gzip compression for rotated files
Compress bool
}
// DefaultConfig returns default configuration
func DefaultConfig() Config {
return Config{
Level: LevelInfo,
Console: true,
File: "",
MaxSize: 100,
MaxBackups: 3,
MaxAge: 30,
Compress: true,
}
}
// New creates a new logger with config
func New(cfg Config) *Logger {
var writers []io.Writer
// Console output with color
if cfg.Console {
consoleWriter := zerolog.ConsoleWriter{
Out: os.Stdout,
TimeFormat: "2006-01-02 15:04:05",
}
writers = append(writers, consoleWriter)
}
// File output with rotation
if cfg.File != "" {
// Ensure directory exists
dir := filepath.Dir(cfg.File)
if dir != "" && dir != "." {
_ = os.MkdirAll(dir, 0755)
}
fileWriter := &lumberjack.Logger{
Filename: cfg.File,
MaxSize: cfg.MaxSize,
MaxBackups: cfg.MaxBackups,
MaxAge: cfg.MaxAge,
Compress: cfg.Compress,
LocalTime: true,
}
writers = append(writers, fileWriter)
}
// Combine writers
var writer io.Writer
if len(writers) == 0 {
writer = os.Stdout
} else if len(writers) == 1 {
writer = writers[0]
} else {
writer = zerolog.MultiLevelWriter(writers...)
}
// Create logger
zl := zerolog.New(writer).
Level(cfg.Level).
With().
Timestamp().
Logger()
return &Logger{zl: zl}
}
// WithPrefix returns a new logger with a prefix field
func (l *Logger) WithPrefix(prefix string) *Logger {
return &Logger{
zl: l.zl.With().Str("module", prefix).Logger(),
}
}
// Debug logs a debug message
func (l *Logger) Debug(format string, args ...interface{}) {
l.zl.Debug().Msgf(format, args...)
}
// Info logs an info message
func (l *Logger) Info(format string, args ...interface{}) {
l.zl.Info().Msgf(format, args...)
}
// Warn logs a warning message
func (l *Logger) Warn(format string, args ...interface{}) {
l.zl.Warn().Msgf(format, args...)
}
// Error logs an error message
func (l *Logger) Error(format string, args ...interface{}) {
l.zl.Error().Msgf(format, args...)
}
// Fatal logs a fatal message and exits
func (l *Logger) Fatal(format string, args ...interface{}) {
l.zl.Fatal().Msgf(format, args...)
}
// SetLevel sets the log level
func (l *Logger) SetLevel(level Level) {
l.zl = l.zl.Level(level)
}
// GetLevel returns the current log level
func (l *Logger) GetLevel() Level {
return l.zl.GetLevel()
}
// ClientEvent logs client online/offline events
func (l *Logger) ClientEvent(event string, clientID uint64, ip string, extra ...string) {
e := l.zl.Info().
Str("event", event).
Uint64("client_id", clientID).
Str("ip", ip).
Time("time", time.Now())
if len(extra) >= 2 {
for i := 0; i+1 < len(extra); i += 2 {
e = e.Str(extra[i], extra[i+1])
}
}
e.Msg("")
}
// default global logger
var defaultLogger = New(DefaultConfig())
// SetDefault sets the default global logger
func SetDefault(l *Logger) {
defaultLogger = l
}
// Default returns the default global logger
func Default() *Logger {
return defaultLogger
}
// Package-level convenience functions
func Debug(format string, args ...interface{}) {
defaultLogger.Debug(format, args...)
}
func Info(format string, args ...interface{}) {
defaultLogger.Info(format, args...)
}
func Warn(format string, args ...interface{}) {
defaultLogger.Warn(format, args...)
}
func Error(format string, args ...interface{}) {
defaultLogger.Error(format, args...)
}
func Fatal(format string, args ...interface{}) {
defaultLogger.Fatal(format, args...)
}

161
server/go/protocol/codec.go Normal file
View File

@@ -0,0 +1,161 @@
package protocol
import (
"github.com/klauspost/compress/zstd"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
)
// Codec handles encoding/decoding and compression
type Codec struct {
encoder *zstd.Encoder
decoder *zstd.Decoder
}
// NewCodec creates a new codec
func NewCodec() *Codec {
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedDefault))
if err != nil {
panic("failed to create zstd encoder: " + err.Error())
}
decoder, err := zstd.NewReader(nil)
if err != nil {
panic("failed to create zstd decoder: " + err.Error())
}
return &Codec{
encoder: encoder,
decoder: decoder,
}
}
// Compress compresses data using the appropriate method
func (c *Codec) Compress(ctx *connection.Context, data []byte) ([]byte, error) {
switch ctx.CompressMethod {
case connection.CompressNone:
return data, nil
case connection.CompressZstd:
return c.encoder.EncodeAll(data, nil), nil
default:
// Default to zstd
return c.encoder.EncodeAll(data, nil), nil
}
}
// Decompress decompresses data using the appropriate method
func (c *Codec) Decompress(ctx *connection.Context, data []byte, origLen uint32) ([]byte, error) {
switch ctx.CompressMethod {
case connection.CompressNone:
// No compression, return as-is
result := make([]byte, len(data))
copy(result, data)
return result, nil
case connection.CompressZstd:
result := make([]byte, 0, origLen)
return c.decoder.DecodeAll(data, result)
default:
// Try zstd by default
result := make([]byte, 0, origLen)
return c.decoder.DecodeAll(data, result)
}
}
// Encode encodes data after compression (before sending) - Encoder2
func (c *Codec) Encode(ctx *connection.Context, data []byte) {
// This is Encoder2 - applied after compression
switch ctx.FlagType {
case connection.FlagHello, connection.FlagHell:
// XOREncoder16 - needs param from header
// For now, skip encoding on send since we need the header params
case connection.FlagFuck:
// No encoding after compression for FUCK
}
}
// Decode decodes data before decompression (after receiving) - Encoder2
func (c *Codec) Decode(ctx *connection.Context, data []byte) {
// This is Encoder2 - applied before decompression
// XOREncoder16 for HELL/HELLO protocols
if ctx.FlagType == connection.FlagHell || ctx.FlagType == connection.FlagHello {
// Get k1, k2 from stored header params
if len(ctx.HeaderParams) >= 8 {
k1 := ctx.HeaderParams[6]
k2 := ctx.HeaderParams[7]
if k1 != 0 || k2 != 0 {
xorDecoder16(data, k1, k2)
}
}
}
}
// EncodeData encodes data before compression - Encoder
func (c *Codec) EncodeData(ctx *connection.Context, data []byte) {
// This is Encoder - applied before compression
// Default encoder does nothing
}
// DecodeData decodes data after decompression - Encoder
func (c *Codec) DecodeData(ctx *connection.Context, data []byte) {
// This is Encoder - applied after decompression
// Default encoder does nothing
}
// xorDecoder16 implements XOREncoder16.decrypt_internal
func xorDecoder16(data []byte, k1, k2 byte) {
if len(data) == 0 {
return
}
key := (uint16(k1) << 8) | uint16(k2)
dataLen := len(data)
// Reverse two rounds of pseudo-random swaps
for round := 1; round >= 0; round-- {
for i := dataLen - 1; i >= 0; i-- {
j := int(pseudoRandom(key, i+round*100)) % dataLen
data[i], data[j] = data[j], data[i]
}
}
// XOR decode
for i := 0; i < dataLen; i++ {
data[i] ^= (k1 + byte(i*13)) ^ (k2 ^ byte(i<<1))
}
}
// xorEncoder16 implements XOREncoder16.encrypt_internal
func xorEncoder16(data []byte, k1, k2 byte) {
if len(data) == 0 {
return
}
key := (uint16(k1) << 8) | uint16(k2)
dataLen := len(data)
// XOR encode
for i := 0; i < dataLen; i++ {
data[i] ^= (k1 + byte(i*13)) ^ (k2 ^ byte(i<<1))
}
// Two rounds of pseudo-random swaps
for round := 0; round < 2; round++ {
for i := 0; i < dataLen; i++ {
j := int(pseudoRandom(key, i+round*100)) % dataLen
data[i], data[j] = data[j], data[i]
}
}
}
// pseudoRandom matches the C++ pseudo_random function
func pseudoRandom(seed uint16, index int) uint16 {
return ((seed ^ uint16(index*251+97)) * 733) ^ (seed >> 3)
}
// Close releases resources
func (c *Codec) Close() {
if c.encoder != nil {
c.encoder.Close()
}
if c.decoder != nil {
c.decoder.Close()
}
}

View File

@@ -0,0 +1,202 @@
package protocol
import (
"bytes"
"encoding/binary"
"strings"
"golang.org/x/text/encoding/simplifiedchinese"
"golang.org/x/text/transform"
)
// gbkToUTF8 converts GBK encoded bytes to UTF-8 string
func gbkToUTF8(data []byte) string {
// Find the first null byte and truncate there
if idx := bytes.IndexByte(data, 0); idx >= 0 {
data = data[:idx]
}
if len(data) == 0 {
return ""
}
// Try to decode as GBK
reader := transform.NewReader(bytes.NewReader(data), simplifiedchinese.GBK.NewDecoder())
buf := new(bytes.Buffer)
_, err := buf.ReadFrom(reader)
if err != nil {
// If GBK decoding fails, try treating as UTF-8 or ASCII
return cleanString(string(data))
}
return cleanString(buf.String())
}
// cleanString removes non-printable characters except common whitespace
func cleanString(s string) string {
var result strings.Builder
for _, r := range s {
if r >= 32 || r == '\t' || r == '\n' || r == '\r' {
result.WriteRune(r)
}
}
return strings.TrimSpace(result.String())
}
// Command tokens - matching the C++ definitions
const (
// Server -> Client commands
CommandActived byte = 0 // COMMAND_ACTIVED
CommandBye byte = 204 // COMMAND_BYE - disconnect
CommandHeartbeat byte = 216 // CMD_HEARTBEAT_ACK
// Client -> Server tokens
TokenAuth byte = 100 // TOKEN_AUTH - authorization required
TokenHeartbeat byte = 101 // TOKEN_HEARTBEAT
TokenLogin byte = 102 // TOKEN_LOGIN - login packet
)
// LOGIN_INFOR structure size and offsets (matching C++ struct with default alignment)
// Note: C++ struct uses default alignment (4-byte for uint32/int)
const (
LoginInfoSize = 980 // Total size of LOGIN_INFOR struct (with alignment padding)
// Field offsets (with alignment padding)
OffsetToken = 0 // 1 byte (unsigned char)
OffsetOsVerInfoEx = 1 // 156 bytes (char[156])
// 3 bytes padding here to align dwCPUMHz to 4-byte boundary
OffsetCPUMHz = 160 // 4 bytes (unsigned int) - aligned to 4
OffsetModuleVersion = 164 // 24 bytes (char[24])
OffsetPCName = 188 // 240 bytes (char[240])
OffsetMasterID = 428 // 20 bytes (char[20])
OffsetWebCamExist = 448 // 4 bytes (int) - aligned to 4
OffsetSpeed = 452 // 4 bytes (unsigned int)
OffsetStartTime = 456 // 20 bytes (char[20])
OffsetReserved = 476 // 512 bytes (char[512])
)
// LoginInfo represents client login information
type LoginInfo struct {
Token byte
OsVerInfo string // OS version info
CPUMHz uint32
ModuleVersion string
PCName string // Computer name
MasterID string
WebCamExist bool
Speed uint32
StartTime string
Reserved string // Contains additional info separated by |
}
// ParseLoginInfo parses LOGIN_INFOR from data
func ParseLoginInfo(data []byte) (*LoginInfo, error) {
if len(data) < 100 { // Minimum size check
return nil, ErrInvalidData
}
info := &LoginInfo{
Token: data[0],
}
// Parse OS version info (offset 1, 156 bytes)
// The C++ client fills this with a readable string like "Windows 10" via getSystemName()
if len(data) >= OffsetOsVerInfoEx+156 {
info.OsVerInfo = parseOsVersionInfo(data[OffsetOsVerInfoEx : OffsetOsVerInfoEx+156])
}
// Parse CPU MHz (offset 160, 4 bytes)
if len(data) >= OffsetCPUMHz+4 {
info.CPUMHz = binary.LittleEndian.Uint32(data[OffsetCPUMHz:])
}
// Parse module version (offset 164, 24 bytes)
// This contains date string like "Dec 19 2025"
if len(data) >= OffsetModuleVersion+24 {
info.ModuleVersion = gbkToUTF8(data[OffsetModuleVersion : OffsetModuleVersion+24])
}
// Parse PC name (offset 188, 240 bytes)
if len(data) >= OffsetPCName+240 {
info.PCName = gbkToUTF8(data[OffsetPCName : OffsetPCName+240])
}
// Parse Master ID (offset 428, 20 bytes)
if len(data) >= OffsetMasterID+20 {
info.MasterID = gbkToUTF8(data[OffsetMasterID : OffsetMasterID+20])
}
// Parse WebCam exist (offset 448, 4 bytes)
if len(data) >= OffsetWebCamExist+4 {
info.WebCamExist = binary.LittleEndian.Uint32(data[OffsetWebCamExist:]) != 0
}
// Parse Speed (offset 452, 4 bytes)
if len(data) >= OffsetSpeed+4 {
info.Speed = binary.LittleEndian.Uint32(data[OffsetSpeed:])
}
// Parse Start time (offset 456, 20 bytes)
if len(data) >= OffsetStartTime+20 {
info.StartTime = gbkToUTF8(data[OffsetStartTime : OffsetStartTime+20])
}
// Parse Reserved (offset 476, 512 bytes) - contains additional info
if len(data) >= OffsetReserved+512 {
info.Reserved = gbkToUTF8(data[OffsetReserved : OffsetReserved+512])
} else if len(data) > OffsetReserved {
info.Reserved = gbkToUTF8(data[OffsetReserved:])
}
return info, nil
}
// parseOsVersionInfo parses the OS version info field
// The C++ client fills this with a readable string like "Windows 10" via getSystemName()
func parseOsVersionInfo(data []byte) string {
return gbkToUTF8(data)
}
// ParseReserved parses the reserved field into a slice of strings
func (info *LoginInfo) ParseReserved() []string {
if info.Reserved == "" {
return nil
}
return strings.Split(info.Reserved, "|")
}
// GetReservedField returns a specific field from reserved data by index
// Fields: ClientType(0), SystemBits(1), CPU(2), Memory(3), FilePath(4),
// Reserved(5), InstallTime(6), InstallInfo(7), ProgramBits(8), ExpiredDate(9),
// ClientLoc(10), ClientPubIP(11), ExeVersion(12), Username(13), IsAdmin(14)
func (info *LoginInfo) GetReservedField(index int) string {
fields := info.ParseReserved()
if index >= 0 && index < len(fields) {
return fields[index]
}
return ""
}
// Validation structure for TOKEN_AUTH
type Validation struct {
From string // Start date
To string // End date
Admin string // Admin address
Port int // Admin port
Checksum string // Reserved field
}
// BuildValidation creates a validation response
func BuildValidation(days float64, admin string, port int) []byte {
// This would build the validation structure
// For now, return a simple structure
data := make([]byte, 160) // Size of Validation struct
data[0] = TokenAuth
// Fill in fields...
// From: 20 bytes
// To: 20 bytes
// Admin: 100 bytes
// Port: 4 bytes
// Checksum: 16 bytes
return data
}

View File

@@ -0,0 +1,267 @@
package protocol
// Header encoding/decoding functions
// Ported from common/header.h and common/encfuncs.h
const (
MsgHeader = "HELL"
FlagCompLen = 4
FlagLength = 8
HdrLength = FlagLength + 8 // FLAG_LENGTH + 2 * sizeof(uint32)
MinComLen = 12
)
// HeaderEncType represents the encryption method used for header
type HeaderEncType int
const (
HeaderEncUnknown HeaderEncType = -1
HeaderEncNone HeaderEncType = 0
HeaderEncV0 HeaderEncType = 1
HeaderEncV1 HeaderEncType = 2
HeaderEncV2 HeaderEncType = 3
HeaderEncV3 HeaderEncType = 4
HeaderEncV4 HeaderEncType = 5
HeaderEncV5 HeaderEncType = 6
HeaderEncV6 HeaderEncType = 7
)
// DecryptFunc is the function signature for header decryption
type DecryptFunc func(data []byte, key byte)
// defaultDecrypt does nothing (no encryption)
func defaultDecrypt(data []byte, key byte) {
// No-op
}
// decrypt is the default encryption method (V0)
func decrypt(data []byte, key byte) {
if key == 0 {
return
}
for i := 0; i < len(data); i++ {
k := key ^ byte(i*31)
value := int(data[i])
switch i % 4 {
case 0:
value -= int(k)
case 1:
value = value ^ int(k)
case 2:
value += int(k)
case 3:
value = ^value ^ int(k)
}
data[i] = byte(value & 0xFF)
}
}
// decryptV1 - alternating add/subtract
func decryptV1(data []byte, key byte) {
for i := 0; i < len(data); i++ {
if i%2 == 0 {
data[i] = data[i] - key
} else {
data[i] = data[i] + key
}
}
}
// decryptV2 - XOR with rotation
func decryptV2(data []byte, key byte) {
for i := 0; i < len(data); i++ {
if i%2 == 0 {
data[i] = (data[i] >> 1) | (data[i] << 7) // rotate right
} else {
data[i] = (data[i] << 1) | (data[i] >> 7) // rotate left
}
data[i] ^= key
}
}
// decryptV3 - dynamic key with position
func decryptV3(data []byte, key byte) {
for i := 0; i < len(data); i++ {
dynamicKey := key + byte(i%8)
switch i % 3 {
case 0:
data[i] = (data[i] ^ dynamicKey) - dynamicKey
case 1:
data[i] = (data[i] + dynamicKey) ^ dynamicKey
case 2:
data[i] = ^data[i] - dynamicKey
}
}
}
// decryptV4 - pseudo-random XOR (symmetric)
func decryptV4(data []byte, key byte) {
rand := uint16(key)
for i := 0; i < len(data); i++ {
rand = (rand*13 + 17) % 256
data[i] ^= byte(rand)
}
}
// decryptV5 - dynamic key with bit shift
func decryptV5(data []byte, key byte) {
for i := 0; i < len(data); i++ {
dynamicKey := (key + byte(i)) ^ 0x55
data[i] = ((data[i] - byte(i%7)) ^ (dynamicKey << 3)) - dynamicKey
}
}
// decryptV6 - pseudo-random with position (symmetric)
func decryptV6(data []byte, key byte) {
rand := uint16(key)
for i := 0; i < len(data); i++ {
rand = (rand*31 + 17) % 256
data[i] ^= byte(rand) + byte(i)
}
}
// All decrypt methods
var decryptMethods = []DecryptFunc{
defaultDecrypt,
decrypt,
decryptV1,
decryptV2,
decryptV3,
decryptV4,
decryptV5,
decryptV6,
}
// compare decrypts and compares the flag with magic
func compare(flag []byte, magic string, length int, dec DecryptFunc, key byte) bool {
if len(flag) < MinComLen {
return false
}
buf := make([]byte, MinComLen)
copy(buf, flag[:MinComLen])
dec(buf[:length], key)
return string(buf[:length]) == magic
}
// CheckHead tries all decryption methods to identify the protocol
func CheckHead(flag []byte) (flagType FlagType, encType HeaderEncType, decrypted []byte) {
if len(flag) < MinComLen {
return FlagUnknown, HeaderEncUnknown, nil
}
for i, method := range decryptMethods {
buf := make([]byte, MinComLen)
copy(buf, flag[:MinComLen])
ft := checkHeadWithMethod(buf, method)
if ft != FlagUnknown {
return ft, HeaderEncType(i), buf
}
}
return FlagUnknown, HeaderEncUnknown, nil
}
// checkHeadWithMethod checks the flag with a specific decrypt method
func checkHeadWithMethod(flag []byte, dec DecryptFunc) FlagType {
// Try HELL (FLAG_HELL)
if len(flag) >= FlagLength {
buf := make([]byte, MinComLen)
copy(buf, flag)
key := buf[6]
dec(buf[:FlagCompLen], key)
if string(buf[:4]) == MsgHeader {
copy(flag, buf)
return FlagHell
}
}
// Try Shine (FLAG_SHINE)
if len(flag) >= 5 {
buf := make([]byte, MinComLen)
copy(buf, flag)
dec(buf[:5], 0)
if string(buf[:5]) == "Shine" {
copy(flag, buf)
return FlagShine
}
}
// Try <<FUCK>> (FLAG_FUCK)
if len(flag) >= 10 {
buf := make([]byte, MinComLen)
copy(buf, flag)
key := buf[9]
dec(buf[:8], key)
if string(buf[:8]) == "<<FUCK>>" {
copy(flag, buf)
return FlagFuck
}
}
// Try Hello? (FLAG_HELLO)
if len(flag) >= 7 {
buf := make([]byte, MinComLen)
copy(buf, flag)
key := buf[6]
dec(buf[:6], key)
if string(buf[:6]) == "Hello?" {
copy(flag, buf)
return FlagHello
}
}
return FlagUnknown
}
// FlagType represents the protocol type
type FlagType int
const (
FlagWinOS FlagType = -1
FlagUnknown FlagType = 0
FlagShine FlagType = 1
FlagFuck FlagType = 2
FlagHello FlagType = 3
FlagHell FlagType = 4
)
func (f FlagType) String() string {
switch f {
case FlagWinOS:
return "WinOS"
case FlagShine:
return "Shine"
case FlagFuck:
return "Fuck"
case FlagHello:
return "Hello"
case FlagHell:
return "Hell"
default:
return "Unknown"
}
}
// GetFlagLength returns the flag length for a given flag type
func GetFlagLength(ft FlagType) int {
switch ft {
case FlagShine:
return 5
case FlagFuck:
return 11 // 8 + 3
case FlagHello:
return 8
case FlagHell:
return FlagLength
default:
return 0
}
}
// GetHeaderLength returns the full header length for a given flag type
func GetHeaderLength(ft FlagType) int {
return GetFlagLength(ft) + 8
}

View File

@@ -0,0 +1,261 @@
package protocol
import (
"encoding/binary"
"errors"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
)
// Errors
var (
ErrNeedMore = errors.New("need more data")
ErrInvalidData = errors.New("invalid data")
ErrUnsupported = errors.New("unsupported protocol")
ErrDecompress = errors.New("decompression failed")
)
// Parser handles protocol parsing
type Parser struct {
codec *Codec
}
// NewParser creates a new parser
func NewParser() *Parser {
return &Parser{
codec: NewCodec(),
}
}
// Close releases resources held by the parser
func (p *Parser) Close() {
if p.codec != nil {
p.codec.Close()
}
}
// Parse attempts to parse a complete packet from the buffer
func (p *Parser) Parse(ctx *connection.Context) ([]byte, error) {
buf := ctx.InBuffer
// Need at least minimum bytes to determine protocol
if buf.Len() < MinComLen {
return nil, ErrNeedMore
}
// Check if header is already parsed
if ctx.FlagType == connection.FlagUnknown {
// Try to parse header
if err := p.parseHeader(ctx); err != nil {
return nil, err
}
}
// Now parse the packet
return p.parsePacket(ctx)
}
// parseHeader parses the protocol header with obfuscation handling
func (p *Parser) parseHeader(ctx *connection.Context) error {
buf := ctx.InBuffer
header := buf.Peek(MinComLen)
if header == nil || len(header) < MinComLen {
return ErrNeedMore
}
// Try to decode the header using all encryption methods
flagType, encType, decrypted := CheckHead(header)
if flagType == FlagUnknown {
return ErrUnsupported
}
// Store decrypted header params for later use (for XOREncoder16)
ctx.HeaderParams = make([]byte, MinComLen)
if decrypted != nil {
copy(ctx.HeaderParams, decrypted)
} else {
copy(ctx.HeaderParams, header)
}
// Map protocol FlagType to connection FlagType and set compression method
switch flagType {
case FlagHell:
ctx.FlagType = connection.FlagHell
ctx.FlagLen = FlagLength
ctx.HeaderLen = ctx.FlagLen + 8
ctx.CompressMethod = connection.CompressZstd // HELL uses ZSTD
case FlagHello:
ctx.FlagType = connection.FlagHello
ctx.FlagLen = 8
ctx.HeaderLen = ctx.FlagLen + 8
ctx.CompressMethod = connection.CompressNone // HELLO uses no compression
case FlagShine:
ctx.FlagType = connection.FlagShine
ctx.FlagLen = 5
ctx.HeaderLen = ctx.FlagLen + 8
ctx.CompressMethod = connection.CompressZstd // SHINE uses ZSTD
case FlagFuck:
ctx.FlagType = connection.FlagFuck
ctx.FlagLen = 11
ctx.HeaderLen = ctx.FlagLen + 8
ctx.CompressMethod = connection.CompressZstd // FUCK uses ZSTD
default:
return ErrUnsupported
}
// Store encryption type for later use
ctx.HeaderEncType = int(encType)
return nil
}
// parsePacket parses a complete packet
func (p *Parser) parsePacket(ctx *connection.Context) ([]byte, error) {
buf := ctx.InBuffer
// Check if we have enough data for header
if buf.Len() < ctx.HeaderLen {
return nil, ErrNeedMore
}
// Peek the header to get total length
headerData := buf.Peek(ctx.HeaderLen)
if headerData == nil {
return nil, ErrNeedMore
}
// Decrypt the header first
decryptedHeader := make([]byte, len(headerData))
copy(decryptedHeader, headerData)
// Get the encryption key (usually at position 6)
var key byte
if len(headerData) > 6 {
key = headerData[6] // Use original key before decryption
}
// Decrypt flag portion
if ctx.HeaderEncType >= 0 && ctx.HeaderEncType < len(decryptMethods) {
decryptMethods[ctx.HeaderEncType](decryptedHeader[:FlagCompLen], key)
}
// Read the total length field (after flag)
totalLen := binary.LittleEndian.Uint32(decryptedHeader[ctx.FlagLen:])
// Validate length
if totalLen < uint32(ctx.HeaderLen) || totalLen > 10*1024*1024 {
return nil, ErrInvalidData
}
// Check if we have the complete packet
if buf.Len() < int(totalLen) {
return nil, ErrNeedMore
}
// Read the complete packet
packet := buf.Read(int(totalLen))
if packet == nil {
return nil, ErrInvalidData
}
// Decrypt header portion of packet
if ctx.HeaderEncType >= 0 && ctx.HeaderEncType < len(decryptMethods) {
decryptMethods[ctx.HeaderEncType](packet[:FlagCompLen], key)
}
// Update HeaderParams with this packet's header (for XOREncoder16 k1, k2)
if len(packet) >= ctx.FlagLen {
ctx.HeaderParams = make([]byte, ctx.HeaderLen)
copy(ctx.HeaderParams, packet[:ctx.HeaderLen])
}
// Extract data after header
dataStart := ctx.HeaderLen
data := packet[dataStart:]
// Get original length (before compression)
var origLen uint32
if ctx.FlagType != connection.FlagWinOS {
origLen = binary.LittleEndian.Uint32(packet[ctx.FlagLen+4 : ctx.FlagLen+8])
}
// Decode (XOR, etc.) before decompression - this is Encoder2
p.codec.Decode(ctx, data)
// Decompress
decompressed, err := p.codec.Decompress(ctx, data, origLen)
if err != nil {
return nil, err
}
// Decode after decompression - this is Encoder
p.codec.DecodeData(ctx, decompressed)
return decompressed, nil
}
// Encode encodes data for sending
func (p *Parser) Encode(ctx *connection.Context, data []byte) ([]byte, error) {
// Encode before compression
encoded := make([]byte, len(data))
copy(encoded, data)
p.codec.EncodeData(ctx, encoded)
// Compress
compressed, err := p.codec.Compress(ctx, encoded)
if err != nil {
return nil, err
}
// Build packet
packet := p.buildPacket(ctx, compressed, uint32(len(data)))
return packet, nil
}
// buildPacket builds a complete packet with header
func (p *Parser) buildPacket(ctx *connection.Context, data []byte, origLen uint32) []byte {
totalLen := ctx.HeaderLen + len(data)
packet := make([]byte, totalLen)
// Write flag
flag := p.getFlag(ctx)
copy(packet[:ctx.FlagLen], flag)
// Write total length
binary.LittleEndian.PutUint32(packet[ctx.FlagLen:], uint32(totalLen))
// Write original length
binary.LittleEndian.PutUint32(packet[ctx.FlagLen+4:], origLen)
// Write data
copy(packet[ctx.HeaderLen:], data)
// Encode after building
p.codec.Encode(ctx, packet[ctx.HeaderLen:])
return packet
}
// getFlag returns the protocol flag bytes
func (p *Parser) getFlag(ctx *connection.Context) []byte {
switch ctx.FlagType {
case connection.FlagHell:
flag := make([]byte, FlagLength)
copy(flag, []byte(MsgHeader))
return flag
case connection.FlagHello:
flag := make([]byte, 8)
copy(flag, []byte("Hello?"))
return flag
case connection.FlagShine:
return []byte("Shine")
case connection.FlagFuck:
flag := make([]byte, 11)
copy(flag, []byte("<<FUCK>>"))
return flag
default:
return make([]byte, ctx.FlagLen)
}
}

316
server/go/server/server.go Normal file
View File

@@ -0,0 +1,316 @@
package server
import (
"context"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/logger"
"github.com/yuanyuanxiang/SimpleRemoter/server/go/protocol"
)
// Config holds server configuration
type Config struct {
Port int
MaxConnections int
ReadBufferSize int
WriteBufferSize int
KeepAliveTime time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
}
// DefaultConfig returns default configuration
func DefaultConfig() Config {
return Config{
Port: 6543,
MaxConnections: 9999,
ReadBufferSize: 8192,
WriteBufferSize: 8192,
KeepAliveTime: time.Minute * 5,
ReadTimeout: time.Minute * 2,
WriteTimeout: time.Second * 30,
}
}
// Handler defines the interface for handling client events
type Handler interface {
OnConnect(ctx *connection.Context)
OnDisconnect(ctx *connection.Context)
OnReceive(ctx *connection.Context, data []byte)
}
// Server is the TCP server
type Server struct {
config Config
listener net.Listener
manager *connection.Manager
handler Handler
parser *protocol.Parser
running atomic.Bool
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
// Logger
logger *logger.Logger
}
// New creates a new server
func New(config Config) *Server {
ctx, cancel := context.WithCancel(context.Background())
s := &Server{
config: config,
manager: connection.NewManager(config.MaxConnections),
parser: protocol.NewParser(),
ctx: ctx,
cancel: cancel,
logger: logger.New(logger.DefaultConfig()).WithPrefix("Server"),
}
return s
}
// SetHandler sets the event handler
func (s *Server) SetHandler(h Handler) {
s.handler = h
s.manager.SetCallbacks(
func(ctx *connection.Context) {
if s.handler != nil {
s.handler.OnConnect(ctx)
}
},
func(ctx *connection.Context) {
if s.handler != nil {
s.handler.OnDisconnect(ctx)
}
},
func(ctx *connection.Context, data []byte) {
if s.handler != nil {
s.handler.OnReceive(ctx, data)
}
},
)
}
// SetLogger sets the logger
func (s *Server) SetLogger(l *logger.Logger) {
s.logger = l
}
// Start starts the server
func (s *Server) Start() error {
addr := net.JoinHostPort("0.0.0.0", itoa(s.config.Port))
listener, err := net.Listen("tcp", addr)
if err != nil {
return err
}
s.listener = listener
s.running.Store(true)
s.logger.Info("Server started on port %d", s.config.Port)
s.wg.Add(1)
go s.acceptLoop()
return nil
}
func itoa(i int) string {
if i == 0 {
return "0"
}
var b [20]byte
pos := len(b)
neg := i < 0
if neg {
// Handle math.MinInt safely by using unsigned
u := uint(-i)
for u > 0 {
pos--
b[pos] = byte('0' + u%10)
u /= 10
}
pos--
b[pos] = '-'
} else {
for i > 0 {
pos--
b[pos] = byte('0' + i%10)
i /= 10
}
}
return string(b[pos:])
}
// Stop stops the server gracefully
func (s *Server) Stop() {
if !s.running.Swap(false) {
return
}
s.cancel()
if s.listener != nil {
_ = s.listener.Close()
}
// Close all connections
s.manager.CloseAll()
// Close parser resources
if s.parser != nil {
s.parser.Close()
}
s.wg.Wait()
s.logger.Info("Server stopped")
}
// acceptLoop accepts incoming connections
func (s *Server) acceptLoop() {
defer s.wg.Done()
for s.running.Load() {
conn, err := s.listener.Accept()
if err != nil {
if s.running.Load() {
s.logger.Error("Accept error: %v", err)
}
continue
}
// Check connection limit before spawning goroutine
if s.manager.Count() >= s.config.MaxConnections {
s.logger.Warn("Max connections reached, rejecting new connection from %s", conn.RemoteAddr())
_ = conn.Close()
continue
}
// Handle each connection in its own goroutine
go s.handleConnection(conn)
}
}
// handleConnection handles a single connection
func (s *Server) handleConnection(conn net.Conn) {
// Create context
ctx := connection.NewContext(conn, nil)
// Add to manager
if err := s.manager.Add(ctx); err != nil {
s.logger.Warn("Failed to add connection: %v", err)
_ = conn.Close()
return
}
defer func() {
s.manager.Remove(ctx)
_ = ctx.Close()
}()
// Read loop
buf := make([]byte, s.config.ReadBufferSize)
for !ctx.IsClosed() && s.running.Load() {
// Set read deadline
if s.config.ReadTimeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(s.config.ReadTimeout))
}
n, err := conn.Read(buf)
if err != nil {
if err != io.EOF && s.running.Load() {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Timeout - check keepalive
if s.config.KeepAliveTime > 0 && ctx.TimeSinceLastActive() > s.config.KeepAliveTime {
s.logger.Info("Connection %d timed out", ctx.ID)
break
}
continue
}
}
break
}
if n > 0 {
ctx.UpdateLastActive()
// Write to input buffer
_, _ = ctx.InBuffer.Write(buf[:n])
// Process received data
s.processData(ctx)
}
}
}
// processData processes received data and calls handler
func (s *Server) processData(ctx *connection.Context) {
for ctx.InBuffer.Len() > 0 {
// Try to parse a complete packet
data, err := s.parser.Parse(ctx)
if err != nil {
if err == protocol.ErrNeedMore {
return
}
s.logger.Error("Parse error for connection %d: %v", ctx.ID, err)
_ = ctx.Close()
return
}
if data != nil {
// Call handler
s.manager.OnReceive(ctx, data)
}
}
}
// Send sends data to a specific connection
func (s *Server) Send(ctx *connection.Context, data []byte) error {
if ctx == nil || ctx.IsClosed() {
return connection.ErrConnectionClosed
}
// Encode and compress data
encoded, err := s.parser.Encode(ctx, data)
if err != nil {
return err
}
return ctx.Send(encoded)
}
// Broadcast sends data to all connections
func (s *Server) Broadcast(data []byte) {
s.manager.Range(func(ctx *connection.Context) bool {
_ = s.Send(ctx, data)
return true
})
}
// GetConnection returns a connection by ID
func (s *Server) GetConnection(id uint64) *connection.Context {
return s.manager.Get(id)
}
// ConnectionCount returns the current connection count
func (s *Server) ConnectionCount() int {
return s.manager.Count()
}
// Port returns the server port
func (s *Server) Port() int {
return s.config.Port
}
// UpdateMaxConnections updates the max connections limit
func (s *Server) UpdateMaxConnections(max int) {
s.manager.UpdateMaxConnections(max)
}