mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-02-02 04:33:10 +08:00
Compare commits
265 Commits
v1.2.0-dev
...
002-server
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fd832ce22 | ||
|
|
e76ecaac15 | ||
|
|
08e6c7fbe3 | ||
|
|
5adb239547 | ||
|
|
896ae7743d | ||
|
|
d5c363294b | ||
|
|
4734f7a576 | ||
|
|
46b1d5a1d1 | ||
|
|
66fa60c415 | ||
|
|
3d54d26c7e | ||
|
|
b4a289b198 | ||
|
|
b727b2d001 | ||
|
|
b859fc9062 | ||
|
|
49b5fbef28 | ||
|
|
11112a68f6 | ||
|
|
9049b096ba | ||
|
|
ca6c0eb082 | ||
|
|
64bcd9a6f5 | ||
|
|
443e2172e4 | ||
|
|
c6dcfb0a5b | ||
|
|
25ae325c69 | ||
|
|
cab83d89cf | ||
|
|
0f8fff2dc4 | ||
|
|
6e48b97dc2 | ||
|
|
ed757d6e14 | ||
|
|
2aa1afbabf | ||
|
|
35ac64db57 | ||
|
|
b4bfab92e3 | ||
|
|
72210c42d0 | ||
|
|
91aaf7997f | ||
|
|
32e3179d58 | ||
|
|
487f7c84b5 | ||
|
|
b2cc83f569 | ||
|
|
f854cf09be | ||
|
|
7e1c2c187a | ||
|
|
4abb259ca0 | ||
|
|
bbef6af000 | ||
|
|
ba0864ed16 | ||
|
|
f54827829a | ||
|
|
170021130c | ||
|
|
b540f69152 | ||
|
|
d7f1e04855 | ||
|
|
68ad18e6da | ||
|
|
a7542d4a34 | ||
|
|
6f02d9f3c5 | ||
|
|
794846ca7a | ||
|
|
5eea7b2621 | ||
|
|
069527a7f1 | ||
|
|
e542633ad3 | ||
|
|
e8a9606d3b | ||
|
|
dc2e1e027d | ||
|
|
b1847faa3a | ||
|
|
e699842492 | ||
|
|
08a4807bef | ||
|
|
191ff9837b | ||
|
|
679dff9037 | ||
|
|
ce4330b628 | ||
|
|
4ce6b148f8 | ||
|
|
a89f775ee9 | ||
|
|
e3003f33f9 | ||
|
|
3760684b64 | ||
|
|
bfd7e11d09 | ||
|
|
f758feb0d0 | ||
|
|
8798eed337 | ||
|
|
bd1e25cfd5 | ||
|
|
d775055572 | ||
|
|
00dfad60b8 | ||
|
|
a5c48fe4d4 | ||
|
|
85c880731c | ||
|
|
c6b6507412 | ||
|
|
af457dc44c | ||
|
|
9e01a6aa5e | ||
|
|
ed80772e6f | ||
|
|
a22af21dcb | ||
|
|
8de950a7a5 | ||
|
|
9db84221e9 | ||
|
|
0728f3c01d | ||
|
|
4aa7b3d68a | ||
|
|
3946a53337 | ||
|
|
c94fe1ec4b | ||
|
|
6dea525527 | ||
|
|
5b0416972a | ||
|
|
5345a34cbd | ||
|
|
3ca56abc3e | ||
|
|
9703add22d | ||
|
|
f5a489e2d6 | ||
|
|
d75a3f6882 | ||
|
|
59e48e5b15 | ||
|
|
2d2ec93626 | ||
|
|
ced9f811f4 | ||
|
|
aa99b26f50 | ||
|
|
8342f196db | ||
|
|
1bd2a6ed88 | ||
|
|
033ff89aee | ||
|
|
4284a0cd9a | ||
|
|
943a4cb960 | ||
|
|
eb2d853b76 | ||
|
|
1184c18b74 | ||
|
|
8a6f1b6f24 | ||
|
|
255d505aba | ||
|
|
d06a9bab1f | ||
|
|
6d5c776bf7 | ||
|
|
bf058dd67b | ||
|
|
0532d7c8b8 | ||
|
|
2ee9b5ffa2 | ||
|
|
648a1888d4 | ||
|
|
2508268a45 | ||
|
|
c60383940c | ||
|
|
47298c294a | ||
|
|
eba394e14e | ||
|
|
592a1958c4 | ||
|
|
38e2856c08 | ||
|
|
f5ad8e68e9 | ||
|
|
d5f91a236c | ||
|
|
24ae8b5aeb | ||
|
|
86f43f94a0 | ||
|
|
53ba03d1e5 | ||
|
|
89c44ebd05 | ||
|
|
e0e3419edb | ||
|
|
52ee4684a7 | ||
|
|
ce8cebf11d | ||
|
|
ec006d8f54 | ||
|
|
48976a570f | ||
|
|
5da7229873 | ||
|
|
8bb737a9fa | ||
|
|
2d018d33f3 | ||
|
|
0c07cc8497 | ||
|
|
225b039985 | ||
|
|
d1624627bc | ||
|
|
7bb15e4ae4 | ||
|
|
8e8cc29669 | ||
|
|
d6d5338acb | ||
|
|
c521bdb511 | ||
|
|
abf2d95f6f | ||
|
|
ab58cf0d85 | ||
|
|
fb0111adf2 | ||
|
|
161ee9a2b1 | ||
|
|
0cf75585d5 | ||
|
|
1d8d5f51d9 | ||
|
|
3f8de07c8c | ||
|
|
cd5c2b9f11 | ||
|
|
54786c22dd | ||
|
|
d468f975ab | ||
|
|
a85a12b8ad | ||
|
|
a8b0d97b7b | ||
|
|
b8504921c2 | ||
|
|
ecfc1822fb | ||
|
|
81633642e6 | ||
|
|
d1ec9b7f27 | ||
|
|
2a3d9b4446 | ||
|
|
9b63203b5a | ||
|
|
6ff86e14ec | ||
|
|
4c1282e9bb | ||
|
|
ba3a9b709d | ||
|
|
283b28b46a | ||
|
|
1269e5a314 | ||
|
|
802e967906 | ||
|
|
e446326416 | ||
|
|
e0abb3ce7b | ||
|
|
d418baaf79 | ||
|
|
f8da408580 | ||
|
|
7cd4354d8f | ||
|
|
6bf35a760f | ||
|
|
be9ecadffb | ||
|
|
adb53c9f85 | ||
|
|
7b7bbed634 | ||
|
|
8dd3f0536e | ||
|
|
8a8062a12d | ||
|
|
55908a2da5 | ||
|
|
22a7d4f091 | ||
|
|
f287f18134 | ||
|
|
de27230b7a | ||
|
|
15a6295189 | ||
|
|
674acdac66 | ||
|
|
c59152bedf | ||
|
|
b4037202dc | ||
|
|
4b4f9862bf | ||
|
|
1c42e4978f | ||
|
|
57bab63997 | ||
|
|
b1f0f18ac0 | ||
|
|
ccee5471b8 | ||
|
|
0ccd362535 | ||
|
|
7f2af7f7e2 | ||
|
|
4bd0f9e8c1 | ||
|
|
68cc996e3b | ||
|
|
f1e79d638e | ||
|
|
d484133e4c | ||
|
|
fc977ae029 | ||
|
|
f328474404 | ||
|
|
68e726a066 | ||
|
|
77a6f45909 | ||
|
|
49d1f1f1bb | ||
|
|
db8ecb1644 | ||
|
|
18cc016268 | ||
|
|
23bc463283 | ||
|
|
7b903b91b2 | ||
|
|
b3136d51b9 | ||
|
|
08372588a4 | ||
|
|
236c828041 | ||
|
|
fb13bb74d8 | ||
|
|
f076c682b6 | ||
|
|
9eda2caceb | ||
|
|
b1c9e202dd | ||
|
|
918669bc29 | ||
|
|
fd70b0544d | ||
|
|
0f2df7a5f3 | ||
|
|
857ab737b5 | ||
|
|
ee2d99edda | ||
|
|
db6ce16aca | ||
|
|
ab800eca06 | ||
|
|
e8e5572339 | ||
|
|
d48d4bbcad | ||
|
|
d1cca4c083 | ||
|
|
df0810c863 | ||
|
|
d33e54c440 | ||
|
|
35a306fe8b | ||
|
|
724df82931 | ||
|
|
8dfffdf802 | ||
|
|
b8cb85ce0b | ||
|
|
da96d437a4 | ||
|
|
feaf8062e5 | ||
|
|
4bab76f233 | ||
|
|
09416b4615 | ||
|
|
bc1c5f6b0e | ||
|
|
2f2742e6fe | ||
|
|
be3c346a74 | ||
|
|
0c7a6fff12 | ||
|
|
3b4f0e3147 | ||
|
|
51212a2a0c | ||
|
|
58533bbaf6 | ||
|
|
6ccca1602d | ||
|
|
6389b0f672 | ||
|
|
d7599b8599 | ||
|
|
8eff298293 | ||
|
|
3634101c5b | ||
|
|
163973a7df | ||
|
|
80ffecba3e | ||
|
|
3c21ac940c | ||
|
|
5c9f484d70 | ||
|
|
7567f6c25b | ||
|
|
0599a0b298 | ||
|
|
f7557fe90c | ||
|
|
13571b9772 | ||
|
|
8ee76eef69 | ||
|
|
2a31e29aa2 | ||
|
|
81abc59961 | ||
|
|
ffbfec6dd5 | ||
|
|
a0091636a8 | ||
|
|
69490ab396 | ||
|
|
7306964abf | ||
|
|
cb6b0259e3 | ||
|
|
e1b4618e58 | ||
|
|
556dcf5f62 | ||
|
|
0628eef025 | ||
|
|
38ed8bc642 | ||
|
|
2f4d6a2168 | ||
|
|
c25cb9e06b | ||
|
|
b14ab71c7f | ||
|
|
8b5060e2d3 | ||
|
|
3c9335febf | ||
|
|
1b95e4f2c3 | ||
|
|
d20a600afc | ||
|
|
c29b11fd37 | ||
|
|
6caf707072 | ||
|
|
2627b1fc40 |
45
.github/workflows/check-generated-files.yml
vendored
Normal file
45
.github/workflows/check-generated-files.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
name: Check Generated Files
|
||||
|
||||
on:
|
||||
workflow_call: # 只在被其他 workflow 调用时运行
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
check:
|
||||
runs-on: ubuntu-22.04 # 固定版本,避免 runner 更新导致 CI 行为变化
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.24' # 与 go.mod 保持一致
|
||||
|
||||
- name: Generate files for all workflows
|
||||
working-directory: worker
|
||||
run: make generate
|
||||
|
||||
- name: Check for differences
|
||||
run: |
|
||||
if ! git diff --exit-code; then
|
||||
echo "❌ Generated files are out of date!"
|
||||
echo "Please run: cd worker && make generate"
|
||||
echo ""
|
||||
echo "Changed files:"
|
||||
git status --porcelain
|
||||
echo ""
|
||||
echo "Diff:"
|
||||
git diff
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Generated files are up to date"
|
||||
|
||||
- name: Run metadata consistency tests
|
||||
working-directory: worker
|
||||
run: make test-metadata
|
||||
|
||||
- name: Run all tests
|
||||
working-directory: worker
|
||||
run: make test
|
||||
13
.github/workflows/ci.yml
vendored
Normal file
13
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, develop]
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
check-generated:
|
||||
uses: ./.github/workflows/check-generated-files.yml
|
||||
139
.github/workflows/docker-build.yml
vendored
139
.github/workflows/docker-build.yml
vendored
@@ -1,139 +0,0 @@
|
||||
name: Build and Push Docker Images
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*' # 只在推送 v 开头的 tag 时触发(如 v1.0.0)
|
||||
workflow_dispatch: # 手动触发
|
||||
|
||||
# 并发控制:同一分支只保留最新的构建,取消之前正在运行的
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
REGISTRY: docker.io
|
||||
IMAGE_PREFIX: yyhuni
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- image: xingrin-server
|
||||
dockerfile: docker/server/Dockerfile
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- image: xingrin-frontend
|
||||
dockerfile: docker/frontend/Dockerfile
|
||||
context: .
|
||||
platforms: linux/amd64 # ARM64 构建时 Next.js 在 QEMU 下会崩溃
|
||||
- image: xingrin-worker
|
||||
dockerfile: docker/worker/Dockerfile
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- image: xingrin-nginx
|
||||
dockerfile: docker/nginx/Dockerfile
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- image: xingrin-agent
|
||||
dockerfile: docker/agent/Dockerfile
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free disk space (for large builds like worker)
|
||||
run: |
|
||||
echo "=== Before cleanup ==="
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker image prune -af
|
||||
echo "=== After cleanup ==="
|
||||
df -h
|
||||
|
||||
- name: Generate SSL certificates for nginx build
|
||||
if: matrix.image == 'xingrin-nginx'
|
||||
run: |
|
||||
mkdir -p docker/nginx/ssl
|
||||
openssl req -x509 -nodes -days 365 -newkey rsa:2048 \
|
||||
-keyout docker/nginx/ssl/privkey.pem \
|
||||
-out docker/nginx/ssl/fullchain.pem \
|
||||
-subj "/CN=localhost"
|
||||
echo "SSL certificates generated for CI build"
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Get version from git tag
|
||||
id: version
|
||||
run: |
|
||||
if [[ $GITHUB_REF == refs/tags/* ]]; then
|
||||
echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
||||
echo "IS_RELEASE=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "VERSION=dev-$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
|
||||
echo "IS_RELEASE=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.dockerfile }}
|
||||
platforms: ${{ matrix.platforms }}
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:${{ steps.version.outputs.VERSION }}
|
||||
${{ steps.version.outputs.IS_RELEASE == 'true' && format('{0}/{1}:latest', env.IMAGE_PREFIX, matrix.image) || '' }}
|
||||
build-args: |
|
||||
IMAGE_TAG=${{ steps.version.outputs.VERSION }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
provenance: false
|
||||
sbom: false
|
||||
|
||||
# 所有镜像构建成功后,更新 VERSION 文件
|
||||
# 只有正式版本(不含 -dev, -alpha, -beta, -rc 等后缀)才更新
|
||||
update-version:
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: main
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Update VERSION file
|
||||
run: |
|
||||
VERSION="${GITHUB_REF#refs/tags/}"
|
||||
echo "$VERSION" > VERSION
|
||||
echo "Updated VERSION to $VERSION"
|
||||
|
||||
- name: Commit and push
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git add VERSION
|
||||
git diff --staged --quiet || git commit -m "chore: bump version to ${GITHUB_REF#refs/tags/}"
|
||||
git push
|
||||
168
.gitignore
vendored
168
.gitignore
vendored
@@ -1,136 +1,60 @@
|
||||
# ============================
|
||||
# 操作系统相关文件
|
||||
# ============================
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
# Go
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
*.test
|
||||
*.out
|
||||
vendor/
|
||||
go.work
|
||||
|
||||
# ============================
|
||||
# 前端 (Next.js/Node.js) 相关
|
||||
# ============================
|
||||
# 依赖目录
|
||||
front-back/node_modules/
|
||||
front-back/.pnpm-store/
|
||||
# Build artifacts
|
||||
dist/
|
||||
build/
|
||||
bin/
|
||||
|
||||
# Next.js 构建产物
|
||||
front-back/.next/
|
||||
front-back/out/
|
||||
front-back/dist/
|
||||
|
||||
# 环境变量文件
|
||||
front-back/.env
|
||||
front-back/.env.local
|
||||
front-back/.env.development.local
|
||||
front-back/.env.test.local
|
||||
front-back/.env.production.local
|
||||
|
||||
# 运行时和缓存
|
||||
front-back/.turbo/
|
||||
front-back/.swc/
|
||||
front-back/.eslintcache
|
||||
front-back/.tsbuildinfo
|
||||
|
||||
# ============================
|
||||
# 后端 (Python/Django) 相关
|
||||
# ============================
|
||||
# Python 虚拟环境
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
|
||||
# Python 编译文件
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Django 相关
|
||||
backend/db.sqlite3
|
||||
backend/db.sqlite3-journal
|
||||
backend/media/
|
||||
backend/staticfiles/
|
||||
backend/.env
|
||||
backend/.env.local
|
||||
|
||||
# Python 测试和覆盖率
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
*.cover
|
||||
|
||||
# ============================
|
||||
# 后端 (Go) 相关
|
||||
# ============================
|
||||
# 编译产物
|
||||
backend/bin/
|
||||
backend/dist/
|
||||
backend/*.exe
|
||||
backend/*.exe~
|
||||
backend/*.dll
|
||||
backend/*.so
|
||||
backend/*.dylib
|
||||
|
||||
# 测试相关
|
||||
backend/*.test
|
||||
backend/*.out
|
||||
backend/*.prof
|
||||
|
||||
# Go workspace 文件
|
||||
backend/go.work
|
||||
backend/go.work.sum
|
||||
|
||||
# Go 依赖管理
|
||||
backend/vendor/
|
||||
|
||||
# ============================
|
||||
# IDE 和编辑器相关
|
||||
# ============================
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
.cursor/
|
||||
.claude/
|
||||
.kiro/
|
||||
.playwright-mcp/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.DS_Store
|
||||
|
||||
# ============================
|
||||
# Docker 相关
|
||||
# ============================
|
||||
docker/.env
|
||||
docker/.env.local
|
||||
|
||||
# SSL 证书和私钥(不应提交)
|
||||
docker/nginx/ssl/*.pem
|
||||
docker/nginx/ssl/*.key
|
||||
docker/nginx/ssl/*.crt
|
||||
|
||||
# ============================
|
||||
# 日志文件和扫描结果
|
||||
# ============================
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
**/.env
|
||||
**/.env.local
|
||||
**/.env.*.local
|
||||
*.log
|
||||
logs/
|
||||
results/
|
||||
.venv/
|
||||
|
||||
# 开发脚本运行时文件(进程 ID 和启动日志)
|
||||
backend/scripts/dev/.pids/
|
||||
# Testing
|
||||
coverage.txt
|
||||
*.coverprofile
|
||||
.hypothesis/
|
||||
|
||||
# ============================
|
||||
# 临时文件
|
||||
# ============================
|
||||
# Temporary files
|
||||
*.tmp
|
||||
tmp/
|
||||
temp/
|
||||
.cache/
|
||||
|
||||
HGETALL
|
||||
KEYS
|
||||
vuln_scan/input_endpoints.txt
|
||||
open-in-v0
|
||||
.kiro/
|
||||
.claude/
|
||||
.specify/
|
||||
|
||||
# AI Assistant directories
|
||||
codex/
|
||||
openspec/
|
||||
specs/
|
||||
AGENTS.md
|
||||
WARP.md
|
||||
.opencode/
|
||||
|
||||
# SSL certificates
|
||||
docker/nginx/ssl/*.pem
|
||||
docker/nginx/ssl/*.key
|
||||
docker/nginx/ssl/*.crt
|
||||
266
README.md
266
README.md
@@ -1,266 +0,0 @@
|
||||
<h1 align="center">XingRin - 星环</h1>
|
||||
|
||||
<p align="center">
|
||||
<b>🛡️ 攻击面管理平台 (ASM) | 自动化资产发现与漏洞扫描系统</b>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/yyhuni/xingrin/stargazers"><img src="https://img.shields.io/github/stars/yyhuni/xingrin?style=flat-square&logo=github" alt="GitHub stars"></a>
|
||||
<a href="https://github.com/yyhuni/xingrin/network/members"><img src="https://img.shields.io/github/forks/yyhuni/xingrin?style=flat-square&logo=github" alt="GitHub forks"></a>
|
||||
<a href="https://github.com/yyhuni/xingrin/issues"><img src="https://img.shields.io/github/issues/yyhuni/xingrin?style=flat-square&logo=github" alt="GitHub issues"></a>
|
||||
<a href="https://github.com/yyhuni/xingrin/blob/main/LICENSE"><img src="https://img.shields.io/badge/license-PolyForm%20NC-blue?style=flat-square" alt="License"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="#-功能特性">功能特性</a> •
|
||||
<a href="#-快速开始">快速开始</a> •
|
||||
<a href="#-文档">文档</a> •
|
||||
<a href="#-技术栈">技术栈</a> •
|
||||
<a href="#-反馈与贡献">反馈与贡献</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<sub>🔍 关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
|
||||
<p align="center">
|
||||
<b>🎨 现代化 UI </b>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="docs/screenshots/light.png" alt="Light Mode" width="24%">
|
||||
<img src="docs/screenshots/bubblegum.png" alt="Bubblegum" width="24%">
|
||||
<img src="docs/screenshots/cosmic-night.png" alt="Cosmic Night" width="24%">
|
||||
<img src="docs/screenshots/quantum-rose.png" alt="Quantum Rose" width="24%">
|
||||
</p>
|
||||
|
||||
## 📚 文档
|
||||
|
||||
- [📖 技术文档](./docs/README.md) - 技术文档导航(🚧 持续完善中)
|
||||
- [🚀 快速开始](./docs/quick-start.md) - 一键安装和部署指南
|
||||
- [🔄 版本管理](./docs/version-management.md) - Git Tag 驱动的自动化版本管理系统
|
||||
- [📦 Nuclei 模板架构](./docs/nuclei-template-architecture.md) - 模板仓库的存储与同步
|
||||
- [📖 字典文件架构](./docs/wordlist-architecture.md) - 字典文件的存储与同步
|
||||
- [🔍 扫描流程架构](./docs/scan-flow-architecture.md) - 完整扫描流程与工具编排
|
||||
|
||||
|
||||
---
|
||||
|
||||
## ✨ 功能特性
|
||||
|
||||
### 🎯 目标与资产管理
|
||||
- **组织管理** - 多层级目标组织,灵活分组
|
||||
- **目标管理** - 支持域名、IP目标类型
|
||||
- **资产发现** - 子域名、网站、端点、目录自动发现
|
||||
- **资产快照** - 扫描结果快照对比,追踪资产变化
|
||||
|
||||
### 🔍 漏洞扫描
|
||||
- **多引擎支持** - 集成 Nuclei 等主流扫描引擎
|
||||
- **自定义流程** - YAML 配置扫描流程,灵活编排
|
||||
- **定时扫描** - Cron 表达式配置,自动化周期扫描
|
||||
|
||||
#### 扫描流程架构
|
||||
|
||||
完整的扫描流程包括:子域名发现、端口扫描、站点发现、URL 收集、目录扫描、漏洞扫描等阶段
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
START["开始扫描"]
|
||||
|
||||
subgraph STAGE1["阶段 1: 资产发现"]
|
||||
direction TB
|
||||
SUB["子域名发现<br/>subfinder, amass, puredns"]
|
||||
PORT["端口扫描<br/>naabu"]
|
||||
SITE["站点识别<br/>httpx"]
|
||||
SUB --> PORT --> SITE
|
||||
end
|
||||
|
||||
subgraph STAGE2["阶段 2: 深度分析"]
|
||||
direction TB
|
||||
URL["URL 收集<br/>waymore, katana"]
|
||||
DIR["目录扫描<br/>ffuf"]
|
||||
end
|
||||
|
||||
subgraph STAGE3["阶段 3: 漏洞检测"]
|
||||
VULN["漏洞扫描<br/>nuclei, dalfox"]
|
||||
end
|
||||
|
||||
FINISH["扫描完成"]
|
||||
|
||||
START --> STAGE1
|
||||
SITE --> STAGE2
|
||||
STAGE2 --> STAGE3
|
||||
STAGE3 --> FINISH
|
||||
|
||||
style START fill:#34495e,stroke:#2c3e50,stroke-width:2px,color:#fff
|
||||
style FINISH fill:#27ae60,stroke:#229954,stroke-width:2px,color:#fff
|
||||
style STAGE1 fill:#3498db,stroke:#2980b9,stroke-width:2px,color:#fff
|
||||
style STAGE2 fill:#9b59b6,stroke:#8e44ad,stroke-width:2px,color:#fff
|
||||
style STAGE3 fill:#e67e22,stroke:#d35400,stroke-width:2px,color:#fff
|
||||
style SUB fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
|
||||
style PORT fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
|
||||
style SITE fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
|
||||
style URL fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
|
||||
style DIR fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
|
||||
style VULN fill:#f0b27a,stroke:#e67e22,stroke-width:1px,color:#fff
|
||||
```
|
||||
|
||||
详细说明请查看 [扫描流程架构文档](./docs/scan-flow-architecture.md)
|
||||
|
||||
### 🖥️ 分布式架构
|
||||
- **多节点扫描** - 支持部署多个 Worker 节点,横向扩展扫描能力
|
||||
- **本地节点** - 零配置,安装即自动注册本地 Docker Worker
|
||||
- **远程节点** - SSH 一键部署远程 VPS 作为扫描节点
|
||||
- **负载感知调度** - 实时感知节点负载,自动分发任务到最优节点
|
||||
- **节点监控** - 实时心跳检测,CPU/内存/磁盘状态监控
|
||||
- **断线重连** - 节点离线自动检测,恢复后自动重新接入
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
subgraph MASTER["主服务器 (Master Server)"]
|
||||
direction TB
|
||||
|
||||
REDIS["Redis 负载缓存"]
|
||||
|
||||
subgraph SCHEDULER["任务调度器 (Task Distributor)"]
|
||||
direction TB
|
||||
SUBMIT["接收扫描任务"]
|
||||
SELECT["负载感知选择"]
|
||||
DISPATCH["智能分发"]
|
||||
|
||||
SUBMIT --> SELECT
|
||||
SELECT --> DISPATCH
|
||||
end
|
||||
|
||||
REDIS -.负载数据.-> SELECT
|
||||
end
|
||||
|
||||
subgraph WORKERS["Worker 节点集群"]
|
||||
direction TB
|
||||
|
||||
W1["Worker 1 (本地)<br/>CPU: 45% | MEM: 60%"]
|
||||
W2["Worker 2 (远程)<br/>CPU: 30% | MEM: 40%"]
|
||||
W3["Worker N (远程)<br/>CPU: 90% | MEM: 85%"]
|
||||
end
|
||||
|
||||
DISPATCH -->|任务分发| W1
|
||||
DISPATCH -->|任务分发| W2
|
||||
DISPATCH -->|高负载跳过| W3
|
||||
|
||||
W1 -.心跳上报.-> REDIS
|
||||
W2 -.心跳上报.-> REDIS
|
||||
W3 -.心跳上报.-> REDIS
|
||||
```
|
||||
|
||||
### 📊 可视化界面
|
||||
- **数据统计** - 资产/漏洞统计仪表盘
|
||||
- **实时通知** - WebSocket 消息推送
|
||||
|
||||
---
|
||||
|
||||
## 📦 快速开始
|
||||
|
||||
### 环境要求
|
||||
|
||||
- **操作系统**: Ubuntu 20.04+ / Debian 11+ (推荐)
|
||||
- **硬件**: 2核 4G 内存起步,20GB+ 磁盘空间
|
||||
|
||||
### 一键安装
|
||||
|
||||
```bash
|
||||
# 克隆项目
|
||||
git clone https://github.com/yyhuni/xingrin.git
|
||||
cd xingrin
|
||||
|
||||
# 安装并启动(生产模式)
|
||||
sudo ./install.sh
|
||||
|
||||
# 🇨🇳 中国大陆用户推荐使用镜像加速
|
||||
sudo ./install.sh --mirror
|
||||
```
|
||||
|
||||
> **💡 --mirror 参数说明**
|
||||
> - 自动配置 Docker 镜像加速(国内镜像源)
|
||||
> - 加速 Git 仓库克隆(Nuclei 模板等)
|
||||
> - 大幅提升安装速度,避免网络超时
|
||||
|
||||
### 访问服务
|
||||
|
||||
- **Web 界面**: `https://ip:8083`
|
||||
|
||||
### 常用命令
|
||||
|
||||
```bash
|
||||
# 启动服务
|
||||
sudo ./start.sh
|
||||
|
||||
# 停止服务
|
||||
sudo ./stop.sh
|
||||
|
||||
# 重启服务
|
||||
sudo ./restart.sh
|
||||
|
||||
# 卸载
|
||||
sudo ./uninstall.sh
|
||||
```
|
||||
|
||||
## 🤝 反馈与贡献
|
||||
|
||||
- 🐛 **如果发现 Bug** 可以点击右边链接进行提交 [Issue](https://github.com/yyhuni/xingrin/issues)
|
||||
- 💡 **有新想法,比如UI设计,功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues)
|
||||
- 🔧 **想参与开发?** 关注我公众号与我个人联系
|
||||
|
||||
## 📧 联系
|
||||
- 目前版本就我个人使用,可能会有很多边界问题
|
||||
- 如有问题,建议,其他,优先提交[Issue](https://github.com/yyhuni/xingrin/issues),也可以直接给我的公众号发消息,我都会回复的
|
||||
|
||||
- 微信公众号: **洋洋的小黑屋**
|
||||
|
||||
<img src="docs/wechat-qrcode.png" alt="微信公众号" width="200">
|
||||
|
||||
|
||||
## ⚠️ 免责声明
|
||||
|
||||
**重要:请在使用前仔细阅读**
|
||||
|
||||
1. 本工具仅供**授权的安全测试**和**安全研究**使用
|
||||
2. 使用者必须确保已获得目标系统的**合法授权**
|
||||
3. **严禁**将本工具用于未经授权的渗透测试或攻击行为
|
||||
4. 未经授权扫描他人系统属于**违法行为**,可能面临法律责任
|
||||
5. 开发者**不对任何滥用行为负责**
|
||||
|
||||
使用本工具即表示您同意:
|
||||
- 仅在合法授权范围内使用
|
||||
- 遵守所在地区的法律法规
|
||||
- 承担因滥用产生的一切后果
|
||||
|
||||
## 🌟 Star History
|
||||
|
||||
如果这个项目对你有帮助,请给一个 ⭐ Star 支持一下!
|
||||
|
||||
[](https://star-history.com/#yyhuni/xingrin&Date)
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
本项目采用 [GNU General Public License v3.0](LICENSE) 许可证。
|
||||
|
||||
### 允许的用途
|
||||
|
||||
- ✅ 个人学习和研究
|
||||
- ✅ 商业和非商业使用
|
||||
- ✅ 修改和分发
|
||||
- ✅ 专利使用
|
||||
- ✅ 私人使用
|
||||
|
||||
### 义务和限制
|
||||
|
||||
- 📋 **开源义务**:分发时必须提供源代码
|
||||
- 📋 **相同许可**:衍生作品必须使用相同许可证
|
||||
- 📋 **版权声明**:必须保留原始版权和许可证声明
|
||||
- ❌ **责任免除**:不提供任何担保
|
||||
- ❌ 未经授权的渗透测试
|
||||
- ❌ 任何违法行为
|
||||
|
||||
13
agent/.air.toml
Normal file
13
agent/.air.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
root = "."
|
||||
tmp_dir = "tmp"
|
||||
|
||||
[build]
|
||||
cmd = "go build -o ./tmp/agent ./cmd/agent"
|
||||
bin = "./tmp/agent"
|
||||
delay = 1000
|
||||
include_ext = ["go", "tpl", "tmpl", "html"]
|
||||
exclude_dir = ["tmp", "vendor", ".git"]
|
||||
exclude_regex = ["_test\\.go"]
|
||||
|
||||
[log]
|
||||
time = true
|
||||
41
agent/Dockerfile
Normal file
41
agent/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
# ============================================
|
||||
# Go Agent - build
|
||||
# ============================================
|
||||
FROM golang:1.25.6 AS builder
|
||||
|
||||
ARG GO111MODULE=on
|
||||
ARG GOPROXY=https://goproxy.cn,direct
|
||||
|
||||
ENV GO111MODULE=$GO111MODULE
|
||||
ENV GOPROXY=$GOPROXY
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
# Cache dependencies
|
||||
COPY agent/go.mod agent/go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
# Copy source
|
||||
COPY agent ./agent
|
||||
|
||||
WORKDIR /src/agent
|
||||
|
||||
# Build (static where possible)
|
||||
RUN CGO_ENABLED=0 go build -o /out/agent ./cmd/agent
|
||||
|
||||
# ============================================
|
||||
# Go Agent - runtime
|
||||
# ============================================
|
||||
FROM debian:bookworm-20260112-slim
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /out/agent /usr/local/bin/agent
|
||||
|
||||
CMD ["agent"]
|
||||
37
agent/cmd/agent/main.go
Normal file
37
agent/cmd/agent/main.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/app"
|
||||
"github.com/yyhuni/lunafox/agent/internal/config"
|
||||
"github.com/yyhuni/lunafox/agent/internal/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := logger.Init(os.Getenv("LOG_LEVEL")); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "logger init failed: %v\n", err)
|
||||
}
|
||||
defer logger.Sync()
|
||||
|
||||
cfg, err := config.Load(os.Args[1:])
|
||||
if err != nil {
|
||||
logger.Log.Fatal("failed to load config", zap.Error(err))
|
||||
}
|
||||
wsURL, err := config.BuildWebSocketURL(cfg.ServerURL)
|
||||
if err != nil {
|
||||
logger.Log.Fatal("invalid server URL", zap.Error(err))
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
if err := app.Run(ctx, *cfg, wsURL); err != nil {
|
||||
logger.Log.Fatal("agent stopped", zap.Error(err))
|
||||
}
|
||||
}
|
||||
48
agent/go.mod
Normal file
48
agent/go.mod
Normal file
@@ -0,0 +1,48 @@
|
||||
module github.com/yyhuni/lunafox/agent
|
||||
|
||||
go 1.24.5
|
||||
|
||||
require (
|
||||
github.com/docker/docker v28.5.2+incompatible
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/opencontainers/image-spec v1.1.1
|
||||
github.com/shirou/gopsutil/v3 v3.24.5
|
||||
go.uber.org/zap v1.27.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||
github.com/moby/sys/atomicwriter v0.1.0 // indirect
|
||||
github.com/moby/term v0.5.2 // indirect
|
||||
github.com/morikuni/aec v1.1.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect
|
||||
go.opentelemetry.io/otel v1.39.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.39.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.39.0 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
gotest.tools/v3 v3.5.2 // indirect
|
||||
)
|
||||
131
agent/go.sum
Normal file
131
agent/go.sum
Normal file
@@ -0,0 +1,131 @@
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM=
|
||||
github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||
github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
|
||||
github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
|
||||
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
|
||||
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
|
||||
github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
|
||||
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
|
||||
github.com/morikuni/aec v1.1.0 h1:vBBl0pUnvi/Je71dsRrhMBtreIqNMYErSAbEeb8jrXQ=
|
||||
github.com/morikuni/aec v1.1.0/go.mod h1:xDRgiq/iw5l+zkao76YTKzKttOp2cwPEne25HDkJnBw=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
|
||||
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
|
||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
||||
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ=
|
||||
go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48=
|
||||
go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 h1:f0cb2XPmrqn4XMy9PNliTgRKJgS5WcL/u0/WRYGz4t0=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0/go.mod h1:vnakAaFckOMiMtOIhFI2MNH4FYrZzXCYxmb1LlhoGz8=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0 h1:Ckwye2FpXkYgiHX7fyVrN1uA/UYd9ounqqTuSNAv0k4=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0/go.mod h1:teIFJh5pW2y+AN7riv6IBPX2DuesS3HgP39mwOspKwU=
|
||||
go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0=
|
||||
go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs=
|
||||
go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18=
|
||||
go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew=
|
||||
go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI=
|
||||
go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA=
|
||||
go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A=
|
||||
go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
|
||||
google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM=
|
||||
google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||
139
agent/internal/app/agent.go
Normal file
139
agent/internal/app/agent.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/config"
|
||||
"github.com/yyhuni/lunafox/agent/internal/docker"
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
"github.com/yyhuni/lunafox/agent/internal/health"
|
||||
"github.com/yyhuni/lunafox/agent/internal/logger"
|
||||
"github.com/yyhuni/lunafox/agent/internal/metrics"
|
||||
"github.com/yyhuni/lunafox/agent/internal/protocol"
|
||||
"github.com/yyhuni/lunafox/agent/internal/task"
|
||||
"github.com/yyhuni/lunafox/agent/internal/update"
|
||||
agentws "github.com/yyhuni/lunafox/agent/internal/websocket"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func Run(ctx context.Context, cfg config.Config, wsURL string) error {
|
||||
configUpdater := config.NewUpdater(cfg)
|
||||
|
||||
version := cfg.AgentVersion
|
||||
hostname := os.Getenv("AGENT_HOSTNAME")
|
||||
if hostname == "" {
|
||||
var err error
|
||||
hostname, err = os.Hostname()
|
||||
if err != nil || hostname == "" {
|
||||
hostname = "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
logger.Log.Info("agent starting",
|
||||
zap.String("version", version),
|
||||
zap.String("hostname", hostname),
|
||||
zap.String("server", cfg.ServerURL),
|
||||
zap.String("ws", wsURL),
|
||||
zap.Int("maxTasks", cfg.MaxTasks),
|
||||
zap.Int("cpuThreshold", cfg.CPUThreshold),
|
||||
zap.Int("memThreshold", cfg.MemThreshold),
|
||||
zap.Int("diskThreshold", cfg.DiskThreshold),
|
||||
)
|
||||
|
||||
client := agentws.NewClient(wsURL, cfg.APIKey)
|
||||
collector := metrics.NewCollector()
|
||||
healthManager := health.NewManager()
|
||||
taskCounter := &task.Counter{}
|
||||
heartbeat := agentws.NewHeartbeatSender(client, collector, healthManager, version, hostname, taskCounter.Count)
|
||||
|
||||
taskClient := task.NewClient(cfg.ServerURL, cfg.APIKey)
|
||||
puller := task.NewPuller(taskClient, collector, taskCounter, cfg.MaxTasks, cfg.CPUThreshold, cfg.MemThreshold, cfg.DiskThreshold)
|
||||
|
||||
taskQueue := make(chan *domain.Task, cfg.MaxTasks)
|
||||
puller.SetOnTask(func(t *domain.Task) {
|
||||
logger.Log.Info("task received",
|
||||
zap.Int("taskId", t.ID),
|
||||
zap.Int("scanId", t.ScanID),
|
||||
zap.String("workflow", t.WorkflowName),
|
||||
zap.Int("stage", t.Stage),
|
||||
zap.String("target", t.TargetName),
|
||||
)
|
||||
taskQueue <- t
|
||||
})
|
||||
|
||||
dockerClient, err := docker.NewClient()
|
||||
if err != nil {
|
||||
logger.Log.Warn("docker client unavailable", zap.Error(err))
|
||||
} else {
|
||||
logger.Log.Info("docker client ready")
|
||||
}
|
||||
|
||||
workerToken := os.Getenv("WORKER_TOKEN")
|
||||
if workerToken == "" {
|
||||
return errors.New("WORKER_TOKEN environment variable is required")
|
||||
}
|
||||
logger.Log.Info("worker token loaded")
|
||||
|
||||
executor := task.NewExecutor(dockerClient, taskClient, taskCounter, cfg.ServerURL, workerToken, version)
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := executor.Shutdown(shutdownCtx); err != nil && !errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.Log.Error("executor shutdown error", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
updater := update.NewUpdater(dockerClient, healthManager, puller, executor, configUpdater, cfg.APIKey, workerToken)
|
||||
|
||||
handler := agentws.NewHandler()
|
||||
handler.OnTaskAvailable(puller.NotifyTaskAvailable)
|
||||
handler.OnTaskCancel(func(taskID int) {
|
||||
logger.Log.Info("task cancel requested", zap.Int("taskId", taskID))
|
||||
executor.MarkCancelled(taskID)
|
||||
executor.CancelTask(taskID)
|
||||
})
|
||||
handler.OnConfigUpdate(func(payload protocol.ConfigUpdatePayload) {
|
||||
logger.Log.Info("config update received",
|
||||
zap.String("maxTasks", formatOptionalInt(payload.MaxTasks)),
|
||||
zap.String("cpuThreshold", formatOptionalInt(payload.CPUThreshold)),
|
||||
zap.String("memThreshold", formatOptionalInt(payload.MemThreshold)),
|
||||
zap.String("diskThreshold", formatOptionalInt(payload.DiskThreshold)),
|
||||
)
|
||||
cfgUpdate := config.Update{
|
||||
MaxTasks: payload.MaxTasks,
|
||||
CPUThreshold: payload.CPUThreshold,
|
||||
MemThreshold: payload.MemThreshold,
|
||||
DiskThreshold: payload.DiskThreshold,
|
||||
}
|
||||
configUpdater.Apply(cfgUpdate)
|
||||
puller.UpdateConfig(cfgUpdate.MaxTasks, cfgUpdate.CPUThreshold, cfgUpdate.MemThreshold, cfgUpdate.DiskThreshold)
|
||||
})
|
||||
handler.OnUpdateRequired(updater.HandleUpdateRequired)
|
||||
client.SetOnMessage(handler.Handle)
|
||||
|
||||
logger.Log.Info("starting heartbeat sender")
|
||||
go heartbeat.Start(ctx)
|
||||
logger.Log.Info("starting task puller")
|
||||
go func() {
|
||||
_ = puller.Run(ctx)
|
||||
}()
|
||||
logger.Log.Info("starting task executor")
|
||||
go executor.Start(ctx, taskQueue)
|
||||
|
||||
logger.Log.Info("connecting to server websocket")
|
||||
if err := client.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatOptionalInt(value *int) string {
|
||||
if value == nil {
|
||||
return "nil"
|
||||
}
|
||||
return strconv.Itoa(*value)
|
||||
}
|
||||
53
agent/internal/config/config.go
Normal file
53
agent/internal/config/config.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Config represents runtime settings for the agent.
|
||||
type Config struct {
|
||||
ServerURL string
|
||||
APIKey string
|
||||
AgentVersion string
|
||||
MaxTasks int
|
||||
CPUThreshold int
|
||||
MemThreshold int
|
||||
DiskThreshold int
|
||||
}
|
||||
|
||||
// Validate ensures config values are usable.
|
||||
func (c *Config) Validate() error {
|
||||
if c.ServerURL == "" {
|
||||
return errors.New("server URL is required")
|
||||
}
|
||||
if c.APIKey == "" {
|
||||
return errors.New("api key is required")
|
||||
}
|
||||
if c.AgentVersion == "" {
|
||||
return errors.New("AGENT_VERSION environment variable is required")
|
||||
}
|
||||
if c.MaxTasks < 1 {
|
||||
return errors.New("max tasks must be at least 1")
|
||||
}
|
||||
if err := validatePercent("cpu threshold", c.CPUThreshold); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePercent("mem threshold", c.MemThreshold); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePercent("disk threshold", c.DiskThreshold); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := BuildWebSocketURL(c.ServerURL); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePercent(name string, value int) error {
|
||||
if value < 1 || value > 100 {
|
||||
return fmt.Errorf("%s must be between 1 and 100", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
87
agent/internal/config/loader.go
Normal file
87
agent/internal/config/loader.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxTasks = 5
|
||||
defaultCPUThreshold = 85
|
||||
defaultMemThreshold = 85
|
||||
defaultDiskThreshold = 90
|
||||
)
|
||||
|
||||
// Load parses configuration from environment variables and CLI flags.
|
||||
func Load(args []string) (*Config, error) {
|
||||
maxTasks, err := readEnvInt("MAX_TASKS", defaultMaxTasks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cpuThreshold, err := readEnvInt("CPU_THRESHOLD", defaultCPUThreshold)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
memThreshold, err := readEnvInt("MEM_THRESHOLD", defaultMemThreshold)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
diskThreshold, err := readEnvInt("DISK_THRESHOLD", defaultDiskThreshold)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
ServerURL: strings.TrimSpace(os.Getenv("SERVER_URL")),
|
||||
APIKey: strings.TrimSpace(os.Getenv("API_KEY")),
|
||||
AgentVersion: strings.TrimSpace(os.Getenv("AGENT_VERSION")),
|
||||
MaxTasks: maxTasks,
|
||||
CPUThreshold: cpuThreshold,
|
||||
MemThreshold: memThreshold,
|
||||
DiskThreshold: diskThreshold,
|
||||
}
|
||||
|
||||
fs := flag.NewFlagSet("agent", flag.ContinueOnError)
|
||||
serverURL := fs.String("server-url", cfg.ServerURL, "Server base URL (e.g. https://1.1.1.1:8080)")
|
||||
apiKey := fs.String("api-key", cfg.APIKey, "Agent API key")
|
||||
maxTasksFlag := fs.Int("max-tasks", cfg.MaxTasks, "Maximum concurrent tasks")
|
||||
cpuThresholdFlag := fs.Int("cpu-threshold", cfg.CPUThreshold, "CPU threshold percentage")
|
||||
memThresholdFlag := fs.Int("mem-threshold", cfg.MemThreshold, "Memory threshold percentage")
|
||||
diskThresholdFlag := fs.Int("disk-threshold", cfg.DiskThreshold, "Disk threshold percentage")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg.ServerURL = strings.TrimSpace(*serverURL)
|
||||
cfg.APIKey = strings.TrimSpace(*apiKey)
|
||||
cfg.MaxTasks = *maxTasksFlag
|
||||
cfg.CPUThreshold = *cpuThresholdFlag
|
||||
cfg.MemThreshold = *memThresholdFlag
|
||||
cfg.DiskThreshold = *diskThresholdFlag
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func readEnvInt(key string, fallback int) (int, error) {
|
||||
val, ok := os.LookupEnv(key)
|
||||
if !ok {
|
||||
return fallback, nil
|
||||
}
|
||||
val = strings.TrimSpace(val)
|
||||
if val == "" {
|
||||
return fallback, nil
|
||||
}
|
||||
parsed, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid %s: %w", key, err)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
75
agent/internal/config/loader_test.go
Normal file
75
agent/internal/config/loader_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadConfigFromEnvAndFlags(t *testing.T) {
|
||||
t.Setenv("SERVER_URL", "https://example.com")
|
||||
t.Setenv("API_KEY", "abc12345")
|
||||
t.Setenv("AGENT_VERSION", "v1.2.3")
|
||||
t.Setenv("MAX_TASKS", "5")
|
||||
t.Setenv("CPU_THRESHOLD", "80")
|
||||
t.Setenv("MEM_THRESHOLD", "81")
|
||||
t.Setenv("DISK_THRESHOLD", "82")
|
||||
|
||||
cfg, err := Load([]string{})
|
||||
if err != nil {
|
||||
t.Fatalf("load failed: %v", err)
|
||||
}
|
||||
if cfg.ServerURL != "https://example.com" {
|
||||
t.Fatalf("expected server url from env")
|
||||
}
|
||||
if cfg.MaxTasks != 5 {
|
||||
t.Fatalf("expected max tasks from env")
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"--server-url=https://override.example.com",
|
||||
"--api-key=deadbeef",
|
||||
"--max-tasks=9",
|
||||
"--cpu-threshold=70",
|
||||
"--mem-threshold=71",
|
||||
"--disk-threshold=72",
|
||||
}
|
||||
cfg, err = Load(args)
|
||||
if err != nil {
|
||||
t.Fatalf("load failed: %v", err)
|
||||
}
|
||||
if cfg.ServerURL != "https://override.example.com" {
|
||||
t.Fatalf("expected server url from args")
|
||||
}
|
||||
if cfg.APIKey != "deadbeef" {
|
||||
t.Fatalf("expected api key from args")
|
||||
}
|
||||
if cfg.MaxTasks != 9 {
|
||||
t.Fatalf("expected max tasks from args")
|
||||
}
|
||||
if cfg.CPUThreshold != 70 || cfg.MemThreshold != 71 || cfg.DiskThreshold != 72 {
|
||||
t.Fatalf("expected thresholds from args")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigMissingRequired(t *testing.T) {
|
||||
t.Setenv("SERVER_URL", "")
|
||||
t.Setenv("API_KEY", "")
|
||||
t.Setenv("AGENT_VERSION", "v1.2.3")
|
||||
|
||||
_, err := Load([]string{})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error when required values missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigInvalidEnvValue(t *testing.T) {
|
||||
t.Setenv("SERVER_URL", "https://example.com")
|
||||
t.Setenv("API_KEY", "abc")
|
||||
t.Setenv("AGENT_VERSION", "v1.2.3")
|
||||
t.Setenv("MAX_TASKS", "nope")
|
||||
|
||||
_, err := Load([]string{})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid MAX_TASKS")
|
||||
}
|
||||
}
|
||||
|
||||
49
agent/internal/config/updater.go
Normal file
49
agent/internal/config/updater.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
)
|
||||
|
||||
// Update holds optional configuration updates.
|
||||
type Update = domain.ConfigUpdate
|
||||
|
||||
// Updater manages runtime configuration changes.
|
||||
type Updater struct {
|
||||
mu sync.RWMutex
|
||||
cfg Config
|
||||
}
|
||||
|
||||
// NewUpdater creates an updater with initial config.
|
||||
func NewUpdater(cfg Config) *Updater {
|
||||
return &Updater{cfg: cfg}
|
||||
}
|
||||
|
||||
// Apply updates the configuration and returns the new snapshot.
|
||||
func (u *Updater) Apply(update Update) Config {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
if update.MaxTasks != nil && *update.MaxTasks > 0 {
|
||||
u.cfg.MaxTasks = *update.MaxTasks
|
||||
}
|
||||
if update.CPUThreshold != nil && *update.CPUThreshold > 0 {
|
||||
u.cfg.CPUThreshold = *update.CPUThreshold
|
||||
}
|
||||
if update.MemThreshold != nil && *update.MemThreshold > 0 {
|
||||
u.cfg.MemThreshold = *update.MemThreshold
|
||||
}
|
||||
if update.DiskThreshold != nil && *update.DiskThreshold > 0 {
|
||||
u.cfg.DiskThreshold = *update.DiskThreshold
|
||||
}
|
||||
|
||||
return u.cfg
|
||||
}
|
||||
|
||||
// Snapshot returns a copy of current config.
|
||||
func (u *Updater) Snapshot() Config {
|
||||
u.mu.RLock()
|
||||
defer u.mu.RUnlock()
|
||||
return u.cfg
|
||||
}
|
||||
39
agent/internal/config/updater_test.go
Normal file
39
agent/internal/config/updater_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestUpdaterApplyAndSnapshot(t *testing.T) {
|
||||
cfg := Config{
|
||||
ServerURL: "https://example.com",
|
||||
APIKey: "key",
|
||||
MaxTasks: 2,
|
||||
CPUThreshold: 70,
|
||||
MemThreshold: 80,
|
||||
DiskThreshold: 90,
|
||||
}
|
||||
|
||||
updater := NewUpdater(cfg)
|
||||
snapshot := updater.Snapshot()
|
||||
if snapshot.MaxTasks != 2 || snapshot.CPUThreshold != 70 {
|
||||
t.Fatalf("unexpected snapshot values")
|
||||
}
|
||||
|
||||
invalid := 0
|
||||
update := Update{MaxTasks: &invalid, CPUThreshold: &invalid}
|
||||
snapshot = updater.Apply(update)
|
||||
if snapshot.MaxTasks != 2 || snapshot.CPUThreshold != 70 {
|
||||
t.Fatalf("expected invalid update to be ignored")
|
||||
}
|
||||
|
||||
maxTasks := 5
|
||||
cpu := 85
|
||||
mem := 60
|
||||
snapshot = updater.Apply(Update{
|
||||
MaxTasks: &maxTasks,
|
||||
CPUThreshold: &cpu,
|
||||
MemThreshold: &mem,
|
||||
})
|
||||
if snapshot.MaxTasks != 5 || snapshot.CPUThreshold != 85 || snapshot.MemThreshold != 60 {
|
||||
t.Fatalf("unexpected applied update")
|
||||
}
|
||||
}
|
||||
50
agent/internal/config/url.go
Normal file
50
agent/internal/config/url.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// BuildWebSocketURL derives the agent WebSocket endpoint from the server URL.
|
||||
func BuildWebSocketURL(serverURL string) (string, error) {
|
||||
trimmed := strings.TrimSpace(serverURL)
|
||||
if trimmed == "" {
|
||||
return "", errors.New("server URL is required")
|
||||
}
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "http":
|
||||
parsed.Scheme = "ws"
|
||||
case "https":
|
||||
parsed.Scheme = "wss"
|
||||
case "ws", "wss":
|
||||
default:
|
||||
if parsed.Scheme == "" {
|
||||
return "", errors.New("server URL scheme is required")
|
||||
}
|
||||
return "", fmt.Errorf("unsupported server URL scheme: %s", parsed.Scheme)
|
||||
}
|
||||
|
||||
parsed.Path = buildWSPath(parsed.Path)
|
||||
parsed.RawQuery = ""
|
||||
parsed.Fragment = ""
|
||||
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
func buildWSPath(path string) string {
|
||||
trimmed := strings.TrimRight(path, "/")
|
||||
if trimmed == "" {
|
||||
return "/api/agents/ws"
|
||||
}
|
||||
if strings.HasSuffix(trimmed, "/api") {
|
||||
return trimmed + "/agents/ws"
|
||||
}
|
||||
return trimmed + "/api/agents/ws"
|
||||
}
|
||||
38
agent/internal/config/url_test.go
Normal file
38
agent/internal/config/url_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildWebSocketURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"https://example.com", "wss://example.com/api/agents/ws"},
|
||||
{"http://example.com", "ws://example.com/api/agents/ws"},
|
||||
{"https://example.com/api", "wss://example.com/api/agents/ws"},
|
||||
{"https://example.com/base", "wss://example.com/base/api/agents/ws"},
|
||||
{"wss://example.com", "wss://example.com/api/agents/ws"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got, err := BuildWebSocketURL(tt.input)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for %s: %v", tt.input, err)
|
||||
}
|
||||
if got != tt.expected {
|
||||
t.Fatalf("input %s expected %s got %s", tt.input, tt.expected, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWebSocketURLInvalid(t *testing.T) {
|
||||
if _, err := BuildWebSocketURL("example.com"); err == nil {
|
||||
t.Fatalf("expected error for missing scheme")
|
||||
}
|
||||
if _, err := BuildWebSocketURL(" "); err == nil {
|
||||
t.Fatalf("expected error for empty url")
|
||||
}
|
||||
if _, err := BuildWebSocketURL("ftp://example.com"); err == nil {
|
||||
t.Fatalf("expected error for unsupported scheme")
|
||||
}
|
||||
}
|
||||
23
agent/internal/docker/cleanup.go
Normal file
23
agent/internal/docker/cleanup.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
)
|
||||
|
||||
// Remove removes the container.
|
||||
func (c *Client) Remove(ctx context.Context, containerID string) error {
|
||||
return c.cli.ContainerRemove(ctx, containerID, container.RemoveOptions{
|
||||
Force: true,
|
||||
RemoveVolumes: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Stop stops a running container with a timeout.
|
||||
func (c *Client) Stop(ctx context.Context, containerID string) error {
|
||||
timeout := 10
|
||||
return c.cli.ContainerStop(ctx, containerID, container.StopOptions{
|
||||
Timeout: &timeout,
|
||||
})
|
||||
}
|
||||
46
agent/internal/docker/client.go
Normal file
46
agent/internal/docker/client.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
imagetypes "github.com/docker/docker/api/types/image"
|
||||
"github.com/docker/docker/api/types/network"
|
||||
"github.com/docker/docker/client"
|
||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
)
|
||||
|
||||
// Client wraps the Docker SDK client.
|
||||
type Client struct {
|
||||
cli *client.Client
|
||||
}
|
||||
|
||||
// NewClient creates a Docker client using environment configuration.
|
||||
func NewClient() (*Client, error) {
|
||||
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Client{cli: cli}, nil
|
||||
}
|
||||
|
||||
// Close closes the Docker client.
|
||||
func (c *Client) Close() error {
|
||||
return c.cli.Close()
|
||||
}
|
||||
|
||||
// ImagePull pulls an image from the registry.
|
||||
func (c *Client) ImagePull(ctx context.Context, imageRef string) (io.ReadCloser, error) {
|
||||
return c.cli.ImagePull(ctx, imageRef, imagetypes.PullOptions{})
|
||||
}
|
||||
|
||||
// ContainerCreate creates a container.
|
||||
func (c *Client) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, name string) (container.CreateResponse, error) {
|
||||
return c.cli.ContainerCreate(ctx, config, hostConfig, networkingConfig, platform, name)
|
||||
}
|
||||
|
||||
// ContainerStart starts a container.
|
||||
func (c *Client) ContainerStart(ctx context.Context, containerID string, opts container.StartOptions) error {
|
||||
return c.cli.ContainerStart(ctx, containerID, opts)
|
||||
}
|
||||
49
agent/internal/docker/logs.go
Normal file
49
agent/internal/docker/logs.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
)
|
||||
|
||||
const (
|
||||
maxErrorBytes = 4096
|
||||
)
|
||||
|
||||
// TailLogs returns the last N lines of container logs, truncated to 4KB.
|
||||
func (c *Client) TailLogs(ctx context.Context, containerID string, lines int) (string, error) {
|
||||
reader, err := c.cli.ContainerLogs(ctx, containerID, container.LogsOptions{
|
||||
ShowStdout: true,
|
||||
ShowStderr: true,
|
||||
Timestamps: false,
|
||||
Tail: strconv.Itoa(lines),
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
var buf bytes.Buffer
|
||||
if _, err := io.Copy(&buf, reader); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
out := buf.String()
|
||||
out = strings.TrimSpace(out)
|
||||
if len(out) > maxErrorBytes {
|
||||
out = out[len(out)-maxErrorBytes:]
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// TruncateErrorMessage clamps message length to 4KB.
|
||||
func TruncateErrorMessage(message string) string {
|
||||
if len(message) <= maxErrorBytes {
|
||||
return message
|
||||
}
|
||||
return message[:maxErrorBytes]
|
||||
}
|
||||
22
agent/internal/docker/logs_test.go
Normal file
22
agent/internal/docker/logs_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTruncateErrorMessage(t *testing.T) {
|
||||
short := "short message"
|
||||
if got := TruncateErrorMessage(short); got != short {
|
||||
t.Fatalf("expected message to stay unchanged")
|
||||
}
|
||||
|
||||
long := strings.Repeat("x", maxErrorBytes+10)
|
||||
got := TruncateErrorMessage(long)
|
||||
if len(got) != maxErrorBytes {
|
||||
t.Fatalf("expected length %d, got %d", maxErrorBytes, len(got))
|
||||
}
|
||||
if got != long[:maxErrorBytes] {
|
||||
t.Fatalf("unexpected truncation result")
|
||||
}
|
||||
}
|
||||
20
agent/internal/docker/monitor.go
Normal file
20
agent/internal/docker/monitor.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
)
|
||||
|
||||
// Wait waits for a container to stop and returns the exit code.
|
||||
func (c *Client) Wait(ctx context.Context, containerID string) (int64, error) {
|
||||
statusCh, errCh := c.cli.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
|
||||
select {
|
||||
case status := <-statusCh:
|
||||
return status.StatusCode, nil
|
||||
case err := <-errCh:
|
||||
return 0, err
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
}
|
||||
76
agent/internal/docker/runner.go
Normal file
76
agent/internal/docker/runner.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/network"
|
||||
"github.com/docker/docker/api/types/strslice"
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
)
|
||||
|
||||
const workerImagePrefix = "yyhuni/lunafox-worker:"
|
||||
|
||||
// StartWorker starts a worker container for a task and returns the container ID.
|
||||
func (c *Client) StartWorker(ctx context.Context, t *domain.Task, serverURL, serverToken, agentVersion string) (string, error) {
|
||||
if t == nil {
|
||||
return "", fmt.Errorf("task is nil")
|
||||
}
|
||||
if err := os.MkdirAll(t.WorkspaceDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("prepare workspace: %w", err)
|
||||
}
|
||||
|
||||
image, err := resolveWorkerImage(agentVersion)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
env := buildWorkerEnv(t, serverURL, serverToken)
|
||||
|
||||
config := &container.Config{
|
||||
Image: image,
|
||||
Env: env,
|
||||
Cmd: strslice.StrSlice{},
|
||||
}
|
||||
|
||||
hostConfig := &container.HostConfig{
|
||||
Binds: []string{"/opt/lunafox:/opt/lunafox"},
|
||||
AutoRemove: false,
|
||||
OomScoreAdj: 500,
|
||||
}
|
||||
|
||||
resp, err := c.cli.ContainerCreate(ctx, config, hostConfig, &network.NetworkingConfig{}, nil, "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return resp.ID, nil
|
||||
}
|
||||
|
||||
func resolveWorkerImage(version string) (string, error) {
|
||||
version = strings.TrimSpace(version)
|
||||
if version == "" {
|
||||
return "", fmt.Errorf("worker version is required")
|
||||
}
|
||||
return workerImagePrefix + version, nil
|
||||
}
|
||||
|
||||
func buildWorkerEnv(t *domain.Task, serverURL, serverToken string) []string {
|
||||
return []string{
|
||||
fmt.Sprintf("SERVER_URL=%s", serverURL),
|
||||
fmt.Sprintf("SERVER_TOKEN=%s", serverToken),
|
||||
fmt.Sprintf("SCAN_ID=%d", t.ScanID),
|
||||
fmt.Sprintf("TARGET_ID=%d", t.TargetID),
|
||||
fmt.Sprintf("TARGET_NAME=%s", t.TargetName),
|
||||
fmt.Sprintf("TARGET_TYPE=%s", t.TargetType),
|
||||
fmt.Sprintf("WORKFLOW_NAME=%s", t.WorkflowName),
|
||||
fmt.Sprintf("WORKSPACE_DIR=%s", t.WorkspaceDir),
|
||||
fmt.Sprintf("CONFIG=%s", t.Config),
|
||||
}
|
||||
}
|
||||
50
agent/internal/docker/runner_test.go
Normal file
50
agent/internal/docker/runner_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package docker
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
)
|
||||
|
||||
func TestResolveWorkerImage(t *testing.T) {
|
||||
if _, err := resolveWorkerImage(""); err == nil {
|
||||
t.Fatalf("expected error for empty version")
|
||||
}
|
||||
if got, err := resolveWorkerImage("v1.2.3"); err != nil || got != workerImagePrefix+"v1.2.3" {
|
||||
t.Fatalf("expected version image, got %s, err: %v", got, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWorkerEnv(t *testing.T) {
|
||||
spec := &domain.Task{
|
||||
ScanID: 1,
|
||||
TargetID: 2,
|
||||
TargetName: "example.com",
|
||||
TargetType: "domain",
|
||||
WorkflowName: "subdomain_discovery",
|
||||
WorkspaceDir: "/opt/lunafox/results",
|
||||
Config: "config-yaml",
|
||||
}
|
||||
|
||||
env := buildWorkerEnv(spec, "https://server", "token")
|
||||
expected := []string{
|
||||
"SERVER_URL=https://server",
|
||||
"SERVER_TOKEN=token",
|
||||
"SCAN_ID=1",
|
||||
"TARGET_ID=2",
|
||||
"TARGET_NAME=example.com",
|
||||
"TARGET_TYPE=domain",
|
||||
"WORKFLOW_NAME=subdomain_discovery",
|
||||
"WORKSPACE_DIR=/opt/lunafox/results",
|
||||
"CONFIG=config-yaml",
|
||||
}
|
||||
|
||||
if len(env) != len(expected) {
|
||||
t.Fatalf("expected %d env entries, got %d", len(expected), len(env))
|
||||
}
|
||||
for i, item := range expected {
|
||||
if env[i] != item {
|
||||
t.Fatalf("expected env[%d]=%s got %s", i, item, env[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
8
agent/internal/domain/config.go
Normal file
8
agent/internal/domain/config.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package domain
|
||||
|
||||
type ConfigUpdate struct {
|
||||
MaxTasks *int `json:"maxTasks"`
|
||||
CPUThreshold *int `json:"cpuThreshold"`
|
||||
MemThreshold *int `json:"memThreshold"`
|
||||
DiskThreshold *int `json:"diskThreshold"`
|
||||
}
|
||||
10
agent/internal/domain/health.go
Normal file
10
agent/internal/domain/health.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
type HealthStatus struct {
|
||||
State string `json:"state"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Since *time.Time `json:"since,omitempty"`
|
||||
}
|
||||
13
agent/internal/domain/task.go
Normal file
13
agent/internal/domain/task.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package domain
|
||||
|
||||
type Task struct {
|
||||
ID int `json:"taskId"`
|
||||
ScanID int `json:"scanId"`
|
||||
Stage int `json:"stage"`
|
||||
WorkflowName string `json:"workflowName"`
|
||||
TargetID int `json:"targetId"`
|
||||
TargetName string `json:"targetName"`
|
||||
TargetType string `json:"targetType"`
|
||||
WorkspaceDir string `json:"workspaceDir"`
|
||||
Config string `json:"config"`
|
||||
}
|
||||
6
agent/internal/domain/update.go
Normal file
6
agent/internal/domain/update.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package domain
|
||||
|
||||
type UpdateRequiredPayload struct {
|
||||
Version string `json:"version"`
|
||||
Image string `json:"image"`
|
||||
}
|
||||
51
agent/internal/health/health.go
Normal file
51
agent/internal/health/health.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
)
|
||||
|
||||
// Status represents the agent health state reported in heartbeats.
|
||||
type Status = domain.HealthStatus
|
||||
|
||||
// Manager stores current health status.
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
status Status
|
||||
}
|
||||
|
||||
// NewManager initializes the manager with ok status.
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
status: Status{State: "ok"},
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a snapshot of current status.
|
||||
func (m *Manager) Get() Status {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.status
|
||||
}
|
||||
|
||||
// Set updates health status and timestamps transitions.
|
||||
func (m *Manager) Set(state, reason, message string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.status.State != state {
|
||||
now := time.Now().UTC()
|
||||
m.status.Since = &now
|
||||
}
|
||||
|
||||
m.status.State = state
|
||||
m.status.Reason = reason
|
||||
m.status.Message = message
|
||||
if state == "ok" {
|
||||
m.status.Since = nil
|
||||
m.status.Reason = ""
|
||||
m.status.Message = ""
|
||||
}
|
||||
}
|
||||
33
agent/internal/health/health_test.go
Normal file
33
agent/internal/health/health_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package health
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestManagerSetTransitions(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
initial := mgr.Get()
|
||||
if initial.State != "ok" || initial.Since != nil {
|
||||
t.Fatalf("expected initial ok status")
|
||||
}
|
||||
|
||||
mgr.Set("paused", "update", "waiting")
|
||||
status := mgr.Get()
|
||||
if status.State != "paused" || status.Since == nil {
|
||||
t.Fatalf("expected paused state with timestamp")
|
||||
}
|
||||
prevSince := status.Since
|
||||
|
||||
mgr.Set("paused", "still", "waiting more")
|
||||
status = mgr.Get()
|
||||
if status.Since == nil || !status.Since.Equal(*prevSince) {
|
||||
t.Fatalf("expected unchanged since on same state")
|
||||
}
|
||||
if status.Reason != "still" || status.Message != "waiting more" {
|
||||
t.Fatalf("expected updated reason/message")
|
||||
}
|
||||
|
||||
mgr.Set("ok", "ignored", "ignored")
|
||||
status = mgr.Get()
|
||||
if status.State != "ok" || status.Since != nil || status.Reason != "" || status.Message != "" {
|
||||
t.Fatalf("expected ok reset to clear fields")
|
||||
}
|
||||
}
|
||||
50
agent/internal/logger/logger.go
Normal file
50
agent/internal/logger/logger.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// Log is the shared agent logger. Defaults to a no-op logger until initialized.
|
||||
var Log = zap.NewNop()
|
||||
|
||||
// Init configures the logger using the provided level and ENV.
|
||||
func Init(level string) error {
|
||||
level = strings.TrimSpace(level)
|
||||
if level == "" {
|
||||
level = "info"
|
||||
}
|
||||
|
||||
var zapLevel zapcore.Level
|
||||
if err := zapLevel.UnmarshalText([]byte(level)); err != nil {
|
||||
zapLevel = zapcore.InfoLevel
|
||||
}
|
||||
|
||||
isDev := strings.EqualFold(os.Getenv("ENV"), "development")
|
||||
var config zap.Config
|
||||
if isDev {
|
||||
config = zap.NewDevelopmentConfig()
|
||||
config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
||||
} else {
|
||||
config = zap.NewProductionConfig()
|
||||
}
|
||||
config.Level = zap.NewAtomicLevelAt(zapLevel)
|
||||
|
||||
logger, err := config.Build()
|
||||
if err != nil {
|
||||
Log = zap.NewNop()
|
||||
return err
|
||||
}
|
||||
Log = logger
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync flushes any buffered log entries.
|
||||
func Sync() {
|
||||
if Log != nil {
|
||||
_ = Log.Sync()
|
||||
}
|
||||
}
|
||||
58
agent/internal/metrics/collector.go
Normal file
58
agent/internal/metrics/collector.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"github.com/shirou/gopsutil/v3/disk"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
"github.com/yyhuni/lunafox/agent/internal/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Collector gathers system metrics.
|
||||
type Collector struct{}
|
||||
|
||||
// NewCollector creates a new Collector.
|
||||
func NewCollector() *Collector {
|
||||
return &Collector{}
|
||||
}
|
||||
|
||||
// Sample returns CPU, memory, and disk usage percentages.
|
||||
func (c *Collector) Sample() (float64, float64, float64) {
|
||||
cpuPercent, err := cpuUsagePercent()
|
||||
if err != nil {
|
||||
logger.Log.Warn("metrics: cpu percent error", zap.Error(err))
|
||||
}
|
||||
memPercent, err := memUsagePercent()
|
||||
if err != nil {
|
||||
logger.Log.Warn("metrics: mem percent error", zap.Error(err))
|
||||
}
|
||||
diskPercent, err := diskUsagePercent("/")
|
||||
if err != nil {
|
||||
logger.Log.Warn("metrics: disk percent error", zap.Error(err))
|
||||
}
|
||||
return cpuPercent, memPercent, diskPercent
|
||||
}
|
||||
|
||||
func cpuUsagePercent() (float64, error) {
|
||||
values, err := cpu.Percent(0, false)
|
||||
if err != nil || len(values) == 0 {
|
||||
return 0, err
|
||||
}
|
||||
return values[0], nil
|
||||
}
|
||||
|
||||
func memUsagePercent() (float64, error) {
|
||||
info, err := mem.VirtualMemory()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return info.UsedPercent, nil
|
||||
}
|
||||
|
||||
func diskUsagePercent(path string) (float64, error) {
|
||||
info, err := disk.Usage(path)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return info.UsedPercent, nil
|
||||
}
|
||||
11
agent/internal/metrics/collector_test.go
Normal file
11
agent/internal/metrics/collector_test.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package metrics
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestCollectorSample(t *testing.T) {
|
||||
c := NewCollector()
|
||||
cpu, mem, disk := c.Sample()
|
||||
if cpu < 0 || mem < 0 || disk < 0 {
|
||||
t.Fatalf("expected non-negative metrics")
|
||||
}
|
||||
}
|
||||
42
agent/internal/protocol/messages.go
Normal file
42
agent/internal/protocol/messages.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
MessageTypeHeartbeat = "heartbeat"
|
||||
MessageTypeTaskAvailable = "task_available"
|
||||
MessageTypeTaskCancel = "task_cancel"
|
||||
MessageTypeConfigUpdate = "config_update"
|
||||
MessageTypeUpdateRequired = "update_required"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Type string `json:"type"`
|
||||
Payload interface{} `json:"payload"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
type HealthStatus = domain.HealthStatus
|
||||
|
||||
type HeartbeatPayload struct {
|
||||
CPU float64 `json:"cpu"`
|
||||
Mem float64 `json:"mem"`
|
||||
Disk float64 `json:"disk"`
|
||||
Tasks int `json:"tasks"`
|
||||
Version string `json:"version"`
|
||||
Hostname string `json:"hostname"`
|
||||
Uptime int64 `json:"uptime"`
|
||||
Health HealthStatus `json:"health"`
|
||||
}
|
||||
|
||||
type ConfigUpdatePayload = domain.ConfigUpdate
|
||||
|
||||
type UpdateRequiredPayload = domain.UpdateRequiredPayload
|
||||
|
||||
type TaskCancelPayload struct {
|
||||
TaskID int `json:"taskId"`
|
||||
}
|
||||
118
agent/internal/task/client.go
Normal file
118
agent/internal/task/client.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
)
|
||||
|
||||
// Client handles HTTP API requests to the server.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates a new task client.
|
||||
func NewClient(serverURL, apiKey string) *Client {
|
||||
transport := http.DefaultTransport
|
||||
if base, ok := transport.(*http.Transport); ok {
|
||||
clone := base.Clone()
|
||||
clone.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
transport = clone
|
||||
}
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(serverURL, "/"),
|
||||
apiKey: apiKey,
|
||||
http: &http.Client{
|
||||
Timeout: 15 * time.Second,
|
||||
Transport: transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// PullTask requests a task from the server. Returns nil when no task available.
|
||||
func (c *Client) PullTask(ctx context.Context) (*domain.Task, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/agent/tasks/pull", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("X-Agent-Key", c.apiKey)
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNoContent {
|
||||
return nil, nil
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("pull task failed: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var task domain.Task
|
||||
if err := json.NewDecoder(resp.Body).Decode(&task); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
// UpdateStatus reports task status to the server with retry.
|
||||
func (c *Client) UpdateStatus(ctx context.Context, taskID int, status, errorMessage string) error {
|
||||
payload := map[string]string{
|
||||
"status": status,
|
||||
}
|
||||
if errorMessage != "" {
|
||||
payload["errorMessage"] = errorMessage
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(5<<attempt) * time.Second // 5s, 10s, 20s
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, fmt.Sprintf("%s/api/agent/tasks/%d/status", c.baseURL, taskID), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Agent-Key", c.apiKey)
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
lastErr = fmt.Errorf("update status failed: status %d", resp.StatusCode)
|
||||
|
||||
// Don't retry 4xx client errors (except 429)
|
||||
if resp.StatusCode >= 400 && resp.StatusCode < 500 && resp.StatusCode != 429 {
|
||||
return lastErr
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
187
agent/internal/task/client_test.go
Normal file
187
agent/internal/task/client_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
)
|
||||
|
||||
func TestClientPullTaskNoContent(t *testing.T) {
|
||||
client := &Client{
|
||||
baseURL: "http://example",
|
||||
apiKey: "key",
|
||||
http: &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Path != "/api/agent/tasks/pull" {
|
||||
t.Fatalf("unexpected path %s", r.URL.Path)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusNoContent,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
Header: http.Header{},
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
task, err := client.PullTask(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if task != nil {
|
||||
t.Fatalf("expected nil task")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientPullTaskOK(t *testing.T) {
|
||||
client := &Client{
|
||||
baseURL: "http://example",
|
||||
apiKey: "key",
|
||||
http: &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Header.Get("X-Agent-Key") == "" {
|
||||
t.Fatalf("missing api key header")
|
||||
}
|
||||
body, _ := json.Marshal(domain.Task{ID: 1})
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
Header: http.Header{},
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
task, err := client.PullTask(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if task == nil || task.ID != 1 {
|
||||
t.Fatalf("unexpected task")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientUpdateStatus(t *testing.T) {
|
||||
client := &Client{
|
||||
baseURL: "http://example",
|
||||
apiKey: "key",
|
||||
http: &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
if r.Method != http.MethodPatch {
|
||||
t.Fatalf("expected PATCH")
|
||||
}
|
||||
if r.Header.Get("X-Agent-Key") == "" {
|
||||
t.Fatalf("missing api key header")
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
Header: http.Header{},
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
if err := client.UpdateStatus(context.Background(), 1, "completed", ""); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientPullTaskErrorStatus(t *testing.T) {
|
||||
client := &Client{
|
||||
baseURL: "http://example",
|
||||
apiKey: "key",
|
||||
http: &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: io.NopCloser(strings.NewReader("bad")),
|
||||
Header: http.Header{},
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
if _, err := client.PullTask(context.Background()); err == nil {
|
||||
t.Fatalf("expected error for non-200 status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientPullTaskBadJSON(t *testing.T) {
|
||||
client := &Client{
|
||||
baseURL: "http://example",
|
||||
apiKey: "key",
|
||||
http: &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("{bad json")),
|
||||
Header: http.Header{},
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
if _, err := client.PullTask(context.Background()); err == nil {
|
||||
t.Fatalf("expected error for invalid json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientUpdateStatusIncludesErrorMessage(t *testing.T) {
|
||||
client := &Client{
|
||||
baseURL: "http://example",
|
||||
apiKey: "key",
|
||||
http: &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
var payload map[string]string
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
t.Fatalf("unmarshal body: %v", err)
|
||||
}
|
||||
if payload["status"] != "failed" {
|
||||
t.Fatalf("expected status failed")
|
||||
}
|
||||
if payload["errorMessage"] != "boom" {
|
||||
t.Fatalf("expected error message")
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
Header: http.Header{},
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
if err := client.UpdateStatus(context.Background(), 1, "failed", "boom"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientUpdateStatusErrorStatus(t *testing.T) {
|
||||
client := &Client{
|
||||
baseURL: "http://example",
|
||||
apiKey: "key",
|
||||
http: &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
Header: http.Header{},
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
if err := client.UpdateStatus(context.Background(), 1, "completed", ""); err == nil {
|
||||
t.Fatalf("expected error for non-200 status")
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return f(r)
|
||||
}
|
||||
23
agent/internal/task/counter.go
Normal file
23
agent/internal/task/counter.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package task
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// Counter tracks running task count.
|
||||
type Counter struct {
|
||||
value int64
|
||||
}
|
||||
|
||||
// Inc increments the counter.
|
||||
func (c *Counter) Inc() {
|
||||
atomic.AddInt64(&c.value, 1)
|
||||
}
|
||||
|
||||
// Dec decrements the counter.
|
||||
func (c *Counter) Dec() {
|
||||
atomic.AddInt64(&c.value, -1)
|
||||
}
|
||||
|
||||
// Count returns current count.
|
||||
func (c *Counter) Count() int {
|
||||
return int(atomic.LoadInt64(&c.value))
|
||||
}
|
||||
18
agent/internal/task/counter_test.go
Normal file
18
agent/internal/task/counter_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package task
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestCounterIncDec(t *testing.T) {
|
||||
var counter Counter
|
||||
|
||||
counter.Inc()
|
||||
counter.Inc()
|
||||
if got := counter.Count(); got != 2 {
|
||||
t.Fatalf("expected count 2, got %d", got)
|
||||
}
|
||||
|
||||
counter.Dec()
|
||||
if got := counter.Count(); got != 1 {
|
||||
t.Fatalf("expected count 1, got %d", got)
|
||||
}
|
||||
}
|
||||
258
agent/internal/task/executor.go
Normal file
258
agent/internal/task/executor.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/docker"
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
)
|
||||
|
||||
const defaultMaxRuntime = 7 * 24 * time.Hour
|
||||
|
||||
// Executor runs tasks inside worker containers.
|
||||
type Executor struct {
|
||||
docker DockerRunner
|
||||
client statusReporter
|
||||
counter *Counter
|
||||
serverURL string
|
||||
workerToken string
|
||||
agentVersion string
|
||||
maxRuntime time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
running map[int]context.CancelFunc
|
||||
cancelMu sync.Mutex
|
||||
cancelled map[int]struct{}
|
||||
wg sync.WaitGroup
|
||||
|
||||
stopping atomic.Bool
|
||||
}
|
||||
|
||||
type statusReporter interface {
|
||||
UpdateStatus(ctx context.Context, taskID int, status, errorMessage string) error
|
||||
}
|
||||
|
||||
type DockerRunner interface {
|
||||
StartWorker(ctx context.Context, t *domain.Task, serverURL, serverToken, agentVersion string) (string, error)
|
||||
Wait(ctx context.Context, containerID string) (int64, error)
|
||||
Stop(ctx context.Context, containerID string) error
|
||||
Remove(ctx context.Context, containerID string) error
|
||||
TailLogs(ctx context.Context, containerID string, lines int) (string, error)
|
||||
}
|
||||
|
||||
// NewExecutor creates an Executor.
|
||||
func NewExecutor(dockerClient DockerRunner, taskClient statusReporter, counter *Counter, serverURL, workerToken, agentVersion string) *Executor {
|
||||
return &Executor{
|
||||
docker: dockerClient,
|
||||
client: taskClient,
|
||||
counter: counter,
|
||||
serverURL: serverURL,
|
||||
workerToken: workerToken,
|
||||
agentVersion: agentVersion,
|
||||
maxRuntime: defaultMaxRuntime,
|
||||
running: map[int]context.CancelFunc{},
|
||||
cancelled: map[int]struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
// Start processes tasks from the queue.
|
||||
func (e *Executor) Start(ctx context.Context, tasks <-chan *domain.Task) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case t, ok := <-tasks:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
if e.stopping.Load() {
|
||||
// During shutdown/update: drain the queue but don't start new work.
|
||||
continue
|
||||
}
|
||||
if e.isCancelled(t.ID) {
|
||||
e.reportStatus(ctx, t.ID, "cancelled", "")
|
||||
e.clearCancelled(t.ID)
|
||||
continue
|
||||
}
|
||||
go e.execute(ctx, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CancelTask requests cancellation of a running task.
|
||||
func (e *Executor) CancelTask(taskID int) {
|
||||
e.mu.Lock()
|
||||
cancel := e.running[taskID]
|
||||
e.mu.Unlock()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// MarkCancelled records a task as cancelled to prevent execution.
|
||||
func (e *Executor) MarkCancelled(taskID int) {
|
||||
e.cancelMu.Lock()
|
||||
e.cancelled[taskID] = struct{}{}
|
||||
e.cancelMu.Unlock()
|
||||
}
|
||||
|
||||
func (e *Executor) reportStatus(ctx context.Context, taskID int, status, errorMessage string) {
|
||||
if e.client == nil {
|
||||
return
|
||||
}
|
||||
statusCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
|
||||
defer cancel()
|
||||
_ = e.client.UpdateStatus(statusCtx, taskID, status, errorMessage)
|
||||
}
|
||||
|
||||
func (e *Executor) execute(ctx context.Context, t *domain.Task) {
|
||||
e.wg.Add(1)
|
||||
defer e.wg.Done()
|
||||
defer e.clearCancelled(t.ID)
|
||||
|
||||
if e.counter != nil {
|
||||
e.counter.Inc()
|
||||
defer e.counter.Dec()
|
||||
}
|
||||
|
||||
if e.workerToken == "" {
|
||||
e.reportStatus(ctx, t.ID, "failed", "missing worker token")
|
||||
return
|
||||
}
|
||||
if e.docker == nil {
|
||||
e.reportStatus(ctx, t.ID, "failed", "docker client unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithTimeout(ctx, e.maxRuntime)
|
||||
defer cancel()
|
||||
|
||||
containerID, err := e.docker.StartWorker(runCtx, t, e.serverURL, e.workerToken, e.agentVersion)
|
||||
if err != nil {
|
||||
message := docker.TruncateErrorMessage(err.Error())
|
||||
e.reportStatus(ctx, t.ID, "failed", message)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = e.docker.Remove(context.Background(), containerID)
|
||||
}()
|
||||
|
||||
e.trackCancel(t.ID, cancel)
|
||||
defer e.clearCancel(t.ID)
|
||||
|
||||
exitCode, waitErr := e.docker.Wait(runCtx, containerID)
|
||||
if waitErr != nil {
|
||||
if errors.Is(waitErr, context.DeadlineExceeded) || errors.Is(runCtx.Err(), context.DeadlineExceeded) {
|
||||
e.handleTimeout(ctx, t, containerID)
|
||||
return
|
||||
}
|
||||
if errors.Is(waitErr, context.Canceled) || errors.Is(runCtx.Err(), context.Canceled) {
|
||||
e.handleCancel(ctx, t, containerID)
|
||||
return
|
||||
}
|
||||
message := docker.TruncateErrorMessage(waitErr.Error())
|
||||
e.reportStatus(ctx, t.ID, "failed", message)
|
||||
return
|
||||
}
|
||||
|
||||
if runCtx.Err() != nil {
|
||||
if errors.Is(runCtx.Err(), context.DeadlineExceeded) {
|
||||
e.handleTimeout(ctx, t, containerID)
|
||||
return
|
||||
}
|
||||
if errors.Is(runCtx.Err(), context.Canceled) {
|
||||
e.handleCancel(ctx, t, containerID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if exitCode == 0 {
|
||||
e.reportStatus(ctx, t.ID, "completed", "")
|
||||
return
|
||||
}
|
||||
|
||||
logs, _ := e.docker.TailLogs(context.Background(), containerID, 100)
|
||||
message := logs
|
||||
if message == "" {
|
||||
message = fmt.Sprintf("container exited with code %d", exitCode)
|
||||
}
|
||||
message = docker.TruncateErrorMessage(message)
|
||||
e.reportStatus(ctx, t.ID, "failed", message)
|
||||
}
|
||||
|
||||
func (e *Executor) handleCancel(ctx context.Context, t *domain.Task, containerID string) {
|
||||
_ = e.docker.Stop(context.Background(), containerID)
|
||||
e.reportStatus(ctx, t.ID, "cancelled", "")
|
||||
}
|
||||
|
||||
func (e *Executor) handleTimeout(ctx context.Context, t *domain.Task, containerID string) {
|
||||
_ = e.docker.Stop(context.Background(), containerID)
|
||||
message := docker.TruncateErrorMessage("task timed out")
|
||||
e.reportStatus(ctx, t.ID, "failed", message)
|
||||
}
|
||||
|
||||
func (e *Executor) trackCancel(taskID int, cancel context.CancelFunc) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.running[taskID] = cancel
|
||||
}
|
||||
|
||||
func (e *Executor) clearCancel(taskID int) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
delete(e.running, taskID)
|
||||
}
|
||||
|
||||
func (e *Executor) isCancelled(taskID int) bool {
|
||||
e.cancelMu.Lock()
|
||||
defer e.cancelMu.Unlock()
|
||||
_, ok := e.cancelled[taskID]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *Executor) clearCancelled(taskID int) {
|
||||
e.cancelMu.Lock()
|
||||
delete(e.cancelled, taskID)
|
||||
e.cancelMu.Unlock()
|
||||
}
|
||||
|
||||
// CancelAll requests cancellation for all running tasks.
|
||||
func (e *Executor) CancelAll() {
|
||||
e.mu.Lock()
|
||||
cancels := make([]context.CancelFunc, 0, len(e.running))
|
||||
for _, cancel := range e.running {
|
||||
cancels = append(cancels, cancel)
|
||||
}
|
||||
e.mu.Unlock()
|
||||
|
||||
for _, cancel := range cancels {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown cancels running tasks and waits for completion.
|
||||
func (e *Executor) Shutdown(ctx context.Context) error {
|
||||
e.stopping.Store(true)
|
||||
e.CancelAll()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
e.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
107
agent/internal/task/executor_test.go
Normal file
107
agent/internal/task/executor_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
)
|
||||
|
||||
type fakeReporter struct {
|
||||
status string
|
||||
msg string
|
||||
}
|
||||
|
||||
func (f *fakeReporter) UpdateStatus(ctx context.Context, taskID int, status, errorMessage string) error {
|
||||
f.status = status
|
||||
f.msg = errorMessage
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestExecutorMissingWorkerToken(t *testing.T) {
|
||||
reporter := &fakeReporter{}
|
||||
exec := &Executor{
|
||||
client: reporter,
|
||||
serverURL: "https://server",
|
||||
workerToken: "",
|
||||
}
|
||||
|
||||
exec.execute(context.Background(), &domain.Task{ID: 1})
|
||||
if reporter.status != "failed" {
|
||||
t.Fatalf("expected failed status, got %s", reporter.status)
|
||||
}
|
||||
if reporter.msg == "" {
|
||||
t.Fatalf("expected error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutorDockerUnavailable(t *testing.T) {
|
||||
reporter := &fakeReporter{}
|
||||
exec := &Executor{
|
||||
client: reporter,
|
||||
serverURL: "https://server",
|
||||
workerToken: "token",
|
||||
}
|
||||
|
||||
exec.execute(context.Background(), &domain.Task{ID: 2})
|
||||
if reporter.status != "failed" {
|
||||
t.Fatalf("expected failed status, got %s", reporter.status)
|
||||
}
|
||||
if reporter.msg == "" {
|
||||
t.Fatalf("expected error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutorCancelAll(t *testing.T) {
|
||||
exec := &Executor{
|
||||
running: map[int]context.CancelFunc{},
|
||||
}
|
||||
calls := 0
|
||||
exec.running[1] = func() { calls++ }
|
||||
exec.running[2] = func() { calls++ }
|
||||
|
||||
exec.CancelAll()
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected cancel calls, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutorShutdownWaits(t *testing.T) {
|
||||
exec := &Executor{
|
||||
running: map[int]context.CancelFunc{},
|
||||
}
|
||||
calls := 0
|
||||
exec.running[1] = func() { calls++ }
|
||||
|
||||
exec.wg.Add(1)
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
exec.wg.Done()
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := exec.Shutdown(ctx); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("expected cancel call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutorShutdownTimeout(t *testing.T) {
|
||||
exec := &Executor{
|
||||
running: map[int]context.CancelFunc{},
|
||||
}
|
||||
exec.wg.Add(1)
|
||||
defer exec.wg.Done()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if err := exec.Shutdown(ctx); err == nil {
|
||||
t.Fatalf("expected timeout error")
|
||||
}
|
||||
}
|
||||
252
agent/internal/task/puller.go
Normal file
252
agent/internal/task/puller.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
)
|
||||
|
||||
// Puller coordinates task pulling with load gating and backoff.
|
||||
type Puller struct {
|
||||
client TaskPuller
|
||||
collector MetricsSampler
|
||||
counter *Counter
|
||||
maxTasks int
|
||||
cpuThreshold int
|
||||
memThreshold int
|
||||
diskThreshold int
|
||||
|
||||
onTask func(*domain.Task)
|
||||
notifyCh chan struct{}
|
||||
emptyBackoff []time.Duration
|
||||
emptyIdx int
|
||||
errorBackoff time.Duration
|
||||
errorMax time.Duration
|
||||
randSrc *rand.Rand
|
||||
mu sync.RWMutex
|
||||
paused atomic.Bool
|
||||
}
|
||||
|
||||
type MetricsSampler interface {
|
||||
Sample() (float64, float64, float64)
|
||||
}
|
||||
|
||||
type TaskPuller interface {
|
||||
PullTask(ctx context.Context) (*domain.Task, error)
|
||||
}
|
||||
|
||||
// NewPuller creates a new Puller.
|
||||
func NewPuller(client TaskPuller, collector MetricsSampler, counter *Counter, maxTasks, cpuThreshold, memThreshold, diskThreshold int) *Puller {
|
||||
return &Puller{
|
||||
client: client,
|
||||
collector: collector,
|
||||
counter: counter,
|
||||
maxTasks: maxTasks,
|
||||
cpuThreshold: cpuThreshold,
|
||||
memThreshold: memThreshold,
|
||||
diskThreshold: diskThreshold,
|
||||
notifyCh: make(chan struct{}, 1),
|
||||
emptyBackoff: []time.Duration{5 * time.Second, 10 * time.Second, 30 * time.Second, 60 * time.Second},
|
||||
errorBackoff: 1 * time.Second,
|
||||
errorMax: 60 * time.Second,
|
||||
randSrc: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnTask registers a callback invoked when a task is assigned.
|
||||
func (p *Puller) SetOnTask(fn func(*domain.Task)) {
|
||||
p.onTask = fn
|
||||
}
|
||||
|
||||
// NotifyTaskAvailable triggers an immediate pull attempt.
|
||||
func (p *Puller) NotifyTaskAvailable() {
|
||||
select {
|
||||
case p.notifyCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the pull loop.
|
||||
func (p *Puller) Run(ctx context.Context) error {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if p.paused.Load() {
|
||||
if !p.waitUntilCanceled(ctx) {
|
||||
return ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
loadInterval := p.loadInterval()
|
||||
if !p.canPull() {
|
||||
if !p.wait(ctx, loadInterval) {
|
||||
return ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
task, err := p.client.PullTask(ctx)
|
||||
if err != nil {
|
||||
delay := p.nextErrorBackoff()
|
||||
if !p.wait(ctx, delay) {
|
||||
return ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
p.resetErrorBackoff()
|
||||
if task == nil {
|
||||
delay := p.nextEmptyDelay(loadInterval)
|
||||
if !p.waitOrNotify(ctx, delay) {
|
||||
return ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
p.resetEmptyBackoff()
|
||||
if p.onTask != nil {
|
||||
p.onTask(task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Puller) canPull() bool {
|
||||
maxTasks, cpuThreshold, memThreshold, diskThreshold := p.currentConfig()
|
||||
if p.counter != nil && p.counter.Count() >= maxTasks {
|
||||
return false
|
||||
}
|
||||
cpu, mem, disk := p.collector.Sample()
|
||||
return cpu < float64(cpuThreshold) &&
|
||||
mem < float64(memThreshold) &&
|
||||
disk < float64(diskThreshold)
|
||||
}
|
||||
|
||||
func (p *Puller) loadInterval() time.Duration {
|
||||
cpu, mem, disk := p.collector.Sample()
|
||||
load := math.Max(cpu, math.Max(mem, disk))
|
||||
switch {
|
||||
case load < 50:
|
||||
return 1 * time.Second
|
||||
case load < 80:
|
||||
return 3 * time.Second
|
||||
default:
|
||||
return 10 * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Puller) nextEmptyDelay(loadInterval time.Duration) time.Duration {
|
||||
var empty time.Duration
|
||||
if p.emptyIdx < len(p.emptyBackoff) {
|
||||
empty = p.emptyBackoff[p.emptyIdx]
|
||||
p.emptyIdx++
|
||||
} else {
|
||||
empty = p.emptyBackoff[len(p.emptyBackoff)-1]
|
||||
}
|
||||
if empty < loadInterval {
|
||||
return loadInterval
|
||||
}
|
||||
return empty
|
||||
}
|
||||
|
||||
func (p *Puller) resetEmptyBackoff() {
|
||||
p.emptyIdx = 0
|
||||
}
|
||||
|
||||
func (p *Puller) nextErrorBackoff() time.Duration {
|
||||
delay := p.errorBackoff
|
||||
next := delay * 2
|
||||
if next > p.errorMax {
|
||||
next = p.errorMax
|
||||
}
|
||||
p.errorBackoff = next
|
||||
return withJitter(delay, p.randSrc)
|
||||
}
|
||||
|
||||
func (p *Puller) resetErrorBackoff() {
|
||||
p.errorBackoff = 1 * time.Second
|
||||
}
|
||||
|
||||
func (p *Puller) wait(ctx context.Context, delay time.Duration) bool {
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-timer.C:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Puller) waitOrNotify(ctx context.Context, delay time.Duration) bool {
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-p.notifyCh:
|
||||
return true
|
||||
case <-timer.C:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func withJitter(delay time.Duration, src *rand.Rand) time.Duration {
|
||||
if delay <= 0 || src == nil {
|
||||
return delay
|
||||
}
|
||||
jitter := src.Float64() * 0.2
|
||||
return delay + time.Duration(float64(delay)*jitter)
|
||||
}
|
||||
|
||||
func (p *Puller) EnsureTaskHandler() error {
|
||||
if p.onTask == nil {
|
||||
return errors.New("task handler is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pause stops pulling. Once paused, only context cancellation exits the loop.
|
||||
func (p *Puller) Pause() {
|
||||
p.paused.Store(true)
|
||||
}
|
||||
|
||||
// UpdateConfig updates puller thresholds and max tasks.
|
||||
func (p *Puller) UpdateConfig(maxTasks, cpuThreshold, memThreshold, diskThreshold *int) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if maxTasks != nil && *maxTasks > 0 {
|
||||
p.maxTasks = *maxTasks
|
||||
}
|
||||
if cpuThreshold != nil && *cpuThreshold > 0 {
|
||||
p.cpuThreshold = *cpuThreshold
|
||||
}
|
||||
if memThreshold != nil && *memThreshold > 0 {
|
||||
p.memThreshold = *memThreshold
|
||||
}
|
||||
if diskThreshold != nil && *diskThreshold > 0 {
|
||||
p.diskThreshold = *diskThreshold
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Puller) currentConfig() (int, int, int, int) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.maxTasks, p.cpuThreshold, p.memThreshold, p.diskThreshold
|
||||
}
|
||||
|
||||
func (p *Puller) waitUntilCanceled(ctx context.Context) bool {
|
||||
<-ctx.Done()
|
||||
return false
|
||||
}
|
||||
101
agent/internal/task/puller_test.go
Normal file
101
agent/internal/task/puller_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
)
|
||||
|
||||
func TestPullerUpdateConfig(t *testing.T) {
|
||||
p := NewPuller(nil, nil, nil, 5, 85, 86, 87)
|
||||
max, cpu, mem, disk := p.currentConfig()
|
||||
if max != 5 || cpu != 85 || mem != 86 || disk != 87 {
|
||||
t.Fatalf("unexpected initial config")
|
||||
}
|
||||
|
||||
maxUpdate := 8
|
||||
cpuUpdate := 70
|
||||
p.UpdateConfig(&maxUpdate, &cpuUpdate, nil, nil)
|
||||
max, cpu, mem, disk = p.currentConfig()
|
||||
if max != 8 || cpu != 70 || mem != 86 || disk != 87 {
|
||||
t.Fatalf("unexpected updated config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullerPause(t *testing.T) {
|
||||
p := NewPuller(nil, nil, nil, 1, 1, 1, 1)
|
||||
p.Pause()
|
||||
if !p.paused.Load() {
|
||||
t.Fatalf("expected paused")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullerEnsureTaskHandler(t *testing.T) {
|
||||
p := NewPuller(nil, nil, nil, 1, 1, 1, 1)
|
||||
if err := p.EnsureTaskHandler(); err == nil {
|
||||
t.Fatalf("expected error when handler missing")
|
||||
}
|
||||
p.SetOnTask(func(*domain.Task) {})
|
||||
if err := p.EnsureTaskHandler(); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullerNextEmptyDelay(t *testing.T) {
|
||||
p := NewPuller(nil, nil, nil, 1, 1, 1, 1)
|
||||
p.emptyBackoff = []time.Duration{5 * time.Second, 10 * time.Second}
|
||||
|
||||
if delay := p.nextEmptyDelay(8 * time.Second); delay != 8*time.Second {
|
||||
t.Fatalf("expected delay to honor load interval, got %v", delay)
|
||||
}
|
||||
if delay := p.nextEmptyDelay(1 * time.Second); delay != 10*time.Second {
|
||||
t.Fatalf("expected backoff delay, got %v", delay)
|
||||
}
|
||||
if p.emptyIdx != 2 {
|
||||
t.Fatalf("expected empty index to advance")
|
||||
}
|
||||
p.resetEmptyBackoff()
|
||||
if p.emptyIdx != 0 {
|
||||
t.Fatalf("expected empty index reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPullerErrorBackoff(t *testing.T) {
|
||||
p := NewPuller(nil, nil, nil, 1, 1, 1, 1)
|
||||
p.randSrc = rand.New(rand.NewSource(1))
|
||||
|
||||
first := p.nextErrorBackoff()
|
||||
if first < time.Second || first > time.Second+(time.Second/5) {
|
||||
t.Fatalf("unexpected backoff %v", first)
|
||||
}
|
||||
if p.errorBackoff != 2*time.Second {
|
||||
t.Fatalf("expected backoff to double")
|
||||
}
|
||||
|
||||
second := p.nextErrorBackoff()
|
||||
if second < 2*time.Second || second > 2*time.Second+(2*time.Second/5) {
|
||||
t.Fatalf("unexpected backoff %v", second)
|
||||
}
|
||||
if p.errorBackoff != 4*time.Second {
|
||||
t.Fatalf("expected backoff to double")
|
||||
}
|
||||
|
||||
p.resetErrorBackoff()
|
||||
if p.errorBackoff != time.Second {
|
||||
t.Fatalf("expected error backoff reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithJitterRange(t *testing.T) {
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
delay := 10 * time.Second
|
||||
got := withJitter(delay, rng)
|
||||
if got < delay {
|
||||
t.Fatalf("expected jitter >= delay")
|
||||
}
|
||||
if got > delay+(delay/5) {
|
||||
t.Fatalf("expected jitter <= 20%%")
|
||||
}
|
||||
}
|
||||
279
agent/internal/update/updater.go
Normal file
279
agent/internal/update/updater.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package update
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/network"
|
||||
"github.com/docker/docker/api/types/strslice"
|
||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
"github.com/yyhuni/lunafox/agent/internal/config"
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
"github.com/yyhuni/lunafox/agent/internal/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Updater handles agent self-update.
|
||||
type Updater struct {
|
||||
docker dockerClient
|
||||
health healthSetter
|
||||
puller pullerController
|
||||
executor executorController
|
||||
cfg configSnapshot
|
||||
apiKey string
|
||||
token string
|
||||
mu sync.Mutex
|
||||
updating bool
|
||||
randSrc *rand.Rand
|
||||
backoff time.Duration
|
||||
maxBackoff time.Duration
|
||||
}
|
||||
|
||||
type dockerClient interface {
|
||||
ImagePull(ctx context.Context, imageRef string) (io.ReadCloser, error)
|
||||
ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *ocispec.Platform, name string) (container.CreateResponse, error)
|
||||
ContainerStart(ctx context.Context, containerID string, opts container.StartOptions) error
|
||||
}
|
||||
|
||||
type healthSetter interface {
|
||||
Set(state, reason, message string)
|
||||
}
|
||||
|
||||
type pullerController interface {
|
||||
Pause()
|
||||
}
|
||||
|
||||
type executorController interface {
|
||||
Shutdown(ctx context.Context) error
|
||||
}
|
||||
|
||||
type configSnapshot interface {
|
||||
Snapshot() config.Config
|
||||
}
|
||||
|
||||
// NewUpdater creates a new updater.
|
||||
func NewUpdater(dockerClient dockerClient, healthManager healthSetter, puller pullerController, executor executorController, cfg configSnapshot, apiKey, token string) *Updater {
|
||||
return &Updater{
|
||||
docker: dockerClient,
|
||||
health: healthManager,
|
||||
puller: puller,
|
||||
executor: executor,
|
||||
cfg: cfg,
|
||||
apiKey: apiKey,
|
||||
token: token,
|
||||
randSrc: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
backoff: 30 * time.Second,
|
||||
maxBackoff: 10 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleUpdateRequired triggers the update flow.
|
||||
func (u *Updater) HandleUpdateRequired(payload domain.UpdateRequiredPayload) {
|
||||
u.mu.Lock()
|
||||
if u.updating {
|
||||
u.mu.Unlock()
|
||||
return
|
||||
}
|
||||
u.updating = true
|
||||
u.mu.Unlock()
|
||||
|
||||
go u.run(payload)
|
||||
}
|
||||
|
||||
func (u *Updater) run(payload domain.UpdateRequiredPayload) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Log.Error("agent update panic", zap.Any("panic", r))
|
||||
u.health.Set("paused", "update_panic", fmt.Sprintf("%v", r))
|
||||
}
|
||||
u.mu.Lock()
|
||||
u.updating = false
|
||||
u.mu.Unlock()
|
||||
}()
|
||||
u.puller.Pause()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
_ = u.executor.Shutdown(ctx)
|
||||
cancel()
|
||||
|
||||
for {
|
||||
if err := u.updateOnce(payload); err == nil {
|
||||
u.health.Set("ok", "", "")
|
||||
os.Exit(0)
|
||||
} else {
|
||||
u.health.Set("paused", "update_failed", err.Error())
|
||||
}
|
||||
|
||||
delay := withJitter(u.backoff, u.randSrc)
|
||||
if u.backoff < u.maxBackoff {
|
||||
u.backoff *= 2
|
||||
if u.backoff > u.maxBackoff {
|
||||
u.backoff = u.maxBackoff
|
||||
}
|
||||
}
|
||||
time.Sleep(delay)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Updater) updateOnce(payload domain.UpdateRequiredPayload) error {
|
||||
if u.docker == nil {
|
||||
return fmt.Errorf("docker client unavailable")
|
||||
}
|
||||
image := strings.TrimSpace(payload.Image)
|
||||
version := strings.TrimSpace(payload.Version)
|
||||
if image == "" || version == "" {
|
||||
return fmt.Errorf("invalid update payload")
|
||||
}
|
||||
|
||||
// Strict validation: reject invalid data from server
|
||||
if err := validateImageName(image); err != nil {
|
||||
logger.Log.Warn("invalid image name from server", zap.String("image", image), zap.Error(err))
|
||||
return fmt.Errorf("invalid image name from server: %w", err)
|
||||
}
|
||||
if err := validateVersion(version); err != nil {
|
||||
logger.Log.Warn("invalid version from server", zap.String("version", version), zap.Error(err))
|
||||
return fmt.Errorf("invalid version from server: %w", err)
|
||||
}
|
||||
|
||||
fullImage := fmt.Sprintf("%s:%s", image, version)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
reader, err := u.docker.ImagePull(ctx, fullImage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, reader)
|
||||
_ = reader.Close()
|
||||
|
||||
if err := u.startNewContainer(ctx, image, version); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Updater) startNewContainer(ctx context.Context, image, version string) error {
|
||||
env := []string{
|
||||
fmt.Sprintf("SERVER_URL=%s", u.cfg.Snapshot().ServerURL),
|
||||
fmt.Sprintf("API_KEY=%s", u.apiKey),
|
||||
fmt.Sprintf("MAX_TASKS=%d", u.cfg.Snapshot().MaxTasks),
|
||||
fmt.Sprintf("CPU_THRESHOLD=%d", u.cfg.Snapshot().CPUThreshold),
|
||||
fmt.Sprintf("MEM_THRESHOLD=%d", u.cfg.Snapshot().MemThreshold),
|
||||
fmt.Sprintf("DISK_THRESHOLD=%d", u.cfg.Snapshot().DiskThreshold),
|
||||
fmt.Sprintf("AGENT_VERSION=%s", version),
|
||||
}
|
||||
if u.token != "" {
|
||||
env = append(env, fmt.Sprintf("WORKER_TOKEN=%s", u.token))
|
||||
}
|
||||
|
||||
cfg := &container.Config{
|
||||
Image: fmt.Sprintf("%s:%s", image, version),
|
||||
Env: env,
|
||||
Cmd: strslice.StrSlice{},
|
||||
}
|
||||
|
||||
hostConfig := &container.HostConfig{
|
||||
Binds: []string{
|
||||
"/var/run/docker.sock:/var/run/docker.sock",
|
||||
"/opt/lunafox:/opt/lunafox",
|
||||
},
|
||||
RestartPolicy: container.RestartPolicy{Name: "unless-stopped"},
|
||||
OomScoreAdj: -500,
|
||||
}
|
||||
|
||||
// Version is already validated, just normalize to lowercase for container name
|
||||
name := fmt.Sprintf("lunafox-agent-%s", strings.ToLower(version))
|
||||
resp, err := u.docker.ContainerCreate(ctx, cfg, hostConfig, &network.NetworkingConfig{}, nil, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := u.docker.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Log.Info("agent update started new container", zap.String("containerId", resp.ID))
|
||||
return nil
|
||||
}
|
||||
|
||||
func withJitter(delay time.Duration, src *rand.Rand) time.Duration {
|
||||
if delay <= 0 || src == nil {
|
||||
return delay
|
||||
}
|
||||
jitter := src.Float64() * 0.2
|
||||
return delay + time.Duration(float64(delay)*jitter)
|
||||
}
|
||||
|
||||
// validateImageName validates that the image name contains only safe characters.
|
||||
// Returns error if validation fails.
|
||||
func validateImageName(image string) error {
|
||||
if len(image) == 0 {
|
||||
return fmt.Errorf("image name cannot be empty")
|
||||
}
|
||||
if len(image) > 255 {
|
||||
return fmt.Errorf("image name too long: %d characters", len(image))
|
||||
}
|
||||
|
||||
// Allow: alphanumeric, dots, hyphens, underscores, slashes (for registry paths)
|
||||
for i, r := range image {
|
||||
if !((r >= 'a' && r <= 'z') ||
|
||||
(r >= 'A' && r <= 'Z') ||
|
||||
(r >= '0' && r <= '9') ||
|
||||
r == '.' || r == '-' || r == '_' || r == '/') {
|
||||
return fmt.Errorf("invalid character at position %d: %c", i, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Must not start or end with special characters
|
||||
first := rune(image[0])
|
||||
last := rune(image[len(image)-1])
|
||||
if first == '.' || first == '-' || first == '/' {
|
||||
return fmt.Errorf("image name cannot start with special character: %c", first)
|
||||
}
|
||||
if last == '.' || last == '-' || last == '/' {
|
||||
return fmt.Errorf("image name cannot end with special character: %c", last)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateVersion validates that the version string contains only safe characters.
|
||||
// Returns error if validation fails.
|
||||
func validateVersion(version string) error {
|
||||
if len(version) == 0 {
|
||||
return fmt.Errorf("version cannot be empty")
|
||||
}
|
||||
if len(version) > 128 {
|
||||
return fmt.Errorf("version too long: %d characters", len(version))
|
||||
}
|
||||
|
||||
// Allow: alphanumeric, dots, hyphens, underscores
|
||||
for i, r := range version {
|
||||
if !((r >= 'a' && r <= 'z') ||
|
||||
(r >= 'A' && r <= 'Z') ||
|
||||
(r >= '0' && r <= '9') ||
|
||||
r == '.' || r == '-' || r == '_') {
|
||||
return fmt.Errorf("invalid character at position %d: %c", i, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Must not start or end with special characters
|
||||
first := rune(version[0])
|
||||
last := rune(version[len(version)-1])
|
||||
if first == '.' || first == '-' || first == '_' {
|
||||
return fmt.Errorf("version cannot start with special character: %c", first)
|
||||
}
|
||||
if last == '.' || last == '-' || last == '_' {
|
||||
return fmt.Errorf("version cannot end with special character: %c", last)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
45
agent/internal/update/updater_test.go
Normal file
45
agent/internal/update/updater_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package update
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/domain"
|
||||
)
|
||||
|
||||
func TestSanitizeContainerName(t *testing.T) {
|
||||
got := sanitizeContainerName("v1.0.0+TEST")
|
||||
if got == "" {
|
||||
t.Fatalf("expected sanitized name")
|
||||
}
|
||||
if got == "v1.0.0+test" {
|
||||
t.Fatalf("expected sanitized to replace invalid chars")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithJitterRange(t *testing.T) {
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
delay := 10 * time.Second
|
||||
got := withJitter(delay, rng)
|
||||
if got < delay {
|
||||
t.Fatalf("expected jitter >= delay")
|
||||
}
|
||||
if got > delay+(delay/5) {
|
||||
t.Fatalf("expected jitter <= 20%%")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateOnceDockerUnavailable(t *testing.T) {
|
||||
updater := &Updater{}
|
||||
payload := domain.UpdateRequiredPayload{Version: "v1.0.0", Image: "yyhuni/lunafox-agent"}
|
||||
|
||||
err := updater.updateOnce(payload)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error when docker client is nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "docker client unavailable") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
37
agent/internal/websocket/backoff.go
Normal file
37
agent/internal/websocket/backoff.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package websocket
|
||||
|
||||
import "time"
|
||||
|
||||
// Backoff implements exponential backoff with a maximum cap.
|
||||
type Backoff struct {
|
||||
base time.Duration
|
||||
max time.Duration
|
||||
current time.Duration
|
||||
}
|
||||
|
||||
// NewBackoff creates a backoff with the given base and max delay.
|
||||
func NewBackoff(base, max time.Duration) Backoff {
|
||||
return Backoff{
|
||||
base: base,
|
||||
max: max,
|
||||
}
|
||||
}
|
||||
|
||||
// Next returns the next backoff duration.
|
||||
func (b *Backoff) Next() time.Duration {
|
||||
if b.current <= 0 {
|
||||
b.current = b.base
|
||||
return b.current
|
||||
}
|
||||
next := b.current * 2
|
||||
if next > b.max {
|
||||
next = b.max
|
||||
}
|
||||
b.current = next
|
||||
return b.current
|
||||
}
|
||||
|
||||
// Reset clears the backoff to start over.
|
||||
func (b *Backoff) Reset() {
|
||||
b.current = 0
|
||||
}
|
||||
32
agent/internal/websocket/backoff_test.go
Normal file
32
agent/internal/websocket/backoff_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBackoffSequence(t *testing.T) {
|
||||
b := NewBackoff(time.Second, 60*time.Second)
|
||||
|
||||
expected := []time.Duration{
|
||||
time.Second,
|
||||
2 * time.Second,
|
||||
4 * time.Second,
|
||||
8 * time.Second,
|
||||
16 * time.Second,
|
||||
32 * time.Second,
|
||||
60 * time.Second,
|
||||
60 * time.Second,
|
||||
}
|
||||
|
||||
for i, exp := range expected {
|
||||
if got := b.Next(); got != exp {
|
||||
t.Fatalf("step %d: expected %v, got %v", i, exp, got)
|
||||
}
|
||||
}
|
||||
|
||||
b.Reset()
|
||||
if got := b.Next(); got != time.Second {
|
||||
t.Fatalf("after reset expected %v, got %v", time.Second, got)
|
||||
}
|
||||
}
|
||||
177
agent/internal/websocket/client.go
Normal file
177
agent/internal/websocket/client.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/yyhuni/lunafox/agent/internal/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPingInterval = 30 * time.Second
|
||||
defaultPongWait = 60 * time.Second
|
||||
defaultWriteWait = 10 * time.Second
|
||||
)
|
||||
|
||||
// Client maintains a WebSocket connection to the server.
|
||||
type Client struct {
|
||||
wsURL string
|
||||
apiKey string
|
||||
dialer *websocket.Dialer
|
||||
send chan []byte
|
||||
onMessage func([]byte)
|
||||
backoff Backoff
|
||||
pingInterval time.Duration
|
||||
pongWait time.Duration
|
||||
writeWait time.Duration
|
||||
}
|
||||
|
||||
// NewClient creates a WebSocket client for the agent.
|
||||
func NewClient(wsURL, apiKey string) *Client {
|
||||
dialer := *websocket.DefaultDialer
|
||||
dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
return &Client{
|
||||
wsURL: wsURL,
|
||||
apiKey: apiKey,
|
||||
dialer: &dialer,
|
||||
send: make(chan []byte, 256),
|
||||
backoff: NewBackoff(1*time.Second, 60*time.Second),
|
||||
pingInterval: defaultPingInterval,
|
||||
pongWait: defaultPongWait,
|
||||
writeWait: defaultWriteWait,
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnMessage registers a callback for incoming messages.
|
||||
func (c *Client) SetOnMessage(fn func([]byte)) {
|
||||
c.onMessage = fn
|
||||
}
|
||||
|
||||
// Send queues a message for sending. It returns false if the buffer is full.
|
||||
func (c *Client) Send(payload []byte) bool {
|
||||
select {
|
||||
case c.send <- payload:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Run keeps the connection alive with reconnect backoff and keepalive pings.
|
||||
func (c *Client) Run(ctx context.Context) error {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
logger.Log.Info("websocket connect attempt", zap.String("url", c.wsURL))
|
||||
conn, err := c.connect(ctx)
|
||||
if err != nil {
|
||||
logger.Log.Warn("websocket connect failed", zap.Error(err))
|
||||
if !sleepWithContext(ctx, c.backoff.Next()) {
|
||||
return ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
c.backoff.Reset()
|
||||
logger.Log.Info("websocket connected")
|
||||
err = c.runConn(ctx, conn)
|
||||
if err != nil && ctx.Err() == nil {
|
||||
logger.Log.Warn("websocket connection closed", zap.Error(err))
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if !sleepWithContext(ctx, c.backoff.Next()) {
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context) (*websocket.Conn, error) {
|
||||
header := http.Header{}
|
||||
if c.apiKey != "" {
|
||||
header.Set("X-Agent-Key", c.apiKey)
|
||||
}
|
||||
conn, _, err := c.dialer.DialContext(ctx, c.wsURL, header)
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func (c *Client) runConn(ctx context.Context, conn *websocket.Conn) error {
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(c.pongWait))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(c.pongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
go c.readLoop(conn, errCh)
|
||||
go c.writeLoop(ctx, conn, errCh)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-errCh:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) readLoop(conn *websocket.Conn, errCh chan<- error) {
|
||||
for {
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
if c.onMessage != nil {
|
||||
c.onMessage(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) writeLoop(ctx context.Context, conn *websocket.Conn, errCh chan<- error) {
|
||||
ticker := time.NewTicker(c.pingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
errCh <- ctx.Err()
|
||||
return
|
||||
case payload := <-c.send:
|
||||
if err := c.writeMessage(conn, websocket.TextMessage, payload); err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
if err := c.writeMessage(conn, websocket.PingMessage, nil); err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error {
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(c.writeWait))
|
||||
return conn.WriteMessage(msgType, payload)
|
||||
}
|
||||
|
||||
func sleepWithContext(ctx context.Context, delay time.Duration) bool {
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-timer.C:
|
||||
return true
|
||||
}
|
||||
}
|
||||
32
agent/internal/websocket/client_test.go
Normal file
32
agent/internal/websocket/client_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestClientSendBufferFull(t *testing.T) {
|
||||
client := &Client{send: make(chan []byte, 1)}
|
||||
if !client.Send([]byte("first")) {
|
||||
t.Fatalf("expected first send to succeed")
|
||||
}
|
||||
if client.Send([]byte("second")) {
|
||||
t.Fatalf("expected second send to fail when buffer is full")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSleepWithContextCancelled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
if sleepWithContext(ctx, 50*time.Millisecond) {
|
||||
t.Fatalf("expected sleepWithContext to return false when canceled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSleepWithContextElapsed(t *testing.T) {
|
||||
if !sleepWithContext(context.Background(), 5*time.Millisecond) {
|
||||
t.Fatalf("expected sleepWithContext to return true after delay")
|
||||
}
|
||||
}
|
||||
90
agent/internal/websocket/handlers.go
Normal file
90
agent/internal/websocket/handlers.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/protocol"
|
||||
)
|
||||
|
||||
// Handler routes incoming WebSocket messages.
|
||||
type Handler struct {
|
||||
onTaskAvailable func()
|
||||
onTaskCancel func(int)
|
||||
onConfigUpdate func(protocol.ConfigUpdatePayload)
|
||||
onUpdateReq func(protocol.UpdateRequiredPayload)
|
||||
}
|
||||
|
||||
// NewHandler creates a message handler.
|
||||
func NewHandler() *Handler {
|
||||
return &Handler{}
|
||||
}
|
||||
|
||||
// OnTaskAvailable registers a callback for task_available messages.
|
||||
func (h *Handler) OnTaskAvailable(fn func()) {
|
||||
h.onTaskAvailable = fn
|
||||
}
|
||||
|
||||
// OnTaskCancel registers a callback for task_cancel messages.
|
||||
func (h *Handler) OnTaskCancel(fn func(int)) {
|
||||
h.onTaskCancel = fn
|
||||
}
|
||||
|
||||
// OnConfigUpdate registers a callback for config_update messages.
|
||||
func (h *Handler) OnConfigUpdate(fn func(protocol.ConfigUpdatePayload)) {
|
||||
h.onConfigUpdate = fn
|
||||
}
|
||||
|
||||
// OnUpdateRequired registers a callback for update_required messages.
|
||||
func (h *Handler) OnUpdateRequired(fn func(protocol.UpdateRequiredPayload)) {
|
||||
h.onUpdateReq = fn
|
||||
}
|
||||
|
||||
// Handle processes a raw message.
|
||||
func (h *Handler) Handle(raw []byte) {
|
||||
var msg struct {
|
||||
Type string `json:"type"`
|
||||
Data json.RawMessage `json:"payload"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &msg); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case protocol.MessageTypeTaskAvailable:
|
||||
if h.onTaskAvailable != nil {
|
||||
h.onTaskAvailable()
|
||||
}
|
||||
case protocol.MessageTypeTaskCancel:
|
||||
if h.onTaskCancel == nil {
|
||||
return
|
||||
}
|
||||
var payload protocol.TaskCancelPayload
|
||||
if err := json.Unmarshal(msg.Data, &payload); err != nil {
|
||||
return
|
||||
}
|
||||
if payload.TaskID > 0 {
|
||||
h.onTaskCancel(payload.TaskID)
|
||||
}
|
||||
case protocol.MessageTypeConfigUpdate:
|
||||
if h.onConfigUpdate == nil {
|
||||
return
|
||||
}
|
||||
var payload protocol.ConfigUpdatePayload
|
||||
if err := json.Unmarshal(msg.Data, &payload); err != nil {
|
||||
return
|
||||
}
|
||||
h.onConfigUpdate(payload)
|
||||
case protocol.MessageTypeUpdateRequired:
|
||||
if h.onUpdateReq == nil {
|
||||
return
|
||||
}
|
||||
var payload protocol.UpdateRequiredPayload
|
||||
if err := json.Unmarshal(msg.Data, &payload); err != nil {
|
||||
return
|
||||
}
|
||||
if payload.Version == "" || payload.Image == "" {
|
||||
return
|
||||
}
|
||||
h.onUpdateReq(payload)
|
||||
}
|
||||
}
|
||||
85
agent/internal/websocket/handlers_test.go
Normal file
85
agent/internal/websocket/handlers_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/protocol"
|
||||
)
|
||||
|
||||
func TestHandlersTaskAvailable(t *testing.T) {
|
||||
h := NewHandler()
|
||||
called := 0
|
||||
h.OnTaskAvailable(func() { called++ })
|
||||
|
||||
message := fmt.Sprintf(`{"type":"%s","payload":{},"timestamp":"2026-01-01T00:00:00Z"}`, protocol.MessageTypeTaskAvailable)
|
||||
h.Handle([]byte(message))
|
||||
if called != 1 {
|
||||
t.Fatalf("expected callback to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlersTaskCancel(t *testing.T) {
|
||||
h := NewHandler()
|
||||
var got int
|
||||
h.OnTaskCancel(func(id int) { got = id })
|
||||
|
||||
message := fmt.Sprintf(`{"type":"%s","payload":{"taskId":123},"timestamp":"2026-01-01T00:00:00Z"}`, protocol.MessageTypeTaskCancel)
|
||||
h.Handle([]byte(message))
|
||||
if got != 123 {
|
||||
t.Fatalf("expected taskId 123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlersConfigUpdate(t *testing.T) {
|
||||
h := NewHandler()
|
||||
var maxTasks int
|
||||
h.OnConfigUpdate(func(payload protocol.ConfigUpdatePayload) {
|
||||
if payload.MaxTasks != nil {
|
||||
maxTasks = *payload.MaxTasks
|
||||
}
|
||||
})
|
||||
|
||||
message := fmt.Sprintf(`{"type":"%s","payload":{"maxTasks":8},"timestamp":"2026-01-01T00:00:00Z"}`, protocol.MessageTypeConfigUpdate)
|
||||
h.Handle([]byte(message))
|
||||
if maxTasks != 8 {
|
||||
t.Fatalf("expected maxTasks 8")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlersUpdateRequired(t *testing.T) {
|
||||
h := NewHandler()
|
||||
var version string
|
||||
h.OnUpdateRequired(func(payload protocol.UpdateRequiredPayload) { version = payload.Version })
|
||||
|
||||
message := fmt.Sprintf(`{"type":"%s","payload":{"version":"v1.0.1","image":"yyhuni/lunafox-agent"},"timestamp":"2026-01-01T00:00:00Z"}`, protocol.MessageTypeUpdateRequired)
|
||||
h.Handle([]byte(message))
|
||||
if version != "v1.0.1" {
|
||||
t.Fatalf("expected version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlersIgnoreInvalidJSON(t *testing.T) {
|
||||
h := NewHandler()
|
||||
called := 0
|
||||
h.OnTaskAvailable(func() { called++ })
|
||||
|
||||
h.Handle([]byte("{bad json"))
|
||||
if called != 0 {
|
||||
t.Fatalf("expected no callbacks on invalid json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlersUpdateRequiredMissingFields(t *testing.T) {
|
||||
h := NewHandler()
|
||||
called := 0
|
||||
h.OnUpdateRequired(func(payload protocol.UpdateRequiredPayload) { called++ })
|
||||
|
||||
message := fmt.Sprintf(`{"type":"%s","payload":{"version":"","image":"yyhuni/lunafox-agent"}}`, protocol.MessageTypeUpdateRequired)
|
||||
h.Handle([]byte(message))
|
||||
message = fmt.Sprintf(`{"type":"%s","payload":{"version":"v1.2.3","image":""}}`, protocol.MessageTypeUpdateRequired)
|
||||
h.Handle([]byte(message))
|
||||
if called != 0 {
|
||||
t.Fatalf("expected no callbacks for invalid payload")
|
||||
}
|
||||
}
|
||||
97
agent/internal/websocket/heartbeat.go
Normal file
97
agent/internal/websocket/heartbeat.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/health"
|
||||
"github.com/yyhuni/lunafox/agent/internal/logger"
|
||||
"github.com/yyhuni/lunafox/agent/internal/metrics"
|
||||
"github.com/yyhuni/lunafox/agent/internal/protocol"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HeartbeatSender sends periodic heartbeat messages over WebSocket.
|
||||
type HeartbeatSender struct {
|
||||
client *Client
|
||||
collector *metrics.Collector
|
||||
health *health.Manager
|
||||
version string
|
||||
hostname string
|
||||
startedAt time.Time
|
||||
taskCount func() int
|
||||
interval time.Duration
|
||||
lastSentAt time.Time
|
||||
}
|
||||
|
||||
// NewHeartbeatSender creates a heartbeat sender.
|
||||
func NewHeartbeatSender(client *Client, collector *metrics.Collector, healthManager *health.Manager, version, hostname string, taskCount func() int) *HeartbeatSender {
|
||||
return &HeartbeatSender{
|
||||
client: client,
|
||||
collector: collector,
|
||||
health: healthManager,
|
||||
version: version,
|
||||
hostname: hostname,
|
||||
startedAt: time.Now(),
|
||||
taskCount: taskCount,
|
||||
interval: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins sending heartbeats until context is canceled.
|
||||
func (h *HeartbeatSender) Start(ctx context.Context) {
|
||||
ticker := time.NewTicker(h.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
h.sendOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
h.sendOnce()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HeartbeatSender) sendOnce() {
|
||||
cpu, mem, disk := h.collector.Sample()
|
||||
uptime := int64(time.Since(h.startedAt).Seconds())
|
||||
tasks := 0
|
||||
if h.taskCount != nil {
|
||||
tasks = h.taskCount()
|
||||
}
|
||||
|
||||
status := h.health.Get()
|
||||
payload := protocol.HeartbeatPayload{
|
||||
CPU: cpu,
|
||||
Mem: mem,
|
||||
Disk: disk,
|
||||
Tasks: tasks,
|
||||
Version: h.version,
|
||||
Hostname: h.hostname,
|
||||
Uptime: uptime,
|
||||
Health: protocol.HealthStatus{
|
||||
State: status.State,
|
||||
Reason: status.Reason,
|
||||
Message: status.Message,
|
||||
Since: status.Since,
|
||||
},
|
||||
}
|
||||
|
||||
msg := protocol.Message{
|
||||
Type: protocol.MessageTypeHeartbeat,
|
||||
Payload: payload,
|
||||
Timestamp: time.Now().UTC(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
logger.Log.Warn("failed to marshal heartbeat message", zap.Error(err))
|
||||
return
|
||||
}
|
||||
if !h.client.Send(data) {
|
||||
logger.Log.Warn("failed to send heartbeat: client not connected")
|
||||
}
|
||||
}
|
||||
57
agent/internal/websocket/heartbeat_test.go
Normal file
57
agent/internal/websocket/heartbeat_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/yyhuni/lunafox/agent/internal/health"
|
||||
"github.com/yyhuni/lunafox/agent/internal/metrics"
|
||||
"github.com/yyhuni/lunafox/agent/internal/protocol"
|
||||
)
|
||||
|
||||
func TestHeartbeatSenderSendOnce(t *testing.T) {
|
||||
client := &Client{send: make(chan []byte, 1)}
|
||||
collector := metrics.NewCollector()
|
||||
healthManager := health.NewManager()
|
||||
healthManager.Set("paused", "maintenance", "waiting")
|
||||
|
||||
sender := NewHeartbeatSender(client, collector, healthManager, "v1.0.0", "agent-host", func() int { return 3 })
|
||||
sender.sendOnce()
|
||||
|
||||
select {
|
||||
case payload := <-client.send:
|
||||
var msg struct {
|
||||
Type string `json:"type"`
|
||||
Payload map[string]interface{} `json:"payload"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &msg); err != nil {
|
||||
t.Fatalf("unmarshal heartbeat: %v", err)
|
||||
}
|
||||
if msg.Type != protocol.MessageTypeHeartbeat {
|
||||
t.Fatalf("expected heartbeat type, got %s", msg.Type)
|
||||
}
|
||||
if msg.Timestamp.IsZero() {
|
||||
t.Fatalf("expected timestamp")
|
||||
}
|
||||
if msg.Payload["version"] != "v1.0.0" {
|
||||
t.Fatalf("expected version in payload")
|
||||
}
|
||||
if msg.Payload["hostname"] != "agent-host" {
|
||||
t.Fatalf("expected hostname in payload")
|
||||
}
|
||||
if tasks, ok := msg.Payload["tasks"].(float64); !ok || int(tasks) != 3 {
|
||||
t.Fatalf("expected tasks=3")
|
||||
}
|
||||
healthPayload, ok := msg.Payload["health"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected health payload")
|
||||
}
|
||||
if healthPayload["state"] != "paused" {
|
||||
t.Fatalf("expected health state paused")
|
||||
}
|
||||
default:
|
||||
t.Fatalf("expected heartbeat message")
|
||||
}
|
||||
}
|
||||
13
agent/test/integration/task_test.go
Normal file
13
agent/test/integration/task_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTaskExecutionFlow(t *testing.T) {
|
||||
if os.Getenv("AGENT_INTEGRATION") == "" {
|
||||
t.Skip("set AGENT_INTEGRATION=1 to run integration tests")
|
||||
}
|
||||
// TODO: wire up real server + docker environment for end-to-end validation.
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
"""
|
||||
系统日志服务模块
|
||||
|
||||
提供系统日志的读取功能,支持:
|
||||
- 从日志目录读取日志文件
|
||||
- 限制返回行数,防止内存溢出
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SystemLogService:
|
||||
"""
|
||||
系统日志服务类
|
||||
|
||||
负责读取系统日志文件,支持从容器内路径或宿主机挂载路径读取日志。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 日志文件路径(统一使用 /opt/xingrin/logs)
|
||||
self.log_file = "/opt/xingrin/logs/xingrin.log"
|
||||
self.default_lines = 200 # 默认返回行数
|
||||
self.max_lines = 10000 # 最大返回行数限制
|
||||
self.timeout_seconds = 3 # tail 命令超时时间
|
||||
|
||||
def get_logs_content(self, lines: int | None = None) -> str:
|
||||
"""
|
||||
获取系统日志内容
|
||||
|
||||
Args:
|
||||
lines: 返回的日志行数,默认 200 行,最大 10000 行
|
||||
|
||||
Returns:
|
||||
str: 日志内容,每行以换行符分隔,保持原始顺序
|
||||
"""
|
||||
# 参数校验和默认值处理
|
||||
if lines is None:
|
||||
lines = self.default_lines
|
||||
|
||||
lines = int(lines)
|
||||
if lines < 1:
|
||||
lines = 1
|
||||
if lines > self.max_lines:
|
||||
lines = self.max_lines
|
||||
|
||||
# 使用 tail 命令读取日志文件末尾内容
|
||||
cmd = ["tail", "-n", str(lines), self.log_file]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=self.timeout_seconds,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
"tail command failed: returncode=%s stderr=%s",
|
||||
result.returncode,
|
||||
(result.stderr or "").strip(),
|
||||
)
|
||||
|
||||
# 直接返回原始内容,保持文件中的顺序
|
||||
return result.stdout or ""
|
||||
@@ -1,21 +0,0 @@
|
||||
"""
|
||||
通用模块 URL 配置
|
||||
|
||||
路由说明:
|
||||
- /api/auth/* 认证相关接口(登录、登出、用户信息)
|
||||
- /api/system/* 系统管理接口(日志查看等)
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
from .views import LoginView, LogoutView, MeView, ChangePasswordView, SystemLogsView
|
||||
|
||||
urlpatterns = [
|
||||
# 认证相关
|
||||
path('auth/login/', LoginView.as_view(), name='auth-login'),
|
||||
path('auth/logout/', LogoutView.as_view(), name='auth-logout'),
|
||||
path('auth/me/', MeView.as_view(), name='auth-me'),
|
||||
path('auth/change-password/', ChangePasswordView.as_view(), name='auth-change-password'),
|
||||
|
||||
# 系统管理
|
||||
path('system/logs/', SystemLogsView.as_view(), name='system-logs'),
|
||||
]
|
||||
@@ -1,116 +0,0 @@
|
||||
"""CSV 导出工具模块
|
||||
|
||||
提供流式 CSV 生成功能,支持:
|
||||
- UTF-8 BOM(Excel 兼容)
|
||||
- RFC 4180 规范转义
|
||||
- 流式生成(内存友好)
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from typing import Iterator, Dict, Any, List, Callable, Optional
|
||||
|
||||
# UTF-8 BOM,确保 Excel 正确识别编码
|
||||
UTF8_BOM = '\ufeff'
|
||||
|
||||
|
||||
def generate_csv_rows(
|
||||
data_iterator: Iterator[Dict[str, Any]],
|
||||
headers: List[str],
|
||||
field_formatters: Optional[Dict[str, Callable]] = None
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
流式生成 CSV 行
|
||||
|
||||
Args:
|
||||
data_iterator: 数据迭代器,每个元素是一个字典
|
||||
headers: CSV 表头列表
|
||||
field_formatters: 字段格式化函数字典,key 为字段名,value 为格式化函数
|
||||
|
||||
Yields:
|
||||
CSV 行字符串(包含换行符)
|
||||
|
||||
Example:
|
||||
>>> data = [{'ip': '192.168.1.1', 'hosts': ['a.com', 'b.com']}]
|
||||
>>> headers = ['ip', 'hosts']
|
||||
>>> formatters = {'hosts': format_list_field}
|
||||
>>> for row in generate_csv_rows(iter(data), headers, formatters):
|
||||
... print(row, end='')
|
||||
"""
|
||||
# 输出 BOM + 表头
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
|
||||
writer.writerow(headers)
|
||||
yield UTF8_BOM + output.getvalue()
|
||||
|
||||
# 输出数据行
|
||||
for row_data in data_iterator:
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
|
||||
|
||||
row = []
|
||||
for header in headers:
|
||||
value = row_data.get(header, '')
|
||||
if field_formatters and header in field_formatters:
|
||||
value = field_formatters[header](value)
|
||||
row.append(value if value is not None else '')
|
||||
|
||||
writer.writerow(row)
|
||||
yield output.getvalue()
|
||||
|
||||
|
||||
def format_list_field(values: List, separator: str = ';') -> str:
|
||||
"""
|
||||
将列表字段格式化为分号分隔的字符串
|
||||
|
||||
Args:
|
||||
values: 值列表
|
||||
separator: 分隔符,默认为分号
|
||||
|
||||
Returns:
|
||||
分隔符连接的字符串
|
||||
|
||||
Example:
|
||||
>>> format_list_field(['a.com', 'b.com'])
|
||||
'a.com;b.com'
|
||||
>>> format_list_field([80, 443])
|
||||
'80;443'
|
||||
>>> format_list_field([])
|
||||
''
|
||||
>>> format_list_field(None)
|
||||
''
|
||||
"""
|
||||
if not values:
|
||||
return ''
|
||||
return separator.join(str(v) for v in values)
|
||||
|
||||
|
||||
def format_datetime(dt: Optional[datetime]) -> str:
|
||||
"""
|
||||
格式化日期时间为字符串(转换为本地时区)
|
||||
|
||||
Args:
|
||||
dt: datetime 对象或 None
|
||||
|
||||
Returns:
|
||||
格式化的日期时间字符串,格式为 YYYY-MM-DD HH:MM:SS(本地时区)
|
||||
|
||||
Example:
|
||||
>>> from datetime import datetime
|
||||
>>> format_datetime(datetime(2024, 1, 15, 10, 30, 0))
|
||||
'2024-01-15 10:30:00'
|
||||
>>> format_datetime(None)
|
||||
''
|
||||
"""
|
||||
if dt is None:
|
||||
return ''
|
||||
if isinstance(dt, str):
|
||||
return dt
|
||||
|
||||
# 转换为本地时区(从 Django settings 获取)
|
||||
from django.utils import timezone
|
||||
if timezone.is_aware(dt):
|
||||
dt = timezone.localtime(dt)
|
||||
|
||||
return dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
@@ -1,41 +0,0 @@
|
||||
"""Xget proxy utilities for Git URL acceleration."""
|
||||
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def get_xget_proxy_url(original_url: str) -> str:
|
||||
"""
|
||||
Convert Git repository URL to Xget proxy format.
|
||||
|
||||
Args:
|
||||
original_url: Original repository URL, e.g., https://github.com/user/repo.git
|
||||
|
||||
Returns:
|
||||
Converted URL, e.g., https://xget.xi-xu.me/gh/https://github.com/user/repo.git
|
||||
If XGET_MIRROR is not set, returns the original URL unchanged.
|
||||
"""
|
||||
xget_mirror = os.getenv("XGET_MIRROR", "").strip()
|
||||
if not xget_mirror:
|
||||
return original_url
|
||||
|
||||
# Remove trailing slash from mirror URL if present
|
||||
xget_mirror = xget_mirror.rstrip("/")
|
||||
|
||||
parsed = urlparse(original_url)
|
||||
host = parsed.netloc.lower()
|
||||
|
||||
# Map domains to proxy prefixes
|
||||
prefix_map = {
|
||||
"github.com": "gh",
|
||||
"gitlab.com": "gl",
|
||||
"gitea.com": "gitea",
|
||||
"codeberg.org": "codeberg",
|
||||
}
|
||||
|
||||
for domain, prefix in prefix_map.items():
|
||||
if domain in host:
|
||||
return f"{xget_mirror}/{prefix}/{original_url}"
|
||||
|
||||
# Unknown domain, return original URL
|
||||
return original_url
|
||||
@@ -1,12 +0,0 @@
|
||||
"""
|
||||
通用模块视图导出
|
||||
|
||||
包含:
|
||||
- 认证相关视图:登录、登出、用户信息、修改密码
|
||||
- 系统日志视图:实时日志查看
|
||||
"""
|
||||
|
||||
from .auth_views import LoginView, LogoutView, MeView, ChangePasswordView
|
||||
from .system_log_views import SystemLogsView
|
||||
|
||||
__all__ = ['LoginView', 'LogoutView', 'MeView', 'ChangePasswordView', 'SystemLogsView']
|
||||
@@ -1,69 +0,0 @@
|
||||
"""
|
||||
系统日志视图模块
|
||||
|
||||
提供系统日志的 REST API 接口,供前端实时查看系统运行日志。
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.common.services.system_log_service import SystemLogService
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name="dispatch")
|
||||
class SystemLogsView(APIView):
|
||||
"""
|
||||
系统日志 API 视图
|
||||
|
||||
GET /api/system/logs/
|
||||
获取系统日志内容
|
||||
|
||||
Query Parameters:
|
||||
lines (int, optional): 返回的日志行数,默认 200,最大 10000
|
||||
|
||||
Response:
|
||||
{
|
||||
"content": "日志内容字符串..."
|
||||
}
|
||||
|
||||
Note:
|
||||
- 当前为开发阶段,暂时允许匿名访问
|
||||
- 生产环境应添加管理员权限验证
|
||||
"""
|
||||
|
||||
# TODO: 生产环境应改为 IsAdminUser 权限
|
||||
authentication_classes = []
|
||||
permission_classes = [AllowAny]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = SystemLogService()
|
||||
|
||||
def get(self, request):
|
||||
"""
|
||||
获取系统日志
|
||||
|
||||
支持通过 lines 参数控制返回行数,用于前端分页或实时刷新场景。
|
||||
"""
|
||||
try:
|
||||
# 解析 lines 参数
|
||||
lines_raw = request.query_params.get("lines")
|
||||
lines = int(lines_raw) if lines_raw is not None else None
|
||||
|
||||
# 调用服务获取日志内容
|
||||
content = self.service.get_logs_content(lines=lines)
|
||||
return Response({"content": content})
|
||||
except ValueError:
|
||||
return Response({"error": "lines 参数必须是整数"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
except Exception:
|
||||
logger.exception("获取系统日志失败")
|
||||
return Response({"error": "获取系统日志失败"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
@@ -1,684 +0,0 @@
|
||||
"""
|
||||
目录扫描 Flow
|
||||
|
||||
负责编排目录扫描的完整流程
|
||||
|
||||
架构:
|
||||
- Flow 负责编排多个原子 Task
|
||||
- 支持并发执行扫描工具(使用 ThreadPoolTaskRunner)
|
||||
- 每个 Task 可独立重试
|
||||
- 配置由 YAML 解析
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
from prefect import flow
|
||||
from prefect.task_runners import ThreadPoolTaskRunner
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from apps.scan.tasks.directory_scan import (
|
||||
export_sites_task,
|
||||
run_and_stream_save_directories_task
|
||||
)
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.utils import config_parser, build_scan_command, ensure_wordlist_local
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认最大并发数
|
||||
DEFAULT_MAX_WORKERS = 5
|
||||
|
||||
|
||||
def calculate_directory_scan_timeout(
|
||||
tool_config: dict,
|
||||
base_per_word: float = 1.0,
|
||||
min_timeout: int = 60,
|
||||
max_timeout: int = 7200
|
||||
) -> int:
|
||||
"""
|
||||
根据字典行数计算目录扫描超时时间
|
||||
|
||||
计算公式:超时时间 = 字典行数 × 每个单词基础时间
|
||||
超时范围:60秒 ~ 2小时(7200秒)
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置字典,包含 wordlist 路径
|
||||
base_per_word: 每个单词的基础时间(秒),默认 1.0秒
|
||||
min_timeout: 最小超时时间(秒),默认 60秒
|
||||
max_timeout: 最大超时时间(秒),默认 7200秒(2小时)
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒),范围:60 ~ 7200
|
||||
|
||||
Example:
|
||||
# 1000行字典 × 1.0秒 = 1000秒 → 限制为7200秒中的 1000秒
|
||||
# 10000行字典 × 1.0秒 = 10000秒 → 限制为7200秒(最大值)
|
||||
timeout = calculate_directory_scan_timeout(
|
||||
tool_config={'wordlist': '/path/to/wordlist.txt'}
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 从 tool_config 中获取 wordlist 路径
|
||||
wordlist_path = tool_config.get('wordlist')
|
||||
if not wordlist_path:
|
||||
logger.warning("工具配置中未指定 wordlist,使用默认超时: %d秒", min_timeout)
|
||||
return min_timeout
|
||||
|
||||
# 展开用户目录(~)
|
||||
wordlist_path = os.path.expanduser(wordlist_path)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(wordlist_path):
|
||||
logger.warning("字典文件不存在: %s,使用默认超时: %d秒", wordlist_path, min_timeout)
|
||||
return min_timeout
|
||||
|
||||
# 使用 wc -l 快速统计字典行数
|
||||
result = subprocess.run(
|
||||
['wc', '-l', wordlist_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
# wc -l 输出格式:行数 + 空格 + 文件名
|
||||
line_count = int(result.stdout.strip().split()[0])
|
||||
|
||||
# 计算超时时间
|
||||
timeout = int(line_count * base_per_word)
|
||||
|
||||
# 设置合理的下限(不再设置上限)
|
||||
timeout = max(min_timeout, timeout)
|
||||
|
||||
logger.info(
|
||||
"目录扫描超时计算 - 字典: %s, 行数: %d, 基础时间: %.3f秒/词, 计算超时: %d秒",
|
||||
wordlist_path, line_count, base_per_word, timeout
|
||||
)
|
||||
|
||||
return timeout
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error("统计字典行数失败: %s", e)
|
||||
# 失败时返回默认超时
|
||||
return min_timeout
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.error("解析字典行数失败: %s", e)
|
||||
return min_timeout
|
||||
except Exception as e:
|
||||
logger.error("计算超时时间异常: %s", e)
|
||||
return min_timeout
|
||||
|
||||
|
||||
def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> int:
|
||||
"""
|
||||
从单个工具配置中获取 max_workers 参数
|
||||
|
||||
Args:
|
||||
tool_config: 单个工具的配置字典,如 {'max_workers': 10, 'threads': 5, ...}
|
||||
default: 默认值,默认为 5
|
||||
|
||||
Returns:
|
||||
int: max_workers 值
|
||||
"""
|
||||
if not isinstance(tool_config, dict):
|
||||
return default
|
||||
|
||||
# 支持 max_workers 和 max-workers(YAML 中划线会被转换)
|
||||
max_workers = tool_config.get('max_workers') or tool_config.get('max-workers')
|
||||
if max_workers is not None and isinstance(max_workers, int) and max_workers > 0:
|
||||
return max_workers
|
||||
return default
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]:
|
||||
"""
|
||||
导出目标下的所有站点 URL 到文件(支持懒加载)
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(用于懒加载创建默认站点)
|
||||
directory_scan_dir: 目录扫描目录
|
||||
|
||||
Returns:
|
||||
tuple: (sites_file, site_count)
|
||||
|
||||
Raises:
|
||||
ValueError: 站点数量为 0
|
||||
"""
|
||||
logger.info("Step 1: 导出目标的所有站点 URL")
|
||||
|
||||
sites_file = str(directory_scan_dir / 'sites.txt')
|
||||
export_result = export_sites_task(
|
||||
target_id=target_id,
|
||||
output_file=sites_file,
|
||||
batch_size=1000 # 每次读取 1000 条,优化内存占用
|
||||
)
|
||||
|
||||
site_count = export_result['total_count']
|
||||
|
||||
logger.info(
|
||||
"✓ 站点 URL 导出完成 - 文件: %s, 数量: %d",
|
||||
export_result['output_file'],
|
||||
site_count
|
||||
)
|
||||
|
||||
if site_count == 0:
|
||||
logger.warning("目标下没有站点,无法执行目录扫描")
|
||||
# 不抛出异常,由上层决定如何处理
|
||||
# raise ValueError("目标下没有站点,无法执行目录扫描")
|
||||
|
||||
return export_result['output_file'], site_count
|
||||
|
||||
|
||||
def _run_scans_sequentially(
|
||||
enabled_tools: dict,
|
||||
sites_file: str,
|
||||
directory_scan_dir: Path,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
site_count: int,
|
||||
target_name: str
|
||||
) -> tuple[int, int, list]:
|
||||
"""
|
||||
串行执行目录扫描任务(支持多工具)- 已废弃,保留用于兼容
|
||||
|
||||
Args:
|
||||
enabled_tools: 启用的工具配置字典
|
||||
sites_file: 站点文件路径
|
||||
directory_scan_dir: 目录扫描目录
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
site_count: 站点数量
|
||||
target_name: 目标名称(用于错误日志)
|
||||
|
||||
Returns:
|
||||
tuple: (total_directories, processed_sites, failed_sites)
|
||||
"""
|
||||
# 读取站点列表
|
||||
sites = []
|
||||
with open(sites_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
site_url = line.strip()
|
||||
if site_url:
|
||||
sites.append(site_url)
|
||||
|
||||
logger.info("准备扫描 %d 个站点,使用工具: %s", len(sites), ', '.join(enabled_tools.keys()))
|
||||
|
||||
total_directories = 0
|
||||
processed_sites_set = set() # 使用 set 避免重复计数
|
||||
failed_sites = []
|
||||
|
||||
# 遍历每个工具
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
logger.info("="*60)
|
||||
logger.info("使用工具: %s", tool_name)
|
||||
logger.info("="*60)
|
||||
|
||||
# 如果配置了 wordlist_name,则先确保本地存在对应的字典文件(含 hash 校验)
|
||||
wordlist_name = tool_config.get('wordlist_name')
|
||||
if wordlist_name:
|
||||
try:
|
||||
local_wordlist_path = ensure_wordlist_local(wordlist_name)
|
||||
tool_config['wordlist'] = local_wordlist_path
|
||||
except Exception as exc:
|
||||
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
|
||||
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
|
||||
failed_sites.extend(sites)
|
||||
continue
|
||||
|
||||
# 逐个站点执行扫描
|
||||
for idx, site_url in enumerate(sites, 1):
|
||||
logger.info(
|
||||
"[%d/%d] 开始扫描站点: %s (工具: %s)",
|
||||
idx, len(sites), site_url, tool_name
|
||||
)
|
||||
|
||||
# 使用统一的命令构建器
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='directory_scan',
|
||||
command_params={
|
||||
'url': site_url
|
||||
},
|
||||
tool_config=tool_config
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
|
||||
idx, len(sites), tool_name, e, site_url
|
||||
)
|
||||
failed_sites.append(site_url)
|
||||
continue
|
||||
|
||||
# 单个站点超时:从配置中获取(支持 'auto' 动态计算)
|
||||
# ffuf 逐个站点扫描,timeout 就是单个站点的超时时间
|
||||
site_timeout = tool_config.get('timeout', 300)
|
||||
if site_timeout == 'auto':
|
||||
# 动态计算超时时间(基于字典行数)
|
||||
site_timeout = calculate_directory_scan_timeout(tool_config)
|
||||
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}秒")
|
||||
|
||||
# 生成日志文件路径
|
||||
from datetime import datetime
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = directory_scan_dir / f"{tool_name}_{timestamp}_{idx}.log"
|
||||
|
||||
try:
|
||||
# 直接调用 task(串行执行)
|
||||
result = run_and_stream_save_directories_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name, # 新增:工具名称
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
site_url=site_url,
|
||||
cwd=str(directory_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1000,
|
||||
timeout=site_timeout,
|
||||
log_file=str(log_file) # 新增:日志文件路径
|
||||
)
|
||||
|
||||
total_directories += result.get('created_directories', 0)
|
||||
processed_sites_set.add(site_url) # 使用 set 记录成功的站点
|
||||
|
||||
logger.info(
|
||||
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
|
||||
idx, len(sites), site_url,
|
||||
result.get('created_directories', 0)
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
# 超时异常单独处理
|
||||
failed_sites.append(site_url)
|
||||
logger.warning(
|
||||
"⚠️ [%d/%d] 站点扫描超时: %s - 超时配置: %d秒\n"
|
||||
"注意:超时前已解析的目录数据已保存到数据库,但扫描未完全完成。",
|
||||
idx, len(sites), site_url, site_timeout
|
||||
)
|
||||
except Exception as exc:
|
||||
# 其他异常
|
||||
failed_sites.append(site_url)
|
||||
logger.error(
|
||||
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
|
||||
idx, len(sites), site_url, exc
|
||||
)
|
||||
|
||||
# 每 10 个站点输出进度
|
||||
if idx % 10 == 0:
|
||||
logger.info(
|
||||
"进度: %d/%d (%.1f%%) - 已发现 %d 个目录",
|
||||
idx, len(sites), idx/len(sites)*100, total_directories
|
||||
)
|
||||
|
||||
# 计算成功和失败的站点数
|
||||
processed_count = len(processed_sites_set)
|
||||
|
||||
if failed_sites:
|
||||
logger.warning(
|
||||
"部分站点扫描失败: %d/%d",
|
||||
len(failed_sites), len(sites)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"✓ 串行目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
|
||||
processed_count, len(sites), len(failed_sites), total_directories
|
||||
)
|
||||
|
||||
return total_directories, processed_count, failed_sites
|
||||
|
||||
|
||||
def _generate_log_filename(tool_name: str, site_url: str, directory_scan_dir: Path) -> Path:
|
||||
"""
|
||||
生成唯一的日志文件名
|
||||
|
||||
使用 URL 的 hash 确保并发时不会冲突
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
site_url: 站点 URL
|
||||
directory_scan_dir: 目录扫描目录
|
||||
|
||||
Returns:
|
||||
Path: 日志文件路径
|
||||
"""
|
||||
url_hash = hashlib.md5(site_url.encode()).hexdigest()[:8]
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
|
||||
return directory_scan_dir / f"{tool_name}_{url_hash}_{timestamp}.log"
|
||||
|
||||
|
||||
def _run_scans_concurrently(
|
||||
enabled_tools: dict,
|
||||
sites_file: str,
|
||||
directory_scan_dir: Path,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
site_count: int,
|
||||
target_name: str
|
||||
) -> Tuple[int, int, List[str]]:
|
||||
"""
|
||||
并发执行目录扫描任务(使用 ThreadPoolTaskRunner)
|
||||
|
||||
Args:
|
||||
enabled_tools: 启用的工具配置字典
|
||||
sites_file: 站点文件路径
|
||||
directory_scan_dir: 目录扫描目录
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
site_count: 站点数量
|
||||
target_name: 目标名称(用于错误日志)
|
||||
|
||||
Returns:
|
||||
tuple: (total_directories, processed_sites, failed_sites)
|
||||
"""
|
||||
# 读取站点列表
|
||||
sites: List[str] = []
|
||||
with open(sites_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
site_url = line.strip()
|
||||
if site_url:
|
||||
sites.append(site_url)
|
||||
|
||||
if not sites:
|
||||
logger.warning("站点列表为空")
|
||||
return 0, 0, []
|
||||
|
||||
logger.info(
|
||||
"准备并发扫描 %d 个站点,使用工具: %s",
|
||||
len(sites), ', '.join(enabled_tools.keys())
|
||||
)
|
||||
|
||||
total_directories = 0
|
||||
processed_sites_count = 0
|
||||
failed_sites: List[str] = []
|
||||
|
||||
# 遍历每个工具
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 每个工具独立获取 max_workers 配置
|
||||
max_workers = _get_max_workers(tool_config)
|
||||
|
||||
logger.info("="*60)
|
||||
logger.info("使用工具: %s (并发模式, max_workers=%d)", tool_name, max_workers)
|
||||
logger.info("="*60)
|
||||
|
||||
# 如果配置了 wordlist_name,则先确保本地存在对应的字典文件(含 hash 校验)
|
||||
wordlist_name = tool_config.get('wordlist_name')
|
||||
if wordlist_name:
|
||||
try:
|
||||
local_wordlist_path = ensure_wordlist_local(wordlist_name)
|
||||
tool_config['wordlist'] = local_wordlist_path
|
||||
except Exception as exc:
|
||||
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
|
||||
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
|
||||
failed_sites.extend(sites)
|
||||
continue
|
||||
|
||||
# 计算超时时间(所有站点共用)
|
||||
site_timeout = tool_config.get('timeout', 300)
|
||||
if site_timeout == 'auto':
|
||||
site_timeout = calculate_directory_scan_timeout(tool_config)
|
||||
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}秒")
|
||||
|
||||
# 准备所有站点的扫描参数
|
||||
scan_params_list = []
|
||||
for idx, site_url in enumerate(sites, 1):
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='directory_scan',
|
||||
command_params={'url': site_url},
|
||||
tool_config=tool_config
|
||||
)
|
||||
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
|
||||
scan_params_list.append({
|
||||
'idx': idx,
|
||||
'site_url': site_url,
|
||||
'command': command,
|
||||
'log_file': str(log_file),
|
||||
'timeout': site_timeout
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
|
||||
idx, len(sites), tool_name, e, site_url
|
||||
)
|
||||
failed_sites.append(site_url)
|
||||
|
||||
if not scan_params_list:
|
||||
logger.warning("没有有效的扫描任务")
|
||||
continue
|
||||
|
||||
# ============================================================
|
||||
# 分批执行策略:控制实际并发的 ffuf 进程数
|
||||
# ============================================================
|
||||
total_tasks = len(scan_params_list)
|
||||
logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers)
|
||||
|
||||
batch_num = 0
|
||||
for batch_start in range(0, total_tasks, max_workers):
|
||||
batch_end = min(batch_start + max_workers, total_tasks)
|
||||
batch_params = scan_params_list[batch_start:batch_end]
|
||||
batch_num += 1
|
||||
|
||||
logger.info("执行第 %d 批任务(%d-%d/%d)...", batch_num, batch_start + 1, batch_end, total_tasks)
|
||||
|
||||
# 提交当前批次的任务(非阻塞,立即返回 future)
|
||||
futures = []
|
||||
for params in batch_params:
|
||||
future = run_and_stream_save_directories_task.submit(
|
||||
cmd=params['command'],
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
site_url=params['site_url'],
|
||||
cwd=str(directory_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1000,
|
||||
timeout=params['timeout'],
|
||||
log_file=params['log_file']
|
||||
)
|
||||
futures.append((params['idx'], params['site_url'], future))
|
||||
|
||||
# 等待当前批次所有任务完成(阻塞,确保本批完成后再启动下一批)
|
||||
for idx, site_url, future in futures:
|
||||
try:
|
||||
result = future.result() # 阻塞等待单个任务完成
|
||||
directories_found = result.get('created_directories', 0)
|
||||
total_directories += directories_found
|
||||
processed_sites_count += 1
|
||||
|
||||
logger.info(
|
||||
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
|
||||
idx, len(sites), site_url, directories_found
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
failed_sites.append(site_url)
|
||||
if 'timeout' in str(exc).lower() or isinstance(exc, subprocess.TimeoutExpired):
|
||||
logger.warning(
|
||||
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
|
||||
idx, len(sites), site_url, exc
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
|
||||
idx, len(sites), site_url, exc
|
||||
)
|
||||
|
||||
# 输出汇总信息
|
||||
if failed_sites:
|
||||
logger.warning(
|
||||
"部分站点扫描失败: %d/%d",
|
||||
len(failed_sites), len(sites)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"✓ 并发目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
|
||||
processed_sites_count, len(sites), len(failed_sites), total_directories
|
||||
)
|
||||
|
||||
return total_directories, processed_sites_count, failed_sites
|
||||
|
||||
|
||||
@flow(
|
||||
name="directory_scan",
|
||||
log_prints=True,
|
||||
on_running=[on_scan_flow_running],
|
||||
on_completion=[on_scan_flow_completed],
|
||||
on_failure=[on_scan_flow_failed],
|
||||
)
|
||||
def directory_scan_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict
|
||||
) -> dict:
|
||||
"""
|
||||
目录扫描 Flow
|
||||
|
||||
主要功能:
|
||||
1. 从 target 获取所有站点的 URL
|
||||
2. 对每个站点 URL 执行目录扫描(支持 ffuf 等工具)
|
||||
3. 流式保存扫描结果到数据库 Directory 表
|
||||
|
||||
工作流程:
|
||||
Step 0: 创建工作目录
|
||||
Step 1: 导出站点 URL 列表到文件(供扫描工具使用)
|
||||
Step 2: 验证工具配置
|
||||
Step 3: 并发执行扫描工具并实时保存结果(使用 ThreadPoolTaskRunner)
|
||||
|
||||
ffuf 输出字段:
|
||||
- url: 发现的目录/文件 URL
|
||||
- length: 响应内容长度
|
||||
- status: HTTP 状态码
|
||||
- words: 响应内容单词数
|
||||
- lines: 响应内容行数
|
||||
- content_type: 内容类型
|
||||
- duration: 请求耗时
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置字典
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'scan_workspace_dir': str,
|
||||
'sites_file': str,
|
||||
'site_count': int,
|
||||
'total_directories': int, # 发现的总目录数
|
||||
'processed_sites': int, # 成功处理的站点数
|
||||
'failed_sites_count': int, # 失败的站点数
|
||||
'executed_tasks': list
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 参数错误
|
||||
RuntimeError: 执行失败
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始目录扫描\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
)
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
if not enabled_tools:
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
|
||||
|
||||
# Step 1: 导出站点 URL(支持懒加载)
|
||||
sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir)
|
||||
|
||||
if site_count == 0:
|
||||
logger.warning("目标下没有站点,跳过目录扫描")
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'sites_file': sites_file,
|
||||
'site_count': 0,
|
||||
'total_directories': 0,
|
||||
'processed_sites': 0,
|
||||
'failed_sites_count': 0,
|
||||
'executed_tasks': ['export_sites']
|
||||
}
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
logger.info("Step 2: 工具配置信息")
|
||||
tool_info = []
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
mw = _get_max_workers(tool_config)
|
||||
tool_info.append(f"{tool_name}(max_workers={mw})")
|
||||
logger.info("✓ 启用工具: %s", ', '.join(tool_info))
|
||||
|
||||
# Step 3: 并发执行扫描工具并实时保存结果
|
||||
logger.info("Step 3: 并发执行扫描工具并实时保存结果")
|
||||
total_directories, processed_sites, failed_sites = _run_scans_concurrently(
|
||||
enabled_tools=enabled_tools,
|
||||
sites_file=sites_file,
|
||||
directory_scan_dir=directory_scan_dir,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
site_count=site_count,
|
||||
target_name=target_name
|
||||
)
|
||||
|
||||
# 检查是否所有站点都失败
|
||||
if processed_sites == 0 and site_count > 0:
|
||||
logger.warning("所有站点扫描均失败 - 总站点数: %d, 失败数: %d", site_count, len(failed_sites))
|
||||
# 不抛出异常,让扫描继续
|
||||
|
||||
logger.info("="*60 + "\n✓ 目录扫描完成\n" + "="*60)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'sites_file': sites_file,
|
||||
'site_count': site_count,
|
||||
'total_directories': total_directories,
|
||||
'processed_sites': processed_sites,
|
||||
'failed_sites_count': len(failed_sites),
|
||||
'executed_tasks': ['export_sites', 'run_and_stream_save_directories']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("目录扫描失败: %s", e)
|
||||
raise
|
||||
@@ -1,380 +0,0 @@
|
||||
"""
|
||||
指纹识别 Flow
|
||||
|
||||
负责编排指纹识别的完整流程
|
||||
|
||||
架构:
|
||||
- Flow 负责编排多个原子 Task
|
||||
- 在 site_scan 后串行执行
|
||||
- 使用 xingfinger 工具识别技术栈
|
||||
- 流式处理输出,批量更新数据库
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.tasks.fingerprint_detect import (
|
||||
export_urls_for_fingerprint_task,
|
||||
run_xingfinger_and_stream_update_tech_task,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command
|
||||
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_fingerprint_detect_timeout(
|
||||
url_count: int,
|
||||
base_per_url: float = 3.0,
|
||||
min_timeout: int = 60
|
||||
) -> int:
|
||||
"""
|
||||
根据 URL 数量计算超时时间
|
||||
|
||||
公式:超时时间 = URL 数量 × 每 URL 基础时间
|
||||
最小值:60秒
|
||||
无上限
|
||||
|
||||
Args:
|
||||
url_count: URL 数量
|
||||
base_per_url: 每 URL 基础时间(秒),默认 3秒
|
||||
min_timeout: 最小超时时间(秒),默认 60秒
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒)
|
||||
|
||||
示例:
|
||||
100 URL × 3秒 = 300秒
|
||||
1000 URL × 3秒 = 3000秒(50分钟)
|
||||
10000 URL × 3秒 = 30000秒(8.3小时)
|
||||
"""
|
||||
timeout = int(url_count * base_per_url)
|
||||
return max(min_timeout, timeout)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _export_urls(
|
||||
target_id: int,
|
||||
fingerprint_dir: Path,
|
||||
source: str = 'website'
|
||||
) -> tuple[str, int]:
|
||||
"""
|
||||
导出 URL 到文件
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
fingerprint_dir: 指纹识别目录
|
||||
source: 数据源类型
|
||||
|
||||
Returns:
|
||||
tuple: (urls_file, total_count)
|
||||
"""
|
||||
logger.info("Step 1: 导出 URL 列表 (source=%s)", source)
|
||||
|
||||
urls_file = str(fingerprint_dir / 'urls.txt')
|
||||
export_result = export_urls_for_fingerprint_task(
|
||||
target_id=target_id,
|
||||
output_file=urls_file,
|
||||
source=source,
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
total_count = export_result['total_count']
|
||||
|
||||
logger.info(
|
||||
"✓ URL 导出完成 - 文件: %s, 数量: %d",
|
||||
export_result['output_file'],
|
||||
total_count
|
||||
)
|
||||
|
||||
return export_result['output_file'], total_count
|
||||
|
||||
|
||||
def _run_fingerprint_detect(
|
||||
enabled_tools: dict,
|
||||
urls_file: str,
|
||||
url_count: int,
|
||||
fingerprint_dir: Path,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
source: str
|
||||
) -> tuple[dict, list]:
|
||||
"""
|
||||
执行指纹识别任务
|
||||
|
||||
Args:
|
||||
enabled_tools: 已启用的工具配置字典
|
||||
urls_file: URL 文件路径
|
||||
url_count: URL 总数
|
||||
fingerprint_dir: 指纹识别目录
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
source: 数据源类型
|
||||
|
||||
Returns:
|
||||
tuple: (tool_stats, failed_tools)
|
||||
"""
|
||||
tool_stats = {}
|
||||
failed_tools = []
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 1. 获取指纹库路径
|
||||
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
|
||||
fingerprint_paths = get_fingerprint_paths(lib_names)
|
||||
|
||||
if not fingerprint_paths:
|
||||
reason = f"没有可用的指纹库: {lib_names}"
|
||||
logger.warning(reason)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
# 2. 将指纹库路径合并到 tool_config(用于命令构建)
|
||||
tool_config_with_paths = {**tool_config, **fingerprint_paths}
|
||||
|
||||
# 3. 构建命令
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='fingerprint_detect',
|
||||
command_params={
|
||||
'urls_file': urls_file
|
||||
},
|
||||
tool_config=tool_config_with_paths
|
||||
)
|
||||
except Exception as e:
|
||||
reason = f"命令构建失败: {str(e)}"
|
||||
logger.error("构建 %s 命令失败: %s", tool_name, e)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
# 4. 计算超时时间
|
||||
timeout = calculate_fingerprint_detect_timeout(url_count)
|
||||
|
||||
# 5. 生成日志文件路径
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = fingerprint_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
logger.info(
|
||||
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
|
||||
tool_name, url_count, timeout, list(fingerprint_paths.keys())
|
||||
)
|
||||
|
||||
# 6. 执行扫描任务
|
||||
try:
|
||||
result = run_xingfinger_and_stream_update_tech_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
source=source,
|
||||
cwd=str(fingerprint_dir),
|
||||
timeout=timeout,
|
||||
log_file=str(log_file),
|
||||
batch_size=100
|
||||
)
|
||||
|
||||
tool_stats[tool_name] = {
|
||||
'command': command,
|
||||
'result': result,
|
||||
'timeout': timeout,
|
||||
'fingerprint_libs': list(fingerprint_paths.keys())
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
|
||||
tool_name,
|
||||
result.get('processed_records', 0),
|
||||
result.get('updated_count', 0),
|
||||
result.get('not_found_count', 0)
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
failed_tools.append({'tool': tool_name, 'reason': str(exc)})
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
|
||||
if failed_tools:
|
||||
logger.warning(
|
||||
"以下指纹识别工具执行失败: %s",
|
||||
', '.join([f['tool'] for f in failed_tools])
|
||||
)
|
||||
|
||||
return tool_stats, failed_tools
|
||||
|
||||
|
||||
@flow(
|
||||
name="fingerprint_detect",
|
||||
log_prints=True,
|
||||
on_running=[on_scan_flow_running],
|
||||
on_completion=[on_scan_flow_completed],
|
||||
on_failure=[on_scan_flow_failed],
|
||||
)
|
||||
def fingerprint_detect_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict
|
||||
) -> dict:
|
||||
"""
|
||||
指纹识别 Flow
|
||||
|
||||
主要功能:
|
||||
1. 从数据库导出目标下所有 WebSite URL 到文件
|
||||
2. 使用 xingfinger 进行技术栈识别
|
||||
3. 解析结果并更新 WebSite.tech 字段(合并去重)
|
||||
|
||||
工作流程:
|
||||
Step 0: 创建工作目录
|
||||
Step 1: 导出 URL 列表
|
||||
Step 2: 解析配置,获取启用的工具
|
||||
Step 3: 执行 xingfinger 并解析结果
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置(xingfinger)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'scan_workspace_dir': str,
|
||||
'urls_file': str,
|
||||
'url_count': int,
|
||||
'processed_records': int,
|
||||
'updated_count': int,
|
||||
'not_found_count': int,
|
||||
'executed_tasks': list,
|
||||
'tool_stats': dict
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始指纹识别\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
)
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
|
||||
# 数据源类型(当前只支持 website)
|
||||
source = 'website'
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
|
||||
|
||||
# Step 1: 导出 URL(支持懒加载)
|
||||
urls_file, url_count = _export_urls(target_id, fingerprint_dir, source)
|
||||
|
||||
if url_count == 0:
|
||||
logger.warning("目标下没有可用的 URL,跳过指纹识别")
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'url_count': 0,
|
||||
'processed_records': 0,
|
||||
'updated_count': 0,
|
||||
'created_count': 0,
|
||||
'executed_tasks': ['export_urls_for_fingerprint'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'successful_tools': [],
|
||||
'failed_tools': [],
|
||||
'details': {}
|
||||
}
|
||||
}
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
logger.info("Step 2: 工具配置信息")
|
||||
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
|
||||
|
||||
# Step 3: 执行指纹识别
|
||||
logger.info("Step 3: 执行指纹识别")
|
||||
tool_stats, failed_tools = _run_fingerprint_detect(
|
||||
enabled_tools=enabled_tools,
|
||||
urls_file=urls_file,
|
||||
url_count=url_count,
|
||||
fingerprint_dir=fingerprint_dir,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
source=source
|
||||
)
|
||||
|
||||
logger.info("="*60 + "\n✓ 指纹识别完成\n" + "="*60)
|
||||
|
||||
# 动态生成已执行的任务列表
|
||||
executed_tasks = ['export_urls_for_fingerprint']
|
||||
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats.keys()])
|
||||
|
||||
# 汇总所有工具的结果
|
||||
total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values())
|
||||
total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values())
|
||||
total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values())
|
||||
|
||||
successful_tools = [name for name in enabled_tools.keys()
|
||||
if name not in [f['tool'] for f in failed_tools]]
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'url_count': url_count,
|
||||
'processed_records': total_processed,
|
||||
'updated_count': total_updated,
|
||||
'created_count': total_created,
|
||||
'executed_tasks': executed_tasks,
|
||||
'tool_stats': {
|
||||
'total': len(enabled_tools),
|
||||
'successful': len(successful_tools),
|
||||
'failed': len(failed_tools),
|
||||
'successful_tools': successful_tools,
|
||||
'failed_tools': failed_tools,
|
||||
'details': tool_stats
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("配置错误: %s", e)
|
||||
raise
|
||||
except RuntimeError as e:
|
||||
logger.error("运行时错误: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("指纹识别失败: %s", e)
|
||||
raise
|
||||
@@ -1,279 +0,0 @@
|
||||
"""
|
||||
扫描初始化 Flow
|
||||
|
||||
负责编排扫描任务的初始化流程
|
||||
|
||||
职责:
|
||||
- 使用 FlowOrchestrator 解析 YAML 配置
|
||||
- 在 Prefect Flow 中执行子 Flow(Subflow)
|
||||
- 按照 YAML 顺序编排工作流
|
||||
- 不包含具体业务逻辑(由 Tasks 和 FlowOrchestrator 实现)
|
||||
|
||||
架构:
|
||||
- Flow: Prefect 编排层(本文件)
|
||||
- FlowOrchestrator: 配置解析和执行计划(apps/scan/services/)
|
||||
- Tasks: 执行层(apps/scan/tasks/)
|
||||
- Handlers: 状态管理(apps/scan/handlers/)
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
# 注意:动态扫描容器应使用 run_initiate_scan.py 启动,以便在导入前设置环境变量
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
from prefect import flow, task
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from apps.scan.handlers import (
|
||||
on_initiate_scan_flow_running,
|
||||
on_initiate_scan_flow_completed,
|
||||
on_initiate_scan_flow_failed,
|
||||
)
|
||||
from prefect.futures import wait
|
||||
from apps.scan.utils import setup_scan_workspace
|
||||
from apps.scan.orchestrators import FlowOrchestrator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="run_subflow")
|
||||
def _run_subflow_task(scan_type: str, flow_func, flow_kwargs: dict):
|
||||
"""包装子 Flow 的 Task,用于在并行阶段并发执行子 Flow。"""
|
||||
logger.info("开始执行子 Flow: %s", scan_type)
|
||||
return flow_func(**flow_kwargs)
|
||||
|
||||
|
||||
@flow(
|
||||
name='initiate_scan',
|
||||
description='扫描任务初始化流程',
|
||||
log_prints=True,
|
||||
on_running=[on_initiate_scan_flow_running],
|
||||
on_completion=[on_initiate_scan_flow_completed],
|
||||
on_failure=[on_initiate_scan_flow_failed],
|
||||
)
|
||||
def initiate_scan_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
engine_name: str,
|
||||
scheduled_scan_name: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
初始化扫描任务(动态工作流编排)
|
||||
|
||||
根据 YAML 配置动态编排工作流:
|
||||
- 从数据库获取 engine_config (YAML)
|
||||
- 检测启用的扫描类型
|
||||
- 按照定义的阶段执行:
|
||||
Stage 1: Discovery (顺序执行)
|
||||
- subdomain_discovery
|
||||
- port_scan
|
||||
- site_scan
|
||||
Stage 2: Analysis (并行执行)
|
||||
- url_fetch
|
||||
- directory_scan
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: Scan 工作空间目录路径
|
||||
engine_name: 引擎名称(用于显示)
|
||||
scheduled_scan_name: 定时扫描任务名称(可选,用于通知显示)
|
||||
|
||||
Returns:
|
||||
dict: 执行结果摘要
|
||||
|
||||
Raises:
|
||||
ValueError: 参数验证失败或配置无效
|
||||
RuntimeError: 执行失败
|
||||
"""
|
||||
try:
|
||||
# ==================== 参数验证 ====================
|
||||
if not scan_id:
|
||||
raise ValueError("scan_id is required")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir is required")
|
||||
if not engine_name:
|
||||
raise ValueError("engine_name is required")
|
||||
|
||||
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始初始化扫描任务\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Engine: {engine_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
)
|
||||
|
||||
# ==================== Task 1: 创建 Scan 工作空间 ====================
|
||||
scan_workspace_path = setup_scan_workspace(scan_workspace_dir)
|
||||
|
||||
# ==================== Task 2: 获取引擎配置 ====================
|
||||
from apps.scan.models import Scan
|
||||
scan = Scan.objects.select_related('engine').get(id=scan_id)
|
||||
engine_config = scan.engine.configuration
|
||||
|
||||
# ==================== Task 3: 解析配置,生成执行计划 ====================
|
||||
orchestrator = FlowOrchestrator(engine_config)
|
||||
|
||||
# FlowOrchestrator 已经解析了所有工具配置
|
||||
enabled_tools_by_type = orchestrator.enabled_tools_by_type
|
||||
|
||||
logger.info(
|
||||
f"执行计划生成成功:\n"
|
||||
f" 扫描类型: {' → '.join(orchestrator.scan_types)}\n"
|
||||
f" 总共 {len(orchestrator.scan_types)} 个 Flow"
|
||||
)
|
||||
|
||||
# ==================== 初始化阶段进度 ====================
|
||||
# 在解析完配置后立即初始化,此时已有完整的 scan_types 列表
|
||||
from apps.scan.services import ScanService
|
||||
scan_service = ScanService()
|
||||
scan_service.init_stage_progress(scan_id, orchestrator.scan_types)
|
||||
logger.info(f"✓ 初始化阶段进度 - Stages: {orchestrator.scan_types}")
|
||||
|
||||
# ==================== 更新 Target 最后扫描时间 ====================
|
||||
# 在开始扫描时更新,表示"最后一次扫描开始时间"
|
||||
from apps.targets.services import TargetService
|
||||
target_service = TargetService()
|
||||
target_service.update_last_scanned_at(target_id)
|
||||
logger.info(f"✓ 更新 Target 最后扫描时间 - Target ID: {target_id}")
|
||||
|
||||
# ==================== Task 3: 执行 Flow(动态阶段执行)====================
|
||||
# 注意:各阶段状态更新由 scan_flow_handlers.py 自动处理(running/completed/failed)
|
||||
executed_flows = []
|
||||
results = {}
|
||||
|
||||
# 通用执行参数
|
||||
flow_kwargs = {
|
||||
'scan_id': scan_id,
|
||||
'target_name': target_name,
|
||||
'target_id': target_id,
|
||||
'scan_workspace_dir': str(scan_workspace_path)
|
||||
}
|
||||
|
||||
def record_flow_result(scan_type, result=None, error=None):
|
||||
"""
|
||||
统一的结果记录函数
|
||||
|
||||
Args:
|
||||
scan_type: 扫描类型名称
|
||||
result: 执行结果(成功时)
|
||||
error: 异常对象(失败时)
|
||||
"""
|
||||
if error:
|
||||
# 失败处理:记录错误但不抛出异常,让扫描继续执行后续阶段
|
||||
error_msg = f"{scan_type} 执行失败: {str(error)}"
|
||||
logger.warning(error_msg)
|
||||
executed_flows.append(f"{scan_type} (失败)")
|
||||
results[scan_type] = {'success': False, 'error': str(error)}
|
||||
# 不再抛出异常,让扫描继续
|
||||
else:
|
||||
# 成功处理
|
||||
executed_flows.append(scan_type)
|
||||
results[scan_type] = result
|
||||
logger.info(f"✓ {scan_type} 执行成功")
|
||||
|
||||
def get_valid_flows(flow_names):
|
||||
"""
|
||||
获取有效的 Flow 函数列表,并为每个 Flow 准备专属参数
|
||||
|
||||
Args:
|
||||
flow_names: 扫描类型名称列表
|
||||
|
||||
Returns:
|
||||
list: [(scan_type, flow_func, flow_specific_kwargs), ...] 有效的函数列表
|
||||
"""
|
||||
valid_flows = []
|
||||
for scan_type in flow_names:
|
||||
flow_func = orchestrator.get_flow_function(scan_type)
|
||||
if flow_func:
|
||||
# 为每个 Flow 准备专属的参数(包含对应的 enabled_tools)
|
||||
flow_specific_kwargs = dict(flow_kwargs)
|
||||
flow_specific_kwargs['enabled_tools'] = enabled_tools_by_type.get(scan_type, {})
|
||||
valid_flows.append((scan_type, flow_func, flow_specific_kwargs))
|
||||
else:
|
||||
logger.warning(f"跳过未实现的 Flow: {scan_type}")
|
||||
return valid_flows
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 动态阶段执行(基于 FlowOrchestrator 定义)
|
||||
# ---------------------------------------------------------
|
||||
for mode, enabled_flows in orchestrator.get_execution_stages():
|
||||
if mode == 'sequential':
|
||||
# 顺序执行
|
||||
logger.info(f"\n{'='*60}\n顺序执行阶段: {', '.join(enabled_flows)}\n{'='*60}")
|
||||
for scan_type, flow_func, flow_specific_kwargs in get_valid_flows(enabled_flows):
|
||||
logger.info(f"\n{'='*60}\n执行 Flow: {scan_type}\n{'='*60}")
|
||||
try:
|
||||
result = flow_func(**flow_specific_kwargs)
|
||||
record_flow_result(scan_type, result=result)
|
||||
except Exception as e:
|
||||
record_flow_result(scan_type, error=e)
|
||||
|
||||
elif mode == 'parallel':
|
||||
# 并行执行阶段:通过 Task 包装子 Flow,并使用 Prefect TaskRunner 并发运行
|
||||
logger.info(f"\n{'='*60}\n并行执行阶段: {', '.join(enabled_flows)}\n{'='*60}")
|
||||
futures = []
|
||||
|
||||
# 提交所有并行子 Flow 任务
|
||||
for scan_type, flow_func, flow_specific_kwargs in get_valid_flows(enabled_flows):
|
||||
logger.info(f"\n{'='*60}\n提交并行子 Flow 任务: {scan_type}\n{'='*60}")
|
||||
future = _run_subflow_task.submit(
|
||||
scan_type=scan_type,
|
||||
flow_func=flow_func,
|
||||
flow_kwargs=flow_specific_kwargs,
|
||||
)
|
||||
futures.append((scan_type, future))
|
||||
|
||||
# 等待所有并行子 Flow 完成
|
||||
if futures:
|
||||
wait([f for _, f in futures])
|
||||
|
||||
# 检查结果(复用统一的结果处理逻辑)
|
||||
for scan_type, future in futures:
|
||||
try:
|
||||
result = future.result()
|
||||
record_flow_result(scan_type, result=result)
|
||||
except Exception as e:
|
||||
record_flow_result(scan_type, error=e)
|
||||
|
||||
# ==================== 完成 ====================
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"✓ 扫描任务初始化完成\n" +
|
||||
f" 执行的 Flow: {', '.join(executed_flows)}\n" +
|
||||
"="*60
|
||||
)
|
||||
|
||||
# ==================== 返回结果 ====================
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': str(scan_workspace_path),
|
||||
'executed_flows': executed_flows,
|
||||
'results': results
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
# 参数错误
|
||||
logger.error("参数错误: %s", e)
|
||||
raise
|
||||
except RuntimeError as e:
|
||||
# 执行失败
|
||||
logger.error("运行时错误: %s", e)
|
||||
raise
|
||||
except OSError as e:
|
||||
# 文件系统错误(工作空间创建失败)
|
||||
logger.error("文件系统错误: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
# 其他未预期错误
|
||||
logger.exception("初始化扫描任务失败: %s", e)
|
||||
# 注意:失败状态更新由 Prefect State Handlers 自动处理
|
||||
raise
|
||||
@@ -1,478 +0,0 @@
|
||||
|
||||
"""
|
||||
站点扫描 Flow
|
||||
|
||||
负责编排站点扫描的完整流程
|
||||
|
||||
架构:
|
||||
- Flow 负责编排多个原子 Task
|
||||
- 支持串行执行扫描工具(流式处理)
|
||||
- 每个 Task 可独立重试
|
||||
- 配置由 YAML 解析
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from prefect import flow
|
||||
from apps.scan.tasks.site_scan import export_site_urls_task, run_and_stream_save_websites_task
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.utils import config_parser, build_scan_command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_timeout_by_line_count(
|
||||
tool_config: dict,
|
||||
file_path: str,
|
||||
base_per_time: int = 1,
|
||||
min_timeout: int = 60
|
||||
) -> int:
|
||||
"""
|
||||
根据文件行数计算 timeout
|
||||
|
||||
使用 wc -l 统计文件行数,根据行数和每行基础时间计算 timeout
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置字典(此函数未使用,但保持接口一致性)
|
||||
file_path: 要统计行数的文件路径
|
||||
base_per_time: 每行的基础时间(秒),默认1秒
|
||||
min_timeout: 最小超时时间(秒),默认60秒
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒),不低于 min_timeout
|
||||
|
||||
Example:
|
||||
timeout = calculate_timeout_by_line_count(
|
||||
tool_config={},
|
||||
file_path='/path/to/urls.txt',
|
||||
base_per_time=2
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 使用 wc -l 快速统计行数
|
||||
result = subprocess.run(
|
||||
['wc', '-l', file_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
# wc -l 输出格式:行数 + 空格 + 文件名
|
||||
line_count = int(result.stdout.strip().split()[0])
|
||||
|
||||
# 计算 timeout:行数 × 每行基础时间,不低于最小值
|
||||
timeout = max(line_count * base_per_time, min_timeout)
|
||||
|
||||
logger.info(
|
||||
f"timeout 自动计算: 文件={file_path}, "
|
||||
f"行数={line_count}, 每行时间={base_per_time}秒, 最小值={min_timeout}秒, timeout={timeout}秒"
|
||||
)
|
||||
|
||||
return timeout
|
||||
|
||||
except Exception as e:
|
||||
# 如果 wc -l 失败,使用默认值
|
||||
logger.warning(f"wc -l 计算行数失败: {e},使用默认 timeout: {min_timeout}秒")
|
||||
return min_timeout
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _export_site_urls(target_id: int, site_scan_dir: Path, target_name: str = None) -> tuple[str, int, int]:
|
||||
"""
|
||||
导出站点 URL 到文件
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
site_scan_dir: 站点扫描目录
|
||||
target_name: 目标名称(用于懒加载时写入默认值)
|
||||
|
||||
Returns:
|
||||
tuple: (urls_file, total_urls, association_count)
|
||||
|
||||
Raises:
|
||||
ValueError: URL 数量为 0
|
||||
"""
|
||||
logger.info("Step 1: 导出站点URL列表")
|
||||
|
||||
urls_file = str(site_scan_dir / 'site_urls.txt')
|
||||
export_result = export_site_urls_task(
|
||||
target_id=target_id,
|
||||
output_file=urls_file,
|
||||
batch_size=1000 # 每次处理1000个子域名
|
||||
)
|
||||
|
||||
total_urls = export_result['total_urls']
|
||||
association_count = export_result['association_count'] # 主机端口关联数
|
||||
|
||||
logger.info(
|
||||
"✓ 站点URL导出完成 - 文件: %s, URL数量: %d, 关联数: %d",
|
||||
export_result['output_file'],
|
||||
total_urls,
|
||||
association_count
|
||||
)
|
||||
|
||||
if total_urls == 0:
|
||||
logger.warning("目标下没有可用的站点URL,无法执行站点扫描")
|
||||
# 不抛出异常,由上层决定如何处理
|
||||
# raise ValueError("目标下没有可用的站点URL,无法执行站点扫描")
|
||||
|
||||
return export_result['output_file'], total_urls, association_count
|
||||
|
||||
|
||||
def _run_scans_sequentially(
|
||||
enabled_tools: dict,
|
||||
urls_file: str,
|
||||
total_urls: int,
|
||||
site_scan_dir: Path,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
target_name: str
|
||||
) -> tuple[dict, int, list, list]:
|
||||
"""
|
||||
串行执行站点扫描任务
|
||||
|
||||
Args:
|
||||
enabled_tools: 已启用的工具配置字典
|
||||
urls_file: URL 文件路径
|
||||
total_urls: URL 总数
|
||||
site_scan_dir: 站点扫描目录
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(用于错误日志)
|
||||
|
||||
Returns:
|
||||
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 所有工具均失败
|
||||
"""
|
||||
tool_stats = {}
|
||||
processed_records = 0
|
||||
failed_tools = []
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 1. 构建完整命令(变量替换)
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='site_scan',
|
||||
command_params={
|
||||
'url_file': urls_file
|
||||
},
|
||||
tool_config=tool_config
|
||||
)
|
||||
except Exception as e:
|
||||
reason = f"命令构建失败: {str(e)}"
|
||||
logger.error(f"构建 {tool_name} 命令失败: {e}")
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
# 2. 获取超时时间(支持 'auto' 动态计算)
|
||||
config_timeout = tool_config.get('timeout', 300)
|
||||
if config_timeout == 'auto':
|
||||
# 动态计算超时时间
|
||||
timeout = calculate_timeout_by_line_count(tool_config, urls_file, base_per_time=1)
|
||||
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {timeout}秒")
|
||||
else:
|
||||
# 使用配置的超时时间和动态计算的较大值
|
||||
dynamic_timeout = calculate_timeout_by_line_count(tool_config, urls_file, base_per_time=1)
|
||||
timeout = max(dynamic_timeout, config_timeout)
|
||||
|
||||
# 2.1 生成日志文件路径(类似端口扫描)
|
||||
from datetime import datetime
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = site_scan_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
logger.info(
|
||||
"开始执行 %s 站点扫描 - URL数: %d, 最终超时: %ds",
|
||||
tool_name, total_urls, timeout
|
||||
)
|
||||
|
||||
# 3. 执行扫描任务
|
||||
try:
|
||||
# 流式执行扫描并实时保存结果
|
||||
result = run_and_stream_save_websites_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name, # 新增:工具名称
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(site_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1000,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file) # 新增:日志文件路径
|
||||
)
|
||||
|
||||
tool_stats[tool_name] = {
|
||||
'command': command,
|
||||
'result': result,
|
||||
'timeout': timeout
|
||||
}
|
||||
processed_records += result.get('processed_records', 0)
|
||||
|
||||
logger.info(
|
||||
"✓ 工具 %s 流式处理完成 - 处理记录: %d, 创建站点: %d, 跳过: %d",
|
||||
tool_name,
|
||||
result.get('processed_records', 0),
|
||||
result.get('created_websites', 0),
|
||||
result.get('skipped_no_subdomain', 0) + result.get('skipped_failed', 0)
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
# 超时异常单独处理
|
||||
reason = f"执行超时(配置: {timeout}秒)"
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.warning(
|
||||
"⚠️ 工具 %s 执行超时 - 超时配置: %d秒\n"
|
||||
"注意:超时前已解析的站点数据已保存到数据库,但扫描未完全完成。",
|
||||
tool_name, timeout
|
||||
)
|
||||
except Exception as exc:
|
||||
# 其他异常
|
||||
failed_tools.append({'tool': tool_name, 'reason': str(exc)})
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
|
||||
if failed_tools:
|
||||
logger.warning(
|
||||
"以下扫描工具执行失败: %s",
|
||||
', '.join([f['tool'] for f in failed_tools])
|
||||
)
|
||||
|
||||
if not tool_stats:
|
||||
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
|
||||
logger.warning("所有站点扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
|
||||
# 返回空结果,不抛出异常,让扫描继续
|
||||
return {}, 0, [], failed_tools
|
||||
|
||||
# 动态计算成功的工具列表
|
||||
successful_tool_names = [name for name in enabled_tools.keys()
|
||||
if name not in [f['tool'] for f in failed_tools]]
|
||||
|
||||
logger.info(
|
||||
"✓ 串行站点扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
|
||||
len(tool_stats), len(enabled_tools),
|
||||
', '.join(successful_tool_names) if successful_tool_names else '无',
|
||||
', '.join([f['tool'] for f in failed_tools]) if failed_tools else '无'
|
||||
)
|
||||
|
||||
return tool_stats, processed_records, successful_tool_names, failed_tools
|
||||
|
||||
|
||||
def calculate_timeout(url_count: int, base: int = 600, per_url: int = 1) -> int:
|
||||
"""
|
||||
根据 URL 数量动态计算扫描超时时间
|
||||
|
||||
规则:
|
||||
- 基础时间:默认 600 秒(10 分钟)
|
||||
- 每个 URL 额外增加:默认 1 秒
|
||||
|
||||
Args:
|
||||
url_count: URL 数量,必须为正整数
|
||||
base: 基础超时时间(秒),默认 600
|
||||
per_url: 每个 URL 增加的时间(秒),默认 1
|
||||
|
||||
Returns:
|
||||
int: 计算得到的超时时间(秒),不超过 max_timeout
|
||||
|
||||
Raises:
|
||||
ValueError: 当 url_count 为负数或 0 时抛出异常
|
||||
"""
|
||||
if url_count < 0:
|
||||
raise ValueError(f"URL数量不能为负数: {url_count}")
|
||||
if url_count == 0:
|
||||
raise ValueError("URL数量不能为0")
|
||||
|
||||
timeout = base + int(url_count * per_url)
|
||||
|
||||
# 不设置上限,由调用方根据需要控制
|
||||
return timeout
|
||||
|
||||
|
||||
@flow(
|
||||
name="site_scan",
|
||||
log_prints=True,
|
||||
on_running=[on_scan_flow_running],
|
||||
on_completion=[on_scan_flow_completed],
|
||||
on_failure=[on_scan_flow_failed],
|
||||
)
|
||||
def site_scan_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict
|
||||
) -> dict:
|
||||
"""
|
||||
站点扫描 Flow
|
||||
|
||||
主要功能:
|
||||
1. 从target获取所有子域名与其对应的端口号,拼接成URL写入文件
|
||||
2. 用httpx进行批量请求并实时保存到数据库(流式处理)
|
||||
|
||||
工作流程:
|
||||
Step 0: 创建工作目录
|
||||
Step 1: 导出站点 URL 列表
|
||||
Step 2: 解析配置,获取启用的工具
|
||||
Step 3: 串行执行扫描工具并实时保存结果
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置字典
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'scan_workspace_dir': str,
|
||||
'urls_file': str,
|
||||
'total_urls': int,
|
||||
'association_count': int,
|
||||
'processed_records': int,
|
||||
'created_websites': int,
|
||||
'skipped_no_subdomain': int,
|
||||
'skipped_failed': int,
|
||||
'executed_tasks': list,
|
||||
'tool_stats': {
|
||||
'total': int,
|
||||
'successful': int,
|
||||
'failed': int,
|
||||
'successful_tools': list[str],
|
||||
'failed_tools': list[dict]
|
||||
}
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 配置错误
|
||||
RuntimeError: 执行失败
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始站点扫描\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
)
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan')
|
||||
|
||||
# Step 1: 导出站点 URL
|
||||
urls_file, total_urls, association_count = _export_site_urls(
|
||||
target_id, site_scan_dir, target_name
|
||||
)
|
||||
|
||||
if total_urls == 0:
|
||||
logger.warning("目标下没有可用的站点URL,跳过站点扫描")
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'total_urls': 0,
|
||||
'association_count': association_count,
|
||||
'processed_records': 0,
|
||||
'created_websites': 0,
|
||||
'skipped_no_subdomain': 0,
|
||||
'skipped_failed': 0,
|
||||
'executed_tasks': ['export_site_urls'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'successful_tools': [],
|
||||
'failed_tools': [],
|
||||
'details': {}
|
||||
}
|
||||
}
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
logger.info("Step 2: 工具配置信息")
|
||||
logger.info(
|
||||
"✓ 启用工具: %s",
|
||||
', '.join(enabled_tools.keys())
|
||||
)
|
||||
|
||||
# Step 3: 串行执行扫描工具
|
||||
logger.info("Step 3: 串行执行扫描工具并实时保存结果")
|
||||
tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially(
|
||||
enabled_tools=enabled_tools,
|
||||
urls_file=urls_file,
|
||||
total_urls=total_urls,
|
||||
site_scan_dir=site_scan_dir,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
target_name=target_name
|
||||
)
|
||||
|
||||
logger.info("="*60 + "\n✓ 站点扫描完成\n" + "="*60)
|
||||
|
||||
# 动态生成已执行的任务列表
|
||||
executed_tasks = ['export_site_urls', 'parse_config']
|
||||
executed_tasks.extend([f'run_and_stream_save_websites ({tool})' for tool in tool_stats.keys()])
|
||||
|
||||
# 汇总所有工具的结果
|
||||
total_created = sum(stats['result'].get('created_websites', 0) for stats in tool_stats.values())
|
||||
total_skipped_no_subdomain = sum(stats['result'].get('skipped_no_subdomain', 0) for stats in tool_stats.values())
|
||||
total_skipped_failed = sum(stats['result'].get('skipped_failed', 0) for stats in tool_stats.values())
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'total_urls': total_urls,
|
||||
'association_count': association_count,
|
||||
'processed_records': processed_records,
|
||||
'created_websites': total_created,
|
||||
'skipped_no_subdomain': total_skipped_no_subdomain,
|
||||
'skipped_failed': total_skipped_failed,
|
||||
'executed_tasks': executed_tasks,
|
||||
'tool_stats': {
|
||||
'total': len(enabled_tools),
|
||||
'successful': len(successful_tool_names),
|
||||
'failed': len(failed_tools),
|
||||
'successful_tools': successful_tool_names,
|
||||
'failed_tools': failed_tools,
|
||||
'details': tool_stats
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("配置错误: %s", e)
|
||||
raise
|
||||
except RuntimeError as e:
|
||||
logger.error("运行时错误: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("站点扫描失败: %s", e)
|
||||
raise
|
||||
@@ -1,744 +0,0 @@
|
||||
"""
|
||||
子域名发现扫描 Flow
|
||||
|
||||
负责编排子域名发现扫描的完整流程
|
||||
|
||||
架构:
|
||||
- Flow 负责编排多个原子 Task
|
||||
- 支持并行执行扫描工具
|
||||
- 每个 Task 可独立重试
|
||||
- 配置由 YAML 解析
|
||||
|
||||
增强流程(4 阶段):
|
||||
Stage 1: 被动收集(并行) - 必选
|
||||
Stage 2: 字典爆破(可选) - 子域名字典爆破
|
||||
Stage 3: 变异生成 + 验证(可选) - dnsgen + 通用存活验证
|
||||
Stage 4: DNS 存活验证(可选) - 通用存活验证
|
||||
|
||||
各阶段可灵活开关,最终结果根据实际执行的阶段动态决定
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
from prefect import flow
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import os
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command, ensure_wordlist_local
|
||||
from apps.engine.services.wordlist_service import WordlistService
|
||||
from apps.common.normalizer import normalize_domain
|
||||
from apps.common.validators import validate_domain
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import subprocess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _validate_and_normalize_target(target_name: str) -> str:
|
||||
"""
|
||||
验证并规范化目标域名
|
||||
|
||||
Args:
|
||||
target_name: 原始目标域名
|
||||
|
||||
Returns:
|
||||
str: 规范化后的域名
|
||||
|
||||
Raises:
|
||||
ValueError: 域名无效时抛出异常
|
||||
|
||||
Example:
|
||||
>>> _validate_and_normalize_target('EXAMPLE.COM')
|
||||
'example.com'
|
||||
>>> _validate_and_normalize_target('http://example.com')
|
||||
'example.com'
|
||||
"""
|
||||
try:
|
||||
normalized_target = normalize_domain(target_name)
|
||||
validate_domain(normalized_target)
|
||||
logger.debug("域名验证通过: %s -> %s", target_name, normalized_target)
|
||||
return normalized_target
|
||||
except ValueError as e:
|
||||
error_msg = f"无效的目标域名: {target_name} - {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
def _run_scans_parallel(
|
||||
enabled_tools: dict,
|
||||
domain_name: str,
|
||||
result_dir: Path
|
||||
) -> tuple[list, list, list]:
|
||||
"""
|
||||
并行运行所有启用的子域名扫描工具
|
||||
|
||||
Args:
|
||||
enabled_tools: 启用的工具配置字典 {'tool_name': {'timeout': 600, ...}}
|
||||
domain_name: 目标域名
|
||||
result_dir: 结果输出目录
|
||||
|
||||
Returns:
|
||||
tuple: (result_files, failed_tools, successful_tool_names)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 所有工具均失败
|
||||
"""
|
||||
# 导入任务函数
|
||||
from apps.scan.tasks.subdomain_discovery import run_subdomain_discovery_task
|
||||
|
||||
# 生成时间戳(所有工具共用)
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
|
||||
failures = [] # 记录命令构建失败的工具
|
||||
futures = {}
|
||||
|
||||
# 1. 构建命令并提交并行任务
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 1.1 生成唯一的输出文件路径(绝对路径)
|
||||
short_uuid = uuid.uuid4().hex[:4]
|
||||
output_file = str(result_dir / f"{tool_name}_{timestamp}_{short_uuid}.txt")
|
||||
|
||||
# 1.2 构建完整命令(变量替换)
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='subdomain_discovery',
|
||||
command_params={
|
||||
'domain': domain_name, # 对应 {domain}
|
||||
'output_file': output_file # 对应 {output_file}
|
||||
},
|
||||
tool_config=tool_config
|
||||
)
|
||||
except Exception as e:
|
||||
failure_msg = f"{tool_name}: 命令构建失败 - {e}"
|
||||
failures.append(failure_msg)
|
||||
logger.error(f"构建 {tool_name} 命令失败: {e}")
|
||||
continue
|
||||
|
||||
# 1.3 获取超时时间(支持 'auto' 动态计算)
|
||||
timeout = tool_config['timeout']
|
||||
if timeout == 'auto':
|
||||
# 子域名发现工具通常运行时间较长,使用默认值 600 秒
|
||||
timeout = 600
|
||||
logger.info(f"✓ 工具 {tool_name} 使用默认 timeout: {timeout}秒")
|
||||
|
||||
# 1.4 提交任务
|
||||
logger.debug(
|
||||
f"提交任务 - 工具: {tool_name}, 超时: {timeout}s, 输出: {output_file}"
|
||||
)
|
||||
|
||||
future = run_subdomain_discovery_task.submit(
|
||||
tool=tool_name,
|
||||
command=command,
|
||||
timeout=timeout,
|
||||
output_file=output_file
|
||||
)
|
||||
futures[tool_name] = future
|
||||
|
||||
# 2. 检查是否有任何工具成功提交
|
||||
if not futures:
|
||||
logger.warning(
|
||||
"所有扫描工具均无法启动 - 目标: %s, 失败详情: %s",
|
||||
domain_name, "; ".join(failures)
|
||||
)
|
||||
# 返回空结果,不抛出异常,让扫描继续
|
||||
return [], [{'tool': 'all', 'reason': '所有工具均无法启动'}], []
|
||||
|
||||
# 3. 等待并行任务完成,获取结果
|
||||
result_files = []
|
||||
failed_tools = []
|
||||
|
||||
for tool_name, future in futures.items():
|
||||
try:
|
||||
result = future.result() # 返回文件路径(字符串)或 ""(失败)
|
||||
if result:
|
||||
result_files.append(result)
|
||||
logger.info("✓ 扫描工具 %s 执行成功: %s", tool_name, result)
|
||||
else:
|
||||
failure_msg = f"{tool_name}: 未生成结果文件"
|
||||
failures.append(failure_msg)
|
||||
failed_tools.append({'tool': tool_name, 'reason': '未生成结果文件'})
|
||||
logger.warning("⚠️ 扫描工具 %s 未生成结果文件", tool_name)
|
||||
except Exception as e:
|
||||
failure_msg = f"{tool_name}: {str(e)}"
|
||||
failures.append(failure_msg)
|
||||
failed_tools.append({'tool': tool_name, 'reason': str(e)})
|
||||
logger.warning("⚠️ 扫描工具 %s 执行失败: %s", tool_name, str(e))
|
||||
|
||||
# 4. 检查是否有成功的工具
|
||||
if not result_files:
|
||||
logger.warning(
|
||||
"所有扫描工具均失败 - 目标: %s, 失败详情: %s",
|
||||
domain_name, "; ".join(failures)
|
||||
)
|
||||
# 返回空结果,不抛出异常,让扫描继续
|
||||
return [], failed_tools, []
|
||||
|
||||
# 5. 动态计算成功的工具列表
|
||||
successful_tool_names = [name for name in futures.keys()
|
||||
if name not in [f['tool'] for f in failed_tools]]
|
||||
|
||||
logger.info(
|
||||
"✓ 扫描工具并行执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
|
||||
len(result_files), len(futures),
|
||||
', '.join(successful_tool_names) if successful_tool_names else '无',
|
||||
', '.join([f['tool'] for f in failed_tools]) if failed_tools else '无'
|
||||
)
|
||||
|
||||
return result_files, failed_tools, successful_tool_names
|
||||
|
||||
|
||||
def _run_single_tool(
|
||||
tool_name: str,
|
||||
tool_config: dict,
|
||||
command_params: dict,
|
||||
result_dir: Path,
|
||||
scan_type: str = 'subdomain_discovery'
|
||||
) -> str:
|
||||
"""
|
||||
运行单个扫描工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_config: 工具配置
|
||||
command_params: 命令参数
|
||||
result_dir: 结果目录
|
||||
scan_type: 扫描类型
|
||||
|
||||
Returns:
|
||||
str: 输出文件路径,失败返回空字符串
|
||||
"""
|
||||
from apps.scan.tasks.subdomain_discovery import run_subdomain_discovery_task
|
||||
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
short_uuid = uuid.uuid4().hex[:4]
|
||||
output_file = str(result_dir / f"{tool_name}_{timestamp}_{short_uuid}.txt")
|
||||
|
||||
# 添加 output_file 到参数
|
||||
command_params['output_file'] = output_file
|
||||
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type=scan_type,
|
||||
command_params=command_params,
|
||||
tool_config=tool_config
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"构建 {tool_name} 命令失败: {e}")
|
||||
return ""
|
||||
|
||||
timeout = tool_config.get('timeout', 3600)
|
||||
if timeout == 'auto':
|
||||
timeout = 3600
|
||||
|
||||
logger.info(f"执行 {tool_name}: timeout={timeout}s")
|
||||
|
||||
try:
|
||||
result = run_subdomain_discovery_task(
|
||||
tool=tool_name,
|
||||
command=command,
|
||||
timeout=timeout,
|
||||
output_file=output_file
|
||||
)
|
||||
return result if result else ""
|
||||
except Exception as e:
|
||||
logger.warning(f"{tool_name} 执行失败: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def _count_lines(file_path: str) -> int:
|
||||
"""
|
||||
统计文件非空行数
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
int: 非空行数量
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
return sum(1 for line in f if line.strip())
|
||||
except Exception as e:
|
||||
logger.warning(f"统计文件行数失败: {file_path} - {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def _merge_files(file_list: list, output_file: str) -> str:
|
||||
"""
|
||||
合并多个文件并去重
|
||||
|
||||
Args:
|
||||
file_list: 文件路径列表
|
||||
output_file: 输出文件路径
|
||||
|
||||
Returns:
|
||||
str: 输出文件路径
|
||||
"""
|
||||
domains = set()
|
||||
for f in file_list:
|
||||
if f and Path(f).exists():
|
||||
with open(f, 'r', encoding='utf-8', errors='ignore') as fp:
|
||||
for line in fp:
|
||||
line = line.strip()
|
||||
if line:
|
||||
domains.add(line)
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as fp:
|
||||
for domain in sorted(domains):
|
||||
fp.write(domain + '\n')
|
||||
|
||||
logger.info(f"合并完成: {len(domains)} 个域名 -> {output_file}")
|
||||
return output_file
|
||||
|
||||
|
||||
@flow(
|
||||
name="subdomain_discovery",
|
||||
log_prints=True,
|
||||
on_running=[on_scan_flow_running],
|
||||
on_completion=[on_scan_flow_completed],
|
||||
on_failure=[on_scan_flow_failed],
|
||||
)
|
||||
def subdomain_discovery_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict
|
||||
) -> dict:
|
||||
"""子域名发现扫描流程
|
||||
|
||||
工作流程(4 阶段):
|
||||
Stage 1: 被动收集(并行) - 必选
|
||||
Stage 2: 字典爆破(可选) - 子域名字典爆破
|
||||
Stage 3: 变异生成 + 验证(可选) - dnsgen + 通用存活验证
|
||||
Stage 4: DNS 存活验证(可选) - 通用存活验证
|
||||
Final: 保存到数据库
|
||||
|
||||
注意:
|
||||
- 子域名发现只对 DOMAIN 类型目标有意义
|
||||
- IP 和 CIDR 类型目标会自动跳过
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称(域名)
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: Scan 工作空间目录(由 Service 层创建)
|
||||
enabled_tools: 扫描配置字典:
|
||||
{
|
||||
'passive_tools': {...},
|
||||
'bruteforce': {...},
|
||||
'permutation': {...},
|
||||
'resolve': {...}
|
||||
}
|
||||
|
||||
Returns:
|
||||
dict: 扫描结果
|
||||
|
||||
Raises:
|
||||
ValueError: 配置错误
|
||||
RuntimeError: 执行失败
|
||||
"""
|
||||
try:
|
||||
# ==================== 参数验证 ====================
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
if enabled_tools is None:
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
scan_config = enabled_tools
|
||||
|
||||
# 如果未提供目标域名,跳过扫描
|
||||
if not target_name:
|
||||
logger.warning("未提供目标域名,跳过子域名发现扫描")
|
||||
return _empty_result(scan_id, '', scan_workspace_dir)
|
||||
|
||||
# ==================== 检查 Target 类型 ====================
|
||||
# 子域名发现只对 DOMAIN 类型有意义,IP 和 CIDR 类型跳过
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
|
||||
if target and target.type != Target.TargetType.DOMAIN:
|
||||
logger.info(
|
||||
"跳过子域名发现扫描: Target 类型为 %s (ID=%d, Name=%s),子域名发现仅适用于域名类型",
|
||||
target.type, target_id, target_name
|
||||
)
|
||||
return _empty_result(scan_id, target_name, scan_workspace_dir)
|
||||
|
||||
# 导入任务函数
|
||||
from apps.scan.tasks.subdomain_discovery import (
|
||||
run_subdomain_discovery_task,
|
||||
merge_and_validate_task,
|
||||
save_domains_task
|
||||
)
|
||||
|
||||
# Step 0: 准备工作
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
result_dir = setup_scan_directory(scan_workspace_dir, 'subdomain_discovery')
|
||||
|
||||
# 验证并规范化目标域名
|
||||
try:
|
||||
domain_name = _validate_and_normalize_target(target_name)
|
||||
except ValueError as e:
|
||||
logger.warning("目标域名无效,跳过子域名发现扫描: %s", e)
|
||||
return _empty_result(scan_id, target_name, scan_workspace_dir)
|
||||
|
||||
# 验证成功后打印日志
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始子域名发现扫描\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Domain: {domain_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
)
|
||||
|
||||
# 解析配置
|
||||
passive_tools = scan_config.get('passive_tools', {})
|
||||
bruteforce_config = scan_config.get('bruteforce', {})
|
||||
permutation_config = scan_config.get('permutation', {})
|
||||
resolve_config = scan_config.get('resolve', {})
|
||||
|
||||
# 过滤出启用的被动工具
|
||||
enabled_passive_tools = {
|
||||
k: v for k, v in passive_tools.items()
|
||||
if v.get('enabled', True)
|
||||
}
|
||||
|
||||
executed_tasks = []
|
||||
all_result_files = []
|
||||
failed_tools = []
|
||||
successful_tool_names = []
|
||||
|
||||
# ==================== Stage 1: 被动收集(并行)====================
|
||||
logger.info("=" * 40)
|
||||
logger.info("Stage 1: 被动收集(并行)")
|
||||
logger.info("=" * 40)
|
||||
|
||||
if enabled_passive_tools:
|
||||
logger.info("启用工具: %s", ', '.join(enabled_passive_tools.keys()))
|
||||
result_files, stage1_failed, stage1_success = _run_scans_parallel(
|
||||
enabled_tools=enabled_passive_tools,
|
||||
domain_name=domain_name,
|
||||
result_dir=result_dir
|
||||
)
|
||||
all_result_files.extend(result_files)
|
||||
failed_tools.extend(stage1_failed)
|
||||
successful_tool_names.extend(stage1_success)
|
||||
executed_tasks.extend([f'passive ({tool})' for tool in stage1_success])
|
||||
else:
|
||||
logger.warning("未启用任何被动收集工具")
|
||||
|
||||
# 合并 Stage 1 结果
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
current_result = str(result_dir / f"subs_passive_{timestamp}.txt")
|
||||
if all_result_files:
|
||||
current_result = _merge_files(all_result_files, current_result)
|
||||
executed_tasks.append('merge_passive')
|
||||
else:
|
||||
# 创建空文件
|
||||
Path(current_result).touch()
|
||||
logger.warning("Stage 1 无结果,创建空文件")
|
||||
|
||||
# ==================== Stage 2: 字典爆破(可选)====================
|
||||
bruteforce_enabled = bruteforce_config.get('enabled', False)
|
||||
if bruteforce_enabled:
|
||||
logger.info("=" * 40)
|
||||
logger.info("Stage 2: 字典爆破")
|
||||
logger.info("=" * 40)
|
||||
|
||||
bruteforce_tool_config = bruteforce_config.get('subdomain_bruteforce', {})
|
||||
wordlist_name = bruteforce_tool_config.get('wordlist_name', 'dns_wordlist.txt')
|
||||
|
||||
try:
|
||||
# 确保本地存在字典文件(含 hash 校验)
|
||||
local_wordlist_path = ensure_wordlist_local(wordlist_name)
|
||||
|
||||
# 获取字典记录用于计算 timeout
|
||||
wordlist_service = WordlistService()
|
||||
wordlist = wordlist_service.get_wordlist_by_name(wordlist_name)
|
||||
|
||||
timeout_value = bruteforce_tool_config.get('timeout', 3600)
|
||||
if timeout_value == 'auto' and wordlist:
|
||||
line_count = getattr(wordlist, 'line_count', None)
|
||||
if line_count is None:
|
||||
try:
|
||||
with open(local_wordlist_path, 'rb') as f:
|
||||
line_count = sum(1 for _ in f)
|
||||
except OSError:
|
||||
line_count = 0
|
||||
|
||||
try:
|
||||
line_count_int = int(line_count)
|
||||
except (TypeError, ValueError):
|
||||
line_count_int = 0
|
||||
|
||||
timeout_value = line_count_int * 3 if line_count_int > 0 else 3600
|
||||
bruteforce_tool_config = {
|
||||
**bruteforce_tool_config,
|
||||
'timeout': timeout_value,
|
||||
}
|
||||
logger.info(
|
||||
"subdomain_bruteforce 使用自动 timeout: %s 秒 (字典行数=%s, 3秒/行)",
|
||||
timeout_value,
|
||||
line_count_int,
|
||||
)
|
||||
|
||||
brute_output = str(result_dir / f"subs_brute_{timestamp}.txt")
|
||||
brute_result = _run_single_tool(
|
||||
tool_name='subdomain_bruteforce',
|
||||
tool_config=bruteforce_tool_config,
|
||||
command_params={
|
||||
'domain': domain_name,
|
||||
'wordlist': local_wordlist_path,
|
||||
'output_file': brute_output
|
||||
},
|
||||
result_dir=result_dir
|
||||
)
|
||||
|
||||
if brute_result:
|
||||
# 合并 Stage 1 + Stage 2
|
||||
current_result = _merge_files(
|
||||
[current_result, brute_result],
|
||||
str(result_dir / f"subs_merged_{timestamp}.txt")
|
||||
)
|
||||
successful_tool_names.append('subdomain_bruteforce')
|
||||
executed_tasks.append('bruteforce')
|
||||
else:
|
||||
failed_tools.append({'tool': 'subdomain_bruteforce', 'reason': '执行失败'})
|
||||
except Exception as exc:
|
||||
logger.warning("字典准备失败,跳过字典爆破: %s", exc)
|
||||
failed_tools.append({'tool': 'subdomain_bruteforce', 'reason': str(exc)})
|
||||
|
||||
# ==================== Stage 3: 变异生成 + 验证(可选)====================
|
||||
permutation_enabled = permutation_config.get('enabled', False)
|
||||
if permutation_enabled:
|
||||
logger.info("=" * 40)
|
||||
logger.info("Stage 3: 变异生成 + 存活验证(流式管道)")
|
||||
logger.info("=" * 40)
|
||||
|
||||
permutation_tool_config = permutation_config.get('subdomain_permutation_resolve', {})
|
||||
|
||||
# === Step 3.1: 泛解析采样检测 ===
|
||||
# 生成原文件 100 倍的变异样本,检查解析结果是否超过 50 倍
|
||||
before_count = _count_lines(current_result)
|
||||
|
||||
# 配置参数
|
||||
SAMPLE_MULTIPLIER = 100 # 采样数量 = 原文件 × 100
|
||||
EXPANSION_THRESHOLD = 50 # 膨胀阈值 = 原文件 × 50
|
||||
SAMPLE_TIMEOUT = 7200 # 采样超时 2 小时
|
||||
|
||||
sample_size = before_count * SAMPLE_MULTIPLIER
|
||||
max_allowed = before_count * EXPANSION_THRESHOLD
|
||||
|
||||
sample_output = str(result_dir / f"subs_permuted_sample_{timestamp}.txt")
|
||||
sample_cmd = (
|
||||
f"cat {current_result} | dnsgen - | head -n {sample_size} | "
|
||||
f"puredns resolve -r /app/backend/resources/resolvers.txt "
|
||||
f"--write {sample_output} --wildcard-tests 50 --wildcard-batch 1000000 --quiet"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"泛解析采样检测: 原文件 {before_count} 个, "
|
||||
f"采样 {sample_size} 个, 阈值 {max_allowed} 个"
|
||||
)
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
sample_cmd,
|
||||
shell=True,
|
||||
timeout=SAMPLE_TIMEOUT,
|
||||
check=False,
|
||||
capture_output=True
|
||||
)
|
||||
sample_result_count = _count_lines(sample_output) if Path(sample_output).exists() else 0
|
||||
|
||||
logger.info(
|
||||
f"采样结果: {sample_result_count} 个域名存活 "
|
||||
f"(原文件: {before_count}, 阈值: {max_allowed})"
|
||||
)
|
||||
|
||||
if sample_result_count > max_allowed:
|
||||
# 采样结果超过阈值,说明存在泛解析,跳过完整变异
|
||||
ratio = sample_result_count / before_count if before_count > 0 else sample_result_count
|
||||
logger.warning(
|
||||
f"跳过变异: 采样检测到泛解析 "
|
||||
f"({sample_result_count} > {max_allowed}, 膨胀率 {ratio:.1f}x)"
|
||||
)
|
||||
failed_tools.append({
|
||||
'tool': 'subdomain_permutation_resolve',
|
||||
'reason': f"采样检测到泛解析 (膨胀率 {ratio:.1f}x)"
|
||||
})
|
||||
else:
|
||||
# === Step 3.2: 采样通过,执行完整变异 ===
|
||||
logger.info("采样检测通过,执行完整变异...")
|
||||
|
||||
permuted_output = str(result_dir / f"subs_permuted_{timestamp}.txt")
|
||||
|
||||
permuted_result = _run_single_tool(
|
||||
tool_name='subdomain_permutation_resolve',
|
||||
tool_config=permutation_tool_config,
|
||||
command_params={
|
||||
'input_file': current_result,
|
||||
'output_file': permuted_output,
|
||||
},
|
||||
result_dir=result_dir
|
||||
)
|
||||
|
||||
if permuted_result:
|
||||
# 合并原结果 + 变异验证结果
|
||||
current_result = _merge_files(
|
||||
[current_result, permuted_result],
|
||||
str(result_dir / f"subs_with_permuted_{timestamp}.txt")
|
||||
)
|
||||
successful_tool_names.append('subdomain_permutation_resolve')
|
||||
executed_tasks.append('permutation')
|
||||
else:
|
||||
failed_tools.append({'tool': 'subdomain_permutation_resolve', 'reason': '执行失败'})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"采样检测超时 ({SAMPLE_TIMEOUT}秒),跳过变异")
|
||||
failed_tools.append({'tool': 'subdomain_permutation_resolve', 'reason': '采样检测超时'})
|
||||
except Exception as e:
|
||||
logger.warning(f"采样检测失败: {e},跳过变异")
|
||||
failed_tools.append({'tool': 'subdomain_permutation_resolve', 'reason': f'采样检测失败: {e}'})
|
||||
|
||||
# ==================== Stage 4: DNS 存活验证(可选)====================
|
||||
# 无论是否启用 Stage 3,只要 resolve.enabled 为 true 就会执行,对当前所有候选子域做统一 DNS 验证
|
||||
resolve_enabled = resolve_config.get('enabled', False)
|
||||
if resolve_enabled:
|
||||
logger.info("=" * 40)
|
||||
logger.info("Stage 4: DNS 存活验证")
|
||||
logger.info("=" * 40)
|
||||
|
||||
resolve_tool_config = resolve_config.get('subdomain_resolve', {})
|
||||
|
||||
# 根据当前候选子域数量动态计算 timeout(支持 timeout: auto)
|
||||
timeout_value = resolve_tool_config.get('timeout', 3600)
|
||||
if timeout_value == 'auto':
|
||||
line_count = 0
|
||||
try:
|
||||
with open(current_result, 'rb') as f:
|
||||
line_count = sum(1 for _ in f)
|
||||
except OSError:
|
||||
line_count = 0
|
||||
|
||||
try:
|
||||
line_count_int = int(line_count)
|
||||
except (TypeError, ValueError):
|
||||
line_count_int = 0
|
||||
|
||||
timeout_value = line_count_int * 3 if line_count_int > 0 else 3600
|
||||
resolve_tool_config = {
|
||||
**resolve_tool_config,
|
||||
'timeout': timeout_value,
|
||||
}
|
||||
logger.info(
|
||||
"subdomain_resolve 使用自动 timeout: %s 秒 (候选子域数=%s, 3秒/域名)",
|
||||
timeout_value,
|
||||
line_count_int,
|
||||
)
|
||||
|
||||
alive_output = str(result_dir / f"subs_alive_{timestamp}.txt")
|
||||
|
||||
alive_result = _run_single_tool(
|
||||
tool_name='subdomain_resolve',
|
||||
tool_config=resolve_tool_config,
|
||||
command_params={
|
||||
'input_file': current_result,
|
||||
'output_file': alive_output,
|
||||
},
|
||||
result_dir=result_dir
|
||||
)
|
||||
|
||||
if alive_result:
|
||||
current_result = alive_result
|
||||
successful_tool_names.append('subdomain_resolve')
|
||||
executed_tasks.append('resolve')
|
||||
else:
|
||||
failed_tools.append({'tool': 'subdomain_resolve', 'reason': '执行失败'})
|
||||
|
||||
# ==================== Final: 保存到数据库 ====================
|
||||
logger.info("=" * 40)
|
||||
logger.info("Final: 保存到数据库")
|
||||
logger.info("=" * 40)
|
||||
|
||||
# 最终验证和保存
|
||||
final_file = merge_and_validate_task(
|
||||
result_files=[current_result],
|
||||
result_dir=str(result_dir)
|
||||
)
|
||||
|
||||
save_result = save_domains_task(
|
||||
domains_file=final_file,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id
|
||||
)
|
||||
processed_domains = save_result.get('processed_records', 0)
|
||||
executed_tasks.append('save_domains')
|
||||
|
||||
logger.info("="*60 + "\n✓ 子域名发现扫描完成\n" + "="*60)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': domain_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'total': processed_domains,
|
||||
'executed_tasks': executed_tasks,
|
||||
'tool_stats': {
|
||||
'total': len(enabled_passive_tools) + (1 if bruteforce_enabled else 0) +
|
||||
(1 if permutation_enabled else 0) + (1 if resolve_enabled else 0),
|
||||
'successful': len(successful_tool_names),
|
||||
'failed': len(failed_tools),
|
||||
'successful_tools': successful_tool_names,
|
||||
'failed_tools': failed_tools
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("配置错误: %s", e)
|
||||
raise
|
||||
except RuntimeError as e:
|
||||
logger.error("运行时错误: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("子域名发现扫描失败: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _empty_result(scan_id: int, target: str, scan_workspace_dir: str) -> dict:
|
||||
"""返回空结果"""
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'total': 0,
|
||||
'executed_tasks': [],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'successful_tools': [],
|
||||
'failed_tools': []
|
||||
}
|
||||
}
|
||||
@@ -1,238 +0,0 @@
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command, ensure_nuclei_templates_local
|
||||
from apps.scan.tasks.vuln_scan import (
|
||||
export_endpoints_task,
|
||||
run_vuln_tool_task,
|
||||
run_and_stream_save_dalfox_vulns_task,
|
||||
run_and_stream_save_nuclei_vulns_task,
|
||||
)
|
||||
from .utils import calculate_timeout_by_line_count
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@flow(
|
||||
name="endpoints_vuln_scan_flow",
|
||||
log_prints=True,
|
||||
)
|
||||
def endpoints_vuln_scan_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: Dict[str, dict],
|
||||
) -> dict:
|
||||
"""基于 Endpoint 的漏洞扫描 Flow(串行执行 Dalfox 等工具)。"""
|
||||
try:
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
if not enabled_tools:
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
vuln_scan_dir = setup_scan_directory(scan_workspace_dir, 'vuln_scan')
|
||||
endpoints_file = vuln_scan_dir / "input_endpoints.txt"
|
||||
|
||||
# Step 1: 导出 Endpoint URL
|
||||
export_result = export_endpoints_task(
|
||||
target_id=target_id,
|
||||
output_file=str(endpoints_file),
|
||||
)
|
||||
total_endpoints = export_result.get("total_count", 0)
|
||||
|
||||
if total_endpoints == 0 or not endpoints_file.exists() or endpoints_file.stat().st_size == 0:
|
||||
logger.warning("目标下没有可用 Endpoint,跳过漏洞扫描")
|
||||
return {
|
||||
"success": True,
|
||||
"scan_id": scan_id,
|
||||
"target": target_name,
|
||||
"scan_workspace_dir": scan_workspace_dir,
|
||||
"endpoints_file": str(endpoints_file),
|
||||
"endpoint_count": 0,
|
||||
"executed_tools": [],
|
||||
"tool_results": {},
|
||||
}
|
||||
|
||||
logger.info("Endpoint 导出完成,共 %d 条,开始执行漏洞扫描", total_endpoints)
|
||||
|
||||
tool_results: Dict[str, dict] = {}
|
||||
|
||||
# Step 2: 并行执行每个漏洞扫描工具(目前主要是 Dalfox)
|
||||
# 1)先为每个工具 submit Prefect Task,让 Worker 并行调度
|
||||
# 2)再统一收集各自的结果,组装成 tool_results
|
||||
tool_futures: Dict[str, dict] = {}
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# Nuclei 需要先确保本地模板存在(支持多个模板仓库)
|
||||
template_args = ""
|
||||
if tool_name == "nuclei":
|
||||
repo_names = tool_config.get("template_repo_names")
|
||||
if not repo_names or not isinstance(repo_names, (list, tuple)):
|
||||
logger.error("Nuclei 配置缺少 template_repo_names(数组),跳过")
|
||||
continue
|
||||
template_paths = []
|
||||
try:
|
||||
for repo_name in repo_names:
|
||||
path = ensure_nuclei_templates_local(repo_name)
|
||||
template_paths.append(path)
|
||||
logger.info("Nuclei 模板路径 [%s]: %s", repo_name, path)
|
||||
except Exception as e:
|
||||
logger.error("获取 Nuclei 模板失败: %s,跳过 nuclei 扫描", e)
|
||||
continue
|
||||
template_args = " ".join(f"-t {p}" for p in template_paths)
|
||||
|
||||
# 构建命令参数
|
||||
command_params = {"endpoints_file": str(endpoints_file)}
|
||||
if template_args:
|
||||
command_params["template_args"] = template_args
|
||||
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type="vuln_scan",
|
||||
command_params=command_params,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
raw_timeout = tool_config.get("timeout", 600)
|
||||
|
||||
if isinstance(raw_timeout, str) and raw_timeout == "auto":
|
||||
# timeout=auto 时,根据 endpoints_file 行数自动计算超时时间
|
||||
# Dalfox: 每行 100 秒,Nuclei: 每行 30 秒
|
||||
base_per_time = 30 if tool_name == "nuclei" else 100
|
||||
timeout = calculate_timeout_by_line_count(
|
||||
tool_config=tool_config,
|
||||
file_path=str(endpoints_file),
|
||||
base_per_time=base_per_time,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
timeout = int(raw_timeout)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"工具 {tool_name} 的 timeout 配置无效: {raw_timeout!r}"
|
||||
) from e
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
log_file = vuln_scan_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
# Dalfox XSS 使用流式任务,一边解析一边保存漏洞结果
|
||||
if tool_name == "dalfox_xss":
|
||||
logger.info("开始执行漏洞扫描工具 %s(流式保存漏洞结果,已提交任务)", tool_name)
|
||||
future = run_and_stream_save_dalfox_vulns_task.submit(
|
||||
cmd=command,
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(vuln_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file),
|
||||
)
|
||||
|
||||
tool_futures[tool_name] = {
|
||||
"future": future,
|
||||
"command": command,
|
||||
"timeout": timeout,
|
||||
"log_file": str(log_file),
|
||||
"mode": "streaming",
|
||||
}
|
||||
elif tool_name == "nuclei":
|
||||
# Nuclei 使用流式任务
|
||||
logger.info("开始执行漏洞扫描工具 %s(流式保存漏洞结果,已提交任务)", tool_name)
|
||||
future = run_and_stream_save_nuclei_vulns_task.submit(
|
||||
cmd=command,
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(vuln_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file),
|
||||
)
|
||||
|
||||
tool_futures[tool_name] = {
|
||||
"future": future,
|
||||
"command": command,
|
||||
"timeout": timeout,
|
||||
"log_file": str(log_file),
|
||||
"mode": "streaming",
|
||||
}
|
||||
else:
|
||||
# 其他工具仍使用非流式执行逻辑
|
||||
logger.info("开始执行漏洞扫描工具 %s(已提交任务)", tool_name)
|
||||
future = run_vuln_tool_task.submit(
|
||||
tool_name=tool_name,
|
||||
command=command,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file),
|
||||
)
|
||||
|
||||
tool_futures[tool_name] = {
|
||||
"future": future,
|
||||
"command": command,
|
||||
"timeout": timeout,
|
||||
"log_file": str(log_file),
|
||||
"mode": "normal",
|
||||
}
|
||||
|
||||
# 统一收集所有工具的执行结果
|
||||
for tool_name, meta in tool_futures.items():
|
||||
future = meta["future"]
|
||||
result = future.result()
|
||||
|
||||
if meta["mode"] == "streaming":
|
||||
tool_results[tool_name] = {
|
||||
"command": meta["command"],
|
||||
"timeout": meta["timeout"],
|
||||
"processed_records": result.get("processed_records"),
|
||||
"created_vulns": result.get("created_vulns"),
|
||||
"command_log_file": meta["log_file"],
|
||||
}
|
||||
else:
|
||||
tool_results[tool_name] = {
|
||||
"command": meta["command"],
|
||||
"timeout": meta["timeout"],
|
||||
"duration": result.get("duration"),
|
||||
"returncode": result.get("returncode"),
|
||||
"command_log_file": result.get("command_log_file"),
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"scan_id": scan_id,
|
||||
"target": target_name,
|
||||
"scan_workspace_dir": scan_workspace_dir,
|
||||
"endpoints_file": str(endpoints_file),
|
||||
"endpoint_count": total_endpoints,
|
||||
"executed_tools": list(enabled_tools.keys()),
|
||||
"tool_results": tool_results,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Endpoint 漏洞扫描失败: %s", e)
|
||||
raise
|
||||
@@ -1,107 +0,0 @@
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.configs.command_templates import get_command_template
|
||||
from .endpoints_vuln_scan_flow import endpoints_vuln_scan_flow
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _classify_vuln_tools(enabled_tools: Dict[str, dict]) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
"""根据命令模板中的 input_type 对漏洞扫描工具进行分类。
|
||||
|
||||
当前支持:
|
||||
- endpoints_file: 以端点列表文件为输入(例如 Dalfox XSS)
|
||||
预留:
|
||||
- 其他 input_type 将被归类到 other_tools,暂不处理。
|
||||
"""
|
||||
endpoints_tools: Dict[str, dict] = {}
|
||||
other_tools: Dict[str, dict] = {}
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
template = get_command_template("vuln_scan", tool_name) or {}
|
||||
input_type = template.get("input_type", "endpoints_file")
|
||||
|
||||
if input_type == "endpoints_file":
|
||||
endpoints_tools[tool_name] = tool_config
|
||||
else:
|
||||
other_tools[tool_name] = tool_config
|
||||
|
||||
return endpoints_tools, other_tools
|
||||
|
||||
|
||||
@flow(
|
||||
name="vuln_scan",
|
||||
log_prints=True,
|
||||
on_running=[on_scan_flow_running],
|
||||
on_completion=[on_scan_flow_completed],
|
||||
on_failure=[on_scan_flow_failed],
|
||||
)
|
||||
def vuln_scan_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: Dict[str, dict],
|
||||
) -> dict:
|
||||
"""漏洞扫描主 Flow:串行编排各类漏洞扫描子 Flow。
|
||||
|
||||
支持工具:
|
||||
- dalfox_xss: XSS 漏洞扫描(流式保存)
|
||||
- nuclei: 通用漏洞扫描(流式保存,支持模板 commit hash 同步)
|
||||
"""
|
||||
try:
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
if not enabled_tools:
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
# Step 1: 分类工具
|
||||
endpoints_tools, other_tools = _classify_vuln_tools(enabled_tools)
|
||||
|
||||
logger.info(
|
||||
"漏洞扫描工具分类 - endpoints_file: %s, 其他: %s",
|
||||
list(endpoints_tools.keys()) or "无",
|
||||
list(other_tools.keys()) or "无",
|
||||
)
|
||||
|
||||
if other_tools:
|
||||
logger.warning(
|
||||
"存在暂不支持输入类型的漏洞扫描工具,将被忽略: %s",
|
||||
list(other_tools.keys()),
|
||||
)
|
||||
|
||||
if not endpoints_tools:
|
||||
raise ValueError("漏洞扫描需要至少启用一个以 endpoints_file 为输入的工具(如 dalfox_xss、nuclei)。")
|
||||
|
||||
# Step 2: 执行 Endpoint 漏洞扫描子 Flow(串行)
|
||||
endpoint_result = endpoints_vuln_scan_flow(
|
||||
scan_id=scan_id,
|
||||
target_name=target_name,
|
||||
target_id=target_id,
|
||||
scan_workspace_dir=scan_workspace_dir,
|
||||
enabled_tools=endpoints_tools,
|
||||
)
|
||||
|
||||
# 目前只有一个子 Flow,直接返回其结果
|
||||
return endpoint_result
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("漏洞扫描主 Flow 失败: %s", e)
|
||||
raise
|
||||
@@ -1,182 +0,0 @@
|
||||
"""
|
||||
扫描流程处理器
|
||||
|
||||
负责处理扫描流程(端口扫描、子域名发现等)的状态变化和通知
|
||||
|
||||
职责:
|
||||
- 更新各阶段的进度状态(running/completed/failed)
|
||||
- 发送扫描阶段的通知
|
||||
- 记录 Flow 性能指标
|
||||
"""
|
||||
|
||||
import logging
|
||||
from prefect import Flow
|
||||
from prefect.client.schemas import FlowRun, State
|
||||
|
||||
from apps.scan.utils.performance import FlowPerformanceTracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 存储每个 flow_run 的性能追踪器
|
||||
_flow_trackers: dict[str, FlowPerformanceTracker] = {}
|
||||
|
||||
|
||||
def _get_stage_from_flow_name(flow_name: str) -> str | None:
|
||||
"""
|
||||
从 Flow name 获取对应的 stage
|
||||
|
||||
Flow name 直接作为 stage(与 engine_config 的 key 一致)
|
||||
排除主 Flow(initiate_scan)
|
||||
"""
|
||||
# 排除主 Flow,它不是阶段 Flow
|
||||
if flow_name == 'initiate_scan':
|
||||
return None
|
||||
return flow_name
|
||||
|
||||
|
||||
def on_scan_flow_running(flow: Flow, flow_run: FlowRun, state: State) -> None:
|
||||
"""
|
||||
扫描流程开始运行时的回调
|
||||
|
||||
职责:
|
||||
- 更新阶段进度为 running
|
||||
- 发送扫描开始通知
|
||||
- 启动性能追踪
|
||||
|
||||
Args:
|
||||
flow: Prefect Flow 对象
|
||||
flow_run: Flow 运行实例
|
||||
state: Flow 当前状态
|
||||
"""
|
||||
logger.info("🚀 扫描流程开始运行 - Flow: %s, Run ID: %s", flow.name, flow_run.id)
|
||||
|
||||
# 提取流程参数
|
||||
flow_params = flow_run.parameters or {}
|
||||
scan_id = flow_params.get('scan_id')
|
||||
target_name = flow_params.get('target_name', 'unknown')
|
||||
target_id = flow_params.get('target_id')
|
||||
|
||||
# 启动性能追踪
|
||||
if scan_id:
|
||||
tracker = FlowPerformanceTracker(flow.name, scan_id)
|
||||
tracker.start(target_id=target_id, target_name=target_name)
|
||||
_flow_trackers[str(flow_run.id)] = tracker
|
||||
|
||||
# 更新阶段进度
|
||||
stage = _get_stage_from_flow_name(flow.name)
|
||||
if scan_id and stage:
|
||||
try:
|
||||
from apps.scan.services import ScanService
|
||||
service = ScanService()
|
||||
service.start_stage(scan_id, stage)
|
||||
logger.info(f"✓ 阶段进度已更新为 running - Scan ID: {scan_id}, Stage: {stage}")
|
||||
except Exception as e:
|
||||
logger.error(f"更新阶段进度失败 - Scan ID: {scan_id}, Stage: {stage}: {e}")
|
||||
|
||||
|
||||
def on_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State) -> None:
|
||||
"""
|
||||
扫描流程完成时的回调
|
||||
|
||||
职责:
|
||||
- 更新阶段进度为 completed
|
||||
- 发送扫描完成通知(可选)
|
||||
- 记录性能指标
|
||||
|
||||
Args:
|
||||
flow: Prefect Flow 对象
|
||||
flow_run: Flow 运行实例
|
||||
state: Flow 当前状态
|
||||
"""
|
||||
logger.info("✅ 扫描流程完成 - Flow: %s, Run ID: %s", flow.name, flow_run.id)
|
||||
|
||||
# 提取流程参数
|
||||
flow_params = flow_run.parameters or {}
|
||||
scan_id = flow_params.get('scan_id')
|
||||
|
||||
# 获取 flow result
|
||||
result = None
|
||||
try:
|
||||
result = state.result() if state.result else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 记录性能指标
|
||||
tracker = _flow_trackers.pop(str(flow_run.id), None)
|
||||
if tracker:
|
||||
tracker.finish(success=True)
|
||||
|
||||
# 更新阶段进度
|
||||
stage = _get_stage_from_flow_name(flow.name)
|
||||
if scan_id and stage:
|
||||
try:
|
||||
from apps.scan.services import ScanService
|
||||
service = ScanService()
|
||||
# 从 flow result 中提取 detail(如果有)
|
||||
detail = None
|
||||
if isinstance(result, dict):
|
||||
detail = result.get('detail')
|
||||
service.complete_stage(scan_id, stage, detail)
|
||||
logger.info(f"✓ 阶段进度已更新为 completed - Scan ID: {scan_id}, Stage: {stage}")
|
||||
# 每个阶段完成后刷新缓存统计,便于前端实时看到增量
|
||||
try:
|
||||
service.update_cached_stats(scan_id)
|
||||
logger.info("✓ 阶段完成后已刷新缓存统计 - Scan ID: %s", scan_id)
|
||||
except Exception as e:
|
||||
logger.error("阶段完成后刷新缓存统计失败 - Scan ID: %s, 错误: %s", scan_id, e)
|
||||
except Exception as e:
|
||||
logger.error(f"更新阶段进度失败 - Scan ID: {scan_id}, Stage: {stage}: {e}")
|
||||
|
||||
|
||||
def on_scan_flow_failed(flow: Flow, flow_run: FlowRun, state: State) -> None:
|
||||
"""
|
||||
扫描流程失败时的回调
|
||||
|
||||
职责:
|
||||
- 更新阶段进度为 failed
|
||||
- 发送扫描失败通知
|
||||
- 记录性能指标(含错误信息)
|
||||
|
||||
Args:
|
||||
flow: Prefect Flow 对象
|
||||
flow_run: Flow 运行实例
|
||||
state: Flow 当前状态
|
||||
"""
|
||||
logger.info("❌ 扫描流程失败 - Flow: %s, Run ID: %s", flow.name, flow_run.id)
|
||||
|
||||
# 提取流程参数
|
||||
flow_params = flow_run.parameters or {}
|
||||
scan_id = flow_params.get('scan_id')
|
||||
target_name = flow_params.get('target_name', 'unknown')
|
||||
|
||||
# 提取错误信息
|
||||
error_message = str(state.message) if state.message else "未知错误"
|
||||
|
||||
# 记录性能指标(失败情况)
|
||||
tracker = _flow_trackers.pop(str(flow_run.id), None)
|
||||
if tracker:
|
||||
tracker.finish(success=False, error_message=error_message)
|
||||
|
||||
# 更新阶段进度
|
||||
stage = _get_stage_from_flow_name(flow.name)
|
||||
if scan_id and stage:
|
||||
try:
|
||||
from apps.scan.services import ScanService
|
||||
service = ScanService()
|
||||
service.fail_stage(scan_id, stage, error_message)
|
||||
logger.info(f"✓ 阶段进度已更新为 failed - Scan ID: {scan_id}, Stage: {stage}")
|
||||
except Exception as e:
|
||||
logger.error(f"更新阶段进度失败 - Scan ID: {scan_id}, Stage: {stage}: {e}")
|
||||
|
||||
# 发送通知
|
||||
try:
|
||||
from apps.scan.notifications import create_notification, NotificationLevel
|
||||
message = f"任务:{flow.name}\n状态:执行失败\n错误:{error_message}"
|
||||
create_notification(
|
||||
title=target_name,
|
||||
message=message,
|
||||
level=NotificationLevel.HIGH
|
||||
)
|
||||
logger.error(f"✓ 扫描失败通知已发送 - Target: {target_name}, Flow: {flow.name}, Error: {error_message}")
|
||||
except Exception as e:
|
||||
logger.error(f"发送扫描失败通知失败 - Flow: {flow.name}: {e}")
|
||||
@@ -1,180 +0,0 @@
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
||||
from ..common.definitions import ScanStatus
|
||||
|
||||
|
||||
|
||||
|
||||
class SoftDeleteManager(models.Manager):
|
||||
"""软删除管理器:默认只返回未删除的记录"""
|
||||
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().filter(deleted_at__isnull=True)
|
||||
|
||||
|
||||
class Scan(models.Model):
|
||||
"""扫描任务模型"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
|
||||
target = models.ForeignKey('targets.Target', on_delete=models.CASCADE, related_name='scans', help_text='扫描目标')
|
||||
|
||||
engine = models.ForeignKey(
|
||||
'engine.ScanEngine',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='scans',
|
||||
help_text='使用的扫描引擎'
|
||||
)
|
||||
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='任务创建时间')
|
||||
stopped_at = models.DateTimeField(null=True, blank=True, help_text='扫描结束时间')
|
||||
|
||||
status = models.CharField(
|
||||
max_length=20,
|
||||
choices=ScanStatus.choices,
|
||||
default=ScanStatus.INITIATED,
|
||||
db_index=True,
|
||||
help_text='任务状态'
|
||||
)
|
||||
|
||||
results_dir = models.CharField(max_length=100, blank=True, default='', help_text='结果存储目录')
|
||||
|
||||
container_ids = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
default=list,
|
||||
help_text='容器 ID 列表(Docker Container ID)'
|
||||
)
|
||||
|
||||
worker = models.ForeignKey(
|
||||
'engine.WorkerNode',
|
||||
on_delete=models.SET_NULL,
|
||||
related_name='scans',
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='执行扫描的 Worker 节点'
|
||||
)
|
||||
|
||||
error_message = models.CharField(max_length=2000, blank=True, default='', help_text='错误信息')
|
||||
|
||||
# ==================== 软删除字段 ====================
|
||||
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间(NULL表示未删除)')
|
||||
|
||||
# ==================== 管理器 ====================
|
||||
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
|
||||
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
|
||||
|
||||
# ==================== 进度跟踪字段 ====================
|
||||
progress = models.IntegerField(default=0, help_text='扫描进度 0-100')
|
||||
current_stage = models.CharField(max_length=50, blank=True, default='', help_text='当前扫描阶段')
|
||||
stage_progress = models.JSONField(default=dict, help_text='各阶段进度详情')
|
||||
|
||||
# ==================== 缓存统计字段 ====================
|
||||
cached_subdomains_count = models.IntegerField(default=0, help_text='缓存的子域名数量')
|
||||
cached_websites_count = models.IntegerField(default=0, help_text='缓存的网站数量')
|
||||
cached_endpoints_count = models.IntegerField(default=0, help_text='缓存的端点数量')
|
||||
cached_ips_count = models.IntegerField(default=0, help_text='缓存的IP地址数量')
|
||||
cached_directories_count = models.IntegerField(default=0, help_text='缓存的目录数量')
|
||||
cached_vulns_total = models.IntegerField(default=0, help_text='缓存的漏洞总数')
|
||||
cached_vulns_critical = models.IntegerField(default=0, help_text='缓存的严重漏洞数量')
|
||||
cached_vulns_high = models.IntegerField(default=0, help_text='缓存的高危漏洞数量')
|
||||
cached_vulns_medium = models.IntegerField(default=0, help_text='缓存的中危漏洞数量')
|
||||
cached_vulns_low = models.IntegerField(default=0, help_text='缓存的低危漏洞数量')
|
||||
stats_updated_at = models.DateTimeField(null=True, blank=True, help_text='统计数据最后更新时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'scan'
|
||||
verbose_name = '扫描任务'
|
||||
verbose_name_plural = '扫描任务'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-created_at']), # 优化按创建时间降序排序(list 查询的默认排序)
|
||||
models.Index(fields=['target']), # 优化按目标查询扫描任务
|
||||
models.Index(fields=['deleted_at', '-created_at']), # 软删除 + 时间索引
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f"Scan #{self.id} - {self.target.name}"
|
||||
|
||||
|
||||
class ScheduledScan(models.Model):
|
||||
"""
|
||||
定时扫描任务模型
|
||||
|
||||
调度机制:
|
||||
- APScheduler 每分钟检查 next_run_time
|
||||
- 到期任务通过 task_distributor 分发到 Worker 执行
|
||||
- 支持 cron 表达式进行灵活调度
|
||||
|
||||
扫描模式(二选一):
|
||||
- 组织扫描:设置 organization,执行时动态获取组织下所有目标
|
||||
- 目标扫描:设置 target,扫描单个目标
|
||||
- organization 优先级高于 target
|
||||
"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
|
||||
# 基本信息
|
||||
name = models.CharField(max_length=200, help_text='任务名称')
|
||||
|
||||
# 关联的扫描引擎
|
||||
engine = models.ForeignKey(
|
||||
'engine.ScanEngine',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='scheduled_scans',
|
||||
help_text='使用的扫描引擎'
|
||||
)
|
||||
|
||||
# 关联的组织(组织扫描模式:执行时动态获取组织下所有目标)
|
||||
organization = models.ForeignKey(
|
||||
'targets.Organization',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='scheduled_scans',
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='扫描组织(设置后执行时动态获取组织下所有目标)'
|
||||
)
|
||||
|
||||
# 关联的目标(目标扫描模式:扫描单个目标)
|
||||
target = models.ForeignKey(
|
||||
'targets.Target',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='scheduled_scans',
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='扫描单个目标(与 organization 二选一)'
|
||||
)
|
||||
|
||||
# 调度配置 - 直接使用 Cron 表达式
|
||||
cron_expression = models.CharField(
|
||||
max_length=100,
|
||||
default='0 2 * * *',
|
||||
help_text='Cron 表达式,格式:分 时 日 月 周'
|
||||
)
|
||||
|
||||
# 状态
|
||||
is_enabled = models.BooleanField(default=True, db_index=True, help_text='是否启用')
|
||||
|
||||
# 执行统计
|
||||
run_count = models.IntegerField(default=0, help_text='已执行次数')
|
||||
last_run_time = models.DateTimeField(null=True, blank=True, help_text='上次执行时间')
|
||||
next_run_time = models.DateTimeField(null=True, blank=True, help_text='下次执行时间')
|
||||
|
||||
# 时间戳
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'scheduled_scan'
|
||||
verbose_name = '定时扫描任务'
|
||||
verbose_name_plural = '定时扫描任务'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-created_at']),
|
||||
models.Index(fields=['is_enabled', '-created_at']),
|
||||
models.Index(fields=['name']), # 优化 name 搜索
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f"ScheduledScan #{self.id} - {self.name}"
|
||||
@@ -1,189 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
扫描任务启动脚本
|
||||
|
||||
用于动态扫描容器启动时执行。
|
||||
必须在 Django 导入之前获取配置并设置环境变量。
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import traceback
|
||||
|
||||
|
||||
def diagnose_prefect_environment():
|
||||
"""诊断 Prefect 运行环境,输出详细信息用于排查问题"""
|
||||
print("\n" + "="*60)
|
||||
print("Prefect 环境诊断")
|
||||
print("="*60)
|
||||
|
||||
# 1. 检查 Prefect 相关环境变量
|
||||
print("\n[诊断] Prefect 环境变量:")
|
||||
prefect_vars = [
|
||||
'PREFECT_HOME',
|
||||
'PREFECT_API_URL',
|
||||
'PREFECT_SERVER_EPHEMERAL_ENABLED',
|
||||
'PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS',
|
||||
'PREFECT_SERVER_DATABASE_CONNECTION_URL',
|
||||
'PREFECT_LOGGING_LEVEL',
|
||||
'PREFECT_DEBUG_MODE',
|
||||
]
|
||||
for var in prefect_vars:
|
||||
value = os.environ.get(var, 'NOT SET')
|
||||
print(f" {var}={value}")
|
||||
|
||||
# 2. 检查 PREFECT_HOME 目录
|
||||
prefect_home = os.environ.get('PREFECT_HOME', os.path.expanduser('~/.prefect'))
|
||||
print(f"\n[诊断] PREFECT_HOME 目录: {prefect_home}")
|
||||
if os.path.exists(prefect_home):
|
||||
print(f" ✓ 目录存在")
|
||||
print(f" 可写: {os.access(prefect_home, os.W_OK)}")
|
||||
try:
|
||||
files = os.listdir(prefect_home)
|
||||
print(f" 文件列表: {files[:10]}{'...' if len(files) > 10 else ''}")
|
||||
except Exception as e:
|
||||
print(f" ✗ 无法列出文件: {e}")
|
||||
else:
|
||||
print(f" 目录不存在,尝试创建...")
|
||||
try:
|
||||
os.makedirs(prefect_home, exist_ok=True)
|
||||
print(f" ✓ 创建成功")
|
||||
except Exception as e:
|
||||
print(f" ✗ 创建失败: {e}")
|
||||
|
||||
# 3. 检查 uvicorn 是否可用
|
||||
print(f"\n[诊断] uvicorn 可用性:")
|
||||
import shutil
|
||||
uvicorn_path = shutil.which('uvicorn')
|
||||
if uvicorn_path:
|
||||
print(f" ✓ uvicorn 路径: {uvicorn_path}")
|
||||
else:
|
||||
print(f" ✗ uvicorn 不在 PATH 中")
|
||||
print(f" PATH: {os.environ.get('PATH', 'NOT SET')}")
|
||||
|
||||
# 4. 检查 Prefect 版本
|
||||
print(f"\n[诊断] Prefect 版本:")
|
||||
try:
|
||||
import prefect
|
||||
print(f" ✓ prefect=={prefect.__version__}")
|
||||
except Exception as e:
|
||||
print(f" ✗ 无法导入 prefect: {e}")
|
||||
|
||||
# 5. 检查 SQLite 支持
|
||||
print(f"\n[诊断] SQLite 支持:")
|
||||
try:
|
||||
import sqlite3
|
||||
print(f" ✓ sqlite3 版本: {sqlite3.sqlite_version}")
|
||||
# 测试创建数据库
|
||||
test_db = os.path.join(prefect_home, 'test.db')
|
||||
conn = sqlite3.connect(test_db)
|
||||
conn.execute('CREATE TABLE IF NOT EXISTS test (id INTEGER)')
|
||||
conn.close()
|
||||
os.remove(test_db)
|
||||
print(f" ✓ SQLite 读写测试通过")
|
||||
except Exception as e:
|
||||
print(f" ✗ SQLite 测试失败: {e}")
|
||||
|
||||
# 6. 检查端口绑定能力
|
||||
print(f"\n[诊断] 端口绑定测试:")
|
||||
try:
|
||||
import socket
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(('127.0.0.1', 0))
|
||||
port = sock.getsockname()[1]
|
||||
sock.close()
|
||||
print(f" ✓ 可以绑定 127.0.0.1 端口 (测试端口: {port})")
|
||||
except Exception as e:
|
||||
print(f" ✗ 端口绑定失败: {e}")
|
||||
|
||||
# 7. 检查内存情况
|
||||
print(f"\n[诊断] 系统资源:")
|
||||
try:
|
||||
import psutil
|
||||
mem = psutil.virtual_memory()
|
||||
print(f" 内存总量: {mem.total / 1024 / 1024:.0f} MB")
|
||||
print(f" 可用内存: {mem.available / 1024 / 1024:.0f} MB")
|
||||
print(f" 内存使用率: {mem.percent}%")
|
||||
except ImportError:
|
||||
print(f" psutil 未安装,跳过内存检查")
|
||||
except Exception as e:
|
||||
print(f" ✗ 资源检查失败: {e}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("诊断完成")
|
||||
print("="*60 + "\n")
|
||||
|
||||
|
||||
def main():
|
||||
print("="*60)
|
||||
print("run_initiate_scan.py 启动")
|
||||
print(f" Python: {sys.version}")
|
||||
print(f" CWD: {os.getcwd()}")
|
||||
print(f" SERVER_URL: {os.environ.get('SERVER_URL', 'NOT SET')}")
|
||||
print("="*60)
|
||||
|
||||
# 1. 从配置中心获取配置并初始化 Django(必须在 Django 导入之前)
|
||||
print("[1/4] 从配置中心获取配置...")
|
||||
try:
|
||||
from apps.common.container_bootstrap import fetch_config_and_setup_django
|
||||
fetch_config_and_setup_django()
|
||||
print("[1/4] ✓ 配置获取成功")
|
||||
except Exception as e:
|
||||
print(f"[1/4] ✗ 配置获取失败: {e}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
# 2. 解析命令行参数
|
||||
print("[2/4] 解析命令行参数...")
|
||||
parser = argparse.ArgumentParser(description="执行扫描初始化 Flow")
|
||||
parser.add_argument("--scan_id", type=int, required=True, help="扫描任务 ID")
|
||||
parser.add_argument("--target_name", type=str, required=True, help="目标名称")
|
||||
parser.add_argument("--target_id", type=int, required=True, help="目标 ID")
|
||||
parser.add_argument("--scan_workspace_dir", type=str, required=True, help="扫描工作目录")
|
||||
parser.add_argument("--engine_name", type=str, required=True, help="引擎名称")
|
||||
parser.add_argument("--scheduled_scan_name", type=str, default=None, help="定时扫描任务名称(可选)")
|
||||
|
||||
args = parser.parse_args()
|
||||
print(f"[2/4] ✓ 参数解析成功:")
|
||||
print(f" scan_id: {args.scan_id}")
|
||||
print(f" target_name: {args.target_name}")
|
||||
print(f" target_id: {args.target_id}")
|
||||
print(f" scan_workspace_dir: {args.scan_workspace_dir}")
|
||||
print(f" engine_name: {args.engine_name}")
|
||||
print(f" scheduled_scan_name: {args.scheduled_scan_name}")
|
||||
|
||||
# 2.5. 运行 Prefect 环境诊断(仅在 DEBUG 模式下)
|
||||
if os.environ.get('DEBUG', '').lower() == 'true':
|
||||
diagnose_prefect_environment()
|
||||
|
||||
# 3. 现在可以安全导入 Django 相关模块
|
||||
print("[3/4] 导入 initiate_scan_flow...")
|
||||
try:
|
||||
from apps.scan.flows.initiate_scan_flow import initiate_scan_flow
|
||||
print("[3/4] ✓ 导入成功")
|
||||
except Exception as e:
|
||||
print(f"[3/4] ✗ 导入失败: {e}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
# 4. 执行 Flow
|
||||
print("[4/4] 执行 initiate_scan_flow...")
|
||||
try:
|
||||
result = initiate_scan_flow(
|
||||
scan_id=args.scan_id,
|
||||
target_name=args.target_name,
|
||||
target_id=args.target_id,
|
||||
scan_workspace_dir=args.scan_workspace_dir,
|
||||
engine_name=args.engine_name,
|
||||
scheduled_scan_name=args.scheduled_scan_name,
|
||||
)
|
||||
print("[4/4] ✓ Flow 执行完成")
|
||||
print(f"结果: {result}")
|
||||
except Exception as e:
|
||||
print(f"[4/4] ✗ Flow 执行失败: {e}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,245 +0,0 @@
|
||||
from rest_framework import serializers
|
||||
from django.db.models import Count
|
||||
|
||||
from .models import Scan, ScheduledScan
|
||||
|
||||
|
||||
class ScanSerializer(serializers.ModelSerializer):
|
||||
"""扫描任务序列化器"""
|
||||
target_name = serializers.SerializerMethodField()
|
||||
engine_name = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = Scan
|
||||
fields = [
|
||||
'id', 'target', 'target_name', 'engine', 'engine_name',
|
||||
'created_at', 'stopped_at', 'status', 'results_dir',
|
||||
'container_ids', 'error_message'
|
||||
]
|
||||
read_only_fields = [
|
||||
'id', 'created_at', 'stopped_at', 'results_dir',
|
||||
'container_ids', 'error_message', 'status'
|
||||
]
|
||||
|
||||
def get_target_name(self, obj):
|
||||
"""获取目标名称"""
|
||||
return obj.target.name if obj.target else None
|
||||
|
||||
def get_engine_name(self, obj):
|
||||
"""获取引擎名称"""
|
||||
return obj.engine.name if obj.engine else None
|
||||
|
||||
|
||||
class ScanHistorySerializer(serializers.ModelSerializer):
|
||||
"""扫描历史列表专用序列化器
|
||||
|
||||
为前端扫描历史页面提供优化的数据格式,包括:
|
||||
- 扫描汇总统计(子域名、端点、漏洞数量)
|
||||
- 进度百分比和当前阶段
|
||||
"""
|
||||
|
||||
# 字段映射
|
||||
target_name = serializers.CharField(source='target.name', read_only=True)
|
||||
engine_name = serializers.CharField(source='engine.name', read_only=True)
|
||||
|
||||
# 计算字段
|
||||
summary = serializers.SerializerMethodField()
|
||||
|
||||
# 进度跟踪字段(直接从模型读取)
|
||||
progress = serializers.IntegerField(read_only=True)
|
||||
current_stage = serializers.CharField(read_only=True)
|
||||
stage_progress = serializers.JSONField(read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = Scan
|
||||
fields = [
|
||||
'id', 'target', 'target_name', 'engine', 'engine_name',
|
||||
'created_at', 'status', 'error_message', 'summary', 'progress',
|
||||
'current_stage', 'stage_progress'
|
||||
]
|
||||
|
||||
def get_summary(self, obj):
|
||||
"""获取扫描汇总数据。
|
||||
|
||||
设计原则:
|
||||
- 子域名/网站/端点/IP/目录使用缓存字段(避免实时 COUNT)
|
||||
- 漏洞统计使用 Scan 上的缓存字段,在扫描结束时统一聚合
|
||||
"""
|
||||
# 1. 使用缓存字段构建基础统计(子域名、网站、端点、IP、目录)
|
||||
summary = {
|
||||
'subdomains': obj.cached_subdomains_count or 0,
|
||||
'websites': obj.cached_websites_count or 0,
|
||||
'endpoints': obj.cached_endpoints_count or 0,
|
||||
'ips': obj.cached_ips_count or 0,
|
||||
'directories': obj.cached_directories_count or 0,
|
||||
}
|
||||
|
||||
# 2. 使用 Scan 模型上的缓存漏洞统计(按严重性聚合)
|
||||
summary['vulnerabilities'] = {
|
||||
'total': obj.cached_vulns_total or 0,
|
||||
'critical': obj.cached_vulns_critical or 0,
|
||||
'high': obj.cached_vulns_high or 0,
|
||||
'medium': obj.cached_vulns_medium or 0,
|
||||
'low': obj.cached_vulns_low or 0,
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
class QuickScanSerializer(serializers.Serializer):
|
||||
"""
|
||||
快速扫描序列化器
|
||||
|
||||
功能:
|
||||
- 接收目标列表和引擎配置
|
||||
- 自动创建/获取目标
|
||||
- 立即发起扫描
|
||||
"""
|
||||
|
||||
# 批量创建的最大数量限制
|
||||
MAX_BATCH_SIZE = 1000
|
||||
|
||||
# 目标列表
|
||||
targets = serializers.ListField(
|
||||
child=serializers.DictField(),
|
||||
help_text='目标列表,每个目标包含 name 字段'
|
||||
)
|
||||
|
||||
# 扫描引擎 ID
|
||||
engine_id = serializers.IntegerField(
|
||||
required=True,
|
||||
help_text='使用的扫描引擎 ID (必填)'
|
||||
)
|
||||
|
||||
def validate_targets(self, value):
|
||||
"""验证目标列表"""
|
||||
if not value:
|
||||
raise serializers.ValidationError("目标列表不能为空")
|
||||
|
||||
# 检查数量限制,防止服务器过载
|
||||
if len(value) > self.MAX_BATCH_SIZE:
|
||||
raise serializers.ValidationError(
|
||||
f"快速扫描最多支持 {self.MAX_BATCH_SIZE} 个目标,当前提交了 {len(value)} 个"
|
||||
)
|
||||
|
||||
# 验证每个目标的必填字段
|
||||
for idx, target in enumerate(value):
|
||||
if 'name' not in target:
|
||||
raise serializers.ValidationError(f"第 {idx + 1} 个目标缺少 name 字段")
|
||||
if not target['name']:
|
||||
raise serializers.ValidationError(f"第 {idx + 1} 个目标的 name 不能为空")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# ==================== 定时扫描序列化器 ====================
|
||||
|
||||
class ScheduledScanSerializer(serializers.ModelSerializer):
|
||||
"""定时扫描任务序列化器(用于列表和详情)"""
|
||||
|
||||
# 关联字段
|
||||
engine_name = serializers.CharField(source='engine.name', read_only=True)
|
||||
organization_id = serializers.IntegerField(source='organization.id', read_only=True, allow_null=True)
|
||||
organization_name = serializers.CharField(source='organization.name', read_only=True, allow_null=True)
|
||||
target_id = serializers.IntegerField(source='target.id', read_only=True, allow_null=True)
|
||||
target_name = serializers.CharField(source='target.name', read_only=True, allow_null=True)
|
||||
scan_mode = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = ScheduledScan
|
||||
fields = [
|
||||
'id', 'name',
|
||||
'engine', 'engine_name',
|
||||
'organization_id', 'organization_name',
|
||||
'target_id', 'target_name',
|
||||
'scan_mode',
|
||||
'cron_expression',
|
||||
'is_enabled',
|
||||
'run_count', 'last_run_time', 'next_run_time',
|
||||
'created_at', 'updated_at'
|
||||
]
|
||||
read_only_fields = [
|
||||
'id', 'run_count',
|
||||
'last_run_time', 'next_run_time',
|
||||
'created_at', 'updated_at'
|
||||
]
|
||||
|
||||
def get_scan_mode(self, obj):
|
||||
"""获取扫描模式:organization 或 target"""
|
||||
return 'organization' if obj.organization_id else 'target'
|
||||
|
||||
|
||||
class CreateScheduledScanSerializer(serializers.Serializer):
|
||||
"""创建定时扫描任务序列化器
|
||||
|
||||
扫描模式(二选一):
|
||||
- 组织扫描:提供 organization_id,执行时动态获取组织下所有目标
|
||||
- 目标扫描:提供 target_id,扫描单个目标
|
||||
"""
|
||||
|
||||
name = serializers.CharField(max_length=200, help_text='任务名称')
|
||||
engine_id = serializers.IntegerField(help_text='扫描引擎 ID')
|
||||
|
||||
# 组织扫描模式
|
||||
organization_id = serializers.IntegerField(
|
||||
required=False,
|
||||
allow_null=True,
|
||||
help_text='组织 ID(组织扫描模式:执行时动态获取组织下所有目标)'
|
||||
)
|
||||
|
||||
# 目标扫描模式
|
||||
target_id = serializers.IntegerField(
|
||||
required=False,
|
||||
allow_null=True,
|
||||
help_text='目标 ID(目标扫描模式:扫描单个目标)'
|
||||
)
|
||||
|
||||
cron_expression = serializers.CharField(
|
||||
max_length=100,
|
||||
default='0 2 * * *',
|
||||
help_text='Cron 表达式,格式:分 时 日 月 周'
|
||||
)
|
||||
is_enabled = serializers.BooleanField(default=True, help_text='是否立即启用')
|
||||
|
||||
def validate(self, data):
|
||||
"""验证 organization_id 和 target_id 互斥"""
|
||||
organization_id = data.get('organization_id')
|
||||
target_id = data.get('target_id')
|
||||
|
||||
if not organization_id and not target_id:
|
||||
raise serializers.ValidationError('必须提供 organization_id 或 target_id 其中之一')
|
||||
|
||||
if organization_id and target_id:
|
||||
raise serializers.ValidationError('organization_id 和 target_id 只能提供其中之一')
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class UpdateScheduledScanSerializer(serializers.Serializer):
|
||||
"""更新定时扫描任务序列化器"""
|
||||
|
||||
name = serializers.CharField(max_length=200, required=False, help_text='任务名称')
|
||||
engine_id = serializers.IntegerField(required=False, help_text='扫描引擎 ID')
|
||||
|
||||
# 组织扫描模式
|
||||
organization_id = serializers.IntegerField(
|
||||
required=False,
|
||||
allow_null=True,
|
||||
help_text='组织 ID(设置后清空 target_id)'
|
||||
)
|
||||
|
||||
# 目标扫描模式
|
||||
target_id = serializers.IntegerField(
|
||||
required=False,
|
||||
allow_null=True,
|
||||
help_text='目标 ID(设置后清空 organization_id)'
|
||||
)
|
||||
|
||||
cron_expression = serializers.CharField(max_length=100, required=False, help_text='Cron 表达式')
|
||||
is_enabled = serializers.BooleanField(required=False, help_text='是否启用')
|
||||
|
||||
|
||||
class ToggleScheduledScanSerializer(serializers.Serializer):
|
||||
"""切换定时扫描启用状态序列化器"""
|
||||
|
||||
is_enabled = serializers.BooleanField(help_text='是否启用')
|
||||
@@ -1,85 +0,0 @@
|
||||
"""
|
||||
黑名单过滤服务
|
||||
|
||||
过滤敏感域名(如 .gov、.edu、.mil 等)
|
||||
|
||||
当前版本使用默认规则,后续将支持从前端配置加载。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from django.db.models import QuerySet
|
||||
import re
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BlacklistService:
|
||||
"""
|
||||
黑名单过滤服务 - 过滤敏感域名
|
||||
|
||||
TODO: 后续版本支持从前端配置加载黑名单规则
|
||||
- 用户在开始扫描时配置黑名单 URL、域名、IP
|
||||
- 黑名单规则存储在数据库中,与 Scan 或 Engine 关联
|
||||
"""
|
||||
|
||||
# 默认黑名单正则规则
|
||||
DEFAULT_PATTERNS = [
|
||||
r'\.gov$', # .gov 结尾
|
||||
r'\.gov\.[a-z]{2}$', # .gov.cn, .gov.uk 等
|
||||
r'\.edu$', # .edu 结尾
|
||||
r'\.edu\.[a-z]{2}$', # .edu.cn 等
|
||||
r'\.mil$', # .mil 结尾
|
||||
]
|
||||
|
||||
def __init__(self, patterns: Optional[List[str]] = None):
|
||||
"""
|
||||
初始化黑名单服务
|
||||
|
||||
Args:
|
||||
patterns: 正则表达式列表,None 使用默认规则
|
||||
"""
|
||||
self.patterns = patterns or self.DEFAULT_PATTERNS
|
||||
self._compiled_patterns = [re.compile(p) for p in self.patterns]
|
||||
|
||||
def filter_queryset(
|
||||
self,
|
||||
queryset: QuerySet,
|
||||
url_field: str = 'url'
|
||||
) -> QuerySet:
|
||||
"""
|
||||
数据库层面过滤 queryset
|
||||
|
||||
使用 PostgreSQL 正则表达式排除黑名单 URL
|
||||
|
||||
Args:
|
||||
queryset: 原始 queryset
|
||||
url_field: URL 字段名
|
||||
|
||||
Returns:
|
||||
QuerySet: 过滤后的 queryset
|
||||
"""
|
||||
for pattern in self.patterns:
|
||||
queryset = queryset.exclude(**{f'{url_field}__regex': pattern})
|
||||
return queryset
|
||||
|
||||
def filter_url(self, url: str) -> bool:
|
||||
"""
|
||||
检查单个 URL 是否通过黑名单过滤
|
||||
|
||||
Args:
|
||||
url: 要检查的 URL
|
||||
|
||||
Returns:
|
||||
bool: True 表示通过(不在黑名单),False 表示被过滤
|
||||
"""
|
||||
for pattern in self._compiled_patterns:
|
||||
if pattern.search(url):
|
||||
return False
|
||||
return True
|
||||
|
||||
# TODO: 后续版本实现
|
||||
# @classmethod
|
||||
# def from_scan(cls, scan_id: int) -> 'BlacklistService':
|
||||
# """从数据库加载扫描配置的黑名单规则"""
|
||||
# pass
|
||||
@@ -1,295 +0,0 @@
|
||||
"""
|
||||
快速扫描服务
|
||||
|
||||
负责解析用户输入(URL、域名、IP、CIDR)并创建对应的资产数据
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Literal, List, Dict, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from django.db import transaction
|
||||
|
||||
from apps.common.validators import validate_url, detect_input_type, validate_domain, validate_ip, validate_cidr, is_valid_ip
|
||||
from apps.targets.services.target_service import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.asset.dtos import WebSiteDTO
|
||||
from apps.asset.dtos.asset import EndpointDTO
|
||||
from apps.asset.repositories.asset.website_repository import DjangoWebSiteRepository
|
||||
from apps.asset.repositories.asset.endpoint_repository import DjangoEndpointRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedInputDTO:
|
||||
"""
|
||||
解析输入 DTO
|
||||
|
||||
只在快速扫描流程中使用
|
||||
"""
|
||||
original_input: str
|
||||
input_type: Literal['url', 'domain', 'ip', 'cidr']
|
||||
target_name: str # host/domain/ip/cidr
|
||||
target_type: Literal['domain', 'ip', 'cidr']
|
||||
website_url: Optional[str] = None # 根 URL(scheme://host[:port])
|
||||
endpoint_url: Optional[str] = None # 完整 URL(含路径)
|
||||
is_valid: bool = True
|
||||
error: Optional[str] = None
|
||||
line_number: Optional[int] = None
|
||||
|
||||
|
||||
class QuickScanService:
|
||||
"""快速扫描服务 - 解析输入并创建资产"""
|
||||
|
||||
def __init__(self):
|
||||
self.target_service = TargetService()
|
||||
self.website_repo = DjangoWebSiteRepository()
|
||||
self.endpoint_repo = DjangoEndpointRepository()
|
||||
|
||||
def parse_inputs(self, inputs: List[str]) -> List[ParsedInputDTO]:
|
||||
"""
|
||||
解析多行输入
|
||||
|
||||
Args:
|
||||
inputs: 输入字符串列表(每行一个)
|
||||
|
||||
Returns:
|
||||
解析结果列表(跳过空行)
|
||||
"""
|
||||
results = []
|
||||
for line_number, input_str in enumerate(inputs, start=1):
|
||||
input_str = input_str.strip()
|
||||
|
||||
# 空行跳过
|
||||
if not input_str:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 检测输入类型
|
||||
input_type = detect_input_type(input_str)
|
||||
|
||||
if input_type == 'url':
|
||||
dto = self._parse_url_input(input_str, line_number)
|
||||
else:
|
||||
dto = self._parse_target_input(input_str, input_type, line_number)
|
||||
|
||||
results.append(dto)
|
||||
except ValueError as e:
|
||||
# 解析失败,记录错误
|
||||
results.append(ParsedInputDTO(
|
||||
original_input=input_str,
|
||||
input_type='domain', # 默认类型
|
||||
target_name=input_str,
|
||||
target_type='domain',
|
||||
is_valid=False,
|
||||
error=str(e),
|
||||
line_number=line_number
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
def _parse_url_input(self, url_str: str, line_number: int) -> ParsedInputDTO:
|
||||
"""
|
||||
解析 URL 输入
|
||||
|
||||
Args:
|
||||
url_str: URL 字符串
|
||||
line_number: 行号
|
||||
|
||||
Returns:
|
||||
ParsedInputDTO
|
||||
"""
|
||||
# 验证 URL 格式
|
||||
validate_url(url_str)
|
||||
|
||||
# 使用标准库解析
|
||||
parsed = urlparse(url_str)
|
||||
|
||||
host = parsed.hostname # 不含端口
|
||||
has_path = parsed.path and parsed.path != '/'
|
||||
|
||||
# 构建 root_url: scheme://host[:port]
|
||||
root_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# 检测 host 类型(domain 或 ip)
|
||||
target_type = 'ip' if is_valid_ip(host) else 'domain'
|
||||
|
||||
return ParsedInputDTO(
|
||||
original_input=url_str,
|
||||
input_type='url',
|
||||
target_name=host,
|
||||
target_type=target_type,
|
||||
website_url=root_url,
|
||||
endpoint_url=url_str if has_path else None,
|
||||
line_number=line_number
|
||||
)
|
||||
|
||||
def _parse_target_input(
|
||||
self,
|
||||
input_str: str,
|
||||
input_type: str,
|
||||
line_number: int
|
||||
) -> ParsedInputDTO:
|
||||
"""
|
||||
解析非 URL 输入(domain/ip/cidr)
|
||||
|
||||
Args:
|
||||
input_str: 输入字符串
|
||||
input_type: 输入类型
|
||||
line_number: 行号
|
||||
|
||||
Returns:
|
||||
ParsedInputDTO
|
||||
"""
|
||||
# 验证格式
|
||||
if input_type == 'domain':
|
||||
validate_domain(input_str)
|
||||
target_type = 'domain'
|
||||
elif input_type == 'ip':
|
||||
validate_ip(input_str)
|
||||
target_type = 'ip'
|
||||
elif input_type == 'cidr':
|
||||
validate_cidr(input_str)
|
||||
target_type = 'cidr'
|
||||
else:
|
||||
raise ValueError(f"未知的输入类型: {input_type}")
|
||||
|
||||
return ParsedInputDTO(
|
||||
original_input=input_str,
|
||||
input_type=input_type,
|
||||
target_name=input_str,
|
||||
target_type=target_type,
|
||||
website_url=None,
|
||||
endpoint_url=None,
|
||||
line_number=line_number
|
||||
)
|
||||
|
||||
@transaction.atomic
|
||||
def process_quick_scan(
|
||||
self,
|
||||
inputs: List[str],
|
||||
engine_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理快速扫描请求
|
||||
|
||||
Args:
|
||||
inputs: 输入字符串列表
|
||||
engine_id: 扫描引擎 ID
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
# 1. 解析输入
|
||||
parsed_inputs = self.parse_inputs(inputs)
|
||||
|
||||
# 分离有效和无效输入
|
||||
valid_inputs = [p for p in parsed_inputs if p.is_valid]
|
||||
invalid_inputs = [p for p in parsed_inputs if not p.is_valid]
|
||||
|
||||
if not valid_inputs:
|
||||
return {
|
||||
'targets': [],
|
||||
'target_stats': {'created': 0, 'reused': 0, 'failed': len(invalid_inputs)},
|
||||
'asset_stats': {'websites_created': 0, 'endpoints_created': 0},
|
||||
'errors': [
|
||||
{'line_number': p.line_number, 'input': p.original_input, 'error': p.error}
|
||||
for p in invalid_inputs
|
||||
]
|
||||
}
|
||||
|
||||
# 2. 创建资产
|
||||
asset_result = self.create_assets_from_parsed_inputs(valid_inputs)
|
||||
|
||||
# 3. 返回结果
|
||||
return {
|
||||
'targets': asset_result['targets'],
|
||||
'target_stats': asset_result['target_stats'],
|
||||
'asset_stats': asset_result['asset_stats'],
|
||||
'errors': [
|
||||
{'line_number': p.line_number, 'input': p.original_input, 'error': p.error}
|
||||
for p in invalid_inputs
|
||||
]
|
||||
}
|
||||
|
||||
def create_assets_from_parsed_inputs(
|
||||
self,
|
||||
parsed_inputs: List[ParsedInputDTO]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从解析结果创建资产
|
||||
|
||||
Args:
|
||||
parsed_inputs: 解析结果列表(只包含有效输入)
|
||||
|
||||
Returns:
|
||||
创建结果字典
|
||||
"""
|
||||
# 1. 收集所有 target 数据(内存操作,去重)
|
||||
targets_data = {}
|
||||
for dto in parsed_inputs:
|
||||
if dto.target_name not in targets_data:
|
||||
targets_data[dto.target_name] = {'name': dto.target_name, 'type': dto.target_type}
|
||||
|
||||
targets_list = list(targets_data.values())
|
||||
|
||||
# 2. 批量创建 Target(复用现有方法)
|
||||
target_result = self.target_service.batch_create_targets(targets_list)
|
||||
|
||||
# 3. 查询刚创建的 Target,建立 name → id 映射
|
||||
target_names = [d['name'] for d in targets_list]
|
||||
targets = Target.objects.filter(name__in=target_names)
|
||||
target_id_map = {t.name: t.id for t in targets}
|
||||
|
||||
# 4. 收集 Website DTO(内存操作,去重)
|
||||
website_dtos = []
|
||||
seen_websites = set()
|
||||
for dto in parsed_inputs:
|
||||
if dto.website_url and dto.website_url not in seen_websites:
|
||||
seen_websites.add(dto.website_url)
|
||||
target_id = target_id_map.get(dto.target_name)
|
||||
if target_id:
|
||||
website_dtos.append(WebSiteDTO(
|
||||
target_id=target_id,
|
||||
url=dto.website_url,
|
||||
host=dto.target_name
|
||||
))
|
||||
|
||||
# 5. 批量创建 Website(存在即跳过)
|
||||
websites_created = 0
|
||||
if website_dtos:
|
||||
websites_created = self.website_repo.bulk_create_ignore_conflicts(website_dtos)
|
||||
|
||||
# 6. 收集 Endpoint DTO(内存操作,去重)
|
||||
endpoint_dtos = []
|
||||
seen_endpoints = set()
|
||||
for dto in parsed_inputs:
|
||||
if dto.endpoint_url and dto.endpoint_url not in seen_endpoints:
|
||||
seen_endpoints.add(dto.endpoint_url)
|
||||
target_id = target_id_map.get(dto.target_name)
|
||||
if target_id:
|
||||
endpoint_dtos.append(EndpointDTO(
|
||||
target_id=target_id,
|
||||
url=dto.endpoint_url,
|
||||
host=dto.target_name
|
||||
))
|
||||
|
||||
# 7. 批量创建 Endpoint(存在即跳过)
|
||||
endpoints_created = 0
|
||||
if endpoint_dtos:
|
||||
endpoints_created = self.endpoint_repo.bulk_create_ignore_conflicts(endpoint_dtos)
|
||||
|
||||
return {
|
||||
'targets': list(targets),
|
||||
'target_stats': {
|
||||
'created': target_result['created_count'],
|
||||
'reused': 0, # bulk_create 无法区分新建和复用
|
||||
'failed': target_result['failed_count']
|
||||
},
|
||||
'asset_stats': {
|
||||
'websites_created': websites_created,
|
||||
'endpoints_created': endpoints_created
|
||||
}
|
||||
}
|
||||
@@ -1,238 +0,0 @@
|
||||
"""
|
||||
扫描任务服务
|
||||
|
||||
负责 Scan 模型的所有业务逻辑
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from django.conf import settings
|
||||
from django.db import transaction
|
||||
from django.db.utils import DatabaseError, IntegrityError, OperationalError
|
||||
from django.core.exceptions import ValidationError, ObjectDoesNotExist
|
||||
|
||||
from apps.scan.models import Scan
|
||||
from apps.scan.repositories import DjangoScanRepository
|
||||
from apps.targets.repositories import DjangoTargetRepository, DjangoOrganizationRepository
|
||||
from apps.engine.repositories import DjangoEngineRepository
|
||||
from apps.targets.models import Target
|
||||
from apps.engine.models import ScanEngine
|
||||
from apps.common.definitions import ScanStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScanService:
|
||||
"""
|
||||
扫描任务服务(协调者)
|
||||
|
||||
职责:
|
||||
- 协调各个子服务
|
||||
- 提供统一的公共接口
|
||||
- 保持向后兼容
|
||||
|
||||
注意:
|
||||
- 具体业务逻辑已拆分到子服务
|
||||
- 本类主要负责委托和协调
|
||||
"""
|
||||
|
||||
# 终态集合:这些状态一旦设置,不应该被覆盖
|
||||
FINAL_STATUSES = {
|
||||
ScanStatus.COMPLETED,
|
||||
ScanStatus.FAILED,
|
||||
ScanStatus.CANCELLED
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化服务
|
||||
"""
|
||||
# 初始化子服务
|
||||
from apps.scan.services.scan_creation_service import ScanCreationService
|
||||
from apps.scan.services.scan_state_service import ScanStateService
|
||||
from apps.scan.services.scan_control_service import ScanControlService
|
||||
from apps.scan.services.scan_stats_service import ScanStatsService
|
||||
|
||||
self.creation_service = ScanCreationService()
|
||||
self.state_service = ScanStateService()
|
||||
self.control_service = ScanControlService()
|
||||
self.stats_service = ScanStatsService()
|
||||
|
||||
# 保留 ScanRepository(用于 get_scan 方法)
|
||||
self.scan_repo = DjangoScanRepository()
|
||||
|
||||
def get_scan(self, scan_id: int, prefetch_relations: bool) -> Scan | None:
|
||||
"""
|
||||
获取扫描任务(包含关联对象)
|
||||
|
||||
自动预加载 engine 和 target,避免 N+1 查询问题
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
|
||||
Returns:
|
||||
Scan 对象(包含 engine 和 target)或 None
|
||||
"""
|
||||
return self.scan_repo.get_by_id(scan_id, prefetch_relations)
|
||||
|
||||
def get_all_scans(self, prefetch_relations: bool = True):
|
||||
return self.scan_repo.get_all(prefetch_relations=prefetch_relations)
|
||||
|
||||
def prepare_initiate_scan(
|
||||
self,
|
||||
organization_id: int | None = None,
|
||||
target_id: int | None = None,
|
||||
engine_id: int | None = None
|
||||
) -> tuple[List[Target], ScanEngine]:
|
||||
"""
|
||||
为创建扫描任务做准备,返回所需的目标列表和扫描引擎
|
||||
"""
|
||||
return self.creation_service.prepare_initiate_scan(
|
||||
organization_id, target_id, engine_id
|
||||
)
|
||||
|
||||
def create_scans(
|
||||
self,
|
||||
targets: List[Target],
|
||||
engine: ScanEngine,
|
||||
scheduled_scan_name: str | None = None
|
||||
) -> List[Scan]:
|
||||
"""批量创建扫描任务(委托给 ScanCreationService)"""
|
||||
return self.creation_service.create_scans(targets, engine, scheduled_scan_name)
|
||||
|
||||
# ==================== 状态管理方法(委托给 ScanStateService) ====================
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
scan_id: int,
|
||||
status: ScanStatus,
|
||||
error_message: str | None = None,
|
||||
stopped_at: datetime | None = None
|
||||
) -> bool:
|
||||
"""更新 Scan 状态(委托给 ScanStateService)"""
|
||||
return self.state_service.update_status(
|
||||
scan_id, status, error_message, stopped_at
|
||||
)
|
||||
|
||||
def update_status_if_match(
|
||||
self,
|
||||
scan_id: int,
|
||||
current_status: ScanStatus,
|
||||
new_status: ScanStatus,
|
||||
stopped_at: datetime | None = None
|
||||
) -> bool:
|
||||
"""条件更新 Scan 状态(委托给 ScanStateService)"""
|
||||
return self.state_service.update_status_if_match(
|
||||
scan_id, current_status, new_status, stopped_at
|
||||
)
|
||||
|
||||
def update_cached_stats(self, scan_id: int) -> dict | None:
|
||||
"""更新缓存统计数据(委托给 ScanStateService),返回统计数据字典"""
|
||||
return self.state_service.update_cached_stats(scan_id)
|
||||
|
||||
# ==================== 进度跟踪方法(委托给 ScanStateService) ====================
|
||||
|
||||
def init_stage_progress(self, scan_id: int, stages: list[str]) -> bool:
|
||||
"""初始化阶段进度(委托给 ScanStateService)"""
|
||||
return self.state_service.init_stage_progress(scan_id, stages)
|
||||
|
||||
def start_stage(self, scan_id: int, stage: str) -> bool:
|
||||
"""开始执行某个阶段(委托给 ScanStateService)"""
|
||||
return self.state_service.start_stage(scan_id, stage)
|
||||
|
||||
def complete_stage(self, scan_id: int, stage: str, detail: str | None = None) -> bool:
|
||||
"""完成某个阶段(委托给 ScanStateService)"""
|
||||
return self.state_service.complete_stage(scan_id, stage, detail)
|
||||
|
||||
def fail_stage(self, scan_id: int, stage: str, error: str | None = None) -> bool:
|
||||
"""标记某个阶段失败(委托给 ScanStateService)"""
|
||||
return self.state_service.fail_stage(scan_id, stage, error)
|
||||
|
||||
def cancel_running_stages(self, scan_id: int, final_status: str = "cancelled") -> bool:
|
||||
"""取消所有正在运行的阶段(委托给 ScanStateService)"""
|
||||
return self.state_service.cancel_running_stages(scan_id, final_status)
|
||||
|
||||
# TODO:待接入
|
||||
def add_command_to_scan(self, scan_id: int, stage_name: str, tool_name: str, command: str) -> bool:
|
||||
"""
|
||||
增量添加命令到指定扫描阶段
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务ID
|
||||
stage_name: 阶段名称(如 'subdomain_discovery', 'port_scan')
|
||||
tool_name: 工具名称
|
||||
command: 执行命令
|
||||
|
||||
Returns:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
try:
|
||||
scan = self.get_scan(scan_id, prefetch_relations=False)
|
||||
if not scan:
|
||||
logger.error(f"扫描任务不存在: {scan_id}")
|
||||
return False
|
||||
|
||||
stage_progress = scan.stage_progress or {}
|
||||
|
||||
# 确保指定阶段存在
|
||||
if stage_name not in stage_progress:
|
||||
stage_progress[stage_name] = {'status': 'running', 'commands': []}
|
||||
|
||||
# 确保 commands 列表存在
|
||||
if 'commands' not in stage_progress[stage_name]:
|
||||
stage_progress[stage_name]['commands'] = []
|
||||
|
||||
# 增量添加命令
|
||||
command_entry = f"{tool_name}: {command}"
|
||||
stage_progress[stage_name]['commands'].append(command_entry)
|
||||
|
||||
scan.stage_progress = stage_progress
|
||||
scan.save(update_fields=['stage_progress'])
|
||||
|
||||
command_count = len(stage_progress[stage_name]['commands'])
|
||||
logger.info(f"✓ 记录命令: {stage_name}.{tool_name} (总计: {command_count})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录命令失败: {e}")
|
||||
return False
|
||||
|
||||
# ==================== 删除和控制方法(委托给 ScanControlService) ====================
|
||||
|
||||
def delete_scans_two_phase(self, scan_ids: List[int]) -> dict:
|
||||
"""两阶段删除扫描任务(委托给 ScanControlService)"""
|
||||
return self.control_service.delete_scans_two_phase(scan_ids)
|
||||
|
||||
def stop_scan(self, scan_id: int) -> tuple[bool, int]:
|
||||
"""停止扫描任务(委托给 ScanControlService)"""
|
||||
return self.control_service.stop_scan(scan_id)
|
||||
|
||||
def hard_delete_scans(self, scan_ids: List[int]) -> tuple[int, Dict[str, int]]:
|
||||
"""
|
||||
硬删除扫描任务(真正删除数据)
|
||||
|
||||
用于 Worker 容器中执行,删除已软删除的扫描及其关联数据。
|
||||
|
||||
Args:
|
||||
scan_ids: 扫描任务 ID 列表
|
||||
|
||||
Returns:
|
||||
(删除数量, 详情字典)
|
||||
"""
|
||||
return self.scan_repo.hard_delete_by_ids(scan_ids)
|
||||
|
||||
# ==================== 统计方法(委托给 ScanStatsService) ====================
|
||||
|
||||
def get_statistics(self) -> dict:
|
||||
"""获取扫描统计数据(委托给 ScanStatsService)"""
|
||||
return self.stats_service.get_statistics()
|
||||
|
||||
|
||||
|
||||
# 导出接口
|
||||
__all__ = ['ScanService']
|
||||
@@ -1,364 +0,0 @@
|
||||
"""
|
||||
目标导出服务
|
||||
|
||||
提供统一的目标提取和文件导出功能,支持:
|
||||
- URL 导出(流式写入 + 默认值回退)
|
||||
- 域名/IP 导出(用于端口扫描)
|
||||
- 黑名单过滤集成
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Iterator
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from .blacklist_service import BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TargetExportService:
|
||||
"""
|
||||
目标导出服务 - 提供统一的目标提取和文件导出功能
|
||||
|
||||
使用方式:
|
||||
# Task 层决定数据源
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
# 使用导出服务
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
result = export_service.export_urls(target_id, output_path, queryset)
|
||||
"""
|
||||
|
||||
def __init__(self, blacklist_service: Optional[BlacklistService] = None):
|
||||
"""
|
||||
初始化导出服务
|
||||
|
||||
Args:
|
||||
blacklist_service: 黑名单过滤服务,None 表示禁用过滤
|
||||
"""
|
||||
self.blacklist_service = blacklist_service
|
||||
|
||||
def export_urls(
|
||||
self,
|
||||
target_id: int,
|
||||
output_path: str,
|
||||
queryset: QuerySet,
|
||||
url_field: str = 'url',
|
||||
batch_size: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
统一 URL 导出函数
|
||||
|
||||
自动判断数据库有无数据:
|
||||
- 有数据:流式写入数据库数据到文件
|
||||
- 无数据:调用默认值生成器生成 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_path: 输出文件路径
|
||||
queryset: 数据源 queryset(由 Task 层构建,应为 values_list flat=True)
|
||||
url_field: URL 字段名(用于黑名单过滤)
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int
|
||||
}
|
||||
|
||||
Raises:
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
output_file = Path(output_path)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("开始导出 URL - target_id=%s, output=%s", target_id, output_path)
|
||||
|
||||
# 应用黑名单过滤(数据库层面)
|
||||
if self.blacklist_service:
|
||||
# 注意:queryset 应该是原始 queryset,不是 values_list
|
||||
# 这里假设 Task 层传入的是 values_list,需要在 Task 层处理过滤
|
||||
pass
|
||||
|
||||
total_count = 0
|
||||
try:
|
||||
with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in queryset.iterator(chunk_size=batch_size):
|
||||
if url:
|
||||
# Python 层面黑名单过滤
|
||||
if self.blacklist_service and not self.blacklist_service.filter_url(url):
|
||||
continue
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
except IOError as e:
|
||||
logger.error("文件写入失败: %s - %s", output_path, e)
|
||||
raise
|
||||
|
||||
# 默认值回退模式
|
||||
if total_count == 0:
|
||||
total_count = self._generate_default_urls(target_id, output_file)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 数量: %d, 文件: %s", total_count, output_path)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_file),
|
||||
'total_count': total_count
|
||||
}
|
||||
|
||||
def _generate_default_urls(
|
||||
self,
|
||||
target_id: int,
|
||||
output_path: Path
|
||||
) -> int:
|
||||
"""
|
||||
默认值生成器(内部函数)
|
||||
|
||||
根据 Target 类型生成默认 URL:
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
- URL: 直接使用目标 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 写入的 URL 总数
|
||||
"""
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
return 0
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name)
|
||||
|
||||
total_urls = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
urls = [f"http://{target_name}", f"https://{target_name}"]
|
||||
for url in urls:
|
||||
if self._should_write_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
urls = [f"http://{target_name}", f"https://{target_name}"]
|
||||
for url in urls:
|
||||
if self._should_write_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
|
||||
for ip in network.hosts():
|
||||
urls = [f"http://{ip}", f"https://{ip}"]
|
||||
for url in urls:
|
||||
if self._should_write_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
if total_urls % 10000 == 0:
|
||||
logger.info("已生成 %d 个 URL...", total_urls)
|
||||
|
||||
# /32 或 /128 特殊处理
|
||||
if total_urls == 0:
|
||||
ip = str(network.network_address)
|
||||
urls = [f"http://{ip}", f"https://{ip}"]
|
||||
for url in urls:
|
||||
if self._should_write_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
raise ValueError(f"无效的 CIDR: {target_name}") from e
|
||||
|
||||
elif target_type == Target.TargetType.URL:
|
||||
if self._should_write_url(target_name):
|
||||
f.write(f"{target_name}\n")
|
||||
total_urls = 1
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
|
||||
logger.info("✓ 懒加载生成默认 URL - 数量: %d", total_urls)
|
||||
return total_urls
|
||||
|
||||
def _should_write_url(self, url: str) -> bool:
|
||||
"""检查 URL 是否应该写入(通过黑名单过滤)"""
|
||||
if self.blacklist_service:
|
||||
return self.blacklist_service.filter_url(url)
|
||||
return True
|
||||
|
||||
def export_targets(
|
||||
self,
|
||||
target_id: int,
|
||||
output_path: str,
|
||||
batch_size: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
域名/IP 导出函数(用于端口扫描)
|
||||
|
||||
根据 Target 类型选择导出逻辑:
|
||||
- DOMAIN: 从 Subdomain 表流式导出子域名
|
||||
- IP: 直接写入 IP 地址
|
||||
- CIDR: 展开为所有主机 IP
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_path: 输出文件路径
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int,
|
||||
'target_type': str
|
||||
}
|
||||
"""
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
|
||||
output_file = Path(output_path)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取 Target 信息
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
|
||||
if not target:
|
||||
raise ValueError(f"Target ID {target_id} 不存在")
|
||||
|
||||
target_type = target.type
|
||||
target_name = target.name
|
||||
|
||||
logger.info(
|
||||
"开始导出扫描目标 - Target ID: %d, Name: %s, Type: %s, 输出文件: %s",
|
||||
target_id, target_name, target_type, output_path
|
||||
)
|
||||
|
||||
total_count = 0
|
||||
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
total_count = self._export_domains(target_id, target_name, output_file, batch_size)
|
||||
type_desc = "域名"
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
total_count = self._export_ip(target_name, output_file)
|
||||
type_desc = "IP"
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
total_count = self._export_cidr(target_name, output_file)
|
||||
type_desc = "CIDR IP"
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的目标类型: {target_type}")
|
||||
|
||||
logger.info(
|
||||
"✓ 扫描目标导出完成 - 类型: %s, 总数: %d, 文件: %s",
|
||||
type_desc, total_count, output_path
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_file),
|
||||
'total_count': total_count,
|
||||
'target_type': target_type
|
||||
}
|
||||
|
||||
def _export_domains(
|
||||
self,
|
||||
target_id: int,
|
||||
target_name: str,
|
||||
output_path: Path,
|
||||
batch_size: int
|
||||
) -> int:
|
||||
"""导出域名类型目标的子域名"""
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
|
||||
subdomain_service = SubdomainService()
|
||||
domain_iterator = subdomain_service.iter_subdomain_names_by_target(
|
||||
target_id=target_id,
|
||||
chunk_size=batch_size
|
||||
)
|
||||
|
||||
total_count = 0
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for domain_name in domain_iterator:
|
||||
if self._should_write_target(domain_name):
|
||||
f.write(f"{domain_name}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个域名...", total_count)
|
||||
|
||||
# 默认值模式:如果没有子域名,使用根域名
|
||||
if total_count == 0:
|
||||
logger.info("采用默认域名:%s (target_id=%d)", target_name, target_id)
|
||||
if self._should_write_target(target_name):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{target_name}\n")
|
||||
total_count = 1
|
||||
|
||||
return total_count
|
||||
|
||||
def _export_ip(self, target_name: str, output_path: Path) -> int:
|
||||
"""导出 IP 类型目标"""
|
||||
if self._should_write_target(target_name):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{target_name}\n")
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def _export_cidr(self, target_name: str, output_path: Path) -> int:
|
||||
"""导出 CIDR 类型目标,展开为每个 IP"""
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
total_count = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for ip in network.hosts():
|
||||
ip_str = str(ip)
|
||||
if self._should_write_target(ip_str):
|
||||
f.write(f"{ip_str}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个 IP...", total_count)
|
||||
|
||||
# /32 或 /128 特殊处理
|
||||
if total_count == 0:
|
||||
ip_str = str(network.network_address)
|
||||
if self._should_write_target(ip_str):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{ip_str}\n")
|
||||
total_count = 1
|
||||
|
||||
return total_count
|
||||
|
||||
def _should_write_target(self, target: str) -> bool:
|
||||
"""检查目标是否应该写入(通过黑名单过滤)"""
|
||||
if self.blacklist_service:
|
||||
return self.blacklist_service.filter_url(target)
|
||||
return True
|
||||
@@ -1,71 +0,0 @@
|
||||
"""
|
||||
导出站点 URL 到 TXT 文件的 Task
|
||||
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
数据源: WebSite.url
|
||||
"""
|
||||
import logging
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.models import WebSite
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_sites")
|
||||
def export_sites_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
batch_size: int = 1000,
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点 URL 到 TXT 文件
|
||||
|
||||
数据源: WebSite.url
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
batch_size: 每次读取的批次大小,默认 1000
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'success': result['success'],
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count']
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
"""
|
||||
导出 URL 任务
|
||||
|
||||
用于指纹识别前导出目标下的 URL 到文件
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.models import WebSite
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_urls_for_fingerprint")
|
||||
def export_urls_for_fingerprint_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
source: str = 'website',
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的 URL 到文件(用于指纹识别)
|
||||
|
||||
数据源: WebSite.url
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
- URL: 直接使用目标 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径
|
||||
source: 数据源类型(保留参数,兼容旧调用)
|
||||
batch_size: 批量读取大小
|
||||
|
||||
Returns:
|
||||
dict: {'output_file': str, 'total_count': int, 'source': str}
|
||||
"""
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
'source': source
|
||||
}
|
||||
@@ -1,300 +0,0 @@
|
||||
"""
|
||||
xingfinger 执行任务
|
||||
|
||||
流式执行 xingfinger 命令并实时更新 tech 字段
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
from typing import Optional, Generator
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from django.db import connection
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.utils import execute_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 数据源映射:source → (module_path, model_name, url_field)
|
||||
SOURCE_MODEL_MAP = {
|
||||
'website': ('apps.asset.models', 'WebSite', 'url'),
|
||||
# 以后扩展:
|
||||
# 'endpoint': ('apps.asset.models', 'Endpoint', 'url'),
|
||||
# 'directory': ('apps.asset.models', 'Directory', 'url'),
|
||||
}
|
||||
|
||||
|
||||
def _get_model_class(source: str):
|
||||
"""根据数据源类型获取 Model 类"""
|
||||
if source not in SOURCE_MODEL_MAP:
|
||||
raise ValueError(f"不支持的数据源: {source}")
|
||||
|
||||
module_path, model_name, _ = SOURCE_MODEL_MAP[source]
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, model_name)
|
||||
|
||||
|
||||
def parse_xingfinger_line(line: str) -> tuple[str, list[str]] | None:
|
||||
"""
|
||||
解析 xingfinger 单行 JSON 输出
|
||||
|
||||
xingfinger 静默模式输出格式:
|
||||
{"url": "https://example.com", "cms": "WordPress,PHP,nginx", ...}
|
||||
|
||||
Returns:
|
||||
tuple: (url, tech_list) 或 None(解析失败时)
|
||||
"""
|
||||
try:
|
||||
item = json.loads(line)
|
||||
url = item.get('url', '').strip()
|
||||
cms = item.get('cms', '')
|
||||
|
||||
if not url or not cms:
|
||||
return None
|
||||
|
||||
# cms 字段按逗号分割,去除空白
|
||||
techs = [t.strip() for t in cms.split(',') if t.strip()]
|
||||
|
||||
return (url, techs) if techs else None
|
||||
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def bulk_merge_tech_field(
|
||||
source: str,
|
||||
url_techs_map: dict[str, list[str]],
|
||||
target_id: int
|
||||
) -> dict:
|
||||
"""
|
||||
批量合并 tech 数组字段(PostgreSQL 原生 SQL)
|
||||
|
||||
使用 PostgreSQL 原生 SQL 实现高效的数组合并去重操作。
|
||||
如果 URL 对应的记录不存在,会自动创建新记录。
|
||||
|
||||
Returns:
|
||||
dict: {'updated_count': int, 'created_count': int}
|
||||
"""
|
||||
Model = _get_model_class(source)
|
||||
table_name = Model._meta.db_table
|
||||
|
||||
updated_count = 0
|
||||
created_count = 0
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
for url, techs in url_techs_map.items():
|
||||
if not techs:
|
||||
continue
|
||||
|
||||
# 先尝试更新(PostgreSQL 数组合并去重)
|
||||
sql = f"""
|
||||
UPDATE {table_name}
|
||||
SET tech = (
|
||||
SELECT ARRAY(SELECT DISTINCT unnest(
|
||||
COALESCE(tech, ARRAY[]::varchar[]) || %s::varchar[]
|
||||
))
|
||||
)
|
||||
WHERE url = %s AND target_id = %s
|
||||
"""
|
||||
|
||||
cursor.execute(sql, [techs, url, target_id])
|
||||
|
||||
if cursor.rowcount > 0:
|
||||
updated_count += cursor.rowcount
|
||||
else:
|
||||
# 记录不存在,创建新记录
|
||||
try:
|
||||
# 从 URL 提取 host
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname or ''
|
||||
|
||||
# 插入新记录(带冲突处理)
|
||||
insert_sql = f"""
|
||||
INSERT INTO {table_name} (target_id, url, host, tech, created_at)
|
||||
VALUES (%s, %s, %s, %s::varchar[], NOW())
|
||||
ON CONFLICT (target_id, url) DO UPDATE SET
|
||||
tech = (
|
||||
SELECT ARRAY(SELECT DISTINCT unnest(
|
||||
COALESCE({table_name}.tech, ARRAY[]::varchar[]) || EXCLUDED.tech
|
||||
))
|
||||
)
|
||||
"""
|
||||
cursor.execute(insert_sql, [target_id, url, host, techs])
|
||||
created_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("创建 %s 记录失败 (url=%s): %s", source, url, e)
|
||||
|
||||
return {
|
||||
'updated_count': updated_count,
|
||||
'created_count': created_count
|
||||
}
|
||||
|
||||
|
||||
def _parse_xingfinger_stream_output(
|
||||
cmd: str,
|
||||
tool_name: str,
|
||||
cwd: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
log_file: Optional[str] = None
|
||||
) -> Generator[tuple[str, list[str]], None, None]:
|
||||
"""
|
||||
流式解析 xingfinger 命令输出
|
||||
|
||||
基于 execute_stream 实时处理 xingfinger 命令的 stdout,将每行 JSON 输出
|
||||
转换为 (url, tech_list) 格式
|
||||
"""
|
||||
logger.info("开始流式解析 xingfinger 命令输出 - 命令: %s", cmd)
|
||||
|
||||
total_lines = 0
|
||||
valid_records = 0
|
||||
|
||||
try:
|
||||
for line in execute_stream(cmd=cmd, tool_name=tool_name, cwd=cwd, shell=True, timeout=timeout, log_file=log_file):
|
||||
total_lines += 1
|
||||
|
||||
# 解析单行 JSON
|
||||
result = parse_xingfinger_line(line)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
valid_records += 1
|
||||
yield result
|
||||
|
||||
# 每处理 500 条记录输出一次进度
|
||||
if valid_records % 500 == 0:
|
||||
logger.info("已解析 %d 条有效记录...", valid_records)
|
||||
|
||||
except subprocess.TimeoutExpired as e:
|
||||
error_msg = f"xingfinger 命令执行超时 - 超过 {timeout} 秒"
|
||||
logger.warning(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
except Exception as e:
|
||||
logger.error("流式解析 xingfinger 输出失败: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
logger.info("流式解析完成 - 总行数: %d, 有效记录: %d", total_lines, valid_records)
|
||||
|
||||
|
||||
@task(name="run_xingfinger_and_stream_update_tech")
|
||||
def run_xingfinger_and_stream_update_tech_task(
|
||||
cmd: str,
|
||||
tool_name: str,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
source: str,
|
||||
cwd: str,
|
||||
timeout: int,
|
||||
log_file: str,
|
||||
batch_size: int = 100
|
||||
) -> dict:
|
||||
"""
|
||||
流式执行 xingfinger 命令并实时更新 tech 字段
|
||||
|
||||
根据 source 参数更新对应表的 tech 字段:
|
||||
- website → WebSite.tech
|
||||
- endpoint → Endpoint.tech(以后扩展)
|
||||
|
||||
处理流程:
|
||||
1. 流式执行 xingfinger 命令
|
||||
2. 实时解析 JSON 输出
|
||||
3. 累积到 batch_size 条后批量更新数据库
|
||||
4. 使用 PostgreSQL 原生 SQL 进行数组合并去重
|
||||
5. 如果记录不存在,自动创建
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'processed_records': int,
|
||||
'updated_count': int,
|
||||
'created_count': int,
|
||||
'batch_count': int
|
||||
}
|
||||
"""
|
||||
logger.info(
|
||||
"开始执行 xingfinger 并更新 tech - target_id=%s, source=%s, timeout=%s秒",
|
||||
target_id, source, timeout
|
||||
)
|
||||
|
||||
data_generator = None
|
||||
|
||||
try:
|
||||
# 初始化统计
|
||||
processed_records = 0
|
||||
updated_count = 0
|
||||
created_count = 0
|
||||
batch_count = 0
|
||||
|
||||
# 当前批次的 URL -> techs 映射
|
||||
url_techs_map = {}
|
||||
|
||||
# 流式处理
|
||||
data_generator = _parse_xingfinger_stream_output(
|
||||
cmd=cmd,
|
||||
tool_name=tool_name,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
log_file=log_file
|
||||
)
|
||||
|
||||
for url, techs in data_generator:
|
||||
processed_records += 1
|
||||
|
||||
# 累积到 url_techs_map
|
||||
if url in url_techs_map:
|
||||
# 合并同一 URL 的多次识别结果
|
||||
url_techs_map[url].extend(techs)
|
||||
else:
|
||||
url_techs_map[url] = techs
|
||||
|
||||
# 达到批次大小,执行批量更新
|
||||
if len(url_techs_map) >= batch_size:
|
||||
batch_count += 1
|
||||
result = bulk_merge_tech_field(source, url_techs_map, target_id)
|
||||
updated_count += result['updated_count']
|
||||
created_count += result.get('created_count', 0)
|
||||
|
||||
logger.debug(
|
||||
"批次 %d 完成 - 更新: %d, 创建: %d",
|
||||
batch_count, result['updated_count'], result.get('created_count', 0)
|
||||
)
|
||||
|
||||
# 清空批次
|
||||
url_techs_map = {}
|
||||
|
||||
# 处理最后一批
|
||||
if url_techs_map:
|
||||
batch_count += 1
|
||||
result = bulk_merge_tech_field(source, url_techs_map, target_id)
|
||||
updated_count += result['updated_count']
|
||||
created_count += result.get('created_count', 0)
|
||||
|
||||
logger.info(
|
||||
"✓ xingfinger 执行完成 - 处理记录: %d, 更新: %d, 创建: %d, 批次: %d",
|
||||
processed_records, updated_count, created_count, batch_count
|
||||
)
|
||||
|
||||
return {
|
||||
'processed_records': processed_records,
|
||||
'updated_count': updated_count,
|
||||
'created_count': created_count,
|
||||
'batch_count': batch_count
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("⚠️ xingfinger 执行超时 - target_id=%s, timeout=%s秒", target_id, timeout)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"xingfinger 执行失败: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
finally:
|
||||
# 清理资源
|
||||
if data_generator is not None:
|
||||
try:
|
||||
data_generator.close()
|
||||
except Exception as e:
|
||||
logger.debug("关闭生成器时出错: %s", e)
|
||||
@@ -1,66 +0,0 @@
|
||||
"""
|
||||
导出扫描目标到 TXT 文件的 Task
|
||||
|
||||
使用 TargetExportService.export_targets() 统一处理导出逻辑
|
||||
|
||||
根据 Target 类型决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名
|
||||
- IP: 直接写入 target.name
|
||||
- CIDR: 展开 CIDR 范围内的所有 IP
|
||||
"""
|
||||
import logging
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_scan_targets")
|
||||
def export_scan_targets_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出扫描目标到 TXT 文件
|
||||
|
||||
根据 Target 类型自动决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名(流式处理,支持 10万+ 域名)
|
||||
- IP: 直接写入 target.name(单个 IP)
|
||||
- CIDR: 展开 CIDR 范围内的所有可用 IP
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
batch_size: 每次读取的批次大小,默认 1000(仅对 DOMAIN 类型有效)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int,
|
||||
'target_type': str
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: Target 不存在
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
result = export_service.export_targets(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'success': result['success'],
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
'target_type': result['target_type']
|
||||
}
|
||||
@@ -1,127 +0,0 @@
|
||||
"""
|
||||
导出站点URL到文件的Task
|
||||
|
||||
直接使用 HostPortMapping 表查询 host+port 组合,拼接成URL格式写入文件
|
||||
使用 TargetExportService 处理默认值回退逻辑
|
||||
|
||||
特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.services import HostPortMappingService
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _generate_urls_from_port(host: str, port: int) -> list[str]:
|
||||
"""
|
||||
根据端口生成 URL 列表
|
||||
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
"""
|
||||
if port == 80:
|
||||
return [f"http://{host}"]
|
||||
elif port == 443:
|
||||
return [f"https://{host}"]
|
||||
else:
|
||||
return [f"http://{host}:{port}", f"https://{host}:{port}"]
|
||||
|
||||
|
||||
@task(name="export_site_urls")
|
||||
def export_site_urls_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点URL到文件(基于 HostPortMapping 表)
|
||||
|
||||
数据源: HostPortMapping (host + port)
|
||||
|
||||
特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
batch_size: 每次处理的批次大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_urls': int,
|
||||
'association_count': int # 主机端口关联数量
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 初始化黑名单服务
|
||||
blacklist_service = BlacklistService()
|
||||
|
||||
# 直接查询 HostPortMapping 表,按 host 排序
|
||||
service = HostPortMappingService()
|
||||
associations = service.iter_host_port_by_target(
|
||||
target_id=target_id,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
total_urls = 0
|
||||
association_count = 0
|
||||
|
||||
# 流式写入文件(特殊端口逻辑)
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for assoc in associations:
|
||||
association_count += 1
|
||||
host = assoc['host']
|
||||
port = assoc['port']
|
||||
|
||||
# 根据端口号生成URL
|
||||
for url in _generate_urls_from_port(host, port):
|
||||
if blacklist_service.filter_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
if association_count % 1000 == 0:
|
||||
logger.info("已处理 %d 条关联,生成 %d 个URL...", association_count, total_urls)
|
||||
|
||||
logger.info(
|
||||
"✓ 站点URL导出完成 - 关联数: %d, 总URL数: %d, 文件: %s",
|
||||
association_count, total_urls, str(output_path)
|
||||
)
|
||||
|
||||
# 默认值回退模式:使用 TargetExportService
|
||||
if total_urls == 0:
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
total_urls = export_service._generate_default_urls(target_id, output_path)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_urls': total_urls,
|
||||
'association_count': association_count
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
"""
|
||||
合并并去重域名任务
|
||||
|
||||
合并 merge + parse + validate 三个步骤,优化性能:
|
||||
- 单命令实现(LC_ALL=C sort -u)
|
||||
- C语言级性能,单进程高效
|
||||
- 无临时文件,零额外开销
|
||||
- 支持千万级数据处理
|
||||
|
||||
性能优势:
|
||||
- LC_ALL=C 字节序比较(比locale快20-30%)
|
||||
- 单进程直接处理多文件(无管道开销)
|
||||
- 内存占用恒定(~50MB for 50万域名)
|
||||
- 50万域名处理时间:~0.5秒(相比 Python 提升 ~67%)
|
||||
|
||||
Note:
|
||||
- 工具(amass/subfinder)输出已标准化(小写,无空行)
|
||||
- sort -u 自动处理去重和排序
|
||||
- 无需额外过滤,性能最优
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from prefect import task
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 注:使用纯系统命令实现,无需 Python 缓冲区配置
|
||||
# 工具(amass/subfinder)输出已是小写且标准化
|
||||
|
||||
@task(
|
||||
name='merge_and_deduplicate',
|
||||
retries=1,
|
||||
log_prints=True
|
||||
)
|
||||
def merge_and_validate_task(
|
||||
result_files: List[str],
|
||||
result_dir: str
|
||||
) -> str:
|
||||
"""
|
||||
合并扫描结果并去重(高性能流式处理)
|
||||
|
||||
流程:
|
||||
1. 使用 LC_ALL=C sort -u 直接处理多文件
|
||||
2. 排序去重一步完成
|
||||
3. 返回去重后的文件路径
|
||||
|
||||
命令:LC_ALL=C sort -u file1 file2 file3 -o output
|
||||
注:工具输出已标准化(小写,无空行),无需额外处理
|
||||
|
||||
Args:
|
||||
result_files: 结果文件路径列表
|
||||
result_dir: 结果目录
|
||||
|
||||
Returns:
|
||||
str: 去重后的域名文件路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 处理失败
|
||||
|
||||
Performance:
|
||||
- 纯系统命令(C语言实现),单进程极简
|
||||
- LC_ALL=C: 字节序比较
|
||||
- sort -u: 直接处理多文件(无管道开销)
|
||||
|
||||
Design:
|
||||
- 极简单命令,无冗余处理
|
||||
- 单进程直接执行(无管道/重定向开销)
|
||||
- 内存占用仅在 sort 阶段(外部排序,不会 OOM)
|
||||
"""
|
||||
logger.info("开始合并并去重 %d 个结果文件(系统命令优化)", len(result_files))
|
||||
|
||||
result_path = Path(result_dir)
|
||||
|
||||
# 验证文件存在性
|
||||
valid_files = []
|
||||
for file_path_str in result_files:
|
||||
file_path = Path(file_path_str)
|
||||
if file_path.exists():
|
||||
valid_files.append(str(file_path))
|
||||
else:
|
||||
logger.warning("结果文件不存在: %s", file_path)
|
||||
|
||||
if not valid_files:
|
||||
raise RuntimeError("所有结果文件都不存在")
|
||||
|
||||
# 生成输出文件路径
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
short_uuid = uuid.uuid4().hex[:4]
|
||||
merged_file = result_path / f"merged_{timestamp}_{short_uuid}.txt"
|
||||
|
||||
try:
|
||||
# ==================== 使用系统命令一步完成:排序去重 ====================
|
||||
# LC_ALL=C: 使用字节序比较(比locale快20-30%)
|
||||
# sort -u: 直接处理多文件,排序去重
|
||||
# -o: 安全输出(比重定向更可靠)
|
||||
cmd = f"LC_ALL=C sort -u {' '.join(valid_files)} -o {merged_file}"
|
||||
|
||||
logger.debug("执行命令: %s", cmd)
|
||||
|
||||
# 按输入文件总行数动态计算超时时间
|
||||
total_lines = 0
|
||||
for file_path in valid_files:
|
||||
try:
|
||||
line_count_proc = subprocess.run(
|
||||
["wc", "-l", file_path],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
total_lines += int(line_count_proc.stdout.strip().split()[0])
|
||||
except (subprocess.CalledProcessError, ValueError, IndexError):
|
||||
continue
|
||||
|
||||
timeout = 3600
|
||||
if total_lines > 0:
|
||||
# 按行数线性计算:每行约 0.1 秒
|
||||
base_per_line = 0.1
|
||||
est = int(total_lines * base_per_line)
|
||||
timeout = max(600, est)
|
||||
|
||||
logger.info(
|
||||
"Subdomain 合并去重 timeout 自动计算: 输入总行数=%d, timeout=%d秒",
|
||||
total_lines,
|
||||
timeout,
|
||||
)
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
check=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
logger.debug("✓ 合并去重完成")
|
||||
|
||||
# ==================== 统计结果 ====================
|
||||
if not merged_file.exists():
|
||||
raise RuntimeError("合并文件未被创建")
|
||||
|
||||
# 统计行数(使用系统命令提升大文件性能)
|
||||
try:
|
||||
line_count_proc = subprocess.run(
|
||||
["wc", "-l", str(merged_file)],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
unique_count = int(line_count_proc.stdout.strip().split()[0])
|
||||
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
|
||||
logger.warning(
|
||||
"wc -l 统计失败(文件: %s),降级为 Python 逐行统计 - 错误: %s",
|
||||
merged_file, e
|
||||
)
|
||||
unique_count = 0
|
||||
with open(merged_file, 'r', encoding='utf-8') as file_obj:
|
||||
for _ in file_obj:
|
||||
unique_count += 1
|
||||
|
||||
if unique_count == 0:
|
||||
raise RuntimeError("未找到任何有效域名")
|
||||
|
||||
file_size = merged_file.stat().st_size
|
||||
|
||||
logger.info(
|
||||
"✓ 合并去重完成 - 去重后: %d 个域名, 文件大小: %.2f KB",
|
||||
unique_count,
|
||||
file_size / 1024
|
||||
)
|
||||
|
||||
return str(merged_file)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
error_msg = "合并去重超时(>60分钟),请检查数据量或系统资源"
|
||||
logger.warning(error_msg) # 超时是可预期的
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_msg = f"系统命令执行失败: {e.stderr if e.stderr else str(e)}"
|
||||
logger.warning(error_msg) # 超时是可预期的
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
except IOError as e:
|
||||
error_msg = f"文件读写失败: {e}"
|
||||
logger.warning(error_msg) # 超时是可预期的
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"合并去重失败: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise
|
||||
@@ -1,74 +0,0 @@
|
||||
"""
|
||||
导出站点 URL 列表任务
|
||||
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
数据源: WebSite.url(用于 katana 等爬虫工具)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from prefect import task
|
||||
from typing import Optional
|
||||
|
||||
from apps.asset.models import WebSite
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(
|
||||
name='export_sites_for_url_fetch',
|
||||
retries=1,
|
||||
log_prints=True
|
||||
)
|
||||
def export_sites_task(
|
||||
output_file: str,
|
||||
target_id: int,
|
||||
scan_id: int,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出站点 URL 列表到文件(用于 katana 等爬虫工具)
|
||||
|
||||
数据源: WebSite.url
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
|
||||
Args:
|
||||
output_file: 输出文件路径
|
||||
target_id: 目标 ID
|
||||
scan_id: 扫描 ID(保留参数,兼容旧调用)
|
||||
batch_size: 批次大小(内存优化)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'output_file': str, # 输出文件路径
|
||||
'asset_count': int, # 资产数量
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 参数错误
|
||||
RuntimeError: 执行失败
|
||||
"""
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'output_file': result['output_file'],
|
||||
'asset_count': result['total_count'],
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
"""导出 Endpoint URL 到文件的 Task
|
||||
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
数据源: Endpoint.url
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.models import Endpoint
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_endpoints")
|
||||
def export_endpoints_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
batch_size: int = 1000,
|
||||
) -> Dict[str, object]:
|
||||
"""导出目标下的所有 Endpoint URL 到文本文件。
|
||||
|
||||
数据源: Endpoint.url
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
batch_size: 每次从数据库迭代的批大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"success": bool,
|
||||
"output_file": str,
|
||||
"total_count": int,
|
||||
}
|
||||
"""
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = Endpoint.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
"success": result['success'],
|
||||
"output_file": result['output_file'],
|
||||
"total_count": result['total_count'],
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
"""
|
||||
扫描模块工具包
|
||||
|
||||
提供扫描相关的工具函数。
|
||||
"""
|
||||
|
||||
from .directory_cleanup import remove_directory
|
||||
from .command_builder import build_scan_command
|
||||
from .command_executor import execute_and_wait, execute_stream
|
||||
from .wordlist_helpers import ensure_wordlist_local
|
||||
from .nuclei_helpers import ensure_nuclei_templates_local
|
||||
from .performance import FlowPerformanceTracker, CommandPerformanceTracker
|
||||
from .workspace_utils import setup_scan_workspace, setup_scan_directory
|
||||
from . import config_parser
|
||||
|
||||
__all__ = [
|
||||
# 目录清理
|
||||
'remove_directory',
|
||||
# 工作空间
|
||||
'setup_scan_workspace', # 创建 Scan 根工作空间
|
||||
'setup_scan_directory', # 创建扫描子目录
|
||||
# 命令构建
|
||||
'build_scan_command', # 扫描工具命令构建(基于 f-string)
|
||||
# 命令执行
|
||||
'execute_and_wait', # 等待式执行(文件输出)
|
||||
'execute_stream', # 流式执行(实时处理)
|
||||
# 字典文件
|
||||
'ensure_wordlist_local', # 确保本地字典文件(含 hash 校验)
|
||||
# Nuclei 模板
|
||||
'ensure_nuclei_templates_local', # 确保本地模板(含 commit hash 校验)
|
||||
# 性能监控
|
||||
'FlowPerformanceTracker', # Flow 性能追踪器(含系统资源采样)
|
||||
'CommandPerformanceTracker', # 命令性能追踪器
|
||||
# 配置解析
|
||||
'config_parser',
|
||||
]
|
||||
|
||||
@@ -1,417 +0,0 @@
|
||||
from rest_framework import viewsets, status
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.exceptions import NotFound, APIException
|
||||
from rest_framework.filters import SearchFilter
|
||||
from django.core.exceptions import ObjectDoesNotExist, ValidationError
|
||||
from django.db.utils import DatabaseError, IntegrityError, OperationalError
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ..models import Scan, ScheduledScan
|
||||
from ..serializers import (
|
||||
ScanSerializer, ScanHistorySerializer, QuickScanSerializer,
|
||||
ScheduledScanSerializer, CreateScheduledScanSerializer,
|
||||
UpdateScheduledScanSerializer, ToggleScheduledScanSerializer
|
||||
)
|
||||
from ..services.scan_service import ScanService
|
||||
from ..services.scheduled_scan_service import ScheduledScanService
|
||||
from ..repositories import ScheduledScanDTO
|
||||
from apps.targets.services.target_service import TargetService
|
||||
from apps.targets.services.organization_service import OrganizationService
|
||||
from apps.engine.services.engine_service import EngineService
|
||||
from apps.common.definitions import ScanStatus
|
||||
from apps.common.pagination import BasePagination
|
||||
|
||||
|
||||
class ScanViewSet(viewsets.ModelViewSet):
|
||||
"""扫描任务视图集"""
|
||||
serializer_class = ScanSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [SearchFilter]
|
||||
search_fields = ['target__name'] # 按目标名称搜索
|
||||
|
||||
def get_queryset(self):
|
||||
"""优化查询集,提升API性能
|
||||
|
||||
查询优化策略:
|
||||
- select_related: 预加载 target 和 engine(一对一/多对一关系,使用 JOIN)
|
||||
- 移除 prefetch_related: 避免加载大量资产数据到内存
|
||||
- order_by: 按创建时间降序排列(最新创建的任务排在最前面)
|
||||
|
||||
性能优化原理:
|
||||
- 列表页:使用缓存统计字段(cached_*_count),避免实时 COUNT 查询
|
||||
- 序列化器:严格验证缓存字段,确保数据一致性
|
||||
- 分页场景:每页只显示10条记录,查询高效
|
||||
- 避免大数据加载:不再预加载所有关联的资产数据
|
||||
"""
|
||||
# 只保留必要的 select_related,移除所有 prefetch_related
|
||||
scan_service = ScanService()
|
||||
queryset = scan_service.get_all_scans(prefetch_relations=True)
|
||||
|
||||
return queryset
|
||||
|
||||
def get_serializer_class(self):
|
||||
"""根据不同的 action 返回不同的序列化器
|
||||
|
||||
- list action: 使用 ScanHistorySerializer(包含 summary 和 progress)
|
||||
- retrieve action: 使用 ScanHistorySerializer(包含 summary 和 progress)
|
||||
- 其他 action: 使用标准的 ScanSerializer
|
||||
"""
|
||||
if self.action in ['list', 'retrieve']:
|
||||
return ScanHistorySerializer
|
||||
return ScanSerializer
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
"""
|
||||
删除单个扫描任务(两阶段删除)
|
||||
|
||||
1. 软删除:立即对用户不可见
|
||||
2. 硬删除:后台异步执行
|
||||
"""
|
||||
try:
|
||||
scan = self.get_object()
|
||||
scan_service = ScanService()
|
||||
result = scan_service.delete_scans_two_phase([scan.id])
|
||||
|
||||
return Response({
|
||||
'message': f'已删除扫描任务: Scan #{scan.id}',
|
||||
'scanId': scan.id,
|
||||
'deletedCount': result['soft_deleted_count'],
|
||||
'deletedScans': result['scan_names']
|
||||
}, status=status.HTTP_200_OK)
|
||||
|
||||
except Scan.DoesNotExist:
|
||||
raise NotFound('扫描任务不存在')
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("删除扫描任务时发生错误")
|
||||
raise APIException('服务器错误,请稍后重试')
|
||||
|
||||
@action(detail=False, methods=['post'])
|
||||
def quick(self, request):
|
||||
"""
|
||||
快速扫描接口
|
||||
|
||||
功能:
|
||||
1. 接收目标列表和引擎配置
|
||||
2. 自动解析输入(支持 URL、域名、IP、CIDR)
|
||||
3. 批量创建 Target、Website、Endpoint 资产
|
||||
4. 立即发起批量扫描
|
||||
|
||||
请求参数:
|
||||
{
|
||||
"targets": [{"name": "example.com"}, {"name": "https://example.com/api"}],
|
||||
"engine_id": 1
|
||||
}
|
||||
|
||||
支持的输入格式:
|
||||
- 域名: example.com
|
||||
- IP: 192.168.1.1
|
||||
- CIDR: 10.0.0.0/8
|
||||
- URL: https://example.com/api/v1
|
||||
"""
|
||||
from ..services.quick_scan_service import QuickScanService
|
||||
|
||||
serializer = QuickScanSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
targets_data = serializer.validated_data['targets']
|
||||
engine_id = serializer.validated_data.get('engine_id')
|
||||
|
||||
try:
|
||||
# 提取输入字符串列表
|
||||
inputs = [t['name'] for t in targets_data]
|
||||
|
||||
# 1. 使用 QuickScanService 解析输入并创建资产
|
||||
quick_scan_service = QuickScanService()
|
||||
result = quick_scan_service.process_quick_scan(inputs, engine_id)
|
||||
|
||||
targets = result['targets']
|
||||
|
||||
if not targets:
|
||||
return Response({
|
||||
'error': '没有有效的目标可供扫描',
|
||||
'errors': result.get('errors', [])
|
||||
}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# 2. 获取扫描引擎
|
||||
engine_service = EngineService()
|
||||
engine = engine_service.get_engine(engine_id)
|
||||
if not engine:
|
||||
raise ValidationError(f'扫描引擎 ID {engine_id} 不存在')
|
||||
|
||||
# 3. 批量发起扫描
|
||||
scan_service = ScanService()
|
||||
created_scans = scan_service.create_scans(
|
||||
targets=targets,
|
||||
engine=engine
|
||||
)
|
||||
|
||||
# 序列化返回结果
|
||||
scan_serializer = ScanSerializer(created_scans, many=True)
|
||||
|
||||
return Response({
|
||||
'message': f'快速扫描已启动:{len(created_scans)} 个任务',
|
||||
'target_stats': result['target_stats'],
|
||||
'asset_stats': result['asset_stats'],
|
||||
'errors': result.get('errors', []),
|
||||
'scans': scan_serializer.data
|
||||
}, status=status.HTTP_201_CREATED)
|
||||
|
||||
except ValidationError as e:
|
||||
return Response({'error': str(e)}, status=status.HTTP_400_BAD_REQUEST)
|
||||
except Exception as e:
|
||||
logger.exception("快速扫描启动失败")
|
||||
return Response(
|
||||
{'error': '服务器内部错误,请稍后重试'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['post'])
|
||||
def initiate(self, request):
|
||||
"""
|
||||
发起扫描任务
|
||||
|
||||
请求参数:
|
||||
- organization_id: 组织ID (int, 可选)
|
||||
- target_id: 目标ID (int, 可选)
|
||||
- engine_id: 扫描引擎ID (int, 必填)
|
||||
|
||||
注意: organization_id 和 target_id 二选一
|
||||
|
||||
返回:
|
||||
- 扫描任务详情(单个或多个)
|
||||
"""
|
||||
# 获取请求数据
|
||||
organization_id = request.data.get('organization_id')
|
||||
target_id = request.data.get('target_id')
|
||||
engine_id = request.data.get('engine_id')
|
||||
|
||||
try:
|
||||
# 步骤1:准备扫描所需的数据(验证参数、查询资源、返回目标列表和引擎)
|
||||
scan_service = ScanService()
|
||||
targets, engine = scan_service.prepare_initiate_scan(
|
||||
organization_id=organization_id,
|
||||
target_id=target_id,
|
||||
engine_id=engine_id
|
||||
)
|
||||
|
||||
# 步骤2:批量创建扫描记录并分发扫描任务
|
||||
created_scans = scan_service.create_scans(
|
||||
targets=targets,
|
||||
engine=engine
|
||||
)
|
||||
|
||||
# 序列化返回结果
|
||||
scan_serializer = ScanSerializer(created_scans, many=True)
|
||||
|
||||
return Response(
|
||||
{
|
||||
'message': f'已成功发起 {len(created_scans)} 个扫描任务',
|
||||
'count': len(created_scans),
|
||||
'scans': scan_serializer.data
|
||||
},
|
||||
status=status.HTTP_201_CREATED
|
||||
)
|
||||
|
||||
except ObjectDoesNotExist as e:
|
||||
# 资源不存在错误(由 service 层抛出)
|
||||
error_msg = str(e)
|
||||
return Response(
|
||||
{'error': error_msg},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
# 参数验证错误(由 service 层抛出)
|
||||
return Response(
|
||||
{'error': str(e)},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
except (DatabaseError, IntegrityError, OperationalError):
|
||||
# 数据库错误
|
||||
return Response(
|
||||
{'error': '数据库错误,请稍后重试'},
|
||||
status=status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
)
|
||||
|
||||
# 所有快照相关的 action 和 export 已迁移到 asset/views.py 中的快照 ViewSet
|
||||
# GET /api/scans/{id}/subdomains/ -> SubdomainSnapshotViewSet
|
||||
# GET /api/scans/{id}/subdomains/export/ -> SubdomainSnapshotViewSet.export
|
||||
# GET /api/scans/{id}/websites/ -> WebsiteSnapshotViewSet
|
||||
# GET /api/scans/{id}/websites/export/ -> WebsiteSnapshotViewSet.export
|
||||
# GET /api/scans/{id}/directories/ -> DirectorySnapshotViewSet
|
||||
# GET /api/scans/{id}/directories/export/ -> DirectorySnapshotViewSet.export
|
||||
# GET /api/scans/{id}/endpoints/ -> EndpointSnapshotViewSet
|
||||
# GET /api/scans/{id}/endpoints/export/ -> EndpointSnapshotViewSet.export
|
||||
# GET /api/scans/{id}/ip-addresses/ -> HostPortMappingSnapshotViewSet
|
||||
# GET /api/scans/{id}/ip-addresses/export/ -> HostPortMappingSnapshotViewSet.export
|
||||
# GET /api/scans/{id}/vulnerabilities/ -> VulnerabilitySnapshotViewSet
|
||||
|
||||
@action(detail=False, methods=['post', 'delete'], url_path='bulk-delete')
|
||||
def bulk_delete(self, request):
|
||||
"""
|
||||
批量删除扫描记录
|
||||
|
||||
请求参数:
|
||||
- ids: 扫描ID列表 (list[int], 必填)
|
||||
|
||||
示例请求:
|
||||
POST /api/scans/bulk-delete/
|
||||
{
|
||||
"ids": [1, 2, 3]
|
||||
}
|
||||
|
||||
返回:
|
||||
- message: 成功消息
|
||||
- deletedCount: 实际删除的记录数
|
||||
|
||||
注意:
|
||||
- 使用级联删除,会同时删除关联的子域名、端点等数据
|
||||
- 只删除存在的记录,不存在的ID会被忽略
|
||||
"""
|
||||
ids = request.data.get('ids', [])
|
||||
|
||||
# 参数验证
|
||||
if not ids:
|
||||
return Response(
|
||||
{'error': '缺少必填参数: ids'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
if not isinstance(ids, list):
|
||||
return Response(
|
||||
{'error': 'ids 必须是数组'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
if not all(isinstance(i, int) for i in ids):
|
||||
return Response(
|
||||
{'error': 'ids 数组中的所有元素必须是整数'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用 Service 层批量删除(两阶段删除)
|
||||
scan_service = ScanService()
|
||||
result = scan_service.delete_scans_two_phase(ids)
|
||||
|
||||
return Response({
|
||||
'message': f"已删除 {result['soft_deleted_count']} 个扫描任务",
|
||||
'deletedCount': result['soft_deleted_count'],
|
||||
'deletedScans': result['scan_names']
|
||||
}, status=status.HTTP_200_OK)
|
||||
|
||||
except ValueError as e:
|
||||
# 未找到记录
|
||||
raise NotFound(str(e))
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("批量删除扫描任务时发生错误")
|
||||
raise APIException('服务器错误,请稍后重试')
|
||||
|
||||
@action(detail=False, methods=['get'])
|
||||
def statistics(self, request):
|
||||
"""
|
||||
获取扫描统计数据
|
||||
|
||||
返回扫描任务的汇总统计信息,用于仪表板和扫描历史页面。
|
||||
使用缓存字段聚合查询,性能优异。
|
||||
|
||||
返回:
|
||||
- total: 总扫描次数
|
||||
- running: 运行中的扫描数量
|
||||
- completed: 已完成的扫描数量
|
||||
- failed: 失败的扫描数量
|
||||
- totalVulns: 总共发现的漏洞数量
|
||||
- totalSubdomains: 总共发现的子域名数量
|
||||
- totalEndpoints: 总共发现的端点数量
|
||||
- totalAssets: 总资产数
|
||||
"""
|
||||
try:
|
||||
# 使用 Service 层获取统计数据
|
||||
scan_service = ScanService()
|
||||
stats = scan_service.get_statistics()
|
||||
|
||||
return Response({
|
||||
'total': stats['total'],
|
||||
'running': stats['running'],
|
||||
'completed': stats['completed'],
|
||||
'failed': stats['failed'],
|
||||
'totalVulns': stats['total_vulns'],
|
||||
'totalSubdomains': stats['total_subdomains'],
|
||||
'totalEndpoints': stats['total_endpoints'],
|
||||
'totalWebsites': stats['total_websites'],
|
||||
'totalAssets': stats['total_assets'],
|
||||
})
|
||||
|
||||
except (DatabaseError, OperationalError):
|
||||
return Response(
|
||||
{'error': '数据库错误,请稍后重试'},
|
||||
status=status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
)
|
||||
|
||||
@action(detail=True, methods=['post'])
|
||||
def stop(self, request, pk=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
停止扫描任务
|
||||
|
||||
URL: POST /api/scans/{id}/stop/
|
||||
|
||||
功能:
|
||||
- 终止正在运行或初始化的扫描任务
|
||||
- 更新扫描状态为 CANCELLED
|
||||
|
||||
状态限制:
|
||||
- 只能停止 RUNNING 或 INITIATED 状态的扫描
|
||||
- 已完成、失败或取消的扫描无法停止
|
||||
|
||||
返回:
|
||||
- message: 成功消息
|
||||
- revokedTaskCount: 取消的 Flow Run 数量
|
||||
"""
|
||||
try:
|
||||
# 使用 Service 层处理停止逻辑
|
||||
scan_service = ScanService()
|
||||
success, revoked_count = scan_service.stop_scan(scan_id=pk)
|
||||
|
||||
if not success:
|
||||
# 检查是否是状态不允许的问题
|
||||
scan = scan_service.get_scan(scan_id=pk, prefetch_relations=False)
|
||||
if scan and scan.status not in [ScanStatus.RUNNING, ScanStatus.INITIATED]:
|
||||
return Response(
|
||||
{
|
||||
'error': f'无法停止扫描:当前状态为 {ScanStatus(scan.status).label}',
|
||||
'detail': '只能停止运行中或初始化状态的扫描'
|
||||
},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
# 其他失败原因
|
||||
return Response(
|
||||
{'error': '停止扫描失败'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
'message': f'扫描已停止,已撤销 {revoked_count} 个任务',
|
||||
'revokedTaskCount': revoked_count
|
||||
},
|
||||
status=status.HTTP_200_OK
|
||||
)
|
||||
|
||||
except ObjectDoesNotExist:
|
||||
return Response(
|
||||
{'error': f'扫描 ID {pk} 不存在'},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
except (DatabaseError, IntegrityError, OperationalError):
|
||||
return Response(
|
||||
{'error': '数据库错误,请稍后重试'},
|
||||
status=status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
)
|
||||
@@ -1,27 +0,0 @@
|
||||
[tool.pytest.ini_options]
|
||||
DJANGO_SETTINGS_MODULE = "config.settings"
|
||||
python_files = ["test_*.py", "*_test.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
testpaths = ["apps"]
|
||||
addopts = "-v --reuse-db"
|
||||
|
||||
[tool.pylint]
|
||||
django-settings-module = "config.settings"
|
||||
load-plugins = "pylint_django"
|
||||
|
||||
[tool.pylint.messages_control]
|
||||
disable = [
|
||||
"missing-docstring",
|
||||
"invalid-name",
|
||||
"too-few-public-methods",
|
||||
"no-member",
|
||||
"import-error",
|
||||
"no-name-in-module",
|
||||
]
|
||||
|
||||
[tool.pylint.format]
|
||||
max-line-length = 120
|
||||
|
||||
[tool.pylint.basic]
|
||||
good-names = ["i", "j", "k", "ex", "Run", "_", "id", "pk", "ip", "url", "db", "qs"]
|
||||
@@ -1,57 +1,38 @@
|
||||
# ==================== 数据库配置(PostgreSQL) ====================
|
||||
# DB_HOST 决定使用本地容器还是远程数据库:
|
||||
# - postgres / localhost / 127.0.0.1 → 启动本地 PostgreSQL 容器
|
||||
# - 其他地址(如 192.168.1.100) → 使用远程数据库,不启动本地容器
|
||||
# ============================================
|
||||
# Docker Image Configuration
|
||||
# ============================================
|
||||
IMAGE_TAG=dev
|
||||
|
||||
# ============================================
|
||||
# Required: Security Configuration
|
||||
# MUST change these in production!
|
||||
# ============================================
|
||||
JWT_SECRET=change-me-in-production-use-a-long-random-string
|
||||
WORKER_TOKEN=change-me-worker-token
|
||||
|
||||
# ============================================
|
||||
# Required: Docker Service Hosts
|
||||
# ============================================
|
||||
DB_HOST=postgres
|
||||
DB_PORT=5432
|
||||
DB_NAME=xingrin
|
||||
DB_USER=postgres
|
||||
DB_PASSWORD=123.com
|
||||
|
||||
# ==================== Redis 配置 ====================
|
||||
# 在 Docker 网络中,Redis 服务名称为 redis
|
||||
DB_PASSWORD=postgres
|
||||
REDIS_HOST=redis
|
||||
REDIS_PORT=6379
|
||||
REDIS_DB=0
|
||||
|
||||
# ==================== 服务端口配置 ====================
|
||||
# SERVER_PORT 为 Django / uvicorn 容器内部端口(由 nginx 反代,对公网不直接暴露)
|
||||
SERVER_PORT=8888
|
||||
# ============================================
|
||||
# Optional: Override defaults if needed
|
||||
# ============================================
|
||||
# PUBLIC_URL=https://your-domain.com:8083
|
||||
# SERVER_PORT=8080
|
||||
# GIN_MODE=release
|
||||
# DB_PORT=5432
|
||||
# DB_USER=postgres
|
||||
# DB_NAME=lunafox
|
||||
# DB_SSLMODE=disable
|
||||
# DB_MAX_OPEN_CONNS=50
|
||||
# DB_MAX_IDLE_CONNS=10
|
||||
# REDIS_PORT=6379
|
||||
# REDIS_PASSWORD=
|
||||
# LOG_LEVEL=info
|
||||
# LOG_FORMAT=json
|
||||
# WORDLISTS_BASE_PATH=/opt/lunafox/wordlists
|
||||
|
||||
# ==================== 远程 Worker 配置 ====================
|
||||
# 供远程 Worker 访问主服务器的地址:
|
||||
# - 仅本地部署:server(Docker 内部服务名)
|
||||
# - 有远程 Worker:改为主服务器外网 IP 或域名(如 192.168.1.100 或 xingrin.example.com)
|
||||
# 注意:远程 Worker 会通过 https://{PUBLIC_HOST}:{PUBLIC_PORT} 访问(nginx 反代到后端 8888)
|
||||
PUBLIC_HOST=server
|
||||
# 对外 HTTPS 端口
|
||||
PUBLIC_PORT=8083
|
||||
|
||||
# ==================== Django 核心配置 ====================
|
||||
# 生产环境务必更换为随机强密钥
|
||||
DJANGO_SECRET_KEY=django-insecure-change-me-in-production
|
||||
# 是否开启调试模式(生产环境请保持 False)
|
||||
DEBUG=False
|
||||
# 允许的前端来源地址(用于 CORS)
|
||||
CORS_ALLOWED_ORIGINS=http://localhost:3000
|
||||
|
||||
# ==================== 路径配置(容器内路径) ====================
|
||||
# 扫描结果保存目录
|
||||
SCAN_RESULTS_DIR=/opt/xingrin/results
|
||||
# Django 日志目录
|
||||
# 注意:如果留空或删除此变量,日志将只输出到 Docker 控制台(标准输出),不写入文件
|
||||
LOG_DIR=/opt/xingrin/logs
|
||||
|
||||
# ==================== 日志级别配置 ====================
|
||||
# 应用日志级别:DEBUG / INFO / WARNING / ERROR
|
||||
LOG_LEVEL=INFO
|
||||
# 是否记录命令执行日志(大量扫描时会增加磁盘占用)
|
||||
ENABLE_COMMAND_LOGGING=true
|
||||
|
||||
# ==================== Docker Hub 配置(生产模式) ====================
|
||||
# 生产模式下从 Docker Hub 拉取镜像时使用
|
||||
DOCKER_USER=yyhuni
|
||||
# 镜像版本标签(安装时自动从 VERSION 文件读取)
|
||||
# VERSION 文件由 CI 自动更新,与 Git Tag 保持一致
|
||||
# 注意:此值由 install.sh 自动设置,请勿手动修改
|
||||
IMAGE_TAG=__WILL_BE_SET_BY_INSTALLER__
|
||||
|
||||
@@ -1,120 +1,122 @@
|
||||
services:
|
||||
# PostgreSQL(可选,使用远程数据库时不启动)
|
||||
# 本地模式: docker compose --profile local-db up -d
|
||||
# 远程模式: docker compose up -d(需配置 DB_HOST 为远程地址)
|
||||
# Agent 请通过安装脚本注册启动(/api/agents/install.sh)
|
||||
postgres:
|
||||
profiles: ["local-db"]
|
||||
image: postgres:15
|
||||
restart: always
|
||||
image: postgres:16.3-alpine
|
||||
restart: "on-failure:3"
|
||||
environment:
|
||||
POSTGRES_DB: ${DB_NAME}
|
||||
POSTGRES_USER: ${DB_USER}
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD}
|
||||
POSTGRES_DB: ${DB_NAME:-lunafox}
|
||||
POSTGRES_USER: ${DB_USER:-postgres}
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-postgres}
|
||||
ports:
|
||||
- "${DB_PORT:-5432}:5432"
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./postgres/init-user-db.sh:/docker-entrypoint-initdb.d/init-user-db.sh
|
||||
ports:
|
||||
- "${DB_PORT}:5432"
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${DB_USER}"]
|
||||
test: ["CMD-SHELL", "pg_isready -U ${DB_USER:-postgres}"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
restart: always
|
||||
image: redis:7.4.7-alpine
|
||||
restart: "on-failure:3"
|
||||
ports:
|
||||
- "${REDIS_PORT}:6379"
|
||||
- "${REDIS_PORT:-6379}:6379"
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
test: [CMD, redis-cli, ping]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
server:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/server/Dockerfile
|
||||
restart: always
|
||||
image: golang:1.25.6
|
||||
restart: "on-failure:3"
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- IMAGE_TAG=${IMAGE_TAG:-dev}
|
||||
- PUBLIC_URL=${PUBLIC_URL:-}
|
||||
- GOMODCACHE=/go/pkg/mod
|
||||
- GOCACHE=/root/.cache/go-build
|
||||
- GO111MODULE=${GO111MODULE:-on}
|
||||
- GOPROXY=${GOPROXY:-https://goproxy.cn,direct}
|
||||
ports:
|
||||
- "8888:8888"
|
||||
- "8080:8080"
|
||||
working_dir: /workspace/server
|
||||
command: sh -c "go install github.com/air-verse/air@latest && air -c .air.toml"
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
# 统一挂载数据目录
|
||||
- /opt/xingrin:/opt/xingrin
|
||||
- /opt/lunafox:/opt/lunafox
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8888/api/"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
|
||||
# Agent:心跳上报 + 负载监控 + 版本检查
|
||||
agent:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/agent/Dockerfile
|
||||
args:
|
||||
IMAGE_TAG: ${IMAGE_TAG:-dev}
|
||||
restart: always
|
||||
environment:
|
||||
- SERVER_URL=http://server:8888
|
||||
- WORKER_NAME=本地节点
|
||||
- IS_LOCAL=true
|
||||
- IMAGE_TAG=${IMAGE_TAG:-dev}
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- /proc:/host/proc:ro
|
||||
- ../server:/workspace/server
|
||||
- go-mod-cache:/go/pkg/mod
|
||||
- go-build-cache:/root/.cache/go-build
|
||||
|
||||
frontend:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/frontend/Dockerfile
|
||||
args:
|
||||
IMAGE_TAG: ${IMAGE_TAG:-dev}
|
||||
restart: always
|
||||
image: node:20.20.0-alpine
|
||||
restart: "on-failure:3"
|
||||
environment:
|
||||
- NODE_ENV=development
|
||||
- API_HOST=server
|
||||
- NEXT_PUBLIC_BACKEND_URL=${NEXT_PUBLIC_BACKEND_URL:-}
|
||||
- PORT=3000
|
||||
- HOSTNAME=0.0.0.0
|
||||
ports:
|
||||
- "3000:3000"
|
||||
working_dir: /app
|
||||
command: sh -c "corepack enable && corepack prepare pnpm@latest --activate && if [ ! -d node_modules/.pnpm ]; then pnpm install; fi && pnpm dev"
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
condition: service_started
|
||||
volumes:
|
||||
- ../frontend:/app
|
||||
- frontend_node_modules:/app/node_modules
|
||||
- frontend_pnpm_store:/root/.local/share/pnpm/store
|
||||
healthcheck:
|
||||
test: ["CMD", "node", "-e", "require('http').get('http://localhost:3000',res=>process.exit(res.statusCode<500?0:1)).on('error',()=>process.exit(1))"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 20
|
||||
start_period: 20s
|
||||
|
||||
nginx:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/nginx/Dockerfile
|
||||
restart: always
|
||||
image: yyhuni/lunafox-nginx:${IMAGE_TAG:-dev}
|
||||
restart: "on-failure:3"
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
frontend:
|
||||
condition: service_started
|
||||
frontend:
|
||||
condition: service_healthy
|
||||
ports:
|
||||
- "8083:8083"
|
||||
volumes:
|
||||
# SSL 证书挂载(方便更新)
|
||||
- ./nginx/ssl:/etc/nginx/ssl:ro
|
||||
|
||||
# Worker:扫描任务执行容器(开发模式下构建)
|
||||
# Worker: build image for task execution (not run in dev by default).
|
||||
worker:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/worker/Dockerfile
|
||||
image: docker-worker:${IMAGE_TAG:-latest}-dev
|
||||
context: ../worker
|
||||
dockerfile: Dockerfile
|
||||
image: yyhuni/lunafox-worker:${IMAGE_TAG:-dev}
|
||||
restart: "no"
|
||||
volumes:
|
||||
- /opt/xingrin:/opt/xingrin
|
||||
- /opt/lunafox:/opt/lunafox
|
||||
command: echo "Worker image built for development"
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
go-mod-cache:
|
||||
go-build-cache:
|
||||
frontend_node_modules:
|
||||
frontend_pnpm_store:
|
||||
|
||||
networks:
|
||||
default:
|
||||
name: xingrin_network # 固定网络名,不随目录名变化
|
||||
name: lunafox_network # Fixed network name, independent of directory name
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
# ============================================
|
||||
# 生产环境配置 - 使用 Docker Hub 预构建镜像
|
||||
# ============================================
|
||||
# 用法: docker compose up -d
|
||||
#
|
||||
# 开发环境请使用: docker compose -f docker-compose.dev.yml up -d
|
||||
# ============================================
|
||||
|
||||
services:
|
||||
# PostgreSQL(可选,使用远程数据库时不启动)
|
||||
postgres:
|
||||
profiles: ["local-db"]
|
||||
image: postgres:15
|
||||
image: postgres:16.3-alpine
|
||||
restart: always
|
||||
environment:
|
||||
POSTGRES_DB: ${DB_NAME}
|
||||
@@ -18,7 +14,6 @@ services:
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD}
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ./postgres/init-user-db.sh:/docker-entrypoint-initdb.d/init-user-db.sh
|
||||
ports:
|
||||
- "${DB_PORT}:5432"
|
||||
healthcheck:
|
||||
@@ -28,10 +23,8 @@ services:
|
||||
retries: 5
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
image: redis:7.4.7-alpine
|
||||
restart: always
|
||||
ports:
|
||||
- "${REDIS_PORT}:6379"
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 5s
|
||||
@@ -39,60 +32,47 @@ services:
|
||||
retries: 5
|
||||
|
||||
server:
|
||||
image: ${DOCKER_USER:-yyhuni}/xingrin-server:${IMAGE_TAG:?IMAGE_TAG is required}
|
||||
image: yyhuni/lunafox-server:${IMAGE_TAG:?IMAGE_TAG is required}
|
||||
restart: always
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- IMAGE_TAG=${IMAGE_TAG}
|
||||
- PUBLIC_URL=${PUBLIC_URL:-}
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
# 统一挂载数据目录
|
||||
- /opt/xingrin:/opt/xingrin
|
||||
# Docker Socket 挂载:允许 Django 服务器执行本地 docker 命令(用于本地 Worker 任务分发)
|
||||
- /opt/lunafox:/opt/lunafox
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8888/api/"]
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
|
||||
# ============================================
|
||||
# Agent:轻量心跳上报 + 负载监控(~10MB)
|
||||
# 扫描任务通过 task_distributor 分发到动态容器
|
||||
# ============================================
|
||||
|
||||
agent:
|
||||
image: ${DOCKER_USER:-yyhuni}/xingrin-agent:${IMAGE_TAG:?IMAGE_TAG is required}
|
||||
container_name: xingrin-agent
|
||||
restart: always
|
||||
environment:
|
||||
- SERVER_URL=http://server:8888
|
||||
- WORKER_NAME=本地节点
|
||||
- IS_LOCAL=true
|
||||
- IMAGE_TAG=${IMAGE_TAG}
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
volumes:
|
||||
- /proc:/host/proc:ro
|
||||
|
||||
frontend:
|
||||
image: ${DOCKER_USER:-yyhuni}/xingrin-frontend:${IMAGE_TAG:?IMAGE_TAG is required}
|
||||
image: yyhuni/lunafox-frontend:${IMAGE_TAG:?IMAGE_TAG is required}
|
||||
restart: always
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "node", "-e", "require('http').get('http://localhost:3000',res=>process.exit(res.statusCode<500?0:1)).on('error',()=>process.exit(1))"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 20
|
||||
start_period: 20s
|
||||
|
||||
nginx:
|
||||
image: ${DOCKER_USER:-yyhuni}/xingrin-nginx:${IMAGE_TAG:?IMAGE_TAG is required}
|
||||
image: yyhuni/lunafox-nginx:${IMAGE_TAG:?IMAGE_TAG is required}
|
||||
restart: always
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
frontend:
|
||||
condition: service_started
|
||||
condition: service_healthy
|
||||
ports:
|
||||
- "8083:8083"
|
||||
volumes:
|
||||
@@ -103,4 +83,4 @@ volumes:
|
||||
|
||||
networks:
|
||||
default:
|
||||
name: xingrin_network # 固定网络名,不随目录名变化
|
||||
name: lunafox_network # 固定网络名,不随目录名变化
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM nginx:1.27-alpine
|
||||
FROM nginx:1.28.1-alpine
|
||||
|
||||
# 复制 nginx 配置和证书
|
||||
COPY docker/nginx/nginx.conf /etc/nginx/nginx.conf
|
||||
|
||||
@@ -9,7 +9,7 @@ http {
|
||||
|
||||
# 上游服务
|
||||
upstream backend {
|
||||
server server:8888;
|
||||
server server:8080;
|
||||
}
|
||||
|
||||
upstream frontend {
|
||||
@@ -31,18 +31,11 @@ http {
|
||||
# HTTP 请求到 HTTPS 端口时自动跳转
|
||||
error_page 497 =301 https://$host:$server_port$request_uri;
|
||||
|
||||
# 指纹特征 - 用于 FOFA/Shodan 等搜索引擎识别
|
||||
add_header X-Powered-By "Xingrin ASM" always;
|
||||
# 指纹特征
|
||||
add_header X-Powered-By "LunaFox ASM" always;
|
||||
|
||||
location /api/ {
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_pass http://backend;
|
||||
}
|
||||
|
||||
# WebSocket 反代
|
||||
location /ws/ {
|
||||
# Agent WebSocket
|
||||
location /api/agents/ws {
|
||||
proxy_pass http://backend;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
@@ -50,9 +43,52 @@ http {
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_read_timeout 86400; # 24小时,防止 WebSocket 超时
|
||||
}
|
||||
|
||||
# 健康检查
|
||||
location /health {
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_read_timeout 30s;
|
||||
proxy_send_timeout 30s;
|
||||
proxy_pass http://backend;
|
||||
}
|
||||
|
||||
location /api/ {
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_read_timeout 300s; # 5分钟,支持大数据量导出
|
||||
proxy_send_timeout 300s;
|
||||
proxy_pass http://backend;
|
||||
}
|
||||
|
||||
# Next.js HMR (dev)
|
||||
location /_next/webpack-hmr {
|
||||
proxy_pass http://frontend;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header X-Forwarded-Host $host;
|
||||
proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_read_timeout 86400;
|
||||
}
|
||||
|
||||
# 前端反代
|
||||
location / {
|
||||
proxy_set_header Host $host;
|
||||
|
||||
574
docs/redis-stream-queue-design.md
Normal file
574
docs/redis-stream-queue-design.md
Normal file
@@ -0,0 +1,574 @@
|
||||
# Redis Stream 队列方案设计文档
|
||||
|
||||
## 概述
|
||||
|
||||
本文档描述了使用 Redis Stream 作为消息队列来优化大规模数据写入的方案设计。
|
||||
|
||||
## 背景
|
||||
|
||||
### 当前问题
|
||||
|
||||
在扫描大量 Endpoint 数据(几十万条)时,当前的 HTTP 批量写入方案存在以下问题:
|
||||
|
||||
1. **性能瓶颈**:50 万 Endpoint(每个 15 KB)需要 83-166 分钟
|
||||
2. **数据库 I/O 压力**:20 个 Worker 同时写入导致数据库 I/O 满载
|
||||
3. **Worker 阻塞风险**:如果使用批量写入 + 背压机制,Worker 会阻塞等待
|
||||
|
||||
### 方案目标
|
||||
|
||||
- 性能提升 10 倍(83 分钟 → 8 分钟)
|
||||
- Worker 永不阻塞(扫描速度稳定)
|
||||
- 数据不丢失(持久化保证)
|
||||
- 无需部署新组件(利用现有 Redis)
|
||||
|
||||
## 架构设计
|
||||
|
||||
### 整体架构
|
||||
|
||||
```
|
||||
Worker 扫描 → Redis Stream → Server 消费 → PostgreSQL
|
||||
```
|
||||
|
||||
### 数据流
|
||||
|
||||
1. **Worker 端**:扫描到 Endpoint → 发布到 Redis Stream
|
||||
2. **Redis Stream**:缓冲消息(持久化到磁盘)
|
||||
3. **Server 端**:单线程消费 → 批量写入数据库
|
||||
|
||||
### 关键特性
|
||||
|
||||
- **解耦**:Worker 和数据库完全解耦
|
||||
- **背压**:Server 控制消费速度,保护数据库
|
||||
- **持久化**:Redis AOF 保证数据不丢失
|
||||
- **扩展性**:支持多 Worker 并发写入
|
||||
|
||||
## Redis Stream 配置
|
||||
|
||||
### 启用 AOF 持久化
|
||||
|
||||
```conf
|
||||
# redis.conf
|
||||
appendonly yes
|
||||
appendfsync everysec # 每秒同步一次(平衡性能和安全)
|
||||
```
|
||||
|
||||
**效果**:
|
||||
- 数据持久化到磁盘
|
||||
- Redis 崩溃最多丢失 1 秒数据
|
||||
- 性能影响小
|
||||
|
||||
### 内存配置
|
||||
|
||||
```conf
|
||||
# redis.conf
|
||||
maxmemory 2gb
|
||||
maxmemory-policy allkeys-lru # 内存不足时淘汰最少使用的 key
|
||||
```
|
||||
|
||||
## 实现方案
|
||||
|
||||
### 1. Worker 端:发布到 Redis Stream
|
||||
|
||||
#### 代码结构
|
||||
|
||||
```
|
||||
worker/internal/queue/
|
||||
├── redis_publisher.go # Redis 发布者
|
||||
└── types.go # 数据类型定义
|
||||
```
|
||||
|
||||
#### 核心实现
|
||||
|
||||
```go
|
||||
// worker/internal/queue/redis_publisher.go
|
||||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type RedisPublisher struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
func NewRedisPublisher(redisURL string) (*RedisPublisher, error) {
|
||||
opt, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := redis.NewClient(opt)
|
||||
|
||||
// 测试连接
|
||||
if err := client.Ping(context.Background()).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RedisPublisher{client: client}, nil
|
||||
}
|
||||
|
||||
// PublishEndpoint 发布 Endpoint 到 Redis Stream
|
||||
func (p *RedisPublisher) PublishEndpoint(ctx context.Context, scanID int, endpoint Endpoint) error {
|
||||
data, err := json.Marshal(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
streamName := fmt.Sprintf("endpoints:%d", scanID)
|
||||
|
||||
return p.client.XAdd(ctx, &redis.XAddArgs{
|
||||
Stream: streamName,
|
||||
MaxLen: 1000000, // 最多保留 100 万条消息(防止内存溢出)
|
||||
Approx: true, // 使用近似裁剪(性能更好)
|
||||
Values: map[string]interface{}{
|
||||
"data": data,
|
||||
},
|
||||
}).Err()
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (p *RedisPublisher) Close() error {
|
||||
return p.client.Close()
|
||||
}
|
||||
```
|
||||
|
||||
#### 使用示例
|
||||
|
||||
```go
|
||||
// Worker 扫描流程
|
||||
func (w *Worker) ScanEndpoints(ctx context.Context, scanID int) error {
|
||||
// 初始化 Redis 发布者
|
||||
publisher, err := queue.NewRedisPublisher(os.Getenv("REDIS_URL"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer publisher.Close()
|
||||
|
||||
// 扫描 Endpoint
|
||||
for endpoint := range w.scan() {
|
||||
// 发布到 Redis Stream(非阻塞,超快)
|
||||
if err := publisher.PublishEndpoint(ctx, scanID, endpoint); err != nil {
|
||||
log.Printf("Failed to publish endpoint: %v", err)
|
||||
// 可以选择重试或记录错误
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Server 端:消费 Redis Stream
|
||||
|
||||
#### 代码结构
|
||||
|
||||
```
|
||||
server/internal/queue/
|
||||
├── redis_consumer.go # Redis 消费者
|
||||
├── batch_writer.go # 批量写入器
|
||||
└── types.go # 数据类型定义
|
||||
```
|
||||
|
||||
#### 核心实现
|
||||
|
||||
```go
|
||||
// server/internal/queue/redis_consumer.go
|
||||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/yyhuni/lunafox/server/internal/repository"
|
||||
)
|
||||
|
||||
type EndpointConsumer struct {
|
||||
client *redis.Client
|
||||
repository *repository.EndpointRepository
|
||||
}
|
||||
|
||||
func NewEndpointConsumer(redisURL string, repo *repository.EndpointRepository) (*EndpointConsumer, error) {
|
||||
opt, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := redis.NewClient(opt)
|
||||
|
||||
return &EndpointConsumer{
|
||||
client: client,
|
||||
repository: repo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start 启动消费者(单线程,控制写入速度)
|
||||
func (c *EndpointConsumer) Start(ctx context.Context, scanID int) error {
|
||||
streamName := fmt.Sprintf("endpoints:%d", scanID)
|
||||
groupName := "endpoint-consumers"
|
||||
consumerName := fmt.Sprintf("server-%d", time.Now().Unix())
|
||||
|
||||
// 创建消费者组(如果不存在)
|
||||
c.client.XGroupCreateMkStream(ctx, streamName, groupName, "0")
|
||||
|
||||
// 批量写入器(每 5000 条批量写入)
|
||||
batchWriter := NewBatchWriter(c.repository, 5000)
|
||||
defer batchWriter.Flush()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// 读取消息(批量)
|
||||
streams, err := c.client.XReadGroup(ctx, &redis.XReadGroupArgs{
|
||||
Group: groupName,
|
||||
Consumer: consumerName,
|
||||
Streams: []string{streamName, ">"},
|
||||
Count: 100, // 每次读取 100 条
|
||||
Block: 1000, // 阻塞 1 秒
|
||||
}).Result()
|
||||
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
continue // 没有新消息
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理消息
|
||||
for _, stream := range streams {
|
||||
for _, message := range stream.Messages {
|
||||
// 解析消息
|
||||
var endpoint Endpoint
|
||||
if err := json.Unmarshal([]byte(message.Values["data"].(string)), &endpoint); err != nil {
|
||||
// 记录错误,继续处理下一条
|
||||
continue
|
||||
}
|
||||
|
||||
// 添加到批量写入器
|
||||
if err := batchWriter.Add(endpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 确认消息(ACK)
|
||||
c.client.XAck(ctx, streamName, groupName, message.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// 定期 Flush
|
||||
if batchWriter.ShouldFlush() {
|
||||
if err := batchWriter.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *EndpointConsumer) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
```
|
||||
|
||||
#### 批量写入器
|
||||
|
||||
```go
|
||||
// server/internal/queue/batch_writer.go
|
||||
package queue
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"github.com/yyhuni/lunafox/server/internal/model"
|
||||
"github.com/yyhuni/lunafox/server/internal/repository"
|
||||
)
|
||||
|
||||
type BatchWriter struct {
|
||||
repository *repository.EndpointRepository
|
||||
buffer []model.Endpoint
|
||||
batchSize int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewBatchWriter(repo *repository.EndpointRepository, batchSize int) *BatchWriter {
|
||||
return &BatchWriter{
|
||||
repository: repo,
|
||||
batchSize: batchSize,
|
||||
buffer: make([]model.Endpoint, 0, batchSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Add 添加到缓冲区
|
||||
func (w *BatchWriter) Add(endpoint model.Endpoint) error {
|
||||
w.mu.Lock()
|
||||
w.buffer = append(w.buffer, endpoint)
|
||||
shouldFlush := len(w.buffer) >= w.batchSize
|
||||
w.mu.Unlock()
|
||||
|
||||
if shouldFlush {
|
||||
return w.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ShouldFlush 是否应该 Flush
|
||||
func (w *BatchWriter) ShouldFlush() bool {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return len(w.buffer) >= w.batchSize
|
||||
}
|
||||
|
||||
// Flush 批量写入数据库
|
||||
func (w *BatchWriter) Flush() error {
|
||||
w.mu.Lock()
|
||||
if len(w.buffer) == 0 {
|
||||
w.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// 复制缓冲区
|
||||
toWrite := make([]model.Endpoint, len(w.buffer))
|
||||
copy(toWrite, w.buffer)
|
||||
w.buffer = w.buffer[:0]
|
||||
w.mu.Unlock()
|
||||
|
||||
// 批量写入(使用现有的 BulkUpsert 方法)
|
||||
_, err := w.repository.BulkUpsert(toWrite)
|
||||
return err
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Server 启动消费者
|
||||
|
||||
```go
|
||||
// server/internal/app/app.go
|
||||
func Run(ctx context.Context, cfg config.Config) error {
|
||||
// ... 现有代码
|
||||
|
||||
// 启动 Redis 消费者(后台运行)
|
||||
consumer, err := queue.NewEndpointConsumer(cfg.RedisURL, endpointRepo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
// 消费所有活跃的扫描任务
|
||||
for {
|
||||
// 获取活跃的扫描任务
|
||||
scans := scanRepo.GetActiveScans()
|
||||
for _, scan := range scans {
|
||||
go consumer.Start(ctx, scan.ID)
|
||||
}
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
}()
|
||||
|
||||
// ... 现有代码
|
||||
}
|
||||
```
|
||||
|
||||
## 性能对比
|
||||
|
||||
### 50 万 Endpoint(每个 15 KB)
|
||||
|
||||
| 方案 | 写入速度 | 总时间 | 内存占用 | Worker 阻塞 |
|
||||
|------|---------|--------|---------|-----------|
|
||||
| **当前(HTTP 批量)** | 100 条/秒 | 83 分钟 | 1.5 MB | 否 |
|
||||
| **Redis Stream** | 1000 条/秒 | 8 分钟 | 75 MB | 否 |
|
||||
|
||||
**提升**:**10 倍性能!**
|
||||
|
||||
## 资源消耗
|
||||
|
||||
### Redis 资源消耗
|
||||
|
||||
| 项目 | 消耗 |
|
||||
|------|------|
|
||||
| 内存 | ~500 MB(缓冲 100 万条消息) |
|
||||
| CPU | ~10%(序列化/反序列化) |
|
||||
| 磁盘 | ~7.5 GB(AOF 持久化) |
|
||||
| 带宽 | ~50 MB/s |
|
||||
|
||||
### Server 资源消耗
|
||||
|
||||
| 项目 | 消耗 |
|
||||
|------|------|
|
||||
| 内存 | 75 MB(批量写入缓冲) |
|
||||
| CPU | 30%(反序列化 + 数据库写入) |
|
||||
| 数据库连接 | 1 个(单线程消费) |
|
||||
|
||||
## 可靠性保证
|
||||
|
||||
### 数据不丢失
|
||||
|
||||
1. **Redis AOF 持久化**:每秒同步到磁盘,最多丢失 1 秒数据
|
||||
2. **消息确认机制**:Server 处理成功后才 ACK
|
||||
3. **自动重试**:未 ACK 的消息会自动重新入队
|
||||
|
||||
### 故障恢复
|
||||
|
||||
| 故障场景 | 恢复机制 |
|
||||
|---------|---------|
|
||||
| Worker 崩溃 | 消息已发送到 Redis,不影响 |
|
||||
| Redis 崩溃 | AOF 恢复,最多丢失 1 秒数据 |
|
||||
| Server 崩溃 | 未 ACK 的消息重新入队 |
|
||||
| 数据库崩溃 | 消息保留在 Redis,恢复后继续消费 |
|
||||
|
||||
## 扩展性
|
||||
|
||||
### 多 Worker 支持
|
||||
|
||||
- Redis Stream 原生支持多个生产者
|
||||
- 无需额外配置
|
||||
|
||||
### 多 Server 消费者
|
||||
|
||||
```go
|
||||
// 启动多个消费者(负载均衡)
|
||||
for i := 0; i < 3; i++ {
|
||||
go consumer.Start(ctx, scanID)
|
||||
}
|
||||
```
|
||||
|
||||
Redis Stream 的消费者组会自动分配消息,实现负载均衡。
|
||||
|
||||
## 监控和运维
|
||||
|
||||
### 监控指标
|
||||
|
||||
```go
|
||||
// 获取队列长度
|
||||
func (c *EndpointConsumer) GetQueueLength(ctx context.Context, scanID int) (int64, error) {
|
||||
streamName := fmt.Sprintf("endpoints:%d", scanID)
|
||||
return c.client.XLen(ctx, streamName).Result()
|
||||
}
|
||||
|
||||
// 获取消费者组信息
|
||||
func (c *EndpointConsumer) GetConsumerGroupInfo(ctx context.Context, scanID int) ([]redis.XInfoGroup, error) {
|
||||
streamName := fmt.Sprintf("endpoints:%d", scanID)
|
||||
return c.client.XInfoGroups(ctx, streamName).Result()
|
||||
}
|
||||
```
|
||||
|
||||
### 清理策略
|
||||
|
||||
```go
|
||||
// 扫描完成后清理 Stream
|
||||
func (c *EndpointConsumer) CleanupStream(ctx context.Context, scanID int) error {
|
||||
streamName := fmt.Sprintf("endpoints:%d", scanID)
|
||||
return c.client.Del(ctx, streamName).Err()
|
||||
}
|
||||
```
|
||||
|
||||
## 配置建议
|
||||
|
||||
### Redis 配置
|
||||
|
||||
```conf
|
||||
# redis.conf
|
||||
|
||||
# 持久化
|
||||
appendonly yes
|
||||
appendfsync everysec
|
||||
|
||||
# 内存
|
||||
maxmemory 2gb
|
||||
maxmemory-policy allkeys-lru
|
||||
|
||||
# 性能
|
||||
tcp-backlog 511
|
||||
timeout 0
|
||||
tcp-keepalive 300
|
||||
```
|
||||
|
||||
### 环境变量
|
||||
|
||||
```bash
|
||||
# Worker 端
|
||||
REDIS_URL=redis://localhost:6379/0
|
||||
|
||||
# Server 端
|
||||
REDIS_URL=redis://localhost:6379/0
|
||||
```
|
||||
|
||||
## 迁移步骤
|
||||
|
||||
### 阶段 1:准备(1 天)
|
||||
|
||||
1. 启用 Redis AOF 持久化
|
||||
2. 实现 Worker 端 Redis 发布者
|
||||
3. 实现 Server 端 Redis 消费者
|
||||
|
||||
### 阶段 2:测试(2 天)
|
||||
|
||||
1. 单元测试
|
||||
2. 集成测试
|
||||
3. 性能测试(模拟 50 万数据)
|
||||
|
||||
### 阶段 3:灰度发布(3 天)
|
||||
|
||||
1. 10% 流量使用 Redis Stream
|
||||
2. 50% 流量使用 Redis Stream
|
||||
3. 100% 流量使用 Redis Stream
|
||||
|
||||
### 阶段 4:清理(1 天)
|
||||
|
||||
1. 移除旧的 HTTP 批量写入代码
|
||||
2. 更新文档
|
||||
|
||||
## 风险和缓解
|
||||
|
||||
### 风险 1:Redis 内存溢出
|
||||
|
||||
**缓解**:
|
||||
- 设置 `maxmemory` 限制
|
||||
- 使用 `MaxLen` 限制 Stream 长度
|
||||
- 监控 Redis 内存使用
|
||||
|
||||
### 风险 2:消息积压
|
||||
|
||||
**缓解**:
|
||||
- 增加 Server 消费者数量
|
||||
- 优化数据库写入性能
|
||||
- 监控队列长度
|
||||
|
||||
### 风险 3:数据丢失
|
||||
|
||||
**缓解**:
|
||||
- 启用 AOF 持久化
|
||||
- 使用消息确认机制
|
||||
- 定期备份 Redis
|
||||
|
||||
## 总结
|
||||
|
||||
### 优势
|
||||
|
||||
- ✅ 性能提升 10 倍
|
||||
- ✅ Worker 永不阻塞
|
||||
- ✅ 数据不丢失(AOF 持久化)
|
||||
- ✅ 无需部署新组件(利用现有 Redis)
|
||||
- ✅ 架构简单,易于维护
|
||||
|
||||
### 适用场景
|
||||
|
||||
- 数据量 > 10 万
|
||||
- 已有 Redis
|
||||
- 需要高性能写入
|
||||
- 不需要复杂的消息路由
|
||||
|
||||
### 不适用场景
|
||||
|
||||
- 数据量 < 10 万(当前方案足够)
|
||||
- 需要复杂的消息路由(考虑 RabbitMQ)
|
||||
- 数据量 > 1000 万(考虑 Kafka)
|
||||
|
||||
## 参考资料
|
||||
|
||||
- [Redis Stream 官方文档](https://redis.io/docs/data-types/streams/)
|
||||
- [Redis 持久化](https://redis.io/docs/management/persistence/)
|
||||
- [go-redis 文档](https://redis.uptrace.dev/)
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 95 KiB |
1
frontend/.gitignore
vendored
1
frontend/.gitignore
vendored
@@ -9,6 +9,7 @@
|
||||
!.yarn/plugins
|
||||
!.yarn/releases
|
||||
!.yarn/versions
|
||||
.pnpm-store/
|
||||
|
||||
# testing
|
||||
/coverage
|
||||
|
||||
60
frontend/Dockerfile
Normal file
60
frontend/Dockerfile
Normal file
@@ -0,0 +1,60 @@
|
||||
# Frontend Next.js Dockerfile
|
||||
# Multi-stage build with BuildKit caching
|
||||
|
||||
# ==================== Dependencies stage ====================
|
||||
FROM node:20.20.0-alpine AS deps
|
||||
WORKDIR /app
|
||||
|
||||
# Install pnpm
|
||||
RUN corepack enable && corepack prepare pnpm@latest --activate
|
||||
|
||||
# Copy dependency manifests
|
||||
COPY frontend/package.json frontend/pnpm-lock.yaml ./
|
||||
|
||||
# Install dependencies (BuildKit cache)
|
||||
RUN --mount=type=cache,target=/root/.local/share/pnpm/store \
|
||||
pnpm install --frozen-lockfile
|
||||
|
||||
# ==================== Build stage ====================
|
||||
FROM node:20.20.0-alpine AS builder
|
||||
WORKDIR /app
|
||||
|
||||
RUN corepack enable && corepack prepare pnpm@latest --activate
|
||||
|
||||
# Copy deps
|
||||
COPY --from=deps /app/node_modules ./node_modules
|
||||
COPY frontend/ ./
|
||||
|
||||
# Build-time env
|
||||
ARG IMAGE_TAG=unknown
|
||||
ENV NEXT_PUBLIC_IMAGE_TAG=${IMAGE_TAG}
|
||||
# Use service name "server" inside Docker network
|
||||
ENV API_HOST=server
|
||||
|
||||
# Build (BuildKit cache)
|
||||
RUN --mount=type=cache,target=/app/.next/cache \
|
||||
pnpm build
|
||||
|
||||
# ==================== Runtime stage ====================
|
||||
FROM node:20.20.0-alpine AS runner
|
||||
WORKDIR /app
|
||||
|
||||
ENV NODE_ENV=production
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup --system --gid 1001 nodejs
|
||||
RUN adduser --system --uid 1001 nextjs
|
||||
|
||||
# Copy build output
|
||||
COPY --from=builder /app/public ./public
|
||||
COPY --from=builder --chown=nextjs:nodejs /app/.next/standalone ./
|
||||
COPY --from=builder --chown=nextjs:nodejs /app/.next/static ./.next/static
|
||||
|
||||
USER nextjs
|
||||
|
||||
EXPOSE 3000
|
||||
|
||||
ENV PORT=3000
|
||||
ENV HOSTNAME="0.0.0.0"
|
||||
|
||||
CMD ["node", "server.js"]
|
||||
@@ -4,27 +4,27 @@ import { VulnSeverityChart } from "@/components/dashboard/vuln-severity-chart"
|
||||
import { DashboardDataTable } from "@/components/dashboard/dashboard-data-table"
|
||||
|
||||
/**
|
||||
* 仪表板页面组件
|
||||
* 这是应用的主要仪表板页面,包含卡片、图表和数据表格
|
||||
* 布局结构已移至根布局组件中
|
||||
* Dashboard page component
|
||||
* This is the main dashboard page of the application, containing cards, charts and data tables
|
||||
* Layout structure has been moved to the root layout component
|
||||
*/
|
||||
export default function Page() {
|
||||
return (
|
||||
// 内容区域,包含卡片、图表和数据表格
|
||||
// Content area containing cards, charts and data tables
|
||||
<div className="flex flex-col gap-4 py-4 md:gap-6 md:py-6">
|
||||
{/* 顶部统计卡片 */}
|
||||
{/* Top statistics cards */}
|
||||
<DashboardStatCards />
|
||||
|
||||
{/* 图表区域 - 趋势图 + 漏洞分布 */}
|
||||
{/* Chart area - Trend chart + Vulnerability distribution */}
|
||||
<div className="grid gap-4 px-4 lg:px-6 @xl/main:grid-cols-2">
|
||||
{/* 资产趋势折线图 */}
|
||||
{/* Asset trend line chart */}
|
||||
<AssetTrendChart />
|
||||
|
||||
{/* 漏洞严重程度分布 */}
|
||||
{/* Vulnerability severity distribution */}
|
||||
<VulnSeverityChart />
|
||||
</div>
|
||||
|
||||
{/* 漏洞 / 扫描历史 Tab */}
|
||||
{/* Vulnerabilities / Scan history tab */}
|
||||
<div className="px-4 lg:px-6">
|
||||
<DashboardDataTable />
|
||||
</div>
|
||||
|
||||
@@ -3,15 +3,17 @@ import type { Metadata } from "next"
|
||||
import { NextIntlClientProvider } from 'next-intl'
|
||||
import { getMessages, setRequestLocale, getTranslations } from 'next-intl/server'
|
||||
import { notFound } from 'next/navigation'
|
||||
import { cookies } from "next/headers"
|
||||
import { locales, localeHtmlLang, type Locale } from '@/i18n/config'
|
||||
import { COLOR_THEME_COOKIE_KEY, isColorThemeId, DEFAULT_COLOR_THEME_ID, isDarkColorTheme } from "@/lib/color-themes"
|
||||
|
||||
// 导入全局样式文件
|
||||
// Import global style files
|
||||
import "../globals.css"
|
||||
// 导入思源黑体(Noto Sans SC)本地字体
|
||||
// Import Noto Sans SC local font
|
||||
import "@fontsource/noto-sans-sc/400.css"
|
||||
import "@fontsource/noto-sans-sc/500.css"
|
||||
import "@fontsource/noto-sans-sc/700.css"
|
||||
// 导入颜色主题
|
||||
// Import color themes
|
||||
import "@/styles/themes/bubblegum.css"
|
||||
import "@/styles/themes/quantum-rose.css"
|
||||
import "@/styles/themes/clean-slate.css"
|
||||
@@ -24,13 +26,15 @@ import { Suspense } from "react"
|
||||
import Script from "next/script"
|
||||
import { QueryProvider } from "@/components/providers/query-provider"
|
||||
import { ThemeProvider } from "@/components/providers/theme-provider"
|
||||
import { UiI18nProvider } from "@/components/providers/ui-i18n-provider"
|
||||
import { ColorThemeInit } from "@/components/color-theme-init"
|
||||
|
||||
// 导入公共布局组件
|
||||
// Import common layout components
|
||||
import { RoutePrefetch } from "@/components/route-prefetch"
|
||||
import { RouteProgress } from "@/components/route-progress"
|
||||
import { AuthLayout } from "@/components/auth/auth-layout"
|
||||
|
||||
// 动态生成元数据
|
||||
// Dynamically generate metadata
|
||||
export async function generateMetadata({ params }: { params: Promise<{ locale: string }> }): Promise<Metadata> {
|
||||
const { locale } = await params
|
||||
const t = await getTranslations({ locale, namespace: 'metadata' })
|
||||
@@ -39,8 +43,15 @@ export async function generateMetadata({ params }: { params: Promise<{ locale: s
|
||||
title: t('title'),
|
||||
description: t('description'),
|
||||
keywords: t('keywords').split(',').map(k => k.trim()),
|
||||
generator: "Xingrin ASM Platform",
|
||||
generator: "LunaFox ASM Platform",
|
||||
authors: [{ name: "yyhuni" }],
|
||||
icons: {
|
||||
icon: [
|
||||
{ url: "/images/icon-64.png", sizes: "64x64", type: "image/png" },
|
||||
{ url: "/images/icon-256.png", sizes: "256x256", type: "image/png" },
|
||||
],
|
||||
apple: [{ url: "/images/icon-256.png", sizes: "256x256", type: "image/png" }],
|
||||
},
|
||||
openGraph: {
|
||||
title: t('ogTitle'),
|
||||
description: t('ogDescription'),
|
||||
@@ -54,7 +65,7 @@ export async function generateMetadata({ params }: { params: Promise<{ locale: s
|
||||
}
|
||||
}
|
||||
|
||||
// 使用思源黑体 + 系统字体回退,完全本地加载
|
||||
// Use Noto Sans SC + system font fallback, fully loaded locally
|
||||
const fontConfig = {
|
||||
className: "font-sans",
|
||||
style: {
|
||||
@@ -62,7 +73,7 @@ const fontConfig = {
|
||||
}
|
||||
}
|
||||
|
||||
// 生成静态参数,支持所有语言
|
||||
// Generate static parameters, support all languages
|
||||
export function generateStaticParams() {
|
||||
return locales.map((locale) => ({ locale }))
|
||||
}
|
||||
@@ -73,8 +84,8 @@ interface Props {
|
||||
}
|
||||
|
||||
/**
|
||||
* 语言布局组件
|
||||
* 包装所有页面,提供国际化上下文
|
||||
* Language layout component
|
||||
* Wraps all pages, provides internationalization context
|
||||
*/
|
||||
export default async function LocaleLayout({
|
||||
children,
|
||||
@@ -82,47 +93,61 @@ export default async function LocaleLayout({
|
||||
}: Props) {
|
||||
const { locale } = await params
|
||||
|
||||
// 验证 locale 有效性
|
||||
// Validate locale validity
|
||||
if (!locales.includes(locale as Locale)) {
|
||||
notFound()
|
||||
}
|
||||
|
||||
// 启用静态渲染
|
||||
// Enable static rendering
|
||||
setRequestLocale(locale)
|
||||
|
||||
// 加载翻译消息
|
||||
// Load translation messages
|
||||
const messages = await getMessages()
|
||||
|
||||
const cookieStore = await cookies()
|
||||
const cookieTheme = cookieStore.get(COLOR_THEME_COOKIE_KEY)?.value
|
||||
const themeId = isColorThemeId(cookieTheme) ? cookieTheme : DEFAULT_COLOR_THEME_ID
|
||||
const isDark = isDarkColorTheme(themeId)
|
||||
|
||||
return (
|
||||
<html lang={localeHtmlLang[locale as Locale]} suppressHydrationWarning>
|
||||
<html
|
||||
lang={localeHtmlLang[locale as Locale]}
|
||||
data-theme={themeId}
|
||||
className={isDark ? "dark" : undefined}
|
||||
suppressHydrationWarning
|
||||
>
|
||||
<body className={fontConfig.className} style={fontConfig.style}>
|
||||
{/* 加载外部脚本 */}
|
||||
<ColorThemeInit />
|
||||
{/* Load external scripts */}
|
||||
<Script
|
||||
src="https://tweakcn.com/live-preview.min.js"
|
||||
strategy="beforeInteractive"
|
||||
crossOrigin="anonymous"
|
||||
/>
|
||||
{/* 路由加载进度条 */}
|
||||
{/* Route loading progress bar */}
|
||||
<Suspense fallback={null}>
|
||||
<RouteProgress />
|
||||
</Suspense>
|
||||
{/* ThemeProvider 提供主题切换功能 */}
|
||||
{/* ThemeProvider provides theme switching functionality */}
|
||||
<ThemeProvider
|
||||
attribute="class"
|
||||
defaultTheme="dark"
|
||||
defaultTheme={isDark ? "dark" : "light"}
|
||||
enableSystem
|
||||
disableTransitionOnChange
|
||||
>
|
||||
{/* NextIntlClientProvider 提供国际化上下文 */}
|
||||
{/* NextIntlClientProvider provides internationalization context */}
|
||||
<NextIntlClientProvider messages={messages}>
|
||||
{/* QueryProvider 提供 React Query 功能 */}
|
||||
{/* QueryProvider provides React Query functionality */}
|
||||
<QueryProvider>
|
||||
{/* 路由预加载 */}
|
||||
<RoutePrefetch />
|
||||
{/* AuthLayout 处理认证和侧边栏显示 */}
|
||||
<AuthLayout>
|
||||
{children}
|
||||
</AuthLayout>
|
||||
{/* UiI18nProvider provides UI component translations */}
|
||||
<UiI18nProvider>
|
||||
{/* Route prefetch */}
|
||||
<RoutePrefetch />
|
||||
{/* AuthLayout handles authentication and sidebar display */}
|
||||
<AuthLayout>
|
||||
{children}
|
||||
</AuthLayout>
|
||||
</UiI18nProvider>
|
||||
</QueryProvider>
|
||||
</NextIntlClientProvider>
|
||||
</ThemeProvider>
|
||||
|
||||
@@ -16,13 +16,17 @@ export async function generateMetadata({ params }: Props): Promise<Metadata> {
|
||||
}
|
||||
|
||||
/**
|
||||
* 登录页面布局
|
||||
* 不包含侧边栏和头部
|
||||
* Login page layout
|
||||
* Does not include sidebar and header
|
||||
*/
|
||||
export default function LoginLayout({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode
|
||||
}) {
|
||||
return children
|
||||
return (
|
||||
<>
|
||||
{children}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -2,124 +2,351 @@
|
||||
|
||||
import React from "react"
|
||||
import { useRouter } from "next/navigation"
|
||||
import { useTranslations } from "next-intl"
|
||||
import Lottie from "lottie-react"
|
||||
import securityAnimation from "@/public/animations/Security000-Purple.json"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { Input } from "@/components/ui/input"
|
||||
import { Card, CardContent } from "@/components/ui/card"
|
||||
import {
|
||||
Field,
|
||||
FieldGroup,
|
||||
FieldLabel,
|
||||
} from "@/components/ui/field"
|
||||
import { Spinner } from "@/components/ui/spinner"
|
||||
import { useLocale, useTranslations } from "next-intl"
|
||||
import { useQueryClient } from "@tanstack/react-query"
|
||||
import { TerminalLogin } from "@/components/ui/terminal-login"
|
||||
import { LoadingState } from "@/components/loading-spinner"
|
||||
import { useLogin, useAuth } from "@/hooks/use-auth"
|
||||
import { useRoutePrefetch } from "@/hooks/use-route-prefetch"
|
||||
import { vulnerabilityKeys } from "@/hooks/use-vulnerabilities"
|
||||
import { getAssetStatistics, getStatisticsHistory } from "@/services/dashboard.service"
|
||||
import { getScans } from "@/services/scan.service"
|
||||
import { VulnerabilityService } from "@/services/vulnerability.service"
|
||||
|
||||
export default function LoginPage() {
|
||||
// 在登录页面预加载所有页面组件
|
||||
useRoutePrefetch()
|
||||
const router = useRouter()
|
||||
const queryClient = useQueryClient()
|
||||
const { data: auth, isLoading: authLoading } = useAuth()
|
||||
const { mutate: login, isPending } = useLogin()
|
||||
const t = useTranslations("auth")
|
||||
|
||||
const [username, setUsername] = React.useState("")
|
||||
const [password, setPassword] = React.useState("")
|
||||
const { mutateAsync: login, isPending } = useLogin()
|
||||
const t = useTranslations("auth.terminal")
|
||||
const locale = useLocale()
|
||||
|
||||
// 如果已登录,跳转到 dashboard
|
||||
const loginStartedRef = React.useRef(false)
|
||||
const [loginReady, setLoginReady] = React.useState(false)
|
||||
|
||||
const [isReady, setIsReady] = React.useState(false)
|
||||
const [loginProcessing, setLoginProcessing] = React.useState(false)
|
||||
const [isExiting, setIsExiting] = React.useState(false)
|
||||
const exitStartedRef = React.useRef(false)
|
||||
const showLoading = !isReady || loginProcessing
|
||||
const showExitOverlay = isExiting
|
||||
|
||||
const withLocale = React.useCallback((path: string) => {
|
||||
if (path.startsWith(`/${locale}/`)) return path
|
||||
const normalized = path.startsWith("/") ? path : `/${path}`
|
||||
return `/${locale}${normalized}`
|
||||
}, [locale])
|
||||
|
||||
// Hide the inline boot splash and show login content
|
||||
React.useEffect(() => {
|
||||
if (auth?.authenticated) {
|
||||
router.push("/dashboard/")
|
||||
let cancelled = false
|
||||
|
||||
const waitForLoad = new Promise<void>((resolve) => {
|
||||
if (typeof document === "undefined") {
|
||||
resolve()
|
||||
return
|
||||
}
|
||||
if (document.readyState === "complete") {
|
||||
resolve()
|
||||
return
|
||||
}
|
||||
const handleLoad = () => resolve()
|
||||
window.addEventListener("load", handleLoad, { once: true })
|
||||
})
|
||||
|
||||
const waitForPrefetch = new Promise<void>((resolve) => {
|
||||
if (typeof window === "undefined") {
|
||||
resolve()
|
||||
return
|
||||
}
|
||||
const w = window as Window & { __lunafoxRoutePrefetchDone?: boolean }
|
||||
if (w.__lunafoxRoutePrefetchDone) {
|
||||
resolve()
|
||||
return
|
||||
}
|
||||
const handlePrefetchDone = () => resolve()
|
||||
window.addEventListener("lunafox:route-prefetch-done", handlePrefetchDone, { once: true })
|
||||
})
|
||||
|
||||
const waitForPrefetchOrTimeout = Promise.race([
|
||||
waitForPrefetch,
|
||||
new Promise<void>((resolve) => setTimeout(resolve, 3000)),
|
||||
])
|
||||
|
||||
Promise.all([waitForLoad, waitForPrefetchOrTimeout]).then(() => {
|
||||
if (cancelled) return
|
||||
setIsReady(true)
|
||||
})
|
||||
|
||||
return () => {
|
||||
cancelled = true
|
||||
}
|
||||
}, [auth, router])
|
||||
}, [])
|
||||
|
||||
const handleSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
login({ username, password })
|
||||
}
|
||||
// 提取预加载逻辑为可复用函数
|
||||
const prefetchDashboardData = React.useCallback(async () => {
|
||||
const scansParams = { page: 1, pageSize: 10 }
|
||||
const vulnsParams = { page: 1, pageSize: 10 }
|
||||
|
||||
// 加载中显示 spinner
|
||||
if (authLoading) {
|
||||
return (
|
||||
<div className="flex min-h-svh w-full flex-col items-center justify-center gap-4 bg-background">
|
||||
<Spinner className="size-8 text-primary" />
|
||||
<p className="text-muted-foreground text-sm" suppressHydrationWarning>loading...</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
return Promise.allSettled([
|
||||
queryClient.prefetchQuery({
|
||||
queryKey: ["asset", "statistics"],
|
||||
queryFn: getAssetStatistics,
|
||||
}),
|
||||
queryClient.prefetchQuery({
|
||||
queryKey: ["asset", "statistics", "history", 7],
|
||||
queryFn: () => getStatisticsHistory(7),
|
||||
}),
|
||||
queryClient.prefetchQuery({
|
||||
queryKey: ["scans", scansParams],
|
||||
queryFn: () => getScans(scansParams),
|
||||
}),
|
||||
queryClient.prefetchQuery({
|
||||
queryKey: vulnerabilityKeys.list(vulnsParams),
|
||||
queryFn: () => VulnerabilityService.getAllVulnerabilities(vulnsParams),
|
||||
}),
|
||||
])
|
||||
}, [queryClient])
|
||||
|
||||
// 已登录不显示登录页
|
||||
if (auth?.authenticated) {
|
||||
return null
|
||||
// Memoize translations object to avoid recreating on every render
|
||||
const translations = React.useMemo(() => ({
|
||||
title: t("title"),
|
||||
subtitle: t("subtitle"),
|
||||
usernamePrompt: t("usernamePrompt"),
|
||||
passwordPrompt: t("passwordPrompt"),
|
||||
authenticating: t("authenticating"),
|
||||
processing: t("processing"),
|
||||
accessGranted: t("accessGranted"),
|
||||
welcomeMessage: t("welcomeMessage"),
|
||||
authFailed: t("authFailed"),
|
||||
invalidCredentials: t("invalidCredentials"),
|
||||
shortcuts: t("shortcuts"),
|
||||
submit: t("submit"),
|
||||
cancel: t("cancel"),
|
||||
clear: t("clear"),
|
||||
startEnd: t("startEnd"),
|
||||
}), [t])
|
||||
|
||||
// If already logged in, warm up the dashboard, then redirect.
|
||||
React.useEffect(() => {
|
||||
if (authLoading) return
|
||||
if (!auth?.authenticated) return
|
||||
if (loginStartedRef.current) return
|
||||
|
||||
let cancelled = false
|
||||
let timer: number | undefined
|
||||
|
||||
void (async () => {
|
||||
setLoginProcessing(true)
|
||||
await prefetchDashboardData()
|
||||
|
||||
if (cancelled) return
|
||||
setLoginProcessing(false)
|
||||
if (!exitStartedRef.current) {
|
||||
exitStartedRef.current = true
|
||||
setIsExiting(true)
|
||||
timer = window.setTimeout(() => {
|
||||
router.replace(withLocale("/dashboard/"))
|
||||
}, 300)
|
||||
}
|
||||
})()
|
||||
|
||||
return () => {
|
||||
cancelled = true
|
||||
if (timer) window.clearTimeout(timer)
|
||||
}
|
||||
}, [auth?.authenticated, authLoading, prefetchDashboardData, router, withLocale])
|
||||
|
||||
React.useEffect(() => {
|
||||
if (!loginReady) return
|
||||
if (exitStartedRef.current) return
|
||||
exitStartedRef.current = true
|
||||
setIsExiting(true)
|
||||
const timer = window.setTimeout(() => {
|
||||
router.replace(withLocale("/dashboard/"))
|
||||
}, 300)
|
||||
return () => window.clearTimeout(timer)
|
||||
}, [loginReady, router, withLocale])
|
||||
|
||||
const handleLogin = async (username: string, password: string) => {
|
||||
loginStartedRef.current = true
|
||||
setLoginReady(false)
|
||||
setLoginProcessing(true)
|
||||
|
||||
// 并行执行独立操作:登录验证 + 预加载 dashboard bundle
|
||||
const [loginRes] = await Promise.all([
|
||||
login({ username, password }),
|
||||
router.prefetch(withLocale("/dashboard/")),
|
||||
])
|
||||
|
||||
// 预加载 dashboard 数据
|
||||
await prefetchDashboardData()
|
||||
|
||||
// Prime auth cache so AuthLayout doesn't flash a full-screen loading state.
|
||||
queryClient.setQueryData(["auth", "me"], {
|
||||
authenticated: true,
|
||||
user: loginRes.user,
|
||||
})
|
||||
|
||||
setLoginProcessing(false)
|
||||
setLoginReady(true)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="login-bg flex min-h-svh flex-col p-6 md:p-10">
|
||||
{/* 主要内容区域 */}
|
||||
<div className="flex-1 flex items-center justify-center">
|
||||
<div className="w-full max-w-sm md:max-w-4xl">
|
||||
<Card className="overflow-hidden p-0">
|
||||
<CardContent className="grid p-0 md:grid-cols-2">
|
||||
<form className="p-6 md:p-8" onSubmit={handleSubmit}>
|
||||
<FieldGroup>
|
||||
{/* 指纹标识 - 用于 FOFA/Shodan 等搜索引擎识别 */}
|
||||
<meta name="generator" content="Xingrin ASM Platform" />
|
||||
<div className="flex flex-col items-center gap-2 text-center">
|
||||
<h1 className="text-2xl font-bold">{t("title")}</h1>
|
||||
<p className="text-sm text-muted-foreground mt-1">
|
||||
{t("subtitle")}
|
||||
</p>
|
||||
</div>
|
||||
<Field>
|
||||
<FieldLabel htmlFor="username">{t("username")}</FieldLabel>
|
||||
<Input
|
||||
id="username"
|
||||
type="text"
|
||||
placeholder={t("usernamePlaceholder")}
|
||||
value={username}
|
||||
onChange={(e) => setUsername(e.target.value)}
|
||||
required
|
||||
autoFocus
|
||||
/>
|
||||
</Field>
|
||||
<Field>
|
||||
<FieldLabel htmlFor="password">{t("password")}</FieldLabel>
|
||||
<Input
|
||||
id="password"
|
||||
type="password"
|
||||
placeholder={t("passwordPlaceholder")}
|
||||
value={password}
|
||||
onChange={(e) => setPassword(e.target.value)}
|
||||
required
|
||||
/>
|
||||
</Field>
|
||||
<Field>
|
||||
<Button type="submit" className="w-full" disabled={isPending}>
|
||||
{isPending ? t("loggingIn") : t("login")}
|
||||
</Button>
|
||||
</Field>
|
||||
</FieldGroup>
|
||||
</form>
|
||||
<div className="bg-primary/5 relative hidden md:flex md:items-center md:justify-center">
|
||||
<div className="text-center p-4">
|
||||
<Lottie
|
||||
animationData={securityAnimation}
|
||||
loop={true}
|
||||
className="w-96 h-96 mx-auto"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
<div className="relative flex min-h-svh flex-col bg-background text-foreground">
|
||||
{showLoading && !showExitOverlay ? (
|
||||
<LoadingState
|
||||
active
|
||||
message="loading..."
|
||||
className="fixed inset-0 z-50 bg-background"
|
||||
/>
|
||||
) : null}
|
||||
{showExitOverlay ? (
|
||||
<div className="fixed inset-0 z-50 bg-background" />
|
||||
) : null}
|
||||
{/* Circuit Board Animation */}
|
||||
<div className={`fixed inset-0 z-0 transition-opacity duration-300 ${isReady ? "opacity-100" : "opacity-0"}`}>
|
||||
<div className="circuit-container">
|
||||
{/* Grid pattern */}
|
||||
<div className="circuit-grid" />
|
||||
|
||||
{/* === Main backbone traces === */}
|
||||
{/* Horizontal main lines - 6 lines */}
|
||||
<div className="trace trace-h" style={{ top: '12%', left: 0, width: '100%' }}>
|
||||
<div className="trace-glow" style={{ animationDuration: '6s' }} />
|
||||
</div>
|
||||
<div className="trace trace-h" style={{ top: '28%', left: 0, width: '100%' }}>
|
||||
<div className="trace-glow" style={{ animationDelay: '1s', animationDuration: '5s' }} />
|
||||
</div>
|
||||
<div className="trace trace-h" style={{ top: '44%', left: 0, width: '100%' }}>
|
||||
<div className="trace-glow" style={{ animationDelay: '2s', animationDuration: '5.5s' }} />
|
||||
</div>
|
||||
<div className="trace trace-h" style={{ top: '60%', left: 0, width: '100%' }}>
|
||||
<div className="trace-glow" style={{ animationDelay: '3s', animationDuration: '4.5s' }} />
|
||||
</div>
|
||||
<div className="trace trace-h" style={{ top: '76%', left: 0, width: '100%' }}>
|
||||
<div className="trace-glow" style={{ animationDelay: '4s', animationDuration: '5s' }} />
|
||||
</div>
|
||||
<div className="trace trace-h" style={{ top: '92%', left: 0, width: '100%' }}>
|
||||
<div className="trace-glow" style={{ animationDelay: '5s', animationDuration: '6s' }} />
|
||||
</div>
|
||||
|
||||
{/* Vertical main lines - 6 lines */}
|
||||
<div className="trace trace-v" style={{ left: '8%', top: 0, height: '100%' }}>
|
||||
<div className="trace-glow trace-glow-v" style={{ animationDelay: '0.5s', animationDuration: '7s' }} />
|
||||
</div>
|
||||
<div className="trace trace-v" style={{ left: '24%', top: 0, height: '100%' }}>
|
||||
<div className="trace-glow trace-glow-v" style={{ animationDelay: '1.5s', animationDuration: '6s' }} />
|
||||
</div>
|
||||
<div className="trace trace-v" style={{ left: '40%', top: 0, height: '100%' }}>
|
||||
<div className="trace-glow trace-glow-v" style={{ animationDelay: '2.5s', animationDuration: '5.5s' }} />
|
||||
</div>
|
||||
<div className="trace trace-v" style={{ left: '56%', top: 0, height: '100%' }}>
|
||||
<div className="trace-glow trace-glow-v" style={{ animationDelay: '3.5s', animationDuration: '6.5s' }} />
|
||||
</div>
|
||||
<div className="trace trace-v" style={{ left: '72%', top: 0, height: '100%' }}>
|
||||
<div className="trace-glow trace-glow-v" style={{ animationDelay: '4.5s', animationDuration: '5s' }} />
|
||||
</div>
|
||||
<div className="trace trace-v" style={{ left: '88%', top: 0, height: '100%' }}>
|
||||
<div className="trace-glow trace-glow-v" style={{ animationDelay: '5.5s', animationDuration: '6s' }} />
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
<style jsx>{`
|
||||
.circuit-container {
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
background: var(--background);
|
||||
overflow: hidden;
|
||||
--login-grid: color-mix(in oklch, var(--foreground) 6%, transparent);
|
||||
--login-trace: color-mix(in oklch, var(--foreground) 16%, transparent);
|
||||
--login-glow: color-mix(in oklch, var(--primary) 65%, transparent);
|
||||
--login-glow-muted: color-mix(in oklch, var(--foreground) 45%, transparent);
|
||||
}
|
||||
|
||||
.circuit-grid {
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
background-image:
|
||||
linear-gradient(var(--login-grid) 1px, transparent 1px),
|
||||
linear-gradient(90deg, var(--login-grid) 1px, transparent 1px);
|
||||
background-size: 40px 40px;
|
||||
}
|
||||
|
||||
.trace {
|
||||
position: absolute;
|
||||
background: var(--login-trace);
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.trace-h {
|
||||
height: 2px;
|
||||
}
|
||||
|
||||
.trace-v {
|
||||
width: 2px;
|
||||
}
|
||||
|
||||
.trace-glow {
|
||||
position: absolute;
|
||||
top: -2px;
|
||||
left: -20%;
|
||||
width: 30%;
|
||||
height: 6px;
|
||||
background: linear-gradient(90deg, transparent, var(--login-glow), var(--login-glow-muted), transparent);
|
||||
animation: traceFlow 3s linear infinite;
|
||||
filter: blur(2px);
|
||||
}
|
||||
|
||||
.trace-glow-v {
|
||||
top: -20%;
|
||||
left: -2px;
|
||||
width: 6px;
|
||||
height: 30%;
|
||||
background: linear-gradient(180deg, transparent, var(--login-glow), var(--login-glow-muted), transparent);
|
||||
animation: traceFlowV 3s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes traceFlow {
|
||||
0% { left: -30%; }
|
||||
100% { left: 100%; }
|
||||
}
|
||||
|
||||
@keyframes traceFlowV {
|
||||
0% { top: -30%; }
|
||||
100% { top: 100%; }
|
||||
}
|
||||
`}</style>
|
||||
</div>
|
||||
|
||||
{/* 版本号 - 固定在页面底部 */}
|
||||
<div className="flex-shrink-0 text-center py-4">
|
||||
|
||||
{/* Fingerprint identifier - for FOFA/Shodan and other search engines to identify */}
|
||||
<meta name="generator" content="LunaFox ASM Platform" />
|
||||
|
||||
{/* Main content area */}
|
||||
<div
|
||||
className={`relative z-10 flex-1 flex items-center justify-center p-6 transition-[opacity,transform] duration-300 ${
|
||||
isReady ? "opacity-100 translate-y-0" : "opacity-0 translate-y-2"
|
||||
}`}
|
||||
>
|
||||
<TerminalLogin
|
||||
onLogin={handleLogin}
|
||||
authDone={loginReady}
|
||||
isPending={isPending}
|
||||
translations={translations}
|
||||
className={`transition-[opacity,transform] duration-300 ${
|
||||
isExiting ? "opacity-0 scale-[0.98]" : "opacity-100 scale-100"
|
||||
}`}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Version number - fixed at the bottom of the page */}
|
||||
<div
|
||||
className={`relative z-10 flex-shrink-0 text-center py-4 transition-opacity duration-300 ${
|
||||
isReady && !isExiting ? "opacity-100" : "opacity-0"
|
||||
}`}
|
||||
>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{process.env.NEXT_PUBLIC_VERSION || 'dev'}
|
||||
{process.env.NEXT_PUBLIC_IMAGE_TAG || "dev"}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user