mirror of
https://github.com/yuanyuanxiang/SimpleRemoter.git
synced 2026-01-22 07:14:15 +08:00
338 lines
7.0 KiB
Go
338 lines
7.0 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/yuanyuanxiang/SimpleRemoter/server/go/connection"
|
|
"github.com/yuanyuanxiang/SimpleRemoter/server/go/logger"
|
|
"github.com/yuanyuanxiang/SimpleRemoter/server/go/protocol"
|
|
)
|
|
|
|
// Config holds server configuration
|
|
type Config struct {
|
|
Port int
|
|
MaxConnections int
|
|
ReadBufferSize int
|
|
WriteBufferSize int
|
|
KeepAliveTime time.Duration
|
|
ReadTimeout time.Duration
|
|
WriteTimeout time.Duration
|
|
}
|
|
|
|
// DefaultConfig returns default configuration
|
|
func DefaultConfig() Config {
|
|
return Config{
|
|
Port: 6543,
|
|
MaxConnections: 9999,
|
|
ReadBufferSize: 8192,
|
|
WriteBufferSize: 8192,
|
|
KeepAliveTime: time.Minute * 5,
|
|
ReadTimeout: time.Minute * 2,
|
|
WriteTimeout: time.Second * 30,
|
|
}
|
|
}
|
|
|
|
// Handler defines the interface for handling client events
|
|
type Handler interface {
|
|
OnConnect(ctx *connection.Context)
|
|
OnDisconnect(ctx *connection.Context)
|
|
OnReceive(ctx *connection.Context, data []byte)
|
|
}
|
|
|
|
// Server is the TCP server
|
|
type Server struct {
|
|
config Config
|
|
listener net.Listener
|
|
manager *connection.Manager
|
|
handler Handler
|
|
parser *protocol.Parser
|
|
|
|
running atomic.Bool
|
|
wg sync.WaitGroup
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
|
|
// Logger
|
|
logger *logger.Logger
|
|
}
|
|
|
|
// New creates a new server
|
|
func New(config Config) *Server {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
s := &Server{
|
|
config: config,
|
|
manager: connection.NewManager(config.MaxConnections),
|
|
parser: protocol.NewParser(),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
logger: logger.New(logger.DefaultConfig()).WithPrefix("Server"),
|
|
}
|
|
|
|
return s
|
|
}
|
|
|
|
// SetHandler sets the event handler
|
|
func (s *Server) SetHandler(h Handler) {
|
|
s.handler = h
|
|
s.manager.SetCallbacks(
|
|
func(ctx *connection.Context) {
|
|
if s.handler != nil {
|
|
s.handler.OnConnect(ctx)
|
|
}
|
|
},
|
|
func(ctx *connection.Context) {
|
|
if s.handler != nil {
|
|
s.handler.OnDisconnect(ctx)
|
|
}
|
|
},
|
|
func(ctx *connection.Context, data []byte) {
|
|
if s.handler != nil {
|
|
s.handler.OnReceive(ctx, data)
|
|
}
|
|
},
|
|
)
|
|
}
|
|
|
|
// SetLogger sets the logger
|
|
func (s *Server) SetLogger(l *logger.Logger) {
|
|
s.logger = l
|
|
}
|
|
|
|
// Start starts the server
|
|
func (s *Server) Start() error {
|
|
addr := net.JoinHostPort("0.0.0.0", itoa(s.config.Port))
|
|
listener, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
s.listener = listener
|
|
s.running.Store(true)
|
|
|
|
s.logger.Info("Server started on port %d", s.config.Port)
|
|
|
|
s.wg.Add(1)
|
|
go s.acceptLoop()
|
|
|
|
return nil
|
|
}
|
|
|
|
func itoa(i int) string {
|
|
if i == 0 {
|
|
return "0"
|
|
}
|
|
var b [20]byte
|
|
pos := len(b)
|
|
neg := i < 0
|
|
if neg {
|
|
// Handle math.MinInt safely by using unsigned
|
|
u := uint(-i)
|
|
for u > 0 {
|
|
pos--
|
|
b[pos] = byte('0' + u%10)
|
|
u /= 10
|
|
}
|
|
pos--
|
|
b[pos] = '-'
|
|
} else {
|
|
for i > 0 {
|
|
pos--
|
|
b[pos] = byte('0' + i%10)
|
|
i /= 10
|
|
}
|
|
}
|
|
return string(b[pos:])
|
|
}
|
|
|
|
// Stop stops the server gracefully
|
|
func (s *Server) Stop() {
|
|
if !s.running.Swap(false) {
|
|
return
|
|
}
|
|
|
|
s.cancel()
|
|
|
|
if s.listener != nil {
|
|
_ = s.listener.Close()
|
|
}
|
|
|
|
// Close all connections
|
|
s.manager.CloseAll()
|
|
|
|
// Close parser resources
|
|
if s.parser != nil {
|
|
s.parser.Close()
|
|
}
|
|
|
|
s.wg.Wait()
|
|
s.logger.Info("Server stopped")
|
|
}
|
|
|
|
// acceptLoop accepts incoming connections
|
|
func (s *Server) acceptLoop() {
|
|
defer s.wg.Done()
|
|
|
|
for s.running.Load() {
|
|
conn, err := s.listener.Accept()
|
|
if err != nil {
|
|
if s.running.Load() {
|
|
s.logger.Error("Accept error: %v", err)
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Check connection limit before spawning goroutine
|
|
if s.manager.Count() >= s.config.MaxConnections {
|
|
s.logger.Warn("Max connections reached, rejecting new connection from %s", conn.RemoteAddr())
|
|
_ = conn.Close()
|
|
continue
|
|
}
|
|
|
|
// Handle each connection in its own goroutine
|
|
go s.handleConnection(conn)
|
|
}
|
|
}
|
|
|
|
// handleConnection handles a single connection
|
|
func (s *Server) handleConnection(conn net.Conn) {
|
|
// Create context
|
|
ctx := connection.NewContext(conn, nil)
|
|
|
|
// Add to manager
|
|
if err := s.manager.Add(ctx); err != nil {
|
|
s.logger.Warn("Failed to add connection: %v", err)
|
|
_ = conn.Close()
|
|
return
|
|
}
|
|
|
|
defer func() {
|
|
s.manager.Remove(ctx)
|
|
_ = ctx.Close()
|
|
}()
|
|
|
|
// Read loop
|
|
buf := make([]byte, s.config.ReadBufferSize)
|
|
for !ctx.IsClosed() && s.running.Load() {
|
|
// Set read deadline
|
|
if s.config.ReadTimeout > 0 {
|
|
_ = conn.SetReadDeadline(time.Now().Add(s.config.ReadTimeout))
|
|
}
|
|
|
|
n, err := conn.Read(buf)
|
|
if err != nil {
|
|
if err != io.EOF && s.running.Load() {
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
// Timeout - check keepalive
|
|
if s.config.KeepAliveTime > 0 && ctx.TimeSinceLastActive() > s.config.KeepAliveTime {
|
|
s.logger.Info("Connection %d timed out", ctx.ID)
|
|
break
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
break
|
|
}
|
|
|
|
if n > 0 {
|
|
ctx.UpdateLastActive()
|
|
|
|
// Write to input buffer
|
|
_, _ = ctx.InBuffer.Write(buf[:n])
|
|
|
|
// Process received data
|
|
s.processData(ctx)
|
|
}
|
|
}
|
|
}
|
|
|
|
// processData processes received data and calls handler
|
|
func (s *Server) processData(ctx *connection.Context) {
|
|
for ctx.InBuffer.Len() > 0 {
|
|
// Try to parse a complete packet
|
|
data, err := s.parser.Parse(ctx)
|
|
if err != nil {
|
|
if err == protocol.ErrNeedMore {
|
|
return
|
|
}
|
|
// Log more details for unsupported protocol errors
|
|
if err == protocol.ErrUnsupported {
|
|
// Get first 32 bytes for debugging
|
|
peekLen := 32
|
|
if ctx.InBuffer.Len() < peekLen {
|
|
peekLen = ctx.InBuffer.Len()
|
|
}
|
|
rawData := ctx.InBuffer.Peek(peekLen)
|
|
// Sanitize for safe logging (replace non-printable chars)
|
|
ascii := make([]byte, len(rawData))
|
|
for i, b := range rawData {
|
|
if b >= 32 && b < 127 {
|
|
ascii[i] = b
|
|
} else {
|
|
ascii[i] = '.'
|
|
}
|
|
}
|
|
s.logger.Warn("Unsupported protocol from ip=%s conn=%d raw=%x ascii=%s",
|
|
ctx.GetPeerIP(), ctx.ID, rawData, string(ascii))
|
|
} else {
|
|
s.logger.Error("Parse error for connection %d: %v", ctx.ID, err)
|
|
}
|
|
_ = ctx.Close()
|
|
return
|
|
}
|
|
|
|
if data != nil {
|
|
// Call handler
|
|
s.manager.OnReceive(ctx, data)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Send sends data to a specific connection
|
|
func (s *Server) Send(ctx *connection.Context, data []byte) error {
|
|
if ctx == nil || ctx.IsClosed() {
|
|
return connection.ErrConnectionClosed
|
|
}
|
|
|
|
// Encode and compress data
|
|
encoded, err := s.parser.Encode(ctx, data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return ctx.Send(encoded)
|
|
}
|
|
|
|
// Broadcast sends data to all connections
|
|
func (s *Server) Broadcast(data []byte) {
|
|
s.manager.Range(func(ctx *connection.Context) bool {
|
|
_ = s.Send(ctx, data)
|
|
return true
|
|
})
|
|
}
|
|
|
|
// GetConnection returns a connection by ID
|
|
func (s *Server) GetConnection(id uint64) *connection.Context {
|
|
return s.manager.Get(id)
|
|
}
|
|
|
|
// ConnectionCount returns the current connection count
|
|
func (s *Server) ConnectionCount() int {
|
|
return s.manager.Count()
|
|
}
|
|
|
|
// Port returns the server port
|
|
func (s *Server) Port() int {
|
|
return s.config.Port
|
|
}
|
|
|
|
// UpdateMaxConnections updates the max connections limit
|
|
func (s *Server) UpdateMaxConnections(max int) {
|
|
s.manager.UpdateMaxConnections(max)
|
|
}
|