diff --git a/backend/apps/scan/flows/directory_scan_flow.py b/backend/apps/scan/flows/directory_scan_flow.py index 631a5b28..ccb28cd1 100644 --- a/backend/apps/scan/flows/directory_scan_flow.py +++ b/backend/apps/scan/flows/directory_scan_flow.py @@ -10,30 +10,30 @@ - 配置由 YAML 解析 """ -# Django 环境初始化(导入即生效) -from apps.common.prefect_django_setup import setup_django_for_prefect - -from prefect import flow -from prefect.task_runners import ThreadPoolTaskRunner - import hashlib import logging -import os import subprocess from datetime import datetime from pathlib import Path from typing import List, Tuple -from apps.scan.tasks.directory_scan import ( - export_sites_task, - run_and_stream_save_directories_task -) +from prefect import flow + from apps.scan.handlers.scan_flow_handlers import ( - on_scan_flow_running, on_scan_flow_completed, on_scan_flow_failed, + on_scan_flow_running, +) +from apps.scan.tasks.directory_scan import ( + export_sites_task, + run_and_stream_save_directories_task, +) +from apps.scan.utils import ( + build_scan_command, + ensure_wordlist_local, + user_log, + wait_for_system_load, ) -from apps.scan.utils import config_parser, build_scan_command, ensure_wordlist_local, user_log logger = logging.getLogger(__name__) @@ -45,517 +45,343 @@ def calculate_directory_scan_timeout( tool_config: dict, base_per_word: float = 1.0, min_timeout: int = 60, - max_timeout: int = 7200 ) -> int: """ 根据字典行数计算目录扫描超时时间 - + 计算公式:超时时间 = 字典行数 × 每个单词基础时间 - 超时范围:60秒 ~ 2小时(7200秒) - + 超时范围:最小 60 秒,无上限 + Args: tool_config: 工具配置字典,包含 wordlist 路径 base_per_word: 每个单词的基础时间(秒),默认 1.0秒 min_timeout: 最小超时时间(秒),默认 60秒 - max_timeout: 最大超时时间(秒),默认 7200秒(2小时) - + Returns: - int: 计算出的超时时间(秒),范围:60 ~ 7200 - - Example: - # 1000行字典 × 1.0秒 = 1000秒 → 限制为7200秒中的 1000秒 - # 10000行字典 × 1.0秒 = 10000秒 → 限制为7200秒(最大值) - timeout = calculate_directory_scan_timeout( - tool_config={'wordlist': '/path/to/wordlist.txt'} - ) + int: 计算出的超时时间(秒) """ + import os + + wordlist_path = tool_config.get('wordlist') + if not wordlist_path: + logger.warning("工具配置中未指定 wordlist,使用默认超时: %d秒", min_timeout) + return min_timeout + + wordlist_path = os.path.expanduser(wordlist_path) + + if not os.path.exists(wordlist_path): + logger.warning("字典文件不存在: %s,使用默认超时: %d秒", wordlist_path, min_timeout) + return min_timeout + try: - # 从 tool_config 中获取 wordlist 路径 - wordlist_path = tool_config.get('wordlist') - if not wordlist_path: - logger.warning("工具配置中未指定 wordlist,使用默认超时: %d秒", min_timeout) - return min_timeout - - # 展开用户目录(~) - wordlist_path = os.path.expanduser(wordlist_path) - - # 检查文件是否存在 - if not os.path.exists(wordlist_path): - logger.warning("字典文件不存在: %s,使用默认超时: %d秒", wordlist_path, min_timeout) - return min_timeout - - # 使用 wc -l 快速统计字典行数 result = subprocess.run( ['wc', '-l', wordlist_path], capture_output=True, text=True, check=True ) - # wc -l 输出格式:行数 + 空格 + 文件名 line_count = int(result.stdout.strip().split()[0]) - - # 计算超时时间 - timeout = int(line_count * base_per_word) - - # 设置合理的下限(不再设置上限) - timeout = max(min_timeout, timeout) - + timeout = max(min_timeout, int(line_count * base_per_word)) + logger.info( "目录扫描超时计算 - 字典: %s, 行数: %d, 基础时间: %.3f秒/词, 计算超时: %d秒", wordlist_path, line_count, base_per_word, timeout ) - return timeout - - except subprocess.CalledProcessError as e: - logger.error("统计字典行数失败: %s", e) - # 失败时返回默认超时 - return min_timeout - except (ValueError, IndexError) as e: - logger.error("解析字典行数失败: %s", e) - return min_timeout - except Exception as e: - logger.error("计算超时时间异常: %s", e) + + except (subprocess.CalledProcessError, ValueError, IndexError) as e: + logger.error("计算超时时间失败: %s", e) return min_timeout def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> int: - """ - 从单个工具配置中获取 max_workers 参数 - - Args: - tool_config: 单个工具的配置字典,如 {'max_workers': 10, 'threads': 5, ...} - default: 默认值,默认为 5 - - Returns: - int: max_workers 值 - """ + """从单个工具配置中获取 max_workers 参数""" if not isinstance(tool_config, dict): return default - - # 支持 max_workers 和 max-workers(YAML 中划线会被转换) + max_workers = tool_config.get('max_workers') or tool_config.get('max-workers') - if max_workers is not None and isinstance(max_workers, int) and max_workers > 0: + if isinstance(max_workers, int) and max_workers > 0: return max_workers return default - - - -def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]: +def _export_site_urls( + target_id: int, + directory_scan_dir: Path +) -> Tuple[str, int]: """ - 导出目标下的所有站点 URL 到文件(支持懒加载) - + 导出目标下的所有站点 URL 到文件 + Args: target_id: 目标 ID - target_name: 目标名称(用于懒加载创建默认站点) directory_scan_dir: 目录扫描目录 - + Returns: tuple: (sites_file, site_count) - - Raises: - ValueError: 站点数量为 0 """ logger.info("Step 1: 导出目标的所有站点 URL") - + sites_file = str(directory_scan_dir / 'sites.txt') export_result = export_sites_task( target_id=target_id, output_file=sites_file, - batch_size=1000 # 每次读取 1000 条,优化内存占用 + batch_size=1000 ) - + site_count = export_result['total_count'] - logger.info( "✓ 站点 URL 导出完成 - 文件: %s, 数量: %d", export_result['output_file'], site_count ) - + if site_count == 0: logger.warning("目标下没有站点,无法执行目录扫描") - # 不抛出异常,由上层决定如何处理 - # raise ValueError("目标下没有站点,无法执行目录扫描") - + return export_result['output_file'], site_count -def _run_scans_sequentially( - enabled_tools: dict, - sites_file: str, - directory_scan_dir: Path, - scan_id: int, - target_id: int, - site_count: int, - target_name: str -) -> tuple[int, int, list]: - """ - 串行执行目录扫描任务(支持多工具)- 已废弃,保留用于兼容 - - Args: - enabled_tools: 启用的工具配置字典 - sites_file: 站点文件路径 - directory_scan_dir: 目录扫描目录 - scan_id: 扫描任务 ID - target_id: 目标 ID - site_count: 站点数量 - target_name: 目标名称(用于错误日志) - - Returns: - tuple: (total_directories, processed_sites, failed_sites) - """ - # 读取站点列表 - sites = [] - with open(sites_file, 'r', encoding='utf-8') as f: - for line in f: - site_url = line.strip() - if site_url: - sites.append(site_url) - - logger.info("准备扫描 %d 个站点,使用工具: %s", len(sites), ', '.join(enabled_tools.keys())) - - total_directories = 0 - processed_sites_set = set() # 使用 set 避免重复计数 - failed_sites = [] - - # 遍历每个工具 - for tool_name, tool_config in enabled_tools.items(): - logger.info("="*60) - logger.info("使用工具: %s", tool_name) - logger.info("="*60) - - # 如果配置了 wordlist_name,则先确保本地存在对应的字典文件(含 hash 校验) - wordlist_name = tool_config.get('wordlist_name') - if wordlist_name: - try: - local_wordlist_path = ensure_wordlist_local(wordlist_name) - tool_config['wordlist'] = local_wordlist_path - except Exception as exc: - logger.error("为工具 %s 准备字典失败: %s", tool_name, exc) - # 当前工具无法执行,将所有站点视为失败,继续下一个工具 - failed_sites.extend(sites) - continue - - # 逐个站点执行扫描 - for idx, site_url in enumerate(sites, 1): - logger.info( - "[%d/%d] 开始扫描站点: %s (工具: %s)", - idx, len(sites), site_url, tool_name - ) - - # 使用统一的命令构建器 - try: - command = build_scan_command( - tool_name=tool_name, - scan_type='directory_scan', - command_params={ - 'url': site_url - }, - tool_config=tool_config - ) - except Exception as e: - logger.error( - "✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s", - idx, len(sites), tool_name, e, site_url - ) - failed_sites.append(site_url) - continue - - # 单个站点超时:从配置中获取(支持 'auto' 动态计算) - # ffuf 逐个站点扫描,timeout 就是单个站点的超时时间 - site_timeout = tool_config.get('timeout', 300) - if site_timeout == 'auto': - # 动态计算超时时间(基于字典行数) - site_timeout = calculate_directory_scan_timeout(tool_config) - logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}秒") - - # 生成日志文件路径 - from datetime import datetime - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - log_file = directory_scan_dir / f"{tool_name}_{timestamp}_{idx}.log" - - try: - # 直接调用 task(串行执行) - result = run_and_stream_save_directories_task( - cmd=command, - tool_name=tool_name, # 新增:工具名称 - scan_id=scan_id, - target_id=target_id, - site_url=site_url, - cwd=str(directory_scan_dir), - shell=True, - batch_size=1000, - timeout=site_timeout, - log_file=str(log_file) # 新增:日志文件路径 - ) - - total_directories += result.get('created_directories', 0) - processed_sites_set.add(site_url) # 使用 set 记录成功的站点 - - logger.info( - "✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录", - idx, len(sites), site_url, - result.get('created_directories', 0) - ) - - except subprocess.TimeoutExpired as exc: - # 超时异常单独处理 - failed_sites.append(site_url) - logger.warning( - "⚠️ [%d/%d] 站点扫描超时: %s - 超时配置: %d秒\n" - "注意:超时前已解析的目录数据已保存到数据库,但扫描未完全完成。", - idx, len(sites), site_url, site_timeout - ) - except Exception as exc: - # 其他异常 - failed_sites.append(site_url) - logger.error( - "✗ [%d/%d] 站点扫描失败: %s - 错误: %s", - idx, len(sites), site_url, exc - ) - - # 每 10 个站点输出进度 - if idx % 10 == 0: - logger.info( - "进度: %d/%d (%.1f%%) - 已发现 %d 个目录", - idx, len(sites), idx/len(sites)*100, total_directories - ) - - # 计算成功和失败的站点数 - processed_count = len(processed_sites_set) - - if failed_sites: - logger.warning( - "部分站点扫描失败: %d/%d", - len(failed_sites), len(sites) - ) - - logger.info( - "✓ 串行目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d", - processed_count, len(sites), len(failed_sites), total_directories - ) - - return total_directories, processed_count, failed_sites - - -def _generate_log_filename(tool_name: str, site_url: str, directory_scan_dir: Path) -> Path: - """ - 生成唯一的日志文件名 - - 使用 URL 的 hash 确保并发时不会冲突 - - Args: - tool_name: 工具名称 - site_url: 站点 URL - directory_scan_dir: 目录扫描目录 - - Returns: - Path: 日志文件路径 - """ - url_hash = hashlib.md5(site_url.encode()).hexdigest()[:8] +def _generate_log_filename( + tool_name: str, + site_url: str, + directory_scan_dir: Path +) -> Path: + """生成唯一的日志文件名(使用 URL 的 hash 确保并发时不会冲突)""" + url_hash = hashlib.md5( + site_url.encode(), + usedforsecurity=False + ).hexdigest()[:8] timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f') return directory_scan_dir / f"{tool_name}_{url_hash}_{timestamp}.log" +def _prepare_tool_wordlist(tool_name: str, tool_config: dict) -> bool: + """准备工具的字典文件,返回是否成功""" + wordlist_name = tool_config.get('wordlist_name') + if not wordlist_name: + return True + + try: + local_wordlist_path = ensure_wordlist_local(wordlist_name) + tool_config['wordlist'] = local_wordlist_path + return True + except Exception as exc: + logger.error("为工具 %s 准备字典失败: %s", tool_name, exc) + return False + + +def _build_scan_params( + tool_name: str, + tool_config: dict, + sites: List[str], + directory_scan_dir: Path, + site_timeout: int +) -> Tuple[List[dict], List[str]]: + """构建所有站点的扫描参数,返回 (scan_params_list, failed_sites)""" + scan_params_list = [] + failed_sites = [] + + for idx, site_url in enumerate(sites, 1): + try: + command = build_scan_command( + tool_name=tool_name, + scan_type='directory_scan', + command_params={'url': site_url}, + tool_config=tool_config + ) + log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir) + scan_params_list.append({ + 'idx': idx, + 'site_url': site_url, + 'command': command, + 'log_file': str(log_file), + 'timeout': site_timeout + }) + except Exception as e: + logger.error( + "✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s", + idx, len(sites), tool_name, e, site_url + ) + failed_sites.append(site_url) + + return scan_params_list, failed_sites + + +def _execute_batch( + batch_params: List[dict], + tool_name: str, + scan_id: int, + target_id: int, + directory_scan_dir: Path, + total_sites: int +) -> Tuple[int, List[str]]: + """执行一批扫描任务,返回 (directories_found, failed_sites)""" + directories_found = 0 + failed_sites = [] + + # 提交任务 + futures = [] + for params in batch_params: + future = run_and_stream_save_directories_task.submit( + cmd=params['command'], + tool_name=tool_name, + scan_id=scan_id, + target_id=target_id, + site_url=params['site_url'], + cwd=str(directory_scan_dir), + shell=True, + batch_size=1000, + timeout=params['timeout'], + log_file=params['log_file'] + ) + futures.append((params['idx'], params['site_url'], future)) + + # 等待结果 + for idx, site_url, future in futures: + try: + result = future.result() + dirs_count = result.get('created_directories', 0) + directories_found += dirs_count + logger.info( + "✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录", + idx, total_sites, site_url, dirs_count + ) + except Exception as exc: + failed_sites.append(site_url) + if 'timeout' in str(exc).lower(): + logger.warning( + "⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s", + idx, total_sites, site_url, exc + ) + else: + logger.error( + "✗ [%d/%d] 站点扫描失败: %s - 错误: %s", + idx, total_sites, site_url, exc + ) + + return directories_found, failed_sites + + def _run_scans_concurrently( enabled_tools: dict, sites_file: str, directory_scan_dir: Path, scan_id: int, target_id: int, - site_count: int, - target_name: str ) -> Tuple[int, int, List[str]]: """ - 并发执行目录扫描任务(使用 ThreadPoolTaskRunner) - - Args: - enabled_tools: 启用的工具配置字典 - sites_file: 站点文件路径 - directory_scan_dir: 目录扫描目录 - scan_id: 扫描任务 ID - target_id: 目标 ID - site_count: 站点数量 - target_name: 目标名称(用于错误日志) - + 并发执行目录扫描任务 + Returns: tuple: (total_directories, processed_sites, failed_sites) """ # 读取站点列表 sites: List[str] = [] with open(sites_file, 'r', encoding='utf-8') as f: - for line in f: - site_url = line.strip() - if site_url: - sites.append(site_url) - + sites = [line.strip() for line in f if line.strip()] + if not sites: logger.warning("站点列表为空") return 0, 0, [] - + logger.info( "准备并发扫描 %d 个站点,使用工具: %s", len(sites), ', '.join(enabled_tools.keys()) ) - + total_directories = 0 processed_sites_count = 0 failed_sites: List[str] = [] - - # 遍历每个工具 + for tool_name, tool_config in enabled_tools.items(): - # 每个工具独立获取 max_workers 配置 max_workers = _get_max_workers(tool_config) - - logger.info("="*60) + + logger.info("=" * 60) logger.info("使用工具: %s (并发模式, max_workers=%d)", tool_name, max_workers) - logger.info("="*60) + logger.info("=" * 60) user_log(scan_id, "directory_scan", f"Running {tool_name}") - # 如果配置了 wordlist_name,则先确保本地存在对应的字典文件(含 hash 校验) - wordlist_name = tool_config.get('wordlist_name') - if wordlist_name: - try: - local_wordlist_path = ensure_wordlist_local(wordlist_name) - tool_config['wordlist'] = local_wordlist_path - except Exception as exc: - logger.error("为工具 %s 准备字典失败: %s", tool_name, exc) - # 当前工具无法执行,将所有站点视为失败,继续下一个工具 - failed_sites.extend(sites) - continue - - # 计算超时时间(所有站点共用) + # 准备字典文件 + if not _prepare_tool_wordlist(tool_name, tool_config): + failed_sites.extend(sites) + continue + + # 计算超时时间 site_timeout = tool_config.get('timeout', 300) if site_timeout == 'auto': site_timeout = calculate_directory_scan_timeout(tool_config) - logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}秒") - - # 准备所有站点的扫描参数 - scan_params_list = [] - for idx, site_url in enumerate(sites, 1): - try: - command = build_scan_command( - tool_name=tool_name, - scan_type='directory_scan', - command_params={'url': site_url}, - tool_config=tool_config - ) - log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir) - scan_params_list.append({ - 'idx': idx, - 'site_url': site_url, - 'command': command, - 'log_file': str(log_file), - 'timeout': site_timeout - }) - except Exception as e: - logger.error( - "✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s", - idx, len(sites), tool_name, e, site_url - ) - failed_sites.append(site_url) - + logger.info("✓ 工具 %s 动态计算 timeout: %d秒", tool_name, site_timeout) + + # 构建扫描参数 + scan_params_list, build_failed = _build_scan_params( + tool_name, tool_config, sites, directory_scan_dir, site_timeout + ) + failed_sites.extend(build_failed) + if not scan_params_list: logger.warning("没有有效的扫描任务") continue - - # ============================================================ - # 分批执行策略:控制实际并发的 ffuf 进程数 - # ============================================================ + + # 分批执行 total_tasks = len(scan_params_list) logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers) - - # 进度里程碑跟踪 + last_progress_percent = 0 tool_directories = 0 tool_processed = 0 - - batch_num = 0 + for batch_start in range(0, total_tasks, max_workers): batch_end = min(batch_start + max_workers, total_tasks) batch_params = scan_params_list[batch_start:batch_end] - batch_num += 1 - - logger.info("执行第 %d 批任务(%d-%d/%d)...", batch_num, batch_start + 1, batch_end, total_tasks) - - # 提交当前批次的任务(非阻塞,立即返回 future) - futures = [] - for params in batch_params: - future = run_and_stream_save_directories_task.submit( - cmd=params['command'], - tool_name=tool_name, - scan_id=scan_id, - target_id=target_id, - site_url=params['site_url'], - cwd=str(directory_scan_dir), - shell=True, - batch_size=1000, - timeout=params['timeout'], - log_file=params['log_file'] - ) - futures.append((params['idx'], params['site_url'], future)) - - # 等待当前批次所有任务完成(阻塞,确保本批完成后再启动下一批) - for idx, site_url, future in futures: - try: - result = future.result() # 阻塞等待单个任务完成 - directories_found = result.get('created_directories', 0) - total_directories += directories_found - tool_directories += directories_found - processed_sites_count += 1 - tool_processed += 1 - - logger.info( - "✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录", - idx, len(sites), site_url, directories_found - ) - - except Exception as exc: - failed_sites.append(site_url) - if 'timeout' in str(exc).lower() or isinstance(exc, subprocess.TimeoutExpired): - logger.warning( - "⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s", - idx, len(sites), site_url, exc - ) - else: - logger.error( - "✗ [%d/%d] 站点扫描失败: %s - 错误: %s", - idx, len(sites), site_url, exc - ) - + batch_num = batch_start // max_workers + 1 + + logger.info( + "执行第 %d 批任务(%d-%d/%d)...", + batch_num, batch_start + 1, batch_end, total_tasks + ) + + dirs_found, batch_failed = _execute_batch( + batch_params, tool_name, scan_id, target_id, + directory_scan_dir, len(sites) + ) + + total_directories += dirs_found + tool_directories += dirs_found + tool_processed += len(batch_params) - len(batch_failed) + processed_sites_count += len(batch_params) - len(batch_failed) + failed_sites.extend(batch_failed) + # 进度里程碑:每 20% 输出一次 current_progress = int((batch_end / total_tasks) * 100) if current_progress >= last_progress_percent + 20: - user_log(scan_id, "directory_scan", f"Progress: {batch_end}/{total_tasks} sites scanned") + user_log( + scan_id, "directory_scan", + f"Progress: {batch_end}/{total_tasks} sites scanned" + ) last_progress_percent = (current_progress // 20) * 20 - - # 工具完成日志(开发者日志 + 用户日志) + logger.info( "✓ 工具 %s 执行完成 - 已处理站点: %d/%d, 发现目录: %d", tool_name, tool_processed, total_tasks, tool_directories ) - user_log(scan_id, "directory_scan", f"{tool_name} completed: found {tool_directories} directories") - - # 输出汇总信息 - if failed_sites: - logger.warning( - "部分站点扫描失败: %d/%d", - len(failed_sites), len(sites) + user_log( + scan_id, "directory_scan", + f"{tool_name} completed: found {tool_directories} directories" ) - + + if failed_sites: + logger.warning("部分站点扫描失败: %d/%d", len(failed_sites), len(sites)) + logger.info( "✓ 并发目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d", processed_sites_count, len(sites), len(failed_sites), total_directories ) - + return total_directories, processed_sites_count, failed_sites @flow( - name="directory_scan", + name="directory_scan", log_prints=True, on_running=[on_scan_flow_running], on_completion=[on_scan_flow_completed], @@ -570,64 +396,31 @@ def directory_scan_flow( ) -> dict: """ 目录扫描 Flow - + 主要功能: 1. 从 target 获取所有站点的 URL 2. 对每个站点 URL 执行目录扫描(支持 ffuf 等工具) 3. 流式保存扫描结果到数据库 Directory 表 - - 工作流程: - Step 0: 创建工作目录 - Step 1: 导出站点 URL 列表到文件(供扫描工具使用) - Step 2: 验证工具配置 - Step 3: 并发执行扫描工具并实时保存结果(使用 ThreadPoolTaskRunner) - - ffuf 输出字段: - - url: 发现的目录/文件 URL - - length: 响应内容长度 - - status: HTTP 状态码 - - words: 响应内容单词数 - - lines: 响应内容行数 - - content_type: 内容类型 - - duration: 请求耗时 - + Args: scan_id: 扫描任务 ID target_name: 目标名称 target_id: 目标 ID scan_workspace_dir: 扫描工作空间目录 enabled_tools: 启用的工具配置字典 - + Returns: - dict: { - 'success': bool, - 'scan_id': int, - 'target': str, - 'scan_workspace_dir': str, - 'sites_file': str, - 'site_count': int, - 'total_directories': int, # 发现的总目录数 - 'processed_sites': int, # 成功处理的站点数 - 'failed_sites_count': int, # 失败的站点数 - 'executed_tasks': list - } - - Raises: - ValueError: 参数错误 - RuntimeError: 执行失败 + dict: 扫描结果 """ try: + wait_for_system_load(context="directory_scan_flow") + logger.info( - "="*60 + "\n" + - "开始目录扫描\n" + - f" Scan ID: {scan_id}\n" + - f" Target: {target_name}\n" + - f" Workspace: {scan_workspace_dir}\n" + - "="*60 + "开始目录扫描 - Scan ID: %s, Target: %s, Workspace: %s", + scan_id, target_name, scan_workspace_dir ) - user_log(scan_id, "directory_scan", "Starting directory scan") - + # 参数验证 if scan_id is None: raise ValueError("scan_id 不能为空") @@ -639,14 +432,14 @@ def directory_scan_flow( raise ValueError("scan_workspace_dir 不能为空") if not enabled_tools: raise ValueError("enabled_tools 不能为空") - + # Step 0: 创建工作目录 from apps.scan.utils import setup_scan_directory directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan') - - # Step 1: 导出站点 URL(支持懒加载) - sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir) - + + # Step 1: 导出站点 URL + sites_file, site_count = _export_site_urls(target_id, directory_scan_dir) + if site_count == 0: logger.warning("跳过目录扫描:没有站点可扫描 - Scan ID: %s", scan_id) user_log(scan_id, "directory_scan", "Skipped: no sites to scan", "warning") @@ -662,16 +455,16 @@ def directory_scan_flow( 'failed_sites_count': 0, 'executed_tasks': ['export_sites'] } - + # Step 2: 工具配置信息 logger.info("Step 2: 工具配置信息") - tool_info = [] - for tool_name, tool_config in enabled_tools.items(): - mw = _get_max_workers(tool_config) - tool_info.append(f"{tool_name}(max_workers={mw})") + tool_info = [ + f"{name}(max_workers={_get_max_workers(cfg)})" + for name, cfg in enabled_tools.items() + ] logger.info("✓ 启用工具: %s", ', '.join(tool_info)) - - # Step 3: 并发执行扫描工具并实时保存结果 + + # Step 3: 并发执行扫描 logger.info("Step 3: 并发执行扫描工具并实时保存结果") total_directories, processed_sites, failed_sites = _run_scans_concurrently( enabled_tools=enabled_tools, @@ -679,19 +472,20 @@ def directory_scan_flow( directory_scan_dir=directory_scan_dir, scan_id=scan_id, target_id=target_id, - site_count=site_count, - target_name=target_name ) - - # 检查是否所有站点都失败 + if processed_sites == 0 and site_count > 0: - logger.warning("所有站点扫描均失败 - 总站点数: %d, 失败数: %d", site_count, len(failed_sites)) - # 不抛出异常,让扫描继续 - - # 记录 Flow 完成 + logger.warning( + "所有站点扫描均失败 - 总站点数: %d, 失败数: %d", + site_count, len(failed_sites) + ) + logger.info("✓ 目录扫描完成 - 发现目录: %d", total_directories) - user_log(scan_id, "directory_scan", f"directory_scan completed: found {total_directories} directories") - + user_log( + scan_id, "directory_scan", + f"directory_scan completed: found {total_directories} directories" + ) + return { 'success': True, 'scan_id': scan_id, @@ -704,7 +498,7 @@ def directory_scan_flow( 'failed_sites_count': len(failed_sites), 'executed_tasks': ['export_sites', 'run_and_stream_save_directories'] } - + except Exception as e: logger.exception("目录扫描失败: %s", e) - raise \ No newline at end of file + raise diff --git a/backend/apps/scan/flows/fingerprint_detect_flow.py b/backend/apps/scan/flows/fingerprint_detect_flow.py index 04dc44cd..b2e9f9ad 100644 --- a/backend/apps/scan/flows/fingerprint_detect_flow.py +++ b/backend/apps/scan/flows/fingerprint_detect_flow.py @@ -10,26 +10,22 @@ - 流式处理输出,批量更新数据库 """ -# Django 环境初始化(导入即生效) -from apps.common.prefect_django_setup import setup_django_for_prefect - import logging -import os from datetime import datetime from pathlib import Path from prefect import flow from apps.scan.handlers.scan_flow_handlers import ( - on_scan_flow_running, on_scan_flow_completed, on_scan_flow_failed, + on_scan_flow_running, ) from apps.scan.tasks.fingerprint_detect import ( export_urls_for_fingerprint_task, run_xingfinger_and_stream_update_tech_task, ) -from apps.scan.utils import build_scan_command, user_log +from apps.scan.utils import build_scan_command, user_log, wait_for_system_load from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths logger = logging.getLogger(__name__) @@ -42,22 +38,19 @@ def calculate_fingerprint_detect_timeout( ) -> int: """ 根据 URL 数量计算超时时间 - + 公式:超时时间 = URL 数量 × 每 URL 基础时间 - 最小值:300秒 - 无上限 - + 最小值:300秒,无上限 + Args: url_count: URL 数量 base_per_url: 每 URL 基础时间(秒),默认 10秒 min_timeout: 最小超时时间(秒),默认 300秒 - + Returns: int: 计算出的超时时间(秒) - """ - timeout = int(url_count * base_per_url) - return max(min_timeout, timeout) + return max(min_timeout, int(url_count * base_per_url)) @@ -70,17 +63,17 @@ def _export_urls( ) -> tuple[str, int]: """ 导出 URL 到文件 - + Args: target_id: 目标 ID fingerprint_dir: 指纹识别目录 source: 数据源类型 - + Returns: tuple: (urls_file, total_count) """ logger.info("Step 1: 导出 URL 列表 (source=%s)", source) - + urls_file = str(fingerprint_dir / 'urls.txt') export_result = export_urls_for_fingerprint_task( target_id=target_id, @@ -88,15 +81,14 @@ def _export_urls( source=source, batch_size=1000 ) - + total_count = export_result['total_count'] - logger.info( "✓ URL 导出完成 - 文件: %s, 数量: %d", export_result['output_file'], total_count ) - + return export_result['output_file'], total_count @@ -111,7 +103,7 @@ def _run_fingerprint_detect( ) -> tuple[dict, list]: """ 执行指纹识别任务 - + Args: enabled_tools: 已启用的工具配置字典 urls_file: URL 文件路径 @@ -120,56 +112,54 @@ def _run_fingerprint_detect( scan_id: 扫描任务 ID target_id: 目标 ID source: 数据源类型 - + Returns: tuple: (tool_stats, failed_tools) """ tool_stats = {} failed_tools = [] - + for tool_name, tool_config in enabled_tools.items(): # 1. 获取指纹库路径 lib_names = tool_config.get('fingerprint_libs', ['ehole']) fingerprint_paths = get_fingerprint_paths(lib_names) - + if not fingerprint_paths: reason = f"没有可用的指纹库: {lib_names}" logger.warning(reason) failed_tools.append({'tool': tool_name, 'reason': reason}) continue - + # 2. 将指纹库路径合并到 tool_config(用于命令构建) tool_config_with_paths = {**tool_config, **fingerprint_paths} - + # 3. 构建命令 try: command = build_scan_command( tool_name=tool_name, scan_type='fingerprint_detect', - command_params={ - 'urls_file': urls_file - }, + command_params={'urls_file': urls_file}, tool_config=tool_config_with_paths ) except Exception as e: - reason = f"命令构建失败: {str(e)}" + reason = f"命令构建失败: {e}" logger.error("构建 %s 命令失败: %s", tool_name, e) failed_tools.append({'tool': tool_name, 'reason': reason}) continue - + # 4. 计算超时时间 timeout = calculate_fingerprint_detect_timeout(url_count) - + # 5. 生成日志文件路径 timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') log_file = fingerprint_dir / f"{tool_name}_{timestamp}.log" - + logger.info( "开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s", tool_name, url_count, timeout, list(fingerprint_paths.keys()) ) user_log(scan_id, "fingerprint_detect", f"Running {tool_name}: {command}") - + # 6. 执行扫描任务 try: result = run_xingfinger_and_stream_update_tech_task( @@ -183,14 +173,14 @@ def _run_fingerprint_detect( log_file=str(log_file), batch_size=100 ) - + tool_stats[tool_name] = { 'command': command, 'result': result, 'timeout': timeout, 'fingerprint_libs': list(fingerprint_paths.keys()) } - + tool_updated = result.get('updated_count', 0) logger.info( "✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d", @@ -199,20 +189,23 @@ def _run_fingerprint_detect( tool_updated, result.get('not_found_count', 0) ) - user_log(scan_id, "fingerprint_detect", f"{tool_name} completed: identified {tool_updated} fingerprints") - + user_log( + scan_id, "fingerprint_detect", + f"{tool_name} completed: identified {tool_updated} fingerprints" + ) + except Exception as exc: reason = str(exc) failed_tools.append({'tool': tool_name, 'reason': reason}) logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True) user_log(scan_id, "fingerprint_detect", f"{tool_name} failed: {reason}", "error") - + if failed_tools: logger.warning( "以下指纹识别工具执行失败: %s", ', '.join([f['tool'] for f in failed_tools]) ) - + return tool_stats, failed_tools @@ -232,53 +225,38 @@ def fingerprint_detect_flow( ) -> dict: """ 指纹识别 Flow - + 主要功能: 1. 从数据库导出目标下所有 WebSite URL 到文件 2. 使用 xingfinger 进行技术栈识别 3. 解析结果并更新 WebSite.tech 字段(合并去重) - + 工作流程: Step 0: 创建工作目录 Step 1: 导出 URL 列表 Step 2: 解析配置,获取启用的工具 Step 3: 执行 xingfinger 并解析结果 - + Args: scan_id: 扫描任务 ID target_name: 目标名称 target_id: 目标 ID scan_workspace_dir: 扫描工作空间目录 enabled_tools: 启用的工具配置(xingfinger) - + Returns: - dict: { - 'success': bool, - 'scan_id': int, - 'target': str, - 'scan_workspace_dir': str, - 'urls_file': str, - 'url_count': int, - 'processed_records': int, - 'updated_count': int, - 'created_count': int, - 'snapshot_count': int, - 'executed_tasks': list, - 'tool_stats': dict - } + dict: 扫描结果 """ try: + # 负载检查:等待系统资源充足 + wait_for_system_load(context="fingerprint_detect_flow") + logger.info( - "="*60 + "\n" + - "开始指纹识别\n" + - f" Scan ID: {scan_id}\n" + - f" Target: {target_name}\n" + - f" Workspace: {scan_workspace_dir}\n" + - "="*60 + "开始指纹识别 - Scan ID: %s, Target: %s, Workspace: %s", + scan_id, target_name, scan_workspace_dir ) - user_log(scan_id, "fingerprint_detect", "Starting fingerprint detection") - + # 参数验证 if scan_id is None: raise ValueError("scan_id 不能为空") @@ -288,46 +266,26 @@ def fingerprint_detect_flow( raise ValueError("target_id 不能为空") if not scan_workspace_dir: raise ValueError("scan_workspace_dir 不能为空") - + # 数据源类型(当前只支持 website) source = 'website' - + # Step 0: 创建工作目录 from apps.scan.utils import setup_scan_directory fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect') - + # Step 1: 导出 URL(支持懒加载) urls_file, url_count = _export_urls(target_id, fingerprint_dir, source) - + if url_count == 0: logger.warning("跳过指纹识别:没有 URL 可扫描 - Scan ID: %s", scan_id) user_log(scan_id, "fingerprint_detect", "Skipped: no URLs to scan", "warning") - return { - 'success': True, - 'scan_id': scan_id, - 'target': target_name, - 'scan_workspace_dir': scan_workspace_dir, - 'urls_file': urls_file, - 'url_count': 0, - 'processed_records': 0, - 'updated_count': 0, - 'created_count': 0, - 'snapshot_count': 0, - 'executed_tasks': ['export_urls_for_fingerprint'], - 'tool_stats': { - 'total': 0, - 'successful': 0, - 'failed': 0, - 'successful_tools': [], - 'failed_tools': [], - 'details': {} - } - } - + return _build_empty_result(scan_id, target_name, scan_workspace_dir, urls_file) + # Step 2: 工具配置信息 logger.info("Step 2: 工具配置信息") logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys())) - + # Step 3: 执行指纹识别 logger.info("Step 3: 执行指纹识别") tool_stats, failed_tools = _run_fingerprint_detect( @@ -339,24 +297,37 @@ def fingerprint_detect_flow( target_id=target_id, source=source ) - + # 动态生成已执行的任务列表 executed_tasks = ['export_urls_for_fingerprint'] - executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats.keys()]) - + executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats]) + # 汇总所有工具的结果 - total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values()) - total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values()) - total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values()) - total_snapshots = sum(stats['result'].get('snapshot_count', 0) for stats in tool_stats.values()) - + total_processed = sum( + stats['result'].get('processed_records', 0) for stats in tool_stats.values() + ) + total_updated = sum( + stats['result'].get('updated_count', 0) for stats in tool_stats.values() + ) + total_created = sum( + stats['result'].get('created_count', 0) for stats in tool_stats.values() + ) + total_snapshots = sum( + stats['result'].get('snapshot_count', 0) for stats in tool_stats.values() + ) + # 记录 Flow 完成 logger.info("✓ 指纹识别完成 - 识别指纹: %d", total_updated) - user_log(scan_id, "fingerprint_detect", f"fingerprint_detect completed: identified {total_updated} fingerprints") - - successful_tools = [name for name in enabled_tools.keys() - if name not in [f['tool'] for f in failed_tools]] - + user_log( + scan_id, "fingerprint_detect", + f"fingerprint_detect completed: identified {total_updated} fingerprints" + ) + + successful_tools = [ + name for name in enabled_tools + if name not in [f['tool'] for f in failed_tools] + ] + return { 'success': True, 'scan_id': scan_id, @@ -378,7 +349,7 @@ def fingerprint_detect_flow( 'details': tool_stats } } - + except ValueError as e: logger.error("配置错误: %s", e) raise @@ -388,3 +359,33 @@ def fingerprint_detect_flow( except Exception as e: logger.exception("指纹识别失败: %s", e) raise + + +def _build_empty_result( + scan_id: int, + target_name: str, + scan_workspace_dir: str, + urls_file: str +) -> dict: + """构建空结果(无 URL 可扫描时)""" + return { + 'success': True, + 'scan_id': scan_id, + 'target': target_name, + 'scan_workspace_dir': scan_workspace_dir, + 'urls_file': urls_file, + 'url_count': 0, + 'processed_records': 0, + 'updated_count': 0, + 'created_count': 0, + 'snapshot_count': 0, + 'executed_tasks': ['export_urls_for_fingerprint'], + 'tool_stats': { + 'total': 0, + 'successful': 0, + 'failed': 0, + 'successful_tools': [], + 'failed_tools': [], + 'details': {} + } + } diff --git a/backend/apps/scan/flows/port_scan_flow.py b/backend/apps/scan/flows/port_scan_flow.py index 8116c1bd..1918afd9 100644 --- a/backend/apps/scan/flows/port_scan_flow.py +++ b/backend/apps/scan/flows/port_scan_flow.py @@ -1,4 +1,4 @@ -""" +""" 端口扫描 Flow 负责编排端口扫描的完整流程 @@ -10,25 +10,23 @@ - 配置由 YAML 解析 """ -# Django 环境初始化(导入即生效) -from apps.common.prefect_django_setup import setup_django_for_prefect - import logging -import os import subprocess +from datetime import datetime from pathlib import Path -from typing import Callable + from prefect import flow -from apps.scan.tasks.port_scan import ( - export_hosts_task, - run_and_stream_save_ports_task -) + from apps.scan.handlers.scan_flow_handlers import ( - on_scan_flow_running, on_scan_flow_completed, on_scan_flow_failed, + on_scan_flow_running, ) -from apps.scan.utils import config_parser, build_scan_command, user_log +from apps.scan.tasks.port_scan import ( + export_hosts_task, + run_and_stream_save_ports_task, +) +from apps.scan.utils import build_scan_command, user_log, wait_for_system_load logger = logging.getLogger(__name__) @@ -40,28 +38,19 @@ def calculate_port_scan_timeout( ) -> int: """ 根据目标数量和端口数量计算超时时间 - + 计算公式:超时时间 = 目标数 × 端口数 × base_per_pair - 超时范围:60秒 ~ 2天(172800秒) - + 超时范围:60秒 ~ 无上限 + Args: tool_config: 工具配置字典,包含端口配置(ports, top-ports等) file_path: 目标文件路径(域名/IP列表) base_per_pair: 每个"端口-目标对"的基础时间(秒),默认 0.5秒 - + Returns: - int: 计算出的超时时间(秒),范围:60 ~ 172800 - - Example: - # 100个目标 × 100个端口 × 0.5秒 = 5000秒 - # 10个目标 × 1000个端口 × 0.5秒 = 5000秒 - timeout = calculate_port_scan_timeout( - tool_config={'top-ports': 100}, - file_path='/path/to/domains.txt' - ) + int: 计算出的超时时间(秒),最小 60 秒 """ try: - # 1. 统计目标数量 result = subprocess.run( ['wc', '-l', file_path], capture_output=True, @@ -69,88 +58,74 @@ def calculate_port_scan_timeout( check=True ) target_count = int(result.stdout.strip().split()[0]) - - # 2. 解析端口数量 port_count = _parse_port_count(tool_config) - - # 3. 计算超时时间 - # 总工作量 = 目标数 × 端口数 total_work = target_count * port_count - timeout = int(total_work * base_per_pair) - - # 4. 设置合理的下限(不再设置上限) - min_timeout = 60 # 最小 60 秒 - timeout = max(min_timeout, timeout) - + timeout = max(60, int(total_work * base_per_pair)) + logger.info( - f"计算端口扫描 timeout - " - f"目标数: {target_count}, " - f"端口数: {port_count}, " - f"总工作量: {total_work}, " - f"超时: {timeout}秒" + "计算端口扫描 timeout - 目标数: %d, 端口数: %d, 总工作量: %d, 超时: %d秒", + target_count, port_count, total_work, timeout ) return timeout - + except Exception as e: - logger.warning(f"计算 timeout 失败: {e},使用默认值 600秒") + logger.warning("计算 timeout 失败: %s,使用默认值 600秒", e) return 600 def _parse_port_count(tool_config: dict) -> int: """ 从工具配置中解析端口数量 - + 优先级: 1. top-ports: N → 返回 N 2. ports: "80,443,8080" → 返回逗号分隔的数量 3. ports: "1-1000" → 返回范围的大小 4. ports: "1-65535" → 返回 65535 5. 默认 → 返回 100(naabu 默认扫描 top 100) - + Args: tool_config: 工具配置字典 - + Returns: int: 端口数量 """ - # 1. 检查 top-ports 配置 + # 检查 top-ports 配置 if 'top-ports' in tool_config: top_ports = tool_config['top-ports'] if isinstance(top_ports, int) and top_ports > 0: return top_ports - logger.warning(f"top-ports 配置无效: {top_ports},使用默认值") - - # 2. 检查 ports 配置 + logger.warning("top-ports 配置无效: %s,使用默认值", top_ports) + + # 检查 ports 配置 if 'ports' in tool_config: ports_str = str(tool_config['ports']).strip() - - # 2.1 逗号分隔的端口列表:80,443,8080 + + # 逗号分隔的端口列表:80,443,8080 if ',' in ports_str: - port_list = [p.strip() for p in ports_str.split(',') if p.strip()] - return len(port_list) - - # 2.2 端口范围:1-1000 + return len([p.strip() for p in ports_str.split(',') if p.strip()]) + + # 端口范围:1-1000 if '-' in ports_str: try: start, end = ports_str.split('-', 1) start_port = int(start.strip()) end_port = int(end.strip()) - if 1 <= start_port <= end_port <= 65535: return end_port - start_port + 1 - logger.warning(f"端口范围无效: {ports_str},使用默认值") + logger.warning("端口范围无效: %s,使用默认值", ports_str) except ValueError: - logger.warning(f"端口范围解析失败: {ports_str},使用默认值") - - # 2.3 单个端口 + logger.warning("端口范围解析失败: %s,使用默认值", ports_str) + + # 单个端口 try: port = int(ports_str) if 1 <= port <= 65535: return 1 except ValueError: - logger.warning(f"端口配置解析失败: {ports_str},使用默认值") - - # 3. 默认值:naabu 默认扫描 top 100 端口 + logger.warning("端口配置解析失败: %s,使用默认值", ports_str) + + # 默认值:naabu 默认扫描 top 100 端口 return 100 @@ -160,41 +135,38 @@ def _parse_port_count(tool_config: dict) -> int: def _export_hosts(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]: """ 导出主机列表到文件 - + 根据 Target 类型自动决定导出内容: - DOMAIN: 从 Subdomain 表导出子域名 - IP: 直接写入 target.name - CIDR: 展开 CIDR 范围内的所有 IP - + Args: target_id: 目标 ID port_scan_dir: 端口扫描目录 - + Returns: tuple: (hosts_file, host_count, target_type) """ logger.info("Step 1: 导出主机列表") - + hosts_file = str(port_scan_dir / 'hosts.txt') export_result = export_hosts_task( target_id=target_id, output_file=hosts_file, - batch_size=1000 # 每次读取 1000 条,优化内存占用 ) - + host_count = export_result['total_count'] target_type = export_result.get('target_type', 'unknown') - + logger.info( "✓ 主机列表导出完成 - 类型: %s, 文件: %s, 数量: %d", - target_type, - export_result['output_file'], - host_count + target_type, export_result['output_file'], host_count ) - + if host_count == 0: logger.warning("目标下没有可扫描的主机,无法执行端口扫描") - + return export_result['output_file'], host_count, target_type @@ -208,7 +180,7 @@ def _run_scans_sequentially( ) -> tuple[dict, int, list, list]: """ 串行执行端口扫描任务 - + Args: enabled_tools: 已启用的工具配置字典 domains_file: 域名文件路径 @@ -216,72 +188,56 @@ def _run_scans_sequentially( scan_id: 扫描任务 ID target_id: 目标 ID target_name: 目标名称(用于错误日志) - + Returns: tuple: (tool_stats, processed_records, successful_tool_names, failed_tools) - 注意:端口扫描是流式输出,不生成结果文件 - - Raises: - RuntimeError: 所有工具均失败 """ - # ==================== 构建命令并串行执行 ==================== - tool_stats = {} processed_records = 0 - failed_tools = [] # 记录失败的工具(含原因) - - # for循环执行工具:按顺序串行运行每个启用的端口扫描工具 + failed_tools = [] + for tool_name, tool_config in enabled_tools.items(): - # 1. 构建完整命令(变量替换) + # 构建命令 try: command = build_scan_command( tool_name=tool_name, scan_type='port_scan', - command_params={ - 'domains_file': domains_file # 对应 {domains_file} - }, - tool_config=tool_config #yaml的工具配置 + command_params={'domains_file': domains_file}, + tool_config=tool_config ) except Exception as e: - reason = f"命令构建失败: {str(e)}" - logger.error(f"构建 {tool_name} 命令失败: {e}") + reason = f"命令构建失败: {e}" + logger.error("构建 %s 命令失败: %s", tool_name, e) failed_tools.append({'tool': tool_name, 'reason': reason}) continue - - # 2. 获取超时时间(支持 'auto' 动态计算) + + # 获取超时时间 config_timeout = tool_config['timeout'] if config_timeout == 'auto': - # 动态计算超时时间 - config_timeout = calculate_port_scan_timeout( - tool_config=tool_config, - file_path=str(domains_file) - ) - logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {config_timeout}秒") - - # 2.1 生成日志文件路径 - from datetime import datetime + config_timeout = calculate_port_scan_timeout(tool_config, str(domains_file)) + logger.info("✓ 工具 %s 动态计算 timeout: %d秒", tool_name, config_timeout) + + # 生成日志文件路径 timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') log_file = port_scan_dir / f"{tool_name}_{timestamp}.log" - - # 3. 执行扫描任务 + logger.info("开始执行 %s 扫描(超时: %d秒)...", tool_name, config_timeout) user_log(scan_id, "port_scan", f"Running {tool_name}: {command}") - + + # 执行扫描任务 try: - # 直接调用 task(串行执行) - # 注意:端口扫描是流式输出到 stdout,不使用 output_file result = run_and_stream_save_ports_task( cmd=command, - tool_name=tool_name, # 工具名称 + tool_name=tool_name, scan_id=scan_id, target_id=target_id, cwd=str(port_scan_dir), shell=True, batch_size=1000, timeout=config_timeout, - log_file=str(log_file) # 新增:日志文件路径 + log_file=str(log_file) ) - + tool_stats[tool_name] = { 'command': command, 'result': result, @@ -289,15 +245,10 @@ def _run_scans_sequentially( } tool_records = result.get('processed_records', 0) processed_records += tool_records - logger.info( - "✓ 工具 %s 流式处理完成 - 记录数: %d", - tool_name, tool_records - ) + logger.info("✓ 工具 %s 流式处理完成 - 记录数: %d", tool_name, tool_records) user_log(scan_id, "port_scan", f"{tool_name} completed: found {tool_records} ports") - - except subprocess.TimeoutExpired as exc: - # 超时异常单独处理 - # 注意:流式处理任务超时时,已解析的数据已保存到数据库 + + except subprocess.TimeoutExpired: reason = f"timeout after {config_timeout}s" failed_tools.append({'tool': tool_name, 'reason': reason}) logger.warning( @@ -307,40 +258,39 @@ def _run_scans_sequentially( ) user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error") except Exception as exc: - # 其他异常 reason = str(exc) failed_tools.append({'tool': tool_name, 'reason': reason}) logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True) user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error") - + if failed_tools: logger.warning( "以下扫描工具执行失败: %s", ', '.join([f['tool'] for f in failed_tools]) ) - + if not tool_stats: error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools]) logger.warning("所有端口扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details) - # 返回空结果,不抛出异常,让扫描继续 return {}, 0, [], failed_tools - - # 动态计算成功的工具列表 - successful_tool_names = [name for name in enabled_tools.keys() - if name not in [f['tool'] for f in failed_tools]] - + + successful_tool_names = [ + name for name in enabled_tools + if name not in [f['tool'] for f in failed_tools] + ] + logger.info( "✓ 串行端口扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)", len(tool_stats), len(enabled_tools), ', '.join(successful_tool_names) if successful_tool_names else '无', ', '.join([f['tool'] for f in failed_tools]) if failed_tools else '无' ) - + return tool_stats, processed_records, successful_tool_names, failed_tools @flow( - name="port_scan", + name="port_scan", log_prints=True, on_running=[on_scan_flow_running], on_completion=[on_scan_flow_completed], @@ -355,19 +305,19 @@ def port_scan_flow( ) -> dict: """ 端口扫描 Flow - + 主要功能: 1. 扫描目标域名/IP 的开放端口 2. 保存 host + ip + port 三元映射到 HostPortMapping 表 - + 输出资产: - HostPortMapping:主机端口映射(host + ip + port 三元组) - + 工作流程: Step 0: 创建工作目录 Step 1: 导出域名列表到文件(供扫描工具使用) Step 2: 解析配置,获取启用的工具 - Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库(→ HostPortMapping) + Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库 Args: scan_id: 扫描任务 ID @@ -377,35 +327,15 @@ def port_scan_flow( enabled_tools: 启用的工具配置字典 Returns: - dict: { - 'success': bool, - 'scan_id': int, - 'target': str, - 'scan_workspace_dir': str, - 'hosts_file': str, - 'host_count': int, - 'processed_records': int, - 'executed_tasks': list, - 'tool_stats': { - 'total': int, # 总工具数 - 'successful': int, # 成功工具数 - 'failed': int, # 失败工具数 - 'successful_tools': list[str], # 成功工具列表 ['naabu_active'] - 'failed_tools': list[dict], # 失败工具列表 [{'tool': 'naabu_passive', 'reason': '超时'}] - 'details': dict # 详细执行结果(保留向后兼容) - } - } + dict: 扫描结果 Raises: ValueError: 配置错误 RuntimeError: 执行失败 - - Note: - 端口扫描工具(如 naabu)会解析域名获取 IP,输出 host + ip + port 三元组。 - 同一 host 可能对应多个 IP(CDN、负载均衡),因此使用三元映射表存储。 """ try: - # 参数验证 + wait_for_system_load(context="port_scan_flow") + if scan_id is None: raise ValueError("scan_id 不能为空") if not target_name: @@ -416,25 +346,20 @@ def port_scan_flow( raise ValueError("scan_workspace_dir 不能为空") if not enabled_tools: raise ValueError("enabled_tools 不能为空") - + logger.info( - "="*60 + "\n" + - "开始端口扫描\n" + - f" Scan ID: {scan_id}\n" + - f" Target: {target_name}\n" + - f" Workspace: {scan_workspace_dir}\n" + - "="*60 + "开始端口扫描 - Scan ID: %s, Target: %s, Workspace: %s", + scan_id, target_name, scan_workspace_dir ) - user_log(scan_id, "port_scan", "Starting port scan") - + # Step 0: 创建工作目录 from apps.scan.utils import setup_scan_directory port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan') - - # Step 1: 导出主机列表到文件(根据 Target 类型自动决定内容) + + # Step 1: 导出主机列表 hosts_file, host_count, target_type = _export_hosts(target_id, port_scan_dir) - + if host_count == 0: logger.warning("跳过端口扫描:没有主机可扫描 - Scan ID: %s", scan_id) user_log(scan_id, "port_scan", "Skipped: no hosts to scan", "warning") @@ -457,14 +382,11 @@ def port_scan_flow( 'details': {} } } - + # Step 2: 工具配置信息 logger.info("Step 2: 工具配置信息") - logger.info( - "✓ 启用工具: %s", - ', '.join(enabled_tools.keys()) - ) - + logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys())) + # Step 3: 串行执行扫描工具 logger.info("Step 3: 串行执行扫描工具") tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially( @@ -475,15 +397,13 @@ def port_scan_flow( target_id=target_id, target_name=target_name ) - - # 记录 Flow 完成 + logger.info("✓ 端口扫描完成 - 发现端口: %d", processed_records) user_log(scan_id, "port_scan", f"port_scan completed: found {processed_records} ports") - - # 动态生成已执行的任务列表 + executed_tasks = ['export_hosts', 'parse_config'] - executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats.keys()]) - + executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats]) + return { 'success': True, 'scan_id': scan_id, diff --git a/backend/apps/scan/flows/screenshot_flow.py b/backend/apps/scan/flows/screenshot_flow.py index bfab626b..579de55e 100644 --- a/backend/apps/scan/flows/screenshot_flow.py +++ b/backend/apps/scan/flows/screenshot_flow.py @@ -11,43 +11,33 @@ 2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL """ -# Django 环境初始化 -from apps.common.prefect_django_setup import setup_django_for_prefect - import logging -from pathlib import Path from typing import Optional + from prefect import flow -from apps.scan.tasks.screenshot import capture_screenshots_task from apps.scan.handlers.scan_flow_handlers import ( - on_scan_flow_running, on_scan_flow_completed, on_scan_flow_failed, -) -from apps.scan.utils import user_log -from apps.scan.services.target_export_service import ( - get_urls_with_fallback, - DataSource, + on_scan_flow_running, ) from apps.scan.providers import TargetProvider +from apps.scan.services.target_export_service import DataSource, get_urls_with_fallback +from apps.scan.tasks.screenshot import capture_screenshots_task +from apps.scan.utils import user_log, wait_for_system_load logger = logging.getLogger(__name__) +# URL 来源到 DataSource 的映射 +_SOURCE_MAPPING = { + 'websites': DataSource.WEBSITE, + 'endpoints': DataSource.ENDPOINT, +} + def _parse_screenshot_config(enabled_tools: dict) -> dict: - """ - 解析截图配置 - - Args: - enabled_tools: 启用的工具配置 - - Returns: - 截图配置字典 - """ - # 从 enabled_tools 中获取 playwright 配置 + """解析截图配置""" playwright_config = enabled_tools.get('playwright', {}) - return { 'concurrency': playwright_config.get('concurrency', 5), 'url_sources': playwright_config.get('url_sources', ['websites']) @@ -55,33 +45,54 @@ def _parse_screenshot_config(enabled_tools: dict) -> dict: def _map_url_sources_to_data_sources(url_sources: list[str]) -> list[str]: - """ - 将配置中的 url_sources 映射为 DataSource 常量 - - Args: - url_sources: 配置中的来源列表,如 ['websites', 'endpoints'] - - Returns: - DataSource 常量列表 - """ - source_mapping = { - 'websites': DataSource.WEBSITE, - 'endpoints': DataSource.ENDPOINT, - } - + """将配置中的 url_sources 映射为 DataSource 常量""" sources = [] for source in url_sources: - if source in source_mapping: - sources.append(source_mapping[source]) + if source in _SOURCE_MAPPING: + sources.append(_SOURCE_MAPPING[source]) else: logger.warning("未知的 URL 来源: %s,跳过", source) - + # 添加默认回退(从 subdomain 构造) sources.append(DataSource.DEFAULT) - return sources +def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str, list[str]]: + """从 Provider 收集 URL""" + logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__) + urls = list(provider.iter_urls()) + + blacklist_filter = provider.get_blacklist_filter() + if blacklist_filter: + urls = [url for url in urls if blacklist_filter.is_allowed(url)] + + return urls, 'provider', ['provider'] + + +def _collect_urls_from_database( + target_id: int, + url_sources: list[str] +) -> tuple[list[str], str, list[str]]: + """从数据库收集 URL(带黑名单过滤和回退)""" + data_sources = _map_url_sources_to_data_sources(url_sources) + result = get_urls_with_fallback(target_id, sources=data_sources) + return result['urls'], result['source'], result['tried_sources'] + + +def _build_empty_result(scan_id: int, target_name: str) -> dict: + """构建空结果""" + return { + 'success': True, + 'scan_id': scan_id, + 'target': target_name, + 'total_urls': 0, + 'successful': 0, + 'failed': 0, + 'synced': 0 + } + + @flow( name="screenshot", log_prints=True, @@ -99,17 +110,11 @@ def screenshot_flow( ) -> dict: """ 截图 Flow - + 支持两种模式: 1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL 2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL - - 工作流程: - Step 1: 解析配置 - Step 2: 收集 URL 列表 - Step 3: 批量截图并保存快照 - Step 4: 同步到资产表 - + Args: scan_id: 扫描任务 ID target_name: 目标名称 @@ -117,116 +122,87 @@ def screenshot_flow( scan_workspace_dir: 扫描工作空间目录 enabled_tools: 启用的工具配置 provider: TargetProvider 实例(新模式,可选) - + Returns: - dict: { - 'success': bool, - 'scan_id': int, - 'target': str, - 'total_urls': int, - 'successful': int, - 'failed': int, - 'synced': int - } + 截图结果字典 """ try: + # 负载检查:等待系统资源充足 + wait_for_system_load(context="screenshot_flow") + + mode = 'Provider' if provider else 'Legacy' logger.info( - "="*60 + "\n" + - "开始截图扫描\n" + - f" Scan ID: {scan_id}\n" + - f" Target: {target_name}\n" + - f" Workspace: {scan_workspace_dir}\n" + - f" Mode: {'Provider' if provider else 'Legacy'}\n" + - "="*60 + "开始截图扫描 - Scan ID: %s, Target: %s, Mode: %s", + scan_id, target_name, mode ) - user_log(scan_id, "screenshot", "Starting screenshot capture") - + # Step 1: 解析配置 config = _parse_screenshot_config(enabled_tools) concurrency = config['concurrency'] - url_sources = config['url_sources'] - - logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, url_sources) - + logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, config['url_sources']) + # Step 2: 收集 URL 列表 if provider is not None: - # Provider 模式:使用 TargetProvider 获取 URL - logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__) - urls = list(provider.iter_urls()) - - # 应用黑名单过滤(如果有) - blacklist_filter = provider.get_blacklist_filter() - if blacklist_filter: - urls = [url for url in urls if blacklist_filter.is_allowed(url)] - - source_info = 'provider' - tried_sources = ['provider'] + urls, source_info, tried_sources = _collect_urls_from_provider(provider) else: - # 传统模式:使用统一服务收集 URL(带黑名单过滤和回退) - data_sources = _map_url_sources_to_data_sources(url_sources) - result = get_urls_with_fallback(target_id, sources=data_sources) - - urls = result['urls'] - source_info = result['source'] - tried_sources = result['tried_sources'] - + urls, source_info, tried_sources = _collect_urls_from_database( + target_id, config['url_sources'] + ) + logger.info( "URL 收集完成 - 来源: %s, 数量: %d, 尝试过: %s", source_info, len(urls), tried_sources ) - + if not urls: logger.warning("没有可截图的 URL,跳过截图任务") user_log(scan_id, "screenshot", "Skipped: no URLs to capture", "warning") - return { - 'success': True, - 'scan_id': scan_id, - 'target': target_name, - 'total_urls': 0, - 'successful': 0, - 'failed': 0, - 'synced': 0 - } - - user_log(scan_id, "screenshot", f"Found {len(urls)} URLs to capture (source: {source_info})") - + return _build_empty_result(scan_id, target_name) + + user_log( + scan_id, "screenshot", + f"Found {len(urls)} URLs to capture (source: {source_info})" + ) + # Step 3: 批量截图 - logger.info("Step 3: 批量截图 - %d 个 URL", len(urls)) - + logger.info("批量截图 - %d 个 URL", len(urls)) capture_result = capture_screenshots_task( urls=urls, scan_id=scan_id, target_id=target_id, config={'concurrency': concurrency} ) - + # Step 4: 同步到资产表 - logger.info("Step 4: 同步截图到资产表") + logger.info("同步截图到资产表") from apps.asset.services.screenshot_service import ScreenshotService - screenshot_service = ScreenshotService() - synced = screenshot_service.sync_screenshots_to_asset(scan_id, target_id) - + synced = ScreenshotService().sync_screenshots_to_asset(scan_id, target_id) + + total = capture_result['total'] + successful = capture_result['successful'] + failed = capture_result['failed'] + logger.info( "✓ 截图完成 - 总数: %d, 成功: %d, 失败: %d, 同步: %d", - capture_result['total'], capture_result['successful'], capture_result['failed'], synced + total, successful, failed, synced ) user_log( scan_id, "screenshot", - f"Screenshot completed: {capture_result['successful']}/{capture_result['total']} captured, {synced} synced" + f"Screenshot completed: {successful}/{total} captured, {synced} synced" ) - + return { 'success': True, 'scan_id': scan_id, 'target': target_name, - 'total_urls': capture_result['total'], - 'successful': capture_result['successful'], - 'failed': capture_result['failed'], + 'total_urls': total, + 'successful': successful, + 'failed': failed, 'synced': synced } - - except Exception as e: - logger.exception("截图 Flow 失败: %s", e) - user_log(scan_id, "screenshot", f"Screenshot failed: {e}", "error") + + except Exception: + logger.exception("截图 Flow 失败") + user_log(scan_id, "screenshot", "Screenshot failed", "error") raise diff --git a/backend/apps/scan/flows/site_scan_flow.py b/backend/apps/scan/flows/site_scan_flow.py index 97e70610..e9897989 100644 --- a/backend/apps/scan/flows/site_scan_flow.py +++ b/backend/apps/scan/flows/site_scan_flow.py @@ -1,4 +1,3 @@ - """ 站点扫描 Flow @@ -11,23 +10,22 @@ - 配置由 YAML 解析 """ -# Django 环境初始化(导入即生效) -from apps.common.prefect_django_setup import setup_django_for_prefect - +from datetime import datetime import logging -import os import subprocess -import time from pathlib import Path -from typing import Callable + from prefect import flow -from apps.scan.tasks.site_scan import export_site_urls_task, run_and_stream_save_websites_task + +# Django 环境初始化(导入即生效,pylint: disable=unused-import) +from apps.common.prefect_django_setup import setup_django_for_prefect # noqa: F401 from apps.scan.handlers.scan_flow_handlers import ( - on_scan_flow_running, on_scan_flow_completed, on_scan_flow_failed, + on_scan_flow_running, ) -from apps.scan.utils import config_parser, build_scan_command, user_log +from apps.scan.tasks.site_scan import export_site_urls_task, run_and_stream_save_websites_task +from apps.scan.utils import build_scan_command, user_log, wait_for_system_load logger = logging.getLogger(__name__) @@ -191,7 +189,6 @@ def _run_scans_sequentially( timeout = max(dynamic_timeout, config_timeout) # 2.1 生成日志文件路径(类似端口扫描) - from datetime import datetime timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') log_file = site_scan_dir / f"{tool_name}_{timestamp}.log" @@ -233,7 +230,7 @@ def _run_scans_sequentially( ) user_log(scan_id, "site_scan", f"{tool_name} completed: found {tool_created} websites") - except subprocess.TimeoutExpired as exc: + except subprocess.TimeoutExpired: # 超时异常单独处理 reason = f"timeout after {timeout}s" failed_tools.append({'tool': tool_name, 'reason': reason}) @@ -368,6 +365,9 @@ def site_scan_flow( RuntimeError: 执行失败 """ try: + # 负载检查:等待系统资源充足 + wait_for_system_load(context="site_scan_flow") + logger.info( "="*60 + "\n" + "开始站点扫描\n" + diff --git a/backend/apps/scan/flows/subdomain_discovery_flow.py b/backend/apps/scan/flows/subdomain_discovery_flow.py index 7d8aff13..88f2f28b 100644 --- a/backend/apps/scan/flows/subdomain_discovery_flow.py +++ b/backend/apps/scan/flows/subdomain_discovery_flow.py @@ -14,298 +14,82 @@ Stage 2: 字典爆破(可选) - 子域名字典爆破 Stage 3: 变异生成 + 验证(可选) - dnsgen + 通用存活验证 Stage 4: DNS 存活验证(可选) - 通用存活验证 - + 各阶段可灵活开关,最终结果根据实际执行的阶段动态决定 """ -# Django 环境初始化(导入即生效) -from apps.common.prefect_django_setup import setup_django_for_prefect +import logging +import subprocess +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Optional from prefect import flow -from pathlib import Path -import logging -import os -from apps.scan.handlers.scan_flow_handlers import ( - on_scan_flow_running, - on_scan_flow_completed, - on_scan_flow_failed, -) -from apps.scan.utils import build_scan_command, ensure_wordlist_local, user_log -from apps.engine.services.wordlist_service import WordlistService + +# Django 环境初始化(导入即生效,pylint: disable=unused-import) +from apps.common.prefect_django_setup import setup_django_for_prefect # noqa: F401 from apps.common.normalizer import normalize_domain from apps.common.validators import validate_domain -from datetime import datetime -import uuid -import subprocess +from apps.engine.services.wordlist_service import WordlistService +from apps.scan.handlers.scan_flow_handlers import ( + on_scan_flow_completed, + on_scan_flow_failed, + on_scan_flow_running, +) +from apps.scan.utils import ( + build_scan_command, + ensure_wordlist_local, + user_log, + wait_for_system_load, +) logger = logging.getLogger(__name__) +# 泛解析检测配置 +_SAMPLE_MULTIPLIER = 100 # 采样数量 = 原文件 × 100 +_EXPANSION_THRESHOLD = 50 # 膨胀阈值 = 原文件 × 50 +_SAMPLE_TIMEOUT = 7200 # 采样超时 2 小时 +@dataclass +class ScanContext: + """扫描上下文,用于在各阶段间传递状态""" + scan_id: int + target_id: int + domain_name: str + result_dir: Path + timestamp: str + current_result: str = "" + executed_tasks: list = field(default_factory=list) + failed_tools: list = field(default_factory=list) + successful_tools: list = field(default_factory=list) def _validate_and_normalize_target(target_name: str) -> str: - """ - 验证并规范化目标域名 - - Args: - target_name: 原始目标域名 - - Returns: - str: 规范化后的域名 - - Raises: - ValueError: 域名无效时抛出异常 - - Example: - >>> _validate_and_normalize_target('EXAMPLE.COM') - 'example.com' - >>> _validate_and_normalize_target('http://example.com') - 'example.com' - """ + """验证并规范化目标域名""" try: normalized_target = normalize_domain(target_name) validate_domain(normalized_target) logger.debug("域名验证通过: %s -> %s", target_name, normalized_target) return normalized_target except ValueError as e: - error_msg = f"无效的目标域名: {target_name} - {e}" - logger.error(error_msg) - raise ValueError(error_msg) from e - - -def _run_scans_parallel( - enabled_tools: dict, - domain_name: str, - result_dir: Path, - scan_id: int, - provider_config_path: str = None -) -> tuple[list, list, list]: - """ - 并行运行所有启用的子域名扫描工具 - - Args: - enabled_tools: 启用的工具配置字典 {'tool_name': {'timeout': 600, ...}} - domain_name: 目标域名 - result_dir: 结果输出目录 - scan_id: 扫描任务 ID(用于记录日志) - provider_config_path: Provider 配置文件路径(可选,用于 subfinder) - - Returns: - tuple: (result_files, failed_tools, successful_tool_names) - - Raises: - RuntimeError: 所有工具均失败 - """ - # 导入任务函数 - from apps.scan.tasks.subdomain_discovery import run_subdomain_discovery_task - - # 生成时间戳(所有工具共用) - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - - failures = [] # 记录命令构建失败的工具 - futures = {} - - # 1. 构建命令并提交并行任务 - for tool_name, tool_config in enabled_tools.items(): - # 1.1 生成唯一的输出文件路径(绝对路径) - short_uuid = uuid.uuid4().hex[:4] - output_file = str(result_dir / f"{tool_name}_{timestamp}_{short_uuid}.txt") - - # 1.2 构建完整命令(变量替换) - try: - command_params = { - 'domain': domain_name, # 对应 {domain} - 'output_file': output_file # 对应 {output_file} - } - - # 如果是 subfinder 且有 provider_config,添加到参数 - if tool_name == 'subfinder' and provider_config_path: - command_params['provider_config'] = provider_config_path - - command = build_scan_command( - tool_name=tool_name, - scan_type='subdomain_discovery', - command_params=command_params, - tool_config=tool_config - ) - except Exception as e: - failure_msg = f"{tool_name}: 命令构建失败 - {e}" - failures.append(failure_msg) - logger.error(f"构建 {tool_name} 命令失败: {e}") - continue - - # 1.3 获取超时时间(支持 'auto' 动态计算) - timeout = tool_config['timeout'] - if timeout == 'auto': - # 子域名发现工具通常运行时间较长,使用默认值 600 秒 - timeout = 600 - logger.info(f"✓ 工具 {tool_name} 使用默认 timeout: {timeout}秒") - - # 1.4 提交任务 - logger.debug( - f"提交任务 - 工具: {tool_name}, 超时: {timeout}s, 输出: {output_file}" - ) - - # 记录工具开始执行日志 - user_log(scan_id, "subdomain_discovery", f"Running {tool_name}: {command}") - - future = run_subdomain_discovery_task.submit( - tool=tool_name, - command=command, - timeout=timeout, - output_file=output_file - ) - futures[tool_name] = future - - # 2. 检查是否有任何工具成功提交 - if not futures: - logger.warning( - "所有扫描工具均无法启动 - 目标: %s, 失败详情: %s", - domain_name, "; ".join(failures) - ) - # 返回空结果,不抛出异常,让扫描继续 - return [], [{'tool': 'all', 'reason': '所有工具均无法启动'}], [] - - # 3. 等待并行任务完成,获取结果 - result_files = [] - failed_tools = [] - - for tool_name, future in futures.items(): - try: - result = future.result() # 返回文件路径(字符串)或 ""(失败) - if result: - result_files.append(result) - logger.info("✓ 扫描工具 %s 执行成功: %s", tool_name, result) - user_log(scan_id, "subdomain_discovery", f"{tool_name} completed") - else: - failure_msg = f"{tool_name}: 未生成结果文件" - failures.append(failure_msg) - failed_tools.append({'tool': tool_name, 'reason': '未生成结果文件'}) - logger.warning("⚠️ 扫描工具 %s 未生成结果文件", tool_name) - user_log(scan_id, "subdomain_discovery", f"{tool_name} failed: no output file", "error") - except Exception as e: - failure_msg = f"{tool_name}: {str(e)}" - failures.append(failure_msg) - failed_tools.append({'tool': tool_name, 'reason': str(e)}) - logger.warning("⚠️ 扫描工具 %s 执行失败: %s", tool_name, str(e)) - user_log(scan_id, "subdomain_discovery", f"{tool_name} failed: {str(e)}", "error") - - # 4. 检查是否有成功的工具 - if not result_files: - logger.warning( - "所有扫描工具均失败 - 目标: %s, 失败详情: %s", - domain_name, "; ".join(failures) - ) - # 返回空结果,不抛出异常,让扫描继续 - return [], failed_tools, [] - - # 5. 动态计算成功的工具列表 - successful_tool_names = [name for name in futures.keys() - if name not in [f['tool'] for f in failed_tools]] - - logger.info( - "✓ 扫描工具并行执行完成 - 成功: %d/%d (成功: %s, 失败: %s)", - len(result_files), len(futures), - ', '.join(successful_tool_names) if successful_tool_names else '无', - ', '.join([f['tool'] for f in failed_tools]) if failed_tools else '无' - ) - - return result_files, failed_tools, successful_tool_names - - -def _run_single_tool( - tool_name: str, - tool_config: dict, - command_params: dict, - result_dir: Path, - scan_type: str = 'subdomain_discovery', - scan_id: int = None -) -> str: - """ - 运行单个扫描工具 - - Args: - tool_name: 工具名称 - tool_config: 工具配置 - command_params: 命令参数 - result_dir: 结果目录 - scan_type: 扫描类型 - scan_id: 扫描 ID(用于记录用户日志) - - Returns: - str: 输出文件路径,失败返回空字符串 - """ - from apps.scan.tasks.subdomain_discovery import run_subdomain_discovery_task - - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - short_uuid = uuid.uuid4().hex[:4] - output_file = str(result_dir / f"{tool_name}_{timestamp}_{short_uuid}.txt") - - # 添加 output_file 到参数 - command_params['output_file'] = output_file - - try: - command = build_scan_command( - tool_name=tool_name, - scan_type=scan_type, - command_params=command_params, - tool_config=tool_config - ) - except Exception as e: - logger.error(f"构建 {tool_name} 命令失败: {e}") - return "" - - timeout = tool_config.get('timeout', 3600) - if timeout == 'auto': - timeout = 3600 - - logger.info(f"执行 {tool_name}: {command}") - if scan_id: - user_log(scan_id, scan_type, f"Running {tool_name}: {command}") - - try: - result = run_subdomain_discovery_task( - tool=tool_name, - command=command, - timeout=timeout, - output_file=output_file - ) - return result if result else "" - except Exception as e: - logger.warning(f"{tool_name} 执行失败: {e}") - return "" + raise ValueError(f"无效的目标域名: {target_name} - {e}") from e def _count_lines(file_path: str) -> int: - """ - 统计文件非空行数 - - Args: - file_path: 文件路径 - - Returns: - int: 非空行数量 - """ + """统计文件非空行数""" try: with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: return sum(1 for line in f if line.strip()) - except Exception as e: - logger.warning(f"统计文件行数失败: {file_path} - {e}") + except OSError as e: + logger.warning("统计文件行数失败: %s - %s", file_path, e) return 0 def _merge_files(file_list: list, output_file: str) -> str: - """ - 合并多个文件并去重 - - Args: - file_list: 文件路径列表 - output_file: 输出文件路径 - - Returns: - str: 输出文件路径 - """ + """合并多个文件并去重""" domains = set() for f in file_list: if f and Path(f).exists(): @@ -314,455 +98,418 @@ def _merge_files(file_list: list, output_file: str) -> str: line = line.strip() if line: domains.add(line) - + with open(output_file, 'w', encoding='utf-8') as fp: for domain in sorted(domains): fp.write(domain + '\n') - - logger.info(f"合并完成: {len(domains)} 个域名 -> {output_file}") + + logger.info("合并完成: %d 个域名 -> %s", len(domains), output_file) return output_file -@flow( - name="subdomain_discovery", - log_prints=True, - on_running=[on_scan_flow_running], - on_completion=[on_scan_flow_completed], - on_failure=[on_scan_flow_failed], -) -def subdomain_discovery_flow( - scan_id: int, - target_name: str, - target_id: int, - scan_workspace_dir: str, - enabled_tools: dict -) -> dict: - """子域名发现扫描流程 - - 工作流程(4 阶段): - Stage 1: 被动收集(并行) - 必选 - Stage 2: 字典爆破(可选) - 子域名字典爆破 - Stage 3: 变异生成 + 验证(可选) - dnsgen + 通用存活验证 - Stage 4: DNS 存活验证(可选) - 通用存活验证 - Final: 保存到数据库 - - 注意: - - 子域名发现只对 DOMAIN 类型目标有意义 - - IP 和 CIDR 类型目标会自动跳过 - - Args: - scan_id: 扫描任务 ID - target_name: 目标名称(域名) - target_id: 目标 ID - scan_workspace_dir: Scan 工作空间目录(由 Service 层创建) - enabled_tools: 扫描配置字典: - { - 'passive_tools': {...}, - 'bruteforce': {...}, - 'permutation': {...}, - 'resolve': {...} - } - - Returns: - dict: 扫描结果 - - Raises: - ValueError: 配置错误 - RuntimeError: 执行失败 - """ +def _calculate_auto_timeout(file_path: str, multiplier: int = 3, default: int = 3600) -> int: + """根据文件行数计算超时时间""" try: - # ==================== 参数验证 ==================== - if scan_id is None: - raise ValueError("scan_id 不能为空") - if target_id is None: - raise ValueError("target_id 不能为空") - if not scan_workspace_dir: - raise ValueError("scan_workspace_dir 不能为空") - if enabled_tools is None: - raise ValueError("enabled_tools 不能为空") - - scan_config = enabled_tools - - # 如果未提供目标域名,跳过扫描 - if not target_name: - logger.warning("未提供目标域名,跳过子域名发现扫描") - return _empty_result(scan_id, '', scan_workspace_dir) - - # ==================== 检查 Target 类型 ==================== - # 子域名发现只对 DOMAIN 类型有意义,IP 和 CIDR 类型跳过 - from apps.targets.services import TargetService - from apps.targets.models import Target - - target_service = TargetService() - target = target_service.get_target(target_id) - - if target and target.type != Target.TargetType.DOMAIN: - logger.info( - "跳过子域名发现扫描: Target 类型为 %s (ID=%d, Name=%s),子域名发现仅适用于域名类型", - target.type, target_id, target_name - ) - return _empty_result(scan_id, target_name, scan_workspace_dir) - - # 导入任务函数 - from apps.scan.tasks.subdomain_discovery import ( - run_subdomain_discovery_task, - merge_and_validate_task, - save_domains_task + with open(file_path, 'rb') as f: + line_count = sum(1 for _ in f) + return line_count * multiplier if line_count > 0 else default + except OSError: + return default + + +def _run_single_tool( + tool_name: str, + tool_config: dict, + command_params: dict, + result_dir: Path, + scan_id: Optional[int] = None, + scan_type: str = 'subdomain_discovery' +) -> str: + """运行单个扫描工具,返回输出文件路径,失败返回空字符串""" + from apps.scan.tasks.subdomain_discovery import run_subdomain_discovery_task + + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + short_uuid = uuid.uuid4().hex[:4] + output_file = str(result_dir / f"{tool_name}_{timestamp}_{short_uuid}.txt") + command_params['output_file'] = output_file + + try: + command = build_scan_command( + tool_name=tool_name, + scan_type=scan_type, + command_params=command_params, + tool_config=tool_config ) - - # Step 0: 准备工作 - from apps.scan.utils import setup_scan_directory - result_dir = setup_scan_directory(scan_workspace_dir, 'subdomain_discovery') - - # 验证并规范化目标域名 - try: - domain_name = _validate_and_normalize_target(target_name) - except ValueError as e: - logger.warning("目标域名无效,跳过子域名发现扫描: %s", e) - return _empty_result(scan_id, target_name, scan_workspace_dir) - - logger.info( - "="*60 + "\n" + - "开始子域名发现扫描\n" + - f" Scan ID: {scan_id}\n" + - f" Domain: {domain_name}\n" + - f" Workspace: {scan_workspace_dir}\n" + - "="*60 + except (ValueError, KeyError) as e: + logger.error("构建 %s 命令失败: %s", tool_name, e) + return "" + + timeout = tool_config.get('timeout', 3600) + if timeout == 'auto': + timeout = 3600 + + logger.info("执行 %s: %s", tool_name, command) + if scan_id: + user_log(scan_id, scan_type, f"Running {tool_name}: {command}") + + try: + result = run_subdomain_discovery_task( + tool=tool_name, + command=command, + timeout=timeout, + output_file=output_file ) - user_log(scan_id, "subdomain_discovery", f"Starting subdomain discovery for {domain_name}") - - # 解析配置 - passive_tools = scan_config.get('passive_tools', {}) - bruteforce_config = scan_config.get('bruteforce', {}) - permutation_config = scan_config.get('permutation', {}) - resolve_config = scan_config.get('resolve', {}) - - # 过滤出启用的被动工具 - enabled_passive_tools = { - k: v for k, v in passive_tools.items() - if v.get('enabled', True) - } - - executed_tasks = [] - all_result_files = [] - failed_tools = [] - successful_tool_names = [] - - # ==================== 生成 Provider 配置文件 ==================== - # 为 subfinder 生成第三方数据源配置 - provider_config_path = None + return result if result else "" + except (subprocess.TimeoutExpired, OSError) as e: + logger.warning("%s 执行失败: %s", tool_name, e) + return "" + + +def _run_scans_parallel( + enabled_tools: dict, + domain_name: str, + result_dir: Path, + scan_id: int, + provider_config_path: Optional[str] = None +) -> tuple[list, list, list]: + """并行运行所有启用的子域名扫描工具""" + from apps.scan.tasks.subdomain_discovery import run_subdomain_discovery_task + + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + futures = {} + failed_tools = [] + + for tool_name, tool_config in enabled_tools.items(): + short_uuid = uuid.uuid4().hex[:4] + output_file = str(result_dir / f"{tool_name}_{timestamp}_{short_uuid}.txt") + + command_params = {'domain': domain_name, 'output_file': output_file} + if tool_name == 'subfinder' and provider_config_path: + command_params['provider_config'] = provider_config_path + try: - from apps.scan.services.subfinder_provider_config_service import SubfinderProviderConfigService - provider_config_service = SubfinderProviderConfigService() - provider_config_path = provider_config_service.generate(str(result_dir)) - if provider_config_path: - logger.info(f"Provider 配置文件已生成: {provider_config_path}") - user_log(scan_id, "subdomain_discovery", "Provider config generated for subfinder") - except Exception as e: - logger.warning(f"生成 Provider 配置文件失败: {e}") - - # ==================== Stage 1: 被动收集(并行)==================== - if enabled_passive_tools: - logger.info("=" * 40) - logger.info("Stage 1: 被动收集(并行)") - logger.info("=" * 40) - logger.info("启用工具: %s", ', '.join(enabled_passive_tools.keys())) - user_log(scan_id, "subdomain_discovery", f"Stage 1: passive collection ({', '.join(enabled_passive_tools.keys())})") - result_files, stage1_failed, stage1_success = _run_scans_parallel( - enabled_tools=enabled_passive_tools, - domain_name=domain_name, - result_dir=result_dir, - scan_id=scan_id, - provider_config_path=provider_config_path + command = build_scan_command( + tool_name=tool_name, + scan_type='subdomain_discovery', + command_params=command_params, + tool_config=tool_config ) - all_result_files.extend(result_files) - failed_tools.extend(stage1_failed) - successful_tool_names.extend(stage1_success) - executed_tasks.extend([f'passive ({tool})' for tool in stage1_success]) - - # 合并 Stage 1 结果 - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - current_result = str(result_dir / f"subs_passive_{timestamp}.txt") - if all_result_files: - current_result = _merge_files(all_result_files, current_result) - executed_tasks.append('merge_passive') - else: - # 创建空文件 - Path(current_result).touch() - - # ==================== Stage 2: 字典爆破(可选)==================== - bruteforce_enabled = bruteforce_config.get('enabled', False) - if bruteforce_enabled: - logger.info("=" * 40) - logger.info("Stage 2: 字典爆破") - logger.info("=" * 40) - user_log(scan_id, "subdomain_discovery", "Stage 2: bruteforce") - - bruteforce_tool_config = bruteforce_config.get('subdomain_bruteforce', {}) - wordlist_name = bruteforce_tool_config.get('wordlist_name', 'dns_wordlist.txt') - - try: - # 确保本地存在字典文件(含 hash 校验) - local_wordlist_path = ensure_wordlist_local(wordlist_name) - - # 获取字典记录用于计算 timeout - wordlist_service = WordlistService() - wordlist = wordlist_service.get_wordlist_by_name(wordlist_name) - - timeout_value = bruteforce_tool_config.get('timeout', 3600) - if timeout_value == 'auto' and wordlist: - line_count = getattr(wordlist, 'line_count', None) - if line_count is None: - try: - with open(local_wordlist_path, 'rb') as f: - line_count = sum(1 for _ in f) - except OSError: - line_count = 0 + except (ValueError, KeyError) as e: + logger.error("构建 %s 命令失败: %s", tool_name, e) + failed_tools.append({'tool': tool_name, 'reason': f'命令构建失败: {e}'}) + continue - try: - line_count_int = int(line_count) - except (TypeError, ValueError): - line_count_int = 0 + timeout = tool_config.get('timeout', 600) + if timeout == 'auto': + timeout = 600 + logger.info("✓ 工具 %s 使用默认 timeout: %d秒", tool_name, timeout) - timeout_value = line_count_int * 3 if line_count_int > 0 else 3600 - bruteforce_tool_config = { - **bruteforce_tool_config, - 'timeout': timeout_value, - } + logger.debug("提交任务 - 工具: %s, 超时: %ds, 输出: %s", tool_name, timeout, output_file) + user_log(scan_id, "subdomain_discovery", f"Running {tool_name}: {command}") - brute_result = _run_single_tool( - tool_name='subdomain_bruteforce', - tool_config=bruteforce_tool_config, - command_params={ - 'domain': domain_name, - 'wordlist': local_wordlist_path, - }, - result_dir=result_dir, - scan_id=scan_id - ) - - if brute_result: - # 合并 Stage 1 + Stage 2 - current_result = _merge_files( - [current_result, brute_result], - str(result_dir / f"subs_merged_{timestamp}.txt") - ) - successful_tool_names.append('subdomain_bruteforce') - executed_tasks.append('bruteforce') - logger.info("✓ subdomain_bruteforce 执行完成") - user_log(scan_id, "subdomain_discovery", "subdomain_bruteforce completed") - else: - failed_tools.append({'tool': 'subdomain_bruteforce', 'reason': '执行失败'}) - logger.warning("⚠️ subdomain_bruteforce 执行失败") - user_log(scan_id, "subdomain_discovery", "subdomain_bruteforce failed: execution failed", "error") - except Exception as exc: - failed_tools.append({'tool': 'subdomain_bruteforce', 'reason': str(exc)}) - logger.warning("字典准备失败,跳过字典爆破: %s", exc) - user_log(scan_id, "subdomain_discovery", f"subdomain_bruteforce failed: {str(exc)}", "error") - - # ==================== Stage 3: 变异生成 + 验证(可选)==================== - permutation_enabled = permutation_config.get('enabled', False) - if permutation_enabled: - logger.info("=" * 40) - logger.info("Stage 3: 变异生成 + 存活验证(流式管道)") - logger.info("=" * 40) - user_log(scan_id, "subdomain_discovery", "Stage 3: permutation + resolve") - - permutation_tool_config = permutation_config.get('subdomain_permutation_resolve', {}) - - # === Step 3.1: 泛解析采样检测 === - # 生成原文件 100 倍的变异样本,检查解析结果是否超过 50 倍 - before_count = _count_lines(current_result) - - # 配置参数 - SAMPLE_MULTIPLIER = 100 # 采样数量 = 原文件 × 100 - EXPANSION_THRESHOLD = 50 # 膨胀阈值 = 原文件 × 50 - SAMPLE_TIMEOUT = 7200 # 采样超时 2 小时 - - sample_size = before_count * SAMPLE_MULTIPLIER - max_allowed = before_count * EXPANSION_THRESHOLD - - sample_output = str(result_dir / f"subs_permuted_sample_{timestamp}.txt") - sample_cmd = ( - f"cat {current_result} | dnsgen - | head -n {sample_size} | " - f"puredns resolve -r /app/backend/resources/resolvers.txt " - f"--write {sample_output} --wildcard-tests 50 --wildcard-batch 1000000 --quiet" - ) - - logger.info( - f"泛解析采样检测: 原文件 {before_count} 个, " - f"采样 {sample_size} 个, 阈值 {max_allowed} 个" - ) - - try: - subprocess.run( - sample_cmd, - shell=True, - timeout=SAMPLE_TIMEOUT, - check=False, - capture_output=True - ) - sample_result_count = _count_lines(sample_output) if Path(sample_output).exists() else 0 - - logger.info( - f"采样结果: {sample_result_count} 个域名存活 " - f"(原文件: {before_count}, 阈值: {max_allowed})" - ) - - if sample_result_count > max_allowed: - # 采样结果超过阈值,说明存在泛解析,跳过完整变异 - ratio = sample_result_count / before_count if before_count > 0 else sample_result_count - logger.warning( - f"跳过变异: 采样检测到泛解析 " - f"({sample_result_count} > {max_allowed}, 膨胀率 {ratio:.1f}x)" - ) - failed_tools.append({ - 'tool': 'subdomain_permutation_resolve', - 'reason': f"采样检测到泛解析 (膨胀率 {ratio:.1f}x)" - }) - user_log(scan_id, "subdomain_discovery", f"subdomain_permutation_resolve skipped: wildcard detected (ratio {ratio:.1f}x)", "warning") - else: - # === Step 3.2: 采样通过,执行完整变异 === - logger.info("采样检测通过,执行完整变异...") - - permuted_result = _run_single_tool( - tool_name='subdomain_permutation_resolve', - tool_config=permutation_tool_config, - command_params={ - 'input_file': current_result, - }, - result_dir=result_dir, - scan_id=scan_id - ) - - if permuted_result: - # 合并原结果 + 变异验证结果 - current_result = _merge_files( - [current_result, permuted_result], - str(result_dir / f"subs_with_permuted_{timestamp}.txt") - ) - successful_tool_names.append('subdomain_permutation_resolve') - executed_tasks.append('permutation') - logger.info("✓ subdomain_permutation_resolve 执行完成") - user_log(scan_id, "subdomain_discovery", "subdomain_permutation_resolve completed") - else: - failed_tools.append({'tool': 'subdomain_permutation_resolve', 'reason': '执行失败'}) - logger.warning("⚠️ subdomain_permutation_resolve 执行失败") - user_log(scan_id, "subdomain_discovery", "subdomain_permutation_resolve failed: execution failed", "error") - - except subprocess.TimeoutExpired: - failed_tools.append({'tool': 'subdomain_permutation_resolve', 'reason': '采样检测超时'}) - logger.warning(f"采样检测超时 ({SAMPLE_TIMEOUT}秒),跳过变异") - user_log(scan_id, "subdomain_discovery", "subdomain_permutation_resolve failed: sample detection timeout", "error") - except Exception as e: - failed_tools.append({'tool': 'subdomain_permutation_resolve', 'reason': f'采样检测失败: {e}'}) - logger.warning(f"采样检测失败: {e},跳过变异") - user_log(scan_id, "subdomain_discovery", f"subdomain_permutation_resolve failed: {str(e)}", "error") - - # ==================== Stage 4: DNS 存活验证(可选)==================== - # 无论是否启用 Stage 3,只要 resolve.enabled 为 true 就会执行,对当前所有候选子域做统一 DNS 验证 - resolve_enabled = resolve_config.get('enabled', False) - if resolve_enabled: - logger.info("=" * 40) - logger.info("Stage 4: DNS 存活验证") - logger.info("=" * 40) - user_log(scan_id, "subdomain_discovery", "Stage 4: DNS resolve") - - resolve_tool_config = resolve_config.get('subdomain_resolve', {}) + future = run_subdomain_discovery_task.submit( + tool=tool_name, + command=command, + timeout=timeout, + output_file=output_file + ) + futures[tool_name] = future - # 根据当前候选子域数量动态计算 timeout(支持 timeout: auto) - timeout_value = resolve_tool_config.get('timeout', 3600) - if timeout_value == 'auto': - line_count = 0 - try: - with open(current_result, 'rb') as f: - line_count = sum(1 for _ in f) - except OSError: - line_count = 0 + if not futures: + logger.warning("所有扫描工具均无法启动 - 目标: %s", domain_name) + return [], [{'tool': 'all', 'reason': '所有工具均无法启动'}], [] - try: - line_count_int = int(line_count) - except (TypeError, ValueError): - line_count_int = 0 - - timeout_value = line_count_int * 3 if line_count_int > 0 else 3600 - resolve_tool_config = { - **resolve_tool_config, - 'timeout': timeout_value, - } - - alive_result = _run_single_tool( - tool_name='subdomain_resolve', - tool_config=resolve_tool_config, - command_params={ - 'input_file': current_result, - }, - result_dir=result_dir, - scan_id=scan_id - ) - - if alive_result: - current_result = alive_result - successful_tool_names.append('subdomain_resolve') - executed_tasks.append('resolve') - logger.info("✓ subdomain_resolve 执行完成") - user_log(scan_id, "subdomain_discovery", "subdomain_resolve completed") + result_files = [] + for tool_name, future in futures.items(): + try: + result = future.result() + if result: + result_files.append(result) + logger.info("✓ 扫描工具 %s 执行成功: %s", tool_name, result) + user_log(scan_id, "subdomain_discovery", f"{tool_name} completed") else: - failed_tools.append({'tool': 'subdomain_resolve', 'reason': '执行失败'}) - logger.warning("⚠️ subdomain_resolve 执行失败") - user_log(scan_id, "subdomain_discovery", "subdomain_resolve failed: execution failed", "error") - - # ==================== Final: 保存到数据库 ==================== - logger.info("=" * 40) - logger.info("Final: 保存到数据库") - logger.info("=" * 40) - - # 最终验证和保存 - final_file = merge_and_validate_task( - result_files=[current_result], - result_dir=str(result_dir) + failed_tools.append({'tool': tool_name, 'reason': '未生成结果文件'}) + logger.warning("⚠️ 扫描工具 %s 未生成结果文件", tool_name) + user_log(scan_id, "subdomain_discovery", f"{tool_name} failed: no output", "error") + except (subprocess.TimeoutExpired, OSError) as e: + failed_tools.append({'tool': tool_name, 'reason': str(e)}) + logger.warning("⚠️ 扫描工具 %s 执行失败: %s", tool_name, e) + user_log(scan_id, "subdomain_discovery", f"{tool_name} failed: {e}", "error") + + successful_tools = [name for name in futures if name not in [f['tool'] for f in failed_tools]] + + logger.info( + "✓ 扫描工具并行执行完成 - 成功: %d/%d", + len(result_files), len(futures) + ) + + return result_files, failed_tools, successful_tools + + +def _generate_provider_config(result_dir: Path, scan_id: int) -> Optional[str]: + """为 subfinder 生成第三方数据源配置""" + try: + from apps.scan.services.subfinder_provider_config_service import ( + SubfinderProviderConfigService, ) - - save_result = save_domains_task( - domains_file=final_file, - scan_id=scan_id, - target_id=target_id + config_path = SubfinderProviderConfigService().generate(str(result_dir)) + if config_path: + logger.info("Provider 配置文件已生成: %s", config_path) + user_log(scan_id, "subdomain_discovery", "Provider config generated for subfinder") + return config_path + except (ImportError, OSError) as e: + logger.warning("生成 Provider 配置文件失败: %s", e) + return None + + +def _run_stage1_passive(ctx: ScanContext, enabled_tools: dict, provider_config: Optional[str]): + """Stage 1: 被动收集(并行)""" + if not enabled_tools: + return + + logger.info("=" * 40) + logger.info("Stage 1: 被动收集(并行)") + logger.info("=" * 40) + logger.info("启用工具: %s", ', '.join(enabled_tools.keys())) + user_log( + ctx.scan_id, "subdomain_discovery", + f"Stage 1: passive collection ({', '.join(enabled_tools.keys())})" + ) + + result_files, failed, successful = _run_scans_parallel( + enabled_tools=enabled_tools, + domain_name=ctx.domain_name, + result_dir=ctx.result_dir, + scan_id=ctx.scan_id, + provider_config_path=provider_config + ) + + ctx.failed_tools.extend(failed) + ctx.successful_tools.extend(successful) + ctx.executed_tasks.extend([f'passive ({tool})' for tool in successful]) + + # 合并结果 + ctx.current_result = str(ctx.result_dir / f"subs_passive_{ctx.timestamp}.txt") + if result_files: + ctx.current_result = _merge_files(result_files, ctx.current_result) + ctx.executed_tasks.append('merge_passive') + else: + Path(ctx.current_result).touch() + + +def _run_stage2_bruteforce(ctx: ScanContext, bruteforce_config: dict): + """Stage 2: 字典爆破(可选)""" + if not bruteforce_config.get('enabled', False): + return + + logger.info("=" * 40) + logger.info("Stage 2: 字典爆破") + logger.info("=" * 40) + user_log(ctx.scan_id, "subdomain_discovery", "Stage 2: bruteforce") + + tool_config = bruteforce_config.get('subdomain_bruteforce', {}) + wordlist_name = tool_config.get('wordlist_name', 'dns_wordlist.txt') + + try: + local_wordlist_path = ensure_wordlist_local(wordlist_name) + + # 计算 timeout + timeout_value = tool_config.get('timeout', 3600) + if timeout_value == 'auto': + wordlist = WordlistService().get_wordlist_by_name(wordlist_name) + line_count = getattr(wordlist, 'line_count', None) if wordlist else None + if line_count is None: + line_count = _calculate_auto_timeout(local_wordlist_path, 1, 0) + timeout_value = int(line_count) * 3 if line_count else 3600 + tool_config = {**tool_config, 'timeout': timeout_value} + + result = _run_single_tool( + tool_name='subdomain_bruteforce', + tool_config=tool_config, + command_params={'domain': ctx.domain_name, 'wordlist': local_wordlist_path}, + result_dir=ctx.result_dir, + scan_id=ctx.scan_id ) - processed_domains = save_result.get('processed_records', 0) - executed_tasks.append('save_domains') - - # 记录 Flow 完成 - logger.info("="*60) - logger.info("✓ 子域名发现扫描完成") - logger.info("="*60) - user_log(scan_id, "subdomain_discovery", f"subdomain_discovery completed: found {processed_domains} subdomains") - - return { - 'success': True, - 'scan_id': scan_id, - 'target': domain_name, - 'scan_workspace_dir': scan_workspace_dir, - 'total': processed_domains, - 'executed_tasks': executed_tasks, - 'tool_stats': { - 'total': len(enabled_passive_tools) + (1 if bruteforce_enabled else 0) + - (1 if permutation_enabled else 0) + (1 if resolve_enabled else 0), - 'successful': len(successful_tool_names), - 'failed': len(failed_tools), - 'successful_tools': successful_tool_names, - 'failed_tools': failed_tools - } - } - - except ValueError as e: - logger.error("配置错误: %s", e) - raise - except RuntimeError as e: - logger.error("运行时错误: %s", e) - raise - except Exception as e: - logger.exception("子域名发现扫描失败: %s", e) - raise + + if result: + ctx.current_result = _merge_files( + [ctx.current_result, result], + str(ctx.result_dir / f"subs_merged_{ctx.timestamp}.txt") + ) + ctx.successful_tools.append('subdomain_bruteforce') + ctx.executed_tasks.append('bruteforce') + logger.info("✓ subdomain_bruteforce 执行完成") + user_log(ctx.scan_id, "subdomain_discovery", "subdomain_bruteforce completed") + else: + ctx.failed_tools.append({'tool': 'subdomain_bruteforce', 'reason': '执行失败'}) + logger.warning("⚠️ subdomain_bruteforce 执行失败") + user_log(ctx.scan_id, "subdomain_discovery", "subdomain_bruteforce failed", "error") + + except (ValueError, OSError) as exc: + ctx.failed_tools.append({'tool': 'subdomain_bruteforce', 'reason': str(exc)}) + logger.warning("字典准备失败,跳过字典爆破: %s", exc) + user_log(ctx.scan_id, "subdomain_discovery", f"subdomain_bruteforce failed: {exc}", "error") + + +def _run_stage3_permutation(ctx: ScanContext, permutation_config: dict): + """Stage 3: 变异生成 + 验证(可选)""" + if not permutation_config.get('enabled', False): + return + + logger.info("=" * 40) + logger.info("Stage 3: 变异生成 + 存活验证(流式管道)") + logger.info("=" * 40) + user_log(ctx.scan_id, "subdomain_discovery", "Stage 3: permutation + resolve") + + tool_config = permutation_config.get('subdomain_permutation_resolve', {}) + before_count = _count_lines(ctx.current_result) + + sample_size = before_count * _SAMPLE_MULTIPLIER + max_allowed = before_count * _EXPANSION_THRESHOLD + sample_output = str(ctx.result_dir / f"subs_permuted_sample_{ctx.timestamp}.txt") + + sample_cmd = ( + f"cat {ctx.current_result} | dnsgen - | head -n {sample_size} | " + f"puredns resolve -r /app/backend/resources/resolvers.txt " + f"--write {sample_output} --wildcard-tests 50 --wildcard-batch 1000000 --quiet" + ) + + logger.info( + "泛解析采样检测: 原文件 %d 个, 采样 %d 个, 阈值 %d 个", + before_count, sample_size, max_allowed + ) + + try: + subprocess.run( + sample_cmd, + shell=True, # noqa: S602 + timeout=_SAMPLE_TIMEOUT, + check=False, + capture_output=True + ) + sample_count = _count_lines(sample_output) if Path(sample_output).exists() else 0 + + logger.info( + "采样结果: %d 个域名存活 (原文件: %d, 阈值: %d)", + sample_count, before_count, max_allowed + ) + + if sample_count > max_allowed: + ratio = sample_count / before_count if before_count > 0 else sample_count + logger.warning( + "跳过变异: 采样检测到泛解析 (%d > %d, 膨胀率 %.1fx)", + sample_count, max_allowed, ratio + ) + ctx.failed_tools.append({ + 'tool': 'subdomain_permutation_resolve', + 'reason': f"采样检测到泛解析 (膨胀率 {ratio:.1f}x)" + }) + user_log( + ctx.scan_id, "subdomain_discovery", + f"subdomain_permutation_resolve skipped: wildcard (ratio {ratio:.1f}x)", + "warning" + ) + return + + # 采样通过,执行完整变异 + logger.info("采样检测通过,执行完整变异...") + result = _run_single_tool( + tool_name='subdomain_permutation_resolve', + tool_config=tool_config, + command_params={'input_file': ctx.current_result}, + result_dir=ctx.result_dir, + scan_id=ctx.scan_id + ) + + if result: + ctx.current_result = _merge_files( + [ctx.current_result, result], + str(ctx.result_dir / f"subs_with_permuted_{ctx.timestamp}.txt") + ) + ctx.successful_tools.append('subdomain_permutation_resolve') + ctx.executed_tasks.append('permutation') + logger.info("✓ subdomain_permutation_resolve 执行完成") + user_log(ctx.scan_id, "subdomain_discovery", "subdomain_permutation_resolve completed") + else: + ctx.failed_tools.append({'tool': 'subdomain_permutation_resolve', 'reason': '执行失败'}) + logger.warning("⚠️ subdomain_permutation_resolve 执行失败") + user_log( + ctx.scan_id, "subdomain_discovery", + "subdomain_permutation_resolve failed", "error" + ) + + except subprocess.TimeoutExpired: + ctx.failed_tools.append({'tool': 'subdomain_permutation_resolve', 'reason': '采样检测超时'}) + logger.warning("采样检测超时 (%d秒),跳过变异", _SAMPLE_TIMEOUT) + user_log( + ctx.scan_id, "subdomain_discovery", + "subdomain_permutation_resolve failed: timeout", "error" + ) + except OSError as e: + ctx.failed_tools.append({'tool': 'subdomain_permutation_resolve', 'reason': f'采样检测失败: {e}'}) + logger.warning("采样检测失败: %s,跳过变异", e) + user_log(ctx.scan_id, "subdomain_discovery", f"subdomain_permutation_resolve failed: {e}", "error") + + +def _run_stage4_resolve(ctx: ScanContext, resolve_config: dict): + """Stage 4: DNS 存活验证(可选)""" + if not resolve_config.get('enabled', False): + return + + logger.info("=" * 40) + logger.info("Stage 4: DNS 存活验证") + logger.info("=" * 40) + user_log(ctx.scan_id, "subdomain_discovery", "Stage 4: DNS resolve") + + tool_config = resolve_config.get('subdomain_resolve', {}) + + # 动态计算 timeout + timeout_value = tool_config.get('timeout', 3600) + if timeout_value == 'auto': + timeout_value = _calculate_auto_timeout(ctx.current_result, 3, 3600) + tool_config = {**tool_config, 'timeout': timeout_value} + + result = _run_single_tool( + tool_name='subdomain_resolve', + tool_config=tool_config, + command_params={'input_file': ctx.current_result}, + result_dir=ctx.result_dir, + scan_id=ctx.scan_id + ) + + if result: + ctx.current_result = result + ctx.successful_tools.append('subdomain_resolve') + ctx.executed_tasks.append('resolve') + logger.info("✓ subdomain_resolve 执行完成") + user_log(ctx.scan_id, "subdomain_discovery", "subdomain_resolve completed") + else: + ctx.failed_tools.append({'tool': 'subdomain_resolve', 'reason': '执行失败'}) + logger.warning("⚠️ subdomain_resolve 执行失败") + user_log(ctx.scan_id, "subdomain_discovery", "subdomain_resolve failed", "error") + + +def _save_to_database(ctx: ScanContext) -> int: + """Final: 保存到数据库""" + from apps.scan.tasks.subdomain_discovery import merge_and_validate_task, save_domains_task + + logger.info("=" * 40) + logger.info("Final: 保存到数据库") + logger.info("=" * 40) + + final_file = merge_and_validate_task( + result_files=[ctx.current_result], + result_dir=str(ctx.result_dir) + ) + + save_result = save_domains_task( + domains_file=final_file, + scan_id=ctx.scan_id, + target_id=ctx.target_id + ) + + ctx.executed_tasks.append('save_domains') + return save_result.get('processed_records', 0) def _empty_result(scan_id: int, target: str, scan_workspace_dir: str) -> dict: @@ -782,3 +529,148 @@ def _empty_result(scan_id: int, target: str, scan_workspace_dir: str) -> dict: 'failed_tools': [] } } + + +@flow( + name="subdomain_discovery", + log_prints=True, + on_running=[on_scan_flow_running], + on_completion=[on_scan_flow_completed], + on_failure=[on_scan_flow_failed], +) +def subdomain_discovery_flow( + scan_id: int, + target_name: str, + target_id: int, + scan_workspace_dir: str, + enabled_tools: dict +) -> dict: + """子域名发现扫描流程 + + 工作流程(4 阶段): + Stage 1: 被动收集(并行) - 必选 + Stage 2: 字典爆破(可选) - 子域名字典爆破 + Stage 3: 变异生成 + 验证(可选) - dnsgen + 通用存活验证 + Stage 4: DNS 存活验证(可选) - 通用存活验证 + Final: 保存到数据库 + + 注意: + - 子域名发现只对 DOMAIN 类型目标有意义 + - IP 和 CIDR 类型目标会自动跳过 + """ + try: + wait_for_system_load(context="subdomain_discovery_flow") + + # 参数验证 + if scan_id is None: + raise ValueError("scan_id 不能为空") + if target_id is None: + raise ValueError("target_id 不能为空") + if not scan_workspace_dir: + raise ValueError("scan_workspace_dir 不能为空") + if enabled_tools is None: + raise ValueError("enabled_tools 不能为空") + + if not target_name: + logger.warning("未提供目标域名,跳过子域名发现扫描") + return _empty_result(scan_id, '', scan_workspace_dir) + + # 检查 Target 类型 + from apps.targets.models import Target + from apps.targets.services import TargetService + + target = TargetService().get_target(target_id) + if target and target.type != Target.TargetType.DOMAIN: + logger.info( + "跳过子域名发现扫描: Target 类型为 %s (ID=%d),仅适用于域名类型", + target.type, target_id + ) + return _empty_result(scan_id, target_name, scan_workspace_dir) + + # 验证并规范化目标域名 + try: + domain_name = _validate_and_normalize_target(target_name) + except ValueError as e: + logger.warning("目标域名无效,跳过子域名发现扫描: %s", e) + return _empty_result(scan_id, target_name, scan_workspace_dir) + + # 准备工作目录 + from apps.scan.utils import setup_scan_directory + result_dir = setup_scan_directory(scan_workspace_dir, 'subdomain_discovery') + + logger.info( + "开始子域名发现扫描 - Scan ID: %s, Domain: %s, Workspace: %s", + scan_id, domain_name, scan_workspace_dir + ) + user_log(scan_id, "subdomain_discovery", f"Starting subdomain discovery for {domain_name}") + + # 解析配置 + scan_config = enabled_tools + passive_tools = scan_config.get('passive_tools', {}) + bruteforce_config = scan_config.get('bruteforce', {}) + permutation_config = scan_config.get('permutation', {}) + resolve_config = scan_config.get('resolve', {}) + + enabled_passive_tools = { + k: v for k, v in passive_tools.items() + if v.get('enabled', True) + } + + # 创建扫描上下文 + ctx = ScanContext( + scan_id=scan_id, + target_id=target_id, + domain_name=domain_name, + result_dir=result_dir, + timestamp=datetime.now().strftime('%Y%m%d_%H%M%S') + ) + + # 生成 Provider 配置 + provider_config = _generate_provider_config(result_dir, scan_id) + + # 执行各阶段 + _run_stage1_passive(ctx, enabled_passive_tools, provider_config) + _run_stage2_bruteforce(ctx, bruteforce_config) + _run_stage3_permutation(ctx, permutation_config) + _run_stage4_resolve(ctx, resolve_config) + + # 保存到数据库 + processed_domains = _save_to_database(ctx) + + logger.info("✓ 子域名发现扫描完成") + user_log( + scan_id, "subdomain_discovery", + f"subdomain_discovery completed: found {processed_domains} subdomains" + ) + + # 计算工具总数 + total_tools = len(enabled_passive_tools) + if bruteforce_config.get('enabled', False): + total_tools += 1 + if permutation_config.get('enabled', False): + total_tools += 1 + if resolve_config.get('enabled', False): + total_tools += 1 + + return { + 'success': True, + 'scan_id': scan_id, + 'target': domain_name, + 'scan_workspace_dir': scan_workspace_dir, + 'total': processed_domains, + 'executed_tasks': ctx.executed_tasks, + 'tool_stats': { + 'total': total_tools, + 'successful': len(ctx.successful_tools), + 'failed': len(ctx.failed_tools), + 'successful_tools': ctx.successful_tools, + 'failed_tools': ctx.failed_tools + } + } + + except ValueError as e: + logger.error("配置错误: %s", e) + raise + except RuntimeError as e: + logger.error("运行时错误: %s", e) + raise diff --git a/backend/apps/scan/flows/url_fetch/main_flow.py b/backend/apps/scan/flows/url_fetch/main_flow.py index 904ed723..2f0d1a43 100644 --- a/backend/apps/scan/flows/url_fetch/main_flow.py +++ b/backend/apps/scan/flows/url_fetch/main_flow.py @@ -10,22 +10,18 @@ URL Fetch 主 Flow - 统一进行 httpx 验证(如果启用) """ -# Django 环境初始化 -from apps.common.prefect_django_setup import setup_django_for_prefect - import logging -import os -from pathlib import Path from datetime import datetime +from pathlib import Path from prefect import flow from apps.scan.handlers.scan_flow_handlers import ( - on_scan_flow_running, on_scan_flow_completed, on_scan_flow_failed, + on_scan_flow_running, ) -from apps.scan.utils import user_log +from apps.scan.utils import user_log, wait_for_system_load from .domain_name_url_fetch_flow import domain_name_url_fetch_flow from .sites_url_fetch_flow import sites_url_fetch_flow @@ -43,13 +39,10 @@ SITES_FILE_TOOLS = {'katana'} POST_PROCESS_TOOLS = {'uro', 'httpx'} - - - def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]: """ 将启用的工具按输入类型分类 - + Returns: tuple: (domain_name_tools, sites_file_tools, uro_config, httpx_config) """ @@ -76,23 +69,23 @@ def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]: def _merge_and_deduplicate_urls(result_files: list, url_fetch_dir: Path) -> tuple[str, int]: """合并并去重 URL""" from apps.scan.tasks.url_fetch import merge_and_deduplicate_urls_task - + merged_file = merge_and_deduplicate_urls_task( result_files=result_files, result_dir=str(url_fetch_dir) ) - + # 统计唯一 URL 数量 unique_url_count = 0 if Path(merged_file).exists(): - with open(merged_file, 'r') as f: + with open(merged_file, 'r', encoding='utf-8') as f: unique_url_count = sum(1 for line in f if line.strip()) - + logger.info( "✓ URL 合并去重完成 - 合并文件: %s, 唯一 URL 数: %d", merged_file, unique_url_count ) - + return merged_file, unique_url_count @@ -103,12 +96,12 @@ def _clean_urls_with_uro( ) -> tuple[str, int, int]: """使用 uro 清理合并后的 URL 列表""" from apps.scan.tasks.url_fetch import clean_urls_task - + raw_timeout = uro_config.get('timeout', 60) whitelist = uro_config.get('whitelist') blacklist = uro_config.get('blacklist') filters = uro_config.get('filters') - + # 计算超时时间 if isinstance(raw_timeout, str) and raw_timeout == 'auto': timeout = calculate_timeout_by_line_count( @@ -124,7 +117,7 @@ def _clean_urls_with_uro( except (TypeError, ValueError): logger.warning("uro timeout 配置无效(%s),使用默认 60 秒", raw_timeout) timeout = 60 - + result = clean_urls_task( input_file=merged_file, output_dir=str(url_fetch_dir), @@ -133,12 +126,12 @@ def _clean_urls_with_uro( blacklist=blacklist, filters=filters ) - + if result['success']: return result['output_file'], result['output_count'], result['removed_count'] - else: - logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误')) - return merged_file, result['input_count'], 0 + + logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误')) + return merged_file, result['input_count'], 0 def _validate_and_stream_save_urls( @@ -151,25 +144,25 @@ def _validate_and_stream_save_urls( """使用 httpx 验证 URL 存活并流式保存到数据库""" from apps.scan.utils import build_scan_command from apps.scan.tasks.url_fetch import run_and_stream_save_urls_task - + logger.info("开始使用 httpx 验证 URL 存活状态...") - + # 统计待验证的 URL 数量 try: - with open(merged_file, 'r') as f: + with open(merged_file, 'r', encoding='utf-8') as f: url_count = sum(1 for _ in f) logger.info("待验证 URL 数量: %d", url_count) - except Exception as e: + except OSError as e: logger.error("读取 URL 文件失败: %s", e) return 0 - + if url_count == 0: logger.warning("没有需要验证的 URL") return 0 - + # 构建 httpx 命令 command_params = {'url_file': merged_file} - + try: command = build_scan_command( tool_name='httpx', @@ -177,21 +170,19 @@ def _validate_and_stream_save_urls( command_params=command_params, tool_config=httpx_config ) - except Exception as e: + except (ValueError, KeyError) as e: logger.error("构建 httpx 命令失败: %s", e) logger.warning("降级处理:将直接保存所有 URL(不验证存活)") return _save_urls_to_database(merged_file, scan_id, target_id) - + # 计算超时时间 raw_timeout = httpx_config.get('timeout', 'auto') - timeout = 3600 if isinstance(raw_timeout, str) and raw_timeout == 'auto': # 按 URL 行数计算超时时间:每行 3 秒,最小 60 秒 timeout = max(60, url_count * 3) logger.info( "自动计算 httpx 超时时间(按行数,每行 3 秒,最小 60 秒): url_count=%d, timeout=%d 秒", - url_count, - timeout, + url_count, timeout ) else: try: @@ -199,49 +190,44 @@ def _validate_and_stream_save_urls( except (TypeError, ValueError): timeout = 3600 logger.info("使用配置的 httpx 超时时间: %d 秒", timeout) - + # 生成日志文件路径 timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') log_file = url_fetch_dir / f"httpx_validation_{timestamp}.log" - + # 流式执行 - try: - result = run_and_stream_save_urls_task( - cmd=command, - tool_name='httpx', - scan_id=scan_id, - target_id=target_id, - cwd=str(url_fetch_dir), - shell=True, - timeout=timeout, - log_file=str(log_file) - ) - - saved = result.get('saved_urls', 0) - logger.info( - "✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)", - saved, (saved / url_count * 100) if url_count > 0 else 0 - ) - return saved - - except Exception as e: - logger.error("httpx 流式验证失败: %s", e, exc_info=True) - raise + result = run_and_stream_save_urls_task( + cmd=command, + tool_name='httpx', + scan_id=scan_id, + target_id=target_id, + cwd=str(url_fetch_dir), + shell=True, + timeout=timeout, + log_file=str(log_file) + ) + + saved = result.get('saved_urls', 0) + logger.info( + "✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)", + saved, (saved / url_count * 100) if url_count > 0 else 0 + ) + return saved def _save_urls_to_database(merged_file: str, scan_id: int, target_id: int) -> int: """保存 URL 到数据库(不验证存活)""" from apps.scan.tasks.url_fetch import save_urls_task - + result = save_urls_task( urls_file=merged_file, scan_id=scan_id, target_id=target_id ) - + saved_count = result.get('saved_urls', 0) logger.info("✓ URL 保存完成 - 保存数量: %d", saved_count) - + return saved_count @@ -261,7 +247,7 @@ def url_fetch_flow( ) -> dict: """ URL 获取主 Flow - + 执行流程: 1. 准备工作目录 2. 按输入类型分类工具(domain_name / sites_file / 后处理) @@ -271,36 +257,32 @@ def url_fetch_flow( 4. 合并所有子 Flow 的结果并去重 5. uro 去重(如果启用) 6. httpx 验证(如果启用) - + Args: scan_id: 扫描 ID target_name: 目标名称 target_id: 目标 ID scan_workspace_dir: 扫描工作目录 enabled_tools: 启用的工具配置 - + Returns: dict: 扫描结果 """ try: + # 负载检查:等待系统资源充足 + wait_for_system_load(context="url_fetch_flow") + logger.info( - "="*60 + "\n" + - "开始 URL 获取扫描\n" + - f" Scan ID: {scan_id}\n" + - f" Target: {target_name}\n" + - f" Workspace: {scan_workspace_dir}\n" + - "="*60 + "开始 URL 获取扫描 - Scan ID: %s, Target: %s, Workspace: %s", + scan_id, target_name, scan_workspace_dir ) - user_log(scan_id, "url_fetch", "Starting URL fetch") - + # Step 1: 准备工作目录 - logger.info("Step 1: 准备工作目录") from apps.scan.utils import setup_scan_directory url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch') - + # Step 2: 分类工具(按输入类型) - logger.info("Step 2: 分类工具") domain_name_tools, sites_file_tools, uro_config, httpx_config = _classify_tools(enabled_tools) logger.info( @@ -317,15 +299,14 @@ def url_fetch_flow( "URL Fetch 流程需要至少启用一个 URL 获取工具(如 waymore, katana)。" "httpx 和 uro 仅用于后处理,不能单独使用。" ) - - # Step 3: 并行执行子 Flow + + # Step 3: 执行子 Flow all_result_files = [] all_failed_tools = [] all_successful_tools = [] - - # 3a: 基于 domain_name(target_name) 的 URL 被动收集(如 waymore) + + # 3a: 基于 domain_name 的 URL 被动收集(如 waymore) if domain_name_tools: - logger.info("Step 3a: 执行基于 domain_name 的 URL 被动收集子 Flow") tn_result = domain_name_url_fetch_flow( scan_id=scan_id, target_id=target_id, @@ -336,10 +317,9 @@ def url_fetch_flow( all_result_files.extend(tn_result.get('result_files', [])) all_failed_tools.extend(tn_result.get('failed_tools', [])) all_successful_tools.extend(tn_result.get('successful_tools', [])) - + # 3b: 爬虫(以 sites_file 为输入) if sites_file_tools: - logger.info("Step 3b: 执行爬虫子 Flow") crawl_result = sites_url_fetch_flow( scan_id=scan_id, target_id=target_id, @@ -350,12 +330,13 @@ def url_fetch_flow( all_result_files.extend(crawl_result.get('result_files', [])) all_failed_tools.extend(crawl_result.get('failed_tools', [])) all_successful_tools.extend(crawl_result.get('successful_tools', [])) - + # 检查是否有成功的工具 if not all_result_files: - error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in all_failed_tools]) + error_details = "; ".join([ + "%s: %s" % (f['tool'], f['reason']) for f in all_failed_tools + ]) logger.warning("所有 URL 获取工具均失败 - 目标: %s, 失败详情: %s", target_name, error_details) - # 返回空结果,不抛出异常,让扫描继续 return { 'success': True, 'scan_id': scan_id, @@ -366,31 +347,24 @@ def url_fetch_flow( 'successful_tools': [], 'message': '所有 URL 获取工具均无结果' } - + # Step 4: 合并并去重 URL - logger.info("Step 4: 合并并去重 URL") - merged_file, unique_url_count = _merge_and_deduplicate_urls( + merged_file, _ = _merge_and_deduplicate_urls( result_files=all_result_files, url_fetch_dir=url_fetch_dir ) - + # Step 5: 使用 uro 清理 URL(如果启用) url_file_for_validation = merged_file - uro_removed_count = 0 - if uro_config and uro_config.get('enabled', False): - logger.info("Step 5: 使用 uro 清理 URL") - url_file_for_validation, cleaned_count, uro_removed_count = _clean_urls_with_uro( + url_file_for_validation, _, _ = _clean_urls_with_uro( merged_file=merged_file, uro_config=uro_config, url_fetch_dir=url_fetch_dir ) - else: - logger.info("Step 5: 跳过 uro 清理(未启用)") - + # Step 6: 使用 httpx 验证存活并保存(如果启用) if httpx_config and httpx_config.get('enabled', False): - logger.info("Step 6: 使用 httpx 验证 URL 存活并流式保存") saved_count = _validate_and_stream_save_urls( merged_file=url_file_for_validation, httpx_config=httpx_config, @@ -399,17 +373,16 @@ def url_fetch_flow( target_id=target_id ) else: - logger.info("Step 6: 保存到数据库(未启用 httpx 验证)") saved_count = _save_urls_to_database( merged_file=url_file_for_validation, scan_id=scan_id, target_id=target_id ) - + # 记录 Flow 完成 logger.info("✓ URL 获取完成 - 保存 endpoints: %d", saved_count) - user_log(scan_id, "url_fetch", f"url_fetch completed: found {saved_count} endpoints") - + user_log(scan_id, "url_fetch", "url_fetch completed: found %d endpoints" % saved_count) + # 构建已执行的任务列表 executed_tasks = ['setup_directory', 'classify_tools'] if domain_name_tools: @@ -423,7 +396,7 @@ def url_fetch_flow( executed_tasks.append('httpx_validation_and_save') else: executed_tasks.append('save_urls') - + return { 'success': True, 'scan_id': scan_id, @@ -439,7 +412,7 @@ def url_fetch_flow( 'failed_tools': [f['tool'] for f in all_failed_tools] } } - + except Exception as e: logger.error("URL 获取扫描失败: %s", e, exc_info=True) raise diff --git a/backend/apps/scan/flows/vuln_scan/main_flow.py b/backend/apps/scan/flows/vuln_scan/main_flow.py index 6ae8dd70..94330dc0 100644 --- a/backend/apps/scan/flows/vuln_scan/main_flow.py +++ b/backend/apps/scan/flows/vuln_scan/main_flow.py @@ -1,5 +1,6 @@ -from apps.common.prefect_django_setup import setup_django_for_prefect - +""" +漏洞扫描主 Flow +""" import logging from typing import Dict, Tuple @@ -11,7 +12,7 @@ from apps.scan.handlers.scan_flow_handlers import ( on_scan_flow_failed, ) from apps.scan.configs.command_templates import get_command_template -from apps.scan.utils import user_log +from apps.scan.utils import user_log, wait_for_system_load from .endpoints_vuln_scan_flow import endpoints_vuln_scan_flow @@ -62,6 +63,9 @@ def vuln_scan_flow( - nuclei: 通用漏洞扫描(流式保存,支持模板 commit hash 同步) """ try: + # 负载检查:等待系统资源充足 + wait_for_system_load(context="vuln_scan_flow") + if scan_id is None: raise ValueError("scan_id 不能为空") if not target_name: diff --git a/backend/apps/scan/providers/base.py b/backend/apps/scan/providers/base.py index eb51e764..5e499c92 100644 --- a/backend/apps/scan/providers/base.py +++ b/backend/apps/scan/providers/base.py @@ -4,11 +4,11 @@ 定义 ProviderContext 数据类和 TargetProvider 抽象基类。 """ -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Iterator, Optional, TYPE_CHECKING import ipaddress import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Iterator, Optional if TYPE_CHECKING: from apps.common.utils import BlacklistFilter @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) class ProviderContext: """ Provider 上下文,携带元数据 - + Attributes: target_id: 关联的 Target ID(用于结果保存),None 表示临时扫描(不保存) scan_id: 扫描任务 ID @@ -32,130 +32,83 @@ class ProviderContext: class TargetProvider(ABC): """ 扫描目标提供者抽象基类 - + 职责: - 提供扫描目标(域名、IP、URL 等)的迭代器 - 提供黑名单过滤器 - 携带上下文信息(target_id, scan_id 等) - 自动展开 CIDR(子类无需关心) - + 使用方式: provider = create_target_provider(target_id=123) for host in provider.iter_hosts(): print(host) """ - + def __init__(self, context: Optional[ProviderContext] = None): - """ - 初始化 Provider - - Args: - context: Provider 上下文,None 时创建默认上下文 - """ self._context = context or ProviderContext() - + @property def context(self) -> ProviderContext: """返回 Provider 上下文""" return self._context - + @staticmethod def _expand_host(host: str) -> Iterator[str]: """ 展开主机(如果是 CIDR 则展开为多个 IP,否则直接返回) - - 这是一个内部方法,由 iter_hosts() 自动调用。 - - Args: - host: 主机字符串(IP/域名/CIDR) - - Yields: - str: 单个主机(IP 或域名) - + 示例: "192.168.1.0/30" → "192.168.1.1", "192.168.1.2" "192.168.1.1" → "192.168.1.1" "example.com" → "example.com" - "invalid" → (跳过,不返回) """ from apps.common.validators import detect_target_type from apps.targets.models import Target - + host = host.strip() if not host: return - - # 统一使用 detect_target_type 检测类型 + try: target_type = detect_target_type(host) - + if target_type == Target.TargetType.CIDR: - # 展开 CIDR network = ipaddress.ip_network(host, strict=False) if network.num_addresses == 1: yield str(network.network_address) else: - for ip in network.hosts(): - yield str(ip) - elif target_type == Target.TargetType.IP: - # 单个 IP - yield host - elif target_type == Target.TargetType.DOMAIN: - # 域名 + yield from (str(ip) for ip in network.hosts()) + elif target_type in (Target.TargetType.IP, Target.TargetType.DOMAIN): yield host except ValueError as e: - # 无效格式,跳过并记录警告 logger.warning("跳过无效的主机格式 '%s': %s", host, str(e)) - + def iter_hosts(self) -> Iterator[str]: - """ - 迭代主机列表(域名/IP) - - 自动展开 CIDR,子类无需关心。 - - Yields: - str: 主机名或 IP 地址(单个,不包含 CIDR) - """ + """迭代主机列表(域名/IP),自动展开 CIDR""" for host in self._iter_raw_hosts(): yield from self._expand_host(host) - + @abstractmethod def _iter_raw_hosts(self) -> Iterator[str]: - """ - 迭代原始主机列表(可能包含 CIDR) - - 子类实现此方法,返回原始数据即可,不需要处理 CIDR 展开。 - - Yields: - str: 主机名、IP 地址或 CIDR - """ + """迭代原始主机列表(可能包含 CIDR),子类实现""" pass - + @abstractmethod def iter_urls(self) -> Iterator[str]: - """ - 迭代 URL 列表 - - Yields: - str: URL 字符串 - """ + """迭代 URL 列表""" pass - + @abstractmethod def get_blacklist_filter(self) -> Optional['BlacklistFilter']: - """ - 获取黑名单过滤器 - - Returns: - BlacklistFilter: 黑名单过滤器实例,或 None(不过滤) - """ + """获取黑名单过滤器,返回 None 表示不过滤""" pass - + @property def target_id(self) -> Optional[int]: """返回关联的 target_id,临时扫描返回 None""" return self._context.target_id - + @property def scan_id(self) -> Optional[int]: """返回关联的 scan_id""" diff --git a/backend/apps/scan/providers/database_provider.py b/backend/apps/scan/providers/database_provider.py index 77e6664c..abe5516b 100644 --- a/backend/apps/scan/providers/database_provider.py +++ b/backend/apps/scan/providers/database_provider.py @@ -5,9 +5,9 @@ """ import logging -from typing import Iterator, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Iterator, Optional -from .base import TargetProvider, ProviderContext +from .base import ProviderContext, TargetProvider if TYPE_CHECKING: from apps.common.utils import BlacklistFilter @@ -18,109 +18,71 @@ logger = logging.getLogger(__name__) class DatabaseTargetProvider(TargetProvider): """ 数据库目标提供者 - 从 Target 表及关联资产表查询 - - 这是现有行为的封装,保持向后兼容。 - + 数据来源: - iter_hosts(): 根据 Target 类型返回域名/IP - - DOMAIN: 根域名 + Subdomain 表 - - IP: 直接返回 IP - - CIDR: 使用 _expand_host() 展开为所有主机 IP - iter_urls(): WebSite/Endpoint 表,带回退链 - + 使用方式: provider = DatabaseTargetProvider(target_id=123) for host in provider.iter_hosts(): scan(host) """ - + def __init__(self, target_id: int, context: Optional[ProviderContext] = None): - """ - 初始化数据库目标提供者 - - Args: - target_id: 目标 ID(必需) - context: Provider 上下文 - """ ctx = context or ProviderContext() ctx.target_id = target_id super().__init__(ctx) - self._blacklist_filter: Optional['BlacklistFilter'] = None # 延迟加载 - + self._blacklist_filter: Optional['BlacklistFilter'] = None + def iter_hosts(self) -> Iterator[str]: - """ - 从数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤 - - 重写基类方法以支持黑名单过滤(需要在 CIDR 展开后过滤) - """ + """从数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤""" blacklist = self.get_blacklist_filter() - + for host in self._iter_raw_hosts(): - # 展开 CIDR for expanded_host in self._expand_host(host): - # 应用黑名单过滤 if not blacklist or blacklist.is_allowed(expanded_host): yield expanded_host - + def _iter_raw_hosts(self) -> Iterator[str]: - """ - 从数据库查询原始主机列表(可能包含 CIDR) - - 根据 Target 类型决定数据来源: - - DOMAIN: 根域名 + Subdomain 表 - - IP: 直接返回 target.name - - CIDR: 返回 CIDR 字符串(由 iter_hosts() 展开) - - 注意:此方法不应用黑名单过滤,过滤在 iter_hosts() 中进行 - """ - from apps.targets.services import TargetService - from apps.targets.models import Target + """从数据库查询原始主机列表(可能包含 CIDR)""" from apps.asset.services.asset.subdomain_service import SubdomainService - + from apps.targets.models import Target + from apps.targets.services import TargetService + target = TargetService().get_target(self.target_id) if not target: logger.warning("Target ID %d 不存在", self.target_id) return - + if target.type == Target.TargetType.DOMAIN: - # 返回根域名 yield target.name - - # 返回子域名 - subdomain_service = SubdomainService() - for domain in subdomain_service.iter_subdomain_names_by_target( + for domain in SubdomainService().iter_subdomain_names_by_target( target_id=self.target_id, chunk_size=1000 ): - if domain != target.name: # 避免重复 + if domain != target.name: yield domain - - elif target.type == Target.TargetType.IP: + + elif target.type in (Target.TargetType.IP, Target.TargetType.CIDR): yield target.name - - elif target.type == Target.TargetType.CIDR: - # 直接返回 CIDR,由 iter_hosts() 展开并过滤 - yield target.name - + def iter_urls(self) -> Iterator[str]: - """ - 从数据库查询 URL 列表 - - 使用现有的回退链逻辑:Endpoint → WebSite → Default - """ + """从数据库查询 URL 列表,使用回退链:Endpoint → WebSite → Default""" from apps.scan.services.target_export_service import ( - _iter_urls_with_fallback, DataSource + DataSource, + _iter_urls_with_fallback, ) - + blacklist = self.get_blacklist_filter() - - for url, source in _iter_urls_with_fallback( + + for url, _ in _iter_urls_with_fallback( target_id=self.target_id, sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT], blacklist_filter=blacklist ): yield url - + def get_blacklist_filter(self) -> Optional['BlacklistFilter']: """获取黑名单过滤器(延迟加载)""" if self._blacklist_filter is None: diff --git a/backend/apps/scan/services/target_export_service.py b/backend/apps/scan/services/target_export_service.py index a44c4379..3d0ad1a0 100644 --- a/backend/apps/scan/services/target_export_service.py +++ b/backend/apps/scan/services/target_export_service.py @@ -12,7 +12,7 @@ import ipaddress import logging from pathlib import Path -from typing import Dict, Any, Optional, List, Iterator, Tuple, Callable +from typing import Dict, Any, Optional, List, Iterator, Tuple from django.db.models import QuerySet @@ -485,8 +485,7 @@ class TargetExportService: """ from apps.targets.services import TargetService from apps.targets.models import Target - from apps.asset.services.asset.subdomain_service import SubdomainService - + output_file = Path(output_path) output_file.parent.mkdir(parents=True, exist_ok=True) diff --git a/backend/apps/scan/tasks/directory_scan/export_sites_task.py b/backend/apps/scan/tasks/directory_scan/export_sites_task.py index 25720c72..defe7957 100644 --- a/backend/apps/scan/tasks/directory_scan/export_sites_task.py +++ b/backend/apps/scan/tasks/directory_scan/export_sites_task.py @@ -16,7 +16,7 @@ from apps.scan.services.target_export_service import ( export_urls_with_fallback, DataSource, ) -from apps.scan.providers import TargetProvider, DatabaseTargetProvider +from apps.scan.providers import TargetProvider logger = logging.getLogger(__name__) diff --git a/backend/apps/scan/tasks/port_scan/export_hosts_task.py b/backend/apps/scan/tasks/port_scan/export_hosts_task.py index 8519abf6..d7578d42 100644 --- a/backend/apps/scan/tasks/port_scan/export_hosts_task.py +++ b/backend/apps/scan/tasks/port_scan/export_hosts_task.py @@ -11,12 +11,12 @@ - CIDR: 展开 CIDR 范围内的所有 IP """ import logging -from typing import Optional from pathlib import Path +from typing import Optional + from prefect import task -from apps.scan.services.target_export_service import create_export_service -from apps.scan.providers import TargetProvider, DatabaseTargetProvider, ProviderContext +from apps.scan.providers import DatabaseTargetProvider, TargetProvider logger = logging.getLogger(__name__) @@ -26,15 +26,14 @@ def export_hosts_task( output_file: str, target_id: Optional[int] = None, provider: Optional[TargetProvider] = None, - batch_size: int = 1000 ) -> dict: """ 导出主机列表到 TXT 文件 - + 支持两种模式: 1. 传统模式(向后兼容):传入 target_id,从数据库导出 2. Provider 模式:传入 provider,从任意数据源导出 - + 根据 Target 类型自动决定导出内容: - DOMAIN: 从 Subdomain 表导出子域名(流式处理,支持 10万+ 域名) - IP: 直接写入 target.name(单个 IP) @@ -44,7 +43,6 @@ def export_hosts_task( output_file: 输出文件路径(绝对路径) target_id: 目标 ID(传统模式,向后兼容) provider: TargetProvider 实例(新模式) - batch_size: 每次读取的批次大小,默认 1000(仅对 DOMAIN 类型有效) Returns: dict: { @@ -58,53 +56,44 @@ def export_hosts_task( ValueError: 参数错误(target_id 和 provider 都未提供) IOError: 文件写入失败 """ - # 参数验证:至少提供一个 if target_id is None and provider is None: raise ValueError("必须提供 target_id 或 provider 参数之一") - + # 向后兼容:如果没有提供 provider,使用 target_id 创建 DatabaseTargetProvider - if provider is None: + use_legacy_mode = provider is None + if use_legacy_mode: logger.info("使用传统模式 - Target ID: %d", target_id) provider = DatabaseTargetProvider(target_id=target_id) - use_legacy_mode = True else: logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__) - use_legacy_mode = False - + # 确保输出目录存在 output_path = Path(output_file) output_path.parent.mkdir(parents=True, exist_ok=True) - - # 使用 Provider 导出主机列表 + + # 使用 Provider 导出主机列表(iter_hosts 内部已处理黑名单过滤) total_count = 0 - blacklist_filter = provider.get_blacklist_filter() - + with open(output_path, 'w', encoding='utf-8', buffering=8192) as f: for host in provider.iter_hosts(): - # 应用黑名单过滤(如果有) - if blacklist_filter and not blacklist_filter.is_allowed(host): - continue - f.write(f"{host}\n") total_count += 1 - + if total_count % 1000 == 0: logger.info("已导出 %d 个主机...", total_count) - + logger.info("✓ 主机列表导出完成 - 总数: %d, 文件: %s", total_count, str(output_path)) - - # 构建返回值 + result = { 'success': True, 'output_file': str(output_path), 'total_count': total_count, } - + # 传统模式:保持返回值格式不变(向后兼容) if use_legacy_mode: - # 获取 target_type(仅传统模式需要) from apps.targets.services import TargetService target = TargetService().get_target(target_id) result['target_type'] = target.type if target else 'unknown' - + return result diff --git a/backend/apps/scan/utils/__init__.py b/backend/apps/scan/utils/__init__.py index 26a3084b..f7a1e13d 100644 --- a/backend/apps/scan/utils/__init__.py +++ b/backend/apps/scan/utils/__init__.py @@ -4,37 +4,40 @@ 提供扫描相关的工具函数。 """ -from .directory_cleanup import remove_directory +from . import config_parser from .command_builder import build_scan_command from .command_executor import execute_and_wait, execute_stream -from .wordlist_helpers import ensure_wordlist_local +from .directory_cleanup import remove_directory from .nuclei_helpers import ensure_nuclei_templates_local -from .performance import FlowPerformanceTracker, CommandPerformanceTracker -from .workspace_utils import setup_scan_workspace, setup_scan_directory +from .performance import CommandPerformanceTracker, FlowPerformanceTracker +from .system_load import check_system_load, wait_for_system_load from .user_logger import user_log -from . import config_parser +from .wordlist_helpers import ensure_wordlist_local +from .workspace_utils import setup_scan_directory, setup_scan_workspace __all__ = [ # 目录清理 'remove_directory', # 工作空间 - 'setup_scan_workspace', # 创建 Scan 根工作空间 - 'setup_scan_directory', # 创建扫描子目录 + 'setup_scan_workspace', + 'setup_scan_directory', # 命令构建 - 'build_scan_command', # 扫描工具命令构建(基于 f-string) + 'build_scan_command', # 命令执行 - 'execute_and_wait', # 等待式执行(文件输出) - 'execute_stream', # 流式执行(实时处理) + 'execute_and_wait', + 'execute_stream', + # 系统负载 + 'wait_for_system_load', + 'check_system_load', # 字典文件 - 'ensure_wordlist_local', # 确保本地字典文件(含 hash 校验) + 'ensure_wordlist_local', # Nuclei 模板 - 'ensure_nuclei_templates_local', # 确保本地模板(含 commit hash 校验) + 'ensure_nuclei_templates_local', # 性能监控 - 'FlowPerformanceTracker', # Flow 性能追踪器(含系统资源采样) - 'CommandPerformanceTracker', # 命令性能追踪器 + 'FlowPerformanceTracker', + 'CommandPerformanceTracker', # 扫描日志 - 'user_log', # 用户可见扫描日志记录 + 'user_log', # 配置解析 'config_parser', ] - diff --git a/backend/apps/scan/utils/command_executor.py b/backend/apps/scan/utils/command_executor.py index 9e1302a6..ea085756 100644 --- a/backend/apps/scan/utils/command_executor.py +++ b/backend/apps/scan/utils/command_executor.py @@ -12,16 +12,18 @@ import logging import os -from django.conf import settings import re import signal import subprocess import threading import time +from collections import deque from datetime import datetime from pathlib import Path from typing import Dict, Any, Optional, Generator +from django.conf import settings + try: # 可选依赖:用于根据 CPU / 内存负载做动态并发控制 import psutil @@ -669,33 +671,68 @@ class CommandExecutor: def _read_log_tail(self, log_file: Path, max_lines: int = MAX_LOG_TAIL_LINES) -> str: """ - 读取日志文件的末尾部分 - + 读取日志文件的末尾部分(常量内存实现) + + 使用 seek 从文件末尾往前读取,避免将整个文件加载到内存。 + Args: log_file: 日志文件路径 max_lines: 最大读取行数 - + Returns: 日志内容(字符串),读取失败返回错误提示 """ if not log_file.exists(): logger.debug("日志文件不存在: %s", log_file) return "" - - if log_file.stat().st_size == 0: + + file_size = log_file.stat().st_size + if file_size == 0: logger.debug("日志文件为空: %s", log_file) return "" - + + # 每次读取的块大小(8KB,足够容纳大多数日志行) + chunk_size = 8192 + + def decode_line(line_bytes: bytes) -> str: + """解码单行:优先 UTF-8,失败则降级 latin-1""" + try: + return line_bytes.decode('utf-8') + except UnicodeDecodeError: + return line_bytes.decode('latin-1', errors='replace') + try: - with open(log_file, 'r', encoding='utf-8') as f: - lines = f.readlines() - return ''.join(lines[-max_lines:] if len(lines) > max_lines else lines) - except UnicodeDecodeError as e: - logger.warning("日志文件编码错误 (%s): %s", log_file, e) - return f"(无法读取日志文件: 编码错误 - {e})" + with open(log_file, 'rb') as f: + lines_found: deque[bytes] = deque() + remaining = b'' + position = file_size + + while position > 0 and len(lines_found) < max_lines: + read_size = min(chunk_size, position) + position -= read_size + + f.seek(position) + chunk = f.read(read_size) + remaining + parts = chunk.split(b'\n') + + # 最前面的部分可能不完整,留到下次处理 + remaining = parts[0] + + # 其余部分是完整的行(从后往前收集,用 appendleft 保持顺序) + for part in reversed(parts[1:]): + if len(lines_found) >= max_lines: + break + lines_found.appendleft(part) + + # 处理文件开头的行 + if remaining and len(lines_found) < max_lines: + lines_found.appendleft(remaining) + + return '\n'.join(decode_line(line) for line in lines_found) + except PermissionError as e: logger.warning("日志文件权限不足 (%s): %s", log_file, e) - return f"(无法读取日志文件: 权限不足)" + return "(无法读取日志文件: 权限不足)" except IOError as e: logger.warning("日志文件读取IO错误 (%s): %s", log_file, e) return f"(无法读取日志文件: IO错误 - {e})" diff --git a/backend/apps/scan/utils/system_load.py b/backend/apps/scan/utils/system_load.py new file mode 100644 index 00000000..48619f65 --- /dev/null +++ b/backend/apps/scan/utils/system_load.py @@ -0,0 +1,77 @@ +""" +系统负载检查工具 + +提供统一的系统负载检查功能,用于: +- Flow 入口处检查系统资源是否充足 +- 防止在高负载时启动新的扫描任务 +""" + +import logging +import time + +import psutil +from django.conf import settings + +logger = logging.getLogger(__name__) + +# 动态并发控制阈值(可在 Django settings 中覆盖) +SCAN_CPU_HIGH: float = getattr(settings, 'SCAN_CPU_HIGH', 90.0) +SCAN_MEM_HIGH: float = getattr(settings, 'SCAN_MEM_HIGH', 80.0) +SCAN_LOAD_CHECK_INTERVAL: int = getattr(settings, 'SCAN_LOAD_CHECK_INTERVAL', 180) + + +def _get_current_load() -> tuple[float, float]: + """获取当前 CPU 和内存使用率""" + return psutil.cpu_percent(interval=0.5), psutil.virtual_memory().percent + + +def wait_for_system_load( + cpu_threshold: float = SCAN_CPU_HIGH, + mem_threshold: float = SCAN_MEM_HIGH, + check_interval: int = SCAN_LOAD_CHECK_INTERVAL, + context: str = "task" +) -> None: + """ + 等待系统负载降到阈值以下 + + 在高负载时阻塞等待,直到 CPU 和内存都低于阈值。 + 用于 Flow 入口处,防止在资源紧张时启动新任务。 + """ + while True: + cpu, mem = _get_current_load() + + if cpu < cpu_threshold and mem < mem_threshold: + logger.debug( + "[%s] 系统负载正常: cpu=%.1f%%, mem=%.1f%%", + context, cpu, mem + ) + return + + logger.info( + "[%s] 系统负载较高,等待资源释放: " + "cpu=%.1f%% (阈值 %.1f%%), mem=%.1f%% (阈值 %.1f%%)", + context, cpu, cpu_threshold, mem, mem_threshold + ) + time.sleep(check_interval) + + +def check_system_load( + cpu_threshold: float = SCAN_CPU_HIGH, + mem_threshold: float = SCAN_MEM_HIGH +) -> dict: + """ + 检查当前系统负载(非阻塞) + + Returns: + dict: cpu_percent, mem_percent, cpu_threshold, mem_threshold, is_overloaded + """ + cpu, mem = _get_current_load() + + return { + 'cpu_percent': cpu, + 'mem_percent': mem, + 'cpu_threshold': cpu_threshold, + 'mem_threshold': mem_threshold, + 'is_overloaded': cpu >= cpu_threshold or mem >= mem_threshold, + } + diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 52685700..feec28ac 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -12,16 +12,34 @@ load-plugins = "pylint_django" [tool.pylint.messages_control] disable = [ - "missing-docstring", - "invalid-name", - "too-few-public-methods", - "no-member", - "import-error", - "no-name-in-module", + "missing-docstring", + "invalid-name", + "too-few-public-methods", + "no-member", + "import-error", + "no-name-in-module", + "wrong-import-position", # 允许函数内导入(防循环依赖) + "import-outside-toplevel", # 同上 + "too-many-arguments", # Django 视图/服务方法参数常超过5个 + "too-many-locals", # 复杂业务逻辑局部变量多 + "duplicate-code", # 某些模式代码相似是正常的 ] [tool.pylint.format] max-line-length = 120 [tool.pylint.basic] -good-names = ["i", "j", "k", "ex", "Run", "_", "id", "pk", "ip", "url", "db", "qs"] +good-names = [ + "i", + "j", + "k", + "ex", + "Run", + "_", + "id", + "pk", + "ip", + "url", + "db", + "qs", +]