优化:性能优化

This commit is contained in:
yyhuni
2026-01-10 09:44:43 +08:00
parent 592a1958c4
commit eba394e14e
17 changed files with 1493 additions and 1895 deletions

View File

@@ -10,30 +10,30 @@
- 配置由 YAML 解析 - 配置由 YAML 解析
""" """
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
from prefect import flow
from prefect.task_runners import ThreadPoolTaskRunner
import hashlib import hashlib
import logging import logging
import os
import subprocess import subprocess
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
from apps.scan.tasks.directory_scan import ( from prefect import flow
export_sites_task,
run_and_stream_save_directories_task
)
from apps.scan.handlers.scan_flow_handlers import ( from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed, on_scan_flow_completed,
on_scan_flow_failed, on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.tasks.directory_scan import (
export_sites_task,
run_and_stream_save_directories_task,
)
from apps.scan.utils import (
build_scan_command,
ensure_wordlist_local,
user_log,
wait_for_system_load,
) )
from apps.scan.utils import config_parser, build_scan_command, ensure_wordlist_local, user_log
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -45,118 +45,79 @@ def calculate_directory_scan_timeout(
tool_config: dict, tool_config: dict,
base_per_word: float = 1.0, base_per_word: float = 1.0,
min_timeout: int = 60, min_timeout: int = 60,
max_timeout: int = 7200
) -> int: ) -> int:
""" """
根据字典行数计算目录扫描超时时间 根据字典行数计算目录扫描超时时间
计算公式:超时时间 = 字典行数 × 每个单词基础时间 计算公式:超时时间 = 字典行数 × 每个单词基础时间
超时范围:60秒 ~ 2小时7200秒 超时范围:最小 60 秒,无上限
Args: Args:
tool_config: 工具配置字典,包含 wordlist 路径 tool_config: 工具配置字典,包含 wordlist 路径
base_per_word: 每个单词的基础时间(秒),默认 1.0秒 base_per_word: 每个单词的基础时间(秒),默认 1.0秒
min_timeout: 最小超时时间(秒),默认 60秒 min_timeout: 最小超时时间(秒),默认 60秒
max_timeout: 最大超时时间(秒),默认 7200秒2小时
Returns: Returns:
int: 计算出的超时时间(秒)范围60 ~ 7200 int: 计算出的超时时间(秒)
Example:
# 1000行字典 × 1.0秒 = 1000秒 → 限制为7200秒中的 1000秒
# 10000行字典 × 1.0秒 = 10000秒 → 限制为7200秒最大值
timeout = calculate_directory_scan_timeout(
tool_config={'wordlist': '/path/to/wordlist.txt'}
)
""" """
try: import os
# 从 tool_config 中获取 wordlist 路径
wordlist_path = tool_config.get('wordlist') wordlist_path = tool_config.get('wordlist')
if not wordlist_path: if not wordlist_path:
logger.warning("工具配置中未指定 wordlist使用默认超时: %d", min_timeout) logger.warning("工具配置中未指定 wordlist使用默认超时: %d", min_timeout)
return min_timeout return min_timeout
# 展开用户目录(~
wordlist_path = os.path.expanduser(wordlist_path) wordlist_path = os.path.expanduser(wordlist_path)
# 检查文件是否存在
if not os.path.exists(wordlist_path): if not os.path.exists(wordlist_path):
logger.warning("字典文件不存在: %s,使用默认超时: %d", wordlist_path, min_timeout) logger.warning("字典文件不存在: %s,使用默认超时: %d", wordlist_path, min_timeout)
return min_timeout return min_timeout
# 使用 wc -l 快速统计字典行数 try:
result = subprocess.run( result = subprocess.run(
['wc', '-l', wordlist_path], ['wc', '-l', wordlist_path],
capture_output=True, capture_output=True,
text=True, text=True,
check=True check=True
) )
# wc -l 输出格式:行数 + 空格 + 文件名
line_count = int(result.stdout.strip().split()[0]) line_count = int(result.stdout.strip().split()[0])
timeout = max(min_timeout, int(line_count * base_per_word))
# 计算超时时间
timeout = int(line_count * base_per_word)
# 设置合理的下限(不再设置上限)
timeout = max(min_timeout, timeout)
logger.info( logger.info(
"目录扫描超时计算 - 字典: %s, 行数: %d, 基础时间: %.3f秒/词, 计算超时: %d", "目录扫描超时计算 - 字典: %s, 行数: %d, 基础时间: %.3f秒/词, 计算超时: %d",
wordlist_path, line_count, base_per_word, timeout wordlist_path, line_count, base_per_word, timeout
) )
return timeout return timeout
except subprocess.CalledProcessError as e: except (subprocess.CalledProcessError, ValueError, IndexError) as e:
logger.error("统计字典行数失败: %s", e) logger.error("计算超时时间失败: %s", e)
# 失败时返回默认超时
return min_timeout
except (ValueError, IndexError) as e:
logger.error("解析字典行数失败: %s", e)
return min_timeout
except Exception as e:
logger.error("计算超时时间异常: %s", e)
return min_timeout return min_timeout
def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> int: def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> int:
""" """从单个工具配置中获取 max_workers 参数"""
从单个工具配置中获取 max_workers 参数
Args:
tool_config: 单个工具的配置字典,如 {'max_workers': 10, 'threads': 5, ...}
default: 默认值,默认为 5
Returns:
int: max_workers 值
"""
if not isinstance(tool_config, dict): if not isinstance(tool_config, dict):
return default return default
# 支持 max_workers 和 max-workersYAML 中划线会被转换)
max_workers = tool_config.get('max_workers') or tool_config.get('max-workers') max_workers = tool_config.get('max_workers') or tool_config.get('max-workers')
if max_workers is not None and isinstance(max_workers, int) and max_workers > 0: if isinstance(max_workers, int) and max_workers > 0:
return max_workers return max_workers
return default return default
def _export_site_urls(
target_id: int,
directory_scan_dir: Path
def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]: ) -> Tuple[str, int]:
""" """
导出目标下的所有站点 URL 到文件(支持懒加载) 导出目标下的所有站点 URL 到文件
Args: Args:
target_id: 目标 ID target_id: 目标 ID
target_name: 目标名称(用于懒加载创建默认站点)
directory_scan_dir: 目录扫描目录 directory_scan_dir: 目录扫描目录
Returns: Returns:
tuple: (sites_file, site_count) tuple: (sites_file, site_count)
Raises:
ValueError: 站点数量为 0
""" """
logger.info("Step 1: 导出目标的所有站点 URL") logger.info("Step 1: 导出目标的所有站点 URL")
@@ -164,11 +125,10 @@ def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path
export_result = export_sites_task( export_result = export_sites_task(
target_id=target_id, target_id=target_id,
output_file=sites_file, output_file=sites_file,
batch_size=1000 # 每次读取 1000 条,优化内存占用 batch_size=1000
) )
site_count = export_result['total_count'] site_count = export_result['total_count']
logger.info( logger.info(
"✓ 站点 URL 导出完成 - 文件: %s, 数量: %d", "✓ 站点 URL 导出完成 - 文件: %s, 数量: %d",
export_result['output_file'], export_result['output_file'],
@@ -177,264 +137,50 @@ def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path
if site_count == 0: if site_count == 0:
logger.warning("目标下没有站点,无法执行目录扫描") logger.warning("目标下没有站点,无法执行目录扫描")
# 不抛出异常,由上层决定如何处理
# raise ValueError("目标下没有站点,无法执行目录扫描")
return export_result['output_file'], site_count return export_result['output_file'], site_count
def _run_scans_sequentially( def _generate_log_filename(
enabled_tools: dict, tool_name: str,
sites_file: str, site_url: str,
directory_scan_dir: Path, directory_scan_dir: Path
scan_id: int, ) -> Path:
target_id: int, """生成唯一的日志文件名(使用 URL 的 hash 确保并发时不会冲突)"""
site_count: int, url_hash = hashlib.md5(
target_name: str site_url.encode(),
) -> tuple[int, int, list]: usedforsecurity=False
""" ).hexdigest()[:8]
串行执行目录扫描任务(支持多工具)- 已废弃,保留用于兼容
Args:
enabled_tools: 启用的工具配置字典
sites_file: 站点文件路径
directory_scan_dir: 目录扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
site_count: 站点数量
target_name: 目标名称(用于错误日志)
Returns:
tuple: (total_directories, processed_sites, failed_sites)
"""
# 读取站点列表
sites = []
with open(sites_file, 'r', encoding='utf-8') as f:
for line in f:
site_url = line.strip()
if site_url:
sites.append(site_url)
logger.info("准备扫描 %d 个站点,使用工具: %s", len(sites), ', '.join(enabled_tools.keys()))
total_directories = 0
processed_sites_set = set() # 使用 set 避免重复计数
failed_sites = []
# 遍历每个工具
for tool_name, tool_config in enabled_tools.items():
logger.info("="*60)
logger.info("使用工具: %s", tool_name)
logger.info("="*60)
# 如果配置了 wordlist_name则先确保本地存在对应的字典文件含 hash 校验)
wordlist_name = tool_config.get('wordlist_name')
if wordlist_name:
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
failed_sites.extend(sites)
continue
# 逐个站点执行扫描
for idx, site_url in enumerate(sites, 1):
logger.info(
"[%d/%d] 开始扫描站点: %s (工具: %s)",
idx, len(sites), site_url, tool_name
)
# 使用统一的命令构建器
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={
'url': site_url
},
tool_config=tool_config
)
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
continue
# 单个站点超时:从配置中获取(支持 'auto' 动态计算)
# ffuf 逐个站点扫描timeout 就是单个站点的超时时间
site_timeout = tool_config.get('timeout', 300)
if site_timeout == 'auto':
# 动态计算超时时间(基于字典行数)
site_timeout = calculate_directory_scan_timeout(tool_config)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}")
# 生成日志文件路径
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = directory_scan_dir / f"{tool_name}_{timestamp}_{idx}.log"
try:
# 直接调用 task串行执行
result = run_and_stream_save_directories_task(
cmd=command,
tool_name=tool_name, # 新增:工具名称
scan_id=scan_id,
target_id=target_id,
site_url=site_url,
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=site_timeout,
log_file=str(log_file) # 新增:日志文件路径
)
total_directories += result.get('created_directories', 0)
processed_sites_set.add(site_url) # 使用 set 记录成功的站点
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, len(sites), site_url,
result.get('created_directories', 0)
)
except subprocess.TimeoutExpired as exc:
# 超时异常单独处理
failed_sites.append(site_url)
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 超时配置: %d\n"
"注意:超时前已解析的目录数据已保存到数据库,但扫描未完全完成。",
idx, len(sites), site_url, site_timeout
)
except Exception as exc:
# 其他异常
failed_sites.append(site_url)
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, len(sites), site_url, exc
)
# 每 10 个站点输出进度
if idx % 10 == 0:
logger.info(
"进度: %d/%d (%.1f%%) - 已发现 %d 个目录",
idx, len(sites), idx/len(sites)*100, total_directories
)
# 计算成功和失败的站点数
processed_count = len(processed_sites_set)
if failed_sites:
logger.warning(
"部分站点扫描失败: %d/%d",
len(failed_sites), len(sites)
)
logger.info(
"✓ 串行目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
processed_count, len(sites), len(failed_sites), total_directories
)
return total_directories, processed_count, failed_sites
def _generate_log_filename(tool_name: str, site_url: str, directory_scan_dir: Path) -> Path:
"""
生成唯一的日志文件名
使用 URL 的 hash 确保并发时不会冲突
Args:
tool_name: 工具名称
site_url: 站点 URL
directory_scan_dir: 目录扫描目录
Returns:
Path: 日志文件路径
"""
url_hash = hashlib.md5(site_url.encode()).hexdigest()[:8]
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f') timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
return directory_scan_dir / f"{tool_name}_{url_hash}_{timestamp}.log" return directory_scan_dir / f"{tool_name}_{url_hash}_{timestamp}.log"
def _run_scans_concurrently( def _prepare_tool_wordlist(tool_name: str, tool_config: dict) -> bool:
enabled_tools: dict, """准备工具的字典文件,返回是否成功"""
sites_file: str,
directory_scan_dir: Path,
scan_id: int,
target_id: int,
site_count: int,
target_name: str
) -> Tuple[int, int, List[str]]:
"""
并发执行目录扫描任务(使用 ThreadPoolTaskRunner
Args:
enabled_tools: 启用的工具配置字典
sites_file: 站点文件路径
directory_scan_dir: 目录扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
site_count: 站点数量
target_name: 目标名称(用于错误日志)
Returns:
tuple: (total_directories, processed_sites, failed_sites)
"""
# 读取站点列表
sites: List[str] = []
with open(sites_file, 'r', encoding='utf-8') as f:
for line in f:
site_url = line.strip()
if site_url:
sites.append(site_url)
if not sites:
logger.warning("站点列表为空")
return 0, 0, []
logger.info(
"准备并发扫描 %d 个站点,使用工具: %s",
len(sites), ', '.join(enabled_tools.keys())
)
total_directories = 0
processed_sites_count = 0
failed_sites: List[str] = []
# 遍历每个工具
for tool_name, tool_config in enabled_tools.items():
# 每个工具独立获取 max_workers 配置
max_workers = _get_max_workers(tool_config)
logger.info("="*60)
logger.info("使用工具: %s (并发模式, max_workers=%d)", tool_name, max_workers)
logger.info("="*60)
user_log(scan_id, "directory_scan", f"Running {tool_name}")
# 如果配置了 wordlist_name则先确保本地存在对应的字典文件含 hash 校验)
wordlist_name = tool_config.get('wordlist_name') wordlist_name = tool_config.get('wordlist_name')
if wordlist_name: if not wordlist_name:
return True
try: try:
local_wordlist_path = ensure_wordlist_local(wordlist_name) local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path tool_config['wordlist'] = local_wordlist_path
return True
except Exception as exc: except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc) logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
# 当前工具无法执行,将所有站点视为失败,继续下一个工具 return False
failed_sites.extend(sites)
continue
# 计算超时时间(所有站点共用)
site_timeout = tool_config.get('timeout', 300)
if site_timeout == 'auto':
site_timeout = calculate_directory_scan_timeout(tool_config)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}")
# 准备所有站点的扫描参数 def _build_scan_params(
tool_name: str,
tool_config: dict,
sites: List[str],
directory_scan_dir: Path,
site_timeout: int
) -> Tuple[List[dict], List[str]]:
"""构建所有站点的扫描参数,返回 (scan_params_list, failed_sites)"""
scan_params_list = [] scan_params_list = []
failed_sites = []
for idx, site_url in enumerate(sites, 1): for idx, site_url in enumerate(sites, 1):
try: try:
command = build_scan_command( command = build_scan_command(
@@ -458,30 +204,22 @@ def _run_scans_concurrently(
) )
failed_sites.append(site_url) failed_sites.append(site_url)
if not scan_params_list: return scan_params_list, failed_sites
logger.warning("没有有效的扫描任务")
continue
# ============================================================
# 分批执行策略:控制实际并发的 ffuf 进程数
# ============================================================
total_tasks = len(scan_params_list)
logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers)
# 进度里程碑跟踪 def _execute_batch(
last_progress_percent = 0 batch_params: List[dict],
tool_directories = 0 tool_name: str,
tool_processed = 0 scan_id: int,
target_id: int,
directory_scan_dir: Path,
total_sites: int
) -> Tuple[int, List[str]]:
"""执行一批扫描任务,返回 (directories_found, failed_sites)"""
directories_found = 0
failed_sites = []
batch_num = 0 # 提交任务
for batch_start in range(0, total_tasks, max_workers):
batch_end = min(batch_start + max_workers, total_tasks)
batch_params = scan_params_list[batch_start:batch_end]
batch_num += 1
logger.info("执行第 %d 批任务(%d-%d/%d...", batch_num, batch_start + 1, batch_end, total_tasks)
# 提交当前批次的任务(非阻塞,立即返回 future
futures = [] futures = []
for params in batch_params: for params in batch_params:
future = run_and_stream_save_directories_task.submit( future = run_and_stream_save_directories_task.submit(
@@ -498,54 +236,142 @@ def _run_scans_concurrently(
) )
futures.append((params['idx'], params['site_url'], future)) futures.append((params['idx'], params['site_url'], future))
# 等待当前批次所有任务完成(阻塞,确保本批完成后再启动下一批) # 等待结果
for idx, site_url, future in futures: for idx, site_url, future in futures:
try: try:
result = future.result() # 阻塞等待单个任务完成 result = future.result()
directories_found = result.get('created_directories', 0) dirs_count = result.get('created_directories', 0)
total_directories += directories_found directories_found += dirs_count
tool_directories += directories_found
processed_sites_count += 1
tool_processed += 1
logger.info( logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录", "✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, len(sites), site_url, directories_found idx, total_sites, site_url, dirs_count
) )
except Exception as exc: except Exception as exc:
failed_sites.append(site_url) failed_sites.append(site_url)
if 'timeout' in str(exc).lower() or isinstance(exc, subprocess.TimeoutExpired): if 'timeout' in str(exc).lower():
logger.warning( logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s", "⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
idx, len(sites), site_url, exc idx, total_sites, site_url, exc
) )
else: else:
logger.error( logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s", "✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, len(sites), site_url, exc idx, total_sites, site_url, exc
) )
return directories_found, failed_sites
def _run_scans_concurrently(
enabled_tools: dict,
sites_file: str,
directory_scan_dir: Path,
scan_id: int,
target_id: int,
) -> Tuple[int, int, List[str]]:
"""
并发执行目录扫描任务
Returns:
tuple: (total_directories, processed_sites, failed_sites)
"""
# 读取站点列表
sites: List[str] = []
with open(sites_file, 'r', encoding='utf-8') as f:
sites = [line.strip() for line in f if line.strip()]
if not sites:
logger.warning("站点列表为空")
return 0, 0, []
logger.info(
"准备并发扫描 %d 个站点,使用工具: %s",
len(sites), ', '.join(enabled_tools.keys())
)
total_directories = 0
processed_sites_count = 0
failed_sites: List[str] = []
for tool_name, tool_config in enabled_tools.items():
max_workers = _get_max_workers(tool_config)
logger.info("=" * 60)
logger.info("使用工具: %s (并发模式, max_workers=%d)", tool_name, max_workers)
logger.info("=" * 60)
user_log(scan_id, "directory_scan", f"Running {tool_name}")
# 准备字典文件
if not _prepare_tool_wordlist(tool_name, tool_config):
failed_sites.extend(sites)
continue
# 计算超时时间
site_timeout = tool_config.get('timeout', 300)
if site_timeout == 'auto':
site_timeout = calculate_directory_scan_timeout(tool_config)
logger.info("✓ 工具 %s 动态计算 timeout: %d", tool_name, site_timeout)
# 构建扫描参数
scan_params_list, build_failed = _build_scan_params(
tool_name, tool_config, sites, directory_scan_dir, site_timeout
)
failed_sites.extend(build_failed)
if not scan_params_list:
logger.warning("没有有效的扫描任务")
continue
# 分批执行
total_tasks = len(scan_params_list)
logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers)
last_progress_percent = 0
tool_directories = 0
tool_processed = 0
for batch_start in range(0, total_tasks, max_workers):
batch_end = min(batch_start + max_workers, total_tasks)
batch_params = scan_params_list[batch_start:batch_end]
batch_num = batch_start // max_workers + 1
logger.info(
"执行第 %d 批任务(%d-%d/%d...",
batch_num, batch_start + 1, batch_end, total_tasks
)
dirs_found, batch_failed = _execute_batch(
batch_params, tool_name, scan_id, target_id,
directory_scan_dir, len(sites)
)
total_directories += dirs_found
tool_directories += dirs_found
tool_processed += len(batch_params) - len(batch_failed)
processed_sites_count += len(batch_params) - len(batch_failed)
failed_sites.extend(batch_failed)
# 进度里程碑:每 20% 输出一次 # 进度里程碑:每 20% 输出一次
current_progress = int((batch_end / total_tasks) * 100) current_progress = int((batch_end / total_tasks) * 100)
if current_progress >= last_progress_percent + 20: if current_progress >= last_progress_percent + 20:
user_log(scan_id, "directory_scan", f"Progress: {batch_end}/{total_tasks} sites scanned") user_log(
scan_id, "directory_scan",
f"Progress: {batch_end}/{total_tasks} sites scanned"
)
last_progress_percent = (current_progress // 20) * 20 last_progress_percent = (current_progress // 20) * 20
# 工具完成日志(开发者日志 + 用户日志)
logger.info( logger.info(
"✓ 工具 %s 执行完成 - 已处理站点: %d/%d, 发现目录: %d", "✓ 工具 %s 执行完成 - 已处理站点: %d/%d, 发现目录: %d",
tool_name, tool_processed, total_tasks, tool_directories tool_name, tool_processed, total_tasks, tool_directories
) )
user_log(scan_id, "directory_scan", f"{tool_name} completed: found {tool_directories} directories") user_log(
scan_id, "directory_scan",
# 输出汇总信息 f"{tool_name} completed: found {tool_directories} directories"
if failed_sites:
logger.warning(
"部分站点扫描失败: %d/%d",
len(failed_sites), len(sites)
) )
if failed_sites:
logger.warning("部分站点扫描失败: %d/%d", len(failed_sites), len(sites))
logger.info( logger.info(
"✓ 并发目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d", "✓ 并发目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
processed_sites_count, len(sites), len(failed_sites), total_directories processed_sites_count, len(sites), len(failed_sites), total_directories
@@ -576,21 +402,6 @@ def directory_scan_flow(
2. 对每个站点 URL 执行目录扫描(支持 ffuf 等工具) 2. 对每个站点 URL 执行目录扫描(支持 ffuf 等工具)
3. 流式保存扫描结果到数据库 Directory 表 3. 流式保存扫描结果到数据库 Directory 表
工作流程:
Step 0: 创建工作目录
Step 1: 导出站点 URL 列表到文件(供扫描工具使用)
Step 2: 验证工具配置
Step 3: 并发执行扫描工具并实时保存结果(使用 ThreadPoolTaskRunner
ffuf 输出字段:
- url: 发现的目录/文件 URL
- length: 响应内容长度
- status: HTTP 状态码
- words: 响应内容单词数
- lines: 响应内容行数
- content_type: 内容类型
- duration: 请求耗时
Args: Args:
scan_id: 扫描任务 ID scan_id: 扫描任务 ID
target_name: 目标名称 target_name: 目标名称
@@ -599,33 +410,15 @@ def directory_scan_flow(
enabled_tools: 启用的工具配置字典 enabled_tools: 启用的工具配置字典
Returns: Returns:
dict: { dict: 扫描结果
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'sites_file': str,
'site_count': int,
'total_directories': int, # 发现的总目录数
'processed_sites': int, # 成功处理的站点数
'failed_sites_count': int, # 失败的站点数
'executed_tasks': list
}
Raises:
ValueError: 参数错误
RuntimeError: 执行失败
""" """
try: try:
logger.info( wait_for_system_load(context="directory_scan_flow")
"="*60 + "\n" +
"开始目录扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
)
logger.info(
"开始目录扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "directory_scan", "Starting directory scan") user_log(scan_id, "directory_scan", "Starting directory scan")
# 参数验证 # 参数验证
@@ -644,8 +437,8 @@ def directory_scan_flow(
from apps.scan.utils import setup_scan_directory from apps.scan.utils import setup_scan_directory
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan') directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
# Step 1: 导出站点 URL(支持懒加载) # Step 1: 导出站点 URL
sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir) sites_file, site_count = _export_site_urls(target_id, directory_scan_dir)
if site_count == 0: if site_count == 0:
logger.warning("跳过目录扫描:没有站点可扫描 - Scan ID: %s", scan_id) logger.warning("跳过目录扫描:没有站点可扫描 - Scan ID: %s", scan_id)
@@ -665,13 +458,13 @@ def directory_scan_flow(
# Step 2: 工具配置信息 # Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息") logger.info("Step 2: 工具配置信息")
tool_info = [] tool_info = [
for tool_name, tool_config in enabled_tools.items(): f"{name}(max_workers={_get_max_workers(cfg)})"
mw = _get_max_workers(tool_config) for name, cfg in enabled_tools.items()
tool_info.append(f"{tool_name}(max_workers={mw})") ]
logger.info("✓ 启用工具: %s", ', '.join(tool_info)) logger.info("✓ 启用工具: %s", ', '.join(tool_info))
# Step 3: 并发执行扫描工具并实时保存结果 # Step 3: 并发执行扫描
logger.info("Step 3: 并发执行扫描工具并实时保存结果") logger.info("Step 3: 并发执行扫描工具并实时保存结果")
total_directories, processed_sites, failed_sites = _run_scans_concurrently( total_directories, processed_sites, failed_sites = _run_scans_concurrently(
enabled_tools=enabled_tools, enabled_tools=enabled_tools,
@@ -679,18 +472,19 @@ def directory_scan_flow(
directory_scan_dir=directory_scan_dir, directory_scan_dir=directory_scan_dir,
scan_id=scan_id, scan_id=scan_id,
target_id=target_id, target_id=target_id,
site_count=site_count,
target_name=target_name
) )
# 检查是否所有站点都失败
if processed_sites == 0 and site_count > 0: if processed_sites == 0 and site_count > 0:
logger.warning("所有站点扫描均失败 - 总站点数: %d, 失败数: %d", site_count, len(failed_sites)) logger.warning(
# 不抛出异常,让扫描继续 "所有站点扫描均失败 - 总站点数: %d, 失败数: %d",
site_count, len(failed_sites)
)
# 记录 Flow 完成
logger.info("✓ 目录扫描完成 - 发现目录: %d", total_directories) logger.info("✓ 目录扫描完成 - 发现目录: %d", total_directories)
user_log(scan_id, "directory_scan", f"directory_scan completed: found {total_directories} directories") user_log(
scan_id, "directory_scan",
f"directory_scan completed: found {total_directories} directories"
)
return { return {
'success': True, 'success': True,

View File

@@ -10,26 +10,22 @@
- 流式处理输出,批量更新数据库 - 流式处理输出,批量更新数据库
""" """
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging import logging
import os
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from prefect import flow from prefect import flow
from apps.scan.handlers.scan_flow_handlers import ( from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed, on_scan_flow_completed,
on_scan_flow_failed, on_scan_flow_failed,
on_scan_flow_running,
) )
from apps.scan.tasks.fingerprint_detect import ( from apps.scan.tasks.fingerprint_detect import (
export_urls_for_fingerprint_task, export_urls_for_fingerprint_task,
run_xingfinger_and_stream_update_tech_task, run_xingfinger_and_stream_update_tech_task,
) )
from apps.scan.utils import build_scan_command, user_log from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -44,8 +40,7 @@ def calculate_fingerprint_detect_timeout(
根据 URL 数量计算超时时间 根据 URL 数量计算超时时间
公式:超时时间 = URL 数量 × 每 URL 基础时间 公式:超时时间 = URL 数量 × 每 URL 基础时间
最小值300秒 最小值300秒,无上限
无上限
Args: Args:
url_count: URL 数量 url_count: URL 数量
@@ -54,10 +49,8 @@ def calculate_fingerprint_detect_timeout(
Returns: Returns:
int: 计算出的超时时间(秒) int: 计算出的超时时间(秒)
""" """
timeout = int(url_count * base_per_url) return max(min_timeout, int(url_count * base_per_url))
return max(min_timeout, timeout)
@@ -90,7 +83,6 @@ def _export_urls(
) )
total_count = export_result['total_count'] total_count = export_result['total_count']
logger.info( logger.info(
"✓ URL 导出完成 - 文件: %s, 数量: %d", "✓ URL 导出完成 - 文件: %s, 数量: %d",
export_result['output_file'], export_result['output_file'],
@@ -146,13 +138,11 @@ def _run_fingerprint_detect(
command = build_scan_command( command = build_scan_command(
tool_name=tool_name, tool_name=tool_name,
scan_type='fingerprint_detect', scan_type='fingerprint_detect',
command_params={ command_params={'urls_file': urls_file},
'urls_file': urls_file
},
tool_config=tool_config_with_paths tool_config=tool_config_with_paths
) )
except Exception as e: except Exception as e:
reason = f"命令构建失败: {str(e)}" reason = f"命令构建失败: {e}"
logger.error("构建 %s 命令失败: %s", tool_name, e) logger.error("构建 %s 命令失败: %s", tool_name, e)
failed_tools.append({'tool': tool_name, 'reason': reason}) failed_tools.append({'tool': tool_name, 'reason': reason})
continue continue
@@ -199,7 +189,10 @@ def _run_fingerprint_detect(
tool_updated, tool_updated,
result.get('not_found_count', 0) result.get('not_found_count', 0)
) )
user_log(scan_id, "fingerprint_detect", f"{tool_name} completed: identified {tool_updated} fingerprints") user_log(
scan_id, "fingerprint_detect",
f"{tool_name} completed: identified {tool_updated} fingerprints"
)
except Exception as exc: except Exception as exc:
reason = str(exc) reason = str(exc)
@@ -252,31 +245,16 @@ def fingerprint_detect_flow(
enabled_tools: 启用的工具配置xingfinger enabled_tools: 启用的工具配置xingfinger
Returns: Returns:
dict: { dict: 扫描结果
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'urls_file': str,
'url_count': int,
'processed_records': int,
'updated_count': int,
'created_count': int,
'snapshot_count': int,
'executed_tasks': list,
'tool_stats': dict
}
""" """
try: try:
logger.info( # 负载检查:等待系统资源充足
"="*60 + "\n" + wait_for_system_load(context="fingerprint_detect_flow")
"开始指纹识别\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
)
logger.info(
"开始指纹识别 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "fingerprint_detect", "Starting fingerprint detection") user_log(scan_id, "fingerprint_detect", "Starting fingerprint detection")
# 参数验证 # 参数验证
@@ -302,27 +280,7 @@ def fingerprint_detect_flow(
if url_count == 0: if url_count == 0:
logger.warning("跳过指纹识别:没有 URL 可扫描 - Scan ID: %s", scan_id) logger.warning("跳过指纹识别:没有 URL 可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "fingerprint_detect", "Skipped: no URLs to scan", "warning") user_log(scan_id, "fingerprint_detect", "Skipped: no URLs to scan", "warning")
return { return _build_empty_result(scan_id, target_name, scan_workspace_dir, urls_file)
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': 0,
'processed_records': 0,
'updated_count': 0,
'created_count': 0,
'snapshot_count': 0,
'executed_tasks': ['export_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}
# Step 2: 工具配置信息 # Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息") logger.info("Step 2: 工具配置信息")
@@ -342,20 +300,33 @@ def fingerprint_detect_flow(
# 动态生成已执行的任务列表 # 动态生成已执行的任务列表
executed_tasks = ['export_urls_for_fingerprint'] executed_tasks = ['export_urls_for_fingerprint']
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats.keys()]) executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats])
# 汇总所有工具的结果 # 汇总所有工具的结果
total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values()) total_processed = sum(
total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values()) stats['result'].get('processed_records', 0) for stats in tool_stats.values()
total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values()) )
total_snapshots = sum(stats['result'].get('snapshot_count', 0) for stats in tool_stats.values()) total_updated = sum(
stats['result'].get('updated_count', 0) for stats in tool_stats.values()
)
total_created = sum(
stats['result'].get('created_count', 0) for stats in tool_stats.values()
)
total_snapshots = sum(
stats['result'].get('snapshot_count', 0) for stats in tool_stats.values()
)
# 记录 Flow 完成 # 记录 Flow 完成
logger.info("✓ 指纹识别完成 - 识别指纹: %d", total_updated) logger.info("✓ 指纹识别完成 - 识别指纹: %d", total_updated)
user_log(scan_id, "fingerprint_detect", f"fingerprint_detect completed: identified {total_updated} fingerprints") user_log(
scan_id, "fingerprint_detect",
f"fingerprint_detect completed: identified {total_updated} fingerprints"
)
successful_tools = [name for name in enabled_tools.keys() successful_tools = [
if name not in [f['tool'] for f in failed_tools]] name for name in enabled_tools
if name not in [f['tool'] for f in failed_tools]
]
return { return {
'success': True, 'success': True,
@@ -388,3 +359,33 @@ def fingerprint_detect_flow(
except Exception as e: except Exception as e:
logger.exception("指纹识别失败: %s", e) logger.exception("指纹识别失败: %s", e)
raise raise
def _build_empty_result(
scan_id: int,
target_name: str,
scan_workspace_dir: str,
urls_file: str
) -> dict:
"""构建空结果(无 URL 可扫描时)"""
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': 0,
'processed_records': 0,
'updated_count': 0,
'created_count': 0,
'snapshot_count': 0,
'executed_tasks': ['export_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}

View File

@@ -10,25 +10,23 @@
- 配置由 YAML 解析 - 配置由 YAML 解析
""" """
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging import logging
import os
import subprocess import subprocess
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Callable
from prefect import flow from prefect import flow
from apps.scan.tasks.port_scan import (
export_hosts_task,
run_and_stream_save_ports_task
)
from apps.scan.handlers.scan_flow_handlers import ( from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed, on_scan_flow_completed,
on_scan_flow_failed, on_scan_flow_failed,
on_scan_flow_running,
) )
from apps.scan.utils import config_parser, build_scan_command, user_log from apps.scan.tasks.port_scan import (
export_hosts_task,
run_and_stream_save_ports_task,
)
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -42,7 +40,7 @@ def calculate_port_scan_timeout(
根据目标数量和端口数量计算超时时间 根据目标数量和端口数量计算超时时间
计算公式:超时时间 = 目标数 × 端口数 × base_per_pair 计算公式:超时时间 = 目标数 × 端口数 × base_per_pair
超时范围60秒 ~ 2天172800秒 超时范围60秒 ~ 无上限
Args: Args:
tool_config: 工具配置字典包含端口配置ports, top-ports等 tool_config: 工具配置字典包含端口配置ports, top-ports等
@@ -50,18 +48,9 @@ def calculate_port_scan_timeout(
base_per_pair: 每个"端口-目标对"的基础时间(秒),默认 0.5秒 base_per_pair: 每个"端口-目标对"的基础时间(秒),默认 0.5秒
Returns: Returns:
int: 计算出的超时时间(秒),范围60 ~ 172800 int: 计算出的超时时间(秒),最小 60 秒
Example:
# 100个目标 × 100个端口 × 0.5秒 = 5000秒
# 10个目标 × 1000个端口 × 0.5秒 = 5000秒
timeout = calculate_port_scan_timeout(
tool_config={'top-ports': 100},
file_path='/path/to/domains.txt'
)
""" """
try: try:
# 1. 统计目标数量
result = subprocess.run( result = subprocess.run(
['wc', '-l', file_path], ['wc', '-l', file_path],
capture_output=True, capture_output=True,
@@ -69,30 +58,18 @@ def calculate_port_scan_timeout(
check=True check=True
) )
target_count = int(result.stdout.strip().split()[0]) target_count = int(result.stdout.strip().split()[0])
# 2. 解析端口数量
port_count = _parse_port_count(tool_config) port_count = _parse_port_count(tool_config)
# 3. 计算超时时间
# 总工作量 = 目标数 × 端口数
total_work = target_count * port_count total_work = target_count * port_count
timeout = int(total_work * base_per_pair) timeout = max(60, int(total_work * base_per_pair))
# 4. 设置合理的下限(不再设置上限)
min_timeout = 60 # 最小 60 秒
timeout = max(min_timeout, timeout)
logger.info( logger.info(
f"计算端口扫描 timeout - " "计算端口扫描 timeout - 目标数: %d, 端口数: %d, 总工作量: %d, 超时: %d",
f"目标数: {target_count}, " target_count, port_count, total_work, timeout
f"端口数: {port_count}, "
f"总工作量: {total_work}, "
f"超时: {timeout}"
) )
return timeout return timeout
except Exception as e: except Exception as e:
logger.warning(f"计算 timeout 失败: {e},使用默认值 600秒") logger.warning("计算 timeout 失败: %s,使用默认值 600秒", e)
return 600 return 600
@@ -113,44 +90,42 @@ def _parse_port_count(tool_config: dict) -> int:
Returns: Returns:
int: 端口数量 int: 端口数量
""" """
# 1. 检查 top-ports 配置 # 检查 top-ports 配置
if 'top-ports' in tool_config: if 'top-ports' in tool_config:
top_ports = tool_config['top-ports'] top_ports = tool_config['top-ports']
if isinstance(top_ports, int) and top_ports > 0: if isinstance(top_ports, int) and top_ports > 0:
return top_ports return top_ports
logger.warning(f"top-ports 配置无效: {top_ports},使用默认值") logger.warning("top-ports 配置无效: %s,使用默认值", top_ports)
# 2. 检查 ports 配置 # 检查 ports 配置
if 'ports' in tool_config: if 'ports' in tool_config:
ports_str = str(tool_config['ports']).strip() ports_str = str(tool_config['ports']).strip()
# 2.1 逗号分隔的端口列表80,443,8080 # 逗号分隔的端口列表80,443,8080
if ',' in ports_str: if ',' in ports_str:
port_list = [p.strip() for p in ports_str.split(',') if p.strip()] return len([p.strip() for p in ports_str.split(',') if p.strip()])
return len(port_list)
# 2.2 端口范围1-1000 # 端口范围1-1000
if '-' in ports_str: if '-' in ports_str:
try: try:
start, end = ports_str.split('-', 1) start, end = ports_str.split('-', 1)
start_port = int(start.strip()) start_port = int(start.strip())
end_port = int(end.strip()) end_port = int(end.strip())
if 1 <= start_port <= end_port <= 65535: if 1 <= start_port <= end_port <= 65535:
return end_port - start_port + 1 return end_port - start_port + 1
logger.warning(f"端口范围无效: {ports_str},使用默认值") logger.warning("端口范围无效: %s,使用默认值", ports_str)
except ValueError: except ValueError:
logger.warning(f"端口范围解析失败: {ports_str},使用默认值") logger.warning("端口范围解析失败: %s,使用默认值", ports_str)
# 2.3 单个端口 # 单个端口
try: try:
port = int(ports_str) port = int(ports_str)
if 1 <= port <= 65535: if 1 <= port <= 65535:
return 1 return 1
except ValueError: except ValueError:
logger.warning(f"端口配置解析失败: {ports_str},使用默认值") logger.warning("端口配置解析失败: %s,使用默认值", ports_str)
# 3. 默认值naabu 默认扫描 top 100 端口 # 默认值naabu 默认扫描 top 100 端口
return 100 return 100
@@ -179,7 +154,6 @@ def _export_hosts(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
export_result = export_hosts_task( export_result = export_hosts_task(
target_id=target_id, target_id=target_id,
output_file=hosts_file, output_file=hosts_file,
batch_size=1000 # 每次读取 1000 条,优化内存占用
) )
host_count = export_result['total_count'] host_count = export_result['total_count']
@@ -187,9 +161,7 @@ def _export_hosts(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
logger.info( logger.info(
"✓ 主机列表导出完成 - 类型: %s, 文件: %s, 数量: %d", "✓ 主机列表导出完成 - 类型: %s, 文件: %s, 数量: %d",
target_type, target_type, export_result['output_file'], host_count
export_result['output_file'],
host_count
) )
if host_count == 0: if host_count == 0:
@@ -219,67 +191,51 @@ def _run_scans_sequentially(
Returns: Returns:
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools) tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
注意:端口扫描是流式输出,不生成结果文件
Raises:
RuntimeError: 所有工具均失败
""" """
# ==================== 构建命令并串行执行 ====================
tool_stats = {} tool_stats = {}
processed_records = 0 processed_records = 0
failed_tools = [] # 记录失败的工具(含原因) failed_tools = []
# for循环执行工具按顺序串行运行每个启用的端口扫描工具
for tool_name, tool_config in enabled_tools.items(): for tool_name, tool_config in enabled_tools.items():
# 1. 构建完整命令(变量替换) # 构建命令
try: try:
command = build_scan_command( command = build_scan_command(
tool_name=tool_name, tool_name=tool_name,
scan_type='port_scan', scan_type='port_scan',
command_params={ command_params={'domains_file': domains_file},
'domains_file': domains_file # 对应 {domains_file} tool_config=tool_config
},
tool_config=tool_config #yaml的工具配置
) )
except Exception as e: except Exception as e:
reason = f"命令构建失败: {str(e)}" reason = f"命令构建失败: {e}"
logger.error(f"构建 {tool_name} 命令失败: {e}") logger.error("构建 %s 命令失败: %s", tool_name, e)
failed_tools.append({'tool': tool_name, 'reason': reason}) failed_tools.append({'tool': tool_name, 'reason': reason})
continue continue
# 2. 获取超时时间(支持 'auto' 动态计算) # 获取超时时间
config_timeout = tool_config['timeout'] config_timeout = tool_config['timeout']
if config_timeout == 'auto': if config_timeout == 'auto':
# 动态计算超时时间 config_timeout = calculate_port_scan_timeout(tool_config, str(domains_file))
config_timeout = calculate_port_scan_timeout( logger.info("✓ 工具 %s 动态计算 timeout: %d", tool_name, config_timeout)
tool_config=tool_config,
file_path=str(domains_file)
)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {config_timeout}")
# 2.1 生成日志文件路径 # 生成日志文件路径
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = port_scan_dir / f"{tool_name}_{timestamp}.log" log_file = port_scan_dir / f"{tool_name}_{timestamp}.log"
# 3. 执行扫描任务
logger.info("开始执行 %s 扫描(超时: %d秒)...", tool_name, config_timeout) logger.info("开始执行 %s 扫描(超时: %d秒)...", tool_name, config_timeout)
user_log(scan_id, "port_scan", f"Running {tool_name}: {command}") user_log(scan_id, "port_scan", f"Running {tool_name}: {command}")
# 执行扫描任务
try: try:
# 直接调用 task串行执行
# 注意:端口扫描是流式输出到 stdout不使用 output_file
result = run_and_stream_save_ports_task( result = run_and_stream_save_ports_task(
cmd=command, cmd=command,
tool_name=tool_name, # 工具名称 tool_name=tool_name,
scan_id=scan_id, scan_id=scan_id,
target_id=target_id, target_id=target_id,
cwd=str(port_scan_dir), cwd=str(port_scan_dir),
shell=True, shell=True,
batch_size=1000, batch_size=1000,
timeout=config_timeout, timeout=config_timeout,
log_file=str(log_file) # 新增:日志文件路径 log_file=str(log_file)
) )
tool_stats[tool_name] = { tool_stats[tool_name] = {
@@ -289,15 +245,10 @@ def _run_scans_sequentially(
} }
tool_records = result.get('processed_records', 0) tool_records = result.get('processed_records', 0)
processed_records += tool_records processed_records += tool_records
logger.info( logger.info("✓ 工具 %s 流式处理完成 - 记录数: %d", tool_name, tool_records)
"✓ 工具 %s 流式处理完成 - 记录数: %d",
tool_name, tool_records
)
user_log(scan_id, "port_scan", f"{tool_name} completed: found {tool_records} ports") user_log(scan_id, "port_scan", f"{tool_name} completed: found {tool_records} ports")
except subprocess.TimeoutExpired as exc: except subprocess.TimeoutExpired:
# 超时异常单独处理
# 注意:流式处理任务超时时,已解析的数据已保存到数据库
reason = f"timeout after {config_timeout}s" reason = f"timeout after {config_timeout}s"
failed_tools.append({'tool': tool_name, 'reason': reason}) failed_tools.append({'tool': tool_name, 'reason': reason})
logger.warning( logger.warning(
@@ -307,7 +258,6 @@ def _run_scans_sequentially(
) )
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error") user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
except Exception as exc: except Exception as exc:
# 其他异常
reason = str(exc) reason = str(exc)
failed_tools.append({'tool': tool_name, 'reason': reason}) failed_tools.append({'tool': tool_name, 'reason': reason})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True) logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
@@ -322,12 +272,12 @@ def _run_scans_sequentially(
if not tool_stats: if not tool_stats:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools]) error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
logger.warning("所有端口扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details) logger.warning("所有端口扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
# 返回空结果,不抛出异常,让扫描继续
return {}, 0, [], failed_tools return {}, 0, [], failed_tools
# 动态计算成功的工具列表 successful_tool_names = [
successful_tool_names = [name for name in enabled_tools.keys() name for name in enabled_tools
if name not in [f['tool'] for f in failed_tools]] if name not in [f['tool'] for f in failed_tools]
]
logger.info( logger.info(
"✓ 串行端口扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)", "✓ 串行端口扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
@@ -367,7 +317,7 @@ def port_scan_flow(
Step 0: 创建工作目录 Step 0: 创建工作目录
Step 1: 导出域名列表到文件(供扫描工具使用) Step 1: 导出域名列表到文件(供扫描工具使用)
Step 2: 解析配置,获取启用的工具 Step 2: 解析配置,获取启用的工具
Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库(→ HostPortMapping Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库
Args: Args:
scan_id: 扫描任务 ID scan_id: 扫描任务 ID
@@ -377,35 +327,15 @@ def port_scan_flow(
enabled_tools: 启用的工具配置字典 enabled_tools: 启用的工具配置字典
Returns: Returns:
dict: { dict: 扫描结果
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'hosts_file': str,
'host_count': int,
'processed_records': int,
'executed_tasks': list,
'tool_stats': {
'total': int, # 总工具数
'successful': int, # 成功工具数
'failed': int, # 失败工具数
'successful_tools': list[str], # 成功工具列表 ['naabu_active']
'failed_tools': list[dict], # 失败工具列表 [{'tool': 'naabu_passive', 'reason': '超时'}]
'details': dict # 详细执行结果(保留向后兼容)
}
}
Raises: Raises:
ValueError: 配置错误 ValueError: 配置错误
RuntimeError: 执行失败 RuntimeError: 执行失败
Note:
端口扫描工具(如 naabu会解析域名获取 IP输出 host + ip + port 三元组。
同一 host 可能对应多个 IPCDN、负载均衡因此使用三元映射表存储。
""" """
try: try:
# 参数验证 wait_for_system_load(context="port_scan_flow")
if scan_id is None: if scan_id is None:
raise ValueError("scan_id 不能为空") raise ValueError("scan_id 不能为空")
if not target_name: if not target_name:
@@ -418,21 +348,16 @@ def port_scan_flow(
raise ValueError("enabled_tools 不能为空") raise ValueError("enabled_tools 不能为空")
logger.info( logger.info(
"="*60 + "\n" + "开始端口扫描 - Scan ID: %s, Target: %s, Workspace: %s",
"开始端口扫描\n" + scan_id, target_name, scan_workspace_dir
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
) )
user_log(scan_id, "port_scan", "Starting port scan") user_log(scan_id, "port_scan", "Starting port scan")
# Step 0: 创建工作目录 # Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory from apps.scan.utils import setup_scan_directory
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan') port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
# Step 1: 导出主机列表到文件(根据 Target 类型自动决定内容) # Step 1: 导出主机列表
hosts_file, host_count, target_type = _export_hosts(target_id, port_scan_dir) hosts_file, host_count, target_type = _export_hosts(target_id, port_scan_dir)
if host_count == 0: if host_count == 0:
@@ -460,10 +385,7 @@ def port_scan_flow(
# Step 2: 工具配置信息 # Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息") logger.info("Step 2: 工具配置信息")
logger.info( logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
"✓ 启用工具: %s",
', '.join(enabled_tools.keys())
)
# Step 3: 串行执行扫描工具 # Step 3: 串行执行扫描工具
logger.info("Step 3: 串行执行扫描工具") logger.info("Step 3: 串行执行扫描工具")
@@ -476,13 +398,11 @@ def port_scan_flow(
target_name=target_name target_name=target_name
) )
# 记录 Flow 完成
logger.info("✓ 端口扫描完成 - 发现端口: %d", processed_records) logger.info("✓ 端口扫描完成 - 发现端口: %d", processed_records)
user_log(scan_id, "port_scan", f"port_scan completed: found {processed_records} ports") user_log(scan_id, "port_scan", f"port_scan completed: found {processed_records} ports")
# 动态生成已执行的任务列表
executed_tasks = ['export_hosts', 'parse_config'] executed_tasks = ['export_hosts', 'parse_config']
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats.keys()]) executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats])
return { return {
'success': True, 'success': True,

View File

@@ -11,43 +11,33 @@
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL 2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
""" """
# Django 环境初始化
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging import logging
from pathlib import Path
from typing import Optional from typing import Optional
from prefect import flow from prefect import flow
from apps.scan.tasks.screenshot import capture_screenshots_task
from apps.scan.handlers.scan_flow_handlers import ( from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed, on_scan_flow_completed,
on_scan_flow_failed, on_scan_flow_failed,
) on_scan_flow_running,
from apps.scan.utils import user_log
from apps.scan.services.target_export_service import (
get_urls_with_fallback,
DataSource,
) )
from apps.scan.providers import TargetProvider from apps.scan.providers import TargetProvider
from apps.scan.services.target_export_service import DataSource, get_urls_with_fallback
from apps.scan.tasks.screenshot import capture_screenshots_task
from apps.scan.utils import user_log, wait_for_system_load
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# URL 来源到 DataSource 的映射
_SOURCE_MAPPING = {
'websites': DataSource.WEBSITE,
'endpoints': DataSource.ENDPOINT,
}
def _parse_screenshot_config(enabled_tools: dict) -> dict: def _parse_screenshot_config(enabled_tools: dict) -> dict:
""" """解析截图配置"""
解析截图配置
Args:
enabled_tools: 启用的工具配置
Returns:
截图配置字典
"""
# 从 enabled_tools 中获取 playwright 配置
playwright_config = enabled_tools.get('playwright', {}) playwright_config = enabled_tools.get('playwright', {})
return { return {
'concurrency': playwright_config.get('concurrency', 5), 'concurrency': playwright_config.get('concurrency', 5),
'url_sources': playwright_config.get('url_sources', ['websites']) 'url_sources': playwright_config.get('url_sources', ['websites'])
@@ -55,33 +45,54 @@ def _parse_screenshot_config(enabled_tools: dict) -> dict:
def _map_url_sources_to_data_sources(url_sources: list[str]) -> list[str]: def _map_url_sources_to_data_sources(url_sources: list[str]) -> list[str]:
""" """将配置中的 url_sources 映射为 DataSource 常量"""
将配置中的 url_sources 映射为 DataSource 常量
Args:
url_sources: 配置中的来源列表,如 ['websites', 'endpoints']
Returns:
DataSource 常量列表
"""
source_mapping = {
'websites': DataSource.WEBSITE,
'endpoints': DataSource.ENDPOINT,
}
sources = [] sources = []
for source in url_sources: for source in url_sources:
if source in source_mapping: if source in _SOURCE_MAPPING:
sources.append(source_mapping[source]) sources.append(_SOURCE_MAPPING[source])
else: else:
logger.warning("未知的 URL 来源: %s,跳过", source) logger.warning("未知的 URL 来源: %s,跳过", source)
# 添加默认回退(从 subdomain 构造) # 添加默认回退(从 subdomain 构造)
sources.append(DataSource.DEFAULT) sources.append(DataSource.DEFAULT)
return sources return sources
def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str, list[str]]:
"""从 Provider 收集 URL"""
logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__)
urls = list(provider.iter_urls())
blacklist_filter = provider.get_blacklist_filter()
if blacklist_filter:
urls = [url for url in urls if blacklist_filter.is_allowed(url)]
return urls, 'provider', ['provider']
def _collect_urls_from_database(
target_id: int,
url_sources: list[str]
) -> tuple[list[str], str, list[str]]:
"""从数据库收集 URL带黑名单过滤和回退"""
data_sources = _map_url_sources_to_data_sources(url_sources)
result = get_urls_with_fallback(target_id, sources=data_sources)
return result['urls'], result['source'], result['tried_sources']
def _build_empty_result(scan_id: int, target_name: str) -> dict:
"""构建空结果"""
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'total_urls': 0,
'successful': 0,
'failed': 0,
'synced': 0
}
@flow( @flow(
name="screenshot", name="screenshot",
log_prints=True, log_prints=True,
@@ -104,12 +115,6 @@ def screenshot_flow(
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL 1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL 2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
工作流程:
Step 1: 解析配置
Step 2: 收集 URL 列表
Step 3: 批量截图并保存快照
Step 4: 同步到资产表
Args: Args:
scan_id: 扫描任务 ID scan_id: 扫描任务 ID
target_name: 目标名称 target_name: 目标名称
@@ -119,57 +124,31 @@ def screenshot_flow(
provider: TargetProvider 实例(新模式,可选) provider: TargetProvider 实例(新模式,可选)
Returns: Returns:
dict: { 截图结果字典
'success': bool,
'scan_id': int,
'target': str,
'total_urls': int,
'successful': int,
'failed': int,
'synced': int
}
""" """
try: try:
logger.info( # 负载检查:等待系统资源充足
"="*60 + "\n" + wait_for_system_load(context="screenshot_flow")
"开始截图扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
f" Mode: {'Provider' if provider else 'Legacy'}\n" +
"="*60
)
mode = 'Provider' if provider else 'Legacy'
logger.info(
"开始截图扫描 - Scan ID: %s, Target: %s, Mode: %s",
scan_id, target_name, mode
)
user_log(scan_id, "screenshot", "Starting screenshot capture") user_log(scan_id, "screenshot", "Starting screenshot capture")
# Step 1: 解析配置 # Step 1: 解析配置
config = _parse_screenshot_config(enabled_tools) config = _parse_screenshot_config(enabled_tools)
concurrency = config['concurrency'] concurrency = config['concurrency']
url_sources = config['url_sources'] logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, config['url_sources'])
logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, url_sources)
# Step 2: 收集 URL 列表 # Step 2: 收集 URL 列表
if provider is not None: if provider is not None:
# Provider 模式:使用 TargetProvider 获取 URL urls, source_info, tried_sources = _collect_urls_from_provider(provider)
logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__)
urls = list(provider.iter_urls())
# 应用黑名单过滤(如果有)
blacklist_filter = provider.get_blacklist_filter()
if blacklist_filter:
urls = [url for url in urls if blacklist_filter.is_allowed(url)]
source_info = 'provider'
tried_sources = ['provider']
else: else:
# 传统模式:使用统一服务收集 URL带黑名单过滤和回退 urls, source_info, tried_sources = _collect_urls_from_database(
data_sources = _map_url_sources_to_data_sources(url_sources) target_id, config['url_sources']
result = get_urls_with_fallback(target_id, sources=data_sources) )
urls = result['urls']
source_info = result['source']
tried_sources = result['tried_sources']
logger.info( logger.info(
"URL 收集完成 - 来源: %s, 数量: %d, 尝试过: %s", "URL 收集完成 - 来源: %s, 数量: %d, 尝试过: %s",
@@ -179,21 +158,15 @@ def screenshot_flow(
if not urls: if not urls:
logger.warning("没有可截图的 URL跳过截图任务") logger.warning("没有可截图的 URL跳过截图任务")
user_log(scan_id, "screenshot", "Skipped: no URLs to capture", "warning") user_log(scan_id, "screenshot", "Skipped: no URLs to capture", "warning")
return { return _build_empty_result(scan_id, target_name)
'success': True,
'scan_id': scan_id,
'target': target_name,
'total_urls': 0,
'successful': 0,
'failed': 0,
'synced': 0
}
user_log(scan_id, "screenshot", f"Found {len(urls)} URLs to capture (source: {source_info})") user_log(
scan_id, "screenshot",
f"Found {len(urls)} URLs to capture (source: {source_info})"
)
# Step 3: 批量截图 # Step 3: 批量截图
logger.info("Step 3: 批量截图 - %d 个 URL", len(urls)) logger.info("批量截图 - %d 个 URL", len(urls))
capture_result = capture_screenshots_task( capture_result = capture_screenshots_task(
urls=urls, urls=urls,
scan_id=scan_id, scan_id=scan_id,
@@ -202,31 +175,34 @@ def screenshot_flow(
) )
# Step 4: 同步到资产表 # Step 4: 同步到资产表
logger.info("Step 4: 同步截图到资产表") logger.info("同步截图到资产表")
from apps.asset.services.screenshot_service import ScreenshotService from apps.asset.services.screenshot_service import ScreenshotService
screenshot_service = ScreenshotService() synced = ScreenshotService().sync_screenshots_to_asset(scan_id, target_id)
synced = screenshot_service.sync_screenshots_to_asset(scan_id, target_id)
total = capture_result['total']
successful = capture_result['successful']
failed = capture_result['failed']
logger.info( logger.info(
"✓ 截图完成 - 总数: %d, 成功: %d, 失败: %d, 同步: %d", "✓ 截图完成 - 总数: %d, 成功: %d, 失败: %d, 同步: %d",
capture_result['total'], capture_result['successful'], capture_result['failed'], synced total, successful, failed, synced
) )
user_log( user_log(
scan_id, "screenshot", scan_id, "screenshot",
f"Screenshot completed: {capture_result['successful']}/{capture_result['total']} captured, {synced} synced" f"Screenshot completed: {successful}/{total} captured, {synced} synced"
) )
return { return {
'success': True, 'success': True,
'scan_id': scan_id, 'scan_id': scan_id,
'target': target_name, 'target': target_name,
'total_urls': capture_result['total'], 'total_urls': total,
'successful': capture_result['successful'], 'successful': successful,
'failed': capture_result['failed'], 'failed': failed,
'synced': synced 'synced': synced
} }
except Exception as e: except Exception:
logger.exception("截图 Flow 失败: %s", e) logger.exception("截图 Flow 失败")
user_log(scan_id, "screenshot", f"Screenshot failed: {e}", "error") user_log(scan_id, "screenshot", "Screenshot failed", "error")
raise raise

View File

@@ -1,4 +1,3 @@
""" """
站点扫描 Flow 站点扫描 Flow
@@ -11,23 +10,22 @@
- 配置由 YAML 解析 - 配置由 YAML 解析
""" """
# Django 环境初始化(导入即生效) from datetime import datetime
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging import logging
import os
import subprocess import subprocess
import time
from pathlib import Path from pathlib import Path
from typing import Callable
from prefect import flow from prefect import flow
from apps.scan.tasks.site_scan import export_site_urls_task, run_and_stream_save_websites_task
# Django 环境初始化导入即生效pylint: disable=unused-import
from apps.common.prefect_django_setup import setup_django_for_prefect # noqa: F401
from apps.scan.handlers.scan_flow_handlers import ( from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed, on_scan_flow_completed,
on_scan_flow_failed, on_scan_flow_failed,
on_scan_flow_running,
) )
from apps.scan.utils import config_parser, build_scan_command, user_log from apps.scan.tasks.site_scan import export_site_urls_task, run_and_stream_save_websites_task
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -191,7 +189,6 @@ def _run_scans_sequentially(
timeout = max(dynamic_timeout, config_timeout) timeout = max(dynamic_timeout, config_timeout)
# 2.1 生成日志文件路径(类似端口扫描) # 2.1 生成日志文件路径(类似端口扫描)
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = site_scan_dir / f"{tool_name}_{timestamp}.log" log_file = site_scan_dir / f"{tool_name}_{timestamp}.log"
@@ -233,7 +230,7 @@ def _run_scans_sequentially(
) )
user_log(scan_id, "site_scan", f"{tool_name} completed: found {tool_created} websites") user_log(scan_id, "site_scan", f"{tool_name} completed: found {tool_created} websites")
except subprocess.TimeoutExpired as exc: except subprocess.TimeoutExpired:
# 超时异常单独处理 # 超时异常单独处理
reason = f"timeout after {timeout}s" reason = f"timeout after {timeout}s"
failed_tools.append({'tool': tool_name, 'reason': reason}) failed_tools.append({'tool': tool_name, 'reason': reason})
@@ -368,6 +365,9 @@ def site_scan_flow(
RuntimeError: 执行失败 RuntimeError: 执行失败
""" """
try: try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="site_scan_flow")
logger.info( logger.info(
"="*60 + "\n" + "="*60 + "\n" +
"开始站点扫描\n" + "开始站点扫描\n" +

File diff suppressed because it is too large Load Diff

View File

@@ -10,22 +10,18 @@ URL Fetch 主 Flow
- 统一进行 httpx 验证(如果启用) - 统一进行 httpx 验证(如果启用)
""" """
# Django 环境初始化
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging import logging
import os
from pathlib import Path
from datetime import datetime from datetime import datetime
from pathlib import Path
from prefect import flow from prefect import flow
from apps.scan.handlers.scan_flow_handlers import ( from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed, on_scan_flow_completed,
on_scan_flow_failed, on_scan_flow_failed,
on_scan_flow_running,
) )
from apps.scan.utils import user_log from apps.scan.utils import user_log, wait_for_system_load
from .domain_name_url_fetch_flow import domain_name_url_fetch_flow from .domain_name_url_fetch_flow import domain_name_url_fetch_flow
from .sites_url_fetch_flow import sites_url_fetch_flow from .sites_url_fetch_flow import sites_url_fetch_flow
@@ -43,9 +39,6 @@ SITES_FILE_TOOLS = {'katana'}
POST_PROCESS_TOOLS = {'uro', 'httpx'} POST_PROCESS_TOOLS = {'uro', 'httpx'}
def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]: def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
""" """
将启用的工具按输入类型分类 将启用的工具按输入类型分类
@@ -85,7 +78,7 @@ def _merge_and_deduplicate_urls(result_files: list, url_fetch_dir: Path) -> tupl
# 统计唯一 URL 数量 # 统计唯一 URL 数量
unique_url_count = 0 unique_url_count = 0
if Path(merged_file).exists(): if Path(merged_file).exists():
with open(merged_file, 'r') as f: with open(merged_file, 'r', encoding='utf-8') as f:
unique_url_count = sum(1 for line in f if line.strip()) unique_url_count = sum(1 for line in f if line.strip())
logger.info( logger.info(
@@ -136,7 +129,7 @@ def _clean_urls_with_uro(
if result['success']: if result['success']:
return result['output_file'], result['output_count'], result['removed_count'] return result['output_file'], result['output_count'], result['removed_count']
else:
logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误')) logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误'))
return merged_file, result['input_count'], 0 return merged_file, result['input_count'], 0
@@ -156,10 +149,10 @@ def _validate_and_stream_save_urls(
# 统计待验证的 URL 数量 # 统计待验证的 URL 数量
try: try:
with open(merged_file, 'r') as f: with open(merged_file, 'r', encoding='utf-8') as f:
url_count = sum(1 for _ in f) url_count = sum(1 for _ in f)
logger.info("待验证 URL 数量: %d", url_count) logger.info("待验证 URL 数量: %d", url_count)
except Exception as e: except OSError as e:
logger.error("读取 URL 文件失败: %s", e) logger.error("读取 URL 文件失败: %s", e)
return 0 return 0
@@ -177,21 +170,19 @@ def _validate_and_stream_save_urls(
command_params=command_params, command_params=command_params,
tool_config=httpx_config tool_config=httpx_config
) )
except Exception as e: except (ValueError, KeyError) as e:
logger.error("构建 httpx 命令失败: %s", e) logger.error("构建 httpx 命令失败: %s", e)
logger.warning("降级处理:将直接保存所有 URL不验证存活") logger.warning("降级处理:将直接保存所有 URL不验证存活")
return _save_urls_to_database(merged_file, scan_id, target_id) return _save_urls_to_database(merged_file, scan_id, target_id)
# 计算超时时间 # 计算超时时间
raw_timeout = httpx_config.get('timeout', 'auto') raw_timeout = httpx_config.get('timeout', 'auto')
timeout = 3600
if isinstance(raw_timeout, str) and raw_timeout == 'auto': if isinstance(raw_timeout, str) and raw_timeout == 'auto':
# 按 URL 行数计算超时时间:每行 3 秒,最小 60 秒 # 按 URL 行数计算超时时间:每行 3 秒,最小 60 秒
timeout = max(60, url_count * 3) timeout = max(60, url_count * 3)
logger.info( logger.info(
"自动计算 httpx 超时时间(按行数,每行 3 秒,最小 60 秒): url_count=%d, timeout=%d", "自动计算 httpx 超时时间(按行数,每行 3 秒,最小 60 秒): url_count=%d, timeout=%d",
url_count, url_count, timeout
timeout,
) )
else: else:
try: try:
@@ -205,7 +196,6 @@ def _validate_and_stream_save_urls(
log_file = url_fetch_dir / f"httpx_validation_{timestamp}.log" log_file = url_fetch_dir / f"httpx_validation_{timestamp}.log"
# 流式执行 # 流式执行
try:
result = run_and_stream_save_urls_task( result = run_and_stream_save_urls_task(
cmd=command, cmd=command,
tool_name='httpx', tool_name='httpx',
@@ -224,10 +214,6 @@ def _validate_and_stream_save_urls(
) )
return saved return saved
except Exception as e:
logger.error("httpx 流式验证失败: %s", e, exc_info=True)
raise
def _save_urls_to_database(merged_file: str, scan_id: int, target_id: int) -> int: def _save_urls_to_database(merged_file: str, scan_id: int, target_id: int) -> int:
"""保存 URL 到数据库(不验证存活)""" """保存 URL 到数据库(不验证存活)"""
@@ -283,24 +269,20 @@ def url_fetch_flow(
dict: 扫描结果 dict: 扫描结果
""" """
try: try:
logger.info( # 负载检查:等待系统资源充足
"="*60 + "\n" + wait_for_system_load(context="url_fetch_flow")
"开始 URL 获取扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
)
logger.info(
"开始 URL 获取扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "url_fetch", "Starting URL fetch") user_log(scan_id, "url_fetch", "Starting URL fetch")
# Step 1: 准备工作目录 # Step 1: 准备工作目录
logger.info("Step 1: 准备工作目录")
from apps.scan.utils import setup_scan_directory from apps.scan.utils import setup_scan_directory
url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch') url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch')
# Step 2: 分类工具(按输入类型) # Step 2: 分类工具(按输入类型)
logger.info("Step 2: 分类工具")
domain_name_tools, sites_file_tools, uro_config, httpx_config = _classify_tools(enabled_tools) domain_name_tools, sites_file_tools, uro_config, httpx_config = _classify_tools(enabled_tools)
logger.info( logger.info(
@@ -318,14 +300,13 @@ def url_fetch_flow(
"httpx 和 uro 仅用于后处理,不能单独使用。" "httpx 和 uro 仅用于后处理,不能单独使用。"
) )
# Step 3: 并行执行子 Flow # Step 3: 执行子 Flow
all_result_files = [] all_result_files = []
all_failed_tools = [] all_failed_tools = []
all_successful_tools = [] all_successful_tools = []
# 3a: 基于 domain_nametarget_name 的 URL 被动收集(如 waymore # 3a: 基于 domain_name 的 URL 被动收集(如 waymore
if domain_name_tools: if domain_name_tools:
logger.info("Step 3a: 执行基于 domain_name 的 URL 被动收集子 Flow")
tn_result = domain_name_url_fetch_flow( tn_result = domain_name_url_fetch_flow(
scan_id=scan_id, scan_id=scan_id,
target_id=target_id, target_id=target_id,
@@ -339,7 +320,6 @@ def url_fetch_flow(
# 3b: 爬虫(以 sites_file 为输入) # 3b: 爬虫(以 sites_file 为输入)
if sites_file_tools: if sites_file_tools:
logger.info("Step 3b: 执行爬虫子 Flow")
crawl_result = sites_url_fetch_flow( crawl_result = sites_url_fetch_flow(
scan_id=scan_id, scan_id=scan_id,
target_id=target_id, target_id=target_id,
@@ -353,9 +333,10 @@ def url_fetch_flow(
# 检查是否有成功的工具 # 检查是否有成功的工具
if not all_result_files: if not all_result_files:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in all_failed_tools]) error_details = "; ".join([
"%s: %s" % (f['tool'], f['reason']) for f in all_failed_tools
])
logger.warning("所有 URL 获取工具均失败 - 目标: %s, 失败详情: %s", target_name, error_details) logger.warning("所有 URL 获取工具均失败 - 目标: %s, 失败详情: %s", target_name, error_details)
# 返回空结果,不抛出异常,让扫描继续
return { return {
'success': True, 'success': True,
'scan_id': scan_id, 'scan_id': scan_id,
@@ -368,29 +349,22 @@ def url_fetch_flow(
} }
# Step 4: 合并并去重 URL # Step 4: 合并并去重 URL
logger.info("Step 4: 合并并去重 URL") merged_file, _ = _merge_and_deduplicate_urls(
merged_file, unique_url_count = _merge_and_deduplicate_urls(
result_files=all_result_files, result_files=all_result_files,
url_fetch_dir=url_fetch_dir url_fetch_dir=url_fetch_dir
) )
# Step 5: 使用 uro 清理 URL如果启用 # Step 5: 使用 uro 清理 URL如果启用
url_file_for_validation = merged_file url_file_for_validation = merged_file
uro_removed_count = 0
if uro_config and uro_config.get('enabled', False): if uro_config and uro_config.get('enabled', False):
logger.info("Step 5: 使用 uro 清理 URL") url_file_for_validation, _, _ = _clean_urls_with_uro(
url_file_for_validation, cleaned_count, uro_removed_count = _clean_urls_with_uro(
merged_file=merged_file, merged_file=merged_file,
uro_config=uro_config, uro_config=uro_config,
url_fetch_dir=url_fetch_dir url_fetch_dir=url_fetch_dir
) )
else:
logger.info("Step 5: 跳过 uro 清理(未启用)")
# Step 6: 使用 httpx 验证存活并保存(如果启用) # Step 6: 使用 httpx 验证存活并保存(如果启用)
if httpx_config and httpx_config.get('enabled', False): if httpx_config and httpx_config.get('enabled', False):
logger.info("Step 6: 使用 httpx 验证 URL 存活并流式保存")
saved_count = _validate_and_stream_save_urls( saved_count = _validate_and_stream_save_urls(
merged_file=url_file_for_validation, merged_file=url_file_for_validation,
httpx_config=httpx_config, httpx_config=httpx_config,
@@ -399,7 +373,6 @@ def url_fetch_flow(
target_id=target_id target_id=target_id
) )
else: else:
logger.info("Step 6: 保存到数据库(未启用 httpx 验证)")
saved_count = _save_urls_to_database( saved_count = _save_urls_to_database(
merged_file=url_file_for_validation, merged_file=url_file_for_validation,
scan_id=scan_id, scan_id=scan_id,
@@ -408,7 +381,7 @@ def url_fetch_flow(
# 记录 Flow 完成 # 记录 Flow 完成
logger.info("✓ URL 获取完成 - 保存 endpoints: %d", saved_count) logger.info("✓ URL 获取完成 - 保存 endpoints: %d", saved_count)
user_log(scan_id, "url_fetch", f"url_fetch completed: found {saved_count} endpoints") user_log(scan_id, "url_fetch", "url_fetch completed: found %d endpoints" % saved_count)
# 构建已执行的任务列表 # 构建已执行的任务列表
executed_tasks = ['setup_directory', 'classify_tools'] executed_tasks = ['setup_directory', 'classify_tools']

View File

@@ -1,5 +1,6 @@
from apps.common.prefect_django_setup import setup_django_for_prefect """
漏洞扫描主 Flow
"""
import logging import logging
from typing import Dict, Tuple from typing import Dict, Tuple
@@ -11,7 +12,7 @@ from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_failed, on_scan_flow_failed,
) )
from apps.scan.configs.command_templates import get_command_template from apps.scan.configs.command_templates import get_command_template
from apps.scan.utils import user_log from apps.scan.utils import user_log, wait_for_system_load
from .endpoints_vuln_scan_flow import endpoints_vuln_scan_flow from .endpoints_vuln_scan_flow import endpoints_vuln_scan_flow
@@ -62,6 +63,9 @@ def vuln_scan_flow(
- nuclei: 通用漏洞扫描(流式保存,支持模板 commit hash 同步) - nuclei: 通用漏洞扫描(流式保存,支持模板 commit hash 同步)
""" """
try: try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="vuln_scan_flow")
if scan_id is None: if scan_id is None:
raise ValueError("scan_id 不能为空") raise ValueError("scan_id 不能为空")
if not target_name: if not target_name:

View File

@@ -4,11 +4,11 @@
定义 ProviderContext 数据类和 TargetProvider 抽象基类。 定义 ProviderContext 数据类和 TargetProvider 抽象基类。
""" """
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Iterator, Optional, TYPE_CHECKING
import ipaddress import ipaddress
import logging import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from apps.common.utils import BlacklistFilter from apps.common.utils import BlacklistFilter
@@ -46,12 +46,6 @@ class TargetProvider(ABC):
""" """
def __init__(self, context: Optional[ProviderContext] = None): def __init__(self, context: Optional[ProviderContext] = None):
"""
初始化 Provider
Args:
context: Provider 上下文None 时创建默认上下文
"""
self._context = context or ProviderContext() self._context = context or ProviderContext()
@property @property
@@ -64,19 +58,10 @@ class TargetProvider(ABC):
""" """
展开主机(如果是 CIDR 则展开为多个 IP否则直接返回 展开主机(如果是 CIDR 则展开为多个 IP否则直接返回
这是一个内部方法,由 iter_hosts() 自动调用。
Args:
host: 主机字符串IP/域名/CIDR
Yields:
str: 单个主机IP 或域名)
示例: 示例:
"192.168.1.0/30""192.168.1.1", "192.168.1.2" "192.168.1.0/30""192.168.1.1", "192.168.1.2"
"192.168.1.1""192.168.1.1" "192.168.1.1""192.168.1.1"
"example.com""example.com" "example.com""example.com"
"invalid" → (跳过,不返回)
""" """
from apps.common.validators import detect_target_type from apps.common.validators import detect_target_type
from apps.targets.models import Target from apps.targets.models import Target
@@ -85,70 +70,38 @@ class TargetProvider(ABC):
if not host: if not host:
return return
# 统一使用 detect_target_type 检测类型
try: try:
target_type = detect_target_type(host) target_type = detect_target_type(host)
if target_type == Target.TargetType.CIDR: if target_type == Target.TargetType.CIDR:
# 展开 CIDR
network = ipaddress.ip_network(host, strict=False) network = ipaddress.ip_network(host, strict=False)
if network.num_addresses == 1: if network.num_addresses == 1:
yield str(network.network_address) yield str(network.network_address)
else: else:
for ip in network.hosts(): yield from (str(ip) for ip in network.hosts())
yield str(ip) elif target_type in (Target.TargetType.IP, Target.TargetType.DOMAIN):
elif target_type == Target.TargetType.IP:
# 单个 IP
yield host
elif target_type == Target.TargetType.DOMAIN:
# 域名
yield host yield host
except ValueError as e: except ValueError as e:
# 无效格式,跳过并记录警告
logger.warning("跳过无效的主机格式 '%s': %s", host, str(e)) logger.warning("跳过无效的主机格式 '%s': %s", host, str(e))
def iter_hosts(self) -> Iterator[str]: def iter_hosts(self) -> Iterator[str]:
""" """迭代主机列表(域名/IP自动展开 CIDR"""
迭代主机列表(域名/IP
自动展开 CIDR子类无需关心。
Yields:
str: 主机名或 IP 地址(单个,不包含 CIDR
"""
for host in self._iter_raw_hosts(): for host in self._iter_raw_hosts():
yield from self._expand_host(host) yield from self._expand_host(host)
@abstractmethod @abstractmethod
def _iter_raw_hosts(self) -> Iterator[str]: def _iter_raw_hosts(self) -> Iterator[str]:
""" """迭代原始主机列表(可能包含 CIDR子类实现"""
迭代原始主机列表(可能包含 CIDR
子类实现此方法,返回原始数据即可,不需要处理 CIDR 展开。
Yields:
str: 主机名、IP 地址或 CIDR
"""
pass pass
@abstractmethod @abstractmethod
def iter_urls(self) -> Iterator[str]: def iter_urls(self) -> Iterator[str]:
""" """迭代 URL 列表"""
迭代 URL 列表
Yields:
str: URL 字符串
"""
pass pass
@abstractmethod @abstractmethod
def get_blacklist_filter(self) -> Optional['BlacklistFilter']: def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
""" """获取黑名单过滤器,返回 None 表示不过滤"""
获取黑名单过滤器
Returns:
BlacklistFilter: 黑名单过滤器实例,或 None不过滤
"""
pass pass
@property @property

View File

@@ -5,9 +5,9 @@
""" """
import logging import logging
from typing import Iterator, Optional, TYPE_CHECKING from typing import TYPE_CHECKING, Iterator, Optional
from .base import TargetProvider, ProviderContext from .base import ProviderContext, TargetProvider
if TYPE_CHECKING: if TYPE_CHECKING:
from apps.common.utils import BlacklistFilter from apps.common.utils import BlacklistFilter
@@ -19,13 +19,8 @@ class DatabaseTargetProvider(TargetProvider):
""" """
数据库目标提供者 - 从 Target 表及关联资产表查询 数据库目标提供者 - 从 Target 表及关联资产表查询
这是现有行为的封装,保持向后兼容。
数据来源: 数据来源:
- iter_hosts(): 根据 Target 类型返回域名/IP - iter_hosts(): 根据 Target 类型返回域名/IP
- DOMAIN: 根域名 + Subdomain 表
- IP: 直接返回 IP
- CIDR: 使用 _expand_host() 展开为所有主机 IP
- iter_urls(): WebSite/Endpoint 表,带回退链 - iter_urls(): WebSite/Endpoint 表,带回退链
使用方式: 使用方式:
@@ -35,47 +30,25 @@ class DatabaseTargetProvider(TargetProvider):
""" """
def __init__(self, target_id: int, context: Optional[ProviderContext] = None): def __init__(self, target_id: int, context: Optional[ProviderContext] = None):
"""
初始化数据库目标提供者
Args:
target_id: 目标 ID必需
context: Provider 上下文
"""
ctx = context or ProviderContext() ctx = context or ProviderContext()
ctx.target_id = target_id ctx.target_id = target_id
super().__init__(ctx) super().__init__(ctx)
self._blacklist_filter: Optional['BlacklistFilter'] = None # 延迟加载 self._blacklist_filter: Optional['BlacklistFilter'] = None
def iter_hosts(self) -> Iterator[str]: def iter_hosts(self) -> Iterator[str]:
""" """从数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤"""
从数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤
重写基类方法以支持黑名单过滤(需要在 CIDR 展开后过滤)
"""
blacklist = self.get_blacklist_filter() blacklist = self.get_blacklist_filter()
for host in self._iter_raw_hosts(): for host in self._iter_raw_hosts():
# 展开 CIDR
for expanded_host in self._expand_host(host): for expanded_host in self._expand_host(host):
# 应用黑名单过滤
if not blacklist or blacklist.is_allowed(expanded_host): if not blacklist or blacklist.is_allowed(expanded_host):
yield expanded_host yield expanded_host
def _iter_raw_hosts(self) -> Iterator[str]: def _iter_raw_hosts(self) -> Iterator[str]:
""" """从数据库查询原始主机列表(可能包含 CIDR"""
从数据库查询原始主机列表(可能包含 CIDR
根据 Target 类型决定数据来源:
- DOMAIN: 根域名 + Subdomain 表
- IP: 直接返回 target.name
- CIDR: 返回 CIDR 字符串(由 iter_hosts() 展开)
注意:此方法不应用黑名单过滤,过滤在 iter_hosts() 中进行
"""
from apps.targets.services import TargetService
from apps.targets.models import Target
from apps.asset.services.asset.subdomain_service import SubdomainService from apps.asset.services.asset.subdomain_service import SubdomainService
from apps.targets.models import Target
from apps.targets.services import TargetService
target = TargetService().get_target(self.target_id) target = TargetService().get_target(self.target_id)
if not target: if not target:
@@ -83,38 +56,27 @@ class DatabaseTargetProvider(TargetProvider):
return return
if target.type == Target.TargetType.DOMAIN: if target.type == Target.TargetType.DOMAIN:
# 返回根域名
yield target.name yield target.name
for domain in SubdomainService().iter_subdomain_names_by_target(
# 返回子域名
subdomain_service = SubdomainService()
for domain in subdomain_service.iter_subdomain_names_by_target(
target_id=self.target_id, target_id=self.target_id,
chunk_size=1000 chunk_size=1000
): ):
if domain != target.name: # 避免重复 if domain != target.name:
yield domain yield domain
elif target.type == Target.TargetType.IP: elif target.type in (Target.TargetType.IP, Target.TargetType.CIDR):
yield target.name
elif target.type == Target.TargetType.CIDR:
# 直接返回 CIDR由 iter_hosts() 展开并过滤
yield target.name yield target.name
def iter_urls(self) -> Iterator[str]: def iter_urls(self) -> Iterator[str]:
""" """从数据库查询 URL 列表使用回退链Endpoint → WebSite → Default"""
从数据库查询 URL 列表
使用现有的回退链逻辑Endpoint → WebSite → Default
"""
from apps.scan.services.target_export_service import ( from apps.scan.services.target_export_service import (
_iter_urls_with_fallback, DataSource DataSource,
_iter_urls_with_fallback,
) )
blacklist = self.get_blacklist_filter() blacklist = self.get_blacklist_filter()
for url, source in _iter_urls_with_fallback( for url, _ in _iter_urls_with_fallback(
target_id=self.target_id, target_id=self.target_id,
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT], sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
blacklist_filter=blacklist blacklist_filter=blacklist

View File

@@ -12,7 +12,7 @@
import ipaddress import ipaddress
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Optional, List, Iterator, Tuple, Callable from typing import Dict, Any, Optional, List, Iterator, Tuple
from django.db.models import QuerySet from django.db.models import QuerySet
@@ -485,7 +485,6 @@ class TargetExportService:
""" """
from apps.targets.services import TargetService from apps.targets.services import TargetService
from apps.targets.models import Target from apps.targets.models import Target
from apps.asset.services.asset.subdomain_service import SubdomainService
output_file = Path(output_path) output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True) output_file.parent.mkdir(parents=True, exist_ok=True)

View File

@@ -16,7 +16,7 @@ from apps.scan.services.target_export_service import (
export_urls_with_fallback, export_urls_with_fallback,
DataSource, DataSource,
) )
from apps.scan.providers import TargetProvider, DatabaseTargetProvider from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -11,12 +11,12 @@
- CIDR: 展开 CIDR 范围内的所有 IP - CIDR: 展开 CIDR 范围内的所有 IP
""" """
import logging import logging
from typing import Optional
from pathlib import Path from pathlib import Path
from typing import Optional
from prefect import task from prefect import task
from apps.scan.services.target_export_service import create_export_service from apps.scan.providers import DatabaseTargetProvider, TargetProvider
from apps.scan.providers import TargetProvider, DatabaseTargetProvider, ProviderContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,7 +26,6 @@ def export_hosts_task(
output_file: str, output_file: str,
target_id: Optional[int] = None, target_id: Optional[int] = None,
provider: Optional[TargetProvider] = None, provider: Optional[TargetProvider] = None,
batch_size: int = 1000
) -> dict: ) -> dict:
""" """
导出主机列表到 TXT 文件 导出主机列表到 TXT 文件
@@ -44,7 +43,6 @@ def export_hosts_task(
output_file: 输出文件路径(绝对路径) output_file: 输出文件路径(绝对路径)
target_id: 目标 ID传统模式向后兼容 target_id: 目标 ID传统模式向后兼容
provider: TargetProvider 实例(新模式) provider: TargetProvider 实例(新模式)
batch_size: 每次读取的批次大小,默认 1000仅对 DOMAIN 类型有效)
Returns: Returns:
dict: { dict: {
@@ -58,33 +56,26 @@ def export_hosts_task(
ValueError: 参数错误target_id 和 provider 都未提供) ValueError: 参数错误target_id 和 provider 都未提供)
IOError: 文件写入失败 IOError: 文件写入失败
""" """
# 参数验证:至少提供一个
if target_id is None and provider is None: if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一") raise ValueError("必须提供 target_id 或 provider 参数之一")
# 向后兼容:如果没有提供 provider使用 target_id 创建 DatabaseTargetProvider # 向后兼容:如果没有提供 provider使用 target_id 创建 DatabaseTargetProvider
if provider is None: use_legacy_mode = provider is None
if use_legacy_mode:
logger.info("使用传统模式 - Target ID: %d", target_id) logger.info("使用传统模式 - Target ID: %d", target_id)
provider = DatabaseTargetProvider(target_id=target_id) provider = DatabaseTargetProvider(target_id=target_id)
use_legacy_mode = True
else: else:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__) logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
use_legacy_mode = False
# 确保输出目录存在 # 确保输出目录存在
output_path = Path(output_file) output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
# 使用 Provider 导出主机列表 # 使用 Provider 导出主机列表iter_hosts 内部已处理黑名单过滤)
total_count = 0 total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f: with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for host in provider.iter_hosts(): for host in provider.iter_hosts():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(host):
continue
f.write(f"{host}\n") f.write(f"{host}\n")
total_count += 1 total_count += 1
@@ -93,7 +84,6 @@ def export_hosts_task(
logger.info("✓ 主机列表导出完成 - 总数: %d, 文件: %s", total_count, str(output_path)) logger.info("✓ 主机列表导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
# 构建返回值
result = { result = {
'success': True, 'success': True,
'output_file': str(output_path), 'output_file': str(output_path),
@@ -102,7 +92,6 @@ def export_hosts_task(
# 传统模式:保持返回值格式不变(向后兼容) # 传统模式:保持返回值格式不变(向后兼容)
if use_legacy_mode: if use_legacy_mode:
# 获取 target_type仅传统模式需要
from apps.targets.services import TargetService from apps.targets.services import TargetService
target = TargetService().get_target(target_id) target = TargetService().get_target(target_id)
result['target_type'] = target.type if target else 'unknown' result['target_type'] = target.type if target else 'unknown'

View File

@@ -4,37 +4,40 @@
提供扫描相关的工具函数。 提供扫描相关的工具函数。
""" """
from .directory_cleanup import remove_directory from . import config_parser
from .command_builder import build_scan_command from .command_builder import build_scan_command
from .command_executor import execute_and_wait, execute_stream from .command_executor import execute_and_wait, execute_stream
from .wordlist_helpers import ensure_wordlist_local from .directory_cleanup import remove_directory
from .nuclei_helpers import ensure_nuclei_templates_local from .nuclei_helpers import ensure_nuclei_templates_local
from .performance import FlowPerformanceTracker, CommandPerformanceTracker from .performance import CommandPerformanceTracker, FlowPerformanceTracker
from .workspace_utils import setup_scan_workspace, setup_scan_directory from .system_load import check_system_load, wait_for_system_load
from .user_logger import user_log from .user_logger import user_log
from . import config_parser from .wordlist_helpers import ensure_wordlist_local
from .workspace_utils import setup_scan_directory, setup_scan_workspace
__all__ = [ __all__ = [
# 目录清理 # 目录清理
'remove_directory', 'remove_directory',
# 工作空间 # 工作空间
'setup_scan_workspace', # 创建 Scan 根工作空间 'setup_scan_workspace',
'setup_scan_directory', # 创建扫描子目录 'setup_scan_directory',
# 命令构建 # 命令构建
'build_scan_command', # 扫描工具命令构建(基于 f-string 'build_scan_command',
# 命令执行 # 命令执行
'execute_and_wait', # 等待式执行(文件输出) 'execute_and_wait',
'execute_stream', # 流式执行(实时处理) 'execute_stream',
# 系统负载
'wait_for_system_load',
'check_system_load',
# 字典文件 # 字典文件
'ensure_wordlist_local', # 确保本地字典文件(含 hash 校验) 'ensure_wordlist_local',
# Nuclei 模板 # Nuclei 模板
'ensure_nuclei_templates_local', # 确保本地模板(含 commit hash 校验) 'ensure_nuclei_templates_local',
# 性能监控 # 性能监控
'FlowPerformanceTracker', # Flow 性能追踪器(含系统资源采样) 'FlowPerformanceTracker',
'CommandPerformanceTracker', # 命令性能追踪器 'CommandPerformanceTracker',
# 扫描日志 # 扫描日志
'user_log', # 用户可见扫描日志记录 'user_log',
# 配置解析 # 配置解析
'config_parser', 'config_parser',
] ]

View File

@@ -12,16 +12,18 @@
import logging import logging
import os import os
from django.conf import settings
import re import re
import signal import signal
import subprocess import subprocess
import threading import threading
import time import time
from collections import deque
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Optional, Generator from typing import Dict, Any, Optional, Generator
from django.conf import settings
try: try:
# 可选依赖:用于根据 CPU / 内存负载做动态并发控制 # 可选依赖:用于根据 CPU / 内存负载做动态并发控制
import psutil import psutil
@@ -669,7 +671,9 @@ class CommandExecutor:
def _read_log_tail(self, log_file: Path, max_lines: int = MAX_LOG_TAIL_LINES) -> str: def _read_log_tail(self, log_file: Path, max_lines: int = MAX_LOG_TAIL_LINES) -> str:
""" """
读取日志文件的末尾部分 读取日志文件的末尾部分(常量内存实现)
使用 seek 从文件末尾往前读取,避免将整个文件加载到内存。
Args: Args:
log_file: 日志文件路径 log_file: 日志文件路径
@@ -682,20 +686,53 @@ class CommandExecutor:
logger.debug("日志文件不存在: %s", log_file) logger.debug("日志文件不存在: %s", log_file)
return "" return ""
if log_file.stat().st_size == 0: file_size = log_file.stat().st_size
if file_size == 0:
logger.debug("日志文件为空: %s", log_file) logger.debug("日志文件为空: %s", log_file)
return "" return ""
# 每次读取的块大小8KB足够容纳大多数日志行
chunk_size = 8192
def decode_line(line_bytes: bytes) -> str:
"""解码单行:优先 UTF-8失败则降级 latin-1"""
try: try:
with open(log_file, 'r', encoding='utf-8') as f: return line_bytes.decode('utf-8')
lines = f.readlines() except UnicodeDecodeError:
return ''.join(lines[-max_lines:] if len(lines) > max_lines else lines) return line_bytes.decode('latin-1', errors='replace')
except UnicodeDecodeError as e:
logger.warning("日志文件编码错误 (%s): %s", log_file, e) try:
return f"(无法读取日志文件: 编码错误 - {e})" with open(log_file, 'rb') as f:
lines_found: deque[bytes] = deque()
remaining = b''
position = file_size
while position > 0 and len(lines_found) < max_lines:
read_size = min(chunk_size, position)
position -= read_size
f.seek(position)
chunk = f.read(read_size) + remaining
parts = chunk.split(b'\n')
# 最前面的部分可能不完整,留到下次处理
remaining = parts[0]
# 其余部分是完整的行(从后往前收集,用 appendleft 保持顺序)
for part in reversed(parts[1:]):
if len(lines_found) >= max_lines:
break
lines_found.appendleft(part)
# 处理文件开头的行
if remaining and len(lines_found) < max_lines:
lines_found.appendleft(remaining)
return '\n'.join(decode_line(line) for line in lines_found)
except PermissionError as e: except PermissionError as e:
logger.warning("日志文件权限不足 (%s): %s", log_file, e) logger.warning("日志文件权限不足 (%s): %s", log_file, e)
return f"(无法读取日志文件: 权限不足)" return "(无法读取日志文件: 权限不足)"
except IOError as e: except IOError as e:
logger.warning("日志文件读取IO错误 (%s): %s", log_file, e) logger.warning("日志文件读取IO错误 (%s): %s", log_file, e)
return f"(无法读取日志文件: IO错误 - {e})" return f"(无法读取日志文件: IO错误 - {e})"

View File

@@ -0,0 +1,77 @@
"""
系统负载检查工具
提供统一的系统负载检查功能,用于:
- Flow 入口处检查系统资源是否充足
- 防止在高负载时启动新的扫描任务
"""
import logging
import time
import psutil
from django.conf import settings
logger = logging.getLogger(__name__)
# 动态并发控制阈值(可在 Django settings 中覆盖)
SCAN_CPU_HIGH: float = getattr(settings, 'SCAN_CPU_HIGH', 90.0)
SCAN_MEM_HIGH: float = getattr(settings, 'SCAN_MEM_HIGH', 80.0)
SCAN_LOAD_CHECK_INTERVAL: int = getattr(settings, 'SCAN_LOAD_CHECK_INTERVAL', 180)
def _get_current_load() -> tuple[float, float]:
"""获取当前 CPU 和内存使用率"""
return psutil.cpu_percent(interval=0.5), psutil.virtual_memory().percent
def wait_for_system_load(
cpu_threshold: float = SCAN_CPU_HIGH,
mem_threshold: float = SCAN_MEM_HIGH,
check_interval: int = SCAN_LOAD_CHECK_INTERVAL,
context: str = "task"
) -> None:
"""
等待系统负载降到阈值以下
在高负载时阻塞等待,直到 CPU 和内存都低于阈值。
用于 Flow 入口处,防止在资源紧张时启动新任务。
"""
while True:
cpu, mem = _get_current_load()
if cpu < cpu_threshold and mem < mem_threshold:
logger.debug(
"[%s] 系统负载正常: cpu=%.1f%%, mem=%.1f%%",
context, cpu, mem
)
return
logger.info(
"[%s] 系统负载较高,等待资源释放: "
"cpu=%.1f%% (阈值 %.1f%%), mem=%.1f%% (阈值 %.1f%%)",
context, cpu, cpu_threshold, mem, mem_threshold
)
time.sleep(check_interval)
def check_system_load(
cpu_threshold: float = SCAN_CPU_HIGH,
mem_threshold: float = SCAN_MEM_HIGH
) -> dict:
"""
检查当前系统负载(非阻塞)
Returns:
dict: cpu_percent, mem_percent, cpu_threshold, mem_threshold, is_overloaded
"""
cpu, mem = _get_current_load()
return {
'cpu_percent': cpu,
'mem_percent': mem,
'cpu_threshold': cpu_threshold,
'mem_threshold': mem_threshold,
'is_overloaded': cpu >= cpu_threshold or mem >= mem_threshold,
}

View File

@@ -18,10 +18,28 @@ disable = [
"no-member", "no-member",
"import-error", "import-error",
"no-name-in-module", "no-name-in-module",
"wrong-import-position", # 允许函数内导入(防循环依赖)
"import-outside-toplevel", # 同上
"too-many-arguments", # Django 视图/服务方法参数常超过5个
"too-many-locals", # 复杂业务逻辑局部变量多
"duplicate-code", # 某些模式代码相似是正常的
] ]
[tool.pylint.format] [tool.pylint.format]
max-line-length = 120 max-line-length = 120
[tool.pylint.basic] [tool.pylint.basic]
good-names = ["i", "j", "k", "ex", "Run", "_", "id", "pk", "ip", "url", "db", "qs"] good-names = [
"i",
"j",
"k",
"ex",
"Run",
"_",
"id",
"pk",
"ip",
"url",
"db",
"qs",
]