Compare commits

..

13 Commits

Author SHA1 Message Date
yyhuni
c60383940c 提供升级功能 2026-01-10 10:04:07 +08:00
yyhuni
47298c294a 性能优化 2026-01-10 09:44:49 +08:00
yyhuni
eba394e14e 优化:性能优化 2026-01-10 09:44:43 +08:00
yyhuni
592a1958c4 优化ui 2026-01-09 16:52:50 +08:00
yyhuni
38e2856c08 feat(scan): add provider abstraction layer for flexible target sourcing
- Add TargetProvider base class and ProviderContext for unified target acquisition
- Implement DatabaseTargetProvider for database-backed target queries
- Implement ListTargetProvider for in-memory target lists (fast scan phase 1)
- Implement SnapshotTargetProvider for snapshot table reads (fast scan phase 2+)
- Implement PipelineTargetProvider for pipeline stage outputs
- Add comprehensive provider tests covering common properties and individual providers
- Update screenshot_flow to support both legacy mode (target_id) and provider mode
- Add backward compatibility layer for existing task exports (directory, fingerprint, port, site, url_fetch, vuln scans)
- Add task backward compatibility tests
- Update .gitignore to exclude .hypothesis/ cache directory
- Update frontend ANSI log viewer component
- Update backend requirements.txt with new dependencies
- Enables flexible data source integration while maintaining backward compatibility with existing database-driven workflows
2026-01-09 09:02:09 +08:00
yyhuni
f5ad8e68e9 chore(backend): add hypothesis cache directory to gitignore
- Add .hypothesis/ directory to .gitignore to exclude Hypothesis property testing cache files
- Prevents test cache artifacts from being tracked in version control
- Improves repository cleanliness by ignoring generated test data
2026-01-08 11:58:49 +08:00
yyhuni
d5f91a236c Merge branch 'main' of https://github.com/yyhuni/xingrin 2026-01-08 10:37:32 +08:00
yyhuni
24ae8b5aeb docs: restructure README features section with capability tables
- Convert feature descriptions from nested lists to organized capability tables
- Add scanning capability table with tools and descriptions for each feature
- Add platform capability table highlighting core platform features
- Improve readability and scannability of feature documentation
- Maintain scanning pipeline architecture section for reference
- Simplify feature organization for better user comprehension
2026-01-08 10:35:56 +08:00
github-actions[bot]
86f43f94a0 chore: bump version to v1.5.3 2026-01-08 02:17:58 +00:00
yyhuni
53ba03d1e5 支持kali 2026-01-08 10:14:12 +08:00
github-actions[bot]
89c44ebd05 chore: bump version to v1.5.2 2026-01-08 00:20:11 +00:00
yyhuni
e0e3419edb chore(docker): improve worker dockerfile reliability with retry mechanism
- Add retry mechanism for apt-get install to handle ARM64 mirror sync delays
- Use --no-install-recommends flag to reduce image size and installation time
- Split apt-get update and install commands for better layer caching
- Add fallback installation logic for packages in case of initial failure
- Include explanatory comment about ARM64 ports.ubuntu.com potential delays
- Maintain compatibility with both ARM64 and AMD64 architectures
2026-01-08 08:14:24 +08:00
yyhuni
52ee4684a7 chore(docker): add apt-get update before playwright dependencies
- Add apt-get update before installing playwright chromium dependencies
- Ensures package lists are refreshed before installing system dependencies
- Prevents potential package installation failures in Docker builds
2026-01-08 08:09:21 +08:00
56 changed files with 4474 additions and 2319 deletions

1
.gitignore vendored
View File

@@ -64,6 +64,7 @@ backend/.env.local
.coverage
htmlcov/
*.cover
.hypothesis/
# ============================
# 后端 (Go) 相关

View File

@@ -58,33 +58,33 @@
## ✨ 功能特性
### 🎯 目标与资产管理
- **组织管理** - 多层级目标组织,灵活分组
- **目标管理** - 支持域名、IP目标类型
- **资产发现** - 子域名、网站、端点、目录自动发现
- **资产快照** - 扫描结果快照对比,追踪资产变化
### 扫描能力
### 🔍 漏洞扫描
- **多引擎支持** - 集成 Nuclei 等主流扫描引擎
- **自定义流程** - YAML 配置扫描流程,灵活编排
- **定时扫描** - Cron 表达式配置,自动化周期扫描
| 功能 | 状态 | 工具 | 说明 |
|------|------|------|------|
| 子域名扫描 | ✅ | Subfinder, Amass, PureDNS | 被动收集 + 主动爆破,聚合 50+ 数据源 |
| 端口扫描 | ✅ | Naabu | 自定义端口范围 |
| 站点发现 | ✅ | HTTPX | HTTP 探测,自动获取标题、状态码、技术栈 |
| 指纹识别 | ✅ | XingFinger | 2.7W+ 指纹规则,多源指纹库 |
| URL 收集 | ✅ | Waymore, Katana | 历史数据 + 主动爬取 |
| 目录扫描 | ✅ | FFUF | 高速爆破,智能字典 |
| 漏洞扫描 | ✅ | Nuclei, Dalfox | 9000+ POC 模板XSS 检测 |
| 站点截图 | ✅ | Playwright | WebP 高压缩存储 |
### 🚫 黑名单过滤
- **两层黑名单** - 全局黑名单 + Target 级黑名单,灵活控制扫描范围
- **智能规则识别** - 自动识别域名通配符(`*.gov`、IP、CIDR 网段
### 平台能力
### 🔖 指纹识别
- **多源指纹库** - 内置 EHole、Goby、Wappalyzer、Fingers、FingerPrintHub、ARL 等 2.7W+ 指纹规则
- **自动识别** - 扫描流程自动执行,识别 Web 应用技术栈
- **指纹管理** - 支持查询、导入、导出指纹规则
| 功能 | 状态 | 说明 |
|------|------|------|
| 目标管理 | ✅ | 多层级组织,支持域名/IP 目标 |
| 资产快照 | ✅ | 扫描结果对比,追踪资产变化 |
| 黑名单过滤 | ✅ | 全局 + Target 级,支持通配符/CIDR |
| 定时任务 | ✅ | Cron 表达式,自动化周期扫描 |
| 分布式扫描 | ✅ | 多 Worker 节点,负载感知调度 |
| 全局搜索 | ✅ | 表达式语法,多字段组合查询 |
| 通知推送 | ✅ | 企业微信、Telegram、Discord |
| API 密钥管理 | ✅ | 可视化配置各数据源 API Key |
### 📸 站点截图
- **自动截图** - 使用 Playwright 对发现的网站自动截图
- **WebP 格式** - 高压缩比存储500k图片压缩存储只占几十K
- **多来源支持** - 支持对 Websites、Endpoints 等不同来源的 URL 截图
- **资产关联** - 截图自动同步到资产表,方便查看
#### 扫描流程架构
### 扫描流程架构
完整的扫描流程包括子域名发现、端口扫描、站点发现、指纹识别、URL 收集、目录扫描、漏洞扫描等阶段

View File

@@ -1 +1 @@
v1.4.1
v1.5.3

1
backend/.gitignore vendored
View File

@@ -7,6 +7,7 @@ __pycache__/
*.egg-info/
dist/
build/
.hypothesis/ # Hypothesis 属性测试缓存
# 虚拟环境
venv/

View File

@@ -312,7 +312,11 @@ class TaskDistributor:
# - 本地 Workerinstall.sh 已预拉取镜像,直接使用本地版本
# - 远程 Workerdeploy 时已预拉取镜像,直接使用本地版本
# - 避免每次任务都检查 Docker Hub提升性能和稳定性
# OOM 优先级:--oom-score-adj=1000 让 Worker 在内存不足时优先被杀
# - 范围 -1000 到 1000值越大越容易被 OOM Killer 选中
# - 保护 server/nginx/frontend 等核心服务,确保 Web 界面可用
cmd = f'''docker run --rm -d --pull=missing {network_arg} \\
--oom-score-adj=1000 \\
{' '.join(env_vars)} \\
{' '.join(volumes)} \\
{self.docker_image} \\

View File

@@ -24,18 +24,6 @@ SUBDOMAIN_DISCOVERY_COMMANDS = {
}
},
'amass_passive': {
# 先执行被动枚举,将结果写入 amass 内部数据库然后从数据库中导出纯域名names到 output_file
# -silent 禁用进度条和其他输出
'base': "amass enum -passive -silent -d {domain} && amass subs -names -d {domain} > '{output_file}'"
},
'amass_active': {
# 先执行主动枚举 + 爆破,将结果写入 amass 内部数据库然后从数据库中导出纯域名names到 output_file
# -silent 禁用进度条和其他输出
'base': "amass enum -active -silent -d {domain} -brute && amass subs -names -d {domain} > '{output_file}'"
},
'sublist3r': {
'base': "python3 '/usr/local/share/Sublist3r/sublist3r.py' -d {domain} -o '{output_file}'",
'optional': {

View File

@@ -17,14 +17,6 @@ subdomain_discovery:
timeout: 3600 # 1小时
# threads: 10 # 并发 goroutine 数
amass_passive:
enabled: true
timeout: 3600
amass_active:
enabled: true # 主动枚举 + 爆破
timeout: 3600
sublist3r:
enabled: true
timeout: 3600
@@ -62,7 +54,7 @@ port_scan:
threads: 200 # 并发连接数(默认 5
# ports: 1-65535 # 扫描端口范围(默认 1-65535
top-ports: 100 # 扫描 nmap top 100 端口
rate: 10 # 扫描速率(默认 10
rate: 50 # 扫描速率
naabu_passive:
enabled: true

View File

@@ -10,30 +10,30 @@
- 配置由 YAML 解析
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
from prefect import flow
from prefect.task_runners import ThreadPoolTaskRunner
import hashlib
import logging
import os
import subprocess
from datetime import datetime
from pathlib import Path
from typing import List, Tuple
from apps.scan.tasks.directory_scan import (
export_sites_task,
run_and_stream_save_directories_task
)
from prefect import flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.tasks.directory_scan import (
export_sites_task,
run_and_stream_save_directories_task,
)
from apps.scan.utils import (
build_scan_command,
ensure_wordlist_local,
user_log,
wait_for_system_load,
)
from apps.scan.utils import config_parser, build_scan_command, ensure_wordlist_local, user_log
logger = logging.getLogger(__name__)
@@ -45,517 +45,343 @@ def calculate_directory_scan_timeout(
tool_config: dict,
base_per_word: float = 1.0,
min_timeout: int = 60,
max_timeout: int = 7200
) -> int:
"""
根据字典行数计算目录扫描超时时间
计算公式:超时时间 = 字典行数 × 每个单词基础时间
超时范围:60秒 ~ 2小时7200秒
超时范围:最小 60 秒,无上限
Args:
tool_config: 工具配置字典,包含 wordlist 路径
base_per_word: 每个单词的基础时间(秒),默认 1.0秒
min_timeout: 最小超时时间(秒),默认 60秒
max_timeout: 最大超时时间(秒),默认 7200秒2小时
Returns:
int: 计算出的超时时间(秒)范围60 ~ 7200
Example:
# 1000行字典 × 1.0秒 = 1000秒 → 限制为7200秒中的 1000秒
# 10000行字典 × 1.0秒 = 10000秒 → 限制为7200秒最大值
timeout = calculate_directory_scan_timeout(
tool_config={'wordlist': '/path/to/wordlist.txt'}
)
int: 计算出的超时时间(秒)
"""
import os
wordlist_path = tool_config.get('wordlist')
if not wordlist_path:
logger.warning("工具配置中未指定 wordlist使用默认超时: %d", min_timeout)
return min_timeout
wordlist_path = os.path.expanduser(wordlist_path)
if not os.path.exists(wordlist_path):
logger.warning("字典文件不存在: %s,使用默认超时: %d", wordlist_path, min_timeout)
return min_timeout
try:
# 从 tool_config 中获取 wordlist 路径
wordlist_path = tool_config.get('wordlist')
if not wordlist_path:
logger.warning("工具配置中未指定 wordlist使用默认超时: %d", min_timeout)
return min_timeout
# 展开用户目录(~
wordlist_path = os.path.expanduser(wordlist_path)
# 检查文件是否存在
if not os.path.exists(wordlist_path):
logger.warning("字典文件不存在: %s,使用默认超时: %d", wordlist_path, min_timeout)
return min_timeout
# 使用 wc -l 快速统计字典行数
result = subprocess.run(
['wc', '-l', wordlist_path],
capture_output=True,
text=True,
check=True
)
# wc -l 输出格式:行数 + 空格 + 文件名
line_count = int(result.stdout.strip().split()[0])
# 计算超时时间
timeout = int(line_count * base_per_word)
# 设置合理的下限(不再设置上限)
timeout = max(min_timeout, timeout)
timeout = max(min_timeout, int(line_count * base_per_word))
logger.info(
"目录扫描超时计算 - 字典: %s, 行数: %d, 基础时间: %.3f秒/词, 计算超时: %d",
wordlist_path, line_count, base_per_word, timeout
)
return timeout
except subprocess.CalledProcessError as e:
logger.error("统计字典行数失败: %s", e)
# 失败时返回默认超时
return min_timeout
except (ValueError, IndexError) as e:
logger.error("解析字典行数失败: %s", e)
return min_timeout
except Exception as e:
logger.error("计算超时时间异常: %s", e)
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
logger.error("计算超时时间失败: %s", e)
return min_timeout
def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> int:
"""
从单个工具配置中获取 max_workers 参数
Args:
tool_config: 单个工具的配置字典,如 {'max_workers': 10, 'threads': 5, ...}
default: 默认值,默认为 5
Returns:
int: max_workers 值
"""
"""从单个工具配置中获取 max_workers 参数"""
if not isinstance(tool_config, dict):
return default
# 支持 max_workers 和 max-workersYAML 中划线会被转换)
max_workers = tool_config.get('max_workers') or tool_config.get('max-workers')
if max_workers is not None and isinstance(max_workers, int) and max_workers > 0:
if isinstance(max_workers, int) and max_workers > 0:
return max_workers
return default
def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]:
def _export_site_urls(
target_id: int,
directory_scan_dir: Path
) -> Tuple[str, int]:
"""
导出目标下的所有站点 URL 到文件(支持懒加载)
导出目标下的所有站点 URL 到文件
Args:
target_id: 目标 ID
target_name: 目标名称(用于懒加载创建默认站点)
directory_scan_dir: 目录扫描目录
Returns:
tuple: (sites_file, site_count)
Raises:
ValueError: 站点数量为 0
"""
logger.info("Step 1: 导出目标的所有站点 URL")
sites_file = str(directory_scan_dir / 'sites.txt')
export_result = export_sites_task(
target_id=target_id,
output_file=sites_file,
batch_size=1000 # 每次读取 1000 条,优化内存占用
batch_size=1000
)
site_count = export_result['total_count']
logger.info(
"✓ 站点 URL 导出完成 - 文件: %s, 数量: %d",
export_result['output_file'],
site_count
)
if site_count == 0:
logger.warning("目标下没有站点,无法执行目录扫描")
# 不抛出异常,由上层决定如何处理
# raise ValueError("目标下没有站点,无法执行目录扫描")
return export_result['output_file'], site_count
def _run_scans_sequentially(
enabled_tools: dict,
sites_file: str,
directory_scan_dir: Path,
scan_id: int,
target_id: int,
site_count: int,
target_name: str
) -> tuple[int, int, list]:
"""
串行执行目录扫描任务(支持多工具)- 已废弃,保留用于兼容
Args:
enabled_tools: 启用的工具配置字典
sites_file: 站点文件路径
directory_scan_dir: 目录扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
site_count: 站点数量
target_name: 目标名称(用于错误日志)
Returns:
tuple: (total_directories, processed_sites, failed_sites)
"""
# 读取站点列表
sites = []
with open(sites_file, 'r', encoding='utf-8') as f:
for line in f:
site_url = line.strip()
if site_url:
sites.append(site_url)
logger.info("准备扫描 %d 个站点,使用工具: %s", len(sites), ', '.join(enabled_tools.keys()))
total_directories = 0
processed_sites_set = set() # 使用 set 避免重复计数
failed_sites = []
# 遍历每个工具
for tool_name, tool_config in enabled_tools.items():
logger.info("="*60)
logger.info("使用工具: %s", tool_name)
logger.info("="*60)
# 如果配置了 wordlist_name则先确保本地存在对应的字典文件含 hash 校验)
wordlist_name = tool_config.get('wordlist_name')
if wordlist_name:
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
failed_sites.extend(sites)
continue
# 逐个站点执行扫描
for idx, site_url in enumerate(sites, 1):
logger.info(
"[%d/%d] 开始扫描站点: %s (工具: %s)",
idx, len(sites), site_url, tool_name
)
# 使用统一的命令构建器
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={
'url': site_url
},
tool_config=tool_config
)
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
continue
# 单个站点超时:从配置中获取(支持 'auto' 动态计算)
# ffuf 逐个站点扫描timeout 就是单个站点的超时时间
site_timeout = tool_config.get('timeout', 300)
if site_timeout == 'auto':
# 动态计算超时时间(基于字典行数)
site_timeout = calculate_directory_scan_timeout(tool_config)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}")
# 生成日志文件路径
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = directory_scan_dir / f"{tool_name}_{timestamp}_{idx}.log"
try:
# 直接调用 task串行执行
result = run_and_stream_save_directories_task(
cmd=command,
tool_name=tool_name, # 新增:工具名称
scan_id=scan_id,
target_id=target_id,
site_url=site_url,
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=site_timeout,
log_file=str(log_file) # 新增:日志文件路径
)
total_directories += result.get('created_directories', 0)
processed_sites_set.add(site_url) # 使用 set 记录成功的站点
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, len(sites), site_url,
result.get('created_directories', 0)
)
except subprocess.TimeoutExpired as exc:
# 超时异常单独处理
failed_sites.append(site_url)
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 超时配置: %d\n"
"注意:超时前已解析的目录数据已保存到数据库,但扫描未完全完成。",
idx, len(sites), site_url, site_timeout
)
except Exception as exc:
# 其他异常
failed_sites.append(site_url)
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, len(sites), site_url, exc
)
# 每 10 个站点输出进度
if idx % 10 == 0:
logger.info(
"进度: %d/%d (%.1f%%) - 已发现 %d 个目录",
idx, len(sites), idx/len(sites)*100, total_directories
)
# 计算成功和失败的站点数
processed_count = len(processed_sites_set)
if failed_sites:
logger.warning(
"部分站点扫描失败: %d/%d",
len(failed_sites), len(sites)
)
logger.info(
"✓ 串行目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
processed_count, len(sites), len(failed_sites), total_directories
)
return total_directories, processed_count, failed_sites
def _generate_log_filename(tool_name: str, site_url: str, directory_scan_dir: Path) -> Path:
"""
生成唯一的日志文件名
使用 URL 的 hash 确保并发时不会冲突
Args:
tool_name: 工具名称
site_url: 站点 URL
directory_scan_dir: 目录扫描目录
Returns:
Path: 日志文件路径
"""
url_hash = hashlib.md5(site_url.encode()).hexdigest()[:8]
def _generate_log_filename(
tool_name: str,
site_url: str,
directory_scan_dir: Path
) -> Path:
"""生成唯一的日志文件名(使用 URL 的 hash 确保并发时不会冲突)"""
url_hash = hashlib.md5(
site_url.encode(),
usedforsecurity=False
).hexdigest()[:8]
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
return directory_scan_dir / f"{tool_name}_{url_hash}_{timestamp}.log"
def _prepare_tool_wordlist(tool_name: str, tool_config: dict) -> bool:
"""准备工具的字典文件,返回是否成功"""
wordlist_name = tool_config.get('wordlist_name')
if not wordlist_name:
return True
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
return True
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
return False
def _build_scan_params(
tool_name: str,
tool_config: dict,
sites: List[str],
directory_scan_dir: Path,
site_timeout: int
) -> Tuple[List[dict], List[str]]:
"""构建所有站点的扫描参数,返回 (scan_params_list, failed_sites)"""
scan_params_list = []
failed_sites = []
for idx, site_url in enumerate(sites, 1):
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={'url': site_url},
tool_config=tool_config
)
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
scan_params_list.append({
'idx': idx,
'site_url': site_url,
'command': command,
'log_file': str(log_file),
'timeout': site_timeout
})
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
return scan_params_list, failed_sites
def _execute_batch(
batch_params: List[dict],
tool_name: str,
scan_id: int,
target_id: int,
directory_scan_dir: Path,
total_sites: int
) -> Tuple[int, List[str]]:
"""执行一批扫描任务,返回 (directories_found, failed_sites)"""
directories_found = 0
failed_sites = []
# 提交任务
futures = []
for params in batch_params:
future = run_and_stream_save_directories_task.submit(
cmd=params['command'],
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
site_url=params['site_url'],
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=params['timeout'],
log_file=params['log_file']
)
futures.append((params['idx'], params['site_url'], future))
# 等待结果
for idx, site_url, future in futures:
try:
result = future.result()
dirs_count = result.get('created_directories', 0)
directories_found += dirs_count
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, total_sites, site_url, dirs_count
)
except Exception as exc:
failed_sites.append(site_url)
if 'timeout' in str(exc).lower():
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
idx, total_sites, site_url, exc
)
else:
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, total_sites, site_url, exc
)
return directories_found, failed_sites
def _run_scans_concurrently(
enabled_tools: dict,
sites_file: str,
directory_scan_dir: Path,
scan_id: int,
target_id: int,
site_count: int,
target_name: str
) -> Tuple[int, int, List[str]]:
"""
并发执行目录扫描任务(使用 ThreadPoolTaskRunner
Args:
enabled_tools: 启用的工具配置字典
sites_file: 站点文件路径
directory_scan_dir: 目录扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
site_count: 站点数量
target_name: 目标名称(用于错误日志)
并发执行目录扫描任务
Returns:
tuple: (total_directories, processed_sites, failed_sites)
"""
# 读取站点列表
sites: List[str] = []
with open(sites_file, 'r', encoding='utf-8') as f:
for line in f:
site_url = line.strip()
if site_url:
sites.append(site_url)
sites = [line.strip() for line in f if line.strip()]
if not sites:
logger.warning("站点列表为空")
return 0, 0, []
logger.info(
"准备并发扫描 %d 个站点,使用工具: %s",
len(sites), ', '.join(enabled_tools.keys())
)
total_directories = 0
processed_sites_count = 0
failed_sites: List[str] = []
# 遍历每个工具
for tool_name, tool_config in enabled_tools.items():
# 每个工具独立获取 max_workers 配置
max_workers = _get_max_workers(tool_config)
logger.info("="*60)
logger.info("=" * 60)
logger.info("使用工具: %s (并发模式, max_workers=%d)", tool_name, max_workers)
logger.info("="*60)
logger.info("=" * 60)
user_log(scan_id, "directory_scan", f"Running {tool_name}")
# 如果配置了 wordlist_name则先确保本地存在对应的字典文件含 hash 校验)
wordlist_name = tool_config.get('wordlist_name')
if wordlist_name:
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
failed_sites.extend(sites)
continue
# 计算超时时间(所有站点共用)
# 准备字典文件
if not _prepare_tool_wordlist(tool_name, tool_config):
failed_sites.extend(sites)
continue
# 计算超时时间
site_timeout = tool_config.get('timeout', 300)
if site_timeout == 'auto':
site_timeout = calculate_directory_scan_timeout(tool_config)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}")
# 准备所有站点的扫描参数
scan_params_list = []
for idx, site_url in enumerate(sites, 1):
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={'url': site_url},
tool_config=tool_config
)
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
scan_params_list.append({
'idx': idx,
'site_url': site_url,
'command': command,
'log_file': str(log_file),
'timeout': site_timeout
})
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
logger.info("✓ 工具 %s 动态计算 timeout: %d", tool_name, site_timeout)
# 构建扫描参数
scan_params_list, build_failed = _build_scan_params(
tool_name, tool_config, sites, directory_scan_dir, site_timeout
)
failed_sites.extend(build_failed)
if not scan_params_list:
logger.warning("没有有效的扫描任务")
continue
# ============================================================
# 分批执行策略:控制实际并发的 ffuf 进程数
# ============================================================
# 分批执行
total_tasks = len(scan_params_list)
logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers)
# 进度里程碑跟踪
last_progress_percent = 0
tool_directories = 0
tool_processed = 0
batch_num = 0
for batch_start in range(0, total_tasks, max_workers):
batch_end = min(batch_start + max_workers, total_tasks)
batch_params = scan_params_list[batch_start:batch_end]
batch_num += 1
logger.info("执行第 %d 批任务(%d-%d/%d...", batch_num, batch_start + 1, batch_end, total_tasks)
# 提交当前批次的任务(非阻塞,立即返回 future
futures = []
for params in batch_params:
future = run_and_stream_save_directories_task.submit(
cmd=params['command'],
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
site_url=params['site_url'],
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=params['timeout'],
log_file=params['log_file']
)
futures.append((params['idx'], params['site_url'], future))
# 等待当前批次所有任务完成(阻塞,确保本批完成后再启动下一批)
for idx, site_url, future in futures:
try:
result = future.result() # 阻塞等待单个任务完成
directories_found = result.get('created_directories', 0)
total_directories += directories_found
tool_directories += directories_found
processed_sites_count += 1
tool_processed += 1
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, len(sites), site_url, directories_found
)
except Exception as exc:
failed_sites.append(site_url)
if 'timeout' in str(exc).lower() or isinstance(exc, subprocess.TimeoutExpired):
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
idx, len(sites), site_url, exc
)
else:
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, len(sites), site_url, exc
)
batch_num = batch_start // max_workers + 1
logger.info(
"执行第 %d 批任务(%d-%d/%d...",
batch_num, batch_start + 1, batch_end, total_tasks
)
dirs_found, batch_failed = _execute_batch(
batch_params, tool_name, scan_id, target_id,
directory_scan_dir, len(sites)
)
total_directories += dirs_found
tool_directories += dirs_found
tool_processed += len(batch_params) - len(batch_failed)
processed_sites_count += len(batch_params) - len(batch_failed)
failed_sites.extend(batch_failed)
# 进度里程碑:每 20% 输出一次
current_progress = int((batch_end / total_tasks) * 100)
if current_progress >= last_progress_percent + 20:
user_log(scan_id, "directory_scan", f"Progress: {batch_end}/{total_tasks} sites scanned")
user_log(
scan_id, "directory_scan",
f"Progress: {batch_end}/{total_tasks} sites scanned"
)
last_progress_percent = (current_progress // 20) * 20
# 工具完成日志(开发者日志 + 用户日志)
logger.info(
"✓ 工具 %s 执行完成 - 已处理站点: %d/%d, 发现目录: %d",
tool_name, tool_processed, total_tasks, tool_directories
)
user_log(scan_id, "directory_scan", f"{tool_name} completed: found {tool_directories} directories")
# 输出汇总信息
if failed_sites:
logger.warning(
"部分站点扫描失败: %d/%d",
len(failed_sites), len(sites)
user_log(
scan_id, "directory_scan",
f"{tool_name} completed: found {tool_directories} directories"
)
if failed_sites:
logger.warning("部分站点扫描失败: %d/%d", len(failed_sites), len(sites))
logger.info(
"✓ 并发目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
processed_sites_count, len(sites), len(failed_sites), total_directories
)
return total_directories, processed_sites_count, failed_sites
@flow(
name="directory_scan",
name="directory_scan",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
@@ -570,64 +396,31 @@ def directory_scan_flow(
) -> dict:
"""
目录扫描 Flow
主要功能:
1. 从 target 获取所有站点的 URL
2. 对每个站点 URL 执行目录扫描(支持 ffuf 等工具)
3. 流式保存扫描结果到数据库 Directory 表
工作流程:
Step 0: 创建工作目录
Step 1: 导出站点 URL 列表到文件(供扫描工具使用)
Step 2: 验证工具配置
Step 3: 并发执行扫描工具并实时保存结果(使用 ThreadPoolTaskRunner
ffuf 输出字段:
- url: 发现的目录/文件 URL
- length: 响应内容长度
- status: HTTP 状态码
- words: 响应内容单词数
- lines: 响应内容行数
- content_type: 内容类型
- duration: 请求耗时
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置字典
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'sites_file': str,
'site_count': int,
'total_directories': int, # 发现的总目录数
'processed_sites': int, # 成功处理的站点数
'failed_sites_count': int, # 失败的站点数
'executed_tasks': list
}
Raises:
ValueError: 参数错误
RuntimeError: 执行失败
dict: 扫描结果
"""
try:
wait_for_system_load(context="directory_scan_flow")
logger.info(
"="*60 + "\n" +
"开始目录扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始目录扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "directory_scan", "Starting directory scan")
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
@@ -639,14 +432,14 @@ def directory_scan_flow(
raise ValueError("scan_workspace_dir 不能为空")
if not enabled_tools:
raise ValueError("enabled_tools 不能为空")
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
# Step 1: 导出站点 URL(支持懒加载)
sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir)
# Step 1: 导出站点 URL
sites_file, site_count = _export_site_urls(target_id, directory_scan_dir)
if site_count == 0:
logger.warning("跳过目录扫描:没有站点可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "directory_scan", "Skipped: no sites to scan", "warning")
@@ -662,16 +455,16 @@ def directory_scan_flow(
'failed_sites_count': 0,
'executed_tasks': ['export_sites']
}
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
tool_info = []
for tool_name, tool_config in enabled_tools.items():
mw = _get_max_workers(tool_config)
tool_info.append(f"{tool_name}(max_workers={mw})")
tool_info = [
f"{name}(max_workers={_get_max_workers(cfg)})"
for name, cfg in enabled_tools.items()
]
logger.info("✓ 启用工具: %s", ', '.join(tool_info))
# Step 3: 并发执行扫描工具并实时保存结果
# Step 3: 并发执行扫描
logger.info("Step 3: 并发执行扫描工具并实时保存结果")
total_directories, processed_sites, failed_sites = _run_scans_concurrently(
enabled_tools=enabled_tools,
@@ -679,19 +472,20 @@ def directory_scan_flow(
directory_scan_dir=directory_scan_dir,
scan_id=scan_id,
target_id=target_id,
site_count=site_count,
target_name=target_name
)
# 检查是否所有站点都失败
if processed_sites == 0 and site_count > 0:
logger.warning("所有站点扫描均失败 - 总站点数: %d, 失败数: %d", site_count, len(failed_sites))
# 不抛出异常,让扫描继续
# 记录 Flow 完成
logger.warning(
"所有站点扫描均失败 - 总站点数: %d, 失败数: %d",
site_count, len(failed_sites)
)
logger.info("✓ 目录扫描完成 - 发现目录: %d", total_directories)
user_log(scan_id, "directory_scan", f"directory_scan completed: found {total_directories} directories")
user_log(
scan_id, "directory_scan",
f"directory_scan completed: found {total_directories} directories"
)
return {
'success': True,
'scan_id': scan_id,
@@ -704,7 +498,7 @@ def directory_scan_flow(
'failed_sites_count': len(failed_sites),
'executed_tasks': ['export_sites', 'run_and_stream_save_directories']
}
except Exception as e:
logger.exception("目录扫描失败: %s", e)
raise
raise

View File

@@ -10,26 +10,22 @@
- 流式处理输出,批量更新数据库
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
from datetime import datetime
from pathlib import Path
from prefect import flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.tasks.fingerprint_detect import (
export_urls_for_fingerprint_task,
run_xingfinger_and_stream_update_tech_task,
)
from apps.scan.utils import build_scan_command, user_log
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
logger = logging.getLogger(__name__)
@@ -42,22 +38,19 @@ def calculate_fingerprint_detect_timeout(
) -> int:
"""
根据 URL 数量计算超时时间
公式:超时时间 = URL 数量 × 每 URL 基础时间
最小值300秒
无上限
最小值300秒,无上限
Args:
url_count: URL 数量
base_per_url: 每 URL 基础时间(秒),默认 10秒
min_timeout: 最小超时时间(秒),默认 300秒
Returns:
int: 计算出的超时时间(秒)
"""
timeout = int(url_count * base_per_url)
return max(min_timeout, timeout)
return max(min_timeout, int(url_count * base_per_url))
@@ -70,17 +63,17 @@ def _export_urls(
) -> tuple[str, int]:
"""
导出 URL 到文件
Args:
target_id: 目标 ID
fingerprint_dir: 指纹识别目录
source: 数据源类型
Returns:
tuple: (urls_file, total_count)
"""
logger.info("Step 1: 导出 URL 列表 (source=%s)", source)
urls_file = str(fingerprint_dir / 'urls.txt')
export_result = export_urls_for_fingerprint_task(
target_id=target_id,
@@ -88,15 +81,14 @@ def _export_urls(
source=source,
batch_size=1000
)
total_count = export_result['total_count']
logger.info(
"✓ URL 导出完成 - 文件: %s, 数量: %d",
export_result['output_file'],
total_count
)
return export_result['output_file'], total_count
@@ -111,7 +103,7 @@ def _run_fingerprint_detect(
) -> tuple[dict, list]:
"""
执行指纹识别任务
Args:
enabled_tools: 已启用的工具配置字典
urls_file: URL 文件路径
@@ -120,56 +112,54 @@ def _run_fingerprint_detect(
scan_id: 扫描任务 ID
target_id: 目标 ID
source: 数据源类型
Returns:
tuple: (tool_stats, failed_tools)
"""
tool_stats = {}
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 获取指纹库路径
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
fingerprint_paths = get_fingerprint_paths(lib_names)
if not fingerprint_paths:
reason = f"没有可用的指纹库: {lib_names}"
logger.warning(reason)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 将指纹库路径合并到 tool_config用于命令构建
tool_config_with_paths = {**tool_config, **fingerprint_paths}
# 3. 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='fingerprint_detect',
command_params={
'urls_file': urls_file
},
command_params={'urls_file': urls_file},
tool_config=tool_config_with_paths
)
except Exception as e:
reason = f"命令构建失败: {str(e)}"
reason = f"命令构建失败: {e}"
logger.error("构建 %s 命令失败: %s", tool_name, e)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 4. 计算超时时间
timeout = calculate_fingerprint_detect_timeout(url_count)
# 5. 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = fingerprint_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
tool_name, url_count, timeout, list(fingerprint_paths.keys())
)
user_log(scan_id, "fingerprint_detect", f"Running {tool_name}: {command}")
# 6. 执行扫描任务
try:
result = run_xingfinger_and_stream_update_tech_task(
@@ -183,14 +173,14 @@ def _run_fingerprint_detect(
log_file=str(log_file),
batch_size=100
)
tool_stats[tool_name] = {
'command': command,
'result': result,
'timeout': timeout,
'fingerprint_libs': list(fingerprint_paths.keys())
}
tool_updated = result.get('updated_count', 0)
logger.info(
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
@@ -199,20 +189,23 @@ def _run_fingerprint_detect(
tool_updated,
result.get('not_found_count', 0)
)
user_log(scan_id, "fingerprint_detect", f"{tool_name} completed: identified {tool_updated} fingerprints")
user_log(
scan_id, "fingerprint_detect",
f"{tool_name} completed: identified {tool_updated} fingerprints"
)
except Exception as exc:
reason = str(exc)
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(scan_id, "fingerprint_detect", f"{tool_name} failed: {reason}", "error")
if failed_tools:
logger.warning(
"以下指纹识别工具执行失败: %s",
', '.join([f['tool'] for f in failed_tools])
)
return tool_stats, failed_tools
@@ -232,53 +225,38 @@ def fingerprint_detect_flow(
) -> dict:
"""
指纹识别 Flow
主要功能:
1. 从数据库导出目标下所有 WebSite URL 到文件
2. 使用 xingfinger 进行技术栈识别
3. 解析结果并更新 WebSite.tech 字段(合并去重)
工作流程:
Step 0: 创建工作目录
Step 1: 导出 URL 列表
Step 2: 解析配置,获取启用的工具
Step 3: 执行 xingfinger 并解析结果
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置xingfinger
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'urls_file': str,
'url_count': int,
'processed_records': int,
'updated_count': int,
'created_count': int,
'snapshot_count': int,
'executed_tasks': list,
'tool_stats': dict
}
dict: 扫描结果
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="fingerprint_detect_flow")
logger.info(
"="*60 + "\n" +
"开始指纹识别\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始指纹识别 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "fingerprint_detect", "Starting fingerprint detection")
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
@@ -288,46 +266,26 @@ def fingerprint_detect_flow(
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
# 数据源类型(当前只支持 website
source = 'website'
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
# Step 1: 导出 URL支持懒加载
urls_file, url_count = _export_urls(target_id, fingerprint_dir, source)
if url_count == 0:
logger.warning("跳过指纹识别:没有 URL 可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "fingerprint_detect", "Skipped: no URLs to scan", "warning")
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': 0,
'processed_records': 0,
'updated_count': 0,
'created_count': 0,
'snapshot_count': 0,
'executed_tasks': ['export_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}
return _build_empty_result(scan_id, target_name, scan_workspace_dir, urls_file)
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
# Step 3: 执行指纹识别
logger.info("Step 3: 执行指纹识别")
tool_stats, failed_tools = _run_fingerprint_detect(
@@ -339,24 +297,37 @@ def fingerprint_detect_flow(
target_id=target_id,
source=source
)
# 动态生成已执行的任务列表
executed_tasks = ['export_urls_for_fingerprint']
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats.keys()])
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats])
# 汇总所有工具的结果
total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values())
total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values())
total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values())
total_snapshots = sum(stats['result'].get('snapshot_count', 0) for stats in tool_stats.values())
total_processed = sum(
stats['result'].get('processed_records', 0) for stats in tool_stats.values()
)
total_updated = sum(
stats['result'].get('updated_count', 0) for stats in tool_stats.values()
)
total_created = sum(
stats['result'].get('created_count', 0) for stats in tool_stats.values()
)
total_snapshots = sum(
stats['result'].get('snapshot_count', 0) for stats in tool_stats.values()
)
# 记录 Flow 完成
logger.info("✓ 指纹识别完成 - 识别指纹: %d", total_updated)
user_log(scan_id, "fingerprint_detect", f"fingerprint_detect completed: identified {total_updated} fingerprints")
successful_tools = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
user_log(
scan_id, "fingerprint_detect",
f"fingerprint_detect completed: identified {total_updated} fingerprints"
)
successful_tools = [
name for name in enabled_tools
if name not in [f['tool'] for f in failed_tools]
]
return {
'success': True,
'scan_id': scan_id,
@@ -378,7 +349,7 @@ def fingerprint_detect_flow(
'details': tool_stats
}
}
except ValueError as e:
logger.error("配置错误: %s", e)
raise
@@ -388,3 +359,33 @@ def fingerprint_detect_flow(
except Exception as e:
logger.exception("指纹识别失败: %s", e)
raise
def _build_empty_result(
scan_id: int,
target_name: str,
scan_workspace_dir: str,
urls_file: str
) -> dict:
"""构建空结果(无 URL 可扫描时)"""
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': 0,
'processed_records': 0,
'updated_count': 0,
'created_count': 0,
'snapshot_count': 0,
'executed_tasks': ['export_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}

View File

@@ -1,4 +1,4 @@
"""
"""
端口扫描 Flow
负责编排端口扫描的完整流程
@@ -10,25 +10,23 @@
- 配置由 YAML 解析
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
import subprocess
from datetime import datetime
from pathlib import Path
from typing import Callable
from prefect import flow
from apps.scan.tasks.port_scan import (
export_hosts_task,
run_and_stream_save_ports_task
)
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.utils import config_parser, build_scan_command, user_log
from apps.scan.tasks.port_scan import (
export_hosts_task,
run_and_stream_save_ports_task,
)
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
logger = logging.getLogger(__name__)
@@ -40,28 +38,19 @@ def calculate_port_scan_timeout(
) -> int:
"""
根据目标数量和端口数量计算超时时间
计算公式:超时时间 = 目标数 × 端口数 × base_per_pair
超时范围60秒 ~ 2天172800秒
超时范围60秒 ~ 无上限
Args:
tool_config: 工具配置字典包含端口配置ports, top-ports等
file_path: 目标文件路径(域名/IP列表
base_per_pair: 每个"端口-目标对"的基础时间(秒),默认 0.5秒
Returns:
int: 计算出的超时时间(秒),范围60 ~ 172800
Example:
# 100个目标 × 100个端口 × 0.5秒 = 5000秒
# 10个目标 × 1000个端口 × 0.5秒 = 5000秒
timeout = calculate_port_scan_timeout(
tool_config={'top-ports': 100},
file_path='/path/to/domains.txt'
)
int: 计算出的超时时间(秒),最小 60 秒
"""
try:
# 1. 统计目标数量
result = subprocess.run(
['wc', '-l', file_path],
capture_output=True,
@@ -69,88 +58,74 @@ def calculate_port_scan_timeout(
check=True
)
target_count = int(result.stdout.strip().split()[0])
# 2. 解析端口数量
port_count = _parse_port_count(tool_config)
# 3. 计算超时时间
# 总工作量 = 目标数 × 端口数
total_work = target_count * port_count
timeout = int(total_work * base_per_pair)
# 4. 设置合理的下限(不再设置上限)
min_timeout = 60 # 最小 60 秒
timeout = max(min_timeout, timeout)
timeout = max(60, int(total_work * base_per_pair))
logger.info(
f"计算端口扫描 timeout - "
f"目标数: {target_count}, "
f"端口数: {port_count}, "
f"总工作量: {total_work}, "
f"超时: {timeout}"
"计算端口扫描 timeout - 目标数: %d, 端口数: %d, 总工作量: %d, 超时: %d",
target_count, port_count, total_work, timeout
)
return timeout
except Exception as e:
logger.warning(f"计算 timeout 失败: {e},使用默认值 600秒")
logger.warning("计算 timeout 失败: %s,使用默认值 600秒", e)
return 600
def _parse_port_count(tool_config: dict) -> int:
"""
从工具配置中解析端口数量
优先级:
1. top-ports: N → 返回 N
2. ports: "80,443,8080" → 返回逗号分隔的数量
3. ports: "1-1000" → 返回范围的大小
4. ports: "1-65535" → 返回 65535
5. 默认 → 返回 100naabu 默认扫描 top 100
Args:
tool_config: 工具配置字典
Returns:
int: 端口数量
"""
# 1. 检查 top-ports 配置
# 检查 top-ports 配置
if 'top-ports' in tool_config:
top_ports = tool_config['top-ports']
if isinstance(top_ports, int) and top_ports > 0:
return top_ports
logger.warning(f"top-ports 配置无效: {top_ports},使用默认值")
# 2. 检查 ports 配置
logger.warning("top-ports 配置无效: %s,使用默认值", top_ports)
# 检查 ports 配置
if 'ports' in tool_config:
ports_str = str(tool_config['ports']).strip()
# 2.1 逗号分隔的端口列表80,443,8080
# 逗号分隔的端口列表80,443,8080
if ',' in ports_str:
port_list = [p.strip() for p in ports_str.split(',') if p.strip()]
return len(port_list)
# 2.2 端口范围1-1000
return len([p.strip() for p in ports_str.split(',') if p.strip()])
# 端口范围1-1000
if '-' in ports_str:
try:
start, end = ports_str.split('-', 1)
start_port = int(start.strip())
end_port = int(end.strip())
if 1 <= start_port <= end_port <= 65535:
return end_port - start_port + 1
logger.warning(f"端口范围无效: {ports_str},使用默认值")
logger.warning("端口范围无效: %s,使用默认值", ports_str)
except ValueError:
logger.warning(f"端口范围解析失败: {ports_str},使用默认值")
# 2.3 单个端口
logger.warning("端口范围解析失败: %s,使用默认值", ports_str)
# 单个端口
try:
port = int(ports_str)
if 1 <= port <= 65535:
return 1
except ValueError:
logger.warning(f"端口配置解析失败: {ports_str},使用默认值")
# 3. 默认值naabu 默认扫描 top 100 端口
logger.warning("端口配置解析失败: %s,使用默认值", ports_str)
# 默认值naabu 默认扫描 top 100 端口
return 100
@@ -160,41 +135,38 @@ def _parse_port_count(tool_config: dict) -> int:
def _export_hosts(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
"""
导出主机列表到文件
根据 Target 类型自动决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名
- IP: 直接写入 target.name
- CIDR: 展开 CIDR 范围内的所有 IP
Args:
target_id: 目标 ID
port_scan_dir: 端口扫描目录
Returns:
tuple: (hosts_file, host_count, target_type)
"""
logger.info("Step 1: 导出主机列表")
hosts_file = str(port_scan_dir / 'hosts.txt')
export_result = export_hosts_task(
target_id=target_id,
output_file=hosts_file,
batch_size=1000 # 每次读取 1000 条,优化内存占用
)
host_count = export_result['total_count']
target_type = export_result.get('target_type', 'unknown')
logger.info(
"✓ 主机列表导出完成 - 类型: %s, 文件: %s, 数量: %d",
target_type,
export_result['output_file'],
host_count
target_type, export_result['output_file'], host_count
)
if host_count == 0:
logger.warning("目标下没有可扫描的主机,无法执行端口扫描")
return export_result['output_file'], host_count, target_type
@@ -208,7 +180,7 @@ def _run_scans_sequentially(
) -> tuple[dict, int, list, list]:
"""
串行执行端口扫描任务
Args:
enabled_tools: 已启用的工具配置字典
domains_file: 域名文件路径
@@ -216,72 +188,56 @@ def _run_scans_sequentially(
scan_id: 扫描任务 ID
target_id: 目标 ID
target_name: 目标名称(用于错误日志)
Returns:
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
注意:端口扫描是流式输出,不生成结果文件
Raises:
RuntimeError: 所有工具均失败
"""
# ==================== 构建命令并串行执行 ====================
tool_stats = {}
processed_records = 0
failed_tools = [] # 记录失败的工具(含原因)
# for循环执行工具按顺序串行运行每个启用的端口扫描工具
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 构建完整命令(变量替换)
# 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='port_scan',
command_params={
'domains_file': domains_file # 对应 {domains_file}
},
tool_config=tool_config #yaml的工具配置
command_params={'domains_file': domains_file},
tool_config=tool_config
)
except Exception as e:
reason = f"命令构建失败: {str(e)}"
logger.error(f"构建 {tool_name} 命令失败: {e}")
reason = f"命令构建失败: {e}"
logger.error("构建 %s 命令失败: %s", tool_name, e)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 获取超时时间(支持 'auto' 动态计算)
# 获取超时时间
config_timeout = tool_config['timeout']
if config_timeout == 'auto':
# 动态计算超时时间
config_timeout = calculate_port_scan_timeout(
tool_config=tool_config,
file_path=str(domains_file)
)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {config_timeout}")
# 2.1 生成日志文件路径
from datetime import datetime
config_timeout = calculate_port_scan_timeout(tool_config, str(domains_file))
logger.info("✓ 工具 %s 动态计算 timeout: %d", tool_name, config_timeout)
# 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = port_scan_dir / f"{tool_name}_{timestamp}.log"
# 3. 执行扫描任务
logger.info("开始执行 %s 扫描(超时: %d秒)...", tool_name, config_timeout)
user_log(scan_id, "port_scan", f"Running {tool_name}: {command}")
# 执行扫描任务
try:
# 直接调用 task串行执行
# 注意:端口扫描是流式输出到 stdout不使用 output_file
result = run_and_stream_save_ports_task(
cmd=command,
tool_name=tool_name, # 工具名称
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
cwd=str(port_scan_dir),
shell=True,
batch_size=1000,
timeout=config_timeout,
log_file=str(log_file) # 新增:日志文件路径
log_file=str(log_file)
)
tool_stats[tool_name] = {
'command': command,
'result': result,
@@ -289,15 +245,10 @@ def _run_scans_sequentially(
}
tool_records = result.get('processed_records', 0)
processed_records += tool_records
logger.info(
"✓ 工具 %s 流式处理完成 - 记录数: %d",
tool_name, tool_records
)
logger.info("✓ 工具 %s 流式处理完成 - 记录数: %d", tool_name, tool_records)
user_log(scan_id, "port_scan", f"{tool_name} completed: found {tool_records} ports")
except subprocess.TimeoutExpired as exc:
# 超时异常单独处理
# 注意:流式处理任务超时时,已解析的数据已保存到数据库
except subprocess.TimeoutExpired:
reason = f"timeout after {config_timeout}s"
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.warning(
@@ -307,40 +258,39 @@ def _run_scans_sequentially(
)
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
except Exception as exc:
# 其他异常
reason = str(exc)
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
if failed_tools:
logger.warning(
"以下扫描工具执行失败: %s",
', '.join([f['tool'] for f in failed_tools])
)
if not tool_stats:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
logger.warning("所有端口扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
# 返回空结果,不抛出异常,让扫描继续
return {}, 0, [], failed_tools
# 动态计算成功的工具列表
successful_tool_names = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
successful_tool_names = [
name for name in enabled_tools
if name not in [f['tool'] for f in failed_tools]
]
logger.info(
"✓ 串行端口扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
len(tool_stats), len(enabled_tools),
', '.join(successful_tool_names) if successful_tool_names else '',
', '.join([f['tool'] for f in failed_tools]) if failed_tools else ''
)
return tool_stats, processed_records, successful_tool_names, failed_tools
@flow(
name="port_scan",
name="port_scan",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
@@ -355,19 +305,19 @@ def port_scan_flow(
) -> dict:
"""
端口扫描 Flow
主要功能:
1. 扫描目标域名/IP 的开放端口
2. 保存 host + ip + port 三元映射到 HostPortMapping 表
输出资产:
- HostPortMapping主机端口映射host + ip + port 三元组)
工作流程:
Step 0: 创建工作目录
Step 1: 导出域名列表到文件(供扫描工具使用)
Step 2: 解析配置,获取启用的工具
Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库(→ HostPortMapping
Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库
Args:
scan_id: 扫描任务 ID
@@ -377,35 +327,15 @@ def port_scan_flow(
enabled_tools: 启用的工具配置字典
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'hosts_file': str,
'host_count': int,
'processed_records': int,
'executed_tasks': list,
'tool_stats': {
'total': int, # 总工具数
'successful': int, # 成功工具数
'failed': int, # 失败工具数
'successful_tools': list[str], # 成功工具列表 ['naabu_active']
'failed_tools': list[dict], # 失败工具列表 [{'tool': 'naabu_passive', 'reason': '超时'}]
'details': dict # 详细执行结果(保留向后兼容)
}
}
dict: 扫描结果
Raises:
ValueError: 配置错误
RuntimeError: 执行失败
Note:
端口扫描工具(如 naabu会解析域名获取 IP输出 host + ip + port 三元组。
同一 host 可能对应多个 IPCDN、负载均衡因此使用三元映射表存储。
"""
try:
# 参数验证
wait_for_system_load(context="port_scan_flow")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
@@ -416,25 +346,20 @@ def port_scan_flow(
raise ValueError("scan_workspace_dir 不能为空")
if not enabled_tools:
raise ValueError("enabled_tools 不能为空")
logger.info(
"="*60 + "\n" +
"开始端口扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始端口扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "port_scan", "Starting port scan")
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
# Step 1: 导出主机列表到文件(根据 Target 类型自动决定内容)
# Step 1: 导出主机列表
hosts_file, host_count, target_type = _export_hosts(target_id, port_scan_dir)
if host_count == 0:
logger.warning("跳过端口扫描:没有主机可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "port_scan", "Skipped: no hosts to scan", "warning")
@@ -457,14 +382,11 @@ def port_scan_flow(
'details': {}
}
}
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info(
"✓ 启用工具: %s",
', '.join(enabled_tools.keys())
)
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
# Step 3: 串行执行扫描工具
logger.info("Step 3: 串行执行扫描工具")
tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially(
@@ -475,15 +397,13 @@ def port_scan_flow(
target_id=target_id,
target_name=target_name
)
# 记录 Flow 完成
logger.info("✓ 端口扫描完成 - 发现端口: %d", processed_records)
user_log(scan_id, "port_scan", f"port_scan completed: found {processed_records} ports")
# 动态生成已执行的任务列表
executed_tasks = ['export_hosts', 'parse_config']
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats.keys()])
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats])
return {
'success': True,
'scan_id': scan_id,

View File

@@ -5,43 +5,39 @@
1. 从数据库获取 URL 列表websites 和/或 endpoints
2. 批量截图并保存快照
3. 同步到资产表
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
"""
# Django 环境初始化
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
from pathlib import Path
from typing import Optional
from prefect import flow
from apps.scan.tasks.screenshot import capture_screenshots_task
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.utils import user_log
from apps.scan.services.target_export_service import (
get_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider
from apps.scan.services.target_export_service import DataSource, get_urls_with_fallback
from apps.scan.tasks.screenshot import capture_screenshots_task
from apps.scan.utils import user_log, wait_for_system_load
logger = logging.getLogger(__name__)
# URL 来源到 DataSource 的映射
_SOURCE_MAPPING = {
'websites': DataSource.WEBSITE,
'endpoints': DataSource.ENDPOINT,
}
def _parse_screenshot_config(enabled_tools: dict) -> dict:
"""
解析截图配置
Args:
enabled_tools: 启用的工具配置
Returns:
截图配置字典
"""
# 从 enabled_tools 中获取 playwright 配置
"""解析截图配置"""
playwright_config = enabled_tools.get('playwright', {})
return {
'concurrency': playwright_config.get('concurrency', 5),
'url_sources': playwright_config.get('url_sources', ['websites'])
@@ -49,33 +45,54 @@ def _parse_screenshot_config(enabled_tools: dict) -> dict:
def _map_url_sources_to_data_sources(url_sources: list[str]) -> list[str]:
"""
将配置中的 url_sources 映射为 DataSource 常量
Args:
url_sources: 配置中的来源列表,如 ['websites', 'endpoints']
Returns:
DataSource 常量列表
"""
source_mapping = {
'websites': DataSource.WEBSITE,
'endpoints': DataSource.ENDPOINT,
}
"""将配置中的 url_sources 映射为 DataSource 常量"""
sources = []
for source in url_sources:
if source in source_mapping:
sources.append(source_mapping[source])
if source in _SOURCE_MAPPING:
sources.append(_SOURCE_MAPPING[source])
else:
logger.warning("未知的 URL 来源: %s,跳过", source)
# 添加默认回退(从 subdomain 构造)
sources.append(DataSource.DEFAULT)
return sources
def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str, list[str]]:
"""从 Provider 收集 URL"""
logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__)
urls = list(provider.iter_urls())
blacklist_filter = provider.get_blacklist_filter()
if blacklist_filter:
urls = [url for url in urls if blacklist_filter.is_allowed(url)]
return urls, 'provider', ['provider']
def _collect_urls_from_database(
target_id: int,
url_sources: list[str]
) -> tuple[list[str], str, list[str]]:
"""从数据库收集 URL带黑名单过滤和回退"""
data_sources = _map_url_sources_to_data_sources(url_sources)
result = get_urls_with_fallback(target_id, sources=data_sources)
return result['urls'], result['source'], result['tried_sources']
def _build_empty_result(scan_id: int, target_name: str) -> dict:
"""构建空结果"""
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'total_urls': 0,
'successful': 0,
'failed': 0,
'synced': 0
}
@flow(
name="screenshot",
log_prints=True,
@@ -88,115 +105,104 @@ def screenshot_flow(
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider: Optional[TargetProvider] = None
) -> dict:
"""
截图 Flow
工作流程
Step 1: 解析配置
Step 2: 收集 URL 列表
Step 3: 批量截图并保存快照
Step 4: 同步到资产表
支持两种模式
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置
provider: TargetProvider 实例(新模式,可选)
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'total_urls': int,
'successful': int,
'failed': int,
'synced': int
}
截图结果字典
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="screenshot_flow")
mode = 'Provider' if provider else 'Legacy'
logger.info(
"="*60 + "\n" +
"开始截图扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始截图扫描 - Scan ID: %s, Target: %s, Mode: %s",
scan_id, target_name, mode
)
user_log(scan_id, "screenshot", "Starting screenshot capture")
# Step 1: 解析配置
config = _parse_screenshot_config(enabled_tools)
concurrency = config['concurrency']
url_sources = config['url_sources']
logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, url_sources)
# Step 2: 使用统一服务收集 URL带黑名单过滤和回退
data_sources = _map_url_sources_to_data_sources(url_sources)
result = get_urls_with_fallback(target_id, sources=data_sources)
urls = result['urls']
logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, config['url_sources'])
# Step 2: 收集 URL 列表
if provider is not None:
urls, source_info, tried_sources = _collect_urls_from_provider(provider)
else:
urls, source_info, tried_sources = _collect_urls_from_database(
target_id, config['url_sources']
)
logger.info(
"URL 收集完成 - 来源: %s, 数量: %d, 尝试过: %s",
result['source'], result['total_count'], result['tried_sources']
source_info, len(urls), tried_sources
)
if not urls:
logger.warning("没有可截图的 URL跳过截图任务")
user_log(scan_id, "screenshot", "Skipped: no URLs to capture", "warning")
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'total_urls': 0,
'successful': 0,
'failed': 0,
'synced': 0
}
user_log(scan_id, "screenshot", f"Found {len(urls)} URLs to capture (source: {result['source']})")
return _build_empty_result(scan_id, target_name)
user_log(
scan_id, "screenshot",
f"Found {len(urls)} URLs to capture (source: {source_info})"
)
# Step 3: 批量截图
logger.info("Step 3: 批量截图 - %d 个 URL", len(urls))
logger.info("批量截图 - %d 个 URL", len(urls))
capture_result = capture_screenshots_task(
urls=urls,
scan_id=scan_id,
target_id=target_id,
config={'concurrency': concurrency}
)
# Step 4: 同步到资产表
logger.info("Step 4: 同步截图到资产表")
logger.info("同步截图到资产表")
from apps.asset.services.screenshot_service import ScreenshotService
screenshot_service = ScreenshotService()
synced = screenshot_service.sync_screenshots_to_asset(scan_id, target_id)
synced = ScreenshotService().sync_screenshots_to_asset(scan_id, target_id)
total = capture_result['total']
successful = capture_result['successful']
failed = capture_result['failed']
logger.info(
"✓ 截图完成 - 总数: %d, 成功: %d, 失败: %d, 同步: %d",
capture_result['total'], capture_result['successful'], capture_result['failed'], synced
total, successful, failed, synced
)
user_log(
scan_id, "screenshot",
f"Screenshot completed: {capture_result['successful']}/{capture_result['total']} captured, {synced} synced"
f"Screenshot completed: {successful}/{total} captured, {synced} synced"
)
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'total_urls': capture_result['total'],
'successful': capture_result['successful'],
'failed': capture_result['failed'],
'total_urls': total,
'successful': successful,
'failed': failed,
'synced': synced
}
except Exception as e:
logger.exception("截图 Flow 失败: %s", e)
user_log(scan_id, "screenshot", f"Screenshot failed: {e}", "error")
except Exception:
logger.exception("截图 Flow 失败")
user_log(scan_id, "screenshot", "Screenshot failed", "error")
raise

View File

@@ -1,4 +1,3 @@
"""
站点扫描 Flow
@@ -11,303 +10,319 @@
- 配置由 YAML 解析
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
import subprocess
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Callable
from typing import Optional
from prefect import flow
from apps.scan.tasks.site_scan import export_site_urls_task, run_and_stream_save_websites_task
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect # noqa: F401
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.utils import config_parser, build_scan_command, user_log
from apps.scan.tasks.site_scan import (
export_site_urls_task,
run_and_stream_save_websites_task,
)
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
logger = logging.getLogger(__name__)
def calculate_timeout_by_line_count(
tool_config: dict,
file_path: str,
base_per_time: int = 1,
min_timeout: int = 60
) -> int:
"""
根据文件行数计算 timeout
使用 wc -l 统计文件行数,根据行数和每行基础时间计算 timeout
Args:
tool_config: 工具配置字典(此函数未使用,但保持接口一致性)
file_path: 要统计行数的文件路径
base_per_time: 每行的基础时间默认1秒
min_timeout: 最小超时时间默认60秒
Returns:
int: 计算出的超时时间(秒),不低于 min_timeout
Example:
timeout = calculate_timeout_by_line_count(
tool_config={},
file_path='/path/to/urls.txt',
base_per_time=2
)
"""
@dataclass
class ScanContext:
"""扫描上下文,封装扫描参数"""
scan_id: int
target_id: int
target_name: str
site_scan_dir: Path
urls_file: str
total_urls: int
def _count_file_lines(file_path: str) -> int:
"""使用 wc -l 统计文件行数"""
try:
# 使用 wc -l 快速统计行数
result = subprocess.run(
['wc', '-l', file_path],
capture_output=True,
text=True,
check=True
)
# wc -l 输出格式:行数 + 空格 + 文件名
line_count = int(result.stdout.strip().split()[0])
# 计算 timeout行数 × 每行基础时间,不低于最小值
timeout = max(line_count * base_per_time, min_timeout)
logger.info(
f"timeout 自动计算: 文件={file_path}, "
f"行数={line_count}, 每行时间={base_per_time}秒, 最小值={min_timeout}秒, timeout={timeout}"
)
return timeout
except Exception as e:
# 如果 wc -l 失败,使用默认值
logger.warning(f"wc -l 计算行数失败: {e},使用默认 timeout: {min_timeout}")
return min_timeout
return int(result.stdout.strip().split()[0])
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
logger.warning("wc -l 计算行数失败: %s,返回 0", e)
return 0
def _calculate_timeout_by_line_count(
file_path: str,
base_per_time: int = 1,
min_timeout: int = 60
) -> int:
"""
根据文件行数计算 timeout
Args:
file_path: 要统计行数的文件路径
base_per_time: 每行的基础时间默认1秒
min_timeout: 最小超时时间默认60秒
Returns:
int: 计算出的超时时间(秒),不低于 min_timeout
"""
line_count = _count_file_lines(file_path)
timeout = max(line_count * base_per_time, min_timeout)
logger.info(
"timeout 自动计算: 文件=%s, 行数=%d, 每行时间=%d秒, timeout=%d",
file_path, line_count, base_per_time, timeout
)
return timeout
def _export_site_urls(target_id: int, site_scan_dir: Path, target_name: str = None) -> tuple[str, int, int]:
def _export_site_urls(
target_id: int,
site_scan_dir: Path
) -> tuple[str, int, int]:
"""
导出站点 URL 到文件
Args:
target_id: 目标 ID
site_scan_dir: 站点扫描目录
target_name: 目标名称(用于懒加载时写入默认值)
Returns:
tuple: (urls_file, total_urls, association_count)
Raises:
ValueError: URL 数量为 0
"""
logger.info("Step 1: 导出站点URL列表")
urls_file = str(site_scan_dir / 'site_urls.txt')
export_result = export_site_urls_task(
target_id=target_id,
output_file=urls_file,
batch_size=1000 # 每次处理1000个子域名
batch_size=1000
)
total_urls = export_result['total_urls']
association_count = export_result['association_count'] # 主机端口关联数
association_count = export_result['association_count']
logger.info(
"✓ 站点URL导出完成 - 文件: %s, URL数量: %d, 关联数: %d",
export_result['output_file'],
total_urls,
association_count
export_result['output_file'], total_urls, association_count
)
if total_urls == 0:
logger.warning("目标下没有可用的站点URL无法执行站点扫描")
# 不抛出异常,由上层决定如何处理
# raise ValueError("目标下没有可用的站点URL无法执行站点扫描")
return export_result['output_file'], total_urls, association_count
def _get_tool_timeout(tool_config: dict, urls_file: str) -> int:
"""获取工具超时时间(支持 'auto' 动态计算)"""
config_timeout = tool_config.get('timeout', 300)
if config_timeout == 'auto':
return _calculate_timeout_by_line_count(urls_file, base_per_time=1)
dynamic_timeout = _calculate_timeout_by_line_count(urls_file, base_per_time=1)
return max(dynamic_timeout, config_timeout)
def _execute_single_tool(
tool_name: str,
tool_config: dict,
ctx: ScanContext
) -> Optional[dict]:
"""
执行单个扫描工具
Returns:
成功返回结果字典,失败返回 None
"""
# 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='site_scan',
command_params={'url_file': ctx.urls_file},
tool_config=tool_config
)
except (ValueError, KeyError) as e:
logger.error("构建 %s 命令失败: %s", tool_name, e)
return None
timeout = _get_tool_timeout(tool_config, ctx.urls_file)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = ctx.site_scan_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 站点扫描 - URL数: %d, 超时: %ds",
tool_name, ctx.total_urls, timeout
)
user_log(ctx.scan_id, "site_scan", f"Running {tool_name}: {command}")
try:
result = run_and_stream_save_websites_task(
cmd=command,
tool_name=tool_name,
scan_id=ctx.scan_id,
target_id=ctx.target_id,
cwd=str(ctx.site_scan_dir),
shell=True,
timeout=timeout,
log_file=str(log_file)
)
tool_created = result.get('created_websites', 0)
skipped = result.get('skipped_no_subdomain', 0) + result.get('skipped_failed', 0)
logger.info(
"✓ 工具 %s 完成 - 处理: %d, 创建: %d, 跳过: %d",
tool_name, result.get('processed_records', 0), tool_created, skipped
)
user_log(
ctx.scan_id, "site_scan",
f"{tool_name} completed: found {tool_created} websites"
)
return {'command': command, 'result': result, 'timeout': timeout}
except subprocess.TimeoutExpired:
logger.warning(
"⚠️ 工具 %s 执行超时 - 超时配置: %d秒 (超时前数据已保存)",
tool_name, timeout
)
user_log(
ctx.scan_id, "site_scan",
f"{tool_name} failed: timeout after {timeout}s", "error"
)
except (OSError, RuntimeError) as exc:
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(ctx.scan_id, "site_scan", f"{tool_name} failed: {exc}", "error")
return None
def _run_scans_sequentially(
enabled_tools: dict,
urls_file: str,
total_urls: int,
site_scan_dir: Path,
scan_id: int,
target_id: int,
target_name: str
ctx: ScanContext
) -> tuple[dict, int, list, list]:
"""
串行执行站点扫描任务
Args:
enabled_tools: 已启用的工具配置字典
urls_file: URL 文件路径
total_urls: URL 总数
site_scan_dir: 站点扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
target_name: 目标名称(用于错误日志)
Returns:
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
Raises:
RuntimeError: 所有工具均失败
tuple: (tool_stats, processed_records, successful_tools, failed_tools)
"""
tool_stats = {}
processed_records = 0
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 构建完整命令(变量替换)
try:
command_params = {'url_file': urls_file}
command = build_scan_command(
tool_name=tool_name,
scan_type='site_scan',
command_params=command_params,
tool_config=tool_config
)
except Exception as e:
reason = f"命令构建失败: {str(e)}"
logger.error(f"构建 {tool_name} 命令失败: {e}")
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 获取超时时间(支持 'auto' 动态计算)
config_timeout = tool_config.get('timeout', 300)
if config_timeout == 'auto':
# 动态计算超时时间
timeout = calculate_timeout_by_line_count(tool_config, urls_file, base_per_time=1)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {timeout}")
result = _execute_single_tool(tool_name, tool_config, ctx)
if result:
tool_stats[tool_name] = result
processed_records += result['result'].get('processed_records', 0)
else:
# 使用配置的超时时间和动态计算的较大值
dynamic_timeout = calculate_timeout_by_line_count(tool_config, urls_file, base_per_time=1)
timeout = max(dynamic_timeout, config_timeout)
# 2.1 生成日志文件路径(类似端口扫描)
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = site_scan_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 站点扫描 - URL数: %d, 最终超时: %ds",
tool_name, total_urls, timeout
)
user_log(scan_id, "site_scan", f"Running {tool_name}: {command}")
# 3. 执行扫描任务
try:
# 流式执行扫描并实时保存结果
result = run_and_stream_save_websites_task(
cmd=command,
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
cwd=str(site_scan_dir),
shell=True,
timeout=timeout,
log_file=str(log_file)
)
tool_stats[tool_name] = {
'command': command,
'result': result,
'timeout': timeout
}
tool_records = result.get('processed_records', 0)
tool_created = result.get('created_websites', 0)
processed_records += tool_records
logger.info(
"✓ 工具 %s 流式处理完成 - 处理记录: %d, 创建站点: %d, 跳过: %d",
tool_name,
tool_records,
tool_created,
result.get('skipped_no_subdomain', 0) + result.get('skipped_failed', 0)
)
user_log(scan_id, "site_scan", f"{tool_name} completed: found {tool_created} websites")
except subprocess.TimeoutExpired as exc:
# 超时异常单独处理
reason = f"timeout after {timeout}s"
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.warning(
"⚠️ 工具 %s 执行超时 - 超时配置: %d\n"
"注意:超时前已解析的站点数据已保存到数据库,但扫描未完全完成。",
tool_name, timeout
)
user_log(scan_id, "site_scan", f"{tool_name} failed: {reason}", "error")
except Exception as exc:
# 其他异常
reason = str(exc)
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(scan_id, "site_scan", f"{tool_name} failed: {reason}", "error")
failed_tools.append({'tool': tool_name, 'reason': '执行失败'})
if failed_tools:
logger.warning(
"以下扫描工具执行失败: %s",
', '.join([f['tool'] for f in failed_tools])
', '.join(f['tool'] for f in failed_tools)
)
if not tool_stats:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
logger.warning("所有站点扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
# 返回空结果,不抛出异常,让扫描继续
logger.warning(
"所有站点扫描工具均失败 - 目标: %s", ctx.target_name
)
return {}, 0, [], failed_tools
# 动态计算成功的工具列表
successful_tool_names = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
successful_tools = [
name for name in enabled_tools
if name not in {f['tool'] for f in failed_tools}
]
logger.info(
"串行站点扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
len(tool_stats), len(enabled_tools),
', '.join(successful_tool_names) if successful_tool_names else '',
', '.join([f['tool'] for f in failed_tools]) if failed_tools else ''
"✓ 站点扫描执行完成 - 成功: %d/%d",
len(tool_stats), len(enabled_tools)
)
return tool_stats, processed_records, successful_tool_names, failed_tools
return tool_stats, processed_records, successful_tools, failed_tools
def calculate_timeout(url_count: int, base: int = 600, per_url: int = 1) -> int:
"""
根据 URL 数量动态计算扫描超时时间
def _build_empty_result(
scan_id: int,
target_name: str,
scan_workspace_dir: str,
urls_file: str,
association_count: int
) -> dict:
"""构建空结果(无 URL 可扫描时)"""
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'total_urls': 0,
'association_count': association_count,
'processed_records': 0,
'created_websites': 0,
'skipped_no_subdomain': 0,
'skipped_failed': 0,
'executed_tasks': ['export_site_urls'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}
规则:
- 基础时间:默认 600 秒10 分钟)
- 每个 URL 额外增加:默认 1 秒
Args:
url_count: URL 数量,必须为正整数
base: 基础超时时间(秒),默认 600
per_url: 每个 URL 增加的时间(秒),默认 1
def _aggregate_tool_results(tool_stats: dict) -> tuple[int, int, int]:
"""汇总工具结果"""
total_created = sum(
s['result'].get('created_websites', 0) for s in tool_stats.values()
)
total_skipped_no_subdomain = sum(
s['result'].get('skipped_no_subdomain', 0) for s in tool_stats.values()
)
total_skipped_failed = sum(
s['result'].get('skipped_failed', 0) for s in tool_stats.values()
)
return total_created, total_skipped_no_subdomain, total_skipped_failed
Returns:
int: 计算得到的超时时间(秒),不超过 max_timeout
Raises:
ValueError: 当 url_count 为负数或 0 时抛出异常
"""
if url_count < 0:
raise ValueError(f"URL数量不能为负数: {url_count}")
if url_count == 0:
raise ValueError("URL数量不能为0")
timeout = base + int(url_count * per_url)
# 不设置上限,由调用方根据需要控制
return timeout
def _validate_flow_params(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str
) -> None:
"""验证 Flow 参数"""
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
@flow(
name="site_scan",
name="site_scan",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
@@ -322,140 +337,83 @@ def site_scan_flow(
) -> dict:
"""
站点扫描 Flow
主要功能:
1. 从target获取所有子域名与其对应的端口号拼接成URL写入文件
2. 用httpx进行批量请求并实时保存到数据库流式处理
工作流程:
Step 0: 创建工作目录
Step 1: 导出站点 URL 列表
Step 2: 解析配置,获取启用的工具
Step 3: 串行执行扫描工具并实时保存结果
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置字典
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'urls_file': str,
'total_urls': int,
'association_count': int,
'processed_records': int,
'created_websites': int,
'skipped_no_subdomain': int,
'skipped_failed': int,
'executed_tasks': list,
'tool_stats': {
'total': int,
'successful': int,
'failed': int,
'successful_tools': list[str],
'failed_tools': list[dict]
}
}
dict: 扫描结果
Raises:
ValueError: 配置错误
RuntimeError: 执行失败
"""
try:
wait_for_system_load(context="site_scan_flow")
logger.info(
"="*60 + "\n" +
"开始站点扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始站点扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
_validate_flow_params(scan_id, target_name, target_id, scan_workspace_dir)
user_log(scan_id, "site_scan", "Starting site scan")
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan')
# Step 1: 导出站点 URL
urls_file, total_urls, association_count = _export_site_urls(
target_id, site_scan_dir, target_name
target_id, site_scan_dir
)
if total_urls == 0:
logger.warning("跳过站点扫描:没有站点 URL 可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "site_scan", "Skipped: no site URLs to scan", "warning")
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'total_urls': 0,
'association_count': association_count,
'processed_records': 0,
'created_websites': 0,
'skipped_no_subdomain': 0,
'skipped_failed': 0,
'executed_tasks': ['export_site_urls'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}
return _build_empty_result(
scan_id, target_name, scan_workspace_dir, urls_file, association_count
)
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info(
"✓ 启用工具: %s",
', '.join(enabled_tools.keys())
)
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools))
# Step 3: 串行执行扫描工具
logger.info("Step 3: 串行执行扫描工具并实时保存结果")
tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially(
enabled_tools=enabled_tools,
urls_file=urls_file,
total_urls=total_urls,
site_scan_dir=site_scan_dir,
ctx = ScanContext(
scan_id=scan_id,
target_id=target_id,
target_name=target_name
target_name=target_name,
site_scan_dir=site_scan_dir,
urls_file=urls_file,
total_urls=total_urls
)
# 动态生成已执行的任务列表
tool_stats, processed_records, successful_tools, failed_tools = \
_run_scans_sequentially(enabled_tools, ctx)
# 汇总结果
executed_tasks = ['export_site_urls', 'parse_config']
executed_tasks.extend([f'run_and_stream_save_websites ({tool})' for tool in tool_stats.keys()])
# 汇总所有工具的结果
total_created = sum(stats['result'].get('created_websites', 0) for stats in tool_stats.values())
total_skipped_no_subdomain = sum(stats['result'].get('skipped_no_subdomain', 0) for stats in tool_stats.values())
total_skipped_failed = sum(stats['result'].get('skipped_failed', 0) for stats in tool_stats.values())
# 记录 Flow 完成
executed_tasks.extend(
f'run_and_stream_save_websites ({tool})' for tool in tool_stats
)
total_created, total_skipped_no_sub, total_skipped_failed = \
_aggregate_tool_results(tool_stats)
logger.info("✓ 站点扫描完成 - 创建站点: %d", total_created)
user_log(scan_id, "site_scan", f"site_scan completed: found {total_created} websites")
user_log(
scan_id, "site_scan",
f"site_scan completed: found {total_created} websites"
)
return {
'success': True,
'scan_id': scan_id,
@@ -466,25 +424,20 @@ def site_scan_flow(
'association_count': association_count,
'processed_records': processed_records,
'created_websites': total_created,
'skipped_no_subdomain': total_skipped_no_subdomain,
'skipped_no_subdomain': total_skipped_no_sub,
'skipped_failed': total_skipped_failed,
'executed_tasks': executed_tasks,
'tool_stats': {
'total': len(enabled_tools),
'successful': len(successful_tool_names),
'successful': len(successful_tools),
'failed': len(failed_tools),
'successful_tools': successful_tool_names,
'successful_tools': successful_tools,
'failed_tools': failed_tools,
'details': tool_stats
}
}
except ValueError as e:
logger.error("配置错误: %s", e)
except ValueError:
raise
except RuntimeError as e:
logger.error("运行时错误: %s", e)
except RuntimeError:
raise
except Exception as e:
logger.exception("站点扫描失败: %s", e)
raise

File diff suppressed because it is too large Load Diff

View File

@@ -10,22 +10,18 @@ URL Fetch 主 Flow
- 统一进行 httpx 验证(如果启用)
"""
# Django 环境初始化
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
from pathlib import Path
from datetime import datetime
from pathlib import Path
from prefect import flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.utils import user_log
from apps.scan.utils import user_log, wait_for_system_load
from .domain_name_url_fetch_flow import domain_name_url_fetch_flow
from .sites_url_fetch_flow import sites_url_fetch_flow
@@ -43,13 +39,10 @@ SITES_FILE_TOOLS = {'katana'}
POST_PROCESS_TOOLS = {'uro', 'httpx'}
def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
"""
将启用的工具按输入类型分类
Returns:
tuple: (domain_name_tools, sites_file_tools, uro_config, httpx_config)
"""
@@ -76,23 +69,23 @@ def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
def _merge_and_deduplicate_urls(result_files: list, url_fetch_dir: Path) -> tuple[str, int]:
"""合并并去重 URL"""
from apps.scan.tasks.url_fetch import merge_and_deduplicate_urls_task
merged_file = merge_and_deduplicate_urls_task(
result_files=result_files,
result_dir=str(url_fetch_dir)
)
# 统计唯一 URL 数量
unique_url_count = 0
if Path(merged_file).exists():
with open(merged_file, 'r') as f:
with open(merged_file, 'r', encoding='utf-8') as f:
unique_url_count = sum(1 for line in f if line.strip())
logger.info(
"✓ URL 合并去重完成 - 合并文件: %s, 唯一 URL 数: %d",
merged_file, unique_url_count
)
return merged_file, unique_url_count
@@ -103,12 +96,12 @@ def _clean_urls_with_uro(
) -> tuple[str, int, int]:
"""使用 uro 清理合并后的 URL 列表"""
from apps.scan.tasks.url_fetch import clean_urls_task
raw_timeout = uro_config.get('timeout', 60)
whitelist = uro_config.get('whitelist')
blacklist = uro_config.get('blacklist')
filters = uro_config.get('filters')
# 计算超时时间
if isinstance(raw_timeout, str) and raw_timeout == 'auto':
timeout = calculate_timeout_by_line_count(
@@ -124,7 +117,7 @@ def _clean_urls_with_uro(
except (TypeError, ValueError):
logger.warning("uro timeout 配置无效(%s),使用默认 60 秒", raw_timeout)
timeout = 60
result = clean_urls_task(
input_file=merged_file,
output_dir=str(url_fetch_dir),
@@ -133,12 +126,12 @@ def _clean_urls_with_uro(
blacklist=blacklist,
filters=filters
)
if result['success']:
return result['output_file'], result['output_count'], result['removed_count']
else:
logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误'))
return merged_file, result['input_count'], 0
logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误'))
return merged_file, result['input_count'], 0
def _validate_and_stream_save_urls(
@@ -151,25 +144,25 @@ def _validate_and_stream_save_urls(
"""使用 httpx 验证 URL 存活并流式保存到数据库"""
from apps.scan.utils import build_scan_command
from apps.scan.tasks.url_fetch import run_and_stream_save_urls_task
logger.info("开始使用 httpx 验证 URL 存活状态...")
# 统计待验证的 URL 数量
try:
with open(merged_file, 'r') as f:
with open(merged_file, 'r', encoding='utf-8') as f:
url_count = sum(1 for _ in f)
logger.info("待验证 URL 数量: %d", url_count)
except Exception as e:
except OSError as e:
logger.error("读取 URL 文件失败: %s", e)
return 0
if url_count == 0:
logger.warning("没有需要验证的 URL")
return 0
# 构建 httpx 命令
command_params = {'url_file': merged_file}
try:
command = build_scan_command(
tool_name='httpx',
@@ -177,21 +170,19 @@ def _validate_and_stream_save_urls(
command_params=command_params,
tool_config=httpx_config
)
except Exception as e:
except (ValueError, KeyError) as e:
logger.error("构建 httpx 命令失败: %s", e)
logger.warning("降级处理:将直接保存所有 URL不验证存活")
return _save_urls_to_database(merged_file, scan_id, target_id)
# 计算超时时间
raw_timeout = httpx_config.get('timeout', 'auto')
timeout = 3600
if isinstance(raw_timeout, str) and raw_timeout == 'auto':
# 按 URL 行数计算超时时间:每行 3 秒,最小 60 秒
timeout = max(60, url_count * 3)
logger.info(
"自动计算 httpx 超时时间(按行数,每行 3 秒,最小 60 秒): url_count=%d, timeout=%d",
url_count,
timeout,
url_count, timeout
)
else:
try:
@@ -199,49 +190,44 @@ def _validate_and_stream_save_urls(
except (TypeError, ValueError):
timeout = 3600
logger.info("使用配置的 httpx 超时时间: %d", timeout)
# 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = url_fetch_dir / f"httpx_validation_{timestamp}.log"
# 流式执行
try:
result = run_and_stream_save_urls_task(
cmd=command,
tool_name='httpx',
scan_id=scan_id,
target_id=target_id,
cwd=str(url_fetch_dir),
shell=True,
timeout=timeout,
log_file=str(log_file)
)
saved = result.get('saved_urls', 0)
logger.info(
"✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)",
saved, (saved / url_count * 100) if url_count > 0 else 0
)
return saved
except Exception as e:
logger.error("httpx 流式验证失败: %s", e, exc_info=True)
raise
result = run_and_stream_save_urls_task(
cmd=command,
tool_name='httpx',
scan_id=scan_id,
target_id=target_id,
cwd=str(url_fetch_dir),
shell=True,
timeout=timeout,
log_file=str(log_file)
)
saved = result.get('saved_urls', 0)
logger.info(
"✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)",
saved, (saved / url_count * 100) if url_count > 0 else 0
)
return saved
def _save_urls_to_database(merged_file: str, scan_id: int, target_id: int) -> int:
"""保存 URL 到数据库(不验证存活)"""
from apps.scan.tasks.url_fetch import save_urls_task
result = save_urls_task(
urls_file=merged_file,
scan_id=scan_id,
target_id=target_id
)
saved_count = result.get('saved_urls', 0)
logger.info("✓ URL 保存完成 - 保存数量: %d", saved_count)
return saved_count
@@ -261,7 +247,7 @@ def url_fetch_flow(
) -> dict:
"""
URL 获取主 Flow
执行流程:
1. 准备工作目录
2. 按输入类型分类工具domain_name / sites_file / 后处理)
@@ -271,36 +257,32 @@ def url_fetch_flow(
4. 合并所有子 Flow 的结果并去重
5. uro 去重(如果启用)
6. httpx 验证(如果启用)
Args:
scan_id: 扫描 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作目录
enabled_tools: 启用的工具配置
Returns:
dict: 扫描结果
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="url_fetch_flow")
logger.info(
"="*60 + "\n" +
"开始 URL 获取扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始 URL 获取扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "url_fetch", "Starting URL fetch")
# Step 1: 准备工作目录
logger.info("Step 1: 准备工作目录")
from apps.scan.utils import setup_scan_directory
url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch')
# Step 2: 分类工具(按输入类型)
logger.info("Step 2: 分类工具")
domain_name_tools, sites_file_tools, uro_config, httpx_config = _classify_tools(enabled_tools)
logger.info(
@@ -317,15 +299,14 @@ def url_fetch_flow(
"URL Fetch 流程需要至少启用一个 URL 获取工具(如 waymore, katana"
"httpx 和 uro 仅用于后处理,不能单独使用。"
)
# Step 3: 并行执行子 Flow
# Step 3: 执行子 Flow
all_result_files = []
all_failed_tools = []
all_successful_tools = []
# 3a: 基于 domain_nametarget_name 的 URL 被动收集(如 waymore
# 3a: 基于 domain_name 的 URL 被动收集(如 waymore
if domain_name_tools:
logger.info("Step 3a: 执行基于 domain_name 的 URL 被动收集子 Flow")
tn_result = domain_name_url_fetch_flow(
scan_id=scan_id,
target_id=target_id,
@@ -336,10 +317,9 @@ def url_fetch_flow(
all_result_files.extend(tn_result.get('result_files', []))
all_failed_tools.extend(tn_result.get('failed_tools', []))
all_successful_tools.extend(tn_result.get('successful_tools', []))
# 3b: 爬虫(以 sites_file 为输入)
if sites_file_tools:
logger.info("Step 3b: 执行爬虫子 Flow")
crawl_result = sites_url_fetch_flow(
scan_id=scan_id,
target_id=target_id,
@@ -350,12 +330,13 @@ def url_fetch_flow(
all_result_files.extend(crawl_result.get('result_files', []))
all_failed_tools.extend(crawl_result.get('failed_tools', []))
all_successful_tools.extend(crawl_result.get('successful_tools', []))
# 检查是否有成功的工具
if not all_result_files:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in all_failed_tools])
error_details = "; ".join([
"%s: %s" % (f['tool'], f['reason']) for f in all_failed_tools
])
logger.warning("所有 URL 获取工具均失败 - 目标: %s, 失败详情: %s", target_name, error_details)
# 返回空结果,不抛出异常,让扫描继续
return {
'success': True,
'scan_id': scan_id,
@@ -366,31 +347,24 @@ def url_fetch_flow(
'successful_tools': [],
'message': '所有 URL 获取工具均无结果'
}
# Step 4: 合并并去重 URL
logger.info("Step 4: 合并并去重 URL")
merged_file, unique_url_count = _merge_and_deduplicate_urls(
merged_file, _ = _merge_and_deduplicate_urls(
result_files=all_result_files,
url_fetch_dir=url_fetch_dir
)
# Step 5: 使用 uro 清理 URL如果启用
url_file_for_validation = merged_file
uro_removed_count = 0
if uro_config and uro_config.get('enabled', False):
logger.info("Step 5: 使用 uro 清理 URL")
url_file_for_validation, cleaned_count, uro_removed_count = _clean_urls_with_uro(
url_file_for_validation, _, _ = _clean_urls_with_uro(
merged_file=merged_file,
uro_config=uro_config,
url_fetch_dir=url_fetch_dir
)
else:
logger.info("Step 5: 跳过 uro 清理(未启用)")
# Step 6: 使用 httpx 验证存活并保存(如果启用)
if httpx_config and httpx_config.get('enabled', False):
logger.info("Step 6: 使用 httpx 验证 URL 存活并流式保存")
saved_count = _validate_and_stream_save_urls(
merged_file=url_file_for_validation,
httpx_config=httpx_config,
@@ -399,17 +373,16 @@ def url_fetch_flow(
target_id=target_id
)
else:
logger.info("Step 6: 保存到数据库(未启用 httpx 验证)")
saved_count = _save_urls_to_database(
merged_file=url_file_for_validation,
scan_id=scan_id,
target_id=target_id
)
# 记录 Flow 完成
logger.info("✓ URL 获取完成 - 保存 endpoints: %d", saved_count)
user_log(scan_id, "url_fetch", f"url_fetch completed: found {saved_count} endpoints")
user_log(scan_id, "url_fetch", "url_fetch completed: found %d endpoints" % saved_count)
# 构建已执行的任务列表
executed_tasks = ['setup_directory', 'classify_tools']
if domain_name_tools:
@@ -423,7 +396,7 @@ def url_fetch_flow(
executed_tasks.append('httpx_validation_and_save')
else:
executed_tasks.append('save_urls')
return {
'success': True,
'scan_id': scan_id,
@@ -439,7 +412,7 @@ def url_fetch_flow(
'failed_tools': [f['tool'] for f in all_failed_tools]
}
}
except Exception as e:
logger.error("URL 获取扫描失败: %s", e, exc_info=True)
raise

View File

@@ -1,5 +1,6 @@
from apps.common.prefect_django_setup import setup_django_for_prefect
"""
漏洞扫描主 Flow
"""
import logging
from typing import Dict, Tuple
@@ -11,7 +12,7 @@ from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_failed,
)
from apps.scan.configs.command_templates import get_command_template
from apps.scan.utils import user_log
from apps.scan.utils import user_log, wait_for_system_load
from .endpoints_vuln_scan_flow import endpoints_vuln_scan_flow
@@ -62,6 +63,9 @@ def vuln_scan_flow(
- nuclei: 通用漏洞扫描(流式保存,支持模板 commit hash 同步)
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="vuln_scan_flow")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:

View File

@@ -0,0 +1,56 @@
"""
扫描目标提供者模块
提供统一的目标获取接口,支持多种数据源:
- DatabaseTargetProvider: 从数据库查询(完整扫描)
- ListTargetProvider: 使用内存列表快速扫描阶段1
- SnapshotTargetProvider: 从快照表读取快速扫描阶段2+
- PipelineTargetProvider: 使用管道输出Phase 2
使用方式:
from apps.scan.providers import (
DatabaseTargetProvider,
ListTargetProvider,
SnapshotTargetProvider,
ProviderContext
)
# 数据库模式(完整扫描)
provider = DatabaseTargetProvider(target_id=123)
# 列表模式快速扫描阶段1
context = ProviderContext(target_id=1, scan_id=100)
provider = ListTargetProvider(
targets=["a.test.com"],
context=context
)
# 快照模式快速扫描阶段2+
context = ProviderContext(target_id=1, scan_id=100)
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain",
context=context
)
# 使用 Provider
for host in provider.iter_hosts():
scan(host)
"""
from .base import TargetProvider, ProviderContext
from .list_provider import ListTargetProvider
from .database_provider import DatabaseTargetProvider
from .snapshot_provider import SnapshotTargetProvider, SnapshotType
from .pipeline_provider import PipelineTargetProvider, StageOutput
__all__ = [
'TargetProvider',
'ProviderContext',
'ListTargetProvider',
'DatabaseTargetProvider',
'SnapshotTargetProvider',
'SnapshotType',
'PipelineTargetProvider',
'StageOutput',
]

View File

@@ -0,0 +1,115 @@
"""
扫描目标提供者基础模块
定义 ProviderContext 数据类和 TargetProvider 抽象基类。
"""
import ipaddress
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, Optional
if TYPE_CHECKING:
from apps.common.utils import BlacklistFilter
logger = logging.getLogger(__name__)
@dataclass
class ProviderContext:
"""
Provider 上下文,携带元数据
Attributes:
target_id: 关联的 Target ID用于结果保存None 表示临时扫描(不保存)
scan_id: 扫描任务 ID
"""
target_id: Optional[int] = None
scan_id: Optional[int] = None
class TargetProvider(ABC):
"""
扫描目标提供者抽象基类
职责:
- 提供扫描目标域名、IP、URL 等)的迭代器
- 提供黑名单过滤器
- 携带上下文信息target_id, scan_id 等)
- 自动展开 CIDR子类无需关心
使用方式:
provider = create_target_provider(target_id=123)
for host in provider.iter_hosts():
print(host)
"""
def __init__(self, context: Optional[ProviderContext] = None):
self._context = context or ProviderContext()
@property
def context(self) -> ProviderContext:
"""返回 Provider 上下文"""
return self._context
@staticmethod
def _expand_host(host: str) -> Iterator[str]:
"""
展开主机(如果是 CIDR 则展开为多个 IP否则直接返回
示例:
"192.168.1.0/30""192.168.1.1", "192.168.1.2"
"192.168.1.1""192.168.1.1"
"example.com""example.com"
"""
from apps.common.validators import detect_target_type
from apps.targets.models import Target
host = host.strip()
if not host:
return
try:
target_type = detect_target_type(host)
if target_type == Target.TargetType.CIDR:
network = ipaddress.ip_network(host, strict=False)
if network.num_addresses == 1:
yield str(network.network_address)
else:
yield from (str(ip) for ip in network.hosts())
elif target_type in (Target.TargetType.IP, Target.TargetType.DOMAIN):
yield host
except ValueError as e:
logger.warning("跳过无效的主机格式 '%s': %s", host, str(e))
def iter_hosts(self) -> Iterator[str]:
"""迭代主机列表(域名/IP自动展开 CIDR"""
for host in self._iter_raw_hosts():
yield from self._expand_host(host)
@abstractmethod
def _iter_raw_hosts(self) -> Iterator[str]:
"""迭代原始主机列表(可能包含 CIDR子类实现"""
pass
@abstractmethod
def iter_urls(self) -> Iterator[str]:
"""迭代 URL 列表"""
pass
@abstractmethod
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
"""获取黑名单过滤器,返回 None 表示不过滤"""
pass
@property
def target_id(self) -> Optional[int]:
"""返回关联的 target_id临时扫描返回 None"""
return self._context.target_id
@property
def scan_id(self) -> Optional[int]:
"""返回关联的 scan_id"""
return self._context.scan_id

View File

@@ -0,0 +1,93 @@
"""
数据库目标提供者模块
提供基于数据库查询的目标提供者实现。
"""
import logging
from typing import TYPE_CHECKING, Iterator, Optional
from .base import ProviderContext, TargetProvider
if TYPE_CHECKING:
from apps.common.utils import BlacklistFilter
logger = logging.getLogger(__name__)
class DatabaseTargetProvider(TargetProvider):
"""
数据库目标提供者 - 从 Target 表及关联资产表查询
数据来源:
- iter_hosts(): 根据 Target 类型返回域名/IP
- iter_urls(): WebSite/Endpoint 表,带回退链
使用方式:
provider = DatabaseTargetProvider(target_id=123)
for host in provider.iter_hosts():
scan(host)
"""
def __init__(self, target_id: int, context: Optional[ProviderContext] = None):
ctx = context or ProviderContext()
ctx.target_id = target_id
super().__init__(ctx)
self._blacklist_filter: Optional['BlacklistFilter'] = None
def iter_hosts(self) -> Iterator[str]:
"""从数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤"""
blacklist = self.get_blacklist_filter()
for host in self._iter_raw_hosts():
for expanded_host in self._expand_host(host):
if not blacklist or blacklist.is_allowed(expanded_host):
yield expanded_host
def _iter_raw_hosts(self) -> Iterator[str]:
"""从数据库查询原始主机列表(可能包含 CIDR"""
from apps.asset.services.asset.subdomain_service import SubdomainService
from apps.targets.models import Target
from apps.targets.services import TargetService
target = TargetService().get_target(self.target_id)
if not target:
logger.warning("Target ID %d 不存在", self.target_id)
return
if target.type == Target.TargetType.DOMAIN:
yield target.name
for domain in SubdomainService().iter_subdomain_names_by_target(
target_id=self.target_id,
chunk_size=1000
):
if domain != target.name:
yield domain
elif target.type in (Target.TargetType.IP, Target.TargetType.CIDR):
yield target.name
def iter_urls(self) -> Iterator[str]:
"""从数据库查询 URL 列表使用回退链Endpoint → WebSite → Default"""
from apps.scan.services.target_export_service import (
DataSource,
_iter_urls_with_fallback,
)
blacklist = self.get_blacklist_filter()
for url, _ in _iter_urls_with_fallback(
target_id=self.target_id,
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
blacklist_filter=blacklist
):
yield url
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
"""获取黑名单过滤器(延迟加载)"""
if self._blacklist_filter is None:
from apps.common.services import BlacklistService
from apps.common.utils import BlacklistFilter
rules = BlacklistService().get_rules(self.target_id)
self._blacklist_filter = BlacklistFilter(rules)
return self._blacklist_filter

View File

@@ -0,0 +1,84 @@
"""
列表目标提供者模块
提供基于内存列表的目标提供者实现。
"""
from typing import Iterator, Optional, List
from .base import TargetProvider, ProviderContext
class ListTargetProvider(TargetProvider):
"""
列表目标提供者 - 直接使用内存中的列表
用于快速扫描、临时扫描等场景,只扫描用户指定的目标。
特点:
- 不查询数据库
- 不应用黑名单过滤(用户明确指定的目标)
- 不关联 target_id由调用方负责创建 Target
- 自动检测输入类型URL/域名/IP/CIDR
- 自动展开 CIDR
使用方式:
# 快速扫描:用户提供目标,自动识别类型
provider = ListTargetProvider(targets=[
"example.com", # 域名
"192.168.1.0/24", # CIDR自动展开
"https://api.example.com" # URL
])
for host in provider.iter_hosts():
scan(host)
"""
def __init__(
self,
targets: Optional[List[str]] = None,
context: Optional[ProviderContext] = None
):
"""
初始化列表目标提供者
Args:
targets: 目标列表自动识别类型URL/域名/IP/CIDR
context: Provider 上下文
"""
from apps.common.validators import detect_input_type
ctx = context or ProviderContext()
super().__init__(ctx)
# 自动分类目标
self._hosts = []
self._urls = []
if targets:
for target in targets:
target = target.strip()
if not target:
continue
try:
input_type = detect_input_type(target)
if input_type == 'url':
self._urls.append(target)
else:
# domain/ip/cidr 都作为 host
self._hosts.append(target)
except ValueError:
# 无法识别类型,默认作为 host
self._hosts.append(target)
def _iter_raw_hosts(self) -> Iterator[str]:
"""迭代原始主机列表(可能包含 CIDR"""
yield from self._hosts
def iter_urls(self) -> Iterator[str]:
"""迭代 URL 列表"""
yield from self._urls
def get_blacklist_filter(self) -> None:
"""列表模式不使用黑名单过滤"""
return None

View File

@@ -0,0 +1,91 @@
"""
管道目标提供者模块
提供基于管道阶段输出的目标提供者实现。
用于 Phase 2 管道模式的阶段间数据传递。
"""
from dataclasses import dataclass, field
from typing import Iterator, Optional, List, Dict, Any
from .base import TargetProvider, ProviderContext
@dataclass
class StageOutput:
"""
阶段输出数据
用于在管道阶段之间传递数据。
Attributes:
hosts: 主机列表(域名/IP
urls: URL 列表
new_targets: 新发现的目标列表
stats: 统计信息
success: 是否成功
error: 错误信息
"""
hosts: List[str] = field(default_factory=list)
urls: List[str] = field(default_factory=list)
new_targets: List[str] = field(default_factory=list)
stats: Dict[str, Any] = field(default_factory=dict)
success: bool = True
error: Optional[str] = None
class PipelineTargetProvider(TargetProvider):
"""
管道目标提供者 - 使用上一阶段的输出
用于 Phase 2 管道模式的阶段间数据传递。
特点:
- 不查询数据库
- 不应用黑名单过滤(数据已在上一阶段过滤)
- 直接使用 StageOutput 中的数据
使用方式Phase 2
stage1_output = stage1.run(input)
provider = PipelineTargetProvider(
previous_output=stage1_output,
target_id=123
)
for host in provider.iter_hosts():
stage2.scan(host)
"""
def __init__(
self,
previous_output: StageOutput,
target_id: Optional[int] = None,
context: Optional[ProviderContext] = None
):
"""
初始化管道目标提供者
Args:
previous_output: 上一阶段的输出
target_id: 可选,关联到某个 Target用于保存结果
context: Provider 上下文
"""
ctx = context or ProviderContext(target_id=target_id)
super().__init__(ctx)
self._previous_output = previous_output
def _iter_raw_hosts(self) -> Iterator[str]:
"""迭代上一阶段输出的原始主机(可能包含 CIDR"""
yield from self._previous_output.hosts
def iter_urls(self) -> Iterator[str]:
"""迭代上一阶段输出的 URL"""
yield from self._previous_output.urls
def get_blacklist_filter(self) -> None:
"""管道传递的数据已经过滤过了"""
return None
@property
def previous_output(self) -> StageOutput:
"""返回上一阶段的输出"""
return self._previous_output

View File

@@ -0,0 +1,175 @@
"""
快照目标提供者模块
提供基于快照表的目标提供者实现。
用于快速扫描的阶段间数据传递。
"""
import logging
from typing import Iterator, Optional, Literal
from .base import TargetProvider, ProviderContext
logger = logging.getLogger(__name__)
# 快照类型定义
SnapshotType = Literal["subdomain", "website", "endpoint", "host_port"]
class SnapshotTargetProvider(TargetProvider):
"""
快照目标提供者 - 从快照表读取本次扫描的数据
用于快速扫描的阶段间数据传递,解决精确扫描控制问题。
核心价值:
- 只返回本次扫描scan_id发现的资产
- 避免扫描历史数据DatabaseTargetProvider 会扫描所有历史资产)
特点:
- 通过 scan_id 过滤快照表
- 不应用黑名单过滤(数据已在上一阶段过滤)
- 支持多种快照类型subdomain/website/endpoint/host_port
使用场景:
# 快速扫描流程
用户输入: a.test.com
创建 Target: test.com (id=1)
创建 Scan: scan_id=100
# 阶段1: 子域名发现
provider = ListTargetProvider(
targets=["a.test.com"],
context=ProviderContext(target_id=1, scan_id=100)
)
# 发现: b.a.test.com, c.a.test.com
# 保存: SubdomainSnapshot(scan_id=100) + Subdomain(target_id=1)
# 阶段2: 端口扫描
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain",
context=ProviderContext(target_id=1, scan_id=100)
)
# 只返回: b.a.test.com, c.a.test.com本次扫描发现的
# 不返回: www.test.com, api.test.com历史数据
# 阶段3: 网站扫描
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="host_port",
context=ProviderContext(target_id=1, scan_id=100)
)
# 只返回本次扫描发现的 IP:Port
"""
def __init__(
self,
scan_id: int,
snapshot_type: SnapshotType,
context: Optional[ProviderContext] = None
):
"""
初始化快照目标提供者
Args:
scan_id: 扫描任务 ID必需
snapshot_type: 快照类型
- "subdomain": 子域名快照SubdomainSnapshot
- "website": 网站快照WebsiteSnapshot
- "endpoint": 端点快照EndpointSnapshot
- "host_port": 主机端口映射快照HostPortMappingSnapshot
context: Provider 上下文
"""
ctx = context or ProviderContext()
ctx.scan_id = scan_id
super().__init__(ctx)
self._scan_id = scan_id
self._snapshot_type = snapshot_type
def _iter_raw_hosts(self) -> Iterator[str]:
"""
从快照表迭代主机列表
根据 snapshot_type 选择不同的快照表:
- subdomain: SubdomainSnapshot.name
- host_port: HostPortMappingSnapshot.host (返回 host:port 格式,不经过验证)
"""
if self._snapshot_type == "subdomain":
from apps.asset.services.snapshot import SubdomainSnapshotsService
service = SubdomainSnapshotsService()
yield from service.iter_subdomain_names_by_scan(
scan_id=self._scan_id,
chunk_size=1000
)
elif self._snapshot_type == "host_port":
# host_port 类型不使用 _iter_raw_hosts直接在 iter_hosts 中处理
# 这里返回空,避免被基类的 iter_hosts 调用
return
else:
# 其他类型暂不支持 iter_hosts
logger.warning(
"快照类型 '%s' 不支持 iter_hosts返回空迭代器",
self._snapshot_type
)
return
def iter_hosts(self) -> Iterator[str]:
"""
迭代主机列表
对于 host_port 类型,返回 host:port 格式,不经过 CIDR 展开验证
"""
if self._snapshot_type == "host_port":
# host_port 类型直接返回 host:port不经过 _expand_host 验证
from apps.asset.services.snapshot import HostPortMappingSnapshotsService
service = HostPortMappingSnapshotsService()
queryset = service.get_by_scan(scan_id=self._scan_id)
for mapping in queryset.iterator(chunk_size=1000):
yield f"{mapping.host}:{mapping.port}"
else:
# 其他类型使用基类的 iter_hosts会调用 _iter_raw_hosts 并展开 CIDR
yield from super().iter_hosts()
def iter_urls(self) -> Iterator[str]:
"""
从快照表迭代 URL 列表
根据 snapshot_type 选择不同的快照表:
- website: WebsiteSnapshot.url
- endpoint: EndpointSnapshot.url
"""
if self._snapshot_type == "website":
from apps.asset.services.snapshot import WebsiteSnapshotsService
service = WebsiteSnapshotsService()
yield from service.iter_website_urls_by_scan(
scan_id=self._scan_id,
chunk_size=1000
)
elif self._snapshot_type == "endpoint":
from apps.asset.services.snapshot import EndpointSnapshotsService
service = EndpointSnapshotsService()
# 从快照表获取端点 URL
queryset = service.get_by_scan(scan_id=self._scan_id)
for endpoint in queryset.iterator(chunk_size=1000):
yield endpoint.url
else:
# 其他类型暂不支持 iter_urls
logger.warning(
"快照类型 '%s' 不支持 iter_urls返回空迭代器",
self._snapshot_type
)
return
def get_blacklist_filter(self) -> None:
"""快照数据已在上一阶段过滤过了"""
return None
@property
def snapshot_type(self) -> SnapshotType:
"""返回快照类型"""
return self._snapshot_type

View File

@@ -0,0 +1,3 @@
"""
扫描目标提供者测试模块
"""

View File

@@ -0,0 +1,256 @@
"""
通用属性测试
包含跨多个 Provider 的通用属性测试:
- Property 4: Context Propagation
- Property 5: Non-Database Provider Blacklist Filter
- Property 7: CIDR Expansion Consistency
"""
import pytest
from hypothesis import given, strategies as st, settings
from ipaddress import IPv4Network
from apps.scan.providers import (
ProviderContext,
ListTargetProvider,
DatabaseTargetProvider,
PipelineTargetProvider,
SnapshotTargetProvider
)
from apps.scan.providers.pipeline_provider import StageOutput
class TestContextPropagation:
"""
Property 4: Context Propagation
*For any* ProviderContext传入 Provider 构造函数后,
Provider 的 target_id 和 scan_id 属性应该与 context 中的值一致。
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_list_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (ListTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
provider = ListTargetProvider(targets=["example.com"], context=ctx)
assert provider.target_id == target_id
assert provider.scan_id == scan_id
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_database_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (DatabaseTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=999, scan_id=scan_id)
# DatabaseTargetProvider 会覆盖 context 中的 target_id
provider = DatabaseTargetProvider(target_id=target_id, context=ctx)
assert provider.target_id == target_id # 使用构造函数参数
assert provider.scan_id == scan_id # 使用 context 中的值
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_pipeline_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (PipelineTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
assert provider.target_id == target_id
assert provider.scan_id == scan_id
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_snapshot_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (SnapshotTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=target_id, scan_id=999)
# SnapshotTargetProvider 会覆盖 context 中的 scan_id
provider = SnapshotTargetProvider(
scan_id=scan_id,
snapshot_type="subdomain",
context=ctx
)
assert provider.target_id == target_id # 使用 context 中的值
assert provider.scan_id == scan_id # 使用构造函数参数
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
class TestNonDatabaseProviderBlacklistFilter:
"""
Property 5: Non-Database Provider Blacklist Filter
*For any* ListTargetProvider 或 PipelineTargetProvider 实例,
get_blacklist_filter() 方法应该返回 None。
**Validates: Requirements 3.4, 9.4, 9.5**
"""
@given(targets=st.lists(st.text(min_size=1, max_size=20), max_size=10))
@settings(max_examples=100)
def test_property_5_list_provider_no_blacklist(self, targets):
"""
Property 5: Non-Database Provider Blacklist Filter (ListTargetProvider)
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
**Validates: Requirements 3.4, 9.4, 9.5**
"""
provider = ListTargetProvider(targets=targets)
assert provider.get_blacklist_filter() is None
@given(hosts=st.lists(st.text(min_size=1, max_size=20), max_size=10))
@settings(max_examples=100)
def test_property_5_pipeline_provider_no_blacklist(self, hosts):
"""
Property 5: Non-Database Provider Blacklist Filter (PipelineTargetProvider)
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
**Validates: Requirements 3.4, 9.4, 9.5**
"""
stage_output = StageOutput(hosts=hosts)
provider = PipelineTargetProvider(previous_output=stage_output)
assert provider.get_blacklist_filter() is None
def test_property_5_snapshot_provider_no_blacklist(self):
"""
Property 5: Non-Database Provider Blacklist Filter (SnapshotTargetProvider)
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
**Validates: Requirements 3.4, 9.4, 9.5**
"""
provider = SnapshotTargetProvider(scan_id=1, snapshot_type="subdomain")
assert provider.get_blacklist_filter() is None
class TestCIDRExpansionConsistency:
"""
Property 7: CIDR Expansion Consistency
*For any* CIDR 字符串(如 "192.168.1.0/24"),所有 Provider 的 iter_hosts()
方法应该将其展开为相同的单个 IP 地址列表。
**Validates: Requirements 1.1, 3.6**
"""
@given(
# 生成小的 CIDR 范围以避免测试超时
network_prefix=st.integers(min_value=1, max_value=254),
cidr_suffix=st.integers(min_value=28, max_value=30) # /28 = 16 IPs, /30 = 4 IPs
)
@settings(max_examples=50, deadline=None)
def test_property_7_cidr_expansion_consistency(self, network_prefix, cidr_suffix):
"""
Property 7: CIDR Expansion Consistency
Feature: scan-target-provider, Property 7: CIDR Expansion Consistency
**Validates: Requirements 1.1, 3.6**
For any CIDR string, all Providers should expand it to the same IP list.
"""
cidr = f"192.168.{network_prefix}.0/{cidr_suffix}"
# 计算预期的 IP 列表
network = IPv4Network(cidr, strict=False)
# 排除网络地址和广播地址
expected_ips = [str(ip) for ip in network.hosts()]
# 如果 CIDR 太小(/31 或 /32使用所有地址
if not expected_ips:
expected_ips = [str(ip) for ip in network]
# ListTargetProvider
list_provider = ListTargetProvider(targets=[cidr])
list_result = list(list_provider.iter_hosts())
# PipelineTargetProvider
stage_output = StageOutput(hosts=[cidr])
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
pipeline_result = list(pipeline_provider.iter_hosts())
# 验证:所有 Provider 展开的结果应该一致
assert list_result == expected_ips, f"ListProvider CIDR expansion mismatch for {cidr}"
assert pipeline_result == expected_ips, f"PipelineProvider CIDR expansion mismatch for {cidr}"
assert list_result == pipeline_result, f"Providers produce different results for {cidr}"
def test_cidr_expansion_with_multiple_cidrs(self):
"""测试多个 CIDR 的展开一致性"""
cidrs = ["192.168.1.0/30", "10.0.0.0/30"]
# 计算预期结果
expected_ips = []
for cidr in cidrs:
network = IPv4Network(cidr, strict=False)
expected_ips.extend([str(ip) for ip in network.hosts()])
# ListTargetProvider
list_provider = ListTargetProvider(targets=cidrs)
list_result = list(list_provider.iter_hosts())
# PipelineTargetProvider
stage_output = StageOutput(hosts=cidrs)
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
pipeline_result = list(pipeline_provider.iter_hosts())
# 验证
assert list_result == expected_ips
assert pipeline_result == expected_ips
assert list_result == pipeline_result
def test_mixed_hosts_and_cidrs(self):
"""测试混合主机和 CIDR 的处理"""
targets = ["example.com", "192.168.1.0/30", "test.com"]
# 计算预期结果
network = IPv4Network("192.168.1.0/30", strict=False)
cidr_ips = [str(ip) for ip in network.hosts()]
expected = ["example.com"] + cidr_ips + ["test.com"]
# ListTargetProvider
list_provider = ListTargetProvider(targets=targets)
list_result = list(list_provider.iter_hosts())
# 验证
assert list_result == expected

View File

@@ -0,0 +1,158 @@
"""
DatabaseTargetProvider 属性测试
Property 7: DatabaseTargetProvider Blacklist Application
*For any* 带有黑名单规则的 target_idDatabaseTargetProvider 的 iter_hosts() 和 iter_urls()
应该过滤掉匹配黑名单规则的目标。
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
"""
import pytest
from unittest.mock import patch, MagicMock
from hypothesis import given, strategies as st, settings
from apps.scan.providers.database_provider import DatabaseTargetProvider
from apps.scan.providers.base import ProviderContext
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
class MockBlacklistFilter:
"""模拟黑名单过滤器"""
def __init__(self, blocked_patterns: list):
self.blocked_patterns = blocked_patterns
def is_allowed(self, target: str) -> bool:
"""检查目标是否被允许(不在黑名单中)"""
for pattern in self.blocked_patterns:
if pattern in target:
return False
return True
class TestDatabaseTargetProviderProperties:
"""DatabaseTargetProvider 属性测试类"""
@given(
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=20),
blocked_keyword=st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=5
)
)
@settings(max_examples=100)
def test_property_7_blacklist_filters_hosts(self, hosts, blocked_keyword):
"""
Property 7: DatabaseTargetProvider Blacklist Application (hosts)
Feature: scan-target-provider, Property 7: DatabaseTargetProvider Blacklist Application
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
For any set of hosts and a blacklist keyword, the provider should filter out
all hosts containing the blocked keyword.
"""
# 创建模拟的黑名单过滤器
mock_filter = MockBlacklistFilter([blocked_keyword])
# 创建 provider 并注入模拟的黑名单过滤器
provider = DatabaseTargetProvider(target_id=1)
provider._blacklist_filter = mock_filter
# 模拟 Target 和 SubdomainService
mock_target = MagicMock()
mock_target.type = 'domain'
mock_target.name = hosts[0] if hosts else 'example.com'
with patch('apps.targets.services.TargetService') as mock_target_service, \
patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_subdomain_service:
mock_target_service.return_value.get_target.return_value = mock_target
mock_subdomain_service.return_value.iter_subdomain_names_by_target.return_value = iter(hosts[1:] if len(hosts) > 1 else [])
# 获取结果
result = list(provider.iter_hosts())
# 验证:所有结果都不包含被阻止的关键词
for host in result:
assert blocked_keyword not in host, f"Host '{host}' should be filtered by blacklist keyword '{blocked_keyword}'"
# 验证:所有不包含关键词的主机都应该在结果中
if hosts:
all_hosts = [hosts[0]] + [h for h in hosts[1:] if h != hosts[0]]
expected_allowed = [h for h in all_hosts if blocked_keyword not in h]
else:
expected_allowed = []
assert set(result) == set(expected_allowed)
class TestDatabaseTargetProviderUnit:
"""DatabaseTargetProvider 单元测试类"""
def test_target_id_in_context(self):
"""测试 target_id 正确设置到上下文中"""
provider = DatabaseTargetProvider(target_id=123)
assert provider.target_id == 123
assert provider.context.target_id == 123
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(scan_id=789)
provider = DatabaseTargetProvider(target_id=123, context=ctx)
assert provider.target_id == 123 # target_id 被覆盖
assert provider.scan_id == 789
def test_blacklist_filter_lazy_loading(self):
"""测试黑名单过滤器延迟加载"""
provider = DatabaseTargetProvider(target_id=123)
# 初始时 _blacklist_filter 为 None
assert provider._blacklist_filter is None
# 模拟 BlacklistService
with patch('apps.common.services.BlacklistService') as mock_service, \
patch('apps.common.utils.BlacklistFilter') as mock_filter_class:
mock_service.return_value.get_rules.return_value = []
mock_filter_instance = MagicMock()
mock_filter_class.return_value = mock_filter_instance
# 第一次调用
result1 = provider.get_blacklist_filter()
assert result1 == mock_filter_instance
# 第二次调用应该返回缓存的实例
result2 = provider.get_blacklist_filter()
assert result2 == mock_filter_instance
# BlacklistService 只应该被调用一次
mock_service.return_value.get_rules.assert_called_once_with(123)
def test_nonexistent_target_returns_empty(self):
"""测试不存在的 target 返回空迭代器"""
provider = DatabaseTargetProvider(target_id=99999)
with patch('apps.targets.services.TargetService') as mock_service, \
patch('apps.common.services.BlacklistService') as mock_blacklist_service:
mock_service.return_value.get_target.return_value = None
mock_blacklist_service.return_value.get_rules.return_value = []
result = list(provider.iter_hosts())
assert result == []

View File

@@ -0,0 +1,152 @@
"""
ListTargetProvider 属性测试
Property 1: ListTargetProvider Round-Trip
*For any* 主机列表和 URL 列表,创建 ListTargetProvider 后迭代 iter_hosts() 和 iter_urls()
应该返回与输入相同的元素(顺序相同)。
**Validates: Requirements 3.1, 3.2**
"""
import pytest
from hypothesis import given, strategies as st, settings, assume
from apps.scan.providers.list_provider import ListTargetProvider
from apps.scan.providers.base import ProviderContext
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
# 生成简单的域名格式: subdomain.domain.tld
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
# 生成有效 IP 地址的策略
def valid_ip_strategy():
"""生成有效的 IPv4 地址"""
octet = st.integers(min_value=1, max_value=254)
return st.builds(
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
octet, octet, octet, octet
)
# 组合策略:域名或 IP
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
# 生成有效 URL 的策略
def valid_url_strategy():
"""生成有效的 URL"""
domain = valid_domain_strategy()
return st.builds(
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
domain,
st.one_of(
st.just(""),
st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=1,
max_size=10
)
)
)
url_strategy = valid_url_strategy()
class TestListTargetProviderProperties:
"""ListTargetProvider 属性测试类"""
@given(hosts=st.lists(host_strategy, max_size=50))
@settings(max_examples=100)
def test_property_1_hosts_round_trip(self, hosts):
"""
Property 1: ListTargetProvider Round-Trip (hosts)
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
**Validates: Requirements 3.1, 3.2**
For any host list, creating a ListTargetProvider and iterating iter_hosts()
should return the same elements in the same order.
"""
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
provider = ListTargetProvider(targets=hosts)
result = list(provider.iter_hosts())
assert result == hosts
@given(urls=st.lists(url_strategy, max_size=50))
@settings(max_examples=100)
def test_property_1_urls_round_trip(self, urls):
"""
Property 1: ListTargetProvider Round-Trip (urls)
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
**Validates: Requirements 3.1, 3.2**
For any URL list, creating a ListTargetProvider and iterating iter_urls()
should return the same elements in the same order.
"""
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
provider = ListTargetProvider(targets=urls)
result = list(provider.iter_urls())
assert result == urls
@given(
hosts=st.lists(host_strategy, max_size=30),
urls=st.lists(url_strategy, max_size=30)
)
@settings(max_examples=100)
def test_property_1_combined_round_trip(self, hosts, urls):
"""
Property 1: ListTargetProvider Round-Trip (combined)
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
**Validates: Requirements 3.1, 3.2**
For any combination of hosts and URLs, both should round-trip correctly.
"""
# 合并 hosts 和 urlsListTargetProvider 会自动分类
combined = hosts + urls
provider = ListTargetProvider(targets=combined)
hosts_result = list(provider.iter_hosts())
urls_result = list(provider.iter_urls())
assert hosts_result == hosts
assert urls_result == urls
class TestListTargetProviderUnit:
"""ListTargetProvider 单元测试类"""
def test_empty_lists(self):
"""测试空列表返回空迭代器 - Requirements 3.5"""
provider = ListTargetProvider()
assert list(provider.iter_hosts()) == []
assert list(provider.iter_urls()) == []
def test_blacklist_filter_returns_none(self):
"""测试黑名单过滤器返回 None - Requirements 3.4"""
provider = ListTargetProvider(targets=["example.com"])
assert provider.get_blacklist_filter() is None
def test_target_id_association(self):
"""测试 target_id 关联 - Requirements 3.3"""
ctx = ProviderContext(target_id=123)
provider = ListTargetProvider(targets=["example.com"], context=ctx)
assert provider.target_id == 123
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(target_id=456, scan_id=789)
provider = ListTargetProvider(targets=["example.com"], context=ctx)
assert provider.target_id == 456
assert provider.scan_id == 789

View File

@@ -0,0 +1,180 @@
"""
PipelineTargetProvider 属性测试
Property 3: PipelineTargetProvider Round-Trip
*For any* StageOutput 对象PipelineTargetProvider 的 iter_hosts() 和 iter_urls()
应该返回与 StageOutput 中 hosts 和 urls 列表相同的元素。
**Validates: Requirements 5.1, 5.2**
"""
import pytest
from hypothesis import given, strategies as st, settings
from apps.scan.providers.pipeline_provider import PipelineTargetProvider, StageOutput
from apps.scan.providers.base import ProviderContext
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
# 生成有效 IP 地址的策略
def valid_ip_strategy():
"""生成有效的 IPv4 地址"""
octet = st.integers(min_value=1, max_value=254)
return st.builds(
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
octet, octet, octet, octet
)
# 组合策略:域名或 IP
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
# 生成有效 URL 的策略
def valid_url_strategy():
"""生成有效的 URL"""
domain = valid_domain_strategy()
return st.builds(
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
domain,
st.one_of(
st.just(""),
st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=1,
max_size=10
)
)
)
url_strategy = valid_url_strategy()
class TestPipelineTargetProviderProperties:
"""PipelineTargetProvider 属性测试类"""
@given(hosts=st.lists(host_strategy, max_size=50))
@settings(max_examples=100)
def test_property_3_hosts_round_trip(self, hosts):
"""
Property 3: PipelineTargetProvider Round-Trip (hosts)
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
**Validates: Requirements 5.1, 5.2**
For any StageOutput with hosts, PipelineTargetProvider should return
the same hosts in the same order.
"""
stage_output = StageOutput(hosts=hosts)
provider = PipelineTargetProvider(previous_output=stage_output)
result = list(provider.iter_hosts())
assert result == hosts
@given(urls=st.lists(url_strategy, max_size=50))
@settings(max_examples=100)
def test_property_3_urls_round_trip(self, urls):
"""
Property 3: PipelineTargetProvider Round-Trip (urls)
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
**Validates: Requirements 5.1, 5.2**
For any StageOutput with urls, PipelineTargetProvider should return
the same urls in the same order.
"""
stage_output = StageOutput(urls=urls)
provider = PipelineTargetProvider(previous_output=stage_output)
result = list(provider.iter_urls())
assert result == urls
@given(
hosts=st.lists(host_strategy, max_size=30),
urls=st.lists(url_strategy, max_size=30)
)
@settings(max_examples=100)
def test_property_3_combined_round_trip(self, hosts, urls):
"""
Property 3: PipelineTargetProvider Round-Trip (combined)
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
**Validates: Requirements 5.1, 5.2**
For any StageOutput with both hosts and urls, both should round-trip correctly.
"""
stage_output = StageOutput(hosts=hosts, urls=urls)
provider = PipelineTargetProvider(previous_output=stage_output)
hosts_result = list(provider.iter_hosts())
urls_result = list(provider.iter_urls())
assert hosts_result == hosts
assert urls_result == urls
class TestPipelineTargetProviderUnit:
"""PipelineTargetProvider 单元测试类"""
def test_empty_stage_output(self):
"""测试空 StageOutput 返回空迭代器 - Requirements 5.5"""
stage_output = StageOutput()
provider = PipelineTargetProvider(previous_output=stage_output)
assert list(provider.iter_hosts()) == []
assert list(provider.iter_urls()) == []
def test_blacklist_filter_returns_none(self):
"""测试黑名单过滤器返回 None - Requirements 5.3"""
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output)
assert provider.get_blacklist_filter() is None
def test_target_id_association(self):
"""测试 target_id 关联 - Requirements 5.4"""
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output, target_id=123)
assert provider.target_id == 123
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(target_id=456, scan_id=789)
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
assert provider.target_id == 456
assert provider.scan_id == 789
def test_previous_output_property(self):
"""测试 previous_output 属性"""
stage_output = StageOutput(hosts=["example.com"], urls=["https://example.com"])
provider = PipelineTargetProvider(previous_output=stage_output)
assert provider.previous_output is stage_output
assert provider.previous_output.hosts == ["example.com"]
assert provider.previous_output.urls == ["https://example.com"]
def test_stage_output_with_metadata(self):
"""测试带元数据的 StageOutput"""
stage_output = StageOutput(
hosts=["example.com"],
urls=["https://example.com"],
new_targets=["new.example.com"],
stats={"count": 1},
success=True,
error=None
)
provider = PipelineTargetProvider(previous_output=stage_output)
assert list(provider.iter_hosts()) == ["example.com"]
assert list(provider.iter_urls()) == ["https://example.com"]
assert provider.previous_output.new_targets == ["new.example.com"]
assert provider.previous_output.stats == {"count": 1}

View File

@@ -0,0 +1,191 @@
"""
SnapshotTargetProvider 单元测试
"""
import pytest
from unittest.mock import Mock, patch
from apps.scan.providers import SnapshotTargetProvider, ProviderContext
class TestSnapshotTargetProvider:
"""SnapshotTargetProvider 测试类"""
def test_init_with_scan_id_and_type(self):
"""测试初始化"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain"
)
assert provider.scan_id == 100
assert provider.snapshot_type == "subdomain"
assert provider.target_id is None # 默认 context
def test_init_with_context(self):
"""测试带 context 初始化"""
ctx = ProviderContext(target_id=1, scan_id=100)
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain",
context=ctx
)
assert provider.scan_id == 100
assert provider.target_id == 1
assert provider.snapshot_type == "subdomain"
@patch('apps.asset.services.snapshot.SubdomainSnapshotsService')
def test_iter_hosts_subdomain(self, mock_service_class):
"""测试从子域名快照迭代主机"""
# Mock service
mock_service = Mock()
mock_service.iter_subdomain_names_by_scan.return_value = iter([
"a.example.com",
"b.example.com"
])
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain"
)
# 迭代主机
hosts = list(provider.iter_hosts())
assert hosts == ["a.example.com", "b.example.com"]
mock_service.iter_subdomain_names_by_scan.assert_called_once_with(
scan_id=100,
chunk_size=1000
)
@patch('apps.asset.services.snapshot.HostPortMappingSnapshotsService')
def test_iter_hosts_host_port(self, mock_service_class):
"""测试从主机端口映射快照迭代主机"""
# Mock queryset
mock_mapping1 = Mock()
mock_mapping1.host = "example.com"
mock_mapping1.port = 80
mock_mapping2 = Mock()
mock_mapping2.host = "example.com"
mock_mapping2.port = 443
mock_queryset = Mock()
mock_queryset.iterator.return_value = iter([mock_mapping1, mock_mapping2])
# Mock service
mock_service = Mock()
mock_service.get_by_scan.return_value = mock_queryset
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="host_port"
)
# 迭代主机
hosts = list(provider.iter_hosts())
assert hosts == ["example.com:80", "example.com:443"]
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
@patch('apps.asset.services.snapshot.WebsiteSnapshotsService')
def test_iter_urls_website(self, mock_service_class):
"""测试从网站快照迭代 URL"""
# Mock service
mock_service = Mock()
mock_service.iter_website_urls_by_scan.return_value = iter([
"http://example.com",
"https://example.com"
])
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="website"
)
# 迭代 URL
urls = list(provider.iter_urls())
assert urls == ["http://example.com", "https://example.com"]
mock_service.iter_website_urls_by_scan.assert_called_once_with(
scan_id=100,
chunk_size=1000
)
@patch('apps.asset.services.snapshot.EndpointSnapshotsService')
def test_iter_urls_endpoint(self, mock_service_class):
"""测试从端点快照迭代 URL"""
# Mock queryset
mock_endpoint1 = Mock()
mock_endpoint1.url = "http://example.com/api/v1"
mock_endpoint2 = Mock()
mock_endpoint2.url = "http://example.com/api/v2"
mock_queryset = Mock()
mock_queryset.iterator.return_value = iter([mock_endpoint1, mock_endpoint2])
# Mock service
mock_service = Mock()
mock_service.get_by_scan.return_value = mock_queryset
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="endpoint"
)
# 迭代 URL
urls = list(provider.iter_urls())
assert urls == ["http://example.com/api/v1", "http://example.com/api/v2"]
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
def test_iter_hosts_unsupported_type(self):
"""测试不支持的快照类型iter_hosts"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="website" # website 不支持 iter_hosts
)
hosts = list(provider.iter_hosts())
assert hosts == []
def test_iter_urls_unsupported_type(self):
"""测试不支持的快照类型iter_urls"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain" # subdomain 不支持 iter_urls
)
urls = list(provider.iter_urls())
assert urls == []
def test_get_blacklist_filter(self):
"""测试黑名单过滤器(快照模式不使用黑名单)"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain"
)
assert provider.get_blacklist_filter() is None
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(target_id=456, scan_id=789)
provider = SnapshotTargetProvider(
scan_id=100, # 会被 context 覆盖
snapshot_type="subdomain",
context=ctx
)
assert provider.target_id == 456
assert provider.scan_id == 100 # scan_id 在 __init__ 中被设置

View File

@@ -12,7 +12,7 @@
import ipaddress
import logging
from pathlib import Path
from typing import Dict, Any, Optional, List, Iterator, Tuple, Callable
from typing import Dict, Any, Optional, List, Iterator, Tuple
from django.db.models import QuerySet
@@ -485,8 +485,7 @@ class TargetExportService:
"""
from apps.targets.services import TargetService
from apps.targets.models import Target
from apps.asset.services.asset.subdomain_service import SubdomainService
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)

View File

@@ -1,36 +1,48 @@
"""
导出站点 URL 到 TXT 文件的 Task
使用 export_urls_with_fallback 用例函数处理回退链逻辑
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
数据源: WebSite.url → Default
"""
import logging
from typing import Optional
from pathlib import Path
from prefect import task
from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
@task(name="export_sites")
def export_sites_task(
target_id: int,
output_file: str,
target_id: Optional[int] = None,
output_file: str = "",
provider: Optional[TargetProvider] = None,
batch_size: int = 1000,
) -> dict:
"""
导出目标下的所有站点 URL 到 TXT 文件
数据源优先级(回退链)
支持两种模式
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
数据源优先级(回退链,仅传统模式):
1. WebSite 表 - 站点级别 URL
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
Args:
target_id: 目标 ID
target_id: 目标 ID(传统模式,向后兼容)
output_file: 输出文件路径(绝对路径)
provider: TargetProvider 实例(新模式)
batch_size: 每次读取的批次大小,默认 1000
Returns:
@@ -44,6 +56,17 @@ def export_sites_task(
ValueError: 参数错误
IOError: 文件写入失败
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# Provider 模式:使用 TargetProvider 导出
if provider is not None:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
return _export_with_provider(output_file, provider)
# 传统模式:使用 export_urls_with_fallback
logger.info("使用传统模式 - Target ID: %d", target_id)
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
@@ -62,3 +85,32 @@ def export_sites_task(
'output_file': result['output_file'],
'total_count': result['total_count'],
}
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
"""使用 Provider 导出 URL"""
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
f.write(f"{url}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
return {
'success': True,
'output_file': str(output_path),
'total_count': total_count,
}

View File

@@ -1,11 +1,16 @@
"""
导出 URL 任务
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
用于指纹识别前导出目标下的 URL 到文件
使用 export_urls_with_fallback 用例函数处理回退链逻辑
"""
import logging
from typing import Optional
from pathlib import Path
from prefect import task
@@ -13,33 +18,51 @@ from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
logger = logging.getLogger(__name__)
@task(name="export_urls_for_fingerprint")
def export_urls_for_fingerprint_task(
target_id: int,
output_file: str,
target_id: Optional[int] = None,
output_file: str = "",
source: str = 'website', # 保留参数,兼容旧调用(实际值由回退链决定)
provider: Optional[TargetProvider] = None,
batch_size: int = 1000
) -> dict:
"""
导出目标下的 URL 到文件(用于指纹识别)
数据源优先级(回退链)
支持两种模式
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
数据源优先级(回退链,仅传统模式):
1. WebSite 表 - 站点级别 URL
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
Args:
target_id: 目标 ID
target_id: 目标 ID(传统模式,向后兼容)
output_file: 输出文件路径
source: 数据源类型(保留参数,兼容旧调用,实际值由回退链决定)
provider: TargetProvider 实例(新模式)
batch_size: 批量读取大小
Returns:
dict: {'output_file': str, 'total_count': int, 'source': str}
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# Provider 模式:使用 TargetProvider 导出
if provider is not None:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
return _export_with_provider(output_file, provider)
# 传统模式:使用 export_urls_with_fallback
logger.info("使用传统模式 - Target ID: %d", target_id)
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
@@ -58,3 +81,32 @@ def export_urls_for_fingerprint_task(
'total_count': result['total_count'],
'source': result['source'],
}
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
"""使用 Provider 导出 URL"""
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
f.write(f"{url}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
return {
'output_file': str(output_path),
'total_count': total_count,
'source': 'provider',
}

View File

@@ -1,7 +1,9 @@
"""
导出主机列表到 TXT 文件的 Task
使用 TargetExportService.export_hosts() 统一处理导出逻辑
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
根据 Target 类型决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名
@@ -9,57 +11,89 @@
- CIDR: 展开 CIDR 范围内的所有 IP
"""
import logging
from pathlib import Path
from typing import Optional
from prefect import task
from apps.scan.services.target_export_service import create_export_service
from apps.scan.providers import DatabaseTargetProvider, TargetProvider
logger = logging.getLogger(__name__)
@task(name="export_hosts")
def export_hosts_task(
target_id: int,
output_file: str,
batch_size: int = 1000
target_id: Optional[int] = None,
provider: Optional[TargetProvider] = None,
) -> dict:
"""
导出主机列表到 TXT 文件
支持两种模式:
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
根据 Target 类型自动决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名(流式处理,支持 10万+ 域名)
- IP: 直接写入 target.name单个 IP
- CIDR: 展开 CIDR 范围内的所有可用 IP
Args:
target_id: 目标 ID
output_file: 输出文件路径(绝对路径)
batch_size: 每次读取的批次大小,默认 1000仅对 DOMAIN 类型有效
target_id: 目标 ID传统模式向后兼容
provider: TargetProvider 实例(新模式)
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int,
'target_type': str
'target_type': str # 仅传统模式返回
}
Raises:
ValueError: Target 不存在
ValueError: 参数错误target_id 和 provider 都未提供)
IOError: 文件写入失败
"""
# 使用工厂函数创建导出服务
export_service = create_export_service(target_id)
result = export_service.export_hosts(
target_id=target_id,
output_path=output_file,
batch_size=batch_size
)
# 保持返回值格式不变(向后兼容)
return {
'success': result['success'],
'output_file': result['output_file'],
'total_count': result['total_count'],
'target_type': result['target_type']
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# 向后兼容:如果没有提供 provider使用 target_id 创建 DatabaseTargetProvider
use_legacy_mode = provider is None
if use_legacy_mode:
logger.info("使用传统模式 - Target ID: %d", target_id)
provider = DatabaseTargetProvider(target_id=target_id)
else:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
# 确保输出目录存在
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 使用 Provider 导出主机列表iter_hosts 内部已处理黑名单过滤)
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for host in provider.iter_hosts():
f.write(f"{host}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个主机...", total_count)
logger.info("✓ 主机列表导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
result = {
'success': True,
'output_file': str(output_path),
'total_count': total_count,
}
# 传统模式:保持返回值格式不变(向后兼容)
if use_legacy_mode:
from apps.targets.services import TargetService
target = TargetService().get_target(target_id)
result['target_type'] = target.type if target else 'unknown'
return result

View File

@@ -1,8 +1,9 @@
"""
导出站点URL到文件的Task
直接使用 HostPortMapping 表查询 host+port 组合拼接成URL格式写入文件
使用 TargetExportService.generate_default_urls() 处理默认值回退逻辑
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
特殊逻辑:
- 80 端口:只生成 HTTP URL省略端口号
@@ -10,6 +11,7 @@
- 其他端口:生成 HTTP 和 HTTPS 两个URL带端口号
"""
import logging
from typing import Optional
from pathlib import Path
from prefect import task
@@ -17,6 +19,7 @@ from apps.asset.services import HostPortMappingService
from apps.scan.services.target_export_service import create_export_service
from apps.common.services import BlacklistService
from apps.common.utils import BlacklistFilter
from apps.scan.providers import TargetProvider, DatabaseTargetProvider, ProviderContext
logger = logging.getLogger(__name__)
@@ -39,26 +42,30 @@ def _generate_urls_from_port(host: str, port: int) -> list[str]:
@task(name="export_site_urls")
def export_site_urls_task(
target_id: int,
output_file: str,
target_id: Optional[int] = None,
provider: Optional[TargetProvider] = None,
batch_size: int = 1000
) -> dict:
"""
导出目标下的所有站点URL到文件(基于 HostPortMapping 表)
导出目标下的所有站点URL到文件
数据源: HostPortMapping (host + port) → Default
支持两种模式:
1. 传统模式(向后兼容):传入 target_id从 HostPortMapping 表导出
2. Provider 模式:传入 provider从任意数据源导出
特殊逻辑:
传统模式特殊逻辑:
- 80 端口:只生成 HTTP URL省略端口号
- 443 端口:只生成 HTTPS URL省略端口号
- 其他端口:生成 HTTP 和 HTTPS 两个URL带端口号
回退逻辑:
回退逻辑(仅传统模式)
- 如果 HostPortMapping 为空,使用 generate_default_urls() 生成默认 URL
Args:
target_id: 目标ID
output_file: 输出文件路径(绝对路径)
target_id: 目标ID传统模式向后兼容
provider: TargetProvider 实例(新模式)
batch_size: 每次处理的批次大小
Returns:
@@ -66,14 +73,62 @@ def export_site_urls_task(
'success': bool,
'output_file': str,
'total_urls': int,
'association_count': int, # 主机端口关联数量
'source': str, # 数据来源: "host_port" | "default"
'association_count': int, # 主机端口关联数量(仅传统模式)
'source': str, # 数据来源: "host_port" | "default" | "provider"
}
Raises:
ValueError: 参数错误
IOError: 文件写入失败
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# 向后兼容:如果没有提供 provider使用传统模式
if provider is None:
logger.info("使用传统模式 - Target ID: %d, 输出文件: %s", target_id, output_file)
return _export_site_urls_legacy(target_id, output_file, batch_size)
# Provider 模式
logger.info("使用 Provider 模式 - Provider: %s, 输出文件: %s", type(provider).__name__, output_file)
# 确保输出目录存在
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 使用 Provider 导出 URL 列表
total_urls = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
f.write(f"{url}\n")
total_urls += 1
if total_urls % 1000 == 0:
logger.info("已导出 %d 个URL...", total_urls)
logger.info("✓ URL导出完成 - 总数: %d, 文件: %s", total_urls, str(output_path))
return {
'success': True,
'output_file': str(output_path),
'total_urls': total_urls,
'source': 'provider',
}
def _export_site_urls_legacy(target_id: int, output_file: str, batch_size: int) -> dict:
"""
传统模式:从 HostPortMapping 表导出 URL
保持原有逻辑不变,确保向后兼容
"""
logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file)
# 确保输出目录存在

View File

@@ -20,63 +20,40 @@ Note:
"""
import logging
import uuid
import subprocess
from pathlib import Path
import uuid
from datetime import datetime
from prefect import task
from pathlib import Path
from typing import List
from prefect import task
logger = logging.getLogger(__name__)
# 注:使用纯系统命令实现,无需 Python 缓冲区配置
# 工具amass/subfinder输出已是小写且标准化
@task(
name='merge_and_deduplicate',
retries=1,
log_prints=True
)
def merge_and_validate_task(
result_files: List[str],
result_dir: str
) -> str:
"""
合并扫描结果并去重(高性能流式处理)
流程:
1. 使用 LC_ALL=C sort -u 直接处理多文件
2. 排序去重一步完成
3. 返回去重后的文件路径
命令:LC_ALL=C sort -u file1 file2 file3 -o output
注:工具输出已标准化(小写,无空行),无需额外处理
Args:
result_files: 结果文件路径列表
result_dir: 结果目录
Returns:
str: 去重后的域名文件路径
Raises:
RuntimeError: 处理失败
Performance:
- 纯系统命令(C语言实现),单进程极简
- LC_ALL=C: 字节序比较
- sort -u: 直接处理多文件(无管道开销)
Design:
- 极简单命令,无冗余处理
- 单进程直接执行(无管道/重定向开销)
- 内存占用仅在 sort 阶段(外部排序,不会 OOM
"""
logger.info("开始合并并去重 %d 个结果文件(系统命令优化)", len(result_files))
result_path = Path(result_dir)
# 验证文件存在性
def _count_file_lines(file_path: str) -> int:
"""使用 wc -l 统计文件行数,失败时返回 0"""
try:
result = subprocess.run(
["wc", "-l", file_path],
check=True,
capture_output=True,
text=True,
)
return int(result.stdout.strip().split()[0])
except (subprocess.CalledProcessError, ValueError, IndexError):
return 0
def _calculate_timeout(total_lines: int) -> int:
"""根据总行数计算超时时间(每行约 0.1 秒,最少 600 秒)"""
if total_lines <= 0:
return 3600
return max(600, int(total_lines * 0.1))
def _validate_input_files(result_files: List[str]) -> List[str]:
"""验证输入文件存在性,返回有效文件列表"""
valid_files = []
for file_path_str in result_files:
file_path = Path(file_path_str)
@@ -84,112 +61,67 @@ def merge_and_validate_task(
valid_files.append(str(file_path))
else:
logger.warning("结果文件不存在: %s", file_path)
return valid_files
@task(name='merge_and_deduplicate', retries=1, log_prints=True)
def merge_and_validate_task(result_files: List[str], result_dir: str) -> str:
"""
合并扫描结果并去重(高性能流式处理)
使用 LC_ALL=C sort -u 直接处理多文件,排序去重一步完成。
Args:
result_files: 结果文件路径列表
result_dir: 结果目录
Returns:
去重后的域名文件路径
Raises:
RuntimeError: 处理失败
"""
logger.info("开始合并并去重 %d 个结果文件", len(result_files))
valid_files = _validate_input_files(result_files)
if not valid_files:
raise RuntimeError("所有结果文件都不存在")
# 生成输出文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
short_uuid = uuid.uuid4().hex[:4]
merged_file = result_path / f"merged_{timestamp}_{short_uuid}.txt"
merged_file = Path(result_dir) / f"merged_{timestamp}_{short_uuid}.txt"
# 计算超时时间
total_lines = sum(_count_file_lines(f) for f in valid_files)
timeout = _calculate_timeout(total_lines)
logger.info("合并去重: 输入总行数=%d, timeout=%d", total_lines, timeout)
# 执行合并去重命令
cmd = f"LC_ALL=C sort -u {' '.join(valid_files)} -o {merged_file}"
logger.debug("执行命令: %s", cmd)
try:
# ==================== 使用系统命令一步完成:排序去重 ====================
# LC_ALL=C: 使用字节序比较(比locale快20-30%)
# sort -u: 直接处理多文件,排序去重
# -o: 安全输出(比重定向更可靠)
cmd = f"LC_ALL=C sort -u {' '.join(valid_files)} -o {merged_file}"
logger.debug("执行命令: %s", cmd)
subprocess.run(cmd, shell=True, check=True, timeout=timeout)
except subprocess.TimeoutExpired as exc:
raise RuntimeError("合并去重超时,请检查数据量或系统资源") from exc
except subprocess.CalledProcessError as exc:
raise RuntimeError(f"系统命令执行失败: {exc.stderr or exc}") from exc
# 按输入文件总行数动态计算超时时间
total_lines = 0
for file_path in valid_files:
try:
line_count_proc = subprocess.run(
["wc", "-l", file_path],
check=True,
capture_output=True,
text=True,
)
total_lines += int(line_count_proc.stdout.strip().split()[0])
except (subprocess.CalledProcessError, ValueError, IndexError):
continue
# 验证输出文件
if not merged_file.exists():
raise RuntimeError("合并文件未被创建")
timeout = 3600
if total_lines > 0:
# 按行数线性计算:每行约 0.1 秒
base_per_line = 0.1
est = int(total_lines * base_per_line)
timeout = max(600, est)
unique_count = _count_file_lines(str(merged_file))
if unique_count == 0:
# 降级为 Python 统计
with open(merged_file, 'r', encoding='utf-8') as f:
unique_count = sum(1 for _ in f)
logger.info(
"Subdomain 合并去重 timeout 自动计算: 输入总行数=%d, timeout=%d",
total_lines,
timeout,
)
if unique_count == 0:
raise RuntimeError("未找到任何有效域名")
result = subprocess.run(
cmd,
shell=True,
check=True,
timeout=timeout
)
logger.debug("✓ 合并去重完成")
# ==================== 统计结果 ====================
if not merged_file.exists():
raise RuntimeError("合并文件未被创建")
# 统计行数(使用系统命令提升大文件性能)
try:
line_count_proc = subprocess.run(
["wc", "-l", str(merged_file)],
check=True,
capture_output=True,
text=True
)
unique_count = int(line_count_proc.stdout.strip().split()[0])
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
logger.warning(
"wc -l 统计失败(文件: %s),降级为 Python 逐行统计 - 错误: %s",
merged_file, e
)
unique_count = 0
with open(merged_file, 'r', encoding='utf-8') as file_obj:
for _ in file_obj:
unique_count += 1
if unique_count == 0:
raise RuntimeError("未找到任何有效域名")
file_size = merged_file.stat().st_size
logger.info(
"✓ 合并去重完成 - 去重后: %d 个域名, 文件大小: %.2f KB",
unique_count,
file_size / 1024
)
return str(merged_file)
except subprocess.TimeoutExpired:
error_msg = "合并去重超时(>60分钟请检查数据量或系统资源"
logger.warning(error_msg) # 超时是可预期的
raise RuntimeError(error_msg)
except subprocess.CalledProcessError as e:
error_msg = f"系统命令执行失败: {e.stderr if e.stderr else str(e)}"
logger.warning(error_msg) # 超时是可预期的
raise RuntimeError(error_msg) from e
except IOError as e:
error_msg = f"文件读写失败: {e}"
logger.warning(error_msg) # 超时是可预期的
raise RuntimeError(error_msg) from e
except Exception as e:
error_msg = f"合并去重失败: {e}"
logger.error(error_msg, exc_info=True)
raise
file_size_kb = merged_file.stat().st_size / 1024
logger.info("✓ 合并去重完成 - 去重后: %d 个域名, 文件大小: %.2f KB", unique_count, file_size_kb)
return str(merged_file)

View File

@@ -1,7 +1,7 @@
"""
运行扫描工具任务
负责运行单个子域名扫描工具(amass、subfinder 等)
负责运行单个子域名扫描工具subfinder、sublist3r 等)
"""
import logging
@@ -58,7 +58,7 @@ def run_subdomain_discovery_task(
timeout=timeout,
log_file=log_file # 明确指定日志文件路径
)
# 验证输出文件是否生成
if not output_file_path.exists():
logger.warning(

View File

@@ -0,0 +1,240 @@
"""
Task 向后兼容性测试
Property 8: Task Backward Compatibility
*For any* 任务调用,当仅提供 target_id 参数时,任务应该创建 DatabaseTargetProvider
并使用它进行数据访问,行为与改造前一致。
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
"""
import tempfile
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from hypothesis import given, strategies as st, settings
from apps.scan.tasks.port_scan.export_hosts_task import export_hosts_task
from apps.scan.tasks.site_scan.export_site_urls_task import export_site_urls_task
from apps.scan.providers import ListTargetProvider
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
class TestExportHostsTaskBackwardCompatibility:
"""export_hosts_task 向后兼容性测试"""
@given(
target_id=st.integers(min_value=1, max_value=1000),
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=10)
)
@settings(max_examples=50, deadline=None)
def test_property_8_legacy_mode_creates_database_provider(self, target_id, hosts):
"""
Property 8: Task Backward Compatibility (export_hosts_task)
Feature: scan-target-provider, Property 8: Task Backward Compatibility
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
For any target_id, when calling export_hosts_task with only target_id,
it should create a DatabaseTargetProvider and use it for data access.
"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
# Mock Target 和 SubdomainService
mock_target = MagicMock()
mock_target.type = 'domain'
mock_target.name = hosts[0]
with patch('apps.scan.tasks.port_scan.export_hosts_task.DatabaseTargetProvider') as mock_provider_class, \
patch('apps.targets.services.TargetService') as mock_target_service:
# 创建 mock provider 实例
mock_provider = MagicMock()
mock_provider.iter_hosts.return_value = iter(hosts)
mock_provider.get_blacklist_filter.return_value = None
mock_provider_class.return_value = mock_provider
# Mock TargetService
mock_target_service.return_value.get_target.return_value = mock_target
# 调用任务(传统模式:只传 target_id
result = export_hosts_task(
output_file=output_file,
target_id=target_id
)
# 验证:应该创建了 DatabaseTargetProvider
mock_provider_class.assert_called_once_with(target_id=target_id)
# 验证:返回值包含必需字段
assert result['success'] is True
assert result['output_file'] == output_file
assert result['total_count'] == len(hosts)
assert 'target_type' in result # 传统模式应该返回 target_type
# 验证:文件内容正确
with open(output_file, 'r') as f:
lines = [line.strip() for line in f.readlines()]
assert lines == hosts
finally:
Path(output_file).unlink(missing_ok=True)
def test_legacy_mode_with_provider_parameter(self):
"""测试当同时提供 target_id 和 provider 时provider 优先"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
hosts = ['example.com', 'test.com']
provider = ListTargetProvider(targets=hosts)
# 调用任务(同时提供 target_id 和 provider
result = export_hosts_task(
output_file=output_file,
target_id=123, # 应该被忽略
provider=provider
)
# 验证:使用了 provider
assert result['success'] is True
assert result['total_count'] == len(hosts)
assert 'target_type' not in result # Provider 模式不返回 target_type
# 验证:文件内容正确
with open(output_file, 'r') as f:
lines = [line.strip() for line in f.readlines()]
assert lines == hosts
finally:
Path(output_file).unlink(missing_ok=True)
def test_error_when_no_parameters(self):
"""测试当 target_id 和 provider 都未提供时抛出错误"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
with pytest.raises(ValueError, match="必须提供 target_id 或 provider 参数之一"):
export_hosts_task(output_file=output_file)
finally:
Path(output_file).unlink(missing_ok=True)
class TestExportSiteUrlsTaskBackwardCompatibility:
"""export_site_urls_task 向后兼容性测试"""
def test_property_8_legacy_mode_uses_traditional_logic(self):
"""
Property 8: Task Backward Compatibility (export_site_urls_task)
Feature: scan-target-provider, Property 8: Task Backward Compatibility
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
When calling export_site_urls_task with only target_id,
it should use the traditional logic (_export_site_urls_legacy).
"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
target_id = 123
# Mock HostPortMappingService
mock_associations = [
{'host': 'example.com', 'port': 80},
{'host': 'test.com', 'port': 443},
]
with patch('apps.scan.tasks.site_scan.export_site_urls_task.HostPortMappingService') as mock_service_class, \
patch('apps.scan.tasks.site_scan.export_site_urls_task.BlacklistService') as mock_blacklist_service:
# Mock HostPortMappingService
mock_service = MagicMock()
mock_service.iter_host_port_by_target.return_value = iter(mock_associations)
mock_service_class.return_value = mock_service
# Mock BlacklistService
mock_blacklist = MagicMock()
mock_blacklist.get_rules.return_value = []
mock_blacklist_service.return_value = mock_blacklist
# 调用任务(传统模式:只传 target_id
result = export_site_urls_task(
output_file=output_file,
target_id=target_id
)
# 验证:返回值包含传统模式的字段
assert result['success'] is True
assert result['output_file'] == output_file
assert result['total_urls'] == 2 # 80 端口生成 1 个 URL443 端口生成 1 个 URL
assert 'association_count' in result # 传统模式应该返回 association_count
assert result['association_count'] == 2
assert result['source'] == 'host_port'
# 验证:文件内容正确
with open(output_file, 'r') as f:
lines = [line.strip() for line in f.readlines()]
assert 'http://example.com' in lines
assert 'https://test.com' in lines
finally:
Path(output_file).unlink(missing_ok=True)
def test_provider_mode_uses_provider_logic(self):
"""测试当提供 provider 时使用 Provider 模式"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
urls = ['https://example.com', 'https://test.com']
provider = ListTargetProvider(targets=urls)
# 调用任务Provider 模式)
result = export_site_urls_task(
output_file=output_file,
provider=provider
)
# 验证:使用了 provider
assert result['success'] is True
assert result['total_urls'] == len(urls)
assert 'association_count' not in result # Provider 模式不返回 association_count
assert result['source'] == 'provider'
# 验证:文件内容正确
with open(output_file, 'r') as f:
lines = [line.strip() for line in f.readlines()]
assert lines == urls
finally:
Path(output_file).unlink(missing_ok=True)
def test_error_when_no_parameters(self):
"""测试当 target_id 和 provider 都未提供时抛出错误"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
with pytest.raises(ValueError, match="必须提供 target_id 或 provider 参数之一"):
export_site_urls_task(output_file=output_file)
finally:
Path(output_file).unlink(missing_ok=True)

View File

@@ -1,17 +1,23 @@
"""
导出站点 URL 列表任务
使用 export_urls_with_fallback 用例函数处理回退链逻辑
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
数据源: WebSite.url → Default用于 katana 等爬虫工具)
"""
import logging
from typing import Optional
from pathlib import Path
from prefect import task
from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
logger = logging.getLogger(__name__)
@@ -23,21 +29,27 @@ logger = logging.getLogger(__name__)
)
def export_sites_task(
output_file: str,
target_id: int,
scan_id: int,
target_id: Optional[int] = None,
scan_id: Optional[int] = None,
provider: Optional[TargetProvider] = None,
batch_size: int = 1000
) -> dict:
"""
导出站点 URL 列表到文件(用于 katana 等爬虫工具)
数据源优先级(回退链)
支持两种模式
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
数据源优先级(回退链,仅传统模式):
1. WebSite 表 - 站点级别 URL
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
Args:
output_file: 输出文件路径
target_id: 目标 ID
target_id: 目标 ID(传统模式,向后兼容)
scan_id: 扫描 ID保留参数兼容旧调用
provider: TargetProvider 实例(新模式)
batch_size: 批次大小(内存优化)
Returns:
@@ -50,6 +62,17 @@ def export_sites_task(
ValueError: 参数错误
RuntimeError: 执行失败
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# Provider 模式:使用 TargetProvider 导出
if provider is not None:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
return _export_with_provider(output_file, provider)
# 传统模式:使用 export_urls_with_fallback
logger.info("使用传统模式 - Target ID: %d", target_id)
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
@@ -67,3 +90,31 @@ def export_sites_task(
'output_file': result['output_file'],
'asset_count': result['total_count'],
}
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
"""使用 Provider 导出 URL"""
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
f.write(f"{url}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
return {
'output_file': str(output_path),
'asset_count': total_count,
}

View File

@@ -1,15 +1,18 @@
"""导出 Endpoint URL 到文件的 Task
使用 export_urls_with_fallback 用例函数处理回退链逻辑
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
数据源优先级(回退链):
数据源优先级(回退链,仅传统模式
1. Endpoint.url - 最精细的 URL含路径、参数等
2. WebSite.url - 站点级别 URL
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
"""
import logging
from typing import Dict
from typing import Dict, Optional
from pathlib import Path
from prefect import task
@@ -17,26 +20,33 @@ from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
logger = logging.getLogger(__name__)
@task(name="export_endpoints")
def export_endpoints_task(
target_id: int,
output_file: str,
target_id: Optional[int] = None,
output_file: str = "",
provider: Optional[TargetProvider] = None,
batch_size: int = 1000,
) -> Dict[str, object]:
"""导出目标下的所有 Endpoint URL 到文本文件。
数据源优先级(回退链)
支持两种模式
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
数据源优先级(回退链,仅传统模式):
1. Endpoint 表 - 最精细的 URL含路径、参数等
2. WebSite 表 - 站点级别 URL
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
Args:
target_id: 目标 ID
target_id: 目标 ID(传统模式,向后兼容)
output_file: 输出文件路径(绝对路径)
provider: TargetProvider 实例(新模式)
batch_size: 每次从数据库迭代的批大小
Returns:
@@ -44,9 +54,20 @@ def export_endpoints_task(
"success": bool,
"output_file": str,
"total_count": int,
"source": str, # 数据来源: "endpoint" | "website" | "default" | "none"
"source": str, # 数据来源: "endpoint" | "website" | "default" | "none" | "provider"
}
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# Provider 模式:使用 TargetProvider 导出
if provider is not None:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
return _export_with_provider(output_file, provider)
# 传统模式:使用 export_urls_with_fallback
logger.info("使用传统模式 - Target ID: %d", target_id)
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
@@ -65,3 +86,33 @@ def export_endpoints_task(
"total_count": result['total_count'],
"source": result['source'],
}
def _export_with_provider(output_file: str, provider: TargetProvider) -> Dict[str, object]:
"""使用 Provider 导出 URL"""
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
f.write(f"{url}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
return {
"success": True,
"output_file": str(output_path),
"total_count": total_count,
"source": "provider",
}

View File

@@ -4,37 +4,40 @@
提供扫描相关的工具函数。
"""
from .directory_cleanup import remove_directory
from . import config_parser
from .command_builder import build_scan_command
from .command_executor import execute_and_wait, execute_stream
from .wordlist_helpers import ensure_wordlist_local
from .directory_cleanup import remove_directory
from .nuclei_helpers import ensure_nuclei_templates_local
from .performance import FlowPerformanceTracker, CommandPerformanceTracker
from .workspace_utils import setup_scan_workspace, setup_scan_directory
from .performance import CommandPerformanceTracker, FlowPerformanceTracker
from .system_load import check_system_load, wait_for_system_load
from .user_logger import user_log
from . import config_parser
from .wordlist_helpers import ensure_wordlist_local
from .workspace_utils import setup_scan_directory, setup_scan_workspace
__all__ = [
# 目录清理
'remove_directory',
# 工作空间
'setup_scan_workspace', # 创建 Scan 根工作空间
'setup_scan_directory', # 创建扫描子目录
'setup_scan_workspace',
'setup_scan_directory',
# 命令构建
'build_scan_command', # 扫描工具命令构建(基于 f-string
'build_scan_command',
# 命令执行
'execute_and_wait', # 等待式执行(文件输出)
'execute_stream', # 流式执行(实时处理)
'execute_and_wait',
'execute_stream',
# 系统负载
'wait_for_system_load',
'check_system_load',
# 字典文件
'ensure_wordlist_local', # 确保本地字典文件(含 hash 校验)
'ensure_wordlist_local',
# Nuclei 模板
'ensure_nuclei_templates_local', # 确保本地模板(含 commit hash 校验)
'ensure_nuclei_templates_local',
# 性能监控
'FlowPerformanceTracker', # Flow 性能追踪器(含系统资源采样)
'CommandPerformanceTracker', # 命令性能追踪器
'FlowPerformanceTracker',
'CommandPerformanceTracker',
# 扫描日志
'user_log', # 用户可见扫描日志记录
'user_log',
# 配置解析
'config_parser',
]

View File

@@ -12,16 +12,18 @@
import logging
import os
from django.conf import settings
import re
import signal
import subprocess
import threading
import time
from collections import deque
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Optional, Generator
from django.conf import settings
try:
# 可选依赖:用于根据 CPU / 内存负载做动态并发控制
import psutil
@@ -354,10 +356,13 @@ class CommandExecutor:
if log_file_path:
error_output = self._read_log_tail(log_file_path, max_lines=MAX_LOG_TAIL_LINES)
logger.warning(
"扫描工具 %s 返回非零状态码: %d (执行时间: %.2f秒)%s",
tool_name, returncode, duration,
f"\n错误输出:\n{error_output}" if error_output else ""
"扫描工具 %s 返回非零状态码: %d (执行时间: %.2f秒)",
tool_name, returncode, duration
)
if error_output:
for line in error_output.strip().split('\n'):
if line.strip():
logger.warning("%s", line)
else:
logger.info("✓ 扫描工具 %s 执行完成 (执行时间: %.2f秒)", tool_name, duration)
@@ -666,33 +671,68 @@ class CommandExecutor:
def _read_log_tail(self, log_file: Path, max_lines: int = MAX_LOG_TAIL_LINES) -> str:
"""
读取日志文件的末尾部分
读取日志文件的末尾部分(常量内存实现)
使用 seek 从文件末尾往前读取,避免将整个文件加载到内存。
Args:
log_file: 日志文件路径
max_lines: 最大读取行数
Returns:
日志内容(字符串),读取失败返回错误提示
"""
if not log_file.exists():
logger.debug("日志文件不存在: %s", log_file)
return ""
if log_file.stat().st_size == 0:
file_size = log_file.stat().st_size
if file_size == 0:
logger.debug("日志文件为空: %s", log_file)
return ""
# 每次读取的块大小8KB足够容纳大多数日志行
chunk_size = 8192
def decode_line(line_bytes: bytes) -> str:
"""解码单行:优先 UTF-8失败则降级 latin-1"""
try:
return line_bytes.decode('utf-8')
except UnicodeDecodeError:
return line_bytes.decode('latin-1', errors='replace')
try:
with open(log_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
return ''.join(lines[-max_lines:] if len(lines) > max_lines else lines)
except UnicodeDecodeError as e:
logger.warning("日志文件编码错误 (%s): %s", log_file, e)
return f"(无法读取日志文件: 编码错误 - {e})"
with open(log_file, 'rb') as f:
lines_found: deque[bytes] = deque()
remaining = b''
position = file_size
while position > 0 and len(lines_found) < max_lines:
read_size = min(chunk_size, position)
position -= read_size
f.seek(position)
chunk = f.read(read_size) + remaining
parts = chunk.split(b'\n')
# 最前面的部分可能不完整,留到下次处理
remaining = parts[0]
# 其余部分是完整的行(从后往前收集,用 appendleft 保持顺序)
for part in reversed(parts[1:]):
if len(lines_found) >= max_lines:
break
lines_found.appendleft(part)
# 处理文件开头的行
if remaining and len(lines_found) < max_lines:
lines_found.appendleft(remaining)
return '\n'.join(decode_line(line) for line in lines_found)
except PermissionError as e:
logger.warning("日志文件权限不足 (%s): %s", log_file, e)
return f"(无法读取日志文件: 权限不足)"
return "(无法读取日志文件: 权限不足)"
except IOError as e:
logger.warning("日志文件读取IO错误 (%s): %s", log_file, e)
return f"(无法读取日志文件: IO错误 - {e})"

View File

@@ -0,0 +1,77 @@
"""
系统负载检查工具
提供统一的系统负载检查功能,用于:
- Flow 入口处检查系统资源是否充足
- 防止在高负载时启动新的扫描任务
"""
import logging
import time
import psutil
from django.conf import settings
logger = logging.getLogger(__name__)
# 动态并发控制阈值(可在 Django settings 中覆盖)
SCAN_CPU_HIGH: float = getattr(settings, 'SCAN_CPU_HIGH', 90.0)
SCAN_MEM_HIGH: float = getattr(settings, 'SCAN_MEM_HIGH', 80.0)
SCAN_LOAD_CHECK_INTERVAL: int = getattr(settings, 'SCAN_LOAD_CHECK_INTERVAL', 180)
def _get_current_load() -> tuple[float, float]:
"""获取当前 CPU 和内存使用率"""
return psutil.cpu_percent(interval=0.5), psutil.virtual_memory().percent
def wait_for_system_load(
cpu_threshold: float = SCAN_CPU_HIGH,
mem_threshold: float = SCAN_MEM_HIGH,
check_interval: int = SCAN_LOAD_CHECK_INTERVAL,
context: str = "task"
) -> None:
"""
等待系统负载降到阈值以下
在高负载时阻塞等待,直到 CPU 和内存都低于阈值。
用于 Flow 入口处,防止在资源紧张时启动新任务。
"""
while True:
cpu, mem = _get_current_load()
if cpu < cpu_threshold and mem < mem_threshold:
logger.debug(
"[%s] 系统负载正常: cpu=%.1f%%, mem=%.1f%%",
context, cpu, mem
)
return
logger.info(
"[%s] 系统负载较高,等待资源释放: "
"cpu=%.1f%% (阈值 %.1f%%), mem=%.1f%% (阈值 %.1f%%)",
context, cpu, cpu_threshold, mem, mem_threshold
)
time.sleep(check_interval)
def check_system_load(
cpu_threshold: float = SCAN_CPU_HIGH,
mem_threshold: float = SCAN_MEM_HIGH
) -> dict:
"""
检查当前系统负载(非阻塞)
Returns:
dict: cpu_percent, mem_percent, cpu_threshold, mem_threshold, is_overloaded
"""
cpu, mem = _get_current_load()
return {
'cpu_percent': cpu,
'mem_percent': mem,
'cpu_threshold': cpu_threshold,
'mem_threshold': mem_threshold,
'is_overloaded': cpu >= cpu_threshold or mem >= mem_threshold,
}

View File

@@ -12,16 +12,34 @@ load-plugins = "pylint_django"
[tool.pylint.messages_control]
disable = [
"missing-docstring",
"invalid-name",
"too-few-public-methods",
"no-member",
"import-error",
"no-name-in-module",
"missing-docstring",
"invalid-name",
"too-few-public-methods",
"no-member",
"import-error",
"no-name-in-module",
"wrong-import-position", # 允许函数内导入(防循环依赖)
"import-outside-toplevel", # 同上
"too-many-arguments", # Django 视图/服务方法参数常超过5个
"too-many-locals", # 复杂业务逻辑局部变量多
"duplicate-code", # 某些模式代码相似是正常的
]
[tool.pylint.format]
max-line-length = 120
[tool.pylint.basic]
good-names = ["i", "j", "k", "ex", "Run", "_", "id", "pk", "ip", "url", "db", "qs"]
good-names = [
"i",
"j",
"k",
"ex",
"Run",
"_",
"id",
"pk",
"ip",
"url",
"db",
"qs",
]

View File

@@ -38,6 +38,7 @@ packaging>=21.0 # 版本比较
# 测试框架
pytest==8.0.0
pytest-django==4.7.0
hypothesis>=6.100.0 # 属性测试框架
# 工具库
python-dateutil==2.9.0

View File

@@ -2,7 +2,7 @@
# ============================================
# XingRin 远程节点安装脚本
# 用途:安装 Docker 环境 + 预拉取镜像
# 支持Ubuntu / Debian
# 支持Ubuntu / Debian / Kali
#
# 架构说明:
# 1. 安装 Docker 环境
@@ -101,8 +101,8 @@ detect_os() {
exit 1
fi
if [[ "$OS" != "ubuntu" && "$OS" != "debian" ]]; then
log_error "仅支持 Ubuntu/Debian 系统"
if [[ "$OS" != "ubuntu" && "$OS" != "debian" && "$OS" != "kali" ]]; then
log_error "仅支持 Ubuntu/Debian/Kali 系统"
exit 1
fi
}

View File

@@ -53,6 +53,8 @@ services:
# 统一挂载数据目录
- /opt/xingrin:/opt/xingrin
- /var/run/docker.sock:/var/run/docker.sock
# OOM 优先级:-500 保护核心服务
oom_score_adj: -500
healthcheck:
# 使用专门的健康检查端点(无需认证)
test: ["CMD", "curl", "-f", "http://localhost:8888/api/health/"]
@@ -88,6 +90,8 @@ services:
args:
IMAGE_TAG: ${IMAGE_TAG:-dev}
restart: always
# OOM 优先级:-500 保护 Web 界面
oom_score_adj: -500
depends_on:
server:
condition: service_healthy
@@ -97,6 +101,8 @@ services:
context: ..
dockerfile: docker/nginx/Dockerfile
restart: always
# OOM 优先级:-500 保护入口网关
oom_score_adj: -500
depends_on:
server:
condition: service_healthy

View File

@@ -56,6 +56,8 @@ services:
- /opt/xingrin:/opt/xingrin
# Docker Socket 挂载:允许 Django 服务器执行本地 docker 命令(用于本地 Worker 任务分发)
- /var/run/docker.sock:/var/run/docker.sock
# OOM 优先级:-500 降低被 OOM Killer 选中的概率,保护核心服务
oom_score_adj: -500
healthcheck:
# 使用专门的健康检查端点(无需认证)
test: ["CMD", "curl", "-f", "http://localhost:8888/api/health/"]
@@ -88,6 +90,8 @@ services:
frontend:
image: ${DOCKER_USER:-yyhuni}/xingrin-frontend:${IMAGE_TAG:?IMAGE_TAG is required}
restart: always
# OOM 优先级:-500 保护 Web 界面
oom_score_adj: -500
depends_on:
server:
condition: service_healthy
@@ -95,6 +99,8 @@ services:
nginx:
image: ${DOCKER_USER:-yyhuni}/xingrin-nginx:${IMAGE_TAG:?IMAGE_TAG is required}
restart: always
# OOM 优先级:-500 保护入口网关
oom_score_adj: -500
depends_on:
server:
condition: service_healthy

View File

@@ -29,9 +29,6 @@ RUN go install -v github.com/projectdiscovery/httpx/cmd/httpx@latest && \
go install -v github.com/d3mondev/puredns/v2@latest && \
go install -v github.com/yyhuni/xingfinger@latest
# 安装 Amass v5禁用 CGO 以跳过 libpostal 依赖)
RUN CGO_ENABLED=0 go install -v github.com/owasp-amass/amass/v5/cmd/amass@main
# 安装漏洞扫描器
RUN go install github.com/hahwul/dalfox/v2@latest
@@ -45,7 +42,9 @@ ENV DEBIAN_FRONTEND=noninteractive
WORKDIR /app
# 1. 安装基础工具和 Python
RUN apt-get update && apt-get install -y \
# 注意ARM64 使用 ports.ubuntu.com可能存在镜像同步延迟需要重试机制
RUN apt-get update && \
apt-get install -y --no-install-recommends \
python3 \
python3-pip \
python3-venv \
@@ -64,12 +63,16 @@ RUN apt-get update && apt-get install -y \
libnss3 \
libxss1 \
libasound2t64 \
|| (rm -rf /var/lib/apt/lists/* && apt-get update && apt-get install -y --no-install-recommends \
python3 python3-pip python3-venv pipx git curl wget unzip jq tmux nmap masscan libpcap-dev \
ca-certificates fonts-liberation libnss3 libxss1 libasound2t64) \
&& rm -rf /var/lib/apt/lists/*
# 安装 Chromium通过 Playwright 安装,支持 ARM64 和 AMD64
# Ubuntu 24.04 的 chromium-browser 是 snap 过渡包Docker 中不可用
RUN pip install playwright --break-system-packages && \
playwright install chromium && \
apt-get update && \
playwright install-deps chromium && \
rm -rf /var/lib/apt/lists/*

View File

@@ -365,6 +365,7 @@ export function DashboardDataTable() {
columns={scanColumns}
getRowId={(row) => String(row.id)}
enableRowSelection={false}
enableAutoColumnSizing
pagination={scanPagination}
onPaginationChange={setScanPagination}
paginationInfo={scanPaginationInfo}

View File

@@ -99,6 +99,8 @@ export function ScanHistoryDataTable({
hideToolbar={hideToolbar}
// Empty state
emptyMessage={t("noData")}
// Auto column sizing
enableAutoColumnSizing
// Custom search box
toolbarLeft={
<div className="flex items-center space-x-2">

View File

@@ -84,6 +84,15 @@ function formatStageDuration(seconds?: number): string | undefined {
return secs > 0 ? `${minutes}m ${secs}s` : `${minutes}m`
}
// Status priority for sorting (lower = higher priority)
const STAGE_STATUS_PRIORITY: Record<StageStatus, number> = {
running: 0,
pending: 1,
completed: 2,
failed: 3,
cancelled: 4,
}
export function ScanOverview({ scanId }: ScanOverviewProps) {
const t = useTranslations("scan.history.overview")
const tStatus = useTranslations("scan.history.status")
@@ -326,7 +335,16 @@ export function ScanOverview({ scanId }: ScanOverviewProps) {
{scan.stageProgress && Object.keys(scan.stageProgress).length > 0 ? (
<div className="space-y-1 flex-1 min-h-0 overflow-y-auto pr-1">
{Object.entries(scan.stageProgress)
.sort(([, a], [, b]) => ((a as any).order ?? 0) - ((b as any).order ?? 0))
.sort(([, a], [, b]) => {
const progressA = a as any
const progressB = b as any
const priorityA = STAGE_STATUS_PRIORITY[progressA.status as StageStatus] ?? 99
const priorityB = STAGE_STATUS_PRIORITY[progressB.status as StageStatus] ?? 99
if (priorityA !== priorityB) {
return priorityA - priorityB
}
return (progressA.order ?? 0) - (progressB.order ?? 0)
})
.map(([stageName, progress]) => {
const stageProgress = progress as any
const isRunning = stageProgress.status === "running"
@@ -346,9 +364,6 @@ export function ScanOverview({ scanId }: ScanOverviewProps) {
<span className={cn("truncate", isRunning && "font-medium text-foreground")}>
{tProgress(`stages.${stageName}`)}
</span>
{isRunning && (
<span className="text-[10px] text-[#d29922] shrink-0"></span>
)}
</div>
<span className="text-xs text-muted-foreground font-mono shrink-0 ml-2">
{stageProgress.status === "completed" && stageProgress.duration

View File

@@ -78,5 +78,9 @@ export function ScanLogList({ logs, loading }: ScanLogListProps) {
)
}
return <AnsiLogViewer content={content} />
return (
<div className="h-full">
<AnsiLogViewer content={content} />
</div>
)
}

View File

@@ -280,7 +280,9 @@ export function ScanProgressDialog({
</div>
) : (
/* Log list */
<ScanLogList logs={logs} loading={logsLoading} />
<div className="h-[300px] overflow-hidden rounded-md">
<ScanLogList logs={logs} loading={logsLoading} />
</div>
)}
</DialogContent>
</Dialog>
@@ -350,13 +352,29 @@ function getStageResultCount(stageName: string, summary: ScanRecord["summary"]):
* Stage names come directly from engine_config keys, no mapping needed
* Stage order follows the order field, consistent with Flow execution order
*/
// Status priority for sorting (lower = higher priority)
const STATUS_PRIORITY: Record<StageStatus, number> = {
running: 0,
pending: 1,
completed: 2,
failed: 3,
cancelled: 4,
}
export function buildScanProgressData(scan: ScanRecord): ScanProgressData {
const stages: StageDetail[] = []
if (scan.stageProgress) {
// Sort by order then iterate
// Sort by status priority first, then by order
const sortedEntries = Object.entries(scan.stageProgress)
.sort(([, a], [, b]) => (a.order ?? 0) - (b.order ?? 0))
.sort(([, a], [, b]) => {
const priorityA = STATUS_PRIORITY[a.status] ?? 99
const priorityB = STATUS_PRIORITY[b.status] ?? 99
if (priorityA !== priorityB) {
return priorityA - priorityB
}
return (a.order ?? 0) - (b.order ?? 0)
})
for (const [stageName, progress] of sortedEntries) {
const resultCount = progress.status === "completed"

View File

@@ -55,9 +55,10 @@ function hasAnsiCodes(text: string): boolean {
// 解析纯文本日志内容,为日志级别添加颜色
function colorizeLogContent(content: string): string {
// 匹配日志格式: [时间] [级别] [模块:行号] 消息
// 例如: [2025-01-05 10:30:00] [INFO] [apps.scan:123] 消息内容
const logLineRegex = /^(\[\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\]) (\[(DEBUG|INFO|WARNING|WARN|ERROR|CRITICAL)\]) (.*)$/
// 匹配日志格式:
// 1) 系统日志: [2026-01-10 09:51:52] [INFO] [apps.scan.xxx:123] ...
// 2) 扫描日志: [09:50:37] [INFO] [subdomain_discovery] ...
const logLineRegex = /^(\[(?:\d{4}-\d{2}-\d{2} )?\d{2}:\d{2}:\d{2}\]) (\[(DEBUG|INFO|WARNING|WARN|ERROR|CRITICAL)\]) (.*)$/i
return content
.split("\n")
@@ -66,14 +67,15 @@ function colorizeLogContent(content: string): string {
if (match) {
const [, timestamp, levelBracket, level, rest] = match
const color = LOG_LEVEL_COLORS[level] || "#d4d4d4"
const levelUpper = level.toUpperCase()
const color = LOG_LEVEL_COLORS[levelUpper] || "#d4d4d4"
// ansiConverter.toHtml 已经处理了 HTML 转义
const escapedTimestamp = ansiConverter.toHtml(timestamp)
const escapedLevelBracket = ansiConverter.toHtml(levelBracket)
const escapedRest = ansiConverter.toHtml(rest)
// 时间戳灰色,日志级别带颜色,其余默认色
return `<span style="color:#808080">${escapedTimestamp}</span> <span style="color:${color};font-weight:${level === "CRITICAL" ? "bold" : "normal"}">${escapedLevelBracket}</span> ${escapedRest}`
return `<span style="color:#808080">${escapedTimestamp}</span> <span style="color:${color};font-weight:${levelUpper === "CRITICAL" ? "bold" : "normal"}">${escapedLevelBracket}</span> ${escapedRest}`
}
// 非标准格式的行,也进行 HTML 转义
@@ -85,16 +87,24 @@ function colorizeLogContent(content: string): string {
// 高亮搜索关键词
function highlightSearch(html: string, query: string): string {
if (!query.trim()) return html
// `ansi-to-html` 在 `escapeXML: true` 时,会把非 ASCII 字符(如中文)转成实体:
// 例如 "中文" => "&#x4E2D;&#x6587;"。
// 因此这里需要用同样的转义规则来生成可匹配的搜索串。
const escapedQueryForHtml = ansiConverter.toHtml(query)
// 转义正则特殊字符
const escapedQuery = query.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")
const regex = new RegExp(`(${escapedQuery})`, "gi")
const escapedQuery = escapedQueryForHtml.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")
const regex = new RegExp(`(${escapedQuery})`, "giu")
// 在标签外的文本中高亮关键词
return html.replace(/(<[^>]+>)|([^<]+)/g, (match, tag, text) => {
if (tag) return tag
if (text) {
return text.replace(regex, '<mark style="background:#fbbf24;color:#1e1e1e;border-radius:2px;padding:0 2px">$1</mark>')
return text.replace(
regex,
'<mark style="background:#fbbf24;color:#1e1e1e;border-radius:2px;padding:0 2px">$1</mark>'
)
}
return match
})
@@ -104,6 +114,8 @@ function highlightSearch(html: string, query: string): string {
const LOG_LEVEL_PATTERNS = [
// 标准格式: [2026-01-07 12:00:00] [INFO]
/^\[\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\] \[(DEBUG|INFO|WARNING|WARN|ERROR|CRITICAL)\]/i,
// 扫描日志格式: [09:50:37] [INFO] [stage]
/^\[\d{2}:\d{2}:\d{2}\] \[(DEBUG|INFO|WARNING|WARN|ERROR|CRITICAL)\]/i,
// Prefect 格式: 12:01:50.419 | WARNING | prefect
/^[\d:.]+\s+\|\s+(DEBUG|INFO|WARNING|WARN|ERROR|CRITICAL)\s+\|/i,
// 简单格式: [INFO] message 或 INFO: message

View File

@@ -1,7 +1,7 @@
"use client"
import * as React from "react"
import { useTranslations } from "next-intl"
import { useTranslations, useLocale } from "next-intl"
import {
ColumnFiltersState,
ColumnSizingState,
@@ -17,6 +17,7 @@ import {
VisibilityState,
Updater,
} from "@tanstack/react-table"
import { calculateColumnWidths } from "@/lib/table-utils"
import {
IconChevronDown,
IconLayoutColumns,
@@ -145,8 +146,12 @@ export function UnifiedDataTable<TData>({
// Styles
className,
tableClassName,
// Auto column sizing
enableAutoColumnSizing = false,
}: UnifiedDataTableProps<TData>) {
const tActions = useTranslations("common.actions")
const locale = useLocale()
// Internal state
const [internalRowSelection, setInternalRowSelection] = React.useState<Record<string, boolean>>({})
@@ -154,6 +159,7 @@ export function UnifiedDataTable<TData>({
const [internalSorting, setInternalSorting] = React.useState<SortingState>(defaultSorting)
const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>([])
const [columnSizing, setColumnSizing] = React.useState<ColumnSizingState>({})
const [autoSizingCalculated, setAutoSizingCalculated] = React.useState(false)
const [internalPagination, setInternalPagination] = React.useState<PaginationState>({
pageIndex: 0,
pageSize: 10,
@@ -232,6 +238,41 @@ export function UnifiedDataTable<TData>({
return (data || []).filter(item => item && typeof getRowId(item) !== 'undefined')
}, [data, getRowId])
// Auto column sizing: calculate optimal widths based on content
React.useEffect(() => {
if (!enableAutoColumnSizing || autoSizingCalculated || validData.length === 0) {
return
}
// Build header labels from column meta
const headerLabels: Record<string, string> = {}
for (const col of columns) {
const colDef = col as { accessorKey?: string; id?: string; meta?: { title?: string } }
const colId = colDef.accessorKey || colDef.id
if (colId && colDef.meta?.title) {
headerLabels[colId] = colDef.meta.title
}
}
const calculatedWidths = calculateColumnWidths({
data: validData as Record<string, unknown>[],
columns: columns as Array<{
accessorKey?: string
id?: string
size?: number
minSize?: number
maxSize?: number
}>,
headerLabels,
locale,
})
if (Object.keys(calculatedWidths).length > 0) {
setColumnSizing(calculatedWidths)
setAutoSizingCalculated(true)
}
}, [enableAutoColumnSizing, autoSizingCalculated, validData, columns])
// Create table instance
const table = useReactTable({
data: validData,

179
frontend/lib/table-utils.ts Normal file
View File

@@ -0,0 +1,179 @@
/**
* Table utility functions
* Provides column width calculation and other table-related utilities
*/
// Cache for text measurement context
let measureContext: CanvasRenderingContext2D | null = null
/**
* Get or create a canvas context for measuring text width
*/
function getMeasureContext(): CanvasRenderingContext2D {
if (!measureContext) {
const canvas = document.createElement('canvas')
measureContext = canvas.getContext('2d')!
}
return measureContext
}
/**
* Measure text width using canvas
* @param text - Text to measure
* @param font - CSS font string (e.g., "14px Inter, sans-serif")
* @returns Text width in pixels
*/
export function measureTextWidth(text: string, font: string = '14px Inter, system-ui, sans-serif'): number {
const ctx = getMeasureContext()
ctx.font = font
return ctx.measureText(text).width
}
/**
* Options for calculating column widths
*/
export interface CalculateColumnWidthsOptions<TData> {
/** Table data */
data: TData[]
/** Column definitions with accessorKey */
columns: Array<{
accessorKey?: string
id?: string
size?: number
minSize?: number
maxSize?: number
/** If true, skip auto-sizing for this column */
enableAutoSize?: boolean
}>
/** Font to use for measurement */
font?: string
/** Padding to add to each cell (in pixels) */
cellPadding?: number
/** Header font (usually slightly different from cell font) */
headerFont?: string
/** Header labels for columns (keyed by accessorKey or id) */
headerLabels?: Record<string, string>
/** Maximum number of rows to sample (for performance) */
maxSampleRows?: number
/** Locale for date formatting */
locale?: string
}
/**
* Calculate optimal column widths based on content
* Returns a map of column id -> calculated width
*/
/**
* Check if a string looks like an ISO date
*/
function isISODateString(value: string): boolean {
// Match ISO 8601 format: 2024-01-09T12:00:00.000Z or 2024-01-09T12:00:00
return /^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}/.test(value)
}
/**
* Format date for display (matching the app's date format)
*/
function formatDateForMeasurement(dateString: string, locale: string): string {
try {
return new Date(dateString).toLocaleString(locale, {
year: "numeric",
month: "numeric",
day: "numeric",
hour: "2-digit",
minute: "2-digit",
second: "2-digit",
hour12: false,
})
} catch {
return dateString
}
}
export function calculateColumnWidths<TData extends Record<string, unknown>>({
data,
columns,
font = '14px Inter, system-ui, sans-serif',
cellPadding = 32, // Default padding for cell content
headerFont = '500 14px Inter, system-ui, sans-serif',
headerLabels = {},
maxSampleRows = 100,
locale = 'zh-CN',
}: CalculateColumnWidthsOptions<TData>): Record<string, number> {
const widths: Record<string, number> = {}
// Sample data for performance (don't measure all rows if there are too many)
const sampleData = data.slice(0, maxSampleRows)
for (const column of columns) {
const columnId = column.accessorKey || column.id
if (!columnId) continue
// Skip columns that explicitly disable auto-sizing
if (column.enableAutoSize === false) {
if (column.size) {
widths[columnId] = column.size
}
continue
}
// Start with header width
const headerLabel = headerLabels[columnId] || columnId
let maxWidth = measureTextWidth(headerLabel, headerFont) + cellPadding
// Measure content width for each row
for (const row of sampleData) {
const value = row[columnId]
if (value == null) continue
// Convert value to string for measurement
let textValue: string
if (typeof value === 'string') {
// Check if it's a date string and format it
if (isISODateString(value)) {
textValue = formatDateForMeasurement(value, locale)
} else {
textValue = value
}
} else if (typeof value === 'number') {
textValue = String(value)
} else if (Array.isArray(value)) {
// For arrays, join with comma (rough estimate)
textValue = value.join(', ')
} else if (typeof value === 'object') {
// Skip complex objects - they need custom renderers
continue
} else {
textValue = String(value)
}
const contentWidth = measureTextWidth(textValue, font) + cellPadding
maxWidth = Math.max(maxWidth, contentWidth)
}
// Apply min/max constraints
if (column.minSize) {
maxWidth = Math.max(maxWidth, column.minSize)
}
if (column.maxSize) {
maxWidth = Math.min(maxWidth, column.maxSize)
}
widths[columnId] = Math.ceil(maxWidth)
}
return widths
}
/**
* Hook-friendly version that returns initial column sizing state
*/
export function getInitialColumnSizing<TData extends Record<string, unknown>>(
options: CalculateColumnWidthsOptions<TData>
): Record<string, number> {
// Only run on client side
if (typeof window === 'undefined') {
return {}
}
return calculateColumnWidths(options)
}

View File

@@ -136,6 +136,10 @@ export interface UnifiedDataTableProps<TData> {
// Styling
className?: string
tableClassName?: string
// Auto column sizing
/** Enable automatic column width calculation based on content */
enableAutoColumnSizing?: boolean
}
/**

176
update.sh Executable file
View File

@@ -0,0 +1,176 @@
#!/bin/bash
# ============================================
# XingRin 系统更新脚本
# 用途:更新代码 + 同步版本 + 重建镜像 + 重启服务
# ============================================
#
# 更新流程:
# 1. 停止服务
# 2. git pull 拉取最新代码
# 3. 合并 .env 新配置项 + 同步 VERSION
# 4. 构建/拉取镜像(开发模式构建,生产模式拉取)
# 5. 启动服务server 启动时自动执行数据库迁移)
#
# 用法:
# sudo ./update.sh 生产模式更新(拉取 Docker Hub 镜像)
# sudo ./update.sh --dev 开发模式更新(本地构建镜像)
# sudo ./update.sh --no-frontend 更新后只启动后端
# sudo ./update.sh --dev --no-frontend 开发环境更新后只启动后端
cd "$(dirname "$0")"
# 权限检查
if [ "$EUID" -ne 0 ]; then
echo -e "\033[0;31m[错误] 请使用 sudo 运行此脚本\033[0m"
echo -e " 正确用法: \033[1msudo ./update.sh\033[0m"
exit 1
fi
# 跨平台 sed -i兼容 macOS 和 Linux
sed_inplace() {
if [[ "$OSTYPE" == "darwin"* ]]; then
sed -i '' "$@"
else
sed -i "$@"
fi
}
# 解析参数判断模式
DEV_MODE=false
for arg in "$@"; do
case $arg in
--dev) DEV_MODE=true ;;
esac
done
# 颜色定义
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
RED='\033[0;31m'
BLUE='\033[0;34m'
CYAN='\033[0;36m'
BOLD='\033[1m'
NC='\033[0m'
# 合并 .env 新配置项(保留用户已有值)
merge_env_config() {
local example_file="docker/.env.example"
local env_file="docker/.env"
if [ ! -f "$example_file" ] || [ ! -f "$env_file" ]; then
return
fi
local new_keys=0
while IFS= read -r line || [ -n "$line" ]; do
[[ -z "$line" || "$line" =~ ^# ]] && continue
local key="${line%%=*}"
[[ -z "$key" || "$key" == "$line" ]] && continue
if ! grep -q "^${key}=" "$env_file"; then
printf '%s\n' "$line" >> "$env_file"
echo -e " ${GREEN}+${NC} 新增: $key"
((new_keys++))
fi
done < "$example_file"
if [ $new_keys -gt 0 ]; then
echo -e " ${GREEN}OK${NC} 已添加 $new_keys 个新配置项"
else
echo -e " ${GREEN}OK${NC} 配置已是最新"
fi
}
echo ""
echo -e "${BOLD}${BLUE}╔════════════════════════════════════════╗${NC}"
if [ "$DEV_MODE" = true ]; then
echo -e "${BOLD}${BLUE}║ 开发环境更新(本地构建) ║${NC}"
else
echo -e "${BOLD}${BLUE}║ 生产环境更新Docker Hub${NC}"
fi
echo -e "${BOLD}${BLUE}╚════════════════════════════════════════╝${NC}"
echo ""
# 测试性功能警告
echo -e "${BOLD}${YELLOW}[!] 警告:此功能为测试性功能,可能会导致升级失败${NC}"
echo -e "${YELLOW} 建议运行 ./uninstall.sh 后重新执行 ./install.sh 进行全新安装${NC}"
echo ""
echo -n -e "${YELLOW}是否继续更新?(y/N) ${NC}"
read -r ans_continue
ans_continue=${ans_continue:-N}
if [[ ! $ans_continue =~ ^[Yy]$ ]]; then
echo -e "${CYAN}已取消更新。${NC}"
exit 0
fi
echo ""
# Step 1: 停止服务
echo -e "${CYAN}[1/5]${NC} 停止服务..."
./stop.sh 2>&1 | sed 's/^/ /'
# Step 2: 拉取代码
echo ""
echo -e "${CYAN}[2/5]${NC} 拉取代码..."
git pull --rebase 2>&1 | sed 's/^/ /'
if [ $? -ne 0 ]; then
echo -e "${RED}[错误]${NC} git pull 失败,请手动解决冲突后重试"
exit 1
fi
# Step 3: 检查配置更新 + 版本同步
echo ""
echo -e "${CYAN}[3/5]${NC} 检查配置更新..."
merge_env_config
# 版本同步:从 VERSION 文件更新 IMAGE_TAG
if [ -f "VERSION" ]; then
NEW_VERSION=$(cat VERSION | tr -d '[:space:]')
if [ -n "$NEW_VERSION" ]; then
if grep -q "^IMAGE_TAG=" "docker/.env"; then
sed_inplace "s/^IMAGE_TAG=.*/IMAGE_TAG=$NEW_VERSION/" "docker/.env"
echo -e " ${GREEN}+${NC} 版本同步: IMAGE_TAG=$NEW_VERSION"
else
printf '%s\n' "IMAGE_TAG=$NEW_VERSION" >> "docker/.env"
echo -e " ${GREEN}+${NC} 新增版本: IMAGE_TAG=$NEW_VERSION"
fi
fi
fi
# Step 4: 构建/拉取镜像
echo ""
echo -e "${CYAN}[4/5]${NC} 更新镜像..."
if [ "$DEV_MODE" = true ]; then
# 开发模式:本地构建所有镜像(包括 Worker
echo -e " 构建 Worker 镜像..."
# 读取 IMAGE_TAG
IMAGE_TAG=$(grep "^IMAGE_TAG=" "docker/.env" | cut -d'=' -f2)
if [ -z "$IMAGE_TAG" ]; then
IMAGE_TAG="dev"
fi
# 构建 Worker 镜像Worker 是临时容器,不在 compose 中,需要单独构建)
docker build -t docker-worker -f docker/worker/Dockerfile . 2>&1 | sed 's/^/ /'
docker tag docker-worker docker-worker:${IMAGE_TAG} 2>&1 | sed 's/^/ /'
echo -e " ${GREEN}OK${NC} Worker 镜像已构建: docker-worker:${IMAGE_TAG}"
# 其他服务镜像由 start.sh --dev 构建
echo -e " 其他服务镜像将在启动时构建..."
else
# 生产模式:镜像由 start.sh 拉取
echo -e " 镜像将在启动时从 Docker Hub 拉取..."
fi
# Step 5: 启动服务
echo ""
echo -e "${CYAN}[5/5]${NC} 启动服务..."
./start.sh "$@"
echo ""
echo -e "${BOLD}${GREEN}════════════════════════════════════════${NC}"
echo -e "${BOLD}${GREEN} 更新完成!${NC}"
echo -e "${BOLD}${GREEN}════════════════════════════════════════${NC}"
echo ""