mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-01-31 11:46:16 +08:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8a5e0cea8 | ||
|
|
3308908d7a | ||
|
|
a8402cfffa | ||
|
|
dce4e12667 | ||
|
|
bd1dd2c0d5 | ||
|
|
0b6560ac17 | ||
|
|
943a4cb960 | ||
|
|
eb2d853b76 | ||
|
|
1184c18b74 | ||
|
|
8a6f1b6f24 | ||
|
|
255d505aba | ||
|
|
d06a9bab1f | ||
|
|
6d5c776bf7 | ||
|
|
bf058dd67b | ||
|
|
0532d7c8b8 | ||
|
|
2ee9b5ffa2 | ||
|
|
648a1888d4 | ||
|
|
2508268a45 | ||
|
|
c60383940c | ||
|
|
47298c294a | ||
|
|
eba394e14e | ||
|
|
592a1958c4 | ||
|
|
38e2856c08 | ||
|
|
f5ad8e68e9 | ||
|
|
d5f91a236c | ||
|
|
24ae8b5aeb | ||
|
|
86f43f94a0 | ||
|
|
53ba03d1e5 | ||
|
|
89c44ebd05 | ||
|
|
e0e3419edb |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -64,6 +64,7 @@ backend/.env.local
|
||||
.coverage
|
||||
htmlcov/
|
||||
*.cover
|
||||
.hypothesis/
|
||||
|
||||
# ============================
|
||||
# 后端 (Go) 相关
|
||||
|
||||
143
README.md
143
README.md
@@ -1,7 +1,7 @@
|
||||
<h1 align="center">XingRin - 星环</h1>
|
||||
|
||||
<p align="center">
|
||||
<b>🛡️ 攻击面管理平台 (ASM) | 自动化资产发现与漏洞扫描系统</b>
|
||||
<b>攻击面管理平台 (ASM) | 自动化资产发现与漏洞扫描系统</b>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
@@ -12,29 +12,28 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="#-功能特性">功能特性</a> •
|
||||
<a href="#-全局资产搜索">资产搜索</a> •
|
||||
<a href="#-快速开始">快速开始</a> •
|
||||
<a href="#-文档">文档</a> •
|
||||
<a href="#-反馈与贡献">反馈与贡献</a>
|
||||
<a href="#功能特性">功能特性</a> •
|
||||
<a href="#全局资产搜索">资产搜索</a> •
|
||||
<a href="#快速开始">快速开始</a> •
|
||||
<a href="#文档">文档</a> •
|
||||
<a href="#反馈与贡献">反馈与贡献</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<sub>🔍 关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | 资产搜索 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
|
||||
<sub>关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | 资产搜索 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## 🌐 在线 Demo
|
||||
## 在线 Demo
|
||||
|
||||
**[https://xingrin.vercel.app/](https://xingrin.vercel.app/)**
|
||||
|
||||
> ⚠️ 仅用于 UI 展示,未接入后端数据库
|
||||
> 仅用于 UI 展示,未接入后端数据库
|
||||
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
<b>🎨 现代化 UI </b>
|
||||
<b>现代化 UI</b>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
@@ -44,47 +43,47 @@
|
||||
<img src="docs/screenshots/quantum-rose.png" alt="Quantum Rose" width="24%">
|
||||
</p>
|
||||
|
||||
## 📚 文档
|
||||
## 文档
|
||||
|
||||
- [📖 技术文档](./docs/README.md) - 技术文档导航(🚧 持续完善中)
|
||||
- [🚀 快速开始](./docs/quick-start.md) - 一键安装和部署指南
|
||||
- [🔄 版本管理](./docs/version-management.md) - Git Tag 驱动的自动化版本管理系统
|
||||
- [📦 Nuclei 模板架构](./docs/nuclei-template-architecture.md) - 模板仓库的存储与同步
|
||||
- [📖 字典文件架构](./docs/wordlist-architecture.md) - 字典文件的存储与同步
|
||||
- [🔍 扫描流程架构](./docs/scan-flow-architecture.md) - 完整扫描流程与工具编排
|
||||
- [技术文档](./docs/README.md) - 技术文档导航(持续完善中)
|
||||
- [快速开始](./docs/quick-start.md) - 一键安装和部署指南
|
||||
- [版本管理](./docs/version-management.md) - Git Tag 驱动的自动化版本管理系统
|
||||
- [Nuclei 模板架构](./docs/nuclei-template-architecture.md) - 模板仓库的存储与同步
|
||||
- [字典文件架构](./docs/wordlist-architecture.md) - 字典文件的存储与同步
|
||||
- [扫描流程架构](./docs/scan-flow-architecture.md) - 完整扫描流程与工具编排
|
||||
|
||||
|
||||
---
|
||||
|
||||
## ✨ 功能特性
|
||||
## 功能特性
|
||||
|
||||
### 🎯 目标与资产管理
|
||||
- **组织管理** - 多层级目标组织,灵活分组
|
||||
- **目标管理** - 支持域名、IP目标类型
|
||||
- **资产发现** - 子域名、网站、端点、目录自动发现
|
||||
- **资产快照** - 扫描结果快照对比,追踪资产变化
|
||||
### 扫描能力
|
||||
|
||||
### 🔍 漏洞扫描
|
||||
- **多引擎支持** - 集成 Nuclei 等主流扫描引擎
|
||||
- **自定义流程** - YAML 配置扫描流程,灵活编排
|
||||
- **定时扫描** - Cron 表达式配置,自动化周期扫描
|
||||
| 功能 | 状态 | 工具 | 说明 |
|
||||
|------|------|------|------|
|
||||
| 子域名扫描 | 已完成 | Subfinder, Amass, PureDNS | 被动收集 + 主动爆破,聚合 50+ 数据源 |
|
||||
| 端口扫描 | 已完成 | Naabu | 自定义端口范围 |
|
||||
| 站点发现 | 已完成 | HTTPX | HTTP 探测,自动获取标题、状态码、技术栈 |
|
||||
| 指纹识别 | 已完成 | XingFinger | 2.7W+ 指纹规则,多源指纹库 |
|
||||
| URL 收集 | 已完成 | Waymore, Katana | 历史数据 + 主动爬取 |
|
||||
| 目录扫描 | 已完成 | FFUF | 高速爆破,智能字典 |
|
||||
| 漏洞扫描 | 已完成 | Nuclei, Dalfox | 9000+ POC 模板,XSS 检测 |
|
||||
| 站点截图 | 已完成 | Playwright | WebP 高压缩存储 |
|
||||
|
||||
### 🚫 黑名单过滤
|
||||
- **两层黑名单** - 全局黑名单 + Target 级黑名单,灵活控制扫描范围
|
||||
- **智能规则识别** - 自动识别域名通配符(`*.gov`)、IP、CIDR 网段
|
||||
### 平台能力
|
||||
|
||||
### 🔖 指纹识别
|
||||
- **多源指纹库** - 内置 EHole、Goby、Wappalyzer、Fingers、FingerPrintHub、ARL 等 2.7W+ 指纹规则
|
||||
- **自动识别** - 扫描流程自动执行,识别 Web 应用技术栈
|
||||
- **指纹管理** - 支持查询、导入、导出指纹规则
|
||||
| 功能 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 目标管理 | 已完成 | 多层级组织,支持域名/IP 目标 |
|
||||
| 资产快照 | 已完成 | 扫描结果对比,追踪资产变化 |
|
||||
| 黑名单过滤 | 已完成 | 全局 + Target 级,支持通配符/CIDR |
|
||||
| 定时任务 | 已完成 | Cron 表达式,自动化周期扫描 |
|
||||
| 分布式扫描 | 已完成 | 多 Worker 节点,负载感知调度 |
|
||||
| 全局搜索 | 已完成 | 表达式语法,多字段组合查询 |
|
||||
| 通知推送 | 已完成 | 企业微信、Telegram、Discord |
|
||||
| API 密钥管理 | 已完成 | 可视化配置各数据源 API Key |
|
||||
|
||||
### 📸 站点截图
|
||||
- **自动截图** - 使用 Playwright 对发现的网站自动截图
|
||||
- **WebP 格式** - 高压缩比存储,500k图片压缩存储只占几十K
|
||||
- **多来源支持** - 支持对 Websites、Endpoints 等不同来源的 URL 截图
|
||||
- **资产关联** - 截图自动同步到资产表,方便查看
|
||||
|
||||
#### 扫描流程架构
|
||||
### 扫描流程架构
|
||||
|
||||
完整的扫描流程包括:子域名发现、端口扫描、站点发现、指纹识别、URL 收集、目录扫描、漏洞扫描等阶段
|
||||
|
||||
@@ -136,7 +135,7 @@ flowchart LR
|
||||
|
||||
详细说明请查看 [扫描流程架构文档](./docs/scan-flow-architecture.md)
|
||||
|
||||
### 🖥️ 分布式架构
|
||||
### 分布式架构
|
||||
- **多节点扫描** - 支持部署多个 Worker 节点,横向扩展扫描能力
|
||||
- **本地节点** - 零配置,安装即自动注册本地 Docker Worker
|
||||
- **远程节点** - SSH 一键部署远程 VPS 作为扫描节点
|
||||
@@ -181,7 +180,7 @@ flowchart TB
|
||||
W3 -.心跳上报.-> REDIS
|
||||
```
|
||||
|
||||
### 🔎 全局资产搜索
|
||||
### 全局资产搜索
|
||||
- **多类型搜索** - 支持 Website 和 Endpoint 两种资产类型
|
||||
- **表达式语法** - 支持 `=`(模糊)、`==`(精确)、`!=`(不等于)操作符
|
||||
- **逻辑组合** - 支持 `&&` (AND) 和 `||` (OR) 逻辑组合
|
||||
@@ -205,14 +204,14 @@ host="admin" && tech="php" && status=="200"
|
||||
url="/api/v1" && status!="404"
|
||||
```
|
||||
|
||||
### 📊 可视化界面
|
||||
### 可视化界面
|
||||
- **数据统计** - 资产/漏洞统计仪表盘
|
||||
- **实时通知** - WebSocket 消息推送
|
||||
- **通知推送** - 实时企业微信,tg,discard消息推送服务
|
||||
|
||||
---
|
||||
|
||||
## 📦 快速开始
|
||||
## 快速开始
|
||||
|
||||
### 环境要求
|
||||
|
||||
@@ -230,11 +229,11 @@ cd xingrin
|
||||
# 安装并启动(生产模式)
|
||||
sudo ./install.sh
|
||||
|
||||
# 🇨🇳 中国大陆用户推荐使用镜像加速(第三方加速服务可能会失效,不保证长期可用)
|
||||
# 中国大陆用户推荐使用镜像加速(第三方加速服务可能会失效,不保证长期可用)
|
||||
sudo ./install.sh --mirror
|
||||
```
|
||||
|
||||
> **💡 --mirror 参数说明**
|
||||
> **--mirror 参数说明**
|
||||
> - 自动配置 Docker 镜像加速(国内镜像源)
|
||||
> - 加速 Git 仓库克隆(Nuclei 模板等)
|
||||
|
||||
@@ -259,17 +258,17 @@ sudo ./restart.sh
|
||||
sudo ./uninstall.sh
|
||||
```
|
||||
|
||||
## 🤝 反馈与贡献
|
||||
## 反馈与贡献
|
||||
|
||||
- 💡 **发现 Bug,有新想法,比如UI设计,功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues) 或者公众号私信
|
||||
- **发现 Bug,有新想法,比如UI设计,功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues) 或者公众号私信
|
||||
|
||||
## 📧 联系
|
||||
## 联系
|
||||
- 微信公众号: **塔罗安全学苑**
|
||||
- 微信群去公众号底下的菜单,有个交流群,点击就可以看到了,链接过期可以私信我拉你
|
||||
|
||||
<img src="docs/wechat-qrcode.png" alt="微信公众号" width="200">
|
||||
|
||||
### 🎁 关注公众号免费领取指纹库
|
||||
### 关注公众号免费领取指纹库
|
||||
|
||||
| 指纹库 | 数量 |
|
||||
|--------|------|
|
||||
@@ -278,9 +277,9 @@ sudo ./uninstall.sh
|
||||
| goby.json | 7,086 |
|
||||
| FingerprintHub.json | 3,147 |
|
||||
|
||||
> 💡 关注公众号回复「指纹」即可获取
|
||||
> 关注公众号回复「指纹」即可获取
|
||||
|
||||
## ☕ 赞助支持
|
||||
## 赞助支持
|
||||
|
||||
如果这个项目对你有帮助,谢谢请我能喝杯蜜雪冰城,你的star和赞助是我免费更新的动力
|
||||
|
||||
@@ -289,14 +288,9 @@ sudo ./uninstall.sh
|
||||
<img src="docs/zfb_pay.jpg" alt="支付宝" width="200">
|
||||
</p>
|
||||
|
||||
### 🙏 感谢以下赞助
|
||||
|
||||
| 昵称 | 金额 |
|
||||
|------|------|
|
||||
| X(闭关中) | ¥88 |
|
||||
|
||||
|
||||
## ⚠️ 免责声明
|
||||
## 免责声明
|
||||
|
||||
**重要:请在使用前仔细阅读**
|
||||
|
||||
@@ -311,30 +305,29 @@ sudo ./uninstall.sh
|
||||
- 遵守所在地区的法律法规
|
||||
- 承担因滥用产生的一切后果
|
||||
|
||||
## 🌟 Star History
|
||||
## Star History
|
||||
|
||||
如果这个项目对你有帮助,请给一个 ⭐ Star 支持一下!
|
||||
如果这个项目对你有帮助,请给一个 Star 支持一下!
|
||||
|
||||
[](https://star-history.com/#yyhuni/xingrin&Date)
|
||||
|
||||
## 📄 许可证
|
||||
## 许可证
|
||||
|
||||
本项目采用 [GNU General Public License v3.0](LICENSE) 许可证。
|
||||
|
||||
### 允许的用途
|
||||
|
||||
- ✅ 个人学习和研究
|
||||
- ✅ 商业和非商业使用
|
||||
- ✅ 修改和分发
|
||||
- ✅ 专利使用
|
||||
- ✅ 私人使用
|
||||
- 个人学习和研究
|
||||
- 商业和非商业使用
|
||||
- 修改和分发
|
||||
- 专利使用
|
||||
- 私人使用
|
||||
|
||||
### 义务和限制
|
||||
|
||||
- 📋 **开源义务**:分发时必须提供源代码
|
||||
- 📋 **相同许可**:衍生作品必须使用相同许可证
|
||||
- 📋 **版权声明**:必须保留原始版权和许可证声明
|
||||
- ❌ **责任免除**:不提供任何担保
|
||||
- ❌ 未经授权的渗透测试
|
||||
- ❌ 任何违法行为
|
||||
|
||||
- **开源义务**:分发时必须提供源代码
|
||||
- **相同许可**:衍生作品必须使用相同许可证
|
||||
- **版权声明**:必须保留原始版权和许可证声明
|
||||
- **责任免除**:不提供任何担保
|
||||
- 未经授权的渗透测试
|
||||
- 任何违法行为
|
||||
|
||||
1
backend/.gitignore
vendored
1
backend/.gitignore
vendored
@@ -7,6 +7,7 @@ __pycache__/
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
.hypothesis/ # Hypothesis 属性测试缓存
|
||||
|
||||
# 虚拟环境
|
||||
venv/
|
||||
|
||||
@@ -14,6 +14,7 @@ from .views import (
|
||||
LoginView, LogoutView, MeView, ChangePasswordView,
|
||||
SystemLogsView, SystemLogFilesView, HealthCheckView,
|
||||
GlobalBlacklistView,
|
||||
VersionView, CheckUpdateView,
|
||||
)
|
||||
|
||||
urlpatterns = [
|
||||
@@ -29,6 +30,8 @@ urlpatterns = [
|
||||
# 系统管理
|
||||
path('system/logs/', SystemLogsView.as_view(), name='system-logs'),
|
||||
path('system/logs/files/', SystemLogFilesView.as_view(), name='system-log-files'),
|
||||
path('system/version/', VersionView.as_view(), name='system-version'),
|
||||
path('system/check-update/', CheckUpdateView.as_view(), name='system-check-update'),
|
||||
|
||||
# 黑名单管理(PUT 全量替换模式)
|
||||
path('blacklist/rules/', GlobalBlacklistView.as_view(), name='blacklist-rules'),
|
||||
|
||||
@@ -6,16 +6,19 @@
|
||||
- 认证相关视图:登录、登出、用户信息、修改密码
|
||||
- 系统日志视图:实时日志查看
|
||||
- 黑名单视图:全局黑名单规则管理
|
||||
- 版本视图:系统版本和更新检查
|
||||
"""
|
||||
|
||||
from .health_views import HealthCheckView
|
||||
from .auth_views import LoginView, LogoutView, MeView, ChangePasswordView
|
||||
from .system_log_views import SystemLogsView, SystemLogFilesView
|
||||
from .blacklist_views import GlobalBlacklistView
|
||||
from .version_views import VersionView, CheckUpdateView
|
||||
|
||||
__all__ = [
|
||||
'HealthCheckView',
|
||||
'LoginView', 'LogoutView', 'MeView', 'ChangePasswordView',
|
||||
'SystemLogsView', 'SystemLogFilesView',
|
||||
'GlobalBlacklistView',
|
||||
'VersionView', 'CheckUpdateView',
|
||||
]
|
||||
|
||||
136
backend/apps/common/views/version_views.py
Normal file
136
backend/apps/common/views/version_views.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
系统版本相关视图
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
from apps.common.response_helpers import error_response, success_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# GitHub 仓库信息
|
||||
GITHUB_REPO = "yyhuni/xingrin"
|
||||
GITHUB_API_URL = f"https://api.github.com/repos/{GITHUB_REPO}/releases/latest"
|
||||
GITHUB_RELEASES_URL = f"https://github.com/{GITHUB_REPO}/releases"
|
||||
|
||||
|
||||
def get_current_version() -> str:
|
||||
"""读取当前版本号"""
|
||||
import os
|
||||
|
||||
# 方式1:从环境变量读取(Docker 容器中推荐)
|
||||
version = os.environ.get('IMAGE_TAG', '')
|
||||
if version:
|
||||
return version
|
||||
|
||||
# 方式2:从文件读取(开发环境)
|
||||
possible_paths = [
|
||||
Path('/app/VERSION'),
|
||||
Path(__file__).parent.parent.parent.parent.parent / 'VERSION',
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
try:
|
||||
return path.read_text(encoding='utf-8').strip()
|
||||
except (FileNotFoundError, OSError):
|
||||
continue
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def compare_versions(current: str, latest: str) -> bool:
|
||||
"""
|
||||
比较版本号,判断是否有更新
|
||||
|
||||
Returns:
|
||||
True 表示有更新可用
|
||||
"""
|
||||
def parse_version(v: str) -> tuple:
|
||||
v = v.lstrip('v')
|
||||
parts = v.split('.')
|
||||
result = []
|
||||
for part in parts:
|
||||
if '-' in part:
|
||||
num, _ = part.split('-', 1)
|
||||
result.append(int(num))
|
||||
else:
|
||||
result.append(int(part))
|
||||
return tuple(result)
|
||||
|
||||
try:
|
||||
return parse_version(latest) > parse_version(current)
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
class VersionView(APIView):
|
||||
"""获取当前系统版本"""
|
||||
|
||||
def get(self, _request: Request) -> Response:
|
||||
"""获取当前版本信息"""
|
||||
return success_response(data={
|
||||
'version': get_current_version(),
|
||||
'github_repo': GITHUB_REPO,
|
||||
})
|
||||
|
||||
|
||||
class CheckUpdateView(APIView):
|
||||
"""检查系统更新"""
|
||||
|
||||
def get(self, _request: Request) -> Response:
|
||||
"""
|
||||
检查是否有新版本
|
||||
|
||||
Returns:
|
||||
- current_version: 当前版本
|
||||
- latest_version: 最新版本
|
||||
- has_update: 是否有更新
|
||||
- release_url: 发布页面 URL
|
||||
- release_notes: 更新说明(如果有)
|
||||
"""
|
||||
current_version = get_current_version()
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
GITHUB_API_URL,
|
||||
headers={'Accept': 'application/vnd.github.v3+json'},
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
return success_response(data={
|
||||
'current_version': current_version,
|
||||
'latest_version': current_version,
|
||||
'has_update': False,
|
||||
'release_url': GITHUB_RELEASES_URL,
|
||||
'release_notes': None,
|
||||
})
|
||||
|
||||
response.raise_for_status()
|
||||
release_data = response.json()
|
||||
|
||||
latest_version = release_data.get('tag_name', current_version)
|
||||
has_update = compare_versions(current_version, latest_version)
|
||||
|
||||
return success_response(data={
|
||||
'current_version': current_version,
|
||||
'latest_version': latest_version,
|
||||
'has_update': has_update,
|
||||
'release_url': release_data.get('html_url', GITHUB_RELEASES_URL),
|
||||
'release_notes': release_data.get('body'),
|
||||
'published_at': release_data.get('published_at'),
|
||||
})
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.warning("检查更新失败: %s", e)
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message="无法连接到 GitHub,请稍后重试",
|
||||
)
|
||||
@@ -2,8 +2,9 @@
|
||||
初始化默认扫描引擎
|
||||
|
||||
用法:
|
||||
python manage.py init_default_engine # 只创建不存在的引擎(不覆盖已有)
|
||||
python manage.py init_default_engine --force # 强制覆盖所有引擎配置
|
||||
python manage.py init_default_engine # 只创建不存在的引擎(不覆盖已有)
|
||||
python manage.py init_default_engine --force # 强制覆盖所有引擎配置
|
||||
python manage.py init_default_engine --force-sub # 只覆盖子引擎,保留 full scan
|
||||
|
||||
cd /root/my-vulun-scan/docker
|
||||
docker compose exec server python backend/manage.py init_default_engine --force
|
||||
@@ -12,6 +13,7 @@
|
||||
- 读取 engine_config_example.yaml 作为默认配置
|
||||
- 创建 full scan(默认引擎)+ 各扫描类型的子引擎
|
||||
- 默认不覆盖已有配置,加 --force 才会覆盖
|
||||
- 加 --force-sub 只覆盖子引擎配置,保留用户自定义的 full scan
|
||||
"""
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
@@ -30,11 +32,18 @@ class Command(BaseCommand):
|
||||
parser.add_argument(
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='强制覆盖已有的引擎配置',
|
||||
help='强制覆盖已有的引擎配置(包括 full scan 和子引擎)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--force-sub',
|
||||
action='store_true',
|
||||
help='只覆盖子引擎配置,保留 full scan(升级时使用)',
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
force = options.get('force', False)
|
||||
force_sub = options.get('force_sub', False)
|
||||
|
||||
# 读取默认配置文件
|
||||
config_path = Path(__file__).resolve().parent.parent.parent.parent / 'scan' / 'configs' / 'engine_config_example.yaml'
|
||||
|
||||
@@ -99,15 +108,22 @@ class Command(BaseCommand):
|
||||
engine_name = f"{scan_type}"
|
||||
sub_engine = ScanEngine.objects.filter(name=engine_name).first()
|
||||
if sub_engine:
|
||||
if force:
|
||||
# force 或 force_sub 都会覆盖子引擎
|
||||
if force or force_sub:
|
||||
sub_engine.configuration = single_yaml
|
||||
sub_engine.save()
|
||||
self.stdout.write(self.style.SUCCESS(f' ✓ 子引擎 {engine_name} 配置已更新 (ID: {sub_engine.id})'))
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f' ✓ 子引擎 {engine_name} 配置已更新 (ID: {sub_engine.id})'
|
||||
))
|
||||
else:
|
||||
self.stdout.write(self.style.WARNING(f' ⊘ {engine_name} 已存在,跳过(使用 --force 覆盖)'))
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f' ⊘ {engine_name} 已存在,跳过(使用 --force 覆盖)'
|
||||
))
|
||||
else:
|
||||
sub_engine = ScanEngine.objects.create(
|
||||
name=engine_name,
|
||||
configuration=single_yaml,
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS(f' ✓ 子引擎 {engine_name} 已创建 (ID: {sub_engine.id})'))
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f' ✓ 子引擎 {engine_name} 已创建 (ID: {sub_engine.id})'
|
||||
))
|
||||
|
||||
@@ -312,7 +312,11 @@ class TaskDistributor:
|
||||
# - 本地 Worker:install.sh 已预拉取镜像,直接使用本地版本
|
||||
# - 远程 Worker:deploy 时已预拉取镜像,直接使用本地版本
|
||||
# - 避免每次任务都检查 Docker Hub,提升性能和稳定性
|
||||
# OOM 优先级:--oom-score-adj=1000 让 Worker 在内存不足时优先被杀
|
||||
# - 范围 -1000 到 1000,值越大越容易被 OOM Killer 选中
|
||||
# - 保护 server/nginx/frontend 等核心服务,确保 Web 界面可用
|
||||
cmd = f'''docker run --rm -d --pull=missing {network_arg} \\
|
||||
--oom-score-adj=1000 \\
|
||||
{' '.join(env_vars)} \\
|
||||
{' '.join(volumes)} \\
|
||||
{self.docker_image} \\
|
||||
|
||||
@@ -24,18 +24,6 @@ SUBDOMAIN_DISCOVERY_COMMANDS = {
|
||||
}
|
||||
},
|
||||
|
||||
'amass_passive': {
|
||||
# 先执行被动枚举,将结果写入 amass 内部数据库,然后从数据库中导出纯域名(names)到 output_file
|
||||
# -silent 禁用进度条和其他输出
|
||||
'base': "amass enum -passive -silent -d {domain} && amass subs -names -d {domain} > '{output_file}'"
|
||||
},
|
||||
|
||||
'amass_active': {
|
||||
# 先执行主动枚举 + 爆破,将结果写入 amass 内部数据库,然后从数据库中导出纯域名(names)到 output_file
|
||||
# -silent 禁用进度条和其他输出
|
||||
'base': "amass enum -active -silent -d {domain} -brute && amass subs -names -d {domain} > '{output_file}'"
|
||||
},
|
||||
|
||||
'sublist3r': {
|
||||
'base': "python3 '/usr/local/share/Sublist3r/sublist3r.py' -d {domain} -o '{output_file}'",
|
||||
'optional': {
|
||||
|
||||
@@ -17,14 +17,6 @@ subdomain_discovery:
|
||||
timeout: 3600 # 1小时
|
||||
# threads: 10 # 并发 goroutine 数
|
||||
|
||||
amass_passive:
|
||||
enabled: true
|
||||
timeout: 3600
|
||||
|
||||
amass_active:
|
||||
enabled: true # 主动枚举 + 爆破
|
||||
timeout: 3600
|
||||
|
||||
sublist3r:
|
||||
enabled: true
|
||||
timeout: 3600
|
||||
@@ -62,7 +54,7 @@ port_scan:
|
||||
threads: 200 # 并发连接数(默认 5)
|
||||
# ports: 1-65535 # 扫描端口范围(默认 1-65535)
|
||||
top-ports: 100 # 扫描 nmap top 100 端口
|
||||
rate: 10 # 扫描速率(默认 10)
|
||||
rate: 50 # 扫描速率
|
||||
|
||||
naabu_passive:
|
||||
enabled: true
|
||||
|
||||
@@ -10,30 +10,30 @@
|
||||
- 配置由 YAML 解析
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
from prefect import flow
|
||||
from prefect.task_runners import ThreadPoolTaskRunner
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from apps.scan.tasks.directory_scan import (
|
||||
export_sites_task,
|
||||
run_and_stream_save_directories_task
|
||||
)
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.tasks.directory_scan import (
|
||||
export_sites_task,
|
||||
run_and_stream_save_directories_task,
|
||||
)
|
||||
from apps.scan.utils import (
|
||||
build_scan_command,
|
||||
ensure_wordlist_local,
|
||||
user_log,
|
||||
wait_for_system_load,
|
||||
)
|
||||
from apps.scan.utils import config_parser, build_scan_command, ensure_wordlist_local, user_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,517 +45,343 @@ def calculate_directory_scan_timeout(
|
||||
tool_config: dict,
|
||||
base_per_word: float = 1.0,
|
||||
min_timeout: int = 60,
|
||||
max_timeout: int = 7200
|
||||
) -> int:
|
||||
"""
|
||||
根据字典行数计算目录扫描超时时间
|
||||
|
||||
|
||||
计算公式:超时时间 = 字典行数 × 每个单词基础时间
|
||||
超时范围:60秒 ~ 2小时(7200秒)
|
||||
|
||||
超时范围:最小 60 秒,无上限
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置字典,包含 wordlist 路径
|
||||
base_per_word: 每个单词的基础时间(秒),默认 1.0秒
|
||||
min_timeout: 最小超时时间(秒),默认 60秒
|
||||
max_timeout: 最大超时时间(秒),默认 7200秒(2小时)
|
||||
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒),范围:60 ~ 7200
|
||||
|
||||
Example:
|
||||
# 1000行字典 × 1.0秒 = 1000秒 → 限制为7200秒中的 1000秒
|
||||
# 10000行字典 × 1.0秒 = 10000秒 → 限制为7200秒(最大值)
|
||||
timeout = calculate_directory_scan_timeout(
|
||||
tool_config={'wordlist': '/path/to/wordlist.txt'}
|
||||
)
|
||||
int: 计算出的超时时间(秒)
|
||||
"""
|
||||
import os
|
||||
|
||||
wordlist_path = tool_config.get('wordlist')
|
||||
if not wordlist_path:
|
||||
logger.warning("工具配置中未指定 wordlist,使用默认超时: %d秒", min_timeout)
|
||||
return min_timeout
|
||||
|
||||
wordlist_path = os.path.expanduser(wordlist_path)
|
||||
|
||||
if not os.path.exists(wordlist_path):
|
||||
logger.warning("字典文件不存在: %s,使用默认超时: %d秒", wordlist_path, min_timeout)
|
||||
return min_timeout
|
||||
|
||||
try:
|
||||
# 从 tool_config 中获取 wordlist 路径
|
||||
wordlist_path = tool_config.get('wordlist')
|
||||
if not wordlist_path:
|
||||
logger.warning("工具配置中未指定 wordlist,使用默认超时: %d秒", min_timeout)
|
||||
return min_timeout
|
||||
|
||||
# 展开用户目录(~)
|
||||
wordlist_path = os.path.expanduser(wordlist_path)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(wordlist_path):
|
||||
logger.warning("字典文件不存在: %s,使用默认超时: %d秒", wordlist_path, min_timeout)
|
||||
return min_timeout
|
||||
|
||||
# 使用 wc -l 快速统计字典行数
|
||||
result = subprocess.run(
|
||||
['wc', '-l', wordlist_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
# wc -l 输出格式:行数 + 空格 + 文件名
|
||||
line_count = int(result.stdout.strip().split()[0])
|
||||
|
||||
# 计算超时时间
|
||||
timeout = int(line_count * base_per_word)
|
||||
|
||||
# 设置合理的下限(不再设置上限)
|
||||
timeout = max(min_timeout, timeout)
|
||||
|
||||
timeout = max(min_timeout, int(line_count * base_per_word))
|
||||
|
||||
logger.info(
|
||||
"目录扫描超时计算 - 字典: %s, 行数: %d, 基础时间: %.3f秒/词, 计算超时: %d秒",
|
||||
wordlist_path, line_count, base_per_word, timeout
|
||||
)
|
||||
|
||||
return timeout
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error("统计字典行数失败: %s", e)
|
||||
# 失败时返回默认超时
|
||||
return min_timeout
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.error("解析字典行数失败: %s", e)
|
||||
return min_timeout
|
||||
except Exception as e:
|
||||
logger.error("计算超时时间异常: %s", e)
|
||||
|
||||
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
|
||||
logger.error("计算超时时间失败: %s", e)
|
||||
return min_timeout
|
||||
|
||||
|
||||
def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> int:
|
||||
"""
|
||||
从单个工具配置中获取 max_workers 参数
|
||||
|
||||
Args:
|
||||
tool_config: 单个工具的配置字典,如 {'max_workers': 10, 'threads': 5, ...}
|
||||
default: 默认值,默认为 5
|
||||
|
||||
Returns:
|
||||
int: max_workers 值
|
||||
"""
|
||||
"""从单个工具配置中获取 max_workers 参数"""
|
||||
if not isinstance(tool_config, dict):
|
||||
return default
|
||||
|
||||
# 支持 max_workers 和 max-workers(YAML 中划线会被转换)
|
||||
|
||||
max_workers = tool_config.get('max_workers') or tool_config.get('max-workers')
|
||||
if max_workers is not None and isinstance(max_workers, int) and max_workers > 0:
|
||||
if isinstance(max_workers, int) and max_workers > 0:
|
||||
return max_workers
|
||||
return default
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]:
|
||||
def _export_site_urls(
|
||||
target_id: int,
|
||||
directory_scan_dir: Path
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
导出目标下的所有站点 URL 到文件(支持懒加载)
|
||||
|
||||
导出目标下的所有站点 URL 到文件
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(用于懒加载创建默认站点)
|
||||
directory_scan_dir: 目录扫描目录
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (sites_file, site_count)
|
||||
|
||||
Raises:
|
||||
ValueError: 站点数量为 0
|
||||
"""
|
||||
logger.info("Step 1: 导出目标的所有站点 URL")
|
||||
|
||||
|
||||
sites_file = str(directory_scan_dir / 'sites.txt')
|
||||
export_result = export_sites_task(
|
||||
target_id=target_id,
|
||||
output_file=sites_file,
|
||||
batch_size=1000 # 每次读取 1000 条,优化内存占用
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
|
||||
site_count = export_result['total_count']
|
||||
|
||||
logger.info(
|
||||
"✓ 站点 URL 导出完成 - 文件: %s, 数量: %d",
|
||||
export_result['output_file'],
|
||||
site_count
|
||||
)
|
||||
|
||||
|
||||
if site_count == 0:
|
||||
logger.warning("目标下没有站点,无法执行目录扫描")
|
||||
# 不抛出异常,由上层决定如何处理
|
||||
# raise ValueError("目标下没有站点,无法执行目录扫描")
|
||||
|
||||
|
||||
return export_result['output_file'], site_count
|
||||
|
||||
|
||||
def _run_scans_sequentially(
|
||||
enabled_tools: dict,
|
||||
sites_file: str,
|
||||
directory_scan_dir: Path,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
site_count: int,
|
||||
target_name: str
|
||||
) -> tuple[int, int, list]:
|
||||
"""
|
||||
串行执行目录扫描任务(支持多工具)- 已废弃,保留用于兼容
|
||||
|
||||
Args:
|
||||
enabled_tools: 启用的工具配置字典
|
||||
sites_file: 站点文件路径
|
||||
directory_scan_dir: 目录扫描目录
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
site_count: 站点数量
|
||||
target_name: 目标名称(用于错误日志)
|
||||
|
||||
Returns:
|
||||
tuple: (total_directories, processed_sites, failed_sites)
|
||||
"""
|
||||
# 读取站点列表
|
||||
sites = []
|
||||
with open(sites_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
site_url = line.strip()
|
||||
if site_url:
|
||||
sites.append(site_url)
|
||||
|
||||
logger.info("准备扫描 %d 个站点,使用工具: %s", len(sites), ', '.join(enabled_tools.keys()))
|
||||
|
||||
total_directories = 0
|
||||
processed_sites_set = set() # 使用 set 避免重复计数
|
||||
failed_sites = []
|
||||
|
||||
# 遍历每个工具
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
logger.info("="*60)
|
||||
logger.info("使用工具: %s", tool_name)
|
||||
logger.info("="*60)
|
||||
|
||||
# 如果配置了 wordlist_name,则先确保本地存在对应的字典文件(含 hash 校验)
|
||||
wordlist_name = tool_config.get('wordlist_name')
|
||||
if wordlist_name:
|
||||
try:
|
||||
local_wordlist_path = ensure_wordlist_local(wordlist_name)
|
||||
tool_config['wordlist'] = local_wordlist_path
|
||||
except Exception as exc:
|
||||
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
|
||||
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
|
||||
failed_sites.extend(sites)
|
||||
continue
|
||||
|
||||
# 逐个站点执行扫描
|
||||
for idx, site_url in enumerate(sites, 1):
|
||||
logger.info(
|
||||
"[%d/%d] 开始扫描站点: %s (工具: %s)",
|
||||
idx, len(sites), site_url, tool_name
|
||||
)
|
||||
|
||||
# 使用统一的命令构建器
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='directory_scan',
|
||||
command_params={
|
||||
'url': site_url
|
||||
},
|
||||
tool_config=tool_config
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
|
||||
idx, len(sites), tool_name, e, site_url
|
||||
)
|
||||
failed_sites.append(site_url)
|
||||
continue
|
||||
|
||||
# 单个站点超时:从配置中获取(支持 'auto' 动态计算)
|
||||
# ffuf 逐个站点扫描,timeout 就是单个站点的超时时间
|
||||
site_timeout = tool_config.get('timeout', 300)
|
||||
if site_timeout == 'auto':
|
||||
# 动态计算超时时间(基于字典行数)
|
||||
site_timeout = calculate_directory_scan_timeout(tool_config)
|
||||
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}秒")
|
||||
|
||||
# 生成日志文件路径
|
||||
from datetime import datetime
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = directory_scan_dir / f"{tool_name}_{timestamp}_{idx}.log"
|
||||
|
||||
try:
|
||||
# 直接调用 task(串行执行)
|
||||
result = run_and_stream_save_directories_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name, # 新增:工具名称
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
site_url=site_url,
|
||||
cwd=str(directory_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1000,
|
||||
timeout=site_timeout,
|
||||
log_file=str(log_file) # 新增:日志文件路径
|
||||
)
|
||||
|
||||
total_directories += result.get('created_directories', 0)
|
||||
processed_sites_set.add(site_url) # 使用 set 记录成功的站点
|
||||
|
||||
logger.info(
|
||||
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
|
||||
idx, len(sites), site_url,
|
||||
result.get('created_directories', 0)
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
# 超时异常单独处理
|
||||
failed_sites.append(site_url)
|
||||
logger.warning(
|
||||
"⚠️ [%d/%d] 站点扫描超时: %s - 超时配置: %d秒\n"
|
||||
"注意:超时前已解析的目录数据已保存到数据库,但扫描未完全完成。",
|
||||
idx, len(sites), site_url, site_timeout
|
||||
)
|
||||
except Exception as exc:
|
||||
# 其他异常
|
||||
failed_sites.append(site_url)
|
||||
logger.error(
|
||||
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
|
||||
idx, len(sites), site_url, exc
|
||||
)
|
||||
|
||||
# 每 10 个站点输出进度
|
||||
if idx % 10 == 0:
|
||||
logger.info(
|
||||
"进度: %d/%d (%.1f%%) - 已发现 %d 个目录",
|
||||
idx, len(sites), idx/len(sites)*100, total_directories
|
||||
)
|
||||
|
||||
# 计算成功和失败的站点数
|
||||
processed_count = len(processed_sites_set)
|
||||
|
||||
if failed_sites:
|
||||
logger.warning(
|
||||
"部分站点扫描失败: %d/%d",
|
||||
len(failed_sites), len(sites)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"✓ 串行目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
|
||||
processed_count, len(sites), len(failed_sites), total_directories
|
||||
)
|
||||
|
||||
return total_directories, processed_count, failed_sites
|
||||
|
||||
|
||||
def _generate_log_filename(tool_name: str, site_url: str, directory_scan_dir: Path) -> Path:
|
||||
"""
|
||||
生成唯一的日志文件名
|
||||
|
||||
使用 URL 的 hash 确保并发时不会冲突
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
site_url: 站点 URL
|
||||
directory_scan_dir: 目录扫描目录
|
||||
|
||||
Returns:
|
||||
Path: 日志文件路径
|
||||
"""
|
||||
url_hash = hashlib.md5(site_url.encode()).hexdigest()[:8]
|
||||
def _generate_log_filename(
|
||||
tool_name: str,
|
||||
site_url: str,
|
||||
directory_scan_dir: Path
|
||||
) -> Path:
|
||||
"""生成唯一的日志文件名(使用 URL 的 hash 确保并发时不会冲突)"""
|
||||
url_hash = hashlib.md5(
|
||||
site_url.encode(),
|
||||
usedforsecurity=False
|
||||
).hexdigest()[:8]
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
|
||||
return directory_scan_dir / f"{tool_name}_{url_hash}_{timestamp}.log"
|
||||
|
||||
|
||||
def _prepare_tool_wordlist(tool_name: str, tool_config: dict) -> bool:
|
||||
"""准备工具的字典文件,返回是否成功"""
|
||||
wordlist_name = tool_config.get('wordlist_name')
|
||||
if not wordlist_name:
|
||||
return True
|
||||
|
||||
try:
|
||||
local_wordlist_path = ensure_wordlist_local(wordlist_name)
|
||||
tool_config['wordlist'] = local_wordlist_path
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
|
||||
return False
|
||||
|
||||
|
||||
def _build_scan_params(
|
||||
tool_name: str,
|
||||
tool_config: dict,
|
||||
sites: List[str],
|
||||
directory_scan_dir: Path,
|
||||
site_timeout: int
|
||||
) -> Tuple[List[dict], List[str]]:
|
||||
"""构建所有站点的扫描参数,返回 (scan_params_list, failed_sites)"""
|
||||
scan_params_list = []
|
||||
failed_sites = []
|
||||
|
||||
for idx, site_url in enumerate(sites, 1):
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='directory_scan',
|
||||
command_params={'url': site_url},
|
||||
tool_config=tool_config
|
||||
)
|
||||
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
|
||||
scan_params_list.append({
|
||||
'idx': idx,
|
||||
'site_url': site_url,
|
||||
'command': command,
|
||||
'log_file': str(log_file),
|
||||
'timeout': site_timeout
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
|
||||
idx, len(sites), tool_name, e, site_url
|
||||
)
|
||||
failed_sites.append(site_url)
|
||||
|
||||
return scan_params_list, failed_sites
|
||||
|
||||
|
||||
def _execute_batch(
|
||||
batch_params: List[dict],
|
||||
tool_name: str,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
directory_scan_dir: Path,
|
||||
total_sites: int
|
||||
) -> Tuple[int, List[str]]:
|
||||
"""执行一批扫描任务,返回 (directories_found, failed_sites)"""
|
||||
directories_found = 0
|
||||
failed_sites = []
|
||||
|
||||
# 提交任务
|
||||
futures = []
|
||||
for params in batch_params:
|
||||
future = run_and_stream_save_directories_task.submit(
|
||||
cmd=params['command'],
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
site_url=params['site_url'],
|
||||
cwd=str(directory_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1000,
|
||||
timeout=params['timeout'],
|
||||
log_file=params['log_file']
|
||||
)
|
||||
futures.append((params['idx'], params['site_url'], future))
|
||||
|
||||
# 等待结果
|
||||
for idx, site_url, future in futures:
|
||||
try:
|
||||
result = future.result()
|
||||
dirs_count = result.get('created_directories', 0)
|
||||
directories_found += dirs_count
|
||||
logger.info(
|
||||
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
|
||||
idx, total_sites, site_url, dirs_count
|
||||
)
|
||||
except Exception as exc:
|
||||
failed_sites.append(site_url)
|
||||
if 'timeout' in str(exc).lower():
|
||||
logger.warning(
|
||||
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
|
||||
idx, total_sites, site_url, exc
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
|
||||
idx, total_sites, site_url, exc
|
||||
)
|
||||
|
||||
return directories_found, failed_sites
|
||||
|
||||
|
||||
def _run_scans_concurrently(
|
||||
enabled_tools: dict,
|
||||
sites_file: str,
|
||||
directory_scan_dir: Path,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
site_count: int,
|
||||
target_name: str
|
||||
) -> Tuple[int, int, List[str]]:
|
||||
"""
|
||||
并发执行目录扫描任务(使用 ThreadPoolTaskRunner)
|
||||
|
||||
Args:
|
||||
enabled_tools: 启用的工具配置字典
|
||||
sites_file: 站点文件路径
|
||||
directory_scan_dir: 目录扫描目录
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
site_count: 站点数量
|
||||
target_name: 目标名称(用于错误日志)
|
||||
|
||||
并发执行目录扫描任务
|
||||
|
||||
Returns:
|
||||
tuple: (total_directories, processed_sites, failed_sites)
|
||||
"""
|
||||
# 读取站点列表
|
||||
sites: List[str] = []
|
||||
with open(sites_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
site_url = line.strip()
|
||||
if site_url:
|
||||
sites.append(site_url)
|
||||
|
||||
sites = [line.strip() for line in f if line.strip()]
|
||||
|
||||
if not sites:
|
||||
logger.warning("站点列表为空")
|
||||
return 0, 0, []
|
||||
|
||||
|
||||
logger.info(
|
||||
"准备并发扫描 %d 个站点,使用工具: %s",
|
||||
len(sites), ', '.join(enabled_tools.keys())
|
||||
)
|
||||
|
||||
|
||||
total_directories = 0
|
||||
processed_sites_count = 0
|
||||
failed_sites: List[str] = []
|
||||
|
||||
# 遍历每个工具
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 每个工具独立获取 max_workers 配置
|
||||
max_workers = _get_max_workers(tool_config)
|
||||
|
||||
logger.info("="*60)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("使用工具: %s (并发模式, max_workers=%d)", tool_name, max_workers)
|
||||
logger.info("="*60)
|
||||
logger.info("=" * 60)
|
||||
user_log(scan_id, "directory_scan", f"Running {tool_name}")
|
||||
|
||||
# 如果配置了 wordlist_name,则先确保本地存在对应的字典文件(含 hash 校验)
|
||||
wordlist_name = tool_config.get('wordlist_name')
|
||||
if wordlist_name:
|
||||
try:
|
||||
local_wordlist_path = ensure_wordlist_local(wordlist_name)
|
||||
tool_config['wordlist'] = local_wordlist_path
|
||||
except Exception as exc:
|
||||
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
|
||||
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
|
||||
failed_sites.extend(sites)
|
||||
continue
|
||||
|
||||
# 计算超时时间(所有站点共用)
|
||||
# 准备字典文件
|
||||
if not _prepare_tool_wordlist(tool_name, tool_config):
|
||||
failed_sites.extend(sites)
|
||||
continue
|
||||
|
||||
# 计算超时时间
|
||||
site_timeout = tool_config.get('timeout', 300)
|
||||
if site_timeout == 'auto':
|
||||
site_timeout = calculate_directory_scan_timeout(tool_config)
|
||||
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}秒")
|
||||
|
||||
# 准备所有站点的扫描参数
|
||||
scan_params_list = []
|
||||
for idx, site_url in enumerate(sites, 1):
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='directory_scan',
|
||||
command_params={'url': site_url},
|
||||
tool_config=tool_config
|
||||
)
|
||||
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
|
||||
scan_params_list.append({
|
||||
'idx': idx,
|
||||
'site_url': site_url,
|
||||
'command': command,
|
||||
'log_file': str(log_file),
|
||||
'timeout': site_timeout
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
|
||||
idx, len(sites), tool_name, e, site_url
|
||||
)
|
||||
failed_sites.append(site_url)
|
||||
|
||||
logger.info("✓ 工具 %s 动态计算 timeout: %d秒", tool_name, site_timeout)
|
||||
|
||||
# 构建扫描参数
|
||||
scan_params_list, build_failed = _build_scan_params(
|
||||
tool_name, tool_config, sites, directory_scan_dir, site_timeout
|
||||
)
|
||||
failed_sites.extend(build_failed)
|
||||
|
||||
if not scan_params_list:
|
||||
logger.warning("没有有效的扫描任务")
|
||||
continue
|
||||
|
||||
# ============================================================
|
||||
# 分批执行策略:控制实际并发的 ffuf 进程数
|
||||
# ============================================================
|
||||
|
||||
# 分批执行
|
||||
total_tasks = len(scan_params_list)
|
||||
logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers)
|
||||
|
||||
# 进度里程碑跟踪
|
||||
|
||||
last_progress_percent = 0
|
||||
tool_directories = 0
|
||||
tool_processed = 0
|
||||
|
||||
batch_num = 0
|
||||
|
||||
for batch_start in range(0, total_tasks, max_workers):
|
||||
batch_end = min(batch_start + max_workers, total_tasks)
|
||||
batch_params = scan_params_list[batch_start:batch_end]
|
||||
batch_num += 1
|
||||
|
||||
logger.info("执行第 %d 批任务(%d-%d/%d)...", batch_num, batch_start + 1, batch_end, total_tasks)
|
||||
|
||||
# 提交当前批次的任务(非阻塞,立即返回 future)
|
||||
futures = []
|
||||
for params in batch_params:
|
||||
future = run_and_stream_save_directories_task.submit(
|
||||
cmd=params['command'],
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
site_url=params['site_url'],
|
||||
cwd=str(directory_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1000,
|
||||
timeout=params['timeout'],
|
||||
log_file=params['log_file']
|
||||
)
|
||||
futures.append((params['idx'], params['site_url'], future))
|
||||
|
||||
# 等待当前批次所有任务完成(阻塞,确保本批完成后再启动下一批)
|
||||
for idx, site_url, future in futures:
|
||||
try:
|
||||
result = future.result() # 阻塞等待单个任务完成
|
||||
directories_found = result.get('created_directories', 0)
|
||||
total_directories += directories_found
|
||||
tool_directories += directories_found
|
||||
processed_sites_count += 1
|
||||
tool_processed += 1
|
||||
|
||||
logger.info(
|
||||
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
|
||||
idx, len(sites), site_url, directories_found
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
failed_sites.append(site_url)
|
||||
if 'timeout' in str(exc).lower() or isinstance(exc, subprocess.TimeoutExpired):
|
||||
logger.warning(
|
||||
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
|
||||
idx, len(sites), site_url, exc
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
|
||||
idx, len(sites), site_url, exc
|
||||
)
|
||||
|
||||
batch_num = batch_start // max_workers + 1
|
||||
|
||||
logger.info(
|
||||
"执行第 %d 批任务(%d-%d/%d)...",
|
||||
batch_num, batch_start + 1, batch_end, total_tasks
|
||||
)
|
||||
|
||||
dirs_found, batch_failed = _execute_batch(
|
||||
batch_params, tool_name, scan_id, target_id,
|
||||
directory_scan_dir, len(sites)
|
||||
)
|
||||
|
||||
total_directories += dirs_found
|
||||
tool_directories += dirs_found
|
||||
tool_processed += len(batch_params) - len(batch_failed)
|
||||
processed_sites_count += len(batch_params) - len(batch_failed)
|
||||
failed_sites.extend(batch_failed)
|
||||
|
||||
# 进度里程碑:每 20% 输出一次
|
||||
current_progress = int((batch_end / total_tasks) * 100)
|
||||
if current_progress >= last_progress_percent + 20:
|
||||
user_log(scan_id, "directory_scan", f"Progress: {batch_end}/{total_tasks} sites scanned")
|
||||
user_log(
|
||||
scan_id, "directory_scan",
|
||||
f"Progress: {batch_end}/{total_tasks} sites scanned"
|
||||
)
|
||||
last_progress_percent = (current_progress // 20) * 20
|
||||
|
||||
# 工具完成日志(开发者日志 + 用户日志)
|
||||
|
||||
logger.info(
|
||||
"✓ 工具 %s 执行完成 - 已处理站点: %d/%d, 发现目录: %d",
|
||||
tool_name, tool_processed, total_tasks, tool_directories
|
||||
)
|
||||
user_log(scan_id, "directory_scan", f"{tool_name} completed: found {tool_directories} directories")
|
||||
|
||||
# 输出汇总信息
|
||||
if failed_sites:
|
||||
logger.warning(
|
||||
"部分站点扫描失败: %d/%d",
|
||||
len(failed_sites), len(sites)
|
||||
user_log(
|
||||
scan_id, "directory_scan",
|
||||
f"{tool_name} completed: found {tool_directories} directories"
|
||||
)
|
||||
|
||||
|
||||
if failed_sites:
|
||||
logger.warning("部分站点扫描失败: %d/%d", len(failed_sites), len(sites))
|
||||
|
||||
logger.info(
|
||||
"✓ 并发目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
|
||||
processed_sites_count, len(sites), len(failed_sites), total_directories
|
||||
)
|
||||
|
||||
|
||||
return total_directories, processed_sites_count, failed_sites
|
||||
|
||||
|
||||
@flow(
|
||||
name="directory_scan",
|
||||
name="directory_scan",
|
||||
log_prints=True,
|
||||
on_running=[on_scan_flow_running],
|
||||
on_completion=[on_scan_flow_completed],
|
||||
@@ -570,64 +396,31 @@ def directory_scan_flow(
|
||||
) -> dict:
|
||||
"""
|
||||
目录扫描 Flow
|
||||
|
||||
|
||||
主要功能:
|
||||
1. 从 target 获取所有站点的 URL
|
||||
2. 对每个站点 URL 执行目录扫描(支持 ffuf 等工具)
|
||||
3. 流式保存扫描结果到数据库 Directory 表
|
||||
|
||||
工作流程:
|
||||
Step 0: 创建工作目录
|
||||
Step 1: 导出站点 URL 列表到文件(供扫描工具使用)
|
||||
Step 2: 验证工具配置
|
||||
Step 3: 并发执行扫描工具并实时保存结果(使用 ThreadPoolTaskRunner)
|
||||
|
||||
ffuf 输出字段:
|
||||
- url: 发现的目录/文件 URL
|
||||
- length: 响应内容长度
|
||||
- status: HTTP 状态码
|
||||
- words: 响应内容单词数
|
||||
- lines: 响应内容行数
|
||||
- content_type: 内容类型
|
||||
- duration: 请求耗时
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置字典
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'scan_workspace_dir': str,
|
||||
'sites_file': str,
|
||||
'site_count': int,
|
||||
'total_directories': int, # 发现的总目录数
|
||||
'processed_sites': int, # 成功处理的站点数
|
||||
'failed_sites_count': int, # 失败的站点数
|
||||
'executed_tasks': list
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 参数错误
|
||||
RuntimeError: 执行失败
|
||||
dict: 扫描结果
|
||||
"""
|
||||
try:
|
||||
wait_for_system_load(context="directory_scan_flow")
|
||||
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始目录扫描\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
"开始目录扫描 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
)
|
||||
|
||||
user_log(scan_id, "directory_scan", "Starting directory scan")
|
||||
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
@@ -639,14 +432,14 @@ def directory_scan_flow(
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
if not enabled_tools:
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
|
||||
|
||||
# Step 1: 导出站点 URL(支持懒加载)
|
||||
sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir)
|
||||
|
||||
|
||||
# Step 1: 导出站点 URL
|
||||
sites_file, site_count = _export_site_urls(target_id, directory_scan_dir)
|
||||
|
||||
if site_count == 0:
|
||||
logger.warning("跳过目录扫描:没有站点可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "directory_scan", "Skipped: no sites to scan", "warning")
|
||||
@@ -662,16 +455,16 @@ def directory_scan_flow(
|
||||
'failed_sites_count': 0,
|
||||
'executed_tasks': ['export_sites']
|
||||
}
|
||||
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
logger.info("Step 2: 工具配置信息")
|
||||
tool_info = []
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
mw = _get_max_workers(tool_config)
|
||||
tool_info.append(f"{tool_name}(max_workers={mw})")
|
||||
tool_info = [
|
||||
f"{name}(max_workers={_get_max_workers(cfg)})"
|
||||
for name, cfg in enabled_tools.items()
|
||||
]
|
||||
logger.info("✓ 启用工具: %s", ', '.join(tool_info))
|
||||
|
||||
# Step 3: 并发执行扫描工具并实时保存结果
|
||||
|
||||
# Step 3: 并发执行扫描
|
||||
logger.info("Step 3: 并发执行扫描工具并实时保存结果")
|
||||
total_directories, processed_sites, failed_sites = _run_scans_concurrently(
|
||||
enabled_tools=enabled_tools,
|
||||
@@ -679,19 +472,20 @@ def directory_scan_flow(
|
||||
directory_scan_dir=directory_scan_dir,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
site_count=site_count,
|
||||
target_name=target_name
|
||||
)
|
||||
|
||||
# 检查是否所有站点都失败
|
||||
|
||||
if processed_sites == 0 and site_count > 0:
|
||||
logger.warning("所有站点扫描均失败 - 总站点数: %d, 失败数: %d", site_count, len(failed_sites))
|
||||
# 不抛出异常,让扫描继续
|
||||
|
||||
# 记录 Flow 完成
|
||||
logger.warning(
|
||||
"所有站点扫描均失败 - 总站点数: %d, 失败数: %d",
|
||||
site_count, len(failed_sites)
|
||||
)
|
||||
|
||||
logger.info("✓ 目录扫描完成 - 发现目录: %d", total_directories)
|
||||
user_log(scan_id, "directory_scan", f"directory_scan completed: found {total_directories} directories")
|
||||
|
||||
user_log(
|
||||
scan_id, "directory_scan",
|
||||
f"directory_scan completed: found {total_directories} directories"
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -704,7 +498,7 @@ def directory_scan_flow(
|
||||
'failed_sites_count': len(failed_sites),
|
||||
'executed_tasks': ['export_sites', 'run_and_stream_save_directories']
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("目录扫描失败: %s", e)
|
||||
raise
|
||||
raise
|
||||
|
||||
@@ -10,26 +10,22 @@
|
||||
- 流式处理输出,批量更新数据库
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.tasks.fingerprint_detect import (
|
||||
export_urls_for_fingerprint_task,
|
||||
run_xingfinger_and_stream_update_tech_task,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command, user_log
|
||||
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
|
||||
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,22 +38,19 @@ def calculate_fingerprint_detect_timeout(
|
||||
) -> int:
|
||||
"""
|
||||
根据 URL 数量计算超时时间
|
||||
|
||||
|
||||
公式:超时时间 = URL 数量 × 每 URL 基础时间
|
||||
最小值:300秒
|
||||
无上限
|
||||
|
||||
最小值:300秒,无上限
|
||||
|
||||
Args:
|
||||
url_count: URL 数量
|
||||
base_per_url: 每 URL 基础时间(秒),默认 10秒
|
||||
min_timeout: 最小超时时间(秒),默认 300秒
|
||||
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒)
|
||||
|
||||
"""
|
||||
timeout = int(url_count * base_per_url)
|
||||
return max(min_timeout, timeout)
|
||||
return max(min_timeout, int(url_count * base_per_url))
|
||||
|
||||
|
||||
|
||||
@@ -70,17 +63,17 @@ def _export_urls(
|
||||
) -> tuple[str, int]:
|
||||
"""
|
||||
导出 URL 到文件
|
||||
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
fingerprint_dir: 指纹识别目录
|
||||
source: 数据源类型
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (urls_file, total_count)
|
||||
"""
|
||||
logger.info("Step 1: 导出 URL 列表 (source=%s)", source)
|
||||
|
||||
|
||||
urls_file = str(fingerprint_dir / 'urls.txt')
|
||||
export_result = export_urls_for_fingerprint_task(
|
||||
target_id=target_id,
|
||||
@@ -88,15 +81,14 @@ def _export_urls(
|
||||
source=source,
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
|
||||
total_count = export_result['total_count']
|
||||
|
||||
logger.info(
|
||||
"✓ URL 导出完成 - 文件: %s, 数量: %d",
|
||||
export_result['output_file'],
|
||||
total_count
|
||||
)
|
||||
|
||||
|
||||
return export_result['output_file'], total_count
|
||||
|
||||
|
||||
@@ -111,7 +103,7 @@ def _run_fingerprint_detect(
|
||||
) -> tuple[dict, list]:
|
||||
"""
|
||||
执行指纹识别任务
|
||||
|
||||
|
||||
Args:
|
||||
enabled_tools: 已启用的工具配置字典
|
||||
urls_file: URL 文件路径
|
||||
@@ -120,56 +112,54 @@ def _run_fingerprint_detect(
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
source: 数据源类型
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (tool_stats, failed_tools)
|
||||
"""
|
||||
tool_stats = {}
|
||||
failed_tools = []
|
||||
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 1. 获取指纹库路径
|
||||
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
|
||||
fingerprint_paths = get_fingerprint_paths(lib_names)
|
||||
|
||||
|
||||
if not fingerprint_paths:
|
||||
reason = f"没有可用的指纹库: {lib_names}"
|
||||
logger.warning(reason)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
|
||||
# 2. 将指纹库路径合并到 tool_config(用于命令构建)
|
||||
tool_config_with_paths = {**tool_config, **fingerprint_paths}
|
||||
|
||||
|
||||
# 3. 构建命令
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='fingerprint_detect',
|
||||
command_params={
|
||||
'urls_file': urls_file
|
||||
},
|
||||
command_params={'urls_file': urls_file},
|
||||
tool_config=tool_config_with_paths
|
||||
)
|
||||
except Exception as e:
|
||||
reason = f"命令构建失败: {str(e)}"
|
||||
reason = f"命令构建失败: {e}"
|
||||
logger.error("构建 %s 命令失败: %s", tool_name, e)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
|
||||
# 4. 计算超时时间
|
||||
timeout = calculate_fingerprint_detect_timeout(url_count)
|
||||
|
||||
|
||||
# 5. 生成日志文件路径
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = fingerprint_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
|
||||
logger.info(
|
||||
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
|
||||
tool_name, url_count, timeout, list(fingerprint_paths.keys())
|
||||
)
|
||||
user_log(scan_id, "fingerprint_detect", f"Running {tool_name}: {command}")
|
||||
|
||||
|
||||
# 6. 执行扫描任务
|
||||
try:
|
||||
result = run_xingfinger_and_stream_update_tech_task(
|
||||
@@ -183,14 +173,14 @@ def _run_fingerprint_detect(
|
||||
log_file=str(log_file),
|
||||
batch_size=100
|
||||
)
|
||||
|
||||
|
||||
tool_stats[tool_name] = {
|
||||
'command': command,
|
||||
'result': result,
|
||||
'timeout': timeout,
|
||||
'fingerprint_libs': list(fingerprint_paths.keys())
|
||||
}
|
||||
|
||||
|
||||
tool_updated = result.get('updated_count', 0)
|
||||
logger.info(
|
||||
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
|
||||
@@ -199,20 +189,23 @@ def _run_fingerprint_detect(
|
||||
tool_updated,
|
||||
result.get('not_found_count', 0)
|
||||
)
|
||||
user_log(scan_id, "fingerprint_detect", f"{tool_name} completed: identified {tool_updated} fingerprints")
|
||||
|
||||
user_log(
|
||||
scan_id, "fingerprint_detect",
|
||||
f"{tool_name} completed: identified {tool_updated} fingerprints"
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
reason = str(exc)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
user_log(scan_id, "fingerprint_detect", f"{tool_name} failed: {reason}", "error")
|
||||
|
||||
|
||||
if failed_tools:
|
||||
logger.warning(
|
||||
"以下指纹识别工具执行失败: %s",
|
||||
', '.join([f['tool'] for f in failed_tools])
|
||||
)
|
||||
|
||||
|
||||
return tool_stats, failed_tools
|
||||
|
||||
|
||||
@@ -232,53 +225,38 @@ def fingerprint_detect_flow(
|
||||
) -> dict:
|
||||
"""
|
||||
指纹识别 Flow
|
||||
|
||||
|
||||
主要功能:
|
||||
1. 从数据库导出目标下所有 WebSite URL 到文件
|
||||
2. 使用 xingfinger 进行技术栈识别
|
||||
3. 解析结果并更新 WebSite.tech 字段(合并去重)
|
||||
|
||||
|
||||
工作流程:
|
||||
Step 0: 创建工作目录
|
||||
Step 1: 导出 URL 列表
|
||||
Step 2: 解析配置,获取启用的工具
|
||||
Step 3: 执行 xingfinger 并解析结果
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置(xingfinger)
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'scan_workspace_dir': str,
|
||||
'urls_file': str,
|
||||
'url_count': int,
|
||||
'processed_records': int,
|
||||
'updated_count': int,
|
||||
'created_count': int,
|
||||
'snapshot_count': int,
|
||||
'executed_tasks': list,
|
||||
'tool_stats': dict
|
||||
}
|
||||
dict: 扫描结果
|
||||
"""
|
||||
try:
|
||||
# 负载检查:等待系统资源充足
|
||||
wait_for_system_load(context="fingerprint_detect_flow")
|
||||
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始指纹识别\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
"开始指纹识别 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
)
|
||||
|
||||
user_log(scan_id, "fingerprint_detect", "Starting fingerprint detection")
|
||||
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
@@ -288,46 +266,26 @@ def fingerprint_detect_flow(
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
|
||||
|
||||
# 数据源类型(当前只支持 website)
|
||||
source = 'website'
|
||||
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
|
||||
|
||||
|
||||
# Step 1: 导出 URL(支持懒加载)
|
||||
urls_file, url_count = _export_urls(target_id, fingerprint_dir, source)
|
||||
|
||||
|
||||
if url_count == 0:
|
||||
logger.warning("跳过指纹识别:没有 URL 可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "fingerprint_detect", "Skipped: no URLs to scan", "warning")
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'url_count': 0,
|
||||
'processed_records': 0,
|
||||
'updated_count': 0,
|
||||
'created_count': 0,
|
||||
'snapshot_count': 0,
|
||||
'executed_tasks': ['export_urls_for_fingerprint'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'successful_tools': [],
|
||||
'failed_tools': [],
|
||||
'details': {}
|
||||
}
|
||||
}
|
||||
|
||||
return _build_empty_result(scan_id, target_name, scan_workspace_dir, urls_file)
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
logger.info("Step 2: 工具配置信息")
|
||||
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
|
||||
|
||||
|
||||
# Step 3: 执行指纹识别
|
||||
logger.info("Step 3: 执行指纹识别")
|
||||
tool_stats, failed_tools = _run_fingerprint_detect(
|
||||
@@ -339,24 +297,37 @@ def fingerprint_detect_flow(
|
||||
target_id=target_id,
|
||||
source=source
|
||||
)
|
||||
|
||||
|
||||
# 动态生成已执行的任务列表
|
||||
executed_tasks = ['export_urls_for_fingerprint']
|
||||
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats.keys()])
|
||||
|
||||
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats])
|
||||
|
||||
# 汇总所有工具的结果
|
||||
total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values())
|
||||
total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values())
|
||||
total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values())
|
||||
total_snapshots = sum(stats['result'].get('snapshot_count', 0) for stats in tool_stats.values())
|
||||
|
||||
total_processed = sum(
|
||||
stats['result'].get('processed_records', 0) for stats in tool_stats.values()
|
||||
)
|
||||
total_updated = sum(
|
||||
stats['result'].get('updated_count', 0) for stats in tool_stats.values()
|
||||
)
|
||||
total_created = sum(
|
||||
stats['result'].get('created_count', 0) for stats in tool_stats.values()
|
||||
)
|
||||
total_snapshots = sum(
|
||||
stats['result'].get('snapshot_count', 0) for stats in tool_stats.values()
|
||||
)
|
||||
|
||||
# 记录 Flow 完成
|
||||
logger.info("✓ 指纹识别完成 - 识别指纹: %d", total_updated)
|
||||
user_log(scan_id, "fingerprint_detect", f"fingerprint_detect completed: identified {total_updated} fingerprints")
|
||||
|
||||
successful_tools = [name for name in enabled_tools.keys()
|
||||
if name not in [f['tool'] for f in failed_tools]]
|
||||
|
||||
user_log(
|
||||
scan_id, "fingerprint_detect",
|
||||
f"fingerprint_detect completed: identified {total_updated} fingerprints"
|
||||
)
|
||||
|
||||
successful_tools = [
|
||||
name for name in enabled_tools
|
||||
if name not in [f['tool'] for f in failed_tools]
|
||||
]
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -378,7 +349,7 @@ def fingerprint_detect_flow(
|
||||
'details': tool_stats
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("配置错误: %s", e)
|
||||
raise
|
||||
@@ -388,3 +359,33 @@ def fingerprint_detect_flow(
|
||||
except Exception as e:
|
||||
logger.exception("指纹识别失败: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _build_empty_result(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
scan_workspace_dir: str,
|
||||
urls_file: str
|
||||
) -> dict:
|
||||
"""构建空结果(无 URL 可扫描时)"""
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'url_count': 0,
|
||||
'processed_records': 0,
|
||||
'updated_count': 0,
|
||||
'created_count': 0,
|
||||
'snapshot_count': 0,
|
||||
'executed_tasks': ['export_urls_for_fingerprint'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'successful_tools': [],
|
||||
'failed_tools': [],
|
||||
'details': {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""
|
||||
"""
|
||||
端口扫描 Flow
|
||||
|
||||
负责编排端口扫描的完整流程
|
||||
@@ -10,25 +10,23 @@
|
||||
- 配置由 YAML 解析
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from prefect import flow
|
||||
from apps.scan.tasks.port_scan import (
|
||||
export_hosts_task,
|
||||
run_and_stream_save_ports_task
|
||||
)
|
||||
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.utils import config_parser, build_scan_command, user_log
|
||||
from apps.scan.tasks.port_scan import (
|
||||
export_hosts_task,
|
||||
run_and_stream_save_ports_task,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -40,28 +38,19 @@ def calculate_port_scan_timeout(
|
||||
) -> int:
|
||||
"""
|
||||
根据目标数量和端口数量计算超时时间
|
||||
|
||||
|
||||
计算公式:超时时间 = 目标数 × 端口数 × base_per_pair
|
||||
超时范围:60秒 ~ 2天(172800秒)
|
||||
|
||||
超时范围:60秒 ~ 无上限
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置字典,包含端口配置(ports, top-ports等)
|
||||
file_path: 目标文件路径(域名/IP列表)
|
||||
base_per_pair: 每个"端口-目标对"的基础时间(秒),默认 0.5秒
|
||||
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒),范围:60 ~ 172800
|
||||
|
||||
Example:
|
||||
# 100个目标 × 100个端口 × 0.5秒 = 5000秒
|
||||
# 10个目标 × 1000个端口 × 0.5秒 = 5000秒
|
||||
timeout = calculate_port_scan_timeout(
|
||||
tool_config={'top-ports': 100},
|
||||
file_path='/path/to/domains.txt'
|
||||
)
|
||||
int: 计算出的超时时间(秒),最小 60 秒
|
||||
"""
|
||||
try:
|
||||
# 1. 统计目标数量
|
||||
result = subprocess.run(
|
||||
['wc', '-l', file_path],
|
||||
capture_output=True,
|
||||
@@ -69,88 +58,74 @@ def calculate_port_scan_timeout(
|
||||
check=True
|
||||
)
|
||||
target_count = int(result.stdout.strip().split()[0])
|
||||
|
||||
# 2. 解析端口数量
|
||||
port_count = _parse_port_count(tool_config)
|
||||
|
||||
# 3. 计算超时时间
|
||||
# 总工作量 = 目标数 × 端口数
|
||||
total_work = target_count * port_count
|
||||
timeout = int(total_work * base_per_pair)
|
||||
|
||||
# 4. 设置合理的下限(不再设置上限)
|
||||
min_timeout = 60 # 最小 60 秒
|
||||
timeout = max(min_timeout, timeout)
|
||||
|
||||
timeout = max(60, int(total_work * base_per_pair))
|
||||
|
||||
logger.info(
|
||||
f"计算端口扫描 timeout - "
|
||||
f"目标数: {target_count}, "
|
||||
f"端口数: {port_count}, "
|
||||
f"总工作量: {total_work}, "
|
||||
f"超时: {timeout}秒"
|
||||
"计算端口扫描 timeout - 目标数: %d, 端口数: %d, 总工作量: %d, 超时: %d秒",
|
||||
target_count, port_count, total_work, timeout
|
||||
)
|
||||
return timeout
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算 timeout 失败: {e},使用默认值 600秒")
|
||||
logger.warning("计算 timeout 失败: %s,使用默认值 600秒", e)
|
||||
return 600
|
||||
|
||||
|
||||
def _parse_port_count(tool_config: dict) -> int:
|
||||
"""
|
||||
从工具配置中解析端口数量
|
||||
|
||||
|
||||
优先级:
|
||||
1. top-ports: N → 返回 N
|
||||
2. ports: "80,443,8080" → 返回逗号分隔的数量
|
||||
3. ports: "1-1000" → 返回范围的大小
|
||||
4. ports: "1-65535" → 返回 65535
|
||||
5. 默认 → 返回 100(naabu 默认扫描 top 100)
|
||||
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置字典
|
||||
|
||||
|
||||
Returns:
|
||||
int: 端口数量
|
||||
"""
|
||||
# 1. 检查 top-ports 配置
|
||||
# 检查 top-ports 配置
|
||||
if 'top-ports' in tool_config:
|
||||
top_ports = tool_config['top-ports']
|
||||
if isinstance(top_ports, int) and top_ports > 0:
|
||||
return top_ports
|
||||
logger.warning(f"top-ports 配置无效: {top_ports},使用默认值")
|
||||
|
||||
# 2. 检查 ports 配置
|
||||
logger.warning("top-ports 配置无效: %s,使用默认值", top_ports)
|
||||
|
||||
# 检查 ports 配置
|
||||
if 'ports' in tool_config:
|
||||
ports_str = str(tool_config['ports']).strip()
|
||||
|
||||
# 2.1 逗号分隔的端口列表:80,443,8080
|
||||
|
||||
# 逗号分隔的端口列表:80,443,8080
|
||||
if ',' in ports_str:
|
||||
port_list = [p.strip() for p in ports_str.split(',') if p.strip()]
|
||||
return len(port_list)
|
||||
|
||||
# 2.2 端口范围:1-1000
|
||||
return len([p.strip() for p in ports_str.split(',') if p.strip()])
|
||||
|
||||
# 端口范围:1-1000
|
||||
if '-' in ports_str:
|
||||
try:
|
||||
start, end = ports_str.split('-', 1)
|
||||
start_port = int(start.strip())
|
||||
end_port = int(end.strip())
|
||||
|
||||
if 1 <= start_port <= end_port <= 65535:
|
||||
return end_port - start_port + 1
|
||||
logger.warning(f"端口范围无效: {ports_str},使用默认值")
|
||||
logger.warning("端口范围无效: %s,使用默认值", ports_str)
|
||||
except ValueError:
|
||||
logger.warning(f"端口范围解析失败: {ports_str},使用默认值")
|
||||
|
||||
# 2.3 单个端口
|
||||
logger.warning("端口范围解析失败: %s,使用默认值", ports_str)
|
||||
|
||||
# 单个端口
|
||||
try:
|
||||
port = int(ports_str)
|
||||
if 1 <= port <= 65535:
|
||||
return 1
|
||||
except ValueError:
|
||||
logger.warning(f"端口配置解析失败: {ports_str},使用默认值")
|
||||
|
||||
# 3. 默认值:naabu 默认扫描 top 100 端口
|
||||
logger.warning("端口配置解析失败: %s,使用默认值", ports_str)
|
||||
|
||||
# 默认值:naabu 默认扫描 top 100 端口
|
||||
return 100
|
||||
|
||||
|
||||
@@ -160,41 +135,38 @@ def _parse_port_count(tool_config: dict) -> int:
|
||||
def _export_hosts(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
|
||||
"""
|
||||
导出主机列表到文件
|
||||
|
||||
|
||||
根据 Target 类型自动决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名
|
||||
- IP: 直接写入 target.name
|
||||
- CIDR: 展开 CIDR 范围内的所有 IP
|
||||
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
port_scan_dir: 端口扫描目录
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (hosts_file, host_count, target_type)
|
||||
"""
|
||||
logger.info("Step 1: 导出主机列表")
|
||||
|
||||
|
||||
hosts_file = str(port_scan_dir / 'hosts.txt')
|
||||
export_result = export_hosts_task(
|
||||
target_id=target_id,
|
||||
output_file=hosts_file,
|
||||
batch_size=1000 # 每次读取 1000 条,优化内存占用
|
||||
)
|
||||
|
||||
|
||||
host_count = export_result['total_count']
|
||||
target_type = export_result.get('target_type', 'unknown')
|
||||
|
||||
|
||||
logger.info(
|
||||
"✓ 主机列表导出完成 - 类型: %s, 文件: %s, 数量: %d",
|
||||
target_type,
|
||||
export_result['output_file'],
|
||||
host_count
|
||||
target_type, export_result['output_file'], host_count
|
||||
)
|
||||
|
||||
|
||||
if host_count == 0:
|
||||
logger.warning("目标下没有可扫描的主机,无法执行端口扫描")
|
||||
|
||||
|
||||
return export_result['output_file'], host_count, target_type
|
||||
|
||||
|
||||
@@ -208,7 +180,7 @@ def _run_scans_sequentially(
|
||||
) -> tuple[dict, int, list, list]:
|
||||
"""
|
||||
串行执行端口扫描任务
|
||||
|
||||
|
||||
Args:
|
||||
enabled_tools: 已启用的工具配置字典
|
||||
domains_file: 域名文件路径
|
||||
@@ -216,72 +188,56 @@ def _run_scans_sequentially(
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(用于错误日志)
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
|
||||
注意:端口扫描是流式输出,不生成结果文件
|
||||
|
||||
Raises:
|
||||
RuntimeError: 所有工具均失败
|
||||
"""
|
||||
# ==================== 构建命令并串行执行 ====================
|
||||
|
||||
tool_stats = {}
|
||||
processed_records = 0
|
||||
failed_tools = [] # 记录失败的工具(含原因)
|
||||
|
||||
# for循环执行工具:按顺序串行运行每个启用的端口扫描工具
|
||||
failed_tools = []
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 1. 构建完整命令(变量替换)
|
||||
# 构建命令
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='port_scan',
|
||||
command_params={
|
||||
'domains_file': domains_file # 对应 {domains_file}
|
||||
},
|
||||
tool_config=tool_config #yaml的工具配置
|
||||
command_params={'domains_file': domains_file},
|
||||
tool_config=tool_config
|
||||
)
|
||||
except Exception as e:
|
||||
reason = f"命令构建失败: {str(e)}"
|
||||
logger.error(f"构建 {tool_name} 命令失败: {e}")
|
||||
reason = f"命令构建失败: {e}"
|
||||
logger.error("构建 %s 命令失败: %s", tool_name, e)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
# 2. 获取超时时间(支持 'auto' 动态计算)
|
||||
|
||||
# 获取超时时间
|
||||
config_timeout = tool_config['timeout']
|
||||
if config_timeout == 'auto':
|
||||
# 动态计算超时时间
|
||||
config_timeout = calculate_port_scan_timeout(
|
||||
tool_config=tool_config,
|
||||
file_path=str(domains_file)
|
||||
)
|
||||
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {config_timeout}秒")
|
||||
|
||||
# 2.1 生成日志文件路径
|
||||
from datetime import datetime
|
||||
config_timeout = calculate_port_scan_timeout(tool_config, str(domains_file))
|
||||
logger.info("✓ 工具 %s 动态计算 timeout: %d秒", tool_name, config_timeout)
|
||||
|
||||
# 生成日志文件路径
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = port_scan_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
# 3. 执行扫描任务
|
||||
|
||||
logger.info("开始执行 %s 扫描(超时: %d秒)...", tool_name, config_timeout)
|
||||
user_log(scan_id, "port_scan", f"Running {tool_name}: {command}")
|
||||
|
||||
|
||||
# 执行扫描任务
|
||||
try:
|
||||
# 直接调用 task(串行执行)
|
||||
# 注意:端口扫描是流式输出到 stdout,不使用 output_file
|
||||
result = run_and_stream_save_ports_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name, # 工具名称
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(port_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1000,
|
||||
timeout=config_timeout,
|
||||
log_file=str(log_file) # 新增:日志文件路径
|
||||
log_file=str(log_file)
|
||||
)
|
||||
|
||||
|
||||
tool_stats[tool_name] = {
|
||||
'command': command,
|
||||
'result': result,
|
||||
@@ -289,15 +245,10 @@ def _run_scans_sequentially(
|
||||
}
|
||||
tool_records = result.get('processed_records', 0)
|
||||
processed_records += tool_records
|
||||
logger.info(
|
||||
"✓ 工具 %s 流式处理完成 - 记录数: %d",
|
||||
tool_name, tool_records
|
||||
)
|
||||
logger.info("✓ 工具 %s 流式处理完成 - 记录数: %d", tool_name, tool_records)
|
||||
user_log(scan_id, "port_scan", f"{tool_name} completed: found {tool_records} ports")
|
||||
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
# 超时异常单独处理
|
||||
# 注意:流式处理任务超时时,已解析的数据已保存到数据库
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
reason = f"timeout after {config_timeout}s"
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.warning(
|
||||
@@ -307,40 +258,39 @@ def _run_scans_sequentially(
|
||||
)
|
||||
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
|
||||
except Exception as exc:
|
||||
# 其他异常
|
||||
reason = str(exc)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
|
||||
|
||||
|
||||
if failed_tools:
|
||||
logger.warning(
|
||||
"以下扫描工具执行失败: %s",
|
||||
', '.join([f['tool'] for f in failed_tools])
|
||||
)
|
||||
|
||||
|
||||
if not tool_stats:
|
||||
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
|
||||
logger.warning("所有端口扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
|
||||
# 返回空结果,不抛出异常,让扫描继续
|
||||
return {}, 0, [], failed_tools
|
||||
|
||||
# 动态计算成功的工具列表
|
||||
successful_tool_names = [name for name in enabled_tools.keys()
|
||||
if name not in [f['tool'] for f in failed_tools]]
|
||||
|
||||
|
||||
successful_tool_names = [
|
||||
name for name in enabled_tools
|
||||
if name not in [f['tool'] for f in failed_tools]
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"✓ 串行端口扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
|
||||
len(tool_stats), len(enabled_tools),
|
||||
', '.join(successful_tool_names) if successful_tool_names else '无',
|
||||
', '.join([f['tool'] for f in failed_tools]) if failed_tools else '无'
|
||||
)
|
||||
|
||||
|
||||
return tool_stats, processed_records, successful_tool_names, failed_tools
|
||||
|
||||
|
||||
@flow(
|
||||
name="port_scan",
|
||||
name="port_scan",
|
||||
log_prints=True,
|
||||
on_running=[on_scan_flow_running],
|
||||
on_completion=[on_scan_flow_completed],
|
||||
@@ -355,19 +305,19 @@ def port_scan_flow(
|
||||
) -> dict:
|
||||
"""
|
||||
端口扫描 Flow
|
||||
|
||||
|
||||
主要功能:
|
||||
1. 扫描目标域名/IP 的开放端口
|
||||
2. 保存 host + ip + port 三元映射到 HostPortMapping 表
|
||||
|
||||
|
||||
输出资产:
|
||||
- HostPortMapping:主机端口映射(host + ip + port 三元组)
|
||||
|
||||
|
||||
工作流程:
|
||||
Step 0: 创建工作目录
|
||||
Step 1: 导出域名列表到文件(供扫描工具使用)
|
||||
Step 2: 解析配置,获取启用的工具
|
||||
Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库(→ HostPortMapping)
|
||||
Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
@@ -377,35 +327,15 @@ def port_scan_flow(
|
||||
enabled_tools: 启用的工具配置字典
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'scan_workspace_dir': str,
|
||||
'hosts_file': str,
|
||||
'host_count': int,
|
||||
'processed_records': int,
|
||||
'executed_tasks': list,
|
||||
'tool_stats': {
|
||||
'total': int, # 总工具数
|
||||
'successful': int, # 成功工具数
|
||||
'failed': int, # 失败工具数
|
||||
'successful_tools': list[str], # 成功工具列表 ['naabu_active']
|
||||
'failed_tools': list[dict], # 失败工具列表 [{'tool': 'naabu_passive', 'reason': '超时'}]
|
||||
'details': dict # 详细执行结果(保留向后兼容)
|
||||
}
|
||||
}
|
||||
dict: 扫描结果
|
||||
|
||||
Raises:
|
||||
ValueError: 配置错误
|
||||
RuntimeError: 执行失败
|
||||
|
||||
Note:
|
||||
端口扫描工具(如 naabu)会解析域名获取 IP,输出 host + ip + port 三元组。
|
||||
同一 host 可能对应多个 IP(CDN、负载均衡),因此使用三元映射表存储。
|
||||
"""
|
||||
try:
|
||||
# 参数验证
|
||||
wait_for_system_load(context="port_scan_flow")
|
||||
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
@@ -416,25 +346,20 @@ def port_scan_flow(
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
if not enabled_tools:
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始端口扫描\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
"开始端口扫描 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
)
|
||||
|
||||
user_log(scan_id, "port_scan", "Starting port scan")
|
||||
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
|
||||
|
||||
# Step 1: 导出主机列表到文件(根据 Target 类型自动决定内容)
|
||||
|
||||
# Step 1: 导出主机列表
|
||||
hosts_file, host_count, target_type = _export_hosts(target_id, port_scan_dir)
|
||||
|
||||
|
||||
if host_count == 0:
|
||||
logger.warning("跳过端口扫描:没有主机可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "port_scan", "Skipped: no hosts to scan", "warning")
|
||||
@@ -457,14 +382,11 @@ def port_scan_flow(
|
||||
'details': {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
logger.info("Step 2: 工具配置信息")
|
||||
logger.info(
|
||||
"✓ 启用工具: %s",
|
||||
', '.join(enabled_tools.keys())
|
||||
)
|
||||
|
||||
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
|
||||
|
||||
# Step 3: 串行执行扫描工具
|
||||
logger.info("Step 3: 串行执行扫描工具")
|
||||
tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially(
|
||||
@@ -475,15 +397,13 @@ def port_scan_flow(
|
||||
target_id=target_id,
|
||||
target_name=target_name
|
||||
)
|
||||
|
||||
# 记录 Flow 完成
|
||||
|
||||
logger.info("✓ 端口扫描完成 - 发现端口: %d", processed_records)
|
||||
user_log(scan_id, "port_scan", f"port_scan completed: found {processed_records} ports")
|
||||
|
||||
# 动态生成已执行的任务列表
|
||||
|
||||
executed_tasks = ['export_hosts', 'parse_config']
|
||||
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats.keys()])
|
||||
|
||||
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats])
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
|
||||
@@ -5,43 +5,39 @@
|
||||
1. 从数据库获取 URL 列表(websites 和/或 endpoints)
|
||||
2. 批量截图并保存快照
|
||||
3. 同步到资产表
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
|
||||
"""
|
||||
|
||||
# Django 环境初始化
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.tasks.screenshot import capture_screenshots_task
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.utils import user_log
|
||||
from apps.scan.services.target_export_service import (
|
||||
get_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider
|
||||
from apps.scan.services.target_export_service import DataSource, get_urls_with_fallback
|
||||
from apps.scan.tasks.screenshot import capture_screenshots_task
|
||||
from apps.scan.utils import user_log, wait_for_system_load
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# URL 来源到 DataSource 的映射
|
||||
_SOURCE_MAPPING = {
|
||||
'websites': DataSource.WEBSITE,
|
||||
'endpoints': DataSource.ENDPOINT,
|
||||
}
|
||||
|
||||
|
||||
def _parse_screenshot_config(enabled_tools: dict) -> dict:
|
||||
"""
|
||||
解析截图配置
|
||||
|
||||
Args:
|
||||
enabled_tools: 启用的工具配置
|
||||
|
||||
Returns:
|
||||
截图配置字典
|
||||
"""
|
||||
# 从 enabled_tools 中获取 playwright 配置
|
||||
"""解析截图配置"""
|
||||
playwright_config = enabled_tools.get('playwright', {})
|
||||
|
||||
return {
|
||||
'concurrency': playwright_config.get('concurrency', 5),
|
||||
'url_sources': playwright_config.get('url_sources', ['websites'])
|
||||
@@ -49,33 +45,54 @@ def _parse_screenshot_config(enabled_tools: dict) -> dict:
|
||||
|
||||
|
||||
def _map_url_sources_to_data_sources(url_sources: list[str]) -> list[str]:
|
||||
"""
|
||||
将配置中的 url_sources 映射为 DataSource 常量
|
||||
|
||||
Args:
|
||||
url_sources: 配置中的来源列表,如 ['websites', 'endpoints']
|
||||
|
||||
Returns:
|
||||
DataSource 常量列表
|
||||
"""
|
||||
source_mapping = {
|
||||
'websites': DataSource.WEBSITE,
|
||||
'endpoints': DataSource.ENDPOINT,
|
||||
}
|
||||
|
||||
"""将配置中的 url_sources 映射为 DataSource 常量"""
|
||||
sources = []
|
||||
for source in url_sources:
|
||||
if source in source_mapping:
|
||||
sources.append(source_mapping[source])
|
||||
if source in _SOURCE_MAPPING:
|
||||
sources.append(_SOURCE_MAPPING[source])
|
||||
else:
|
||||
logger.warning("未知的 URL 来源: %s,跳过", source)
|
||||
|
||||
|
||||
# 添加默认回退(从 subdomain 构造)
|
||||
sources.append(DataSource.DEFAULT)
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str, list[str]]:
|
||||
"""从 Provider 收集 URL"""
|
||||
logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__)
|
||||
urls = list(provider.iter_urls())
|
||||
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
if blacklist_filter:
|
||||
urls = [url for url in urls if blacklist_filter.is_allowed(url)]
|
||||
|
||||
return urls, 'provider', ['provider']
|
||||
|
||||
|
||||
def _collect_urls_from_database(
|
||||
target_id: int,
|
||||
url_sources: list[str]
|
||||
) -> tuple[list[str], str, list[str]]:
|
||||
"""从数据库收集 URL(带黑名单过滤和回退)"""
|
||||
data_sources = _map_url_sources_to_data_sources(url_sources)
|
||||
result = get_urls_with_fallback(target_id, sources=data_sources)
|
||||
return result['urls'], result['source'], result['tried_sources']
|
||||
|
||||
|
||||
def _build_empty_result(scan_id: int, target_name: str) -> dict:
|
||||
"""构建空结果"""
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'total_urls': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'synced': 0
|
||||
}
|
||||
|
||||
|
||||
@flow(
|
||||
name="screenshot",
|
||||
log_prints=True,
|
||||
@@ -88,115 +105,104 @@ def screenshot_flow(
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict
|
||||
enabled_tools: dict,
|
||||
provider: Optional[TargetProvider] = None
|
||||
) -> dict:
|
||||
"""
|
||||
截图 Flow
|
||||
|
||||
工作流程:
|
||||
Step 1: 解析配置
|
||||
Step 2: 收集 URL 列表
|
||||
Step 3: 批量截图并保存快照
|
||||
Step 4: 同步到资产表
|
||||
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置
|
||||
|
||||
provider: TargetProvider 实例(新模式,可选)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'total_urls': int,
|
||||
'successful': int,
|
||||
'failed': int,
|
||||
'synced': int
|
||||
}
|
||||
截图结果字典
|
||||
"""
|
||||
try:
|
||||
# 负载检查:等待系统资源充足
|
||||
wait_for_system_load(context="screenshot_flow")
|
||||
|
||||
mode = 'Provider' if provider else 'Legacy'
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始截图扫描\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
"开始截图扫描 - Scan ID: %s, Target: %s, Mode: %s",
|
||||
scan_id, target_name, mode
|
||||
)
|
||||
|
||||
user_log(scan_id, "screenshot", "Starting screenshot capture")
|
||||
|
||||
|
||||
# Step 1: 解析配置
|
||||
config = _parse_screenshot_config(enabled_tools)
|
||||
concurrency = config['concurrency']
|
||||
url_sources = config['url_sources']
|
||||
|
||||
logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, url_sources)
|
||||
|
||||
# Step 2: 使用统一服务收集 URL(带黑名单过滤和回退)
|
||||
data_sources = _map_url_sources_to_data_sources(url_sources)
|
||||
result = get_urls_with_fallback(target_id, sources=data_sources)
|
||||
|
||||
urls = result['urls']
|
||||
logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, config['url_sources'])
|
||||
|
||||
# Step 2: 收集 URL 列表
|
||||
if provider is not None:
|
||||
urls, source_info, tried_sources = _collect_urls_from_provider(provider)
|
||||
else:
|
||||
urls, source_info, tried_sources = _collect_urls_from_database(
|
||||
target_id, config['url_sources']
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"URL 收集完成 - 来源: %s, 数量: %d, 尝试过: %s",
|
||||
result['source'], result['total_count'], result['tried_sources']
|
||||
source_info, len(urls), tried_sources
|
||||
)
|
||||
|
||||
|
||||
if not urls:
|
||||
logger.warning("没有可截图的 URL,跳过截图任务")
|
||||
user_log(scan_id, "screenshot", "Skipped: no URLs to capture", "warning")
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'total_urls': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'synced': 0
|
||||
}
|
||||
|
||||
user_log(scan_id, "screenshot", f"Found {len(urls)} URLs to capture (source: {result['source']})")
|
||||
|
||||
return _build_empty_result(scan_id, target_name)
|
||||
|
||||
user_log(
|
||||
scan_id, "screenshot",
|
||||
f"Found {len(urls)} URLs to capture (source: {source_info})"
|
||||
)
|
||||
|
||||
# Step 3: 批量截图
|
||||
logger.info("Step 3: 批量截图 - %d 个 URL", len(urls))
|
||||
|
||||
logger.info("批量截图 - %d 个 URL", len(urls))
|
||||
capture_result = capture_screenshots_task(
|
||||
urls=urls,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
config={'concurrency': concurrency}
|
||||
)
|
||||
|
||||
|
||||
# Step 4: 同步到资产表
|
||||
logger.info("Step 4: 同步截图到资产表")
|
||||
logger.info("同步截图到资产表")
|
||||
from apps.asset.services.screenshot_service import ScreenshotService
|
||||
screenshot_service = ScreenshotService()
|
||||
synced = screenshot_service.sync_screenshots_to_asset(scan_id, target_id)
|
||||
|
||||
synced = ScreenshotService().sync_screenshots_to_asset(scan_id, target_id)
|
||||
|
||||
total = capture_result['total']
|
||||
successful = capture_result['successful']
|
||||
failed = capture_result['failed']
|
||||
|
||||
logger.info(
|
||||
"✓ 截图完成 - 总数: %d, 成功: %d, 失败: %d, 同步: %d",
|
||||
capture_result['total'], capture_result['successful'], capture_result['failed'], synced
|
||||
total, successful, failed, synced
|
||||
)
|
||||
user_log(
|
||||
scan_id, "screenshot",
|
||||
f"Screenshot completed: {capture_result['successful']}/{capture_result['total']} captured, {synced} synced"
|
||||
f"Screenshot completed: {successful}/{total} captured, {synced} synced"
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'total_urls': capture_result['total'],
|
||||
'successful': capture_result['successful'],
|
||||
'failed': capture_result['failed'],
|
||||
'total_urls': total,
|
||||
'successful': successful,
|
||||
'failed': failed,
|
||||
'synced': synced
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("截图 Flow 失败: %s", e)
|
||||
user_log(scan_id, "screenshot", f"Screenshot failed: {e}", "error")
|
||||
|
||||
except Exception:
|
||||
logger.exception("截图 Flow 失败")
|
||||
user_log(scan_id, "screenshot", "Screenshot failed", "error")
|
||||
raise
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
"""
|
||||
站点扫描 Flow
|
||||
|
||||
@@ -11,303 +10,319 @@
|
||||
- 配置由 YAML 解析
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
|
||||
from prefect import flow
|
||||
from apps.scan.tasks.site_scan import export_site_urls_task, run_and_stream_save_websites_task
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect # noqa: F401
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.utils import config_parser, build_scan_command, user_log
|
||||
from apps.scan.tasks.site_scan import (
|
||||
export_site_urls_task,
|
||||
run_and_stream_save_websites_task,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_timeout_by_line_count(
|
||||
tool_config: dict,
|
||||
file_path: str,
|
||||
base_per_time: int = 1,
|
||||
min_timeout: int = 60
|
||||
) -> int:
|
||||
"""
|
||||
根据文件行数计算 timeout
|
||||
|
||||
使用 wc -l 统计文件行数,根据行数和每行基础时间计算 timeout
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置字典(此函数未使用,但保持接口一致性)
|
||||
file_path: 要统计行数的文件路径
|
||||
base_per_time: 每行的基础时间(秒),默认1秒
|
||||
min_timeout: 最小超时时间(秒),默认60秒
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒),不低于 min_timeout
|
||||
|
||||
Example:
|
||||
timeout = calculate_timeout_by_line_count(
|
||||
tool_config={},
|
||||
file_path='/path/to/urls.txt',
|
||||
base_per_time=2
|
||||
)
|
||||
"""
|
||||
@dataclass
|
||||
class ScanContext:
|
||||
"""扫描上下文,封装扫描参数"""
|
||||
scan_id: int
|
||||
target_id: int
|
||||
target_name: str
|
||||
site_scan_dir: Path
|
||||
urls_file: str
|
||||
total_urls: int
|
||||
|
||||
|
||||
def _count_file_lines(file_path: str) -> int:
|
||||
"""使用 wc -l 统计文件行数"""
|
||||
try:
|
||||
# 使用 wc -l 快速统计行数
|
||||
result = subprocess.run(
|
||||
['wc', '-l', file_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
# wc -l 输出格式:行数 + 空格 + 文件名
|
||||
line_count = int(result.stdout.strip().split()[0])
|
||||
|
||||
# 计算 timeout:行数 × 每行基础时间,不低于最小值
|
||||
timeout = max(line_count * base_per_time, min_timeout)
|
||||
|
||||
logger.info(
|
||||
f"timeout 自动计算: 文件={file_path}, "
|
||||
f"行数={line_count}, 每行时间={base_per_time}秒, 最小值={min_timeout}秒, timeout={timeout}秒"
|
||||
)
|
||||
|
||||
return timeout
|
||||
|
||||
except Exception as e:
|
||||
# 如果 wc -l 失败,使用默认值
|
||||
logger.warning(f"wc -l 计算行数失败: {e},使用默认 timeout: {min_timeout}秒")
|
||||
return min_timeout
|
||||
return int(result.stdout.strip().split()[0])
|
||||
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
|
||||
logger.warning("wc -l 计算行数失败: %s,返回 0", e)
|
||||
return 0
|
||||
|
||||
|
||||
def _calculate_timeout_by_line_count(
|
||||
file_path: str,
|
||||
base_per_time: int = 1,
|
||||
min_timeout: int = 60
|
||||
) -> int:
|
||||
"""
|
||||
根据文件行数计算 timeout
|
||||
|
||||
Args:
|
||||
file_path: 要统计行数的文件路径
|
||||
base_per_time: 每行的基础时间(秒),默认1秒
|
||||
min_timeout: 最小超时时间(秒),默认60秒
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒),不低于 min_timeout
|
||||
"""
|
||||
line_count = _count_file_lines(file_path)
|
||||
timeout = max(line_count * base_per_time, min_timeout)
|
||||
|
||||
logger.info(
|
||||
"timeout 自动计算: 文件=%s, 行数=%d, 每行时间=%d秒, timeout=%d秒",
|
||||
file_path, line_count, base_per_time, timeout
|
||||
)
|
||||
return timeout
|
||||
|
||||
|
||||
|
||||
def _export_site_urls(target_id: int, site_scan_dir: Path, target_name: str = None) -> tuple[str, int, int]:
|
||||
def _export_site_urls(
|
||||
target_id: int,
|
||||
site_scan_dir: Path
|
||||
) -> tuple[str, int, int]:
|
||||
"""
|
||||
导出站点 URL 到文件
|
||||
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
site_scan_dir: 站点扫描目录
|
||||
target_name: 目标名称(用于懒加载时写入默认值)
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (urls_file, total_urls, association_count)
|
||||
|
||||
Raises:
|
||||
ValueError: URL 数量为 0
|
||||
"""
|
||||
logger.info("Step 1: 导出站点URL列表")
|
||||
|
||||
|
||||
urls_file = str(site_scan_dir / 'site_urls.txt')
|
||||
export_result = export_site_urls_task(
|
||||
target_id=target_id,
|
||||
output_file=urls_file,
|
||||
batch_size=1000 # 每次处理1000个子域名
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
|
||||
total_urls = export_result['total_urls']
|
||||
association_count = export_result['association_count'] # 主机端口关联数
|
||||
|
||||
association_count = export_result['association_count']
|
||||
|
||||
logger.info(
|
||||
"✓ 站点URL导出完成 - 文件: %s, URL数量: %d, 关联数: %d",
|
||||
export_result['output_file'],
|
||||
total_urls,
|
||||
association_count
|
||||
export_result['output_file'], total_urls, association_count
|
||||
)
|
||||
|
||||
|
||||
if total_urls == 0:
|
||||
logger.warning("目标下没有可用的站点URL,无法执行站点扫描")
|
||||
# 不抛出异常,由上层决定如何处理
|
||||
# raise ValueError("目标下没有可用的站点URL,无法执行站点扫描")
|
||||
|
||||
|
||||
return export_result['output_file'], total_urls, association_count
|
||||
|
||||
|
||||
def _get_tool_timeout(tool_config: dict, urls_file: str) -> int:
|
||||
"""获取工具超时时间(支持 'auto' 动态计算)"""
|
||||
config_timeout = tool_config.get('timeout', 300)
|
||||
|
||||
if config_timeout == 'auto':
|
||||
return _calculate_timeout_by_line_count(urls_file, base_per_time=1)
|
||||
|
||||
dynamic_timeout = _calculate_timeout_by_line_count(urls_file, base_per_time=1)
|
||||
return max(dynamic_timeout, config_timeout)
|
||||
|
||||
|
||||
def _execute_single_tool(
|
||||
tool_name: str,
|
||||
tool_config: dict,
|
||||
ctx: ScanContext
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
执行单个扫描工具
|
||||
|
||||
Returns:
|
||||
成功返回结果字典,失败返回 None
|
||||
"""
|
||||
# 构建命令
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='site_scan',
|
||||
command_params={'url_file': ctx.urls_file},
|
||||
tool_config=tool_config
|
||||
)
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.error("构建 %s 命令失败: %s", tool_name, e)
|
||||
return None
|
||||
|
||||
timeout = _get_tool_timeout(tool_config, ctx.urls_file)
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = ctx.site_scan_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
logger.info(
|
||||
"开始执行 %s 站点扫描 - URL数: %d, 超时: %ds",
|
||||
tool_name, ctx.total_urls, timeout
|
||||
)
|
||||
user_log(ctx.scan_id, "site_scan", f"Running {tool_name}: {command}")
|
||||
|
||||
try:
|
||||
result = run_and_stream_save_websites_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name,
|
||||
scan_id=ctx.scan_id,
|
||||
target_id=ctx.target_id,
|
||||
cwd=str(ctx.site_scan_dir),
|
||||
shell=True,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file)
|
||||
)
|
||||
|
||||
tool_created = result.get('created_websites', 0)
|
||||
skipped = result.get('skipped_no_subdomain', 0) + result.get('skipped_failed', 0)
|
||||
|
||||
logger.info(
|
||||
"✓ 工具 %s 完成 - 处理: %d, 创建: %d, 跳过: %d",
|
||||
tool_name, result.get('processed_records', 0), tool_created, skipped
|
||||
)
|
||||
user_log(
|
||||
ctx.scan_id, "site_scan",
|
||||
f"{tool_name} completed: found {tool_created} websites"
|
||||
)
|
||||
|
||||
return {'command': command, 'result': result, 'timeout': timeout}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(
|
||||
"⚠️ 工具 %s 执行超时 - 超时配置: %d秒 (超时前数据已保存)",
|
||||
tool_name, timeout
|
||||
)
|
||||
user_log(
|
||||
ctx.scan_id, "site_scan",
|
||||
f"{tool_name} failed: timeout after {timeout}s", "error"
|
||||
)
|
||||
except (OSError, RuntimeError) as exc:
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
user_log(ctx.scan_id, "site_scan", f"{tool_name} failed: {exc}", "error")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _run_scans_sequentially(
|
||||
enabled_tools: dict,
|
||||
urls_file: str,
|
||||
total_urls: int,
|
||||
site_scan_dir: Path,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
target_name: str
|
||||
ctx: ScanContext
|
||||
) -> tuple[dict, int, list, list]:
|
||||
"""
|
||||
串行执行站点扫描任务
|
||||
|
||||
Args:
|
||||
enabled_tools: 已启用的工具配置字典
|
||||
urls_file: URL 文件路径
|
||||
total_urls: URL 总数
|
||||
site_scan_dir: 站点扫描目录
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(用于错误日志)
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 所有工具均失败
|
||||
tuple: (tool_stats, processed_records, successful_tools, failed_tools)
|
||||
"""
|
||||
tool_stats = {}
|
||||
processed_records = 0
|
||||
failed_tools = []
|
||||
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 1. 构建完整命令(变量替换)
|
||||
try:
|
||||
command_params = {'url_file': urls_file}
|
||||
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='site_scan',
|
||||
command_params=command_params,
|
||||
tool_config=tool_config
|
||||
)
|
||||
except Exception as e:
|
||||
reason = f"命令构建失败: {str(e)}"
|
||||
logger.error(f"构建 {tool_name} 命令失败: {e}")
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
# 2. 获取超时时间(支持 'auto' 动态计算)
|
||||
config_timeout = tool_config.get('timeout', 300)
|
||||
if config_timeout == 'auto':
|
||||
# 动态计算超时时间
|
||||
timeout = calculate_timeout_by_line_count(tool_config, urls_file, base_per_time=1)
|
||||
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {timeout}秒")
|
||||
result = _execute_single_tool(tool_name, tool_config, ctx)
|
||||
|
||||
if result:
|
||||
tool_stats[tool_name] = result
|
||||
processed_records += result['result'].get('processed_records', 0)
|
||||
else:
|
||||
# 使用配置的超时时间和动态计算的较大值
|
||||
dynamic_timeout = calculate_timeout_by_line_count(tool_config, urls_file, base_per_time=1)
|
||||
timeout = max(dynamic_timeout, config_timeout)
|
||||
|
||||
# 2.1 生成日志文件路径(类似端口扫描)
|
||||
from datetime import datetime
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = site_scan_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
logger.info(
|
||||
"开始执行 %s 站点扫描 - URL数: %d, 最终超时: %ds",
|
||||
tool_name, total_urls, timeout
|
||||
)
|
||||
user_log(scan_id, "site_scan", f"Running {tool_name}: {command}")
|
||||
|
||||
# 3. 执行扫描任务
|
||||
try:
|
||||
# 流式执行扫描并实时保存结果
|
||||
result = run_and_stream_save_websites_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(site_scan_dir),
|
||||
shell=True,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file)
|
||||
)
|
||||
|
||||
tool_stats[tool_name] = {
|
||||
'command': command,
|
||||
'result': result,
|
||||
'timeout': timeout
|
||||
}
|
||||
tool_records = result.get('processed_records', 0)
|
||||
tool_created = result.get('created_websites', 0)
|
||||
processed_records += tool_records
|
||||
|
||||
logger.info(
|
||||
"✓ 工具 %s 流式处理完成 - 处理记录: %d, 创建站点: %d, 跳过: %d",
|
||||
tool_name,
|
||||
tool_records,
|
||||
tool_created,
|
||||
result.get('skipped_no_subdomain', 0) + result.get('skipped_failed', 0)
|
||||
)
|
||||
user_log(scan_id, "site_scan", f"{tool_name} completed: found {tool_created} websites")
|
||||
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
# 超时异常单独处理
|
||||
reason = f"timeout after {timeout}s"
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.warning(
|
||||
"⚠️ 工具 %s 执行超时 - 超时配置: %d秒\n"
|
||||
"注意:超时前已解析的站点数据已保存到数据库,但扫描未完全完成。",
|
||||
tool_name, timeout
|
||||
)
|
||||
user_log(scan_id, "site_scan", f"{tool_name} failed: {reason}", "error")
|
||||
except Exception as exc:
|
||||
# 其他异常
|
||||
reason = str(exc)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
user_log(scan_id, "site_scan", f"{tool_name} failed: {reason}", "error")
|
||||
|
||||
failed_tools.append({'tool': tool_name, 'reason': '执行失败'})
|
||||
|
||||
if failed_tools:
|
||||
logger.warning(
|
||||
"以下扫描工具执行失败: %s",
|
||||
', '.join([f['tool'] for f in failed_tools])
|
||||
', '.join(f['tool'] for f in failed_tools)
|
||||
)
|
||||
|
||||
|
||||
if not tool_stats:
|
||||
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
|
||||
logger.warning("所有站点扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
|
||||
# 返回空结果,不抛出异常,让扫描继续
|
||||
logger.warning(
|
||||
"所有站点扫描工具均失败 - 目标: %s", ctx.target_name
|
||||
)
|
||||
return {}, 0, [], failed_tools
|
||||
|
||||
# 动态计算成功的工具列表
|
||||
successful_tool_names = [name for name in enabled_tools.keys()
|
||||
if name not in [f['tool'] for f in failed_tools]]
|
||||
|
||||
|
||||
successful_tools = [
|
||||
name for name in enabled_tools
|
||||
if name not in {f['tool'] for f in failed_tools}
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"✓ 串行站点扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
|
||||
len(tool_stats), len(enabled_tools),
|
||||
', '.join(successful_tool_names) if successful_tool_names else '无',
|
||||
', '.join([f['tool'] for f in failed_tools]) if failed_tools else '无'
|
||||
"✓ 站点扫描执行完成 - 成功: %d/%d",
|
||||
len(tool_stats), len(enabled_tools)
|
||||
)
|
||||
|
||||
return tool_stats, processed_records, successful_tool_names, failed_tools
|
||||
|
||||
return tool_stats, processed_records, successful_tools, failed_tools
|
||||
|
||||
|
||||
def calculate_timeout(url_count: int, base: int = 600, per_url: int = 1) -> int:
|
||||
"""
|
||||
根据 URL 数量动态计算扫描超时时间
|
||||
def _build_empty_result(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
scan_workspace_dir: str,
|
||||
urls_file: str,
|
||||
association_count: int
|
||||
) -> dict:
|
||||
"""构建空结果(无 URL 可扫描时)"""
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'total_urls': 0,
|
||||
'association_count': association_count,
|
||||
'processed_records': 0,
|
||||
'created_websites': 0,
|
||||
'skipped_no_subdomain': 0,
|
||||
'skipped_failed': 0,
|
||||
'executed_tasks': ['export_site_urls'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'successful_tools': [],
|
||||
'failed_tools': [],
|
||||
'details': {}
|
||||
}
|
||||
}
|
||||
|
||||
规则:
|
||||
- 基础时间:默认 600 秒(10 分钟)
|
||||
- 每个 URL 额外增加:默认 1 秒
|
||||
|
||||
Args:
|
||||
url_count: URL 数量,必须为正整数
|
||||
base: 基础超时时间(秒),默认 600
|
||||
per_url: 每个 URL 增加的时间(秒),默认 1
|
||||
def _aggregate_tool_results(tool_stats: dict) -> tuple[int, int, int]:
|
||||
"""汇总工具结果"""
|
||||
total_created = sum(
|
||||
s['result'].get('created_websites', 0) for s in tool_stats.values()
|
||||
)
|
||||
total_skipped_no_subdomain = sum(
|
||||
s['result'].get('skipped_no_subdomain', 0) for s in tool_stats.values()
|
||||
)
|
||||
total_skipped_failed = sum(
|
||||
s['result'].get('skipped_failed', 0) for s in tool_stats.values()
|
||||
)
|
||||
return total_created, total_skipped_no_subdomain, total_skipped_failed
|
||||
|
||||
Returns:
|
||||
int: 计算得到的超时时间(秒),不超过 max_timeout
|
||||
|
||||
Raises:
|
||||
ValueError: 当 url_count 为负数或 0 时抛出异常
|
||||
"""
|
||||
if url_count < 0:
|
||||
raise ValueError(f"URL数量不能为负数: {url_count}")
|
||||
if url_count == 0:
|
||||
raise ValueError("URL数量不能为0")
|
||||
|
||||
timeout = base + int(url_count * per_url)
|
||||
|
||||
# 不设置上限,由调用方根据需要控制
|
||||
return timeout
|
||||
def _validate_flow_params(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str
|
||||
) -> None:
|
||||
"""验证 Flow 参数"""
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
|
||||
|
||||
@flow(
|
||||
name="site_scan",
|
||||
name="site_scan",
|
||||
log_prints=True,
|
||||
on_running=[on_scan_flow_running],
|
||||
on_completion=[on_scan_flow_completed],
|
||||
@@ -322,140 +337,83 @@ def site_scan_flow(
|
||||
) -> dict:
|
||||
"""
|
||||
站点扫描 Flow
|
||||
|
||||
|
||||
主要功能:
|
||||
1. 从target获取所有子域名与其对应的端口号,拼接成URL写入文件
|
||||
2. 用httpx进行批量请求并实时保存到数据库(流式处理)
|
||||
|
||||
工作流程:
|
||||
Step 0: 创建工作目录
|
||||
Step 1: 导出站点 URL 列表
|
||||
Step 2: 解析配置,获取启用的工具
|
||||
Step 3: 串行执行扫描工具并实时保存结果
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置字典
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'scan_workspace_dir': str,
|
||||
'urls_file': str,
|
||||
'total_urls': int,
|
||||
'association_count': int,
|
||||
'processed_records': int,
|
||||
'created_websites': int,
|
||||
'skipped_no_subdomain': int,
|
||||
'skipped_failed': int,
|
||||
'executed_tasks': list,
|
||||
'tool_stats': {
|
||||
'total': int,
|
||||
'successful': int,
|
||||
'failed': int,
|
||||
'successful_tools': list[str],
|
||||
'failed_tools': list[dict]
|
||||
}
|
||||
}
|
||||
|
||||
dict: 扫描结果
|
||||
|
||||
Raises:
|
||||
ValueError: 配置错误
|
||||
RuntimeError: 执行失败
|
||||
"""
|
||||
try:
|
||||
wait_for_system_load(context="site_scan_flow")
|
||||
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始站点扫描\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
"开始站点扫描 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
)
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
|
||||
|
||||
_validate_flow_params(scan_id, target_name, target_id, scan_workspace_dir)
|
||||
user_log(scan_id, "site_scan", "Starting site scan")
|
||||
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan')
|
||||
|
||||
|
||||
# Step 1: 导出站点 URL
|
||||
urls_file, total_urls, association_count = _export_site_urls(
|
||||
target_id, site_scan_dir, target_name
|
||||
target_id, site_scan_dir
|
||||
)
|
||||
|
||||
|
||||
if total_urls == 0:
|
||||
logger.warning("跳过站点扫描:没有站点 URL 可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "site_scan", "Skipped: no site URLs to scan", "warning")
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'total_urls': 0,
|
||||
'association_count': association_count,
|
||||
'processed_records': 0,
|
||||
'created_websites': 0,
|
||||
'skipped_no_subdomain': 0,
|
||||
'skipped_failed': 0,
|
||||
'executed_tasks': ['export_site_urls'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'successful_tools': [],
|
||||
'failed_tools': [],
|
||||
'details': {}
|
||||
}
|
||||
}
|
||||
|
||||
return _build_empty_result(
|
||||
scan_id, target_name, scan_workspace_dir, urls_file, association_count
|
||||
)
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
logger.info("Step 2: 工具配置信息")
|
||||
logger.info(
|
||||
"✓ 启用工具: %s",
|
||||
', '.join(enabled_tools.keys())
|
||||
)
|
||||
|
||||
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools))
|
||||
|
||||
# Step 3: 串行执行扫描工具
|
||||
logger.info("Step 3: 串行执行扫描工具并实时保存结果")
|
||||
tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially(
|
||||
enabled_tools=enabled_tools,
|
||||
urls_file=urls_file,
|
||||
total_urls=total_urls,
|
||||
site_scan_dir=site_scan_dir,
|
||||
ctx = ScanContext(
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
target_name=target_name
|
||||
target_name=target_name,
|
||||
site_scan_dir=site_scan_dir,
|
||||
urls_file=urls_file,
|
||||
total_urls=total_urls
|
||||
)
|
||||
|
||||
# 动态生成已执行的任务列表
|
||||
|
||||
tool_stats, processed_records, successful_tools, failed_tools = \
|
||||
_run_scans_sequentially(enabled_tools, ctx)
|
||||
|
||||
# 汇总结果
|
||||
executed_tasks = ['export_site_urls', 'parse_config']
|
||||
executed_tasks.extend([f'run_and_stream_save_websites ({tool})' for tool in tool_stats.keys()])
|
||||
|
||||
# 汇总所有工具的结果
|
||||
total_created = sum(stats['result'].get('created_websites', 0) for stats in tool_stats.values())
|
||||
total_skipped_no_subdomain = sum(stats['result'].get('skipped_no_subdomain', 0) for stats in tool_stats.values())
|
||||
total_skipped_failed = sum(stats['result'].get('skipped_failed', 0) for stats in tool_stats.values())
|
||||
|
||||
# 记录 Flow 完成
|
||||
executed_tasks.extend(
|
||||
f'run_and_stream_save_websites ({tool})' for tool in tool_stats
|
||||
)
|
||||
|
||||
total_created, total_skipped_no_sub, total_skipped_failed = \
|
||||
_aggregate_tool_results(tool_stats)
|
||||
|
||||
logger.info("✓ 站点扫描完成 - 创建站点: %d", total_created)
|
||||
user_log(scan_id, "site_scan", f"site_scan completed: found {total_created} websites")
|
||||
|
||||
user_log(
|
||||
scan_id, "site_scan",
|
||||
f"site_scan completed: found {total_created} websites"
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -466,25 +424,20 @@ def site_scan_flow(
|
||||
'association_count': association_count,
|
||||
'processed_records': processed_records,
|
||||
'created_websites': total_created,
|
||||
'skipped_no_subdomain': total_skipped_no_subdomain,
|
||||
'skipped_no_subdomain': total_skipped_no_sub,
|
||||
'skipped_failed': total_skipped_failed,
|
||||
'executed_tasks': executed_tasks,
|
||||
'tool_stats': {
|
||||
'total': len(enabled_tools),
|
||||
'successful': len(successful_tool_names),
|
||||
'successful': len(successful_tools),
|
||||
'failed': len(failed_tools),
|
||||
'successful_tools': successful_tool_names,
|
||||
'successful_tools': successful_tools,
|
||||
'failed_tools': failed_tools,
|
||||
'details': tool_stats
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("配置错误: %s", e)
|
||||
|
||||
except ValueError:
|
||||
raise
|
||||
except RuntimeError as e:
|
||||
logger.error("运行时错误: %s", e)
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("站点扫描失败: %s", e)
|
||||
raise
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,22 +10,18 @@ URL Fetch 主 Flow
|
||||
- 统一进行 httpx 验证(如果启用)
|
||||
"""
|
||||
|
||||
# Django 环境初始化
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.utils import user_log
|
||||
from apps.scan.utils import user_log, wait_for_system_load
|
||||
|
||||
from .domain_name_url_fetch_flow import domain_name_url_fetch_flow
|
||||
from .sites_url_fetch_flow import sites_url_fetch_flow
|
||||
@@ -43,13 +39,10 @@ SITES_FILE_TOOLS = {'katana'}
|
||||
POST_PROCESS_TOOLS = {'uro', 'httpx'}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
|
||||
"""
|
||||
将启用的工具按输入类型分类
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (domain_name_tools, sites_file_tools, uro_config, httpx_config)
|
||||
"""
|
||||
@@ -76,23 +69,23 @@ def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
|
||||
def _merge_and_deduplicate_urls(result_files: list, url_fetch_dir: Path) -> tuple[str, int]:
|
||||
"""合并并去重 URL"""
|
||||
from apps.scan.tasks.url_fetch import merge_and_deduplicate_urls_task
|
||||
|
||||
|
||||
merged_file = merge_and_deduplicate_urls_task(
|
||||
result_files=result_files,
|
||||
result_dir=str(url_fetch_dir)
|
||||
)
|
||||
|
||||
|
||||
# 统计唯一 URL 数量
|
||||
unique_url_count = 0
|
||||
if Path(merged_file).exists():
|
||||
with open(merged_file, 'r') as f:
|
||||
with open(merged_file, 'r', encoding='utf-8') as f:
|
||||
unique_url_count = sum(1 for line in f if line.strip())
|
||||
|
||||
|
||||
logger.info(
|
||||
"✓ URL 合并去重完成 - 合并文件: %s, 唯一 URL 数: %d",
|
||||
merged_file, unique_url_count
|
||||
)
|
||||
|
||||
|
||||
return merged_file, unique_url_count
|
||||
|
||||
|
||||
@@ -103,12 +96,12 @@ def _clean_urls_with_uro(
|
||||
) -> tuple[str, int, int]:
|
||||
"""使用 uro 清理合并后的 URL 列表"""
|
||||
from apps.scan.tasks.url_fetch import clean_urls_task
|
||||
|
||||
|
||||
raw_timeout = uro_config.get('timeout', 60)
|
||||
whitelist = uro_config.get('whitelist')
|
||||
blacklist = uro_config.get('blacklist')
|
||||
filters = uro_config.get('filters')
|
||||
|
||||
|
||||
# 计算超时时间
|
||||
if isinstance(raw_timeout, str) and raw_timeout == 'auto':
|
||||
timeout = calculate_timeout_by_line_count(
|
||||
@@ -124,7 +117,7 @@ def _clean_urls_with_uro(
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("uro timeout 配置无效(%s),使用默认 60 秒", raw_timeout)
|
||||
timeout = 60
|
||||
|
||||
|
||||
result = clean_urls_task(
|
||||
input_file=merged_file,
|
||||
output_dir=str(url_fetch_dir),
|
||||
@@ -133,12 +126,12 @@ def _clean_urls_with_uro(
|
||||
blacklist=blacklist,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
|
||||
if result['success']:
|
||||
return result['output_file'], result['output_count'], result['removed_count']
|
||||
else:
|
||||
logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误'))
|
||||
return merged_file, result['input_count'], 0
|
||||
|
||||
logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误'))
|
||||
return merged_file, result['input_count'], 0
|
||||
|
||||
|
||||
def _validate_and_stream_save_urls(
|
||||
@@ -151,25 +144,25 @@ def _validate_and_stream_save_urls(
|
||||
"""使用 httpx 验证 URL 存活并流式保存到数据库"""
|
||||
from apps.scan.utils import build_scan_command
|
||||
from apps.scan.tasks.url_fetch import run_and_stream_save_urls_task
|
||||
|
||||
|
||||
logger.info("开始使用 httpx 验证 URL 存活状态...")
|
||||
|
||||
|
||||
# 统计待验证的 URL 数量
|
||||
try:
|
||||
with open(merged_file, 'r') as f:
|
||||
with open(merged_file, 'r', encoding='utf-8') as f:
|
||||
url_count = sum(1 for _ in f)
|
||||
logger.info("待验证 URL 数量: %d", url_count)
|
||||
except Exception as e:
|
||||
except OSError as e:
|
||||
logger.error("读取 URL 文件失败: %s", e)
|
||||
return 0
|
||||
|
||||
|
||||
if url_count == 0:
|
||||
logger.warning("没有需要验证的 URL")
|
||||
return 0
|
||||
|
||||
|
||||
# 构建 httpx 命令
|
||||
command_params = {'url_file': merged_file}
|
||||
|
||||
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name='httpx',
|
||||
@@ -177,21 +170,19 @@ def _validate_and_stream_save_urls(
|
||||
command_params=command_params,
|
||||
tool_config=httpx_config
|
||||
)
|
||||
except Exception as e:
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.error("构建 httpx 命令失败: %s", e)
|
||||
logger.warning("降级处理:将直接保存所有 URL(不验证存活)")
|
||||
return _save_urls_to_database(merged_file, scan_id, target_id)
|
||||
|
||||
|
||||
# 计算超时时间
|
||||
raw_timeout = httpx_config.get('timeout', 'auto')
|
||||
timeout = 3600
|
||||
if isinstance(raw_timeout, str) and raw_timeout == 'auto':
|
||||
# 按 URL 行数计算超时时间:每行 3 秒,最小 60 秒
|
||||
timeout = max(60, url_count * 3)
|
||||
logger.info(
|
||||
"自动计算 httpx 超时时间(按行数,每行 3 秒,最小 60 秒): url_count=%d, timeout=%d 秒",
|
||||
url_count,
|
||||
timeout,
|
||||
url_count, timeout
|
||||
)
|
||||
else:
|
||||
try:
|
||||
@@ -199,49 +190,44 @@ def _validate_and_stream_save_urls(
|
||||
except (TypeError, ValueError):
|
||||
timeout = 3600
|
||||
logger.info("使用配置的 httpx 超时时间: %d 秒", timeout)
|
||||
|
||||
|
||||
# 生成日志文件路径
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = url_fetch_dir / f"httpx_validation_{timestamp}.log"
|
||||
|
||||
|
||||
# 流式执行
|
||||
try:
|
||||
result = run_and_stream_save_urls_task(
|
||||
cmd=command,
|
||||
tool_name='httpx',
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(url_fetch_dir),
|
||||
shell=True,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file)
|
||||
)
|
||||
|
||||
saved = result.get('saved_urls', 0)
|
||||
logger.info(
|
||||
"✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)",
|
||||
saved, (saved / url_count * 100) if url_count > 0 else 0
|
||||
)
|
||||
return saved
|
||||
|
||||
except Exception as e:
|
||||
logger.error("httpx 流式验证失败: %s", e, exc_info=True)
|
||||
raise
|
||||
result = run_and_stream_save_urls_task(
|
||||
cmd=command,
|
||||
tool_name='httpx',
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(url_fetch_dir),
|
||||
shell=True,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file)
|
||||
)
|
||||
|
||||
saved = result.get('saved_urls', 0)
|
||||
logger.info(
|
||||
"✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)",
|
||||
saved, (saved / url_count * 100) if url_count > 0 else 0
|
||||
)
|
||||
return saved
|
||||
|
||||
|
||||
def _save_urls_to_database(merged_file: str, scan_id: int, target_id: int) -> int:
|
||||
"""保存 URL 到数据库(不验证存活)"""
|
||||
from apps.scan.tasks.url_fetch import save_urls_task
|
||||
|
||||
|
||||
result = save_urls_task(
|
||||
urls_file=merged_file,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id
|
||||
)
|
||||
|
||||
|
||||
saved_count = result.get('saved_urls', 0)
|
||||
logger.info("✓ URL 保存完成 - 保存数量: %d", saved_count)
|
||||
|
||||
|
||||
return saved_count
|
||||
|
||||
|
||||
@@ -261,7 +247,7 @@ def url_fetch_flow(
|
||||
) -> dict:
|
||||
"""
|
||||
URL 获取主 Flow
|
||||
|
||||
|
||||
执行流程:
|
||||
1. 准备工作目录
|
||||
2. 按输入类型分类工具(domain_name / sites_file / 后处理)
|
||||
@@ -271,36 +257,32 @@ def url_fetch_flow(
|
||||
4. 合并所有子 Flow 的结果并去重
|
||||
5. uro 去重(如果启用)
|
||||
6. httpx 验证(如果启用)
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作目录
|
||||
enabled_tools: 启用的工具配置
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 扫描结果
|
||||
"""
|
||||
try:
|
||||
# 负载检查:等待系统资源充足
|
||||
wait_for_system_load(context="url_fetch_flow")
|
||||
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始 URL 获取扫描\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
"开始 URL 获取扫描 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
)
|
||||
|
||||
user_log(scan_id, "url_fetch", "Starting URL fetch")
|
||||
|
||||
|
||||
# Step 1: 准备工作目录
|
||||
logger.info("Step 1: 准备工作目录")
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch')
|
||||
|
||||
|
||||
# Step 2: 分类工具(按输入类型)
|
||||
logger.info("Step 2: 分类工具")
|
||||
domain_name_tools, sites_file_tools, uro_config, httpx_config = _classify_tools(enabled_tools)
|
||||
|
||||
logger.info(
|
||||
@@ -317,15 +299,14 @@ def url_fetch_flow(
|
||||
"URL Fetch 流程需要至少启用一个 URL 获取工具(如 waymore, katana)。"
|
||||
"httpx 和 uro 仅用于后处理,不能单独使用。"
|
||||
)
|
||||
|
||||
# Step 3: 并行执行子 Flow
|
||||
|
||||
# Step 3: 执行子 Flow
|
||||
all_result_files = []
|
||||
all_failed_tools = []
|
||||
all_successful_tools = []
|
||||
|
||||
# 3a: 基于 domain_name(target_name) 的 URL 被动收集(如 waymore)
|
||||
|
||||
# 3a: 基于 domain_name 的 URL 被动收集(如 waymore)
|
||||
if domain_name_tools:
|
||||
logger.info("Step 3a: 执行基于 domain_name 的 URL 被动收集子 Flow")
|
||||
tn_result = domain_name_url_fetch_flow(
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
@@ -336,10 +317,9 @@ def url_fetch_flow(
|
||||
all_result_files.extend(tn_result.get('result_files', []))
|
||||
all_failed_tools.extend(tn_result.get('failed_tools', []))
|
||||
all_successful_tools.extend(tn_result.get('successful_tools', []))
|
||||
|
||||
|
||||
# 3b: 爬虫(以 sites_file 为输入)
|
||||
if sites_file_tools:
|
||||
logger.info("Step 3b: 执行爬虫子 Flow")
|
||||
crawl_result = sites_url_fetch_flow(
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
@@ -350,12 +330,13 @@ def url_fetch_flow(
|
||||
all_result_files.extend(crawl_result.get('result_files', []))
|
||||
all_failed_tools.extend(crawl_result.get('failed_tools', []))
|
||||
all_successful_tools.extend(crawl_result.get('successful_tools', []))
|
||||
|
||||
|
||||
# 检查是否有成功的工具
|
||||
if not all_result_files:
|
||||
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in all_failed_tools])
|
||||
error_details = "; ".join([
|
||||
"%s: %s" % (f['tool'], f['reason']) for f in all_failed_tools
|
||||
])
|
||||
logger.warning("所有 URL 获取工具均失败 - 目标: %s, 失败详情: %s", target_name, error_details)
|
||||
# 返回空结果,不抛出异常,让扫描继续
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -366,31 +347,24 @@ def url_fetch_flow(
|
||||
'successful_tools': [],
|
||||
'message': '所有 URL 获取工具均无结果'
|
||||
}
|
||||
|
||||
|
||||
# Step 4: 合并并去重 URL
|
||||
logger.info("Step 4: 合并并去重 URL")
|
||||
merged_file, unique_url_count = _merge_and_deduplicate_urls(
|
||||
merged_file, _ = _merge_and_deduplicate_urls(
|
||||
result_files=all_result_files,
|
||||
url_fetch_dir=url_fetch_dir
|
||||
)
|
||||
|
||||
|
||||
# Step 5: 使用 uro 清理 URL(如果启用)
|
||||
url_file_for_validation = merged_file
|
||||
uro_removed_count = 0
|
||||
|
||||
if uro_config and uro_config.get('enabled', False):
|
||||
logger.info("Step 5: 使用 uro 清理 URL")
|
||||
url_file_for_validation, cleaned_count, uro_removed_count = _clean_urls_with_uro(
|
||||
url_file_for_validation, _, _ = _clean_urls_with_uro(
|
||||
merged_file=merged_file,
|
||||
uro_config=uro_config,
|
||||
url_fetch_dir=url_fetch_dir
|
||||
)
|
||||
else:
|
||||
logger.info("Step 5: 跳过 uro 清理(未启用)")
|
||||
|
||||
|
||||
# Step 6: 使用 httpx 验证存活并保存(如果启用)
|
||||
if httpx_config and httpx_config.get('enabled', False):
|
||||
logger.info("Step 6: 使用 httpx 验证 URL 存活并流式保存")
|
||||
saved_count = _validate_and_stream_save_urls(
|
||||
merged_file=url_file_for_validation,
|
||||
httpx_config=httpx_config,
|
||||
@@ -399,17 +373,16 @@ def url_fetch_flow(
|
||||
target_id=target_id
|
||||
)
|
||||
else:
|
||||
logger.info("Step 6: 保存到数据库(未启用 httpx 验证)")
|
||||
saved_count = _save_urls_to_database(
|
||||
merged_file=url_file_for_validation,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id
|
||||
)
|
||||
|
||||
|
||||
# 记录 Flow 完成
|
||||
logger.info("✓ URL 获取完成 - 保存 endpoints: %d", saved_count)
|
||||
user_log(scan_id, "url_fetch", f"url_fetch completed: found {saved_count} endpoints")
|
||||
|
||||
user_log(scan_id, "url_fetch", "url_fetch completed: found %d endpoints" % saved_count)
|
||||
|
||||
# 构建已执行的任务列表
|
||||
executed_tasks = ['setup_directory', 'classify_tools']
|
||||
if domain_name_tools:
|
||||
@@ -423,7 +396,7 @@ def url_fetch_flow(
|
||||
executed_tasks.append('httpx_validation_and_save')
|
||||
else:
|
||||
executed_tasks.append('save_urls')
|
||||
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -439,7 +412,7 @@ def url_fetch_flow(
|
||||
'failed_tools': [f['tool'] for f in all_failed_tools]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("URL 获取扫描失败: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
"""
|
||||
漏洞扫描主 Flow
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Tuple
|
||||
|
||||
@@ -11,7 +12,7 @@ from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.configs.command_templates import get_command_template
|
||||
from apps.scan.utils import user_log
|
||||
from apps.scan.utils import user_log, wait_for_system_load
|
||||
from .endpoints_vuln_scan_flow import endpoints_vuln_scan_flow
|
||||
|
||||
|
||||
@@ -62,6 +63,9 @@ def vuln_scan_flow(
|
||||
- nuclei: 通用漏洞扫描(流式保存,支持模板 commit hash 同步)
|
||||
"""
|
||||
try:
|
||||
# 负载检查:等待系统资源充足
|
||||
wait_for_system_load(context="vuln_scan_flow")
|
||||
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
|
||||
23
backend/apps/scan/migrations/0003_add_wecom_fields.py
Normal file
23
backend/apps/scan/migrations/0003_add_wecom_fields.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Generated manually for WeCom notification support
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('scan', '0002_add_cached_screenshots_count'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='notificationsettings',
|
||||
name='wecom_enabled',
|
||||
field=models.BooleanField(default=False, help_text='是否启用企业微信通知'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='notificationsettings',
|
||||
name='wecom_webhook_url',
|
||||
field=models.URLField(blank=True, default='', help_text='企业微信机器人 Webhook URL'),
|
||||
),
|
||||
]
|
||||
@@ -1,8 +1,14 @@
|
||||
"""通知系统数据模型"""
|
||||
|
||||
from django.db import models
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from .types import NotificationLevel, NotificationCategory
|
||||
from django.db import models
|
||||
from django.utils import timezone
|
||||
|
||||
from .types import NotificationCategory, NotificationLevel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationSettings(models.Model):
|
||||
@@ -10,31 +16,34 @@ class NotificationSettings(models.Model):
|
||||
通知设置(单例模型)
|
||||
存储 Discord webhook 配置和各分类的通知开关
|
||||
"""
|
||||
|
||||
|
||||
# Discord 配置
|
||||
discord_enabled = models.BooleanField(default=False, help_text='是否启用 Discord 通知')
|
||||
discord_webhook_url = models.URLField(blank=True, default='', help_text='Discord Webhook URL')
|
||||
|
||||
|
||||
# 企业微信配置
|
||||
wecom_enabled = models.BooleanField(default=False, help_text='是否启用企业微信通知')
|
||||
wecom_webhook_url = models.URLField(blank=True, default='', help_text='企业微信机器人 Webhook URL')
|
||||
|
||||
# 分类开关(使用 JSONField 存储)
|
||||
categories = models.JSONField(
|
||||
default=dict,
|
||||
help_text='各分类通知开关,如 {"scan": true, "vulnerability": true, "asset": true, "system": false}'
|
||||
)
|
||||
|
||||
|
||||
# 时间信息
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
updated_at = models.DateTimeField(auto_now=True)
|
||||
|
||||
|
||||
class Meta:
|
||||
db_table = 'notification_settings'
|
||||
verbose_name = '通知设置'
|
||||
verbose_name_plural = '通知设置'
|
||||
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
# 单例模式:强制只有一条记录
|
||||
self.pk = 1
|
||||
self.pk = 1 # 单例模式
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> 'NotificationSettings':
|
||||
"""获取或创建单例实例"""
|
||||
@@ -52,7 +61,7 @@ class NotificationSettings(models.Model):
|
||||
}
|
||||
)
|
||||
return obj
|
||||
|
||||
|
||||
def is_category_enabled(self, category: str) -> bool:
|
||||
"""检查指定分类是否启用通知"""
|
||||
return self.categories.get(category, False)
|
||||
@@ -60,10 +69,9 @@ class NotificationSettings(models.Model):
|
||||
|
||||
class Notification(models.Model):
|
||||
"""通知模型"""
|
||||
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
|
||||
# 通知分类
|
||||
|
||||
category = models.CharField(
|
||||
max_length=20,
|
||||
choices=NotificationCategory.choices,
|
||||
@@ -71,8 +79,7 @@ class Notification(models.Model):
|
||||
db_index=True,
|
||||
help_text='通知分类'
|
||||
)
|
||||
|
||||
# 通知级别
|
||||
|
||||
level = models.CharField(
|
||||
max_length=20,
|
||||
choices=NotificationLevel.choices,
|
||||
@@ -80,16 +87,15 @@ class Notification(models.Model):
|
||||
db_index=True,
|
||||
help_text='通知级别'
|
||||
)
|
||||
|
||||
|
||||
title = models.CharField(max_length=200, help_text='通知标题')
|
||||
message = models.CharField(max_length=2000, help_text='通知内容')
|
||||
|
||||
# 时间信息
|
||||
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
|
||||
|
||||
is_read = models.BooleanField(default=False, help_text='是否已读')
|
||||
read_at = models.DateTimeField(null=True, blank=True, help_text='阅读时间')
|
||||
|
||||
|
||||
class Meta:
|
||||
db_table = 'notification'
|
||||
verbose_name = '通知'
|
||||
@@ -101,44 +107,26 @@ class Notification(models.Model):
|
||||
models.Index(fields=['level', '-created_at']),
|
||||
models.Index(fields=['is_read', '-created_at']),
|
||||
]
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.get_level_display()} - {self.title}"
|
||||
|
||||
|
||||
@classmethod
|
||||
def cleanup_old_notifications(cls):
|
||||
"""
|
||||
清理超过15天的旧通知(硬编码)
|
||||
|
||||
Returns:
|
||||
int: 删除的通知数量
|
||||
"""
|
||||
from datetime import timedelta
|
||||
from django.utils import timezone
|
||||
|
||||
# 硬编码:只保留最近15天的通知
|
||||
def cleanup_old_notifications(cls) -> int:
|
||||
"""清理超过15天的旧通知"""
|
||||
cutoff_date = timezone.now() - timedelta(days=15)
|
||||
delete_result = cls.objects.filter(created_at__lt=cutoff_date).delete()
|
||||
|
||||
return delete_result[0] if delete_result[0] else 0
|
||||
|
||||
deleted_count, _ = cls.objects.filter(created_at__lt=cutoff_date).delete()
|
||||
return deleted_count or 0
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
"""
|
||||
重写save方法,在创建新通知时自动清理旧通知
|
||||
"""
|
||||
"""重写save方法,在创建新通知时自动清理旧通知"""
|
||||
is_new = self.pk is None
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
# 只在创建新通知时执行清理(自动清理超过15天的通知)
|
||||
|
||||
if is_new:
|
||||
try:
|
||||
deleted_count = self.__class__.cleanup_old_notifications()
|
||||
if deleted_count > 0:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"自动清理了 {deleted_count} 条超过15天的旧通知")
|
||||
except Exception as e:
|
||||
# 清理失败不应影响通知创建
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"通知自动清理失败: {e}")
|
||||
logger.info("自动清理了 %d 条超过15天的旧通知", deleted_count)
|
||||
except Exception:
|
||||
logger.warning("通知自动清理失败", exc_info=True)
|
||||
|
||||
@@ -1,52 +1,70 @@
|
||||
"""通知系统仓储层模块"""
|
||||
|
||||
import logging
|
||||
from typing import TypedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from .models import Notification, NotificationSettings
|
||||
|
||||
from .models import Notification, NotificationSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationSettingsData(TypedDict):
|
||||
"""通知设置数据结构"""
|
||||
@dataclass
|
||||
class NotificationSettingsData:
|
||||
"""通知设置更新数据"""
|
||||
|
||||
discord_enabled: bool
|
||||
discord_webhook_url: str
|
||||
categories: dict[str, bool]
|
||||
wecom_enabled: bool = False
|
||||
wecom_webhook_url: str = ''
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class NotificationSettingsRepository:
|
||||
"""通知设置仓储层"""
|
||||
|
||||
|
||||
def get_settings(self) -> NotificationSettings:
|
||||
"""获取通知设置单例"""
|
||||
return NotificationSettings.get_instance()
|
||||
|
||||
def update_settings(
|
||||
self,
|
||||
discord_enabled: bool,
|
||||
discord_webhook_url: str,
|
||||
categories: dict[str, bool]
|
||||
) -> NotificationSettings:
|
||||
|
||||
def update_settings(self, data: NotificationSettingsData) -> NotificationSettings:
|
||||
"""更新通知设置"""
|
||||
settings = NotificationSettings.get_instance()
|
||||
settings.discord_enabled = discord_enabled
|
||||
settings.discord_webhook_url = discord_webhook_url
|
||||
settings.categories = categories
|
||||
settings.discord_enabled = data.discord_enabled
|
||||
settings.discord_webhook_url = data.discord_webhook_url
|
||||
settings.wecom_enabled = data.wecom_enabled
|
||||
settings.wecom_webhook_url = data.wecom_webhook_url
|
||||
settings.categories = data.categories
|
||||
settings.save()
|
||||
return settings
|
||||
|
||||
|
||||
def is_category_enabled(self, category: str) -> bool:
|
||||
"""检查指定分类是否启用"""
|
||||
settings = self.get_settings()
|
||||
return settings.is_category_enabled(category)
|
||||
return self.get_settings().is_category_enabled(category)
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoNotificationRepository:
|
||||
def get_filtered(self, level: str | None = None, unread: bool | None = None):
|
||||
"""通知数据仓储层"""
|
||||
|
||||
def get_filtered(
|
||||
self,
|
||||
level: Optional[str] = None,
|
||||
unread: Optional[bool] = None
|
||||
) -> QuerySet[Notification]:
|
||||
"""
|
||||
获取过滤后的通知列表
|
||||
|
||||
Args:
|
||||
level: 通知级别过滤
|
||||
unread: 已读状态过滤 (True=未读, False=已读, None=全部)
|
||||
"""
|
||||
queryset = Notification.objects.all()
|
||||
|
||||
if level:
|
||||
@@ -60,16 +78,24 @@ class DjangoNotificationRepository:
|
||||
return queryset.order_by("-created_at")
|
||||
|
||||
def get_unread_count(self) -> int:
|
||||
"""获取未读通知数量"""
|
||||
return Notification.objects.filter(is_read=False).count()
|
||||
|
||||
def mark_all_as_read(self) -> int:
|
||||
updated = Notification.objects.filter(is_read=False).update(
|
||||
"""标记所有通知为已读,返回更新数量"""
|
||||
return Notification.objects.filter(is_read=False).update(
|
||||
is_read=True,
|
||||
read_at=timezone.now(),
|
||||
)
|
||||
return updated
|
||||
|
||||
def create(self, title: str, message: str, level: str, category: str = 'system') -> Notification:
|
||||
def create(
|
||||
self,
|
||||
title: str,
|
||||
message: str,
|
||||
level: str,
|
||||
category: str = 'system'
|
||||
) -> Notification:
|
||||
"""创建新通知"""
|
||||
return Notification.objects.create(
|
||||
category=category,
|
||||
level=level,
|
||||
|
||||
@@ -60,13 +60,12 @@ def push_to_external_channels(notification: Notification) -> None:
|
||||
except Exception as e:
|
||||
logger.warning(f"Discord 推送失败: {e}")
|
||||
|
||||
# 未来扩展:Slack
|
||||
# if settings.slack_enabled and settings.slack_webhook_url:
|
||||
# _send_slack(notification, settings.slack_webhook_url)
|
||||
|
||||
# 未来扩展:Telegram
|
||||
# if settings.telegram_enabled and settings.telegram_bot_token:
|
||||
# _send_telegram(notification, settings.telegram_chat_id)
|
||||
# 企业微信渠道
|
||||
if settings.wecom_enabled and settings.wecom_webhook_url:
|
||||
try:
|
||||
_send_wecom(notification, settings.wecom_webhook_url)
|
||||
except Exception as e:
|
||||
logger.warning(f"企业微信推送失败: {e}")
|
||||
|
||||
|
||||
def _send_discord(notification: Notification, webhook_url: str) -> bool:
|
||||
@@ -103,6 +102,41 @@ def _send_discord(notification: Notification, webhook_url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _send_wecom(notification: Notification, webhook_url: str) -> bool:
|
||||
"""发送到企业微信机器人 Webhook"""
|
||||
try:
|
||||
emoji = CATEGORY_EMOJI.get(notification.category, '📢')
|
||||
|
||||
# 企业微信 Markdown 格式
|
||||
content = f"""**{emoji} {notification.title}**
|
||||
> 级别:{notification.get_level_display()}
|
||||
> 分类:{notification.get_category_display()}
|
||||
|
||||
{notification.message}"""
|
||||
|
||||
payload = {
|
||||
'msgtype': 'markdown',
|
||||
'markdown': {'content': content}
|
||||
}
|
||||
|
||||
response = requests.post(webhook_url, json=payload, timeout=10)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get('errcode') == 0:
|
||||
logger.info(f"企业微信通知发送成功 - {notification.title}")
|
||||
return True
|
||||
logger.warning(f"企业微信发送失败 - errcode: {result.get('errcode')}, errmsg: {result.get('errmsg')}")
|
||||
return False
|
||||
|
||||
logger.warning(f"企业微信发送失败 - 状态码: {response.status_code}")
|
||||
return False
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"企业微信网络错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 设置服务
|
||||
# ============================================================
|
||||
@@ -121,31 +155,43 @@ class NotificationSettingsService:
|
||||
'enabled': settings.discord_enabled,
|
||||
'webhookUrl': settings.discord_webhook_url,
|
||||
},
|
||||
'wecom': {
|
||||
'enabled': settings.wecom_enabled,
|
||||
'webhookUrl': settings.wecom_webhook_url,
|
||||
},
|
||||
'categories': settings.categories,
|
||||
}
|
||||
|
||||
def update_settings(self, data: dict) -> dict:
|
||||
"""更新通知设置
|
||||
|
||||
|
||||
注意:DRF CamelCaseJSONParser 会将前端的 webhookUrl 转换为 webhook_url
|
||||
"""
|
||||
discord_data = data.get('discord', {})
|
||||
wecom_data = data.get('wecom', {})
|
||||
categories = data.get('categories', {})
|
||||
|
||||
|
||||
# CamelCaseJSONParser 转换后的字段名是 webhook_url
|
||||
webhook_url = discord_data.get('webhook_url', '')
|
||||
|
||||
discord_webhook_url = discord_data.get('webhook_url', '')
|
||||
wecom_webhook_url = wecom_data.get('webhook_url', '')
|
||||
|
||||
settings = self.repo.update_settings(
|
||||
discord_enabled=discord_data.get('enabled', False),
|
||||
discord_webhook_url=webhook_url,
|
||||
discord_webhook_url=discord_webhook_url,
|
||||
wecom_enabled=wecom_data.get('enabled', False),
|
||||
wecom_webhook_url=wecom_webhook_url,
|
||||
categories=categories,
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
'discord': {
|
||||
'enabled': settings.discord_enabled,
|
||||
'webhookUrl': settings.discord_webhook_url,
|
||||
},
|
||||
'wecom': {
|
||||
'enabled': settings.wecom_enabled,
|
||||
'webhookUrl': settings.wecom_webhook_url,
|
||||
},
|
||||
'categories': settings.categories,
|
||||
}
|
||||
|
||||
|
||||
56
backend/apps/scan/providers/__init__.py
Normal file
56
backend/apps/scan/providers/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
扫描目标提供者模块
|
||||
|
||||
提供统一的目标获取接口,支持多种数据源:
|
||||
- DatabaseTargetProvider: 从数据库查询(完整扫描)
|
||||
- ListTargetProvider: 使用内存列表(快速扫描阶段1)
|
||||
- SnapshotTargetProvider: 从快照表读取(快速扫描阶段2+)
|
||||
- PipelineTargetProvider: 使用管道输出(Phase 2)
|
||||
|
||||
使用方式:
|
||||
from apps.scan.providers import (
|
||||
DatabaseTargetProvider,
|
||||
ListTargetProvider,
|
||||
SnapshotTargetProvider,
|
||||
ProviderContext
|
||||
)
|
||||
|
||||
# 数据库模式(完整扫描)
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
|
||||
# 列表模式(快速扫描阶段1)
|
||||
context = ProviderContext(target_id=1, scan_id=100)
|
||||
provider = ListTargetProvider(
|
||||
targets=["a.test.com"],
|
||||
context=context
|
||||
)
|
||||
|
||||
# 快照模式(快速扫描阶段2+)
|
||||
context = ProviderContext(target_id=1, scan_id=100)
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain",
|
||||
context=context
|
||||
)
|
||||
|
||||
# 使用 Provider
|
||||
for host in provider.iter_hosts():
|
||||
scan(host)
|
||||
"""
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
from .list_provider import ListTargetProvider
|
||||
from .database_provider import DatabaseTargetProvider
|
||||
from .snapshot_provider import SnapshotTargetProvider, SnapshotType
|
||||
from .pipeline_provider import PipelineTargetProvider, StageOutput
|
||||
|
||||
__all__ = [
|
||||
'TargetProvider',
|
||||
'ProviderContext',
|
||||
'ListTargetProvider',
|
||||
'DatabaseTargetProvider',
|
||||
'SnapshotTargetProvider',
|
||||
'SnapshotType',
|
||||
'PipelineTargetProvider',
|
||||
'StageOutput',
|
||||
]
|
||||
115
backend/apps/scan/providers/base.py
Normal file
115
backend/apps/scan/providers/base.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
扫描目标提供者基础模块
|
||||
|
||||
定义 ProviderContext 数据类和 TargetProvider 抽象基类。
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Iterator, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apps.common.utils import BlacklistFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderContext:
|
||||
"""
|
||||
Provider 上下文,携带元数据
|
||||
|
||||
Attributes:
|
||||
target_id: 关联的 Target ID(用于结果保存),None 表示临时扫描(不保存)
|
||||
scan_id: 扫描任务 ID
|
||||
"""
|
||||
target_id: Optional[int] = None
|
||||
scan_id: Optional[int] = None
|
||||
|
||||
|
||||
class TargetProvider(ABC):
|
||||
"""
|
||||
扫描目标提供者抽象基类
|
||||
|
||||
职责:
|
||||
- 提供扫描目标(域名、IP、URL 等)的迭代器
|
||||
- 提供黑名单过滤器
|
||||
- 携带上下文信息(target_id, scan_id 等)
|
||||
- 自动展开 CIDR(子类无需关心)
|
||||
|
||||
使用方式:
|
||||
provider = create_target_provider(target_id=123)
|
||||
for host in provider.iter_hosts():
|
||||
print(host)
|
||||
"""
|
||||
|
||||
def __init__(self, context: Optional[ProviderContext] = None):
|
||||
self._context = context or ProviderContext()
|
||||
|
||||
@property
|
||||
def context(self) -> ProviderContext:
|
||||
"""返回 Provider 上下文"""
|
||||
return self._context
|
||||
|
||||
@staticmethod
|
||||
def _expand_host(host: str) -> Iterator[str]:
|
||||
"""
|
||||
展开主机(如果是 CIDR 则展开为多个 IP,否则直接返回)
|
||||
|
||||
示例:
|
||||
"192.168.1.0/30" → "192.168.1.1", "192.168.1.2"
|
||||
"192.168.1.1" → "192.168.1.1"
|
||||
"example.com" → "example.com"
|
||||
"""
|
||||
from apps.common.validators import detect_target_type
|
||||
from apps.targets.models import Target
|
||||
|
||||
host = host.strip()
|
||||
if not host:
|
||||
return
|
||||
|
||||
try:
|
||||
target_type = detect_target_type(host)
|
||||
|
||||
if target_type == Target.TargetType.CIDR:
|
||||
network = ipaddress.ip_network(host, strict=False)
|
||||
if network.num_addresses == 1:
|
||||
yield str(network.network_address)
|
||||
else:
|
||||
yield from (str(ip) for ip in network.hosts())
|
||||
elif target_type in (Target.TargetType.IP, Target.TargetType.DOMAIN):
|
||||
yield host
|
||||
except ValueError as e:
|
||||
logger.warning("跳过无效的主机格式 '%s': %s", host, str(e))
|
||||
|
||||
def iter_hosts(self) -> Iterator[str]:
|
||||
"""迭代主机列表(域名/IP),自动展开 CIDR"""
|
||||
for host in self._iter_raw_hosts():
|
||||
yield from self._expand_host(host)
|
||||
|
||||
@abstractmethod
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""迭代原始主机列表(可能包含 CIDR),子类实现"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""迭代 URL 列表"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
|
||||
"""获取黑名单过滤器,返回 None 表示不过滤"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def target_id(self) -> Optional[int]:
|
||||
"""返回关联的 target_id,临时扫描返回 None"""
|
||||
return self._context.target_id
|
||||
|
||||
@property
|
||||
def scan_id(self) -> Optional[int]:
|
||||
"""返回关联的 scan_id"""
|
||||
return self._context.scan_id
|
||||
93
backend/apps/scan/providers/database_provider.py
Normal file
93
backend/apps/scan/providers/database_provider.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
数据库目标提供者模块
|
||||
|
||||
提供基于数据库查询的目标提供者实现。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Iterator, Optional
|
||||
|
||||
from .base import ProviderContext, TargetProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apps.common.utils import BlacklistFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseTargetProvider(TargetProvider):
|
||||
"""
|
||||
数据库目标提供者 - 从 Target 表及关联资产表查询
|
||||
|
||||
数据来源:
|
||||
- iter_hosts(): 根据 Target 类型返回域名/IP
|
||||
- iter_urls(): WebSite/Endpoint 表,带回退链
|
||||
|
||||
使用方式:
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
for host in provider.iter_hosts():
|
||||
scan(host)
|
||||
"""
|
||||
|
||||
def __init__(self, target_id: int, context: Optional[ProviderContext] = None):
|
||||
ctx = context or ProviderContext()
|
||||
ctx.target_id = target_id
|
||||
super().__init__(ctx)
|
||||
self._blacklist_filter: Optional['BlacklistFilter'] = None
|
||||
|
||||
def iter_hosts(self) -> Iterator[str]:
|
||||
"""从数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤"""
|
||||
blacklist = self.get_blacklist_filter()
|
||||
|
||||
for host in self._iter_raw_hosts():
|
||||
for expanded_host in self._expand_host(host):
|
||||
if not blacklist or blacklist.is_allowed(expanded_host):
|
||||
yield expanded_host
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""从数据库查询原始主机列表(可能包含 CIDR)"""
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
from apps.targets.models import Target
|
||||
from apps.targets.services import TargetService
|
||||
|
||||
target = TargetService().get_target(self.target_id)
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在", self.target_id)
|
||||
return
|
||||
|
||||
if target.type == Target.TargetType.DOMAIN:
|
||||
yield target.name
|
||||
for domain in SubdomainService().iter_subdomain_names_by_target(
|
||||
target_id=self.target_id,
|
||||
chunk_size=1000
|
||||
):
|
||||
if domain != target.name:
|
||||
yield domain
|
||||
|
||||
elif target.type in (Target.TargetType.IP, Target.TargetType.CIDR):
|
||||
yield target.name
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""从数据库查询 URL 列表,使用回退链:Endpoint → WebSite → Default"""
|
||||
from apps.scan.services.target_export_service import (
|
||||
DataSource,
|
||||
_iter_urls_with_fallback,
|
||||
)
|
||||
|
||||
blacklist = self.get_blacklist_filter()
|
||||
|
||||
for url, _ in _iter_urls_with_fallback(
|
||||
target_id=self.target_id,
|
||||
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
|
||||
blacklist_filter=blacklist
|
||||
):
|
||||
yield url
|
||||
|
||||
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
|
||||
"""获取黑名单过滤器(延迟加载)"""
|
||||
if self._blacklist_filter is None:
|
||||
from apps.common.services import BlacklistService
|
||||
from apps.common.utils import BlacklistFilter
|
||||
rules = BlacklistService().get_rules(self.target_id)
|
||||
self._blacklist_filter = BlacklistFilter(rules)
|
||||
return self._blacklist_filter
|
||||
84
backend/apps/scan/providers/list_provider.py
Normal file
84
backend/apps/scan/providers/list_provider.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
列表目标提供者模块
|
||||
|
||||
提供基于内存列表的目标提供者实现。
|
||||
"""
|
||||
|
||||
from typing import Iterator, Optional, List
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
|
||||
|
||||
class ListTargetProvider(TargetProvider):
|
||||
"""
|
||||
列表目标提供者 - 直接使用内存中的列表
|
||||
|
||||
用于快速扫描、临时扫描等场景,只扫描用户指定的目标。
|
||||
|
||||
特点:
|
||||
- 不查询数据库
|
||||
- 不应用黑名单过滤(用户明确指定的目标)
|
||||
- 不关联 target_id(由调用方负责创建 Target)
|
||||
- 自动检测输入类型(URL/域名/IP/CIDR)
|
||||
- 自动展开 CIDR
|
||||
|
||||
使用方式:
|
||||
# 快速扫描:用户提供目标,自动识别类型
|
||||
provider = ListTargetProvider(targets=[
|
||||
"example.com", # 域名
|
||||
"192.168.1.0/24", # CIDR(自动展开)
|
||||
"https://api.example.com" # URL
|
||||
])
|
||||
for host in provider.iter_hosts():
|
||||
scan(host)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
targets: Optional[List[str]] = None,
|
||||
context: Optional[ProviderContext] = None
|
||||
):
|
||||
"""
|
||||
初始化列表目标提供者
|
||||
|
||||
Args:
|
||||
targets: 目标列表(自动识别类型:URL/域名/IP/CIDR)
|
||||
context: Provider 上下文
|
||||
"""
|
||||
from apps.common.validators import detect_input_type
|
||||
|
||||
ctx = context or ProviderContext()
|
||||
super().__init__(ctx)
|
||||
|
||||
# 自动分类目标
|
||||
self._hosts = []
|
||||
self._urls = []
|
||||
|
||||
if targets:
|
||||
for target in targets:
|
||||
target = target.strip()
|
||||
if not target:
|
||||
continue
|
||||
|
||||
try:
|
||||
input_type = detect_input_type(target)
|
||||
if input_type == 'url':
|
||||
self._urls.append(target)
|
||||
else:
|
||||
# domain/ip/cidr 都作为 host
|
||||
self._hosts.append(target)
|
||||
except ValueError:
|
||||
# 无法识别类型,默认作为 host
|
||||
self._hosts.append(target)
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""迭代原始主机列表(可能包含 CIDR)"""
|
||||
yield from self._hosts
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""迭代 URL 列表"""
|
||||
yield from self._urls
|
||||
|
||||
def get_blacklist_filter(self) -> None:
|
||||
"""列表模式不使用黑名单过滤"""
|
||||
return None
|
||||
91
backend/apps/scan/providers/pipeline_provider.py
Normal file
91
backend/apps/scan/providers/pipeline_provider.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
管道目标提供者模块
|
||||
|
||||
提供基于管道阶段输出的目标提供者实现。
|
||||
用于 Phase 2 管道模式的阶段间数据传递。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator, Optional, List, Dict, Any
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageOutput:
|
||||
"""
|
||||
阶段输出数据
|
||||
|
||||
用于在管道阶段之间传递数据。
|
||||
|
||||
Attributes:
|
||||
hosts: 主机列表(域名/IP)
|
||||
urls: URL 列表
|
||||
new_targets: 新发现的目标列表
|
||||
stats: 统计信息
|
||||
success: 是否成功
|
||||
error: 错误信息
|
||||
"""
|
||||
hosts: List[str] = field(default_factory=list)
|
||||
urls: List[str] = field(default_factory=list)
|
||||
new_targets: List[str] = field(default_factory=list)
|
||||
stats: Dict[str, Any] = field(default_factory=dict)
|
||||
success: bool = True
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class PipelineTargetProvider(TargetProvider):
|
||||
"""
|
||||
管道目标提供者 - 使用上一阶段的输出
|
||||
|
||||
用于 Phase 2 管道模式的阶段间数据传递。
|
||||
|
||||
特点:
|
||||
- 不查询数据库
|
||||
- 不应用黑名单过滤(数据已在上一阶段过滤)
|
||||
- 直接使用 StageOutput 中的数据
|
||||
|
||||
使用方式(Phase 2):
|
||||
stage1_output = stage1.run(input)
|
||||
provider = PipelineTargetProvider(
|
||||
previous_output=stage1_output,
|
||||
target_id=123
|
||||
)
|
||||
for host in provider.iter_hosts():
|
||||
stage2.scan(host)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
previous_output: StageOutput,
|
||||
target_id: Optional[int] = None,
|
||||
context: Optional[ProviderContext] = None
|
||||
):
|
||||
"""
|
||||
初始化管道目标提供者
|
||||
|
||||
Args:
|
||||
previous_output: 上一阶段的输出
|
||||
target_id: 可选,关联到某个 Target(用于保存结果)
|
||||
context: Provider 上下文
|
||||
"""
|
||||
ctx = context or ProviderContext(target_id=target_id)
|
||||
super().__init__(ctx)
|
||||
self._previous_output = previous_output
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""迭代上一阶段输出的原始主机(可能包含 CIDR)"""
|
||||
yield from self._previous_output.hosts
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""迭代上一阶段输出的 URL"""
|
||||
yield from self._previous_output.urls
|
||||
|
||||
def get_blacklist_filter(self) -> None:
|
||||
"""管道传递的数据已经过滤过了"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def previous_output(self) -> StageOutput:
|
||||
"""返回上一阶段的输出"""
|
||||
return self._previous_output
|
||||
175
backend/apps/scan/providers/snapshot_provider.py
Normal file
175
backend/apps/scan/providers/snapshot_provider.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
快照目标提供者模块
|
||||
|
||||
提供基于快照表的目标提供者实现。
|
||||
用于快速扫描的阶段间数据传递。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Iterator, Optional, Literal
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 快照类型定义
|
||||
SnapshotType = Literal["subdomain", "website", "endpoint", "host_port"]
|
||||
|
||||
|
||||
class SnapshotTargetProvider(TargetProvider):
|
||||
"""
|
||||
快照目标提供者 - 从快照表读取本次扫描的数据
|
||||
|
||||
用于快速扫描的阶段间数据传递,解决精确扫描控制问题。
|
||||
|
||||
核心价值:
|
||||
- 只返回本次扫描(scan_id)发现的资产
|
||||
- 避免扫描历史数据(DatabaseTargetProvider 会扫描所有历史资产)
|
||||
|
||||
特点:
|
||||
- 通过 scan_id 过滤快照表
|
||||
- 不应用黑名单过滤(数据已在上一阶段过滤)
|
||||
- 支持多种快照类型(subdomain/website/endpoint/host_port)
|
||||
|
||||
使用场景:
|
||||
# 快速扫描流程
|
||||
用户输入: a.test.com
|
||||
创建 Target: test.com (id=1)
|
||||
创建 Scan: scan_id=100
|
||||
|
||||
# 阶段1: 子域名发现
|
||||
provider = ListTargetProvider(
|
||||
targets=["a.test.com"],
|
||||
context=ProviderContext(target_id=1, scan_id=100)
|
||||
)
|
||||
# 发现: b.a.test.com, c.a.test.com
|
||||
# 保存: SubdomainSnapshot(scan_id=100) + Subdomain(target_id=1)
|
||||
|
||||
# 阶段2: 端口扫描
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain",
|
||||
context=ProviderContext(target_id=1, scan_id=100)
|
||||
)
|
||||
# 只返回: b.a.test.com, c.a.test.com(本次扫描发现的)
|
||||
# 不返回: www.test.com, api.test.com(历史数据)
|
||||
|
||||
# 阶段3: 网站扫描
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="host_port",
|
||||
context=ProviderContext(target_id=1, scan_id=100)
|
||||
)
|
||||
# 只返回本次扫描发现的 IP:Port
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scan_id: int,
|
||||
snapshot_type: SnapshotType,
|
||||
context: Optional[ProviderContext] = None
|
||||
):
|
||||
"""
|
||||
初始化快照目标提供者
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID(必需)
|
||||
snapshot_type: 快照类型
|
||||
- "subdomain": 子域名快照(SubdomainSnapshot)
|
||||
- "website": 网站快照(WebsiteSnapshot)
|
||||
- "endpoint": 端点快照(EndpointSnapshot)
|
||||
- "host_port": 主机端口映射快照(HostPortMappingSnapshot)
|
||||
context: Provider 上下文
|
||||
"""
|
||||
ctx = context or ProviderContext()
|
||||
ctx.scan_id = scan_id
|
||||
super().__init__(ctx)
|
||||
self._scan_id = scan_id
|
||||
self._snapshot_type = snapshot_type
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""
|
||||
从快照表迭代主机列表
|
||||
|
||||
根据 snapshot_type 选择不同的快照表:
|
||||
- subdomain: SubdomainSnapshot.name
|
||||
- host_port: HostPortMappingSnapshot.host (返回 host:port 格式,不经过验证)
|
||||
"""
|
||||
if self._snapshot_type == "subdomain":
|
||||
from apps.asset.services.snapshot import SubdomainSnapshotsService
|
||||
service = SubdomainSnapshotsService()
|
||||
yield from service.iter_subdomain_names_by_scan(
|
||||
scan_id=self._scan_id,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
elif self._snapshot_type == "host_port":
|
||||
# host_port 类型不使用 _iter_raw_hosts,直接在 iter_hosts 中处理
|
||||
# 这里返回空,避免被基类的 iter_hosts 调用
|
||||
return
|
||||
|
||||
else:
|
||||
# 其他类型暂不支持 iter_hosts
|
||||
logger.warning(
|
||||
"快照类型 '%s' 不支持 iter_hosts,返回空迭代器",
|
||||
self._snapshot_type
|
||||
)
|
||||
return
|
||||
|
||||
def iter_hosts(self) -> Iterator[str]:
|
||||
"""
|
||||
迭代主机列表
|
||||
|
||||
对于 host_port 类型,返回 host:port 格式,不经过 CIDR 展开验证
|
||||
"""
|
||||
if self._snapshot_type == "host_port":
|
||||
# host_port 类型直接返回 host:port,不经过 _expand_host 验证
|
||||
from apps.asset.services.snapshot import HostPortMappingSnapshotsService
|
||||
service = HostPortMappingSnapshotsService()
|
||||
queryset = service.get_by_scan(scan_id=self._scan_id)
|
||||
for mapping in queryset.iterator(chunk_size=1000):
|
||||
yield f"{mapping.host}:{mapping.port}"
|
||||
else:
|
||||
# 其他类型使用基类的 iter_hosts(会调用 _iter_raw_hosts 并展开 CIDR)
|
||||
yield from super().iter_hosts()
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""
|
||||
从快照表迭代 URL 列表
|
||||
|
||||
根据 snapshot_type 选择不同的快照表:
|
||||
- website: WebsiteSnapshot.url
|
||||
- endpoint: EndpointSnapshot.url
|
||||
"""
|
||||
if self._snapshot_type == "website":
|
||||
from apps.asset.services.snapshot import WebsiteSnapshotsService
|
||||
service = WebsiteSnapshotsService()
|
||||
yield from service.iter_website_urls_by_scan(
|
||||
scan_id=self._scan_id,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
elif self._snapshot_type == "endpoint":
|
||||
from apps.asset.services.snapshot import EndpointSnapshotsService
|
||||
service = EndpointSnapshotsService()
|
||||
# 从快照表获取端点 URL
|
||||
queryset = service.get_by_scan(scan_id=self._scan_id)
|
||||
for endpoint in queryset.iterator(chunk_size=1000):
|
||||
yield endpoint.url
|
||||
|
||||
else:
|
||||
# 其他类型暂不支持 iter_urls
|
||||
logger.warning(
|
||||
"快照类型 '%s' 不支持 iter_urls,返回空迭代器",
|
||||
self._snapshot_type
|
||||
)
|
||||
return
|
||||
|
||||
def get_blacklist_filter(self) -> None:
|
||||
"""快照数据已在上一阶段过滤过了"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def snapshot_type(self) -> SnapshotType:
|
||||
"""返回快照类型"""
|
||||
return self._snapshot_type
|
||||
3
backend/apps/scan/providers/tests/__init__.py
Normal file
3
backend/apps/scan/providers/tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
扫描目标提供者测试模块
|
||||
"""
|
||||
256
backend/apps/scan/providers/tests/test_common_properties.py
Normal file
256
backend/apps/scan/providers/tests/test_common_properties.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
通用属性测试
|
||||
|
||||
包含跨多个 Provider 的通用属性测试:
|
||||
- Property 4: Context Propagation
|
||||
- Property 5: Non-Database Provider Blacklist Filter
|
||||
- Property 7: CIDR Expansion Consistency
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st, settings
|
||||
from ipaddress import IPv4Network
|
||||
|
||||
from apps.scan.providers import (
|
||||
ProviderContext,
|
||||
ListTargetProvider,
|
||||
DatabaseTargetProvider,
|
||||
PipelineTargetProvider,
|
||||
SnapshotTargetProvider
|
||||
)
|
||||
from apps.scan.providers.pipeline_provider import StageOutput
|
||||
|
||||
|
||||
class TestContextPropagation:
|
||||
"""
|
||||
Property 4: Context Propagation
|
||||
|
||||
*For any* ProviderContext,传入 Provider 构造函数后,
|
||||
Provider 的 target_id 和 scan_id 属性应该与 context 中的值一致。
|
||||
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_list_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (ListTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
|
||||
provider = ListTargetProvider(targets=["example.com"], context=ctx)
|
||||
|
||||
assert provider.target_id == target_id
|
||||
assert provider.scan_id == scan_id
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_database_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (DatabaseTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=999, scan_id=scan_id)
|
||||
# DatabaseTargetProvider 会覆盖 context 中的 target_id
|
||||
provider = DatabaseTargetProvider(target_id=target_id, context=ctx)
|
||||
|
||||
assert provider.target_id == target_id # 使用构造函数参数
|
||||
assert provider.scan_id == scan_id # 使用 context 中的值
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_pipeline_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (PipelineTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
|
||||
|
||||
assert provider.target_id == target_id
|
||||
assert provider.scan_id == scan_id
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_snapshot_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (SnapshotTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=target_id, scan_id=999)
|
||||
# SnapshotTargetProvider 会覆盖 context 中的 scan_id
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=scan_id,
|
||||
snapshot_type="subdomain",
|
||||
context=ctx
|
||||
)
|
||||
|
||||
assert provider.target_id == target_id # 使用 context 中的值
|
||||
assert provider.scan_id == scan_id # 使用构造函数参数
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
|
||||
class TestNonDatabaseProviderBlacklistFilter:
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter
|
||||
|
||||
*For any* ListTargetProvider 或 PipelineTargetProvider 实例,
|
||||
get_blacklist_filter() 方法应该返回 None。
|
||||
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
|
||||
@given(targets=st.lists(st.text(min_size=1, max_size=20), max_size=10))
|
||||
@settings(max_examples=100)
|
||||
def test_property_5_list_provider_no_blacklist(self, targets):
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter (ListTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
provider = ListTargetProvider(targets=targets)
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
@given(hosts=st.lists(st.text(min_size=1, max_size=20), max_size=10))
|
||||
@settings(max_examples=100)
|
||||
def test_property_5_pipeline_provider_no_blacklist(self, hosts):
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter (PipelineTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
stage_output = StageOutput(hosts=hosts)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_property_5_snapshot_provider_no_blacklist(self):
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter (SnapshotTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
provider = SnapshotTargetProvider(scan_id=1, snapshot_type="subdomain")
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
|
||||
class TestCIDRExpansionConsistency:
|
||||
"""
|
||||
Property 7: CIDR Expansion Consistency
|
||||
|
||||
*For any* CIDR 字符串(如 "192.168.1.0/24"),所有 Provider 的 iter_hosts()
|
||||
方法应该将其展开为相同的单个 IP 地址列表。
|
||||
|
||||
**Validates: Requirements 1.1, 3.6**
|
||||
"""
|
||||
|
||||
@given(
|
||||
# 生成小的 CIDR 范围以避免测试超时
|
||||
network_prefix=st.integers(min_value=1, max_value=254),
|
||||
cidr_suffix=st.integers(min_value=28, max_value=30) # /28 = 16 IPs, /30 = 4 IPs
|
||||
)
|
||||
@settings(max_examples=50, deadline=None)
|
||||
def test_property_7_cidr_expansion_consistency(self, network_prefix, cidr_suffix):
|
||||
"""
|
||||
Property 7: CIDR Expansion Consistency
|
||||
|
||||
Feature: scan-target-provider, Property 7: CIDR Expansion Consistency
|
||||
**Validates: Requirements 1.1, 3.6**
|
||||
|
||||
For any CIDR string, all Providers should expand it to the same IP list.
|
||||
"""
|
||||
cidr = f"192.168.{network_prefix}.0/{cidr_suffix}"
|
||||
|
||||
# 计算预期的 IP 列表
|
||||
network = IPv4Network(cidr, strict=False)
|
||||
# 排除网络地址和广播地址
|
||||
expected_ips = [str(ip) for ip in network.hosts()]
|
||||
|
||||
# 如果 CIDR 太小(/31 或 /32),使用所有地址
|
||||
if not expected_ips:
|
||||
expected_ips = [str(ip) for ip in network]
|
||||
|
||||
# ListTargetProvider
|
||||
list_provider = ListTargetProvider(targets=[cidr])
|
||||
list_result = list(list_provider.iter_hosts())
|
||||
|
||||
# PipelineTargetProvider
|
||||
stage_output = StageOutput(hosts=[cidr])
|
||||
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
pipeline_result = list(pipeline_provider.iter_hosts())
|
||||
|
||||
# 验证:所有 Provider 展开的结果应该一致
|
||||
assert list_result == expected_ips, f"ListProvider CIDR expansion mismatch for {cidr}"
|
||||
assert pipeline_result == expected_ips, f"PipelineProvider CIDR expansion mismatch for {cidr}"
|
||||
assert list_result == pipeline_result, f"Providers produce different results for {cidr}"
|
||||
|
||||
def test_cidr_expansion_with_multiple_cidrs(self):
|
||||
"""测试多个 CIDR 的展开一致性"""
|
||||
cidrs = ["192.168.1.0/30", "10.0.0.0/30"]
|
||||
|
||||
# 计算预期结果
|
||||
expected_ips = []
|
||||
for cidr in cidrs:
|
||||
network = IPv4Network(cidr, strict=False)
|
||||
expected_ips.extend([str(ip) for ip in network.hosts()])
|
||||
|
||||
# ListTargetProvider
|
||||
list_provider = ListTargetProvider(targets=cidrs)
|
||||
list_result = list(list_provider.iter_hosts())
|
||||
|
||||
# PipelineTargetProvider
|
||||
stage_output = StageOutput(hosts=cidrs)
|
||||
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
pipeline_result = list(pipeline_provider.iter_hosts())
|
||||
|
||||
# 验证
|
||||
assert list_result == expected_ips
|
||||
assert pipeline_result == expected_ips
|
||||
assert list_result == pipeline_result
|
||||
|
||||
def test_mixed_hosts_and_cidrs(self):
|
||||
"""测试混合主机和 CIDR 的处理"""
|
||||
targets = ["example.com", "192.168.1.0/30", "test.com"]
|
||||
|
||||
# 计算预期结果
|
||||
network = IPv4Network("192.168.1.0/30", strict=False)
|
||||
cidr_ips = [str(ip) for ip in network.hosts()]
|
||||
expected = ["example.com"] + cidr_ips + ["test.com"]
|
||||
|
||||
# ListTargetProvider
|
||||
list_provider = ListTargetProvider(targets=targets)
|
||||
list_result = list(list_provider.iter_hosts())
|
||||
|
||||
# 验证
|
||||
assert list_result == expected
|
||||
158
backend/apps/scan/providers/tests/test_database_provider.py
Normal file
158
backend/apps/scan/providers/tests/test_database_provider.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
DatabaseTargetProvider 属性测试
|
||||
|
||||
Property 7: DatabaseTargetProvider Blacklist Application
|
||||
*For any* 带有黑名单规则的 target_id,DatabaseTargetProvider 的 iter_hosts() 和 iter_urls()
|
||||
应该过滤掉匹配黑名单规则的目标。
|
||||
|
||||
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from hypothesis import given, strategies as st, settings
|
||||
|
||||
from apps.scan.providers.database_provider import DatabaseTargetProvider
|
||||
from apps.scan.providers.base import ProviderContext
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
|
||||
class MockBlacklistFilter:
|
||||
"""模拟黑名单过滤器"""
|
||||
|
||||
def __init__(self, blocked_patterns: list):
|
||||
self.blocked_patterns = blocked_patterns
|
||||
|
||||
def is_allowed(self, target: str) -> bool:
|
||||
"""检查目标是否被允许(不在黑名单中)"""
|
||||
for pattern in self.blocked_patterns:
|
||||
if pattern in target:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class TestDatabaseTargetProviderProperties:
|
||||
"""DatabaseTargetProvider 属性测试类"""
|
||||
|
||||
@given(
|
||||
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=20),
|
||||
blocked_keyword=st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=5
|
||||
)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_7_blacklist_filters_hosts(self, hosts, blocked_keyword):
|
||||
"""
|
||||
Property 7: DatabaseTargetProvider Blacklist Application (hosts)
|
||||
|
||||
Feature: scan-target-provider, Property 7: DatabaseTargetProvider Blacklist Application
|
||||
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
|
||||
|
||||
For any set of hosts and a blacklist keyword, the provider should filter out
|
||||
all hosts containing the blocked keyword.
|
||||
"""
|
||||
# 创建模拟的黑名单过滤器
|
||||
mock_filter = MockBlacklistFilter([blocked_keyword])
|
||||
|
||||
# 创建 provider 并注入模拟的黑名单过滤器
|
||||
provider = DatabaseTargetProvider(target_id=1)
|
||||
provider._blacklist_filter = mock_filter
|
||||
|
||||
# 模拟 Target 和 SubdomainService
|
||||
mock_target = MagicMock()
|
||||
mock_target.type = 'domain'
|
||||
mock_target.name = hosts[0] if hosts else 'example.com'
|
||||
|
||||
with patch('apps.targets.services.TargetService') as mock_target_service, \
|
||||
patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_subdomain_service:
|
||||
|
||||
mock_target_service.return_value.get_target.return_value = mock_target
|
||||
mock_subdomain_service.return_value.iter_subdomain_names_by_target.return_value = iter(hosts[1:] if len(hosts) > 1 else [])
|
||||
|
||||
# 获取结果
|
||||
result = list(provider.iter_hosts())
|
||||
|
||||
# 验证:所有结果都不包含被阻止的关键词
|
||||
for host in result:
|
||||
assert blocked_keyword not in host, f"Host '{host}' should be filtered by blacklist keyword '{blocked_keyword}'"
|
||||
|
||||
# 验证:所有不包含关键词的主机都应该在结果中
|
||||
if hosts:
|
||||
all_hosts = [hosts[0]] + [h for h in hosts[1:] if h != hosts[0]]
|
||||
expected_allowed = [h for h in all_hosts if blocked_keyword not in h]
|
||||
else:
|
||||
expected_allowed = []
|
||||
|
||||
assert set(result) == set(expected_allowed)
|
||||
|
||||
|
||||
class TestDatabaseTargetProviderUnit:
|
||||
"""DatabaseTargetProvider 单元测试类"""
|
||||
|
||||
def test_target_id_in_context(self):
|
||||
"""测试 target_id 正确设置到上下文中"""
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
assert provider.target_id == 123
|
||||
assert provider.context.target_id == 123
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(scan_id=789)
|
||||
provider = DatabaseTargetProvider(target_id=123, context=ctx)
|
||||
|
||||
assert provider.target_id == 123 # target_id 被覆盖
|
||||
assert provider.scan_id == 789
|
||||
|
||||
def test_blacklist_filter_lazy_loading(self):
|
||||
"""测试黑名单过滤器延迟加载"""
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
|
||||
# 初始时 _blacklist_filter 为 None
|
||||
assert provider._blacklist_filter is None
|
||||
|
||||
# 模拟 BlacklistService
|
||||
with patch('apps.common.services.BlacklistService') as mock_service, \
|
||||
patch('apps.common.utils.BlacklistFilter') as mock_filter_class:
|
||||
|
||||
mock_service.return_value.get_rules.return_value = []
|
||||
mock_filter_instance = MagicMock()
|
||||
mock_filter_class.return_value = mock_filter_instance
|
||||
|
||||
# 第一次调用
|
||||
result1 = provider.get_blacklist_filter()
|
||||
assert result1 == mock_filter_instance
|
||||
|
||||
# 第二次调用应该返回缓存的实例
|
||||
result2 = provider.get_blacklist_filter()
|
||||
assert result2 == mock_filter_instance
|
||||
|
||||
# BlacklistService 只应该被调用一次
|
||||
mock_service.return_value.get_rules.assert_called_once_with(123)
|
||||
|
||||
def test_nonexistent_target_returns_empty(self):
|
||||
"""测试不存在的 target 返回空迭代器"""
|
||||
provider = DatabaseTargetProvider(target_id=99999)
|
||||
|
||||
with patch('apps.targets.services.TargetService') as mock_service, \
|
||||
patch('apps.common.services.BlacklistService') as mock_blacklist_service:
|
||||
|
||||
mock_service.return_value.get_target.return_value = None
|
||||
mock_blacklist_service.return_value.get_rules.return_value = []
|
||||
|
||||
result = list(provider.iter_hosts())
|
||||
assert result == []
|
||||
152
backend/apps/scan/providers/tests/test_list_provider.py
Normal file
152
backend/apps/scan/providers/tests/test_list_provider.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
ListTargetProvider 属性测试
|
||||
|
||||
Property 1: ListTargetProvider Round-Trip
|
||||
*For any* 主机列表和 URL 列表,创建 ListTargetProvider 后迭代 iter_hosts() 和 iter_urls()
|
||||
应该返回与输入相同的元素(顺序相同)。
|
||||
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st, settings, assume
|
||||
|
||||
from apps.scan.providers.list_provider import ListTargetProvider
|
||||
from apps.scan.providers.base import ProviderContext
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
# 生成简单的域名格式: subdomain.domain.tld
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
# 生成有效 IP 地址的策略
|
||||
def valid_ip_strategy():
|
||||
"""生成有效的 IPv4 地址"""
|
||||
octet = st.integers(min_value=1, max_value=254)
|
||||
return st.builds(
|
||||
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
|
||||
octet, octet, octet, octet
|
||||
)
|
||||
|
||||
# 组合策略:域名或 IP
|
||||
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
|
||||
|
||||
# 生成有效 URL 的策略
|
||||
def valid_url_strategy():
|
||||
"""生成有效的 URL"""
|
||||
domain = valid_domain_strategy()
|
||||
return st.builds(
|
||||
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
|
||||
domain,
|
||||
st.one_of(
|
||||
st.just(""),
|
||||
st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=1,
|
||||
max_size=10
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
url_strategy = valid_url_strategy()
|
||||
|
||||
|
||||
class TestListTargetProviderProperties:
|
||||
"""ListTargetProvider 属性测试类"""
|
||||
|
||||
@given(hosts=st.lists(host_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_1_hosts_round_trip(self, hosts):
|
||||
"""
|
||||
Property 1: ListTargetProvider Round-Trip (hosts)
|
||||
|
||||
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
For any host list, creating a ListTargetProvider and iterating iter_hosts()
|
||||
should return the same elements in the same order.
|
||||
"""
|
||||
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
|
||||
provider = ListTargetProvider(targets=hosts)
|
||||
result = list(provider.iter_hosts())
|
||||
assert result == hosts
|
||||
|
||||
@given(urls=st.lists(url_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_1_urls_round_trip(self, urls):
|
||||
"""
|
||||
Property 1: ListTargetProvider Round-Trip (urls)
|
||||
|
||||
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
For any URL list, creating a ListTargetProvider and iterating iter_urls()
|
||||
should return the same elements in the same order.
|
||||
"""
|
||||
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
|
||||
provider = ListTargetProvider(targets=urls)
|
||||
result = list(provider.iter_urls())
|
||||
assert result == urls
|
||||
|
||||
@given(
|
||||
hosts=st.lists(host_strategy, max_size=30),
|
||||
urls=st.lists(url_strategy, max_size=30)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_1_combined_round_trip(self, hosts, urls):
|
||||
"""
|
||||
Property 1: ListTargetProvider Round-Trip (combined)
|
||||
|
||||
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
For any combination of hosts and URLs, both should round-trip correctly.
|
||||
"""
|
||||
# 合并 hosts 和 urls,ListTargetProvider 会自动分类
|
||||
combined = hosts + urls
|
||||
provider = ListTargetProvider(targets=combined)
|
||||
|
||||
hosts_result = list(provider.iter_hosts())
|
||||
urls_result = list(provider.iter_urls())
|
||||
|
||||
assert hosts_result == hosts
|
||||
assert urls_result == urls
|
||||
|
||||
|
||||
class TestListTargetProviderUnit:
|
||||
"""ListTargetProvider 单元测试类"""
|
||||
|
||||
def test_empty_lists(self):
|
||||
"""测试空列表返回空迭代器 - Requirements 3.5"""
|
||||
provider = ListTargetProvider()
|
||||
assert list(provider.iter_hosts()) == []
|
||||
assert list(provider.iter_urls()) == []
|
||||
|
||||
def test_blacklist_filter_returns_none(self):
|
||||
"""测试黑名单过滤器返回 None - Requirements 3.4"""
|
||||
provider = ListTargetProvider(targets=["example.com"])
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_target_id_association(self):
|
||||
"""测试 target_id 关联 - Requirements 3.3"""
|
||||
ctx = ProviderContext(target_id=123)
|
||||
provider = ListTargetProvider(targets=["example.com"], context=ctx)
|
||||
assert provider.target_id == 123
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(target_id=456, scan_id=789)
|
||||
provider = ListTargetProvider(targets=["example.com"], context=ctx)
|
||||
|
||||
assert provider.target_id == 456
|
||||
assert provider.scan_id == 789
|
||||
180
backend/apps/scan/providers/tests/test_pipeline_provider.py
Normal file
180
backend/apps/scan/providers/tests/test_pipeline_provider.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
PipelineTargetProvider 属性测试
|
||||
|
||||
Property 3: PipelineTargetProvider Round-Trip
|
||||
*For any* StageOutput 对象,PipelineTargetProvider 的 iter_hosts() 和 iter_urls()
|
||||
应该返回与 StageOutput 中 hosts 和 urls 列表相同的元素。
|
||||
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st, settings
|
||||
|
||||
from apps.scan.providers.pipeline_provider import PipelineTargetProvider, StageOutput
|
||||
from apps.scan.providers.base import ProviderContext
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
# 生成有效 IP 地址的策略
|
||||
def valid_ip_strategy():
|
||||
"""生成有效的 IPv4 地址"""
|
||||
octet = st.integers(min_value=1, max_value=254)
|
||||
return st.builds(
|
||||
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
|
||||
octet, octet, octet, octet
|
||||
)
|
||||
|
||||
# 组合策略:域名或 IP
|
||||
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
|
||||
|
||||
# 生成有效 URL 的策略
|
||||
def valid_url_strategy():
|
||||
"""生成有效的 URL"""
|
||||
domain = valid_domain_strategy()
|
||||
return st.builds(
|
||||
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
|
||||
domain,
|
||||
st.one_of(
|
||||
st.just(""),
|
||||
st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=1,
|
||||
max_size=10
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
url_strategy = valid_url_strategy()
|
||||
|
||||
|
||||
class TestPipelineTargetProviderProperties:
|
||||
"""PipelineTargetProvider 属性测试类"""
|
||||
|
||||
@given(hosts=st.lists(host_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_3_hosts_round_trip(self, hosts):
|
||||
"""
|
||||
Property 3: PipelineTargetProvider Round-Trip (hosts)
|
||||
|
||||
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
|
||||
For any StageOutput with hosts, PipelineTargetProvider should return
|
||||
the same hosts in the same order.
|
||||
"""
|
||||
stage_output = StageOutput(hosts=hosts)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
result = list(provider.iter_hosts())
|
||||
assert result == hosts
|
||||
|
||||
@given(urls=st.lists(url_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_3_urls_round_trip(self, urls):
|
||||
"""
|
||||
Property 3: PipelineTargetProvider Round-Trip (urls)
|
||||
|
||||
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
|
||||
For any StageOutput with urls, PipelineTargetProvider should return
|
||||
the same urls in the same order.
|
||||
"""
|
||||
stage_output = StageOutput(urls=urls)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
result = list(provider.iter_urls())
|
||||
assert result == urls
|
||||
|
||||
@given(
|
||||
hosts=st.lists(host_strategy, max_size=30),
|
||||
urls=st.lists(url_strategy, max_size=30)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_3_combined_round_trip(self, hosts, urls):
|
||||
"""
|
||||
Property 3: PipelineTargetProvider Round-Trip (combined)
|
||||
|
||||
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
|
||||
For any StageOutput with both hosts and urls, both should round-trip correctly.
|
||||
"""
|
||||
stage_output = StageOutput(hosts=hosts, urls=urls)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
hosts_result = list(provider.iter_hosts())
|
||||
urls_result = list(provider.iter_urls())
|
||||
|
||||
assert hosts_result == hosts
|
||||
assert urls_result == urls
|
||||
|
||||
|
||||
class TestPipelineTargetProviderUnit:
|
||||
"""PipelineTargetProvider 单元测试类"""
|
||||
|
||||
def test_empty_stage_output(self):
|
||||
"""测试空 StageOutput 返回空迭代器 - Requirements 5.5"""
|
||||
stage_output = StageOutput()
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
assert list(provider.iter_hosts()) == []
|
||||
assert list(provider.iter_urls()) == []
|
||||
|
||||
def test_blacklist_filter_returns_none(self):
|
||||
"""测试黑名单过滤器返回 None - Requirements 5.3"""
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_target_id_association(self):
|
||||
"""测试 target_id 关联 - Requirements 5.4"""
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output, target_id=123)
|
||||
assert provider.target_id == 123
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(target_id=456, scan_id=789)
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
|
||||
|
||||
assert provider.target_id == 456
|
||||
assert provider.scan_id == 789
|
||||
|
||||
def test_previous_output_property(self):
|
||||
"""测试 previous_output 属性"""
|
||||
stage_output = StageOutput(hosts=["example.com"], urls=["https://example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
assert provider.previous_output is stage_output
|
||||
assert provider.previous_output.hosts == ["example.com"]
|
||||
assert provider.previous_output.urls == ["https://example.com"]
|
||||
|
||||
def test_stage_output_with_metadata(self):
|
||||
"""测试带元数据的 StageOutput"""
|
||||
stage_output = StageOutput(
|
||||
hosts=["example.com"],
|
||||
urls=["https://example.com"],
|
||||
new_targets=["new.example.com"],
|
||||
stats={"count": 1},
|
||||
success=True,
|
||||
error=None
|
||||
)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
assert list(provider.iter_hosts()) == ["example.com"]
|
||||
assert list(provider.iter_urls()) == ["https://example.com"]
|
||||
assert provider.previous_output.new_targets == ["new.example.com"]
|
||||
assert provider.previous_output.stats == {"count": 1}
|
||||
191
backend/apps/scan/providers/tests/test_snapshot_provider.py
Normal file
191
backend/apps/scan/providers/tests/test_snapshot_provider.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
SnapshotTargetProvider 单元测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from apps.scan.providers import SnapshotTargetProvider, ProviderContext
|
||||
|
||||
|
||||
class TestSnapshotTargetProvider:
|
||||
"""SnapshotTargetProvider 测试类"""
|
||||
|
||||
def test_init_with_scan_id_and_type(self):
|
||||
"""测试初始化"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain"
|
||||
)
|
||||
|
||||
assert provider.scan_id == 100
|
||||
assert provider.snapshot_type == "subdomain"
|
||||
assert provider.target_id is None # 默认 context
|
||||
|
||||
def test_init_with_context(self):
|
||||
"""测试带 context 初始化"""
|
||||
ctx = ProviderContext(target_id=1, scan_id=100)
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain",
|
||||
context=ctx
|
||||
)
|
||||
|
||||
assert provider.scan_id == 100
|
||||
assert provider.target_id == 1
|
||||
assert provider.snapshot_type == "subdomain"
|
||||
|
||||
@patch('apps.asset.services.snapshot.SubdomainSnapshotsService')
|
||||
def test_iter_hosts_subdomain(self, mock_service_class):
|
||||
"""测试从子域名快照迭代主机"""
|
||||
# Mock service
|
||||
mock_service = Mock()
|
||||
mock_service.iter_subdomain_names_by_scan.return_value = iter([
|
||||
"a.example.com",
|
||||
"b.example.com"
|
||||
])
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain"
|
||||
)
|
||||
|
||||
# 迭代主机
|
||||
hosts = list(provider.iter_hosts())
|
||||
|
||||
assert hosts == ["a.example.com", "b.example.com"]
|
||||
mock_service.iter_subdomain_names_by_scan.assert_called_once_with(
|
||||
scan_id=100,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
@patch('apps.asset.services.snapshot.HostPortMappingSnapshotsService')
|
||||
def test_iter_hosts_host_port(self, mock_service_class):
|
||||
"""测试从主机端口映射快照迭代主机"""
|
||||
# Mock queryset
|
||||
mock_mapping1 = Mock()
|
||||
mock_mapping1.host = "example.com"
|
||||
mock_mapping1.port = 80
|
||||
|
||||
mock_mapping2 = Mock()
|
||||
mock_mapping2.host = "example.com"
|
||||
mock_mapping2.port = 443
|
||||
|
||||
mock_queryset = Mock()
|
||||
mock_queryset.iterator.return_value = iter([mock_mapping1, mock_mapping2])
|
||||
|
||||
# Mock service
|
||||
mock_service = Mock()
|
||||
mock_service.get_by_scan.return_value = mock_queryset
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="host_port"
|
||||
)
|
||||
|
||||
# 迭代主机
|
||||
hosts = list(provider.iter_hosts())
|
||||
|
||||
assert hosts == ["example.com:80", "example.com:443"]
|
||||
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
|
||||
|
||||
@patch('apps.asset.services.snapshot.WebsiteSnapshotsService')
|
||||
def test_iter_urls_website(self, mock_service_class):
|
||||
"""测试从网站快照迭代 URL"""
|
||||
# Mock service
|
||||
mock_service = Mock()
|
||||
mock_service.iter_website_urls_by_scan.return_value = iter([
|
||||
"http://example.com",
|
||||
"https://example.com"
|
||||
])
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="website"
|
||||
)
|
||||
|
||||
# 迭代 URL
|
||||
urls = list(provider.iter_urls())
|
||||
|
||||
assert urls == ["http://example.com", "https://example.com"]
|
||||
mock_service.iter_website_urls_by_scan.assert_called_once_with(
|
||||
scan_id=100,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
@patch('apps.asset.services.snapshot.EndpointSnapshotsService')
|
||||
def test_iter_urls_endpoint(self, mock_service_class):
|
||||
"""测试从端点快照迭代 URL"""
|
||||
# Mock queryset
|
||||
mock_endpoint1 = Mock()
|
||||
mock_endpoint1.url = "http://example.com/api/v1"
|
||||
|
||||
mock_endpoint2 = Mock()
|
||||
mock_endpoint2.url = "http://example.com/api/v2"
|
||||
|
||||
mock_queryset = Mock()
|
||||
mock_queryset.iterator.return_value = iter([mock_endpoint1, mock_endpoint2])
|
||||
|
||||
# Mock service
|
||||
mock_service = Mock()
|
||||
mock_service.get_by_scan.return_value = mock_queryset
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="endpoint"
|
||||
)
|
||||
|
||||
# 迭代 URL
|
||||
urls = list(provider.iter_urls())
|
||||
|
||||
assert urls == ["http://example.com/api/v1", "http://example.com/api/v2"]
|
||||
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
|
||||
|
||||
def test_iter_hosts_unsupported_type(self):
|
||||
"""测试不支持的快照类型(iter_hosts)"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="website" # website 不支持 iter_hosts
|
||||
)
|
||||
|
||||
hosts = list(provider.iter_hosts())
|
||||
assert hosts == []
|
||||
|
||||
def test_iter_urls_unsupported_type(self):
|
||||
"""测试不支持的快照类型(iter_urls)"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain" # subdomain 不支持 iter_urls
|
||||
)
|
||||
|
||||
urls = list(provider.iter_urls())
|
||||
assert urls == []
|
||||
|
||||
def test_get_blacklist_filter(self):
|
||||
"""测试黑名单过滤器(快照模式不使用黑名单)"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain"
|
||||
)
|
||||
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(target_id=456, scan_id=789)
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100, # 会被 context 覆盖
|
||||
snapshot_type="subdomain",
|
||||
context=ctx
|
||||
)
|
||||
|
||||
assert provider.target_id == 456
|
||||
assert provider.scan_id == 100 # scan_id 在 __init__ 中被设置
|
||||
@@ -12,7 +12,7 @@
|
||||
import ipaddress
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List, Iterator, Tuple, Callable
|
||||
from typing import Dict, Any, Optional, List, Iterator, Tuple
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
@@ -485,8 +485,7 @@ class TargetExportService:
|
||||
"""
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
|
||||
|
||||
output_file = Path(output_path)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -1,36 +1,48 @@
|
||||
"""
|
||||
导出站点 URL 到 TXT 文件的 Task
|
||||
|
||||
使用 export_urls_with_fallback 用例函数处理回退链逻辑
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
数据源: WebSite.url → Default
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_sites")
|
||||
def export_sites_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
output_file: str = "",
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000,
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点 URL 到 TXT 文件
|
||||
|
||||
数据源优先级(回退链):
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. WebSite 表 - 站点级别 URL
|
||||
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 每次读取的批次大小,默认 1000
|
||||
|
||||
Returns:
|
||||
@@ -44,6 +56,17 @@ def export_sites_task(
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_file=output_file,
|
||||
@@ -62,3 +85,32 @@ def export_sites_task(
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
}
|
||||
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
}
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
"""
|
||||
导出 URL 任务
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
用于指纹识别前导出目标下的 URL 到文件
|
||||
使用 export_urls_with_fallback 用例函数处理回退链逻辑
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import task
|
||||
|
||||
@@ -13,33 +18,51 @@ from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_urls_for_fingerprint")
|
||||
def export_urls_for_fingerprint_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
output_file: str = "",
|
||||
source: str = 'website', # 保留参数,兼容旧调用(实际值由回退链决定)
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的 URL 到文件(用于指纹识别)
|
||||
|
||||
数据源优先级(回退链):
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. WebSite 表 - 站点级别 URL
|
||||
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
output_file: 输出文件路径
|
||||
source: 数据源类型(保留参数,兼容旧调用,实际值由回退链决定)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 批量读取大小
|
||||
|
||||
Returns:
|
||||
dict: {'output_file': str, 'total_count': int, 'source': str}
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_file=output_file,
|
||||
@@ -58,3 +81,32 @@ def export_urls_for_fingerprint_task(
|
||||
'total_count': result['total_count'],
|
||||
'source': result['source'],
|
||||
}
|
||||
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
return {
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
'source': 'provider',
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
导出主机列表到 TXT 文件的 Task
|
||||
|
||||
使用 TargetExportService.export_hosts() 统一处理导出逻辑
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
根据 Target 类型决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名
|
||||
@@ -9,57 +11,89 @@
|
||||
- CIDR: 展开 CIDR 范围内的所有 IP
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.services.target_export_service import create_export_service
|
||||
from apps.scan.providers import DatabaseTargetProvider, TargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_hosts")
|
||||
def export_hosts_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
batch_size: int = 1000
|
||||
target_id: Optional[int] = None,
|
||||
provider: Optional[TargetProvider] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
导出主机列表到 TXT 文件
|
||||
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
根据 Target 类型自动决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名(流式处理,支持 10万+ 域名)
|
||||
- IP: 直接写入 target.name(单个 IP)
|
||||
- CIDR: 展开 CIDR 范围内的所有可用 IP
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
batch_size: 每次读取的批次大小,默认 1000(仅对 DOMAIN 类型有效)
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int,
|
||||
'target_type': str
|
||||
'target_type': str # 仅传统模式返回
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: Target 不存在
|
||||
ValueError: 参数错误(target_id 和 provider 都未提供)
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
# 使用工厂函数创建导出服务
|
||||
export_service = create_export_service(target_id)
|
||||
|
||||
result = export_service.export_hosts(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'success': result['success'],
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
'target_type': result['target_type']
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# 向后兼容:如果没有提供 provider,使用 target_id 创建 DatabaseTargetProvider
|
||||
use_legacy_mode = provider is None
|
||||
if use_legacy_mode:
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
provider = DatabaseTargetProvider(target_id=target_id)
|
||||
else:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用 Provider 导出主机列表(iter_hosts 内部已处理黑名单过滤)
|
||||
total_count = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for host in provider.iter_hosts():
|
||||
f.write(f"{host}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个主机...", total_count)
|
||||
|
||||
logger.info("✓ 主机列表导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
result = {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
}
|
||||
|
||||
# 传统模式:保持返回值格式不变(向后兼容)
|
||||
if use_legacy_mode:
|
||||
from apps.targets.services import TargetService
|
||||
target = TargetService().get_target(target_id)
|
||||
result['target_type'] = target.type if target else 'unknown'
|
||||
|
||||
return result
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
导出站点URL到文件的Task
|
||||
|
||||
直接使用 HostPortMapping 表查询 host+port 组合,拼接成URL格式写入文件
|
||||
使用 TargetExportService.generate_default_urls() 处理默认值回退逻辑
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
@@ -10,6 +11,7 @@
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
@@ -17,6 +19,7 @@ from apps.asset.services import HostPortMappingService
|
||||
from apps.scan.services.target_export_service import create_export_service
|
||||
from apps.common.services import BlacklistService
|
||||
from apps.common.utils import BlacklistFilter
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider, ProviderContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,26 +42,30 @@ def _generate_urls_from_port(host: str, port: int) -> list[str]:
|
||||
|
||||
@task(name="export_site_urls")
|
||||
def export_site_urls_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点URL到文件(基于 HostPortMapping 表)
|
||||
导出目标下的所有站点URL到文件
|
||||
|
||||
数据源: HostPortMapping (host + port) → Default
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从 HostPortMapping 表导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
特殊逻辑:
|
||||
传统模式特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
|
||||
回退逻辑:
|
||||
回退逻辑(仅传统模式):
|
||||
- 如果 HostPortMapping 为空,使用 generate_default_urls() 生成默认 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
target_id: 目标ID(传统模式,向后兼容)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 每次处理的批次大小
|
||||
|
||||
Returns:
|
||||
@@ -66,14 +73,62 @@ def export_site_urls_task(
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_urls': int,
|
||||
'association_count': int, # 主机端口关联数量
|
||||
'source': str, # 数据来源: "host_port" | "default"
|
||||
'association_count': int, # 主机端口关联数量(仅传统模式)
|
||||
'source': str, # 数据来源: "host_port" | "default" | "provider"
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# 向后兼容:如果没有提供 provider,使用传统模式
|
||||
if provider is None:
|
||||
logger.info("使用传统模式 - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
return _export_site_urls_legacy(target_id, output_file, batch_size)
|
||||
|
||||
# Provider 模式
|
||||
logger.info("使用 Provider 模式 - Provider: %s, 输出文件: %s", type(provider).__name__, output_file)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用 Provider 导出 URL 列表
|
||||
total_urls = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
if total_urls % 1000 == 0:
|
||||
logger.info("已导出 %d 个URL...", total_urls)
|
||||
|
||||
logger.info("✓ URL导出完成 - 总数: %d, 文件: %s", total_urls, str(output_path))
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_urls': total_urls,
|
||||
'source': 'provider',
|
||||
}
|
||||
|
||||
|
||||
def _export_site_urls_legacy(target_id: int, output_file: str, batch_size: int) -> dict:
|
||||
"""
|
||||
传统模式:从 HostPortMapping 表导出 URL
|
||||
|
||||
保持原有逻辑不变,确保向后兼容
|
||||
"""
|
||||
logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
|
||||
# 确保输出目录存在
|
||||
|
||||
@@ -20,63 +20,40 @@ Note:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from prefect import task
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from prefect import task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 注:使用纯系统命令实现,无需 Python 缓冲区配置
|
||||
# 工具(amass/subfinder)输出已是小写且标准化
|
||||
|
||||
@task(
|
||||
name='merge_and_deduplicate',
|
||||
retries=1,
|
||||
log_prints=True
|
||||
)
|
||||
def merge_and_validate_task(
|
||||
result_files: List[str],
|
||||
result_dir: str
|
||||
) -> str:
|
||||
"""
|
||||
合并扫描结果并去重(高性能流式处理)
|
||||
|
||||
流程:
|
||||
1. 使用 LC_ALL=C sort -u 直接处理多文件
|
||||
2. 排序去重一步完成
|
||||
3. 返回去重后的文件路径
|
||||
|
||||
命令:LC_ALL=C sort -u file1 file2 file3 -o output
|
||||
注:工具输出已标准化(小写,无空行),无需额外处理
|
||||
|
||||
Args:
|
||||
result_files: 结果文件路径列表
|
||||
result_dir: 结果目录
|
||||
|
||||
Returns:
|
||||
str: 去重后的域名文件路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 处理失败
|
||||
|
||||
Performance:
|
||||
- 纯系统命令(C语言实现),单进程极简
|
||||
- LC_ALL=C: 字节序比较
|
||||
- sort -u: 直接处理多文件(无管道开销)
|
||||
|
||||
Design:
|
||||
- 极简单命令,无冗余处理
|
||||
- 单进程直接执行(无管道/重定向开销)
|
||||
- 内存占用仅在 sort 阶段(外部排序,不会 OOM)
|
||||
"""
|
||||
logger.info("开始合并并去重 %d 个结果文件(系统命令优化)", len(result_files))
|
||||
|
||||
result_path = Path(result_dir)
|
||||
|
||||
# 验证文件存在性
|
||||
def _count_file_lines(file_path: str) -> int:
|
||||
"""使用 wc -l 统计文件行数,失败时返回 0"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["wc", "-l", file_path],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return int(result.stdout.strip().split()[0])
|
||||
except (subprocess.CalledProcessError, ValueError, IndexError):
|
||||
return 0
|
||||
|
||||
|
||||
def _calculate_timeout(total_lines: int) -> int:
|
||||
"""根据总行数计算超时时间(每行约 0.1 秒,最少 600 秒)"""
|
||||
if total_lines <= 0:
|
||||
return 3600
|
||||
return max(600, int(total_lines * 0.1))
|
||||
|
||||
|
||||
def _validate_input_files(result_files: List[str]) -> List[str]:
|
||||
"""验证输入文件存在性,返回有效文件列表"""
|
||||
valid_files = []
|
||||
for file_path_str in result_files:
|
||||
file_path = Path(file_path_str)
|
||||
@@ -84,112 +61,67 @@ def merge_and_validate_task(
|
||||
valid_files.append(str(file_path))
|
||||
else:
|
||||
logger.warning("结果文件不存在: %s", file_path)
|
||||
|
||||
return valid_files
|
||||
|
||||
|
||||
@task(name='merge_and_deduplicate', retries=1, log_prints=True)
|
||||
def merge_and_validate_task(result_files: List[str], result_dir: str) -> str:
|
||||
"""
|
||||
合并扫描结果并去重(高性能流式处理)
|
||||
|
||||
使用 LC_ALL=C sort -u 直接处理多文件,排序去重一步完成。
|
||||
|
||||
Args:
|
||||
result_files: 结果文件路径列表
|
||||
result_dir: 结果目录
|
||||
|
||||
Returns:
|
||||
去重后的域名文件路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 处理失败
|
||||
"""
|
||||
logger.info("开始合并并去重 %d 个结果文件", len(result_files))
|
||||
|
||||
valid_files = _validate_input_files(result_files)
|
||||
if not valid_files:
|
||||
raise RuntimeError("所有结果文件都不存在")
|
||||
|
||||
|
||||
# 生成输出文件路径
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
short_uuid = uuid.uuid4().hex[:4]
|
||||
merged_file = result_path / f"merged_{timestamp}_{short_uuid}.txt"
|
||||
|
||||
merged_file = Path(result_dir) / f"merged_{timestamp}_{short_uuid}.txt"
|
||||
|
||||
# 计算超时时间
|
||||
total_lines = sum(_count_file_lines(f) for f in valid_files)
|
||||
timeout = _calculate_timeout(total_lines)
|
||||
logger.info("合并去重: 输入总行数=%d, timeout=%d秒", total_lines, timeout)
|
||||
|
||||
# 执行合并去重命令
|
||||
cmd = f"LC_ALL=C sort -u {' '.join(valid_files)} -o {merged_file}"
|
||||
logger.debug("执行命令: %s", cmd)
|
||||
|
||||
try:
|
||||
# ==================== 使用系统命令一步完成:排序去重 ====================
|
||||
# LC_ALL=C: 使用字节序比较(比locale快20-30%)
|
||||
# sort -u: 直接处理多文件,排序去重
|
||||
# -o: 安全输出(比重定向更可靠)
|
||||
cmd = f"LC_ALL=C sort -u {' '.join(valid_files)} -o {merged_file}"
|
||||
|
||||
logger.debug("执行命令: %s", cmd)
|
||||
subprocess.run(cmd, shell=True, check=True, timeout=timeout)
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
raise RuntimeError("合并去重超时,请检查数据量或系统资源") from exc
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise RuntimeError(f"系统命令执行失败: {exc.stderr or exc}") from exc
|
||||
|
||||
# 按输入文件总行数动态计算超时时间
|
||||
total_lines = 0
|
||||
for file_path in valid_files:
|
||||
try:
|
||||
line_count_proc = subprocess.run(
|
||||
["wc", "-l", file_path],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
total_lines += int(line_count_proc.stdout.strip().split()[0])
|
||||
except (subprocess.CalledProcessError, ValueError, IndexError):
|
||||
continue
|
||||
# 验证输出文件
|
||||
if not merged_file.exists():
|
||||
raise RuntimeError("合并文件未被创建")
|
||||
|
||||
timeout = 3600
|
||||
if total_lines > 0:
|
||||
# 按行数线性计算:每行约 0.1 秒
|
||||
base_per_line = 0.1
|
||||
est = int(total_lines * base_per_line)
|
||||
timeout = max(600, est)
|
||||
unique_count = _count_file_lines(str(merged_file))
|
||||
if unique_count == 0:
|
||||
# 降级为 Python 统计
|
||||
with open(merged_file, 'r', encoding='utf-8') as f:
|
||||
unique_count = sum(1 for _ in f)
|
||||
|
||||
logger.info(
|
||||
"Subdomain 合并去重 timeout 自动计算: 输入总行数=%d, timeout=%d秒",
|
||||
total_lines,
|
||||
timeout,
|
||||
)
|
||||
if unique_count == 0:
|
||||
raise RuntimeError("未找到任何有效域名")
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
check=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
logger.debug("✓ 合并去重完成")
|
||||
|
||||
# ==================== 统计结果 ====================
|
||||
if not merged_file.exists():
|
||||
raise RuntimeError("合并文件未被创建")
|
||||
|
||||
# 统计行数(使用系统命令提升大文件性能)
|
||||
try:
|
||||
line_count_proc = subprocess.run(
|
||||
["wc", "-l", str(merged_file)],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
unique_count = int(line_count_proc.stdout.strip().split()[0])
|
||||
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
|
||||
logger.warning(
|
||||
"wc -l 统计失败(文件: %s),降级为 Python 逐行统计 - 错误: %s",
|
||||
merged_file, e
|
||||
)
|
||||
unique_count = 0
|
||||
with open(merged_file, 'r', encoding='utf-8') as file_obj:
|
||||
for _ in file_obj:
|
||||
unique_count += 1
|
||||
|
||||
if unique_count == 0:
|
||||
raise RuntimeError("未找到任何有效域名")
|
||||
|
||||
file_size = merged_file.stat().st_size
|
||||
|
||||
logger.info(
|
||||
"✓ 合并去重完成 - 去重后: %d 个域名, 文件大小: %.2f KB",
|
||||
unique_count,
|
||||
file_size / 1024
|
||||
)
|
||||
|
||||
return str(merged_file)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
error_msg = "合并去重超时(>60分钟),请检查数据量或系统资源"
|
||||
logger.warning(error_msg) # 超时是可预期的
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_msg = f"系统命令执行失败: {e.stderr if e.stderr else str(e)}"
|
||||
logger.warning(error_msg) # 超时是可预期的
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
except IOError as e:
|
||||
error_msg = f"文件读写失败: {e}"
|
||||
logger.warning(error_msg) # 超时是可预期的
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"合并去重失败: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise
|
||||
file_size_kb = merged_file.stat().st_size / 1024
|
||||
logger.info("✓ 合并去重完成 - 去重后: %d 个域名, 文件大小: %.2f KB", unique_count, file_size_kb)
|
||||
|
||||
return str(merged_file)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
运行扫描工具任务
|
||||
|
||||
负责运行单个子域名扫描工具(amass、subfinder 等)
|
||||
负责运行单个子域名扫描工具(subfinder、sublist3r 等)
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -58,7 +58,7 @@ def run_subdomain_discovery_task(
|
||||
timeout=timeout,
|
||||
log_file=log_file # 明确指定日志文件路径
|
||||
)
|
||||
|
||||
|
||||
# 验证输出文件是否生成
|
||||
if not output_file_path.exists():
|
||||
logger.warning(
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Task 向后兼容性测试
|
||||
|
||||
Property 8: Task Backward Compatibility
|
||||
*For any* 任务调用,当仅提供 target_id 参数时,任务应该创建 DatabaseTargetProvider
|
||||
并使用它进行数据访问,行为与改造前一致。
|
||||
|
||||
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
from hypothesis import given, strategies as st, settings
|
||||
|
||||
from apps.scan.tasks.port_scan.export_hosts_task import export_hosts_task
|
||||
from apps.scan.tasks.site_scan.export_site_urls_task import export_site_urls_task
|
||||
from apps.scan.providers import ListTargetProvider
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
|
||||
class TestExportHostsTaskBackwardCompatibility:
|
||||
"""export_hosts_task 向后兼容性测试"""
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=1000),
|
||||
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=10)
|
||||
)
|
||||
@settings(max_examples=50, deadline=None)
|
||||
def test_property_8_legacy_mode_creates_database_provider(self, target_id, hosts):
|
||||
"""
|
||||
Property 8: Task Backward Compatibility (export_hosts_task)
|
||||
|
||||
Feature: scan-target-provider, Property 8: Task Backward Compatibility
|
||||
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
|
||||
|
||||
For any target_id, when calling export_hosts_task with only target_id,
|
||||
it should create a DatabaseTargetProvider and use it for data access.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
# Mock Target 和 SubdomainService
|
||||
mock_target = MagicMock()
|
||||
mock_target.type = 'domain'
|
||||
mock_target.name = hosts[0]
|
||||
|
||||
with patch('apps.scan.tasks.port_scan.export_hosts_task.DatabaseTargetProvider') as mock_provider_class, \
|
||||
patch('apps.targets.services.TargetService') as mock_target_service:
|
||||
|
||||
# 创建 mock provider 实例
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.iter_hosts.return_value = iter(hosts)
|
||||
mock_provider.get_blacklist_filter.return_value = None
|
||||
mock_provider_class.return_value = mock_provider
|
||||
|
||||
# Mock TargetService
|
||||
mock_target_service.return_value.get_target.return_value = mock_target
|
||||
|
||||
# 调用任务(传统模式:只传 target_id)
|
||||
result = export_hosts_task(
|
||||
output_file=output_file,
|
||||
target_id=target_id
|
||||
)
|
||||
|
||||
# 验证:应该创建了 DatabaseTargetProvider
|
||||
mock_provider_class.assert_called_once_with(target_id=target_id)
|
||||
|
||||
# 验证:返回值包含必需字段
|
||||
assert result['success'] is True
|
||||
assert result['output_file'] == output_file
|
||||
assert result['total_count'] == len(hosts)
|
||||
assert 'target_type' in result # 传统模式应该返回 target_type
|
||||
|
||||
# 验证:文件内容正确
|
||||
with open(output_file, 'r') as f:
|
||||
lines = [line.strip() for line in f.readlines()]
|
||||
assert lines == hosts
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
def test_legacy_mode_with_provider_parameter(self):
|
||||
"""测试当同时提供 target_id 和 provider 时,provider 优先"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
hosts = ['example.com', 'test.com']
|
||||
provider = ListTargetProvider(targets=hosts)
|
||||
|
||||
# 调用任务(同时提供 target_id 和 provider)
|
||||
result = export_hosts_task(
|
||||
output_file=output_file,
|
||||
target_id=123, # 应该被忽略
|
||||
provider=provider
|
||||
)
|
||||
|
||||
# 验证:使用了 provider
|
||||
assert result['success'] is True
|
||||
assert result['total_count'] == len(hosts)
|
||||
assert 'target_type' not in result # Provider 模式不返回 target_type
|
||||
|
||||
# 验证:文件内容正确
|
||||
with open(output_file, 'r') as f:
|
||||
lines = [line.strip() for line in f.readlines()]
|
||||
assert lines == hosts
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
def test_error_when_no_parameters(self):
|
||||
"""测试当 target_id 和 provider 都未提供时抛出错误"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError, match="必须提供 target_id 或 provider 参数之一"):
|
||||
export_hosts_task(output_file=output_file)
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
|
||||
class TestExportSiteUrlsTaskBackwardCompatibility:
|
||||
"""export_site_urls_task 向后兼容性测试"""
|
||||
|
||||
def test_property_8_legacy_mode_uses_traditional_logic(self):
|
||||
"""
|
||||
Property 8: Task Backward Compatibility (export_site_urls_task)
|
||||
|
||||
Feature: scan-target-provider, Property 8: Task Backward Compatibility
|
||||
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
|
||||
|
||||
When calling export_site_urls_task with only target_id,
|
||||
it should use the traditional logic (_export_site_urls_legacy).
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
target_id = 123
|
||||
|
||||
# Mock HostPortMappingService
|
||||
mock_associations = [
|
||||
{'host': 'example.com', 'port': 80},
|
||||
{'host': 'test.com', 'port': 443},
|
||||
]
|
||||
|
||||
with patch('apps.scan.tasks.site_scan.export_site_urls_task.HostPortMappingService') as mock_service_class, \
|
||||
patch('apps.scan.tasks.site_scan.export_site_urls_task.BlacklistService') as mock_blacklist_service:
|
||||
|
||||
# Mock HostPortMappingService
|
||||
mock_service = MagicMock()
|
||||
mock_service.iter_host_port_by_target.return_value = iter(mock_associations)
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Mock BlacklistService
|
||||
mock_blacklist = MagicMock()
|
||||
mock_blacklist.get_rules.return_value = []
|
||||
mock_blacklist_service.return_value = mock_blacklist
|
||||
|
||||
# 调用任务(传统模式:只传 target_id)
|
||||
result = export_site_urls_task(
|
||||
output_file=output_file,
|
||||
target_id=target_id
|
||||
)
|
||||
|
||||
# 验证:返回值包含传统模式的字段
|
||||
assert result['success'] is True
|
||||
assert result['output_file'] == output_file
|
||||
assert result['total_urls'] == 2 # 80 端口生成 1 个 URL,443 端口生成 1 个 URL
|
||||
assert 'association_count' in result # 传统模式应该返回 association_count
|
||||
assert result['association_count'] == 2
|
||||
assert result['source'] == 'host_port'
|
||||
|
||||
# 验证:文件内容正确
|
||||
with open(output_file, 'r') as f:
|
||||
lines = [line.strip() for line in f.readlines()]
|
||||
assert 'http://example.com' in lines
|
||||
assert 'https://test.com' in lines
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
def test_provider_mode_uses_provider_logic(self):
|
||||
"""测试当提供 provider 时使用 Provider 模式"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
urls = ['https://example.com', 'https://test.com']
|
||||
provider = ListTargetProvider(targets=urls)
|
||||
|
||||
# 调用任务(Provider 模式)
|
||||
result = export_site_urls_task(
|
||||
output_file=output_file,
|
||||
provider=provider
|
||||
)
|
||||
|
||||
# 验证:使用了 provider
|
||||
assert result['success'] is True
|
||||
assert result['total_urls'] == len(urls)
|
||||
assert 'association_count' not in result # Provider 模式不返回 association_count
|
||||
assert result['source'] == 'provider'
|
||||
|
||||
# 验证:文件内容正确
|
||||
with open(output_file, 'r') as f:
|
||||
lines = [line.strip() for line in f.readlines()]
|
||||
assert lines == urls
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
def test_error_when_no_parameters(self):
|
||||
"""测试当 target_id 和 provider 都未提供时抛出错误"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError, match="必须提供 target_id 或 provider 参数之一"):
|
||||
export_site_urls_task(output_file=output_file)
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
@@ -1,17 +1,23 @@
|
||||
"""
|
||||
导出站点 URL 列表任务
|
||||
|
||||
使用 export_urls_with_fallback 用例函数处理回退链逻辑
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
数据源: WebSite.url → Default(用于 katana 等爬虫工具)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,21 +29,27 @@ logger = logging.getLogger(__name__)
|
||||
)
|
||||
def export_sites_task(
|
||||
output_file: str,
|
||||
target_id: int,
|
||||
scan_id: int,
|
||||
target_id: Optional[int] = None,
|
||||
scan_id: Optional[int] = None,
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出站点 URL 列表到文件(用于 katana 等爬虫工具)
|
||||
|
||||
数据源优先级(回退链):
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. WebSite 表 - 站点级别 URL
|
||||
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
Args:
|
||||
output_file: 输出文件路径
|
||||
target_id: 目标 ID
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
scan_id: 扫描 ID(保留参数,兼容旧调用)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 批次大小(内存优化)
|
||||
|
||||
Returns:
|
||||
@@ -50,6 +62,17 @@ def export_sites_task(
|
||||
ValueError: 参数错误
|
||||
RuntimeError: 执行失败
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_file=output_file,
|
||||
@@ -67,3 +90,31 @@ def export_sites_task(
|
||||
'output_file': result['output_file'],
|
||||
'asset_count': result['total_count'],
|
||||
}
|
||||
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
return {
|
||||
'output_file': str(output_path),
|
||||
'asset_count': total_count,
|
||||
}
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
"""导出 Endpoint URL 到文件的 Task
|
||||
|
||||
使用 export_urls_with_fallback 用例函数处理回退链逻辑
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
数据源优先级(回退链):
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. Endpoint.url - 最精细的 URL(含路径、参数等)
|
||||
2. WebSite.url - 站点级别 URL
|
||||
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import task
|
||||
|
||||
@@ -17,26 +20,33 @@ from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_endpoints")
|
||||
def export_endpoints_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
output_file: str = "",
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000,
|
||||
) -> Dict[str, object]:
|
||||
"""导出目标下的所有 Endpoint URL 到文本文件。
|
||||
|
||||
数据源优先级(回退链):
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. Endpoint 表 - 最精细的 URL(含路径、参数等)
|
||||
2. WebSite 表 - 站点级别 URL
|
||||
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 每次从数据库迭代的批大小
|
||||
|
||||
Returns:
|
||||
@@ -44,9 +54,20 @@ def export_endpoints_task(
|
||||
"success": bool,
|
||||
"output_file": str,
|
||||
"total_count": int,
|
||||
"source": str, # 数据来源: "endpoint" | "website" | "default" | "none"
|
||||
"source": str, # 数据来源: "endpoint" | "website" | "default" | "none" | "provider"
|
||||
}
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_file=output_file,
|
||||
@@ -65,3 +86,33 @@ def export_endpoints_task(
|
||||
"total_count": result['total_count'],
|
||||
"source": result['source'],
|
||||
}
|
||||
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> Dict[str, object]:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output_file": str(output_path),
|
||||
"total_count": total_count,
|
||||
"source": "provider",
|
||||
}
|
||||
|
||||
@@ -4,37 +4,40 @@
|
||||
提供扫描相关的工具函数。
|
||||
"""
|
||||
|
||||
from .directory_cleanup import remove_directory
|
||||
from . import config_parser
|
||||
from .command_builder import build_scan_command
|
||||
from .command_executor import execute_and_wait, execute_stream
|
||||
from .wordlist_helpers import ensure_wordlist_local
|
||||
from .directory_cleanup import remove_directory
|
||||
from .nuclei_helpers import ensure_nuclei_templates_local
|
||||
from .performance import FlowPerformanceTracker, CommandPerformanceTracker
|
||||
from .workspace_utils import setup_scan_workspace, setup_scan_directory
|
||||
from .performance import CommandPerformanceTracker, FlowPerformanceTracker
|
||||
from .system_load import check_system_load, wait_for_system_load
|
||||
from .user_logger import user_log
|
||||
from . import config_parser
|
||||
from .wordlist_helpers import ensure_wordlist_local
|
||||
from .workspace_utils import setup_scan_directory, setup_scan_workspace
|
||||
|
||||
__all__ = [
|
||||
# 目录清理
|
||||
'remove_directory',
|
||||
# 工作空间
|
||||
'setup_scan_workspace', # 创建 Scan 根工作空间
|
||||
'setup_scan_directory', # 创建扫描子目录
|
||||
'setup_scan_workspace',
|
||||
'setup_scan_directory',
|
||||
# 命令构建
|
||||
'build_scan_command', # 扫描工具命令构建(基于 f-string)
|
||||
'build_scan_command',
|
||||
# 命令执行
|
||||
'execute_and_wait', # 等待式执行(文件输出)
|
||||
'execute_stream', # 流式执行(实时处理)
|
||||
'execute_and_wait',
|
||||
'execute_stream',
|
||||
# 系统负载
|
||||
'wait_for_system_load',
|
||||
'check_system_load',
|
||||
# 字典文件
|
||||
'ensure_wordlist_local', # 确保本地字典文件(含 hash 校验)
|
||||
'ensure_wordlist_local',
|
||||
# Nuclei 模板
|
||||
'ensure_nuclei_templates_local', # 确保本地模板(含 commit hash 校验)
|
||||
'ensure_nuclei_templates_local',
|
||||
# 性能监控
|
||||
'FlowPerformanceTracker', # Flow 性能追踪器(含系统资源采样)
|
||||
'CommandPerformanceTracker', # 命令性能追踪器
|
||||
'FlowPerformanceTracker',
|
||||
'CommandPerformanceTracker',
|
||||
# 扫描日志
|
||||
'user_log', # 用户可见扫描日志记录
|
||||
'user_log',
|
||||
# 配置解析
|
||||
'config_parser',
|
||||
]
|
||||
|
||||
|
||||
@@ -12,16 +12,18 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from django.conf import settings
|
||||
import re
|
||||
import signal
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Generator
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
try:
|
||||
# 可选依赖:用于根据 CPU / 内存负载做动态并发控制
|
||||
import psutil
|
||||
@@ -354,10 +356,13 @@ class CommandExecutor:
|
||||
if log_file_path:
|
||||
error_output = self._read_log_tail(log_file_path, max_lines=MAX_LOG_TAIL_LINES)
|
||||
logger.warning(
|
||||
"扫描工具 %s 返回非零状态码: %d (执行时间: %.2f秒)%s",
|
||||
tool_name, returncode, duration,
|
||||
f"\n错误输出:\n{error_output}" if error_output else ""
|
||||
"扫描工具 %s 返回非零状态码: %d (执行时间: %.2f秒)",
|
||||
tool_name, returncode, duration
|
||||
)
|
||||
if error_output:
|
||||
for line in error_output.strip().split('\n'):
|
||||
if line.strip():
|
||||
logger.warning("%s", line)
|
||||
else:
|
||||
logger.info("✓ 扫描工具 %s 执行完成 (执行时间: %.2f秒)", tool_name, duration)
|
||||
|
||||
@@ -666,33 +671,68 @@ class CommandExecutor:
|
||||
|
||||
def _read_log_tail(self, log_file: Path, max_lines: int = MAX_LOG_TAIL_LINES) -> str:
|
||||
"""
|
||||
读取日志文件的末尾部分
|
||||
|
||||
读取日志文件的末尾部分(常量内存实现)
|
||||
|
||||
使用 seek 从文件末尾往前读取,避免将整个文件加载到内存。
|
||||
|
||||
Args:
|
||||
log_file: 日志文件路径
|
||||
max_lines: 最大读取行数
|
||||
|
||||
|
||||
Returns:
|
||||
日志内容(字符串),读取失败返回错误提示
|
||||
"""
|
||||
if not log_file.exists():
|
||||
logger.debug("日志文件不存在: %s", log_file)
|
||||
return ""
|
||||
|
||||
if log_file.stat().st_size == 0:
|
||||
|
||||
file_size = log_file.stat().st_size
|
||||
if file_size == 0:
|
||||
logger.debug("日志文件为空: %s", log_file)
|
||||
return ""
|
||||
|
||||
|
||||
# 每次读取的块大小(8KB,足够容纳大多数日志行)
|
||||
chunk_size = 8192
|
||||
|
||||
def decode_line(line_bytes: bytes) -> str:
|
||||
"""解码单行:优先 UTF-8,失败则降级 latin-1"""
|
||||
try:
|
||||
return line_bytes.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
return line_bytes.decode('latin-1', errors='replace')
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
return ''.join(lines[-max_lines:] if len(lines) > max_lines else lines)
|
||||
except UnicodeDecodeError as e:
|
||||
logger.warning("日志文件编码错误 (%s): %s", log_file, e)
|
||||
return f"(无法读取日志文件: 编码错误 - {e})"
|
||||
with open(log_file, 'rb') as f:
|
||||
lines_found: deque[bytes] = deque()
|
||||
remaining = b''
|
||||
position = file_size
|
||||
|
||||
while position > 0 and len(lines_found) < max_lines:
|
||||
read_size = min(chunk_size, position)
|
||||
position -= read_size
|
||||
|
||||
f.seek(position)
|
||||
chunk = f.read(read_size) + remaining
|
||||
parts = chunk.split(b'\n')
|
||||
|
||||
# 最前面的部分可能不完整,留到下次处理
|
||||
remaining = parts[0]
|
||||
|
||||
# 其余部分是完整的行(从后往前收集,用 appendleft 保持顺序)
|
||||
for part in reversed(parts[1:]):
|
||||
if len(lines_found) >= max_lines:
|
||||
break
|
||||
lines_found.appendleft(part)
|
||||
|
||||
# 处理文件开头的行
|
||||
if remaining and len(lines_found) < max_lines:
|
||||
lines_found.appendleft(remaining)
|
||||
|
||||
return '\n'.join(decode_line(line) for line in lines_found)
|
||||
|
||||
except PermissionError as e:
|
||||
logger.warning("日志文件权限不足 (%s): %s", log_file, e)
|
||||
return f"(无法读取日志文件: 权限不足)"
|
||||
return "(无法读取日志文件: 权限不足)"
|
||||
except IOError as e:
|
||||
logger.warning("日志文件读取IO错误 (%s): %s", log_file, e)
|
||||
return f"(无法读取日志文件: IO错误 - {e})"
|
||||
|
||||
77
backend/apps/scan/utils/system_load.py
Normal file
77
backend/apps/scan/utils/system_load.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
系统负载检查工具
|
||||
|
||||
提供统一的系统负载检查功能,用于:
|
||||
- Flow 入口处检查系统资源是否充足
|
||||
- 防止在高负载时启动新的扫描任务
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import psutil
|
||||
from django.conf import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 动态并发控制阈值(可在 Django settings 中覆盖)
|
||||
SCAN_CPU_HIGH: float = getattr(settings, 'SCAN_CPU_HIGH', 90.0)
|
||||
SCAN_MEM_HIGH: float = getattr(settings, 'SCAN_MEM_HIGH', 80.0)
|
||||
SCAN_LOAD_CHECK_INTERVAL: int = getattr(settings, 'SCAN_LOAD_CHECK_INTERVAL', 180)
|
||||
|
||||
|
||||
def _get_current_load() -> tuple[float, float]:
|
||||
"""获取当前 CPU 和内存使用率"""
|
||||
return psutil.cpu_percent(interval=0.5), psutil.virtual_memory().percent
|
||||
|
||||
|
||||
def wait_for_system_load(
|
||||
cpu_threshold: float = SCAN_CPU_HIGH,
|
||||
mem_threshold: float = SCAN_MEM_HIGH,
|
||||
check_interval: int = SCAN_LOAD_CHECK_INTERVAL,
|
||||
context: str = "task"
|
||||
) -> None:
|
||||
"""
|
||||
等待系统负载降到阈值以下
|
||||
|
||||
在高负载时阻塞等待,直到 CPU 和内存都低于阈值。
|
||||
用于 Flow 入口处,防止在资源紧张时启动新任务。
|
||||
"""
|
||||
while True:
|
||||
cpu, mem = _get_current_load()
|
||||
|
||||
if cpu < cpu_threshold and mem < mem_threshold:
|
||||
logger.debug(
|
||||
"[%s] 系统负载正常: cpu=%.1f%%, mem=%.1f%%",
|
||||
context, cpu, mem
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"[%s] 系统负载较高,等待资源释放: "
|
||||
"cpu=%.1f%% (阈值 %.1f%%), mem=%.1f%% (阈值 %.1f%%)",
|
||||
context, cpu, cpu_threshold, mem, mem_threshold
|
||||
)
|
||||
time.sleep(check_interval)
|
||||
|
||||
|
||||
def check_system_load(
|
||||
cpu_threshold: float = SCAN_CPU_HIGH,
|
||||
mem_threshold: float = SCAN_MEM_HIGH
|
||||
) -> dict:
|
||||
"""
|
||||
检查当前系统负载(非阻塞)
|
||||
|
||||
Returns:
|
||||
dict: cpu_percent, mem_percent, cpu_threshold, mem_threshold, is_overloaded
|
||||
"""
|
||||
cpu, mem = _get_current_load()
|
||||
|
||||
return {
|
||||
'cpu_percent': cpu,
|
||||
'mem_percent': mem,
|
||||
'cpu_threshold': cpu_threshold,
|
||||
'mem_threshold': mem_threshold,
|
||||
'is_overloaded': cpu >= cpu_threshold or mem >= mem_threshold,
|
||||
}
|
||||
|
||||
@@ -12,16 +12,34 @@ load-plugins = "pylint_django"
|
||||
|
||||
[tool.pylint.messages_control]
|
||||
disable = [
|
||||
"missing-docstring",
|
||||
"invalid-name",
|
||||
"too-few-public-methods",
|
||||
"no-member",
|
||||
"import-error",
|
||||
"no-name-in-module",
|
||||
"missing-docstring",
|
||||
"invalid-name",
|
||||
"too-few-public-methods",
|
||||
"no-member",
|
||||
"import-error",
|
||||
"no-name-in-module",
|
||||
"wrong-import-position", # 允许函数内导入(防循环依赖)
|
||||
"import-outside-toplevel", # 同上
|
||||
"too-many-arguments", # Django 视图/服务方法参数常超过5个
|
||||
"too-many-locals", # 复杂业务逻辑局部变量多
|
||||
"duplicate-code", # 某些模式代码相似是正常的
|
||||
]
|
||||
|
||||
[tool.pylint.format]
|
||||
max-line-length = 120
|
||||
|
||||
[tool.pylint.basic]
|
||||
good-names = ["i", "j", "k", "ex", "Run", "_", "id", "pk", "ip", "url", "db", "qs"]
|
||||
good-names = [
|
||||
"i",
|
||||
"j",
|
||||
"k",
|
||||
"ex",
|
||||
"Run",
|
||||
"_",
|
||||
"id",
|
||||
"pk",
|
||||
"ip",
|
||||
"url",
|
||||
"db",
|
||||
"qs",
|
||||
]
|
||||
|
||||
@@ -38,6 +38,7 @@ packaging>=21.0 # 版本比较
|
||||
# 测试框架
|
||||
pytest==8.0.0
|
||||
pytest-django==4.7.0
|
||||
hypothesis>=6.100.0 # 属性测试框架
|
||||
|
||||
# 工具库
|
||||
python-dateutil==2.9.0
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# ============================================
|
||||
# XingRin 远程节点安装脚本
|
||||
# 用途:安装 Docker 环境 + 预拉取镜像
|
||||
# 支持:Ubuntu / Debian
|
||||
# 支持:Ubuntu / Debian / Kali
|
||||
#
|
||||
# 架构说明:
|
||||
# 1. 安装 Docker 环境
|
||||
@@ -101,8 +101,8 @@ detect_os() {
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$OS" != "ubuntu" && "$OS" != "debian" ]]; then
|
||||
log_error "仅支持 Ubuntu/Debian 系统"
|
||||
if [[ "$OS" != "ubuntu" && "$OS" != "debian" && "$OS" != "kali" ]]; then
|
||||
log_error "仅支持 Ubuntu/Debian/Kali 系统"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
@@ -44,6 +44,8 @@ services:
|
||||
restart: always
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- IMAGE_TAG=${IMAGE_TAG:-dev}
|
||||
ports:
|
||||
- "8888:8888"
|
||||
depends_on:
|
||||
@@ -53,6 +55,8 @@ services:
|
||||
# 统一挂载数据目录
|
||||
- /opt/xingrin:/opt/xingrin
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
# OOM 优先级:-500 保护核心服务
|
||||
oom_score_adj: -500
|
||||
healthcheck:
|
||||
# 使用专门的健康检查端点(无需认证)
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8888/api/health/"]
|
||||
@@ -88,6 +92,8 @@ services:
|
||||
args:
|
||||
IMAGE_TAG: ${IMAGE_TAG:-dev}
|
||||
restart: always
|
||||
# OOM 优先级:-500 保护 Web 界面
|
||||
oom_score_adj: -500
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
@@ -97,6 +103,8 @@ services:
|
||||
context: ..
|
||||
dockerfile: docker/nginx/Dockerfile
|
||||
restart: always
|
||||
# OOM 优先级:-500 保护入口网关
|
||||
oom_score_adj: -500
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -48,6 +48,8 @@ services:
|
||||
restart: always
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- IMAGE_TAG=${IMAGE_TAG}
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
@@ -56,6 +58,8 @@ services:
|
||||
- /opt/xingrin:/opt/xingrin
|
||||
# Docker Socket 挂载:允许 Django 服务器执行本地 docker 命令(用于本地 Worker 任务分发)
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
# OOM 优先级:-500 降低被 OOM Killer 选中的概率,保护核心服务
|
||||
oom_score_adj: -500
|
||||
healthcheck:
|
||||
# 使用专门的健康检查端点(无需认证)
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8888/api/health/"]
|
||||
@@ -88,6 +92,8 @@ services:
|
||||
frontend:
|
||||
image: ${DOCKER_USER:-yyhuni}/xingrin-frontend:${IMAGE_TAG:?IMAGE_TAG is required}
|
||||
restart: always
|
||||
# OOM 优先级:-500 保护 Web 界面
|
||||
oom_score_adj: -500
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
@@ -95,6 +101,8 @@ services:
|
||||
nginx:
|
||||
image: ${DOCKER_USER:-yyhuni}/xingrin-nginx:${IMAGE_TAG:?IMAGE_TAG is required}
|
||||
restart: always
|
||||
# OOM 优先级:-500 保护入口网关
|
||||
oom_score_adj: -500
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -83,20 +83,20 @@ if not yaml_path.exists():
|
||||
print('未找到配置文件,跳过')
|
||||
exit(0)
|
||||
|
||||
new_config = yaml_path.read_text()
|
||||
|
||||
# 检查是否已有 full scan 引擎
|
||||
engine = ScanEngine.objects.filter(name='full scan').first()
|
||||
if engine:
|
||||
if not engine.configuration or not engine.configuration.strip():
|
||||
engine.configuration = yaml_path.read_text()
|
||||
engine.save(update_fields=['configuration'])
|
||||
print(f'已初始化引擎配置: {engine.name}')
|
||||
else:
|
||||
print(f'引擎已有配置,跳过')
|
||||
# 直接覆盖为最新配置
|
||||
engine.configuration = new_config
|
||||
engine.save(update_fields=['configuration'])
|
||||
print(f'已更新引擎配置: {engine.name}')
|
||||
else:
|
||||
# 创建引擎
|
||||
engine = ScanEngine.objects.create(
|
||||
name='full scan',
|
||||
configuration=yaml_path.read_text(),
|
||||
configuration=new_config,
|
||||
)
|
||||
print(f'已创建引擎: {engine.name}')
|
||||
"
|
||||
|
||||
@@ -10,7 +10,7 @@ python manage.py migrate --noinput
|
||||
echo " ✓ 数据库迁移完成"
|
||||
|
||||
echo " [1.1/3] 初始化默认扫描引擎..."
|
||||
python manage.py init_default_engine
|
||||
python manage.py init_default_engine --force
|
||||
echo " ✓ 默认扫描引擎已就绪"
|
||||
|
||||
echo " [1.2/3] 初始化默认目录字典..."
|
||||
|
||||
@@ -182,7 +182,7 @@ echo -e "${BOLD}${GREEN}══════════════════
|
||||
echo ""
|
||||
echo -e "${BOLD}访问地址${NC}"
|
||||
if [ "$WITH_FRONTEND" = true ]; then
|
||||
echo -e " XingRin: ${CYAN}https://${ACCESS_HOST}/${NC}"
|
||||
echo -e " XingRin: ${CYAN}https://${ACCESS_HOST}:8083/${NC}"
|
||||
echo -e " ${YELLOW}(HTTP 会自动跳转到 HTTPS)${NC}"
|
||||
else
|
||||
echo -e " API: ${CYAN}通过前端或 nginx 访问(后端未暴露 8888)${NC}"
|
||||
@@ -191,8 +191,3 @@ else
|
||||
echo " cd frontend && pnpm dev"
|
||||
fi
|
||||
echo ""
|
||||
echo -e "${BOLD}默认账号${NC}"
|
||||
echo " 用户名: admin"
|
||||
echo " 密码: admin"
|
||||
echo -e " ${YELLOW}[!] 请首次登录后修改密码${NC}"
|
||||
echo ""
|
||||
|
||||
@@ -29,9 +29,6 @@ RUN go install -v github.com/projectdiscovery/httpx/cmd/httpx@latest && \
|
||||
go install -v github.com/d3mondev/puredns/v2@latest && \
|
||||
go install -v github.com/yyhuni/xingfinger@latest
|
||||
|
||||
# 安装 Amass v5(禁用 CGO 以跳过 libpostal 依赖)
|
||||
RUN CGO_ENABLED=0 go install -v github.com/owasp-amass/amass/v5/cmd/amass@main
|
||||
|
||||
# 安装漏洞扫描器
|
||||
RUN go install github.com/hahwul/dalfox/v2@latest
|
||||
|
||||
@@ -45,7 +42,9 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
WORKDIR /app
|
||||
|
||||
# 1. 安装基础工具和 Python
|
||||
RUN apt-get update && apt-get install -y \
|
||||
# 注意:ARM64 使用 ports.ubuntu.com,可能存在镜像同步延迟,需要重试机制
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-venv \
|
||||
@@ -64,6 +63,9 @@ RUN apt-get update && apt-get install -y \
|
||||
libnss3 \
|
||||
libxss1 \
|
||||
libasound2t64 \
|
||||
|| (rm -rf /var/lib/apt/lists/* && apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3 python3-pip python3-venv pipx git curl wget unzip jq tmux nmap masscan libpcap-dev \
|
||||
ca-certificates fonts-liberation libnss3 libxss1 libasound2t64) \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 安装 Chromium(通过 Playwright 安装,支持 ARM64 和 AMD64)
|
||||
|
||||
@@ -29,6 +29,10 @@ export default function NotificationSettingsPage() {
|
||||
enabled: z.boolean(),
|
||||
webhookUrl: z.string().url(t("discord.urlInvalid")).or(z.literal('')),
|
||||
}),
|
||||
wecom: z.object({
|
||||
enabled: z.boolean(),
|
||||
webhookUrl: z.string().url(t("wecom.urlInvalid")).or(z.literal('')),
|
||||
}),
|
||||
categories: z.object({
|
||||
scan: z.boolean(),
|
||||
vulnerability: z.boolean(),
|
||||
@@ -46,6 +50,15 @@ export default function NotificationSettingsPage() {
|
||||
})
|
||||
}
|
||||
}
|
||||
if (val.wecom.enabled) {
|
||||
if (!val.wecom.webhookUrl || val.wecom.webhookUrl.trim() === '') {
|
||||
ctx.addIssue({
|
||||
code: z.ZodIssueCode.custom,
|
||||
message: t("wecom.requiredError"),
|
||||
path: ['wecom', 'webhookUrl'],
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const NOTIFICATION_CATEGORIES = [
|
||||
@@ -79,6 +92,7 @@ export default function NotificationSettingsPage() {
|
||||
resolver: zodResolver(schema),
|
||||
values: data ?? {
|
||||
discord: { enabled: false, webhookUrl: '' },
|
||||
wecom: { enabled: false, webhookUrl: '' },
|
||||
categories: {
|
||||
scan: true,
|
||||
vulnerability: true,
|
||||
@@ -93,6 +107,7 @@ export default function NotificationSettingsPage() {
|
||||
}
|
||||
|
||||
const discordEnabled = form.watch('discord.enabled')
|
||||
const wecomEnabled = form.watch('wecom.enabled')
|
||||
|
||||
return (
|
||||
<div className="p-4 md:p-6 space-y-6">
|
||||
@@ -187,25 +202,59 @@ export default function NotificationSettingsPage() {
|
||||
</CardHeader>
|
||||
</Card>
|
||||
|
||||
{/* Feishu/DingTalk/WeCom - Coming soon */}
|
||||
<Card className="opacity-60">
|
||||
{/* 企业微信 */}
|
||||
<Card>
|
||||
<CardHeader className="pb-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-lg bg-muted">
|
||||
<IconBrandSlack className="h-5 w-5 text-muted-foreground" />
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-lg bg-[#07C160]/10">
|
||||
<IconBrandSlack className="h-5 w-5 text-[#07C160]" />
|
||||
</div>
|
||||
<div>
|
||||
<div className="flex items-center gap-2">
|
||||
<CardTitle className="text-base">{t("enterprise.title")}</CardTitle>
|
||||
<Badge variant="secondary" className="text-xs">{t("emailChannel.comingSoon")}</Badge>
|
||||
</div>
|
||||
<CardDescription>{t("enterprise.description")}</CardDescription>
|
||||
<CardTitle className="text-base">{t("wecom.title")}</CardTitle>
|
||||
<CardDescription>{t("wecom.description")}</CardDescription>
|
||||
</div>
|
||||
</div>
|
||||
<Switch disabled />
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="wecom.enabled"
|
||||
render={({ field }) => (
|
||||
<FormControl>
|
||||
<Switch
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
disabled={isLoading || updateMutation.isPending}
|
||||
/>
|
||||
</FormControl>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</CardHeader>
|
||||
{wecomEnabled && (
|
||||
<CardContent className="pt-0">
|
||||
<Separator className="mb-4" />
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="wecom.webhookUrl"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>{t("wecom.webhookLabel")}</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
placeholder={t("wecom.webhookPlaceholder")}
|
||||
{...field}
|
||||
disabled={isLoading || updateMutation.isPending}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormDescription>
|
||||
{t("wecom.webhookHelp")}
|
||||
</FormDescription>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</CardContent>
|
||||
)}
|
||||
</Card>
|
||||
</TabsContent>
|
||||
|
||||
|
||||
189
frontend/components/about-dialog.tsx
Normal file
189
frontend/components/about-dialog.tsx
Normal file
@@ -0,0 +1,189 @@
|
||||
"use client"
|
||||
|
||||
import { useState } from 'react'
|
||||
import { useTranslations } from 'next-intl'
|
||||
import { useQueryClient } from '@tanstack/react-query'
|
||||
import {
|
||||
IconRadar,
|
||||
IconRefresh,
|
||||
IconExternalLink,
|
||||
IconBrandGithub,
|
||||
IconMessageReport,
|
||||
IconBook,
|
||||
IconFileText,
|
||||
IconCheck,
|
||||
IconArrowUp,
|
||||
} from '@tabler/icons-react'
|
||||
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Separator } from '@/components/ui/separator'
|
||||
import { Badge } from '@/components/ui/badge'
|
||||
import { useVersion } from '@/hooks/use-version'
|
||||
import { VersionService } from '@/services/version.service'
|
||||
import type { UpdateCheckResult } from '@/types/version.types'
|
||||
|
||||
interface AboutDialogProps {
|
||||
children: React.ReactNode
|
||||
}
|
||||
|
||||
export function AboutDialog({ children }: AboutDialogProps) {
|
||||
const t = useTranslations('about')
|
||||
const { data: versionData } = useVersion()
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const [isChecking, setIsChecking] = useState(false)
|
||||
const [updateResult, setUpdateResult] = useState<UpdateCheckResult | null>(null)
|
||||
const [checkError, setCheckError] = useState<string | null>(null)
|
||||
|
||||
const handleCheckUpdate = async () => {
|
||||
setIsChecking(true)
|
||||
setCheckError(null)
|
||||
try {
|
||||
const result = await VersionService.checkUpdate()
|
||||
setUpdateResult(result)
|
||||
queryClient.setQueryData(['check-update'], result)
|
||||
} catch {
|
||||
setCheckError(t('checkFailed'))
|
||||
} finally {
|
||||
setIsChecking(false)
|
||||
}
|
||||
}
|
||||
|
||||
const currentVersion = updateResult?.currentVersion || versionData?.version || '-'
|
||||
const latestVersion = updateResult?.latestVersion
|
||||
const hasUpdate = updateResult?.hasUpdate
|
||||
|
||||
return (
|
||||
<Dialog>
|
||||
<DialogTrigger asChild>
|
||||
{children}
|
||||
</DialogTrigger>
|
||||
<DialogContent className="sm:max-w-md">
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('title')}</DialogTitle>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="space-y-6">
|
||||
{/* Logo and name */}
|
||||
<div className="flex flex-col items-center py-4">
|
||||
<div className="flex h-16 w-16 items-center justify-center rounded-2xl bg-primary/10 mb-3">
|
||||
<IconRadar className="h-8 w-8 text-primary" />
|
||||
</div>
|
||||
<h2 className="text-xl font-semibold">XingRin</h2>
|
||||
<p className="text-sm text-muted-foreground">{t('description')}</p>
|
||||
</div>
|
||||
|
||||
{/* Version info */}
|
||||
<div className="rounded-lg border p-4 space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-muted-foreground">{t('currentVersion')}</span>
|
||||
<span className="font-mono text-sm">{currentVersion}</span>
|
||||
</div>
|
||||
|
||||
{updateResult && (
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-muted-foreground">{t('latestVersion')}</span>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-mono text-sm">{latestVersion}</span>
|
||||
{hasUpdate ? (
|
||||
<Badge variant="default" className="gap-1">
|
||||
<IconArrowUp className="h-3 w-3" />
|
||||
{t('updateAvailable')}
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="secondary" className="gap-1">
|
||||
<IconCheck className="h-3 w-3" />
|
||||
{t('upToDate')}
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{checkError && (
|
||||
<p className="text-sm text-destructive">{checkError}</p>
|
||||
)}
|
||||
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
className="flex-1"
|
||||
onClick={handleCheckUpdate}
|
||||
disabled={isChecking}
|
||||
>
|
||||
<IconRefresh className={`h-4 w-4 mr-2 ${isChecking ? 'animate-spin' : ''}`} />
|
||||
{isChecking ? t('checking') : t('checkUpdate')}
|
||||
</Button>
|
||||
|
||||
{hasUpdate && updateResult?.releaseUrl && (
|
||||
<Button
|
||||
variant="default"
|
||||
size="sm"
|
||||
className="flex-1"
|
||||
asChild
|
||||
>
|
||||
<a href={updateResult.releaseUrl} target="_blank" rel="noopener noreferrer">
|
||||
<IconExternalLink className="h-4 w-4 mr-2" />
|
||||
{t('viewRelease')}
|
||||
</a>
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{hasUpdate && (
|
||||
<div className="rounded-md bg-muted p-3 text-sm text-muted-foreground">
|
||||
<p>{t('updateHint')}</p>
|
||||
<code className="mt-1 block rounded bg-background px-2 py-1 font-mono text-xs">
|
||||
sudo ./update.sh
|
||||
</code>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* Links */}
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
<Button variant="ghost" size="sm" className="justify-start" asChild>
|
||||
<a href="https://github.com/yyhuni/xingrin" target="_blank" rel="noopener noreferrer">
|
||||
<IconBrandGithub className="h-4 w-4 mr-2" />
|
||||
GitHub
|
||||
</a>
|
||||
</Button>
|
||||
<Button variant="ghost" size="sm" className="justify-start" asChild>
|
||||
<a href="https://github.com/yyhuni/xingrin/releases" target="_blank" rel="noopener noreferrer">
|
||||
<IconFileText className="h-4 w-4 mr-2" />
|
||||
{t('changelog')}
|
||||
</a>
|
||||
</Button>
|
||||
<Button variant="ghost" size="sm" className="justify-start" asChild>
|
||||
<a href="https://github.com/yyhuni/xingrin/issues" target="_blank" rel="noopener noreferrer">
|
||||
<IconMessageReport className="h-4 w-4 mr-2" />
|
||||
{t('feedback')}
|
||||
</a>
|
||||
</Button>
|
||||
<Button variant="ghost" size="sm" className="justify-start" asChild>
|
||||
<a href="https://github.com/yyhuni/xingrin#readme" target="_blank" rel="noopener noreferrer">
|
||||
<IconBook className="h-4 w-4 mr-2" />
|
||||
{t('docs')}
|
||||
</a>
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Footer */}
|
||||
<p className="text-center text-xs text-muted-foreground">
|
||||
© 2025 XingRin · MIT License
|
||||
</p>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import type * as React from "react"
|
||||
// Import various icons from Tabler Icons library
|
||||
import {
|
||||
IconDashboard, // Dashboard icon
|
||||
IconHelp, // Help icon
|
||||
IconListDetails, // List details icon
|
||||
IconSettings, // Settings icon
|
||||
IconUsers, // Users icon
|
||||
@@ -15,10 +14,10 @@ import {
|
||||
IconServer, // Server icon
|
||||
IconTerminal2, // Terminal icon
|
||||
IconBug, // Vulnerability icon
|
||||
IconMessageReport, // Feedback icon
|
||||
IconSearch, // Search icon
|
||||
IconKey, // API Key icon
|
||||
IconBan, // Blacklist icon
|
||||
IconInfoCircle, // About icon
|
||||
} from "@tabler/icons-react"
|
||||
// Import internationalization hook
|
||||
import { useTranslations } from 'next-intl'
|
||||
@@ -27,8 +26,8 @@ import { Link, usePathname } from '@/i18n/navigation'
|
||||
|
||||
// Import custom navigation components
|
||||
import { NavSystem } from "@/components/nav-system"
|
||||
import { NavSecondary } from "@/components/nav-secondary"
|
||||
import { NavUser } from "@/components/nav-user"
|
||||
import { AboutDialog } from "@/components/about-dialog"
|
||||
// Import sidebar UI components
|
||||
import {
|
||||
Sidebar,
|
||||
@@ -139,20 +138,6 @@ export function AppSidebar({ ...props }: React.ComponentProps<typeof Sidebar>) {
|
||||
},
|
||||
]
|
||||
|
||||
// Secondary navigation menu items
|
||||
const navSecondary = [
|
||||
{
|
||||
title: t('feedback'),
|
||||
url: "https://github.com/yyhuni/xingrin/issues",
|
||||
icon: IconMessageReport,
|
||||
},
|
||||
{
|
||||
title: t('help'),
|
||||
url: "https://github.com/yyhuni/xingrin",
|
||||
icon: IconHelp,
|
||||
},
|
||||
]
|
||||
|
||||
// System settings related menu items
|
||||
const documents = [
|
||||
{
|
||||
@@ -271,8 +256,21 @@ export function AppSidebar({ ...props }: React.ComponentProps<typeof Sidebar>) {
|
||||
|
||||
{/* System settings navigation menu */}
|
||||
<NavSystem items={documents} />
|
||||
{/* Secondary navigation menu, using mt-auto to push to bottom */}
|
||||
<NavSecondary items={navSecondary} className="mt-auto" />
|
||||
{/* About system button */}
|
||||
<SidebarGroup className="mt-auto">
|
||||
<SidebarGroupContent>
|
||||
<SidebarMenu>
|
||||
<SidebarMenuItem>
|
||||
<AboutDialog>
|
||||
<SidebarMenuButton>
|
||||
<IconInfoCircle />
|
||||
<span>{t('about')}</span>
|
||||
</SidebarMenuButton>
|
||||
</AboutDialog>
|
||||
</SidebarMenuItem>
|
||||
</SidebarMenu>
|
||||
</SidebarGroupContent>
|
||||
</SidebarGroup>
|
||||
</SidebarContent>
|
||||
|
||||
{/* Sidebar footer */}
|
||||
|
||||
@@ -365,6 +365,7 @@ export function DashboardDataTable() {
|
||||
columns={scanColumns}
|
||||
getRowId={(row) => String(row.id)}
|
||||
enableRowSelection={false}
|
||||
enableAutoColumnSizing
|
||||
pagination={scanPagination}
|
||||
onPaginationChange={setScanPagination}
|
||||
paginationInfo={scanPaginationInfo}
|
||||
|
||||
@@ -58,14 +58,6 @@ subdomain_discovery:
|
||||
enabled: true
|
||||
timeout: 600 # 10 minutes (required)
|
||||
|
||||
amass_passive:
|
||||
enabled: true
|
||||
timeout: 600 # 10 minutes (required)
|
||||
|
||||
amass_active:
|
||||
enabled: true
|
||||
timeout: 1800 # 30 minutes (required)
|
||||
|
||||
sublist3r:
|
||||
enabled: true
|
||||
timeout: 900 # 15 minutes (required)
|
||||
|
||||
@@ -99,6 +99,8 @@ export function ScanHistoryDataTable({
|
||||
hideToolbar={hideToolbar}
|
||||
// Empty state
|
||||
emptyMessage={t("noData")}
|
||||
// Auto column sizing
|
||||
enableAutoColumnSizing
|
||||
// Custom search box
|
||||
toolbarLeft={
|
||||
<div className="flex items-center space-x-2">
|
||||
|
||||
@@ -84,6 +84,15 @@ function formatStageDuration(seconds?: number): string | undefined {
|
||||
return secs > 0 ? `${minutes}m ${secs}s` : `${minutes}m`
|
||||
}
|
||||
|
||||
// Status priority for sorting (lower = higher priority)
|
||||
const STAGE_STATUS_PRIORITY: Record<StageStatus, number> = {
|
||||
running: 0,
|
||||
pending: 1,
|
||||
completed: 2,
|
||||
failed: 3,
|
||||
cancelled: 4,
|
||||
}
|
||||
|
||||
export function ScanOverview({ scanId }: ScanOverviewProps) {
|
||||
const t = useTranslations("scan.history.overview")
|
||||
const tStatus = useTranslations("scan.history.status")
|
||||
@@ -326,7 +335,16 @@ export function ScanOverview({ scanId }: ScanOverviewProps) {
|
||||
{scan.stageProgress && Object.keys(scan.stageProgress).length > 0 ? (
|
||||
<div className="space-y-1 flex-1 min-h-0 overflow-y-auto pr-1">
|
||||
{Object.entries(scan.stageProgress)
|
||||
.sort(([, a], [, b]) => ((a as any).order ?? 0) - ((b as any).order ?? 0))
|
||||
.sort(([, a], [, b]) => {
|
||||
const progressA = a as any
|
||||
const progressB = b as any
|
||||
const priorityA = STAGE_STATUS_PRIORITY[progressA.status as StageStatus] ?? 99
|
||||
const priorityB = STAGE_STATUS_PRIORITY[progressB.status as StageStatus] ?? 99
|
||||
if (priorityA !== priorityB) {
|
||||
return priorityA - priorityB
|
||||
}
|
||||
return (progressA.order ?? 0) - (progressB.order ?? 0)
|
||||
})
|
||||
.map(([stageName, progress]) => {
|
||||
const stageProgress = progress as any
|
||||
const isRunning = stageProgress.status === "running"
|
||||
@@ -346,9 +364,6 @@ export function ScanOverview({ scanId }: ScanOverviewProps) {
|
||||
<span className={cn("truncate", isRunning && "font-medium text-foreground")}>
|
||||
{tProgress(`stages.${stageName}`)}
|
||||
</span>
|
||||
{isRunning && (
|
||||
<span className="text-[10px] text-[#d29922] shrink-0">←</span>
|
||||
)}
|
||||
</div>
|
||||
<span className="text-xs text-muted-foreground font-mono shrink-0 ml-2">
|
||||
{stageProgress.status === "completed" && stageProgress.duration
|
||||
|
||||
@@ -78,5 +78,9 @@ export function ScanLogList({ logs, loading }: ScanLogListProps) {
|
||||
)
|
||||
}
|
||||
|
||||
return <AnsiLogViewer content={content} />
|
||||
return (
|
||||
<div className="h-full">
|
||||
<AnsiLogViewer content={content} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -280,7 +280,9 @@ export function ScanProgressDialog({
|
||||
</div>
|
||||
) : (
|
||||
/* Log list */
|
||||
<ScanLogList logs={logs} loading={logsLoading} />
|
||||
<div className="h-[300px] overflow-hidden rounded-md">
|
||||
<ScanLogList logs={logs} loading={logsLoading} />
|
||||
</div>
|
||||
)}
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
@@ -350,13 +352,29 @@ function getStageResultCount(stageName: string, summary: ScanRecord["summary"]):
|
||||
* Stage names come directly from engine_config keys, no mapping needed
|
||||
* Stage order follows the order field, consistent with Flow execution order
|
||||
*/
|
||||
// Status priority for sorting (lower = higher priority)
|
||||
const STATUS_PRIORITY: Record<StageStatus, number> = {
|
||||
running: 0,
|
||||
pending: 1,
|
||||
completed: 2,
|
||||
failed: 3,
|
||||
cancelled: 4,
|
||||
}
|
||||
|
||||
export function buildScanProgressData(scan: ScanRecord): ScanProgressData {
|
||||
const stages: StageDetail[] = []
|
||||
|
||||
if (scan.stageProgress) {
|
||||
// Sort by order then iterate
|
||||
// Sort by status priority first, then by order
|
||||
const sortedEntries = Object.entries(scan.stageProgress)
|
||||
.sort(([, a], [, b]) => (a.order ?? 0) - (b.order ?? 0))
|
||||
.sort(([, a], [, b]) => {
|
||||
const priorityA = STATUS_PRIORITY[a.status] ?? 99
|
||||
const priorityB = STATUS_PRIORITY[b.status] ?? 99
|
||||
if (priorityA !== priorityB) {
|
||||
return priorityA - priorityB
|
||||
}
|
||||
return (a.order ?? 0) - (b.order ?? 0)
|
||||
})
|
||||
|
||||
for (const [stageName, progress] of sortedEntries) {
|
||||
const resultCount = progress.status === "completed"
|
||||
|
||||
@@ -55,9 +55,10 @@ function hasAnsiCodes(text: string): boolean {
|
||||
|
||||
// 解析纯文本日志内容,为日志级别添加颜色
|
||||
function colorizeLogContent(content: string): string {
|
||||
// 匹配日志格式: [时间] [级别] [模块:行号] 消息
|
||||
// 例如: [2025-01-05 10:30:00] [INFO] [apps.scan:123] 消息内容
|
||||
const logLineRegex = /^(\[\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\]) (\[(DEBUG|INFO|WARNING|WARN|ERROR|CRITICAL)\]) (.*)$/
|
||||
// 匹配日志格式:
|
||||
// 1) 系统日志: [2026-01-10 09:51:52] [INFO] [apps.scan.xxx:123] ...
|
||||
// 2) 扫描日志: [09:50:37] [INFO] [subdomain_discovery] ...
|
||||
const logLineRegex = /^(\[(?:\d{4}-\d{2}-\d{2} )?\d{2}:\d{2}:\d{2}\]) (\[(DEBUG|INFO|WARNING|WARN|ERROR|CRITICAL)\]) (.*)$/i
|
||||
|
||||
return content
|
||||
.split("\n")
|
||||
@@ -66,14 +67,15 @@ function colorizeLogContent(content: string): string {
|
||||
|
||||
if (match) {
|
||||
const [, timestamp, levelBracket, level, rest] = match
|
||||
const color = LOG_LEVEL_COLORS[level] || "#d4d4d4"
|
||||
const levelUpper = level.toUpperCase()
|
||||
const color = LOG_LEVEL_COLORS[levelUpper] || "#d4d4d4"
|
||||
// ansiConverter.toHtml 已经处理了 HTML 转义
|
||||
const escapedTimestamp = ansiConverter.toHtml(timestamp)
|
||||
const escapedLevelBracket = ansiConverter.toHtml(levelBracket)
|
||||
const escapedRest = ansiConverter.toHtml(rest)
|
||||
|
||||
// 时间戳灰色,日志级别带颜色,其余默认色
|
||||
return `<span style="color:#808080">${escapedTimestamp}</span> <span style="color:${color};font-weight:${level === "CRITICAL" ? "bold" : "normal"}">${escapedLevelBracket}</span> ${escapedRest}`
|
||||
return `<span style="color:#808080">${escapedTimestamp}</span> <span style="color:${color};font-weight:${levelUpper === "CRITICAL" ? "bold" : "normal"}">${escapedLevelBracket}</span> ${escapedRest}`
|
||||
}
|
||||
|
||||
// 非标准格式的行,也进行 HTML 转义
|
||||
@@ -85,16 +87,24 @@ function colorizeLogContent(content: string): string {
|
||||
// 高亮搜索关键词
|
||||
function highlightSearch(html: string, query: string): string {
|
||||
if (!query.trim()) return html
|
||||
|
||||
|
||||
// `ansi-to-html` 在 `escapeXML: true` 时,会把非 ASCII 字符(如中文)转成实体:
|
||||
// 例如 "中文" => "中文"。
|
||||
// 因此这里需要用同样的转义规则来生成可匹配的搜索串。
|
||||
const escapedQueryForHtml = ansiConverter.toHtml(query)
|
||||
|
||||
// 转义正则特殊字符
|
||||
const escapedQuery = query.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")
|
||||
const regex = new RegExp(`(${escapedQuery})`, "gi")
|
||||
|
||||
const escapedQuery = escapedQueryForHtml.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")
|
||||
const regex = new RegExp(`(${escapedQuery})`, "giu")
|
||||
|
||||
// 在标签外的文本中高亮关键词
|
||||
return html.replace(/(<[^>]+>)|([^<]+)/g, (match, tag, text) => {
|
||||
if (tag) return tag
|
||||
if (text) {
|
||||
return text.replace(regex, '<mark style="background:#fbbf24;color:#1e1e1e;border-radius:2px;padding:0 2px">$1</mark>')
|
||||
return text.replace(
|
||||
regex,
|
||||
'<mark style="background:#fbbf24;color:#1e1e1e;border-radius:2px;padding:0 2px">$1</mark>'
|
||||
)
|
||||
}
|
||||
return match
|
||||
})
|
||||
@@ -104,6 +114,8 @@ function highlightSearch(html: string, query: string): string {
|
||||
const LOG_LEVEL_PATTERNS = [
|
||||
// 标准格式: [2026-01-07 12:00:00] [INFO]
|
||||
/^\[\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\] \[(DEBUG|INFO|WARNING|WARN|ERROR|CRITICAL)\]/i,
|
||||
// 扫描日志格式: [09:50:37] [INFO] [stage]
|
||||
/^\[\d{2}:\d{2}:\d{2}\] \[(DEBUG|INFO|WARNING|WARN|ERROR|CRITICAL)\]/i,
|
||||
// Prefect 格式: 12:01:50.419 | WARNING | prefect
|
||||
/^[\d:.]+\s+\|\s+(DEBUG|INFO|WARNING|WARN|ERROR|CRITICAL)\s+\|/i,
|
||||
// 简单格式: [INFO] message 或 INFO: message
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client"
|
||||
|
||||
import * as React from "react"
|
||||
import { useTranslations } from "next-intl"
|
||||
import { useTranslations, useLocale } from "next-intl"
|
||||
import {
|
||||
ColumnFiltersState,
|
||||
ColumnSizingState,
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
VisibilityState,
|
||||
Updater,
|
||||
} from "@tanstack/react-table"
|
||||
import { calculateColumnWidths } from "@/lib/table-utils"
|
||||
import {
|
||||
IconChevronDown,
|
||||
IconLayoutColumns,
|
||||
@@ -145,8 +146,12 @@ export function UnifiedDataTable<TData>({
|
||||
// Styles
|
||||
className,
|
||||
tableClassName,
|
||||
|
||||
// Auto column sizing
|
||||
enableAutoColumnSizing = false,
|
||||
}: UnifiedDataTableProps<TData>) {
|
||||
const tActions = useTranslations("common.actions")
|
||||
const locale = useLocale()
|
||||
|
||||
// Internal state
|
||||
const [internalRowSelection, setInternalRowSelection] = React.useState<Record<string, boolean>>({})
|
||||
@@ -154,6 +159,7 @@ export function UnifiedDataTable<TData>({
|
||||
const [internalSorting, setInternalSorting] = React.useState<SortingState>(defaultSorting)
|
||||
const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>([])
|
||||
const [columnSizing, setColumnSizing] = React.useState<ColumnSizingState>({})
|
||||
const [autoSizingCalculated, setAutoSizingCalculated] = React.useState(false)
|
||||
const [internalPagination, setInternalPagination] = React.useState<PaginationState>({
|
||||
pageIndex: 0,
|
||||
pageSize: 10,
|
||||
@@ -232,6 +238,41 @@ export function UnifiedDataTable<TData>({
|
||||
return (data || []).filter(item => item && typeof getRowId(item) !== 'undefined')
|
||||
}, [data, getRowId])
|
||||
|
||||
// Auto column sizing: calculate optimal widths based on content
|
||||
React.useEffect(() => {
|
||||
if (!enableAutoColumnSizing || autoSizingCalculated || validData.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
// Build header labels from column meta
|
||||
const headerLabels: Record<string, string> = {}
|
||||
for (const col of columns) {
|
||||
const colDef = col as { accessorKey?: string; id?: string; meta?: { title?: string } }
|
||||
const colId = colDef.accessorKey || colDef.id
|
||||
if (colId && colDef.meta?.title) {
|
||||
headerLabels[colId] = colDef.meta.title
|
||||
}
|
||||
}
|
||||
|
||||
const calculatedWidths = calculateColumnWidths({
|
||||
data: validData as Record<string, unknown>[],
|
||||
columns: columns as Array<{
|
||||
accessorKey?: string
|
||||
id?: string
|
||||
size?: number
|
||||
minSize?: number
|
||||
maxSize?: number
|
||||
}>,
|
||||
headerLabels,
|
||||
locale,
|
||||
})
|
||||
|
||||
if (Object.keys(calculatedWidths).length > 0) {
|
||||
setColumnSizing(calculatedWidths)
|
||||
setAutoSizingCalculated(true)
|
||||
}
|
||||
}, [enableAutoColumnSizing, autoSizingCalculated, validData, columns])
|
||||
|
||||
// Create table instance
|
||||
const table = useReactTable({
|
||||
data: validData,
|
||||
|
||||
19
frontend/hooks/use-version.ts
Normal file
19
frontend/hooks/use-version.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { VersionService } from '@/services/version.service'
|
||||
|
||||
export function useVersion() {
|
||||
return useQuery({
|
||||
queryKey: ['version'],
|
||||
queryFn: () => VersionService.getVersion(),
|
||||
staleTime: Infinity,
|
||||
})
|
||||
}
|
||||
|
||||
export function useCheckUpdate() {
|
||||
return useQuery({
|
||||
queryKey: ['check-update'],
|
||||
queryFn: () => VersionService.checkUpdate(),
|
||||
enabled: false, // 手动触发
|
||||
staleTime: 5 * 60 * 1000, // 5 分钟缓存
|
||||
})
|
||||
}
|
||||
179
frontend/lib/table-utils.ts
Normal file
179
frontend/lib/table-utils.ts
Normal file
@@ -0,0 +1,179 @@
|
||||
/**
|
||||
* Table utility functions
|
||||
* Provides column width calculation and other table-related utilities
|
||||
*/
|
||||
|
||||
// Cache for text measurement context
|
||||
let measureContext: CanvasRenderingContext2D | null = null
|
||||
|
||||
/**
|
||||
* Get or create a canvas context for measuring text width
|
||||
*/
|
||||
function getMeasureContext(): CanvasRenderingContext2D {
|
||||
if (!measureContext) {
|
||||
const canvas = document.createElement('canvas')
|
||||
measureContext = canvas.getContext('2d')!
|
||||
}
|
||||
return measureContext
|
||||
}
|
||||
|
||||
/**
|
||||
* Measure text width using canvas
|
||||
* @param text - Text to measure
|
||||
* @param font - CSS font string (e.g., "14px Inter, sans-serif")
|
||||
* @returns Text width in pixels
|
||||
*/
|
||||
export function measureTextWidth(text: string, font: string = '14px Inter, system-ui, sans-serif'): number {
|
||||
const ctx = getMeasureContext()
|
||||
ctx.font = font
|
||||
return ctx.measureText(text).width
|
||||
}
|
||||
|
||||
/**
|
||||
* Options for calculating column widths
|
||||
*/
|
||||
export interface CalculateColumnWidthsOptions<TData> {
|
||||
/** Table data */
|
||||
data: TData[]
|
||||
/** Column definitions with accessorKey */
|
||||
columns: Array<{
|
||||
accessorKey?: string
|
||||
id?: string
|
||||
size?: number
|
||||
minSize?: number
|
||||
maxSize?: number
|
||||
/** If true, skip auto-sizing for this column */
|
||||
enableAutoSize?: boolean
|
||||
}>
|
||||
/** Font to use for measurement */
|
||||
font?: string
|
||||
/** Padding to add to each cell (in pixels) */
|
||||
cellPadding?: number
|
||||
/** Header font (usually slightly different from cell font) */
|
||||
headerFont?: string
|
||||
/** Header labels for columns (keyed by accessorKey or id) */
|
||||
headerLabels?: Record<string, string>
|
||||
/** Maximum number of rows to sample (for performance) */
|
||||
maxSampleRows?: number
|
||||
/** Locale for date formatting */
|
||||
locale?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate optimal column widths based on content
|
||||
* Returns a map of column id -> calculated width
|
||||
*/
|
||||
/**
|
||||
* Check if a string looks like an ISO date
|
||||
*/
|
||||
function isISODateString(value: string): boolean {
|
||||
// Match ISO 8601 format: 2024-01-09T12:00:00.000Z or 2024-01-09T12:00:00
|
||||
return /^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}/.test(value)
|
||||
}
|
||||
|
||||
/**
|
||||
* Format date for display (matching the app's date format)
|
||||
*/
|
||||
function formatDateForMeasurement(dateString: string, locale: string): string {
|
||||
try {
|
||||
return new Date(dateString).toLocaleString(locale, {
|
||||
year: "numeric",
|
||||
month: "numeric",
|
||||
day: "numeric",
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
second: "2-digit",
|
||||
hour12: false,
|
||||
})
|
||||
} catch {
|
||||
return dateString
|
||||
}
|
||||
}
|
||||
|
||||
export function calculateColumnWidths<TData extends Record<string, unknown>>({
|
||||
data,
|
||||
columns,
|
||||
font = '14px Inter, system-ui, sans-serif',
|
||||
cellPadding = 32, // Default padding for cell content
|
||||
headerFont = '500 14px Inter, system-ui, sans-serif',
|
||||
headerLabels = {},
|
||||
maxSampleRows = 100,
|
||||
locale = 'zh-CN',
|
||||
}: CalculateColumnWidthsOptions<TData>): Record<string, number> {
|
||||
const widths: Record<string, number> = {}
|
||||
|
||||
// Sample data for performance (don't measure all rows if there are too many)
|
||||
const sampleData = data.slice(0, maxSampleRows)
|
||||
|
||||
for (const column of columns) {
|
||||
const columnId = column.accessorKey || column.id
|
||||
if (!columnId) continue
|
||||
|
||||
// Skip columns that explicitly disable auto-sizing
|
||||
if (column.enableAutoSize === false) {
|
||||
if (column.size) {
|
||||
widths[columnId] = column.size
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Start with header width
|
||||
const headerLabel = headerLabels[columnId] || columnId
|
||||
let maxWidth = measureTextWidth(headerLabel, headerFont) + cellPadding
|
||||
|
||||
// Measure content width for each row
|
||||
for (const row of sampleData) {
|
||||
const value = row[columnId]
|
||||
if (value == null) continue
|
||||
|
||||
// Convert value to string for measurement
|
||||
let textValue: string
|
||||
if (typeof value === 'string') {
|
||||
// Check if it's a date string and format it
|
||||
if (isISODateString(value)) {
|
||||
textValue = formatDateForMeasurement(value, locale)
|
||||
} else {
|
||||
textValue = value
|
||||
}
|
||||
} else if (typeof value === 'number') {
|
||||
textValue = String(value)
|
||||
} else if (Array.isArray(value)) {
|
||||
// For arrays, join with comma (rough estimate)
|
||||
textValue = value.join(', ')
|
||||
} else if (typeof value === 'object') {
|
||||
// Skip complex objects - they need custom renderers
|
||||
continue
|
||||
} else {
|
||||
textValue = String(value)
|
||||
}
|
||||
|
||||
const contentWidth = measureTextWidth(textValue, font) + cellPadding
|
||||
maxWidth = Math.max(maxWidth, contentWidth)
|
||||
}
|
||||
|
||||
// Apply min/max constraints
|
||||
if (column.minSize) {
|
||||
maxWidth = Math.max(maxWidth, column.minSize)
|
||||
}
|
||||
if (column.maxSize) {
|
||||
maxWidth = Math.min(maxWidth, column.maxSize)
|
||||
}
|
||||
|
||||
widths[columnId] = Math.ceil(maxWidth)
|
||||
}
|
||||
|
||||
return widths
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook-friendly version that returns initial column sizing state
|
||||
*/
|
||||
export function getInitialColumnSizing<TData extends Record<string, unknown>>(
|
||||
options: CalculateColumnWidthsOptions<TData>
|
||||
): Record<string, number> {
|
||||
// Only run on client side
|
||||
if (typeof window === 'undefined') {
|
||||
return {}
|
||||
}
|
||||
return calculateColumnWidths(options)
|
||||
}
|
||||
@@ -325,8 +325,7 @@
|
||||
"notifications": "Notifications",
|
||||
"apiKeys": "API Keys",
|
||||
"globalBlacklist": "Global Blacklist",
|
||||
"help": "Get Help",
|
||||
"feedback": "Feedback"
|
||||
"about": "About"
|
||||
},
|
||||
"search": {
|
||||
"title": "Asset Search",
|
||||
@@ -1486,6 +1485,15 @@
|
||||
"requiredError": "Webhook URL is required when Discord is enabled",
|
||||
"urlInvalid": "Please enter a valid Discord Webhook URL"
|
||||
},
|
||||
"wecom": {
|
||||
"title": "WeCom",
|
||||
"description": "Push notifications to WeCom group bot",
|
||||
"webhookLabel": "Webhook URL",
|
||||
"webhookPlaceholder": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=...",
|
||||
"webhookHelp": "Add a bot in WeCom group and copy the Webhook URL",
|
||||
"requiredError": "Webhook URL is required when WeCom is enabled",
|
||||
"urlInvalid": "Please enter a valid WeCom Webhook URL"
|
||||
},
|
||||
"emailChannel": {
|
||||
"title": "Email",
|
||||
"description": "Receive notifications via email",
|
||||
@@ -2283,5 +2291,21 @@
|
||||
"conflict": "Resource conflict, please check and try again",
|
||||
"unauthorized": "Please login first",
|
||||
"rateLimited": "Too many requests, please try again later"
|
||||
},
|
||||
"about": {
|
||||
"title": "About XingRin",
|
||||
"description": "Attack Surface Management Platform",
|
||||
"currentVersion": "Current Version",
|
||||
"latestVersion": "Latest Version",
|
||||
"checkUpdate": "Check Update",
|
||||
"checking": "Checking...",
|
||||
"checkFailed": "Failed to check update, please try again later",
|
||||
"updateAvailable": "Update Available",
|
||||
"upToDate": "Up to Date",
|
||||
"viewRelease": "View Release",
|
||||
"updateHint": "Run the following command in project root to update:",
|
||||
"changelog": "Changelog",
|
||||
"feedback": "Feedback",
|
||||
"docs": "Documentation"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -325,8 +325,7 @@
|
||||
"notifications": "通知设置",
|
||||
"apiKeys": "API 密钥",
|
||||
"globalBlacklist": "全局黑名单",
|
||||
"help": "获取帮助",
|
||||
"feedback": "反馈建议"
|
||||
"about": "关于系统"
|
||||
},
|
||||
"search": {
|
||||
"title": "资产搜索",
|
||||
@@ -1486,6 +1485,15 @@
|
||||
"requiredError": "启用 Discord 时必须填写 Webhook URL",
|
||||
"urlInvalid": "请输入有效的 Discord Webhook URL"
|
||||
},
|
||||
"wecom": {
|
||||
"title": "企业微信",
|
||||
"description": "将通知推送到企业微信群机器人",
|
||||
"webhookLabel": "Webhook URL",
|
||||
"webhookPlaceholder": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=...",
|
||||
"webhookHelp": "在企业微信群中添加机器人,复制 Webhook 地址",
|
||||
"requiredError": "启用企业微信时必须填写 Webhook URL",
|
||||
"urlInvalid": "请输入有效的企业微信 Webhook URL"
|
||||
},
|
||||
"emailChannel": {
|
||||
"title": "邮件",
|
||||
"description": "通过邮件接收通知",
|
||||
@@ -2283,5 +2291,21 @@
|
||||
"conflict": "资源冲突,请检查后重试",
|
||||
"unauthorized": "请先登录",
|
||||
"rateLimited": "请求过于频繁,请稍后重试"
|
||||
},
|
||||
"about": {
|
||||
"title": "关于 XingRin",
|
||||
"description": "攻击面管理平台",
|
||||
"currentVersion": "当前版本",
|
||||
"latestVersion": "最新版本",
|
||||
"checkUpdate": "检查更新",
|
||||
"checking": "检查中...",
|
||||
"checkFailed": "检查更新失败,请稍后重试",
|
||||
"updateAvailable": "有更新",
|
||||
"upToDate": "已是最新",
|
||||
"viewRelease": "查看发布",
|
||||
"updateHint": "在项目根目录运行以下命令更新:",
|
||||
"changelog": "更新日志",
|
||||
"feedback": "问题反馈",
|
||||
"docs": "使用文档"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,10 @@ export const mockNotificationSettings: NotificationSettings = {
|
||||
enabled: true,
|
||||
webhookUrl: 'https://discord.com/api/webhooks/1234567890/abcdefghijklmnop',
|
||||
},
|
||||
wecom: {
|
||||
enabled: false,
|
||||
webhookUrl: '',
|
||||
},
|
||||
categories: {
|
||||
scan: true,
|
||||
vulnerability: true,
|
||||
@@ -30,6 +34,7 @@ export function updateMockNotificationSettings(
|
||||
return {
|
||||
message: 'Notification settings updated successfully',
|
||||
discord: mockNotificationSettings.discord,
|
||||
wecom: mockNotificationSettings.wecom,
|
||||
categories: mockNotificationSettings.categories,
|
||||
}
|
||||
}
|
||||
|
||||
14
frontend/services/version.service.ts
Normal file
14
frontend/services/version.service.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
import { api } from '@/lib/api-client'
|
||||
import type { VersionInfo, UpdateCheckResult } from '@/types/version.types'
|
||||
|
||||
export class VersionService {
|
||||
static async getVersion(): Promise<VersionInfo> {
|
||||
const res = await api.get<VersionInfo>('/system/version/')
|
||||
return res.data
|
||||
}
|
||||
|
||||
static async checkUpdate(): Promise<UpdateCheckResult> {
|
||||
const res = await api.get<UpdateCheckResult>('/system/check-update/')
|
||||
return res.data
|
||||
}
|
||||
}
|
||||
@@ -136,6 +136,10 @@ export interface UnifiedDataTableProps<TData> {
|
||||
// Styling
|
||||
className?: string
|
||||
tableClassName?: string
|
||||
|
||||
// Auto column sizing
|
||||
/** Enable automatic column width calculation based on content */
|
||||
enableAutoColumnSizing?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -3,6 +3,11 @@ export interface DiscordSettings {
|
||||
webhookUrl: string
|
||||
}
|
||||
|
||||
export interface WeComSettings {
|
||||
enabled: boolean
|
||||
webhookUrl: string
|
||||
}
|
||||
|
||||
/** Notification category - corresponds to backend NotificationCategory */
|
||||
export type NotificationCategory = 'scan' | 'vulnerability' | 'asset' | 'system'
|
||||
|
||||
@@ -16,6 +21,7 @@ export interface NotificationCategories {
|
||||
|
||||
export interface NotificationSettings {
|
||||
discord: DiscordSettings
|
||||
wecom: WeComSettings
|
||||
categories: NotificationCategories
|
||||
}
|
||||
|
||||
@@ -26,5 +32,6 @@ export type UpdateNotificationSettingsRequest = NotificationSettings
|
||||
export interface UpdateNotificationSettingsResponse {
|
||||
message: string
|
||||
discord: DiscordSettings
|
||||
wecom: WeComSettings
|
||||
categories: NotificationCategories
|
||||
}
|
||||
|
||||
13
frontend/types/version.types.ts
Normal file
13
frontend/types/version.types.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
export interface VersionInfo {
|
||||
version: string
|
||||
githubRepo: string
|
||||
}
|
||||
|
||||
export interface UpdateCheckResult {
|
||||
currentVersion: string
|
||||
latestVersion: string
|
||||
hasUpdate: boolean
|
||||
releaseUrl: string
|
||||
releaseNotes: string | null
|
||||
publishedAt: string | null
|
||||
}
|
||||
192
update.sh
Executable file
192
update.sh
Executable file
@@ -0,0 +1,192 @@
|
||||
#!/bin/bash
|
||||
# ============================================
|
||||
# XingRin 系统更新脚本
|
||||
# 用途:更新代码 + 同步版本 + 重建镜像 + 重启服务
|
||||
# ============================================
|
||||
#
|
||||
# 更新流程:
|
||||
# 1. 停止服务
|
||||
# 2. git pull 拉取最新代码
|
||||
# 3. 合并 .env 新配置项 + 同步 VERSION
|
||||
# 4. 构建/拉取镜像(开发模式构建,生产模式拉取)
|
||||
# 5. 启动服务(server 启动时自动执行数据库迁移)
|
||||
#
|
||||
# 用法:
|
||||
# sudo ./update.sh 生产模式更新(拉取 Docker Hub 镜像)
|
||||
# sudo ./update.sh --dev 开发模式更新(本地构建镜像)
|
||||
# sudo ./update.sh --no-frontend 更新后只启动后端
|
||||
# sudo ./update.sh --dev --no-frontend 开发环境更新后只启动后端
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
# 权限检查
|
||||
if [ "$EUID" -ne 0 ]; then
|
||||
printf "\033[0;31m✗ 请使用 sudo 运行此脚本\033[0m\n"
|
||||
printf " 正确用法: \033[1msudo ./update.sh\033[0m\n"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 跨平台 sed -i(兼容 macOS 和 Linux)
|
||||
sed_inplace() {
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
sed -i '' "$@"
|
||||
else
|
||||
sed -i "$@"
|
||||
fi
|
||||
}
|
||||
|
||||
# 解析参数判断模式
|
||||
DEV_MODE=false
|
||||
for arg in "$@"; do
|
||||
case $arg in
|
||||
--dev) DEV_MODE=true ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# 颜色定义
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
BLUE='\033[0;34m'
|
||||
CYAN='\033[0;36m'
|
||||
DIM='\033[2m'
|
||||
BOLD='\033[1m'
|
||||
NC='\033[0m'
|
||||
|
||||
# 日志函数
|
||||
log_step() { printf "${CYAN}▶${NC} %s\n" "$1"; }
|
||||
log_ok() { printf " ${GREEN}✓${NC} %s\n" "$1"; }
|
||||
log_info() { printf " ${DIM}→${NC} %s\n" "$1"; }
|
||||
log_warn() { printf " ${YELLOW}!${NC} %s\n" "$1"; }
|
||||
log_error() { printf "${RED}✗${NC} %s\n" "$1"; }
|
||||
|
||||
# 合并 .env 新配置项(保留用户已有值)
|
||||
merge_env_config() {
|
||||
local example_file="docker/.env.example"
|
||||
local env_file="docker/.env"
|
||||
|
||||
if [ ! -f "$example_file" ] || [ ! -f "$env_file" ]; then
|
||||
return
|
||||
fi
|
||||
|
||||
local new_keys=0
|
||||
|
||||
while IFS= read -r line || [ -n "$line" ]; do
|
||||
[[ -z "$line" || "$line" =~ ^# ]] && continue
|
||||
local key="${line%%=*}"
|
||||
[[ -z "$key" || "$key" == "$line" ]] && continue
|
||||
|
||||
if ! grep -q "^${key}=" "$env_file"; then
|
||||
printf '%s\n' "$line" >> "$env_file"
|
||||
log_info "新增配置: $key"
|
||||
((new_keys++))
|
||||
fi
|
||||
done < "$example_file"
|
||||
|
||||
if [ $new_keys -gt 0 ]; then
|
||||
log_ok "已添加 $new_keys 个新配置项"
|
||||
else
|
||||
log_ok "配置已是最新"
|
||||
fi
|
||||
}
|
||||
|
||||
# 显示标题
|
||||
printf "\n"
|
||||
printf "${BOLD}${BLUE}┌────────────────────────────────────────┐${NC}\n"
|
||||
if [ "$DEV_MODE" = true ]; then
|
||||
printf "${BOLD}${BLUE}│${NC} ${BOLD}XingRin 系统更新${NC} ${BOLD}${BLUE}│${NC}\n"
|
||||
printf "${BOLD}${BLUE}│${NC} ${DIM}开发模式 · 本地构建${NC} ${BOLD}${BLUE}│${NC}\n"
|
||||
else
|
||||
printf "${BOLD}${BLUE}│${NC} ${BOLD}XingRin 系统更新${NC} ${BOLD}${BLUE}│${NC}\n"
|
||||
printf "${BOLD}${BLUE}│${NC} ${DIM}生产模式 · Docker Hub${NC} ${BOLD}${BLUE}│${NC}\n"
|
||||
fi
|
||||
printf "${BOLD}${BLUE}└────────────────────────────────────────┘${NC}\n"
|
||||
printf "\n"
|
||||
|
||||
# 警告提示
|
||||
printf "${YELLOW}┌─ 注意事项 ─────────────────────────────┐${NC}\n"
|
||||
printf "${YELLOW}│${NC} • 此功能为测试性功能,可能导致升级失败 ${YELLOW}│${NC}\n"
|
||||
printf "${YELLOW}│${NC} • 升级会覆盖所有默认引擎配置 ${YELLOW}│${NC}\n"
|
||||
printf "${YELLOW}│${NC} • 自定义配置请先备份或创建新引擎 ${YELLOW}│${NC}\n"
|
||||
printf "${YELLOW}│${NC} • 推荐:卸载后重新安装以获得最佳体验 ${YELLOW}│${NC}\n"
|
||||
printf "${YELLOW}└────────────────────────────────────────┘${NC}\n"
|
||||
printf "\n"
|
||||
|
||||
printf "${YELLOW}是否继续更新?${NC} [y/N] "
|
||||
read -r ans_continue
|
||||
ans_continue=${ans_continue:-N}
|
||||
|
||||
if [[ ! $ans_continue =~ ^[Yy]$ ]]; then
|
||||
printf "\n${DIM}已取消更新${NC}\n"
|
||||
exit 0
|
||||
fi
|
||||
printf "\n"
|
||||
|
||||
# Step 1: 停止服务
|
||||
log_step "停止服务..."
|
||||
./stop.sh 2>&1 | sed 's/^/ /'
|
||||
log_ok "服务已停止"
|
||||
|
||||
# Step 2: 拉取代码
|
||||
printf "\n"
|
||||
log_step "拉取最新代码..."
|
||||
if git pull --rebase 2>&1 | sed 's/^/ /'; then
|
||||
log_ok "代码已更新"
|
||||
else
|
||||
log_error "git pull 失败,请手动解决冲突后重试"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Step 3: 检查配置更新 + 版本同步
|
||||
printf "\n"
|
||||
log_step "同步配置..."
|
||||
merge_env_config
|
||||
|
||||
# 版本同步:从 VERSION 文件更新 IMAGE_TAG
|
||||
if [ -f "VERSION" ]; then
|
||||
NEW_VERSION=$(cat VERSION | tr -d '[:space:]')
|
||||
if [ -n "$NEW_VERSION" ]; then
|
||||
if grep -q "^IMAGE_TAG=" "docker/.env"; then
|
||||
sed_inplace "s/^IMAGE_TAG=.*/IMAGE_TAG=$NEW_VERSION/" "docker/.env"
|
||||
else
|
||||
printf '%s\n' "IMAGE_TAG=$NEW_VERSION" >> "docker/.env"
|
||||
fi
|
||||
log_ok "版本同步: $NEW_VERSION"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Step 4: 构建/拉取镜像
|
||||
printf "\n"
|
||||
log_step "更新镜像..."
|
||||
|
||||
if [ "$DEV_MODE" = true ]; then
|
||||
# 开发模式:本地构建所有镜像(包括 Worker)
|
||||
log_info "构建 Worker 镜像..."
|
||||
|
||||
# 读取 IMAGE_TAG
|
||||
IMAGE_TAG=$(grep "^IMAGE_TAG=" "docker/.env" | cut -d'=' -f2)
|
||||
if [ -z "$IMAGE_TAG" ]; then
|
||||
IMAGE_TAG="dev"
|
||||
fi
|
||||
|
||||
# 构建 Worker 镜像(Worker 是临时容器,不在 compose 中,需要单独构建)
|
||||
docker build -t docker-worker -f docker/worker/Dockerfile . 2>&1 | sed 's/^/ /'
|
||||
docker tag docker-worker docker-worker:${IMAGE_TAG} 2>&1 | sed 's/^/ /'
|
||||
log_ok "Worker 镜像: docker-worker:${IMAGE_TAG}"
|
||||
|
||||
log_info "其他服务镜像将在启动时构建"
|
||||
else
|
||||
log_info "镜像将在启动时从 Docker Hub 拉取"
|
||||
fi
|
||||
|
||||
# Step 5: 启动服务
|
||||
printf "\n"
|
||||
log_step "启动服务..."
|
||||
./start.sh "$@"
|
||||
|
||||
# 完成提示
|
||||
printf "\n"
|
||||
printf "${GREEN}┌────────────────────────────────────────┐${NC}\n"
|
||||
printf "${GREEN}│${NC} ${BOLD}${GREEN}✓${NC} ${BOLD}更新完成${NC} ${GREEN}│${NC}\n"
|
||||
printf "${GREEN}└────────────────────────────────────────┘${NC}\n"
|
||||
printf "\n"
|
||||
Reference in New Issue
Block a user