mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-01-31 11:46:16 +08:00
新增:重构导出逻辑代码,加入黑名单过滤
This commit is contained in:
@@ -140,28 +140,7 @@ def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> i
|
||||
return default
|
||||
|
||||
|
||||
def _setup_directory_scan_directory(scan_workspace_dir: str) -> Path:
|
||||
"""
|
||||
创建并验证目录扫描工作目录
|
||||
|
||||
Args:
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
|
||||
Returns:
|
||||
Path: 目录扫描目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 目录创建或验证失败
|
||||
"""
|
||||
directory_scan_dir = Path(scan_workspace_dir) / 'directory_scan'
|
||||
directory_scan_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not directory_scan_dir.is_dir():
|
||||
raise RuntimeError(f"目录扫描目录创建失败: {directory_scan_dir}")
|
||||
if not os.access(directory_scan_dir, os.W_OK):
|
||||
raise RuntimeError(f"目录扫描目录不可写: {directory_scan_dir}")
|
||||
|
||||
return directory_scan_dir
|
||||
|
||||
|
||||
|
||||
def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]:
|
||||
@@ -640,7 +619,8 @@ def directory_scan_flow(
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
directory_scan_dir = _setup_directory_scan_directory(scan_workspace_dir)
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
|
||||
|
||||
# Step 1: 导出站点 URL(支持懒加载)
|
||||
sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir)
|
||||
|
||||
@@ -64,28 +64,7 @@ def calculate_fingerprint_detect_timeout(
|
||||
return max(min_timeout, timeout)
|
||||
|
||||
|
||||
def _setup_fingerprint_detect_directory(scan_workspace_dir: str) -> Path:
|
||||
"""
|
||||
创建并验证指纹识别工作目录
|
||||
|
||||
Args:
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
|
||||
Returns:
|
||||
Path: 指纹识别目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 目录创建或验证失败
|
||||
"""
|
||||
fingerprint_dir = Path(scan_workspace_dir) / 'fingerprint_detect'
|
||||
fingerprint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not fingerprint_dir.is_dir():
|
||||
raise RuntimeError(f"指纹识别目录创建失败: {fingerprint_dir}")
|
||||
if not os.access(fingerprint_dir, os.W_OK):
|
||||
raise RuntimeError(f"指纹识别目录不可写: {fingerprint_dir}")
|
||||
|
||||
return fingerprint_dir
|
||||
|
||||
|
||||
|
||||
def _export_urls(
|
||||
@@ -313,7 +292,8 @@ def fingerprint_detect_flow(
|
||||
source = 'website'
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
fingerprint_dir = _setup_fingerprint_detect_directory(scan_workspace_dir)
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
|
||||
|
||||
# Step 1: 导出 URL(支持懒加载)
|
||||
urls_file, url_count = _export_urls(target_id, fingerprint_dir, target_name, source)
|
||||
|
||||
@@ -30,7 +30,7 @@ from apps.scan.handlers import (
|
||||
on_initiate_scan_flow_failed,
|
||||
)
|
||||
from prefect.futures import wait
|
||||
from apps.scan.tasks.workspace_tasks import create_scan_workspace_task
|
||||
from apps.scan.utils import setup_scan_workspace
|
||||
from apps.scan.orchestrators import FlowOrchestrator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -110,7 +110,7 @@ def initiate_scan_flow(
|
||||
)
|
||||
|
||||
# ==================== Task 1: 创建 Scan 工作空间 ====================
|
||||
scan_workspace_path = create_scan_workspace_task(scan_workspace_dir)
|
||||
scan_workspace_path = setup_scan_workspace(scan_workspace_dir)
|
||||
|
||||
# ==================== Task 2: 获取引擎配置 ====================
|
||||
from apps.scan.models import Scan
|
||||
|
||||
@@ -154,28 +154,7 @@ def _parse_port_count(tool_config: dict) -> int:
|
||||
return 100
|
||||
|
||||
|
||||
def _setup_port_scan_directory(scan_workspace_dir: str) -> Path:
|
||||
"""
|
||||
创建并验证端口扫描工作目录
|
||||
|
||||
Args:
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
|
||||
Returns:
|
||||
Path: 端口扫描目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 目录创建或验证失败
|
||||
"""
|
||||
port_scan_dir = Path(scan_workspace_dir) / 'port_scan'
|
||||
port_scan_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not port_scan_dir.is_dir():
|
||||
raise RuntimeError(f"端口扫描目录创建失败: {port_scan_dir}")
|
||||
if not os.access(port_scan_dir, os.W_OK):
|
||||
raise RuntimeError(f"端口扫描目录不可写: {port_scan_dir}")
|
||||
|
||||
return port_scan_dir
|
||||
|
||||
|
||||
|
||||
def _export_scan_targets(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
|
||||
@@ -442,7 +421,8 @@ def port_scan_flow(
|
||||
)
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
port_scan_dir = _setup_port_scan_directory(scan_workspace_dir)
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
|
||||
|
||||
# Step 1: 导出扫描目标列表到文件(根据 Target 类型自动决定内容)
|
||||
targets_file, target_count, target_type = _export_scan_targets(target_id, port_scan_dir)
|
||||
|
||||
@@ -85,28 +85,7 @@ def calculate_timeout_by_line_count(
|
||||
return min_timeout
|
||||
|
||||
|
||||
def _setup_site_scan_directory(scan_workspace_dir: str) -> Path:
|
||||
"""
|
||||
创建并验证站点扫描工作目录
|
||||
|
||||
Args:
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
|
||||
Returns:
|
||||
Path: 站点扫描目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 目录创建或验证失败
|
||||
"""
|
||||
site_scan_dir = Path(scan_workspace_dir) / 'site_scan'
|
||||
site_scan_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not site_scan_dir.is_dir():
|
||||
raise RuntimeError(f"站点扫描目录创建失败: {site_scan_dir}")
|
||||
if not os.access(site_scan_dir, os.W_OK):
|
||||
raise RuntimeError(f"站点扫描目录不可写: {site_scan_dir}")
|
||||
|
||||
return site_scan_dir
|
||||
|
||||
|
||||
|
||||
def _export_site_urls(target_id: int, site_scan_dir: Path, target_name: str = None) -> tuple[str, int, int]:
|
||||
@@ -403,7 +382,8 @@ def site_scan_flow(
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
site_scan_dir = _setup_site_scan_directory(scan_workspace_dir)
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan')
|
||||
|
||||
# Step 1: 导出站点 URL
|
||||
urls_file, total_urls, association_count = _export_site_urls(
|
||||
|
||||
@@ -41,28 +41,7 @@ import subprocess
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _setup_subdomain_directory(scan_workspace_dir: str) -> Path:
|
||||
"""
|
||||
创建并验证子域名扫描工作目录
|
||||
|
||||
Args:
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
|
||||
Returns:
|
||||
Path: 子域名扫描目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 目录创建或验证失败
|
||||
"""
|
||||
result_dir = Path(scan_workspace_dir) / 'subdomain_discovery'
|
||||
result_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not result_dir.is_dir():
|
||||
raise RuntimeError(f"子域名扫描目录创建失败: {result_dir}")
|
||||
if not os.access(result_dir, os.W_OK):
|
||||
raise RuntimeError(f"子域名扫描目录不可写: {result_dir}")
|
||||
|
||||
return result_dir
|
||||
|
||||
|
||||
|
||||
def _validate_and_normalize_target(target_name: str) -> str:
|
||||
@@ -119,12 +98,7 @@ def _run_scans_parallel(
|
||||
|
||||
# 生成时间戳(所有工具共用)
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
|
||||
# TODO: 接入代理池管理系统
|
||||
# from apps.proxy.services import proxy_pool
|
||||
# proxy_stats = proxy_pool.get_stats()
|
||||
# logger.info(f"代理池状态: {proxy_stats['healthy']}/{proxy_stats['total']} 可用")
|
||||
|
||||
|
||||
failures = [] # 记录命令构建失败的工具
|
||||
futures = {}
|
||||
|
||||
@@ -417,7 +391,8 @@ def subdomain_discovery_flow(
|
||||
)
|
||||
|
||||
# Step 0: 准备工作
|
||||
result_dir = _setup_subdomain_directory(scan_workspace_dir)
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
result_dir = setup_scan_directory(scan_workspace_dir, 'subdomain_discovery')
|
||||
|
||||
# 验证并规范化目标域名
|
||||
try:
|
||||
|
||||
@@ -42,17 +42,7 @@ SITES_FILE_TOOLS = {'katana'}
|
||||
POST_PROCESS_TOOLS = {'uro', 'httpx'}
|
||||
|
||||
|
||||
def _setup_url_fetch_directory(scan_workspace_dir: str) -> Path:
|
||||
"""创建并验证 URL 获取工作目录"""
|
||||
url_fetch_dir = Path(scan_workspace_dir) / 'url_fetch'
|
||||
url_fetch_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not url_fetch_dir.is_dir():
|
||||
raise RuntimeError(f"URL 获取目录创建失败: {url_fetch_dir}")
|
||||
if not os.access(url_fetch_dir, os.W_OK):
|
||||
raise RuntimeError(f"URL 获取目录不可写: {url_fetch_dir}")
|
||||
|
||||
return url_fetch_dir
|
||||
|
||||
|
||||
|
||||
def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
|
||||
@@ -304,7 +294,8 @@ def url_fetch_flow(
|
||||
|
||||
# Step 1: 准备工作目录
|
||||
logger.info("Step 1: 准备工作目录")
|
||||
url_fetch_dir = _setup_url_fetch_directory(scan_workspace_dir)
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch')
|
||||
|
||||
# Step 2: 分类工具(按输入类型)
|
||||
logger.info("Step 2: 分类工具")
|
||||
|
||||
@@ -25,10 +25,7 @@ from .utils import calculate_timeout_by_line_count
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _setup_vuln_scan_directory(scan_workspace_dir: str) -> Path:
|
||||
vuln_scan_dir = Path(scan_workspace_dir) / "vuln_scan"
|
||||
vuln_scan_dir.mkdir(parents=True, exist_ok=True)
|
||||
return vuln_scan_dir
|
||||
|
||||
|
||||
|
||||
@flow(
|
||||
@@ -55,14 +52,14 @@ def endpoints_vuln_scan_flow(
|
||||
if not enabled_tools:
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
vuln_scan_dir = _setup_vuln_scan_directory(scan_workspace_dir)
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
vuln_scan_dir = setup_scan_directory(scan_workspace_dir, 'vuln_scan')
|
||||
endpoints_file = vuln_scan_dir / "input_endpoints.txt"
|
||||
|
||||
# Step 1: 导出 Endpoint URL
|
||||
export_result = export_endpoints_task(
|
||||
target_id=target_id,
|
||||
output_file=str(endpoints_file),
|
||||
target_name=target_name, # 传入 target_name 用于生成默认端点
|
||||
)
|
||||
total_endpoints = export_result.get("total_count", 0)
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ from .scan_state_service import ScanStateService
|
||||
from .scan_control_service import ScanControlService
|
||||
from .scan_stats_service import ScanStatsService
|
||||
from .scheduled_scan_service import ScheduledScanService
|
||||
from .blacklist_service import BlacklistService
|
||||
from .target_export_service import TargetExportService
|
||||
|
||||
__all__ = [
|
||||
'ScanService', # 主入口(向后兼容)
|
||||
@@ -25,5 +27,7 @@ __all__ = [
|
||||
'ScanControlService',
|
||||
'ScanStatsService',
|
||||
'ScheduledScanService',
|
||||
'BlacklistService', # 黑名单过滤服务
|
||||
'TargetExportService', # 目标导出服务
|
||||
]
|
||||
|
||||
|
||||
@@ -9,9 +9,6 @@
|
||||
- Tasks 负责具体操作,Flow 负责编排
|
||||
"""
|
||||
|
||||
# Prefect Tasks
|
||||
from .workspace_tasks import create_scan_workspace_task
|
||||
|
||||
# 子域名发现任务(已重构为多个子任务)
|
||||
from .subdomain_discovery import (
|
||||
run_subdomain_discovery_task,
|
||||
@@ -30,10 +27,9 @@ from .fingerprint_detect import (
|
||||
# - finalize_scan_task 已废弃(Handler 统一管理状态)
|
||||
# - initiate_scan_task 已迁移到 flows/initiate_scan_flow.py
|
||||
# - cleanup_old_scans_task 已迁移到 flows(cleanup_old_scans_flow)
|
||||
# - create_scan_workspace_task 已删除,直接使用 setup_scan_workspace()
|
||||
|
||||
__all__ = [
|
||||
# Prefect Tasks
|
||||
'create_scan_workspace_task',
|
||||
# 子域名发现任务
|
||||
'run_subdomain_discovery_task',
|
||||
'merge_and_validate_task',
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
"""
|
||||
导出站点 URL 到 TXT 文件的 Task
|
||||
|
||||
使用流式处理,避免大量站点导致内存溢出
|
||||
支持默认值模式:如果没有站点,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://target_name
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
数据源: WebSite.url
|
||||
"""
|
||||
import logging
|
||||
import ipaddress
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.repositories import DjangoWebSiteRepository
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.asset.models import WebSite
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,19 +18,22 @@ def export_sites_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
batch_size: int = 1000,
|
||||
target_name: str = None
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点 URL 到 TXT 文件
|
||||
|
||||
使用流式处理,支持大规模数据导出(10万+站点)
|
||||
支持默认值模式:如果没有站点,自动使用默认站点 URL(http(s)://target_name)
|
||||
数据源: WebSite.url
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
batch_size: 每次读取的批次大小,默认 1000
|
||||
target_name: 目标名称(用于默认值模式)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
@@ -49,134 +46,26 @@ def export_sites_task(
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
try:
|
||||
# 初始化 Repository
|
||||
repository = DjangoWebSiteRepository()
|
||||
|
||||
logger.info("开始导出站点 URL - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用 Repository 流式查询站点 URL
|
||||
url_iterator = repository.get_urls_for_export(
|
||||
target_id=target_id,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 流式写入文件
|
||||
total_count = 0
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in url_iterator:
|
||||
# 每次只处理一个 URL,边读边写
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
# 每写入 10000 条记录打印一次进度
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个站点 URL...", total_count)
|
||||
|
||||
# ==================== 懒加载模式:根据 Target 类型生成默认 URL ====================
|
||||
if total_count == 0:
|
||||
total_count = _write_default_urls(target_id, target_name, output_path)
|
||||
|
||||
logger.info(
|
||||
"✓ 站点 URL 导出完成 - 总数: %d, 文件: %s (%.2f KB)",
|
||||
total_count,
|
||||
str(output_path), # 使用绝对路径
|
||||
output_path.stat().st_size / 1024
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count
|
||||
}
|
||||
|
||||
except FileNotFoundError as e:
|
||||
logger.error("输出目录不存在: %s", e)
|
||||
raise
|
||||
except PermissionError as e:
|
||||
logger.error("文件写入权限不足: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("导出站点 URL 失败: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _write_default_urls(target_id: int, target_name: str, output_path: Path) -> int:
|
||||
"""
|
||||
懒加载模式:根据 Target 类型生成默认 URL
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(可选,如果为空则从数据库查询)
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 生成的 URL 数量
|
||||
"""
|
||||
# 获取 Target 信息
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
return 0
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name)
|
||||
|
||||
total_urls = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
# 域名类型:生成 http(s)://domain
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ 域名默认 URL 已写入: http(s)://%s", target_name)
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
# IP 类型:生成 http(s)://ip
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ IP 默认 URL 已写入: http(s)://%s", target_name)
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
# CIDR 类型:展开为所有 IP 的 URL
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
|
||||
for ip in network.hosts(): # 排除网络地址和广播地址
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls += 2
|
||||
|
||||
if total_urls % 10000 == 0:
|
||||
logger.info("已生成 %d 个 URL...", total_urls)
|
||||
|
||||
# 如果是 /32 或 /128(单个 IP),hosts() 会为空
|
||||
if total_urls == 0:
|
||||
ip = str(network.network_address)
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls = 2
|
||||
|
||||
logger.info("✓ CIDR 默认 URL 已写入: %d 个 URL (来自 %s)", total_urls, target_name)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
return 0
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
return 0
|
||||
|
||||
return total_urls
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'success': result['success'],
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count']
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -2,55 +2,30 @@
|
||||
导出 URL 任务
|
||||
|
||||
用于指纹识别前导出目标下的 URL 到文件
|
||||
支持懒加载模式:如果数据库为空,根据 Target 类型生成默认 URL
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import importlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.models import WebSite
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 数据源映射:source → (module_path, model_name, url_field)
|
||||
SOURCE_MODEL_MAP = {
|
||||
'website': ('apps.asset.models', 'WebSite', 'url'),
|
||||
# 以后扩展:
|
||||
# 'endpoint': ('apps.asset.models', 'Endpoint', 'url'),
|
||||
# 'directory': ('apps.asset.models', 'Directory', 'url'),
|
||||
}
|
||||
|
||||
|
||||
def _get_model_class(source: str):
|
||||
"""
|
||||
根据数据源类型获取 Model 类
|
||||
"""
|
||||
if source not in SOURCE_MODEL_MAP:
|
||||
raise ValueError(f"不支持的数据源: {source},支持的类型: {list(SOURCE_MODEL_MAP.keys())}")
|
||||
|
||||
module_path, model_name, _ = SOURCE_MODEL_MAP[source]
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, model_name)
|
||||
|
||||
|
||||
@task(name="export_urls_for_fingerprint")
|
||||
def export_urls_for_fingerprint_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_name: str = None,
|
||||
source: str = 'website',
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的 URL 到文件(用于指纹识别)
|
||||
|
||||
支持多种数据源,预留扩展:
|
||||
- website: WebSite 表(当前实现)
|
||||
- endpoint: Endpoint 表(以后扩展)
|
||||
- directory: Directory 表(以后扩展)
|
||||
数据源: WebSite.url
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
@@ -62,76 +37,29 @@ def export_urls_for_fingerprint_task(
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径
|
||||
target_name: 目标名称(用于懒加载)
|
||||
source: 数据源类型
|
||||
source: 数据源类型(保留参数,兼容旧调用)
|
||||
batch_size: 批量读取大小
|
||||
|
||||
Returns:
|
||||
dict: {'output_file': str, 'total_count': int, 'source': str}
|
||||
"""
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
logger.info("开始导出 URL - target_id=%s, source=%s, output=%s", target_id, source, output_file)
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
Model = _get_model_class(source)
|
||||
_, _, url_field = SOURCE_MODEL_MAP[source]
|
||||
|
||||
output_path = Path(output_file)
|
||||
|
||||
# 分批导出
|
||||
total_count = 0
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
queryset = Model.objects.filter(target_id=target_id).values_list(url_field, flat=True)
|
||||
for url in queryset.iterator(chunk_size=batch_size):
|
||||
if url:
|
||||
f.write(url + '\n')
|
||||
total_count += 1
|
||||
|
||||
# ==================== 懒加载模式:根据 Target 类型生成默认 URL ====================
|
||||
if total_count == 0:
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
|
||||
if target:
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_count = 2
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_count = 2
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
for ip in network.hosts():
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_count += 2
|
||||
except ValueError as e:
|
||||
logger.warning("CIDR 解析失败: %s", e)
|
||||
|
||||
elif target_type == Target.TargetType.URL:
|
||||
f.write(f"{target_name}\n")
|
||||
total_count = 1
|
||||
|
||||
logger.info("✓ 懒加载生成默认 URL - 数量: %d", total_count)
|
||||
else:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 数量: %d, 文件: %s", total_count, output_file)
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'output_file': output_file,
|
||||
'total_count': total_count,
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
'source': source
|
||||
}
|
||||
|
||||
@@ -1,119 +1,21 @@
|
||||
"""
|
||||
导出扫描目标到 TXT 文件的 Task
|
||||
|
||||
使用 TargetExportService.export_targets() 统一处理导出逻辑
|
||||
|
||||
根据 Target 类型决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名
|
||||
- IP: 直接写入 target.name
|
||||
- CIDR: 展开 CIDR 范围内的所有 IP
|
||||
|
||||
使用流式处理,避免大量数据导致内存溢出
|
||||
"""
|
||||
import logging
|
||||
import ipaddress
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target # 仅用于 TargetType 常量
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _export_domains(target_id: int, target_name: str, output_path: Path, batch_size: int) -> int:
|
||||
"""
|
||||
导出域名类型目标的子域名(支持默认值模式)
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(域名)
|
||||
output_path: 输出文件路径
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
int: 导出的记录数
|
||||
|
||||
默认值模式:
|
||||
如果没有子域名,自动使用根域名作为默认子域名
|
||||
"""
|
||||
subdomain_service = SubdomainService()
|
||||
domain_iterator = subdomain_service.iter_subdomain_names_by_target(
|
||||
target_id=target_id,
|
||||
chunk_size=batch_size
|
||||
)
|
||||
|
||||
total_count = 0
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for domain_name in domain_iterator:
|
||||
f.write(f"{domain_name}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个域名...", total_count)
|
||||
|
||||
# ==================== 采用默认域名:如果没有子域名,使用根域名 ====================
|
||||
# 只写入文件供扫描工具使用,不写入数据库
|
||||
# 数据库只存储扫描发现的真实资产
|
||||
if total_count == 0:
|
||||
logger.info("采用默认域名:%s (target_id=%d)", target_name, target_id)
|
||||
|
||||
# 只写入文件,不写入数据库
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{target_name}\n")
|
||||
total_count = 1
|
||||
|
||||
logger.info("✓ 默认域名已写入文件: %s", target_name)
|
||||
|
||||
return total_count
|
||||
|
||||
|
||||
def _export_ip(target_name: str, output_path: Path) -> int:
|
||||
"""
|
||||
导出 IP 类型目标
|
||||
|
||||
Args:
|
||||
target_name: IP 地址
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 导出的记录数(始终为 1)
|
||||
"""
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{target_name}\n")
|
||||
return 1
|
||||
|
||||
|
||||
def _export_cidr(target_name: str, output_path: Path) -> int:
|
||||
"""
|
||||
导出 CIDR 类型目标,展开为每个 IP
|
||||
|
||||
Args:
|
||||
target_name: CIDR 范围(如 192.168.1.0/24)
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 导出的 IP 数量
|
||||
"""
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
total_count = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for ip in network.hosts(): # 排除网络地址和广播地址
|
||||
f.write(f"{ip}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个 IP...", total_count)
|
||||
|
||||
# 如果是 /32 或 /128(单个 IP),hosts() 会为空,需要特殊处理
|
||||
if total_count == 0:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{network.network_address}\n")
|
||||
total_count = 1
|
||||
|
||||
return total_count
|
||||
|
||||
|
||||
@task(name="export_scan_targets")
|
||||
def export_scan_targets_task(
|
||||
target_id: int,
|
||||
@@ -145,62 +47,20 @@ def export_scan_targets_task(
|
||||
ValueError: Target 不存在
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
try:
|
||||
# 1. 通过 Service 层获取 Target
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
if not target:
|
||||
raise ValueError(f"Target ID {target_id} 不存在")
|
||||
|
||||
target_type = target.type
|
||||
target_name = target.name
|
||||
|
||||
logger.info(
|
||||
"开始导出扫描目标 - Target ID: %d, Name: %s, Type: %s, 输出文件: %s",
|
||||
target_id, target_name, target_type, output_file
|
||||
)
|
||||
|
||||
# 2. 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 3. 根据类型导出
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
total_count = _export_domains(target_id, target_name, output_path, batch_size)
|
||||
type_desc = "域名"
|
||||
elif target_type == Target.TargetType.IP:
|
||||
total_count = _export_ip(target_name, output_path)
|
||||
type_desc = "IP"
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
total_count = _export_cidr(target_name, output_path)
|
||||
type_desc = "CIDR IP"
|
||||
else:
|
||||
raise ValueError(f"不支持的目标类型: {target_type}")
|
||||
|
||||
logger.info(
|
||||
"✓ 扫描目标导出完成 - 类型: %s, 总数: %d, 文件: %s (%.2f KB)",
|
||||
type_desc,
|
||||
total_count,
|
||||
str(output_path),
|
||||
output_path.stat().st_size / 1024
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
'target_type': target_type
|
||||
}
|
||||
|
||||
except FileNotFoundError as e:
|
||||
logger.error("输出目录不存在: %s", e)
|
||||
raise
|
||||
except PermissionError as e:
|
||||
logger.error("文件写入权限不足: %s", e)
|
||||
raise
|
||||
except ValueError as e:
|
||||
logger.error("参数错误: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("导出扫描目标失败: %s", e)
|
||||
raise
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
result = export_service.export_targets(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'success': result['success'],
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
'target_type': result['target_type']
|
||||
}
|
||||
|
||||
@@ -2,52 +2,65 @@
|
||||
导出站点URL到文件的Task
|
||||
|
||||
直接使用 HostPortMapping 表查询 host+port 组合,拼接成URL格式写入文件
|
||||
使用 TargetExportService 处理默认值回退逻辑
|
||||
|
||||
默认值模式:
|
||||
- 如果没有 HostPortMapping 数据,写入默认 URL 到文件(不写入数据库)
|
||||
- DOMAIN: http(s)://target_name
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
"""
|
||||
import logging
|
||||
import ipaddress
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
from typing import Optional
|
||||
|
||||
from apps.asset.services import HostPortMappingService
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _generate_urls_from_port(host: str, port: int) -> list[str]:
|
||||
"""
|
||||
根据端口生成 URL 列表
|
||||
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
"""
|
||||
if port == 80:
|
||||
return [f"http://{host}"]
|
||||
elif port == 443:
|
||||
return [f"https://{host}"]
|
||||
else:
|
||||
return [f"http://{host}:{port}", f"https://{host}:{port}"]
|
||||
|
||||
|
||||
@task(name="export_site_urls")
|
||||
def export_site_urls_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_name: Optional[str] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点URL到文件(基于 HostPortMapping 表)
|
||||
|
||||
功能:
|
||||
1. 从 HostPortMapping 表查询 target 下所有 host+port 组合
|
||||
2. 拼接成URL格式(标准端口80/443将省略端口号)
|
||||
3. 写入到指定文件中
|
||||
数据源: HostPortMapping (host + port)
|
||||
|
||||
默认值模式(懒加载):
|
||||
- 如果没有 HostPortMapping 数据,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://target_name
|
||||
特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
target_name: 目标名称(用于懒加载时写入默认值)
|
||||
batch_size: 每次处理的批次大小,默认1000(暂未使用,预留)
|
||||
batch_size: 每次处理的批次大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
@@ -61,155 +74,54 @@ def export_site_urls_task(
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
try:
|
||||
logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 直接查询 HostPortMapping 表,按 host 排序
|
||||
service = HostPortMappingService()
|
||||
associations = service.iter_host_port_by_target(
|
||||
target_id=target_id,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
total_urls = 0
|
||||
association_count = 0
|
||||
|
||||
# 流式写入文件
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for assoc in associations:
|
||||
association_count += 1
|
||||
host = assoc['host']
|
||||
port = assoc['port']
|
||||
|
||||
# 根据端口号生成URL
|
||||
# 80 端口:只生成 HTTP URL(省略端口号)
|
||||
# 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
# 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
if port == 80:
|
||||
# HTTP 标准端口,省略端口号
|
||||
url = f"http://{host}"
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
elif port == 443:
|
||||
# HTTPS 标准端口,省略端口号
|
||||
url = f"https://{host}"
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
else:
|
||||
# 非标准端口,生成 HTTP 和 HTTPS 两个URL
|
||||
http_url = f"http://{host}:{port}"
|
||||
https_url = f"https://{host}:{port}"
|
||||
f.write(f"{http_url}\n")
|
||||
f.write(f"{https_url}\n")
|
||||
total_urls += 2
|
||||
|
||||
# 每处理1000条记录打印一次进度
|
||||
if association_count % 1000 == 0:
|
||||
logger.info("已处理 %d 条关联,生成 %d 个URL...", association_count, total_urls)
|
||||
|
||||
logger.info(
|
||||
"✓ 站点URL导出完成 - 关联数: %d, 总URL数: %d, 文件: %s (%.2f KB)",
|
||||
association_count,
|
||||
total_urls,
|
||||
str(output_path),
|
||||
output_path.stat().st_size / 1024
|
||||
)
|
||||
|
||||
# ==================== 懒加载模式:根据 Target 类型生成默认 URL ====================
|
||||
if total_urls == 0:
|
||||
total_urls = _write_default_urls(target_id, target_name, output_path)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_urls': total_urls,
|
||||
'association_count': association_count
|
||||
}
|
||||
|
||||
except FileNotFoundError as e:
|
||||
logger.error("输出目录不存在: %s", e)
|
||||
raise
|
||||
except PermissionError as e:
|
||||
logger.error("文件写入权限不足: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("导出站点URL失败: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _write_default_urls(target_id: int, target_name: Optional[str], output_path: Path) -> int:
|
||||
"""
|
||||
懒加载模式:根据 Target 类型生成默认 URL
|
||||
logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(可选,如果为空则从数据库查询)
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 生成的 URL 数量
|
||||
"""
|
||||
# 获取 Target 信息
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
return 0
|
||||
# 初始化黑名单服务
|
||||
blacklist_service = BlacklistService()
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name)
|
||||
# 直接查询 HostPortMapping 表,按 host 排序
|
||||
service = HostPortMappingService()
|
||||
associations = service.iter_host_port_by_target(
|
||||
target_id=target_id,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
total_urls = 0
|
||||
association_count = 0
|
||||
|
||||
# 流式写入文件(特殊端口逻辑)
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
# 域名类型:生成 http(s)://domain
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ 域名默认 URL 已写入: http(s)://%s", target_name)
|
||||
for assoc in associations:
|
||||
association_count += 1
|
||||
host = assoc['host']
|
||||
port = assoc['port']
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
# IP 类型:生成 http(s)://ip
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ IP 默认 URL 已写入: http(s)://%s", target_name)
|
||||
# 根据端口号生成URL
|
||||
for url in _generate_urls_from_port(host, port):
|
||||
if blacklist_service.filter_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
# CIDR 类型:展开为所有 IP 的 URL
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
|
||||
for ip in network.hosts(): # 排除网络地址和广播地址
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls += 2
|
||||
|
||||
if total_urls % 10000 == 0:
|
||||
logger.info("已生成 %d 个 URL...", total_urls)
|
||||
|
||||
# 如果是 /32 或 /128(单个 IP),hosts() 会为空
|
||||
if total_urls == 0:
|
||||
ip = str(network.network_address)
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls = 2
|
||||
|
||||
logger.info("✓ CIDR 默认 URL 已写入: %d 个 URL (来自 %s)", total_urls, target_name)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
return 0
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
return 0
|
||||
if association_count % 1000 == 0:
|
||||
logger.info("已处理 %d 条关联,生成 %d 个URL...", association_count, total_urls)
|
||||
|
||||
return total_urls
|
||||
logger.info(
|
||||
"✓ 站点URL导出完成 - 关联数: %d, 总URL数: %d, 文件: %s",
|
||||
association_count, total_urls, str(output_path)
|
||||
)
|
||||
|
||||
# 默认值回退模式:使用 TargetExportService
|
||||
if total_urls == 0:
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
total_urls = export_service._generate_default_urls(target_id, output_path)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_urls': total_urls,
|
||||
'association_count': association_count
|
||||
}
|
||||
|
||||
@@ -1,25 +1,16 @@
|
||||
"""
|
||||
导出站点 URL 列表任务
|
||||
|
||||
从 WebSite 表导出站点 URL 列表到文件(用于 katana 等爬虫工具)
|
||||
|
||||
使用流式写入,避免内存溢出
|
||||
|
||||
懒加载模式:
|
||||
- 如果 WebSite 表为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: 写入 http(s)://domain
|
||||
- IP: 写入 http(s)://ip
|
||||
- CIDR: 展开为所有 IP
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
数据源: WebSite.url(用于 katana 等爬虫工具)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import ipaddress
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
from typing import Optional
|
||||
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.asset.models import WebSite
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,21 +24,23 @@ def export_sites_task(
|
||||
output_file: str,
|
||||
target_id: int,
|
||||
scan_id: int,
|
||||
target_name: Optional[str] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出站点 URL 列表到文件(用于 katana 等爬虫工具)
|
||||
|
||||
数据源: WebSite.url
|
||||
|
||||
懒加载模式:
|
||||
- 如果 WebSite 表为空,根据 Target 类型生成默认 URL
|
||||
- 数据库只存储"真实发现"的资产
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
|
||||
Args:
|
||||
output_file: 输出文件路径
|
||||
target_id: 目标 ID
|
||||
scan_id: 扫描 ID
|
||||
target_name: 目标名称(用于懒加载时写入默认值)
|
||||
scan_id: 扫描 ID(保留参数,兼容旧调用)
|
||||
batch_size: 批次大小(内存优化)
|
||||
|
||||
Returns:
|
||||
@@ -60,109 +53,22 @@ def export_sites_task(
|
||||
ValueError: 参数错误
|
||||
RuntimeError: 执行失败
|
||||
"""
|
||||
try:
|
||||
logger.info("开始导出站点 URL 列表 - Target ID: %d", target_id)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 从 WebSite 表导出站点 URL
|
||||
from apps.asset.services import WebSiteService
|
||||
|
||||
website_service = WebSiteService()
|
||||
|
||||
# 流式写入文件
|
||||
asset_count = 0
|
||||
with open(output_path, 'w') as f:
|
||||
for url in website_service.iter_website_urls_by_target(target_id, batch_size):
|
||||
f.write(f"{url}\n")
|
||||
asset_count += 1
|
||||
|
||||
if asset_count % batch_size == 0:
|
||||
f.flush()
|
||||
|
||||
# ==================== 懒加载模式:根据 Target 类型生成默认 URL ====================
|
||||
if asset_count == 0:
|
||||
asset_count = _write_default_urls(target_id, target_name, output_path)
|
||||
|
||||
logger.info("✓ 站点 URL 导出完成 - 文件: %s, 数量: %d", output_file, asset_count)
|
||||
|
||||
return {
|
||||
'output_file': output_file,
|
||||
'asset_count': asset_count,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("导出站点 URL 失败: %s", e, exc_info=True)
|
||||
raise RuntimeError(f"导出站点 URL 失败: {e}") from e
|
||||
|
||||
|
||||
def _write_default_urls(target_id: int, target_name: Optional[str], output_path: Path) -> int:
|
||||
"""
|
||||
懒加载模式:根据 Target 类型生成默认 URL 列表
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 生成的 URL 数量
|
||||
"""
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
return 0
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name)
|
||||
|
||||
total_urls = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ 域名默认 URL 已写入: http(s)://%s", target_name)
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ IP 默认 URL 已写入: http(s)://%s", target_name)
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
|
||||
for ip in network.hosts():
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls += 2
|
||||
|
||||
if total_urls % 10000 == 0:
|
||||
logger.info("已生成 %d 个 URL...", total_urls)
|
||||
|
||||
# /32 或 /128 特殊处理
|
||||
if total_urls == 0:
|
||||
ip = str(network.network_address)
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls = 2
|
||||
|
||||
logger.info("✓ CIDR 默认 URL 已写入: %d 个 URL (来自 %s)", total_urls, target_name)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
return 0
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
return 0
|
||||
|
||||
return total_urls
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'output_file': result['output_file'],
|
||||
'asset_count': result['total_count'],
|
||||
}
|
||||
|
||||
@@ -1,25 +1,16 @@
|
||||
"""导出 Endpoint URL 到文件的 Task
|
||||
|
||||
基于 EndpointService.iter_endpoint_urls_by_target 按目标流式导出端点 URL,
|
||||
用于漏洞扫描(如 Dalfox XSS)的输入文件生成。
|
||||
|
||||
默认值模式:
|
||||
- 如果没有 Endpoint,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://target_name
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
数据源: Endpoint.url
|
||||
"""
|
||||
|
||||
import logging
|
||||
import ipaddress
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.services import EndpointService
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.asset.models import Endpoint
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,17 +20,21 @@ def export_endpoints_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
batch_size: int = 1000,
|
||||
target_name: Optional[str] = None,
|
||||
) -> Dict[str, object]:
|
||||
"""导出目标下的所有 Endpoint URL 到文本文件。
|
||||
|
||||
默认值模式:如果没有 Endpoint,根据 Target 类型生成默认 URL
|
||||
数据源: Endpoint.url
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
batch_size: 每次从数据库迭代的批大小
|
||||
target_name: 目标名称(用于默认值模式)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
@@ -48,117 +43,23 @@ def export_endpoints_task(
|
||||
"total_count": int,
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info("开始导出 Endpoint URL - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
service = EndpointService()
|
||||
url_iterator = service.iter_endpoint_urls_by_target(target_id, chunk_size=batch_size)
|
||||
|
||||
total_count = 0
|
||||
with open(output_path, "w", encoding="utf-8", buffering=8192) as f:
|
||||
for url in url_iterator:
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个 Endpoint URL...", total_count)
|
||||
|
||||
# ==================== 懒加载模式:根据 Target 类型生成默认 URL ====================
|
||||
if total_count == 0:
|
||||
total_count = _write_default_urls(target_id, target_name, output_path)
|
||||
|
||||
logger.info(
|
||||
"✓ Endpoint URL 导出完成 - 总数: %d, 文件: %s (%.2f KB)",
|
||||
total_count,
|
||||
str(output_path),
|
||||
output_path.stat().st_size / 1024,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output_file": str(output_path),
|
||||
"total_count": total_count,
|
||||
}
|
||||
|
||||
except FileNotFoundError as e:
|
||||
logger.error("输出目录不存在: %s", e)
|
||||
raise
|
||||
except PermissionError as e:
|
||||
logger.error("文件写入权限不足: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("导出 Endpoint URL 失败: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _write_default_urls(target_id: int, target_name: Optional[str], output_path: Path) -> int:
|
||||
"""
|
||||
懒加载模式:根据 Target 类型生成默认 URL
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = Endpoint.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(可选,如果为空则从数据库查询)
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 生成的 URL 数量
|
||||
"""
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
return 0
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name)
|
||||
|
||||
total_urls = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ 域名默认 URL 已写入: http(s)://%s", target_name)
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ IP 默认 URL 已写入: http(s)://%s", target_name)
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
|
||||
for ip in network.hosts():
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls += 2
|
||||
|
||||
if total_urls % 10000 == 0:
|
||||
logger.info("已生成 %d 个 URL...", total_urls)
|
||||
|
||||
# /32 或 /128 特殊处理
|
||||
if total_urls == 0:
|
||||
ip = str(network.network_address)
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls = 2
|
||||
|
||||
logger.info("✓ CIDR 默认 URL 已写入: %d 个 URL (来自 %s)", total_urls, target_name)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
return 0
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
return 0
|
||||
|
||||
return total_urls
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
"success": result['success'],
|
||||
"output_file": result['output_file'],
|
||||
"total_count": result['total_count'],
|
||||
}
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
"""
|
||||
工作空间相关的 Prefect Tasks
|
||||
|
||||
负责扫描工作空间的创建、验证和管理
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(
|
||||
name="create_scan_workspace",
|
||||
description="创建并验证 Scan 工作空间目录",
|
||||
retries=2,
|
||||
retry_delay_seconds=5
|
||||
)
|
||||
def create_scan_workspace_task(scan_workspace_dir: str) -> Path:
|
||||
"""
|
||||
创建并验证 Scan 工作空间目录
|
||||
|
||||
Args:
|
||||
scan_workspace_dir: Scan 工作空间目录路径
|
||||
|
||||
Returns:
|
||||
Path: 创建的 Scan 工作空间路径对象
|
||||
|
||||
Raises:
|
||||
OSError: 目录创建失败或不可写
|
||||
"""
|
||||
scan_workspace_path = Path(scan_workspace_dir)
|
||||
|
||||
# 创建目录
|
||||
try:
|
||||
scan_workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("✓ Scan 工作空间已创建: %s", scan_workspace_path)
|
||||
except OSError as e:
|
||||
logger.error("创建 Scan 工作空间失败: %s - %s", scan_workspace_dir, e)
|
||||
raise
|
||||
|
||||
# 验证目录是否可写
|
||||
test_file = scan_workspace_path / ".test_write"
|
||||
try:
|
||||
test_file.touch()
|
||||
test_file.unlink()
|
||||
logger.info("✓ Scan 工作空间验证通过(可写): %s", scan_workspace_path)
|
||||
except OSError as e:
|
||||
error_msg = f"Scan 工作空间不可写: {scan_workspace_path}"
|
||||
logger.error(error_msg)
|
||||
raise OSError(error_msg) from e
|
||||
|
||||
return scan_workspace_path
|
||||
@@ -10,11 +10,15 @@ from .command_executor import execute_and_wait, execute_stream
|
||||
from .wordlist_helpers import ensure_wordlist_local
|
||||
from .nuclei_helpers import ensure_nuclei_templates_local
|
||||
from .performance import FlowPerformanceTracker, CommandPerformanceTracker
|
||||
from .workspace_utils import setup_scan_workspace, setup_scan_directory
|
||||
from . import config_parser
|
||||
|
||||
__all__ = [
|
||||
# 目录清理
|
||||
'remove_directory',
|
||||
# 工作空间
|
||||
'setup_scan_workspace', # 创建 Scan 根工作空间
|
||||
'setup_scan_directory', # 创建扫描子目录
|
||||
# 命令构建
|
||||
'build_scan_command', # 扫描工具命令构建(基于 f-string)
|
||||
# 命令执行
|
||||
|
||||
Reference in New Issue
Block a user