perf(screenshot): optimize memory usage and add URL collection fallback logic

- Add iterator(chunk_size=50) to ScreenshotSnapshot query to prevent BinaryField data caching and reduce memory consumption
- Implement fallback logic in URL collection: WebSite → HostPortMapping → Default URL with priority handling
- Update _collect_urls_from_provider to return tuple with data source information for better logging and debugging
- Add detailed logging to track which data source was used during URL collection
- Improve code documentation with clear return type hints and fallback priority explanation
- Prevents memory spikes when processing large screenshot datasets with binary image data
This commit is contained in:
yyhuni
2026-01-11 16:14:48 +08:00
parent ced9f811f4
commit 2d2ec93626
2 changed files with 34 additions and 7 deletions

View File

@@ -146,7 +146,9 @@ class ScreenshotService:
"""
from apps.asset.models import Screenshot, ScreenshotSnapshot
snapshots = ScreenshotSnapshot.objects.filter(scan_id=scan_id)
# 使用 iterator() 避免 QuerySet 缓存大量 BinaryField 数据导致内存飙升
# chunk_size=50: 每次只加载 50 条记录,处理完后释放内存
snapshots = ScreenshotSnapshot.objects.filter(scan_id=scan_id).iterator(chunk_size=50)
count = 0
for snapshot in snapshots:

View File

@@ -31,10 +31,35 @@ def _parse_screenshot_config(enabled_tools: dict) -> dict:
}
def _collect_urls_from_provider(provider: TargetProvider) -> list[str]:
"""从 Provider 收集网站 URL"""
def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str]:
"""
从 Provider 收集网站 URL带回退逻辑
优先级WebSite → HostPortMapping → Default URL
Returns:
tuple: (urls, source)
- urls: URL 列表
- source: 数据来源 ('website' | 'host_port' | 'default')
"""
logger.info("从 Provider 获取网站 URL - Provider: %s", type(provider).__name__)
return list(provider.iter_websites())
# 优先从 WebSite 获取
urls = list(provider.iter_websites())
if urls:
logger.info("使用 WebSite 数据源 - 数量: %d", len(urls))
return urls, "website"
# 回退到 HostPortMapping
urls = list(provider.iter_host_port_urls())
if urls:
logger.info("WebSite 为空,回退到 HostPortMapping - 数量: %d", len(urls))
return urls, "host_port"
# 最终回退到默认 URL
urls = list(provider.iter_default_urls())
logger.info("HostPortMapping 为空,回退到默认 URL - 数量: %d", len(urls))
return urls, "default"
def _build_empty_result(scan_id: int, target_name: str) -> dict:
@@ -96,9 +121,9 @@ def screenshot_flow(
concurrency = config['concurrency']
logger.info("截图配置 - 并发: %d", concurrency)
# Step 2: 从 Provider 收集 URL 列表
urls = _collect_urls_from_provider(provider)
logger.info("URL 收集完成 - 数量: %d", len(urls))
# Step 2: 从 Provider 收集 URL 列表(带回退逻辑)
urls, source = _collect_urls_from_provider(provider)
logger.info("URL 收集完成 - 来源: %s, 数量: %d", source, len(urls))
if not urls:
logger.warning("没有可截图的 URL跳过截图任务")