diff --git a/backend/apps/scan/flows/directory_scan_flow.py b/backend/apps/scan/flows/directory_scan_flow.py index ecc0be2a..9025baae 100644 --- a/backend/apps/scan/flows/directory_scan_flow.py +++ b/backend/apps/scan/flows/directory_scan_flow.py @@ -140,28 +140,7 @@ def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> i return default -def _setup_directory_scan_directory(scan_workspace_dir: str) -> Path: - """ - 创建并验证目录扫描工作目录 - - Args: - scan_workspace_dir: 扫描工作空间目录 - - Returns: - Path: 目录扫描目录路径 - - Raises: - RuntimeError: 目录创建或验证失败 - """ - directory_scan_dir = Path(scan_workspace_dir) / 'directory_scan' - directory_scan_dir.mkdir(parents=True, exist_ok=True) - - if not directory_scan_dir.is_dir(): - raise RuntimeError(f"目录扫描目录创建失败: {directory_scan_dir}") - if not os.access(directory_scan_dir, os.W_OK): - raise RuntimeError(f"目录扫描目录不可写: {directory_scan_dir}") - - return directory_scan_dir + def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]: @@ -640,7 +619,8 @@ def directory_scan_flow( raise ValueError("enabled_tools 不能为空") # Step 0: 创建工作目录 - directory_scan_dir = _setup_directory_scan_directory(scan_workspace_dir) + from apps.scan.utils import setup_scan_directory + directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan') # Step 1: 导出站点 URL(支持懒加载) sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir) diff --git a/backend/apps/scan/flows/fingerprint_detect_flow.py b/backend/apps/scan/flows/fingerprint_detect_flow.py index a4d39114..f038a8ff 100644 --- a/backend/apps/scan/flows/fingerprint_detect_flow.py +++ b/backend/apps/scan/flows/fingerprint_detect_flow.py @@ -64,28 +64,7 @@ def calculate_fingerprint_detect_timeout( return max(min_timeout, timeout) -def _setup_fingerprint_detect_directory(scan_workspace_dir: str) -> Path: - """ - 创建并验证指纹识别工作目录 - - Args: - scan_workspace_dir: 扫描工作空间目录 - - Returns: - Path: 指纹识别目录路径 - - Raises: - RuntimeError: 目录创建或验证失败 - """ - fingerprint_dir = Path(scan_workspace_dir) / 'fingerprint_detect' - fingerprint_dir.mkdir(parents=True, exist_ok=True) - - if not fingerprint_dir.is_dir(): - raise RuntimeError(f"指纹识别目录创建失败: {fingerprint_dir}") - if not os.access(fingerprint_dir, os.W_OK): - raise RuntimeError(f"指纹识别目录不可写: {fingerprint_dir}") - - return fingerprint_dir + def _export_urls( @@ -313,7 +292,8 @@ def fingerprint_detect_flow( source = 'website' # Step 0: 创建工作目录 - fingerprint_dir = _setup_fingerprint_detect_directory(scan_workspace_dir) + from apps.scan.utils import setup_scan_directory + fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect') # Step 1: 导出 URL(支持懒加载) urls_file, url_count = _export_urls(target_id, fingerprint_dir, target_name, source) diff --git a/backend/apps/scan/flows/initiate_scan_flow.py b/backend/apps/scan/flows/initiate_scan_flow.py index de75ad2c..ee05d7a4 100644 --- a/backend/apps/scan/flows/initiate_scan_flow.py +++ b/backend/apps/scan/flows/initiate_scan_flow.py @@ -30,7 +30,7 @@ from apps.scan.handlers import ( on_initiate_scan_flow_failed, ) from prefect.futures import wait -from apps.scan.tasks.workspace_tasks import create_scan_workspace_task +from apps.scan.utils import setup_scan_workspace from apps.scan.orchestrators import FlowOrchestrator logger = logging.getLogger(__name__) @@ -110,7 +110,7 @@ def initiate_scan_flow( ) # ==================== Task 1: 创建 Scan 工作空间 ==================== - scan_workspace_path = create_scan_workspace_task(scan_workspace_dir) + scan_workspace_path = setup_scan_workspace(scan_workspace_dir) # ==================== Task 2: 获取引擎配置 ==================== from apps.scan.models import Scan diff --git a/backend/apps/scan/flows/port_scan_flow.py b/backend/apps/scan/flows/port_scan_flow.py index 95a4f510..6c8c7939 100644 --- a/backend/apps/scan/flows/port_scan_flow.py +++ b/backend/apps/scan/flows/port_scan_flow.py @@ -154,28 +154,7 @@ def _parse_port_count(tool_config: dict) -> int: return 100 -def _setup_port_scan_directory(scan_workspace_dir: str) -> Path: - """ - 创建并验证端口扫描工作目录 - - Args: - scan_workspace_dir: 扫描工作空间目录 - - Returns: - Path: 端口扫描目录路径 - - Raises: - RuntimeError: 目录创建或验证失败 - """ - port_scan_dir = Path(scan_workspace_dir) / 'port_scan' - port_scan_dir.mkdir(parents=True, exist_ok=True) - - if not port_scan_dir.is_dir(): - raise RuntimeError(f"端口扫描目录创建失败: {port_scan_dir}") - if not os.access(port_scan_dir, os.W_OK): - raise RuntimeError(f"端口扫描目录不可写: {port_scan_dir}") - - return port_scan_dir + def _export_scan_targets(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]: @@ -442,7 +421,8 @@ def port_scan_flow( ) # Step 0: 创建工作目录 - port_scan_dir = _setup_port_scan_directory(scan_workspace_dir) + from apps.scan.utils import setup_scan_directory + port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan') # Step 1: 导出扫描目标列表到文件(根据 Target 类型自动决定内容) targets_file, target_count, target_type = _export_scan_targets(target_id, port_scan_dir) diff --git a/backend/apps/scan/flows/site_scan_flow.py b/backend/apps/scan/flows/site_scan_flow.py index 0dbd45b8..9fb0ea82 100644 --- a/backend/apps/scan/flows/site_scan_flow.py +++ b/backend/apps/scan/flows/site_scan_flow.py @@ -85,28 +85,7 @@ def calculate_timeout_by_line_count( return min_timeout -def _setup_site_scan_directory(scan_workspace_dir: str) -> Path: - """ - 创建并验证站点扫描工作目录 - - Args: - scan_workspace_dir: 扫描工作空间目录 - - Returns: - Path: 站点扫描目录路径 - - Raises: - RuntimeError: 目录创建或验证失败 - """ - site_scan_dir = Path(scan_workspace_dir) / 'site_scan' - site_scan_dir.mkdir(parents=True, exist_ok=True) - - if not site_scan_dir.is_dir(): - raise RuntimeError(f"站点扫描目录创建失败: {site_scan_dir}") - if not os.access(site_scan_dir, os.W_OK): - raise RuntimeError(f"站点扫描目录不可写: {site_scan_dir}") - - return site_scan_dir + def _export_site_urls(target_id: int, site_scan_dir: Path, target_name: str = None) -> tuple[str, int, int]: @@ -403,7 +382,8 @@ def site_scan_flow( raise ValueError("scan_workspace_dir 不能为空") # Step 0: 创建工作目录 - site_scan_dir = _setup_site_scan_directory(scan_workspace_dir) + from apps.scan.utils import setup_scan_directory + site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan') # Step 1: 导出站点 URL urls_file, total_urls, association_count = _export_site_urls( diff --git a/backend/apps/scan/flows/subdomain_discovery_flow.py b/backend/apps/scan/flows/subdomain_discovery_flow.py index e4da722d..09a553e6 100644 --- a/backend/apps/scan/flows/subdomain_discovery_flow.py +++ b/backend/apps/scan/flows/subdomain_discovery_flow.py @@ -41,28 +41,7 @@ import subprocess logger = logging.getLogger(__name__) -def _setup_subdomain_directory(scan_workspace_dir: str) -> Path: - """ - 创建并验证子域名扫描工作目录 - - Args: - scan_workspace_dir: 扫描工作空间目录 - - Returns: - Path: 子域名扫描目录路径 - - Raises: - RuntimeError: 目录创建或验证失败 - """ - result_dir = Path(scan_workspace_dir) / 'subdomain_discovery' - result_dir.mkdir(parents=True, exist_ok=True) - - if not result_dir.is_dir(): - raise RuntimeError(f"子域名扫描目录创建失败: {result_dir}") - if not os.access(result_dir, os.W_OK): - raise RuntimeError(f"子域名扫描目录不可写: {result_dir}") - - return result_dir + def _validate_and_normalize_target(target_name: str) -> str: @@ -119,12 +98,7 @@ def _run_scans_parallel( # 生成时间戳(所有工具共用) timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - - # TODO: 接入代理池管理系统 - # from apps.proxy.services import proxy_pool - # proxy_stats = proxy_pool.get_stats() - # logger.info(f"代理池状态: {proxy_stats['healthy']}/{proxy_stats['total']} 可用") - + failures = [] # 记录命令构建失败的工具 futures = {} @@ -417,7 +391,8 @@ def subdomain_discovery_flow( ) # Step 0: 准备工作 - result_dir = _setup_subdomain_directory(scan_workspace_dir) + from apps.scan.utils import setup_scan_directory + result_dir = setup_scan_directory(scan_workspace_dir, 'subdomain_discovery') # 验证并规范化目标域名 try: diff --git a/backend/apps/scan/flows/url_fetch/main_flow.py b/backend/apps/scan/flows/url_fetch/main_flow.py index 63d51cb3..31c2cd6f 100644 --- a/backend/apps/scan/flows/url_fetch/main_flow.py +++ b/backend/apps/scan/flows/url_fetch/main_flow.py @@ -42,17 +42,7 @@ SITES_FILE_TOOLS = {'katana'} POST_PROCESS_TOOLS = {'uro', 'httpx'} -def _setup_url_fetch_directory(scan_workspace_dir: str) -> Path: - """创建并验证 URL 获取工作目录""" - url_fetch_dir = Path(scan_workspace_dir) / 'url_fetch' - url_fetch_dir.mkdir(parents=True, exist_ok=True) - - if not url_fetch_dir.is_dir(): - raise RuntimeError(f"URL 获取目录创建失败: {url_fetch_dir}") - if not os.access(url_fetch_dir, os.W_OK): - raise RuntimeError(f"URL 获取目录不可写: {url_fetch_dir}") - - return url_fetch_dir + def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]: @@ -304,7 +294,8 @@ def url_fetch_flow( # Step 1: 准备工作目录 logger.info("Step 1: 准备工作目录") - url_fetch_dir = _setup_url_fetch_directory(scan_workspace_dir) + from apps.scan.utils import setup_scan_directory + url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch') # Step 2: 分类工具(按输入类型) logger.info("Step 2: 分类工具") diff --git a/backend/apps/scan/flows/vuln_scan/endpoints_vuln_scan_flow.py b/backend/apps/scan/flows/vuln_scan/endpoints_vuln_scan_flow.py index 2607fabf..c427cadf 100644 --- a/backend/apps/scan/flows/vuln_scan/endpoints_vuln_scan_flow.py +++ b/backend/apps/scan/flows/vuln_scan/endpoints_vuln_scan_flow.py @@ -25,10 +25,7 @@ from .utils import calculate_timeout_by_line_count logger = logging.getLogger(__name__) -def _setup_vuln_scan_directory(scan_workspace_dir: str) -> Path: - vuln_scan_dir = Path(scan_workspace_dir) / "vuln_scan" - vuln_scan_dir.mkdir(parents=True, exist_ok=True) - return vuln_scan_dir + @flow( @@ -55,14 +52,14 @@ def endpoints_vuln_scan_flow( if not enabled_tools: raise ValueError("enabled_tools 不能为空") - vuln_scan_dir = _setup_vuln_scan_directory(scan_workspace_dir) + from apps.scan.utils import setup_scan_directory + vuln_scan_dir = setup_scan_directory(scan_workspace_dir, 'vuln_scan') endpoints_file = vuln_scan_dir / "input_endpoints.txt" # Step 1: 导出 Endpoint URL export_result = export_endpoints_task( target_id=target_id, output_file=str(endpoints_file), - target_name=target_name, # 传入 target_name 用于生成默认端点 ) total_endpoints = export_result.get("total_count", 0) diff --git a/backend/apps/scan/services/__init__.py b/backend/apps/scan/services/__init__.py index 1d02c1d8..e40d6ff1 100644 --- a/backend/apps/scan/services/__init__.py +++ b/backend/apps/scan/services/__init__.py @@ -17,6 +17,8 @@ from .scan_state_service import ScanStateService from .scan_control_service import ScanControlService from .scan_stats_service import ScanStatsService from .scheduled_scan_service import ScheduledScanService +from .blacklist_service import BlacklistService +from .target_export_service import TargetExportService __all__ = [ 'ScanService', # 主入口(向后兼容) @@ -25,5 +27,7 @@ __all__ = [ 'ScanControlService', 'ScanStatsService', 'ScheduledScanService', + 'BlacklistService', # 黑名单过滤服务 + 'TargetExportService', # 目标导出服务 ] diff --git a/backend/apps/scan/tasks/__init__.py b/backend/apps/scan/tasks/__init__.py index ad0dff60..6bf919c8 100644 --- a/backend/apps/scan/tasks/__init__.py +++ b/backend/apps/scan/tasks/__init__.py @@ -9,9 +9,6 @@ - Tasks 负责具体操作,Flow 负责编排 """ -# Prefect Tasks -from .workspace_tasks import create_scan_workspace_task - # 子域名发现任务(已重构为多个子任务) from .subdomain_discovery import ( run_subdomain_discovery_task, @@ -30,10 +27,9 @@ from .fingerprint_detect import ( # - finalize_scan_task 已废弃(Handler 统一管理状态) # - initiate_scan_task 已迁移到 flows/initiate_scan_flow.py # - cleanup_old_scans_task 已迁移到 flows(cleanup_old_scans_flow) +# - create_scan_workspace_task 已删除,直接使用 setup_scan_workspace() __all__ = [ - # Prefect Tasks - 'create_scan_workspace_task', # 子域名发现任务 'run_subdomain_discovery_task', 'merge_and_validate_task', diff --git a/backend/apps/scan/tasks/directory_scan/export_sites_task.py b/backend/apps/scan/tasks/directory_scan/export_sites_task.py index 89662052..6b3bf6fe 100644 --- a/backend/apps/scan/tasks/directory_scan/export_sites_task.py +++ b/backend/apps/scan/tasks/directory_scan/export_sites_task.py @@ -1,20 +1,14 @@ """ 导出站点 URL 到 TXT 文件的 Task -使用流式处理,避免大量站点导致内存溢出 -支持默认值模式:如果没有站点,根据 Target 类型生成默认 URL -- DOMAIN: http(s)://target_name -- IP: http(s)://ip -- CIDR: 展开为所有 IP 的 http(s)://ip +使用 TargetExportService 统一处理导出逻辑和默认值回退 +数据源: WebSite.url """ import logging -import ipaddress -from pathlib import Path from prefect import task -from apps.asset.repositories import DjangoWebSiteRepository -from apps.targets.services import TargetService -from apps.targets.models import Target +from apps.asset.models import WebSite +from apps.scan.services import TargetExportService, BlacklistService logger = logging.getLogger(__name__) @@ -24,19 +18,22 @@ def export_sites_task( target_id: int, output_file: str, batch_size: int = 1000, - target_name: str = None ) -> dict: """ 导出目标下的所有站点 URL 到 TXT 文件 - 使用流式处理,支持大规模数据导出(10万+站点) - 支持默认值模式:如果没有站点,自动使用默认站点 URL(http(s)://target_name) + 数据源: WebSite.url + + 懒加载模式: + - 如果数据库为空,根据 Target 类型生成默认 URL + - DOMAIN: http(s)://domain + - IP: http(s)://ip + - CIDR: 展开为所有 IP 的 URL Args: target_id: 目标 ID output_file: 输出文件路径(绝对路径) batch_size: 每次读取的批次大小,默认 1000 - target_name: 目标名称(用于默认值模式) Returns: dict: { @@ -49,134 +46,26 @@ def export_sites_task( ValueError: 参数错误 IOError: 文件写入失败 """ - try: - # 初始化 Repository - repository = DjangoWebSiteRepository() - - logger.info("开始导出站点 URL - Target ID: %d, 输出文件: %s", target_id, output_file) - - # 确保输出目录存在 - output_path = Path(output_file) - output_path.parent.mkdir(parents=True, exist_ok=True) - - # 使用 Repository 流式查询站点 URL - url_iterator = repository.get_urls_for_export( - target_id=target_id, - batch_size=batch_size - ) - - # 流式写入文件 - total_count = 0 - with open(output_path, 'w', encoding='utf-8', buffering=8192) as f: - for url in url_iterator: - # 每次只处理一个 URL,边读边写 - f.write(f"{url}\n") - total_count += 1 - - # 每写入 10000 条记录打印一次进度 - if total_count % 10000 == 0: - logger.info("已导出 %d 个站点 URL...", total_count) - - # ==================== 懒加载模式:根据 Target 类型生成默认 URL ==================== - if total_count == 0: - total_count = _write_default_urls(target_id, target_name, output_path) - - logger.info( - "✓ 站点 URL 导出完成 - 总数: %d, 文件: %s (%.2f KB)", - total_count, - str(output_path), # 使用绝对路径 - output_path.stat().st_size / 1024 - ) - - return { - 'success': True, - 'output_file': str(output_path), - 'total_count': total_count - } - - except FileNotFoundError as e: - logger.error("输出目录不存在: %s", e) - raise - except PermissionError as e: - logger.error("文件写入权限不足: %s", e) - raise - except Exception as e: - logger.exception("导出站点 URL 失败: %s", e) - raise - - -def _write_default_urls(target_id: int, target_name: str, output_path: Path) -> int: - """ - 懒加载模式:根据 Target 类型生成默认 URL + # 构建数据源 queryset(Task 层决定数据源) + queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True) - Args: - target_id: 目标 ID - target_name: 目标名称(可选,如果为空则从数据库查询) - output_path: 输出文件路径 - - Returns: - int: 生成的 URL 数量 - """ - # 获取 Target 信息 - target_service = TargetService() - target = target_service.get_target(target_id) + # 使用 TargetExportService 处理导出 + blacklist_service = BlacklistService() + export_service = TargetExportService(blacklist_service=blacklist_service) - if not target: - logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id) - return 0 + result = export_service.export_urls( + target_id=target_id, + output_path=output_file, + queryset=queryset, + batch_size=batch_size + ) - target_name = target.name - target_type = target.type - - logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name) - - total_urls = 0 - - with open(output_path, 'w', encoding='utf-8', buffering=8192) as f: - if target_type == Target.TargetType.DOMAIN: - # 域名类型:生成 http(s)://domain - f.write(f"http://{target_name}\n") - f.write(f"https://{target_name}\n") - total_urls = 2 - logger.info("✓ 域名默认 URL 已写入: http(s)://%s", target_name) - - elif target_type == Target.TargetType.IP: - # IP 类型:生成 http(s)://ip - f.write(f"http://{target_name}\n") - f.write(f"https://{target_name}\n") - total_urls = 2 - logger.info("✓ IP 默认 URL 已写入: http(s)://%s", target_name) - - elif target_type == Target.TargetType.CIDR: - # CIDR 类型:展开为所有 IP 的 URL - try: - network = ipaddress.ip_network(target_name, strict=False) - - for ip in network.hosts(): # 排除网络地址和广播地址 - f.write(f"http://{ip}\n") - f.write(f"https://{ip}\n") - total_urls += 2 - - if total_urls % 10000 == 0: - logger.info("已生成 %d 个 URL...", total_urls) - - # 如果是 /32 或 /128(单个 IP),hosts() 会为空 - if total_urls == 0: - ip = str(network.network_address) - f.write(f"http://{ip}\n") - f.write(f"https://{ip}\n") - total_urls = 2 - - logger.info("✓ CIDR 默认 URL 已写入: %d 个 URL (来自 %s)", total_urls, target_name) - - except ValueError as e: - logger.error("CIDR 解析失败: %s - %s", target_name, e) - return 0 - else: - logger.warning("不支持的 Target 类型: %s", target_type) - return 0 - - return total_urls + # 保持返回值格式不变(向后兼容) + return { + 'success': result['success'], + 'output_file': result['output_file'], + 'total_count': result['total_count'] + } diff --git a/backend/apps/scan/tasks/fingerprint_detect/export_urls_task.py b/backend/apps/scan/tasks/fingerprint_detect/export_urls_task.py index 4e2fe7ea..74ff5935 100644 --- a/backend/apps/scan/tasks/fingerprint_detect/export_urls_task.py +++ b/backend/apps/scan/tasks/fingerprint_detect/export_urls_task.py @@ -2,55 +2,30 @@ 导出 URL 任务 用于指纹识别前导出目标下的 URL 到文件 -支持懒加载模式:如果数据库为空,根据 Target 类型生成默认 URL +使用 TargetExportService 统一处理导出逻辑和默认值回退 """ -import ipaddress -import importlib import logging -from pathlib import Path from prefect import task +from apps.asset.models import WebSite +from apps.scan.services import TargetExportService, BlacklistService + logger = logging.getLogger(__name__) -# 数据源映射:source → (module_path, model_name, url_field) -SOURCE_MODEL_MAP = { - 'website': ('apps.asset.models', 'WebSite', 'url'), - # 以后扩展: - # 'endpoint': ('apps.asset.models', 'Endpoint', 'url'), - # 'directory': ('apps.asset.models', 'Directory', 'url'), -} - - -def _get_model_class(source: str): - """ - 根据数据源类型获取 Model 类 - """ - if source not in SOURCE_MODEL_MAP: - raise ValueError(f"不支持的数据源: {source},支持的类型: {list(SOURCE_MODEL_MAP.keys())}") - - module_path, model_name, _ = SOURCE_MODEL_MAP[source] - module = importlib.import_module(module_path) - return getattr(module, model_name) - - @task(name="export_urls_for_fingerprint") def export_urls_for_fingerprint_task( target_id: int, output_file: str, - target_name: str = None, source: str = 'website', batch_size: int = 1000 ) -> dict: """ 导出目标下的 URL 到文件(用于指纹识别) - 支持多种数据源,预留扩展: - - website: WebSite 表(当前实现) - - endpoint: Endpoint 表(以后扩展) - - directory: Directory 表(以后扩展) + 数据源: WebSite.url 懒加载模式: - 如果数据库为空,根据 Target 类型生成默认 URL @@ -62,76 +37,29 @@ def export_urls_for_fingerprint_task( Args: target_id: 目标 ID output_file: 输出文件路径 - target_name: 目标名称(用于懒加载) - source: 数据源类型 + source: 数据源类型(保留参数,兼容旧调用) batch_size: 批量读取大小 Returns: dict: {'output_file': str, 'total_count': int, 'source': str} """ - from apps.targets.services import TargetService - from apps.targets.models import Target + # 构建数据源 queryset(Task 层决定数据源) + queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True) - logger.info("开始导出 URL - target_id=%s, source=%s, output=%s", target_id, source, output_file) + # 使用 TargetExportService 处理导出 + blacklist_service = BlacklistService() + export_service = TargetExportService(blacklist_service=blacklist_service) - Model = _get_model_class(source) - _, _, url_field = SOURCE_MODEL_MAP[source] - - output_path = Path(output_file) - - # 分批导出 - total_count = 0 - with open(output_path, 'w', encoding='utf-8') as f: - queryset = Model.objects.filter(target_id=target_id).values_list(url_field, flat=True) - for url in queryset.iterator(chunk_size=batch_size): - if url: - f.write(url + '\n') - total_count += 1 - - # ==================== 懒加载模式:根据 Target 类型生成默认 URL ==================== - if total_count == 0: - target_service = TargetService() - target = target_service.get_target(target_id) - - if target: - target_name = target.name - target_type = target.type - - logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name) - - with open(output_path, 'w', encoding='utf-8') as f: - if target_type == Target.TargetType.DOMAIN: - f.write(f"http://{target_name}\n") - f.write(f"https://{target_name}\n") - total_count = 2 - - elif target_type == Target.TargetType.IP: - f.write(f"http://{target_name}\n") - f.write(f"https://{target_name}\n") - total_count = 2 - - elif target_type == Target.TargetType.CIDR: - try: - network = ipaddress.ip_network(target_name, strict=False) - for ip in network.hosts(): - f.write(f"http://{ip}\n") - f.write(f"https://{ip}\n") - total_count += 2 - except ValueError as e: - logger.warning("CIDR 解析失败: %s", e) - - elif target_type == Target.TargetType.URL: - f.write(f"{target_name}\n") - total_count = 1 - - logger.info("✓ 懒加载生成默认 URL - 数量: %d", total_count) - else: - logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id) - - logger.info("✓ URL 导出完成 - 数量: %d, 文件: %s", total_count, output_file) + result = export_service.export_urls( + target_id=target_id, + output_path=output_file, + queryset=queryset, + batch_size=batch_size + ) + # 保持返回值格式不变(向后兼容) return { - 'output_file': output_file, - 'total_count': total_count, + 'output_file': result['output_file'], + 'total_count': result['total_count'], 'source': source } diff --git a/backend/apps/scan/tasks/port_scan/export_scan_targets_task.py b/backend/apps/scan/tasks/port_scan/export_scan_targets_task.py index e2d84298..bc334d0a 100644 --- a/backend/apps/scan/tasks/port_scan/export_scan_targets_task.py +++ b/backend/apps/scan/tasks/port_scan/export_scan_targets_task.py @@ -1,119 +1,21 @@ """ 导出扫描目标到 TXT 文件的 Task +使用 TargetExportService.export_targets() 统一处理导出逻辑 + 根据 Target 类型决定导出内容: - DOMAIN: 从 Subdomain 表导出子域名 - IP: 直接写入 target.name - CIDR: 展开 CIDR 范围内的所有 IP - -使用流式处理,避免大量数据导致内存溢出 """ import logging -import ipaddress -from pathlib import Path from prefect import task -from apps.asset.services.asset.subdomain_service import SubdomainService -from apps.targets.services import TargetService -from apps.targets.models import Target # 仅用于 TargetType 常量 +from apps.scan.services import TargetExportService, BlacklistService logger = logging.getLogger(__name__) -def _export_domains(target_id: int, target_name: str, output_path: Path, batch_size: int) -> int: - """ - 导出域名类型目标的子域名(支持默认值模式) - - Args: - target_id: 目标 ID - target_name: 目标名称(域名) - output_path: 输出文件路径 - batch_size: 批次大小 - - Returns: - int: 导出的记录数 - - 默认值模式: - 如果没有子域名,自动使用根域名作为默认子域名 - """ - subdomain_service = SubdomainService() - domain_iterator = subdomain_service.iter_subdomain_names_by_target( - target_id=target_id, - chunk_size=batch_size - ) - - total_count = 0 - with open(output_path, 'w', encoding='utf-8', buffering=8192) as f: - for domain_name in domain_iterator: - f.write(f"{domain_name}\n") - total_count += 1 - - if total_count % 10000 == 0: - logger.info("已导出 %d 个域名...", total_count) - - # ==================== 采用默认域名:如果没有子域名,使用根域名 ==================== - # 只写入文件供扫描工具使用,不写入数据库 - # 数据库只存储扫描发现的真实资产 - if total_count == 0: - logger.info("采用默认域名:%s (target_id=%d)", target_name, target_id) - - # 只写入文件,不写入数据库 - with open(output_path, 'w', encoding='utf-8') as f: - f.write(f"{target_name}\n") - total_count = 1 - - logger.info("✓ 默认域名已写入文件: %s", target_name) - - return total_count - - -def _export_ip(target_name: str, output_path: Path) -> int: - """ - 导出 IP 类型目标 - - Args: - target_name: IP 地址 - output_path: 输出文件路径 - - Returns: - int: 导出的记录数(始终为 1) - """ - with open(output_path, 'w', encoding='utf-8') as f: - f.write(f"{target_name}\n") - return 1 - - -def _export_cidr(target_name: str, output_path: Path) -> int: - """ - 导出 CIDR 类型目标,展开为每个 IP - - Args: - target_name: CIDR 范围(如 192.168.1.0/24) - output_path: 输出文件路径 - - Returns: - int: 导出的 IP 数量 - """ - network = ipaddress.ip_network(target_name, strict=False) - total_count = 0 - - with open(output_path, 'w', encoding='utf-8', buffering=8192) as f: - for ip in network.hosts(): # 排除网络地址和广播地址 - f.write(f"{ip}\n") - total_count += 1 - - if total_count % 10000 == 0: - logger.info("已导出 %d 个 IP...", total_count) - - # 如果是 /32 或 /128(单个 IP),hosts() 会为空,需要特殊处理 - if total_count == 0: - with open(output_path, 'w', encoding='utf-8') as f: - f.write(f"{network.network_address}\n") - total_count = 1 - - return total_count - - @task(name="export_scan_targets") def export_scan_targets_task( target_id: int, @@ -145,62 +47,20 @@ def export_scan_targets_task( ValueError: Target 不存在 IOError: 文件写入失败 """ - try: - # 1. 通过 Service 层获取 Target - target_service = TargetService() - target = target_service.get_target(target_id) - if not target: - raise ValueError(f"Target ID {target_id} 不存在") - - target_type = target.type - target_name = target.name - - logger.info( - "开始导出扫描目标 - Target ID: %d, Name: %s, Type: %s, 输出文件: %s", - target_id, target_name, target_type, output_file - ) - - # 2. 确保输出目录存在 - output_path = Path(output_file) - output_path.parent.mkdir(parents=True, exist_ok=True) - - # 3. 根据类型导出 - if target_type == Target.TargetType.DOMAIN: - total_count = _export_domains(target_id, target_name, output_path, batch_size) - type_desc = "域名" - elif target_type == Target.TargetType.IP: - total_count = _export_ip(target_name, output_path) - type_desc = "IP" - elif target_type == Target.TargetType.CIDR: - total_count = _export_cidr(target_name, output_path) - type_desc = "CIDR IP" - else: - raise ValueError(f"不支持的目标类型: {target_type}") - - logger.info( - "✓ 扫描目标导出完成 - 类型: %s, 总数: %d, 文件: %s (%.2f KB)", - type_desc, - total_count, - str(output_path), - output_path.stat().st_size / 1024 - ) - - return { - 'success': True, - 'output_file': str(output_path), - 'total_count': total_count, - 'target_type': target_type - } - - except FileNotFoundError as e: - logger.error("输出目录不存在: %s", e) - raise - except PermissionError as e: - logger.error("文件写入权限不足: %s", e) - raise - except ValueError as e: - logger.error("参数错误: %s", e) - raise - except Exception as e: - logger.exception("导出扫描目标失败: %s", e) - raise + # 使用 TargetExportService 处理导出 + blacklist_service = BlacklistService() + export_service = TargetExportService(blacklist_service=blacklist_service) + + result = export_service.export_targets( + target_id=target_id, + output_path=output_file, + batch_size=batch_size + ) + + # 保持返回值格式不变(向后兼容) + return { + 'success': result['success'], + 'output_file': result['output_file'], + 'total_count': result['total_count'], + 'target_type': result['target_type'] + } diff --git a/backend/apps/scan/tasks/site_scan/export_site_urls_task.py b/backend/apps/scan/tasks/site_scan/export_site_urls_task.py index 6be51255..a5a07d80 100644 --- a/backend/apps/scan/tasks/site_scan/export_site_urls_task.py +++ b/backend/apps/scan/tasks/site_scan/export_site_urls_task.py @@ -2,52 +2,65 @@ 导出站点URL到文件的Task 直接使用 HostPortMapping 表查询 host+port 组合,拼接成URL格式写入文件 +使用 TargetExportService 处理默认值回退逻辑 -默认值模式: -- 如果没有 HostPortMapping 数据,写入默认 URL 到文件(不写入数据库) -- DOMAIN: http(s)://target_name -- IP: http(s)://ip -- CIDR: 展开为所有 IP 的 http(s)://ip +特殊逻辑: +- 80 端口:只生成 HTTP URL(省略端口号) +- 443 端口:只生成 HTTPS URL(省略端口号) +- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号) """ import logging -import ipaddress from pathlib import Path from prefect import task -from typing import Optional from apps.asset.services import HostPortMappingService -from apps.targets.services import TargetService -from apps.targets.models import Target +from apps.scan.services import TargetExportService, BlacklistService logger = logging.getLogger(__name__) +def _generate_urls_from_port(host: str, port: int) -> list[str]: + """ + 根据端口生成 URL 列表 + + - 80 端口:只生成 HTTP URL(省略端口号) + - 443 端口:只生成 HTTPS URL(省略端口号) + - 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号) + """ + if port == 80: + return [f"http://{host}"] + elif port == 443: + return [f"https://{host}"] + else: + return [f"http://{host}:{port}", f"https://{host}:{port}"] + + @task(name="export_site_urls") def export_site_urls_task( target_id: int, output_file: str, - target_name: Optional[str] = None, batch_size: int = 1000 ) -> dict: """ 导出目标下的所有站点URL到文件(基于 HostPortMapping 表) - 功能: - 1. 从 HostPortMapping 表查询 target 下所有 host+port 组合 - 2. 拼接成URL格式(标准端口80/443将省略端口号) - 3. 写入到指定文件中 + 数据源: HostPortMapping (host + port) - 默认值模式(懒加载): - - 如果没有 HostPortMapping 数据,根据 Target 类型生成默认 URL - - DOMAIN: http(s)://target_name + 特殊逻辑: + - 80 端口:只生成 HTTP URL(省略端口号) + - 443 端口:只生成 HTTPS URL(省略端口号) + - 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号) + + 懒加载模式: + - 如果数据库为空,根据 Target 类型生成默认 URL + - DOMAIN: http(s)://domain - IP: http(s)://ip - - CIDR: 展开为所有 IP 的 http(s)://ip + - CIDR: 展开为所有 IP 的 URL Args: target_id: 目标ID output_file: 输出文件路径(绝对路径) - target_name: 目标名称(用于懒加载时写入默认值) - batch_size: 每次处理的批次大小,默认1000(暂未使用,预留) + batch_size: 每次处理的批次大小 Returns: dict: { @@ -61,155 +74,54 @@ def export_site_urls_task( ValueError: 参数错误 IOError: 文件写入失败 """ - try: - logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file) - - # 确保输出目录存在 - output_path = Path(output_file) - output_path.parent.mkdir(parents=True, exist_ok=True) - - # 直接查询 HostPortMapping 表,按 host 排序 - service = HostPortMappingService() - associations = service.iter_host_port_by_target( - target_id=target_id, - batch_size=batch_size, - ) - - total_urls = 0 - association_count = 0 - - # 流式写入文件 - with open(output_path, 'w', encoding='utf-8', buffering=8192) as f: - for assoc in associations: - association_count += 1 - host = assoc['host'] - port = assoc['port'] - - # 根据端口号生成URL - # 80 端口:只生成 HTTP URL(省略端口号) - # 443 端口:只生成 HTTPS URL(省略端口号) - # 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号) - if port == 80: - # HTTP 标准端口,省略端口号 - url = f"http://{host}" - f.write(f"{url}\n") - total_urls += 1 - elif port == 443: - # HTTPS 标准端口,省略端口号 - url = f"https://{host}" - f.write(f"{url}\n") - total_urls += 1 - else: - # 非标准端口,生成 HTTP 和 HTTPS 两个URL - http_url = f"http://{host}:{port}" - https_url = f"https://{host}:{port}" - f.write(f"{http_url}\n") - f.write(f"{https_url}\n") - total_urls += 2 - - # 每处理1000条记录打印一次进度 - if association_count % 1000 == 0: - logger.info("已处理 %d 条关联,生成 %d 个URL...", association_count, total_urls) - - logger.info( - "✓ 站点URL导出完成 - 关联数: %d, 总URL数: %d, 文件: %s (%.2f KB)", - association_count, - total_urls, - str(output_path), - output_path.stat().st_size / 1024 - ) - - # ==================== 懒加载模式:根据 Target 类型生成默认 URL ==================== - if total_urls == 0: - total_urls = _write_default_urls(target_id, target_name, output_path) - - return { - 'success': True, - 'output_file': str(output_path), - 'total_urls': total_urls, - 'association_count': association_count - } - - except FileNotFoundError as e: - logger.error("输出目录不存在: %s", e) - raise - except PermissionError as e: - logger.error("文件写入权限不足: %s", e) - raise - except Exception as e: - logger.exception("导出站点URL失败: %s", e) - raise - - -def _write_default_urls(target_id: int, target_name: Optional[str], output_path: Path) -> int: - """ - 懒加载模式:根据 Target 类型生成默认 URL + logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file) - Args: - target_id: 目标 ID - target_name: 目标名称(可选,如果为空则从数据库查询) - output_path: 输出文件路径 - - Returns: - int: 生成的 URL 数量 - """ - # 获取 Target 信息 - target_service = TargetService() - target = target_service.get_target(target_id) + # 确保输出目录存在 + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) - if not target: - logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id) - return 0 + # 初始化黑名单服务 + blacklist_service = BlacklistService() - target_name = target.name - target_type = target.type - - logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name) + # 直接查询 HostPortMapping 表,按 host 排序 + service = HostPortMappingService() + associations = service.iter_host_port_by_target( + target_id=target_id, + batch_size=batch_size, + ) total_urls = 0 + association_count = 0 + # 流式写入文件(特殊端口逻辑) with open(output_path, 'w', encoding='utf-8', buffering=8192) as f: - if target_type == Target.TargetType.DOMAIN: - # 域名类型:生成 http(s)://domain - f.write(f"http://{target_name}\n") - f.write(f"https://{target_name}\n") - total_urls = 2 - logger.info("✓ 域名默认 URL 已写入: http(s)://%s", target_name) + for assoc in associations: + association_count += 1 + host = assoc['host'] + port = assoc['port'] - elif target_type == Target.TargetType.IP: - # IP 类型:生成 http(s)://ip - f.write(f"http://{target_name}\n") - f.write(f"https://{target_name}\n") - total_urls = 2 - logger.info("✓ IP 默认 URL 已写入: http(s)://%s", target_name) + # 根据端口号生成URL + for url in _generate_urls_from_port(host, port): + if blacklist_service.filter_url(url): + f.write(f"{url}\n") + total_urls += 1 - elif target_type == Target.TargetType.CIDR: - # CIDR 类型:展开为所有 IP 的 URL - try: - network = ipaddress.ip_network(target_name, strict=False) - - for ip in network.hosts(): # 排除网络地址和广播地址 - f.write(f"http://{ip}\n") - f.write(f"https://{ip}\n") - total_urls += 2 - - if total_urls % 10000 == 0: - logger.info("已生成 %d 个 URL...", total_urls) - - # 如果是 /32 或 /128(单个 IP),hosts() 会为空 - if total_urls == 0: - ip = str(network.network_address) - f.write(f"http://{ip}\n") - f.write(f"https://{ip}\n") - total_urls = 2 - - logger.info("✓ CIDR 默认 URL 已写入: %d 个 URL (来自 %s)", total_urls, target_name) - - except ValueError as e: - logger.error("CIDR 解析失败: %s - %s", target_name, e) - return 0 - else: - logger.warning("不支持的 Target 类型: %s", target_type) - return 0 + if association_count % 1000 == 0: + logger.info("已处理 %d 条关联,生成 %d 个URL...", association_count, total_urls) - return total_urls + logger.info( + "✓ 站点URL导出完成 - 关联数: %d, 总URL数: %d, 文件: %s", + association_count, total_urls, str(output_path) + ) + + # 默认值回退模式:使用 TargetExportService + if total_urls == 0: + export_service = TargetExportService(blacklist_service=blacklist_service) + total_urls = export_service._generate_default_urls(target_id, output_path) + + return { + 'success': True, + 'output_file': str(output_path), + 'total_urls': total_urls, + 'association_count': association_count + } diff --git a/backend/apps/scan/tasks/url_fetch/export_sites_task.py b/backend/apps/scan/tasks/url_fetch/export_sites_task.py index cd75ff2b..b562293f 100644 --- a/backend/apps/scan/tasks/url_fetch/export_sites_task.py +++ b/backend/apps/scan/tasks/url_fetch/export_sites_task.py @@ -1,25 +1,16 @@ """ 导出站点 URL 列表任务 -从 WebSite 表导出站点 URL 列表到文件(用于 katana 等爬虫工具) - -使用流式写入,避免内存溢出 - -懒加载模式: -- 如果 WebSite 表为空,根据 Target 类型生成默认 URL -- DOMAIN: 写入 http(s)://domain -- IP: 写入 http(s)://ip -- CIDR: 展开为所有 IP +使用 TargetExportService 统一处理导出逻辑和默认值回退 +数据源: WebSite.url(用于 katana 等爬虫工具) """ import logging -import ipaddress -from pathlib import Path from prefect import task from typing import Optional -from apps.targets.services import TargetService -from apps.targets.models import Target +from apps.asset.models import WebSite +from apps.scan.services import TargetExportService, BlacklistService logger = logging.getLogger(__name__) @@ -33,21 +24,23 @@ def export_sites_task( output_file: str, target_id: int, scan_id: int, - target_name: Optional[str] = None, batch_size: int = 1000 ) -> dict: """ 导出站点 URL 列表到文件(用于 katana 等爬虫工具) + 数据源: WebSite.url + 懒加载模式: - - 如果 WebSite 表为空,根据 Target 类型生成默认 URL - - 数据库只存储"真实发现"的资产 + - 如果数据库为空,根据 Target 类型生成默认 URL + - DOMAIN: http(s)://domain + - IP: http(s)://ip + - CIDR: 展开为所有 IP 的 URL Args: output_file: 输出文件路径 target_id: 目标 ID - scan_id: 扫描 ID - target_name: 目标名称(用于懒加载时写入默认值) + scan_id: 扫描 ID(保留参数,兼容旧调用) batch_size: 批次大小(内存优化) Returns: @@ -60,109 +53,22 @@ def export_sites_task( ValueError: 参数错误 RuntimeError: 执行失败 """ - try: - logger.info("开始导出站点 URL 列表 - Target ID: %d", target_id) - - # 确保输出目录存在 - output_path = Path(output_file) - output_path.parent.mkdir(parents=True, exist_ok=True) - - # 从 WebSite 表导出站点 URL - from apps.asset.services import WebSiteService - - website_service = WebSiteService() - - # 流式写入文件 - asset_count = 0 - with open(output_path, 'w') as f: - for url in website_service.iter_website_urls_by_target(target_id, batch_size): - f.write(f"{url}\n") - asset_count += 1 - - if asset_count % batch_size == 0: - f.flush() - - # ==================== 懒加载模式:根据 Target 类型生成默认 URL ==================== - if asset_count == 0: - asset_count = _write_default_urls(target_id, target_name, output_path) - - logger.info("✓ 站点 URL 导出完成 - 文件: %s, 数量: %d", output_file, asset_count) - - return { - 'output_file': output_file, - 'asset_count': asset_count, - } - - except Exception as e: - logger.error("导出站点 URL 失败: %s", e, exc_info=True) - raise RuntimeError(f"导出站点 URL 失败: {e}") from e - - -def _write_default_urls(target_id: int, target_name: Optional[str], output_path: Path) -> int: - """ - 懒加载模式:根据 Target 类型生成默认 URL 列表 + # 构建数据源 queryset(Task 层决定数据源) + queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True) - Args: - target_id: 目标 ID - target_name: 目标名称 - output_path: 输出文件路径 - - Returns: - int: 生成的 URL 数量 - """ - target_service = TargetService() - target = target_service.get_target(target_id) + # 使用 TargetExportService 处理导出 + blacklist_service = BlacklistService() + export_service = TargetExportService(blacklist_service=blacklist_service) - if not target: - logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id) - return 0 + result = export_service.export_urls( + target_id=target_id, + output_path=output_file, + queryset=queryset, + batch_size=batch_size + ) - target_name = target.name - target_type = target.type - - logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name) - - total_urls = 0 - - with open(output_path, 'w', encoding='utf-8', buffering=8192) as f: - if target_type == Target.TargetType.DOMAIN: - f.write(f"http://{target_name}\n") - f.write(f"https://{target_name}\n") - total_urls = 2 - logger.info("✓ 域名默认 URL 已写入: http(s)://%s", target_name) - - elif target_type == Target.TargetType.IP: - f.write(f"http://{target_name}\n") - f.write(f"https://{target_name}\n") - total_urls = 2 - logger.info("✓ IP 默认 URL 已写入: http(s)://%s", target_name) - - elif target_type == Target.TargetType.CIDR: - try: - network = ipaddress.ip_network(target_name, strict=False) - - for ip in network.hosts(): - f.write(f"http://{ip}\n") - f.write(f"https://{ip}\n") - total_urls += 2 - - if total_urls % 10000 == 0: - logger.info("已生成 %d 个 URL...", total_urls) - - # /32 或 /128 特殊处理 - if total_urls == 0: - ip = str(network.network_address) - f.write(f"http://{ip}\n") - f.write(f"https://{ip}\n") - total_urls = 2 - - logger.info("✓ CIDR 默认 URL 已写入: %d 个 URL (来自 %s)", total_urls, target_name) - - except ValueError as e: - logger.error("CIDR 解析失败: %s - %s", target_name, e) - return 0 - else: - logger.warning("不支持的 Target 类型: %s", target_type) - return 0 - - return total_urls + # 保持返回值格式不变(向后兼容) + return { + 'output_file': result['output_file'], + 'asset_count': result['total_count'], + } diff --git a/backend/apps/scan/tasks/vuln_scan/export_endpoints_task.py b/backend/apps/scan/tasks/vuln_scan/export_endpoints_task.py index 8cf4425b..483b6874 100644 --- a/backend/apps/scan/tasks/vuln_scan/export_endpoints_task.py +++ b/backend/apps/scan/tasks/vuln_scan/export_endpoints_task.py @@ -1,25 +1,16 @@ """导出 Endpoint URL 到文件的 Task -基于 EndpointService.iter_endpoint_urls_by_target 按目标流式导出端点 URL, -用于漏洞扫描(如 Dalfox XSS)的输入文件生成。 - -默认值模式: -- 如果没有 Endpoint,根据 Target 类型生成默认 URL -- DOMAIN: http(s)://target_name -- IP: http(s)://ip -- CIDR: 展开为所有 IP 的 http(s)://ip +使用 TargetExportService 统一处理导出逻辑和默认值回退 +数据源: Endpoint.url """ import logging -import ipaddress -from pathlib import Path from typing import Dict, Optional from prefect import task -from apps.asset.services import EndpointService -from apps.targets.services import TargetService -from apps.targets.models import Target +from apps.asset.models import Endpoint +from apps.scan.services import TargetExportService, BlacklistService logger = logging.getLogger(__name__) @@ -29,17 +20,21 @@ def export_endpoints_task( target_id: int, output_file: str, batch_size: int = 1000, - target_name: Optional[str] = None, ) -> Dict[str, object]: """导出目标下的所有 Endpoint URL 到文本文件。 - 默认值模式:如果没有 Endpoint,根据 Target 类型生成默认 URL + 数据源: Endpoint.url + + 懒加载模式: + - 如果数据库为空,根据 Target 类型生成默认 URL + - DOMAIN: http(s)://domain + - IP: http(s)://ip + - CIDR: 展开为所有 IP 的 URL Args: target_id: 目标 ID output_file: 输出文件路径(绝对路径) batch_size: 每次从数据库迭代的批大小 - target_name: 目标名称(用于默认值模式) Returns: dict: { @@ -48,117 +43,23 @@ def export_endpoints_task( "total_count": int, } """ - try: - logger.info("开始导出 Endpoint URL - Target ID: %d, 输出文件: %s", target_id, output_file) - - output_path = Path(output_file) - output_path.parent.mkdir(parents=True, exist_ok=True) - - service = EndpointService() - url_iterator = service.iter_endpoint_urls_by_target(target_id, chunk_size=batch_size) - - total_count = 0 - with open(output_path, "w", encoding="utf-8", buffering=8192) as f: - for url in url_iterator: - f.write(f"{url}\n") - total_count += 1 - - if total_count % 10000 == 0: - logger.info("已导出 %d 个 Endpoint URL...", total_count) - - # ==================== 懒加载模式:根据 Target 类型生成默认 URL ==================== - if total_count == 0: - total_count = _write_default_urls(target_id, target_name, output_path) - - logger.info( - "✓ Endpoint URL 导出完成 - 总数: %d, 文件: %s (%.2f KB)", - total_count, - str(output_path), - output_path.stat().st_size / 1024, - ) - - return { - "success": True, - "output_file": str(output_path), - "total_count": total_count, - } - - except FileNotFoundError as e: - logger.error("输出目录不存在: %s", e) - raise - except PermissionError as e: - logger.error("文件写入权限不足: %s", e) - raise - except Exception as e: - logger.exception("导出 Endpoint URL 失败: %s", e) - raise - - -def _write_default_urls(target_id: int, target_name: Optional[str], output_path: Path) -> int: - """ - 懒加载模式:根据 Target 类型生成默认 URL + # 构建数据源 queryset(Task 层决定数据源) + queryset = Endpoint.objects.filter(target_id=target_id).values_list('url', flat=True) - Args: - target_id: 目标 ID - target_name: 目标名称(可选,如果为空则从数据库查询) - output_path: 输出文件路径 - - Returns: - int: 生成的 URL 数量 - """ - target_service = TargetService() - target = target_service.get_target(target_id) + # 使用 TargetExportService 处理导出 + blacklist_service = BlacklistService() + export_service = TargetExportService(blacklist_service=blacklist_service) - if not target: - logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id) - return 0 + result = export_service.export_urls( + target_id=target_id, + output_path=output_file, + queryset=queryset, + batch_size=batch_size + ) - target_name = target.name - target_type = target.type - - logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name) - - total_urls = 0 - - with open(output_path, 'w', encoding='utf-8', buffering=8192) as f: - if target_type == Target.TargetType.DOMAIN: - f.write(f"http://{target_name}\n") - f.write(f"https://{target_name}\n") - total_urls = 2 - logger.info("✓ 域名默认 URL 已写入: http(s)://%s", target_name) - - elif target_type == Target.TargetType.IP: - f.write(f"http://{target_name}\n") - f.write(f"https://{target_name}\n") - total_urls = 2 - logger.info("✓ IP 默认 URL 已写入: http(s)://%s", target_name) - - elif target_type == Target.TargetType.CIDR: - try: - network = ipaddress.ip_network(target_name, strict=False) - - for ip in network.hosts(): - f.write(f"http://{ip}\n") - f.write(f"https://{ip}\n") - total_urls += 2 - - if total_urls % 10000 == 0: - logger.info("已生成 %d 个 URL...", total_urls) - - # /32 或 /128 特殊处理 - if total_urls == 0: - ip = str(network.network_address) - f.write(f"http://{ip}\n") - f.write(f"https://{ip}\n") - total_urls = 2 - - logger.info("✓ CIDR 默认 URL 已写入: %d 个 URL (来自 %s)", total_urls, target_name) - - except ValueError as e: - logger.error("CIDR 解析失败: %s - %s", target_name, e) - return 0 - else: - logger.warning("不支持的 Target 类型: %s", target_type) - return 0 - - return total_urls + # 保持返回值格式不变(向后兼容) + return { + "success": result['success'], + "output_file": result['output_file'], + "total_count": result['total_count'], + } diff --git a/backend/apps/scan/tasks/workspace_tasks.py b/backend/apps/scan/tasks/workspace_tasks.py deleted file mode 100644 index 15880523..00000000 --- a/backend/apps/scan/tasks/workspace_tasks.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -工作空间相关的 Prefect Tasks - -负责扫描工作空间的创建、验证和管理 -""" - -from pathlib import Path -from prefect import task -import logging - -logger = logging.getLogger(__name__) - - -@task( - name="create_scan_workspace", - description="创建并验证 Scan 工作空间目录", - retries=2, - retry_delay_seconds=5 -) -def create_scan_workspace_task(scan_workspace_dir: str) -> Path: - """ - 创建并验证 Scan 工作空间目录 - - Args: - scan_workspace_dir: Scan 工作空间目录路径 - - Returns: - Path: 创建的 Scan 工作空间路径对象 - - Raises: - OSError: 目录创建失败或不可写 - """ - scan_workspace_path = Path(scan_workspace_dir) - - # 创建目录 - try: - scan_workspace_path.mkdir(parents=True, exist_ok=True) - logger.info("✓ Scan 工作空间已创建: %s", scan_workspace_path) - except OSError as e: - logger.error("创建 Scan 工作空间失败: %s - %s", scan_workspace_dir, e) - raise - - # 验证目录是否可写 - test_file = scan_workspace_path / ".test_write" - try: - test_file.touch() - test_file.unlink() - logger.info("✓ Scan 工作空间验证通过(可写): %s", scan_workspace_path) - except OSError as e: - error_msg = f"Scan 工作空间不可写: {scan_workspace_path}" - logger.error(error_msg) - raise OSError(error_msg) from e - - return scan_workspace_path diff --git a/backend/apps/scan/utils/__init__.py b/backend/apps/scan/utils/__init__.py index e3daaae9..3cdb64d5 100644 --- a/backend/apps/scan/utils/__init__.py +++ b/backend/apps/scan/utils/__init__.py @@ -10,11 +10,15 @@ from .command_executor import execute_and_wait, execute_stream from .wordlist_helpers import ensure_wordlist_local from .nuclei_helpers import ensure_nuclei_templates_local from .performance import FlowPerformanceTracker, CommandPerformanceTracker +from .workspace_utils import setup_scan_workspace, setup_scan_directory from . import config_parser __all__ = [ # 目录清理 'remove_directory', + # 工作空间 + 'setup_scan_workspace', # 创建 Scan 根工作空间 + 'setup_scan_directory', # 创建扫描子目录 # 命令构建 'build_scan_command', # 扫描工具命令构建(基于 f-string) # 命令执行