mirror of
https://github.com/yuanyuanxiang/SimpleRemoter.git
synced 2026-01-21 23:13:08 +08:00
269 lines
7.4 KiB
Go
269 lines
7.4 KiB
Go
|
|
package main
|
||
|
|
|
||
|
|
import (
|
||
|
|
"flag"
|
||
|
|
"fmt"
|
||
|
|
"os"
|
||
|
|
"os/signal"
|
||
|
|
"strconv"
|
||
|
|
"strings"
|
||
|
|
"syscall"
|
||
|
|
|
||
|
|
"github.com/yuanyuanxiang/SimpleRemoter/server/go/auth"
|
||
|
|
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
|
||
|
|
"github.com/yuanyuanxiang/SimpleRemoter/server/go/logger"
|
||
|
|
"github.com/yuanyuanxiang/SimpleRemoter/server/go/protocol"
|
||
|
|
"github.com/yuanyuanxiang/SimpleRemoter/server/go/server"
|
||
|
|
)
|
||
|
|
|
||
|
|
// MyHandler implements the server.Handler interface
|
||
|
|
type MyHandler struct {
|
||
|
|
log *logger.Logger
|
||
|
|
auth *auth.Authenticator
|
||
|
|
srv *server.Server
|
||
|
|
}
|
||
|
|
|
||
|
|
// OnConnect is called when a client connects
|
||
|
|
func (h *MyHandler) OnConnect(ctx *connection.Context) {
|
||
|
|
// Only log connection established, detailed info logged on login
|
||
|
|
}
|
||
|
|
|
||
|
|
// OnDisconnect is called when a client disconnects
|
||
|
|
func (h *MyHandler) OnDisconnect(ctx *connection.Context) {
|
||
|
|
info := ctx.GetInfo()
|
||
|
|
if info.ClientID != "" {
|
||
|
|
h.log.ClientEvent("offline", ctx.ID, ctx.GetPeerIP(),
|
||
|
|
"clientID", info.ClientID,
|
||
|
|
"computer", info.ComputerName,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// OnReceive is called when data is received from a client
|
||
|
|
func (h *MyHandler) OnReceive(ctx *connection.Context, data []byte) {
|
||
|
|
if len(data) == 0 {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
cmd := data[0]
|
||
|
|
// Handle commands
|
||
|
|
switch cmd {
|
||
|
|
case protocol.TokenLogin:
|
||
|
|
h.handleLogin(ctx, data)
|
||
|
|
case protocol.TokenAuth:
|
||
|
|
h.handleAuth(ctx, data)
|
||
|
|
case protocol.TokenHeartbeat:
|
||
|
|
h.handleHeartbeat(ctx, data)
|
||
|
|
default:
|
||
|
|
// Other commands are not implemented yet
|
||
|
|
h.log.Info("Unhandled command %d from client %d", cmd, ctx.ID)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// handleLogin handles client login (TOKEN_LOGIN = 102)
|
||
|
|
func (h *MyHandler) handleLogin(ctx *connection.Context, data []byte) {
|
||
|
|
info, err := protocol.ParseLoginInfo(data)
|
||
|
|
if err != nil {
|
||
|
|
h.log.Error("Failed to parse login info from client %d: %v", ctx.ID, err)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
// Use MasterID from login request as ClientID for logging
|
||
|
|
clientID := info.MasterID
|
||
|
|
if clientID == "" {
|
||
|
|
clientID = fmt.Sprintf("conn-%d", ctx.ID)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Store client info
|
||
|
|
reserved := info.ParseReserved()
|
||
|
|
clientInfo := connection.ClientInfo{
|
||
|
|
ClientID: clientID,
|
||
|
|
ComputerName: info.PCName,
|
||
|
|
OS: info.OsVerInfo,
|
||
|
|
Version: info.ModuleVersion,
|
||
|
|
HasCamera: info.WebCamExist,
|
||
|
|
InstallTime: info.StartTime,
|
||
|
|
}
|
||
|
|
|
||
|
|
// Parse additional info from reserved field
|
||
|
|
if len(reserved) > 0 {
|
||
|
|
clientInfo.ClientType = info.GetReservedField(0)
|
||
|
|
}
|
||
|
|
if len(reserved) > 2 {
|
||
|
|
clientInfo.CPU = info.GetReservedField(2)
|
||
|
|
}
|
||
|
|
if len(reserved) > 4 {
|
||
|
|
clientInfo.FilePath = info.GetReservedField(4)
|
||
|
|
}
|
||
|
|
if len(reserved) > 11 {
|
||
|
|
clientInfo.IP = info.GetReservedField(11) // Public IP
|
||
|
|
}
|
||
|
|
|
||
|
|
ctx.SetInfo(clientInfo)
|
||
|
|
ctx.IsLoggedIn.Store(true)
|
||
|
|
|
||
|
|
h.log.ClientEvent("online", ctx.ID, ctx.GetPeerIP(),
|
||
|
|
"clientID", clientID,
|
||
|
|
"computer", info.PCName,
|
||
|
|
"os", info.OsVerInfo,
|
||
|
|
"version", info.ModuleVersion,
|
||
|
|
"path", clientInfo.FilePath,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
// handleAuth handles authorization request (TOKEN_AUTH = 100)
|
||
|
|
func (h *MyHandler) handleAuth(ctx *connection.Context, data []byte) {
|
||
|
|
result := h.auth.Authenticate(data)
|
||
|
|
info := ctx.GetInfo()
|
||
|
|
|
||
|
|
if result.Valid {
|
||
|
|
h.log.Info("Auth success: clientID=%s computer=%s ip=%s sn=%s passcode=%s",
|
||
|
|
info.ClientID, info.ComputerName, ctx.GetPeerIP(), result.SN, result.Passcode)
|
||
|
|
} else {
|
||
|
|
h.log.Warn("Auth failed: clientID=%s computer=%s ip=%s sn=%s passcode=%s",
|
||
|
|
info.ClientID, info.ComputerName, ctx.GetPeerIP(), result.SN, result.Passcode)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Build and send response
|
||
|
|
resp := h.auth.BuildResponse(result)
|
||
|
|
if err := h.srv.Send(ctx, resp); err != nil {
|
||
|
|
h.log.Error("Failed to send auth response to client %d: %v", ctx.ID, err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// handleHeartbeat handles heartbeat from client (TOKEN_HEARTBEAT = 101)
|
||
|
|
func (h *MyHandler) handleHeartbeat(ctx *connection.Context, data []byte) {
|
||
|
|
|
||
|
|
// Parse Time from heartbeat request (offset 1, 8 bytes)
|
||
|
|
var hbTime uint64
|
||
|
|
if len(data) >= 9 {
|
||
|
|
hbTime = uint64(data[1]) | uint64(data[2])<<8 | uint64(data[3])<<16 | uint64(data[4])<<24 |
|
||
|
|
uint64(data[5])<<32 | uint64(data[6])<<40 | uint64(data[7])<<48 | uint64(data[8])<<56
|
||
|
|
}
|
||
|
|
|
||
|
|
// Build HeartbeatACK response: CMD_HEARTBEAT_ACK(1) + HeartbeatACK(32)
|
||
|
|
resp := make([]byte, 33)
|
||
|
|
resp[0] = protocol.CommandHeartbeat // CMD_HEARTBEAT_ACK = 216
|
||
|
|
// Time at offset 1 (8 bytes, little-endian)
|
||
|
|
resp[1] = byte(hbTime)
|
||
|
|
resp[2] = byte(hbTime >> 8)
|
||
|
|
resp[3] = byte(hbTime >> 16)
|
||
|
|
resp[4] = byte(hbTime >> 24)
|
||
|
|
resp[5] = byte(hbTime >> 32)
|
||
|
|
resp[6] = byte(hbTime >> 40)
|
||
|
|
resp[7] = byte(hbTime >> 48)
|
||
|
|
resp[8] = byte(hbTime >> 56)
|
||
|
|
// Reserved[24] at offset 9 is already zero
|
||
|
|
|
||
|
|
if err := h.srv.Send(ctx, resp); err != nil {
|
||
|
|
h.log.Error("Failed to send heartbeat ACK to client %d: %v", ctx.ID, err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// parsePorts parses a semicolon-separated port string and returns port numbers
|
||
|
|
func parsePorts(portStr string) ([]int, error) {
|
||
|
|
var ports []int
|
||
|
|
parts := strings.Split(portStr, ";")
|
||
|
|
for _, p := range parts {
|
||
|
|
p = strings.TrimSpace(p)
|
||
|
|
if p == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
port, err := strconv.Atoi(p)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("invalid port %q: %v", p, err)
|
||
|
|
}
|
||
|
|
if port < 1 || port > 65535 {
|
||
|
|
return nil, fmt.Errorf("port %d out of range (1-65535)", port)
|
||
|
|
}
|
||
|
|
ports = append(ports, port)
|
||
|
|
}
|
||
|
|
if len(ports) == 0 {
|
||
|
|
return nil, fmt.Errorf("no valid ports specified")
|
||
|
|
}
|
||
|
|
return ports, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func main() {
|
||
|
|
// Parse command line flags
|
||
|
|
portStr := flag.String("port", "6543", "Server listen ports (semicolon-separated, e.g. 6543;6544;6545)")
|
||
|
|
flag.StringVar(portStr, "p", "6543", "Server listen ports (shorthand)")
|
||
|
|
noConsole := flag.Bool("no-console", false, "Disable console output (for daemon mode)")
|
||
|
|
flag.Parse()
|
||
|
|
|
||
|
|
// Parse ports
|
||
|
|
ports, err := parsePorts(*portStr)
|
||
|
|
if err != nil {
|
||
|
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||
|
|
os.Exit(1)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Create logger with file output
|
||
|
|
logCfg := logger.DefaultConfig()
|
||
|
|
logCfg.Level = logger.LevelDebug
|
||
|
|
logCfg.Console = !*noConsole
|
||
|
|
logCfg.File = "logs/server.log"
|
||
|
|
logCfg.MaxSize = 100 // 100 MB
|
||
|
|
logCfg.MaxBackups = 10 // keep 10 old files
|
||
|
|
logCfg.MaxAge = 30 // 30 days
|
||
|
|
logCfg.Compress = true
|
||
|
|
|
||
|
|
log := logger.New(logCfg)
|
||
|
|
|
||
|
|
// Create auth config
|
||
|
|
authCfg := auth.DefaultConfig()
|
||
|
|
// PwdHash can be set from environment or config file
|
||
|
|
authCfg.PwdHash = os.Getenv("YAMA_PWDHASH")
|
||
|
|
if authCfg.PwdHash == "" {
|
||
|
|
// Default placeholder - should be configured in production
|
||
|
|
authCfg.PwdHash = "61f04dd637a74ee34493fc1025de2c131022536da751c29e3ff4e9024d8eec43"
|
||
|
|
}
|
||
|
|
authCfg.SuperPass = os.Getenv("YAMA_PWD")
|
||
|
|
|
||
|
|
// Create authenticator (shared by all servers)
|
||
|
|
authenticator := auth.New(authCfg)
|
||
|
|
|
||
|
|
// Create servers for each port
|
||
|
|
var servers []*server.Server
|
||
|
|
for _, port := range ports {
|
||
|
|
config := server.DefaultConfig()
|
||
|
|
config.Port = port
|
||
|
|
config.MaxConnections = 9999
|
||
|
|
|
||
|
|
srv := server.New(config)
|
||
|
|
srv.SetLogger(log.WithPrefix(fmt.Sprintf("Server:%d", port)))
|
||
|
|
|
||
|
|
// Create handler for this server
|
||
|
|
handler := &MyHandler{
|
||
|
|
log: log.WithPrefix(fmt.Sprintf("Handler:%d", port)),
|
||
|
|
auth: authenticator,
|
||
|
|
srv: srv,
|
||
|
|
}
|
||
|
|
srv.SetHandler(handler)
|
||
|
|
|
||
|
|
servers = append(servers, srv)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Start all servers
|
||
|
|
for _, srv := range servers {
|
||
|
|
if err := srv.Start(); err != nil {
|
||
|
|
log.Fatal("Failed to start server: %v", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fmt.Printf("Server started on port(s): %v\n", ports)
|
||
|
|
fmt.Println("Logs are written to: logs/server.log")
|
||
|
|
fmt.Println("Press Ctrl+C to stop...")
|
||
|
|
|
||
|
|
// Wait for interrupt signal
|
||
|
|
sigChan := make(chan os.Signal, 1)
|
||
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||
|
|
<-sigChan
|
||
|
|
|
||
|
|
fmt.Println("\nShutting down...")
|
||
|
|
for _, srv := range servers {
|
||
|
|
srv.Stop()
|
||
|
|
}
|
||
|
|
fmt.Println("Server stopped")
|
||
|
|
}
|