Compare commits

..

6 Commits

Author SHA1 Message Date
yyhuni
e8a5e0cea8 Update README.md 2026-01-17 17:09:49 +08:00
yyhuni
3308908d7a Update README.md 2026-01-16 21:43:08 +08:00
yyhuni
a8402cfffa Update project status in README 2026-01-16 17:56:44 +08:00
yyhuni
dce4e12667 Update README.md 2026-01-16 17:56:26 +08:00
github-actions[bot]
bd1dd2c0d5 chore: bump version to v1.5.8 2026-01-11 19:34:26 +08:00
yyhuni
0b6560ac17 Update README.md 2026-01-10 16:28:30 +08:00
1027 changed files with 10897 additions and 198987 deletions

View File

@@ -1,45 +0,0 @@
name: Check Generated Files
on:
workflow_call: # 只在被其他 workflow 调用时运行
permissions:
contents: read
jobs:
check:
runs-on: ubuntu-22.04 # 固定版本,避免 runner 更新导致 CI 行为变化
steps:
- uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.24' # 与 go.mod 保持一致
- name: Generate files for all workflows
working-directory: worker
run: make generate
- name: Check for differences
run: |
if ! git diff --exit-code; then
echo "❌ Generated files are out of date!"
echo "Please run: cd worker && make generate"
echo ""
echo "Changed files:"
git status --porcelain
echo ""
echo "Diff:"
git diff
exit 1
fi
echo "✅ Generated files are up to date"
- name: Run metadata consistency tests
working-directory: worker
run: make test-metadata
- name: Run all tests
working-directory: worker
run: make test

View File

@@ -1,13 +0,0 @@
name: CI
on:
push:
branches: [main, develop]
pull_request:
permissions:
contents: read
jobs:
check-generated:
uses: ./.github/workflows/check-generated-files.yml

View File

@@ -19,13 +19,8 @@ permissions:
contents: write
jobs:
# 在构建前检查生成文件
check:
uses: ./.github/workflows/check-generated-files.yml
# AMD64 构建(原生 x64 runner
build-amd64:
needs: check # 依赖检查通过
runs-on: ubuntu-latest
strategy:
matrix:
@@ -101,7 +96,6 @@ jobs:
# ARM64 构建(原生 ARM64 runner
build-arm64:
needs: check # 依赖检查通过
runs-on: ubuntu-22.04-arm
strategy:
matrix:

175
.gitignore vendored
View File

@@ -1,60 +1,137 @@
# Go
*.exe
*.exe~
*.dll
*.so
*.dylib
*.test
*.out
vendor/
go.work
# ============================
# 操作系统相关文件
# ============================
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Build artifacts
dist/
build/
bin/
# ============================
# 前端 (Next.js/Node.js) 相关
# ============================
# 依赖目录
front-back/node_modules/
front-back/.pnpm-store/
# IDE
# Next.js 构建产物
front-back/.next/
front-back/out/
front-back/dist/
# 环境变量文件
front-back/.env
front-back/.env.local
front-back/.env.development.local
front-back/.env.test.local
front-back/.env.production.local
# 运行时和缓存
front-back/.turbo/
front-back/.swc/
front-back/.eslintcache
front-back/.tsbuildinfo
# ============================
# 后端 (Python/Django) 相关
# ============================
# Python 虚拟环境
.venv/
venv/
env/
ENV/
# Python 编译文件
*.pyc
*.pyo
*.pyd
__pycache__/
*.py[cod]
*$py.class
# Django 相关
backend/db.sqlite3
backend/db.sqlite3-journal
backend/media/
backend/staticfiles/
backend/.env
backend/.env.local
# Python 测试和覆盖率
.pytest_cache/
.coverage
htmlcov/
*.cover
.hypothesis/
# ============================
# 后端 (Go) 相关
# ============================
# 编译产物
backend/bin/
backend/dist/
backend/*.exe
backend/*.exe~
backend/*.dll
backend/*.so
backend/*.dylib
# 测试相关
backend/*.test
backend/*.out
backend/*.prof
# Go workspace 文件
backend/go.work
backend/go.work.sum
# Go 依赖管理
backend/vendor/
# ============================
# IDE 和编辑器相关
# ============================
.vscode/
.idea/
.cursor/
.claude/
.kiro/
.playwright-mcp/
*.swp
*.swo
*~
.DS_Store
# Environment
.env
.env.local
.env.*.local
**/.env
**/.env.local
**/.env.*.local
*.log
.venv/
# ============================
# Docker 相关
# ============================
docker/.env
docker/.env.local
# Testing
coverage.txt
*.coverprofile
.hypothesis/
# Temporary files
*.tmp
tmp/
temp/
.kiro/
.claude/
.specify/
# AI Assistant directories
codex/
openspec/
specs/
AGENTS.md
WARP.md
.opencode/
# SSL certificates
# SSL 证书和私钥(不应提交)
docker/nginx/ssl/*.pem
docker/nginx/ssl/*.key
docker/nginx/ssl/*.crt
docker/nginx/ssl/*.crt
# ============================
# 日志文件和扫描结果
# ============================
*.log
logs/
results/
# 开发脚本运行时文件(进程 ID 和启动日志)
backend/scripts/dev/.pids/
# ============================
# 临时文件
# ============================
tmp/
temp/
.cache/
HGETALL
KEYS
vuln_scan/input_endpoints.txt
open-in-v0

333
README.md Normal file
View File

@@ -0,0 +1,333 @@
<h1 align="center">XingRin - 星环</h1>
<p align="center">
<b>攻击面管理平台 (ASM) | 自动化资产发现与漏洞扫描系统</b>
</p>
<p align="center">
<a href="https://github.com/yyhuni/xingrin/stargazers"><img src="https://img.shields.io/github/stars/yyhuni/xingrin?style=flat-square&logo=github" alt="GitHub stars"></a>
<a href="https://github.com/yyhuni/xingrin/network/members"><img src="https://img.shields.io/github/forks/yyhuni/xingrin?style=flat-square&logo=github" alt="GitHub forks"></a>
<a href="https://github.com/yyhuni/xingrin/issues"><img src="https://img.shields.io/github/issues/yyhuni/xingrin?style=flat-square&logo=github" alt="GitHub issues"></a>
<a href="https://github.com/yyhuni/xingrin/blob/main/LICENSE"><img src="https://img.shields.io/badge/license-PolyForm%20NC-blue?style=flat-square" alt="License"></a>
</p>
<p align="center">
<a href="#功能特性">功能特性</a> •
<a href="#全局资产搜索">资产搜索</a> •
<a href="#快速开始">快速开始</a> •
<a href="#文档">文档</a> •
<a href="#反馈与贡献">反馈与贡献</a>
</p>
<p align="center">
<sub>关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | 资产搜索 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
</p>
---
## 在线 Demo
**[https://xingrin.vercel.app/](https://xingrin.vercel.app/)**
> 仅用于 UI 展示,未接入后端数据库
---
<p align="center">
<b>现代化 UI</b>
</p>
<p align="center">
<img src="docs/screenshots/light.png" alt="Light Mode" width="24%">
<img src="docs/screenshots/bubblegum.png" alt="Bubblegum" width="24%">
<img src="docs/screenshots/cosmic-night.png" alt="Cosmic Night" width="24%">
<img src="docs/screenshots/quantum-rose.png" alt="Quantum Rose" width="24%">
</p>
## 文档
- [技术文档](./docs/README.md) - 技术文档导航(持续完善中)
- [快速开始](./docs/quick-start.md) - 一键安装和部署指南
- [版本管理](./docs/version-management.md) - Git Tag 驱动的自动化版本管理系统
- [Nuclei 模板架构](./docs/nuclei-template-architecture.md) - 模板仓库的存储与同步
- [字典文件架构](./docs/wordlist-architecture.md) - 字典文件的存储与同步
- [扫描流程架构](./docs/scan-flow-architecture.md) - 完整扫描流程与工具编排
---
## 功能特性
### 扫描能力
| 功能 | 状态 | 工具 | 说明 |
|------|------|------|------|
| 子域名扫描 | 已完成 | Subfinder, Amass, PureDNS | 被动收集 + 主动爆破,聚合 50+ 数据源 |
| 端口扫描 | 已完成 | Naabu | 自定义端口范围 |
| 站点发现 | 已完成 | HTTPX | HTTP 探测,自动获取标题、状态码、技术栈 |
| 指纹识别 | 已完成 | XingFinger | 2.7W+ 指纹规则,多源指纹库 |
| URL 收集 | 已完成 | Waymore, Katana | 历史数据 + 主动爬取 |
| 目录扫描 | 已完成 | FFUF | 高速爆破,智能字典 |
| 漏洞扫描 | 已完成 | Nuclei, Dalfox | 9000+ POC 模板XSS 检测 |
| 站点截图 | 已完成 | Playwright | WebP 高压缩存储 |
### 平台能力
| 功能 | 状态 | 说明 |
|------|------|------|
| 目标管理 | 已完成 | 多层级组织,支持域名/IP 目标 |
| 资产快照 | 已完成 | 扫描结果对比,追踪资产变化 |
| 黑名单过滤 | 已完成 | 全局 + Target 级,支持通配符/CIDR |
| 定时任务 | 已完成 | Cron 表达式,自动化周期扫描 |
| 分布式扫描 | 已完成 | 多 Worker 节点,负载感知调度 |
| 全局搜索 | 已完成 | 表达式语法,多字段组合查询 |
| 通知推送 | 已完成 | 企业微信、Telegram、Discord |
| API 密钥管理 | 已完成 | 可视化配置各数据源 API Key |
### 扫描流程架构
完整的扫描流程包括子域名发现、端口扫描、站点发现、指纹识别、URL 收集、目录扫描、漏洞扫描等阶段
```mermaid
flowchart LR
START["开始扫描"]
subgraph STAGE1["阶段 1: 资产发现"]
direction TB
SUB["子域名发现<br/>subfinder, amass, puredns"]
PORT["端口扫描<br/>naabu"]
SITE["站点识别<br/>httpx"]
FINGER["指纹识别<br/>xingfinger"]
SUB --> PORT --> SITE --> FINGER
end
subgraph STAGE2["阶段 2: 深度分析"]
direction TB
URL["URL 收集<br/>waymore, katana"]
DIR["目录扫描<br/>ffuf"]
SCREENSHOT["站点截图<br/>playwright"]
end
subgraph STAGE3["阶段 3: 漏洞检测"]
VULN["漏洞扫描<br/>nuclei, dalfox"]
end
FINISH["扫描完成"]
START --> STAGE1
FINGER --> STAGE2
STAGE2 --> STAGE3
STAGE3 --> FINISH
style START fill:#34495e,stroke:#2c3e50,stroke-width:2px,color:#fff
style FINISH fill:#27ae60,stroke:#229954,stroke-width:2px,color:#fff
style STAGE1 fill:#3498db,stroke:#2980b9,stroke-width:2px,color:#fff
style STAGE2 fill:#9b59b6,stroke:#8e44ad,stroke-width:2px,color:#fff
style STAGE3 fill:#e67e22,stroke:#d35400,stroke-width:2px,color:#fff
style SUB fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
style PORT fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
style SITE fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
style FINGER fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
style URL fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
style DIR fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
style SCREENSHOT fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
style VULN fill:#f0b27a,stroke:#e67e22,stroke-width:1px,color:#fff
```
详细说明请查看 [扫描流程架构文档](./docs/scan-flow-architecture.md)
### 分布式架构
- **多节点扫描** - 支持部署多个 Worker 节点,横向扩展扫描能力
- **本地节点** - 零配置,安装即自动注册本地 Docker Worker
- **远程节点** - SSH 一键部署远程 VPS 作为扫描节点
- **负载感知调度** - 实时感知节点负载,自动分发任务到最优节点
- **节点监控** - 实时心跳检测CPU/内存/磁盘状态监控
- **断线重连** - 节点离线自动检测,恢复后自动重新接入
```mermaid
flowchart TB
subgraph MASTER["主服务器 (Master Server)"]
direction TB
REDIS["Redis 负载缓存"]
subgraph SCHEDULER["任务调度器 (Task Distributor)"]
direction TB
SUBMIT["接收扫描任务"]
SELECT["负载感知选择"]
DISPATCH["智能分发"]
SUBMIT --> SELECT
SELECT --> DISPATCH
end
REDIS -.负载数据.-> SELECT
end
subgraph WORKERS["Worker 节点集群"]
direction TB
W1["Worker 1 (本地)<br/>CPU: 45% | MEM: 60%"]
W2["Worker 2 (远程)<br/>CPU: 30% | MEM: 40%"]
W3["Worker N (远程)<br/>CPU: 90% | MEM: 85%"]
end
DISPATCH -->|任务分发| W1
DISPATCH -->|任务分发| W2
DISPATCH -->|高负载跳过| W3
W1 -.心跳上报.-> REDIS
W2 -.心跳上报.-> REDIS
W3 -.心跳上报.-> REDIS
```
### 全局资产搜索
- **多类型搜索** - 支持 Website 和 Endpoint 两种资产类型
- **表达式语法** - 支持 `=`(模糊)、`==`(精确)、`!=`(不等于)操作符
- **逻辑组合** - 支持 `&&` (AND) 和 `||` (OR) 逻辑组合
- **多字段查询** - 支持 host、url、title、tech、status、body、header 字段
- **CSV 导出** - 流式导出全部搜索结果,无数量限制
#### 搜索语法示例
```bash
# 基础搜索
host="api" # host 包含 "api"
status=="200" # 状态码精确等于 200
tech="nginx" # 技术栈包含 nginx
# 组合搜索
host="api" && status=="200" # host 包含 api 且状态码为 200
tech="vue" || tech="react" # 技术栈包含 vue 或 react
# 复杂查询
host="admin" && tech="php" && status=="200"
url="/api/v1" && status!="404"
```
### 可视化界面
- **数据统计** - 资产/漏洞统计仪表盘
- **实时通知** - WebSocket 消息推送
- **通知推送** - 实时企业微信tgdiscard消息推送服务
---
## 快速开始
### 环境要求
- **操作系统**: Ubuntu 20.04+ / Debian 11+
- **系统架构**: AMD64 (x86_64) / ARM64 (aarch64)
- **硬件**: 2核 4G 内存起步20GB+ 磁盘空间
### 一键安装
```bash
# 克隆项目
git clone https://github.com/yyhuni/xingrin.git
cd xingrin
# 安装并启动(生产模式)
sudo ./install.sh
# 中国大陆用户推荐使用镜像加速(第三方加速服务可能会失效,不保证长期可用)
sudo ./install.sh --mirror
```
> **--mirror 参数说明**
> - 自动配置 Docker 镜像加速(国内镜像源)
> - 加速 Git 仓库克隆Nuclei 模板等)
### 访问服务
- **Web 界面**: `https://ip:8083`
- **默认账号**: admin / admin首次登录后请修改密码
### 常用命令
```bash
# 启动服务
sudo ./start.sh
# 停止服务
sudo ./stop.sh
# 重启服务
sudo ./restart.sh
# 卸载
sudo ./uninstall.sh
```
## 反馈与贡献
- **发现 Bug有新想法比如UI设计功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues) 或者公众号私信
## 联系
- 微信公众号: **塔罗安全学苑**
- 微信群去公众号底下的菜单,有个交流群,点击就可以看到了,链接过期可以私信我拉你
<img src="docs/wechat-qrcode.png" alt="微信公众号" width="200">
### 关注公众号免费领取指纹库
| 指纹库 | 数量 |
|--------|------|
| ehole.json | 21,977 |
| ARL.yaml | 9,264 |
| goby.json | 7,086 |
| FingerprintHub.json | 3,147 |
> 关注公众号回复「指纹」即可获取
## 赞助支持
如果这个项目对你有帮助谢谢请我能喝杯蜜雪冰城你的star和赞助是我免费更新的动力
<p>
<img src="docs/wx_pay.jpg" alt="微信支付" width="200">
<img src="docs/zfb_pay.jpg" alt="支付宝" width="200">
</p>
## 免责声明
**重要:请在使用前仔细阅读**
1. 本工具仅供**授权的安全测试**和**安全研究**使用
2. 使用者必须确保已获得目标系统的**合法授权**
3. **严禁**将本工具用于未经授权的渗透测试或攻击行为
4. 未经授权扫描他人系统属于**违法行为**,可能面临法律责任
5. 开发者**不对任何滥用行为负责**
使用本工具即表示您同意:
- 仅在合法授权范围内使用
- 遵守所在地区的法律法规
- 承担因滥用产生的一切后果
## Star History
如果这个项目对你有帮助,请给一个 Star 支持一下!
[![Star History Chart](https://api.star-history.com/svg?repos=yyhuni/xingrin&type=Date)](https://star-history.com/#yyhuni/xingrin&Date)
## 许可证
本项目采用 [GNU General Public License v3.0](LICENSE) 许可证。
### 允许的用途
- 个人学习和研究
- 商业和非商业使用
- 修改和分发
- 专利使用
- 私人使用
### 义务和限制
- **开源义务**:分发时必须提供源代码
- **相同许可**:衍生作品必须使用相同许可证
- **版权声明**:必须保留原始版权和许可证声明
- **责任免除**:不提供任何担保
- 未经授权的渗透测试
- 任何违法行为

View File

@@ -1 +1 @@
v1.5.12-dev
v1.5.8

View File

@@ -1,13 +0,0 @@
root = "."
tmp_dir = "tmp"
[build]
cmd = "go build -o ./tmp/agent ./cmd/agent"
bin = "./tmp/agent"
delay = 1000
include_ext = ["go", "tpl", "tmpl", "html"]
exclude_dir = ["tmp", "vendor", ".git"]
exclude_regex = ["_test\\.go"]
[log]
time = true

View File

@@ -1,41 +0,0 @@
# syntax=docker/dockerfile:1
# ============================================
# Go Agent - build
# ============================================
FROM golang:1.25.6 AS builder
ARG GO111MODULE=on
ARG GOPROXY=https://goproxy.cn,direct
ENV GO111MODULE=$GO111MODULE
ENV GOPROXY=$GOPROXY
WORKDIR /src
# Cache dependencies
COPY agent/go.mod agent/go.sum ./
RUN go mod download
# Copy source
COPY agent ./agent
WORKDIR /src/agent
# Build (static where possible)
RUN CGO_ENABLED=0 go build -o /out/agent ./cmd/agent
# ============================================
# Go Agent - runtime
# ============================================
FROM debian:bookworm-20260112-slim
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY --from=builder /out/agent /usr/local/bin/agent
CMD ["agent"]

View File

@@ -1,37 +0,0 @@
package main
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"github.com/yyhuni/lunafox/agent/internal/app"
"github.com/yyhuni/lunafox/agent/internal/config"
"github.com/yyhuni/lunafox/agent/internal/logger"
"go.uber.org/zap"
)
func main() {
if err := logger.Init(os.Getenv("LOG_LEVEL")); err != nil {
fmt.Fprintf(os.Stderr, "logger init failed: %v\n", err)
}
defer logger.Sync()
cfg, err := config.Load(os.Args[1:])
if err != nil {
logger.Log.Fatal("failed to load config", zap.Error(err))
}
wsURL, err := config.BuildWebSocketURL(cfg.ServerURL)
if err != nil {
logger.Log.Fatal("invalid server URL", zap.Error(err))
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
if err := app.Run(ctx, *cfg, wsURL); err != nil {
logger.Log.Fatal("agent stopped", zap.Error(err))
}
}

View File

@@ -1,48 +0,0 @@
module github.com/yyhuni/lunafox/agent
go 1.24.5
require (
github.com/docker/docker v28.5.2+incompatible
github.com/gorilla/websocket v1.5.3
github.com/opencontainers/image-spec v1.1.1
github.com/shirou/gopsutil/v3 v3.24.5
go.uber.org/zap v1.27.0
)
require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/sys/atomicwriter v0.1.0 // indirect
github.com/moby/term v0.5.2 // indirect
github.com/morikuni/aec v1.1.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect
go.opentelemetry.io/otel v1.39.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0 // indirect
go.opentelemetry.io/otel/metric v1.39.0 // indirect
go.opentelemetry.io/otel/trace v1.39.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/time v0.14.0 // indirect
gotest.tools/v3 v3.5.2 // indirect
)

View File

@@ -1,131 +0,0 @@
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM=
github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
github.com/morikuni/aec v1.1.0 h1:vBBl0pUnvi/Je71dsRrhMBtreIqNMYErSAbEeb8jrXQ=
github.com/morikuni/aec v1.1.0/go.mod h1:xDRgiq/iw5l+zkao76YTKzKttOp2cwPEne25HDkJnBw=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ=
go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48=
go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 h1:f0cb2XPmrqn4XMy9PNliTgRKJgS5WcL/u0/WRYGz4t0=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0/go.mod h1:vnakAaFckOMiMtOIhFI2MNH4FYrZzXCYxmb1LlhoGz8=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0 h1:Ckwye2FpXkYgiHX7fyVrN1uA/UYd9ounqqTuSNAv0k4=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0/go.mod h1:teIFJh5pW2y+AN7riv6IBPX2DuesS3HgP39mwOspKwU=
go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0=
go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs=
go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18=
go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE=
go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8=
go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew=
go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI=
go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA=
go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A=
go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls=
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM=
google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=

View File

@@ -1,139 +0,0 @@
package app
import (
"context"
"errors"
"os"
"strconv"
"time"
"github.com/yyhuni/lunafox/agent/internal/config"
"github.com/yyhuni/lunafox/agent/internal/docker"
"github.com/yyhuni/lunafox/agent/internal/domain"
"github.com/yyhuni/lunafox/agent/internal/health"
"github.com/yyhuni/lunafox/agent/internal/logger"
"github.com/yyhuni/lunafox/agent/internal/metrics"
"github.com/yyhuni/lunafox/agent/internal/protocol"
"github.com/yyhuni/lunafox/agent/internal/task"
"github.com/yyhuni/lunafox/agent/internal/update"
agentws "github.com/yyhuni/lunafox/agent/internal/websocket"
"go.uber.org/zap"
)
func Run(ctx context.Context, cfg config.Config, wsURL string) error {
configUpdater := config.NewUpdater(cfg)
version := cfg.AgentVersion
hostname := os.Getenv("AGENT_HOSTNAME")
if hostname == "" {
var err error
hostname, err = os.Hostname()
if err != nil || hostname == "" {
hostname = "unknown"
}
}
logger.Log.Info("agent starting",
zap.String("version", version),
zap.String("hostname", hostname),
zap.String("server", cfg.ServerURL),
zap.String("ws", wsURL),
zap.Int("maxTasks", cfg.MaxTasks),
zap.Int("cpuThreshold", cfg.CPUThreshold),
zap.Int("memThreshold", cfg.MemThreshold),
zap.Int("diskThreshold", cfg.DiskThreshold),
)
client := agentws.NewClient(wsURL, cfg.APIKey)
collector := metrics.NewCollector()
healthManager := health.NewManager()
taskCounter := &task.Counter{}
heartbeat := agentws.NewHeartbeatSender(client, collector, healthManager, version, hostname, taskCounter.Count)
taskClient := task.NewClient(cfg.ServerURL, cfg.APIKey)
puller := task.NewPuller(taskClient, collector, taskCounter, cfg.MaxTasks, cfg.CPUThreshold, cfg.MemThreshold, cfg.DiskThreshold)
taskQueue := make(chan *domain.Task, cfg.MaxTasks)
puller.SetOnTask(func(t *domain.Task) {
logger.Log.Info("task received",
zap.Int("taskId", t.ID),
zap.Int("scanId", t.ScanID),
zap.String("workflow", t.WorkflowName),
zap.Int("stage", t.Stage),
zap.String("target", t.TargetName),
)
taskQueue <- t
})
dockerClient, err := docker.NewClient()
if err != nil {
logger.Log.Warn("docker client unavailable", zap.Error(err))
} else {
logger.Log.Info("docker client ready")
}
workerToken := os.Getenv("WORKER_TOKEN")
if workerToken == "" {
return errors.New("WORKER_TOKEN environment variable is required")
}
logger.Log.Info("worker token loaded")
executor := task.NewExecutor(dockerClient, taskClient, taskCounter, cfg.ServerURL, workerToken, version)
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := executor.Shutdown(shutdownCtx); err != nil && !errors.Is(err, context.DeadlineExceeded) {
logger.Log.Error("executor shutdown error", zap.Error(err))
}
}()
updater := update.NewUpdater(dockerClient, healthManager, puller, executor, configUpdater, cfg.APIKey, workerToken)
handler := agentws.NewHandler()
handler.OnTaskAvailable(puller.NotifyTaskAvailable)
handler.OnTaskCancel(func(taskID int) {
logger.Log.Info("task cancel requested", zap.Int("taskId", taskID))
executor.MarkCancelled(taskID)
executor.CancelTask(taskID)
})
handler.OnConfigUpdate(func(payload protocol.ConfigUpdatePayload) {
logger.Log.Info("config update received",
zap.String("maxTasks", formatOptionalInt(payload.MaxTasks)),
zap.String("cpuThreshold", formatOptionalInt(payload.CPUThreshold)),
zap.String("memThreshold", formatOptionalInt(payload.MemThreshold)),
zap.String("diskThreshold", formatOptionalInt(payload.DiskThreshold)),
)
cfgUpdate := config.Update{
MaxTasks: payload.MaxTasks,
CPUThreshold: payload.CPUThreshold,
MemThreshold: payload.MemThreshold,
DiskThreshold: payload.DiskThreshold,
}
configUpdater.Apply(cfgUpdate)
puller.UpdateConfig(cfgUpdate.MaxTasks, cfgUpdate.CPUThreshold, cfgUpdate.MemThreshold, cfgUpdate.DiskThreshold)
})
handler.OnUpdateRequired(updater.HandleUpdateRequired)
client.SetOnMessage(handler.Handle)
logger.Log.Info("starting heartbeat sender")
go heartbeat.Start(ctx)
logger.Log.Info("starting task puller")
go func() {
_ = puller.Run(ctx)
}()
logger.Log.Info("starting task executor")
go executor.Start(ctx, taskQueue)
logger.Log.Info("connecting to server websocket")
if err := client.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
return err
}
return nil
}
func formatOptionalInt(value *int) string {
if value == nil {
return "nil"
}
return strconv.Itoa(*value)
}

View File

@@ -1,53 +0,0 @@
package config
import (
"errors"
"fmt"
)
// Config represents runtime settings for the agent.
type Config struct {
ServerURL string
APIKey string
AgentVersion string
MaxTasks int
CPUThreshold int
MemThreshold int
DiskThreshold int
}
// Validate ensures config values are usable.
func (c *Config) Validate() error {
if c.ServerURL == "" {
return errors.New("server URL is required")
}
if c.APIKey == "" {
return errors.New("api key is required")
}
if c.AgentVersion == "" {
return errors.New("AGENT_VERSION environment variable is required")
}
if c.MaxTasks < 1 {
return errors.New("max tasks must be at least 1")
}
if err := validatePercent("cpu threshold", c.CPUThreshold); err != nil {
return err
}
if err := validatePercent("mem threshold", c.MemThreshold); err != nil {
return err
}
if err := validatePercent("disk threshold", c.DiskThreshold); err != nil {
return err
}
if _, err := BuildWebSocketURL(c.ServerURL); err != nil {
return err
}
return nil
}
func validatePercent(name string, value int) error {
if value < 1 || value > 100 {
return fmt.Errorf("%s must be between 1 and 100", name)
}
return nil
}

View File

@@ -1,87 +0,0 @@
package config
import (
"flag"
"fmt"
"os"
"strconv"
"strings"
)
const (
defaultMaxTasks = 5
defaultCPUThreshold = 85
defaultMemThreshold = 85
defaultDiskThreshold = 90
)
// Load parses configuration from environment variables and CLI flags.
func Load(args []string) (*Config, error) {
maxTasks, err := readEnvInt("MAX_TASKS", defaultMaxTasks)
if err != nil {
return nil, err
}
cpuThreshold, err := readEnvInt("CPU_THRESHOLD", defaultCPUThreshold)
if err != nil {
return nil, err
}
memThreshold, err := readEnvInt("MEM_THRESHOLD", defaultMemThreshold)
if err != nil {
return nil, err
}
diskThreshold, err := readEnvInt("DISK_THRESHOLD", defaultDiskThreshold)
if err != nil {
return nil, err
}
cfg := &Config{
ServerURL: strings.TrimSpace(os.Getenv("SERVER_URL")),
APIKey: strings.TrimSpace(os.Getenv("API_KEY")),
AgentVersion: strings.TrimSpace(os.Getenv("AGENT_VERSION")),
MaxTasks: maxTasks,
CPUThreshold: cpuThreshold,
MemThreshold: memThreshold,
DiskThreshold: diskThreshold,
}
fs := flag.NewFlagSet("agent", flag.ContinueOnError)
serverURL := fs.String("server-url", cfg.ServerURL, "Server base URL (e.g. https://1.1.1.1:8080)")
apiKey := fs.String("api-key", cfg.APIKey, "Agent API key")
maxTasksFlag := fs.Int("max-tasks", cfg.MaxTasks, "Maximum concurrent tasks")
cpuThresholdFlag := fs.Int("cpu-threshold", cfg.CPUThreshold, "CPU threshold percentage")
memThresholdFlag := fs.Int("mem-threshold", cfg.MemThreshold, "Memory threshold percentage")
diskThresholdFlag := fs.Int("disk-threshold", cfg.DiskThreshold, "Disk threshold percentage")
if err := fs.Parse(args); err != nil {
return nil, err
}
cfg.ServerURL = strings.TrimSpace(*serverURL)
cfg.APIKey = strings.TrimSpace(*apiKey)
cfg.MaxTasks = *maxTasksFlag
cfg.CPUThreshold = *cpuThresholdFlag
cfg.MemThreshold = *memThresholdFlag
cfg.DiskThreshold = *diskThresholdFlag
if err := cfg.Validate(); err != nil {
return nil, err
}
return cfg, nil
}
func readEnvInt(key string, fallback int) (int, error) {
val, ok := os.LookupEnv(key)
if !ok {
return fallback, nil
}
val = strings.TrimSpace(val)
if val == "" {
return fallback, nil
}
parsed, err := strconv.Atoi(val)
if err != nil {
return 0, fmt.Errorf("invalid %s: %w", key, err)
}
return parsed, nil
}

View File

@@ -1,75 +0,0 @@
package config
import (
"testing"
)
func TestLoadConfigFromEnvAndFlags(t *testing.T) {
t.Setenv("SERVER_URL", "https://example.com")
t.Setenv("API_KEY", "abc12345")
t.Setenv("AGENT_VERSION", "v1.2.3")
t.Setenv("MAX_TASKS", "5")
t.Setenv("CPU_THRESHOLD", "80")
t.Setenv("MEM_THRESHOLD", "81")
t.Setenv("DISK_THRESHOLD", "82")
cfg, err := Load([]string{})
if err != nil {
t.Fatalf("load failed: %v", err)
}
if cfg.ServerURL != "https://example.com" {
t.Fatalf("expected server url from env")
}
if cfg.MaxTasks != 5 {
t.Fatalf("expected max tasks from env")
}
args := []string{
"--server-url=https://override.example.com",
"--api-key=deadbeef",
"--max-tasks=9",
"--cpu-threshold=70",
"--mem-threshold=71",
"--disk-threshold=72",
}
cfg, err = Load(args)
if err != nil {
t.Fatalf("load failed: %v", err)
}
if cfg.ServerURL != "https://override.example.com" {
t.Fatalf("expected server url from args")
}
if cfg.APIKey != "deadbeef" {
t.Fatalf("expected api key from args")
}
if cfg.MaxTasks != 9 {
t.Fatalf("expected max tasks from args")
}
if cfg.CPUThreshold != 70 || cfg.MemThreshold != 71 || cfg.DiskThreshold != 72 {
t.Fatalf("expected thresholds from args")
}
}
func TestLoadConfigMissingRequired(t *testing.T) {
t.Setenv("SERVER_URL", "")
t.Setenv("API_KEY", "")
t.Setenv("AGENT_VERSION", "v1.2.3")
_, err := Load([]string{})
if err == nil {
t.Fatalf("expected error when required values missing")
}
}
func TestLoadConfigInvalidEnvValue(t *testing.T) {
t.Setenv("SERVER_URL", "https://example.com")
t.Setenv("API_KEY", "abc")
t.Setenv("AGENT_VERSION", "v1.2.3")
t.Setenv("MAX_TASKS", "nope")
_, err := Load([]string{})
if err == nil {
t.Fatalf("expected error for invalid MAX_TASKS")
}
}

View File

@@ -1,49 +0,0 @@
package config
import (
"sync"
"github.com/yyhuni/lunafox/agent/internal/domain"
)
// Update holds optional configuration updates.
type Update = domain.ConfigUpdate
// Updater manages runtime configuration changes.
type Updater struct {
mu sync.RWMutex
cfg Config
}
// NewUpdater creates an updater with initial config.
func NewUpdater(cfg Config) *Updater {
return &Updater{cfg: cfg}
}
// Apply updates the configuration and returns the new snapshot.
func (u *Updater) Apply(update Update) Config {
u.mu.Lock()
defer u.mu.Unlock()
if update.MaxTasks != nil && *update.MaxTasks > 0 {
u.cfg.MaxTasks = *update.MaxTasks
}
if update.CPUThreshold != nil && *update.CPUThreshold > 0 {
u.cfg.CPUThreshold = *update.CPUThreshold
}
if update.MemThreshold != nil && *update.MemThreshold > 0 {
u.cfg.MemThreshold = *update.MemThreshold
}
if update.DiskThreshold != nil && *update.DiskThreshold > 0 {
u.cfg.DiskThreshold = *update.DiskThreshold
}
return u.cfg
}
// Snapshot returns a copy of current config.
func (u *Updater) Snapshot() Config {
u.mu.RLock()
defer u.mu.RUnlock()
return u.cfg
}

View File

@@ -1,39 +0,0 @@
package config
import "testing"
func TestUpdaterApplyAndSnapshot(t *testing.T) {
cfg := Config{
ServerURL: "https://example.com",
APIKey: "key",
MaxTasks: 2,
CPUThreshold: 70,
MemThreshold: 80,
DiskThreshold: 90,
}
updater := NewUpdater(cfg)
snapshot := updater.Snapshot()
if snapshot.MaxTasks != 2 || snapshot.CPUThreshold != 70 {
t.Fatalf("unexpected snapshot values")
}
invalid := 0
update := Update{MaxTasks: &invalid, CPUThreshold: &invalid}
snapshot = updater.Apply(update)
if snapshot.MaxTasks != 2 || snapshot.CPUThreshold != 70 {
t.Fatalf("expected invalid update to be ignored")
}
maxTasks := 5
cpu := 85
mem := 60
snapshot = updater.Apply(Update{
MaxTasks: &maxTasks,
CPUThreshold: &cpu,
MemThreshold: &mem,
})
if snapshot.MaxTasks != 5 || snapshot.CPUThreshold != 85 || snapshot.MemThreshold != 60 {
t.Fatalf("unexpected applied update")
}
}

View File

@@ -1,50 +0,0 @@
package config
import (
"errors"
"fmt"
"net/url"
"strings"
)
// BuildWebSocketURL derives the agent WebSocket endpoint from the server URL.
func BuildWebSocketURL(serverURL string) (string, error) {
trimmed := strings.TrimSpace(serverURL)
if trimmed == "" {
return "", errors.New("server URL is required")
}
parsed, err := url.Parse(trimmed)
if err != nil {
return "", err
}
switch strings.ToLower(parsed.Scheme) {
case "http":
parsed.Scheme = "ws"
case "https":
parsed.Scheme = "wss"
case "ws", "wss":
default:
if parsed.Scheme == "" {
return "", errors.New("server URL scheme is required")
}
return "", fmt.Errorf("unsupported server URL scheme: %s", parsed.Scheme)
}
parsed.Path = buildWSPath(parsed.Path)
parsed.RawQuery = ""
parsed.Fragment = ""
return parsed.String(), nil
}
func buildWSPath(path string) string {
trimmed := strings.TrimRight(path, "/")
if trimmed == "" {
return "/api/agents/ws"
}
if strings.HasSuffix(trimmed, "/api") {
return trimmed + "/agents/ws"
}
return trimmed + "/api/agents/ws"
}

View File

@@ -1,38 +0,0 @@
package config
import "testing"
func TestBuildWebSocketURL(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"https://example.com", "wss://example.com/api/agents/ws"},
{"http://example.com", "ws://example.com/api/agents/ws"},
{"https://example.com/api", "wss://example.com/api/agents/ws"},
{"https://example.com/base", "wss://example.com/base/api/agents/ws"},
{"wss://example.com", "wss://example.com/api/agents/ws"},
}
for _, tt := range tests {
got, err := BuildWebSocketURL(tt.input)
if err != nil {
t.Fatalf("unexpected error for %s: %v", tt.input, err)
}
if got != tt.expected {
t.Fatalf("input %s expected %s got %s", tt.input, tt.expected, got)
}
}
}
func TestBuildWebSocketURLInvalid(t *testing.T) {
if _, err := BuildWebSocketURL("example.com"); err == nil {
t.Fatalf("expected error for missing scheme")
}
if _, err := BuildWebSocketURL(" "); err == nil {
t.Fatalf("expected error for empty url")
}
if _, err := BuildWebSocketURL("ftp://example.com"); err == nil {
t.Fatalf("expected error for unsupported scheme")
}
}

View File

@@ -1,23 +0,0 @@
package docker
import (
"context"
"github.com/docker/docker/api/types/container"
)
// Remove removes the container.
func (c *Client) Remove(ctx context.Context, containerID string) error {
return c.cli.ContainerRemove(ctx, containerID, container.RemoveOptions{
Force: true,
RemoveVolumes: true,
})
}
// Stop stops a running container with a timeout.
func (c *Client) Stop(ctx context.Context, containerID string) error {
timeout := 10
return c.cli.ContainerStop(ctx, containerID, container.StopOptions{
Timeout: &timeout,
})
}

View File

@@ -1,46 +0,0 @@
package docker
import (
"context"
"io"
"github.com/docker/docker/api/types/container"
imagetypes "github.com/docker/docker/api/types/image"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
)
// Client wraps the Docker SDK client.
type Client struct {
cli *client.Client
}
// NewClient creates a Docker client using environment configuration.
func NewClient() (*Client, error) {
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
return nil, err
}
return &Client{cli: cli}, nil
}
// Close closes the Docker client.
func (c *Client) Close() error {
return c.cli.Close()
}
// ImagePull pulls an image from the registry.
func (c *Client) ImagePull(ctx context.Context, imageRef string) (io.ReadCloser, error) {
return c.cli.ImagePull(ctx, imageRef, imagetypes.PullOptions{})
}
// ContainerCreate creates a container.
func (c *Client) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, name string) (container.CreateResponse, error) {
return c.cli.ContainerCreate(ctx, config, hostConfig, networkingConfig, platform, name)
}
// ContainerStart starts a container.
func (c *Client) ContainerStart(ctx context.Context, containerID string, opts container.StartOptions) error {
return c.cli.ContainerStart(ctx, containerID, opts)
}

View File

@@ -1,49 +0,0 @@
package docker
import (
"bytes"
"context"
"io"
"strconv"
"strings"
"github.com/docker/docker/api/types/container"
)
const (
maxErrorBytes = 4096
)
// TailLogs returns the last N lines of container logs, truncated to 4KB.
func (c *Client) TailLogs(ctx context.Context, containerID string, lines int) (string, error) {
reader, err := c.cli.ContainerLogs(ctx, containerID, container.LogsOptions{
ShowStdout: true,
ShowStderr: true,
Timestamps: false,
Tail: strconv.Itoa(lines),
})
if err != nil {
return "", err
}
defer reader.Close()
var buf bytes.Buffer
if _, err := io.Copy(&buf, reader); err != nil {
return "", err
}
out := buf.String()
out = strings.TrimSpace(out)
if len(out) > maxErrorBytes {
out = out[len(out)-maxErrorBytes:]
}
return out, nil
}
// TruncateErrorMessage clamps message length to 4KB.
func TruncateErrorMessage(message string) string {
if len(message) <= maxErrorBytes {
return message
}
return message[:maxErrorBytes]
}

View File

@@ -1,22 +0,0 @@
package docker
import (
"strings"
"testing"
)
func TestTruncateErrorMessage(t *testing.T) {
short := "short message"
if got := TruncateErrorMessage(short); got != short {
t.Fatalf("expected message to stay unchanged")
}
long := strings.Repeat("x", maxErrorBytes+10)
got := TruncateErrorMessage(long)
if len(got) != maxErrorBytes {
t.Fatalf("expected length %d, got %d", maxErrorBytes, len(got))
}
if got != long[:maxErrorBytes] {
t.Fatalf("unexpected truncation result")
}
}

View File

@@ -1,20 +0,0 @@
package docker
import (
"context"
"github.com/docker/docker/api/types/container"
)
// Wait waits for a container to stop and returns the exit code.
func (c *Client) Wait(ctx context.Context, containerID string) (int64, error) {
statusCh, errCh := c.cli.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
select {
case status := <-statusCh:
return status.StatusCode, nil
case err := <-errCh:
return 0, err
case <-ctx.Done():
return 0, ctx.Err()
}
}

View File

@@ -1,76 +0,0 @@
package docker
import (
"context"
"fmt"
"os"
"strings"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/api/types/strslice"
"github.com/yyhuni/lunafox/agent/internal/domain"
)
const workerImagePrefix = "yyhuni/lunafox-worker:"
// StartWorker starts a worker container for a task and returns the container ID.
func (c *Client) StartWorker(ctx context.Context, t *domain.Task, serverURL, serverToken, agentVersion string) (string, error) {
if t == nil {
return "", fmt.Errorf("task is nil")
}
if err := os.MkdirAll(t.WorkspaceDir, 0755); err != nil {
return "", fmt.Errorf("prepare workspace: %w", err)
}
image, err := resolveWorkerImage(agentVersion)
if err != nil {
return "", err
}
env := buildWorkerEnv(t, serverURL, serverToken)
config := &container.Config{
Image: image,
Env: env,
Cmd: strslice.StrSlice{},
}
hostConfig := &container.HostConfig{
Binds: []string{"/opt/lunafox:/opt/lunafox"},
AutoRemove: false,
OomScoreAdj: 500,
}
resp, err := c.cli.ContainerCreate(ctx, config, hostConfig, &network.NetworkingConfig{}, nil, "")
if err != nil {
return "", err
}
if err := c.cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil {
return "", err
}
return resp.ID, nil
}
func resolveWorkerImage(version string) (string, error) {
version = strings.TrimSpace(version)
if version == "" {
return "", fmt.Errorf("worker version is required")
}
return workerImagePrefix + version, nil
}
func buildWorkerEnv(t *domain.Task, serverURL, serverToken string) []string {
return []string{
fmt.Sprintf("SERVER_URL=%s", serverURL),
fmt.Sprintf("SERVER_TOKEN=%s", serverToken),
fmt.Sprintf("SCAN_ID=%d", t.ScanID),
fmt.Sprintf("TARGET_ID=%d", t.TargetID),
fmt.Sprintf("TARGET_NAME=%s", t.TargetName),
fmt.Sprintf("TARGET_TYPE=%s", t.TargetType),
fmt.Sprintf("WORKFLOW_NAME=%s", t.WorkflowName),
fmt.Sprintf("WORKSPACE_DIR=%s", t.WorkspaceDir),
fmt.Sprintf("CONFIG=%s", t.Config),
}
}

View File

@@ -1,50 +0,0 @@
package docker
import (
"testing"
"github.com/yyhuni/lunafox/agent/internal/domain"
)
func TestResolveWorkerImage(t *testing.T) {
if _, err := resolveWorkerImage(""); err == nil {
t.Fatalf("expected error for empty version")
}
if got, err := resolveWorkerImage("v1.2.3"); err != nil || got != workerImagePrefix+"v1.2.3" {
t.Fatalf("expected version image, got %s, err: %v", got, err)
}
}
func TestBuildWorkerEnv(t *testing.T) {
spec := &domain.Task{
ScanID: 1,
TargetID: 2,
TargetName: "example.com",
TargetType: "domain",
WorkflowName: "subdomain_discovery",
WorkspaceDir: "/opt/lunafox/results",
Config: "config-yaml",
}
env := buildWorkerEnv(spec, "https://server", "token")
expected := []string{
"SERVER_URL=https://server",
"SERVER_TOKEN=token",
"SCAN_ID=1",
"TARGET_ID=2",
"TARGET_NAME=example.com",
"TARGET_TYPE=domain",
"WORKFLOW_NAME=subdomain_discovery",
"WORKSPACE_DIR=/opt/lunafox/results",
"CONFIG=config-yaml",
}
if len(env) != len(expected) {
t.Fatalf("expected %d env entries, got %d", len(expected), len(env))
}
for i, item := range expected {
if env[i] != item {
t.Fatalf("expected env[%d]=%s got %s", i, item, env[i])
}
}
}

View File

@@ -1,8 +0,0 @@
package domain
type ConfigUpdate struct {
MaxTasks *int `json:"maxTasks"`
CPUThreshold *int `json:"cpuThreshold"`
MemThreshold *int `json:"memThreshold"`
DiskThreshold *int `json:"diskThreshold"`
}

View File

@@ -1,10 +0,0 @@
package domain
import "time"
type HealthStatus struct {
State string `json:"state"`
Reason string `json:"reason,omitempty"`
Message string `json:"message,omitempty"`
Since *time.Time `json:"since,omitempty"`
}

View File

@@ -1,13 +0,0 @@
package domain
type Task struct {
ID int `json:"taskId"`
ScanID int `json:"scanId"`
Stage int `json:"stage"`
WorkflowName string `json:"workflowName"`
TargetID int `json:"targetId"`
TargetName string `json:"targetName"`
TargetType string `json:"targetType"`
WorkspaceDir string `json:"workspaceDir"`
Config string `json:"config"`
}

View File

@@ -1,6 +0,0 @@
package domain
type UpdateRequiredPayload struct {
Version string `json:"version"`
Image string `json:"image"`
}

View File

@@ -1,51 +0,0 @@
package health
import (
"sync"
"time"
"github.com/yyhuni/lunafox/agent/internal/domain"
)
// Status represents the agent health state reported in heartbeats.
type Status = domain.HealthStatus
// Manager stores current health status.
type Manager struct {
mu sync.RWMutex
status Status
}
// NewManager initializes the manager with ok status.
func NewManager() *Manager {
return &Manager{
status: Status{State: "ok"},
}
}
// Get returns a snapshot of current status.
func (m *Manager) Get() Status {
m.mu.RLock()
defer m.mu.RUnlock()
return m.status
}
// Set updates health status and timestamps transitions.
func (m *Manager) Set(state, reason, message string) {
m.mu.Lock()
defer m.mu.Unlock()
if m.status.State != state {
now := time.Now().UTC()
m.status.Since = &now
}
m.status.State = state
m.status.Reason = reason
m.status.Message = message
if state == "ok" {
m.status.Since = nil
m.status.Reason = ""
m.status.Message = ""
}
}

View File

@@ -1,33 +0,0 @@
package health
import "testing"
func TestManagerSetTransitions(t *testing.T) {
mgr := NewManager()
initial := mgr.Get()
if initial.State != "ok" || initial.Since != nil {
t.Fatalf("expected initial ok status")
}
mgr.Set("paused", "update", "waiting")
status := mgr.Get()
if status.State != "paused" || status.Since == nil {
t.Fatalf("expected paused state with timestamp")
}
prevSince := status.Since
mgr.Set("paused", "still", "waiting more")
status = mgr.Get()
if status.Since == nil || !status.Since.Equal(*prevSince) {
t.Fatalf("expected unchanged since on same state")
}
if status.Reason != "still" || status.Message != "waiting more" {
t.Fatalf("expected updated reason/message")
}
mgr.Set("ok", "ignored", "ignored")
status = mgr.Get()
if status.State != "ok" || status.Since != nil || status.Reason != "" || status.Message != "" {
t.Fatalf("expected ok reset to clear fields")
}
}

View File

@@ -1,50 +0,0 @@
package logger
import (
"os"
"strings"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// Log is the shared agent logger. Defaults to a no-op logger until initialized.
var Log = zap.NewNop()
// Init configures the logger using the provided level and ENV.
func Init(level string) error {
level = strings.TrimSpace(level)
if level == "" {
level = "info"
}
var zapLevel zapcore.Level
if err := zapLevel.UnmarshalText([]byte(level)); err != nil {
zapLevel = zapcore.InfoLevel
}
isDev := strings.EqualFold(os.Getenv("ENV"), "development")
var config zap.Config
if isDev {
config = zap.NewDevelopmentConfig()
config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
} else {
config = zap.NewProductionConfig()
}
config.Level = zap.NewAtomicLevelAt(zapLevel)
logger, err := config.Build()
if err != nil {
Log = zap.NewNop()
return err
}
Log = logger
return nil
}
// Sync flushes any buffered log entries.
func Sync() {
if Log != nil {
_ = Log.Sync()
}
}

View File

@@ -1,58 +0,0 @@
package metrics
import (
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/disk"
"github.com/shirou/gopsutil/v3/mem"
"github.com/yyhuni/lunafox/agent/internal/logger"
"go.uber.org/zap"
)
// Collector gathers system metrics.
type Collector struct{}
// NewCollector creates a new Collector.
func NewCollector() *Collector {
return &Collector{}
}
// Sample returns CPU, memory, and disk usage percentages.
func (c *Collector) Sample() (float64, float64, float64) {
cpuPercent, err := cpuUsagePercent()
if err != nil {
logger.Log.Warn("metrics: cpu percent error", zap.Error(err))
}
memPercent, err := memUsagePercent()
if err != nil {
logger.Log.Warn("metrics: mem percent error", zap.Error(err))
}
diskPercent, err := diskUsagePercent("/")
if err != nil {
logger.Log.Warn("metrics: disk percent error", zap.Error(err))
}
return cpuPercent, memPercent, diskPercent
}
func cpuUsagePercent() (float64, error) {
values, err := cpu.Percent(0, false)
if err != nil || len(values) == 0 {
return 0, err
}
return values[0], nil
}
func memUsagePercent() (float64, error) {
info, err := mem.VirtualMemory()
if err != nil {
return 0, err
}
return info.UsedPercent, nil
}
func diskUsagePercent(path string) (float64, error) {
info, err := disk.Usage(path)
if err != nil {
return 0, err
}
return info.UsedPercent, nil
}

View File

@@ -1,11 +0,0 @@
package metrics
import "testing"
func TestCollectorSample(t *testing.T) {
c := NewCollector()
cpu, mem, disk := c.Sample()
if cpu < 0 || mem < 0 || disk < 0 {
t.Fatalf("expected non-negative metrics")
}
}

View File

@@ -1,42 +0,0 @@
package protocol
import (
"time"
"github.com/yyhuni/lunafox/agent/internal/domain"
)
const (
MessageTypeHeartbeat = "heartbeat"
MessageTypeTaskAvailable = "task_available"
MessageTypeTaskCancel = "task_cancel"
MessageTypeConfigUpdate = "config_update"
MessageTypeUpdateRequired = "update_required"
)
type Message struct {
Type string `json:"type"`
Payload interface{} `json:"payload"`
Timestamp time.Time `json:"timestamp"`
}
type HealthStatus = domain.HealthStatus
type HeartbeatPayload struct {
CPU float64 `json:"cpu"`
Mem float64 `json:"mem"`
Disk float64 `json:"disk"`
Tasks int `json:"tasks"`
Version string `json:"version"`
Hostname string `json:"hostname"`
Uptime int64 `json:"uptime"`
Health HealthStatus `json:"health"`
}
type ConfigUpdatePayload = domain.ConfigUpdate
type UpdateRequiredPayload = domain.UpdateRequiredPayload
type TaskCancelPayload struct {
TaskID int `json:"taskId"`
}

View File

@@ -1,118 +0,0 @@
package task
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/yyhuni/lunafox/agent/internal/domain"
)
// Client handles HTTP API requests to the server.
type Client struct {
baseURL string
apiKey string
http *http.Client
}
// NewClient creates a new task client.
func NewClient(serverURL, apiKey string) *Client {
transport := http.DefaultTransport
if base, ok := transport.(*http.Transport); ok {
clone := base.Clone()
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
transport = clone
}
return &Client{
baseURL: strings.TrimRight(serverURL, "/"),
apiKey: apiKey,
http: &http.Client{
Timeout: 15 * time.Second,
Transport: transport,
},
}
}
// PullTask requests a task from the server. Returns nil when no task available.
func (c *Client) PullTask(ctx context.Context) (*domain.Task, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/agent/tasks/pull", nil)
if err != nil {
return nil, err
}
req.Header.Set("X-Agent-Key", c.apiKey)
resp, err := c.http.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNoContent {
return nil, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("pull task failed: status %d", resp.StatusCode)
}
var task domain.Task
if err := json.NewDecoder(resp.Body).Decode(&task); err != nil {
return nil, err
}
return &task, nil
}
// UpdateStatus reports task status to the server with retry.
func (c *Client) UpdateStatus(ctx context.Context, taskID int, status, errorMessage string) error {
payload := map[string]string{
"status": status,
}
if errorMessage != "" {
payload["errorMessage"] = errorMessage
}
body, err := json.Marshal(payload)
if err != nil {
return err
}
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
backoff := time.Duration(5<<attempt) * time.Second // 5s, 10s, 20s
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff):
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, fmt.Sprintf("%s/api/agent/tasks/%d/status", c.baseURL, taskID), bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Agent-Key", c.apiKey)
resp, err := c.http.Do(req)
if err != nil {
lastErr = err
continue
}
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
lastErr = fmt.Errorf("update status failed: status %d", resp.StatusCode)
// Don't retry 4xx client errors (except 429)
if resp.StatusCode >= 400 && resp.StatusCode < 500 && resp.StatusCode != 429 {
return lastErr
}
}
return lastErr
}

View File

@@ -1,187 +0,0 @@
package task
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"github.com/yyhuni/lunafox/agent/internal/domain"
)
func TestClientPullTaskNoContent(t *testing.T) {
client := &Client{
baseURL: "http://example",
apiKey: "key",
http: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path != "/api/agent/tasks/pull" {
t.Fatalf("unexpected path %s", r.URL.Path)
}
return &http.Response{
StatusCode: http.StatusNoContent,
Body: io.NopCloser(strings.NewReader("")),
Header: http.Header{},
}, nil
}),
},
}
task, err := client.PullTask(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if task != nil {
t.Fatalf("expected nil task")
}
}
func TestClientPullTaskOK(t *testing.T) {
client := &Client{
baseURL: "http://example",
apiKey: "key",
http: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Header.Get("X-Agent-Key") == "" {
t.Fatalf("missing api key header")
}
body, _ := json.Marshal(domain.Task{ID: 1})
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{},
}, nil
}),
},
}
task, err := client.PullTask(context.Background())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if task == nil || task.ID != 1 {
t.Fatalf("unexpected task")
}
}
func TestClientUpdateStatus(t *testing.T) {
client := &Client{
baseURL: "http://example",
apiKey: "key",
http: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.Method != http.MethodPatch {
t.Fatalf("expected PATCH")
}
if r.Header.Get("X-Agent-Key") == "" {
t.Fatalf("missing api key header")
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("")),
Header: http.Header{},
}, nil
}),
},
}
if err := client.UpdateStatus(context.Background(), 1, "completed", ""); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestClientPullTaskErrorStatus(t *testing.T) {
client := &Client{
baseURL: "http://example",
apiKey: "key",
http: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(strings.NewReader("bad")),
Header: http.Header{},
}, nil
}),
},
}
if _, err := client.PullTask(context.Background()); err == nil {
t.Fatalf("expected error for non-200 status")
}
}
func TestClientPullTaskBadJSON(t *testing.T) {
client := &Client{
baseURL: "http://example",
apiKey: "key",
http: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("{bad json")),
Header: http.Header{},
}, nil
}),
},
}
if _, err := client.PullTask(context.Background()); err == nil {
t.Fatalf("expected error for invalid json")
}
}
func TestClientUpdateStatusIncludesErrorMessage(t *testing.T) {
client := &Client{
baseURL: "http://example",
apiKey: "key",
http: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
var payload map[string]string
if err := json.Unmarshal(body, &payload); err != nil {
t.Fatalf("unmarshal body: %v", err)
}
if payload["status"] != "failed" {
t.Fatalf("expected status failed")
}
if payload["errorMessage"] != "boom" {
t.Fatalf("expected error message")
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("")),
Header: http.Header{},
}, nil
}),
},
}
if err := client.UpdateStatus(context.Background(), 1, "failed", "boom"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestClientUpdateStatusErrorStatus(t *testing.T) {
client := &Client{
baseURL: "http://example",
apiKey: "key",
http: &http.Client{
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader("")),
Header: http.Header{},
}, nil
}),
},
}
if err := client.UpdateStatus(context.Background(), 1, "completed", ""); err == nil {
t.Fatalf("expected error for non-200 status")
}
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}

View File

@@ -1,23 +0,0 @@
package task
import "sync/atomic"
// Counter tracks running task count.
type Counter struct {
value int64
}
// Inc increments the counter.
func (c *Counter) Inc() {
atomic.AddInt64(&c.value, 1)
}
// Dec decrements the counter.
func (c *Counter) Dec() {
atomic.AddInt64(&c.value, -1)
}
// Count returns current count.
func (c *Counter) Count() int {
return int(atomic.LoadInt64(&c.value))
}

View File

@@ -1,18 +0,0 @@
package task
import "testing"
func TestCounterIncDec(t *testing.T) {
var counter Counter
counter.Inc()
counter.Inc()
if got := counter.Count(); got != 2 {
t.Fatalf("expected count 2, got %d", got)
}
counter.Dec()
if got := counter.Count(); got != 1 {
t.Fatalf("expected count 1, got %d", got)
}
}

View File

@@ -1,258 +0,0 @@
package task
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/yyhuni/lunafox/agent/internal/docker"
"github.com/yyhuni/lunafox/agent/internal/domain"
)
const defaultMaxRuntime = 7 * 24 * time.Hour
// Executor runs tasks inside worker containers.
type Executor struct {
docker DockerRunner
client statusReporter
counter *Counter
serverURL string
workerToken string
agentVersion string
maxRuntime time.Duration
mu sync.Mutex
running map[int]context.CancelFunc
cancelMu sync.Mutex
cancelled map[int]struct{}
wg sync.WaitGroup
stopping atomic.Bool
}
type statusReporter interface {
UpdateStatus(ctx context.Context, taskID int, status, errorMessage string) error
}
type DockerRunner interface {
StartWorker(ctx context.Context, t *domain.Task, serverURL, serverToken, agentVersion string) (string, error)
Wait(ctx context.Context, containerID string) (int64, error)
Stop(ctx context.Context, containerID string) error
Remove(ctx context.Context, containerID string) error
TailLogs(ctx context.Context, containerID string, lines int) (string, error)
}
// NewExecutor creates an Executor.
func NewExecutor(dockerClient DockerRunner, taskClient statusReporter, counter *Counter, serverURL, workerToken, agentVersion string) *Executor {
return &Executor{
docker: dockerClient,
client: taskClient,
counter: counter,
serverURL: serverURL,
workerToken: workerToken,
agentVersion: agentVersion,
maxRuntime: defaultMaxRuntime,
running: map[int]context.CancelFunc{},
cancelled: map[int]struct{}{},
}
}
// Start processes tasks from the queue.
func (e *Executor) Start(ctx context.Context, tasks <-chan *domain.Task) {
for {
select {
case <-ctx.Done():
return
case t, ok := <-tasks:
if !ok {
return
}
if t == nil {
continue
}
if e.stopping.Load() {
// During shutdown/update: drain the queue but don't start new work.
continue
}
if e.isCancelled(t.ID) {
e.reportStatus(ctx, t.ID, "cancelled", "")
e.clearCancelled(t.ID)
continue
}
go e.execute(ctx, t)
}
}
}
// CancelTask requests cancellation of a running task.
func (e *Executor) CancelTask(taskID int) {
e.mu.Lock()
cancel := e.running[taskID]
e.mu.Unlock()
if cancel != nil {
cancel()
}
}
// MarkCancelled records a task as cancelled to prevent execution.
func (e *Executor) MarkCancelled(taskID int) {
e.cancelMu.Lock()
e.cancelled[taskID] = struct{}{}
e.cancelMu.Unlock()
}
func (e *Executor) reportStatus(ctx context.Context, taskID int, status, errorMessage string) {
if e.client == nil {
return
}
statusCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
defer cancel()
_ = e.client.UpdateStatus(statusCtx, taskID, status, errorMessage)
}
func (e *Executor) execute(ctx context.Context, t *domain.Task) {
e.wg.Add(1)
defer e.wg.Done()
defer e.clearCancelled(t.ID)
if e.counter != nil {
e.counter.Inc()
defer e.counter.Dec()
}
if e.workerToken == "" {
e.reportStatus(ctx, t.ID, "failed", "missing worker token")
return
}
if e.docker == nil {
e.reportStatus(ctx, t.ID, "failed", "docker client unavailable")
return
}
runCtx, cancel := context.WithTimeout(ctx, e.maxRuntime)
defer cancel()
containerID, err := e.docker.StartWorker(runCtx, t, e.serverURL, e.workerToken, e.agentVersion)
if err != nil {
message := docker.TruncateErrorMessage(err.Error())
e.reportStatus(ctx, t.ID, "failed", message)
return
}
defer func() {
_ = e.docker.Remove(context.Background(), containerID)
}()
e.trackCancel(t.ID, cancel)
defer e.clearCancel(t.ID)
exitCode, waitErr := e.docker.Wait(runCtx, containerID)
if waitErr != nil {
if errors.Is(waitErr, context.DeadlineExceeded) || errors.Is(runCtx.Err(), context.DeadlineExceeded) {
e.handleTimeout(ctx, t, containerID)
return
}
if errors.Is(waitErr, context.Canceled) || errors.Is(runCtx.Err(), context.Canceled) {
e.handleCancel(ctx, t, containerID)
return
}
message := docker.TruncateErrorMessage(waitErr.Error())
e.reportStatus(ctx, t.ID, "failed", message)
return
}
if runCtx.Err() != nil {
if errors.Is(runCtx.Err(), context.DeadlineExceeded) {
e.handleTimeout(ctx, t, containerID)
return
}
if errors.Is(runCtx.Err(), context.Canceled) {
e.handleCancel(ctx, t, containerID)
return
}
}
if exitCode == 0 {
e.reportStatus(ctx, t.ID, "completed", "")
return
}
logs, _ := e.docker.TailLogs(context.Background(), containerID, 100)
message := logs
if message == "" {
message = fmt.Sprintf("container exited with code %d", exitCode)
}
message = docker.TruncateErrorMessage(message)
e.reportStatus(ctx, t.ID, "failed", message)
}
func (e *Executor) handleCancel(ctx context.Context, t *domain.Task, containerID string) {
_ = e.docker.Stop(context.Background(), containerID)
e.reportStatus(ctx, t.ID, "cancelled", "")
}
func (e *Executor) handleTimeout(ctx context.Context, t *domain.Task, containerID string) {
_ = e.docker.Stop(context.Background(), containerID)
message := docker.TruncateErrorMessage("task timed out")
e.reportStatus(ctx, t.ID, "failed", message)
}
func (e *Executor) trackCancel(taskID int, cancel context.CancelFunc) {
e.mu.Lock()
defer e.mu.Unlock()
e.running[taskID] = cancel
}
func (e *Executor) clearCancel(taskID int) {
e.mu.Lock()
defer e.mu.Unlock()
delete(e.running, taskID)
}
func (e *Executor) isCancelled(taskID int) bool {
e.cancelMu.Lock()
defer e.cancelMu.Unlock()
_, ok := e.cancelled[taskID]
return ok
}
func (e *Executor) clearCancelled(taskID int) {
e.cancelMu.Lock()
delete(e.cancelled, taskID)
e.cancelMu.Unlock()
}
// CancelAll requests cancellation for all running tasks.
func (e *Executor) CancelAll() {
e.mu.Lock()
cancels := make([]context.CancelFunc, 0, len(e.running))
for _, cancel := range e.running {
cancels = append(cancels, cancel)
}
e.mu.Unlock()
for _, cancel := range cancels {
cancel()
}
}
// Shutdown cancels running tasks and waits for completion.
func (e *Executor) Shutdown(ctx context.Context) error {
e.stopping.Store(true)
e.CancelAll()
done := make(chan struct{})
go func() {
e.wg.Wait()
close(done)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}

View File

@@ -1,107 +0,0 @@
package task
import (
"context"
"testing"
"time"
"github.com/yyhuni/lunafox/agent/internal/domain"
)
type fakeReporter struct {
status string
msg string
}
func (f *fakeReporter) UpdateStatus(ctx context.Context, taskID int, status, errorMessage string) error {
f.status = status
f.msg = errorMessage
return nil
}
func TestExecutorMissingWorkerToken(t *testing.T) {
reporter := &fakeReporter{}
exec := &Executor{
client: reporter,
serverURL: "https://server",
workerToken: "",
}
exec.execute(context.Background(), &domain.Task{ID: 1})
if reporter.status != "failed" {
t.Fatalf("expected failed status, got %s", reporter.status)
}
if reporter.msg == "" {
t.Fatalf("expected error message")
}
}
func TestExecutorDockerUnavailable(t *testing.T) {
reporter := &fakeReporter{}
exec := &Executor{
client: reporter,
serverURL: "https://server",
workerToken: "token",
}
exec.execute(context.Background(), &domain.Task{ID: 2})
if reporter.status != "failed" {
t.Fatalf("expected failed status, got %s", reporter.status)
}
if reporter.msg == "" {
t.Fatalf("expected error message")
}
}
func TestExecutorCancelAll(t *testing.T) {
exec := &Executor{
running: map[int]context.CancelFunc{},
}
calls := 0
exec.running[1] = func() { calls++ }
exec.running[2] = func() { calls++ }
exec.CancelAll()
if calls != 2 {
t.Fatalf("expected cancel calls, got %d", calls)
}
}
func TestExecutorShutdownWaits(t *testing.T) {
exec := &Executor{
running: map[int]context.CancelFunc{},
}
calls := 0
exec.running[1] = func() { calls++ }
exec.wg.Add(1)
go func() {
time.Sleep(10 * time.Millisecond)
exec.wg.Done()
}()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := exec.Shutdown(ctx); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls != 1 {
t.Fatalf("expected cancel call")
}
}
func TestExecutorShutdownTimeout(t *testing.T) {
exec := &Executor{
running: map[int]context.CancelFunc{},
}
exec.wg.Add(1)
defer exec.wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
if err := exec.Shutdown(ctx); err == nil {
t.Fatalf("expected timeout error")
}
}

View File

@@ -1,252 +0,0 @@
package task
import (
"context"
"errors"
"math"
"math/rand"
"sync"
"sync/atomic"
"time"
"github.com/yyhuni/lunafox/agent/internal/domain"
)
// Puller coordinates task pulling with load gating and backoff.
type Puller struct {
client TaskPuller
collector MetricsSampler
counter *Counter
maxTasks int
cpuThreshold int
memThreshold int
diskThreshold int
onTask func(*domain.Task)
notifyCh chan struct{}
emptyBackoff []time.Duration
emptyIdx int
errorBackoff time.Duration
errorMax time.Duration
randSrc *rand.Rand
mu sync.RWMutex
paused atomic.Bool
}
type MetricsSampler interface {
Sample() (float64, float64, float64)
}
type TaskPuller interface {
PullTask(ctx context.Context) (*domain.Task, error)
}
// NewPuller creates a new Puller.
func NewPuller(client TaskPuller, collector MetricsSampler, counter *Counter, maxTasks, cpuThreshold, memThreshold, diskThreshold int) *Puller {
return &Puller{
client: client,
collector: collector,
counter: counter,
maxTasks: maxTasks,
cpuThreshold: cpuThreshold,
memThreshold: memThreshold,
diskThreshold: diskThreshold,
notifyCh: make(chan struct{}, 1),
emptyBackoff: []time.Duration{5 * time.Second, 10 * time.Second, 30 * time.Second, 60 * time.Second},
errorBackoff: 1 * time.Second,
errorMax: 60 * time.Second,
randSrc: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
// SetOnTask registers a callback invoked when a task is assigned.
func (p *Puller) SetOnTask(fn func(*domain.Task)) {
p.onTask = fn
}
// NotifyTaskAvailable triggers an immediate pull attempt.
func (p *Puller) NotifyTaskAvailable() {
select {
case p.notifyCh <- struct{}{}:
default:
}
}
// Run starts the pull loop.
func (p *Puller) Run(ctx context.Context) error {
for {
if ctx.Err() != nil {
return ctx.Err()
}
if p.paused.Load() {
if !p.waitUntilCanceled(ctx) {
return ctx.Err()
}
continue
}
loadInterval := p.loadInterval()
if !p.canPull() {
if !p.wait(ctx, loadInterval) {
return ctx.Err()
}
continue
}
task, err := p.client.PullTask(ctx)
if err != nil {
delay := p.nextErrorBackoff()
if !p.wait(ctx, delay) {
return ctx.Err()
}
continue
}
p.resetErrorBackoff()
if task == nil {
delay := p.nextEmptyDelay(loadInterval)
if !p.waitOrNotify(ctx, delay) {
return ctx.Err()
}
continue
}
p.resetEmptyBackoff()
if p.onTask != nil {
p.onTask(task)
}
}
}
func (p *Puller) canPull() bool {
maxTasks, cpuThreshold, memThreshold, diskThreshold := p.currentConfig()
if p.counter != nil && p.counter.Count() >= maxTasks {
return false
}
cpu, mem, disk := p.collector.Sample()
return cpu < float64(cpuThreshold) &&
mem < float64(memThreshold) &&
disk < float64(diskThreshold)
}
func (p *Puller) loadInterval() time.Duration {
cpu, mem, disk := p.collector.Sample()
load := math.Max(cpu, math.Max(mem, disk))
switch {
case load < 50:
return 1 * time.Second
case load < 80:
return 3 * time.Second
default:
return 10 * time.Second
}
}
func (p *Puller) nextEmptyDelay(loadInterval time.Duration) time.Duration {
var empty time.Duration
if p.emptyIdx < len(p.emptyBackoff) {
empty = p.emptyBackoff[p.emptyIdx]
p.emptyIdx++
} else {
empty = p.emptyBackoff[len(p.emptyBackoff)-1]
}
if empty < loadInterval {
return loadInterval
}
return empty
}
func (p *Puller) resetEmptyBackoff() {
p.emptyIdx = 0
}
func (p *Puller) nextErrorBackoff() time.Duration {
delay := p.errorBackoff
next := delay * 2
if next > p.errorMax {
next = p.errorMax
}
p.errorBackoff = next
return withJitter(delay, p.randSrc)
}
func (p *Puller) resetErrorBackoff() {
p.errorBackoff = 1 * time.Second
}
func (p *Puller) wait(ctx context.Context, delay time.Duration) bool {
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return false
case <-timer.C:
return true
}
}
func (p *Puller) waitOrNotify(ctx context.Context, delay time.Duration) bool {
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return false
case <-p.notifyCh:
return true
case <-timer.C:
return true
}
}
func withJitter(delay time.Duration, src *rand.Rand) time.Duration {
if delay <= 0 || src == nil {
return delay
}
jitter := src.Float64() * 0.2
return delay + time.Duration(float64(delay)*jitter)
}
func (p *Puller) EnsureTaskHandler() error {
if p.onTask == nil {
return errors.New("task handler is required")
}
return nil
}
// Pause stops pulling. Once paused, only context cancellation exits the loop.
func (p *Puller) Pause() {
p.paused.Store(true)
}
// UpdateConfig updates puller thresholds and max tasks.
func (p *Puller) UpdateConfig(maxTasks, cpuThreshold, memThreshold, diskThreshold *int) {
p.mu.Lock()
defer p.mu.Unlock()
if maxTasks != nil && *maxTasks > 0 {
p.maxTasks = *maxTasks
}
if cpuThreshold != nil && *cpuThreshold > 0 {
p.cpuThreshold = *cpuThreshold
}
if memThreshold != nil && *memThreshold > 0 {
p.memThreshold = *memThreshold
}
if diskThreshold != nil && *diskThreshold > 0 {
p.diskThreshold = *diskThreshold
}
}
func (p *Puller) currentConfig() (int, int, int, int) {
p.mu.RLock()
defer p.mu.RUnlock()
return p.maxTasks, p.cpuThreshold, p.memThreshold, p.diskThreshold
}
func (p *Puller) waitUntilCanceled(ctx context.Context) bool {
<-ctx.Done()
return false
}

View File

@@ -1,101 +0,0 @@
package task
import (
"math/rand"
"testing"
"time"
"github.com/yyhuni/lunafox/agent/internal/domain"
)
func TestPullerUpdateConfig(t *testing.T) {
p := NewPuller(nil, nil, nil, 5, 85, 86, 87)
max, cpu, mem, disk := p.currentConfig()
if max != 5 || cpu != 85 || mem != 86 || disk != 87 {
t.Fatalf("unexpected initial config")
}
maxUpdate := 8
cpuUpdate := 70
p.UpdateConfig(&maxUpdate, &cpuUpdate, nil, nil)
max, cpu, mem, disk = p.currentConfig()
if max != 8 || cpu != 70 || mem != 86 || disk != 87 {
t.Fatalf("unexpected updated config")
}
}
func TestPullerPause(t *testing.T) {
p := NewPuller(nil, nil, nil, 1, 1, 1, 1)
p.Pause()
if !p.paused.Load() {
t.Fatalf("expected paused")
}
}
func TestPullerEnsureTaskHandler(t *testing.T) {
p := NewPuller(nil, nil, nil, 1, 1, 1, 1)
if err := p.EnsureTaskHandler(); err == nil {
t.Fatalf("expected error when handler missing")
}
p.SetOnTask(func(*domain.Task) {})
if err := p.EnsureTaskHandler(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestPullerNextEmptyDelay(t *testing.T) {
p := NewPuller(nil, nil, nil, 1, 1, 1, 1)
p.emptyBackoff = []time.Duration{5 * time.Second, 10 * time.Second}
if delay := p.nextEmptyDelay(8 * time.Second); delay != 8*time.Second {
t.Fatalf("expected delay to honor load interval, got %v", delay)
}
if delay := p.nextEmptyDelay(1 * time.Second); delay != 10*time.Second {
t.Fatalf("expected backoff delay, got %v", delay)
}
if p.emptyIdx != 2 {
t.Fatalf("expected empty index to advance")
}
p.resetEmptyBackoff()
if p.emptyIdx != 0 {
t.Fatalf("expected empty index reset")
}
}
func TestPullerErrorBackoff(t *testing.T) {
p := NewPuller(nil, nil, nil, 1, 1, 1, 1)
p.randSrc = rand.New(rand.NewSource(1))
first := p.nextErrorBackoff()
if first < time.Second || first > time.Second+(time.Second/5) {
t.Fatalf("unexpected backoff %v", first)
}
if p.errorBackoff != 2*time.Second {
t.Fatalf("expected backoff to double")
}
second := p.nextErrorBackoff()
if second < 2*time.Second || second > 2*time.Second+(2*time.Second/5) {
t.Fatalf("unexpected backoff %v", second)
}
if p.errorBackoff != 4*time.Second {
t.Fatalf("expected backoff to double")
}
p.resetErrorBackoff()
if p.errorBackoff != time.Second {
t.Fatalf("expected error backoff reset")
}
}
func TestWithJitterRange(t *testing.T) {
rng := rand.New(rand.NewSource(1))
delay := 10 * time.Second
got := withJitter(delay, rng)
if got < delay {
t.Fatalf("expected jitter >= delay")
}
if got > delay+(delay/5) {
t.Fatalf("expected jitter <= 20%%")
}
}

View File

@@ -1,279 +0,0 @@
package update
import (
"context"
"fmt"
"io"
"math/rand"
"os"
"strings"
"sync"
"time"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/api/types/strslice"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/yyhuni/lunafox/agent/internal/config"
"github.com/yyhuni/lunafox/agent/internal/domain"
"github.com/yyhuni/lunafox/agent/internal/logger"
"go.uber.org/zap"
)
// Updater handles agent self-update.
type Updater struct {
docker dockerClient
health healthSetter
puller pullerController
executor executorController
cfg configSnapshot
apiKey string
token string
mu sync.Mutex
updating bool
randSrc *rand.Rand
backoff time.Duration
maxBackoff time.Duration
}
type dockerClient interface {
ImagePull(ctx context.Context, imageRef string) (io.ReadCloser, error)
ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, name string) (container.CreateResponse, error)
ContainerStart(ctx context.Context, containerID string, opts container.StartOptions) error
}
type healthSetter interface {
Set(state, reason, message string)
}
type pullerController interface {
Pause()
}
type executorController interface {
Shutdown(ctx context.Context) error
}
type configSnapshot interface {
Snapshot() config.Config
}
// NewUpdater creates a new updater.
func NewUpdater(dockerClient dockerClient, healthManager healthSetter, puller pullerController, executor executorController, cfg configSnapshot, apiKey, token string) *Updater {
return &Updater{
docker: dockerClient,
health: healthManager,
puller: puller,
executor: executor,
cfg: cfg,
apiKey: apiKey,
token: token,
randSrc: rand.New(rand.NewSource(time.Now().UnixNano())),
backoff: 30 * time.Second,
maxBackoff: 10 * time.Minute,
}
}
// HandleUpdateRequired triggers the update flow.
func (u *Updater) HandleUpdateRequired(payload domain.UpdateRequiredPayload) {
u.mu.Lock()
if u.updating {
u.mu.Unlock()
return
}
u.updating = true
u.mu.Unlock()
go u.run(payload)
}
func (u *Updater) run(payload domain.UpdateRequiredPayload) {
defer func() {
if r := recover(); r != nil {
logger.Log.Error("agent update panic", zap.Any("panic", r))
u.health.Set("paused", "update_panic", fmt.Sprintf("%v", r))
}
u.mu.Lock()
u.updating = false
u.mu.Unlock()
}()
u.puller.Pause()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
_ = u.executor.Shutdown(ctx)
cancel()
for {
if err := u.updateOnce(payload); err == nil {
u.health.Set("ok", "", "")
os.Exit(0)
} else {
u.health.Set("paused", "update_failed", err.Error())
}
delay := withJitter(u.backoff, u.randSrc)
if u.backoff < u.maxBackoff {
u.backoff *= 2
if u.backoff > u.maxBackoff {
u.backoff = u.maxBackoff
}
}
time.Sleep(delay)
}
}
func (u *Updater) updateOnce(payload domain.UpdateRequiredPayload) error {
if u.docker == nil {
return fmt.Errorf("docker client unavailable")
}
image := strings.TrimSpace(payload.Image)
version := strings.TrimSpace(payload.Version)
if image == "" || version == "" {
return fmt.Errorf("invalid update payload")
}
// Strict validation: reject invalid data from server
if err := validateImageName(image); err != nil {
logger.Log.Warn("invalid image name from server", zap.String("image", image), zap.Error(err))
return fmt.Errorf("invalid image name from server: %w", err)
}
if err := validateVersion(version); err != nil {
logger.Log.Warn("invalid version from server", zap.String("version", version), zap.Error(err))
return fmt.Errorf("invalid version from server: %w", err)
}
fullImage := fmt.Sprintf("%s:%s", image, version)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
reader, err := u.docker.ImagePull(ctx, fullImage)
if err != nil {
return err
}
_, _ = io.Copy(io.Discard, reader)
_ = reader.Close()
if err := u.startNewContainer(ctx, image, version); err != nil {
return err
}
return nil
}
func (u *Updater) startNewContainer(ctx context.Context, image, version string) error {
env := []string{
fmt.Sprintf("SERVER_URL=%s", u.cfg.Snapshot().ServerURL),
fmt.Sprintf("API_KEY=%s", u.apiKey),
fmt.Sprintf("MAX_TASKS=%d", u.cfg.Snapshot().MaxTasks),
fmt.Sprintf("CPU_THRESHOLD=%d", u.cfg.Snapshot().CPUThreshold),
fmt.Sprintf("MEM_THRESHOLD=%d", u.cfg.Snapshot().MemThreshold),
fmt.Sprintf("DISK_THRESHOLD=%d", u.cfg.Snapshot().DiskThreshold),
fmt.Sprintf("AGENT_VERSION=%s", version),
}
if u.token != "" {
env = append(env, fmt.Sprintf("WORKER_TOKEN=%s", u.token))
}
cfg := &container.Config{
Image: fmt.Sprintf("%s:%s", image, version),
Env: env,
Cmd: strslice.StrSlice{},
}
hostConfig := &container.HostConfig{
Binds: []string{
"/var/run/docker.sock:/var/run/docker.sock",
"/opt/lunafox:/opt/lunafox",
},
RestartPolicy: container.RestartPolicy{Name: "unless-stopped"},
OomScoreAdj: -500,
}
// Version is already validated, just normalize to lowercase for container name
name := fmt.Sprintf("lunafox-agent-%s", strings.ToLower(version))
resp, err := u.docker.ContainerCreate(ctx, cfg, hostConfig, &network.NetworkingConfig{}, nil, name)
if err != nil {
return err
}
if err := u.docker.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil {
return err
}
logger.Log.Info("agent update started new container", zap.String("containerId", resp.ID))
return nil
}
func withJitter(delay time.Duration, src *rand.Rand) time.Duration {
if delay <= 0 || src == nil {
return delay
}
jitter := src.Float64() * 0.2
return delay + time.Duration(float64(delay)*jitter)
}
// validateImageName validates that the image name contains only safe characters.
// Returns error if validation fails.
func validateImageName(image string) error {
if len(image) == 0 {
return fmt.Errorf("image name cannot be empty")
}
if len(image) > 255 {
return fmt.Errorf("image name too long: %d characters", len(image))
}
// Allow: alphanumeric, dots, hyphens, underscores, slashes (for registry paths)
for i, r := range image {
if !((r >= 'a' && r <= 'z') ||
(r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') ||
r == '.' || r == '-' || r == '_' || r == '/') {
return fmt.Errorf("invalid character at position %d: %c", i, r)
}
}
// Must not start or end with special characters
first := rune(image[0])
last := rune(image[len(image)-1])
if first == '.' || first == '-' || first == '/' {
return fmt.Errorf("image name cannot start with special character: %c", first)
}
if last == '.' || last == '-' || last == '/' {
return fmt.Errorf("image name cannot end with special character: %c", last)
}
return nil
}
// validateVersion validates that the version string contains only safe characters.
// Returns error if validation fails.
func validateVersion(version string) error {
if len(version) == 0 {
return fmt.Errorf("version cannot be empty")
}
if len(version) > 128 {
return fmt.Errorf("version too long: %d characters", len(version))
}
// Allow: alphanumeric, dots, hyphens, underscores
for i, r := range version {
if !((r >= 'a' && r <= 'z') ||
(r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') ||
r == '.' || r == '-' || r == '_') {
return fmt.Errorf("invalid character at position %d: %c", i, r)
}
}
// Must not start or end with special characters
first := rune(version[0])
last := rune(version[len(version)-1])
if first == '.' || first == '-' || first == '_' {
return fmt.Errorf("version cannot start with special character: %c", first)
}
if last == '.' || last == '-' || last == '_' {
return fmt.Errorf("version cannot end with special character: %c", last)
}
return nil
}

View File

@@ -1,45 +0,0 @@
package update
import (
"math/rand"
"strings"
"testing"
"time"
"github.com/yyhuni/lunafox/agent/internal/domain"
)
func TestSanitizeContainerName(t *testing.T) {
got := sanitizeContainerName("v1.0.0+TEST")
if got == "" {
t.Fatalf("expected sanitized name")
}
if got == "v1.0.0+test" {
t.Fatalf("expected sanitized to replace invalid chars")
}
}
func TestWithJitterRange(t *testing.T) {
rng := rand.New(rand.NewSource(1))
delay := 10 * time.Second
got := withJitter(delay, rng)
if got < delay {
t.Fatalf("expected jitter >= delay")
}
if got > delay+(delay/5) {
t.Fatalf("expected jitter <= 20%%")
}
}
func TestUpdateOnceDockerUnavailable(t *testing.T) {
updater := &Updater{}
payload := domain.UpdateRequiredPayload{Version: "v1.0.0", Image: "yyhuni/lunafox-agent"}
err := updater.updateOnce(payload)
if err == nil {
t.Fatalf("expected error when docker client is nil")
}
if !strings.Contains(err.Error(), "docker client unavailable") {
t.Fatalf("unexpected error: %v", err)
}
}

View File

@@ -1,37 +0,0 @@
package websocket
import "time"
// Backoff implements exponential backoff with a maximum cap.
type Backoff struct {
base time.Duration
max time.Duration
current time.Duration
}
// NewBackoff creates a backoff with the given base and max delay.
func NewBackoff(base, max time.Duration) Backoff {
return Backoff{
base: base,
max: max,
}
}
// Next returns the next backoff duration.
func (b *Backoff) Next() time.Duration {
if b.current <= 0 {
b.current = b.base
return b.current
}
next := b.current * 2
if next > b.max {
next = b.max
}
b.current = next
return b.current
}
// Reset clears the backoff to start over.
func (b *Backoff) Reset() {
b.current = 0
}

View File

@@ -1,32 +0,0 @@
package websocket
import (
"testing"
"time"
)
func TestBackoffSequence(t *testing.T) {
b := NewBackoff(time.Second, 60*time.Second)
expected := []time.Duration{
time.Second,
2 * time.Second,
4 * time.Second,
8 * time.Second,
16 * time.Second,
32 * time.Second,
60 * time.Second,
60 * time.Second,
}
for i, exp := range expected {
if got := b.Next(); got != exp {
t.Fatalf("step %d: expected %v, got %v", i, exp, got)
}
}
b.Reset()
if got := b.Next(); got != time.Second {
t.Fatalf("after reset expected %v, got %v", time.Second, got)
}
}

View File

@@ -1,177 +0,0 @@
package websocket
import (
"context"
"crypto/tls"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/yyhuni/lunafox/agent/internal/logger"
"go.uber.org/zap"
)
const (
defaultPingInterval = 30 * time.Second
defaultPongWait = 60 * time.Second
defaultWriteWait = 10 * time.Second
)
// Client maintains a WebSocket connection to the server.
type Client struct {
wsURL string
apiKey string
dialer *websocket.Dialer
send chan []byte
onMessage func([]byte)
backoff Backoff
pingInterval time.Duration
pongWait time.Duration
writeWait time.Duration
}
// NewClient creates a WebSocket client for the agent.
func NewClient(wsURL, apiKey string) *Client {
dialer := *websocket.DefaultDialer
dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
return &Client{
wsURL: wsURL,
apiKey: apiKey,
dialer: &dialer,
send: make(chan []byte, 256),
backoff: NewBackoff(1*time.Second, 60*time.Second),
pingInterval: defaultPingInterval,
pongWait: defaultPongWait,
writeWait: defaultWriteWait,
}
}
// SetOnMessage registers a callback for incoming messages.
func (c *Client) SetOnMessage(fn func([]byte)) {
c.onMessage = fn
}
// Send queues a message for sending. It returns false if the buffer is full.
func (c *Client) Send(payload []byte) bool {
select {
case c.send <- payload:
return true
default:
return false
}
}
// Run keeps the connection alive with reconnect backoff and keepalive pings.
func (c *Client) Run(ctx context.Context) error {
for {
if ctx.Err() != nil {
return ctx.Err()
}
logger.Log.Info("websocket connect attempt", zap.String("url", c.wsURL))
conn, err := c.connect(ctx)
if err != nil {
logger.Log.Warn("websocket connect failed", zap.Error(err))
if !sleepWithContext(ctx, c.backoff.Next()) {
return ctx.Err()
}
continue
}
c.backoff.Reset()
logger.Log.Info("websocket connected")
err = c.runConn(ctx, conn)
if err != nil && ctx.Err() == nil {
logger.Log.Warn("websocket connection closed", zap.Error(err))
}
if ctx.Err() != nil {
return ctx.Err()
}
if !sleepWithContext(ctx, c.backoff.Next()) {
return ctx.Err()
}
}
}
func (c *Client) connect(ctx context.Context) (*websocket.Conn, error) {
header := http.Header{}
if c.apiKey != "" {
header.Set("X-Agent-Key", c.apiKey)
}
conn, _, err := c.dialer.DialContext(ctx, c.wsURL, header)
return conn, err
}
func (c *Client) runConn(ctx context.Context, conn *websocket.Conn) error {
defer conn.Close()
conn.SetReadDeadline(time.Now().Add(c.pongWait))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(c.pongWait))
return nil
})
errCh := make(chan error, 2)
go c.readLoop(conn, errCh)
go c.writeLoop(ctx, conn, errCh)
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errCh:
return err
}
}
func (c *Client) readLoop(conn *websocket.Conn, errCh chan<- error) {
for {
_, message, err := conn.ReadMessage()
if err != nil {
errCh <- err
return
}
if c.onMessage != nil {
c.onMessage(message)
}
}
}
func (c *Client) writeLoop(ctx context.Context, conn *websocket.Conn, errCh chan<- error) {
ticker := time.NewTicker(c.pingInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
errCh <- ctx.Err()
return
case payload := <-c.send:
if err := c.writeMessage(conn, websocket.TextMessage, payload); err != nil {
errCh <- err
return
}
case <-ticker.C:
if err := c.writeMessage(conn, websocket.PingMessage, nil); err != nil {
errCh <- err
return
}
}
}
}
func (c *Client) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error {
_ = conn.SetWriteDeadline(time.Now().Add(c.writeWait))
return conn.WriteMessage(msgType, payload)
}
func sleepWithContext(ctx context.Context, delay time.Duration) bool {
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return false
case <-timer.C:
return true
}
}

View File

@@ -1,32 +0,0 @@
package websocket
import (
"context"
"testing"
"time"
)
func TestClientSendBufferFull(t *testing.T) {
client := &Client{send: make(chan []byte, 1)}
if !client.Send([]byte("first")) {
t.Fatalf("expected first send to succeed")
}
if client.Send([]byte("second")) {
t.Fatalf("expected second send to fail when buffer is full")
}
}
func TestSleepWithContextCancelled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
if sleepWithContext(ctx, 50*time.Millisecond) {
t.Fatalf("expected sleepWithContext to return false when canceled")
}
}
func TestSleepWithContextElapsed(t *testing.T) {
if !sleepWithContext(context.Background(), 5*time.Millisecond) {
t.Fatalf("expected sleepWithContext to return true after delay")
}
}

View File

@@ -1,90 +0,0 @@
package websocket
import (
"encoding/json"
"github.com/yyhuni/lunafox/agent/internal/protocol"
)
// Handler routes incoming WebSocket messages.
type Handler struct {
onTaskAvailable func()
onTaskCancel func(int)
onConfigUpdate func(protocol.ConfigUpdatePayload)
onUpdateReq func(protocol.UpdateRequiredPayload)
}
// NewHandler creates a message handler.
func NewHandler() *Handler {
return &Handler{}
}
// OnTaskAvailable registers a callback for task_available messages.
func (h *Handler) OnTaskAvailable(fn func()) {
h.onTaskAvailable = fn
}
// OnTaskCancel registers a callback for task_cancel messages.
func (h *Handler) OnTaskCancel(fn func(int)) {
h.onTaskCancel = fn
}
// OnConfigUpdate registers a callback for config_update messages.
func (h *Handler) OnConfigUpdate(fn func(protocol.ConfigUpdatePayload)) {
h.onConfigUpdate = fn
}
// OnUpdateRequired registers a callback for update_required messages.
func (h *Handler) OnUpdateRequired(fn func(protocol.UpdateRequiredPayload)) {
h.onUpdateReq = fn
}
// Handle processes a raw message.
func (h *Handler) Handle(raw []byte) {
var msg struct {
Type string `json:"type"`
Data json.RawMessage `json:"payload"`
}
if err := json.Unmarshal(raw, &msg); err != nil {
return
}
switch msg.Type {
case protocol.MessageTypeTaskAvailable:
if h.onTaskAvailable != nil {
h.onTaskAvailable()
}
case protocol.MessageTypeTaskCancel:
if h.onTaskCancel == nil {
return
}
var payload protocol.TaskCancelPayload
if err := json.Unmarshal(msg.Data, &payload); err != nil {
return
}
if payload.TaskID > 0 {
h.onTaskCancel(payload.TaskID)
}
case protocol.MessageTypeConfigUpdate:
if h.onConfigUpdate == nil {
return
}
var payload protocol.ConfigUpdatePayload
if err := json.Unmarshal(msg.Data, &payload); err != nil {
return
}
h.onConfigUpdate(payload)
case protocol.MessageTypeUpdateRequired:
if h.onUpdateReq == nil {
return
}
var payload protocol.UpdateRequiredPayload
if err := json.Unmarshal(msg.Data, &payload); err != nil {
return
}
if payload.Version == "" || payload.Image == "" {
return
}
h.onUpdateReq(payload)
}
}

View File

@@ -1,85 +0,0 @@
package websocket
import (
"fmt"
"testing"
"github.com/yyhuni/lunafox/agent/internal/protocol"
)
func TestHandlersTaskAvailable(t *testing.T) {
h := NewHandler()
called := 0
h.OnTaskAvailable(func() { called++ })
message := fmt.Sprintf(`{"type":"%s","payload":{},"timestamp":"2026-01-01T00:00:00Z"}`, protocol.MessageTypeTaskAvailable)
h.Handle([]byte(message))
if called != 1 {
t.Fatalf("expected callback to be called")
}
}
func TestHandlersTaskCancel(t *testing.T) {
h := NewHandler()
var got int
h.OnTaskCancel(func(id int) { got = id })
message := fmt.Sprintf(`{"type":"%s","payload":{"taskId":123},"timestamp":"2026-01-01T00:00:00Z"}`, protocol.MessageTypeTaskCancel)
h.Handle([]byte(message))
if got != 123 {
t.Fatalf("expected taskId 123")
}
}
func TestHandlersConfigUpdate(t *testing.T) {
h := NewHandler()
var maxTasks int
h.OnConfigUpdate(func(payload protocol.ConfigUpdatePayload) {
if payload.MaxTasks != nil {
maxTasks = *payload.MaxTasks
}
})
message := fmt.Sprintf(`{"type":"%s","payload":{"maxTasks":8},"timestamp":"2026-01-01T00:00:00Z"}`, protocol.MessageTypeConfigUpdate)
h.Handle([]byte(message))
if maxTasks != 8 {
t.Fatalf("expected maxTasks 8")
}
}
func TestHandlersUpdateRequired(t *testing.T) {
h := NewHandler()
var version string
h.OnUpdateRequired(func(payload protocol.UpdateRequiredPayload) { version = payload.Version })
message := fmt.Sprintf(`{"type":"%s","payload":{"version":"v1.0.1","image":"yyhuni/lunafox-agent"},"timestamp":"2026-01-01T00:00:00Z"}`, protocol.MessageTypeUpdateRequired)
h.Handle([]byte(message))
if version != "v1.0.1" {
t.Fatalf("expected version")
}
}
func TestHandlersIgnoreInvalidJSON(t *testing.T) {
h := NewHandler()
called := 0
h.OnTaskAvailable(func() { called++ })
h.Handle([]byte("{bad json"))
if called != 0 {
t.Fatalf("expected no callbacks on invalid json")
}
}
func TestHandlersUpdateRequiredMissingFields(t *testing.T) {
h := NewHandler()
called := 0
h.OnUpdateRequired(func(payload protocol.UpdateRequiredPayload) { called++ })
message := fmt.Sprintf(`{"type":"%s","payload":{"version":"","image":"yyhuni/lunafox-agent"}}`, protocol.MessageTypeUpdateRequired)
h.Handle([]byte(message))
message = fmt.Sprintf(`{"type":"%s","payload":{"version":"v1.2.3","image":""}}`, protocol.MessageTypeUpdateRequired)
h.Handle([]byte(message))
if called != 0 {
t.Fatalf("expected no callbacks for invalid payload")
}
}

View File

@@ -1,97 +0,0 @@
package websocket
import (
"context"
"encoding/json"
"time"
"github.com/yyhuni/lunafox/agent/internal/health"
"github.com/yyhuni/lunafox/agent/internal/logger"
"github.com/yyhuni/lunafox/agent/internal/metrics"
"github.com/yyhuni/lunafox/agent/internal/protocol"
"go.uber.org/zap"
)
// HeartbeatSender sends periodic heartbeat messages over WebSocket.
type HeartbeatSender struct {
client *Client
collector *metrics.Collector
health *health.Manager
version string
hostname string
startedAt time.Time
taskCount func() int
interval time.Duration
lastSentAt time.Time
}
// NewHeartbeatSender creates a heartbeat sender.
func NewHeartbeatSender(client *Client, collector *metrics.Collector, healthManager *health.Manager, version, hostname string, taskCount func() int) *HeartbeatSender {
return &HeartbeatSender{
client: client,
collector: collector,
health: healthManager,
version: version,
hostname: hostname,
startedAt: time.Now(),
taskCount: taskCount,
interval: 5 * time.Second,
}
}
// Start begins sending heartbeats until context is canceled.
func (h *HeartbeatSender) Start(ctx context.Context) {
ticker := time.NewTicker(h.interval)
defer ticker.Stop()
h.sendOnce()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
h.sendOnce()
}
}
}
func (h *HeartbeatSender) sendOnce() {
cpu, mem, disk := h.collector.Sample()
uptime := int64(time.Since(h.startedAt).Seconds())
tasks := 0
if h.taskCount != nil {
tasks = h.taskCount()
}
status := h.health.Get()
payload := protocol.HeartbeatPayload{
CPU: cpu,
Mem: mem,
Disk: disk,
Tasks: tasks,
Version: h.version,
Hostname: h.hostname,
Uptime: uptime,
Health: protocol.HealthStatus{
State: status.State,
Reason: status.Reason,
Message: status.Message,
Since: status.Since,
},
}
msg := protocol.Message{
Type: protocol.MessageTypeHeartbeat,
Payload: payload,
Timestamp: time.Now().UTC(),
}
data, err := json.Marshal(msg)
if err != nil {
logger.Log.Warn("failed to marshal heartbeat message", zap.Error(err))
return
}
if !h.client.Send(data) {
logger.Log.Warn("failed to send heartbeat: client not connected")
}
}

View File

@@ -1,57 +0,0 @@
package websocket
import (
"encoding/json"
"testing"
"time"
"github.com/yyhuni/lunafox/agent/internal/health"
"github.com/yyhuni/lunafox/agent/internal/metrics"
"github.com/yyhuni/lunafox/agent/internal/protocol"
)
func TestHeartbeatSenderSendOnce(t *testing.T) {
client := &Client{send: make(chan []byte, 1)}
collector := metrics.NewCollector()
healthManager := health.NewManager()
healthManager.Set("paused", "maintenance", "waiting")
sender := NewHeartbeatSender(client, collector, healthManager, "v1.0.0", "agent-host", func() int { return 3 })
sender.sendOnce()
select {
case payload := <-client.send:
var msg struct {
Type string `json:"type"`
Payload map[string]interface{} `json:"payload"`
Timestamp time.Time `json:"timestamp"`
}
if err := json.Unmarshal(payload, &msg); err != nil {
t.Fatalf("unmarshal heartbeat: %v", err)
}
if msg.Type != protocol.MessageTypeHeartbeat {
t.Fatalf("expected heartbeat type, got %s", msg.Type)
}
if msg.Timestamp.IsZero() {
t.Fatalf("expected timestamp")
}
if msg.Payload["version"] != "v1.0.0" {
t.Fatalf("expected version in payload")
}
if msg.Payload["hostname"] != "agent-host" {
t.Fatalf("expected hostname in payload")
}
if tasks, ok := msg.Payload["tasks"].(float64); !ok || int(tasks) != 3 {
t.Fatalf("expected tasks=3")
}
healthPayload, ok := msg.Payload["health"].(map[string]interface{})
if !ok {
t.Fatalf("expected health payload")
}
if healthPayload["state"] != "paused" {
t.Fatalf("expected health state paused")
}
default:
t.Fatalf("expected heartbeat message")
}
}

View File

@@ -1,13 +0,0 @@
package integration
import (
"os"
"testing"
)
func TestTaskExecutionFlow(t *testing.T) {
if os.Getenv("AGENT_INTEGRATION") == "" {
t.Skip("set AGENT_INTEGRATION=1 to run integration tests")
}
// TODO: wire up real server + docker environment for end-to-end validation.
}

View File

@@ -195,32 +195,3 @@ class DjangoHostPortMappingSnapshotRepository:
for row in qs.iterator(chunk_size=batch_size):
yield row
def iter_unique_host_ports_by_scan(
self,
scan_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
流式获取扫描下的唯一 host:port 组合去重
用于生成 URL 时避免重复同一个 host:port 可能对应多个 IP
但生成 URL 时只需要一个
Args:
scan_id: 扫描 ID
batch_size: 每批数据量
Yields:
{'host': 'example.com', 'port': 80}
"""
qs = (
HostPortMappingSnapshot.objects
.filter(scan_id=scan_id)
.values('host', 'port')
.distinct()
.order_by('host', 'port')
)
for row in qs.iterator(chunk_size=batch_size):
yield row

Some files were not shown because too many files have changed in this diff Show More