mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-02-02 04:33:10 +08:00
Compare commits
43 Commits
v1.4.0-dev
...
v1.5.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a6f1b6f24 | ||
|
|
255d505aba | ||
|
|
d06a9bab1f | ||
|
|
6d5c776bf7 | ||
|
|
bf058dd67b | ||
|
|
0532d7c8b8 | ||
|
|
2ee9b5ffa2 | ||
|
|
648a1888d4 | ||
|
|
2508268a45 | ||
|
|
c60383940c | ||
|
|
47298c294a | ||
|
|
eba394e14e | ||
|
|
592a1958c4 | ||
|
|
38e2856c08 | ||
|
|
f5ad8e68e9 | ||
|
|
d5f91a236c | ||
|
|
24ae8b5aeb | ||
|
|
86f43f94a0 | ||
|
|
53ba03d1e5 | ||
|
|
89c44ebd05 | ||
|
|
e0e3419edb | ||
|
|
52ee4684a7 | ||
|
|
ce8cebf11d | ||
|
|
ec006d8f54 | ||
|
|
48976a570f | ||
|
|
5da7229873 | ||
|
|
8bb737a9fa | ||
|
|
2d018d33f3 | ||
|
|
0c07cc8497 | ||
|
|
225b039985 | ||
|
|
d1624627bc | ||
|
|
7bb15e4ae4 | ||
|
|
8e8cc29669 | ||
|
|
d6d5338acb | ||
|
|
c521bdb511 | ||
|
|
abf2d95f6f | ||
|
|
ab58cf0d85 | ||
|
|
fb0111adf2 | ||
|
|
161ee9a2b1 | ||
|
|
0cf75585d5 | ||
|
|
1d8d5f51d9 | ||
|
|
3f8de07c8c | ||
|
|
6ff86e14ec |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -64,6 +64,7 @@ backend/.env.local
|
||||
.coverage
|
||||
htmlcov/
|
||||
*.cover
|
||||
.hypothesis/
|
||||
|
||||
# ============================
|
||||
# 后端 (Go) 相关
|
||||
|
||||
69
README.md
69
README.md
@@ -27,7 +27,7 @@
|
||||
|
||||
## 🌐 在线 Demo
|
||||
|
||||
👉 **[https://xingrin.vercel.app/](https://xingrin.vercel.app/)**
|
||||
**[https://xingrin.vercel.app/](https://xingrin.vercel.app/)**
|
||||
|
||||
> ⚠️ 仅用于 UI 展示,未接入后端数据库
|
||||
|
||||
@@ -58,23 +58,33 @@
|
||||
|
||||
## ✨ 功能特性
|
||||
|
||||
### 🎯 目标与资产管理
|
||||
- **组织管理** - 多层级目标组织,灵活分组
|
||||
- **目标管理** - 支持域名、IP目标类型
|
||||
- **资产发现** - 子域名、网站、端点、目录自动发现
|
||||
- **资产快照** - 扫描结果快照对比,追踪资产变化
|
||||
### 扫描能力
|
||||
|
||||
### 🔍 漏洞扫描
|
||||
- **多引擎支持** - 集成 Nuclei 等主流扫描引擎
|
||||
- **自定义流程** - YAML 配置扫描流程,灵活编排
|
||||
- **定时扫描** - Cron 表达式配置,自动化周期扫描
|
||||
| 功能 | 状态 | 工具 | 说明 |
|
||||
|------|------|------|------|
|
||||
| 子域名扫描 | ✅ | Subfinder, Amass, PureDNS | 被动收集 + 主动爆破,聚合 50+ 数据源 |
|
||||
| 端口扫描 | ✅ | Naabu | 自定义端口范围 |
|
||||
| 站点发现 | ✅ | HTTPX | HTTP 探测,自动获取标题、状态码、技术栈 |
|
||||
| 指纹识别 | ✅ | XingFinger | 2.7W+ 指纹规则,多源指纹库 |
|
||||
| URL 收集 | ✅ | Waymore, Katana | 历史数据 + 主动爬取 |
|
||||
| 目录扫描 | ✅ | FFUF | 高速爆破,智能字典 |
|
||||
| 漏洞扫描 | ✅ | Nuclei, Dalfox | 9000+ POC 模板,XSS 检测 |
|
||||
| 站点截图 | ✅ | Playwright | WebP 高压缩存储 |
|
||||
|
||||
### 🔖 指纹识别
|
||||
- **多源指纹库** - 内置 EHole、Goby、Wappalyzer、Fingers、FingerPrintHub、ARL 等 2.7W+ 指纹规则
|
||||
- **自动识别** - 扫描流程自动执行,识别 Web 应用技术栈
|
||||
- **指纹管理** - 支持查询、导入、导出指纹规则
|
||||
### 平台能力
|
||||
|
||||
#### 扫描流程架构
|
||||
| 功能 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| 目标管理 | ✅ | 多层级组织,支持域名/IP 目标 |
|
||||
| 资产快照 | ✅ | 扫描结果对比,追踪资产变化 |
|
||||
| 黑名单过滤 | ✅ | 全局 + Target 级,支持通配符/CIDR |
|
||||
| 定时任务 | ✅ | Cron 表达式,自动化周期扫描 |
|
||||
| 分布式扫描 | ✅ | 多 Worker 节点,负载感知调度 |
|
||||
| 全局搜索 | ✅ | 表达式语法,多字段组合查询 |
|
||||
| 通知推送 | ✅ | 企业微信、Telegram、Discord |
|
||||
| API 密钥管理 | ✅ | 可视化配置各数据源 API Key |
|
||||
|
||||
### 扫描流程架构
|
||||
|
||||
完整的扫描流程包括:子域名发现、端口扫描、站点发现、指纹识别、URL 收集、目录扫描、漏洞扫描等阶段
|
||||
|
||||
@@ -95,6 +105,7 @@ flowchart LR
|
||||
direction TB
|
||||
URL["URL 收集<br/>waymore, katana"]
|
||||
DIR["目录扫描<br/>ffuf"]
|
||||
SCREENSHOT["站点截图<br/>playwright"]
|
||||
end
|
||||
|
||||
subgraph STAGE3["阶段 3: 漏洞检测"]
|
||||
@@ -119,6 +130,7 @@ flowchart LR
|
||||
style FINGER fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
|
||||
style URL fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
|
||||
style DIR fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
|
||||
style SCREENSHOT fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
|
||||
style VULN fill:#f0b27a,stroke:#e67e22,stroke-width:1px,color:#fff
|
||||
```
|
||||
|
||||
@@ -225,7 +237,6 @@ sudo ./install.sh --mirror
|
||||
> **💡 --mirror 参数说明**
|
||||
> - 自动配置 Docker 镜像加速(国内镜像源)
|
||||
> - 加速 Git 仓库克隆(Nuclei 模板等)
|
||||
> - 大幅提升安装速度,避免网络超时
|
||||
|
||||
### 访问服务
|
||||
|
||||
@@ -258,6 +269,32 @@ sudo ./uninstall.sh
|
||||
|
||||
<img src="docs/wechat-qrcode.png" alt="微信公众号" width="200">
|
||||
|
||||
### 🎁 关注公众号免费领取指纹库
|
||||
|
||||
| 指纹库 | 数量 |
|
||||
|--------|------|
|
||||
| ehole.json | 21,977 |
|
||||
| ARL.yaml | 9,264 |
|
||||
| goby.json | 7,086 |
|
||||
| FingerprintHub.json | 3,147 |
|
||||
|
||||
> 💡 关注公众号回复「指纹」即可获取
|
||||
|
||||
## ☕ 赞助支持
|
||||
|
||||
如果这个项目对你有帮助,谢谢请我能喝杯蜜雪冰城,你的star和赞助是我免费更新的动力
|
||||
|
||||
<p>
|
||||
<img src="docs/wx_pay.jpg" alt="微信支付" width="200">
|
||||
<img src="docs/zfb_pay.jpg" alt="支付宝" width="200">
|
||||
</p>
|
||||
|
||||
### 🙏 感谢以下赞助
|
||||
|
||||
| 昵称 | 金额 |
|
||||
|------|------|
|
||||
| X(闭关中) | ¥88 |
|
||||
|
||||
|
||||
## ⚠️ 免责声明
|
||||
|
||||
|
||||
1
backend/.gitignore
vendored
1
backend/.gitignore
vendored
@@ -7,6 +7,7 @@ __pycache__/
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
.hypothesis/ # Hypothesis 属性测试缓存
|
||||
|
||||
# 虚拟环境
|
||||
venv/
|
||||
|
||||
53
backend/apps/asset/migrations/0003_add_screenshot_models.py
Normal file
53
backend/apps/asset/migrations/0003_add_screenshot_models.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Generated by Django 5.2.7 on 2026-01-07 02:21
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('asset', '0002_create_search_views'),
|
||||
('scan', '0001_initial'),
|
||||
('targets', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='Screenshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.TextField(help_text='截图对应的 URL')),
|
||||
('image', models.BinaryField(help_text='截图 WebP 二进制数据(压缩后)')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
|
||||
('target', models.ForeignKey(help_text='所属目标', on_delete=django.db.models.deletion.CASCADE, related_name='screenshots', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '截图',
|
||||
'verbose_name_plural': '截图',
|
||||
'db_table': 'screenshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['target'], name='screenshot_target__2f01f6_idx'), models.Index(fields=['-created_at'], name='screenshot_created_c0ad4b_idx')],
|
||||
'constraints': [models.UniqueConstraint(fields=('target', 'url'), name='unique_screenshot_per_target')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ScreenshotSnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.TextField(help_text='截图对应的 URL')),
|
||||
('image', models.BinaryField(help_text='截图 WebP 二进制数据(压缩后)')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='screenshot_snapshots', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '截图快照',
|
||||
'verbose_name_plural': '截图快照',
|
||||
'db_table': 'screenshot_snapshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scan'], name='screenshot__scan_id_fb8c4d_idx'), models.Index(fields=['-created_at'], name='screenshot__created_804117_idx')],
|
||||
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_screenshot_per_scan_snapshot')],
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,23 @@
|
||||
# Generated by Django 5.2.7 on 2026-01-07 13:29
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('asset', '0003_add_screenshot_models'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='screenshot',
|
||||
name='status_code',
|
||||
field=models.SmallIntegerField(blank=True, help_text='HTTP 响应状态码', null=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='screenshotsnapshot',
|
||||
name='status_code',
|
||||
field=models.SmallIntegerField(blank=True, help_text='HTTP 响应状态码', null=True),
|
||||
),
|
||||
]
|
||||
@@ -20,6 +20,12 @@ from .snapshot_models import (
|
||||
VulnerabilitySnapshot,
|
||||
)
|
||||
|
||||
# 截图模型
|
||||
from .screenshot_models import (
|
||||
Screenshot,
|
||||
ScreenshotSnapshot,
|
||||
)
|
||||
|
||||
# 统计模型
|
||||
from .statistics_models import AssetStatistics, StatisticsHistory
|
||||
|
||||
@@ -39,6 +45,9 @@ __all__ = [
|
||||
'HostPortMappingSnapshot',
|
||||
'EndpointSnapshot',
|
||||
'VulnerabilitySnapshot',
|
||||
# 截图模型
|
||||
'Screenshot',
|
||||
'ScreenshotSnapshot',
|
||||
# 统计模型
|
||||
'AssetStatistics',
|
||||
'StatisticsHistory',
|
||||
|
||||
80
backend/apps/asset/models/screenshot_models.py
Normal file
80
backend/apps/asset/models/screenshot_models.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from django.db import models
|
||||
|
||||
|
||||
class ScreenshotSnapshot(models.Model):
|
||||
"""
|
||||
截图快照
|
||||
|
||||
记录:某次扫描中捕获的网站截图
|
||||
"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
scan = models.ForeignKey(
|
||||
'scan.Scan',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='screenshot_snapshots',
|
||||
help_text='所属的扫描任务'
|
||||
)
|
||||
url = models.TextField(help_text='截图对应的 URL')
|
||||
status_code = models.SmallIntegerField(null=True, blank=True, help_text='HTTP 响应状态码')
|
||||
image = models.BinaryField(help_text='截图 WebP 二进制数据(压缩后)')
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'screenshot_snapshot'
|
||||
verbose_name = '截图快照'
|
||||
verbose_name_plural = '截图快照'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=['scan', 'url'],
|
||||
name='unique_screenshot_per_scan_snapshot'
|
||||
),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.url} (Scan #{self.scan_id})'
|
||||
|
||||
|
||||
class Screenshot(models.Model):
|
||||
"""
|
||||
截图资产
|
||||
|
||||
存储:目标的最新截图(从快照同步)
|
||||
"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
target = models.ForeignKey(
|
||||
'targets.Target',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='screenshots',
|
||||
help_text='所属目标'
|
||||
)
|
||||
url = models.TextField(help_text='截图对应的 URL')
|
||||
status_code = models.SmallIntegerField(null=True, blank=True, help_text='HTTP 响应状态码')
|
||||
image = models.BinaryField(help_text='截图 WebP 二进制数据(压缩后)')
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'screenshot'
|
||||
verbose_name = '截图'
|
||||
verbose_name_plural = '截图'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['target']),
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=['target', 'url'],
|
||||
name='unique_screenshot_per_target'
|
||||
),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.url} (Target #{self.target_id})'
|
||||
@@ -7,6 +7,7 @@ from .models.snapshot_models import (
|
||||
EndpointSnapshot,
|
||||
VulnerabilitySnapshot,
|
||||
)
|
||||
from .models.screenshot_models import Screenshot, ScreenshotSnapshot
|
||||
|
||||
|
||||
# 注意:IPAddress 和 Port 模型已被重构为 HostPortMapping
|
||||
@@ -290,3 +291,23 @@ class EndpointSnapshotSerializer(serializers.ModelSerializer):
|
||||
'created_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
|
||||
# ==================== 截图序列化器 ====================
|
||||
|
||||
class ScreenshotListSerializer(serializers.ModelSerializer):
|
||||
"""截图资产列表序列化器(不包含 image 字段)"""
|
||||
|
||||
class Meta:
|
||||
model = Screenshot
|
||||
fields = ['id', 'url', 'status_code', 'created_at', 'updated_at']
|
||||
read_only_fields = fields
|
||||
|
||||
|
||||
class ScreenshotSnapshotListSerializer(serializers.ModelSerializer):
|
||||
"""截图快照列表序列化器(不包含 image 字段)"""
|
||||
|
||||
class Meta:
|
||||
model = ScreenshotSnapshot
|
||||
fields = ['id', 'url', 'status_code', 'created_at']
|
||||
read_only_fields = fields
|
||||
|
||||
186
backend/apps/asset/services/playwright_screenshot_service.py
Normal file
186
backend/apps/asset/services/playwright_screenshot_service.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Playwright 截图服务
|
||||
|
||||
使用 Playwright 异步批量捕获网站截图
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, AsyncGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlaywrightScreenshotService:
|
||||
"""Playwright 截图服务 - 异步多 Page 并发截图"""
|
||||
|
||||
# 内置默认值(不暴露给用户)
|
||||
DEFAULT_VIEWPORT_WIDTH = 1920
|
||||
DEFAULT_VIEWPORT_HEIGHT = 1080
|
||||
DEFAULT_TIMEOUT = 30000 # 毫秒
|
||||
DEFAULT_JPEG_QUALITY = 85
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
viewport_width: int = DEFAULT_VIEWPORT_WIDTH,
|
||||
viewport_height: int = DEFAULT_VIEWPORT_HEIGHT,
|
||||
timeout: int = DEFAULT_TIMEOUT,
|
||||
concurrency: int = 5
|
||||
):
|
||||
"""
|
||||
初始化 Playwright 截图服务
|
||||
|
||||
Args:
|
||||
viewport_width: 视口宽度(像素)
|
||||
viewport_height: 视口高度(像素)
|
||||
timeout: 页面加载超时时间(毫秒)
|
||||
concurrency: 并发截图数
|
||||
"""
|
||||
self.viewport_width = viewport_width
|
||||
self.viewport_height = viewport_height
|
||||
self.timeout = timeout
|
||||
self.concurrency = concurrency
|
||||
|
||||
async def capture_screenshot(self, url: str, page) -> tuple[Optional[bytes], Optional[int]]:
|
||||
"""
|
||||
捕获单个 URL 的截图
|
||||
|
||||
Args:
|
||||
url: 目标 URL
|
||||
page: Playwright Page 对象
|
||||
|
||||
Returns:
|
||||
(screenshot_bytes, status_code) 元组
|
||||
- screenshot_bytes: JPEG 格式的截图字节数据,失败返回 None
|
||||
- status_code: HTTP 响应状态码,失败返回 None
|
||||
"""
|
||||
status_code = None
|
||||
try:
|
||||
# 尝试加载页面,即使返回错误状态码也继续截图
|
||||
try:
|
||||
response = await page.goto(url, timeout=self.timeout, wait_until='networkidle')
|
||||
if response:
|
||||
status_code = response.status
|
||||
except Exception as goto_error:
|
||||
# 页面加载失败(4xx/5xx 或其他错误),但页面可能已部分渲染
|
||||
# 仍然尝试截图以捕获错误页面
|
||||
logger.debug("页面加载异常但尝试截图: %s, 错误: %s", url, str(goto_error)[:50])
|
||||
|
||||
# 尝试截图(即使 goto 失败)
|
||||
screenshot_bytes = await page.screenshot(
|
||||
type='jpeg',
|
||||
quality=self.DEFAULT_JPEG_QUALITY,
|
||||
full_page=False
|
||||
)
|
||||
return (screenshot_bytes, status_code)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("截图超时: %s", url)
|
||||
return (None, None)
|
||||
except Exception as e:
|
||||
logger.warning("截图失败: %s, 错误: %s", url, str(e)[:100])
|
||||
return (None, None)
|
||||
|
||||
async def _capture_with_semaphore(
|
||||
self,
|
||||
url: str,
|
||||
context,
|
||||
semaphore: asyncio.Semaphore
|
||||
) -> tuple[str, Optional[bytes], Optional[int]]:
|
||||
"""
|
||||
使用信号量控制并发的截图任务
|
||||
|
||||
Args:
|
||||
url: 目标 URL
|
||||
context: Playwright BrowserContext
|
||||
semaphore: 并发控制信号量
|
||||
|
||||
Returns:
|
||||
(url, screenshot_bytes, status_code) 元组
|
||||
"""
|
||||
async with semaphore:
|
||||
page = await context.new_page()
|
||||
try:
|
||||
screenshot_bytes, status_code = await self.capture_screenshot(url, page)
|
||||
return (url, screenshot_bytes, status_code)
|
||||
finally:
|
||||
await page.close()
|
||||
|
||||
async def capture_batch(
|
||||
self,
|
||||
urls: list[str]
|
||||
) -> AsyncGenerator[tuple[str, Optional[bytes], Optional[int]], None]:
|
||||
"""
|
||||
批量捕获截图(异步生成器)
|
||||
|
||||
使用单个 BrowserContext + 多 Page 并发模式
|
||||
通过 Semaphore 控制并发数
|
||||
|
||||
Args:
|
||||
urls: URL 列表
|
||||
|
||||
Yields:
|
||||
(url, screenshot_bytes, status_code) 元组
|
||||
"""
|
||||
if not urls:
|
||||
return
|
||||
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
async with async_playwright() as p:
|
||||
# 启动浏览器(headless 模式)
|
||||
browser = await p.chromium.launch(
|
||||
headless=True,
|
||||
args=[
|
||||
'--no-sandbox',
|
||||
'--disable-setuid-sandbox',
|
||||
'--disable-dev-shm-usage',
|
||||
'--disable-gpu'
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
# 创建单个 context
|
||||
context = await browser.new_context(
|
||||
viewport={
|
||||
'width': self.viewport_width,
|
||||
'height': self.viewport_height
|
||||
},
|
||||
ignore_https_errors=True,
|
||||
user_agent='Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
|
||||
)
|
||||
|
||||
# 使用 Semaphore 控制并发
|
||||
semaphore = asyncio.Semaphore(self.concurrency)
|
||||
|
||||
# 创建所有任务
|
||||
tasks = [
|
||||
self._capture_with_semaphore(url, context, semaphore)
|
||||
for url in urls
|
||||
]
|
||||
|
||||
# 使用 as_completed 实现流式返回
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
result = await coro
|
||||
yield result
|
||||
|
||||
await context.close()
|
||||
|
||||
finally:
|
||||
await browser.close()
|
||||
|
||||
async def capture_batch_collect(
|
||||
self,
|
||||
urls: list[str]
|
||||
) -> list[tuple[str, Optional[bytes], Optional[int]]]:
|
||||
"""
|
||||
批量捕获截图(收集所有结果)
|
||||
|
||||
Args:
|
||||
urls: URL 列表
|
||||
|
||||
Returns:
|
||||
[(url, screenshot_bytes, status_code), ...] 列表
|
||||
"""
|
||||
results = []
|
||||
async for result in self.capture_batch(urls):
|
||||
results.append(result)
|
||||
return results
|
||||
185
backend/apps/asset/services/screenshot_service.py
Normal file
185
backend/apps/asset/services/screenshot_service.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
截图服务
|
||||
|
||||
负责截图的压缩、保存和同步
|
||||
"""
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScreenshotService:
|
||||
"""截图服务 - 负责压缩、保存和同步"""
|
||||
|
||||
def __init__(self, max_width: int = 800, target_kb: int = 100):
|
||||
"""
|
||||
初始化截图服务
|
||||
|
||||
Args:
|
||||
max_width: 最大宽度(像素)
|
||||
target_kb: 目标文件大小(KB)
|
||||
"""
|
||||
self.max_width = max_width
|
||||
self.target_kb = target_kb
|
||||
|
||||
def compress_screenshot(self, image_path: str) -> Optional[bytes]:
|
||||
"""
|
||||
压缩截图为 WebP 格式
|
||||
|
||||
Args:
|
||||
image_path: PNG 截图文件路径
|
||||
|
||||
Returns:
|
||||
压缩后的 WebP 二进制数据,失败返回 None
|
||||
"""
|
||||
if not os.path.exists(image_path):
|
||||
logger.warning(f"截图文件不存在: {image_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
with Image.open(image_path) as img:
|
||||
return self._compress_image(img)
|
||||
except Exception as e:
|
||||
logger.error(f"压缩截图失败: {image_path}, 错误: {e}")
|
||||
return None
|
||||
|
||||
def compress_from_bytes(self, image_bytes: bytes) -> Optional[bytes]:
|
||||
"""
|
||||
从字节数据压缩截图为 WebP 格式
|
||||
|
||||
Args:
|
||||
image_bytes: JPEG/PNG 图片字节数据
|
||||
|
||||
Returns:
|
||||
压缩后的 WebP 二进制数据,失败返回 None
|
||||
"""
|
||||
if not image_bytes:
|
||||
return None
|
||||
|
||||
try:
|
||||
img = Image.open(io.BytesIO(image_bytes))
|
||||
return self._compress_image(img)
|
||||
except Exception as e:
|
||||
logger.error(f"从字节压缩截图失败: {e}")
|
||||
return None
|
||||
|
||||
def _compress_image(self, img: Image.Image) -> Optional[bytes]:
|
||||
"""
|
||||
压缩 PIL Image 对象为 WebP 格式
|
||||
|
||||
Args:
|
||||
img: PIL Image 对象
|
||||
|
||||
Returns:
|
||||
压缩后的 WebP 二进制数据
|
||||
"""
|
||||
try:
|
||||
if img.mode in ('RGBA', 'P'):
|
||||
img = img.convert('RGB')
|
||||
|
||||
width, height = img.size
|
||||
if width > self.max_width:
|
||||
ratio = self.max_width / width
|
||||
new_size = (self.max_width, int(height * ratio))
|
||||
img = img.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
quality = 80
|
||||
while quality >= 10:
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='WEBP', quality=quality, method=6)
|
||||
if len(buffer.getvalue()) <= self.target_kb * 1024:
|
||||
return buffer.getvalue()
|
||||
quality -= 10
|
||||
|
||||
return buffer.getvalue()
|
||||
except Exception as e:
|
||||
logger.error(f"压缩图片失败: {e}")
|
||||
return None
|
||||
|
||||
def save_screenshot_snapshot(
|
||||
self,
|
||||
scan_id: int,
|
||||
url: str,
|
||||
image_data: bytes,
|
||||
status_code: int | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
保存截图快照到 ScreenshotSnapshot 表
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
url: 截图对应的 URL
|
||||
image_data: 压缩后的图片二进制数据
|
||||
status_code: HTTP 响应状态码
|
||||
|
||||
Returns:
|
||||
是否保存成功
|
||||
"""
|
||||
from apps.asset.models import ScreenshotSnapshot
|
||||
|
||||
try:
|
||||
ScreenshotSnapshot.objects.update_or_create(
|
||||
scan_id=scan_id,
|
||||
url=url,
|
||||
defaults={'image': image_data, 'status_code': status_code}
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"保存截图快照失败: scan_id={scan_id}, url={url}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def sync_screenshots_to_asset(self, scan_id: int, target_id: int) -> int:
|
||||
"""
|
||||
将扫描的截图快照同步到资产表
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
同步的截图数量
|
||||
"""
|
||||
from apps.asset.models import Screenshot, ScreenshotSnapshot
|
||||
|
||||
snapshots = ScreenshotSnapshot.objects.filter(scan_id=scan_id)
|
||||
count = 0
|
||||
|
||||
for snapshot in snapshots:
|
||||
try:
|
||||
Screenshot.objects.update_or_create(
|
||||
target_id=target_id,
|
||||
url=snapshot.url,
|
||||
defaults={
|
||||
'image': snapshot.image,
|
||||
'status_code': snapshot.status_code
|
||||
}
|
||||
)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"同步截图到资产表失败: url={snapshot.url}, 错误: {e}")
|
||||
|
||||
logger.info(f"同步截图完成: scan_id={scan_id}, target_id={target_id}, 数量={count}")
|
||||
return count
|
||||
|
||||
def process_and_save_screenshot(self, scan_id: int, url: str, image_path: str) -> bool:
|
||||
"""
|
||||
处理并保存截图(压缩 + 保存快照)
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
url: 截图对应的 URL
|
||||
image_path: PNG 截图文件路径
|
||||
|
||||
Returns:
|
||||
是否处理成功
|
||||
"""
|
||||
image_data = self.compress_screenshot(image_path)
|
||||
if image_data is None:
|
||||
return False
|
||||
|
||||
return self.save_screenshot_snapshot(scan_id, url, image_data)
|
||||
@@ -12,17 +12,22 @@ from .views import (
|
||||
AssetStatisticsViewSet,
|
||||
AssetSearchView,
|
||||
AssetSearchExportView,
|
||||
EndpointViewSet,
|
||||
HostPortMappingViewSet,
|
||||
ScreenshotViewSet,
|
||||
)
|
||||
|
||||
# 创建 DRF 路由器
|
||||
router = DefaultRouter()
|
||||
|
||||
# 注册 ViewSet
|
||||
# 注意:IPAddress 模型已被重构为 HostPortMapping,相关路由已移除
|
||||
router.register(r'subdomains', SubdomainViewSet, basename='subdomain')
|
||||
router.register(r'websites', WebSiteViewSet, basename='website')
|
||||
router.register(r'directories', DirectoryViewSet, basename='directory')
|
||||
router.register(r'endpoints', EndpointViewSet, basename='endpoint')
|
||||
router.register(r'ip-addresses', HostPortMappingViewSet, basename='ip-address')
|
||||
router.register(r'vulnerabilities', VulnerabilityViewSet, basename='vulnerability')
|
||||
router.register(r'screenshots', ScreenshotViewSet, basename='screenshot')
|
||||
router.register(r'statistics', AssetStatisticsViewSet, basename='asset-statistics')
|
||||
|
||||
urlpatterns = [
|
||||
|
||||
@@ -18,6 +18,8 @@ from .asset_views import (
|
||||
EndpointSnapshotViewSet,
|
||||
HostPortMappingSnapshotViewSet,
|
||||
VulnerabilitySnapshotViewSet,
|
||||
ScreenshotViewSet,
|
||||
ScreenshotSnapshotViewSet,
|
||||
)
|
||||
from .search_views import AssetSearchView, AssetSearchExportView
|
||||
|
||||
@@ -35,6 +37,8 @@ __all__ = [
|
||||
'EndpointSnapshotViewSet',
|
||||
'HostPortMappingSnapshotViewSet',
|
||||
'VulnerabilitySnapshotViewSet',
|
||||
'ScreenshotViewSet',
|
||||
'ScreenshotSnapshotViewSet',
|
||||
'AssetSearchView',
|
||||
'AssetSearchExportView',
|
||||
]
|
||||
|
||||
@@ -260,6 +260,35 @@ class SubdomainViewSet(viewsets.ModelViewSet):
|
||||
field_formatters=formatters
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='bulk-delete')
|
||||
def bulk_delete(self, request, **kwargs):
|
||||
"""批量删除子域名
|
||||
|
||||
POST /api/assets/subdomains/bulk-delete/
|
||||
|
||||
请求体: {"ids": [1, 2, 3]}
|
||||
响应: {"deletedCount": 3}
|
||||
"""
|
||||
ids = request.data.get('ids', [])
|
||||
if not ids or not isinstance(ids, list):
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='ids is required and must be a list',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
try:
|
||||
from ..models import Subdomain
|
||||
deleted_count, _ = Subdomain.objects.filter(id__in=ids).delete()
|
||||
return success_response(data={'deletedCount': deleted_count})
|
||||
except Exception as e:
|
||||
logger.exception("批量删除子域名失败")
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to delete subdomains',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
class WebSiteViewSet(viewsets.ModelViewSet):
|
||||
"""站点管理 ViewSet
|
||||
@@ -393,6 +422,35 @@ class WebSiteViewSet(viewsets.ModelViewSet):
|
||||
field_formatters=formatters
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='bulk-delete')
|
||||
def bulk_delete(self, request, **kwargs):
|
||||
"""批量删除网站
|
||||
|
||||
POST /api/assets/websites/bulk-delete/
|
||||
|
||||
请求体: {"ids": [1, 2, 3]}
|
||||
响应: {"deletedCount": 3}
|
||||
"""
|
||||
ids = request.data.get('ids', [])
|
||||
if not ids or not isinstance(ids, list):
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='ids is required and must be a list',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
try:
|
||||
from ..models import WebSite
|
||||
deleted_count, _ = WebSite.objects.filter(id__in=ids).delete()
|
||||
return success_response(data={'deletedCount': deleted_count})
|
||||
except Exception as e:
|
||||
logger.exception("批量删除网站失败")
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to delete websites',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
class DirectoryViewSet(viewsets.ModelViewSet):
|
||||
"""目录管理 ViewSet
|
||||
@@ -521,6 +579,35 @@ class DirectoryViewSet(viewsets.ModelViewSet):
|
||||
field_formatters=formatters
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='bulk-delete')
|
||||
def bulk_delete(self, request, **kwargs):
|
||||
"""批量删除目录
|
||||
|
||||
POST /api/assets/directories/bulk-delete/
|
||||
|
||||
请求体: {"ids": [1, 2, 3]}
|
||||
响应: {"deletedCount": 3}
|
||||
"""
|
||||
ids = request.data.get('ids', [])
|
||||
if not ids or not isinstance(ids, list):
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='ids is required and must be a list',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
try:
|
||||
from ..models import Directory
|
||||
deleted_count, _ = Directory.objects.filter(id__in=ids).delete()
|
||||
return success_response(data={'deletedCount': deleted_count})
|
||||
except Exception as e:
|
||||
logger.exception("批量删除目录失败")
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to delete directories',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
class EndpointViewSet(viewsets.ModelViewSet):
|
||||
"""端点管理 ViewSet
|
||||
@@ -655,6 +742,35 @@ class EndpointViewSet(viewsets.ModelViewSet):
|
||||
field_formatters=formatters
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='bulk-delete')
|
||||
def bulk_delete(self, request, **kwargs):
|
||||
"""批量删除端点
|
||||
|
||||
POST /api/assets/endpoints/bulk-delete/
|
||||
|
||||
请求体: {"ids": [1, 2, 3]}
|
||||
响应: {"deletedCount": 3}
|
||||
"""
|
||||
ids = request.data.get('ids', [])
|
||||
if not ids or not isinstance(ids, list):
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='ids is required and must be a list',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
try:
|
||||
from ..models import Endpoint
|
||||
deleted_count, _ = Endpoint.objects.filter(id__in=ids).delete()
|
||||
return success_response(data={'deletedCount': deleted_count})
|
||||
except Exception as e:
|
||||
logger.exception("批量删除端点失败")
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to delete endpoints',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
class HostPortMappingViewSet(viewsets.ModelViewSet):
|
||||
"""主机端口映射管理 ViewSet(IP 地址聚合视图)
|
||||
@@ -728,6 +844,38 @@ class HostPortMappingViewSet(viewsets.ModelViewSet):
|
||||
field_formatters=formatters
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='bulk-delete')
|
||||
def bulk_delete(self, request, **kwargs):
|
||||
"""批量删除 IP 地址映射
|
||||
|
||||
POST /api/assets/ip-addresses/bulk-delete/
|
||||
|
||||
请求体: {"ips": ["192.168.1.1", "10.0.0.1"]}
|
||||
响应: {"deletedCount": 3}
|
||||
|
||||
注意:由于 IP 地址是聚合显示的,删除时传入 IP 列表,
|
||||
会删除该 IP 下的所有 host:port 映射记录
|
||||
"""
|
||||
ips = request.data.get('ips', [])
|
||||
if not ips or not isinstance(ips, list):
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='ips is required and must be a list',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
try:
|
||||
from ..models import HostPortMapping
|
||||
deleted_count, _ = HostPortMapping.objects.filter(ip__in=ips).delete()
|
||||
return success_response(data={'deletedCount': deleted_count})
|
||||
except Exception as e:
|
||||
logger.exception("批量删除 IP 地址映射失败")
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to delete ip addresses',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
class VulnerabilityViewSet(viewsets.ModelViewSet):
|
||||
"""漏洞资产管理 ViewSet(只读)
|
||||
@@ -1077,3 +1225,162 @@ class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
|
||||
return self.service.get_all(filter_query=filter_query)
|
||||
|
||||
|
||||
# ==================== 截图 ViewSet ====================
|
||||
|
||||
class ScreenshotViewSet(viewsets.ModelViewSet):
|
||||
"""截图资产 ViewSet
|
||||
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/screenshots/
|
||||
2. 独立路由:GET /api/screenshots/(全局查询)
|
||||
|
||||
支持智能过滤语法(filter 参数):
|
||||
- url="example" URL 模糊匹配
|
||||
"""
|
||||
|
||||
from ..serializers import ScreenshotListSerializer
|
||||
|
||||
serializer_class = ScreenshotListSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering = ['-created_at']
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
from ..models import Screenshot
|
||||
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
queryset = Screenshot.objects.all()
|
||||
if target_pk:
|
||||
queryset = queryset.filter(target_id=target_pk)
|
||||
|
||||
if filter_query:
|
||||
# 简单的 URL 模糊匹配
|
||||
queryset = queryset.filter(url__icontains=filter_query)
|
||||
|
||||
return queryset.order_by('-created_at')
|
||||
|
||||
@action(detail=True, methods=['get'], url_path='image')
|
||||
def image(self, request, pk=None, **kwargs):
|
||||
"""获取截图图片
|
||||
|
||||
GET /api/assets/screenshots/{id}/image/
|
||||
|
||||
返回 WebP 格式的图片二进制数据
|
||||
"""
|
||||
from django.http import HttpResponse
|
||||
from ..models import Screenshot
|
||||
|
||||
try:
|
||||
screenshot = Screenshot.objects.get(pk=pk)
|
||||
if not screenshot.image:
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='Screenshot image not found',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
response = HttpResponse(screenshot.image, content_type='image/webp')
|
||||
response['Content-Disposition'] = f'inline; filename="screenshot_{pk}.webp"'
|
||||
return response
|
||||
except Screenshot.DoesNotExist:
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='Screenshot not found',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='bulk-delete')
|
||||
def bulk_delete(self, request, **kwargs):
|
||||
"""批量删除截图
|
||||
|
||||
POST /api/assets/screenshots/bulk-delete/
|
||||
|
||||
请求体: {"ids": [1, 2, 3]}
|
||||
响应: {"deletedCount": 3}
|
||||
"""
|
||||
ids = request.data.get('ids', [])
|
||||
if not ids or not isinstance(ids, list):
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='ids is required and must be a list',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
try:
|
||||
from ..models import Screenshot
|
||||
deleted_count, _ = Screenshot.objects.filter(id__in=ids).delete()
|
||||
return success_response(data={'deletedCount': deleted_count})
|
||||
except Exception as e:
|
||||
logger.exception("批量删除截图失败")
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to delete screenshots',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
class ScreenshotSnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""截图快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/screenshots/
|
||||
|
||||
支持智能过滤语法(filter 参数):
|
||||
- url="example" URL 模糊匹配
|
||||
"""
|
||||
|
||||
from ..serializers import ScreenshotSnapshotListSerializer
|
||||
|
||||
serializer_class = ScreenshotSnapshotListSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering = ['-created_at']
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据 scan_pk 参数查询"""
|
||||
from ..models import ScreenshotSnapshot
|
||||
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
queryset = ScreenshotSnapshot.objects.all()
|
||||
if scan_pk:
|
||||
queryset = queryset.filter(scan_id=scan_pk)
|
||||
|
||||
if filter_query:
|
||||
# 简单的 URL 模糊匹配
|
||||
queryset = queryset.filter(url__icontains=filter_query)
|
||||
|
||||
return queryset.order_by('-created_at')
|
||||
|
||||
@action(detail=True, methods=['get'], url_path='image')
|
||||
def image(self, request, pk=None, **kwargs):
|
||||
"""获取截图快照图片
|
||||
|
||||
GET /api/scans/{scan_pk}/screenshots/{id}/image/
|
||||
|
||||
返回 WebP 格式的图片二进制数据
|
||||
"""
|
||||
from django.http import HttpResponse
|
||||
from ..models import ScreenshotSnapshot
|
||||
|
||||
try:
|
||||
screenshot = ScreenshotSnapshot.objects.get(pk=pk)
|
||||
if not screenshot.image:
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='Screenshot image not found',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
response = HttpResponse(screenshot.image, content_type='image/webp')
|
||||
response['Content-Disposition'] = f'inline; filename="screenshot_snapshot_{pk}.webp"'
|
||||
return response
|
||||
except ScreenshotSnapshot.DoesNotExist:
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='Screenshot snapshot not found',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from .views import (
|
||||
LoginView, LogoutView, MeView, ChangePasswordView,
|
||||
SystemLogsView, SystemLogFilesView, HealthCheckView,
|
||||
GlobalBlacklistView,
|
||||
VersionView, CheckUpdateView,
|
||||
)
|
||||
|
||||
urlpatterns = [
|
||||
@@ -29,6 +30,8 @@ urlpatterns = [
|
||||
# 系统管理
|
||||
path('system/logs/', SystemLogsView.as_view(), name='system-logs'),
|
||||
path('system/logs/files/', SystemLogFilesView.as_view(), name='system-log-files'),
|
||||
path('system/version/', VersionView.as_view(), name='system-version'),
|
||||
path('system/check-update/', CheckUpdateView.as_view(), name='system-check-update'),
|
||||
|
||||
# 黑名单管理(PUT 全量替换模式)
|
||||
path('blacklist/rules/', GlobalBlacklistView.as_view(), name='blacklist-rules'),
|
||||
|
||||
@@ -6,16 +6,19 @@
|
||||
- 认证相关视图:登录、登出、用户信息、修改密码
|
||||
- 系统日志视图:实时日志查看
|
||||
- 黑名单视图:全局黑名单规则管理
|
||||
- 版本视图:系统版本和更新检查
|
||||
"""
|
||||
|
||||
from .health_views import HealthCheckView
|
||||
from .auth_views import LoginView, LogoutView, MeView, ChangePasswordView
|
||||
from .system_log_views import SystemLogsView, SystemLogFilesView
|
||||
from .blacklist_views import GlobalBlacklistView
|
||||
from .version_views import VersionView, CheckUpdateView
|
||||
|
||||
__all__ = [
|
||||
'HealthCheckView',
|
||||
'LoginView', 'LogoutView', 'MeView', 'ChangePasswordView',
|
||||
'SystemLogsView', 'SystemLogFilesView',
|
||||
'GlobalBlacklistView',
|
||||
'VersionView', 'CheckUpdateView',
|
||||
]
|
||||
|
||||
136
backend/apps/common/views/version_views.py
Normal file
136
backend/apps/common/views/version_views.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
系统版本相关视图
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
from apps.common.response_helpers import error_response, success_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# GitHub 仓库信息
|
||||
GITHUB_REPO = "yyhuni/xingrin"
|
||||
GITHUB_API_URL = f"https://api.github.com/repos/{GITHUB_REPO}/releases/latest"
|
||||
GITHUB_RELEASES_URL = f"https://github.com/{GITHUB_REPO}/releases"
|
||||
|
||||
|
||||
def get_current_version() -> str:
|
||||
"""读取当前版本号"""
|
||||
import os
|
||||
|
||||
# 方式1:从环境变量读取(Docker 容器中推荐)
|
||||
version = os.environ.get('IMAGE_TAG', '')
|
||||
if version:
|
||||
return version
|
||||
|
||||
# 方式2:从文件读取(开发环境)
|
||||
possible_paths = [
|
||||
Path('/app/VERSION'),
|
||||
Path(__file__).parent.parent.parent.parent.parent / 'VERSION',
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
try:
|
||||
return path.read_text(encoding='utf-8').strip()
|
||||
except (FileNotFoundError, OSError):
|
||||
continue
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def compare_versions(current: str, latest: str) -> bool:
|
||||
"""
|
||||
比较版本号,判断是否有更新
|
||||
|
||||
Returns:
|
||||
True 表示有更新可用
|
||||
"""
|
||||
def parse_version(v: str) -> tuple:
|
||||
v = v.lstrip('v')
|
||||
parts = v.split('.')
|
||||
result = []
|
||||
for part in parts:
|
||||
if '-' in part:
|
||||
num, _ = part.split('-', 1)
|
||||
result.append(int(num))
|
||||
else:
|
||||
result.append(int(part))
|
||||
return tuple(result)
|
||||
|
||||
try:
|
||||
return parse_version(latest) > parse_version(current)
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
class VersionView(APIView):
|
||||
"""获取当前系统版本"""
|
||||
|
||||
def get(self, _request: Request) -> Response:
|
||||
"""获取当前版本信息"""
|
||||
return success_response(data={
|
||||
'version': get_current_version(),
|
||||
'github_repo': GITHUB_REPO,
|
||||
})
|
||||
|
||||
|
||||
class CheckUpdateView(APIView):
|
||||
"""检查系统更新"""
|
||||
|
||||
def get(self, _request: Request) -> Response:
|
||||
"""
|
||||
检查是否有新版本
|
||||
|
||||
Returns:
|
||||
- current_version: 当前版本
|
||||
- latest_version: 最新版本
|
||||
- has_update: 是否有更新
|
||||
- release_url: 发布页面 URL
|
||||
- release_notes: 更新说明(如果有)
|
||||
"""
|
||||
current_version = get_current_version()
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
GITHUB_API_URL,
|
||||
headers={'Accept': 'application/vnd.github.v3+json'},
|
||||
timeout=10
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
return success_response(data={
|
||||
'current_version': current_version,
|
||||
'latest_version': current_version,
|
||||
'has_update': False,
|
||||
'release_url': GITHUB_RELEASES_URL,
|
||||
'release_notes': None,
|
||||
})
|
||||
|
||||
response.raise_for_status()
|
||||
release_data = response.json()
|
||||
|
||||
latest_version = release_data.get('tag_name', current_version)
|
||||
has_update = compare_versions(current_version, latest_version)
|
||||
|
||||
return success_response(data={
|
||||
'current_version': current_version,
|
||||
'latest_version': latest_version,
|
||||
'has_update': has_update,
|
||||
'release_url': release_data.get('html_url', GITHUB_RELEASES_URL),
|
||||
'release_notes': release_data.get('body'),
|
||||
'published_at': release_data.get('published_at'),
|
||||
})
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.warning("检查更新失败: %s", e)
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message="无法连接到 GitHub,请稍后重试",
|
||||
)
|
||||
@@ -2,8 +2,9 @@
|
||||
初始化默认扫描引擎
|
||||
|
||||
用法:
|
||||
python manage.py init_default_engine # 只创建不存在的引擎(不覆盖已有)
|
||||
python manage.py init_default_engine --force # 强制覆盖所有引擎配置
|
||||
python manage.py init_default_engine # 只创建不存在的引擎(不覆盖已有)
|
||||
python manage.py init_default_engine --force # 强制覆盖所有引擎配置
|
||||
python manage.py init_default_engine --force-sub # 只覆盖子引擎,保留 full scan
|
||||
|
||||
cd /root/my-vulun-scan/docker
|
||||
docker compose exec server python backend/manage.py init_default_engine --force
|
||||
@@ -12,6 +13,7 @@
|
||||
- 读取 engine_config_example.yaml 作为默认配置
|
||||
- 创建 full scan(默认引擎)+ 各扫描类型的子引擎
|
||||
- 默认不覆盖已有配置,加 --force 才会覆盖
|
||||
- 加 --force-sub 只覆盖子引擎配置,保留用户自定义的 full scan
|
||||
"""
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
@@ -30,11 +32,18 @@ class Command(BaseCommand):
|
||||
parser.add_argument(
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='强制覆盖已有的引擎配置',
|
||||
help='强制覆盖已有的引擎配置(包括 full scan 和子引擎)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--force-sub',
|
||||
action='store_true',
|
||||
help='只覆盖子引擎配置,保留 full scan(升级时使用)',
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
force = options.get('force', False)
|
||||
force_sub = options.get('force_sub', False)
|
||||
|
||||
# 读取默认配置文件
|
||||
config_path = Path(__file__).resolve().parent.parent.parent.parent / 'scan' / 'configs' / 'engine_config_example.yaml'
|
||||
|
||||
@@ -99,15 +108,22 @@ class Command(BaseCommand):
|
||||
engine_name = f"{scan_type}"
|
||||
sub_engine = ScanEngine.objects.filter(name=engine_name).first()
|
||||
if sub_engine:
|
||||
if force:
|
||||
# force 或 force_sub 都会覆盖子引擎
|
||||
if force or force_sub:
|
||||
sub_engine.configuration = single_yaml
|
||||
sub_engine.save()
|
||||
self.stdout.write(self.style.SUCCESS(f' ✓ 子引擎 {engine_name} 配置已更新 (ID: {sub_engine.id})'))
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f' ✓ 子引擎 {engine_name} 配置已更新 (ID: {sub_engine.id})'
|
||||
))
|
||||
else:
|
||||
self.stdout.write(self.style.WARNING(f' ⊘ {engine_name} 已存在,跳过(使用 --force 覆盖)'))
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f' ⊘ {engine_name} 已存在,跳过(使用 --force 覆盖)'
|
||||
))
|
||||
else:
|
||||
sub_engine = ScanEngine.objects.create(
|
||||
name=engine_name,
|
||||
configuration=single_yaml,
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS(f' ✓ 子引擎 {engine_name} 已创建 (ID: {sub_engine.id})'))
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f' ✓ 子引擎 {engine_name} 已创建 (ID: {sub_engine.id})'
|
||||
))
|
||||
|
||||
@@ -312,7 +312,11 @@ class TaskDistributor:
|
||||
# - 本地 Worker:install.sh 已预拉取镜像,直接使用本地版本
|
||||
# - 远程 Worker:deploy 时已预拉取镜像,直接使用本地版本
|
||||
# - 避免每次任务都检查 Docker Hub,提升性能和稳定性
|
||||
# OOM 优先级:--oom-score-adj=1000 让 Worker 在内存不足时优先被杀
|
||||
# - 范围 -1000 到 1000,值越大越容易被 OOM Killer 选中
|
||||
# - 保护 server/nginx/frontend 等核心服务,确保 Web 界面可用
|
||||
cmd = f'''docker run --rm -d --pull=missing {network_arg} \\
|
||||
--oom-score-adj=1000 \\
|
||||
{' '.join(env_vars)} \\
|
||||
{' '.join(volumes)} \\
|
||||
{self.docker_image} \\
|
||||
|
||||
@@ -24,18 +24,6 @@ SUBDOMAIN_DISCOVERY_COMMANDS = {
|
||||
}
|
||||
},
|
||||
|
||||
'amass_passive': {
|
||||
# 先执行被动枚举,将结果写入 amass 内部数据库,然后从数据库中导出纯域名(names)到 output_file
|
||||
# -silent 禁用进度条和其他输出
|
||||
'base': "amass enum -passive -silent -d {domain} && amass subs -names -d {domain} > '{output_file}'"
|
||||
},
|
||||
|
||||
'amass_active': {
|
||||
# 先执行主动枚举 + 爆破,将结果写入 amass 内部数据库,然后从数据库中导出纯域名(names)到 output_file
|
||||
# -silent 禁用进度条和其他输出
|
||||
'base': "amass enum -active -silent -d {domain} -brute && amass subs -names -d {domain} > '{output_file}'"
|
||||
},
|
||||
|
||||
'sublist3r': {
|
||||
'base': "python3 '/usr/local/share/Sublist3r/sublist3r.py' -d {domain} -o '{output_file}'",
|
||||
'optional': {
|
||||
@@ -263,11 +251,16 @@ COMMAND_TEMPLATES = {
|
||||
'directory_scan': DIRECTORY_SCAN_COMMANDS,
|
||||
'url_fetch': URL_FETCH_COMMANDS,
|
||||
'vuln_scan': VULN_SCAN_COMMANDS,
|
||||
'screenshot': {}, # 使用 Python 原生库(Playwright),无命令模板
|
||||
}
|
||||
|
||||
# ==================== 扫描类型配置 ====================
|
||||
|
||||
# 执行阶段定义(按顺序执行)
|
||||
# Stage 1: 资产发现 - 子域名 → 端口 → 站点探测 → 指纹识别
|
||||
# Stage 2: URL 收集 - URL 获取 + 目录扫描(并行)
|
||||
# Stage 3: 截图 - 在 URL 收集完成后执行,捕获更多发现的页面
|
||||
# Stage 4: 漏洞扫描 - 最后执行
|
||||
EXECUTION_STAGES = [
|
||||
{
|
||||
'mode': 'sequential',
|
||||
@@ -277,6 +270,10 @@ EXECUTION_STAGES = [
|
||||
'mode': 'parallel',
|
||||
'flows': ['url_fetch', 'directory_scan']
|
||||
},
|
||||
{
|
||||
'mode': 'sequential',
|
||||
'flows': ['screenshot']
|
||||
},
|
||||
{
|
||||
'mode': 'sequential',
|
||||
'flows': ['vuln_scan']
|
||||
|
||||
@@ -17,14 +17,6 @@ subdomain_discovery:
|
||||
timeout: 3600 # 1小时
|
||||
# threads: 10 # 并发 goroutine 数
|
||||
|
||||
amass_passive:
|
||||
enabled: true
|
||||
timeout: 3600
|
||||
|
||||
amass_active:
|
||||
enabled: true # 主动枚举 + 爆破
|
||||
timeout: 3600
|
||||
|
||||
sublist3r:
|
||||
enabled: true
|
||||
timeout: 3600
|
||||
@@ -62,7 +54,7 @@ port_scan:
|
||||
threads: 200 # 并发连接数(默认 5)
|
||||
# ports: 1-65535 # 扫描端口范围(默认 1-65535)
|
||||
top-ports: 100 # 扫描 nmap top 100 端口
|
||||
rate: 10 # 扫描速率(默认 10)
|
||||
rate: 50 # 扫描速率
|
||||
|
||||
naabu_passive:
|
||||
enabled: true
|
||||
@@ -101,6 +93,16 @@ directory_scan:
|
||||
match-codes: 200,201,301,302,401,403 # 匹配的 HTTP 状态码
|
||||
# rate: 0 # 每秒请求数(默认 0 不限制)
|
||||
|
||||
screenshot:
|
||||
# ==================== 网站截图 ====================
|
||||
# 使用 Playwright 对网站进行截图,保存为 WebP 格式
|
||||
# 在 Stage 2 与 url_fetch、directory_scan 并行执行
|
||||
tools:
|
||||
playwright:
|
||||
enabled: true
|
||||
concurrency: 5 # 并发截图数(默认 5)
|
||||
url_sources: [websites] # URL 来源,当前对website截图,还可以用 [websites, endpoints]
|
||||
|
||||
url_fetch:
|
||||
# ==================== URL 获取 ====================
|
||||
tools:
|
||||
|
||||
@@ -10,30 +10,30 @@
|
||||
- 配置由 YAML 解析
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
from prefect import flow
|
||||
from prefect.task_runners import ThreadPoolTaskRunner
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from apps.scan.tasks.directory_scan import (
|
||||
export_sites_task,
|
||||
run_and_stream_save_directories_task
|
||||
)
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.tasks.directory_scan import (
|
||||
export_sites_task,
|
||||
run_and_stream_save_directories_task,
|
||||
)
|
||||
from apps.scan.utils import (
|
||||
build_scan_command,
|
||||
ensure_wordlist_local,
|
||||
user_log,
|
||||
wait_for_system_load,
|
||||
)
|
||||
from apps.scan.utils import config_parser, build_scan_command, ensure_wordlist_local, user_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,517 +45,343 @@ def calculate_directory_scan_timeout(
|
||||
tool_config: dict,
|
||||
base_per_word: float = 1.0,
|
||||
min_timeout: int = 60,
|
||||
max_timeout: int = 7200
|
||||
) -> int:
|
||||
"""
|
||||
根据字典行数计算目录扫描超时时间
|
||||
|
||||
|
||||
计算公式:超时时间 = 字典行数 × 每个单词基础时间
|
||||
超时范围:60秒 ~ 2小时(7200秒)
|
||||
|
||||
超时范围:最小 60 秒,无上限
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置字典,包含 wordlist 路径
|
||||
base_per_word: 每个单词的基础时间(秒),默认 1.0秒
|
||||
min_timeout: 最小超时时间(秒),默认 60秒
|
||||
max_timeout: 最大超时时间(秒),默认 7200秒(2小时)
|
||||
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒),范围:60 ~ 7200
|
||||
|
||||
Example:
|
||||
# 1000行字典 × 1.0秒 = 1000秒 → 限制为7200秒中的 1000秒
|
||||
# 10000行字典 × 1.0秒 = 10000秒 → 限制为7200秒(最大值)
|
||||
timeout = calculate_directory_scan_timeout(
|
||||
tool_config={'wordlist': '/path/to/wordlist.txt'}
|
||||
)
|
||||
int: 计算出的超时时间(秒)
|
||||
"""
|
||||
import os
|
||||
|
||||
wordlist_path = tool_config.get('wordlist')
|
||||
if not wordlist_path:
|
||||
logger.warning("工具配置中未指定 wordlist,使用默认超时: %d秒", min_timeout)
|
||||
return min_timeout
|
||||
|
||||
wordlist_path = os.path.expanduser(wordlist_path)
|
||||
|
||||
if not os.path.exists(wordlist_path):
|
||||
logger.warning("字典文件不存在: %s,使用默认超时: %d秒", wordlist_path, min_timeout)
|
||||
return min_timeout
|
||||
|
||||
try:
|
||||
# 从 tool_config 中获取 wordlist 路径
|
||||
wordlist_path = tool_config.get('wordlist')
|
||||
if not wordlist_path:
|
||||
logger.warning("工具配置中未指定 wordlist,使用默认超时: %d秒", min_timeout)
|
||||
return min_timeout
|
||||
|
||||
# 展开用户目录(~)
|
||||
wordlist_path = os.path.expanduser(wordlist_path)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(wordlist_path):
|
||||
logger.warning("字典文件不存在: %s,使用默认超时: %d秒", wordlist_path, min_timeout)
|
||||
return min_timeout
|
||||
|
||||
# 使用 wc -l 快速统计字典行数
|
||||
result = subprocess.run(
|
||||
['wc', '-l', wordlist_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
# wc -l 输出格式:行数 + 空格 + 文件名
|
||||
line_count = int(result.stdout.strip().split()[0])
|
||||
|
||||
# 计算超时时间
|
||||
timeout = int(line_count * base_per_word)
|
||||
|
||||
# 设置合理的下限(不再设置上限)
|
||||
timeout = max(min_timeout, timeout)
|
||||
|
||||
timeout = max(min_timeout, int(line_count * base_per_word))
|
||||
|
||||
logger.info(
|
||||
"目录扫描超时计算 - 字典: %s, 行数: %d, 基础时间: %.3f秒/词, 计算超时: %d秒",
|
||||
wordlist_path, line_count, base_per_word, timeout
|
||||
)
|
||||
|
||||
return timeout
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error("统计字典行数失败: %s", e)
|
||||
# 失败时返回默认超时
|
||||
return min_timeout
|
||||
except (ValueError, IndexError) as e:
|
||||
logger.error("解析字典行数失败: %s", e)
|
||||
return min_timeout
|
||||
except Exception as e:
|
||||
logger.error("计算超时时间异常: %s", e)
|
||||
|
||||
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
|
||||
logger.error("计算超时时间失败: %s", e)
|
||||
return min_timeout
|
||||
|
||||
|
||||
def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> int:
|
||||
"""
|
||||
从单个工具配置中获取 max_workers 参数
|
||||
|
||||
Args:
|
||||
tool_config: 单个工具的配置字典,如 {'max_workers': 10, 'threads': 5, ...}
|
||||
default: 默认值,默认为 5
|
||||
|
||||
Returns:
|
||||
int: max_workers 值
|
||||
"""
|
||||
"""从单个工具配置中获取 max_workers 参数"""
|
||||
if not isinstance(tool_config, dict):
|
||||
return default
|
||||
|
||||
# 支持 max_workers 和 max-workers(YAML 中划线会被转换)
|
||||
|
||||
max_workers = tool_config.get('max_workers') or tool_config.get('max-workers')
|
||||
if max_workers is not None and isinstance(max_workers, int) and max_workers > 0:
|
||||
if isinstance(max_workers, int) and max_workers > 0:
|
||||
return max_workers
|
||||
return default
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]:
|
||||
def _export_site_urls(
|
||||
target_id: int,
|
||||
directory_scan_dir: Path
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
导出目标下的所有站点 URL 到文件(支持懒加载)
|
||||
|
||||
导出目标下的所有站点 URL 到文件
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(用于懒加载创建默认站点)
|
||||
directory_scan_dir: 目录扫描目录
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (sites_file, site_count)
|
||||
|
||||
Raises:
|
||||
ValueError: 站点数量为 0
|
||||
"""
|
||||
logger.info("Step 1: 导出目标的所有站点 URL")
|
||||
|
||||
|
||||
sites_file = str(directory_scan_dir / 'sites.txt')
|
||||
export_result = export_sites_task(
|
||||
target_id=target_id,
|
||||
output_file=sites_file,
|
||||
batch_size=1000 # 每次读取 1000 条,优化内存占用
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
|
||||
site_count = export_result['total_count']
|
||||
|
||||
logger.info(
|
||||
"✓ 站点 URL 导出完成 - 文件: %s, 数量: %d",
|
||||
export_result['output_file'],
|
||||
site_count
|
||||
)
|
||||
|
||||
|
||||
if site_count == 0:
|
||||
logger.warning("目标下没有站点,无法执行目录扫描")
|
||||
# 不抛出异常,由上层决定如何处理
|
||||
# raise ValueError("目标下没有站点,无法执行目录扫描")
|
||||
|
||||
|
||||
return export_result['output_file'], site_count
|
||||
|
||||
|
||||
def _run_scans_sequentially(
|
||||
enabled_tools: dict,
|
||||
sites_file: str,
|
||||
directory_scan_dir: Path,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
site_count: int,
|
||||
target_name: str
|
||||
) -> tuple[int, int, list]:
|
||||
"""
|
||||
串行执行目录扫描任务(支持多工具)- 已废弃,保留用于兼容
|
||||
|
||||
Args:
|
||||
enabled_tools: 启用的工具配置字典
|
||||
sites_file: 站点文件路径
|
||||
directory_scan_dir: 目录扫描目录
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
site_count: 站点数量
|
||||
target_name: 目标名称(用于错误日志)
|
||||
|
||||
Returns:
|
||||
tuple: (total_directories, processed_sites, failed_sites)
|
||||
"""
|
||||
# 读取站点列表
|
||||
sites = []
|
||||
with open(sites_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
site_url = line.strip()
|
||||
if site_url:
|
||||
sites.append(site_url)
|
||||
|
||||
logger.info("准备扫描 %d 个站点,使用工具: %s", len(sites), ', '.join(enabled_tools.keys()))
|
||||
|
||||
total_directories = 0
|
||||
processed_sites_set = set() # 使用 set 避免重复计数
|
||||
failed_sites = []
|
||||
|
||||
# 遍历每个工具
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
logger.info("="*60)
|
||||
logger.info("使用工具: %s", tool_name)
|
||||
logger.info("="*60)
|
||||
|
||||
# 如果配置了 wordlist_name,则先确保本地存在对应的字典文件(含 hash 校验)
|
||||
wordlist_name = tool_config.get('wordlist_name')
|
||||
if wordlist_name:
|
||||
try:
|
||||
local_wordlist_path = ensure_wordlist_local(wordlist_name)
|
||||
tool_config['wordlist'] = local_wordlist_path
|
||||
except Exception as exc:
|
||||
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
|
||||
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
|
||||
failed_sites.extend(sites)
|
||||
continue
|
||||
|
||||
# 逐个站点执行扫描
|
||||
for idx, site_url in enumerate(sites, 1):
|
||||
logger.info(
|
||||
"[%d/%d] 开始扫描站点: %s (工具: %s)",
|
||||
idx, len(sites), site_url, tool_name
|
||||
)
|
||||
|
||||
# 使用统一的命令构建器
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='directory_scan',
|
||||
command_params={
|
||||
'url': site_url
|
||||
},
|
||||
tool_config=tool_config
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
|
||||
idx, len(sites), tool_name, e, site_url
|
||||
)
|
||||
failed_sites.append(site_url)
|
||||
continue
|
||||
|
||||
# 单个站点超时:从配置中获取(支持 'auto' 动态计算)
|
||||
# ffuf 逐个站点扫描,timeout 就是单个站点的超时时间
|
||||
site_timeout = tool_config.get('timeout', 300)
|
||||
if site_timeout == 'auto':
|
||||
# 动态计算超时时间(基于字典行数)
|
||||
site_timeout = calculate_directory_scan_timeout(tool_config)
|
||||
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}秒")
|
||||
|
||||
# 生成日志文件路径
|
||||
from datetime import datetime
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = directory_scan_dir / f"{tool_name}_{timestamp}_{idx}.log"
|
||||
|
||||
try:
|
||||
# 直接调用 task(串行执行)
|
||||
result = run_and_stream_save_directories_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name, # 新增:工具名称
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
site_url=site_url,
|
||||
cwd=str(directory_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1000,
|
||||
timeout=site_timeout,
|
||||
log_file=str(log_file) # 新增:日志文件路径
|
||||
)
|
||||
|
||||
total_directories += result.get('created_directories', 0)
|
||||
processed_sites_set.add(site_url) # 使用 set 记录成功的站点
|
||||
|
||||
logger.info(
|
||||
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
|
||||
idx, len(sites), site_url,
|
||||
result.get('created_directories', 0)
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
# 超时异常单独处理
|
||||
failed_sites.append(site_url)
|
||||
logger.warning(
|
||||
"⚠️ [%d/%d] 站点扫描超时: %s - 超时配置: %d秒\n"
|
||||
"注意:超时前已解析的目录数据已保存到数据库,但扫描未完全完成。",
|
||||
idx, len(sites), site_url, site_timeout
|
||||
)
|
||||
except Exception as exc:
|
||||
# 其他异常
|
||||
failed_sites.append(site_url)
|
||||
logger.error(
|
||||
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
|
||||
idx, len(sites), site_url, exc
|
||||
)
|
||||
|
||||
# 每 10 个站点输出进度
|
||||
if idx % 10 == 0:
|
||||
logger.info(
|
||||
"进度: %d/%d (%.1f%%) - 已发现 %d 个目录",
|
||||
idx, len(sites), idx/len(sites)*100, total_directories
|
||||
)
|
||||
|
||||
# 计算成功和失败的站点数
|
||||
processed_count = len(processed_sites_set)
|
||||
|
||||
if failed_sites:
|
||||
logger.warning(
|
||||
"部分站点扫描失败: %d/%d",
|
||||
len(failed_sites), len(sites)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"✓ 串行目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
|
||||
processed_count, len(sites), len(failed_sites), total_directories
|
||||
)
|
||||
|
||||
return total_directories, processed_count, failed_sites
|
||||
|
||||
|
||||
def _generate_log_filename(tool_name: str, site_url: str, directory_scan_dir: Path) -> Path:
|
||||
"""
|
||||
生成唯一的日志文件名
|
||||
|
||||
使用 URL 的 hash 确保并发时不会冲突
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
site_url: 站点 URL
|
||||
directory_scan_dir: 目录扫描目录
|
||||
|
||||
Returns:
|
||||
Path: 日志文件路径
|
||||
"""
|
||||
url_hash = hashlib.md5(site_url.encode()).hexdigest()[:8]
|
||||
def _generate_log_filename(
|
||||
tool_name: str,
|
||||
site_url: str,
|
||||
directory_scan_dir: Path
|
||||
) -> Path:
|
||||
"""生成唯一的日志文件名(使用 URL 的 hash 确保并发时不会冲突)"""
|
||||
url_hash = hashlib.md5(
|
||||
site_url.encode(),
|
||||
usedforsecurity=False
|
||||
).hexdigest()[:8]
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
|
||||
return directory_scan_dir / f"{tool_name}_{url_hash}_{timestamp}.log"
|
||||
|
||||
|
||||
def _prepare_tool_wordlist(tool_name: str, tool_config: dict) -> bool:
|
||||
"""准备工具的字典文件,返回是否成功"""
|
||||
wordlist_name = tool_config.get('wordlist_name')
|
||||
if not wordlist_name:
|
||||
return True
|
||||
|
||||
try:
|
||||
local_wordlist_path = ensure_wordlist_local(wordlist_name)
|
||||
tool_config['wordlist'] = local_wordlist_path
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
|
||||
return False
|
||||
|
||||
|
||||
def _build_scan_params(
|
||||
tool_name: str,
|
||||
tool_config: dict,
|
||||
sites: List[str],
|
||||
directory_scan_dir: Path,
|
||||
site_timeout: int
|
||||
) -> Tuple[List[dict], List[str]]:
|
||||
"""构建所有站点的扫描参数,返回 (scan_params_list, failed_sites)"""
|
||||
scan_params_list = []
|
||||
failed_sites = []
|
||||
|
||||
for idx, site_url in enumerate(sites, 1):
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='directory_scan',
|
||||
command_params={'url': site_url},
|
||||
tool_config=tool_config
|
||||
)
|
||||
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
|
||||
scan_params_list.append({
|
||||
'idx': idx,
|
||||
'site_url': site_url,
|
||||
'command': command,
|
||||
'log_file': str(log_file),
|
||||
'timeout': site_timeout
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
|
||||
idx, len(sites), tool_name, e, site_url
|
||||
)
|
||||
failed_sites.append(site_url)
|
||||
|
||||
return scan_params_list, failed_sites
|
||||
|
||||
|
||||
def _execute_batch(
|
||||
batch_params: List[dict],
|
||||
tool_name: str,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
directory_scan_dir: Path,
|
||||
total_sites: int
|
||||
) -> Tuple[int, List[str]]:
|
||||
"""执行一批扫描任务,返回 (directories_found, failed_sites)"""
|
||||
directories_found = 0
|
||||
failed_sites = []
|
||||
|
||||
# 提交任务
|
||||
futures = []
|
||||
for params in batch_params:
|
||||
future = run_and_stream_save_directories_task.submit(
|
||||
cmd=params['command'],
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
site_url=params['site_url'],
|
||||
cwd=str(directory_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1000,
|
||||
timeout=params['timeout'],
|
||||
log_file=params['log_file']
|
||||
)
|
||||
futures.append((params['idx'], params['site_url'], future))
|
||||
|
||||
# 等待结果
|
||||
for idx, site_url, future in futures:
|
||||
try:
|
||||
result = future.result()
|
||||
dirs_count = result.get('created_directories', 0)
|
||||
directories_found += dirs_count
|
||||
logger.info(
|
||||
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
|
||||
idx, total_sites, site_url, dirs_count
|
||||
)
|
||||
except Exception as exc:
|
||||
failed_sites.append(site_url)
|
||||
if 'timeout' in str(exc).lower():
|
||||
logger.warning(
|
||||
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
|
||||
idx, total_sites, site_url, exc
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
|
||||
idx, total_sites, site_url, exc
|
||||
)
|
||||
|
||||
return directories_found, failed_sites
|
||||
|
||||
|
||||
def _run_scans_concurrently(
|
||||
enabled_tools: dict,
|
||||
sites_file: str,
|
||||
directory_scan_dir: Path,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
site_count: int,
|
||||
target_name: str
|
||||
) -> Tuple[int, int, List[str]]:
|
||||
"""
|
||||
并发执行目录扫描任务(使用 ThreadPoolTaskRunner)
|
||||
|
||||
Args:
|
||||
enabled_tools: 启用的工具配置字典
|
||||
sites_file: 站点文件路径
|
||||
directory_scan_dir: 目录扫描目录
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
site_count: 站点数量
|
||||
target_name: 目标名称(用于错误日志)
|
||||
|
||||
并发执行目录扫描任务
|
||||
|
||||
Returns:
|
||||
tuple: (total_directories, processed_sites, failed_sites)
|
||||
"""
|
||||
# 读取站点列表
|
||||
sites: List[str] = []
|
||||
with open(sites_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
site_url = line.strip()
|
||||
if site_url:
|
||||
sites.append(site_url)
|
||||
|
||||
sites = [line.strip() for line in f if line.strip()]
|
||||
|
||||
if not sites:
|
||||
logger.warning("站点列表为空")
|
||||
return 0, 0, []
|
||||
|
||||
|
||||
logger.info(
|
||||
"准备并发扫描 %d 个站点,使用工具: %s",
|
||||
len(sites), ', '.join(enabled_tools.keys())
|
||||
)
|
||||
|
||||
|
||||
total_directories = 0
|
||||
processed_sites_count = 0
|
||||
failed_sites: List[str] = []
|
||||
|
||||
# 遍历每个工具
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 每个工具独立获取 max_workers 配置
|
||||
max_workers = _get_max_workers(tool_config)
|
||||
|
||||
logger.info("="*60)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("使用工具: %s (并发模式, max_workers=%d)", tool_name, max_workers)
|
||||
logger.info("="*60)
|
||||
logger.info("=" * 60)
|
||||
user_log(scan_id, "directory_scan", f"Running {tool_name}")
|
||||
|
||||
# 如果配置了 wordlist_name,则先确保本地存在对应的字典文件(含 hash 校验)
|
||||
wordlist_name = tool_config.get('wordlist_name')
|
||||
if wordlist_name:
|
||||
try:
|
||||
local_wordlist_path = ensure_wordlist_local(wordlist_name)
|
||||
tool_config['wordlist'] = local_wordlist_path
|
||||
except Exception as exc:
|
||||
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
|
||||
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
|
||||
failed_sites.extend(sites)
|
||||
continue
|
||||
|
||||
# 计算超时时间(所有站点共用)
|
||||
# 准备字典文件
|
||||
if not _prepare_tool_wordlist(tool_name, tool_config):
|
||||
failed_sites.extend(sites)
|
||||
continue
|
||||
|
||||
# 计算超时时间
|
||||
site_timeout = tool_config.get('timeout', 300)
|
||||
if site_timeout == 'auto':
|
||||
site_timeout = calculate_directory_scan_timeout(tool_config)
|
||||
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}秒")
|
||||
|
||||
# 准备所有站点的扫描参数
|
||||
scan_params_list = []
|
||||
for idx, site_url in enumerate(sites, 1):
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='directory_scan',
|
||||
command_params={'url': site_url},
|
||||
tool_config=tool_config
|
||||
)
|
||||
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
|
||||
scan_params_list.append({
|
||||
'idx': idx,
|
||||
'site_url': site_url,
|
||||
'command': command,
|
||||
'log_file': str(log_file),
|
||||
'timeout': site_timeout
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
|
||||
idx, len(sites), tool_name, e, site_url
|
||||
)
|
||||
failed_sites.append(site_url)
|
||||
|
||||
logger.info("✓ 工具 %s 动态计算 timeout: %d秒", tool_name, site_timeout)
|
||||
|
||||
# 构建扫描参数
|
||||
scan_params_list, build_failed = _build_scan_params(
|
||||
tool_name, tool_config, sites, directory_scan_dir, site_timeout
|
||||
)
|
||||
failed_sites.extend(build_failed)
|
||||
|
||||
if not scan_params_list:
|
||||
logger.warning("没有有效的扫描任务")
|
||||
continue
|
||||
|
||||
# ============================================================
|
||||
# 分批执行策略:控制实际并发的 ffuf 进程数
|
||||
# ============================================================
|
||||
|
||||
# 分批执行
|
||||
total_tasks = len(scan_params_list)
|
||||
logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers)
|
||||
|
||||
# 进度里程碑跟踪
|
||||
|
||||
last_progress_percent = 0
|
||||
tool_directories = 0
|
||||
tool_processed = 0
|
||||
|
||||
batch_num = 0
|
||||
|
||||
for batch_start in range(0, total_tasks, max_workers):
|
||||
batch_end = min(batch_start + max_workers, total_tasks)
|
||||
batch_params = scan_params_list[batch_start:batch_end]
|
||||
batch_num += 1
|
||||
|
||||
logger.info("执行第 %d 批任务(%d-%d/%d)...", batch_num, batch_start + 1, batch_end, total_tasks)
|
||||
|
||||
# 提交当前批次的任务(非阻塞,立即返回 future)
|
||||
futures = []
|
||||
for params in batch_params:
|
||||
future = run_and_stream_save_directories_task.submit(
|
||||
cmd=params['command'],
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
site_url=params['site_url'],
|
||||
cwd=str(directory_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1000,
|
||||
timeout=params['timeout'],
|
||||
log_file=params['log_file']
|
||||
)
|
||||
futures.append((params['idx'], params['site_url'], future))
|
||||
|
||||
# 等待当前批次所有任务完成(阻塞,确保本批完成后再启动下一批)
|
||||
for idx, site_url, future in futures:
|
||||
try:
|
||||
result = future.result() # 阻塞等待单个任务完成
|
||||
directories_found = result.get('created_directories', 0)
|
||||
total_directories += directories_found
|
||||
tool_directories += directories_found
|
||||
processed_sites_count += 1
|
||||
tool_processed += 1
|
||||
|
||||
logger.info(
|
||||
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
|
||||
idx, len(sites), site_url, directories_found
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
failed_sites.append(site_url)
|
||||
if 'timeout' in str(exc).lower() or isinstance(exc, subprocess.TimeoutExpired):
|
||||
logger.warning(
|
||||
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
|
||||
idx, len(sites), site_url, exc
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
|
||||
idx, len(sites), site_url, exc
|
||||
)
|
||||
|
||||
batch_num = batch_start // max_workers + 1
|
||||
|
||||
logger.info(
|
||||
"执行第 %d 批任务(%d-%d/%d)...",
|
||||
batch_num, batch_start + 1, batch_end, total_tasks
|
||||
)
|
||||
|
||||
dirs_found, batch_failed = _execute_batch(
|
||||
batch_params, tool_name, scan_id, target_id,
|
||||
directory_scan_dir, len(sites)
|
||||
)
|
||||
|
||||
total_directories += dirs_found
|
||||
tool_directories += dirs_found
|
||||
tool_processed += len(batch_params) - len(batch_failed)
|
||||
processed_sites_count += len(batch_params) - len(batch_failed)
|
||||
failed_sites.extend(batch_failed)
|
||||
|
||||
# 进度里程碑:每 20% 输出一次
|
||||
current_progress = int((batch_end / total_tasks) * 100)
|
||||
if current_progress >= last_progress_percent + 20:
|
||||
user_log(scan_id, "directory_scan", f"Progress: {batch_end}/{total_tasks} sites scanned")
|
||||
user_log(
|
||||
scan_id, "directory_scan",
|
||||
f"Progress: {batch_end}/{total_tasks} sites scanned"
|
||||
)
|
||||
last_progress_percent = (current_progress // 20) * 20
|
||||
|
||||
# 工具完成日志(开发者日志 + 用户日志)
|
||||
|
||||
logger.info(
|
||||
"✓ 工具 %s 执行完成 - 已处理站点: %d/%d, 发现目录: %d",
|
||||
tool_name, tool_processed, total_tasks, tool_directories
|
||||
)
|
||||
user_log(scan_id, "directory_scan", f"{tool_name} completed: found {tool_directories} directories")
|
||||
|
||||
# 输出汇总信息
|
||||
if failed_sites:
|
||||
logger.warning(
|
||||
"部分站点扫描失败: %d/%d",
|
||||
len(failed_sites), len(sites)
|
||||
user_log(
|
||||
scan_id, "directory_scan",
|
||||
f"{tool_name} completed: found {tool_directories} directories"
|
||||
)
|
||||
|
||||
|
||||
if failed_sites:
|
||||
logger.warning("部分站点扫描失败: %d/%d", len(failed_sites), len(sites))
|
||||
|
||||
logger.info(
|
||||
"✓ 并发目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
|
||||
processed_sites_count, len(sites), len(failed_sites), total_directories
|
||||
)
|
||||
|
||||
|
||||
return total_directories, processed_sites_count, failed_sites
|
||||
|
||||
|
||||
@flow(
|
||||
name="directory_scan",
|
||||
name="directory_scan",
|
||||
log_prints=True,
|
||||
on_running=[on_scan_flow_running],
|
||||
on_completion=[on_scan_flow_completed],
|
||||
@@ -570,64 +396,31 @@ def directory_scan_flow(
|
||||
) -> dict:
|
||||
"""
|
||||
目录扫描 Flow
|
||||
|
||||
|
||||
主要功能:
|
||||
1. 从 target 获取所有站点的 URL
|
||||
2. 对每个站点 URL 执行目录扫描(支持 ffuf 等工具)
|
||||
3. 流式保存扫描结果到数据库 Directory 表
|
||||
|
||||
工作流程:
|
||||
Step 0: 创建工作目录
|
||||
Step 1: 导出站点 URL 列表到文件(供扫描工具使用)
|
||||
Step 2: 验证工具配置
|
||||
Step 3: 并发执行扫描工具并实时保存结果(使用 ThreadPoolTaskRunner)
|
||||
|
||||
ffuf 输出字段:
|
||||
- url: 发现的目录/文件 URL
|
||||
- length: 响应内容长度
|
||||
- status: HTTP 状态码
|
||||
- words: 响应内容单词数
|
||||
- lines: 响应内容行数
|
||||
- content_type: 内容类型
|
||||
- duration: 请求耗时
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置字典
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'scan_workspace_dir': str,
|
||||
'sites_file': str,
|
||||
'site_count': int,
|
||||
'total_directories': int, # 发现的总目录数
|
||||
'processed_sites': int, # 成功处理的站点数
|
||||
'failed_sites_count': int, # 失败的站点数
|
||||
'executed_tasks': list
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 参数错误
|
||||
RuntimeError: 执行失败
|
||||
dict: 扫描结果
|
||||
"""
|
||||
try:
|
||||
wait_for_system_load(context="directory_scan_flow")
|
||||
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始目录扫描\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
"开始目录扫描 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
)
|
||||
|
||||
user_log(scan_id, "directory_scan", "Starting directory scan")
|
||||
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
@@ -639,14 +432,14 @@ def directory_scan_flow(
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
if not enabled_tools:
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
|
||||
|
||||
# Step 1: 导出站点 URL(支持懒加载)
|
||||
sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir)
|
||||
|
||||
|
||||
# Step 1: 导出站点 URL
|
||||
sites_file, site_count = _export_site_urls(target_id, directory_scan_dir)
|
||||
|
||||
if site_count == 0:
|
||||
logger.warning("跳过目录扫描:没有站点可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "directory_scan", "Skipped: no sites to scan", "warning")
|
||||
@@ -662,16 +455,16 @@ def directory_scan_flow(
|
||||
'failed_sites_count': 0,
|
||||
'executed_tasks': ['export_sites']
|
||||
}
|
||||
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
logger.info("Step 2: 工具配置信息")
|
||||
tool_info = []
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
mw = _get_max_workers(tool_config)
|
||||
tool_info.append(f"{tool_name}(max_workers={mw})")
|
||||
tool_info = [
|
||||
f"{name}(max_workers={_get_max_workers(cfg)})"
|
||||
for name, cfg in enabled_tools.items()
|
||||
]
|
||||
logger.info("✓ 启用工具: %s", ', '.join(tool_info))
|
||||
|
||||
# Step 3: 并发执行扫描工具并实时保存结果
|
||||
|
||||
# Step 3: 并发执行扫描
|
||||
logger.info("Step 3: 并发执行扫描工具并实时保存结果")
|
||||
total_directories, processed_sites, failed_sites = _run_scans_concurrently(
|
||||
enabled_tools=enabled_tools,
|
||||
@@ -679,19 +472,20 @@ def directory_scan_flow(
|
||||
directory_scan_dir=directory_scan_dir,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
site_count=site_count,
|
||||
target_name=target_name
|
||||
)
|
||||
|
||||
# 检查是否所有站点都失败
|
||||
|
||||
if processed_sites == 0 and site_count > 0:
|
||||
logger.warning("所有站点扫描均失败 - 总站点数: %d, 失败数: %d", site_count, len(failed_sites))
|
||||
# 不抛出异常,让扫描继续
|
||||
|
||||
# 记录 Flow 完成
|
||||
logger.warning(
|
||||
"所有站点扫描均失败 - 总站点数: %d, 失败数: %d",
|
||||
site_count, len(failed_sites)
|
||||
)
|
||||
|
||||
logger.info("✓ 目录扫描完成 - 发现目录: %d", total_directories)
|
||||
user_log(scan_id, "directory_scan", f"directory_scan completed: found {total_directories} directories")
|
||||
|
||||
user_log(
|
||||
scan_id, "directory_scan",
|
||||
f"directory_scan completed: found {total_directories} directories"
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -704,7 +498,7 @@ def directory_scan_flow(
|
||||
'failed_sites_count': len(failed_sites),
|
||||
'executed_tasks': ['export_sites', 'run_and_stream_save_directories']
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("目录扫描失败: %s", e)
|
||||
raise
|
||||
raise
|
||||
|
||||
@@ -10,26 +10,22 @@
|
||||
- 流式处理输出,批量更新数据库
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.tasks.fingerprint_detect import (
|
||||
export_urls_for_fingerprint_task,
|
||||
run_xingfinger_and_stream_update_tech_task,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command, user_log
|
||||
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
|
||||
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,22 +38,19 @@ def calculate_fingerprint_detect_timeout(
|
||||
) -> int:
|
||||
"""
|
||||
根据 URL 数量计算超时时间
|
||||
|
||||
|
||||
公式:超时时间 = URL 数量 × 每 URL 基础时间
|
||||
最小值:300秒
|
||||
无上限
|
||||
|
||||
最小值:300秒,无上限
|
||||
|
||||
Args:
|
||||
url_count: URL 数量
|
||||
base_per_url: 每 URL 基础时间(秒),默认 10秒
|
||||
min_timeout: 最小超时时间(秒),默认 300秒
|
||||
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒)
|
||||
|
||||
"""
|
||||
timeout = int(url_count * base_per_url)
|
||||
return max(min_timeout, timeout)
|
||||
return max(min_timeout, int(url_count * base_per_url))
|
||||
|
||||
|
||||
|
||||
@@ -70,17 +63,17 @@ def _export_urls(
|
||||
) -> tuple[str, int]:
|
||||
"""
|
||||
导出 URL 到文件
|
||||
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
fingerprint_dir: 指纹识别目录
|
||||
source: 数据源类型
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (urls_file, total_count)
|
||||
"""
|
||||
logger.info("Step 1: 导出 URL 列表 (source=%s)", source)
|
||||
|
||||
|
||||
urls_file = str(fingerprint_dir / 'urls.txt')
|
||||
export_result = export_urls_for_fingerprint_task(
|
||||
target_id=target_id,
|
||||
@@ -88,15 +81,14 @@ def _export_urls(
|
||||
source=source,
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
|
||||
total_count = export_result['total_count']
|
||||
|
||||
logger.info(
|
||||
"✓ URL 导出完成 - 文件: %s, 数量: %d",
|
||||
export_result['output_file'],
|
||||
total_count
|
||||
)
|
||||
|
||||
|
||||
return export_result['output_file'], total_count
|
||||
|
||||
|
||||
@@ -111,7 +103,7 @@ def _run_fingerprint_detect(
|
||||
) -> tuple[dict, list]:
|
||||
"""
|
||||
执行指纹识别任务
|
||||
|
||||
|
||||
Args:
|
||||
enabled_tools: 已启用的工具配置字典
|
||||
urls_file: URL 文件路径
|
||||
@@ -120,56 +112,54 @@ def _run_fingerprint_detect(
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
source: 数据源类型
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (tool_stats, failed_tools)
|
||||
"""
|
||||
tool_stats = {}
|
||||
failed_tools = []
|
||||
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 1. 获取指纹库路径
|
||||
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
|
||||
fingerprint_paths = get_fingerprint_paths(lib_names)
|
||||
|
||||
|
||||
if not fingerprint_paths:
|
||||
reason = f"没有可用的指纹库: {lib_names}"
|
||||
logger.warning(reason)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
|
||||
# 2. 将指纹库路径合并到 tool_config(用于命令构建)
|
||||
tool_config_with_paths = {**tool_config, **fingerprint_paths}
|
||||
|
||||
|
||||
# 3. 构建命令
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='fingerprint_detect',
|
||||
command_params={
|
||||
'urls_file': urls_file
|
||||
},
|
||||
command_params={'urls_file': urls_file},
|
||||
tool_config=tool_config_with_paths
|
||||
)
|
||||
except Exception as e:
|
||||
reason = f"命令构建失败: {str(e)}"
|
||||
reason = f"命令构建失败: {e}"
|
||||
logger.error("构建 %s 命令失败: %s", tool_name, e)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
|
||||
# 4. 计算超时时间
|
||||
timeout = calculate_fingerprint_detect_timeout(url_count)
|
||||
|
||||
|
||||
# 5. 生成日志文件路径
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = fingerprint_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
|
||||
logger.info(
|
||||
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
|
||||
tool_name, url_count, timeout, list(fingerprint_paths.keys())
|
||||
)
|
||||
user_log(scan_id, "fingerprint_detect", f"Running {tool_name}: {command}")
|
||||
|
||||
|
||||
# 6. 执行扫描任务
|
||||
try:
|
||||
result = run_xingfinger_and_stream_update_tech_task(
|
||||
@@ -183,14 +173,14 @@ def _run_fingerprint_detect(
|
||||
log_file=str(log_file),
|
||||
batch_size=100
|
||||
)
|
||||
|
||||
|
||||
tool_stats[tool_name] = {
|
||||
'command': command,
|
||||
'result': result,
|
||||
'timeout': timeout,
|
||||
'fingerprint_libs': list(fingerprint_paths.keys())
|
||||
}
|
||||
|
||||
|
||||
tool_updated = result.get('updated_count', 0)
|
||||
logger.info(
|
||||
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
|
||||
@@ -199,20 +189,23 @@ def _run_fingerprint_detect(
|
||||
tool_updated,
|
||||
result.get('not_found_count', 0)
|
||||
)
|
||||
user_log(scan_id, "fingerprint_detect", f"{tool_name} completed: identified {tool_updated} fingerprints")
|
||||
|
||||
user_log(
|
||||
scan_id, "fingerprint_detect",
|
||||
f"{tool_name} completed: identified {tool_updated} fingerprints"
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
reason = str(exc)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
user_log(scan_id, "fingerprint_detect", f"{tool_name} failed: {reason}", "error")
|
||||
|
||||
|
||||
if failed_tools:
|
||||
logger.warning(
|
||||
"以下指纹识别工具执行失败: %s",
|
||||
', '.join([f['tool'] for f in failed_tools])
|
||||
)
|
||||
|
||||
|
||||
return tool_stats, failed_tools
|
||||
|
||||
|
||||
@@ -232,53 +225,38 @@ def fingerprint_detect_flow(
|
||||
) -> dict:
|
||||
"""
|
||||
指纹识别 Flow
|
||||
|
||||
|
||||
主要功能:
|
||||
1. 从数据库导出目标下所有 WebSite URL 到文件
|
||||
2. 使用 xingfinger 进行技术栈识别
|
||||
3. 解析结果并更新 WebSite.tech 字段(合并去重)
|
||||
|
||||
|
||||
工作流程:
|
||||
Step 0: 创建工作目录
|
||||
Step 1: 导出 URL 列表
|
||||
Step 2: 解析配置,获取启用的工具
|
||||
Step 3: 执行 xingfinger 并解析结果
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置(xingfinger)
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'scan_workspace_dir': str,
|
||||
'urls_file': str,
|
||||
'url_count': int,
|
||||
'processed_records': int,
|
||||
'updated_count': int,
|
||||
'created_count': int,
|
||||
'snapshot_count': int,
|
||||
'executed_tasks': list,
|
||||
'tool_stats': dict
|
||||
}
|
||||
dict: 扫描结果
|
||||
"""
|
||||
try:
|
||||
# 负载检查:等待系统资源充足
|
||||
wait_for_system_load(context="fingerprint_detect_flow")
|
||||
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始指纹识别\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
"开始指纹识别 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
)
|
||||
|
||||
user_log(scan_id, "fingerprint_detect", "Starting fingerprint detection")
|
||||
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
@@ -288,46 +266,26 @@ def fingerprint_detect_flow(
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
|
||||
|
||||
# 数据源类型(当前只支持 website)
|
||||
source = 'website'
|
||||
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
|
||||
|
||||
|
||||
# Step 1: 导出 URL(支持懒加载)
|
||||
urls_file, url_count = _export_urls(target_id, fingerprint_dir, source)
|
||||
|
||||
|
||||
if url_count == 0:
|
||||
logger.warning("跳过指纹识别:没有 URL 可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "fingerprint_detect", "Skipped: no URLs to scan", "warning")
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'url_count': 0,
|
||||
'processed_records': 0,
|
||||
'updated_count': 0,
|
||||
'created_count': 0,
|
||||
'snapshot_count': 0,
|
||||
'executed_tasks': ['export_urls_for_fingerprint'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'successful_tools': [],
|
||||
'failed_tools': [],
|
||||
'details': {}
|
||||
}
|
||||
}
|
||||
|
||||
return _build_empty_result(scan_id, target_name, scan_workspace_dir, urls_file)
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
logger.info("Step 2: 工具配置信息")
|
||||
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
|
||||
|
||||
|
||||
# Step 3: 执行指纹识别
|
||||
logger.info("Step 3: 执行指纹识别")
|
||||
tool_stats, failed_tools = _run_fingerprint_detect(
|
||||
@@ -339,24 +297,37 @@ def fingerprint_detect_flow(
|
||||
target_id=target_id,
|
||||
source=source
|
||||
)
|
||||
|
||||
|
||||
# 动态生成已执行的任务列表
|
||||
executed_tasks = ['export_urls_for_fingerprint']
|
||||
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats.keys()])
|
||||
|
||||
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats])
|
||||
|
||||
# 汇总所有工具的结果
|
||||
total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values())
|
||||
total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values())
|
||||
total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values())
|
||||
total_snapshots = sum(stats['result'].get('snapshot_count', 0) for stats in tool_stats.values())
|
||||
|
||||
total_processed = sum(
|
||||
stats['result'].get('processed_records', 0) for stats in tool_stats.values()
|
||||
)
|
||||
total_updated = sum(
|
||||
stats['result'].get('updated_count', 0) for stats in tool_stats.values()
|
||||
)
|
||||
total_created = sum(
|
||||
stats['result'].get('created_count', 0) for stats in tool_stats.values()
|
||||
)
|
||||
total_snapshots = sum(
|
||||
stats['result'].get('snapshot_count', 0) for stats in tool_stats.values()
|
||||
)
|
||||
|
||||
# 记录 Flow 完成
|
||||
logger.info("✓ 指纹识别完成 - 识别指纹: %d", total_updated)
|
||||
user_log(scan_id, "fingerprint_detect", f"fingerprint_detect completed: identified {total_updated} fingerprints")
|
||||
|
||||
successful_tools = [name for name in enabled_tools.keys()
|
||||
if name not in [f['tool'] for f in failed_tools]]
|
||||
|
||||
user_log(
|
||||
scan_id, "fingerprint_detect",
|
||||
f"fingerprint_detect completed: identified {total_updated} fingerprints"
|
||||
)
|
||||
|
||||
successful_tools = [
|
||||
name for name in enabled_tools
|
||||
if name not in [f['tool'] for f in failed_tools]
|
||||
]
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -378,7 +349,7 @@ def fingerprint_detect_flow(
|
||||
'details': tool_stats
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("配置错误: %s", e)
|
||||
raise
|
||||
@@ -388,3 +359,33 @@ def fingerprint_detect_flow(
|
||||
except Exception as e:
|
||||
logger.exception("指纹识别失败: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _build_empty_result(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
scan_workspace_dir: str,
|
||||
urls_file: str
|
||||
) -> dict:
|
||||
"""构建空结果(无 URL 可扫描时)"""
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'url_count': 0,
|
||||
'processed_records': 0,
|
||||
'updated_count': 0,
|
||||
'created_count': 0,
|
||||
'snapshot_count': 0,
|
||||
'executed_tasks': ['export_urls_for_fingerprint'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'successful_tools': [],
|
||||
'failed_tools': [],
|
||||
'details': {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,15 +99,13 @@ def initiate_scan_flow(
|
||||
raise ValueError("engine_name is required")
|
||||
|
||||
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始初始化扫描任务\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Engine: {engine_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
)
|
||||
logger.info("="*60)
|
||||
logger.info("开始初始化扫描任务")
|
||||
logger.info(f"Scan ID: {scan_id}")
|
||||
logger.info(f"Target: {target_name}")
|
||||
logger.info(f"Engine: {engine_name}")
|
||||
logger.info(f"Workspace: {scan_workspace_dir}")
|
||||
logger.info("="*60)
|
||||
|
||||
# ==================== Task 1: 创建 Scan 工作空间 ====================
|
||||
scan_workspace_path = setup_scan_workspace(scan_workspace_dir)
|
||||
@@ -126,11 +124,9 @@ def initiate_scan_flow(
|
||||
# FlowOrchestrator 已经解析了所有工具配置
|
||||
enabled_tools_by_type = orchestrator.enabled_tools_by_type
|
||||
|
||||
logger.info(
|
||||
f"执行计划生成成功:\n"
|
||||
f" 扫描类型: {' → '.join(orchestrator.scan_types)}\n"
|
||||
f" 总共 {len(orchestrator.scan_types)} 个 Flow"
|
||||
)
|
||||
logger.info("执行计划生成成功")
|
||||
logger.info(f"扫描类型: {' → '.join(orchestrator.scan_types)}")
|
||||
logger.info(f"总共 {len(orchestrator.scan_types)} 个 Flow")
|
||||
|
||||
# ==================== 初始化阶段进度 ====================
|
||||
# 在解析完配置后立即初始化,此时已有完整的 scan_types 列表
|
||||
@@ -209,9 +205,13 @@ def initiate_scan_flow(
|
||||
for mode, enabled_flows in orchestrator.get_execution_stages():
|
||||
if mode == 'sequential':
|
||||
# 顺序执行
|
||||
logger.info(f"\n{'='*60}\n顺序执行阶段: {', '.join(enabled_flows)}\n{'='*60}")
|
||||
logger.info("="*60)
|
||||
logger.info(f"顺序执行阶段: {', '.join(enabled_flows)}")
|
||||
logger.info("="*60)
|
||||
for scan_type, flow_func, flow_specific_kwargs in get_valid_flows(enabled_flows):
|
||||
logger.info(f"\n{'='*60}\n执行 Flow: {scan_type}\n{'='*60}")
|
||||
logger.info("="*60)
|
||||
logger.info(f"执行 Flow: {scan_type}")
|
||||
logger.info("="*60)
|
||||
try:
|
||||
result = flow_func(**flow_specific_kwargs)
|
||||
record_flow_result(scan_type, result=result)
|
||||
@@ -220,12 +220,16 @@ def initiate_scan_flow(
|
||||
|
||||
elif mode == 'parallel':
|
||||
# 并行执行阶段:通过 Task 包装子 Flow,并使用 Prefect TaskRunner 并发运行
|
||||
logger.info(f"\n{'='*60}\n并行执行阶段: {', '.join(enabled_flows)}\n{'='*60}")
|
||||
logger.info("="*60)
|
||||
logger.info(f"并行执行阶段: {', '.join(enabled_flows)}")
|
||||
logger.info("="*60)
|
||||
futures = []
|
||||
|
||||
# 提交所有并行子 Flow 任务
|
||||
for scan_type, flow_func, flow_specific_kwargs in get_valid_flows(enabled_flows):
|
||||
logger.info(f"\n{'='*60}\n提交并行子 Flow 任务: {scan_type}\n{'='*60}")
|
||||
logger.info("="*60)
|
||||
logger.info(f"提交并行子 Flow 任务: {scan_type}")
|
||||
logger.info("="*60)
|
||||
future = _run_subflow_task.submit(
|
||||
scan_type=scan_type,
|
||||
flow_func=flow_func,
|
||||
@@ -246,12 +250,10 @@ def initiate_scan_flow(
|
||||
record_flow_result(scan_type, error=e)
|
||||
|
||||
# ==================== 完成 ====================
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"✓ 扫描任务初始化完成\n" +
|
||||
f" 执行的 Flow: {', '.join(executed_flows)}\n" +
|
||||
"="*60
|
||||
)
|
||||
logger.info("="*60)
|
||||
logger.info("✓ 扫描任务初始化完成")
|
||||
logger.info(f"执行的 Flow: {', '.join(executed_flows)}")
|
||||
logger.info("="*60)
|
||||
|
||||
# ==================== 返回结果 ====================
|
||||
return {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""
|
||||
"""
|
||||
端口扫描 Flow
|
||||
|
||||
负责编排端口扫描的完整流程
|
||||
@@ -10,25 +10,23 @@
|
||||
- 配置由 YAML 解析
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from prefect import flow
|
||||
from apps.scan.tasks.port_scan import (
|
||||
export_hosts_task,
|
||||
run_and_stream_save_ports_task
|
||||
)
|
||||
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.utils import config_parser, build_scan_command, user_log
|
||||
from apps.scan.tasks.port_scan import (
|
||||
export_hosts_task,
|
||||
run_and_stream_save_ports_task,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -40,28 +38,19 @@ def calculate_port_scan_timeout(
|
||||
) -> int:
|
||||
"""
|
||||
根据目标数量和端口数量计算超时时间
|
||||
|
||||
|
||||
计算公式:超时时间 = 目标数 × 端口数 × base_per_pair
|
||||
超时范围:60秒 ~ 2天(172800秒)
|
||||
|
||||
超时范围:60秒 ~ 无上限
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置字典,包含端口配置(ports, top-ports等)
|
||||
file_path: 目标文件路径(域名/IP列表)
|
||||
base_per_pair: 每个"端口-目标对"的基础时间(秒),默认 0.5秒
|
||||
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒),范围:60 ~ 172800
|
||||
|
||||
Example:
|
||||
# 100个目标 × 100个端口 × 0.5秒 = 5000秒
|
||||
# 10个目标 × 1000个端口 × 0.5秒 = 5000秒
|
||||
timeout = calculate_port_scan_timeout(
|
||||
tool_config={'top-ports': 100},
|
||||
file_path='/path/to/domains.txt'
|
||||
)
|
||||
int: 计算出的超时时间(秒),最小 60 秒
|
||||
"""
|
||||
try:
|
||||
# 1. 统计目标数量
|
||||
result = subprocess.run(
|
||||
['wc', '-l', file_path],
|
||||
capture_output=True,
|
||||
@@ -69,88 +58,74 @@ def calculate_port_scan_timeout(
|
||||
check=True
|
||||
)
|
||||
target_count = int(result.stdout.strip().split()[0])
|
||||
|
||||
# 2. 解析端口数量
|
||||
port_count = _parse_port_count(tool_config)
|
||||
|
||||
# 3. 计算超时时间
|
||||
# 总工作量 = 目标数 × 端口数
|
||||
total_work = target_count * port_count
|
||||
timeout = int(total_work * base_per_pair)
|
||||
|
||||
# 4. 设置合理的下限(不再设置上限)
|
||||
min_timeout = 60 # 最小 60 秒
|
||||
timeout = max(min_timeout, timeout)
|
||||
|
||||
timeout = max(60, int(total_work * base_per_pair))
|
||||
|
||||
logger.info(
|
||||
f"计算端口扫描 timeout - "
|
||||
f"目标数: {target_count}, "
|
||||
f"端口数: {port_count}, "
|
||||
f"总工作量: {total_work}, "
|
||||
f"超时: {timeout}秒"
|
||||
"计算端口扫描 timeout - 目标数: %d, 端口数: %d, 总工作量: %d, 超时: %d秒",
|
||||
target_count, port_count, total_work, timeout
|
||||
)
|
||||
return timeout
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算 timeout 失败: {e},使用默认值 600秒")
|
||||
logger.warning("计算 timeout 失败: %s,使用默认值 600秒", e)
|
||||
return 600
|
||||
|
||||
|
||||
def _parse_port_count(tool_config: dict) -> int:
|
||||
"""
|
||||
从工具配置中解析端口数量
|
||||
|
||||
|
||||
优先级:
|
||||
1. top-ports: N → 返回 N
|
||||
2. ports: "80,443,8080" → 返回逗号分隔的数量
|
||||
3. ports: "1-1000" → 返回范围的大小
|
||||
4. ports: "1-65535" → 返回 65535
|
||||
5. 默认 → 返回 100(naabu 默认扫描 top 100)
|
||||
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置字典
|
||||
|
||||
|
||||
Returns:
|
||||
int: 端口数量
|
||||
"""
|
||||
# 1. 检查 top-ports 配置
|
||||
# 检查 top-ports 配置
|
||||
if 'top-ports' in tool_config:
|
||||
top_ports = tool_config['top-ports']
|
||||
if isinstance(top_ports, int) and top_ports > 0:
|
||||
return top_ports
|
||||
logger.warning(f"top-ports 配置无效: {top_ports},使用默认值")
|
||||
|
||||
# 2. 检查 ports 配置
|
||||
logger.warning("top-ports 配置无效: %s,使用默认值", top_ports)
|
||||
|
||||
# 检查 ports 配置
|
||||
if 'ports' in tool_config:
|
||||
ports_str = str(tool_config['ports']).strip()
|
||||
|
||||
# 2.1 逗号分隔的端口列表:80,443,8080
|
||||
|
||||
# 逗号分隔的端口列表:80,443,8080
|
||||
if ',' in ports_str:
|
||||
port_list = [p.strip() for p in ports_str.split(',') if p.strip()]
|
||||
return len(port_list)
|
||||
|
||||
# 2.2 端口范围:1-1000
|
||||
return len([p.strip() for p in ports_str.split(',') if p.strip()])
|
||||
|
||||
# 端口范围:1-1000
|
||||
if '-' in ports_str:
|
||||
try:
|
||||
start, end = ports_str.split('-', 1)
|
||||
start_port = int(start.strip())
|
||||
end_port = int(end.strip())
|
||||
|
||||
if 1 <= start_port <= end_port <= 65535:
|
||||
return end_port - start_port + 1
|
||||
logger.warning(f"端口范围无效: {ports_str},使用默认值")
|
||||
logger.warning("端口范围无效: %s,使用默认值", ports_str)
|
||||
except ValueError:
|
||||
logger.warning(f"端口范围解析失败: {ports_str},使用默认值")
|
||||
|
||||
# 2.3 单个端口
|
||||
logger.warning("端口范围解析失败: %s,使用默认值", ports_str)
|
||||
|
||||
# 单个端口
|
||||
try:
|
||||
port = int(ports_str)
|
||||
if 1 <= port <= 65535:
|
||||
return 1
|
||||
except ValueError:
|
||||
logger.warning(f"端口配置解析失败: {ports_str},使用默认值")
|
||||
|
||||
# 3. 默认值:naabu 默认扫描 top 100 端口
|
||||
logger.warning("端口配置解析失败: %s,使用默认值", ports_str)
|
||||
|
||||
# 默认值:naabu 默认扫描 top 100 端口
|
||||
return 100
|
||||
|
||||
|
||||
@@ -160,41 +135,38 @@ def _parse_port_count(tool_config: dict) -> int:
|
||||
def _export_hosts(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
|
||||
"""
|
||||
导出主机列表到文件
|
||||
|
||||
|
||||
根据 Target 类型自动决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名
|
||||
- IP: 直接写入 target.name
|
||||
- CIDR: 展开 CIDR 范围内的所有 IP
|
||||
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
port_scan_dir: 端口扫描目录
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (hosts_file, host_count, target_type)
|
||||
"""
|
||||
logger.info("Step 1: 导出主机列表")
|
||||
|
||||
|
||||
hosts_file = str(port_scan_dir / 'hosts.txt')
|
||||
export_result = export_hosts_task(
|
||||
target_id=target_id,
|
||||
output_file=hosts_file,
|
||||
batch_size=1000 # 每次读取 1000 条,优化内存占用
|
||||
)
|
||||
|
||||
|
||||
host_count = export_result['total_count']
|
||||
target_type = export_result.get('target_type', 'unknown')
|
||||
|
||||
|
||||
logger.info(
|
||||
"✓ 主机列表导出完成 - 类型: %s, 文件: %s, 数量: %d",
|
||||
target_type,
|
||||
export_result['output_file'],
|
||||
host_count
|
||||
target_type, export_result['output_file'], host_count
|
||||
)
|
||||
|
||||
|
||||
if host_count == 0:
|
||||
logger.warning("目标下没有可扫描的主机,无法执行端口扫描")
|
||||
|
||||
|
||||
return export_result['output_file'], host_count, target_type
|
||||
|
||||
|
||||
@@ -208,7 +180,7 @@ def _run_scans_sequentially(
|
||||
) -> tuple[dict, int, list, list]:
|
||||
"""
|
||||
串行执行端口扫描任务
|
||||
|
||||
|
||||
Args:
|
||||
enabled_tools: 已启用的工具配置字典
|
||||
domains_file: 域名文件路径
|
||||
@@ -216,72 +188,56 @@ def _run_scans_sequentially(
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(用于错误日志)
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
|
||||
注意:端口扫描是流式输出,不生成结果文件
|
||||
|
||||
Raises:
|
||||
RuntimeError: 所有工具均失败
|
||||
"""
|
||||
# ==================== 构建命令并串行执行 ====================
|
||||
|
||||
tool_stats = {}
|
||||
processed_records = 0
|
||||
failed_tools = [] # 记录失败的工具(含原因)
|
||||
|
||||
# for循环执行工具:按顺序串行运行每个启用的端口扫描工具
|
||||
failed_tools = []
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 1. 构建完整命令(变量替换)
|
||||
# 构建命令
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='port_scan',
|
||||
command_params={
|
||||
'domains_file': domains_file # 对应 {domains_file}
|
||||
},
|
||||
tool_config=tool_config #yaml的工具配置
|
||||
command_params={'domains_file': domains_file},
|
||||
tool_config=tool_config
|
||||
)
|
||||
except Exception as e:
|
||||
reason = f"命令构建失败: {str(e)}"
|
||||
logger.error(f"构建 {tool_name} 命令失败: {e}")
|
||||
reason = f"命令构建失败: {e}"
|
||||
logger.error("构建 %s 命令失败: %s", tool_name, e)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
# 2. 获取超时时间(支持 'auto' 动态计算)
|
||||
|
||||
# 获取超时时间
|
||||
config_timeout = tool_config['timeout']
|
||||
if config_timeout == 'auto':
|
||||
# 动态计算超时时间
|
||||
config_timeout = calculate_port_scan_timeout(
|
||||
tool_config=tool_config,
|
||||
file_path=str(domains_file)
|
||||
)
|
||||
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {config_timeout}秒")
|
||||
|
||||
# 2.1 生成日志文件路径
|
||||
from datetime import datetime
|
||||
config_timeout = calculate_port_scan_timeout(tool_config, str(domains_file))
|
||||
logger.info("✓ 工具 %s 动态计算 timeout: %d秒", tool_name, config_timeout)
|
||||
|
||||
# 生成日志文件路径
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = port_scan_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
# 3. 执行扫描任务
|
||||
|
||||
logger.info("开始执行 %s 扫描(超时: %d秒)...", tool_name, config_timeout)
|
||||
user_log(scan_id, "port_scan", f"Running {tool_name}: {command}")
|
||||
|
||||
|
||||
# 执行扫描任务
|
||||
try:
|
||||
# 直接调用 task(串行执行)
|
||||
# 注意:端口扫描是流式输出到 stdout,不使用 output_file
|
||||
result = run_and_stream_save_ports_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name, # 工具名称
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(port_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1000,
|
||||
timeout=config_timeout,
|
||||
log_file=str(log_file) # 新增:日志文件路径
|
||||
log_file=str(log_file)
|
||||
)
|
||||
|
||||
|
||||
tool_stats[tool_name] = {
|
||||
'command': command,
|
||||
'result': result,
|
||||
@@ -289,15 +245,10 @@ def _run_scans_sequentially(
|
||||
}
|
||||
tool_records = result.get('processed_records', 0)
|
||||
processed_records += tool_records
|
||||
logger.info(
|
||||
"✓ 工具 %s 流式处理完成 - 记录数: %d",
|
||||
tool_name, tool_records
|
||||
)
|
||||
logger.info("✓ 工具 %s 流式处理完成 - 记录数: %d", tool_name, tool_records)
|
||||
user_log(scan_id, "port_scan", f"{tool_name} completed: found {tool_records} ports")
|
||||
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
# 超时异常单独处理
|
||||
# 注意:流式处理任务超时时,已解析的数据已保存到数据库
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
reason = f"timeout after {config_timeout}s"
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.warning(
|
||||
@@ -307,40 +258,39 @@ def _run_scans_sequentially(
|
||||
)
|
||||
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
|
||||
except Exception as exc:
|
||||
# 其他异常
|
||||
reason = str(exc)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
|
||||
|
||||
|
||||
if failed_tools:
|
||||
logger.warning(
|
||||
"以下扫描工具执行失败: %s",
|
||||
', '.join([f['tool'] for f in failed_tools])
|
||||
)
|
||||
|
||||
|
||||
if not tool_stats:
|
||||
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
|
||||
logger.warning("所有端口扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
|
||||
# 返回空结果,不抛出异常,让扫描继续
|
||||
return {}, 0, [], failed_tools
|
||||
|
||||
# 动态计算成功的工具列表
|
||||
successful_tool_names = [name for name in enabled_tools.keys()
|
||||
if name not in [f['tool'] for f in failed_tools]]
|
||||
|
||||
|
||||
successful_tool_names = [
|
||||
name for name in enabled_tools
|
||||
if name not in [f['tool'] for f in failed_tools]
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"✓ 串行端口扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
|
||||
len(tool_stats), len(enabled_tools),
|
||||
', '.join(successful_tool_names) if successful_tool_names else '无',
|
||||
', '.join([f['tool'] for f in failed_tools]) if failed_tools else '无'
|
||||
)
|
||||
|
||||
|
||||
return tool_stats, processed_records, successful_tool_names, failed_tools
|
||||
|
||||
|
||||
@flow(
|
||||
name="port_scan",
|
||||
name="port_scan",
|
||||
log_prints=True,
|
||||
on_running=[on_scan_flow_running],
|
||||
on_completion=[on_scan_flow_completed],
|
||||
@@ -355,19 +305,19 @@ def port_scan_flow(
|
||||
) -> dict:
|
||||
"""
|
||||
端口扫描 Flow
|
||||
|
||||
|
||||
主要功能:
|
||||
1. 扫描目标域名/IP 的开放端口
|
||||
2. 保存 host + ip + port 三元映射到 HostPortMapping 表
|
||||
|
||||
|
||||
输出资产:
|
||||
- HostPortMapping:主机端口映射(host + ip + port 三元组)
|
||||
|
||||
|
||||
工作流程:
|
||||
Step 0: 创建工作目录
|
||||
Step 1: 导出域名列表到文件(供扫描工具使用)
|
||||
Step 2: 解析配置,获取启用的工具
|
||||
Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库(→ HostPortMapping)
|
||||
Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
@@ -377,35 +327,15 @@ def port_scan_flow(
|
||||
enabled_tools: 启用的工具配置字典
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'scan_workspace_dir': str,
|
||||
'hosts_file': str,
|
||||
'host_count': int,
|
||||
'processed_records': int,
|
||||
'executed_tasks': list,
|
||||
'tool_stats': {
|
||||
'total': int, # 总工具数
|
||||
'successful': int, # 成功工具数
|
||||
'failed': int, # 失败工具数
|
||||
'successful_tools': list[str], # 成功工具列表 ['naabu_active']
|
||||
'failed_tools': list[dict], # 失败工具列表 [{'tool': 'naabu_passive', 'reason': '超时'}]
|
||||
'details': dict # 详细执行结果(保留向后兼容)
|
||||
}
|
||||
}
|
||||
dict: 扫描结果
|
||||
|
||||
Raises:
|
||||
ValueError: 配置错误
|
||||
RuntimeError: 执行失败
|
||||
|
||||
Note:
|
||||
端口扫描工具(如 naabu)会解析域名获取 IP,输出 host + ip + port 三元组。
|
||||
同一 host 可能对应多个 IP(CDN、负载均衡),因此使用三元映射表存储。
|
||||
"""
|
||||
try:
|
||||
# 参数验证
|
||||
wait_for_system_load(context="port_scan_flow")
|
||||
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
@@ -416,25 +346,20 @@ def port_scan_flow(
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
if not enabled_tools:
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始端口扫描\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
"开始端口扫描 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
)
|
||||
|
||||
user_log(scan_id, "port_scan", "Starting port scan")
|
||||
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
|
||||
|
||||
# Step 1: 导出主机列表到文件(根据 Target 类型自动决定内容)
|
||||
|
||||
# Step 1: 导出主机列表
|
||||
hosts_file, host_count, target_type = _export_hosts(target_id, port_scan_dir)
|
||||
|
||||
|
||||
if host_count == 0:
|
||||
logger.warning("跳过端口扫描:没有主机可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "port_scan", "Skipped: no hosts to scan", "warning")
|
||||
@@ -457,14 +382,11 @@ def port_scan_flow(
|
||||
'details': {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
logger.info("Step 2: 工具配置信息")
|
||||
logger.info(
|
||||
"✓ 启用工具: %s",
|
||||
', '.join(enabled_tools.keys())
|
||||
)
|
||||
|
||||
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
|
||||
|
||||
# Step 3: 串行执行扫描工具
|
||||
logger.info("Step 3: 串行执行扫描工具")
|
||||
tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially(
|
||||
@@ -475,15 +397,13 @@ def port_scan_flow(
|
||||
target_id=target_id,
|
||||
target_name=target_name
|
||||
)
|
||||
|
||||
# 记录 Flow 完成
|
||||
|
||||
logger.info("✓ 端口扫描完成 - 发现端口: %d", processed_records)
|
||||
user_log(scan_id, "port_scan", f"port_scan completed: found {processed_records} ports")
|
||||
|
||||
# 动态生成已执行的任务列表
|
||||
|
||||
executed_tasks = ['export_hosts', 'parse_config']
|
||||
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats.keys()])
|
||||
|
||||
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats])
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
|
||||
208
backend/apps/scan/flows/screenshot_flow.py
Normal file
208
backend/apps/scan/flows/screenshot_flow.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
截图 Flow
|
||||
|
||||
负责编排截图的完整流程:
|
||||
1. 从数据库获取 URL 列表(websites 和/或 endpoints)
|
||||
2. 批量截图并保存快照
|
||||
3. 同步到资产表
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider
|
||||
from apps.scan.services.target_export_service import DataSource, get_urls_with_fallback
|
||||
from apps.scan.tasks.screenshot import capture_screenshots_task
|
||||
from apps.scan.utils import user_log, wait_for_system_load
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# URL 来源到 DataSource 的映射
|
||||
_SOURCE_MAPPING = {
|
||||
'websites': DataSource.WEBSITE,
|
||||
'endpoints': DataSource.ENDPOINT,
|
||||
}
|
||||
|
||||
|
||||
def _parse_screenshot_config(enabled_tools: dict) -> dict:
|
||||
"""解析截图配置"""
|
||||
playwright_config = enabled_tools.get('playwright', {})
|
||||
return {
|
||||
'concurrency': playwright_config.get('concurrency', 5),
|
||||
'url_sources': playwright_config.get('url_sources', ['websites'])
|
||||
}
|
||||
|
||||
|
||||
def _map_url_sources_to_data_sources(url_sources: list[str]) -> list[str]:
|
||||
"""将配置中的 url_sources 映射为 DataSource 常量"""
|
||||
sources = []
|
||||
for source in url_sources:
|
||||
if source in _SOURCE_MAPPING:
|
||||
sources.append(_SOURCE_MAPPING[source])
|
||||
else:
|
||||
logger.warning("未知的 URL 来源: %s,跳过", source)
|
||||
|
||||
# 添加默认回退(从 subdomain 构造)
|
||||
sources.append(DataSource.DEFAULT)
|
||||
return sources
|
||||
|
||||
|
||||
def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str, list[str]]:
|
||||
"""从 Provider 收集 URL"""
|
||||
logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__)
|
||||
urls = list(provider.iter_urls())
|
||||
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
if blacklist_filter:
|
||||
urls = [url for url in urls if blacklist_filter.is_allowed(url)]
|
||||
|
||||
return urls, 'provider', ['provider']
|
||||
|
||||
|
||||
def _collect_urls_from_database(
|
||||
target_id: int,
|
||||
url_sources: list[str]
|
||||
) -> tuple[list[str], str, list[str]]:
|
||||
"""从数据库收集 URL(带黑名单过滤和回退)"""
|
||||
data_sources = _map_url_sources_to_data_sources(url_sources)
|
||||
result = get_urls_with_fallback(target_id, sources=data_sources)
|
||||
return result['urls'], result['source'], result['tried_sources']
|
||||
|
||||
|
||||
def _build_empty_result(scan_id: int, target_name: str) -> dict:
|
||||
"""构建空结果"""
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'total_urls': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'synced': 0
|
||||
}
|
||||
|
||||
|
||||
@flow(
|
||||
name="screenshot",
|
||||
log_prints=True,
|
||||
on_running=[on_scan_flow_running],
|
||||
on_completion=[on_scan_flow_completed],
|
||||
on_failure=[on_scan_flow_failed],
|
||||
)
|
||||
def screenshot_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict,
|
||||
provider: Optional[TargetProvider] = None
|
||||
) -> dict:
|
||||
"""
|
||||
截图 Flow
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置
|
||||
provider: TargetProvider 实例(新模式,可选)
|
||||
|
||||
Returns:
|
||||
截图结果字典
|
||||
"""
|
||||
try:
|
||||
# 负载检查:等待系统资源充足
|
||||
wait_for_system_load(context="screenshot_flow")
|
||||
|
||||
mode = 'Provider' if provider else 'Legacy'
|
||||
logger.info(
|
||||
"开始截图扫描 - Scan ID: %s, Target: %s, Mode: %s",
|
||||
scan_id, target_name, mode
|
||||
)
|
||||
user_log(scan_id, "screenshot", "Starting screenshot capture")
|
||||
|
||||
# Step 1: 解析配置
|
||||
config = _parse_screenshot_config(enabled_tools)
|
||||
concurrency = config['concurrency']
|
||||
logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, config['url_sources'])
|
||||
|
||||
# Step 2: 收集 URL 列表
|
||||
if provider is not None:
|
||||
urls, source_info, tried_sources = _collect_urls_from_provider(provider)
|
||||
else:
|
||||
urls, source_info, tried_sources = _collect_urls_from_database(
|
||||
target_id, config['url_sources']
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"URL 收集完成 - 来源: %s, 数量: %d, 尝试过: %s",
|
||||
source_info, len(urls), tried_sources
|
||||
)
|
||||
|
||||
if not urls:
|
||||
logger.warning("没有可截图的 URL,跳过截图任务")
|
||||
user_log(scan_id, "screenshot", "Skipped: no URLs to capture", "warning")
|
||||
return _build_empty_result(scan_id, target_name)
|
||||
|
||||
user_log(
|
||||
scan_id, "screenshot",
|
||||
f"Found {len(urls)} URLs to capture (source: {source_info})"
|
||||
)
|
||||
|
||||
# Step 3: 批量截图
|
||||
logger.info("批量截图 - %d 个 URL", len(urls))
|
||||
capture_result = capture_screenshots_task(
|
||||
urls=urls,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
config={'concurrency': concurrency}
|
||||
)
|
||||
|
||||
# Step 4: 同步到资产表
|
||||
logger.info("同步截图到资产表")
|
||||
from apps.asset.services.screenshot_service import ScreenshotService
|
||||
synced = ScreenshotService().sync_screenshots_to_asset(scan_id, target_id)
|
||||
|
||||
total = capture_result['total']
|
||||
successful = capture_result['successful']
|
||||
failed = capture_result['failed']
|
||||
|
||||
logger.info(
|
||||
"✓ 截图完成 - 总数: %d, 成功: %d, 失败: %d, 同步: %d",
|
||||
total, successful, failed, synced
|
||||
)
|
||||
user_log(
|
||||
scan_id, "screenshot",
|
||||
f"Screenshot completed: {successful}/{total} captured, {synced} synced"
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'total_urls': total,
|
||||
'successful': successful,
|
||||
'failed': failed,
|
||||
'synced': synced
|
||||
}
|
||||
|
||||
except Exception:
|
||||
logger.exception("截图 Flow 失败")
|
||||
user_log(scan_id, "screenshot", "Screenshot failed", "error")
|
||||
raise
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
"""
|
||||
站点扫描 Flow
|
||||
|
||||
@@ -11,303 +10,319 @@
|
||||
- 配置由 YAML 解析
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
|
||||
from prefect import flow
|
||||
from apps.scan.tasks.site_scan import export_site_urls_task, run_and_stream_save_websites_task
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect # noqa: F401
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.utils import config_parser, build_scan_command, user_log
|
||||
from apps.scan.tasks.site_scan import (
|
||||
export_site_urls_task,
|
||||
run_and_stream_save_websites_task,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_timeout_by_line_count(
|
||||
tool_config: dict,
|
||||
file_path: str,
|
||||
base_per_time: int = 1,
|
||||
min_timeout: int = 60
|
||||
) -> int:
|
||||
"""
|
||||
根据文件行数计算 timeout
|
||||
|
||||
使用 wc -l 统计文件行数,根据行数和每行基础时间计算 timeout
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置字典(此函数未使用,但保持接口一致性)
|
||||
file_path: 要统计行数的文件路径
|
||||
base_per_time: 每行的基础时间(秒),默认1秒
|
||||
min_timeout: 最小超时时间(秒),默认60秒
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒),不低于 min_timeout
|
||||
|
||||
Example:
|
||||
timeout = calculate_timeout_by_line_count(
|
||||
tool_config={},
|
||||
file_path='/path/to/urls.txt',
|
||||
base_per_time=2
|
||||
)
|
||||
"""
|
||||
@dataclass
|
||||
class ScanContext:
|
||||
"""扫描上下文,封装扫描参数"""
|
||||
scan_id: int
|
||||
target_id: int
|
||||
target_name: str
|
||||
site_scan_dir: Path
|
||||
urls_file: str
|
||||
total_urls: int
|
||||
|
||||
|
||||
def _count_file_lines(file_path: str) -> int:
|
||||
"""使用 wc -l 统计文件行数"""
|
||||
try:
|
||||
# 使用 wc -l 快速统计行数
|
||||
result = subprocess.run(
|
||||
['wc', '-l', file_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
# wc -l 输出格式:行数 + 空格 + 文件名
|
||||
line_count = int(result.stdout.strip().split()[0])
|
||||
|
||||
# 计算 timeout:行数 × 每行基础时间,不低于最小值
|
||||
timeout = max(line_count * base_per_time, min_timeout)
|
||||
|
||||
logger.info(
|
||||
f"timeout 自动计算: 文件={file_path}, "
|
||||
f"行数={line_count}, 每行时间={base_per_time}秒, 最小值={min_timeout}秒, timeout={timeout}秒"
|
||||
)
|
||||
|
||||
return timeout
|
||||
|
||||
except Exception as e:
|
||||
# 如果 wc -l 失败,使用默认值
|
||||
logger.warning(f"wc -l 计算行数失败: {e},使用默认 timeout: {min_timeout}秒")
|
||||
return min_timeout
|
||||
return int(result.stdout.strip().split()[0])
|
||||
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
|
||||
logger.warning("wc -l 计算行数失败: %s,返回 0", e)
|
||||
return 0
|
||||
|
||||
|
||||
def _calculate_timeout_by_line_count(
|
||||
file_path: str,
|
||||
base_per_time: int = 1,
|
||||
min_timeout: int = 60
|
||||
) -> int:
|
||||
"""
|
||||
根据文件行数计算 timeout
|
||||
|
||||
Args:
|
||||
file_path: 要统计行数的文件路径
|
||||
base_per_time: 每行的基础时间(秒),默认1秒
|
||||
min_timeout: 最小超时时间(秒),默认60秒
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒),不低于 min_timeout
|
||||
"""
|
||||
line_count = _count_file_lines(file_path)
|
||||
timeout = max(line_count * base_per_time, min_timeout)
|
||||
|
||||
logger.info(
|
||||
"timeout 自动计算: 文件=%s, 行数=%d, 每行时间=%d秒, timeout=%d秒",
|
||||
file_path, line_count, base_per_time, timeout
|
||||
)
|
||||
return timeout
|
||||
|
||||
|
||||
|
||||
def _export_site_urls(target_id: int, site_scan_dir: Path, target_name: str = None) -> tuple[str, int, int]:
|
||||
def _export_site_urls(
|
||||
target_id: int,
|
||||
site_scan_dir: Path
|
||||
) -> tuple[str, int, int]:
|
||||
"""
|
||||
导出站点 URL 到文件
|
||||
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
site_scan_dir: 站点扫描目录
|
||||
target_name: 目标名称(用于懒加载时写入默认值)
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (urls_file, total_urls, association_count)
|
||||
|
||||
Raises:
|
||||
ValueError: URL 数量为 0
|
||||
"""
|
||||
logger.info("Step 1: 导出站点URL列表")
|
||||
|
||||
|
||||
urls_file = str(site_scan_dir / 'site_urls.txt')
|
||||
export_result = export_site_urls_task(
|
||||
target_id=target_id,
|
||||
output_file=urls_file,
|
||||
batch_size=1000 # 每次处理1000个子域名
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
|
||||
total_urls = export_result['total_urls']
|
||||
association_count = export_result['association_count'] # 主机端口关联数
|
||||
|
||||
association_count = export_result['association_count']
|
||||
|
||||
logger.info(
|
||||
"✓ 站点URL导出完成 - 文件: %s, URL数量: %d, 关联数: %d",
|
||||
export_result['output_file'],
|
||||
total_urls,
|
||||
association_count
|
||||
export_result['output_file'], total_urls, association_count
|
||||
)
|
||||
|
||||
|
||||
if total_urls == 0:
|
||||
logger.warning("目标下没有可用的站点URL,无法执行站点扫描")
|
||||
# 不抛出异常,由上层决定如何处理
|
||||
# raise ValueError("目标下没有可用的站点URL,无法执行站点扫描")
|
||||
|
||||
|
||||
return export_result['output_file'], total_urls, association_count
|
||||
|
||||
|
||||
def _get_tool_timeout(tool_config: dict, urls_file: str) -> int:
|
||||
"""获取工具超时时间(支持 'auto' 动态计算)"""
|
||||
config_timeout = tool_config.get('timeout', 300)
|
||||
|
||||
if config_timeout == 'auto':
|
||||
return _calculate_timeout_by_line_count(urls_file, base_per_time=1)
|
||||
|
||||
dynamic_timeout = _calculate_timeout_by_line_count(urls_file, base_per_time=1)
|
||||
return max(dynamic_timeout, config_timeout)
|
||||
|
||||
|
||||
def _execute_single_tool(
|
||||
tool_name: str,
|
||||
tool_config: dict,
|
||||
ctx: ScanContext
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
执行单个扫描工具
|
||||
|
||||
Returns:
|
||||
成功返回结果字典,失败返回 None
|
||||
"""
|
||||
# 构建命令
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='site_scan',
|
||||
command_params={'url_file': ctx.urls_file},
|
||||
tool_config=tool_config
|
||||
)
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.error("构建 %s 命令失败: %s", tool_name, e)
|
||||
return None
|
||||
|
||||
timeout = _get_tool_timeout(tool_config, ctx.urls_file)
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = ctx.site_scan_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
logger.info(
|
||||
"开始执行 %s 站点扫描 - URL数: %d, 超时: %ds",
|
||||
tool_name, ctx.total_urls, timeout
|
||||
)
|
||||
user_log(ctx.scan_id, "site_scan", f"Running {tool_name}: {command}")
|
||||
|
||||
try:
|
||||
result = run_and_stream_save_websites_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name,
|
||||
scan_id=ctx.scan_id,
|
||||
target_id=ctx.target_id,
|
||||
cwd=str(ctx.site_scan_dir),
|
||||
shell=True,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file)
|
||||
)
|
||||
|
||||
tool_created = result.get('created_websites', 0)
|
||||
skipped = result.get('skipped_no_subdomain', 0) + result.get('skipped_failed', 0)
|
||||
|
||||
logger.info(
|
||||
"✓ 工具 %s 完成 - 处理: %d, 创建: %d, 跳过: %d",
|
||||
tool_name, result.get('processed_records', 0), tool_created, skipped
|
||||
)
|
||||
user_log(
|
||||
ctx.scan_id, "site_scan",
|
||||
f"{tool_name} completed: found {tool_created} websites"
|
||||
)
|
||||
|
||||
return {'command': command, 'result': result, 'timeout': timeout}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(
|
||||
"⚠️ 工具 %s 执行超时 - 超时配置: %d秒 (超时前数据已保存)",
|
||||
tool_name, timeout
|
||||
)
|
||||
user_log(
|
||||
ctx.scan_id, "site_scan",
|
||||
f"{tool_name} failed: timeout after {timeout}s", "error"
|
||||
)
|
||||
except (OSError, RuntimeError) as exc:
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
user_log(ctx.scan_id, "site_scan", f"{tool_name} failed: {exc}", "error")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _run_scans_sequentially(
|
||||
enabled_tools: dict,
|
||||
urls_file: str,
|
||||
total_urls: int,
|
||||
site_scan_dir: Path,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
target_name: str
|
||||
ctx: ScanContext
|
||||
) -> tuple[dict, int, list, list]:
|
||||
"""
|
||||
串行执行站点扫描任务
|
||||
|
||||
Args:
|
||||
enabled_tools: 已启用的工具配置字典
|
||||
urls_file: URL 文件路径
|
||||
total_urls: URL 总数
|
||||
site_scan_dir: 站点扫描目录
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(用于错误日志)
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 所有工具均失败
|
||||
tuple: (tool_stats, processed_records, successful_tools, failed_tools)
|
||||
"""
|
||||
tool_stats = {}
|
||||
processed_records = 0
|
||||
failed_tools = []
|
||||
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 1. 构建完整命令(变量替换)
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='site_scan',
|
||||
command_params={
|
||||
'url_file': urls_file
|
||||
},
|
||||
tool_config=tool_config
|
||||
)
|
||||
except Exception as e:
|
||||
reason = f"命令构建失败: {str(e)}"
|
||||
logger.error(f"构建 {tool_name} 命令失败: {e}")
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
# 2. 获取超时时间(支持 'auto' 动态计算)
|
||||
config_timeout = tool_config.get('timeout', 300)
|
||||
if config_timeout == 'auto':
|
||||
# 动态计算超时时间
|
||||
timeout = calculate_timeout_by_line_count(tool_config, urls_file, base_per_time=1)
|
||||
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {timeout}秒")
|
||||
result = _execute_single_tool(tool_name, tool_config, ctx)
|
||||
|
||||
if result:
|
||||
tool_stats[tool_name] = result
|
||||
processed_records += result['result'].get('processed_records', 0)
|
||||
else:
|
||||
# 使用配置的超时时间和动态计算的较大值
|
||||
dynamic_timeout = calculate_timeout_by_line_count(tool_config, urls_file, base_per_time=1)
|
||||
timeout = max(dynamic_timeout, config_timeout)
|
||||
|
||||
# 2.1 生成日志文件路径(类似端口扫描)
|
||||
from datetime import datetime
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = site_scan_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
logger.info(
|
||||
"开始执行 %s 站点扫描 - URL数: %d, 最终超时: %ds",
|
||||
tool_name, total_urls, timeout
|
||||
)
|
||||
user_log(scan_id, "site_scan", f"Running {tool_name}: {command}")
|
||||
|
||||
# 3. 执行扫描任务
|
||||
try:
|
||||
# 流式执行扫描并实时保存结果
|
||||
result = run_and_stream_save_websites_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(site_scan_dir),
|
||||
shell=True,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file)
|
||||
)
|
||||
|
||||
tool_stats[tool_name] = {
|
||||
'command': command,
|
||||
'result': result,
|
||||
'timeout': timeout
|
||||
}
|
||||
tool_records = result.get('processed_records', 0)
|
||||
tool_created = result.get('created_websites', 0)
|
||||
processed_records += tool_records
|
||||
|
||||
logger.info(
|
||||
"✓ 工具 %s 流式处理完成 - 处理记录: %d, 创建站点: %d, 跳过: %d",
|
||||
tool_name,
|
||||
tool_records,
|
||||
tool_created,
|
||||
result.get('skipped_no_subdomain', 0) + result.get('skipped_failed', 0)
|
||||
)
|
||||
user_log(scan_id, "site_scan", f"{tool_name} completed: found {tool_created} websites")
|
||||
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
# 超时异常单独处理
|
||||
reason = f"timeout after {timeout}s"
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.warning(
|
||||
"⚠️ 工具 %s 执行超时 - 超时配置: %d秒\n"
|
||||
"注意:超时前已解析的站点数据已保存到数据库,但扫描未完全完成。",
|
||||
tool_name, timeout
|
||||
)
|
||||
user_log(scan_id, "site_scan", f"{tool_name} failed: {reason}", "error")
|
||||
except Exception as exc:
|
||||
# 其他异常
|
||||
reason = str(exc)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
user_log(scan_id, "site_scan", f"{tool_name} failed: {reason}", "error")
|
||||
|
||||
failed_tools.append({'tool': tool_name, 'reason': '执行失败'})
|
||||
|
||||
if failed_tools:
|
||||
logger.warning(
|
||||
"以下扫描工具执行失败: %s",
|
||||
', '.join([f['tool'] for f in failed_tools])
|
||||
', '.join(f['tool'] for f in failed_tools)
|
||||
)
|
||||
|
||||
|
||||
if not tool_stats:
|
||||
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
|
||||
logger.warning("所有站点扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
|
||||
# 返回空结果,不抛出异常,让扫描继续
|
||||
logger.warning(
|
||||
"所有站点扫描工具均失败 - 目标: %s", ctx.target_name
|
||||
)
|
||||
return {}, 0, [], failed_tools
|
||||
|
||||
# 动态计算成功的工具列表
|
||||
successful_tool_names = [name for name in enabled_tools.keys()
|
||||
if name not in [f['tool'] for f in failed_tools]]
|
||||
|
||||
|
||||
successful_tools = [
|
||||
name for name in enabled_tools
|
||||
if name not in {f['tool'] for f in failed_tools}
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"✓ 串行站点扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
|
||||
len(tool_stats), len(enabled_tools),
|
||||
', '.join(successful_tool_names) if successful_tool_names else '无',
|
||||
', '.join([f['tool'] for f in failed_tools]) if failed_tools else '无'
|
||||
"✓ 站点扫描执行完成 - 成功: %d/%d",
|
||||
len(tool_stats), len(enabled_tools)
|
||||
)
|
||||
|
||||
return tool_stats, processed_records, successful_tool_names, failed_tools
|
||||
|
||||
return tool_stats, processed_records, successful_tools, failed_tools
|
||||
|
||||
|
||||
def calculate_timeout(url_count: int, base: int = 600, per_url: int = 1) -> int:
|
||||
"""
|
||||
根据 URL 数量动态计算扫描超时时间
|
||||
def _build_empty_result(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
scan_workspace_dir: str,
|
||||
urls_file: str,
|
||||
association_count: int
|
||||
) -> dict:
|
||||
"""构建空结果(无 URL 可扫描时)"""
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'total_urls': 0,
|
||||
'association_count': association_count,
|
||||
'processed_records': 0,
|
||||
'created_websites': 0,
|
||||
'skipped_no_subdomain': 0,
|
||||
'skipped_failed': 0,
|
||||
'executed_tasks': ['export_site_urls'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'successful_tools': [],
|
||||
'failed_tools': [],
|
||||
'details': {}
|
||||
}
|
||||
}
|
||||
|
||||
规则:
|
||||
- 基础时间:默认 600 秒(10 分钟)
|
||||
- 每个 URL 额外增加:默认 1 秒
|
||||
|
||||
Args:
|
||||
url_count: URL 数量,必须为正整数
|
||||
base: 基础超时时间(秒),默认 600
|
||||
per_url: 每个 URL 增加的时间(秒),默认 1
|
||||
def _aggregate_tool_results(tool_stats: dict) -> tuple[int, int, int]:
|
||||
"""汇总工具结果"""
|
||||
total_created = sum(
|
||||
s['result'].get('created_websites', 0) for s in tool_stats.values()
|
||||
)
|
||||
total_skipped_no_subdomain = sum(
|
||||
s['result'].get('skipped_no_subdomain', 0) for s in tool_stats.values()
|
||||
)
|
||||
total_skipped_failed = sum(
|
||||
s['result'].get('skipped_failed', 0) for s in tool_stats.values()
|
||||
)
|
||||
return total_created, total_skipped_no_subdomain, total_skipped_failed
|
||||
|
||||
Returns:
|
||||
int: 计算得到的超时时间(秒),不超过 max_timeout
|
||||
|
||||
Raises:
|
||||
ValueError: 当 url_count 为负数或 0 时抛出异常
|
||||
"""
|
||||
if url_count < 0:
|
||||
raise ValueError(f"URL数量不能为负数: {url_count}")
|
||||
if url_count == 0:
|
||||
raise ValueError("URL数量不能为0")
|
||||
|
||||
timeout = base + int(url_count * per_url)
|
||||
|
||||
# 不设置上限,由调用方根据需要控制
|
||||
return timeout
|
||||
def _validate_flow_params(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str
|
||||
) -> None:
|
||||
"""验证 Flow 参数"""
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
|
||||
|
||||
@flow(
|
||||
name="site_scan",
|
||||
name="site_scan",
|
||||
log_prints=True,
|
||||
on_running=[on_scan_flow_running],
|
||||
on_completion=[on_scan_flow_completed],
|
||||
@@ -322,140 +337,83 @@ def site_scan_flow(
|
||||
) -> dict:
|
||||
"""
|
||||
站点扫描 Flow
|
||||
|
||||
|
||||
主要功能:
|
||||
1. 从target获取所有子域名与其对应的端口号,拼接成URL写入文件
|
||||
2. 用httpx进行批量请求并实时保存到数据库(流式处理)
|
||||
|
||||
工作流程:
|
||||
Step 0: 创建工作目录
|
||||
Step 1: 导出站点 URL 列表
|
||||
Step 2: 解析配置,获取启用的工具
|
||||
Step 3: 串行执行扫描工具并实时保存结果
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置字典
|
||||
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'scan_workspace_dir': str,
|
||||
'urls_file': str,
|
||||
'total_urls': int,
|
||||
'association_count': int,
|
||||
'processed_records': int,
|
||||
'created_websites': int,
|
||||
'skipped_no_subdomain': int,
|
||||
'skipped_failed': int,
|
||||
'executed_tasks': list,
|
||||
'tool_stats': {
|
||||
'total': int,
|
||||
'successful': int,
|
||||
'failed': int,
|
||||
'successful_tools': list[str],
|
||||
'failed_tools': list[dict]
|
||||
}
|
||||
}
|
||||
|
||||
dict: 扫描结果
|
||||
|
||||
Raises:
|
||||
ValueError: 配置错误
|
||||
RuntimeError: 执行失败
|
||||
"""
|
||||
try:
|
||||
wait_for_system_load(context="site_scan_flow")
|
||||
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始站点扫描\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
"开始站点扫描 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
)
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
|
||||
|
||||
_validate_flow_params(scan_id, target_name, target_id, scan_workspace_dir)
|
||||
user_log(scan_id, "site_scan", "Starting site scan")
|
||||
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan')
|
||||
|
||||
|
||||
# Step 1: 导出站点 URL
|
||||
urls_file, total_urls, association_count = _export_site_urls(
|
||||
target_id, site_scan_dir, target_name
|
||||
target_id, site_scan_dir
|
||||
)
|
||||
|
||||
|
||||
if total_urls == 0:
|
||||
logger.warning("跳过站点扫描:没有站点 URL 可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "site_scan", "Skipped: no site URLs to scan", "warning")
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'total_urls': 0,
|
||||
'association_count': association_count,
|
||||
'processed_records': 0,
|
||||
'created_websites': 0,
|
||||
'skipped_no_subdomain': 0,
|
||||
'skipped_failed': 0,
|
||||
'executed_tasks': ['export_site_urls'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'successful_tools': [],
|
||||
'failed_tools': [],
|
||||
'details': {}
|
||||
}
|
||||
}
|
||||
|
||||
return _build_empty_result(
|
||||
scan_id, target_name, scan_workspace_dir, urls_file, association_count
|
||||
)
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
logger.info("Step 2: 工具配置信息")
|
||||
logger.info(
|
||||
"✓ 启用工具: %s",
|
||||
', '.join(enabled_tools.keys())
|
||||
)
|
||||
|
||||
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools))
|
||||
|
||||
# Step 3: 串行执行扫描工具
|
||||
logger.info("Step 3: 串行执行扫描工具并实时保存结果")
|
||||
tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially(
|
||||
enabled_tools=enabled_tools,
|
||||
urls_file=urls_file,
|
||||
total_urls=total_urls,
|
||||
site_scan_dir=site_scan_dir,
|
||||
ctx = ScanContext(
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
target_name=target_name
|
||||
target_name=target_name,
|
||||
site_scan_dir=site_scan_dir,
|
||||
urls_file=urls_file,
|
||||
total_urls=total_urls
|
||||
)
|
||||
|
||||
# 动态生成已执行的任务列表
|
||||
|
||||
tool_stats, processed_records, successful_tools, failed_tools = \
|
||||
_run_scans_sequentially(enabled_tools, ctx)
|
||||
|
||||
# 汇总结果
|
||||
executed_tasks = ['export_site_urls', 'parse_config']
|
||||
executed_tasks.extend([f'run_and_stream_save_websites ({tool})' for tool in tool_stats.keys()])
|
||||
|
||||
# 汇总所有工具的结果
|
||||
total_created = sum(stats['result'].get('created_websites', 0) for stats in tool_stats.values())
|
||||
total_skipped_no_subdomain = sum(stats['result'].get('skipped_no_subdomain', 0) for stats in tool_stats.values())
|
||||
total_skipped_failed = sum(stats['result'].get('skipped_failed', 0) for stats in tool_stats.values())
|
||||
|
||||
# 记录 Flow 完成
|
||||
executed_tasks.extend(
|
||||
f'run_and_stream_save_websites ({tool})' for tool in tool_stats
|
||||
)
|
||||
|
||||
total_created, total_skipped_no_sub, total_skipped_failed = \
|
||||
_aggregate_tool_results(tool_stats)
|
||||
|
||||
logger.info("✓ 站点扫描完成 - 创建站点: %d", total_created)
|
||||
user_log(scan_id, "site_scan", f"site_scan completed: found {total_created} websites")
|
||||
|
||||
user_log(
|
||||
scan_id, "site_scan",
|
||||
f"site_scan completed: found {total_created} websites"
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -466,25 +424,20 @@ def site_scan_flow(
|
||||
'association_count': association_count,
|
||||
'processed_records': processed_records,
|
||||
'created_websites': total_created,
|
||||
'skipped_no_subdomain': total_skipped_no_subdomain,
|
||||
'skipped_no_subdomain': total_skipped_no_sub,
|
||||
'skipped_failed': total_skipped_failed,
|
||||
'executed_tasks': executed_tasks,
|
||||
'tool_stats': {
|
||||
'total': len(enabled_tools),
|
||||
'successful': len(successful_tool_names),
|
||||
'successful': len(successful_tools),
|
||||
'failed': len(failed_tools),
|
||||
'successful_tools': successful_tool_names,
|
||||
'successful_tools': successful_tools,
|
||||
'failed_tools': failed_tools,
|
||||
'details': tool_stats
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("配置错误: %s", e)
|
||||
|
||||
except ValueError:
|
||||
raise
|
||||
except RuntimeError as e:
|
||||
logger.error("运行时错误: %s", e)
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("站点扫描失败: %s", e)
|
||||
raise
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,22 +10,18 @@ URL Fetch 主 Flow
|
||||
- 统一进行 httpx 验证(如果启用)
|
||||
"""
|
||||
|
||||
# Django 环境初始化
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
on_scan_flow_running,
|
||||
)
|
||||
from apps.scan.utils import user_log
|
||||
from apps.scan.utils import user_log, wait_for_system_load
|
||||
|
||||
from .domain_name_url_fetch_flow import domain_name_url_fetch_flow
|
||||
from .sites_url_fetch_flow import sites_url_fetch_flow
|
||||
@@ -43,13 +39,10 @@ SITES_FILE_TOOLS = {'katana'}
|
||||
POST_PROCESS_TOOLS = {'uro', 'httpx'}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
|
||||
"""
|
||||
将启用的工具按输入类型分类
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (domain_name_tools, sites_file_tools, uro_config, httpx_config)
|
||||
"""
|
||||
@@ -76,23 +69,23 @@ def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
|
||||
def _merge_and_deduplicate_urls(result_files: list, url_fetch_dir: Path) -> tuple[str, int]:
|
||||
"""合并并去重 URL"""
|
||||
from apps.scan.tasks.url_fetch import merge_and_deduplicate_urls_task
|
||||
|
||||
|
||||
merged_file = merge_and_deduplicate_urls_task(
|
||||
result_files=result_files,
|
||||
result_dir=str(url_fetch_dir)
|
||||
)
|
||||
|
||||
|
||||
# 统计唯一 URL 数量
|
||||
unique_url_count = 0
|
||||
if Path(merged_file).exists():
|
||||
with open(merged_file, 'r') as f:
|
||||
with open(merged_file, 'r', encoding='utf-8') as f:
|
||||
unique_url_count = sum(1 for line in f if line.strip())
|
||||
|
||||
|
||||
logger.info(
|
||||
"✓ URL 合并去重完成 - 合并文件: %s, 唯一 URL 数: %d",
|
||||
merged_file, unique_url_count
|
||||
)
|
||||
|
||||
|
||||
return merged_file, unique_url_count
|
||||
|
||||
|
||||
@@ -103,12 +96,12 @@ def _clean_urls_with_uro(
|
||||
) -> tuple[str, int, int]:
|
||||
"""使用 uro 清理合并后的 URL 列表"""
|
||||
from apps.scan.tasks.url_fetch import clean_urls_task
|
||||
|
||||
|
||||
raw_timeout = uro_config.get('timeout', 60)
|
||||
whitelist = uro_config.get('whitelist')
|
||||
blacklist = uro_config.get('blacklist')
|
||||
filters = uro_config.get('filters')
|
||||
|
||||
|
||||
# 计算超时时间
|
||||
if isinstance(raw_timeout, str) and raw_timeout == 'auto':
|
||||
timeout = calculate_timeout_by_line_count(
|
||||
@@ -124,7 +117,7 @@ def _clean_urls_with_uro(
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("uro timeout 配置无效(%s),使用默认 60 秒", raw_timeout)
|
||||
timeout = 60
|
||||
|
||||
|
||||
result = clean_urls_task(
|
||||
input_file=merged_file,
|
||||
output_dir=str(url_fetch_dir),
|
||||
@@ -133,12 +126,12 @@ def _clean_urls_with_uro(
|
||||
blacklist=blacklist,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
|
||||
if result['success']:
|
||||
return result['output_file'], result['output_count'], result['removed_count']
|
||||
else:
|
||||
logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误'))
|
||||
return merged_file, result['input_count'], 0
|
||||
|
||||
logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误'))
|
||||
return merged_file, result['input_count'], 0
|
||||
|
||||
|
||||
def _validate_and_stream_save_urls(
|
||||
@@ -151,25 +144,25 @@ def _validate_and_stream_save_urls(
|
||||
"""使用 httpx 验证 URL 存活并流式保存到数据库"""
|
||||
from apps.scan.utils import build_scan_command
|
||||
from apps.scan.tasks.url_fetch import run_and_stream_save_urls_task
|
||||
|
||||
|
||||
logger.info("开始使用 httpx 验证 URL 存活状态...")
|
||||
|
||||
|
||||
# 统计待验证的 URL 数量
|
||||
try:
|
||||
with open(merged_file, 'r') as f:
|
||||
with open(merged_file, 'r', encoding='utf-8') as f:
|
||||
url_count = sum(1 for _ in f)
|
||||
logger.info("待验证 URL 数量: %d", url_count)
|
||||
except Exception as e:
|
||||
except OSError as e:
|
||||
logger.error("读取 URL 文件失败: %s", e)
|
||||
return 0
|
||||
|
||||
|
||||
if url_count == 0:
|
||||
logger.warning("没有需要验证的 URL")
|
||||
return 0
|
||||
|
||||
|
||||
# 构建 httpx 命令
|
||||
command_params = {'url_file': merged_file}
|
||||
|
||||
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name='httpx',
|
||||
@@ -177,21 +170,19 @@ def _validate_and_stream_save_urls(
|
||||
command_params=command_params,
|
||||
tool_config=httpx_config
|
||||
)
|
||||
except Exception as e:
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.error("构建 httpx 命令失败: %s", e)
|
||||
logger.warning("降级处理:将直接保存所有 URL(不验证存活)")
|
||||
return _save_urls_to_database(merged_file, scan_id, target_id)
|
||||
|
||||
|
||||
# 计算超时时间
|
||||
raw_timeout = httpx_config.get('timeout', 'auto')
|
||||
timeout = 3600
|
||||
if isinstance(raw_timeout, str) and raw_timeout == 'auto':
|
||||
# 按 URL 行数计算超时时间:每行 3 秒,最小 60 秒
|
||||
timeout = max(60, url_count * 3)
|
||||
logger.info(
|
||||
"自动计算 httpx 超时时间(按行数,每行 3 秒,最小 60 秒): url_count=%d, timeout=%d 秒",
|
||||
url_count,
|
||||
timeout,
|
||||
url_count, timeout
|
||||
)
|
||||
else:
|
||||
try:
|
||||
@@ -199,49 +190,44 @@ def _validate_and_stream_save_urls(
|
||||
except (TypeError, ValueError):
|
||||
timeout = 3600
|
||||
logger.info("使用配置的 httpx 超时时间: %d 秒", timeout)
|
||||
|
||||
|
||||
# 生成日志文件路径
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = url_fetch_dir / f"httpx_validation_{timestamp}.log"
|
||||
|
||||
|
||||
# 流式执行
|
||||
try:
|
||||
result = run_and_stream_save_urls_task(
|
||||
cmd=command,
|
||||
tool_name='httpx',
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(url_fetch_dir),
|
||||
shell=True,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file)
|
||||
)
|
||||
|
||||
saved = result.get('saved_urls', 0)
|
||||
logger.info(
|
||||
"✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)",
|
||||
saved, (saved / url_count * 100) if url_count > 0 else 0
|
||||
)
|
||||
return saved
|
||||
|
||||
except Exception as e:
|
||||
logger.error("httpx 流式验证失败: %s", e, exc_info=True)
|
||||
raise
|
||||
result = run_and_stream_save_urls_task(
|
||||
cmd=command,
|
||||
tool_name='httpx',
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(url_fetch_dir),
|
||||
shell=True,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file)
|
||||
)
|
||||
|
||||
saved = result.get('saved_urls', 0)
|
||||
logger.info(
|
||||
"✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)",
|
||||
saved, (saved / url_count * 100) if url_count > 0 else 0
|
||||
)
|
||||
return saved
|
||||
|
||||
|
||||
def _save_urls_to_database(merged_file: str, scan_id: int, target_id: int) -> int:
|
||||
"""保存 URL 到数据库(不验证存活)"""
|
||||
from apps.scan.tasks.url_fetch import save_urls_task
|
||||
|
||||
|
||||
result = save_urls_task(
|
||||
urls_file=merged_file,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id
|
||||
)
|
||||
|
||||
|
||||
saved_count = result.get('saved_urls', 0)
|
||||
logger.info("✓ URL 保存完成 - 保存数量: %d", saved_count)
|
||||
|
||||
|
||||
return saved_count
|
||||
|
||||
|
||||
@@ -261,7 +247,7 @@ def url_fetch_flow(
|
||||
) -> dict:
|
||||
"""
|
||||
URL 获取主 Flow
|
||||
|
||||
|
||||
执行流程:
|
||||
1. 准备工作目录
|
||||
2. 按输入类型分类工具(domain_name / sites_file / 后处理)
|
||||
@@ -271,36 +257,32 @@ def url_fetch_flow(
|
||||
4. 合并所有子 Flow 的结果并去重
|
||||
5. uro 去重(如果启用)
|
||||
6. httpx 验证(如果启用)
|
||||
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作目录
|
||||
enabled_tools: 启用的工具配置
|
||||
|
||||
|
||||
Returns:
|
||||
dict: 扫描结果
|
||||
"""
|
||||
try:
|
||||
# 负载检查:等待系统资源充足
|
||||
wait_for_system_load(context="url_fetch_flow")
|
||||
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始 URL 获取扫描\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
"开始 URL 获取扫描 - Scan ID: %s, Target: %s, Workspace: %s",
|
||||
scan_id, target_name, scan_workspace_dir
|
||||
)
|
||||
|
||||
user_log(scan_id, "url_fetch", "Starting URL fetch")
|
||||
|
||||
|
||||
# Step 1: 准备工作目录
|
||||
logger.info("Step 1: 准备工作目录")
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch')
|
||||
|
||||
|
||||
# Step 2: 分类工具(按输入类型)
|
||||
logger.info("Step 2: 分类工具")
|
||||
domain_name_tools, sites_file_tools, uro_config, httpx_config = _classify_tools(enabled_tools)
|
||||
|
||||
logger.info(
|
||||
@@ -317,15 +299,14 @@ def url_fetch_flow(
|
||||
"URL Fetch 流程需要至少启用一个 URL 获取工具(如 waymore, katana)。"
|
||||
"httpx 和 uro 仅用于后处理,不能单独使用。"
|
||||
)
|
||||
|
||||
# Step 3: 并行执行子 Flow
|
||||
|
||||
# Step 3: 执行子 Flow
|
||||
all_result_files = []
|
||||
all_failed_tools = []
|
||||
all_successful_tools = []
|
||||
|
||||
# 3a: 基于 domain_name(target_name) 的 URL 被动收集(如 waymore)
|
||||
|
||||
# 3a: 基于 domain_name 的 URL 被动收集(如 waymore)
|
||||
if domain_name_tools:
|
||||
logger.info("Step 3a: 执行基于 domain_name 的 URL 被动收集子 Flow")
|
||||
tn_result = domain_name_url_fetch_flow(
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
@@ -336,10 +317,9 @@ def url_fetch_flow(
|
||||
all_result_files.extend(tn_result.get('result_files', []))
|
||||
all_failed_tools.extend(tn_result.get('failed_tools', []))
|
||||
all_successful_tools.extend(tn_result.get('successful_tools', []))
|
||||
|
||||
|
||||
# 3b: 爬虫(以 sites_file 为输入)
|
||||
if sites_file_tools:
|
||||
logger.info("Step 3b: 执行爬虫子 Flow")
|
||||
crawl_result = sites_url_fetch_flow(
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
@@ -350,12 +330,13 @@ def url_fetch_flow(
|
||||
all_result_files.extend(crawl_result.get('result_files', []))
|
||||
all_failed_tools.extend(crawl_result.get('failed_tools', []))
|
||||
all_successful_tools.extend(crawl_result.get('successful_tools', []))
|
||||
|
||||
|
||||
# 检查是否有成功的工具
|
||||
if not all_result_files:
|
||||
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in all_failed_tools])
|
||||
error_details = "; ".join([
|
||||
"%s: %s" % (f['tool'], f['reason']) for f in all_failed_tools
|
||||
])
|
||||
logger.warning("所有 URL 获取工具均失败 - 目标: %s, 失败详情: %s", target_name, error_details)
|
||||
# 返回空结果,不抛出异常,让扫描继续
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -366,31 +347,24 @@ def url_fetch_flow(
|
||||
'successful_tools': [],
|
||||
'message': '所有 URL 获取工具均无结果'
|
||||
}
|
||||
|
||||
|
||||
# Step 4: 合并并去重 URL
|
||||
logger.info("Step 4: 合并并去重 URL")
|
||||
merged_file, unique_url_count = _merge_and_deduplicate_urls(
|
||||
merged_file, _ = _merge_and_deduplicate_urls(
|
||||
result_files=all_result_files,
|
||||
url_fetch_dir=url_fetch_dir
|
||||
)
|
||||
|
||||
|
||||
# Step 5: 使用 uro 清理 URL(如果启用)
|
||||
url_file_for_validation = merged_file
|
||||
uro_removed_count = 0
|
||||
|
||||
if uro_config and uro_config.get('enabled', False):
|
||||
logger.info("Step 5: 使用 uro 清理 URL")
|
||||
url_file_for_validation, cleaned_count, uro_removed_count = _clean_urls_with_uro(
|
||||
url_file_for_validation, _, _ = _clean_urls_with_uro(
|
||||
merged_file=merged_file,
|
||||
uro_config=uro_config,
|
||||
url_fetch_dir=url_fetch_dir
|
||||
)
|
||||
else:
|
||||
logger.info("Step 5: 跳过 uro 清理(未启用)")
|
||||
|
||||
|
||||
# Step 6: 使用 httpx 验证存活并保存(如果启用)
|
||||
if httpx_config and httpx_config.get('enabled', False):
|
||||
logger.info("Step 6: 使用 httpx 验证 URL 存活并流式保存")
|
||||
saved_count = _validate_and_stream_save_urls(
|
||||
merged_file=url_file_for_validation,
|
||||
httpx_config=httpx_config,
|
||||
@@ -399,17 +373,16 @@ def url_fetch_flow(
|
||||
target_id=target_id
|
||||
)
|
||||
else:
|
||||
logger.info("Step 6: 保存到数据库(未启用 httpx 验证)")
|
||||
saved_count = _save_urls_to_database(
|
||||
merged_file=url_file_for_validation,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id
|
||||
)
|
||||
|
||||
|
||||
# 记录 Flow 完成
|
||||
logger.info("✓ URL 获取完成 - 保存 endpoints: %d", saved_count)
|
||||
user_log(scan_id, "url_fetch", f"url_fetch completed: found {saved_count} endpoints")
|
||||
|
||||
user_log(scan_id, "url_fetch", "url_fetch completed: found %d endpoints" % saved_count)
|
||||
|
||||
# 构建已执行的任务列表
|
||||
executed_tasks = ['setup_directory', 'classify_tools']
|
||||
if domain_name_tools:
|
||||
@@ -423,7 +396,7 @@ def url_fetch_flow(
|
||||
executed_tasks.append('httpx_validation_and_save')
|
||||
else:
|
||||
executed_tasks.append('save_urls')
|
||||
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -439,7 +412,7 @@ def url_fetch_flow(
|
||||
'failed_tools': [f['tool'] for f in all_failed_tools]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error("URL 获取扫描失败: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
"""
|
||||
漏洞扫描主 Flow
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Tuple
|
||||
|
||||
@@ -11,7 +12,7 @@ from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.configs.command_templates import get_command_template
|
||||
from apps.scan.utils import user_log
|
||||
from apps.scan.utils import user_log, wait_for_system_load
|
||||
from .endpoints_vuln_scan_flow import endpoints_vuln_scan_flow
|
||||
|
||||
|
||||
@@ -62,6 +63,9 @@ def vuln_scan_flow(
|
||||
- nuclei: 通用漏洞扫描(流式保存,支持模板 commit hash 同步)
|
||||
"""
|
||||
try:
|
||||
# 负载检查:等待系统资源充足
|
||||
wait_for_system_load(context="vuln_scan_flow")
|
||||
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 5.2.7 on 2026-01-07 14:03
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('scan', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='scan',
|
||||
name='cached_screenshots_count',
|
||||
field=models.IntegerField(default=0, help_text='缓存的截图数量'),
|
||||
),
|
||||
]
|
||||
23
backend/apps/scan/migrations/0003_add_wecom_fields.py
Normal file
23
backend/apps/scan/migrations/0003_add_wecom_fields.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Generated manually for WeCom notification support
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('scan', '0002_add_cached_screenshots_count'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='notificationsettings',
|
||||
name='wecom_enabled',
|
||||
field=models.BooleanField(default=False, help_text='是否启用企业微信通知'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='notificationsettings',
|
||||
name='wecom_webhook_url',
|
||||
field=models.URLField(blank=True, default='', help_text='企业微信机器人 Webhook URL'),
|
||||
),
|
||||
]
|
||||
@@ -84,6 +84,7 @@ class Scan(models.Model):
|
||||
cached_endpoints_count = models.IntegerField(default=0, help_text='缓存的端点数量')
|
||||
cached_ips_count = models.IntegerField(default=0, help_text='缓存的IP地址数量')
|
||||
cached_directories_count = models.IntegerField(default=0, help_text='缓存的目录数量')
|
||||
cached_screenshots_count = models.IntegerField(default=0, help_text='缓存的截图数量')
|
||||
cached_vulns_total = models.IntegerField(default=0, help_text='缓存的漏洞总数')
|
||||
cached_vulns_critical = models.IntegerField(default=0, help_text='缓存的严重漏洞数量')
|
||||
cached_vulns_high = models.IntegerField(default=0, help_text='缓存的高危漏洞数量')
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
"""通知系统数据模型"""
|
||||
|
||||
from django.db import models
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from .types import NotificationLevel, NotificationCategory
|
||||
from django.db import models
|
||||
from django.utils import timezone
|
||||
|
||||
from .types import NotificationCategory, NotificationLevel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationSettings(models.Model):
|
||||
@@ -10,31 +16,34 @@ class NotificationSettings(models.Model):
|
||||
通知设置(单例模型)
|
||||
存储 Discord webhook 配置和各分类的通知开关
|
||||
"""
|
||||
|
||||
|
||||
# Discord 配置
|
||||
discord_enabled = models.BooleanField(default=False, help_text='是否启用 Discord 通知')
|
||||
discord_webhook_url = models.URLField(blank=True, default='', help_text='Discord Webhook URL')
|
||||
|
||||
|
||||
# 企业微信配置
|
||||
wecom_enabled = models.BooleanField(default=False, help_text='是否启用企业微信通知')
|
||||
wecom_webhook_url = models.URLField(blank=True, default='', help_text='企业微信机器人 Webhook URL')
|
||||
|
||||
# 分类开关(使用 JSONField 存储)
|
||||
categories = models.JSONField(
|
||||
default=dict,
|
||||
help_text='各分类通知开关,如 {"scan": true, "vulnerability": true, "asset": true, "system": false}'
|
||||
)
|
||||
|
||||
|
||||
# 时间信息
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
updated_at = models.DateTimeField(auto_now=True)
|
||||
|
||||
|
||||
class Meta:
|
||||
db_table = 'notification_settings'
|
||||
verbose_name = '通知设置'
|
||||
verbose_name_plural = '通知设置'
|
||||
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
# 单例模式:强制只有一条记录
|
||||
self.pk = 1
|
||||
self.pk = 1 # 单例模式
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> 'NotificationSettings':
|
||||
"""获取或创建单例实例"""
|
||||
@@ -52,7 +61,7 @@ class NotificationSettings(models.Model):
|
||||
}
|
||||
)
|
||||
return obj
|
||||
|
||||
|
||||
def is_category_enabled(self, category: str) -> bool:
|
||||
"""检查指定分类是否启用通知"""
|
||||
return self.categories.get(category, False)
|
||||
@@ -60,10 +69,9 @@ class NotificationSettings(models.Model):
|
||||
|
||||
class Notification(models.Model):
|
||||
"""通知模型"""
|
||||
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
|
||||
# 通知分类
|
||||
|
||||
category = models.CharField(
|
||||
max_length=20,
|
||||
choices=NotificationCategory.choices,
|
||||
@@ -71,8 +79,7 @@ class Notification(models.Model):
|
||||
db_index=True,
|
||||
help_text='通知分类'
|
||||
)
|
||||
|
||||
# 通知级别
|
||||
|
||||
level = models.CharField(
|
||||
max_length=20,
|
||||
choices=NotificationLevel.choices,
|
||||
@@ -80,16 +87,15 @@ class Notification(models.Model):
|
||||
db_index=True,
|
||||
help_text='通知级别'
|
||||
)
|
||||
|
||||
|
||||
title = models.CharField(max_length=200, help_text='通知标题')
|
||||
message = models.CharField(max_length=2000, help_text='通知内容')
|
||||
|
||||
# 时间信息
|
||||
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
|
||||
|
||||
is_read = models.BooleanField(default=False, help_text='是否已读')
|
||||
read_at = models.DateTimeField(null=True, blank=True, help_text='阅读时间')
|
||||
|
||||
|
||||
class Meta:
|
||||
db_table = 'notification'
|
||||
verbose_name = '通知'
|
||||
@@ -101,44 +107,26 @@ class Notification(models.Model):
|
||||
models.Index(fields=['level', '-created_at']),
|
||||
models.Index(fields=['is_read', '-created_at']),
|
||||
]
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.get_level_display()} - {self.title}"
|
||||
|
||||
|
||||
@classmethod
|
||||
def cleanup_old_notifications(cls):
|
||||
"""
|
||||
清理超过15天的旧通知(硬编码)
|
||||
|
||||
Returns:
|
||||
int: 删除的通知数量
|
||||
"""
|
||||
from datetime import timedelta
|
||||
from django.utils import timezone
|
||||
|
||||
# 硬编码:只保留最近15天的通知
|
||||
def cleanup_old_notifications(cls) -> int:
|
||||
"""清理超过15天的旧通知"""
|
||||
cutoff_date = timezone.now() - timedelta(days=15)
|
||||
delete_result = cls.objects.filter(created_at__lt=cutoff_date).delete()
|
||||
|
||||
return delete_result[0] if delete_result[0] else 0
|
||||
|
||||
deleted_count, _ = cls.objects.filter(created_at__lt=cutoff_date).delete()
|
||||
return deleted_count or 0
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
"""
|
||||
重写save方法,在创建新通知时自动清理旧通知
|
||||
"""
|
||||
"""重写save方法,在创建新通知时自动清理旧通知"""
|
||||
is_new = self.pk is None
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
# 只在创建新通知时执行清理(自动清理超过15天的通知)
|
||||
|
||||
if is_new:
|
||||
try:
|
||||
deleted_count = self.__class__.cleanup_old_notifications()
|
||||
if deleted_count > 0:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"自动清理了 {deleted_count} 条超过15天的旧通知")
|
||||
except Exception as e:
|
||||
# 清理失败不应影响通知创建
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"通知自动清理失败: {e}")
|
||||
logger.info("自动清理了 %d 条超过15天的旧通知", deleted_count)
|
||||
except Exception:
|
||||
logger.warning("通知自动清理失败", exc_info=True)
|
||||
|
||||
@@ -1,52 +1,70 @@
|
||||
"""通知系统仓储层模块"""
|
||||
|
||||
import logging
|
||||
from typing import TypedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from .models import Notification, NotificationSettings
|
||||
|
||||
from .models import Notification, NotificationSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationSettingsData(TypedDict):
|
||||
"""通知设置数据结构"""
|
||||
@dataclass
|
||||
class NotificationSettingsData:
|
||||
"""通知设置更新数据"""
|
||||
|
||||
discord_enabled: bool
|
||||
discord_webhook_url: str
|
||||
categories: dict[str, bool]
|
||||
wecom_enabled: bool = False
|
||||
wecom_webhook_url: str = ''
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class NotificationSettingsRepository:
|
||||
"""通知设置仓储层"""
|
||||
|
||||
|
||||
def get_settings(self) -> NotificationSettings:
|
||||
"""获取通知设置单例"""
|
||||
return NotificationSettings.get_instance()
|
||||
|
||||
def update_settings(
|
||||
self,
|
||||
discord_enabled: bool,
|
||||
discord_webhook_url: str,
|
||||
categories: dict[str, bool]
|
||||
) -> NotificationSettings:
|
||||
|
||||
def update_settings(self, data: NotificationSettingsData) -> NotificationSettings:
|
||||
"""更新通知设置"""
|
||||
settings = NotificationSettings.get_instance()
|
||||
settings.discord_enabled = discord_enabled
|
||||
settings.discord_webhook_url = discord_webhook_url
|
||||
settings.categories = categories
|
||||
settings.discord_enabled = data.discord_enabled
|
||||
settings.discord_webhook_url = data.discord_webhook_url
|
||||
settings.wecom_enabled = data.wecom_enabled
|
||||
settings.wecom_webhook_url = data.wecom_webhook_url
|
||||
settings.categories = data.categories
|
||||
settings.save()
|
||||
return settings
|
||||
|
||||
|
||||
def is_category_enabled(self, category: str) -> bool:
|
||||
"""检查指定分类是否启用"""
|
||||
settings = self.get_settings()
|
||||
return settings.is_category_enabled(category)
|
||||
return self.get_settings().is_category_enabled(category)
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoNotificationRepository:
|
||||
def get_filtered(self, level: str | None = None, unread: bool | None = None):
|
||||
"""通知数据仓储层"""
|
||||
|
||||
def get_filtered(
|
||||
self,
|
||||
level: Optional[str] = None,
|
||||
unread: Optional[bool] = None
|
||||
) -> QuerySet[Notification]:
|
||||
"""
|
||||
获取过滤后的通知列表
|
||||
|
||||
Args:
|
||||
level: 通知级别过滤
|
||||
unread: 已读状态过滤 (True=未读, False=已读, None=全部)
|
||||
"""
|
||||
queryset = Notification.objects.all()
|
||||
|
||||
if level:
|
||||
@@ -60,16 +78,24 @@ class DjangoNotificationRepository:
|
||||
return queryset.order_by("-created_at")
|
||||
|
||||
def get_unread_count(self) -> int:
|
||||
"""获取未读通知数量"""
|
||||
return Notification.objects.filter(is_read=False).count()
|
||||
|
||||
def mark_all_as_read(self) -> int:
|
||||
updated = Notification.objects.filter(is_read=False).update(
|
||||
"""标记所有通知为已读,返回更新数量"""
|
||||
return Notification.objects.filter(is_read=False).update(
|
||||
is_read=True,
|
||||
read_at=timezone.now(),
|
||||
)
|
||||
return updated
|
||||
|
||||
def create(self, title: str, message: str, level: str, category: str = 'system') -> Notification:
|
||||
def create(
|
||||
self,
|
||||
title: str,
|
||||
message: str,
|
||||
level: str,
|
||||
category: str = 'system'
|
||||
) -> Notification:
|
||||
"""创建新通知"""
|
||||
return Notification.objects.create(
|
||||
category=category,
|
||||
level=level,
|
||||
|
||||
@@ -60,13 +60,12 @@ def push_to_external_channels(notification: Notification) -> None:
|
||||
except Exception as e:
|
||||
logger.warning(f"Discord 推送失败: {e}")
|
||||
|
||||
# 未来扩展:Slack
|
||||
# if settings.slack_enabled and settings.slack_webhook_url:
|
||||
# _send_slack(notification, settings.slack_webhook_url)
|
||||
|
||||
# 未来扩展:Telegram
|
||||
# if settings.telegram_enabled and settings.telegram_bot_token:
|
||||
# _send_telegram(notification, settings.telegram_chat_id)
|
||||
# 企业微信渠道
|
||||
if settings.wecom_enabled and settings.wecom_webhook_url:
|
||||
try:
|
||||
_send_wecom(notification, settings.wecom_webhook_url)
|
||||
except Exception as e:
|
||||
logger.warning(f"企业微信推送失败: {e}")
|
||||
|
||||
|
||||
def _send_discord(notification: Notification, webhook_url: str) -> bool:
|
||||
@@ -103,6 +102,41 @@ def _send_discord(notification: Notification, webhook_url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _send_wecom(notification: Notification, webhook_url: str) -> bool:
|
||||
"""发送到企业微信机器人 Webhook"""
|
||||
try:
|
||||
emoji = CATEGORY_EMOJI.get(notification.category, '📢')
|
||||
|
||||
# 企业微信 Markdown 格式
|
||||
content = f"""**{emoji} {notification.title}**
|
||||
> 级别:{notification.get_level_display()}
|
||||
> 分类:{notification.get_category_display()}
|
||||
|
||||
{notification.message}"""
|
||||
|
||||
payload = {
|
||||
'msgtype': 'markdown',
|
||||
'markdown': {'content': content}
|
||||
}
|
||||
|
||||
response = requests.post(webhook_url, json=payload, timeout=10)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get('errcode') == 0:
|
||||
logger.info(f"企业微信通知发送成功 - {notification.title}")
|
||||
return True
|
||||
logger.warning(f"企业微信发送失败 - errcode: {result.get('errcode')}, errmsg: {result.get('errmsg')}")
|
||||
return False
|
||||
|
||||
logger.warning(f"企业微信发送失败 - 状态码: {response.status_code}")
|
||||
return False
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"企业微信网络错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 设置服务
|
||||
# ============================================================
|
||||
@@ -121,31 +155,43 @@ class NotificationSettingsService:
|
||||
'enabled': settings.discord_enabled,
|
||||
'webhookUrl': settings.discord_webhook_url,
|
||||
},
|
||||
'wecom': {
|
||||
'enabled': settings.wecom_enabled,
|
||||
'webhookUrl': settings.wecom_webhook_url,
|
||||
},
|
||||
'categories': settings.categories,
|
||||
}
|
||||
|
||||
def update_settings(self, data: dict) -> dict:
|
||||
"""更新通知设置
|
||||
|
||||
|
||||
注意:DRF CamelCaseJSONParser 会将前端的 webhookUrl 转换为 webhook_url
|
||||
"""
|
||||
discord_data = data.get('discord', {})
|
||||
wecom_data = data.get('wecom', {})
|
||||
categories = data.get('categories', {})
|
||||
|
||||
|
||||
# CamelCaseJSONParser 转换后的字段名是 webhook_url
|
||||
webhook_url = discord_data.get('webhook_url', '')
|
||||
|
||||
discord_webhook_url = discord_data.get('webhook_url', '')
|
||||
wecom_webhook_url = wecom_data.get('webhook_url', '')
|
||||
|
||||
settings = self.repo.update_settings(
|
||||
discord_enabled=discord_data.get('enabled', False),
|
||||
discord_webhook_url=webhook_url,
|
||||
discord_webhook_url=discord_webhook_url,
|
||||
wecom_enabled=wecom_data.get('enabled', False),
|
||||
wecom_webhook_url=wecom_webhook_url,
|
||||
categories=categories,
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
'discord': {
|
||||
'enabled': settings.discord_enabled,
|
||||
'webhookUrl': settings.discord_webhook_url,
|
||||
},
|
||||
'wecom': {
|
||||
'enabled': settings.wecom_enabled,
|
||||
'webhookUrl': settings.wecom_webhook_url,
|
||||
},
|
||||
'categories': settings.categories,
|
||||
}
|
||||
|
||||
|
||||
@@ -147,10 +147,10 @@ class FlowOrchestrator:
|
||||
return True
|
||||
return False
|
||||
|
||||
# 其他扫描类型:检查 tools
|
||||
# 其他扫描类型(包括 screenshot):检查 tools
|
||||
tools = scan_config.get('tools', {})
|
||||
for tool_config in tools.values():
|
||||
if tool_config.get('enabled', False):
|
||||
if isinstance(tool_config, dict) and tool_config.get('enabled', False):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -222,6 +222,10 @@ class FlowOrchestrator:
|
||||
from apps.scan.flows.vuln_scan import vuln_scan_flow
|
||||
return vuln_scan_flow
|
||||
|
||||
elif scan_type == 'screenshot':
|
||||
from apps.scan.flows.screenshot_flow import screenshot_flow
|
||||
return screenshot_flow
|
||||
|
||||
else:
|
||||
logger.warning(f"未实现的扫描类型: {scan_type}")
|
||||
return None
|
||||
|
||||
56
backend/apps/scan/providers/__init__.py
Normal file
56
backend/apps/scan/providers/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
扫描目标提供者模块
|
||||
|
||||
提供统一的目标获取接口,支持多种数据源:
|
||||
- DatabaseTargetProvider: 从数据库查询(完整扫描)
|
||||
- ListTargetProvider: 使用内存列表(快速扫描阶段1)
|
||||
- SnapshotTargetProvider: 从快照表读取(快速扫描阶段2+)
|
||||
- PipelineTargetProvider: 使用管道输出(Phase 2)
|
||||
|
||||
使用方式:
|
||||
from apps.scan.providers import (
|
||||
DatabaseTargetProvider,
|
||||
ListTargetProvider,
|
||||
SnapshotTargetProvider,
|
||||
ProviderContext
|
||||
)
|
||||
|
||||
# 数据库模式(完整扫描)
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
|
||||
# 列表模式(快速扫描阶段1)
|
||||
context = ProviderContext(target_id=1, scan_id=100)
|
||||
provider = ListTargetProvider(
|
||||
targets=["a.test.com"],
|
||||
context=context
|
||||
)
|
||||
|
||||
# 快照模式(快速扫描阶段2+)
|
||||
context = ProviderContext(target_id=1, scan_id=100)
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain",
|
||||
context=context
|
||||
)
|
||||
|
||||
# 使用 Provider
|
||||
for host in provider.iter_hosts():
|
||||
scan(host)
|
||||
"""
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
from .list_provider import ListTargetProvider
|
||||
from .database_provider import DatabaseTargetProvider
|
||||
from .snapshot_provider import SnapshotTargetProvider, SnapshotType
|
||||
from .pipeline_provider import PipelineTargetProvider, StageOutput
|
||||
|
||||
__all__ = [
|
||||
'TargetProvider',
|
||||
'ProviderContext',
|
||||
'ListTargetProvider',
|
||||
'DatabaseTargetProvider',
|
||||
'SnapshotTargetProvider',
|
||||
'SnapshotType',
|
||||
'PipelineTargetProvider',
|
||||
'StageOutput',
|
||||
]
|
||||
115
backend/apps/scan/providers/base.py
Normal file
115
backend/apps/scan/providers/base.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
扫描目标提供者基础模块
|
||||
|
||||
定义 ProviderContext 数据类和 TargetProvider 抽象基类。
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Iterator, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apps.common.utils import BlacklistFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderContext:
|
||||
"""
|
||||
Provider 上下文,携带元数据
|
||||
|
||||
Attributes:
|
||||
target_id: 关联的 Target ID(用于结果保存),None 表示临时扫描(不保存)
|
||||
scan_id: 扫描任务 ID
|
||||
"""
|
||||
target_id: Optional[int] = None
|
||||
scan_id: Optional[int] = None
|
||||
|
||||
|
||||
class TargetProvider(ABC):
|
||||
"""
|
||||
扫描目标提供者抽象基类
|
||||
|
||||
职责:
|
||||
- 提供扫描目标(域名、IP、URL 等)的迭代器
|
||||
- 提供黑名单过滤器
|
||||
- 携带上下文信息(target_id, scan_id 等)
|
||||
- 自动展开 CIDR(子类无需关心)
|
||||
|
||||
使用方式:
|
||||
provider = create_target_provider(target_id=123)
|
||||
for host in provider.iter_hosts():
|
||||
print(host)
|
||||
"""
|
||||
|
||||
def __init__(self, context: Optional[ProviderContext] = None):
|
||||
self._context = context or ProviderContext()
|
||||
|
||||
@property
|
||||
def context(self) -> ProviderContext:
|
||||
"""返回 Provider 上下文"""
|
||||
return self._context
|
||||
|
||||
@staticmethod
|
||||
def _expand_host(host: str) -> Iterator[str]:
|
||||
"""
|
||||
展开主机(如果是 CIDR 则展开为多个 IP,否则直接返回)
|
||||
|
||||
示例:
|
||||
"192.168.1.0/30" → "192.168.1.1", "192.168.1.2"
|
||||
"192.168.1.1" → "192.168.1.1"
|
||||
"example.com" → "example.com"
|
||||
"""
|
||||
from apps.common.validators import detect_target_type
|
||||
from apps.targets.models import Target
|
||||
|
||||
host = host.strip()
|
||||
if not host:
|
||||
return
|
||||
|
||||
try:
|
||||
target_type = detect_target_type(host)
|
||||
|
||||
if target_type == Target.TargetType.CIDR:
|
||||
network = ipaddress.ip_network(host, strict=False)
|
||||
if network.num_addresses == 1:
|
||||
yield str(network.network_address)
|
||||
else:
|
||||
yield from (str(ip) for ip in network.hosts())
|
||||
elif target_type in (Target.TargetType.IP, Target.TargetType.DOMAIN):
|
||||
yield host
|
||||
except ValueError as e:
|
||||
logger.warning("跳过无效的主机格式 '%s': %s", host, str(e))
|
||||
|
||||
def iter_hosts(self) -> Iterator[str]:
|
||||
"""迭代主机列表(域名/IP),自动展开 CIDR"""
|
||||
for host in self._iter_raw_hosts():
|
||||
yield from self._expand_host(host)
|
||||
|
||||
@abstractmethod
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""迭代原始主机列表(可能包含 CIDR),子类实现"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""迭代 URL 列表"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
|
||||
"""获取黑名单过滤器,返回 None 表示不过滤"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def target_id(self) -> Optional[int]:
|
||||
"""返回关联的 target_id,临时扫描返回 None"""
|
||||
return self._context.target_id
|
||||
|
||||
@property
|
||||
def scan_id(self) -> Optional[int]:
|
||||
"""返回关联的 scan_id"""
|
||||
return self._context.scan_id
|
||||
93
backend/apps/scan/providers/database_provider.py
Normal file
93
backend/apps/scan/providers/database_provider.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
数据库目标提供者模块
|
||||
|
||||
提供基于数据库查询的目标提供者实现。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Iterator, Optional
|
||||
|
||||
from .base import ProviderContext, TargetProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apps.common.utils import BlacklistFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseTargetProvider(TargetProvider):
|
||||
"""
|
||||
数据库目标提供者 - 从 Target 表及关联资产表查询
|
||||
|
||||
数据来源:
|
||||
- iter_hosts(): 根据 Target 类型返回域名/IP
|
||||
- iter_urls(): WebSite/Endpoint 表,带回退链
|
||||
|
||||
使用方式:
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
for host in provider.iter_hosts():
|
||||
scan(host)
|
||||
"""
|
||||
|
||||
def __init__(self, target_id: int, context: Optional[ProviderContext] = None):
|
||||
ctx = context or ProviderContext()
|
||||
ctx.target_id = target_id
|
||||
super().__init__(ctx)
|
||||
self._blacklist_filter: Optional['BlacklistFilter'] = None
|
||||
|
||||
def iter_hosts(self) -> Iterator[str]:
|
||||
"""从数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤"""
|
||||
blacklist = self.get_blacklist_filter()
|
||||
|
||||
for host in self._iter_raw_hosts():
|
||||
for expanded_host in self._expand_host(host):
|
||||
if not blacklist or blacklist.is_allowed(expanded_host):
|
||||
yield expanded_host
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""从数据库查询原始主机列表(可能包含 CIDR)"""
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
from apps.targets.models import Target
|
||||
from apps.targets.services import TargetService
|
||||
|
||||
target = TargetService().get_target(self.target_id)
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在", self.target_id)
|
||||
return
|
||||
|
||||
if target.type == Target.TargetType.DOMAIN:
|
||||
yield target.name
|
||||
for domain in SubdomainService().iter_subdomain_names_by_target(
|
||||
target_id=self.target_id,
|
||||
chunk_size=1000
|
||||
):
|
||||
if domain != target.name:
|
||||
yield domain
|
||||
|
||||
elif target.type in (Target.TargetType.IP, Target.TargetType.CIDR):
|
||||
yield target.name
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""从数据库查询 URL 列表,使用回退链:Endpoint → WebSite → Default"""
|
||||
from apps.scan.services.target_export_service import (
|
||||
DataSource,
|
||||
_iter_urls_with_fallback,
|
||||
)
|
||||
|
||||
blacklist = self.get_blacklist_filter()
|
||||
|
||||
for url, _ in _iter_urls_with_fallback(
|
||||
target_id=self.target_id,
|
||||
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
|
||||
blacklist_filter=blacklist
|
||||
):
|
||||
yield url
|
||||
|
||||
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
|
||||
"""获取黑名单过滤器(延迟加载)"""
|
||||
if self._blacklist_filter is None:
|
||||
from apps.common.services import BlacklistService
|
||||
from apps.common.utils import BlacklistFilter
|
||||
rules = BlacklistService().get_rules(self.target_id)
|
||||
self._blacklist_filter = BlacklistFilter(rules)
|
||||
return self._blacklist_filter
|
||||
84
backend/apps/scan/providers/list_provider.py
Normal file
84
backend/apps/scan/providers/list_provider.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
列表目标提供者模块
|
||||
|
||||
提供基于内存列表的目标提供者实现。
|
||||
"""
|
||||
|
||||
from typing import Iterator, Optional, List
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
|
||||
|
||||
class ListTargetProvider(TargetProvider):
|
||||
"""
|
||||
列表目标提供者 - 直接使用内存中的列表
|
||||
|
||||
用于快速扫描、临时扫描等场景,只扫描用户指定的目标。
|
||||
|
||||
特点:
|
||||
- 不查询数据库
|
||||
- 不应用黑名单过滤(用户明确指定的目标)
|
||||
- 不关联 target_id(由调用方负责创建 Target)
|
||||
- 自动检测输入类型(URL/域名/IP/CIDR)
|
||||
- 自动展开 CIDR
|
||||
|
||||
使用方式:
|
||||
# 快速扫描:用户提供目标,自动识别类型
|
||||
provider = ListTargetProvider(targets=[
|
||||
"example.com", # 域名
|
||||
"192.168.1.0/24", # CIDR(自动展开)
|
||||
"https://api.example.com" # URL
|
||||
])
|
||||
for host in provider.iter_hosts():
|
||||
scan(host)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
targets: Optional[List[str]] = None,
|
||||
context: Optional[ProviderContext] = None
|
||||
):
|
||||
"""
|
||||
初始化列表目标提供者
|
||||
|
||||
Args:
|
||||
targets: 目标列表(自动识别类型:URL/域名/IP/CIDR)
|
||||
context: Provider 上下文
|
||||
"""
|
||||
from apps.common.validators import detect_input_type
|
||||
|
||||
ctx = context or ProviderContext()
|
||||
super().__init__(ctx)
|
||||
|
||||
# 自动分类目标
|
||||
self._hosts = []
|
||||
self._urls = []
|
||||
|
||||
if targets:
|
||||
for target in targets:
|
||||
target = target.strip()
|
||||
if not target:
|
||||
continue
|
||||
|
||||
try:
|
||||
input_type = detect_input_type(target)
|
||||
if input_type == 'url':
|
||||
self._urls.append(target)
|
||||
else:
|
||||
# domain/ip/cidr 都作为 host
|
||||
self._hosts.append(target)
|
||||
except ValueError:
|
||||
# 无法识别类型,默认作为 host
|
||||
self._hosts.append(target)
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""迭代原始主机列表(可能包含 CIDR)"""
|
||||
yield from self._hosts
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""迭代 URL 列表"""
|
||||
yield from self._urls
|
||||
|
||||
def get_blacklist_filter(self) -> None:
|
||||
"""列表模式不使用黑名单过滤"""
|
||||
return None
|
||||
91
backend/apps/scan/providers/pipeline_provider.py
Normal file
91
backend/apps/scan/providers/pipeline_provider.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
管道目标提供者模块
|
||||
|
||||
提供基于管道阶段输出的目标提供者实现。
|
||||
用于 Phase 2 管道模式的阶段间数据传递。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator, Optional, List, Dict, Any
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageOutput:
|
||||
"""
|
||||
阶段输出数据
|
||||
|
||||
用于在管道阶段之间传递数据。
|
||||
|
||||
Attributes:
|
||||
hosts: 主机列表(域名/IP)
|
||||
urls: URL 列表
|
||||
new_targets: 新发现的目标列表
|
||||
stats: 统计信息
|
||||
success: 是否成功
|
||||
error: 错误信息
|
||||
"""
|
||||
hosts: List[str] = field(default_factory=list)
|
||||
urls: List[str] = field(default_factory=list)
|
||||
new_targets: List[str] = field(default_factory=list)
|
||||
stats: Dict[str, Any] = field(default_factory=dict)
|
||||
success: bool = True
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class PipelineTargetProvider(TargetProvider):
|
||||
"""
|
||||
管道目标提供者 - 使用上一阶段的输出
|
||||
|
||||
用于 Phase 2 管道模式的阶段间数据传递。
|
||||
|
||||
特点:
|
||||
- 不查询数据库
|
||||
- 不应用黑名单过滤(数据已在上一阶段过滤)
|
||||
- 直接使用 StageOutput 中的数据
|
||||
|
||||
使用方式(Phase 2):
|
||||
stage1_output = stage1.run(input)
|
||||
provider = PipelineTargetProvider(
|
||||
previous_output=stage1_output,
|
||||
target_id=123
|
||||
)
|
||||
for host in provider.iter_hosts():
|
||||
stage2.scan(host)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
previous_output: StageOutput,
|
||||
target_id: Optional[int] = None,
|
||||
context: Optional[ProviderContext] = None
|
||||
):
|
||||
"""
|
||||
初始化管道目标提供者
|
||||
|
||||
Args:
|
||||
previous_output: 上一阶段的输出
|
||||
target_id: 可选,关联到某个 Target(用于保存结果)
|
||||
context: Provider 上下文
|
||||
"""
|
||||
ctx = context or ProviderContext(target_id=target_id)
|
||||
super().__init__(ctx)
|
||||
self._previous_output = previous_output
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""迭代上一阶段输出的原始主机(可能包含 CIDR)"""
|
||||
yield from self._previous_output.hosts
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""迭代上一阶段输出的 URL"""
|
||||
yield from self._previous_output.urls
|
||||
|
||||
def get_blacklist_filter(self) -> None:
|
||||
"""管道传递的数据已经过滤过了"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def previous_output(self) -> StageOutput:
|
||||
"""返回上一阶段的输出"""
|
||||
return self._previous_output
|
||||
175
backend/apps/scan/providers/snapshot_provider.py
Normal file
175
backend/apps/scan/providers/snapshot_provider.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
快照目标提供者模块
|
||||
|
||||
提供基于快照表的目标提供者实现。
|
||||
用于快速扫描的阶段间数据传递。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Iterator, Optional, Literal
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 快照类型定义
|
||||
SnapshotType = Literal["subdomain", "website", "endpoint", "host_port"]
|
||||
|
||||
|
||||
class SnapshotTargetProvider(TargetProvider):
|
||||
"""
|
||||
快照目标提供者 - 从快照表读取本次扫描的数据
|
||||
|
||||
用于快速扫描的阶段间数据传递,解决精确扫描控制问题。
|
||||
|
||||
核心价值:
|
||||
- 只返回本次扫描(scan_id)发现的资产
|
||||
- 避免扫描历史数据(DatabaseTargetProvider 会扫描所有历史资产)
|
||||
|
||||
特点:
|
||||
- 通过 scan_id 过滤快照表
|
||||
- 不应用黑名单过滤(数据已在上一阶段过滤)
|
||||
- 支持多种快照类型(subdomain/website/endpoint/host_port)
|
||||
|
||||
使用场景:
|
||||
# 快速扫描流程
|
||||
用户输入: a.test.com
|
||||
创建 Target: test.com (id=1)
|
||||
创建 Scan: scan_id=100
|
||||
|
||||
# 阶段1: 子域名发现
|
||||
provider = ListTargetProvider(
|
||||
targets=["a.test.com"],
|
||||
context=ProviderContext(target_id=1, scan_id=100)
|
||||
)
|
||||
# 发现: b.a.test.com, c.a.test.com
|
||||
# 保存: SubdomainSnapshot(scan_id=100) + Subdomain(target_id=1)
|
||||
|
||||
# 阶段2: 端口扫描
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain",
|
||||
context=ProviderContext(target_id=1, scan_id=100)
|
||||
)
|
||||
# 只返回: b.a.test.com, c.a.test.com(本次扫描发现的)
|
||||
# 不返回: www.test.com, api.test.com(历史数据)
|
||||
|
||||
# 阶段3: 网站扫描
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="host_port",
|
||||
context=ProviderContext(target_id=1, scan_id=100)
|
||||
)
|
||||
# 只返回本次扫描发现的 IP:Port
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scan_id: int,
|
||||
snapshot_type: SnapshotType,
|
||||
context: Optional[ProviderContext] = None
|
||||
):
|
||||
"""
|
||||
初始化快照目标提供者
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID(必需)
|
||||
snapshot_type: 快照类型
|
||||
- "subdomain": 子域名快照(SubdomainSnapshot)
|
||||
- "website": 网站快照(WebsiteSnapshot)
|
||||
- "endpoint": 端点快照(EndpointSnapshot)
|
||||
- "host_port": 主机端口映射快照(HostPortMappingSnapshot)
|
||||
context: Provider 上下文
|
||||
"""
|
||||
ctx = context or ProviderContext()
|
||||
ctx.scan_id = scan_id
|
||||
super().__init__(ctx)
|
||||
self._scan_id = scan_id
|
||||
self._snapshot_type = snapshot_type
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""
|
||||
从快照表迭代主机列表
|
||||
|
||||
根据 snapshot_type 选择不同的快照表:
|
||||
- subdomain: SubdomainSnapshot.name
|
||||
- host_port: HostPortMappingSnapshot.host (返回 host:port 格式,不经过验证)
|
||||
"""
|
||||
if self._snapshot_type == "subdomain":
|
||||
from apps.asset.services.snapshot import SubdomainSnapshotsService
|
||||
service = SubdomainSnapshotsService()
|
||||
yield from service.iter_subdomain_names_by_scan(
|
||||
scan_id=self._scan_id,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
elif self._snapshot_type == "host_port":
|
||||
# host_port 类型不使用 _iter_raw_hosts,直接在 iter_hosts 中处理
|
||||
# 这里返回空,避免被基类的 iter_hosts 调用
|
||||
return
|
||||
|
||||
else:
|
||||
# 其他类型暂不支持 iter_hosts
|
||||
logger.warning(
|
||||
"快照类型 '%s' 不支持 iter_hosts,返回空迭代器",
|
||||
self._snapshot_type
|
||||
)
|
||||
return
|
||||
|
||||
def iter_hosts(self) -> Iterator[str]:
|
||||
"""
|
||||
迭代主机列表
|
||||
|
||||
对于 host_port 类型,返回 host:port 格式,不经过 CIDR 展开验证
|
||||
"""
|
||||
if self._snapshot_type == "host_port":
|
||||
# host_port 类型直接返回 host:port,不经过 _expand_host 验证
|
||||
from apps.asset.services.snapshot import HostPortMappingSnapshotsService
|
||||
service = HostPortMappingSnapshotsService()
|
||||
queryset = service.get_by_scan(scan_id=self._scan_id)
|
||||
for mapping in queryset.iterator(chunk_size=1000):
|
||||
yield f"{mapping.host}:{mapping.port}"
|
||||
else:
|
||||
# 其他类型使用基类的 iter_hosts(会调用 _iter_raw_hosts 并展开 CIDR)
|
||||
yield from super().iter_hosts()
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""
|
||||
从快照表迭代 URL 列表
|
||||
|
||||
根据 snapshot_type 选择不同的快照表:
|
||||
- website: WebsiteSnapshot.url
|
||||
- endpoint: EndpointSnapshot.url
|
||||
"""
|
||||
if self._snapshot_type == "website":
|
||||
from apps.asset.services.snapshot import WebsiteSnapshotsService
|
||||
service = WebsiteSnapshotsService()
|
||||
yield from service.iter_website_urls_by_scan(
|
||||
scan_id=self._scan_id,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
elif self._snapshot_type == "endpoint":
|
||||
from apps.asset.services.snapshot import EndpointSnapshotsService
|
||||
service = EndpointSnapshotsService()
|
||||
# 从快照表获取端点 URL
|
||||
queryset = service.get_by_scan(scan_id=self._scan_id)
|
||||
for endpoint in queryset.iterator(chunk_size=1000):
|
||||
yield endpoint.url
|
||||
|
||||
else:
|
||||
# 其他类型暂不支持 iter_urls
|
||||
logger.warning(
|
||||
"快照类型 '%s' 不支持 iter_urls,返回空迭代器",
|
||||
self._snapshot_type
|
||||
)
|
||||
return
|
||||
|
||||
def get_blacklist_filter(self) -> None:
|
||||
"""快照数据已在上一阶段过滤过了"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def snapshot_type(self) -> SnapshotType:
|
||||
"""返回快照类型"""
|
||||
return self._snapshot_type
|
||||
3
backend/apps/scan/providers/tests/__init__.py
Normal file
3
backend/apps/scan/providers/tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
扫描目标提供者测试模块
|
||||
"""
|
||||
256
backend/apps/scan/providers/tests/test_common_properties.py
Normal file
256
backend/apps/scan/providers/tests/test_common_properties.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
通用属性测试
|
||||
|
||||
包含跨多个 Provider 的通用属性测试:
|
||||
- Property 4: Context Propagation
|
||||
- Property 5: Non-Database Provider Blacklist Filter
|
||||
- Property 7: CIDR Expansion Consistency
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st, settings
|
||||
from ipaddress import IPv4Network
|
||||
|
||||
from apps.scan.providers import (
|
||||
ProviderContext,
|
||||
ListTargetProvider,
|
||||
DatabaseTargetProvider,
|
||||
PipelineTargetProvider,
|
||||
SnapshotTargetProvider
|
||||
)
|
||||
from apps.scan.providers.pipeline_provider import StageOutput
|
||||
|
||||
|
||||
class TestContextPropagation:
|
||||
"""
|
||||
Property 4: Context Propagation
|
||||
|
||||
*For any* ProviderContext,传入 Provider 构造函数后,
|
||||
Provider 的 target_id 和 scan_id 属性应该与 context 中的值一致。
|
||||
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_list_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (ListTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
|
||||
provider = ListTargetProvider(targets=["example.com"], context=ctx)
|
||||
|
||||
assert provider.target_id == target_id
|
||||
assert provider.scan_id == scan_id
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_database_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (DatabaseTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=999, scan_id=scan_id)
|
||||
# DatabaseTargetProvider 会覆盖 context 中的 target_id
|
||||
provider = DatabaseTargetProvider(target_id=target_id, context=ctx)
|
||||
|
||||
assert provider.target_id == target_id # 使用构造函数参数
|
||||
assert provider.scan_id == scan_id # 使用 context 中的值
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_pipeline_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (PipelineTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
|
||||
|
||||
assert provider.target_id == target_id
|
||||
assert provider.scan_id == scan_id
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_snapshot_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (SnapshotTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=target_id, scan_id=999)
|
||||
# SnapshotTargetProvider 会覆盖 context 中的 scan_id
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=scan_id,
|
||||
snapshot_type="subdomain",
|
||||
context=ctx
|
||||
)
|
||||
|
||||
assert provider.target_id == target_id # 使用 context 中的值
|
||||
assert provider.scan_id == scan_id # 使用构造函数参数
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
|
||||
class TestNonDatabaseProviderBlacklistFilter:
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter
|
||||
|
||||
*For any* ListTargetProvider 或 PipelineTargetProvider 实例,
|
||||
get_blacklist_filter() 方法应该返回 None。
|
||||
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
|
||||
@given(targets=st.lists(st.text(min_size=1, max_size=20), max_size=10))
|
||||
@settings(max_examples=100)
|
||||
def test_property_5_list_provider_no_blacklist(self, targets):
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter (ListTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
provider = ListTargetProvider(targets=targets)
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
@given(hosts=st.lists(st.text(min_size=1, max_size=20), max_size=10))
|
||||
@settings(max_examples=100)
|
||||
def test_property_5_pipeline_provider_no_blacklist(self, hosts):
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter (PipelineTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
stage_output = StageOutput(hosts=hosts)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_property_5_snapshot_provider_no_blacklist(self):
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter (SnapshotTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
provider = SnapshotTargetProvider(scan_id=1, snapshot_type="subdomain")
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
|
||||
class TestCIDRExpansionConsistency:
|
||||
"""
|
||||
Property 7: CIDR Expansion Consistency
|
||||
|
||||
*For any* CIDR 字符串(如 "192.168.1.0/24"),所有 Provider 的 iter_hosts()
|
||||
方法应该将其展开为相同的单个 IP 地址列表。
|
||||
|
||||
**Validates: Requirements 1.1, 3.6**
|
||||
"""
|
||||
|
||||
@given(
|
||||
# 生成小的 CIDR 范围以避免测试超时
|
||||
network_prefix=st.integers(min_value=1, max_value=254),
|
||||
cidr_suffix=st.integers(min_value=28, max_value=30) # /28 = 16 IPs, /30 = 4 IPs
|
||||
)
|
||||
@settings(max_examples=50, deadline=None)
|
||||
def test_property_7_cidr_expansion_consistency(self, network_prefix, cidr_suffix):
|
||||
"""
|
||||
Property 7: CIDR Expansion Consistency
|
||||
|
||||
Feature: scan-target-provider, Property 7: CIDR Expansion Consistency
|
||||
**Validates: Requirements 1.1, 3.6**
|
||||
|
||||
For any CIDR string, all Providers should expand it to the same IP list.
|
||||
"""
|
||||
cidr = f"192.168.{network_prefix}.0/{cidr_suffix}"
|
||||
|
||||
# 计算预期的 IP 列表
|
||||
network = IPv4Network(cidr, strict=False)
|
||||
# 排除网络地址和广播地址
|
||||
expected_ips = [str(ip) for ip in network.hosts()]
|
||||
|
||||
# 如果 CIDR 太小(/31 或 /32),使用所有地址
|
||||
if not expected_ips:
|
||||
expected_ips = [str(ip) for ip in network]
|
||||
|
||||
# ListTargetProvider
|
||||
list_provider = ListTargetProvider(targets=[cidr])
|
||||
list_result = list(list_provider.iter_hosts())
|
||||
|
||||
# PipelineTargetProvider
|
||||
stage_output = StageOutput(hosts=[cidr])
|
||||
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
pipeline_result = list(pipeline_provider.iter_hosts())
|
||||
|
||||
# 验证:所有 Provider 展开的结果应该一致
|
||||
assert list_result == expected_ips, f"ListProvider CIDR expansion mismatch for {cidr}"
|
||||
assert pipeline_result == expected_ips, f"PipelineProvider CIDR expansion mismatch for {cidr}"
|
||||
assert list_result == pipeline_result, f"Providers produce different results for {cidr}"
|
||||
|
||||
def test_cidr_expansion_with_multiple_cidrs(self):
|
||||
"""测试多个 CIDR 的展开一致性"""
|
||||
cidrs = ["192.168.1.0/30", "10.0.0.0/30"]
|
||||
|
||||
# 计算预期结果
|
||||
expected_ips = []
|
||||
for cidr in cidrs:
|
||||
network = IPv4Network(cidr, strict=False)
|
||||
expected_ips.extend([str(ip) for ip in network.hosts()])
|
||||
|
||||
# ListTargetProvider
|
||||
list_provider = ListTargetProvider(targets=cidrs)
|
||||
list_result = list(list_provider.iter_hosts())
|
||||
|
||||
# PipelineTargetProvider
|
||||
stage_output = StageOutput(hosts=cidrs)
|
||||
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
pipeline_result = list(pipeline_provider.iter_hosts())
|
||||
|
||||
# 验证
|
||||
assert list_result == expected_ips
|
||||
assert pipeline_result == expected_ips
|
||||
assert list_result == pipeline_result
|
||||
|
||||
def test_mixed_hosts_and_cidrs(self):
|
||||
"""测试混合主机和 CIDR 的处理"""
|
||||
targets = ["example.com", "192.168.1.0/30", "test.com"]
|
||||
|
||||
# 计算预期结果
|
||||
network = IPv4Network("192.168.1.0/30", strict=False)
|
||||
cidr_ips = [str(ip) for ip in network.hosts()]
|
||||
expected = ["example.com"] + cidr_ips + ["test.com"]
|
||||
|
||||
# ListTargetProvider
|
||||
list_provider = ListTargetProvider(targets=targets)
|
||||
list_result = list(list_provider.iter_hosts())
|
||||
|
||||
# 验证
|
||||
assert list_result == expected
|
||||
158
backend/apps/scan/providers/tests/test_database_provider.py
Normal file
158
backend/apps/scan/providers/tests/test_database_provider.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
DatabaseTargetProvider 属性测试
|
||||
|
||||
Property 7: DatabaseTargetProvider Blacklist Application
|
||||
*For any* 带有黑名单规则的 target_id,DatabaseTargetProvider 的 iter_hosts() 和 iter_urls()
|
||||
应该过滤掉匹配黑名单规则的目标。
|
||||
|
||||
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from hypothesis import given, strategies as st, settings
|
||||
|
||||
from apps.scan.providers.database_provider import DatabaseTargetProvider
|
||||
from apps.scan.providers.base import ProviderContext
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
|
||||
class MockBlacklistFilter:
|
||||
"""模拟黑名单过滤器"""
|
||||
|
||||
def __init__(self, blocked_patterns: list):
|
||||
self.blocked_patterns = blocked_patterns
|
||||
|
||||
def is_allowed(self, target: str) -> bool:
|
||||
"""检查目标是否被允许(不在黑名单中)"""
|
||||
for pattern in self.blocked_patterns:
|
||||
if pattern in target:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class TestDatabaseTargetProviderProperties:
|
||||
"""DatabaseTargetProvider 属性测试类"""
|
||||
|
||||
@given(
|
||||
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=20),
|
||||
blocked_keyword=st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=5
|
||||
)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_7_blacklist_filters_hosts(self, hosts, blocked_keyword):
|
||||
"""
|
||||
Property 7: DatabaseTargetProvider Blacklist Application (hosts)
|
||||
|
||||
Feature: scan-target-provider, Property 7: DatabaseTargetProvider Blacklist Application
|
||||
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
|
||||
|
||||
For any set of hosts and a blacklist keyword, the provider should filter out
|
||||
all hosts containing the blocked keyword.
|
||||
"""
|
||||
# 创建模拟的黑名单过滤器
|
||||
mock_filter = MockBlacklistFilter([blocked_keyword])
|
||||
|
||||
# 创建 provider 并注入模拟的黑名单过滤器
|
||||
provider = DatabaseTargetProvider(target_id=1)
|
||||
provider._blacklist_filter = mock_filter
|
||||
|
||||
# 模拟 Target 和 SubdomainService
|
||||
mock_target = MagicMock()
|
||||
mock_target.type = 'domain'
|
||||
mock_target.name = hosts[0] if hosts else 'example.com'
|
||||
|
||||
with patch('apps.targets.services.TargetService') as mock_target_service, \
|
||||
patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_subdomain_service:
|
||||
|
||||
mock_target_service.return_value.get_target.return_value = mock_target
|
||||
mock_subdomain_service.return_value.iter_subdomain_names_by_target.return_value = iter(hosts[1:] if len(hosts) > 1 else [])
|
||||
|
||||
# 获取结果
|
||||
result = list(provider.iter_hosts())
|
||||
|
||||
# 验证:所有结果都不包含被阻止的关键词
|
||||
for host in result:
|
||||
assert blocked_keyword not in host, f"Host '{host}' should be filtered by blacklist keyword '{blocked_keyword}'"
|
||||
|
||||
# 验证:所有不包含关键词的主机都应该在结果中
|
||||
if hosts:
|
||||
all_hosts = [hosts[0]] + [h for h in hosts[1:] if h != hosts[0]]
|
||||
expected_allowed = [h for h in all_hosts if blocked_keyword not in h]
|
||||
else:
|
||||
expected_allowed = []
|
||||
|
||||
assert set(result) == set(expected_allowed)
|
||||
|
||||
|
||||
class TestDatabaseTargetProviderUnit:
|
||||
"""DatabaseTargetProvider 单元测试类"""
|
||||
|
||||
def test_target_id_in_context(self):
|
||||
"""测试 target_id 正确设置到上下文中"""
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
assert provider.target_id == 123
|
||||
assert provider.context.target_id == 123
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(scan_id=789)
|
||||
provider = DatabaseTargetProvider(target_id=123, context=ctx)
|
||||
|
||||
assert provider.target_id == 123 # target_id 被覆盖
|
||||
assert provider.scan_id == 789
|
||||
|
||||
def test_blacklist_filter_lazy_loading(self):
|
||||
"""测试黑名单过滤器延迟加载"""
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
|
||||
# 初始时 _blacklist_filter 为 None
|
||||
assert provider._blacklist_filter is None
|
||||
|
||||
# 模拟 BlacklistService
|
||||
with patch('apps.common.services.BlacklistService') as mock_service, \
|
||||
patch('apps.common.utils.BlacklistFilter') as mock_filter_class:
|
||||
|
||||
mock_service.return_value.get_rules.return_value = []
|
||||
mock_filter_instance = MagicMock()
|
||||
mock_filter_class.return_value = mock_filter_instance
|
||||
|
||||
# 第一次调用
|
||||
result1 = provider.get_blacklist_filter()
|
||||
assert result1 == mock_filter_instance
|
||||
|
||||
# 第二次调用应该返回缓存的实例
|
||||
result2 = provider.get_blacklist_filter()
|
||||
assert result2 == mock_filter_instance
|
||||
|
||||
# BlacklistService 只应该被调用一次
|
||||
mock_service.return_value.get_rules.assert_called_once_with(123)
|
||||
|
||||
def test_nonexistent_target_returns_empty(self):
|
||||
"""测试不存在的 target 返回空迭代器"""
|
||||
provider = DatabaseTargetProvider(target_id=99999)
|
||||
|
||||
with patch('apps.targets.services.TargetService') as mock_service, \
|
||||
patch('apps.common.services.BlacklistService') as mock_blacklist_service:
|
||||
|
||||
mock_service.return_value.get_target.return_value = None
|
||||
mock_blacklist_service.return_value.get_rules.return_value = []
|
||||
|
||||
result = list(provider.iter_hosts())
|
||||
assert result == []
|
||||
152
backend/apps/scan/providers/tests/test_list_provider.py
Normal file
152
backend/apps/scan/providers/tests/test_list_provider.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
ListTargetProvider 属性测试
|
||||
|
||||
Property 1: ListTargetProvider Round-Trip
|
||||
*For any* 主机列表和 URL 列表,创建 ListTargetProvider 后迭代 iter_hosts() 和 iter_urls()
|
||||
应该返回与输入相同的元素(顺序相同)。
|
||||
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st, settings, assume
|
||||
|
||||
from apps.scan.providers.list_provider import ListTargetProvider
|
||||
from apps.scan.providers.base import ProviderContext
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
# 生成简单的域名格式: subdomain.domain.tld
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
# 生成有效 IP 地址的策略
|
||||
def valid_ip_strategy():
|
||||
"""生成有效的 IPv4 地址"""
|
||||
octet = st.integers(min_value=1, max_value=254)
|
||||
return st.builds(
|
||||
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
|
||||
octet, octet, octet, octet
|
||||
)
|
||||
|
||||
# 组合策略:域名或 IP
|
||||
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
|
||||
|
||||
# 生成有效 URL 的策略
|
||||
def valid_url_strategy():
|
||||
"""生成有效的 URL"""
|
||||
domain = valid_domain_strategy()
|
||||
return st.builds(
|
||||
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
|
||||
domain,
|
||||
st.one_of(
|
||||
st.just(""),
|
||||
st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=1,
|
||||
max_size=10
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
url_strategy = valid_url_strategy()
|
||||
|
||||
|
||||
class TestListTargetProviderProperties:
|
||||
"""ListTargetProvider 属性测试类"""
|
||||
|
||||
@given(hosts=st.lists(host_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_1_hosts_round_trip(self, hosts):
|
||||
"""
|
||||
Property 1: ListTargetProvider Round-Trip (hosts)
|
||||
|
||||
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
For any host list, creating a ListTargetProvider and iterating iter_hosts()
|
||||
should return the same elements in the same order.
|
||||
"""
|
||||
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
|
||||
provider = ListTargetProvider(targets=hosts)
|
||||
result = list(provider.iter_hosts())
|
||||
assert result == hosts
|
||||
|
||||
@given(urls=st.lists(url_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_1_urls_round_trip(self, urls):
|
||||
"""
|
||||
Property 1: ListTargetProvider Round-Trip (urls)
|
||||
|
||||
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
For any URL list, creating a ListTargetProvider and iterating iter_urls()
|
||||
should return the same elements in the same order.
|
||||
"""
|
||||
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
|
||||
provider = ListTargetProvider(targets=urls)
|
||||
result = list(provider.iter_urls())
|
||||
assert result == urls
|
||||
|
||||
@given(
|
||||
hosts=st.lists(host_strategy, max_size=30),
|
||||
urls=st.lists(url_strategy, max_size=30)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_1_combined_round_trip(self, hosts, urls):
|
||||
"""
|
||||
Property 1: ListTargetProvider Round-Trip (combined)
|
||||
|
||||
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
For any combination of hosts and URLs, both should round-trip correctly.
|
||||
"""
|
||||
# 合并 hosts 和 urls,ListTargetProvider 会自动分类
|
||||
combined = hosts + urls
|
||||
provider = ListTargetProvider(targets=combined)
|
||||
|
||||
hosts_result = list(provider.iter_hosts())
|
||||
urls_result = list(provider.iter_urls())
|
||||
|
||||
assert hosts_result == hosts
|
||||
assert urls_result == urls
|
||||
|
||||
|
||||
class TestListTargetProviderUnit:
|
||||
"""ListTargetProvider 单元测试类"""
|
||||
|
||||
def test_empty_lists(self):
|
||||
"""测试空列表返回空迭代器 - Requirements 3.5"""
|
||||
provider = ListTargetProvider()
|
||||
assert list(provider.iter_hosts()) == []
|
||||
assert list(provider.iter_urls()) == []
|
||||
|
||||
def test_blacklist_filter_returns_none(self):
|
||||
"""测试黑名单过滤器返回 None - Requirements 3.4"""
|
||||
provider = ListTargetProvider(targets=["example.com"])
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_target_id_association(self):
|
||||
"""测试 target_id 关联 - Requirements 3.3"""
|
||||
ctx = ProviderContext(target_id=123)
|
||||
provider = ListTargetProvider(targets=["example.com"], context=ctx)
|
||||
assert provider.target_id == 123
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(target_id=456, scan_id=789)
|
||||
provider = ListTargetProvider(targets=["example.com"], context=ctx)
|
||||
|
||||
assert provider.target_id == 456
|
||||
assert provider.scan_id == 789
|
||||
180
backend/apps/scan/providers/tests/test_pipeline_provider.py
Normal file
180
backend/apps/scan/providers/tests/test_pipeline_provider.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
PipelineTargetProvider 属性测试
|
||||
|
||||
Property 3: PipelineTargetProvider Round-Trip
|
||||
*For any* StageOutput 对象,PipelineTargetProvider 的 iter_hosts() 和 iter_urls()
|
||||
应该返回与 StageOutput 中 hosts 和 urls 列表相同的元素。
|
||||
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st, settings
|
||||
|
||||
from apps.scan.providers.pipeline_provider import PipelineTargetProvider, StageOutput
|
||||
from apps.scan.providers.base import ProviderContext
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
# 生成有效 IP 地址的策略
|
||||
def valid_ip_strategy():
|
||||
"""生成有效的 IPv4 地址"""
|
||||
octet = st.integers(min_value=1, max_value=254)
|
||||
return st.builds(
|
||||
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
|
||||
octet, octet, octet, octet
|
||||
)
|
||||
|
||||
# 组合策略:域名或 IP
|
||||
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
|
||||
|
||||
# 生成有效 URL 的策略
|
||||
def valid_url_strategy():
|
||||
"""生成有效的 URL"""
|
||||
domain = valid_domain_strategy()
|
||||
return st.builds(
|
||||
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
|
||||
domain,
|
||||
st.one_of(
|
||||
st.just(""),
|
||||
st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=1,
|
||||
max_size=10
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
url_strategy = valid_url_strategy()
|
||||
|
||||
|
||||
class TestPipelineTargetProviderProperties:
|
||||
"""PipelineTargetProvider 属性测试类"""
|
||||
|
||||
@given(hosts=st.lists(host_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_3_hosts_round_trip(self, hosts):
|
||||
"""
|
||||
Property 3: PipelineTargetProvider Round-Trip (hosts)
|
||||
|
||||
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
|
||||
For any StageOutput with hosts, PipelineTargetProvider should return
|
||||
the same hosts in the same order.
|
||||
"""
|
||||
stage_output = StageOutput(hosts=hosts)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
result = list(provider.iter_hosts())
|
||||
assert result == hosts
|
||||
|
||||
@given(urls=st.lists(url_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_3_urls_round_trip(self, urls):
|
||||
"""
|
||||
Property 3: PipelineTargetProvider Round-Trip (urls)
|
||||
|
||||
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
|
||||
For any StageOutput with urls, PipelineTargetProvider should return
|
||||
the same urls in the same order.
|
||||
"""
|
||||
stage_output = StageOutput(urls=urls)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
result = list(provider.iter_urls())
|
||||
assert result == urls
|
||||
|
||||
@given(
|
||||
hosts=st.lists(host_strategy, max_size=30),
|
||||
urls=st.lists(url_strategy, max_size=30)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_3_combined_round_trip(self, hosts, urls):
|
||||
"""
|
||||
Property 3: PipelineTargetProvider Round-Trip (combined)
|
||||
|
||||
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
|
||||
For any StageOutput with both hosts and urls, both should round-trip correctly.
|
||||
"""
|
||||
stage_output = StageOutput(hosts=hosts, urls=urls)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
hosts_result = list(provider.iter_hosts())
|
||||
urls_result = list(provider.iter_urls())
|
||||
|
||||
assert hosts_result == hosts
|
||||
assert urls_result == urls
|
||||
|
||||
|
||||
class TestPipelineTargetProviderUnit:
|
||||
"""PipelineTargetProvider 单元测试类"""
|
||||
|
||||
def test_empty_stage_output(self):
|
||||
"""测试空 StageOutput 返回空迭代器 - Requirements 5.5"""
|
||||
stage_output = StageOutput()
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
assert list(provider.iter_hosts()) == []
|
||||
assert list(provider.iter_urls()) == []
|
||||
|
||||
def test_blacklist_filter_returns_none(self):
|
||||
"""测试黑名单过滤器返回 None - Requirements 5.3"""
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_target_id_association(self):
|
||||
"""测试 target_id 关联 - Requirements 5.4"""
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output, target_id=123)
|
||||
assert provider.target_id == 123
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(target_id=456, scan_id=789)
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
|
||||
|
||||
assert provider.target_id == 456
|
||||
assert provider.scan_id == 789
|
||||
|
||||
def test_previous_output_property(self):
|
||||
"""测试 previous_output 属性"""
|
||||
stage_output = StageOutput(hosts=["example.com"], urls=["https://example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
assert provider.previous_output is stage_output
|
||||
assert provider.previous_output.hosts == ["example.com"]
|
||||
assert provider.previous_output.urls == ["https://example.com"]
|
||||
|
||||
def test_stage_output_with_metadata(self):
|
||||
"""测试带元数据的 StageOutput"""
|
||||
stage_output = StageOutput(
|
||||
hosts=["example.com"],
|
||||
urls=["https://example.com"],
|
||||
new_targets=["new.example.com"],
|
||||
stats={"count": 1},
|
||||
success=True,
|
||||
error=None
|
||||
)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
assert list(provider.iter_hosts()) == ["example.com"]
|
||||
assert list(provider.iter_urls()) == ["https://example.com"]
|
||||
assert provider.previous_output.new_targets == ["new.example.com"]
|
||||
assert provider.previous_output.stats == {"count": 1}
|
||||
191
backend/apps/scan/providers/tests/test_snapshot_provider.py
Normal file
191
backend/apps/scan/providers/tests/test_snapshot_provider.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
SnapshotTargetProvider 单元测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from apps.scan.providers import SnapshotTargetProvider, ProviderContext
|
||||
|
||||
|
||||
class TestSnapshotTargetProvider:
|
||||
"""SnapshotTargetProvider 测试类"""
|
||||
|
||||
def test_init_with_scan_id_and_type(self):
|
||||
"""测试初始化"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain"
|
||||
)
|
||||
|
||||
assert provider.scan_id == 100
|
||||
assert provider.snapshot_type == "subdomain"
|
||||
assert provider.target_id is None # 默认 context
|
||||
|
||||
def test_init_with_context(self):
|
||||
"""测试带 context 初始化"""
|
||||
ctx = ProviderContext(target_id=1, scan_id=100)
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain",
|
||||
context=ctx
|
||||
)
|
||||
|
||||
assert provider.scan_id == 100
|
||||
assert provider.target_id == 1
|
||||
assert provider.snapshot_type == "subdomain"
|
||||
|
||||
@patch('apps.asset.services.snapshot.SubdomainSnapshotsService')
|
||||
def test_iter_hosts_subdomain(self, mock_service_class):
|
||||
"""测试从子域名快照迭代主机"""
|
||||
# Mock service
|
||||
mock_service = Mock()
|
||||
mock_service.iter_subdomain_names_by_scan.return_value = iter([
|
||||
"a.example.com",
|
||||
"b.example.com"
|
||||
])
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain"
|
||||
)
|
||||
|
||||
# 迭代主机
|
||||
hosts = list(provider.iter_hosts())
|
||||
|
||||
assert hosts == ["a.example.com", "b.example.com"]
|
||||
mock_service.iter_subdomain_names_by_scan.assert_called_once_with(
|
||||
scan_id=100,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
@patch('apps.asset.services.snapshot.HostPortMappingSnapshotsService')
|
||||
def test_iter_hosts_host_port(self, mock_service_class):
|
||||
"""测试从主机端口映射快照迭代主机"""
|
||||
# Mock queryset
|
||||
mock_mapping1 = Mock()
|
||||
mock_mapping1.host = "example.com"
|
||||
mock_mapping1.port = 80
|
||||
|
||||
mock_mapping2 = Mock()
|
||||
mock_mapping2.host = "example.com"
|
||||
mock_mapping2.port = 443
|
||||
|
||||
mock_queryset = Mock()
|
||||
mock_queryset.iterator.return_value = iter([mock_mapping1, mock_mapping2])
|
||||
|
||||
# Mock service
|
||||
mock_service = Mock()
|
||||
mock_service.get_by_scan.return_value = mock_queryset
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="host_port"
|
||||
)
|
||||
|
||||
# 迭代主机
|
||||
hosts = list(provider.iter_hosts())
|
||||
|
||||
assert hosts == ["example.com:80", "example.com:443"]
|
||||
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
|
||||
|
||||
@patch('apps.asset.services.snapshot.WebsiteSnapshotsService')
|
||||
def test_iter_urls_website(self, mock_service_class):
|
||||
"""测试从网站快照迭代 URL"""
|
||||
# Mock service
|
||||
mock_service = Mock()
|
||||
mock_service.iter_website_urls_by_scan.return_value = iter([
|
||||
"http://example.com",
|
||||
"https://example.com"
|
||||
])
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="website"
|
||||
)
|
||||
|
||||
# 迭代 URL
|
||||
urls = list(provider.iter_urls())
|
||||
|
||||
assert urls == ["http://example.com", "https://example.com"]
|
||||
mock_service.iter_website_urls_by_scan.assert_called_once_with(
|
||||
scan_id=100,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
@patch('apps.asset.services.snapshot.EndpointSnapshotsService')
|
||||
def test_iter_urls_endpoint(self, mock_service_class):
|
||||
"""测试从端点快照迭代 URL"""
|
||||
# Mock queryset
|
||||
mock_endpoint1 = Mock()
|
||||
mock_endpoint1.url = "http://example.com/api/v1"
|
||||
|
||||
mock_endpoint2 = Mock()
|
||||
mock_endpoint2.url = "http://example.com/api/v2"
|
||||
|
||||
mock_queryset = Mock()
|
||||
mock_queryset.iterator.return_value = iter([mock_endpoint1, mock_endpoint2])
|
||||
|
||||
# Mock service
|
||||
mock_service = Mock()
|
||||
mock_service.get_by_scan.return_value = mock_queryset
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="endpoint"
|
||||
)
|
||||
|
||||
# 迭代 URL
|
||||
urls = list(provider.iter_urls())
|
||||
|
||||
assert urls == ["http://example.com/api/v1", "http://example.com/api/v2"]
|
||||
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
|
||||
|
||||
def test_iter_hosts_unsupported_type(self):
|
||||
"""测试不支持的快照类型(iter_hosts)"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="website" # website 不支持 iter_hosts
|
||||
)
|
||||
|
||||
hosts = list(provider.iter_hosts())
|
||||
assert hosts == []
|
||||
|
||||
def test_iter_urls_unsupported_type(self):
|
||||
"""测试不支持的快照类型(iter_urls)"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain" # subdomain 不支持 iter_urls
|
||||
)
|
||||
|
||||
urls = list(provider.iter_urls())
|
||||
assert urls == []
|
||||
|
||||
def test_get_blacklist_filter(self):
|
||||
"""测试黑名单过滤器(快照模式不使用黑名单)"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain"
|
||||
)
|
||||
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(target_id=456, scan_id=789)
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100, # 会被 context 覆盖
|
||||
snapshot_type="subdomain",
|
||||
context=ctx
|
||||
)
|
||||
|
||||
assert provider.target_id == 456
|
||||
assert provider.scan_id == 100 # scan_id 在 __init__ 中被设置
|
||||
@@ -464,6 +464,7 @@ class DjangoScanRepository:
|
||||
'endpoints': scan.endpoint_snapshots.count(),
|
||||
'ips': ips_count,
|
||||
'directories': scan.directory_snapshots.count(),
|
||||
'screenshots': scan.screenshot_snapshots.count(),
|
||||
'vulns_total': total_vulns,
|
||||
'vulns_critical': severity_stats['critical'],
|
||||
'vulns_high': severity_stats['high'],
|
||||
@@ -478,6 +479,7 @@ class DjangoScanRepository:
|
||||
'cached_endpoints_count': stats['endpoints'],
|
||||
'cached_ips_count': stats['ips'],
|
||||
'cached_directories_count': stats['directories'],
|
||||
'cached_screenshots_count': stats['screenshots'],
|
||||
'cached_vulns_total': stats['vulns_total'],
|
||||
'cached_vulns_critical': stats['vulns_critical'],
|
||||
'cached_vulns_high': stats['vulns_high'],
|
||||
|
||||
@@ -41,7 +41,7 @@ class ScanHistorySerializer(serializers.ModelSerializer):
|
||||
fields = [
|
||||
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
|
||||
'worker_name', 'created_at', 'status', 'error_message', 'summary',
|
||||
'progress', 'current_stage', 'stage_progress'
|
||||
'progress', 'current_stage', 'stage_progress', 'yaml_configuration'
|
||||
]
|
||||
|
||||
def get_summary(self, obj):
|
||||
@@ -51,6 +51,7 @@ class ScanHistorySerializer(serializers.ModelSerializer):
|
||||
'endpoints': obj.cached_endpoints_count or 0,
|
||||
'ips': obj.cached_ips_count or 0,
|
||||
'directories': obj.cached_directories_count or 0,
|
||||
'screenshots': obj.cached_screenshots_count or 0,
|
||||
}
|
||||
summary['vulnerabilities'] = {
|
||||
'total': obj.cached_vulns_total or 0,
|
||||
|
||||
@@ -17,7 +17,12 @@ from .scan_state_service import ScanStateService
|
||||
from .scan_control_service import ScanControlService
|
||||
from .scan_stats_service import ScanStatsService
|
||||
from .scheduled_scan_service import ScheduledScanService
|
||||
from .target_export_service import TargetExportService
|
||||
from .target_export_service import (
|
||||
TargetExportService,
|
||||
create_export_service,
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'ScanService', # 主入口(向后兼容)
|
||||
@@ -27,5 +32,8 @@ __all__ = [
|
||||
'ScanStatsService',
|
||||
'ScheduledScanService',
|
||||
'TargetExportService', # 目标导出服务
|
||||
'create_export_service',
|
||||
'export_urls_with_fallback',
|
||||
'DataSource',
|
||||
]
|
||||
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
目标导出服务
|
||||
|
||||
提供统一的目标提取和文件导出功能,支持:
|
||||
- URL 导出(流式写入 + 默认值回退)
|
||||
- URL 导出(纯导出,不做隐式回退)
|
||||
- 默认 URL 生成(独立方法)
|
||||
- 带回退链的 URL 导出(用例层编排)
|
||||
- 域名/IP 导出(用于端口扫描)
|
||||
- 黑名单过滤集成
|
||||
"""
|
||||
@@ -10,7 +12,7 @@
|
||||
import ipaddress
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
from typing import Dict, Any, Optional, List, Iterator, Tuple
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
@@ -19,6 +21,14 @@ from apps.common.utils import BlacklistFilter
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataSource:
|
||||
"""数据源类型常量"""
|
||||
ENDPOINT = "endpoint"
|
||||
WEBSITE = "website"
|
||||
HOST_PORT = "host_port"
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
def create_export_service(target_id: int) -> 'TargetExportService':
|
||||
"""
|
||||
工厂函数:创建带黑名单过滤的导出服务
|
||||
@@ -36,21 +46,281 @@ def create_export_service(target_id: int) -> 'TargetExportService':
|
||||
return TargetExportService(blacklist_filter=blacklist_filter)
|
||||
|
||||
|
||||
def _iter_default_urls_from_target(
|
||||
target_id: int,
|
||||
blacklist_filter: Optional[BlacklistFilter] = None
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
内部生成器:从 Target 本身生成默认 URL
|
||||
|
||||
根据 Target 类型生成 URL:
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
- URL: 直接使用目标 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
blacklist_filter: 黑名单过滤器
|
||||
|
||||
Yields:
|
||||
str: URL
|
||||
"""
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
return
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
# 根据 Target 类型生成 URL
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
urls = [f"http://{target_name}", f"https://{target_name}"]
|
||||
elif target_type == Target.TargetType.IP:
|
||||
urls = [f"http://{target_name}", f"https://{target_name}"]
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
urls = []
|
||||
for ip in network.hosts():
|
||||
urls.extend([f"http://{ip}", f"https://{ip}"])
|
||||
# /32 或 /128 特殊处理
|
||||
if not urls:
|
||||
ip = str(network.network_address)
|
||||
urls = [f"http://{ip}", f"https://{ip}"]
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
return
|
||||
elif target_type == Target.TargetType.URL:
|
||||
urls = [target_name]
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
return
|
||||
|
||||
# 过滤并产出
|
||||
for url in urls:
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
yield url
|
||||
|
||||
|
||||
def _iter_urls_with_fallback(
|
||||
target_id: int,
|
||||
sources: List[str],
|
||||
blacklist_filter: Optional[BlacklistFilter] = None,
|
||||
batch_size: int = 1000,
|
||||
tried_sources: Optional[List[str]] = None
|
||||
) -> Iterator[Tuple[str, str]]:
|
||||
"""
|
||||
内部生成器:流式产出 URL(带回退链)
|
||||
|
||||
按 sources 顺序尝试每个数据源,直到有数据返回。
|
||||
|
||||
回退逻辑:
|
||||
- 数据源有数据且通过过滤 → 产出 URL,停止回退
|
||||
- 数据源有数据但全被过滤 → 不回退,停止(避免意外暴露)
|
||||
- 数据源为空 → 继续尝试下一个
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
sources: 数据源优先级列表
|
||||
blacklist_filter: 黑名单过滤器
|
||||
batch_size: 批次大小
|
||||
tried_sources: 可选,用于记录尝试过的数据源(外部传入列表,会被修改)
|
||||
|
||||
Yields:
|
||||
Tuple[str, str]: (url, source) - URL 和来源标识
|
||||
"""
|
||||
from apps.asset.models import Endpoint, WebSite
|
||||
|
||||
for source in sources:
|
||||
if tried_sources is not None:
|
||||
tried_sources.append(source)
|
||||
|
||||
has_output = False # 是否有输出(通过过滤的)
|
||||
has_raw_data = False # 是否有原始数据(过滤前)
|
||||
|
||||
if source == DataSource.DEFAULT:
|
||||
# 默认 URL 生成(从 Target 本身构造,复用共用生成器)
|
||||
for url in _iter_default_urls_from_target(target_id, blacklist_filter):
|
||||
has_raw_data = True
|
||||
has_output = True
|
||||
yield url, source
|
||||
|
||||
# 检查是否有原始数据(需要单独判断,因为生成器可能被过滤后为空)
|
||||
if not has_raw_data:
|
||||
# 再次检查 Target 是否存在
|
||||
from apps.targets.services import TargetService
|
||||
target = TargetService().get_target(target_id)
|
||||
has_raw_data = target is not None
|
||||
|
||||
if has_raw_data:
|
||||
if not has_output:
|
||||
logger.info("%s 有数据但全被黑名单过滤,不回退", source)
|
||||
return
|
||||
continue
|
||||
|
||||
# 构建对应数据源的 queryset
|
||||
if source == DataSource.ENDPOINT:
|
||||
queryset = Endpoint.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
elif source == DataSource.WEBSITE:
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
else:
|
||||
logger.warning("未知的数据源类型: %s,跳过", source)
|
||||
continue
|
||||
|
||||
for url in queryset.iterator(chunk_size=batch_size):
|
||||
if url:
|
||||
has_raw_data = True
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
has_output = True
|
||||
yield url, source
|
||||
|
||||
# 有原始数据就停止(不管是否被过滤)
|
||||
if has_raw_data:
|
||||
if not has_output:
|
||||
logger.info("%s 有数据但全被黑名单过滤,不回退", source)
|
||||
return
|
||||
|
||||
logger.info("%s 为空,尝试下一个数据源", source)
|
||||
|
||||
|
||||
def get_urls_with_fallback(
|
||||
target_id: int,
|
||||
sources: List[str],
|
||||
batch_size: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
带回退链的 URL 获取用例函数(返回列表)
|
||||
|
||||
按 sources 顺序尝试每个数据源,直到有数据返回。
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
sources: 数据源优先级列表,如 ["website", "endpoint", "default"]
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'urls': List[str],
|
||||
'total_count': int,
|
||||
'source': str, # 实际使用的数据源
|
||||
'tried_sources': List[str], # 尝试过的数据源
|
||||
}
|
||||
"""
|
||||
from apps.common.services import BlacklistService
|
||||
|
||||
rules = BlacklistService().get_rules(target_id)
|
||||
blacklist_filter = BlacklistFilter(rules)
|
||||
|
||||
urls = []
|
||||
actual_source = 'none'
|
||||
tried_sources = []
|
||||
|
||||
for url, source in _iter_urls_with_fallback(target_id, sources, blacklist_filter, batch_size, tried_sources):
|
||||
urls.append(url)
|
||||
actual_source = source
|
||||
|
||||
if urls:
|
||||
logger.info("从 %s 获取 %d 条 URL", actual_source, len(urls))
|
||||
else:
|
||||
logger.warning("所有数据源都为空,无法获取 URL")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'urls': urls,
|
||||
'total_count': len(urls),
|
||||
'source': actual_source,
|
||||
'tried_sources': tried_sources,
|
||||
}
|
||||
|
||||
|
||||
def export_urls_with_fallback(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
sources: List[str],
|
||||
batch_size: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
带回退链的 URL 导出用例函数(写入文件)
|
||||
|
||||
按 sources 顺序尝试每个数据源,直到有数据返回。
|
||||
流式写入,内存占用 O(1)。
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径
|
||||
sources: 数据源优先级列表,如 ["endpoint", "website", "default"]
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int,
|
||||
'source': str, # 实际使用的数据源
|
||||
'tried_sources': List[str], # 尝试过的数据源
|
||||
}
|
||||
"""
|
||||
from apps.common.services import BlacklistService
|
||||
|
||||
rules = BlacklistService().get_rules(target_id)
|
||||
blacklist_filter = BlacklistFilter(rules)
|
||||
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
actual_source = 'none'
|
||||
tried_sources = []
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url, source in _iter_urls_with_fallback(target_id, sources, blacklist_filter, batch_size, tried_sources):
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
actual_source = source
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
if total_count > 0:
|
||||
logger.info("从 %s 导出 %d 条 URL 到 %s", actual_source, total_count, output_file)
|
||||
else:
|
||||
logger.warning("所有数据源都为空,无法导出 URL")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
'source': actual_source,
|
||||
'tried_sources': tried_sources,
|
||||
}
|
||||
|
||||
|
||||
class TargetExportService:
|
||||
"""
|
||||
目标导出服务 - 提供统一的目标提取和文件导出功能
|
||||
|
||||
使用方式:
|
||||
from apps.common.services import BlacklistService
|
||||
from apps.common.utils import BlacklistFilter
|
||||
# 方式 1:使用用例函数(推荐)
|
||||
from apps.scan.services.target_export_service import export_urls_with_fallback, DataSource
|
||||
|
||||
# 获取规则并创建过滤器
|
||||
blacklist_service = BlacklistService()
|
||||
rules = blacklist_service.get_rules(target_id)
|
||||
blacklist_filter = BlacklistFilter(rules)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=1,
|
||||
output_file='/path/to/output.txt',
|
||||
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT]
|
||||
)
|
||||
|
||||
# 使用导出服务
|
||||
export_service = TargetExportService(blacklist_filter=blacklist_filter)
|
||||
# 方式 2:直接使用 Service(纯导出,不带回退)
|
||||
export_service = create_export_service(target_id)
|
||||
result = export_service.export_urls(target_id, output_path, queryset)
|
||||
"""
|
||||
|
||||
@@ -72,16 +342,14 @@ class TargetExportService:
|
||||
batch_size: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
统一 URL 导出函数
|
||||
纯 URL 导出函数 - 只负责将 queryset 数据写入文件
|
||||
|
||||
自动判断数据库有无数据:
|
||||
- 有数据:流式写入数据库数据到文件
|
||||
- 无数据:调用默认值生成器生成 URL
|
||||
不做任何隐式回退或默认 URL 生成。
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_path: 输出文件路径
|
||||
queryset: 数据源 queryset(由 Task 层构建,应为 values_list flat=True)
|
||||
queryset: 数据源 queryset(由调用方构建,应为 values_list flat=True)
|
||||
url_field: URL 字段名(用于黑名单过滤)
|
||||
batch_size: 批次大小
|
||||
|
||||
@@ -89,7 +357,9 @@ class TargetExportService:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int
|
||||
'total_count': int, # 实际写入数量
|
||||
'queryset_count': int, # 原始数据数量(迭代计数)
|
||||
'filtered_count': int, # 被黑名单过滤的数量
|
||||
}
|
||||
|
||||
Raises:
|
||||
@@ -102,9 +372,12 @@ class TargetExportService:
|
||||
|
||||
total_count = 0
|
||||
filtered_count = 0
|
||||
queryset_count = 0
|
||||
|
||||
try:
|
||||
with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in queryset.iterator(chunk_size=batch_size):
|
||||
queryset_count += 1
|
||||
if url:
|
||||
# 黑名单过滤
|
||||
if self.blacklist_filter and not self.blacklist_filter.is_allowed(url):
|
||||
@@ -122,25 +395,26 @@ class TargetExportService:
|
||||
if filtered_count > 0:
|
||||
logger.info("黑名单过滤: 过滤 %d 个 URL", filtered_count)
|
||||
|
||||
# 默认值回退模式
|
||||
if total_count == 0:
|
||||
total_count = self._generate_default_urls(target_id, output_file)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 数量: %d, 文件: %s", total_count, output_path)
|
||||
logger.info(
|
||||
"✓ URL 导出完成 - 写入: %d, 原始: %d, 过滤: %d, 文件: %s",
|
||||
total_count, queryset_count, filtered_count, output_path
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_file),
|
||||
'total_count': total_count
|
||||
'total_count': total_count,
|
||||
'queryset_count': queryset_count,
|
||||
'filtered_count': filtered_count,
|
||||
}
|
||||
|
||||
def _generate_default_urls(
|
||||
def generate_default_urls(
|
||||
self,
|
||||
target_id: int,
|
||||
output_path: Path
|
||||
) -> int:
|
||||
output_path: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
默认值生成器(内部函数)
|
||||
默认 URL 生成器
|
||||
|
||||
根据 Target 类型生成默认 URL:
|
||||
- DOMAIN: http(s)://domain
|
||||
@@ -153,82 +427,34 @@ class TargetExportService:
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 写入的 URL 总数
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int,
|
||||
}
|
||||
"""
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
output_file = Path(output_path)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
return 0
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name)
|
||||
logger.info("生成默认 URL - target_id=%d", target_id)
|
||||
|
||||
total_urls = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
urls = [f"http://{target_name}", f"https://{target_name}"]
|
||||
for url in urls:
|
||||
if self._should_write_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
urls = [f"http://{target_name}", f"https://{target_name}"]
|
||||
for url in urls:
|
||||
if self._should_write_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
|
||||
for ip in network.hosts():
|
||||
urls = [f"http://{ip}", f"https://{ip}"]
|
||||
for url in urls:
|
||||
if self._should_write_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
if total_urls % 10000 == 0:
|
||||
logger.info("已生成 %d 个 URL...", total_urls)
|
||||
|
||||
# /32 或 /128 特殊处理
|
||||
if total_urls == 0:
|
||||
ip = str(network.network_address)
|
||||
urls = [f"http://{ip}", f"https://{ip}"]
|
||||
for url in urls:
|
||||
if self._should_write_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
raise ValueError(f"无效的 CIDR: {target_name}") from e
|
||||
|
||||
elif target_type == Target.TargetType.URL:
|
||||
if self._should_write_url(target_name):
|
||||
f.write(f"{target_name}\n")
|
||||
total_urls = 1
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in _iter_default_urls_from_target(target_id, self.blacklist_filter):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
if total_urls % 10000 == 0:
|
||||
logger.info("已生成 %d 个 URL...", total_urls)
|
||||
|
||||
logger.info("✓ 懒加载生成默认 URL - 数量: %d", total_urls)
|
||||
return total_urls
|
||||
|
||||
def _should_write_url(self, url: str) -> bool:
|
||||
"""检查 URL 是否应该写入(通过黑名单过滤)"""
|
||||
if self.blacklist_filter:
|
||||
return self.blacklist_filter.is_allowed(url)
|
||||
return True
|
||||
logger.info("✓ 默认 URL 生成完成 - 数量: %d", total_urls)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_file),
|
||||
'total_count': total_urls,
|
||||
}
|
||||
|
||||
def export_hosts(
|
||||
self,
|
||||
@@ -259,8 +485,7 @@ class TargetExportService:
|
||||
"""
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
|
||||
|
||||
output_file = Path(output_path)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -1,39 +1,48 @@
|
||||
"""
|
||||
导出站点 URL 到 TXT 文件的 Task
|
||||
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
数据源: WebSite.url
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
数据源: WebSite.url → Default
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.models import WebSite
|
||||
from apps.scan.services import TargetExportService
|
||||
from apps.scan.services.target_export_service import create_export_service
|
||||
from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_sites")
|
||||
def export_sites_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
output_file: str = "",
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000,
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点 URL 到 TXT 文件
|
||||
|
||||
数据源: WebSite.url
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. WebSite 表 - 站点级别 URL
|
||||
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 每次读取的批次大小,默认 1000
|
||||
|
||||
Returns:
|
||||
@@ -47,25 +56,61 @@ def export_sites_task(
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# 使用工厂函数创建导出服务
|
||||
export_service = create_export_service(target_id)
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
result = export_service.export_urls(
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
output_file=output_file,
|
||||
sources=[DataSource.WEBSITE, DataSource.DEFAULT],
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"站点 URL 导出完成 - source=%s, count=%d",
|
||||
result['source'], result['total_count']
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'success': result['success'],
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count']
|
||||
'total_count': result['total_count'],
|
||||
}
|
||||
|
||||
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
}
|
||||
|
||||
@@ -1,64 +1,112 @@
|
||||
"""
|
||||
导出 URL 任务
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
用于指纹识别前导出目标下的 URL 到文件
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.models import WebSite
|
||||
from apps.scan.services.target_export_service import create_export_service
|
||||
from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_urls_for_fingerprint")
|
||||
def export_urls_for_fingerprint_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
source: str = 'website',
|
||||
target_id: Optional[int] = None,
|
||||
output_file: str = "",
|
||||
source: str = 'website', # 保留参数,兼容旧调用(实际值由回退链决定)
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的 URL 到文件(用于指纹识别)
|
||||
|
||||
数据源: WebSite.url
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
- URL: 直接使用目标 URL
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. WebSite 表 - 站点级别 URL
|
||||
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
output_file: 输出文件路径
|
||||
source: 数据源类型(保留参数,兼容旧调用)
|
||||
source: 数据源类型(保留参数,兼容旧调用,实际值由回退链决定)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 批量读取大小
|
||||
|
||||
Returns:
|
||||
dict: {'output_file': str, 'total_count': int, 'source': str}
|
||||
"""
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# 使用工厂函数创建导出服务
|
||||
export_service = create_export_service(target_id)
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
result = export_service.export_urls(
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
output_file=output_file,
|
||||
sources=[DataSource.WEBSITE, DataSource.DEFAULT],
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
logger.info(
|
||||
"指纹识别 URL 导出完成 - source=%s, count=%d",
|
||||
result['source'], result['total_count']
|
||||
)
|
||||
|
||||
# 返回实际使用的数据源(不再固定为 "website")
|
||||
return {
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
'source': source
|
||||
'source': result['source'],
|
||||
}
|
||||
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
return {
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
'source': 'provider',
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
导出主机列表到 TXT 文件的 Task
|
||||
|
||||
使用 TargetExportService.export_hosts() 统一处理导出逻辑
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
根据 Target 类型决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名
|
||||
@@ -9,57 +11,89 @@
|
||||
- CIDR: 展开 CIDR 范围内的所有 IP
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.services.target_export_service import create_export_service
|
||||
from apps.scan.providers import DatabaseTargetProvider, TargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_hosts")
|
||||
def export_hosts_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
batch_size: int = 1000
|
||||
target_id: Optional[int] = None,
|
||||
provider: Optional[TargetProvider] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
导出主机列表到 TXT 文件
|
||||
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
根据 Target 类型自动决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名(流式处理,支持 10万+ 域名)
|
||||
- IP: 直接写入 target.name(单个 IP)
|
||||
- CIDR: 展开 CIDR 范围内的所有可用 IP
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
batch_size: 每次读取的批次大小,默认 1000(仅对 DOMAIN 类型有效)
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int,
|
||||
'target_type': str
|
||||
'target_type': str # 仅传统模式返回
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: Target 不存在
|
||||
ValueError: 参数错误(target_id 和 provider 都未提供)
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
# 使用工厂函数创建导出服务
|
||||
export_service = create_export_service(target_id)
|
||||
|
||||
result = export_service.export_hosts(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'success': result['success'],
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
'target_type': result['target_type']
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# 向后兼容:如果没有提供 provider,使用 target_id 创建 DatabaseTargetProvider
|
||||
use_legacy_mode = provider is None
|
||||
if use_legacy_mode:
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
provider = DatabaseTargetProvider(target_id=target_id)
|
||||
else:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用 Provider 导出主机列表(iter_hosts 内部已处理黑名单过滤)
|
||||
total_count = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for host in provider.iter_hosts():
|
||||
f.write(f"{host}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个主机...", total_count)
|
||||
|
||||
logger.info("✓ 主机列表导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
result = {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
}
|
||||
|
||||
# 传统模式:保持返回值格式不变(向后兼容)
|
||||
if use_legacy_mode:
|
||||
from apps.targets.services import TargetService
|
||||
target = TargetService().get_target(target_id)
|
||||
result['target_type'] = target.type if target else 'unknown'
|
||||
|
||||
return result
|
||||
|
||||
12
backend/apps/scan/tasks/screenshot/__init__.py
Normal file
12
backend/apps/scan/tasks/screenshot/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
截图任务模块
|
||||
|
||||
包含截图相关的所有任务:
|
||||
- capture_screenshots_task: 批量截图任务
|
||||
"""
|
||||
|
||||
from .capture_screenshots_task import capture_screenshots_task
|
||||
|
||||
__all__ = [
|
||||
'capture_screenshots_task',
|
||||
]
|
||||
194
backend/apps/scan/tasks/screenshot/capture_screenshots_task.py
Normal file
194
backend/apps/scan/tasks/screenshot/capture_screenshots_task.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
批量截图任务
|
||||
|
||||
使用 Playwright 批量捕获网站截图,压缩后保存到数据库
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from prefect import task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""
|
||||
在同步环境中运行异步协程
|
||||
|
||||
Args:
|
||||
coro: 异步协程
|
||||
|
||||
Returns:
|
||||
协程执行结果
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.run_until_complete(coro)
|
||||
|
||||
|
||||
def _save_screenshot_with_retry(
|
||||
screenshot_service,
|
||||
scan_id: int,
|
||||
url: str,
|
||||
webp_data: bytes,
|
||||
status_code: int | None = None,
|
||||
max_retries: int = 3
|
||||
) -> bool:
|
||||
"""
|
||||
保存截图到数据库(带重试机制)
|
||||
|
||||
Args:
|
||||
screenshot_service: ScreenshotService 实例
|
||||
scan_id: 扫描 ID
|
||||
url: URL
|
||||
webp_data: WebP 图片数据
|
||||
status_code: HTTP 响应状态码
|
||||
max_retries: 最大重试次数
|
||||
|
||||
Returns:
|
||||
是否保存成功
|
||||
"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
if screenshot_service.save_screenshot_snapshot(scan_id, url, webp_data, status_code):
|
||||
return True
|
||||
# save 返回 False,等待后重试
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2 ** attempt # 指数退避:1s, 2s, 4s
|
||||
logger.warning(
|
||||
"保存截图失败(第 %d 次尝试),%d秒后重试: %s",
|
||||
attempt + 1, wait_time, url
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2 ** attempt
|
||||
logger.warning(
|
||||
"保存截图异常(第 %d 次尝试),%d秒后重试: %s, 错误: %s",
|
||||
attempt + 1, wait_time, url, str(e)[:100]
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
logger.error("保存截图失败(已重试 %d 次): %s", max_retries, url)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _capture_and_save_screenshots(
|
||||
urls: list[str],
|
||||
scan_id: int,
|
||||
concurrency: int
|
||||
) -> dict:
|
||||
"""
|
||||
异步批量截图并保存
|
||||
|
||||
Args:
|
||||
urls: URL 列表
|
||||
scan_id: 扫描 ID
|
||||
concurrency: 并发数
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
from asgiref.sync import sync_to_async
|
||||
from apps.asset.services.playwright_screenshot_service import PlaywrightScreenshotService
|
||||
from apps.asset.services.screenshot_service import ScreenshotService
|
||||
|
||||
# 初始化服务
|
||||
playwright_service = PlaywrightScreenshotService(concurrency=concurrency)
|
||||
screenshot_service = ScreenshotService()
|
||||
|
||||
# 包装同步的保存函数为异步
|
||||
async_save_with_retry = sync_to_async(_save_screenshot_with_retry, thread_sensitive=True)
|
||||
|
||||
# 统计
|
||||
total = len(urls)
|
||||
successful = 0
|
||||
failed = 0
|
||||
|
||||
logger.info("开始批量截图 - URL数: %d, 并发数: %d", total, concurrency)
|
||||
|
||||
# 批量截图
|
||||
async for url, screenshot_bytes, status_code in playwright_service.capture_batch(urls):
|
||||
if screenshot_bytes is None:
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 压缩为 WebP
|
||||
webp_data = screenshot_service.compress_from_bytes(screenshot_bytes)
|
||||
if webp_data is None:
|
||||
logger.warning("压缩截图失败: %s", url)
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 保存到数据库(带重试,使用 sync_to_async)
|
||||
if await async_save_with_retry(screenshot_service, scan_id, url, webp_data, status_code):
|
||||
successful += 1
|
||||
if successful % 10 == 0:
|
||||
logger.info("截图进度: %d/%d 成功", successful, total)
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
return {
|
||||
'total': total,
|
||||
'successful': successful,
|
||||
'failed': failed
|
||||
}
|
||||
|
||||
|
||||
@task(name='capture_screenshots', retries=0)
|
||||
def capture_screenshots_task(
|
||||
urls: list[str],
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
config: dict
|
||||
) -> dict:
|
||||
"""
|
||||
批量截图任务
|
||||
|
||||
Args:
|
||||
urls: URL 列表
|
||||
scan_id: 扫描 ID
|
||||
target_id: 目标 ID(用于日志)
|
||||
config: 截图配置
|
||||
- concurrency: 并发数(默认 5)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'total': int, # 总 URL 数
|
||||
'successful': int, # 成功截图数
|
||||
'failed': int # 失败数
|
||||
}
|
||||
"""
|
||||
if not urls:
|
||||
logger.info("URL 列表为空,跳过截图任务")
|
||||
return {'total': 0, 'successful': 0, 'failed': 0}
|
||||
|
||||
concurrency = config.get('concurrency', 5)
|
||||
|
||||
logger.info(
|
||||
"开始截图任务 - scan_id=%d, target_id=%d, URL数=%d, 并发=%d",
|
||||
scan_id, target_id, len(urls), concurrency
|
||||
)
|
||||
|
||||
try:
|
||||
result = _run_async(_capture_and_save_screenshots(
|
||||
urls=urls,
|
||||
scan_id=scan_id,
|
||||
concurrency=concurrency
|
||||
))
|
||||
|
||||
logger.info(
|
||||
"✓ 截图任务完成 - 总数: %d, 成功: %d, 失败: %d",
|
||||
result['total'], result['successful'], result['failed']
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("截图任务失败: %s", e, exc_info=True)
|
||||
raise RuntimeError(f"截图任务失败: {e}") from e
|
||||
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
导出站点URL到文件的Task
|
||||
|
||||
直接使用 HostPortMapping 表查询 host+port 组合,拼接成URL格式写入文件
|
||||
使用 TargetExportService 处理默认值回退逻辑
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
@@ -10,6 +11,7 @@
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
@@ -17,6 +19,7 @@ from apps.asset.services import HostPortMappingService
|
||||
from apps.scan.services.target_export_service import create_export_service
|
||||
from apps.common.services import BlacklistService
|
||||
from apps.common.utils import BlacklistFilter
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider, ProviderContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,29 +42,30 @@ def _generate_urls_from_port(host: str, port: int) -> list[str]:
|
||||
|
||||
@task(name="export_site_urls")
|
||||
def export_site_urls_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点URL到文件(基于 HostPortMapping 表)
|
||||
导出目标下的所有站点URL到文件
|
||||
|
||||
数据源: HostPortMapping (host + port)
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从 HostPortMapping 表导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
特殊逻辑:
|
||||
传统模式特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
回退逻辑(仅传统模式):
|
||||
- 如果 HostPortMapping 为空,使用 generate_default_urls() 生成默认 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
target_id: 目标ID(传统模式,向后兼容)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 每次处理的批次大小
|
||||
|
||||
Returns:
|
||||
@@ -69,13 +73,62 @@ def export_site_urls_task(
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_urls': int,
|
||||
'association_count': int # 主机端口关联数量
|
||||
'association_count': int, # 主机端口关联数量(仅传统模式)
|
||||
'source': str, # 数据来源: "host_port" | "default" | "provider"
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# 向后兼容:如果没有提供 provider,使用传统模式
|
||||
if provider is None:
|
||||
logger.info("使用传统模式 - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
return _export_site_urls_legacy(target_id, output_file, batch_size)
|
||||
|
||||
# Provider 模式
|
||||
logger.info("使用 Provider 模式 - Provider: %s, 输出文件: %s", type(provider).__name__, output_file)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用 Provider 导出 URL 列表
|
||||
total_urls = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
if total_urls % 1000 == 0:
|
||||
logger.info("已导出 %d 个URL...", total_urls)
|
||||
|
||||
logger.info("✓ URL导出完成 - 总数: %d, 文件: %s", total_urls, str(output_path))
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_urls': total_urls,
|
||||
'source': 'provider',
|
||||
}
|
||||
|
||||
|
||||
def _export_site_urls_legacy(target_id: int, output_file: str, batch_size: int) -> dict:
|
||||
"""
|
||||
传统模式:从 HostPortMapping 表导出 URL
|
||||
|
||||
保持原有逻辑不变,确保向后兼容
|
||||
"""
|
||||
logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
|
||||
# 确保输出目录存在
|
||||
@@ -94,6 +147,7 @@ def export_site_urls_task(
|
||||
|
||||
total_urls = 0
|
||||
association_count = 0
|
||||
filtered_count = 0
|
||||
|
||||
# 流式写入文件(特殊端口逻辑)
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
@@ -104,6 +158,7 @@ def export_site_urls_task(
|
||||
|
||||
# 先校验 host,通过了再生成 URL
|
||||
if not blacklist_filter.is_allowed(host):
|
||||
filtered_count += 1
|
||||
continue
|
||||
|
||||
# 根据端口号生成URL
|
||||
@@ -114,19 +169,40 @@ def export_site_urls_task(
|
||||
if association_count % 1000 == 0:
|
||||
logger.info("已处理 %d 条关联,生成 %d 个URL...", association_count, total_urls)
|
||||
|
||||
if filtered_count > 0:
|
||||
logger.info("黑名单过滤: 过滤 %d 条关联", filtered_count)
|
||||
|
||||
logger.info(
|
||||
"✓ 站点URL导出完成 - 关联数: %d, 总URL数: %d, 文件: %s",
|
||||
association_count, total_urls, str(output_path)
|
||||
)
|
||||
|
||||
# 默认值回退模式:使用工厂函数创建导出服务
|
||||
# 判断数据来源
|
||||
source = "host_port"
|
||||
|
||||
# 数据存在但全被过滤,不回退
|
||||
if association_count > 0 and total_urls == 0:
|
||||
logger.info("HostPortMapping 有 %d 条数据,但全被黑名单过滤,不回退", association_count)
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_urls': 0,
|
||||
'association_count': association_count,
|
||||
'source': source,
|
||||
}
|
||||
|
||||
# 数据源为空,回退到默认 URL 生成
|
||||
if total_urls == 0:
|
||||
logger.info("HostPortMapping 为空,使用默认 URL 生成")
|
||||
export_service = create_export_service(target_id)
|
||||
total_urls = export_service._generate_default_urls(target_id, output_path)
|
||||
result = export_service.generate_default_urls(target_id, str(output_path))
|
||||
total_urls = result['total_count']
|
||||
source = "default"
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_urls': total_urls,
|
||||
'association_count': association_count
|
||||
'association_count': association_count,
|
||||
'source': source,
|
||||
}
|
||||
|
||||
@@ -341,11 +341,12 @@ def _save_batch(
|
||||
)
|
||||
|
||||
snapshot_items.append(snapshot_dto)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("处理记录失败: %s,错误: %s", record.url, e)
|
||||
continue
|
||||
|
||||
# ========== Step 3: 保存快照并同步到资产表(通过快照 Service)==========
|
||||
# ========== Step 2: 保存快照并同步到资产表(通过快照 Service)==========
|
||||
if snapshot_items:
|
||||
services.snapshot.save_and_sync(snapshot_items)
|
||||
|
||||
|
||||
@@ -20,63 +20,40 @@ Note:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from prefect import task
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from prefect import task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 注:使用纯系统命令实现,无需 Python 缓冲区配置
|
||||
# 工具(amass/subfinder)输出已是小写且标准化
|
||||
|
||||
@task(
|
||||
name='merge_and_deduplicate',
|
||||
retries=1,
|
||||
log_prints=True
|
||||
)
|
||||
def merge_and_validate_task(
|
||||
result_files: List[str],
|
||||
result_dir: str
|
||||
) -> str:
|
||||
"""
|
||||
合并扫描结果并去重(高性能流式处理)
|
||||
|
||||
流程:
|
||||
1. 使用 LC_ALL=C sort -u 直接处理多文件
|
||||
2. 排序去重一步完成
|
||||
3. 返回去重后的文件路径
|
||||
|
||||
命令:LC_ALL=C sort -u file1 file2 file3 -o output
|
||||
注:工具输出已标准化(小写,无空行),无需额外处理
|
||||
|
||||
Args:
|
||||
result_files: 结果文件路径列表
|
||||
result_dir: 结果目录
|
||||
|
||||
Returns:
|
||||
str: 去重后的域名文件路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 处理失败
|
||||
|
||||
Performance:
|
||||
- 纯系统命令(C语言实现),单进程极简
|
||||
- LC_ALL=C: 字节序比较
|
||||
- sort -u: 直接处理多文件(无管道开销)
|
||||
|
||||
Design:
|
||||
- 极简单命令,无冗余处理
|
||||
- 单进程直接执行(无管道/重定向开销)
|
||||
- 内存占用仅在 sort 阶段(外部排序,不会 OOM)
|
||||
"""
|
||||
logger.info("开始合并并去重 %d 个结果文件(系统命令优化)", len(result_files))
|
||||
|
||||
result_path = Path(result_dir)
|
||||
|
||||
# 验证文件存在性
|
||||
def _count_file_lines(file_path: str) -> int:
|
||||
"""使用 wc -l 统计文件行数,失败时返回 0"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["wc", "-l", file_path],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return int(result.stdout.strip().split()[0])
|
||||
except (subprocess.CalledProcessError, ValueError, IndexError):
|
||||
return 0
|
||||
|
||||
|
||||
def _calculate_timeout(total_lines: int) -> int:
|
||||
"""根据总行数计算超时时间(每行约 0.1 秒,最少 600 秒)"""
|
||||
if total_lines <= 0:
|
||||
return 3600
|
||||
return max(600, int(total_lines * 0.1))
|
||||
|
||||
|
||||
def _validate_input_files(result_files: List[str]) -> List[str]:
|
||||
"""验证输入文件存在性,返回有效文件列表"""
|
||||
valid_files = []
|
||||
for file_path_str in result_files:
|
||||
file_path = Path(file_path_str)
|
||||
@@ -84,112 +61,67 @@ def merge_and_validate_task(
|
||||
valid_files.append(str(file_path))
|
||||
else:
|
||||
logger.warning("结果文件不存在: %s", file_path)
|
||||
|
||||
return valid_files
|
||||
|
||||
|
||||
@task(name='merge_and_deduplicate', retries=1, log_prints=True)
|
||||
def merge_and_validate_task(result_files: List[str], result_dir: str) -> str:
|
||||
"""
|
||||
合并扫描结果并去重(高性能流式处理)
|
||||
|
||||
使用 LC_ALL=C sort -u 直接处理多文件,排序去重一步完成。
|
||||
|
||||
Args:
|
||||
result_files: 结果文件路径列表
|
||||
result_dir: 结果目录
|
||||
|
||||
Returns:
|
||||
去重后的域名文件路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 处理失败
|
||||
"""
|
||||
logger.info("开始合并并去重 %d 个结果文件", len(result_files))
|
||||
|
||||
valid_files = _validate_input_files(result_files)
|
||||
if not valid_files:
|
||||
raise RuntimeError("所有结果文件都不存在")
|
||||
|
||||
|
||||
# 生成输出文件路径
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
short_uuid = uuid.uuid4().hex[:4]
|
||||
merged_file = result_path / f"merged_{timestamp}_{short_uuid}.txt"
|
||||
|
||||
merged_file = Path(result_dir) / f"merged_{timestamp}_{short_uuid}.txt"
|
||||
|
||||
# 计算超时时间
|
||||
total_lines = sum(_count_file_lines(f) for f in valid_files)
|
||||
timeout = _calculate_timeout(total_lines)
|
||||
logger.info("合并去重: 输入总行数=%d, timeout=%d秒", total_lines, timeout)
|
||||
|
||||
# 执行合并去重命令
|
||||
cmd = f"LC_ALL=C sort -u {' '.join(valid_files)} -o {merged_file}"
|
||||
logger.debug("执行命令: %s", cmd)
|
||||
|
||||
try:
|
||||
# ==================== 使用系统命令一步完成:排序去重 ====================
|
||||
# LC_ALL=C: 使用字节序比较(比locale快20-30%)
|
||||
# sort -u: 直接处理多文件,排序去重
|
||||
# -o: 安全输出(比重定向更可靠)
|
||||
cmd = f"LC_ALL=C sort -u {' '.join(valid_files)} -o {merged_file}"
|
||||
|
||||
logger.debug("执行命令: %s", cmd)
|
||||
subprocess.run(cmd, shell=True, check=True, timeout=timeout)
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
raise RuntimeError("合并去重超时,请检查数据量或系统资源") from exc
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise RuntimeError(f"系统命令执行失败: {exc.stderr or exc}") from exc
|
||||
|
||||
# 按输入文件总行数动态计算超时时间
|
||||
total_lines = 0
|
||||
for file_path in valid_files:
|
||||
try:
|
||||
line_count_proc = subprocess.run(
|
||||
["wc", "-l", file_path],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
total_lines += int(line_count_proc.stdout.strip().split()[0])
|
||||
except (subprocess.CalledProcessError, ValueError, IndexError):
|
||||
continue
|
||||
# 验证输出文件
|
||||
if not merged_file.exists():
|
||||
raise RuntimeError("合并文件未被创建")
|
||||
|
||||
timeout = 3600
|
||||
if total_lines > 0:
|
||||
# 按行数线性计算:每行约 0.1 秒
|
||||
base_per_line = 0.1
|
||||
est = int(total_lines * base_per_line)
|
||||
timeout = max(600, est)
|
||||
unique_count = _count_file_lines(str(merged_file))
|
||||
if unique_count == 0:
|
||||
# 降级为 Python 统计
|
||||
with open(merged_file, 'r', encoding='utf-8') as f:
|
||||
unique_count = sum(1 for _ in f)
|
||||
|
||||
logger.info(
|
||||
"Subdomain 合并去重 timeout 自动计算: 输入总行数=%d, timeout=%d秒",
|
||||
total_lines,
|
||||
timeout,
|
||||
)
|
||||
if unique_count == 0:
|
||||
raise RuntimeError("未找到任何有效域名")
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
check=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
logger.debug("✓ 合并去重完成")
|
||||
|
||||
# ==================== 统计结果 ====================
|
||||
if not merged_file.exists():
|
||||
raise RuntimeError("合并文件未被创建")
|
||||
|
||||
# 统计行数(使用系统命令提升大文件性能)
|
||||
try:
|
||||
line_count_proc = subprocess.run(
|
||||
["wc", "-l", str(merged_file)],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
unique_count = int(line_count_proc.stdout.strip().split()[0])
|
||||
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
|
||||
logger.warning(
|
||||
"wc -l 统计失败(文件: %s),降级为 Python 逐行统计 - 错误: %s",
|
||||
merged_file, e
|
||||
)
|
||||
unique_count = 0
|
||||
with open(merged_file, 'r', encoding='utf-8') as file_obj:
|
||||
for _ in file_obj:
|
||||
unique_count += 1
|
||||
|
||||
if unique_count == 0:
|
||||
raise RuntimeError("未找到任何有效域名")
|
||||
|
||||
file_size = merged_file.stat().st_size
|
||||
|
||||
logger.info(
|
||||
"✓ 合并去重完成 - 去重后: %d 个域名, 文件大小: %.2f KB",
|
||||
unique_count,
|
||||
file_size / 1024
|
||||
)
|
||||
|
||||
return str(merged_file)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
error_msg = "合并去重超时(>60分钟),请检查数据量或系统资源"
|
||||
logger.warning(error_msg) # 超时是可预期的
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_msg = f"系统命令执行失败: {e.stderr if e.stderr else str(e)}"
|
||||
logger.warning(error_msg) # 超时是可预期的
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
except IOError as e:
|
||||
error_msg = f"文件读写失败: {e}"
|
||||
logger.warning(error_msg) # 超时是可预期的
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"合并去重失败: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise
|
||||
file_size_kb = merged_file.stat().st_size / 1024
|
||||
logger.info("✓ 合并去重完成 - 去重后: %d 个域名, 文件大小: %.2f KB", unique_count, file_size_kb)
|
||||
|
||||
return str(merged_file)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
运行扫描工具任务
|
||||
|
||||
负责运行单个子域名扫描工具(amass、subfinder 等)
|
||||
负责运行单个子域名扫描工具(subfinder、sublist3r 等)
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -58,7 +58,7 @@ def run_subdomain_discovery_task(
|
||||
timeout=timeout,
|
||||
log_file=log_file # 明确指定日志文件路径
|
||||
)
|
||||
|
||||
|
||||
# 验证输出文件是否生成
|
||||
if not output_file_path.exists():
|
||||
logger.warning(
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Task 向后兼容性测试
|
||||
|
||||
Property 8: Task Backward Compatibility
|
||||
*For any* 任务调用,当仅提供 target_id 参数时,任务应该创建 DatabaseTargetProvider
|
||||
并使用它进行数据访问,行为与改造前一致。
|
||||
|
||||
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
from hypothesis import given, strategies as st, settings
|
||||
|
||||
from apps.scan.tasks.port_scan.export_hosts_task import export_hosts_task
|
||||
from apps.scan.tasks.site_scan.export_site_urls_task import export_site_urls_task
|
||||
from apps.scan.providers import ListTargetProvider
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
|
||||
class TestExportHostsTaskBackwardCompatibility:
|
||||
"""export_hosts_task 向后兼容性测试"""
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=1000),
|
||||
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=10)
|
||||
)
|
||||
@settings(max_examples=50, deadline=None)
|
||||
def test_property_8_legacy_mode_creates_database_provider(self, target_id, hosts):
|
||||
"""
|
||||
Property 8: Task Backward Compatibility (export_hosts_task)
|
||||
|
||||
Feature: scan-target-provider, Property 8: Task Backward Compatibility
|
||||
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
|
||||
|
||||
For any target_id, when calling export_hosts_task with only target_id,
|
||||
it should create a DatabaseTargetProvider and use it for data access.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
# Mock Target 和 SubdomainService
|
||||
mock_target = MagicMock()
|
||||
mock_target.type = 'domain'
|
||||
mock_target.name = hosts[0]
|
||||
|
||||
with patch('apps.scan.tasks.port_scan.export_hosts_task.DatabaseTargetProvider') as mock_provider_class, \
|
||||
patch('apps.targets.services.TargetService') as mock_target_service:
|
||||
|
||||
# 创建 mock provider 实例
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.iter_hosts.return_value = iter(hosts)
|
||||
mock_provider.get_blacklist_filter.return_value = None
|
||||
mock_provider_class.return_value = mock_provider
|
||||
|
||||
# Mock TargetService
|
||||
mock_target_service.return_value.get_target.return_value = mock_target
|
||||
|
||||
# 调用任务(传统模式:只传 target_id)
|
||||
result = export_hosts_task(
|
||||
output_file=output_file,
|
||||
target_id=target_id
|
||||
)
|
||||
|
||||
# 验证:应该创建了 DatabaseTargetProvider
|
||||
mock_provider_class.assert_called_once_with(target_id=target_id)
|
||||
|
||||
# 验证:返回值包含必需字段
|
||||
assert result['success'] is True
|
||||
assert result['output_file'] == output_file
|
||||
assert result['total_count'] == len(hosts)
|
||||
assert 'target_type' in result # 传统模式应该返回 target_type
|
||||
|
||||
# 验证:文件内容正确
|
||||
with open(output_file, 'r') as f:
|
||||
lines = [line.strip() for line in f.readlines()]
|
||||
assert lines == hosts
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
def test_legacy_mode_with_provider_parameter(self):
|
||||
"""测试当同时提供 target_id 和 provider 时,provider 优先"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
hosts = ['example.com', 'test.com']
|
||||
provider = ListTargetProvider(targets=hosts)
|
||||
|
||||
# 调用任务(同时提供 target_id 和 provider)
|
||||
result = export_hosts_task(
|
||||
output_file=output_file,
|
||||
target_id=123, # 应该被忽略
|
||||
provider=provider
|
||||
)
|
||||
|
||||
# 验证:使用了 provider
|
||||
assert result['success'] is True
|
||||
assert result['total_count'] == len(hosts)
|
||||
assert 'target_type' not in result # Provider 模式不返回 target_type
|
||||
|
||||
# 验证:文件内容正确
|
||||
with open(output_file, 'r') as f:
|
||||
lines = [line.strip() for line in f.readlines()]
|
||||
assert lines == hosts
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
def test_error_when_no_parameters(self):
|
||||
"""测试当 target_id 和 provider 都未提供时抛出错误"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError, match="必须提供 target_id 或 provider 参数之一"):
|
||||
export_hosts_task(output_file=output_file)
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
|
||||
class TestExportSiteUrlsTaskBackwardCompatibility:
|
||||
"""export_site_urls_task 向后兼容性测试"""
|
||||
|
||||
def test_property_8_legacy_mode_uses_traditional_logic(self):
|
||||
"""
|
||||
Property 8: Task Backward Compatibility (export_site_urls_task)
|
||||
|
||||
Feature: scan-target-provider, Property 8: Task Backward Compatibility
|
||||
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
|
||||
|
||||
When calling export_site_urls_task with only target_id,
|
||||
it should use the traditional logic (_export_site_urls_legacy).
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
target_id = 123
|
||||
|
||||
# Mock HostPortMappingService
|
||||
mock_associations = [
|
||||
{'host': 'example.com', 'port': 80},
|
||||
{'host': 'test.com', 'port': 443},
|
||||
]
|
||||
|
||||
with patch('apps.scan.tasks.site_scan.export_site_urls_task.HostPortMappingService') as mock_service_class, \
|
||||
patch('apps.scan.tasks.site_scan.export_site_urls_task.BlacklistService') as mock_blacklist_service:
|
||||
|
||||
# Mock HostPortMappingService
|
||||
mock_service = MagicMock()
|
||||
mock_service.iter_host_port_by_target.return_value = iter(mock_associations)
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Mock BlacklistService
|
||||
mock_blacklist = MagicMock()
|
||||
mock_blacklist.get_rules.return_value = []
|
||||
mock_blacklist_service.return_value = mock_blacklist
|
||||
|
||||
# 调用任务(传统模式:只传 target_id)
|
||||
result = export_site_urls_task(
|
||||
output_file=output_file,
|
||||
target_id=target_id
|
||||
)
|
||||
|
||||
# 验证:返回值包含传统模式的字段
|
||||
assert result['success'] is True
|
||||
assert result['output_file'] == output_file
|
||||
assert result['total_urls'] == 2 # 80 端口生成 1 个 URL,443 端口生成 1 个 URL
|
||||
assert 'association_count' in result # 传统模式应该返回 association_count
|
||||
assert result['association_count'] == 2
|
||||
assert result['source'] == 'host_port'
|
||||
|
||||
# 验证:文件内容正确
|
||||
with open(output_file, 'r') as f:
|
||||
lines = [line.strip() for line in f.readlines()]
|
||||
assert 'http://example.com' in lines
|
||||
assert 'https://test.com' in lines
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
def test_provider_mode_uses_provider_logic(self):
|
||||
"""测试当提供 provider 时使用 Provider 模式"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
urls = ['https://example.com', 'https://test.com']
|
||||
provider = ListTargetProvider(targets=urls)
|
||||
|
||||
# 调用任务(Provider 模式)
|
||||
result = export_site_urls_task(
|
||||
output_file=output_file,
|
||||
provider=provider
|
||||
)
|
||||
|
||||
# 验证:使用了 provider
|
||||
assert result['success'] is True
|
||||
assert result['total_urls'] == len(urls)
|
||||
assert 'association_count' not in result # Provider 模式不返回 association_count
|
||||
assert result['source'] == 'provider'
|
||||
|
||||
# 验证:文件内容正确
|
||||
with open(output_file, 'r') as f:
|
||||
lines = [line.strip() for line in f.readlines()]
|
||||
assert lines == urls
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
def test_error_when_no_parameters(self):
|
||||
"""测试当 target_id 和 provider 都未提供时抛出错误"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError, match="必须提供 target_id 或 provider 参数之一"):
|
||||
export_site_urls_task(output_file=output_file)
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
@@ -1,16 +1,23 @@
|
||||
"""
|
||||
导出站点 URL 列表任务
|
||||
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
数据源: WebSite.url(用于 katana 等爬虫工具)
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
数据源: WebSite.url → Default(用于 katana 等爬虫工具)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from prefect import task
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.models import WebSite
|
||||
from apps.scan.services.target_export_service import create_export_service
|
||||
from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,25 +29,27 @@ logger = logging.getLogger(__name__)
|
||||
)
|
||||
def export_sites_task(
|
||||
output_file: str,
|
||||
target_id: int,
|
||||
scan_id: int,
|
||||
target_id: Optional[int] = None,
|
||||
scan_id: Optional[int] = None,
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出站点 URL 列表到文件(用于 katana 等爬虫工具)
|
||||
|
||||
数据源: WebSite.url
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. WebSite 表 - 站点级别 URL
|
||||
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
Args:
|
||||
output_file: 输出文件路径
|
||||
target_id: 目标 ID
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
scan_id: 扫描 ID(保留参数,兼容旧调用)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 批次大小(内存优化)
|
||||
|
||||
Returns:
|
||||
@@ -53,17 +62,27 @@ def export_sites_task(
|
||||
ValueError: 参数错误
|
||||
RuntimeError: 执行失败
|
||||
"""
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# 使用工厂函数创建导出服务
|
||||
export_service = create_export_service(target_id)
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
result = export_service.export_urls(
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
output_file=output_file,
|
||||
sources=[DataSource.WEBSITE, DataSource.DEFAULT],
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"站点 URL 导出完成 - source=%s, count=%d",
|
||||
result['source'], result['total_count']
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
@@ -71,3 +90,31 @@ def export_sites_task(
|
||||
'output_file': result['output_file'],
|
||||
'asset_count': result['total_count'],
|
||||
}
|
||||
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
return {
|
||||
'output_file': str(output_path),
|
||||
'asset_count': total_count,
|
||||
}
|
||||
|
||||
@@ -1,41 +1,52 @@
|
||||
"""导出 Endpoint URL 到文件的 Task
|
||||
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
数据源优先级(回退链):
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. Endpoint.url - 最精细的 URL(含路径、参数等)
|
||||
2. WebSite.url - 站点级别 URL
|
||||
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.models import Endpoint, WebSite
|
||||
from apps.scan.services.target_export_service import create_export_service
|
||||
from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_endpoints")
|
||||
def export_endpoints_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
output_file: str = "",
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000,
|
||||
) -> Dict[str, object]:
|
||||
"""导出目标下的所有 Endpoint URL 到文本文件。
|
||||
|
||||
数据源优先级(回退链):
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. Endpoint 表 - 最精细的 URL(含路径、参数等)
|
||||
2. WebSite 表 - 站点级别 URL
|
||||
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 每次从数据库迭代的批大小
|
||||
|
||||
Returns:
|
||||
@@ -43,55 +54,65 @@ def export_endpoints_task(
|
||||
"success": bool,
|
||||
"output_file": str,
|
||||
"total_count": int,
|
||||
"source": str, # 数据来源: "endpoint" | "website" | "default"
|
||||
"source": str, # 数据来源: "endpoint" | "website" | "default" | "none" | "provider"
|
||||
}
|
||||
"""
|
||||
export_service = create_export_service(target_id)
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_file=output_file,
|
||||
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"URL 导出完成 - source=%s, count=%d, tried=%s",
|
||||
result['source'], result['total_count'], result['tried_sources']
|
||||
)
|
||||
|
||||
return {
|
||||
"success": result['success'],
|
||||
"output_file": result['output_file'],
|
||||
"total_count": result['total_count'],
|
||||
"source": result['source'],
|
||||
}
|
||||
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> Dict[str, object]:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 1. 优先从 Endpoint 表导出
|
||||
endpoint_queryset = Endpoint.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=endpoint_queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
if result['total_count'] > 0:
|
||||
logger.info("从 Endpoint 表导出 %d 条 URL", result['total_count'])
|
||||
return {
|
||||
"success": True,
|
||||
"output_file": result['output_file'],
|
||||
"total_count": result['total_count'],
|
||||
"source": "endpoint",
|
||||
}
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
# 2. Endpoint 为空,回退到 WebSite 表
|
||||
logger.info("Endpoint 表为空,回退到 WebSite 表")
|
||||
website_queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=website_queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
if result['total_count'] > 0:
|
||||
logger.info("从 WebSite 表导出 %d 条 URL", result['total_count'])
|
||||
return {
|
||||
"success": True,
|
||||
"output_file": result['output_file'],
|
||||
"total_count": result['total_count'],
|
||||
"source": "website",
|
||||
}
|
||||
|
||||
# 3. WebSite 也为空,生成默认 URL(export_urls 内部已处理)
|
||||
logger.info("WebSite 表也为空,使用默认 URL 生成")
|
||||
return {
|
||||
"success": True,
|
||||
"output_file": result['output_file'],
|
||||
"total_count": result['total_count'],
|
||||
"source": "default",
|
||||
"output_file": str(output_path),
|
||||
"total_count": total_count,
|
||||
"source": "provider",
|
||||
}
|
||||
|
||||
@@ -4,7 +4,8 @@ from .views import ScanViewSet, ScheduledScanViewSet, ScanLogListView, Subfinder
|
||||
from .notifications.views import notification_callback
|
||||
from apps.asset.views import (
|
||||
SubdomainSnapshotViewSet, WebsiteSnapshotViewSet, DirectorySnapshotViewSet,
|
||||
EndpointSnapshotViewSet, HostPortMappingSnapshotViewSet, VulnerabilitySnapshotViewSet
|
||||
EndpointSnapshotViewSet, HostPortMappingSnapshotViewSet, VulnerabilitySnapshotViewSet,
|
||||
ScreenshotSnapshotViewSet
|
||||
)
|
||||
|
||||
# 创建路由器
|
||||
@@ -26,6 +27,8 @@ scan_endpoints_export = EndpointSnapshotViewSet.as_view({'get': 'export'})
|
||||
scan_ip_addresses_list = HostPortMappingSnapshotViewSet.as_view({'get': 'list'})
|
||||
scan_ip_addresses_export = HostPortMappingSnapshotViewSet.as_view({'get': 'export'})
|
||||
scan_vulnerabilities_list = VulnerabilitySnapshotViewSet.as_view({'get': 'list'})
|
||||
scan_screenshots_list = ScreenshotSnapshotViewSet.as_view({'get': 'list'})
|
||||
scan_screenshots_image = ScreenshotSnapshotViewSet.as_view({'get': 'image'})
|
||||
|
||||
urlpatterns = [
|
||||
path('', include(router.urls)),
|
||||
@@ -47,5 +50,7 @@ urlpatterns = [
|
||||
path('scans/<int:scan_pk>/ip-addresses/', scan_ip_addresses_list, name='scan-ip-addresses-list'),
|
||||
path('scans/<int:scan_pk>/ip-addresses/export/', scan_ip_addresses_export, name='scan-ip-addresses-export'),
|
||||
path('scans/<int:scan_pk>/vulnerabilities/', scan_vulnerabilities_list, name='scan-vulnerabilities-list'),
|
||||
path('scans/<int:scan_pk>/screenshots/', scan_screenshots_list, name='scan-screenshots-list'),
|
||||
path('scans/<int:scan_pk>/screenshots/<int:pk>/image/', scan_screenshots_image, name='scan-screenshots-image'),
|
||||
]
|
||||
|
||||
|
||||
@@ -4,37 +4,40 @@
|
||||
提供扫描相关的工具函数。
|
||||
"""
|
||||
|
||||
from .directory_cleanup import remove_directory
|
||||
from . import config_parser
|
||||
from .command_builder import build_scan_command
|
||||
from .command_executor import execute_and_wait, execute_stream
|
||||
from .wordlist_helpers import ensure_wordlist_local
|
||||
from .directory_cleanup import remove_directory
|
||||
from .nuclei_helpers import ensure_nuclei_templates_local
|
||||
from .performance import FlowPerformanceTracker, CommandPerformanceTracker
|
||||
from .workspace_utils import setup_scan_workspace, setup_scan_directory
|
||||
from .performance import CommandPerformanceTracker, FlowPerformanceTracker
|
||||
from .system_load import check_system_load, wait_for_system_load
|
||||
from .user_logger import user_log
|
||||
from . import config_parser
|
||||
from .wordlist_helpers import ensure_wordlist_local
|
||||
from .workspace_utils import setup_scan_directory, setup_scan_workspace
|
||||
|
||||
__all__ = [
|
||||
# 目录清理
|
||||
'remove_directory',
|
||||
# 工作空间
|
||||
'setup_scan_workspace', # 创建 Scan 根工作空间
|
||||
'setup_scan_directory', # 创建扫描子目录
|
||||
'setup_scan_workspace',
|
||||
'setup_scan_directory',
|
||||
# 命令构建
|
||||
'build_scan_command', # 扫描工具命令构建(基于 f-string)
|
||||
'build_scan_command',
|
||||
# 命令执行
|
||||
'execute_and_wait', # 等待式执行(文件输出)
|
||||
'execute_stream', # 流式执行(实时处理)
|
||||
'execute_and_wait',
|
||||
'execute_stream',
|
||||
# 系统负载
|
||||
'wait_for_system_load',
|
||||
'check_system_load',
|
||||
# 字典文件
|
||||
'ensure_wordlist_local', # 确保本地字典文件(含 hash 校验)
|
||||
'ensure_wordlist_local',
|
||||
# Nuclei 模板
|
||||
'ensure_nuclei_templates_local', # 确保本地模板(含 commit hash 校验)
|
||||
'ensure_nuclei_templates_local',
|
||||
# 性能监控
|
||||
'FlowPerformanceTracker', # Flow 性能追踪器(含系统资源采样)
|
||||
'CommandPerformanceTracker', # 命令性能追踪器
|
||||
'FlowPerformanceTracker',
|
||||
'CommandPerformanceTracker',
|
||||
# 扫描日志
|
||||
'user_log', # 用户可见扫描日志记录
|
||||
'user_log',
|
||||
# 配置解析
|
||||
'config_parser',
|
||||
]
|
||||
|
||||
|
||||
@@ -12,16 +12,18 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from django.conf import settings
|
||||
import re
|
||||
import signal
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Generator
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
try:
|
||||
# 可选依赖:用于根据 CPU / 内存负载做动态并发控制
|
||||
import psutil
|
||||
@@ -354,10 +356,13 @@ class CommandExecutor:
|
||||
if log_file_path:
|
||||
error_output = self._read_log_tail(log_file_path, max_lines=MAX_LOG_TAIL_LINES)
|
||||
logger.warning(
|
||||
"扫描工具 %s 返回非零状态码: %d (执行时间: %.2f秒)%s",
|
||||
tool_name, returncode, duration,
|
||||
f"\n错误输出:\n{error_output}" if error_output else ""
|
||||
"扫描工具 %s 返回非零状态码: %d (执行时间: %.2f秒)",
|
||||
tool_name, returncode, duration
|
||||
)
|
||||
if error_output:
|
||||
for line in error_output.strip().split('\n'):
|
||||
if line.strip():
|
||||
logger.warning("%s", line)
|
||||
else:
|
||||
logger.info("✓ 扫描工具 %s 执行完成 (执行时间: %.2f秒)", tool_name, duration)
|
||||
|
||||
@@ -666,33 +671,68 @@ class CommandExecutor:
|
||||
|
||||
def _read_log_tail(self, log_file: Path, max_lines: int = MAX_LOG_TAIL_LINES) -> str:
|
||||
"""
|
||||
读取日志文件的末尾部分
|
||||
|
||||
读取日志文件的末尾部分(常量内存实现)
|
||||
|
||||
使用 seek 从文件末尾往前读取,避免将整个文件加载到内存。
|
||||
|
||||
Args:
|
||||
log_file: 日志文件路径
|
||||
max_lines: 最大读取行数
|
||||
|
||||
|
||||
Returns:
|
||||
日志内容(字符串),读取失败返回错误提示
|
||||
"""
|
||||
if not log_file.exists():
|
||||
logger.debug("日志文件不存在: %s", log_file)
|
||||
return ""
|
||||
|
||||
if log_file.stat().st_size == 0:
|
||||
|
||||
file_size = log_file.stat().st_size
|
||||
if file_size == 0:
|
||||
logger.debug("日志文件为空: %s", log_file)
|
||||
return ""
|
||||
|
||||
|
||||
# 每次读取的块大小(8KB,足够容纳大多数日志行)
|
||||
chunk_size = 8192
|
||||
|
||||
def decode_line(line_bytes: bytes) -> str:
|
||||
"""解码单行:优先 UTF-8,失败则降级 latin-1"""
|
||||
try:
|
||||
return line_bytes.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
return line_bytes.decode('latin-1', errors='replace')
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
return ''.join(lines[-max_lines:] if len(lines) > max_lines else lines)
|
||||
except UnicodeDecodeError as e:
|
||||
logger.warning("日志文件编码错误 (%s): %s", log_file, e)
|
||||
return f"(无法读取日志文件: 编码错误 - {e})"
|
||||
with open(log_file, 'rb') as f:
|
||||
lines_found: deque[bytes] = deque()
|
||||
remaining = b''
|
||||
position = file_size
|
||||
|
||||
while position > 0 and len(lines_found) < max_lines:
|
||||
read_size = min(chunk_size, position)
|
||||
position -= read_size
|
||||
|
||||
f.seek(position)
|
||||
chunk = f.read(read_size) + remaining
|
||||
parts = chunk.split(b'\n')
|
||||
|
||||
# 最前面的部分可能不完整,留到下次处理
|
||||
remaining = parts[0]
|
||||
|
||||
# 其余部分是完整的行(从后往前收集,用 appendleft 保持顺序)
|
||||
for part in reversed(parts[1:]):
|
||||
if len(lines_found) >= max_lines:
|
||||
break
|
||||
lines_found.appendleft(part)
|
||||
|
||||
# 处理文件开头的行
|
||||
if remaining and len(lines_found) < max_lines:
|
||||
lines_found.appendleft(remaining)
|
||||
|
||||
return '\n'.join(decode_line(line) for line in lines_found)
|
||||
|
||||
except PermissionError as e:
|
||||
logger.warning("日志文件权限不足 (%s): %s", log_file, e)
|
||||
return f"(无法读取日志文件: 权限不足)"
|
||||
return "(无法读取日志文件: 权限不足)"
|
||||
except IOError as e:
|
||||
logger.warning("日志文件读取IO错误 (%s): %s", log_file, e)
|
||||
return f"(无法读取日志文件: IO错误 - {e})"
|
||||
|
||||
77
backend/apps/scan/utils/system_load.py
Normal file
77
backend/apps/scan/utils/system_load.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
系统负载检查工具
|
||||
|
||||
提供统一的系统负载检查功能,用于:
|
||||
- Flow 入口处检查系统资源是否充足
|
||||
- 防止在高负载时启动新的扫描任务
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import psutil
|
||||
from django.conf import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 动态并发控制阈值(可在 Django settings 中覆盖)
|
||||
SCAN_CPU_HIGH: float = getattr(settings, 'SCAN_CPU_HIGH', 90.0)
|
||||
SCAN_MEM_HIGH: float = getattr(settings, 'SCAN_MEM_HIGH', 80.0)
|
||||
SCAN_LOAD_CHECK_INTERVAL: int = getattr(settings, 'SCAN_LOAD_CHECK_INTERVAL', 180)
|
||||
|
||||
|
||||
def _get_current_load() -> tuple[float, float]:
|
||||
"""获取当前 CPU 和内存使用率"""
|
||||
return psutil.cpu_percent(interval=0.5), psutil.virtual_memory().percent
|
||||
|
||||
|
||||
def wait_for_system_load(
|
||||
cpu_threshold: float = SCAN_CPU_HIGH,
|
||||
mem_threshold: float = SCAN_MEM_HIGH,
|
||||
check_interval: int = SCAN_LOAD_CHECK_INTERVAL,
|
||||
context: str = "task"
|
||||
) -> None:
|
||||
"""
|
||||
等待系统负载降到阈值以下
|
||||
|
||||
在高负载时阻塞等待,直到 CPU 和内存都低于阈值。
|
||||
用于 Flow 入口处,防止在资源紧张时启动新任务。
|
||||
"""
|
||||
while True:
|
||||
cpu, mem = _get_current_load()
|
||||
|
||||
if cpu < cpu_threshold and mem < mem_threshold:
|
||||
logger.debug(
|
||||
"[%s] 系统负载正常: cpu=%.1f%%, mem=%.1f%%",
|
||||
context, cpu, mem
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"[%s] 系统负载较高,等待资源释放: "
|
||||
"cpu=%.1f%% (阈值 %.1f%%), mem=%.1f%% (阈值 %.1f%%)",
|
||||
context, cpu, cpu_threshold, mem, mem_threshold
|
||||
)
|
||||
time.sleep(check_interval)
|
||||
|
||||
|
||||
def check_system_load(
|
||||
cpu_threshold: float = SCAN_CPU_HIGH,
|
||||
mem_threshold: float = SCAN_MEM_HIGH
|
||||
) -> dict:
|
||||
"""
|
||||
检查当前系统负载(非阻塞)
|
||||
|
||||
Returns:
|
||||
dict: cpu_percent, mem_percent, cpu_threshold, mem_threshold, is_overloaded
|
||||
"""
|
||||
cpu, mem = _get_current_load()
|
||||
|
||||
return {
|
||||
'cpu_percent': cpu,
|
||||
'mem_percent': mem,
|
||||
'cpu_threshold': cpu_threshold,
|
||||
'mem_threshold': mem_threshold,
|
||||
'is_overloaded': cpu >= cpu_threshold or mem >= mem_threshold,
|
||||
}
|
||||
|
||||
@@ -119,6 +119,7 @@ class TargetDetailSerializer(serializers.ModelSerializer):
|
||||
- endpoints: 端点数量
|
||||
- ips: IP地址数量
|
||||
- directories: 目录数量
|
||||
- screenshots: 截图数量
|
||||
- vulnerabilities: 漏洞统计(暂时返回 0,待后续实现)
|
||||
|
||||
性能说明:
|
||||
@@ -134,6 +135,7 @@ class TargetDetailSerializer(serializers.ModelSerializer):
|
||||
endpoints_count = obj.endpoints.count()
|
||||
ips_count = obj.host_port_mappings.values('ip').distinct().count()
|
||||
directories_count = obj.directories.count()
|
||||
screenshots_count = obj.screenshots.count()
|
||||
|
||||
# 漏洞统计:按目标维度实时统计 Vulnerability 资产表
|
||||
vuln_qs = obj.vulnerabilities.all()
|
||||
@@ -159,6 +161,7 @@ class TargetDetailSerializer(serializers.ModelSerializer):
|
||||
'endpoints': endpoints_count,
|
||||
'ips': ips_count,
|
||||
'directories': directories_count,
|
||||
'screenshots': screenshots_count,
|
||||
'vulnerabilities': {
|
||||
'total': total,
|
||||
**severity_stats,
|
||||
@@ -182,12 +185,12 @@ class BatchCreateTargetSerializer(serializers.Serializer):
|
||||
批量创建目标的序列化器
|
||||
|
||||
安全限制:
|
||||
- 最多支持 1000 个目标的批量创建
|
||||
- 最多支持 5000 个目标的批量创建
|
||||
- 防止恶意用户提交大量数据导致服务器过载
|
||||
"""
|
||||
|
||||
# 批量创建的最大数量限制
|
||||
MAX_BATCH_SIZE = 1000
|
||||
MAX_BATCH_SIZE = 5000
|
||||
|
||||
# 目标列表
|
||||
targets = serializers.ListField(
|
||||
|
||||
@@ -3,7 +3,8 @@ from rest_framework.routers import DefaultRouter
|
||||
from .views import OrganizationViewSet, TargetViewSet
|
||||
from apps.asset.views import (
|
||||
SubdomainViewSet, WebSiteViewSet, DirectoryViewSet,
|
||||
EndpointViewSet, HostPortMappingViewSet, VulnerabilityViewSet
|
||||
EndpointViewSet, HostPortMappingViewSet, VulnerabilityViewSet,
|
||||
ScreenshotViewSet
|
||||
)
|
||||
|
||||
# 创建路由器
|
||||
@@ -29,6 +30,8 @@ target_endpoints_bulk_create = EndpointViewSet.as_view({'post': 'bulk_create'})
|
||||
target_ip_addresses_list = HostPortMappingViewSet.as_view({'get': 'list'})
|
||||
target_ip_addresses_export = HostPortMappingViewSet.as_view({'get': 'export'})
|
||||
target_vulnerabilities_list = VulnerabilityViewSet.as_view({'get': 'list'})
|
||||
target_screenshots_list = ScreenshotViewSet.as_view({'get': 'list'})
|
||||
target_screenshots_bulk_delete = ScreenshotViewSet.as_view({'post': 'bulk_delete'})
|
||||
|
||||
urlpatterns = [
|
||||
path('', include(router.urls)),
|
||||
@@ -48,4 +51,6 @@ urlpatterns = [
|
||||
path('targets/<int:target_pk>/ip-addresses/', target_ip_addresses_list, name='target-ip-addresses-list'),
|
||||
path('targets/<int:target_pk>/ip-addresses/export/', target_ip_addresses_export, name='target-ip-addresses-export'),
|
||||
path('targets/<int:target_pk>/vulnerabilities/', target_vulnerabilities_list, name='target-vulnerabilities-list'),
|
||||
path('targets/<int:target_pk>/screenshots/', target_screenshots_list, name='target-screenshots-list'),
|
||||
path('targets/<int:target_pk>/screenshots/bulk-delete/', target_screenshots_bulk_delete, name='target-screenshots-bulk-delete'),
|
||||
]
|
||||
|
||||
@@ -12,16 +12,34 @@ load-plugins = "pylint_django"
|
||||
|
||||
[tool.pylint.messages_control]
|
||||
disable = [
|
||||
"missing-docstring",
|
||||
"invalid-name",
|
||||
"too-few-public-methods",
|
||||
"no-member",
|
||||
"import-error",
|
||||
"no-name-in-module",
|
||||
"missing-docstring",
|
||||
"invalid-name",
|
||||
"too-few-public-methods",
|
||||
"no-member",
|
||||
"import-error",
|
||||
"no-name-in-module",
|
||||
"wrong-import-position", # 允许函数内导入(防循环依赖)
|
||||
"import-outside-toplevel", # 同上
|
||||
"too-many-arguments", # Django 视图/服务方法参数常超过5个
|
||||
"too-many-locals", # 复杂业务逻辑局部变量多
|
||||
"duplicate-code", # 某些模式代码相似是正常的
|
||||
]
|
||||
|
||||
[tool.pylint.format]
|
||||
max-line-length = 120
|
||||
|
||||
[tool.pylint.basic]
|
||||
good-names = ["i", "j", "k", "ex", "Run", "_", "id", "pk", "ip", "url", "db", "qs"]
|
||||
good-names = [
|
||||
"i",
|
||||
"j",
|
||||
"k",
|
||||
"ex",
|
||||
"Run",
|
||||
"_",
|
||||
"id",
|
||||
"pk",
|
||||
"ip",
|
||||
"url",
|
||||
"db",
|
||||
"qs",
|
||||
]
|
||||
|
||||
@@ -38,9 +38,12 @@ packaging>=21.0 # 版本比较
|
||||
# 测试框架
|
||||
pytest==8.0.0
|
||||
pytest-django==4.7.0
|
||||
hypothesis>=6.100.0 # 属性测试框架
|
||||
|
||||
# 工具库
|
||||
python-dateutil==2.9.0
|
||||
Pillow>=10.0.0 # 图像处理(截图服务)
|
||||
playwright>=1.40.0 # 浏览器自动化(截图服务)
|
||||
pytz==2024.1
|
||||
validators==0.22.0
|
||||
PyYAML==6.0.1
|
||||
|
||||
@@ -639,19 +639,19 @@ class TestDataGenerator:
|
||||
target_id, engine_ids, engine_names, yaml_configuration, status, worker_id, progress, current_stage,
|
||||
results_dir, error_message, container_ids, stage_progress,
|
||||
cached_subdomains_count, cached_websites_count, cached_endpoints_count,
|
||||
cached_ips_count, cached_directories_count, cached_vulns_total,
|
||||
cached_ips_count, cached_directories_count, cached_screenshots_count, cached_vulns_total,
|
||||
cached_vulns_critical, cached_vulns_high, cached_vulns_medium, cached_vulns_low,
|
||||
created_at, stopped_at, deleted_at
|
||||
) VALUES (
|
||||
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
|
||||
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
|
||||
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
|
||||
NOW() - INTERVAL '%s days', %s, NULL
|
||||
)
|
||||
RETURNING id
|
||||
""", (
|
||||
target_id, selected_engine_ids, json.dumps(selected_engine_names), '', status, worker_id, progress, stage,
|
||||
f'/app/results/scan_{target_id}_{random.randint(1000, 9999)}', error_msg, '{}', '{}',
|
||||
subdomains, websites, endpoints, ips, directories, vulns_total,
|
||||
subdomains, websites, endpoints, ips, directories, 0, vulns_total,
|
||||
vulns_critical, vulns_high, vulns_medium, vulns_low,
|
||||
days_ago,
|
||||
datetime.now() - timedelta(days=days_ago, hours=random.randint(0, 23)) if status in ['completed', 'failed', 'cancelled'] else None
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# ============================================
|
||||
# XingRin 远程节点安装脚本
|
||||
# 用途:安装 Docker 环境 + 预拉取镜像
|
||||
# 支持:Ubuntu / Debian
|
||||
# 支持:Ubuntu / Debian / Kali
|
||||
#
|
||||
# 架构说明:
|
||||
# 1. 安装 Docker 环境
|
||||
@@ -101,8 +101,8 @@ detect_os() {
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$OS" != "ubuntu" && "$OS" != "debian" ]]; then
|
||||
log_error "仅支持 Ubuntu/Debian 系统"
|
||||
if [[ "$OS" != "ubuntu" && "$OS" != "debian" && "$OS" != "kali" ]]; then
|
||||
log_error "仅支持 Ubuntu/Debian/Kali 系统"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
@@ -44,6 +44,8 @@ services:
|
||||
restart: always
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- IMAGE_TAG=${IMAGE_TAG:-dev}
|
||||
ports:
|
||||
- "8888:8888"
|
||||
depends_on:
|
||||
@@ -53,6 +55,8 @@ services:
|
||||
# 统一挂载数据目录
|
||||
- /opt/xingrin:/opt/xingrin
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
# OOM 优先级:-500 保护核心服务
|
||||
oom_score_adj: -500
|
||||
healthcheck:
|
||||
# 使用专门的健康检查端点(无需认证)
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8888/api/health/"]
|
||||
@@ -88,6 +92,8 @@ services:
|
||||
args:
|
||||
IMAGE_TAG: ${IMAGE_TAG:-dev}
|
||||
restart: always
|
||||
# OOM 优先级:-500 保护 Web 界面
|
||||
oom_score_adj: -500
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
@@ -97,6 +103,8 @@ services:
|
||||
context: ..
|
||||
dockerfile: docker/nginx/Dockerfile
|
||||
restart: always
|
||||
# OOM 优先级:-500 保护入口网关
|
||||
oom_score_adj: -500
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -48,6 +48,8 @@ services:
|
||||
restart: always
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- IMAGE_TAG=${IMAGE_TAG}
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
@@ -56,6 +58,8 @@ services:
|
||||
- /opt/xingrin:/opt/xingrin
|
||||
# Docker Socket 挂载:允许 Django 服务器执行本地 docker 命令(用于本地 Worker 任务分发)
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
# OOM 优先级:-500 降低被 OOM Killer 选中的概率,保护核心服务
|
||||
oom_score_adj: -500
|
||||
healthcheck:
|
||||
# 使用专门的健康检查端点(无需认证)
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8888/api/health/"]
|
||||
@@ -88,6 +92,8 @@ services:
|
||||
frontend:
|
||||
image: ${DOCKER_USER:-yyhuni}/xingrin-frontend:${IMAGE_TAG:?IMAGE_TAG is required}
|
||||
restart: always
|
||||
# OOM 优先级:-500 保护 Web 界面
|
||||
oom_score_adj: -500
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
@@ -95,6 +101,8 @@ services:
|
||||
nginx:
|
||||
image: ${DOCKER_USER:-yyhuni}/xingrin-nginx:${IMAGE_TAG:?IMAGE_TAG is required}
|
||||
restart: always
|
||||
# OOM 优先级:-500 保护入口网关
|
||||
oom_score_adj: -500
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -17,7 +17,8 @@
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
DOCKER_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
cd "$DOCKER_DIR"
|
||||
|
||||
# 颜色输出
|
||||
GREEN='\033[0;32m'
|
||||
@@ -33,7 +34,9 @@ log_step() { echo -e " ${CYAN}>>${NC} $1"; }
|
||||
|
||||
# 检查服务是否运行
|
||||
check_server() {
|
||||
if ! docker compose ps --status running 2>/dev/null | grep -q "server"; then
|
||||
# 使用 docker compose ps 的 --format 选项获取服务状态
|
||||
# 这种方式不依赖容器名称格式,只检查服务名
|
||||
if ! docker compose ps --format '{{.Service}} {{.State}}' 2>/dev/null | grep -E "^server\s+running" > /dev/null; then
|
||||
echo "Server 容器未运行,跳过数据初始化"
|
||||
return 1
|
||||
fi
|
||||
@@ -80,20 +83,20 @@ if not yaml_path.exists():
|
||||
print('未找到配置文件,跳过')
|
||||
exit(0)
|
||||
|
||||
new_config = yaml_path.read_text()
|
||||
|
||||
# 检查是否已有 full scan 引擎
|
||||
engine = ScanEngine.objects.filter(name='full scan').first()
|
||||
if engine:
|
||||
if not engine.configuration or not engine.configuration.strip():
|
||||
engine.configuration = yaml_path.read_text()
|
||||
engine.save(update_fields=['configuration'])
|
||||
print(f'已初始化引擎配置: {engine.name}')
|
||||
else:
|
||||
print(f'引擎已有配置,跳过')
|
||||
# 直接覆盖为最新配置
|
||||
engine.configuration = new_config
|
||||
engine.save(update_fields=['configuration'])
|
||||
print(f'已更新引擎配置: {engine.name}')
|
||||
else:
|
||||
# 创建引擎
|
||||
engine = ScanEngine.objects.create(
|
||||
name='full scan',
|
||||
configuration=yaml_path.read_text(),
|
||||
configuration=new_config,
|
||||
)
|
||||
print(f'已创建引擎: {engine.name}')
|
||||
"
|
||||
|
||||
@@ -10,7 +10,7 @@ python manage.py migrate --noinput
|
||||
echo " ✓ 数据库迁移完成"
|
||||
|
||||
echo " [1.1/3] 初始化默认扫描引擎..."
|
||||
python manage.py init_default_engine
|
||||
python manage.py init_default_engine --force
|
||||
echo " ✓ 默认扫描引擎已就绪"
|
||||
|
||||
echo " [1.2/3] 初始化默认目录字典..."
|
||||
|
||||
@@ -182,7 +182,7 @@ echo -e "${BOLD}${GREEN}══════════════════
|
||||
echo ""
|
||||
echo -e "${BOLD}访问地址${NC}"
|
||||
if [ "$WITH_FRONTEND" = true ]; then
|
||||
echo -e " XingRin: ${CYAN}https://${ACCESS_HOST}/${NC}"
|
||||
echo -e " XingRin: ${CYAN}https://${ACCESS_HOST}:8083/${NC}"
|
||||
echo -e " ${YELLOW}(HTTP 会自动跳转到 HTTPS)${NC}"
|
||||
else
|
||||
echo -e " API: ${CYAN}通过前端或 nginx 访问(后端未暴露 8888)${NC}"
|
||||
|
||||
@@ -29,9 +29,6 @@ RUN go install -v github.com/projectdiscovery/httpx/cmd/httpx@latest && \
|
||||
go install -v github.com/d3mondev/puredns/v2@latest && \
|
||||
go install -v github.com/yyhuni/xingfinger@latest
|
||||
|
||||
# 安装 Amass v5(禁用 CGO 以跳过 libpostal 依赖)
|
||||
RUN CGO_ENABLED=0 go install -v github.com/owasp-amass/amass/v5/cmd/amass@main
|
||||
|
||||
# 安装漏洞扫描器
|
||||
RUN go install github.com/hahwul/dalfox/v2@latest
|
||||
|
||||
@@ -45,7 +42,9 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
WORKDIR /app
|
||||
|
||||
# 1. 安装基础工具和 Python
|
||||
RUN apt-get update && apt-get install -y \
|
||||
# 注意:ARM64 使用 ports.ubuntu.com,可能存在镜像同步延迟,需要重试机制
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-venv \
|
||||
@@ -60,8 +59,32 @@ RUN apt-get update && apt-get install -y \
|
||||
masscan \
|
||||
libpcap-dev \
|
||||
ca-certificates \
|
||||
fonts-liberation \
|
||||
libnss3 \
|
||||
libxss1 \
|
||||
libasound2t64 \
|
||||
|| (rm -rf /var/lib/apt/lists/* && apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3 python3-pip python3-venv pipx git curl wget unzip jq tmux nmap masscan libpcap-dev \
|
||||
ca-certificates fonts-liberation libnss3 libxss1 libasound2t64) \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 安装 Chromium(通过 Playwright 安装,支持 ARM64 和 AMD64)
|
||||
# Ubuntu 24.04 的 chromium-browser 是 snap 过渡包,Docker 中不可用
|
||||
RUN pip install playwright --break-system-packages && \
|
||||
playwright install chromium && \
|
||||
apt-get update && \
|
||||
playwright install-deps chromium && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 设置 Chrome 路径供 httpx 等工具使用(Playwright 安装位置)
|
||||
ENV CHROME_PATH=/root/.cache/ms-playwright/chromium-*/chrome-linux/chrome
|
||||
# 创建软链接确保 httpx 的 -system-chrome 能找到浏览器
|
||||
RUN CHROME_BIN=$(find /root/.cache/ms-playwright -name chrome -type f 2>/dev/null | head -1) && \
|
||||
ln -sf "$CHROME_BIN" /usr/bin/chromium-browser && \
|
||||
ln -sf "$CHROME_BIN" /usr/bin/chromium && \
|
||||
ln -sf "$CHROME_BIN" /usr/bin/chrome && \
|
||||
ln -sf "$CHROME_BIN" /usr/bin/google-chrome-stable
|
||||
|
||||
# 建立 python 软链接
|
||||
RUN ln -s /usr/bin/python3 /usr/bin/python
|
||||
|
||||
|
||||
@@ -54,10 +54,10 @@ flowchart TB
|
||||
TARGET --> SUBLIST3R
|
||||
TARGET --> ASSETFINDER
|
||||
|
||||
subgraph STAGE2["Stage 2: Analysis Parallel"]
|
||||
subgraph STAGE2["Stage 2: URL Collection Parallel"]
|
||||
direction TB
|
||||
|
||||
subgraph URL["URL Collection"]
|
||||
subgraph URL["URL Fetch"]
|
||||
direction TB
|
||||
WAYMORE[waymore<br/>Historical URLs]
|
||||
KATANA[katana<br/>Crawler]
|
||||
@@ -78,7 +78,15 @@ flowchart TB
|
||||
XINGFINGER --> KATANA
|
||||
XINGFINGER --> FFUF
|
||||
|
||||
subgraph STAGE3["Stage 3: Vulnerability Sequential"]
|
||||
subgraph STAGE3["Stage 3: Screenshot Sequential"]
|
||||
direction TB
|
||||
SCREENSHOT[Playwright<br/>Page Screenshot]
|
||||
end
|
||||
|
||||
HTTPX2 --> SCREENSHOT
|
||||
FFUF --> SCREENSHOT
|
||||
|
||||
subgraph STAGE4["Stage 4: Vulnerability Sequential"]
|
||||
direction TB
|
||||
|
||||
subgraph VULN["Vulnerability Scan"]
|
||||
@@ -88,12 +96,11 @@ flowchart TB
|
||||
end
|
||||
end
|
||||
|
||||
HTTPX2 --> DALFOX
|
||||
HTTPX2 --> NUCLEI
|
||||
SCREENSHOT --> DALFOX
|
||||
SCREENSHOT --> NUCLEI
|
||||
|
||||
DALFOX --> FINISH
|
||||
NUCLEI --> FINISH
|
||||
FFUF --> FINISH
|
||||
|
||||
FINISH[Scan Complete]
|
||||
|
||||
@@ -109,9 +116,14 @@ flowchart TB
|
||||
|
||||
```python
|
||||
# backend/apps/scan/configs/command_templates.py
|
||||
# Stage 1: 资产发现 - 子域名 → 端口 → 站点探测 → 指纹识别
|
||||
# Stage 2: URL 收集 - URL 获取 + 目录扫描(并行)
|
||||
# Stage 3: 截图 - 在 URL 收集完成后执行,捕获更多发现的页面
|
||||
# Stage 4: 漏洞扫描 - 最后执行
|
||||
EXECUTION_STAGES = [
|
||||
{'mode': 'sequential', 'flows': ['subdomain_discovery', 'port_scan', 'site_scan', 'fingerprint_detect']},
|
||||
{'mode': 'parallel', 'flows': ['url_fetch', 'directory_scan']},
|
||||
{'mode': 'sequential', 'flows': ['screenshot']},
|
||||
{'mode': 'sequential', 'flows': ['vuln_scan']},
|
||||
]
|
||||
```
|
||||
@@ -126,4 +138,5 @@ EXECUTION_STAGES = [
|
||||
| fingerprint_detect | xingfinger | WebSite.tech(更新) |
|
||||
| url_fetch | waymore, katana, uro, httpx | Endpoint |
|
||||
| directory_scan | ffuf | Directory |
|
||||
| screenshot | Playwright | Screenshot |
|
||||
| vuln_scan | dalfox, nuclei | Vulnerability |
|
||||
|
||||
BIN
docs/wx_pay.jpg
Normal file
BIN
docs/wx_pay.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 156 KiB |
BIN
docs/zfb_pay.jpg
Normal file
BIN
docs/zfb_pay.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 144 KiB |
@@ -34,6 +34,7 @@ const FEATURE_LIST = [
|
||||
{ key: "site_scan" },
|
||||
{ key: "fingerprint_detect" },
|
||||
{ key: "directory_scan" },
|
||||
{ key: "screenshot" },
|
||||
{ key: "url_fetch" },
|
||||
{ key: "vuln_scan" },
|
||||
] as const
|
||||
@@ -48,6 +49,7 @@ function parseEngineFeatures(engine: ScanEngine): Record<FeatureKey, boolean> {
|
||||
site_scan: false,
|
||||
fingerprint_detect: false,
|
||||
directory_scan: false,
|
||||
screenshot: false,
|
||||
url_fetch: false,
|
||||
vuln_scan: false,
|
||||
}
|
||||
@@ -64,6 +66,7 @@ function parseEngineFeatures(engine: ScanEngine): Record<FeatureKey, boolean> {
|
||||
site_scan: !!config.site_scan,
|
||||
fingerprint_detect: !!config.fingerprint_detect,
|
||||
directory_scan: !!config.directory_scan,
|
||||
screenshot: !!config.screenshot,
|
||||
url_fetch: !!config.url_fetch,
|
||||
vuln_scan: !!config.vuln_scan,
|
||||
}
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
import React from "react"
|
||||
import { usePathname, useParams } from "next/navigation"
|
||||
import Link from "next/link"
|
||||
import { Target } from "lucide-react"
|
||||
import { Target, LayoutDashboard, Package, Image, ShieldAlert } from "lucide-react"
|
||||
import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs"
|
||||
import { Badge } from "@/components/ui/badge"
|
||||
import { Skeleton } from "@/components/ui/skeleton"
|
||||
import { useScan } from "@/hooks/use-scans"
|
||||
import { useTranslations } from "next-intl"
|
||||
|
||||
@@ -19,104 +20,136 @@ export default function ScanHistoryLayout({
|
||||
const { data: scanData, isLoading } = useScan(parseInt(id))
|
||||
const t = useTranslations("scan.history")
|
||||
|
||||
const getActiveTab = () => {
|
||||
if (pathname.includes("/subdomain")) return "subdomain"
|
||||
if (pathname.includes("/endpoints")) return "endpoints"
|
||||
if (pathname.includes("/websites")) return "websites"
|
||||
if (pathname.includes("/directories")) return "directories"
|
||||
// Get primary navigation active tab
|
||||
const getPrimaryTab = () => {
|
||||
if (pathname.includes("/overview")) return "overview"
|
||||
if (pathname.includes("/screenshots")) return "screenshots"
|
||||
if (pathname.includes("/vulnerabilities")) return "vulnerabilities"
|
||||
if (pathname.includes("/ip-addresses")) return "ip-addresses"
|
||||
return ""
|
||||
// All asset pages fall under "assets"
|
||||
if (
|
||||
pathname.includes("/websites") ||
|
||||
pathname.includes("/subdomain") ||
|
||||
pathname.includes("/ip-addresses") ||
|
||||
pathname.includes("/endpoints") ||
|
||||
pathname.includes("/directories")
|
||||
) {
|
||||
return "assets"
|
||||
}
|
||||
return "overview"
|
||||
}
|
||||
|
||||
// Get secondary navigation active tab (for assets)
|
||||
const getSecondaryTab = () => {
|
||||
if (pathname.includes("/websites")) return "websites"
|
||||
if (pathname.includes("/subdomain")) return "subdomain"
|
||||
if (pathname.includes("/ip-addresses")) return "ip-addresses"
|
||||
if (pathname.includes("/endpoints")) return "endpoints"
|
||||
if (pathname.includes("/directories")) return "directories"
|
||||
return "websites"
|
||||
}
|
||||
|
||||
// Check if we should show secondary navigation
|
||||
const showSecondaryNav = getPrimaryTab() === "assets"
|
||||
|
||||
const basePath = `/scan/history/${id}`
|
||||
const tabPaths = {
|
||||
subdomain: `${basePath}/subdomain/`,
|
||||
endpoints: `${basePath}/endpoints/`,
|
||||
websites: `${basePath}/websites/`,
|
||||
directories: `${basePath}/directories/`,
|
||||
const primaryPaths = {
|
||||
overview: `${basePath}/overview/`,
|
||||
assets: `${basePath}/websites/`, // Default to websites when clicking assets
|
||||
screenshots: `${basePath}/screenshots/`,
|
||||
vulnerabilities: `${basePath}/vulnerabilities/`,
|
||||
}
|
||||
|
||||
const secondaryPaths = {
|
||||
websites: `${basePath}/websites/`,
|
||||
subdomain: `${basePath}/subdomain/`,
|
||||
"ip-addresses": `${basePath}/ip-addresses/`,
|
||||
endpoints: `${basePath}/endpoints/`,
|
||||
directories: `${basePath}/directories/`,
|
||||
}
|
||||
|
||||
// Get counts for each tab from scan data
|
||||
const summary = scanData?.summary as any
|
||||
const counts = {
|
||||
subdomain: scanData?.summary?.subdomains || 0,
|
||||
endpoints: scanData?.summary?.endpoints || 0,
|
||||
websites: scanData?.summary?.websites || 0,
|
||||
directories: scanData?.summary?.directories || 0,
|
||||
vulnerabilities: scanData?.summary?.vulnerabilities?.total || 0,
|
||||
"ip-addresses": scanData?.summary?.ips || 0,
|
||||
subdomain: summary?.subdomains || 0,
|
||||
endpoints: summary?.endpoints || 0,
|
||||
websites: summary?.websites || 0,
|
||||
directories: summary?.directories || 0,
|
||||
screenshots: summary?.screenshots || 0,
|
||||
vulnerabilities: summary?.vulnerabilities?.total || 0,
|
||||
"ip-addresses": summary?.ips || 0,
|
||||
}
|
||||
|
||||
// Calculate total assets count
|
||||
const totalAssets = counts.websites + counts.subdomain + counts["ip-addresses"] + counts.endpoints + counts.directories
|
||||
|
||||
// Loading state
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex flex-col gap-4 py-4 md:gap-6 md:py-6">
|
||||
{/* Header skeleton */}
|
||||
<div className="flex items-center gap-2 px-4 lg:px-6">
|
||||
<Skeleton className="h-4 w-16" />
|
||||
<span className="text-muted-foreground">/</span>
|
||||
<Skeleton className="h-4 w-32" />
|
||||
</div>
|
||||
{/* Tabs skeleton */}
|
||||
<div className="flex gap-1 px-4 lg:px-6">
|
||||
<Skeleton className="h-9 w-20" />
|
||||
<Skeleton className="h-9 w-20" />
|
||||
<Skeleton className="h-9 w-24" />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4 py-4 md:gap-6 md:py-6">
|
||||
<div className="flex items-center justify-between px-4 lg:px-6">
|
||||
<div>
|
||||
<h2 className="text-2xl font-bold tracking-tight flex items-center gap-2">
|
||||
<Target />
|
||||
Scan Results
|
||||
</h2>
|
||||
<p className="text-muted-foreground">{t("taskId", { id })}</p>
|
||||
</div>
|
||||
<div className="flex flex-col gap-4 py-4 md:gap-6 md:py-6 h-full">
|
||||
{/* Header: Page label + Scan info */}
|
||||
<div className="flex items-center gap-2 text-sm px-4 lg:px-6">
|
||||
<span className="text-muted-foreground">{t("breadcrumb.scanHistory")}</span>
|
||||
<span className="text-muted-foreground">/</span>
|
||||
<span className="font-medium flex items-center gap-1.5">
|
||||
<Target className="h-4 w-4" />
|
||||
{(scanData?.target as any)?.name || t("taskId", { id })}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-between px-4 lg:px-6">
|
||||
<Tabs value={getActiveTab()} className="w-full">
|
||||
{/* Primary navigation */}
|
||||
<div className="px-4 lg:px-6">
|
||||
<Tabs value={getPrimaryTab()}>
|
||||
<TabsList>
|
||||
<TabsTrigger value="websites" asChild>
|
||||
<Link href={tabPaths.websites} className="flex items-center gap-0.5">
|
||||
Websites
|
||||
{counts.websites > 0 && (
|
||||
<TabsTrigger value="overview" asChild>
|
||||
<Link href={primaryPaths.overview} className="flex items-center gap-1.5">
|
||||
<LayoutDashboard className="h-4 w-4" />
|
||||
{t("tabs.overview")}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="assets" asChild>
|
||||
<Link href={primaryPaths.assets} className="flex items-center gap-1.5">
|
||||
<Package className="h-4 w-4" />
|
||||
{t("tabs.assets")}
|
||||
{totalAssets > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
{counts.websites}
|
||||
{totalAssets}
|
||||
</Badge>
|
||||
)}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="subdomain" asChild>
|
||||
<Link href={tabPaths.subdomain} className="flex items-center gap-0.5">
|
||||
Subdomains
|
||||
{counts.subdomain > 0 && (
|
||||
<TabsTrigger value="screenshots" asChild>
|
||||
<Link href={primaryPaths.screenshots} className="flex items-center gap-1.5">
|
||||
<Image className="h-4 w-4" />
|
||||
{t("tabs.screenshots")}
|
||||
{counts.screenshots > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
{counts.subdomain}
|
||||
</Badge>
|
||||
)}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="ip-addresses" asChild>
|
||||
<Link href={tabPaths["ip-addresses"]} className="flex items-center gap-0.5">
|
||||
IP Addresses
|
||||
{counts["ip-addresses"] > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
{counts["ip-addresses"]}
|
||||
</Badge>
|
||||
)}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="endpoints" asChild>
|
||||
<Link href={tabPaths.endpoints} className="flex items-center gap-0.5">
|
||||
URLs
|
||||
{counts.endpoints > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
{counts.endpoints}
|
||||
</Badge>
|
||||
)}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="directories" asChild>
|
||||
<Link href={tabPaths.directories} className="flex items-center gap-0.5">
|
||||
Directories
|
||||
{counts.directories > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
{counts.directories}
|
||||
{counts.screenshots}
|
||||
</Badge>
|
||||
)}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="vulnerabilities" asChild>
|
||||
<Link href={tabPaths.vulnerabilities} className="flex items-center gap-0.5">
|
||||
Vulnerabilities
|
||||
<Link href={primaryPaths.vulnerabilities} className="flex items-center gap-1.5">
|
||||
<ShieldAlert className="h-4 w-4" />
|
||||
{t("tabs.vulnerabilities")}
|
||||
{counts.vulnerabilities > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
{counts.vulnerabilities}
|
||||
@@ -128,6 +161,67 @@ export default function ScanHistoryLayout({
|
||||
</Tabs>
|
||||
</div>
|
||||
|
||||
{/* Secondary navigation (only for assets) */}
|
||||
{showSecondaryNav && (
|
||||
<div className="flex items-center px-4 lg:px-6">
|
||||
<Tabs value={getSecondaryTab()} className="w-full">
|
||||
<TabsList variant="underline">
|
||||
<TabsTrigger value="websites" variant="underline" asChild>
|
||||
<Link href={secondaryPaths.websites} className="flex items-center gap-0.5">
|
||||
Websites
|
||||
{counts.websites > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
{counts.websites}
|
||||
</Badge>
|
||||
)}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="subdomain" variant="underline" asChild>
|
||||
<Link href={secondaryPaths.subdomain} className="flex items-center gap-0.5">
|
||||
Subdomains
|
||||
{counts.subdomain > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
{counts.subdomain}
|
||||
</Badge>
|
||||
)}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="ip-addresses" variant="underline" asChild>
|
||||
<Link href={secondaryPaths["ip-addresses"]} className="flex items-center gap-0.5">
|
||||
IPs
|
||||
{counts["ip-addresses"] > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
{counts["ip-addresses"]}
|
||||
</Badge>
|
||||
)}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="endpoints" variant="underline" asChild>
|
||||
<Link href={secondaryPaths.endpoints} className="flex items-center gap-0.5">
|
||||
URLs
|
||||
{counts.endpoints > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
{counts.endpoints}
|
||||
</Badge>
|
||||
)}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="directories" variant="underline" asChild>
|
||||
<Link href={secondaryPaths.directories} className="flex items-center gap-0.5">
|
||||
Directories
|
||||
{counts.directories > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
{counts.directories}
|
||||
</Badge>
|
||||
)}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
</Tabs>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Sub-page content */}
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
|
||||
19
frontend/app/[locale]/scan/history/[id]/overview/page.tsx
Normal file
19
frontend/app/[locale]/scan/history/[id]/overview/page.tsx
Normal file
@@ -0,0 +1,19 @@
|
||||
"use client"
|
||||
|
||||
import { useParams } from "next/navigation"
|
||||
import { ScanOverview } from "@/components/scan/history/scan-overview"
|
||||
|
||||
/**
|
||||
* Scan overview page
|
||||
* Displays scan statistics and summary information
|
||||
*/
|
||||
export default function ScanOverviewPage() {
|
||||
const { id } = useParams<{ id: string }>()
|
||||
const scanId = Number(id)
|
||||
|
||||
return (
|
||||
<div className="flex-1 flex flex-col min-h-0 px-4 lg:px-6">
|
||||
<ScanOverview scanId={scanId} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -8,7 +8,7 @@ export default function ScanHistoryDetailPage() {
|
||||
const router = useRouter()
|
||||
|
||||
useEffect(() => {
|
||||
router.replace(`/scan/history/${id}/websites/`)
|
||||
router.replace(`/scan/history/${id}/overview/`)
|
||||
}, [id, router])
|
||||
|
||||
return null
|
||||
|
||||
15
frontend/app/[locale]/scan/history/[id]/screenshots/page.tsx
Normal file
15
frontend/app/[locale]/scan/history/[id]/screenshots/page.tsx
Normal file
@@ -0,0 +1,15 @@
|
||||
"use client"
|
||||
|
||||
import { useParams } from "next/navigation"
|
||||
import { ScreenshotsGallery } from "@/components/screenshots/screenshots-gallery"
|
||||
|
||||
export default function ScanScreenshotsPage() {
|
||||
const { id } = useParams<{ id: string }>()
|
||||
const scanId = Number(id)
|
||||
|
||||
return (
|
||||
<div className="px-4 lg:px-6">
|
||||
<ScreenshotsGallery scanId={scanId} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -29,6 +29,10 @@ export default function NotificationSettingsPage() {
|
||||
enabled: z.boolean(),
|
||||
webhookUrl: z.string().url(t("discord.urlInvalid")).or(z.literal('')),
|
||||
}),
|
||||
wecom: z.object({
|
||||
enabled: z.boolean(),
|
||||
webhookUrl: z.string().url(t("wecom.urlInvalid")).or(z.literal('')),
|
||||
}),
|
||||
categories: z.object({
|
||||
scan: z.boolean(),
|
||||
vulnerability: z.boolean(),
|
||||
@@ -46,6 +50,15 @@ export default function NotificationSettingsPage() {
|
||||
})
|
||||
}
|
||||
}
|
||||
if (val.wecom.enabled) {
|
||||
if (!val.wecom.webhookUrl || val.wecom.webhookUrl.trim() === '') {
|
||||
ctx.addIssue({
|
||||
code: z.ZodIssueCode.custom,
|
||||
message: t("wecom.requiredError"),
|
||||
path: ['wecom', 'webhookUrl'],
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
const NOTIFICATION_CATEGORIES = [
|
||||
@@ -79,6 +92,7 @@ export default function NotificationSettingsPage() {
|
||||
resolver: zodResolver(schema),
|
||||
values: data ?? {
|
||||
discord: { enabled: false, webhookUrl: '' },
|
||||
wecom: { enabled: false, webhookUrl: '' },
|
||||
categories: {
|
||||
scan: true,
|
||||
vulnerability: true,
|
||||
@@ -93,6 +107,7 @@ export default function NotificationSettingsPage() {
|
||||
}
|
||||
|
||||
const discordEnabled = form.watch('discord.enabled')
|
||||
const wecomEnabled = form.watch('wecom.enabled')
|
||||
|
||||
return (
|
||||
<div className="p-4 md:p-6 space-y-6">
|
||||
@@ -187,25 +202,59 @@ export default function NotificationSettingsPage() {
|
||||
</CardHeader>
|
||||
</Card>
|
||||
|
||||
{/* Feishu/DingTalk/WeCom - Coming soon */}
|
||||
<Card className="opacity-60">
|
||||
{/* 企业微信 */}
|
||||
<Card>
|
||||
<CardHeader className="pb-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-lg bg-muted">
|
||||
<IconBrandSlack className="h-5 w-5 text-muted-foreground" />
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-lg bg-[#07C160]/10">
|
||||
<IconBrandSlack className="h-5 w-5 text-[#07C160]" />
|
||||
</div>
|
||||
<div>
|
||||
<div className="flex items-center gap-2">
|
||||
<CardTitle className="text-base">{t("enterprise.title")}</CardTitle>
|
||||
<Badge variant="secondary" className="text-xs">{t("emailChannel.comingSoon")}</Badge>
|
||||
</div>
|
||||
<CardDescription>{t("enterprise.description")}</CardDescription>
|
||||
<CardTitle className="text-base">{t("wecom.title")}</CardTitle>
|
||||
<CardDescription>{t("wecom.description")}</CardDescription>
|
||||
</div>
|
||||
</div>
|
||||
<Switch disabled />
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="wecom.enabled"
|
||||
render={({ field }) => (
|
||||
<FormControl>
|
||||
<Switch
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
disabled={isLoading || updateMutation.isPending}
|
||||
/>
|
||||
</FormControl>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</CardHeader>
|
||||
{wecomEnabled && (
|
||||
<CardContent className="pt-0">
|
||||
<Separator className="mb-4" />
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="wecom.webhookUrl"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>{t("wecom.webhookLabel")}</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
placeholder={t("wecom.webhookPlaceholder")}
|
||||
{...field}
|
||||
disabled={isLoading || updateMutation.isPending}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormDescription>
|
||||
{t("wecom.webhookHelp")}
|
||||
</FormDescription>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</CardContent>
|
||||
)}
|
||||
</Card>
|
||||
</TabsContent>
|
||||
|
||||
|
||||
@@ -1,17 +1,10 @@
|
||||
"use client"
|
||||
|
||||
import { useTranslations } from "next-intl"
|
||||
import { SystemLogsView } from "@/components/settings/system-logs"
|
||||
|
||||
export default function SystemLogsPage() {
|
||||
const t = useTranslations("settings.systemLogs")
|
||||
|
||||
return (
|
||||
<div className="flex flex-1 flex-col gap-4 p-4">
|
||||
<div>
|
||||
<h1 className="text-2xl font-bold tracking-tight">{t("title")}</h1>
|
||||
<p className="text-muted-foreground">{t("description")}</p>
|
||||
</div>
|
||||
<div className="flex flex-1 flex-col p-4 h-full">
|
||||
<SystemLogsView />
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import React from "react"
|
||||
import { usePathname, useParams } from "next/navigation"
|
||||
import Link from "next/link"
|
||||
import { Target } from "lucide-react"
|
||||
import { Target, LayoutDashboard, Package, Image, ShieldAlert, Settings } from "lucide-react"
|
||||
import { Skeleton } from "@/components/ui/skeleton"
|
||||
import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs"
|
||||
import { Badge } from "@/components/ui/badge"
|
||||
@@ -34,6 +34,7 @@ export default function TargetLayout({
|
||||
// Get primary navigation active tab
|
||||
const getPrimaryTab = () => {
|
||||
if (pathname.includes("/overview")) return "overview"
|
||||
if (pathname.includes("/screenshots")) return "screenshots"
|
||||
if (pathname.includes("/vulnerabilities")) return "vulnerabilities"
|
||||
if (pathname.includes("/settings")) return "settings"
|
||||
// All asset pages fall under "assets"
|
||||
@@ -67,6 +68,7 @@ export default function TargetLayout({
|
||||
const primaryPaths = {
|
||||
overview: `${basePath}/overview/`,
|
||||
assets: `${basePath}/websites/`, // Default to websites when clicking assets
|
||||
screenshots: `${basePath}/screenshots/`,
|
||||
vulnerabilities: `${basePath}/vulnerabilities/`,
|
||||
settings: `${basePath}/settings/`,
|
||||
}
|
||||
@@ -87,6 +89,7 @@ export default function TargetLayout({
|
||||
directories: (target as any)?.summary?.directories || 0,
|
||||
vulnerabilities: (target as any)?.summary?.vulnerabilities?.total || 0,
|
||||
"ip-addresses": (target as any)?.summary?.ips || 0,
|
||||
screenshots: (target as any)?.summary?.screenshots || 0,
|
||||
}
|
||||
|
||||
// Calculate total assets count
|
||||
@@ -162,12 +165,14 @@ export default function TargetLayout({
|
||||
<Tabs value={getPrimaryTab()}>
|
||||
<TabsList>
|
||||
<TabsTrigger value="overview" asChild>
|
||||
<Link href={primaryPaths.overview} className="flex items-center gap-0.5">
|
||||
<Link href={primaryPaths.overview} className="flex items-center gap-1.5">
|
||||
<LayoutDashboard className="h-4 w-4" />
|
||||
{t("tabs.overview")}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="assets" asChild>
|
||||
<Link href={primaryPaths.assets} className="flex items-center gap-0.5">
|
||||
<Link href={primaryPaths.assets} className="flex items-center gap-1.5">
|
||||
<Package className="h-4 w-4" />
|
||||
{t("tabs.assets")}
|
||||
{totalAssets > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
@@ -176,8 +181,20 @@ export default function TargetLayout({
|
||||
)}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="screenshots" asChild>
|
||||
<Link href={primaryPaths.screenshots} className="flex items-center gap-1.5">
|
||||
<Image className="h-4 w-4" />
|
||||
{t("tabs.screenshots")}
|
||||
{counts.screenshots > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
{counts.screenshots}
|
||||
</Badge>
|
||||
)}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="vulnerabilities" asChild>
|
||||
<Link href={primaryPaths.vulnerabilities} className="flex items-center gap-0.5">
|
||||
<Link href={primaryPaths.vulnerabilities} className="flex items-center gap-1.5">
|
||||
<ShieldAlert className="h-4 w-4" />
|
||||
{t("tabs.vulnerabilities")}
|
||||
{counts.vulnerabilities > 0 && (
|
||||
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
|
||||
@@ -187,7 +204,8 @@ export default function TargetLayout({
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="settings" asChild>
|
||||
<Link href={primaryPaths.settings} className="flex items-center gap-0.5">
|
||||
<Link href={primaryPaths.settings} className="flex items-center gap-1.5">
|
||||
<Settings className="h-4 w-4" />
|
||||
{t("tabs.settings")}
|
||||
</Link>
|
||||
</TabsTrigger>
|
||||
|
||||
15
frontend/app/[locale]/target/[id]/screenshots/page.tsx
Normal file
15
frontend/app/[locale]/target/[id]/screenshots/page.tsx
Normal file
@@ -0,0 +1,15 @@
|
||||
"use client"
|
||||
|
||||
import { useParams } from "next/navigation"
|
||||
import { ScreenshotsGallery } from "@/components/screenshots/screenshots-gallery"
|
||||
|
||||
export default function ScreenshotsPage() {
|
||||
const { id } = useParams<{ id: string }>()
|
||||
const targetId = Number(id)
|
||||
|
||||
return (
|
||||
<div className="px-4 lg:px-6">
|
||||
<ScreenshotsGallery targetId={targetId} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
189
frontend/components/about-dialog.tsx
Normal file
189
frontend/components/about-dialog.tsx
Normal file
@@ -0,0 +1,189 @@
|
||||
"use client"
|
||||
|
||||
import { useState } from 'react'
|
||||
import { useTranslations } from 'next-intl'
|
||||
import { useQueryClient } from '@tanstack/react-query'
|
||||
import {
|
||||
IconRadar,
|
||||
IconRefresh,
|
||||
IconExternalLink,
|
||||
IconBrandGithub,
|
||||
IconMessageReport,
|
||||
IconBook,
|
||||
IconFileText,
|
||||
IconCheck,
|
||||
IconArrowUp,
|
||||
} from '@tabler/icons-react'
|
||||
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Separator } from '@/components/ui/separator'
|
||||
import { Badge } from '@/components/ui/badge'
|
||||
import { useVersion } from '@/hooks/use-version'
|
||||
import { VersionService } from '@/services/version.service'
|
||||
import type { UpdateCheckResult } from '@/types/version.types'
|
||||
|
||||
interface AboutDialogProps {
|
||||
children: React.ReactNode
|
||||
}
|
||||
|
||||
export function AboutDialog({ children }: AboutDialogProps) {
|
||||
const t = useTranslations('about')
|
||||
const { data: versionData } = useVersion()
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const [isChecking, setIsChecking] = useState(false)
|
||||
const [updateResult, setUpdateResult] = useState<UpdateCheckResult | null>(null)
|
||||
const [checkError, setCheckError] = useState<string | null>(null)
|
||||
|
||||
const handleCheckUpdate = async () => {
|
||||
setIsChecking(true)
|
||||
setCheckError(null)
|
||||
try {
|
||||
const result = await VersionService.checkUpdate()
|
||||
setUpdateResult(result)
|
||||
queryClient.setQueryData(['check-update'], result)
|
||||
} catch {
|
||||
setCheckError(t('checkFailed'))
|
||||
} finally {
|
||||
setIsChecking(false)
|
||||
}
|
||||
}
|
||||
|
||||
const currentVersion = updateResult?.currentVersion || versionData?.version || '-'
|
||||
const latestVersion = updateResult?.latestVersion
|
||||
const hasUpdate = updateResult?.hasUpdate
|
||||
|
||||
return (
|
||||
<Dialog>
|
||||
<DialogTrigger asChild>
|
||||
{children}
|
||||
</DialogTrigger>
|
||||
<DialogContent className="sm:max-w-md">
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t('title')}</DialogTitle>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="space-y-6">
|
||||
{/* Logo and name */}
|
||||
<div className="flex flex-col items-center py-4">
|
||||
<div className="flex h-16 w-16 items-center justify-center rounded-2xl bg-primary/10 mb-3">
|
||||
<IconRadar className="h-8 w-8 text-primary" />
|
||||
</div>
|
||||
<h2 className="text-xl font-semibold">XingRin</h2>
|
||||
<p className="text-sm text-muted-foreground">{t('description')}</p>
|
||||
</div>
|
||||
|
||||
{/* Version info */}
|
||||
<div className="rounded-lg border p-4 space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-muted-foreground">{t('currentVersion')}</span>
|
||||
<span className="font-mono text-sm">{currentVersion}</span>
|
||||
</div>
|
||||
|
||||
{updateResult && (
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-muted-foreground">{t('latestVersion')}</span>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-mono text-sm">{latestVersion}</span>
|
||||
{hasUpdate ? (
|
||||
<Badge variant="default" className="gap-1">
|
||||
<IconArrowUp className="h-3 w-3" />
|
||||
{t('updateAvailable')}
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge variant="secondary" className="gap-1">
|
||||
<IconCheck className="h-3 w-3" />
|
||||
{t('upToDate')}
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{checkError && (
|
||||
<p className="text-sm text-destructive">{checkError}</p>
|
||||
)}
|
||||
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
className="flex-1"
|
||||
onClick={handleCheckUpdate}
|
||||
disabled={isChecking}
|
||||
>
|
||||
<IconRefresh className={`h-4 w-4 mr-2 ${isChecking ? 'animate-spin' : ''}`} />
|
||||
{isChecking ? t('checking') : t('checkUpdate')}
|
||||
</Button>
|
||||
|
||||
{hasUpdate && updateResult?.releaseUrl && (
|
||||
<Button
|
||||
variant="default"
|
||||
size="sm"
|
||||
className="flex-1"
|
||||
asChild
|
||||
>
|
||||
<a href={updateResult.releaseUrl} target="_blank" rel="noopener noreferrer">
|
||||
<IconExternalLink className="h-4 w-4 mr-2" />
|
||||
{t('viewRelease')}
|
||||
</a>
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{hasUpdate && (
|
||||
<div className="rounded-md bg-muted p-3 text-sm text-muted-foreground">
|
||||
<p>{t('updateHint')}</p>
|
||||
<code className="mt-1 block rounded bg-background px-2 py-1 font-mono text-xs">
|
||||
sudo ./update.sh
|
||||
</code>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* Links */}
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
<Button variant="ghost" size="sm" className="justify-start" asChild>
|
||||
<a href="https://github.com/yyhuni/xingrin" target="_blank" rel="noopener noreferrer">
|
||||
<IconBrandGithub className="h-4 w-4 mr-2" />
|
||||
GitHub
|
||||
</a>
|
||||
</Button>
|
||||
<Button variant="ghost" size="sm" className="justify-start" asChild>
|
||||
<a href="https://github.com/yyhuni/xingrin/releases" target="_blank" rel="noopener noreferrer">
|
||||
<IconFileText className="h-4 w-4 mr-2" />
|
||||
{t('changelog')}
|
||||
</a>
|
||||
</Button>
|
||||
<Button variant="ghost" size="sm" className="justify-start" asChild>
|
||||
<a href="https://github.com/yyhuni/xingrin/issues" target="_blank" rel="noopener noreferrer">
|
||||
<IconMessageReport className="h-4 w-4 mr-2" />
|
||||
{t('feedback')}
|
||||
</a>
|
||||
</Button>
|
||||
<Button variant="ghost" size="sm" className="justify-start" asChild>
|
||||
<a href="https://github.com/yyhuni/xingrin#readme" target="_blank" rel="noopener noreferrer">
|
||||
<IconBook className="h-4 w-4 mr-2" />
|
||||
{t('docs')}
|
||||
</a>
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Footer */}
|
||||
<p className="text-center text-xs text-muted-foreground">
|
||||
© 2025 XingRin · MIT License
|
||||
</p>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import type * as React from "react"
|
||||
// Import various icons from Tabler Icons library
|
||||
import {
|
||||
IconDashboard, // Dashboard icon
|
||||
IconHelp, // Help icon
|
||||
IconListDetails, // List details icon
|
||||
IconSettings, // Settings icon
|
||||
IconUsers, // Users icon
|
||||
@@ -15,10 +14,10 @@ import {
|
||||
IconServer, // Server icon
|
||||
IconTerminal2, // Terminal icon
|
||||
IconBug, // Vulnerability icon
|
||||
IconMessageReport, // Feedback icon
|
||||
IconSearch, // Search icon
|
||||
IconKey, // API Key icon
|
||||
IconBan, // Blacklist icon
|
||||
IconInfoCircle, // About icon
|
||||
} from "@tabler/icons-react"
|
||||
// Import internationalization hook
|
||||
import { useTranslations } from 'next-intl'
|
||||
@@ -27,8 +26,8 @@ import { Link, usePathname } from '@/i18n/navigation'
|
||||
|
||||
// Import custom navigation components
|
||||
import { NavSystem } from "@/components/nav-system"
|
||||
import { NavSecondary } from "@/components/nav-secondary"
|
||||
import { NavUser } from "@/components/nav-user"
|
||||
import { AboutDialog } from "@/components/about-dialog"
|
||||
// Import sidebar UI components
|
||||
import {
|
||||
Sidebar,
|
||||
@@ -139,20 +138,6 @@ export function AppSidebar({ ...props }: React.ComponentProps<typeof Sidebar>) {
|
||||
},
|
||||
]
|
||||
|
||||
// Secondary navigation menu items
|
||||
const navSecondary = [
|
||||
{
|
||||
title: t('feedback'),
|
||||
url: "https://github.com/yyhuni/xingrin/issues",
|
||||
icon: IconMessageReport,
|
||||
},
|
||||
{
|
||||
title: t('help'),
|
||||
url: "https://github.com/yyhuni/xingrin",
|
||||
icon: IconHelp,
|
||||
},
|
||||
]
|
||||
|
||||
// System settings related menu items
|
||||
const documents = [
|
||||
{
|
||||
@@ -271,8 +256,21 @@ export function AppSidebar({ ...props }: React.ComponentProps<typeof Sidebar>) {
|
||||
|
||||
{/* System settings navigation menu */}
|
||||
<NavSystem items={documents} />
|
||||
{/* Secondary navigation menu, using mt-auto to push to bottom */}
|
||||
<NavSecondary items={navSecondary} className="mt-auto" />
|
||||
{/* About system button */}
|
||||
<SidebarGroup className="mt-auto">
|
||||
<SidebarGroupContent>
|
||||
<SidebarMenu>
|
||||
<SidebarMenuItem>
|
||||
<AboutDialog>
|
||||
<SidebarMenuButton>
|
||||
<IconInfoCircle />
|
||||
<span>{t('about')}</span>
|
||||
</SidebarMenuButton>
|
||||
</AboutDialog>
|
||||
</SidebarMenuItem>
|
||||
</SidebarMenu>
|
||||
</SidebarGroupContent>
|
||||
</SidebarGroup>
|
||||
</SidebarContent>
|
||||
|
||||
{/* Sidebar footer */}
|
||||
|
||||
@@ -365,6 +365,7 @@ export function DashboardDataTable() {
|
||||
columns={scanColumns}
|
||||
getRowId={(row) => String(row.id)}
|
||||
enableRowSelection={false}
|
||||
enableAutoColumnSizing
|
||||
pagination={scanPagination}
|
||||
onPaginationChange={setScanPagination}
|
||||
paginationInfo={scanPaginationInfo}
|
||||
|
||||
@@ -114,7 +114,7 @@ export function DirectoriesDataTable({
|
||||
onSelectionChange={handleSelectionChange}
|
||||
// Bulk operations
|
||||
onBulkDelete={onBulkDelete}
|
||||
bulkDeleteLabel="Delete"
|
||||
bulkDeleteLabel={tActions("delete")}
|
||||
showAddButton={false}
|
||||
// Bulk add button
|
||||
onBulkAdd={onBulkAdd}
|
||||
|
||||
@@ -11,6 +11,7 @@ import { useTargetDirectories, useScanDirectories } from "@/hooks/use-directorie
|
||||
import { useTarget } from "@/hooks/use-targets"
|
||||
import { DirectoryService } from "@/services/directory.service"
|
||||
import { BulkAddUrlsDialog } from "@/components/common/bulk-add-urls-dialog"
|
||||
import { ConfirmDialog } from "@/components/ui/confirm-dialog"
|
||||
import { getDateLocale } from "@/lib/date-utils"
|
||||
import type { TargetType } from "@/lib/url-validator"
|
||||
import type { Directory } from "@/types/directory.types"
|
||||
@@ -29,6 +30,8 @@ export function DirectoriesView({
|
||||
})
|
||||
const [selectedDirectories, setSelectedDirectories] = useState<Directory[]>([])
|
||||
const [bulkAddDialogOpen, setBulkAddDialogOpen] = useState(false)
|
||||
const [deleteDialogOpen, setDeleteDialogOpen] = useState(false)
|
||||
const [isDeleting, setIsDeleting] = useState(false)
|
||||
|
||||
const [filterQuery, setFilterQuery] = useState("")
|
||||
const [isSearching, setIsSearching] = useState(false)
|
||||
@@ -240,6 +243,26 @@ export function DirectoriesView({
|
||||
URL.revokeObjectURL(url)
|
||||
}
|
||||
|
||||
// Handle bulk delete
|
||||
const handleBulkDelete = async () => {
|
||||
if (selectedDirectories.length === 0) return
|
||||
|
||||
setIsDeleting(true)
|
||||
try {
|
||||
const ids = selectedDirectories.map(d => d.id)
|
||||
const result = await DirectoryService.bulkDelete(ids)
|
||||
toast.success(tToast("deleteSuccess", { count: result.deletedCount }))
|
||||
setSelectedDirectories([])
|
||||
setDeleteDialogOpen(false)
|
||||
refetch()
|
||||
} catch (error) {
|
||||
console.error("Failed to delete directories", error)
|
||||
toast.error(tToast("deleteFailed"))
|
||||
} finally {
|
||||
setIsDeleting(false)
|
||||
}
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-12">
|
||||
@@ -280,6 +303,7 @@ export function DirectoriesView({
|
||||
onSelectionChange={handleSelectionChange}
|
||||
onDownloadAll={handleDownloadAll}
|
||||
onDownloadSelected={handleDownloadSelected}
|
||||
onBulkDelete={targetId ? () => setDeleteDialogOpen(true) : undefined}
|
||||
onBulkAdd={targetId ? () => setBulkAddDialogOpen(true) : undefined}
|
||||
/>
|
||||
|
||||
@@ -295,6 +319,17 @@ export function DirectoriesView({
|
||||
onSuccess={() => refetch()}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Delete confirmation dialog */}
|
||||
<ConfirmDialog
|
||||
open={deleteDialogOpen}
|
||||
onOpenChange={setDeleteDialogOpen}
|
||||
title={tCommon("actions.confirmDelete")}
|
||||
description={tCommon("actions.deleteConfirmMessage", { count: selectedDirectories.length })}
|
||||
onConfirm={handleBulkDelete}
|
||||
loading={isDeleting}
|
||||
variant="destructive"
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ interface EndpointsDataTableProps<TData extends { id: number | string }, TValue>
|
||||
onAddNew?: () => void
|
||||
addButtonText?: string
|
||||
onSelectionChange?: (selectedRows: TData[]) => void
|
||||
onBulkDelete?: () => void
|
||||
pagination?: { pageIndex: number; pageSize: number }
|
||||
onPaginationChange?: (pagination: { pageIndex: number; pageSize: number }) => void
|
||||
totalCount?: number
|
||||
@@ -54,6 +55,7 @@ export function EndpointsDataTable<TData extends { id: number | string }, TValue
|
||||
onAddNew,
|
||||
addButtonText = "Add",
|
||||
onSelectionChange,
|
||||
onBulkDelete,
|
||||
pagination: externalPagination,
|
||||
onPaginationChange,
|
||||
totalCount,
|
||||
@@ -135,7 +137,8 @@ export function EndpointsDataTable<TData extends { id: number | string }, TValue
|
||||
// Selection
|
||||
onSelectionChange={onSelectionChange}
|
||||
// Bulk operations
|
||||
showBulkDelete={false}
|
||||
onBulkDelete={onBulkDelete}
|
||||
bulkDeleteLabel={tActions("delete")}
|
||||
onAddNew={onAddNew}
|
||||
addButtonLabel={addButtonText}
|
||||
// Bulk add button
|
||||
|
||||
@@ -10,6 +10,7 @@ import { createEndpointColumns } from "./endpoints-columns"
|
||||
import { LoadingSpinner } from "@/components/loading-spinner"
|
||||
import { DataTableSkeleton } from "@/components/ui/data-table-skeleton"
|
||||
import { BulkAddUrlsDialog } from "@/components/common/bulk-add-urls-dialog"
|
||||
import { ConfirmDialog } from "@/components/ui/confirm-dialog"
|
||||
import { getDateLocale } from "@/lib/date-utils"
|
||||
import type { TargetType } from "@/lib/url-validator"
|
||||
import {
|
||||
@@ -41,6 +42,8 @@ export function EndpointsDetailView({
|
||||
const [endpointToDelete, setEndpointToDelete] = useState<Endpoint | null>(null)
|
||||
const [selectedEndpoints, setSelectedEndpoints] = useState<Endpoint[]>([])
|
||||
const [bulkAddDialogOpen, setBulkAddDialogOpen] = useState(false)
|
||||
const [bulkDeleteDialogOpen, setBulkDeleteDialogOpen] = useState(false)
|
||||
const [isDeleting, setIsDeleting] = useState(false)
|
||||
|
||||
// Pagination state management
|
||||
const [pagination, setPagination] = useState({
|
||||
@@ -280,6 +283,26 @@ export function EndpointsDetailView({
|
||||
URL.revokeObjectURL(url)
|
||||
}
|
||||
|
||||
// Handle bulk delete
|
||||
const handleBulkDelete = async () => {
|
||||
if (selectedEndpoints.length === 0) return
|
||||
|
||||
setIsDeleting(true)
|
||||
try {
|
||||
const ids = selectedEndpoints.map(e => e.id)
|
||||
const result = await EndpointService.bulkDelete(ids)
|
||||
toast.success(tToast("deleteSuccess", { count: result.deletedCount }))
|
||||
setSelectedEndpoints([])
|
||||
setBulkDeleteDialogOpen(false)
|
||||
refetch()
|
||||
} catch (error) {
|
||||
console.error("Failed to delete endpoints", error)
|
||||
toast.error(tToast("deleteFailed"))
|
||||
} finally {
|
||||
setIsDeleting(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Error state
|
||||
if (error) {
|
||||
return (
|
||||
@@ -327,6 +350,7 @@ export function EndpointsDetailView({
|
||||
onSelectionChange={handleSelectionChange}
|
||||
onDownloadAll={handleDownloadAll}
|
||||
onDownloadSelected={handleDownloadSelected}
|
||||
onBulkDelete={targetId ? () => setBulkDeleteDialogOpen(true) : undefined}
|
||||
onBulkAdd={targetId ? () => setBulkAddDialogOpen(true) : undefined}
|
||||
/>
|
||||
|
||||
@@ -343,7 +367,18 @@ export function EndpointsDetailView({
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Delete confirmation dialog */}
|
||||
{/* Bulk delete confirmation dialog */}
|
||||
<ConfirmDialog
|
||||
open={bulkDeleteDialogOpen}
|
||||
onOpenChange={setBulkDeleteDialogOpen}
|
||||
title={tConfirm("deleteTitle")}
|
||||
description={tCommon("actions.deleteConfirmMessage", { count: selectedEndpoints.length })}
|
||||
onConfirm={handleBulkDelete}
|
||||
loading={isDeleting}
|
||||
variant="destructive"
|
||||
/>
|
||||
|
||||
{/* Single delete confirmation dialog */}
|
||||
<AlertDialog open={deleteDialogOpen} onOpenChange={setDeleteDialogOpen}>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user