优化:性能优化

This commit is contained in:
yyhuni
2026-01-10 09:44:43 +08:00
parent 592a1958c4
commit eba394e14e
17 changed files with 1493 additions and 1895 deletions

View File

@@ -10,30 +10,30 @@
- 配置由 YAML 解析
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
from prefect import flow
from prefect.task_runners import ThreadPoolTaskRunner
import hashlib
import logging
import os
import subprocess
from datetime import datetime
from pathlib import Path
from typing import List, Tuple
from apps.scan.tasks.directory_scan import (
export_sites_task,
run_and_stream_save_directories_task
)
from prefect import flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.tasks.directory_scan import (
export_sites_task,
run_and_stream_save_directories_task,
)
from apps.scan.utils import (
build_scan_command,
ensure_wordlist_local,
user_log,
wait_for_system_load,
)
from apps.scan.utils import config_parser, build_scan_command, ensure_wordlist_local, user_log
logger = logging.getLogger(__name__)
@@ -45,517 +45,343 @@ def calculate_directory_scan_timeout(
tool_config: dict,
base_per_word: float = 1.0,
min_timeout: int = 60,
max_timeout: int = 7200
) -> int:
"""
根据字典行数计算目录扫描超时时间
计算公式:超时时间 = 字典行数 × 每个单词基础时间
超时范围:60秒 ~ 2小时7200秒
超时范围:最小 60 秒,无上限
Args:
tool_config: 工具配置字典,包含 wordlist 路径
base_per_word: 每个单词的基础时间(秒),默认 1.0秒
min_timeout: 最小超时时间(秒),默认 60秒
max_timeout: 最大超时时间(秒),默认 7200秒2小时
Returns:
int: 计算出的超时时间(秒)范围60 ~ 7200
Example:
# 1000行字典 × 1.0秒 = 1000秒 → 限制为7200秒中的 1000秒
# 10000行字典 × 1.0秒 = 10000秒 → 限制为7200秒最大值
timeout = calculate_directory_scan_timeout(
tool_config={'wordlist': '/path/to/wordlist.txt'}
)
int: 计算出的超时时间(秒)
"""
import os
wordlist_path = tool_config.get('wordlist')
if not wordlist_path:
logger.warning("工具配置中未指定 wordlist使用默认超时: %d", min_timeout)
return min_timeout
wordlist_path = os.path.expanduser(wordlist_path)
if not os.path.exists(wordlist_path):
logger.warning("字典文件不存在: %s,使用默认超时: %d", wordlist_path, min_timeout)
return min_timeout
try:
# 从 tool_config 中获取 wordlist 路径
wordlist_path = tool_config.get('wordlist')
if not wordlist_path:
logger.warning("工具配置中未指定 wordlist使用默认超时: %d", min_timeout)
return min_timeout
# 展开用户目录(~
wordlist_path = os.path.expanduser(wordlist_path)
# 检查文件是否存在
if not os.path.exists(wordlist_path):
logger.warning("字典文件不存在: %s,使用默认超时: %d", wordlist_path, min_timeout)
return min_timeout
# 使用 wc -l 快速统计字典行数
result = subprocess.run(
['wc', '-l', wordlist_path],
capture_output=True,
text=True,
check=True
)
# wc -l 输出格式:行数 + 空格 + 文件名
line_count = int(result.stdout.strip().split()[0])
# 计算超时时间
timeout = int(line_count * base_per_word)
# 设置合理的下限(不再设置上限)
timeout = max(min_timeout, timeout)
timeout = max(min_timeout, int(line_count * base_per_word))
logger.info(
"目录扫描超时计算 - 字典: %s, 行数: %d, 基础时间: %.3f秒/词, 计算超时: %d",
wordlist_path, line_count, base_per_word, timeout
)
return timeout
except subprocess.CalledProcessError as e:
logger.error("统计字典行数失败: %s", e)
# 失败时返回默认超时
return min_timeout
except (ValueError, IndexError) as e:
logger.error("解析字典行数失败: %s", e)
return min_timeout
except Exception as e:
logger.error("计算超时时间异常: %s", e)
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
logger.error("计算超时时间失败: %s", e)
return min_timeout
def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> int:
"""
从单个工具配置中获取 max_workers 参数
Args:
tool_config: 单个工具的配置字典,如 {'max_workers': 10, 'threads': 5, ...}
default: 默认值,默认为 5
Returns:
int: max_workers 值
"""
"""从单个工具配置中获取 max_workers 参数"""
if not isinstance(tool_config, dict):
return default
# 支持 max_workers 和 max-workersYAML 中划线会被转换)
max_workers = tool_config.get('max_workers') or tool_config.get('max-workers')
if max_workers is not None and isinstance(max_workers, int) and max_workers > 0:
if isinstance(max_workers, int) and max_workers > 0:
return max_workers
return default
def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]:
def _export_site_urls(
target_id: int,
directory_scan_dir: Path
) -> Tuple[str, int]:
"""
导出目标下的所有站点 URL 到文件(支持懒加载)
导出目标下的所有站点 URL 到文件
Args:
target_id: 目标 ID
target_name: 目标名称(用于懒加载创建默认站点)
directory_scan_dir: 目录扫描目录
Returns:
tuple: (sites_file, site_count)
Raises:
ValueError: 站点数量为 0
"""
logger.info("Step 1: 导出目标的所有站点 URL")
sites_file = str(directory_scan_dir / 'sites.txt')
export_result = export_sites_task(
target_id=target_id,
output_file=sites_file,
batch_size=1000 # 每次读取 1000 条,优化内存占用
batch_size=1000
)
site_count = export_result['total_count']
logger.info(
"✓ 站点 URL 导出完成 - 文件: %s, 数量: %d",
export_result['output_file'],
site_count
)
if site_count == 0:
logger.warning("目标下没有站点,无法执行目录扫描")
# 不抛出异常,由上层决定如何处理
# raise ValueError("目标下没有站点,无法执行目录扫描")
return export_result['output_file'], site_count
def _run_scans_sequentially(
enabled_tools: dict,
sites_file: str,
directory_scan_dir: Path,
scan_id: int,
target_id: int,
site_count: int,
target_name: str
) -> tuple[int, int, list]:
"""
串行执行目录扫描任务(支持多工具)- 已废弃,保留用于兼容
Args:
enabled_tools: 启用的工具配置字典
sites_file: 站点文件路径
directory_scan_dir: 目录扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
site_count: 站点数量
target_name: 目标名称(用于错误日志)
Returns:
tuple: (total_directories, processed_sites, failed_sites)
"""
# 读取站点列表
sites = []
with open(sites_file, 'r', encoding='utf-8') as f:
for line in f:
site_url = line.strip()
if site_url:
sites.append(site_url)
logger.info("准备扫描 %d 个站点,使用工具: %s", len(sites), ', '.join(enabled_tools.keys()))
total_directories = 0
processed_sites_set = set() # 使用 set 避免重复计数
failed_sites = []
# 遍历每个工具
for tool_name, tool_config in enabled_tools.items():
logger.info("="*60)
logger.info("使用工具: %s", tool_name)
logger.info("="*60)
# 如果配置了 wordlist_name则先确保本地存在对应的字典文件含 hash 校验)
wordlist_name = tool_config.get('wordlist_name')
if wordlist_name:
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
failed_sites.extend(sites)
continue
# 逐个站点执行扫描
for idx, site_url in enumerate(sites, 1):
logger.info(
"[%d/%d] 开始扫描站点: %s (工具: %s)",
idx, len(sites), site_url, tool_name
)
# 使用统一的命令构建器
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={
'url': site_url
},
tool_config=tool_config
)
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
continue
# 单个站点超时:从配置中获取(支持 'auto' 动态计算)
# ffuf 逐个站点扫描timeout 就是单个站点的超时时间
site_timeout = tool_config.get('timeout', 300)
if site_timeout == 'auto':
# 动态计算超时时间(基于字典行数)
site_timeout = calculate_directory_scan_timeout(tool_config)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}")
# 生成日志文件路径
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = directory_scan_dir / f"{tool_name}_{timestamp}_{idx}.log"
try:
# 直接调用 task串行执行
result = run_and_stream_save_directories_task(
cmd=command,
tool_name=tool_name, # 新增:工具名称
scan_id=scan_id,
target_id=target_id,
site_url=site_url,
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=site_timeout,
log_file=str(log_file) # 新增:日志文件路径
)
total_directories += result.get('created_directories', 0)
processed_sites_set.add(site_url) # 使用 set 记录成功的站点
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, len(sites), site_url,
result.get('created_directories', 0)
)
except subprocess.TimeoutExpired as exc:
# 超时异常单独处理
failed_sites.append(site_url)
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 超时配置: %d\n"
"注意:超时前已解析的目录数据已保存到数据库,但扫描未完全完成。",
idx, len(sites), site_url, site_timeout
)
except Exception as exc:
# 其他异常
failed_sites.append(site_url)
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, len(sites), site_url, exc
)
# 每 10 个站点输出进度
if idx % 10 == 0:
logger.info(
"进度: %d/%d (%.1f%%) - 已发现 %d 个目录",
idx, len(sites), idx/len(sites)*100, total_directories
)
# 计算成功和失败的站点数
processed_count = len(processed_sites_set)
if failed_sites:
logger.warning(
"部分站点扫描失败: %d/%d",
len(failed_sites), len(sites)
)
logger.info(
"✓ 串行目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
processed_count, len(sites), len(failed_sites), total_directories
)
return total_directories, processed_count, failed_sites
def _generate_log_filename(tool_name: str, site_url: str, directory_scan_dir: Path) -> Path:
"""
生成唯一的日志文件名
使用 URL 的 hash 确保并发时不会冲突
Args:
tool_name: 工具名称
site_url: 站点 URL
directory_scan_dir: 目录扫描目录
Returns:
Path: 日志文件路径
"""
url_hash = hashlib.md5(site_url.encode()).hexdigest()[:8]
def _generate_log_filename(
tool_name: str,
site_url: str,
directory_scan_dir: Path
) -> Path:
"""生成唯一的日志文件名(使用 URL 的 hash 确保并发时不会冲突)"""
url_hash = hashlib.md5(
site_url.encode(),
usedforsecurity=False
).hexdigest()[:8]
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
return directory_scan_dir / f"{tool_name}_{url_hash}_{timestamp}.log"
def _prepare_tool_wordlist(tool_name: str, tool_config: dict) -> bool:
"""准备工具的字典文件,返回是否成功"""
wordlist_name = tool_config.get('wordlist_name')
if not wordlist_name:
return True
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
return True
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
return False
def _build_scan_params(
tool_name: str,
tool_config: dict,
sites: List[str],
directory_scan_dir: Path,
site_timeout: int
) -> Tuple[List[dict], List[str]]:
"""构建所有站点的扫描参数,返回 (scan_params_list, failed_sites)"""
scan_params_list = []
failed_sites = []
for idx, site_url in enumerate(sites, 1):
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={'url': site_url},
tool_config=tool_config
)
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
scan_params_list.append({
'idx': idx,
'site_url': site_url,
'command': command,
'log_file': str(log_file),
'timeout': site_timeout
})
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
return scan_params_list, failed_sites
def _execute_batch(
batch_params: List[dict],
tool_name: str,
scan_id: int,
target_id: int,
directory_scan_dir: Path,
total_sites: int
) -> Tuple[int, List[str]]:
"""执行一批扫描任务,返回 (directories_found, failed_sites)"""
directories_found = 0
failed_sites = []
# 提交任务
futures = []
for params in batch_params:
future = run_and_stream_save_directories_task.submit(
cmd=params['command'],
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
site_url=params['site_url'],
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=params['timeout'],
log_file=params['log_file']
)
futures.append((params['idx'], params['site_url'], future))
# 等待结果
for idx, site_url, future in futures:
try:
result = future.result()
dirs_count = result.get('created_directories', 0)
directories_found += dirs_count
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, total_sites, site_url, dirs_count
)
except Exception as exc:
failed_sites.append(site_url)
if 'timeout' in str(exc).lower():
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
idx, total_sites, site_url, exc
)
else:
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, total_sites, site_url, exc
)
return directories_found, failed_sites
def _run_scans_concurrently(
enabled_tools: dict,
sites_file: str,
directory_scan_dir: Path,
scan_id: int,
target_id: int,
site_count: int,
target_name: str
) -> Tuple[int, int, List[str]]:
"""
并发执行目录扫描任务(使用 ThreadPoolTaskRunner
Args:
enabled_tools: 启用的工具配置字典
sites_file: 站点文件路径
directory_scan_dir: 目录扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
site_count: 站点数量
target_name: 目标名称(用于错误日志)
并发执行目录扫描任务
Returns:
tuple: (total_directories, processed_sites, failed_sites)
"""
# 读取站点列表
sites: List[str] = []
with open(sites_file, 'r', encoding='utf-8') as f:
for line in f:
site_url = line.strip()
if site_url:
sites.append(site_url)
sites = [line.strip() for line in f if line.strip()]
if not sites:
logger.warning("站点列表为空")
return 0, 0, []
logger.info(
"准备并发扫描 %d 个站点,使用工具: %s",
len(sites), ', '.join(enabled_tools.keys())
)
total_directories = 0
processed_sites_count = 0
failed_sites: List[str] = []
# 遍历每个工具
for tool_name, tool_config in enabled_tools.items():
# 每个工具独立获取 max_workers 配置
max_workers = _get_max_workers(tool_config)
logger.info("="*60)
logger.info("=" * 60)
logger.info("使用工具: %s (并发模式, max_workers=%d)", tool_name, max_workers)
logger.info("="*60)
logger.info("=" * 60)
user_log(scan_id, "directory_scan", f"Running {tool_name}")
# 如果配置了 wordlist_name则先确保本地存在对应的字典文件含 hash 校验)
wordlist_name = tool_config.get('wordlist_name')
if wordlist_name:
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
failed_sites.extend(sites)
continue
# 计算超时时间(所有站点共用)
# 准备字典文件
if not _prepare_tool_wordlist(tool_name, tool_config):
failed_sites.extend(sites)
continue
# 计算超时时间
site_timeout = tool_config.get('timeout', 300)
if site_timeout == 'auto':
site_timeout = calculate_directory_scan_timeout(tool_config)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}")
# 准备所有站点的扫描参数
scan_params_list = []
for idx, site_url in enumerate(sites, 1):
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={'url': site_url},
tool_config=tool_config
)
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
scan_params_list.append({
'idx': idx,
'site_url': site_url,
'command': command,
'log_file': str(log_file),
'timeout': site_timeout
})
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
logger.info("✓ 工具 %s 动态计算 timeout: %d", tool_name, site_timeout)
# 构建扫描参数
scan_params_list, build_failed = _build_scan_params(
tool_name, tool_config, sites, directory_scan_dir, site_timeout
)
failed_sites.extend(build_failed)
if not scan_params_list:
logger.warning("没有有效的扫描任务")
continue
# ============================================================
# 分批执行策略:控制实际并发的 ffuf 进程数
# ============================================================
# 分批执行
total_tasks = len(scan_params_list)
logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers)
# 进度里程碑跟踪
last_progress_percent = 0
tool_directories = 0
tool_processed = 0
batch_num = 0
for batch_start in range(0, total_tasks, max_workers):
batch_end = min(batch_start + max_workers, total_tasks)
batch_params = scan_params_list[batch_start:batch_end]
batch_num += 1
logger.info("执行第 %d 批任务(%d-%d/%d...", batch_num, batch_start + 1, batch_end, total_tasks)
# 提交当前批次的任务(非阻塞,立即返回 future
futures = []
for params in batch_params:
future = run_and_stream_save_directories_task.submit(
cmd=params['command'],
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
site_url=params['site_url'],
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=params['timeout'],
log_file=params['log_file']
)
futures.append((params['idx'], params['site_url'], future))
# 等待当前批次所有任务完成(阻塞,确保本批完成后再启动下一批)
for idx, site_url, future in futures:
try:
result = future.result() # 阻塞等待单个任务完成
directories_found = result.get('created_directories', 0)
total_directories += directories_found
tool_directories += directories_found
processed_sites_count += 1
tool_processed += 1
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, len(sites), site_url, directories_found
)
except Exception as exc:
failed_sites.append(site_url)
if 'timeout' in str(exc).lower() or isinstance(exc, subprocess.TimeoutExpired):
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
idx, len(sites), site_url, exc
)
else:
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, len(sites), site_url, exc
)
batch_num = batch_start // max_workers + 1
logger.info(
"执行第 %d 批任务(%d-%d/%d...",
batch_num, batch_start + 1, batch_end, total_tasks
)
dirs_found, batch_failed = _execute_batch(
batch_params, tool_name, scan_id, target_id,
directory_scan_dir, len(sites)
)
total_directories += dirs_found
tool_directories += dirs_found
tool_processed += len(batch_params) - len(batch_failed)
processed_sites_count += len(batch_params) - len(batch_failed)
failed_sites.extend(batch_failed)
# 进度里程碑:每 20% 输出一次
current_progress = int((batch_end / total_tasks) * 100)
if current_progress >= last_progress_percent + 20:
user_log(scan_id, "directory_scan", f"Progress: {batch_end}/{total_tasks} sites scanned")
user_log(
scan_id, "directory_scan",
f"Progress: {batch_end}/{total_tasks} sites scanned"
)
last_progress_percent = (current_progress // 20) * 20
# 工具完成日志(开发者日志 + 用户日志)
logger.info(
"✓ 工具 %s 执行完成 - 已处理站点: %d/%d, 发现目录: %d",
tool_name, tool_processed, total_tasks, tool_directories
)
user_log(scan_id, "directory_scan", f"{tool_name} completed: found {tool_directories} directories")
# 输出汇总信息
if failed_sites:
logger.warning(
"部分站点扫描失败: %d/%d",
len(failed_sites), len(sites)
user_log(
scan_id, "directory_scan",
f"{tool_name} completed: found {tool_directories} directories"
)
if failed_sites:
logger.warning("部分站点扫描失败: %d/%d", len(failed_sites), len(sites))
logger.info(
"✓ 并发目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
processed_sites_count, len(sites), len(failed_sites), total_directories
)
return total_directories, processed_sites_count, failed_sites
@flow(
name="directory_scan",
name="directory_scan",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
@@ -570,64 +396,31 @@ def directory_scan_flow(
) -> dict:
"""
目录扫描 Flow
主要功能:
1. 从 target 获取所有站点的 URL
2. 对每个站点 URL 执行目录扫描(支持 ffuf 等工具)
3. 流式保存扫描结果到数据库 Directory 表
工作流程:
Step 0: 创建工作目录
Step 1: 导出站点 URL 列表到文件(供扫描工具使用)
Step 2: 验证工具配置
Step 3: 并发执行扫描工具并实时保存结果(使用 ThreadPoolTaskRunner
ffuf 输出字段:
- url: 发现的目录/文件 URL
- length: 响应内容长度
- status: HTTP 状态码
- words: 响应内容单词数
- lines: 响应内容行数
- content_type: 内容类型
- duration: 请求耗时
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置字典
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'sites_file': str,
'site_count': int,
'total_directories': int, # 发现的总目录数
'processed_sites': int, # 成功处理的站点数
'failed_sites_count': int, # 失败的站点数
'executed_tasks': list
}
Raises:
ValueError: 参数错误
RuntimeError: 执行失败
dict: 扫描结果
"""
try:
wait_for_system_load(context="directory_scan_flow")
logger.info(
"="*60 + "\n" +
"开始目录扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始目录扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "directory_scan", "Starting directory scan")
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
@@ -639,14 +432,14 @@ def directory_scan_flow(
raise ValueError("scan_workspace_dir 不能为空")
if not enabled_tools:
raise ValueError("enabled_tools 不能为空")
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
# Step 1: 导出站点 URL(支持懒加载)
sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir)
# Step 1: 导出站点 URL
sites_file, site_count = _export_site_urls(target_id, directory_scan_dir)
if site_count == 0:
logger.warning("跳过目录扫描:没有站点可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "directory_scan", "Skipped: no sites to scan", "warning")
@@ -662,16 +455,16 @@ def directory_scan_flow(
'failed_sites_count': 0,
'executed_tasks': ['export_sites']
}
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
tool_info = []
for tool_name, tool_config in enabled_tools.items():
mw = _get_max_workers(tool_config)
tool_info.append(f"{tool_name}(max_workers={mw})")
tool_info = [
f"{name}(max_workers={_get_max_workers(cfg)})"
for name, cfg in enabled_tools.items()
]
logger.info("✓ 启用工具: %s", ', '.join(tool_info))
# Step 3: 并发执行扫描工具并实时保存结果
# Step 3: 并发执行扫描
logger.info("Step 3: 并发执行扫描工具并实时保存结果")
total_directories, processed_sites, failed_sites = _run_scans_concurrently(
enabled_tools=enabled_tools,
@@ -679,19 +472,20 @@ def directory_scan_flow(
directory_scan_dir=directory_scan_dir,
scan_id=scan_id,
target_id=target_id,
site_count=site_count,
target_name=target_name
)
# 检查是否所有站点都失败
if processed_sites == 0 and site_count > 0:
logger.warning("所有站点扫描均失败 - 总站点数: %d, 失败数: %d", site_count, len(failed_sites))
# 不抛出异常,让扫描继续
# 记录 Flow 完成
logger.warning(
"所有站点扫描均失败 - 总站点数: %d, 失败数: %d",
site_count, len(failed_sites)
)
logger.info("✓ 目录扫描完成 - 发现目录: %d", total_directories)
user_log(scan_id, "directory_scan", f"directory_scan completed: found {total_directories} directories")
user_log(
scan_id, "directory_scan",
f"directory_scan completed: found {total_directories} directories"
)
return {
'success': True,
'scan_id': scan_id,
@@ -704,7 +498,7 @@ def directory_scan_flow(
'failed_sites_count': len(failed_sites),
'executed_tasks': ['export_sites', 'run_and_stream_save_directories']
}
except Exception as e:
logger.exception("目录扫描失败: %s", e)
raise
raise

View File

@@ -10,26 +10,22 @@
- 流式处理输出,批量更新数据库
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
from datetime import datetime
from pathlib import Path
from prefect import flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.tasks.fingerprint_detect import (
export_urls_for_fingerprint_task,
run_xingfinger_and_stream_update_tech_task,
)
from apps.scan.utils import build_scan_command, user_log
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
logger = logging.getLogger(__name__)
@@ -42,22 +38,19 @@ def calculate_fingerprint_detect_timeout(
) -> int:
"""
根据 URL 数量计算超时时间
公式:超时时间 = URL 数量 × 每 URL 基础时间
最小值300秒
无上限
最小值300秒,无上限
Args:
url_count: URL 数量
base_per_url: 每 URL 基础时间(秒),默认 10秒
min_timeout: 最小超时时间(秒),默认 300秒
Returns:
int: 计算出的超时时间(秒)
"""
timeout = int(url_count * base_per_url)
return max(min_timeout, timeout)
return max(min_timeout, int(url_count * base_per_url))
@@ -70,17 +63,17 @@ def _export_urls(
) -> tuple[str, int]:
"""
导出 URL 到文件
Args:
target_id: 目标 ID
fingerprint_dir: 指纹识别目录
source: 数据源类型
Returns:
tuple: (urls_file, total_count)
"""
logger.info("Step 1: 导出 URL 列表 (source=%s)", source)
urls_file = str(fingerprint_dir / 'urls.txt')
export_result = export_urls_for_fingerprint_task(
target_id=target_id,
@@ -88,15 +81,14 @@ def _export_urls(
source=source,
batch_size=1000
)
total_count = export_result['total_count']
logger.info(
"✓ URL 导出完成 - 文件: %s, 数量: %d",
export_result['output_file'],
total_count
)
return export_result['output_file'], total_count
@@ -111,7 +103,7 @@ def _run_fingerprint_detect(
) -> tuple[dict, list]:
"""
执行指纹识别任务
Args:
enabled_tools: 已启用的工具配置字典
urls_file: URL 文件路径
@@ -120,56 +112,54 @@ def _run_fingerprint_detect(
scan_id: 扫描任务 ID
target_id: 目标 ID
source: 数据源类型
Returns:
tuple: (tool_stats, failed_tools)
"""
tool_stats = {}
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 获取指纹库路径
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
fingerprint_paths = get_fingerprint_paths(lib_names)
if not fingerprint_paths:
reason = f"没有可用的指纹库: {lib_names}"
logger.warning(reason)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 将指纹库路径合并到 tool_config用于命令构建
tool_config_with_paths = {**tool_config, **fingerprint_paths}
# 3. 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='fingerprint_detect',
command_params={
'urls_file': urls_file
},
command_params={'urls_file': urls_file},
tool_config=tool_config_with_paths
)
except Exception as e:
reason = f"命令构建失败: {str(e)}"
reason = f"命令构建失败: {e}"
logger.error("构建 %s 命令失败: %s", tool_name, e)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 4. 计算超时时间
timeout = calculate_fingerprint_detect_timeout(url_count)
# 5. 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = fingerprint_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
tool_name, url_count, timeout, list(fingerprint_paths.keys())
)
user_log(scan_id, "fingerprint_detect", f"Running {tool_name}: {command}")
# 6. 执行扫描任务
try:
result = run_xingfinger_and_stream_update_tech_task(
@@ -183,14 +173,14 @@ def _run_fingerprint_detect(
log_file=str(log_file),
batch_size=100
)
tool_stats[tool_name] = {
'command': command,
'result': result,
'timeout': timeout,
'fingerprint_libs': list(fingerprint_paths.keys())
}
tool_updated = result.get('updated_count', 0)
logger.info(
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
@@ -199,20 +189,23 @@ def _run_fingerprint_detect(
tool_updated,
result.get('not_found_count', 0)
)
user_log(scan_id, "fingerprint_detect", f"{tool_name} completed: identified {tool_updated} fingerprints")
user_log(
scan_id, "fingerprint_detect",
f"{tool_name} completed: identified {tool_updated} fingerprints"
)
except Exception as exc:
reason = str(exc)
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(scan_id, "fingerprint_detect", f"{tool_name} failed: {reason}", "error")
if failed_tools:
logger.warning(
"以下指纹识别工具执行失败: %s",
', '.join([f['tool'] for f in failed_tools])
)
return tool_stats, failed_tools
@@ -232,53 +225,38 @@ def fingerprint_detect_flow(
) -> dict:
"""
指纹识别 Flow
主要功能:
1. 从数据库导出目标下所有 WebSite URL 到文件
2. 使用 xingfinger 进行技术栈识别
3. 解析结果并更新 WebSite.tech 字段(合并去重)
工作流程:
Step 0: 创建工作目录
Step 1: 导出 URL 列表
Step 2: 解析配置,获取启用的工具
Step 3: 执行 xingfinger 并解析结果
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置xingfinger
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'urls_file': str,
'url_count': int,
'processed_records': int,
'updated_count': int,
'created_count': int,
'snapshot_count': int,
'executed_tasks': list,
'tool_stats': dict
}
dict: 扫描结果
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="fingerprint_detect_flow")
logger.info(
"="*60 + "\n" +
"开始指纹识别\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始指纹识别 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "fingerprint_detect", "Starting fingerprint detection")
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
@@ -288,46 +266,26 @@ def fingerprint_detect_flow(
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
# 数据源类型(当前只支持 website
source = 'website'
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
# Step 1: 导出 URL支持懒加载
urls_file, url_count = _export_urls(target_id, fingerprint_dir, source)
if url_count == 0:
logger.warning("跳过指纹识别:没有 URL 可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "fingerprint_detect", "Skipped: no URLs to scan", "warning")
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': 0,
'processed_records': 0,
'updated_count': 0,
'created_count': 0,
'snapshot_count': 0,
'executed_tasks': ['export_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}
return _build_empty_result(scan_id, target_name, scan_workspace_dir, urls_file)
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
# Step 3: 执行指纹识别
logger.info("Step 3: 执行指纹识别")
tool_stats, failed_tools = _run_fingerprint_detect(
@@ -339,24 +297,37 @@ def fingerprint_detect_flow(
target_id=target_id,
source=source
)
# 动态生成已执行的任务列表
executed_tasks = ['export_urls_for_fingerprint']
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats.keys()])
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats])
# 汇总所有工具的结果
total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values())
total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values())
total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values())
total_snapshots = sum(stats['result'].get('snapshot_count', 0) for stats in tool_stats.values())
total_processed = sum(
stats['result'].get('processed_records', 0) for stats in tool_stats.values()
)
total_updated = sum(
stats['result'].get('updated_count', 0) for stats in tool_stats.values()
)
total_created = sum(
stats['result'].get('created_count', 0) for stats in tool_stats.values()
)
total_snapshots = sum(
stats['result'].get('snapshot_count', 0) for stats in tool_stats.values()
)
# 记录 Flow 完成
logger.info("✓ 指纹识别完成 - 识别指纹: %d", total_updated)
user_log(scan_id, "fingerprint_detect", f"fingerprint_detect completed: identified {total_updated} fingerprints")
successful_tools = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
user_log(
scan_id, "fingerprint_detect",
f"fingerprint_detect completed: identified {total_updated} fingerprints"
)
successful_tools = [
name for name in enabled_tools
if name not in [f['tool'] for f in failed_tools]
]
return {
'success': True,
'scan_id': scan_id,
@@ -378,7 +349,7 @@ def fingerprint_detect_flow(
'details': tool_stats
}
}
except ValueError as e:
logger.error("配置错误: %s", e)
raise
@@ -388,3 +359,33 @@ def fingerprint_detect_flow(
except Exception as e:
logger.exception("指纹识别失败: %s", e)
raise
def _build_empty_result(
scan_id: int,
target_name: str,
scan_workspace_dir: str,
urls_file: str
) -> dict:
"""构建空结果(无 URL 可扫描时)"""
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': 0,
'processed_records': 0,
'updated_count': 0,
'created_count': 0,
'snapshot_count': 0,
'executed_tasks': ['export_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}

View File

@@ -1,4 +1,4 @@
"""
"""
端口扫描 Flow
负责编排端口扫描的完整流程
@@ -10,25 +10,23 @@
- 配置由 YAML 解析
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
import subprocess
from datetime import datetime
from pathlib import Path
from typing import Callable
from prefect import flow
from apps.scan.tasks.port_scan import (
export_hosts_task,
run_and_stream_save_ports_task
)
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.utils import config_parser, build_scan_command, user_log
from apps.scan.tasks.port_scan import (
export_hosts_task,
run_and_stream_save_ports_task,
)
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
logger = logging.getLogger(__name__)
@@ -40,28 +38,19 @@ def calculate_port_scan_timeout(
) -> int:
"""
根据目标数量和端口数量计算超时时间
计算公式:超时时间 = 目标数 × 端口数 × base_per_pair
超时范围60秒 ~ 2天172800秒
超时范围60秒 ~ 无上限
Args:
tool_config: 工具配置字典包含端口配置ports, top-ports等
file_path: 目标文件路径(域名/IP列表
base_per_pair: 每个"端口-目标对"的基础时间(秒),默认 0.5秒
Returns:
int: 计算出的超时时间(秒),范围60 ~ 172800
Example:
# 100个目标 × 100个端口 × 0.5秒 = 5000秒
# 10个目标 × 1000个端口 × 0.5秒 = 5000秒
timeout = calculate_port_scan_timeout(
tool_config={'top-ports': 100},
file_path='/path/to/domains.txt'
)
int: 计算出的超时时间(秒),最小 60 秒
"""
try:
# 1. 统计目标数量
result = subprocess.run(
['wc', '-l', file_path],
capture_output=True,
@@ -69,88 +58,74 @@ def calculate_port_scan_timeout(
check=True
)
target_count = int(result.stdout.strip().split()[0])
# 2. 解析端口数量
port_count = _parse_port_count(tool_config)
# 3. 计算超时时间
# 总工作量 = 目标数 × 端口数
total_work = target_count * port_count
timeout = int(total_work * base_per_pair)
# 4. 设置合理的下限(不再设置上限)
min_timeout = 60 # 最小 60 秒
timeout = max(min_timeout, timeout)
timeout = max(60, int(total_work * base_per_pair))
logger.info(
f"计算端口扫描 timeout - "
f"目标数: {target_count}, "
f"端口数: {port_count}, "
f"总工作量: {total_work}, "
f"超时: {timeout}"
"计算端口扫描 timeout - 目标数: %d, 端口数: %d, 总工作量: %d, 超时: %d",
target_count, port_count, total_work, timeout
)
return timeout
except Exception as e:
logger.warning(f"计算 timeout 失败: {e},使用默认值 600秒")
logger.warning("计算 timeout 失败: %s,使用默认值 600秒", e)
return 600
def _parse_port_count(tool_config: dict) -> int:
"""
从工具配置中解析端口数量
优先级:
1. top-ports: N → 返回 N
2. ports: "80,443,8080" → 返回逗号分隔的数量
3. ports: "1-1000" → 返回范围的大小
4. ports: "1-65535" → 返回 65535
5. 默认 → 返回 100naabu 默认扫描 top 100
Args:
tool_config: 工具配置字典
Returns:
int: 端口数量
"""
# 1. 检查 top-ports 配置
# 检查 top-ports 配置
if 'top-ports' in tool_config:
top_ports = tool_config['top-ports']
if isinstance(top_ports, int) and top_ports > 0:
return top_ports
logger.warning(f"top-ports 配置无效: {top_ports},使用默认值")
# 2. 检查 ports 配置
logger.warning("top-ports 配置无效: %s,使用默认值", top_ports)
# 检查 ports 配置
if 'ports' in tool_config:
ports_str = str(tool_config['ports']).strip()
# 2.1 逗号分隔的端口列表80,443,8080
# 逗号分隔的端口列表80,443,8080
if ',' in ports_str:
port_list = [p.strip() for p in ports_str.split(',') if p.strip()]
return len(port_list)
# 2.2 端口范围1-1000
return len([p.strip() for p in ports_str.split(',') if p.strip()])
# 端口范围1-1000
if '-' in ports_str:
try:
start, end = ports_str.split('-', 1)
start_port = int(start.strip())
end_port = int(end.strip())
if 1 <= start_port <= end_port <= 65535:
return end_port - start_port + 1
logger.warning(f"端口范围无效: {ports_str},使用默认值")
logger.warning("端口范围无效: %s,使用默认值", ports_str)
except ValueError:
logger.warning(f"端口范围解析失败: {ports_str},使用默认值")
# 2.3 单个端口
logger.warning("端口范围解析失败: %s,使用默认值", ports_str)
# 单个端口
try:
port = int(ports_str)
if 1 <= port <= 65535:
return 1
except ValueError:
logger.warning(f"端口配置解析失败: {ports_str},使用默认值")
# 3. 默认值naabu 默认扫描 top 100 端口
logger.warning("端口配置解析失败: %s,使用默认值", ports_str)
# 默认值naabu 默认扫描 top 100 端口
return 100
@@ -160,41 +135,38 @@ def _parse_port_count(tool_config: dict) -> int:
def _export_hosts(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
"""
导出主机列表到文件
根据 Target 类型自动决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名
- IP: 直接写入 target.name
- CIDR: 展开 CIDR 范围内的所有 IP
Args:
target_id: 目标 ID
port_scan_dir: 端口扫描目录
Returns:
tuple: (hosts_file, host_count, target_type)
"""
logger.info("Step 1: 导出主机列表")
hosts_file = str(port_scan_dir / 'hosts.txt')
export_result = export_hosts_task(
target_id=target_id,
output_file=hosts_file,
batch_size=1000 # 每次读取 1000 条,优化内存占用
)
host_count = export_result['total_count']
target_type = export_result.get('target_type', 'unknown')
logger.info(
"✓ 主机列表导出完成 - 类型: %s, 文件: %s, 数量: %d",
target_type,
export_result['output_file'],
host_count
target_type, export_result['output_file'], host_count
)
if host_count == 0:
logger.warning("目标下没有可扫描的主机,无法执行端口扫描")
return export_result['output_file'], host_count, target_type
@@ -208,7 +180,7 @@ def _run_scans_sequentially(
) -> tuple[dict, int, list, list]:
"""
串行执行端口扫描任务
Args:
enabled_tools: 已启用的工具配置字典
domains_file: 域名文件路径
@@ -216,72 +188,56 @@ def _run_scans_sequentially(
scan_id: 扫描任务 ID
target_id: 目标 ID
target_name: 目标名称(用于错误日志)
Returns:
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
注意:端口扫描是流式输出,不生成结果文件
Raises:
RuntimeError: 所有工具均失败
"""
# ==================== 构建命令并串行执行 ====================
tool_stats = {}
processed_records = 0
failed_tools = [] # 记录失败的工具(含原因)
# for循环执行工具按顺序串行运行每个启用的端口扫描工具
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 构建完整命令(变量替换)
# 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='port_scan',
command_params={
'domains_file': domains_file # 对应 {domains_file}
},
tool_config=tool_config #yaml的工具配置
command_params={'domains_file': domains_file},
tool_config=tool_config
)
except Exception as e:
reason = f"命令构建失败: {str(e)}"
logger.error(f"构建 {tool_name} 命令失败: {e}")
reason = f"命令构建失败: {e}"
logger.error("构建 %s 命令失败: %s", tool_name, e)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 获取超时时间(支持 'auto' 动态计算)
# 获取超时时间
config_timeout = tool_config['timeout']
if config_timeout == 'auto':
# 动态计算超时时间
config_timeout = calculate_port_scan_timeout(
tool_config=tool_config,
file_path=str(domains_file)
)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {config_timeout}")
# 2.1 生成日志文件路径
from datetime import datetime
config_timeout = calculate_port_scan_timeout(tool_config, str(domains_file))
logger.info("✓ 工具 %s 动态计算 timeout: %d", tool_name, config_timeout)
# 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = port_scan_dir / f"{tool_name}_{timestamp}.log"
# 3. 执行扫描任务
logger.info("开始执行 %s 扫描(超时: %d秒)...", tool_name, config_timeout)
user_log(scan_id, "port_scan", f"Running {tool_name}: {command}")
# 执行扫描任务
try:
# 直接调用 task串行执行
# 注意:端口扫描是流式输出到 stdout不使用 output_file
result = run_and_stream_save_ports_task(
cmd=command,
tool_name=tool_name, # 工具名称
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
cwd=str(port_scan_dir),
shell=True,
batch_size=1000,
timeout=config_timeout,
log_file=str(log_file) # 新增:日志文件路径
log_file=str(log_file)
)
tool_stats[tool_name] = {
'command': command,
'result': result,
@@ -289,15 +245,10 @@ def _run_scans_sequentially(
}
tool_records = result.get('processed_records', 0)
processed_records += tool_records
logger.info(
"✓ 工具 %s 流式处理完成 - 记录数: %d",
tool_name, tool_records
)
logger.info("✓ 工具 %s 流式处理完成 - 记录数: %d", tool_name, tool_records)
user_log(scan_id, "port_scan", f"{tool_name} completed: found {tool_records} ports")
except subprocess.TimeoutExpired as exc:
# 超时异常单独处理
# 注意:流式处理任务超时时,已解析的数据已保存到数据库
except subprocess.TimeoutExpired:
reason = f"timeout after {config_timeout}s"
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.warning(
@@ -307,40 +258,39 @@ def _run_scans_sequentially(
)
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
except Exception as exc:
# 其他异常
reason = str(exc)
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
if failed_tools:
logger.warning(
"以下扫描工具执行失败: %s",
', '.join([f['tool'] for f in failed_tools])
)
if not tool_stats:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
logger.warning("所有端口扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
# 返回空结果,不抛出异常,让扫描继续
return {}, 0, [], failed_tools
# 动态计算成功的工具列表
successful_tool_names = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
successful_tool_names = [
name for name in enabled_tools
if name not in [f['tool'] for f in failed_tools]
]
logger.info(
"✓ 串行端口扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
len(tool_stats), len(enabled_tools),
', '.join(successful_tool_names) if successful_tool_names else '',
', '.join([f['tool'] for f in failed_tools]) if failed_tools else ''
)
return tool_stats, processed_records, successful_tool_names, failed_tools
@flow(
name="port_scan",
name="port_scan",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
@@ -355,19 +305,19 @@ def port_scan_flow(
) -> dict:
"""
端口扫描 Flow
主要功能:
1. 扫描目标域名/IP 的开放端口
2. 保存 host + ip + port 三元映射到 HostPortMapping 表
输出资产:
- HostPortMapping主机端口映射host + ip + port 三元组)
工作流程:
Step 0: 创建工作目录
Step 1: 导出域名列表到文件(供扫描工具使用)
Step 2: 解析配置,获取启用的工具
Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库(→ HostPortMapping
Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库
Args:
scan_id: 扫描任务 ID
@@ -377,35 +327,15 @@ def port_scan_flow(
enabled_tools: 启用的工具配置字典
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'hosts_file': str,
'host_count': int,
'processed_records': int,
'executed_tasks': list,
'tool_stats': {
'total': int, # 总工具数
'successful': int, # 成功工具数
'failed': int, # 失败工具数
'successful_tools': list[str], # 成功工具列表 ['naabu_active']
'failed_tools': list[dict], # 失败工具列表 [{'tool': 'naabu_passive', 'reason': '超时'}]
'details': dict # 详细执行结果(保留向后兼容)
}
}
dict: 扫描结果
Raises:
ValueError: 配置错误
RuntimeError: 执行失败
Note:
端口扫描工具(如 naabu会解析域名获取 IP输出 host + ip + port 三元组。
同一 host 可能对应多个 IPCDN、负载均衡因此使用三元映射表存储。
"""
try:
# 参数验证
wait_for_system_load(context="port_scan_flow")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
@@ -416,25 +346,20 @@ def port_scan_flow(
raise ValueError("scan_workspace_dir 不能为空")
if not enabled_tools:
raise ValueError("enabled_tools 不能为空")
logger.info(
"="*60 + "\n" +
"开始端口扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始端口扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "port_scan", "Starting port scan")
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
# Step 1: 导出主机列表到文件(根据 Target 类型自动决定内容)
# Step 1: 导出主机列表
hosts_file, host_count, target_type = _export_hosts(target_id, port_scan_dir)
if host_count == 0:
logger.warning("跳过端口扫描:没有主机可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "port_scan", "Skipped: no hosts to scan", "warning")
@@ -457,14 +382,11 @@ def port_scan_flow(
'details': {}
}
}
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info(
"✓ 启用工具: %s",
', '.join(enabled_tools.keys())
)
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
# Step 3: 串行执行扫描工具
logger.info("Step 3: 串行执行扫描工具")
tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially(
@@ -475,15 +397,13 @@ def port_scan_flow(
target_id=target_id,
target_name=target_name
)
# 记录 Flow 完成
logger.info("✓ 端口扫描完成 - 发现端口: %d", processed_records)
user_log(scan_id, "port_scan", f"port_scan completed: found {processed_records} ports")
# 动态生成已执行的任务列表
executed_tasks = ['export_hosts', 'parse_config']
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats.keys()])
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats])
return {
'success': True,
'scan_id': scan_id,

View File

@@ -11,43 +11,33 @@
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
"""
# Django 环境初始化
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
from pathlib import Path
from typing import Optional
from prefect import flow
from apps.scan.tasks.screenshot import capture_screenshots_task
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
)
from apps.scan.utils import user_log
from apps.scan.services.target_export_service import (
get_urls_with_fallback,
DataSource,
on_scan_flow_running,
)
from apps.scan.providers import TargetProvider
from apps.scan.services.target_export_service import DataSource, get_urls_with_fallback
from apps.scan.tasks.screenshot import capture_screenshots_task
from apps.scan.utils import user_log, wait_for_system_load
logger = logging.getLogger(__name__)
# URL 来源到 DataSource 的映射
_SOURCE_MAPPING = {
'websites': DataSource.WEBSITE,
'endpoints': DataSource.ENDPOINT,
}
def _parse_screenshot_config(enabled_tools: dict) -> dict:
"""
解析截图配置
Args:
enabled_tools: 启用的工具配置
Returns:
截图配置字典
"""
# 从 enabled_tools 中获取 playwright 配置
"""解析截图配置"""
playwright_config = enabled_tools.get('playwright', {})
return {
'concurrency': playwright_config.get('concurrency', 5),
'url_sources': playwright_config.get('url_sources', ['websites'])
@@ -55,33 +45,54 @@ def _parse_screenshot_config(enabled_tools: dict) -> dict:
def _map_url_sources_to_data_sources(url_sources: list[str]) -> list[str]:
"""
将配置中的 url_sources 映射为 DataSource 常量
Args:
url_sources: 配置中的来源列表,如 ['websites', 'endpoints']
Returns:
DataSource 常量列表
"""
source_mapping = {
'websites': DataSource.WEBSITE,
'endpoints': DataSource.ENDPOINT,
}
"""将配置中的 url_sources 映射为 DataSource 常量"""
sources = []
for source in url_sources:
if source in source_mapping:
sources.append(source_mapping[source])
if source in _SOURCE_MAPPING:
sources.append(_SOURCE_MAPPING[source])
else:
logger.warning("未知的 URL 来源: %s,跳过", source)
# 添加默认回退(从 subdomain 构造)
sources.append(DataSource.DEFAULT)
return sources
def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str, list[str]]:
"""从 Provider 收集 URL"""
logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__)
urls = list(provider.iter_urls())
blacklist_filter = provider.get_blacklist_filter()
if blacklist_filter:
urls = [url for url in urls if blacklist_filter.is_allowed(url)]
return urls, 'provider', ['provider']
def _collect_urls_from_database(
target_id: int,
url_sources: list[str]
) -> tuple[list[str], str, list[str]]:
"""从数据库收集 URL带黑名单过滤和回退"""
data_sources = _map_url_sources_to_data_sources(url_sources)
result = get_urls_with_fallback(target_id, sources=data_sources)
return result['urls'], result['source'], result['tried_sources']
def _build_empty_result(scan_id: int, target_name: str) -> dict:
"""构建空结果"""
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'total_urls': 0,
'successful': 0,
'failed': 0,
'synced': 0
}
@flow(
name="screenshot",
log_prints=True,
@@ -99,17 +110,11 @@ def screenshot_flow(
) -> dict:
"""
截图 Flow
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
工作流程:
Step 1: 解析配置
Step 2: 收集 URL 列表
Step 3: 批量截图并保存快照
Step 4: 同步到资产表
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
@@ -117,116 +122,87 @@ def screenshot_flow(
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置
provider: TargetProvider 实例(新模式,可选)
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'total_urls': int,
'successful': int,
'failed': int,
'synced': int
}
截图结果字典
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="screenshot_flow")
mode = 'Provider' if provider else 'Legacy'
logger.info(
"="*60 + "\n" +
"开始截图扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
f" Mode: {'Provider' if provider else 'Legacy'}\n" +
"="*60
"开始截图扫描 - Scan ID: %s, Target: %s, Mode: %s",
scan_id, target_name, mode
)
user_log(scan_id, "screenshot", "Starting screenshot capture")
# Step 1: 解析配置
config = _parse_screenshot_config(enabled_tools)
concurrency = config['concurrency']
url_sources = config['url_sources']
logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, url_sources)
logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, config['url_sources'])
# Step 2: 收集 URL 列表
if provider is not None:
# Provider 模式:使用 TargetProvider 获取 URL
logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__)
urls = list(provider.iter_urls())
# 应用黑名单过滤(如果有)
blacklist_filter = provider.get_blacklist_filter()
if blacklist_filter:
urls = [url for url in urls if blacklist_filter.is_allowed(url)]
source_info = 'provider'
tried_sources = ['provider']
urls, source_info, tried_sources = _collect_urls_from_provider(provider)
else:
# 传统模式:使用统一服务收集 URL带黑名单过滤和回退
data_sources = _map_url_sources_to_data_sources(url_sources)
result = get_urls_with_fallback(target_id, sources=data_sources)
urls = result['urls']
source_info = result['source']
tried_sources = result['tried_sources']
urls, source_info, tried_sources = _collect_urls_from_database(
target_id, config['url_sources']
)
logger.info(
"URL 收集完成 - 来源: %s, 数量: %d, 尝试过: %s",
source_info, len(urls), tried_sources
)
if not urls:
logger.warning("没有可截图的 URL跳过截图任务")
user_log(scan_id, "screenshot", "Skipped: no URLs to capture", "warning")
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'total_urls': 0,
'successful': 0,
'failed': 0,
'synced': 0
}
user_log(scan_id, "screenshot", f"Found {len(urls)} URLs to capture (source: {source_info})")
return _build_empty_result(scan_id, target_name)
user_log(
scan_id, "screenshot",
f"Found {len(urls)} URLs to capture (source: {source_info})"
)
# Step 3: 批量截图
logger.info("Step 3: 批量截图 - %d 个 URL", len(urls))
logger.info("批量截图 - %d 个 URL", len(urls))
capture_result = capture_screenshots_task(
urls=urls,
scan_id=scan_id,
target_id=target_id,
config={'concurrency': concurrency}
)
# Step 4: 同步到资产表
logger.info("Step 4: 同步截图到资产表")
logger.info("同步截图到资产表")
from apps.asset.services.screenshot_service import ScreenshotService
screenshot_service = ScreenshotService()
synced = screenshot_service.sync_screenshots_to_asset(scan_id, target_id)
synced = ScreenshotService().sync_screenshots_to_asset(scan_id, target_id)
total = capture_result['total']
successful = capture_result['successful']
failed = capture_result['failed']
logger.info(
"✓ 截图完成 - 总数: %d, 成功: %d, 失败: %d, 同步: %d",
capture_result['total'], capture_result['successful'], capture_result['failed'], synced
total, successful, failed, synced
)
user_log(
scan_id, "screenshot",
f"Screenshot completed: {capture_result['successful']}/{capture_result['total']} captured, {synced} synced"
f"Screenshot completed: {successful}/{total} captured, {synced} synced"
)
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'total_urls': capture_result['total'],
'successful': capture_result['successful'],
'failed': capture_result['failed'],
'total_urls': total,
'successful': successful,
'failed': failed,
'synced': synced
}
except Exception as e:
logger.exception("截图 Flow 失败: %s", e)
user_log(scan_id, "screenshot", f"Screenshot failed: {e}", "error")
except Exception:
logger.exception("截图 Flow 失败")
user_log(scan_id, "screenshot", "Screenshot failed", "error")
raise

View File

@@ -1,4 +1,3 @@
"""
站点扫描 Flow
@@ -11,23 +10,22 @@
- 配置由 YAML 解析
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
from datetime import datetime
import logging
import os
import subprocess
import time
from pathlib import Path
from typing import Callable
from prefect import flow
from apps.scan.tasks.site_scan import export_site_urls_task, run_and_stream_save_websites_task
# Django 环境初始化导入即生效pylint: disable=unused-import
from apps.common.prefect_django_setup import setup_django_for_prefect # noqa: F401
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.utils import config_parser, build_scan_command, user_log
from apps.scan.tasks.site_scan import export_site_urls_task, run_and_stream_save_websites_task
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
logger = logging.getLogger(__name__)
@@ -191,7 +189,6 @@ def _run_scans_sequentially(
timeout = max(dynamic_timeout, config_timeout)
# 2.1 生成日志文件路径(类似端口扫描)
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = site_scan_dir / f"{tool_name}_{timestamp}.log"
@@ -233,7 +230,7 @@ def _run_scans_sequentially(
)
user_log(scan_id, "site_scan", f"{tool_name} completed: found {tool_created} websites")
except subprocess.TimeoutExpired as exc:
except subprocess.TimeoutExpired:
# 超时异常单独处理
reason = f"timeout after {timeout}s"
failed_tools.append({'tool': tool_name, 'reason': reason})
@@ -368,6 +365,9 @@ def site_scan_flow(
RuntimeError: 执行失败
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="site_scan_flow")
logger.info(
"="*60 + "\n" +
"开始站点扫描\n" +

File diff suppressed because it is too large Load Diff

View File

@@ -10,22 +10,18 @@ URL Fetch 主 Flow
- 统一进行 httpx 验证(如果启用)
"""
# Django 环境初始化
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
from pathlib import Path
from datetime import datetime
from pathlib import Path
from prefect import flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.utils import user_log
from apps.scan.utils import user_log, wait_for_system_load
from .domain_name_url_fetch_flow import domain_name_url_fetch_flow
from .sites_url_fetch_flow import sites_url_fetch_flow
@@ -43,13 +39,10 @@ SITES_FILE_TOOLS = {'katana'}
POST_PROCESS_TOOLS = {'uro', 'httpx'}
def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
"""
将启用的工具按输入类型分类
Returns:
tuple: (domain_name_tools, sites_file_tools, uro_config, httpx_config)
"""
@@ -76,23 +69,23 @@ def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
def _merge_and_deduplicate_urls(result_files: list, url_fetch_dir: Path) -> tuple[str, int]:
"""合并并去重 URL"""
from apps.scan.tasks.url_fetch import merge_and_deduplicate_urls_task
merged_file = merge_and_deduplicate_urls_task(
result_files=result_files,
result_dir=str(url_fetch_dir)
)
# 统计唯一 URL 数量
unique_url_count = 0
if Path(merged_file).exists():
with open(merged_file, 'r') as f:
with open(merged_file, 'r', encoding='utf-8') as f:
unique_url_count = sum(1 for line in f if line.strip())
logger.info(
"✓ URL 合并去重完成 - 合并文件: %s, 唯一 URL 数: %d",
merged_file, unique_url_count
)
return merged_file, unique_url_count
@@ -103,12 +96,12 @@ def _clean_urls_with_uro(
) -> tuple[str, int, int]:
"""使用 uro 清理合并后的 URL 列表"""
from apps.scan.tasks.url_fetch import clean_urls_task
raw_timeout = uro_config.get('timeout', 60)
whitelist = uro_config.get('whitelist')
blacklist = uro_config.get('blacklist')
filters = uro_config.get('filters')
# 计算超时时间
if isinstance(raw_timeout, str) and raw_timeout == 'auto':
timeout = calculate_timeout_by_line_count(
@@ -124,7 +117,7 @@ def _clean_urls_with_uro(
except (TypeError, ValueError):
logger.warning("uro timeout 配置无效(%s),使用默认 60 秒", raw_timeout)
timeout = 60
result = clean_urls_task(
input_file=merged_file,
output_dir=str(url_fetch_dir),
@@ -133,12 +126,12 @@ def _clean_urls_with_uro(
blacklist=blacklist,
filters=filters
)
if result['success']:
return result['output_file'], result['output_count'], result['removed_count']
else:
logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误'))
return merged_file, result['input_count'], 0
logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误'))
return merged_file, result['input_count'], 0
def _validate_and_stream_save_urls(
@@ -151,25 +144,25 @@ def _validate_and_stream_save_urls(
"""使用 httpx 验证 URL 存活并流式保存到数据库"""
from apps.scan.utils import build_scan_command
from apps.scan.tasks.url_fetch import run_and_stream_save_urls_task
logger.info("开始使用 httpx 验证 URL 存活状态...")
# 统计待验证的 URL 数量
try:
with open(merged_file, 'r') as f:
with open(merged_file, 'r', encoding='utf-8') as f:
url_count = sum(1 for _ in f)
logger.info("待验证 URL 数量: %d", url_count)
except Exception as e:
except OSError as e:
logger.error("读取 URL 文件失败: %s", e)
return 0
if url_count == 0:
logger.warning("没有需要验证的 URL")
return 0
# 构建 httpx 命令
command_params = {'url_file': merged_file}
try:
command = build_scan_command(
tool_name='httpx',
@@ -177,21 +170,19 @@ def _validate_and_stream_save_urls(
command_params=command_params,
tool_config=httpx_config
)
except Exception as e:
except (ValueError, KeyError) as e:
logger.error("构建 httpx 命令失败: %s", e)
logger.warning("降级处理:将直接保存所有 URL不验证存活")
return _save_urls_to_database(merged_file, scan_id, target_id)
# 计算超时时间
raw_timeout = httpx_config.get('timeout', 'auto')
timeout = 3600
if isinstance(raw_timeout, str) and raw_timeout == 'auto':
# 按 URL 行数计算超时时间:每行 3 秒,最小 60 秒
timeout = max(60, url_count * 3)
logger.info(
"自动计算 httpx 超时时间(按行数,每行 3 秒,最小 60 秒): url_count=%d, timeout=%d",
url_count,
timeout,
url_count, timeout
)
else:
try:
@@ -199,49 +190,44 @@ def _validate_and_stream_save_urls(
except (TypeError, ValueError):
timeout = 3600
logger.info("使用配置的 httpx 超时时间: %d", timeout)
# 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = url_fetch_dir / f"httpx_validation_{timestamp}.log"
# 流式执行
try:
result = run_and_stream_save_urls_task(
cmd=command,
tool_name='httpx',
scan_id=scan_id,
target_id=target_id,
cwd=str(url_fetch_dir),
shell=True,
timeout=timeout,
log_file=str(log_file)
)
saved = result.get('saved_urls', 0)
logger.info(
"✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)",
saved, (saved / url_count * 100) if url_count > 0 else 0
)
return saved
except Exception as e:
logger.error("httpx 流式验证失败: %s", e, exc_info=True)
raise
result = run_and_stream_save_urls_task(
cmd=command,
tool_name='httpx',
scan_id=scan_id,
target_id=target_id,
cwd=str(url_fetch_dir),
shell=True,
timeout=timeout,
log_file=str(log_file)
)
saved = result.get('saved_urls', 0)
logger.info(
"✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)",
saved, (saved / url_count * 100) if url_count > 0 else 0
)
return saved
def _save_urls_to_database(merged_file: str, scan_id: int, target_id: int) -> int:
"""保存 URL 到数据库(不验证存活)"""
from apps.scan.tasks.url_fetch import save_urls_task
result = save_urls_task(
urls_file=merged_file,
scan_id=scan_id,
target_id=target_id
)
saved_count = result.get('saved_urls', 0)
logger.info("✓ URL 保存完成 - 保存数量: %d", saved_count)
return saved_count
@@ -261,7 +247,7 @@ def url_fetch_flow(
) -> dict:
"""
URL 获取主 Flow
执行流程:
1. 准备工作目录
2. 按输入类型分类工具domain_name / sites_file / 后处理)
@@ -271,36 +257,32 @@ def url_fetch_flow(
4. 合并所有子 Flow 的结果并去重
5. uro 去重(如果启用)
6. httpx 验证(如果启用)
Args:
scan_id: 扫描 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作目录
enabled_tools: 启用的工具配置
Returns:
dict: 扫描结果
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="url_fetch_flow")
logger.info(
"="*60 + "\n" +
"开始 URL 获取扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始 URL 获取扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "url_fetch", "Starting URL fetch")
# Step 1: 准备工作目录
logger.info("Step 1: 准备工作目录")
from apps.scan.utils import setup_scan_directory
url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch')
# Step 2: 分类工具(按输入类型)
logger.info("Step 2: 分类工具")
domain_name_tools, sites_file_tools, uro_config, httpx_config = _classify_tools(enabled_tools)
logger.info(
@@ -317,15 +299,14 @@ def url_fetch_flow(
"URL Fetch 流程需要至少启用一个 URL 获取工具(如 waymore, katana"
"httpx 和 uro 仅用于后处理,不能单独使用。"
)
# Step 3: 并行执行子 Flow
# Step 3: 执行子 Flow
all_result_files = []
all_failed_tools = []
all_successful_tools = []
# 3a: 基于 domain_nametarget_name 的 URL 被动收集(如 waymore
# 3a: 基于 domain_name 的 URL 被动收集(如 waymore
if domain_name_tools:
logger.info("Step 3a: 执行基于 domain_name 的 URL 被动收集子 Flow")
tn_result = domain_name_url_fetch_flow(
scan_id=scan_id,
target_id=target_id,
@@ -336,10 +317,9 @@ def url_fetch_flow(
all_result_files.extend(tn_result.get('result_files', []))
all_failed_tools.extend(tn_result.get('failed_tools', []))
all_successful_tools.extend(tn_result.get('successful_tools', []))
# 3b: 爬虫(以 sites_file 为输入)
if sites_file_tools:
logger.info("Step 3b: 执行爬虫子 Flow")
crawl_result = sites_url_fetch_flow(
scan_id=scan_id,
target_id=target_id,
@@ -350,12 +330,13 @@ def url_fetch_flow(
all_result_files.extend(crawl_result.get('result_files', []))
all_failed_tools.extend(crawl_result.get('failed_tools', []))
all_successful_tools.extend(crawl_result.get('successful_tools', []))
# 检查是否有成功的工具
if not all_result_files:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in all_failed_tools])
error_details = "; ".join([
"%s: %s" % (f['tool'], f['reason']) for f in all_failed_tools
])
logger.warning("所有 URL 获取工具均失败 - 目标: %s, 失败详情: %s", target_name, error_details)
# 返回空结果,不抛出异常,让扫描继续
return {
'success': True,
'scan_id': scan_id,
@@ -366,31 +347,24 @@ def url_fetch_flow(
'successful_tools': [],
'message': '所有 URL 获取工具均无结果'
}
# Step 4: 合并并去重 URL
logger.info("Step 4: 合并并去重 URL")
merged_file, unique_url_count = _merge_and_deduplicate_urls(
merged_file, _ = _merge_and_deduplicate_urls(
result_files=all_result_files,
url_fetch_dir=url_fetch_dir
)
# Step 5: 使用 uro 清理 URL如果启用
url_file_for_validation = merged_file
uro_removed_count = 0
if uro_config and uro_config.get('enabled', False):
logger.info("Step 5: 使用 uro 清理 URL")
url_file_for_validation, cleaned_count, uro_removed_count = _clean_urls_with_uro(
url_file_for_validation, _, _ = _clean_urls_with_uro(
merged_file=merged_file,
uro_config=uro_config,
url_fetch_dir=url_fetch_dir
)
else:
logger.info("Step 5: 跳过 uro 清理(未启用)")
# Step 6: 使用 httpx 验证存活并保存(如果启用)
if httpx_config and httpx_config.get('enabled', False):
logger.info("Step 6: 使用 httpx 验证 URL 存活并流式保存")
saved_count = _validate_and_stream_save_urls(
merged_file=url_file_for_validation,
httpx_config=httpx_config,
@@ -399,17 +373,16 @@ def url_fetch_flow(
target_id=target_id
)
else:
logger.info("Step 6: 保存到数据库(未启用 httpx 验证)")
saved_count = _save_urls_to_database(
merged_file=url_file_for_validation,
scan_id=scan_id,
target_id=target_id
)
# 记录 Flow 完成
logger.info("✓ URL 获取完成 - 保存 endpoints: %d", saved_count)
user_log(scan_id, "url_fetch", f"url_fetch completed: found {saved_count} endpoints")
user_log(scan_id, "url_fetch", "url_fetch completed: found %d endpoints" % saved_count)
# 构建已执行的任务列表
executed_tasks = ['setup_directory', 'classify_tools']
if domain_name_tools:
@@ -423,7 +396,7 @@ def url_fetch_flow(
executed_tasks.append('httpx_validation_and_save')
else:
executed_tasks.append('save_urls')
return {
'success': True,
'scan_id': scan_id,
@@ -439,7 +412,7 @@ def url_fetch_flow(
'failed_tools': [f['tool'] for f in all_failed_tools]
}
}
except Exception as e:
logger.error("URL 获取扫描失败: %s", e, exc_info=True)
raise

View File

@@ -1,5 +1,6 @@
from apps.common.prefect_django_setup import setup_django_for_prefect
"""
漏洞扫描主 Flow
"""
import logging
from typing import Dict, Tuple
@@ -11,7 +12,7 @@ from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_failed,
)
from apps.scan.configs.command_templates import get_command_template
from apps.scan.utils import user_log
from apps.scan.utils import user_log, wait_for_system_load
from .endpoints_vuln_scan_flow import endpoints_vuln_scan_flow
@@ -62,6 +63,9 @@ def vuln_scan_flow(
- nuclei: 通用漏洞扫描(流式保存,支持模板 commit hash 同步)
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="vuln_scan_flow")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:

View File

@@ -4,11 +4,11 @@
定义 ProviderContext 数据类和 TargetProvider 抽象基类。
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Iterator, Optional, TYPE_CHECKING
import ipaddress
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, Optional
if TYPE_CHECKING:
from apps.common.utils import BlacklistFilter
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
class ProviderContext:
"""
Provider 上下文,携带元数据
Attributes:
target_id: 关联的 Target ID用于结果保存None 表示临时扫描(不保存)
scan_id: 扫描任务 ID
@@ -32,130 +32,83 @@ class ProviderContext:
class TargetProvider(ABC):
"""
扫描目标提供者抽象基类
职责:
- 提供扫描目标域名、IP、URL 等)的迭代器
- 提供黑名单过滤器
- 携带上下文信息target_id, scan_id 等)
- 自动展开 CIDR子类无需关心
使用方式:
provider = create_target_provider(target_id=123)
for host in provider.iter_hosts():
print(host)
"""
def __init__(self, context: Optional[ProviderContext] = None):
"""
初始化 Provider
Args:
context: Provider 上下文None 时创建默认上下文
"""
self._context = context or ProviderContext()
@property
def context(self) -> ProviderContext:
"""返回 Provider 上下文"""
return self._context
@staticmethod
def _expand_host(host: str) -> Iterator[str]:
"""
展开主机(如果是 CIDR 则展开为多个 IP否则直接返回
这是一个内部方法,由 iter_hosts() 自动调用。
Args:
host: 主机字符串IP/域名/CIDR
Yields:
str: 单个主机IP 或域名)
示例:
"192.168.1.0/30""192.168.1.1", "192.168.1.2"
"192.168.1.1""192.168.1.1"
"example.com""example.com"
"invalid" → (跳过,不返回)
"""
from apps.common.validators import detect_target_type
from apps.targets.models import Target
host = host.strip()
if not host:
return
# 统一使用 detect_target_type 检测类型
try:
target_type = detect_target_type(host)
if target_type == Target.TargetType.CIDR:
# 展开 CIDR
network = ipaddress.ip_network(host, strict=False)
if network.num_addresses == 1:
yield str(network.network_address)
else:
for ip in network.hosts():
yield str(ip)
elif target_type == Target.TargetType.IP:
# 单个 IP
yield host
elif target_type == Target.TargetType.DOMAIN:
# 域名
yield from (str(ip) for ip in network.hosts())
elif target_type in (Target.TargetType.IP, Target.TargetType.DOMAIN):
yield host
except ValueError as e:
# 无效格式,跳过并记录警告
logger.warning("跳过无效的主机格式 '%s': %s", host, str(e))
def iter_hosts(self) -> Iterator[str]:
"""
迭代主机列表(域名/IP
自动展开 CIDR子类无需关心。
Yields:
str: 主机名或 IP 地址(单个,不包含 CIDR
"""
"""迭代主机列表(域名/IP自动展开 CIDR"""
for host in self._iter_raw_hosts():
yield from self._expand_host(host)
@abstractmethod
def _iter_raw_hosts(self) -> Iterator[str]:
"""
迭代原始主机列表(可能包含 CIDR
子类实现此方法,返回原始数据即可,不需要处理 CIDR 展开。
Yields:
str: 主机名、IP 地址或 CIDR
"""
"""迭代原始主机列表(可能包含 CIDR子类实现"""
pass
@abstractmethod
def iter_urls(self) -> Iterator[str]:
"""
迭代 URL 列表
Yields:
str: URL 字符串
"""
"""迭代 URL 列表"""
pass
@abstractmethod
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
"""
获取黑名单过滤器
Returns:
BlacklistFilter: 黑名单过滤器实例,或 None不过滤
"""
"""获取黑名单过滤器,返回 None 表示不过滤"""
pass
@property
def target_id(self) -> Optional[int]:
"""返回关联的 target_id临时扫描返回 None"""
return self._context.target_id
@property
def scan_id(self) -> Optional[int]:
"""返回关联的 scan_id"""

View File

@@ -5,9 +5,9 @@
"""
import logging
from typing import Iterator, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Iterator, Optional
from .base import TargetProvider, ProviderContext
from .base import ProviderContext, TargetProvider
if TYPE_CHECKING:
from apps.common.utils import BlacklistFilter
@@ -18,109 +18,71 @@ logger = logging.getLogger(__name__)
class DatabaseTargetProvider(TargetProvider):
"""
数据库目标提供者 - 从 Target 表及关联资产表查询
这是现有行为的封装,保持向后兼容。
数据来源:
- iter_hosts(): 根据 Target 类型返回域名/IP
- DOMAIN: 根域名 + Subdomain 表
- IP: 直接返回 IP
- CIDR: 使用 _expand_host() 展开为所有主机 IP
- iter_urls(): WebSite/Endpoint 表,带回退链
使用方式:
provider = DatabaseTargetProvider(target_id=123)
for host in provider.iter_hosts():
scan(host)
"""
def __init__(self, target_id: int, context: Optional[ProviderContext] = None):
"""
初始化数据库目标提供者
Args:
target_id: 目标 ID必需
context: Provider 上下文
"""
ctx = context or ProviderContext()
ctx.target_id = target_id
super().__init__(ctx)
self._blacklist_filter: Optional['BlacklistFilter'] = None # 延迟加载
self._blacklist_filter: Optional['BlacklistFilter'] = None
def iter_hosts(self) -> Iterator[str]:
"""
从数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤
重写基类方法以支持黑名单过滤(需要在 CIDR 展开后过滤)
"""
"""从数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤"""
blacklist = self.get_blacklist_filter()
for host in self._iter_raw_hosts():
# 展开 CIDR
for expanded_host in self._expand_host(host):
# 应用黑名单过滤
if not blacklist or blacklist.is_allowed(expanded_host):
yield expanded_host
def _iter_raw_hosts(self) -> Iterator[str]:
"""
从数据库查询原始主机列表(可能包含 CIDR
根据 Target 类型决定数据来源:
- DOMAIN: 根域名 + Subdomain 表
- IP: 直接返回 target.name
- CIDR: 返回 CIDR 字符串(由 iter_hosts() 展开)
注意:此方法不应用黑名单过滤,过滤在 iter_hosts() 中进行
"""
from apps.targets.services import TargetService
from apps.targets.models import Target
"""从数据库查询原始主机列表(可能包含 CIDR"""
from apps.asset.services.asset.subdomain_service import SubdomainService
from apps.targets.models import Target
from apps.targets.services import TargetService
target = TargetService().get_target(self.target_id)
if not target:
logger.warning("Target ID %d 不存在", self.target_id)
return
if target.type == Target.TargetType.DOMAIN:
# 返回根域名
yield target.name
# 返回子域名
subdomain_service = SubdomainService()
for domain in subdomain_service.iter_subdomain_names_by_target(
for domain in SubdomainService().iter_subdomain_names_by_target(
target_id=self.target_id,
chunk_size=1000
):
if domain != target.name: # 避免重复
if domain != target.name:
yield domain
elif target.type == Target.TargetType.IP:
elif target.type in (Target.TargetType.IP, Target.TargetType.CIDR):
yield target.name
elif target.type == Target.TargetType.CIDR:
# 直接返回 CIDR由 iter_hosts() 展开并过滤
yield target.name
def iter_urls(self) -> Iterator[str]:
"""
从数据库查询 URL 列表
使用现有的回退链逻辑Endpoint → WebSite → Default
"""
"""从数据库查询 URL 列表使用回退链Endpoint → WebSite → Default"""
from apps.scan.services.target_export_service import (
_iter_urls_with_fallback, DataSource
DataSource,
_iter_urls_with_fallback,
)
blacklist = self.get_blacklist_filter()
for url, source in _iter_urls_with_fallback(
for url, _ in _iter_urls_with_fallback(
target_id=self.target_id,
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
blacklist_filter=blacklist
):
yield url
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
"""获取黑名单过滤器(延迟加载)"""
if self._blacklist_filter is None:

View File

@@ -12,7 +12,7 @@
import ipaddress
import logging
from pathlib import Path
from typing import Dict, Any, Optional, List, Iterator, Tuple, Callable
from typing import Dict, Any, Optional, List, Iterator, Tuple
from django.db.models import QuerySet
@@ -485,8 +485,7 @@ class TargetExportService:
"""
from apps.targets.services import TargetService
from apps.targets.models import Target
from apps.asset.services.asset.subdomain_service import SubdomainService
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)

View File

@@ -16,7 +16,7 @@ from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)

View File

@@ -11,12 +11,12 @@
- CIDR: 展开 CIDR 范围内的所有 IP
"""
import logging
from typing import Optional
from pathlib import Path
from typing import Optional
from prefect import task
from apps.scan.services.target_export_service import create_export_service
from apps.scan.providers import TargetProvider, DatabaseTargetProvider, ProviderContext
from apps.scan.providers import DatabaseTargetProvider, TargetProvider
logger = logging.getLogger(__name__)
@@ -26,15 +26,14 @@ def export_hosts_task(
output_file: str,
target_id: Optional[int] = None,
provider: Optional[TargetProvider] = None,
batch_size: int = 1000
) -> dict:
"""
导出主机列表到 TXT 文件
支持两种模式:
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
根据 Target 类型自动决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名(流式处理,支持 10万+ 域名)
- IP: 直接写入 target.name单个 IP
@@ -44,7 +43,6 @@ def export_hosts_task(
output_file: 输出文件路径(绝对路径)
target_id: 目标 ID传统模式向后兼容
provider: TargetProvider 实例(新模式)
batch_size: 每次读取的批次大小,默认 1000仅对 DOMAIN 类型有效)
Returns:
dict: {
@@ -58,53 +56,44 @@ def export_hosts_task(
ValueError: 参数错误target_id 和 provider 都未提供)
IOError: 文件写入失败
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# 向后兼容:如果没有提供 provider使用 target_id 创建 DatabaseTargetProvider
if provider is None:
use_legacy_mode = provider is None
if use_legacy_mode:
logger.info("使用传统模式 - Target ID: %d", target_id)
provider = DatabaseTargetProvider(target_id=target_id)
use_legacy_mode = True
else:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
use_legacy_mode = False
# 确保输出目录存在
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 使用 Provider 导出主机列表
# 使用 Provider 导出主机列表iter_hosts 内部已处理黑名单过滤)
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for host in provider.iter_hosts():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(host):
continue
f.write(f"{host}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个主机...", total_count)
logger.info("✓ 主机列表导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
# 构建返回值
result = {
'success': True,
'output_file': str(output_path),
'total_count': total_count,
}
# 传统模式:保持返回值格式不变(向后兼容)
if use_legacy_mode:
# 获取 target_type仅传统模式需要
from apps.targets.services import TargetService
target = TargetService().get_target(target_id)
result['target_type'] = target.type if target else 'unknown'
return result

View File

@@ -4,37 +4,40 @@
提供扫描相关的工具函数。
"""
from .directory_cleanup import remove_directory
from . import config_parser
from .command_builder import build_scan_command
from .command_executor import execute_and_wait, execute_stream
from .wordlist_helpers import ensure_wordlist_local
from .directory_cleanup import remove_directory
from .nuclei_helpers import ensure_nuclei_templates_local
from .performance import FlowPerformanceTracker, CommandPerformanceTracker
from .workspace_utils import setup_scan_workspace, setup_scan_directory
from .performance import CommandPerformanceTracker, FlowPerformanceTracker
from .system_load import check_system_load, wait_for_system_load
from .user_logger import user_log
from . import config_parser
from .wordlist_helpers import ensure_wordlist_local
from .workspace_utils import setup_scan_directory, setup_scan_workspace
__all__ = [
# 目录清理
'remove_directory',
# 工作空间
'setup_scan_workspace', # 创建 Scan 根工作空间
'setup_scan_directory', # 创建扫描子目录
'setup_scan_workspace',
'setup_scan_directory',
# 命令构建
'build_scan_command', # 扫描工具命令构建(基于 f-string
'build_scan_command',
# 命令执行
'execute_and_wait', # 等待式执行(文件输出)
'execute_stream', # 流式执行(实时处理)
'execute_and_wait',
'execute_stream',
# 系统负载
'wait_for_system_load',
'check_system_load',
# 字典文件
'ensure_wordlist_local', # 确保本地字典文件(含 hash 校验)
'ensure_wordlist_local',
# Nuclei 模板
'ensure_nuclei_templates_local', # 确保本地模板(含 commit hash 校验)
'ensure_nuclei_templates_local',
# 性能监控
'FlowPerformanceTracker', # Flow 性能追踪器(含系统资源采样)
'CommandPerformanceTracker', # 命令性能追踪器
'FlowPerformanceTracker',
'CommandPerformanceTracker',
# 扫描日志
'user_log', # 用户可见扫描日志记录
'user_log',
# 配置解析
'config_parser',
]

View File

@@ -12,16 +12,18 @@
import logging
import os
from django.conf import settings
import re
import signal
import subprocess
import threading
import time
from collections import deque
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Optional, Generator
from django.conf import settings
try:
# 可选依赖:用于根据 CPU / 内存负载做动态并发控制
import psutil
@@ -669,33 +671,68 @@ class CommandExecutor:
def _read_log_tail(self, log_file: Path, max_lines: int = MAX_LOG_TAIL_LINES) -> str:
"""
读取日志文件的末尾部分
读取日志文件的末尾部分(常量内存实现)
使用 seek 从文件末尾往前读取,避免将整个文件加载到内存。
Args:
log_file: 日志文件路径
max_lines: 最大读取行数
Returns:
日志内容(字符串),读取失败返回错误提示
"""
if not log_file.exists():
logger.debug("日志文件不存在: %s", log_file)
return ""
if log_file.stat().st_size == 0:
file_size = log_file.stat().st_size
if file_size == 0:
logger.debug("日志文件为空: %s", log_file)
return ""
# 每次读取的块大小8KB足够容纳大多数日志行
chunk_size = 8192
def decode_line(line_bytes: bytes) -> str:
"""解码单行:优先 UTF-8失败则降级 latin-1"""
try:
return line_bytes.decode('utf-8')
except UnicodeDecodeError:
return line_bytes.decode('latin-1', errors='replace')
try:
with open(log_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
return ''.join(lines[-max_lines:] if len(lines) > max_lines else lines)
except UnicodeDecodeError as e:
logger.warning("日志文件编码错误 (%s): %s", log_file, e)
return f"(无法读取日志文件: 编码错误 - {e})"
with open(log_file, 'rb') as f:
lines_found: deque[bytes] = deque()
remaining = b''
position = file_size
while position > 0 and len(lines_found) < max_lines:
read_size = min(chunk_size, position)
position -= read_size
f.seek(position)
chunk = f.read(read_size) + remaining
parts = chunk.split(b'\n')
# 最前面的部分可能不完整,留到下次处理
remaining = parts[0]
# 其余部分是完整的行(从后往前收集,用 appendleft 保持顺序)
for part in reversed(parts[1:]):
if len(lines_found) >= max_lines:
break
lines_found.appendleft(part)
# 处理文件开头的行
if remaining and len(lines_found) < max_lines:
lines_found.appendleft(remaining)
return '\n'.join(decode_line(line) for line in lines_found)
except PermissionError as e:
logger.warning("日志文件权限不足 (%s): %s", log_file, e)
return f"(无法读取日志文件: 权限不足)"
return "(无法读取日志文件: 权限不足)"
except IOError as e:
logger.warning("日志文件读取IO错误 (%s): %s", log_file, e)
return f"(无法读取日志文件: IO错误 - {e})"

View File

@@ -0,0 +1,77 @@
"""
系统负载检查工具
提供统一的系统负载检查功能,用于:
- Flow 入口处检查系统资源是否充足
- 防止在高负载时启动新的扫描任务
"""
import logging
import time
import psutil
from django.conf import settings
logger = logging.getLogger(__name__)
# 动态并发控制阈值(可在 Django settings 中覆盖)
SCAN_CPU_HIGH: float = getattr(settings, 'SCAN_CPU_HIGH', 90.0)
SCAN_MEM_HIGH: float = getattr(settings, 'SCAN_MEM_HIGH', 80.0)
SCAN_LOAD_CHECK_INTERVAL: int = getattr(settings, 'SCAN_LOAD_CHECK_INTERVAL', 180)
def _get_current_load() -> tuple[float, float]:
"""获取当前 CPU 和内存使用率"""
return psutil.cpu_percent(interval=0.5), psutil.virtual_memory().percent
def wait_for_system_load(
cpu_threshold: float = SCAN_CPU_HIGH,
mem_threshold: float = SCAN_MEM_HIGH,
check_interval: int = SCAN_LOAD_CHECK_INTERVAL,
context: str = "task"
) -> None:
"""
等待系统负载降到阈值以下
在高负载时阻塞等待,直到 CPU 和内存都低于阈值。
用于 Flow 入口处,防止在资源紧张时启动新任务。
"""
while True:
cpu, mem = _get_current_load()
if cpu < cpu_threshold and mem < mem_threshold:
logger.debug(
"[%s] 系统负载正常: cpu=%.1f%%, mem=%.1f%%",
context, cpu, mem
)
return
logger.info(
"[%s] 系统负载较高,等待资源释放: "
"cpu=%.1f%% (阈值 %.1f%%), mem=%.1f%% (阈值 %.1f%%)",
context, cpu, cpu_threshold, mem, mem_threshold
)
time.sleep(check_interval)
def check_system_load(
cpu_threshold: float = SCAN_CPU_HIGH,
mem_threshold: float = SCAN_MEM_HIGH
) -> dict:
"""
检查当前系统负载(非阻塞)
Returns:
dict: cpu_percent, mem_percent, cpu_threshold, mem_threshold, is_overloaded
"""
cpu, mem = _get_current_load()
return {
'cpu_percent': cpu,
'mem_percent': mem,
'cpu_threshold': cpu_threshold,
'mem_threshold': mem_threshold,
'is_overloaded': cpu >= cpu_threshold or mem >= mem_threshold,
}

View File

@@ -12,16 +12,34 @@ load-plugins = "pylint_django"
[tool.pylint.messages_control]
disable = [
"missing-docstring",
"invalid-name",
"too-few-public-methods",
"no-member",
"import-error",
"no-name-in-module",
"missing-docstring",
"invalid-name",
"too-few-public-methods",
"no-member",
"import-error",
"no-name-in-module",
"wrong-import-position", # 允许函数内导入(防循环依赖)
"import-outside-toplevel", # 同上
"too-many-arguments", # Django 视图/服务方法参数常超过5个
"too-many-locals", # 复杂业务逻辑局部变量多
"duplicate-code", # 某些模式代码相似是正常的
]
[tool.pylint.format]
max-line-length = 120
[tool.pylint.basic]
good-names = ["i", "j", "k", "ex", "Run", "_", "id", "pk", "ip", "url", "db", "qs"]
good-names = [
"i",
"j",
"k",
"ex",
"Run",
"_",
"id",
"pk",
"ip",
"url",
"db",
"qs",
]