diff --git a/backend/apps/common/prefect_django_setup.py b/backend/apps/common/django_setup.py similarity index 78% rename from backend/apps/common/prefect_django_setup.py rename to backend/apps/common/django_setup.py index b4452663..97d9254c 100644 --- a/backend/apps/common/prefect_django_setup.py +++ b/backend/apps/common/django_setup.py @@ -1,43 +1,43 @@ """ -Prefect Flow Django 环境初始化模块 +Django 环境初始化模块 -在所有 Prefect Flow 文件开头导入此模块即可自动配置 Django 环境 +在所有 Worker 脚本开头导入此模块即可自动配置 Django 环境。 """ import os import sys -def setup_django_for_prefect(): +def setup_django(): """ - 为 Prefect Flow 配置 Django 环境 - + 配置 Django 环境 + 此函数会: 1. 添加项目根目录到 Python 路径 2. 设置 DJANGO_SETTINGS_MODULE 环境变量 3. 调用 django.setup() 初始化 Django 4. 关闭旧的数据库连接,确保使用新连接 - + 使用方式: - from apps.common.prefect_django_setup import setup_django_for_prefect - setup_django_for_prefect() + from apps.common.django_setup import setup_django + setup_django() """ # 获取项目根目录(backend 目录) current_dir = os.path.dirname(os.path.abspath(__file__)) backend_dir = os.path.join(current_dir, '../..') backend_dir = os.path.abspath(backend_dir) - + # 添加到 Python 路径 if backend_dir not in sys.path: sys.path.insert(0, backend_dir) - + # 配置 Django os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') - + # 初始化 Django import django django.setup() - + # 关闭所有旧的数据库连接,确保 Worker 进程使用新连接 # 解决 "server closed the connection unexpectedly" 问题 from django.db import connections @@ -47,7 +47,7 @@ def setup_django_for_prefect(): def close_old_db_connections(): """ 关闭旧的数据库连接 - + 在长时间运行的任务中调用此函数,可以确保使用有效的数据库连接。 适用于: - Flow 开始前 @@ -59,4 +59,4 @@ def close_old_db_connections(): # 自动执行初始化(导入即生效) -setup_django_for_prefect() +setup_django() diff --git a/backend/apps/engine/services/task_distributor.py b/backend/apps/engine/services/task_distributor.py index 14672642..88829394 100644 --- a/backend/apps/engine/services/task_distributor.py +++ b/backend/apps/engine/services/task_distributor.py @@ -279,17 +279,11 @@ class TaskDistributor: # 环境变量:SERVER_URL + IS_LOCAL,其他配置容器启动时从配置中心获取 # IS_LOCAL 用于 Worker 向配置中心声明身份,决定返回的数据库地址 - # Prefect 本地模式配置:启用 ephemeral server(本地临时服务器) is_local_str = "true" if worker.is_local else "false" env_vars = [ f"-e SERVER_URL={shlex.quote(server_url)}", f"-e IS_LOCAL={is_local_str}", f"-e WORKER_API_KEY={shlex.quote(settings.WORKER_API_KEY)}", # Worker API 认证密钥 - "-e PREFECT_HOME=/tmp/.prefect", # 设置 Prefect 数据目录到可写位置 - "-e PREFECT_SERVER_EPHEMERAL_ENABLED=true", # 启用 ephemeral server(本地临时服务器) - "-e PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=120", # 增加启动超时时间 - "-e PREFECT_SERVER_DATABASE_CONNECTION_URL=sqlite+aiosqlite:////tmp/.prefect/prefect.db", # 使用 /tmp 下的 SQLite - "-e PREFECT_LOGGING_LEVEL=WARNING", # 日志级别(减少 DEBUG 噪音) ] # 挂载卷(统一挂载整个 /opt/xingrin 目录) diff --git a/backend/apps/scan/decorators.py b/backend/apps/scan/decorators.py new file mode 100644 index 00000000..897ef77d --- /dev/null +++ b/backend/apps/scan/decorators.py @@ -0,0 +1,200 @@ +""" +扫描流程装饰器模块 + +提供轻量级的 @scan_flow 和 @scan_task 装饰器,替代 Prefect 的 @flow 和 @task。 + +核心功能: +- @scan_flow: 状态管理、通知、性能追踪 +- @scan_task: 重试逻辑(大部分 task 不需要重试,可直接移除装饰器) + +设计原则: +- 保持与 Prefect 装饰器相同的使用方式 +- 零依赖,无额外内存开销 +- 保留原函数签名和返回值 +""" + +import functools +import logging +import time +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class FlowContext: + """ + Flow 执行上下文 + + 替代 Prefect 的 Flow、FlowRun、State 参数,传递给回调函数。 + """ + flow_name: str + stage_name: str + scan_id: Optional[int] = None + target_id: Optional[int] = None + target_name: Optional[str] = None + parameters: dict = field(default_factory=dict) + start_time: datetime = field(default_factory=datetime.now) + end_time: Optional[datetime] = None + result: Any = None + error: Optional[Exception] = None + error_message: Optional[str] = None + + +def scan_flow( + name: Optional[str] = None, + stage_name: Optional[str] = None, + on_running: Optional[list[Callable]] = None, + on_completion: Optional[list[Callable]] = None, + on_failure: Optional[list[Callable]] = None, + log_prints: bool = True, # 保持与 Prefect 兼容,但不使用 +): + """ + 扫描流程装饰器 + + 替代 Prefect 的 @flow 装饰器,提供: + - 自动状态管理(start_stage/complete_stage/fail_stage) + - 生命周期回调(on_running/on_completion/on_failure) + - 性能追踪(FlowPerformanceTracker) + - 失败通知 + + Args: + name: Flow 名称,默认使用函数名 + stage_name: 阶段名称,默认使用 name + on_running: 流程开始时的回调列表 + on_completion: 流程完成时的回调列表 + on_failure: 流程失败时的回调列表 + log_prints: 保持与 Prefect 兼容,不使用 + + Usage: + @scan_flow(name="site_scan", on_running=[on_scan_flow_running]) + def site_scan_flow(scan_id: int, target_id: int, ...): + ... + """ + def decorator(func: Callable) -> Callable: + flow_name = name or func.__name__ + actual_stage_name = stage_name or flow_name + + @functools.wraps(func) + def wrapper(*args, **kwargs) -> Any: + # 提取参数 + scan_id = kwargs.get('scan_id') + target_id = kwargs.get('target_id') + target_name = kwargs.get('target_name') + + # 创建上下文 + context = FlowContext( + flow_name=flow_name, + stage_name=actual_stage_name, + scan_id=scan_id, + target_id=target_id, + target_name=target_name, + parameters=kwargs.copy(), + start_time=datetime.now(), + ) + + # 执行 on_running 回调 + if on_running: + for callback in on_running: + try: + callback(context) + except Exception as e: + logger.warning("on_running 回调执行失败: %s", e) + + try: + # 执行原函数 + result = func(*args, **kwargs) + + # 更新上下文 + context.end_time = datetime.now() + context.result = result + + # 执行 on_completion 回调 + if on_completion: + for callback in on_completion: + try: + callback(context) + except Exception as e: + logger.warning("on_completion 回调执行失败: %s", e) + + return result + + except Exception as e: + # 更新上下文 + context.end_time = datetime.now() + context.error = e + context.error_message = str(e) + + # 执行 on_failure 回调 + if on_failure: + for callback in on_failure: + try: + callback(context) + except Exception as cb_error: + logger.warning("on_failure 回调执行失败: %s", cb_error) + + # 重新抛出异常 + raise + + return wrapper + return decorator + + +def scan_task( + retries: int = 0, + retry_delay: float = 1.0, + name: Optional[str] = None, # 保持与 Prefect 兼容 +): + """ + 扫描任务装饰器 + + 替代 Prefect 的 @task 装饰器,提供重试能力。 + + 注意:当前代码中大部分 @task 都是 retries=0,可以直接移除装饰器。 + 只有需要重试的 task 才需要使用此装饰器。 + + Args: + retries: 失败后重试次数,默认 0(不重试) + retry_delay: 重试间隔(秒),默认 1.0 + name: 任务名称,保持与 Prefect 兼容,不使用 + + Usage: + @scan_task(retries=3, retry_delay=2.0) + def run_scan_tool(command: str, timeout: int): + ... + """ + def decorator(func: Callable) -> Callable: + task_name = name or func.__name__ + + @functools.wraps(func) + def wrapper(*args, **kwargs) -> Any: + last_exception = None + + for attempt in range(retries + 1): + try: + return func(*args, **kwargs) + except Exception as e: + last_exception = e + if attempt < retries: + logger.warning( + "任务 %s 重试 %d/%d: %s", + task_name, attempt + 1, retries, e + ) + time.sleep(retry_delay) + else: + logger.error( + "任务 %s 重试耗尽 (%d 次): %s", + task_name, retries + 1, e + ) + + # 重试耗尽,抛出最后一个异常 + raise last_exception + + # 添加 submit 方法以保持与 Prefect task.submit() 的兼容性 + # 注意:这只是为了迁移过渡,最终应该使用 ThreadPoolExecutor + wrapper.fn = func + + return wrapper + return decorator diff --git a/backend/apps/scan/flows/directory_scan_flow.py b/backend/apps/scan/flows/directory_scan_flow.py index 2f8a54c3..f5e23dc3 100644 --- a/backend/apps/scan/flows/directory_scan_flow.py +++ b/backend/apps/scan/flows/directory_scan_flow.py @@ -17,8 +17,9 @@ from datetime import datetime from pathlib import Path from typing import List, Tuple -from prefect import flow +from concurrent.futures import ThreadPoolExecutor +from apps.scan.decorators import scan_flow from apps.scan.handlers.scan_flow_handlers import ( on_scan_flow_completed, on_scan_flow_failed, @@ -220,45 +221,47 @@ def _execute_batch( directories_found = 0 failed_sites = [] - # 提交任务 - futures = [] - for params in batch_params: - future = run_and_stream_save_directories_task.submit( - cmd=params['command'], - tool_name=tool_name, - scan_id=scan_id, - target_id=target_id, - site_url=params['site_url'], - cwd=str(directory_scan_dir), - shell=True, - batch_size=1000, - timeout=params['timeout'], - log_file=params['log_file'] - ) - futures.append((params['idx'], params['site_url'], future)) - - # 等待结果 - for idx, site_url, future in futures: - try: - result = future.result() - dirs_count = result.get('created_directories', 0) - directories_found += dirs_count - logger.info( - "✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录", - idx, total_sites, site_url, dirs_count + # 使用 ThreadPoolExecutor 并行执行 + with ThreadPoolExecutor(max_workers=len(batch_params)) as executor: + futures = [] + for params in batch_params: + future = executor.submit( + run_and_stream_save_directories_task, + cmd=params['command'], + tool_name=tool_name, + scan_id=scan_id, + target_id=target_id, + site_url=params['site_url'], + cwd=str(directory_scan_dir), + shell=True, + batch_size=1000, + timeout=params['timeout'], + log_file=params['log_file'] ) - except Exception as exc: - failed_sites.append(site_url) - if 'timeout' in str(exc).lower(): - logger.warning( - "⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s", - idx, total_sites, site_url, exc - ) - else: - logger.error( - "✗ [%d/%d] 站点扫描失败: %s - 错误: %s", - idx, total_sites, site_url, exc + futures.append((params['idx'], params['site_url'], future)) + + # 等待结果 + for idx, site_url, future in futures: + try: + result = future.result() + dirs_count = result.get('created_directories', 0) + directories_found += dirs_count + logger.info( + "✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录", + idx, total_sites, site_url, dirs_count ) + except Exception as exc: + failed_sites.append(site_url) + if 'timeout' in str(exc).lower(): + logger.warning( + "⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s", + idx, total_sites, site_url, exc + ) + else: + logger.error( + "✗ [%d/%d] 站点扫描失败: %s - 错误: %s", + idx, total_sites, site_url, exc + ) return directories_found, failed_sites @@ -381,9 +384,8 @@ def _run_scans_concurrently( return total_directories, processed_sites_count, failed_sites -@flow( +@scan_flow( name="directory_scan", - log_prints=True, on_running=[on_scan_flow_running], on_completion=[on_scan_flow_completed], on_failure=[on_scan_flow_failed], diff --git a/backend/apps/scan/flows/fingerprint_detect_flow.py b/backend/apps/scan/flows/fingerprint_detect_flow.py index e8a5fabd..9cb82ca6 100644 --- a/backend/apps/scan/flows/fingerprint_detect_flow.py +++ b/backend/apps/scan/flows/fingerprint_detect_flow.py @@ -16,8 +16,7 @@ from datetime import datetime from pathlib import Path from typing import Optional -from prefect import flow - +from apps.scan.decorators import scan_flow from apps.scan.handlers.scan_flow_handlers import ( on_scan_flow_completed, on_scan_flow_failed, @@ -193,9 +192,8 @@ def _aggregate_results(tool_stats: dict) -> dict: } -@flow( +@scan_flow( name="fingerprint_detect", - log_prints=True, on_running=[on_scan_flow_running], on_completion=[on_scan_flow_completed], on_failure=[on_scan_flow_failed], diff --git a/backend/apps/scan/flows/initiate_scan_flow.py b/backend/apps/scan/flows/initiate_scan_flow.py index 4b54a0fb..915d535c 100644 --- a/backend/apps/scan/flows/initiate_scan_flow.py +++ b/backend/apps/scan/flows/initiate_scan_flow.py @@ -5,13 +5,13 @@ 职责: - 使用 FlowOrchestrator 解析 YAML 配置 -- 在 Prefect Flow 中执行子 Flow(Subflow) +- 执行子 Flow(Subflow) - 按照 YAML 顺序编排工作流 - 根据 scan_mode 创建对应的 Provider - 不包含具体业务逻辑(由 Tasks 和 FlowOrchestrator 实现) 架构: -- Flow: Prefect 编排层(本文件) +- Flow: 编排层(本文件) - FlowOrchestrator: 配置解析和执行计划(apps/scan/services/) - Tasks: 执行层(apps/scan/tasks/) - Handlers: 状态管理(apps/scan/handlers/) @@ -19,13 +19,12 @@ # Django 环境初始化(导入即生效) # 注意:动态扫描容器应使用 run_initiate_scan.py 启动,以便在导入前设置环境变量 -import apps.common.prefect_django_setup # noqa: F401 +import apps.common.django_setup # noqa: F401 import logging +from concurrent.futures import ThreadPoolExecutor -from prefect import flow, task -from prefect.futures import wait - +from apps.scan.decorators import scan_flow from apps.scan.handlers import ( on_initiate_scan_flow_running, on_initiate_scan_flow_completed, @@ -37,13 +36,6 @@ from apps.scan.utils import setup_scan_workspace logger = logging.getLogger(__name__) -@task(name="run_subflow") -def _run_subflow_task(scan_type: str, flow_func, flow_kwargs: dict): - """包装子 Flow 的 Task,用于在并行阶段并发执行子 Flow。""" - logger.info("开始执行子 Flow: %s", scan_type) - return flow_func(**flow_kwargs) - - def _create_provider(scan, target_id: int, scan_id: int): """根据 scan_mode 创建对应的 Provider""" from apps.scan.models import Scan @@ -83,40 +75,36 @@ def _execute_sequential_flows(valid_flows: list, results: dict, executed_flows: def _execute_parallel_flows(valid_flows: list, results: dict, executed_flows: list): - """并行执行 Flow 列表""" - futures = [] - for scan_type, flow_func, flow_kwargs in valid_flows: - logger.info("=" * 60) - logger.info("提交并行子 Flow 任务: %s", scan_type) - logger.info("=" * 60) - future = _run_subflow_task.submit( - scan_type=scan_type, - flow_func=flow_func, - flow_kwargs=flow_kwargs, - ) - futures.append((scan_type, future)) - - if not futures: + """并行执行 Flow 列表(使用 ThreadPoolExecutor)""" + if not valid_flows: return - wait([f for _, f in futures]) + logger.info("并行执行 %d 个 Flow", len(valid_flows)) - for scan_type, future in futures: - try: - result = future.result() - executed_flows.append(scan_type) - results[scan_type] = result - logger.info("✓ %s 执行成功", scan_type) - except Exception as e: - logger.warning("%s 执行失败: %s", scan_type, e) - executed_flows.append(f"{scan_type} (失败)") - results[scan_type] = {'success': False, 'error': str(e)} + with ThreadPoolExecutor(max_workers=len(valid_flows)) as executor: + futures = [] + for scan_type, flow_func, flow_kwargs in valid_flows: + logger.info("=" * 60) + logger.info("提交并行子 Flow 任务: %s", scan_type) + logger.info("=" * 60) + future = executor.submit(flow_func, **flow_kwargs) + futures.append((scan_type, future)) + + # 收集结果 + for scan_type, future in futures: + try: + result = future.result() + executed_flows.append(scan_type) + results[scan_type] = result + logger.info("✓ %s 执行成功", scan_type) + except Exception as e: + logger.warning("%s 执行失败: %s", scan_type, e) + executed_flows.append(f"{scan_type} (失败)") + results[scan_type] = {'success': False, 'error': str(e)} -@flow( +@scan_flow( name='initiate_scan', - description='扫描任务初始化流程', - log_prints=True, on_running=[on_initiate_scan_flow_running], on_completion=[on_initiate_scan_flow_completed], on_failure=[on_initiate_scan_flow_failed], diff --git a/backend/apps/scan/flows/port_scan_flow.py b/backend/apps/scan/flows/port_scan_flow.py index c71fec8e..87a33714 100644 --- a/backend/apps/scan/flows/port_scan_flow.py +++ b/backend/apps/scan/flows/port_scan_flow.py @@ -15,8 +15,7 @@ import subprocess from datetime import datetime from pathlib import Path -from prefect import flow - +from apps.scan.decorators import scan_flow from apps.scan.handlers.scan_flow_handlers import ( on_scan_flow_completed, on_scan_flow_failed, @@ -283,9 +282,8 @@ def _run_scans_sequentially( return tool_stats, processed_records, successful_tool_names, failed_tools -@flow( +@scan_flow( name="port_scan", - log_prints=True, on_running=[on_scan_flow_running], on_completion=[on_scan_flow_completed], on_failure=[on_scan_flow_failed], diff --git a/backend/apps/scan/flows/screenshot_flow.py b/backend/apps/scan/flows/screenshot_flow.py index 55702fb3..02e100d5 100644 --- a/backend/apps/scan/flows/screenshot_flow.py +++ b/backend/apps/scan/flows/screenshot_flow.py @@ -9,8 +9,7 @@ import logging -from prefect import flow - +from apps.scan.decorators import scan_flow from apps.scan.handlers.scan_flow_handlers import ( on_scan_flow_completed, on_scan_flow_failed, @@ -34,9 +33,9 @@ def _parse_screenshot_config(enabled_tools: dict) -> dict: def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str]: """ 从 Provider 收集网站 URL(带回退逻辑) - + 优先级:WebSite → HostPortMapping → Default URL - + Returns: tuple: (urls, source) - urls: URL 列表 @@ -75,9 +74,8 @@ def _build_empty_result(scan_id: int, target_name: str) -> dict: } -@flow( +@scan_flow( name="screenshot", - log_prints=True, on_running=[on_scan_flow_running], on_completion=[on_scan_flow_completed], on_failure=[on_scan_flow_failed], diff --git a/backend/apps/scan/flows/site_scan_flow.py b/backend/apps/scan/flows/site_scan_flow.py index 725ff76b..69589cbf 100644 --- a/backend/apps/scan/flows/site_scan_flow.py +++ b/backend/apps/scan/flows/site_scan_flow.py @@ -17,10 +17,8 @@ from datetime import datetime from pathlib import Path from typing import Optional -from prefect import flow +from apps.scan.decorators import scan_flow -# Django 环境初始化(导入即生效) -from apps.common.prefect_django_setup import setup_django_for_prefect # noqa: F401 from apps.scan.handlers.scan_flow_handlers import ( on_scan_flow_completed, on_scan_flow_failed, @@ -314,9 +312,8 @@ def _validate_flow_params( raise ValueError("scan_workspace_dir 不能为空") -@flow( +@scan_flow( name="site_scan", - log_prints=True, on_running=[on_scan_flow_running], on_completion=[on_scan_flow_completed], on_failure=[on_scan_flow_failed], diff --git a/backend/apps/scan/flows/subdomain_discovery_flow.py b/backend/apps/scan/flows/subdomain_discovery_flow.py index d2b58e45..a76ee111 100644 --- a/backend/apps/scan/flows/subdomain_discovery_flow.py +++ b/backend/apps/scan/flows/subdomain_discovery_flow.py @@ -26,10 +26,12 @@ from datetime import datetime from pathlib import Path from typing import Optional -from prefect import flow +from concurrent.futures import ThreadPoolExecutor + +from apps.scan.decorators import scan_flow # Django 环境初始化(导入即生效,pylint: disable=unused-import) -from apps.common.prefect_django_setup import setup_django_for_prefect # noqa: F401 +from apps.common.django_setup import setup_django # noqa: F401 from apps.common.normalizer import normalize_domain from apps.common.validators import validate_domain from apps.engine.services.wordlist_service import WordlistService @@ -178,7 +180,9 @@ def _run_scans_parallel( timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') futures = {} failed_tools = [] + tool_params = {} # 存储每个工具的参数 + # 准备所有工具的参数 for tool_name, tool_config in enabled_tools.items(): short_uuid = uuid.uuid4().hex[:4] output_file = str(result_dir / f"{tool_name}_{timestamp}_{short_uuid}.txt") @@ -207,40 +211,51 @@ def _run_scans_parallel( logger.debug("提交任务 - 工具: %s, 超时: %ds, 输出: %s", tool_name, timeout, output_file) user_log(scan_id, "subdomain_discovery", f"Running {tool_name}: {command}") - future = run_subdomain_discovery_task.submit( - tool=tool_name, - command=command, - timeout=timeout, - output_file=output_file - ) - futures[tool_name] = future + tool_params[tool_name] = { + 'command': command, + 'timeout': timeout, + 'output_file': output_file + } - if not futures: + if not tool_params: logger.warning("所有扫描工具均无法启动 - 目标: %s", domain_name) return [], [{'tool': 'all', 'reason': '所有工具均无法启动'}], [] - result_files = [] - for tool_name, future in futures.items(): - try: - result = future.result() - if result: - result_files.append(result) - logger.info("✓ 扫描工具 %s 执行成功: %s", tool_name, result) - user_log(scan_id, "subdomain_discovery", f"{tool_name} completed") - else: - failed_tools.append({'tool': tool_name, 'reason': '未生成结果文件'}) - logger.warning("⚠️ 扫描工具 %s 未生成结果文件", tool_name) - user_log(scan_id, "subdomain_discovery", f"{tool_name} failed: no output", "error") - except (subprocess.TimeoutExpired, OSError) as e: - failed_tools.append({'tool': tool_name, 'reason': str(e)}) - logger.warning("⚠️ 扫描工具 %s 执行失败: %s", tool_name, e) - user_log(scan_id, "subdomain_discovery", f"{tool_name} failed: {e}", "error") + # 使用 ThreadPoolExecutor 并行执行 + with ThreadPoolExecutor(max_workers=len(tool_params)) as executor: + for tool_name, params in tool_params.items(): + future = executor.submit( + run_subdomain_discovery_task, + tool=tool_name, + command=params['command'], + timeout=params['timeout'], + output_file=params['output_file'] + ) + futures[tool_name] = future - successful_tools = [name for name in futures if name not in [f['tool'] for f in failed_tools]] + # 收集结果 + result_files = [] + for tool_name, future in futures.items(): + try: + result = future.result() + if result: + result_files.append(result) + logger.info("✓ 扫描工具 %s 执行成功: %s", tool_name, result) + user_log(scan_id, "subdomain_discovery", f"{tool_name} completed") + else: + failed_tools.append({'tool': tool_name, 'reason': '未生成结果文件'}) + logger.warning("⚠️ 扫描工具 %s 未生成结果文件", tool_name) + user_log(scan_id, "subdomain_discovery", f"{tool_name} failed: no output", "error") + except (subprocess.TimeoutExpired, OSError) as e: + failed_tools.append({'tool': tool_name, 'reason': str(e)}) + logger.warning("⚠️ 扫描工具 %s 执行失败: %s", tool_name, e) + user_log(scan_id, "subdomain_discovery", f"{tool_name} failed: {e}", "error") + + successful_tools = [name for name in tool_params if name not in [f['tool'] for f in failed_tools]] logger.info( "✓ 扫描工具并行执行完成 - 成功: %d/%d", - len(result_files), len(futures) + len(result_files), len(tool_params) ) return result_files, failed_tools, successful_tools @@ -531,9 +546,8 @@ def _empty_result(scan_id: int, target: str, scan_workspace_dir: str) -> dict: } -@flow( +@scan_flow( name="subdomain_discovery", - log_prints=True, on_running=[on_scan_flow_running], on_completion=[on_scan_flow_completed], on_failure=[on_scan_flow_failed], diff --git a/backend/apps/scan/flows/url_fetch/domain_name_url_fetch_flow.py b/backend/apps/scan/flows/url_fetch/domain_name_url_fetch_flow.py index 15db18e1..5463bf32 100644 --- a/backend/apps/scan/flows/url_fetch/domain_name_url_fetch_flow.py +++ b/backend/apps/scan/flows/url_fetch/domain_name_url_fetch_flow.py @@ -11,17 +11,14 @@ - IP 和 CIDR 类型会自动跳过(被动收集工具不支持) """ -# Django 环境初始化 -from apps.common.prefect_django_setup import setup_django_for_prefect - import logging import uuid +from concurrent.futures import ThreadPoolExecutor from datetime import datetime from pathlib import Path from typing import Dict -from prefect import flow - +from apps.scan.decorators import scan_flow from apps.common.validators import validate_domain from apps.scan.tasks.url_fetch import run_url_fetcher_task from apps.scan.utils import build_scan_command @@ -30,7 +27,7 @@ from apps.scan.utils import build_scan_command logger = logging.getLogger(__name__) -@flow(name="domain_name_url_fetch_flow", log_prints=True) +@scan_flow(name="domain_name_url_fetch_flow") def domain_name_url_fetch_flow( scan_id: int, target_id: int, @@ -77,7 +74,7 @@ def domain_name_url_fetch_flow( if target and target.type != Target.TargetType.DOMAIN: logger.info( - "跳过 domain_name URL 获取: Target 类型为 %s (ID=%d, Name=%s),waymore 等工具仅适用于域名类型", + "跳过 domain_name URL 获取: Target 类型为 %s (ID=%d, Name=%s)", target.type, target_id, target_name ) return { @@ -96,10 +93,10 @@ def domain_name_url_fetch_flow( ", ".join(domain_name_tools.keys()) if domain_name_tools else "无", ) - futures: dict[str, object] = {} + tool_params = {} # 存储每个工具的参数 failed_tools: list[dict] = [] - # 提交所有基于域名的 URL 获取任务 + # 准备所有工具的参数 for tool_name, tool_config in domain_name_tools.items(): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") short_uuid = uuid.uuid4().hex[:4] @@ -153,46 +150,62 @@ def domain_name_url_fetch_flow( # 记录工具开始执行日志 user_log(scan_id, "url_fetch", f"Running {tool_name}: {command}") - future = run_url_fetcher_task.submit( - tool_name=tool_name, - command=command, - timeout=timeout, - output_file=output_file, - ) - futures[tool_name] = future + tool_params[tool_name] = { + 'command': command, + 'timeout': timeout, + 'output_file': output_file + } result_files: list[str] = [] successful_tools: list[str] = [] - # 收集执行结果 - for tool_name, future in futures.items(): - try: - result = future.result() - if result and result.get("success"): - result_files.append(result["output_file"]) - successful_tools.append(tool_name) - url_count = result.get("url_count", 0) - logger.info( - "✓ 工具 %s 执行成功 - 发现 URL: %d", - tool_name, - url_count, + # 使用 ThreadPoolExecutor 并行执行 + if tool_params: + with ThreadPoolExecutor(max_workers=len(tool_params)) as executor: + futures = {} + for tool_name, params in tool_params.items(): + future = executor.submit( + run_url_fetcher_task, + tool_name=tool_name, + command=params['command'], + timeout=params['timeout'], + output_file=params['output_file'], ) - user_log(scan_id, "url_fetch", f"{tool_name} completed: found {url_count} urls") - else: - reason = "未生成结果或无有效 URL" - failed_tools.append( - { - "tool": tool_name, - "reason": reason, - } - ) - logger.warning("⚠️ 工具 %s 未生成有效结果", tool_name) - user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error") - except Exception as e: - reason = str(e) - failed_tools.append({"tool": tool_name, "reason": reason}) - logger.warning("⚠️ 工具 %s 执行失败: %s", tool_name, e) - user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error") + futures[tool_name] = future + + # 收集执行结果 + for tool_name, future in futures.items(): + try: + result = future.result() + if result and result.get("success"): + result_files.append(result["output_file"]) + successful_tools.append(tool_name) + url_count = result.get("url_count", 0) + logger.info( + "✓ 工具 %s 执行成功 - 发现 URL: %d", + tool_name, + url_count, + ) + user_log( + scan_id, "url_fetch", + f"{tool_name} completed: found {url_count} urls" + ) + else: + reason = "未生成结果或无有效 URL" + failed_tools.append({"tool": tool_name, "reason": reason}) + logger.warning("⚠️ 工具 %s 未生成有效结果", tool_name) + user_log( + scan_id, "url_fetch", + f"{tool_name} failed: {reason}", "error" + ) + except Exception as e: + reason = str(e) + failed_tools.append({"tool": tool_name, "reason": reason}) + logger.warning("⚠️ 工具 %s 执行失败: %s", tool_name, e) + user_log( + scan_id, "url_fetch", + f"{tool_name} failed: {reason}", "error" + ) logger.info( "基于 domain_name 的 URL 获取完成 - 成功工具: %s, 失败工具: %s", diff --git a/backend/apps/scan/flows/url_fetch/main_flow.py b/backend/apps/scan/flows/url_fetch/main_flow.py index 1bda1445..4f7bd513 100644 --- a/backend/apps/scan/flows/url_fetch/main_flow.py +++ b/backend/apps/scan/flows/url_fetch/main_flow.py @@ -14,8 +14,7 @@ import logging from datetime import datetime from pathlib import Path -from prefect import flow - +from apps.scan.decorators import scan_flow from apps.scan.handlers.scan_flow_handlers import ( on_scan_flow_completed, on_scan_flow_failed, @@ -231,9 +230,8 @@ def _save_urls_to_database(merged_file: str, scan_id: int, target_id: int) -> in return saved_count -@flow( +@scan_flow( name="url_fetch", - log_prints=True, on_running=[on_scan_flow_running], on_completion=[on_scan_flow_completed], on_failure=[on_scan_flow_failed], diff --git a/backend/apps/scan/flows/url_fetch/sites_url_fetch_flow.py b/backend/apps/scan/flows/url_fetch/sites_url_fetch_flow.py index bd1b7aeb..6d4f373b 100644 --- a/backend/apps/scan/flows/url_fetch/sites_url_fetch_flow.py +++ b/backend/apps/scan/flows/url_fetch/sites_url_fetch_flow.py @@ -6,14 +6,10 @@ URL 爬虫 Flow 输入:sites_file(站点 URL 列表) """ -# Django 环境初始化 -from apps.common.prefect_django_setup import setup_django_for_prefect - import logging from pathlib import Path -from prefect import flow - +from apps.scan.decorators import scan_flow from .utils import run_tools_parallel logger = logging.getLogger(__name__) @@ -25,32 +21,32 @@ def _export_sites_file( ) -> tuple[str, int]: """ 导出站点 URL 列表到文件 - + Args: output_dir: 输出目录 provider: TargetProvider 实例 - + Returns: tuple: (file_path, count) """ from apps.scan.tasks.url_fetch import export_sites_task - + output_file = str(output_dir / "sites.txt") result = export_sites_task( output_file=output_file, provider=provider ) - + count = result['asset_count'] if count > 0: logger.info("✓ 站点列表导出完成 - 数量: %d", count) else: logger.warning("站点列表为空,爬虫可能无法正常工作") - + return output_file, count -@flow(name="sites_url_fetch_flow", log_prints=True) +@scan_flow(name="sites_url_fetch_flow") def sites_url_fetch_flow( scan_id: int, target_id: int, @@ -100,7 +96,7 @@ def sites_url_fetch_flow( output_dir=output_path, provider=provider ) - + # 默认值模式下,即使原本没有站点,也会有默认 URL 作为输入 if sites_count == 0: logger.warning("没有可用的站点,跳过爬虫") @@ -111,7 +107,7 @@ def sites_url_fetch_flow( 'successful_tools': [], 'sites_count': 0 } - + # Step 2: 并行执行爬虫工具 result_files, failed_tools, successful_tools = run_tools_parallel( tools=enabled_tools, @@ -120,12 +116,12 @@ def sites_url_fetch_flow( output_dir=output_path, scan_id=scan_id ) - + logger.info( "✓ 爬虫完成 - 成功: %d/%d, 结果文件: %d", len(successful_tools), len(enabled_tools), len(result_files) ) - + return { 'success': True, 'result_files': result_files, @@ -133,7 +129,7 @@ def sites_url_fetch_flow( 'successful_tools': successful_tools, 'sites_count': sites_count } - + except Exception as e: logger.error("URL 爬虫失败: %s", e, exc_info=True) return { diff --git a/backend/apps/scan/flows/url_fetch/utils.py b/backend/apps/scan/flows/url_fetch/utils.py index 8a11bbc2..8d6db137 100644 --- a/backend/apps/scan/flows/url_fetch/utils.py +++ b/backend/apps/scan/flows/url_fetch/utils.py @@ -5,6 +5,7 @@ URL Fetch 共享工具函数 import logging import subprocess import uuid +from concurrent.futures import ThreadPoolExecutor from datetime import datetime from pathlib import Path @@ -21,13 +22,13 @@ def calculate_timeout_by_line_count( ) -> int: """ 根据文件行数自动计算超时时间 - + Args: tool_config: 工具配置(保留参数,未来可能用于更复杂的计算) file_path: 输入文件路径 base_per_time: 每行的基础时间(秒) min_timeout: 最小超时时间(秒),默认60秒 - + Returns: int: 计算出的超时时间(秒),不低于 min_timeout """ @@ -64,7 +65,7 @@ def prepare_tool_execution( ) -> dict: """ 准备单个工具的执行参数 - + Args: tool_name: 工具名称 tool_config: 工具配置 @@ -72,7 +73,7 @@ def prepare_tool_execution( input_type: 输入类型(domains_file 或 sites_file) output_dir: 输出目录 scan_type: 扫描类型 - + Returns: dict: 执行参数,包含 command, input_file, output_file, timeout 或包含 error 键表示失败 @@ -110,7 +111,7 @@ def prepare_tool_execution( # 4. 计算超时时间(支持 auto 和显式整数) raw_timeout = tool_config.get("timeout", 3600) timeout = 3600 - + if isinstance(raw_timeout, str) and raw_timeout == "auto": try: # katana / waymore 每个站点需要更长时间 @@ -157,24 +158,24 @@ def run_tools_parallel( ) -> tuple[list, list, list]: """ 并行执行工具列表 - + Args: tools: 工具配置字典 {tool_name: tool_config} input_file: 输入文件路径 input_type: 输入类型 output_dir: 输出目录 scan_id: 扫描任务 ID(用于记录日志) - + Returns: tuple: (result_files, failed_tools, successful_tool_names) """ from apps.scan.tasks.url_fetch import run_url_fetcher_task from apps.scan.utils import user_log - futures: dict[str, object] = {} + tool_params = {} # 存储每个工具的参数 failed_tools: list[dict] = [] - # 提交所有工具的并行任务 + # 准备所有工具的参数 for tool_name, tool_config in tools.items(): exec_params = prepare_tool_execution( tool_name=tool_name, @@ -198,44 +199,54 @@ def run_tools_parallel( # 记录工具开始执行日志 user_log(scan_id, "url_fetch", f"Running {tool_name}: {exec_params['command']}") - # 提交并行任务 - future = run_url_fetcher_task.submit( - tool_name=tool_name, - command=exec_params["command"], - timeout=exec_params["timeout"], - output_file=exec_params["output_file"], - ) - futures[tool_name] = future + tool_params[tool_name] = exec_params - # 收集执行结果 + # 使用 ThreadPoolExecutor 并行执行 result_files = [] - for tool_name, future in futures.items(): - try: - result = future.result() - if result and result['success']: - result_files.append(result['output_file']) - url_count = result['url_count'] - logger.info( - "✓ 工具 %s 执行成功 - 发现 URL: %d", - tool_name, url_count + if tool_params: + with ThreadPoolExecutor(max_workers=len(tool_params)) as executor: + futures = {} + for tool_name, params in tool_params.items(): + future = executor.submit( + run_url_fetcher_task, + tool_name=tool_name, + command=params["command"], + timeout=params["timeout"], + output_file=params["output_file"], ) - user_log(scan_id, "url_fetch", f"{tool_name} completed: found {url_count} urls") - else: - reason = '未生成结果或无有效URL' - failed_tools.append({ - 'tool': tool_name, - 'reason': reason - }) - logger.warning("⚠️ 工具 %s 未生成有效结果", tool_name) - user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error") - except Exception as e: - reason = str(e) - failed_tools.append({ - 'tool': tool_name, - 'reason': reason - }) - logger.warning("⚠️ 工具 %s 执行失败: %s", tool_name, e) - user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error") + futures[tool_name] = future + + # 收集执行结果 + for tool_name, future in futures.items(): + try: + result = future.result() + if result and result['success']: + result_files.append(result['output_file']) + url_count = result['url_count'] + logger.info( + "✓ 工具 %s 执行成功 - 发现 URL: %d", + tool_name, url_count + ) + user_log( + scan_id, "url_fetch", + f"{tool_name} completed: found {url_count} urls" + ) + else: + reason = '未生成结果或无有效URL' + failed_tools.append({'tool': tool_name, 'reason': reason}) + logger.warning("⚠️ 工具 %s 未生成有效结果", tool_name) + user_log( + scan_id, "url_fetch", + f"{tool_name} failed: {reason}", "error" + ) + except Exception as e: + reason = str(e) + failed_tools.append({'tool': tool_name, 'reason': reason}) + logger.warning("⚠️ 工具 %s 执行失败: %s", tool_name, e) + user_log( + scan_id, "url_fetch", + f"{tool_name} failed: {reason}", "error" + ) # 计算成功的工具列表 failed_tool_names = [f['tool'] for f in failed_tools] diff --git a/backend/apps/scan/flows/vuln_scan/endpoints_vuln_scan_flow.py b/backend/apps/scan/flows/vuln_scan/endpoints_vuln_scan_flow.py index daa05e0c..fd45499a 100644 --- a/backend/apps/scan/flows/vuln_scan/endpoints_vuln_scan_flow.py +++ b/backend/apps/scan/flows/vuln_scan/endpoints_vuln_scan_flow.py @@ -1,17 +1,13 @@ -from apps.common.prefect_django_setup import setup_django_for_prefect +""" +基于 Endpoint 的漏洞扫描 Flow +""" import logging +from concurrent.futures import ThreadPoolExecutor from datetime import datetime -from pathlib import Path from typing import Dict -from prefect import flow - -from apps.scan.handlers.scan_flow_handlers import ( - on_scan_flow_running, - on_scan_flow_completed, - on_scan_flow_failed, -) +from apps.scan.decorators import scan_flow from apps.scan.utils import build_scan_command, ensure_nuclei_templates_local, user_log from apps.scan.tasks.vuln_scan import ( export_endpoints_task, @@ -25,13 +21,7 @@ from .utils import calculate_timeout_by_line_count logger = logging.getLogger(__name__) - - - -@flow( - name="endpoints_vuln_scan_flow", - log_prints=True, -) +@scan_flow(name="endpoints_vuln_scan_flow") def endpoints_vuln_scan_flow( scan_id: int, target_id: int, @@ -82,12 +72,9 @@ def endpoints_vuln_scan_flow( logger.info("Endpoint 导出完成,共 %d 条,开始执行漏洞扫描", total_endpoints) tool_results: Dict[str, dict] = {} + tool_params: Dict[str, dict] = {} # 存储每个工具的参数 - # Step 2: 并行执行每个漏洞扫描工具(目前主要是 Dalfox) - # 1)先为每个工具 submit Prefect Task,让 Worker 并行调度 - # 2)再统一收集各自的结果,组装成 tool_results - tool_futures: Dict[str, dict] = {} - + # Step 2: 准备每个漏洞扫描工具的参数 for tool_name, tool_config in enabled_tools.items(): # Nuclei 需要先确保本地模板存在(支持多个模板仓库) template_args = "" @@ -144,102 +131,105 @@ def endpoints_vuln_scan_flow( timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = vuln_scan_dir / f"{tool_name}_{timestamp}.log" - # Dalfox XSS 使用流式任务,一边解析一边保存漏洞结果 + logger.info("开始执行漏洞扫描工具 %s", tool_name) + user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}") + + # 确定工具类型 if tool_name == "dalfox_xss": - logger.info("开始执行漏洞扫描工具 %s(流式保存漏洞结果,已提交任务)", tool_name) - user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}") - future = run_and_stream_save_dalfox_vulns_task.submit( - cmd=command, - tool_name=tool_name, - scan_id=scan_id, - target_id=target_id, - cwd=str(vuln_scan_dir), - shell=True, - batch_size=1, - timeout=timeout, - log_file=str(log_file), - ) - - tool_futures[tool_name] = { - "future": future, - "command": command, - "timeout": timeout, - "log_file": str(log_file), - "mode": "streaming", - } + mode = "dalfox" elif tool_name == "nuclei": - # Nuclei 使用流式任务 - logger.info("开始执行漏洞扫描工具 %s(流式保存漏洞结果,已提交任务)", tool_name) - user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}") - future = run_and_stream_save_nuclei_vulns_task.submit( - cmd=command, - tool_name=tool_name, - scan_id=scan_id, - target_id=target_id, - cwd=str(vuln_scan_dir), - shell=True, - batch_size=1, - timeout=timeout, - log_file=str(log_file), - ) - - tool_futures[tool_name] = { - "future": future, - "command": command, - "timeout": timeout, - "log_file": str(log_file), - "mode": "streaming", - } + mode = "nuclei" else: - # 其他工具仍使用非流式执行逻辑 - logger.info("开始执行漏洞扫描工具 %s(已提交任务)", tool_name) - user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}") - future = run_vuln_tool_task.submit( - tool_name=tool_name, - command=command, - timeout=timeout, - log_file=str(log_file), - ) + mode = "normal" - tool_futures[tool_name] = { - "future": future, - "command": command, - "timeout": timeout, - "log_file": str(log_file), - "mode": "normal", - } + tool_params[tool_name] = { + "command": command, + "timeout": timeout, + "log_file": str(log_file), + "mode": mode, + } - # 统一收集所有工具的执行结果 - for tool_name, meta in tool_futures.items(): - future = meta["future"] - try: - result = future.result() + # Step 3: 使用 ThreadPoolExecutor 并行执行 + if tool_params: + with ThreadPoolExecutor(max_workers=len(tool_params)) as executor: + futures = {} + for tool_name, params in tool_params.items(): + if params["mode"] == "dalfox": + future = executor.submit( + run_and_stream_save_dalfox_vulns_task, + cmd=params["command"], + tool_name=tool_name, + scan_id=scan_id, + target_id=target_id, + cwd=str(vuln_scan_dir), + shell=True, + batch_size=1, + timeout=params["timeout"], + log_file=params["log_file"], + ) + elif params["mode"] == "nuclei": + future = executor.submit( + run_and_stream_save_nuclei_vulns_task, + cmd=params["command"], + tool_name=tool_name, + scan_id=scan_id, + target_id=target_id, + cwd=str(vuln_scan_dir), + shell=True, + batch_size=1, + timeout=params["timeout"], + log_file=params["log_file"], + ) + else: + future = executor.submit( + run_vuln_tool_task, + tool_name=tool_name, + command=params["command"], + timeout=params["timeout"], + log_file=params["log_file"], + ) + futures[tool_name] = future - if meta["mode"] == "streaming": - created_vulns = result.get("created_vulns", 0) - tool_results[tool_name] = { - "command": meta["command"], - "timeout": meta["timeout"], - "processed_records": result.get("processed_records"), - "created_vulns": created_vulns, - "command_log_file": meta["log_file"], - } - logger.info("✓ 工具 %s 执行完成 - 漏洞: %d", tool_name, created_vulns) - user_log(scan_id, "vuln_scan", f"{tool_name} completed: found {created_vulns} vulnerabilities") - else: - tool_results[tool_name] = { - "command": meta["command"], - "timeout": meta["timeout"], - "duration": result.get("duration"), - "returncode": result.get("returncode"), - "command_log_file": result.get("command_log_file"), - } - logger.info("✓ 工具 %s 执行完成 - returncode=%s", tool_name, result.get("returncode")) - user_log(scan_id, "vuln_scan", f"{tool_name} completed") - except Exception as e: - reason = str(e) - logger.error("工具 %s 执行失败: %s", tool_name, e, exc_info=True) - user_log(scan_id, "vuln_scan", f"{tool_name} failed: {reason}", "error") + # 收集结果 + for tool_name, future in futures.items(): + params = tool_params[tool_name] + try: + result = future.result() + + if params["mode"] in ("dalfox", "nuclei"): + created_vulns = result.get("created_vulns", 0) + tool_results[tool_name] = { + "command": params["command"], + "timeout": params["timeout"], + "processed_records": result.get("processed_records"), + "created_vulns": created_vulns, + "command_log_file": params["log_file"], + } + logger.info( + "✓ 工具 %s 执行完成 - 漏洞: %d", + tool_name, created_vulns + ) + user_log( + scan_id, "vuln_scan", + f"{tool_name} completed: found {created_vulns} vulnerabilities" + ) + else: + tool_results[tool_name] = { + "command": params["command"], + "timeout": params["timeout"], + "duration": result.get("duration"), + "returncode": result.get("returncode"), + "command_log_file": result.get("command_log_file"), + } + logger.info( + "✓ 工具 %s 执行完成 - returncode=%s", + tool_name, result.get("returncode") + ) + user_log(scan_id, "vuln_scan", f"{tool_name} completed") + except Exception as e: + reason = str(e) + logger.error("工具 %s 执行失败: %s", tool_name, e, exc_info=True) + user_log(scan_id, "vuln_scan", f"{tool_name} failed: {reason}", "error") return { "success": True, diff --git a/backend/apps/scan/flows/vuln_scan/main_flow.py b/backend/apps/scan/flows/vuln_scan/main_flow.py index 694cb9ac..0f9e531d 100644 --- a/backend/apps/scan/flows/vuln_scan/main_flow.py +++ b/backend/apps/scan/flows/vuln_scan/main_flow.py @@ -4,8 +4,7 @@ import logging from typing import Dict, Tuple -from prefect import flow - +from apps.scan.decorators import scan_flow from apps.scan.handlers.scan_flow_handlers import ( on_scan_flow_running, on_scan_flow_completed, @@ -58,9 +57,8 @@ def _classify_vuln_tools( return endpoints_tools, websites_tools, other_tools -@flow( +@scan_flow( name="vuln_scan", - log_prints=True, on_running=[on_scan_flow_running], on_completion=[on_scan_flow_completed], on_failure=[on_scan_flow_failed], diff --git a/backend/apps/scan/flows/vuln_scan/websites_vuln_scan_flow.py b/backend/apps/scan/flows/vuln_scan/websites_vuln_scan_flow.py index c199a109..09e48313 100644 --- a/backend/apps/scan/flows/vuln_scan/websites_vuln_scan_flow.py +++ b/backend/apps/scan/flows/vuln_scan/websites_vuln_scan_flow.py @@ -9,8 +9,9 @@ import logging from datetime import datetime from typing import Dict -from prefect import flow +from concurrent.futures import ThreadPoolExecutor +from apps.scan.decorators import scan_flow from apps.scan.utils import build_scan_command, ensure_nuclei_templates_local, user_log from apps.scan.tasks.vuln_scan import run_and_stream_save_nuclei_vulns_task from apps.scan.tasks.vuln_scan.export_websites_task import export_websites_task @@ -19,10 +20,7 @@ from .utils import calculate_timeout_by_line_count logger = logging.getLogger(__name__) -@flow( - name="websites_vuln_scan_flow", - log_prints=True, -) +@scan_flow(name="websites_vuln_scan_flow") def websites_vuln_scan_flow( scan_id: int, target_id: int, @@ -134,47 +132,56 @@ def websites_vuln_scan_flow( logger.info("开始执行 %s 漏洞扫描(WebSite 模式)", tool_name) user_log(scan_id, "vuln_scan", f"Running {tool_name} (websites): {command}") - future = run_and_stream_save_nuclei_vulns_task.submit( - cmd=command, - tool_name=tool_name, - scan_id=scan_id, - target_id=target_id, - cwd=str(vuln_scan_dir), - shell=True, - batch_size=1, - timeout=timeout, - log_file=str(log_file), - ) - tool_futures[tool_name] = { - "future": future, "command": command, "timeout": timeout, "log_file": str(log_file), } - # 收集结果 - for tool_name, meta in tool_futures.items(): - future = meta["future"] - try: - result = future.result() - created_vulns = result.get("created_vulns", 0) - tool_results[tool_name] = { - "command": meta["command"], - "timeout": meta["timeout"], - "processed_records": result.get("processed_records"), - "created_vulns": created_vulns, - "command_log_file": meta["log_file"], - } - logger.info("✓ 工具 %s (websites) 执行完成 - 漏洞: %d", tool_name, created_vulns) - user_log( - scan_id, "vuln_scan", - f"{tool_name} (websites) completed: found {created_vulns} vulnerabilities" - ) - except Exception as e: - reason = str(e) - logger.error("工具 %s 执行失败: %s", tool_name, e, exc_info=True) - user_log(scan_id, "vuln_scan", f"{tool_name} failed: {reason}", "error") + # 使用 ThreadPoolExecutor 并行执行 + if tool_futures: + with ThreadPoolExecutor(max_workers=len(tool_futures)) as executor: + futures = {} + for tool_name, meta in tool_futures.items(): + future = executor.submit( + run_and_stream_save_nuclei_vulns_task, + cmd=meta["command"], + tool_name=tool_name, + scan_id=scan_id, + target_id=target_id, + cwd=str(vuln_scan_dir), + shell=True, + batch_size=1, + timeout=meta["timeout"], + log_file=meta["log_file"], + ) + futures[tool_name] = future + + # 收集结果 + for tool_name, future in futures.items(): + meta = tool_futures[tool_name] + try: + result = future.result() + created_vulns = result.get("created_vulns", 0) + tool_results[tool_name] = { + "command": meta["command"], + "timeout": meta["timeout"], + "processed_records": result.get("processed_records"), + "created_vulns": created_vulns, + "command_log_file": meta["log_file"], + } + logger.info( + "✓ 工具 %s (websites) 执行完成 - 漏洞: %d", + tool_name, created_vulns + ) + user_log( + scan_id, "vuln_scan", + f"{tool_name} (websites) completed: found {created_vulns} vulnerabilities" + ) + except Exception as e: + reason = str(e) + logger.error("工具 %s 执行失败: %s", tool_name, e, exc_info=True) + user_log(scan_id, "vuln_scan", f"{tool_name} failed: {reason}", "error") return { "success": True, diff --git a/backend/apps/scan/handlers/initiate_scan_flow_handlers.py b/backend/apps/scan/handlers/initiate_scan_flow_handlers.py index c7556b32..649faaac 100644 --- a/backend/apps/scan/handlers/initiate_scan_flow_handlers.py +++ b/backend/apps/scan/handlers/initiate_scan_flow_handlers.py @@ -12,57 +12,49 @@ initiate_scan_flow 状态处理器 """ import logging -from prefect import Flow -from prefect.client.schemas import FlowRun, State + +from apps.scan.decorators import FlowContext logger = logging.getLogger(__name__) - -def on_initiate_scan_flow_running(flow: Flow, flow_run: FlowRun, state: State) -> None: +def on_initiate_scan_flow_running(context: FlowContext) -> None: """ initiate_scan_flow 开始运行时的回调 - + 职责:更新 Scan 状态为 RUNNING + 发送通知 - - 触发时机: - - Prefect Flow 状态变为 Running 时自动触发 - - 在 Flow 函数体执行之前调用 - + Args: - flow: Prefect Flow 对象 - flow_run: Flow 运行实例 - state: Flow 当前状态 + context: Flow 执行上下文 """ - logger.info("🚀 initiate_scan_flow_running 回调开始运行 - Flow Run: %s", flow_run.id) - - scan_id = flow_run.parameters.get('scan_id') - target_name = flow_run.parameters.get('target_name') - engine_name = flow_run.parameters.get('engine_name') - scheduled_scan_name = flow_run.parameters.get('scheduled_scan_name') - + logger.info("🚀 initiate_scan_flow_running 回调开始运行 - Flow: %s", context.flow_name) + + scan_id = context.scan_id + target_name = context.parameters.get('target_name') + engine_name = context.parameters.get('engine_name') + scheduled_scan_name = context.parameters.get('scheduled_scan_name') + if not scan_id: logger.warning( - "Flow 参数中缺少 scan_id,跳过状态更新 - Flow Run: %s", - flow_run.id + "Flow 参数中缺少 scan_id,跳过状态更新 - Flow: %s", + context.flow_name ) return - + def _update_running_status(): from apps.scan.services import ScanService from apps.common.definitions import ScanStatus - + service = ScanService() success = service.update_status( - scan_id, + scan_id, ScanStatus.RUNNING ) - + if success: logger.info( - "✓ Flow 状态回调:扫描状态已更新为 RUNNING - Scan ID: %s, Flow Run: %s", - scan_id, - flow_run.id + "✓ Flow 状态回调:扫描状态已更新为 RUNNING - Scan ID: %s", + scan_id ) else: logger.error( @@ -70,15 +62,17 @@ def on_initiate_scan_flow_running(flow: Flow, flow_run: FlowRun, state: State) - scan_id ) return success - - # 执行状态更新(Repository 层已有 @auto_ensure_db_connection 保证连接可靠性) + + # 执行状态更新 _update_running_status() - + # 发送通知 logger.info("准备发送扫描开始通知 - Scan ID: %s, Target: %s", scan_id, target_name) try: - from apps.scan.notifications import create_notification, NotificationLevel, NotificationCategory - + from apps.scan.notifications import ( + create_notification, NotificationLevel, NotificationCategory + ) + # 根据是否为定时扫描构建不同的标题和消息 if scheduled_scan_name: title = f"⏰ {target_name} 扫描开始" @@ -86,7 +80,7 @@ def on_initiate_scan_flow_running(flow: Flow, flow_run: FlowRun, state: State) - else: title = f"{target_name} 扫描开始" message = f"引擎:{engine_name}" - + create_notification( title=title, message=message, @@ -95,47 +89,34 @@ def on_initiate_scan_flow_running(flow: Flow, flow_run: FlowRun, state: State) - ) logger.info("✓ 扫描开始通知已发送 - Scan ID: %s, Target: %s", scan_id, target_name) except Exception as e: - logger.error(f"发送扫描开始通知失败 - Scan ID: {scan_id}: {e}", exc_info=True) + logger.error("发送扫描开始通知失败 - Scan ID: %s: %s", scan_id, e, exc_info=True) -def on_initiate_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State) -> None: +def on_initiate_scan_flow_completed(context: FlowContext) -> None: """ initiate_scan_flow 成功完成时的回调 - + 职责:更新 Scan 状态为 COMPLETED - - 触发时机: - - Prefect Flow 正常执行完成时自动触发 - - 在 Flow 函数体返回之后调用 - - 策略:快速失败(Fail-Fast) - - Flow 成功完成 = 所有任务成功 → COMPLETED - - Flow 执行失败 = 有任务失败 → FAILED (由 on_failed 处理) - - 竞态条件处理: - - 如果用户已手动取消(状态已是 CANCELLED),保持终态,不覆盖 - + Args: - flow: Prefect Flow 对象 - flow_run: Flow 运行实例 - state: Flow 当前状态 + context: Flow 执行上下文 """ - logger.info("✅ initiate_scan_flow_completed 回调开始运行 - Flow Run: %s", flow_run.id) - - scan_id = flow_run.parameters.get('scan_id') - target_name = flow_run.parameters.get('target_name') - engine_name = flow_run.parameters.get('engine_name') - + logger.info("✅ initiate_scan_flow_completed 回调开始运行 - Flow: %s", context.flow_name) + + scan_id = context.scan_id + target_name = context.parameters.get('target_name') + engine_name = context.parameters.get('engine_name') + if not scan_id: return - + def _update_completed_status(): from apps.scan.services import ScanService from apps.common.definitions import ScanStatus from django.utils import timezone - + service = ScanService() - + # 仅在运行中时更新为 COMPLETED;其他状态保持不变 completed_updated = service.update_status_if_match( scan_id=scan_id, @@ -143,32 +124,30 @@ def on_initiate_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State) new_status=ScanStatus.COMPLETED, stopped_at=timezone.now() ) - + if completed_updated: logger.info( - "✓ Flow 状态回调:扫描状态已原子更新为 COMPLETED - Scan ID: %s, Flow Run: %s", - scan_id, - flow_run.id + "✓ Flow 状态回调:扫描状态已原子更新为 COMPLETED - Scan ID: %s", + scan_id ) return service.update_cached_stats(scan_id) else: logger.info( - "ℹ️ Flow 状态回调:状态未更新(可能已是终态)- Scan ID: %s, Flow Run: %s", - scan_id, - flow_run.id + "ℹ️ Flow 状态回调:状态未更新(可能已是终态)- Scan ID: %s", + scan_id ) return None - + # 执行状态更新并获取统计数据 stats = _update_completed_status() - - # 注意:物化视图刷新已迁移到 pg_ivm 增量维护,无需手动标记刷新 - + # 发送通知(包含统计摘要) logger.info("准备发送扫描完成通知 - Scan ID: %s, Target: %s", scan_id, target_name) try: - from apps.scan.notifications import create_notification, NotificationLevel, NotificationCategory - + from apps.scan.notifications import ( + create_notification, NotificationLevel, NotificationCategory + ) + # 构建通知消息 message = f"引擎:{engine_name}" if stats: @@ -180,11 +159,17 @@ def on_initiate_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State) results.append(f"目录: {stats.get('directories', 0)}") vulns_total = stats.get('vulns_total', 0) if vulns_total > 0: - results.append(f"漏洞: {vulns_total} (严重:{stats.get('vulns_critical', 0)} 高:{stats.get('vulns_high', 0)} 中:{stats.get('vulns_medium', 0)} 低:{stats.get('vulns_low', 0)})") + results.append( + f"漏洞: {vulns_total} " + f"(严重:{stats.get('vulns_critical', 0)} " + f"高:{stats.get('vulns_high', 0)} " + f"中:{stats.get('vulns_medium', 0)} " + f"低:{stats.get('vulns_low', 0)})" + ) else: results.append("漏洞: 0") message += f"\n结果:{' | '.join(results)}" - + create_notification( title=f"{target_name} 扫描完成", message=message, @@ -193,46 +178,35 @@ def on_initiate_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State) ) logger.info("✓ 扫描完成通知已发送 - Scan ID: %s, Target: %s", scan_id, target_name) except Exception as e: - logger.error(f"发送扫描完成通知失败 - Scan ID: {scan_id}: {e}", exc_info=True) + logger.error("发送扫描完成通知失败 - Scan ID: %s: %s", scan_id, e, exc_info=True) -def on_initiate_scan_flow_failed(flow: Flow, flow_run: FlowRun, state: State) -> None: +def on_initiate_scan_flow_failed(context: FlowContext) -> None: """ initiate_scan_flow 失败时的回调 - + 职责:更新 Scan 状态为 FAILED,并记录错误信息 - - 触发时机: - - Prefect Flow 执行失败或抛出异常时自动触发 - - Flow 超时、任务失败等所有失败场景都会触发此回调 - - 竞态条件处理: - - 如果用户已手动取消(状态已是 CANCELLED),保持终态,不覆盖 - + Args: - flow: Prefect Flow 对象 - flow_run: Flow 运行实例 - state: Flow 当前状态(包含错误信息) + context: Flow 执行上下文 """ - logger.info("❌ initiate_scan_flow_failed 回调开始运行 - Flow Run: %s", flow_run.id) - - scan_id = flow_run.parameters.get('scan_id') - target_name = flow_run.parameters.get('target_name') - engine_name = flow_run.parameters.get('engine_name') - + logger.info("❌ initiate_scan_flow_failed 回调开始运行 - Flow: %s", context.flow_name) + + scan_id = context.scan_id + target_name = context.parameters.get('target_name') + engine_name = context.parameters.get('engine_name') + error_message = context.error_message or "Flow 执行失败" + if not scan_id: return - + def _update_failed_status(): from apps.scan.services import ScanService from apps.common.definitions import ScanStatus from django.utils import timezone - + service = ScanService() - - # 提取错误信息 - error_message = str(state.message) if state.message else "Flow 执行失败" - + # 仅在运行中时更新为 FAILED;其他状态保持不变 failed_updated = service.update_status_if_match( scan_id=scan_id, @@ -240,33 +214,32 @@ def on_initiate_scan_flow_failed(flow: Flow, flow_run: FlowRun, state: State) -> new_status=ScanStatus.FAILED, stopped_at=timezone.now() ) - + if failed_updated: # 成功更新(正常失败流程) logger.error( - "✗ Flow 状态回调:扫描状态已原子更新为 FAILED - Scan ID: %s, Flow Run: %s, 错误: %s", + "✗ Flow 状态回调:扫描状态已原子更新为 FAILED - Scan ID: %s, 错误: %s", scan_id, - flow_run.id, error_message ) # 更新缓存统计数据(终态) service.update_cached_stats(scan_id) else: logger.warning( - "⚠️ Flow 状态回调:未更新任何记录(可能已被其他进程处理)- Scan ID: %s, Flow Run: %s", - scan_id, - flow_run.id + "⚠️ Flow 状态回调:未更新任何记录(可能已被其他进程处理)- Scan ID: %s", + scan_id ) return True - + # 执行状态更新 _update_failed_status() - + # 发送通知 logger.info("准备发送扫描失败通知 - Scan ID: %s, Target: %s", scan_id, target_name) try: - from apps.scan.notifications import create_notification, NotificationLevel, NotificationCategory - error_message = str(state.message) if state.message else "未知错误" + from apps.scan.notifications import ( + create_notification, NotificationLevel, NotificationCategory + ) message = f"引擎:{engine_name}\n错误:{error_message}" create_notification( title=f"{target_name} 扫描失败", @@ -276,4 +249,4 @@ def on_initiate_scan_flow_failed(flow: Flow, flow_run: FlowRun, state: State) -> ) logger.info("✓ 扫描失败通知已发送 - Scan ID: %s, Target: %s", scan_id, target_name) except Exception as e: - logger.error(f"发送扫描失败通知失败 - Scan ID: {scan_id}: {e}", exc_info=True) + logger.error("发送扫描失败通知失败 - Scan ID: %s: %s", scan_id, e, exc_info=True) diff --git a/backend/apps/scan/handlers/scan_flow_handlers.py b/backend/apps/scan/handlers/scan_flow_handlers.py index 5c6b09fa..3fe44dc4 100644 --- a/backend/apps/scan/handlers/scan_flow_handlers.py +++ b/backend/apps/scan/handlers/scan_flow_handlers.py @@ -10,22 +10,26 @@ """ import logging -from prefect import Flow -from prefect.client.schemas import FlowRun, State +from apps.scan.decorators import FlowContext from apps.scan.utils.performance import FlowPerformanceTracker from apps.scan.utils import user_log logger = logging.getLogger(__name__) -# 存储每个 flow_run 的性能追踪器 +# 存储每个 flow 的性能追踪器(使用 scan_id + stage_name 作为 key) _flow_trackers: dict[str, FlowPerformanceTracker] = {} +def _get_tracker_key(scan_id: int, stage_name: str) -> str: + """生成追踪器的唯一 key""" + return f"{scan_id}_{stage_name}" + + def _get_stage_from_flow_name(flow_name: str) -> str | None: """ 从 Flow name 获取对应的 stage - + Flow name 直接作为 stage(与 engine_config 的 key 一致) 排除主 Flow(initiate_scan) """ @@ -35,80 +39,81 @@ def _get_stage_from_flow_name(flow_name: str) -> str | None: return flow_name -def on_scan_flow_running(flow: Flow, flow_run: FlowRun, state: State) -> None: +def on_scan_flow_running(context: FlowContext) -> None: """ 扫描流程开始运行时的回调 - + 职责: - 更新阶段进度为 running - 发送扫描开始通知 - 启动性能追踪 - + Args: - flow: Prefect Flow 对象 - flow_run: Flow 运行实例 - state: Flow 当前状态 + context: Flow 执行上下文 """ - logger.info("🚀 扫描流程开始运行 - Flow: %s, Run ID: %s", flow.name, flow_run.id) - - # 提取流程参数 - flow_params = flow_run.parameters or {} - scan_id = flow_params.get('scan_id') - target_name = flow_params.get('target_name', 'unknown') - target_id = flow_params.get('target_id') - + logger.info( + "🚀 扫描流程开始运行 - Flow: %s, Scan ID: %s", + context.flow_name, context.scan_id + ) + + scan_id = context.scan_id + target_name = context.target_name or 'unknown' + target_id = context.target_id + # 启动性能追踪 if scan_id: - tracker = FlowPerformanceTracker(flow.name, scan_id) + tracker_key = _get_tracker_key(scan_id, context.stage_name) + tracker = FlowPerformanceTracker(context.flow_name, scan_id) tracker.start(target_id=target_id, target_name=target_name) - _flow_trackers[str(flow_run.id)] = tracker - + _flow_trackers[tracker_key] = tracker + # 更新阶段进度 - stage = _get_stage_from_flow_name(flow.name) + stage = _get_stage_from_flow_name(context.flow_name) if scan_id and stage: try: from apps.scan.services import ScanService service = ScanService() service.start_stage(scan_id, stage) - logger.info(f"✓ 阶段进度已更新为 running - Scan ID: {scan_id}, Stage: {stage}") + logger.info( + "✓ 阶段进度已更新为 running - Scan ID: %s, Stage: %s", + scan_id, stage + ) except Exception as e: - logger.error(f"更新阶段进度失败 - Scan ID: {scan_id}, Stage: {stage}: {e}") + logger.error( + "更新阶段进度失败 - Scan ID: %s, Stage: %s: %s", + scan_id, stage, e + ) -def on_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State) -> None: +def on_scan_flow_completed(context: FlowContext) -> None: """ 扫描流程完成时的回调 - + 职责: - 更新阶段进度为 completed - 发送扫描完成通知(可选) - 记录性能指标 - + Args: - flow: Prefect Flow 对象 - flow_run: Flow 运行实例 - state: Flow 当前状态 + context: Flow 执行上下文 """ - logger.info("✅ 扫描流程完成 - Flow: %s, Run ID: %s", flow.name, flow_run.id) - - # 提取流程参数 - flow_params = flow_run.parameters or {} - scan_id = flow_params.get('scan_id') - - # 获取 flow result - result = None - try: - result = state.result() if state.result else None - except Exception: - pass - + logger.info( + "✅ 扫描流程完成 - Flow: %s, Scan ID: %s", + context.flow_name, context.scan_id + ) + + scan_id = context.scan_id + result = context.result + # 记录性能指标 - tracker = _flow_trackers.pop(str(flow_run.id), None) - if tracker: - tracker.finish(success=True) - + if scan_id: + tracker_key = _get_tracker_key(scan_id, context.stage_name) + tracker = _flow_trackers.pop(tracker_key, None) + if tracker: + tracker.finish(success=True) + # 更新阶段进度 - stage = _get_stage_from_flow_name(flow.name) + stage = _get_stage_from_flow_name(context.flow_name) if scan_id and stage: try: from apps.scan.services import ScanService @@ -118,72 +123,88 @@ def on_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State) -> None: if isinstance(result, dict): detail = result.get('detail') service.complete_stage(scan_id, stage, detail) - logger.info(f"✓ 阶段进度已更新为 completed - Scan ID: {scan_id}, Stage: {stage}") + logger.info( + "✓ 阶段进度已更新为 completed - Scan ID: %s, Stage: %s", + scan_id, stage + ) # 每个阶段完成后刷新缓存统计,便于前端实时看到增量 try: service.update_cached_stats(scan_id) logger.info("✓ 阶段完成后已刷新缓存统计 - Scan ID: %s", scan_id) except Exception as e: - logger.error("阶段完成后刷新缓存统计失败 - Scan ID: %s, 错误: %s", scan_id, e) + logger.error( + "阶段完成后刷新缓存统计失败 - Scan ID: %s, 错误: %s", + scan_id, e + ) except Exception as e: - logger.error(f"更新阶段进度失败 - Scan ID: {scan_id}, Stage: {stage}: {e}") + logger.error( + "更新阶段进度失败 - Scan ID: %s, Stage: %s: %s", + scan_id, stage, e + ) -def on_scan_flow_failed(flow: Flow, flow_run: FlowRun, state: State) -> None: +def on_scan_flow_failed(context: FlowContext) -> None: """ 扫描流程失败时的回调 - + 职责: - 更新阶段进度为 failed - 发送扫描失败通知 - 记录性能指标(含错误信息) - 写入 ScanLog 供前端显示 - + Args: - flow: Prefect Flow 对象 - flow_run: Flow 运行实例 - state: Flow 当前状态 + context: Flow 执行上下文 """ - logger.info("❌ 扫描流程失败 - Flow: %s, Run ID: %s", flow.name, flow_run.id) - - # 提取流程参数 - flow_params = flow_run.parameters or {} - scan_id = flow_params.get('scan_id') - target_name = flow_params.get('target_name', 'unknown') - - # 提取错误信息 - error_message = str(state.message) if state.message else "未知错误" - + logger.info( + "❌ 扫描流程失败 - Flow: %s, Scan ID: %s", + context.flow_name, context.scan_id + ) + + scan_id = context.scan_id + target_name = context.target_name or 'unknown' + error_message = context.error_message or "未知错误" + # 写入 ScanLog 供前端显示 - stage = _get_stage_from_flow_name(flow.name) + stage = _get_stage_from_flow_name(context.flow_name) if scan_id and stage: user_log(scan_id, stage, f"Failed: {error_message}", "error") - + # 记录性能指标(失败情况) - tracker = _flow_trackers.pop(str(flow_run.id), None) - if tracker: - tracker.finish(success=False, error_message=error_message) - + if scan_id: + tracker_key = _get_tracker_key(scan_id, context.stage_name) + tracker = _flow_trackers.pop(tracker_key, None) + if tracker: + tracker.finish(success=False, error_message=error_message) + # 更新阶段进度 - stage = _get_stage_from_flow_name(flow.name) if scan_id and stage: try: from apps.scan.services import ScanService service = ScanService() service.fail_stage(scan_id, stage, error_message) - logger.info(f"✓ 阶段进度已更新为 failed - Scan ID: {scan_id}, Stage: {stage}") + logger.info( + "✓ 阶段进度已更新为 failed - Scan ID: %s, Stage: %s", + scan_id, stage + ) except Exception as e: - logger.error(f"更新阶段进度失败 - Scan ID: {scan_id}, Stage: {stage}: {e}") - + logger.error( + "更新阶段进度失败 - Scan ID: %s, Stage: %s: %s", + scan_id, stage, e + ) + # 发送通知 try: from apps.scan.notifications import create_notification, NotificationLevel - message = f"任务:{flow.name}\n状态:执行失败\n错误:{error_message}" + message = f"任务:{context.flow_name}\n状态:执行失败\n错误:{error_message}" create_notification( title=target_name, message=message, level=NotificationLevel.HIGH ) - logger.error(f"✓ 扫描失败通知已发送 - Target: {target_name}, Flow: {flow.name}, Error: {error_message}") + logger.error( + "✓ 扫描失败通知已发送 - Target: %s, Flow: %s, Error: %s", + target_name, context.flow_name, error_message + ) except Exception as e: - logger.error(f"发送扫描失败通知失败 - Flow: {flow.name}: {e}") + logger.error("发送扫描失败通知失败 - Flow: %s: %s", context.flow_name, e) diff --git a/backend/apps/scan/scripts/run_initiate_scan.py b/backend/apps/scan/scripts/run_initiate_scan.py index 5f3832ba..3339b709 100644 --- a/backend/apps/scan/scripts/run_initiate_scan.py +++ b/backend/apps/scan/scripts/run_initiate_scan.py @@ -11,109 +11,6 @@ import os import traceback -def diagnose_prefect_environment(): - """诊断 Prefect 运行环境,输出详细信息用于排查问题""" - print("\n" + "="*60) - print("Prefect 环境诊断") - print("="*60) - - # 1. 检查 Prefect 相关环境变量 - print("\n[诊断] Prefect 环境变量:") - prefect_vars = [ - 'PREFECT_HOME', - 'PREFECT_API_URL', - 'PREFECT_SERVER_EPHEMERAL_ENABLED', - 'PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS', - 'PREFECT_SERVER_DATABASE_CONNECTION_URL', - 'PREFECT_LOGGING_LEVEL', - 'PREFECT_DEBUG_MODE', - ] - for var in prefect_vars: - value = os.environ.get(var, 'NOT SET') - print(f" {var}={value}") - - # 2. 检查 PREFECT_HOME 目录 - prefect_home = os.environ.get('PREFECT_HOME', os.path.expanduser('~/.prefect')) - print(f"\n[诊断] PREFECT_HOME 目录: {prefect_home}") - if os.path.exists(prefect_home): - print(f" ✓ 目录存在") - print(f" 可写: {os.access(prefect_home, os.W_OK)}") - try: - files = os.listdir(prefect_home) - print(f" 文件列表: {files[:10]}{'...' if len(files) > 10 else ''}") - except Exception as e: - print(f" ✗ 无法列出文件: {e}") - else: - print(f" 目录不存在,尝试创建...") - try: - os.makedirs(prefect_home, exist_ok=True) - print(f" ✓ 创建成功") - except Exception as e: - print(f" ✗ 创建失败: {e}") - - # 3. 检查 uvicorn 是否可用 - print(f"\n[诊断] uvicorn 可用性:") - import shutil - uvicorn_path = shutil.which('uvicorn') - if uvicorn_path: - print(f" ✓ uvicorn 路径: {uvicorn_path}") - else: - print(f" ✗ uvicorn 不在 PATH 中") - print(f" PATH: {os.environ.get('PATH', 'NOT SET')}") - - # 4. 检查 Prefect 版本 - print(f"\n[诊断] Prefect 版本:") - try: - import prefect - print(f" ✓ prefect=={prefect.__version__}") - except Exception as e: - print(f" ✗ 无法导入 prefect: {e}") - - # 5. 检查 SQLite 支持 - print(f"\n[诊断] SQLite 支持:") - try: - import sqlite3 - print(f" ✓ sqlite3 版本: {sqlite3.sqlite_version}") - # 测试创建数据库 - test_db = os.path.join(prefect_home, 'test.db') - conn = sqlite3.connect(test_db) - conn.execute('CREATE TABLE IF NOT EXISTS test (id INTEGER)') - conn.close() - os.remove(test_db) - print(f" ✓ SQLite 读写测试通过") - except Exception as e: - print(f" ✗ SQLite 测试失败: {e}") - - # 6. 检查端口绑定能力 - print(f"\n[诊断] 端口绑定测试:") - try: - import socket - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(('127.0.0.1', 0)) - port = sock.getsockname()[1] - sock.close() - print(f" ✓ 可以绑定 127.0.0.1 端口 (测试端口: {port})") - except Exception as e: - print(f" ✗ 端口绑定失败: {e}") - - # 7. 检查内存情况 - print(f"\n[诊断] 系统资源:") - try: - import psutil - mem = psutil.virtual_memory() - print(f" 内存总量: {mem.total / 1024 / 1024:.0f} MB") - print(f" 可用内存: {mem.available / 1024 / 1024:.0f} MB") - print(f" 内存使用率: {mem.percent}%") - except ImportError: - print(f" psutil 未安装,跳过内存检查") - except Exception as e: - print(f" ✗ 资源检查失败: {e}") - - print("\n" + "="*60) - print("诊断完成") - print("="*60 + "\n") - - def main(): print("="*60) print("run_initiate_scan.py 启动") @@ -143,17 +40,13 @@ def main(): parser.add_argument("--scheduled_scan_name", type=str, default=None, help="定时扫描任务名称(可选)") args = parser.parse_args() - print(f"[2/4] ✓ 参数解析成功:") + print("[2/4] ✓ 参数解析成功:") print(f" scan_id: {args.scan_id}") print(f" target_id: {args.target_id}") print(f" scan_workspace_dir: {args.scan_workspace_dir}") print(f" engine_name: {args.engine_name}") print(f" scheduled_scan_name: {args.scheduled_scan_name}") - # 2.5. 运行 Prefect 环境诊断(仅在 DEBUG 模式下) - if os.environ.get('DEBUG', '').lower() == 'true': - diagnose_prefect_environment() - # 3. 现在可以安全导入 Django 相关模块 print("[3/4] 导入 initiate_scan_flow...") try: diff --git a/backend/apps/scan/tasks/directory_scan/export_sites_task.py b/backend/apps/scan/tasks/directory_scan/export_sites_task.py index e584d5fb..d84ca6c2 100644 --- a/backend/apps/scan/tasks/directory_scan/export_sites_task.py +++ b/backend/apps/scan/tasks/directory_scan/export_sites_task.py @@ -7,14 +7,14 @@ """ import logging from pathlib import Path -from prefect import task + from apps.scan.providers import TargetProvider logger = logging.getLogger(__name__) -@task(name="export_sites") + def export_sites_task( output_file: str, provider: TargetProvider, diff --git a/backend/apps/scan/tasks/directory_scan/run_and_stream_save_directories_task.py b/backend/apps/scan/tasks/directory_scan/run_and_stream_save_directories_task.py index 575b8eea..339f014f 100644 --- a/backend/apps/scan/tasks/directory_scan/run_and_stream_save_directories_task.py +++ b/backend/apps/scan/tasks/directory_scan/run_and_stream_save_directories_task.py @@ -24,7 +24,7 @@ import json import subprocess import time from pathlib import Path -from prefect import task + from typing import Generator, Optional, TYPE_CHECKING from django.db import IntegrityError, OperationalError, DatabaseError from psycopg2 import InterfaceError @@ -305,11 +305,11 @@ def _save_batch( return len(snapshot_items) -@task( - name='run_and_stream_save_directories', - retries=0, - log_prints=True -) + + + + + def run_and_stream_save_directories_task( cmd: str, tool_name: str, diff --git a/backend/apps/scan/tasks/fingerprint_detect/export_site_urls_task.py b/backend/apps/scan/tasks/fingerprint_detect/export_site_urls_task.py index 0e8665a3..0c9fb71b 100644 --- a/backend/apps/scan/tasks/fingerprint_detect/export_site_urls_task.py +++ b/backend/apps/scan/tasks/fingerprint_detect/export_site_urls_task.py @@ -9,14 +9,14 @@ import logging from pathlib import Path -from prefect import task + from apps.scan.providers import TargetProvider logger = logging.getLogger(__name__) -@task(name="export_site_urls_for_fingerprint") + def export_site_urls_for_fingerprint_task( output_file: str, provider: TargetProvider, diff --git a/backend/apps/scan/tasks/fingerprint_detect/run_xingfinger_task.py b/backend/apps/scan/tasks/fingerprint_detect/run_xingfinger_task.py index e33ada97..333c6bca 100644 --- a/backend/apps/scan/tasks/fingerprint_detect/run_xingfinger_task.py +++ b/backend/apps/scan/tasks/fingerprint_detect/run_xingfinger_task.py @@ -11,7 +11,7 @@ from typing import Optional, Generator from urllib.parse import urlparse from django.db import connection -from prefect import task + from apps.scan.utils import execute_stream from apps.asset.dtos.snapshot import WebsiteSnapshotDTO @@ -189,7 +189,7 @@ def _parse_xingfinger_stream_output( logger.info("流式解析完成 - 总行数: %d, 有效记录: %d", total_lines, valid_records) -@task(name="run_xingfinger_and_stream_update_tech") + def run_xingfinger_and_stream_update_tech_task( cmd: str, tool_name: str, diff --git a/backend/apps/scan/tasks/port_scan/export_hosts_task.py b/backend/apps/scan/tasks/port_scan/export_hosts_task.py index d6a0be7d..1750e452 100644 --- a/backend/apps/scan/tasks/port_scan/export_hosts_task.py +++ b/backend/apps/scan/tasks/port_scan/export_hosts_task.py @@ -6,14 +6,14 @@ import logging from pathlib import Path -from prefect import task + from apps.scan.providers import TargetProvider logger = logging.getLogger(__name__) -@task(name="export_hosts") + def export_hosts_task( output_file: str, provider: TargetProvider, diff --git a/backend/apps/scan/tasks/port_scan/run_and_stream_save_ports_task.py b/backend/apps/scan/tasks/port_scan/run_and_stream_save_ports_task.py index 84db75f3..11b05764 100644 --- a/backend/apps/scan/tasks/port_scan/run_and_stream_save_ports_task.py +++ b/backend/apps/scan/tasks/port_scan/run_and_stream_save_ports_task.py @@ -26,7 +26,7 @@ import subprocess import time from asyncio import CancelledError from pathlib import Path -from prefect import task + from typing import Generator, List, Optional, TYPE_CHECKING from django.db import IntegrityError, OperationalError, DatabaseError from psycopg2 import InterfaceError @@ -582,11 +582,11 @@ def _cleanup_resources(data_generator) -> None: ) -@task( - name='run_and_stream_save_ports', - retries=0, - log_prints=True -) + + + + + def run_and_stream_save_ports_task( cmd: str, tool_name: str, diff --git a/backend/apps/scan/tasks/screenshot/capture_screenshots_task.py b/backend/apps/scan/tasks/screenshot/capture_screenshots_task.py index c478f3ff..e6e3e19a 100644 --- a/backend/apps/scan/tasks/screenshot/capture_screenshots_task.py +++ b/backend/apps/scan/tasks/screenshot/capture_screenshots_task.py @@ -6,7 +6,7 @@ import asyncio import logging import time -from prefect import task + logger = logging.getLogger(__name__) @@ -140,7 +140,7 @@ async def _capture_and_save_screenshots( } -@task(name='capture_screenshots', retries=0) + def capture_screenshots_task( urls: list[str], scan_id: int, diff --git a/backend/apps/scan/tasks/site_scan/export_site_urls_task.py b/backend/apps/scan/tasks/site_scan/export_site_urls_task.py index c094ed40..d9a501ce 100644 --- a/backend/apps/scan/tasks/site_scan/export_site_urls_task.py +++ b/backend/apps/scan/tasks/site_scan/export_site_urls_task.py @@ -7,14 +7,14 @@ """ import logging from pathlib import Path -from prefect import task + from apps.scan.providers import TargetProvider logger = logging.getLogger(__name__) -@task(name="export_site_urls") + def export_site_urls_task( output_file: str, provider: TargetProvider, diff --git a/backend/apps/scan/tasks/site_scan/run_and_stream_save_websites_task.py b/backend/apps/scan/tasks/site_scan/run_and_stream_save_websites_task.py index df5035b5..c64f65d4 100644 --- a/backend/apps/scan/tasks/site_scan/run_and_stream_save_websites_task.py +++ b/backend/apps/scan/tasks/site_scan/run_and_stream_save_websites_task.py @@ -25,7 +25,7 @@ import json import subprocess import time from pathlib import Path -from prefect import task + from typing import Generator, Optional, Dict, Any, TYPE_CHECKING from django.db import IntegrityError, OperationalError, DatabaseError from dataclasses import dataclass @@ -659,7 +659,7 @@ def _cleanup_resources(data_generator) -> None: logger.error("关闭生成器时出错: %s", gen_close_error) -@task(name='run_and_stream_save_websites', retries=0) + def run_and_stream_save_websites_task( cmd: str, tool_name: str, diff --git a/backend/apps/scan/tasks/subdomain_discovery/merge_and_validate_task.py b/backend/apps/scan/tasks/subdomain_discovery/merge_and_validate_task.py index 5d3d3aa2..f03d3b5f 100644 --- a/backend/apps/scan/tasks/subdomain_discovery/merge_and_validate_task.py +++ b/backend/apps/scan/tasks/subdomain_discovery/merge_and_validate_task.py @@ -26,7 +26,7 @@ from datetime import datetime from pathlib import Path from typing import List -from prefect import task + logger = logging.getLogger(__name__) @@ -64,7 +64,7 @@ def _validate_input_files(result_files: List[str]) -> List[str]: return valid_files -@task(name='merge_and_deduplicate', retries=1, log_prints=True) + def merge_and_validate_task(result_files: List[str], result_dir: str) -> str: """ 合并扫描结果并去重(高性能流式处理) diff --git a/backend/apps/scan/tasks/subdomain_discovery/run_subdomain_discovery_task.py b/backend/apps/scan/tasks/subdomain_discovery/run_subdomain_discovery_task.py index b33bae89..d599279b 100644 --- a/backend/apps/scan/tasks/subdomain_discovery/run_subdomain_discovery_task.py +++ b/backend/apps/scan/tasks/subdomain_discovery/run_subdomain_discovery_task.py @@ -6,17 +6,17 @@ import logging from pathlib import Path -from prefect import task + from apps.scan.utils import execute_and_wait logger = logging.getLogger(__name__) -@task( - name='run_subdomain_discovery', - retries=0, # 显式禁用重试 - log_prints=True -) + + + + + def run_subdomain_discovery_task( tool: str, command: str, diff --git a/backend/apps/scan/tasks/subdomain_discovery/save_domains_task.py b/backend/apps/scan/tasks/subdomain_discovery/save_domains_task.py index 6f0b0893..5f7551aa 100644 --- a/backend/apps/scan/tasks/subdomain_discovery/save_domains_task.py +++ b/backend/apps/scan/tasks/subdomain_discovery/save_domains_task.py @@ -7,7 +7,7 @@ import logging import time from pathlib import Path -from prefect import task + from typing import List from dataclasses import dataclass from django.db import IntegrityError, OperationalError, DatabaseError @@ -35,11 +35,11 @@ class ServiceSet: ) -@task( - name='save_domains', - retries=0, - log_prints=True -) + + + + + def save_domains_task( domains_file: str, scan_id: int, diff --git a/backend/apps/scan/tasks/url_fetch/clean_urls_task.py b/backend/apps/scan/tasks/url_fetch/clean_urls_task.py index d1483958..8dd11286 100644 --- a/backend/apps/scan/tasks/url_fetch/clean_urls_task.py +++ b/backend/apps/scan/tasks/url_fetch/clean_urls_task.py @@ -11,7 +11,7 @@ import logging import subprocess from pathlib import Path from datetime import datetime -from prefect import task + from typing import Optional from apps.scan.utils import execute_and_wait @@ -19,11 +19,11 @@ from apps.scan.utils import execute_and_wait logger = logging.getLogger(__name__) -@task( - name='clean_urls_with_uro', - retries=1, - log_prints=True -) + + + + + def clean_urls_task( input_file: str, output_dir: str, diff --git a/backend/apps/scan/tasks/url_fetch/export_sites_task.py b/backend/apps/scan/tasks/url_fetch/export_sites_task.py index 2bf2cb98..d391fd3b 100644 --- a/backend/apps/scan/tasks/url_fetch/export_sites_task.py +++ b/backend/apps/scan/tasks/url_fetch/export_sites_task.py @@ -8,18 +8,18 @@ import logging from pathlib import Path -from prefect import task + from apps.scan.providers import TargetProvider logger = logging.getLogger(__name__) -@task( - name='export_sites_for_url_fetch', - retries=1, - log_prints=True -) + + + + + def export_sites_task( output_file: str, provider: TargetProvider, diff --git a/backend/apps/scan/tasks/url_fetch/merge_and_deduplicate_urls_task.py b/backend/apps/scan/tasks/url_fetch/merge_and_deduplicate_urls_task.py index 1784b5b9..8c7d0ca5 100644 --- a/backend/apps/scan/tasks/url_fetch/merge_and_deduplicate_urls_task.py +++ b/backend/apps/scan/tasks/url_fetch/merge_and_deduplicate_urls_task.py @@ -10,17 +10,17 @@ import uuid import subprocess from pathlib import Path from datetime import datetime -from prefect import task + from typing import List logger = logging.getLogger(__name__) -@task( - name='merge_and_deduplicate_urls', - retries=1, - log_prints=True -) + + + + + def merge_and_deduplicate_urls_task( result_files: List[str], result_dir: str diff --git a/backend/apps/scan/tasks/url_fetch/run_and_stream_save_urls_task.py b/backend/apps/scan/tasks/url_fetch/run_and_stream_save_urls_task.py index a706a2b9..89c159dd 100644 --- a/backend/apps/scan/tasks/url_fetch/run_and_stream_save_urls_task.py +++ b/backend/apps/scan/tasks/url_fetch/run_and_stream_save_urls_task.py @@ -22,7 +22,7 @@ import json import subprocess import time from pathlib import Path -from prefect import task + from typing import Generator, Optional, Dict, Any from django.db import IntegrityError, OperationalError, DatabaseError from psycopg2 import InterfaceError @@ -582,7 +582,7 @@ def _process_records_in_batches( } -@task(name="run_and_stream_save_urls", retries=0) + def run_and_stream_save_urls_task( cmd: str, tool_name: str, diff --git a/backend/apps/scan/tasks/url_fetch/run_url_fetcher_task.py b/backend/apps/scan/tasks/url_fetch/run_url_fetcher_task.py index d871079a..5e95fe91 100644 --- a/backend/apps/scan/tasks/url_fetch/run_url_fetcher_task.py +++ b/backend/apps/scan/tasks/url_fetch/run_url_fetcher_task.py @@ -10,17 +10,17 @@ import logging from pathlib import Path -from prefect import task + from apps.scan.utils import execute_and_wait logger = logging.getLogger(__name__) -@task( - name='run_url_fetcher', - retries=0, # 不重试,工具本身会处理 - log_prints=True -) + + + + + def run_url_fetcher_task( tool_name: str, command: str, diff --git a/backend/apps/scan/tasks/url_fetch/save_urls_task.py b/backend/apps/scan/tasks/url_fetch/save_urls_task.py index 01e32e56..080597d4 100644 --- a/backend/apps/scan/tasks/url_fetch/save_urls_task.py +++ b/backend/apps/scan/tasks/url_fetch/save_urls_task.py @@ -7,7 +7,7 @@ import logging from pathlib import Path -from prefect import task + from typing import List, Optional from urllib.parse import urlparse from dataclasses import dataclass @@ -70,11 +70,11 @@ def _parse_url(url: str) -> Optional[ParsedURL]: return None -@task( - name='save_urls', - retries=1, - log_prints=True -) + + + + + def save_urls_task( urls_file: str, scan_id: int, diff --git a/backend/apps/scan/tasks/vuln_scan/export_endpoints_task.py b/backend/apps/scan/tasks/vuln_scan/export_endpoints_task.py index f3febe5e..4885e4c9 100644 --- a/backend/apps/scan/tasks/vuln_scan/export_endpoints_task.py +++ b/backend/apps/scan/tasks/vuln_scan/export_endpoints_task.py @@ -9,14 +9,14 @@ import logging from typing import Dict from pathlib import Path -from prefect import task + from apps.scan.providers import TargetProvider logger = logging.getLogger(__name__) -@task(name="export_endpoints") + def export_endpoints_task( output_file: str, provider: TargetProvider, diff --git a/backend/apps/scan/tasks/vuln_scan/export_websites_task.py b/backend/apps/scan/tasks/vuln_scan/export_websites_task.py index b188d106..1c110469 100644 --- a/backend/apps/scan/tasks/vuln_scan/export_websites_task.py +++ b/backend/apps/scan/tasks/vuln_scan/export_websites_task.py @@ -8,14 +8,14 @@ import logging from pathlib import Path -from prefect import task + from apps.scan.providers import TargetProvider logger = logging.getLogger(__name__) -@task(name="export_websites_for_vuln_scan") + def export_websites_task( output_file: str, provider: TargetProvider, diff --git a/backend/apps/scan/tasks/vuln_scan/run_and_stream_save_dalfox_vulns_task.py b/backend/apps/scan/tasks/vuln_scan/run_and_stream_save_dalfox_vulns_task.py index b03d3cc8..895eef17 100644 --- a/backend/apps/scan/tasks/vuln_scan/run_and_stream_save_dalfox_vulns_task.py +++ b/backend/apps/scan/tasks/vuln_scan/run_and_stream_save_dalfox_vulns_task.py @@ -25,7 +25,7 @@ from pathlib import Path from dataclasses import dataclass from typing import Generator, Optional, TYPE_CHECKING -from prefect import task + from django.db import IntegrityError, OperationalError, DatabaseError from psycopg2 import InterfaceError @@ -393,11 +393,11 @@ def _cleanup_resources(data_generator) -> None: logger.error("关闭生成器时出错: %s", gen_close_error) -@task( - name="run_and_stream_save_dalfox_vulns", - retries=0, - log_prints=True, -) + + + + + def run_and_stream_save_dalfox_vulns_task( cmd: str, tool_name: str, diff --git a/backend/apps/scan/tasks/vuln_scan/run_and_stream_save_nuclei_vulns_task.py b/backend/apps/scan/tasks/vuln_scan/run_and_stream_save_nuclei_vulns_task.py index ee2f4025..d8295a83 100644 --- a/backend/apps/scan/tasks/vuln_scan/run_and_stream_save_nuclei_vulns_task.py +++ b/backend/apps/scan/tasks/vuln_scan/run_and_stream_save_nuclei_vulns_task.py @@ -22,7 +22,7 @@ from pathlib import Path from dataclasses import dataclass from typing import Generator, Optional, TYPE_CHECKING -from prefect import task + from django.db import IntegrityError, OperationalError, DatabaseError from psycopg2 import InterfaceError @@ -395,11 +395,11 @@ def _cleanup_resources(data_generator) -> None: logger.error("关闭生成器时出错: %s", gen_close_error) -@task( - name="run_and_stream_save_nuclei_vulns", - retries=0, - log_prints=True, -) + + + + + def run_and_stream_save_nuclei_vulns_task( cmd: str, tool_name: str, diff --git a/backend/apps/scan/tasks/vuln_scan/run_vuln_tool_task.py b/backend/apps/scan/tasks/vuln_scan/run_vuln_tool_task.py index 3c1f2091..3c8478a1 100644 --- a/backend/apps/scan/tasks/vuln_scan/run_vuln_tool_task.py +++ b/backend/apps/scan/tasks/vuln_scan/run_vuln_tool_task.py @@ -10,18 +10,18 @@ import logging from typing import Dict -from prefect import task + from apps.scan.utils import execute_and_wait logger = logging.getLogger(__name__) -@task( - name="run_vuln_tool", - retries=0, - log_prints=True, -) + + + + + def run_vuln_tool_task( tool_name: str, command: str, diff --git a/backend/requirements.txt b/backend/requirements.txt index 7c448e71..e2d2e701 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -17,9 +17,7 @@ django-filter==24.3 # 环境变量管理 python-dotenv==1.0.1 -# 异步任务和工作流编排 -prefect==3.4.25 -fastapi==0.115.5 # 锁定版本,0.123+ 与 Prefect 不兼容 +# 异步任务 redis==5.0.3 # 可选:用于缓存 APScheduler>=3.10.0 # 定时任务调度器