Compare commits

...

7 Commits

Author SHA1 Message Date
yyhuni
2d2ec93626 perf(screenshot): optimize memory usage and add URL collection fallback logic
- Add iterator(chunk_size=50) to ScreenshotSnapshot query to prevent BinaryField data caching and reduce memory consumption
- Implement fallback logic in URL collection: WebSite → HostPortMapping → Default URL with priority handling
- Update _collect_urls_from_provider to return tuple with data source information for better logging and debugging
- Add detailed logging to track which data source was used during URL collection
- Improve code documentation with clear return type hints and fallback priority explanation
- Prevents memory spikes when processing large screenshot datasets with binary image data
2026-01-11 16:14:56 +08:00
github-actions[bot]
ced9f811f4 chore: bump version to v1.5.8-dev 2026-01-11 08:09:37 +00:00
yyhuni
aa99b26f50 fix(vuln_scan): use tool-specific parameter names for endpoint scanning
- Add conditional logic to use "input_file" parameter for nuclei tool
- Use "endpoints_file" parameter for other scanning tools
- Improve compatibility with different vulnerability scanning tools
- Ensure correct parameter naming based on tool requirements
2026-01-11 15:59:39 +08:00
yyhuni
8342f196db nuclei加入website扫描为默认 2026-01-11 12:13:27 +08:00
yyhuni
1bd2a6ed88 重构:完成provider 2026-01-11 11:15:59 +08:00
yyhuni
033ff89aee 重构:采用provider提供数据 2026-01-11 10:29:27 +08:00
yyhuni
4284a0cd9a refactor(scan): remove deprecated provider implementations and cleanup
- Delete ListTargetProvider implementation and related tests
- Delete PipelineTargetProvider implementation and related tests
- Remove target_export_service.py unused service module
- Remove test files for common properties validation
- Update engine-preset-selector component in frontend
- Remove sponsor acknowledgment section from README
- Simplify provider architecture by consolidating implementations
2026-01-10 23:53:52 +08:00
60 changed files with 2192 additions and 3688 deletions

View File

@@ -289,12 +289,6 @@ sudo ./uninstall.sh
<img src="docs/zfb_pay.jpg" alt="支付宝" width="200">
</p>
### 感谢以下赞助
| 昵称 | 金额 |
|------|------|
| X闭关中 | ¥88 |
## 免责声明

View File

@@ -1 +1 @@
v1.5.7
v1.5.8-dev

View File

@@ -195,3 +195,32 @@ class DjangoHostPortMappingSnapshotRepository:
for row in qs.iterator(chunk_size=batch_size):
yield row
def iter_unique_host_ports_by_scan(
self,
scan_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
流式获取扫描下的唯一 host:port 组合(去重)
用于生成 URL 时避免重复,同一个 host:port 可能对应多个 IP
但生成 URL 时只需要一个。
Args:
scan_id: 扫描 ID
batch_size: 每批数据量
Yields:
{'host': 'example.com', 'port': 80}
"""
qs = (
HostPortMappingSnapshot.objects
.filter(scan_id=scan_id)
.values('host', 'port')
.distinct()
.order_by('host', 'port')
)
for row in qs.iterator(chunk_size=batch_size):
yield row

View File

@@ -146,7 +146,9 @@ class ScreenshotService:
"""
from apps.asset.models import Screenshot, ScreenshotSnapshot
snapshots = ScreenshotSnapshot.objects.filter(scan_id=scan_id)
# 使用 iterator() 避免 QuerySet 缓存大量 BinaryField 数据导致内存飙升
# chunk_size=50: 每次只加载 50 条记录,处理完后释放内存
snapshots = ScreenshotSnapshot.objects.filter(scan_id=scan_id).iterator(chunk_size=50)
count = 0
for snapshot in snapshots:

View File

@@ -1,72 +1,18 @@
"""Endpoint Snapshots Service - 业务逻辑层"""
import logging
from typing import List, Iterator
from typing import Iterator, List, Optional
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
from apps.asset.repositories.snapshot import DjangoEndpointSnapshotRepository
from apps.asset.services.asset import EndpointService
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
logger = logging.getLogger(__name__)
class EndpointSnapshotsService:
"""端点快照服务 - 统一管理快照和资产同步"""
def __init__(self):
self.snapshot_repo = DjangoEndpointSnapshotRepository()
self.asset_service = EndpointService()
def save_and_sync(self, items: List[EndpointSnapshotDTO]) -> None:
"""
保存端点快照并同步到资产表(统一入口)
流程:
1. 保存到快照表(完整记录)
2. 同步到资产表(去重)
Args:
items: 端点快照 DTO 列表(必须包含 target_id
Raises:
ValueError: 如果 items 中的 target_id 为 None
Exception: 数据库操作失败
"""
if not items:
return
# 检查 Scan 是否仍存在(防止删除后竞态写入)
scan_id = items[0].scan_id
from apps.scan.repositories import DjangoScanRepository
if not DjangoScanRepository().exists(scan_id):
logger.warning("Scan 已删除,跳过端点快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
return
try:
logger.debug("保存端点快照并同步到资产表 - 数量: %d", len(items))
# 步骤 1: 保存到快照表
logger.debug("步骤 1: 保存到快照表")
self.snapshot_repo.save_snapshots(items)
# 步骤 2: 转换为资产 DTO 并保存到资产表
# 使用 upsert新记录插入已存在的记录更新
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
asset_items = [item.to_asset_dto() for item in items]
self.asset_service.bulk_upsert(asset_items)
logger.info("端点快照和资产数据保存成功 - 数量: %d", len(items))
except Exception as e:
logger.error(
"保存端点快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
@@ -76,26 +22,89 @@ class EndpointSnapshotsService:
'webserver': 'webserver',
'tech': 'tech',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
def __init__(self):
self.snapshot_repo = DjangoEndpointSnapshotRepository()
self.asset_service = EndpointService()
def save_and_sync(self, items: List[EndpointSnapshotDTO]) -> None:
"""
保存端点快照并同步到资产表(统一入口)
流程:
1. 保存到快照表(完整记录)
2. 同步到资产表(去重)
Args:
items: 端点快照 DTO 列表(必须包含 target_id
Raises:
ValueError: 如果 items 中的 target_id 为 None
Exception: 数据库操作失败
"""
if not items:
return
# 检查 Scan 是否仍存在(防止删除后竞态写入)
scan_id = items[0].scan_id
from apps.scan.repositories import DjangoScanRepository
if not DjangoScanRepository().exists(scan_id):
logger.warning("Scan 已删除,跳过端点快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
return
try:
logger.debug("保存端点快照并同步到资产表 - 数量: %d", len(items))
# 步骤 1: 保存到快照表
self.snapshot_repo.save_snapshots(items)
# 步骤 2: 转换为资产 DTO 并保存到资产表upsert
asset_items = [item.to_asset_dto() for item in items]
self.asset_service.bulk_upsert(asset_items)
logger.info("端点快照和资产数据保存成功 - 数量: %d", len(items))
except Exception as e:
logger.error("保存端点快照失败 - 数量: %d, 错误: %s", len(items), str(e), exc_info=True)
raise
def get_by_scan(self, scan_id: int, filter_query: Optional[str] = None):
"""
获取指定扫描的端点快照
Args:
scan_id: 扫描 ID
filter_query: 过滤查询字符串
Returns:
QuerySet: 端点快照查询集
"""
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self, filter_query: str = None):
"""获取所有端点快照"""
def get_all(self, filter_query: Optional[str] = None):
"""
获取所有端点快照
Args:
filter_query: 过滤查询字符串
Returns:
QuerySet: 端点快照查询集
"""
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_endpoint_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有端点 URL"""
"""流式获取某次扫描下的所有端点 URL"""
queryset = self.snapshot_repo.get_by_scan(scan_id)
for snapshot in queryset.iterator(chunk_size=chunk_size):
yield snapshot.url
@@ -103,10 +112,10 @@ class EndpointSnapshotsService:
def iter_raw_data_for_csv_export(self, scan_id: int) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
scan_id: 扫描 ID
Yields:
原始数据字典
"""

View File

@@ -91,3 +91,25 @@ class HostPortMappingSnapshotsService:
原始数据字典 {ip, host, port, created_at}
"""
return self.snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)
def iter_unique_host_ports_by_scan(
self,
scan_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
流式获取扫描下的唯一 host:port 组合(去重)
用于生成 URL 时避免重复。
Args:
scan_id: 扫描 ID
batch_size: 每批数据量
Yields:
{'host': 'example.com', 'port': 80}
"""
return self.snapshot_repo.iter_unique_host_ports_by_scan(
scan_id=scan_id,
batch_size=batch_size
)

View File

@@ -449,34 +449,33 @@ class TaskDistributor:
def execute_scan_flow(
self,
scan_id: int,
target_name: str,
target_id: int,
target_name: str,
scan_workspace_dir: str,
engine_name: str,
scheduled_scan_name: str | None = None,
) -> tuple[bool, str, Optional[str], Optional[int]]:
"""
在远程或本地 Worker 上执行扫描 Flow
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作目录
engine_name: 引擎名称
scheduled_scan_name: 定时扫描任务名称(可选)
Returns:
(success, message, container_id, worker_id) 元组
Note:
engine_config 由 Flow 内部通过 scan_id 查询数据库获取
"""
logger.info("="*60)
logger.info("execute_scan_flow 开始")
logger.info(" scan_id: %s", scan_id)
logger.info(" target_name: %s", target_name)
logger.info(" target_id: %s", target_id)
logger.info(" target_name: %s", target_name)
logger.info(" scan_workspace_dir: %s", scan_workspace_dir)
logger.info(" engine_name: %s", engine_name)
logger.info(" docker_image: %s", self.docker_image)
@@ -495,23 +494,22 @@ class TaskDistributor:
# 3. 构建 docker run 命令
script_args = {
'scan_id': scan_id,
'target_name': target_name,
'target_id': target_id,
'scan_workspace_dir': scan_workspace_dir,
'engine_name': engine_name,
}
if scheduled_scan_name:
script_args['scheduled_scan_name'] = scheduled_scan_name
docker_cmd = self._build_docker_command(
worker=worker,
script_module='apps.scan.scripts.run_initiate_scan',
script_args=script_args,
)
logger.info(
"提交扫描任务到 Worker: %s - Scan ID: %d, Target: %s",
worker.name, scan_id, target_name
"提交扫描任务到 Worker: %s - Scan ID: %d, Target: %s (ID: %d)",
worker.name, scan_id, target_name, target_id
)
# 4. 执行 docker run本地直接执行远程通过 SSH

View File

@@ -203,7 +203,7 @@ VULN_SCAN_COMMANDS = {
# -silent: 静默模式
# -l: 输入 URL 列表文件
# -t: 模板目录路径(支持多个仓库,多次 -t 由 template_args 直接拼接)
'base': "nuclei -j -silent -l '{endpoints_file}' {template_args}",
'base': "nuclei -j -silent -l '{input_file}' {template_args}",
'optional': {
'concurrency': '-c {concurrency}', # 并发数(默认 25
'rate_limit': '-rl {rate_limit}', # 每秒请求数限制
@@ -214,7 +214,12 @@ VULN_SCAN_COMMANDS = {
'tags': '-tags {tags}', # 过滤标签
'exclude_tags': '-etags {exclude_tags}', # 排除标签
},
'input_type': 'endpoints_file',
# 支持多种输入类型,用户通过 scan_endpoints/scan_websites 选择
'input_types': ['endpoints_file', 'websites_file'],
'defaults': {
'scan_endpoints': False, # 默认不扫描 endpoints
'scan_websites': True, # 默认扫描 websites
},
},
}

View File

@@ -158,7 +158,9 @@ vuln_scan:
nuclei:
enabled: true
# timeout: auto # 自动计算(根据 endpoints 行数)
# timeout: auto # 自动计算(根据输入 URL 行数)
scan-endpoints: false # 是否扫描 endpoints默认关闭
scan-websites: true # 是否扫描 websites默认开启
template-repo-names: # 模板仓库列表对应「Nuclei 模板」中的仓库名
- nuclei-templates
# - nuclei-custom # 可追加自定义仓库

View File

@@ -107,7 +107,8 @@ def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> i
def _export_site_urls(
target_id: int,
directory_scan_dir: Path
directory_scan_dir: Path,
provider,
) -> Tuple[str, int]:
"""
导出目标下的所有站点 URL 到文件
@@ -115,6 +116,7 @@ def _export_site_urls(
Args:
target_id: 目标 ID
directory_scan_dir: 目录扫描目录
provider: TargetProvider 实例
Returns:
tuple: (sites_file, site_count)
@@ -123,9 +125,8 @@ def _export_site_urls(
sites_file = str(directory_scan_dir / 'sites.txt')
export_result = export_sites_task(
target_id=target_id,
output_file=sites_file,
batch_size=1000
provider=provider,
)
site_count = export_result['total_count']
@@ -389,10 +390,10 @@ def _run_scans_concurrently(
)
def directory_scan_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider,
) -> dict:
"""
目录扫描 Flow
@@ -404,10 +405,10 @@ def directory_scan_flow(
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置字典
provider: TargetProvider 实例
Returns:
dict: 扫描结果
@@ -415,6 +416,11 @@ def directory_scan_flow(
try:
wait_for_system_load(context="directory_scan_flow")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
logger.info(
"开始目录扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
@@ -424,8 +430,6 @@ def directory_scan_flow(
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
@@ -438,7 +442,9 @@ def directory_scan_flow(
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
# Step 1: 导出站点 URL
sites_file, site_count = _export_site_urls(target_id, directory_scan_dir)
sites_file, site_count = _export_site_urls(
target_id, directory_scan_dir, provider
)
if site_count == 0:
logger.warning("跳过目录扫描:没有站点可扫描 - Scan ID: %s", scan_id)

View File

@@ -11,8 +11,10 @@
"""
import logging
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional
from prefect import flow
@@ -22,183 +24,147 @@ from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
)
from apps.scan.tasks.fingerprint_detect import (
export_urls_for_fingerprint_task,
export_site_urls_for_fingerprint_task,
run_xingfinger_and_stream_update_tech_task,
)
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
from apps.scan.utils import build_scan_command, setup_scan_directory, user_log, wait_for_system_load
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
logger = logging.getLogger(__name__)
@dataclass
class FingerprintContext:
"""指纹识别上下文,用于在各函数间传递状态"""
scan_id: int
target_id: int
target_name: str
scan_workspace_dir: str
fingerprint_dir: Optional[Path] = None
urls_file: str = ""
url_count: int = 0
source: str = "website"
def calculate_fingerprint_detect_timeout(
url_count: int,
base_per_url: float = 10.0,
min_timeout: int = 300
) -> int:
"""
根据 URL 数量计算超时时间
公式:超时时间 = URL 数量 × 每 URL 基础时间
最小值300秒无上限
Args:
url_count: URL 数量
base_per_url: 每 URL 基础时间(秒),默认 10秒
min_timeout: 最小超时时间(秒),默认 300秒
Returns:
int: 计算出的超时时间(秒)
"""
"""根据 URL 数量计算超时时间(最小 300 秒)"""
return max(min_timeout, int(url_count * base_per_url))
def _export_urls(
target_id: int,
fingerprint_dir: Path,
source: str = 'website'
) -> tuple[str, int]:
"""
导出 URL 到文件
Args:
target_id: 目标 ID
fingerprint_dir: 指纹识别目录
source: 数据源类型
Returns:
tuple: (urls_file, total_count)
"""
logger.info("Step 1: 导出 URL 列表 (source=%s)", source)
def _export_urls(fingerprint_dir: Path, provider) -> tuple[str, int]:
"""导出 URL 到文件,返回 (urls_file, total_count)"""
logger.info("Step 1: 导出 URL 列表")
urls_file = str(fingerprint_dir / 'urls.txt')
export_result = export_urls_for_fingerprint_task(
target_id=target_id,
export_result = export_site_urls_for_fingerprint_task(
output_file=urls_file,
source=source,
batch_size=1000
provider=provider,
)
total_count = export_result['total_count']
logger.info(
"✓ URL 导出完成 - 文件: %s, 数量: %d",
export_result['output_file'],
total_count
)
logger.info("✓ URL 导出完成 - 文件: %s, 数量: %d", export_result['output_file'], total_count)
return export_result['output_file'], total_count
def _run_fingerprint_detect(
enabled_tools: dict,
urls_file: str,
url_count: int,
fingerprint_dir: Path,
scan_id: int,
target_id: int,
source: str
) -> tuple[dict, list]:
"""
执行指纹识别任务
def _run_single_tool(
tool_name: str,
tool_config: dict,
ctx: FingerprintContext
) -> tuple[Optional[dict], Optional[dict]]:
"""执行单个指纹识别工具,返回 (stats, failed_info)"""
# 获取指纹库路径
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
fingerprint_paths = get_fingerprint_paths(lib_names)
Args:
enabled_tools: 已启用的工具配置字典
urls_file: URL 文件路径
url_count: URL 总数
fingerprint_dir: 指纹识别目录
scan_id: 扫描任务 ID
target_id: 目标 ID
source: 数据源类型
if not fingerprint_paths:
reason = f"没有可用的指纹库: {lib_names}"
logger.warning(reason)
return None, {'tool': tool_name, 'reason': reason}
Returns:
tuple: (tool_stats, failed_tools)
"""
# 构建命令
tool_config_with_paths = {**tool_config, **fingerprint_paths}
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='fingerprint_detect',
command_params={'urls_file': ctx.urls_file},
tool_config=tool_config_with_paths
)
except Exception as e:
reason = f"命令构建失败: {e}"
logger.error("构建 %s 命令失败: %s", tool_name, e)
return None, {'tool': tool_name, 'reason': reason}
# 计算超时时间和日志文件
timeout = calculate_fingerprint_detect_timeout(ctx.url_count)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = ctx.fingerprint_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
tool_name, ctx.url_count, timeout, list(fingerprint_paths.keys())
)
user_log(ctx.scan_id, "fingerprint_detect", f"Running {tool_name}: {command}")
# 执行扫描任务
try:
result = run_xingfinger_and_stream_update_tech_task(
cmd=command,
tool_name=tool_name,
scan_id=ctx.scan_id,
target_id=ctx.target_id,
source=ctx.source,
cwd=str(ctx.fingerprint_dir),
timeout=timeout,
log_file=str(log_file),
batch_size=100
)
stats = {
'command': command,
'result': result,
'timeout': timeout,
'fingerprint_libs': list(fingerprint_paths.keys())
}
tool_updated = result.get('updated_count', 0)
logger.info(
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
tool_name,
result.get('processed_records', 0),
tool_updated,
result.get('not_found_count', 0)
)
user_log(
ctx.scan_id, "fingerprint_detect",
f"{tool_name} completed: identified {tool_updated} fingerprints"
)
return stats, None
except Exception as exc:
reason = str(exc)
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(ctx.scan_id, "fingerprint_detect", f"{tool_name} failed: {reason}", "error")
return None, {'tool': tool_name, 'reason': reason}
def _run_fingerprint_detect(enabled_tools: dict, ctx: FingerprintContext) -> tuple[dict, list]:
"""执行指纹识别任务,返回 (tool_stats, failed_tools)"""
tool_stats = {}
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 获取指纹库路径
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
fingerprint_paths = get_fingerprint_paths(lib_names)
if not fingerprint_paths:
reason = f"没有可用的指纹库: {lib_names}"
logger.warning(reason)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 将指纹库路径合并到 tool_config用于命令构建
tool_config_with_paths = {**tool_config, **fingerprint_paths}
# 3. 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='fingerprint_detect',
command_params={'urls_file': urls_file},
tool_config=tool_config_with_paths
)
except Exception as e:
reason = f"命令构建失败: {e}"
logger.error("构建 %s 命令失败: %s", tool_name, e)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 4. 计算超时时间
timeout = calculate_fingerprint_detect_timeout(url_count)
# 5. 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = fingerprint_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
tool_name, url_count, timeout, list(fingerprint_paths.keys())
)
user_log(scan_id, "fingerprint_detect", f"Running {tool_name}: {command}")
# 6. 执行扫描任务
try:
result = run_xingfinger_and_stream_update_tech_task(
cmd=command,
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
source=source,
cwd=str(fingerprint_dir),
timeout=timeout,
log_file=str(log_file),
batch_size=100
)
tool_stats[tool_name] = {
'command': command,
'result': result,
'timeout': timeout,
'fingerprint_libs': list(fingerprint_paths.keys())
}
tool_updated = result.get('updated_count', 0)
logger.info(
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
tool_name,
result.get('processed_records', 0),
tool_updated,
result.get('not_found_count', 0)
)
user_log(
scan_id, "fingerprint_detect",
f"{tool_name} completed: identified {tool_updated} fingerprints"
)
except Exception as exc:
reason = str(exc)
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(scan_id, "fingerprint_detect", f"{tool_name} failed: {reason}", "error")
stats, failed_info = _run_single_tool(tool_name, tool_config, ctx)
if stats:
tool_stats[tool_name] = stats
if failed_info:
failed_tools.append(failed_info)
if failed_tools:
logger.warning(
@@ -209,6 +175,24 @@ def _run_fingerprint_detect(
return tool_stats, failed_tools
def _aggregate_results(tool_stats: dict) -> dict:
"""汇总所有工具的结果"""
return {
'processed_records': sum(
s['result'].get('processed_records', 0) for s in tool_stats.values()
),
'updated_count': sum(
s['result'].get('updated_count', 0) for s in tool_stats.values()
),
'created_count': sum(
s['result'].get('created_count', 0) for s in tool_stats.values()
),
'snapshot_count': sum(
s['result'].get('snapshot_count', 0) for s in tool_stats.values()
),
}
@flow(
name="fingerprint_detect",
log_prints=True,
@@ -218,10 +202,10 @@ def _run_fingerprint_detect(
)
def fingerprint_detect_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider,
) -> dict:
"""
指纹识别 Flow
@@ -230,57 +214,45 @@ def fingerprint_detect_flow(
1. 从数据库导出目标下所有 WebSite URL 到文件
2. 使用 xingfinger 进行技术栈识别
3. 解析结果并更新 WebSite.tech 字段(合并去重)
工作流程:
Step 0: 创建工作目录
Step 1: 导出 URL 列表
Step 2: 解析配置,获取启用的工具
Step 3: 执行 xingfinger 并解析结果
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置xingfinger
Returns:
dict: 扫描结果
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="fingerprint_detect_flow")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
logger.info(
"开始指纹识别 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "fingerprint_detect", "Starting fingerprint detection")
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
# 创建上下文
ctx = FingerprintContext(
scan_id=scan_id,
target_id=target_id,
target_name=target_name,
scan_workspace_dir=scan_workspace_dir,
fingerprint_dir=setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
)
# 数据源类型(当前只支持 website
source = 'website'
# Step 1: 导出 URL
ctx.urls_file, ctx.url_count = _export_urls(ctx.fingerprint_dir, provider)
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
# Step 1: 导出 URL支持懒加载
urls_file, url_count = _export_urls(target_id, fingerprint_dir, source)
if url_count == 0:
if ctx.url_count == 0:
logger.warning("跳过指纹识别:没有 URL 可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "fingerprint_detect", "Skipped: no URLs to scan", "warning")
return _build_empty_result(scan_id, target_name, scan_workspace_dir, urls_file)
return _build_empty_result(scan_id, target_name, scan_workspace_dir, ctx.urls_file)
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
@@ -288,57 +260,30 @@ def fingerprint_detect_flow(
# Step 3: 执行指纹识别
logger.info("Step 3: 执行指纹识别")
tool_stats, failed_tools = _run_fingerprint_detect(
enabled_tools=enabled_tools,
urls_file=urls_file,
url_count=url_count,
fingerprint_dir=fingerprint_dir,
scan_id=scan_id,
target_id=target_id,
source=source
)
tool_stats, failed_tools = _run_fingerprint_detect(enabled_tools, ctx)
# 动态生成已执行的任务列表
executed_tasks = ['export_urls_for_fingerprint']
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats])
# 汇总结果
totals = _aggregate_results(tool_stats)
failed_tool_names = {f['tool'] for f in failed_tools}
successful_tools = [name for name in enabled_tools if name not in failed_tool_names]
# 汇总所有工具的结果
total_processed = sum(
stats['result'].get('processed_records', 0) for stats in tool_stats.values()
)
total_updated = sum(
stats['result'].get('updated_count', 0) for stats in tool_stats.values()
)
total_created = sum(
stats['result'].get('created_count', 0) for stats in tool_stats.values()
)
total_snapshots = sum(
stats['result'].get('snapshot_count', 0) for stats in tool_stats.values()
)
# 记录 Flow 完成
logger.info("✓ 指纹识别完成 - 识别指纹: %d", total_updated)
logger.info("✓ 指纹识别完成 - 识别指纹: %d", totals['updated_count'])
user_log(
scan_id, "fingerprint_detect",
f"fingerprint_detect completed: identified {total_updated} fingerprints"
f"fingerprint_detect completed: identified {totals['updated_count']} fingerprints"
)
successful_tools = [
name for name in enabled_tools
if name not in [f['tool'] for f in failed_tools]
]
executed_tasks = ['export_site_urls_for_fingerprint']
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats])
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': url_count,
'processed_records': total_processed,
'updated_count': total_updated,
'created_count': total_created,
'snapshot_count': total_snapshots,
'urls_file': ctx.urls_file,
'url_count': ctx.url_count,
**totals,
'executed_tasks': executed_tasks,
'tool_stats': {
'total': len(enabled_tools),
@@ -379,7 +324,7 @@ def _build_empty_result(
'updated_count': 0,
'created_count': 0,
'snapshot_count': 0,
'executed_tasks': ['export_urls_for_fingerprint'],
'executed_tasks': ['export_site_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
'successful': 0,

View File

@@ -7,6 +7,7 @@
- 使用 FlowOrchestrator 解析 YAML 配置
- 在 Prefect Flow 中执行子 FlowSubflow
- 按照 YAML 顺序编排工作流
- 根据 scan_mode 创建对应的 Provider
- 不包含具体业务逻辑(由 Tasks 和 FlowOrchestrator 实现)
架构:
@@ -18,20 +19,20 @@
# Django 环境初始化(导入即生效)
# 注意:动态扫描容器应使用 run_initiate_scan.py 启动,以便在导入前设置环境变量
from apps.common.prefect_django_setup import setup_django_for_prefect
import apps.common.prefect_django_setup # noqa: F401
import logging
from prefect import flow, task
from pathlib import Path
import logging
from prefect.futures import wait
from apps.scan.handlers import (
on_initiate_scan_flow_running,
on_initiate_scan_flow_completed,
on_initiate_scan_flow_failed,
)
from prefect.futures import wait
from apps.scan.utils import setup_scan_workspace
from apps.scan.orchestrators import FlowOrchestrator
from apps.scan.utils import setup_scan_workspace
logger = logging.getLogger(__name__)
@@ -43,6 +44,75 @@ def _run_subflow_task(scan_type: str, flow_func, flow_kwargs: dict):
return flow_func(**flow_kwargs)
def _create_provider(scan, target_id: int, scan_id: int):
"""根据 scan_mode 创建对应的 Provider"""
from apps.scan.models import Scan
from apps.scan.providers import (
DatabaseTargetProvider,
SnapshotTargetProvider,
ProviderContext,
)
provider_context = ProviderContext(target_id=target_id, scan_id=scan_id)
if scan.scan_mode == Scan.ScanMode.QUICK:
provider = SnapshotTargetProvider(scan_id=scan_id, context=provider_context)
logger.info("✓ 快速扫描模式 - 创建 SnapshotTargetProvider")
else:
provider = DatabaseTargetProvider(target_id=target_id, context=provider_context)
logger.info("✓ 完整扫描模式 - 使用 DatabaseTargetProvider")
return provider
def _execute_sequential_flows(valid_flows: list, results: dict, executed_flows: list):
"""顺序执行 Flow 列表"""
for scan_type, flow_func, flow_kwargs in valid_flows:
logger.info("=" * 60)
logger.info("执行 Flow: %s", scan_type)
logger.info("=" * 60)
try:
result = flow_func(**flow_kwargs)
executed_flows.append(scan_type)
results[scan_type] = result
logger.info("%s 执行成功", scan_type)
except Exception as e:
logger.warning("%s 执行失败: %s", scan_type, e)
executed_flows.append(f"{scan_type} (失败)")
results[scan_type] = {'success': False, 'error': str(e)}
def _execute_parallel_flows(valid_flows: list, results: dict, executed_flows: list):
"""并行执行 Flow 列表"""
futures = []
for scan_type, flow_func, flow_kwargs in valid_flows:
logger.info("=" * 60)
logger.info("提交并行子 Flow 任务: %s", scan_type)
logger.info("=" * 60)
future = _run_subflow_task.submit(
scan_type=scan_type,
flow_func=flow_func,
flow_kwargs=flow_kwargs,
)
futures.append((scan_type, future))
if not futures:
return
wait([f for _, f in futures])
for scan_type, future in futures:
try:
result = future.result()
executed_flows.append(scan_type)
results[scan_type] = result
logger.info("%s 执行成功", scan_type)
except Exception as e:
logger.warning("%s 执行失败: %s", scan_type, e)
executed_flows.append(f"{scan_type} (失败)")
results[scan_type] = {'success': False, 'error': str(e)}
@flow(
name='initiate_scan',
description='扫描任务初始化流程',
@@ -53,15 +123,14 @@ def _run_subflow_task(scan_type: str, flow_func, flow_kwargs: dict):
)
def initiate_scan_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
engine_name: str,
scheduled_scan_name: str | None = None,
scheduled_scan_name: str | None = None, # noqa: ARG001
) -> dict:
"""
初始化扫描任务(动态工作流编排)
根据 YAML 配置动态编排工作流:
- 从数据库获取 engine_config (YAML)
- 检测启用的扫描类型
@@ -73,189 +142,112 @@ def initiate_scan_flow(
Stage 2: Analysis (并行执行)
- url_fetch
- directory_scan
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: Scan 工作空间目录路径
engine_name: 引擎名称(用于显示)
scheduled_scan_name: 定时扫描任务名称(可选,用于通知显示)
Returns:
dict: 执行结果摘要
Raises:
ValueError: 参数验证失败或配置无效
RuntimeError: 执行失败
"""
try:
# ==================== 参数验证 ====================
# 参数验证
if not scan_id:
raise ValueError("scan_id is required")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir is required")
if not engine_name:
raise ValueError("engine_name is required")
logger.info("="*60)
logger.info("开始初始化扫描任务")
logger.info(f"Scan ID: {scan_id}")
logger.info(f"Target: {target_name}")
logger.info(f"Engine: {engine_name}")
logger.info(f"Workspace: {scan_workspace_dir}")
logger.info("="*60)
# ==================== Task 1: 创建 Scan 工作空间 ====================
# 创建工作空间
scan_workspace_path = setup_scan_workspace(scan_workspace_dir)
# ==================== Task 2: 获取引擎配置 ====================
# 获取引擎配置
from apps.scan.models import Scan
scan = Scan.objects.get(id=scan_id)
engine_config = scan.yaml_configuration
# 使用 engine_names 进行显示
display_engine_name = ', '.join(scan.engine_names) if scan.engine_names else engine_name
# ==================== Task 3: 解析配置,生成执行计划 ====================
# 创建 Provider
provider = _create_provider(scan, target_id, scan_id)
# 获取 target_name 用于日志显示
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
logger.info("=" * 60)
logger.info("开始初始化扫描任务")
logger.info("Scan ID: %s, Target: %s, Engine: %s", scan_id, target_name, engine_name)
logger.info("Workspace: %s", scan_workspace_dir)
logger.info("=" * 60)
# 解析配置,生成执行计划
orchestrator = FlowOrchestrator(engine_config)
# FlowOrchestrator 已经解析了所有工具配置
enabled_tools_by_type = orchestrator.enabled_tools_by_type
logger.info("执行计划生成成功")
logger.info(f"扫描类型: {''.join(orchestrator.scan_types)}")
logger.info(f"总共 {len(orchestrator.scan_types)} 个 Flow")
# ==================== 初始化阶段进度 ====================
# 在解析完配置后立即初始化,此时已有完整的 scan_types 列表
logger.info("执行计划: %s (共 %d 个 Flow)",
''.join(orchestrator.scan_types), len(orchestrator.scan_types))
# 初始化阶段进度
from apps.scan.services import ScanService
scan_service = ScanService()
scan_service.init_stage_progress(scan_id, orchestrator.scan_types)
logger.info(f"✓ 初始化阶段进度 - Stages: {orchestrator.scan_types}")
# ==================== 更新 Target 最后扫描时间 ====================
# 在开始扫描时更新,表示"最后一次扫描开始时间"
ScanService().init_stage_progress(scan_id, orchestrator.scan_types)
logger.info("✓ 初始化阶段进度 - Stages: %s", orchestrator.scan_types)
# 更新 Target 最后扫描时间
from apps.targets.services import TargetService
target_service = TargetService()
target_service.update_last_scanned_at(target_id)
logger.info(f"✓ 更新 Target 最后扫描时间 - Target ID: {target_id}")
# ==================== Task 3: 执行 Flow动态阶段执行====================
# 注意:各阶段状态更新由 scan_flow_handlers.py 自动处理running/completed/failed
TargetService().update_last_scanned_at(target_id)
logger.info("✓ 更新 Target 最后扫描时间 - Target ID: %s", target_id)
# 执行 Flow
executed_flows = []
results = {}
# 通用执行参数
flow_kwargs = {
base_kwargs = {
'scan_id': scan_id,
'target_name': target_name,
'target_id': target_id,
'scan_workspace_dir': str(scan_workspace_path)
}
def record_flow_result(scan_type, result=None, error=None):
"""
统一的结果记录函数
Args:
scan_type: 扫描类型名称
result: 执行结果(成功时)
error: 异常对象(失败时)
"""
if error:
# 失败处理:记录错误但不抛出异常,让扫描继续执行后续阶段
error_msg = f"{scan_type} 执行失败: {str(error)}"
logger.warning(error_msg)
executed_flows.append(f"{scan_type} (失败)")
results[scan_type] = {'success': False, 'error': str(error)}
# 不再抛出异常,让扫描继续
else:
# 成功处理
executed_flows.append(scan_type)
results[scan_type] = result
logger.info(f"{scan_type} 执行成功")
def get_valid_flows(flow_names):
"""
获取有效的 Flow 函数列表,并为每个 Flow 准备专属参数
Args:
flow_names: 扫描类型名称列表
Returns:
list: [(scan_type, flow_func, flow_specific_kwargs), ...] 有效的函数列表
"""
valid_flows = []
def get_valid_flows(flow_names: list) -> list:
"""获取有效的 Flow 函数列表"""
valid = []
for scan_type in flow_names:
flow_func = orchestrator.get_flow_function(scan_type)
if flow_func:
# 为每个 Flow 准备专属的参数(包含对应的 enabled_tools
flow_specific_kwargs = dict(flow_kwargs)
flow_specific_kwargs['enabled_tools'] = enabled_tools_by_type.get(scan_type, {})
valid_flows.append((scan_type, flow_func, flow_specific_kwargs))
else:
logger.warning(f"跳过未实现的 Flow: {scan_type}")
return valid_flows
if not flow_func:
logger.warning("跳过未实现的 Flow: %s", scan_type)
continue
kwargs = dict(base_kwargs)
kwargs['enabled_tools'] = enabled_tools_by_type.get(scan_type, {})
kwargs['provider'] = provider
valid.append((scan_type, flow_func, kwargs))
return valid
# ---------------------------------------------------------
# 动态阶段执行(基于 FlowOrchestrator 定义)
# ---------------------------------------------------------
# 动态阶段执行
for mode, enabled_flows in orchestrator.get_execution_stages():
valid_flows = get_valid_flows(enabled_flows)
if not valid_flows:
continue
logger.info("=" * 60)
logger.info("%s执行阶段: %s", "顺序" if mode == 'sequential' else "并行",
', '.join(enabled_flows))
logger.info("=" * 60)
if mode == 'sequential':
# 顺序执行
logger.info("="*60)
logger.info(f"顺序执行阶段: {', '.join(enabled_flows)}")
logger.info("="*60)
for scan_type, flow_func, flow_specific_kwargs in get_valid_flows(enabled_flows):
logger.info("="*60)
logger.info(f"执行 Flow: {scan_type}")
logger.info("="*60)
try:
result = flow_func(**flow_specific_kwargs)
record_flow_result(scan_type, result=result)
except Exception as e:
record_flow_result(scan_type, error=e)
elif mode == 'parallel':
# 并行执行阶段:通过 Task 包装子 Flow并使用 Prefect TaskRunner 并发运行
logger.info("="*60)
logger.info(f"并行执行阶段: {', '.join(enabled_flows)}")
logger.info("="*60)
futures = []
_execute_sequential_flows(valid_flows, results, executed_flows)
else:
_execute_parallel_flows(valid_flows, results, executed_flows)
# 提交所有并行子 Flow 任务
for scan_type, flow_func, flow_specific_kwargs in get_valid_flows(enabled_flows):
logger.info("="*60)
logger.info(f"提交并行子 Flow 任务: {scan_type}")
logger.info("="*60)
future = _run_subflow_task.submit(
scan_type=scan_type,
flow_func=flow_func,
flow_kwargs=flow_specific_kwargs,
)
futures.append((scan_type, future))
logger.info("=" * 60)
logger.info("✓ 扫描任务初始化完成 - 执行的 Flow: %s", ', '.join(executed_flows))
logger.info("=" * 60)
# 等待所有并行子 Flow 完成
if futures:
wait([f for _, f in futures])
# 检查结果(复用统一的结果处理逻辑)
for scan_type, future in futures:
try:
result = future.result()
record_flow_result(scan_type, result=result)
except Exception as e:
record_flow_result(scan_type, error=e)
# ==================== 完成 ====================
logger.info("="*60)
logger.info("✓ 扫描任务初始化完成")
logger.info(f"执行的 Flow: {', '.join(executed_flows)}")
logger.info("="*60)
# ==================== 返回结果 ====================
return {
'success': True,
'scan_id': scan_id,
@@ -264,21 +256,16 @@ def initiate_scan_flow(
'executed_flows': executed_flows,
'results': results
}
except ValueError as e:
# 参数错误
logger.error("参数错误: %s", e)
raise
except RuntimeError as e:
# 执行失败
logger.error("运行时错误: %s", e)
raise
except OSError as e:
# 文件系统错误(工作空间创建失败)
logger.error("文件系统错误: %s", e)
raise
except Exception as e:
# 其他未预期错误
logger.exception("初始化扫描任务失败: %s", e)
# 注意:失败状态更新由 Prefect State Handlers 自动处理
raise

View File

@@ -132,42 +132,36 @@ def _parse_port_count(tool_config: dict) -> int:
def _export_hosts(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
def _export_hosts(port_scan_dir: Path, provider) -> tuple[str, int]:
"""
导出主机列表到文件
根据 Target 类型自动决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名
- IP: 直接写入 target.name
- CIDR: 展开 CIDR 范围内的所有 IP
Args:
target_id: 目标 ID
port_scan_dir: 端口扫描目录
provider: TargetProvider 实例
Returns:
tuple: (hosts_file, host_count, target_type)
tuple: (hosts_file, host_count)
"""
logger.info("Step 1: 导出主机列表")
hosts_file = str(port_scan_dir / 'hosts.txt')
export_result = export_hosts_task(
target_id=target_id,
output_file=hosts_file,
provider=provider,
)
host_count = export_result['total_count']
target_type = export_result.get('target_type', 'unknown')
logger.info(
"✓ 主机列表导出完成 - 类型: %s, 文件: %s, 数量: %d",
target_type, export_result['output_file'], host_count
"✓ 主机列表导出完成 - 文件: %s, 数量: %d",
export_result['output_file'], host_count
)
if host_count == 0:
logger.warning("目标下没有可扫描的主机,无法执行端口扫描")
return export_result['output_file'], host_count, target_type
return export_result['output_file'], host_count
def _run_scans_sequentially(
@@ -176,7 +170,7 @@ def _run_scans_sequentially(
port_scan_dir: Path,
scan_id: int,
target_id: int,
target_name: str
target_name: str,
) -> tuple[dict, int, list, list]:
"""
串行执行端口扫描任务
@@ -187,7 +181,7 @@ def _run_scans_sequentially(
port_scan_dir: 端口扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
target_name: 目标名称(用于错误日志)
target_name: 目标名称(用于日志显示
Returns:
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
@@ -271,7 +265,7 @@ def _run_scans_sequentially(
if not tool_stats:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
logger.warning("所有端口扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
logger.warning("所有端口扫描工具均失败 - Target: %s, 失败工具: %s", target_name, error_details)
return {}, 0, [], failed_tools
successful_tool_names = [
@@ -298,10 +292,10 @@ def _run_scans_sequentially(
)
def port_scan_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider,
) -> dict:
"""
端口扫描 Flow
@@ -321,10 +315,10 @@ def port_scan_flow(
Args:
scan_id: 扫描任务 ID
target_name: 域名
target_id: 目标 ID
scan_workspace_dir: Scan 工作空间目录
enabled_tools: 启用的工具配置字典
provider: TargetProvider 实例
Returns:
dict: 扫描结果
@@ -336,10 +330,13 @@ def port_scan_flow(
try:
wait_for_system_load(context="port_scan_flow")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
@@ -358,7 +355,7 @@ def port_scan_flow(
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
# Step 1: 导出主机列表
hosts_file, host_count, target_type = _export_hosts(target_id, port_scan_dir)
hosts_file, host_count = _export_hosts(port_scan_dir, provider)
if host_count == 0:
logger.warning("跳过端口扫描:没有主机可扫描 - Scan ID: %s", scan_id)
@@ -370,7 +367,6 @@ def port_scan_flow(
'scan_workspace_dir': scan_workspace_dir,
'hosts_file': hosts_file,
'host_count': 0,
'target_type': target_type,
'processed_records': 0,
'executed_tasks': ['export_hosts'],
'tool_stats': {
@@ -395,7 +391,7 @@ def port_scan_flow(
port_scan_dir=port_scan_dir,
scan_id=scan_id,
target_id=target_id,
target_name=target_name
target_name=target_name,
)
logger.info("✓ 端口扫描完成 - 发现端口: %d", processed_records)
@@ -411,7 +407,6 @@ def port_scan_flow(
'scan_workspace_dir': scan_workspace_dir,
'hosts_file': hosts_file,
'host_count': host_count,
'target_type': target_type,
'processed_records': processed_records,
'executed_tasks': executed_tasks,
'tool_stats': {

View File

@@ -2,17 +2,12 @@
截图 Flow
负责编排截图的完整流程:
1. 从数据库获取 URL 列表websites 和/或 endpoints
1. 从 Provider 获取 URL 列表
2. 批量截图并保存快照
3. 同步到资产表
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
"""
import logging
from typing import Optional
from prefect import flow
@@ -22,62 +17,49 @@ from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
)
from apps.scan.providers import TargetProvider
from apps.scan.services.target_export_service import DataSource, get_urls_with_fallback
from apps.scan.tasks.screenshot import capture_screenshots_task
from apps.scan.utils import user_log, wait_for_system_load
logger = logging.getLogger(__name__)
# URL 来源到 DataSource 的映射
_SOURCE_MAPPING = {
'websites': DataSource.WEBSITE,
'endpoints': DataSource.ENDPOINT,
}
def _parse_screenshot_config(enabled_tools: dict) -> dict:
"""解析截图配置"""
playwright_config = enabled_tools.get('playwright', {})
return {
'concurrency': playwright_config.get('concurrency', 5),
'url_sources': playwright_config.get('url_sources', ['websites'])
}
def _map_url_sources_to_data_sources(url_sources: list[str]) -> list[str]:
"""将配置中的 url_sources 映射为 DataSource 常量"""
sources = []
for source in url_sources:
if source in _SOURCE_MAPPING:
sources.append(_SOURCE_MAPPING[source])
else:
logger.warning("未知的 URL 来源: %s,跳过", source)
def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str]:
"""
从 Provider 收集网站 URL带回退逻辑
优先级WebSite → HostPortMapping → Default URL
Returns:
tuple: (urls, source)
- urls: URL 列表
- source: 数据来源 ('website' | 'host_port' | 'default')
"""
logger.info("从 Provider 获取网站 URL - Provider: %s", type(provider).__name__)
# 添加默认回退(从 subdomain 构造)
sources.append(DataSource.DEFAULT)
return sources
# 优先从 WebSite 获取
urls = list(provider.iter_websites())
if urls:
logger.info("使用 WebSite 数据源 - 数量: %d", len(urls))
return urls, "website"
# 回退到 HostPortMapping
urls = list(provider.iter_host_port_urls())
if urls:
logger.info("WebSite 为空,回退到 HostPortMapping - 数量: %d", len(urls))
return urls, "host_port"
def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str, list[str]]:
"""从 Provider 收集 URL"""
logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__)
urls = list(provider.iter_urls())
blacklist_filter = provider.get_blacklist_filter()
if blacklist_filter:
urls = [url for url in urls if blacklist_filter.is_allowed(url)]
return urls, 'provider', ['provider']
def _collect_urls_from_database(
target_id: int,
url_sources: list[str]
) -> tuple[list[str], str, list[str]]:
"""从数据库收集 URL带黑名单过滤和回退"""
data_sources = _map_url_sources_to_data_sources(url_sources)
result = get_urls_with_fallback(target_id, sources=data_sources)
return result['urls'], result['source'], result['tried_sources']
# 最终回退到默认 URL
urls = list(provider.iter_default_urls())
logger.info("HostPortMapping 为空,回退到默认 URL - 数量: %d", len(urls))
return urls, "default"
def _build_empty_result(scan_id: int, target_name: str) -> dict:
@@ -102,68 +84,53 @@ def _build_empty_result(scan_id: int, target_name: str) -> dict:
)
def screenshot_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict,
provider: Optional[TargetProvider] = None
provider: TargetProvider,
) -> dict:
"""
截图 Flow
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置
provider: TargetProvider 实例(新模式,可选)
provider: TargetProvider 实例
Returns:
截图结果字典
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="screenshot_flow")
mode = 'Provider' if provider else 'Legacy'
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
logger.info(
"开始截图扫描 - Scan ID: %s, Target: %s, Mode: %s",
scan_id, target_name, mode
"开始截图扫描 - Scan ID: %s, Target: %s",
scan_id, target_name
)
user_log(scan_id, "screenshot", "Starting screenshot capture")
# Step 1: 解析配置
config = _parse_screenshot_config(enabled_tools)
concurrency = config['concurrency']
logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, config['url_sources'])
logger.info("截图配置 - 并发: %d", concurrency)
# Step 2: 收集 URL 列表
if provider is not None:
urls, source_info, tried_sources = _collect_urls_from_provider(provider)
else:
urls, source_info, tried_sources = _collect_urls_from_database(
target_id, config['url_sources']
)
logger.info(
"URL 收集完成 - 来源: %s, 数量: %d, 尝试过: %s",
source_info, len(urls), tried_sources
)
# Step 2: 从 Provider 收集 URL 列表(带回退逻辑)
urls, source = _collect_urls_from_provider(provider)
logger.info("URL 收集完成 - 来源: %s, 数量: %d", source, len(urls))
if not urls:
logger.warning("没有可截图的 URL跳过截图任务")
user_log(scan_id, "screenshot", "Skipped: no URLs to capture", "warning")
return _build_empty_result(scan_id, target_name)
user_log(
scan_id, "screenshot",
f"Found {len(urls)} URLs to capture (source: {source_info})"
)
user_log(scan_id, "screenshot", f"Found {len(urls)} URLs to capture")
# Step 3: 批量截图
logger.info("批量截图 - %d 个 URL", len(urls))

View File

@@ -88,40 +88,38 @@ def _calculate_timeout_by_line_count(
def _export_site_urls(
target_id: int,
site_scan_dir: Path
) -> tuple[str, int, int]:
site_scan_dir: Path,
provider,
) -> tuple[str, int]:
"""
导出站点 URL 到文件
Args:
target_id: 目标 ID
site_scan_dir: 站点扫描目录
provider: TargetProvider 实例
Returns:
tuple: (urls_file, total_urls, association_count)
tuple: (urls_file, total_urls)
"""
logger.info("Step 1: 导出站点URL列表")
urls_file = str(site_scan_dir / 'site_urls.txt')
export_result = export_site_urls_task(
target_id=target_id,
output_file=urls_file,
batch_size=1000
provider=provider,
)
total_urls = export_result['total_urls']
association_count = export_result['association_count']
logger.info(
"✓ 站点URL导出完成 - 文件: %s, URL数量: %d, 关联数: %d",
export_result['output_file'], total_urls, association_count
"✓ 站点URL导出完成 - 文件: %s, URL数量: %d",
export_result['output_file'], total_urls
)
if total_urls == 0:
logger.warning("目标下没有可用的站点URL无法执行站点扫描")
return export_result['output_file'], total_urls, association_count
return export_result['output_file'], total_urls
def _get_tool_timeout(tool_config: dict, urls_file: str) -> int:
@@ -263,7 +261,6 @@ def _build_empty_result(
target_name: str,
scan_workspace_dir: str,
urls_file: str,
association_count: int
) -> dict:
"""构建空结果(无 URL 可扫描时)"""
return {
@@ -273,7 +270,6 @@ def _build_empty_result(
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'total_urls': 0,
'association_count': association_count,
'processed_records': 0,
'created_websites': 0,
'skipped_no_subdomain': 0,
@@ -306,15 +302,12 @@ def _aggregate_tool_results(tool_stats: dict) -> tuple[int, int, int]:
def _validate_flow_params(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str
) -> None:
"""验证 Flow 参数"""
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
@@ -330,10 +323,10 @@ def _validate_flow_params(
)
def site_scan_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider,
) -> dict:
"""
站点扫描 Flow
@@ -344,10 +337,10 @@ def site_scan_flow(
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置字典
provider: TargetProvider 实例
Returns:
dict: 扫描结果
@@ -359,12 +352,17 @@ def site_scan_flow(
try:
wait_for_system_load(context="site_scan_flow")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
logger.info(
"开始站点扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
_validate_flow_params(scan_id, target_name, target_id, scan_workspace_dir)
_validate_flow_params(scan_id, target_id, scan_workspace_dir)
user_log(scan_id, "site_scan", "Starting site scan")
# Step 0: 创建工作目录
@@ -372,15 +370,15 @@ def site_scan_flow(
site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan')
# Step 1: 导出站点 URL
urls_file, total_urls, association_count = _export_site_urls(
target_id, site_scan_dir
urls_file, total_urls = _export_site_urls(
site_scan_dir, provider
)
if total_urls == 0:
logger.warning("跳过站点扫描:没有站点 URL 可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "site_scan", "Skipped: no site URLs to scan", "warning")
return _build_empty_result(
scan_id, target_name, scan_workspace_dir, urls_file, association_count
scan_id, target_name, scan_workspace_dir, urls_file
)
# Step 2: 工具配置信息
@@ -421,7 +419,6 @@ def site_scan_flow(
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'total_urls': total_urls,
'association_count': association_count,
'processed_records': processed_records,
'created_websites': total_created,
'skipped_no_subdomain': total_skipped_no_sub,

View File

@@ -540,10 +540,10 @@ def _empty_result(scan_id: int, target: str, scan_workspace_dir: str) -> dict:
)
def subdomain_discovery_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider,
) -> dict:
"""子域名发现扫描流程
@@ -571,6 +571,8 @@ def subdomain_discovery_flow(
if enabled_tools is None:
raise ValueError("enabled_tools 不能为空")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
logger.warning("未提供目标域名,跳过子域名发现扫描")
return _empty_result(scan_id, '', scan_workspace_dir)

View File

@@ -34,9 +34,9 @@ logger = logging.getLogger(__name__)
def domain_name_url_fetch_flow(
scan_id: int,
target_id: int,
target_name: str,
output_dir: str,
domain_name_tools: Dict[str, dict],
provider,
) -> dict:
"""
基于 Target 根域名执行 URL 被动收集(当前主要用于 waymore
@@ -46,32 +46,35 @@ def domain_name_url_fetch_flow(
2. 使用传入的工具列表对根域名执行被动收集
3. 工具内部会自动查询该域名及其子域名的历史 URL
4. 汇总结果文件列表
Args:
scan_id: 扫描 ID
target_id: 目标 ID
target_name: Target 根域名(如 example.com不是子域名列表
output_dir: 输出目录
domain_name_tools: 被动收集工具配置(如 waymore
provider: TargetProvider 实例
注意:
- 此 Flow 只对 DOMAIN 类型 Target 有效
- IP 和 CIDR 类型会自动跳过waymore 等工具不支持)
- 工具会自动收集 *.target_name 的所有历史 URL无需遍历子域名
"""
from apps.scan.utils import user_log
try:
# 从 provider 获取 target_name
target_name = provider.get_target_name()
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 检查 Target 类型IP/CIDR 类型跳过
from apps.targets.services import TargetService
from apps.targets.models import Target
target_service = TargetService()
target = target_service.get_target(target_id)
if target and target.type != Target.TargetType.DOMAIN:
logger.info(
"跳过 domain_name URL 获取: Target 类型为 %s (ID=%d, Name=%s)waymore 等工具仅适用于域名类型",

View File

@@ -240,10 +240,10 @@ def _save_urls_to_database(merged_file: str, scan_id: int, target_id: int) -> in
)
def url_fetch_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider,
) -> dict:
"""
URL 获取主 Flow
@@ -252,7 +252,7 @@ def url_fetch_flow(
1. 准备工作目录
2. 按输入类型分类工具domain_name / sites_file / 后处理)
3. 并行执行子 Flow
- domain_name_url_fetch_flow: 基于 domain_name来自 target_name)执行 URL 获取(如 waymore
- domain_name_url_fetch_flow: 基于 domain_name来自 provider)执行 URL 获取(如 waymore
- sites_url_fetch_flow: 基于 sites_file 执行爬虫(如 katana 等)
4. 合并所有子 Flow 的结果并去重
5. uro 去重(如果启用)
@@ -260,10 +260,10 @@ def url_fetch_flow(
Args:
scan_id: 扫描 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作目录
enabled_tools: 启用的工具配置
provider: TargetProvider 实例
Returns:
dict: 扫描结果
@@ -272,6 +272,11 @@ def url_fetch_flow(
# 负载检查:等待系统资源充足
wait_for_system_load(context="url_fetch_flow")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
logger.info(
"开始 URL 获取扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
@@ -310,9 +315,9 @@ def url_fetch_flow(
tn_result = domain_name_url_fetch_flow(
scan_id=scan_id,
target_id=target_id,
target_name=target_name,
output_dir=str(url_fetch_dir),
domain_name_tools=domain_name_tools,
provider=provider,
)
all_result_files.extend(tn_result.get('result_files', []))
all_failed_tools.extend(tn_result.get('failed_tools', []))
@@ -323,9 +328,9 @@ def url_fetch_flow(
crawl_result = sites_url_fetch_flow(
scan_id=scan_id,
target_id=target_id,
target_name=target_name,
output_dir=str(url_fetch_dir),
enabled_tools=sites_file_tools
enabled_tools=sites_file_tools,
provider=provider
)
all_result_files.extend(crawl_result.get('result_files', []))
all_failed_tools.extend(crawl_result.get('failed_tools', []))

View File

@@ -19,17 +19,16 @@ from .utils import run_tools_parallel
logger = logging.getLogger(__name__)
def _export_sites_file(target_id: int, scan_id: int, target_name: str, output_dir: Path) -> tuple[str, int]:
def _export_sites_file(
output_dir: Path,
provider,
) -> tuple[str, int]:
"""
导出站点 URL 列表到文件
懒加载模式:如果 WebSite 表为空,根据 Target 类型生成默认 URL
Args:
target_id: 目标 ID
scan_id: 扫描 ID
target_name: 目标名称(用于懒加载)
output_dir: 输出目录
provider: TargetProvider 实例
Returns:
tuple: (file_path, count)
@@ -39,8 +38,7 @@ def _export_sites_file(target_id: int, scan_id: int, target_name: str, output_di
output_file = str(output_dir / "sites.txt")
result = export_sites_task(
output_file=output_file,
target_id=target_id,
scan_id=scan_id
provider=provider
)
count = result['asset_count']
@@ -56,25 +54,25 @@ def _export_sites_file(target_id: int, scan_id: int, target_name: str, output_di
def sites_url_fetch_flow(
scan_id: int,
target_id: int,
target_name: str,
output_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider,
) -> dict:
"""
URL 爬虫子 Flow
执行流程:
1. 导出站点 URL 列表sites_file
2. 并行执行爬虫工具
3. 返回结果文件列表
Args:
scan_id: 扫描 ID
target_id: 目标 ID
target_name: 目标名称
output_dir: 输出目录
enabled_tools: 启用的爬虫工具配置
provider: TargetProvider 实例
Returns:
dict: {
'success': bool,
@@ -85,19 +83,22 @@ def sites_url_fetch_flow(
}
"""
try:
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
output_path = Path(output_dir)
logger.info(
"开始 URL 爬虫 - Target: %s, Tools: %s",
target_name, ', '.join(enabled_tools.keys())
)
# Step 1: 导出站点 URL 列表
sites_file, sites_count = _export_sites_file(
target_id=target_id,
scan_id=scan_id,
target_name=target_name,
output_dir=output_path
output_dir=output_path,
provider=provider
)
# 默认值模式下,即使原本没有站点,也会有默认 URL 作为输入

View File

@@ -34,17 +34,20 @@ logger = logging.getLogger(__name__)
)
def endpoints_vuln_scan_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: Dict[str, dict],
provider,
) -> dict:
"""基于 Endpoint 的漏洞扫描 Flow串行执行 Dalfox 等工具)。"""
try:
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
@@ -58,8 +61,8 @@ def endpoints_vuln_scan_flow(
# Step 1: 导出 Endpoint URL
export_result = export_endpoints_task(
target_id=target_id,
output_file=str(endpoints_file),
provider=provider,
)
total_endpoints = export_result.get("total_count", 0)
@@ -104,8 +107,11 @@ def endpoints_vuln_scan_flow(
continue
template_args = " ".join(f"-t {p}" for p in template_paths)
# 构建命令参数
command_params = {"endpoints_file": str(endpoints_file)}
# 构建命令参数(根据工具模板使用不同的参数名)
if tool_name == "nuclei":
command_params = {"input_file": str(endpoints_file)}
else:
command_params = {"endpoints_file": str(endpoints_file)}
if template_args:
command_params["template_args"] = template_args

View File

@@ -14,32 +14,48 @@ from apps.scan.handlers.scan_flow_handlers import (
from apps.scan.configs.command_templates import get_command_template
from apps.scan.utils import user_log, wait_for_system_load
from .endpoints_vuln_scan_flow import endpoints_vuln_scan_flow
from .websites_vuln_scan_flow import websites_vuln_scan_flow
logger = logging.getLogger(__name__)
def _classify_vuln_tools(enabled_tools: Dict[str, dict]) -> Tuple[Dict[str, dict], Dict[str, dict]]:
"""根据命令模板中的 input_type 对漏洞扫描工具进行分类。
def _classify_vuln_tools(
enabled_tools: Dict[str, dict]
) -> Tuple[Dict[str, dict], Dict[str, dict], Dict[str, dict]]:
"""根据用户配置分类漏洞扫描工具。
当前支持
- endpoints_file: 以端点列表文件为输入(例如 Dalfox XSS
预留:
- 其他 input_type 将被归类到 other_tools暂不处理。
分类逻辑
- 读取 scan_endpoints / scan_websites 配置
- 默认值从模板的 defaults 或 input_type 推断
Returns:
(endpoints_tools, websites_tools, other_tools) 三元组
"""
endpoints_tools: Dict[str, dict] = {}
websites_tools: Dict[str, dict] = {}
other_tools: Dict[str, dict] = {}
for tool_name, tool_config in enabled_tools.items():
template = get_command_template("vuln_scan", tool_name) or {}
input_type = template.get("input_type", "endpoints_file")
defaults = template.get("defaults", {})
if input_type == "endpoints_file":
# 根据 input_type 推断默认值(兼容老工具)
input_type = template.get("input_type")
default_endpoints = defaults.get("scan_endpoints", input_type == "endpoints_file")
default_websites = defaults.get("scan_websites", input_type == "websites_file")
scan_endpoints = tool_config.get("scan_endpoints", default_endpoints)
scan_websites = tool_config.get("scan_websites", default_websites)
if scan_endpoints:
endpoints_tools[tool_name] = tool_config
else:
if scan_websites:
websites_tools[tool_name] = tool_config
if not scan_endpoints and not scan_websites:
other_tools[tool_name] = tool_config
return endpoints_tools, other_tools
return endpoints_tools, websites_tools, other_tools
@flow(
@@ -51,25 +67,28 @@ def _classify_vuln_tools(enabled_tools: Dict[str, dict]) -> Tuple[Dict[str, dict
)
def vuln_scan_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: Dict[str, dict],
provider,
) -> dict:
"""漏洞扫描主 Flow串行编排各类漏洞扫描子 Flow。
支持工具:
- dalfox_xss: XSS 漏洞扫描(流式保存)
- nuclei: 通用漏洞扫描(流式保存,支持模板 commit hash 同步
- nuclei: 通用漏洞扫描(流式保存,支持 endpoints 和 websites 两种输入
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="vuln_scan_flow")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
@@ -81,11 +100,12 @@ def vuln_scan_flow(
user_log(scan_id, "vuln_scan", "Starting vulnerability scan")
# Step 1: 分类工具
endpoints_tools, other_tools = _classify_vuln_tools(enabled_tools)
endpoints_tools, websites_tools, other_tools = _classify_vuln_tools(enabled_tools)
logger.info(
"漏洞扫描工具分类 - endpoints_file: %s, 其他: %s",
"漏洞扫描工具分类 - endpoints: %s, websites: %s, 其他: %s",
list(endpoints_tools.keys()) or "",
list(websites_tools.keys()) or "",
list(other_tools.keys()) or "",
)
@@ -95,28 +115,58 @@ def vuln_scan_flow(
list(other_tools.keys()),
)
if not endpoints_tools:
raise ValueError("漏洞扫描需要至少启用一个以 endpoints_file 为输入的工具(如 dalfox_xss、nuclei")
if not endpoints_tools and not websites_tools:
raise ValueError(
"漏洞扫描需要至少启用一个工具endpoints 或 websites 模式)"
)
# Step 2: 执行 Endpoint 漏洞扫描子 Flow串行
endpoint_result = endpoints_vuln_scan_flow(
scan_id=scan_id,
target_name=target_name,
target_id=target_id,
scan_workspace_dir=scan_workspace_dir,
enabled_tools=endpoints_tools,
)
total_vulns = 0
results = {}
# Step 2: 执行 Endpoint 漏洞扫描子 Flow
if endpoints_tools:
logger.info("执行 Endpoint 漏洞扫描 - 工具: %s", list(endpoints_tools.keys()))
endpoint_result = endpoints_vuln_scan_flow(
scan_id=scan_id,
target_id=target_id,
scan_workspace_dir=scan_workspace_dir,
enabled_tools=endpoints_tools,
provider=provider,
)
results["endpoints"] = endpoint_result
total_vulns += sum(
r.get("created_vulns", 0)
for r in endpoint_result.get("tool_results", {}).values()
)
# Step 3: 执行 WebSite 漏洞扫描子 Flow
if websites_tools:
logger.info("执行 WebSite 漏洞扫描 - 工具: %s", list(websites_tools.keys()))
website_result = websites_vuln_scan_flow(
scan_id=scan_id,
target_id=target_id,
scan_workspace_dir=scan_workspace_dir,
enabled_tools=websites_tools,
provider=provider,
)
results["websites"] = website_result
total_vulns += sum(
r.get("created_vulns", 0)
for r in website_result.get("tool_results", {}).values()
)
# 记录 Flow 完成
total_vulns = sum(
r.get("created_vulns", 0)
for r in endpoint_result.get("tool_results", {}).values()
)
logger.info("✓ 漏洞扫描完成 - 新增漏洞: %d", total_vulns)
user_log(scan_id, "vuln_scan", f"vuln_scan completed: found {total_vulns} vulnerabilities")
# 目前只有一个子 Flow直接返回其结果
return endpoint_result
return {
"success": True,
"scan_id": scan_id,
"target": target_name,
"scan_workspace_dir": scan_workspace_dir,
"total_vulns": total_vulns,
"sub_flow_results": results,
}
except Exception as e:
logger.exception("漏洞扫描主 Flow 失败: %s", e)

View File

@@ -0,0 +1,192 @@
"""
基于 WebSite 的漏洞扫描 Flow
与 endpoints_vuln_scan_flow 类似,但数据源是 WebSite 而不是 Endpoint。
主要用于 nuclei 扫描已存活的网站。
"""
import logging
from datetime import datetime
from typing import Dict
from prefect import flow
from apps.scan.utils import build_scan_command, ensure_nuclei_templates_local, user_log
from apps.scan.tasks.vuln_scan import run_and_stream_save_nuclei_vulns_task
from apps.scan.tasks.vuln_scan.export_websites_task import export_websites_task
from .utils import calculate_timeout_by_line_count
logger = logging.getLogger(__name__)
@flow(
name="websites_vuln_scan_flow",
log_prints=True,
)
def websites_vuln_scan_flow(
scan_id: int,
target_id: int,
scan_workspace_dir: str,
enabled_tools: Dict[str, dict],
provider,
) -> dict:
"""基于 WebSite 的漏洞扫描 Flow主要用于 nuclei"""
try:
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
if not enabled_tools:
raise ValueError("enabled_tools 不能为空")
from apps.scan.utils import setup_scan_directory
vuln_scan_dir = setup_scan_directory(scan_workspace_dir, 'vuln_scan')
websites_file = vuln_scan_dir / "input_websites.txt"
# Step 1: 导出 WebSite URL
export_result = export_websites_task(
output_file=str(websites_file),
provider=provider,
)
total_websites = export_result.get("total_count", 0)
if total_websites == 0:
logger.warning("目标下没有可用 WebSite跳过漏洞扫描")
return {
"success": True,
"scan_id": scan_id,
"target": target_name,
"scan_workspace_dir": scan_workspace_dir,
"websites_file": str(websites_file),
"website_count": 0,
"executed_tools": [],
"tool_results": {},
}
logger.info("WebSite 导出完成,共 %d 条,开始执行漏洞扫描", total_websites)
tool_results: Dict[str, dict] = {}
tool_futures: Dict[str, dict] = {}
# Step 2: 执行漏洞扫描工具
for tool_name, tool_config in enabled_tools.items():
# 目前只支持 nuclei
if tool_name != "nuclei":
logger.warning("websites_vuln_scan_flow 暂不支持工具: %s", tool_name)
continue
# 确保 nuclei 模板存在
repo_names = tool_config.get("template_repo_names")
if not repo_names or not isinstance(repo_names, (list, tuple)):
logger.error("Nuclei 配置缺少 template_repo_names数组跳过")
continue
template_paths = []
try:
for repo_name in repo_names:
path = ensure_nuclei_templates_local(repo_name)
template_paths.append(path)
logger.info("Nuclei 模板路径 [%s]: %s", repo_name, path)
except Exception as e:
logger.error("获取 Nuclei 模板失败: %s,跳过 nuclei 扫描", e)
continue
template_args = " ".join(f"-t {p}" for p in template_paths)
# 构建命令(使用 websites_file 作为输入)
command_params = {
"input_file": str(websites_file),
"template_args": template_args,
}
command = build_scan_command(
tool_name=tool_name,
scan_type="vuln_scan",
command_params=command_params,
tool_config=tool_config,
)
# 计算超时时间
raw_timeout = tool_config.get("timeout", 600)
if isinstance(raw_timeout, str) and raw_timeout == "auto":
timeout = calculate_timeout_by_line_count(
tool_config=tool_config,
file_path=str(websites_file),
base_per_time=30,
)
else:
try:
timeout = int(raw_timeout)
except (TypeError, ValueError) as e:
raise ValueError(
f"工具 {tool_name} 的 timeout 配置无效: {raw_timeout!r}"
) from e
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = vuln_scan_dir / f"{tool_name}_websites_{timestamp}.log"
logger.info("开始执行 %s 漏洞扫描WebSite 模式)", tool_name)
user_log(scan_id, "vuln_scan", f"Running {tool_name} (websites): {command}")
future = run_and_stream_save_nuclei_vulns_task.submit(
cmd=command,
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
cwd=str(vuln_scan_dir),
shell=True,
batch_size=1,
timeout=timeout,
log_file=str(log_file),
)
tool_futures[tool_name] = {
"future": future,
"command": command,
"timeout": timeout,
"log_file": str(log_file),
}
# 收集结果
for tool_name, meta in tool_futures.items():
future = meta["future"]
try:
result = future.result()
created_vulns = result.get("created_vulns", 0)
tool_results[tool_name] = {
"command": meta["command"],
"timeout": meta["timeout"],
"processed_records": result.get("processed_records"),
"created_vulns": created_vulns,
"command_log_file": meta["log_file"],
}
logger.info("✓ 工具 %s (websites) 执行完成 - 漏洞: %d", tool_name, created_vulns)
user_log(
scan_id, "vuln_scan",
f"{tool_name} (websites) completed: found {created_vulns} vulnerabilities"
)
except Exception as e:
reason = str(e)
logger.error("工具 %s 执行失败: %s", tool_name, e, exc_info=True)
user_log(scan_id, "vuln_scan", f"{tool_name} failed: {reason}", "error")
return {
"success": True,
"scan_id": scan_id,
"target": target_name,
"scan_workspace_dir": scan_workspace_dir,
"websites_file": str(websites_file),
"website_count": total_websites,
"executed_tools": list(enabled_tools.keys()),
"tool_results": tool_results,
}
except Exception as e:
logger.exception("WebSite 漏洞扫描失败: %s", e)
raise

View File

@@ -0,0 +1,35 @@
# Generated by Django 5.2.7 on 2026-01-10 03:51
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('scan', '0003_add_wecom_fields'),
]
operations = [
migrations.AddField(
model_name='scan',
name='scan_mode',
field=models.CharField(choices=[('full', '完整扫描'), ('quick', '快速扫描')], default='full', help_text='扫描模式full=完整扫描quick=快速扫描', max_length=10),
),
migrations.CreateModel(
name='ScanInputTarget',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('value', models.CharField(help_text='用户输入的原始值', max_length=2000)),
('input_type', models.CharField(choices=[('domain', '域名'), ('ip', 'IP地址'), ('cidr', 'CIDR'), ('url', 'URL')], help_text='输入类型', max_length=10)),
('created_at', models.DateTimeField(auto_now_add=True)),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='input_targets', to='scan.scan')),
],
options={
'verbose_name': '扫描输入目标',
'verbose_name_plural': '扫描输入目标',
'db_table': 'scan_input_target',
'indexes': [models.Index(fields=['scan'], name='scan_input__scan_id_0a3227_idx'), models.Index(fields=['input_type'], name='scan_input__input_t_e3f681_idx')],
},
),
]

View File

@@ -4,6 +4,7 @@ from .scan_models import Scan, SoftDeleteManager
from .scan_log_model import ScanLog
from .scheduled_scan_model import ScheduledScan
from .subfinder_provider_settings_model import SubfinderProviderSettings
from .scan_input_target import ScanInputTarget
# 兼容旧名称(已废弃,请使用 SubfinderProviderSettings
ProviderSettings = SubfinderProviderSettings
@@ -15,4 +16,5 @@ __all__ = [
'SoftDeleteManager',
'SubfinderProviderSettings',
'ProviderSettings', # 兼容旧名称
'ScanInputTarget',
]

View File

@@ -0,0 +1,47 @@
"""
扫描输入目标模型
存储快速扫描时用户输入的目标支持大量数据1万+)的分块迭代。
用于快速扫描的第一阶段。
"""
from django.db import models
class ScanInputTarget(models.Model):
"""扫描输入目标表"""
class InputType(models.TextChoices):
"""输入类型枚举"""
DOMAIN = 'domain', '域名'
IP = 'ip', 'IP地址'
CIDR = 'cidr', 'CIDR'
URL = 'url', 'URL'
id = models.AutoField(primary_key=True)
scan = models.ForeignKey(
'scan.Scan',
on_delete=models.CASCADE,
related_name='input_targets',
help_text='所属的扫描任务'
)
value = models.CharField(max_length=2000, help_text='用户输入的原始值')
input_type = models.CharField(
max_length=10,
choices=InputType.choices,
help_text='输入类型'
)
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
"""模型元数据"""
db_table = 'scan_input_target'
verbose_name = '扫描输入目标'
verbose_name_plural = '扫描输入目标'
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['input_type']),
]
def __str__(self):
return f"ScanInputTarget #{self.id} - {self.value} ({self.input_type})"

View File

@@ -8,17 +8,28 @@ from apps.common.definitions import ScanStatus
class SoftDeleteManager(models.Manager):
"""软删除管理器:默认只返回未删除的记录"""
def get_queryset(self):
"""返回未删除记录的查询集"""
return super().get_queryset().filter(deleted_at__isnull=True)
class Scan(models.Model):
"""扫描任务模型"""
class ScanMode(models.TextChoices):
"""扫描模式枚举"""
FULL = 'full', '完整扫描'
QUICK = 'quick', '快速扫描'
id = models.AutoField(primary_key=True)
target = models.ForeignKey('targets.Target', on_delete=models.CASCADE, related_name='scans', help_text='扫描目标')
target = models.ForeignKey(
'targets.Target',
on_delete=models.CASCADE,
related_name='scans',
help_text='扫描目标'
)
# 多引擎支持字段
engine_ids = ArrayField(
@@ -35,6 +46,14 @@ class Scan(models.Model):
help_text='YAML 格式的扫描配置'
)
# 扫描模式
scan_mode = models.CharField(
max_length=10,
choices=ScanMode.choices,
default=ScanMode.FULL,
help_text='扫描模式full=完整扫描quick=快速扫描'
)
created_at = models.DateTimeField(auto_now_add=True, help_text='任务创建时间')
stopped_at = models.DateTimeField(null=True, blank=True, help_text='扫描结束时间')
@@ -46,7 +65,12 @@ class Scan(models.Model):
help_text='任务状态'
)
results_dir = models.CharField(max_length=100, blank=True, default='', help_text='结果存储目录')
results_dir = models.CharField(
max_length=100,
blank=True,
default='',
help_text='结果存储目录'
)
container_ids = ArrayField(
models.CharField(max_length=100),
@@ -54,7 +78,7 @@ class Scan(models.Model):
default=list,
help_text='容器 ID 列表Docker Container ID'
)
worker = models.ForeignKey(
'engine.WorkerNode',
on_delete=models.SET_NULL,
@@ -64,35 +88,46 @@ class Scan(models.Model):
help_text='执行扫描的 Worker 节点'
)
error_message = models.CharField(max_length=2000, blank=True, default='', help_text='错误信息')
error_message = models.CharField(
max_length=2000,
blank=True,
default='',
help_text='错误信息'
)
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间NULL表示未删除')
# 软删除字段
deleted_at = models.DateTimeField(
null=True,
blank=True,
db_index=True,
help_text='删除时间NULL表示未删除'
)
# ==================== 管理器 ====================
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
# 管理器
objects = SoftDeleteManager()
all_objects = models.Manager()
# ==================== 进度跟踪字段 ====================
# 进度跟踪字段
progress = models.IntegerField(default=0, help_text='扫描进度 0-100')
current_stage = models.CharField(max_length=50, blank=True, default='', help_text='当前扫描阶段')
stage_progress = models.JSONField(default=dict, help_text='各阶段进度详情')
# ==================== 缓存统计字段 ====================
cached_subdomains_count = models.IntegerField(default=0, help_text='缓存的子域名数量')
cached_websites_count = models.IntegerField(default=0, help_text='缓存的网站数量')
cached_endpoints_count = models.IntegerField(default=0, help_text='缓存的端点数量')
cached_ips_count = models.IntegerField(default=0, help_text='缓存的IP地址数量')
cached_directories_count = models.IntegerField(default=0, help_text='缓存的目录数量')
cached_screenshots_count = models.IntegerField(default=0, help_text='缓存的截图数量')
cached_vulns_total = models.IntegerField(default=0, help_text='缓存的漏洞总数')
cached_vulns_critical = models.IntegerField(default=0, help_text='缓存的严重漏洞数量')
cached_vulns_high = models.IntegerField(default=0, help_text='缓存的高危漏洞数量')
cached_vulns_medium = models.IntegerField(default=0, help_text='缓存的中危漏洞数量')
cached_vulns_low = models.IntegerField(default=0, help_text='缓存的低危漏洞数量')
# 缓存统计字段
cached_subdomains_count = models.IntegerField(default=0, help_text='子域名数量')
cached_websites_count = models.IntegerField(default=0, help_text='网站数量')
cached_endpoints_count = models.IntegerField(default=0, help_text='端点数量')
cached_ips_count = models.IntegerField(default=0, help_text='IP地址数量')
cached_directories_count = models.IntegerField(default=0, help_text='目录数量')
cached_screenshots_count = models.IntegerField(default=0, help_text='截图数量')
cached_vulns_total = models.IntegerField(default=0, help_text='漏洞总数')
cached_vulns_critical = models.IntegerField(default=0, help_text='严重漏洞数量')
cached_vulns_high = models.IntegerField(default=0, help_text='高危漏洞数量')
cached_vulns_medium = models.IntegerField(default=0, help_text='中危漏洞数量')
cached_vulns_low = models.IntegerField(default=0, help_text='低危漏洞数量')
stats_updated_at = models.DateTimeField(null=True, blank=True, help_text='统计数据最后更新时间')
class Meta:
"""模型元数据配置"""
db_table = 'scan'
verbose_name = '扫描任务'
verbose_name_plural = '扫描任务'

View File

@@ -3,54 +3,49 @@
提供统一的目标获取接口,支持多种数据源:
- DatabaseTargetProvider: 从数据库查询(完整扫描)
- ListTargetProvider: 使用内存列表(快速扫描阶段1
- SnapshotTargetProvider: 从快照表读取快速扫描阶段2+
- PipelineTargetProvider: 使用管道输出Phase 2
- SnapshotTargetProvider: 从快照表读取(快速扫描)
Provider 方法:
- get_target_name(): Target 名称(根域名/IP/CIDR
- iter_subdomains(): 子域名列表
- iter_host_port_urls(): 从 host:port 生成的 URL站点探测用
- iter_websites(): 已存活网站 URL截图、指纹、目录扫描用
- iter_endpoints(): 端点 URL漏洞扫描用
使用方式:
from apps.scan.providers import (
DatabaseTargetProvider,
ListTargetProvider,
SnapshotTargetProvider,
ProviderContext
)
# 数据库模式(完整扫描)
provider = DatabaseTargetProvider(target_id=123)
# 列表模式快速扫描阶段1
context = ProviderContext(target_id=1, scan_id=100)
provider = ListTargetProvider(
targets=["a.test.com"],
context=context
)
# 快照模式快速扫描阶段2+
context = ProviderContext(target_id=1, scan_id=100)
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain",
context=context
)
# 使用 Provider
for host in provider.iter_hosts():
scan(host)
# 端口扫描:显式组合 target_name + subdomains
target_name = provider.get_target_name()
if target_name:
scan_port(target_name) # CIDR 需要调用方自己展开
for subdomain in provider.iter_subdomains():
scan_port(subdomain)
# 截图
for url in provider.iter_websites():
take_screenshot(url)
# 快照模式(快速扫描)
provider = SnapshotTargetProvider(scan_id=100)
for url in provider.iter_websites():
take_screenshot(url)
"""
from .base import TargetProvider, ProviderContext
from .list_provider import ListTargetProvider
from .database_provider import DatabaseTargetProvider
from .snapshot_provider import SnapshotTargetProvider, SnapshotType
from .pipeline_provider import PipelineTargetProvider, StageOutput
from .snapshot_provider import SnapshotTargetProvider
__all__ = [
'TargetProvider',
'ProviderContext',
'ListTargetProvider',
'DatabaseTargetProvider',
'SnapshotTargetProvider',
'SnapshotType',
'PipelineTargetProvider',
'StageOutput',
]

View File

@@ -4,7 +4,6 @@
定义 ProviderContext 数据类和 TargetProvider 抽象基类。
"""
import ipaddress
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
@@ -37,72 +36,184 @@ class TargetProvider(ABC):
- 提供扫描目标域名、IP、URL 等)的迭代器
- 提供黑名单过滤器
- 携带上下文信息target_id, scan_id 等)
- 自动展开 CIDR子类无需关心
方法说明:
- get_target_name(): Target 名称(根域名/IP/CIDR
- iter_subdomains(): 子域名列表
- iter_host_port_urls(): 从 host:port 生成的 URL站点探测用
- iter_websites(): 已存活网站 URL截图、指纹、目录扫描用
- iter_endpoints(): 端点 URL漏洞扫描用
使用方式:
provider = create_target_provider(target_id=123)
for host in provider.iter_hosts():
print(host)
provider = DatabaseTargetProvider(target_id=123)
# 端口扫描:显式组合 target_name + subdomains
target_name = provider.get_target_name()
if target_name:
scan_port(target_name) # CIDR 需要调用方自己展开
for subdomain in provider.iter_subdomains():
scan_port(subdomain)
# 截图
for url in provider.iter_websites():
take_screenshot(url)
"""
def __init__(self, context: Optional[ProviderContext] = None):
self._context = context or ProviderContext()
self._target_name: Optional[str] = None # 缓存 target_name
@property
def context(self) -> ProviderContext:
"""返回 Provider 上下文"""
return self._context
@staticmethod
def _expand_host(host: str) -> Iterator[str]:
def get_target_name(self) -> Optional[str]:
"""
展开主机(如果是 CIDR 则展开为多个 IP否则直接返回
获取 Target 名称(根域名/IP/CIDR
示例:
"192.168.1.0/30""192.168.1.1", "192.168.1.2"
"192.168.1.1""192.168.1.1"
"example.com""example.com"
Returns:
Target 名称,不存在时返回 None
注意CIDR 不会自动展开,调用方需要自己处理
"""
# 使用缓存避免重复查询
if self._target_name is not None:
return self._target_name
if not self.target_id:
logger.warning("target_id 未设置,无法获取 Target 名称")
return None
from apps.targets.services import TargetService
target = TargetService().get_target(self.target_id)
self._target_name = target.name if target else None
return self._target_name
def iter_target_hosts(self) -> Iterator[str]:
"""
迭代 Target 展开后的主机列表(已过滤黑名单)
- DOMAIN/IP: 直接返回
- CIDR: 展开为所有 IP
Returns:
主机迭代器(域名或 IP
"""
import ipaddress
from apps.common.validators import detect_target_type
from apps.targets.models import Target
host = host.strip()
if not host:
target_name = self.get_target_name()
if not target_name:
return
try:
target_type = detect_target_type(host)
blacklist = self.get_blacklist_filter()
target_type = detect_target_type(target_name)
if target_type == Target.TargetType.CIDR:
network = ipaddress.ip_network(host, strict=False)
if network.num_addresses == 1:
yield str(network.network_address)
else:
yield from (str(ip) for ip in network.hosts())
elif target_type in (Target.TargetType.IP, Target.TargetType.DOMAIN):
if target_type == Target.TargetType.CIDR:
# CIDR 展开
network = ipaddress.ip_network(target_name, strict=False)
if network.num_addresses == 1:
hosts = [str(network.network_address)]
else:
hosts = [str(ip) for ip in network.hosts()]
else:
# DOMAIN / IP 直接返回
hosts = [target_name]
for host in hosts:
if not blacklist or blacklist.is_allowed(host):
yield host
except ValueError as e:
logger.warning("跳过无效的主机格式 '%s': %s", host, str(e))
def iter_hosts(self) -> Iterator[str]:
"""迭代主机列表(域名/IP自动展开 CIDR"""
for host in self._iter_raw_hosts():
yield from self._expand_host(host)
@abstractmethod
def _iter_raw_hosts(self) -> Iterator[str]:
"""迭代原始主机列表(可能包含 CIDR,子类实现"""
pass
def iter_subdomains(self) -> Iterator[str]:
"""迭代子域名列表,子类实现"""
@abstractmethod
def iter_urls(self) -> Iterator[str]:
"""迭代 URL 列表"""
pass
def iter_host_port_urls(self) -> Iterator[str]:
"""
迭代 host:port 生成的 URL待探测
用于站点扫描httpx 探测),从 HostPortMapping 生成 URL。
返回格式http://host:port 或 https://host:port
"""
@abstractmethod
def iter_websites(self) -> Iterator[str]:
"""
迭代已存活网站 URL
用于截图、指纹识别、目录扫描、URL 爬虫。
数据来源WebSite 表(已确认存活的网站)
"""
@abstractmethod
def iter_endpoints(self) -> Iterator[str]:
"""
迭代端点 URL
用于漏洞扫描。
数据来源Endpoint 表(带参数的 URL
"""
def iter_default_urls(self) -> Iterator[str]:
"""
从 Target 本身生成默认 URL
用于跳过前置阶段直接扫描的场景。
根据 Target 类型生成:
- DOMAIN: http(s)://domain
- IP: http(s)://ip
- CIDR: 展开为所有 IP 的 http(s)://ip
"""
import ipaddress
from apps.targets.models import Target
from apps.targets.services import TargetService
if not self.target_id:
logger.warning("target_id 未设置,无法生成默认 URL")
return
target = TargetService().get_target(self.target_id)
if not target:
logger.warning("Target ID %d 不存在,无法生成默认 URL", self.target_id)
return
target_name = target.name
target_type = target.type
blacklist = self.get_blacklist_filter()
if target_type == Target.TargetType.DOMAIN:
urls = [f"http://{target_name}", f"https://{target_name}"]
elif target_type == Target.TargetType.IP:
urls = [f"http://{target_name}", f"https://{target_name}"]
elif target_type == Target.TargetType.CIDR:
try:
network = ipaddress.ip_network(target_name, strict=False)
urls = []
for ip in network.hosts():
urls.extend([f"http://{ip}", f"https://{ip}"])
# /32 或 /128 特殊处理
if not urls:
ip = str(network.network_address)
urls = [f"http://{ip}", f"https://{ip}"]
except ValueError as e:
logger.error("CIDR 解析失败: %s - %s", target_name, e)
return
else:
logger.warning("不支持的 Target 类型: %s", target_type)
return
for url in urls:
if not blacklist or blacklist.is_allowed(url):
yield url
@abstractmethod
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
"""获取黑名单过滤器,返回 None 表示不过滤"""
pass
@property
def target_id(self) -> Optional[int]:

View File

@@ -2,6 +2,7 @@
数据库目标提供者模块
提供基于数据库查询的目标提供者实现。
用于完整扫描模式,从 Target 关联的资产表查询数据。
"""
import logging
@@ -19,14 +20,33 @@ class DatabaseTargetProvider(TargetProvider):
"""
数据库目标提供者 - 从 Target 表及关联资产表查询
用于完整扫描模式,查询目标下的所有历史资产。
数据来源:
- iter_hosts(): 根据 Target 类型返回域名/IP
- iter_urls(): WebSite/Endpoint,带回退链
- iter_target_name(): Target 表(根域名/IP/CIDR
- iter_subdomains(): Subdomain 表
- iter_host_port_urls(): HostPortMapping 表
- iter_websites(): WebSite 表
- iter_endpoints(): Endpoint 表
- iter_default_urls(): 从 Target 本身生成默认 URL
回退逻辑由调用方Task/Flow决定Provider 只负责单一数据源查询。
使用方式:
provider = DatabaseTargetProvider(target_id=123)
for host in provider.iter_hosts():
scan(host)
# 端口扫描:显式组合
for name in provider.iter_target_name():
scan_port(name) # CIDR 需要调用方自己展开
for subdomain in provider.iter_subdomains():
scan_port(subdomain)
# 调用方控制回退
urls = list(provider.iter_endpoints())
if not urls:
urls = list(provider.iter_websites())
if not urls:
urls = list(provider.iter_default_urls())
"""
def __init__(self, target_id: int, context: Optional[ProviderContext] = None):
@@ -35,53 +55,73 @@ class DatabaseTargetProvider(TargetProvider):
super().__init__(ctx)
self._blacklist_filter: Optional['BlacklistFilter'] = None
def iter_hosts(self) -> Iterator[str]:
"""数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤"""
blacklist = self.get_blacklist_filter()
for host in self._iter_raw_hosts():
for expanded_host in self._expand_host(host):
if not blacklist or blacklist.is_allowed(expanded_host):
yield expanded_host
def _iter_raw_hosts(self) -> Iterator[str]:
"""从数据库查询原始主机列表(可能包含 CIDR"""
def iter_subdomains(self) -> Iterator[str]:
""" Subdomain 表查询子域名列表"""
from apps.asset.services.asset.subdomain_service import SubdomainService
from apps.targets.models import Target
from apps.targets.services import TargetService
target = TargetService().get_target(self.target_id)
if not target:
logger.warning("Target ID %d 不存在", self.target_id)
return
if target.type == Target.TargetType.DOMAIN:
yield target.name
for domain in SubdomainService().iter_subdomain_names_by_target(
target_id=self.target_id,
chunk_size=1000
):
if domain != target.name:
yield domain
elif target.type in (Target.TargetType.IP, Target.TargetType.CIDR):
yield target.name
def iter_urls(self) -> Iterator[str]:
"""从数据库查询 URL 列表使用回退链Endpoint → WebSite → Default"""
from apps.scan.services.target_export_service import (
DataSource,
_iter_urls_with_fallback,
)
blacklist = self.get_blacklist_filter()
for url, _ in _iter_urls_with_fallback(
for domain in SubdomainService().iter_subdomain_names_by_target(
target_id=self.target_id,
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
blacklist_filter=blacklist
chunk_size=1000
):
yield url
if not blacklist or blacklist.is_allowed(domain):
yield domain
def iter_host_port_urls(self) -> Iterator[str]:
"""从 HostPortMapping 表生成待探测的 URL"""
from apps.asset.models import HostPortMapping
blacklist = self.get_blacklist_filter()
queryset = HostPortMapping.objects.filter(
target_id=self.target_id
).values('host', 'port').distinct()
for mapping in queryset.iterator(chunk_size=1000):
host = mapping['host']
port = mapping['port']
if port == 80:
urls = [f"http://{host}"]
elif port == 443:
urls = [f"https://{host}"]
else:
urls = [f"http://{host}:{port}", f"https://{host}:{port}"]
for url in urls:
if not blacklist or blacklist.is_allowed(url):
yield url
def iter_websites(self) -> Iterator[str]:
"""从 WebSite 表查询已存活网站 URL"""
from apps.asset.models import WebSite
blacklist = self.get_blacklist_filter()
queryset = WebSite.objects.filter(
target_id=self.target_id
).values_list('url', flat=True)
for url in queryset.iterator(chunk_size=1000):
if url:
if not blacklist or blacklist.is_allowed(url):
yield url
def iter_endpoints(self) -> Iterator[str]:
"""从 Endpoint 表查询端点 URL"""
from apps.asset.models import Endpoint
blacklist = self.get_blacklist_filter()
queryset = Endpoint.objects.filter(
target_id=self.target_id
).values_list('url', flat=True)
for url in queryset.iterator(chunk_size=1000):
if url:
if not blacklist or blacklist.is_allowed(url):
yield url
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
"""获取黑名单过滤器(延迟加载)"""

View File

@@ -1,84 +0,0 @@
"""
列表目标提供者模块
提供基于内存列表的目标提供者实现。
"""
from typing import Iterator, Optional, List
from .base import TargetProvider, ProviderContext
class ListTargetProvider(TargetProvider):
"""
列表目标提供者 - 直接使用内存中的列表
用于快速扫描、临时扫描等场景,只扫描用户指定的目标。
特点:
- 不查询数据库
- 不应用黑名单过滤(用户明确指定的目标)
- 不关联 target_id由调用方负责创建 Target
- 自动检测输入类型URL/域名/IP/CIDR
- 自动展开 CIDR
使用方式:
# 快速扫描:用户提供目标,自动识别类型
provider = ListTargetProvider(targets=[
"example.com", # 域名
"192.168.1.0/24", # CIDR自动展开
"https://api.example.com" # URL
])
for host in provider.iter_hosts():
scan(host)
"""
def __init__(
self,
targets: Optional[List[str]] = None,
context: Optional[ProviderContext] = None
):
"""
初始化列表目标提供者
Args:
targets: 目标列表自动识别类型URL/域名/IP/CIDR
context: Provider 上下文
"""
from apps.common.validators import detect_input_type
ctx = context or ProviderContext()
super().__init__(ctx)
# 自动分类目标
self._hosts = []
self._urls = []
if targets:
for target in targets:
target = target.strip()
if not target:
continue
try:
input_type = detect_input_type(target)
if input_type == 'url':
self._urls.append(target)
else:
# domain/ip/cidr 都作为 host
self._hosts.append(target)
except ValueError:
# 无法识别类型,默认作为 host
self._hosts.append(target)
def _iter_raw_hosts(self) -> Iterator[str]:
"""迭代原始主机列表(可能包含 CIDR"""
yield from self._hosts
def iter_urls(self) -> Iterator[str]:
"""迭代 URL 列表"""
yield from self._urls
def get_blacklist_filter(self) -> None:
"""列表模式不使用黑名单过滤"""
return None

View File

@@ -1,91 +0,0 @@
"""
管道目标提供者模块
提供基于管道阶段输出的目标提供者实现。
用于 Phase 2 管道模式的阶段间数据传递。
"""
from dataclasses import dataclass, field
from typing import Iterator, Optional, List, Dict, Any
from .base import TargetProvider, ProviderContext
@dataclass
class StageOutput:
"""
阶段输出数据
用于在管道阶段之间传递数据。
Attributes:
hosts: 主机列表(域名/IP
urls: URL 列表
new_targets: 新发现的目标列表
stats: 统计信息
success: 是否成功
error: 错误信息
"""
hosts: List[str] = field(default_factory=list)
urls: List[str] = field(default_factory=list)
new_targets: List[str] = field(default_factory=list)
stats: Dict[str, Any] = field(default_factory=dict)
success: bool = True
error: Optional[str] = None
class PipelineTargetProvider(TargetProvider):
"""
管道目标提供者 - 使用上一阶段的输出
用于 Phase 2 管道模式的阶段间数据传递。
特点:
- 不查询数据库
- 不应用黑名单过滤(数据已在上一阶段过滤)
- 直接使用 StageOutput 中的数据
使用方式Phase 2
stage1_output = stage1.run(input)
provider = PipelineTargetProvider(
previous_output=stage1_output,
target_id=123
)
for host in provider.iter_hosts():
stage2.scan(host)
"""
def __init__(
self,
previous_output: StageOutput,
target_id: Optional[int] = None,
context: Optional[ProviderContext] = None
):
"""
初始化管道目标提供者
Args:
previous_output: 上一阶段的输出
target_id: 可选,关联到某个 Target用于保存结果
context: Provider 上下文
"""
ctx = context or ProviderContext(target_id=target_id)
super().__init__(ctx)
self._previous_output = previous_output
def _iter_raw_hosts(self) -> Iterator[str]:
"""迭代上一阶段输出的原始主机(可能包含 CIDR"""
yield from self._previous_output.hosts
def iter_urls(self) -> Iterator[str]:
"""迭代上一阶段输出的 URL"""
yield from self._previous_output.urls
def get_blacklist_filter(self) -> None:
"""管道传递的数据已经过滤过了"""
return None
@property
def previous_output(self) -> StageOutput:
"""返回上一阶段的输出"""
return self._previous_output

View File

@@ -6,170 +6,106 @@
"""
import logging
from typing import Iterator, Optional, Literal
from typing import Iterator, Optional
from .base import TargetProvider, ProviderContext
from .base import ProviderContext, TargetProvider
logger = logging.getLogger(__name__)
# 快照类型定义
SnapshotType = Literal["subdomain", "website", "endpoint", "host_port"]
class SnapshotTargetProvider(TargetProvider):
"""
快照目标提供者 - 从快照表读取本次扫描的数据
用于快速扫描的阶段间数据传递,解决精确扫描控制问题。
核心价值:
- 只返回本次扫描scan_id发现的资产
- 避免扫描历史数据DatabaseTargetProvider 会扫描所有历史资产)
特点:
- 通过 scan_id 过滤快照表
- 不应用黑名单过滤(数据已在上一阶段过滤)
- 支持多种快照类型subdomain/website/endpoint/host_port
- 每个 iter_* 方法只查对应的快照表(单一职责
- 回退逻辑由调用方Task/Flow决定
使用场景:
# 快速扫描流程
用户输入: a.test.com
创建 Target: test.com (id=1)
创建 Scan: scan_id=100
# 阶段1: 子域名发现
provider = ListTargetProvider(
targets=["a.test.com"],
context=ProviderContext(target_id=1, scan_id=100)
)
# 发现: b.a.test.com, c.a.test.com
# 保存: SubdomainSnapshot(scan_id=100) + Subdomain(target_id=1)
# 阶段2: 端口扫描
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain",
context=ProviderContext(target_id=1, scan_id=100)
)
# 只返回: b.a.test.com, c.a.test.com本次扫描发现的
# 不返回: www.test.com, api.test.com历史数据
# 阶段3: 网站扫描
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="host_port",
context=ProviderContext(target_id=1, scan_id=100)
)
# 只返回本次扫描发现的 IP:Port
provider = SnapshotTargetProvider(scan_id=100)
# 单一数据源
for url in provider.iter_websites():
take_screenshot(url)
# 调用方控制回退
urls = list(provider.iter_endpoints())
if not urls:
urls = list(provider.iter_websites())
if not urls:
urls = list(provider.iter_default_urls())
"""
def __init__(
self,
scan_id: int,
snapshot_type: SnapshotType,
context: Optional[ProviderContext] = None
):
"""
初始化快照目标提供者
Args:
scan_id: 扫描任务 ID必需
snapshot_type: 快照类型
- "subdomain": 子域名快照SubdomainSnapshot
- "website": 网站快照WebsiteSnapshot
- "endpoint": 端点快照EndpointSnapshot
- "host_port": 主机端口映射快照HostPortMappingSnapshot
context: Provider 上下文
"""
ctx = context or ProviderContext()
ctx.scan_id = scan_id
super().__init__(ctx)
self._scan_id = scan_id
self._snapshot_type = snapshot_type
def _iter_raw_hosts(self) -> Iterator[str]:
"""
从快照表迭代主机列表
根据 snapshot_type 选择不同的快照表:
- subdomain: SubdomainSnapshot.name
- host_port: HostPortMappingSnapshot.host (返回 host:port 格式,不经过验证)
"""
if self._snapshot_type == "subdomain":
from apps.asset.services.snapshot import SubdomainSnapshotsService
service = SubdomainSnapshotsService()
yield from service.iter_subdomain_names_by_scan(
scan_id=self._scan_id,
chunk_size=1000
)
elif self._snapshot_type == "host_port":
# host_port 类型不使用 _iter_raw_hosts直接在 iter_hosts 中处理
# 这里返回空,避免被基类的 iter_hosts 调用
return
else:
# 其他类型暂不支持 iter_hosts
logger.warning(
"快照类型 '%s' 不支持 iter_hosts返回空迭代器",
self._snapshot_type
)
return
def iter_hosts(self) -> Iterator[str]:
"""
迭代主机列表
对于 host_port 类型,返回 host:port 格式,不经过 CIDR 展开验证
"""
if self._snapshot_type == "host_port":
# host_port 类型直接返回 host:port不经过 _expand_host 验证
from apps.asset.services.snapshot import HostPortMappingSnapshotsService
service = HostPortMappingSnapshotsService()
queryset = service.get_by_scan(scan_id=self._scan_id)
for mapping in queryset.iterator(chunk_size=1000):
yield f"{mapping.host}:{mapping.port}"
else:
# 其他类型使用基类的 iter_hosts会调用 _iter_raw_hosts 并展开 CIDR
yield from super().iter_hosts()
def iter_urls(self) -> Iterator[str]:
"""
从快照表迭代 URL 列表
根据 snapshot_type 选择不同的快照表:
- website: WebsiteSnapshot.url
- endpoint: EndpointSnapshot.url
"""
if self._snapshot_type == "website":
from apps.asset.services.snapshot import WebsiteSnapshotsService
service = WebsiteSnapshotsService()
yield from service.iter_website_urls_by_scan(
scan_id=self._scan_id,
chunk_size=1000
)
elif self._snapshot_type == "endpoint":
from apps.asset.services.snapshot import EndpointSnapshotsService
service = EndpointSnapshotsService()
# 从快照表获取端点 URL
queryset = service.get_by_scan(scan_id=self._scan_id)
for endpoint in queryset.iterator(chunk_size=1000):
yield endpoint.url
else:
# 其他类型暂不支持 iter_urls
logger.warning(
"快照类型 '%s' 不支持 iter_urls返回空迭代器",
self._snapshot_type
)
return
def iter_subdomains(self) -> Iterator[str]:
"""从 SubdomainSnapshot 迭代子域名列表"""
from apps.asset.services.snapshot import SubdomainSnapshotsService
service = SubdomainSnapshotsService()
yield from service.iter_subdomain_names_by_scan(
scan_id=self._scan_id,
chunk_size=1000
)
def iter_host_port_urls(self) -> Iterator[str]:
"""从 HostPortMappingSnapshot 生成待探测的 URL"""
from apps.asset.services.snapshot import HostPortMappingSnapshotsService
service = HostPortMappingSnapshotsService()
for mapping in service.iter_unique_host_ports_by_scan(
scan_id=self._scan_id,
batch_size=1000
):
host = mapping['host']
port = mapping['port']
if port == 80:
yield f"http://{host}"
elif port == 443:
yield f"https://{host}"
else:
yield f"http://{host}:{port}"
yield f"https://{host}:{port}"
def iter_websites(self) -> Iterator[str]:
"""从 WebsiteSnapshot 迭代网站 URL"""
from apps.asset.services.snapshot import WebsiteSnapshotsService
service = WebsiteSnapshotsService()
yield from service.iter_website_urls_by_scan(
scan_id=self._scan_id,
chunk_size=1000
)
def iter_endpoints(self) -> Iterator[str]:
"""从 EndpointSnapshot 迭代端点 URL"""
from apps.asset.services.snapshot import EndpointSnapshotsService
service = EndpointSnapshotsService()
queryset = service.get_by_scan(scan_id=self._scan_id)
for endpoint in queryset.iterator(chunk_size=1000):
yield endpoint.url
def get_blacklist_filter(self) -> None:
"""快照数据已在上一阶段过滤过了"""
return None
@property
def snapshot_type(self) -> SnapshotType:
"""返回快照类型"""
return self._snapshot_type

View File

@@ -1,256 +0,0 @@
"""
通用属性测试
包含跨多个 Provider 的通用属性测试:
- Property 4: Context Propagation
- Property 5: Non-Database Provider Blacklist Filter
- Property 7: CIDR Expansion Consistency
"""
import pytest
from hypothesis import given, strategies as st, settings
from ipaddress import IPv4Network
from apps.scan.providers import (
ProviderContext,
ListTargetProvider,
DatabaseTargetProvider,
PipelineTargetProvider,
SnapshotTargetProvider
)
from apps.scan.providers.pipeline_provider import StageOutput
class TestContextPropagation:
"""
Property 4: Context Propagation
*For any* ProviderContext传入 Provider 构造函数后,
Provider 的 target_id 和 scan_id 属性应该与 context 中的值一致。
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_list_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (ListTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
provider = ListTargetProvider(targets=["example.com"], context=ctx)
assert provider.target_id == target_id
assert provider.scan_id == scan_id
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_database_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (DatabaseTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=999, scan_id=scan_id)
# DatabaseTargetProvider 会覆盖 context 中的 target_id
provider = DatabaseTargetProvider(target_id=target_id, context=ctx)
assert provider.target_id == target_id # 使用构造函数参数
assert provider.scan_id == scan_id # 使用 context 中的值
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_pipeline_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (PipelineTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
assert provider.target_id == target_id
assert provider.scan_id == scan_id
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_snapshot_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (SnapshotTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=target_id, scan_id=999)
# SnapshotTargetProvider 会覆盖 context 中的 scan_id
provider = SnapshotTargetProvider(
scan_id=scan_id,
snapshot_type="subdomain",
context=ctx
)
assert provider.target_id == target_id # 使用 context 中的值
assert provider.scan_id == scan_id # 使用构造函数参数
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
class TestNonDatabaseProviderBlacklistFilter:
"""
Property 5: Non-Database Provider Blacklist Filter
*For any* ListTargetProvider 或 PipelineTargetProvider 实例,
get_blacklist_filter() 方法应该返回 None。
**Validates: Requirements 3.4, 9.4, 9.5**
"""
@given(targets=st.lists(st.text(min_size=1, max_size=20), max_size=10))
@settings(max_examples=100)
def test_property_5_list_provider_no_blacklist(self, targets):
"""
Property 5: Non-Database Provider Blacklist Filter (ListTargetProvider)
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
**Validates: Requirements 3.4, 9.4, 9.5**
"""
provider = ListTargetProvider(targets=targets)
assert provider.get_blacklist_filter() is None
@given(hosts=st.lists(st.text(min_size=1, max_size=20), max_size=10))
@settings(max_examples=100)
def test_property_5_pipeline_provider_no_blacklist(self, hosts):
"""
Property 5: Non-Database Provider Blacklist Filter (PipelineTargetProvider)
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
**Validates: Requirements 3.4, 9.4, 9.5**
"""
stage_output = StageOutput(hosts=hosts)
provider = PipelineTargetProvider(previous_output=stage_output)
assert provider.get_blacklist_filter() is None
def test_property_5_snapshot_provider_no_blacklist(self):
"""
Property 5: Non-Database Provider Blacklist Filter (SnapshotTargetProvider)
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
**Validates: Requirements 3.4, 9.4, 9.5**
"""
provider = SnapshotTargetProvider(scan_id=1, snapshot_type="subdomain")
assert provider.get_blacklist_filter() is None
class TestCIDRExpansionConsistency:
"""
Property 7: CIDR Expansion Consistency
*For any* CIDR 字符串(如 "192.168.1.0/24"),所有 Provider 的 iter_hosts()
方法应该将其展开为相同的单个 IP 地址列表。
**Validates: Requirements 1.1, 3.6**
"""
@given(
# 生成小的 CIDR 范围以避免测试超时
network_prefix=st.integers(min_value=1, max_value=254),
cidr_suffix=st.integers(min_value=28, max_value=30) # /28 = 16 IPs, /30 = 4 IPs
)
@settings(max_examples=50, deadline=None)
def test_property_7_cidr_expansion_consistency(self, network_prefix, cidr_suffix):
"""
Property 7: CIDR Expansion Consistency
Feature: scan-target-provider, Property 7: CIDR Expansion Consistency
**Validates: Requirements 1.1, 3.6**
For any CIDR string, all Providers should expand it to the same IP list.
"""
cidr = f"192.168.{network_prefix}.0/{cidr_suffix}"
# 计算预期的 IP 列表
network = IPv4Network(cidr, strict=False)
# 排除网络地址和广播地址
expected_ips = [str(ip) for ip in network.hosts()]
# 如果 CIDR 太小(/31 或 /32使用所有地址
if not expected_ips:
expected_ips = [str(ip) for ip in network]
# ListTargetProvider
list_provider = ListTargetProvider(targets=[cidr])
list_result = list(list_provider.iter_hosts())
# PipelineTargetProvider
stage_output = StageOutput(hosts=[cidr])
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
pipeline_result = list(pipeline_provider.iter_hosts())
# 验证:所有 Provider 展开的结果应该一致
assert list_result == expected_ips, f"ListProvider CIDR expansion mismatch for {cidr}"
assert pipeline_result == expected_ips, f"PipelineProvider CIDR expansion mismatch for {cidr}"
assert list_result == pipeline_result, f"Providers produce different results for {cidr}"
def test_cidr_expansion_with_multiple_cidrs(self):
"""测试多个 CIDR 的展开一致性"""
cidrs = ["192.168.1.0/30", "10.0.0.0/30"]
# 计算预期结果
expected_ips = []
for cidr in cidrs:
network = IPv4Network(cidr, strict=False)
expected_ips.extend([str(ip) for ip in network.hosts()])
# ListTargetProvider
list_provider = ListTargetProvider(targets=cidrs)
list_result = list(list_provider.iter_hosts())
# PipelineTargetProvider
stage_output = StageOutput(hosts=cidrs)
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
pipeline_result = list(pipeline_provider.iter_hosts())
# 验证
assert list_result == expected_ips
assert pipeline_result == expected_ips
assert list_result == pipeline_result
def test_mixed_hosts_and_cidrs(self):
"""测试混合主机和 CIDR 的处理"""
targets = ["example.com", "192.168.1.0/30", "test.com"]
# 计算预期结果
network = IPv4Network("192.168.1.0/30", strict=False)
cidr_ips = [str(ip) for ip in network.hosts()]
expected = ["example.com"] + cidr_ips + ["test.com"]
# ListTargetProvider
list_provider = ListTargetProvider(targets=targets)
list_result = list(list_provider.iter_hosts())
# 验证
assert list_result == expected

View File

@@ -2,7 +2,7 @@
DatabaseTargetProvider 属性测试
Property 7: DatabaseTargetProvider Blacklist Application
*For any* 带有黑名单规则的 target_idDatabaseTargetProvider 的 iter_hosts() 和 iter_urls()
*For any* 带有黑名单规则的 target_idDatabaseTargetProvider 的 iter_subdomains()
应该过滤掉匹配黑名单规则的目标。
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
@@ -48,7 +48,7 @@ class TestDatabaseTargetProviderProperties:
"""DatabaseTargetProvider 属性测试类"""
@given(
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=20),
subdomains=st.lists(valid_domain_strategy(), min_size=1, max_size=20),
blocked_keyword=st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
@@ -56,15 +56,15 @@ class TestDatabaseTargetProviderProperties:
)
)
@settings(max_examples=100)
def test_property_7_blacklist_filters_hosts(self, hosts, blocked_keyword):
def test_property_7_blacklist_filters_subdomains(self, subdomains, blocked_keyword):
"""
Property 7: DatabaseTargetProvider Blacklist Application (hosts)
Property 7: DatabaseTargetProvider Blacklist Application (subdomains)
Feature: scan-target-provider, Property 7: DatabaseTargetProvider Blacklist Application
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
For any set of hosts and a blacklist keyword, the provider should filter out
all hosts containing the blocked keyword.
For any set of subdomains and a blacklist keyword, the provider should filter out
all subdomains containing the blocked keyword.
"""
# 创建模拟的黑名单过滤器
mock_filter = MockBlacklistFilter([blocked_keyword])
@@ -73,31 +73,18 @@ class TestDatabaseTargetProviderProperties:
provider = DatabaseTargetProvider(target_id=1)
provider._blacklist_filter = mock_filter
# 模拟 Target 和 SubdomainService
mock_target = MagicMock()
mock_target.type = 'domain'
mock_target.name = hosts[0] if hosts else 'example.com'
with patch('apps.targets.services.TargetService') as mock_target_service, \
patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_subdomain_service:
mock_target_service.return_value.get_target.return_value = mock_target
mock_subdomain_service.return_value.iter_subdomain_names_by_target.return_value = iter(hosts[1:] if len(hosts) > 1 else [])
with patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_subdomain_service:
mock_subdomain_service.return_value.iter_subdomain_names_by_target.return_value = iter(subdomains)
# 获取结果
result = list(provider.iter_hosts())
result = list(provider.iter_subdomains())
# 验证:所有结果都不包含被阻止的关键词
for host in result:
assert blocked_keyword not in host, f"Host '{host}' should be filtered by blacklist keyword '{blocked_keyword}'"
# 验证:所有不包含关键词的主机都应该在结果中
if hosts:
all_hosts = [hosts[0]] + [h for h in hosts[1:] if h != hosts[0]]
expected_allowed = [h for h in all_hosts if blocked_keyword not in h]
else:
expected_allowed = []
for subdomain in result:
assert blocked_keyword not in subdomain, f"Subdomain '{subdomain}' should be filtered by blacklist keyword '{blocked_keyword}'"
# 验证:所有不包含关键词的子域名都应该在结果中
expected_allowed = [s for s in subdomains if blocked_keyword not in s]
assert set(result) == set(expected_allowed)
@@ -144,15 +131,38 @@ class TestDatabaseTargetProviderUnit:
# BlacklistService 只应该被调用一次
mock_service.return_value.get_rules.assert_called_once_with(123)
def test_nonexistent_target_returns_empty(self):
"""测试不存在的 target 返回空迭代器"""
def test_get_target_name(self):
"""测试 get_target_name 返回 Target 名称"""
provider = DatabaseTargetProvider(target_id=123)
mock_target = MagicMock()
mock_target.name = 'example.com'
with patch('apps.targets.services.TargetService') as mock_service:
mock_service.return_value.get_target.return_value = mock_target
result = provider.get_target_name()
assert result == 'example.com'
def test_get_target_name_nonexistent(self):
"""测试不存在的 target 返回 None"""
provider = DatabaseTargetProvider(target_id=99999)
with patch('apps.targets.services.TargetService') as mock_service, \
with patch('apps.targets.services.TargetService') as mock_service:
mock_service.return_value.get_target.return_value = None
result = provider.get_target_name()
assert result is None
def test_iter_subdomains_empty(self):
"""测试空子域名列表"""
provider = DatabaseTargetProvider(target_id=123)
with patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_service, \
patch('apps.common.services.BlacklistService') as mock_blacklist_service:
mock_service.return_value.get_target.return_value = None
mock_service.return_value.iter_subdomain_names_by_target.return_value = iter([])
mock_blacklist_service.return_value.get_rules.return_value = []
result = list(provider.iter_hosts())
result = list(provider.iter_subdomains())
assert result == []

View File

@@ -1,152 +0,0 @@
"""
ListTargetProvider 属性测试
Property 1: ListTargetProvider Round-Trip
*For any* 主机列表和 URL 列表,创建 ListTargetProvider 后迭代 iter_hosts() 和 iter_urls()
应该返回与输入相同的元素(顺序相同)。
**Validates: Requirements 3.1, 3.2**
"""
import pytest
from hypothesis import given, strategies as st, settings, assume
from apps.scan.providers.list_provider import ListTargetProvider
from apps.scan.providers.base import ProviderContext
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
# 生成简单的域名格式: subdomain.domain.tld
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
# 生成有效 IP 地址的策略
def valid_ip_strategy():
"""生成有效的 IPv4 地址"""
octet = st.integers(min_value=1, max_value=254)
return st.builds(
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
octet, octet, octet, octet
)
# 组合策略:域名或 IP
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
# 生成有效 URL 的策略
def valid_url_strategy():
"""生成有效的 URL"""
domain = valid_domain_strategy()
return st.builds(
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
domain,
st.one_of(
st.just(""),
st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=1,
max_size=10
)
)
)
url_strategy = valid_url_strategy()
class TestListTargetProviderProperties:
"""ListTargetProvider 属性测试类"""
@given(hosts=st.lists(host_strategy, max_size=50))
@settings(max_examples=100)
def test_property_1_hosts_round_trip(self, hosts):
"""
Property 1: ListTargetProvider Round-Trip (hosts)
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
**Validates: Requirements 3.1, 3.2**
For any host list, creating a ListTargetProvider and iterating iter_hosts()
should return the same elements in the same order.
"""
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
provider = ListTargetProvider(targets=hosts)
result = list(provider.iter_hosts())
assert result == hosts
@given(urls=st.lists(url_strategy, max_size=50))
@settings(max_examples=100)
def test_property_1_urls_round_trip(self, urls):
"""
Property 1: ListTargetProvider Round-Trip (urls)
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
**Validates: Requirements 3.1, 3.2**
For any URL list, creating a ListTargetProvider and iterating iter_urls()
should return the same elements in the same order.
"""
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
provider = ListTargetProvider(targets=urls)
result = list(provider.iter_urls())
assert result == urls
@given(
hosts=st.lists(host_strategy, max_size=30),
urls=st.lists(url_strategy, max_size=30)
)
@settings(max_examples=100)
def test_property_1_combined_round_trip(self, hosts, urls):
"""
Property 1: ListTargetProvider Round-Trip (combined)
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
**Validates: Requirements 3.1, 3.2**
For any combination of hosts and URLs, both should round-trip correctly.
"""
# 合并 hosts 和 urlsListTargetProvider 会自动分类
combined = hosts + urls
provider = ListTargetProvider(targets=combined)
hosts_result = list(provider.iter_hosts())
urls_result = list(provider.iter_urls())
assert hosts_result == hosts
assert urls_result == urls
class TestListTargetProviderUnit:
"""ListTargetProvider 单元测试类"""
def test_empty_lists(self):
"""测试空列表返回空迭代器 - Requirements 3.5"""
provider = ListTargetProvider()
assert list(provider.iter_hosts()) == []
assert list(provider.iter_urls()) == []
def test_blacklist_filter_returns_none(self):
"""测试黑名单过滤器返回 None - Requirements 3.4"""
provider = ListTargetProvider(targets=["example.com"])
assert provider.get_blacklist_filter() is None
def test_target_id_association(self):
"""测试 target_id 关联 - Requirements 3.3"""
ctx = ProviderContext(target_id=123)
provider = ListTargetProvider(targets=["example.com"], context=ctx)
assert provider.target_id == 123
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(target_id=456, scan_id=789)
provider = ListTargetProvider(targets=["example.com"], context=ctx)
assert provider.target_id == 456
assert provider.scan_id == 789

View File

@@ -1,180 +0,0 @@
"""
PipelineTargetProvider 属性测试
Property 3: PipelineTargetProvider Round-Trip
*For any* StageOutput 对象PipelineTargetProvider 的 iter_hosts() 和 iter_urls()
应该返回与 StageOutput 中 hosts 和 urls 列表相同的元素。
**Validates: Requirements 5.1, 5.2**
"""
import pytest
from hypothesis import given, strategies as st, settings
from apps.scan.providers.pipeline_provider import PipelineTargetProvider, StageOutput
from apps.scan.providers.base import ProviderContext
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
# 生成有效 IP 地址的策略
def valid_ip_strategy():
"""生成有效的 IPv4 地址"""
octet = st.integers(min_value=1, max_value=254)
return st.builds(
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
octet, octet, octet, octet
)
# 组合策略:域名或 IP
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
# 生成有效 URL 的策略
def valid_url_strategy():
"""生成有效的 URL"""
domain = valid_domain_strategy()
return st.builds(
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
domain,
st.one_of(
st.just(""),
st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=1,
max_size=10
)
)
)
url_strategy = valid_url_strategy()
class TestPipelineTargetProviderProperties:
"""PipelineTargetProvider 属性测试类"""
@given(hosts=st.lists(host_strategy, max_size=50))
@settings(max_examples=100)
def test_property_3_hosts_round_trip(self, hosts):
"""
Property 3: PipelineTargetProvider Round-Trip (hosts)
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
**Validates: Requirements 5.1, 5.2**
For any StageOutput with hosts, PipelineTargetProvider should return
the same hosts in the same order.
"""
stage_output = StageOutput(hosts=hosts)
provider = PipelineTargetProvider(previous_output=stage_output)
result = list(provider.iter_hosts())
assert result == hosts
@given(urls=st.lists(url_strategy, max_size=50))
@settings(max_examples=100)
def test_property_3_urls_round_trip(self, urls):
"""
Property 3: PipelineTargetProvider Round-Trip (urls)
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
**Validates: Requirements 5.1, 5.2**
For any StageOutput with urls, PipelineTargetProvider should return
the same urls in the same order.
"""
stage_output = StageOutput(urls=urls)
provider = PipelineTargetProvider(previous_output=stage_output)
result = list(provider.iter_urls())
assert result == urls
@given(
hosts=st.lists(host_strategy, max_size=30),
urls=st.lists(url_strategy, max_size=30)
)
@settings(max_examples=100)
def test_property_3_combined_round_trip(self, hosts, urls):
"""
Property 3: PipelineTargetProvider Round-Trip (combined)
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
**Validates: Requirements 5.1, 5.2**
For any StageOutput with both hosts and urls, both should round-trip correctly.
"""
stage_output = StageOutput(hosts=hosts, urls=urls)
provider = PipelineTargetProvider(previous_output=stage_output)
hosts_result = list(provider.iter_hosts())
urls_result = list(provider.iter_urls())
assert hosts_result == hosts
assert urls_result == urls
class TestPipelineTargetProviderUnit:
"""PipelineTargetProvider 单元测试类"""
def test_empty_stage_output(self):
"""测试空 StageOutput 返回空迭代器 - Requirements 5.5"""
stage_output = StageOutput()
provider = PipelineTargetProvider(previous_output=stage_output)
assert list(provider.iter_hosts()) == []
assert list(provider.iter_urls()) == []
def test_blacklist_filter_returns_none(self):
"""测试黑名单过滤器返回 None - Requirements 5.3"""
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output)
assert provider.get_blacklist_filter() is None
def test_target_id_association(self):
"""测试 target_id 关联 - Requirements 5.4"""
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output, target_id=123)
assert provider.target_id == 123
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(target_id=456, scan_id=789)
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
assert provider.target_id == 456
assert provider.scan_id == 789
def test_previous_output_property(self):
"""测试 previous_output 属性"""
stage_output = StageOutput(hosts=["example.com"], urls=["https://example.com"])
provider = PipelineTargetProvider(previous_output=stage_output)
assert provider.previous_output is stage_output
assert provider.previous_output.hosts == ["example.com"]
assert provider.previous_output.urls == ["https://example.com"]
def test_stage_output_with_metadata(self):
"""测试带元数据的 StageOutput"""
stage_output = StageOutput(
hosts=["example.com"],
urls=["https://example.com"],
new_targets=["new.example.com"],
stats={"count": 1},
success=True,
error=None
)
provider = PipelineTargetProvider(previous_output=stage_output)
assert list(provider.iter_hosts()) == ["example.com"]
assert list(provider.iter_urls()) == ["https://example.com"]
assert provider.previous_output.new_targets == ["new.example.com"]
assert provider.previous_output.stats == {"count": 1}

View File

@@ -10,182 +10,112 @@ from apps.scan.providers import SnapshotTargetProvider, ProviderContext
class TestSnapshotTargetProvider:
"""SnapshotTargetProvider 测试类"""
def test_init_with_scan_id_and_type(self):
def test_init_with_scan_id(self):
"""测试初始化"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain"
)
provider = SnapshotTargetProvider(scan_id=100)
assert provider.scan_id == 100
assert provider.snapshot_type == "subdomain"
assert provider.target_id is None # 默认 context
assert provider.target_id is None
def test_init_with_context(self):
"""测试带 context 初始化"""
ctx = ProviderContext(target_id=1, scan_id=100)
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain",
context=ctx
)
provider = SnapshotTargetProvider(scan_id=100, context=ctx)
assert provider.scan_id == 100
assert provider.target_id == 1
assert provider.snapshot_type == "subdomain"
@patch('apps.asset.services.snapshot.SubdomainSnapshotsService')
def test_iter_hosts_subdomain(self, mock_service_class):
"""测试从子域名快照迭代主机"""
# Mock service
def test_iter_subdomains(self, mock_service_class):
"""测试从子域名快照迭代子域名"""
mock_service = Mock()
mock_service.iter_subdomain_names_by_scan.return_value = iter([
"a.example.com",
"b.example.com"
])
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain"
)
# 迭代主机
hosts = list(provider.iter_hosts())
assert hosts == ["a.example.com", "b.example.com"]
provider = SnapshotTargetProvider(scan_id=100)
subdomains = list(provider.iter_subdomains())
assert subdomains == ["a.example.com", "b.example.com"]
mock_service.iter_subdomain_names_by_scan.assert_called_once_with(
scan_id=100,
chunk_size=1000
)
@patch('apps.asset.services.snapshot.HostPortMappingSnapshotsService')
def test_iter_hosts_host_port(self, mock_service_class):
"""测试从主机端口映射快照迭代主机"""
# Mock queryset
mock_mapping1 = Mock()
mock_mapping1.host = "example.com"
mock_mapping1.port = 80
mock_mapping2 = Mock()
mock_mapping2.host = "example.com"
mock_mapping2.port = 443
mock_queryset = Mock()
mock_queryset.iterator.return_value = iter([mock_mapping1, mock_mapping2])
# Mock service
def test_iter_host_port_urls(self, mock_service_class):
"""测试从主机端口映射快照生成 URL"""
mock_service = Mock()
mock_service.get_by_scan.return_value = mock_queryset
mock_service.iter_unique_host_ports_by_scan.return_value = iter([
{'host': 'example.com', 'port': 80},
{'host': 'example.com', 'port': 443},
{'host': 'example.com', 'port': 8080},
])
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="host_port"
)
# 迭代主机
hosts = list(provider.iter_hosts())
assert hosts == ["example.com:80", "example.com:443"]
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
provider = SnapshotTargetProvider(scan_id=100)
urls = list(provider.iter_host_port_urls())
assert urls == [
"http://example.com",
"https://example.com",
"http://example.com:8080",
"https://example.com:8080",
]
@patch('apps.asset.services.snapshot.WebsiteSnapshotsService')
def test_iter_urls_website(self, mock_service_class):
def test_iter_websites(self, mock_service_class):
"""测试从网站快照迭代 URL"""
# Mock service
mock_service = Mock()
mock_service.iter_website_urls_by_scan.return_value = iter([
"http://example.com",
"https://example.com"
])
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="website"
)
# 迭代 URL
urls = list(provider.iter_urls())
provider = SnapshotTargetProvider(scan_id=100)
urls = list(provider.iter_websites())
assert urls == ["http://example.com", "https://example.com"]
mock_service.iter_website_urls_by_scan.assert_called_once_with(
scan_id=100,
chunk_size=1000
)
@patch('apps.asset.services.snapshot.EndpointSnapshotsService')
def test_iter_urls_endpoint(self, mock_service_class):
def test_iter_endpoints(self, mock_service_class):
"""测试从端点快照迭代 URL"""
# Mock queryset
mock_endpoint1 = Mock()
mock_endpoint1.url = "http://example.com/api/v1"
mock_endpoint2 = Mock()
mock_endpoint2.url = "http://example.com/api/v2"
mock_queryset = Mock()
mock_queryset.iterator.return_value = iter([mock_endpoint1, mock_endpoint2])
# Mock service
mock_service = Mock()
mock_service.get_by_scan.return_value = mock_queryset
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="endpoint"
)
# 迭代 URL
urls = list(provider.iter_urls())
provider = SnapshotTargetProvider(scan_id=100)
urls = list(provider.iter_endpoints())
assert urls == ["http://example.com/api/v1", "http://example.com/api/v2"]
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
def test_iter_hosts_unsupported_type(self):
"""测试不支持的快照类型iter_hosts"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="website" # website 不支持 iter_hosts
)
hosts = list(provider.iter_hosts())
assert hosts == []
def test_iter_urls_unsupported_type(self):
"""测试不支持的快照类型iter_urls"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain" # subdomain 不支持 iter_urls
)
urls = list(provider.iter_urls())
assert urls == []
def test_get_blacklist_filter(self):
"""测试黑名单过滤器(快照模式不使用黑名单)"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain"
)
provider = SnapshotTargetProvider(scan_id=100)
assert provider.get_blacklist_filter() is None
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(target_id=456, scan_id=789)
provider = SnapshotTargetProvider(
scan_id=100, # 会被 context 覆盖
snapshot_type="subdomain",
context=ctx
)
provider = SnapshotTargetProvider(scan_id=100, context=ctx)
assert provider.target_id == 456
assert provider.scan_id == 100 # scan_id 在 __init__ 中被设置
assert provider.scan_id == 100

View File

@@ -137,16 +137,14 @@ def main():
print("[2/4] 解析命令行参数...")
parser = argparse.ArgumentParser(description="执行扫描初始化 Flow")
parser.add_argument("--scan_id", type=int, required=True, help="扫描任务 ID")
parser.add_argument("--target_name", type=str, required=True, help="目标名称")
parser.add_argument("--target_id", type=int, required=True, help="目标 ID")
parser.add_argument("--scan_workspace_dir", type=str, required=True, help="扫描工作目录")
parser.add_argument("--engine_name", type=str, required=True, help="引擎名称")
parser.add_argument("--scheduled_scan_name", type=str, default=None, help="定时扫描任务名称(可选)")
args = parser.parse_args()
print(f"[2/4] ✓ 参数解析成功:")
print(f" scan_id: {args.scan_id}")
print(f" target_name: {args.target_name}")
print(f" target_id: {args.target_id}")
print(f" scan_workspace_dir: {args.scan_workspace_dir}")
print(f" engine_name: {args.engine_name}")
@@ -171,7 +169,6 @@ def main():
try:
result = initiate_scan_flow(
scan_id=args.scan_id,
target_name=args.target_name,
target_id=args.target_id,
scan_workspace_dir=args.scan_workspace_dir,
engine_name=args.engine_name,

View File

@@ -15,11 +15,11 @@ class ScanSerializer(serializers.ModelSerializer):
fields = [
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
'created_at', 'stopped_at', 'status', 'results_dir',
'container_ids', 'error_message'
'container_ids', 'error_message', 'scan_mode'
]
read_only_fields = [
'id', 'created_at', 'stopped_at', 'results_dir',
'container_ids', 'error_message', 'status'
'container_ids', 'error_message', 'status', 'scan_mode'
]
def get_target_name(self, obj):
@@ -39,9 +39,10 @@ class ScanHistorySerializer(serializers.ModelSerializer):
class Meta:
model = Scan
fields = [
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
'worker_name', 'created_at', 'status', 'error_message', 'summary',
'progress', 'current_stage', 'stage_progress', 'yaml_configuration'
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
'worker_name', 'created_at', 'status', 'error_message', 'summary',
'progress', 'current_stage', 'stage_progress', 'yaml_configuration',
'scan_mode'
]
def get_summary(self, obj):

View File

@@ -17,23 +17,15 @@ from .scan_state_service import ScanStateService
from .scan_control_service import ScanControlService
from .scan_stats_service import ScanStatsService
from .scheduled_scan_service import ScheduledScanService
from .target_export_service import (
TargetExportService,
create_export_service,
export_urls_with_fallback,
DataSource,
)
from .scan_input_target_service import ScanInputTargetService
__all__ = [
'ScanService', # 主入口(向后兼容)
'ScanService',
'ScanCreationService',
'ScanStateService',
'ScanControlService',
'ScanStatsService',
'ScheduledScanService',
'TargetExportService', # 目标导出服务
'create_export_service',
'export_urls_with_fallback',
'DataSource',
'ScanInputTargetService',
]

View File

@@ -5,13 +5,16 @@
"""
import logging
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional, Literal, List, Dict, Any
from urllib.parse import urlparse
from django.db import transaction
from apps.common.validators import validate_url, detect_input_type, validate_domain, validate_ip, validate_cidr, is_valid_ip
from apps.common.validators import (
validate_url, detect_input_type, validate_domain,
validate_ip, validate_cidr, is_valid_ip
)
from apps.targets.services.target_service import TargetService
from apps.targets.models import Target
from apps.asset.dtos import WebSiteDTO
@@ -24,98 +27,72 @@ logger = logging.getLogger(__name__)
@dataclass
class ParsedInputDTO:
"""
解析输入 DTO
只在快速扫描流程中使用
"""
"""解析输入 DTO只在快速扫描流程中使用"""
original_input: str
input_type: Literal['url', 'domain', 'ip', 'cidr']
target_name: str # host/domain/ip/cidr
target_name: str
target_type: Literal['domain', 'ip', 'cidr']
website_url: Optional[str] = None # 根 URLscheme://host[:port]
endpoint_url: Optional[str] = None # 完整 URL含路径
is_valid: bool = True
error: Optional[str] = None
website_url: Optional[str] = None
endpoint_url: Optional[str] = None
line_number: Optional[int] = None
# 验证状态放在嵌套结构中,减少顶层属性数量
validation: Dict[str, Any] = field(default_factory=lambda: {
'is_valid': True,
'error': None
})
@property
def is_valid(self) -> bool:
return self.validation.get('is_valid', True)
@property
def error(self) -> Optional[str]:
return self.validation.get('error')
class QuickScanService:
"""快速扫描服务 - 解析输入并创建资产"""
def __init__(self):
self.target_service = TargetService()
self.website_repo = DjangoWebSiteRepository()
self.endpoint_repo = DjangoEndpointRepository()
def parse_inputs(self, inputs: List[str]) -> List[ParsedInputDTO]:
"""
解析多行输入
Args:
inputs: 输入字符串列表(每行一个)
Returns:
解析结果列表(跳过空行)
"""
"""解析多行输入,返回解析结果列表(跳过空行)"""
results = []
for line_number, input_str in enumerate(inputs, start=1):
input_str = input_str.strip()
# 空行跳过
if not input_str:
continue
try:
# 检测输入类型
input_type = detect_input_type(input_str)
if input_type == 'url':
dto = self._parse_url_input(input_str, line_number)
else:
dto = self._parse_target_input(input_str, input_type, line_number)
results.append(dto)
except ValueError as e:
# 解析失败,记录错误
results.append(ParsedInputDTO(
original_input=input_str,
input_type='domain', # 默认类型
input_type='domain',
target_name=input_str,
target_type='domain',
is_valid=False,
error=str(e),
line_number=line_number
line_number=line_number,
validation={'is_valid': False, 'error': str(e)}
))
return results
def _parse_url_input(self, url_str: str, line_number: int) -> ParsedInputDTO:
"""
解析 URL 输入
Args:
url_str: URL 字符串
line_number: 行号
Returns:
ParsedInputDTO
"""
# 验证 URL 格式
"""解析 URL 输入"""
validate_url(url_str)
# 使用标准库解析
parsed = urlparse(url_str)
host = parsed.hostname # 不含端口
host = parsed.hostname
has_path = parsed.path and parsed.path != '/'
# 构建 root_url: scheme://host[:port]
root_url = f"{parsed.scheme}://{parsed.netloc}"
# 检测 host 类型domain 或 ip
target_type = 'ip' if is_valid_ip(host) else 'domain'
return ParsedInputDTO(
original_input=url_str,
input_type='url',
@@ -125,167 +102,98 @@ class QuickScanService:
endpoint_url=url_str if has_path else None,
line_number=line_number
)
def _parse_target_input(
self,
input_str: str,
input_type: str,
self,
input_str: str,
input_type: str,
line_number: int
) -> ParsedInputDTO:
"""
解析非 URL 输入domain/ip/cidr
Args:
input_str: 输入字符串
input_type: 输入类型
line_number: 行号
Returns:
ParsedInputDTO
"""
# 验证格式
if input_type == 'domain':
validate_domain(input_str)
target_type = 'domain'
elif input_type == 'ip':
validate_ip(input_str)
target_type = 'ip'
elif input_type == 'cidr':
validate_cidr(input_str)
target_type = 'cidr'
else:
"""解析非 URL 输入domain/ip/cidr"""
validators = {
'domain': (validate_domain, 'domain'),
'ip': (validate_ip, 'ip'),
'cidr': (validate_cidr, 'cidr'),
}
if input_type not in validators:
raise ValueError(f"未知的输入类型: {input_type}")
validator, target_type = validators[input_type]
validator(input_str)
return ParsedInputDTO(
original_input=input_str,
input_type=input_type,
target_name=input_str,
target_type=target_type,
website_url=None,
endpoint_url=None,
line_number=line_number
)
@transaction.atomic
def process_quick_scan(
self,
inputs: List[str],
engine_id: int
) -> Dict[str, Any]:
"""
处理快速扫描请求
Args:
inputs: 输入字符串列表
engine_id: 扫描引擎 ID
Returns:
处理结果字典
"""
# 1. 解析输入
def process_quick_scan(self, inputs: List[str]) -> Dict[str, Any]:
"""处理快速扫描请求"""
parsed_inputs = self.parse_inputs(inputs)
# 分离有效和无效输入
valid_inputs = [p for p in parsed_inputs if p.is_valid]
invalid_inputs = [p for p in parsed_inputs if not p.is_valid]
errors = [
{'line_number': p.line_number, 'input': p.original_input, 'error': p.error}
for p in invalid_inputs
]
if not valid_inputs:
return {
'targets': [],
'target_stats': {'created': 0, 'reused': 0, 'failed': len(invalid_inputs)},
'asset_stats': {'websites_created': 0, 'endpoints_created': 0},
'errors': [
{'line_number': p.line_number, 'input': p.original_input, 'error': p.error}
for p in invalid_inputs
]
'errors': errors
}
# 2. 创建资产
asset_result = self.create_assets_from_parsed_inputs(valid_inputs)
# 3. 返回结果
# 构建 target_name → inputs 映射
target_inputs_map: Dict[str, List[str]] = {}
for p in valid_inputs:
target_inputs_map.setdefault(p.target_name, []).append(p.original_input)
return {
'targets': asset_result['targets'],
'target_stats': asset_result['target_stats'],
'asset_stats': asset_result['asset_stats'],
'errors': [
{'line_number': p.line_number, 'input': p.original_input, 'error': p.error}
for p in invalid_inputs
]
'target_inputs_map': target_inputs_map,
'errors': errors
}
def create_assets_from_parsed_inputs(
self,
self,
parsed_inputs: List[ParsedInputDTO]
) -> Dict[str, Any]:
"""
从解析结果创建资产
Args:
parsed_inputs: 解析结果列表(只包含有效输入)
Returns:
创建结果字典
"""
# 1. 收集所有 target 数据(内存操作,去重)
targets_data = {}
for dto in parsed_inputs:
if dto.target_name not in targets_data:
targets_data[dto.target_name] = {'name': dto.target_name, 'type': dto.target_type}
"""从解析结果创建资产(只包含有效输入)"""
# 1. 收集并去重 target 数据
targets_data = {
dto.target_name: {'name': dto.target_name, 'type': dto.target_type}
for dto in parsed_inputs
}
targets_list = list(targets_data.values())
# 2. 批量创建 Target(复用现有方法)
# 2. 批量创建 Target
target_result = self.target_service.batch_create_targets(targets_list)
# 3. 查询刚创建的 Target建立 name → id 映射
# 3. 建立 name → id 映射
target_names = [d['name'] for d in targets_list]
targets = Target.objects.filter(name__in=target_names)
target_id_map = {t.name: t.id for t in targets}
# 4. 收集 Website DTO内存操作去重
website_dtos = []
seen_websites = set()
for dto in parsed_inputs:
if dto.website_url and dto.website_url not in seen_websites:
seen_websites.add(dto.website_url)
target_id = target_id_map.get(dto.target_name)
if target_id:
website_dtos.append(WebSiteDTO(
target_id=target_id,
url=dto.website_url,
host=dto.target_name
))
# 5. 批量创建 Website存在即跳过
websites_created = 0
if website_dtos:
websites_created = self.website_repo.bulk_create_ignore_conflicts(website_dtos)
# 6. 收集 Endpoint DTO内存操作去重
endpoint_dtos = []
seen_endpoints = set()
for dto in parsed_inputs:
if dto.endpoint_url and dto.endpoint_url not in seen_endpoints:
seen_endpoints.add(dto.endpoint_url)
target_id = target_id_map.get(dto.target_name)
if target_id:
endpoint_dtos.append(EndpointDTO(
target_id=target_id,
url=dto.endpoint_url,
host=dto.target_name
))
# 7. 批量创建 Endpoint存在即跳过
endpoints_created = 0
if endpoint_dtos:
endpoints_created = self.endpoint_repo.bulk_create_ignore_conflicts(endpoint_dtos)
# 4. 批量创建 Website 和 Endpoint
websites_created = self._bulk_create_websites(parsed_inputs, target_id_map)
endpoints_created = self._bulk_create_endpoints(parsed_inputs, target_id_map)
return {
'targets': list(targets),
'target_stats': {
'created': target_result['created_count'],
'reused': 0, # bulk_create 无法区分新建和复用
'reused': 0,
'failed': target_result['failed_count']
},
'asset_stats': {
@@ -293,3 +201,53 @@ class QuickScanService:
'endpoints_created': endpoints_created
}
}
def _bulk_create_websites(
self,
parsed_inputs: List[ParsedInputDTO],
target_id_map: Dict[str, int]
) -> int:
"""批量创建 Website存在即跳过"""
website_dtos = []
seen = set()
for dto in parsed_inputs:
if not dto.website_url or dto.website_url in seen:
continue
seen.add(dto.website_url)
target_id = target_id_map.get(dto.target_name)
if target_id:
website_dtos.append(WebSiteDTO(
target_id=target_id,
url=dto.website_url,
host=dto.target_name
))
if not website_dtos:
return 0
return self.website_repo.bulk_create_ignore_conflicts(website_dtos)
def _bulk_create_endpoints(
self,
parsed_inputs: List[ParsedInputDTO],
target_id_map: Dict[str, int]
) -> int:
"""批量创建 Endpoint存在即跳过"""
endpoint_dtos = []
seen = set()
for dto in parsed_inputs:
if not dto.endpoint_url or dto.endpoint_url in seen:
continue
seen.add(dto.endpoint_url)
target_id = target_id_map.get(dto.target_name)
if target_id:
endpoint_dtos.append(EndpointDTO(
target_id=target_id,
url=dto.endpoint_url,
host=dto.target_name
))
if not endpoint_dtos:
return 0
return self.endpoint_repo.bulk_create_ignore_conflicts(endpoint_dtos)

View File

@@ -283,7 +283,8 @@ class ScanCreationService:
engine_ids: List[int],
engine_names: List[str],
yaml_configuration: str,
scheduled_scan_name: str | None = None
scheduled_scan_name: str | None = None,
scan_mode: str = 'full'
) -> List[Scan]:
"""
为多个目标批量创建扫描任务,后台异步分发到 Worker
@@ -294,6 +295,7 @@ class ScanCreationService:
engine_names: 引擎名称列表
yaml_configuration: YAML 格式的扫描配置
scheduled_scan_name: 定时扫描任务名称(可选,用于通知显示)
scan_mode: 扫描模式,'full''quick'(默认 'full'
Returns:
创建的 Scan 对象列表(立即返回,不等待分发完成)
@@ -316,6 +318,7 @@ class ScanCreationService:
results_dir=scan_workspace_dir,
status=ScanStatus.INITIATED,
container_ids=[],
scan_mode=scan_mode,
)
scans_to_create.append(scan)
except (ValidationError, ValueError) as e:
@@ -392,13 +395,13 @@ class ScanCreationService:
for data in scan_data:
scan_id = data['scan_id']
logger.info("-"*40)
logger.info("准备分发扫描任务 - Scan ID: %s, Target: %s", scan_id, data['target_name'])
logger.info("准备分发扫描任务 - Scan ID: %s, Target ID: %s", scan_id, data['target_id'])
try:
logger.info("调用 distributor.execute_scan_flow...")
success, message, container_id, worker_id = distributor.execute_scan_flow(
scan_id=scan_id,
target_name=data['target_name'],
target_id=data['target_id'],
target_name=data['target_name'],
scan_workspace_dir=data['results_dir'],
engine_name=data['engine_name'],
scheduled_scan_name=data.get('scheduled_scan_name'),

View File

@@ -0,0 +1,54 @@
"""
扫描输入目标服务
提供 ScanInputTarget 的写入操作。
"""
import logging
from typing import List
from apps.common.validators import detect_input_type
from apps.scan.models import ScanInputTarget
logger = logging.getLogger(__name__)
class ScanInputTargetService:
"""扫描输入目标服务,负责批量写入操作。"""
BATCH_SIZE = 1000
def bulk_create(self, scan_id: int, inputs: List[str]) -> int:
"""
批量创建扫描输入目标
Args:
scan_id: 扫描任务 ID
inputs: 输入字符串列表
Returns:
创建的记录数
"""
if not inputs:
return 0
records = []
for raw_input in inputs:
value = raw_input.strip()
if not value:
continue
try:
records.append(ScanInputTarget(
scan_id=scan_id,
value=value,
input_type=detect_input_type(value)
))
except ValueError as e:
logger.warning("跳过无效输入 '%s': %s", value, e)
if not records:
return 0
ScanInputTarget.objects.bulk_create(records, batch_size=self.BATCH_SIZE)
logger.info("批量创建 %d 条扫描输入目标 (scan_id=%d)", len(records), scan_id)
return len(records)

View File

@@ -1,25 +1,17 @@
"""
扫描任务服务
负责 Scan 模型的所有业务逻辑
负责 Scan 模型的所有业务逻辑,协调各个子服务
"""
from __future__ import annotations
import logging
import uuid
from typing import Dict, List, TYPE_CHECKING
from datetime import datetime
from pathlib import Path
from django.conf import settings
from django.db import transaction
from django.db.utils import DatabaseError, IntegrityError, OperationalError
from django.core.exceptions import ValidationError, ObjectDoesNotExist
from typing import Dict, List
from apps.scan.models import Scan
from apps.scan.repositories import DjangoScanRepository
from apps.targets.repositories import DjangoTargetRepository, DjangoOrganizationRepository
from apps.engine.repositories import DjangoEngineRepository
from apps.targets.models import Target
from apps.engine.models import ScanEngine
from apps.common.definitions import ScanStatus
@@ -30,115 +22,84 @@ logger = logging.getLogger(__name__)
class ScanService:
"""
扫描任务服务(协调者)
职责:
- 协调各个子服务
- 提供统一的公共接口
- 保持向后兼容
注意:
- 具体业务逻辑已拆分到子服务
- 本类主要负责委托和协调
职责:协调各个子服务,提供统一的公共接口
"""
# 终态集合:这些状态一旦设置,不应该被覆盖
FINAL_STATUSES = {
ScanStatus.COMPLETED,
ScanStatus.FAILED,
ScanStatus.CANCELLED
}
def __init__(self):
"""
初始化服务
"""
# 初始化子服务
from apps.scan.services.scan_creation_service import ScanCreationService
from apps.scan.services.scan_state_service import ScanStateService
from apps.scan.services.scan_control_service import ScanControlService
from apps.scan.services.scan_stats_service import ScanStatsService
self.creation_service = ScanCreationService()
self.state_service = ScanStateService()
self.control_service = ScanControlService()
self.stats_service = ScanStatsService()
# 保留 ScanRepository用于 get_scan 方法)
self.scan_repo = DjangoScanRepository()
def get_scan(self, scan_id: int, prefetch_relations: bool) -> Scan | None:
"""
获取扫描任务(包含关联对象)
自动预加载 engine 和 target避免 N+1 查询问题
Args:
scan_id: 扫描任务 ID
Returns:
Scan 对象(包含 engine 和 target或 None
"""
"""获取扫描任务(包含关联对象)"""
return self.scan_repo.get_by_id(scan_id, prefetch_relations)
def get_all_scans(self, prefetch_relations: bool = True):
"""获取所有扫描任务"""
return self.scan_repo.get_all(prefetch_relations=prefetch_relations)
def prepare_initiate_scan(
self,
organization_id: int | None = None,
target_id: int | None = None,
engine_id: int | None = None
) -> tuple[List[Target], ScanEngine]:
"""
为创建扫描任务做准备,返回所需的目标列表和扫描引擎
"""
"""为创建扫描任务做准备,返回目标列表和扫描引擎"""
return self.creation_service.prepare_initiate_scan(
organization_id, target_id, engine_id
)
def prepare_initiate_scan_multi_engine(
self,
organization_id: int | None = None,
target_id: int | None = None,
engine_ids: List[int] | None = None
) -> tuple[List[Target], str, List[str], List[int]]:
"""
为创建多引擎扫描任务做准备
Returns:
(目标列表, 合并配置, 引擎名称列表, 引擎ID列表)
"""
"""为创建多引擎扫描任务做准备"""
return self.creation_service.prepare_initiate_scan_multi_engine(
organization_id, target_id, engine_ids
)
def create_scans(
self,
targets: List[Target],
engine_ids: List[int],
engine_names: List[str],
yaml_configuration: str,
scheduled_scan_name: str | None = None
scheduled_scan_name: str | None = None,
scan_mode: str = 'full'
) -> List[Scan]:
"""批量创建扫描任务(委托给 ScanCreationService"""
"""批量创建扫描任务"""
return self.creation_service.create_scans(
targets, engine_ids, engine_names, yaml_configuration, scheduled_scan_name
targets, engine_ids, engine_names, yaml_configuration, scheduled_scan_name, scan_mode
)
# ==================== 状态管理方法(委托给 ScanStateService ====================
# ==================== 状态管理方法 ====================
def update_status(
self,
scan_id: int,
status: ScanStatus,
self,
scan_id: int,
status: ScanStatus,
error_message: str | None = None,
stopped_at: datetime | None = None
) -> bool:
"""更新 Scan 状态(委托给 ScanStateService"""
return self.state_service.update_status(
scan_id, status, error_message, stopped_at
)
"""更新 Scan 状态"""
return self.state_service.update_status(scan_id, status, error_message, stopped_at)
def update_status_if_match(
self,
scan_id: int,
@@ -146,113 +107,56 @@ class ScanService:
new_status: ScanStatus,
stopped_at: datetime | None = None
) -> bool:
"""条件更新 Scan 状态(委托给 ScanStateService"""
"""条件更新 Scan 状态"""
return self.state_service.update_status_if_match(
scan_id, current_status, new_status, stopped_at
)
def update_cached_stats(self, scan_id: int) -> dict | None:
"""更新缓存统计数据(委托给 ScanStateService,返回统计数据字典"""
"""更新缓存统计数据,返回统计数据字典"""
return self.state_service.update_cached_stats(scan_id)
# ==================== 进度跟踪方法(委托给 ScanStateService ====================
# ==================== 进度跟踪方法 ====================
def init_stage_progress(self, scan_id: int, stages: list[str]) -> bool:
"""初始化阶段进度(委托给 ScanStateService"""
"""初始化阶段进度"""
return self.state_service.init_stage_progress(scan_id, stages)
def start_stage(self, scan_id: int, stage: str) -> bool:
"""开始执行某个阶段(委托给 ScanStateService"""
"""开始执行某个阶段"""
return self.state_service.start_stage(scan_id, stage)
def complete_stage(self, scan_id: int, stage: str, detail: str | None = None) -> bool:
"""完成某个阶段(委托给 ScanStateService"""
"""完成某个阶段"""
return self.state_service.complete_stage(scan_id, stage, detail)
def fail_stage(self, scan_id: int, stage: str, error: str | None = None) -> bool:
"""标记某个阶段失败(委托给 ScanStateService"""
"""标记某个阶段失败"""
return self.state_service.fail_stage(scan_id, stage, error)
def cancel_running_stages(self, scan_id: int, final_status: str = "cancelled") -> bool:
"""取消所有正在运行的阶段(委托给 ScanStateService"""
"""取消所有正在运行的阶段"""
return self.state_service.cancel_running_stages(scan_id, final_status)
# TODO待接入
def add_command_to_scan(self, scan_id: int, stage_name: str, tool_name: str, command: str) -> bool:
"""
增量添加命令到指定扫描阶段
Args:
scan_id: 扫描任务ID
stage_name: 阶段名称(如 'subdomain_discovery', 'port_scan'
tool_name: 工具名称
command: 执行命令
Returns:
bool: 是否成功添加
"""
try:
scan = self.get_scan(scan_id, prefetch_relations=False)
if not scan:
logger.error(f"扫描任务不存在: {scan_id}")
return False
stage_progress = scan.stage_progress or {}
# 确保指定阶段存在
if stage_name not in stage_progress:
stage_progress[stage_name] = {'status': 'running', 'commands': []}
# 确保 commands 列表存在
if 'commands' not in stage_progress[stage_name]:
stage_progress[stage_name]['commands'] = []
# 增量添加命令
command_entry = f"{tool_name}: {command}"
stage_progress[stage_name]['commands'].append(command_entry)
scan.stage_progress = stage_progress
scan.save(update_fields=['stage_progress'])
command_count = len(stage_progress[stage_name]['commands'])
logger.info(f"✓ 记录命令: {stage_name}.{tool_name} (总计: {command_count})")
return True
except Exception as e:
logger.error(f"记录命令失败: {e}")
return False
# ==================== 删除和控制方法(委托给 ScanControlService ====================
# ==================== 删除和控制方法 ====================
def delete_scans_two_phase(self, scan_ids: List[int]) -> dict:
"""两阶段删除扫描任务(委托给 ScanControlService"""
"""两阶段删除扫描任务"""
return self.control_service.delete_scans_two_phase(scan_ids)
def stop_scan(self, scan_id: int) -> tuple[bool, int]:
"""停止扫描任务(委托给 ScanControlService"""
"""停止扫描任务"""
return self.control_service.stop_scan(scan_id)
def hard_delete_scans(self, scan_ids: List[int]) -> tuple[int, Dict[str, int]]:
"""
硬删除扫描任务(真正删除数据)
用于 Worker 容器中执行,删除已软删除的扫描及其关联数据。
Args:
scan_ids: 扫描任务 ID 列表
Returns:
(删除数量, 详情字典)
"""
"""硬删除扫描任务(真正删除数据)"""
return self.scan_repo.hard_delete_by_ids(scan_ids)
# ==================== 统计方法(委托给 ScanStatsService ====================
# ==================== 统计方法 ====================
def get_statistics(self) -> dict:
"""获取扫描统计数据(委托给 ScanStatsService"""
"""获取扫描统计数据"""
return self.stats_service.get_statistics()
# 导出接口
__all__ = ['ScanService']

View File

@@ -1,613 +0,0 @@
"""
目标导出服务
提供统一的目标提取和文件导出功能,支持:
- URL 导出(纯导出,不做隐式回退)
- 默认 URL 生成(独立方法)
- 带回退链的 URL 导出(用例层编排)
- 域名/IP 导出(用于端口扫描)
- 黑名单过滤集成
"""
import ipaddress
import logging
from pathlib import Path
from typing import Dict, Any, Optional, List, Iterator, Tuple
from django.db.models import QuerySet
from apps.common.utils import BlacklistFilter
logger = logging.getLogger(__name__)
class DataSource:
"""数据源类型常量"""
ENDPOINT = "endpoint"
WEBSITE = "website"
HOST_PORT = "host_port"
DEFAULT = "default"
def create_export_service(target_id: int) -> 'TargetExportService':
"""
工厂函数:创建带黑名单过滤的导出服务
Args:
target_id: 目标 ID用于加载黑名单规则
Returns:
TargetExportService: 配置好黑名单过滤器的导出服务实例
"""
from apps.common.services import BlacklistService
rules = BlacklistService().get_rules(target_id)
blacklist_filter = BlacklistFilter(rules)
return TargetExportService(blacklist_filter=blacklist_filter)
def _iter_default_urls_from_target(
target_id: int,
blacklist_filter: Optional[BlacklistFilter] = None
) -> Iterator[str]:
"""
内部生成器:从 Target 本身生成默认 URL
根据 Target 类型生成 URL
- DOMAIN: http(s)://domain
- IP: http(s)://ip
- CIDR: 展开为所有 IP 的 http(s)://ip
- URL: 直接使用目标 URL
Args:
target_id: 目标 ID
blacklist_filter: 黑名单过滤器
Yields:
str: URL
"""
from apps.targets.services import TargetService
from apps.targets.models import Target
target_service = TargetService()
target = target_service.get_target(target_id)
if not target:
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
return
target_name = target.name
target_type = target.type
# 根据 Target 类型生成 URL
if target_type == Target.TargetType.DOMAIN:
urls = [f"http://{target_name}", f"https://{target_name}"]
elif target_type == Target.TargetType.IP:
urls = [f"http://{target_name}", f"https://{target_name}"]
elif target_type == Target.TargetType.CIDR:
try:
network = ipaddress.ip_network(target_name, strict=False)
urls = []
for ip in network.hosts():
urls.extend([f"http://{ip}", f"https://{ip}"])
# /32 或 /128 特殊处理
if not urls:
ip = str(network.network_address)
urls = [f"http://{ip}", f"https://{ip}"]
except ValueError as e:
logger.error("CIDR 解析失败: %s - %s", target_name, e)
return
elif target_type == Target.TargetType.URL:
urls = [target_name]
else:
logger.warning("不支持的 Target 类型: %s", target_type)
return
# 过滤并产出
for url in urls:
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
yield url
def _iter_urls_with_fallback(
target_id: int,
sources: List[str],
blacklist_filter: Optional[BlacklistFilter] = None,
batch_size: int = 1000,
tried_sources: Optional[List[str]] = None
) -> Iterator[Tuple[str, str]]:
"""
内部生成器:流式产出 URL带回退链
按 sources 顺序尝试每个数据源,直到有数据返回。
回退逻辑:
- 数据源有数据且通过过滤 → 产出 URL停止回退
- 数据源有数据但全被过滤 → 不回退,停止(避免意外暴露)
- 数据源为空 → 继续尝试下一个
Args:
target_id: 目标 ID
sources: 数据源优先级列表
blacklist_filter: 黑名单过滤器
batch_size: 批次大小
tried_sources: 可选,用于记录尝试过的数据源(外部传入列表,会被修改)
Yields:
Tuple[str, str]: (url, source) - URL 和来源标识
"""
from apps.asset.models import Endpoint, WebSite
for source in sources:
if tried_sources is not None:
tried_sources.append(source)
has_output = False # 是否有输出(通过过滤的)
has_raw_data = False # 是否有原始数据(过滤前)
if source == DataSource.DEFAULT:
# 默认 URL 生成(从 Target 本身构造,复用共用生成器)
for url in _iter_default_urls_from_target(target_id, blacklist_filter):
has_raw_data = True
has_output = True
yield url, source
# 检查是否有原始数据(需要单独判断,因为生成器可能被过滤后为空)
if not has_raw_data:
# 再次检查 Target 是否存在
from apps.targets.services import TargetService
target = TargetService().get_target(target_id)
has_raw_data = target is not None
if has_raw_data:
if not has_output:
logger.info("%s 有数据但全被黑名单过滤,不回退", source)
return
continue
# 构建对应数据源的 queryset
if source == DataSource.ENDPOINT:
queryset = Endpoint.objects.filter(target_id=target_id).values_list('url', flat=True)
elif source == DataSource.WEBSITE:
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
else:
logger.warning("未知的数据源类型: %s,跳过", source)
continue
for url in queryset.iterator(chunk_size=batch_size):
if url:
has_raw_data = True
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
has_output = True
yield url, source
# 有原始数据就停止(不管是否被过滤)
if has_raw_data:
if not has_output:
logger.info("%s 有数据但全被黑名单过滤,不回退", source)
return
logger.info("%s 为空,尝试下一个数据源", source)
def get_urls_with_fallback(
target_id: int,
sources: List[str],
batch_size: int = 1000
) -> Dict[str, Any]:
"""
带回退链的 URL 获取用例函数(返回列表)
按 sources 顺序尝试每个数据源,直到有数据返回。
Args:
target_id: 目标 ID
sources: 数据源优先级列表,如 ["website", "endpoint", "default"]
batch_size: 批次大小
Returns:
dict: {
'success': bool,
'urls': List[str],
'total_count': int,
'source': str, # 实际使用的数据源
'tried_sources': List[str], # 尝试过的数据源
}
"""
from apps.common.services import BlacklistService
rules = BlacklistService().get_rules(target_id)
blacklist_filter = BlacklistFilter(rules)
urls = []
actual_source = 'none'
tried_sources = []
for url, source in _iter_urls_with_fallback(target_id, sources, blacklist_filter, batch_size, tried_sources):
urls.append(url)
actual_source = source
if urls:
logger.info("%s 获取 %d 条 URL", actual_source, len(urls))
else:
logger.warning("所有数据源都为空,无法获取 URL")
return {
'success': True,
'urls': urls,
'total_count': len(urls),
'source': actual_source,
'tried_sources': tried_sources,
}
def export_urls_with_fallback(
target_id: int,
output_file: str,
sources: List[str],
batch_size: int = 1000
) -> Dict[str, Any]:
"""
带回退链的 URL 导出用例函数(写入文件)
按 sources 顺序尝试每个数据源,直到有数据返回。
流式写入,内存占用 O(1)。
Args:
target_id: 目标 ID
output_file: 输出文件路径
sources: 数据源优先级列表,如 ["endpoint", "website", "default"]
batch_size: 批次大小
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int,
'source': str, # 实际使用的数据源
'tried_sources': List[str], # 尝试过的数据源
}
"""
from apps.common.services import BlacklistService
rules = BlacklistService().get_rules(target_id)
blacklist_filter = BlacklistFilter(rules)
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
total_count = 0
actual_source = 'none'
tried_sources = []
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url, source in _iter_urls_with_fallback(target_id, sources, blacklist_filter, batch_size, tried_sources):
f.write(f"{url}\n")
total_count += 1
actual_source = source
if total_count % 10000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
if total_count > 0:
logger.info("%s 导出 %d 条 URL 到 %s", actual_source, total_count, output_file)
else:
logger.warning("所有数据源都为空,无法导出 URL")
return {
'success': True,
'output_file': str(output_path),
'total_count': total_count,
'source': actual_source,
'tried_sources': tried_sources,
}
class TargetExportService:
"""
目标导出服务 - 提供统一的目标提取和文件导出功能
使用方式:
# 方式 1使用用例函数推荐
from apps.scan.services.target_export_service import export_urls_with_fallback, DataSource
result = export_urls_with_fallback(
target_id=1,
output_file='/path/to/output.txt',
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT]
)
# 方式 2直接使用 Service纯导出不带回退
export_service = create_export_service(target_id)
result = export_service.export_urls(target_id, output_path, queryset)
"""
def __init__(self, blacklist_filter: Optional[BlacklistFilter] = None):
"""
初始化导出服务
Args:
blacklist_filter: 黑名单过滤器None 表示禁用过滤
"""
self.blacklist_filter = blacklist_filter
def export_urls(
self,
target_id: int,
output_path: str,
queryset: QuerySet,
url_field: str = 'url',
batch_size: int = 1000
) -> Dict[str, Any]:
"""
纯 URL 导出函数 - 只负责将 queryset 数据写入文件
不做任何隐式回退或默认 URL 生成。
Args:
target_id: 目标 ID
output_path: 输出文件路径
queryset: 数据源 queryset由调用方构建应为 values_list flat=True
url_field: URL 字段名(用于黑名单过滤)
batch_size: 批次大小
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int, # 实际写入数量
'queryset_count': int, # 原始数据数量(迭代计数)
'filtered_count': int, # 被黑名单过滤的数量
}
Raises:
IOError: 文件写入失败
"""
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
logger.info("开始导出 URL - target_id=%s, output=%s", target_id, output_path)
total_count = 0
filtered_count = 0
queryset_count = 0
try:
with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
for url in queryset.iterator(chunk_size=batch_size):
queryset_count += 1
if url:
# 黑名单过滤
if self.blacklist_filter and not self.blacklist_filter.is_allowed(url):
filtered_count += 1
continue
f.write(f"{url}\n")
total_count += 1
if total_count % 10000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
except IOError as e:
logger.error("文件写入失败: %s - %s", output_path, e)
raise
if filtered_count > 0:
logger.info("黑名单过滤: 过滤 %d 个 URL", filtered_count)
logger.info(
"✓ URL 导出完成 - 写入: %d, 原始: %d, 过滤: %d, 文件: %s",
total_count, queryset_count, filtered_count, output_path
)
return {
'success': True,
'output_file': str(output_file),
'total_count': total_count,
'queryset_count': queryset_count,
'filtered_count': filtered_count,
}
def generate_default_urls(
self,
target_id: int,
output_path: str
) -> Dict[str, Any]:
"""
默认 URL 生成器
根据 Target 类型生成默认 URL
- DOMAIN: http(s)://domain
- IP: http(s)://ip
- CIDR: 展开为所有 IP 的 http(s)://ip
- URL: 直接使用目标 URL
Args:
target_id: 目标 ID
output_path: 输出文件路径
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int,
}
"""
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
logger.info("生成默认 URL - target_id=%d", target_id)
total_urls = 0
with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
for url in _iter_default_urls_from_target(target_id, self.blacklist_filter):
f.write(f"{url}\n")
total_urls += 1
if total_urls % 10000 == 0:
logger.info("已生成 %d 个 URL...", total_urls)
logger.info("✓ 默认 URL 生成完成 - 数量: %d", total_urls)
return {
'success': True,
'output_file': str(output_file),
'total_count': total_urls,
}
def export_hosts(
self,
target_id: int,
output_path: str,
batch_size: int = 1000
) -> Dict[str, Any]:
"""
主机列表导出函数(用于端口扫描)
根据 Target 类型选择导出逻辑:
- DOMAIN: 从 Subdomain 表流式导出子域名
- IP: 直接写入 IP 地址
- CIDR: 展开为所有主机 IP
Args:
target_id: 目标 ID
output_path: 输出文件路径
batch_size: 批次大小
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int,
'target_type': str
}
"""
from apps.targets.services import TargetService
from apps.targets.models import Target
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
# 获取 Target 信息
target_service = TargetService()
target = target_service.get_target(target_id)
if not target:
raise ValueError(f"Target ID {target_id} 不存在")
target_type = target.type
target_name = target.name
logger.info(
"开始导出主机列表 - Target ID: %d, Name: %s, Type: %s, 输出文件: %s",
target_id, target_name, target_type, output_path
)
total_count = 0
if target_type == Target.TargetType.DOMAIN:
total_count = self._export_domains(target_id, target_name, output_file, batch_size)
type_desc = "域名"
elif target_type == Target.TargetType.IP:
total_count = self._export_ip(target_name, output_file)
type_desc = "IP"
elif target_type == Target.TargetType.CIDR:
total_count = self._export_cidr(target_name, output_file)
type_desc = "CIDR IP"
else:
raise ValueError(f"不支持的目标类型: {target_type}")
logger.info(
"✓ 主机列表导出完成 - 类型: %s, 总数: %d, 文件: %s",
type_desc, total_count, output_path
)
return {
'success': True,
'output_file': str(output_file),
'total_count': total_count,
'target_type': target_type
}
def _export_domains(
self,
target_id: int,
target_name: str,
output_path: Path,
batch_size: int
) -> int:
"""导出域名类型目标的根域名 + 子域名"""
from apps.asset.services.asset.subdomain_service import SubdomainService
subdomain_service = SubdomainService()
domain_iterator = subdomain_service.iter_subdomain_names_by_target(
target_id=target_id,
chunk_size=batch_size
)
total_count = 0
written_domains = set() # 去重(子域名表可能已包含根域名)
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
# 1. 先写入根域名
if self._should_write_target(target_name):
f.write(f"{target_name}\n")
written_domains.add(target_name)
total_count += 1
# 2. 再写入子域名(跳过已写入的根域名)
for domain_name in domain_iterator:
if domain_name in written_domains:
continue
if self._should_write_target(domain_name):
f.write(f"{domain_name}\n")
written_domains.add(domain_name)
total_count += 1
if total_count % 10000 == 0:
logger.info("已导出 %d 个域名...", total_count)
return total_count
def _export_ip(self, target_name: str, output_path: Path) -> int:
"""导出 IP 类型目标"""
if self._should_write_target(target_name):
with open(output_path, 'w', encoding='utf-8') as f:
f.write(f"{target_name}\n")
return 1
return 0
def _export_cidr(self, target_name: str, output_path: Path) -> int:
"""导出 CIDR 类型目标,展开为每个 IP"""
network = ipaddress.ip_network(target_name, strict=False)
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for ip in network.hosts():
ip_str = str(ip)
if self._should_write_target(ip_str):
f.write(f"{ip_str}\n")
total_count += 1
if total_count % 10000 == 0:
logger.info("已导出 %d 个 IP...", total_count)
# /32 或 /128 特殊处理
if total_count == 0:
ip_str = str(network.network_address)
if self._should_write_target(ip_str):
with open(output_path, 'w', encoding='utf-8') as f:
f.write(f"{ip_str}\n")
total_count = 1
return total_count
def _should_write_target(self, target: str) -> bool:
"""检查目标是否应该写入(通过黑名单过滤)"""
if self.blacklist_filter:
return self.blacklist_filter.is_allowed(target)
return True

View File

@@ -18,7 +18,7 @@ from .subdomain_discovery import (
# 指纹识别任务
from .fingerprint_detect import (
export_urls_for_fingerprint_task,
export_site_urls_for_fingerprint_task,
run_xingfinger_and_stream_update_tech_task,
)
@@ -35,6 +35,6 @@ __all__ = [
'merge_and_validate_task',
'save_domains_task',
# 指纹识别任务
'export_urls_for_fingerprint_task',
'export_site_urls_for_fingerprint_task',
'run_xingfinger_and_stream_update_tech_task',
]

View File

@@ -1,21 +1,14 @@
"""
导出站点 URL 到 TXT 文件的 Task
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
使用 TargetProvider 从任意数据源导出 URL用于目录扫描
数据源: WebSite.url → Default
数据源WebSite,为空时回退到默认 URL
"""
import logging
from typing import Optional
from pathlib import Path
from prefect import task
from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
@@ -23,94 +16,61 @@ logger = logging.getLogger(__name__)
@task(name="export_sites")
def export_sites_task(
target_id: Optional[int] = None,
output_file: str = "",
provider: Optional[TargetProvider] = None,
batch_size: int = 1000,
output_file: str,
provider: TargetProvider,
) -> dict:
"""
导出目标下的所有站点 URL 到 TXT 文件
支持两种模式:
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
数据源优先级(回退链,仅传统模式):
1. WebSite 表 - 站点级别 URL
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
数据源WebSite为空时回退到默认 URL
Args:
target_id: 目标 ID传统模式向后兼容
output_file: 输出文件路径(绝对路径)
provider: TargetProvider 实例(新模式)
batch_size: 每次读取的批次大小,默认 1000
provider: TargetProvider 实例
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int
'total_count': int,
'source': str, # website | default
}
Raises:
ValueError: 参数错误
IOError: 文件写入失败
ValueError: provider 未提供
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# Provider 模式:使用 TargetProvider 导出
if provider is not None:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
return _export_with_provider(output_file, provider)
# 传统模式:使用 export_urls_with_fallback
logger.info("使用传统模式 - Target ID: %d", target_id)
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
sources=[DataSource.WEBSITE, DataSource.DEFAULT],
batch_size=batch_size,
)
logger.info(
"站点 URL 导出完成 - source=%s, count=%d",
result['source'], result['total_count']
)
# 保持返回值格式不变(向后兼容)
return {
'success': result['success'],
'output_file': result['output_file'],
'total_count': result['total_count'],
}
if provider is None:
raise ValueError("必须提供 provider 参数")
logger.info("导出 URL - Provider: %s", type(provider).__name__)
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
"""使用 Provider 导出 URL"""
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 按优先级获取数据源
urls = list(provider.iter_websites())
source = "website"
if not urls:
logger.info("WebSite 为空,生成默认 URL")
urls = list(provider.iter_default_urls())
source = "default"
# 写入文件
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
for url in urls:
f.write(f"{url}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
logger.info(
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
source, total_count, str(output_path)
)
return {
'success': True,
'output_file': str(output_path),
'total_count': total_count,
'source': source,
}

View File

@@ -2,14 +2,14 @@
指纹识别任务模块
包含:
- export_urls_for_fingerprint_task: 导出 URL 到文件
- export_site_urls_for_fingerprint_task: 导出站点 URL 到文件
- run_xingfinger_and_stream_update_tech_task: 流式执行 xingfinger 并更新 tech
"""
from .export_urls_task import export_urls_for_fingerprint_task
from .export_site_urls_task import export_site_urls_for_fingerprint_task
from .run_xingfinger_task import run_xingfinger_and_stream_update_tech_task
__all__ = [
'export_urls_for_fingerprint_task',
'export_site_urls_for_fingerprint_task',
'run_xingfinger_and_stream_update_tech_task',
]

View File

@@ -0,0 +1,73 @@
"""
导出站点 URL 任务
使用 TargetProvider 从任意数据源导出站点 URL用于指纹识别
数据源WebSite为空时回退到默认 URL
"""
import logging
from pathlib import Path
from prefect import task
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
@task(name="export_site_urls_for_fingerprint")
def export_site_urls_for_fingerprint_task(
output_file: str,
provider: TargetProvider,
) -> dict:
"""
导出目标下的 URL 到文件(用于指纹识别)
数据源WebSite为空时回退到默认 URL
Args:
output_file: 输出文件路径
provider: TargetProvider 实例
Returns:
dict: {
'output_file': str,
'total_count': int,
'source': str, # website | default
}
"""
if provider is None:
raise ValueError("必须提供 provider 参数")
logger.info("导出 URL - Provider: %s", type(provider).__name__)
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 按优先级获取数据源
urls = list(provider.iter_websites())
source = "website"
if not urls:
logger.info("WebSite 为空,生成默认 URL")
urls = list(provider.iter_default_urls())
source = "default"
# 写入文件
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in urls:
f.write(f"{url}\n")
total_count += 1
logger.info(
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
source, total_count, str(output_path)
)
return {
'output_file': str(output_path),
'total_count': total_count,
'source': source,
}

View File

@@ -1,112 +0,0 @@
"""
导出 URL 任务
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
用于指纹识别前导出目标下的 URL 到文件
"""
import logging
from typing import Optional
from pathlib import Path
from prefect import task
from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
logger = logging.getLogger(__name__)
@task(name="export_urls_for_fingerprint")
def export_urls_for_fingerprint_task(
target_id: Optional[int] = None,
output_file: str = "",
source: str = 'website', # 保留参数,兼容旧调用(实际值由回退链决定)
provider: Optional[TargetProvider] = None,
batch_size: int = 1000
) -> dict:
"""
导出目标下的 URL 到文件(用于指纹识别)
支持两种模式:
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
数据源优先级(回退链,仅传统模式):
1. WebSite 表 - 站点级别 URL
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
Args:
target_id: 目标 ID传统模式向后兼容
output_file: 输出文件路径
source: 数据源类型(保留参数,兼容旧调用,实际值由回退链决定)
provider: TargetProvider 实例(新模式)
batch_size: 批量读取大小
Returns:
dict: {'output_file': str, 'total_count': int, 'source': str}
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# Provider 模式:使用 TargetProvider 导出
if provider is not None:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
return _export_with_provider(output_file, provider)
# 传统模式:使用 export_urls_with_fallback
logger.info("使用传统模式 - Target ID: %d", target_id)
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
sources=[DataSource.WEBSITE, DataSource.DEFAULT],
batch_size=batch_size,
)
logger.info(
"指纹识别 URL 导出完成 - source=%s, count=%d",
result['source'], result['total_count']
)
# 返回实际使用的数据源(不再固定为 "website"
return {
'output_file': result['output_file'],
'total_count': result['total_count'],
'source': result['source'],
}
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
"""使用 Provider 导出 URL"""
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
f.write(f"{url}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
return {
'output_file': str(output_path),
'total_count': total_count,
'source': 'provider',
}

View File

@@ -1,22 +1,14 @@
"""
导出主机列表到 TXT 文件的 Task
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
根据 Target 类型决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名
- IP: 直接写入 target.name
- CIDR: 展开 CIDR 范围内的所有 IP
使用 TargetProvider 从任意数据源导出主机列表。
"""
import logging
from pathlib import Path
from typing import Optional
from prefect import task
from apps.scan.providers import DatabaseTargetProvider, TargetProvider
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
@@ -24,76 +16,56 @@ logger = logging.getLogger(__name__)
@task(name="export_hosts")
def export_hosts_task(
output_file: str,
target_id: Optional[int] = None,
provider: Optional[TargetProvider] = None,
provider: TargetProvider,
) -> dict:
"""
导出主机列表到 TXT 文件
支持两种模式:
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
根据 Target 类型自动决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名(流式处理,支持 10万+ 域名)
- IP: 直接写入 target.name单个 IP
- CIDR: 展开 CIDR 范围内的所有可用 IP
显式组合 iter_target_hosts() + iter_subdomains()。
Args:
output_file: 输出文件路径(绝对路径)
target_id: 目标 ID传统模式向后兼容
provider: TargetProvider 实例(新模式)
provider: TargetProvider 实例
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int,
'target_type': str # 仅传统模式返回
}
Raises:
ValueError: 参数错误target_id 和 provider 未提供
ValueError: provider 未提供
IOError: 文件写入失败
"""
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
if provider is None:
raise ValueError("必须提供 provider 参数")
# 向后兼容:如果没有提供 provider使用 target_id 创建 DatabaseTargetProvider
use_legacy_mode = provider is None
if use_legacy_mode:
logger.info("使用传统模式 - Target ID: %d", target_id)
provider = DatabaseTargetProvider(target_id=target_id)
else:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
logger.info("导出主机列表 - Provider: %s", type(provider).__name__)
# 确保输出目录存在
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 使用 Provider 导出主机列表iter_hosts 内部已处理黑名单过滤)
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for host in provider.iter_hosts():
# 1. 导出 Target 主机CIDR 自动展开,已过滤黑名单)
for host in provider.iter_target_hosts():
f.write(f"{host}\n")
total_count += 1
# 2. 导出子域名Provider 内部已过滤黑名单)
for subdomain in provider.iter_subdomains():
f.write(f"{subdomain}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个主机...", total_count)
logger.info("✓ 主机列表导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
result = {
return {
'success': True,
'output_file': str(output_path),
'total_count': total_count,
}
# 传统模式:保持返回值格式不变(向后兼容)
if use_legacy_mode:
from apps.targets.services import TargetService
target = TargetService().get_target(target_id)
result['target_type'] = target.type if target else 'unknown'
return result

View File

@@ -1,208 +1,76 @@
"""
导出站点URL到文件的Task
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
使用 TargetProvider 从任意数据源导出 URL用于 httpx 站点探测)。
特殊逻辑:
- 80 端口:只生成 HTTP URL省略端口号
- 443 端口:只生成 HTTPS URL省略端口号
- 其他端口:生成 HTTP 和 HTTPS 两个URL带端口号
数据源HostPortMapping为空时回退到默认 URL
"""
import logging
from typing import Optional
from pathlib import Path
from prefect import task
from apps.asset.services import HostPortMappingService
from apps.scan.services.target_export_service import create_export_service
from apps.common.services import BlacklistService
from apps.common.utils import BlacklistFilter
from apps.scan.providers import TargetProvider, DatabaseTargetProvider, ProviderContext
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
def _generate_urls_from_port(host: str, port: int) -> list[str]:
"""
根据端口生成 URL 列表
- 80 端口:只生成 HTTP URL省略端口号
- 443 端口:只生成 HTTPS URL省略端口号
- 其他端口:生成 HTTP 和 HTTPS 两个URL带端口号
"""
if port == 80:
return [f"http://{host}"]
elif port == 443:
return [f"https://{host}"]
else:
return [f"http://{host}:{port}", f"https://{host}:{port}"]
@task(name="export_site_urls")
def export_site_urls_task(
output_file: str,
target_id: Optional[int] = None,
provider: Optional[TargetProvider] = None,
batch_size: int = 1000
provider: TargetProvider,
) -> dict:
"""
导出目标下的所有站点URL到文件
支持两种模式:
1. 传统模式(向后兼容):传入 target_id从 HostPortMapping 表导出
2. Provider 模式:传入 provider从任意数据源导出
传统模式特殊逻辑:
- 80 端口:只生成 HTTP URL省略端口号
- 443 端口:只生成 HTTPS URL省略端口号
- 其他端口:生成 HTTP 和 HTTPS 两个URL带端口号
回退逻辑(仅传统模式):
- 如果 HostPortMapping 为空,使用 generate_default_urls() 生成默认 URL
数据源HostPortMapping为空时回退到默认 URL
Args:
output_file: 输出文件路径(绝对路径)
target_id: 目标ID传统模式向后兼容
provider: TargetProvider 实例(新模式)
batch_size: 每次处理的批次大小
provider: TargetProvider 实例
Returns:
dict: {
'success': bool,
'output_file': str,
'total_urls': int,
'association_count': int, # 主机端口关联数量(仅传统模式)
'source': str, # 数据来源: "host_port" | "default" | "provider"
'source': str, # host_port | default
}
Raises:
ValueError: 参数错误
IOError: 文件写入失败
ValueError: provider 未提供
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# 向后兼容:如果没有提供 provider使用传统模式
if provider is None:
logger.info("使用传统模式 - Target ID: %d, 输出文件: %s", target_id, output_file)
return _export_site_urls_legacy(target_id, output_file, batch_size)
# Provider 模式
logger.info("使用 Provider 模式 - Provider: %s, 输出文件: %s", type(provider).__name__, output_file)
# 确保输出目录存在
raise ValueError("必须提供 provider 参数")
logger.info("导出 URL - Provider: %s", type(provider).__name__)
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 使用 Provider 导出 URL 列表
# 按优先级获取数据源
urls = list(provider.iter_host_port_urls())
source = "host_port"
if not urls:
logger.info("HostPortMapping 为空,生成默认 URL")
urls = list(provider.iter_default_urls())
source = "default"
# 写入文件
total_urls = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
for url in urls:
f.write(f"{url}\n")
total_urls += 1
if total_urls % 1000 == 0:
logger.info("已导出 %d 个URL...", total_urls)
logger.info("✓ URL导出完成 - 总数: %d, 文件: %s", total_urls, str(output_path))
return {
'success': True,
'output_file': str(output_path),
'total_urls': total_urls,
'source': 'provider',
}
def _export_site_urls_legacy(target_id: int, output_file: str, batch_size: int) -> dict:
"""
传统模式:从 HostPortMapping 表导出 URL
保持原有逻辑不变,确保向后兼容
"""
logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file)
# 确保输出目录存在
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 获取规则并创建过滤器
blacklist_filter = BlacklistFilter(BlacklistService().get_rules(target_id))
# 直接查询 HostPortMapping 表,按 host 排序
service = HostPortMappingService()
associations = service.iter_host_port_by_target(
target_id=target_id,
batch_size=batch_size,
)
total_urls = 0
association_count = 0
filtered_count = 0
# 流式写入文件(特殊端口逻辑)
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for assoc in associations:
association_count += 1
host = assoc['host']
port = assoc['port']
# 先校验 host通过了再生成 URL
if not blacklist_filter.is_allowed(host):
filtered_count += 1
continue
# 根据端口号生成URL
for url in _generate_urls_from_port(host, port):
f.write(f"{url}\n")
total_urls += 1
if association_count % 1000 == 0:
logger.info("已处理 %d 条关联,生成 %d 个URL...", association_count, total_urls)
if filtered_count > 0:
logger.info("黑名单过滤: 过滤 %d 条关联", filtered_count)
logger.info(
"站点URL导出完成 - 关联数: %d, 总URL数: %d, 文件: %s",
association_count, total_urls, str(output_path)
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
source, total_urls, str(output_path)
)
# 判断数据来源
source = "host_port"
# 数据存在但全被过滤,不回退
if association_count > 0 and total_urls == 0:
logger.info("HostPortMapping 有 %d 条数据,但全被黑名单过滤,不回退", association_count)
return {
'success': True,
'output_file': str(output_path),
'total_urls': 0,
'association_count': association_count,
'source': source,
}
# 数据源为空,回退到默认 URL 生成
if total_urls == 0:
logger.info("HostPortMapping 为空,使用默认 URL 生成")
export_service = create_export_service(target_id)
result = export_service.generate_default_urls(target_id, str(output_path))
total_urls = result['total_count']
source = "default"
return {
'success': True,
'output_file': str(output_path),
'total_urls': total_urls,
'association_count': association_count,
'source': source,
}

View File

@@ -119,7 +119,8 @@ def merge_and_validate_task(result_files: List[str], result_dir: str) -> str:
unique_count = sum(1 for _ in f)
if unique_count == 0:
raise RuntimeError("未找到任何有效域名")
logger.warning("未找到任何有效域名,返回空文件")
# 不抛出异常,返回空文件让后续流程正常处理
file_size_kb = merged_file.stat().st_size / 1024
logger.info("✓ 合并去重完成 - 去重后: %d 个域名, 文件大小: %.2f KB", unique_count, file_size_kb)

View File

@@ -1,23 +1,16 @@
"""
导出站点 URL 列表任务
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
使用 TargetProvider 从任意数据源导出 URL用于 katana 等爬虫工具)。
数据源: WebSite.url → Default用于 katana 等爬虫工具)
数据源WebSite,为空时回退到默认 URL
"""
import logging
from typing import Optional
from pathlib import Path
from prefect import task
from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
@@ -29,92 +22,58 @@ logger = logging.getLogger(__name__)
)
def export_sites_task(
output_file: str,
target_id: Optional[int] = None,
scan_id: Optional[int] = None,
provider: Optional[TargetProvider] = None,
batch_size: int = 1000
provider: TargetProvider,
) -> dict:
"""
导出站点 URL 列表到文件(用于 katana 等爬虫工具)
支持两种模式:
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
数据源优先级(回退链,仅传统模式):
1. WebSite 表 - 站点级别 URL
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
数据源WebSite为空时回退到默认 URL
Args:
output_file: 输出文件路径
target_id: 目标 ID传统模式向后兼容
scan_id: 扫描 ID保留参数兼容旧调用
provider: TargetProvider 实例(新模式)
batch_size: 批次大小(内存优化)
provider: TargetProvider 实例
Returns:
dict: {
'output_file': str, # 输出文件路径
'asset_count': int, # 资产数量
'output_file': str,
'asset_count': int,
'source': str, # website | default
}
Raises:
ValueError: 参数错误
RuntimeError: 执行失败
ValueError: provider 未提供
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# Provider 模式:使用 TargetProvider 导出
if provider is not None:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
return _export_with_provider(output_file, provider)
# 传统模式:使用 export_urls_with_fallback
logger.info("使用传统模式 - Target ID: %d", target_id)
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
sources=[DataSource.WEBSITE, DataSource.DEFAULT],
batch_size=batch_size,
)
logger.info(
"站点 URL 导出完成 - source=%s, count=%d",
result['source'], result['total_count']
)
# 保持返回值格式不变(向后兼容)
return {
'output_file': result['output_file'],
'asset_count': result['total_count'],
}
if provider is None:
raise ValueError("必须提供 provider 参数")
logger.info("导出 URL - Provider: %s", type(provider).__name__)
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
"""使用 Provider 导出 URL"""
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 按优先级获取数据源
urls = list(provider.iter_websites())
source = "website"
if not urls:
logger.info("WebSite 为空,生成默认 URL")
urls = list(provider.iter_default_urls())
source = "default"
# 写入文件
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
for url in urls:
f.write(f"{url}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
logger.info(
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
source, total_count, str(output_path)
)
return {
'output_file': str(output_path),
'asset_count': total_count,
'source': source,
}

View File

@@ -2,18 +2,21 @@
包含:
- export_endpoints_task: 导出端点 URL 到文件
- export_websites_task: 导出网站 URL 到文件
- run_vuln_tool_task: 执行漏洞扫描工具(非流式)
- run_and_stream_save_dalfox_vulns_task: Dalfox 流式执行并保存漏洞结果
- run_and_stream_save_nuclei_vulns_task: Nuclei 流式执行并保存漏洞结果
"""
from .export_endpoints_task import export_endpoints_task
from .export_websites_task import export_websites_task
from .run_vuln_tool_task import run_vuln_tool_task
from .run_and_stream_save_dalfox_vulns_task import run_and_stream_save_dalfox_vulns_task
from .run_and_stream_save_nuclei_vulns_task import run_and_stream_save_nuclei_vulns_task
__all__ = [
"export_endpoints_task",
"export_websites_task",
"run_vuln_tool_task",
"run_and_stream_save_dalfox_vulns_task",
"run_and_stream_save_nuclei_vulns_task",

View File

@@ -1,118 +1,74 @@
"""导出 Endpoint URL 到文件的 Task
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
使用 TargetProvider 从任意数据源导出 URL。
数据源优先级(回退链,仅传统模式):
1. Endpoint.url - 最精细的 URL含路径、参数等
2. WebSite.url - 站点级别 URL
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
数据源Endpoint为空时回退到默认 URL
"""
import logging
from typing import Dict, Optional
from typing import Dict
from pathlib import Path
from prefect import task
from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
@task(name="export_endpoints")
def export_endpoints_task(
target_id: Optional[int] = None,
output_file: str = "",
provider: Optional[TargetProvider] = None,
batch_size: int = 1000,
output_file: str,
provider: TargetProvider,
) -> Dict[str, object]:
"""导出目标下的所有 Endpoint URL 到文本文件。
支持两种模式:
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
数据源优先级(回退链,仅传统模式):
1. Endpoint 表 - 最精细的 URL含路径、参数等
2. WebSite 表 - 站点级别 URL
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
数据源优先级Endpoint → 默认生成
Args:
target_id: 目标 ID传统模式向后兼容
output_file: 输出文件路径(绝对路径)
provider: TargetProvider 实例(新模式)
batch_size: 每次从数据库迭代的批大小
provider: TargetProvider 实例
Returns:
dict: {
"success": bool,
"output_file": str,
"total_count": int,
"source": str, # 数据来源: "endpoint" | "website" | "default" | "none" | "provider"
"source": str, # endpoint | default
}
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# Provider 模式:使用 TargetProvider 导出
if provider is not None:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
return _export_with_provider(output_file, provider)
# 传统模式:使用 export_urls_with_fallback
logger.info("使用传统模式 - Target ID: %d", target_id)
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
batch_size=batch_size,
)
logger.info(
"URL 导出完成 - source=%s, count=%d, tried=%s",
result['source'], result['total_count'], result['tried_sources']
)
return {
"success": result['success'],
"output_file": result['output_file'],
"total_count": result['total_count'],
"source": result['source'],
}
if provider is None:
raise ValueError("必须提供 provider 参数")
logger.info("导出 URL - Provider: %s", type(provider).__name__)
def _export_with_provider(output_file: str, provider: TargetProvider) -> Dict[str, object]:
"""使用 Provider 导出 URL"""
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 获取数据,为空时回退到默认 URL
urls = list(provider.iter_endpoints())
source = "endpoint"
if not urls:
logger.info("Endpoint 为空,生成默认 URL")
urls = list(provider.iter_default_urls())
source = "default"
# 写入文件
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
for url in urls:
f.write(f"{url}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
logger.info(
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
source, total_count, str(output_path)
)
return {
"success": True,
"output_file": str(output_path),
"total_count": total_count,
"source": "provider",
"source": source,
}

View File

@@ -0,0 +1,73 @@
"""导出 WebSite URL 到文件的 Task
使用 TargetProvider 从任意数据源导出 URL。
数据源WebSite为空时回退到默认 URL
"""
import logging
from pathlib import Path
from prefect import task
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
@task(name="export_websites_for_vuln_scan")
def export_websites_task(
output_file: str,
provider: TargetProvider,
) -> dict:
"""导出目标下的所有 WebSite URL 到文本文件。
数据源优先级WebSite → 默认生成
Args:
output_file: 输出文件路径(绝对路径)
provider: TargetProvider 实例
Returns:
dict: {
"success": bool,
"output_file": str,
"total_count": int,
"source": str, # website | default
}
"""
if provider is None:
raise ValueError("必须提供 provider 参数")
logger.info("导出 URL - Provider: %s", type(provider).__name__)
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 获取数据,为空时回退到默认 URL
urls = list(provider.iter_websites())
source = "website"
if not urls:
logger.info("WebSite 为空,生成默认 URL")
urls = list(provider.iter_default_urls())
source = "default"
# 写入文件
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in urls:
f.write(f"{url}\n")
total_count += 1
logger.info(
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
source, total_count, str(output_path)
)
return {
"success": True,
"output_file": str(output_path),
"total_count": total_count,
"source": source,
}

View File

@@ -410,6 +410,14 @@ class CommandExecutor:
# 关键修复:确保进程树被清理
if process:
self._kill_process_tree(process)
# 回收子进程,避免产生 zombie 进程
try:
process.wait(timeout=GRACEFUL_SHUTDOWN_TIMEOUT)
except subprocess.TimeoutExpired:
# kill 之后仍未退出:避免阻塞,继续清理后续资源
pass
except Exception:
pass
# 关闭文件句柄
if log_file_handle:

View File

@@ -1,94 +1,76 @@
from rest_framework import viewsets, status
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.exceptions import NotFound, APIException
from rest_framework.filters import SearchFilter
from django_filters.rest_framework import DjangoFilterBackend
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.db.utils import DatabaseError, IntegrityError, OperationalError
"""扫描任务视图集"""
import logging
from apps.common.response_helpers import success_response, error_response
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.db.utils import DatabaseError, IntegrityError, OperationalError
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework import status, viewsets
from rest_framework.decorators import action
from rest_framework.filters import SearchFilter
from apps.common.definitions import ScanStatus
from apps.common.error_codes import ErrorCodes
from apps.scan.utils.config_merger import ConfigConflictError
from apps.common.pagination import BasePagination
from apps.common.response_helpers import error_response, success_response
from apps.targets.repositories import DjangoOrganizationRepository, DjangoTargetRepository
from ..models import Scan
from ..serializers import (
InitiateScanSerializer,
QuickScanSerializer,
ScanHistorySerializer,
ScanSerializer,
)
from ..services.quick_scan_service import QuickScanService
from ..services.scan_input_target_service import ScanInputTargetService
from ..services.scan_service import ScanService
logger = logging.getLogger(__name__)
from ..models import Scan, ScheduledScan
from ..serializers import (
ScanSerializer, ScanHistorySerializer, QuickScanSerializer,
InitiateScanSerializer, ScheduledScanSerializer, CreateScheduledScanSerializer,
UpdateScheduledScanSerializer, ToggleScheduledScanSerializer
)
from ..services.scan_service import ScanService
from ..services.scheduled_scan_service import ScheduledScanService
from ..repositories import ScheduledScanDTO
from apps.targets.services.target_service import TargetService
from apps.targets.services.organization_service import OrganizationService
from apps.engine.services.engine_service import EngineService
from apps.common.definitions import ScanStatus
from apps.common.pagination import BasePagination
def _handle_database_error():
"""处理数据库错误的通用响应"""
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Database error',
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
)
class ScanViewSet(viewsets.ModelViewSet):
"""扫描任务视图集"""
serializer_class = ScanSerializer
pagination_class = BasePagination
filter_backends = [DjangoFilterBackend, SearchFilter]
filterset_fields = ['target'] # 支持 ?target=123 过滤
search_fields = ['target__name'] # 按目标名称搜索
filterset_fields = ['target']
search_fields = ['target__name']
def get_queryset(self):
"""优化查询集提升API性能
查询优化策略:
- select_related: 预加载 target 和 engine一对一/多对一关系,使用 JOIN
- 移除 prefetch_related: 避免加载大量资产数据到内存
- order_by: 按创建时间降序排列(最新创建的任务排在最前面)
性能优化原理:
- 列表页使用缓存统计字段cached_*_count避免实时 COUNT 查询
- 序列化器:严格验证缓存字段,确保数据一致性
- 分页场景每页只显示10条记录查询高效
- 避免大数据加载:不再预加载所有关联的资产数据
"""
# 只保留必要的 select_related移除所有 prefetch_related
"""优化查询集提升API性能"""
scan_service = ScanService()
queryset = scan_service.get_all_scans(prefetch_relations=True)
return queryset
return scan_service.get_all_scans(prefetch_relations=True)
def get_serializer_class(self):
"""根据不同的 action 返回不同的序列化器
- list action: 使用 ScanHistorySerializer包含 summary 和 progress
- retrieve action: 使用 ScanHistorySerializer包含 summary 和 progress
- 其他 action: 使用标准的 ScanSerializer
"""
"""根据不同的 action 返回不同的序列化器"""
if self.action in ['list', 'retrieve']:
return ScanHistorySerializer
return ScanSerializer
def destroy(self, request, *args, **kwargs):
"""
删除单个扫描任务(两阶段删除)
1. 软删除:立即对用户不可见
2. 硬删除:后台异步执行
"""
"""删除单个扫描任务(两阶段删除)"""
try:
scan = self.get_object()
scan_service = ScanService()
result = scan_service.delete_scans_two_phase([scan.id])
return success_response(
data={
'scanId': scan.id,
'deletedCount': result['soft_deleted_count'],
'deletedScans': result['scan_names']
}
)
return success_response(data={
'scanId': scan.id,
'deletedCount': result['soft_deleted_count'],
'deletedScans': result['scan_names']
})
except Scan.DoesNotExist:
return error_response(
code=ErrorCodes.NOT_FOUND,
@@ -100,80 +82,57 @@ class ScanViewSet(viewsets.ModelViewSet):
message=str(e),
status_code=status.HTTP_404_NOT_FOUND
)
except Exception as e:
except Exception:
logger.exception("删除扫描任务时发生错误")
return error_response(
code=ErrorCodes.SERVER_ERROR,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@action(detail=False, methods=['post'])
def quick(self, request):
"""
快速扫描接口
功能:
1. 接收目标列表和 YAML 配置
2. 自动解析输入(支持 URL、域名、IP、CIDR
3. 批量创建 Target、Website、Endpoint 资产
4. 立即发起批量扫描
请求参数:
{
"targets": [{"name": "example.com"}, {"name": "https://example.com/api"}],
"configuration": "subdomain_discovery:\n enabled: true\n ...",
"engine_ids": [1, 2], // 可选,用于记录
"engine_names": ["引擎A", "引擎B"] // 可选,用于记录
}
支持的输入格式:
- 域名: example.com
- IP: 192.168.1.1
- CIDR: 10.0.0.0/8
- URL: https://example.com/api/v1
4. 立即发起批量扫描scan_mode='quick'
5. 将用户输入写入 ScanInputTarget 表
"""
from ..services.quick_scan_service import QuickScanService
serializer = QuickScanSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
targets_data = serializer.validated_data['targets']
configuration = serializer.validated_data['configuration']
engine_ids = serializer.validated_data.get('engine_ids', [])
engine_names = serializer.validated_data.get('engine_names', [])
data = serializer.validated_data
try:
# 提取输入字符串列表
inputs = [t['name'] for t in targets_data]
# 1. 使用 QuickScanService 解析输入并创建资产
quick_scan_service = QuickScanService()
result = quick_scan_service.process_quick_scan(inputs, engine_ids[0] if engine_ids else None)
targets = result['targets']
if not targets:
inputs = [t['name'] for t in data['targets']]
# 1. 解析输入并创建资产
result = QuickScanService().process_quick_scan(inputs)
if not result['targets']:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='No valid targets for scanning',
details=result.get('errors', []),
status_code=status.HTTP_400_BAD_REQUEST
)
# 2. 直接使用前端传递的配置创建扫描
scan_service = ScanService()
created_scans = scan_service.create_scans(
targets=targets,
engine_ids=engine_ids,
engine_names=engine_names,
yaml_configuration=configuration
# 2. 创建扫描scan_mode='quick'
created_scans = ScanService().create_scans(
targets=result['targets'],
engine_ids=data.get('engine_ids', []),
engine_names=data.get('engine_names', []),
yaml_configuration=data['configuration'],
scan_mode='quick'
)
# 检查是否成功创建扫描任务
if not created_scans:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='No scan tasks were created. All targets may already have active scans.',
message='No scan tasks were created. '
'All targets may already have active scans.',
details={
'targetStats': result['target_stats'],
'assetStats': result['asset_stats'],
@@ -181,317 +140,210 @@ class ScanViewSet(viewsets.ModelViewSet):
},
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
)
# 序列化返回结果
scan_serializer = ScanSerializer(created_scans, many=True)
# 3. 将用户输入写入 ScanInputTarget 表
scan_input_service = ScanInputTargetService()
target_inputs_map = result.get('target_inputs_map', {})
for scan in created_scans:
inputs_for_target = target_inputs_map.get(scan.target.name, [])
if inputs_for_target:
scan_input_service.bulk_create(scan.id, inputs_for_target)
return success_response(
data={
'count': len(created_scans),
'targetStats': result['target_stats'],
'assetStats': result['asset_stats'],
'errors': result.get('errors', []),
'scans': scan_serializer.data
'scans': ScanSerializer(created_scans, many=True).data
},
status_code=status.HTTP_201_CREATED
)
except ValidationError as e:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message=str(e),
status_code=status.HTTP_400_BAD_REQUEST
)
except Exception as e:
except (DatabaseError, IntegrityError, OperationalError):
logger.exception("快速扫描启动失败")
return error_response(
code=ErrorCodes.SERVER_ERROR,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return _handle_database_error()
@action(detail=False, methods=['post'])
def initiate(self, request):
"""
发起扫描任务
请求参数:
- organization_id: 组织ID (int, 可选)
- target_id: 目标ID (int, 可选)
- configuration: YAML 配置字符串 (str, 必填)
- engine_ids: 扫描引擎ID列表 (list[int], 必填)
- engine_names: 引擎名称列表 (list[str], 必填)
注意: organization_id 和 target_id 二选一
返回:
- 扫描任务详情(单个或多个)
"""
# 使用 serializer 验证请求数据
serializer = InitiateScanSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
# 获取验证后的数据
organization_id = serializer.validated_data.get('organization_id')
target_id = serializer.validated_data.get('target_id')
configuration = serializer.validated_data['configuration']
engine_ids = serializer.validated_data['engine_ids']
engine_names = serializer.validated_data['engine_names']
data = serializer.validated_data
try:
# 获取目标列表
scan_service = ScanService()
if organization_id:
from apps.targets.repositories import DjangoOrganizationRepository
org_repo = DjangoOrganizationRepository()
organization = org_repo.get_by_id(organization_id)
if not organization:
raise ObjectDoesNotExist(f'Organization ID {organization_id} 不存在')
targets = org_repo.get_targets(organization_id)
if not targets:
raise ValidationError(f'组织 ID {organization_id} 下没有目标')
else:
from apps.targets.repositories import DjangoTargetRepository
target_repo = DjangoTargetRepository()
target = target_repo.get_by_id(target_id)
if not target:
raise ObjectDoesNotExist(f'Target ID {target_id} 不存在')
targets = [target]
# 直接使用前端传递的配置创建扫描
created_scans = scan_service.create_scans(
targets=targets,
engine_ids=engine_ids,
engine_names=engine_names,
yaml_configuration=configuration
targets = self._get_targets_for_scan(
data.get('organization_id'),
data.get('target_id')
)
# 检查是否成功创建扫描任务
created_scans = ScanService().create_scans(
targets=targets,
engine_ids=data['engine_ids'],
engine_names=data['engine_names'],
yaml_configuration=data['configuration']
)
if not created_scans:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='No scan tasks were created. All targets may already have active scans.',
message='No scan tasks were created. '
'All targets may already have active scans.',
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
)
# 序列化返回结果
scan_serializer = ScanSerializer(created_scans, many=True)
return success_response(
data={
'count': len(created_scans),
'scans': scan_serializer.data
'scans': ScanSerializer(created_scans, many=True).data
},
status_code=status.HTTP_201_CREATED
)
except ObjectDoesNotExist as e:
# 资源不存在错误(由 service 层抛出)
return error_response(
code=ErrorCodes.NOT_FOUND,
message=str(e),
status_code=status.HTTP_404_NOT_FOUND
)
except ValidationError as e:
# 参数验证错误(由 service 层抛出)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message=str(e),
status_code=status.HTTP_400_BAD_REQUEST
)
except (DatabaseError, IntegrityError, OperationalError):
# 数据库错误
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Database error',
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
)
return _handle_database_error()
# 所有快照相关的 action 和 export 已迁移到 asset/views.py 中的快照 ViewSet
# GET /api/scans/{id}/subdomains/ -> SubdomainSnapshotViewSet
# GET /api/scans/{id}/subdomains/export/ -> SubdomainSnapshotViewSet.export
# GET /api/scans/{id}/websites/ -> WebsiteSnapshotViewSet
# GET /api/scans/{id}/websites/export/ -> WebsiteSnapshotViewSet.export
# GET /api/scans/{id}/directories/ -> DirectorySnapshotViewSet
# GET /api/scans/{id}/directories/export/ -> DirectorySnapshotViewSet.export
# GET /api/scans/{id}/endpoints/ -> EndpointSnapshotViewSet
# GET /api/scans/{id}/endpoints/export/ -> EndpointSnapshotViewSet.export
# GET /api/scans/{id}/ip-addresses/ -> HostPortMappingSnapshotViewSet
# GET /api/scans/{id}/ip-addresses/export/ -> HostPortMappingSnapshotViewSet.export
# GET /api/scans/{id}/vulnerabilities/ -> VulnerabilitySnapshotViewSet
def _get_targets_for_scan(self, organization_id, target_id):
"""根据组织ID或目标ID获取扫描目标列表"""
if organization_id:
org_repo = DjangoOrganizationRepository()
organization = org_repo.get_by_id(organization_id)
if not organization:
raise ObjectDoesNotExist(f'Organization ID {organization_id} 不存在')
targets = org_repo.get_targets(organization_id)
if not targets:
raise ValidationError(f'组织 ID {organization_id} 下没有目标')
return targets
target_repo = DjangoTargetRepository()
target = target_repo.get_by_id(target_id)
if not target:
raise ObjectDoesNotExist(f'Target ID {target_id} 不存在')
return [target]
@action(detail=False, methods=['post', 'delete'], url_path='bulk-delete')
def bulk_delete(self, request):
"""
批量删除扫描记录
请求参数:
- ids: 扫描ID列表 (list[int], 必填)
示例请求:
POST /api/scans/bulk-delete/
{
"ids": [1, 2, 3]
}
返回:
- message: 成功消息
- deletedCount: 实际删除的记录数
注意:
- 使用级联删除,会同时删除关联的子域名、端点等数据
- 只删除存在的记录不存在的ID会被忽略
"""
"""批量删除扫描记录"""
ids = request.data.get('ids', [])
# 参数验证
if not ids:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Missing required parameter: ids',
status_code=status.HTTP_400_BAD_REQUEST
)
if not isinstance(ids, list):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='ids must be an array',
status_code=status.HTTP_400_BAD_REQUEST
)
if not all(isinstance(i, int) for i in ids):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='All elements in ids array must be integers',
status_code=status.HTTP_400_BAD_REQUEST
)
try:
# 使用 Service 层批量删除(两阶段删除)
scan_service = ScanService()
result = scan_service.delete_scans_two_phase(ids)
return success_response(
data={
'deletedCount': result['soft_deleted_count'],
'deletedScans': result['scan_names']
}
)
return success_response(data={
'deletedCount': result['soft_deleted_count'],
'deletedScans': result['scan_names']
})
except ValueError as e:
# 未找到记录
return error_response(
code=ErrorCodes.NOT_FOUND,
message=str(e),
status_code=status.HTTP_404_NOT_FOUND
)
except Exception as e:
except (DatabaseError, IntegrityError, OperationalError):
logger.exception("批量删除扫描任务时发生错误")
return error_response(
code=ErrorCodes.SERVER_ERROR,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return _handle_database_error()
@action(detail=False, methods=['get'])
def statistics(self, request):
"""
获取扫描统计数据
返回扫描任务的汇总统计信息,用于仪表板和扫描历史页面。
使用缓存字段聚合查询,性能优异。
返回:
- total: 总扫描次数
- running: 运行中的扫描数量
- completed: 已完成的扫描数量
- failed: 失败的扫描数量
- totalVulns: 总共发现的漏洞数量
- totalSubdomains: 总共发现的子域名数量
- totalEndpoints: 总共发现的端点数量
- totalAssets: 总资产数
"""
def statistics(self, request): # pylint: disable=unused-argument
"""获取扫描统计数据"""
try:
# 使用 Service 层获取统计数据
scan_service = ScanService()
stats = scan_service.get_statistics()
return success_response(
data={
'total': stats['total'],
'running': stats['running'],
'completed': stats['completed'],
'failed': stats['failed'],
'totalVulns': stats['total_vulns'],
'totalSubdomains': stats['total_subdomains'],
'totalEndpoints': stats['total_endpoints'],
'totalWebsites': stats['total_websites'],
'totalAssets': stats['total_assets'],
}
)
stats = ScanService().get_statistics()
return success_response(data={
'total': stats['total'],
'running': stats['running'],
'completed': stats['completed'],
'failed': stats['failed'],
'totalVulns': stats['total_vulns'],
'totalSubdomains': stats['total_subdomains'],
'totalEndpoints': stats['total_endpoints'],
'totalWebsites': stats['total_websites'],
'totalAssets': stats['total_assets'],
})
except (DatabaseError, OperationalError):
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Database error',
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
)
return _handle_database_error()
@action(detail=True, methods=['post'])
def stop(self, request, pk=None): # pylint: disable=unused-argument
"""
停止扫描任务
URL: POST /api/scans/{id}/stop/
功能:
- 终止正在运行或初始化的扫描任务
- 更新扫描状态为 CANCELLED
状态限制:
- 只能停止 RUNNING 或 INITIATED 状态的扫描
- 已完成、失败或取消的扫描无法停止
返回:
- message: 成功消息
- revokedTaskCount: 取消的 Flow Run 数量
"""
"""停止扫描任务"""
try:
# 使用 Service 层处理停止逻辑
scan_service = ScanService()
success, revoked_count = scan_service.stop_scan(scan_id=pk)
if not success:
# 检查是否是状态不允许的问题
scan = scan_service.get_scan(scan_id=pk, prefetch_relations=False)
if scan and scan.status not in [ScanStatus.RUNNING, ScanStatus.INITIATED]:
return error_response(
code=ErrorCodes.BAD_REQUEST,
message=f'Cannot stop scan: current status is {ScanStatus(scan.status).label}',
message=f'Cannot stop scan: current status is '
f'{ScanStatus(scan.status).label}',
status_code=status.HTTP_400_BAD_REQUEST
)
# 其他失败原因
return error_response(
code=ErrorCodes.SERVER_ERROR,
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return success_response(
data={'revokedTaskCount': revoked_count}
)
return success_response(data={'revokedTaskCount': revoked_count})
except ObjectDoesNotExist:
return error_response(
code=ErrorCodes.NOT_FOUND,
message=f'Scan ID {pk} not found',
status_code=status.HTTP_404_NOT_FOUND
)
except (DatabaseError, IntegrityError, OperationalError):
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Database error',
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
)
return _handle_database_error()

View File

@@ -54,7 +54,7 @@ export function EnginePresetSelector({
engines.forEach(e => {
const caps = parseEngineCapabilities(e.configuration || "")
const hasRecon = caps.includes("subdomain_discovery") || caps.includes("port_scan") || caps.includes("site_scan") || caps.includes("directory_scan") || caps.includes("url_fetch")
const hasRecon = caps.includes("subdomain_discovery") || caps.includes("port_scan") || caps.includes("site_scan") || caps.includes("fingerprint_detect") || caps.includes("directory_scan") || caps.includes("url_fetch") || caps.includes("screenshot")
const hasVuln = caps.includes("vuln_scan")
if (hasRecon && hasVuln) {