Compare commits

..

49 Commits
v1.4.1 ... dev

Author SHA1 Message Date
yyhuni
5345a34cbd 重构:去除prefect 2026-01-11 19:31:47 +08:00
github-actions[bot]
3ca56abc3e chore: bump version to v1.5.12-dev 2026-01-11 09:22:30 +00:00
yyhuni
9703add22d feat(nuclei): support configurable Nuclei templates repository with Gitee mirror
- Add NUCLEI_TEMPLATES_REPO_URL setting to allow runtime configuration of template repository URL
- Refactor install.sh mirror parameter handling to use boolean flag instead of URL string
- Replace hardcoded GitHub repository URL with Gitee mirror option for faster downloads in mainland China
- Update environment variable configuration to persist Nuclei repository URL in .env file
- Improve shell script variable quoting and conditional syntax for better reliability
- Simplify mirror detection logic by using USE_MIRROR boolean flag throughout installation process
- Add support for automatic Gitee mirror selection when --mirror flag is enabled
2026-01-11 17:19:09 +08:00
github-actions[bot]
f5a489e2d6 chore: bump version to v1.5.11-dev 2026-01-11 08:54:04 +00:00
yyhuni
d75a3f6882 fix(task_distributor): adjust high load wait parameters and improve timeout handling
- Increase high load wait interval from 60 to 120 seconds (2 minutes)
- Increase max retries from 10 to 60 to support up to 2 hours total wait time
- Improve timeout message to show actual wait duration in minutes
- Remove duplicate return statement in worker selection logic
- Update notification message to reflect new wait parameters (2 minutes check interval, 2 hours max wait)
- Clean up trailing whitespace in task_distributor.py
- Remove redundant error message from install.sh about missing/incorrect image versions
- Better handling of high load scenarios with clearer logging and user communication
2026-01-11 16:41:05 +08:00
github-actions[bot]
59e48e5b15 chore: bump version to v1.5.10-dev 2026-01-11 08:19:39 +00:00
yyhuni
2d2ec93626 perf(screenshot): optimize memory usage and add URL collection fallback logic
- Add iterator(chunk_size=50) to ScreenshotSnapshot query to prevent BinaryField data caching and reduce memory consumption
- Implement fallback logic in URL collection: WebSite → HostPortMapping → Default URL with priority handling
- Update _collect_urls_from_provider to return tuple with data source information for better logging and debugging
- Add detailed logging to track which data source was used during URL collection
- Improve code documentation with clear return type hints and fallback priority explanation
- Prevents memory spikes when processing large screenshot datasets with binary image data
2026-01-11 16:14:56 +08:00
github-actions[bot]
ced9f811f4 chore: bump version to v1.5.8-dev 2026-01-11 08:09:37 +00:00
yyhuni
aa99b26f50 fix(vuln_scan): use tool-specific parameter names for endpoint scanning
- Add conditional logic to use "input_file" parameter for nuclei tool
- Use "endpoints_file" parameter for other scanning tools
- Improve compatibility with different vulnerability scanning tools
- Ensure correct parameter naming based on tool requirements
2026-01-11 15:59:39 +08:00
yyhuni
8342f196db nuclei加入website扫描为默认 2026-01-11 12:13:27 +08:00
yyhuni
1bd2a6ed88 重构:完成provider 2026-01-11 11:15:59 +08:00
yyhuni
033ff89aee 重构:采用provider提供数据 2026-01-11 10:29:27 +08:00
yyhuni
4284a0cd9a refactor(scan): remove deprecated provider implementations and cleanup
- Delete ListTargetProvider implementation and related tests
- Delete PipelineTargetProvider implementation and related tests
- Remove target_export_service.py unused service module
- Remove test files for common properties validation
- Update engine-preset-selector component in frontend
- Remove sponsor acknowledgment section from README
- Simplify provider architecture by consolidating implementations
2026-01-10 23:53:52 +08:00
yyhuni
943a4cb960 docs(docker): remove default credentials from startup message
- Remove hardcoded default username and password display from docker startup script
- Remove warning message about changing password after first login
- Improve security by not exposing default credentials in startup output
- Simplifies startup message output for cleaner user experience
2026-01-10 11:21:14 +08:00
yyhuni
eb2d853b76 docs: remove emoji symbols from README for better accessibility
- Remove shield emoji (🛡️) from main title
- Replace emoji prefixes in navigation links with plain text anchors
- Remove emoji icons from section headers (🌐, 📚, , 📦, 🤝, 📧, 🎁, , 🙏, ⚠️, 🌟, 📄)
- Replace emoji status indicators (, ⚠️, 🔍, 💡, ) with plain text equivalents
- Remove emoji bullet points and replace with standard formatting
- Simplify documentation for improved readability and cross-platform compatibility
2026-01-10 11:17:43 +08:00
github-actions[bot]
1184c18b74 chore: bump version to v1.5.7 2026-01-10 03:10:45 +00:00
yyhuni
8a6f1b6f24 feat(engine): add --force-sub flag for selective engine config updates
- Add --force-sub command flag to init_default_engine management command
- Allow updating only sub-engines while preserving user-customized full scan config
- Update docker/scripts/init-data.sh to always update full scan engine configuration
- Change docker/server/start.sh to use --force flag for initial engine setup
- Improve update.sh with better logging functions and formatted output
- Add color-coded log functions (log_step, log_ok, log_info, log_warn, log_error)
- Enhance update.sh UI with better visual formatting and warning messages
- Refactor error messages and user prompts for improved clarity
- This enables safer upgrades by preserving custom full scan configurations while updating sub-engines
2026-01-10 11:04:42 +08:00
yyhuni
255d505aba refactor(scan): remove deprecated amass engine configurations
- Remove amass_passive engine configuration from subdomain discovery defaults
- Remove amass_active engine configuration from subdomain discovery defaults
- Simplify engine configuration by eliminating unused amass-based scanners
- Streamline the default engine template for better maintainability
2026-01-10 10:51:07 +08:00
github-actions[bot]
d06a9bab1f chore: bump version to v1.5.7-dev 2026-01-10 02:48:21 +00:00
yyhuni
6d5c776bf7 chore: improve version detection and update deployment configuration
- Update version detection to support IMAGE_TAG environment variable for Docker containers
- Add fallback mechanism to check multiple version file paths (/app/VERSION and project root)
- Add IMAGE_TAG environment variable to docker-compose.dev.yml and docker-compose.yml
- Fix frontend access URL in start.sh to include correct port (8083)
- Update upgrade warning message in update.sh to recommend fresh installation with latest code
- Improve robustness of version retrieval with better error handling for missing files
2026-01-10 10:41:36 +08:00
github-actions[bot]
bf058dd67b chore: bump version to v1.5.6-dev 2026-01-10 02:33:15 +00:00
yyhuni
0532d7c8b8 feat(notifications): add WeChat Work (WeChat Enterprise) notification support
- Add wecom notification channel configuration to mock notification settings
- Initialize wecom with disabled state and empty webhook URL by default
- Update notification settings response to include wecom configuration
- Enable WeChat Work as an alternative notification channel alongside Discord
2026-01-10 10:29:33 +08:00
yyhuni
2ee9b5ffa2 更新版本 2026-01-10 10:27:48 +08:00
yyhuni
648a1888d4 增加企业微信 2026-01-10 10:16:01 +08:00
github-actions[bot]
2508268a45 chore: bump version to v1.5.4-dev 2026-01-10 02:10:05 +00:00
yyhuni
c60383940c 提供升级功能 2026-01-10 10:04:07 +08:00
yyhuni
47298c294a 性能优化 2026-01-10 09:44:49 +08:00
yyhuni
eba394e14e 优化:性能优化 2026-01-10 09:44:43 +08:00
yyhuni
592a1958c4 优化ui 2026-01-09 16:52:50 +08:00
yyhuni
38e2856c08 feat(scan): add provider abstraction layer for flexible target sourcing
- Add TargetProvider base class and ProviderContext for unified target acquisition
- Implement DatabaseTargetProvider for database-backed target queries
- Implement ListTargetProvider for in-memory target lists (fast scan phase 1)
- Implement SnapshotTargetProvider for snapshot table reads (fast scan phase 2+)
- Implement PipelineTargetProvider for pipeline stage outputs
- Add comprehensive provider tests covering common properties and individual providers
- Update screenshot_flow to support both legacy mode (target_id) and provider mode
- Add backward compatibility layer for existing task exports (directory, fingerprint, port, site, url_fetch, vuln scans)
- Add task backward compatibility tests
- Update .gitignore to exclude .hypothesis/ cache directory
- Update frontend ANSI log viewer component
- Update backend requirements.txt with new dependencies
- Enables flexible data source integration while maintaining backward compatibility with existing database-driven workflows
2026-01-09 09:02:09 +08:00
yyhuni
f5ad8e68e9 chore(backend): add hypothesis cache directory to gitignore
- Add .hypothesis/ directory to .gitignore to exclude Hypothesis property testing cache files
- Prevents test cache artifacts from being tracked in version control
- Improves repository cleanliness by ignoring generated test data
2026-01-08 11:58:49 +08:00
yyhuni
d5f91a236c Merge branch 'main' of https://github.com/yyhuni/xingrin 2026-01-08 10:37:32 +08:00
yyhuni
24ae8b5aeb docs: restructure README features section with capability tables
- Convert feature descriptions from nested lists to organized capability tables
- Add scanning capability table with tools and descriptions for each feature
- Add platform capability table highlighting core platform features
- Improve readability and scannability of feature documentation
- Maintain scanning pipeline architecture section for reference
- Simplify feature organization for better user comprehension
2026-01-08 10:35:56 +08:00
github-actions[bot]
86f43f94a0 chore: bump version to v1.5.3 2026-01-08 02:17:58 +00:00
yyhuni
53ba03d1e5 支持kali 2026-01-08 10:14:12 +08:00
github-actions[bot]
89c44ebd05 chore: bump version to v1.5.2 2026-01-08 00:20:11 +00:00
yyhuni
e0e3419edb chore(docker): improve worker dockerfile reliability with retry mechanism
- Add retry mechanism for apt-get install to handle ARM64 mirror sync delays
- Use --no-install-recommends flag to reduce image size and installation time
- Split apt-get update and install commands for better layer caching
- Add fallback installation logic for packages in case of initial failure
- Include explanatory comment about ARM64 ports.ubuntu.com potential delays
- Maintain compatibility with both ARM64 and AMD64 architectures
2026-01-08 08:14:24 +08:00
yyhuni
52ee4684a7 chore(docker): add apt-get update before playwright dependencies
- Add apt-get update before installing playwright chromium dependencies
- Ensures package lists are refreshed before installing system dependencies
- Prevents potential package installation failures in Docker builds
2026-01-08 08:09:21 +08:00
yyhuni
ce8cebf11d chore(frontend): update pnpm-lock.yaml with @radix-ui/react-hover-card
- Add @radix-ui/react-hover-card@1.1.15 package resolution entry
- Add package snapshot with all required dependencies and peer dependencies
- Update lock file to reflect new hover card component dependency
- Ensures consistent dependency management across the frontend environment
2026-01-08 07:57:58 +08:00
yyhuni
ec006d8f54 chore(frontend): add @radix-ui/react-hover-card dependency
- Add @radix-ui/react-hover-card v1.1.6 to project dependencies
- Enables hover card UI component functionality for improved user interactions
- Maintains consistency with existing Radix UI component library usage
2026-01-08 07:56:07 +08:00
yyhuni
48976a570f docs: update README with screenshot feature and sponsorship info
- Add screenshot feature documentation to features section with Playwright details
- Include WebP format compression benefits and multi-source URL support
- Add screenshot stage to scan flow architecture diagram with styling
- Add fingerprint library table with counts for public distribution
- Add sponsorship section with WeChat Pay and Alipay QR codes
- Add sponsor appreciation table
- Update frontend dependencies with @radix-ui/react-visually-hidden package
- Remove redundant installation speed note from mirror parameter documentation
- Clean up demo link formatting in online demo section
2026-01-08 07:54:31 +08:00
yyhuni
5da7229873 feat(scan-overview): add yaml configuration tab and improve logs layout
- Add yaml_configuration field to ScanHistorySerializer for backend exposure
- Implement tabbed interface with Logs and Configuration tabs in scan overview
- Add YamlEditor component to display scan configuration in read-only mode
- Refactor logs section to show status bar only when logs tab is active
- Move auto-refresh toggle to logs tab header for better UX
- Add padding to stage progress items for improved visual alignment
- Add internationalization strings for new UI elements (en and zh)
- Update ScanHistory type to include yamlConfiguration field
- Improve tab switching state management with activeTab state
2026-01-08 07:31:54 +08:00
yyhuni
8bb737a9fa feat(scan-history): add auto-refresh toggle and improve layout
- Add auto-refresh toggle switch to scan logs section for manual control
- Implement flexible polling based on auto-refresh state and scan status
- Restructure scan overview layout to use left-right split (stages + logs)
- Move stage progress to left column with vulnerability statistics
- Implement scrollable logs panel on right side with proper height constraints
- Update component imports to use Switch and Label instead of Button
- Add full-height flex layout to parent containers for proper scrolling
- Refactor grid layout from 2-column to fixed-width left + flexible right
- Update translations for new UI elements and labels
- Improve responsive design with better flex constraints and min-height handling
2026-01-07 23:30:27 +08:00
yyhuni
2d018d33f3 优化扫描历史详细页面 2026-01-07 22:44:46 +08:00
yyhuni
0c07cc8497 refactor(scan-flows): simplify logger calls by splitting multiline strings
- Split multiline logger.info() calls into separate single-line calls in initiate_scan_flow.py
- Improved log readability by removing string concatenation with newlines and separators
- Refactored 6 logger.info() calls across sequential, parallel, and completion stages
- Updated subdomain_discovery_flow.py to use consistent single-line logger pattern
- Maintains same log output while improving code maintainability and consistency
2026-01-07 22:21:50 +08:00
yyhuni
225b039985 style(system-logs): adjust log level filter dropdown width
- Increase SelectTrigger width from 100px to 130px for better label visibility
- Improve UI consistency in log toolbar component
- Prevent text truncation in log level filter dropdown
2026-01-07 22:17:07 +08:00
yyhuni
d1624627bc 一级tab加图标 2026-01-07 22:14:42 +08:00
yyhuni
7bb15e4ae4 增加:截图功能 2026-01-07 22:10:51 +08:00
github-actions[bot]
8e8cc29669 chore: bump version to v1.4.1 2026-01-07 01:33:29 +00:00
167 changed files with 9597 additions and 5354 deletions

1
.gitignore vendored
View File

@@ -64,6 +64,7 @@ backend/.env.local
.coverage
htmlcov/
*.cover
.hypothesis/
# ============================
# 后端 (Go) 相关

152
README.md
View File

@@ -1,7 +1,7 @@
<h1 align="center">XingRin - 星环</h1>
<p align="center">
<b>🛡️ 攻击面管理平台 (ASM) | 自动化资产发现与漏洞扫描系统</b>
<b>攻击面管理平台 (ASM) | 自动化资产发现与漏洞扫描系统</b>
</p>
<p align="center">
@@ -12,29 +12,29 @@
</p>
<p align="center">
<a href="#-功能特性">功能特性</a> •
<a href="#-全局资产搜索">资产搜索</a> •
<a href="#-快速开始">快速开始</a> •
<a href="#-文档">文档</a> •
<a href="#-反馈与贡献">反馈与贡献</a>
<a href="#功能特性">功能特性</a> •
<a href="#全局资产搜索">资产搜索</a> •
<a href="#快速开始">快速开始</a> •
<a href="#文档">文档</a> •
<a href="#反馈与贡献">反馈与贡献</a>
</p>
<p align="center">
<sub>🔍 关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | 资产搜索 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
<sub>关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | 资产搜索 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
</p>
---
## 🌐 在线 Demo
## 在线 Demo
👉 **[https://xingrin.vercel.app/](https://xingrin.vercel.app/)**
**[https://xingrin.vercel.app/](https://xingrin.vercel.app/)**
> ⚠️ 仅用于 UI 展示,未接入后端数据库
> 仅用于 UI 展示,未接入后端数据库
---
<p align="center">
<b>🎨 现代化 UI </b>
<b>现代化 UI</b>
</p>
<p align="center">
@@ -44,43 +44,47 @@
<img src="docs/screenshots/quantum-rose.png" alt="Quantum Rose" width="24%">
</p>
## 📚 文档
## 文档
- [📖 技术文档](./docs/README.md) - 技术文档导航(🚧 持续完善中)
- [🚀 快速开始](./docs/quick-start.md) - 一键安装和部署指南
- [🔄 版本管理](./docs/version-management.md) - Git Tag 驱动的自动化版本管理系统
- [📦 Nuclei 模板架构](./docs/nuclei-template-architecture.md) - 模板仓库的存储与同步
- [📖 字典文件架构](./docs/wordlist-architecture.md) - 字典文件的存储与同步
- [🔍 扫描流程架构](./docs/scan-flow-architecture.md) - 完整扫描流程与工具编排
- [技术文档](./docs/README.md) - 技术文档导航(持续完善中)
- [快速开始](./docs/quick-start.md) - 一键安装和部署指南
- [版本管理](./docs/version-management.md) - Git Tag 驱动的自动化版本管理系统
- [Nuclei 模板架构](./docs/nuclei-template-architecture.md) - 模板仓库的存储与同步
- [字典文件架构](./docs/wordlist-architecture.md) - 字典文件的存储与同步
- [扫描流程架构](./docs/scan-flow-architecture.md) - 完整扫描流程与工具编排
---
## 功能特性
## 功能特性
### 🎯 目标与资产管理
- **组织管理** - 多层级目标组织,灵活分组
- **目标管理** - 支持域名、IP目标类型
- **资产发现** - 子域名、网站、端点、目录自动发现
- **资产快照** - 扫描结果快照对比,追踪资产变化
### 扫描能力
### 🔍 漏洞扫描
- **多引擎支持** - 集成 Nuclei 等主流扫描引擎
- **自定义流程** - YAML 配置扫描流程,灵活编排
- **定时扫描** - Cron 表达式配置,自动化周期扫描
| 功能 | 状态 | 工具 | 说明 |
|------|------|------|------|
| 子域名扫描 | 已完成 | Subfinder, Amass, PureDNS | 被动收集 + 主动爆破,聚合 50+ 数据源 |
| 端口扫描 | 已完成 | Naabu | 自定义端口范围 |
| 站点发现 | 已完成 | HTTPX | HTTP 探测,自动获取标题、状态码、技术栈 |
| 指纹识别 | 已完成 | XingFinger | 2.7W+ 指纹规则,多源指纹库 |
| URL 收集 | 已完成 | Waymore, Katana | 历史数据 + 主动爬取 |
| 目录扫描 | 已完成 | FFUF | 高速爆破,智能字典 |
| 漏洞扫描 | 已完成 | Nuclei, Dalfox | 9000+ POC 模板XSS 检测 |
| 站点截图 | 已完成 | Playwright | WebP 高压缩存储 |
### 🚫 黑名单过滤
- **两层黑名单** - 全局黑名单 + Target 级黑名单,灵活控制扫描范围
- **智能规则识别** - 自动识别域名通配符(`*.gov`、IP、CIDR 网段
- **敏感目标保护** - 过滤政府、军事、教育等敏感域名,防止误扫
- **内网过滤** - 支持 `10.0.0.0/8``172.16.0.0/12``192.168.0.0/16` 等私有网段
### 平台能力
### 🔖 指纹识别
- **多源指纹库** - 内置 EHole、Goby、Wappalyzer、Fingers、FingerPrintHub、ARL 等 2.7W+ 指纹规则
- **自动识别** - 扫描流程自动执行,识别 Web 应用技术栈
- **指纹管理** - 支持查询、导入、导出指纹规则
| 功能 | 状态 | 说明 |
|------|------|------|
| 目标管理 | 已完成 | 多层级组织,支持域名/IP 目标 |
| 资产快照 | 已完成 | 扫描结果对比,追踪资产变化 |
| 黑名单过滤 | 已完成 | 全局 + Target 级,支持通配符/CIDR |
| 定时任务 | 已完成 | Cron 表达式,自动化周期扫描 |
| 分布式扫描 | 已完成 | 多 Worker 节点,负载感知调度 |
| 全局搜索 | 已完成 | 表达式语法,多字段组合查询 |
| 通知推送 | 已完成 | 企业微信、Telegram、Discord |
| API 密钥管理 | 已完成 | 可视化配置各数据源 API Key |
#### 扫描流程架构
### 扫描流程架构
完整的扫描流程包括子域名发现、端口扫描、站点发现、指纹识别、URL 收集、目录扫描、漏洞扫描等阶段
@@ -101,6 +105,7 @@ flowchart LR
direction TB
URL["URL 收集<br/>waymore, katana"]
DIR["目录扫描<br/>ffuf"]
SCREENSHOT["站点截图<br/>playwright"]
end
subgraph STAGE3["阶段 3: 漏洞检测"]
@@ -125,12 +130,13 @@ flowchart LR
style FINGER fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
style URL fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
style DIR fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
style SCREENSHOT fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
style VULN fill:#f0b27a,stroke:#e67e22,stroke-width:1px,color:#fff
```
详细说明请查看 [扫描流程架构文档](./docs/scan-flow-architecture.md)
### 🖥️ 分布式架构
### 分布式架构
- **多节点扫描** - 支持部署多个 Worker 节点,横向扩展扫描能力
- **本地节点** - 零配置,安装即自动注册本地 Docker Worker
- **远程节点** - SSH 一键部署远程 VPS 作为扫描节点
@@ -175,7 +181,7 @@ flowchart TB
W3 -.心跳上报.-> REDIS
```
### 🔎 全局资产搜索
### 全局资产搜索
- **多类型搜索** - 支持 Website 和 Endpoint 两种资产类型
- **表达式语法** - 支持 `=`(模糊)、`==`(精确)、`!=`(不等于)操作符
- **逻辑组合** - 支持 `&&` (AND) 和 `||` (OR) 逻辑组合
@@ -199,14 +205,14 @@ host="admin" && tech="php" && status=="200"
url="/api/v1" && status!="404"
```
### 📊 可视化界面
### 可视化界面
- **数据统计** - 资产/漏洞统计仪表盘
- **实时通知** - WebSocket 消息推送
- **通知推送** - 实时企业微信tgdiscard消息推送服务
---
## 📦 快速开始
## 快速开始
### 环境要求
@@ -224,14 +230,13 @@ cd xingrin
# 安装并启动(生产模式)
sudo ./install.sh
# 🇨🇳 中国大陆用户推荐使用镜像加速(第三方加速服务可能会失效,不保证长期可用)
# 中国大陆用户推荐使用镜像加速(第三方加速服务可能会失效,不保证长期可用)
sudo ./install.sh --mirror
```
> **💡 --mirror 参数说明**
> **--mirror 参数说明**
> - 自动配置 Docker 镜像加速(国内镜像源)
> - 加速 Git 仓库克隆Nuclei 模板等)
> - 大幅提升安装速度,避免网络超时
### 访问服务
@@ -254,18 +259,38 @@ sudo ./restart.sh
sudo ./uninstall.sh
```
## 🤝 反馈与贡献
## 反馈与贡献
- 💡 **发现 Bug有新想法比如UI设计功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues) 或者公众号私信
- **发现 Bug有新想法比如UI设计功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues) 或者公众号私信
## 📧 联系
## 联系
- 微信公众号: **塔罗安全学苑**
- 微信群去公众号底下的菜单,有个交流群,点击就可以看到了,链接过期可以私信我拉你
<img src="docs/wechat-qrcode.png" alt="微信公众号" width="200">
### 关注公众号免费领取指纹库
## ⚠️ 免责声明
| 指纹库 | 数量 |
|--------|------|
| ehole.json | 21,977 |
| ARL.yaml | 9,264 |
| goby.json | 7,086 |
| FingerprintHub.json | 3,147 |
> 关注公众号回复「指纹」即可获取
## 赞助支持
如果这个项目对你有帮助谢谢请我能喝杯蜜雪冰城你的star和赞助是我免费更新的动力
<p>
<img src="docs/wx_pay.jpg" alt="微信支付" width="200">
<img src="docs/zfb_pay.jpg" alt="支付宝" width="200">
</p>
## 免责声明
**重要:请在使用前仔细阅读**
@@ -280,30 +305,29 @@ sudo ./uninstall.sh
- 遵守所在地区的法律法规
- 承担因滥用产生的一切后果
## 🌟 Star History
## Star History
如果这个项目对你有帮助,请给一个 Star 支持一下!
如果这个项目对你有帮助,请给一个 Star 支持一下!
[![Star History Chart](https://api.star-history.com/svg?repos=yyhuni/xingrin&type=Date)](https://star-history.com/#yyhuni/xingrin&Date)
## 📄 许可证
## 许可证
本项目采用 [GNU General Public License v3.0](LICENSE) 许可证。
### 允许的用途
- 个人学习和研究
- 商业和非商业使用
- 修改和分发
- 专利使用
- 私人使用
- 个人学习和研究
- 商业和非商业使用
- 修改和分发
- 专利使用
- 私人使用
### 义务和限制
- 📋 **开源义务**:分发时必须提供源代码
- 📋 **相同许可**:衍生作品必须使用相同许可证
- 📋 **版权声明**:必须保留原始版权和许可证声明
- **责任免除**:不提供任何担保
- 未经授权的渗透测试
- 任何违法行为
- **开源义务**:分发时必须提供源代码
- **相同许可**:衍生作品必须使用相同许可证
- **版权声明**:必须保留原始版权和许可证声明
- **责任免除**:不提供任何担保
- 未经授权的渗透测试
- 任何违法行为

View File

@@ -1 +1 @@
v1.4.0
v1.5.12-dev

1
backend/.gitignore vendored
View File

@@ -7,6 +7,7 @@ __pycache__/
*.egg-info/
dist/
build/
.hypothesis/ # Hypothesis 属性测试缓存
# 虚拟环境
venv/

View File

@@ -0,0 +1,53 @@
# Generated by Django 5.2.7 on 2026-01-07 02:21
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('asset', '0002_create_search_views'),
('scan', '0001_initial'),
('targets', '0001_initial'),
]
operations = [
migrations.CreateModel(
name='Screenshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.TextField(help_text='截图对应的 URL')),
('image', models.BinaryField(help_text='截图 WebP 二进制数据(压缩后)')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
('target', models.ForeignKey(help_text='所属目标', on_delete=django.db.models.deletion.CASCADE, related_name='screenshots', to='targets.target')),
],
options={
'verbose_name': '截图',
'verbose_name_plural': '截图',
'db_table': 'screenshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['target'], name='screenshot_target__2f01f6_idx'), models.Index(fields=['-created_at'], name='screenshot_created_c0ad4b_idx')],
'constraints': [models.UniqueConstraint(fields=('target', 'url'), name='unique_screenshot_per_target')],
},
),
migrations.CreateModel(
name='ScreenshotSnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.TextField(help_text='截图对应的 URL')),
('image', models.BinaryField(help_text='截图 WebP 二进制数据(压缩后)')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='screenshot_snapshots', to='scan.scan')),
],
options={
'verbose_name': '截图快照',
'verbose_name_plural': '截图快照',
'db_table': 'screenshot_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='screenshot__scan_id_fb8c4d_idx'), models.Index(fields=['-created_at'], name='screenshot__created_804117_idx')],
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_screenshot_per_scan_snapshot')],
},
),
]

View File

@@ -0,0 +1,23 @@
# Generated by Django 5.2.7 on 2026-01-07 13:29
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('asset', '0003_add_screenshot_models'),
]
operations = [
migrations.AddField(
model_name='screenshot',
name='status_code',
field=models.SmallIntegerField(blank=True, help_text='HTTP 响应状态码', null=True),
),
migrations.AddField(
model_name='screenshotsnapshot',
name='status_code',
field=models.SmallIntegerField(blank=True, help_text='HTTP 响应状态码', null=True),
),
]

View File

@@ -20,6 +20,12 @@ from .snapshot_models import (
VulnerabilitySnapshot,
)
# 截图模型
from .screenshot_models import (
Screenshot,
ScreenshotSnapshot,
)
# 统计模型
from .statistics_models import AssetStatistics, StatisticsHistory
@@ -39,6 +45,9 @@ __all__ = [
'HostPortMappingSnapshot',
'EndpointSnapshot',
'VulnerabilitySnapshot',
# 截图模型
'Screenshot',
'ScreenshotSnapshot',
# 统计模型
'AssetStatistics',
'StatisticsHistory',

View File

@@ -0,0 +1,80 @@
from django.db import models
class ScreenshotSnapshot(models.Model):
"""
截图快照
记录:某次扫描中捕获的网站截图
"""
id = models.AutoField(primary_key=True)
scan = models.ForeignKey(
'scan.Scan',
on_delete=models.CASCADE,
related_name='screenshot_snapshots',
help_text='所属的扫描任务'
)
url = models.TextField(help_text='截图对应的 URL')
status_code = models.SmallIntegerField(null=True, blank=True, help_text='HTTP 响应状态码')
image = models.BinaryField(help_text='截图 WebP 二进制数据(压缩后)')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'screenshot_snapshot'
verbose_name = '截图快照'
verbose_name_plural = '截图快照'
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['-created_at']),
]
constraints = [
models.UniqueConstraint(
fields=['scan', 'url'],
name='unique_screenshot_per_scan_snapshot'
),
]
def __str__(self):
return f'{self.url} (Scan #{self.scan_id})'
class Screenshot(models.Model):
"""
截图资产
存储:目标的最新截图(从快照同步)
"""
id = models.AutoField(primary_key=True)
target = models.ForeignKey(
'targets.Target',
on_delete=models.CASCADE,
related_name='screenshots',
help_text='所属目标'
)
url = models.TextField(help_text='截图对应的 URL')
status_code = models.SmallIntegerField(null=True, blank=True, help_text='HTTP 响应状态码')
image = models.BinaryField(help_text='截图 WebP 二进制数据(压缩后)')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
class Meta:
db_table = 'screenshot'
verbose_name = '截图'
verbose_name_plural = '截图'
ordering = ['-created_at']
indexes = [
models.Index(fields=['target']),
models.Index(fields=['-created_at']),
]
constraints = [
models.UniqueConstraint(
fields=['target', 'url'],
name='unique_screenshot_per_target'
),
]
def __str__(self):
return f'{self.url} (Target #{self.target_id})'

View File

@@ -195,3 +195,32 @@ class DjangoHostPortMappingSnapshotRepository:
for row in qs.iterator(chunk_size=batch_size):
yield row
def iter_unique_host_ports_by_scan(
self,
scan_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
流式获取扫描下的唯一 host:port 组合(去重)
用于生成 URL 时避免重复,同一个 host:port 可能对应多个 IP
但生成 URL 时只需要一个。
Args:
scan_id: 扫描 ID
batch_size: 每批数据量
Yields:
{'host': 'example.com', 'port': 80}
"""
qs = (
HostPortMappingSnapshot.objects
.filter(scan_id=scan_id)
.values('host', 'port')
.distinct()
.order_by('host', 'port')
)
for row in qs.iterator(chunk_size=batch_size):
yield row

View File

@@ -7,6 +7,7 @@ from .models.snapshot_models import (
EndpointSnapshot,
VulnerabilitySnapshot,
)
from .models.screenshot_models import Screenshot, ScreenshotSnapshot
# 注意IPAddress 和 Port 模型已被重构为 HostPortMapping
@@ -290,3 +291,23 @@ class EndpointSnapshotSerializer(serializers.ModelSerializer):
'created_at',
]
read_only_fields = fields
# ==================== 截图序列化器 ====================
class ScreenshotListSerializer(serializers.ModelSerializer):
"""截图资产列表序列化器(不包含 image 字段)"""
class Meta:
model = Screenshot
fields = ['id', 'url', 'status_code', 'created_at', 'updated_at']
read_only_fields = fields
class ScreenshotSnapshotListSerializer(serializers.ModelSerializer):
"""截图快照列表序列化器(不包含 image 字段)"""
class Meta:
model = ScreenshotSnapshot
fields = ['id', 'url', 'status_code', 'created_at']
read_only_fields = fields

View File

@@ -0,0 +1,186 @@
"""
Playwright 截图服务
使用 Playwright 异步批量捕获网站截图
"""
import asyncio
import logging
from typing import Optional, AsyncGenerator
logger = logging.getLogger(__name__)
class PlaywrightScreenshotService:
"""Playwright 截图服务 - 异步多 Page 并发截图"""
# 内置默认值(不暴露给用户)
DEFAULT_VIEWPORT_WIDTH = 1920
DEFAULT_VIEWPORT_HEIGHT = 1080
DEFAULT_TIMEOUT = 30000 # 毫秒
DEFAULT_JPEG_QUALITY = 85
def __init__(
self,
viewport_width: int = DEFAULT_VIEWPORT_WIDTH,
viewport_height: int = DEFAULT_VIEWPORT_HEIGHT,
timeout: int = DEFAULT_TIMEOUT,
concurrency: int = 5
):
"""
初始化 Playwright 截图服务
Args:
viewport_width: 视口宽度(像素)
viewport_height: 视口高度(像素)
timeout: 页面加载超时时间(毫秒)
concurrency: 并发截图数
"""
self.viewport_width = viewport_width
self.viewport_height = viewport_height
self.timeout = timeout
self.concurrency = concurrency
async def capture_screenshot(self, url: str, page) -> tuple[Optional[bytes], Optional[int]]:
"""
捕获单个 URL 的截图
Args:
url: 目标 URL
page: Playwright Page 对象
Returns:
(screenshot_bytes, status_code) 元组
- screenshot_bytes: JPEG 格式的截图字节数据,失败返回 None
- status_code: HTTP 响应状态码,失败返回 None
"""
status_code = None
try:
# 尝试加载页面,即使返回错误状态码也继续截图
try:
response = await page.goto(url, timeout=self.timeout, wait_until='networkidle')
if response:
status_code = response.status
except Exception as goto_error:
# 页面加载失败4xx/5xx 或其他错误),但页面可能已部分渲染
# 仍然尝试截图以捕获错误页面
logger.debug("页面加载异常但尝试截图: %s, 错误: %s", url, str(goto_error)[:50])
# 尝试截图(即使 goto 失败)
screenshot_bytes = await page.screenshot(
type='jpeg',
quality=self.DEFAULT_JPEG_QUALITY,
full_page=False
)
return (screenshot_bytes, status_code)
except asyncio.TimeoutError:
logger.warning("截图超时: %s", url)
return (None, None)
except Exception as e:
logger.warning("截图失败: %s, 错误: %s", url, str(e)[:100])
return (None, None)
async def _capture_with_semaphore(
self,
url: str,
context,
semaphore: asyncio.Semaphore
) -> tuple[str, Optional[bytes], Optional[int]]:
"""
使用信号量控制并发的截图任务
Args:
url: 目标 URL
context: Playwright BrowserContext
semaphore: 并发控制信号量
Returns:
(url, screenshot_bytes, status_code) 元组
"""
async with semaphore:
page = await context.new_page()
try:
screenshot_bytes, status_code = await self.capture_screenshot(url, page)
return (url, screenshot_bytes, status_code)
finally:
await page.close()
async def capture_batch(
self,
urls: list[str]
) -> AsyncGenerator[tuple[str, Optional[bytes], Optional[int]], None]:
"""
批量捕获截图(异步生成器)
使用单个 BrowserContext + 多 Page 并发模式
通过 Semaphore 控制并发数
Args:
urls: URL 列表
Yields:
(url, screenshot_bytes, status_code) 元组
"""
if not urls:
return
from playwright.async_api import async_playwright
async with async_playwright() as p:
# 启动浏览器headless 模式)
browser = await p.chromium.launch(
headless=True,
args=[
'--no-sandbox',
'--disable-setuid-sandbox',
'--disable-dev-shm-usage',
'--disable-gpu'
]
)
try:
# 创建单个 context
context = await browser.new_context(
viewport={
'width': self.viewport_width,
'height': self.viewport_height
},
ignore_https_errors=True,
user_agent='Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
)
# 使用 Semaphore 控制并发
semaphore = asyncio.Semaphore(self.concurrency)
# 创建所有任务
tasks = [
self._capture_with_semaphore(url, context, semaphore)
for url in urls
]
# 使用 as_completed 实现流式返回
for coro in asyncio.as_completed(tasks):
result = await coro
yield result
await context.close()
finally:
await browser.close()
async def capture_batch_collect(
self,
urls: list[str]
) -> list[tuple[str, Optional[bytes], Optional[int]]]:
"""
批量捕获截图(收集所有结果)
Args:
urls: URL 列表
Returns:
[(url, screenshot_bytes, status_code), ...] 列表
"""
results = []
async for result in self.capture_batch(urls):
results.append(result)
return results

View File

@@ -0,0 +1,187 @@
"""
截图服务
负责截图的压缩、保存和同步
"""
import io
import logging
import os
from typing import Optional
from PIL import Image
logger = logging.getLogger(__name__)
class ScreenshotService:
"""截图服务 - 负责压缩、保存和同步"""
def __init__(self, max_width: int = 800, target_kb: int = 100):
"""
初始化截图服务
Args:
max_width: 最大宽度(像素)
target_kb: 目标文件大小KB
"""
self.max_width = max_width
self.target_kb = target_kb
def compress_screenshot(self, image_path: str) -> Optional[bytes]:
"""
压缩截图为 WebP 格式
Args:
image_path: PNG 截图文件路径
Returns:
压缩后的 WebP 二进制数据,失败返回 None
"""
if not os.path.exists(image_path):
logger.warning(f"截图文件不存在: {image_path}")
return None
try:
with Image.open(image_path) as img:
return self._compress_image(img)
except Exception as e:
logger.error(f"压缩截图失败: {image_path}, 错误: {e}")
return None
def compress_from_bytes(self, image_bytes: bytes) -> Optional[bytes]:
"""
从字节数据压缩截图为 WebP 格式
Args:
image_bytes: JPEG/PNG 图片字节数据
Returns:
压缩后的 WebP 二进制数据,失败返回 None
"""
if not image_bytes:
return None
try:
img = Image.open(io.BytesIO(image_bytes))
return self._compress_image(img)
except Exception as e:
logger.error(f"从字节压缩截图失败: {e}")
return None
def _compress_image(self, img: Image.Image) -> Optional[bytes]:
"""
压缩 PIL Image 对象为 WebP 格式
Args:
img: PIL Image 对象
Returns:
压缩后的 WebP 二进制数据
"""
try:
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
width, height = img.size
if width > self.max_width:
ratio = self.max_width / width
new_size = (self.max_width, int(height * ratio))
img = img.resize(new_size, Image.Resampling.LANCZOS)
quality = 80
while quality >= 10:
buffer = io.BytesIO()
img.save(buffer, format='WEBP', quality=quality, method=6)
if len(buffer.getvalue()) <= self.target_kb * 1024:
return buffer.getvalue()
quality -= 10
return buffer.getvalue()
except Exception as e:
logger.error(f"压缩图片失败: {e}")
return None
def save_screenshot_snapshot(
self,
scan_id: int,
url: str,
image_data: bytes,
status_code: int | None = None
) -> bool:
"""
保存截图快照到 ScreenshotSnapshot 表
Args:
scan_id: 扫描 ID
url: 截图对应的 URL
image_data: 压缩后的图片二进制数据
status_code: HTTP 响应状态码
Returns:
是否保存成功
"""
from apps.asset.models import ScreenshotSnapshot
try:
ScreenshotSnapshot.objects.update_or_create(
scan_id=scan_id,
url=url,
defaults={'image': image_data, 'status_code': status_code}
)
return True
except Exception as e:
logger.error(f"保存截图快照失败: scan_id={scan_id}, url={url}, 错误: {e}")
return False
def sync_screenshots_to_asset(self, scan_id: int, target_id: int) -> int:
"""
将扫描的截图快照同步到资产表
Args:
scan_id: 扫描 ID
target_id: 目标 ID
Returns:
同步的截图数量
"""
from apps.asset.models import Screenshot, ScreenshotSnapshot
# 使用 iterator() 避免 QuerySet 缓存大量 BinaryField 数据导致内存飙升
# chunk_size=50: 每次只加载 50 条记录,处理完后释放内存
snapshots = ScreenshotSnapshot.objects.filter(scan_id=scan_id).iterator(chunk_size=50)
count = 0
for snapshot in snapshots:
try:
Screenshot.objects.update_or_create(
target_id=target_id,
url=snapshot.url,
defaults={
'image': snapshot.image,
'status_code': snapshot.status_code
}
)
count += 1
except Exception as e:
logger.error(f"同步截图到资产表失败: url={snapshot.url}, 错误: {e}")
logger.info(f"同步截图完成: scan_id={scan_id}, target_id={target_id}, 数量={count}")
return count
def process_and_save_screenshot(self, scan_id: int, url: str, image_path: str) -> bool:
"""
处理并保存截图(压缩 + 保存快照)
Args:
scan_id: 扫描 ID
url: 截图对应的 URL
image_path: PNG 截图文件路径
Returns:
是否处理成功
"""
image_data = self.compress_screenshot(image_path)
if image_data is None:
return False
return self.save_screenshot_snapshot(scan_id, url, image_data)

View File

@@ -1,72 +1,18 @@
"""Endpoint Snapshots Service - 业务逻辑层"""
import logging
from typing import List, Iterator
from typing import Iterator, List, Optional
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
from apps.asset.repositories.snapshot import DjangoEndpointSnapshotRepository
from apps.asset.services.asset import EndpointService
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
logger = logging.getLogger(__name__)
class EndpointSnapshotsService:
"""端点快照服务 - 统一管理快照和资产同步"""
def __init__(self):
self.snapshot_repo = DjangoEndpointSnapshotRepository()
self.asset_service = EndpointService()
def save_and_sync(self, items: List[EndpointSnapshotDTO]) -> None:
"""
保存端点快照并同步到资产表(统一入口)
流程:
1. 保存到快照表(完整记录)
2. 同步到资产表(去重)
Args:
items: 端点快照 DTO 列表(必须包含 target_id
Raises:
ValueError: 如果 items 中的 target_id 为 None
Exception: 数据库操作失败
"""
if not items:
return
# 检查 Scan 是否仍存在(防止删除后竞态写入)
scan_id = items[0].scan_id
from apps.scan.repositories import DjangoScanRepository
if not DjangoScanRepository().exists(scan_id):
logger.warning("Scan 已删除,跳过端点快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
return
try:
logger.debug("保存端点快照并同步到资产表 - 数量: %d", len(items))
# 步骤 1: 保存到快照表
logger.debug("步骤 1: 保存到快照表")
self.snapshot_repo.save_snapshots(items)
# 步骤 2: 转换为资产 DTO 并保存到资产表
# 使用 upsert新记录插入已存在的记录更新
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
asset_items = [item.to_asset_dto() for item in items]
self.asset_service.bulk_upsert(asset_items)
logger.info("端点快照和资产数据保存成功 - 数量: %d", len(items))
except Exception as e:
logger.error(
"保存端点快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
@@ -76,26 +22,89 @@ class EndpointSnapshotsService:
'webserver': 'webserver',
'tech': 'tech',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
def __init__(self):
self.snapshot_repo = DjangoEndpointSnapshotRepository()
self.asset_service = EndpointService()
def save_and_sync(self, items: List[EndpointSnapshotDTO]) -> None:
"""
保存端点快照并同步到资产表(统一入口)
流程:
1. 保存到快照表(完整记录)
2. 同步到资产表(去重)
Args:
items: 端点快照 DTO 列表(必须包含 target_id
Raises:
ValueError: 如果 items 中的 target_id 为 None
Exception: 数据库操作失败
"""
if not items:
return
# 检查 Scan 是否仍存在(防止删除后竞态写入)
scan_id = items[0].scan_id
from apps.scan.repositories import DjangoScanRepository
if not DjangoScanRepository().exists(scan_id):
logger.warning("Scan 已删除,跳过端点快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
return
try:
logger.debug("保存端点快照并同步到资产表 - 数量: %d", len(items))
# 步骤 1: 保存到快照表
self.snapshot_repo.save_snapshots(items)
# 步骤 2: 转换为资产 DTO 并保存到资产表upsert
asset_items = [item.to_asset_dto() for item in items]
self.asset_service.bulk_upsert(asset_items)
logger.info("端点快照和资产数据保存成功 - 数量: %d", len(items))
except Exception as e:
logger.error("保存端点快照失败 - 数量: %d, 错误: %s", len(items), str(e), exc_info=True)
raise
def get_by_scan(self, scan_id: int, filter_query: Optional[str] = None):
"""
获取指定扫描的端点快照
Args:
scan_id: 扫描 ID
filter_query: 过滤查询字符串
Returns:
QuerySet: 端点快照查询集
"""
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self, filter_query: str = None):
"""获取所有端点快照"""
def get_all(self, filter_query: Optional[str] = None):
"""
获取所有端点快照
Args:
filter_query: 过滤查询字符串
Returns:
QuerySet: 端点快照查询集
"""
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_endpoint_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有端点 URL"""
"""流式获取某次扫描下的所有端点 URL"""
queryset = self.snapshot_repo.get_by_scan(scan_id)
for snapshot in queryset.iterator(chunk_size=chunk_size):
yield snapshot.url
@@ -103,10 +112,10 @@ class EndpointSnapshotsService:
def iter_raw_data_for_csv_export(self, scan_id: int) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
scan_id: 扫描 ID
Yields:
原始数据字典
"""

View File

@@ -91,3 +91,25 @@ class HostPortMappingSnapshotsService:
原始数据字典 {ip, host, port, created_at}
"""
return self.snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)
def iter_unique_host_ports_by_scan(
self,
scan_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
流式获取扫描下的唯一 host:port 组合(去重)
用于生成 URL 时避免重复。
Args:
scan_id: 扫描 ID
batch_size: 每批数据量
Yields:
{'host': 'example.com', 'port': 80}
"""
return self.snapshot_repo.iter_unique_host_ports_by_scan(
scan_id=scan_id,
batch_size=batch_size
)

View File

@@ -14,6 +14,7 @@ from .views import (
AssetSearchExportView,
EndpointViewSet,
HostPortMappingViewSet,
ScreenshotViewSet,
)
# 创建 DRF 路由器
@@ -26,6 +27,7 @@ router.register(r'directories', DirectoryViewSet, basename='directory')
router.register(r'endpoints', EndpointViewSet, basename='endpoint')
router.register(r'ip-addresses', HostPortMappingViewSet, basename='ip-address')
router.register(r'vulnerabilities', VulnerabilityViewSet, basename='vulnerability')
router.register(r'screenshots', ScreenshotViewSet, basename='screenshot')
router.register(r'statistics', AssetStatisticsViewSet, basename='asset-statistics')
urlpatterns = [

View File

@@ -18,6 +18,8 @@ from .asset_views import (
EndpointSnapshotViewSet,
HostPortMappingSnapshotViewSet,
VulnerabilitySnapshotViewSet,
ScreenshotViewSet,
ScreenshotSnapshotViewSet,
)
from .search_views import AssetSearchView, AssetSearchExportView
@@ -35,6 +37,8 @@ __all__ = [
'EndpointSnapshotViewSet',
'HostPortMappingSnapshotViewSet',
'VulnerabilitySnapshotViewSet',
'ScreenshotViewSet',
'ScreenshotSnapshotViewSet',
'AssetSearchView',
'AssetSearchExportView',
]

View File

@@ -1225,3 +1225,162 @@ class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
if scan_pk:
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
# ==================== 截图 ViewSet ====================
class ScreenshotViewSet(viewsets.ModelViewSet):
"""截图资产 ViewSet
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/screenshots/
2. 独立路由GET /api/screenshots/(全局查询)
支持智能过滤语法filter 参数):
- url="example" URL 模糊匹配
"""
from ..serializers import ScreenshotListSerializer
serializer_class = ScreenshotListSerializer
pagination_class = BasePagination
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
from ..models import Screenshot
target_pk = self.kwargs.get('target_pk')
filter_query = self.request.query_params.get('filter', None)
queryset = Screenshot.objects.all()
if target_pk:
queryset = queryset.filter(target_id=target_pk)
if filter_query:
# 简单的 URL 模糊匹配
queryset = queryset.filter(url__icontains=filter_query)
return queryset.order_by('-created_at')
@action(detail=True, methods=['get'], url_path='image')
def image(self, request, pk=None, **kwargs):
"""获取截图图片
GET /api/assets/screenshots/{id}/image/
返回 WebP 格式的图片二进制数据
"""
from django.http import HttpResponse
from ..models import Screenshot
try:
screenshot = Screenshot.objects.get(pk=pk)
if not screenshot.image:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Screenshot image not found',
status_code=status.HTTP_404_NOT_FOUND
)
response = HttpResponse(screenshot.image, content_type='image/webp')
response['Content-Disposition'] = f'inline; filename="screenshot_{pk}.webp"'
return response
except Screenshot.DoesNotExist:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Screenshot not found',
status_code=status.HTTP_404_NOT_FOUND
)
@action(detail=False, methods=['post'], url_path='bulk-delete')
def bulk_delete(self, request, **kwargs):
"""批量删除截图
POST /api/assets/screenshots/bulk-delete/
请求体: {"ids": [1, 2, 3]}
响应: {"deletedCount": 3}
"""
ids = request.data.get('ids', [])
if not ids or not isinstance(ids, list):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='ids is required and must be a list',
status_code=status.HTTP_400_BAD_REQUEST
)
try:
from ..models import Screenshot
deleted_count, _ = Screenshot.objects.filter(id__in=ids).delete()
return success_response(data={'deletedCount': deleted_count})
except Exception as e:
logger.exception("批量删除截图失败")
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Failed to delete screenshots',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
class ScreenshotSnapshotViewSet(viewsets.ModelViewSet):
"""截图快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/screenshots/
支持智能过滤语法filter 参数):
- url="example" URL 模糊匹配
"""
from ..serializers import ScreenshotSnapshotListSerializer
serializer_class = ScreenshotSnapshotListSerializer
pagination_class = BasePagination
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def get_queryset(self):
"""根据 scan_pk 参数查询"""
from ..models import ScreenshotSnapshot
scan_pk = self.kwargs.get('scan_pk')
filter_query = self.request.query_params.get('filter', None)
queryset = ScreenshotSnapshot.objects.all()
if scan_pk:
queryset = queryset.filter(scan_id=scan_pk)
if filter_query:
# 简单的 URL 模糊匹配
queryset = queryset.filter(url__icontains=filter_query)
return queryset.order_by('-created_at')
@action(detail=True, methods=['get'], url_path='image')
def image(self, request, pk=None, **kwargs):
"""获取截图快照图片
GET /api/scans/{scan_pk}/screenshots/{id}/image/
返回 WebP 格式的图片二进制数据
"""
from django.http import HttpResponse
from ..models import ScreenshotSnapshot
try:
screenshot = ScreenshotSnapshot.objects.get(pk=pk)
if not screenshot.image:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Screenshot image not found',
status_code=status.HTTP_404_NOT_FOUND
)
response = HttpResponse(screenshot.image, content_type='image/webp')
response['Content-Disposition'] = f'inline; filename="screenshot_snapshot_{pk}.webp"'
return response
except ScreenshotSnapshot.DoesNotExist:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Screenshot snapshot not found',
status_code=status.HTTP_404_NOT_FOUND
)

View File

@@ -1,43 +1,43 @@
"""
Prefect Flow Django 环境初始化模块
Django 环境初始化模块
在所有 Prefect Flow 文件开头导入此模块即可自动配置 Django 环境
在所有 Worker 脚本开头导入此模块即可自动配置 Django 环境
"""
import os
import sys
def setup_django_for_prefect():
def setup_django():
"""
Prefect Flow 配置 Django 环境
配置 Django 环境
此函数会
1. 添加项目根目录到 Python 路径
2. 设置 DJANGO_SETTINGS_MODULE 环境变量
3. 调用 django.setup() 初始化 Django
4. 关闭旧的数据库连接确保使用新连接
使用方式
from apps.common.prefect_django_setup import setup_django_for_prefect
setup_django_for_prefect()
from apps.common.django_setup import setup_django
setup_django()
"""
# 获取项目根目录backend 目录)
current_dir = os.path.dirname(os.path.abspath(__file__))
backend_dir = os.path.join(current_dir, '../..')
backend_dir = os.path.abspath(backend_dir)
# 添加到 Python 路径
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
# 配置 Django
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
# 初始化 Django
import django
django.setup()
# 关闭所有旧的数据库连接,确保 Worker 进程使用新连接
# 解决 "server closed the connection unexpectedly" 问题
from django.db import connections
@@ -47,7 +47,7 @@ def setup_django_for_prefect():
def close_old_db_connections():
"""
关闭旧的数据库连接
在长时间运行的任务中调用此函数可以确保使用有效的数据库连接
适用于
- Flow 开始前
@@ -59,4 +59,4 @@ def close_old_db_connections():
# 自动执行初始化(导入即生效)
setup_django_for_prefect()
setup_django()

View File

@@ -14,6 +14,7 @@ from .views import (
LoginView, LogoutView, MeView, ChangePasswordView,
SystemLogsView, SystemLogFilesView, HealthCheckView,
GlobalBlacklistView,
VersionView, CheckUpdateView,
)
urlpatterns = [
@@ -29,6 +30,8 @@ urlpatterns = [
# 系统管理
path('system/logs/', SystemLogsView.as_view(), name='system-logs'),
path('system/logs/files/', SystemLogFilesView.as_view(), name='system-log-files'),
path('system/version/', VersionView.as_view(), name='system-version'),
path('system/check-update/', CheckUpdateView.as_view(), name='system-check-update'),
# 黑名单管理PUT 全量替换模式)
path('blacklist/rules/', GlobalBlacklistView.as_view(), name='blacklist-rules'),

View File

@@ -6,16 +6,19 @@
- 认证相关视图:登录、登出、用户信息、修改密码
- 系统日志视图:实时日志查看
- 黑名单视图:全局黑名单规则管理
- 版本视图:系统版本和更新检查
"""
from .health_views import HealthCheckView
from .auth_views import LoginView, LogoutView, MeView, ChangePasswordView
from .system_log_views import SystemLogsView, SystemLogFilesView
from .blacklist_views import GlobalBlacklistView
from .version_views import VersionView, CheckUpdateView
__all__ = [
'HealthCheckView',
'LoginView', 'LogoutView', 'MeView', 'ChangePasswordView',
'SystemLogsView', 'SystemLogFilesView',
'GlobalBlacklistView',
'VersionView', 'CheckUpdateView',
]

View File

@@ -0,0 +1,136 @@
"""
系统版本相关视图
"""
import logging
from pathlib import Path
import requests
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.common.error_codes import ErrorCodes
from apps.common.response_helpers import error_response, success_response
logger = logging.getLogger(__name__)
# GitHub 仓库信息
GITHUB_REPO = "yyhuni/xingrin"
GITHUB_API_URL = f"https://api.github.com/repos/{GITHUB_REPO}/releases/latest"
GITHUB_RELEASES_URL = f"https://github.com/{GITHUB_REPO}/releases"
def get_current_version() -> str:
"""读取当前版本号"""
import os
# 方式1从环境变量读取Docker 容器中推荐)
version = os.environ.get('IMAGE_TAG', '')
if version:
return version
# 方式2从文件读取开发环境
possible_paths = [
Path('/app/VERSION'),
Path(__file__).parent.parent.parent.parent.parent / 'VERSION',
]
for path in possible_paths:
try:
return path.read_text(encoding='utf-8').strip()
except (FileNotFoundError, OSError):
continue
return "unknown"
def compare_versions(current: str, latest: str) -> bool:
"""
比较版本号,判断是否有更新
Returns:
True 表示有更新可用
"""
def parse_version(v: str) -> tuple:
v = v.lstrip('v')
parts = v.split('.')
result = []
for part in parts:
if '-' in part:
num, _ = part.split('-', 1)
result.append(int(num))
else:
result.append(int(part))
return tuple(result)
try:
return parse_version(latest) > parse_version(current)
except (ValueError, AttributeError):
return False
class VersionView(APIView):
"""获取当前系统版本"""
def get(self, _request: Request) -> Response:
"""获取当前版本信息"""
return success_response(data={
'version': get_current_version(),
'github_repo': GITHUB_REPO,
})
class CheckUpdateView(APIView):
"""检查系统更新"""
def get(self, _request: Request) -> Response:
"""
检查是否有新版本
Returns:
- current_version: 当前版本
- latest_version: 最新版本
- has_update: 是否有更新
- release_url: 发布页面 URL
- release_notes: 更新说明(如果有)
"""
current_version = get_current_version()
try:
response = requests.get(
GITHUB_API_URL,
headers={'Accept': 'application/vnd.github.v3+json'},
timeout=10
)
if response.status_code == 404:
return success_response(data={
'current_version': current_version,
'latest_version': current_version,
'has_update': False,
'release_url': GITHUB_RELEASES_URL,
'release_notes': None,
})
response.raise_for_status()
release_data = response.json()
latest_version = release_data.get('tag_name', current_version)
has_update = compare_versions(current_version, latest_version)
return success_response(data={
'current_version': current_version,
'latest_version': latest_version,
'has_update': has_update,
'release_url': release_data.get('html_url', GITHUB_RELEASES_URL),
'release_notes': release_data.get('body'),
'published_at': release_data.get('published_at'),
})
except requests.RequestException as e:
logger.warning("检查更新失败: %s", e)
return error_response(
code=ErrorCodes.SERVER_ERROR,
message="无法连接到 GitHub请稍后重试",
)

View File

@@ -2,8 +2,9 @@
初始化默认扫描引擎
用法:
python manage.py init_default_engine # 只创建不存在的引擎(不覆盖已有)
python manage.py init_default_engine --force # 强制覆盖所有引擎配置
python manage.py init_default_engine # 只创建不存在的引擎(不覆盖已有)
python manage.py init_default_engine --force # 强制覆盖所有引擎配置
python manage.py init_default_engine --force-sub # 只覆盖子引擎,保留 full scan
cd /root/my-vulun-scan/docker
docker compose exec server python backend/manage.py init_default_engine --force
@@ -12,6 +13,7 @@
- 读取 engine_config_example.yaml 作为默认配置
- 创建 full scan默认引擎+ 各扫描类型的子引擎
- 默认不覆盖已有配置,加 --force 才会覆盖
- 加 --force-sub 只覆盖子引擎配置,保留用户自定义的 full scan
"""
from django.core.management.base import BaseCommand
@@ -30,11 +32,18 @@ class Command(BaseCommand):
parser.add_argument(
'--force',
action='store_true',
help='强制覆盖已有的引擎配置',
help='强制覆盖已有的引擎配置(包括 full scan 和子引擎)',
)
parser.add_argument(
'--force-sub',
action='store_true',
help='只覆盖子引擎配置,保留 full scan升级时使用',
)
def handle(self, *args, **options):
force = options.get('force', False)
force_sub = options.get('force_sub', False)
# 读取默认配置文件
config_path = Path(__file__).resolve().parent.parent.parent.parent / 'scan' / 'configs' / 'engine_config_example.yaml'
@@ -99,15 +108,22 @@ class Command(BaseCommand):
engine_name = f"{scan_type}"
sub_engine = ScanEngine.objects.filter(name=engine_name).first()
if sub_engine:
if force:
# force 或 force_sub 都会覆盖子引擎
if force or force_sub:
sub_engine.configuration = single_yaml
sub_engine.save()
self.stdout.write(self.style.SUCCESS(f' ✓ 子引擎 {engine_name} 配置已更新 (ID: {sub_engine.id})'))
self.stdout.write(self.style.SUCCESS(
f' ✓ 子引擎 {engine_name} 配置已更新 (ID: {sub_engine.id})'
))
else:
self.stdout.write(self.style.WARNING(f'{engine_name} 已存在,跳过(使用 --force 覆盖)'))
self.stdout.write(self.style.WARNING(
f'{engine_name} 已存在,跳过(使用 --force 覆盖)'
))
else:
sub_engine = ScanEngine.objects.create(
name=engine_name,
configuration=single_yaml,
)
self.stdout.write(self.style.SUCCESS(f' ✓ 子引擎 {engine_name} 已创建 (ID: {sub_engine.id})'))
self.stdout.write(self.style.SUCCESS(
f' ✓ 子引擎 {engine_name} 已创建 (ID: {sub_engine.id})'
))

View File

@@ -21,11 +21,11 @@ from apps.engine.services import NucleiTemplateRepoService
logger = logging.getLogger(__name__)
# 默认仓库配置
# 默认仓库配置(从 settings 读取,支持 Gitee 镜像)
DEFAULT_REPOS = [
{
"name": "nuclei-templates",
"repo_url": "https://github.com/projectdiscovery/nuclei-templates.git",
"repo_url": getattr(settings, 'NUCLEI_TEMPLATES_REPO_URL', 'https://github.com/projectdiscovery/nuclei-templates.git'),
"description": "Nuclei 官方模板仓库,包含数千个漏洞检测模板",
},
]

View File

@@ -156,10 +156,10 @@ class TaskDistributor:
# 降级策略:如果没有正常负载的,循环等待后重新检测
if not scored_workers:
if high_load_workers:
# 高负载等待参数(默认每 60 秒检测一次,最多 10 次
high_load_wait = getattr(settings, 'HIGH_LOAD_WAIT_SECONDS', 60)
high_load_max_retries = getattr(settings, 'HIGH_LOAD_MAX_RETRIES', 10)
# 高负载等待参数(每 2 分钟检测一次,最多等待 2 小时
high_load_wait = getattr(settings, 'HIGH_LOAD_WAIT_SECONDS', 120)
high_load_max_retries = getattr(settings, 'HIGH_LOAD_MAX_RETRIES', 60)
# 开始等待前发送高负载通知
high_load_workers.sort(key=lambda x: x[1])
_, _, first_cpu, first_mem = high_load_workers[0]
@@ -170,51 +170,51 @@ class TaskDistributor:
cpu=first_cpu,
mem=first_mem
)
for retry in range(high_load_max_retries):
logger.warning(
"所有 Worker 高负载,等待 %d 秒后重试... (%d/%d)",
high_load_wait, retry + 1, high_load_max_retries
)
time.sleep(high_load_wait)
# 重新获取负载数据
loads = worker_load_service.get_all_loads(worker_ids)
# 重新评估
scored_workers = []
high_load_workers = []
for worker in workers:
load = loads.get(worker.id)
if not load:
continue
cpu = load.get('cpu', 0)
mem = load.get('mem', 0)
score = cpu * 0.7 + mem * 0.3
if cpu > 85 or mem > 85:
high_load_workers.append((worker, score, cpu, mem))
else:
scored_workers.append((worker, score, cpu, mem))
# 如果有正常负载的 Worker跳出循环
if scored_workers:
logger.info("检测到正常负载 Worker结束等待")
break
# 超时或仍然高负载,选择负载最低的
# 超时后强制派发到负载最低的 Worker
if not scored_workers and high_load_workers:
high_load_workers.sort(key=lambda x: x[1])
best_worker, _, cpu, mem = high_load_workers[0]
logger.warning(
"等待超时,强制分发到高负载 Worker: %s (CPU: %.1f%%, MEM: %.1f%%)",
"等待 %d 分钟后仍高负载,强制分发到 Worker: %s (CPU: %.1f%%, MEM: %.1f%%)",
(high_load_wait * high_load_max_retries) // 60,
best_worker.name, cpu, mem
)
return best_worker
return best_worker
else:
logger.warning("没有可用的 Worker")
return None
@@ -279,17 +279,11 @@ class TaskDistributor:
# 环境变量SERVER_URL + IS_LOCAL其他配置容器启动时从配置中心获取
# IS_LOCAL 用于 Worker 向配置中心声明身份,决定返回的数据库地址
# Prefect 本地模式配置:启用 ephemeral server本地临时服务器
is_local_str = "true" if worker.is_local else "false"
env_vars = [
f"-e SERVER_URL={shlex.quote(server_url)}",
f"-e IS_LOCAL={is_local_str}",
f"-e WORKER_API_KEY={shlex.quote(settings.WORKER_API_KEY)}", # Worker API 认证密钥
"-e PREFECT_HOME=/tmp/.prefect", # 设置 Prefect 数据目录到可写位置
"-e PREFECT_SERVER_EPHEMERAL_ENABLED=true", # 启用 ephemeral server本地临时服务器
"-e PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=120", # 增加启动超时时间
"-e PREFECT_SERVER_DATABASE_CONNECTION_URL=sqlite+aiosqlite:////tmp/.prefect/prefect.db", # 使用 /tmp 下的 SQLite
"-e PREFECT_LOGGING_LEVEL=WARNING", # 日志级别(减少 DEBUG 噪音)
]
# 挂载卷(统一挂载整个 /opt/xingrin 目录)
@@ -312,7 +306,11 @@ class TaskDistributor:
# - 本地 Workerinstall.sh 已预拉取镜像,直接使用本地版本
# - 远程 Workerdeploy 时已预拉取镜像,直接使用本地版本
# - 避免每次任务都检查 Docker Hub提升性能和稳定性
# OOM 优先级:--oom-score-adj=1000 让 Worker 在内存不足时优先被杀
# - 范围 -1000 到 1000值越大越容易被 OOM Killer 选中
# - 保护 server/nginx/frontend 等核心服务,确保 Web 界面可用
cmd = f'''docker run --rm -d --pull=missing {network_arg} \\
--oom-score-adj=1000 \\
{' '.join(env_vars)} \\
{' '.join(volumes)} \\
{self.docker_image} \\
@@ -445,34 +443,33 @@ class TaskDistributor:
def execute_scan_flow(
self,
scan_id: int,
target_name: str,
target_id: int,
target_name: str,
scan_workspace_dir: str,
engine_name: str,
scheduled_scan_name: str | None = None,
) -> tuple[bool, str, Optional[str], Optional[int]]:
"""
在远程或本地 Worker 上执行扫描 Flow
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作目录
engine_name: 引擎名称
scheduled_scan_name: 定时扫描任务名称(可选)
Returns:
(success, message, container_id, worker_id) 元组
Note:
engine_config 由 Flow 内部通过 scan_id 查询数据库获取
"""
logger.info("="*60)
logger.info("execute_scan_flow 开始")
logger.info(" scan_id: %s", scan_id)
logger.info(" target_name: %s", target_name)
logger.info(" target_id: %s", target_id)
logger.info(" target_name: %s", target_name)
logger.info(" scan_workspace_dir: %s", scan_workspace_dir)
logger.info(" engine_name: %s", engine_name)
logger.info(" docker_image: %s", self.docker_image)
@@ -491,23 +488,22 @@ class TaskDistributor:
# 3. 构建 docker run 命令
script_args = {
'scan_id': scan_id,
'target_name': target_name,
'target_id': target_id,
'scan_workspace_dir': scan_workspace_dir,
'engine_name': engine_name,
}
if scheduled_scan_name:
script_args['scheduled_scan_name'] = scheduled_scan_name
docker_cmd = self._build_docker_command(
worker=worker,
script_module='apps.scan.scripts.run_initiate_scan',
script_args=script_args,
)
logger.info(
"提交扫描任务到 Worker: %s - Scan ID: %d, Target: %s",
worker.name, scan_id, target_name
"提交扫描任务到 Worker: %s - Scan ID: %d, Target: %s (ID: %d)",
worker.name, scan_id, target_name, target_id
)
# 4. 执行 docker run本地直接执行远程通过 SSH

View File

@@ -24,18 +24,6 @@ SUBDOMAIN_DISCOVERY_COMMANDS = {
}
},
'amass_passive': {
# 先执行被动枚举,将结果写入 amass 内部数据库然后从数据库中导出纯域名names到 output_file
# -silent 禁用进度条和其他输出
'base': "amass enum -passive -silent -d {domain} && amass subs -names -d {domain} > '{output_file}'"
},
'amass_active': {
# 先执行主动枚举 + 爆破,将结果写入 amass 内部数据库然后从数据库中导出纯域名names到 output_file
# -silent 禁用进度条和其他输出
'base': "amass enum -active -silent -d {domain} -brute && amass subs -names -d {domain} > '{output_file}'"
},
'sublist3r': {
'base': "python3 '/usr/local/share/Sublist3r/sublist3r.py' -d {domain} -o '{output_file}'",
'optional': {
@@ -215,7 +203,7 @@ VULN_SCAN_COMMANDS = {
# -silent: 静默模式
# -l: 输入 URL 列表文件
# -t: 模板目录路径(支持多个仓库,多次 -t 由 template_args 直接拼接)
'base': "nuclei -j -silent -l '{endpoints_file}' {template_args}",
'base': "nuclei -j -silent -l '{input_file}' {template_args}",
'optional': {
'concurrency': '-c {concurrency}', # 并发数(默认 25
'rate_limit': '-rl {rate_limit}', # 每秒请求数限制
@@ -226,7 +214,12 @@ VULN_SCAN_COMMANDS = {
'tags': '-tags {tags}', # 过滤标签
'exclude_tags': '-etags {exclude_tags}', # 排除标签
},
'input_type': 'endpoints_file',
# 支持多种输入类型,用户通过 scan_endpoints/scan_websites 选择
'input_types': ['endpoints_file', 'websites_file'],
'defaults': {
'scan_endpoints': False, # 默认不扫描 endpoints
'scan_websites': True, # 默认扫描 websites
},
},
}
@@ -263,11 +256,16 @@ COMMAND_TEMPLATES = {
'directory_scan': DIRECTORY_SCAN_COMMANDS,
'url_fetch': URL_FETCH_COMMANDS,
'vuln_scan': VULN_SCAN_COMMANDS,
'screenshot': {}, # 使用 Python 原生库Playwright无命令模板
}
# ==================== 扫描类型配置 ====================
# 执行阶段定义(按顺序执行)
# Stage 1: 资产发现 - 子域名 → 端口 → 站点探测 → 指纹识别
# Stage 2: URL 收集 - URL 获取 + 目录扫描(并行)
# Stage 3: 截图 - 在 URL 收集完成后执行,捕获更多发现的页面
# Stage 4: 漏洞扫描 - 最后执行
EXECUTION_STAGES = [
{
'mode': 'sequential',
@@ -277,6 +275,10 @@ EXECUTION_STAGES = [
'mode': 'parallel',
'flows': ['url_fetch', 'directory_scan']
},
{
'mode': 'sequential',
'flows': ['screenshot']
},
{
'mode': 'sequential',
'flows': ['vuln_scan']

View File

@@ -17,14 +17,6 @@ subdomain_discovery:
timeout: 3600 # 1小时
# threads: 10 # 并发 goroutine 数
amass_passive:
enabled: true
timeout: 3600
amass_active:
enabled: true # 主动枚举 + 爆破
timeout: 3600
sublist3r:
enabled: true
timeout: 3600
@@ -62,7 +54,7 @@ port_scan:
threads: 200 # 并发连接数(默认 5
# ports: 1-65535 # 扫描端口范围(默认 1-65535
top-ports: 100 # 扫描 nmap top 100 端口
rate: 10 # 扫描速率(默认 10
rate: 50 # 扫描速率
naabu_passive:
enabled: true
@@ -101,6 +93,16 @@ directory_scan:
match-codes: 200,201,301,302,401,403 # 匹配的 HTTP 状态码
# rate: 0 # 每秒请求数(默认 0 不限制)
screenshot:
# ==================== 网站截图 ====================
# 使用 Playwright 对网站进行截图,保存为 WebP 格式
# 在 Stage 2 与 url_fetch、directory_scan 并行执行
tools:
playwright:
enabled: true
concurrency: 5 # 并发截图数(默认 5
url_sources: [websites] # URL 来源当前对website截图还可以用 [websites, endpoints]
url_fetch:
# ==================== URL 获取 ====================
tools:
@@ -156,7 +158,9 @@ vuln_scan:
nuclei:
enabled: true
# timeout: auto # 自动计算(根据 endpoints 行数)
# timeout: auto # 自动计算(根据输入 URL 行数)
scan-endpoints: false # 是否扫描 endpoints默认关闭
scan-websites: true # 是否扫描 websites默认开启
template-repo-names: # 模板仓库列表对应「Nuclei 模板」中的仓库名
- nuclei-templates
# - nuclei-custom # 可追加自定义仓库

View File

@@ -0,0 +1,200 @@
"""
扫描流程装饰器模块
提供轻量级的 @scan_flow 和 @scan_task 装饰器,替代 Prefect 的 @flow 和 @task。
核心功能:
- @scan_flow: 状态管理、通知、性能追踪
- @scan_task: 重试逻辑(大部分 task 不需要重试,可直接移除装饰器)
设计原则:
- 保持与 Prefect 装饰器相同的使用方式
- 零依赖,无额外内存开销
- 保留原函数签名和返回值
"""
import functools
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Callable, Optional
logger = logging.getLogger(__name__)
@dataclass
class FlowContext:
"""
Flow 执行上下文
替代 Prefect 的 Flow、FlowRun、State 参数,传递给回调函数。
"""
flow_name: str
stage_name: str
scan_id: Optional[int] = None
target_id: Optional[int] = None
target_name: Optional[str] = None
parameters: dict = field(default_factory=dict)
start_time: datetime = field(default_factory=datetime.now)
end_time: Optional[datetime] = None
result: Any = None
error: Optional[Exception] = None
error_message: Optional[str] = None
def scan_flow(
name: Optional[str] = None,
stage_name: Optional[str] = None,
on_running: Optional[list[Callable]] = None,
on_completion: Optional[list[Callable]] = None,
on_failure: Optional[list[Callable]] = None,
log_prints: bool = True, # 保持与 Prefect 兼容,但不使用
):
"""
扫描流程装饰器
替代 Prefect 的 @flow 装饰器,提供:
- 自动状态管理start_stage/complete_stage/fail_stage
- 生命周期回调on_running/on_completion/on_failure
- 性能追踪FlowPerformanceTracker
- 失败通知
Args:
name: Flow 名称,默认使用函数名
stage_name: 阶段名称,默认使用 name
on_running: 流程开始时的回调列表
on_completion: 流程完成时的回调列表
on_failure: 流程失败时的回调列表
log_prints: 保持与 Prefect 兼容,不使用
Usage:
@scan_flow(name="site_scan", on_running=[on_scan_flow_running])
def site_scan_flow(scan_id: int, target_id: int, ...):
...
"""
def decorator(func: Callable) -> Callable:
flow_name = name or func.__name__
actual_stage_name = stage_name or flow_name
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
# 提取参数
scan_id = kwargs.get('scan_id')
target_id = kwargs.get('target_id')
target_name = kwargs.get('target_name')
# 创建上下文
context = FlowContext(
flow_name=flow_name,
stage_name=actual_stage_name,
scan_id=scan_id,
target_id=target_id,
target_name=target_name,
parameters=kwargs.copy(),
start_time=datetime.now(),
)
# 执行 on_running 回调
if on_running:
for callback in on_running:
try:
callback(context)
except Exception as e:
logger.warning("on_running 回调执行失败: %s", e)
try:
# 执行原函数
result = func(*args, **kwargs)
# 更新上下文
context.end_time = datetime.now()
context.result = result
# 执行 on_completion 回调
if on_completion:
for callback in on_completion:
try:
callback(context)
except Exception as e:
logger.warning("on_completion 回调执行失败: %s", e)
return result
except Exception as e:
# 更新上下文
context.end_time = datetime.now()
context.error = e
context.error_message = str(e)
# 执行 on_failure 回调
if on_failure:
for callback in on_failure:
try:
callback(context)
except Exception as cb_error:
logger.warning("on_failure 回调执行失败: %s", cb_error)
# 重新抛出异常
raise
return wrapper
return decorator
def scan_task(
retries: int = 0,
retry_delay: float = 1.0,
name: Optional[str] = None, # 保持与 Prefect 兼容
):
"""
扫描任务装饰器
替代 Prefect 的 @task 装饰器,提供重试能力。
注意:当前代码中大部分 @task 都是 retries=0可以直接移除装饰器。
只有需要重试的 task 才需要使用此装饰器。
Args:
retries: 失败后重试次数,默认 0不重试
retry_delay: 重试间隔(秒),默认 1.0
name: 任务名称,保持与 Prefect 兼容,不使用
Usage:
@scan_task(retries=3, retry_delay=2.0)
def run_scan_tool(command: str, timeout: int):
...
"""
def decorator(func: Callable) -> Callable:
task_name = name or func.__name__
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
last_exception = None
for attempt in range(retries + 1):
try:
return func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < retries:
logger.warning(
"任务 %s 重试 %d/%d: %s",
task_name, attempt + 1, retries, e
)
time.sleep(retry_delay)
else:
logger.error(
"任务 %s 重试耗尽 (%d 次): %s",
task_name, retries + 1, e
)
# 重试耗尽,抛出最后一个异常
raise last_exception
# 添加 submit 方法以保持与 Prefect task.submit() 的兼容性
# 注意:这只是为了迁移过渡,最终应该使用 ThreadPoolExecutor
wrapper.fn = func
return wrapper
return decorator

View File

@@ -10,30 +10,31 @@
- 配置由 YAML 解析
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
from prefect import flow
from prefect.task_runners import ThreadPoolTaskRunner
import hashlib
import logging
import os
import subprocess
from datetime import datetime
from pathlib import Path
from typing import List, Tuple
from apps.scan.tasks.directory_scan import (
export_sites_task,
run_and_stream_save_directories_task
)
from concurrent.futures import ThreadPoolExecutor
from apps.scan.decorators import scan_flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.tasks.directory_scan import (
export_sites_task,
run_and_stream_save_directories_task,
)
from apps.scan.utils import (
build_scan_command,
ensure_wordlist_local,
user_log,
wait_for_system_load,
)
from apps.scan.utils import config_parser, build_scan_command, ensure_wordlist_local, user_log
logger = logging.getLogger(__name__)
@@ -45,608 +46,408 @@ def calculate_directory_scan_timeout(
tool_config: dict,
base_per_word: float = 1.0,
min_timeout: int = 60,
max_timeout: int = 7200
) -> int:
"""
根据字典行数计算目录扫描超时时间
计算公式:超时时间 = 字典行数 × 每个单词基础时间
超时范围:60秒 ~ 2小时7200秒
超时范围:最小 60 秒,无上限
Args:
tool_config: 工具配置字典,包含 wordlist 路径
base_per_word: 每个单词的基础时间(秒),默认 1.0秒
min_timeout: 最小超时时间(秒),默认 60秒
max_timeout: 最大超时时间(秒),默认 7200秒2小时
Returns:
int: 计算出的超时时间(秒)范围60 ~ 7200
Example:
# 1000行字典 × 1.0秒 = 1000秒 → 限制为7200秒中的 1000秒
# 10000行字典 × 1.0秒 = 10000秒 → 限制为7200秒最大值
timeout = calculate_directory_scan_timeout(
tool_config={'wordlist': '/path/to/wordlist.txt'}
)
int: 计算出的超时时间(秒)
"""
import os
wordlist_path = tool_config.get('wordlist')
if not wordlist_path:
logger.warning("工具配置中未指定 wordlist使用默认超时: %d", min_timeout)
return min_timeout
wordlist_path = os.path.expanduser(wordlist_path)
if not os.path.exists(wordlist_path):
logger.warning("字典文件不存在: %s,使用默认超时: %d", wordlist_path, min_timeout)
return min_timeout
try:
# 从 tool_config 中获取 wordlist 路径
wordlist_path = tool_config.get('wordlist')
if not wordlist_path:
logger.warning("工具配置中未指定 wordlist使用默认超时: %d", min_timeout)
return min_timeout
# 展开用户目录(~
wordlist_path = os.path.expanduser(wordlist_path)
# 检查文件是否存在
if not os.path.exists(wordlist_path):
logger.warning("字典文件不存在: %s,使用默认超时: %d", wordlist_path, min_timeout)
return min_timeout
# 使用 wc -l 快速统计字典行数
result = subprocess.run(
['wc', '-l', wordlist_path],
capture_output=True,
text=True,
check=True
)
# wc -l 输出格式:行数 + 空格 + 文件名
line_count = int(result.stdout.strip().split()[0])
# 计算超时时间
timeout = int(line_count * base_per_word)
# 设置合理的下限(不再设置上限)
timeout = max(min_timeout, timeout)
timeout = max(min_timeout, int(line_count * base_per_word))
logger.info(
"目录扫描超时计算 - 字典: %s, 行数: %d, 基础时间: %.3f秒/词, 计算超时: %d",
wordlist_path, line_count, base_per_word, timeout
)
return timeout
except subprocess.CalledProcessError as e:
logger.error("统计字典行数失败: %s", e)
# 失败时返回默认超时
return min_timeout
except (ValueError, IndexError) as e:
logger.error("解析字典行数失败: %s", e)
return min_timeout
except Exception as e:
logger.error("计算超时时间异常: %s", e)
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
logger.error("计算超时时间失败: %s", e)
return min_timeout
def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> int:
"""
从单个工具配置中获取 max_workers 参数
Args:
tool_config: 单个工具的配置字典,如 {'max_workers': 10, 'threads': 5, ...}
default: 默认值,默认为 5
Returns:
int: max_workers 值
"""
"""从单个工具配置中获取 max_workers 参数"""
if not isinstance(tool_config, dict):
return default
# 支持 max_workers 和 max-workersYAML 中划线会被转换)
max_workers = tool_config.get('max_workers') or tool_config.get('max-workers')
if max_workers is not None and isinstance(max_workers, int) and max_workers > 0:
if isinstance(max_workers, int) and max_workers > 0:
return max_workers
return default
def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]:
def _export_site_urls(
target_id: int,
directory_scan_dir: Path,
provider,
) -> Tuple[str, int]:
"""
导出目标下的所有站点 URL 到文件(支持懒加载)
导出目标下的所有站点 URL 到文件
Args:
target_id: 目标 ID
target_name: 目标名称(用于懒加载创建默认站点)
directory_scan_dir: 目录扫描目录
provider: TargetProvider 实例
Returns:
tuple: (sites_file, site_count)
Raises:
ValueError: 站点数量为 0
"""
logger.info("Step 1: 导出目标的所有站点 URL")
sites_file = str(directory_scan_dir / 'sites.txt')
export_result = export_sites_task(
target_id=target_id,
output_file=sites_file,
batch_size=1000 # 每次读取 1000 条,优化内存占用
provider=provider,
)
site_count = export_result['total_count']
logger.info(
"✓ 站点 URL 导出完成 - 文件: %s, 数量: %d",
export_result['output_file'],
site_count
)
if site_count == 0:
logger.warning("目标下没有站点,无法执行目录扫描")
# 不抛出异常,由上层决定如何处理
# raise ValueError("目标下没有站点,无法执行目录扫描")
return export_result['output_file'], site_count
def _run_scans_sequentially(
enabled_tools: dict,
sites_file: str,
directory_scan_dir: Path,
scan_id: int,
target_id: int,
site_count: int,
target_name: str
) -> tuple[int, int, list]:
"""
串行执行目录扫描任务(支持多工具)- 已废弃,保留用于兼容
Args:
enabled_tools: 启用的工具配置字典
sites_file: 站点文件路径
directory_scan_dir: 目录扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
site_count: 站点数量
target_name: 目标名称(用于错误日志)
Returns:
tuple: (total_directories, processed_sites, failed_sites)
"""
# 读取站点列表
sites = []
with open(sites_file, 'r', encoding='utf-8') as f:
for line in f:
site_url = line.strip()
if site_url:
sites.append(site_url)
logger.info("准备扫描 %d 个站点,使用工具: %s", len(sites), ', '.join(enabled_tools.keys()))
total_directories = 0
processed_sites_set = set() # 使用 set 避免重复计数
failed_sites = []
# 遍历每个工具
for tool_name, tool_config in enabled_tools.items():
logger.info("="*60)
logger.info("使用工具: %s", tool_name)
logger.info("="*60)
# 如果配置了 wordlist_name则先确保本地存在对应的字典文件含 hash 校验)
wordlist_name = tool_config.get('wordlist_name')
if wordlist_name:
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
failed_sites.extend(sites)
continue
# 逐个站点执行扫描
for idx, site_url in enumerate(sites, 1):
logger.info(
"[%d/%d] 开始扫描站点: %s (工具: %s)",
idx, len(sites), site_url, tool_name
)
# 使用统一的命令构建器
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={
'url': site_url
},
tool_config=tool_config
)
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
continue
# 单个站点超时:从配置中获取(支持 'auto' 动态计算)
# ffuf 逐个站点扫描timeout 就是单个站点的超时时间
site_timeout = tool_config.get('timeout', 300)
if site_timeout == 'auto':
# 动态计算超时时间(基于字典行数)
site_timeout = calculate_directory_scan_timeout(tool_config)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}")
# 生成日志文件路径
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = directory_scan_dir / f"{tool_name}_{timestamp}_{idx}.log"
try:
# 直接调用 task串行执行
result = run_and_stream_save_directories_task(
cmd=command,
tool_name=tool_name, # 新增:工具名称
scan_id=scan_id,
target_id=target_id,
site_url=site_url,
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=site_timeout,
log_file=str(log_file) # 新增:日志文件路径
)
total_directories += result.get('created_directories', 0)
processed_sites_set.add(site_url) # 使用 set 记录成功的站点
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, len(sites), site_url,
result.get('created_directories', 0)
)
except subprocess.TimeoutExpired as exc:
# 超时异常单独处理
failed_sites.append(site_url)
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 超时配置: %d\n"
"注意:超时前已解析的目录数据已保存到数据库,但扫描未完全完成。",
idx, len(sites), site_url, site_timeout
)
except Exception as exc:
# 其他异常
failed_sites.append(site_url)
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, len(sites), site_url, exc
)
# 每 10 个站点输出进度
if idx % 10 == 0:
logger.info(
"进度: %d/%d (%.1f%%) - 已发现 %d 个目录",
idx, len(sites), idx/len(sites)*100, total_directories
)
# 计算成功和失败的站点数
processed_count = len(processed_sites_set)
if failed_sites:
logger.warning(
"部分站点扫描失败: %d/%d",
len(failed_sites), len(sites)
)
logger.info(
"✓ 串行目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
processed_count, len(sites), len(failed_sites), total_directories
)
return total_directories, processed_count, failed_sites
def _generate_log_filename(tool_name: str, site_url: str, directory_scan_dir: Path) -> Path:
"""
生成唯一的日志文件名
使用 URL 的 hash 确保并发时不会冲突
Args:
tool_name: 工具名称
site_url: 站点 URL
directory_scan_dir: 目录扫描目录
Returns:
Path: 日志文件路径
"""
url_hash = hashlib.md5(site_url.encode()).hexdigest()[:8]
def _generate_log_filename(
tool_name: str,
site_url: str,
directory_scan_dir: Path
) -> Path:
"""生成唯一的日志文件名(使用 URL 的 hash 确保并发时不会冲突)"""
url_hash = hashlib.md5(
site_url.encode(),
usedforsecurity=False
).hexdigest()[:8]
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
return directory_scan_dir / f"{tool_name}_{url_hash}_{timestamp}.log"
def _prepare_tool_wordlist(tool_name: str, tool_config: dict) -> bool:
"""准备工具的字典文件,返回是否成功"""
wordlist_name = tool_config.get('wordlist_name')
if not wordlist_name:
return True
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
return True
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
return False
def _build_scan_params(
tool_name: str,
tool_config: dict,
sites: List[str],
directory_scan_dir: Path,
site_timeout: int
) -> Tuple[List[dict], List[str]]:
"""构建所有站点的扫描参数,返回 (scan_params_list, failed_sites)"""
scan_params_list = []
failed_sites = []
for idx, site_url in enumerate(sites, 1):
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={'url': site_url},
tool_config=tool_config
)
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
scan_params_list.append({
'idx': idx,
'site_url': site_url,
'command': command,
'log_file': str(log_file),
'timeout': site_timeout
})
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
return scan_params_list, failed_sites
def _execute_batch(
batch_params: List[dict],
tool_name: str,
scan_id: int,
target_id: int,
directory_scan_dir: Path,
total_sites: int
) -> Tuple[int, List[str]]:
"""执行一批扫描任务,返回 (directories_found, failed_sites)"""
directories_found = 0
failed_sites = []
# 使用 ThreadPoolExecutor 并行执行
with ThreadPoolExecutor(max_workers=len(batch_params)) as executor:
futures = []
for params in batch_params:
future = executor.submit(
run_and_stream_save_directories_task,
cmd=params['command'],
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
site_url=params['site_url'],
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=params['timeout'],
log_file=params['log_file']
)
futures.append((params['idx'], params['site_url'], future))
# 等待结果
for idx, site_url, future in futures:
try:
result = future.result()
dirs_count = result.get('created_directories', 0)
directories_found += dirs_count
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, total_sites, site_url, dirs_count
)
except Exception as exc:
failed_sites.append(site_url)
if 'timeout' in str(exc).lower():
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
idx, total_sites, site_url, exc
)
else:
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, total_sites, site_url, exc
)
return directories_found, failed_sites
def _run_scans_concurrently(
enabled_tools: dict,
sites_file: str,
directory_scan_dir: Path,
scan_id: int,
target_id: int,
site_count: int,
target_name: str
) -> Tuple[int, int, List[str]]:
"""
并发执行目录扫描任务(使用 ThreadPoolTaskRunner
Args:
enabled_tools: 启用的工具配置字典
sites_file: 站点文件路径
directory_scan_dir: 目录扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
site_count: 站点数量
target_name: 目标名称(用于错误日志)
并发执行目录扫描任务
Returns:
tuple: (total_directories, processed_sites, failed_sites)
"""
# 读取站点列表
sites: List[str] = []
with open(sites_file, 'r', encoding='utf-8') as f:
for line in f:
site_url = line.strip()
if site_url:
sites.append(site_url)
sites = [line.strip() for line in f if line.strip()]
if not sites:
logger.warning("站点列表为空")
return 0, 0, []
logger.info(
"准备并发扫描 %d 个站点,使用工具: %s",
len(sites), ', '.join(enabled_tools.keys())
)
total_directories = 0
processed_sites_count = 0
failed_sites: List[str] = []
# 遍历每个工具
for tool_name, tool_config in enabled_tools.items():
# 每个工具独立获取 max_workers 配置
max_workers = _get_max_workers(tool_config)
logger.info("="*60)
logger.info("=" * 60)
logger.info("使用工具: %s (并发模式, max_workers=%d)", tool_name, max_workers)
logger.info("="*60)
logger.info("=" * 60)
user_log(scan_id, "directory_scan", f"Running {tool_name}")
# 如果配置了 wordlist_name则先确保本地存在对应的字典文件含 hash 校验)
wordlist_name = tool_config.get('wordlist_name')
if wordlist_name:
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
failed_sites.extend(sites)
continue
# 计算超时时间(所有站点共用)
# 准备字典文件
if not _prepare_tool_wordlist(tool_name, tool_config):
failed_sites.extend(sites)
continue
# 计算超时时间
site_timeout = tool_config.get('timeout', 300)
if site_timeout == 'auto':
site_timeout = calculate_directory_scan_timeout(tool_config)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}")
# 准备所有站点的扫描参数
scan_params_list = []
for idx, site_url in enumerate(sites, 1):
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={'url': site_url},
tool_config=tool_config
)
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
scan_params_list.append({
'idx': idx,
'site_url': site_url,
'command': command,
'log_file': str(log_file),
'timeout': site_timeout
})
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
logger.info("✓ 工具 %s 动态计算 timeout: %d", tool_name, site_timeout)
# 构建扫描参数
scan_params_list, build_failed = _build_scan_params(
tool_name, tool_config, sites, directory_scan_dir, site_timeout
)
failed_sites.extend(build_failed)
if not scan_params_list:
logger.warning("没有有效的扫描任务")
continue
# ============================================================
# 分批执行策略:控制实际并发的 ffuf 进程数
# ============================================================
# 分批执行
total_tasks = len(scan_params_list)
logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers)
# 进度里程碑跟踪
last_progress_percent = 0
tool_directories = 0
tool_processed = 0
batch_num = 0
for batch_start in range(0, total_tasks, max_workers):
batch_end = min(batch_start + max_workers, total_tasks)
batch_params = scan_params_list[batch_start:batch_end]
batch_num += 1
logger.info("执行第 %d 批任务(%d-%d/%d...", batch_num, batch_start + 1, batch_end, total_tasks)
# 提交当前批次的任务(非阻塞,立即返回 future
futures = []
for params in batch_params:
future = run_and_stream_save_directories_task.submit(
cmd=params['command'],
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
site_url=params['site_url'],
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=params['timeout'],
log_file=params['log_file']
)
futures.append((params['idx'], params['site_url'], future))
# 等待当前批次所有任务完成(阻塞,确保本批完成后再启动下一批)
for idx, site_url, future in futures:
try:
result = future.result() # 阻塞等待单个任务完成
directories_found = result.get('created_directories', 0)
total_directories += directories_found
tool_directories += directories_found
processed_sites_count += 1
tool_processed += 1
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, len(sites), site_url, directories_found
)
except Exception as exc:
failed_sites.append(site_url)
if 'timeout' in str(exc).lower() or isinstance(exc, subprocess.TimeoutExpired):
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
idx, len(sites), site_url, exc
)
else:
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, len(sites), site_url, exc
)
batch_num = batch_start // max_workers + 1
logger.info(
"执行第 %d 批任务(%d-%d/%d...",
batch_num, batch_start + 1, batch_end, total_tasks
)
dirs_found, batch_failed = _execute_batch(
batch_params, tool_name, scan_id, target_id,
directory_scan_dir, len(sites)
)
total_directories += dirs_found
tool_directories += dirs_found
tool_processed += len(batch_params) - len(batch_failed)
processed_sites_count += len(batch_params) - len(batch_failed)
failed_sites.extend(batch_failed)
# 进度里程碑:每 20% 输出一次
current_progress = int((batch_end / total_tasks) * 100)
if current_progress >= last_progress_percent + 20:
user_log(scan_id, "directory_scan", f"Progress: {batch_end}/{total_tasks} sites scanned")
user_log(
scan_id, "directory_scan",
f"Progress: {batch_end}/{total_tasks} sites scanned"
)
last_progress_percent = (current_progress // 20) * 20
# 工具完成日志(开发者日志 + 用户日志)
logger.info(
"✓ 工具 %s 执行完成 - 已处理站点: %d/%d, 发现目录: %d",
tool_name, tool_processed, total_tasks, tool_directories
)
user_log(scan_id, "directory_scan", f"{tool_name} completed: found {tool_directories} directories")
# 输出汇总信息
if failed_sites:
logger.warning(
"部分站点扫描失败: %d/%d",
len(failed_sites), len(sites)
user_log(
scan_id, "directory_scan",
f"{tool_name} completed: found {tool_directories} directories"
)
if failed_sites:
logger.warning("部分站点扫描失败: %d/%d", len(failed_sites), len(sites))
logger.info(
"✓ 并发目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
processed_sites_count, len(sites), len(failed_sites), total_directories
)
return total_directories, processed_sites_count, failed_sites
@flow(
name="directory_scan",
log_prints=True,
@scan_flow(
name="directory_scan",
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
on_failure=[on_scan_flow_failed],
)
def directory_scan_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider,
) -> dict:
"""
目录扫描 Flow
主要功能:
1. 从 target 获取所有站点的 URL
2. 对每个站点 URL 执行目录扫描(支持 ffuf 等工具)
3. 流式保存扫描结果到数据库 Directory 表
工作流程:
Step 0: 创建工作目录
Step 1: 导出站点 URL 列表到文件(供扫描工具使用)
Step 2: 验证工具配置
Step 3: 并发执行扫描工具并实时保存结果(使用 ThreadPoolTaskRunner
ffuf 输出字段:
- url: 发现的目录/文件 URL
- length: 响应内容长度
- status: HTTP 状态码
- words: 响应内容单词数
- lines: 响应内容行数
- content_type: 内容类型
- duration: 请求耗时
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置字典
provider: TargetProvider 实例
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'sites_file': str,
'site_count': int,
'total_directories': int, # 发现的总目录数
'processed_sites': int, # 成功处理的站点数
'failed_sites_count': int, # 失败的站点数
'executed_tasks': list
}
Raises:
ValueError: 参数错误
RuntimeError: 执行失败
dict: 扫描结果
"""
try:
wait_for_system_load(context="directory_scan_flow")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
logger.info(
"="*60 + "\n" +
"开始目录扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始目录扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "directory_scan", "Starting directory scan")
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
if not enabled_tools:
raise ValueError("enabled_tools 不能为空")
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
# Step 1: 导出站点 URL(支持懒加载)
sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir)
# Step 1: 导出站点 URL
sites_file, site_count = _export_site_urls(
target_id, directory_scan_dir, provider
)
if site_count == 0:
logger.warning("跳过目录扫描:没有站点可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "directory_scan", "Skipped: no sites to scan", "warning")
@@ -662,16 +463,16 @@ def directory_scan_flow(
'failed_sites_count': 0,
'executed_tasks': ['export_sites']
}
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
tool_info = []
for tool_name, tool_config in enabled_tools.items():
mw = _get_max_workers(tool_config)
tool_info.append(f"{tool_name}(max_workers={mw})")
tool_info = [
f"{name}(max_workers={_get_max_workers(cfg)})"
for name, cfg in enabled_tools.items()
]
logger.info("✓ 启用工具: %s", ', '.join(tool_info))
# Step 3: 并发执行扫描工具并实时保存结果
# Step 3: 并发执行扫描
logger.info("Step 3: 并发执行扫描工具并实时保存结果")
total_directories, processed_sites, failed_sites = _run_scans_concurrently(
enabled_tools=enabled_tools,
@@ -679,19 +480,20 @@ def directory_scan_flow(
directory_scan_dir=directory_scan_dir,
scan_id=scan_id,
target_id=target_id,
site_count=site_count,
target_name=target_name
)
# 检查是否所有站点都失败
if processed_sites == 0 and site_count > 0:
logger.warning("所有站点扫描均失败 - 总站点数: %d, 失败数: %d", site_count, len(failed_sites))
# 不抛出异常,让扫描继续
# 记录 Flow 完成
logger.warning(
"所有站点扫描均失败 - 总站点数: %d, 失败数: %d",
site_count, len(failed_sites)
)
logger.info("✓ 目录扫描完成 - 发现目录: %d", total_directories)
user_log(scan_id, "directory_scan", f"directory_scan completed: found {total_directories} directories")
user_log(
scan_id, "directory_scan",
f"directory_scan completed: found {total_directories} directories"
)
return {
'success': True,
'scan_id': scan_id,
@@ -704,7 +506,7 @@ def directory_scan_flow(
'failed_sites_count': len(failed_sites),
'executed_tasks': ['export_sites', 'run_and_stream_save_directories']
}
except Exception as e:
logger.exception("目录扫描失败: %s", e)
raise
raise

View File

@@ -10,364 +10,278 @@
- 流式处理输出,批量更新数据库
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional
from prefect import flow
from apps.scan.decorators import scan_flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.tasks.fingerprint_detect import (
export_urls_for_fingerprint_task,
export_site_urls_for_fingerprint_task,
run_xingfinger_and_stream_update_tech_task,
)
from apps.scan.utils import build_scan_command, user_log
from apps.scan.utils import build_scan_command, setup_scan_directory, user_log, wait_for_system_load
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
logger = logging.getLogger(__name__)
@dataclass
class FingerprintContext:
"""指纹识别上下文,用于在各函数间传递状态"""
scan_id: int
target_id: int
target_name: str
scan_workspace_dir: str
fingerprint_dir: Optional[Path] = None
urls_file: str = ""
url_count: int = 0
source: str = "website"
def calculate_fingerprint_detect_timeout(
url_count: int,
base_per_url: float = 10.0,
min_timeout: int = 300
) -> int:
"""
根据 URL 数量计算超时时间
公式:超时时间 = URL 数量 × 每 URL 基础时间
最小值300秒
无上限
Args:
url_count: URL 数量
base_per_url: 每 URL 基础时间(秒),默认 10秒
min_timeout: 最小超时时间(秒),默认 300秒
Returns:
int: 计算出的超时时间(秒)
"""
timeout = int(url_count * base_per_url)
return max(min_timeout, timeout)
"""根据 URL 数量计算超时时间(最小 300 秒)"""
return max(min_timeout, int(url_count * base_per_url))
def _export_urls(fingerprint_dir: Path, provider) -> tuple[str, int]:
"""导出 URL 到文件,返回 (urls_file, total_count)"""
logger.info("Step 1: 导出 URL 列表")
def _export_urls(
target_id: int,
fingerprint_dir: Path,
source: str = 'website'
) -> tuple[str, int]:
"""
导出 URL 到文件
Args:
target_id: 目标 ID
fingerprint_dir: 指纹识别目录
source: 数据源类型
Returns:
tuple: (urls_file, total_count)
"""
logger.info("Step 1: 导出 URL 列表 (source=%s)", source)
urls_file = str(fingerprint_dir / 'urls.txt')
export_result = export_urls_for_fingerprint_task(
target_id=target_id,
export_result = export_site_urls_for_fingerprint_task(
output_file=urls_file,
source=source,
batch_size=1000
provider=provider,
)
total_count = export_result['total_count']
logger.info(
"✓ URL 导出完成 - 文件: %s, 数量: %d",
export_result['output_file'],
total_count
)
logger.info("✓ URL 导出完成 - 文件: %s, 数量: %d", export_result['output_file'], total_count)
return export_result['output_file'], total_count
def _run_fingerprint_detect(
enabled_tools: dict,
urls_file: str,
url_count: int,
fingerprint_dir: Path,
scan_id: int,
target_id: int,
source: str
) -> tuple[dict, list]:
"""
执行指纹识别任务
Args:
enabled_tools: 已启用的工具配置字典
urls_file: URL 文件路径
url_count: URL 总数
fingerprint_dir: 指纹识别目录
scan_id: 扫描任务 ID
target_id: 目标 ID
source: 数据源类型
Returns:
tuple: (tool_stats, failed_tools)
"""
def _run_single_tool(
tool_name: str,
tool_config: dict,
ctx: FingerprintContext
) -> tuple[Optional[dict], Optional[dict]]:
"""执行单个指纹识别工具,返回 (stats, failed_info)"""
# 获取指纹库路径
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
fingerprint_paths = get_fingerprint_paths(lib_names)
if not fingerprint_paths:
reason = f"没有可用的指纹库: {lib_names}"
logger.warning(reason)
return None, {'tool': tool_name, 'reason': reason}
# 构建命令
tool_config_with_paths = {**tool_config, **fingerprint_paths}
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='fingerprint_detect',
command_params={'urls_file': ctx.urls_file},
tool_config=tool_config_with_paths
)
except Exception as e:
reason = f"命令构建失败: {e}"
logger.error("构建 %s 命令失败: %s", tool_name, e)
return None, {'tool': tool_name, 'reason': reason}
# 计算超时时间和日志文件
timeout = calculate_fingerprint_detect_timeout(ctx.url_count)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = ctx.fingerprint_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
tool_name, ctx.url_count, timeout, list(fingerprint_paths.keys())
)
user_log(ctx.scan_id, "fingerprint_detect", f"Running {tool_name}: {command}")
# 执行扫描任务
try:
result = run_xingfinger_and_stream_update_tech_task(
cmd=command,
tool_name=tool_name,
scan_id=ctx.scan_id,
target_id=ctx.target_id,
source=ctx.source,
cwd=str(ctx.fingerprint_dir),
timeout=timeout,
log_file=str(log_file),
batch_size=100
)
stats = {
'command': command,
'result': result,
'timeout': timeout,
'fingerprint_libs': list(fingerprint_paths.keys())
}
tool_updated = result.get('updated_count', 0)
logger.info(
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
tool_name,
result.get('processed_records', 0),
tool_updated,
result.get('not_found_count', 0)
)
user_log(
ctx.scan_id, "fingerprint_detect",
f"{tool_name} completed: identified {tool_updated} fingerprints"
)
return stats, None
except Exception as exc:
reason = str(exc)
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(ctx.scan_id, "fingerprint_detect", f"{tool_name} failed: {reason}", "error")
return None, {'tool': tool_name, 'reason': reason}
def _run_fingerprint_detect(enabled_tools: dict, ctx: FingerprintContext) -> tuple[dict, list]:
"""执行指纹识别任务,返回 (tool_stats, failed_tools)"""
tool_stats = {}
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 获取指纹库路径
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
fingerprint_paths = get_fingerprint_paths(lib_names)
if not fingerprint_paths:
reason = f"没有可用的指纹库: {lib_names}"
logger.warning(reason)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 将指纹库路径合并到 tool_config用于命令构建
tool_config_with_paths = {**tool_config, **fingerprint_paths}
# 3. 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='fingerprint_detect',
command_params={
'urls_file': urls_file
},
tool_config=tool_config_with_paths
)
except Exception as e:
reason = f"命令构建失败: {str(e)}"
logger.error("构建 %s 命令失败: %s", tool_name, e)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 4. 计算超时时间
timeout = calculate_fingerprint_detect_timeout(url_count)
# 5. 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = fingerprint_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
tool_name, url_count, timeout, list(fingerprint_paths.keys())
)
user_log(scan_id, "fingerprint_detect", f"Running {tool_name}: {command}")
# 6. 执行扫描任务
try:
result = run_xingfinger_and_stream_update_tech_task(
cmd=command,
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
source=source,
cwd=str(fingerprint_dir),
timeout=timeout,
log_file=str(log_file),
batch_size=100
)
tool_stats[tool_name] = {
'command': command,
'result': result,
'timeout': timeout,
'fingerprint_libs': list(fingerprint_paths.keys())
}
tool_updated = result.get('updated_count', 0)
logger.info(
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
tool_name,
result.get('processed_records', 0),
tool_updated,
result.get('not_found_count', 0)
)
user_log(scan_id, "fingerprint_detect", f"{tool_name} completed: identified {tool_updated} fingerprints")
except Exception as exc:
reason = str(exc)
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(scan_id, "fingerprint_detect", f"{tool_name} failed: {reason}", "error")
stats, failed_info = _run_single_tool(tool_name, tool_config, ctx)
if stats:
tool_stats[tool_name] = stats
if failed_info:
failed_tools.append(failed_info)
if failed_tools:
logger.warning(
"以下指纹识别工具执行失败: %s",
', '.join([f['tool'] for f in failed_tools])
)
return tool_stats, failed_tools
@flow(
def _aggregate_results(tool_stats: dict) -> dict:
"""汇总所有工具的结果"""
return {
'processed_records': sum(
s['result'].get('processed_records', 0) for s in tool_stats.values()
),
'updated_count': sum(
s['result'].get('updated_count', 0) for s in tool_stats.values()
),
'created_count': sum(
s['result'].get('created_count', 0) for s in tool_stats.values()
),
'snapshot_count': sum(
s['result'].get('snapshot_count', 0) for s in tool_stats.values()
),
}
@scan_flow(
name="fingerprint_detect",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
on_failure=[on_scan_flow_failed],
)
def fingerprint_detect_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider,
) -> dict:
"""
指纹识别 Flow
主要功能:
1. 从数据库导出目标下所有 WebSite URL 到文件
2. 使用 xingfinger 进行技术栈识别
3. 解析结果并更新 WebSite.tech 字段(合并去重)
工作流程:
Step 0: 创建工作目录
Step 1: 导出 URL 列表
Step 2: 解析配置,获取启用的工具
Step 3: 执行 xingfinger 并解析结果
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置xingfinger
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'urls_file': str,
'url_count': int,
'processed_records': int,
'updated_count': int,
'created_count': int,
'snapshot_count': int,
'executed_tasks': list,
'tool_stats': dict
}
"""
try:
logger.info(
"="*60 + "\n" +
"开始指纹识别\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
)
user_log(scan_id, "fingerprint_detect", "Starting fingerprint detection")
wait_for_system_load(context="fingerprint_detect_flow")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
# 数据源类型(当前只支持 website
source = 'website'
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
# Step 1: 导出 URL支持懒加载
urls_file, url_count = _export_urls(target_id, fingerprint_dir, source)
if url_count == 0:
logger.info(
"开始指纹识别 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "fingerprint_detect", "Starting fingerprint detection")
# 创建上下文
ctx = FingerprintContext(
scan_id=scan_id,
target_id=target_id,
target_name=target_name,
scan_workspace_dir=scan_workspace_dir,
fingerprint_dir=setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
)
# Step 1: 导出 URL
ctx.urls_file, ctx.url_count = _export_urls(ctx.fingerprint_dir, provider)
if ctx.url_count == 0:
logger.warning("跳过指纹识别:没有 URL 可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "fingerprint_detect", "Skipped: no URLs to scan", "warning")
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': 0,
'processed_records': 0,
'updated_count': 0,
'created_count': 0,
'snapshot_count': 0,
'executed_tasks': ['export_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}
return _build_empty_result(scan_id, target_name, scan_workspace_dir, ctx.urls_file)
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
# Step 3: 执行指纹识别
logger.info("Step 3: 执行指纹识别")
tool_stats, failed_tools = _run_fingerprint_detect(
enabled_tools=enabled_tools,
urls_file=urls_file,
url_count=url_count,
fingerprint_dir=fingerprint_dir,
scan_id=scan_id,
target_id=target_id,
source=source
tool_stats, failed_tools = _run_fingerprint_detect(enabled_tools, ctx)
# 汇总结果
totals = _aggregate_results(tool_stats)
failed_tool_names = {f['tool'] for f in failed_tools}
successful_tools = [name for name in enabled_tools if name not in failed_tool_names]
logger.info("✓ 指纹识别完成 - 识别指纹: %d", totals['updated_count'])
user_log(
scan_id, "fingerprint_detect",
f"fingerprint_detect completed: identified {totals['updated_count']} fingerprints"
)
# 动态生成已执行的任务列表
executed_tasks = ['export_urls_for_fingerprint']
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats.keys()])
# 汇总所有工具的结果
total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values())
total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values())
total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values())
total_snapshots = sum(stats['result'].get('snapshot_count', 0) for stats in tool_stats.values())
# 记录 Flow 完成
logger.info("✓ 指纹识别完成 - 识别指纹: %d", total_updated)
user_log(scan_id, "fingerprint_detect", f"fingerprint_detect completed: identified {total_updated} fingerprints")
successful_tools = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
executed_tasks = ['export_site_urls_for_fingerprint']
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats])
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': url_count,
'processed_records': total_processed,
'updated_count': total_updated,
'created_count': total_created,
'snapshot_count': total_snapshots,
'urls_file': ctx.urls_file,
'url_count': ctx.url_count,
**totals,
'executed_tasks': executed_tasks,
'tool_stats': {
'total': len(enabled_tools),
@@ -378,7 +292,7 @@ def fingerprint_detect_flow(
'details': tool_stats
}
}
except ValueError as e:
logger.error("配置错误: %s", e)
raise
@@ -388,3 +302,33 @@ def fingerprint_detect_flow(
except Exception as e:
logger.exception("指纹识别失败: %s", e)
raise
def _build_empty_result(
scan_id: int,
target_name: str,
scan_workspace_dir: str,
urls_file: str
) -> dict:
"""构建空结果(无 URL 可扫描时)"""
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': 0,
'processed_records': 0,
'updated_count': 0,
'created_count': 0,
'snapshot_count': 0,
'executed_tasks': ['export_site_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}

View File

@@ -5,12 +5,13 @@
职责:
- 使用 FlowOrchestrator 解析 YAML 配置
- 在 Prefect Flow 中执行子 FlowSubflow
- 执行子 FlowSubflow
- 按照 YAML 顺序编排工作流
- 根据 scan_mode 创建对应的 Provider
- 不包含具体业务逻辑(由 Tasks 和 FlowOrchestrator 实现)
架构:
- Flow: Prefect 编排层(本文件)
- Flow: 编排层(本文件)
- FlowOrchestrator: 配置解析和执行计划apps/scan/services/
- Tasks: 执行层apps/scan/tasks/
- Handlers: 状态管理apps/scan/handlers/
@@ -18,50 +19,106 @@
# Django 环境初始化(导入即生效)
# 注意:动态扫描容器应使用 run_initiate_scan.py 启动,以便在导入前设置环境变量
from apps.common.prefect_django_setup import setup_django_for_prefect
import apps.common.django_setup # noqa: F401
from prefect import flow, task
from pathlib import Path
import logging
from concurrent.futures import ThreadPoolExecutor
from apps.scan.decorators import scan_flow
from apps.scan.handlers import (
on_initiate_scan_flow_running,
on_initiate_scan_flow_completed,
on_initiate_scan_flow_failed,
)
from prefect.futures import wait
from apps.scan.utils import setup_scan_workspace
from apps.scan.orchestrators import FlowOrchestrator
from apps.scan.utils import setup_scan_workspace
logger = logging.getLogger(__name__)
@task(name="run_subflow")
def _run_subflow_task(scan_type: str, flow_func, flow_kwargs: dict):
"""包装子 Flow 的 Task用于在并行阶段并发执行子 Flow。"""
logger.info("开始执行子 Flow: %s", scan_type)
return flow_func(**flow_kwargs)
def _create_provider(scan, target_id: int, scan_id: int):
"""根据 scan_mode 创建对应的 Provider"""
from apps.scan.models import Scan
from apps.scan.providers import (
DatabaseTargetProvider,
SnapshotTargetProvider,
ProviderContext,
)
provider_context = ProviderContext(target_id=target_id, scan_id=scan_id)
if scan.scan_mode == Scan.ScanMode.QUICK:
provider = SnapshotTargetProvider(scan_id=scan_id, context=provider_context)
logger.info("✓ 快速扫描模式 - 创建 SnapshotTargetProvider")
else:
provider = DatabaseTargetProvider(target_id=target_id, context=provider_context)
logger.info("✓ 完整扫描模式 - 使用 DatabaseTargetProvider")
return provider
@flow(
def _execute_sequential_flows(valid_flows: list, results: dict, executed_flows: list):
"""顺序执行 Flow 列表"""
for scan_type, flow_func, flow_kwargs in valid_flows:
logger.info("=" * 60)
logger.info("执行 Flow: %s", scan_type)
logger.info("=" * 60)
try:
result = flow_func(**flow_kwargs)
executed_flows.append(scan_type)
results[scan_type] = result
logger.info("%s 执行成功", scan_type)
except Exception as e:
logger.warning("%s 执行失败: %s", scan_type, e)
executed_flows.append(f"{scan_type} (失败)")
results[scan_type] = {'success': False, 'error': str(e)}
def _execute_parallel_flows(valid_flows: list, results: dict, executed_flows: list):
"""并行执行 Flow 列表(使用 ThreadPoolExecutor"""
if not valid_flows:
return
logger.info("并行执行 %d 个 Flow", len(valid_flows))
with ThreadPoolExecutor(max_workers=len(valid_flows)) as executor:
futures = []
for scan_type, flow_func, flow_kwargs in valid_flows:
logger.info("=" * 60)
logger.info("提交并行子 Flow 任务: %s", scan_type)
logger.info("=" * 60)
future = executor.submit(flow_func, **flow_kwargs)
futures.append((scan_type, future))
# 收集结果
for scan_type, future in futures:
try:
result = future.result()
executed_flows.append(scan_type)
results[scan_type] = result
logger.info("%s 执行成功", scan_type)
except Exception as e:
logger.warning("%s 执行失败: %s", scan_type, e)
executed_flows.append(f"{scan_type} (失败)")
results[scan_type] = {'success': False, 'error': str(e)}
@scan_flow(
name='initiate_scan',
description='扫描任务初始化流程',
log_prints=True,
on_running=[on_initiate_scan_flow_running],
on_completion=[on_initiate_scan_flow_completed],
on_failure=[on_initiate_scan_flow_failed],
)
def initiate_scan_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
engine_name: str,
scheduled_scan_name: str | None = None,
scheduled_scan_name: str | None = None, # noqa: ARG001
) -> dict:
"""
初始化扫描任务(动态工作流编排)
根据 YAML 配置动态编排工作流:
- 从数据库获取 engine_config (YAML)
- 检测启用的扫描类型
@@ -73,187 +130,112 @@ def initiate_scan_flow(
Stage 2: Analysis (并行执行)
- url_fetch
- directory_scan
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: Scan 工作空间目录路径
engine_name: 引擎名称(用于显示)
scheduled_scan_name: 定时扫描任务名称(可选,用于通知显示)
Returns:
dict: 执行结果摘要
Raises:
ValueError: 参数验证失败或配置无效
RuntimeError: 执行失败
"""
try:
# ==================== 参数验证 ====================
# 参数验证
if not scan_id:
raise ValueError("scan_id is required")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir is required")
if not engine_name:
raise ValueError("engine_name is required")
logger.info(
"="*60 + "\n" +
"开始初始化扫描任务\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Engine: {engine_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
)
# ==================== Task 1: 创建 Scan 工作空间 ====================
# 创建工作空间
scan_workspace_path = setup_scan_workspace(scan_workspace_dir)
# ==================== Task 2: 获取引擎配置 ====================
# 获取引擎配置
from apps.scan.models import Scan
scan = Scan.objects.get(id=scan_id)
engine_config = scan.yaml_configuration
# 使用 engine_names 进行显示
display_engine_name = ', '.join(scan.engine_names) if scan.engine_names else engine_name
# ==================== Task 3: 解析配置,生成执行计划 ====================
# 创建 Provider
provider = _create_provider(scan, target_id, scan_id)
# 获取 target_name 用于日志显示
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
logger.info("=" * 60)
logger.info("开始初始化扫描任务")
logger.info("Scan ID: %s, Target: %s, Engine: %s", scan_id, target_name, engine_name)
logger.info("Workspace: %s", scan_workspace_dir)
logger.info("=" * 60)
# 解析配置,生成执行计划
orchestrator = FlowOrchestrator(engine_config)
# FlowOrchestrator 已经解析了所有工具配置
enabled_tools_by_type = orchestrator.enabled_tools_by_type
logger.info(
f"执行计划生成成功:\n"
f" 扫描类型: {''.join(orchestrator.scan_types)}\n"
f" 总共 {len(orchestrator.scan_types)} 个 Flow"
)
# ==================== 初始化阶段进度 ====================
# 在解析完配置后立即初始化,此时已有完整的 scan_types 列表
logger.info("执行计划: %s (共 %d 个 Flow)",
''.join(orchestrator.scan_types), len(orchestrator.scan_types))
# 初始化阶段进度
from apps.scan.services import ScanService
scan_service = ScanService()
scan_service.init_stage_progress(scan_id, orchestrator.scan_types)
logger.info(f"✓ 初始化阶段进度 - Stages: {orchestrator.scan_types}")
# ==================== 更新 Target 最后扫描时间 ====================
# 在开始扫描时更新,表示"最后一次扫描开始时间"
ScanService().init_stage_progress(scan_id, orchestrator.scan_types)
logger.info("✓ 初始化阶段进度 - Stages: %s", orchestrator.scan_types)
# 更新 Target 最后扫描时间
from apps.targets.services import TargetService
target_service = TargetService()
target_service.update_last_scanned_at(target_id)
logger.info(f"✓ 更新 Target 最后扫描时间 - Target ID: {target_id}")
# ==================== Task 3: 执行 Flow动态阶段执行====================
# 注意:各阶段状态更新由 scan_flow_handlers.py 自动处理running/completed/failed
TargetService().update_last_scanned_at(target_id)
logger.info("✓ 更新 Target 最后扫描时间 - Target ID: %s", target_id)
# 执行 Flow
executed_flows = []
results = {}
# 通用执行参数
flow_kwargs = {
base_kwargs = {
'scan_id': scan_id,
'target_name': target_name,
'target_id': target_id,
'scan_workspace_dir': str(scan_workspace_path)
}
def record_flow_result(scan_type, result=None, error=None):
"""
统一的结果记录函数
Args:
scan_type: 扫描类型名称
result: 执行结果(成功时)
error: 异常对象(失败时)
"""
if error:
# 失败处理:记录错误但不抛出异常,让扫描继续执行后续阶段
error_msg = f"{scan_type} 执行失败: {str(error)}"
logger.warning(error_msg)
executed_flows.append(f"{scan_type} (失败)")
results[scan_type] = {'success': False, 'error': str(error)}
# 不再抛出异常,让扫描继续
else:
# 成功处理
executed_flows.append(scan_type)
results[scan_type] = result
logger.info(f"{scan_type} 执行成功")
def get_valid_flows(flow_names):
"""
获取有效的 Flow 函数列表,并为每个 Flow 准备专属参数
Args:
flow_names: 扫描类型名称列表
Returns:
list: [(scan_type, flow_func, flow_specific_kwargs), ...] 有效的函数列表
"""
valid_flows = []
def get_valid_flows(flow_names: list) -> list:
"""获取有效的 Flow 函数列表"""
valid = []
for scan_type in flow_names:
flow_func = orchestrator.get_flow_function(scan_type)
if flow_func:
# 为每个 Flow 准备专属的参数(包含对应的 enabled_tools
flow_specific_kwargs = dict(flow_kwargs)
flow_specific_kwargs['enabled_tools'] = enabled_tools_by_type.get(scan_type, {})
valid_flows.append((scan_type, flow_func, flow_specific_kwargs))
else:
logger.warning(f"跳过未实现的 Flow: {scan_type}")
return valid_flows
if not flow_func:
logger.warning("跳过未实现的 Flow: %s", scan_type)
continue
kwargs = dict(base_kwargs)
kwargs['enabled_tools'] = enabled_tools_by_type.get(scan_type, {})
kwargs['provider'] = provider
valid.append((scan_type, flow_func, kwargs))
return valid
# ---------------------------------------------------------
# 动态阶段执行(基于 FlowOrchestrator 定义)
# ---------------------------------------------------------
# 动态阶段执行
for mode, enabled_flows in orchestrator.get_execution_stages():
valid_flows = get_valid_flows(enabled_flows)
if not valid_flows:
continue
logger.info("=" * 60)
logger.info("%s执行阶段: %s", "顺序" if mode == 'sequential' else "并行",
', '.join(enabled_flows))
logger.info("=" * 60)
if mode == 'sequential':
# 顺序执行
logger.info(f"\n{'='*60}\n顺序执行阶段: {', '.join(enabled_flows)}\n{'='*60}")
for scan_type, flow_func, flow_specific_kwargs in get_valid_flows(enabled_flows):
logger.info(f"\n{'='*60}\n执行 Flow: {scan_type}\n{'='*60}")
try:
result = flow_func(**flow_specific_kwargs)
record_flow_result(scan_type, result=result)
except Exception as e:
record_flow_result(scan_type, error=e)
elif mode == 'parallel':
# 并行执行阶段:通过 Task 包装子 Flow并使用 Prefect TaskRunner 并发运行
logger.info(f"\n{'='*60}\n并行执行阶段: {', '.join(enabled_flows)}\n{'='*60}")
futures = []
_execute_sequential_flows(valid_flows, results, executed_flows)
else:
_execute_parallel_flows(valid_flows, results, executed_flows)
# 提交所有并行子 Flow 任务
for scan_type, flow_func, flow_specific_kwargs in get_valid_flows(enabled_flows):
logger.info(f"\n{'='*60}\n提交并行子 Flow 任务: {scan_type}\n{'='*60}")
future = _run_subflow_task.submit(
scan_type=scan_type,
flow_func=flow_func,
flow_kwargs=flow_specific_kwargs,
)
futures.append((scan_type, future))
logger.info("=" * 60)
logger.info("✓ 扫描任务初始化完成 - 执行的 Flow: %s", ', '.join(executed_flows))
logger.info("=" * 60)
# 等待所有并行子 Flow 完成
if futures:
wait([f for _, f in futures])
# 检查结果(复用统一的结果处理逻辑)
for scan_type, future in futures:
try:
result = future.result()
record_flow_result(scan_type, result=result)
except Exception as e:
record_flow_result(scan_type, error=e)
# ==================== 完成 ====================
logger.info(
"="*60 + "\n" +
"✓ 扫描任务初始化完成\n" +
f" 执行的 Flow: {', '.join(executed_flows)}\n" +
"="*60
)
# ==================== 返回结果 ====================
return {
'success': True,
'scan_id': scan_id,
@@ -262,21 +244,16 @@ def initiate_scan_flow(
'executed_flows': executed_flows,
'results': results
}
except ValueError as e:
# 参数错误
logger.error("参数错误: %s", e)
raise
except RuntimeError as e:
# 执行失败
logger.error("运行时错误: %s", e)
raise
except OSError as e:
# 文件系统错误(工作空间创建失败)
logger.error("文件系统错误: %s", e)
raise
except Exception as e:
# 其他未预期错误
logger.exception("初始化扫描任务失败: %s", e)
# 注意:失败状态更新由 Prefect State Handlers 自动处理
raise

View File

@@ -1,4 +1,4 @@
"""
"""
端口扫描 Flow
负责编排端口扫描的完整流程
@@ -10,25 +10,22 @@
- 配置由 YAML 解析
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
import subprocess
from datetime import datetime
from pathlib import Path
from typing import Callable
from prefect import flow
from apps.scan.tasks.port_scan import (
export_hosts_task,
run_and_stream_save_ports_task
)
from apps.scan.decorators import scan_flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.utils import config_parser, build_scan_command, user_log
from apps.scan.tasks.port_scan import (
export_hosts_task,
run_and_stream_save_ports_task,
)
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
logger = logging.getLogger(__name__)
@@ -40,28 +37,19 @@ def calculate_port_scan_timeout(
) -> int:
"""
根据目标数量和端口数量计算超时时间
计算公式:超时时间 = 目标数 × 端口数 × base_per_pair
超时范围60秒 ~ 2天172800秒
超时范围60秒 ~ 无上限
Args:
tool_config: 工具配置字典包含端口配置ports, top-ports等
file_path: 目标文件路径(域名/IP列表
base_per_pair: 每个"端口-目标对"的基础时间(秒),默认 0.5秒
Returns:
int: 计算出的超时时间(秒),范围60 ~ 172800
Example:
# 100个目标 × 100个端口 × 0.5秒 = 5000秒
# 10个目标 × 1000个端口 × 0.5秒 = 5000秒
timeout = calculate_port_scan_timeout(
tool_config={'top-ports': 100},
file_path='/path/to/domains.txt'
)
int: 计算出的超时时间(秒),最小 60 秒
"""
try:
# 1. 统计目标数量
result = subprocess.run(
['wc', '-l', file_path],
capture_output=True,
@@ -69,133 +57,110 @@ def calculate_port_scan_timeout(
check=True
)
target_count = int(result.stdout.strip().split()[0])
# 2. 解析端口数量
port_count = _parse_port_count(tool_config)
# 3. 计算超时时间
# 总工作量 = 目标数 × 端口数
total_work = target_count * port_count
timeout = int(total_work * base_per_pair)
# 4. 设置合理的下限(不再设置上限)
min_timeout = 60 # 最小 60 秒
timeout = max(min_timeout, timeout)
timeout = max(60, int(total_work * base_per_pair))
logger.info(
f"计算端口扫描 timeout - "
f"目标数: {target_count}, "
f"端口数: {port_count}, "
f"总工作量: {total_work}, "
f"超时: {timeout}"
"计算端口扫描 timeout - 目标数: %d, 端口数: %d, 总工作量: %d, 超时: %d",
target_count, port_count, total_work, timeout
)
return timeout
except Exception as e:
logger.warning(f"计算 timeout 失败: {e},使用默认值 600秒")
logger.warning("计算 timeout 失败: %s,使用默认值 600秒", e)
return 600
def _parse_port_count(tool_config: dict) -> int:
"""
从工具配置中解析端口数量
优先级:
1. top-ports: N → 返回 N
2. ports: "80,443,8080" → 返回逗号分隔的数量
3. ports: "1-1000" → 返回范围的大小
4. ports: "1-65535" → 返回 65535
5. 默认 → 返回 100naabu 默认扫描 top 100
Args:
tool_config: 工具配置字典
Returns:
int: 端口数量
"""
# 1. 检查 top-ports 配置
# 检查 top-ports 配置
if 'top-ports' in tool_config:
top_ports = tool_config['top-ports']
if isinstance(top_ports, int) and top_ports > 0:
return top_ports
logger.warning(f"top-ports 配置无效: {top_ports},使用默认值")
# 2. 检查 ports 配置
logger.warning("top-ports 配置无效: %s,使用默认值", top_ports)
# 检查 ports 配置
if 'ports' in tool_config:
ports_str = str(tool_config['ports']).strip()
# 2.1 逗号分隔的端口列表80,443,8080
# 逗号分隔的端口列表80,443,8080
if ',' in ports_str:
port_list = [p.strip() for p in ports_str.split(',') if p.strip()]
return len(port_list)
# 2.2 端口范围1-1000
return len([p.strip() for p in ports_str.split(',') if p.strip()])
# 端口范围1-1000
if '-' in ports_str:
try:
start, end = ports_str.split('-', 1)
start_port = int(start.strip())
end_port = int(end.strip())
if 1 <= start_port <= end_port <= 65535:
return end_port - start_port + 1
logger.warning(f"端口范围无效: {ports_str},使用默认值")
logger.warning("端口范围无效: %s,使用默认值", ports_str)
except ValueError:
logger.warning(f"端口范围解析失败: {ports_str},使用默认值")
# 2.3 单个端口
logger.warning("端口范围解析失败: %s,使用默认值", ports_str)
# 单个端口
try:
port = int(ports_str)
if 1 <= port <= 65535:
return 1
except ValueError:
logger.warning(f"端口配置解析失败: {ports_str},使用默认值")
# 3. 默认值naabu 默认扫描 top 100 端口
logger.warning("端口配置解析失败: %s,使用默认值", ports_str)
# 默认值naabu 默认扫描 top 100 端口
return 100
def _export_hosts(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
def _export_hosts(port_scan_dir: Path, provider) -> tuple[str, int]:
"""
导出主机列表到文件
根据 Target 类型自动决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名
- IP: 直接写入 target.name
- CIDR: 展开 CIDR 范围内的所有 IP
Args:
target_id: 目标 ID
port_scan_dir: 端口扫描目录
provider: TargetProvider 实例
Returns:
tuple: (hosts_file, host_count, target_type)
tuple: (hosts_file, host_count)
"""
logger.info("Step 1: 导出主机列表")
hosts_file = str(port_scan_dir / 'hosts.txt')
export_result = export_hosts_task(
target_id=target_id,
output_file=hosts_file,
batch_size=1000 # 每次读取 1000 条,优化内存占用
provider=provider,
)
host_count = export_result['total_count']
target_type = export_result.get('target_type', 'unknown')
logger.info(
"✓ 主机列表导出完成 - 类型: %s, 文件: %s, 数量: %d",
target_type,
export_result['output_file'],
host_count
"✓ 主机列表导出完成 - 文件: %s, 数量: %d",
export_result['output_file'], host_count
)
if host_count == 0:
logger.warning("目标下没有可扫描的主机,无法执行端口扫描")
return export_result['output_file'], host_count, target_type
return export_result['output_file'], host_count
def _run_scans_sequentially(
@@ -204,84 +169,68 @@ def _run_scans_sequentially(
port_scan_dir: Path,
scan_id: int,
target_id: int,
target_name: str
target_name: str,
) -> tuple[dict, int, list, list]:
"""
串行执行端口扫描任务
Args:
enabled_tools: 已启用的工具配置字典
domains_file: 域名文件路径
port_scan_dir: 端口扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
target_name: 目标名称(用于错误日志)
target_name: 目标名称(用于日志显示
Returns:
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
注意:端口扫描是流式输出,不生成结果文件
Raises:
RuntimeError: 所有工具均失败
"""
# ==================== 构建命令并串行执行 ====================
tool_stats = {}
processed_records = 0
failed_tools = [] # 记录失败的工具(含原因)
# for循环执行工具按顺序串行运行每个启用的端口扫描工具
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 构建完整命令(变量替换)
# 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='port_scan',
command_params={
'domains_file': domains_file # 对应 {domains_file}
},
tool_config=tool_config #yaml的工具配置
command_params={'domains_file': domains_file},
tool_config=tool_config
)
except Exception as e:
reason = f"命令构建失败: {str(e)}"
logger.error(f"构建 {tool_name} 命令失败: {e}")
reason = f"命令构建失败: {e}"
logger.error("构建 %s 命令失败: %s", tool_name, e)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 获取超时时间(支持 'auto' 动态计算)
# 获取超时时间
config_timeout = tool_config['timeout']
if config_timeout == 'auto':
# 动态计算超时时间
config_timeout = calculate_port_scan_timeout(
tool_config=tool_config,
file_path=str(domains_file)
)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {config_timeout}")
# 2.1 生成日志文件路径
from datetime import datetime
config_timeout = calculate_port_scan_timeout(tool_config, str(domains_file))
logger.info("✓ 工具 %s 动态计算 timeout: %d", tool_name, config_timeout)
# 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = port_scan_dir / f"{tool_name}_{timestamp}.log"
# 3. 执行扫描任务
logger.info("开始执行 %s 扫描(超时: %d秒)...", tool_name, config_timeout)
user_log(scan_id, "port_scan", f"Running {tool_name}: {command}")
# 执行扫描任务
try:
# 直接调用 task串行执行
# 注意:端口扫描是流式输出到 stdout不使用 output_file
result = run_and_stream_save_ports_task(
cmd=command,
tool_name=tool_name, # 工具名称
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
cwd=str(port_scan_dir),
shell=True,
batch_size=1000,
timeout=config_timeout,
log_file=str(log_file) # 新增:日志文件路径
log_file=str(log_file)
)
tool_stats[tool_name] = {
'command': command,
'result': result,
@@ -289,15 +238,10 @@ def _run_scans_sequentially(
}
tool_records = result.get('processed_records', 0)
processed_records += tool_records
logger.info(
"✓ 工具 %s 流式处理完成 - 记录数: %d",
tool_name, tool_records
)
logger.info("✓ 工具 %s 流式处理完成 - 记录数: %d", tool_name, tool_records)
user_log(scan_id, "port_scan", f"{tool_name} completed: found {tool_records} ports")
except subprocess.TimeoutExpired as exc:
# 超时异常单独处理
# 注意:流式处理任务超时时,已解析的数据已保存到数据库
except subprocess.TimeoutExpired:
reason = f"timeout after {config_timeout}s"
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.warning(
@@ -307,134 +251,110 @@ def _run_scans_sequentially(
)
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
except Exception as exc:
# 其他异常
reason = str(exc)
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
if failed_tools:
logger.warning(
"以下扫描工具执行失败: %s",
', '.join([f['tool'] for f in failed_tools])
)
if not tool_stats:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
logger.warning("所有端口扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
# 返回空结果,不抛出异常,让扫描继续
logger.warning("所有端口扫描工具均失败 - Target: %s, 失败工具: %s", target_name, error_details)
return {}, 0, [], failed_tools
# 动态计算成功的工具列表
successful_tool_names = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
successful_tool_names = [
name for name in enabled_tools
if name not in [f['tool'] for f in failed_tools]
]
logger.info(
"✓ 串行端口扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
len(tool_stats), len(enabled_tools),
', '.join(successful_tool_names) if successful_tool_names else '',
', '.join([f['tool'] for f in failed_tools]) if failed_tools else ''
)
return tool_stats, processed_records, successful_tool_names, failed_tools
@flow(
name="port_scan",
log_prints=True,
@scan_flow(
name="port_scan",
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
on_failure=[on_scan_flow_failed],
)
def port_scan_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider,
) -> dict:
"""
端口扫描 Flow
主要功能:
1. 扫描目标域名/IP 的开放端口
2. 保存 host + ip + port 三元映射到 HostPortMapping 表
输出资产:
- HostPortMapping主机端口映射host + ip + port 三元组)
工作流程:
Step 0: 创建工作目录
Step 1: 导出域名列表到文件(供扫描工具使用)
Step 2: 解析配置,获取启用的工具
Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库(→ HostPortMapping
Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库
Args:
scan_id: 扫描任务 ID
target_name: 域名
target_id: 目标 ID
scan_workspace_dir: Scan 工作空间目录
enabled_tools: 启用的工具配置字典
provider: TargetProvider 实例
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'hosts_file': str,
'host_count': int,
'processed_records': int,
'executed_tasks': list,
'tool_stats': {
'total': int, # 总工具数
'successful': int, # 成功工具数
'failed': int, # 失败工具数
'successful_tools': list[str], # 成功工具列表 ['naabu_active']
'failed_tools': list[dict], # 失败工具列表 [{'tool': 'naabu_passive', 'reason': '超时'}]
'details': dict # 详细执行结果(保留向后兼容)
}
}
dict: 扫描结果
Raises:
ValueError: 配置错误
RuntimeError: 执行失败
Note:
端口扫描工具(如 naabu会解析域名获取 IP输出 host + ip + port 三元组。
同一 host 可能对应多个 IPCDN、负载均衡因此使用三元映射表存储。
"""
try:
# 参数验证
wait_for_system_load(context="port_scan_flow")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
if not enabled_tools:
raise ValueError("enabled_tools 不能为空")
logger.info(
"="*60 + "\n" +
"开始端口扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始端口扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "port_scan", "Starting port scan")
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
# Step 1: 导出主机列表到文件(根据 Target 类型自动决定内容)
hosts_file, host_count, target_type = _export_hosts(target_id, port_scan_dir)
# Step 1: 导出主机列表
hosts_file, host_count = _export_hosts(port_scan_dir, provider)
if host_count == 0:
logger.warning("跳过端口扫描:没有主机可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "port_scan", "Skipped: no hosts to scan", "warning")
@@ -445,7 +365,6 @@ def port_scan_flow(
'scan_workspace_dir': scan_workspace_dir,
'hosts_file': hosts_file,
'host_count': 0,
'target_type': target_type,
'processed_records': 0,
'executed_tasks': ['export_hosts'],
'tool_stats': {
@@ -457,14 +376,11 @@ def port_scan_flow(
'details': {}
}
}
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info(
"✓ 启用工具: %s",
', '.join(enabled_tools.keys())
)
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
# Step 3: 串行执行扫描工具
logger.info("Step 3: 串行执行扫描工具")
tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially(
@@ -473,17 +389,15 @@ def port_scan_flow(
port_scan_dir=port_scan_dir,
scan_id=scan_id,
target_id=target_id,
target_name=target_name
target_name=target_name,
)
# 记录 Flow 完成
logger.info("✓ 端口扫描完成 - 发现端口: %d", processed_records)
user_log(scan_id, "port_scan", f"port_scan completed: found {processed_records} ports")
# 动态生成已执行的任务列表
executed_tasks = ['export_hosts', 'parse_config']
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats.keys()])
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats])
return {
'success': True,
'scan_id': scan_id,
@@ -491,7 +405,6 @@ def port_scan_flow(
'scan_workspace_dir': scan_workspace_dir,
'hosts_file': hosts_file,
'host_count': host_count,
'target_type': target_type,
'processed_records': processed_records,
'executed_tasks': executed_tasks,
'tool_stats': {

View File

@@ -0,0 +1,173 @@
"""
截图 Flow
负责编排截图的完整流程:
1. 从 Provider 获取 URL 列表
2. 批量截图并保存快照
3. 同步到资产表
"""
import logging
from apps.scan.decorators import scan_flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.providers import TargetProvider
from apps.scan.tasks.screenshot import capture_screenshots_task
from apps.scan.utils import user_log, wait_for_system_load
logger = logging.getLogger(__name__)
def _parse_screenshot_config(enabled_tools: dict) -> dict:
"""解析截图配置"""
playwright_config = enabled_tools.get('playwright', {})
return {
'concurrency': playwright_config.get('concurrency', 5),
}
def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str]:
"""
从 Provider 收集网站 URL带回退逻辑
优先级WebSite → HostPortMapping → Default URL
Returns:
tuple: (urls, source)
- urls: URL 列表
- source: 数据来源 ('website' | 'host_port' | 'default')
"""
logger.info("从 Provider 获取网站 URL - Provider: %s", type(provider).__name__)
# 优先从 WebSite 获取
urls = list(provider.iter_websites())
if urls:
logger.info("使用 WebSite 数据源 - 数量: %d", len(urls))
return urls, "website"
# 回退到 HostPortMapping
urls = list(provider.iter_host_port_urls())
if urls:
logger.info("WebSite 为空,回退到 HostPortMapping - 数量: %d", len(urls))
return urls, "host_port"
# 最终回退到默认 URL
urls = list(provider.iter_default_urls())
logger.info("HostPortMapping 为空,回退到默认 URL - 数量: %d", len(urls))
return urls, "default"
def _build_empty_result(scan_id: int, target_name: str) -> dict:
"""构建空结果"""
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'total_urls': 0,
'successful': 0,
'failed': 0,
'synced': 0
}
@scan_flow(
name="screenshot",
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
on_failure=[on_scan_flow_failed],
)
def screenshot_flow(
scan_id: int,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict,
provider: TargetProvider,
) -> dict:
"""
截图 Flow
Args:
scan_id: 扫描任务 ID
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置
provider: TargetProvider 实例
Returns:
截图结果字典
"""
try:
wait_for_system_load(context="screenshot_flow")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
logger.info(
"开始截图扫描 - Scan ID: %s, Target: %s",
scan_id, target_name
)
user_log(scan_id, "screenshot", "Starting screenshot capture")
# Step 1: 解析配置
config = _parse_screenshot_config(enabled_tools)
concurrency = config['concurrency']
logger.info("截图配置 - 并发: %d", concurrency)
# Step 2: 从 Provider 收集 URL 列表(带回退逻辑)
urls, source = _collect_urls_from_provider(provider)
logger.info("URL 收集完成 - 来源: %s, 数量: %d", source, len(urls))
if not urls:
logger.warning("没有可截图的 URL跳过截图任务")
user_log(scan_id, "screenshot", "Skipped: no URLs to capture", "warning")
return _build_empty_result(scan_id, target_name)
user_log(scan_id, "screenshot", f"Found {len(urls)} URLs to capture")
# Step 3: 批量截图
logger.info("批量截图 - %d 个 URL", len(urls))
capture_result = capture_screenshots_task(
urls=urls,
scan_id=scan_id,
target_id=target_id,
config={'concurrency': concurrency}
)
# Step 4: 同步到资产表
logger.info("同步截图到资产表")
from apps.asset.services.screenshot_service import ScreenshotService
synced = ScreenshotService().sync_screenshots_to_asset(scan_id, target_id)
total = capture_result['total']
successful = capture_result['successful']
failed = capture_result['failed']
logger.info(
"✓ 截图完成 - 总数: %d, 成功: %d, 失败: %d, 同步: %d",
total, successful, failed, synced
)
user_log(
scan_id, "screenshot",
f"Screenshot completed: {successful}/{total} captured, {synced} synced"
)
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'total_urls': total,
'successful': successful,
'failed': failed,
'synced': synced
}
except Exception:
logger.exception("截图 Flow 失败")
user_log(scan_id, "screenshot", "Screenshot failed", "error")
raise

View File

@@ -1,4 +1,3 @@
"""
站点扫描 Flow
@@ -11,451 +10,405 @@
- 配置由 YAML 解析
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
import subprocess
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Callable
from prefect import flow
from apps.scan.tasks.site_scan import export_site_urls_task, run_and_stream_save_websites_task
from typing import Optional
from apps.scan.decorators import scan_flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.utils import config_parser, build_scan_command, user_log
from apps.scan.tasks.site_scan import (
export_site_urls_task,
run_and_stream_save_websites_task,
)
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
logger = logging.getLogger(__name__)
def calculate_timeout_by_line_count(
tool_config: dict,
file_path: str,
base_per_time: int = 1,
min_timeout: int = 60
) -> int:
"""
根据文件行数计算 timeout
使用 wc -l 统计文件行数,根据行数和每行基础时间计算 timeout
Args:
tool_config: 工具配置字典(此函数未使用,但保持接口一致性)
file_path: 要统计行数的文件路径
base_per_time: 每行的基础时间默认1秒
min_timeout: 最小超时时间默认60秒
Returns:
int: 计算出的超时时间(秒),不低于 min_timeout
Example:
timeout = calculate_timeout_by_line_count(
tool_config={},
file_path='/path/to/urls.txt',
base_per_time=2
)
"""
@dataclass
class ScanContext:
"""扫描上下文,封装扫描参数"""
scan_id: int
target_id: int
target_name: str
site_scan_dir: Path
urls_file: str
total_urls: int
def _count_file_lines(file_path: str) -> int:
"""使用 wc -l 统计文件行数"""
try:
# 使用 wc -l 快速统计行数
result = subprocess.run(
['wc', '-l', file_path],
capture_output=True,
text=True,
check=True
)
# wc -l 输出格式:行数 + 空格 + 文件名
line_count = int(result.stdout.strip().split()[0])
# 计算 timeout行数 × 每行基础时间,不低于最小值
timeout = max(line_count * base_per_time, min_timeout)
logger.info(
f"timeout 自动计算: 文件={file_path}, "
f"行数={line_count}, 每行时间={base_per_time}秒, 最小值={min_timeout}秒, timeout={timeout}"
)
return timeout
except Exception as e:
# 如果 wc -l 失败,使用默认值
logger.warning(f"wc -l 计算行数失败: {e},使用默认 timeout: {min_timeout}")
return min_timeout
return int(result.stdout.strip().split()[0])
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
logger.warning("wc -l 计算行数失败: %s,返回 0", e)
return 0
def _calculate_timeout_by_line_count(
file_path: str,
base_per_time: int = 1,
min_timeout: int = 60
) -> int:
"""
根据文件行数计算 timeout
Args:
file_path: 要统计行数的文件路径
base_per_time: 每行的基础时间默认1秒
min_timeout: 最小超时时间默认60秒
Returns:
int: 计算出的超时时间(秒),不低于 min_timeout
"""
line_count = _count_file_lines(file_path)
timeout = max(line_count * base_per_time, min_timeout)
logger.info(
"timeout 自动计算: 文件=%s, 行数=%d, 每行时间=%d秒, timeout=%d",
file_path, line_count, base_per_time, timeout
)
return timeout
def _export_site_urls(target_id: int, site_scan_dir: Path, target_name: str = None) -> tuple[str, int, int]:
def _export_site_urls(
site_scan_dir: Path,
provider,
) -> tuple[str, int]:
"""
导出站点 URL 到文件
Args:
target_id: 目标 ID
site_scan_dir: 站点扫描目录
target_name: 目标名称(用于懒加载时写入默认值)
provider: TargetProvider 实例
Returns:
tuple: (urls_file, total_urls, association_count)
Raises:
ValueError: URL 数量为 0
tuple: (urls_file, total_urls)
"""
logger.info("Step 1: 导出站点URL列表")
urls_file = str(site_scan_dir / 'site_urls.txt')
export_result = export_site_urls_task(
target_id=target_id,
output_file=urls_file,
batch_size=1000 # 每次处理1000个子域名
provider=provider,
)
total_urls = export_result['total_urls']
association_count = export_result['association_count'] # 主机端口关联数
logger.info(
"✓ 站点URL导出完成 - 文件: %s, URL数量: %d, 关联数: %d",
export_result['output_file'],
total_urls,
association_count
"✓ 站点URL导出完成 - 文件: %s, URL数量: %d",
export_result['output_file'], total_urls
)
if total_urls == 0:
logger.warning("目标下没有可用的站点URL无法执行站点扫描")
# 不抛出异常,由上层决定如何处理
# raise ValueError("目标下没有可用的站点URL无法执行站点扫描")
return export_result['output_file'], total_urls, association_count
return export_result['output_file'], total_urls
def _get_tool_timeout(tool_config: dict, urls_file: str) -> int:
"""获取工具超时时间(支持 'auto' 动态计算)"""
config_timeout = tool_config.get('timeout', 300)
if config_timeout == 'auto':
return _calculate_timeout_by_line_count(urls_file, base_per_time=1)
dynamic_timeout = _calculate_timeout_by_line_count(urls_file, base_per_time=1)
return max(dynamic_timeout, config_timeout)
def _execute_single_tool(
tool_name: str,
tool_config: dict,
ctx: ScanContext
) -> Optional[dict]:
"""
执行单个扫描工具
Returns:
成功返回结果字典,失败返回 None
"""
# 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='site_scan',
command_params={'url_file': ctx.urls_file},
tool_config=tool_config
)
except (ValueError, KeyError) as e:
logger.error("构建 %s 命令失败: %s", tool_name, e)
return None
timeout = _get_tool_timeout(tool_config, ctx.urls_file)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = ctx.site_scan_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 站点扫描 - URL数: %d, 超时: %ds",
tool_name, ctx.total_urls, timeout
)
user_log(ctx.scan_id, "site_scan", f"Running {tool_name}: {command}")
try:
result = run_and_stream_save_websites_task(
cmd=command,
tool_name=tool_name,
scan_id=ctx.scan_id,
target_id=ctx.target_id,
cwd=str(ctx.site_scan_dir),
shell=True,
timeout=timeout,
log_file=str(log_file)
)
tool_created = result.get('created_websites', 0)
skipped = result.get('skipped_no_subdomain', 0) + result.get('skipped_failed', 0)
logger.info(
"✓ 工具 %s 完成 - 处理: %d, 创建: %d, 跳过: %d",
tool_name, result.get('processed_records', 0), tool_created, skipped
)
user_log(
ctx.scan_id, "site_scan",
f"{tool_name} completed: found {tool_created} websites"
)
return {'command': command, 'result': result, 'timeout': timeout}
except subprocess.TimeoutExpired:
logger.warning(
"⚠️ 工具 %s 执行超时 - 超时配置: %d秒 (超时前数据已保存)",
tool_name, timeout
)
user_log(
ctx.scan_id, "site_scan",
f"{tool_name} failed: timeout after {timeout}s", "error"
)
except (OSError, RuntimeError) as exc:
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(ctx.scan_id, "site_scan", f"{tool_name} failed: {exc}", "error")
return None
def _run_scans_sequentially(
enabled_tools: dict,
urls_file: str,
total_urls: int,
site_scan_dir: Path,
scan_id: int,
target_id: int,
target_name: str
ctx: ScanContext
) -> tuple[dict, int, list, list]:
"""
串行执行站点扫描任务
Args:
enabled_tools: 已启用的工具配置字典
urls_file: URL 文件路径
total_urls: URL 总数
site_scan_dir: 站点扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
target_name: 目标名称(用于错误日志)
Returns:
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
Raises:
RuntimeError: 所有工具均失败
tuple: (tool_stats, processed_records, successful_tools, failed_tools)
"""
tool_stats = {}
processed_records = 0
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 构建完整命令(变量替换)
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='site_scan',
command_params={
'url_file': urls_file
},
tool_config=tool_config
)
except Exception as e:
reason = f"命令构建失败: {str(e)}"
logger.error(f"构建 {tool_name} 命令失败: {e}")
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 获取超时时间(支持 'auto' 动态计算)
config_timeout = tool_config.get('timeout', 300)
if config_timeout == 'auto':
# 动态计算超时时间
timeout = calculate_timeout_by_line_count(tool_config, urls_file, base_per_time=1)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {timeout}")
result = _execute_single_tool(tool_name, tool_config, ctx)
if result:
tool_stats[tool_name] = result
processed_records += result['result'].get('processed_records', 0)
else:
# 使用配置的超时时间和动态计算的较大值
dynamic_timeout = calculate_timeout_by_line_count(tool_config, urls_file, base_per_time=1)
timeout = max(dynamic_timeout, config_timeout)
# 2.1 生成日志文件路径(类似端口扫描)
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = site_scan_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 站点扫描 - URL数: %d, 最终超时: %ds",
tool_name, total_urls, timeout
)
user_log(scan_id, "site_scan", f"Running {tool_name}: {command}")
# 3. 执行扫描任务
try:
# 流式执行扫描并实时保存结果
result = run_and_stream_save_websites_task(
cmd=command,
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
cwd=str(site_scan_dir),
shell=True,
timeout=timeout,
log_file=str(log_file)
)
tool_stats[tool_name] = {
'command': command,
'result': result,
'timeout': timeout
}
tool_records = result.get('processed_records', 0)
tool_created = result.get('created_websites', 0)
processed_records += tool_records
logger.info(
"✓ 工具 %s 流式处理完成 - 处理记录: %d, 创建站点: %d, 跳过: %d",
tool_name,
tool_records,
tool_created,
result.get('skipped_no_subdomain', 0) + result.get('skipped_failed', 0)
)
user_log(scan_id, "site_scan", f"{tool_name} completed: found {tool_created} websites")
except subprocess.TimeoutExpired as exc:
# 超时异常单独处理
reason = f"timeout after {timeout}s"
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.warning(
"⚠️ 工具 %s 执行超时 - 超时配置: %d\n"
"注意:超时前已解析的站点数据已保存到数据库,但扫描未完全完成。",
tool_name, timeout
)
user_log(scan_id, "site_scan", f"{tool_name} failed: {reason}", "error")
except Exception as exc:
# 其他异常
reason = str(exc)
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(scan_id, "site_scan", f"{tool_name} failed: {reason}", "error")
failed_tools.append({'tool': tool_name, 'reason': '执行失败'})
if failed_tools:
logger.warning(
"以下扫描工具执行失败: %s",
', '.join([f['tool'] for f in failed_tools])
', '.join(f['tool'] for f in failed_tools)
)
if not tool_stats:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
logger.warning("所有站点扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
# 返回空结果,不抛出异常,让扫描继续
logger.warning(
"所有站点扫描工具均失败 - 目标: %s", ctx.target_name
)
return {}, 0, [], failed_tools
# 动态计算成功的工具列表
successful_tool_names = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
successful_tools = [
name for name in enabled_tools
if name not in {f['tool'] for f in failed_tools}
]
logger.info(
"串行站点扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
len(tool_stats), len(enabled_tools),
', '.join(successful_tool_names) if successful_tool_names else '',
', '.join([f['tool'] for f in failed_tools]) if failed_tools else ''
"✓ 站点扫描执行完成 - 成功: %d/%d",
len(tool_stats), len(enabled_tools)
)
return tool_stats, processed_records, successful_tool_names, failed_tools
return tool_stats, processed_records, successful_tools, failed_tools
def calculate_timeout(url_count: int, base: int = 600, per_url: int = 1) -> int:
"""
根据 URL 数量动态计算扫描超时时间
规则:
- 基础时间:默认 600 秒10 分钟)
- 每个 URL 额外增加:默认 1 秒
Args:
url_count: URL 数量,必须为正整数
base: 基础超时时间(秒),默认 600
per_url: 每个 URL 增加的时间(秒),默认 1
Returns:
int: 计算得到的超时时间(秒),不超过 max_timeout
Raises:
ValueError: 当 url_count 为负数或 0 时抛出异常
"""
if url_count < 0:
raise ValueError(f"URL数量不能为负数: {url_count}")
if url_count == 0:
raise ValueError("URL数量不能为0")
timeout = base + int(url_count * per_url)
# 不设置上限,由调用方根据需要控制
return timeout
def _build_empty_result(
scan_id: int,
target_name: str,
scan_workspace_dir: str,
urls_file: str,
) -> dict:
"""构建空结果(无 URL 可扫描时)"""
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'total_urls': 0,
'processed_records': 0,
'created_websites': 0,
'skipped_no_subdomain': 0,
'skipped_failed': 0,
'executed_tasks': ['export_site_urls'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}
@flow(
name="site_scan",
log_prints=True,
def _aggregate_tool_results(tool_stats: dict) -> tuple[int, int, int]:
"""汇总工具结果"""
total_created = sum(
s['result'].get('created_websites', 0) for s in tool_stats.values()
)
total_skipped_no_subdomain = sum(
s['result'].get('skipped_no_subdomain', 0) for s in tool_stats.values()
)
total_skipped_failed = sum(
s['result'].get('skipped_failed', 0) for s in tool_stats.values()
)
return total_created, total_skipped_no_subdomain, total_skipped_failed
def _validate_flow_params(
scan_id: int,
target_id: int,
scan_workspace_dir: str
) -> None:
"""验证 Flow 参数"""
if scan_id is None:
raise ValueError("scan_id 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
@scan_flow(
name="site_scan",
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
on_failure=[on_scan_flow_failed],
)
def site_scan_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider,
) -> dict:
"""
站点扫描 Flow
主要功能:
1. 从target获取所有子域名与其对应的端口号拼接成URL写入文件
2. 用httpx进行批量请求并实时保存到数据库流式处理
工作流程:
Step 0: 创建工作目录
Step 1: 导出站点 URL 列表
Step 2: 解析配置,获取启用的工具
Step 3: 串行执行扫描工具并实时保存结果
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置字典
provider: TargetProvider 实例
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'urls_file': str,
'total_urls': int,
'association_count': int,
'processed_records': int,
'created_websites': int,
'skipped_no_subdomain': int,
'skipped_failed': int,
'executed_tasks': list,
'tool_stats': {
'total': int,
'successful': int,
'failed': int,
'successful_tools': list[str],
'failed_tools': list[dict]
}
}
dict: 扫描结果
Raises:
ValueError: 配置错误
RuntimeError: 执行失败
"""
try:
logger.info(
"="*60 + "\n" +
"开始站点扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
)
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
wait_for_system_load(context="site_scan_flow")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
raise ValueError("无法获取 Target 名称")
logger.info(
"开始站点扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
_validate_flow_params(scan_id, target_id, scan_workspace_dir)
user_log(scan_id, "site_scan", "Starting site scan")
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan')
# Step 1: 导出站点 URL
urls_file, total_urls, association_count = _export_site_urls(
target_id, site_scan_dir, target_name
urls_file, total_urls = _export_site_urls(
site_scan_dir, provider
)
if total_urls == 0:
logger.warning("跳过站点扫描:没有站点 URL 可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "site_scan", "Skipped: no site URLs to scan", "warning")
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'total_urls': 0,
'association_count': association_count,
'processed_records': 0,
'created_websites': 0,
'skipped_no_subdomain': 0,
'skipped_failed': 0,
'executed_tasks': ['export_site_urls'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}
return _build_empty_result(
scan_id, target_name, scan_workspace_dir, urls_file
)
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info(
"✓ 启用工具: %s",
', '.join(enabled_tools.keys())
)
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools))
# Step 3: 串行执行扫描工具
logger.info("Step 3: 串行执行扫描工具并实时保存结果")
tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially(
enabled_tools=enabled_tools,
urls_file=urls_file,
total_urls=total_urls,
site_scan_dir=site_scan_dir,
ctx = ScanContext(
scan_id=scan_id,
target_id=target_id,
target_name=target_name
target_name=target_name,
site_scan_dir=site_scan_dir,
urls_file=urls_file,
total_urls=total_urls
)
# 动态生成已执行的任务列表
tool_stats, processed_records, successful_tools, failed_tools = \
_run_scans_sequentially(enabled_tools, ctx)
# 汇总结果
executed_tasks = ['export_site_urls', 'parse_config']
executed_tasks.extend([f'run_and_stream_save_websites ({tool})' for tool in tool_stats.keys()])
# 汇总所有工具的结果
total_created = sum(stats['result'].get('created_websites', 0) for stats in tool_stats.values())
total_skipped_no_subdomain = sum(stats['result'].get('skipped_no_subdomain', 0) for stats in tool_stats.values())
total_skipped_failed = sum(stats['result'].get('skipped_failed', 0) for stats in tool_stats.values())
# 记录 Flow 完成
executed_tasks.extend(
f'run_and_stream_save_websites ({tool})' for tool in tool_stats
)
total_created, total_skipped_no_sub, total_skipped_failed = \
_aggregate_tool_results(tool_stats)
logger.info("✓ 站点扫描完成 - 创建站点: %d", total_created)
user_log(scan_id, "site_scan", f"site_scan completed: found {total_created} websites")
user_log(
scan_id, "site_scan",
f"site_scan completed: found {total_created} websites"
)
return {
'success': True,
'scan_id': scan_id,
@@ -463,28 +416,22 @@ def site_scan_flow(
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'total_urls': total_urls,
'association_count': association_count,
'processed_records': processed_records,
'created_websites': total_created,
'skipped_no_subdomain': total_skipped_no_subdomain,
'skipped_no_subdomain': total_skipped_no_sub,
'skipped_failed': total_skipped_failed,
'executed_tasks': executed_tasks,
'tool_stats': {
'total': len(enabled_tools),
'successful': len(successful_tool_names),
'successful': len(successful_tools),
'failed': len(failed_tools),
'successful_tools': successful_tool_names,
'successful_tools': successful_tools,
'failed_tools': failed_tools,
'details': tool_stats
}
}
except ValueError as e:
logger.error("配置错误: %s", e)
except ValueError:
raise
except RuntimeError as e:
logger.error("运行时错误: %s", e)
except RuntimeError:
raise
except Exception as e:
logger.exception("站点扫描失败: %s", e)
raise

File diff suppressed because it is too large Load Diff

View File

@@ -11,17 +11,14 @@
- IP 和 CIDR 类型会自动跳过(被动收集工具不支持)
"""
# Django 环境初始化
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Dict
from prefect import flow
from apps.scan.decorators import scan_flow
from apps.common.validators import validate_domain
from apps.scan.tasks.url_fetch import run_url_fetcher_task
from apps.scan.utils import build_scan_command
@@ -30,13 +27,13 @@ from apps.scan.utils import build_scan_command
logger = logging.getLogger(__name__)
@flow(name="domain_name_url_fetch_flow", log_prints=True)
@scan_flow(name="domain_name_url_fetch_flow")
def domain_name_url_fetch_flow(
scan_id: int,
target_id: int,
target_name: str,
output_dir: str,
domain_name_tools: Dict[str, dict],
provider,
) -> dict:
"""
基于 Target 根域名执行 URL 被动收集(当前主要用于 waymore
@@ -46,35 +43,38 @@ def domain_name_url_fetch_flow(
2. 使用传入的工具列表对根域名执行被动收集
3. 工具内部会自动查询该域名及其子域名的历史 URL
4. 汇总结果文件列表
Args:
scan_id: 扫描 ID
target_id: 目标 ID
target_name: Target 根域名(如 example.com不是子域名列表
output_dir: 输出目录
domain_name_tools: 被动收集工具配置(如 waymore
provider: TargetProvider 实例
注意:
- 此 Flow 只对 DOMAIN 类型 Target 有效
- IP 和 CIDR 类型会自动跳过waymore 等工具不支持)
- 工具会自动收集 *.target_name 的所有历史 URL无需遍历子域名
"""
from apps.scan.utils import user_log
try:
# 从 provider 获取 target_name
target_name = provider.get_target_name()
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 检查 Target 类型IP/CIDR 类型跳过
from apps.targets.services import TargetService
from apps.targets.models import Target
target_service = TargetService()
target = target_service.get_target(target_id)
if target and target.type != Target.TargetType.DOMAIN:
logger.info(
"跳过 domain_name URL 获取: Target 类型为 %s (ID=%d, Name=%s)waymore 等工具仅适用于域名类型",
"跳过 domain_name URL 获取: Target 类型为 %s (ID=%d, Name=%s)",
target.type, target_id, target_name
)
return {
@@ -93,10 +93,10 @@ def domain_name_url_fetch_flow(
", ".join(domain_name_tools.keys()) if domain_name_tools else "",
)
futures: dict[str, object] = {}
tool_params = {} # 存储每个工具的参数
failed_tools: list[dict] = []
# 提交所有基于域名的 URL 获取任务
# 准备所有工具的参数
for tool_name, tool_config in domain_name_tools.items():
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
short_uuid = uuid.uuid4().hex[:4]
@@ -150,46 +150,62 @@ def domain_name_url_fetch_flow(
# 记录工具开始执行日志
user_log(scan_id, "url_fetch", f"Running {tool_name}: {command}")
future = run_url_fetcher_task.submit(
tool_name=tool_name,
command=command,
timeout=timeout,
output_file=output_file,
)
futures[tool_name] = future
tool_params[tool_name] = {
'command': command,
'timeout': timeout,
'output_file': output_file
}
result_files: list[str] = []
successful_tools: list[str] = []
# 收集执行结果
for tool_name, future in futures.items():
try:
result = future.result()
if result and result.get("success"):
result_files.append(result["output_file"])
successful_tools.append(tool_name)
url_count = result.get("url_count", 0)
logger.info(
"✓ 工具 %s 执行成功 - 发现 URL: %d",
tool_name,
url_count,
# 使用 ThreadPoolExecutor 并行执行
if tool_params:
with ThreadPoolExecutor(max_workers=len(tool_params)) as executor:
futures = {}
for tool_name, params in tool_params.items():
future = executor.submit(
run_url_fetcher_task,
tool_name=tool_name,
command=params['command'],
timeout=params['timeout'],
output_file=params['output_file'],
)
user_log(scan_id, "url_fetch", f"{tool_name} completed: found {url_count} urls")
else:
reason = "未生成结果或无有效 URL"
failed_tools.append(
{
"tool": tool_name,
"reason": reason,
}
)
logger.warning("⚠️ 工具 %s 未生成有效结果", tool_name)
user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error")
except Exception as e:
reason = str(e)
failed_tools.append({"tool": tool_name, "reason": reason})
logger.warning("⚠️ 工具 %s 执行失败: %s", tool_name, e)
user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error")
futures[tool_name] = future
# 收集执行结果
for tool_name, future in futures.items():
try:
result = future.result()
if result and result.get("success"):
result_files.append(result["output_file"])
successful_tools.append(tool_name)
url_count = result.get("url_count", 0)
logger.info(
"✓ 工具 %s 执行成功 - 发现 URL: %d",
tool_name,
url_count,
)
user_log(
scan_id, "url_fetch",
f"{tool_name} completed: found {url_count} urls"
)
else:
reason = "未生成结果或无有效 URL"
failed_tools.append({"tool": tool_name, "reason": reason})
logger.warning("⚠️ 工具 %s 未生成有效结果", tool_name)
user_log(
scan_id, "url_fetch",
f"{tool_name} failed: {reason}", "error"
)
except Exception as e:
reason = str(e)
failed_tools.append({"tool": tool_name, "reason": reason})
logger.warning("⚠️ 工具 %s 执行失败: %s", tool_name, e)
user_log(
scan_id, "url_fetch",
f"{tool_name} failed: {reason}", "error"
)
logger.info(
"基于 domain_name 的 URL 获取完成 - 成功工具: %s, 失败工具: %s",

View File

@@ -10,22 +10,17 @@ URL Fetch 主 Flow
- 统一进行 httpx 验证(如果启用)
"""
# Django 环境初始化
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
from pathlib import Path
from datetime import datetime
from pathlib import Path
from prefect import flow
from apps.scan.decorators import scan_flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.utils import user_log
from apps.scan.utils import user_log, wait_for_system_load
from .domain_name_url_fetch_flow import domain_name_url_fetch_flow
from .sites_url_fetch_flow import sites_url_fetch_flow
@@ -43,13 +38,10 @@ SITES_FILE_TOOLS = {'katana'}
POST_PROCESS_TOOLS = {'uro', 'httpx'}
def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
"""
将启用的工具按输入类型分类
Returns:
tuple: (domain_name_tools, sites_file_tools, uro_config, httpx_config)
"""
@@ -76,23 +68,23 @@ def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
def _merge_and_deduplicate_urls(result_files: list, url_fetch_dir: Path) -> tuple[str, int]:
"""合并并去重 URL"""
from apps.scan.tasks.url_fetch import merge_and_deduplicate_urls_task
merged_file = merge_and_deduplicate_urls_task(
result_files=result_files,
result_dir=str(url_fetch_dir)
)
# 统计唯一 URL 数量
unique_url_count = 0
if Path(merged_file).exists():
with open(merged_file, 'r') as f:
with open(merged_file, 'r', encoding='utf-8') as f:
unique_url_count = sum(1 for line in f if line.strip())
logger.info(
"✓ URL 合并去重完成 - 合并文件: %s, 唯一 URL 数: %d",
merged_file, unique_url_count
)
return merged_file, unique_url_count
@@ -103,12 +95,12 @@ def _clean_urls_with_uro(
) -> tuple[str, int, int]:
"""使用 uro 清理合并后的 URL 列表"""
from apps.scan.tasks.url_fetch import clean_urls_task
raw_timeout = uro_config.get('timeout', 60)
whitelist = uro_config.get('whitelist')
blacklist = uro_config.get('blacklist')
filters = uro_config.get('filters')
# 计算超时时间
if isinstance(raw_timeout, str) and raw_timeout == 'auto':
timeout = calculate_timeout_by_line_count(
@@ -124,7 +116,7 @@ def _clean_urls_with_uro(
except (TypeError, ValueError):
logger.warning("uro timeout 配置无效(%s),使用默认 60 秒", raw_timeout)
timeout = 60
result = clean_urls_task(
input_file=merged_file,
output_dir=str(url_fetch_dir),
@@ -133,12 +125,12 @@ def _clean_urls_with_uro(
blacklist=blacklist,
filters=filters
)
if result['success']:
return result['output_file'], result['output_count'], result['removed_count']
else:
logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误'))
return merged_file, result['input_count'], 0
logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误'))
return merged_file, result['input_count'], 0
def _validate_and_stream_save_urls(
@@ -151,25 +143,25 @@ def _validate_and_stream_save_urls(
"""使用 httpx 验证 URL 存活并流式保存到数据库"""
from apps.scan.utils import build_scan_command
from apps.scan.tasks.url_fetch import run_and_stream_save_urls_task
logger.info("开始使用 httpx 验证 URL 存活状态...")
# 统计待验证的 URL 数量
try:
with open(merged_file, 'r') as f:
with open(merged_file, 'r', encoding='utf-8') as f:
url_count = sum(1 for _ in f)
logger.info("待验证 URL 数量: %d", url_count)
except Exception as e:
except OSError as e:
logger.error("读取 URL 文件失败: %s", e)
return 0
if url_count == 0:
logger.warning("没有需要验证的 URL")
return 0
# 构建 httpx 命令
command_params = {'url_file': merged_file}
try:
command = build_scan_command(
tool_name='httpx',
@@ -177,21 +169,19 @@ def _validate_and_stream_save_urls(
command_params=command_params,
tool_config=httpx_config
)
except Exception as e:
except (ValueError, KeyError) as e:
logger.error("构建 httpx 命令失败: %s", e)
logger.warning("降级处理:将直接保存所有 URL不验证存活")
return _save_urls_to_database(merged_file, scan_id, target_id)
# 计算超时时间
raw_timeout = httpx_config.get('timeout', 'auto')
timeout = 3600
if isinstance(raw_timeout, str) and raw_timeout == 'auto':
# 按 URL 行数计算超时时间:每行 3 秒,最小 60 秒
timeout = max(60, url_count * 3)
logger.info(
"自动计算 httpx 超时时间(按行数,每行 3 秒,最小 60 秒): url_count=%d, timeout=%d",
url_count,
timeout,
url_count, timeout
)
else:
try:
@@ -199,108 +189,103 @@ def _validate_and_stream_save_urls(
except (TypeError, ValueError):
timeout = 3600
logger.info("使用配置的 httpx 超时时间: %d", timeout)
# 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = url_fetch_dir / f"httpx_validation_{timestamp}.log"
# 流式执行
try:
result = run_and_stream_save_urls_task(
cmd=command,
tool_name='httpx',
scan_id=scan_id,
target_id=target_id,
cwd=str(url_fetch_dir),
shell=True,
timeout=timeout,
log_file=str(log_file)
)
saved = result.get('saved_urls', 0)
logger.info(
"✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)",
saved, (saved / url_count * 100) if url_count > 0 else 0
)
return saved
except Exception as e:
logger.error("httpx 流式验证失败: %s", e, exc_info=True)
raise
result = run_and_stream_save_urls_task(
cmd=command,
tool_name='httpx',
scan_id=scan_id,
target_id=target_id,
cwd=str(url_fetch_dir),
shell=True,
timeout=timeout,
log_file=str(log_file)
)
saved = result.get('saved_urls', 0)
logger.info(
"✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)",
saved, (saved / url_count * 100) if url_count > 0 else 0
)
return saved
def _save_urls_to_database(merged_file: str, scan_id: int, target_id: int) -> int:
"""保存 URL 到数据库(不验证存活)"""
from apps.scan.tasks.url_fetch import save_urls_task
result = save_urls_task(
urls_file=merged_file,
scan_id=scan_id,
target_id=target_id
)
saved_count = result.get('saved_urls', 0)
logger.info("✓ URL 保存完成 - 保存数量: %d", saved_count)
return saved_count
@flow(
@scan_flow(
name="url_fetch",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
on_failure=[on_scan_flow_failed],
)
def url_fetch_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider,
) -> dict:
"""
URL 获取主 Flow
执行流程:
1. 准备工作目录
2. 按输入类型分类工具domain_name / sites_file / 后处理)
3. 并行执行子 Flow
- domain_name_url_fetch_flow: 基于 domain_name来自 target_name)执行 URL 获取(如 waymore
- domain_name_url_fetch_flow: 基于 domain_name来自 provider)执行 URL 获取(如 waymore
- sites_url_fetch_flow: 基于 sites_file 执行爬虫(如 katana 等)
4. 合并所有子 Flow 的结果并去重
5. uro 去重(如果启用)
6. httpx 验证(如果启用)
Args:
scan_id: 扫描 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作目录
enabled_tools: 启用的工具配置
provider: TargetProvider 实例
Returns:
dict: 扫描结果
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="url_fetch_flow")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
logger.info(
"="*60 + "\n" +
"开始 URL 获取扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始 URL 获取扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "url_fetch", "Starting URL fetch")
# Step 1: 准备工作目录
logger.info("Step 1: 准备工作目录")
from apps.scan.utils import setup_scan_directory
url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch')
# Step 2: 分类工具(按输入类型)
logger.info("Step 2: 分类工具")
domain_name_tools, sites_file_tools, uro_config, httpx_config = _classify_tools(enabled_tools)
logger.info(
@@ -317,45 +302,44 @@ def url_fetch_flow(
"URL Fetch 流程需要至少启用一个 URL 获取工具(如 waymore, katana"
"httpx 和 uro 仅用于后处理,不能单独使用。"
)
# Step 3: 并行执行子 Flow
# Step 3: 执行子 Flow
all_result_files = []
all_failed_tools = []
all_successful_tools = []
# 3a: 基于 domain_nametarget_name 的 URL 被动收集(如 waymore
# 3a: 基于 domain_name 的 URL 被动收集(如 waymore
if domain_name_tools:
logger.info("Step 3a: 执行基于 domain_name 的 URL 被动收集子 Flow")
tn_result = domain_name_url_fetch_flow(
scan_id=scan_id,
target_id=target_id,
target_name=target_name,
output_dir=str(url_fetch_dir),
domain_name_tools=domain_name_tools,
provider=provider,
)
all_result_files.extend(tn_result.get('result_files', []))
all_failed_tools.extend(tn_result.get('failed_tools', []))
all_successful_tools.extend(tn_result.get('successful_tools', []))
# 3b: 爬虫(以 sites_file 为输入)
if sites_file_tools:
logger.info("Step 3b: 执行爬虫子 Flow")
crawl_result = sites_url_fetch_flow(
scan_id=scan_id,
target_id=target_id,
target_name=target_name,
output_dir=str(url_fetch_dir),
enabled_tools=sites_file_tools
enabled_tools=sites_file_tools,
provider=provider
)
all_result_files.extend(crawl_result.get('result_files', []))
all_failed_tools.extend(crawl_result.get('failed_tools', []))
all_successful_tools.extend(crawl_result.get('successful_tools', []))
# 检查是否有成功的工具
if not all_result_files:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in all_failed_tools])
error_details = "; ".join([
"%s: %s" % (f['tool'], f['reason']) for f in all_failed_tools
])
logger.warning("所有 URL 获取工具均失败 - 目标: %s, 失败详情: %s", target_name, error_details)
# 返回空结果,不抛出异常,让扫描继续
return {
'success': True,
'scan_id': scan_id,
@@ -366,31 +350,24 @@ def url_fetch_flow(
'successful_tools': [],
'message': '所有 URL 获取工具均无结果'
}
# Step 4: 合并并去重 URL
logger.info("Step 4: 合并并去重 URL")
merged_file, unique_url_count = _merge_and_deduplicate_urls(
merged_file, _ = _merge_and_deduplicate_urls(
result_files=all_result_files,
url_fetch_dir=url_fetch_dir
)
# Step 5: 使用 uro 清理 URL如果启用
url_file_for_validation = merged_file
uro_removed_count = 0
if uro_config and uro_config.get('enabled', False):
logger.info("Step 5: 使用 uro 清理 URL")
url_file_for_validation, cleaned_count, uro_removed_count = _clean_urls_with_uro(
url_file_for_validation, _, _ = _clean_urls_with_uro(
merged_file=merged_file,
uro_config=uro_config,
url_fetch_dir=url_fetch_dir
)
else:
logger.info("Step 5: 跳过 uro 清理(未启用)")
# Step 6: 使用 httpx 验证存活并保存(如果启用)
if httpx_config and httpx_config.get('enabled', False):
logger.info("Step 6: 使用 httpx 验证 URL 存活并流式保存")
saved_count = _validate_and_stream_save_urls(
merged_file=url_file_for_validation,
httpx_config=httpx_config,
@@ -399,17 +376,16 @@ def url_fetch_flow(
target_id=target_id
)
else:
logger.info("Step 6: 保存到数据库(未启用 httpx 验证)")
saved_count = _save_urls_to_database(
merged_file=url_file_for_validation,
scan_id=scan_id,
target_id=target_id
)
# 记录 Flow 完成
logger.info("✓ URL 获取完成 - 保存 endpoints: %d", saved_count)
user_log(scan_id, "url_fetch", f"url_fetch completed: found {saved_count} endpoints")
user_log(scan_id, "url_fetch", "url_fetch completed: found %d endpoints" % saved_count)
# 构建已执行的任务列表
executed_tasks = ['setup_directory', 'classify_tools']
if domain_name_tools:
@@ -423,7 +399,7 @@ def url_fetch_flow(
executed_tasks.append('httpx_validation_and_save')
else:
executed_tasks.append('save_urls')
return {
'success': True,
'scan_id': scan_id,
@@ -439,7 +415,7 @@ def url_fetch_flow(
'failed_tools': [f['tool'] for f in all_failed_tools]
}
}
except Exception as e:
logger.error("URL 获取扫描失败: %s", e, exc_info=True)
raise

View File

@@ -6,75 +6,69 @@ URL 爬虫 Flow
输入sites_file站点 URL 列表)
"""
# Django 环境初始化
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
from pathlib import Path
from prefect import flow
from apps.scan.decorators import scan_flow
from .utils import run_tools_parallel
logger = logging.getLogger(__name__)
def _export_sites_file(target_id: int, scan_id: int, target_name: str, output_dir: Path) -> tuple[str, int]:
def _export_sites_file(
output_dir: Path,
provider,
) -> tuple[str, int]:
"""
导出站点 URL 列表到文件
懒加载模式:如果 WebSite 表为空,根据 Target 类型生成默认 URL
Args:
target_id: 目标 ID
scan_id: 扫描 ID
target_name: 目标名称(用于懒加载)
output_dir: 输出目录
provider: TargetProvider 实例
Returns:
tuple: (file_path, count)
"""
from apps.scan.tasks.url_fetch import export_sites_task
output_file = str(output_dir / "sites.txt")
result = export_sites_task(
output_file=output_file,
target_id=target_id,
scan_id=scan_id
provider=provider
)
count = result['asset_count']
if count > 0:
logger.info("✓ 站点列表导出完成 - 数量: %d", count)
else:
logger.warning("站点列表为空,爬虫可能无法正常工作")
return output_file, count
@flow(name="sites_url_fetch_flow", log_prints=True)
@scan_flow(name="sites_url_fetch_flow")
def sites_url_fetch_flow(
scan_id: int,
target_id: int,
target_name: str,
output_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider,
) -> dict:
"""
URL 爬虫子 Flow
执行流程:
1. 导出站点 URL 列表sites_file
2. 并行执行爬虫工具
3. 返回结果文件列表
Args:
scan_id: 扫描 ID
target_id: 目标 ID
target_name: 目标名称
output_dir: 输出目录
enabled_tools: 启用的爬虫工具配置
provider: TargetProvider 实例
Returns:
dict: {
'success': bool,
@@ -85,21 +79,24 @@ def sites_url_fetch_flow(
}
"""
try:
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
output_path = Path(output_dir)
logger.info(
"开始 URL 爬虫 - Target: %s, Tools: %s",
target_name, ', '.join(enabled_tools.keys())
)
# Step 1: 导出站点 URL 列表
sites_file, sites_count = _export_sites_file(
target_id=target_id,
scan_id=scan_id,
target_name=target_name,
output_dir=output_path
output_dir=output_path,
provider=provider
)
# 默认值模式下,即使原本没有站点,也会有默认 URL 作为输入
if sites_count == 0:
logger.warning("没有可用的站点,跳过爬虫")
@@ -110,7 +107,7 @@ def sites_url_fetch_flow(
'successful_tools': [],
'sites_count': 0
}
# Step 2: 并行执行爬虫工具
result_files, failed_tools, successful_tools = run_tools_parallel(
tools=enabled_tools,
@@ -119,12 +116,12 @@ def sites_url_fetch_flow(
output_dir=output_path,
scan_id=scan_id
)
logger.info(
"✓ 爬虫完成 - 成功: %d/%d, 结果文件: %d",
len(successful_tools), len(enabled_tools), len(result_files)
)
return {
'success': True,
'result_files': result_files,
@@ -132,7 +129,7 @@ def sites_url_fetch_flow(
'successful_tools': successful_tools,
'sites_count': sites_count
}
except Exception as e:
logger.error("URL 爬虫失败: %s", e, exc_info=True)
return {

View File

@@ -5,6 +5,7 @@ URL Fetch 共享工具函数
import logging
import subprocess
import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
@@ -21,13 +22,13 @@ def calculate_timeout_by_line_count(
) -> int:
"""
根据文件行数自动计算超时时间
Args:
tool_config: 工具配置(保留参数,未来可能用于更复杂的计算)
file_path: 输入文件路径
base_per_time: 每行的基础时间(秒)
min_timeout: 最小超时时间默认60秒
Returns:
int: 计算出的超时时间(秒),不低于 min_timeout
"""
@@ -64,7 +65,7 @@ def prepare_tool_execution(
) -> dict:
"""
准备单个工具的执行参数
Args:
tool_name: 工具名称
tool_config: 工具配置
@@ -72,7 +73,7 @@ def prepare_tool_execution(
input_type: 输入类型domains_file 或 sites_file
output_dir: 输出目录
scan_type: 扫描类型
Returns:
dict: 执行参数,包含 command, input_file, output_file, timeout
或包含 error 键表示失败
@@ -110,7 +111,7 @@ def prepare_tool_execution(
# 4. 计算超时时间(支持 auto 和显式整数)
raw_timeout = tool_config.get("timeout", 3600)
timeout = 3600
if isinstance(raw_timeout, str) and raw_timeout == "auto":
try:
# katana / waymore 每个站点需要更长时间
@@ -157,24 +158,24 @@ def run_tools_parallel(
) -> tuple[list, list, list]:
"""
并行执行工具列表
Args:
tools: 工具配置字典 {tool_name: tool_config}
input_file: 输入文件路径
input_type: 输入类型
output_dir: 输出目录
scan_id: 扫描任务 ID用于记录日志
Returns:
tuple: (result_files, failed_tools, successful_tool_names)
"""
from apps.scan.tasks.url_fetch import run_url_fetcher_task
from apps.scan.utils import user_log
futures: dict[str, object] = {}
tool_params = {} # 存储每个工具的参数
failed_tools: list[dict] = []
# 提交所有工具的并行任务
# 准备所有工具的参数
for tool_name, tool_config in tools.items():
exec_params = prepare_tool_execution(
tool_name=tool_name,
@@ -198,44 +199,54 @@ def run_tools_parallel(
# 记录工具开始执行日志
user_log(scan_id, "url_fetch", f"Running {tool_name}: {exec_params['command']}")
# 提交并行任务
future = run_url_fetcher_task.submit(
tool_name=tool_name,
command=exec_params["command"],
timeout=exec_params["timeout"],
output_file=exec_params["output_file"],
)
futures[tool_name] = future
tool_params[tool_name] = exec_params
# 收集执行结果
# 使用 ThreadPoolExecutor 并行执行
result_files = []
for tool_name, future in futures.items():
try:
result = future.result()
if result and result['success']:
result_files.append(result['output_file'])
url_count = result['url_count']
logger.info(
"✓ 工具 %s 执行成功 - 发现 URL: %d",
tool_name, url_count
if tool_params:
with ThreadPoolExecutor(max_workers=len(tool_params)) as executor:
futures = {}
for tool_name, params in tool_params.items():
future = executor.submit(
run_url_fetcher_task,
tool_name=tool_name,
command=params["command"],
timeout=params["timeout"],
output_file=params["output_file"],
)
user_log(scan_id, "url_fetch", f"{tool_name} completed: found {url_count} urls")
else:
reason = '未生成结果或无有效URL'
failed_tools.append({
'tool': tool_name,
'reason': reason
})
logger.warning("⚠️ 工具 %s 未生成有效结果", tool_name)
user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error")
except Exception as e:
reason = str(e)
failed_tools.append({
'tool': tool_name,
'reason': reason
})
logger.warning("⚠️ 工具 %s 执行失败: %s", tool_name, e)
user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error")
futures[tool_name] = future
# 收集执行结果
for tool_name, future in futures.items():
try:
result = future.result()
if result and result['success']:
result_files.append(result['output_file'])
url_count = result['url_count']
logger.info(
"✓ 工具 %s 执行成功 - 发现 URL: %d",
tool_name, url_count
)
user_log(
scan_id, "url_fetch",
f"{tool_name} completed: found {url_count} urls"
)
else:
reason = '未生成结果或无有效URL'
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.warning("⚠️ 工具 %s 未生成有效结果", tool_name)
user_log(
scan_id, "url_fetch",
f"{tool_name} failed: {reason}", "error"
)
except Exception as e:
reason = str(e)
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.warning("⚠️ 工具 %s 执行失败: %s", tool_name, e)
user_log(
scan_id, "url_fetch",
f"{tool_name} failed: {reason}", "error"
)
# 计算成功的工具列表
failed_tool_names = [f['tool'] for f in failed_tools]

View File

@@ -1,17 +1,13 @@
from apps.common.prefect_django_setup import setup_django_for_prefect
"""
基于 Endpoint 的漏洞扫描 Flow
"""
import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Dict
from prefect import flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
)
from apps.scan.decorators import scan_flow
from apps.scan.utils import build_scan_command, ensure_nuclei_templates_local, user_log
from apps.scan.tasks.vuln_scan import (
export_endpoints_task,
@@ -25,26 +21,23 @@ from .utils import calculate_timeout_by_line_count
logger = logging.getLogger(__name__)
@flow(
name="endpoints_vuln_scan_flow",
log_prints=True,
)
@scan_flow(name="endpoints_vuln_scan_flow")
def endpoints_vuln_scan_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: Dict[str, dict],
provider,
) -> dict:
"""基于 Endpoint 的漏洞扫描 Flow串行执行 Dalfox 等工具)。"""
try:
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
@@ -58,8 +51,8 @@ def endpoints_vuln_scan_flow(
# Step 1: 导出 Endpoint URL
export_result = export_endpoints_task(
target_id=target_id,
output_file=str(endpoints_file),
provider=provider,
)
total_endpoints = export_result.get("total_count", 0)
@@ -79,12 +72,9 @@ def endpoints_vuln_scan_flow(
logger.info("Endpoint 导出完成,共 %d 条,开始执行漏洞扫描", total_endpoints)
tool_results: Dict[str, dict] = {}
tool_params: Dict[str, dict] = {} # 存储每个工具的参数
# Step 2: 并行执行每个漏洞扫描工具(目前主要是 Dalfox
# 1先为每个工具 submit Prefect Task让 Worker 并行调度
# 2再统一收集各自的结果组装成 tool_results
tool_futures: Dict[str, dict] = {}
# Step 2: 准备每个漏洞扫描工具的参数
for tool_name, tool_config in enabled_tools.items():
# Nuclei 需要先确保本地模板存在(支持多个模板仓库)
template_args = ""
@@ -104,8 +94,11 @@ def endpoints_vuln_scan_flow(
continue
template_args = " ".join(f"-t {p}" for p in template_paths)
# 构建命令参数
command_params = {"endpoints_file": str(endpoints_file)}
# 构建命令参数(根据工具模板使用不同的参数名)
if tool_name == "nuclei":
command_params = {"input_file": str(endpoints_file)}
else:
command_params = {"endpoints_file": str(endpoints_file)}
if template_args:
command_params["template_args"] = template_args
@@ -138,102 +131,105 @@ def endpoints_vuln_scan_flow(
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = vuln_scan_dir / f"{tool_name}_{timestamp}.log"
# Dalfox XSS 使用流式任务,一边解析一边保存漏洞结果
logger.info("开始执行漏洞扫描工具 %s", tool_name)
user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}")
# 确定工具类型
if tool_name == "dalfox_xss":
logger.info("开始执行漏洞扫描工具 %s(流式保存漏洞结果,已提交任务)", tool_name)
user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}")
future = run_and_stream_save_dalfox_vulns_task.submit(
cmd=command,
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
cwd=str(vuln_scan_dir),
shell=True,
batch_size=1,
timeout=timeout,
log_file=str(log_file),
)
tool_futures[tool_name] = {
"future": future,
"command": command,
"timeout": timeout,
"log_file": str(log_file),
"mode": "streaming",
}
mode = "dalfox"
elif tool_name == "nuclei":
# Nuclei 使用流式任务
logger.info("开始执行漏洞扫描工具 %s(流式保存漏洞结果,已提交任务)", tool_name)
user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}")
future = run_and_stream_save_nuclei_vulns_task.submit(
cmd=command,
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
cwd=str(vuln_scan_dir),
shell=True,
batch_size=1,
timeout=timeout,
log_file=str(log_file),
)
tool_futures[tool_name] = {
"future": future,
"command": command,
"timeout": timeout,
"log_file": str(log_file),
"mode": "streaming",
}
mode = "nuclei"
else:
# 其他工具仍使用非流式执行逻辑
logger.info("开始执行漏洞扫描工具 %s(已提交任务)", tool_name)
user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}")
future = run_vuln_tool_task.submit(
tool_name=tool_name,
command=command,
timeout=timeout,
log_file=str(log_file),
)
mode = "normal"
tool_futures[tool_name] = {
"future": future,
"command": command,
"timeout": timeout,
"log_file": str(log_file),
"mode": "normal",
}
tool_params[tool_name] = {
"command": command,
"timeout": timeout,
"log_file": str(log_file),
"mode": mode,
}
# 统一收集所有工具的执行结果
for tool_name, meta in tool_futures.items():
future = meta["future"]
try:
result = future.result()
# Step 3: 使用 ThreadPoolExecutor 并行执行
if tool_params:
with ThreadPoolExecutor(max_workers=len(tool_params)) as executor:
futures = {}
for tool_name, params in tool_params.items():
if params["mode"] == "dalfox":
future = executor.submit(
run_and_stream_save_dalfox_vulns_task,
cmd=params["command"],
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
cwd=str(vuln_scan_dir),
shell=True,
batch_size=1,
timeout=params["timeout"],
log_file=params["log_file"],
)
elif params["mode"] == "nuclei":
future = executor.submit(
run_and_stream_save_nuclei_vulns_task,
cmd=params["command"],
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
cwd=str(vuln_scan_dir),
shell=True,
batch_size=1,
timeout=params["timeout"],
log_file=params["log_file"],
)
else:
future = executor.submit(
run_vuln_tool_task,
tool_name=tool_name,
command=params["command"],
timeout=params["timeout"],
log_file=params["log_file"],
)
futures[tool_name] = future
if meta["mode"] == "streaming":
created_vulns = result.get("created_vulns", 0)
tool_results[tool_name] = {
"command": meta["command"],
"timeout": meta["timeout"],
"processed_records": result.get("processed_records"),
"created_vulns": created_vulns,
"command_log_file": meta["log_file"],
}
logger.info("✓ 工具 %s 执行完成 - 漏洞: %d", tool_name, created_vulns)
user_log(scan_id, "vuln_scan", f"{tool_name} completed: found {created_vulns} vulnerabilities")
else:
tool_results[tool_name] = {
"command": meta["command"],
"timeout": meta["timeout"],
"duration": result.get("duration"),
"returncode": result.get("returncode"),
"command_log_file": result.get("command_log_file"),
}
logger.info("✓ 工具 %s 执行完成 - returncode=%s", tool_name, result.get("returncode"))
user_log(scan_id, "vuln_scan", f"{tool_name} completed")
except Exception as e:
reason = str(e)
logger.error("工具 %s 执行失败: %s", tool_name, e, exc_info=True)
user_log(scan_id, "vuln_scan", f"{tool_name} failed: {reason}", "error")
# 收集结果
for tool_name, future in futures.items():
params = tool_params[tool_name]
try:
result = future.result()
if params["mode"] in ("dalfox", "nuclei"):
created_vulns = result.get("created_vulns", 0)
tool_results[tool_name] = {
"command": params["command"],
"timeout": params["timeout"],
"processed_records": result.get("processed_records"),
"created_vulns": created_vulns,
"command_log_file": params["log_file"],
}
logger.info(
"✓ 工具 %s 执行完成 - 漏洞: %d",
tool_name, created_vulns
)
user_log(
scan_id, "vuln_scan",
f"{tool_name} completed: found {created_vulns} vulnerabilities"
)
else:
tool_results[tool_name] = {
"command": params["command"],
"timeout": params["timeout"],
"duration": result.get("duration"),
"returncode": result.get("returncode"),
"command_log_file": result.get("command_log_file"),
}
logger.info(
"✓ 工具 %s 执行完成 - returncode=%s",
tool_name, result.get("returncode")
)
user_log(scan_id, "vuln_scan", f"{tool_name} completed")
except Exception as e:
reason = str(e)
logger.error("工具 %s 执行失败: %s", tool_name, e, exc_info=True)
user_log(scan_id, "vuln_scan", f"{tool_name} failed: {reason}", "error")
return {
"success": True,

View File

@@ -1,71 +1,92 @@
from apps.common.prefect_django_setup import setup_django_for_prefect
"""
漏洞扫描主 Flow
"""
import logging
from typing import Dict, Tuple
from prefect import flow
from apps.scan.decorators import scan_flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
)
from apps.scan.configs.command_templates import get_command_template
from apps.scan.utils import user_log
from apps.scan.utils import user_log, wait_for_system_load
from .endpoints_vuln_scan_flow import endpoints_vuln_scan_flow
from .websites_vuln_scan_flow import websites_vuln_scan_flow
logger = logging.getLogger(__name__)
def _classify_vuln_tools(enabled_tools: Dict[str, dict]) -> Tuple[Dict[str, dict], Dict[str, dict]]:
"""根据命令模板中的 input_type 对漏洞扫描工具进行分类。
def _classify_vuln_tools(
enabled_tools: Dict[str, dict]
) -> Tuple[Dict[str, dict], Dict[str, dict], Dict[str, dict]]:
"""根据用户配置分类漏洞扫描工具。
当前支持
- endpoints_file: 以端点列表文件为输入(例如 Dalfox XSS
预留:
- 其他 input_type 将被归类到 other_tools暂不处理。
分类逻辑
- 读取 scan_endpoints / scan_websites 配置
- 默认值从模板的 defaults 或 input_type 推断
Returns:
(endpoints_tools, websites_tools, other_tools) 三元组
"""
endpoints_tools: Dict[str, dict] = {}
websites_tools: Dict[str, dict] = {}
other_tools: Dict[str, dict] = {}
for tool_name, tool_config in enabled_tools.items():
template = get_command_template("vuln_scan", tool_name) or {}
input_type = template.get("input_type", "endpoints_file")
defaults = template.get("defaults", {})
if input_type == "endpoints_file":
# 根据 input_type 推断默认值(兼容老工具)
input_type = template.get("input_type")
default_endpoints = defaults.get("scan_endpoints", input_type == "endpoints_file")
default_websites = defaults.get("scan_websites", input_type == "websites_file")
scan_endpoints = tool_config.get("scan_endpoints", default_endpoints)
scan_websites = tool_config.get("scan_websites", default_websites)
if scan_endpoints:
endpoints_tools[tool_name] = tool_config
else:
if scan_websites:
websites_tools[tool_name] = tool_config
if not scan_endpoints and not scan_websites:
other_tools[tool_name] = tool_config
return endpoints_tools, other_tools
return endpoints_tools, websites_tools, other_tools
@flow(
@scan_flow(
name="vuln_scan",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
on_failure=[on_scan_flow_failed],
)
def vuln_scan_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: Dict[str, dict],
provider,
) -> dict:
"""漏洞扫描主 Flow串行编排各类漏洞扫描子 Flow。
支持工具:
- dalfox_xss: XSS 漏洞扫描(流式保存)
- nuclei: 通用漏洞扫描(流式保存,支持模板 commit hash 同步
- nuclei: 通用漏洞扫描(流式保存,支持 endpoints 和 websites 两种输入
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="vuln_scan_flow")
# 从 provider 获取 target_name
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
@@ -77,11 +98,12 @@ def vuln_scan_flow(
user_log(scan_id, "vuln_scan", "Starting vulnerability scan")
# Step 1: 分类工具
endpoints_tools, other_tools = _classify_vuln_tools(enabled_tools)
endpoints_tools, websites_tools, other_tools = _classify_vuln_tools(enabled_tools)
logger.info(
"漏洞扫描工具分类 - endpoints_file: %s, 其他: %s",
"漏洞扫描工具分类 - endpoints: %s, websites: %s, 其他: %s",
list(endpoints_tools.keys()) or "",
list(websites_tools.keys()) or "",
list(other_tools.keys()) or "",
)
@@ -91,28 +113,58 @@ def vuln_scan_flow(
list(other_tools.keys()),
)
if not endpoints_tools:
raise ValueError("漏洞扫描需要至少启用一个以 endpoints_file 为输入的工具(如 dalfox_xss、nuclei")
if not endpoints_tools and not websites_tools:
raise ValueError(
"漏洞扫描需要至少启用一个工具endpoints 或 websites 模式)"
)
# Step 2: 执行 Endpoint 漏洞扫描子 Flow串行
endpoint_result = endpoints_vuln_scan_flow(
scan_id=scan_id,
target_name=target_name,
target_id=target_id,
scan_workspace_dir=scan_workspace_dir,
enabled_tools=endpoints_tools,
)
total_vulns = 0
results = {}
# Step 2: 执行 Endpoint 漏洞扫描子 Flow
if endpoints_tools:
logger.info("执行 Endpoint 漏洞扫描 - 工具: %s", list(endpoints_tools.keys()))
endpoint_result = endpoints_vuln_scan_flow(
scan_id=scan_id,
target_id=target_id,
scan_workspace_dir=scan_workspace_dir,
enabled_tools=endpoints_tools,
provider=provider,
)
results["endpoints"] = endpoint_result
total_vulns += sum(
r.get("created_vulns", 0)
for r in endpoint_result.get("tool_results", {}).values()
)
# Step 3: 执行 WebSite 漏洞扫描子 Flow
if websites_tools:
logger.info("执行 WebSite 漏洞扫描 - 工具: %s", list(websites_tools.keys()))
website_result = websites_vuln_scan_flow(
scan_id=scan_id,
target_id=target_id,
scan_workspace_dir=scan_workspace_dir,
enabled_tools=websites_tools,
provider=provider,
)
results["websites"] = website_result
total_vulns += sum(
r.get("created_vulns", 0)
for r in website_result.get("tool_results", {}).values()
)
# 记录 Flow 完成
total_vulns = sum(
r.get("created_vulns", 0)
for r in endpoint_result.get("tool_results", {}).values()
)
logger.info("✓ 漏洞扫描完成 - 新增漏洞: %d", total_vulns)
user_log(scan_id, "vuln_scan", f"vuln_scan completed: found {total_vulns} vulnerabilities")
# 目前只有一个子 Flow直接返回其结果
return endpoint_result
return {
"success": True,
"scan_id": scan_id,
"target": target_name,
"scan_workspace_dir": scan_workspace_dir,
"total_vulns": total_vulns,
"sub_flow_results": results,
}
except Exception as e:
logger.exception("漏洞扫描主 Flow 失败: %s", e)

View File

@@ -0,0 +1,199 @@
"""
基于 WebSite 的漏洞扫描 Flow
与 endpoints_vuln_scan_flow 类似,但数据源是 WebSite 而不是 Endpoint。
主要用于 nuclei 扫描已存活的网站。
"""
import logging
from datetime import datetime
from typing import Dict
from concurrent.futures import ThreadPoolExecutor
from apps.scan.decorators import scan_flow
from apps.scan.utils import build_scan_command, ensure_nuclei_templates_local, user_log
from apps.scan.tasks.vuln_scan import run_and_stream_save_nuclei_vulns_task
from apps.scan.tasks.vuln_scan.export_websites_task import export_websites_task
from .utils import calculate_timeout_by_line_count
logger = logging.getLogger(__name__)
@scan_flow(name="websites_vuln_scan_flow")
def websites_vuln_scan_flow(
scan_id: int,
target_id: int,
scan_workspace_dir: str,
enabled_tools: Dict[str, dict],
provider,
) -> dict:
"""基于 WebSite 的漏洞扫描 Flow主要用于 nuclei"""
try:
target_name = provider.get_target_name()
if not target_name:
raise ValueError("无法获取 Target 名称")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
if not enabled_tools:
raise ValueError("enabled_tools 不能为空")
from apps.scan.utils import setup_scan_directory
vuln_scan_dir = setup_scan_directory(scan_workspace_dir, 'vuln_scan')
websites_file = vuln_scan_dir / "input_websites.txt"
# Step 1: 导出 WebSite URL
export_result = export_websites_task(
output_file=str(websites_file),
provider=provider,
)
total_websites = export_result.get("total_count", 0)
if total_websites == 0:
logger.warning("目标下没有可用 WebSite跳过漏洞扫描")
return {
"success": True,
"scan_id": scan_id,
"target": target_name,
"scan_workspace_dir": scan_workspace_dir,
"websites_file": str(websites_file),
"website_count": 0,
"executed_tools": [],
"tool_results": {},
}
logger.info("WebSite 导出完成,共 %d 条,开始执行漏洞扫描", total_websites)
tool_results: Dict[str, dict] = {}
tool_futures: Dict[str, dict] = {}
# Step 2: 执行漏洞扫描工具
for tool_name, tool_config in enabled_tools.items():
# 目前只支持 nuclei
if tool_name != "nuclei":
logger.warning("websites_vuln_scan_flow 暂不支持工具: %s", tool_name)
continue
# 确保 nuclei 模板存在
repo_names = tool_config.get("template_repo_names")
if not repo_names or not isinstance(repo_names, (list, tuple)):
logger.error("Nuclei 配置缺少 template_repo_names数组跳过")
continue
template_paths = []
try:
for repo_name in repo_names:
path = ensure_nuclei_templates_local(repo_name)
template_paths.append(path)
logger.info("Nuclei 模板路径 [%s]: %s", repo_name, path)
except Exception as e:
logger.error("获取 Nuclei 模板失败: %s,跳过 nuclei 扫描", e)
continue
template_args = " ".join(f"-t {p}" for p in template_paths)
# 构建命令(使用 websites_file 作为输入)
command_params = {
"input_file": str(websites_file),
"template_args": template_args,
}
command = build_scan_command(
tool_name=tool_name,
scan_type="vuln_scan",
command_params=command_params,
tool_config=tool_config,
)
# 计算超时时间
raw_timeout = tool_config.get("timeout", 600)
if isinstance(raw_timeout, str) and raw_timeout == "auto":
timeout = calculate_timeout_by_line_count(
tool_config=tool_config,
file_path=str(websites_file),
base_per_time=30,
)
else:
try:
timeout = int(raw_timeout)
except (TypeError, ValueError) as e:
raise ValueError(
f"工具 {tool_name} 的 timeout 配置无效: {raw_timeout!r}"
) from e
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = vuln_scan_dir / f"{tool_name}_websites_{timestamp}.log"
logger.info("开始执行 %s 漏洞扫描WebSite 模式)", tool_name)
user_log(scan_id, "vuln_scan", f"Running {tool_name} (websites): {command}")
tool_futures[tool_name] = {
"command": command,
"timeout": timeout,
"log_file": str(log_file),
}
# 使用 ThreadPoolExecutor 并行执行
if tool_futures:
with ThreadPoolExecutor(max_workers=len(tool_futures)) as executor:
futures = {}
for tool_name, meta in tool_futures.items():
future = executor.submit(
run_and_stream_save_nuclei_vulns_task,
cmd=meta["command"],
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
cwd=str(vuln_scan_dir),
shell=True,
batch_size=1,
timeout=meta["timeout"],
log_file=meta["log_file"],
)
futures[tool_name] = future
# 收集结果
for tool_name, future in futures.items():
meta = tool_futures[tool_name]
try:
result = future.result()
created_vulns = result.get("created_vulns", 0)
tool_results[tool_name] = {
"command": meta["command"],
"timeout": meta["timeout"],
"processed_records": result.get("processed_records"),
"created_vulns": created_vulns,
"command_log_file": meta["log_file"],
}
logger.info(
"✓ 工具 %s (websites) 执行完成 - 漏洞: %d",
tool_name, created_vulns
)
user_log(
scan_id, "vuln_scan",
f"{tool_name} (websites) completed: found {created_vulns} vulnerabilities"
)
except Exception as e:
reason = str(e)
logger.error("工具 %s 执行失败: %s", tool_name, e, exc_info=True)
user_log(scan_id, "vuln_scan", f"{tool_name} failed: {reason}", "error")
return {
"success": True,
"scan_id": scan_id,
"target": target_name,
"scan_workspace_dir": scan_workspace_dir,
"websites_file": str(websites_file),
"website_count": total_websites,
"executed_tools": list(enabled_tools.keys()),
"tool_results": tool_results,
}
except Exception as e:
logger.exception("WebSite 漏洞扫描失败: %s", e)
raise

View File

@@ -12,57 +12,49 @@ initiate_scan_flow 状态处理器
"""
import logging
from prefect import Flow
from prefect.client.schemas import FlowRun, State
from apps.scan.decorators import FlowContext
logger = logging.getLogger(__name__)
def on_initiate_scan_flow_running(flow: Flow, flow_run: FlowRun, state: State) -> None:
def on_initiate_scan_flow_running(context: FlowContext) -> None:
"""
initiate_scan_flow 开始运行时的回调
职责:更新 Scan 状态为 RUNNING + 发送通知
触发时机:
- Prefect Flow 状态变为 Running 时自动触发
- 在 Flow 函数体执行之前调用
Args:
flow: Prefect Flow 对象
flow_run: Flow 运行实例
state: Flow 当前状态
context: Flow 执行上下文
"""
logger.info("🚀 initiate_scan_flow_running 回调开始运行 - Flow Run: %s", flow_run.id)
scan_id = flow_run.parameters.get('scan_id')
target_name = flow_run.parameters.get('target_name')
engine_name = flow_run.parameters.get('engine_name')
scheduled_scan_name = flow_run.parameters.get('scheduled_scan_name')
logger.info("🚀 initiate_scan_flow_running 回调开始运行 - Flow: %s", context.flow_name)
scan_id = context.scan_id
target_name = context.parameters.get('target_name')
engine_name = context.parameters.get('engine_name')
scheduled_scan_name = context.parameters.get('scheduled_scan_name')
if not scan_id:
logger.warning(
"Flow 参数中缺少 scan_id跳过状态更新 - Flow Run: %s",
flow_run.id
"Flow 参数中缺少 scan_id跳过状态更新 - Flow: %s",
context.flow_name
)
return
def _update_running_status():
from apps.scan.services import ScanService
from apps.common.definitions import ScanStatus
service = ScanService()
success = service.update_status(
scan_id,
scan_id,
ScanStatus.RUNNING
)
if success:
logger.info(
"✓ Flow 状态回调:扫描状态已更新为 RUNNING - Scan ID: %s, Flow Run: %s",
scan_id,
flow_run.id
"✓ Flow 状态回调:扫描状态已更新为 RUNNING - Scan ID: %s",
scan_id
)
else:
logger.error(
@@ -70,15 +62,17 @@ def on_initiate_scan_flow_running(flow: Flow, flow_run: FlowRun, state: State) -
scan_id
)
return success
# 执行状态更新Repository 层已有 @auto_ensure_db_connection 保证连接可靠性)
# 执行状态更新
_update_running_status()
# 发送通知
logger.info("准备发送扫描开始通知 - Scan ID: %s, Target: %s", scan_id, target_name)
try:
from apps.scan.notifications import create_notification, NotificationLevel, NotificationCategory
from apps.scan.notifications import (
create_notification, NotificationLevel, NotificationCategory
)
# 根据是否为定时扫描构建不同的标题和消息
if scheduled_scan_name:
title = f"{target_name} 扫描开始"
@@ -86,7 +80,7 @@ def on_initiate_scan_flow_running(flow: Flow, flow_run: FlowRun, state: State) -
else:
title = f"{target_name} 扫描开始"
message = f"引擎:{engine_name}"
create_notification(
title=title,
message=message,
@@ -95,47 +89,34 @@ def on_initiate_scan_flow_running(flow: Flow, flow_run: FlowRun, state: State) -
)
logger.info("✓ 扫描开始通知已发送 - Scan ID: %s, Target: %s", scan_id, target_name)
except Exception as e:
logger.error(f"发送扫描开始通知失败 - Scan ID: {scan_id}: {e}", exc_info=True)
logger.error("发送扫描开始通知失败 - Scan ID: %s: %s", scan_id, e, exc_info=True)
def on_initiate_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State) -> None:
def on_initiate_scan_flow_completed(context: FlowContext) -> None:
"""
initiate_scan_flow 成功完成时的回调
职责:更新 Scan 状态为 COMPLETED
触发时机:
- Prefect Flow 正常执行完成时自动触发
- 在 Flow 函数体返回之后调用
策略快速失败Fail-Fast
- Flow 成功完成 = 所有任务成功 → COMPLETED
- Flow 执行失败 = 有任务失败 → FAILED (由 on_failed 处理)
竞态条件处理:
- 如果用户已手动取消(状态已是 CANCELLED保持终态不覆盖
Args:
flow: Prefect Flow 对象
flow_run: Flow 运行实例
state: Flow 当前状态
context: Flow 执行上下文
"""
logger.info("✅ initiate_scan_flow_completed 回调开始运行 - Flow Run: %s", flow_run.id)
scan_id = flow_run.parameters.get('scan_id')
target_name = flow_run.parameters.get('target_name')
engine_name = flow_run.parameters.get('engine_name')
logger.info("✅ initiate_scan_flow_completed 回调开始运行 - Flow: %s", context.flow_name)
scan_id = context.scan_id
target_name = context.parameters.get('target_name')
engine_name = context.parameters.get('engine_name')
if not scan_id:
return
def _update_completed_status():
from apps.scan.services import ScanService
from apps.common.definitions import ScanStatus
from django.utils import timezone
service = ScanService()
# 仅在运行中时更新为 COMPLETED其他状态保持不变
completed_updated = service.update_status_if_match(
scan_id=scan_id,
@@ -143,32 +124,30 @@ def on_initiate_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State)
new_status=ScanStatus.COMPLETED,
stopped_at=timezone.now()
)
if completed_updated:
logger.info(
"✓ Flow 状态回调:扫描状态已原子更新为 COMPLETED - Scan ID: %s, Flow Run: %s",
scan_id,
flow_run.id
"✓ Flow 状态回调:扫描状态已原子更新为 COMPLETED - Scan ID: %s",
scan_id
)
return service.update_cached_stats(scan_id)
else:
logger.info(
" Flow 状态回调:状态未更新(可能已是终态)- Scan ID: %s, Flow Run: %s",
scan_id,
flow_run.id
" Flow 状态回调:状态未更新(可能已是终态)- Scan ID: %s",
scan_id
)
return None
# 执行状态更新并获取统计数据
stats = _update_completed_status()
# 注意:物化视图刷新已迁移到 pg_ivm 增量维护,无需手动标记刷新
# 发送通知(包含统计摘要)
logger.info("准备发送扫描完成通知 - Scan ID: %s, Target: %s", scan_id, target_name)
try:
from apps.scan.notifications import create_notification, NotificationLevel, NotificationCategory
from apps.scan.notifications import (
create_notification, NotificationLevel, NotificationCategory
)
# 构建通知消息
message = f"引擎:{engine_name}"
if stats:
@@ -180,11 +159,17 @@ def on_initiate_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State)
results.append(f"目录: {stats.get('directories', 0)}")
vulns_total = stats.get('vulns_total', 0)
if vulns_total > 0:
results.append(f"漏洞: {vulns_total} (严重:{stats.get('vulns_critical', 0)} 高:{stats.get('vulns_high', 0)} 中:{stats.get('vulns_medium', 0)} 低:{stats.get('vulns_low', 0)})")
results.append(
f"漏洞: {vulns_total} "
f"(严重:{stats.get('vulns_critical', 0)} "
f"高:{stats.get('vulns_high', 0)} "
f"中:{stats.get('vulns_medium', 0)} "
f"低:{stats.get('vulns_low', 0)})"
)
else:
results.append("漏洞: 0")
message += f"\n结果:{' | '.join(results)}"
create_notification(
title=f"{target_name} 扫描完成",
message=message,
@@ -193,46 +178,35 @@ def on_initiate_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State)
)
logger.info("✓ 扫描完成通知已发送 - Scan ID: %s, Target: %s", scan_id, target_name)
except Exception as e:
logger.error(f"发送扫描完成通知失败 - Scan ID: {scan_id}: {e}", exc_info=True)
logger.error("发送扫描完成通知失败 - Scan ID: %s: %s", scan_id, e, exc_info=True)
def on_initiate_scan_flow_failed(flow: Flow, flow_run: FlowRun, state: State) -> None:
def on_initiate_scan_flow_failed(context: FlowContext) -> None:
"""
initiate_scan_flow 失败时的回调
职责:更新 Scan 状态为 FAILED并记录错误信息
触发时机:
- Prefect Flow 执行失败或抛出异常时自动触发
- Flow 超时、任务失败等所有失败场景都会触发此回调
竞态条件处理:
- 如果用户已手动取消(状态已是 CANCELLED保持终态不覆盖
Args:
flow: Prefect Flow 对象
flow_run: Flow 运行实例
state: Flow 当前状态(包含错误信息)
context: Flow 执行上下文
"""
logger.info("❌ initiate_scan_flow_failed 回调开始运行 - Flow Run: %s", flow_run.id)
scan_id = flow_run.parameters.get('scan_id')
target_name = flow_run.parameters.get('target_name')
engine_name = flow_run.parameters.get('engine_name')
logger.info("❌ initiate_scan_flow_failed 回调开始运行 - Flow: %s", context.flow_name)
scan_id = context.scan_id
target_name = context.parameters.get('target_name')
engine_name = context.parameters.get('engine_name')
error_message = context.error_message or "Flow 执行失败"
if not scan_id:
return
def _update_failed_status():
from apps.scan.services import ScanService
from apps.common.definitions import ScanStatus
from django.utils import timezone
service = ScanService()
# 提取错误信息
error_message = str(state.message) if state.message else "Flow 执行失败"
# 仅在运行中时更新为 FAILED其他状态保持不变
failed_updated = service.update_status_if_match(
scan_id=scan_id,
@@ -240,33 +214,32 @@ def on_initiate_scan_flow_failed(flow: Flow, flow_run: FlowRun, state: State) ->
new_status=ScanStatus.FAILED,
stopped_at=timezone.now()
)
if failed_updated:
# 成功更新(正常失败流程)
logger.error(
"✗ Flow 状态回调:扫描状态已原子更新为 FAILED - Scan ID: %s, Flow Run: %s, 错误: %s",
"✗ Flow 状态回调:扫描状态已原子更新为 FAILED - Scan ID: %s, 错误: %s",
scan_id,
flow_run.id,
error_message
)
# 更新缓存统计数据(终态)
service.update_cached_stats(scan_id)
else:
logger.warning(
"⚠️ Flow 状态回调:未更新任何记录(可能已被其他进程处理)- Scan ID: %s, Flow Run: %s",
scan_id,
flow_run.id
"⚠️ Flow 状态回调:未更新任何记录(可能已被其他进程处理)- Scan ID: %s",
scan_id
)
return True
# 执行状态更新
_update_failed_status()
# 发送通知
logger.info("准备发送扫描失败通知 - Scan ID: %s, Target: %s", scan_id, target_name)
try:
from apps.scan.notifications import create_notification, NotificationLevel, NotificationCategory
error_message = str(state.message) if state.message else "未知错误"
from apps.scan.notifications import (
create_notification, NotificationLevel, NotificationCategory
)
message = f"引擎:{engine_name}\n错误:{error_message}"
create_notification(
title=f"{target_name} 扫描失败",
@@ -276,4 +249,4 @@ def on_initiate_scan_flow_failed(flow: Flow, flow_run: FlowRun, state: State) ->
)
logger.info("✓ 扫描失败通知已发送 - Scan ID: %s, Target: %s", scan_id, target_name)
except Exception as e:
logger.error(f"发送扫描失败通知失败 - Scan ID: {scan_id}: {e}", exc_info=True)
logger.error("发送扫描失败通知失败 - Scan ID: %s: %s", scan_id, e, exc_info=True)

View File

@@ -10,22 +10,26 @@
"""
import logging
from prefect import Flow
from prefect.client.schemas import FlowRun, State
from apps.scan.decorators import FlowContext
from apps.scan.utils.performance import FlowPerformanceTracker
from apps.scan.utils import user_log
logger = logging.getLogger(__name__)
# 存储每个 flow_run 的性能追踪器
# 存储每个 flow 的性能追踪器(使用 scan_id + stage_name 作为 key
_flow_trackers: dict[str, FlowPerformanceTracker] = {}
def _get_tracker_key(scan_id: int, stage_name: str) -> str:
"""生成追踪器的唯一 key"""
return f"{scan_id}_{stage_name}"
def _get_stage_from_flow_name(flow_name: str) -> str | None:
"""
从 Flow name 获取对应的 stage
Flow name 直接作为 stage与 engine_config 的 key 一致)
排除主 Flowinitiate_scan
"""
@@ -35,80 +39,81 @@ def _get_stage_from_flow_name(flow_name: str) -> str | None:
return flow_name
def on_scan_flow_running(flow: Flow, flow_run: FlowRun, state: State) -> None:
def on_scan_flow_running(context: FlowContext) -> None:
"""
扫描流程开始运行时的回调
职责:
- 更新阶段进度为 running
- 发送扫描开始通知
- 启动性能追踪
Args:
flow: Prefect Flow 对象
flow_run: Flow 运行实例
state: Flow 当前状态
context: Flow 执行上下文
"""
logger.info("🚀 扫描流程开始运行 - Flow: %s, Run ID: %s", flow.name, flow_run.id)
# 提取流程参数
flow_params = flow_run.parameters or {}
scan_id = flow_params.get('scan_id')
target_name = flow_params.get('target_name', 'unknown')
target_id = flow_params.get('target_id')
logger.info(
"🚀 扫描流程开始运行 - Flow: %s, Scan ID: %s",
context.flow_name, context.scan_id
)
scan_id = context.scan_id
target_name = context.target_name or 'unknown'
target_id = context.target_id
# 启动性能追踪
if scan_id:
tracker = FlowPerformanceTracker(flow.name, scan_id)
tracker_key = _get_tracker_key(scan_id, context.stage_name)
tracker = FlowPerformanceTracker(context.flow_name, scan_id)
tracker.start(target_id=target_id, target_name=target_name)
_flow_trackers[str(flow_run.id)] = tracker
_flow_trackers[tracker_key] = tracker
# 更新阶段进度
stage = _get_stage_from_flow_name(flow.name)
stage = _get_stage_from_flow_name(context.flow_name)
if scan_id and stage:
try:
from apps.scan.services import ScanService
service = ScanService()
service.start_stage(scan_id, stage)
logger.info(f"✓ 阶段进度已更新为 running - Scan ID: {scan_id}, Stage: {stage}")
logger.info(
"✓ 阶段进度已更新为 running - Scan ID: %s, Stage: %s",
scan_id, stage
)
except Exception as e:
logger.error(f"更新阶段进度失败 - Scan ID: {scan_id}, Stage: {stage}: {e}")
logger.error(
"更新阶段进度失败 - Scan ID: %s, Stage: %s: %s",
scan_id, stage, e
)
def on_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State) -> None:
def on_scan_flow_completed(context: FlowContext) -> None:
"""
扫描流程完成时的回调
职责:
- 更新阶段进度为 completed
- 发送扫描完成通知(可选)
- 记录性能指标
Args:
flow: Prefect Flow 对象
flow_run: Flow 运行实例
state: Flow 当前状态
context: Flow 执行上下文
"""
logger.info("✅ 扫描流程完成 - Flow: %s, Run ID: %s", flow.name, flow_run.id)
# 提取流程参数
flow_params = flow_run.parameters or {}
scan_id = flow_params.get('scan_id')
# 获取 flow result
result = None
try:
result = state.result() if state.result else None
except Exception:
pass
logger.info(
"✅ 扫描流程完成 - Flow: %s, Scan ID: %s",
context.flow_name, context.scan_id
)
scan_id = context.scan_id
result = context.result
# 记录性能指标
tracker = _flow_trackers.pop(str(flow_run.id), None)
if tracker:
tracker.finish(success=True)
if scan_id:
tracker_key = _get_tracker_key(scan_id, context.stage_name)
tracker = _flow_trackers.pop(tracker_key, None)
if tracker:
tracker.finish(success=True)
# 更新阶段进度
stage = _get_stage_from_flow_name(flow.name)
stage = _get_stage_from_flow_name(context.flow_name)
if scan_id and stage:
try:
from apps.scan.services import ScanService
@@ -118,72 +123,88 @@ def on_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State) -> None:
if isinstance(result, dict):
detail = result.get('detail')
service.complete_stage(scan_id, stage, detail)
logger.info(f"✓ 阶段进度已更新为 completed - Scan ID: {scan_id}, Stage: {stage}")
logger.info(
"✓ 阶段进度已更新为 completed - Scan ID: %s, Stage: %s",
scan_id, stage
)
# 每个阶段完成后刷新缓存统计,便于前端实时看到增量
try:
service.update_cached_stats(scan_id)
logger.info("✓ 阶段完成后已刷新缓存统计 - Scan ID: %s", scan_id)
except Exception as e:
logger.error("阶段完成后刷新缓存统计失败 - Scan ID: %s, 错误: %s", scan_id, e)
logger.error(
"阶段完成后刷新缓存统计失败 - Scan ID: %s, 错误: %s",
scan_id, e
)
except Exception as e:
logger.error(f"更新阶段进度失败 - Scan ID: {scan_id}, Stage: {stage}: {e}")
logger.error(
"更新阶段进度失败 - Scan ID: %s, Stage: %s: %s",
scan_id, stage, e
)
def on_scan_flow_failed(flow: Flow, flow_run: FlowRun, state: State) -> None:
def on_scan_flow_failed(context: FlowContext) -> None:
"""
扫描流程失败时的回调
职责:
- 更新阶段进度为 failed
- 发送扫描失败通知
- 记录性能指标(含错误信息)
- 写入 ScanLog 供前端显示
Args:
flow: Prefect Flow 对象
flow_run: Flow 运行实例
state: Flow 当前状态
context: Flow 执行上下文
"""
logger.info("❌ 扫描流程失败 - Flow: %s, Run ID: %s", flow.name, flow_run.id)
# 提取流程参数
flow_params = flow_run.parameters or {}
scan_id = flow_params.get('scan_id')
target_name = flow_params.get('target_name', 'unknown')
# 提取错误信息
error_message = str(state.message) if state.message else "未知错误"
logger.info(
"❌ 扫描流程失败 - Flow: %s, Scan ID: %s",
context.flow_name, context.scan_id
)
scan_id = context.scan_id
target_name = context.target_name or 'unknown'
error_message = context.error_message or "未知错误"
# 写入 ScanLog 供前端显示
stage = _get_stage_from_flow_name(flow.name)
stage = _get_stage_from_flow_name(context.flow_name)
if scan_id and stage:
user_log(scan_id, stage, f"Failed: {error_message}", "error")
# 记录性能指标(失败情况)
tracker = _flow_trackers.pop(str(flow_run.id), None)
if tracker:
tracker.finish(success=False, error_message=error_message)
if scan_id:
tracker_key = _get_tracker_key(scan_id, context.stage_name)
tracker = _flow_trackers.pop(tracker_key, None)
if tracker:
tracker.finish(success=False, error_message=error_message)
# 更新阶段进度
stage = _get_stage_from_flow_name(flow.name)
if scan_id and stage:
try:
from apps.scan.services import ScanService
service = ScanService()
service.fail_stage(scan_id, stage, error_message)
logger.info(f"✓ 阶段进度已更新为 failed - Scan ID: {scan_id}, Stage: {stage}")
logger.info(
"✓ 阶段进度已更新为 failed - Scan ID: %s, Stage: %s",
scan_id, stage
)
except Exception as e:
logger.error(f"更新阶段进度失败 - Scan ID: {scan_id}, Stage: {stage}: {e}")
logger.error(
"更新阶段进度失败 - Scan ID: %s, Stage: %s: %s",
scan_id, stage, e
)
# 发送通知
try:
from apps.scan.notifications import create_notification, NotificationLevel
message = f"任务:{flow.name}\n状态:执行失败\n错误:{error_message}"
message = f"任务:{context.flow_name}\n状态:执行失败\n错误:{error_message}"
create_notification(
title=target_name,
message=message,
level=NotificationLevel.HIGH
)
logger.error(f"✓ 扫描失败通知已发送 - Target: {target_name}, Flow: {flow.name}, Error: {error_message}")
logger.error(
"✓ 扫描失败通知已发送 - Target: %s, Flow: %s, Error: %s",
target_name, context.flow_name, error_message
)
except Exception as e:
logger.error(f"发送扫描失败通知失败 - Flow: {flow.name}: {e}")
logger.error("发送扫描失败通知失败 - Flow: %s: %s", context.flow_name, e)

View File

@@ -0,0 +1,18 @@
# Generated by Django 5.2.7 on 2026-01-07 14:03
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('scan', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='scan',
name='cached_screenshots_count',
field=models.IntegerField(default=0, help_text='缓存的截图数量'),
),
]

View File

@@ -0,0 +1,23 @@
# Generated manually for WeCom notification support
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('scan', '0002_add_cached_screenshots_count'),
]
operations = [
migrations.AddField(
model_name='notificationsettings',
name='wecom_enabled',
field=models.BooleanField(default=False, help_text='是否启用企业微信通知'),
),
migrations.AddField(
model_name='notificationsettings',
name='wecom_webhook_url',
field=models.URLField(blank=True, default='', help_text='企业微信机器人 Webhook URL'),
),
]

View File

@@ -0,0 +1,35 @@
# Generated by Django 5.2.7 on 2026-01-10 03:51
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('scan', '0003_add_wecom_fields'),
]
operations = [
migrations.AddField(
model_name='scan',
name='scan_mode',
field=models.CharField(choices=[('full', '完整扫描'), ('quick', '快速扫描')], default='full', help_text='扫描模式full=完整扫描quick=快速扫描', max_length=10),
),
migrations.CreateModel(
name='ScanInputTarget',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('value', models.CharField(help_text='用户输入的原始值', max_length=2000)),
('input_type', models.CharField(choices=[('domain', '域名'), ('ip', 'IP地址'), ('cidr', 'CIDR'), ('url', 'URL')], help_text='输入类型', max_length=10)),
('created_at', models.DateTimeField(auto_now_add=True)),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='input_targets', to='scan.scan')),
],
options={
'verbose_name': '扫描输入目标',
'verbose_name_plural': '扫描输入目标',
'db_table': 'scan_input_target',
'indexes': [models.Index(fields=['scan'], name='scan_input__scan_id_0a3227_idx'), models.Index(fields=['input_type'], name='scan_input__input_t_e3f681_idx')],
},
),
]

View File

@@ -4,6 +4,7 @@ from .scan_models import Scan, SoftDeleteManager
from .scan_log_model import ScanLog
from .scheduled_scan_model import ScheduledScan
from .subfinder_provider_settings_model import SubfinderProviderSettings
from .scan_input_target import ScanInputTarget
# 兼容旧名称(已废弃,请使用 SubfinderProviderSettings
ProviderSettings = SubfinderProviderSettings
@@ -15,4 +16,5 @@ __all__ = [
'SoftDeleteManager',
'SubfinderProviderSettings',
'ProviderSettings', # 兼容旧名称
'ScanInputTarget',
]

View File

@@ -0,0 +1,47 @@
"""
扫描输入目标模型
存储快速扫描时用户输入的目标支持大量数据1万+)的分块迭代。
用于快速扫描的第一阶段。
"""
from django.db import models
class ScanInputTarget(models.Model):
"""扫描输入目标表"""
class InputType(models.TextChoices):
"""输入类型枚举"""
DOMAIN = 'domain', '域名'
IP = 'ip', 'IP地址'
CIDR = 'cidr', 'CIDR'
URL = 'url', 'URL'
id = models.AutoField(primary_key=True)
scan = models.ForeignKey(
'scan.Scan',
on_delete=models.CASCADE,
related_name='input_targets',
help_text='所属的扫描任务'
)
value = models.CharField(max_length=2000, help_text='用户输入的原始值')
input_type = models.CharField(
max_length=10,
choices=InputType.choices,
help_text='输入类型'
)
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
"""模型元数据"""
db_table = 'scan_input_target'
verbose_name = '扫描输入目标'
verbose_name_plural = '扫描输入目标'
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['input_type']),
]
def __str__(self):
return f"ScanInputTarget #{self.id} - {self.value} ({self.input_type})"

View File

@@ -8,17 +8,28 @@ from apps.common.definitions import ScanStatus
class SoftDeleteManager(models.Manager):
"""软删除管理器:默认只返回未删除的记录"""
def get_queryset(self):
"""返回未删除记录的查询集"""
return super().get_queryset().filter(deleted_at__isnull=True)
class Scan(models.Model):
"""扫描任务模型"""
class ScanMode(models.TextChoices):
"""扫描模式枚举"""
FULL = 'full', '完整扫描'
QUICK = 'quick', '快速扫描'
id = models.AutoField(primary_key=True)
target = models.ForeignKey('targets.Target', on_delete=models.CASCADE, related_name='scans', help_text='扫描目标')
target = models.ForeignKey(
'targets.Target',
on_delete=models.CASCADE,
related_name='scans',
help_text='扫描目标'
)
# 多引擎支持字段
engine_ids = ArrayField(
@@ -35,6 +46,14 @@ class Scan(models.Model):
help_text='YAML 格式的扫描配置'
)
# 扫描模式
scan_mode = models.CharField(
max_length=10,
choices=ScanMode.choices,
default=ScanMode.FULL,
help_text='扫描模式full=完整扫描quick=快速扫描'
)
created_at = models.DateTimeField(auto_now_add=True, help_text='任务创建时间')
stopped_at = models.DateTimeField(null=True, blank=True, help_text='扫描结束时间')
@@ -46,7 +65,12 @@ class Scan(models.Model):
help_text='任务状态'
)
results_dir = models.CharField(max_length=100, blank=True, default='', help_text='结果存储目录')
results_dir = models.CharField(
max_length=100,
blank=True,
default='',
help_text='结果存储目录'
)
container_ids = ArrayField(
models.CharField(max_length=100),
@@ -54,7 +78,7 @@ class Scan(models.Model):
default=list,
help_text='容器 ID 列表Docker Container ID'
)
worker = models.ForeignKey(
'engine.WorkerNode',
on_delete=models.SET_NULL,
@@ -64,34 +88,46 @@ class Scan(models.Model):
help_text='执行扫描的 Worker 节点'
)
error_message = models.CharField(max_length=2000, blank=True, default='', help_text='错误信息')
error_message = models.CharField(
max_length=2000,
blank=True,
default='',
help_text='错误信息'
)
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间NULL表示未删除')
# 软删除字段
deleted_at = models.DateTimeField(
null=True,
blank=True,
db_index=True,
help_text='删除时间NULL表示未删除'
)
# ==================== 管理器 ====================
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
# 管理器
objects = SoftDeleteManager()
all_objects = models.Manager()
# ==================== 进度跟踪字段 ====================
# 进度跟踪字段
progress = models.IntegerField(default=0, help_text='扫描进度 0-100')
current_stage = models.CharField(max_length=50, blank=True, default='', help_text='当前扫描阶段')
stage_progress = models.JSONField(default=dict, help_text='各阶段进度详情')
# ==================== 缓存统计字段 ====================
cached_subdomains_count = models.IntegerField(default=0, help_text='缓存的子域名数量')
cached_websites_count = models.IntegerField(default=0, help_text='缓存的网站数量')
cached_endpoints_count = models.IntegerField(default=0, help_text='缓存的端点数量')
cached_ips_count = models.IntegerField(default=0, help_text='缓存的IP地址数量')
cached_directories_count = models.IntegerField(default=0, help_text='缓存的目录数量')
cached_vulns_total = models.IntegerField(default=0, help_text='缓存的漏洞总数')
cached_vulns_critical = models.IntegerField(default=0, help_text='缓存的严重漏洞数量')
cached_vulns_high = models.IntegerField(default=0, help_text='缓存的高危漏洞数量')
cached_vulns_medium = models.IntegerField(default=0, help_text='缓存的中危漏洞数量')
cached_vulns_low = models.IntegerField(default=0, help_text='缓存的低危漏洞数量')
# 缓存统计字段
cached_subdomains_count = models.IntegerField(default=0, help_text='子域名数量')
cached_websites_count = models.IntegerField(default=0, help_text='网站数量')
cached_endpoints_count = models.IntegerField(default=0, help_text='端点数量')
cached_ips_count = models.IntegerField(default=0, help_text='IP地址数量')
cached_directories_count = models.IntegerField(default=0, help_text='目录数量')
cached_screenshots_count = models.IntegerField(default=0, help_text='截图数量')
cached_vulns_total = models.IntegerField(default=0, help_text='漏洞总数')
cached_vulns_critical = models.IntegerField(default=0, help_text='严重漏洞数量')
cached_vulns_high = models.IntegerField(default=0, help_text='危漏洞数量')
cached_vulns_medium = models.IntegerField(default=0, help_text='危漏洞数量')
cached_vulns_low = models.IntegerField(default=0, help_text='低危漏洞数量')
stats_updated_at = models.DateTimeField(null=True, blank=True, help_text='统计数据最后更新时间')
class Meta:
"""模型元数据配置"""
db_table = 'scan'
verbose_name = '扫描任务'
verbose_name_plural = '扫描任务'

View File

@@ -1,8 +1,14 @@
"""通知系统数据模型"""
from django.db import models
import logging
from datetime import timedelta
from .types import NotificationLevel, NotificationCategory
from django.db import models
from django.utils import timezone
from .types import NotificationCategory, NotificationLevel
logger = logging.getLogger(__name__)
class NotificationSettings(models.Model):
@@ -10,31 +16,34 @@ class NotificationSettings(models.Model):
通知设置(单例模型)
存储 Discord webhook 配置和各分类的通知开关
"""
# Discord 配置
discord_enabled = models.BooleanField(default=False, help_text='是否启用 Discord 通知')
discord_webhook_url = models.URLField(blank=True, default='', help_text='Discord Webhook URL')
# 企业微信配置
wecom_enabled = models.BooleanField(default=False, help_text='是否启用企业微信通知')
wecom_webhook_url = models.URLField(blank=True, default='', help_text='企业微信机器人 Webhook URL')
# 分类开关(使用 JSONField 存储)
categories = models.JSONField(
default=dict,
help_text='各分类通知开关,如 {"scan": true, "vulnerability": true, "asset": true, "system": false}'
)
# 时间信息
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
db_table = 'notification_settings'
verbose_name = '通知设置'
verbose_name_plural = '通知设置'
def save(self, *args, **kwargs):
# 单例模式:强制只有一条记录
self.pk = 1
self.pk = 1 # 单例模式
super().save(*args, **kwargs)
@classmethod
def get_instance(cls) -> 'NotificationSettings':
"""获取或创建单例实例"""
@@ -52,7 +61,7 @@ class NotificationSettings(models.Model):
}
)
return obj
def is_category_enabled(self, category: str) -> bool:
"""检查指定分类是否启用通知"""
return self.categories.get(category, False)
@@ -60,10 +69,9 @@ class NotificationSettings(models.Model):
class Notification(models.Model):
"""通知模型"""
id = models.AutoField(primary_key=True)
# 通知分类
category = models.CharField(
max_length=20,
choices=NotificationCategory.choices,
@@ -71,8 +79,7 @@ class Notification(models.Model):
db_index=True,
help_text='通知分类'
)
# 通知级别
level = models.CharField(
max_length=20,
choices=NotificationLevel.choices,
@@ -80,16 +87,15 @@ class Notification(models.Model):
db_index=True,
help_text='通知级别'
)
title = models.CharField(max_length=200, help_text='通知标题')
message = models.CharField(max_length=2000, help_text='通知内容')
# 时间信息
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
is_read = models.BooleanField(default=False, help_text='是否已读')
read_at = models.DateTimeField(null=True, blank=True, help_text='阅读时间')
class Meta:
db_table = 'notification'
verbose_name = '通知'
@@ -101,44 +107,26 @@ class Notification(models.Model):
models.Index(fields=['level', '-created_at']),
models.Index(fields=['is_read', '-created_at']),
]
def __str__(self):
return f"{self.get_level_display()} - {self.title}"
@classmethod
def cleanup_old_notifications(cls):
"""
清理超过15天的旧通知硬编码
Returns:
int: 删除的通知数量
"""
from datetime import timedelta
from django.utils import timezone
# 硬编码只保留最近15天的通知
def cleanup_old_notifications(cls) -> int:
"""清理超过15天的旧通知"""
cutoff_date = timezone.now() - timedelta(days=15)
delete_result = cls.objects.filter(created_at__lt=cutoff_date).delete()
return delete_result[0] if delete_result[0] else 0
deleted_count, _ = cls.objects.filter(created_at__lt=cutoff_date).delete()
return deleted_count or 0
def save(self, *args, **kwargs):
"""
重写save方法在创建新通知时自动清理旧通知
"""
"""重写save方法在创建新通知时自动清理旧通知"""
is_new = self.pk is None
super().save(*args, **kwargs)
# 只在创建新通知时执行清理自动清理超过15天的通知
if is_new:
try:
deleted_count = self.__class__.cleanup_old_notifications()
if deleted_count > 0:
import logging
logger = logging.getLogger(__name__)
logger.info(f"自动清理{deleted_count} 条超过15天的旧通知")
except Exception as e:
# 清理失败不应影响通知创建
import logging
logger = logging.getLogger(__name__)
logger.warning(f"通知自动清理失败: {e}")
logger.info("自动清理了 %d 条超过15天的旧通知", deleted_count)
except Exception:
logger.warning("通知自动清理失败", exc_info=True)

View File

@@ -87,7 +87,7 @@ def on_all_workers_high_load(sender, worker_name, cpu, mem, **kwargs):
"""所有 Worker 高负载时的通知处理"""
create_notification(
title="系统负载较高",
message=f"所有节点负载较高(最低负载节点 CPU: {cpu:.1f}%, 内存: {mem:.1f}%),系统将等待最多 10 分钟后分发任务,扫描速度可能受影响",
message=f"所有节点负载较高(最低负载节点 CPU: {cpu:.1f}%, 内存: {mem:.1f}%),系统将每 2 分钟检测一次,最多等待 2 小时后分发任务",
level=NotificationLevel.MEDIUM,
category=NotificationCategory.SYSTEM
)

View File

@@ -1,52 +1,70 @@
"""通知系统仓储层模块"""
import logging
from typing import TypedDict
from dataclasses import dataclass
from typing import Optional
from django.db.models import QuerySet
from django.utils import timezone
from apps.common.decorators import auto_ensure_db_connection
from .models import Notification, NotificationSettings
from .models import Notification, NotificationSettings
logger = logging.getLogger(__name__)
class NotificationSettingsData(TypedDict):
"""通知设置数据结构"""
@dataclass
class NotificationSettingsData:
"""通知设置更新数据"""
discord_enabled: bool
discord_webhook_url: str
categories: dict[str, bool]
wecom_enabled: bool = False
wecom_webhook_url: str = ''
@auto_ensure_db_connection
class NotificationSettingsRepository:
"""通知设置仓储层"""
def get_settings(self) -> NotificationSettings:
"""获取通知设置单例"""
return NotificationSettings.get_instance()
def update_settings(
self,
discord_enabled: bool,
discord_webhook_url: str,
categories: dict[str, bool]
) -> NotificationSettings:
def update_settings(self, data: NotificationSettingsData) -> NotificationSettings:
"""更新通知设置"""
settings = NotificationSettings.get_instance()
settings.discord_enabled = discord_enabled
settings.discord_webhook_url = discord_webhook_url
settings.categories = categories
settings.discord_enabled = data.discord_enabled
settings.discord_webhook_url = data.discord_webhook_url
settings.wecom_enabled = data.wecom_enabled
settings.wecom_webhook_url = data.wecom_webhook_url
settings.categories = data.categories
settings.save()
return settings
def is_category_enabled(self, category: str) -> bool:
"""检查指定分类是否启用"""
settings = self.get_settings()
return settings.is_category_enabled(category)
return self.get_settings().is_category_enabled(category)
@auto_ensure_db_connection
class DjangoNotificationRepository:
def get_filtered(self, level: str | None = None, unread: bool | None = None):
"""通知数据仓储层"""
def get_filtered(
self,
level: Optional[str] = None,
unread: Optional[bool] = None
) -> QuerySet[Notification]:
"""
获取过滤后的通知列表
Args:
level: 通知级别过滤
unread: 已读状态过滤 (True=未读, False=已读, None=全部)
"""
queryset = Notification.objects.all()
if level:
@@ -60,16 +78,24 @@ class DjangoNotificationRepository:
return queryset.order_by("-created_at")
def get_unread_count(self) -> int:
"""获取未读通知数量"""
return Notification.objects.filter(is_read=False).count()
def mark_all_as_read(self) -> int:
updated = Notification.objects.filter(is_read=False).update(
"""标记所有通知为已读,返回更新数量"""
return Notification.objects.filter(is_read=False).update(
is_read=True,
read_at=timezone.now(),
)
return updated
def create(self, title: str, message: str, level: str, category: str = 'system') -> Notification:
def create(
self,
title: str,
message: str,
level: str,
category: str = 'system'
) -> Notification:
"""创建新通知"""
return Notification.objects.create(
category=category,
level=level,

View File

@@ -60,13 +60,12 @@ def push_to_external_channels(notification: Notification) -> None:
except Exception as e:
logger.warning(f"Discord 推送失败: {e}")
# 未来扩展Slack
# if settings.slack_enabled and settings.slack_webhook_url:
# _send_slack(notification, settings.slack_webhook_url)
# 未来扩展Telegram
# if settings.telegram_enabled and settings.telegram_bot_token:
# _send_telegram(notification, settings.telegram_chat_id)
# 企业微信渠道
if settings.wecom_enabled and settings.wecom_webhook_url:
try:
_send_wecom(notification, settings.wecom_webhook_url)
except Exception as e:
logger.warning(f"企业微信推送失败: {e}")
def _send_discord(notification: Notification, webhook_url: str) -> bool:
@@ -103,6 +102,41 @@ def _send_discord(notification: Notification, webhook_url: str) -> bool:
return False
def _send_wecom(notification: Notification, webhook_url: str) -> bool:
"""发送到企业微信机器人 Webhook"""
try:
emoji = CATEGORY_EMOJI.get(notification.category, '📢')
# 企业微信 Markdown 格式
content = f"""**{emoji} {notification.title}**
> 级别:{notification.get_level_display()}
> 分类:{notification.get_category_display()}
{notification.message}"""
payload = {
'msgtype': 'markdown',
'markdown': {'content': content}
}
response = requests.post(webhook_url, json=payload, timeout=10)
if response.status_code == 200:
result = response.json()
if result.get('errcode') == 0:
logger.info(f"企业微信通知发送成功 - {notification.title}")
return True
logger.warning(f"企业微信发送失败 - errcode: {result.get('errcode')}, errmsg: {result.get('errmsg')}")
return False
logger.warning(f"企业微信发送失败 - 状态码: {response.status_code}")
return False
except requests.RequestException as e:
logger.error(f"企业微信网络错误: {e}")
return False
# ============================================================
# 设置服务
# ============================================================
@@ -121,31 +155,43 @@ class NotificationSettingsService:
'enabled': settings.discord_enabled,
'webhookUrl': settings.discord_webhook_url,
},
'wecom': {
'enabled': settings.wecom_enabled,
'webhookUrl': settings.wecom_webhook_url,
},
'categories': settings.categories,
}
def update_settings(self, data: dict) -> dict:
"""更新通知设置
注意DRF CamelCaseJSONParser 会将前端的 webhookUrl 转换为 webhook_url
"""
discord_data = data.get('discord', {})
wecom_data = data.get('wecom', {})
categories = data.get('categories', {})
# CamelCaseJSONParser 转换后的字段名是 webhook_url
webhook_url = discord_data.get('webhook_url', '')
discord_webhook_url = discord_data.get('webhook_url', '')
wecom_webhook_url = wecom_data.get('webhook_url', '')
settings = self.repo.update_settings(
discord_enabled=discord_data.get('enabled', False),
discord_webhook_url=webhook_url,
discord_webhook_url=discord_webhook_url,
wecom_enabled=wecom_data.get('enabled', False),
wecom_webhook_url=wecom_webhook_url,
categories=categories,
)
return {
'discord': {
'enabled': settings.discord_enabled,
'webhookUrl': settings.discord_webhook_url,
},
'wecom': {
'enabled': settings.wecom_enabled,
'webhookUrl': settings.wecom_webhook_url,
},
'categories': settings.categories,
}

View File

@@ -147,10 +147,10 @@ class FlowOrchestrator:
return True
return False
# 其他扫描类型:检查 tools
# 其他扫描类型(包括 screenshot:检查 tools
tools = scan_config.get('tools', {})
for tool_config in tools.values():
if tool_config.get('enabled', False):
if isinstance(tool_config, dict) and tool_config.get('enabled', False):
return True
return False
@@ -222,6 +222,10 @@ class FlowOrchestrator:
from apps.scan.flows.vuln_scan import vuln_scan_flow
return vuln_scan_flow
elif scan_type == 'screenshot':
from apps.scan.flows.screenshot_flow import screenshot_flow
return screenshot_flow
else:
logger.warning(f"未实现的扫描类型: {scan_type}")
return None

View File

@@ -0,0 +1,51 @@
"""
扫描目标提供者模块
提供统一的目标获取接口,支持多种数据源:
- DatabaseTargetProvider: 从数据库查询(完整扫描)
- SnapshotTargetProvider: 从快照表读取(快速扫描)
Provider 方法:
- get_target_name(): Target 名称(根域名/IP/CIDR
- iter_subdomains(): 子域名列表
- iter_host_port_urls(): 从 host:port 生成的 URL站点探测用
- iter_websites(): 已存活网站 URL截图、指纹、目录扫描用
- iter_endpoints(): 端点 URL漏洞扫描用
使用方式:
from apps.scan.providers import (
DatabaseTargetProvider,
SnapshotTargetProvider,
ProviderContext
)
# 数据库模式(完整扫描)
provider = DatabaseTargetProvider(target_id=123)
# 端口扫描:显式组合 target_name + subdomains
target_name = provider.get_target_name()
if target_name:
scan_port(target_name) # CIDR 需要调用方自己展开
for subdomain in provider.iter_subdomains():
scan_port(subdomain)
# 截图
for url in provider.iter_websites():
take_screenshot(url)
# 快照模式(快速扫描)
provider = SnapshotTargetProvider(scan_id=100)
for url in provider.iter_websites():
take_screenshot(url)
"""
from .base import TargetProvider, ProviderContext
from .database_provider import DatabaseTargetProvider
from .snapshot_provider import SnapshotTargetProvider
__all__ = [
'TargetProvider',
'ProviderContext',
'DatabaseTargetProvider',
'SnapshotTargetProvider',
]

View File

@@ -0,0 +1,226 @@
"""
扫描目标提供者基础模块
定义 ProviderContext 数据类和 TargetProvider 抽象基类。
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, Optional
if TYPE_CHECKING:
from apps.common.utils import BlacklistFilter
logger = logging.getLogger(__name__)
@dataclass
class ProviderContext:
"""
Provider 上下文,携带元数据
Attributes:
target_id: 关联的 Target ID用于结果保存None 表示临时扫描(不保存)
scan_id: 扫描任务 ID
"""
target_id: Optional[int] = None
scan_id: Optional[int] = None
class TargetProvider(ABC):
"""
扫描目标提供者抽象基类
职责:
- 提供扫描目标域名、IP、URL 等)的迭代器
- 提供黑名单过滤器
- 携带上下文信息target_id, scan_id 等)
方法说明:
- get_target_name(): Target 名称(根域名/IP/CIDR
- iter_subdomains(): 子域名列表
- iter_host_port_urls(): 从 host:port 生成的 URL站点探测用
- iter_websites(): 已存活网站 URL截图、指纹、目录扫描用
- iter_endpoints(): 端点 URL漏洞扫描用
使用方式:
provider = DatabaseTargetProvider(target_id=123)
# 端口扫描:显式组合 target_name + subdomains
target_name = provider.get_target_name()
if target_name:
scan_port(target_name) # CIDR 需要调用方自己展开
for subdomain in provider.iter_subdomains():
scan_port(subdomain)
# 截图
for url in provider.iter_websites():
take_screenshot(url)
"""
def __init__(self, context: Optional[ProviderContext] = None):
self._context = context or ProviderContext()
self._target_name: Optional[str] = None # 缓存 target_name
@property
def context(self) -> ProviderContext:
"""返回 Provider 上下文"""
return self._context
def get_target_name(self) -> Optional[str]:
"""
获取 Target 名称(根域名/IP/CIDR
Returns:
Target 名称,不存在时返回 None
注意CIDR 不会自动展开,调用方需要自己处理
"""
# 使用缓存避免重复查询
if self._target_name is not None:
return self._target_name
if not self.target_id:
logger.warning("target_id 未设置,无法获取 Target 名称")
return None
from apps.targets.services import TargetService
target = TargetService().get_target(self.target_id)
self._target_name = target.name if target else None
return self._target_name
def iter_target_hosts(self) -> Iterator[str]:
"""
迭代 Target 展开后的主机列表(已过滤黑名单)
- DOMAIN/IP: 直接返回
- CIDR: 展开为所有 IP
Returns:
主机迭代器(域名或 IP
"""
import ipaddress
from apps.common.validators import detect_target_type
from apps.targets.models import Target
target_name = self.get_target_name()
if not target_name:
return
blacklist = self.get_blacklist_filter()
target_type = detect_target_type(target_name)
if target_type == Target.TargetType.CIDR:
# CIDR 展开
network = ipaddress.ip_network(target_name, strict=False)
if network.num_addresses == 1:
hosts = [str(network.network_address)]
else:
hosts = [str(ip) for ip in network.hosts()]
else:
# DOMAIN / IP 直接返回
hosts = [target_name]
for host in hosts:
if not blacklist or blacklist.is_allowed(host):
yield host
@abstractmethod
def iter_subdomains(self) -> Iterator[str]:
"""迭代子域名列表,子类实现"""
@abstractmethod
def iter_host_port_urls(self) -> Iterator[str]:
"""
迭代 host:port 生成的 URL待探测
用于站点扫描httpx 探测),从 HostPortMapping 生成 URL。
返回格式http://host:port 或 https://host:port
"""
@abstractmethod
def iter_websites(self) -> Iterator[str]:
"""
迭代已存活网站 URL
用于截图、指纹识别、目录扫描、URL 爬虫。
数据来源WebSite 表(已确认存活的网站)
"""
@abstractmethod
def iter_endpoints(self) -> Iterator[str]:
"""
迭代端点 URL
用于漏洞扫描。
数据来源Endpoint 表(带参数的 URL
"""
def iter_default_urls(self) -> Iterator[str]:
"""
从 Target 本身生成默认 URL
用于跳过前置阶段直接扫描的场景。
根据 Target 类型生成:
- DOMAIN: http(s)://domain
- IP: http(s)://ip
- CIDR: 展开为所有 IP 的 http(s)://ip
"""
import ipaddress
from apps.targets.models import Target
from apps.targets.services import TargetService
if not self.target_id:
logger.warning("target_id 未设置,无法生成默认 URL")
return
target = TargetService().get_target(self.target_id)
if not target:
logger.warning("Target ID %d 不存在,无法生成默认 URL", self.target_id)
return
target_name = target.name
target_type = target.type
blacklist = self.get_blacklist_filter()
if target_type == Target.TargetType.DOMAIN:
urls = [f"http://{target_name}", f"https://{target_name}"]
elif target_type == Target.TargetType.IP:
urls = [f"http://{target_name}", f"https://{target_name}"]
elif target_type == Target.TargetType.CIDR:
try:
network = ipaddress.ip_network(target_name, strict=False)
urls = []
for ip in network.hosts():
urls.extend([f"http://{ip}", f"https://{ip}"])
# /32 或 /128 特殊处理
if not urls:
ip = str(network.network_address)
urls = [f"http://{ip}", f"https://{ip}"]
except ValueError as e:
logger.error("CIDR 解析失败: %s - %s", target_name, e)
return
else:
logger.warning("不支持的 Target 类型: %s", target_type)
return
for url in urls:
if not blacklist or blacklist.is_allowed(url):
yield url
@abstractmethod
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
"""获取黑名单过滤器,返回 None 表示不过滤"""
@property
def target_id(self) -> Optional[int]:
"""返回关联的 target_id临时扫描返回 None"""
return self._context.target_id
@property
def scan_id(self) -> Optional[int]:
"""返回关联的 scan_id"""
return self._context.scan_id

View File

@@ -0,0 +1,133 @@
"""
数据库目标提供者模块
提供基于数据库查询的目标提供者实现。
用于完整扫描模式,从 Target 关联的资产表查询数据。
"""
import logging
from typing import TYPE_CHECKING, Iterator, Optional
from .base import ProviderContext, TargetProvider
if TYPE_CHECKING:
from apps.common.utils import BlacklistFilter
logger = logging.getLogger(__name__)
class DatabaseTargetProvider(TargetProvider):
"""
数据库目标提供者 - 从 Target 表及关联资产表查询
用于完整扫描模式,查询目标下的所有历史资产。
数据来源:
- iter_target_name(): Target 表(根域名/IP/CIDR
- iter_subdomains(): Subdomain 表
- iter_host_port_urls(): HostPortMapping 表
- iter_websites(): WebSite 表
- iter_endpoints(): Endpoint 表
- iter_default_urls(): 从 Target 本身生成默认 URL
回退逻辑由调用方Task/Flow决定Provider 只负责单一数据源查询。
使用方式:
provider = DatabaseTargetProvider(target_id=123)
# 端口扫描:显式组合
for name in provider.iter_target_name():
scan_port(name) # CIDR 需要调用方自己展开
for subdomain in provider.iter_subdomains():
scan_port(subdomain)
# 调用方控制回退
urls = list(provider.iter_endpoints())
if not urls:
urls = list(provider.iter_websites())
if not urls:
urls = list(provider.iter_default_urls())
"""
def __init__(self, target_id: int, context: Optional[ProviderContext] = None):
ctx = context or ProviderContext()
ctx.target_id = target_id
super().__init__(ctx)
self._blacklist_filter: Optional['BlacklistFilter'] = None
def iter_subdomains(self) -> Iterator[str]:
"""从 Subdomain 表查询子域名列表"""
from apps.asset.services.asset.subdomain_service import SubdomainService
blacklist = self.get_blacklist_filter()
for domain in SubdomainService().iter_subdomain_names_by_target(
target_id=self.target_id,
chunk_size=1000
):
if not blacklist or blacklist.is_allowed(domain):
yield domain
def iter_host_port_urls(self) -> Iterator[str]:
"""从 HostPortMapping 表生成待探测的 URL"""
from apps.asset.models import HostPortMapping
blacklist = self.get_blacklist_filter()
queryset = HostPortMapping.objects.filter(
target_id=self.target_id
).values('host', 'port').distinct()
for mapping in queryset.iterator(chunk_size=1000):
host = mapping['host']
port = mapping['port']
if port == 80:
urls = [f"http://{host}"]
elif port == 443:
urls = [f"https://{host}"]
else:
urls = [f"http://{host}:{port}", f"https://{host}:{port}"]
for url in urls:
if not blacklist or blacklist.is_allowed(url):
yield url
def iter_websites(self) -> Iterator[str]:
"""从 WebSite 表查询已存活网站 URL"""
from apps.asset.models import WebSite
blacklist = self.get_blacklist_filter()
queryset = WebSite.objects.filter(
target_id=self.target_id
).values_list('url', flat=True)
for url in queryset.iterator(chunk_size=1000):
if url:
if not blacklist or blacklist.is_allowed(url):
yield url
def iter_endpoints(self) -> Iterator[str]:
"""从 Endpoint 表查询端点 URL"""
from apps.asset.models import Endpoint
blacklist = self.get_blacklist_filter()
queryset = Endpoint.objects.filter(
target_id=self.target_id
).values_list('url', flat=True)
for url in queryset.iterator(chunk_size=1000):
if url:
if not blacklist or blacklist.is_allowed(url):
yield url
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
"""获取黑名单过滤器(延迟加载)"""
if self._blacklist_filter is None:
from apps.common.services import BlacklistService
from apps.common.utils import BlacklistFilter
rules = BlacklistService().get_rules(self.target_id)
self._blacklist_filter = BlacklistFilter(rules)
return self._blacklist_filter

View File

@@ -0,0 +1,111 @@
"""
快照目标提供者模块
提供基于快照表的目标提供者实现。
用于快速扫描的阶段间数据传递。
"""
import logging
from typing import Iterator, Optional
from .base import ProviderContext, TargetProvider
logger = logging.getLogger(__name__)
class SnapshotTargetProvider(TargetProvider):
"""
快照目标提供者 - 从快照表读取本次扫描的数据
用于快速扫描的阶段间数据传递,解决精确扫描控制问题。
核心价值:
- 只返回本次扫描scan_id发现的资产
- 避免扫描历史数据DatabaseTargetProvider 会扫描所有历史资产)
特点:
- 通过 scan_id 过滤快照表
- 不应用黑名单过滤(数据已在上一阶段过滤)
- 每个 iter_* 方法只查对应的快照表(单一职责)
- 回退逻辑由调用方Task/Flow决定
使用场景:
provider = SnapshotTargetProvider(scan_id=100)
# 单一数据源
for url in provider.iter_websites():
take_screenshot(url)
# 调用方控制回退
urls = list(provider.iter_endpoints())
if not urls:
urls = list(provider.iter_websites())
if not urls:
urls = list(provider.iter_default_urls())
"""
def __init__(
self,
scan_id: int,
context: Optional[ProviderContext] = None
):
"""
初始化快照目标提供者
Args:
scan_id: 扫描任务 ID必需
context: Provider 上下文
"""
ctx = context or ProviderContext()
ctx.scan_id = scan_id
super().__init__(ctx)
self._scan_id = scan_id
def iter_subdomains(self) -> Iterator[str]:
"""从 SubdomainSnapshot 迭代子域名列表"""
from apps.asset.services.snapshot import SubdomainSnapshotsService
service = SubdomainSnapshotsService()
yield from service.iter_subdomain_names_by_scan(
scan_id=self._scan_id,
chunk_size=1000
)
def iter_host_port_urls(self) -> Iterator[str]:
"""从 HostPortMappingSnapshot 生成待探测的 URL"""
from apps.asset.services.snapshot import HostPortMappingSnapshotsService
service = HostPortMappingSnapshotsService()
for mapping in service.iter_unique_host_ports_by_scan(
scan_id=self._scan_id,
batch_size=1000
):
host = mapping['host']
port = mapping['port']
if port == 80:
yield f"http://{host}"
elif port == 443:
yield f"https://{host}"
else:
yield f"http://{host}:{port}"
yield f"https://{host}:{port}"
def iter_websites(self) -> Iterator[str]:
"""从 WebsiteSnapshot 迭代网站 URL"""
from apps.asset.services.snapshot import WebsiteSnapshotsService
service = WebsiteSnapshotsService()
yield from service.iter_website_urls_by_scan(
scan_id=self._scan_id,
chunk_size=1000
)
def iter_endpoints(self) -> Iterator[str]:
"""从 EndpointSnapshot 迭代端点 URL"""
from apps.asset.services.snapshot import EndpointSnapshotsService
service = EndpointSnapshotsService()
queryset = service.get_by_scan(scan_id=self._scan_id)
for endpoint in queryset.iterator(chunk_size=1000):
yield endpoint.url
def get_blacklist_filter(self) -> None:
"""快照数据已在上一阶段过滤过了"""
return None

View File

@@ -0,0 +1,3 @@
"""
扫描目标提供者测试模块
"""

View File

@@ -0,0 +1,168 @@
"""
DatabaseTargetProvider 属性测试
Property 7: DatabaseTargetProvider Blacklist Application
*For any* 带有黑名单规则的 target_idDatabaseTargetProvider 的 iter_subdomains()
应该过滤掉匹配黑名单规则的目标。
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
"""
import pytest
from unittest.mock import patch, MagicMock
from hypothesis import given, strategies as st, settings
from apps.scan.providers.database_provider import DatabaseTargetProvider
from apps.scan.providers.base import ProviderContext
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
class MockBlacklistFilter:
"""模拟黑名单过滤器"""
def __init__(self, blocked_patterns: list):
self.blocked_patterns = blocked_patterns
def is_allowed(self, target: str) -> bool:
"""检查目标是否被允许(不在黑名单中)"""
for pattern in self.blocked_patterns:
if pattern in target:
return False
return True
class TestDatabaseTargetProviderProperties:
"""DatabaseTargetProvider 属性测试类"""
@given(
subdomains=st.lists(valid_domain_strategy(), min_size=1, max_size=20),
blocked_keyword=st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=5
)
)
@settings(max_examples=100)
def test_property_7_blacklist_filters_subdomains(self, subdomains, blocked_keyword):
"""
Property 7: DatabaseTargetProvider Blacklist Application (subdomains)
Feature: scan-target-provider, Property 7: DatabaseTargetProvider Blacklist Application
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
For any set of subdomains and a blacklist keyword, the provider should filter out
all subdomains containing the blocked keyword.
"""
# 创建模拟的黑名单过滤器
mock_filter = MockBlacklistFilter([blocked_keyword])
# 创建 provider 并注入模拟的黑名单过滤器
provider = DatabaseTargetProvider(target_id=1)
provider._blacklist_filter = mock_filter
with patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_subdomain_service:
mock_subdomain_service.return_value.iter_subdomain_names_by_target.return_value = iter(subdomains)
# 获取结果
result = list(provider.iter_subdomains())
# 验证:所有结果都不包含被阻止的关键词
for subdomain in result:
assert blocked_keyword not in subdomain, f"Subdomain '{subdomain}' should be filtered by blacklist keyword '{blocked_keyword}'"
# 验证:所有不包含关键词的子域名都应该在结果中
expected_allowed = [s for s in subdomains if blocked_keyword not in s]
assert set(result) == set(expected_allowed)
class TestDatabaseTargetProviderUnit:
"""DatabaseTargetProvider 单元测试类"""
def test_target_id_in_context(self):
"""测试 target_id 正确设置到上下文中"""
provider = DatabaseTargetProvider(target_id=123)
assert provider.target_id == 123
assert provider.context.target_id == 123
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(scan_id=789)
provider = DatabaseTargetProvider(target_id=123, context=ctx)
assert provider.target_id == 123 # target_id 被覆盖
assert provider.scan_id == 789
def test_blacklist_filter_lazy_loading(self):
"""测试黑名单过滤器延迟加载"""
provider = DatabaseTargetProvider(target_id=123)
# 初始时 _blacklist_filter 为 None
assert provider._blacklist_filter is None
# 模拟 BlacklistService
with patch('apps.common.services.BlacklistService') as mock_service, \
patch('apps.common.utils.BlacklistFilter') as mock_filter_class:
mock_service.return_value.get_rules.return_value = []
mock_filter_instance = MagicMock()
mock_filter_class.return_value = mock_filter_instance
# 第一次调用
result1 = provider.get_blacklist_filter()
assert result1 == mock_filter_instance
# 第二次调用应该返回缓存的实例
result2 = provider.get_blacklist_filter()
assert result2 == mock_filter_instance
# BlacklistService 只应该被调用一次
mock_service.return_value.get_rules.assert_called_once_with(123)
def test_get_target_name(self):
"""测试 get_target_name 返回 Target 名称"""
provider = DatabaseTargetProvider(target_id=123)
mock_target = MagicMock()
mock_target.name = 'example.com'
with patch('apps.targets.services.TargetService') as mock_service:
mock_service.return_value.get_target.return_value = mock_target
result = provider.get_target_name()
assert result == 'example.com'
def test_get_target_name_nonexistent(self):
"""测试不存在的 target 返回 None"""
provider = DatabaseTargetProvider(target_id=99999)
with patch('apps.targets.services.TargetService') as mock_service:
mock_service.return_value.get_target.return_value = None
result = provider.get_target_name()
assert result is None
def test_iter_subdomains_empty(self):
"""测试空子域名列表"""
provider = DatabaseTargetProvider(target_id=123)
with patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_service, \
patch('apps.common.services.BlacklistService') as mock_blacklist_service:
mock_service.return_value.iter_subdomain_names_by_target.return_value = iter([])
mock_blacklist_service.return_value.get_rules.return_value = []
result = list(provider.iter_subdomains())
assert result == []

View File

@@ -0,0 +1,121 @@
"""
SnapshotTargetProvider 单元测试
"""
import pytest
from unittest.mock import Mock, patch
from apps.scan.providers import SnapshotTargetProvider, ProviderContext
class TestSnapshotTargetProvider:
"""SnapshotTargetProvider 测试类"""
def test_init_with_scan_id(self):
"""测试初始化"""
provider = SnapshotTargetProvider(scan_id=100)
assert provider.scan_id == 100
assert provider.target_id is None
def test_init_with_context(self):
"""测试带 context 初始化"""
ctx = ProviderContext(target_id=1, scan_id=100)
provider = SnapshotTargetProvider(scan_id=100, context=ctx)
assert provider.scan_id == 100
assert provider.target_id == 1
@patch('apps.asset.services.snapshot.SubdomainSnapshotsService')
def test_iter_subdomains(self, mock_service_class):
"""测试从子域名快照迭代子域名"""
mock_service = Mock()
mock_service.iter_subdomain_names_by_scan.return_value = iter([
"a.example.com",
"b.example.com"
])
mock_service_class.return_value = mock_service
provider = SnapshotTargetProvider(scan_id=100)
subdomains = list(provider.iter_subdomains())
assert subdomains == ["a.example.com", "b.example.com"]
mock_service.iter_subdomain_names_by_scan.assert_called_once_with(
scan_id=100,
chunk_size=1000
)
@patch('apps.asset.services.snapshot.HostPortMappingSnapshotsService')
def test_iter_host_port_urls(self, mock_service_class):
"""测试从主机端口映射快照生成 URL"""
mock_service = Mock()
mock_service.iter_unique_host_ports_by_scan.return_value = iter([
{'host': 'example.com', 'port': 80},
{'host': 'example.com', 'port': 443},
{'host': 'example.com', 'port': 8080},
])
mock_service_class.return_value = mock_service
provider = SnapshotTargetProvider(scan_id=100)
urls = list(provider.iter_host_port_urls())
assert urls == [
"http://example.com",
"https://example.com",
"http://example.com:8080",
"https://example.com:8080",
]
@patch('apps.asset.services.snapshot.WebsiteSnapshotsService')
def test_iter_websites(self, mock_service_class):
"""测试从网站快照迭代 URL"""
mock_service = Mock()
mock_service.iter_website_urls_by_scan.return_value = iter([
"http://example.com",
"https://example.com"
])
mock_service_class.return_value = mock_service
provider = SnapshotTargetProvider(scan_id=100)
urls = list(provider.iter_websites())
assert urls == ["http://example.com", "https://example.com"]
mock_service.iter_website_urls_by_scan.assert_called_once_with(
scan_id=100,
chunk_size=1000
)
@patch('apps.asset.services.snapshot.EndpointSnapshotsService')
def test_iter_endpoints(self, mock_service_class):
"""测试从端点快照迭代 URL"""
mock_endpoint1 = Mock()
mock_endpoint1.url = "http://example.com/api/v1"
mock_endpoint2 = Mock()
mock_endpoint2.url = "http://example.com/api/v2"
mock_queryset = Mock()
mock_queryset.iterator.return_value = iter([mock_endpoint1, mock_endpoint2])
mock_service = Mock()
mock_service.get_by_scan.return_value = mock_queryset
mock_service_class.return_value = mock_service
provider = SnapshotTargetProvider(scan_id=100)
urls = list(provider.iter_endpoints())
assert urls == ["http://example.com/api/v1", "http://example.com/api/v2"]
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
def test_get_blacklist_filter(self):
"""测试黑名单过滤器(快照模式不使用黑名单)"""
provider = SnapshotTargetProvider(scan_id=100)
assert provider.get_blacklist_filter() is None
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(target_id=456, scan_id=789)
provider = SnapshotTargetProvider(scan_id=100, context=ctx)
assert provider.target_id == 456
assert provider.scan_id == 100

View File

@@ -464,6 +464,7 @@ class DjangoScanRepository:
'endpoints': scan.endpoint_snapshots.count(),
'ips': ips_count,
'directories': scan.directory_snapshots.count(),
'screenshots': scan.screenshot_snapshots.count(),
'vulns_total': total_vulns,
'vulns_critical': severity_stats['critical'],
'vulns_high': severity_stats['high'],
@@ -478,6 +479,7 @@ class DjangoScanRepository:
'cached_endpoints_count': stats['endpoints'],
'cached_ips_count': stats['ips'],
'cached_directories_count': stats['directories'],
'cached_screenshots_count': stats['screenshots'],
'cached_vulns_total': stats['vulns_total'],
'cached_vulns_critical': stats['vulns_critical'],
'cached_vulns_high': stats['vulns_high'],

View File

@@ -11,109 +11,6 @@ import os
import traceback
def diagnose_prefect_environment():
"""诊断 Prefect 运行环境,输出详细信息用于排查问题"""
print("\n" + "="*60)
print("Prefect 环境诊断")
print("="*60)
# 1. 检查 Prefect 相关环境变量
print("\n[诊断] Prefect 环境变量:")
prefect_vars = [
'PREFECT_HOME',
'PREFECT_API_URL',
'PREFECT_SERVER_EPHEMERAL_ENABLED',
'PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS',
'PREFECT_SERVER_DATABASE_CONNECTION_URL',
'PREFECT_LOGGING_LEVEL',
'PREFECT_DEBUG_MODE',
]
for var in prefect_vars:
value = os.environ.get(var, 'NOT SET')
print(f" {var}={value}")
# 2. 检查 PREFECT_HOME 目录
prefect_home = os.environ.get('PREFECT_HOME', os.path.expanduser('~/.prefect'))
print(f"\n[诊断] PREFECT_HOME 目录: {prefect_home}")
if os.path.exists(prefect_home):
print(f" ✓ 目录存在")
print(f" 可写: {os.access(prefect_home, os.W_OK)}")
try:
files = os.listdir(prefect_home)
print(f" 文件列表: {files[:10]}{'...' if len(files) > 10 else ''}")
except Exception as e:
print(f" ✗ 无法列出文件: {e}")
else:
print(f" 目录不存在,尝试创建...")
try:
os.makedirs(prefect_home, exist_ok=True)
print(f" ✓ 创建成功")
except Exception as e:
print(f" ✗ 创建失败: {e}")
# 3. 检查 uvicorn 是否可用
print(f"\n[诊断] uvicorn 可用性:")
import shutil
uvicorn_path = shutil.which('uvicorn')
if uvicorn_path:
print(f" ✓ uvicorn 路径: {uvicorn_path}")
else:
print(f" ✗ uvicorn 不在 PATH 中")
print(f" PATH: {os.environ.get('PATH', 'NOT SET')}")
# 4. 检查 Prefect 版本
print(f"\n[诊断] Prefect 版本:")
try:
import prefect
print(f" ✓ prefect=={prefect.__version__}")
except Exception as e:
print(f" ✗ 无法导入 prefect: {e}")
# 5. 检查 SQLite 支持
print(f"\n[诊断] SQLite 支持:")
try:
import sqlite3
print(f" ✓ sqlite3 版本: {sqlite3.sqlite_version}")
# 测试创建数据库
test_db = os.path.join(prefect_home, 'test.db')
conn = sqlite3.connect(test_db)
conn.execute('CREATE TABLE IF NOT EXISTS test (id INTEGER)')
conn.close()
os.remove(test_db)
print(f" ✓ SQLite 读写测试通过")
except Exception as e:
print(f" ✗ SQLite 测试失败: {e}")
# 6. 检查端口绑定能力
print(f"\n[诊断] 端口绑定测试:")
try:
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('127.0.0.1', 0))
port = sock.getsockname()[1]
sock.close()
print(f" ✓ 可以绑定 127.0.0.1 端口 (测试端口: {port})")
except Exception as e:
print(f" ✗ 端口绑定失败: {e}")
# 7. 检查内存情况
print(f"\n[诊断] 系统资源:")
try:
import psutil
mem = psutil.virtual_memory()
print(f" 内存总量: {mem.total / 1024 / 1024:.0f} MB")
print(f" 可用内存: {mem.available / 1024 / 1024:.0f} MB")
print(f" 内存使用率: {mem.percent}%")
except ImportError:
print(f" psutil 未安装,跳过内存检查")
except Exception as e:
print(f" ✗ 资源检查失败: {e}")
print("\n" + "="*60)
print("诊断完成")
print("="*60 + "\n")
def main():
print("="*60)
print("run_initiate_scan.py 启动")
@@ -137,25 +34,19 @@ def main():
print("[2/4] 解析命令行参数...")
parser = argparse.ArgumentParser(description="执行扫描初始化 Flow")
parser.add_argument("--scan_id", type=int, required=True, help="扫描任务 ID")
parser.add_argument("--target_name", type=str, required=True, help="目标名称")
parser.add_argument("--target_id", type=int, required=True, help="目标 ID")
parser.add_argument("--scan_workspace_dir", type=str, required=True, help="扫描工作目录")
parser.add_argument("--engine_name", type=str, required=True, help="引擎名称")
parser.add_argument("--scheduled_scan_name", type=str, default=None, help="定时扫描任务名称(可选)")
args = parser.parse_args()
print(f"[2/4] ✓ 参数解析成功:")
print("[2/4] ✓ 参数解析成功:")
print(f" scan_id: {args.scan_id}")
print(f" target_name: {args.target_name}")
print(f" target_id: {args.target_id}")
print(f" scan_workspace_dir: {args.scan_workspace_dir}")
print(f" engine_name: {args.engine_name}")
print(f" scheduled_scan_name: {args.scheduled_scan_name}")
# 2.5. 运行 Prefect 环境诊断(仅在 DEBUG 模式下)
if os.environ.get('DEBUG', '').lower() == 'true':
diagnose_prefect_environment()
# 3. 现在可以安全导入 Django 相关模块
print("[3/4] 导入 initiate_scan_flow...")
try:
@@ -171,7 +62,6 @@ def main():
try:
result = initiate_scan_flow(
scan_id=args.scan_id,
target_name=args.target_name,
target_id=args.target_id,
scan_workspace_dir=args.scan_workspace_dir,
engine_name=args.engine_name,

View File

@@ -15,11 +15,11 @@ class ScanSerializer(serializers.ModelSerializer):
fields = [
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
'created_at', 'stopped_at', 'status', 'results_dir',
'container_ids', 'error_message'
'container_ids', 'error_message', 'scan_mode'
]
read_only_fields = [
'id', 'created_at', 'stopped_at', 'results_dir',
'container_ids', 'error_message', 'status'
'container_ids', 'error_message', 'status', 'scan_mode'
]
def get_target_name(self, obj):
@@ -39,9 +39,10 @@ class ScanHistorySerializer(serializers.ModelSerializer):
class Meta:
model = Scan
fields = [
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
'worker_name', 'created_at', 'status', 'error_message', 'summary',
'progress', 'current_stage', 'stage_progress'
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
'worker_name', 'created_at', 'status', 'error_message', 'summary',
'progress', 'current_stage', 'stage_progress', 'yaml_configuration',
'scan_mode'
]
def get_summary(self, obj):
@@ -51,6 +52,7 @@ class ScanHistorySerializer(serializers.ModelSerializer):
'endpoints': obj.cached_endpoints_count or 0,
'ips': obj.cached_ips_count or 0,
'directories': obj.cached_directories_count or 0,
'screenshots': obj.cached_screenshots_count or 0,
}
summary['vulnerabilities'] = {
'total': obj.cached_vulns_total or 0,

View File

@@ -17,23 +17,15 @@ from .scan_state_service import ScanStateService
from .scan_control_service import ScanControlService
from .scan_stats_service import ScanStatsService
from .scheduled_scan_service import ScheduledScanService
from .target_export_service import (
TargetExportService,
create_export_service,
export_urls_with_fallback,
DataSource,
)
from .scan_input_target_service import ScanInputTargetService
__all__ = [
'ScanService', # 主入口(向后兼容)
'ScanService',
'ScanCreationService',
'ScanStateService',
'ScanControlService',
'ScanStatsService',
'ScheduledScanService',
'TargetExportService', # 目标导出服务
'create_export_service',
'export_urls_with_fallback',
'DataSource',
'ScanInputTargetService',
]

View File

@@ -5,13 +5,16 @@
"""
import logging
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional, Literal, List, Dict, Any
from urllib.parse import urlparse
from django.db import transaction
from apps.common.validators import validate_url, detect_input_type, validate_domain, validate_ip, validate_cidr, is_valid_ip
from apps.common.validators import (
validate_url, detect_input_type, validate_domain,
validate_ip, validate_cidr, is_valid_ip
)
from apps.targets.services.target_service import TargetService
from apps.targets.models import Target
from apps.asset.dtos import WebSiteDTO
@@ -24,98 +27,72 @@ logger = logging.getLogger(__name__)
@dataclass
class ParsedInputDTO:
"""
解析输入 DTO
只在快速扫描流程中使用
"""
"""解析输入 DTO只在快速扫描流程中使用"""
original_input: str
input_type: Literal['url', 'domain', 'ip', 'cidr']
target_name: str # host/domain/ip/cidr
target_name: str
target_type: Literal['domain', 'ip', 'cidr']
website_url: Optional[str] = None # 根 URLscheme://host[:port]
endpoint_url: Optional[str] = None # 完整 URL含路径
is_valid: bool = True
error: Optional[str] = None
website_url: Optional[str] = None
endpoint_url: Optional[str] = None
line_number: Optional[int] = None
# 验证状态放在嵌套结构中,减少顶层属性数量
validation: Dict[str, Any] = field(default_factory=lambda: {
'is_valid': True,
'error': None
})
@property
def is_valid(self) -> bool:
return self.validation.get('is_valid', True)
@property
def error(self) -> Optional[str]:
return self.validation.get('error')
class QuickScanService:
"""快速扫描服务 - 解析输入并创建资产"""
def __init__(self):
self.target_service = TargetService()
self.website_repo = DjangoWebSiteRepository()
self.endpoint_repo = DjangoEndpointRepository()
def parse_inputs(self, inputs: List[str]) -> List[ParsedInputDTO]:
"""
解析多行输入
Args:
inputs: 输入字符串列表(每行一个)
Returns:
解析结果列表(跳过空行)
"""
"""解析多行输入,返回解析结果列表(跳过空行)"""
results = []
for line_number, input_str in enumerate(inputs, start=1):
input_str = input_str.strip()
# 空行跳过
if not input_str:
continue
try:
# 检测输入类型
input_type = detect_input_type(input_str)
if input_type == 'url':
dto = self._parse_url_input(input_str, line_number)
else:
dto = self._parse_target_input(input_str, input_type, line_number)
results.append(dto)
except ValueError as e:
# 解析失败,记录错误
results.append(ParsedInputDTO(
original_input=input_str,
input_type='domain', # 默认类型
input_type='domain',
target_name=input_str,
target_type='domain',
is_valid=False,
error=str(e),
line_number=line_number
line_number=line_number,
validation={'is_valid': False, 'error': str(e)}
))
return results
def _parse_url_input(self, url_str: str, line_number: int) -> ParsedInputDTO:
"""
解析 URL 输入
Args:
url_str: URL 字符串
line_number: 行号
Returns:
ParsedInputDTO
"""
# 验证 URL 格式
"""解析 URL 输入"""
validate_url(url_str)
# 使用标准库解析
parsed = urlparse(url_str)
host = parsed.hostname # 不含端口
host = parsed.hostname
has_path = parsed.path and parsed.path != '/'
# 构建 root_url: scheme://host[:port]
root_url = f"{parsed.scheme}://{parsed.netloc}"
# 检测 host 类型domain 或 ip
target_type = 'ip' if is_valid_ip(host) else 'domain'
return ParsedInputDTO(
original_input=url_str,
input_type='url',
@@ -125,167 +102,98 @@ class QuickScanService:
endpoint_url=url_str if has_path else None,
line_number=line_number
)
def _parse_target_input(
self,
input_str: str,
input_type: str,
self,
input_str: str,
input_type: str,
line_number: int
) -> ParsedInputDTO:
"""
解析非 URL 输入domain/ip/cidr
Args:
input_str: 输入字符串
input_type: 输入类型
line_number: 行号
Returns:
ParsedInputDTO
"""
# 验证格式
if input_type == 'domain':
validate_domain(input_str)
target_type = 'domain'
elif input_type == 'ip':
validate_ip(input_str)
target_type = 'ip'
elif input_type == 'cidr':
validate_cidr(input_str)
target_type = 'cidr'
else:
"""解析非 URL 输入domain/ip/cidr"""
validators = {
'domain': (validate_domain, 'domain'),
'ip': (validate_ip, 'ip'),
'cidr': (validate_cidr, 'cidr'),
}
if input_type not in validators:
raise ValueError(f"未知的输入类型: {input_type}")
validator, target_type = validators[input_type]
validator(input_str)
return ParsedInputDTO(
original_input=input_str,
input_type=input_type,
target_name=input_str,
target_type=target_type,
website_url=None,
endpoint_url=None,
line_number=line_number
)
@transaction.atomic
def process_quick_scan(
self,
inputs: List[str],
engine_id: int
) -> Dict[str, Any]:
"""
处理快速扫描请求
Args:
inputs: 输入字符串列表
engine_id: 扫描引擎 ID
Returns:
处理结果字典
"""
# 1. 解析输入
def process_quick_scan(self, inputs: List[str]) -> Dict[str, Any]:
"""处理快速扫描请求"""
parsed_inputs = self.parse_inputs(inputs)
# 分离有效和无效输入
valid_inputs = [p for p in parsed_inputs if p.is_valid]
invalid_inputs = [p for p in parsed_inputs if not p.is_valid]
errors = [
{'line_number': p.line_number, 'input': p.original_input, 'error': p.error}
for p in invalid_inputs
]
if not valid_inputs:
return {
'targets': [],
'target_stats': {'created': 0, 'reused': 0, 'failed': len(invalid_inputs)},
'asset_stats': {'websites_created': 0, 'endpoints_created': 0},
'errors': [
{'line_number': p.line_number, 'input': p.original_input, 'error': p.error}
for p in invalid_inputs
]
'errors': errors
}
# 2. 创建资产
asset_result = self.create_assets_from_parsed_inputs(valid_inputs)
# 3. 返回结果
# 构建 target_name → inputs 映射
target_inputs_map: Dict[str, List[str]] = {}
for p in valid_inputs:
target_inputs_map.setdefault(p.target_name, []).append(p.original_input)
return {
'targets': asset_result['targets'],
'target_stats': asset_result['target_stats'],
'asset_stats': asset_result['asset_stats'],
'errors': [
{'line_number': p.line_number, 'input': p.original_input, 'error': p.error}
for p in invalid_inputs
]
'target_inputs_map': target_inputs_map,
'errors': errors
}
def create_assets_from_parsed_inputs(
self,
self,
parsed_inputs: List[ParsedInputDTO]
) -> Dict[str, Any]:
"""
从解析结果创建资产
Args:
parsed_inputs: 解析结果列表(只包含有效输入)
Returns:
创建结果字典
"""
# 1. 收集所有 target 数据(内存操作,去重)
targets_data = {}
for dto in parsed_inputs:
if dto.target_name not in targets_data:
targets_data[dto.target_name] = {'name': dto.target_name, 'type': dto.target_type}
"""从解析结果创建资产(只包含有效输入)"""
# 1. 收集并去重 target 数据
targets_data = {
dto.target_name: {'name': dto.target_name, 'type': dto.target_type}
for dto in parsed_inputs
}
targets_list = list(targets_data.values())
# 2. 批量创建 Target(复用现有方法)
# 2. 批量创建 Target
target_result = self.target_service.batch_create_targets(targets_list)
# 3. 查询刚创建的 Target建立 name → id 映射
# 3. 建立 name → id 映射
target_names = [d['name'] for d in targets_list]
targets = Target.objects.filter(name__in=target_names)
target_id_map = {t.name: t.id for t in targets}
# 4. 收集 Website DTO内存操作去重
website_dtos = []
seen_websites = set()
for dto in parsed_inputs:
if dto.website_url and dto.website_url not in seen_websites:
seen_websites.add(dto.website_url)
target_id = target_id_map.get(dto.target_name)
if target_id:
website_dtos.append(WebSiteDTO(
target_id=target_id,
url=dto.website_url,
host=dto.target_name
))
# 5. 批量创建 Website存在即跳过
websites_created = 0
if website_dtos:
websites_created = self.website_repo.bulk_create_ignore_conflicts(website_dtos)
# 6. 收集 Endpoint DTO内存操作去重
endpoint_dtos = []
seen_endpoints = set()
for dto in parsed_inputs:
if dto.endpoint_url and dto.endpoint_url not in seen_endpoints:
seen_endpoints.add(dto.endpoint_url)
target_id = target_id_map.get(dto.target_name)
if target_id:
endpoint_dtos.append(EndpointDTO(
target_id=target_id,
url=dto.endpoint_url,
host=dto.target_name
))
# 7. 批量创建 Endpoint存在即跳过
endpoints_created = 0
if endpoint_dtos:
endpoints_created = self.endpoint_repo.bulk_create_ignore_conflicts(endpoint_dtos)
# 4. 批量创建 Website 和 Endpoint
websites_created = self._bulk_create_websites(parsed_inputs, target_id_map)
endpoints_created = self._bulk_create_endpoints(parsed_inputs, target_id_map)
return {
'targets': list(targets),
'target_stats': {
'created': target_result['created_count'],
'reused': 0, # bulk_create 无法区分新建和复用
'reused': 0,
'failed': target_result['failed_count']
},
'asset_stats': {
@@ -293,3 +201,53 @@ class QuickScanService:
'endpoints_created': endpoints_created
}
}
def _bulk_create_websites(
self,
parsed_inputs: List[ParsedInputDTO],
target_id_map: Dict[str, int]
) -> int:
"""批量创建 Website存在即跳过"""
website_dtos = []
seen = set()
for dto in parsed_inputs:
if not dto.website_url or dto.website_url in seen:
continue
seen.add(dto.website_url)
target_id = target_id_map.get(dto.target_name)
if target_id:
website_dtos.append(WebSiteDTO(
target_id=target_id,
url=dto.website_url,
host=dto.target_name
))
if not website_dtos:
return 0
return self.website_repo.bulk_create_ignore_conflicts(website_dtos)
def _bulk_create_endpoints(
self,
parsed_inputs: List[ParsedInputDTO],
target_id_map: Dict[str, int]
) -> int:
"""批量创建 Endpoint存在即跳过"""
endpoint_dtos = []
seen = set()
for dto in parsed_inputs:
if not dto.endpoint_url or dto.endpoint_url in seen:
continue
seen.add(dto.endpoint_url)
target_id = target_id_map.get(dto.target_name)
if target_id:
endpoint_dtos.append(EndpointDTO(
target_id=target_id,
url=dto.endpoint_url,
host=dto.target_name
))
if not endpoint_dtos:
return 0
return self.endpoint_repo.bulk_create_ignore_conflicts(endpoint_dtos)

View File

@@ -283,7 +283,8 @@ class ScanCreationService:
engine_ids: List[int],
engine_names: List[str],
yaml_configuration: str,
scheduled_scan_name: str | None = None
scheduled_scan_name: str | None = None,
scan_mode: str = 'full'
) -> List[Scan]:
"""
为多个目标批量创建扫描任务,后台异步分发到 Worker
@@ -294,6 +295,7 @@ class ScanCreationService:
engine_names: 引擎名称列表
yaml_configuration: YAML 格式的扫描配置
scheduled_scan_name: 定时扫描任务名称(可选,用于通知显示)
scan_mode: 扫描模式,'full''quick'(默认 'full'
Returns:
创建的 Scan 对象列表(立即返回,不等待分发完成)
@@ -316,6 +318,7 @@ class ScanCreationService:
results_dir=scan_workspace_dir,
status=ScanStatus.INITIATED,
container_ids=[],
scan_mode=scan_mode,
)
scans_to_create.append(scan)
except (ValidationError, ValueError) as e:
@@ -392,13 +395,13 @@ class ScanCreationService:
for data in scan_data:
scan_id = data['scan_id']
logger.info("-"*40)
logger.info("准备分发扫描任务 - Scan ID: %s, Target: %s", scan_id, data['target_name'])
logger.info("准备分发扫描任务 - Scan ID: %s, Target ID: %s", scan_id, data['target_id'])
try:
logger.info("调用 distributor.execute_scan_flow...")
success, message, container_id, worker_id = distributor.execute_scan_flow(
scan_id=scan_id,
target_name=data['target_name'],
target_id=data['target_id'],
target_name=data['target_name'],
scan_workspace_dir=data['results_dir'],
engine_name=data['engine_name'],
scheduled_scan_name=data.get('scheduled_scan_name'),

View File

@@ -0,0 +1,54 @@
"""
扫描输入目标服务
提供 ScanInputTarget 的写入操作。
"""
import logging
from typing import List
from apps.common.validators import detect_input_type
from apps.scan.models import ScanInputTarget
logger = logging.getLogger(__name__)
class ScanInputTargetService:
"""扫描输入目标服务,负责批量写入操作。"""
BATCH_SIZE = 1000
def bulk_create(self, scan_id: int, inputs: List[str]) -> int:
"""
批量创建扫描输入目标
Args:
scan_id: 扫描任务 ID
inputs: 输入字符串列表
Returns:
创建的记录数
"""
if not inputs:
return 0
records = []
for raw_input in inputs:
value = raw_input.strip()
if not value:
continue
try:
records.append(ScanInputTarget(
scan_id=scan_id,
value=value,
input_type=detect_input_type(value)
))
except ValueError as e:
logger.warning("跳过无效输入 '%s': %s", value, e)
if not records:
return 0
ScanInputTarget.objects.bulk_create(records, batch_size=self.BATCH_SIZE)
logger.info("批量创建 %d 条扫描输入目标 (scan_id=%d)", len(records), scan_id)
return len(records)

View File

@@ -1,25 +1,17 @@
"""
扫描任务服务
负责 Scan 模型的所有业务逻辑
负责 Scan 模型的所有业务逻辑,协调各个子服务
"""
from __future__ import annotations
import logging
import uuid
from typing import Dict, List, TYPE_CHECKING
from datetime import datetime
from pathlib import Path
from django.conf import settings
from django.db import transaction
from django.db.utils import DatabaseError, IntegrityError, OperationalError
from django.core.exceptions import ValidationError, ObjectDoesNotExist
from typing import Dict, List
from apps.scan.models import Scan
from apps.scan.repositories import DjangoScanRepository
from apps.targets.repositories import DjangoTargetRepository, DjangoOrganizationRepository
from apps.engine.repositories import DjangoEngineRepository
from apps.targets.models import Target
from apps.engine.models import ScanEngine
from apps.common.definitions import ScanStatus
@@ -30,115 +22,84 @@ logger = logging.getLogger(__name__)
class ScanService:
"""
扫描任务服务(协调者)
职责:
- 协调各个子服务
- 提供统一的公共接口
- 保持向后兼容
注意:
- 具体业务逻辑已拆分到子服务
- 本类主要负责委托和协调
职责:协调各个子服务,提供统一的公共接口
"""
# 终态集合:这些状态一旦设置,不应该被覆盖
FINAL_STATUSES = {
ScanStatus.COMPLETED,
ScanStatus.FAILED,
ScanStatus.CANCELLED
}
def __init__(self):
"""
初始化服务
"""
# 初始化子服务
from apps.scan.services.scan_creation_service import ScanCreationService
from apps.scan.services.scan_state_service import ScanStateService
from apps.scan.services.scan_control_service import ScanControlService
from apps.scan.services.scan_stats_service import ScanStatsService
self.creation_service = ScanCreationService()
self.state_service = ScanStateService()
self.control_service = ScanControlService()
self.stats_service = ScanStatsService()
# 保留 ScanRepository用于 get_scan 方法)
self.scan_repo = DjangoScanRepository()
def get_scan(self, scan_id: int, prefetch_relations: bool) -> Scan | None:
"""
获取扫描任务(包含关联对象)
自动预加载 engine 和 target避免 N+1 查询问题
Args:
scan_id: 扫描任务 ID
Returns:
Scan 对象(包含 engine 和 target或 None
"""
"""获取扫描任务(包含关联对象)"""
return self.scan_repo.get_by_id(scan_id, prefetch_relations)
def get_all_scans(self, prefetch_relations: bool = True):
"""获取所有扫描任务"""
return self.scan_repo.get_all(prefetch_relations=prefetch_relations)
def prepare_initiate_scan(
self,
organization_id: int | None = None,
target_id: int | None = None,
engine_id: int | None = None
) -> tuple[List[Target], ScanEngine]:
"""
为创建扫描任务做准备,返回所需的目标列表和扫描引擎
"""
"""为创建扫描任务做准备,返回目标列表和扫描引擎"""
return self.creation_service.prepare_initiate_scan(
organization_id, target_id, engine_id
)
def prepare_initiate_scan_multi_engine(
self,
organization_id: int | None = None,
target_id: int | None = None,
engine_ids: List[int] | None = None
) -> tuple[List[Target], str, List[str], List[int]]:
"""
为创建多引擎扫描任务做准备
Returns:
(目标列表, 合并配置, 引擎名称列表, 引擎ID列表)
"""
"""为创建多引擎扫描任务做准备"""
return self.creation_service.prepare_initiate_scan_multi_engine(
organization_id, target_id, engine_ids
)
def create_scans(
self,
targets: List[Target],
engine_ids: List[int],
engine_names: List[str],
yaml_configuration: str,
scheduled_scan_name: str | None = None
scheduled_scan_name: str | None = None,
scan_mode: str = 'full'
) -> List[Scan]:
"""批量创建扫描任务(委托给 ScanCreationService"""
"""批量创建扫描任务"""
return self.creation_service.create_scans(
targets, engine_ids, engine_names, yaml_configuration, scheduled_scan_name
targets, engine_ids, engine_names, yaml_configuration, scheduled_scan_name, scan_mode
)
# ==================== 状态管理方法(委托给 ScanStateService ====================
# ==================== 状态管理方法 ====================
def update_status(
self,
scan_id: int,
status: ScanStatus,
self,
scan_id: int,
status: ScanStatus,
error_message: str | None = None,
stopped_at: datetime | None = None
) -> bool:
"""更新 Scan 状态(委托给 ScanStateService"""
return self.state_service.update_status(
scan_id, status, error_message, stopped_at
)
"""更新 Scan 状态"""
return self.state_service.update_status(scan_id, status, error_message, stopped_at)
def update_status_if_match(
self,
scan_id: int,
@@ -146,113 +107,56 @@ class ScanService:
new_status: ScanStatus,
stopped_at: datetime | None = None
) -> bool:
"""条件更新 Scan 状态(委托给 ScanStateService"""
"""条件更新 Scan 状态"""
return self.state_service.update_status_if_match(
scan_id, current_status, new_status, stopped_at
)
def update_cached_stats(self, scan_id: int) -> dict | None:
"""更新缓存统计数据(委托给 ScanStateService,返回统计数据字典"""
"""更新缓存统计数据,返回统计数据字典"""
return self.state_service.update_cached_stats(scan_id)
# ==================== 进度跟踪方法(委托给 ScanStateService ====================
# ==================== 进度跟踪方法 ====================
def init_stage_progress(self, scan_id: int, stages: list[str]) -> bool:
"""初始化阶段进度(委托给 ScanStateService"""
"""初始化阶段进度"""
return self.state_service.init_stage_progress(scan_id, stages)
def start_stage(self, scan_id: int, stage: str) -> bool:
"""开始执行某个阶段(委托给 ScanStateService"""
"""开始执行某个阶段"""
return self.state_service.start_stage(scan_id, stage)
def complete_stage(self, scan_id: int, stage: str, detail: str | None = None) -> bool:
"""完成某个阶段(委托给 ScanStateService"""
"""完成某个阶段"""
return self.state_service.complete_stage(scan_id, stage, detail)
def fail_stage(self, scan_id: int, stage: str, error: str | None = None) -> bool:
"""标记某个阶段失败(委托给 ScanStateService"""
"""标记某个阶段失败"""
return self.state_service.fail_stage(scan_id, stage, error)
def cancel_running_stages(self, scan_id: int, final_status: str = "cancelled") -> bool:
"""取消所有正在运行的阶段(委托给 ScanStateService"""
"""取消所有正在运行的阶段"""
return self.state_service.cancel_running_stages(scan_id, final_status)
# TODO待接入
def add_command_to_scan(self, scan_id: int, stage_name: str, tool_name: str, command: str) -> bool:
"""
增量添加命令到指定扫描阶段
Args:
scan_id: 扫描任务ID
stage_name: 阶段名称(如 'subdomain_discovery', 'port_scan'
tool_name: 工具名称
command: 执行命令
Returns:
bool: 是否成功添加
"""
try:
scan = self.get_scan(scan_id, prefetch_relations=False)
if not scan:
logger.error(f"扫描任务不存在: {scan_id}")
return False
stage_progress = scan.stage_progress or {}
# 确保指定阶段存在
if stage_name not in stage_progress:
stage_progress[stage_name] = {'status': 'running', 'commands': []}
# 确保 commands 列表存在
if 'commands' not in stage_progress[stage_name]:
stage_progress[stage_name]['commands'] = []
# 增量添加命令
command_entry = f"{tool_name}: {command}"
stage_progress[stage_name]['commands'].append(command_entry)
scan.stage_progress = stage_progress
scan.save(update_fields=['stage_progress'])
command_count = len(stage_progress[stage_name]['commands'])
logger.info(f"✓ 记录命令: {stage_name}.{tool_name} (总计: {command_count})")
return True
except Exception as e:
logger.error(f"记录命令失败: {e}")
return False
# ==================== 删除和控制方法(委托给 ScanControlService ====================
# ==================== 删除和控制方法 ====================
def delete_scans_two_phase(self, scan_ids: List[int]) -> dict:
"""两阶段删除扫描任务(委托给 ScanControlService"""
"""两阶段删除扫描任务"""
return self.control_service.delete_scans_two_phase(scan_ids)
def stop_scan(self, scan_id: int) -> tuple[bool, int]:
"""停止扫描任务(委托给 ScanControlService"""
"""停止扫描任务"""
return self.control_service.stop_scan(scan_id)
def hard_delete_scans(self, scan_ids: List[int]) -> tuple[int, Dict[str, int]]:
"""
硬删除扫描任务(真正删除数据)
用于 Worker 容器中执行,删除已软删除的扫描及其关联数据。
Args:
scan_ids: 扫描任务 ID 列表
Returns:
(删除数量, 详情字典)
"""
"""硬删除扫描任务(真正删除数据)"""
return self.scan_repo.hard_delete_by_ids(scan_ids)
# ==================== 统计方法(委托给 ScanStatsService ====================
# ==================== 统计方法 ====================
def get_statistics(self) -> dict:
"""获取扫描统计数据(委托给 ScanStatsService"""
"""获取扫描统计数据"""
return self.stats_service.get_statistics()
# 导出接口
__all__ = ['ScanService']

View File

@@ -1,526 +0,0 @@
"""
目标导出服务
提供统一的目标提取和文件导出功能,支持:
- URL 导出(纯导出,不做隐式回退)
- 默认 URL 生成(独立方法)
- 带回退链的 URL 导出(用例层编排)
- 域名/IP 导出(用于端口扫描)
- 黑名单过滤集成
"""
import ipaddress
import logging
from pathlib import Path
from typing import Dict, Any, Optional, List, Callable
from django.db.models import QuerySet
from apps.common.utils import BlacklistFilter
logger = logging.getLogger(__name__)
class DataSource:
"""数据源类型常量"""
ENDPOINT = "endpoint"
WEBSITE = "website"
HOST_PORT = "host_port"
DEFAULT = "default"
def create_export_service(target_id: int) -> 'TargetExportService':
"""
工厂函数:创建带黑名单过滤的导出服务
Args:
target_id: 目标 ID用于加载黑名单规则
Returns:
TargetExportService: 配置好黑名单过滤器的导出服务实例
"""
from apps.common.services import BlacklistService
rules = BlacklistService().get_rules(target_id)
blacklist_filter = BlacklistFilter(rules)
return TargetExportService(blacklist_filter=blacklist_filter)
def export_urls_with_fallback(
target_id: int,
output_file: str,
sources: List[str],
batch_size: int = 1000
) -> Dict[str, Any]:
"""
带回退链的 URL 导出用例函数
按 sources 顺序尝试每个数据源,直到有数据返回。
回退逻辑:
1. 遍历 sources 列表
2. 对每个 source 构建 queryset 并调用 export_urls()
3. 如果 total_count > 0返回
4. 如果 queryset_count > 0 但 total_count == 0全被黑名单过滤不回退
5. 如果 source == "default",调用 generate_default_urls()
Args:
target_id: 目标 ID
output_file: 输出文件路径
sources: 数据源优先级列表,如 ["endpoint", "website", "default"]
batch_size: 批次大小
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int,
'source': str, # 实际使用的数据源
'tried_sources': List[str], # 尝试过的数据源
}
"""
from apps.asset.models import Endpoint, WebSite
export_service = create_export_service(target_id)
tried_sources = []
for source in sources:
tried_sources.append(source)
if source == DataSource.DEFAULT:
# 默认 URL 生成
result = export_service.generate_default_urls(target_id, output_file)
return {
'success': result['success'],
'output_file': result['output_file'],
'total_count': result['total_count'],
'source': DataSource.DEFAULT,
'tried_sources': tried_sources,
}
# 构建对应数据源的 queryset
if source == DataSource.ENDPOINT:
queryset = Endpoint.objects.filter(target_id=target_id).values_list('url', flat=True)
elif source == DataSource.WEBSITE:
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
else:
logger.warning("未知的数据源类型: %s,跳过", source)
continue
result = export_service.export_urls(
target_id=target_id,
output_path=output_file,
queryset=queryset,
batch_size=batch_size
)
# 有数据写入,返回
if result['total_count'] > 0:
logger.info("%s 导出 %d 条 URL", source, result['total_count'])
return {
'success': result['success'],
'output_file': result['output_file'],
'total_count': result['total_count'],
'source': source,
'tried_sources': tried_sources,
}
# 数据存在但全被黑名单过滤,不回退
if result['queryset_count'] > 0:
logger.info(
"%s%d 条数据但全被黑名单过滤filtered=%d),不回退",
source, result['queryset_count'], result['filtered_count']
)
return {
'success': result['success'],
'output_file': result['output_file'],
'total_count': 0,
'source': source,
'tried_sources': tried_sources,
}
# 数据源为空,继续尝试下一个
logger.info("%s 为空,尝试下一个数据源", source)
# 所有数据源都为空
logger.warning("所有数据源都为空,无法导出 URL")
return {
'success': True,
'output_file': output_file,
'total_count': 0,
'source': 'none',
'tried_sources': tried_sources,
}
class TargetExportService:
"""
目标导出服务 - 提供统一的目标提取和文件导出功能
使用方式:
# 方式 1使用用例函数推荐
from apps.scan.services.target_export_service import export_urls_with_fallback, DataSource
result = export_urls_with_fallback(
target_id=1,
output_file='/path/to/output.txt',
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT]
)
# 方式 2直接使用 Service纯导出不带回退
export_service = create_export_service(target_id)
result = export_service.export_urls(target_id, output_path, queryset)
"""
def __init__(self, blacklist_filter: Optional[BlacklistFilter] = None):
"""
初始化导出服务
Args:
blacklist_filter: 黑名单过滤器None 表示禁用过滤
"""
self.blacklist_filter = blacklist_filter
def export_urls(
self,
target_id: int,
output_path: str,
queryset: QuerySet,
url_field: str = 'url',
batch_size: int = 1000
) -> Dict[str, Any]:
"""
纯 URL 导出函数 - 只负责将 queryset 数据写入文件
不做任何隐式回退或默认 URL 生成。
Args:
target_id: 目标 ID
output_path: 输出文件路径
queryset: 数据源 queryset由调用方构建应为 values_list flat=True
url_field: URL 字段名(用于黑名单过滤)
batch_size: 批次大小
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int, # 实际写入数量
'queryset_count': int, # 原始数据数量(迭代计数)
'filtered_count': int, # 被黑名单过滤的数量
}
Raises:
IOError: 文件写入失败
"""
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
logger.info("开始导出 URL - target_id=%s, output=%s", target_id, output_path)
total_count = 0
filtered_count = 0
queryset_count = 0
try:
with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
for url in queryset.iterator(chunk_size=batch_size):
queryset_count += 1
if url:
# 黑名单过滤
if self.blacklist_filter and not self.blacklist_filter.is_allowed(url):
filtered_count += 1
continue
f.write(f"{url}\n")
total_count += 1
if total_count % 10000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
except IOError as e:
logger.error("文件写入失败: %s - %s", output_path, e)
raise
if filtered_count > 0:
logger.info("黑名单过滤: 过滤 %d 个 URL", filtered_count)
logger.info(
"✓ URL 导出完成 - 写入: %d, 原始: %d, 过滤: %d, 文件: %s",
total_count, queryset_count, filtered_count, output_path
)
return {
'success': True,
'output_file': str(output_file),
'total_count': total_count,
'queryset_count': queryset_count,
'filtered_count': filtered_count,
}
def generate_default_urls(
self,
target_id: int,
output_path: str
) -> Dict[str, Any]:
"""
默认 URL 生成器
根据 Target 类型生成默认 URL
- DOMAIN: http(s)://domain
- IP: http(s)://ip
- CIDR: 展开为所有 IP 的 http(s)://ip
- URL: 直接使用目标 URL
Args:
target_id: 目标 ID
output_path: 输出文件路径
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int,
}
"""
from apps.targets.services import TargetService
from apps.targets.models import Target
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
target_service = TargetService()
target = target_service.get_target(target_id)
if not target:
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
return {
'success': True,
'output_file': str(output_file),
'total_count': 0,
}
target_name = target.name
target_type = target.type
logger.info("生成默认 URLTarget 类型=%s, 名称=%s", target_type, target_name)
total_urls = 0
with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
if target_type == Target.TargetType.DOMAIN:
urls = [f"http://{target_name}", f"https://{target_name}"]
for url in urls:
if self._should_write_url(url):
f.write(f"{url}\n")
total_urls += 1
elif target_type == Target.TargetType.IP:
urls = [f"http://{target_name}", f"https://{target_name}"]
for url in urls:
if self._should_write_url(url):
f.write(f"{url}\n")
total_urls += 1
elif target_type == Target.TargetType.CIDR:
try:
network = ipaddress.ip_network(target_name, strict=False)
for ip in network.hosts():
urls = [f"http://{ip}", f"https://{ip}"]
for url in urls:
if self._should_write_url(url):
f.write(f"{url}\n")
total_urls += 1
if total_urls % 10000 == 0:
logger.info("已生成 %d 个 URL...", total_urls)
# /32 或 /128 特殊处理
if total_urls == 0:
ip = str(network.network_address)
urls = [f"http://{ip}", f"https://{ip}"]
for url in urls:
if self._should_write_url(url):
f.write(f"{url}\n")
total_urls += 1
except ValueError as e:
logger.error("CIDR 解析失败: %s - %s", target_name, e)
raise ValueError(f"无效的 CIDR: {target_name}") from e
elif target_type == Target.TargetType.URL:
if self._should_write_url(target_name):
f.write(f"{target_name}\n")
total_urls = 1
else:
logger.warning("不支持的 Target 类型: %s", target_type)
logger.info("✓ 默认 URL 生成完成 - 数量: %d", total_urls)
return {
'success': True,
'output_file': str(output_file),
'total_count': total_urls,
}
def _should_write_url(self, url: str) -> bool:
"""检查 URL 是否应该写入(通过黑名单过滤)"""
if self.blacklist_filter:
return self.blacklist_filter.is_allowed(url)
return True
def export_hosts(
self,
target_id: int,
output_path: str,
batch_size: int = 1000
) -> Dict[str, Any]:
"""
主机列表导出函数(用于端口扫描)
根据 Target 类型选择导出逻辑:
- DOMAIN: 从 Subdomain 表流式导出子域名
- IP: 直接写入 IP 地址
- CIDR: 展开为所有主机 IP
Args:
target_id: 目标 ID
output_path: 输出文件路径
batch_size: 批次大小
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int,
'target_type': str
}
"""
from apps.targets.services import TargetService
from apps.targets.models import Target
from apps.asset.services.asset.subdomain_service import SubdomainService
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
# 获取 Target 信息
target_service = TargetService()
target = target_service.get_target(target_id)
if not target:
raise ValueError(f"Target ID {target_id} 不存在")
target_type = target.type
target_name = target.name
logger.info(
"开始导出主机列表 - Target ID: %d, Name: %s, Type: %s, 输出文件: %s",
target_id, target_name, target_type, output_path
)
total_count = 0
if target_type == Target.TargetType.DOMAIN:
total_count = self._export_domains(target_id, target_name, output_file, batch_size)
type_desc = "域名"
elif target_type == Target.TargetType.IP:
total_count = self._export_ip(target_name, output_file)
type_desc = "IP"
elif target_type == Target.TargetType.CIDR:
total_count = self._export_cidr(target_name, output_file)
type_desc = "CIDR IP"
else:
raise ValueError(f"不支持的目标类型: {target_type}")
logger.info(
"✓ 主机列表导出完成 - 类型: %s, 总数: %d, 文件: %s",
type_desc, total_count, output_path
)
return {
'success': True,
'output_file': str(output_file),
'total_count': total_count,
'target_type': target_type
}
def _export_domains(
self,
target_id: int,
target_name: str,
output_path: Path,
batch_size: int
) -> int:
"""导出域名类型目标的根域名 + 子域名"""
from apps.asset.services.asset.subdomain_service import SubdomainService
subdomain_service = SubdomainService()
domain_iterator = subdomain_service.iter_subdomain_names_by_target(
target_id=target_id,
chunk_size=batch_size
)
total_count = 0
written_domains = set() # 去重(子域名表可能已包含根域名)
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
# 1. 先写入根域名
if self._should_write_target(target_name):
f.write(f"{target_name}\n")
written_domains.add(target_name)
total_count += 1
# 2. 再写入子域名(跳过已写入的根域名)
for domain_name in domain_iterator:
if domain_name in written_domains:
continue
if self._should_write_target(domain_name):
f.write(f"{domain_name}\n")
written_domains.add(domain_name)
total_count += 1
if total_count % 10000 == 0:
logger.info("已导出 %d 个域名...", total_count)
return total_count
def _export_ip(self, target_name: str, output_path: Path) -> int:
"""导出 IP 类型目标"""
if self._should_write_target(target_name):
with open(output_path, 'w', encoding='utf-8') as f:
f.write(f"{target_name}\n")
return 1
return 0
def _export_cidr(self, target_name: str, output_path: Path) -> int:
"""导出 CIDR 类型目标,展开为每个 IP"""
network = ipaddress.ip_network(target_name, strict=False)
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for ip in network.hosts():
ip_str = str(ip)
if self._should_write_target(ip_str):
f.write(f"{ip_str}\n")
total_count += 1
if total_count % 10000 == 0:
logger.info("已导出 %d 个 IP...", total_count)
# /32 或 /128 特殊处理
if total_count == 0:
ip_str = str(network.network_address)
if self._should_write_target(ip_str):
with open(output_path, 'w', encoding='utf-8') as f:
f.write(f"{ip_str}\n")
total_count = 1
return total_count
def _should_write_target(self, target: str) -> bool:
"""检查目标是否应该写入(通过黑名单过滤)"""
if self.blacklist_filter:
return self.blacklist_filter.is_allowed(target)
return True

View File

@@ -18,7 +18,7 @@ from .subdomain_discovery import (
# 指纹识别任务
from .fingerprint_detect import (
export_urls_for_fingerprint_task,
export_site_urls_for_fingerprint_task,
run_xingfinger_and_stream_update_tech_task,
)
@@ -35,6 +35,6 @@ __all__ = [
'merge_and_validate_task',
'save_domains_task',
# 指纹识别任务
'export_urls_for_fingerprint_task',
'export_site_urls_for_fingerprint_task',
'run_xingfinger_and_stream_update_tech_task',
]

View File

@@ -1,64 +1,76 @@
"""
导出站点 URL 到 TXT 文件的 Task
使用 export_urls_with_fallback 用例函数处理回退链逻辑
数据源: WebSite.url → Default
使用 TargetProvider 从任意数据源导出 URL用于目录扫描
数据源WebSite为空时回退到默认 URL
"""
import logging
from prefect import task
from pathlib import Path
from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
@task(name="export_sites")
def export_sites_task(
target_id: int,
output_file: str,
batch_size: int = 1000,
provider: TargetProvider,
) -> dict:
"""
导出目标下的所有站点 URL 到 TXT 文件
数据源优先级(回退链):
1. WebSite 表 - 站点级别 URL
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
数据源WebSite为空时回退到默认 URL
Args:
target_id: 目标 ID
output_file: 输出文件路径(绝对路径)
batch_size: 每次读取的批次大小,默认 1000
provider: TargetProvider 实例
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int
'total_count': int,
'source': str, # website | default
}
Raises:
ValueError: 参数错误
IOError: 文件写入失败
ValueError: provider 未提供
"""
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
sources=[DataSource.WEBSITE, DataSource.DEFAULT],
batch_size=batch_size,
)
if provider is None:
raise ValueError("必须提供 provider 参数")
logger.info("导出 URL - Provider: %s", type(provider).__name__)
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 按优先级获取数据源
urls = list(provider.iter_websites())
source = "website"
if not urls:
logger.info("WebSite 为空,生成默认 URL")
urls = list(provider.iter_default_urls())
source = "default"
# 写入文件
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in urls:
f.write(f"{url}\n")
total_count += 1
logger.info(
"站点 URL 导出完成 - source=%s, count=%d",
result['source'], result['total_count']
" URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
source, total_count, str(output_path)
)
# 保持返回值格式不变(向后兼容)
return {
'success': result['success'],
'output_file': result['output_file'],
'total_count': result['total_count'],
'success': True,
'output_file': str(output_path),
'total_count': total_count,
'source': source,
}

View File

@@ -24,7 +24,7 @@ import json
import subprocess
import time
from pathlib import Path
from prefect import task
from typing import Generator, Optional, TYPE_CHECKING
from django.db import IntegrityError, OperationalError, DatabaseError
from psycopg2 import InterfaceError
@@ -305,11 +305,11 @@ def _save_batch(
return len(snapshot_items)
@task(
name='run_and_stream_save_directories',
retries=0,
log_prints=True
)
def run_and_stream_save_directories_task(
cmd: str,
tool_name: str,

View File

@@ -2,14 +2,14 @@
指纹识别任务模块
包含:
- export_urls_for_fingerprint_task: 导出 URL 到文件
- export_site_urls_for_fingerprint_task: 导出站点 URL 到文件
- run_xingfinger_and_stream_update_tech_task: 流式执行 xingfinger 并更新 tech
"""
from .export_urls_task import export_urls_for_fingerprint_task
from .export_site_urls_task import export_site_urls_for_fingerprint_task
from .run_xingfinger_task import run_xingfinger_and_stream_update_tech_task
__all__ = [
'export_urls_for_fingerprint_task',
'export_site_urls_for_fingerprint_task',
'run_xingfinger_and_stream_update_tech_task',
]

View File

@@ -0,0 +1,73 @@
"""
导出站点 URL 任务
使用 TargetProvider 从任意数据源导出站点 URL用于指纹识别
数据源WebSite为空时回退到默认 URL
"""
import logging
from pathlib import Path
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
def export_site_urls_for_fingerprint_task(
output_file: str,
provider: TargetProvider,
) -> dict:
"""
导出目标下的 URL 到文件(用于指纹识别)
数据源WebSite为空时回退到默认 URL
Args:
output_file: 输出文件路径
provider: TargetProvider 实例
Returns:
dict: {
'output_file': str,
'total_count': int,
'source': str, # website | default
}
"""
if provider is None:
raise ValueError("必须提供 provider 参数")
logger.info("导出 URL - Provider: %s", type(provider).__name__)
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 按优先级获取数据源
urls = list(provider.iter_websites())
source = "website"
if not urls:
logger.info("WebSite 为空,生成默认 URL")
urls = list(provider.iter_default_urls())
source = "default"
# 写入文件
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in urls:
f.write(f"{url}\n")
total_count += 1
logger.info(
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
source, total_count, str(output_path)
)
return {
'output_file': str(output_path),
'total_count': total_count,
'source': source,
}

View File

@@ -1,60 +0,0 @@
"""
导出 URL 任务
用于指纹识别前导出目标下的 URL 到文件
使用 export_urls_with_fallback 用例函数处理回退链逻辑
"""
import logging
from prefect import task
from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
logger = logging.getLogger(__name__)
@task(name="export_urls_for_fingerprint")
def export_urls_for_fingerprint_task(
target_id: int,
output_file: str,
source: str = 'website', # 保留参数,兼容旧调用(实际值由回退链决定)
batch_size: int = 1000
) -> dict:
"""
导出目标下的 URL 到文件(用于指纹识别)
数据源优先级(回退链):
1. WebSite 表 - 站点级别 URL
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
Args:
target_id: 目标 ID
output_file: 输出文件路径
source: 数据源类型(保留参数,兼容旧调用,实际值由回退链决定)
batch_size: 批量读取大小
Returns:
dict: {'output_file': str, 'total_count': int, 'source': str}
"""
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
sources=[DataSource.WEBSITE, DataSource.DEFAULT],
batch_size=batch_size,
)
logger.info(
"指纹识别 URL 导出完成 - source=%s, count=%d",
result['source'], result['total_count']
)
# 返回实际使用的数据源(不再固定为 "website"
return {
'output_file': result['output_file'],
'total_count': result['total_count'],
'source': result['source'],
}

View File

@@ -11,7 +11,7 @@ from typing import Optional, Generator
from urllib.parse import urlparse
from django.db import connection
from prefect import task
from apps.scan.utils import execute_stream
from apps.asset.dtos.snapshot import WebsiteSnapshotDTO
@@ -189,7 +189,7 @@ def _parse_xingfinger_stream_output(
logger.info("流式解析完成 - 总行数: %d, 有效记录: %d", total_lines, valid_records)
@task(name="run_xingfinger_and_stream_update_tech")
def run_xingfinger_and_stream_update_tech_task(
cmd: str,
tool_name: str,

View File

@@ -1,65 +1,71 @@
"""
导出主机列表到 TXT 文件的 Task
使用 TargetExportService.export_hosts() 统一处理导出逻辑
根据 Target 类型决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名
- IP: 直接写入 target.name
- CIDR: 展开 CIDR 范围内的所有 IP
使用 TargetProvider 从任意数据源导出主机列表。
"""
import logging
from prefect import task
from pathlib import Path
from apps.scan.services.target_export_service import create_export_service
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
@task(name="export_hosts")
def export_hosts_task(
target_id: int,
output_file: str,
batch_size: int = 1000
provider: TargetProvider,
) -> dict:
"""
导出主机列表到 TXT 文件
根据 Target 类型自动决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名(流式处理,支持 10万+ 域名)
- IP: 直接写入 target.name单个 IP
- CIDR: 展开 CIDR 范围内的所有可用 IP
显式组合 iter_target_hosts() + iter_subdomains()。
Args:
target_id: 目标 ID
output_file: 输出文件路径(绝对路径)
batch_size: 每次读取的批次大小,默认 1000仅对 DOMAIN 类型有效)
provider: TargetProvider 实例
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int,
'target_type': str
}
Raises:
ValueError: Target 不存在
ValueError: provider 未提供
IOError: 文件写入失败
"""
# 使用工厂函数创建导出服务
export_service = create_export_service(target_id)
result = export_service.export_hosts(
target_id=target_id,
output_path=output_file,
batch_size=batch_size
)
# 保持返回值格式不变(向后兼容)
if provider is None:
raise ValueError("必须提供 provider 参数")
logger.info("导出主机列表 - Provider: %s", type(provider).__name__)
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
# 1. 导出 Target 主机CIDR 自动展开,已过滤黑名单)
for host in provider.iter_target_hosts():
f.write(f"{host}\n")
total_count += 1
# 2. 导出子域名Provider 内部已过滤黑名单)
for subdomain in provider.iter_subdomains():
f.write(f"{subdomain}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个主机...", total_count)
logger.info("✓ 主机列表导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
return {
'success': result['success'],
'output_file': result['output_file'],
'total_count': result['total_count'],
'target_type': result['target_type']
'success': True,
'output_file': str(output_path),
'total_count': total_count,
}

View File

@@ -26,7 +26,7 @@ import subprocess
import time
from asyncio import CancelledError
from pathlib import Path
from prefect import task
from typing import Generator, List, Optional, TYPE_CHECKING
from django.db import IntegrityError, OperationalError, DatabaseError
from psycopg2 import InterfaceError
@@ -582,11 +582,11 @@ def _cleanup_resources(data_generator) -> None:
)
@task(
name='run_and_stream_save_ports',
retries=0,
log_prints=True
)
def run_and_stream_save_ports_task(
cmd: str,
tool_name: str,

View File

@@ -0,0 +1,12 @@
"""
截图任务模块
包含截图相关的所有任务:
- capture_screenshots_task: 批量截图任务
"""
from .capture_screenshots_task import capture_screenshots_task
__all__ = [
'capture_screenshots_task',
]

View File

@@ -0,0 +1,194 @@
"""
批量截图任务
使用 Playwright 批量捕获网站截图,压缩后保存到数据库
"""
import asyncio
import logging
import time
logger = logging.getLogger(__name__)
def _run_async(coro):
"""
在同步环境中运行异步协程
Args:
coro: 异步协程
Returns:
协程执行结果
"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)
def _save_screenshot_with_retry(
screenshot_service,
scan_id: int,
url: str,
webp_data: bytes,
status_code: int | None = None,
max_retries: int = 3
) -> bool:
"""
保存截图到数据库(带重试机制)
Args:
screenshot_service: ScreenshotService 实例
scan_id: 扫描 ID
url: URL
webp_data: WebP 图片数据
status_code: HTTP 响应状态码
max_retries: 最大重试次数
Returns:
是否保存成功
"""
for attempt in range(max_retries):
try:
if screenshot_service.save_screenshot_snapshot(scan_id, url, webp_data, status_code):
return True
# save 返回 False等待后重试
if attempt < max_retries - 1:
wait_time = 2 ** attempt # 指数退避1s, 2s, 4s
logger.warning(
"保存截图失败(第 %d 次尝试),%d秒后重试: %s",
attempt + 1, wait_time, url
)
time.sleep(wait_time)
except Exception as e:
if attempt < max_retries - 1:
wait_time = 2 ** attempt
logger.warning(
"保存截图异常(第 %d 次尝试),%d秒后重试: %s, 错误: %s",
attempt + 1, wait_time, url, str(e)[:100]
)
time.sleep(wait_time)
else:
logger.error("保存截图失败(已重试 %d 次): %s", max_retries, url)
return False
async def _capture_and_save_screenshots(
urls: list[str],
scan_id: int,
concurrency: int
) -> dict:
"""
异步批量截图并保存
Args:
urls: URL 列表
scan_id: 扫描 ID
concurrency: 并发数
Returns:
统计信息字典
"""
from asgiref.sync import sync_to_async
from apps.asset.services.playwright_screenshot_service import PlaywrightScreenshotService
from apps.asset.services.screenshot_service import ScreenshotService
# 初始化服务
playwright_service = PlaywrightScreenshotService(concurrency=concurrency)
screenshot_service = ScreenshotService()
# 包装同步的保存函数为异步
async_save_with_retry = sync_to_async(_save_screenshot_with_retry, thread_sensitive=True)
# 统计
total = len(urls)
successful = 0
failed = 0
logger.info("开始批量截图 - URL数: %d, 并发数: %d", total, concurrency)
# 批量截图
async for url, screenshot_bytes, status_code in playwright_service.capture_batch(urls):
if screenshot_bytes is None:
failed += 1
continue
# 压缩为 WebP
webp_data = screenshot_service.compress_from_bytes(screenshot_bytes)
if webp_data is None:
logger.warning("压缩截图失败: %s", url)
failed += 1
continue
# 保存到数据库(带重试,使用 sync_to_async
if await async_save_with_retry(screenshot_service, scan_id, url, webp_data, status_code):
successful += 1
if successful % 10 == 0:
logger.info("截图进度: %d/%d 成功", successful, total)
else:
failed += 1
return {
'total': total,
'successful': successful,
'failed': failed
}
def capture_screenshots_task(
urls: list[str],
scan_id: int,
target_id: int,
config: dict
) -> dict:
"""
批量截图任务
Args:
urls: URL 列表
scan_id: 扫描 ID
target_id: 目标 ID用于日志
config: 截图配置
- concurrency: 并发数(默认 5
Returns:
dict: {
'total': int, # 总 URL 数
'successful': int, # 成功截图数
'failed': int # 失败数
}
"""
if not urls:
logger.info("URL 列表为空,跳过截图任务")
return {'total': 0, 'successful': 0, 'failed': 0}
concurrency = config.get('concurrency', 5)
logger.info(
"开始截图任务 - scan_id=%d, target_id=%d, URL数=%d, 并发=%d",
scan_id, target_id, len(urls), concurrency
)
try:
result = _run_async(_capture_and_save_screenshots(
urls=urls,
scan_id=scan_id,
concurrency=concurrency
))
logger.info(
"✓ 截图任务完成 - 总数: %d, 成功: %d, 失败: %d",
result['total'], result['successful'], result['failed']
)
return result
except Exception as e:
logger.error("截图任务失败: %s", e, exc_info=True)
raise RuntimeError(f"截图任务失败: {e}") from e

View File

@@ -1,153 +1,76 @@
"""
导出站点URL到文件的Task
直接使用 HostPortMapping 表查询 host+port 组合拼接成URL格式写入文件
使用 TargetExportService.generate_default_urls() 处理默认值回退逻辑
使用 TargetProvider 从任意数据源导出 URL用于 httpx 站点探测)。
特殊逻辑:
- 80 端口:只生成 HTTP URL省略端口号
- 443 端口:只生成 HTTPS URL省略端口号
- 其他端口:生成 HTTP 和 HTTPS 两个URL带端口号
数据源HostPortMapping为空时回退到默认 URL
"""
import logging
from pathlib import Path
from prefect import task
from apps.asset.services import HostPortMappingService
from apps.scan.services.target_export_service import create_export_service
from apps.common.services import BlacklistService
from apps.common.utils import BlacklistFilter
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
def _generate_urls_from_port(host: str, port: int) -> list[str]:
"""
根据端口生成 URL 列表
- 80 端口:只生成 HTTP URL省略端口号
- 443 端口:只生成 HTTPS URL省略端口号
- 其他端口:生成 HTTP 和 HTTPS 两个URL带端口号
"""
if port == 80:
return [f"http://{host}"]
elif port == 443:
return [f"https://{host}"]
else:
return [f"http://{host}:{port}", f"https://{host}:{port}"]
@task(name="export_site_urls")
def export_site_urls_task(
target_id: int,
output_file: str,
batch_size: int = 1000
provider: TargetProvider,
) -> dict:
"""
导出目标下的所有站点URL到文件(基于 HostPortMapping 表)
数据源: HostPortMapping (host + port) → Default
特殊逻辑:
- 80 端口:只生成 HTTP URL省略端口号
- 443 端口:只生成 HTTPS URL省略端口号
- 其他端口:生成 HTTP 和 HTTPS 两个URL带端口号
回退逻辑:
- 如果 HostPortMapping 为空,使用 generate_default_urls() 生成默认 URL
导出目标下的所有站点URL到文件
数据源HostPortMapping,为空时回退到默认 URL
Args:
target_id: 目标ID
output_file: 输出文件路径(绝对路径)
batch_size: 每次处理的批次大小
provider: TargetProvider 实例
Returns:
dict: {
'success': bool,
'output_file': str,
'total_urls': int,
'association_count': int, # 主机端口关联数量
'source': str, # 数据来源: "host_port" | "default"
'source': str, # host_port | default
}
Raises:
ValueError: 参数错误
IOError: 文件写入失败
ValueError: provider 未提供
"""
logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file)
# 确保输出目录存在
if provider is None:
raise ValueError("必须提供 provider 参数")
logger.info("导出 URL - Provider: %s", type(provider).__name__)
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 获取规则并创建过滤器
blacklist_filter = BlacklistFilter(BlacklistService().get_rules(target_id))
# 直接查询 HostPortMapping 表,按 host 排序
service = HostPortMappingService()
associations = service.iter_host_port_by_target(
target_id=target_id,
batch_size=batch_size,
)
total_urls = 0
association_count = 0
filtered_count = 0
# 流式写入文件(特殊端口逻辑)
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for assoc in associations:
association_count += 1
host = assoc['host']
port = assoc['port']
# 先校验 host通过了再生成 URL
if not blacklist_filter.is_allowed(host):
filtered_count += 1
continue
# 根据端口号生成URL
for url in _generate_urls_from_port(host, port):
f.write(f"{url}\n")
total_urls += 1
if association_count % 1000 == 0:
logger.info("已处理 %d 条关联,生成 %d 个URL...", association_count, total_urls)
if filtered_count > 0:
logger.info("黑名单过滤: 过滤 %d 条关联", filtered_count)
logger.info(
"✓ 站点URL导出完成 - 关联数: %d, 总URL数: %d, 文件: %s",
association_count, total_urls, str(output_path)
)
# 判断数据来源
# 按优先级获取数据源
urls = list(provider.iter_host_port_urls())
source = "host_port"
# 数据存在但全被过滤,不回退
if association_count > 0 and total_urls == 0:
logger.info("HostPortMapping 有 %d 条数据,但全被黑名单过滤,不回退", association_count)
return {
'success': True,
'output_file': str(output_path),
'total_urls': 0,
'association_count': association_count,
'source': source,
}
# 数据源为空,回退到默认 URL 生成
if total_urls == 0:
logger.info("HostPortMapping 为空,使用默认 URL 生成")
export_service = create_export_service(target_id)
result = export_service.generate_default_urls(target_id, str(output_path))
total_urls = result['total_count']
if not urls:
logger.info("HostPortMapping 为空,生成默认 URL")
urls = list(provider.iter_default_urls())
source = "default"
# 写入文件
total_urls = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in urls:
f.write(f"{url}\n")
total_urls += 1
logger.info(
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
source, total_urls, str(output_path)
)
return {
'success': True,
'output_file': str(output_path),
'total_urls': total_urls,
'association_count': association_count,
'source': source,
}

View File

@@ -25,7 +25,7 @@ import json
import subprocess
import time
from pathlib import Path
from prefect import task
from typing import Generator, Optional, Dict, Any, TYPE_CHECKING
from django.db import IntegrityError, OperationalError, DatabaseError
from dataclasses import dataclass
@@ -341,11 +341,12 @@ def _save_batch(
)
snapshot_items.append(snapshot_dto)
except Exception as e:
logger.error("处理记录失败: %s,错误: %s", record.url, e)
continue
# ========== Step 3: 保存快照并同步到资产表(通过快照 Service==========
# ========== Step 2: 保存快照并同步到资产表(通过快照 Service==========
if snapshot_items:
services.snapshot.save_and_sync(snapshot_items)
@@ -658,7 +659,7 @@ def _cleanup_resources(data_generator) -> None:
logger.error("关闭生成器时出错: %s", gen_close_error)
@task(name='run_and_stream_save_websites', retries=0)
def run_and_stream_save_websites_task(
cmd: str,
tool_name: str,

View File

@@ -20,63 +20,40 @@ Note:
"""
import logging
import uuid
import subprocess
from pathlib import Path
import uuid
from datetime import datetime
from prefect import task
from pathlib import Path
from typing import List
logger = logging.getLogger(__name__)
# 注:使用纯系统命令实现,无需 Python 缓冲区配置
# 工具amass/subfinder输出已是小写且标准化
@task(
name='merge_and_deduplicate',
retries=1,
log_prints=True
)
def merge_and_validate_task(
result_files: List[str],
result_dir: str
) -> str:
"""
合并扫描结果并去重(高性能流式处理)
流程:
1. 使用 LC_ALL=C sort -u 直接处理多文件
2. 排序去重一步完成
3. 返回去重后的文件路径
命令:LC_ALL=C sort -u file1 file2 file3 -o output
注:工具输出已标准化(小写,无空行),无需额外处理
Args:
result_files: 结果文件路径列表
result_dir: 结果目录
Returns:
str: 去重后的域名文件路径
Raises:
RuntimeError: 处理失败
Performance:
- 纯系统命令(C语言实现),单进程极简
- LC_ALL=C: 字节序比较
- sort -u: 直接处理多文件(无管道开销)
Design:
- 极简单命令,无冗余处理
- 单进程直接执行(无管道/重定向开销)
- 内存占用仅在 sort 阶段(外部排序,不会 OOM
"""
logger.info("开始合并并去重 %d 个结果文件(系统命令优化)", len(result_files))
result_path = Path(result_dir)
# 验证文件存在性
def _count_file_lines(file_path: str) -> int:
"""使用 wc -l 统计文件行数,失败时返回 0"""
try:
result = subprocess.run(
["wc", "-l", file_path],
check=True,
capture_output=True,
text=True,
)
return int(result.stdout.strip().split()[0])
except (subprocess.CalledProcessError, ValueError, IndexError):
return 0
def _calculate_timeout(total_lines: int) -> int:
"""根据总行数计算超时时间(每行约 0.1 秒,最少 600 秒)"""
if total_lines <= 0:
return 3600
return max(600, int(total_lines * 0.1))
def _validate_input_files(result_files: List[str]) -> List[str]:
"""验证输入文件存在性,返回有效文件列表"""
valid_files = []
for file_path_str in result_files:
file_path = Path(file_path_str)
@@ -84,112 +61,68 @@ def merge_and_validate_task(
valid_files.append(str(file_path))
else:
logger.warning("结果文件不存在: %s", file_path)
return valid_files
def merge_and_validate_task(result_files: List[str], result_dir: str) -> str:
"""
合并扫描结果并去重(高性能流式处理)
使用 LC_ALL=C sort -u 直接处理多文件,排序去重一步完成。
Args:
result_files: 结果文件路径列表
result_dir: 结果目录
Returns:
去重后的域名文件路径
Raises:
RuntimeError: 处理失败
"""
logger.info("开始合并并去重 %d 个结果文件", len(result_files))
valid_files = _validate_input_files(result_files)
if not valid_files:
raise RuntimeError("所有结果文件都不存在")
# 生成输出文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
short_uuid = uuid.uuid4().hex[:4]
merged_file = result_path / f"merged_{timestamp}_{short_uuid}.txt"
merged_file = Path(result_dir) / f"merged_{timestamp}_{short_uuid}.txt"
# 计算超时时间
total_lines = sum(_count_file_lines(f) for f in valid_files)
timeout = _calculate_timeout(total_lines)
logger.info("合并去重: 输入总行数=%d, timeout=%d", total_lines, timeout)
# 执行合并去重命令
cmd = f"LC_ALL=C sort -u {' '.join(valid_files)} -o {merged_file}"
logger.debug("执行命令: %s", cmd)
try:
# ==================== 使用系统命令一步完成:排序去重 ====================
# LC_ALL=C: 使用字节序比较(比locale快20-30%)
# sort -u: 直接处理多文件,排序去重
# -o: 安全输出(比重定向更可靠)
cmd = f"LC_ALL=C sort -u {' '.join(valid_files)} -o {merged_file}"
logger.debug("执行命令: %s", cmd)
subprocess.run(cmd, shell=True, check=True, timeout=timeout)
except subprocess.TimeoutExpired as exc:
raise RuntimeError("合并去重超时,请检查数据量或系统资源") from exc
except subprocess.CalledProcessError as exc:
raise RuntimeError(f"系统命令执行失败: {exc.stderr or exc}") from exc
# 按输入文件总行数动态计算超时时间
total_lines = 0
for file_path in valid_files:
try:
line_count_proc = subprocess.run(
["wc", "-l", file_path],
check=True,
capture_output=True,
text=True,
)
total_lines += int(line_count_proc.stdout.strip().split()[0])
except (subprocess.CalledProcessError, ValueError, IndexError):
continue
# 验证输出文件
if not merged_file.exists():
raise RuntimeError("合并文件未被创建")
timeout = 3600
if total_lines > 0:
# 按行数线性计算:每行约 0.1 秒
base_per_line = 0.1
est = int(total_lines * base_per_line)
timeout = max(600, est)
unique_count = _count_file_lines(str(merged_file))
if unique_count == 0:
# 降级为 Python 统计
with open(merged_file, 'r', encoding='utf-8') as f:
unique_count = sum(1 for _ in f)
logger.info(
"Subdomain 合并去重 timeout 自动计算: 输入总行数=%d, timeout=%d",
total_lines,
timeout,
)
if unique_count == 0:
logger.warning("未找到任何有效域名,返回空文件")
# 不抛出异常,返回空文件让后续流程正常处理
result = subprocess.run(
cmd,
shell=True,
check=True,
timeout=timeout
)
logger.debug("✓ 合并去重完成")
# ==================== 统计结果 ====================
if not merged_file.exists():
raise RuntimeError("合并文件未被创建")
# 统计行数(使用系统命令提升大文件性能)
try:
line_count_proc = subprocess.run(
["wc", "-l", str(merged_file)],
check=True,
capture_output=True,
text=True
)
unique_count = int(line_count_proc.stdout.strip().split()[0])
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
logger.warning(
"wc -l 统计失败(文件: %s),降级为 Python 逐行统计 - 错误: %s",
merged_file, e
)
unique_count = 0
with open(merged_file, 'r', encoding='utf-8') as file_obj:
for _ in file_obj:
unique_count += 1
if unique_count == 0:
raise RuntimeError("未找到任何有效域名")
file_size = merged_file.stat().st_size
logger.info(
"✓ 合并去重完成 - 去重后: %d 个域名, 文件大小: %.2f KB",
unique_count,
file_size / 1024
)
return str(merged_file)
except subprocess.TimeoutExpired:
error_msg = "合并去重超时(>60分钟请检查数据量或系统资源"
logger.warning(error_msg) # 超时是可预期的
raise RuntimeError(error_msg)
except subprocess.CalledProcessError as e:
error_msg = f"系统命令执行失败: {e.stderr if e.stderr else str(e)}"
logger.warning(error_msg) # 超时是可预期的
raise RuntimeError(error_msg) from e
except IOError as e:
error_msg = f"文件读写失败: {e}"
logger.warning(error_msg) # 超时是可预期的
raise RuntimeError(error_msg) from e
except Exception as e:
error_msg = f"合并去重失败: {e}"
logger.error(error_msg, exc_info=True)
raise
file_size_kb = merged_file.stat().st_size / 1024
logger.info("✓ 合并去重完成 - 去重后: %d 个域名, 文件大小: %.2f KB", unique_count, file_size_kb)
return str(merged_file)

View File

@@ -1,22 +1,22 @@
"""
运行扫描工具任务
负责运行单个子域名扫描工具(amass、subfinder 等)
负责运行单个子域名扫描工具subfinder、sublist3r 等)
"""
import logging
from pathlib import Path
from prefect import task
from apps.scan.utils import execute_and_wait
logger = logging.getLogger(__name__)
@task(
name='run_subdomain_discovery',
retries=0, # 显式禁用重试
log_prints=True
)
def run_subdomain_discovery_task(
tool: str,
command: str,
@@ -58,7 +58,7 @@ def run_subdomain_discovery_task(
timeout=timeout,
log_file=log_file # 明确指定日志文件路径
)
# 验证输出文件是否生成
if not output_file_path.exists():
logger.warning(

View File

@@ -7,7 +7,7 @@
import logging
import time
from pathlib import Path
from prefect import task
from typing import List
from dataclasses import dataclass
from django.db import IntegrityError, OperationalError, DatabaseError
@@ -35,11 +35,11 @@ class ServiceSet:
)
@task(
name='save_domains',
retries=0,
log_prints=True
)
def save_domains_task(
domains_file: str,
scan_id: int,

View File

@@ -0,0 +1,240 @@
"""
Task 向后兼容性测试
Property 8: Task Backward Compatibility
*For any* 任务调用,当仅提供 target_id 参数时,任务应该创建 DatabaseTargetProvider
并使用它进行数据访问,行为与改造前一致。
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
"""
import tempfile
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from hypothesis import given, strategies as st, settings
from apps.scan.tasks.port_scan.export_hosts_task import export_hosts_task
from apps.scan.tasks.site_scan.export_site_urls_task import export_site_urls_task
from apps.scan.providers import ListTargetProvider
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
class TestExportHostsTaskBackwardCompatibility:
"""export_hosts_task 向后兼容性测试"""
@given(
target_id=st.integers(min_value=1, max_value=1000),
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=10)
)
@settings(max_examples=50, deadline=None)
def test_property_8_legacy_mode_creates_database_provider(self, target_id, hosts):
"""
Property 8: Task Backward Compatibility (export_hosts_task)
Feature: scan-target-provider, Property 8: Task Backward Compatibility
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
For any target_id, when calling export_hosts_task with only target_id,
it should create a DatabaseTargetProvider and use it for data access.
"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
# Mock Target 和 SubdomainService
mock_target = MagicMock()
mock_target.type = 'domain'
mock_target.name = hosts[0]
with patch('apps.scan.tasks.port_scan.export_hosts_task.DatabaseTargetProvider') as mock_provider_class, \
patch('apps.targets.services.TargetService') as mock_target_service:
# 创建 mock provider 实例
mock_provider = MagicMock()
mock_provider.iter_hosts.return_value = iter(hosts)
mock_provider.get_blacklist_filter.return_value = None
mock_provider_class.return_value = mock_provider
# Mock TargetService
mock_target_service.return_value.get_target.return_value = mock_target
# 调用任务(传统模式:只传 target_id
result = export_hosts_task(
output_file=output_file,
target_id=target_id
)
# 验证:应该创建了 DatabaseTargetProvider
mock_provider_class.assert_called_once_with(target_id=target_id)
# 验证:返回值包含必需字段
assert result['success'] is True
assert result['output_file'] == output_file
assert result['total_count'] == len(hosts)
assert 'target_type' in result # 传统模式应该返回 target_type
# 验证:文件内容正确
with open(output_file, 'r') as f:
lines = [line.strip() for line in f.readlines()]
assert lines == hosts
finally:
Path(output_file).unlink(missing_ok=True)
def test_legacy_mode_with_provider_parameter(self):
"""测试当同时提供 target_id 和 provider 时provider 优先"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
hosts = ['example.com', 'test.com']
provider = ListTargetProvider(targets=hosts)
# 调用任务(同时提供 target_id 和 provider
result = export_hosts_task(
output_file=output_file,
target_id=123, # 应该被忽略
provider=provider
)
# 验证:使用了 provider
assert result['success'] is True
assert result['total_count'] == len(hosts)
assert 'target_type' not in result # Provider 模式不返回 target_type
# 验证:文件内容正确
with open(output_file, 'r') as f:
lines = [line.strip() for line in f.readlines()]
assert lines == hosts
finally:
Path(output_file).unlink(missing_ok=True)
def test_error_when_no_parameters(self):
"""测试当 target_id 和 provider 都未提供时抛出错误"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
with pytest.raises(ValueError, match="必须提供 target_id 或 provider 参数之一"):
export_hosts_task(output_file=output_file)
finally:
Path(output_file).unlink(missing_ok=True)
class TestExportSiteUrlsTaskBackwardCompatibility:
"""export_site_urls_task 向后兼容性测试"""
def test_property_8_legacy_mode_uses_traditional_logic(self):
"""
Property 8: Task Backward Compatibility (export_site_urls_task)
Feature: scan-target-provider, Property 8: Task Backward Compatibility
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
When calling export_site_urls_task with only target_id,
it should use the traditional logic (_export_site_urls_legacy).
"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
target_id = 123
# Mock HostPortMappingService
mock_associations = [
{'host': 'example.com', 'port': 80},
{'host': 'test.com', 'port': 443},
]
with patch('apps.scan.tasks.site_scan.export_site_urls_task.HostPortMappingService') as mock_service_class, \
patch('apps.scan.tasks.site_scan.export_site_urls_task.BlacklistService') as mock_blacklist_service:
# Mock HostPortMappingService
mock_service = MagicMock()
mock_service.iter_host_port_by_target.return_value = iter(mock_associations)
mock_service_class.return_value = mock_service
# Mock BlacklistService
mock_blacklist = MagicMock()
mock_blacklist.get_rules.return_value = []
mock_blacklist_service.return_value = mock_blacklist
# 调用任务(传统模式:只传 target_id
result = export_site_urls_task(
output_file=output_file,
target_id=target_id
)
# 验证:返回值包含传统模式的字段
assert result['success'] is True
assert result['output_file'] == output_file
assert result['total_urls'] == 2 # 80 端口生成 1 个 URL443 端口生成 1 个 URL
assert 'association_count' in result # 传统模式应该返回 association_count
assert result['association_count'] == 2
assert result['source'] == 'host_port'
# 验证:文件内容正确
with open(output_file, 'r') as f:
lines = [line.strip() for line in f.readlines()]
assert 'http://example.com' in lines
assert 'https://test.com' in lines
finally:
Path(output_file).unlink(missing_ok=True)
def test_provider_mode_uses_provider_logic(self):
"""测试当提供 provider 时使用 Provider 模式"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
urls = ['https://example.com', 'https://test.com']
provider = ListTargetProvider(targets=urls)
# 调用任务Provider 模式)
result = export_site_urls_task(
output_file=output_file,
provider=provider
)
# 验证:使用了 provider
assert result['success'] is True
assert result['total_urls'] == len(urls)
assert 'association_count' not in result # Provider 模式不返回 association_count
assert result['source'] == 'provider'
# 验证:文件内容正确
with open(output_file, 'r') as f:
lines = [line.strip() for line in f.readlines()]
assert lines == urls
finally:
Path(output_file).unlink(missing_ok=True)
def test_error_when_no_parameters(self):
"""测试当 target_id 和 provider 都未提供时抛出错误"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
with pytest.raises(ValueError, match="必须提供 target_id 或 provider 参数之一"):
export_site_urls_task(output_file=output_file)
finally:
Path(output_file).unlink(missing_ok=True)

View File

@@ -11,7 +11,7 @@ import logging
import subprocess
from pathlib import Path
from datetime import datetime
from prefect import task
from typing import Optional
from apps.scan.utils import execute_and_wait
@@ -19,11 +19,11 @@ from apps.scan.utils import execute_and_wait
logger = logging.getLogger(__name__)
@task(
name='clean_urls_with_uro',
retries=1,
log_prints=True
)
def clean_urls_task(
input_file: str,
output_dir: str,

View File

@@ -1,69 +1,79 @@
"""
导出站点 URL 列表任务
使用 export_urls_with_fallback 用例函数处理回退链逻辑
数据源: WebSite.url → Default用于 katana 等爬虫工具)
使用 TargetProvider 从任意数据源导出 URL用于 katana 等爬虫工具)。
数据源WebSite为空时回退到默认 URL
"""
import logging
from prefect import task
from pathlib import Path
from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
@task(
name='export_sites_for_url_fetch',
retries=1,
log_prints=True
)
def export_sites_task(
output_file: str,
target_id: int,
scan_id: int,
batch_size: int = 1000
provider: TargetProvider,
) -> dict:
"""
导出站点 URL 列表到文件(用于 katana 等爬虫工具)
数据源优先级(回退链):
1. WebSite 表 - 站点级别 URL
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
数据源WebSite为空时回退到默认 URL
Args:
output_file: 输出文件路径
target_id: 目标 ID
scan_id: 扫描 ID保留参数兼容旧调用
batch_size: 批次大小(内存优化)
provider: TargetProvider 实例
Returns:
dict: {
'output_file': str, # 输出文件路径
'asset_count': int, # 资产数量
'output_file': str,
'asset_count': int,
'source': str, # website | default
}
Raises:
ValueError: 参数错误
RuntimeError: 执行失败
ValueError: provider 未提供
"""
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
sources=[DataSource.WEBSITE, DataSource.DEFAULT],
batch_size=batch_size,
)
if provider is None:
raise ValueError("必须提供 provider 参数")
logger.info("导出 URL - Provider: %s", type(provider).__name__)
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 按优先级获取数据源
urls = list(provider.iter_websites())
source = "website"
if not urls:
logger.info("WebSite 为空,生成默认 URL")
urls = list(provider.iter_default_urls())
source = "default"
# 写入文件
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in urls:
f.write(f"{url}\n")
total_count += 1
logger.info(
"站点 URL 导出完成 - source=%s, count=%d",
result['source'], result['total_count']
" URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
source, total_count, str(output_path)
)
# 保持返回值格式不变(向后兼容)
return {
'output_file': result['output_file'],
'asset_count': result['total_count'],
'output_file': str(output_path),
'asset_count': total_count,
'source': source,
}

View File

@@ -10,17 +10,17 @@ import uuid
import subprocess
from pathlib import Path
from datetime import datetime
from prefect import task
from typing import List
logger = logging.getLogger(__name__)
@task(
name='merge_and_deduplicate_urls',
retries=1,
log_prints=True
)
def merge_and_deduplicate_urls_task(
result_files: List[str],
result_dir: str

View File

@@ -22,7 +22,7 @@ import json
import subprocess
import time
from pathlib import Path
from prefect import task
from typing import Generator, Optional, Dict, Any
from django.db import IntegrityError, OperationalError, DatabaseError
from psycopg2 import InterfaceError
@@ -582,7 +582,7 @@ def _process_records_in_batches(
}
@task(name="run_and_stream_save_urls", retries=0)
def run_and_stream_save_urls_task(
cmd: str,
tool_name: str,

View File

@@ -10,17 +10,17 @@
import logging
from pathlib import Path
from prefect import task
from apps.scan.utils import execute_and_wait
logger = logging.getLogger(__name__)
@task(
name='run_url_fetcher',
retries=0, # 不重试,工具本身会处理
log_prints=True
)
def run_url_fetcher_task(
tool_name: str,
command: str,

View File

@@ -7,7 +7,7 @@
import logging
from pathlib import Path
from prefect import task
from typing import List, Optional
from urllib.parse import urlparse
from dataclasses import dataclass
@@ -70,11 +70,11 @@ def _parse_url(url: str) -> Optional[ParsedURL]:
return None
@task(
name='save_urls',
retries=1,
log_prints=True
)
def save_urls_task(
urls_file: str,
scan_id: int,

View File

@@ -2,18 +2,21 @@
包含:
- export_endpoints_task: 导出端点 URL 到文件
- export_websites_task: 导出网站 URL 到文件
- run_vuln_tool_task: 执行漏洞扫描工具(非流式)
- run_and_stream_save_dalfox_vulns_task: Dalfox 流式执行并保存漏洞结果
- run_and_stream_save_nuclei_vulns_task: Nuclei 流式执行并保存漏洞结果
"""
from .export_endpoints_task import export_endpoints_task
from .export_websites_task import export_websites_task
from .run_vuln_tool_task import run_vuln_tool_task
from .run_and_stream_save_dalfox_vulns_task import run_and_stream_save_dalfox_vulns_task
from .run_and_stream_save_nuclei_vulns_task import run_and_stream_save_nuclei_vulns_task
__all__ = [
"export_endpoints_task",
"export_websites_task",
"run_vuln_tool_task",
"run_and_stream_save_dalfox_vulns_task",
"run_and_stream_save_nuclei_vulns_task",

View File

@@ -1,67 +1,74 @@
"""导出 Endpoint URL 到文件的 Task
使用 export_urls_with_fallback 用例函数处理回退链逻辑
使用 TargetProvider 从任意数据源导出 URL。
数据源优先级(回退链):
1. Endpoint.url - 最精细的 URL含路径、参数等
2. WebSite.url - 站点级别 URL
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
数据源Endpoint为空时回退到默认 URL
"""
import logging
from typing import Dict
from pathlib import Path
from prefect import task
from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
@task(name="export_endpoints")
def export_endpoints_task(
target_id: int,
output_file: str,
batch_size: int = 1000,
provider: TargetProvider,
) -> Dict[str, object]:
"""导出目标下的所有 Endpoint URL 到文本文件。
数据源优先级(回退链)
1. Endpoint 表 - 最精细的 URL含路径、参数等
2. WebSite 表 - 站点级别 URL
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
数据源优先级:Endpoint → 默认生成
Args:
target_id: 目标 ID
output_file: 输出文件路径(绝对路径)
batch_size: 每次从数据库迭代的批大小
provider: TargetProvider 实例
Returns:
dict: {
"success": bool,
"output_file": str,
"total_count": int,
"source": str, # 数据来源: "endpoint" | "website" | "default" | "none"
"source": str, # endpoint | default
}
"""
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
batch_size=batch_size,
)
if provider is None:
raise ValueError("必须提供 provider 参数")
logger.info("导出 URL - Provider: %s", type(provider).__name__)
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 获取数据,为空时回退到默认 URL
urls = list(provider.iter_endpoints())
source = "endpoint"
if not urls:
logger.info("Endpoint 为空,生成默认 URL")
urls = list(provider.iter_default_urls())
source = "default"
# 写入文件
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in urls:
f.write(f"{url}\n")
total_count += 1
logger.info(
"URL 导出完成 - source=%s, count=%d, tried=%s",
result['source'], result['total_count'], result['tried_sources']
"URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
source, total_count, str(output_path)
)
return {
"success": result['success'],
"output_file": result['output_file'],
"total_count": result['total_count'],
"source": result['source'],
"success": True,
"output_file": str(output_path),
"total_count": total_count,
"source": source,
}

View File

@@ -0,0 +1,73 @@
"""导出 WebSite URL 到文件的 Task
使用 TargetProvider 从任意数据源导出 URL。
数据源WebSite为空时回退到默认 URL
"""
import logging
from pathlib import Path
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
def export_websites_task(
output_file: str,
provider: TargetProvider,
) -> dict:
"""导出目标下的所有 WebSite URL 到文本文件。
数据源优先级WebSite → 默认生成
Args:
output_file: 输出文件路径(绝对路径)
provider: TargetProvider 实例
Returns:
dict: {
"success": bool,
"output_file": str,
"total_count": int,
"source": str, # website | default
}
"""
if provider is None:
raise ValueError("必须提供 provider 参数")
logger.info("导出 URL - Provider: %s", type(provider).__name__)
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 获取数据,为空时回退到默认 URL
urls = list(provider.iter_websites())
source = "website"
if not urls:
logger.info("WebSite 为空,生成默认 URL")
urls = list(provider.iter_default_urls())
source = "default"
# 写入文件
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in urls:
f.write(f"{url}\n")
total_count += 1
logger.info(
"✓ URL 导出完成 - 来源: %s, 总数: %d, 文件: %s",
source, total_count, str(output_path)
)
return {
"success": True,
"output_file": str(output_path),
"total_count": total_count,
"source": source,
}

View File

@@ -25,7 +25,7 @@ from pathlib import Path
from dataclasses import dataclass
from typing import Generator, Optional, TYPE_CHECKING
from prefect import task
from django.db import IntegrityError, OperationalError, DatabaseError
from psycopg2 import InterfaceError
@@ -393,11 +393,11 @@ def _cleanup_resources(data_generator) -> None:
logger.error("关闭生成器时出错: %s", gen_close_error)
@task(
name="run_and_stream_save_dalfox_vulns",
retries=0,
log_prints=True,
)
def run_and_stream_save_dalfox_vulns_task(
cmd: str,
tool_name: str,

View File

@@ -22,7 +22,7 @@ from pathlib import Path
from dataclasses import dataclass
from typing import Generator, Optional, TYPE_CHECKING
from prefect import task
from django.db import IntegrityError, OperationalError, DatabaseError
from psycopg2 import InterfaceError
@@ -395,11 +395,11 @@ def _cleanup_resources(data_generator) -> None:
logger.error("关闭生成器时出错: %s", gen_close_error)
@task(
name="run_and_stream_save_nuclei_vulns",
retries=0,
log_prints=True,
)
def run_and_stream_save_nuclei_vulns_task(
cmd: str,
tool_name: str,

View File

@@ -10,18 +10,18 @@
import logging
from typing import Dict
from prefect import task
from apps.scan.utils import execute_and_wait
logger = logging.getLogger(__name__)
@task(
name="run_vuln_tool",
retries=0,
log_prints=True,
)
def run_vuln_tool_task(
tool_name: str,
command: str,

View File

@@ -4,7 +4,8 @@ from .views import ScanViewSet, ScheduledScanViewSet, ScanLogListView, Subfinder
from .notifications.views import notification_callback
from apps.asset.views import (
SubdomainSnapshotViewSet, WebsiteSnapshotViewSet, DirectorySnapshotViewSet,
EndpointSnapshotViewSet, HostPortMappingSnapshotViewSet, VulnerabilitySnapshotViewSet
EndpointSnapshotViewSet, HostPortMappingSnapshotViewSet, VulnerabilitySnapshotViewSet,
ScreenshotSnapshotViewSet
)
# 创建路由器
@@ -26,6 +27,8 @@ scan_endpoints_export = EndpointSnapshotViewSet.as_view({'get': 'export'})
scan_ip_addresses_list = HostPortMappingSnapshotViewSet.as_view({'get': 'list'})
scan_ip_addresses_export = HostPortMappingSnapshotViewSet.as_view({'get': 'export'})
scan_vulnerabilities_list = VulnerabilitySnapshotViewSet.as_view({'get': 'list'})
scan_screenshots_list = ScreenshotSnapshotViewSet.as_view({'get': 'list'})
scan_screenshots_image = ScreenshotSnapshotViewSet.as_view({'get': 'image'})
urlpatterns = [
path('', include(router.urls)),
@@ -47,5 +50,7 @@ urlpatterns = [
path('scans/<int:scan_pk>/ip-addresses/', scan_ip_addresses_list, name='scan-ip-addresses-list'),
path('scans/<int:scan_pk>/ip-addresses/export/', scan_ip_addresses_export, name='scan-ip-addresses-export'),
path('scans/<int:scan_pk>/vulnerabilities/', scan_vulnerabilities_list, name='scan-vulnerabilities-list'),
path('scans/<int:scan_pk>/screenshots/', scan_screenshots_list, name='scan-screenshots-list'),
path('scans/<int:scan_pk>/screenshots/<int:pk>/image/', scan_screenshots_image, name='scan-screenshots-image'),
]

Some files were not shown because too many files have changed in this diff Show More