mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-02-01 04:03:23 +08:00
Compare commits
17 Commits
v1.5.5-dev
...
v1.5.10-de
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d2ec93626 | ||
|
|
ced9f811f4 | ||
|
|
aa99b26f50 | ||
|
|
8342f196db | ||
|
|
1bd2a6ed88 | ||
|
|
033ff89aee | ||
|
|
4284a0cd9a | ||
|
|
943a4cb960 | ||
|
|
eb2d853b76 | ||
|
|
1184c18b74 | ||
|
|
8a6f1b6f24 | ||
|
|
255d505aba | ||
|
|
d06a9bab1f | ||
|
|
6d5c776bf7 | ||
|
|
bf058dd67b | ||
|
|
0532d7c8b8 | ||
|
|
2ee9b5ffa2 |
129
README.md
129
README.md
@@ -1,7 +1,7 @@
|
||||
<h1 align="center">XingRin - 星环</h1>
|
||||
|
||||
<p align="center">
|
||||
<b>🛡️ 攻击面管理平台 (ASM) | 自动化资产发现与漏洞扫描系统</b>
|
||||
<b>攻击面管理平台 (ASM) | 自动化资产发现与漏洞扫描系统</b>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
@@ -12,29 +12,29 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="#-功能特性">功能特性</a> •
|
||||
<a href="#-全局资产搜索">资产搜索</a> •
|
||||
<a href="#-快速开始">快速开始</a> •
|
||||
<a href="#-文档">文档</a> •
|
||||
<a href="#-反馈与贡献">反馈与贡献</a>
|
||||
<a href="#功能特性">功能特性</a> •
|
||||
<a href="#全局资产搜索">资产搜索</a> •
|
||||
<a href="#快速开始">快速开始</a> •
|
||||
<a href="#文档">文档</a> •
|
||||
<a href="#反馈与贡献">反馈与贡献</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<sub>🔍 关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | 资产搜索 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
|
||||
<sub>关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | 资产搜索 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## 🌐 在线 Demo
|
||||
## 在线 Demo
|
||||
|
||||
**[https://xingrin.vercel.app/](https://xingrin.vercel.app/)**
|
||||
|
||||
> ⚠️ 仅用于 UI 展示,未接入后端数据库
|
||||
> 仅用于 UI 展示,未接入后端数据库
|
||||
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
<b>🎨 现代化 UI </b>
|
||||
<b>现代化 UI</b>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
@@ -44,45 +44,45 @@
|
||||
<img src="docs/screenshots/quantum-rose.png" alt="Quantum Rose" width="24%">
|
||||
</p>
|
||||
|
||||
## 📚 文档
|
||||
## 文档
|
||||
|
||||
- [📖 技术文档](./docs/README.md) - 技术文档导航(🚧 持续完善中)
|
||||
- [🚀 快速开始](./docs/quick-start.md) - 一键安装和部署指南
|
||||
- [🔄 版本管理](./docs/version-management.md) - Git Tag 驱动的自动化版本管理系统
|
||||
- [📦 Nuclei 模板架构](./docs/nuclei-template-architecture.md) - 模板仓库的存储与同步
|
||||
- [📖 字典文件架构](./docs/wordlist-architecture.md) - 字典文件的存储与同步
|
||||
- [🔍 扫描流程架构](./docs/scan-flow-architecture.md) - 完整扫描流程与工具编排
|
||||
- [技术文档](./docs/README.md) - 技术文档导航(持续完善中)
|
||||
- [快速开始](./docs/quick-start.md) - 一键安装和部署指南
|
||||
- [版本管理](./docs/version-management.md) - Git Tag 驱动的自动化版本管理系统
|
||||
- [Nuclei 模板架构](./docs/nuclei-template-architecture.md) - 模板仓库的存储与同步
|
||||
- [字典文件架构](./docs/wordlist-architecture.md) - 字典文件的存储与同步
|
||||
- [扫描流程架构](./docs/scan-flow-architecture.md) - 完整扫描流程与工具编排
|
||||
|
||||
|
||||
---
|
||||
|
||||
## ✨ 功能特性
|
||||
## 功能特性
|
||||
|
||||
### 扫描能力
|
||||
|
||||
| 功能 | 状态 | 工具 | 说明 |
|
||||
|------|------|------|------|
|
||||
| 子域名扫描 | ✅ | Subfinder, Amass, PureDNS | 被动收集 + 主动爆破,聚合 50+ 数据源 |
|
||||
| 端口扫描 | ✅ | Naabu | 自定义端口范围 |
|
||||
| 站点发现 | ✅ | HTTPX | HTTP 探测,自动获取标题、状态码、技术栈 |
|
||||
| 指纹识别 | ✅ | XingFinger | 2.7W+ 指纹规则,多源指纹库 |
|
||||
| URL 收集 | ✅ | Waymore, Katana | 历史数据 + 主动爬取 |
|
||||
| 目录扫描 | ✅ | FFUF | 高速爆破,智能字典 |
|
||||
| 漏洞扫描 | ✅ | Nuclei, Dalfox | 9000+ POC 模板,XSS 检测 |
|
||||
| 站点截图 | ✅ | Playwright | WebP 高压缩存储 |
|
||||
| 子域名扫描 | 已完成 | Subfinder, Amass, PureDNS | 被动收集 + 主动爆破,聚合 50+ 数据源 |
|
||||
| 端口扫描 | 已完成 | Naabu | 自定义端口范围 |
|
||||
| 站点发现 | 已完成 | HTTPX | HTTP 探测,自动获取标题、状态码、技术栈 |
|
||||
| 指纹识别 | 已完成 | XingFinger | 2.7W+ 指纹规则,多源指纹库 |
|
||||
| URL 收集 | 已完成 | Waymore, Katana | 历史数据 + 主动爬取 |
|
||||
| 目录扫描 | 已完成 | FFUF | 高速爆破,智能字典 |
|
||||
| 漏洞扫描 | 已完成 | Nuclei, Dalfox | 9000+ POC 模板,XSS 检测 |
|
||||
| 站点截图 | 已完成 | Playwright | WebP 高压缩存储 |
|
||||
|
||||
### 平台能力
|
||||
|
||||
| 功能 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 目标管理 | ✅ | 多层级组织,支持域名/IP 目标 |
|
||||
| 资产快照 | ✅ | 扫描结果对比,追踪资产变化 |
|
||||
| 黑名单过滤 | ✅ | 全局 + Target 级,支持通配符/CIDR |
|
||||
| 定时任务 | ✅ | Cron 表达式,自动化周期扫描 |
|
||||
| 分布式扫描 | ✅ | 多 Worker 节点,负载感知调度 |
|
||||
| 全局搜索 | ✅ | 表达式语法,多字段组合查询 |
|
||||
| 通知推送 | ✅ | 企业微信、Telegram、Discord |
|
||||
| API 密钥管理 | ✅ | 可视化配置各数据源 API Key |
|
||||
| 目标管理 | 已完成 | 多层级组织,支持域名/IP 目标 |
|
||||
| 资产快照 | 已完成 | 扫描结果对比,追踪资产变化 |
|
||||
| 黑名单过滤 | 已完成 | 全局 + Target 级,支持通配符/CIDR |
|
||||
| 定时任务 | 已完成 | Cron 表达式,自动化周期扫描 |
|
||||
| 分布式扫描 | 已完成 | 多 Worker 节点,负载感知调度 |
|
||||
| 全局搜索 | 已完成 | 表达式语法,多字段组合查询 |
|
||||
| 通知推送 | 已完成 | 企业微信、Telegram、Discord |
|
||||
| API 密钥管理 | 已完成 | 可视化配置各数据源 API Key |
|
||||
|
||||
### 扫描流程架构
|
||||
|
||||
@@ -136,7 +136,7 @@ flowchart LR
|
||||
|
||||
详细说明请查看 [扫描流程架构文档](./docs/scan-flow-architecture.md)
|
||||
|
||||
### 🖥️ 分布式架构
|
||||
### 分布式架构
|
||||
- **多节点扫描** - 支持部署多个 Worker 节点,横向扩展扫描能力
|
||||
- **本地节点** - 零配置,安装即自动注册本地 Docker Worker
|
||||
- **远程节点** - SSH 一键部署远程 VPS 作为扫描节点
|
||||
@@ -181,7 +181,7 @@ flowchart TB
|
||||
W3 -.心跳上报.-> REDIS
|
||||
```
|
||||
|
||||
### 🔎 全局资产搜索
|
||||
### 全局资产搜索
|
||||
- **多类型搜索** - 支持 Website 和 Endpoint 两种资产类型
|
||||
- **表达式语法** - 支持 `=`(模糊)、`==`(精确)、`!=`(不等于)操作符
|
||||
- **逻辑组合** - 支持 `&&` (AND) 和 `||` (OR) 逻辑组合
|
||||
@@ -205,14 +205,14 @@ host="admin" && tech="php" && status=="200"
|
||||
url="/api/v1" && status!="404"
|
||||
```
|
||||
|
||||
### 📊 可视化界面
|
||||
### 可视化界面
|
||||
- **数据统计** - 资产/漏洞统计仪表盘
|
||||
- **实时通知** - WebSocket 消息推送
|
||||
- **通知推送** - 实时企业微信,tg,discard消息推送服务
|
||||
|
||||
---
|
||||
|
||||
## 📦 快速开始
|
||||
## 快速开始
|
||||
|
||||
### 环境要求
|
||||
|
||||
@@ -230,11 +230,11 @@ cd xingrin
|
||||
# 安装并启动(生产模式)
|
||||
sudo ./install.sh
|
||||
|
||||
# 🇨🇳 中国大陆用户推荐使用镜像加速(第三方加速服务可能会失效,不保证长期可用)
|
||||
# 中国大陆用户推荐使用镜像加速(第三方加速服务可能会失效,不保证长期可用)
|
||||
sudo ./install.sh --mirror
|
||||
```
|
||||
|
||||
> **💡 --mirror 参数说明**
|
||||
> **--mirror 参数说明**
|
||||
> - 自动配置 Docker 镜像加速(国内镜像源)
|
||||
> - 加速 Git 仓库克隆(Nuclei 模板等)
|
||||
|
||||
@@ -259,17 +259,17 @@ sudo ./restart.sh
|
||||
sudo ./uninstall.sh
|
||||
```
|
||||
|
||||
## 🤝 反馈与贡献
|
||||
## 反馈与贡献
|
||||
|
||||
- 💡 **发现 Bug,有新想法,比如UI设计,功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues) 或者公众号私信
|
||||
- **发现 Bug,有新想法,比如UI设计,功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues) 或者公众号私信
|
||||
|
||||
## 📧 联系
|
||||
## 联系
|
||||
- 微信公众号: **塔罗安全学苑**
|
||||
- 微信群去公众号底下的菜单,有个交流群,点击就可以看到了,链接过期可以私信我拉你
|
||||
|
||||
<img src="docs/wechat-qrcode.png" alt="微信公众号" width="200">
|
||||
|
||||
### 🎁 关注公众号免费领取指纹库
|
||||
### 关注公众号免费领取指纹库
|
||||
|
||||
| 指纹库 | 数量 |
|
||||
|--------|------|
|
||||
@@ -278,9 +278,9 @@ sudo ./uninstall.sh
|
||||
| goby.json | 7,086 |
|
||||
| FingerprintHub.json | 3,147 |
|
||||
|
||||
> 💡 关注公众号回复「指纹」即可获取
|
||||
> 关注公众号回复「指纹」即可获取
|
||||
|
||||
## ☕ 赞助支持
|
||||
## 赞助支持
|
||||
|
||||
如果这个项目对你有帮助,谢谢请我能喝杯蜜雪冰城,你的star和赞助是我免费更新的动力
|
||||
|
||||
@@ -289,14 +289,8 @@ sudo ./uninstall.sh
|
||||
<img src="docs/zfb_pay.jpg" alt="支付宝" width="200">
|
||||
</p>
|
||||
|
||||
### 🙏 感谢以下赞助
|
||||
|
||||
| 昵称 | 金额 |
|
||||
|------|------|
|
||||
| X(闭关中) | ¥88 |
|
||||
|
||||
|
||||
## ⚠️ 免责声明
|
||||
## 免责声明
|
||||
|
||||
**重要:请在使用前仔细阅读**
|
||||
|
||||
@@ -311,30 +305,29 @@ sudo ./uninstall.sh
|
||||
- 遵守所在地区的法律法规
|
||||
- 承担因滥用产生的一切后果
|
||||
|
||||
## 🌟 Star History
|
||||
## Star History
|
||||
|
||||
如果这个项目对你有帮助,请给一个 ⭐ Star 支持一下!
|
||||
如果这个项目对你有帮助,请给一个 Star 支持一下!
|
||||
|
||||
[](https://star-history.com/#yyhuni/xingrin&Date)
|
||||
|
||||
## 📄 许可证
|
||||
## 许可证
|
||||
|
||||
本项目采用 [GNU General Public License v3.0](LICENSE) 许可证。
|
||||
|
||||
### 允许的用途
|
||||
|
||||
- ✅ 个人学习和研究
|
||||
- ✅ 商业和非商业使用
|
||||
- ✅ 修改和分发
|
||||
- ✅ 专利使用
|
||||
- ✅ 私人使用
|
||||
- 个人学习和研究
|
||||
- 商业和非商业使用
|
||||
- 修改和分发
|
||||
- 专利使用
|
||||
- 私人使用
|
||||
|
||||
### 义务和限制
|
||||
|
||||
- 📋 **开源义务**:分发时必须提供源代码
|
||||
- 📋 **相同许可**:衍生作品必须使用相同许可证
|
||||
- 📋 **版权声明**:必须保留原始版权和许可证声明
|
||||
- ❌ **责任免除**:不提供任何担保
|
||||
- ❌ 未经授权的渗透测试
|
||||
- ❌ 任何违法行为
|
||||
|
||||
- **开源义务**:分发时必须提供源代码
|
||||
- **相同许可**:衍生作品必须使用相同许可证
|
||||
- **版权声明**:必须保留原始版权和许可证声明
|
||||
- **责任免除**:不提供任何担保
|
||||
- 未经授权的渗透测试
|
||||
- 任何违法行为
|
||||
|
||||
@@ -195,3 +195,32 @@ class DjangoHostPortMappingSnapshotRepository:
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
yield row
|
||||
|
||||
def iter_unique_host_ports_by_scan(
|
||||
self,
|
||||
scan_id: int,
|
||||
batch_size: int = 1000
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取扫描下的唯一 host:port 组合(去重)
|
||||
|
||||
用于生成 URL 时避免重复,同一个 host:port 可能对应多个 IP,
|
||||
但生成 URL 时只需要一个。
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
batch_size: 每批数据量
|
||||
|
||||
Yields:
|
||||
{'host': 'example.com', 'port': 80}
|
||||
"""
|
||||
qs = (
|
||||
HostPortMappingSnapshot.objects
|
||||
.filter(scan_id=scan_id)
|
||||
.values('host', 'port')
|
||||
.distinct()
|
||||
.order_by('host', 'port')
|
||||
)
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
yield row
|
||||
|
||||
@@ -146,7 +146,9 @@ class ScreenshotService:
|
||||
"""
|
||||
from apps.asset.models import Screenshot, ScreenshotSnapshot
|
||||
|
||||
snapshots = ScreenshotSnapshot.objects.filter(scan_id=scan_id)
|
||||
# 使用 iterator() 避免 QuerySet 缓存大量 BinaryField 数据导致内存飙升
|
||||
# chunk_size=50: 每次只加载 50 条记录,处理完后释放内存
|
||||
snapshots = ScreenshotSnapshot.objects.filter(scan_id=scan_id).iterator(chunk_size=50)
|
||||
count = 0
|
||||
|
||||
for snapshot in snapshots:
|
||||
|
||||
@@ -1,72 +1,18 @@
|
||||
"""Endpoint Snapshots Service - 业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
from typing import Iterator, List, Optional
|
||||
|
||||
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
|
||||
from apps.asset.repositories.snapshot import DjangoEndpointSnapshotRepository
|
||||
from apps.asset.services.asset import EndpointService
|
||||
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndpointSnapshotsService:
|
||||
"""端点快照服务 - 统一管理快照和资产同步"""
|
||||
|
||||
def __init__(self):
|
||||
self.snapshot_repo = DjangoEndpointSnapshotRepository()
|
||||
self.asset_service = EndpointService()
|
||||
|
||||
def save_and_sync(self, items: List[EndpointSnapshotDTO]) -> None:
|
||||
"""
|
||||
保存端点快照并同步到资产表(统一入口)
|
||||
|
||||
流程:
|
||||
1. 保存到快照表(完整记录)
|
||||
2. 同步到资产表(去重)
|
||||
|
||||
Args:
|
||||
items: 端点快照 DTO 列表(必须包含 target_id)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 items 中的 target_id 为 None
|
||||
Exception: 数据库操作失败
|
||||
"""
|
||||
if not items:
|
||||
return
|
||||
|
||||
# 检查 Scan 是否仍存在(防止删除后竞态写入)
|
||||
scan_id = items[0].scan_id
|
||||
from apps.scan.repositories import DjangoScanRepository
|
||||
if not DjangoScanRepository().exists(scan_id):
|
||||
logger.warning("Scan 已删除,跳过端点快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
|
||||
return
|
||||
|
||||
try:
|
||||
logger.debug("保存端点快照并同步到资产表 - 数量: %d", len(items))
|
||||
|
||||
# 步骤 1: 保存到快照表
|
||||
logger.debug("步骤 1: 保存到快照表")
|
||||
self.snapshot_repo.save_snapshots(items)
|
||||
|
||||
# 步骤 2: 转换为资产 DTO 并保存到资产表
|
||||
# 使用 upsert:新记录插入,已存在的记录更新
|
||||
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
|
||||
asset_items = [item.to_asset_dto() for item in items]
|
||||
|
||||
self.asset_service.bulk_upsert(asset_items)
|
||||
|
||||
logger.info("端点快照和资产数据保存成功 - 数量: %d", len(items))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"保存端点快照失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'url': 'url',
|
||||
@@ -76,26 +22,89 @@ class EndpointSnapshotsService:
|
||||
'webserver': 'webserver',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
def get_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
|
||||
def __init__(self):
|
||||
self.snapshot_repo = DjangoEndpointSnapshotRepository()
|
||||
self.asset_service = EndpointService()
|
||||
|
||||
def save_and_sync(self, items: List[EndpointSnapshotDTO]) -> None:
|
||||
"""
|
||||
保存端点快照并同步到资产表(统一入口)
|
||||
|
||||
流程:
|
||||
1. 保存到快照表(完整记录)
|
||||
2. 同步到资产表(去重)
|
||||
|
||||
Args:
|
||||
items: 端点快照 DTO 列表(必须包含 target_id)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 items 中的 target_id 为 None
|
||||
Exception: 数据库操作失败
|
||||
"""
|
||||
if not items:
|
||||
return
|
||||
|
||||
# 检查 Scan 是否仍存在(防止删除后竞态写入)
|
||||
scan_id = items[0].scan_id
|
||||
from apps.scan.repositories import DjangoScanRepository
|
||||
if not DjangoScanRepository().exists(scan_id):
|
||||
logger.warning("Scan 已删除,跳过端点快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
|
||||
return
|
||||
|
||||
try:
|
||||
logger.debug("保存端点快照并同步到资产表 - 数量: %d", len(items))
|
||||
|
||||
# 步骤 1: 保存到快照表
|
||||
self.snapshot_repo.save_snapshots(items)
|
||||
|
||||
# 步骤 2: 转换为资产 DTO 并保存到资产表(upsert)
|
||||
asset_items = [item.to_asset_dto() for item in items]
|
||||
self.asset_service.bulk_upsert(asset_items)
|
||||
|
||||
logger.info("端点快照和资产数据保存成功 - 数量: %d", len(items))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("保存端点快照失败 - 数量: %d, 错误: %s", len(items), str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int, filter_query: Optional[str] = None):
|
||||
"""
|
||||
获取指定扫描的端点快照
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
filter_query: 过滤查询字符串
|
||||
|
||||
Returns:
|
||||
QuerySet: 端点快照查询集
|
||||
"""
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self, filter_query: str = None):
|
||||
"""获取所有端点快照"""
|
||||
def get_all(self, filter_query: Optional[str] = None):
|
||||
"""
|
||||
获取所有端点快照
|
||||
|
||||
Args:
|
||||
filter_query: 过滤查询字符串
|
||||
|
||||
Returns:
|
||||
QuerySet: 端点快照查询集
|
||||
"""
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
|
||||
queryset = self.snapshot_repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_endpoint_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有端点 URL。"""
|
||||
"""流式获取某次扫描下的所有端点 URL"""
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
for snapshot in queryset.iterator(chunk_size=chunk_size):
|
||||
yield snapshot.url
|
||||
@@ -103,10 +112,10 @@ class EndpointSnapshotsService:
|
||||
def iter_raw_data_for_csv_export(self, scan_id: int) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
|
||||
|
||||
Yields:
|
||||
原始数据字典
|
||||
"""
|
||||
|
||||
@@ -91,3 +91,25 @@ class HostPortMappingSnapshotsService:
|
||||
原始数据字典 {ip, host, port, created_at}
|
||||
"""
|
||||
return self.snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)
|
||||
|
||||
def iter_unique_host_ports_by_scan(
|
||||
self,
|
||||
scan_id: int,
|
||||
batch_size: int = 1000
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取扫描下的唯一 host:port 组合(去重)
|
||||
|
||||
用于生成 URL 时避免重复。
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
batch_size: 每批数据量
|
||||
|
||||
Yields:
|
||||
{'host': 'example.com', 'port': 80}
|
||||
"""
|
||||
return self.snapshot_repo.iter_unique_host_ports_by_scan(
|
||||
scan_id=scan_id,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from .views import (
|
||||
LoginView, LogoutView, MeView, ChangePasswordView,
|
||||
SystemLogsView, SystemLogFilesView, HealthCheckView,
|
||||
GlobalBlacklistView,
|
||||
VersionView, CheckUpdateView,
|
||||
)
|
||||
|
||||
urlpatterns = [
|
||||
@@ -29,6 +30,8 @@ urlpatterns = [
|
||||
# 系统管理
|
||||
path('system/logs/', SystemLogsView.as_view(), name='system-logs'),
|
||||
path('system/logs/files/', SystemLogFilesView.as_view(), name='system-log-files'),
|
||||
path('system/version/', VersionView.as_view(), name='system-version'),
|
||||
path('system/check-update/', CheckUpdateView.as_view(), name='system-check-update'),
|
||||
|
||||
# 黑名单管理(PUT 全量替换模式)
|
||||
path('blacklist/rules/', GlobalBlacklistView.as_view(), name='blacklist-rules'),
|
||||
|
||||
@@ -6,16 +6,19 @@
|
||||
- 认证相关视图:登录、登出、用户信息、修改密码
|
||||
- 系统日志视图:实时日志查看
|
||||
- 黑名单视图:全局黑名单规则管理
|
||||
- 版本视图:系统版本和更新检查
|
||||
"""
|
||||
|
||||
from .health_views import HealthCheckView
|
||||
from .auth_views import LoginView, LogoutView, MeView, ChangePasswordView
|
||||
from .system_log_views import SystemLogsView, SystemLogFilesView
|
||||
from .blacklist_views import GlobalBlacklistView
|
||||
from .version_views import VersionView, CheckUpdateView
|
||||
|
||||
__all__ = [
|
||||
'HealthCheckView',
|
||||
'LoginView', 'LogoutView', 'MeView', 'ChangePasswordView',
|
||||
'SystemLogsView', 'SystemLogFilesView',
|
||||
'GlobalBlacklistView',
|
||||
'VersionView', 'CheckUpdateView',
|
||||
]
|
||||
|
||||
136
backend/apps/common/views/version_views.py
Normal file
136
backend/apps/common/views/version_views.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
系统版本相关视图
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
from apps.common.response_helpers import error_response, success_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# GitHub 仓库信息
|
||||
GITHUB_REPO = "yyhuni/xingrin"
|
||||
GITHUB_API_URL = f"https://api.github.com/repos/{GITHUB_REPO}/releases/latest"
|
||||
GITHUB_RELEASES_URL = f"https://github.com/{GITHUB_REPO}/releases"
|
||||
|
||||
|
||||
def get_current_version() -> str:
|
||||
"""读取当前版本号"""
|
||||
import os
|
||||
|
||||
# 方式1:从环境变量读取(Docker 容器中推荐)
|
||||
version = os.environ.get('IMAGE_TAG', '')
|
||||
if version:
|
||||
return version
|
||||
|
||||
# 方式2:从文件读取(开发环境)
|
||||
possible_paths = [
|
||||
Path('/app/VERSION'),
|
||||
Path(__file__).parent.parent.parent.parent.parent / 'VERSION',
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
try:
|
||||
return path.read_text(encoding='utf-8').strip()
|
||||
except (FileNotFoundError, OSError):
|
||||
continue
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def compare_versions(current: str, latest: str) -> bool:
|
||||
"""
|
||||
比较版本号,判断是否有更新
|
||||
|
||||
Returns:
|
||||
True 表示有更新可用
|
||||
"""
|
||||
def parse_version(v: str) -> tuple:
|
||||
v = v.lstrip('v')
|
||||
parts = v.split('.')
|
||||
result = []
|
||||
for part in parts:
|
||||
if '-' in part:
|
||||
num, _ = part.split('-', 1)
|
||||
result.append(int(num))
|
||||
else:
|
||||
result.append(int(part))
|
||||
return tuple(result)
|
||||
|
||||
try:
|
||||
return parse_version(latest) > parse_version(current)
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
class VersionView(APIView):
|
||||
"""获取当前系统版本"""
|
||||
|
||||
def get(self, _request: Request) -> Response:
|
||||
"""获取当前版本信息"""
|
||||
return success_response(data={
|
||||
'version': get_current_version(),
|
||||
'github_repo': GITHUB_REPO,
|
||||
})
|
||||
|
||||
|
||||
class CheckUpdateView(APIView):
|
||||
"""检查系统更新"""
|
||||
|
||||
def get(self, _request: Request) -> Response:
|
||||
"""
|
||||
检查是否有新版本
|
||||
|
||||
Returns:
|
||||
- current_version: 当前版本
|
||||
- latest_version: 最新版本
|
||||
- has_update: 是否有更新
|
||||
- release_url: 发布页面 URL
|
||||
- release_notes: 更新说明(如果有)
|
||||
"""
|
||||
current_version = get_current_version()
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
GITHUB_API_URL,
|
||||
headers={'Accept': 'application/vnd.github.v3+json'},
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
return success_response(data={
|
||||
'current_version': current_version,
|
||||
'latest_version': current_version,
|
||||
'has_update': False,
|
||||
'release_url': GITHUB_RELEASES_URL,
|
||||
'release_notes': None,
|
||||
})
|
||||
|
||||
response.raise_for_status()
|
||||
release_data = response.json()
|
||||
|
||||
latest_version = release_data.get('tag_name', current_version)
|
||||
has_update = compare_versions(current_version, latest_version)
|
||||
|
||||
return success_response(data={
|
||||
'current_version': current_version,
|
||||
'latest_version': latest_version,
|
||||
'has_update': has_update,
|
||||
'release_url': release_data.get('html_url', GITHUB_RELEASES_URL),
|
||||
'release_notes': release_data.get('body'),
|
||||
'published_at': release_data.get('published_at'),
|
||||
})
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.warning("检查更新失败: %s", e)
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message="无法连接到 GitHub,请稍后重试",
|
||||
)
|
||||
@@ -2,8 +2,9 @@
|
||||
初始化默认扫描引擎
|
||||
|
||||
用法:
|
||||
python manage.py init_default_engine # 只创建不存在的引擎(不覆盖已有)
|
||||
python manage.py init_default_engine --force # 强制覆盖所有引擎配置
|
||||
python manage.py init_default_engine # 只创建不存在的引擎(不覆盖已有)
|
||||
python manage.py init_default_engine --force # 强制覆盖所有引擎配置
|
||||
python manage.py init_default_engine --force-sub # 只覆盖子引擎,保留 full scan
|
||||
|
||||
cd /root/my-vulun-scan/docker
|
||||
docker compose exec server python backend/manage.py init_default_engine --force
|
||||
@@ -12,6 +13,7 @@
|
||||
- 读取 engine_config_example.yaml 作为默认配置
|
||||
- 创建 full scan(默认引擎)+ 各扫描类型的子引擎
|
||||
- 默认不覆盖已有配置,加 --force 才会覆盖
|
||||
- 加 --force-sub 只覆盖子引擎配置,保留用户自定义的 full scan
|
||||
"""
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
@@ -30,11 +32,18 @@ class Command(BaseCommand):
|
||||
parser.add_argument(
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='强制覆盖已有的引擎配置',
|
||||
help='强制覆盖已有的引擎配置(包括 full scan 和子引擎)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--force-sub',
|
||||
action='store_true',
|
||||
help='只覆盖子引擎配置,保留 full scan(升级时使用)',
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
force = options.get('force', False)
|
||||
force_sub = options.get('force_sub', False)
|
||||
|
||||
# 读取默认配置文件
|
||||
config_path = Path(__file__).resolve().parent.parent.parent.parent / 'scan' / 'configs' / 'engine_config_example.yaml'
|
||||
|
||||
@@ -99,15 +108,22 @@ class Command(BaseCommand):
|
||||
engine_name = f"{scan_type}"
|
||||
sub_engine = ScanEngine.objects.filter(name=engine_name).first()
|
||||
if sub_engine:
|
||||
if force:
|
||||
# force 或 force_sub 都会覆盖子引擎
|
||||
if force or force_sub:
|
||||
sub_engine.configuration = single_yaml
|
||||
sub_engine.save()
|
||||
self.stdout.write(self.style.SUCCESS(f' ✓ 子引擎 {engine_name} 配置已更新 (ID: {sub_engine.id})'))
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f' ✓ 子引擎 {engine_name} 配置已更新 (ID: {sub_engine.id})'
|
||||
))
|
||||
else:
|
||||
self.stdout.write(self.style.WARNING(f' ⊘ {engine_name} 已存在,跳过(使用 --force 覆盖)'))
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f' ⊘ {engine_name} 已存在,跳过(使用 --force 覆盖)'
|
||||
))
|
||||
else:
|
||||
sub_engine = ScanEngine.objects.create(
|
||||
name=engine_name,
|
||||
configuration=single_yaml,
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS(f' ✓ 子引擎 {engine_name} 已创建 (ID: {sub_engine.id})'))
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f' ✓ 子引擎 {engine_name} 已创建 (ID: {sub_engine.id})'
|
||||
))
|
||||
|
||||
@@ -449,34 +449,33 @@ class TaskDistributor:
|
||||
def execute_scan_flow(
|
||||
self,
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
target_name: str,
|
||||
scan_workspace_dir: str,
|
||||
engine_name: str,
|
||||
scheduled_scan_name: str | None = None,
|
||||
) -> tuple[bool, str, Optional[str], Optional[int]]:
|
||||
"""
|
||||
在远程或本地 Worker 上执行扫描 Flow
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作目录
|
||||
engine_name: 引擎名称
|
||||
scheduled_scan_name: 定时扫描任务名称(可选)
|
||||
|
||||
|
||||
Returns:
|
||||
(success, message, container_id, worker_id) 元组
|
||||
|
||||
|
||||
Note:
|
||||
engine_config 由 Flow 内部通过 scan_id 查询数据库获取
|
||||
"""
|
||||
logger.info("="*60)
|
||||
logger.info("execute_scan_flow 开始")
|
||||
logger.info(" scan_id: %s", scan_id)
|
||||
logger.info(" target_name: %s", target_name)
|
||||
logger.info(" target_id: %s", target_id)
|
||||
logger.info(" target_name: %s", target_name)
|
||||
logger.info(" scan_workspace_dir: %s", scan_workspace_dir)
|
||||
logger.info(" engine_name: %s", engine_name)
|
||||
logger.info(" docker_image: %s", self.docker_image)
|
||||
@@ -495,23 +494,22 @@ class TaskDistributor:
|
||||
# 3. 构建 docker run 命令
|
||||
script_args = {
|
||||
'scan_id': scan_id,
|
||||
'target_name': target_name,
|
||||
'target_id': target_id,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'engine_name': engine_name,
|
||||
}
|
||||
if scheduled_scan_name:
|
||||
script_args['scheduled_scan_name'] = scheduled_scan_name
|
||||
|
||||
|
||||
docker_cmd = self._build_docker_command(
|
||||
worker=worker,
|
||||
script_module='apps.scan.scripts.run_initiate_scan',
|
||||
script_args=script_args,
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
"提交扫描任务到 Worker: %s - Scan ID: %d, Target: %s",
|
||||
worker.name, scan_id, target_name
|
||||
"提交扫描任务到 Worker: %s - Scan ID: %d, Target: %s (ID: %d)",
|
||||
worker.name, scan_id, target_name, target_id
|
||||
)
|
||||
|
||||
# 4. 执行 docker run(本地直接执行,远程通过 SSH)
|
||||
|
||||
@@ -203,7 +203,7 @@ VULN_SCAN_COMMANDS = {
|
||||
# -silent: 静默模式
|
||||
# -l: 输入 URL 列表文件
|
||||
# -t: 模板目录路径(支持多个仓库,多次 -t 由 template_args 直接拼接)
|
||||
'base': "nuclei -j -silent -l '{endpoints_file}' {template_args}",
|
||||
'base': "nuclei -j -silent -l '{input_file}' {template_args}",
|
||||
'optional': {
|
||||
'concurrency': '-c {concurrency}', # 并发数(默认 25)
|
||||
'rate_limit': '-rl {rate_limit}', # 每秒请求数限制
|
||||
@@ -214,7 +214,12 @@ VULN_SCAN_COMMANDS = {
|
||||
'tags': '-tags {tags}', # 过滤标签
|
||||
'exclude_tags': '-etags {exclude_tags}', # 排除标签
|
||||
},
|
||||
'input_type': 'endpoints_file',
|
||||
# 支持多种输入类型,用户通过 scan_endpoints/scan_websites 选择
|
||||
'input_types': ['endpoints_file', 'websites_file'],
|
||||
'defaults': {
|
||||
'scan_endpoints': False, # 默认不扫描 endpoints
|
||||
'scan_websites': True, # 默认扫描 websites
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -158,7 +158,9 @@ vuln_scan:
|
||||
|
||||
nuclei:
|
||||
enabled: true
|
||||
# timeout: auto # 自动计算(根据 endpoints 行数)
|
||||
# timeout: auto # 自动计算(根据输入 URL 行数)
|
||||
scan-endpoints: false # 是否扫描 endpoints(默认关闭)
|
||||
scan-websites: true # 是否扫描 websites(默认开启)
|
||||
template-repo-names: # 模板仓库列表,对应「Nuclei 模板」中的仓库名
|
||||
- nuclei-templates
|
||||
# - nuclei-custom # 可追加自定义仓库
|
||||
|
||||
@@ -107,7 +107,8 @@ def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> i
|
||||
|
||||
def _export_site_urls(
|
||||
target_id: int,
|
||||
directory_scan_dir: Path
|
||||
directory_scan_dir: Path,
|
||||
provider,
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
导出目标下的所有站点 URL 到文件
|
||||
@@ -115,6 +116,7 @@ def _export_site_urls(
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
directory_scan_dir: 目录扫描目录
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
tuple: (sites_file, site_count)
|
||||
@@ -123,9 +125,8 @@ def _export_site_urls(
|
||||
|
||||
sites_file = str(directory_scan_dir / 'sites.txt')
|
||||
export_result = export_sites_task(
|
||||
target_id=target_id,
|
||||
output_file=sites_file,
|
||||
batch_size=1000
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
site_count = export_result['total_count']
|
||||
@@ -389,10 +390,10 @@ def _run_scans_concurrently(
|
||||
)
|
||||
def directory_scan_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict
|
||||
enabled_tools: dict,
|
||||
provider,
|
||||
) -> dict:
|
||||
"""
|
||||
目录扫描 Flow
|
||||
@@ -404,10 +405,10 @@ def directory_scan_flow(
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置字典
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
dict: 扫描结果
|
||||
@@ -415,6 +416,11 @@ def directory_scan_flow(
|
||||
try:
|
||||
wait_for_system_load(context="directory_scan_flow")
|
||||
|
||||
# 从 provider 获取 target_name
|
||||
target_name = provider.get_target_name()
|
||||
if not target_name:
|
||||
raise ValueError("无法获取 Target 名称")
|
||||
|
||||
logger.info(
|
||||
"开始目录扫描 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
@@ -424,8 +430,6 @@ def directory_scan_flow(
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
@@ -438,7 +442,9 @@ def directory_scan_flow(
|
||||
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
|
||||
|
||||
# Step 1: 导出站点 URL
|
||||
sites_file, site_count = _export_site_urls(target_id, directory_scan_dir)
|
||||
sites_file, site_count = _export_site_urls(
|
||||
target_id, directory_scan_dir, provider
|
||||
)
|
||||
|
||||
if site_count == 0:
|
||||
logger.warning("跳过目录扫描:没有站点可扫描 - Scan ID: %s", scan_id)
|
||||
|
||||
@@ -11,8 +11,10 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prefect import flow
|
||||
|
||||
@@ -22,183 +24,147 @@ from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.tasks.fingerprint_detect import (
|
||||
export_urls_for_fingerprint_task,
|
||||
export_site_urls_for_fingerprint_task,
|
||||
run_xingfinger_and_stream_update_tech_task,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
|
||||
from apps.scan.utils import build_scan_command, setup_scan_directory, user_log, wait_for_system_load
|
||||
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FingerprintContext:
|
||||
"""指纹识别上下文,用于在各函数间传递状态"""
|
||||
scan_id: int
|
||||
target_id: int
|
||||
target_name: str
|
||||
scan_workspace_dir: str
|
||||
fingerprint_dir: Optional[Path] = None
|
||||
urls_file: str = ""
|
||||
url_count: int = 0
|
||||
source: str = "website"
|
||||
|
||||
|
||||
def calculate_fingerprint_detect_timeout(
|
||||
url_count: int,
|
||||
base_per_url: float = 10.0,
|
||||
min_timeout: int = 300
|
||||
) -> int:
|
||||
"""
|
||||
根据 URL 数量计算超时时间
|
||||
|
||||
公式:超时时间 = URL 数量 × 每 URL 基础时间
|
||||
最小值:300秒,无上限
|
||||
|
||||
Args:
|
||||
url_count: URL 数量
|
||||
base_per_url: 每 URL 基础时间(秒),默认 10秒
|
||||
min_timeout: 最小超时时间(秒),默认 300秒
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒)
|
||||
"""
|
||||
"""根据 URL 数量计算超时时间(最小 300 秒)"""
|
||||
return max(min_timeout, int(url_count * base_per_url))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _export_urls(
|
||||
target_id: int,
|
||||
fingerprint_dir: Path,
|
||||
source: str = 'website'
|
||||
) -> tuple[str, int]:
|
||||
"""
|
||||
导出 URL 到文件
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
fingerprint_dir: 指纹识别目录
|
||||
source: 数据源类型
|
||||
|
||||
Returns:
|
||||
tuple: (urls_file, total_count)
|
||||
"""
|
||||
logger.info("Step 1: 导出 URL 列表 (source=%s)", source)
|
||||
def _export_urls(fingerprint_dir: Path, provider) -> tuple[str, int]:
|
||||
"""导出 URL 到文件,返回 (urls_file, total_count)"""
|
||||
logger.info("Step 1: 导出 URL 列表")
|
||||
|
||||
urls_file = str(fingerprint_dir / 'urls.txt')
|
||||
export_result = export_urls_for_fingerprint_task(
|
||||
target_id=target_id,
|
||||
export_result = export_site_urls_for_fingerprint_task(
|
||||
output_file=urls_file,
|
||||
source=source,
|
||||
batch_size=1000
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
total_count = export_result['total_count']
|
||||
logger.info(
|
||||
"✓ URL 导出完成 - 文件: %s, 数量: %d",
|
||||
export_result['output_file'],
|
||||
total_count
|
||||
)
|
||||
logger.info("✓ URL 导出完成 - 文件: %s, 数量: %d", export_result['output_file'], total_count)
|
||||
|
||||
return export_result['output_file'], total_count
|
||||
|
||||
|
||||
def _run_fingerprint_detect(
|
||||
enabled_tools: dict,
|
||||
urls_file: str,
|
||||
url_count: int,
|
||||
fingerprint_dir: Path,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
source: str
|
||||
) -> tuple[dict, list]:
|
||||
"""
|
||||
执行指纹识别任务
|
||||
def _run_single_tool(
|
||||
tool_name: str,
|
||||
tool_config: dict,
|
||||
ctx: FingerprintContext
|
||||
) -> tuple[Optional[dict], Optional[dict]]:
|
||||
"""执行单个指纹识别工具,返回 (stats, failed_info)"""
|
||||
# 获取指纹库路径
|
||||
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
|
||||
fingerprint_paths = get_fingerprint_paths(lib_names)
|
||||
|
||||
Args:
|
||||
enabled_tools: 已启用的工具配置字典
|
||||
urls_file: URL 文件路径
|
||||
url_count: URL 总数
|
||||
fingerprint_dir: 指纹识别目录
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
source: 数据源类型
|
||||
if not fingerprint_paths:
|
||||
reason = f"没有可用的指纹库: {lib_names}"
|
||||
logger.warning(reason)
|
||||
return None, {'tool': tool_name, 'reason': reason}
|
||||
|
||||
Returns:
|
||||
tuple: (tool_stats, failed_tools)
|
||||
"""
|
||||
# 构建命令
|
||||
tool_config_with_paths = {**tool_config, **fingerprint_paths}
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='fingerprint_detect',
|
||||
command_params={'urls_file': ctx.urls_file},
|
||||
tool_config=tool_config_with_paths
|
||||
)
|
||||
except Exception as e:
|
||||
reason = f"命令构建失败: {e}"
|
||||
logger.error("构建 %s 命令失败: %s", tool_name, e)
|
||||
return None, {'tool': tool_name, 'reason': reason}
|
||||
|
||||
# 计算超时时间和日志文件
|
||||
timeout = calculate_fingerprint_detect_timeout(ctx.url_count)
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = ctx.fingerprint_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
logger.info(
|
||||
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
|
||||
tool_name, ctx.url_count, timeout, list(fingerprint_paths.keys())
|
||||
)
|
||||
user_log(ctx.scan_id, "fingerprint_detect", f"Running {tool_name}: {command}")
|
||||
|
||||
# 执行扫描任务
|
||||
try:
|
||||
result = run_xingfinger_and_stream_update_tech_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name,
|
||||
scan_id=ctx.scan_id,
|
||||
target_id=ctx.target_id,
|
||||
source=ctx.source,
|
||||
cwd=str(ctx.fingerprint_dir),
|
||||
timeout=timeout,
|
||||
log_file=str(log_file),
|
||||
batch_size=100
|
||||
)
|
||||
|
||||
stats = {
|
||||
'command': command,
|
||||
'result': result,
|
||||
'timeout': timeout,
|
||||
'fingerprint_libs': list(fingerprint_paths.keys())
|
||||
}
|
||||
|
||||
tool_updated = result.get('updated_count', 0)
|
||||
logger.info(
|
||||
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
|
||||
tool_name,
|
||||
result.get('processed_records', 0),
|
||||
tool_updated,
|
||||
result.get('not_found_count', 0)
|
||||
)
|
||||
user_log(
|
||||
ctx.scan_id, "fingerprint_detect",
|
||||
f"{tool_name} completed: identified {tool_updated} fingerprints"
|
||||
)
|
||||
return stats, None
|
||||
|
||||
except Exception as exc:
|
||||
reason = str(exc)
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
user_log(ctx.scan_id, "fingerprint_detect", f"{tool_name} failed: {reason}", "error")
|
||||
return None, {'tool': tool_name, 'reason': reason}
|
||||
|
||||
|
||||
def _run_fingerprint_detect(enabled_tools: dict, ctx: FingerprintContext) -> tuple[dict, list]:
|
||||
"""执行指纹识别任务,返回 (tool_stats, failed_tools)"""
|
||||
tool_stats = {}
|
||||
failed_tools = []
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 1. 获取指纹库路径
|
||||
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
|
||||
fingerprint_paths = get_fingerprint_paths(lib_names)
|
||||
|
||||
if not fingerprint_paths:
|
||||
reason = f"没有可用的指纹库: {lib_names}"
|
||||
logger.warning(reason)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
# 2. 将指纹库路径合并到 tool_config(用于命令构建)
|
||||
tool_config_with_paths = {**tool_config, **fingerprint_paths}
|
||||
|
||||
# 3. 构建命令
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='fingerprint_detect',
|
||||
command_params={'urls_file': urls_file},
|
||||
tool_config=tool_config_with_paths
|
||||
)
|
||||
except Exception as e:
|
||||
reason = f"命令构建失败: {e}"
|
||||
logger.error("构建 %s 命令失败: %s", tool_name, e)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
# 4. 计算超时时间
|
||||
timeout = calculate_fingerprint_detect_timeout(url_count)
|
||||
|
||||
# 5. 生成日志文件路径
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = fingerprint_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
logger.info(
|
||||
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
|
||||
tool_name, url_count, timeout, list(fingerprint_paths.keys())
|
||||
)
|
||||
user_log(scan_id, "fingerprint_detect", f"Running {tool_name}: {command}")
|
||||
|
||||
# 6. 执行扫描任务
|
||||
try:
|
||||
result = run_xingfinger_and_stream_update_tech_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
source=source,
|
||||
cwd=str(fingerprint_dir),
|
||||
timeout=timeout,
|
||||
log_file=str(log_file),
|
||||
batch_size=100
|
||||
)
|
||||
|
||||
tool_stats[tool_name] = {
|
||||
'command': command,
|
||||
'result': result,
|
||||
'timeout': timeout,
|
||||
'fingerprint_libs': list(fingerprint_paths.keys())
|
||||
}
|
||||
|
||||
tool_updated = result.get('updated_count', 0)
|
||||
logger.info(
|
||||
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
|
||||
tool_name,
|
||||
result.get('processed_records', 0),
|
||||
tool_updated,
|
||||
result.get('not_found_count', 0)
|
||||
)
|
||||
user_log(
|
||||
scan_id, "fingerprint_detect",
|
||||
f"{tool_name} completed: identified {tool_updated} fingerprints"
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
reason = str(exc)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
user_log(scan_id, "fingerprint_detect", f"{tool_name} failed: {reason}", "error")
|
||||
stats, failed_info = _run_single_tool(tool_name, tool_config, ctx)
|
||||
if stats:
|
||||
tool_stats[tool_name] = stats
|
||||
if failed_info:
|
||||
failed_tools.append(failed_info)
|
||||
|
||||
if failed_tools:
|
||||
logger.warning(
|
||||
@@ -209,6 +175,24 @@ def _run_fingerprint_detect(
|
||||
return tool_stats, failed_tools
|
||||
|
||||
|
||||
def _aggregate_results(tool_stats: dict) -> dict:
|
||||
"""汇总所有工具的结果"""
|
||||
return {
|
||||
'processed_records': sum(
|
||||
s['result'].get('processed_records', 0) for s in tool_stats.values()
|
||||
),
|
||||
'updated_count': sum(
|
||||
s['result'].get('updated_count', 0) for s in tool_stats.values()
|
||||
),
|
||||
'created_count': sum(
|
||||
s['result'].get('created_count', 0) for s in tool_stats.values()
|
||||
),
|
||||
'snapshot_count': sum(
|
||||
s['result'].get('snapshot_count', 0) for s in tool_stats.values()
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@flow(
|
||||
name="fingerprint_detect",
|
||||
log_prints=True,
|
||||
@@ -218,10 +202,10 @@ def _run_fingerprint_detect(
|
||||
)
|
||||
def fingerprint_detect_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict
|
||||
enabled_tools: dict,
|
||||
provider,
|
||||
) -> dict:
|
||||
"""
|
||||
指纹识别 Flow
|
||||
@@ -230,57 +214,45 @@ def fingerprint_detect_flow(
|
||||
1. 从数据库导出目标下所有 WebSite URL 到文件
|
||||
2. 使用 xingfinger 进行技术栈识别
|
||||
3. 解析结果并更新 WebSite.tech 字段(合并去重)
|
||||
|
||||
工作流程:
|
||||
Step 0: 创建工作目录
|
||||
Step 1: 导出 URL 列表
|
||||
Step 2: 解析配置,获取启用的工具
|
||||
Step 3: 执行 xingfinger 并解析结果
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置(xingfinger)
|
||||
|
||||
Returns:
|
||||
dict: 扫描结果
|
||||
"""
|
||||
try:
|
||||
# 负载检查:等待系统资源充足
|
||||
wait_for_system_load(context="fingerprint_detect_flow")
|
||||
|
||||
# 从 provider 获取 target_name
|
||||
target_name = provider.get_target_name()
|
||||
if not target_name:
|
||||
raise ValueError("无法获取 Target 名称")
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
|
||||
logger.info(
|
||||
"开始指纹识别 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
)
|
||||
user_log(scan_id, "fingerprint_detect", "Starting fingerprint detection")
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
# 创建上下文
|
||||
ctx = FingerprintContext(
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
target_name=target_name,
|
||||
scan_workspace_dir=scan_workspace_dir,
|
||||
fingerprint_dir=setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
|
||||
)
|
||||
|
||||
# 数据源类型(当前只支持 website)
|
||||
source = 'website'
|
||||
# Step 1: 导出 URL
|
||||
ctx.urls_file, ctx.url_count = _export_urls(ctx.fingerprint_dir, provider)
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
|
||||
|
||||
# Step 1: 导出 URL(支持懒加载)
|
||||
urls_file, url_count = _export_urls(target_id, fingerprint_dir, source)
|
||||
|
||||
if url_count == 0:
|
||||
if ctx.url_count == 0:
|
||||
logger.warning("跳过指纹识别:没有 URL 可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "fingerprint_detect", "Skipped: no URLs to scan", "warning")
|
||||
return _build_empty_result(scan_id, target_name, scan_workspace_dir, urls_file)
|
||||
return _build_empty_result(scan_id, target_name, scan_workspace_dir, ctx.urls_file)
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
logger.info("Step 2: 工具配置信息")
|
||||
@@ -288,57 +260,30 @@ def fingerprint_detect_flow(
|
||||
|
||||
# Step 3: 执行指纹识别
|
||||
logger.info("Step 3: 执行指纹识别")
|
||||
tool_stats, failed_tools = _run_fingerprint_detect(
|
||||
enabled_tools=enabled_tools,
|
||||
urls_file=urls_file,
|
||||
url_count=url_count,
|
||||
fingerprint_dir=fingerprint_dir,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
source=source
|
||||
)
|
||||
tool_stats, failed_tools = _run_fingerprint_detect(enabled_tools, ctx)
|
||||
|
||||
# 动态生成已执行的任务列表
|
||||
executed_tasks = ['export_urls_for_fingerprint']
|
||||
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats])
|
||||
# 汇总结果
|
||||
totals = _aggregate_results(tool_stats)
|
||||
failed_tool_names = {f['tool'] for f in failed_tools}
|
||||
successful_tools = [name for name in enabled_tools if name not in failed_tool_names]
|
||||
|
||||
# 汇总所有工具的结果
|
||||
total_processed = sum(
|
||||
stats['result'].get('processed_records', 0) for stats in tool_stats.values()
|
||||
)
|
||||
total_updated = sum(
|
||||
stats['result'].get('updated_count', 0) for stats in tool_stats.values()
|
||||
)
|
||||
total_created = sum(
|
||||
stats['result'].get('created_count', 0) for stats in tool_stats.values()
|
||||
)
|
||||
total_snapshots = sum(
|
||||
stats['result'].get('snapshot_count', 0) for stats in tool_stats.values()
|
||||
)
|
||||
|
||||
# 记录 Flow 完成
|
||||
logger.info("✓ 指纹识别完成 - 识别指纹: %d", total_updated)
|
||||
logger.info("✓ 指纹识别完成 - 识别指纹: %d", totals['updated_count'])
|
||||
user_log(
|
||||
scan_id, "fingerprint_detect",
|
||||
f"fingerprint_detect completed: identified {total_updated} fingerprints"
|
||||
f"fingerprint_detect completed: identified {totals['updated_count']} fingerprints"
|
||||
)
|
||||
|
||||
successful_tools = [
|
||||
name for name in enabled_tools
|
||||
if name not in [f['tool'] for f in failed_tools]
|
||||
]
|
||||
executed_tasks = ['export_site_urls_for_fingerprint']
|
||||
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats])
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'url_count': url_count,
|
||||
'processed_records': total_processed,
|
||||
'updated_count': total_updated,
|
||||
'created_count': total_created,
|
||||
'snapshot_count': total_snapshots,
|
||||
'urls_file': ctx.urls_file,
|
||||
'url_count': ctx.url_count,
|
||||
**totals,
|
||||
'executed_tasks': executed_tasks,
|
||||
'tool_stats': {
|
||||
'total': len(enabled_tools),
|
||||
@@ -379,7 +324,7 @@ def _build_empty_result(
|
||||
'updated_count': 0,
|
||||
'created_count': 0,
|
||||
'snapshot_count': 0,
|
||||
'executed_tasks': ['export_urls_for_fingerprint'],
|
||||
'executed_tasks': ['export_site_urls_for_fingerprint'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
- 使用 FlowOrchestrator 解析 YAML 配置
|
||||
- 在 Prefect Flow 中执行子 Flow(Subflow)
|
||||
- 按照 YAML 顺序编排工作流
|
||||
- 根据 scan_mode 创建对应的 Provider
|
||||
- 不包含具体业务逻辑(由 Tasks 和 FlowOrchestrator 实现)
|
||||
|
||||
架构:
|
||||
@@ -18,20 +19,20 @@
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
# 注意:动态扫描容器应使用 run_initiate_scan.py 启动,以便在导入前设置环境变量
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
import apps.common.prefect_django_setup # noqa: F401
|
||||
|
||||
import logging
|
||||
|
||||
from prefect import flow, task
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from prefect.futures import wait
|
||||
|
||||
from apps.scan.handlers import (
|
||||
on_initiate_scan_flow_running,
|
||||
on_initiate_scan_flow_completed,
|
||||
on_initiate_scan_flow_failed,
|
||||
)
|
||||
from prefect.futures import wait
|
||||
from apps.scan.utils import setup_scan_workspace
|
||||
from apps.scan.orchestrators import FlowOrchestrator
|
||||
from apps.scan.utils import setup_scan_workspace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,6 +44,75 @@ def _run_subflow_task(scan_type: str, flow_func, flow_kwargs: dict):
|
||||
return flow_func(**flow_kwargs)
|
||||
|
||||
|
||||
def _create_provider(scan, target_id: int, scan_id: int):
|
||||
"""根据 scan_mode 创建对应的 Provider"""
|
||||
from apps.scan.models import Scan
|
||||
from apps.scan.providers import (
|
||||
DatabaseTargetProvider,
|
||||
SnapshotTargetProvider,
|
||||
ProviderContext,
|
||||
)
|
||||
|
||||
provider_context = ProviderContext(target_id=target_id, scan_id=scan_id)
|
||||
|
||||
if scan.scan_mode == Scan.ScanMode.QUICK:
|
||||
provider = SnapshotTargetProvider(scan_id=scan_id, context=provider_context)
|
||||
logger.info("✓ 快速扫描模式 - 创建 SnapshotTargetProvider")
|
||||
else:
|
||||
provider = DatabaseTargetProvider(target_id=target_id, context=provider_context)
|
||||
logger.info("✓ 完整扫描模式 - 使用 DatabaseTargetProvider")
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
def _execute_sequential_flows(valid_flows: list, results: dict, executed_flows: list):
|
||||
"""顺序执行 Flow 列表"""
|
||||
for scan_type, flow_func, flow_kwargs in valid_flows:
|
||||
logger.info("=" * 60)
|
||||
logger.info("执行 Flow: %s", scan_type)
|
||||
logger.info("=" * 60)
|
||||
try:
|
||||
result = flow_func(**flow_kwargs)
|
||||
executed_flows.append(scan_type)
|
||||
results[scan_type] = result
|
||||
logger.info("✓ %s 执行成功", scan_type)
|
||||
except Exception as e:
|
||||
logger.warning("%s 执行失败: %s", scan_type, e)
|
||||
executed_flows.append(f"{scan_type} (失败)")
|
||||
results[scan_type] = {'success': False, 'error': str(e)}
|
||||
|
||||
|
||||
def _execute_parallel_flows(valid_flows: list, results: dict, executed_flows: list):
|
||||
"""并行执行 Flow 列表"""
|
||||
futures = []
|
||||
for scan_type, flow_func, flow_kwargs in valid_flows:
|
||||
logger.info("=" * 60)
|
||||
logger.info("提交并行子 Flow 任务: %s", scan_type)
|
||||
logger.info("=" * 60)
|
||||
future = _run_subflow_task.submit(
|
||||
scan_type=scan_type,
|
||||
flow_func=flow_func,
|
||||
flow_kwargs=flow_kwargs,
|
||||
)
|
||||
futures.append((scan_type, future))
|
||||
|
||||
if not futures:
|
||||
return
|
||||
|
||||
wait([f for _, f in futures])
|
||||
|
||||
for scan_type, future in futures:
|
||||
try:
|
||||
result = future.result()
|
||||
executed_flows.append(scan_type)
|
||||
results[scan_type] = result
|
||||
logger.info("✓ %s 执行成功", scan_type)
|
||||
except Exception as e:
|
||||
logger.warning("%s 执行失败: %s", scan_type, e)
|
||||
executed_flows.append(f"{scan_type} (失败)")
|
||||
results[scan_type] = {'success': False, 'error': str(e)}
|
||||
|
||||
|
||||
@flow(
|
||||
name='initiate_scan',
|
||||
description='扫描任务初始化流程',
|
||||
@@ -53,15 +123,14 @@ def _run_subflow_task(scan_type: str, flow_func, flow_kwargs: dict):
|
||||
)
|
||||
def initiate_scan_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
engine_name: str,
|
||||
scheduled_scan_name: str | None = None,
|
||||
scheduled_scan_name: str | None = None, # noqa: ARG001
|
||||
) -> dict:
|
||||
"""
|
||||
初始化扫描任务(动态工作流编排)
|
||||
|
||||
|
||||
根据 YAML 配置动态编排工作流:
|
||||
- 从数据库获取 engine_config (YAML)
|
||||
- 检测启用的扫描类型
|
||||
@@ -73,189 +142,112 @@ def initiate_scan_flow(
|
||||
Stage 2: Analysis (并行执行)
|
||||
- url_fetch
|
||||
- directory_scan
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: Scan 工作空间目录路径
|
||||
engine_name: 引擎名称(用于显示)
|
||||
scheduled_scan_name: 定时扫描任务名称(可选,用于通知显示)
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 执行结果摘要
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 参数验证失败或配置无效
|
||||
RuntimeError: 执行失败
|
||||
"""
|
||||
try:
|
||||
# ==================== 参数验证 ====================
|
||||
# 参数验证
|
||||
if not scan_id:
|
||||
raise ValueError("scan_id is required")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir is required")
|
||||
if not engine_name:
|
||||
raise ValueError("engine_name is required")
|
||||
|
||||
|
||||
logger.info("="*60)
|
||||
logger.info("开始初始化扫描任务")
|
||||
logger.info(f"Scan ID: {scan_id}")
|
||||
logger.info(f"Target: {target_name}")
|
||||
logger.info(f"Engine: {engine_name}")
|
||||
logger.info(f"Workspace: {scan_workspace_dir}")
|
||||
logger.info("="*60)
|
||||
|
||||
# ==================== Task 1: 创建 Scan 工作空间 ====================
|
||||
|
||||
# 创建工作空间
|
||||
scan_workspace_path = setup_scan_workspace(scan_workspace_dir)
|
||||
|
||||
# ==================== Task 2: 获取引擎配置 ====================
|
||||
|
||||
# 获取引擎配置
|
||||
from apps.scan.models import Scan
|
||||
scan = Scan.objects.get(id=scan_id)
|
||||
engine_config = scan.yaml_configuration
|
||||
|
||||
# 使用 engine_names 进行显示
|
||||
display_engine_name = ', '.join(scan.engine_names) if scan.engine_names else engine_name
|
||||
|
||||
# ==================== Task 3: 解析配置,生成执行计划 ====================
|
||||
|
||||
# 创建 Provider
|
||||
provider = _create_provider(scan, target_id, scan_id)
|
||||
|
||||
# 获取 target_name 用于日志显示
|
||||
target_name = provider.get_target_name()
|
||||
if not target_name:
|
||||
raise ValueError("无法获取 Target 名称")
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始初始化扫描任务")
|
||||
logger.info("Scan ID: %s, Target: %s, Engine: %s", scan_id, target_name, engine_name)
|
||||
logger.info("Workspace: %s", scan_workspace_dir)
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 解析配置,生成执行计划
|
||||
orchestrator = FlowOrchestrator(engine_config)
|
||||
|
||||
# FlowOrchestrator 已经解析了所有工具配置
|
||||
enabled_tools_by_type = orchestrator.enabled_tools_by_type
|
||||
|
||||
logger.info("执行计划生成成功")
|
||||
logger.info(f"扫描类型: {' → '.join(orchestrator.scan_types)}")
|
||||
logger.info(f"总共 {len(orchestrator.scan_types)} 个 Flow")
|
||||
|
||||
# ==================== 初始化阶段进度 ====================
|
||||
# 在解析完配置后立即初始化,此时已有完整的 scan_types 列表
|
||||
|
||||
logger.info("执行计划: %s (共 %d 个 Flow)",
|
||||
' → '.join(orchestrator.scan_types), len(orchestrator.scan_types))
|
||||
|
||||
# 初始化阶段进度
|
||||
from apps.scan.services import ScanService
|
||||
scan_service = ScanService()
|
||||
scan_service.init_stage_progress(scan_id, orchestrator.scan_types)
|
||||
logger.info(f"✓ 初始化阶段进度 - Stages: {orchestrator.scan_types}")
|
||||
|
||||
# ==================== 更新 Target 最后扫描时间 ====================
|
||||
# 在开始扫描时更新,表示"最后一次扫描开始时间"
|
||||
ScanService().init_stage_progress(scan_id, orchestrator.scan_types)
|
||||
logger.info("✓ 初始化阶段进度 - Stages: %s", orchestrator.scan_types)
|
||||
|
||||
# 更新 Target 最后扫描时间
|
||||
from apps.targets.services import TargetService
|
||||
target_service = TargetService()
|
||||
target_service.update_last_scanned_at(target_id)
|
||||
logger.info(f"✓ 更新 Target 最后扫描时间 - Target ID: {target_id}")
|
||||
|
||||
# ==================== Task 3: 执行 Flow(动态阶段执行)====================
|
||||
# 注意:各阶段状态更新由 scan_flow_handlers.py 自动处理(running/completed/failed)
|
||||
TargetService().update_last_scanned_at(target_id)
|
||||
logger.info("✓ 更新 Target 最后扫描时间 - Target ID: %s", target_id)
|
||||
|
||||
# 执行 Flow
|
||||
executed_flows = []
|
||||
results = {}
|
||||
|
||||
# 通用执行参数
|
||||
flow_kwargs = {
|
||||
base_kwargs = {
|
||||
'scan_id': scan_id,
|
||||
'target_name': target_name,
|
||||
'target_id': target_id,
|
||||
'scan_workspace_dir': str(scan_workspace_path)
|
||||
}
|
||||
|
||||
def record_flow_result(scan_type, result=None, error=None):
|
||||
"""
|
||||
统一的结果记录函数
|
||||
|
||||
Args:
|
||||
scan_type: 扫描类型名称
|
||||
result: 执行结果(成功时)
|
||||
error: 异常对象(失败时)
|
||||
"""
|
||||
if error:
|
||||
# 失败处理:记录错误但不抛出异常,让扫描继续执行后续阶段
|
||||
error_msg = f"{scan_type} 执行失败: {str(error)}"
|
||||
logger.warning(error_msg)
|
||||
executed_flows.append(f"{scan_type} (失败)")
|
||||
results[scan_type] = {'success': False, 'error': str(error)}
|
||||
# 不再抛出异常,让扫描继续
|
||||
else:
|
||||
# 成功处理
|
||||
executed_flows.append(scan_type)
|
||||
results[scan_type] = result
|
||||
logger.info(f"✓ {scan_type} 执行成功")
|
||||
|
||||
def get_valid_flows(flow_names):
|
||||
"""
|
||||
获取有效的 Flow 函数列表,并为每个 Flow 准备专属参数
|
||||
|
||||
Args:
|
||||
flow_names: 扫描类型名称列表
|
||||
|
||||
Returns:
|
||||
list: [(scan_type, flow_func, flow_specific_kwargs), ...] 有效的函数列表
|
||||
"""
|
||||
valid_flows = []
|
||||
def get_valid_flows(flow_names: list) -> list:
|
||||
"""获取有效的 Flow 函数列表"""
|
||||
valid = []
|
||||
for scan_type in flow_names:
|
||||
flow_func = orchestrator.get_flow_function(scan_type)
|
||||
if flow_func:
|
||||
# 为每个 Flow 准备专属的参数(包含对应的 enabled_tools)
|
||||
flow_specific_kwargs = dict(flow_kwargs)
|
||||
flow_specific_kwargs['enabled_tools'] = enabled_tools_by_type.get(scan_type, {})
|
||||
valid_flows.append((scan_type, flow_func, flow_specific_kwargs))
|
||||
else:
|
||||
logger.warning(f"跳过未实现的 Flow: {scan_type}")
|
||||
return valid_flows
|
||||
if not flow_func:
|
||||
logger.warning("跳过未实现的 Flow: %s", scan_type)
|
||||
continue
|
||||
kwargs = dict(base_kwargs)
|
||||
kwargs['enabled_tools'] = enabled_tools_by_type.get(scan_type, {})
|
||||
kwargs['provider'] = provider
|
||||
valid.append((scan_type, flow_func, kwargs))
|
||||
return valid
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 动态阶段执行(基于 FlowOrchestrator 定义)
|
||||
# ---------------------------------------------------------
|
||||
# 动态阶段执行
|
||||
for mode, enabled_flows in orchestrator.get_execution_stages():
|
||||
valid_flows = get_valid_flows(enabled_flows)
|
||||
if not valid_flows:
|
||||
continue
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("%s执行阶段: %s", "顺序" if mode == 'sequential' else "并行",
|
||||
', '.join(enabled_flows))
|
||||
logger.info("=" * 60)
|
||||
|
||||
if mode == 'sequential':
|
||||
# 顺序执行
|
||||
logger.info("="*60)
|
||||
logger.info(f"顺序执行阶段: {', '.join(enabled_flows)}")
|
||||
logger.info("="*60)
|
||||
for scan_type, flow_func, flow_specific_kwargs in get_valid_flows(enabled_flows):
|
||||
logger.info("="*60)
|
||||
logger.info(f"执行 Flow: {scan_type}")
|
||||
logger.info("="*60)
|
||||
try:
|
||||
result = flow_func(**flow_specific_kwargs)
|
||||
record_flow_result(scan_type, result=result)
|
||||
except Exception as e:
|
||||
record_flow_result(scan_type, error=e)
|
||||
|
||||
elif mode == 'parallel':
|
||||
# 并行执行阶段:通过 Task 包装子 Flow,并使用 Prefect TaskRunner 并发运行
|
||||
logger.info("="*60)
|
||||
logger.info(f"并行执行阶段: {', '.join(enabled_flows)}")
|
||||
logger.info("="*60)
|
||||
futures = []
|
||||
_execute_sequential_flows(valid_flows, results, executed_flows)
|
||||
else:
|
||||
_execute_parallel_flows(valid_flows, results, executed_flows)
|
||||
|
||||
# 提交所有并行子 Flow 任务
|
||||
for scan_type, flow_func, flow_specific_kwargs in get_valid_flows(enabled_flows):
|
||||
logger.info("="*60)
|
||||
logger.info(f"提交并行子 Flow 任务: {scan_type}")
|
||||
logger.info("="*60)
|
||||
future = _run_subflow_task.submit(
|
||||
scan_type=scan_type,
|
||||
flow_func=flow_func,
|
||||
flow_kwargs=flow_specific_kwargs,
|
||||
)
|
||||
futures.append((scan_type, future))
|
||||
logger.info("=" * 60)
|
||||
logger.info("✓ 扫描任务初始化完成 - 执行的 Flow: %s", ', '.join(executed_flows))
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 等待所有并行子 Flow 完成
|
||||
if futures:
|
||||
wait([f for _, f in futures])
|
||||
|
||||
# 检查结果(复用统一的结果处理逻辑)
|
||||
for scan_type, future in futures:
|
||||
try:
|
||||
result = future.result()
|
||||
record_flow_result(scan_type, result=result)
|
||||
except Exception as e:
|
||||
record_flow_result(scan_type, error=e)
|
||||
|
||||
# ==================== 完成 ====================
|
||||
logger.info("="*60)
|
||||
logger.info("✓ 扫描任务初始化完成")
|
||||
logger.info(f"执行的 Flow: {', '.join(executed_flows)}")
|
||||
logger.info("="*60)
|
||||
|
||||
# ==================== 返回结果 ====================
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -264,21 +256,16 @@ def initiate_scan_flow(
|
||||
'executed_flows': executed_flows,
|
||||
'results': results
|
||||
}
|
||||
|
||||
|
||||
except ValueError as e:
|
||||
# 参数错误
|
||||
logger.error("参数错误: %s", e)
|
||||
raise
|
||||
except RuntimeError as e:
|
||||
# 执行失败
|
||||
logger.error("运行时错误: %s", e)
|
||||
raise
|
||||
except OSError as e:
|
||||
# 文件系统错误(工作空间创建失败)
|
||||
logger.error("文件系统错误: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
# 其他未预期错误
|
||||
logger.exception("初始化扫描任务失败: %s", e)
|
||||
# 注意:失败状态更新由 Prefect State Handlers 自动处理
|
||||
raise
|
||||
|
||||
@@ -132,42 +132,36 @@ def _parse_port_count(tool_config: dict) -> int:
|
||||
|
||||
|
||||
|
||||
def _export_hosts(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
|
||||
def _export_hosts(port_scan_dir: Path, provider) -> tuple[str, int]:
|
||||
"""
|
||||
导出主机列表到文件
|
||||
|
||||
根据 Target 类型自动决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名
|
||||
- IP: 直接写入 target.name
|
||||
- CIDR: 展开 CIDR 范围内的所有 IP
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
port_scan_dir: 端口扫描目录
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
tuple: (hosts_file, host_count, target_type)
|
||||
tuple: (hosts_file, host_count)
|
||||
"""
|
||||
logger.info("Step 1: 导出主机列表")
|
||||
|
||||
hosts_file = str(port_scan_dir / 'hosts.txt')
|
||||
export_result = export_hosts_task(
|
||||
target_id=target_id,
|
||||
output_file=hosts_file,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
host_count = export_result['total_count']
|
||||
target_type = export_result.get('target_type', 'unknown')
|
||||
|
||||
logger.info(
|
||||
"✓ 主机列表导出完成 - 类型: %s, 文件: %s, 数量: %d",
|
||||
target_type, export_result['output_file'], host_count
|
||||
"✓ 主机列表导出完成 - 文件: %s, 数量: %d",
|
||||
export_result['output_file'], host_count
|
||||
)
|
||||
|
||||
if host_count == 0:
|
||||
logger.warning("目标下没有可扫描的主机,无法执行端口扫描")
|
||||
|
||||
return export_result['output_file'], host_count, target_type
|
||||
return export_result['output_file'], host_count
|
||||
|
||||
|
||||
def _run_scans_sequentially(
|
||||
@@ -176,7 +170,7 @@ def _run_scans_sequentially(
|
||||
port_scan_dir: Path,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
target_name: str
|
||||
target_name: str,
|
||||
) -> tuple[dict, int, list, list]:
|
||||
"""
|
||||
串行执行端口扫描任务
|
||||
@@ -187,7 +181,7 @@ def _run_scans_sequentially(
|
||||
port_scan_dir: 端口扫描目录
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(用于错误日志)
|
||||
target_name: 目标名称(用于日志显示)
|
||||
|
||||
Returns:
|
||||
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
|
||||
@@ -271,7 +265,7 @@ def _run_scans_sequentially(
|
||||
|
||||
if not tool_stats:
|
||||
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
|
||||
logger.warning("所有端口扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
|
||||
logger.warning("所有端口扫描工具均失败 - Target: %s, 失败工具: %s", target_name, error_details)
|
||||
return {}, 0, [], failed_tools
|
||||
|
||||
successful_tool_names = [
|
||||
@@ -298,10 +292,10 @@ def _run_scans_sequentially(
|
||||
)
|
||||
def port_scan_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict
|
||||
enabled_tools: dict,
|
||||
provider,
|
||||
) -> dict:
|
||||
"""
|
||||
端口扫描 Flow
|
||||
@@ -321,10 +315,10 @@ def port_scan_flow(
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 域名
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: Scan 工作空间目录
|
||||
enabled_tools: 启用的工具配置字典
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
dict: 扫描结果
|
||||
@@ -336,10 +330,13 @@ def port_scan_flow(
|
||||
try:
|
||||
wait_for_system_load(context="port_scan_flow")
|
||||
|
||||
# 从 provider 获取 target_name
|
||||
target_name = provider.get_target_name()
|
||||
if not target_name:
|
||||
raise ValueError("无法获取 Target 名称")
|
||||
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
@@ -358,7 +355,7 @@ def port_scan_flow(
|
||||
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
|
||||
|
||||
# Step 1: 导出主机列表
|
||||
hosts_file, host_count, target_type = _export_hosts(target_id, port_scan_dir)
|
||||
hosts_file, host_count = _export_hosts(port_scan_dir, provider)
|
||||
|
||||
if host_count == 0:
|
||||
logger.warning("跳过端口扫描:没有主机可扫描 - Scan ID: %s", scan_id)
|
||||
@@ -370,7 +367,6 @@ def port_scan_flow(
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'hosts_file': hosts_file,
|
||||
'host_count': 0,
|
||||
'target_type': target_type,
|
||||
'processed_records': 0,
|
||||
'executed_tasks': ['export_hosts'],
|
||||
'tool_stats': {
|
||||
@@ -395,7 +391,7 @@ def port_scan_flow(
|
||||
port_scan_dir=port_scan_dir,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
target_name=target_name
|
||||
target_name=target_name,
|
||||
)
|
||||
|
||||
logger.info("✓ 端口扫描完成 - 发现端口: %d", processed_records)
|
||||
@@ -411,7 +407,6 @@ def port_scan_flow(
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'hosts_file': hosts_file,
|
||||
'host_count': host_count,
|
||||
'target_type': target_type,
|
||||
'processed_records': processed_records,
|
||||
'executed_tasks': executed_tasks,
|
||||
'tool_stats': {
|
||||
|
||||
@@ -2,17 +2,12 @@
|
||||
截图 Flow
|
||||
|
||||
负责编排截图的完整流程:
|
||||
1. 从数据库获取 URL 列表(websites 和/或 endpoints)
|
||||
1. 从 Provider 获取 URL 列表
|
||||
2. 批量截图并保存快照
|
||||
3. 同步到资产表
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from prefect import flow
|
||||
|
||||
@@ -22,62 +17,49 @@ from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider
|
||||
from apps.scan.services.target_export_service import DataSource, get_urls_with_fallback
|
||||
from apps.scan.tasks.screenshot import capture_screenshots_task
|
||||
from apps.scan.utils import user_log, wait_for_system_load
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# URL 来源到 DataSource 的映射
|
||||
_SOURCE_MAPPING = {
|
||||
'websites': DataSource.WEBSITE,
|
||||
'endpoints': DataSource.ENDPOINT,
|
||||
}
|
||||
|
||||
|
||||
def _parse_screenshot_config(enabled_tools: dict) -> dict:
|
||||
"""解析截图配置"""
|
||||
playwright_config = enabled_tools.get('playwright', {})
|
||||
return {
|
||||
'concurrency': playwright_config.get('concurrency', 5),
|
||||
'url_sources': playwright_config.get('url_sources', ['websites'])
|
||||
}
|
||||
|
||||
|
||||
def _map_url_sources_to_data_sources(url_sources: list[str]) -> list[str]:
|
||||
"""将配置中的 url_sources 映射为 DataSource 常量"""
|
||||
sources = []
|
||||
for source in url_sources:
|
||||
if source in _SOURCE_MAPPING:
|
||||
sources.append(_SOURCE_MAPPING[source])
|
||||
else:
|
||||
logger.warning("未知的 URL 来源: %s,跳过", source)
|
||||
def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str]:
|
||||
"""
|
||||
从 Provider 收集网站 URL(带回退逻辑)
|
||||
|
||||
优先级:WebSite → HostPortMapping → Default URL
|
||||
|
||||
Returns:
|
||||
tuple: (urls, source)
|
||||
- urls: URL 列表
|
||||
- source: 数据来源 ('website' | 'host_port' | 'default')
|
||||
"""
|
||||
logger.info("从 Provider 获取网站 URL - Provider: %s", type(provider).__name__)
|
||||
|
||||
# 添加默认回退(从 subdomain 构造)
|
||||
sources.append(DataSource.DEFAULT)
|
||||
return sources
|
||||
# 优先从 WebSite 获取
|
||||
urls = list(provider.iter_websites())
|
||||
if urls:
|
||||
logger.info("使用 WebSite 数据源 - 数量: %d", len(urls))
|
||||
return urls, "website"
|
||||
|
||||
# 回退到 HostPortMapping
|
||||
urls = list(provider.iter_host_port_urls())
|
||||
if urls:
|
||||
logger.info("WebSite 为空,回退到 HostPortMapping - 数量: %d", len(urls))
|
||||
return urls, "host_port"
|
||||
|
||||
def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str, list[str]]:
|
||||
"""从 Provider 收集 URL"""
|
||||
logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__)
|
||||
urls = list(provider.iter_urls())
|
||||
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
if blacklist_filter:
|
||||
urls = [url for url in urls if blacklist_filter.is_allowed(url)]
|
||||
|
||||
return urls, 'provider', ['provider']
|
||||
|
||||
|
||||
def _collect_urls_from_database(
|
||||
target_id: int,
|
||||
url_sources: list[str]
|
||||
) -> tuple[list[str], str, list[str]]:
|
||||
"""从数据库收集 URL(带黑名单过滤和回退)"""
|
||||
data_sources = _map_url_sources_to_data_sources(url_sources)
|
||||
result = get_urls_with_fallback(target_id, sources=data_sources)
|
||||
return result['urls'], result['source'], result['tried_sources']
|
||||
# 最终回退到默认 URL
|
||||
urls = list(provider.iter_default_urls())
|
||||
logger.info("HostPortMapping 为空,回退到默认 URL - 数量: %d", len(urls))
|
||||
return urls, "default"
|
||||
|
||||
|
||||
def _build_empty_result(scan_id: int, target_name: str) -> dict:
|
||||
@@ -102,68 +84,53 @@ def _build_empty_result(scan_id: int, target_name: str) -> dict:
|
||||
)
|
||||
def screenshot_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict,
|
||||
provider: Optional[TargetProvider] = None
|
||||
provider: TargetProvider,
|
||||
) -> dict:
|
||||
"""
|
||||
截图 Flow
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置
|
||||
provider: TargetProvider 实例(新模式,可选)
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
截图结果字典
|
||||
"""
|
||||
try:
|
||||
# 负载检查:等待系统资源充足
|
||||
wait_for_system_load(context="screenshot_flow")
|
||||
|
||||
mode = 'Provider' if provider else 'Legacy'
|
||||
# 从 provider 获取 target_name
|
||||
target_name = provider.get_target_name()
|
||||
if not target_name:
|
||||
raise ValueError("无法获取 Target 名称")
|
||||
|
||||
logger.info(
|
||||
"开始截图扫描 - Scan ID: %s, Target: %s, Mode: %s",
|
||||
scan_id, target_name, mode
|
||||
"开始截图扫描 - Scan ID: %s, Target: %s",
|
||||
scan_id, target_name
|
||||
)
|
||||
user_log(scan_id, "screenshot", "Starting screenshot capture")
|
||||
|
||||
# Step 1: 解析配置
|
||||
config = _parse_screenshot_config(enabled_tools)
|
||||
concurrency = config['concurrency']
|
||||
logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, config['url_sources'])
|
||||
logger.info("截图配置 - 并发: %d", concurrency)
|
||||
|
||||
# Step 2: 收集 URL 列表
|
||||
if provider is not None:
|
||||
urls, source_info, tried_sources = _collect_urls_from_provider(provider)
|
||||
else:
|
||||
urls, source_info, tried_sources = _collect_urls_from_database(
|
||||
target_id, config['url_sources']
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"URL 收集完成 - 来源: %s, 数量: %d, 尝试过: %s",
|
||||
source_info, len(urls), tried_sources
|
||||
)
|
||||
# Step 2: 从 Provider 收集 URL 列表(带回退逻辑)
|
||||
urls, source = _collect_urls_from_provider(provider)
|
||||
logger.info("URL 收集完成 - 来源: %s, 数量: %d", source, len(urls))
|
||||
|
||||
if not urls:
|
||||
logger.warning("没有可截图的 URL,跳过截图任务")
|
||||
user_log(scan_id, "screenshot", "Skipped: no URLs to capture", "warning")
|
||||
return _build_empty_result(scan_id, target_name)
|
||||
|
||||
user_log(
|
||||
scan_id, "screenshot",
|
||||
f"Found {len(urls)} URLs to capture (source: {source_info})"
|
||||
)
|
||||
user_log(scan_id, "screenshot", f"Found {len(urls)} URLs to capture")
|
||||
|
||||
# Step 3: 批量截图
|
||||
logger.info("批量截图 - %d 个 URL", len(urls))
|
||||
|
||||
@@ -88,40 +88,38 @@ def _calculate_timeout_by_line_count(
|
||||
|
||||
|
||||
def _export_site_urls(
|
||||
target_id: int,
|
||||
site_scan_dir: Path
|
||||
) -> tuple[str, int, int]:
|
||||
site_scan_dir: Path,
|
||||
provider,
|
||||
) -> tuple[str, int]:
|
||||
"""
|
||||
导出站点 URL 到文件
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
site_scan_dir: 站点扫描目录
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
tuple: (urls_file, total_urls, association_count)
|
||||
tuple: (urls_file, total_urls)
|
||||
"""
|
||||
logger.info("Step 1: 导出站点URL列表")
|
||||
|
||||
urls_file = str(site_scan_dir / 'site_urls.txt')
|
||||
export_result = export_site_urls_task(
|
||||
target_id=target_id,
|
||||
output_file=urls_file,
|
||||
batch_size=1000
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
total_urls = export_result['total_urls']
|
||||
association_count = export_result['association_count']
|
||||
|
||||
logger.info(
|
||||
"✓ 站点URL导出完成 - 文件: %s, URL数量: %d, 关联数: %d",
|
||||
export_result['output_file'], total_urls, association_count
|
||||
"✓ 站点URL导出完成 - 文件: %s, URL数量: %d",
|
||||
export_result['output_file'], total_urls
|
||||
)
|
||||
|
||||
if total_urls == 0:
|
||||
logger.warning("目标下没有可用的站点URL,无法执行站点扫描")
|
||||
|
||||
return export_result['output_file'], total_urls, association_count
|
||||
return export_result['output_file'], total_urls
|
||||
|
||||
|
||||
def _get_tool_timeout(tool_config: dict, urls_file: str) -> int:
|
||||
@@ -263,7 +261,6 @@ def _build_empty_result(
|
||||
target_name: str,
|
||||
scan_workspace_dir: str,
|
||||
urls_file: str,
|
||||
association_count: int
|
||||
) -> dict:
|
||||
"""构建空结果(无 URL 可扫描时)"""
|
||||
return {
|
||||
@@ -273,7 +270,6 @@ def _build_empty_result(
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'total_urls': 0,
|
||||
'association_count': association_count,
|
||||
'processed_records': 0,
|
||||
'created_websites': 0,
|
||||
'skipped_no_subdomain': 0,
|
||||
@@ -306,15 +302,12 @@ def _aggregate_tool_results(tool_stats: dict) -> tuple[int, int, int]:
|
||||
|
||||
def _validate_flow_params(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str
|
||||
) -> None:
|
||||
"""验证 Flow 参数"""
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
@@ -330,10 +323,10 @@ def _validate_flow_params(
|
||||
)
|
||||
def site_scan_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict
|
||||
enabled_tools: dict,
|
||||
provider,
|
||||
) -> dict:
|
||||
"""
|
||||
站点扫描 Flow
|
||||
@@ -344,10 +337,10 @@ def site_scan_flow(
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置字典
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
dict: 扫描结果
|
||||
@@ -359,12 +352,17 @@ def site_scan_flow(
|
||||
try:
|
||||
wait_for_system_load(context="site_scan_flow")
|
||||
|
||||
# 从 provider 获取 target_name
|
||||
target_name = provider.get_target_name()
|
||||
if not target_name:
|
||||
raise ValueError("无法获取 Target 名称")
|
||||
|
||||
logger.info(
|
||||
"开始站点扫描 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
)
|
||||
|
||||
_validate_flow_params(scan_id, target_name, target_id, scan_workspace_dir)
|
||||
_validate_flow_params(scan_id, target_id, scan_workspace_dir)
|
||||
user_log(scan_id, "site_scan", "Starting site scan")
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
@@ -372,15 +370,15 @@ def site_scan_flow(
|
||||
site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan')
|
||||
|
||||
# Step 1: 导出站点 URL
|
||||
urls_file, total_urls, association_count = _export_site_urls(
|
||||
target_id, site_scan_dir
|
||||
urls_file, total_urls = _export_site_urls(
|
||||
site_scan_dir, provider
|
||||
)
|
||||
|
||||
if total_urls == 0:
|
||||
logger.warning("跳过站点扫描:没有站点 URL 可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "site_scan", "Skipped: no site URLs to scan", "warning")
|
||||
return _build_empty_result(
|
||||
scan_id, target_name, scan_workspace_dir, urls_file, association_count
|
||||
scan_id, target_name, scan_workspace_dir, urls_file
|
||||
)
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
@@ -421,7 +419,6 @@ def site_scan_flow(
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'total_urls': total_urls,
|
||||
'association_count': association_count,
|
||||
'processed_records': processed_records,
|
||||
'created_websites': total_created,
|
||||
'skipped_no_subdomain': total_skipped_no_sub,
|
||||
|
||||
@@ -540,10 +540,10 @@ def _empty_result(scan_id: int, target: str, scan_workspace_dir: str) -> dict:
|
||||
)
|
||||
def subdomain_discovery_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict
|
||||
enabled_tools: dict,
|
||||
provider,
|
||||
) -> dict:
|
||||
"""子域名发现扫描流程
|
||||
|
||||
@@ -571,6 +571,8 @@ def subdomain_discovery_flow(
|
||||
if enabled_tools is None:
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
# 从 provider 获取 target_name
|
||||
target_name = provider.get_target_name()
|
||||
if not target_name:
|
||||
logger.warning("未提供目标域名,跳过子域名发现扫描")
|
||||
return _empty_result(scan_id, '', scan_workspace_dir)
|
||||
|
||||
@@ -34,9 +34,9 @@ logger = logging.getLogger(__name__)
|
||||
def domain_name_url_fetch_flow(
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
target_name: str,
|
||||
output_dir: str,
|
||||
domain_name_tools: Dict[str, dict],
|
||||
provider,
|
||||
) -> dict:
|
||||
"""
|
||||
基于 Target 根域名执行 URL 被动收集(当前主要用于 waymore)
|
||||
@@ -46,32 +46,35 @@ def domain_name_url_fetch_flow(
|
||||
2. 使用传入的工具列表对根域名执行被动收集
|
||||
3. 工具内部会自动查询该域名及其子域名的历史 URL
|
||||
4. 汇总结果文件列表
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
target_id: 目标 ID
|
||||
target_name: Target 根域名(如 example.com),不是子域名列表
|
||||
output_dir: 输出目录
|
||||
domain_name_tools: 被动收集工具配置(如 waymore)
|
||||
|
||||
provider: TargetProvider 实例
|
||||
|
||||
注意:
|
||||
- 此 Flow 只对 DOMAIN 类型 Target 有效
|
||||
- IP 和 CIDR 类型会自动跳过(waymore 等工具不支持)
|
||||
- 工具会自动收集 *.target_name 的所有历史 URL,无需遍历子域名
|
||||
"""
|
||||
from apps.scan.utils import user_log
|
||||
|
||||
|
||||
try:
|
||||
# 从 provider 获取 target_name
|
||||
target_name = provider.get_target_name()
|
||||
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 检查 Target 类型,IP/CIDR 类型跳过
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
|
||||
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
|
||||
|
||||
if target and target.type != Target.TargetType.DOMAIN:
|
||||
logger.info(
|
||||
"跳过 domain_name URL 获取: Target 类型为 %s (ID=%d, Name=%s),waymore 等工具仅适用于域名类型",
|
||||
|
||||
@@ -240,10 +240,10 @@ def _save_urls_to_database(merged_file: str, scan_id: int, target_id: int) -> in
|
||||
)
|
||||
def url_fetch_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict
|
||||
enabled_tools: dict,
|
||||
provider,
|
||||
) -> dict:
|
||||
"""
|
||||
URL 获取主 Flow
|
||||
@@ -252,7 +252,7 @@ def url_fetch_flow(
|
||||
1. 准备工作目录
|
||||
2. 按输入类型分类工具(domain_name / sites_file / 后处理)
|
||||
3. 并行执行子 Flow
|
||||
- domain_name_url_fetch_flow: 基于 domain_name(来自 target_name)执行 URL 获取(如 waymore)
|
||||
- domain_name_url_fetch_flow: 基于 domain_name(来自 provider)执行 URL 获取(如 waymore)
|
||||
- sites_url_fetch_flow: 基于 sites_file 执行爬虫(如 katana 等)
|
||||
4. 合并所有子 Flow 的结果并去重
|
||||
5. uro 去重(如果启用)
|
||||
@@ -260,10 +260,10 @@ def url_fetch_flow(
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作目录
|
||||
enabled_tools: 启用的工具配置
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
dict: 扫描结果
|
||||
@@ -272,6 +272,11 @@ def url_fetch_flow(
|
||||
# 负载检查:等待系统资源充足
|
||||
wait_for_system_load(context="url_fetch_flow")
|
||||
|
||||
# 从 provider 获取 target_name
|
||||
target_name = provider.get_target_name()
|
||||
if not target_name:
|
||||
raise ValueError("无法获取 Target 名称")
|
||||
|
||||
logger.info(
|
||||
"开始 URL 获取扫描 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
@@ -310,9 +315,9 @@ def url_fetch_flow(
|
||||
tn_result = domain_name_url_fetch_flow(
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
target_name=target_name,
|
||||
output_dir=str(url_fetch_dir),
|
||||
domain_name_tools=domain_name_tools,
|
||||
provider=provider,
|
||||
)
|
||||
all_result_files.extend(tn_result.get('result_files', []))
|
||||
all_failed_tools.extend(tn_result.get('failed_tools', []))
|
||||
@@ -323,9 +328,9 @@ def url_fetch_flow(
|
||||
crawl_result = sites_url_fetch_flow(
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
target_name=target_name,
|
||||
output_dir=str(url_fetch_dir),
|
||||
enabled_tools=sites_file_tools
|
||||
enabled_tools=sites_file_tools,
|
||||
provider=provider
|
||||
)
|
||||
all_result_files.extend(crawl_result.get('result_files', []))
|
||||
all_failed_tools.extend(crawl_result.get('failed_tools', []))
|
||||
|
||||
@@ -19,17 +19,16 @@ from .utils import run_tools_parallel
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _export_sites_file(target_id: int, scan_id: int, target_name: str, output_dir: Path) -> tuple[str, int]:
|
||||
def _export_sites_file(
|
||||
output_dir: Path,
|
||||
provider,
|
||||
) -> tuple[str, int]:
|
||||
"""
|
||||
导出站点 URL 列表到文件
|
||||
|
||||
懒加载模式:如果 WebSite 表为空,根据 Target 类型生成默认 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
scan_id: 扫描 ID
|
||||
target_name: 目标名称(用于懒加载)
|
||||
output_dir: 输出目录
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
tuple: (file_path, count)
|
||||
@@ -39,8 +38,7 @@ def _export_sites_file(target_id: int, scan_id: int, target_name: str, output_di
|
||||
output_file = str(output_dir / "sites.txt")
|
||||
result = export_sites_task(
|
||||
output_file=output_file,
|
||||
target_id=target_id,
|
||||
scan_id=scan_id
|
||||
provider=provider
|
||||
)
|
||||
|
||||
count = result['asset_count']
|
||||
@@ -56,25 +54,25 @@ def _export_sites_file(target_id: int, scan_id: int, target_name: str, output_di
|
||||
def sites_url_fetch_flow(
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
target_name: str,
|
||||
output_dir: str,
|
||||
enabled_tools: dict
|
||||
enabled_tools: dict,
|
||||
provider,
|
||||
) -> dict:
|
||||
"""
|
||||
URL 爬虫子 Flow
|
||||
|
||||
|
||||
执行流程:
|
||||
1. 导出站点 URL 列表(sites_file)
|
||||
2. 并行执行爬虫工具
|
||||
3. 返回结果文件列表
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称
|
||||
output_dir: 输出目录
|
||||
enabled_tools: 启用的爬虫工具配置
|
||||
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
@@ -85,19 +83,22 @@ def sites_url_fetch_flow(
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 从 provider 获取 target_name
|
||||
target_name = provider.get_target_name()
|
||||
if not target_name:
|
||||
raise ValueError("无法获取 Target 名称")
|
||||
|
||||
output_path = Path(output_dir)
|
||||
|
||||
|
||||
logger.info(
|
||||
"开始 URL 爬虫 - Target: %s, Tools: %s",
|
||||
target_name, ', '.join(enabled_tools.keys())
|
||||
)
|
||||
|
||||
|
||||
# Step 1: 导出站点 URL 列表
|
||||
sites_file, sites_count = _export_sites_file(
|
||||
target_id=target_id,
|
||||
scan_id=scan_id,
|
||||
target_name=target_name,
|
||||
output_dir=output_path
|
||||
output_dir=output_path,
|
||||
provider=provider
|
||||
)
|
||||
|
||||
# 默认值模式下,即使原本没有站点,也会有默认 URL 作为输入
|
||||
|
||||
@@ -34,17 +34,20 @@ logger = logging.getLogger(__name__)
|
||||
)
|
||||
def endpoints_vuln_scan_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: Dict[str, dict],
|
||||
provider,
|
||||
) -> dict:
|
||||
"""基于 Endpoint 的漏洞扫描 Flow(串行执行 Dalfox 等工具)。"""
|
||||
try:
|
||||
# 从 provider 获取 target_name
|
||||
target_name = provider.get_target_name()
|
||||
if not target_name:
|
||||
raise ValueError("无法获取 Target 名称")
|
||||
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
@@ -58,8 +61,8 @@ def endpoints_vuln_scan_flow(
|
||||
|
||||
# Step 1: 导出 Endpoint URL
|
||||
export_result = export_endpoints_task(
|
||||
target_id=target_id,
|
||||
output_file=str(endpoints_file),
|
||||
provider=provider,
|
||||
)
|
||||
total_endpoints = export_result.get("total_count", 0)
|
||||
|
||||
@@ -104,8 +107,11 @@ def endpoints_vuln_scan_flow(
|
||||
continue
|
||||
template_args = " ".join(f"-t {p}" for p in template_paths)
|
||||
|
||||
# 构建命令参数
|
||||
command_params = {"endpoints_file": str(endpoints_file)}
|
||||
# 构建命令参数(根据工具模板使用不同的参数名)
|
||||
if tool_name == "nuclei":
|
||||
command_params = {"input_file": str(endpoints_file)}
|
||||
else:
|
||||
command_params = {"endpoints_file": str(endpoints_file)}
|
||||
if template_args:
|
||||
command_params["template_args"] = template_args
|
||||
|
||||
|
||||
@@ -14,32 +14,48 @@ from apps.scan.handlers.scan_flow_handlers import (
|
||||
from apps.scan.configs.command_templates import get_command_template
|
||||
from apps.scan.utils import user_log, wait_for_system_load
|
||||
from .endpoints_vuln_scan_flow import endpoints_vuln_scan_flow
|
||||
from .websites_vuln_scan_flow import websites_vuln_scan_flow
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _classify_vuln_tools(enabled_tools: Dict[str, dict]) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
"""根据命令模板中的 input_type 对漏洞扫描工具进行分类。
|
||||
def _classify_vuln_tools(
|
||||
enabled_tools: Dict[str, dict]
|
||||
) -> Tuple[Dict[str, dict], Dict[str, dict], Dict[str, dict]]:
|
||||
"""根据用户配置分类漏洞扫描工具。
|
||||
|
||||
当前支持:
|
||||
- endpoints_file: 以端点列表文件为输入(例如 Dalfox XSS)
|
||||
预留:
|
||||
- 其他 input_type 将被归类到 other_tools,暂不处理。
|
||||
分类逻辑:
|
||||
- 读取 scan_endpoints / scan_websites 配置
|
||||
- 默认值从模板的 defaults 或 input_type 推断
|
||||
|
||||
Returns:
|
||||
(endpoints_tools, websites_tools, other_tools) 三元组
|
||||
"""
|
||||
endpoints_tools: Dict[str, dict] = {}
|
||||
websites_tools: Dict[str, dict] = {}
|
||||
other_tools: Dict[str, dict] = {}
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
template = get_command_template("vuln_scan", tool_name) or {}
|
||||
input_type = template.get("input_type", "endpoints_file")
|
||||
defaults = template.get("defaults", {})
|
||||
|
||||
if input_type == "endpoints_file":
|
||||
# 根据 input_type 推断默认值(兼容老工具)
|
||||
input_type = template.get("input_type")
|
||||
default_endpoints = defaults.get("scan_endpoints", input_type == "endpoints_file")
|
||||
default_websites = defaults.get("scan_websites", input_type == "websites_file")
|
||||
|
||||
scan_endpoints = tool_config.get("scan_endpoints", default_endpoints)
|
||||
scan_websites = tool_config.get("scan_websites", default_websites)
|
||||
|
||||
if scan_endpoints:
|
||||
endpoints_tools[tool_name] = tool_config
|
||||
else:
|
||||
if scan_websites:
|
||||
websites_tools[tool_name] = tool_config
|
||||
if not scan_endpoints and not scan_websites:
|
||||
other_tools[tool_name] = tool_config
|
||||
|
||||
return endpoints_tools, other_tools
|
||||
return endpoints_tools, websites_tools, other_tools
|
||||
|
||||
|
||||
@flow(
|
||||
@@ -51,25 +67,28 @@ def _classify_vuln_tools(enabled_tools: Dict[str, dict]) -> Tuple[Dict[str, dict
|
||||
)
|
||||
def vuln_scan_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: Dict[str, dict],
|
||||
provider,
|
||||
) -> dict:
|
||||
"""漏洞扫描主 Flow:串行编排各类漏洞扫描子 Flow。
|
||||
|
||||
支持工具:
|
||||
- dalfox_xss: XSS 漏洞扫描(流式保存)
|
||||
- nuclei: 通用漏洞扫描(流式保存,支持模板 commit hash 同步)
|
||||
- nuclei: 通用漏洞扫描(流式保存,支持 endpoints 和 websites 两种输入)
|
||||
"""
|
||||
try:
|
||||
# 负载检查:等待系统资源充足
|
||||
wait_for_system_load(context="vuln_scan_flow")
|
||||
|
||||
# 从 provider 获取 target_name
|
||||
target_name = provider.get_target_name()
|
||||
if not target_name:
|
||||
raise ValueError("无法获取 Target 名称")
|
||||
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
@@ -81,11 +100,12 @@ def vuln_scan_flow(
|
||||
user_log(scan_id, "vuln_scan", "Starting vulnerability scan")
|
||||
|
||||
# Step 1: 分类工具
|
||||
endpoints_tools, other_tools = _classify_vuln_tools(enabled_tools)
|
||||
endpoints_tools, websites_tools, other_tools = _classify_vuln_tools(enabled_tools)
|
||||
|
||||
logger.info(
|
||||
"漏洞扫描工具分类 - endpoints_file: %s, 其他: %s",
|
||||
"漏洞扫描工具分类 - endpoints: %s, websites: %s, 其他: %s",
|
||||
list(endpoints_tools.keys()) or "无",
|
||||
list(websites_tools.keys()) or "无",
|
||||
list(other_tools.keys()) or "无",
|
||||
)
|
||||
|
||||
@@ -95,28 +115,58 @@ def vuln_scan_flow(
|
||||
list(other_tools.keys()),
|
||||
)
|
||||
|
||||
if not endpoints_tools:
|
||||
raise ValueError("漏洞扫描需要至少启用一个以 endpoints_file 为输入的工具(如 dalfox_xss、nuclei)。")
|
||||
if not endpoints_tools and not websites_tools:
|
||||
raise ValueError(
|
||||
"漏洞扫描需要至少启用一个工具(endpoints 或 websites 模式)"
|
||||
)
|
||||
|
||||
# Step 2: 执行 Endpoint 漏洞扫描子 Flow(串行)
|
||||
endpoint_result = endpoints_vuln_scan_flow(
|
||||
scan_id=scan_id,
|
||||
target_name=target_name,
|
||||
target_id=target_id,
|
||||
scan_workspace_dir=scan_workspace_dir,
|
||||
enabled_tools=endpoints_tools,
|
||||
)
|
||||
total_vulns = 0
|
||||
results = {}
|
||||
|
||||
# Step 2: 执行 Endpoint 漏洞扫描子 Flow
|
||||
if endpoints_tools:
|
||||
logger.info("执行 Endpoint 漏洞扫描 - 工具: %s", list(endpoints_tools.keys()))
|
||||
endpoint_result = endpoints_vuln_scan_flow(
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
scan_workspace_dir=scan_workspace_dir,
|
||||
enabled_tools=endpoints_tools,
|
||||
provider=provider,
|
||||
)
|
||||
results["endpoints"] = endpoint_result
|
||||
total_vulns += sum(
|
||||
r.get("created_vulns", 0)
|
||||
for r in endpoint_result.get("tool_results", {}).values()
|
||||
)
|
||||
|
||||
# Step 3: 执行 WebSite 漏洞扫描子 Flow
|
||||
if websites_tools:
|
||||
logger.info("执行 WebSite 漏洞扫描 - 工具: %s", list(websites_tools.keys()))
|
||||
website_result = websites_vuln_scan_flow(
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
scan_workspace_dir=scan_workspace_dir,
|
||||
enabled_tools=websites_tools,
|
||||
provider=provider,
|
||||
)
|
||||
results["websites"] = website_result
|
||||
total_vulns += sum(
|
||||
r.get("created_vulns", 0)
|
||||
for r in website_result.get("tool_results", {}).values()
|
||||
)
|
||||
|
||||
# 记录 Flow 完成
|
||||
total_vulns = sum(
|
||||
r.get("created_vulns", 0)
|
||||
for r in endpoint_result.get("tool_results", {}).values()
|
||||
)
|
||||
logger.info("✓ 漏洞扫描完成 - 新增漏洞: %d", total_vulns)
|
||||
user_log(scan_id, "vuln_scan", f"vuln_scan completed: found {total_vulns} vulnerabilities")
|
||||
|
||||
# 目前只有一个子 Flow,直接返回其结果
|
||||
return endpoint_result
|
||||
return {
|
||||
"success": True,
|
||||
"scan_id": scan_id,
|
||||
"target": target_name,
|
||||
"scan_workspace_dir": scan_workspace_dir,
|
||||
"total_vulns": total_vulns,
|
||||
"sub_flow_results": results,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("漏洞扫描主 Flow 失败: %s", e)
|
||||
|
||||
192
backend/apps/scan/flows/vuln_scan/websites_vuln_scan_flow.py
Normal file
192
backend/apps/scan/flows/vuln_scan/websites_vuln_scan_flow.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
基于 WebSite 的漏洞扫描 Flow
|
||||
|
||||
与 endpoints_vuln_scan_flow 类似,但数据源是 WebSite 而不是 Endpoint。
|
||||
主要用于 nuclei 扫描已存活的网站。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.utils import build_scan_command, ensure_nuclei_templates_local, user_log
|
||||
from apps.scan.tasks.vuln_scan import run_and_stream_save_nuclei_vulns_task
|
||||
from apps.scan.tasks.vuln_scan.export_websites_task import export_websites_task
|
||||
from .utils import calculate_timeout_by_line_count
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@flow(
|
||||
name="websites_vuln_scan_flow",
|
||||
log_prints=True,
|
||||
)
|
||||
def websites_vuln_scan_flow(
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: Dict[str, dict],
|
||||
provider,
|
||||
) -> dict:
|
||||
"""基于 WebSite 的漏洞扫描 Flow(主要用于 nuclei)。"""
|
||||
try:
|
||||
target_name = provider.get_target_name()
|
||||
if not target_name:
|
||||
raise ValueError("无法获取 Target 名称")
|
||||
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
if not enabled_tools:
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
vuln_scan_dir = setup_scan_directory(scan_workspace_dir, 'vuln_scan')
|
||||
websites_file = vuln_scan_dir / "input_websites.txt"
|
||||
|
||||
# Step 1: 导出 WebSite URL
|
||||
export_result = export_websites_task(
|
||||
output_file=str(websites_file),
|
||||
provider=provider,
|
||||
)
|
||||
total_websites = export_result.get("total_count", 0)
|
||||
|
||||
if total_websites == 0:
|
||||
logger.warning("目标下没有可用 WebSite,跳过漏洞扫描")
|
||||
return {
|
||||
"success": True,
|
||||
"scan_id": scan_id,
|
||||
"target": target_name,
|
||||
"scan_workspace_dir": scan_workspace_dir,
|
||||
"websites_file": str(websites_file),
|
||||
"website_count": 0,
|
||||
"executed_tools": [],
|
||||
"tool_results": {},
|
||||
}
|
||||
|
||||
logger.info("WebSite 导出完成,共 %d 条,开始执行漏洞扫描", total_websites)
|
||||
|
||||
tool_results: Dict[str, dict] = {}
|
||||
tool_futures: Dict[str, dict] = {}
|
||||
|
||||
# Step 2: 执行漏洞扫描工具
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 目前只支持 nuclei
|
||||
if tool_name != "nuclei":
|
||||
logger.warning("websites_vuln_scan_flow 暂不支持工具: %s", tool_name)
|
||||
continue
|
||||
|
||||
# 确保 nuclei 模板存在
|
||||
repo_names = tool_config.get("template_repo_names")
|
||||
if not repo_names or not isinstance(repo_names, (list, tuple)):
|
||||
logger.error("Nuclei 配置缺少 template_repo_names(数组),跳过")
|
||||
continue
|
||||
|
||||
template_paths = []
|
||||
try:
|
||||
for repo_name in repo_names:
|
||||
path = ensure_nuclei_templates_local(repo_name)
|
||||
template_paths.append(path)
|
||||
logger.info("Nuclei 模板路径 [%s]: %s", repo_name, path)
|
||||
except Exception as e:
|
||||
logger.error("获取 Nuclei 模板失败: %s,跳过 nuclei 扫描", e)
|
||||
continue
|
||||
|
||||
template_args = " ".join(f"-t {p}" for p in template_paths)
|
||||
|
||||
# 构建命令(使用 websites_file 作为输入)
|
||||
command_params = {
|
||||
"input_file": str(websites_file),
|
||||
"template_args": template_args,
|
||||
}
|
||||
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type="vuln_scan",
|
||||
command_params=command_params,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
# 计算超时时间
|
||||
raw_timeout = tool_config.get("timeout", 600)
|
||||
if isinstance(raw_timeout, str) and raw_timeout == "auto":
|
||||
timeout = calculate_timeout_by_line_count(
|
||||
tool_config=tool_config,
|
||||
file_path=str(websites_file),
|
||||
base_per_time=30,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
timeout = int(raw_timeout)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"工具 {tool_name} 的 timeout 配置无效: {raw_timeout!r}"
|
||||
) from e
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
log_file = vuln_scan_dir / f"{tool_name}_websites_{timestamp}.log"
|
||||
|
||||
logger.info("开始执行 %s 漏洞扫描(WebSite 模式)", tool_name)
|
||||
user_log(scan_id, "vuln_scan", f"Running {tool_name} (websites): {command}")
|
||||
|
||||
future = run_and_stream_save_nuclei_vulns_task.submit(
|
||||
cmd=command,
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(vuln_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file),
|
||||
)
|
||||
|
||||
tool_futures[tool_name] = {
|
||||
"future": future,
|
||||
"command": command,
|
||||
"timeout": timeout,
|
||||
"log_file": str(log_file),
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
for tool_name, meta in tool_futures.items():
|
||||
future = meta["future"]
|
||||
try:
|
||||
result = future.result()
|
||||
created_vulns = result.get("created_vulns", 0)
|
||||
tool_results[tool_name] = {
|
||||
"command": meta["command"],
|
||||
"timeout": meta["timeout"],
|
||||
"processed_records": result.get("processed_records"),
|
||||
"created_vulns": created_vulns,
|
||||
"command_log_file": meta["log_file"],
|
||||
}
|
||||
logger.info("✓ 工具 %s (websites) 执行完成 - 漏洞: %d", tool_name, created_vulns)
|
||||
user_log(
|
||||
scan_id, "vuln_scan",
|
||||
f"{tool_name} (websites) completed: found {created_vulns} vulnerabilities"
|
||||
)
|
||||
except Exception as e:
|
||||
reason = str(e)
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, e, exc_info=True)
|
||||
user_log(scan_id, "vuln_scan", f"{tool_name} failed: {reason}", "error")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"scan_id": scan_id,
|
||||
"target": target_name,
|
||||
"scan_workspace_dir": scan_workspace_dir,
|
||||
"websites_file": str(websites_file),
|
||||
"website_count": total_websites,
|
||||
"executed_tools": list(enabled_tools.keys()),
|
||||
"tool_results": tool_results,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("WebSite 漏洞扫描失败: %s", e)
|
||||
raise
|
||||
@@ -0,0 +1,35 @@
|
||||
# Generated by Django 5.2.7 on 2026-01-10 03:51
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('scan', '0003_add_wecom_fields'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='scan',
|
||||
name='scan_mode',
|
||||
field=models.CharField(choices=[('full', '完整扫描'), ('quick', '快速扫描')], default='full', help_text='扫描模式:full=完整扫描,quick=快速扫描', max_length=10),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ScanInputTarget',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('value', models.CharField(help_text='用户输入的原始值', max_length=2000)),
|
||||
('input_type', models.CharField(choices=[('domain', '域名'), ('ip', 'IP地址'), ('cidr', 'CIDR'), ('url', 'URL')], help_text='输入类型', max_length=10)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='input_targets', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '扫描输入目标',
|
||||
'verbose_name_plural': '扫描输入目标',
|
||||
'db_table': 'scan_input_target',
|
||||
'indexes': [models.Index(fields=['scan'], name='scan_input__scan_id_0a3227_idx'), models.Index(fields=['input_type'], name='scan_input__input_t_e3f681_idx')],
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -4,6 +4,7 @@ from .scan_models import Scan, SoftDeleteManager
|
||||
from .scan_log_model import ScanLog
|
||||
from .scheduled_scan_model import ScheduledScan
|
||||
from .subfinder_provider_settings_model import SubfinderProviderSettings
|
||||
from .scan_input_target import ScanInputTarget
|
||||
|
||||
# 兼容旧名称(已废弃,请使用 SubfinderProviderSettings)
|
||||
ProviderSettings = SubfinderProviderSettings
|
||||
@@ -15,4 +16,5 @@ __all__ = [
|
||||
'SoftDeleteManager',
|
||||
'SubfinderProviderSettings',
|
||||
'ProviderSettings', # 兼容旧名称
|
||||
'ScanInputTarget',
|
||||
]
|
||||
|
||||
47
backend/apps/scan/models/scan_input_target.py
Normal file
47
backend/apps/scan/models/scan_input_target.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
扫描输入目标模型
|
||||
|
||||
存储快速扫描时用户输入的目标,支持大量数据(1万+)的分块迭代。
|
||||
用于快速扫描的第一阶段。
|
||||
"""
|
||||
|
||||
from django.db import models
|
||||
|
||||
|
||||
class ScanInputTarget(models.Model):
|
||||
"""扫描输入目标表"""
|
||||
|
||||
class InputType(models.TextChoices):
|
||||
"""输入类型枚举"""
|
||||
DOMAIN = 'domain', '域名'
|
||||
IP = 'ip', 'IP地址'
|
||||
CIDR = 'cidr', 'CIDR'
|
||||
URL = 'url', 'URL'
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
scan = models.ForeignKey(
|
||||
'scan.Scan',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='input_targets',
|
||||
help_text='所属的扫描任务'
|
||||
)
|
||||
value = models.CharField(max_length=2000, help_text='用户输入的原始值')
|
||||
input_type = models.CharField(
|
||||
max_length=10,
|
||||
choices=InputType.choices,
|
||||
help_text='输入类型'
|
||||
)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
"""模型元数据"""
|
||||
db_table = 'scan_input_target'
|
||||
verbose_name = '扫描输入目标'
|
||||
verbose_name_plural = '扫描输入目标'
|
||||
indexes = [
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['input_type']),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f"ScanInputTarget #{self.id} - {self.value} ({self.input_type})"
|
||||
@@ -8,17 +8,28 @@ from apps.common.definitions import ScanStatus
|
||||
|
||||
class SoftDeleteManager(models.Manager):
|
||||
"""软删除管理器:默认只返回未删除的记录"""
|
||||
|
||||
|
||||
def get_queryset(self):
|
||||
"""返回未删除记录的查询集"""
|
||||
return super().get_queryset().filter(deleted_at__isnull=True)
|
||||
|
||||
|
||||
class Scan(models.Model):
|
||||
"""扫描任务模型"""
|
||||
|
||||
class ScanMode(models.TextChoices):
|
||||
"""扫描模式枚举"""
|
||||
FULL = 'full', '完整扫描'
|
||||
QUICK = 'quick', '快速扫描'
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
|
||||
target = models.ForeignKey('targets.Target', on_delete=models.CASCADE, related_name='scans', help_text='扫描目标')
|
||||
target = models.ForeignKey(
|
||||
'targets.Target',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='scans',
|
||||
help_text='扫描目标'
|
||||
)
|
||||
|
||||
# 多引擎支持字段
|
||||
engine_ids = ArrayField(
|
||||
@@ -35,6 +46,14 @@ class Scan(models.Model):
|
||||
help_text='YAML 格式的扫描配置'
|
||||
)
|
||||
|
||||
# 扫描模式
|
||||
scan_mode = models.CharField(
|
||||
max_length=10,
|
||||
choices=ScanMode.choices,
|
||||
default=ScanMode.FULL,
|
||||
help_text='扫描模式:full=完整扫描,quick=快速扫描'
|
||||
)
|
||||
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='任务创建时间')
|
||||
stopped_at = models.DateTimeField(null=True, blank=True, help_text='扫描结束时间')
|
||||
|
||||
@@ -46,7 +65,12 @@ class Scan(models.Model):
|
||||
help_text='任务状态'
|
||||
)
|
||||
|
||||
results_dir = models.CharField(max_length=100, blank=True, default='', help_text='结果存储目录')
|
||||
results_dir = models.CharField(
|
||||
max_length=100,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='结果存储目录'
|
||||
)
|
||||
|
||||
container_ids = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
@@ -54,7 +78,7 @@ class Scan(models.Model):
|
||||
default=list,
|
||||
help_text='容器 ID 列表(Docker Container ID)'
|
||||
)
|
||||
|
||||
|
||||
worker = models.ForeignKey(
|
||||
'engine.WorkerNode',
|
||||
on_delete=models.SET_NULL,
|
||||
@@ -64,35 +88,46 @@ class Scan(models.Model):
|
||||
help_text='执行扫描的 Worker 节点'
|
||||
)
|
||||
|
||||
error_message = models.CharField(max_length=2000, blank=True, default='', help_text='错误信息')
|
||||
error_message = models.CharField(
|
||||
max_length=2000,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='错误信息'
|
||||
)
|
||||
|
||||
# ==================== 软删除字段 ====================
|
||||
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间(NULL表示未删除)')
|
||||
# 软删除字段
|
||||
deleted_at = models.DateTimeField(
|
||||
null=True,
|
||||
blank=True,
|
||||
db_index=True,
|
||||
help_text='删除时间(NULL表示未删除)'
|
||||
)
|
||||
|
||||
# ==================== 管理器 ====================
|
||||
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
|
||||
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
|
||||
# 管理器
|
||||
objects = SoftDeleteManager()
|
||||
all_objects = models.Manager()
|
||||
|
||||
# ==================== 进度跟踪字段 ====================
|
||||
# 进度跟踪字段
|
||||
progress = models.IntegerField(default=0, help_text='扫描进度 0-100')
|
||||
current_stage = models.CharField(max_length=50, blank=True, default='', help_text='当前扫描阶段')
|
||||
stage_progress = models.JSONField(default=dict, help_text='各阶段进度详情')
|
||||
|
||||
# ==================== 缓存统计字段 ====================
|
||||
cached_subdomains_count = models.IntegerField(default=0, help_text='缓存的子域名数量')
|
||||
cached_websites_count = models.IntegerField(default=0, help_text='缓存的网站数量')
|
||||
cached_endpoints_count = models.IntegerField(default=0, help_text='缓存的端点数量')
|
||||
cached_ips_count = models.IntegerField(default=0, help_text='缓存的IP地址数量')
|
||||
cached_directories_count = models.IntegerField(default=0, help_text='缓存的目录数量')
|
||||
cached_screenshots_count = models.IntegerField(default=0, help_text='缓存的截图数量')
|
||||
cached_vulns_total = models.IntegerField(default=0, help_text='缓存的漏洞总数')
|
||||
cached_vulns_critical = models.IntegerField(default=0, help_text='缓存的严重漏洞数量')
|
||||
cached_vulns_high = models.IntegerField(default=0, help_text='缓存的高危漏洞数量')
|
||||
cached_vulns_medium = models.IntegerField(default=0, help_text='缓存的中危漏洞数量')
|
||||
cached_vulns_low = models.IntegerField(default=0, help_text='缓存的低危漏洞数量')
|
||||
# 缓存统计字段
|
||||
cached_subdomains_count = models.IntegerField(default=0, help_text='子域名数量')
|
||||
cached_websites_count = models.IntegerField(default=0, help_text='网站数量')
|
||||
cached_endpoints_count = models.IntegerField(default=0, help_text='端点数量')
|
||||
cached_ips_count = models.IntegerField(default=0, help_text='IP地址数量')
|
||||
cached_directories_count = models.IntegerField(default=0, help_text='目录数量')
|
||||
cached_screenshots_count = models.IntegerField(default=0, help_text='截图数量')
|
||||
cached_vulns_total = models.IntegerField(default=0, help_text='漏洞总数')
|
||||
cached_vulns_critical = models.IntegerField(default=0, help_text='严重漏洞数量')
|
||||
cached_vulns_high = models.IntegerField(default=0, help_text='高危漏洞数量')
|
||||
cached_vulns_medium = models.IntegerField(default=0, help_text='中危漏洞数量')
|
||||
cached_vulns_low = models.IntegerField(default=0, help_text='低危漏洞数量')
|
||||
stats_updated_at = models.DateTimeField(null=True, blank=True, help_text='统计数据最后更新时间')
|
||||
|
||||
class Meta:
|
||||
"""模型元数据配置"""
|
||||
db_table = 'scan'
|
||||
verbose_name = '扫描任务'
|
||||
verbose_name_plural = '扫描任务'
|
||||
|
||||
@@ -3,54 +3,49 @@
|
||||
|
||||
提供统一的目标获取接口,支持多种数据源:
|
||||
- DatabaseTargetProvider: 从数据库查询(完整扫描)
|
||||
- ListTargetProvider: 使用内存列表(快速扫描阶段1)
|
||||
- SnapshotTargetProvider: 从快照表读取(快速扫描阶段2+)
|
||||
- PipelineTargetProvider: 使用管道输出(Phase 2)
|
||||
- SnapshotTargetProvider: 从快照表读取(快速扫描)
|
||||
|
||||
Provider 方法:
|
||||
- get_target_name(): Target 名称(根域名/IP/CIDR)
|
||||
- iter_subdomains(): 子域名列表
|
||||
- iter_host_port_urls(): 从 host:port 生成的 URL(站点探测用)
|
||||
- iter_websites(): 已存活网站 URL(截图、指纹、目录扫描用)
|
||||
- iter_endpoints(): 端点 URL(漏洞扫描用)
|
||||
|
||||
使用方式:
|
||||
from apps.scan.providers import (
|
||||
DatabaseTargetProvider,
|
||||
ListTargetProvider,
|
||||
SnapshotTargetProvider,
|
||||
ProviderContext
|
||||
)
|
||||
|
||||
|
||||
# 数据库模式(完整扫描)
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
|
||||
# 列表模式(快速扫描阶段1)
|
||||
context = ProviderContext(target_id=1, scan_id=100)
|
||||
provider = ListTargetProvider(
|
||||
targets=["a.test.com"],
|
||||
context=context
|
||||
)
|
||||
|
||||
# 快照模式(快速扫描阶段2+)
|
||||
context = ProviderContext(target_id=1, scan_id=100)
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain",
|
||||
context=context
|
||||
)
|
||||
|
||||
# 使用 Provider
|
||||
for host in provider.iter_hosts():
|
||||
scan(host)
|
||||
|
||||
# 端口扫描:显式组合 target_name + subdomains
|
||||
target_name = provider.get_target_name()
|
||||
if target_name:
|
||||
scan_port(target_name) # CIDR 需要调用方自己展开
|
||||
for subdomain in provider.iter_subdomains():
|
||||
scan_port(subdomain)
|
||||
|
||||
# 截图
|
||||
for url in provider.iter_websites():
|
||||
take_screenshot(url)
|
||||
|
||||
# 快照模式(快速扫描)
|
||||
provider = SnapshotTargetProvider(scan_id=100)
|
||||
for url in provider.iter_websites():
|
||||
take_screenshot(url)
|
||||
"""
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
from .list_provider import ListTargetProvider
|
||||
from .database_provider import DatabaseTargetProvider
|
||||
from .snapshot_provider import SnapshotTargetProvider, SnapshotType
|
||||
from .pipeline_provider import PipelineTargetProvider, StageOutput
|
||||
from .snapshot_provider import SnapshotTargetProvider
|
||||
|
||||
__all__ = [
|
||||
'TargetProvider',
|
||||
'ProviderContext',
|
||||
'ListTargetProvider',
|
||||
'DatabaseTargetProvider',
|
||||
'SnapshotTargetProvider',
|
||||
'SnapshotType',
|
||||
'PipelineTargetProvider',
|
||||
'StageOutput',
|
||||
]
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
定义 ProviderContext 数据类和 TargetProvider 抽象基类。
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
@@ -37,72 +36,184 @@ class TargetProvider(ABC):
|
||||
- 提供扫描目标(域名、IP、URL 等)的迭代器
|
||||
- 提供黑名单过滤器
|
||||
- 携带上下文信息(target_id, scan_id 等)
|
||||
- 自动展开 CIDR(子类无需关心)
|
||||
|
||||
方法说明:
|
||||
- get_target_name(): Target 名称(根域名/IP/CIDR)
|
||||
- iter_subdomains(): 子域名列表
|
||||
- iter_host_port_urls(): 从 host:port 生成的 URL(站点探测用)
|
||||
- iter_websites(): 已存活网站 URL(截图、指纹、目录扫描用)
|
||||
- iter_endpoints(): 端点 URL(漏洞扫描用)
|
||||
|
||||
使用方式:
|
||||
provider = create_target_provider(target_id=123)
|
||||
for host in provider.iter_hosts():
|
||||
print(host)
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
|
||||
# 端口扫描:显式组合 target_name + subdomains
|
||||
target_name = provider.get_target_name()
|
||||
if target_name:
|
||||
scan_port(target_name) # CIDR 需要调用方自己展开
|
||||
for subdomain in provider.iter_subdomains():
|
||||
scan_port(subdomain)
|
||||
|
||||
# 截图
|
||||
for url in provider.iter_websites():
|
||||
take_screenshot(url)
|
||||
"""
|
||||
|
||||
def __init__(self, context: Optional[ProviderContext] = None):
|
||||
self._context = context or ProviderContext()
|
||||
self._target_name: Optional[str] = None # 缓存 target_name
|
||||
|
||||
@property
|
||||
def context(self) -> ProviderContext:
|
||||
"""返回 Provider 上下文"""
|
||||
return self._context
|
||||
|
||||
@staticmethod
|
||||
def _expand_host(host: str) -> Iterator[str]:
|
||||
def get_target_name(self) -> Optional[str]:
|
||||
"""
|
||||
展开主机(如果是 CIDR 则展开为多个 IP,否则直接返回)
|
||||
获取 Target 名称(根域名/IP/CIDR)
|
||||
|
||||
示例:
|
||||
"192.168.1.0/30" → "192.168.1.1", "192.168.1.2"
|
||||
"192.168.1.1" → "192.168.1.1"
|
||||
"example.com" → "example.com"
|
||||
Returns:
|
||||
Target 名称,不存在时返回 None
|
||||
注意:CIDR 不会自动展开,调用方需要自己处理
|
||||
"""
|
||||
# 使用缓存避免重复查询
|
||||
if self._target_name is not None:
|
||||
return self._target_name
|
||||
|
||||
if not self.target_id:
|
||||
logger.warning("target_id 未设置,无法获取 Target 名称")
|
||||
return None
|
||||
|
||||
from apps.targets.services import TargetService
|
||||
|
||||
target = TargetService().get_target(self.target_id)
|
||||
self._target_name = target.name if target else None
|
||||
return self._target_name
|
||||
|
||||
def iter_target_hosts(self) -> Iterator[str]:
|
||||
"""
|
||||
迭代 Target 展开后的主机列表(已过滤黑名单)
|
||||
|
||||
- DOMAIN/IP: 直接返回
|
||||
- CIDR: 展开为所有 IP
|
||||
|
||||
Returns:
|
||||
主机迭代器(域名或 IP)
|
||||
"""
|
||||
import ipaddress
|
||||
|
||||
from apps.common.validators import detect_target_type
|
||||
from apps.targets.models import Target
|
||||
|
||||
host = host.strip()
|
||||
if not host:
|
||||
target_name = self.get_target_name()
|
||||
if not target_name:
|
||||
return
|
||||
|
||||
try:
|
||||
target_type = detect_target_type(host)
|
||||
blacklist = self.get_blacklist_filter()
|
||||
target_type = detect_target_type(target_name)
|
||||
|
||||
if target_type == Target.TargetType.CIDR:
|
||||
network = ipaddress.ip_network(host, strict=False)
|
||||
if network.num_addresses == 1:
|
||||
yield str(network.network_address)
|
||||
else:
|
||||
yield from (str(ip) for ip in network.hosts())
|
||||
elif target_type in (Target.TargetType.IP, Target.TargetType.DOMAIN):
|
||||
if target_type == Target.TargetType.CIDR:
|
||||
# CIDR 展开
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
if network.num_addresses == 1:
|
||||
hosts = [str(network.network_address)]
|
||||
else:
|
||||
hosts = [str(ip) for ip in network.hosts()]
|
||||
else:
|
||||
# DOMAIN / IP 直接返回
|
||||
hosts = [target_name]
|
||||
|
||||
for host in hosts:
|
||||
if not blacklist or blacklist.is_allowed(host):
|
||||
yield host
|
||||
except ValueError as e:
|
||||
logger.warning("跳过无效的主机格式 '%s': %s", host, str(e))
|
||||
|
||||
def iter_hosts(self) -> Iterator[str]:
|
||||
"""迭代主机列表(域名/IP),自动展开 CIDR"""
|
||||
for host in self._iter_raw_hosts():
|
||||
yield from self._expand_host(host)
|
||||
|
||||
@abstractmethod
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""迭代原始主机列表(可能包含 CIDR),子类实现"""
|
||||
pass
|
||||
def iter_subdomains(self) -> Iterator[str]:
|
||||
"""迭代子域名列表,子类实现"""
|
||||
|
||||
@abstractmethod
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""迭代 URL 列表"""
|
||||
pass
|
||||
def iter_host_port_urls(self) -> Iterator[str]:
|
||||
"""
|
||||
迭代 host:port 生成的 URL(待探测)
|
||||
|
||||
用于站点扫描(httpx 探测),从 HostPortMapping 生成 URL。
|
||||
返回格式:http://host:port 或 https://host:port
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def iter_websites(self) -> Iterator[str]:
|
||||
"""
|
||||
迭代已存活网站 URL
|
||||
|
||||
用于截图、指纹识别、目录扫描、URL 爬虫。
|
||||
数据来源:WebSite 表(已确认存活的网站)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def iter_endpoints(self) -> Iterator[str]:
|
||||
"""
|
||||
迭代端点 URL
|
||||
|
||||
用于漏洞扫描。
|
||||
数据来源:Endpoint 表(带参数的 URL)
|
||||
"""
|
||||
|
||||
def iter_default_urls(self) -> Iterator[str]:
|
||||
"""
|
||||
从 Target 本身生成默认 URL
|
||||
|
||||
用于跳过前置阶段直接扫描的场景。
|
||||
根据 Target 类型生成:
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
"""
|
||||
import ipaddress
|
||||
|
||||
from apps.targets.models import Target
|
||||
from apps.targets.services import TargetService
|
||||
|
||||
if not self.target_id:
|
||||
logger.warning("target_id 未设置,无法生成默认 URL")
|
||||
return
|
||||
|
||||
target = TargetService().get_target(self.target_id)
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", self.target_id)
|
||||
return
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
blacklist = self.get_blacklist_filter()
|
||||
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
urls = [f"http://{target_name}", f"https://{target_name}"]
|
||||
elif target_type == Target.TargetType.IP:
|
||||
urls = [f"http://{target_name}", f"https://{target_name}"]
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
urls = []
|
||||
for ip in network.hosts():
|
||||
urls.extend([f"http://{ip}", f"https://{ip}"])
|
||||
# /32 或 /128 特殊处理
|
||||
if not urls:
|
||||
ip = str(network.network_address)
|
||||
urls = [f"http://{ip}", f"https://{ip}"]
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
return
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
return
|
||||
|
||||
for url in urls:
|
||||
if not blacklist or blacklist.is_allowed(url):
|
||||
yield url
|
||||
|
||||
@abstractmethod
|
||||
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
|
||||
"""获取黑名单过滤器,返回 None 表示不过滤"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def target_id(self) -> Optional[int]:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
数据库目标提供者模块
|
||||
|
||||
提供基于数据库查询的目标提供者实现。
|
||||
用于完整扫描模式,从 Target 关联的资产表查询数据。
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -19,14 +20,33 @@ class DatabaseTargetProvider(TargetProvider):
|
||||
"""
|
||||
数据库目标提供者 - 从 Target 表及关联资产表查询
|
||||
|
||||
用于完整扫描模式,查询目标下的所有历史资产。
|
||||
|
||||
数据来源:
|
||||
- iter_hosts(): 根据 Target 类型返回域名/IP
|
||||
- iter_urls(): WebSite/Endpoint 表,带回退链
|
||||
- iter_target_name(): Target 表(根域名/IP/CIDR)
|
||||
- iter_subdomains(): Subdomain 表
|
||||
- iter_host_port_urls(): HostPortMapping 表
|
||||
- iter_websites(): WebSite 表
|
||||
- iter_endpoints(): Endpoint 表
|
||||
- iter_default_urls(): 从 Target 本身生成默认 URL
|
||||
|
||||
回退逻辑由调用方(Task/Flow)决定,Provider 只负责单一数据源查询。
|
||||
|
||||
使用方式:
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
for host in provider.iter_hosts():
|
||||
scan(host)
|
||||
|
||||
# 端口扫描:显式组合
|
||||
for name in provider.iter_target_name():
|
||||
scan_port(name) # CIDR 需要调用方自己展开
|
||||
for subdomain in provider.iter_subdomains():
|
||||
scan_port(subdomain)
|
||||
|
||||
# 调用方控制回退
|
||||
urls = list(provider.iter_endpoints())
|
||||
if not urls:
|
||||
urls = list(provider.iter_websites())
|
||||
if not urls:
|
||||
urls = list(provider.iter_default_urls())
|
||||
"""
|
||||
|
||||
def __init__(self, target_id: int, context: Optional[ProviderContext] = None):
|
||||
@@ -35,53 +55,73 @@ class DatabaseTargetProvider(TargetProvider):
|
||||
super().__init__(ctx)
|
||||
self._blacklist_filter: Optional['BlacklistFilter'] = None
|
||||
|
||||
def iter_hosts(self) -> Iterator[str]:
|
||||
"""从数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤"""
|
||||
blacklist = self.get_blacklist_filter()
|
||||
|
||||
for host in self._iter_raw_hosts():
|
||||
for expanded_host in self._expand_host(host):
|
||||
if not blacklist or blacklist.is_allowed(expanded_host):
|
||||
yield expanded_host
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""从数据库查询原始主机列表(可能包含 CIDR)"""
|
||||
def iter_subdomains(self) -> Iterator[str]:
|
||||
"""从 Subdomain 表查询子域名列表"""
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
from apps.targets.models import Target
|
||||
from apps.targets.services import TargetService
|
||||
|
||||
target = TargetService().get_target(self.target_id)
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在", self.target_id)
|
||||
return
|
||||
|
||||
if target.type == Target.TargetType.DOMAIN:
|
||||
yield target.name
|
||||
for domain in SubdomainService().iter_subdomain_names_by_target(
|
||||
target_id=self.target_id,
|
||||
chunk_size=1000
|
||||
):
|
||||
if domain != target.name:
|
||||
yield domain
|
||||
|
||||
elif target.type in (Target.TargetType.IP, Target.TargetType.CIDR):
|
||||
yield target.name
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""从数据库查询 URL 列表,使用回退链:Endpoint → WebSite → Default"""
|
||||
from apps.scan.services.target_export_service import (
|
||||
DataSource,
|
||||
_iter_urls_with_fallback,
|
||||
)
|
||||
|
||||
blacklist = self.get_blacklist_filter()
|
||||
|
||||
for url, _ in _iter_urls_with_fallback(
|
||||
for domain in SubdomainService().iter_subdomain_names_by_target(
|
||||
target_id=self.target_id,
|
||||
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
|
||||
blacklist_filter=blacklist
|
||||
chunk_size=1000
|
||||
):
|
||||
yield url
|
||||
if not blacklist or blacklist.is_allowed(domain):
|
||||
yield domain
|
||||
|
||||
def iter_host_port_urls(self) -> Iterator[str]:
|
||||
"""从 HostPortMapping 表生成待探测的 URL"""
|
||||
from apps.asset.models import HostPortMapping
|
||||
|
||||
blacklist = self.get_blacklist_filter()
|
||||
|
||||
queryset = HostPortMapping.objects.filter(
|
||||
target_id=self.target_id
|
||||
).values('host', 'port').distinct()
|
||||
|
||||
for mapping in queryset.iterator(chunk_size=1000):
|
||||
host = mapping['host']
|
||||
port = mapping['port']
|
||||
|
||||
if port == 80:
|
||||
urls = [f"http://{host}"]
|
||||
elif port == 443:
|
||||
urls = [f"https://{host}"]
|
||||
else:
|
||||
urls = [f"http://{host}:{port}", f"https://{host}:{port}"]
|
||||
|
||||
for url in urls:
|
||||
if not blacklist or blacklist.is_allowed(url):
|
||||
yield url
|
||||
|
||||
def iter_websites(self) -> Iterator[str]:
|
||||
"""从 WebSite 表查询已存活网站 URL"""
|
||||
from apps.asset.models import WebSite
|
||||
|
||||
blacklist = self.get_blacklist_filter()
|
||||
|
||||
queryset = WebSite.objects.filter(
|
||||
target_id=self.target_id
|
||||
).values_list('url', flat=True)
|
||||
|
||||
for url in queryset.iterator(chunk_size=1000):
|
||||
if url:
|
||||
if not blacklist or blacklist.is_allowed(url):
|
||||
yield url
|
||||
|
||||
def iter_endpoints(self) -> Iterator[str]:
|
||||
"""从 Endpoint 表查询端点 URL"""
|
||||
from apps.asset.models import Endpoint
|
||||
|
||||
blacklist = self.get_blacklist_filter()
|
||||
|
||||
queryset = Endpoint.objects.filter(
|
||||
target_id=self.target_id
|
||||
).values_list('url', flat=True)
|
||||
|
||||
for url in queryset.iterator(chunk_size=1000):
|
||||
if url:
|
||||
if not blacklist or blacklist.is_allowed(url):
|
||||
yield url
|
||||
|
||||
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
|
||||
"""获取黑名单过滤器(延迟加载)"""
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
"""
|
||||
列表目标提供者模块
|
||||
|
||||
提供基于内存列表的目标提供者实现。
|
||||
"""
|
||||
|
||||
from typing import Iterator, Optional, List
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
|
||||
|
||||
class ListTargetProvider(TargetProvider):
|
||||
"""
|
||||
列表目标提供者 - 直接使用内存中的列表
|
||||
|
||||
用于快速扫描、临时扫描等场景,只扫描用户指定的目标。
|
||||
|
||||
特点:
|
||||
- 不查询数据库
|
||||
- 不应用黑名单过滤(用户明确指定的目标)
|
||||
- 不关联 target_id(由调用方负责创建 Target)
|
||||
- 自动检测输入类型(URL/域名/IP/CIDR)
|
||||
- 自动展开 CIDR
|
||||
|
||||
使用方式:
|
||||
# 快速扫描:用户提供目标,自动识别类型
|
||||
provider = ListTargetProvider(targets=[
|
||||
"example.com", # 域名
|
||||
"192.168.1.0/24", # CIDR(自动展开)
|
||||
"https://api.example.com" # URL
|
||||
])
|
||||
for host in provider.iter_hosts():
|
||||
scan(host)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
targets: Optional[List[str]] = None,
|
||||
context: Optional[ProviderContext] = None
|
||||
):
|
||||
"""
|
||||
初始化列表目标提供者
|
||||
|
||||
Args:
|
||||
targets: 目标列表(自动识别类型:URL/域名/IP/CIDR)
|
||||
context: Provider 上下文
|
||||
"""
|
||||
from apps.common.validators import detect_input_type
|
||||
|
||||
ctx = context or ProviderContext()
|
||||
super().__init__(ctx)
|
||||
|
||||
# 自动分类目标
|
||||
self._hosts = []
|
||||
self._urls = []
|
||||
|
||||
if targets:
|
||||
for target in targets:
|
||||
target = target.strip()
|
||||
if not target:
|
||||
continue
|
||||
|
||||
try:
|
||||
input_type = detect_input_type(target)
|
||||
if input_type == 'url':
|
||||
self._urls.append(target)
|
||||
else:
|
||||
# domain/ip/cidr 都作为 host
|
||||
self._hosts.append(target)
|
||||
except ValueError:
|
||||
# 无法识别类型,默认作为 host
|
||||
self._hosts.append(target)
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""迭代原始主机列表(可能包含 CIDR)"""
|
||||
yield from self._hosts
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""迭代 URL 列表"""
|
||||
yield from self._urls
|
||||
|
||||
def get_blacklist_filter(self) -> None:
|
||||
"""列表模式不使用黑名单过滤"""
|
||||
return None
|
||||
@@ -1,91 +0,0 @@
|
||||
"""
|
||||
管道目标提供者模块
|
||||
|
||||
提供基于管道阶段输出的目标提供者实现。
|
||||
用于 Phase 2 管道模式的阶段间数据传递。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator, Optional, List, Dict, Any
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageOutput:
|
||||
"""
|
||||
阶段输出数据
|
||||
|
||||
用于在管道阶段之间传递数据。
|
||||
|
||||
Attributes:
|
||||
hosts: 主机列表(域名/IP)
|
||||
urls: URL 列表
|
||||
new_targets: 新发现的目标列表
|
||||
stats: 统计信息
|
||||
success: 是否成功
|
||||
error: 错误信息
|
||||
"""
|
||||
hosts: List[str] = field(default_factory=list)
|
||||
urls: List[str] = field(default_factory=list)
|
||||
new_targets: List[str] = field(default_factory=list)
|
||||
stats: Dict[str, Any] = field(default_factory=dict)
|
||||
success: bool = True
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class PipelineTargetProvider(TargetProvider):
|
||||
"""
|
||||
管道目标提供者 - 使用上一阶段的输出
|
||||
|
||||
用于 Phase 2 管道模式的阶段间数据传递。
|
||||
|
||||
特点:
|
||||
- 不查询数据库
|
||||
- 不应用黑名单过滤(数据已在上一阶段过滤)
|
||||
- 直接使用 StageOutput 中的数据
|
||||
|
||||
使用方式(Phase 2):
|
||||
stage1_output = stage1.run(input)
|
||||
provider = PipelineTargetProvider(
|
||||
previous_output=stage1_output,
|
||||
target_id=123
|
||||
)
|
||||
for host in provider.iter_hosts():
|
||||
stage2.scan(host)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
previous_output: StageOutput,
|
||||
target_id: Optional[int] = None,
|
||||
context: Optional[ProviderContext] = None
|
||||
):
|
||||
"""
|
||||
初始化管道目标提供者
|
||||
|
||||
Args:
|
||||
previous_output: 上一阶段的输出
|
||||
target_id: 可选,关联到某个 Target(用于保存结果)
|
||||
context: Provider 上下文
|
||||
"""
|
||||
ctx = context or ProviderContext(target_id=target_id)
|
||||
super().__init__(ctx)
|
||||
self._previous_output = previous_output
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""迭代上一阶段输出的原始主机(可能包含 CIDR)"""
|
||||
yield from self._previous_output.hosts
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""迭代上一阶段输出的 URL"""
|
||||
yield from self._previous_output.urls
|
||||
|
||||
def get_blacklist_filter(self) -> None:
|
||||
"""管道传递的数据已经过滤过了"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def previous_output(self) -> StageOutput:
|
||||
"""返回上一阶段的输出"""
|
||||
return self._previous_output
|
||||
@@ -6,170 +6,106 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Iterator, Optional, Literal
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
from .base import ProviderContext, TargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 快照类型定义
|
||||
SnapshotType = Literal["subdomain", "website", "endpoint", "host_port"]
|
||||
|
||||
|
||||
class SnapshotTargetProvider(TargetProvider):
|
||||
"""
|
||||
快照目标提供者 - 从快照表读取本次扫描的数据
|
||||
|
||||
|
||||
用于快速扫描的阶段间数据传递,解决精确扫描控制问题。
|
||||
|
||||
|
||||
核心价值:
|
||||
- 只返回本次扫描(scan_id)发现的资产
|
||||
- 避免扫描历史数据(DatabaseTargetProvider 会扫描所有历史资产)
|
||||
|
||||
|
||||
特点:
|
||||
- 通过 scan_id 过滤快照表
|
||||
- 不应用黑名单过滤(数据已在上一阶段过滤)
|
||||
- 支持多种快照类型(subdomain/website/endpoint/host_port)
|
||||
|
||||
- 每个 iter_* 方法只查对应的快照表(单一职责)
|
||||
- 回退逻辑由调用方(Task/Flow)决定
|
||||
|
||||
使用场景:
|
||||
# 快速扫描流程
|
||||
用户输入: a.test.com
|
||||
创建 Target: test.com (id=1)
|
||||
创建 Scan: scan_id=100
|
||||
|
||||
# 阶段1: 子域名发现
|
||||
provider = ListTargetProvider(
|
||||
targets=["a.test.com"],
|
||||
context=ProviderContext(target_id=1, scan_id=100)
|
||||
)
|
||||
# 发现: b.a.test.com, c.a.test.com
|
||||
# 保存: SubdomainSnapshot(scan_id=100) + Subdomain(target_id=1)
|
||||
|
||||
# 阶段2: 端口扫描
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain",
|
||||
context=ProviderContext(target_id=1, scan_id=100)
|
||||
)
|
||||
# 只返回: b.a.test.com, c.a.test.com(本次扫描发现的)
|
||||
# 不返回: www.test.com, api.test.com(历史数据)
|
||||
|
||||
# 阶段3: 网站扫描
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="host_port",
|
||||
context=ProviderContext(target_id=1, scan_id=100)
|
||||
)
|
||||
# 只返回本次扫描发现的 IP:Port
|
||||
provider = SnapshotTargetProvider(scan_id=100)
|
||||
|
||||
# 单一数据源
|
||||
for url in provider.iter_websites():
|
||||
take_screenshot(url)
|
||||
|
||||
# 调用方控制回退
|
||||
urls = list(provider.iter_endpoints())
|
||||
if not urls:
|
||||
urls = list(provider.iter_websites())
|
||||
if not urls:
|
||||
urls = list(provider.iter_default_urls())
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scan_id: int,
|
||||
snapshot_type: SnapshotType,
|
||||
context: Optional[ProviderContext] = None
|
||||
):
|
||||
"""
|
||||
初始化快照目标提供者
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID(必需)
|
||||
snapshot_type: 快照类型
|
||||
- "subdomain": 子域名快照(SubdomainSnapshot)
|
||||
- "website": 网站快照(WebsiteSnapshot)
|
||||
- "endpoint": 端点快照(EndpointSnapshot)
|
||||
- "host_port": 主机端口映射快照(HostPortMappingSnapshot)
|
||||
context: Provider 上下文
|
||||
"""
|
||||
ctx = context or ProviderContext()
|
||||
ctx.scan_id = scan_id
|
||||
super().__init__(ctx)
|
||||
self._scan_id = scan_id
|
||||
self._snapshot_type = snapshot_type
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""
|
||||
从快照表迭代主机列表
|
||||
|
||||
根据 snapshot_type 选择不同的快照表:
|
||||
- subdomain: SubdomainSnapshot.name
|
||||
- host_port: HostPortMappingSnapshot.host (返回 host:port 格式,不经过验证)
|
||||
"""
|
||||
if self._snapshot_type == "subdomain":
|
||||
from apps.asset.services.snapshot import SubdomainSnapshotsService
|
||||
service = SubdomainSnapshotsService()
|
||||
yield from service.iter_subdomain_names_by_scan(
|
||||
scan_id=self._scan_id,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
elif self._snapshot_type == "host_port":
|
||||
# host_port 类型不使用 _iter_raw_hosts,直接在 iter_hosts 中处理
|
||||
# 这里返回空,避免被基类的 iter_hosts 调用
|
||||
return
|
||||
|
||||
else:
|
||||
# 其他类型暂不支持 iter_hosts
|
||||
logger.warning(
|
||||
"快照类型 '%s' 不支持 iter_hosts,返回空迭代器",
|
||||
self._snapshot_type
|
||||
)
|
||||
return
|
||||
|
||||
def iter_hosts(self) -> Iterator[str]:
|
||||
"""
|
||||
迭代主机列表
|
||||
|
||||
对于 host_port 类型,返回 host:port 格式,不经过 CIDR 展开验证
|
||||
"""
|
||||
if self._snapshot_type == "host_port":
|
||||
# host_port 类型直接返回 host:port,不经过 _expand_host 验证
|
||||
from apps.asset.services.snapshot import HostPortMappingSnapshotsService
|
||||
service = HostPortMappingSnapshotsService()
|
||||
queryset = service.get_by_scan(scan_id=self._scan_id)
|
||||
for mapping in queryset.iterator(chunk_size=1000):
|
||||
yield f"{mapping.host}:{mapping.port}"
|
||||
else:
|
||||
# 其他类型使用基类的 iter_hosts(会调用 _iter_raw_hosts 并展开 CIDR)
|
||||
yield from super().iter_hosts()
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""
|
||||
从快照表迭代 URL 列表
|
||||
|
||||
根据 snapshot_type 选择不同的快照表:
|
||||
- website: WebsiteSnapshot.url
|
||||
- endpoint: EndpointSnapshot.url
|
||||
"""
|
||||
if self._snapshot_type == "website":
|
||||
from apps.asset.services.snapshot import WebsiteSnapshotsService
|
||||
service = WebsiteSnapshotsService()
|
||||
yield from service.iter_website_urls_by_scan(
|
||||
scan_id=self._scan_id,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
elif self._snapshot_type == "endpoint":
|
||||
from apps.asset.services.snapshot import EndpointSnapshotsService
|
||||
service = EndpointSnapshotsService()
|
||||
# 从快照表获取端点 URL
|
||||
queryset = service.get_by_scan(scan_id=self._scan_id)
|
||||
for endpoint in queryset.iterator(chunk_size=1000):
|
||||
yield endpoint.url
|
||||
|
||||
else:
|
||||
# 其他类型暂不支持 iter_urls
|
||||
logger.warning(
|
||||
"快照类型 '%s' 不支持 iter_urls,返回空迭代器",
|
||||
self._snapshot_type
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def iter_subdomains(self) -> Iterator[str]:
|
||||
"""从 SubdomainSnapshot 迭代子域名列表"""
|
||||
from apps.asset.services.snapshot import SubdomainSnapshotsService
|
||||
service = SubdomainSnapshotsService()
|
||||
yield from service.iter_subdomain_names_by_scan(
|
||||
scan_id=self._scan_id,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
def iter_host_port_urls(self) -> Iterator[str]:
|
||||
"""从 HostPortMappingSnapshot 生成待探测的 URL"""
|
||||
from apps.asset.services.snapshot import HostPortMappingSnapshotsService
|
||||
service = HostPortMappingSnapshotsService()
|
||||
|
||||
for mapping in service.iter_unique_host_ports_by_scan(
|
||||
scan_id=self._scan_id,
|
||||
batch_size=1000
|
||||
):
|
||||
host = mapping['host']
|
||||
port = mapping['port']
|
||||
if port == 80:
|
||||
yield f"http://{host}"
|
||||
elif port == 443:
|
||||
yield f"https://{host}"
|
||||
else:
|
||||
yield f"http://{host}:{port}"
|
||||
yield f"https://{host}:{port}"
|
||||
|
||||
def iter_websites(self) -> Iterator[str]:
|
||||
"""从 WebsiteSnapshot 迭代网站 URL"""
|
||||
from apps.asset.services.snapshot import WebsiteSnapshotsService
|
||||
service = WebsiteSnapshotsService()
|
||||
yield from service.iter_website_urls_by_scan(
|
||||
scan_id=self._scan_id,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
def iter_endpoints(self) -> Iterator[str]:
|
||||
"""从 EndpointSnapshot 迭代端点 URL"""
|
||||
from apps.asset.services.snapshot import EndpointSnapshotsService
|
||||
service = EndpointSnapshotsService()
|
||||
queryset = service.get_by_scan(scan_id=self._scan_id)
|
||||
for endpoint in queryset.iterator(chunk_size=1000):
|
||||
yield endpoint.url
|
||||
|
||||
def get_blacklist_filter(self) -> None:
|
||||
"""快照数据已在上一阶段过滤过了"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def snapshot_type(self) -> SnapshotType:
|
||||
"""返回快照类型"""
|
||||
return self._snapshot_type
|
||||
|
||||
@@ -1,256 +0,0 @@
|
||||
"""
|
||||
通用属性测试
|
||||
|
||||
包含跨多个 Provider 的通用属性测试:
|
||||
- Property 4: Context Propagation
|
||||
- Property 5: Non-Database Provider Blacklist Filter
|
||||
- Property 7: CIDR Expansion Consistency
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st, settings
|
||||
from ipaddress import IPv4Network
|
||||
|
||||
from apps.scan.providers import (
|
||||
ProviderContext,
|
||||
ListTargetProvider,
|
||||
DatabaseTargetProvider,
|
||||
PipelineTargetProvider,
|
||||
SnapshotTargetProvider
|
||||
)
|
||||
from apps.scan.providers.pipeline_provider import StageOutput
|
||||
|
||||
|
||||
class TestContextPropagation:
|
||||
"""
|
||||
Property 4: Context Propagation
|
||||
|
||||
*For any* ProviderContext,传入 Provider 构造函数后,
|
||||
Provider 的 target_id 和 scan_id 属性应该与 context 中的值一致。
|
||||
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_list_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (ListTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
|
||||
provider = ListTargetProvider(targets=["example.com"], context=ctx)
|
||||
|
||||
assert provider.target_id == target_id
|
||||
assert provider.scan_id == scan_id
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_database_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (DatabaseTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=999, scan_id=scan_id)
|
||||
# DatabaseTargetProvider 会覆盖 context 中的 target_id
|
||||
provider = DatabaseTargetProvider(target_id=target_id, context=ctx)
|
||||
|
||||
assert provider.target_id == target_id # 使用构造函数参数
|
||||
assert provider.scan_id == scan_id # 使用 context 中的值
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_pipeline_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (PipelineTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
|
||||
|
||||
assert provider.target_id == target_id
|
||||
assert provider.scan_id == scan_id
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_snapshot_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (SnapshotTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=target_id, scan_id=999)
|
||||
# SnapshotTargetProvider 会覆盖 context 中的 scan_id
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=scan_id,
|
||||
snapshot_type="subdomain",
|
||||
context=ctx
|
||||
)
|
||||
|
||||
assert provider.target_id == target_id # 使用 context 中的值
|
||||
assert provider.scan_id == scan_id # 使用构造函数参数
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
|
||||
class TestNonDatabaseProviderBlacklistFilter:
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter
|
||||
|
||||
*For any* ListTargetProvider 或 PipelineTargetProvider 实例,
|
||||
get_blacklist_filter() 方法应该返回 None。
|
||||
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
|
||||
@given(targets=st.lists(st.text(min_size=1, max_size=20), max_size=10))
|
||||
@settings(max_examples=100)
|
||||
def test_property_5_list_provider_no_blacklist(self, targets):
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter (ListTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
provider = ListTargetProvider(targets=targets)
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
@given(hosts=st.lists(st.text(min_size=1, max_size=20), max_size=10))
|
||||
@settings(max_examples=100)
|
||||
def test_property_5_pipeline_provider_no_blacklist(self, hosts):
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter (PipelineTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
stage_output = StageOutput(hosts=hosts)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_property_5_snapshot_provider_no_blacklist(self):
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter (SnapshotTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
provider = SnapshotTargetProvider(scan_id=1, snapshot_type="subdomain")
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
|
||||
class TestCIDRExpansionConsistency:
|
||||
"""
|
||||
Property 7: CIDR Expansion Consistency
|
||||
|
||||
*For any* CIDR 字符串(如 "192.168.1.0/24"),所有 Provider 的 iter_hosts()
|
||||
方法应该将其展开为相同的单个 IP 地址列表。
|
||||
|
||||
**Validates: Requirements 1.1, 3.6**
|
||||
"""
|
||||
|
||||
@given(
|
||||
# 生成小的 CIDR 范围以避免测试超时
|
||||
network_prefix=st.integers(min_value=1, max_value=254),
|
||||
cidr_suffix=st.integers(min_value=28, max_value=30) # /28 = 16 IPs, /30 = 4 IPs
|
||||
)
|
||||
@settings(max_examples=50, deadline=None)
|
||||
def test_property_7_cidr_expansion_consistency(self, network_prefix, cidr_suffix):
|
||||
"""
|
||||
Property 7: CIDR Expansion Consistency
|
||||
|
||||
Feature: scan-target-provider, Property 7: CIDR Expansion Consistency
|
||||
**Validates: Requirements 1.1, 3.6**
|
||||
|
||||
For any CIDR string, all Providers should expand it to the same IP list.
|
||||
"""
|
||||
cidr = f"192.168.{network_prefix}.0/{cidr_suffix}"
|
||||
|
||||
# 计算预期的 IP 列表
|
||||
network = IPv4Network(cidr, strict=False)
|
||||
# 排除网络地址和广播地址
|
||||
expected_ips = [str(ip) for ip in network.hosts()]
|
||||
|
||||
# 如果 CIDR 太小(/31 或 /32),使用所有地址
|
||||
if not expected_ips:
|
||||
expected_ips = [str(ip) for ip in network]
|
||||
|
||||
# ListTargetProvider
|
||||
list_provider = ListTargetProvider(targets=[cidr])
|
||||
list_result = list(list_provider.iter_hosts())
|
||||
|
||||
# PipelineTargetProvider
|
||||
stage_output = StageOutput(hosts=[cidr])
|
||||
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
pipeline_result = list(pipeline_provider.iter_hosts())
|
||||
|
||||
# 验证:所有 Provider 展开的结果应该一致
|
||||
assert list_result == expected_ips, f"ListProvider CIDR expansion mismatch for {cidr}"
|
||||
assert pipeline_result == expected_ips, f"PipelineProvider CIDR expansion mismatch for {cidr}"
|
||||
assert list_result == pipeline_result, f"Providers produce different results for {cidr}"
|
||||
|
||||
def test_cidr_expansion_with_multiple_cidrs(self):
|
||||
"""测试多个 CIDR 的展开一致性"""
|
||||
cidrs = ["192.168.1.0/30", "10.0.0.0/30"]
|
||||
|
||||
# 计算预期结果
|
||||
expected_ips = []
|
||||
for cidr in cidrs:
|
||||
network = IPv4Network(cidr, strict=False)
|
||||
expected_ips.extend([str(ip) for ip in network.hosts()])
|
||||
|
||||
# ListTargetProvider
|
||||
list_provider = ListTargetProvider(targets=cidrs)
|
||||
list_result = list(list_provider.iter_hosts())
|
||||
|
||||
# PipelineTargetProvider
|
||||
stage_output = StageOutput(hosts=cidrs)
|
||||
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
pipeline_result = list(pipeline_provider.iter_hosts())
|
||||
|
||||
# 验证
|
||||
assert list_result == expected_ips
|
||||
assert pipeline_result == expected_ips
|
||||
assert list_result == pipeline_result
|
||||
|
||||
def test_mixed_hosts_and_cidrs(self):
|
||||
"""测试混合主机和 CIDR 的处理"""
|
||||
targets = ["example.com", "192.168.1.0/30", "test.com"]
|
||||
|
||||
# 计算预期结果
|
||||
network = IPv4Network("192.168.1.0/30", strict=False)
|
||||
cidr_ips = [str(ip) for ip in network.hosts()]
|
||||
expected = ["example.com"] + cidr_ips + ["test.com"]
|
||||
|
||||
# ListTargetProvider
|
||||
list_provider = ListTargetProvider(targets=targets)
|
||||
list_result = list(list_provider.iter_hosts())
|
||||
|
||||
# 验证
|
||||
assert list_result == expected
|
||||
@@ -2,7 +2,7 @@
|
||||
DatabaseTargetProvider 属性测试
|
||||
|
||||
Property 7: DatabaseTargetProvider Blacklist Application
|
||||
*For any* 带有黑名单规则的 target_id,DatabaseTargetProvider 的 iter_hosts() 和 iter_urls()
|
||||
*For any* 带有黑名单规则的 target_id,DatabaseTargetProvider 的 iter_subdomains()
|
||||
应该过滤掉匹配黑名单规则的目标。
|
||||
|
||||
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
|
||||
@@ -48,7 +48,7 @@ class TestDatabaseTargetProviderProperties:
|
||||
"""DatabaseTargetProvider 属性测试类"""
|
||||
|
||||
@given(
|
||||
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=20),
|
||||
subdomains=st.lists(valid_domain_strategy(), min_size=1, max_size=20),
|
||||
blocked_keyword=st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
@@ -56,15 +56,15 @@ class TestDatabaseTargetProviderProperties:
|
||||
)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_7_blacklist_filters_hosts(self, hosts, blocked_keyword):
|
||||
def test_property_7_blacklist_filters_subdomains(self, subdomains, blocked_keyword):
|
||||
"""
|
||||
Property 7: DatabaseTargetProvider Blacklist Application (hosts)
|
||||
Property 7: DatabaseTargetProvider Blacklist Application (subdomains)
|
||||
|
||||
Feature: scan-target-provider, Property 7: DatabaseTargetProvider Blacklist Application
|
||||
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
|
||||
|
||||
For any set of hosts and a blacklist keyword, the provider should filter out
|
||||
all hosts containing the blocked keyword.
|
||||
For any set of subdomains and a blacklist keyword, the provider should filter out
|
||||
all subdomains containing the blocked keyword.
|
||||
"""
|
||||
# 创建模拟的黑名单过滤器
|
||||
mock_filter = MockBlacklistFilter([blocked_keyword])
|
||||
@@ -73,31 +73,18 @@ class TestDatabaseTargetProviderProperties:
|
||||
provider = DatabaseTargetProvider(target_id=1)
|
||||
provider._blacklist_filter = mock_filter
|
||||
|
||||
# 模拟 Target 和 SubdomainService
|
||||
mock_target = MagicMock()
|
||||
mock_target.type = 'domain'
|
||||
mock_target.name = hosts[0] if hosts else 'example.com'
|
||||
|
||||
with patch('apps.targets.services.TargetService') as mock_target_service, \
|
||||
patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_subdomain_service:
|
||||
|
||||
mock_target_service.return_value.get_target.return_value = mock_target
|
||||
mock_subdomain_service.return_value.iter_subdomain_names_by_target.return_value = iter(hosts[1:] if len(hosts) > 1 else [])
|
||||
with patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_subdomain_service:
|
||||
mock_subdomain_service.return_value.iter_subdomain_names_by_target.return_value = iter(subdomains)
|
||||
|
||||
# 获取结果
|
||||
result = list(provider.iter_hosts())
|
||||
result = list(provider.iter_subdomains())
|
||||
|
||||
# 验证:所有结果都不包含被阻止的关键词
|
||||
for host in result:
|
||||
assert blocked_keyword not in host, f"Host '{host}' should be filtered by blacklist keyword '{blocked_keyword}'"
|
||||
|
||||
# 验证:所有不包含关键词的主机都应该在结果中
|
||||
if hosts:
|
||||
all_hosts = [hosts[0]] + [h for h in hosts[1:] if h != hosts[0]]
|
||||
expected_allowed = [h for h in all_hosts if blocked_keyword not in h]
|
||||
else:
|
||||
expected_allowed = []
|
||||
for subdomain in result:
|
||||
assert blocked_keyword not in subdomain, f"Subdomain '{subdomain}' should be filtered by blacklist keyword '{blocked_keyword}'"
|
||||
|
||||
# 验证:所有不包含关键词的子域名都应该在结果中
|
||||
expected_allowed = [s for s in subdomains if blocked_keyword not in s]
|
||||
assert set(result) == set(expected_allowed)
|
||||
|
||||
|
||||
@@ -144,15 +131,38 @@ class TestDatabaseTargetProviderUnit:
|
||||
# BlacklistService 只应该被调用一次
|
||||
mock_service.return_value.get_rules.assert_called_once_with(123)
|
||||
|
||||
def test_nonexistent_target_returns_empty(self):
|
||||
"""测试不存在的 target 返回空迭代器"""
|
||||
def test_get_target_name(self):
|
||||
"""测试 get_target_name 返回 Target 名称"""
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
|
||||
mock_target = MagicMock()
|
||||
mock_target.name = 'example.com'
|
||||
|
||||
with patch('apps.targets.services.TargetService') as mock_service:
|
||||
mock_service.return_value.get_target.return_value = mock_target
|
||||
|
||||
result = provider.get_target_name()
|
||||
assert result == 'example.com'
|
||||
|
||||
def test_get_target_name_nonexistent(self):
|
||||
"""测试不存在的 target 返回 None"""
|
||||
provider = DatabaseTargetProvider(target_id=99999)
|
||||
|
||||
with patch('apps.targets.services.TargetService') as mock_service, \
|
||||
with patch('apps.targets.services.TargetService') as mock_service:
|
||||
mock_service.return_value.get_target.return_value = None
|
||||
|
||||
result = provider.get_target_name()
|
||||
assert result is None
|
||||
|
||||
def test_iter_subdomains_empty(self):
|
||||
"""测试空子域名列表"""
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
|
||||
with patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_service, \
|
||||
patch('apps.common.services.BlacklistService') as mock_blacklist_service:
|
||||
|
||||
mock_service.return_value.get_target.return_value = None
|
||||
mock_service.return_value.iter_subdomain_names_by_target.return_value = iter([])
|
||||
mock_blacklist_service.return_value.get_rules.return_value = []
|
||||
|
||||
result = list(provider.iter_hosts())
|
||||
result = list(provider.iter_subdomains())
|
||||
assert result == []
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
"""
|
||||
ListTargetProvider 属性测试
|
||||
|
||||
Property 1: ListTargetProvider Round-Trip
|
||||
*For any* 主机列表和 URL 列表,创建 ListTargetProvider 后迭代 iter_hosts() 和 iter_urls()
|
||||
应该返回与输入相同的元素(顺序相同)。
|
||||
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st, settings, assume
|
||||
|
||||
from apps.scan.providers.list_provider import ListTargetProvider
|
||||
from apps.scan.providers.base import ProviderContext
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
# 生成简单的域名格式: subdomain.domain.tld
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
# 生成有效 IP 地址的策略
|
||||
def valid_ip_strategy():
|
||||
"""生成有效的 IPv4 地址"""
|
||||
octet = st.integers(min_value=1, max_value=254)
|
||||
return st.builds(
|
||||
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
|
||||
octet, octet, octet, octet
|
||||
)
|
||||
|
||||
# 组合策略:域名或 IP
|
||||
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
|
||||
|
||||
# 生成有效 URL 的策略
|
||||
def valid_url_strategy():
|
||||
"""生成有效的 URL"""
|
||||
domain = valid_domain_strategy()
|
||||
return st.builds(
|
||||
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
|
||||
domain,
|
||||
st.one_of(
|
||||
st.just(""),
|
||||
st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=1,
|
||||
max_size=10
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
url_strategy = valid_url_strategy()
|
||||
|
||||
|
||||
class TestListTargetProviderProperties:
|
||||
"""ListTargetProvider 属性测试类"""
|
||||
|
||||
@given(hosts=st.lists(host_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_1_hosts_round_trip(self, hosts):
|
||||
"""
|
||||
Property 1: ListTargetProvider Round-Trip (hosts)
|
||||
|
||||
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
For any host list, creating a ListTargetProvider and iterating iter_hosts()
|
||||
should return the same elements in the same order.
|
||||
"""
|
||||
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
|
||||
provider = ListTargetProvider(targets=hosts)
|
||||
result = list(provider.iter_hosts())
|
||||
assert result == hosts
|
||||
|
||||
@given(urls=st.lists(url_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_1_urls_round_trip(self, urls):
|
||||
"""
|
||||
Property 1: ListTargetProvider Round-Trip (urls)
|
||||
|
||||
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
For any URL list, creating a ListTargetProvider and iterating iter_urls()
|
||||
should return the same elements in the same order.
|
||||
"""
|
||||
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
|
||||
provider = ListTargetProvider(targets=urls)
|
||||
result = list(provider.iter_urls())
|
||||
assert result == urls
|
||||
|
||||
@given(
|
||||
hosts=st.lists(host_strategy, max_size=30),
|
||||
urls=st.lists(url_strategy, max_size=30)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_1_combined_round_trip(self, hosts, urls):
|
||||
"""
|
||||
Property 1: ListTargetProvider Round-Trip (combined)
|
||||
|
||||
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
For any combination of hosts and URLs, both should round-trip correctly.
|
||||
"""
|
||||
# 合并 hosts 和 urls,ListTargetProvider 会自动分类
|
||||
combined = hosts + urls
|
||||
provider = ListTargetProvider(targets=combined)
|
||||
|
||||
hosts_result = list(provider.iter_hosts())
|
||||
urls_result = list(provider.iter_urls())
|
||||
|
||||
assert hosts_result == hosts
|
||||
assert urls_result == urls
|
||||
|
||||
|
||||
class TestListTargetProviderUnit:
|
||||
"""ListTargetProvider 单元测试类"""
|
||||
|
||||
def test_empty_lists(self):
|
||||
"""测试空列表返回空迭代器 - Requirements 3.5"""
|
||||
provider = ListTargetProvider()
|
||||
assert list(provider.iter_hosts()) == []
|
||||
assert list(provider.iter_urls()) == []
|
||||
|
||||
def test_blacklist_filter_returns_none(self):
|
||||
"""测试黑名单过滤器返回 None - Requirements 3.4"""
|
||||
provider = ListTargetProvider(targets=["example.com"])
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_target_id_association(self):
|
||||
"""测试 target_id 关联 - Requirements 3.3"""
|
||||
ctx = ProviderContext(target_id=123)
|
||||
provider = ListTargetProvider(targets=["example.com"], context=ctx)
|
||||
assert provider.target_id == 123
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(target_id=456, scan_id=789)
|
||||
provider = ListTargetProvider(targets=["example.com"], context=ctx)
|
||||
|
||||
assert provider.target_id == 456
|
||||
assert provider.scan_id == 789
|
||||
@@ -1,180 +0,0 @@
|
||||
"""
|
||||
PipelineTargetProvider 属性测试
|
||||
|
||||
Property 3: PipelineTargetProvider Round-Trip
|
||||
*For any* StageOutput 对象,PipelineTargetProvider 的 iter_hosts() 和 iter_urls()
|
||||
应该返回与 StageOutput 中 hosts 和 urls 列表相同的元素。
|
||||
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st, settings
|
||||
|
||||
from apps.scan.providers.pipeline_provider import PipelineTargetProvider, StageOutput
|
||||
from apps.scan.providers.base import ProviderContext
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
# 生成有效 IP 地址的策略
|
||||
def valid_ip_strategy():
|
||||
"""生成有效的 IPv4 地址"""
|
||||
octet = st.integers(min_value=1, max_value=254)
|
||||
return st.builds(
|
||||
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
|
||||
octet, octet, octet, octet
|
||||
)
|
||||
|
||||
# 组合策略:域名或 IP
|
||||
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
|
||||
|
||||
# 生成有效 URL 的策略
|
||||
def valid_url_strategy():
|
||||
"""生成有效的 URL"""
|
||||
domain = valid_domain_strategy()
|
||||
return st.builds(
|
||||
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
|
||||
domain,
|
||||
st.one_of(
|
||||
st.just(""),
|
||||
st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=1,
|
||||
max_size=10
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
url_strategy = valid_url_strategy()
|
||||
|
||||
|
||||
class TestPipelineTargetProviderProperties:
|
||||
"""PipelineTargetProvider 属性测试类"""
|
||||
|
||||
@given(hosts=st.lists(host_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_3_hosts_round_trip(self, hosts):
|
||||
"""
|
||||
Property 3: PipelineTargetProvider Round-Trip (hosts)
|
||||
|
||||
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
|
||||
For any StageOutput with hosts, PipelineTargetProvider should return
|
||||
the same hosts in the same order.
|
||||
"""
|
||||
stage_output = StageOutput(hosts=hosts)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
result = list(provider.iter_hosts())
|
||||
assert result == hosts
|
||||
|
||||
@given(urls=st.lists(url_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_3_urls_round_trip(self, urls):
|
||||
"""
|
||||
Property 3: PipelineTargetProvider Round-Trip (urls)
|
||||
|
||||
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
|
||||
For any StageOutput with urls, PipelineTargetProvider should return
|
||||
the same urls in the same order.
|
||||
"""
|
||||
stage_output = StageOutput(urls=urls)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
result = list(provider.iter_urls())
|
||||
assert result == urls
|
||||
|
||||
@given(
|
||||
hosts=st.lists(host_strategy, max_size=30),
|
||||
urls=st.lists(url_strategy, max_size=30)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_3_combined_round_trip(self, hosts, urls):
|
||||
"""
|
||||
Property 3: PipelineTargetProvider Round-Trip (combined)
|
||||
|
||||
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
|
||||
For any StageOutput with both hosts and urls, both should round-trip correctly.
|
||||
"""
|
||||
stage_output = StageOutput(hosts=hosts, urls=urls)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
hosts_result = list(provider.iter_hosts())
|
||||
urls_result = list(provider.iter_urls())
|
||||
|
||||
assert hosts_result == hosts
|
||||
assert urls_result == urls
|
||||
|
||||
|
||||
class TestPipelineTargetProviderUnit:
|
||||
"""PipelineTargetProvider 单元测试类"""
|
||||
|
||||
def test_empty_stage_output(self):
|
||||
"""测试空 StageOutput 返回空迭代器 - Requirements 5.5"""
|
||||
stage_output = StageOutput()
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
assert list(provider.iter_hosts()) == []
|
||||
assert list(provider.iter_urls()) == []
|
||||
|
||||
def test_blacklist_filter_returns_none(self):
|
||||
"""测试黑名单过滤器返回 None - Requirements 5.3"""
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_target_id_association(self):
|
||||
"""测试 target_id 关联 - Requirements 5.4"""
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output, target_id=123)
|
||||
assert provider.target_id == 123
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(target_id=456, scan_id=789)
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
|
||||
|
||||
assert provider.target_id == 456
|
||||
assert provider.scan_id == 789
|
||||
|
||||
def test_previous_output_property(self):
|
||||
"""测试 previous_output 属性"""
|
||||
stage_output = StageOutput(hosts=["example.com"], urls=["https://example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
assert provider.previous_output is stage_output
|
||||
assert provider.previous_output.hosts == ["example.com"]
|
||||
assert provider.previous_output.urls == ["https://example.com"]
|
||||
|
||||
def test_stage_output_with_metadata(self):
|
||||
"""测试带元数据的 StageOutput"""
|
||||
stage_output = StageOutput(
|
||||
hosts=["example.com"],
|
||||
urls=["https://example.com"],
|
||||
new_targets=["new.example.com"],
|
||||
stats={"count": 1},
|
||||
success=True,
|
||||
error=None
|
||||
)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
assert list(provider.iter_hosts()) == ["example.com"]
|
||||
assert list(provider.iter_urls()) == ["https://example.com"]
|
||||
assert provider.previous_output.new_targets == ["new.example.com"]
|
||||
assert provider.previous_output.stats == {"count": 1}
|
||||
@@ -10,182 +10,112 @@ from apps.scan.providers import SnapshotTargetProvider, ProviderContext
|
||||
|
||||
class TestSnapshotTargetProvider:
|
||||
"""SnapshotTargetProvider 测试类"""
|
||||
|
||||
def test_init_with_scan_id_and_type(self):
|
||||
|
||||
def test_init_with_scan_id(self):
|
||||
"""测试初始化"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain"
|
||||
)
|
||||
|
||||
provider = SnapshotTargetProvider(scan_id=100)
|
||||
|
||||
assert provider.scan_id == 100
|
||||
assert provider.snapshot_type == "subdomain"
|
||||
assert provider.target_id is None # 默认 context
|
||||
|
||||
assert provider.target_id is None
|
||||
|
||||
def test_init_with_context(self):
|
||||
"""测试带 context 初始化"""
|
||||
ctx = ProviderContext(target_id=1, scan_id=100)
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain",
|
||||
context=ctx
|
||||
)
|
||||
|
||||
provider = SnapshotTargetProvider(scan_id=100, context=ctx)
|
||||
|
||||
assert provider.scan_id == 100
|
||||
assert provider.target_id == 1
|
||||
assert provider.snapshot_type == "subdomain"
|
||||
|
||||
|
||||
@patch('apps.asset.services.snapshot.SubdomainSnapshotsService')
|
||||
def test_iter_hosts_subdomain(self, mock_service_class):
|
||||
"""测试从子域名快照迭代主机"""
|
||||
# Mock service
|
||||
def test_iter_subdomains(self, mock_service_class):
|
||||
"""测试从子域名快照迭代子域名"""
|
||||
mock_service = Mock()
|
||||
mock_service.iter_subdomain_names_by_scan.return_value = iter([
|
||||
"a.example.com",
|
||||
"b.example.com"
|
||||
])
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain"
|
||||
)
|
||||
|
||||
# 迭代主机
|
||||
hosts = list(provider.iter_hosts())
|
||||
|
||||
assert hosts == ["a.example.com", "b.example.com"]
|
||||
|
||||
provider = SnapshotTargetProvider(scan_id=100)
|
||||
subdomains = list(provider.iter_subdomains())
|
||||
|
||||
assert subdomains == ["a.example.com", "b.example.com"]
|
||||
mock_service.iter_subdomain_names_by_scan.assert_called_once_with(
|
||||
scan_id=100,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
|
||||
@patch('apps.asset.services.snapshot.HostPortMappingSnapshotsService')
|
||||
def test_iter_hosts_host_port(self, mock_service_class):
|
||||
"""测试从主机端口映射快照迭代主机"""
|
||||
# Mock queryset
|
||||
mock_mapping1 = Mock()
|
||||
mock_mapping1.host = "example.com"
|
||||
mock_mapping1.port = 80
|
||||
|
||||
mock_mapping2 = Mock()
|
||||
mock_mapping2.host = "example.com"
|
||||
mock_mapping2.port = 443
|
||||
|
||||
mock_queryset = Mock()
|
||||
mock_queryset.iterator.return_value = iter([mock_mapping1, mock_mapping2])
|
||||
|
||||
# Mock service
|
||||
def test_iter_host_port_urls(self, mock_service_class):
|
||||
"""测试从主机端口映射快照生成 URL"""
|
||||
mock_service = Mock()
|
||||
mock_service.get_by_scan.return_value = mock_queryset
|
||||
mock_service.iter_unique_host_ports_by_scan.return_value = iter([
|
||||
{'host': 'example.com', 'port': 80},
|
||||
{'host': 'example.com', 'port': 443},
|
||||
{'host': 'example.com', 'port': 8080},
|
||||
])
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="host_port"
|
||||
)
|
||||
|
||||
# 迭代主机
|
||||
hosts = list(provider.iter_hosts())
|
||||
|
||||
assert hosts == ["example.com:80", "example.com:443"]
|
||||
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
|
||||
|
||||
|
||||
provider = SnapshotTargetProvider(scan_id=100)
|
||||
urls = list(provider.iter_host_port_urls())
|
||||
|
||||
assert urls == [
|
||||
"http://example.com",
|
||||
"https://example.com",
|
||||
"http://example.com:8080",
|
||||
"https://example.com:8080",
|
||||
]
|
||||
|
||||
@patch('apps.asset.services.snapshot.WebsiteSnapshotsService')
|
||||
def test_iter_urls_website(self, mock_service_class):
|
||||
def test_iter_websites(self, mock_service_class):
|
||||
"""测试从网站快照迭代 URL"""
|
||||
# Mock service
|
||||
mock_service = Mock()
|
||||
mock_service.iter_website_urls_by_scan.return_value = iter([
|
||||
"http://example.com",
|
||||
"https://example.com"
|
||||
])
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="website"
|
||||
)
|
||||
|
||||
# 迭代 URL
|
||||
urls = list(provider.iter_urls())
|
||||
|
||||
|
||||
provider = SnapshotTargetProvider(scan_id=100)
|
||||
urls = list(provider.iter_websites())
|
||||
|
||||
assert urls == ["http://example.com", "https://example.com"]
|
||||
mock_service.iter_website_urls_by_scan.assert_called_once_with(
|
||||
scan_id=100,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
|
||||
@patch('apps.asset.services.snapshot.EndpointSnapshotsService')
|
||||
def test_iter_urls_endpoint(self, mock_service_class):
|
||||
def test_iter_endpoints(self, mock_service_class):
|
||||
"""测试从端点快照迭代 URL"""
|
||||
# Mock queryset
|
||||
mock_endpoint1 = Mock()
|
||||
mock_endpoint1.url = "http://example.com/api/v1"
|
||||
|
||||
|
||||
mock_endpoint2 = Mock()
|
||||
mock_endpoint2.url = "http://example.com/api/v2"
|
||||
|
||||
|
||||
mock_queryset = Mock()
|
||||
mock_queryset.iterator.return_value = iter([mock_endpoint1, mock_endpoint2])
|
||||
|
||||
# Mock service
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_by_scan.return_value = mock_queryset
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="endpoint"
|
||||
)
|
||||
|
||||
# 迭代 URL
|
||||
urls = list(provider.iter_urls())
|
||||
|
||||
|
||||
provider = SnapshotTargetProvider(scan_id=100)
|
||||
urls = list(provider.iter_endpoints())
|
||||
|
||||
assert urls == ["http://example.com/api/v1", "http://example.com/api/v2"]
|
||||
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
|
||||
|
||||
def test_iter_hosts_unsupported_type(self):
|
||||
"""测试不支持的快照类型(iter_hosts)"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="website" # website 不支持 iter_hosts
|
||||
)
|
||||
|
||||
hosts = list(provider.iter_hosts())
|
||||
assert hosts == []
|
||||
|
||||
def test_iter_urls_unsupported_type(self):
|
||||
"""测试不支持的快照类型(iter_urls)"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain" # subdomain 不支持 iter_urls
|
||||
)
|
||||
|
||||
urls = list(provider.iter_urls())
|
||||
assert urls == []
|
||||
|
||||
|
||||
def test_get_blacklist_filter(self):
|
||||
"""测试黑名单过滤器(快照模式不使用黑名单)"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain"
|
||||
)
|
||||
|
||||
provider = SnapshotTargetProvider(scan_id=100)
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(target_id=456, scan_id=789)
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100, # 会被 context 覆盖
|
||||
snapshot_type="subdomain",
|
||||
context=ctx
|
||||
)
|
||||
|
||||
provider = SnapshotTargetProvider(scan_id=100, context=ctx)
|
||||
|
||||
assert provider.target_id == 456
|
||||
assert provider.scan_id == 100 # scan_id 在 __init__ 中被设置
|
||||
assert provider.scan_id == 100
|
||||
|
||||
@@ -137,16 +137,14 @@ def main():
|
||||
print("[2/4] 解析命令行参数...")
|
||||
parser = argparse.ArgumentParser(description="执行扫描初始化 Flow")
|
||||
parser.add_argument("--scan_id", type=int, required=True, help="扫描任务 ID")
|
||||
parser.add_argument("--target_name", type=str, required=True, help="目标名称")
|
||||
parser.add_argument("--target_id", type=int, required=True, help="目标 ID")
|
||||
parser.add_argument("--scan_workspace_dir", type=str, required=True, help="扫描工作目录")
|
||||
parser.add_argument("--engine_name", type=str, required=True, help="引擎名称")
|
||||
parser.add_argument("--scheduled_scan_name", type=str, default=None, help="定时扫描任务名称(可选)")
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
print(f"[2/4] ✓ 参数解析成功:")
|
||||
print(f" scan_id: {args.scan_id}")
|
||||
print(f" target_name: {args.target_name}")
|
||||
print(f" target_id: {args.target_id}")
|
||||
print(f" scan_workspace_dir: {args.scan_workspace_dir}")
|
||||
print(f" engine_name: {args.engine_name}")
|
||||
@@ -171,7 +169,6 @@ def main():
|
||||
try:
|
||||
result = initiate_scan_flow(
|
||||
scan_id=args.scan_id,
|
||||
target_name=args.target_name,
|
||||
target_id=args.target_id,
|
||||
scan_workspace_dir=args.scan_workspace_dir,
|
||||
engine_name=args.engine_name,
|
||||
|
||||
@@ -15,11 +15,11 @@ class ScanSerializer(serializers.ModelSerializer):
|
||||
fields = [
|
||||
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
|
||||
'created_at', 'stopped_at', 'status', 'results_dir',
|
||||
'container_ids', 'error_message'
|
||||
'container_ids', 'error_message', 'scan_mode'
|
||||
]
|
||||
read_only_fields = [
|
||||
'id', 'created_at', 'stopped_at', 'results_dir',
|
||||
'container_ids', 'error_message', 'status'
|
||||
'container_ids', 'error_message', 'status', 'scan_mode'
|
||||
]
|
||||
|
||||
def get_target_name(self, obj):
|
||||
@@ -39,9 +39,10 @@ class ScanHistorySerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Scan
|
||||
fields = [
|
||||
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
|
||||
'worker_name', 'created_at', 'status', 'error_message', 'summary',
|
||||
'progress', 'current_stage', 'stage_progress', 'yaml_configuration'
|
||||
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
|
||||
'worker_name', 'created_at', 'status', 'error_message', 'summary',
|
||||
'progress', 'current_stage', 'stage_progress', 'yaml_configuration',
|
||||
'scan_mode'
|
||||
]
|
||||
|
||||
def get_summary(self, obj):
|
||||
|
||||
@@ -17,23 +17,15 @@ from .scan_state_service import ScanStateService
|
||||
from .scan_control_service import ScanControlService
|
||||
from .scan_stats_service import ScanStatsService
|
||||
from .scheduled_scan_service import ScheduledScanService
|
||||
from .target_export_service import (
|
||||
TargetExportService,
|
||||
create_export_service,
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from .scan_input_target_service import ScanInputTargetService
|
||||
|
||||
__all__ = [
|
||||
'ScanService', # 主入口(向后兼容)
|
||||
'ScanService',
|
||||
'ScanCreationService',
|
||||
'ScanStateService',
|
||||
'ScanControlService',
|
||||
'ScanStatsService',
|
||||
'ScheduledScanService',
|
||||
'TargetExportService', # 目标导出服务
|
||||
'create_export_service',
|
||||
'export_urls_with_fallback',
|
||||
'DataSource',
|
||||
'ScanInputTargetService',
|
||||
]
|
||||
|
||||
|
||||
@@ -5,13 +5,16 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Literal, List, Dict, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from django.db import transaction
|
||||
|
||||
from apps.common.validators import validate_url, detect_input_type, validate_domain, validate_ip, validate_cidr, is_valid_ip
|
||||
from apps.common.validators import (
|
||||
validate_url, detect_input_type, validate_domain,
|
||||
validate_ip, validate_cidr, is_valid_ip
|
||||
)
|
||||
from apps.targets.services.target_service import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.asset.dtos import WebSiteDTO
|
||||
@@ -24,98 +27,72 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ParsedInputDTO:
|
||||
"""
|
||||
解析输入 DTO
|
||||
|
||||
只在快速扫描流程中使用
|
||||
"""
|
||||
"""解析输入 DTO,只在快速扫描流程中使用"""
|
||||
original_input: str
|
||||
input_type: Literal['url', 'domain', 'ip', 'cidr']
|
||||
target_name: str # host/domain/ip/cidr
|
||||
target_name: str
|
||||
target_type: Literal['domain', 'ip', 'cidr']
|
||||
website_url: Optional[str] = None # 根 URL(scheme://host[:port])
|
||||
endpoint_url: Optional[str] = None # 完整 URL(含路径)
|
||||
is_valid: bool = True
|
||||
error: Optional[str] = None
|
||||
website_url: Optional[str] = None
|
||||
endpoint_url: Optional[str] = None
|
||||
line_number: Optional[int] = None
|
||||
# 验证状态放在嵌套结构中,减少顶层属性数量
|
||||
validation: Dict[str, Any] = field(default_factory=lambda: {
|
||||
'is_valid': True,
|
||||
'error': None
|
||||
})
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return self.validation.get('is_valid', True)
|
||||
|
||||
@property
|
||||
def error(self) -> Optional[str]:
|
||||
return self.validation.get('error')
|
||||
|
||||
|
||||
class QuickScanService:
|
||||
"""快速扫描服务 - 解析输入并创建资产"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.target_service = TargetService()
|
||||
self.website_repo = DjangoWebSiteRepository()
|
||||
self.endpoint_repo = DjangoEndpointRepository()
|
||||
|
||||
|
||||
def parse_inputs(self, inputs: List[str]) -> List[ParsedInputDTO]:
|
||||
"""
|
||||
解析多行输入
|
||||
|
||||
Args:
|
||||
inputs: 输入字符串列表(每行一个)
|
||||
|
||||
Returns:
|
||||
解析结果列表(跳过空行)
|
||||
"""
|
||||
"""解析多行输入,返回解析结果列表(跳过空行)"""
|
||||
results = []
|
||||
for line_number, input_str in enumerate(inputs, start=1):
|
||||
input_str = input_str.strip()
|
||||
|
||||
# 空行跳过
|
||||
if not input_str:
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# 检测输入类型
|
||||
input_type = detect_input_type(input_str)
|
||||
|
||||
if input_type == 'url':
|
||||
dto = self._parse_url_input(input_str, line_number)
|
||||
else:
|
||||
dto = self._parse_target_input(input_str, input_type, line_number)
|
||||
|
||||
results.append(dto)
|
||||
except ValueError as e:
|
||||
# 解析失败,记录错误
|
||||
results.append(ParsedInputDTO(
|
||||
original_input=input_str,
|
||||
input_type='domain', # 默认类型
|
||||
input_type='domain',
|
||||
target_name=input_str,
|
||||
target_type='domain',
|
||||
is_valid=False,
|
||||
error=str(e),
|
||||
line_number=line_number
|
||||
line_number=line_number,
|
||||
validation={'is_valid': False, 'error': str(e)}
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _parse_url_input(self, url_str: str, line_number: int) -> ParsedInputDTO:
|
||||
"""
|
||||
解析 URL 输入
|
||||
|
||||
Args:
|
||||
url_str: URL 字符串
|
||||
line_number: 行号
|
||||
|
||||
Returns:
|
||||
ParsedInputDTO
|
||||
"""
|
||||
# 验证 URL 格式
|
||||
"""解析 URL 输入"""
|
||||
validate_url(url_str)
|
||||
|
||||
# 使用标准库解析
|
||||
parsed = urlparse(url_str)
|
||||
|
||||
host = parsed.hostname # 不含端口
|
||||
host = parsed.hostname
|
||||
has_path = parsed.path and parsed.path != '/'
|
||||
|
||||
# 构建 root_url: scheme://host[:port]
|
||||
root_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# 检测 host 类型(domain 或 ip)
|
||||
target_type = 'ip' if is_valid_ip(host) else 'domain'
|
||||
|
||||
|
||||
return ParsedInputDTO(
|
||||
original_input=url_str,
|
||||
input_type='url',
|
||||
@@ -125,167 +102,98 @@ class QuickScanService:
|
||||
endpoint_url=url_str if has_path else None,
|
||||
line_number=line_number
|
||||
)
|
||||
|
||||
|
||||
def _parse_target_input(
|
||||
self,
|
||||
input_str: str,
|
||||
input_type: str,
|
||||
self,
|
||||
input_str: str,
|
||||
input_type: str,
|
||||
line_number: int
|
||||
) -> ParsedInputDTO:
|
||||
"""
|
||||
解析非 URL 输入(domain/ip/cidr)
|
||||
|
||||
Args:
|
||||
input_str: 输入字符串
|
||||
input_type: 输入类型
|
||||
line_number: 行号
|
||||
|
||||
Returns:
|
||||
ParsedInputDTO
|
||||
"""
|
||||
# 验证格式
|
||||
if input_type == 'domain':
|
||||
validate_domain(input_str)
|
||||
target_type = 'domain'
|
||||
elif input_type == 'ip':
|
||||
validate_ip(input_str)
|
||||
target_type = 'ip'
|
||||
elif input_type == 'cidr':
|
||||
validate_cidr(input_str)
|
||||
target_type = 'cidr'
|
||||
else:
|
||||
"""解析非 URL 输入(domain/ip/cidr)"""
|
||||
validators = {
|
||||
'domain': (validate_domain, 'domain'),
|
||||
'ip': (validate_ip, 'ip'),
|
||||
'cidr': (validate_cidr, 'cidr'),
|
||||
}
|
||||
|
||||
if input_type not in validators:
|
||||
raise ValueError(f"未知的输入类型: {input_type}")
|
||||
|
||||
|
||||
validator, target_type = validators[input_type]
|
||||
validator(input_str)
|
||||
|
||||
return ParsedInputDTO(
|
||||
original_input=input_str,
|
||||
input_type=input_type,
|
||||
target_name=input_str,
|
||||
target_type=target_type,
|
||||
website_url=None,
|
||||
endpoint_url=None,
|
||||
line_number=line_number
|
||||
)
|
||||
|
||||
|
||||
@transaction.atomic
|
||||
def process_quick_scan(
|
||||
self,
|
||||
inputs: List[str],
|
||||
engine_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理快速扫描请求
|
||||
|
||||
Args:
|
||||
inputs: 输入字符串列表
|
||||
engine_id: 扫描引擎 ID
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
# 1. 解析输入
|
||||
def process_quick_scan(self, inputs: List[str]) -> Dict[str, Any]:
|
||||
"""处理快速扫描请求"""
|
||||
parsed_inputs = self.parse_inputs(inputs)
|
||||
|
||||
# 分离有效和无效输入
|
||||
valid_inputs = [p for p in parsed_inputs if p.is_valid]
|
||||
invalid_inputs = [p for p in parsed_inputs if not p.is_valid]
|
||||
|
||||
|
||||
errors = [
|
||||
{'line_number': p.line_number, 'input': p.original_input, 'error': p.error}
|
||||
for p in invalid_inputs
|
||||
]
|
||||
|
||||
if not valid_inputs:
|
||||
return {
|
||||
'targets': [],
|
||||
'target_stats': {'created': 0, 'reused': 0, 'failed': len(invalid_inputs)},
|
||||
'asset_stats': {'websites_created': 0, 'endpoints_created': 0},
|
||||
'errors': [
|
||||
{'line_number': p.line_number, 'input': p.original_input, 'error': p.error}
|
||||
for p in invalid_inputs
|
||||
]
|
||||
'errors': errors
|
||||
}
|
||||
|
||||
# 2. 创建资产
|
||||
|
||||
asset_result = self.create_assets_from_parsed_inputs(valid_inputs)
|
||||
|
||||
# 3. 返回结果
|
||||
|
||||
# 构建 target_name → inputs 映射
|
||||
target_inputs_map: Dict[str, List[str]] = {}
|
||||
for p in valid_inputs:
|
||||
target_inputs_map.setdefault(p.target_name, []).append(p.original_input)
|
||||
|
||||
return {
|
||||
'targets': asset_result['targets'],
|
||||
'target_stats': asset_result['target_stats'],
|
||||
'asset_stats': asset_result['asset_stats'],
|
||||
'errors': [
|
||||
{'line_number': p.line_number, 'input': p.original_input, 'error': p.error}
|
||||
for p in invalid_inputs
|
||||
]
|
||||
'target_inputs_map': target_inputs_map,
|
||||
'errors': errors
|
||||
}
|
||||
|
||||
|
||||
def create_assets_from_parsed_inputs(
|
||||
self,
|
||||
self,
|
||||
parsed_inputs: List[ParsedInputDTO]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从解析结果创建资产
|
||||
|
||||
Args:
|
||||
parsed_inputs: 解析结果列表(只包含有效输入)
|
||||
|
||||
Returns:
|
||||
创建结果字典
|
||||
"""
|
||||
# 1. 收集所有 target 数据(内存操作,去重)
|
||||
targets_data = {}
|
||||
for dto in parsed_inputs:
|
||||
if dto.target_name not in targets_data:
|
||||
targets_data[dto.target_name] = {'name': dto.target_name, 'type': dto.target_type}
|
||||
|
||||
"""从解析结果创建资产(只包含有效输入)"""
|
||||
# 1. 收集并去重 target 数据
|
||||
targets_data = {
|
||||
dto.target_name: {'name': dto.target_name, 'type': dto.target_type}
|
||||
for dto in parsed_inputs
|
||||
}
|
||||
targets_list = list(targets_data.values())
|
||||
|
||||
# 2. 批量创建 Target(复用现有方法)
|
||||
|
||||
# 2. 批量创建 Target
|
||||
target_result = self.target_service.batch_create_targets(targets_list)
|
||||
|
||||
# 3. 查询刚创建的 Target,建立 name → id 映射
|
||||
|
||||
# 3. 建立 name → id 映射
|
||||
target_names = [d['name'] for d in targets_list]
|
||||
targets = Target.objects.filter(name__in=target_names)
|
||||
target_id_map = {t.name: t.id for t in targets}
|
||||
|
||||
# 4. 收集 Website DTO(内存操作,去重)
|
||||
website_dtos = []
|
||||
seen_websites = set()
|
||||
for dto in parsed_inputs:
|
||||
if dto.website_url and dto.website_url not in seen_websites:
|
||||
seen_websites.add(dto.website_url)
|
||||
target_id = target_id_map.get(dto.target_name)
|
||||
if target_id:
|
||||
website_dtos.append(WebSiteDTO(
|
||||
target_id=target_id,
|
||||
url=dto.website_url,
|
||||
host=dto.target_name
|
||||
))
|
||||
|
||||
# 5. 批量创建 Website(存在即跳过)
|
||||
websites_created = 0
|
||||
if website_dtos:
|
||||
websites_created = self.website_repo.bulk_create_ignore_conflicts(website_dtos)
|
||||
|
||||
# 6. 收集 Endpoint DTO(内存操作,去重)
|
||||
endpoint_dtos = []
|
||||
seen_endpoints = set()
|
||||
for dto in parsed_inputs:
|
||||
if dto.endpoint_url and dto.endpoint_url not in seen_endpoints:
|
||||
seen_endpoints.add(dto.endpoint_url)
|
||||
target_id = target_id_map.get(dto.target_name)
|
||||
if target_id:
|
||||
endpoint_dtos.append(EndpointDTO(
|
||||
target_id=target_id,
|
||||
url=dto.endpoint_url,
|
||||
host=dto.target_name
|
||||
))
|
||||
|
||||
# 7. 批量创建 Endpoint(存在即跳过)
|
||||
endpoints_created = 0
|
||||
if endpoint_dtos:
|
||||
endpoints_created = self.endpoint_repo.bulk_create_ignore_conflicts(endpoint_dtos)
|
||||
|
||||
|
||||
# 4. 批量创建 Website 和 Endpoint
|
||||
websites_created = self._bulk_create_websites(parsed_inputs, target_id_map)
|
||||
endpoints_created = self._bulk_create_endpoints(parsed_inputs, target_id_map)
|
||||
|
||||
return {
|
||||
'targets': list(targets),
|
||||
'target_stats': {
|
||||
'created': target_result['created_count'],
|
||||
'reused': 0, # bulk_create 无法区分新建和复用
|
||||
'reused': 0,
|
||||
'failed': target_result['failed_count']
|
||||
},
|
||||
'asset_stats': {
|
||||
@@ -293,3 +201,53 @@ class QuickScanService:
|
||||
'endpoints_created': endpoints_created
|
||||
}
|
||||
}
|
||||
|
||||
def _bulk_create_websites(
|
||||
self,
|
||||
parsed_inputs: List[ParsedInputDTO],
|
||||
target_id_map: Dict[str, int]
|
||||
) -> int:
|
||||
"""批量创建 Website(存在即跳过)"""
|
||||
website_dtos = []
|
||||
seen = set()
|
||||
|
||||
for dto in parsed_inputs:
|
||||
if not dto.website_url or dto.website_url in seen:
|
||||
continue
|
||||
seen.add(dto.website_url)
|
||||
target_id = target_id_map.get(dto.target_name)
|
||||
if target_id:
|
||||
website_dtos.append(WebSiteDTO(
|
||||
target_id=target_id,
|
||||
url=dto.website_url,
|
||||
host=dto.target_name
|
||||
))
|
||||
|
||||
if not website_dtos:
|
||||
return 0
|
||||
return self.website_repo.bulk_create_ignore_conflicts(website_dtos)
|
||||
|
||||
def _bulk_create_endpoints(
|
||||
self,
|
||||
parsed_inputs: List[ParsedInputDTO],
|
||||
target_id_map: Dict[str, int]
|
||||
) -> int:
|
||||
"""批量创建 Endpoint(存在即跳过)"""
|
||||
endpoint_dtos = []
|
||||
seen = set()
|
||||
|
||||
for dto in parsed_inputs:
|
||||
if not dto.endpoint_url or dto.endpoint_url in seen:
|
||||
continue
|
||||
seen.add(dto.endpoint_url)
|
||||
target_id = target_id_map.get(dto.target_name)
|
||||
if target_id:
|
||||
endpoint_dtos.append(EndpointDTO(
|
||||
target_id=target_id,
|
||||
url=dto.endpoint_url,
|
||||
host=dto.target_name
|
||||
))
|
||||
|
||||
if not endpoint_dtos:
|
||||
return 0
|
||||
return self.endpoint_repo.bulk_create_ignore_conflicts(endpoint_dtos)
|
||||
|
||||
@@ -283,7 +283,8 @@ class ScanCreationService:
|
||||
engine_ids: List[int],
|
||||
engine_names: List[str],
|
||||
yaml_configuration: str,
|
||||
scheduled_scan_name: str | None = None
|
||||
scheduled_scan_name: str | None = None,
|
||||
scan_mode: str = 'full'
|
||||
) -> List[Scan]:
|
||||
"""
|
||||
为多个目标批量创建扫描任务,后台异步分发到 Worker
|
||||
@@ -294,6 +295,7 @@ class ScanCreationService:
|
||||
engine_names: 引擎名称列表
|
||||
yaml_configuration: YAML 格式的扫描配置
|
||||
scheduled_scan_name: 定时扫描任务名称(可选,用于通知显示)
|
||||
scan_mode: 扫描模式,'full' 或 'quick'(默认 'full')
|
||||
|
||||
Returns:
|
||||
创建的 Scan 对象列表(立即返回,不等待分发完成)
|
||||
@@ -316,6 +318,7 @@ class ScanCreationService:
|
||||
results_dir=scan_workspace_dir,
|
||||
status=ScanStatus.INITIATED,
|
||||
container_ids=[],
|
||||
scan_mode=scan_mode,
|
||||
)
|
||||
scans_to_create.append(scan)
|
||||
except (ValidationError, ValueError) as e:
|
||||
@@ -392,13 +395,13 @@ class ScanCreationService:
|
||||
for data in scan_data:
|
||||
scan_id = data['scan_id']
|
||||
logger.info("-"*40)
|
||||
logger.info("准备分发扫描任务 - Scan ID: %s, Target: %s", scan_id, data['target_name'])
|
||||
logger.info("准备分发扫描任务 - Scan ID: %s, Target ID: %s", scan_id, data['target_id'])
|
||||
try:
|
||||
logger.info("调用 distributor.execute_scan_flow...")
|
||||
success, message, container_id, worker_id = distributor.execute_scan_flow(
|
||||
scan_id=scan_id,
|
||||
target_name=data['target_name'],
|
||||
target_id=data['target_id'],
|
||||
target_name=data['target_name'],
|
||||
scan_workspace_dir=data['results_dir'],
|
||||
engine_name=data['engine_name'],
|
||||
scheduled_scan_name=data.get('scheduled_scan_name'),
|
||||
|
||||
54
backend/apps/scan/services/scan_input_target_service.py
Normal file
54
backend/apps/scan/services/scan_input_target_service.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
扫描输入目标服务
|
||||
|
||||
提供 ScanInputTarget 的写入操作。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from apps.common.validators import detect_input_type
|
||||
from apps.scan.models import ScanInputTarget
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScanInputTargetService:
|
||||
"""扫描输入目标服务,负责批量写入操作。"""
|
||||
|
||||
BATCH_SIZE = 1000
|
||||
|
||||
def bulk_create(self, scan_id: int, inputs: List[str]) -> int:
|
||||
"""
|
||||
批量创建扫描输入目标
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
inputs: 输入字符串列表
|
||||
|
||||
Returns:
|
||||
创建的记录数
|
||||
"""
|
||||
if not inputs:
|
||||
return 0
|
||||
|
||||
records = []
|
||||
for raw_input in inputs:
|
||||
value = raw_input.strip()
|
||||
if not value:
|
||||
continue
|
||||
try:
|
||||
records.append(ScanInputTarget(
|
||||
scan_id=scan_id,
|
||||
value=value,
|
||||
input_type=detect_input_type(value)
|
||||
))
|
||||
except ValueError as e:
|
||||
logger.warning("跳过无效输入 '%s': %s", value, e)
|
||||
|
||||
if not records:
|
||||
return 0
|
||||
|
||||
ScanInputTarget.objects.bulk_create(records, batch_size=self.BATCH_SIZE)
|
||||
logger.info("批量创建 %d 条扫描输入目标 (scan_id=%d)", len(records), scan_id)
|
||||
return len(records)
|
||||
@@ -1,25 +1,17 @@
|
||||
"""
|
||||
扫描任务服务
|
||||
|
||||
负责 Scan 模型的所有业务逻辑
|
||||
负责 Scan 模型的所有业务逻辑,协调各个子服务
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from django.conf import settings
|
||||
from django.db import transaction
|
||||
from django.db.utils import DatabaseError, IntegrityError, OperationalError
|
||||
from django.core.exceptions import ValidationError, ObjectDoesNotExist
|
||||
from typing import Dict, List
|
||||
|
||||
from apps.scan.models import Scan
|
||||
from apps.scan.repositories import DjangoScanRepository
|
||||
from apps.targets.repositories import DjangoTargetRepository, DjangoOrganizationRepository
|
||||
from apps.engine.repositories import DjangoEngineRepository
|
||||
from apps.targets.models import Target
|
||||
from apps.engine.models import ScanEngine
|
||||
from apps.common.definitions import ScanStatus
|
||||
@@ -30,115 +22,84 @@ logger = logging.getLogger(__name__)
|
||||
class ScanService:
|
||||
"""
|
||||
扫描任务服务(协调者)
|
||||
|
||||
职责:
|
||||
- 协调各个子服务
|
||||
- 提供统一的公共接口
|
||||
- 保持向后兼容
|
||||
|
||||
注意:
|
||||
- 具体业务逻辑已拆分到子服务
|
||||
- 本类主要负责委托和协调
|
||||
|
||||
职责:协调各个子服务,提供统一的公共接口
|
||||
"""
|
||||
|
||||
# 终态集合:这些状态一旦设置,不应该被覆盖
|
||||
|
||||
FINAL_STATUSES = {
|
||||
ScanStatus.COMPLETED,
|
||||
ScanStatus.FAILED,
|
||||
ScanStatus.CANCELLED
|
||||
}
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化服务
|
||||
"""
|
||||
# 初始化子服务
|
||||
from apps.scan.services.scan_creation_service import ScanCreationService
|
||||
from apps.scan.services.scan_state_service import ScanStateService
|
||||
from apps.scan.services.scan_control_service import ScanControlService
|
||||
from apps.scan.services.scan_stats_service import ScanStatsService
|
||||
|
||||
|
||||
self.creation_service = ScanCreationService()
|
||||
self.state_service = ScanStateService()
|
||||
self.control_service = ScanControlService()
|
||||
self.stats_service = ScanStatsService()
|
||||
|
||||
# 保留 ScanRepository(用于 get_scan 方法)
|
||||
self.scan_repo = DjangoScanRepository()
|
||||
|
||||
|
||||
def get_scan(self, scan_id: int, prefetch_relations: bool) -> Scan | None:
|
||||
"""
|
||||
获取扫描任务(包含关联对象)
|
||||
|
||||
自动预加载 engine 和 target,避免 N+1 查询问题
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
|
||||
Returns:
|
||||
Scan 对象(包含 engine 和 target)或 None
|
||||
"""
|
||||
"""获取扫描任务(包含关联对象)"""
|
||||
return self.scan_repo.get_by_id(scan_id, prefetch_relations)
|
||||
|
||||
|
||||
def get_all_scans(self, prefetch_relations: bool = True):
|
||||
"""获取所有扫描任务"""
|
||||
return self.scan_repo.get_all(prefetch_relations=prefetch_relations)
|
||||
|
||||
|
||||
def prepare_initiate_scan(
|
||||
self,
|
||||
organization_id: int | None = None,
|
||||
target_id: int | None = None,
|
||||
engine_id: int | None = None
|
||||
) -> tuple[List[Target], ScanEngine]:
|
||||
"""
|
||||
为创建扫描任务做准备,返回所需的目标列表和扫描引擎
|
||||
"""
|
||||
"""为创建扫描任务做准备,返回目标列表和扫描引擎"""
|
||||
return self.creation_service.prepare_initiate_scan(
|
||||
organization_id, target_id, engine_id
|
||||
)
|
||||
|
||||
|
||||
def prepare_initiate_scan_multi_engine(
|
||||
self,
|
||||
organization_id: int | None = None,
|
||||
target_id: int | None = None,
|
||||
engine_ids: List[int] | None = None
|
||||
) -> tuple[List[Target], str, List[str], List[int]]:
|
||||
"""
|
||||
为创建多引擎扫描任务做准备
|
||||
|
||||
Returns:
|
||||
(目标列表, 合并配置, 引擎名称列表, 引擎ID列表)
|
||||
"""
|
||||
"""为创建多引擎扫描任务做准备"""
|
||||
return self.creation_service.prepare_initiate_scan_multi_engine(
|
||||
organization_id, target_id, engine_ids
|
||||
)
|
||||
|
||||
|
||||
def create_scans(
|
||||
self,
|
||||
targets: List[Target],
|
||||
engine_ids: List[int],
|
||||
engine_names: List[str],
|
||||
yaml_configuration: str,
|
||||
scheduled_scan_name: str | None = None
|
||||
scheduled_scan_name: str | None = None,
|
||||
scan_mode: str = 'full'
|
||||
) -> List[Scan]:
|
||||
"""批量创建扫描任务(委托给 ScanCreationService)"""
|
||||
"""批量创建扫描任务"""
|
||||
return self.creation_service.create_scans(
|
||||
targets, engine_ids, engine_names, yaml_configuration, scheduled_scan_name
|
||||
targets, engine_ids, engine_names, yaml_configuration, scheduled_scan_name, scan_mode
|
||||
)
|
||||
|
||||
# ==================== 状态管理方法(委托给 ScanStateService) ====================
|
||||
|
||||
|
||||
# ==================== 状态管理方法 ====================
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
scan_id: int,
|
||||
status: ScanStatus,
|
||||
self,
|
||||
scan_id: int,
|
||||
status: ScanStatus,
|
||||
error_message: str | None = None,
|
||||
stopped_at: datetime | None = None
|
||||
) -> bool:
|
||||
"""更新 Scan 状态(委托给 ScanStateService)"""
|
||||
return self.state_service.update_status(
|
||||
scan_id, status, error_message, stopped_at
|
||||
)
|
||||
|
||||
"""更新 Scan 状态"""
|
||||
return self.state_service.update_status(scan_id, status, error_message, stopped_at)
|
||||
|
||||
def update_status_if_match(
|
||||
self,
|
||||
scan_id: int,
|
||||
@@ -146,113 +107,56 @@ class ScanService:
|
||||
new_status: ScanStatus,
|
||||
stopped_at: datetime | None = None
|
||||
) -> bool:
|
||||
"""条件更新 Scan 状态(委托给 ScanStateService)"""
|
||||
"""条件更新 Scan 状态"""
|
||||
return self.state_service.update_status_if_match(
|
||||
scan_id, current_status, new_status, stopped_at
|
||||
)
|
||||
|
||||
|
||||
def update_cached_stats(self, scan_id: int) -> dict | None:
|
||||
"""更新缓存统计数据(委托给 ScanStateService),返回统计数据字典"""
|
||||
"""更新缓存统计数据,返回统计数据字典"""
|
||||
return self.state_service.update_cached_stats(scan_id)
|
||||
|
||||
# ==================== 进度跟踪方法(委托给 ScanStateService) ====================
|
||||
|
||||
|
||||
# ==================== 进度跟踪方法 ====================
|
||||
|
||||
def init_stage_progress(self, scan_id: int, stages: list[str]) -> bool:
|
||||
"""初始化阶段进度(委托给 ScanStateService)"""
|
||||
"""初始化阶段进度"""
|
||||
return self.state_service.init_stage_progress(scan_id, stages)
|
||||
|
||||
|
||||
def start_stage(self, scan_id: int, stage: str) -> bool:
|
||||
"""开始执行某个阶段(委托给 ScanStateService)"""
|
||||
"""开始执行某个阶段"""
|
||||
return self.state_service.start_stage(scan_id, stage)
|
||||
|
||||
|
||||
def complete_stage(self, scan_id: int, stage: str, detail: str | None = None) -> bool:
|
||||
"""完成某个阶段(委托给 ScanStateService)"""
|
||||
"""完成某个阶段"""
|
||||
return self.state_service.complete_stage(scan_id, stage, detail)
|
||||
|
||||
|
||||
def fail_stage(self, scan_id: int, stage: str, error: str | None = None) -> bool:
|
||||
"""标记某个阶段失败(委托给 ScanStateService)"""
|
||||
"""标记某个阶段失败"""
|
||||
return self.state_service.fail_stage(scan_id, stage, error)
|
||||
|
||||
|
||||
def cancel_running_stages(self, scan_id: int, final_status: str = "cancelled") -> bool:
|
||||
"""取消所有正在运行的阶段(委托给 ScanStateService)"""
|
||||
"""取消所有正在运行的阶段"""
|
||||
return self.state_service.cancel_running_stages(scan_id, final_status)
|
||||
|
||||
# TODO:待接入
|
||||
def add_command_to_scan(self, scan_id: int, stage_name: str, tool_name: str, command: str) -> bool:
|
||||
"""
|
||||
增量添加命令到指定扫描阶段
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务ID
|
||||
stage_name: 阶段名称(如 'subdomain_discovery', 'port_scan')
|
||||
tool_name: 工具名称
|
||||
command: 执行命令
|
||||
|
||||
Returns:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
try:
|
||||
scan = self.get_scan(scan_id, prefetch_relations=False)
|
||||
if not scan:
|
||||
logger.error(f"扫描任务不存在: {scan_id}")
|
||||
return False
|
||||
|
||||
stage_progress = scan.stage_progress or {}
|
||||
|
||||
# 确保指定阶段存在
|
||||
if stage_name not in stage_progress:
|
||||
stage_progress[stage_name] = {'status': 'running', 'commands': []}
|
||||
|
||||
# 确保 commands 列表存在
|
||||
if 'commands' not in stage_progress[stage_name]:
|
||||
stage_progress[stage_name]['commands'] = []
|
||||
|
||||
# 增量添加命令
|
||||
command_entry = f"{tool_name}: {command}"
|
||||
stage_progress[stage_name]['commands'].append(command_entry)
|
||||
|
||||
scan.stage_progress = stage_progress
|
||||
scan.save(update_fields=['stage_progress'])
|
||||
|
||||
command_count = len(stage_progress[stage_name]['commands'])
|
||||
logger.info(f"✓ 记录命令: {stage_name}.{tool_name} (总计: {command_count})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录命令失败: {e}")
|
||||
return False
|
||||
|
||||
# ==================== 删除和控制方法(委托给 ScanControlService) ====================
|
||||
|
||||
|
||||
# ==================== 删除和控制方法 ====================
|
||||
|
||||
def delete_scans_two_phase(self, scan_ids: List[int]) -> dict:
|
||||
"""两阶段删除扫描任务(委托给 ScanControlService)"""
|
||||
"""两阶段删除扫描任务"""
|
||||
return self.control_service.delete_scans_two_phase(scan_ids)
|
||||
|
||||
|
||||
def stop_scan(self, scan_id: int) -> tuple[bool, int]:
|
||||
"""停止扫描任务(委托给 ScanControlService)"""
|
||||
"""停止扫描任务"""
|
||||
return self.control_service.stop_scan(scan_id)
|
||||
|
||||
|
||||
def hard_delete_scans(self, scan_ids: List[int]) -> tuple[int, Dict[str, int]]:
|
||||
"""
|
||||
硬删除扫描任务(真正删除数据)
|
||||
|
||||
用于 Worker 容器中执行,删除已软删除的扫描及其关联数据。
|
||||
|
||||
Args:
|
||||
scan_ids: 扫描任务 ID 列表
|
||||
|
||||
Returns:
|
||||
(删除数量, 详情字典)
|
||||
"""
|
||||
"""硬删除扫描任务(真正删除数据)"""
|
||||
return self.scan_repo.hard_delete_by_ids(scan_ids)
|
||||
|
||||
# ==================== 统计方法(委托给 ScanStatsService) ====================
|
||||
|
||||
|
||||
# ==================== 统计方法 ====================
|
||||
|
||||
def get_statistics(self) -> dict:
|
||||
"""获取扫描统计数据(委托给 ScanStatsService)"""
|
||||
"""获取扫描统计数据"""
|
||||
return self.stats_service.get_statistics()
|
||||
|
||||
|
||||
|
||||
# 导出接口
|
||||
__all__ = ['ScanService']
|
||||
|
||||
@@ -1,613 +0,0 @@
|
||||
"""
|
||||
目标导出服务
|
||||
|
||||
提供统一的目标提取和文件导出功能,支持:
|
||||
- URL 导出(纯导出,不做隐式回退)
|
||||
- 默认 URL 生成(独立方法)
|
||||
- 带回退链的 URL 导出(用例层编排)
|
||||
- 域名/IP 导出(用于端口扫描)
|
||||
- 黑名单过滤集成
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List, Iterator, Tuple
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from apps.common.utils import BlacklistFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataSource:
|
||||
"""数据源类型常量"""
|
||||
ENDPOINT = "endpoint"
|
||||
WEBSITE = "website"
|
||||
HOST_PORT = "host_port"
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
def create_export_service(target_id: int) -> 'TargetExportService':
|
||||
"""
|
||||
工厂函数:创建带黑名单过滤的导出服务
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID,用于加载黑名单规则
|
||||
|
||||
Returns:
|
||||
TargetExportService: 配置好黑名单过滤器的导出服务实例
|
||||
"""
|
||||
from apps.common.services import BlacklistService
|
||||
|
||||
rules = BlacklistService().get_rules(target_id)
|
||||
blacklist_filter = BlacklistFilter(rules)
|
||||
return TargetExportService(blacklist_filter=blacklist_filter)
|
||||
|
||||
|
||||
def _iter_default_urls_from_target(
|
||||
target_id: int,
|
||||
blacklist_filter: Optional[BlacklistFilter] = None
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
内部生成器:从 Target 本身生成默认 URL
|
||||
|
||||
根据 Target 类型生成 URL:
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
- URL: 直接使用目标 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
blacklist_filter: 黑名单过滤器
|
||||
|
||||
Yields:
|
||||
str: URL
|
||||
"""
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
return
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
# 根据 Target 类型生成 URL
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
urls = [f"http://{target_name}", f"https://{target_name}"]
|
||||
elif target_type == Target.TargetType.IP:
|
||||
urls = [f"http://{target_name}", f"https://{target_name}"]
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
urls = []
|
||||
for ip in network.hosts():
|
||||
urls.extend([f"http://{ip}", f"https://{ip}"])
|
||||
# /32 或 /128 特殊处理
|
||||
if not urls:
|
||||
ip = str(network.network_address)
|
||||
urls = [f"http://{ip}", f"https://{ip}"]
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
return
|
||||
elif target_type == Target.TargetType.URL:
|
||||
urls = [target_name]
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
return
|
||||
|
||||
# 过滤并产出
|
||||
for url in urls:
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
yield url
|
||||
|
||||
|
||||
def _iter_urls_with_fallback(
|
||||
target_id: int,
|
||||
sources: List[str],
|
||||
blacklist_filter: Optional[BlacklistFilter] = None,
|
||||
batch_size: int = 1000,
|
||||
tried_sources: Optional[List[str]] = None
|
||||
) -> Iterator[Tuple[str, str]]:
|
||||
"""
|
||||
内部生成器:流式产出 URL(带回退链)
|
||||
|
||||
按 sources 顺序尝试每个数据源,直到有数据返回。
|
||||
|
||||
回退逻辑:
|
||||
- 数据源有数据且通过过滤 → 产出 URL,停止回退
|
||||
- 数据源有数据但全被过滤 → 不回退,停止(避免意外暴露)
|
||||
- 数据源为空 → 继续尝试下一个
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
sources: 数据源优先级列表
|
||||
blacklist_filter: 黑名单过滤器
|
||||
batch_size: 批次大小
|
||||
tried_sources: 可选,用于记录尝试过的数据源(外部传入列表,会被修改)
|
||||
|
||||
Yields:
|
||||
Tuple[str, str]: (url, source) - URL 和来源标识
|
||||
"""
|
||||
from apps.asset.models import Endpoint, WebSite
|
||||
|
||||
for source in sources:
|
||||
if tried_sources is not None:
|
||||
tried_sources.append(source)
|
||||
|
||||
has_output = False # 是否有输出(通过过滤的)
|
||||
has_raw_data = False # 是否有原始数据(过滤前)
|
||||
|
||||
if source == DataSource.DEFAULT:
|
||||
# 默认 URL 生成(从 Target 本身构造,复用共用生成器)
|
||||
for url in _iter_default_urls_from_target(target_id, blacklist_filter):
|
||||
has_raw_data = True
|
||||
has_output = True
|
||||
yield url, source
|
||||
|
||||
# 检查是否有原始数据(需要单独判断,因为生成器可能被过滤后为空)
|
||||
if not has_raw_data:
|
||||
# 再次检查 Target 是否存在
|
||||
from apps.targets.services import TargetService
|
||||
target = TargetService().get_target(target_id)
|
||||
has_raw_data = target is not None
|
||||
|
||||
if has_raw_data:
|
||||
if not has_output:
|
||||
logger.info("%s 有数据但全被黑名单过滤,不回退", source)
|
||||
return
|
||||
continue
|
||||
|
||||
# 构建对应数据源的 queryset
|
||||
if source == DataSource.ENDPOINT:
|
||||
queryset = Endpoint.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
elif source == DataSource.WEBSITE:
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
else:
|
||||
logger.warning("未知的数据源类型: %s,跳过", source)
|
||||
continue
|
||||
|
||||
for url in queryset.iterator(chunk_size=batch_size):
|
||||
if url:
|
||||
has_raw_data = True
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
has_output = True
|
||||
yield url, source
|
||||
|
||||
# 有原始数据就停止(不管是否被过滤)
|
||||
if has_raw_data:
|
||||
if not has_output:
|
||||
logger.info("%s 有数据但全被黑名单过滤,不回退", source)
|
||||
return
|
||||
|
||||
logger.info("%s 为空,尝试下一个数据源", source)
|
||||
|
||||
|
||||
def get_urls_with_fallback(
|
||||
target_id: int,
|
||||
sources: List[str],
|
||||
batch_size: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
带回退链的 URL 获取用例函数(返回列表)
|
||||
|
||||
按 sources 顺序尝试每个数据源,直到有数据返回。
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
sources: 数据源优先级列表,如 ["website", "endpoint", "default"]
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'urls': List[str],
|
||||
'total_count': int,
|
||||
'source': str, # 实际使用的数据源
|
||||
'tried_sources': List[str], # 尝试过的数据源
|
||||
}
|
||||
"""
|
||||
from apps.common.services import BlacklistService
|
||||
|
||||
rules = BlacklistService().get_rules(target_id)
|
||||
blacklist_filter = BlacklistFilter(rules)
|
||||
|
||||
urls = []
|
||||
actual_source = 'none'
|
||||
tried_sources = []
|
||||
|
||||
for url, source in _iter_urls_with_fallback(target_id, sources, blacklist_filter, batch_size, tried_sources):
|
||||
urls.append(url)
|
||||
actual_source = source
|
||||
|
||||
if urls:
|
||||
logger.info("从 %s 获取 %d 条 URL", actual_source, len(urls))
|
||||
else:
|
||||
logger.warning("所有数据源都为空,无法获取 URL")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'urls': urls,
|
||||
'total_count': len(urls),
|
||||
'source': actual_source,
|
||||
'tried_sources': tried_sources,
|
||||
}
|
||||
|
||||
|
||||
def export_urls_with_fallback(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
sources: List[str],
|
||||
batch_size: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
带回退链的 URL 导出用例函数(写入文件)
|
||||
|
||||
按 sources 顺序尝试每个数据源,直到有数据返回。
|
||||
流式写入,内存占用 O(1)。
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径
|
||||
sources: 数据源优先级列表,如 ["endpoint", "website", "default"]
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int,
|
||||
'source': str, # 实际使用的数据源
|
||||
'tried_sources': List[str], # 尝试过的数据源
|
||||
}
|
||||
"""
|
||||
from apps.common.services import BlacklistService
|
||||
|
||||
rules = BlacklistService().get_rules(target_id)
|
||||
blacklist_filter = BlacklistFilter(rules)
|
||||
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
actual_source = 'none'
|
||||
tried_sources = []
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url, source in _iter_urls_with_fallback(target_id, sources, blacklist_filter, batch_size, tried_sources):
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
actual_source = source
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
if total_count > 0:
|
||||
logger.info("从 %s 导出 %d 条 URL 到 %s", actual_source, total_count, output_file)
|
||||
else:
|
||||
logger.warning("所有数据源都为空,无法导出 URL")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
'source': actual_source,
|
||||
'tried_sources': tried_sources,
|
||||
}
|
||||
|
||||
|
||||
class TargetExportService:
|
||||
"""
|
||||
目标导出服务 - 提供统一的目标提取和文件导出功能
|
||||
|
||||
使用方式:
|
||||
# 方式 1:使用用例函数(推荐)
|
||||
from apps.scan.services.target_export_service import export_urls_with_fallback, DataSource
|
||||
|
||||
result = export_urls_with_fallback(
|
||||
target_id=1,
|
||||
output_file='/path/to/output.txt',
|
||||
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT]
|
||||
)
|
||||
|
||||
# 方式 2:直接使用 Service(纯导出,不带回退)
|
||||
export_service = create_export_service(target_id)
|
||||
result = export_service.export_urls(target_id, output_path, queryset)
|
||||
"""
|
||||
|
||||
def __init__(self, blacklist_filter: Optional[BlacklistFilter] = None):
|
||||
"""
|
||||
初始化导出服务
|
||||
|
||||
Args:
|
||||
blacklist_filter: 黑名单过滤器,None 表示禁用过滤
|
||||
"""
|
||||
self.blacklist_filter = blacklist_filter
|
||||
|
||||
def export_urls(
|
||||
self,
|
||||
target_id: int,
|
||||
output_path: str,
|
||||
queryset: QuerySet,
|
||||
url_field: str = 'url',
|
||||
batch_size: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
纯 URL 导出函数 - 只负责将 queryset 数据写入文件
|
||||
|
||||
不做任何隐式回退或默认 URL 生成。
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_path: 输出文件路径
|
||||
queryset: 数据源 queryset(由调用方构建,应为 values_list flat=True)
|
||||
url_field: URL 字段名(用于黑名单过滤)
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int, # 实际写入数量
|
||||
'queryset_count': int, # 原始数据数量(迭代计数)
|
||||
'filtered_count': int, # 被黑名单过滤的数量
|
||||
}
|
||||
|
||||
Raises:
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
output_file = Path(output_path)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("开始导出 URL - target_id=%s, output=%s", target_id, output_path)
|
||||
|
||||
total_count = 0
|
||||
filtered_count = 0
|
||||
queryset_count = 0
|
||||
|
||||
try:
|
||||
with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in queryset.iterator(chunk_size=batch_size):
|
||||
queryset_count += 1
|
||||
if url:
|
||||
# 黑名单过滤
|
||||
if self.blacklist_filter and not self.blacklist_filter.is_allowed(url):
|
||||
filtered_count += 1
|
||||
continue
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
except IOError as e:
|
||||
logger.error("文件写入失败: %s - %s", output_path, e)
|
||||
raise
|
||||
|
||||
if filtered_count > 0:
|
||||
logger.info("黑名单过滤: 过滤 %d 个 URL", filtered_count)
|
||||
|
||||
logger.info(
|
||||
"✓ URL 导出完成 - 写入: %d, 原始: %d, 过滤: %d, 文件: %s",
|
||||
total_count, queryset_count, filtered_count, output_path
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_file),
|
||||
'total_count': total_count,
|
||||
'queryset_count': queryset_count,
|
||||
'filtered_count': filtered_count,
|
||||
}
|
||||
|
||||
def generate_default_urls(
|
||||
self,
|
||||
target_id: int,
|
||||
output_path: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
默认 URL 生成器
|
||||
|
||||
根据 Target 类型生成默认 URL:
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
- URL: 直接使用目标 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int,
|
||||
}
|
||||
"""
|
||||
output_file = Path(output_path)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("生成默认 URL - target_id=%d", target_id)
|
||||
|
||||
total_urls = 0
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in _iter_default_urls_from_target(target_id, self.blacklist_filter):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
if total_urls % 10000 == 0:
|
||||
logger.info("已生成 %d 个 URL...", total_urls)
|
||||
|
||||
logger.info("✓ 默认 URL 生成完成 - 数量: %d", total_urls)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_file),
|
||||
'total_count': total_urls,
|
||||
}
|
||||
|
||||
def export_hosts(
|
||||
self,
|
||||
target_id: int,
|
||||
output_path: str,
|
||||
batch_size: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
主机列表导出函数(用于端口扫描)
|
||||
|
||||
根据 Target 类型选择导出逻辑:
|
||||
- DOMAIN: 从 Subdomain 表流式导出子域名
|
||||
- IP: 直接写入 IP 地址
|
||||
- CIDR: 展开为所有主机 IP
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_path: 输出文件路径
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int,
|
||||
'target_type': str
|
||||
}
|
||||
"""
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
|
||||
output_file = Path(output_path)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取 Target 信息
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
|
||||
if not target:
|
||||
raise ValueError(f"Target ID {target_id} 不存在")
|
||||
|
||||
target_type = target.type
|
||||
target_name = target.name
|
||||
|
||||
logger.info(
|
||||
"开始导出主机列表 - Target ID: %d, Name: %s, Type: %s, 输出文件: %s",
|
||||
target_id, target_name, target_type, output_path
|
||||
)
|
||||
|
||||
total_count = 0
|
||||
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
total_count = self._export_domains(target_id, target_name, output_file, batch_size)
|
||||
type_desc = "域名"
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
total_count = self._export_ip(target_name, output_file)
|
||||
type_desc = "IP"
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
total_count = self._export_cidr(target_name, output_file)
|
||||
type_desc = "CIDR IP"
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的目标类型: {target_type}")
|
||||
|
||||
logger.info(
|
||||
"✓ 主机列表导出完成 - 类型: %s, 总数: %d, 文件: %s",
|
||||
type_desc, total_count, output_path
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_file),
|
||||
'total_count': total_count,
|
||||
'target_type': target_type
|
||||
}
|
||||
|
||||
def _export_domains(
|
||||
self,
|
||||
target_id: int,
|
||||
target_name: str,
|
||||
output_path: Path,
|
||||
batch_size: int
|
||||
) -> int:
|
||||
"""导出域名类型目标的根域名 + 子域名"""
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
|
||||
subdomain_service = SubdomainService()
|
||||
domain_iterator = subdomain_service.iter_subdomain_names_by_target(
|
||||
target_id=target_id,
|
||||
chunk_size=batch_size
|
||||
)
|
||||
|
||||
total_count = 0
|
||||
written_domains = set() # 去重(子域名表可能已包含根域名)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
# 1. 先写入根域名
|
||||
if self._should_write_target(target_name):
|
||||
f.write(f"{target_name}\n")
|
||||
written_domains.add(target_name)
|
||||
total_count += 1
|
||||
|
||||
# 2. 再写入子域名(跳过已写入的根域名)
|
||||
for domain_name in domain_iterator:
|
||||
if domain_name in written_domains:
|
||||
continue
|
||||
if self._should_write_target(domain_name):
|
||||
f.write(f"{domain_name}\n")
|
||||
written_domains.add(domain_name)
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个域名...", total_count)
|
||||
|
||||
return total_count
|
||||
|
||||
def _export_ip(self, target_name: str, output_path: Path) -> int:
|
||||
"""导出 IP 类型目标"""
|
||||
if self._should_write_target(target_name):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{target_name}\n")
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def _export_cidr(self, target_name: str, output_path: Path) -> int:
|
||||
"""导出 CIDR 类型目标,展开为每个 IP"""
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
total_count = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for ip in network.hosts():
|
||||
ip_str = str(ip)
|
||||
if self._should_write_target(ip_str):
|
||||
f.write(f"{ip_str}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个 IP...", total_count)
|
||||
|
||||
# /32 或 /128 特殊处理
|
||||
if total_count == 0:
|
||||
ip_str = str(network.network_address)
|
||||
if self._should_write_target(ip_str):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{ip_str}\n")
|
||||
total_count = 1
|
||||
|
||||
return total_count
|
||||
|
||||
def _should_write_target(self, target: str) -> bool:
|
||||
"""检查目标是否应该写入(通过黑名单过滤)"""
|
||||
if self.blacklist_filter:
|
||||
return self.blacklist_filter.is_allowed(target)
|
||||
return True
|
||||
@@ -18,7 +18,7 @@ from .subdomain_discovery import (
|
||||
|
||||
# 指纹识别任务
|
||||
from .fingerprint_detect import (
|
||||
export_urls_for_fingerprint_task,
|
||||
export_site_urls_for_fingerprint_task,
|
||||
run_xingfinger_and_stream_update_tech_task,
|
||||
)
|
||||
|
||||
@@ -35,6 +35,6 @@ __all__ = [
|
||||
'merge_and_validate_task',
|
||||
'save_domains_task',
|
||||
# 指纹识别任务
|
||||
'export_urls_for_fingerprint_task',
|
||||
'export_site_urls_for_fingerprint_task',
|
||||
'run_xingfinger_and_stream_update_tech_task',
|
||||
]
|
||||
|
||||
@@ -1,21 +1,14 @@
|
||||
"""
|
||||
导出站点 URL 到 TXT 文件的 Task
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
使用 TargetProvider 从任意数据源导出 URL(用于目录扫描)。
|
||||
|
||||
数据源: WebSite.url → Default
|
||||
数据源:WebSite,为空时回退到默认 URL
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -23,94 +16,61 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@task(name="export_sites")
|
||||
def export_sites_task(
|
||||
target_id: Optional[int] = None,
|
||||
output_file: str = "",
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000,
|
||||
output_file: str,
|
||||
provider: TargetProvider,
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点 URL 到 TXT 文件
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. WebSite 表 - 站点级别 URL
|
||||
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
数据源:WebSite,为空时回退到默认 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 每次读取的批次大小,默认 1000
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int
|
||||
'total_count': int,
|
||||
'source': str, # website | default
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
ValueError: provider 未提供
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_file=output_file,
|
||||
sources=[DataSource.WEBSITE, DataSource.DEFAULT],
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"站点 URL 导出完成 - source=%s, count=%d",
|
||||
result['source'], result['total_count']
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'success': result['success'],
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
}
|
||||
if provider is None:
|
||||
raise ValueError("必须提供 provider 参数")
|
||||
|
||||
logger.info("导出 URL - Provider: %s", type(provider).__name__)
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 按优先级获取数据源
|
||||
urls = list(provider.iter_websites())
|
||||
source = "website"
|
||||
|
||||
if not urls:
|
||||
logger.info("WebSite 为空,生成默认 URL")
|
||||
urls = list(provider.iter_default_urls())
|
||||
source = "default"
|
||||
|
||||
# 写入文件
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
for url in urls:
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
|
||||
logger.info(
|
||||
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
|
||||
source, total_count, str(output_path)
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
'source': source,
|
||||
}
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
指纹识别任务模块
|
||||
|
||||
包含:
|
||||
- export_urls_for_fingerprint_task: 导出 URL 到文件
|
||||
- export_site_urls_for_fingerprint_task: 导出站点 URL 到文件
|
||||
- run_xingfinger_and_stream_update_tech_task: 流式执行 xingfinger 并更新 tech
|
||||
"""
|
||||
|
||||
from .export_urls_task import export_urls_for_fingerprint_task
|
||||
from .export_site_urls_task import export_site_urls_for_fingerprint_task
|
||||
from .run_xingfinger_task import run_xingfinger_and_stream_update_tech_task
|
||||
|
||||
__all__ = [
|
||||
'export_urls_for_fingerprint_task',
|
||||
'export_site_urls_for_fingerprint_task',
|
||||
'run_xingfinger_and_stream_update_tech_task',
|
||||
]
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
导出站点 URL 任务
|
||||
|
||||
使用 TargetProvider 从任意数据源导出站点 URL(用于指纹识别)。
|
||||
|
||||
数据源:WebSite,为空时回退到默认 URL
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.providers import TargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_site_urls_for_fingerprint")
|
||||
def export_site_urls_for_fingerprint_task(
|
||||
output_file: str,
|
||||
provider: TargetProvider,
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的 URL 到文件(用于指纹识别)
|
||||
|
||||
数据源:WebSite,为空时回退到默认 URL
|
||||
|
||||
Args:
|
||||
output_file: 输出文件路径
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'output_file': str,
|
||||
'total_count': int,
|
||||
'source': str, # website | default
|
||||
}
|
||||
"""
|
||||
if provider is None:
|
||||
raise ValueError("必须提供 provider 参数")
|
||||
|
||||
logger.info("导出 URL - Provider: %s", type(provider).__name__)
|
||||
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 按优先级获取数据源
|
||||
urls = list(provider.iter_websites())
|
||||
source = "website"
|
||||
|
||||
if not urls:
|
||||
logger.info("WebSite 为空,生成默认 URL")
|
||||
urls = list(provider.iter_default_urls())
|
||||
source = "default"
|
||||
|
||||
# 写入文件
|
||||
total_count = 0
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in urls:
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
logger.info(
|
||||
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
|
||||
source, total_count, str(output_path)
|
||||
)
|
||||
|
||||
return {
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
'source': source,
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
"""
|
||||
导出 URL 任务
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
用于指纹识别前导出目标下的 URL 到文件
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_urls_for_fingerprint")
|
||||
def export_urls_for_fingerprint_task(
|
||||
target_id: Optional[int] = None,
|
||||
output_file: str = "",
|
||||
source: str = 'website', # 保留参数,兼容旧调用(实际值由回退链决定)
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的 URL 到文件(用于指纹识别)
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. WebSite 表 - 站点级别 URL
|
||||
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
output_file: 输出文件路径
|
||||
source: 数据源类型(保留参数,兼容旧调用,实际值由回退链决定)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 批量读取大小
|
||||
|
||||
Returns:
|
||||
dict: {'output_file': str, 'total_count': int, 'source': str}
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_file=output_file,
|
||||
sources=[DataSource.WEBSITE, DataSource.DEFAULT],
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"指纹识别 URL 导出完成 - source=%s, count=%d",
|
||||
result['source'], result['total_count']
|
||||
)
|
||||
|
||||
# 返回实际使用的数据源(不再固定为 "website")
|
||||
return {
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
'source': result['source'],
|
||||
}
|
||||
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
return {
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
'source': 'provider',
|
||||
}
|
||||
@@ -1,22 +1,14 @@
|
||||
"""
|
||||
导出主机列表到 TXT 文件的 Task
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
根据 Target 类型决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名
|
||||
- IP: 直接写入 target.name
|
||||
- CIDR: 展开 CIDR 范围内的所有 IP
|
||||
使用 TargetProvider 从任意数据源导出主机列表。
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.providers import DatabaseTargetProvider, TargetProvider
|
||||
from apps.scan.providers import TargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,76 +16,56 @@ logger = logging.getLogger(__name__)
|
||||
@task(name="export_hosts")
|
||||
def export_hosts_task(
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
provider: Optional[TargetProvider] = None,
|
||||
provider: TargetProvider,
|
||||
) -> dict:
|
||||
"""
|
||||
导出主机列表到 TXT 文件
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
根据 Target 类型自动决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名(流式处理,支持 10万+ 域名)
|
||||
- IP: 直接写入 target.name(单个 IP)
|
||||
- CIDR: 展开 CIDR 范围内的所有可用 IP
|
||||
显式组合 iter_target_hosts() + iter_subdomains()。
|
||||
|
||||
Args:
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int,
|
||||
'target_type': str # 仅传统模式返回
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 参数错误(target_id 和 provider 都未提供)
|
||||
ValueError: provider 未提供
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
if provider is None:
|
||||
raise ValueError("必须提供 provider 参数")
|
||||
|
||||
# 向后兼容:如果没有提供 provider,使用 target_id 创建 DatabaseTargetProvider
|
||||
use_legacy_mode = provider is None
|
||||
if use_legacy_mode:
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
provider = DatabaseTargetProvider(target_id=target_id)
|
||||
else:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
logger.info("导出主机列表 - Provider: %s", type(provider).__name__)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用 Provider 导出主机列表(iter_hosts 内部已处理黑名单过滤)
|
||||
total_count = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for host in provider.iter_hosts():
|
||||
# 1. 导出 Target 主机(CIDR 自动展开,已过滤黑名单)
|
||||
for host in provider.iter_target_hosts():
|
||||
f.write(f"{host}\n")
|
||||
total_count += 1
|
||||
|
||||
# 2. 导出子域名(Provider 内部已过滤黑名单)
|
||||
for subdomain in provider.iter_subdomains():
|
||||
f.write(f"{subdomain}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个主机...", total_count)
|
||||
|
||||
logger.info("✓ 主机列表导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
result = {
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
}
|
||||
|
||||
# 传统模式:保持返回值格式不变(向后兼容)
|
||||
if use_legacy_mode:
|
||||
from apps.targets.services import TargetService
|
||||
target = TargetService().get_target(target_id)
|
||||
result['target_type'] = target.type if target else 'unknown'
|
||||
|
||||
return result
|
||||
|
||||
@@ -1,208 +1,76 @@
|
||||
"""
|
||||
导出站点URL到文件的Task
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
使用 TargetProvider 从任意数据源导出 URL(用于 httpx 站点探测)。
|
||||
|
||||
特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
数据源:HostPortMapping,为空时回退到默认 URL
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.services import HostPortMappingService
|
||||
from apps.scan.services.target_export_service import create_export_service
|
||||
from apps.common.services import BlacklistService
|
||||
from apps.common.utils import BlacklistFilter
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider, ProviderContext
|
||||
from apps.scan.providers import TargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _generate_urls_from_port(host: str, port: int) -> list[str]:
|
||||
"""
|
||||
根据端口生成 URL 列表
|
||||
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
"""
|
||||
if port == 80:
|
||||
return [f"http://{host}"]
|
||||
elif port == 443:
|
||||
return [f"https://{host}"]
|
||||
else:
|
||||
return [f"http://{host}:{port}", f"https://{host}:{port}"]
|
||||
|
||||
|
||||
@task(name="export_site_urls")
|
||||
def export_site_urls_task(
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000
|
||||
provider: TargetProvider,
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点URL到文件
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从 HostPortMapping 表导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
传统模式特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
|
||||
回退逻辑(仅传统模式):
|
||||
- 如果 HostPortMapping 为空,使用 generate_default_urls() 生成默认 URL
|
||||
|
||||
|
||||
数据源:HostPortMapping,为空时回退到默认 URL
|
||||
|
||||
Args:
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
target_id: 目标ID(传统模式,向后兼容)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 每次处理的批次大小
|
||||
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_urls': int,
|
||||
'association_count': int, # 主机端口关联数量(仅传统模式)
|
||||
'source': str, # 数据来源: "host_port" | "default" | "provider"
|
||||
'source': str, # host_port | default
|
||||
}
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
ValueError: provider 未提供
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# 向后兼容:如果没有提供 provider,使用传统模式
|
||||
if provider is None:
|
||||
logger.info("使用传统模式 - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
return _export_site_urls_legacy(target_id, output_file, batch_size)
|
||||
|
||||
# Provider 模式
|
||||
logger.info("使用 Provider 模式 - Provider: %s, 输出文件: %s", type(provider).__name__, output_file)
|
||||
|
||||
# 确保输出目录存在
|
||||
raise ValueError("必须提供 provider 参数")
|
||||
|
||||
logger.info("导出 URL - Provider: %s", type(provider).__name__)
|
||||
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用 Provider 导出 URL 列表
|
||||
|
||||
# 按优先级获取数据源
|
||||
urls = list(provider.iter_host_port_urls())
|
||||
source = "host_port"
|
||||
|
||||
if not urls:
|
||||
logger.info("HostPortMapping 为空,生成默认 URL")
|
||||
urls = list(provider.iter_default_urls())
|
||||
source = "default"
|
||||
|
||||
# 写入文件
|
||||
total_urls = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
for url in urls:
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
if total_urls % 1000 == 0:
|
||||
logger.info("已导出 %d 个URL...", total_urls)
|
||||
|
||||
logger.info("✓ URL导出完成 - 总数: %d, 文件: %s", total_urls, str(output_path))
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_urls': total_urls,
|
||||
'source': 'provider',
|
||||
}
|
||||
|
||||
|
||||
def _export_site_urls_legacy(target_id: int, output_file: str, batch_size: int) -> dict:
|
||||
"""
|
||||
传统模式:从 HostPortMapping 表导出 URL
|
||||
|
||||
保持原有逻辑不变,确保向后兼容
|
||||
"""
|
||||
logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取规则并创建过滤器
|
||||
blacklist_filter = BlacklistFilter(BlacklistService().get_rules(target_id))
|
||||
|
||||
# 直接查询 HostPortMapping 表,按 host 排序
|
||||
service = HostPortMappingService()
|
||||
associations = service.iter_host_port_by_target(
|
||||
target_id=target_id,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
total_urls = 0
|
||||
association_count = 0
|
||||
filtered_count = 0
|
||||
|
||||
# 流式写入文件(特殊端口逻辑)
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for assoc in associations:
|
||||
association_count += 1
|
||||
host = assoc['host']
|
||||
port = assoc['port']
|
||||
|
||||
# 先校验 host,通过了再生成 URL
|
||||
if not blacklist_filter.is_allowed(host):
|
||||
filtered_count += 1
|
||||
continue
|
||||
|
||||
# 根据端口号生成URL
|
||||
for url in _generate_urls_from_port(host, port):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
if association_count % 1000 == 0:
|
||||
logger.info("已处理 %d 条关联,生成 %d 个URL...", association_count, total_urls)
|
||||
|
||||
if filtered_count > 0:
|
||||
logger.info("黑名单过滤: 过滤 %d 条关联", filtered_count)
|
||||
|
||||
logger.info(
|
||||
"✓ 站点URL导出完成 - 关联数: %d, 总URL数: %d, 文件: %s",
|
||||
association_count, total_urls, str(output_path)
|
||||
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
|
||||
source, total_urls, str(output_path)
|
||||
)
|
||||
|
||||
# 判断数据来源
|
||||
source = "host_port"
|
||||
|
||||
# 数据存在但全被过滤,不回退
|
||||
if association_count > 0 and total_urls == 0:
|
||||
logger.info("HostPortMapping 有 %d 条数据,但全被黑名单过滤,不回退", association_count)
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_urls': 0,
|
||||
'association_count': association_count,
|
||||
'source': source,
|
||||
}
|
||||
|
||||
# 数据源为空,回退到默认 URL 生成
|
||||
if total_urls == 0:
|
||||
logger.info("HostPortMapping 为空,使用默认 URL 生成")
|
||||
export_service = create_export_service(target_id)
|
||||
result = export_service.generate_default_urls(target_id, str(output_path))
|
||||
total_urls = result['total_count']
|
||||
source = "default"
|
||||
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_urls': total_urls,
|
||||
'association_count': association_count,
|
||||
'source': source,
|
||||
}
|
||||
|
||||
@@ -119,7 +119,8 @@ def merge_and_validate_task(result_files: List[str], result_dir: str) -> str:
|
||||
unique_count = sum(1 for _ in f)
|
||||
|
||||
if unique_count == 0:
|
||||
raise RuntimeError("未找到任何有效域名")
|
||||
logger.warning("未找到任何有效域名,返回空文件")
|
||||
# 不抛出异常,返回空文件让后续流程正常处理
|
||||
|
||||
file_size_kb = merged_file.stat().st_size / 1024
|
||||
logger.info("✓ 合并去重完成 - 去重后: %d 个域名, 文件大小: %.2f KB", unique_count, file_size_kb)
|
||||
|
||||
@@ -1,23 +1,16 @@
|
||||
"""
|
||||
导出站点 URL 列表任务
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
使用 TargetProvider 从任意数据源导出 URL(用于 katana 等爬虫工具)。
|
||||
|
||||
数据源: WebSite.url → Default(用于 katana 等爬虫工具)
|
||||
数据源:WebSite,为空时回退到默认 URL
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
|
||||
from apps.scan.providers import TargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,92 +22,58 @@ logger = logging.getLogger(__name__)
|
||||
)
|
||||
def export_sites_task(
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
scan_id: Optional[int] = None,
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000
|
||||
provider: TargetProvider,
|
||||
) -> dict:
|
||||
"""
|
||||
导出站点 URL 列表到文件(用于 katana 等爬虫工具)
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. WebSite 表 - 站点级别 URL
|
||||
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
|
||||
数据源:WebSite,为空时回退到默认 URL
|
||||
|
||||
Args:
|
||||
output_file: 输出文件路径
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
scan_id: 扫描 ID(保留参数,兼容旧调用)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 批次大小(内存优化)
|
||||
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'output_file': str, # 输出文件路径
|
||||
'asset_count': int, # 资产数量
|
||||
'output_file': str,
|
||||
'asset_count': int,
|
||||
'source': str, # website | default
|
||||
}
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 参数错误
|
||||
RuntimeError: 执行失败
|
||||
ValueError: provider 未提供
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_file=output_file,
|
||||
sources=[DataSource.WEBSITE, DataSource.DEFAULT],
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"站点 URL 导出完成 - source=%s, count=%d",
|
||||
result['source'], result['total_count']
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'output_file': result['output_file'],
|
||||
'asset_count': result['total_count'],
|
||||
}
|
||||
if provider is None:
|
||||
raise ValueError("必须提供 provider 参数")
|
||||
|
||||
logger.info("导出 URL - Provider: %s", type(provider).__name__)
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 按优先级获取数据源
|
||||
urls = list(provider.iter_websites())
|
||||
source = "website"
|
||||
|
||||
if not urls:
|
||||
logger.info("WebSite 为空,生成默认 URL")
|
||||
urls = list(provider.iter_default_urls())
|
||||
source = "default"
|
||||
|
||||
# 写入文件
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
for url in urls:
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
|
||||
logger.info(
|
||||
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
|
||||
source, total_count, str(output_path)
|
||||
)
|
||||
|
||||
return {
|
||||
'output_file': str(output_path),
|
||||
'asset_count': total_count,
|
||||
'source': source,
|
||||
}
|
||||
|
||||
@@ -2,18 +2,21 @@
|
||||
|
||||
包含:
|
||||
- export_endpoints_task: 导出端点 URL 到文件
|
||||
- export_websites_task: 导出网站 URL 到文件
|
||||
- run_vuln_tool_task: 执行漏洞扫描工具(非流式)
|
||||
- run_and_stream_save_dalfox_vulns_task: Dalfox 流式执行并保存漏洞结果
|
||||
- run_and_stream_save_nuclei_vulns_task: Nuclei 流式执行并保存漏洞结果
|
||||
"""
|
||||
|
||||
from .export_endpoints_task import export_endpoints_task
|
||||
from .export_websites_task import export_websites_task
|
||||
from .run_vuln_tool_task import run_vuln_tool_task
|
||||
from .run_and_stream_save_dalfox_vulns_task import run_and_stream_save_dalfox_vulns_task
|
||||
from .run_and_stream_save_nuclei_vulns_task import run_and_stream_save_nuclei_vulns_task
|
||||
|
||||
__all__ = [
|
||||
"export_endpoints_task",
|
||||
"export_websites_task",
|
||||
"run_vuln_tool_task",
|
||||
"run_and_stream_save_dalfox_vulns_task",
|
||||
"run_and_stream_save_nuclei_vulns_task",
|
||||
|
||||
@@ -1,118 +1,74 @@
|
||||
"""导出 Endpoint URL 到文件的 Task
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
使用 TargetProvider 从任意数据源导出 URL。
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. Endpoint.url - 最精细的 URL(含路径、参数等)
|
||||
2. WebSite.url - 站点级别 URL
|
||||
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
数据源:Endpoint,为空时回退到默认 URL
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
|
||||
from apps.scan.providers import TargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_endpoints")
|
||||
def export_endpoints_task(
|
||||
target_id: Optional[int] = None,
|
||||
output_file: str = "",
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000,
|
||||
output_file: str,
|
||||
provider: TargetProvider,
|
||||
) -> Dict[str, object]:
|
||||
"""导出目标下的所有 Endpoint URL 到文本文件。
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. Endpoint 表 - 最精细的 URL(含路径、参数等)
|
||||
2. WebSite 表 - 站点级别 URL
|
||||
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
数据源优先级:Endpoint → 默认生成
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 每次从数据库迭代的批大小
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"success": bool,
|
||||
"output_file": str,
|
||||
"total_count": int,
|
||||
"source": str, # 数据来源: "endpoint" | "website" | "default" | "none" | "provider"
|
||||
"source": str, # endpoint | default
|
||||
}
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_file=output_file,
|
||||
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"URL 导出完成 - source=%s, count=%d, tried=%s",
|
||||
result['source'], result['total_count'], result['tried_sources']
|
||||
)
|
||||
|
||||
return {
|
||||
"success": result['success'],
|
||||
"output_file": result['output_file'],
|
||||
"total_count": result['total_count'],
|
||||
"source": result['source'],
|
||||
}
|
||||
if provider is None:
|
||||
raise ValueError("必须提供 provider 参数")
|
||||
|
||||
logger.info("导出 URL - Provider: %s", type(provider).__name__)
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> Dict[str, object]:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 获取数据,为空时回退到默认 URL
|
||||
urls = list(provider.iter_endpoints())
|
||||
source = "endpoint"
|
||||
|
||||
if not urls:
|
||||
logger.info("Endpoint 为空,生成默认 URL")
|
||||
urls = list(provider.iter_default_urls())
|
||||
source = "default"
|
||||
|
||||
# 写入文件
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
for url in urls:
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
|
||||
logger.info(
|
||||
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
|
||||
source, total_count, str(output_path)
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output_file": str(output_path),
|
||||
"total_count": total_count,
|
||||
"source": "provider",
|
||||
"source": source,
|
||||
}
|
||||
|
||||
73
backend/apps/scan/tasks/vuln_scan/export_websites_task.py
Normal file
73
backend/apps/scan/tasks/vuln_scan/export_websites_task.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""导出 WebSite URL 到文件的 Task
|
||||
|
||||
使用 TargetProvider 从任意数据源导出 URL。
|
||||
|
||||
数据源:WebSite,为空时回退到默认 URL
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.providers import TargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_websites_for_vuln_scan")
|
||||
def export_websites_task(
|
||||
output_file: str,
|
||||
provider: TargetProvider,
|
||||
) -> dict:
|
||||
"""导出目标下的所有 WebSite URL 到文本文件。
|
||||
|
||||
数据源优先级:WebSite → 默认生成
|
||||
|
||||
Args:
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
provider: TargetProvider 实例
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"success": bool,
|
||||
"output_file": str,
|
||||
"total_count": int,
|
||||
"source": str, # website | default
|
||||
}
|
||||
"""
|
||||
if provider is None:
|
||||
raise ValueError("必须提供 provider 参数")
|
||||
|
||||
logger.info("导出 URL - Provider: %s", type(provider).__name__)
|
||||
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取数据,为空时回退到默认 URL
|
||||
urls = list(provider.iter_websites())
|
||||
source = "website"
|
||||
|
||||
if not urls:
|
||||
logger.info("WebSite 为空,生成默认 URL")
|
||||
urls = list(provider.iter_default_urls())
|
||||
source = "default"
|
||||
|
||||
# 写入文件
|
||||
total_count = 0
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in urls:
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
logger.info(
|
||||
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
|
||||
source, total_count, str(output_path)
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output_file": str(output_path),
|
||||
"total_count": total_count,
|
||||
"source": source,
|
||||
}
|
||||
@@ -410,6 +410,14 @@ class CommandExecutor:
|
||||
# 关键修复:确保进程树被清理
|
||||
if process:
|
||||
self._kill_process_tree(process)
|
||||
# 回收子进程,避免产生 zombie 进程
|
||||
try:
|
||||
process.wait(timeout=GRACEFUL_SHUTDOWN_TIMEOUT)
|
||||
except subprocess.TimeoutExpired:
|
||||
# kill 之后仍未退出:避免阻塞,继续清理后续资源
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 关闭文件句柄
|
||||
if log_file_handle:
|
||||
|
||||
@@ -1,94 +1,76 @@
|
||||
from rest_framework import viewsets, status
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.exceptions import NotFound, APIException
|
||||
from rest_framework.filters import SearchFilter
|
||||
from django_filters.rest_framework import DjangoFilterBackend
|
||||
from django.core.exceptions import ObjectDoesNotExist, ValidationError
|
||||
from django.db.utils import DatabaseError, IntegrityError, OperationalError
|
||||
"""扫描任务视图集"""
|
||||
|
||||
import logging
|
||||
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from django.core.exceptions import ObjectDoesNotExist, ValidationError
|
||||
from django.db.utils import DatabaseError, IntegrityError, OperationalError
|
||||
from django_filters.rest_framework import DjangoFilterBackend
|
||||
from rest_framework import status, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.filters import SearchFilter
|
||||
|
||||
from apps.common.definitions import ScanStatus
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
from apps.scan.utils.config_merger import ConfigConflictError
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.common.response_helpers import error_response, success_response
|
||||
from apps.targets.repositories import DjangoOrganizationRepository, DjangoTargetRepository
|
||||
|
||||
from ..models import Scan
|
||||
from ..serializers import (
|
||||
InitiateScanSerializer,
|
||||
QuickScanSerializer,
|
||||
ScanHistorySerializer,
|
||||
ScanSerializer,
|
||||
)
|
||||
from ..services.quick_scan_service import QuickScanService
|
||||
from ..services.scan_input_target_service import ScanInputTargetService
|
||||
from ..services.scan_service import ScanService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ..models import Scan, ScheduledScan
|
||||
from ..serializers import (
|
||||
ScanSerializer, ScanHistorySerializer, QuickScanSerializer,
|
||||
InitiateScanSerializer, ScheduledScanSerializer, CreateScheduledScanSerializer,
|
||||
UpdateScheduledScanSerializer, ToggleScheduledScanSerializer
|
||||
)
|
||||
from ..services.scan_service import ScanService
|
||||
from ..services.scheduled_scan_service import ScheduledScanService
|
||||
from ..repositories import ScheduledScanDTO
|
||||
from apps.targets.services.target_service import TargetService
|
||||
from apps.targets.services.organization_service import OrganizationService
|
||||
from apps.engine.services.engine_service import EngineService
|
||||
from apps.common.definitions import ScanStatus
|
||||
from apps.common.pagination import BasePagination
|
||||
|
||||
def _handle_database_error():
|
||||
"""处理数据库错误的通用响应"""
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Database error',
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
)
|
||||
|
||||
|
||||
class ScanViewSet(viewsets.ModelViewSet):
|
||||
"""扫描任务视图集"""
|
||||
|
||||
serializer_class = ScanSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [DjangoFilterBackend, SearchFilter]
|
||||
filterset_fields = ['target'] # 支持 ?target=123 过滤
|
||||
search_fields = ['target__name'] # 按目标名称搜索
|
||||
|
||||
filterset_fields = ['target']
|
||||
search_fields = ['target__name']
|
||||
|
||||
def get_queryset(self):
|
||||
"""优化查询集,提升API性能
|
||||
|
||||
查询优化策略:
|
||||
- select_related: 预加载 target 和 engine(一对一/多对一关系,使用 JOIN)
|
||||
- 移除 prefetch_related: 避免加载大量资产数据到内存
|
||||
- order_by: 按创建时间降序排列(最新创建的任务排在最前面)
|
||||
|
||||
性能优化原理:
|
||||
- 列表页:使用缓存统计字段(cached_*_count),避免实时 COUNT 查询
|
||||
- 序列化器:严格验证缓存字段,确保数据一致性
|
||||
- 分页场景:每页只显示10条记录,查询高效
|
||||
- 避免大数据加载:不再预加载所有关联的资产数据
|
||||
"""
|
||||
# 只保留必要的 select_related,移除所有 prefetch_related
|
||||
"""优化查询集,提升API性能"""
|
||||
scan_service = ScanService()
|
||||
queryset = scan_service.get_all_scans(prefetch_relations=True)
|
||||
|
||||
return queryset
|
||||
|
||||
return scan_service.get_all_scans(prefetch_relations=True)
|
||||
|
||||
def get_serializer_class(self):
|
||||
"""根据不同的 action 返回不同的序列化器
|
||||
|
||||
- list action: 使用 ScanHistorySerializer(包含 summary 和 progress)
|
||||
- retrieve action: 使用 ScanHistorySerializer(包含 summary 和 progress)
|
||||
- 其他 action: 使用标准的 ScanSerializer
|
||||
"""
|
||||
"""根据不同的 action 返回不同的序列化器"""
|
||||
if self.action in ['list', 'retrieve']:
|
||||
return ScanHistorySerializer
|
||||
return ScanSerializer
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
"""
|
||||
删除单个扫描任务(两阶段删除)
|
||||
|
||||
1. 软删除:立即对用户不可见
|
||||
2. 硬删除:后台异步执行
|
||||
"""
|
||||
"""删除单个扫描任务(两阶段删除)"""
|
||||
try:
|
||||
scan = self.get_object()
|
||||
scan_service = ScanService()
|
||||
result = scan_service.delete_scans_two_phase([scan.id])
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'scanId': scan.id,
|
||||
'deletedCount': result['soft_deleted_count'],
|
||||
'deletedScans': result['scan_names']
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return success_response(data={
|
||||
'scanId': scan.id,
|
||||
'deletedCount': result['soft_deleted_count'],
|
||||
'deletedScans': result['scan_names']
|
||||
})
|
||||
|
||||
except Scan.DoesNotExist:
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
@@ -100,80 +82,57 @@ class ScanViewSet(viewsets.ModelViewSet):
|
||||
message=str(e),
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("删除扫描任务时发生错误")
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
@action(detail=False, methods=['post'])
|
||||
def quick(self, request):
|
||||
"""
|
||||
快速扫描接口
|
||||
|
||||
|
||||
功能:
|
||||
1. 接收目标列表和 YAML 配置
|
||||
2. 自动解析输入(支持 URL、域名、IP、CIDR)
|
||||
3. 批量创建 Target、Website、Endpoint 资产
|
||||
4. 立即发起批量扫描
|
||||
|
||||
请求参数:
|
||||
{
|
||||
"targets": [{"name": "example.com"}, {"name": "https://example.com/api"}],
|
||||
"configuration": "subdomain_discovery:\n enabled: true\n ...",
|
||||
"engine_ids": [1, 2], // 可选,用于记录
|
||||
"engine_names": ["引擎A", "引擎B"] // 可选,用于记录
|
||||
}
|
||||
|
||||
支持的输入格式:
|
||||
- 域名: example.com
|
||||
- IP: 192.168.1.1
|
||||
- CIDR: 10.0.0.0/8
|
||||
- URL: https://example.com/api/v1
|
||||
4. 立即发起批量扫描(scan_mode='quick')
|
||||
5. 将用户输入写入 ScanInputTarget 表
|
||||
"""
|
||||
from ..services.quick_scan_service import QuickScanService
|
||||
|
||||
serializer = QuickScanSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
targets_data = serializer.validated_data['targets']
|
||||
configuration = serializer.validated_data['configuration']
|
||||
engine_ids = serializer.validated_data.get('engine_ids', [])
|
||||
engine_names = serializer.validated_data.get('engine_names', [])
|
||||
|
||||
data = serializer.validated_data
|
||||
|
||||
try:
|
||||
# 提取输入字符串列表
|
||||
inputs = [t['name'] for t in targets_data]
|
||||
|
||||
# 1. 使用 QuickScanService 解析输入并创建资产
|
||||
quick_scan_service = QuickScanService()
|
||||
result = quick_scan_service.process_quick_scan(inputs, engine_ids[0] if engine_ids else None)
|
||||
|
||||
targets = result['targets']
|
||||
|
||||
if not targets:
|
||||
inputs = [t['name'] for t in data['targets']]
|
||||
|
||||
# 1. 解析输入并创建资产
|
||||
result = QuickScanService().process_quick_scan(inputs)
|
||||
|
||||
if not result['targets']:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='No valid targets for scanning',
|
||||
details=result.get('errors', []),
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 2. 直接使用前端传递的配置创建扫描
|
||||
scan_service = ScanService()
|
||||
created_scans = scan_service.create_scans(
|
||||
targets=targets,
|
||||
engine_ids=engine_ids,
|
||||
engine_names=engine_names,
|
||||
yaml_configuration=configuration
|
||||
|
||||
# 2. 创建扫描(scan_mode='quick')
|
||||
created_scans = ScanService().create_scans(
|
||||
targets=result['targets'],
|
||||
engine_ids=data.get('engine_ids', []),
|
||||
engine_names=data.get('engine_names', []),
|
||||
yaml_configuration=data['configuration'],
|
||||
scan_mode='quick'
|
||||
)
|
||||
|
||||
# 检查是否成功创建扫描任务
|
||||
|
||||
if not created_scans:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='No scan tasks were created. All targets may already have active scans.',
|
||||
message='No scan tasks were created. '
|
||||
'All targets may already have active scans.',
|
||||
details={
|
||||
'targetStats': result['target_stats'],
|
||||
'assetStats': result['asset_stats'],
|
||||
@@ -181,317 +140,210 @@ class ScanViewSet(viewsets.ModelViewSet):
|
||||
},
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
)
|
||||
|
||||
# 序列化返回结果
|
||||
scan_serializer = ScanSerializer(created_scans, many=True)
|
||||
|
||||
|
||||
# 3. 将用户输入写入 ScanInputTarget 表
|
||||
scan_input_service = ScanInputTargetService()
|
||||
target_inputs_map = result.get('target_inputs_map', {})
|
||||
for scan in created_scans:
|
||||
inputs_for_target = target_inputs_map.get(scan.target.name, [])
|
||||
if inputs_for_target:
|
||||
scan_input_service.bulk_create(scan.id, inputs_for_target)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'count': len(created_scans),
|
||||
'targetStats': result['target_stats'],
|
||||
'assetStats': result['asset_stats'],
|
||||
'errors': result.get('errors', []),
|
||||
'scans': scan_serializer.data
|
||||
'scans': ScanSerializer(created_scans, many=True).data
|
||||
},
|
||||
status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
|
||||
|
||||
except ValidationError as e:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
except Exception as e:
|
||||
except (DatabaseError, IntegrityError, OperationalError):
|
||||
logger.exception("快速扫描启动失败")
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
return _handle_database_error()
|
||||
|
||||
@action(detail=False, methods=['post'])
|
||||
def initiate(self, request):
|
||||
"""
|
||||
发起扫描任务
|
||||
|
||||
|
||||
请求参数:
|
||||
- organization_id: 组织ID (int, 可选)
|
||||
- target_id: 目标ID (int, 可选)
|
||||
- configuration: YAML 配置字符串 (str, 必填)
|
||||
- engine_ids: 扫描引擎ID列表 (list[int], 必填)
|
||||
- engine_names: 引擎名称列表 (list[str], 必填)
|
||||
|
||||
|
||||
注意: organization_id 和 target_id 二选一
|
||||
|
||||
返回:
|
||||
- 扫描任务详情(单个或多个)
|
||||
"""
|
||||
# 使用 serializer 验证请求数据
|
||||
serializer = InitiateScanSerializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
# 获取验证后的数据
|
||||
organization_id = serializer.validated_data.get('organization_id')
|
||||
target_id = serializer.validated_data.get('target_id')
|
||||
configuration = serializer.validated_data['configuration']
|
||||
engine_ids = serializer.validated_data['engine_ids']
|
||||
engine_names = serializer.validated_data['engine_names']
|
||||
|
||||
data = serializer.validated_data
|
||||
|
||||
try:
|
||||
# 获取目标列表
|
||||
scan_service = ScanService()
|
||||
|
||||
if organization_id:
|
||||
from apps.targets.repositories import DjangoOrganizationRepository
|
||||
org_repo = DjangoOrganizationRepository()
|
||||
organization = org_repo.get_by_id(organization_id)
|
||||
if not organization:
|
||||
raise ObjectDoesNotExist(f'Organization ID {organization_id} 不存在')
|
||||
targets = org_repo.get_targets(organization_id)
|
||||
if not targets:
|
||||
raise ValidationError(f'组织 ID {organization_id} 下没有目标')
|
||||
else:
|
||||
from apps.targets.repositories import DjangoTargetRepository
|
||||
target_repo = DjangoTargetRepository()
|
||||
target = target_repo.get_by_id(target_id)
|
||||
if not target:
|
||||
raise ObjectDoesNotExist(f'Target ID {target_id} 不存在')
|
||||
targets = [target]
|
||||
|
||||
# 直接使用前端传递的配置创建扫描
|
||||
created_scans = scan_service.create_scans(
|
||||
targets=targets,
|
||||
engine_ids=engine_ids,
|
||||
engine_names=engine_names,
|
||||
yaml_configuration=configuration
|
||||
targets = self._get_targets_for_scan(
|
||||
data.get('organization_id'),
|
||||
data.get('target_id')
|
||||
)
|
||||
|
||||
# 检查是否成功创建扫描任务
|
||||
|
||||
created_scans = ScanService().create_scans(
|
||||
targets=targets,
|
||||
engine_ids=data['engine_ids'],
|
||||
engine_names=data['engine_names'],
|
||||
yaml_configuration=data['configuration']
|
||||
)
|
||||
|
||||
if not created_scans:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='No scan tasks were created. All targets may already have active scans.',
|
||||
message='No scan tasks were created. '
|
||||
'All targets may already have active scans.',
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
)
|
||||
|
||||
# 序列化返回结果
|
||||
scan_serializer = ScanSerializer(created_scans, many=True)
|
||||
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'count': len(created_scans),
|
||||
'scans': scan_serializer.data
|
||||
'scans': ScanSerializer(created_scans, many=True).data
|
||||
},
|
||||
status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
|
||||
|
||||
except ObjectDoesNotExist as e:
|
||||
# 资源不存在错误(由 service 层抛出)
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
# 参数验证错误(由 service 层抛出)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
except (DatabaseError, IntegrityError, OperationalError):
|
||||
# 数据库错误
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Database error',
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
)
|
||||
return _handle_database_error()
|
||||
|
||||
# 所有快照相关的 action 和 export 已迁移到 asset/views.py 中的快照 ViewSet
|
||||
# GET /api/scans/{id}/subdomains/ -> SubdomainSnapshotViewSet
|
||||
# GET /api/scans/{id}/subdomains/export/ -> SubdomainSnapshotViewSet.export
|
||||
# GET /api/scans/{id}/websites/ -> WebsiteSnapshotViewSet
|
||||
# GET /api/scans/{id}/websites/export/ -> WebsiteSnapshotViewSet.export
|
||||
# GET /api/scans/{id}/directories/ -> DirectorySnapshotViewSet
|
||||
# GET /api/scans/{id}/directories/export/ -> DirectorySnapshotViewSet.export
|
||||
# GET /api/scans/{id}/endpoints/ -> EndpointSnapshotViewSet
|
||||
# GET /api/scans/{id}/endpoints/export/ -> EndpointSnapshotViewSet.export
|
||||
# GET /api/scans/{id}/ip-addresses/ -> HostPortMappingSnapshotViewSet
|
||||
# GET /api/scans/{id}/ip-addresses/export/ -> HostPortMappingSnapshotViewSet.export
|
||||
# GET /api/scans/{id}/vulnerabilities/ -> VulnerabilitySnapshotViewSet
|
||||
def _get_targets_for_scan(self, organization_id, target_id):
|
||||
"""根据组织ID或目标ID获取扫描目标列表"""
|
||||
if organization_id:
|
||||
org_repo = DjangoOrganizationRepository()
|
||||
organization = org_repo.get_by_id(organization_id)
|
||||
if not organization:
|
||||
raise ObjectDoesNotExist(f'Organization ID {organization_id} 不存在')
|
||||
targets = org_repo.get_targets(organization_id)
|
||||
if not targets:
|
||||
raise ValidationError(f'组织 ID {organization_id} 下没有目标')
|
||||
return targets
|
||||
|
||||
target_repo = DjangoTargetRepository()
|
||||
target = target_repo.get_by_id(target_id)
|
||||
if not target:
|
||||
raise ObjectDoesNotExist(f'Target ID {target_id} 不存在')
|
||||
return [target]
|
||||
|
||||
@action(detail=False, methods=['post', 'delete'], url_path='bulk-delete')
|
||||
def bulk_delete(self, request):
|
||||
"""
|
||||
批量删除扫描记录
|
||||
|
||||
请求参数:
|
||||
- ids: 扫描ID列表 (list[int], 必填)
|
||||
|
||||
示例请求:
|
||||
POST /api/scans/bulk-delete/
|
||||
{
|
||||
"ids": [1, 2, 3]
|
||||
}
|
||||
|
||||
返回:
|
||||
- message: 成功消息
|
||||
- deletedCount: 实际删除的记录数
|
||||
|
||||
注意:
|
||||
- 使用级联删除,会同时删除关联的子域名、端点等数据
|
||||
- 只删除存在的记录,不存在的ID会被忽略
|
||||
"""
|
||||
"""批量删除扫描记录"""
|
||||
ids = request.data.get('ids', [])
|
||||
|
||||
# 参数验证
|
||||
|
||||
if not ids:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Missing required parameter: ids',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
if not isinstance(ids, list):
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='ids must be an array',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
if not all(isinstance(i, int) for i in ids):
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='All elements in ids array must be integers',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 使用 Service 层批量删除(两阶段删除)
|
||||
scan_service = ScanService()
|
||||
result = scan_service.delete_scans_two_phase(ids)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'deletedCount': result['soft_deleted_count'],
|
||||
'deletedScans': result['scan_names']
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return success_response(data={
|
||||
'deletedCount': result['soft_deleted_count'],
|
||||
'deletedScans': result['scan_names']
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
# 未找到记录
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except (DatabaseError, IntegrityError, OperationalError):
|
||||
logger.exception("批量删除扫描任务时发生错误")
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
return _handle_database_error()
|
||||
|
||||
@action(detail=False, methods=['get'])
|
||||
def statistics(self, request):
|
||||
"""
|
||||
获取扫描统计数据
|
||||
|
||||
返回扫描任务的汇总统计信息,用于仪表板和扫描历史页面。
|
||||
使用缓存字段聚合查询,性能优异。
|
||||
|
||||
返回:
|
||||
- total: 总扫描次数
|
||||
- running: 运行中的扫描数量
|
||||
- completed: 已完成的扫描数量
|
||||
- failed: 失败的扫描数量
|
||||
- totalVulns: 总共发现的漏洞数量
|
||||
- totalSubdomains: 总共发现的子域名数量
|
||||
- totalEndpoints: 总共发现的端点数量
|
||||
- totalAssets: 总资产数
|
||||
"""
|
||||
def statistics(self, request): # pylint: disable=unused-argument
|
||||
"""获取扫描统计数据"""
|
||||
try:
|
||||
# 使用 Service 层获取统计数据
|
||||
scan_service = ScanService()
|
||||
stats = scan_service.get_statistics()
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
'total': stats['total'],
|
||||
'running': stats['running'],
|
||||
'completed': stats['completed'],
|
||||
'failed': stats['failed'],
|
||||
'totalVulns': stats['total_vulns'],
|
||||
'totalSubdomains': stats['total_subdomains'],
|
||||
'totalEndpoints': stats['total_endpoints'],
|
||||
'totalWebsites': stats['total_websites'],
|
||||
'totalAssets': stats['total_assets'],
|
||||
}
|
||||
)
|
||||
|
||||
stats = ScanService().get_statistics()
|
||||
|
||||
return success_response(data={
|
||||
'total': stats['total'],
|
||||
'running': stats['running'],
|
||||
'completed': stats['completed'],
|
||||
'failed': stats['failed'],
|
||||
'totalVulns': stats['total_vulns'],
|
||||
'totalSubdomains': stats['total_subdomains'],
|
||||
'totalEndpoints': stats['total_endpoints'],
|
||||
'totalWebsites': stats['total_websites'],
|
||||
'totalAssets': stats['total_assets'],
|
||||
})
|
||||
|
||||
except (DatabaseError, OperationalError):
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Database error',
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
)
|
||||
|
||||
return _handle_database_error()
|
||||
|
||||
@action(detail=True, methods=['post'])
|
||||
def stop(self, request, pk=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
停止扫描任务
|
||||
|
||||
URL: POST /api/scans/{id}/stop/
|
||||
|
||||
功能:
|
||||
- 终止正在运行或初始化的扫描任务
|
||||
- 更新扫描状态为 CANCELLED
|
||||
|
||||
状态限制:
|
||||
- 只能停止 RUNNING 或 INITIATED 状态的扫描
|
||||
- 已完成、失败或取消的扫描无法停止
|
||||
|
||||
返回:
|
||||
- message: 成功消息
|
||||
- revokedTaskCount: 取消的 Flow Run 数量
|
||||
"""
|
||||
"""停止扫描任务"""
|
||||
try:
|
||||
# 使用 Service 层处理停止逻辑
|
||||
scan_service = ScanService()
|
||||
success, revoked_count = scan_service.stop_scan(scan_id=pk)
|
||||
|
||||
|
||||
if not success:
|
||||
# 检查是否是状态不允许的问题
|
||||
scan = scan_service.get_scan(scan_id=pk, prefetch_relations=False)
|
||||
if scan and scan.status not in [ScanStatus.RUNNING, ScanStatus.INITIATED]:
|
||||
return error_response(
|
||||
code=ErrorCodes.BAD_REQUEST,
|
||||
message=f'Cannot stop scan: current status is {ScanStatus(scan.status).label}',
|
||||
message=f'Cannot stop scan: current status is '
|
||||
f'{ScanStatus(scan.status).label}',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
# 其他失败原因
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
return success_response(
|
||||
data={'revokedTaskCount': revoked_count}
|
||||
)
|
||||
|
||||
|
||||
return success_response(data={'revokedTaskCount': revoked_count})
|
||||
|
||||
except ObjectDoesNotExist:
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message=f'Scan ID {pk} not found',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
except (DatabaseError, IntegrityError, OperationalError):
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Database error',
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
)
|
||||
return _handle_database_error()
|
||||
|
||||
@@ -44,6 +44,8 @@ services:
|
||||
restart: always
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- IMAGE_TAG=${IMAGE_TAG:-dev}
|
||||
ports:
|
||||
- "8888:8888"
|
||||
depends_on:
|
||||
|
||||
@@ -48,6 +48,8 @@ services:
|
||||
restart: always
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- IMAGE_TAG=${IMAGE_TAG}
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -83,20 +83,20 @@ if not yaml_path.exists():
|
||||
print('未找到配置文件,跳过')
|
||||
exit(0)
|
||||
|
||||
new_config = yaml_path.read_text()
|
||||
|
||||
# 检查是否已有 full scan 引擎
|
||||
engine = ScanEngine.objects.filter(name='full scan').first()
|
||||
if engine:
|
||||
if not engine.configuration or not engine.configuration.strip():
|
||||
engine.configuration = yaml_path.read_text()
|
||||
engine.save(update_fields=['configuration'])
|
||||
print(f'已初始化引擎配置: {engine.name}')
|
||||
else:
|
||||
print(f'引擎已有配置,跳过')
|
||||
# 直接覆盖为最新配置
|
||||
engine.configuration = new_config
|
||||
engine.save(update_fields=['configuration'])
|
||||
print(f'已更新引擎配置: {engine.name}')
|
||||
else:
|
||||
# 创建引擎
|
||||
engine = ScanEngine.objects.create(
|
||||
name='full scan',
|
||||
configuration=yaml_path.read_text(),
|
||||
configuration=new_config,
|
||||
)
|
||||
print(f'已创建引擎: {engine.name}')
|
||||
"
|
||||
|
||||
@@ -10,7 +10,7 @@ python manage.py migrate --noinput
|
||||
echo " ✓ 数据库迁移完成"
|
||||
|
||||
echo " [1.1/3] 初始化默认扫描引擎..."
|
||||
python manage.py init_default_engine
|
||||
python manage.py init_default_engine --force
|
||||
echo " ✓ 默认扫描引擎已就绪"
|
||||
|
||||
echo " [1.2/3] 初始化默认目录字典..."
|
||||
|
||||
@@ -182,7 +182,7 @@ echo -e "${BOLD}${GREEN}══════════════════
|
||||
echo ""
|
||||
echo -e "${BOLD}访问地址${NC}"
|
||||
if [ "$WITH_FRONTEND" = true ]; then
|
||||
echo -e " XingRin: ${CYAN}https://${ACCESS_HOST}/${NC}"
|
||||
echo -e " XingRin: ${CYAN}https://${ACCESS_HOST}:8083/${NC}"
|
||||
echo -e " ${YELLOW}(HTTP 会自动跳转到 HTTPS)${NC}"
|
||||
else
|
||||
echo -e " API: ${CYAN}通过前端或 nginx 访问(后端未暴露 8888)${NC}"
|
||||
@@ -191,8 +191,3 @@ else
|
||||
echo " cd frontend && pnpm dev"
|
||||
fi
|
||||
echo ""
|
||||
echo -e "${BOLD}默认账号${NC}"
|
||||
echo " 用户名: admin"
|
||||
echo " 密码: admin"
|
||||
echo -e " ${YELLOW}[!] 请首次登录后修改密码${NC}"
|
||||
echo ""
|
||||
|
||||
189
frontend/components/about-dialog.tsx
Normal file
189
frontend/components/about-dialog.tsx
Normal file
@@ -0,0 +1,189 @@
|
||||
"use client"
|
||||
|
||||
import { useState } from 'react'
|
||||
import { useTranslations } from 'next-intl'
|
||||
import { useQueryClient } from '@tanstack/react-query'
|
||||
import {
|
||||
IconRadar,
|
||||
IconRefresh,
|
||||
IconExternalLink,
|
||||
IconBrandGithub,
|
||||
IconMessageReport,
|
||||
IconBook,
|
||||
IconFileText,
|
||||
IconCheck,
|
||||
IconArrowUp,
|
||||
} from '@tabler/icons-react'
|
||||
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Separator } from '@/components/ui/separator'
|
||||
import { Badge } from '@/components/ui/badge'
|
||||
import { useVersion } from '@/hooks/use-version'
|
||||
import { VersionService } from '@/services/version.service'
|
||||
import type { UpdateCheckResult } from '@/types/version.types'
|
||||
|
||||
interface AboutDialogProps {
|
||||
children: React.ReactNode
|
||||
}
|
||||
|
||||
export function AboutDialog({ children }: AboutDialogProps) {
|
||||
const t = useTranslations('about')
|
||||
const { data: versionData } = useVersion()
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const [isChecking, setIsChecking] = useState(false)
|
||||
const [updateResult, setUpdateResult] = useState<UpdateCheckResult | null>(null)
|
||||
const [checkError, setCheckError] = useState<string | null>(null)
|
||||
|
||||
const handleCheckUpdate = async () => {
|
||||
setIsChecking(true)
|
||||
setCheckError(null)
|
||||
try {
|
||||
const result = await VersionService.checkUpdate()
|
||||
setUpdateResult(result)
|
||||
queryClient.setQueryData(['check-update'], result)
|
||||
} catch {
|
||||
setCheckError(t('checkFailed'))
|
||||
} finally {
|
||||
setIsChecking(false)
|
||||
}
|
||||
}
|
||||
|
||||
const currentVersion = updateResult?.currentVersion || versionData?.version || '-'
|
||||
const latestVersion = updateResult?.latestVersion
|
||||
const hasUpdate = updateResult?.hasUpdate
|
||||
|
||||
return (
|
||||
<Dialog>
|
||||
<DialogTrigger asChild>
|
||||
{children}
|
||||
</DialogTrigger>
|
||||
<DialogContent className="sm:max-w-md">
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('title')}</DialogTitle>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="space-y-6">
|
||||
{/* Logo and name */}
|
||||
<div className="flex flex-col items-center py-4">
|
||||
<div className="flex h-16 w-16 items-center justify-center rounded-2xl bg-primary/10 mb-3">
|
||||
<IconRadar className="h-8 w-8 text-primary" />
|
||||
</div>
|
||||
<h2 className="text-xl font-semibold">XingRin</h2>
|
||||
<p className="text-sm text-muted-foreground">{t('description')}</p>
|
||||
</div>
|
||||
|
||||
{/* Version info */}
|
||||
<div className="rounded-lg border p-4 space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-muted-foreground">{t('currentVersion')}</span>
|
||||
<span className="font-mono text-sm">{currentVersion}</span>
|
||||
</div>
|
||||
|
||||
{updateResult && (
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-muted-foreground">{t('latestVersion')}</span>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-mono text-sm">{latestVersion}</span>
|
||||
{hasUpdate ? (
|
||||
<Badge variant="default" className="gap-1">
|
||||
<IconArrowUp className="h-3 w-3" />
|
||||
{t('updateAvailable')}
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="secondary" className="gap-1">
|
||||
<IconCheck className="h-3 w-3" />
|
||||
{t('upToDate')}
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{checkError && (
|
||||
<p className="text-sm text-destructive">{checkError}</p>
|
||||
)}
|
||||
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
className="flex-1"
|
||||
onClick={handleCheckUpdate}
|
||||
disabled={isChecking}
|
||||
>
|
||||
<IconRefresh className={`h-4 w-4 mr-2 ${isChecking ? 'animate-spin' : ''}`} />
|
||||
{isChecking ? t('checking') : t('checkUpdate')}
|
||||
</Button>
|
||||
|
||||
{hasUpdate && updateResult?.releaseUrl && (
|
||||
<Button
|
||||
variant="default"
|
||||
size="sm"
|
||||
className="flex-1"
|
||||
asChild
|
||||
>
|
||||
<a href={updateResult.releaseUrl} target="_blank" rel="noopener noreferrer">
|
||||
<IconExternalLink className="h-4 w-4 mr-2" />
|
||||
{t('viewRelease')}
|
||||
</a>
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{hasUpdate && (
|
||||
<div className="rounded-md bg-muted p-3 text-sm text-muted-foreground">
|
||||
<p>{t('updateHint')}</p>
|
||||
<code className="mt-1 block rounded bg-background px-2 py-1 font-mono text-xs">
|
||||
sudo ./update.sh
|
||||
</code>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* Links */}
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
<Button variant="ghost" size="sm" className="justify-start" asChild>
|
||||
<a href="https://github.com/yyhuni/xingrin" target="_blank" rel="noopener noreferrer">
|
||||
<IconBrandGithub className="h-4 w-4 mr-2" />
|
||||
GitHub
|
||||
</a>
|
||||
</Button>
|
||||
<Button variant="ghost" size="sm" className="justify-start" asChild>
|
||||
<a href="https://github.com/yyhuni/xingrin/releases" target="_blank" rel="noopener noreferrer">
|
||||
<IconFileText className="h-4 w-4 mr-2" />
|
||||
{t('changelog')}
|
||||
</a>
|
||||
</Button>
|
||||
<Button variant="ghost" size="sm" className="justify-start" asChild>
|
||||
<a href="https://github.com/yyhuni/xingrin/issues" target="_blank" rel="noopener noreferrer">
|
||||
<IconMessageReport className="h-4 w-4 mr-2" />
|
||||
{t('feedback')}
|
||||
</a>
|
||||
</Button>
|
||||
<Button variant="ghost" size="sm" className="justify-start" asChild>
|
||||
<a href="https://github.com/yyhuni/xingrin#readme" target="_blank" rel="noopener noreferrer">
|
||||
<IconBook className="h-4 w-4 mr-2" />
|
||||
{t('docs')}
|
||||
</a>
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Footer */}
|
||||
<p className="text-center text-xs text-muted-foreground">
|
||||
© 2025 XingRin · MIT License
|
||||
</p>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import type * as React from "react"
|
||||
// Import various icons from Tabler Icons library
|
||||
import {
|
||||
IconDashboard, // Dashboard icon
|
||||
IconHelp, // Help icon
|
||||
IconListDetails, // List details icon
|
||||
IconSettings, // Settings icon
|
||||
IconUsers, // Users icon
|
||||
@@ -15,10 +14,10 @@ import {
|
||||
IconServer, // Server icon
|
||||
IconTerminal2, // Terminal icon
|
||||
IconBug, // Vulnerability icon
|
||||
IconMessageReport, // Feedback icon
|
||||
IconSearch, // Search icon
|
||||
IconKey, // API Key icon
|
||||
IconBan, // Blacklist icon
|
||||
IconInfoCircle, // About icon
|
||||
} from "@tabler/icons-react"
|
||||
// Import internationalization hook
|
||||
import { useTranslations } from 'next-intl'
|
||||
@@ -27,8 +26,8 @@ import { Link, usePathname } from '@/i18n/navigation'
|
||||
|
||||
// Import custom navigation components
|
||||
import { NavSystem } from "@/components/nav-system"
|
||||
import { NavSecondary } from "@/components/nav-secondary"
|
||||
import { NavUser } from "@/components/nav-user"
|
||||
import { AboutDialog } from "@/components/about-dialog"
|
||||
// Import sidebar UI components
|
||||
import {
|
||||
Sidebar,
|
||||
@@ -139,20 +138,6 @@ export function AppSidebar({ ...props }: React.ComponentProps<typeof Sidebar>) {
|
||||
},
|
||||
]
|
||||
|
||||
// Secondary navigation menu items
|
||||
const navSecondary = [
|
||||
{
|
||||
title: t('feedback'),
|
||||
url: "https://github.com/yyhuni/xingrin/issues",
|
||||
icon: IconMessageReport,
|
||||
},
|
||||
{
|
||||
title: t('help'),
|
||||
url: "https://github.com/yyhuni/xingrin",
|
||||
icon: IconHelp,
|
||||
},
|
||||
]
|
||||
|
||||
// System settings related menu items
|
||||
const documents = [
|
||||
{
|
||||
@@ -271,8 +256,21 @@ export function AppSidebar({ ...props }: React.ComponentProps<typeof Sidebar>) {
|
||||
|
||||
{/* System settings navigation menu */}
|
||||
<NavSystem items={documents} />
|
||||
{/* Secondary navigation menu, using mt-auto to push to bottom */}
|
||||
<NavSecondary items={navSecondary} className="mt-auto" />
|
||||
{/* About system button */}
|
||||
<SidebarGroup className="mt-auto">
|
||||
<SidebarGroupContent>
|
||||
<SidebarMenu>
|
||||
<SidebarMenuItem>
|
||||
<AboutDialog>
|
||||
<SidebarMenuButton>
|
||||
<IconInfoCircle />
|
||||
<span>{t('about')}</span>
|
||||
</SidebarMenuButton>
|
||||
</AboutDialog>
|
||||
</SidebarMenuItem>
|
||||
</SidebarMenu>
|
||||
</SidebarGroupContent>
|
||||
</SidebarGroup>
|
||||
</SidebarContent>
|
||||
|
||||
{/* Sidebar footer */}
|
||||
|
||||
@@ -54,7 +54,7 @@ export function EnginePresetSelector({
|
||||
|
||||
engines.forEach(e => {
|
||||
const caps = parseEngineCapabilities(e.configuration || "")
|
||||
const hasRecon = caps.includes("subdomain_discovery") || caps.includes("port_scan") || caps.includes("site_scan") || caps.includes("directory_scan") || caps.includes("url_fetch")
|
||||
const hasRecon = caps.includes("subdomain_discovery") || caps.includes("port_scan") || caps.includes("site_scan") || caps.includes("fingerprint_detect") || caps.includes("directory_scan") || caps.includes("url_fetch") || caps.includes("screenshot")
|
||||
const hasVuln = caps.includes("vuln_scan")
|
||||
|
||||
if (hasRecon && hasVuln) {
|
||||
|
||||
@@ -58,14 +58,6 @@ subdomain_discovery:
|
||||
enabled: true
|
||||
timeout: 600 # 10 minutes (required)
|
||||
|
||||
amass_passive:
|
||||
enabled: true
|
||||
timeout: 600 # 10 minutes (required)
|
||||
|
||||
amass_active:
|
||||
enabled: true
|
||||
timeout: 1800 # 30 minutes (required)
|
||||
|
||||
sublist3r:
|
||||
enabled: true
|
||||
timeout: 900 # 15 minutes (required)
|
||||
|
||||
19
frontend/hooks/use-version.ts
Normal file
19
frontend/hooks/use-version.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { VersionService } from '@/services/version.service'
|
||||
|
||||
export function useVersion() {
|
||||
return useQuery({
|
||||
queryKey: ['version'],
|
||||
queryFn: () => VersionService.getVersion(),
|
||||
staleTime: Infinity,
|
||||
})
|
||||
}
|
||||
|
||||
export function useCheckUpdate() {
|
||||
return useQuery({
|
||||
queryKey: ['check-update'],
|
||||
queryFn: () => VersionService.checkUpdate(),
|
||||
enabled: false, // 手动触发
|
||||
staleTime: 5 * 60 * 1000, // 5 分钟缓存
|
||||
})
|
||||
}
|
||||
@@ -325,8 +325,7 @@
|
||||
"notifications": "Notifications",
|
||||
"apiKeys": "API Keys",
|
||||
"globalBlacklist": "Global Blacklist",
|
||||
"help": "Get Help",
|
||||
"feedback": "Feedback"
|
||||
"about": "About"
|
||||
},
|
||||
"search": {
|
||||
"title": "Asset Search",
|
||||
@@ -2292,5 +2291,21 @@
|
||||
"conflict": "Resource conflict, please check and try again",
|
||||
"unauthorized": "Please login first",
|
||||
"rateLimited": "Too many requests, please try again later"
|
||||
},
|
||||
"about": {
|
||||
"title": "About XingRin",
|
||||
"description": "Attack Surface Management Platform",
|
||||
"currentVersion": "Current Version",
|
||||
"latestVersion": "Latest Version",
|
||||
"checkUpdate": "Check Update",
|
||||
"checking": "Checking...",
|
||||
"checkFailed": "Failed to check update, please try again later",
|
||||
"updateAvailable": "Update Available",
|
||||
"upToDate": "Up to Date",
|
||||
"viewRelease": "View Release",
|
||||
"updateHint": "Run the following command in project root to update:",
|
||||
"changelog": "Changelog",
|
||||
"feedback": "Feedback",
|
||||
"docs": "Documentation"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -325,8 +325,7 @@
|
||||
"notifications": "通知设置",
|
||||
"apiKeys": "API 密钥",
|
||||
"globalBlacklist": "全局黑名单",
|
||||
"help": "获取帮助",
|
||||
"feedback": "反馈建议"
|
||||
"about": "关于系统"
|
||||
},
|
||||
"search": {
|
||||
"title": "资产搜索",
|
||||
@@ -2292,5 +2291,21 @@
|
||||
"conflict": "资源冲突,请检查后重试",
|
||||
"unauthorized": "请先登录",
|
||||
"rateLimited": "请求过于频繁,请稍后重试"
|
||||
},
|
||||
"about": {
|
||||
"title": "关于 XingRin",
|
||||
"description": "攻击面管理平台",
|
||||
"currentVersion": "当前版本",
|
||||
"latestVersion": "最新版本",
|
||||
"checkUpdate": "检查更新",
|
||||
"checking": "检查中...",
|
||||
"checkFailed": "检查更新失败,请稍后重试",
|
||||
"updateAvailable": "有更新",
|
||||
"upToDate": "已是最新",
|
||||
"viewRelease": "查看发布",
|
||||
"updateHint": "在项目根目录运行以下命令更新:",
|
||||
"changelog": "更新日志",
|
||||
"feedback": "问题反馈",
|
||||
"docs": "使用文档"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,10 @@ export const mockNotificationSettings: NotificationSettings = {
|
||||
enabled: true,
|
||||
webhookUrl: 'https://discord.com/api/webhooks/1234567890/abcdefghijklmnop',
|
||||
},
|
||||
wecom: {
|
||||
enabled: false,
|
||||
webhookUrl: '',
|
||||
},
|
||||
categories: {
|
||||
scan: true,
|
||||
vulnerability: true,
|
||||
@@ -30,6 +34,7 @@ export function updateMockNotificationSettings(
|
||||
return {
|
||||
message: 'Notification settings updated successfully',
|
||||
discord: mockNotificationSettings.discord,
|
||||
wecom: mockNotificationSettings.wecom,
|
||||
categories: mockNotificationSettings.categories,
|
||||
}
|
||||
}
|
||||
|
||||
14
frontend/services/version.service.ts
Normal file
14
frontend/services/version.service.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
import { api } from '@/lib/api-client'
|
||||
import type { VersionInfo, UpdateCheckResult } from '@/types/version.types'
|
||||
|
||||
export class VersionService {
|
||||
static async getVersion(): Promise<VersionInfo> {
|
||||
const res = await api.get<VersionInfo>('/system/version/')
|
||||
return res.data
|
||||
}
|
||||
|
||||
static async checkUpdate(): Promise<UpdateCheckResult> {
|
||||
const res = await api.get<UpdateCheckResult>('/system/check-update/')
|
||||
return res.data
|
||||
}
|
||||
}
|
||||
13
frontend/types/version.types.ts
Normal file
13
frontend/types/version.types.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
export interface VersionInfo {
|
||||
version: string
|
||||
githubRepo: string
|
||||
}
|
||||
|
||||
export interface UpdateCheckResult {
|
||||
currentVersion: string
|
||||
latestVersion: string
|
||||
hasUpdate: boolean
|
||||
releaseUrl: string
|
||||
releaseNotes: string | null
|
||||
publishedAt: string | null
|
||||
}
|
||||
108
update.sh
108
update.sh
@@ -21,8 +21,8 @@ cd "$(dirname "$0")"
|
||||
|
||||
# 权限检查
|
||||
if [ "$EUID" -ne 0 ]; then
|
||||
echo -e "\033[0;31m[错误] 请使用 sudo 运行此脚本\033[0m"
|
||||
echo -e " 正确用法: \033[1msudo ./update.sh\033[0m"
|
||||
printf "\033[0;31m✗ 请使用 sudo 运行此脚本\033[0m\n"
|
||||
printf " 正确用法: \033[1msudo ./update.sh\033[0m\n"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -49,9 +49,17 @@ YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
BLUE='\033[0;34m'
|
||||
CYAN='\033[0;36m'
|
||||
DIM='\033[2m'
|
||||
BOLD='\033[1m'
|
||||
NC='\033[0m'
|
||||
|
||||
# 日志函数
|
||||
log_step() { printf "${CYAN}▶${NC} %s\n" "$1"; }
|
||||
log_ok() { printf " ${GREEN}✓${NC} %s\n" "$1"; }
|
||||
log_info() { printf " ${DIM}→${NC} %s\n" "$1"; }
|
||||
log_warn() { printf " ${YELLOW}!${NC} %s\n" "$1"; }
|
||||
log_error() { printf "${RED}✗${NC} %s\n" "$1"; }
|
||||
|
||||
# 合并 .env 新配置项(保留用户已有值)
|
||||
merge_env_config() {
|
||||
local example_file="docker/.env.example"
|
||||
@@ -70,58 +78,68 @@ merge_env_config() {
|
||||
|
||||
if ! grep -q "^${key}=" "$env_file"; then
|
||||
printf '%s\n' "$line" >> "$env_file"
|
||||
echo -e " ${GREEN}+${NC} 新增: $key"
|
||||
log_info "新增配置: $key"
|
||||
((new_keys++))
|
||||
fi
|
||||
done < "$example_file"
|
||||
|
||||
if [ $new_keys -gt 0 ]; then
|
||||
echo -e " ${GREEN}OK${NC} 已添加 $new_keys 个新配置项"
|
||||
log_ok "已添加 $new_keys 个新配置项"
|
||||
else
|
||||
echo -e " ${GREEN}OK${NC} 配置已是最新"
|
||||
log_ok "配置已是最新"
|
||||
fi
|
||||
}
|
||||
|
||||
echo ""
|
||||
echo -e "${BOLD}${BLUE}╔════════════════════════════════════════╗${NC}"
|
||||
# 显示标题
|
||||
printf "\n"
|
||||
printf "${BOLD}${BLUE}┌────────────────────────────────────────┐${NC}\n"
|
||||
if [ "$DEV_MODE" = true ]; then
|
||||
echo -e "${BOLD}${BLUE}║ 开发环境更新(本地构建) ║${NC}"
|
||||
printf "${BOLD}${BLUE}│${NC} ${BOLD}XingRin 系统更新${NC} ${BOLD}${BLUE}│${NC}\n"
|
||||
printf "${BOLD}${BLUE}│${NC} ${DIM}开发模式 · 本地构建${NC} ${BOLD}${BLUE}│${NC}\n"
|
||||
else
|
||||
echo -e "${BOLD}${BLUE}║ 生产环境更新(Docker Hub) ║${NC}"
|
||||
printf "${BOLD}${BLUE}│${NC} ${BOLD}XingRin 系统更新${NC} ${BOLD}${BLUE}│${NC}\n"
|
||||
printf "${BOLD}${BLUE}│${NC} ${DIM}生产模式 · Docker Hub${NC} ${BOLD}${BLUE}│${NC}\n"
|
||||
fi
|
||||
echo -e "${BOLD}${BLUE}╚════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
printf "${BOLD}${BLUE}└────────────────────────────────────────┘${NC}\n"
|
||||
printf "\n"
|
||||
|
||||
# 测试性功能警告
|
||||
echo -e "${BOLD}${YELLOW}[!] 警告:此功能为测试性功能,可能会导致升级失败${NC}"
|
||||
echo -e "${YELLOW} 建议运行 ./uninstall.sh 后重新执行 ./install.sh 进行全新安装${NC}"
|
||||
echo ""
|
||||
echo -n -e "${YELLOW}是否继续更新?(y/N) ${NC}"
|
||||
# 警告提示
|
||||
printf "${YELLOW}┌─ 注意事项 ─────────────────────────────┐${NC}\n"
|
||||
printf "${YELLOW}│${NC} • 此功能为测试性功能,可能导致升级失败 ${YELLOW}│${NC}\n"
|
||||
printf "${YELLOW}│${NC} • 升级会覆盖所有默认引擎配置 ${YELLOW}│${NC}\n"
|
||||
printf "${YELLOW}│${NC} • 自定义配置请先备份或创建新引擎 ${YELLOW}│${NC}\n"
|
||||
printf "${YELLOW}│${NC} • 推荐:卸载后重新安装以获得最佳体验 ${YELLOW}│${NC}\n"
|
||||
printf "${YELLOW}└────────────────────────────────────────┘${NC}\n"
|
||||
printf "\n"
|
||||
|
||||
printf "${YELLOW}是否继续更新?${NC} [y/N] "
|
||||
read -r ans_continue
|
||||
ans_continue=${ans_continue:-N}
|
||||
|
||||
if [[ ! $ans_continue =~ ^[Yy]$ ]]; then
|
||||
echo -e "${CYAN}已取消更新。${NC}"
|
||||
printf "\n${DIM}已取消更新${NC}\n"
|
||||
exit 0
|
||||
fi
|
||||
echo ""
|
||||
printf "\n"
|
||||
|
||||
# Step 1: 停止服务
|
||||
echo -e "${CYAN}[1/5]${NC} 停止服务..."
|
||||
./stop.sh 2>&1 | sed 's/^/ /'
|
||||
log_step "停止服务..."
|
||||
./stop.sh 2>&1 | sed 's/^/ /'
|
||||
log_ok "服务已停止"
|
||||
|
||||
# Step 2: 拉取代码
|
||||
echo ""
|
||||
echo -e "${CYAN}[2/5]${NC} 拉取代码..."
|
||||
git pull --rebase 2>&1 | sed 's/^/ /'
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "${RED}[错误]${NC} git pull 失败,请手动解决冲突后重试"
|
||||
printf "\n"
|
||||
log_step "拉取最新代码..."
|
||||
if git pull --rebase 2>&1 | sed 's/^/ /'; then
|
||||
log_ok "代码已更新"
|
||||
else
|
||||
log_error "git pull 失败,请手动解决冲突后重试"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Step 3: 检查配置更新 + 版本同步
|
||||
echo ""
|
||||
echo -e "${CYAN}[3/5]${NC} 检查配置更新..."
|
||||
printf "\n"
|
||||
log_step "同步配置..."
|
||||
merge_env_config
|
||||
|
||||
# 版本同步:从 VERSION 文件更新 IMAGE_TAG
|
||||
@@ -130,21 +148,20 @@ if [ -f "VERSION" ]; then
|
||||
if [ -n "$NEW_VERSION" ]; then
|
||||
if grep -q "^IMAGE_TAG=" "docker/.env"; then
|
||||
sed_inplace "s/^IMAGE_TAG=.*/IMAGE_TAG=$NEW_VERSION/" "docker/.env"
|
||||
echo -e " ${GREEN}+${NC} 版本同步: IMAGE_TAG=$NEW_VERSION"
|
||||
else
|
||||
printf '%s\n' "IMAGE_TAG=$NEW_VERSION" >> "docker/.env"
|
||||
echo -e " ${GREEN}+${NC} 新增版本: IMAGE_TAG=$NEW_VERSION"
|
||||
fi
|
||||
log_ok "版本同步: $NEW_VERSION"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Step 4: 构建/拉取镜像
|
||||
echo ""
|
||||
echo -e "${CYAN}[4/5]${NC} 更新镜像..."
|
||||
printf "\n"
|
||||
log_step "更新镜像..."
|
||||
|
||||
if [ "$DEV_MODE" = true ]; then
|
||||
# 开发模式:本地构建所有镜像(包括 Worker)
|
||||
echo -e " 构建 Worker 镜像..."
|
||||
log_info "构建 Worker 镜像..."
|
||||
|
||||
# 读取 IMAGE_TAG
|
||||
IMAGE_TAG=$(grep "^IMAGE_TAG=" "docker/.env" | cut -d'=' -f2)
|
||||
@@ -153,24 +170,23 @@ if [ "$DEV_MODE" = true ]; then
|
||||
fi
|
||||
|
||||
# 构建 Worker 镜像(Worker 是临时容器,不在 compose 中,需要单独构建)
|
||||
docker build -t docker-worker -f docker/worker/Dockerfile . 2>&1 | sed 's/^/ /'
|
||||
docker tag docker-worker docker-worker:${IMAGE_TAG} 2>&1 | sed 's/^/ /'
|
||||
echo -e " ${GREEN}OK${NC} Worker 镜像已构建: docker-worker:${IMAGE_TAG}"
|
||||
docker build -t docker-worker -f docker/worker/Dockerfile . 2>&1 | sed 's/^/ /'
|
||||
docker tag docker-worker docker-worker:${IMAGE_TAG} 2>&1 | sed 's/^/ /'
|
||||
log_ok "Worker 镜像: docker-worker:${IMAGE_TAG}"
|
||||
|
||||
# 其他服务镜像由 start.sh --dev 构建
|
||||
echo -e " 其他服务镜像将在启动时构建..."
|
||||
log_info "其他服务镜像将在启动时构建"
|
||||
else
|
||||
# 生产模式:镜像由 start.sh 拉取
|
||||
echo -e " 镜像将在启动时从 Docker Hub 拉取..."
|
||||
log_info "镜像将在启动时从 Docker Hub 拉取"
|
||||
fi
|
||||
|
||||
# Step 5: 启动服务
|
||||
echo ""
|
||||
echo -e "${CYAN}[5/5]${NC} 启动服务..."
|
||||
printf "\n"
|
||||
log_step "启动服务..."
|
||||
./start.sh "$@"
|
||||
|
||||
echo ""
|
||||
echo -e "${BOLD}${GREEN}════════════════════════════════════════${NC}"
|
||||
echo -e "${BOLD}${GREEN} 更新完成!${NC}"
|
||||
echo -e "${BOLD}${GREEN}════════════════════════════════════════${NC}"
|
||||
echo ""
|
||||
# 完成提示
|
||||
printf "\n"
|
||||
printf "${GREEN}┌────────────────────────────────────────┐${NC}\n"
|
||||
printf "${GREEN}│${NC} ${BOLD}${GREEN}✓${NC} ${BOLD}更新完成${NC} ${GREEN}│${NC}\n"
|
||||
printf "${GREEN}└────────────────────────────────────────┘${NC}\n"
|
||||
printf "\n"
|
||||
|
||||
Reference in New Issue
Block a user