mirror of
https://github.com/yuanyuanxiang/SimpleRemoter.git
synced 2026-01-21 23:13:08 +08:00
Feature: Add Go TCP server framework
This commit is contained in:
30
server/go/.vscode/launch.json
vendored
Normal file
30
server/go/.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Launch Server",
|
||||
"type": "go",
|
||||
"request": "launch",
|
||||
"mode": "auto",
|
||||
"program": "${workspaceFolder}/cmd",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [],
|
||||
"env": {},
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "Debug Server",
|
||||
"type": "go",
|
||||
"request": "launch",
|
||||
"mode": "debug",
|
||||
"program": "${workspaceFolder}/cmd",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"-port=9090"
|
||||
],
|
||||
"env": {},
|
||||
"console": "integratedTerminal",
|
||||
"buildFlags": "-gcflags='all=-N -l'"
|
||||
}
|
||||
]
|
||||
}
|
||||
333
server/go/README.md
Normal file
333
server/go/README.md
Normal file
@@ -0,0 +1,333 @@
|
||||
# SimpleRemoter Go TCP Server Framework
|
||||
|
||||
基于 Go 语言实现的高性能 TCP 服务端框架,用于替代原有的 C++ IOCP 服务端。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
server/go/
|
||||
├── go.mod # Go 模块定义
|
||||
├── buffer/
|
||||
│ └── buffer.go # 线程安全的动态缓冲区
|
||||
├── connection/
|
||||
│ ├── context.go # 连接上下文
|
||||
│ ├── errors.go # 错误定义
|
||||
│ └── manager.go # 连接管理器
|
||||
├── protocol/
|
||||
│ ├── parser.go # 协议解析器
|
||||
│ ├── codec.go # 编解码和压缩 (ZSTD)
|
||||
│ ├── header.go # 协议头解密 (8种加密方式)
|
||||
│ └── commands.go # 命令常量和LOGIN_INFOR解析
|
||||
├── server/
|
||||
│ ├── server.go # TCP 服务器核心
|
||||
│ └── pool.go # Goroutine 工作池
|
||||
├── logger/
|
||||
│ └── logger.go # 日志模块 (基于 zerolog)
|
||||
└── cmd/
|
||||
└── main.go # 程序入口
|
||||
```
|
||||
|
||||
## 核心特性
|
||||
|
||||
- **高并发**: 基于 Goroutine 池管理并发连接
|
||||
- **协议兼容**: 支持原有 C++ 客户端的多种协议标识 (Hell/Hello/Shine/Fuck)
|
||||
- **协议头解密**: 支持8种协议头加密方式 (V0-V6 + Default)
|
||||
- **XOR编码**: 支持 XOREncoder16 数据编码/解码
|
||||
- **ZSTD 压缩**: 使用高效的 ZSTD 算法进行数据压缩
|
||||
- **GBK编码**: 自动将 Windows 客户端的 GBK 编码转换为 UTF-8
|
||||
- **线程安全**: Buffer、连接管理器和 LastActive 均为线程安全设计
|
||||
- **优雅关闭**: 支持信号处理和优雅停机,自动释放资源
|
||||
- **可配置**: 支持自定义端口、最大连接数、超时时间等
|
||||
- **日志系统**: 基于 zerolog,支持文件输出、日志轮转、客户端上下线记录
|
||||
|
||||
## 支持的命令
|
||||
|
||||
当前已实现以下命令处理:
|
||||
|
||||
| 命令 | 值 | 说明 |
|
||||
|------|-----|------|
|
||||
| TOKEN_AUTH | 100 | 授权请求 |
|
||||
| TOKEN_HEARTBEAT | 101 | 心跳包 |
|
||||
| TOKEN_LOGIN | 102 | 客户端登录 |
|
||||
|
||||
其他命令会被记录为 Debug 日志,可按需扩展。
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 安装依赖
|
||||
|
||||
```bash
|
||||
cd server/go
|
||||
go mod tidy
|
||||
```
|
||||
|
||||
### 编译
|
||||
|
||||
```bash
|
||||
go build -o simpleremoter-server ./cmd
|
||||
```
|
||||
|
||||
### 运行
|
||||
|
||||
```bash
|
||||
./simpleremoter-server
|
||||
```
|
||||
|
||||
服务器默认监听 6543 端口,日志输出到 `logs/server.log`。
|
||||
|
||||
## 使用示例
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/logger"
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/protocol"
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/server"
|
||||
)
|
||||
|
||||
// 实现 Handler 接口
|
||||
type MyHandler struct {
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func (h *MyHandler) OnConnect(ctx *connection.Context) {
|
||||
h.log.ClientEvent("online", ctx.ID, ctx.GetPeerIP())
|
||||
}
|
||||
|
||||
func (h *MyHandler) OnDisconnect(ctx *connection.Context) {
|
||||
h.log.ClientEvent("offline", ctx.ID, ctx.GetPeerIP())
|
||||
}
|
||||
|
||||
func (h *MyHandler) OnReceive(ctx *connection.Context, data []byte) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
cmd := data[0]
|
||||
switch cmd {
|
||||
case protocol.TokenLogin:
|
||||
info, _ := protocol.ParseLoginInfo(data)
|
||||
h.log.Info("Client login: %s (%s)", info.PCName, info.OsVerInfo)
|
||||
case protocol.TokenHeartbeat:
|
||||
h.log.Debug("Heartbeat from client %d", ctx.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 配置日志 (控制台 + 文件)
|
||||
logCfg := logger.DefaultConfig()
|
||||
logCfg.File = "logs/server.log"
|
||||
log := logger.New(logCfg)
|
||||
|
||||
// 配置服务器
|
||||
config := server.DefaultConfig()
|
||||
config.Port = 6543
|
||||
|
||||
// 创建并启动服务器
|
||||
srv := server.New(config)
|
||||
srv.SetLogger(log.WithPrefix("Server"))
|
||||
srv.SetHandler(&MyHandler{log: log})
|
||||
|
||||
if err := srv.Start(); err != nil {
|
||||
log.Fatal("启动失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待退出信号
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigChan
|
||||
|
||||
srv.Stop()
|
||||
}
|
||||
```
|
||||
|
||||
## 配置选项
|
||||
|
||||
| 配置项 | 默认值 | 说明 |
|
||||
|--------|--------|------|
|
||||
| Port | 8080 | 监听端口 |
|
||||
| MaxConnections | 10000 | 最大连接数 |
|
||||
| MinWorkers | 4 | 最小工作协程数 |
|
||||
| MaxWorkers | 100 | 最大工作协程数 |
|
||||
| ReadBufferSize | 8192 | 读缓冲区大小 |
|
||||
| WriteBufferSize | 8192 | 写缓冲区大小 |
|
||||
| KeepAliveTime | 5min | 连接保活时间 |
|
||||
| ReadTimeout | 2min | 读超时时间 |
|
||||
| WriteTimeout | 30s | 写超时时间 |
|
||||
|
||||
## 日志配置
|
||||
|
||||
| 配置项 | 默认值 | 说明 |
|
||||
|--------|--------|------|
|
||||
| Level | Info | 日志级别 (Debug/Info/Warn/Error/Fatal) |
|
||||
| Console | true | 是否输出到控制台 |
|
||||
| File | "" | 日志文件路径 (空则不写文件) |
|
||||
| MaxSize | 100 | 单个日志文件最大 MB |
|
||||
| MaxBackups | 3 | 保留的旧日志文件数量 |
|
||||
| MaxAge | 30 | 旧日志保留天数 |
|
||||
| Compress | true | 是否压缩轮转的日志 |
|
||||
|
||||
日志示例输出:
|
||||
```json
|
||||
{"level":"info","module":"Server","time":"2025-12-19T13:17:32+01:00","message":"Server started on port 6543"}
|
||||
{"level":"info","module":"Handler","event":"login","client_id":1,"ip":"192.168.0.92","computer":"DESKTOP-BI6RGEJ","os":"Windows 10","version":"Dec 19 2025","time":"2025-12-19T13:17:32+01:00"}
|
||||
{"level":"debug","module":"Handler","time":"2025-12-19T13:17:47+01:00","message":"Heartbeat from client 1 (DESKTOP-BI6RGEJ)"}
|
||||
```
|
||||
|
||||
## 协议格式
|
||||
|
||||
数据包格式与 C++ 版本兼容:
|
||||
|
||||
```
|
||||
+----------+------------+------------+------------------+
|
||||
| Flag | TotalLen | OrigLen | Compressed Data |
|
||||
| (N bytes)| (4 bytes) | (4 bytes) | (variable) |
|
||||
+----------+------------+------------+------------------+
|
||||
```
|
||||
|
||||
### 协议标识
|
||||
|
||||
| 标识 | Flag长度 | 压缩方式 | 说明 |
|
||||
|------|----------|----------|------|
|
||||
| HELL | 8 bytes | ZSTD | 主要协议 |
|
||||
| Hello? | 8 bytes | None | 无压缩协议 |
|
||||
| Shine | 5 bytes | ZSTD | 备用协议 |
|
||||
| <<FUCK>> | 11 bytes | ZSTD | 备用协议 |
|
||||
|
||||
### 协议头加密
|
||||
|
||||
支持8种加密方式,服务端自动检测并解密:
|
||||
- V0 (Default): 动态密钥,4种操作
|
||||
- V1: 交替加减
|
||||
- V2: 带旋转的异或
|
||||
- V3: 带位置的动态密钥
|
||||
- V4: 对称的伪随机异或
|
||||
- V5: 带位移的动态密钥
|
||||
- V6: 带位置的伪随机
|
||||
- V7: 纯异或
|
||||
|
||||
### LOGIN_INFOR 结构
|
||||
|
||||
客户端登录信息结构体 (考虑 C++ 内存对齐):
|
||||
|
||||
| 字段 | 偏移 | 大小 | 说明 |
|
||||
|------|------|------|------|
|
||||
| bToken | 0 | 1 | 命令标识 (102) |
|
||||
| OsVerInfoEx | 1 | 156 | 操作系统版本 |
|
||||
| (padding) | 157 | 3 | 对齐填充 |
|
||||
| dwCPUMHz | 160 | 4 | CPU 频率 |
|
||||
| moduleVersion | 164 | 24 | 模块版本 |
|
||||
| szPCName | 188 | 240 | 计算机名 |
|
||||
| szMasterID | 428 | 20 | 主控 ID |
|
||||
| bWebCamExist | 448 | 4 | 是否有摄像头 |
|
||||
| dwSpeed | 452 | 4 | 网速 |
|
||||
| szStartTime | 456 | 20 | 启动时间 |
|
||||
| szReserved | 476 | 512 | 扩展字段 (用`|`分隔) |
|
||||
|
||||
## API 参考
|
||||
|
||||
### Server
|
||||
|
||||
```go
|
||||
// 创建服务器
|
||||
srv := server.New(config)
|
||||
|
||||
// 设置日志
|
||||
srv.SetLogger(log)
|
||||
|
||||
// 设置事件处理器
|
||||
srv.SetHandler(handler)
|
||||
|
||||
// 启动服务器
|
||||
srv.Start()
|
||||
|
||||
// 停止服务器
|
||||
srv.Stop()
|
||||
|
||||
// 发送数据到指定连接
|
||||
srv.Send(ctx, data)
|
||||
|
||||
// 广播数据到所有连接
|
||||
srv.Broadcast(data)
|
||||
|
||||
// 获取当前连接数
|
||||
count := srv.ConnectionCount()
|
||||
```
|
||||
|
||||
### Connection Context
|
||||
|
||||
```go
|
||||
// 发送数据
|
||||
ctx.Send(data)
|
||||
|
||||
// 关闭连接
|
||||
ctx.Close()
|
||||
|
||||
// 获取客户端 IP
|
||||
ip := ctx.GetPeerIP()
|
||||
|
||||
// 检查连接状态
|
||||
closed := ctx.IsClosed()
|
||||
|
||||
// 获取/更新最后活跃时间 (线程安全)
|
||||
lastActive := ctx.LastActive()
|
||||
ctx.UpdateLastActive()
|
||||
duration := ctx.TimeSinceLastActive()
|
||||
|
||||
// 设置/获取客户端信息
|
||||
ctx.SetInfo(clientInfo)
|
||||
info := ctx.GetInfo()
|
||||
|
||||
// 设置/获取用户数据
|
||||
ctx.SetUserData(myData)
|
||||
data := ctx.GetUserData()
|
||||
```
|
||||
|
||||
### Protocol
|
||||
|
||||
```go
|
||||
// 解析登录信息
|
||||
info, err := protocol.ParseLoginInfo(data)
|
||||
if err == nil {
|
||||
fmt.Println(info.PCName) // 计算机名
|
||||
fmt.Println(info.OsVerInfo) // 操作系统
|
||||
fmt.Println(info.ModuleVersion) // 版本
|
||||
fmt.Println(info.WebCamExist) // 是否有摄像头
|
||||
}
|
||||
|
||||
// 获取扩展字段
|
||||
reserved := info.ParseReserved() // 返回 []string
|
||||
clientType := info.GetReservedField(0) // 客户端类型
|
||||
cpuCores := info.GetReservedField(2) // CPU 核数
|
||||
filePath := info.GetReservedField(4) // 文件路径
|
||||
publicIP := info.GetReservedField(11) // 公网 IP
|
||||
```
|
||||
|
||||
## 与 C++ 版本对比
|
||||
|
||||
| 特性 | C++ (IOCP) | Go |
|
||||
|------|------------|-----|
|
||||
| 并发模型 | IOCP + 线程池 | Goroutine 池 |
|
||||
| 压缩算法 | ZSTD | ZSTD |
|
||||
| 跨平台 | Windows | 全平台 |
|
||||
| 内存管理 | 手动 | GC |
|
||||
| 代码复杂度 | 高 | 低 |
|
||||
| 协议头解密 | 8种方式 | 8种方式 |
|
||||
| XOR编码 | XOREncoder16 | XOREncoder16 |
|
||||
| 字符编码 | GBK | GBK -> UTF-8 |
|
||||
|
||||
## 依赖
|
||||
|
||||
- [github.com/klauspost/compress/zstd](https://github.com/klauspost/compress) - ZSTD 压缩
|
||||
- [github.com/rs/zerolog](https://github.com/rs/zerolog) - 高性能日志
|
||||
- [gopkg.in/natefinch/lumberjack.v2](https://github.com/natefinch/lumberjack) - 日志轮转
|
||||
- [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) - GBK 编码转换
|
||||
|
||||
## License
|
||||
|
||||
MIT License
|
||||
208
server/go/auth/auth.go
Normal file
208
server/go/auth/auth.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/encoding/simplifiedchinese"
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
// Config holds authentication configuration
|
||||
type Config struct {
|
||||
PwdHash string // SHA256 hash of the password (64 hex chars)
|
||||
SuperPass string // Super admin password for HMAC verification
|
||||
}
|
||||
|
||||
// DefaultConfig returns default auth configuration
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
PwdHash: "", // Must be configured
|
||||
SuperPass: "", // Can be set via YAMA_PWD env var
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticator handles token authentication
|
||||
type Authenticator struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
// New creates a new Authenticator
|
||||
func New(config *Config) *Authenticator {
|
||||
return &Authenticator{config: config}
|
||||
}
|
||||
|
||||
// AuthResult contains the result of authentication
|
||||
type AuthResult struct {
|
||||
Valid bool
|
||||
Message string
|
||||
SN string
|
||||
Passcode string
|
||||
}
|
||||
|
||||
// Authenticate validates a TOKEN_AUTH request
|
||||
// Data format:
|
||||
// - offset 0: TOKEN_AUTH command byte
|
||||
// - offset 1-19: SN (serial number, 19 bytes)
|
||||
// - offset 20-62: Passcode (42 bytes)
|
||||
// - offset 62-70: HMAC signature (uint64, 8 bytes) if len > 64
|
||||
func (a *Authenticator) Authenticate(data []byte) *AuthResult {
|
||||
result := &AuthResult{
|
||||
Valid: false,
|
||||
Message: "未获授权或消息哈希校验失败",
|
||||
}
|
||||
|
||||
// Minimum length check: 1 (token) + 19 (sn) + 1 (at least some passcode)
|
||||
if len(data) <= 20 {
|
||||
return result
|
||||
}
|
||||
|
||||
// Extract SN (bytes 1-19)
|
||||
sn := string(data[1:20])
|
||||
result.SN = sn
|
||||
|
||||
// Extract passcode (bytes 20-62, or until end if shorter)
|
||||
passcodeEnd := 62
|
||||
if len(data) < passcodeEnd {
|
||||
passcodeEnd = len(data)
|
||||
}
|
||||
passcode := string(data[20:passcodeEnd])
|
||||
result.Passcode = passcode
|
||||
|
||||
// Extract HMAC if present (bytes 62-70)
|
||||
var hmacSig uint64
|
||||
if len(data) >= 70 {
|
||||
hmacSig = binary.LittleEndian.Uint64(data[62:70])
|
||||
} else if len(data) > 62 {
|
||||
// Partial HMAC data - safely handle incomplete bytes
|
||||
hmacBytes := make([]byte, 8)
|
||||
copy(hmacBytes, data[62:])
|
||||
hmacSig = binary.LittleEndian.Uint64(hmacBytes)
|
||||
}
|
||||
|
||||
// Split passcode by '-'
|
||||
parts := strings.Split(passcode, "-")
|
||||
if len(parts) != 6 && len(parts) != 7 {
|
||||
return result
|
||||
}
|
||||
|
||||
// Get last 4 parts as subvector
|
||||
subvector := parts[len(parts)-4:]
|
||||
|
||||
// Build password string: v[0] + " - " + v[1] + ": " + PwdHash + (optional: ": " + v[2])
|
||||
password := parts[0] + " - " + parts[1] + ": " + a.config.PwdHash
|
||||
if len(parts) == 7 {
|
||||
password += ": " + parts[2]
|
||||
}
|
||||
|
||||
// Derive key from password and SN
|
||||
finalKey := DeriveKey(password, sn)
|
||||
|
||||
// Get fixed length ID
|
||||
hash256 := strings.Join(subvector, "-")
|
||||
fixedKey := GetFixedLengthID(finalKey)
|
||||
|
||||
// Debug output (can be removed in production)
|
||||
// fmt.Printf("DEBUG: password=%q sn=%q finalKey=%s fixedKey=%s hash256=%s\n", password, sn, finalKey, fixedKey, hash256)
|
||||
|
||||
// Compare
|
||||
if hash256 != fixedKey {
|
||||
return result
|
||||
}
|
||||
|
||||
// Passcode validation successful, now verify HMAC
|
||||
superPass := os.Getenv("YAMA_PWD")
|
||||
if superPass == "" {
|
||||
superPass = a.config.SuperPass
|
||||
}
|
||||
|
||||
if superPass != "" && hmacSig != 0 {
|
||||
verified := VerifyMessage(superPass, []byte(passcode), hmacSig)
|
||||
if verified {
|
||||
result.Valid = true
|
||||
result.Message = "此程序已获授权,请遵守授权协议,感谢合作"
|
||||
}
|
||||
// If HMAC verification fails, valid remains false
|
||||
} else if hmacSig == 0 {
|
||||
// No HMAC provided but passcode is valid - could be older client
|
||||
// Keep as invalid for security
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// utf8ToGBK converts UTF-8 string to GBK encoded bytes
|
||||
func utf8ToGBK(s string) []byte {
|
||||
reader := transform.NewReader(bytes.NewReader([]byte(s)), simplifiedchinese.GBK.NewEncoder())
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(reader)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// BuildResponse builds the 100-byte response for TOKEN_AUTH
|
||||
func (a *Authenticator) BuildResponse(result *AuthResult) []byte {
|
||||
resp := make([]byte, 100)
|
||||
if result.Valid {
|
||||
resp[0] = 1
|
||||
}
|
||||
// Message starts at offset 4, convert UTF-8 to GBK for Windows client
|
||||
gbkMsg := utf8ToGBK(result.Message)
|
||||
copy(resp[4:], gbkMsg)
|
||||
return resp
|
||||
}
|
||||
|
||||
// HashSHA256 computes SHA256 hash and returns hex string
|
||||
func HashSHA256(data string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(data))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// DeriveKey derives a key from password and hardware ID
|
||||
// Format: SHA256(password + " + " + hardwareID)
|
||||
func DeriveKey(password, hardwareID string) string {
|
||||
return HashSHA256(password + " + " + hardwareID)
|
||||
}
|
||||
|
||||
// GetFixedLengthID formats a hash into fixed length ID
|
||||
// Format: xxxx-xxxx-xxxx-xxxx (first 16 chars split by -)
|
||||
func GetFixedLengthID(hash string) string {
|
||||
if len(hash) < 16 {
|
||||
return hash
|
||||
}
|
||||
return hash[0:4] + "-" + hash[4:8] + "-" + hash[8:12] + "-" + hash[12:16]
|
||||
}
|
||||
|
||||
// SignMessage computes HMAC-SHA256 and returns first 8 bytes as uint64
|
||||
func SignMessage(pwd string, msg []byte) uint64 {
|
||||
h := hmac.New(sha256.New, []byte(pwd))
|
||||
h.Write(msg)
|
||||
hash := h.Sum(nil)
|
||||
return binary.LittleEndian.Uint64(hash[:8])
|
||||
}
|
||||
|
||||
// VerifyMessage verifies HMAC signature
|
||||
func VerifyMessage(pwd string, msg []byte, signature uint64) bool {
|
||||
computed := SignMessage(pwd, msg)
|
||||
return computed == signature
|
||||
}
|
||||
|
||||
// GenHMAC generates HMAC for password verification
|
||||
// This matches the C++ genHMAC function
|
||||
func GenHMAC(pwdHash, superPass string) string {
|
||||
key := HashSHA256(superPass)
|
||||
list := []string{"g", "h", "o", "s", "t"}
|
||||
for _, item := range list {
|
||||
key = HashSHA256(key + " - " + item)
|
||||
}
|
||||
result := HashSHA256(pwdHash + " - " + key)
|
||||
if len(result) >= 16 {
|
||||
return result[:16]
|
||||
}
|
||||
return result
|
||||
}
|
||||
185
server/go/buffer/buffer.go
Normal file
185
server/go/buffer/buffer.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package buffer
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Buffer is a thread-safe dynamic buffer for network I/O
|
||||
type Buffer struct {
|
||||
data []byte
|
||||
mu sync.RWMutex
|
||||
offset int // read offset for lazy compaction
|
||||
}
|
||||
|
||||
// New creates a new buffer with optional initial capacity
|
||||
func New(capacity ...int) *Buffer {
|
||||
cap := 4096
|
||||
if len(capacity) > 0 && capacity[0] > 0 {
|
||||
cap = capacity[0]
|
||||
}
|
||||
return &Buffer{
|
||||
data: make([]byte, 0, cap),
|
||||
}
|
||||
}
|
||||
|
||||
// Write appends data to the buffer
|
||||
func (b *Buffer) Write(p []byte) (int, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.data = append(b.data, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// WriteUint32 writes a uint32 in little-endian format
|
||||
func (b *Buffer) WriteUint32(v uint32) {
|
||||
buf := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(buf, v)
|
||||
b.Write(buf)
|
||||
}
|
||||
|
||||
// Read reads and removes data from the buffer
|
||||
func (b *Buffer) Read(n int) []byte {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
available := len(b.data) - b.offset
|
||||
if n > available {
|
||||
n = available
|
||||
}
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]byte, n)
|
||||
copy(result, b.data[b.offset:b.offset+n])
|
||||
b.offset += n
|
||||
|
||||
// Compact when offset is large enough
|
||||
if b.offset > len(b.data)/2 && b.offset > 1024 {
|
||||
b.compact()
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Peek returns data without removing it
|
||||
func (b *Buffer) Peek(n int) []byte {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
available := len(b.data) - b.offset
|
||||
if n > available {
|
||||
n = available
|
||||
}
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]byte, n)
|
||||
copy(result, b.data[b.offset:b.offset+n])
|
||||
return result
|
||||
}
|
||||
|
||||
// PeekAt returns data at a specific offset without removing it
|
||||
func (b *Buffer) PeekAt(offset, n int) []byte {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
start := b.offset + offset
|
||||
if start >= len(b.data) {
|
||||
return nil
|
||||
}
|
||||
|
||||
end := start + n
|
||||
if end > len(b.data) {
|
||||
end = len(b.data)
|
||||
}
|
||||
|
||||
result := make([]byte, end-start)
|
||||
copy(result, b.data[start:end])
|
||||
return result
|
||||
}
|
||||
|
||||
// Skip removes n bytes from the beginning
|
||||
func (b *Buffer) Skip(n int) int {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
available := len(b.data) - b.offset
|
||||
if n > available {
|
||||
n = available
|
||||
}
|
||||
b.offset += n
|
||||
|
||||
// Compact when offset is large enough
|
||||
if b.offset > len(b.data)/2 && b.offset > 1024 {
|
||||
b.compact()
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// Len returns the length of unread data
|
||||
func (b *Buffer) Len() int {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
return len(b.data) - b.offset
|
||||
}
|
||||
|
||||
// Bytes returns all unread data without removing it
|
||||
func (b *Buffer) Bytes() []byte {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
n := len(b.data) - b.offset
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make([]byte, n)
|
||||
copy(result, b.data[b.offset:])
|
||||
return result
|
||||
}
|
||||
|
||||
// Clear removes all data from the buffer
|
||||
func (b *Buffer) Clear() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.data = b.data[:0]
|
||||
b.offset = 0
|
||||
}
|
||||
|
||||
// compact moves remaining data to the beginning
|
||||
func (b *Buffer) compact() {
|
||||
if b.offset > 0 {
|
||||
remaining := len(b.data) - b.offset
|
||||
copy(b.data[:remaining], b.data[b.offset:])
|
||||
b.data = b.data[:remaining]
|
||||
b.offset = 0
|
||||
}
|
||||
}
|
||||
|
||||
// GetByte returns a single byte at offset
|
||||
func (b *Buffer) GetByte(offset int) byte {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
idx := b.offset + offset
|
||||
if idx >= len(b.data) {
|
||||
return 0
|
||||
}
|
||||
return b.data[idx]
|
||||
}
|
||||
|
||||
// GetUint32 returns a uint32 at offset in little-endian format
|
||||
func (b *Buffer) GetUint32(offset int) uint32 {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
idx := b.offset + offset
|
||||
if idx+4 > len(b.data) {
|
||||
return 0
|
||||
}
|
||||
return binary.LittleEndian.Uint32(b.data[idx : idx+4])
|
||||
}
|
||||
268
server/go/cmd/main.go
Normal file
268
server/go/cmd/main.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/auth"
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/logger"
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/protocol"
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/server"
|
||||
)
|
||||
|
||||
// MyHandler implements the server.Handler interface
|
||||
type MyHandler struct {
|
||||
log *logger.Logger
|
||||
auth *auth.Authenticator
|
||||
srv *server.Server
|
||||
}
|
||||
|
||||
// OnConnect is called when a client connects
|
||||
func (h *MyHandler) OnConnect(ctx *connection.Context) {
|
||||
// Only log connection established, detailed info logged on login
|
||||
}
|
||||
|
||||
// OnDisconnect is called when a client disconnects
|
||||
func (h *MyHandler) OnDisconnect(ctx *connection.Context) {
|
||||
info := ctx.GetInfo()
|
||||
if info.ClientID != "" {
|
||||
h.log.ClientEvent("offline", ctx.ID, ctx.GetPeerIP(),
|
||||
"clientID", info.ClientID,
|
||||
"computer", info.ComputerName,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// OnReceive is called when data is received from a client
|
||||
func (h *MyHandler) OnReceive(ctx *connection.Context, data []byte) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
cmd := data[0]
|
||||
// Handle commands
|
||||
switch cmd {
|
||||
case protocol.TokenLogin:
|
||||
h.handleLogin(ctx, data)
|
||||
case protocol.TokenAuth:
|
||||
h.handleAuth(ctx, data)
|
||||
case protocol.TokenHeartbeat:
|
||||
h.handleHeartbeat(ctx, data)
|
||||
default:
|
||||
// Other commands are not implemented yet
|
||||
h.log.Info("Unhandled command %d from client %d", cmd, ctx.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// handleLogin handles client login (TOKEN_LOGIN = 102)
|
||||
func (h *MyHandler) handleLogin(ctx *connection.Context, data []byte) {
|
||||
info, err := protocol.ParseLoginInfo(data)
|
||||
if err != nil {
|
||||
h.log.Error("Failed to parse login info from client %d: %v", ctx.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Use MasterID from login request as ClientID for logging
|
||||
clientID := info.MasterID
|
||||
if clientID == "" {
|
||||
clientID = fmt.Sprintf("conn-%d", ctx.ID)
|
||||
}
|
||||
|
||||
// Store client info
|
||||
reserved := info.ParseReserved()
|
||||
clientInfo := connection.ClientInfo{
|
||||
ClientID: clientID,
|
||||
ComputerName: info.PCName,
|
||||
OS: info.OsVerInfo,
|
||||
Version: info.ModuleVersion,
|
||||
HasCamera: info.WebCamExist,
|
||||
InstallTime: info.StartTime,
|
||||
}
|
||||
|
||||
// Parse additional info from reserved field
|
||||
if len(reserved) > 0 {
|
||||
clientInfo.ClientType = info.GetReservedField(0)
|
||||
}
|
||||
if len(reserved) > 2 {
|
||||
clientInfo.CPU = info.GetReservedField(2)
|
||||
}
|
||||
if len(reserved) > 4 {
|
||||
clientInfo.FilePath = info.GetReservedField(4)
|
||||
}
|
||||
if len(reserved) > 11 {
|
||||
clientInfo.IP = info.GetReservedField(11) // Public IP
|
||||
}
|
||||
|
||||
ctx.SetInfo(clientInfo)
|
||||
ctx.IsLoggedIn.Store(true)
|
||||
|
||||
h.log.ClientEvent("online", ctx.ID, ctx.GetPeerIP(),
|
||||
"clientID", clientID,
|
||||
"computer", info.PCName,
|
||||
"os", info.OsVerInfo,
|
||||
"version", info.ModuleVersion,
|
||||
"path", clientInfo.FilePath,
|
||||
)
|
||||
}
|
||||
|
||||
// handleAuth handles authorization request (TOKEN_AUTH = 100)
|
||||
func (h *MyHandler) handleAuth(ctx *connection.Context, data []byte) {
|
||||
result := h.auth.Authenticate(data)
|
||||
info := ctx.GetInfo()
|
||||
|
||||
if result.Valid {
|
||||
h.log.Info("Auth success: clientID=%s computer=%s ip=%s sn=%s passcode=%s",
|
||||
info.ClientID, info.ComputerName, ctx.GetPeerIP(), result.SN, result.Passcode)
|
||||
} else {
|
||||
h.log.Warn("Auth failed: clientID=%s computer=%s ip=%s sn=%s passcode=%s",
|
||||
info.ClientID, info.ComputerName, ctx.GetPeerIP(), result.SN, result.Passcode)
|
||||
}
|
||||
|
||||
// Build and send response
|
||||
resp := h.auth.BuildResponse(result)
|
||||
if err := h.srv.Send(ctx, resp); err != nil {
|
||||
h.log.Error("Failed to send auth response to client %d: %v", ctx.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleHeartbeat handles heartbeat from client (TOKEN_HEARTBEAT = 101)
|
||||
func (h *MyHandler) handleHeartbeat(ctx *connection.Context, data []byte) {
|
||||
|
||||
// Parse Time from heartbeat request (offset 1, 8 bytes)
|
||||
var hbTime uint64
|
||||
if len(data) >= 9 {
|
||||
hbTime = uint64(data[1]) | uint64(data[2])<<8 | uint64(data[3])<<16 | uint64(data[4])<<24 |
|
||||
uint64(data[5])<<32 | uint64(data[6])<<40 | uint64(data[7])<<48 | uint64(data[8])<<56
|
||||
}
|
||||
|
||||
// Build HeartbeatACK response: CMD_HEARTBEAT_ACK(1) + HeartbeatACK(32)
|
||||
resp := make([]byte, 33)
|
||||
resp[0] = protocol.CommandHeartbeat // CMD_HEARTBEAT_ACK = 216
|
||||
// Time at offset 1 (8 bytes, little-endian)
|
||||
resp[1] = byte(hbTime)
|
||||
resp[2] = byte(hbTime >> 8)
|
||||
resp[3] = byte(hbTime >> 16)
|
||||
resp[4] = byte(hbTime >> 24)
|
||||
resp[5] = byte(hbTime >> 32)
|
||||
resp[6] = byte(hbTime >> 40)
|
||||
resp[7] = byte(hbTime >> 48)
|
||||
resp[8] = byte(hbTime >> 56)
|
||||
// Reserved[24] at offset 9 is already zero
|
||||
|
||||
if err := h.srv.Send(ctx, resp); err != nil {
|
||||
h.log.Error("Failed to send heartbeat ACK to client %d: %v", ctx.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// parsePorts parses a semicolon-separated port string and returns port numbers
|
||||
func parsePorts(portStr string) ([]int, error) {
|
||||
var ports []int
|
||||
parts := strings.Split(portStr, ";")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
port, err := strconv.Atoi(p)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port %q: %v", p, err)
|
||||
}
|
||||
if port < 1 || port > 65535 {
|
||||
return nil, fmt.Errorf("port %d out of range (1-65535)", port)
|
||||
}
|
||||
ports = append(ports, port)
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
return nil, fmt.Errorf("no valid ports specified")
|
||||
}
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Parse command line flags
|
||||
portStr := flag.String("port", "6543", "Server listen ports (semicolon-separated, e.g. 6543;6544;6545)")
|
||||
flag.StringVar(portStr, "p", "6543", "Server listen ports (shorthand)")
|
||||
noConsole := flag.Bool("no-console", false, "Disable console output (for daemon mode)")
|
||||
flag.Parse()
|
||||
|
||||
// Parse ports
|
||||
ports, err := parsePorts(*portStr)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create logger with file output
|
||||
logCfg := logger.DefaultConfig()
|
||||
logCfg.Level = logger.LevelDebug
|
||||
logCfg.Console = !*noConsole
|
||||
logCfg.File = "logs/server.log"
|
||||
logCfg.MaxSize = 100 // 100 MB
|
||||
logCfg.MaxBackups = 10 // keep 10 old files
|
||||
logCfg.MaxAge = 30 // 30 days
|
||||
logCfg.Compress = true
|
||||
|
||||
log := logger.New(logCfg)
|
||||
|
||||
// Create auth config
|
||||
authCfg := auth.DefaultConfig()
|
||||
// PwdHash can be set from environment or config file
|
||||
authCfg.PwdHash = os.Getenv("YAMA_PWDHASH")
|
||||
if authCfg.PwdHash == "" {
|
||||
// Default placeholder - should be configured in production
|
||||
authCfg.PwdHash = "61f04dd637a74ee34493fc1025de2c131022536da751c29e3ff4e9024d8eec43"
|
||||
}
|
||||
authCfg.SuperPass = os.Getenv("YAMA_PWD")
|
||||
|
||||
// Create authenticator (shared by all servers)
|
||||
authenticator := auth.New(authCfg)
|
||||
|
||||
// Create servers for each port
|
||||
var servers []*server.Server
|
||||
for _, port := range ports {
|
||||
config := server.DefaultConfig()
|
||||
config.Port = port
|
||||
config.MaxConnections = 9999
|
||||
|
||||
srv := server.New(config)
|
||||
srv.SetLogger(log.WithPrefix(fmt.Sprintf("Server:%d", port)))
|
||||
|
||||
// Create handler for this server
|
||||
handler := &MyHandler{
|
||||
log: log.WithPrefix(fmt.Sprintf("Handler:%d", port)),
|
||||
auth: authenticator,
|
||||
srv: srv,
|
||||
}
|
||||
srv.SetHandler(handler)
|
||||
|
||||
servers = append(servers, srv)
|
||||
}
|
||||
|
||||
// Start all servers
|
||||
for _, srv := range servers {
|
||||
if err := srv.Start(); err != nil {
|
||||
log.Fatal("Failed to start server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Server started on port(s): %v\n", ports)
|
||||
fmt.Println("Logs are written to: logs/server.log")
|
||||
fmt.Println("Press Ctrl+C to stop...")
|
||||
|
||||
// Wait for interrupt signal
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigChan
|
||||
|
||||
fmt.Println("\nShutting down...")
|
||||
for _, srv := range servers {
|
||||
srv.Stop()
|
||||
}
|
||||
fmt.Println("Server stopped")
|
||||
}
|
||||
197
server/go/connection/context.go
Normal file
197
server/go/connection/context.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/buffer"
|
||||
)
|
||||
|
||||
// ClientInfo stores client metadata
|
||||
type ClientInfo struct {
|
||||
ClientID string // Client ID from login (MasterID)
|
||||
IP string
|
||||
ComputerName string
|
||||
OS string
|
||||
CPU string
|
||||
HasCamera bool
|
||||
Version string
|
||||
InstallTime string
|
||||
LoginTime time.Time
|
||||
ClientType string
|
||||
FilePath string
|
||||
GroupName string
|
||||
}
|
||||
|
||||
// Context represents a client connection context
|
||||
type Context struct {
|
||||
ID uint64
|
||||
Conn net.Conn
|
||||
RemoteAddr string
|
||||
|
||||
// Buffers
|
||||
InBuffer *buffer.Buffer // Received compressed data
|
||||
OutBuffer *buffer.Buffer // Decompressed data for processing
|
||||
|
||||
// Client info
|
||||
Info ClientInfo
|
||||
IsLoggedIn atomic.Bool
|
||||
|
||||
// Connection state
|
||||
OnlineTime time.Time
|
||||
lastActiveNs atomic.Int64 // Unix nanoseconds for thread-safe access
|
||||
|
||||
// Protocol state
|
||||
CompressMethod int
|
||||
FlagType FlagType
|
||||
HeaderLen int
|
||||
FlagLen int
|
||||
HeaderEncType int // Header encryption type (0-7)
|
||||
HeaderParams []byte // Header parameters for decoding (flag bytes)
|
||||
|
||||
// User data - for storing dialog/handler references
|
||||
UserData interface{}
|
||||
|
||||
// Internal
|
||||
mu sync.RWMutex
|
||||
closed atomic.Bool
|
||||
sendLock sync.Mutex
|
||||
server *Manager
|
||||
}
|
||||
|
||||
// FlagType represents the protocol flag type
|
||||
type FlagType int
|
||||
|
||||
const (
|
||||
FlagUnknown FlagType = iota
|
||||
FlagShine
|
||||
FlagFuck
|
||||
FlagHello
|
||||
FlagHell
|
||||
FlagWinOS
|
||||
)
|
||||
|
||||
// Compression methods
|
||||
const (
|
||||
CompressUnknown = -2
|
||||
CompressZlib = -1
|
||||
CompressZstd = 0
|
||||
CompressNone = 1
|
||||
)
|
||||
|
||||
// NewContext creates a new connection context
|
||||
func NewContext(conn net.Conn, mgr *Manager) *Context {
|
||||
now := time.Now()
|
||||
ctx := &Context{
|
||||
Conn: conn,
|
||||
RemoteAddr: conn.RemoteAddr().String(),
|
||||
InBuffer: buffer.New(8192),
|
||||
OutBuffer: buffer.New(8192),
|
||||
OnlineTime: now,
|
||||
CompressMethod: CompressZstd,
|
||||
FlagType: FlagUnknown,
|
||||
server: mgr,
|
||||
}
|
||||
ctx.lastActiveNs.Store(now.UnixNano())
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Send sends data to the client (thread-safe)
|
||||
func (c *Context) Send(data []byte) error {
|
||||
if c.closed.Load() {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
|
||||
c.sendLock.Lock()
|
||||
defer c.sendLock.Unlock()
|
||||
|
||||
_, err := c.Conn.Write(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.UpdateLastActive()
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateLastActive updates the last active time (thread-safe)
|
||||
func (c *Context) UpdateLastActive() {
|
||||
c.lastActiveNs.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// LastActive returns the last active time (thread-safe)
|
||||
func (c *Context) LastActive() time.Time {
|
||||
return time.Unix(0, c.lastActiveNs.Load())
|
||||
}
|
||||
|
||||
// TimeSinceLastActive returns duration since last activity (thread-safe)
|
||||
func (c *Context) TimeSinceLastActive() time.Duration {
|
||||
return time.Since(c.LastActive())
|
||||
}
|
||||
|
||||
// Close closes the connection
|
||||
func (c *Context) Close() error {
|
||||
if c.closed.Swap(true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// IsClosed returns whether the connection is closed
|
||||
func (c *Context) IsClosed() bool {
|
||||
return c.closed.Load()
|
||||
}
|
||||
|
||||
// GetPeerIP returns the peer IP address
|
||||
func (c *Context) GetPeerIP() string {
|
||||
if host, _, err := net.SplitHostPort(c.RemoteAddr); err == nil {
|
||||
return host
|
||||
}
|
||||
return c.RemoteAddr
|
||||
}
|
||||
|
||||
// AliveTime returns how long the connection has been alive
|
||||
func (c *Context) AliveTime() time.Duration {
|
||||
return time.Since(c.OnlineTime)
|
||||
}
|
||||
|
||||
// SetInfo sets the client info
|
||||
func (c *Context) SetInfo(info ClientInfo) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Info = info
|
||||
}
|
||||
|
||||
// GetInfo returns the client info
|
||||
func (c *Context) GetInfo() ClientInfo {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.Info
|
||||
}
|
||||
|
||||
// SetUserData stores user-defined data
|
||||
func (c *Context) SetUserData(data interface{}) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.UserData = data
|
||||
}
|
||||
|
||||
// GetUserData retrieves user-defined data
|
||||
func (c *Context) GetUserData() interface{} {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.UserData
|
||||
}
|
||||
|
||||
// GetClientID returns the client ID for logging
|
||||
// If ClientID is set (from login), returns it; otherwise returns connection ID as fallback
|
||||
func (c *Context) GetClientID() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.Info.ClientID != "" {
|
||||
return c.Info.ClientID
|
||||
}
|
||||
return fmt.Sprintf("conn-%d", c.ID)
|
||||
}
|
||||
23
server/go/connection/errors.go
Normal file
23
server/go/connection/errors.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package connection
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrConnectionClosed indicates the connection is closed
|
||||
ErrConnectionClosed = errors.New("connection closed")
|
||||
|
||||
// ErrServerClosed indicates the server is shut down
|
||||
ErrServerClosed = errors.New("server closed")
|
||||
|
||||
// ErrMaxConnections indicates max connections reached
|
||||
ErrMaxConnections = errors.New("max connections reached")
|
||||
|
||||
// ErrInvalidPacket indicates an invalid packet
|
||||
ErrInvalidPacket = errors.New("invalid packet")
|
||||
|
||||
// ErrUnsupportedProtocol indicates unsupported protocol
|
||||
ErrUnsupportedProtocol = errors.New("unsupported protocol")
|
||||
|
||||
// ErrDecompressFailed indicates decompression failure
|
||||
ErrDecompressFailed = errors.New("decompression failed")
|
||||
)
|
||||
115
server/go/connection/manager.go
Normal file
115
server/go/connection/manager.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Manager manages all client connections
|
||||
type Manager struct {
|
||||
connections sync.Map // map[uint64]*Context
|
||||
count atomic.Int64
|
||||
maxConns int
|
||||
idCounter atomic.Uint64
|
||||
|
||||
// Callbacks
|
||||
onConnect func(*Context)
|
||||
onDisconnect func(*Context)
|
||||
onReceive func(*Context, []byte)
|
||||
}
|
||||
|
||||
// NewManager creates a new connection manager
|
||||
func NewManager(maxConns int) *Manager {
|
||||
if maxConns <= 0 {
|
||||
maxConns = 10000
|
||||
}
|
||||
return &Manager{
|
||||
maxConns: maxConns,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCallbacks sets the callback functions
|
||||
func (m *Manager) SetCallbacks(onConnect, onDisconnect func(*Context), onReceive func(*Context, []byte)) {
|
||||
m.onConnect = onConnect
|
||||
m.onDisconnect = onDisconnect
|
||||
m.onReceive = onReceive
|
||||
}
|
||||
|
||||
// Add adds a new connection
|
||||
func (m *Manager) Add(ctx *Context) error {
|
||||
if int(m.count.Load()) >= m.maxConns {
|
||||
return ErrMaxConnections
|
||||
}
|
||||
|
||||
ctx.ID = m.idCounter.Add(1)
|
||||
m.connections.Store(ctx.ID, ctx)
|
||||
m.count.Add(1)
|
||||
|
||||
if m.onConnect != nil {
|
||||
m.onConnect(ctx)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove removes a connection
|
||||
func (m *Manager) Remove(ctx *Context) {
|
||||
if _, ok := m.connections.LoadAndDelete(ctx.ID); ok {
|
||||
m.count.Add(-1)
|
||||
if m.onDisconnect != nil {
|
||||
m.onDisconnect(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a connection by ID
|
||||
func (m *Manager) Get(id uint64) *Context {
|
||||
if v, ok := m.connections.Load(id); ok {
|
||||
return v.(*Context)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count returns the current connection count
|
||||
func (m *Manager) Count() int {
|
||||
return int(m.count.Load())
|
||||
}
|
||||
|
||||
// Range iterates over all connections
|
||||
func (m *Manager) Range(fn func(*Context) bool) {
|
||||
m.connections.Range(func(key, value interface{}) bool {
|
||||
return fn(value.(*Context))
|
||||
})
|
||||
}
|
||||
|
||||
// Broadcast sends data to all connections
|
||||
func (m *Manager) Broadcast(data []byte) {
|
||||
m.connections.Range(func(key, value interface{}) bool {
|
||||
ctx := value.(*Context)
|
||||
if !ctx.IsClosed() {
|
||||
_ = ctx.Send(data)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// CloseAll closes all connections
|
||||
func (m *Manager) CloseAll() {
|
||||
m.connections.Range(func(key, value interface{}) bool {
|
||||
ctx := value.(*Context)
|
||||
_ = ctx.Close()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// OnReceive calls the receive callback
|
||||
func (m *Manager) OnReceive(ctx *Context, data []byte) {
|
||||
if m.onReceive != nil {
|
||||
m.onReceive(ctx, data)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateMaxConnections updates the maximum connections limit
|
||||
func (m *Manager) UpdateMaxConnections(max int) {
|
||||
m.maxConns = max
|
||||
}
|
||||
16
server/go/go.mod
Normal file
16
server/go/go.mod
Normal file
@@ -0,0 +1,16 @@
|
||||
module github.com/yuanyuanxiang/SimpleRemoter/server/go
|
||||
|
||||
go 1.24.5
|
||||
|
||||
require (
|
||||
github.com/klauspost/compress v1.18.2
|
||||
github.com/rs/zerolog v1.34.0
|
||||
golang.org/x/text v0.32.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
golang.org/x/sys v0.12.0 // indirect
|
||||
)
|
||||
21
server/go/go.sum
Normal file
21
server/go/go.sum
Normal file
@@ -0,0 +1,21 @@
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
|
||||
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
|
||||
204
server/go/logger/logger.go
Normal file
204
server/go/logger/logger.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
// Level represents log level
|
||||
type Level = zerolog.Level
|
||||
|
||||
const (
|
||||
LevelDebug = zerolog.DebugLevel
|
||||
LevelInfo = zerolog.InfoLevel
|
||||
LevelWarn = zerolog.WarnLevel
|
||||
LevelError = zerolog.ErrorLevel
|
||||
LevelFatal = zerolog.FatalLevel
|
||||
)
|
||||
|
||||
// Logger wraps zerolog.Logger
|
||||
type Logger struct {
|
||||
zl zerolog.Logger
|
||||
}
|
||||
|
||||
// Config holds logger configuration
|
||||
type Config struct {
|
||||
// Level is the minimum log level
|
||||
Level Level
|
||||
// Console enables console output
|
||||
Console bool
|
||||
// File is the log file path (empty to disable file logging)
|
||||
File string
|
||||
// MaxSize is the max size in MB before rotation
|
||||
MaxSize int
|
||||
// MaxBackups is the max number of old log files to keep
|
||||
MaxBackups int
|
||||
// MaxAge is the max days to keep old log files
|
||||
MaxAge int
|
||||
// Compress enables gzip compression for rotated files
|
||||
Compress bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns default configuration
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
Level: LevelInfo,
|
||||
Console: true,
|
||||
File: "",
|
||||
MaxSize: 100,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 30,
|
||||
Compress: true,
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new logger with config
|
||||
func New(cfg Config) *Logger {
|
||||
var writers []io.Writer
|
||||
|
||||
// Console output with color
|
||||
if cfg.Console {
|
||||
consoleWriter := zerolog.ConsoleWriter{
|
||||
Out: os.Stdout,
|
||||
TimeFormat: "2006-01-02 15:04:05",
|
||||
}
|
||||
writers = append(writers, consoleWriter)
|
||||
}
|
||||
|
||||
// File output with rotation
|
||||
if cfg.File != "" {
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(cfg.File)
|
||||
if dir != "" && dir != "." {
|
||||
_ = os.MkdirAll(dir, 0755)
|
||||
}
|
||||
|
||||
fileWriter := &lumberjack.Logger{
|
||||
Filename: cfg.File,
|
||||
MaxSize: cfg.MaxSize,
|
||||
MaxBackups: cfg.MaxBackups,
|
||||
MaxAge: cfg.MaxAge,
|
||||
Compress: cfg.Compress,
|
||||
LocalTime: true,
|
||||
}
|
||||
writers = append(writers, fileWriter)
|
||||
}
|
||||
|
||||
// Combine writers
|
||||
var writer io.Writer
|
||||
if len(writers) == 0 {
|
||||
writer = os.Stdout
|
||||
} else if len(writers) == 1 {
|
||||
writer = writers[0]
|
||||
} else {
|
||||
writer = zerolog.MultiLevelWriter(writers...)
|
||||
}
|
||||
|
||||
// Create logger
|
||||
zl := zerolog.New(writer).
|
||||
Level(cfg.Level).
|
||||
With().
|
||||
Timestamp().
|
||||
Logger()
|
||||
|
||||
return &Logger{zl: zl}
|
||||
}
|
||||
|
||||
// WithPrefix returns a new logger with a prefix field
|
||||
func (l *Logger) WithPrefix(prefix string) *Logger {
|
||||
return &Logger{
|
||||
zl: l.zl.With().Str("module", prefix).Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// Debug logs a debug message
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
l.zl.Debug().Msgf(format, args...)
|
||||
}
|
||||
|
||||
// Info logs an info message
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
l.zl.Info().Msgf(format, args...)
|
||||
}
|
||||
|
||||
// Warn logs a warning message
|
||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
||||
l.zl.Warn().Msgf(format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
l.zl.Error().Msgf(format, args...)
|
||||
}
|
||||
|
||||
// Fatal logs a fatal message and exits
|
||||
func (l *Logger) Fatal(format string, args ...interface{}) {
|
||||
l.zl.Fatal().Msgf(format, args...)
|
||||
}
|
||||
|
||||
// SetLevel sets the log level
|
||||
func (l *Logger) SetLevel(level Level) {
|
||||
l.zl = l.zl.Level(level)
|
||||
}
|
||||
|
||||
// GetLevel returns the current log level
|
||||
func (l *Logger) GetLevel() Level {
|
||||
return l.zl.GetLevel()
|
||||
}
|
||||
|
||||
// ClientEvent logs client online/offline events
|
||||
func (l *Logger) ClientEvent(event string, clientID uint64, ip string, extra ...string) {
|
||||
e := l.zl.Info().
|
||||
Str("event", event).
|
||||
Uint64("client_id", clientID).
|
||||
Str("ip", ip).
|
||||
Time("time", time.Now())
|
||||
|
||||
if len(extra) >= 2 {
|
||||
for i := 0; i+1 < len(extra); i += 2 {
|
||||
e = e.Str(extra[i], extra[i+1])
|
||||
}
|
||||
}
|
||||
|
||||
e.Msg("")
|
||||
}
|
||||
|
||||
// default global logger
|
||||
var defaultLogger = New(DefaultConfig())
|
||||
|
||||
// SetDefault sets the default global logger
|
||||
func SetDefault(l *Logger) {
|
||||
defaultLogger = l
|
||||
}
|
||||
|
||||
// Default returns the default global logger
|
||||
func Default() *Logger {
|
||||
return defaultLogger
|
||||
}
|
||||
|
||||
// Package-level convenience functions
|
||||
|
||||
func Debug(format string, args ...interface{}) {
|
||||
defaultLogger.Debug(format, args...)
|
||||
}
|
||||
|
||||
func Info(format string, args ...interface{}) {
|
||||
defaultLogger.Info(format, args...)
|
||||
}
|
||||
|
||||
func Warn(format string, args ...interface{}) {
|
||||
defaultLogger.Warn(format, args...)
|
||||
}
|
||||
|
||||
func Error(format string, args ...interface{}) {
|
||||
defaultLogger.Error(format, args...)
|
||||
}
|
||||
|
||||
func Fatal(format string, args ...interface{}) {
|
||||
defaultLogger.Fatal(format, args...)
|
||||
}
|
||||
161
server/go/protocol/codec.go
Normal file
161
server/go/protocol/codec.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
|
||||
)
|
||||
|
||||
// Codec handles encoding/decoding and compression
|
||||
type Codec struct {
|
||||
encoder *zstd.Encoder
|
||||
decoder *zstd.Decoder
|
||||
}
|
||||
|
||||
// NewCodec creates a new codec
|
||||
func NewCodec() *Codec {
|
||||
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedDefault))
|
||||
if err != nil {
|
||||
panic("failed to create zstd encoder: " + err.Error())
|
||||
}
|
||||
decoder, err := zstd.NewReader(nil)
|
||||
if err != nil {
|
||||
panic("failed to create zstd decoder: " + err.Error())
|
||||
}
|
||||
|
||||
return &Codec{
|
||||
encoder: encoder,
|
||||
decoder: decoder,
|
||||
}
|
||||
}
|
||||
|
||||
// Compress compresses data using the appropriate method
|
||||
func (c *Codec) Compress(ctx *connection.Context, data []byte) ([]byte, error) {
|
||||
switch ctx.CompressMethod {
|
||||
case connection.CompressNone:
|
||||
return data, nil
|
||||
case connection.CompressZstd:
|
||||
return c.encoder.EncodeAll(data, nil), nil
|
||||
default:
|
||||
// Default to zstd
|
||||
return c.encoder.EncodeAll(data, nil), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Decompress decompresses data using the appropriate method
|
||||
func (c *Codec) Decompress(ctx *connection.Context, data []byte, origLen uint32) ([]byte, error) {
|
||||
switch ctx.CompressMethod {
|
||||
case connection.CompressNone:
|
||||
// No compression, return as-is
|
||||
result := make([]byte, len(data))
|
||||
copy(result, data)
|
||||
return result, nil
|
||||
case connection.CompressZstd:
|
||||
result := make([]byte, 0, origLen)
|
||||
return c.decoder.DecodeAll(data, result)
|
||||
default:
|
||||
// Try zstd by default
|
||||
result := make([]byte, 0, origLen)
|
||||
return c.decoder.DecodeAll(data, result)
|
||||
}
|
||||
}
|
||||
|
||||
// Encode encodes data after compression (before sending) - Encoder2
|
||||
func (c *Codec) Encode(ctx *connection.Context, data []byte) {
|
||||
// This is Encoder2 - applied after compression
|
||||
switch ctx.FlagType {
|
||||
case connection.FlagHello, connection.FlagHell:
|
||||
// XOREncoder16 - needs param from header
|
||||
// For now, skip encoding on send since we need the header params
|
||||
case connection.FlagFuck:
|
||||
// No encoding after compression for FUCK
|
||||
}
|
||||
}
|
||||
|
||||
// Decode decodes data before decompression (after receiving) - Encoder2
|
||||
func (c *Codec) Decode(ctx *connection.Context, data []byte) {
|
||||
// This is Encoder2 - applied before decompression
|
||||
// XOREncoder16 for HELL/HELLO protocols
|
||||
if ctx.FlagType == connection.FlagHell || ctx.FlagType == connection.FlagHello {
|
||||
// Get k1, k2 from stored header params
|
||||
if len(ctx.HeaderParams) >= 8 {
|
||||
k1 := ctx.HeaderParams[6]
|
||||
k2 := ctx.HeaderParams[7]
|
||||
if k1 != 0 || k2 != 0 {
|
||||
xorDecoder16(data, k1, k2)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EncodeData encodes data before compression - Encoder
|
||||
func (c *Codec) EncodeData(ctx *connection.Context, data []byte) {
|
||||
// This is Encoder - applied before compression
|
||||
// Default encoder does nothing
|
||||
}
|
||||
|
||||
// DecodeData decodes data after decompression - Encoder
|
||||
func (c *Codec) DecodeData(ctx *connection.Context, data []byte) {
|
||||
// This is Encoder - applied after decompression
|
||||
// Default encoder does nothing
|
||||
}
|
||||
|
||||
// xorDecoder16 implements XOREncoder16.decrypt_internal
|
||||
func xorDecoder16(data []byte, k1, k2 byte) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
key := (uint16(k1) << 8) | uint16(k2)
|
||||
dataLen := len(data)
|
||||
|
||||
// Reverse two rounds of pseudo-random swaps
|
||||
for round := 1; round >= 0; round-- {
|
||||
for i := dataLen - 1; i >= 0; i-- {
|
||||
j := int(pseudoRandom(key, i+round*100)) % dataLen
|
||||
data[i], data[j] = data[j], data[i]
|
||||
}
|
||||
}
|
||||
|
||||
// XOR decode
|
||||
for i := 0; i < dataLen; i++ {
|
||||
data[i] ^= (k1 + byte(i*13)) ^ (k2 ^ byte(i<<1))
|
||||
}
|
||||
}
|
||||
|
||||
// xorEncoder16 implements XOREncoder16.encrypt_internal
|
||||
func xorEncoder16(data []byte, k1, k2 byte) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
key := (uint16(k1) << 8) | uint16(k2)
|
||||
dataLen := len(data)
|
||||
|
||||
// XOR encode
|
||||
for i := 0; i < dataLen; i++ {
|
||||
data[i] ^= (k1 + byte(i*13)) ^ (k2 ^ byte(i<<1))
|
||||
}
|
||||
|
||||
// Two rounds of pseudo-random swaps
|
||||
for round := 0; round < 2; round++ {
|
||||
for i := 0; i < dataLen; i++ {
|
||||
j := int(pseudoRandom(key, i+round*100)) % dataLen
|
||||
data[i], data[j] = data[j], data[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pseudoRandom matches the C++ pseudo_random function
|
||||
func pseudoRandom(seed uint16, index int) uint16 {
|
||||
return ((seed ^ uint16(index*251+97)) * 733) ^ (seed >> 3)
|
||||
}
|
||||
|
||||
// Close releases resources
|
||||
func (c *Codec) Close() {
|
||||
if c.encoder != nil {
|
||||
c.encoder.Close()
|
||||
}
|
||||
if c.decoder != nil {
|
||||
c.decoder.Close()
|
||||
}
|
||||
}
|
||||
202
server/go/protocol/commands.go
Normal file
202
server/go/protocol/commands.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/encoding/simplifiedchinese"
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
// gbkToUTF8 converts GBK encoded bytes to UTF-8 string
|
||||
func gbkToUTF8(data []byte) string {
|
||||
// Find the first null byte and truncate there
|
||||
if idx := bytes.IndexByte(data, 0); idx >= 0 {
|
||||
data = data[:idx]
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try to decode as GBK
|
||||
reader := transform.NewReader(bytes.NewReader(data), simplifiedchinese.GBK.NewDecoder())
|
||||
buf := new(bytes.Buffer)
|
||||
_, err := buf.ReadFrom(reader)
|
||||
if err != nil {
|
||||
// If GBK decoding fails, try treating as UTF-8 or ASCII
|
||||
return cleanString(string(data))
|
||||
}
|
||||
return cleanString(buf.String())
|
||||
}
|
||||
|
||||
// cleanString removes non-printable characters except common whitespace
|
||||
func cleanString(s string) string {
|
||||
var result strings.Builder
|
||||
for _, r := range s {
|
||||
if r >= 32 || r == '\t' || r == '\n' || r == '\r' {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(result.String())
|
||||
}
|
||||
|
||||
// Command tokens - matching the C++ definitions
|
||||
const (
|
||||
// Server -> Client commands
|
||||
CommandActived byte = 0 // COMMAND_ACTIVED
|
||||
CommandBye byte = 204 // COMMAND_BYE - disconnect
|
||||
CommandHeartbeat byte = 216 // CMD_HEARTBEAT_ACK
|
||||
|
||||
// Client -> Server tokens
|
||||
TokenAuth byte = 100 // TOKEN_AUTH - authorization required
|
||||
TokenHeartbeat byte = 101 // TOKEN_HEARTBEAT
|
||||
TokenLogin byte = 102 // TOKEN_LOGIN - login packet
|
||||
)
|
||||
|
||||
// LOGIN_INFOR structure size and offsets (matching C++ struct with default alignment)
|
||||
// Note: C++ struct uses default alignment (4-byte for uint32/int)
|
||||
const (
|
||||
LoginInfoSize = 980 // Total size of LOGIN_INFOR struct (with alignment padding)
|
||||
|
||||
// Field offsets (with alignment padding)
|
||||
OffsetToken = 0 // 1 byte (unsigned char)
|
||||
OffsetOsVerInfoEx = 1 // 156 bytes (char[156])
|
||||
// 3 bytes padding here to align dwCPUMHz to 4-byte boundary
|
||||
OffsetCPUMHz = 160 // 4 bytes (unsigned int) - aligned to 4
|
||||
OffsetModuleVersion = 164 // 24 bytes (char[24])
|
||||
OffsetPCName = 188 // 240 bytes (char[240])
|
||||
OffsetMasterID = 428 // 20 bytes (char[20])
|
||||
OffsetWebCamExist = 448 // 4 bytes (int) - aligned to 4
|
||||
OffsetSpeed = 452 // 4 bytes (unsigned int)
|
||||
OffsetStartTime = 456 // 20 bytes (char[20])
|
||||
OffsetReserved = 476 // 512 bytes (char[512])
|
||||
)
|
||||
|
||||
// LoginInfo represents client login information
|
||||
type LoginInfo struct {
|
||||
Token byte
|
||||
OsVerInfo string // OS version info
|
||||
CPUMHz uint32
|
||||
ModuleVersion string
|
||||
PCName string // Computer name
|
||||
MasterID string
|
||||
WebCamExist bool
|
||||
Speed uint32
|
||||
StartTime string
|
||||
Reserved string // Contains additional info separated by |
|
||||
}
|
||||
|
||||
// ParseLoginInfo parses LOGIN_INFOR from data
|
||||
func ParseLoginInfo(data []byte) (*LoginInfo, error) {
|
||||
if len(data) < 100 { // Minimum size check
|
||||
return nil, ErrInvalidData
|
||||
}
|
||||
|
||||
info := &LoginInfo{
|
||||
Token: data[0],
|
||||
}
|
||||
|
||||
// Parse OS version info (offset 1, 156 bytes)
|
||||
// The C++ client fills this with a readable string like "Windows 10" via getSystemName()
|
||||
if len(data) >= OffsetOsVerInfoEx+156 {
|
||||
info.OsVerInfo = parseOsVersionInfo(data[OffsetOsVerInfoEx : OffsetOsVerInfoEx+156])
|
||||
}
|
||||
|
||||
// Parse CPU MHz (offset 160, 4 bytes)
|
||||
if len(data) >= OffsetCPUMHz+4 {
|
||||
info.CPUMHz = binary.LittleEndian.Uint32(data[OffsetCPUMHz:])
|
||||
}
|
||||
|
||||
// Parse module version (offset 164, 24 bytes)
|
||||
// This contains date string like "Dec 19 2025"
|
||||
if len(data) >= OffsetModuleVersion+24 {
|
||||
info.ModuleVersion = gbkToUTF8(data[OffsetModuleVersion : OffsetModuleVersion+24])
|
||||
}
|
||||
|
||||
// Parse PC name (offset 188, 240 bytes)
|
||||
if len(data) >= OffsetPCName+240 {
|
||||
info.PCName = gbkToUTF8(data[OffsetPCName : OffsetPCName+240])
|
||||
}
|
||||
|
||||
// Parse Master ID (offset 428, 20 bytes)
|
||||
if len(data) >= OffsetMasterID+20 {
|
||||
info.MasterID = gbkToUTF8(data[OffsetMasterID : OffsetMasterID+20])
|
||||
}
|
||||
|
||||
// Parse WebCam exist (offset 448, 4 bytes)
|
||||
if len(data) >= OffsetWebCamExist+4 {
|
||||
info.WebCamExist = binary.LittleEndian.Uint32(data[OffsetWebCamExist:]) != 0
|
||||
}
|
||||
|
||||
// Parse Speed (offset 452, 4 bytes)
|
||||
if len(data) >= OffsetSpeed+4 {
|
||||
info.Speed = binary.LittleEndian.Uint32(data[OffsetSpeed:])
|
||||
}
|
||||
|
||||
// Parse Start time (offset 456, 20 bytes)
|
||||
if len(data) >= OffsetStartTime+20 {
|
||||
info.StartTime = gbkToUTF8(data[OffsetStartTime : OffsetStartTime+20])
|
||||
}
|
||||
|
||||
// Parse Reserved (offset 476, 512 bytes) - contains additional info
|
||||
if len(data) >= OffsetReserved+512 {
|
||||
info.Reserved = gbkToUTF8(data[OffsetReserved : OffsetReserved+512])
|
||||
} else if len(data) > OffsetReserved {
|
||||
info.Reserved = gbkToUTF8(data[OffsetReserved:])
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// parseOsVersionInfo parses the OS version info field
|
||||
// The C++ client fills this with a readable string like "Windows 10" via getSystemName()
|
||||
func parseOsVersionInfo(data []byte) string {
|
||||
return gbkToUTF8(data)
|
||||
}
|
||||
|
||||
// ParseReserved parses the reserved field into a slice of strings
|
||||
func (info *LoginInfo) ParseReserved() []string {
|
||||
if info.Reserved == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Split(info.Reserved, "|")
|
||||
}
|
||||
|
||||
// GetReservedField returns a specific field from reserved data by index
|
||||
// Fields: ClientType(0), SystemBits(1), CPU(2), Memory(3), FilePath(4),
|
||||
// Reserved(5), InstallTime(6), InstallInfo(7), ProgramBits(8), ExpiredDate(9),
|
||||
// ClientLoc(10), ClientPubIP(11), ExeVersion(12), Username(13), IsAdmin(14)
|
||||
func (info *LoginInfo) GetReservedField(index int) string {
|
||||
fields := info.ParseReserved()
|
||||
if index >= 0 && index < len(fields) {
|
||||
return fields[index]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Validation structure for TOKEN_AUTH
|
||||
type Validation struct {
|
||||
From string // Start date
|
||||
To string // End date
|
||||
Admin string // Admin address
|
||||
Port int // Admin port
|
||||
Checksum string // Reserved field
|
||||
}
|
||||
|
||||
// BuildValidation creates a validation response
|
||||
func BuildValidation(days float64, admin string, port int) []byte {
|
||||
// This would build the validation structure
|
||||
// For now, return a simple structure
|
||||
data := make([]byte, 160) // Size of Validation struct
|
||||
data[0] = TokenAuth
|
||||
|
||||
// Fill in fields...
|
||||
// From: 20 bytes
|
||||
// To: 20 bytes
|
||||
// Admin: 100 bytes
|
||||
// Port: 4 bytes
|
||||
// Checksum: 16 bytes
|
||||
|
||||
return data
|
||||
}
|
||||
267
server/go/protocol/header.go
Normal file
267
server/go/protocol/header.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package protocol
|
||||
|
||||
// Header encoding/decoding functions
|
||||
// Ported from common/header.h and common/encfuncs.h
|
||||
|
||||
const (
|
||||
MsgHeader = "HELL"
|
||||
FlagCompLen = 4
|
||||
FlagLength = 8
|
||||
HdrLength = FlagLength + 8 // FLAG_LENGTH + 2 * sizeof(uint32)
|
||||
MinComLen = 12
|
||||
)
|
||||
|
||||
// HeaderEncType represents the encryption method used for header
|
||||
type HeaderEncType int
|
||||
|
||||
const (
|
||||
HeaderEncUnknown HeaderEncType = -1
|
||||
HeaderEncNone HeaderEncType = 0
|
||||
HeaderEncV0 HeaderEncType = 1
|
||||
HeaderEncV1 HeaderEncType = 2
|
||||
HeaderEncV2 HeaderEncType = 3
|
||||
HeaderEncV3 HeaderEncType = 4
|
||||
HeaderEncV4 HeaderEncType = 5
|
||||
HeaderEncV5 HeaderEncType = 6
|
||||
HeaderEncV6 HeaderEncType = 7
|
||||
)
|
||||
|
||||
// DecryptFunc is the function signature for header decryption
|
||||
type DecryptFunc func(data []byte, key byte)
|
||||
|
||||
// defaultDecrypt does nothing (no encryption)
|
||||
func defaultDecrypt(data []byte, key byte) {
|
||||
// No-op
|
||||
}
|
||||
|
||||
// decrypt is the default encryption method (V0)
|
||||
func decrypt(data []byte, key byte) {
|
||||
if key == 0 {
|
||||
return
|
||||
}
|
||||
for i := 0; i < len(data); i++ {
|
||||
k := key ^ byte(i*31)
|
||||
value := int(data[i])
|
||||
switch i % 4 {
|
||||
case 0:
|
||||
value -= int(k)
|
||||
case 1:
|
||||
value = value ^ int(k)
|
||||
case 2:
|
||||
value += int(k)
|
||||
case 3:
|
||||
value = ^value ^ int(k)
|
||||
}
|
||||
data[i] = byte(value & 0xFF)
|
||||
}
|
||||
}
|
||||
|
||||
// decryptV1 - alternating add/subtract
|
||||
func decryptV1(data []byte, key byte) {
|
||||
for i := 0; i < len(data); i++ {
|
||||
if i%2 == 0 {
|
||||
data[i] = data[i] - key
|
||||
} else {
|
||||
data[i] = data[i] + key
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// decryptV2 - XOR with rotation
|
||||
func decryptV2(data []byte, key byte) {
|
||||
for i := 0; i < len(data); i++ {
|
||||
if i%2 == 0 {
|
||||
data[i] = (data[i] >> 1) | (data[i] << 7) // rotate right
|
||||
} else {
|
||||
data[i] = (data[i] << 1) | (data[i] >> 7) // rotate left
|
||||
}
|
||||
data[i] ^= key
|
||||
}
|
||||
}
|
||||
|
||||
// decryptV3 - dynamic key with position
|
||||
func decryptV3(data []byte, key byte) {
|
||||
for i := 0; i < len(data); i++ {
|
||||
dynamicKey := key + byte(i%8)
|
||||
switch i % 3 {
|
||||
case 0:
|
||||
data[i] = (data[i] ^ dynamicKey) - dynamicKey
|
||||
case 1:
|
||||
data[i] = (data[i] + dynamicKey) ^ dynamicKey
|
||||
case 2:
|
||||
data[i] = ^data[i] - dynamicKey
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// decryptV4 - pseudo-random XOR (symmetric)
|
||||
func decryptV4(data []byte, key byte) {
|
||||
rand := uint16(key)
|
||||
for i := 0; i < len(data); i++ {
|
||||
rand = (rand*13 + 17) % 256
|
||||
data[i] ^= byte(rand)
|
||||
}
|
||||
}
|
||||
|
||||
// decryptV5 - dynamic key with bit shift
|
||||
func decryptV5(data []byte, key byte) {
|
||||
for i := 0; i < len(data); i++ {
|
||||
dynamicKey := (key + byte(i)) ^ 0x55
|
||||
data[i] = ((data[i] - byte(i%7)) ^ (dynamicKey << 3)) - dynamicKey
|
||||
}
|
||||
}
|
||||
|
||||
// decryptV6 - pseudo-random with position (symmetric)
|
||||
func decryptV6(data []byte, key byte) {
|
||||
rand := uint16(key)
|
||||
for i := 0; i < len(data); i++ {
|
||||
rand = (rand*31 + 17) % 256
|
||||
data[i] ^= byte(rand) + byte(i)
|
||||
}
|
||||
}
|
||||
|
||||
// All decrypt methods
|
||||
var decryptMethods = []DecryptFunc{
|
||||
defaultDecrypt,
|
||||
decrypt,
|
||||
decryptV1,
|
||||
decryptV2,
|
||||
decryptV3,
|
||||
decryptV4,
|
||||
decryptV5,
|
||||
decryptV6,
|
||||
}
|
||||
|
||||
// compare decrypts and compares the flag with magic
|
||||
func compare(flag []byte, magic string, length int, dec DecryptFunc, key byte) bool {
|
||||
if len(flag) < MinComLen {
|
||||
return false
|
||||
}
|
||||
|
||||
buf := make([]byte, MinComLen)
|
||||
copy(buf, flag[:MinComLen])
|
||||
dec(buf[:length], key)
|
||||
|
||||
return string(buf[:length]) == magic
|
||||
}
|
||||
|
||||
// CheckHead tries all decryption methods to identify the protocol
|
||||
func CheckHead(flag []byte) (flagType FlagType, encType HeaderEncType, decrypted []byte) {
|
||||
if len(flag) < MinComLen {
|
||||
return FlagUnknown, HeaderEncUnknown, nil
|
||||
}
|
||||
|
||||
for i, method := range decryptMethods {
|
||||
buf := make([]byte, MinComLen)
|
||||
copy(buf, flag[:MinComLen])
|
||||
|
||||
ft := checkHeadWithMethod(buf, method)
|
||||
if ft != FlagUnknown {
|
||||
return ft, HeaderEncType(i), buf
|
||||
}
|
||||
}
|
||||
|
||||
return FlagUnknown, HeaderEncUnknown, nil
|
||||
}
|
||||
|
||||
// checkHeadWithMethod checks the flag with a specific decrypt method
|
||||
func checkHeadWithMethod(flag []byte, dec DecryptFunc) FlagType {
|
||||
// Try HELL (FLAG_HELL)
|
||||
if len(flag) >= FlagLength {
|
||||
buf := make([]byte, MinComLen)
|
||||
copy(buf, flag)
|
||||
key := buf[6]
|
||||
dec(buf[:FlagCompLen], key)
|
||||
if string(buf[:4]) == MsgHeader {
|
||||
copy(flag, buf)
|
||||
return FlagHell
|
||||
}
|
||||
}
|
||||
|
||||
// Try Shine (FLAG_SHINE)
|
||||
if len(flag) >= 5 {
|
||||
buf := make([]byte, MinComLen)
|
||||
copy(buf, flag)
|
||||
dec(buf[:5], 0)
|
||||
if string(buf[:5]) == "Shine" {
|
||||
copy(flag, buf)
|
||||
return FlagShine
|
||||
}
|
||||
}
|
||||
|
||||
// Try <<FUCK>> (FLAG_FUCK)
|
||||
if len(flag) >= 10 {
|
||||
buf := make([]byte, MinComLen)
|
||||
copy(buf, flag)
|
||||
key := buf[9]
|
||||
dec(buf[:8], key)
|
||||
if string(buf[:8]) == "<<FUCK>>" {
|
||||
copy(flag, buf)
|
||||
return FlagFuck
|
||||
}
|
||||
}
|
||||
|
||||
// Try Hello? (FLAG_HELLO)
|
||||
if len(flag) >= 7 {
|
||||
buf := make([]byte, MinComLen)
|
||||
copy(buf, flag)
|
||||
key := buf[6]
|
||||
dec(buf[:6], key)
|
||||
if string(buf[:6]) == "Hello?" {
|
||||
copy(flag, buf)
|
||||
return FlagHello
|
||||
}
|
||||
}
|
||||
|
||||
return FlagUnknown
|
||||
}
|
||||
|
||||
// FlagType represents the protocol type
|
||||
type FlagType int
|
||||
|
||||
const (
|
||||
FlagWinOS FlagType = -1
|
||||
FlagUnknown FlagType = 0
|
||||
FlagShine FlagType = 1
|
||||
FlagFuck FlagType = 2
|
||||
FlagHello FlagType = 3
|
||||
FlagHell FlagType = 4
|
||||
)
|
||||
|
||||
func (f FlagType) String() string {
|
||||
switch f {
|
||||
case FlagWinOS:
|
||||
return "WinOS"
|
||||
case FlagShine:
|
||||
return "Shine"
|
||||
case FlagFuck:
|
||||
return "Fuck"
|
||||
case FlagHello:
|
||||
return "Hello"
|
||||
case FlagHell:
|
||||
return "Hell"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// GetFlagLength returns the flag length for a given flag type
|
||||
func GetFlagLength(ft FlagType) int {
|
||||
switch ft {
|
||||
case FlagShine:
|
||||
return 5
|
||||
case FlagFuck:
|
||||
return 11 // 8 + 3
|
||||
case FlagHello:
|
||||
return 8
|
||||
case FlagHell:
|
||||
return FlagLength
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// GetHeaderLength returns the full header length for a given flag type
|
||||
func GetHeaderLength(ft FlagType) int {
|
||||
return GetFlagLength(ft) + 8
|
||||
}
|
||||
261
server/go/protocol/parser.go
Normal file
261
server/go/protocol/parser.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrNeedMore = errors.New("need more data")
|
||||
ErrInvalidData = errors.New("invalid data")
|
||||
ErrUnsupported = errors.New("unsupported protocol")
|
||||
ErrDecompress = errors.New("decompression failed")
|
||||
)
|
||||
|
||||
// Parser handles protocol parsing
|
||||
type Parser struct {
|
||||
codec *Codec
|
||||
}
|
||||
|
||||
// NewParser creates a new parser
|
||||
func NewParser() *Parser {
|
||||
return &Parser{
|
||||
codec: NewCodec(),
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases resources held by the parser
|
||||
func (p *Parser) Close() {
|
||||
if p.codec != nil {
|
||||
p.codec.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Parse attempts to parse a complete packet from the buffer
|
||||
func (p *Parser) Parse(ctx *connection.Context) ([]byte, error) {
|
||||
buf := ctx.InBuffer
|
||||
|
||||
// Need at least minimum bytes to determine protocol
|
||||
if buf.Len() < MinComLen {
|
||||
return nil, ErrNeedMore
|
||||
}
|
||||
|
||||
// Check if header is already parsed
|
||||
if ctx.FlagType == connection.FlagUnknown {
|
||||
// Try to parse header
|
||||
if err := p.parseHeader(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Now parse the packet
|
||||
return p.parsePacket(ctx)
|
||||
}
|
||||
|
||||
// parseHeader parses the protocol header with obfuscation handling
|
||||
func (p *Parser) parseHeader(ctx *connection.Context) error {
|
||||
buf := ctx.InBuffer
|
||||
header := buf.Peek(MinComLen)
|
||||
if header == nil || len(header) < MinComLen {
|
||||
return ErrNeedMore
|
||||
}
|
||||
|
||||
// Try to decode the header using all encryption methods
|
||||
flagType, encType, decrypted := CheckHead(header)
|
||||
|
||||
if flagType == FlagUnknown {
|
||||
return ErrUnsupported
|
||||
}
|
||||
|
||||
// Store decrypted header params for later use (for XOREncoder16)
|
||||
ctx.HeaderParams = make([]byte, MinComLen)
|
||||
if decrypted != nil {
|
||||
copy(ctx.HeaderParams, decrypted)
|
||||
} else {
|
||||
copy(ctx.HeaderParams, header)
|
||||
}
|
||||
|
||||
// Map protocol FlagType to connection FlagType and set compression method
|
||||
switch flagType {
|
||||
case FlagHell:
|
||||
ctx.FlagType = connection.FlagHell
|
||||
ctx.FlagLen = FlagLength
|
||||
ctx.HeaderLen = ctx.FlagLen + 8
|
||||
ctx.CompressMethod = connection.CompressZstd // HELL uses ZSTD
|
||||
case FlagHello:
|
||||
ctx.FlagType = connection.FlagHello
|
||||
ctx.FlagLen = 8
|
||||
ctx.HeaderLen = ctx.FlagLen + 8
|
||||
ctx.CompressMethod = connection.CompressNone // HELLO uses no compression
|
||||
case FlagShine:
|
||||
ctx.FlagType = connection.FlagShine
|
||||
ctx.FlagLen = 5
|
||||
ctx.HeaderLen = ctx.FlagLen + 8
|
||||
ctx.CompressMethod = connection.CompressZstd // SHINE uses ZSTD
|
||||
case FlagFuck:
|
||||
ctx.FlagType = connection.FlagFuck
|
||||
ctx.FlagLen = 11
|
||||
ctx.HeaderLen = ctx.FlagLen + 8
|
||||
ctx.CompressMethod = connection.CompressZstd // FUCK uses ZSTD
|
||||
default:
|
||||
return ErrUnsupported
|
||||
}
|
||||
|
||||
// Store encryption type for later use
|
||||
ctx.HeaderEncType = int(encType)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parsePacket parses a complete packet
|
||||
func (p *Parser) parsePacket(ctx *connection.Context) ([]byte, error) {
|
||||
buf := ctx.InBuffer
|
||||
|
||||
// Check if we have enough data for header
|
||||
if buf.Len() < ctx.HeaderLen {
|
||||
return nil, ErrNeedMore
|
||||
}
|
||||
|
||||
// Peek the header to get total length
|
||||
headerData := buf.Peek(ctx.HeaderLen)
|
||||
if headerData == nil {
|
||||
return nil, ErrNeedMore
|
||||
}
|
||||
|
||||
// Decrypt the header first
|
||||
decryptedHeader := make([]byte, len(headerData))
|
||||
copy(decryptedHeader, headerData)
|
||||
|
||||
// Get the encryption key (usually at position 6)
|
||||
var key byte
|
||||
if len(headerData) > 6 {
|
||||
key = headerData[6] // Use original key before decryption
|
||||
}
|
||||
|
||||
// Decrypt flag portion
|
||||
if ctx.HeaderEncType >= 0 && ctx.HeaderEncType < len(decryptMethods) {
|
||||
decryptMethods[ctx.HeaderEncType](decryptedHeader[:FlagCompLen], key)
|
||||
}
|
||||
|
||||
// Read the total length field (after flag)
|
||||
totalLen := binary.LittleEndian.Uint32(decryptedHeader[ctx.FlagLen:])
|
||||
|
||||
// Validate length
|
||||
if totalLen < uint32(ctx.HeaderLen) || totalLen > 10*1024*1024 {
|
||||
return nil, ErrInvalidData
|
||||
}
|
||||
|
||||
// Check if we have the complete packet
|
||||
if buf.Len() < int(totalLen) {
|
||||
return nil, ErrNeedMore
|
||||
}
|
||||
|
||||
// Read the complete packet
|
||||
packet := buf.Read(int(totalLen))
|
||||
if packet == nil {
|
||||
return nil, ErrInvalidData
|
||||
}
|
||||
|
||||
// Decrypt header portion of packet
|
||||
if ctx.HeaderEncType >= 0 && ctx.HeaderEncType < len(decryptMethods) {
|
||||
decryptMethods[ctx.HeaderEncType](packet[:FlagCompLen], key)
|
||||
}
|
||||
|
||||
// Update HeaderParams with this packet's header (for XOREncoder16 k1, k2)
|
||||
if len(packet) >= ctx.FlagLen {
|
||||
ctx.HeaderParams = make([]byte, ctx.HeaderLen)
|
||||
copy(ctx.HeaderParams, packet[:ctx.HeaderLen])
|
||||
}
|
||||
|
||||
// Extract data after header
|
||||
dataStart := ctx.HeaderLen
|
||||
data := packet[dataStart:]
|
||||
|
||||
// Get original length (before compression)
|
||||
var origLen uint32
|
||||
if ctx.FlagType != connection.FlagWinOS {
|
||||
origLen = binary.LittleEndian.Uint32(packet[ctx.FlagLen+4 : ctx.FlagLen+8])
|
||||
}
|
||||
|
||||
// Decode (XOR, etc.) before decompression - this is Encoder2
|
||||
p.codec.Decode(ctx, data)
|
||||
|
||||
// Decompress
|
||||
decompressed, err := p.codec.Decompress(ctx, data, origLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decode after decompression - this is Encoder
|
||||
p.codec.DecodeData(ctx, decompressed)
|
||||
|
||||
return decompressed, nil
|
||||
}
|
||||
|
||||
// Encode encodes data for sending
|
||||
func (p *Parser) Encode(ctx *connection.Context, data []byte) ([]byte, error) {
|
||||
// Encode before compression
|
||||
encoded := make([]byte, len(data))
|
||||
copy(encoded, data)
|
||||
p.codec.EncodeData(ctx, encoded)
|
||||
|
||||
// Compress
|
||||
compressed, err := p.codec.Compress(ctx, encoded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build packet
|
||||
packet := p.buildPacket(ctx, compressed, uint32(len(data)))
|
||||
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
// buildPacket builds a complete packet with header
|
||||
func (p *Parser) buildPacket(ctx *connection.Context, data []byte, origLen uint32) []byte {
|
||||
totalLen := ctx.HeaderLen + len(data)
|
||||
packet := make([]byte, totalLen)
|
||||
|
||||
// Write flag
|
||||
flag := p.getFlag(ctx)
|
||||
copy(packet[:ctx.FlagLen], flag)
|
||||
|
||||
// Write total length
|
||||
binary.LittleEndian.PutUint32(packet[ctx.FlagLen:], uint32(totalLen))
|
||||
|
||||
// Write original length
|
||||
binary.LittleEndian.PutUint32(packet[ctx.FlagLen+4:], origLen)
|
||||
|
||||
// Write data
|
||||
copy(packet[ctx.HeaderLen:], data)
|
||||
|
||||
// Encode after building
|
||||
p.codec.Encode(ctx, packet[ctx.HeaderLen:])
|
||||
|
||||
return packet
|
||||
}
|
||||
|
||||
// getFlag returns the protocol flag bytes
|
||||
func (p *Parser) getFlag(ctx *connection.Context) []byte {
|
||||
switch ctx.FlagType {
|
||||
case connection.FlagHell:
|
||||
flag := make([]byte, FlagLength)
|
||||
copy(flag, []byte(MsgHeader))
|
||||
return flag
|
||||
case connection.FlagHello:
|
||||
flag := make([]byte, 8)
|
||||
copy(flag, []byte("Hello?"))
|
||||
return flag
|
||||
case connection.FlagShine:
|
||||
return []byte("Shine")
|
||||
case connection.FlagFuck:
|
||||
flag := make([]byte, 11)
|
||||
copy(flag, []byte("<<FUCK>>"))
|
||||
return flag
|
||||
default:
|
||||
return make([]byte, ctx.FlagLen)
|
||||
}
|
||||
}
|
||||
316
server/go/server/server.go
Normal file
316
server/go/server/server.go
Normal file
@@ -0,0 +1,316 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/logger"
|
||||
"github.com/yuanyuanxiang/SimpleRemoter/server/go/protocol"
|
||||
)
|
||||
|
||||
// Config holds server configuration
|
||||
type Config struct {
|
||||
Port int
|
||||
MaxConnections int
|
||||
ReadBufferSize int
|
||||
WriteBufferSize int
|
||||
KeepAliveTime time.Duration
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
}
|
||||
|
||||
// DefaultConfig returns default configuration
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
Port: 6543,
|
||||
MaxConnections: 9999,
|
||||
ReadBufferSize: 8192,
|
||||
WriteBufferSize: 8192,
|
||||
KeepAliveTime: time.Minute * 5,
|
||||
ReadTimeout: time.Minute * 2,
|
||||
WriteTimeout: time.Second * 30,
|
||||
}
|
||||
}
|
||||
|
||||
// Handler defines the interface for handling client events
|
||||
type Handler interface {
|
||||
OnConnect(ctx *connection.Context)
|
||||
OnDisconnect(ctx *connection.Context)
|
||||
OnReceive(ctx *connection.Context, data []byte)
|
||||
}
|
||||
|
||||
// Server is the TCP server
|
||||
type Server struct {
|
||||
config Config
|
||||
listener net.Listener
|
||||
manager *connection.Manager
|
||||
handler Handler
|
||||
parser *protocol.Parser
|
||||
|
||||
running atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Logger
|
||||
logger *logger.Logger
|
||||
}
|
||||
|
||||
// New creates a new server
|
||||
func New(config Config) *Server {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
s := &Server{
|
||||
config: config,
|
||||
manager: connection.NewManager(config.MaxConnections),
|
||||
parser: protocol.NewParser(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger.New(logger.DefaultConfig()).WithPrefix("Server"),
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SetHandler sets the event handler
|
||||
func (s *Server) SetHandler(h Handler) {
|
||||
s.handler = h
|
||||
s.manager.SetCallbacks(
|
||||
func(ctx *connection.Context) {
|
||||
if s.handler != nil {
|
||||
s.handler.OnConnect(ctx)
|
||||
}
|
||||
},
|
||||
func(ctx *connection.Context) {
|
||||
if s.handler != nil {
|
||||
s.handler.OnDisconnect(ctx)
|
||||
}
|
||||
},
|
||||
func(ctx *connection.Context, data []byte) {
|
||||
if s.handler != nil {
|
||||
s.handler.OnReceive(ctx, data)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// SetLogger sets the logger
|
||||
func (s *Server) SetLogger(l *logger.Logger) {
|
||||
s.logger = l
|
||||
}
|
||||
|
||||
// Start starts the server
|
||||
func (s *Server) Start() error {
|
||||
addr := net.JoinHostPort("0.0.0.0", itoa(s.config.Port))
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.listener = listener
|
||||
s.running.Store(true)
|
||||
|
||||
s.logger.Info("Server started on port %d", s.config.Port)
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.acceptLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func itoa(i int) string {
|
||||
if i == 0 {
|
||||
return "0"
|
||||
}
|
||||
var b [20]byte
|
||||
pos := len(b)
|
||||
neg := i < 0
|
||||
if neg {
|
||||
// Handle math.MinInt safely by using unsigned
|
||||
u := uint(-i)
|
||||
for u > 0 {
|
||||
pos--
|
||||
b[pos] = byte('0' + u%10)
|
||||
u /= 10
|
||||
}
|
||||
pos--
|
||||
b[pos] = '-'
|
||||
} else {
|
||||
for i > 0 {
|
||||
pos--
|
||||
b[pos] = byte('0' + i%10)
|
||||
i /= 10
|
||||
}
|
||||
}
|
||||
return string(b[pos:])
|
||||
}
|
||||
|
||||
// Stop stops the server gracefully
|
||||
func (s *Server) Stop() {
|
||||
if !s.running.Swap(false) {
|
||||
return
|
||||
}
|
||||
|
||||
s.cancel()
|
||||
|
||||
if s.listener != nil {
|
||||
_ = s.listener.Close()
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
s.manager.CloseAll()
|
||||
|
||||
// Close parser resources
|
||||
if s.parser != nil {
|
||||
s.parser.Close()
|
||||
}
|
||||
|
||||
s.wg.Wait()
|
||||
s.logger.Info("Server stopped")
|
||||
}
|
||||
|
||||
// acceptLoop accepts incoming connections
|
||||
func (s *Server) acceptLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
for s.running.Load() {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
if s.running.Load() {
|
||||
s.logger.Error("Accept error: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check connection limit before spawning goroutine
|
||||
if s.manager.Count() >= s.config.MaxConnections {
|
||||
s.logger.Warn("Max connections reached, rejecting new connection from %s", conn.RemoteAddr())
|
||||
_ = conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle each connection in its own goroutine
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection handles a single connection
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
// Create context
|
||||
ctx := connection.NewContext(conn, nil)
|
||||
|
||||
// Add to manager
|
||||
if err := s.manager.Add(ctx); err != nil {
|
||||
s.logger.Warn("Failed to add connection: %v", err)
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
s.manager.Remove(ctx)
|
||||
_ = ctx.Close()
|
||||
}()
|
||||
|
||||
// Read loop
|
||||
buf := make([]byte, s.config.ReadBufferSize)
|
||||
for !ctx.IsClosed() && s.running.Load() {
|
||||
// Set read deadline
|
||||
if s.config.ReadTimeout > 0 {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(s.config.ReadTimeout))
|
||||
}
|
||||
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF && s.running.Load() {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
// Timeout - check keepalive
|
||||
if s.config.KeepAliveTime > 0 && ctx.TimeSinceLastActive() > s.config.KeepAliveTime {
|
||||
s.logger.Info("Connection %d timed out", ctx.ID)
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
ctx.UpdateLastActive()
|
||||
|
||||
// Write to input buffer
|
||||
_, _ = ctx.InBuffer.Write(buf[:n])
|
||||
|
||||
// Process received data
|
||||
s.processData(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processData processes received data and calls handler
|
||||
func (s *Server) processData(ctx *connection.Context) {
|
||||
for ctx.InBuffer.Len() > 0 {
|
||||
// Try to parse a complete packet
|
||||
data, err := s.parser.Parse(ctx)
|
||||
if err != nil {
|
||||
if err == protocol.ErrNeedMore {
|
||||
return
|
||||
}
|
||||
s.logger.Error("Parse error for connection %d: %v", ctx.ID, err)
|
||||
_ = ctx.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if data != nil {
|
||||
// Call handler
|
||||
s.manager.OnReceive(ctx, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send sends data to a specific connection
|
||||
func (s *Server) Send(ctx *connection.Context, data []byte) error {
|
||||
if ctx == nil || ctx.IsClosed() {
|
||||
return connection.ErrConnectionClosed
|
||||
}
|
||||
|
||||
// Encode and compress data
|
||||
encoded, err := s.parser.Encode(ctx, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ctx.Send(encoded)
|
||||
}
|
||||
|
||||
// Broadcast sends data to all connections
|
||||
func (s *Server) Broadcast(data []byte) {
|
||||
s.manager.Range(func(ctx *connection.Context) bool {
|
||||
_ = s.Send(ctx, data)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// GetConnection returns a connection by ID
|
||||
func (s *Server) GetConnection(id uint64) *connection.Context {
|
||||
return s.manager.Get(id)
|
||||
}
|
||||
|
||||
// ConnectionCount returns the current connection count
|
||||
func (s *Server) ConnectionCount() int {
|
||||
return s.manager.Count()
|
||||
}
|
||||
|
||||
// Port returns the server port
|
||||
func (s *Server) Port() int {
|
||||
return s.config.Port
|
||||
}
|
||||
|
||||
// UpdateMaxConnections updates the max connections limit
|
||||
func (s *Server) UpdateMaxConnections(max int) {
|
||||
s.manager.UpdateMaxConnections(max)
|
||||
}
|
||||
Reference in New Issue
Block a user