Compare commits

...

133 Commits

Author SHA1 Message Date
yyhuni
6caf707072 refactor: replace Chinese comments with English in frontend components
- Replace all Chinese inline comments with English equivalents across 24 frontend component files
- Update JSDoc comments to use English for better code documentation
- Improve code readability and maintainability for international development team
- Standardize comment style across directories, endpoints, ip-addresses, subdomains, and websites components
- Ensure consistency with previous frontend refactoring efforts
2025-12-29 23:01:16 +08:00
yyhuni
2627b1fc40 refactor: replace Chinese comments with English across frontend components
- Replace Chinese comments with English in fingerprint components (ehole, goby, wappalyzer)
- Update comments in scan engine, history, and scheduled scan modules
- Translate comments in worker deployment and configuration dialogs
- Update comments in subdomain management and target components
- Translate comments in tools configuration and command modules
- Replace Chinese comments in vulnerability components
- Improve code maintainability and consistency with English documentation standards
- Update Docker build workflow cache configuration with image-specific scopes for better cache isolation
2025-12-29 22:14:12 +08:00
yyhuni
ec6712b9b4 fix: add null coalescing to prevent undefined values in i18n translations
- Add null coalescing operator (?? "") to all i18n translation parameters across components
- Fix scheduled scan deletion dialog to handle undefined scheduled scan name
- Fix nuclei page to pass locale parameter to formatDateTime function
- Fix organization detail view unlink target dialog to handle undefined target name
- Fix organization list deletion dialog to handle undefined organization name
- Fix organization targets detail view unlink dialog to handle undefined target name
- Fix engine edit dialog to handle undefined engine name
- Fix scan history list deletion and stop dialogs to handle undefined target names
- Fix worker list deletion dialog to handle undefined worker name
- Fix all targets detail view deletion dialog to handle undefined target name
- Fix custom tools and opensource tools lists to handle undefined tool names
- Fix vulnerabilities detail view to handle undefined vulnerability names
- Prevents runtime errors when translation parameters are undefined or null
2025-12-29 21:03:47 +08:00
yyhuni
9d5e4d5408 fix(scan/engine): handle undefined engine name in delete confirmation
- Add nullish coalescing operator to prevent undefined value in delete confirmation message
- Ensure engineToDelete?.name defaults to empty string when undefined
- Improve robustness of alert dialog description rendering
2025-12-29 20:54:00 +08:00
yyhuni
c5d5b24c8f 更新github action dev版本不更新version 2025-12-29 20:48:42 +08:00
yyhuni
671cb56b62 fix:nuclei模板加速同步,模板下载到宿主机同步更新 2025-12-29 20:43:49 +08:00
yyhuni
51025f69a8 fix:大陆加速修复 2025-12-29 20:15:25 +08:00
yyhuni
b2403b29c4 删除update.sh 2025-12-29 20:08:40 +08:00
yyhuni
18ef01a47b fix:cn加速 2025-12-29 20:03:14 +08:00
yyhuni
0bf8108fb3 fix:镜像加速 2025-12-29 19:51:33 +08:00
yyhuni
837ad19131 fix:镜像加速问题 2025-12-29 19:48:48 +08:00
yyhuni
d7de9a7129 fix:镜像加速问题 2025-12-29 19:39:59 +08:00
yyhuni
22b4e51b42 feat(xget): add Git URL acceleration support via Xget proxy
- Add xget_proxy utility module to convert Git repository URLs to Xget proxy format
- Support domain mapping for GitHub, GitLab, Gitea, and Codeberg repositories
- Integrate Xget proxy into Nuclei template repository cloning process
- Add XGET_MIRROR environment variable configuration in container bootstrap
- Export XGET_MIRROR setting to worker node configuration endpoint
- Add --mirror flag to install.sh for easy Xget acceleration setup
- Add configure_docker_mirror function to install.sh for Docker registry mirror configuration
- Enable Git clone acceleration for faster template repository downloads in air-gapped or bandwidth-limited environments
2025-12-29 19:32:05 +08:00
yyhuni
d03628ee45 feat(i18n): translate Chinese comments to English in scan history component
- Replace Chinese console error messages with English equivalents
- Translate all inline code comments from Chinese to English
- Update dialog and section comments for consistency
- Improve code readability and maintainability for international development team
2025-12-29 18:42:13 +08:00
yyhuni
0baabe0753 feat(i18n): internationalize frontend components with English translations
- Replace Chinese comments with English equivalents across auth, dashboard, and scan components
- Update UI text labels and descriptions from Chinese to English in bulk-add-urls-dialog
- Translate placeholder text and dialog titles in asset management components
- Update column headers and data table labels to English in organization and engine modules
- Standardize English documentation strings in auth-guard and auth-layout components
- Improve code maintainability and accessibility for international users
- Align with existing internationalization efforts across the frontend codebase
2025-12-29 18:39:25 +08:00
yyhuni
e1191d7abf 国际化前端ui 2025-12-29 18:10:05 +08:00
yyhuni
82a2e9a0e7 国际化前端 2025-12-29 18:09:57 +08:00
yyhuni
1ccd1bc338 更新gfPatterns 2025-12-28 20:26:32 +08:00
yyhuni
b4d42f5372 更新指纹管理搜索 2025-12-28 20:18:26 +08:00
yyhuni
2c66450756 统一ui 2025-12-28 20:10:46 +08:00
yyhuni
119d82dc89 更新ui 2025-12-28 20:06:17 +08:00
yyhuni
fba7f7c508 更新ui 2025-12-28 19:55:57 +08:00
yyhuni
99d384ce29 修复前端列宽 2025-12-28 16:37:35 +08:00
yyhuni
07f36718ab 重构前端 2025-12-28 16:27:01 +08:00
yyhuni
7e3f69c208 重构前端组件 2025-12-28 12:05:47 +08:00
yyhuni
5f90473c3c fix:ui 2025-12-28 08:48:25 +08:00
yyhuni
e2a815b96a 增加:goby wappalyzer指纹 2025-12-28 08:42:37 +08:00
yyhuni
f86a1a9d47 优化ui 2025-12-27 22:01:40 +08:00
yyhuni
d5945679aa 增加日志 2025-12-27 21:50:43 +08:00
yyhuni
51e2c51748 fix:目录创建挂载 2025-12-27 21:44:47 +08:00
yyhuni
e2cbf98dda fix:target name已去除的bug 2025-12-27 21:27:05 +08:00
yyhuni
cd72bdf7c3 指纹接入 2025-12-27 20:19:25 +08:00
yyhuni
35abcf7e39 加入黑名单逻辑 2025-12-27 20:12:01 +08:00
yyhuni
09f2d343a4 新增:重构导出逻辑代码,加入黑名单过滤 2025-12-27 20:11:50 +08:00
yyhuni
54d1f86bde fix:安装报错 2025-12-27 17:51:32 +08:00
yyhuni
a3997c9676 更新yaml 2025-12-27 12:52:49 +08:00
yyhuni
c90a55f85e 更新负载逻辑 2025-12-27 12:49:14 +08:00
yyhuni
2eab88b452 chore(install): Add banner display and update confirmation
- Add show_banner() function to display XingRin ASCII art logo
- Call show_banner() before header in install.sh initialization
- Add experimental feature warning in update.sh with user confirmation
- Prompt user to confirm before proceeding with update operation
- Suggest full reinstall via uninstall.sh and install.sh as alternative
- Improve user experience with visual feedback and safety checks
2025-12-27 12:41:04 +08:00
yyhuni
1baf0eb5e1 fix:指纹扫描命令 2025-12-27 12:29:50 +08:00
yyhuni
b61e73f7be fix:json输出 2025-12-27 12:14:35 +08:00
yyhuni
e896734dfc feat(scan-engine): Add fingerprint detection feature flag
- Add fingerprint_detect feature flag to engine configuration parser
- Enable fingerprint detection capability in scan engine features
- Integrate fingerprint detection into existing feature detection logic
2025-12-27 11:59:51 +08:00
yyhuni
cd83f52f35 新增指纹识别 2025-12-27 11:39:26 +08:00
yyhuni
3e29554c36 新增:指纹识别 2025-12-27 11:39:19 +08:00
yyhuni
18e02b536e 加入:指纹识别 2025-12-27 10:06:23 +08:00
yyhuni
4c1c6f70ab 更新指纹 2025-12-26 21:50:38 +08:00
yyhuni
a72e7675f5 更新ui 2025-12-26 21:40:56 +08:00
yyhuni
93c2163764 新增:ehole指纹的导入 2025-12-26 21:34:36 +08:00
yyhuni
de72c91561 更新ui 2025-12-25 18:31:09 +08:00
github-actions[bot]
3e6d060b75 chore: bump version to v1.1.14 2025-12-25 10:11:08 +00:00
yyhuni
766f045904 fix:ffuf并发问题 2025-12-25 18:02:25 +08:00
yyhuni
8acfe1cc33 调整日志级别 2025-12-25 17:44:31 +08:00
github-actions[bot]
7aec3eabb2 chore: bump version to v1.1.13 2025-12-25 08:29:39 +00:00
yyhuni
b1f11c36a4 fix:字典下载端口 2025-12-25 16:21:32 +08:00
yyhuni
d97fb5245a 修复:提示 2025-12-25 16:18:46 +08:00
github-actions[bot]
ddf9a1f5a4 chore: bump version to v1.1.12 2025-12-25 08:10:57 +00:00
yyhuni
47f9f96a4b 更新文档 2025-12-25 16:07:30 +08:00
yyhuni
6f43e73162 readme up 2025-12-25 16:06:01 +08:00
yyhuni
9b7d496f3e 更新:端口号为8083 2025-12-25 16:02:55 +08:00
github-actions[bot]
6390849d52 chore: bump version to v1.1.11 2025-12-25 03:58:05 +00:00
yyhuni
7a6d2054f6 更新:ui 2025-12-25 11:50:21 +08:00
yyhuni
73ebaab232 更新:ui 2025-12-25 11:31:25 +08:00
github-actions[bot]
11899b29c2 chore: bump version to v1.1.10 2025-12-25 03:20:57 +00:00
github-actions[bot]
877d2a56d1 chore: bump version to v1.1.9 2025-12-25 03:13:58 +00:00
yyhuni
dc1e94f038 更新:ui 2025-12-25 11:12:51 +08:00
yyhuni
9c3833d13d 更新:ui 2025-12-25 11:06:00 +08:00
github-actions[bot]
92f3b722ef chore: bump version to v1.1.8 2025-12-25 02:16:12 +00:00
yyhuni
9ef503c666 更新:ui 2025-12-25 10:12:06 +08:00
yyhuni
c3a43e94fa 修复:ui 2025-12-25 10:08:25 +08:00
github-actions[bot]
d6d94355fb chore: bump version to v1.1.7 2025-12-25 02:02:27 +00:00
yyhuni
bc638eabf4 更新:ui 2025-12-25 10:02:13 +08:00
yyhuni
5acaada7ab 新增:支持多字段搜索功能 2025-12-25 09:54:50 +08:00
github-actions[bot]
aaad3f29cf chore: bump version to v1.1.6 2025-12-24 12:19:12 +00:00
yyhuni
f13eb2d9b2 更新:ui风格 2025-12-24 20:10:12 +08:00
yyhuni
f1b3b60382 新增:EVA主题 2025-12-24 19:57:26 +08:00
yyhuni
e249056289 Update README.md 2025-12-24 19:14:22 +08:00
yyhuni
dba195b83a 更新readme 2025-12-24 17:28:08 +08:00
github-actions[bot]
9b494e6c67 chore: bump version to v1.1.5 2025-12-24 09:23:21 +00:00
yyhuni
2841157747 优化:字体显示 2025-12-24 17:14:45 +08:00
yyhuni
f6c1fef1a6 修复:仪表盘页面删除问题 2025-12-24 17:10:48 +08:00
yyhuni
6ec0adf9dd 优化:日志打印 2025-12-24 16:39:13 +08:00
yyhuni
22c6661567 更新:ui 2025-12-24 16:25:41 +08:00
github-actions[bot]
d9ed004e35 chore: bump version to v1.1.4 2025-12-24 08:23:12 +00:00
yyhuni
a0d9d1f29d 新增:批量添加资产 2025-12-24 16:15:33 +08:00
yyhuni
8aa9ed2a97 新增:新增功能,目标详细页面批量添加资产 2025-12-24 16:15:22 +08:00
yyhuni
8baf29d1c3 新增:子域名添加功能 2025-12-24 11:27:48 +08:00
yyhuni
248e48353a 更新:数据库字段为create at 2025-12-24 10:35:55 +08:00
yyhuni
0d210be50b 更新:subdomain的字段,discovered_at TO created_at 2025-12-24 10:19:01 +08:00
github-actions[bot]
f7c0d0b215 chore: bump version to v1.1.3 2025-12-24 02:11:23 +00:00
github-actions[bot]
d83428f27b chore: bump version to v1.1.2 2025-12-24 02:08:28 +00:00
yyhuni
45a09b8173 优化:增强数据库连接稳定性 2025-12-24 10:03:24 +08:00
yyhuni
11dfdee6fd 更新ui 2025-12-24 09:57:39 +08:00
yyhuni
e53a884d13 更新:ui 2025-12-24 09:54:48 +08:00
yyhuni
3b318c89e3 fix:主题ui 2025-12-24 09:46:51 +08:00
github-actions[bot]
e564bc116a chore: bump version to v1.1.1 2025-12-23 12:17:52 +00:00
yyhuni
410c543066 优化:大量ui 2025-12-23 20:03:27 +08:00
github-actions[bot]
66da140801 chore: bump version to v1.1.0 2025-12-23 11:20:10 +00:00
yyhuni
e60aac3622 更新输入框ui高度 2025-12-23 19:18:58 +08:00
yyhuni
14aaa71cb1 调整:扫描参数 2025-12-23 19:09:09 +08:00
yyhuni
0309dba510 优化:兼容性 2025-12-23 19:08:12 +08:00
yyhuni
967ff8a69f 调整扫描参数 2025-12-23 19:05:01 +08:00
yyhuni
9ac23d50b6 fix:漏洞扫描问题 2025-12-23 18:59:40 +08:00
yyhuni
265525c61e fix:漏洞扫描问题 2025-12-23 18:59:27 +08:00
yyhuni
1b9d05ce62 fix:增加漏扫的超时时间为10分钟 2025-12-23 16:54:26 +08:00
yyhuni
737980b30f 新增:页面下载为csv 2025-12-23 16:34:24 +08:00
yyhuni
494ee81478 新增:ip add页面的下载为csv 2025-12-23 12:34:41 +08:00
yyhuni
452686b282 fix:ffuf路径拼接问题 2025-12-23 11:14:31 +08:00
yyhuni
c95c68f4e9 refactor(asset): Extract deduplication logic into reusable utility
- Create new `deduplicate_for_bulk` utility function in `apps/common/utils/dedup.py`
- Move hash utility from `apps/common/utils/hash.py` to `apps/common/utils/__init__.py`
- Update all asset repositories to use centralized deduplication before bulk operations
- Apply deduplication to directory, endpoint, host_port_mapping, subdomain, and website repositories
- Apply deduplication to all snapshot repositories for consistency
- Update vulnerability service to use new deduplication utility
- Update wordlist service and related helpers to use new utility structure
- Update organization and target repositories to use new utility
- Automatically deduplicate records by model unique constraints, keeping last occurrence
- Improve code reusability and reduce duplication across repositories
2025-12-23 11:09:17 +08:00
yyhuni
b02f38606d fix(scan): Add quotes to file paths in command templates
- Wrap all file path variables with single quotes to handle paths with spaces
- Update subfinder, amass, sublist3r, assetfinder commands with quoted output paths
- Quote wordlist and input file paths in subdomain bruteforce and resolve commands
- Add quotes to dnsgen pipeline input/output file paths
- Quote domains_file parameter in naabu port scan commands
- Wrap url_file and scan_tools_base paths in httpx site scan command
- Quote wordlist and url parameters in ffuf directory scan command
- Add quotes to output file paths in waymore and katana URL fetch commands
- Quote input/output file paths in uro command
- Add quotes to endpoints_file in dalfox and nuclei vulnerability scan commands
- Prevents command execution failures when file paths contain spaces or special characters
2025-12-23 10:45:37 +08:00
yyhuni
b543f3d2b7 feat(scan): 新增快速扫描服务及批量操作支持
新增 QuickScanService 快速扫描服务,支持最小化数据快速扫描
EndpointRepository 新增 bulk_create_ignore_conflicts 方法,高效批量创建端点
WebSiteRepository 新增 bulk_create_ignore_conflicts 方法,高效批量创建站点
新增 validate_url 函数,强制要求 URL 包含协议头
新增 is_valid_ip 辅助函数,无异常方式验证 IP 地址
新增 detect_input_type 函数,自动识别输入类型(url/domain/ip/cidr)
更新 validators 模块文档,补充 URL 验证说明
优化前端 quick-scan-dialog 组件,增强输入验证
更新前端 target-validator 工具,改进输入处理逻辑
engine_config_example.yaml 新增目录扫描 max-workers 配置
优化 directory_scan_flow,支持每个工具独立配置并发数
更新 scan_views 支持快速扫描接口
2025-12-23 10:15:39 +08:00
yyhuni
a18fb46906 更新readme 2025-12-23 08:09:06 +08:00
github-actions[bot]
bb74f61ea2 chore: bump version to v1.0.36 2025-12-22 23:50:38 +00:00
yyhuni
654f3664f8 优化:扫描逻辑,支持CIDR,IP 2025-12-22 23:36:09 +08:00
yyhuni
30defe08d2 fix:计算超时最小60s 2025-12-22 21:57:51 +08:00
yyhuni
41266bd931 fix:扫描路径 2025-12-22 21:31:11 +08:00
yyhuni
9eebd0a47c fix:httpx被python自带httpx覆盖问题 2025-12-22 21:29:20 +08:00
yyhuni
e7f4d25e58 更新:site_scan懒加载逻辑 2025-12-22 21:21:07 +08:00
yyhuni
56cc810783 更新:ui显示 2025-12-22 21:10:34 +08:00
yyhuni
efe20bbf69 fix: 更新 lock 2025-12-22 20:52:05 +08:00
yyhuni
d88cf19a68 fix:vuln 2025-12-22 20:50:09 +08:00
yyhuni
8e74f842f0 优化:ui 2025-12-22 20:36:44 +08:00
yyhuni
5e9773a183 优化:去除Directory 对 WebSite 的外键关联 2025-12-22 20:30:58 +08:00
yyhuni
a952ef5b6b 更新:upsert代替ignore_conflicts 2025-12-22 20:14:50 +08:00
yyhuni
815c409a9e 更新:使用upsert更新资产,让资产始终最新版本 2025-12-22 20:14:27 +08:00
yyhuni
7ca85b8d7d 优化:ffuf目录扫描并行逻辑
新增:添加目标自动添加url和website
2025-12-22 16:11:52 +08:00
github-actions[bot]
73291e6c4c chore: bump version to v1.0.35 2025-12-22 04:28:38 +00:00
yyhuni
dcafe03ea2 ui更新 2025-12-22 12:20:08 +08:00
yyhuni
0390e05397 改进:ui tab数字显示 2025-12-22 12:13:22 +08:00
yyhuni
088b69b61a 优化:大量UI 2025-12-22 12:06:38 +08:00
yyhuni
de34567b53 优化:ui 2025-12-22 11:14:46 +08:00
yyhuni
bf40532ce4 fix:主题切换 2025-12-22 10:06:23 +08:00
yyhuni
252759c822 更新:主题 2025-12-22 10:04:27 +08:00
yyhuni
2d43204639 优化:ui 2025-12-21 21:58:11 +08:00
github-actions[bot]
7715d0cf01 chore: bump version to v1.0.34 2025-12-21 13:44:08 +00:00
430 changed files with 184046 additions and 15406 deletions

View File

@@ -106,16 +106,17 @@ jobs:
${{ steps.version.outputs.IS_RELEASE == 'true' && format('{0}/{1}:latest', env.IMAGE_PREFIX, matrix.image) || '' }}
build-args: |
IMAGE_TAG=${{ steps.version.outputs.VERSION }}
cache-from: type=gha
cache-to: type=gha,mode=max
cache-from: type=gha,scope=${{ matrix.image }}
cache-to: type=gha,mode=max,scope=${{ matrix.image }}
provenance: false
sbom: false
# 所有镜像构建成功后,更新 VERSION 文件
# 只有正式版本(不含 -dev, -alpha, -beta, -rc 等后缀)才更新
update-version:
runs-on: ubuntu-latest
needs: build
if: startsWith(github.ref, 'refs/tags/v')
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
steps:
- name: Checkout
uses: actions/checkout@v4

3
.gitignore vendored
View File

@@ -96,6 +96,7 @@ backend/vendor/
.idea/
.cursor/
.claude/
.kiro/
.playwright-mcp/
*.swp
*.swo
@@ -131,3 +132,5 @@ temp/
HGETALL
KEYS
vuln_scan/input_endpoints.txt
open-in-v0

View File

@@ -25,23 +25,16 @@
---
<p align="center">
<b>🌗 明暗模式切换</b>
<b>🎨 现代化 UI </b>
</p>
<p align="center">
<img src="docs/screenshots/light.png" alt="Light Mode" width="49%">
<img src="docs/screenshots/dark.png" alt="Dark Mode" width="49%">
</p>
<p align="center">
<b>🎨 多种 UI 主题</b>
</p>
<p align="center">
<img src="docs/screenshots/bubblegum.png" alt="Bubblegum" width="32%">
<img src="docs/screenshots/cosmic-night.png" alt="Cosmic Night" width="32%">
<img src="docs/screenshots/quantum-rose.png" alt="Quantum Rose" width="32%">
<img src="docs/screenshots/light.png" alt="Light Mode" width="24%">
<img src="docs/screenshots/bubblegum.png" alt="Bubblegum" width="24%">
<img src="docs/screenshots/cosmic-night.png" alt="Cosmic Night" width="24%">
<img src="docs/screenshots/quantum-rose.png" alt="Quantum Rose" width="24%">
</p>
## 📚 文档
@@ -165,17 +158,9 @@ flowchart TB
### 📊 可视化界面
- **数据统计** - 资产/漏洞统计仪表盘
- **实时通知** - WebSocket 消息推送
- **暗色主题** - 支持明暗主题切换
---
## 🛠️ 技术栈
- **前端**: Next.js + React + TailwindCSS
- **后端**: Django + Django REST Framework
- **数据库**: PostgreSQL + Redis
- **部署**: Docker + Nginx
## 📦 快速开始
### 环境要求
@@ -192,11 +177,19 @@ cd xingrin
# 安装并启动(生产模式)
sudo ./install.sh
# 🇨🇳 中国大陆用户推荐使用镜像加速
sudo ./install.sh --mirror
```
> **💡 --mirror 参数说明**
> - 自动配置 Docker 镜像加速(国内镜像源)
> - 加速 Git 仓库克隆Nuclei 模板等)
> - 大幅提升安装速度,避免网络超时
### 访问服务
- **Web 界面**: `https://localhost`
- **Web 界面**: `https://ip:8083`
### 常用命令
@@ -212,9 +205,6 @@ sudo ./restart.sh
# 卸载
sudo ./uninstall.sh
# 更新
sudo ./update.sh
```
## 🤝 反馈与贡献

View File

@@ -1 +1 @@
v1.0.30
v1.1.14

View File

@@ -7,7 +7,6 @@ from typing import Optional
@dataclass
class DirectoryDTO:
"""目录数据传输对象"""
website_id: int
target_id: int
url: str
status: Optional[int] = None

View File

@@ -9,7 +9,7 @@ class WebSiteDTO:
"""网站数据传输对象"""
target_id: int
url: str
host: str
host: str = ''
title: str = ''
status_code: Optional[int] = None
content_length: Optional[int] = None

View File

@@ -12,11 +12,10 @@ class DirectorySnapshotDTO:
用于保存扫描过程中发现的目录信息到快照表
注意:website_id 和 target_id 只用于传递数据和转换为资产 DTO不会保存到快照表中。
注意target_id 只用于传递数据和转换为资产 DTO不会保存到快照表中。
快照只属于 scan。
"""
scan_id: int
website_id: int # 仅用于传递数据,不保存到数据库
target_id: int # 仅用于传递数据,不保存到数据库
url: str
status: Optional[int] = None
@@ -36,7 +35,6 @@ class DirectorySnapshotDTO:
DirectoryDTO: 资产表 DTO
"""
return DirectoryDTO(
website_id=self.website_id,
target_id=self.target_id,
url=self.url,
status=self.status,

View File

@@ -4,13 +4,6 @@ from django.contrib.postgres.fields import ArrayField
from django.core.validators import MinValueValidator, MaxValueValidator
class SoftDeleteManager(models.Manager):
"""软删除管理器:默认只返回未删除的记录"""
def get_queryset(self):
return super().get_queryset().filter(deleted_at__isnull=True)
class Subdomain(models.Model):
"""
子域名模型(纯资产表)
@@ -29,33 +22,24 @@ class Subdomain(models.Model):
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
)
name = models.CharField(max_length=1000, help_text='子域名名称')
discovered_at = models.DateTimeField(auto_now_add=True, help_text='首次发现时间')
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间NULL表示未删除')
# ==================== 管理器 ====================
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'subdomain'
verbose_name = '子域名'
verbose_name_plural = '子域名'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['-discovered_at']),
models.Index(fields=['-created_at']),
models.Index(fields=['name', 'target']), # 复合索引,优化 get_by_names_and_target_id 批量查询
models.Index(fields=['target']), # 优化从target_id快速查找下面的子域名
models.Index(fields=['name']), # 优化从name快速查找子域名搜索场景
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
]
constraints = [
# 部分唯一约束:只对未删除记录生效
# 普通唯一约束:name + target 组合唯一
models.UniqueConstraint(
fields=['name', 'target'],
condition=models.Q(deleted_at__isnull=True),
name='unique_name_target_active'
name='unique_subdomain_name_target'
)
]
@@ -87,7 +71,7 @@ class Endpoint(models.Model):
default='',
help_text='重定向地址HTTP 3xx 响应头 Location'
)
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
title = models.CharField(
max_length=1000,
blank=True,
@@ -139,33 +123,25 @@ class Endpoint(models.Model):
default=list,
help_text='匹配的GF模式列表用于识别敏感端点如api, debug, config等'
)
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间NULL表示未删除')
# ==================== 管理器 ====================
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
class Meta:
db_table = 'endpoint'
verbose_name = '端点'
verbose_name_plural = '端点'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['-discovered_at']),
models.Index(fields=['-created_at']),
models.Index(fields=['target']), # 优化从target_id快速查找下面的端点主关联字段
models.Index(fields=['url']), # URL索引优化查询性能
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['status_code']), # 状态码索引,优化筛选
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
models.Index(fields=['title']), # title索引优化智能过滤搜索
]
constraints = [
# 部分唯一约束:只对未删除记录生效
# 普通唯一约束:url + target 组合唯一
models.UniqueConstraint(
fields=['url', 'target'],
condition=models.Q(deleted_at__isnull=True),
name='unique_endpoint_url_target_active'
name='unique_endpoint_url_target'
)
]
@@ -197,7 +173,7 @@ class WebSite(models.Model):
default='',
help_text='重定向地址HTTP 3xx 响应头 Location'
)
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
title = models.CharField(
max_length=1000,
blank=True,
@@ -243,32 +219,25 @@ class WebSite(models.Model):
blank=True,
help_text='是否支持虚拟主机'
)
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间NULL表示未删除')
# ==================== 管理器 ====================
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
class Meta:
db_table = 'website'
verbose_name = '站点'
verbose_name_plural = '站点'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['-discovered_at']),
models.Index(fields=['-created_at']),
models.Index(fields=['url']), # URL索引优化查询性能
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['target']), # 优化从target_id快速查找下面的站点
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
models.Index(fields=['title']), # title索引优化智能过滤搜索
models.Index(fields=['status_code']), # 状态码索引,优化智能过滤搜索
]
constraints = [
# 部分唯一约束:只对未删除记录生效
# 普通唯一约束:url + target 组合唯一
models.UniqueConstraint(
fields=['url', 'target'],
condition=models.Q(deleted_at__isnull=True),
name='unique_website_url_target_active'
name='unique_website_url_target'
)
]
@@ -282,19 +251,11 @@ class Directory(models.Model):
"""
id = models.AutoField(primary_key=True)
website = models.ForeignKey(
'Website',
on_delete=models.CASCADE,
related_name='directories',
help_text='所属的站点(主关联字段,表示所属关系,不能为空)'
)
target = models.ForeignKey(
'targets.Target', # 使用字符串引用
'targets.Target',
on_delete=models.CASCADE,
related_name='directories',
null=True,
blank=True,
help_text='所属的扫描目标(冗余字段,用于快速查询)'
help_text='所属的扫描目标'
)
url = models.CharField(
@@ -335,34 +296,24 @@ class Directory(models.Model):
help_text='请求耗时(单位:纳秒)'
)
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间NULL表示未删除')
# ==================== 管理器 ====================
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'directory'
verbose_name = '目录'
verbose_name_plural = '目录'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['-discovered_at']),
models.Index(fields=['-created_at']),
models.Index(fields=['target']), # 优化从target_id快速查找下面的目录
models.Index(fields=['url']), # URL索引优化搜索和唯一约束
models.Index(fields=['website']), # 站点索引,优化按站点查询
models.Index(fields=['status']), # 状态码索引,优化筛选
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
]
constraints = [
# 部分唯一约束:只对未删除记录生效
# 普通唯一约束:target + url 组合唯一
models.UniqueConstraint(
fields=['website', 'url'],
condition=models.Q(deleted_at__isnull=True),
name='unique_directory_url_website_active'
fields=['target', 'url'],
name='unique_directory_url_target'
),
]
@@ -410,43 +361,29 @@ class HostPortMapping(models.Model):
)
# ==================== 时间字段 ====================
discovered_at = models.DateTimeField(
created_at = models.DateTimeField(
auto_now_add=True,
help_text='发现时间'
help_text='创建时间'
)
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(
null=True,
blank=True,
db_index=True,
help_text='删除时间NULL表示未删除'
)
# ==================== 管理器 ====================
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
class Meta:
db_table = 'host_port_mapping'
verbose_name = '主机端口映射'
verbose_name_plural = '主机端口映射'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['target']), # 优化按目标查询
models.Index(fields=['host']), # 优化按主机名查询
models.Index(fields=['ip']), # 优化按IP查询
models.Index(fields=['port']), # 优化按端口查询
models.Index(fields=['host', 'ip']), # 优化组合查询
models.Index(fields=['-discovered_at']), # 优化时间排序
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
models.Index(fields=['-created_at']), # 优化时间排序
]
constraints = [
# 复合唯一约束target + host + ip + port 组合唯一(只对未删除记录生效)
# 复合唯一约束target + host + ip + port 组合唯一
models.UniqueConstraint(
fields=['target', 'host', 'ip', 'port'],
condition=models.Q(deleted_at__isnull=True),
name='unique_target_host_ip_port_active'
name='unique_target_host_ip_port'
),
]
@@ -474,7 +411,7 @@ class Vulnerability(models.Model):
)
# ==================== 核心字段 ====================
url = models.TextField(help_text='漏洞所在的URL')
url = models.CharField(max_length=2000, help_text='漏洞所在的URL')
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli')
severity = models.CharField(
max_length=20,
@@ -488,27 +425,20 @@ class Vulnerability(models.Model):
raw_output = models.JSONField(blank=True, default=dict, help_text='工具原始输出')
# ==================== 时间字段 ====================
discovered_at = models.DateTimeField(auto_now_add=True, help_text='首次发现时间')
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间NULL表示未删除')
# ==================== 管理器 ====================
objects = SoftDeleteManager()
all_objects = models.Manager()
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'vulnerability'
verbose_name = '漏洞'
verbose_name_plural = '漏洞'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['target']),
models.Index(fields=['vuln_type']),
models.Index(fields=['severity']),
models.Index(fields=['source']),
models.Index(fields=['-discovered_at']),
models.Index(fields=['deleted_at', '-discovered_at']),
models.Index(fields=['url']), # url索引优化智能过滤搜索
models.Index(fields=['-created_at']),
]
def __str__(self):

View File

@@ -15,17 +15,17 @@ class SubdomainSnapshot(models.Model):
)
name = models.CharField(max_length=1000, help_text='子域名名称')
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'subdomain_snapshot'
verbose_name = '子域名快照'
verbose_name_plural = '子域名快照'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['name']),
models.Index(fields=['-discovered_at']),
models.Index(fields=['-created_at']),
]
constraints = [
# 唯一约束:同一次扫描中,同一个子域名只能记录一次
@@ -70,18 +70,19 @@ class WebsiteSnapshot(models.Model):
)
body_preview = models.TextField(blank=True, default='', help_text='响应体预览')
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'website_snapshot'
verbose_name = '网站快照'
verbose_name_plural = '网站快照'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['url']),
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['-discovered_at']),
models.Index(fields=['title']), # title索引优化标题搜索
models.Index(fields=['-created_at']),
]
constraints = [
# 唯一约束同一次扫描中同一个URL只能记录一次
@@ -118,18 +119,19 @@ class DirectorySnapshot(models.Model):
lines = models.IntegerField(null=True, blank=True, help_text='响应体行数(按换行符分割)')
content_type = models.CharField(max_length=200, blank=True, default='', help_text='响应头 Content-Type 值')
duration = models.BigIntegerField(null=True, blank=True, help_text='请求耗时(单位:纳秒)')
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'directory_snapshot'
verbose_name = '目录快照'
verbose_name_plural = '目录快照'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['url']),
models.Index(fields=['status']), # 状态码索引,优化筛选
models.Index(fields=['-discovered_at']),
models.Index(fields=['content_type']), # content_type索引优化内容类型搜索
models.Index(fields=['-created_at']),
]
constraints = [
# 唯一约束同一次扫描中同一个目录URL只能记录一次
@@ -183,16 +185,16 @@ class HostPortMappingSnapshot(models.Model):
)
# ==================== 时间字段 ====================
discovered_at = models.DateTimeField(
created_at = models.DateTimeField(
auto_now_add=True,
help_text='发现时间'
help_text='创建时间'
)
class Meta:
db_table = 'host_port_mapping_snapshot'
verbose_name = '主机端口映射快照'
verbose_name_plural = '主机端口映射快照'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']), # 优化按扫描查询
models.Index(fields=['host']), # 优化按主机名查询
@@ -200,7 +202,7 @@ class HostPortMappingSnapshot(models.Model):
models.Index(fields=['port']), # 优化按端口查询
models.Index(fields=['host', 'ip']), # 优化组合查询
models.Index(fields=['scan', 'host']), # 优化扫描+主机查询
models.Index(fields=['-discovered_at']), # 优化时间排序
models.Index(fields=['-created_at']), # 优化时间排序
]
constraints = [
# 复合唯一约束同一次扫描中scan + host + ip + port 组合唯一
@@ -257,19 +259,21 @@ class EndpointSnapshot(models.Model):
default=list,
help_text='匹配的GF模式列表'
)
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'endpoint_snapshot'
verbose_name = '端点快照'
verbose_name_plural = '端点快照'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['url']),
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['title']), # title索引优化标题搜索
models.Index(fields=['status_code']), # 状态码索引,优化筛选
models.Index(fields=['-discovered_at']),
models.Index(fields=['webserver']), # webserver索引优化服务器搜索
models.Index(fields=['-created_at']),
]
constraints = [
# 唯一约束同一次扫描中同一个URL只能记录一次
@@ -302,7 +306,7 @@ class VulnerabilitySnapshot(models.Model):
)
# ==================== 核心字段 ====================
url = models.TextField(help_text='漏洞所在的URL')
url = models.CharField(max_length=2000, help_text='漏洞所在的URL')
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli')
severity = models.CharField(
max_length=20,
@@ -316,19 +320,20 @@ class VulnerabilitySnapshot(models.Model):
raw_output = models.JSONField(blank=True, default=dict, help_text='工具原始输出')
# ==================== 时间字段 ====================
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'vulnerability_snapshot'
verbose_name = '漏洞快照'
verbose_name_plural = '漏洞快照'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['url']), # url索引优化URL搜索
models.Index(fields=['vuln_type']),
models.Index(fields=['severity']),
models.Index(fields=['source']),
models.Index(fields=['-discovered_at']),
models.Index(fields=['-created_at']),
]
def __str__(self):

View File

@@ -3,162 +3,141 @@ Django ORM 实现的 Directory Repository
"""
import logging
from typing import List, Tuple, Dict, Iterator
from django.db import transaction, IntegrityError, OperationalError, DatabaseError
from django.utils import timezone
from typing import List, Iterator
from django.db import transaction
from apps.asset.models.asset_models import Directory
from apps.asset.dtos import DirectoryDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoDirectoryRepository:
"""Django ORM 实现的 Directory Repository"""
def bulk_create_ignore_conflicts(self, items: List[DirectoryDTO]) -> int:
def bulk_upsert(self, items: List[DirectoryDTO]) -> int:
"""
批量创建 Directory,忽略冲突
批量创建或更新 Directoryupsert
存在则更新所有字段,不存在则创建。
使用 Django 原生 update_conflicts。
注意:自动按模型唯一约束去重,保留最后一条记录。
Args:
items: Directory DTO 列表
Returns:
int: 实际创建的记录数
Raises:
IntegrityError: 数据完整性错误
OperationalError: 数据库操作错误
DatabaseError: 数据库错误
int: 处理的记录数
"""
if not items:
return 0
try:
# 转换为 Django 模型对象
directory_objects = [
# 自动按模型唯一约束去重
unique_items = deduplicate_for_bulk(items, Directory)
# 直接从 DTO 字段构建 Model
directories = [
Directory(
website_id=item.website_id,
target_id=item.target_id,
url=item.url,
status=item.status,
content_length=item.content_length,
words=item.words,
lines=item.lines,
content_type=item.content_type,
content_type=item.content_type or '',
duration=item.duration
)
for item in items
for item in unique_items
]
with transaction.atomic():
# 批量插入或忽略冲突
# 如果 website + url 已存在,忽略冲突
Directory.objects.bulk_create(
directory_objects,
ignore_conflicts=True
directories,
update_conflicts=True,
unique_fields=['target', 'url'],
update_fields=[
'status', 'content_length', 'words',
'lines', 'content_type', 'duration'
],
batch_size=1000
)
logger.debug(f"成功处理 {len(items)} Directory 记录")
return len(items)
except IntegrityError as e:
logger.error(
f"批量插入 Directory 失败 - 数据完整性错误: {e}, "
f"记录数: {len(items)}"
)
raise
except OperationalError as e:
logger.error(
f"批量插入 Directory 失败 - 数据库操作错误: {e}, "
f"记录数: {len(items)}"
)
raise
except DatabaseError as e:
logger.error(
f"批量插入 Directory 失败 - 数据库错误: {e}, "
f"记录数: {len(items)}"
)
raise
logger.debug(f"批量 upsert Directory 成功: {len(unique_items)}")
return len(unique_items)
except Exception as e:
logger.error(
f"批量插入 Directory 失败 - 未知错误: {e}, "
f"记录数: {len(items)}, "
f"错误类型: {type(e).__name__}",
exc_info=True
)
logger.error(f"批量 upsert Directory 失败: {e}")
raise
def get_by_website(self, website_id: int) -> List[DirectoryDTO]:
def bulk_create_ignore_conflicts(self, items: List[DirectoryDTO]) -> int:
"""
获取指定站点的所有目录
批量创建 Directory存在即跳过
与 bulk_upsert 不同,此方法不会更新已存在的记录。
适用于批量添加场景,只提供 URL没有其他字段数据。
注意:自动按模型唯一约束去重,保留最后一条记录。
Args:
website_id: 站点 ID
items: Directory DTO 列表
Returns:
List[DirectoryDTO]: 目录列表
int: 处理的记录数
"""
if not items:
return 0
try:
directories = Directory.objects.filter(website_id=website_id)
return [
DirectoryDTO(
website_id=d.website_id,
target_id=d.target_id,
url=d.url,
status=d.status,
content_length=d.content_length,
words=d.words,
lines=d.lines,
content_type=d.content_type,
duration=d.duration
# 自动按模型唯一约束去重
unique_items = deduplicate_for_bulk(items, Directory)
directories = [
Directory(
target_id=item.target_id,
url=item.url,
status=item.status,
content_length=item.content_length,
words=item.words,
lines=item.lines,
content_type=item.content_type or '',
duration=item.duration
)
for d in directories
for item in unique_items
]
except Exception as e:
logger.error(f"获取目录列表失败 - Website ID: {website_id}, 错误: {e}")
raise
def count_by_website(self, website_id: int) -> int:
"""
统计指定站点的目录总数
Args:
website_id: 站点 ID
Returns:
int: 目录总数
"""
try:
count = Directory.objects.filter(website_id=website_id).count()
logger.debug(f"Website {website_id} 的目录总数: {count}")
return count
with transaction.atomic():
Directory.objects.bulk_create(
directories,
ignore_conflicts=True,
batch_size=1000
)
logger.debug(f"批量创建 Directory 成功ignore_conflicts: {len(unique_items)}")
return len(unique_items)
except Exception as e:
logger.error(f"统计目录数量失败 - Website ID: {website_id}, 错误: {e}")
logger.error(f"批量创建 Directory 失败: {e}")
raise
def count_by_target(self, target_id: int) -> int:
"""统计目标下的目录总数"""
return Directory.objects.filter(target_id=target_id).count()
def get_all(self):
"""
获取所有目录
Returns:
QuerySet: 目录查询集
"""
return Directory.objects.all()
"""获取所有目录"""
return Directory.objects.all().order_by('-created_at')
def get_by_target(self, target_id: int):
return Directory.objects.filter(target_id=target_id).select_related('website').order_by('-discovered_at')
"""获取目标下的所有目录"""
return Directory.objects.filter(target_id=target_id).order_by('-created_at')
def get_urls_for_export(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
"""流式导出目标下的所有目录 URL(只查 url 字段,避免加载多余数据)。"""
"""流式导出目标下的所有目录 URL"""
try:
queryset = (
Directory.objects
@@ -172,78 +151,31 @@ class DjangoDirectoryRepository:
except Exception as e:
logger.error("流式导出目录 URL 失败 - Target ID: %s, 错误: %s", target_id, e)
raise
def soft_delete_by_ids(self, directory_ids: List[int]) -> int:
def iter_raw_data_for_export(
self,
target_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
根据 ID 列表批量软删除Directory
流式获取原始数据用于 CSV 导出
Args:
directory_ids: Directory ID 列表
target_id: 目标 ID
batch_size: 每批数据量
Returns:
软删除的记录数
Yields:
包含所有目录字段的字典
"""
try:
updated_count = (
Directory.objects
.filter(id__in=directory_ids)
.update(deleted_at=timezone.now())
qs = (
Directory.objects
.filter(target_id=target_id)
.values(
'url', 'status', 'content_length', 'words',
'lines', 'content_type', 'duration', 'created_at'
)
logger.debug(
"批量软删除Directory成功 - Count: %s, 更新记录: %s",
len(directory_ids),
updated_count
)
return updated_count
except Exception as e:
logger.error(
"批量软删除Directory失败 - IDs: %s, 错误: %s",
directory_ids,
e
)
raise
def hard_delete_by_ids(self, directory_ids: List[int]) -> Tuple[int, Dict[str, int]]:
"""
根据 ID 列表硬删除Directory使用数据库级 CASCADE
.order_by('url')
)
Args:
directory_ids: Directory ID 列表
Returns:
(删除的记录数, 删除详情字典)
"""
try:
batch_size = 1000
total_deleted = 0
logger.debug(f"开始批量删除 {len(directory_ids)} 个Directory数据库 CASCADE...")
for i in range(0, len(directory_ids), batch_size):
batch_ids = directory_ids[i:i + batch_size]
count, _ = Directory.all_objects.filter(id__in=batch_ids).delete()
total_deleted += count
logger.debug(f"批次删除完成: {len(batch_ids)} 个Directory删除 {count} 条记录")
deleted_details = {
'directories': len(directory_ids),
'total': total_deleted,
'note': 'Database CASCADE - detailed stats unavailable'
}
logger.debug(
"批量硬删除成功CASCADE- Directory数: %s, 总删除记录: %s",
len(directory_ids),
total_deleted
)
return total_deleted, deleted_details
except Exception as e:
logger.error(
"批量硬删除失败CASCADE- Directory数: %s, 错误: %s",
len(directory_ids),
str(e),
exc_info=True
)
raise
for row in qs.iterator(chunk_size=batch_size):
yield row

View File

@@ -1,11 +1,12 @@
"""Endpoint Repository - Django ORM 实现"""
import logging
from typing import List, Optional, Tuple, Dict, Any
from typing import List, Iterator
from apps.asset.models import Endpoint
from apps.asset.dtos.asset import EndpointDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
from django.db import transaction
logger = logging.getLogger(__name__)
@@ -15,25 +16,31 @@ logger = logging.getLogger(__name__)
class DjangoEndpointRepository:
"""端点 Repository - 负责端点表的数据访问"""
def bulk_create_ignore_conflicts(self, items: List[EndpointDTO]) -> int:
def bulk_upsert(self, items: List[EndpointDTO]) -> int:
"""
批量创建端点(忽略冲突
批量创建或更新端点upsert
存在则更新所有字段,不存在则创建。
使用 Django 原生 update_conflicts。
注意:自动按模型唯一约束去重,保留最后一条记录。
Args:
items: 端点 DTO 列表
Returns:
int: 创建的记录数
int: 处理的记录数
"""
if not items:
return 0
try:
endpoints = []
for item in items:
# Endpoint 模型当前只关联 target不再依赖 website 外键
# 这里按照 EndpointDTO 字段映射构造 Endpoint 实例
endpoints.append(Endpoint(
# 自动按模型唯一约束去重
unique_items = deduplicate_for_bulk(items, Endpoint)
# 直接从 DTO 字段构建 Model
endpoints = [
Endpoint(
target_id=item.target_id,
url=item.url,
host=item.host or '',
@@ -47,62 +54,35 @@ class DjangoEndpointRepository:
vhost=item.vhost,
location=item.location or '',
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
))
)
for item in unique_items
]
with transaction.atomic():
created = Endpoint.objects.bulk_create(
Endpoint.objects.bulk_create(
endpoints,
ignore_conflicts=True,
update_conflicts=True,
unique_fields=['url', 'target'],
update_fields=[
'host', 'title', 'status_code', 'content_length',
'webserver', 'body_preview', 'content_type', 'tech',
'vhost', 'location', 'matched_gf_patterns'
],
batch_size=1000
)
return len(created)
logger.debug(f"批量 upsert 端点成功: {len(unique_items)}")
return len(unique_items)
except Exception as e:
logger.error(f"批量创建端点失败: {e}")
logger.error(f"批量 upsert 端点失败: {e}")
raise
def get_by_website(self, website_id: int) -> List[EndpointDTO]:
"""
获取网站下的所有端点
Args:
website_id: 网站 ID
Returns:
List[EndpointDTO]: 端点列表
"""
endpoints = Endpoint.objects.filter(
website_id=website_id
).order_by('-discovered_at')
result = []
for endpoint in endpoints:
result.append(EndpointDTO(
website_id=endpoint.website_id,
target_id=endpoint.target_id,
url=endpoint.url,
title=endpoint.title,
status_code=endpoint.status_code,
content_length=endpoint.content_length,
webserver=endpoint.webserver,
body_preview=endpoint.body_preview,
content_type=endpoint.content_type,
tech=endpoint.tech,
vhost=endpoint.vhost,
location=endpoint.location,
matched_gf_patterns=endpoint.matched_gf_patterns
))
return result
def get_queryset_by_target(self, target_id: int):
return Endpoint.objects.filter(target_id=target_id).order_by('-discovered_at')
def get_all(self):
"""获取所有端点(全局查询)"""
return Endpoint.objects.all().order_by('-discovered_at')
return Endpoint.objects.all().order_by('-created_at')
def get_by_target(self, target_id: int) -> List[EndpointDTO]:
def get_by_target(self, target_id: int):
"""
获取目标下的所有端点
@@ -110,43 +90,9 @@ class DjangoEndpointRepository:
target_id: 目标 ID
Returns:
List[EndpointDTO]: 端点列表
QuerySet: 端点查询集
"""
endpoints = Endpoint.objects.filter(
target_id=target_id
).order_by('-discovered_at')
result = []
for endpoint in endpoints:
result.append(EndpointDTO(
website_id=endpoint.website_id,
target_id=endpoint.target_id,
url=endpoint.url,
title=endpoint.title,
status_code=endpoint.status_code,
content_length=endpoint.content_length,
webserver=endpoint.webserver,
body_preview=endpoint.body_preview,
content_type=endpoint.content_type,
tech=endpoint.tech,
vhost=endpoint.vhost,
location=endpoint.location,
matched_gf_patterns=endpoint.matched_gf_patterns
))
return result
def count_by_website(self, website_id: int) -> int:
"""
统计网站下的端点数量
Args:
website_id: 网站 ID
Returns:
int: 端点数量
"""
return Endpoint.objects.filter(website_id=website_id).count()
return Endpoint.objects.filter(target_id=target_id).order_by('-created_at')
def count_by_target(self, target_id: int) -> int:
"""
@@ -159,34 +105,88 @@ class DjangoEndpointRepository:
int: 端点数量
"""
return Endpoint.objects.filter(target_id=target_id).count()
def soft_delete_by_ids(self, ids: List[int]) -> int:
def bulk_create_ignore_conflicts(self, items: List[EndpointDTO]) -> int:
"""
软删除端点(批量)
批量创建端点(存在即跳过
与 bulk_upsert 不同,此方法不会更新已存在的记录。
适用于快速扫描场景,只提供 URL没有其他字段数据。
注意:自动按模型唯一约束去重,保留最后一条记录。
Args:
ids: 端点 ID 列表
items: 端点 DTO 列表
Returns:
int: 更新的记录数
int: 处理的记录数
"""
from django.utils import timezone
return Endpoint.objects.filter(
id__in=ids
).update(deleted_at=timezone.now())
def hard_delete_by_ids(self, ids: List[int]) -> Tuple[int, Dict[str, int]]:
if not items:
return 0
try:
# 自动按模型唯一约束去重
unique_items = deduplicate_for_bulk(items, Endpoint)
# 直接从 DTO 字段构建 Model
endpoints = [
Endpoint(
target_id=item.target_id,
url=item.url,
host=item.host or '',
title=item.title or '',
status_code=item.status_code,
content_length=item.content_length,
webserver=item.webserver or '',
body_preview=item.body_preview or '',
content_type=item.content_type or '',
tech=item.tech if item.tech else [],
vhost=item.vhost,
location=item.location or '',
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
)
for item in unique_items
]
with transaction.atomic():
Endpoint.objects.bulk_create(
endpoints,
ignore_conflicts=True,
batch_size=1000
)
logger.debug(f"批量创建端点成功ignore_conflicts: {len(unique_items)}")
return len(unique_items)
except Exception as e:
logger.error(f"批量创建端点失败: {e}")
raise
def iter_raw_data_for_export(
self,
target_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
硬删除端点(批量)
流式获取原始数据用于 CSV 导出
Args:
ids: 端点 ID 列表
Returns:
Tuple[int, Dict[str, int]]: (删除总数, 详细信息)
"""
deleted_count, details = Endpoint.all_objects.filter(
id__in=ids
).delete()
target_id: 目标 ID
batch_size: 每批数据量
return deleted_count, details
Yields:
包含所有端点字段的字典
"""
qs = (
Endpoint.objects
.filter(target_id=target_id)
.values(
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
)
.order_by('url')
)
for row in qs.iterator(chunk_size=batch_size):
yield row

View File

@@ -1,32 +1,36 @@
"""HostPortMapping Repository - Django ORM 实现"""
import logging
from typing import List, Iterator
from typing import List, Iterator, Dict, Optional
from django.db.models import QuerySet, Min
from apps.asset.models.asset_models import HostPortMapping
from apps.asset.dtos.asset import HostPortMappingDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoHostPortMappingRepository:
"""HostPortMapping Repository - Django ORM 实现"""
"""HostPortMapping Repository - Django ORM 实现
职责:纯数据访问,不包含业务逻辑
"""
def bulk_create_ignore_conflicts(self, items: List[HostPortMappingDTO]) -> int:
"""
批量创建主机端口关联(忽略冲突)
注意:自动按模型唯一约束去重,保留最后一条记录。
Args:
items: 主机端口关联 DTO 列表
Returns:
int: 实际创建的记录数注意ignore_conflicts 时可能为 0
Note:
- 基于唯一约束 (target + host + ip + port) 自动去重
- 忽略已存在的记录,不更新
int: 实际创建的记录数
"""
try:
logger.debug("准备批量创建主机端口关联 - 数量: %d", len(items))
@@ -34,18 +38,20 @@ class DjangoHostPortMappingRepository:
if not items:
logger.debug("主机端口关联为空,跳过创建")
return 0
# 自动按模型唯一约束去重
unique_items = deduplicate_for_bulk(items, HostPortMapping)
# 构建记录对象
records = []
for item in items:
records.append(HostPortMapping(
records = [
HostPortMapping(
target_id=item.target_id,
host=item.host,
ip=item.ip,
port=item.port
))
)
for item in unique_items
]
# 批量创建(忽略冲突,基于唯一约束去重)
created = HostPortMapping.objects.bulk_create(
records,
ignore_conflicts=True
@@ -89,79 +95,47 @@ class DjangoHostPortMappingRepository:
for ip in queryset:
yield ip
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
from django.db.models import Min
def get_queryset_by_target(self, target_id: int) -> QuerySet:
"""获取目标下的 QuerySet"""
return HostPortMapping.objects.filter(target_id=target_id)
qs = HostPortMapping.objects.filter(target_id=target_id)
if search:
qs = qs.filter(ip__icontains=search)
def get_all_queryset(self) -> QuerySet:
"""获取所有记录的 QuerySet"""
return HostPortMapping.objects.all()
ip_aggregated = (
qs
.values('ip')
.annotate(
discovered_at=Min('discovered_at')
)
.order_by('-discovered_at')
def get_queryset_by_ip(self, ip: str, target_id: Optional[int] = None) -> QuerySet:
"""获取指定 IP 的 QuerySet"""
qs = HostPortMapping.objects.filter(ip=ip)
if target_id:
qs = qs.filter(target_id=target_id)
return qs
def iter_raw_data_for_export(
self,
target_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
target_id: 目标 ID
batch_size: 每批数据量
Yields:
{
'ip': '192.168.1.1',
'host': 'example.com',
'port': 80,
'created_at': datetime
}
"""
qs = (
HostPortMapping.objects
.filter(target_id=target_id)
.values('ip', 'host', 'port', 'created_at')
.order_by('ip', 'host', 'port')
)
results = []
for item in ip_aggregated:
ip = item['ip']
mappings = (
HostPortMapping.objects
.filter(target_id=target_id, ip=ip)
.values('host', 'port')
.distinct()
)
hosts = sorted({m['host'] for m in mappings})
ports = sorted({m['port'] for m in mappings})
results.append({
'ip': ip,
'hosts': hosts,
'ports': ports,
'discovered_at': item['discovered_at'],
})
return results
def get_all_ip_aggregation(self, search: str = None):
"""获取所有 IP 聚合数据(全局查询)"""
from django.db.models import Min
qs = HostPortMapping.objects.all()
if search:
qs = qs.filter(ip__icontains=search)
ip_aggregated = (
qs
.values('ip')
.annotate(
discovered_at=Min('discovered_at')
)
.order_by('-discovered_at')
)
results = []
for item in ip_aggregated:
ip = item['ip']
mappings = (
HostPortMapping.objects
.filter(ip=ip)
.values('host', 'port')
.distinct()
)
hosts = sorted({m['host'] for m in mappings})
ports = sorted({m['port'] for m in mappings})
results.append({
'ip': ip,
'hosts': hosts,
'ports': ports,
'discovered_at': item['discovered_at'],
})
return results
for row in qs.iterator(chunk_size=batch_size):
yield row

View File

@@ -1,117 +1,72 @@
"""Subdomain Repository - Django ORM 实现"""
import logging
from typing import List, Iterator
from django.db import transaction, IntegrityError, OperationalError, DatabaseError
from django.utils import timezone
from typing import Tuple, Dict
from django.db import transaction
from apps.asset.models.asset_models import Subdomain
from apps.asset.dtos import SubdomainDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoSubdomainRepository:
"""基于 Django ORM 的子域名仓储实现"""
"""基于 Django ORM 的子域名仓储实现"""
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
"""
批量创建子域名,忽略冲突
注意:自动按模型唯一约束去重,保留最后一条记录。
Args:
items: 子域名 DTO 列表
Raises:
IntegrityError: 数据完整性错误(如唯一约束冲突)
OperationalError: 数据库操作错误(如连接失败)
DatabaseError: 其他数据库错误
"""
if not items:
return
try:
# 自动按模型唯一约束去重
unique_items = deduplicate_for_bulk(items, Subdomain)
subdomain_objects = [
Subdomain(
name=item.name,
target_id=item.target_id,
)
for item in items
for item in unique_items
]
with transaction.atomic():
# 使用 ignore_conflicts 策略:
# - 新子域名INSERT 完整记录
# - 已存在子域名:忽略(不更新,因为没有探测字段数据)
# 注意ignore_conflicts 无法返回实际创建的数量
Subdomain.objects.bulk_create( # type: ignore[attr-defined]
Subdomain.objects.bulk_create(
subdomain_objects,
ignore_conflicts=True, # 忽略重复记录
ignore_conflicts=True,
)
logger.debug(f"成功处理 {len(items)} 条子域名记录")
except IntegrityError as e:
logger.error(
f"批量插入子域名失败 - 数据完整性错误: {e}, "
f"记录数: {len(items)}, "
f"示例域名: {items[0].name if items else 'N/A'}"
)
raise
except OperationalError as e:
logger.error(
f"批量插入子域名失败 - 数据库操作错误: {e}, "
f"记录数: {len(items)}"
)
raise
except DatabaseError as e:
logger.error(
f"批量插入子域名失败 - 数据库错误: {e}, "
f"记录数: {len(items)}"
)
raise
logger.debug(f"成功处理 {len(unique_items)} 条子域名记录")
except Exception as e:
logger.error(
f"批量插入子域名失败 - 未知错误: {e}, "
f"记录数: {len(items)}, "
f"错误类型: {type(e).__name__}",
exc_info=True
)
logger.error(f"批量插入子域名失败: {e}")
raise
def get_or_create(self, name: str, target_id: int) -> Tuple[Subdomain, bool]:
"""
获取或创建子域名
Args:
name: 子域名名称
target_id: 目标 ID
Returns:
(Subdomain对象, 是否新创建)
"""
return Subdomain.objects.get_or_create(
name=name,
target_id=target_id,
)
def get_all(self):
"""获取所有子域名"""
return Subdomain.objects.all().order_by('-created_at')
def get_by_target(self, target_id: int):
"""获取目标下的所有子域名"""
return Subdomain.objects.filter(target_id=target_id).order_by('-created_at')
def count_by_target(self, target_id: int) -> int:
"""统计目标下的域名数量"""
return Subdomain.objects.filter(target_id=target_id).count()
def get_domains_for_export(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
"""
流式导出域名(用于生成扫描工具输入文件)
使用 iterator() 进行流式查询,避免一次性加载所有数据到内存
Args:
target_id: 目标 ID
batch_size: 每次从数据库读取的行数
Yields:
str: 域名
"""
"""流式导出域名"""
queryset = Subdomain.objects.filter(
target_id=target_id
).only('name').iterator(chunk_size=batch_size)
@@ -119,138 +74,36 @@ class DjangoSubdomainRepository:
for subdomain in queryset:
yield subdomain.name
def get_by_target(self, target_id: int):
return Subdomain.objects.filter(target_id=target_id).order_by('-discovered_at')
def count_by_target(self, target_id: int) -> int:
"""
统计目标下的域名数量
Args:
target_id: 目标 ID
Returns:
int: 域名数量
"""
return Subdomain.objects.filter(target_id=target_id).count()
def get_by_names_and_target_id(self, names: set, target_id: int) -> dict:
"""
根据域名列表和目标ID批量查询 Subdomain
Args:
names: 域名集合
target_id: 目标 ID
Returns:
dict: {domain_name: Subdomain对象}
"""
"""根据域名列表和目标ID批量查询 Subdomain"""
subdomains = Subdomain.objects.filter(
name__in=names,
target_id=target_id
).only('id', 'name')
return {sd.name: sd for sd in subdomains}
def get_all(self):
def iter_raw_data_for_export(
self,
target_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
获取所有子域名
Returns:
QuerySet: 子域名查询集
"""
return Subdomain.objects.all()
def soft_delete_by_ids(self, subdomain_ids: List[int]) -> int:
"""
根据 ID 列表批量软删除子域名
流式获取原始数据用于 CSV 导出
Args:
subdomain_ids: 子域名 ID 列表
target_id: 目标 ID
batch_size: 每批数据量
Returns:
软删除的记录数
Note:
- 使用软删除:只标记为已删除,不真正删除数据库记录
- 保留所有关联数据,可恢复
Yields:
{'name': 'sub.example.com', 'created_at': datetime}
"""
try:
updated_count = (
Subdomain.objects
.filter(id__in=subdomain_ids)
.update(deleted_at=timezone.now())
)
logger.debug(
"批量软删除子域名成功 - Count: %s, 更新记录: %s",
len(subdomain_ids),
updated_count
)
return updated_count
except Exception as e:
logger.error(
"批量软删除子域名失败 - IDs: %s, 错误: %s",
subdomain_ids,
e
)
raise
def hard_delete_by_ids(self, subdomain_ids: List[int]) -> Tuple[int, Dict[str, int]]:
"""
根据 ID 列表硬删除子域名(使用数据库级 CASCADE
qs = (
Subdomain.objects
.filter(target_id=target_id)
.values('name', 'created_at')
.order_by('name')
)
Args:
subdomain_ids: 子域名 ID 列表
Returns:
(删除的记录数, 删除详情字典)
Strategy:
使用数据库级 CASCADE 删除,性能最优
Note:
- 硬删除:从数据库中永久删除
- 数据库自动处理所有外键级联删除
- 不触发 Django 信号pre_delete/post_delete
"""
try:
batch_size = 1000 # 每批处理1000个子域名
total_deleted = 0
logger.debug(f"开始批量删除 {len(subdomain_ids)} 个子域名(数据库 CASCADE...")
# 分批处理子域名ID避免单次删除过多
for i in range(0, len(subdomain_ids), batch_size):
batch_ids = subdomain_ids[i:i + batch_size]
# 直接删除子域名,数据库自动级联删除所有关联数据
count, _ = Subdomain.all_objects.filter(id__in=batch_ids).delete()
total_deleted += count
logger.debug(f"批次删除完成: {len(batch_ids)} 个子域名,删除 {count} 条记录")
# 由于使用数据库 CASCADE无法获取详细统计
deleted_details = {
'subdomains': len(subdomain_ids),
'total': total_deleted,
'note': 'Database CASCADE - detailed stats unavailable'
}
logger.debug(
"批量硬删除成功CASCADE- 子域名数: %s, 总删除记录: %s",
len(subdomain_ids),
total_deleted
)
return total_deleted, deleted_details
except Exception as e:
logger.error(
"批量硬删除失败CASCADE- 子域名数: %s, 错误: %s",
len(subdomain_ids),
str(e),
exc_info=True
)
raise
for row in qs.iterator(chunk_size=batch_size):
yield row

View File

@@ -3,110 +3,87 @@ Django ORM 实现的 WebSite Repository
"""
import logging
from typing import List, Generator, Tuple, Dict, Optional
from django.db import transaction, IntegrityError, OperationalError, DatabaseError
from django.utils import timezone
from typing import List, Generator, Optional, Iterator
from django.db import transaction
from apps.asset.models.asset_models import WebSite
from apps.asset.dtos import WebSiteDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoWebSiteRepository:
"""Django ORM 实现的 WebSite Repository"""
def bulk_create_ignore_conflicts(self, items: List[WebSiteDTO]) -> None:
def bulk_upsert(self, items: List[WebSiteDTO]) -> int:
"""
批量创建 WebSite,忽略冲突
批量创建或更新 WebSiteupsert
存在则更新所有字段,不存在则创建。
使用 Django 原生 update_conflicts。
注意:自动按模型唯一约束去重,保留最后一条记录。
Args:
items: WebSite DTO 列表
Raises:
IntegrityError: 数据完整性错误
OperationalError: 数据库操作错误
DatabaseError: 数据库错误
Returns:
int: 处理的记录数
"""
if not items:
return
return 0
try:
# 转换为 Django 模型对象
website_objects = [
# 自动按模型唯一约束去重
unique_items = deduplicate_for_bulk(items, WebSite)
# 直接从 DTO 字段构建 Model
websites = [
WebSite(
target_id=item.target_id,
url=item.url,
host=item.host,
location=item.location,
title=item.title,
webserver=item.webserver,
body_preview=item.body_preview,
content_type=item.content_type,
tech=item.tech,
host=item.host or '',
location=item.location or '',
title=item.title or '',
webserver=item.webserver or '',
body_preview=item.body_preview or '',
content_type=item.content_type or '',
tech=item.tech if item.tech else [],
status_code=item.status_code,
content_length=item.content_length,
vhost=item.vhost
)
for item in items
for item in unique_items
]
with transaction.atomic():
# 批量插入或更新
# 如果URL和目标已存在忽略冲突
WebSite.objects.bulk_create(
website_objects,
ignore_conflicts=True
websites,
update_conflicts=True,
unique_fields=['url', 'target'],
update_fields=[
'host', 'location', 'title', 'webserver',
'body_preview', 'content_type', 'tech',
'status_code', 'content_length', 'vhost'
],
batch_size=1000
)
logger.debug(f"成功处理 {len(items)} WebSite 记录")
except IntegrityError as e:
logger.error(
f"批量插入 WebSite 失败 - 数据完整性错误: {e}, "
f"记录数: {len(items)}"
)
raise
except OperationalError as e:
logger.error(
f"批量插入 WebSite 失败 - 数据库操作错误: {e}, "
f"记录数: {len(items)}"
)
raise
except DatabaseError as e:
logger.error(
f"批量插入 WebSite 失败 - 数据库错误: {e}, "
f"记录数: {len(items)}"
)
raise
logger.debug(f"批量 upsert WebSite 成功: {len(unique_items)}")
return len(unique_items)
except Exception as e:
logger.error(
f"批量插入 WebSite 失败 - 未知错误: {e}, "
f"记录数: {len(items)}, "
f"错误类型: {type(e).__name__}",
exc_info=True
)
logger.error(f"批量 upsert WebSite 失败: {e}")
raise
def get_urls_for_export(self, target_id: int, batch_size: int = 1000) -> Generator[str, None, None]:
"""
流式导出目标下的所有站点 URL
Args:
target_id: 目标 ID
batch_size: 批次大小
Yields:
str: 站点 URL
"""
try:
# 查询目标下的站点,只选择 URL 字段,避免不必要的数据传输
queryset = WebSite.objects.filter(
target_id=target_id
).values_list('url', flat=True).iterator(chunk_size=batch_size)
@@ -117,144 +94,93 @@ class DjangoWebSiteRepository:
logger.error(f"流式导出站点 URL 失败 - Target ID: {target_id}, 错误: {e}")
raise
def get_all(self):
"""获取所有网站"""
return WebSite.objects.all().order_by('-created_at')
def get_by_target(self, target_id: int):
return WebSite.objects.filter(target_id=target_id).order_by('-discovered_at')
"""获取目标下的所有网站"""
return WebSite.objects.filter(target_id=target_id).order_by('-created_at')
def count_by_target(self, target_id: int) -> int:
"""
统计目标下的站点总数
Args:
target_id: 目标 ID
Returns:
int: 站点总数
"""
try:
count = WebSite.objects.filter(target_id=target_id).count()
logger.debug(f"Target {target_id} 的站点总数: {count}")
return count
except Exception as e:
logger.error(f"统计站点数量失败 - Target ID: {target_id}, 错误: {e}")
raise
def count_by_scan(self, scan_id: int) -> int:
"""
统计扫描下的站点总数
"""
try:
count = WebSite.objects.filter(scan_id=scan_id).count()
logger.debug(f"Scan {scan_id} 的站点总数: {count}")
return count
except Exception as e:
logger.error(f"统计站点数量失败 - Scan ID: {scan_id}, 错误: {e}")
raise
"""统计目标下的站点总数"""
return WebSite.objects.filter(target_id=target_id).count()
def get_by_url(self, url: str, target_id: int) -> Optional[int]:
"""根据 URL 和 target_id 查找站点 ID"""
website = WebSite.objects.filter(url=url, target_id=target_id).first()
return website.id if website else None
def bulk_create_ignore_conflicts(self, items: List[WebSiteDTO]) -> int:
"""
根据 URL 和 target_id 查找站点 ID
批量创建 WebSite存在即跳过
注意:自动按模型唯一约束去重,保留最后一条记录。
"""
if not items:
return 0
try:
# 自动按模型唯一约束去重
unique_items = deduplicate_for_bulk(items, WebSite)
websites = [
WebSite(
target_id=item.target_id,
url=item.url,
host=item.host or '',
location=item.location or '',
title=item.title or '',
webserver=item.webserver or '',
body_preview=item.body_preview or '',
content_type=item.content_type or '',
tech=item.tech if item.tech else [],
status_code=item.status_code,
content_length=item.content_length,
vhost=item.vhost
)
for item in unique_items
]
with transaction.atomic():
WebSite.objects.bulk_create(
websites,
ignore_conflicts=True,
batch_size=1000
)
logger.debug(f"批量创建 WebSite 成功ignore_conflicts: {len(unique_items)}")
return len(unique_items)
except Exception as e:
logger.error(f"批量创建 WebSite 失败: {e}")
raise
def iter_raw_data_for_export(
self,
target_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
url: 站点 URL
target_id: 目标 ID
Returns:
Optional[int]: 站点 ID如果不存在返回 None
Raises:
ValueError: 发现多个站点时
"""
try:
website = WebSite.objects.filter(url=url, target_id=target_id).first()
if website:
return website.id
return None
except Exception as e:
logger.error(f"查询站点失败 - URL: {url}, Target ID: {target_id}, 错误: {e}")
raise
def get_all(self):
"""
获取所有网站
batch_size: 每批数据量
Returns:
QuerySet: 网站查询集
Yields:
包含所有网站字段的字典
"""
return WebSite.objects.all()
def soft_delete_by_ids(self, website_ids: List[int]) -> int:
"""
根据 ID 列表批量软删除WebSite
Args:
website_ids: WebSite ID 列表
Returns:
软删除的记录数
"""
try:
updated_count = (
WebSite.objects
.filter(id__in=website_ids)
.update(deleted_at=timezone.now())
qs = (
WebSite.objects
.filter(target_id=target_id)
.values(
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'created_at'
)
logger.debug(
"批量软删除WebSite成功 - Count: %s, 更新记录: %s",
len(website_ids),
updated_count
)
return updated_count
except Exception as e:
logger.error(
"批量软删除WebSite失败 - IDs: %s, 错误: %s",
website_ids,
e
)
raise
def hard_delete_by_ids(self, website_ids: List[int]) -> Tuple[int, Dict[str, int]]:
"""
根据 ID 列表硬删除WebSite使用数据库级 CASCADE
.order_by('url')
)
Args:
website_ids: WebSite ID 列表
Returns:
(删除的记录数, 删除详情字典)
"""
try:
batch_size = 1000
total_deleted = 0
logger.debug(f"开始批量删除 {len(website_ids)} 个WebSite数据库 CASCADE...")
for i in range(0, len(website_ids), batch_size):
batch_ids = website_ids[i:i + batch_size]
count, _ = WebSite.all_objects.filter(id__in=batch_ids).delete()
total_deleted += count
logger.debug(f"批次删除完成: {len(batch_ids)} 个WebSite删除 {count} 条记录")
deleted_details = {
'websites': len(website_ids),
'total': total_deleted,
'note': 'Database CASCADE - detailed stats unavailable'
}
logger.debug(
"批量硬删除成功CASCADE- WebSite数: %s, 总删除记录: %s",
len(website_ids),
total_deleted
)
return total_deleted, deleted_details
except Exception as e:
logger.error(
"批量硬删除失败CASCADE- WebSite数: %s, 错误: %s",
len(website_ids),
str(e),
exc_info=True
)
raise
for row in qs.iterator(chunk_size=batch_size):
yield row

View File

@@ -1,12 +1,13 @@
"""Directory Snapshot Repository - 目录快照数据访问层"""
import logging
from typing import List
from typing import List, Iterator
from django.db import transaction
from apps.asset.models import DirectorySnapshot
from apps.asset.dtos.snapshot import DirectorySnapshotDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
logger = logging.getLogger(__name__)
@@ -25,6 +26,8 @@ class DjangoDirectorySnapshotRepository:
使用 ignore_conflicts 策略,如果快照已存在(相同 scan + url则跳过
注意:会自动按 (scan_id, url) 去重,保留最后一条记录。
Args:
items: 目录快照 DTO 列表
@@ -37,6 +40,9 @@ class DjangoDirectorySnapshotRepository:
return
try:
# 根据模型唯一约束自动去重
unique_items = deduplicate_for_bulk(items, DirectorySnapshot)
# 转换为 Django 模型对象
snapshot_objects = [
DirectorySnapshot(
@@ -49,7 +55,7 @@ class DjangoDirectorySnapshotRepository:
content_type=item.content_type,
duration=item.duration
)
for item in items
for item in unique_items
]
with transaction.atomic():
@@ -60,7 +66,7 @@ class DjangoDirectorySnapshotRepository:
ignore_conflicts=True
)
logger.debug("成功保存 %d 条目录快照记录", len(items))
logger.debug("成功保存 %d 条目录快照记录", len(unique_items))
except Exception as e:
logger.error(
@@ -72,7 +78,35 @@ class DjangoDirectorySnapshotRepository:
raise
def get_by_scan(self, scan_id: int):
return DirectorySnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
return DirectorySnapshot.objects.filter(scan_id=scan_id).order_by('-created_at')
def get_all(self):
return DirectorySnapshot.objects.all().order_by('-discovered_at')
return DirectorySnapshot.objects.all().order_by('-created_at')
def iter_raw_data_for_export(
self,
scan_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
scan_id: 扫描 ID
batch_size: 每批数据量
Yields:
包含所有目录字段的字典
"""
qs = (
DirectorySnapshot.objects
.filter(scan_id=scan_id)
.values(
'url', 'status', 'content_length', 'words',
'lines', 'content_type', 'duration', 'created_at'
)
.order_by('url')
)
for row in qs.iterator(chunk_size=batch_size):
yield row

View File

@@ -1,11 +1,12 @@
"""EndpointSnapshot Repository - Django ORM 实现"""
import logging
from typing import List
from typing import List, Iterator
from apps.asset.models.snapshot_models import EndpointSnapshot
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
logger = logging.getLogger(__name__)
@@ -18,6 +19,8 @@ class DjangoEndpointSnapshotRepository:
"""
保存端点快照
注意:会自动按 (scan_id, url) 去重,保留最后一条记录。
Args:
items: 端点快照 DTO 列表
@@ -31,10 +34,13 @@ class DjangoEndpointSnapshotRepository:
if not items:
logger.debug("端点快照为空,跳过保存")
return
# 根据模型唯一约束自动去重
unique_items = deduplicate_for_bulk(items, EndpointSnapshot)
# 构建快照对象
snapshots = []
for item in items:
for item in unique_items:
snapshots.append(EndpointSnapshot(
scan_id=item.scan_id,
url=item.url,
@@ -68,7 +74,36 @@ class DjangoEndpointSnapshotRepository:
raise
def get_by_scan(self, scan_id: int):
return EndpointSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
return EndpointSnapshot.objects.filter(scan_id=scan_id).order_by('-created_at')
def get_all(self):
return EndpointSnapshot.objects.all().order_by('-discovered_at')
return EndpointSnapshot.objects.all().order_by('-created_at')
def iter_raw_data_for_export(
self,
scan_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
scan_id: 扫描 ID
batch_size: 每批数据量
Yields:
包含所有端点字段的字典
"""
qs = (
EndpointSnapshot.objects
.filter(scan_id=scan_id)
.values(
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
)
.order_by('url')
)
for row in qs.iterator(chunk_size=batch_size):
yield row

View File

@@ -6,6 +6,7 @@ from typing import List, Iterator
from apps.asset.models.snapshot_models import HostPortMappingSnapshot
from apps.asset.dtos.snapshot import HostPortMappingSnapshotDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
logger = logging.getLogger(__name__)
@@ -18,6 +19,8 @@ class DjangoHostPortMappingSnapshotRepository:
"""
保存主机端口关联快照
注意:会自动按 (scan_id, host, ip, port) 去重,保留最后一条记录。
Args:
items: 主机端口关联快照 DTO 列表
@@ -31,10 +34,13 @@ class DjangoHostPortMappingSnapshotRepository:
if not items:
logger.debug("主机端口关联快照为空,跳过保存")
return
# 根据模型唯一约束自动去重
unique_items = deduplicate_for_bulk(items, HostPortMappingSnapshot)
# 构建快照对象
snapshots = []
for item in items:
for item in unique_items:
snapshots.append(HostPortMappingSnapshot(
scan_id=item.scan_id,
host=item.host,
@@ -59,20 +65,28 @@ class DjangoHostPortMappingSnapshotRepository:
)
raise
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
def get_ip_aggregation_by_scan(self, scan_id: int, filter_query: str = None):
from django.db.models import Min
from apps.common.utils.filter_utils import apply_filters
qs = HostPortMappingSnapshot.objects.filter(scan_id=scan_id)
if search:
qs = qs.filter(ip__icontains=search)
# 应用智能过滤
if filter_query:
field_mapping = {
'ip': 'ip',
'port': 'port',
'host': 'host',
}
qs = apply_filters(qs, filter_query, field_mapping)
ip_aggregated = (
qs
.values('ip')
.annotate(
discovered_at=Min('discovered_at')
created_at=Min('created_at')
)
.order_by('-discovered_at')
.order_by('-created_at')
)
results = []
@@ -92,24 +106,32 @@ class DjangoHostPortMappingSnapshotRepository:
'ip': ip,
'hosts': hosts,
'ports': ports,
'discovered_at': item['discovered_at'],
'created_at': item['created_at'],
})
return results
def get_all_ip_aggregation(self, search: str = None):
def get_all_ip_aggregation(self, filter_query: str = None):
"""获取所有 IP 聚合数据"""
from django.db.models import Min
from apps.common.utils.filter_utils import apply_filters
qs = HostPortMappingSnapshot.objects.all()
if search:
qs = qs.filter(ip__icontains=search)
# 应用智能过滤
if filter_query:
field_mapping = {
'ip': 'ip',
'port': 'port',
'host': 'host',
}
qs = apply_filters(qs, filter_query, field_mapping)
ip_aggregated = (
qs
.values('ip')
.annotate(discovered_at=Min('discovered_at'))
.order_by('-discovered_at')
.annotate(created_at=Min('created_at'))
.order_by('-created_at')
)
results = []
@@ -127,7 +149,7 @@ class DjangoHostPortMappingSnapshotRepository:
'ip': ip,
'hosts': hosts,
'ports': ports,
'discovered_at': item['discovered_at'],
'created_at': item['created_at'],
})
return results
@@ -143,3 +165,33 @@ class DjangoHostPortMappingSnapshotRepository:
)
for ip in queryset:
yield ip
def iter_raw_data_for_export(
self,
scan_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
scan_id: 扫描 ID
batch_size: 每批数据量
Yields:
{
'ip': '192.168.1.1',
'host': 'example.com',
'port': 80,
'created_at': datetime
}
"""
qs = (
HostPortMappingSnapshot.objects
.filter(scan_id=scan_id)
.values('ip', 'host', 'port', 'created_at')
.order_by('ip', 'host', 'port')
)
for row in qs.iterator(chunk_size=batch_size):
yield row

View File

@@ -1,11 +1,12 @@
"""Django ORM 实现的 SubdomainSnapshot Repository"""
import logging
from typing import List
from typing import List, Iterator
from apps.asset.models.snapshot_models import SubdomainSnapshot
from apps.asset.dtos import SubdomainSnapshotDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
logger = logging.getLogger(__name__)
@@ -18,6 +19,8 @@ class DjangoSubdomainSnapshotRepository:
"""
保存子域名快照
注意:会自动按 (scan_id, name) 去重,保留最后一条记录。
Args:
items: 子域名快照 DTO 列表
@@ -31,10 +34,13 @@ class DjangoSubdomainSnapshotRepository:
if not items:
logger.debug("子域名快照为空,跳过保存")
return
# 根据模型唯一约束自动去重
unique_items = deduplicate_for_bulk(items, SubdomainSnapshot)
# 构建快照对象
snapshots = []
for item in items:
for item in unique_items:
snapshots.append(SubdomainSnapshot(
scan_id=item.scan_id,
name=item.name,
@@ -55,7 +61,32 @@ class DjangoSubdomainSnapshotRepository:
raise
def get_by_scan(self, scan_id: int):
return SubdomainSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
return SubdomainSnapshot.objects.filter(scan_id=scan_id).order_by('-created_at')
def get_all(self):
return SubdomainSnapshot.objects.all().order_by('-discovered_at')
return SubdomainSnapshot.objects.all().order_by('-created_at')
def iter_raw_data_for_export(
self,
scan_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
scan_id: 扫描 ID
batch_size: 每批数据量
Yields:
{'name': 'sub.example.com', 'created_at': datetime}
"""
qs = (
SubdomainSnapshot.objects
.filter(scan_id=scan_id)
.values('name', 'created_at')
.order_by('name')
)
for row in qs.iterator(chunk_size=batch_size):
yield row

View File

@@ -8,6 +8,7 @@ from django.db import transaction
from apps.asset.models import VulnerabilitySnapshot
from apps.asset.dtos.snapshot import VulnerabilitySnapshotDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
logger = logging.getLogger(__name__)
@@ -21,12 +22,17 @@ class DjangoVulnerabilitySnapshotRepository:
使用 ``ignore_conflicts`` 策略,如果快照已存在则跳过。
具体唯一约束由数据库模型控制。
注意:会自动按唯一约束字段去重,保留最后一条记录。
"""
if not items:
logger.warning("漏洞快照列表为空,跳过保存")
return
try:
# 根据模型唯一约束自动去重
unique_items = deduplicate_for_bulk(items, VulnerabilitySnapshot)
snapshot_objects = [
VulnerabilitySnapshot(
scan_id=item.scan_id,
@@ -38,7 +44,7 @@ class DjangoVulnerabilitySnapshotRepository:
description=item.description,
raw_output=item.raw_output,
)
for item in items
for item in unique_items
]
with transaction.atomic():
@@ -47,7 +53,7 @@ class DjangoVulnerabilitySnapshotRepository:
ignore_conflicts=True,
)
logger.debug("成功保存 %d 条漏洞快照记录", len(items))
logger.debug("成功保存 %d 条漏洞快照记录", len(unique_items))
except Exception as e:
logger.error(
@@ -60,7 +66,7 @@ class DjangoVulnerabilitySnapshotRepository:
def get_by_scan(self, scan_id: int):
"""按扫描任务获取漏洞快照 QuerySet。"""
return VulnerabilitySnapshot.objects.filter(scan_id=scan_id).order_by("-discovered_at")
return VulnerabilitySnapshot.objects.filter(scan_id=scan_id).order_by("-created_at")
def get_all(self):
return VulnerabilitySnapshot.objects.all().order_by('-discovered_at')
return VulnerabilitySnapshot.objects.all().order_by('-created_at')

View File

@@ -1,11 +1,12 @@
"""WebsiteSnapshot Repository - Django ORM 实现"""
import logging
from typing import List
from typing import List, Iterator
from apps.asset.models.snapshot_models import WebsiteSnapshot
from apps.asset.dtos.snapshot import WebsiteSnapshotDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
logger = logging.getLogger(__name__)
@@ -18,6 +19,8 @@ class DjangoWebsiteSnapshotRepository:
"""
保存网站快照
注意:会自动按 (scan_id, url) 去重,保留最后一条记录。
Args:
items: 网站快照 DTO 列表
@@ -31,10 +34,13 @@ class DjangoWebsiteSnapshotRepository:
if not items:
logger.debug("网站快照为空,跳过保存")
return
# 根据模型唯一约束自动去重
unique_items = deduplicate_for_bulk(items, WebsiteSnapshot)
# 构建快照对象
snapshots = []
for item in items:
for item in unique_items:
snapshots.append(WebsiteSnapshot(
scan_id=item.scan_id,
url=item.url,
@@ -68,7 +74,50 @@ class DjangoWebsiteSnapshotRepository:
raise
def get_by_scan(self, scan_id: int):
return WebsiteSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
return WebsiteSnapshot.objects.filter(scan_id=scan_id).order_by('-created_at')
def get_all(self):
return WebsiteSnapshot.objects.all().order_by('-discovered_at')
return WebsiteSnapshot.objects.all().order_by('-created_at')
def iter_raw_data_for_export(
self,
scan_id: int,
batch_size: int = 1000
) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
scan_id: 扫描 ID
batch_size: 每批数据量
Yields:
包含所有网站字段的字典
"""
qs = (
WebsiteSnapshot.objects
.filter(scan_id=scan_id)
.values(
'url', 'host', 'location', 'title', 'status',
'content_length', 'content_type', 'web_server', 'tech',
'body_preview', 'vhost', 'created_at'
)
.order_by('url')
)
for row in qs.iterator(chunk_size=batch_size):
# 重命名字段以匹配 CSV 表头
yield {
'url': row['url'],
'host': row['host'],
'location': row['location'],
'title': row['title'],
'status_code': row['status'],
'content_length': row['content_length'],
'content_type': row['content_type'],
'webserver': row['web_server'],
'tech': row['tech'],
'body_preview': row['body_preview'],
'vhost': row['vhost'],
'created_at': row['created_at'],
}

View File

@@ -26,9 +26,9 @@ class SubdomainSerializer(serializers.ModelSerializer):
class Meta:
model = Subdomain
fields = [
'id', 'name', 'discovered_at', 'target'
'id', 'name', 'created_at', 'target'
]
read_only_fields = ['id', 'discovered_at']
read_only_fields = ['id', 'created_at']
class SubdomainListSerializer(serializers.ModelSerializer):
@@ -41,9 +41,9 @@ class SubdomainListSerializer(serializers.ModelSerializer):
class Meta:
model = Subdomain
fields = [
'id', 'name', 'discovered_at'
'id', 'name', 'created_at'
]
read_only_fields = ['id', 'discovered_at']
read_only_fields = ['id', 'created_at']
# class IPAddressListSerializer(serializers.ModelSerializer):
@@ -87,7 +87,7 @@ class WebSiteSerializer(serializers.ModelSerializer):
'tech',
'vhost',
'subdomain',
'discovered_at',
'created_at',
]
read_only_fields = fields
@@ -107,7 +107,7 @@ class VulnerabilitySerializer(serializers.ModelSerializer):
'cvss_score',
'description',
'raw_output',
'discovered_at',
'created_at',
]
read_only_fields = fields
@@ -126,7 +126,7 @@ class VulnerabilitySnapshotSerializer(serializers.ModelSerializer):
'cvss_score',
'description',
'raw_output',
'discovered_at',
'created_at',
]
read_only_fields = fields
@@ -134,8 +134,8 @@ class VulnerabilitySnapshotSerializer(serializers.ModelSerializer):
class EndpointListSerializer(serializers.ModelSerializer):
"""端点列表序列化器(用于目标端点列表页)"""
# GF 匹配模式映射为前端使用的 tags 字段
tags = serializers.ListField(
# GF 匹配模式gf-patterns 工具匹配的敏感 URL 模式)
gfPatterns = serializers.ListField(
child=serializers.CharField(),
source='matched_gf_patterns',
read_only=True,
@@ -155,8 +155,8 @@ class EndpointListSerializer(serializers.ModelSerializer):
'body_preview',
'tech',
'vhost',
'tags',
'discovered_at',
'gfPatterns',
'created_at',
]
read_only_fields = fields
@@ -164,8 +164,7 @@ class EndpointListSerializer(serializers.ModelSerializer):
class DirectorySerializer(serializers.ModelSerializer):
"""目录序列化器"""
website_url = serializers.CharField(source='website.url', read_only=True)
discovered_at = serializers.DateTimeField(read_only=True)
created_at = serializers.DateTimeField(read_only=True)
class Meta:
model = Directory
@@ -178,8 +177,7 @@ class DirectorySerializer(serializers.ModelSerializer):
'lines',
'content_type',
'duration',
'website_url',
'discovered_at',
'created_at',
]
read_only_fields = fields
@@ -192,12 +190,12 @@ class IPAddressAggregatedSerializer(serializers.Serializer):
- ip: IP 地址
- hosts: 该 IP 关联的所有主机名列表
- ports: 该 IP 关联的所有端口列表
- discovered_at: 首次发现时间
- created_at: 创建时间
"""
ip = serializers.IPAddressField(read_only=True)
hosts = serializers.ListField(child=serializers.CharField(), read_only=True)
ports = serializers.ListField(child=serializers.IntegerField(), read_only=True)
discovered_at = serializers.DateTimeField(read_only=True)
created_at = serializers.DateTimeField(read_only=True)
# ==================== 快照序列化器 ====================
@@ -207,7 +205,7 @@ class SubdomainSnapshotSerializer(serializers.ModelSerializer):
class Meta:
model = SubdomainSnapshot
fields = ['id', 'name', 'discovered_at']
fields = ['id', 'name', 'created_at']
read_only_fields = fields
@@ -233,7 +231,7 @@ class WebsiteSnapshotSerializer(serializers.ModelSerializer):
'tech',
'vhost',
'subdomain_name',
'discovered_at',
'created_at',
]
read_only_fields = fields
@@ -241,9 +239,6 @@ class WebsiteSnapshotSerializer(serializers.ModelSerializer):
class DirectorySnapshotSerializer(serializers.ModelSerializer):
"""目录快照序列化器(用于扫描历史)"""
# DirectorySnapshot 当前不再关联 Website这里暂时将 website_url 映射为自身的 url保证字段兼容
website_url = serializers.CharField(source='url', read_only=True)
class Meta:
model = DirectorySnapshot
fields = [
@@ -255,8 +250,7 @@ class DirectorySnapshotSerializer(serializers.ModelSerializer):
'lines',
'content_type',
'duration',
'website_url',
'discovered_at',
'created_at',
]
read_only_fields = fields
@@ -264,8 +258,8 @@ class DirectorySnapshotSerializer(serializers.ModelSerializer):
class EndpointSnapshotSerializer(serializers.ModelSerializer):
"""端点快照序列化器(用于扫描历史)"""
# GF 匹配模式映射为前端使用的 tags 字段
tags = serializers.ListField(
# GF 匹配模式gf-patterns 工具匹配的敏感 URL 模式)
gfPatterns = serializers.ListField(
child=serializers.CharField(),
source='matched_gf_patterns',
read_only=True,
@@ -286,7 +280,7 @@ class EndpointSnapshotSerializer(serializers.ModelSerializer):
'body_preview',
'tech',
'vhost',
'tags',
'discovered_at',
'gfPatterns',
'created_at',
]
read_only_fields = fields

View File

@@ -1,8 +1,12 @@
import logging
from typing import Tuple, Iterator
"""Directory Service - 目录业务逻辑层"""
import logging
from typing import List, Iterator, Optional
from apps.asset.models.asset_models import Directory
from apps.asset.repositories import DjangoDirectoryRepository
from apps.asset.dtos import DirectoryDTO
from apps.common.validators import is_valid_url, is_url_match_target
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@@ -10,46 +14,122 @@ logger = logging.getLogger(__name__)
class DirectoryService:
"""目录业务逻辑层"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'status': 'status',
}
def __init__(self, repository=None):
"""
初始化目录服务
Args:
repository: 目录仓储实例(用于依赖注入)
"""
"""初始化目录服务"""
self.repo = repository or DjangoDirectoryRepository()
# ==================== 创建操作 ====================
def bulk_create_ignore_conflicts(self, directory_dtos: list) -> None:
def bulk_upsert(self, directory_dtos: List[DirectoryDTO]) -> int:
"""
批量创建目录记录,忽略冲突(用于扫描任务
批量创建或更新目录upsert
存在则更新所有字段,不存在则创建。
Args:
directory_dtos: DirectoryDTO 列表
"""
return self.repo.bulk_create_ignore_conflicts(directory_dtos)
# ==================== 查询操作 ====================
def get_all(self):
"""
获取所有目录
Returns:
QuerySet: 目录查询集
int: 处理的记录数
"""
logger.debug("获取所有目录")
return self.repo.get_all()
if not directory_dtos:
return 0
try:
return self.repo.bulk_upsert(directory_dtos)
except Exception as e:
logger.error(f"批量 upsert 目录失败: {e}")
raise
def get_directories_by_target(self, target_id: int):
logger.debug("获取目标下所有目录 - Target ID: %d", target_id)
return self.repo.get_by_target(target_id)
def bulk_create_urls(self, target_id: int, target_name: str, target_type: str, urls: List[str]) -> int:
"""
批量创建目录(仅 URL使用 ignore_conflicts
验证 URL 格式和匹配,过滤无效/不匹配 URL去重后批量创建。
已存在的记录会被跳过。
Args:
target_id: 目标 ID
target_name: 目标名称(用于匹配验证)
target_type: 目标类型 ('domain', 'ip', 'cidr')
urls: URL 列表
Returns:
int: 实际创建的记录数
"""
if not urls:
return 0
# 过滤有效 URL 并去重
valid_urls = []
seen = set()
for url in urls:
if not isinstance(url, str):
continue
url = url.strip()
if not url or url in seen:
continue
if not is_valid_url(url):
continue
# 匹配验证(前端已阻止不匹配的提交,后端作为双重保障)
if not is_url_match_target(url, target_name, target_type):
continue
seen.add(url)
valid_urls.append(url)
if not valid_urls:
return 0
# 获取创建前的数量
count_before = self.repo.count_by_target(target_id)
# 创建 DTO 列表并批量创建
directory_dtos = [
DirectoryDTO(url=url, target_id=target_id)
for url in valid_urls
]
self.repo.bulk_create_ignore_conflicts(directory_dtos)
# 获取创建后的数量
count_after = self.repo.count_by_target(target_id)
return count_after - count_before
def get_directories_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""获取目标下的所有目录"""
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self, filter_query: Optional[str] = None):
"""获取所有目录"""
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_directory_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取目标下的所有目录 URL,用于导出大批量数据。"""
logger.debug("流式导出目标下目录 URL - Target ID: %d", target_id)
"""流式获取目标下的所有目录 URL"""
return self.repo.get_urls_for_export(target_id=target_id, batch_size=chunk_size)
def iter_raw_data_for_csv_export(self, target_id: int) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
target_id: 目标 ID
Yields:
原始数据字典
"""
return self.repo.iter_raw_data_for_export(target_id=target_id)
__all__ = ['DirectoryService']

View File

@@ -5,10 +5,12 @@ Endpoint 服务层
"""
import logging
from typing import List, Optional, Dict, Any, Iterator
from typing import List, Iterator, Optional
from apps.asset.dtos.asset import EndpointDTO
from apps.asset.repositories.asset import DjangoEndpointRepository
from apps.common.validators import is_valid_url, is_url_match_target
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@@ -20,101 +22,101 @@ class EndpointService:
提供 EndpointURL/端点)相关的业务逻辑
"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
}
def __init__(self):
"""初始化 Endpoint 服务"""
self.repo = DjangoEndpointRepository()
def bulk_create_endpoints(
self,
endpoints: List[EndpointDTO],
ignore_conflicts: bool = True
) -> int:
def bulk_upsert(self, endpoints: List[EndpointDTO]) -> int:
"""
批量创建端点记录
批量创建或更新端点upsert
存在则更新所有字段,不存在则创建。
Args:
endpoints: 端点数据列表
ignore_conflicts: 是否忽略冲突(去重)
Returns:
int: 创建的记录数
int: 处理的记录数
"""
if not endpoints:
return 0
try:
if ignore_conflicts:
return self.repo.bulk_create_ignore_conflicts(endpoints)
else:
# 如果需要非忽略冲突的版本,可以在 repository 中添加
return self.repo.bulk_create_ignore_conflicts(endpoints)
return self.repo.bulk_upsert(endpoints)
except Exception as e:
logger.error(f"批量创建端点失败: {e}")
logger.error(f"批量 upsert 端点失败: {e}")
raise
def get_endpoints_by_website(
self,
website_id: int,
limit: Optional[int] = None
) -> List[Dict[str, Any]]:
def bulk_create_urls(self, target_id: int, target_name: str, target_type: str, urls: List[str]) -> int:
"""
获取网站下的端点列表
批量创建端点(仅 URL使用 ignore_conflicts
Args:
website_id: 网站 ID
limit: 返回数量限制
Returns:
List[Dict]: 端点列表
"""
endpoints_dto = self.repo.get_by_website(website_id)
if limit:
endpoints_dto = endpoints_dto[:limit]
endpoints = []
for dto in endpoints_dto:
endpoints.append({
'url': dto.url,
'title': dto.title,
'status_code': dto.status_code,
'content_length': dto.content_length,
'webserver': dto.webserver
})
return endpoints
def get_endpoints_by_target(
self,
target_id: int,
limit: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
获取目标下的端点列表
验证 URL 格式和匹配,过滤无效/不匹配 URL去重后批量创建。
已存在的记录会被跳过。
Args:
target_id: 目标 ID
limit: 返回数量限制
target_name: 目标名称(用于匹配验证)
target_type: 目标类型 ('domain', 'ip', 'cidr')
urls: URL 列表
Returns:
List[Dict]: 端点列表
int: 实际创建的记录数
"""
endpoints_dto = self.repo.get_by_target(target_id)
if not urls:
return 0
if limit:
endpoints_dto = endpoints_dto[:limit]
# 过滤有效 URL 并去重
valid_urls = []
seen = set()
endpoints = []
for dto in endpoints_dto:
endpoints.append({
'url': dto.url,
'title': dto.title,
'status_code': dto.status_code,
'content_length': dto.content_length,
'webserver': dto.webserver
})
for url in urls:
if not isinstance(url, str):
continue
url = url.strip()
if not url or url in seen:
continue
if not is_valid_url(url):
continue
# 匹配验证(前端已阻止不匹配的提交,后端作为双重保障)
if not is_url_match_target(url, target_name, target_type):
continue
seen.add(url)
valid_urls.append(url)
return endpoints
if not valid_urls:
return 0
# 获取创建前的数量
count_before = self.repo.count_by_target(target_id)
# 创建 DTO 列表并批量创建
endpoint_dtos = [
EndpointDTO(url=url, target_id=target_id)
for url in valid_urls
]
self.repo.bulk_create_ignore_conflicts(endpoint_dtos)
# 获取创建后的数量
count_after = self.repo.count_by_target(target_id)
return count_after - count_before
def get_endpoints_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""获取目标下的所有端点"""
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def count_endpoints_by_target(self, target_id: int) -> int:
"""
@@ -127,52 +129,28 @@ class EndpointService:
int: 端点数量
"""
return self.repo.count_by_target(target_id)
def get_queryset_by_target(self, target_id: int):
return self.repo.get_queryset_by_target(target_id)
def get_all(self):
def get_all(self, filter_query: Optional[str] = None):
"""获取所有端点(全局查询)"""
return self.repo.get_all()
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_endpoint_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取目标下的所有端点 URL用于导出。"""
queryset = self.repo.get_queryset_by_target(target_id)
queryset = self.repo.get_by_target(target_id)
for url in queryset.values_list('url', flat=True).iterator(chunk_size=chunk_size):
yield url
def count_endpoints_by_website(self, website_id: int) -> int:
def iter_raw_data_for_csv_export(self, target_id: int) -> Iterator[dict]:
"""
统计网站下的端点数量
流式获取原始数据用于 CSV 导出
Args:
website_id: 网站 ID
Returns:
int: 端点数量
"""
return self.repo.count_by_website(website_id)
def soft_delete_endpoints(self, endpoint_ids: List[int]) -> int:
"""
软删除端点
target_id: 目标 ID
Args:
endpoint_ids: 端点 ID 列表
Returns:
int: 更新的数量
Yields:
原始数据字典
"""
return self.repo.soft_delete_by_ids(endpoint_ids)
def hard_delete_endpoints(self, endpoint_ids: List[int]) -> tuple:
"""
硬删除端点
Args:
endpoint_ids: 端点 ID 列表
Returns:
tuple: (删除总数, 详细信息)
"""
return self.repo.hard_delete_by_ids(endpoint_ids)
return self.repo.iter_raw_data_for_export(target_id=target_id)

View File

@@ -1,16 +1,31 @@
"""HostPortMapping Service - 业务逻辑层"""
import logging
from typing import List, Iterator
from typing import List, Iterator, Optional, Dict
from django.db.models import Min
from apps.asset.repositories.asset import DjangoHostPortMappingRepository
from apps.asset.dtos.asset import HostPortMappingDTO
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
class HostPortMappingService:
"""主机端口映射服务 - 负责主机端口映射数据的业务逻辑"""
"""主机端口映射服务 - 负责主机端口映射数据的业务逻辑
职责:
- 业务逻辑处理(过滤、聚合)
- 调用 Repository 进行数据访问
"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'ip': 'ip',
'port': 'port',
'host': 'host',
}
def __init__(self):
self.repo = DjangoHostPortMappingRepository()
@@ -49,13 +64,106 @@ class HostPortMappingService:
def iter_host_port_by_target(self, target_id: int, batch_size: int = 1000):
return self.repo.get_for_export(target_id=target_id, batch_size=batch_size)
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
return self.repo.get_ip_aggregation_by_target(target_id, search=search)
def get_ip_aggregation_by_target(
self,
target_id: int,
filter_query: Optional[str] = None
) -> List[Dict]:
"""获取目标下的 IP 聚合数据
Args:
target_id: 目标 ID
filter_query: 智能过滤语法字符串
Returns:
聚合后的 IP 数据列表
"""
# 从 Repository 获取基础 QuerySet
qs = self.repo.get_queryset_by_target(target_id)
# Service 层应用过滤逻辑
if filter_query:
qs = apply_filters(qs, filter_query, self.FILTER_FIELD_MAPPING)
# Service 层处理聚合逻辑
return self._aggregate_by_ip(qs, filter_query, target_id=target_id)
def get_all_ip_aggregation(self, search: str = None):
"""获取所有 IP 聚合数据(全局查询)"""
return self.repo.get_all_ip_aggregation(search=search)
def get_all_ip_aggregation(self, filter_query: Optional[str] = None) -> List[Dict]:
"""获取所有 IP 聚合数据(全局查询)
Args:
filter_query: 智能过滤语法字符串
Returns:
聚合后的 IP 数据列表
"""
# 从 Repository 获取基础 QuerySet
qs = self.repo.get_all_queryset()
# Service 层应用过滤逻辑
if filter_query:
qs = apply_filters(qs, filter_query, self.FILTER_FIELD_MAPPING)
# Service 层处理聚合逻辑
return self._aggregate_by_ip(qs, filter_query)
def _aggregate_by_ip(
self,
qs,
filter_query: Optional[str] = None,
target_id: Optional[int] = None
) -> List[Dict]:
"""按 IP 聚合数据
Args:
qs: 已过滤的 QuerySet
filter_query: 过滤条件(用于子查询)
target_id: 目标 ID用于子查询限定范围
Returns:
聚合后的数据列表
"""
ip_aggregated = (
qs
.values('ip')
.annotate(created_at=Min('created_at'))
.order_by('-created_at')
)
results = []
for item in ip_aggregated:
ip = item['ip']
# 获取该 IP 的所有 host 和 port也需要应用过滤条件
mappings_qs = self.repo.get_queryset_by_ip(ip, target_id=target_id)
if filter_query:
mappings_qs = apply_filters(mappings_qs, filter_query, self.FILTER_FIELD_MAPPING)
mappings = mappings_qs.values('host', 'port').distinct()
hosts = sorted({m['host'] for m in mappings})
ports = sorted({m['port'] for m in mappings})
results.append({
'ip': ip,
'hosts': hosts,
'ports': ports,
'created_at': item['created_at'],
})
return results
def iter_ips_by_target(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
"""流式获取目标下的所有唯一 IP 地址。"""
return self.repo.get_ips_for_export(target_id=target_id, batch_size=batch_size)
def iter_raw_data_for_csv_export(self, target_id: int) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
target_id: 目标 ID
Yields:
原始数据字典 {ip, host, port, created_at}
"""
return self.repo.iter_raw_data_for_export(target_id=target_id)

View File

@@ -1,15 +1,33 @@
import logging
from typing import Tuple, List, Dict
from typing import List, Dict, Optional
from dataclasses import dataclass
from apps.asset.repositories import DjangoSubdomainRepository
from apps.asset.dtos import SubdomainDTO
from apps.common.validators import is_valid_domain
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@dataclass
class BulkCreateResult:
"""批量创建结果"""
created_count: int
skipped_count: int
invalid_count: int
mismatched_count: int
total_received: int
class SubdomainService:
"""子域名业务逻辑层"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'name': 'name',
}
def __init__(self, repository=None):
"""
初始化子域名服务
@@ -21,44 +39,50 @@ class SubdomainService:
# ==================== 查询操作 ====================
def get_all(self):
def get_all(self, filter_query: Optional[str] = None):
"""
获取所有子域名
Args:
filter_query: 智能过滤语法字符串
Returns:
QuerySet: 子域名查询集
"""
logger.debug("获取所有子域名")
return self.repo.get_all()
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
# ==================== 创建操作 ====================
def get_or_create(self, name: str, target_id: int) -> Tuple[any, bool]:
def get_subdomains_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""
获取或创建子域名
获取目标下的子域名
Args:
target_id: 目标 ID
filter_query: 智能过滤语法字符串
Returns:
QuerySet: 子域名查询集
"""
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def count_subdomains_by_target(self, target_id: int) -> int:
"""
统计目标下的子域名数量
Args:
name: 子域名名称
target_id: 目标 ID
Returns:
(Subdomain对象, 是否新创建)
int: 子域名数量
"""
logger.debug("获取或创建子域名 - Name: %s, Target ID: %d", name, target_id)
return self.repo.get_or_create(name, target_id)
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
"""
批量创建子域名,忽略冲突
Args:
items: 子域名 DTO 列表
Note:
使用 ignore_conflicts 策略,重复记录会被跳过
"""
logger.debug("批量创建子域名 - 数量: %d", len(items))
return self.repo.bulk_create_ignore_conflicts(items)
logger.debug("统计目标下子域名数量 - Target ID: %d", target_id)
return self.repo.count_by_target(target_id)
def get_by_names_and_target_id(self, names: set, target_id: int) -> dict:
"""
@@ -85,25 +109,8 @@ class SubdomainService:
List[str]: 子域名名称列表
"""
logger.debug("获取目标下所有子域名 - Target ID: %d", target_id)
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
return list(self.repo.get_domains_for_export(target_id=target_id))
def get_subdomains_by_target(self, target_id: int):
return self.repo.get_by_target(target_id)
def count_subdomains_by_target(self, target_id: int) -> int:
"""
统计目标下的子域名数量
Args:
target_id: 目标 ID
Returns:
int: 子域名数量
"""
logger.debug("统计目标下子域名数量 - Target ID: %d", target_id)
return self.repo.count_by_target(target_id)
def iter_subdomain_names_by_target(self, target_id: int, chunk_size: int = 1000):
"""
流式获取目标下的所有子域名名称(内存优化)
@@ -116,8 +123,123 @@ class SubdomainService:
str: 子域名名称
"""
logger.debug("流式获取目标下所有子域名 - Target ID: %d, 批次大小: %d", target_id, chunk_size)
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
return self.repo.get_domains_for_export(target_id=target_id, batch_size=chunk_size)
def iter_raw_data_for_csv_export(self, target_id: int):
"""
流式获取原始数据用于 CSV 导出
Args:
target_id: 目标 ID
Yields:
原始数据字典 {name, created_at}
"""
return self.repo.iter_raw_data_for_export(target_id=target_id)
__all__ = ['SubdomainService']
# ==================== 创建操作 ====================
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
"""
批量创建子域名,忽略冲突
Args:
items: 子域名 DTO 列表
Note:
使用 ignore_conflicts 策略,重复记录会被跳过
"""
logger.debug("批量创建子域名 - 数量: %d", len(items))
return self.repo.bulk_create_ignore_conflicts(items)
def bulk_create_subdomains(
self,
target_id: int,
target_name: str,
subdomains: List[str]
) -> BulkCreateResult:
"""
批量创建子域名(带验证)
Args:
target_id: 目标 ID
target_name: 目标域名(用于匹配验证)
subdomains: 子域名列表
Returns:
BulkCreateResult: 创建结果统计
"""
total_received = len(subdomains)
target_name = target_name.lower().strip()
def is_subdomain_match(subdomain: str) -> bool:
"""验证子域名是否匹配目标域名"""
if subdomain == target_name:
return True
if subdomain.endswith('.' + target_name):
return True
return False
# 过滤有效的子域名
valid_subdomains = []
invalid_count = 0
mismatched_count = 0
for subdomain in subdomains:
if not isinstance(subdomain, str) or not subdomain.strip():
continue
subdomain = subdomain.lower().strip()
# 验证格式
if not is_valid_domain(subdomain):
invalid_count += 1
continue
# 验证匹配
if not is_subdomain_match(subdomain):
mismatched_count += 1
continue
valid_subdomains.append(subdomain)
# 去重
unique_subdomains = list(set(valid_subdomains))
duplicate_count = len(valid_subdomains) - len(unique_subdomains)
if not unique_subdomains:
return BulkCreateResult(
created_count=0,
skipped_count=duplicate_count,
invalid_count=invalid_count,
mismatched_count=mismatched_count,
total_received=total_received,
)
# 获取创建前的数量
count_before = self.repo.count_by_target(target_id)
# 创建 DTO 列表并批量创建
subdomain_dtos = [
SubdomainDTO(name=name, target_id=target_id)
for name in unique_subdomains
]
self.repo.bulk_create_ignore_conflicts(subdomain_dtos)
# 获取创建后的数量
count_after = self.repo.count_by_target(target_id)
created_count = count_after - count_before
# 计算因数据库冲突跳过的数量
db_skipped = len(unique_subdomains) - created_count
return BulkCreateResult(
created_count=created_count,
skipped_count=duplicate_count + db_skipped,
invalid_count=invalid_count,
mismatched_count=mismatched_count,
total_received=total_received,
)
__all__ = ['SubdomainService', 'BulkCreateResult']

View File

@@ -1,11 +1,13 @@
"""Vulnerability Service - 漏洞资产业务逻辑层"""
import logging
from typing import List
from typing import List, Optional
from apps.asset.models import Vulnerability
from apps.asset.dtos.asset import VulnerabilityDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@@ -16,10 +18,20 @@ class VulnerabilityService:
当前提供基础的批量创建能力,使用 ignore_conflicts 依赖数据库唯一约束去重。
"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'type': 'vuln_type',
'severity': 'severity',
'source': 'source',
'url': 'url',
}
def bulk_create_ignore_conflicts(self, items: List[VulnerabilityDTO]) -> None:
"""批量创建漏洞资产记录,忽略冲突。
注意:会自动按 (target_id, url, vuln_type, source) 去重,保留最后一条记录。
Note:
- 是否去重取决于模型上的唯一/部分唯一约束;
- 当前 Vulnerability 模型未定义唯一约束,因此会保留全部记录。
@@ -29,6 +41,9 @@ class VulnerabilityService:
return
try:
# 根据模型唯一约束自动去重(如果模型没有唯一约束则跳过)
unique_items = deduplicate_for_bulk(items, Vulnerability)
vulns = [
Vulnerability(
target_id=item.target_id,
@@ -40,7 +55,7 @@ class VulnerabilityService:
description=item.description,
raw_output=item.raw_output,
)
for item in items
for item in unique_items
]
Vulnerability.objects.bulk_create(vulns, ignore_conflicts=True)
@@ -57,24 +72,34 @@ class VulnerabilityService:
# ==================== 查询方法 ====================
def get_all(self):
def get_all(self, filter_query: Optional[str] = None):
"""获取所有漏洞 QuerySet用于全局漏洞列表
Returns:
QuerySet[Vulnerability]: 所有漏洞,按发现时间倒序
"""
return Vulnerability.objects.filter(deleted_at__isnull=True).order_by("-discovered_at")
Args:
filter_query: 智能过滤语法字符串
def get_queryset_by_target(self, target_id: int):
Returns:
QuerySet[Vulnerability]: 所有漏洞,按创建时间倒序
"""
queryset = Vulnerability.objects.all().order_by("-created_at")
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_vulnerabilities_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""按目标获取漏洞 QuerySet用于分页
Args:
target_id: 目标 ID
filter_query: 智能过滤语法字符串
Returns:
QuerySet[Vulnerability]: 目标下的所有漏洞,按发现时间倒序
QuerySet[Vulnerability]: 目标下的所有漏洞,按创建时间倒序
"""
return Vulnerability.objects.filter(target_id=target_id).order_by("-discovered_at")
queryset = Vulnerability.objects.filter(target_id=target_id).order_by("-created_at")
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def count_by_target(self, target_id: int) -> int:
"""统计目标下的漏洞数量。"""

View File

@@ -1,8 +1,12 @@
"""WebSite Service - 网站业务逻辑层"""
import logging
from typing import Tuple, List
from typing import List, Iterator, Optional
from apps.asset.repositories import DjangoWebSiteRepository
from apps.asset.dtos import WebSiteDTO
from apps.common.validators import is_valid_url, is_url_match_target
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@@ -10,82 +14,128 @@ logger = logging.getLogger(__name__)
class WebSiteService:
"""网站业务逻辑层"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
}
def __init__(self, repository=None):
"""
初始化网站服务
Args:
repository: 网站仓储实例(用于依赖注入)
"""
"""初始化网站服务"""
self.repo = repository or DjangoWebSiteRepository()
# ==================== 创建操作 ====================
def bulk_create_ignore_conflicts(self, website_dtos: List[WebSiteDTO]) -> None:
def bulk_upsert(self, website_dtos: List[WebSiteDTO]) -> int:
"""
批量创建网站记录,忽略冲突(用于扫描任务
批量创建或更新网站upsert
存在则更新所有字段,不存在则创建。
Args:
website_dtos: WebSiteDTO 列表
Note:
使用 ignore_conflicts 策略,重复记录会被跳过
"""
logger.debug("批量创建网站 - 数量: %d", len(website_dtos))
return self.repo.bulk_create_ignore_conflicts(website_dtos)
# ==================== 查询操作 ====================
def get_by_url(self, url: str, target_id: int) -> int:
"""
根据 URL 和 target_id 查找网站 ID
Args:
url: 网站 URL
target_id: 目标 ID
Returns:
int: 网站 ID如果不存在返回 None
int: 处理的记录数
"""
return self.repo.get_by_url(url=url, target_id=target_id)
# ==================== 查询操作 ====================
def get_all(self):
"""
获取所有网站
if not website_dtos:
return 0
Returns:
QuerySet: 网站查询集
"""
logger.debug("获取所有网站")
return self.repo.get_all()
try:
return self.repo.bulk_upsert(website_dtos)
except Exception as e:
logger.error(f"批量 upsert 网站失败: {e}")
raise
def get_websites_by_target(self, target_id: int):
return self.repo.get_by_target(target_id)
def count_websites_by_scan(self, scan_id: int) -> int:
def bulk_create_urls(self, target_id: int, target_name: str, target_type: str, urls: List[str]) -> int:
"""
统计扫描下的网站数量
批量创建网站(仅 URL使用 ignore_conflicts
验证 URL 格式和匹配,过滤无效/不匹配 URL去重后批量创建。
已存在的记录会被跳过。
Args:
scan_id: 扫描 ID
target_id: 目标 ID
target_name: 目标名称(用于匹配验证)
target_type: 目标类型 ('domain', 'ip', 'cidr')
urls: URL 列表
Returns:
int: 网站数量
int: 实际创建的记录数
"""
logger.debug("统计扫描下网站数量 - Scan ID: %d", scan_id)
return self.repo.count_by_scan(scan_id)
if not urls:
return 0
# 过滤有效 URL 并去重
valid_urls = []
seen = set()
for url in urls:
if not isinstance(url, str):
continue
url = url.strip()
if not url or url in seen:
continue
if not is_valid_url(url):
continue
# 匹配验证(前端已阻止不匹配的提交,后端作为双重保障)
if not is_url_match_target(url, target_name, target_type):
continue
seen.add(url)
valid_urls.append(url)
if not valid_urls:
return 0
# 获取创建前的数量
count_before = self.repo.count_by_target(target_id)
# 创建 DTO 列表并批量创建
website_dtos = [
WebSiteDTO(url=url, target_id=target_id)
for url in valid_urls
]
self.repo.bulk_create_ignore_conflicts(website_dtos)
# 获取创建后的数量
count_after = self.repo.count_by_target(target_id)
return count_after - count_before
def get_websites_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""获取目标下的所有网站"""
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self, filter_query: Optional[str] = None):
"""获取所有网站"""
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_by_url(self, url: str, target_id: int) -> int:
"""根据 URL 和 target_id 查找网站 ID"""
return self.repo.get_by_url(url=url, target_id=target_id)
def iter_website_urls_by_target(self, target_id: int, chunk_size: int = 1000):
"""流式获取目标下的所有站点 URL(内存优化,委托给 Repository 层)"""
logger.debug(
"流式获取目标下所有站点 URL - Target ID: %d, 批次大小: %d",
target_id,
chunk_size,
)
# 通过仓储层统一访问数据库,避免 Service 直接依赖 ORM
"""流式获取目标下的所有站点 URL"""
return self.repo.get_urls_for_export(target_id=target_id, batch_size=chunk_size)
def iter_raw_data_for_csv_export(self, target_id: int) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
target_id: 目标 ID
Yields:
原始数据字典
"""
return self.repo.iter_raw_data_for_export(target_id=target_id)
__all__ = ['WebSiteService']

View File

@@ -26,10 +26,9 @@ class DirectorySnapshotsService:
2. 同步到资产表(去重,不包含 scan_id
Args:
items: 目录快照 DTO 列表(必须包含 website_id
items: 目录快照 DTO 列表(必须包含 target_id
Raises:
ValueError: 如果 items 中的 website_id 为 None
Exception: 数据库操作失败
"""
if not items:
@@ -49,14 +48,13 @@ class DirectorySnapshotsService:
logger.debug("步骤 1: 保存到快照表")
self.snapshot_repo.save_snapshots(items)
# 步骤 2: 转换为资产 DTO 并保存到资产表
# 注意:去重是通过数据库的 UNIQUE 约束 + ignore_conflicts 实现的
# 步骤 2: 转换为资产 DTO 并保存到资产表upsert
# - 新记录:插入资产表
# - 已存在的记录:自动跳过
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
# - 已存在的记录:更新字段created_at 不更新,保留创建时间)
logger.debug("步骤 2: 同步到资产表(通过 Service 层upsert")
asset_items = [item.to_asset_dto() for item in items]
self.asset_service.bulk_create_ignore_conflicts(asset_items)
self.asset_service.bulk_upsert(asset_items)
logger.info("目录快照和资产数据保存成功 - 数量: %d", len(items))
@@ -69,15 +67,44 @@ class DirectorySnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
return self.snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'status': 'status',
'content_type': 'content_type',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: str = None):
"""获取所有目录快照"""
return self.snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_directory_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有目录 URL。"""
queryset = self.snapshot_repo.get_by_scan(scan_id)
for snapshot in queryset.iterator(chunk_size=chunk_size):
yield snapshot.url
def iter_raw_data_for_csv_export(self, scan_id: int) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
scan_id: 扫描 ID
Yields:
原始数据字典
"""
return self.snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)

View File

@@ -50,13 +50,11 @@ class EndpointSnapshotsService:
self.snapshot_repo.save_snapshots(items)
# 步骤 2: 转换为资产 DTO 并保存到资产表
# 注意:去重是通过数据库的 UNIQUE 约束 + ignore_conflicts 实现的
# - 新记录:插入资产表
# - 已存在的记录:自动跳过
# 使用 upsert新记录插入已存在的记录更新
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
asset_items = [item.to_asset_dto() for item in items]
self.asset_service.bulk_create_endpoints(asset_items)
self.asset_service.bulk_upsert(asset_items)
logger.info("端点快照和资产数据保存成功 - 数量: %d", len(items))
@@ -69,15 +67,47 @@ class EndpointSnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
return self.snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
'webserver': 'webserver',
'tech': 'tech',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: str = None):
"""获取所有端点快照"""
return self.snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_endpoint_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有端点 URL。"""
queryset = self.snapshot_repo.get_by_scan(scan_id)
for snapshot in queryset.iterator(chunk_size=chunk_size):
yield snapshot.url
def iter_raw_data_for_csv_export(self, scan_id: int) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
scan_id: 扫描 ID
Yields:
原始数据字典
"""
return self.snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)

View File

@@ -69,13 +69,25 @@ class HostPortMappingSnapshotsService:
)
raise
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
return self.snapshot_repo.get_ip_aggregation_by_scan(scan_id, search=search)
def get_ip_aggregation_by_scan(self, scan_id: int, filter_query: str = None):
return self.snapshot_repo.get_ip_aggregation_by_scan(scan_id, filter_query=filter_query)
def get_all_ip_aggregation(self, search: str = None):
def get_all_ip_aggregation(self, filter_query: str = None):
"""获取所有 IP 聚合数据"""
return self.snapshot_repo.get_all_ip_aggregation(search=search)
return self.snapshot_repo.get_all_ip_aggregation(filter_query=filter_query)
def iter_ips_by_scan(self, scan_id: int, batch_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有唯一 IP 地址。"""
return self.snapshot_repo.get_ips_for_export(scan_id=scan_id, batch_size=batch_size)
def iter_raw_data_for_csv_export(self, scan_id: int) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
scan_id: 扫描 ID
Yields:
原始数据字典 {ip, host, port, created_at}
"""
return self.snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)

View File

@@ -66,14 +66,41 @@ class SubdomainSnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
return self.subdomain_snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'name': 'name',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
from apps.common.utils.filter_utils import apply_filters
queryset = self.subdomain_snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: str = None):
"""获取所有子域名快照"""
return self.subdomain_snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.subdomain_snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_subdomain_names_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
queryset = self.subdomain_snapshot_repo.get_by_scan(scan_id)
for snapshot in queryset.iterator(chunk_size=chunk_size):
yield snapshot.name
yield snapshot.name
def iter_raw_data_for_csv_export(self, scan_id: int) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
scan_id: 扫描 ID
Yields:
原始数据字典 {name, created_at}
"""
return self.subdomain_snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)

View File

@@ -66,13 +66,31 @@ class VulnerabilitySnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
"""按扫描任务获取所有漏洞快照。"""
return self.snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'type': 'vuln_type',
'url': 'url',
'severity': 'severity',
'source': 'source',
}
def get_all(self):
def get_by_scan(self, scan_id: int, filter_query: str = None):
"""按扫描任务获取所有漏洞快照。"""
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self, filter_query: str = None):
"""获取所有漏洞快照"""
return self.snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_vuln_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有漏洞 URL。"""

View File

@@ -49,14 +49,13 @@ class WebsiteSnapshotsService:
logger.debug("步骤 1: 保存到快照表")
self.snapshot_repo.save_snapshots(items)
# 步骤 2: 转换为资产 DTO 并保存到资产表
# 注意:去重是通过数据库的 UNIQUE 约束 + ignore_conflicts 实现的
# 步骤 2: 转换为资产 DTO 并保存到资产表upsert
# - 新记录:插入资产表
# - 已存在的记录:自动跳过
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
# - 已存在的记录:更新字段created_at 不更新,保留创建时间)
logger.debug("步骤 2: 同步到资产表(通过 Service 层upsert")
asset_items = [item.to_asset_dto() for item in items]
self.asset_service.bulk_create_ignore_conflicts(asset_items)
self.asset_service.bulk_upsert(asset_items)
logger.info("网站快照和资产数据保存成功 - 数量: %d", len(items))
@@ -69,15 +68,47 @@ class WebsiteSnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
return self.snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status',
'webserver': 'web_server',
'tech': 'tech',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: str = None):
"""获取所有网站快照"""
return self.snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_website_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有站点 URL发现时间倒序)。"""
"""流式获取某次扫描下的所有站点 URL创建时间倒序)。"""
queryset = self.snapshot_repo.get_by_scan(scan_id)
for snapshot in queryset.iterator(chunk_size=chunk_size):
yield snapshot.url
def iter_raw_data_for_csv_export(self, scan_id: int) -> Iterator[dict]:
"""
流式获取原始数据用于 CSV 导出
Args:
scan_id: 扫描 ID
Yields:
原始数据字典
"""
return self.snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)

File diff suppressed because it is too large Load Diff

View File

@@ -66,12 +66,19 @@ def fetch_config_and_setup_django():
os.environ.setdefault("ENABLE_COMMAND_LOGGING", str(config['logging']['enableCommandLogging']).lower())
os.environ.setdefault("DEBUG", str(config['debug']))
# Git 加速配置(用于 Git clone 加速)
git_mirror = config.get('gitMirror', '')
if git_mirror:
os.environ.setdefault("GIT_MIRROR", git_mirror)
print(f"[CONFIG] ✓ 配置获取成功")
print(f"[CONFIG] DB_HOST: {db_host}")
print(f"[CONFIG] DB_PORT: {db_port}")
print(f"[CONFIG] DB_NAME: {db_name}")
print(f"[CONFIG] DB_USER: {db_user}")
print(f"[CONFIG] REDIS_URL: {config['redisUrl']}")
if git_mirror:
print(f"[CONFIG] GIT_MIRROR: {git_mirror}")
except Exception as e:
print(f"[ERROR] 获取配置失败: {config_url} - {e}", file=sys.stderr)

View File

@@ -16,6 +16,7 @@ def setup_django_for_prefect():
1. 添加项目根目录到 Python 路径
2. 设置 DJANGO_SETTINGS_MODULE 环境变量
3. 调用 django.setup() 初始化 Django
4. 关闭旧的数据库连接,确保使用新连接
使用方式:
from apps.common.prefect_django_setup import setup_django_for_prefect
@@ -36,6 +37,25 @@ def setup_django_for_prefect():
# 初始化 Django
import django
django.setup()
# 关闭所有旧的数据库连接,确保 Worker 进程使用新连接
# 解决 "server closed the connection unexpectedly" 问题
from django.db import connections
connections.close_all()
def close_old_db_connections():
"""
关闭旧的数据库连接
在长时间运行的任务中调用此函数,可以确保使用有效的数据库连接。
适用于:
- Flow 开始前
- Task 开始前
- 长时间空闲后恢复操作前
"""
from django.db import connections
connections.close_all()
# 自动执行初始化(导入即生效)

View File

@@ -3,8 +3,13 @@
提供系统级别的公共服务,包括:
- SystemLogService: 系统日志读取服务
注意FilterService 已移至 apps.common.utils.filter_utils
推荐使用: from apps.common.utils.filter_utils import apply_filters
"""
from .system_log_service import SystemLogService
__all__ = ['SystemLogService']
__all__ = [
'SystemLogService',
]

View File

@@ -21,8 +21,8 @@ class SystemLogService:
"""
def __init__(self):
# 日志文件路径(容器内路径,通过 volume 挂载到宿主机 /opt/xingrin/logs
self.log_file = "/app/backend/logs/xingrin.log"
# 日志文件路径(统一使用 /opt/xingrin/logs
self.log_file = "/opt/xingrin/logs/xingrin.log"
self.default_lines = 200 # 默认返回行数
self.max_lines = 10000 # 最大返回行数限制
self.timeout_seconds = 3 # tail 命令超时时间

View File

@@ -0,0 +1,30 @@
"""Common utilities"""
from .dedup import deduplicate_for_bulk, get_unique_fields
from .hash import (
calc_file_sha256,
calc_stream_sha256,
safe_calc_file_sha256,
is_file_hash_match,
)
from .csv_utils import (
generate_csv_rows,
format_list_field,
format_datetime,
UTF8_BOM,
)
from .git_proxy import get_git_proxy_url
__all__ = [
'deduplicate_for_bulk',
'get_unique_fields',
'calc_file_sha256',
'calc_stream_sha256',
'safe_calc_file_sha256',
'is_file_hash_match',
'generate_csv_rows',
'format_list_field',
'format_datetime',
'UTF8_BOM',
'get_git_proxy_url',
]

View File

@@ -0,0 +1,116 @@
"""CSV 导出工具模块
提供流式 CSV 生成功能,支持:
- UTF-8 BOMExcel 兼容)
- RFC 4180 规范转义
- 流式生成(内存友好)
"""
import csv
import io
from datetime import datetime
from typing import Iterator, Dict, Any, List, Callable, Optional
# UTF-8 BOM确保 Excel 正确识别编码
UTF8_BOM = '\ufeff'
def generate_csv_rows(
data_iterator: Iterator[Dict[str, Any]],
headers: List[str],
field_formatters: Optional[Dict[str, Callable]] = None
) -> Iterator[str]:
"""
流式生成 CSV 行
Args:
data_iterator: 数据迭代器,每个元素是一个字典
headers: CSV 表头列表
field_formatters: 字段格式化函数字典key 为字段名value 为格式化函数
Yields:
CSV 行字符串(包含换行符)
Example:
>>> data = [{'ip': '192.168.1.1', 'hosts': ['a.com', 'b.com']}]
>>> headers = ['ip', 'hosts']
>>> formatters = {'hosts': format_list_field}
>>> for row in generate_csv_rows(iter(data), headers, formatters):
... print(row, end='')
"""
# 输出 BOM + 表头
output = io.StringIO()
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
writer.writerow(headers)
yield UTF8_BOM + output.getvalue()
# 输出数据行
for row_data in data_iterator:
output = io.StringIO()
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
row = []
for header in headers:
value = row_data.get(header, '')
if field_formatters and header in field_formatters:
value = field_formatters[header](value)
row.append(value if value is not None else '')
writer.writerow(row)
yield output.getvalue()
def format_list_field(values: List, separator: str = ';') -> str:
"""
将列表字段格式化为分号分隔的字符串
Args:
values: 值列表
separator: 分隔符,默认为分号
Returns:
分隔符连接的字符串
Example:
>>> format_list_field(['a.com', 'b.com'])
'a.com;b.com'
>>> format_list_field([80, 443])
'80;443'
>>> format_list_field([])
''
>>> format_list_field(None)
''
"""
if not values:
return ''
return separator.join(str(v) for v in values)
def format_datetime(dt: Optional[datetime]) -> str:
"""
格式化日期时间为字符串(转换为本地时区)
Args:
dt: datetime 对象或 None
Returns:
格式化的日期时间字符串,格式为 YYYY-MM-DD HH:MM:SS本地时区
Example:
>>> from datetime import datetime
>>> format_datetime(datetime(2024, 1, 15, 10, 30, 0))
'2024-01-15 10:30:00'
>>> format_datetime(None)
''
"""
if dt is None:
return ''
if isinstance(dt, str):
return dt
# 转换为本地时区(从 Django settings 获取)
from django.utils import timezone
if timezone.is_aware(dt):
dt = timezone.localtime(dt)
return dt.strftime('%Y-%m-%d %H:%M:%S')

View File

@@ -0,0 +1,101 @@
"""
批量数据去重工具
用于 bulk_create 前的批次内去重,避免 PostgreSQL ON CONFLICT 错误。
自动从 Django 模型读取唯一约束字段,无需手动指定。
"""
import logging
from typing import List, TypeVar, Tuple, Optional
from django.db import models
logger = logging.getLogger(__name__)
T = TypeVar('T')
def get_unique_fields(model: type[models.Model]) -> Optional[Tuple[str, ...]]:
"""
从 Django 模型获取唯一约束字段
按优先级查找:
1. Meta.constraints 中的 UniqueConstraint
2. Meta.unique_together
Args:
model: Django 模型类
Returns:
唯一约束字段元组,如果没有则返回 None
"""
meta = model._meta
# 1. 优先查找 UniqueConstraint
for constraint in getattr(meta, 'constraints', []):
if isinstance(constraint, models.UniqueConstraint):
# 跳过条件约束partial unique
if getattr(constraint, 'condition', None) is None:
return tuple(constraint.fields)
# 2. 回退到 unique_together
unique_together = getattr(meta, 'unique_together', None)
if unique_together:
# unique_together 可能是 (('a', 'b'),) 或 ('a', 'b')
if unique_together and isinstance(unique_together[0], (list, tuple)):
return tuple(unique_together[0])
return tuple(unique_together)
return None
def deduplicate_for_bulk(items: List[T], model: type[models.Model]) -> List[T]:
"""
根据模型唯一约束对数据去重
自动从模型读取唯一约束字段,生成去重 key。
保留最后一条记录(后面的数据通常是更新的)。
Args:
items: 待去重的数据列表DTO 或 Model 对象)
model: Django 模型类(用于读取唯一约束)
Returns:
去重后的数据列表
Example:
# 自动从 Endpoint 模型读取唯一约束 (url, target)
unique_items = deduplicate_for_bulk(items, Endpoint)
"""
if not items:
return items
unique_fields = get_unique_fields(model)
if unique_fields is None:
# 模型没有唯一约束,无需去重
logger.debug(f"{model.__name__} 没有唯一约束,跳过去重")
return items
# 处理外键字段名target -> target_id
def make_key(item: T) -> tuple:
key_parts = []
for field in unique_fields:
# 尝试 field_id外键和 field 两种形式
value = getattr(item, f'{field}_id', None)
if value is None:
value = getattr(item, field, None)
key_parts.append(value)
return tuple(key_parts)
# 使用字典去重,保留最后一条
seen = {}
for item in items:
key = make_key(item)
seen[key] = item
unique_items = list(seen.values())
if len(unique_items) < len(items):
logger.debug(f"{model.__name__} 去重: {len(items)} -> {len(unique_items)}")
return unique_items

View File

@@ -0,0 +1,281 @@
"""智能过滤工具 - 通用查询语法解析和 Django ORM 查询构建
支持的语法:
- field="value" 模糊匹配(包含)
- field=="value" 精确匹配
- field!="value" 不等于
逻辑运算符:
- AND: && 或 and 或 空格(默认)
- OR: || 或 or
示例:
type="xss" || type="sqli" # OR
type="xss" or type="sqli" # OR等价
severity="high" && source="nuclei" # AND
severity="high" source="nuclei" # AND空格默认为 AND
severity="high" and source="nuclei" # AND等价
使用示例:
from apps.common.utils.filter_utils import apply_filters
field_mapping = {'ip': 'ip', 'port': 'port', 'host': 'host'}
queryset = apply_filters(queryset, 'ip="192" || port="80"', field_mapping)
"""
import re
import logging
from dataclasses import dataclass
from typing import List, Dict, Optional, Union
from enum import Enum
from django.db.models import QuerySet, Q
logger = logging.getLogger(__name__)
class LogicalOp(Enum):
"""逻辑运算符"""
AND = 'AND'
OR = 'OR'
@dataclass
class ParsedFilter:
"""解析后的过滤条件"""
field: str # 字段名
operator: str # 操作符: '=', '==', '!='
value: str # 原始值
@dataclass
class FilterGroup:
"""过滤条件组(带逻辑运算符)"""
filter: ParsedFilter
logical_op: LogicalOp # 与前一个条件的逻辑关系
class QueryParser:
"""查询语法解析器
支持 ||/or (OR) 和 &&/and/空格 (AND) 逻辑运算符
"""
# 正则匹配: field="value", field=="value", field!="value"
FILTER_PATTERN = re.compile(r'(\w+)(==|!=|=)"([^"]*)"')
# 逻辑运算符模式(带空格)
OR_PATTERN = re.compile(r'\s*(\|\||(?<![a-zA-Z])or(?![a-zA-Z]))\s*', re.IGNORECASE)
AND_PATTERN = re.compile(r'\s*(&&|(?<![a-zA-Z])and(?![a-zA-Z]))\s*', re.IGNORECASE)
@classmethod
def parse(cls, query_string: str) -> List[FilterGroup]:
"""解析查询语法字符串
Args:
query_string: 查询语法字符串
Returns:
解析后的过滤条件组列表
Examples:
>>> QueryParser.parse('type="xss" || type="sqli"')
[FilterGroup(filter=..., logical_op=AND), # 第一个默认 AND
FilterGroup(filter=..., logical_op=OR)]
"""
if not query_string or not query_string.strip():
return []
# 标准化逻辑运算符
# 先处理 || 和 or -> __OR__
normalized = cls.OR_PATTERN.sub(' __OR__ ', query_string)
# 再处理 && 和 and -> __AND__
normalized = cls.AND_PATTERN.sub(' __AND__ ', normalized)
# 分词:按空格分割,保留逻辑运算符标记
tokens = normalized.split()
groups = []
pending_op = LogicalOp.AND # 默认 AND
for token in tokens:
if token == '__OR__':
pending_op = LogicalOp.OR
elif token == '__AND__':
pending_op = LogicalOp.AND
else:
# 尝试解析为过滤条件
match = cls.FILTER_PATTERN.match(token)
if match:
field, operator, value = match.groups()
groups.append(FilterGroup(
filter=ParsedFilter(
field=field.lower(),
operator=operator,
value=value
),
logical_op=pending_op if groups else LogicalOp.AND # 第一个条件默认 AND
))
pending_op = LogicalOp.AND # 重置为默认 AND
return groups
class QueryBuilder:
"""Django ORM 查询构建器
将解析后的过滤条件转换为 Django ORM 查询,支持 AND/OR 逻辑
"""
@classmethod
def build_query(
cls,
queryset: QuerySet,
filter_groups: List[FilterGroup],
field_mapping: Dict[str, str],
json_array_fields: List[str] = None
) -> QuerySet:
"""构建 Django ORM 查询
Args:
queryset: Django QuerySet
filter_groups: 解析后的过滤条件组列表
field_mapping: 字段映射
json_array_fields: JSON 数组字段列表(使用 __contains 查询)
Returns:
过滤后的 QuerySet
"""
if not filter_groups:
return queryset
json_array_fields = json_array_fields or []
# 构建 Q 对象
combined_q = None
for group in filter_groups:
f = group.filter
# 字段映射
db_field = field_mapping.get(f.field)
if not db_field:
logger.debug(f"忽略未知字段: {f.field}")
continue
# 判断是否为 JSON 数组字段
is_json_array = db_field in json_array_fields
# 构建单个条件的 Q 对象
q = cls._build_single_q(db_field, f.operator, f.value, is_json_array)
if q is None:
continue
# 组合 Q 对象
if combined_q is None:
combined_q = q
elif group.logical_op == LogicalOp.OR:
combined_q = combined_q | q
else: # AND
combined_q = combined_q & q
if combined_q is not None:
return queryset.filter(combined_q)
return queryset
@classmethod
def _build_single_q(cls, field: str, operator: str, value: str, is_json_array: bool = False) -> Optional[Q]:
"""构建单个条件的 Q 对象"""
if is_json_array:
# JSON 数组字段使用 __contains 查询
return Q(**{f'{field}__contains': [value]})
if operator == '!=':
return cls._build_not_equal_q(field, value)
elif operator == '==':
return cls._build_exact_q(field, value)
else: # '='
return cls._build_fuzzy_q(field, value)
@classmethod
def _try_convert_to_int(cls, value: str) -> Optional[int]:
"""尝试将值转换为整数"""
try:
return int(value.strip())
except (ValueError, TypeError):
return None
@classmethod
def _build_fuzzy_q(cls, field: str, value: str) -> Q:
"""模糊匹配: 包含"""
return Q(**{f'{field}__icontains': value})
@classmethod
def _build_exact_q(cls, field: str, value: str) -> Q:
"""精确匹配"""
int_val = cls._try_convert_to_int(value)
if int_val is not None:
return Q(**{f'{field}__exact': int_val})
return Q(**{f'{field}__exact': value})
@classmethod
def _build_not_equal_q(cls, field: str, value: str) -> Q:
"""不等于"""
int_val = cls._try_convert_to_int(value)
if int_val is not None:
return ~Q(**{f'{field}__exact': int_val})
return ~Q(**{f'{field}__exact': value})
def apply_filters(
queryset: QuerySet,
query_string: str,
field_mapping: Dict[str, str],
json_array_fields: List[str] = None
) -> QuerySet:
"""应用过滤条件到 QuerySet
Args:
queryset: Django QuerySet
query_string: 查询语法字符串
field_mapping: 字段映射
json_array_fields: JSON 数组字段列表(使用 __contains 查询)
Returns:
过滤后的 QuerySet
Examples:
# OR 查询
apply_filters(qs, 'type="xss" || type="sqli"', mapping)
apply_filters(qs, 'type="xss" or type="sqli"', mapping)
# AND 查询
apply_filters(qs, 'severity="high" && source="nuclei"', mapping)
apply_filters(qs, 'severity="high" source="nuclei"', mapping)
# 混合查询
apply_filters(qs, 'type="xss" || type="sqli" && severity="high"', mapping)
# JSON 数组字段查询
apply_filters(qs, 'implies="PHP"', mapping, json_array_fields=['implies'])
"""
if not query_string or not query_string.strip():
return queryset
try:
filter_groups = QueryParser.parse(query_string)
if not filter_groups:
logger.debug(f"未解析到有效过滤条件: {query_string}")
return queryset
logger.debug(f"解析过滤条件: {filter_groups}")
return QueryBuilder.build_query(
queryset,
filter_groups,
field_mapping,
json_array_fields=json_array_fields
)
except Exception as e:
logger.warning(f"过滤解析错误: {e}, query: {query_string}")
return queryset # 静默降级

View File

@@ -0,0 +1,39 @@
"""Git proxy utilities for URL acceleration."""
import os
from urllib.parse import urlparse
def get_git_proxy_url(original_url: str) -> str:
"""
Convert Git repository URL to proxy format for acceleration.
Supports multiple mirror services (standard format):
- gh-proxy.org: https://gh-proxy.org/https://github.com/user/repo.git
- ghproxy.com: https://ghproxy.com/https://github.com/user/repo.git
- mirror.ghproxy.com: https://mirror.ghproxy.com/https://github.com/user/repo.git
- ghps.cc: https://ghps.cc/https://github.com/user/repo.git
Args:
original_url: Original repository URL, e.g., https://github.com/user/repo.git
Returns:
Converted URL based on GIT_MIRROR setting.
If GIT_MIRROR is not set, returns the original URL unchanged.
"""
git_mirror = os.getenv("GIT_MIRROR", "").strip()
if not git_mirror:
return original_url
# Remove trailing slash from mirror URL if present
git_mirror = git_mirror.rstrip("/")
parsed = urlparse(original_url)
host = parsed.netloc.lower()
# Only support GitHub for now
if "github.com" not in host:
return original_url
# Standard format: https://mirror.example.com/https://github.com/user/repo.git
return f"{git_mirror}/{original_url}"

View File

@@ -7,7 +7,6 @@
import hashlib
import logging
from pathlib import Path
from typing import Optional, BinaryIO
logger = logging.getLogger(__name__)
@@ -91,11 +90,3 @@ def is_file_hash_match(file_path: str, expected_hash: str) -> bool:
return False
return actual_hash.lower() == expected_hash.lower()
__all__ = [
"calc_file_sha256",
"calc_stream_sha256",
"safe_calc_file_sha256",
"is_file_hash_match",
]

View File

@@ -1,6 +1,8 @@
"""域名、IP、端口和目标验证工具函数"""
"""域名、IP、端口、URL 和目标验证工具函数"""
import ipaddress
import logging
from urllib.parse import urlparse
import validators
logger = logging.getLogger(__name__)
@@ -25,6 +27,21 @@ def validate_domain(domain: str) -> None:
raise ValueError(f"域名格式无效: {domain}")
def is_valid_domain(domain: str) -> bool:
"""
判断是否为有效域名(不抛异常)
Args:
domain: 域名字符串
Returns:
bool: 是否为有效域名
"""
if not domain or len(domain) > 253:
return False
return bool(validators.domain(domain))
def validate_ip(ip: str) -> None:
"""
验证 IP 地址格式(支持 IPv4 和 IPv6
@@ -44,6 +61,25 @@ def validate_ip(ip: str) -> None:
raise ValueError(f"IP 地址格式无效: {ip}")
def is_valid_ip(ip: str) -> bool:
"""
判断是否为有效 IP 地址(不抛异常)
Args:
ip: IP 地址字符串
Returns:
bool: 是否为有效 IP 地址
"""
if not ip:
return False
try:
ipaddress.ip_address(ip)
return True
except ValueError:
return False
def validate_cidr(cidr: str) -> None:
"""
验证 CIDR 格式(支持 IPv4 和 IPv6
@@ -140,3 +176,136 @@ def validate_port(port: any) -> tuple[bool, int | None]:
except (ValueError, TypeError):
logger.warning("端口号格式错误,无法转换为整数: %s", port)
return False, None
# ==================== URL 验证函数 ====================
def validate_url(url: str) -> None:
"""
验证 URL 格式,必须包含 schemehttp:// 或 https://
Args:
url: URL 字符串
Raises:
ValueError: URL 格式无效或缺少 scheme
"""
if not url:
raise ValueError("URL 不能为空")
# 检查是否包含 scheme
if not url.startswith('http://') and not url.startswith('https://'):
raise ValueError("URL 必须包含协议http:// 或 https://")
try:
parsed = urlparse(url)
if not parsed.hostname:
raise ValueError("URL 必须包含主机名")
except Exception:
raise ValueError(f"URL 格式无效: {url}")
def is_valid_url(url: str, max_length: int = 2000) -> bool:
"""
判断是否为有效 URL不抛异常
Args:
url: URL 字符串
max_length: URL 最大长度,默认 2000
Returns:
bool: 是否为有效 URL
"""
if not url or len(url) > max_length:
return False
try:
validate_url(url)
return True
except ValueError:
return False
def is_url_match_target(url: str, target_name: str, target_type: str) -> bool:
"""
判断 URL 是否匹配目标
Args:
url: URL 字符串
target_name: 目标名称域名、IP 或 CIDR
target_type: 目标类型 ('domain', 'ip', 'cidr')
Returns:
bool: 是否匹配
"""
try:
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
return False
hostname = hostname.lower()
target_name = target_name.lower()
if target_type == 'domain':
# 域名类型hostname 等于 target_name 或以 .target_name 结尾
return hostname == target_name or hostname.endswith('.' + target_name)
elif target_type == 'ip':
# IP 类型hostname 必须完全等于 target_name
return hostname == target_name
elif target_type == 'cidr':
# CIDR 类型hostname 必须是 IP 且在 CIDR 范围内
try:
ip = ipaddress.ip_address(hostname)
network = ipaddress.ip_network(target_name, strict=False)
return ip in network
except ValueError:
# hostname 不是有效 IP
return False
return False
except Exception:
return False
def detect_input_type(input_str: str) -> str:
"""
检测输入类型(用于快速扫描输入解析)
Args:
input_str: 输入字符串(应该已经 strip
Returns:
str: 输入类型 ('url', 'domain', 'ip', 'cidr')
"""
if not input_str:
raise ValueError("输入不能为空")
# 1. 包含 :// 一定是 URL
if '://' in input_str:
return 'url'
# 2. 包含 / 需要判断是 CIDR 还是 URL缺少 scheme
if '/' in input_str:
# CIDR 格式: IP/prefix如 10.0.0.0/8
parts = input_str.split('/')
if len(parts) == 2:
ip_part, prefix_part = parts
# 如果斜杠后是纯数字且在 0-32 范围内,检查是否是 CIDR
if prefix_part.isdigit() and 0 <= int(prefix_part) <= 32:
ip_parts = ip_part.split('.')
if len(ip_parts) == 4 and all(p.isdigit() for p in ip_parts):
return 'cidr'
# 不是 CIDR视为 URL缺少 scheme后续验证会报错
return 'url'
# 3. 检查是否是 IP 地址
try:
ipaddress.ip_address(input_str)
return 'ip'
except ValueError:
pass
# 4. 默认为域名
return 'domain'

View File

@@ -242,8 +242,9 @@ class WorkerDeployConsumer(AsyncWebsocketConsumer):
return
# 远程 Worker 通过 nginx HTTPS 访问nginx 反代到后端 8888
# 使用 https://{PUBLIC_HOST} 而不是直连 8888 端口
heartbeat_api_url = f"https://{public_host}" # 基础 URLagent 会加 /api/...
# 使用 https://{PUBLIC_HOST}:{PUBLIC_PORT} 而不是直连 8888 端口
public_port = getattr(settings, 'PUBLIC_PORT', '8083')
heartbeat_api_url = f"https://{public_host}:{public_port}"
session_name = f'xingrin_deploy_{self.worker_id}'
remote_script_path = '/tmp/xingrin_deploy.sh'

View File

@@ -0,0 +1,160 @@
"""初始化内置指纹库
- EHole 指纹: ehole.json -> 导入到数据库
- Goby 指纹: goby.json -> 导入到数据库
- Wappalyzer 指纹: wappalyzer.json -> 导入到数据库
可重复执行:如果数据库已有数据则跳过,只在空库时导入。
"""
import json
import logging
from pathlib import Path
from django.conf import settings
from django.core.management.base import BaseCommand
from apps.engine.models import EholeFingerprint, GobyFingerprint, WappalyzerFingerprint
from apps.engine.services.fingerprints import (
EholeFingerprintService,
GobyFingerprintService,
WappalyzerFingerprintService,
)
logger = logging.getLogger(__name__)
# 内置指纹配置
DEFAULT_FINGERPRINTS = [
{
"type": "ehole",
"filename": "ehole.json",
"model": EholeFingerprint,
"service": EholeFingerprintService,
"data_key": "fingerprint", # JSON 中指纹数组的 key
},
{
"type": "goby",
"filename": "goby.json",
"model": GobyFingerprint,
"service": GobyFingerprintService,
"data_key": None, # Goby 是数组格式,直接使用整个 JSON
},
{
"type": "wappalyzer",
"filename": "wappalyzer.json",
"model": WappalyzerFingerprint,
"service": WappalyzerFingerprintService,
"data_key": "apps", # Wappalyzer 使用 apps 对象
},
]
class Command(BaseCommand):
help = "初始化内置指纹库"
def handle(self, *args, **options):
project_base = Path(settings.BASE_DIR).parent # /app/backend -> /app
fingerprints_dir = project_base / "backend" / "fingerprints"
initialized = 0
skipped = 0
failed = 0
for item in DEFAULT_FINGERPRINTS:
fp_type = item["type"]
filename = item["filename"]
model = item["model"]
service_class = item["service"]
data_key = item["data_key"]
# 检查数据库是否已有数据
existing_count = model.objects.count()
if existing_count > 0:
self.stdout.write(self.style.SUCCESS(
f"[{fp_type}] 数据库已有 {existing_count} 条记录,跳过初始化"
))
skipped += 1
continue
# 查找源文件
src_path = fingerprints_dir / filename
if not src_path.exists():
self.stdout.write(self.style.WARNING(
f"[{fp_type}] 未找到内置指纹文件: {src_path},跳过"
))
failed += 1
continue
# 读取并解析 JSON
try:
with open(src_path, "r", encoding="utf-8") as f:
json_data = json.load(f)
except (json.JSONDecodeError, OSError) as exc:
self.stdout.write(self.style.ERROR(
f"[{fp_type}] 读取指纹文件失败: {exc}"
))
failed += 1
continue
# 提取指纹数据(根据不同格式处理)
fingerprints = self._extract_fingerprints(json_data, data_key, fp_type)
if not fingerprints:
self.stdout.write(self.style.WARNING(
f"[{fp_type}] 指纹文件中没有有效数据,跳过"
))
failed += 1
continue
# 使用 Service 批量导入
try:
service = service_class()
result = service.batch_create_fingerprints(fingerprints)
created = result.get("created", 0)
failed_count = result.get("failed", 0)
self.stdout.write(self.style.SUCCESS(
f"[{fp_type}] 导入成功: 创建 {created} 条,失败 {failed_count}"
))
initialized += 1
except Exception as exc:
self.stdout.write(self.style.ERROR(
f"[{fp_type}] 导入失败: {exc}"
))
failed += 1
continue
self.stdout.write(self.style.SUCCESS(
f"指纹初始化完成: 成功 {initialized}, 已存在跳过 {skipped}, 失败 {failed}"
))
def _extract_fingerprints(self, json_data, data_key, fp_type):
"""
根据不同格式提取指纹数据,兼容数组和对象两种格式
支持的格式:
- 数组格式: [...] 或 {"key": [...]}
- 对象格式: {...} 或 {"key": {...}} -> 转换为 [{"name": k, ...v}]
"""
# 获取目标数据
if data_key is None:
# 直接使用整个 JSON
target = json_data
else:
# 从指定 key 获取,支持多个可能的 key如 apps/technologies
if data_key == "apps":
target = json_data.get("apps") or json_data.get("technologies") or {}
else:
target = json_data.get(data_key, [])
# 根据数据类型处理
if isinstance(target, list):
# 已经是数组格式,直接返回
return target
elif isinstance(target, dict):
# 对象格式,转换为数组 [{"name": key, ...value}]
return [{"name": name, **data} if isinstance(data, dict) else {"name": name}
for name, data in target.items()]
return []

View File

@@ -3,12 +3,17 @@
项目安装后执行此命令,自动创建官方模板仓库记录。
使用方式:
python manage.py init_nuclei_templates # 只创建记录
python manage.py init_nuclei_templates # 只创建记录(检测本地已有仓库)
python manage.py init_nuclei_templates --sync # 创建并同步git clone
"""
import logging
import subprocess
from pathlib import Path
from django.conf import settings
from django.core.management.base import BaseCommand
from django.utils import timezone
from apps.engine.models import NucleiTemplateRepo
from apps.engine.services import NucleiTemplateRepoService
@@ -26,6 +31,20 @@ DEFAULT_REPOS = [
]
def get_local_commit_hash(local_path: Path) -> str:
"""获取本地 Git 仓库的 commit hash"""
if not (local_path / ".git").is_dir():
return ""
result = subprocess.run(
["git", "-C", str(local_path), "rev-parse", "HEAD"],
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
return result.stdout.strip() if result.returncode == 0 else ""
class Command(BaseCommand):
help = "初始化 Nuclei 模板仓库(创建官方模板仓库记录)"
@@ -46,6 +65,8 @@ class Command(BaseCommand):
force = options.get("force", False)
service = NucleiTemplateRepoService()
base_dir = Path(getattr(settings, "NUCLEI_TEMPLATES_REPOS_BASE_DIR", "/opt/xingrin/nuclei-repos"))
created = 0
skipped = 0
synced = 0
@@ -87,20 +108,30 @@ class Command(BaseCommand):
# 创建新仓库记录
try:
# 检查本地是否已有仓库(由 install.sh 预下载)
local_path = base_dir / name
local_commit = get_local_commit_hash(local_path)
repo = NucleiTemplateRepo.objects.create(
name=name,
repo_url=repo_url,
local_path=str(local_path) if local_commit else "",
commit_hash=local_commit,
last_synced_at=timezone.now() if local_commit else None,
)
self.stdout.write(self.style.SUCCESS(
f"[{name}] 创建成功: id={repo.id}"
))
if local_commit:
self.stdout.write(self.style.SUCCESS(
f"[{name}] 创建成功(检测到本地仓库): commit={local_commit[:8]}"
))
else:
self.stdout.write(self.style.SUCCESS(
f"[{name}] 创建成功: id={repo.id}"
))
created += 1
# 初始化本地路径
service.ensure_local_path(repo)
# 如果需要同步
if do_sync:
# 如果本地没有仓库且需要同步
if not local_commit and do_sync:
try:
self.stdout.write(self.style.WARNING(
f"[{name}] 正在同步(首次可能需要几分钟)..."

View File

@@ -1,7 +1,8 @@
"""初始化所有内置字典 Wordlist 记录
- 目录扫描默认字典: dir_default.txt -> /app/backend/wordlist/dir_default.txt
- 子域名爆破默认字典: subdomains-top1million-110000.txt -> /app/backend/wordlist/subdomains-top1million-110000.txt
内置字典从镜像内 /app/backend/wordlist/ 复制到运行时目录 /opt/xingrin/wordlists/
- 目录扫描默认字典: dir_default.txt
- 子域名爆破默认字典: subdomains-top1million-110000.txt
可重复执行:如果已存在同名记录且文件有效则跳过,只在缺失或文件丢失时创建/修复。
"""
@@ -13,7 +14,7 @@ from pathlib import Path
from django.conf import settings
from django.core.management.base import BaseCommand
from apps.common.hash_utils import safe_calc_file_sha256
from apps.common.utils import safe_calc_file_sha256
from apps.engine.models import Wordlist

View File

@@ -0,0 +1,19 @@
"""Engine Models
导出所有 Engine 模块的 Models
"""
from .engine import WorkerNode, ScanEngine, Wordlist, NucleiTemplateRepo
from .fingerprints import EholeFingerprint, GobyFingerprint, WappalyzerFingerprint
__all__ = [
# 核心 Models
"WorkerNode",
"ScanEngine",
"Wordlist",
"NucleiTemplateRepo",
# 指纹 Models
"EholeFingerprint",
"GobyFingerprint",
"WappalyzerFingerprint",
]

View File

@@ -1,3 +1,8 @@
"""Engine 模块核心 Models
包含 WorkerNode, ScanEngine, Wordlist, NucleiTemplateRepo
"""
from django.db import models
@@ -78,6 +83,7 @@ class ScanEngine(models.Model):
indexes = [
models.Index(fields=['-created_at']),
]
def __str__(self):
return str(self.name or f'ScanEngine {self.id}')

View File

@@ -0,0 +1,108 @@
"""指纹相关 Models
包含 EHole、Goby、Wappalyzer 等指纹格式的数据模型
"""
from django.db import models
class GobyFingerprint(models.Model):
"""Goby 格式指纹规则
Goby 使用逻辑表达式和规则数组进行匹配:
- logic: 逻辑表达式,如 "a||b", "(a&&b)||c"
- rule: 规则数组,每条规则包含 label, feature, is_equal
"""
name = models.CharField(max_length=300, unique=True, help_text='产品名称')
logic = models.CharField(max_length=500, help_text='逻辑表达式')
rule = models.JSONField(default=list, help_text='规则数组')
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = 'goby_fingerprint'
verbose_name = 'Goby 指纹'
verbose_name_plural = 'Goby 指纹'
ordering = ['-created_at']
indexes = [
models.Index(fields=['name']),
models.Index(fields=['logic']),
models.Index(fields=['-created_at']),
]
def __str__(self) -> str:
return f"{self.name} ({self.logic})"
class EholeFingerprint(models.Model):
"""EHole 格式指纹规则(字段与 ehole.json 一致)"""
cms = models.CharField(max_length=200, help_text='产品/CMS名称')
method = models.CharField(max_length=200, default='keyword', help_text='匹配方式')
location = models.CharField(max_length=200, default='body', help_text='匹配位置')
keyword = models.JSONField(default=list, help_text='关键词列表')
is_important = models.BooleanField(default=False, help_text='是否重点资产')
type = models.CharField(max_length=100, blank=True, default='-', help_text='分类')
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = 'ehole_fingerprint'
verbose_name = 'EHole 指纹'
verbose_name_plural = 'EHole 指纹'
ordering = ['-created_at']
indexes = [
# 搜索过滤字段索引
models.Index(fields=['cms']),
models.Index(fields=['method']),
models.Index(fields=['location']),
models.Index(fields=['type']),
models.Index(fields=['is_important']),
# 排序字段索引
models.Index(fields=['-created_at']),
]
constraints = [
# 唯一约束cms + method + location 组合不能重复
models.UniqueConstraint(
fields=['cms', 'method', 'location'],
name='unique_ehole_fingerprint'
),
]
def __str__(self) -> str:
return f"{self.cms} ({self.method}@{self.location})"
class WappalyzerFingerprint(models.Model):
"""Wappalyzer 格式指纹规则
Wappalyzer 支持多种检测方式cookies, headers, scriptSrc, js, meta, html 等
"""
name = models.CharField(max_length=300, unique=True, help_text='应用名称')
cats = models.JSONField(default=list, help_text='分类 ID 数组')
cookies = models.JSONField(default=dict, blank=True, help_text='Cookie 检测规则')
headers = models.JSONField(default=dict, blank=True, help_text='HTTP Header 检测规则')
script_src = models.JSONField(default=list, blank=True, help_text='脚本 URL 正则数组')
js = models.JSONField(default=list, blank=True, help_text='JavaScript 变量检测规则')
implies = models.JSONField(default=list, blank=True, help_text='依赖关系数组')
meta = models.JSONField(default=dict, blank=True, help_text='HTML meta 标签检测规则')
html = models.JSONField(default=list, blank=True, help_text='HTML 内容正则数组')
description = models.TextField(blank=True, default='', help_text='应用描述')
website = models.URLField(max_length=500, blank=True, default='', help_text='官网链接')
cpe = models.CharField(max_length=300, blank=True, default='', help_text='CPE 标识符')
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = 'wappalyzer_fingerprint'
verbose_name = 'Wappalyzer 指纹'
verbose_name_plural = 'Wappalyzer 指纹'
ordering = ['-created_at']
indexes = [
models.Index(fields=['name']),
models.Index(fields=['website']),
models.Index(fields=['cpe']),
models.Index(fields=['-created_at']),
]
def __str__(self) -> str:
return f"{self.name}"

View File

@@ -0,0 +1,14 @@
"""指纹管理 Serializers
导出所有指纹相关的 Serializer 类
"""
from .ehole import EholeFingerprintSerializer
from .goby import GobyFingerprintSerializer
from .wappalyzer import WappalyzerFingerprintSerializer
__all__ = [
"EholeFingerprintSerializer",
"GobyFingerprintSerializer",
"WappalyzerFingerprintSerializer",
]

View File

@@ -0,0 +1,27 @@
"""EHole 指纹 Serializer"""
from rest_framework import serializers
from apps.engine.models import EholeFingerprint
class EholeFingerprintSerializer(serializers.ModelSerializer):
"""EHole 指纹序列化器"""
class Meta:
model = EholeFingerprint
fields = ['id', 'cms', 'method', 'location', 'keyword',
'is_important', 'type', 'created_at']
read_only_fields = ['id', 'created_at']
def validate_cms(self, value):
"""校验 cms 字段"""
if not value or not value.strip():
raise serializers.ValidationError("cms 字段不能为空")
return value.strip()
def validate_keyword(self, value):
"""校验 keyword 字段"""
if not isinstance(value, list):
raise serializers.ValidationError("keyword 必须是数组")
return value

View File

@@ -0,0 +1,26 @@
"""Goby 指纹 Serializer"""
from rest_framework import serializers
from apps.engine.models import GobyFingerprint
class GobyFingerprintSerializer(serializers.ModelSerializer):
"""Goby 指纹序列化器"""
class Meta:
model = GobyFingerprint
fields = ['id', 'name', 'logic', 'rule', 'created_at']
read_only_fields = ['id', 'created_at']
def validate_name(self, value):
"""校验 name 字段"""
if not value or not value.strip():
raise serializers.ValidationError("name 字段不能为空")
return value.strip()
def validate_rule(self, value):
"""校验 rule 字段"""
if not isinstance(value, list):
raise serializers.ValidationError("rule 必须是数组")
return value

View File

@@ -0,0 +1,24 @@
"""Wappalyzer 指纹 Serializer"""
from rest_framework import serializers
from apps.engine.models import WappalyzerFingerprint
class WappalyzerFingerprintSerializer(serializers.ModelSerializer):
"""Wappalyzer 指纹序列化器"""
class Meta:
model = WappalyzerFingerprint
fields = [
'id', 'name', 'cats', 'cookies', 'headers', 'script_src',
'js', 'implies', 'meta', 'html', 'description', 'website',
'cpe', 'created_at'
]
read_only_fields = ['id', 'created_at']
def validate_name(self, value):
"""校验 name 字段"""
if not value or not value.strip():
raise serializers.ValidationError("name 字段不能为空")
return value.strip()

View File

@@ -0,0 +1,16 @@
"""指纹管理 Services
导出所有指纹相关的 Service 类
"""
from .base import BaseFingerprintService
from .ehole import EholeFingerprintService
from .goby import GobyFingerprintService
from .wappalyzer import WappalyzerFingerprintService
__all__ = [
"BaseFingerprintService",
"EholeFingerprintService",
"GobyFingerprintService",
"WappalyzerFingerprintService",
]

View File

@@ -0,0 +1,144 @@
"""指纹管理基类 Service
提供通用的批量操作和缓存逻辑,供 EHole/Goby/Wappalyzer 等子类继承
"""
import json
import logging
from typing import Any
logger = logging.getLogger(__name__)
class BaseFingerprintService:
"""指纹管理基类 Service提供通用的批量操作和缓存逻辑"""
model = None # 子类必须指定
BATCH_SIZE = 1000 # 每批处理数量
def validate_fingerprint(self, item: dict) -> bool:
"""
校验单条指纹,子类必须实现
Args:
item: 单条指纹数据
Returns:
bool: 是否有效
"""
raise NotImplementedError("子类必须实现 validate_fingerprint 方法")
def validate_fingerprints(self, raw_data: list) -> tuple[list, list]:
"""
批量校验指纹数据
Args:
raw_data: 原始指纹数据列表
Returns:
tuple: (valid_items, invalid_items)
"""
valid, invalid = [], []
for item in raw_data:
if self.validate_fingerprint(item):
valid.append(item)
else:
invalid.append(item)
return valid, invalid
def to_model_data(self, item: dict) -> dict:
"""
转换为 Model 字段,子类必须实现
Args:
item: 原始指纹数据
Returns:
dict: Model 字段数据
"""
raise NotImplementedError("子类必须实现 to_model_data 方法")
def bulk_create(self, fingerprints: list) -> int:
"""
批量创建指纹记录(已校验的数据)
Args:
fingerprints: 已校验的指纹数据列表
Returns:
int: 成功创建数量
"""
if not fingerprints:
return 0
objects = [self.model(**self.to_model_data(item)) for item in fingerprints]
created = self.model.objects.bulk_create(objects, ignore_conflicts=True)
return len(created)
def batch_create_fingerprints(self, raw_data: list) -> dict:
"""
完整流程:分批校验 + 批量创建
Args:
raw_data: 原始指纹数据列表
Returns:
dict: {'created': int, 'failed': int}
"""
total_created = 0
total_failed = 0
for i in range(0, len(raw_data), self.BATCH_SIZE):
batch = raw_data[i:i + self.BATCH_SIZE]
valid, invalid = self.validate_fingerprints(batch)
total_created += self.bulk_create(valid)
total_failed += len(invalid)
logger.info(
"批量创建指纹完成: created=%d, failed=%d, total=%d",
total_created, total_failed, len(raw_data)
)
return {'created': total_created, 'failed': total_failed}
def get_export_data(self) -> dict:
"""
获取导出数据,子类必须实现
Returns:
dict: 导出的 JSON 数据
"""
raise NotImplementedError("子类必须实现 get_export_data 方法")
def export_to_file(self, output_path: str) -> int:
"""
导出所有指纹到 JSON 文件
Args:
output_path: 输出文件路径
Returns:
int: 导出的指纹数量
"""
data = self.get_export_data()
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False)
count = len(data.get('fingerprint', []))
logger.info("导出指纹文件: %s, 数量: %d", output_path, count)
return count
def get_fingerprint_version(self) -> str:
"""
获取指纹库版本标识(用于缓存校验)
Returns:
str: 版本标识,格式 "{count}_{latest_timestamp}"
版本变化场景:
- 新增记录 → count 变化
- 删除记录 → count 变化
- 清空全部 → count 变为 0
"""
count = self.model.objects.count()
latest = self.model.objects.order_by('-created_at').first()
latest_ts = int(latest.created_at.timestamp()) if latest else 0
return f"{count}_{latest_ts}"

View File

@@ -0,0 +1,84 @@
"""EHole 指纹管理 Service
实现 EHole 格式指纹的校验、转换和导出逻辑
"""
from apps.engine.models import EholeFingerprint
from .base import BaseFingerprintService
class EholeFingerprintService(BaseFingerprintService):
"""EHole 指纹管理服务(继承基类,实现 EHole 特定逻辑)"""
model = EholeFingerprint
def validate_fingerprint(self, item: dict) -> bool:
"""
校验单条 EHole 指纹
校验规则:
- cms 字段必须存在且非空
- keyword 字段必须是数组
Args:
item: 单条指纹数据
Returns:
bool: 是否有效
"""
cms = item.get('cms', '')
keyword = item.get('keyword')
return bool(cms and str(cms).strip()) and isinstance(keyword, list)
def to_model_data(self, item: dict) -> dict:
"""
转换 EHole JSON 格式为 Model 字段
字段映射:
- isImportant (JSON) → is_important (Model)
Args:
item: 原始 EHole JSON 数据
Returns:
dict: Model 字段数据
"""
return {
'cms': str(item.get('cms', '')).strip(),
'method': item.get('method', 'keyword'),
'location': item.get('location', 'body'),
'keyword': item.get('keyword', []),
'is_important': item.get('isImportant', False),
'type': item.get('type', '-'),
}
def get_export_data(self) -> dict:
"""
获取导出数据EHole JSON 格式)
Returns:
dict: EHole 格式的 JSON 数据
{
"fingerprint": [
{"cms": "...", "method": "...", "location": "...",
"keyword": [...], "isImportant": false, "type": "..."},
...
],
"version": "1000_1703836800"
}
"""
fingerprints = self.model.objects.all()
data = []
for fp in fingerprints:
data.append({
'cms': fp.cms,
'method': fp.method,
'location': fp.location,
'keyword': fp.keyword,
'isImportant': fp.is_important, # 转回 JSON 格式
'type': fp.type,
})
return {
'fingerprint': data,
'version': self.get_fingerprint_version(),
}

View File

@@ -0,0 +1,70 @@
"""Goby 指纹管理 Service
实现 Goby 格式指纹的校验、转换和导出逻辑
"""
from apps.engine.models import GobyFingerprint
from .base import BaseFingerprintService
class GobyFingerprintService(BaseFingerprintService):
"""Goby 指纹管理服务(继承基类,实现 Goby 特定逻辑)"""
model = GobyFingerprint
def validate_fingerprint(self, item: dict) -> bool:
"""
校验单条 Goby 指纹
校验规则:
- name 字段必须存在且非空
- logic 字段必须存在
- rule 字段必须是数组
Args:
item: 单条指纹数据
Returns:
bool: 是否有效
"""
name = item.get('name', '')
logic = item.get('logic', '')
rule = item.get('rule')
return bool(name and str(name).strip()) and bool(logic) and isinstance(rule, list)
def to_model_data(self, item: dict) -> dict:
"""
转换 Goby JSON 格式为 Model 字段
Args:
item: 原始 Goby JSON 数据
Returns:
dict: Model 字段数据
"""
return {
'name': str(item.get('name', '')).strip(),
'logic': item.get('logic', ''),
'rule': item.get('rule', []),
}
def get_export_data(self) -> list:
"""
获取导出数据Goby JSON 格式 - 数组)
Returns:
list: Goby 格式的 JSON 数据(数组格式)
[
{"name": "...", "logic": "...", "rule": [...]},
...
]
"""
fingerprints = self.model.objects.all()
return [
{
'name': fp.name,
'logic': fp.logic,
'rule': fp.rule,
}
for fp in fingerprints
]

View File

@@ -0,0 +1,99 @@
"""Wappalyzer 指纹管理 Service
实现 Wappalyzer 格式指纹的校验、转换和导出逻辑
"""
from apps.engine.models import WappalyzerFingerprint
from .base import BaseFingerprintService
class WappalyzerFingerprintService(BaseFingerprintService):
"""Wappalyzer 指纹管理服务(继承基类,实现 Wappalyzer 特定逻辑)"""
model = WappalyzerFingerprint
def validate_fingerprint(self, item: dict) -> bool:
"""
校验单条 Wappalyzer 指纹
校验规则:
- name 字段必须存在且非空(从 apps 对象的 key 传入)
Args:
item: 单条指纹数据
Returns:
bool: 是否有效
"""
name = item.get('name', '')
return bool(name and str(name).strip())
def to_model_data(self, item: dict) -> dict:
"""
转换 Wappalyzer JSON 格式为 Model 字段
字段映射:
- scriptSrc (JSON) → script_src (Model)
Args:
item: 原始 Wappalyzer JSON 数据
Returns:
dict: Model 字段数据
"""
return {
'name': str(item.get('name', '')).strip(),
'cats': item.get('cats', []),
'cookies': item.get('cookies', {}),
'headers': item.get('headers', {}),
'script_src': item.get('scriptSrc', []), # JSON: scriptSrc -> Model: script_src
'js': item.get('js', []),
'implies': item.get('implies', []),
'meta': item.get('meta', {}),
'html': item.get('html', []),
'description': item.get('description', ''),
'website': item.get('website', ''),
'cpe': item.get('cpe', ''),
}
def get_export_data(self) -> dict:
"""
获取导出数据Wappalyzer JSON 格式)
Returns:
dict: Wappalyzer 格式的 JSON 数据
{
"apps": {
"AppName": {"cats": [...], "cookies": {...}, ...},
...
}
}
"""
fingerprints = self.model.objects.all()
apps = {}
for fp in fingerprints:
app_data = {}
if fp.cats:
app_data['cats'] = fp.cats
if fp.cookies:
app_data['cookies'] = fp.cookies
if fp.headers:
app_data['headers'] = fp.headers
if fp.script_src:
app_data['scriptSrc'] = fp.script_src # Model: script_src -> JSON: scriptSrc
if fp.js:
app_data['js'] = fp.js
if fp.implies:
app_data['implies'] = fp.implies
if fp.meta:
app_data['meta'] = fp.meta
if fp.html:
app_data['html'] = fp.html
if fp.description:
app_data['description'] = fp.description
if fp.website:
app_data['website'] = fp.website
if fp.cpe:
app_data['cpe'] = fp.cpe
apps[fp.name] = app_data
return {'apps': apps}

View File

@@ -186,6 +186,7 @@ class NucleiTemplateRepoService:
RuntimeError: Git 命令执行失败
"""
import subprocess
from apps.common.utils.git_proxy import get_git_proxy_url
obj = self._get_repo_obj(repo_id)
@@ -196,9 +197,14 @@ class NucleiTemplateRepoService:
cmd: List[str]
action: str
# 获取代理后的 URL如果启用了 Git 加速)
proxied_url = get_git_proxy_url(obj.repo_url)
if proxied_url != obj.repo_url:
logger.info("使用 Git 加速: %s -> %s", obj.repo_url, proxied_url)
# 判断是 clone 还是 pull
if git_dir.is_dir():
# 检查远程地址是否变化
# 检查远程地址是否变化(比较原始 URL不是代理 URL
current_remote = subprocess.run(
["git", "-C", str(local_path), "remote", "get-url", "origin"],
check=False,
@@ -208,12 +214,13 @@ class NucleiTemplateRepoService:
)
current_url = current_remote.stdout.strip() if current_remote.returncode == 0 else ""
if current_url != obj.repo_url:
# 检查是否需要重新 clone原始 URL 或代理 URL 变化都需要)
if current_url not in [obj.repo_url, proxied_url]:
# 远程地址变化,删除旧目录重新 clone
logger.info("nuclei 模板仓库 %s 远程地址变化,重新 clone: %s -> %s", obj.id, current_url, obj.repo_url)
shutil.rmtree(local_path)
local_path.mkdir(parents=True, exist_ok=True)
cmd = ["git", "clone", "--depth", "1", obj.repo_url, str(local_path)]
cmd = ["git", "clone", "--depth", "1", proxied_url, str(local_path)]
action = "clone"
else:
# 已有仓库且地址未变,执行 pull
@@ -224,7 +231,7 @@ class NucleiTemplateRepoService:
if local_path.exists() and not local_path.is_dir():
raise RuntimeError(f"本地路径已存在且不是目录: {local_path}")
# --depth 1 浅克隆,只获取最新提交,节省空间和时间
cmd = ["git", "clone", "--depth", "1", obj.repo_url, str(local_path)]
cmd = ["git", "clone", "--depth", "1", proxied_url, str(local_path)]
action = "clone"
# 执行 Git 命令

View File

@@ -76,8 +76,8 @@ class TaskDistributor:
self.docker_image = settings.TASK_EXECUTOR_IMAGE
if not self.docker_image:
raise ValueError("TASK_EXECUTOR_IMAGE 未配置,请确保 IMAGE_TAG 环境变量已设置")
self.results_mount = getattr(settings, 'CONTAINER_RESULTS_MOUNT', '/app/backend/results')
self.logs_mount = getattr(settings, 'CONTAINER_LOGS_MOUNT', '/app/backend/logs')
# 统一使用 /opt/xingrin 下的路径
self.logs_mount = "/opt/xingrin/logs"
self.submit_interval = getattr(settings, 'TASK_SUBMIT_INTERVAL', 5)
def get_online_workers(self) -> list[WorkerNode]:
@@ -153,30 +153,68 @@ class TaskDistributor:
else:
scored_workers.append((worker, score, cpu, mem))
# 降级策略:如果没有正常负载的,等待后重新选择
# 降级策略:如果没有正常负载的,循环等待后重新检测
if not scored_workers:
if high_load_workers:
# 高负载时先等待,给系统喘息时间(默认 60 秒)
# 高负载等待参数(默认 60 秒检测一次,最多 10 次
high_load_wait = getattr(settings, 'HIGH_LOAD_WAIT_SECONDS', 60)
logger.warning("所有 Worker 高负载,等待 %d 秒后重试...", high_load_wait)
time.sleep(high_load_wait)
high_load_max_retries = getattr(settings, 'HIGH_LOAD_MAX_RETRIES', 10)
# 重新选择(递归调用,可能负载已降下来)
# 为避免无限递归,这里直接使用高负载中最低的
# 开始等待前发送高负载通知
high_load_workers.sort(key=lambda x: x[1])
best_worker, _, cpu, mem = high_load_workers[0]
# 发送高负载通知
_, _, first_cpu, first_mem = high_load_workers[0]
from apps.common.signals import all_workers_high_load
all_workers_high_load.send(
sender=self.__class__,
worker_name=best_worker.name,
cpu=cpu,
mem=mem
worker_name="所有节点",
cpu=first_cpu,
mem=first_mem
)
logger.info("选择 Worker: %s (CPU: %.1f%%, MEM: %.1f%%)", best_worker.name, cpu, mem)
return best_worker
for retry in range(high_load_max_retries):
logger.warning(
"所有 Worker 高负载,等待 %d 秒后重试... (%d/%d)",
high_load_wait, retry + 1, high_load_max_retries
)
time.sleep(high_load_wait)
# 重新获取负载数据
loads = worker_load_service.get_all_loads(worker_ids)
# 重新评估
scored_workers = []
high_load_workers = []
for worker in workers:
load = loads.get(worker.id)
if not load:
continue
cpu = load.get('cpu', 0)
mem = load.get('mem', 0)
score = cpu * 0.7 + mem * 0.3
if cpu > 85 or mem > 85:
high_load_workers.append((worker, score, cpu, mem))
else:
scored_workers.append((worker, score, cpu, mem))
# 如果有正常负载的 Worker跳出循环
if scored_workers:
logger.info("检测到正常负载 Worker结束等待")
break
# 超时或仍然高负载,选择负载最低的
if not scored_workers and high_load_workers:
high_load_workers.sort(key=lambda x: x[1])
best_worker, _, cpu, mem = high_load_workers[0]
logger.warning(
"等待超时,强制分发到高负载 Worker: %s (CPU: %.1f%%, MEM: %.1f%%)",
best_worker.name, cpu, mem
)
return best_worker
return best_worker
else:
logger.warning("没有可用的 Worker")
return None
@@ -234,11 +272,10 @@ class TaskDistributor:
else:
# 远程:通过 Nginx 反向代理访问HTTPS不直连 8888 端口)
network_arg = ""
server_url = f"https://{settings.PUBLIC_HOST}"
server_url = f"https://{settings.PUBLIC_HOST}:{settings.PUBLIC_PORT}"
# 挂载路径(所有节点统一使用固定路径
host_results_dir = settings.HOST_RESULTS_DIR # /opt/xingrin/results
host_logs_dir = settings.HOST_LOGS_DIR # /opt/xingrin/logs
# 挂载路径(统一挂载 /opt/xingrin
host_xingrin_dir = "/opt/xingrin"
# 环境变量SERVER_URL + IS_LOCAL其他配置容器启动时从配置中心获取
# IS_LOCAL 用于 Worker 向配置中心声明身份,决定返回的数据库地址
@@ -251,15 +288,12 @@ class TaskDistributor:
"-e PREFECT_SERVER_EPHEMERAL_ENABLED=true", # 启用 ephemeral server本地临时服务器
"-e PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=120", # 增加启动超时时间
"-e PREFECT_SERVER_DATABASE_CONNECTION_URL=sqlite+aiosqlite:////tmp/.prefect/prefect.db", # 使用 /tmp 下的 SQLite
"-e PREFECT_LOGGING_LEVEL=DEBUG", # 启用 DEBUG 级别日志
"-e PREFECT_LOGGING_SERVER_LEVEL=DEBUG", # Server 日志级别
"-e PREFECT_DEBUG_MODE=true", # 启用调试模式
"-e PREFECT_LOGGING_LEVEL=WARNING", # 日志级别(减少 DEBUG 噪音)
]
# 挂载卷
# 挂载卷(统一挂载整个 /opt/xingrin 目录)
volumes = [
f"-v {host_results_dir}:{self.results_mount}",
f"-v {host_logs_dir}:{self.logs_mount}",
f"-v {host_xingrin_dir}:{host_xingrin_dir}",
]
# 构建命令行参数
@@ -520,7 +554,7 @@ class TaskDistributor:
try:
# 构建 docker run 命令(清理过期扫描结果目录)
script_args = {
'results_dir': '/app/backend/results',
'results_dir': '/opt/xingrin/results',
'retention_days': retention_days,
}

View File

@@ -13,7 +13,7 @@ from django.conf import settings
from django.core.exceptions import ValidationError
from django.core.files.uploadedfile import UploadedFile
from apps.common.hash_utils import safe_calc_file_sha256
from apps.common.utils import safe_calc_file_sha256
from apps.engine.models import Wordlist
from apps.engine.repositories import DjangoWordlistRepository

View File

@@ -7,6 +7,11 @@ from .views import (
WordlistViewSet,
NucleiTemplateRepoViewSet,
)
from .views.fingerprints import (
EholeFingerprintViewSet,
GobyFingerprintViewSet,
WappalyzerFingerprintViewSet,
)
# 创建路由器
@@ -15,6 +20,10 @@ router.register(r"engines", ScanEngineViewSet, basename="engine")
router.register(r"workers", WorkerNodeViewSet, basename="worker")
router.register(r"wordlists", WordlistViewSet, basename="wordlist")
router.register(r"nuclei/repos", NucleiTemplateRepoViewSet, basename="nuclei-repos")
# 指纹管理
router.register(r"fingerprints/ehole", EholeFingerprintViewSet, basename="ehole-fingerprint")
router.register(r"fingerprints/goby", GobyFingerprintViewSet, basename="goby-fingerprint")
router.register(r"fingerprints/wappalyzer", WappalyzerFingerprintViewSet, basename="wappalyzer-fingerprint")
urlpatterns = [
path("", include(router.urls)),

View File

@@ -0,0 +1,16 @@
"""指纹管理 ViewSets
导出所有指纹相关的 ViewSet 类
"""
from .base import BaseFingerprintViewSet
from .ehole import EholeFingerprintViewSet
from .goby import GobyFingerprintViewSet
from .wappalyzer import WappalyzerFingerprintViewSet
__all__ = [
"BaseFingerprintViewSet",
"EholeFingerprintViewSet",
"GobyFingerprintViewSet",
"WappalyzerFingerprintViewSet",
]

View File

@@ -0,0 +1,202 @@
"""指纹管理基类 ViewSet
提供通用的 CRUD 和批量操作,供 EHole/Goby/Wappalyzer 等子类继承
"""
import json
import logging
from django.http import HttpResponse
from rest_framework import viewsets, status, filters
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.exceptions import ValidationError
from apps.common.pagination import BasePagination
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
class BaseFingerprintViewSet(viewsets.ModelViewSet):
"""指纹管理基类 ViewSet供 EHole/Goby/Wappalyzer 等子类继承
提供的 API
标准 CRUD继承自 ModelViewSet
- GET / 列表查询(分页 + 智能过滤)
- POST / 创建单条
- GET /{id}/ 获取详情
- PUT /{id}/ 更新
- DELETE /{id}/ 删除
批量操作(本类实现):
- POST /batch_create/ 批量创建JSON body
- POST /import_file/ 文件导入multipart/form-data适合 10MB+ 大文件)
- POST /bulk-delete/ 批量删除
- POST /delete-all/ 删除所有
- GET /export/ 导出下载
智能过滤语法filter 参数):
- field="value" 模糊匹配(包含)
- field=="value" 精确匹配
- 多条件空格分隔 AND 关系
- || 或 or OR 关系
子类必须实现:
- service_class Service 类
- parse_import_data 解析导入数据格式
- get_export_filename 导出文件名
"""
pagination_class = BasePagination
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
# 子类必须指定
service_class = None # Service 类
# 智能过滤字段映射,子类必须覆盖
FILTER_FIELD_MAPPING = {}
# JSON 数组字段列表(使用 __contains 查询),子类可覆盖
JSON_ARRAY_FIELDS = []
def get_queryset(self):
"""支持智能过滤语法"""
queryset = super().get_queryset()
filter_query = self.request.query_params.get('filter', None)
if filter_query:
queryset = apply_filters(
queryset,
filter_query,
self.FILTER_FIELD_MAPPING,
json_array_fields=getattr(self, 'JSON_ARRAY_FIELDS', [])
)
return queryset
def get_service(self):
"""获取 Service 实例"""
if self.service_class is None:
raise NotImplementedError("子类必须指定 service_class")
return self.service_class()
def parse_import_data(self, json_data: dict) -> list:
"""
解析导入数据,子类必须实现
Args:
json_data: 解析后的 JSON 数据
Returns:
list: 指纹数据列表
"""
raise NotImplementedError("子类必须实现 parse_import_data 方法")
def get_export_filename(self) -> str:
"""
导出文件名,子类必须实现
Returns:
str: 文件名
"""
raise NotImplementedError("子类必须实现 get_export_filename 方法")
@action(detail=False, methods=['post'])
def batch_create(self, request):
"""
批量创建指纹规则
POST /api/engine/fingerprints/{type}/batch_create/
请求格式:
{
"fingerprints": [
{"cms": "WordPress", "method": "keyword", ...},
...
]
}
返回:
{
"created": 2,
"failed": 0
}
"""
fingerprints = request.data.get('fingerprints', [])
if not fingerprints:
raise ValidationError('fingerprints 不能为空')
if not isinstance(fingerprints, list):
raise ValidationError('fingerprints 必须是数组')
result = self.get_service().batch_create_fingerprints(fingerprints)
return Response(result, status=status.HTTP_201_CREATED)
@action(detail=False, methods=['post'])
def import_file(self, request):
"""
文件导入适合大文件10MB+
POST /api/engine/fingerprints/{type}/import_file/
请求格式multipart/form-data
- file: JSON 文件
返回:同 batch_create
"""
file = request.FILES.get('file')
if not file:
raise ValidationError('缺少文件')
try:
json_data = json.load(file)
except json.JSONDecodeError as e:
raise ValidationError(f'无效的 JSON 格式: {e}')
fingerprints = self.parse_import_data(json_data)
if not fingerprints:
raise ValidationError('文件中没有有效的指纹数据')
result = self.get_service().batch_create_fingerprints(fingerprints)
return Response(result, status=status.HTTP_201_CREATED)
@action(detail=False, methods=['post'], url_path='bulk-delete')
def bulk_delete(self, request):
"""
批量删除
POST /api/engine/fingerprints/{type}/bulk-delete/
请求格式:{"ids": [1, 2, 3]}
返回:{"deleted": 3}
"""
ids = request.data.get('ids', [])
if not ids:
raise ValidationError('ids 不能为空')
if not isinstance(ids, list):
raise ValidationError('ids 必须是数组')
deleted_count = self.queryset.model.objects.filter(id__in=ids).delete()[0]
return Response({'deleted': deleted_count})
@action(detail=False, methods=['post'], url_path='delete-all')
def delete_all(self, request):
"""
删除所有指纹
POST /api/engine/fingerprints/{type}/delete-all/
返回:{"deleted": 1000}
"""
deleted_count = self.queryset.model.objects.all().delete()[0]
return Response({'deleted': deleted_count})
@action(detail=False, methods=['get'])
def export(self, request):
"""
导出指纹(前端下载)
GET /api/engine/fingerprints/{type}/export/
返回JSON 文件下载
"""
data = self.get_service().get_export_data()
content = json.dumps(data, ensure_ascii=False, indent=2)
response = HttpResponse(content, content_type='application/json')
response['Content-Disposition'] = f'attachment; filename="{self.get_export_filename()}"'
return response

View File

@@ -0,0 +1,67 @@
"""EHole 指纹管理 ViewSet"""
from apps.common.pagination import BasePagination
from apps.engine.models import EholeFingerprint
from apps.engine.serializers.fingerprints import EholeFingerprintSerializer
from apps.engine.services.fingerprints import EholeFingerprintService
from .base import BaseFingerprintViewSet
class EholeFingerprintViewSet(BaseFingerprintViewSet):
"""EHole 指纹管理 ViewSet
继承自 BaseFingerprintViewSet提供以下 API
标准 CRUDModelViewSet
- GET / 列表查询(分页)
- POST / 创建单条
- GET /{id}/ 获取详情
- PUT /{id}/ 更新
- DELETE /{id}/ 删除
批量操作(继承自基类):
- POST /batch_create/ 批量创建JSON body
- POST /import_file/ 文件导入multipart/form-data
- POST /bulk-delete/ 批量删除
- POST /delete-all/ 删除所有
- GET /export/ 导出下载
智能过滤语法filter 参数):
- cms="word" 模糊匹配 cms 字段
- cms=="WordPress" 精确匹配
- type="CMS" 按类型筛选
- method="keyword" 按匹配方式筛选
- location="body" 按匹配位置筛选
"""
queryset = EholeFingerprint.objects.all()
serializer_class = EholeFingerprintSerializer
pagination_class = BasePagination
service_class = EholeFingerprintService
# 排序配置
ordering_fields = ['created_at', 'cms']
ordering = ['-created_at']
# EHole 过滤字段映射
FILTER_FIELD_MAPPING = {
'cms': 'cms',
'method': 'method',
'location': 'location',
'type': 'type',
'isImportant': 'is_important',
}
def parse_import_data(self, json_data: dict) -> list:
"""
解析 EHole JSON 格式的导入数据
输入格式:{"fingerprint": [...]}
返回:指纹列表
"""
return json_data.get('fingerprint', [])
def get_export_filename(self) -> str:
"""导出文件名"""
return 'ehole.json'

View File

@@ -0,0 +1,65 @@
"""Goby 指纹管理 ViewSet"""
from apps.common.pagination import BasePagination
from apps.engine.models import GobyFingerprint
from apps.engine.serializers.fingerprints import GobyFingerprintSerializer
from apps.engine.services.fingerprints import GobyFingerprintService
from .base import BaseFingerprintViewSet
class GobyFingerprintViewSet(BaseFingerprintViewSet):
"""Goby 指纹管理 ViewSet
继承自 BaseFingerprintViewSet提供以下 API
标准 CRUDModelViewSet
- GET / 列表查询(分页)
- POST / 创建单条
- GET /{id}/ 获取详情
- PUT /{id}/ 更新
- DELETE /{id}/ 删除
批量操作(继承自基类):
- POST /batch_create/ 批量创建JSON body
- POST /import_file/ 文件导入multipart/form-data
- POST /bulk-delete/ 批量删除
- POST /delete-all/ 删除所有
- GET /export/ 导出下载
智能过滤语法filter 参数):
- name="word" 模糊匹配 name 字段
- name=="ProductName" 精确匹配
"""
queryset = GobyFingerprint.objects.all()
serializer_class = GobyFingerprintSerializer
pagination_class = BasePagination
service_class = GobyFingerprintService
# 排序配置
ordering_fields = ['created_at', 'name']
ordering = ['-created_at']
# Goby 过滤字段映射
FILTER_FIELD_MAPPING = {
'name': 'name',
'logic': 'logic',
}
def parse_import_data(self, json_data) -> list:
"""
解析 Goby JSON 格式的导入数据
Goby 格式是数组格式:[{...}, {...}, ...]
输入格式:[{"name": "...", "logic": "...", "rule": [...]}, ...]
返回:指纹列表
"""
if isinstance(json_data, list):
return json_data
return []
def get_export_filename(self) -> str:
"""导出文件名"""
return 'goby.json'

View File

@@ -0,0 +1,75 @@
"""Wappalyzer 指纹管理 ViewSet"""
from apps.common.pagination import BasePagination
from apps.engine.models import WappalyzerFingerprint
from apps.engine.serializers.fingerprints import WappalyzerFingerprintSerializer
from apps.engine.services.fingerprints import WappalyzerFingerprintService
from .base import BaseFingerprintViewSet
class WappalyzerFingerprintViewSet(BaseFingerprintViewSet):
"""Wappalyzer 指纹管理 ViewSet
继承自 BaseFingerprintViewSet提供以下 API
标准 CRUDModelViewSet
- GET / 列表查询(分页)
- POST / 创建单条
- GET /{id}/ 获取详情
- PUT /{id}/ 更新
- DELETE /{id}/ 删除
批量操作(继承自基类):
- POST /batch_create/ 批量创建JSON body
- POST /import_file/ 文件导入multipart/form-data
- POST /bulk-delete/ 批量删除
- POST /delete-all/ 删除所有
- GET /export/ 导出下载
智能过滤语法filter 参数):
- name="word" 模糊匹配 name 字段
- name=="AppName" 精确匹配
"""
queryset = WappalyzerFingerprint.objects.all()
serializer_class = WappalyzerFingerprintSerializer
pagination_class = BasePagination
service_class = WappalyzerFingerprintService
# 排序配置
ordering_fields = ['created_at', 'name']
ordering = ['-created_at']
# Wappalyzer 过滤字段映射
# 注意implies 是 JSON 数组字段,使用 __contains 查询
FILTER_FIELD_MAPPING = {
'name': 'name',
'description': 'description',
'website': 'website',
'cpe': 'cpe',
'implies': 'implies', # JSON 数组字段
}
# JSON 数组字段列表(使用 __contains 查询)
JSON_ARRAY_FIELDS = ['implies']
def parse_import_data(self, json_data: dict) -> list:
"""
解析 Wappalyzer JSON 格式的导入数据
Wappalyzer 格式是 apps 对象格式:{"apps": {"AppName": {...}, ...}}
输入格式:{"apps": {"1C-Bitrix": {"cats": [...], ...}, ...}}
返回:指纹列表(每个 app 转换为带 name 字段的 dict
"""
apps = json_data.get('apps', {})
fingerprints = []
for name, data in apps.items():
item = {'name': name, **data}
fingerprints.append(item)
return fingerprints
def get_export_filename(self) -> str:
"""导出文件名"""
return 'wappalyzer.json'

View File

@@ -238,7 +238,7 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
docker run -d --pull=always \
--name xingrin-agent \
--restart always \
-e HEARTBEAT_API_URL="https://{django_settings.PUBLIC_HOST}" \
-e HEARTBEAT_API_URL="https://{django_settings.PUBLIC_HOST}:{getattr(django_settings, 'PUBLIC_PORT', '8083')}" \
-e WORKER_ID="{worker_id}" \
-e IMAGE_TAG="{target_version}" \
-v /proc:/host/proc:ro \
@@ -390,12 +390,14 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
},
'redisUrl': worker_redis_url,
'paths': {
'results': getattr(settings, 'CONTAINER_RESULTS_MOUNT', '/app/backend/results'),
'logs': getattr(settings, 'CONTAINER_LOGS_MOUNT', '/app/backend/logs'),
'results': getattr(settings, 'CONTAINER_RESULTS_MOUNT', '/opt/xingrin/results'),
'logs': getattr(settings, 'CONTAINER_LOGS_MOUNT', '/opt/xingrin/logs'),
},
'logging': {
'level': os.getenv('LOG_LEVEL', 'INFO'),
'enableCommandLogging': os.getenv('ENABLE_COMMAND_LOGGING', 'true').lower() == 'true',
},
'debug': settings.DEBUG
'debug': settings.DEBUG,
# Git 加速配置(用于 Git clone 加速,如 Nuclei 模板仓库)
'gitMirror': os.getenv('GIT_MIRROR', ''),
})

View File

@@ -16,7 +16,7 @@ SUBDOMAIN_DISCOVERY_COMMANDS = {
# 默认使用所有数据源(更全面,略慢),并始终开启递归
# -all 使用所有数据源
# -recursive 对支持递归的源启用递归枚举(默认开启)
'base': 'subfinder -d {domain} -all -recursive -o {output_file} -silent',
'base': "subfinder -d {domain} -all -recursive -o '{output_file}' -silent",
'optional': {
'threads': '-t {threads}', # 控制并发 goroutine 数
}
@@ -25,31 +25,31 @@ SUBDOMAIN_DISCOVERY_COMMANDS = {
'amass_passive': {
# 先执行被动枚举,将结果写入 amass 内部数据库然后从数据库中导出纯域名names到 output_file
# -silent 禁用进度条和其他输出
'base': 'amass enum -passive -silent -d {domain} && amass subs -names -d {domain} > {output_file}'
'base': "amass enum -passive -silent -d {domain} && amass subs -names -d {domain} > '{output_file}'"
},
'amass_active': {
# 先执行主动枚举 + 爆破,将结果写入 amass 内部数据库然后从数据库中导出纯域名names到 output_file
# -silent 禁用进度条和其他输出
'base': 'amass enum -active -silent -d {domain} -brute && amass subs -names -d {domain} > {output_file}'
'base': "amass enum -active -silent -d {domain} -brute && amass subs -names -d {domain} > '{output_file}'"
},
'sublist3r': {
'base': 'python3 {scan_tools_base}/Sublist3r/sublist3r.py -d {domain} -o {output_file}',
'base': "python3 '{scan_tools_base}/Sublist3r/sublist3r.py' -d {domain} -o '{output_file}'",
'optional': {
'threads': '-t {threads}'
}
},
'assetfinder': {
'base': 'assetfinder --subs-only {domain} > {output_file}',
'base': "assetfinder --subs-only {domain} > '{output_file}'",
},
# === 主动字典爆破 ===
'subdomain_bruteforce': {
# 使用字典对目标域名进行 DNS 爆破
# -d 目标域名,-w 字典文件,-o 输出文件
'base': 'puredns bruteforce {wordlist} {domain} -r /app/backend/resources/resolvers.txt --write {output_file} --quiet',
'base': "puredns bruteforce '{wordlist}' {domain} -r /app/backend/resources/resolvers.txt --write '{output_file}' --quiet",
'optional': {},
},
@@ -57,7 +57,7 @@ SUBDOMAIN_DISCOVERY_COMMANDS = {
'subdomain_resolve': {
# 验证子域名是否能解析(存活验证)
# 输入文件为候选子域列表,输出为存活子域列表
'base': 'puredns resolve {input_file} -r /app/backend/resources/resolvers.txt --write {output_file} --wildcard-tests 50 --wildcard-batch 1000000 --quiet',
'base': "puredns resolve '{input_file}' -r /app/backend/resources/resolvers.txt --write '{output_file}' --wildcard-tests 50 --wildcard-batch 1000000 --quiet",
'optional': {},
},
@@ -65,7 +65,7 @@ SUBDOMAIN_DISCOVERY_COMMANDS = {
'subdomain_permutation_resolve': {
# 流式管道dnsgen 生成变异域名 | puredns resolve 验证存活
# 不落盘中间文件,避免内存爆炸;不做通配符过滤
'base': 'cat {input_file} | dnsgen - | puredns resolve -r /app/backend/resources/resolvers.txt --write {output_file} --wildcard-tests 50 --wildcard-batch 1000000 --quiet',
'base': "cat '{input_file}' | dnsgen - | puredns resolve -r /app/backend/resources/resolvers.txt --write '{output_file}' --wildcard-tests 50 --wildcard-batch 1000000 --quiet",
'optional': {},
},
}
@@ -75,7 +75,7 @@ SUBDOMAIN_DISCOVERY_COMMANDS = {
PORT_SCAN_COMMANDS = {
'naabu_active': {
'base': 'naabu -exclude-cdn -warm-up-time 5 -verify -list {domains_file} -json -silent',
'base': "naabu -exclude-cdn -warm-up-time 5 -verify -list '{domains_file}' -json -silent",
'optional': {
'threads': '-c {threads}',
'ports': '-p {ports}',
@@ -85,7 +85,7 @@ PORT_SCAN_COMMANDS = {
},
'naabu_passive': {
'base': 'naabu -list {domains_file} -passive -json -silent'
'base': "naabu -list '{domains_file}' -passive -json -silent"
},
}
@@ -95,7 +95,7 @@ PORT_SCAN_COMMANDS = {
SITE_SCAN_COMMANDS = {
'httpx': {
'base': (
'httpx -l {url_file} '
"'{scan_tools_base}/httpx' -l '{url_file}' "
'-status-code -content-type -content-length '
'-location -title -server -body-preview '
'-tech-detect -cdn -vhost '
@@ -115,7 +115,7 @@ SITE_SCAN_COMMANDS = {
DIRECTORY_SCAN_COMMANDS = {
'ffuf': {
'base': 'ffuf -u {url}/FUZZ -se -ac -sf -json -w {wordlist}',
'base': "ffuf -u '{url}FUZZ' -se -ac -sf -json -w '{wordlist}'",
'optional': {
'delay': '-p {delay}',
'threads': '-t {threads}',
@@ -131,13 +131,13 @@ DIRECTORY_SCAN_COMMANDS = {
URL_FETCH_COMMANDS = {
'waymore': {
'base': 'waymore -i {domain_name} -mode U -oU {output_file}',
'base': "waymore -i {domain_name} -mode U -oU '{output_file}'",
'input_type': 'domain_name'
},
'katana': {
'base': (
'katana -list {sites_file} -o {output_file} '
"katana -list '{sites_file}' -o '{output_file}' "
'-jc ' # 开启 JavaScript 爬取 + 自动解析 .js 文件里的所有端点(最重要)
'-xhr ' # 额外从 JS 中提取 XHR/Fetch 请求的 API 路径(再多挖 10-20% 隐藏接口)
'-kf all ' # 在每个目录下自动 fuzz 所有已知敏感文件(.env、.git、backup、config、ds_store 等 5000+ 条)
@@ -157,7 +157,7 @@ URL_FETCH_COMMANDS = {
},
'uro': {
'base': 'uro -i {input_file} -o {output_file}',
'base': "uro -i '{input_file}' -o '{output_file}'",
'optional': {
'whitelist': '-w {whitelist}', # 只保留指定扩展名的 URL空格分隔
'blacklist': '-b {blacklist}', # 排除指定扩展名的 URL空格分隔
@@ -167,7 +167,7 @@ URL_FETCH_COMMANDS = {
'httpx': {
'base': (
'httpx -l {url_file} '
"'{scan_tools_base}/httpx' -l '{url_file}' "
'-status-code -content-type -content-length '
'-location -title -server -body-preview '
'-tech-detect -cdn -vhost '
@@ -187,7 +187,7 @@ VULN_SCAN_COMMANDS = {
'base': (
'dalfox --silence --no-color --no-spinner '
'--skip-bav '
'file {endpoints_file} '
"file '{endpoints_file}' "
'--waf-evasion '
'--format json'
),
@@ -205,11 +205,11 @@ VULN_SCAN_COMMANDS = {
},
'nuclei': {
# nuclei 漏洞扫描
# -j: JSON 输出
# -j: JSON 输出(每行一条完整 JSON
# -silent: 静默模式
# -l: 输入 URL 列表文件
# -t: 模板目录路径(支持多个仓库,多次 -t 由 template_args 直接拼接)
'base': 'nuclei -j -silent -l {endpoints_file} {template_args}',
'base': "nuclei -j -silent -l '{endpoints_file}' {template_args}",
'optional': {
'concurrency': '-c {concurrency}', # 并发数(默认 25
'rate_limit': '-rl {rate_limit}', # 每秒请求数限制
@@ -225,12 +225,32 @@ VULN_SCAN_COMMANDS = {
}
# ==================== 指纹识别 ====================
FINGERPRINT_DETECT_COMMANDS = {
'xingfinger': {
# 流式输出模式(不使用 -o输出到 stdout
# -l: URL 列表文件输入
# -s: 静默模式,只输出命中结果
# --json: JSON 格式输出(每行一条)
'base': "xingfinger -l '{urls_file}' -s --json",
'optional': {
# 自定义指纹库路径
'ehole': '--ehole {ehole}',
'goby': '--goby {goby}',
'wappalyzer': '--wappalyzer {wappalyzer}',
}
},
}
# ==================== 工具映射 ====================
COMMAND_TEMPLATES = {
'subdomain_discovery': SUBDOMAIN_DISCOVERY_COMMANDS,
'port_scan': PORT_SCAN_COMMANDS,
'site_scan': SITE_SCAN_COMMANDS,
'fingerprint_detect': FINGERPRINT_DETECT_COMMANDS,
'directory_scan': DIRECTORY_SCAN_COMMANDS,
'url_fetch': URL_FETCH_COMMANDS,
'vuln_scan': VULN_SCAN_COMMANDS,
@@ -242,7 +262,7 @@ COMMAND_TEMPLATES = {
EXECUTION_STAGES = [
{
'mode': 'sequential',
'flows': ['subdomain_discovery', 'port_scan', 'site_scan']
'flows': ['subdomain_discovery', 'port_scan', 'site_scan', 'fingerprint_detect']
},
{
'mode': 'parallel',

View File

@@ -1,7 +1,8 @@
# 引擎配置
#
# 参数命名:统一用中划线(如 rate-limit系统自动转换为下划线
# 必需参数enabled是否启用、timeout超时秒数auto 表示自动计算)
# 必需参数enabled是否启用
# 可选参数timeout超时秒数默认 auto 自动计算)
# ==================== 子域名发现 ====================
#
@@ -15,31 +16,31 @@ subdomain_discovery:
passive_tools:
subfinder:
enabled: true
timeout: 7200 # 2小时
timeout: 3600 # 1小时
# threads: 10 # 并发 goroutine 数
amass_passive:
enabled: true
timeout: 7200
timeout: 3600
amass_active:
enabled: true # 主动枚举 + 爆破
timeout: 7200
timeout: 3600
sublist3r:
enabled: true
timeout: 7200
timeout: 3600
# threads: 50 # 线程数
assetfinder:
enabled: true
timeout: 7200
timeout: 3600
# === Stage 2: 主动字典爆破(可选)===
bruteforce:
enabled: false
subdomain_bruteforce:
timeout: auto # 自动根据字典行数计算
# timeout: auto # 自动根据字典行数计算
wordlist-name: subdomains-top1million-110000.txt # 对应「字典管理」中的 Wordlist.name
# === Stage 3: 变异生成 + 存活验证(可选)===
@@ -52,14 +53,14 @@ subdomain_discovery:
resolve:
enabled: true
subdomain_resolve:
timeout: auto # 自动根据候选子域数量计算
# timeout: auto # 自动根据候选子域数量计算
# ==================== 端口扫描 ====================
port_scan:
tools:
naabu_active:
enabled: true
timeout: auto # 自动计算(目标数 × 端口数 × 0.5秒),范围 60秒 ~ 2天
# timeout: auto # 自动计算(目标数 × 端口数 × 0.5秒),范围 60秒 ~ 2天
threads: 200 # 并发连接数(默认 5
# ports: 1-65535 # 扫描端口范围(默认 1-65535
top-ports: 100 # 扫描 nmap top 100 端口
@@ -67,25 +68,34 @@ port_scan:
naabu_passive:
enabled: true
timeout: auto # 被动扫描通常较快
# timeout: auto # 被动扫描通常较快
# ==================== 站点扫描 ====================
site_scan:
tools:
httpx:
enabled: true
timeout: auto # 自动计算(每个 URL 约 1 秒)
# timeout: auto # 自动计算(每个 URL 约 1 秒)
# threads: 50 # 并发线程数(默认 50
# rate-limit: 150 # 每秒请求数(默认 150
# request-timeout: 10 # 单个请求超时秒数(默认 10
# retries: 2 # 请求失败重试次数
# ==================== 指纹识别 ====================
# 在 site_scan 后串行执行,识别 WebSite 的技术栈
fingerprint_detect:
tools:
xingfinger:
enabled: true
fingerprint-libs: [ehole, goby, wappalyzer] # 启用的指纹库ehole, goby, wappalyzer, fingers, fingerprinthub
# ==================== 目录扫描 ====================
directory_scan:
tools:
ffuf:
enabled: true
timeout: auto # 自动计算(字典行数 × 0.02秒),范围 60秒 ~ 2小时
# timeout: auto # 自动计算(字典行数 × 0.02秒),范围 60秒 ~ 2小时
max-workers: 5 # 并发扫描站点数(默认 5
wordlist-name: dir_default.txt # 对应「字典管理」中的 Wordlist.name
delay: 0.1-2.0 # 请求间隔,支持范围随机(如 "0.1-2.0"
threads: 10 # 并发线程数(默认 40
@@ -102,7 +112,7 @@ url_fetch:
katana:
enabled: true
timeout: auto # 自动计算(根据站点数量)
# timeout: auto # 自动计算(根据站点数量)
depth: 5 # 爬取最大深度(默认 3
threads: 10 # 全局并发数
rate-limit: 30 # 每秒最多请求数
@@ -112,7 +122,7 @@ url_fetch:
uro:
enabled: true
timeout: auto # 自动计算(每 100 个 URL 约 1 秒),范围 30 ~ 300 秒
# timeout: auto # 自动计算(每 100 个 URL 约 1 秒),范围 30 ~ 300 秒
# whitelist: # 只保留指定扩展名
# - php
# - asp
@@ -126,7 +136,7 @@ url_fetch:
httpx:
enabled: true
timeout: auto # 自动计算(每个 URL 约 1 秒)
# timeout: auto # 自动计算(每个 URL 约 1 秒)
# threads: 50 # 并发线程数(默认 50
# rate-limit: 150 # 每秒请求数(默认 150
# request-timeout: 10 # 单个请求超时秒数(默认 10
@@ -137,18 +147,18 @@ vuln_scan:
tools:
dalfox_xss:
enabled: true
timeout: auto # 自动计算endpoints 行数 × 100 秒)
request-timeout: 10 # 单个请求超时秒数
# timeout: auto # 自动计算endpoints 行数 × 100 秒)
request-timeout: 10 # 单个请求超时秒数
only-poc: r # 只输出 POC 结果r: 反射型)
ignore-return: "302,404,403" # 忽略的返回码
delay: 100 # 扫描内部延迟
worker: 10 # worker 数量
delay: 50 # 请求间隔(毫秒)
worker: 30 # worker 数量
user-agent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
# blind-xss-server: xxx # 盲打 XSS 回连服务地址
nuclei:
enabled: true
timeout: auto # 自动计算(根据 endpoints 行数)
# timeout: auto # 自动计算(根据 endpoints 行数)
template-repo-names: # 模板仓库列表对应「Nuclei 模板」中的仓库名
- nuclei-templates
# - nuclei-custom # 可追加自定义仓库

View File

@@ -5,8 +5,10 @@
from .initiate_scan_flow import initiate_scan_flow
from .subdomain_discovery_flow import subdomain_discovery_flow
from .fingerprint_detect_flow import fingerprint_detect_flow
__all__ = [
'initiate_scan_flow',
'subdomain_discovery_flow',
'fingerprint_detect_flow',
]

View File

@@ -5,7 +5,7 @@
架构:
- Flow 负责编排多个原子 Task
- 支持串行执行扫描工具(流式处理
- 支持并发执行扫描工具(使用 ThreadPoolTaskRunner
- 每个 Task 可独立重试
- 配置由 YAML 解析
"""
@@ -14,11 +14,15 @@
from apps.common.prefect_django_setup import setup_django_for_prefect
from prefect import flow
from prefect.task_runners import ThreadPoolTaskRunner
import hashlib
import logging
import os
import subprocess
from datetime import datetime
from pathlib import Path
from typing import List, Tuple
from apps.scan.tasks.directory_scan import (
export_sites_task,
@@ -33,6 +37,9 @@ from apps.scan.utils import config_parser, build_scan_command, ensure_wordlist_l
logger = logging.getLogger(__name__)
# 默认最大并发数
DEFAULT_MAX_WORKERS = 5
def calculate_directory_scan_timeout(
tool_config: dict,
@@ -112,36 +119,37 @@ def calculate_directory_scan_timeout(
return min_timeout
def _setup_directory_scan_directory(scan_workspace_dir: str) -> Path:
def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> int:
"""
创建并验证目录扫描工作目录
从单个工具配置中获取 max_workers 参数
Args:
scan_workspace_dir: 扫描工作空间目录
tool_config: 单个工具的配置字典,如 {'max_workers': 10, 'threads': 5, ...}
default: 默认值,默认为 5
Returns:
Path: 目录扫描目录路径
Raises:
RuntimeError: 目录创建或验证失败
int: max_workers 值
"""
directory_scan_dir = Path(scan_workspace_dir) / 'directory_scan'
directory_scan_dir.mkdir(parents=True, exist_ok=True)
if not isinstance(tool_config, dict):
return default
if not directory_scan_dir.is_dir():
raise RuntimeError(f"目录扫描目录创建失败: {directory_scan_dir}")
if not os.access(directory_scan_dir, os.W_OK):
raise RuntimeError(f"目录扫描目录不可写: {directory_scan_dir}")
return directory_scan_dir
# 支持 max_workers 和 max-workersYAML 中划线会被转换)
max_workers = tool_config.get('max_workers') or tool_config.get('max-workers')
if max_workers is not None and isinstance(max_workers, int) and max_workers > 0:
return max_workers
return default
def _export_site_urls(target_id: int, directory_scan_dir: Path) -> tuple[str, int]:
def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]:
"""
导出目标下的所有站点 URL 到文件
导出目标下的所有站点 URL 到文件(支持懒加载)
Args:
target_id: 目标 ID
target_name: 目标名称(用于懒加载创建默认站点)
directory_scan_dir: 目录扫描目录
Returns:
@@ -185,7 +193,7 @@ def _run_scans_sequentially(
target_name: str
) -> tuple[int, int, list]:
"""
串行执行目录扫描任务(支持多工具)
串行执行目录扫描任务(支持多工具)- 已废弃,保留用于兼容
Args:
enabled_tools: 启用的工具配置字典
@@ -333,6 +341,198 @@ def _run_scans_sequentially(
return total_directories, processed_count, failed_sites
def _generate_log_filename(tool_name: str, site_url: str, directory_scan_dir: Path) -> Path:
"""
生成唯一的日志文件名
使用 URL 的 hash 确保并发时不会冲突
Args:
tool_name: 工具名称
site_url: 站点 URL
directory_scan_dir: 目录扫描目录
Returns:
Path: 日志文件路径
"""
url_hash = hashlib.md5(site_url.encode()).hexdigest()[:8]
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
return directory_scan_dir / f"{tool_name}_{url_hash}_{timestamp}.log"
def _run_scans_concurrently(
enabled_tools: dict,
sites_file: str,
directory_scan_dir: Path,
scan_id: int,
target_id: int,
site_count: int,
target_name: str
) -> Tuple[int, int, List[str]]:
"""
并发执行目录扫描任务(使用 ThreadPoolTaskRunner
Args:
enabled_tools: 启用的工具配置字典
sites_file: 站点文件路径
directory_scan_dir: 目录扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
site_count: 站点数量
target_name: 目标名称(用于错误日志)
Returns:
tuple: (total_directories, processed_sites, failed_sites)
"""
# 读取站点列表
sites: List[str] = []
with open(sites_file, 'r', encoding='utf-8') as f:
for line in f:
site_url = line.strip()
if site_url:
sites.append(site_url)
if not sites:
logger.warning("站点列表为空")
return 0, 0, []
logger.info(
"准备并发扫描 %d 个站点,使用工具: %s",
len(sites), ', '.join(enabled_tools.keys())
)
total_directories = 0
processed_sites_count = 0
failed_sites: List[str] = []
# 遍历每个工具
for tool_name, tool_config in enabled_tools.items():
# 每个工具独立获取 max_workers 配置
max_workers = _get_max_workers(tool_config)
logger.info("="*60)
logger.info("使用工具: %s (并发模式, max_workers=%d)", tool_name, max_workers)
logger.info("="*60)
# 如果配置了 wordlist_name则先确保本地存在对应的字典文件含 hash 校验)
wordlist_name = tool_config.get('wordlist_name')
if wordlist_name:
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
failed_sites.extend(sites)
continue
# 计算超时时间(所有站点共用)
site_timeout = tool_config.get('timeout', 300)
if site_timeout == 'auto':
site_timeout = calculate_directory_scan_timeout(tool_config)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}")
# 准备所有站点的扫描参数
scan_params_list = []
for idx, site_url in enumerate(sites, 1):
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={'url': site_url},
tool_config=tool_config
)
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
scan_params_list.append({
'idx': idx,
'site_url': site_url,
'command': command,
'log_file': str(log_file),
'timeout': site_timeout
})
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
if not scan_params_list:
logger.warning("没有有效的扫描任务")
continue
# ============================================================
# 分批执行策略:控制实际并发的 ffuf 进程数
# ============================================================
total_tasks = len(scan_params_list)
logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers)
batch_num = 0
for batch_start in range(0, total_tasks, max_workers):
batch_end = min(batch_start + max_workers, total_tasks)
batch_params = scan_params_list[batch_start:batch_end]
batch_num += 1
logger.info("执行第 %d 批任务(%d-%d/%d...", batch_num, batch_start + 1, batch_end, total_tasks)
# 提交当前批次的任务(非阻塞,立即返回 future
futures = []
for params in batch_params:
future = run_and_stream_save_directories_task.submit(
cmd=params['command'],
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
site_url=params['site_url'],
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=params['timeout'],
log_file=params['log_file']
)
futures.append((params['idx'], params['site_url'], future))
# 等待当前批次所有任务完成(阻塞,确保本批完成后再启动下一批)
for idx, site_url, future in futures:
try:
result = future.result() # 阻塞等待单个任务完成
directories_found = result.get('created_directories', 0)
total_directories += directories_found
processed_sites_count += 1
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, len(sites), site_url, directories_found
)
except Exception as exc:
failed_sites.append(site_url)
if 'timeout' in str(exc).lower() or isinstance(exc, subprocess.TimeoutExpired):
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
idx, len(sites), site_url, exc
)
else:
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, len(sites), site_url, exc
)
# 输出汇总信息
if failed_sites:
logger.warning(
"部分站点扫描失败: %d/%d",
len(failed_sites), len(sites)
)
logger.info(
"✓ 并发目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
processed_sites_count, len(sites), len(failed_sites), total_directories
)
return total_directories, processed_sites_count, failed_sites
@flow(
name="directory_scan",
log_prints=True,
@@ -359,7 +559,7 @@ def directory_scan_flow(
Step 0: 创建工作目录
Step 1: 导出站点 URL 列表到文件(供扫描工具使用)
Step 2: 验证工具配置
Step 3: 串行执行扫描工具并实时保存结果
Step 3: 并发执行扫描工具并实时保存结果(使用 ThreadPoolTaskRunner
ffuf 输出字段:
- url: 发现的目录/文件 URL
@@ -418,10 +618,11 @@ def directory_scan_flow(
raise ValueError("enabled_tools 不能为空")
# Step 0: 创建工作目录
directory_scan_dir = _setup_directory_scan_directory(scan_workspace_dir)
from apps.scan.utils import setup_scan_directory
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
# Step 1: 导出站点 URL
sites_file, site_count = _export_site_urls(target_id, directory_scan_dir)
# Step 1: 导出站点 URL(支持懒加载)
sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir)
if site_count == 0:
logger.warning("目标下没有站点,跳过目录扫描")
@@ -440,14 +641,15 @@ def directory_scan_flow(
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info(
"✓ 启用工具: %s",
', '.join(enabled_tools.keys())
)
tool_info = []
for tool_name, tool_config in enabled_tools.items():
mw = _get_max_workers(tool_config)
tool_info.append(f"{tool_name}(max_workers={mw})")
logger.info("✓ 启用工具: %s", ', '.join(tool_info))
# Step 3: 串行执行扫描工具并实时保存结果
logger.info("Step 3: 串行执行扫描工具并实时保存结果")
total_directories, processed_sites, failed_sites = _run_scans_sequentially(
# Step 3: 并发执行扫描工具并实时保存结果
logger.info("Step 3: 并发执行扫描工具并实时保存结果")
total_directories, processed_sites, failed_sites = _run_scans_concurrently(
enabled_tools=enabled_tools,
sites_file=sites_file,
directory_scan_dir=directory_scan_dir,

View File

@@ -0,0 +1,380 @@
"""
指纹识别 Flow
负责编排指纹识别的完整流程
架构:
- Flow 负责编排多个原子 Task
- 在 site_scan 后串行执行
- 使用 xingfinger 工具识别技术栈
- 流式处理输出,批量更新数据库
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
from datetime import datetime
from pathlib import Path
from prefect import flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
)
from apps.scan.tasks.fingerprint_detect import (
export_urls_for_fingerprint_task,
run_xingfinger_and_stream_update_tech_task,
)
from apps.scan.utils import build_scan_command
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
logger = logging.getLogger(__name__)
def calculate_fingerprint_detect_timeout(
url_count: int,
base_per_url: float = 3.0,
min_timeout: int = 60
) -> int:
"""
根据 URL 数量计算超时时间
公式:超时时间 = URL 数量 × 每 URL 基础时间
最小值60秒
无上限
Args:
url_count: URL 数量
base_per_url: 每 URL 基础时间(秒),默认 3秒
min_timeout: 最小超时时间(秒),默认 60秒
Returns:
int: 计算出的超时时间(秒)
示例:
100 URL × 3秒 = 300秒
1000 URL × 3秒 = 3000秒50分钟
10000 URL × 3秒 = 30000秒8.3小时)
"""
timeout = int(url_count * base_per_url)
return max(min_timeout, timeout)
def _export_urls(
target_id: int,
fingerprint_dir: Path,
source: str = 'website'
) -> tuple[str, int]:
"""
导出 URL 到文件
Args:
target_id: 目标 ID
fingerprint_dir: 指纹识别目录
source: 数据源类型
Returns:
tuple: (urls_file, total_count)
"""
logger.info("Step 1: 导出 URL 列表 (source=%s)", source)
urls_file = str(fingerprint_dir / 'urls.txt')
export_result = export_urls_for_fingerprint_task(
target_id=target_id,
output_file=urls_file,
source=source,
batch_size=1000
)
total_count = export_result['total_count']
logger.info(
"✓ URL 导出完成 - 文件: %s, 数量: %d",
export_result['output_file'],
total_count
)
return export_result['output_file'], total_count
def _run_fingerprint_detect(
enabled_tools: dict,
urls_file: str,
url_count: int,
fingerprint_dir: Path,
scan_id: int,
target_id: int,
source: str
) -> tuple[dict, list]:
"""
执行指纹识别任务
Args:
enabled_tools: 已启用的工具配置字典
urls_file: URL 文件路径
url_count: URL 总数
fingerprint_dir: 指纹识别目录
scan_id: 扫描任务 ID
target_id: 目标 ID
source: 数据源类型
Returns:
tuple: (tool_stats, failed_tools)
"""
tool_stats = {}
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 获取指纹库路径
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
fingerprint_paths = get_fingerprint_paths(lib_names)
if not fingerprint_paths:
reason = f"没有可用的指纹库: {lib_names}"
logger.warning(reason)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 将指纹库路径合并到 tool_config用于命令构建
tool_config_with_paths = {**tool_config, **fingerprint_paths}
# 3. 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='fingerprint_detect',
command_params={
'urls_file': urls_file
},
tool_config=tool_config_with_paths
)
except Exception as e:
reason = f"命令构建失败: {str(e)}"
logger.error("构建 %s 命令失败: %s", tool_name, e)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 4. 计算超时时间
timeout = calculate_fingerprint_detect_timeout(url_count)
# 5. 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = fingerprint_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
tool_name, url_count, timeout, list(fingerprint_paths.keys())
)
# 6. 执行扫描任务
try:
result = run_xingfinger_and_stream_update_tech_task(
cmd=command,
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
source=source,
cwd=str(fingerprint_dir),
timeout=timeout,
log_file=str(log_file),
batch_size=100
)
tool_stats[tool_name] = {
'command': command,
'result': result,
'timeout': timeout,
'fingerprint_libs': list(fingerprint_paths.keys())
}
logger.info(
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
tool_name,
result.get('processed_records', 0),
result.get('updated_count', 0),
result.get('not_found_count', 0)
)
except Exception as exc:
failed_tools.append({'tool': tool_name, 'reason': str(exc)})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
if failed_tools:
logger.warning(
"以下指纹识别工具执行失败: %s",
', '.join([f['tool'] for f in failed_tools])
)
return tool_stats, failed_tools
@flow(
name="fingerprint_detect",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
on_failure=[on_scan_flow_failed],
)
def fingerprint_detect_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
) -> dict:
"""
指纹识别 Flow
主要功能:
1. 从数据库导出目标下所有 WebSite URL 到文件
2. 使用 xingfinger 进行技术栈识别
3. 解析结果并更新 WebSite.tech 字段(合并去重)
工作流程:
Step 0: 创建工作目录
Step 1: 导出 URL 列表
Step 2: 解析配置,获取启用的工具
Step 3: 执行 xingfinger 并解析结果
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置xingfinger
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'urls_file': str,
'url_count': int,
'processed_records': int,
'updated_count': int,
'not_found_count': int,
'executed_tasks': list,
'tool_stats': dict
}
"""
try:
logger.info(
"="*60 + "\n" +
"开始指纹识别\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
)
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
# 数据源类型(当前只支持 website
source = 'website'
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
# Step 1: 导出 URL支持懒加载
urls_file, url_count = _export_urls(target_id, fingerprint_dir, source)
if url_count == 0:
logger.warning("目标下没有可用的 URL跳过指纹识别")
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': 0,
'processed_records': 0,
'updated_count': 0,
'created_count': 0,
'executed_tasks': ['export_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
# Step 3: 执行指纹识别
logger.info("Step 3: 执行指纹识别")
tool_stats, failed_tools = _run_fingerprint_detect(
enabled_tools=enabled_tools,
urls_file=urls_file,
url_count=url_count,
fingerprint_dir=fingerprint_dir,
scan_id=scan_id,
target_id=target_id,
source=source
)
logger.info("="*60 + "\n✓ 指纹识别完成\n" + "="*60)
# 动态生成已执行的任务列表
executed_tasks = ['export_urls_for_fingerprint']
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats.keys()])
# 汇总所有工具的结果
total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values())
total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values())
total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values())
successful_tools = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': url_count,
'processed_records': total_processed,
'updated_count': total_updated,
'created_count': total_created,
'executed_tasks': executed_tasks,
'tool_stats': {
'total': len(enabled_tools),
'successful': len(successful_tools),
'failed': len(failed_tools),
'successful_tools': successful_tools,
'failed_tools': failed_tools,
'details': tool_stats
}
}
except ValueError as e:
logger.error("配置错误: %s", e)
raise
except RuntimeError as e:
logger.error("运行时错误: %s", e)
raise
except Exception as e:
logger.exception("指纹识别失败: %s", e)
raise

View File

@@ -30,7 +30,7 @@ from apps.scan.handlers import (
on_initiate_scan_flow_failed,
)
from prefect.futures import wait
from apps.scan.tasks.workspace_tasks import create_scan_workspace_task
from apps.scan.utils import setup_scan_workspace
from apps.scan.orchestrators import FlowOrchestrator
logger = logging.getLogger(__name__)
@@ -110,7 +110,7 @@ def initiate_scan_flow(
)
# ==================== Task 1: 创建 Scan 工作空间 ====================
scan_workspace_path = create_scan_workspace_task(scan_workspace_dir)
scan_workspace_path = setup_scan_workspace(scan_workspace_dir)
# ==================== Task 2: 获取引擎配置 ====================
from apps.scan.models import Scan

View File

@@ -154,28 +154,7 @@ def _parse_port_count(tool_config: dict) -> int:
return 100
def _setup_port_scan_directory(scan_workspace_dir: str) -> Path:
"""
创建并验证端口扫描工作目录
Args:
scan_workspace_dir: 扫描工作空间目录
Returns:
Path: 端口扫描目录路径
Raises:
RuntimeError: 目录创建或验证失败
"""
port_scan_dir = Path(scan_workspace_dir) / 'port_scan'
port_scan_dir.mkdir(parents=True, exist_ok=True)
if not port_scan_dir.is_dir():
raise RuntimeError(f"端口扫描目录创建失败: {port_scan_dir}")
if not os.access(port_scan_dir, os.W_OK):
raise RuntimeError(f"端口扫描目录不可写: {port_scan_dir}")
return port_scan_dir
def _export_scan_targets(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
@@ -442,7 +421,8 @@ def port_scan_flow(
)
# Step 0: 创建工作目录
port_scan_dir = _setup_port_scan_directory(scan_workspace_dir)
from apps.scan.utils import setup_scan_directory
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
# Step 1: 导出扫描目标列表到文件(根据 Target 类型自动决定内容)
targets_file, target_count, target_type = _export_scan_targets(target_id, port_scan_dir)

View File

@@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
def calculate_timeout_by_line_count(
tool_config: dict,
file_path: str,
base_per_time: int = 1
base_per_time: int = 1,
min_timeout: int = 60
) -> int:
"""
根据文件行数计算 timeout
@@ -45,9 +46,10 @@ def calculate_timeout_by_line_count(
tool_config: 工具配置字典(此函数未使用,但保持接口一致性)
file_path: 要统计行数的文件路径
base_per_time: 每行的基础时间默认1秒
min_timeout: 最小超时时间默认60秒
Returns:
int: 计算出的超时时间(秒)
int: 计算出的超时时间(秒),不低于 min_timeout
Example:
timeout = calculate_timeout_by_line_count(
@@ -67,53 +69,33 @@ def calculate_timeout_by_line_count(
# wc -l 输出格式:行数 + 空格 + 文件名
line_count = int(result.stdout.strip().split()[0])
# 计算 timeout行数 × 每行基础时间
timeout = line_count * base_per_time
# 计算 timeout行数 × 每行基础时间,不低于最小值
timeout = max(line_count * base_per_time, min_timeout)
logger.info(
f"timeout 自动计算: 文件={file_path}, "
f"行数={line_count}, 每行时间={base_per_time}秒, timeout={timeout}"
f"行数={line_count}, 每行时间={base_per_time}秒, 最小值={min_timeout}秒, timeout={timeout}"
)
return timeout
except Exception as e:
# 如果 wc -l 失败,使用默认值
logger.warning(f"wc -l 计算行数失败: {e},使用默认 timeout: 600")
return 600
logger.warning(f"wc -l 计算行数失败: {e},使用默认 timeout: {min_timeout}")
return min_timeout
def _setup_site_scan_directory(scan_workspace_dir: str) -> Path:
"""
创建并验证站点扫描工作目录
Args:
scan_workspace_dir: 扫描工作空间目录
Returns:
Path: 站点扫描目录路径
Raises:
RuntimeError: 目录创建或验证失败
"""
site_scan_dir = Path(scan_workspace_dir) / 'site_scan'
site_scan_dir.mkdir(parents=True, exist_ok=True)
if not site_scan_dir.is_dir():
raise RuntimeError(f"站点扫描目录创建失败: {site_scan_dir}")
if not os.access(site_scan_dir, os.W_OK):
raise RuntimeError(f"站点扫描目录不可写: {site_scan_dir}")
return site_scan_dir
def _export_site_urls(target_id: int, site_scan_dir: Path) -> tuple[str, int, int]:
def _export_site_urls(target_id: int, site_scan_dir: Path, target_name: str = None) -> tuple[str, int, int]:
"""
导出站点 URL 到文件
Args:
target_id: 目标 ID
site_scan_dir: 站点扫描目录
target_name: 目标名称(用于懒加载时写入默认值)
Returns:
tuple: (urls_file, total_urls, association_count)
@@ -399,11 +381,12 @@ def site_scan_flow(
raise ValueError("scan_workspace_dir 不能为空")
# Step 0: 创建工作目录
site_scan_dir = _setup_site_scan_directory(scan_workspace_dir)
from apps.scan.utils import setup_scan_directory
site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan')
# Step 1: 导出站点 URL
urls_file, total_urls, association_count = _export_site_urls(
target_id, site_scan_dir
target_id, site_scan_dir, target_name
)
if total_urls == 0:

View File

@@ -41,28 +41,7 @@ import subprocess
logger = logging.getLogger(__name__)
def _setup_subdomain_directory(scan_workspace_dir: str) -> Path:
"""
创建并验证子域名扫描工作目录
Args:
scan_workspace_dir: 扫描工作空间目录
Returns:
Path: 子域名扫描目录路径
Raises:
RuntimeError: 目录创建或验证失败
"""
result_dir = Path(scan_workspace_dir) / 'subdomain_discovery'
result_dir.mkdir(parents=True, exist_ok=True)
if not result_dir.is_dir():
raise RuntimeError(f"子域名扫描目录创建失败: {result_dir}")
if not os.access(result_dir, os.W_OK):
raise RuntimeError(f"子域名扫描目录不可写: {result_dir}")
return result_dir
def _validate_and_normalize_target(target_name: str) -> str:
@@ -119,12 +98,7 @@ def _run_scans_parallel(
# 生成时间戳(所有工具共用)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# TODO: 接入代理池管理系统
# from apps.proxy.services import proxy_pool
# proxy_stats = proxy_pool.get_stats()
# logger.info(f"代理池状态: {proxy_stats['healthy']}/{proxy_stats['total']} 可用")
failures = [] # 记录命令构建失败的工具
futures = {}
@@ -352,6 +326,10 @@ def subdomain_discovery_flow(
Stage 4: DNS 存活验证(可选) - 通用存活验证
Final: 保存到数据库
注意:
- 子域名发现只对 DOMAIN 类型目标有意义
- IP 和 CIDR 类型目标会自动跳过
Args:
scan_id: 扫描任务 ID
target_name: 目标名称(域名)
@@ -390,6 +368,21 @@ def subdomain_discovery_flow(
logger.warning("未提供目标域名,跳过子域名发现扫描")
return _empty_result(scan_id, '', scan_workspace_dir)
# ==================== 检查 Target 类型 ====================
# 子域名发现只对 DOMAIN 类型有意义IP 和 CIDR 类型跳过
from apps.targets.services import TargetService
from apps.targets.models import Target
target_service = TargetService()
target = target_service.get_target(target_id)
if target and target.type != Target.TargetType.DOMAIN:
logger.info(
"跳过子域名发现扫描: Target 类型为 %s (ID=%d, Name=%s),子域名发现仅适用于域名类型",
target.type, target_id, target_name
)
return _empty_result(scan_id, target_name, scan_workspace_dir)
# 导入任务函数
from apps.scan.tasks.subdomain_discovery import (
run_subdomain_discovery_task,
@@ -398,7 +391,8 @@ def subdomain_discovery_flow(
)
# Step 0: 准备工作
result_dir = _setup_subdomain_directory(scan_workspace_dir)
from apps.scan.utils import setup_scan_directory
result_dir = setup_scan_directory(scan_workspace_dir, 'subdomain_discovery')
# 验证并规范化目标域名
try:

View File

@@ -4,18 +4,15 @@ URL Fetch Flow 模块
提供 URL 获取相关的 Flow
- url_fetch_flow: 主 Flow按输入类型编排 + 统一后处理)
- domain_name_url_fetch_flow: 基于 domain_name来自 target_name输入的 URL 获取子 Flow如 waymore
- domains_url_fetch_flow: 基于 domains_file 输入的 URL 获取子 Flow如 gau、waybackurls
- sites_url_fetch_flow: 基于 sites_file 输入的 URL 获取子 Flow如 katana 等爬虫)
"""
from .main_flow import url_fetch_flow
from .domain_name_url_fetch_flow import domain_name_url_fetch_flow
from .domains_url_fetch_flow import domains_url_fetch_flow
from .sites_url_fetch_flow import sites_url_fetch_flow
__all__ = [
'url_fetch_flow',
'domain_name_url_fetch_flow',
'domains_url_fetch_flow',
'sites_url_fetch_flow',
]

View File

@@ -1,9 +1,14 @@
"""
基于 domain_name域名的 URL 获取 Flow
基于 Target 根域名的 URL 被动收集 Flow
主要用于 waymore 这种按域名输入input_type = 'domain_name')的工具:
- 直接对目标域名target_name/domain_name执行 URL 被动收集
- 不再依赖 domains_file子域名列表文件
用于 waymore 等被动收集工具:
- 输入Target 的根域名target_name,如 example.com
- 工具会自动从第三方源Wayback Machine、Common Crawl 等)查询该域名及其子域名的历史 URL
- 不需要遍历子域名列表,工具内部会处理 *.example.com
注意:
- 此 Flow 只对 DOMAIN 类型 Target 有效
- IP 和 CIDR 类型会自动跳过(被动收集工具不支持)
"""
# Django 环境初始化
@@ -34,18 +39,49 @@ def domain_name_url_fetch_flow(
domain_name_tools: Dict[str, dict],
) -> dict:
"""
基于 target_name/domain_name 域名执行 URL 获取子 Flow(当前主要用于 waymore
基于 Target 域名执行 URL 被动收集(当前主要用于 waymore
执行流程:
1. 校验 target_name 是否为域名
2. 使用传入的 domain_name_tools 工具列表
3. 为每个工具构建命令并并行执行
1. 校验 Target 类型IP/CIDR 类型跳过)
2. 使用传入的工具列表对根域名执行被动收集
3. 工具内部会自动查询该域名及其子域名的历史 URL
4. 汇总结果文件列表
Args:
scan_id: 扫描 ID
target_id: 目标 ID
target_name: Target 根域名(如 example.com不是子域名列表
output_dir: 输出目录
domain_name_tools: 被动收集工具配置(如 waymore
注意:
- 此 Flow 只对 DOMAIN 类型 Target 有效
- IP 和 CIDR 类型会自动跳过waymore 等工具不支持)
- 工具会自动收集 *.target_name 的所有历史 URL无需遍历子域名
"""
try:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# 检查 Target 类型IP/CIDR 类型跳过
from apps.targets.services import TargetService
from apps.targets.models import Target
target_service = TargetService()
target = target_service.get_target(target_id)
if target and target.type != Target.TargetType.DOMAIN:
logger.info(
"跳过 domain_name URL 获取: Target 类型为 %s (ID=%d, Name=%s)waymore 等工具仅适用于域名类型",
target.type, target_id, target_name
)
return {
"success": True,
"result_files": [],
"failed_tools": [],
"successful_tools": [],
}
# 复用公共域名校验逻辑,确保 target_name 是合法域名
validate_domain(target_name)

View File

@@ -1,139 +0,0 @@
"""
URL 被动收集 Flow
从历史归档、搜索引擎等被动来源收集 URL
工具waymore, gau, waybackurls 等
输入domains_file子域名列表
"""
# Django 环境初始化
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
from pathlib import Path
from prefect import flow
from .utils import run_tools_parallel
logger = logging.getLogger(__name__)
def _export_domains_file(target_id: int, scan_id: int, output_dir: Path) -> tuple[str, int]:
"""
导出子域名列表到文件
Args:
target_id: 目标 ID
scan_id: 扫描 ID
output_dir: 输出目录
Returns:
tuple: (file_path, count)
"""
from apps.scan.tasks.url_fetch import export_target_assets_task
output_file = str(output_dir / "domains.txt")
result = export_target_assets_task(
output_file=output_file,
target_id=target_id,
scan_id=scan_id,
input_type="domains_file"
)
count = result['asset_count']
if count == 0:
logger.warning("子域名列表为空,被动收集可能无法正常工作")
else:
logger.info("✓ 子域名列表导出完成 - 数量: %d", count)
return output_file, count
@flow(name="domains_url_fetch_flow", log_prints=True)
def domains_url_fetch_flow(
scan_id: int,
target_id: int,
target_name: str,
output_dir: str,
enabled_tools: dict
) -> dict:
"""
URL 被动收集子 Flow
执行流程:
1. 导出子域名列表domains_file
2. 并行执行被动收集工具
3. 返回结果文件列表
Args:
scan_id: 扫描 ID
target_id: 目标 ID
target_name: 目标名称
output_dir: 输出目录
enabled_tools: 启用的被动收集工具配置
Returns:
dict: {
'success': bool,
'result_files': list,
'failed_tools': list,
'successful_tools': list,
'domains_count': int
}
"""
try:
output_path = Path(output_dir)
logger.info(
"开始 URL 被动收集 - Target: %s, Tools: %s",
target_name, ', '.join(enabled_tools.keys())
)
# Step 1: 导出子域名列表
domains_file, domains_count = _export_domains_file(
target_id=target_id,
scan_id=scan_id,
output_dir=output_path
)
if domains_count == 0:
logger.warning("没有可用的子域名,跳过被动收集")
return {
'success': True,
'result_files': [],
'failed_tools': [],
'successful_tools': [],
'domains_count': 0
}
# Step 2: 并行执行被动收集工具
result_files, failed_tools, successful_tools = run_tools_parallel(
tools=enabled_tools,
input_file=domains_file,
input_type="domains_file",
output_dir=output_path
)
logger.info(
"✓ 被动收集完成 - 成功: %d/%d, 结果文件: %d",
len(successful_tools), len(enabled_tools), len(result_files)
)
return {
'success': True,
'result_files': result_files,
'failed_tools': failed_tools,
'successful_tools': successful_tools,
'domains_count': domains_count
}
except Exception as e:
logger.error("URL 被动收集失败: %s", e, exc_info=True)
return {
'success': False,
'result_files': [],
'failed_tools': [{'tool': 'domains_url_fetch_flow', 'reason': str(e)}],
'successful_tools': [],
'domains_count': 0
}

View File

@@ -1,10 +1,10 @@
"""
URL Fetch 主 Flow
负责编排不同输入类型的 URL 获取子 Flowdomain_name / domains_file / sites_file以及统一的后处理uro 去重、httpx 验证)
负责编排不同输入类型的 URL 获取子 Flowdomain_name / sites_file以及统一的后处理uro 去重、httpx 验证)
架构:
- 调用 domain_name_url_fetch_flowdomain_name 输入)、domains_url_fetch_flowdomains_file 输入)和 sites_url_fetch_flowsites_file 输入)
- 调用 domain_name_url_fetch_flowdomain_name 输入)和 sites_url_fetch_flowsites_file 输入)
- 合并多个子 Flow 的结果
- 统一进行 uro 去重(如果启用)
- 统一进行 httpx 验证(如果启用)
@@ -27,7 +27,6 @@ from apps.scan.handlers.scan_flow_handlers import (
)
from .domain_name_url_fetch_flow import domain_name_url_fetch_flow
from .domains_url_fetch_flow import domains_url_fetch_flow
from .sites_url_fetch_flow import sites_url_fetch_flow
from .utils import calculate_timeout_by_line_count
@@ -37,36 +36,23 @@ logger = logging.getLogger(__name__)
# ==================== 工具分类配置 ====================
# 使用 target_name (domain_name) 作为输入的 URL 获取工具
DOMAIN_NAME_TOOLS = {'waymore'}
# 使用 domains_file 作为输入的 URL 获取工具
DOMAINS_FILE_TOOLS = {'gau', 'waybackurls'}
# 使用 sites_file 作为输入的 URL 获取工具
SITES_FILE_TOOLS = {'katana', 'gospider', 'hakrawler'}
SITES_FILE_TOOLS = {'katana'}
# 后处理工具:不参与获取,用于清理和验证
POST_PROCESS_TOOLS = {'uro', 'httpx'}
def _setup_url_fetch_directory(scan_workspace_dir: str) -> Path:
"""创建并验证 URL 获取工作目录"""
url_fetch_dir = Path(scan_workspace_dir) / 'url_fetch'
url_fetch_dir.mkdir(parents=True, exist_ok=True)
if not url_fetch_dir.is_dir():
raise RuntimeError(f"URL 获取目录创建失败: {url_fetch_dir}")
if not os.access(url_fetch_dir, os.W_OK):
raise RuntimeError(f"URL 获取目录不可写: {url_fetch_dir}")
return url_fetch_dir
def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict, dict]:
def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
"""
将启用的工具按输入类型分类
Returns:
tuple: (domain_name_tools, domains_file_tools, sites_file_tools, uro_config, httpx_config)
tuple: (domain_name_tools, sites_file_tools, uro_config, httpx_config)
"""
domain_name_tools: dict = {}
domains_file_tools: dict = {}
sites_file_tools: dict = {}
uro_config = None
httpx_config = None
@@ -74,8 +60,6 @@ def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict, dict]:
for tool_name, tool_config in enabled_tools.items():
if tool_name in DOMAIN_NAME_TOOLS:
domain_name_tools[tool_name] = tool_config
elif tool_name in DOMAINS_FILE_TOOLS:
domains_file_tools[tool_name] = tool_config
elif tool_name in SITES_FILE_TOOLS:
sites_file_tools[tool_name] = tool_config
elif tool_name == 'uro':
@@ -83,10 +67,9 @@ def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict, dict]:
elif tool_name == 'httpx':
httpx_config = tool_config
else:
logger.warning("未知工具类型: %s将尝试作为 domains_file 输入的被动收集工具", tool_name)
domains_file_tools[tool_name] = tool_config
logger.warning("未知工具类型: %s跳过", tool_name)
return domain_name_tools, domains_file_tools, sites_file_tools, uro_config, httpx_config
return domain_name_tools, sites_file_tools, uro_config, httpx_config
def _merge_and_deduplicate_urls(result_files: list, url_fetch_dir: Path) -> tuple[str, int]:
@@ -131,9 +114,9 @@ def _clean_urls_with_uro(
tool_config=uro_config,
file_path=merged_file,
base_per_time=1,
min_timeout=60,
)
timeout = max(30, timeout)
logger.info("uro 自动计算超时时间(按行数,每行 1 秒): %d", timeout)
logger.info("uro 自动计算超时时间(按行数,每行 1 秒,最小 60 秒): %d", timeout)
else:
try:
timeout = int(raw_timeout)
@@ -202,11 +185,10 @@ def _validate_and_stream_save_urls(
raw_timeout = httpx_config.get('timeout', 'auto')
timeout = 3600
if isinstance(raw_timeout, str) and raw_timeout == 'auto':
# 按 URL 行数计算超时时间:每行 3 秒,不设上限
timeout = url_count * 3
timeout = max(600, timeout)
# 按 URL 行数计算超时时间:每行 3 秒,最小 60 秒
timeout = max(60, url_count * 3)
logger.info(
"自动计算 httpx 超时时间(按行数,每行 3 秒): url_count=%d, timeout=%d",
"自动计算 httpx 超时时间(按行数,每行 3 秒,最小 60 秒): url_count=%d, timeout=%d",
url_count,
timeout,
)
@@ -282,10 +264,9 @@ def url_fetch_flow(
执行流程:
1. 准备工作目录
2. 按输入类型分类工具domain_name / domains_file / sites_file / 后处理)
2. 按输入类型分类工具domain_name / sites_file / 后处理)
3. 并行执行子 Flow
- domain_name_url_fetch_flow: 基于 domain_name来自 target_name执行 URL 获取(如 waymore
- domains_url_fetch_flow: 基于 domains_file 执行 URL 获取(如 gau、waybackurls
- sites_url_fetch_flow: 基于 sites_file 执行爬虫(如 katana 等)
4. 合并所有子 Flow 的结果并去重
5. uro 去重(如果启用)
@@ -313,23 +294,23 @@ def url_fetch_flow(
# Step 1: 准备工作目录
logger.info("Step 1: 准备工作目录")
url_fetch_dir = _setup_url_fetch_directory(scan_workspace_dir)
from apps.scan.utils import setup_scan_directory
url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch')
# Step 2: 分类工具(按输入类型)
logger.info("Step 2: 分类工具")
domain_name_tools, domains_file_tools, sites_file_tools, uro_config, httpx_config = _classify_tools(enabled_tools)
domain_name_tools, sites_file_tools, uro_config, httpx_config = _classify_tools(enabled_tools)
logger.info(
"工具分类 - domain_name: %s, domains_file: %s, sites_file: %s, uro: %s, httpx: %s",
"工具分类 - domain_name: %s, sites_file: %s, uro: %s, httpx: %s",
list(domain_name_tools.keys()) or '',
list(domains_file_tools.keys()) or '',
list(sites_file_tools.keys()) or '',
'启用' if uro_config else '未启用',
'启用' if httpx_config else '未启用'
)
# 检查是否有获取工具
if not domain_name_tools and not domains_file_tools and not sites_file_tools:
if not domain_name_tools and not sites_file_tools:
raise ValueError(
"URL Fetch 流程需要至少启用一个 URL 获取工具(如 waymore, katana"
"httpx 和 uro 仅用于后处理,不能单独使用。"
@@ -353,24 +334,10 @@ def url_fetch_flow(
all_result_files.extend(tn_result.get('result_files', []))
all_failed_tools.extend(tn_result.get('failed_tools', []))
all_successful_tools.extend(tn_result.get('successful_tools', []))
# 3b: 基于 domains_file 的 URL 被动收集
if domains_file_tools:
logger.info("Step 3b: 执行基于 domains_file 的 URL 被动收集子 Flow")
passive_result = domains_url_fetch_flow(
scan_id=scan_id,
target_id=target_id,
target_name=target_name,
output_dir=str(url_fetch_dir),
enabled_tools=domains_file_tools,
)
all_result_files.extend(passive_result.get('result_files', []))
all_failed_tools.extend(passive_result.get('failed_tools', []))
all_successful_tools.extend(passive_result.get('successful_tools', []))
# 3c: 爬虫(以 sites_file 为输入)
# 3b: 爬虫(以 sites_file 为输入)
if sites_file_tools:
logger.info("Step 3c: 执行爬虫子 Flow")
logger.info("Step 3b: 执行爬虫子 Flow")
crawl_result = sites_url_fetch_flow(
scan_id=scan_id,
target_id=target_id,
@@ -443,8 +410,6 @@ def url_fetch_flow(
executed_tasks = ['setup_directory', 'classify_tools']
if domain_name_tools:
executed_tasks.append('domain_name_url_fetch_flow')
if domains_file_tools:
executed_tasks.append('domains_url_fetch_flow')
if sites_file_tools:
executed_tasks.append('sites_url_fetch_flow')
executed_tasks.append('merge_and_deduplicate')
@@ -463,7 +428,7 @@ def url_fetch_flow(
'total': saved_count,
'executed_tasks': executed_tasks,
'tool_stats': {
'total': len(domain_name_tools) + len(domains_file_tools) + len(sites_file_tools),
'total': len(domain_name_tools) + len(sites_file_tools),
'successful': len(all_successful_tools),
'failed': len(all_failed_tools),
'successful_tools': all_successful_tools,

View File

@@ -19,33 +19,35 @@ from .utils import run_tools_parallel
logger = logging.getLogger(__name__)
def _export_sites_file(target_id: int, scan_id: int, output_dir: Path) -> tuple[str, int]:
def _export_sites_file(target_id: int, scan_id: int, target_name: str, output_dir: Path) -> tuple[str, int]:
"""
导出站点 URL 列表到文件
懒加载模式:如果 WebSite 表为空,根据 Target 类型生成默认 URL
Args:
target_id: 目标 ID
scan_id: 扫描 ID
target_name: 目标名称(用于懒加载)
output_dir: 输出目录
Returns:
tuple: (file_path, count)
"""
from apps.scan.tasks.url_fetch import export_target_assets_task
from apps.scan.tasks.url_fetch import export_sites_task
output_file = str(output_dir / "sites.txt")
result = export_target_assets_task(
result = export_sites_task(
output_file=output_file,
target_id=target_id,
scan_id=scan_id,
input_type="sites_file"
scan_id=scan_id
)
count = result['asset_count']
if count == 0:
logger.warning("站点列表为空,爬虫可能无法正常工作")
else:
if count > 0:
logger.info("✓ 站点列表导出完成 - 数量: %d", count)
else:
logger.warning("站点列表为空,爬虫可能无法正常工作")
return output_file, count
@@ -94,9 +96,11 @@ def sites_url_fetch_flow(
sites_file, sites_count = _export_sites_file(
target_id=target_id,
scan_id=scan_id,
target_name=target_name,
output_dir=output_path
)
# 默认值模式下,即使原本没有站点,也会有默认 URL 作为输入
if sites_count == 0:
logger.warning("没有可用的站点,跳过爬虫")
return {

View File

@@ -17,6 +17,7 @@ def calculate_timeout_by_line_count(
tool_config: dict,
file_path: str,
base_per_time: int = 1,
min_timeout: int = 60,
) -> int:
"""
根据文件行数自动计算超时时间
@@ -25,9 +26,10 @@ def calculate_timeout_by_line_count(
tool_config: 工具配置(保留参数,未来可能用于更复杂的计算)
file_path: 输入文件路径
base_per_time: 每行的基础时间(秒)
min_timeout: 最小超时时间默认60秒
Returns:
int: 计算出的超时时间(秒)
int: 计算出的超时时间(秒),不低于 min_timeout
"""
try:
result = subprocess.run(
@@ -37,18 +39,19 @@ def calculate_timeout_by_line_count(
check=True,
)
line_count = int(result.stdout.strip().split()[0])
timeout = line_count * base_per_time
timeout = max(line_count * base_per_time, min_timeout)
logger.info(
"timeout 自动计算: 文件=%s, 行数=%d, 每行时间=%d秒, timeout=%d",
"timeout 自动计算: 文件=%s, 行数=%d, 每行时间=%d秒, 最小值=%d秒, timeout=%d",
file_path,
line_count,
base_per_time,
min_timeout,
timeout,
)
return timeout
except Exception as e:
logger.warning("wc -l 计算行数失败: %s,将使用默认 timeout: 600", e)
return 600
logger.warning("wc -l 计算行数失败: %s,将使用默认 timeout: %d", e, min_timeout)
return min_timeout
def prepare_tool_execution(

View File

@@ -25,10 +25,7 @@ from .utils import calculate_timeout_by_line_count
logger = logging.getLogger(__name__)
def _setup_vuln_scan_directory(scan_workspace_dir: str) -> Path:
vuln_scan_dir = Path(scan_workspace_dir) / "vuln_scan"
vuln_scan_dir.mkdir(parents=True, exist_ok=True)
return vuln_scan_dir
@flow(
@@ -55,7 +52,8 @@ def endpoints_vuln_scan_flow(
if not enabled_tools:
raise ValueError("enabled_tools 不能为空")
vuln_scan_dir = _setup_vuln_scan_directory(scan_workspace_dir)
from apps.scan.utils import setup_scan_directory
vuln_scan_dir = setup_scan_directory(scan_workspace_dir, 'vuln_scan')
endpoints_file = vuln_scan_dir / "input_endpoints.txt"
# Step 1: 导出 Endpoint URL
@@ -119,7 +117,6 @@ def endpoints_vuln_scan_flow(
)
raw_timeout = tool_config.get("timeout", 600)
timeout = 600
if isinstance(raw_timeout, str) and raw_timeout == "auto":
# timeout=auto 时,根据 endpoints_file 行数自动计算超时时间
@@ -134,7 +131,6 @@ def endpoints_vuln_scan_flow(
try:
timeout = int(raw_timeout)
except (TypeError, ValueError) as e:
# 配置错误应当直接暴露,避免默默使用默认值导致排查困难
raise ValueError(
f"工具 {tool_name} 的 timeout 配置无效: {raw_timeout!r}"
) from e
@@ -174,7 +170,7 @@ def endpoints_vuln_scan_flow(
target_id=target_id,
cwd=str(vuln_scan_dir),
shell=True,
batch_size=10,
batch_size=1,
timeout=timeout,
log_file=str(log_file),
)

View File

@@ -12,6 +12,7 @@ def calculate_timeout_by_line_count(
tool_config: dict,
file_path: str,
base_per_time: int = 1,
min_timeout: int = 600,
) -> int:
"""
根据文件行数自动计算超时时间
@@ -20,9 +21,10 @@ def calculate_timeout_by_line_count(
tool_config: 工具配置(保留参数,未来可能用于更复杂的计算)
file_path: 输入文件路径
base_per_time: 每行的基础时间(秒)
min_timeout: 最小超时时间默认600秒10分钟
Returns:
int: 计算出的超时时间(秒)
int: 计算出的超时时间(秒),不低于 min_timeout
"""
try:
result = subprocess.run(
@@ -32,15 +34,16 @@ def calculate_timeout_by_line_count(
check=True,
)
line_count = int(result.stdout.strip().split()[0])
timeout = line_count * base_per_time
timeout = max(line_count * base_per_time, min_timeout)
logger.info(
"timeout 自动计算: 文件=%s, 行数=%d, 每行时间=%d秒, timeout=%d",
"timeout 自动计算: 文件=%s, 行数=%d, 每行时间=%d秒, 最小值=%d秒, timeout=%d",
file_path,
line_count,
base_per_time,
min_timeout,
timeout,
)
return timeout
except Exception as e:
logger.error("wc -l 计算行数失败: %s", e)
raise RuntimeError(f"自动计算超时时间失败: {e}") from e
logger.warning("wc -l 计算行数失败: %s,使用最小超时: %d", e, min_timeout)
return min_timeout

View File

View File

@@ -1,567 +0,0 @@
"""
生成测试数据的管理命令
用法:
python manage.py generate_test_data --target test.com --count 100000
性能测试:
python manage.py generate_test_data --target test.com --count 10000 --batch-size 500 --benchmark
"""
import random
import string
import time
from django.core.management.base import BaseCommand
from django.db import transaction, connection
from django.utils import timezone
from apps.targets.models import Target
from apps.scan.models import Scan
from apps.asset.models.asset_models import Subdomain, IPAddress, Port, WebSite, Directory
class Command(BaseCommand):
help = '为指定目标生成大量测试数据'
def add_arguments(self, parser):
parser.add_argument(
'--target',
type=str,
required=True,
help='目标域名(如 test.com'
)
parser.add_argument(
'--count',
type=int,
default=100000,
help='每个表生成的记录数(默认 100000'
)
parser.add_argument(
'--batch-size',
type=int,
default=1000,
help='批量插入的批次大小(默认 1000'
)
parser.add_argument(
'--benchmark',
action='store_true',
help='启用性能基准测试模式(显示详细的性能指标)'
)
parser.add_argument(
'--test-batch-sizes',
action='store_true',
help='测试不同批次大小的性能100, 500, 1000, 2000, 5000'
)
def handle(self, *args, **options):
target_name = options['target']
count = options['count']
batch_size = options['batch_size']
benchmark = options['benchmark']
test_batch_sizes = options['test_batch_sizes']
# 如果是测试批次大小模式
if test_batch_sizes:
self._test_batch_sizes(target_name, count)
return
self.stdout.write(f'\n{"="*60}')
self.stdout.write(f' 开始生成测试数据')
self.stdout.write(f'{"="*60}\n')
self.stdout.write(f'目标: {target_name}')
self.stdout.write(f'每表记录数: {count:,}')
self.stdout.write(f'批次大小: {batch_size:,}')
if benchmark:
self.stdout.write('模式: 性能基准测试 ⚡')
self._print_db_info()
self.stdout.write('')
# 记录总开始时间
total_start_time = time.time()
# 1. 获取或创建目标
try:
target = Target.objects.get(name=target_name)
self.stdout.write(self.style.SUCCESS(f'✓ 找到目标: {target.name} (ID: {target.id})'))
except Target.DoesNotExist:
self.stdout.write(self.style.ERROR(f'✗ 目标不存在: {target_name}'))
return
# 2. 创建新的测试扫描任务
from apps.engine.models import ScanEngine
engine = ScanEngine.objects.first()
if not engine:
self.stdout.write(self.style.ERROR('✗ 没有可用的扫描引擎'))
return
scan = Scan.objects.create(
target=target,
engine=engine,
status='completed',
results_dir=f'/tmp/test_{target_name}_{int(time.time())}'
)
self.stdout.write(self.style.SUCCESS(f'✓ 创建新测试扫描任务 (ID: {scan.id})'))
# 3. 生成子域名
self.stdout.write(f'\n[1/5] 生成 {count:,} 个子域名...')
subdomains, stats1 = self._generate_subdomains(target, scan, count, batch_size, benchmark)
# 4. 生成 IP 地址
self.stdout.write(f'\n[2/5] 生成 {count:,} 个 IP 地址...')
ips, stats2 = self._generate_ips(target, scan, subdomains, count, batch_size, benchmark)
# 5. 生成端口
self.stdout.write(f'\n[3/5] 生成 {count:,} 个端口...')
stats3 = self._generate_ports(scan, ips, subdomains, count, batch_size, benchmark)
# 6. 生成网站
self.stdout.write(f'\n[4/5] 生成 {count:,} 个网站...')
websites, stats4 = self._generate_websites(target, scan, subdomains, count, batch_size, benchmark)
# 7. 生成目录
self.stdout.write(f'\n[5/5] 生成 {count:,} 个目录...')
stats5 = self._generate_directories(target, scan, websites, count, batch_size, benchmark)
# 计算总耗时
total_time = time.time() - total_start_time
self.stdout.write(f'\n{"="*60}')
self.stdout.write(self.style.SUCCESS(' ✓ 测试数据生成完成!'))
self.stdout.write(f'{"="*60}')
self.stdout.write(f'总耗时: {total_time:.2f} 秒 ({total_time/60:.2f} 分钟)\n')
if benchmark:
self._print_performance_summary([stats1, stats2, stats3, stats4, stats5])
def _generate_subdomains(self, target, scan, count, batch_size, benchmark=False):
"""生成子域名"""
subdomains = []
created_subdomains = []
start_time = time.time()
batch_times = []
for i in range(count):
# 生成唯一的子域名
subdomain_name = f'test-{i:07d}.{target.name}'
subdomains.append(Subdomain(
target=target,
scan=scan,
name=subdomain_name,
cname=[],
is_cdn=random.choice([True, False]),
cdn_name=random.choice(['', 'cloudflare', 'akamai', 'fastly'])
))
# 批量插入
if len(subdomains) >= batch_size:
batch_start = time.time()
with transaction.atomic():
created = Subdomain.objects.bulk_create(subdomains, ignore_conflicts=True)
created_subdomains.extend(created)
batch_time = time.time() - batch_start
batch_times.append(batch_time)
if benchmark:
speed = len(subdomains) / batch_time
self.stdout.write(f' 插入 {len(subdomains):,} 个 | 耗时: {batch_time:.2f}s | 速度: {speed:.0f} 条/秒')
else:
self.stdout.write(f' 插入 {len(subdomains):,} 个子域名... (进度: {i+1:,}/{count:,})')
subdomains = []
# 插入剩余的
if subdomains:
with transaction.atomic():
created = Subdomain.objects.bulk_create(subdomains, ignore_conflicts=True)
created_subdomains.extend(created)
self.stdout.write(f' 插入 {len(subdomains):,} 个子域名... (进度: {count:,}/{count:,})')
total_time = time.time() - start_time
avg_batch_time = sum(batch_times) / len(batch_times) if batch_times else 0
total_speed = len(created_subdomains) / total_time if total_time > 0 else 0
self.stdout.write(self.style.SUCCESS(
f' ✓ 完成!共创建 {len(created_subdomains):,} 个 | '
f'总耗时: {total_time:.2f}s | '
f'平均速度: {total_speed:.0f} 条/秒'
))
return created_subdomains, {
'name': '子域名',
'count': len(created_subdomains),
'time': total_time,
'speed': total_speed,
'avg_batch_time': avg_batch_time
}
def _generate_ips(self, target, scan, subdomains, count, batch_size, benchmark=False):
"""生成 IP 地址"""
# 重新从数据库查询 subdomain确保有 ID
subdomain_list = list(Subdomain.objects.filter(scan=scan).values_list('id', flat=True))
ips = []
created_ips = []
start_time = time.time()
batch_times = []
for i in range(count):
# 生成随机 IP
ip_addr = f'192.168.{random.randint(0, 255)}.{random.randint(1, 254)}'
subdomain_id = random.choice(subdomain_list) if subdomain_list else None
if subdomain_id:
ips.append(IPAddress(
target=target,
scan=scan,
subdomain_id=subdomain_id,
ip=f'{ip_addr}-{i}', # 加后缀确保唯一
protocol_version='IPv4',
is_private=True
))
# 批量插入
if len(ips) >= batch_size:
batch_start = time.time()
with transaction.atomic():
created = IPAddress.objects.bulk_create(ips, ignore_conflicts=True)
created_ips.extend(created)
batch_time = time.time() - batch_start
batch_times.append(batch_time)
if benchmark:
speed = len(ips) / batch_time
self.stdout.write(f' 插入 {len(ips):,} 个 | 耗时: {batch_time:.2f}s | 速度: {speed:.0f} 条/秒')
else:
self.stdout.write(f' 插入 {len(ips):,} 个 IP 地址... (进度: {i+1:,}/{count:,})')
ips = []
# 插入剩余的
if ips:
with transaction.atomic():
created = IPAddress.objects.bulk_create(ips, ignore_conflicts=True)
created_ips.extend(created)
self.stdout.write(f' 插入 {len(ips):,} 个 IP 地址... (进度: {count:,}/{count:,})')
total_time = time.time() - start_time
avg_batch_time = sum(batch_times) / len(batch_times) if batch_times else 0
total_speed = len(created_ips) / total_time if total_time > 0 else 0
self.stdout.write(self.style.SUCCESS(
f' ✓ 完成!共创建 {len(created_ips):,} 个 | '
f'总耗时: {total_time:.2f}s | '
f'平均速度: {total_speed:.0f} 条/秒'
))
return created_ips, {
'name': 'IP地址',
'count': len(created_ips),
'time': total_time,
'speed': total_speed,
'avg_batch_time': avg_batch_time
}
def _generate_ports(self, scan, ips, subdomains, count, batch_size, benchmark=False):
"""生成端口"""
# 重新查询 IP 和 subdomain 的 ID
ip_list = list(IPAddress.objects.filter(scan=scan).values_list('id', flat=True))
subdomain_list = list(Subdomain.objects.filter(scan=scan).values_list('id', flat=True))
ports = []
total_created = 0
start_time = time.time()
batch_times = []
for i in range(count):
ip_id = random.choice(ip_list) if ip_list else None
subdomain_id = random.choice(subdomain_list) if subdomain_list else None
if ip_id:
ports.append(Port(
ip_address_id=ip_id,
subdomain_id=subdomain_id,
number=random.randint(1, 65535),
service_name=random.choice(['http', 'https', 'ssh', 'ftp', 'mysql']),
is_uncommon=random.choice([True, False])
))
# 批量插入
if len(ports) >= batch_size:
batch_start = time.time()
with transaction.atomic():
Port.objects.bulk_create(ports, ignore_conflicts=True)
total_created += len(ports)
batch_time = time.time() - batch_start
batch_times.append(batch_time)
if benchmark:
speed = len(ports) / batch_time
self.stdout.write(f' 插入 {len(ports):,} 个 | 耗时: {batch_time:.2f}s | 速度: {speed:.0f} 条/秒')
else:
self.stdout.write(f' 插入 {len(ports):,} 个端口... (进度: {i+1:,}/{count:,})')
ports = []
# 插入剩余的
if ports:
with transaction.atomic():
Port.objects.bulk_create(ports, ignore_conflicts=True)
total_created += len(ports)
self.stdout.write(f' 插入 {len(ports):,} 个端口... (进度: {count:,}/{count:,})')
total_time = time.time() - start_time
avg_batch_time = sum(batch_times) / len(batch_times) if batch_times else 0
total_speed = total_created / total_time if total_time > 0 else 0
self.stdout.write(self.style.SUCCESS(
f' ✓ 完成!共创建 {total_created:,} 个 | '
f'总耗时: {total_time:.2f}s | '
f'平均速度: {total_speed:.0f} 条/秒'
))
return {
'name': '端口',
'count': total_created,
'time': total_time,
'speed': total_speed,
'avg_batch_time': avg_batch_time
}
def _generate_websites(self, target, scan, subdomains, count, batch_size, benchmark=False):
"""生成网站"""
# 重新查询 subdomain 信息
subdomain_data = list(Subdomain.objects.filter(scan=scan).values('id', 'name'))
websites = []
created_websites = []
start_time = time.time()
batch_times = []
for i in range(count):
subdomain = random.choice(subdomain_data) if subdomain_data else None
if subdomain:
protocol = random.choice(['http', 'https'])
url = f'{protocol}://{subdomain["name"]}'
websites.append(WebSite(
target=target,
scan=scan,
subdomain_id=subdomain['id'],
url=f'{url}?id={i}', # 加参数确保唯一
title=f'Test Website {i}',
status_code=random.choice([200, 301, 302, 404, 500]),
content_length=random.randint(1000, 100000),
webserver=random.choice(['nginx', 'apache', 'IIS']),
content_type='text/html',
tech=['Python', 'Django'] if i % 2 == 0 else ['Node.js', 'React'],
vhost=random.choice([True, False, None])
))
# 批量插入
if len(websites) >= batch_size:
batch_start = time.time()
with transaction.atomic():
created = WebSite.objects.bulk_create(websites, ignore_conflicts=True)
created_websites.extend(created)
batch_time = time.time() - batch_start
batch_times.append(batch_time)
if benchmark:
speed = len(websites) / batch_time
self.stdout.write(f' 插入 {len(websites):,} 个 | 耗时: {batch_time:.2f}s | 速度: {speed:.0f} 条/秒')
else:
self.stdout.write(f' 插入 {len(websites):,} 个网站... (进度: {i+1:,}/{count:,})')
websites = []
# 插入剩余的
if websites:
with transaction.atomic():
created = WebSite.objects.bulk_create(websites, ignore_conflicts=True)
created_websites.extend(created)
self.stdout.write(f' 插入 {len(websites):,} 个网站... (进度: {count:,}/{count:,})')
total_time = time.time() - start_time
avg_batch_time = sum(batch_times) / len(batch_times) if batch_times else 0
total_speed = len(created_websites) / total_time if total_time > 0 else 0
self.stdout.write(self.style.SUCCESS(
f' ✓ 完成!共创建 {len(created_websites):,} 个 | '
f'总耗时: {total_time:.2f}s | '
f'平均速度: {total_speed:.0f} 条/秒'
))
return created_websites, {
'name': '网站',
'count': len(created_websites),
'time': total_time,
'speed': total_speed,
'avg_batch_time': avg_batch_time
}
def _generate_directories(self, target, scan, websites, count, batch_size, benchmark=False):
"""生成目录"""
# 重新查询 website 信息
website_data = list(WebSite.objects.filter(scan=scan).values('id', 'url'))
directories = []
total_created = 0
start_time = time.time()
batch_times = []
for i in range(count):
website = random.choice(website_data) if website_data else None
if website:
path = ''.join(random.choices(string.ascii_lowercase, k=10))
directories.append(Directory(
target=target,
scan=scan,
website_id=website['id'],
url=f'{website["url"]}/dir/{path}/{i}', # 加后缀确保唯一
status=random.choice([200, 301, 403, 404]),
length=random.randint(1000, 50000),
words=random.randint(100, 5000),
lines=random.randint(50, 1000),
content_type='text/html'
))
# 批量插入
if len(directories) >= batch_size:
batch_start = time.time()
with transaction.atomic():
Directory.objects.bulk_create(directories, ignore_conflicts=True)
total_created += len(directories)
batch_time = time.time() - batch_start
batch_times.append(batch_time)
if benchmark:
speed = len(directories) / batch_time
self.stdout.write(f' 插入 {len(directories):,} 个 | 耗时: {batch_time:.2f}s | 速度: {speed:.0f} 条/秒')
else:
self.stdout.write(f' 插入 {len(directories):,} 个目录... (进度: {i+1:,}/{count:,})')
directories = []
# 插入剩余的
if directories:
with transaction.atomic():
Directory.objects.bulk_create(directories, ignore_conflicts=True)
total_created += len(directories)
self.stdout.write(f' 插入 {len(directories):,} 个目录... (进度: {count:,}/{count:,})')
total_time = time.time() - start_time
avg_batch_time = sum(batch_times) / len(batch_times) if batch_times else 0
total_speed = total_created / total_time if total_time > 0 else 0
self.stdout.write(self.style.SUCCESS(
f' ✓ 完成!共创建 {total_created:,} 个 | '
f'总耗时: {total_time:.2f}s | '
f'平均速度: {total_speed:.0f} 条/秒'
))
return {
'name': '目录',
'count': total_created,
'time': total_time,
'speed': total_speed,
'avg_batch_time': avg_batch_time
}
def _print_db_info(self):
"""打印数据库连接信息"""
db_settings = connection.settings_dict
self.stdout.write(f'\n数据库信息:')
self.stdout.write(f' 主机: {db_settings["HOST"]}')
self.stdout.write(f' 端口: {db_settings["PORT"]}')
self.stdout.write(f' 数据库: {db_settings["NAME"]}')
self.stdout.write(f' 引擎: {db_settings["ENGINE"].split(".")[-1]}')
def _print_performance_summary(self, stats_list):
"""打印性能总结"""
self.stdout.write(f'\n{"="*60}')
self.stdout.write(' 性能测试报告')
self.stdout.write(f'{"="*60}\n')
total_records = sum(s['count'] for s in stats_list)
total_time = sum(s['time'] for s in stats_list)
overall_speed = total_records / total_time if total_time > 0 else 0
self.stdout.write(f'{"表名":<12} {"记录数":<12} {"耗时(秒)":<12} {"速度(条/秒)":<15} {"平均批次时间(秒)"}')
self.stdout.write('-' * 65)
for stats in stats_list:
self.stdout.write(
f'{stats["name"]:<12} '
f'{stats["count"]:<12,} '
f'{stats["time"]:<12.2f} '
f'{stats["speed"]:<15.0f} '
f'{stats.get("avg_batch_time", 0):<.3f}'
)
self.stdout.write('-' * 65)
self.stdout.write(
f'{"总计":<12} '
f'{total_records:<12,} '
f'{total_time:<12.2f} '
f'{overall_speed:<15.0f}'
)
self.stdout.write('')
def _test_batch_sizes(self, target_name, count):
"""测试不同批次大小的性能"""
batch_sizes = [100, 500, 1000, 2000, 5000]
test_count = min(count, 10000) # 限制测试数据量
self.stdout.write(f'\n{"="*60}')
self.stdout.write(f' 批次大小性能测试')
self.stdout.write(f'{"="*60}\n')
self.stdout.write(f'测试数据量: {test_count:,}')
self.stdout.write(f'测试批次: {batch_sizes}\n')
results = []
for batch_size in batch_sizes:
self.stdout.write(f'\n测试批次大小: {batch_size}')
self.stdout.write('-' * 40)
# 这里只测试子域名的插入性能
try:
target = Target.objects.get(name=target_name)
except Target.DoesNotExist:
self.stdout.write(self.style.ERROR(f'目标不存在: {target_name}'))
return
scan = Scan.objects.filter(target=target).first()
if not scan:
from apps.engine.models import ScanEngine
engine = ScanEngine.objects.first()
scan = Scan.objects.create(
target=target,
engine=engine,
status='completed',
results_dir=f'/tmp/test_{target_name}'
)
_, stats = self._generate_subdomains(target, scan, test_count, batch_size, benchmark=True)
results.append((batch_size, stats))
# 清理测试数据
Subdomain.objects.filter(scan=scan, name__startswith=f'test-').delete()
# 打印对比结果
self.stdout.write(f'\n{"="*60}')
self.stdout.write(' 批次大小对比结果')
self.stdout.write(f'{"="*60}\n')
self.stdout.write(f'{"批次大小":<12} {"总耗时(秒)":<15} {"速度(条/秒)":<15} {"平均批次时间(秒)"}')
self.stdout.write('-' * 60)
for batch_size, stats in results:
self.stdout.write(
f'{batch_size:<12} '
f'{stats["time"]:<15.2f} '
f'{stats["speed"]:<15.0f} '
f'{stats["avg_batch_time"]:<.3f}'
)
# 找出最快的批次大小
fastest = min(results, key=lambda x: x[1]['time'])
self.stdout.write(f'\n推荐批次大小: {fastest[0]} (最快: {fastest[1]["time"]:.2f}秒)')
self.stdout.write('')

View File

@@ -87,8 +87,8 @@ def on_all_workers_high_load(sender, worker_name, cpu, mem, **kwargs):
"""所有 Worker 高负载时的通知处理"""
create_notification(
title="系统负载较高",
message=f"所有节点负载较高,已选择负载最低的节点 {worker_name}CPU: {cpu:.1f}%, 内存: {mem:.1f}%执行任务,扫描速度可能受影响",
message=f"所有节点负载较高(最低负载节点 CPU: {cpu:.1f}%, 内存: {mem:.1f}%,系统将等待最多 10 分钟后分发任务,扫描速度可能受影响",
level=NotificationLevel.MEDIUM,
category=NotificationCategory.SYSTEM
)
logger.warning("高负载通知已发送 - worker=%s, cpu=%.1f%%, mem=%.1f%%", worker_name, cpu, mem)
logger.warning("高负载通知已发送 - cpu=%.1f%%, mem=%.1f%%", cpu, mem)

View File

@@ -206,6 +206,10 @@ class FlowOrchestrator:
from apps.scan.flows.site_scan_flow import site_scan_flow
return site_scan_flow
elif scan_type == 'fingerprint_detect':
from apps.scan.flows.fingerprint_detect_flow import fingerprint_detect_flow
return fingerprint_detect_flow
elif scan_type == 'directory_scan':
from apps.scan.flows.directory_scan_flow import directory_scan_flow
return directory_scan_flow

View File

@@ -83,7 +83,7 @@ def cleanup_results(results_dir: str, retention_days: int) -> dict:
def main():
parser = argparse.ArgumentParser(description="清理任务")
parser.add_argument("--results_dir", type=str, default="/app/backend/results", help="扫描结果目录")
parser.add_argument("--results_dir", type=str, default="/opt/xingrin/results", help="扫描结果目录")
parser.add_argument("--retention_days", type=int, default=7, help="保留天数")
args = parser.parse_args()

View File

@@ -17,6 +17,8 @@ from .scan_state_service import ScanStateService
from .scan_control_service import ScanControlService
from .scan_stats_service import ScanStatsService
from .scheduled_scan_service import ScheduledScanService
from .blacklist_service import BlacklistService
from .target_export_service import TargetExportService
__all__ = [
'ScanService', # 主入口(向后兼容)
@@ -25,5 +27,7 @@ __all__ = [
'ScanControlService',
'ScanStatsService',
'ScheduledScanService',
'BlacklistService', # 黑名单过滤服务
'TargetExportService', # 目标导出服务
]

View File

@@ -0,0 +1,85 @@
"""
黑名单过滤服务
过滤敏感域名(如 .gov、.edu、.mil 等)
当前版本使用默认规则,后续将支持从前端配置加载。
"""
from typing import List, Optional
from django.db.models import QuerySet
import re
import logging
logger = logging.getLogger(__name__)
class BlacklistService:
"""
黑名单过滤服务 - 过滤敏感域名
TODO: 后续版本支持从前端配置加载黑名单规则
- 用户在开始扫描时配置黑名单 URL、域名、IP
- 黑名单规则存储在数据库中,与 Scan 或 Engine 关联
"""
# 默认黑名单正则规则
DEFAULT_PATTERNS = [
r'\.gov$', # .gov 结尾
r'\.gov\.[a-z]{2}$', # .gov.cn, .gov.uk 等
r'\.edu$', # .edu 结尾
r'\.edu\.[a-z]{2}$', # .edu.cn 等
r'\.mil$', # .mil 结尾
]
def __init__(self, patterns: Optional[List[str]] = None):
"""
初始化黑名单服务
Args:
patterns: 正则表达式列表None 使用默认规则
"""
self.patterns = patterns or self.DEFAULT_PATTERNS
self._compiled_patterns = [re.compile(p) for p in self.patterns]
def filter_queryset(
self,
queryset: QuerySet,
url_field: str = 'url'
) -> QuerySet:
"""
数据库层面过滤 queryset
使用 PostgreSQL 正则表达式排除黑名单 URL
Args:
queryset: 原始 queryset
url_field: URL 字段名
Returns:
QuerySet: 过滤后的 queryset
"""
for pattern in self.patterns:
queryset = queryset.exclude(**{f'{url_field}__regex': pattern})
return queryset
def filter_url(self, url: str) -> bool:
"""
检查单个 URL 是否通过黑名单过滤
Args:
url: 要检查的 URL
Returns:
bool: True 表示通过不在黑名单False 表示被过滤
"""
for pattern in self._compiled_patterns:
if pattern.search(url):
return False
return True
# TODO: 后续版本实现
# @classmethod
# def from_scan(cls, scan_id: int) -> 'BlacklistService':
# """从数据库加载扫描配置的黑名单规则"""
# pass

View File

@@ -0,0 +1,295 @@
"""
快速扫描服务
负责解析用户输入URL、域名、IP、CIDR并创建对应的资产数据
"""
import logging
from dataclasses import dataclass
from typing import Optional, Literal, List, Dict, Any
from urllib.parse import urlparse
from django.db import transaction
from apps.common.validators import validate_url, detect_input_type, validate_domain, validate_ip, validate_cidr, is_valid_ip
from apps.targets.services.target_service import TargetService
from apps.targets.models import Target
from apps.asset.dtos import WebSiteDTO
from apps.asset.dtos.asset import EndpointDTO
from apps.asset.repositories.asset.website_repository import DjangoWebSiteRepository
from apps.asset.repositories.asset.endpoint_repository import DjangoEndpointRepository
logger = logging.getLogger(__name__)
@dataclass
class ParsedInputDTO:
"""
解析输入 DTO
只在快速扫描流程中使用
"""
original_input: str
input_type: Literal['url', 'domain', 'ip', 'cidr']
target_name: str # host/domain/ip/cidr
target_type: Literal['domain', 'ip', 'cidr']
website_url: Optional[str] = None # 根 URLscheme://host[:port]
endpoint_url: Optional[str] = None # 完整 URL含路径
is_valid: bool = True
error: Optional[str] = None
line_number: Optional[int] = None
class QuickScanService:
"""快速扫描服务 - 解析输入并创建资产"""
def __init__(self):
self.target_service = TargetService()
self.website_repo = DjangoWebSiteRepository()
self.endpoint_repo = DjangoEndpointRepository()
def parse_inputs(self, inputs: List[str]) -> List[ParsedInputDTO]:
"""
解析多行输入
Args:
inputs: 输入字符串列表(每行一个)
Returns:
解析结果列表(跳过空行)
"""
results = []
for line_number, input_str in enumerate(inputs, start=1):
input_str = input_str.strip()
# 空行跳过
if not input_str:
continue
try:
# 检测输入类型
input_type = detect_input_type(input_str)
if input_type == 'url':
dto = self._parse_url_input(input_str, line_number)
else:
dto = self._parse_target_input(input_str, input_type, line_number)
results.append(dto)
except ValueError as e:
# 解析失败,记录错误
results.append(ParsedInputDTO(
original_input=input_str,
input_type='domain', # 默认类型
target_name=input_str,
target_type='domain',
is_valid=False,
error=str(e),
line_number=line_number
))
return results
def _parse_url_input(self, url_str: str, line_number: int) -> ParsedInputDTO:
"""
解析 URL 输入
Args:
url_str: URL 字符串
line_number: 行号
Returns:
ParsedInputDTO
"""
# 验证 URL 格式
validate_url(url_str)
# 使用标准库解析
parsed = urlparse(url_str)
host = parsed.hostname # 不含端口
has_path = parsed.path and parsed.path != '/'
# 构建 root_url: scheme://host[:port]
root_url = f"{parsed.scheme}://{parsed.netloc}"
# 检测 host 类型domain 或 ip
target_type = 'ip' if is_valid_ip(host) else 'domain'
return ParsedInputDTO(
original_input=url_str,
input_type='url',
target_name=host,
target_type=target_type,
website_url=root_url,
endpoint_url=url_str if has_path else None,
line_number=line_number
)
def _parse_target_input(
self,
input_str: str,
input_type: str,
line_number: int
) -> ParsedInputDTO:
"""
解析非 URL 输入domain/ip/cidr
Args:
input_str: 输入字符串
input_type: 输入类型
line_number: 行号
Returns:
ParsedInputDTO
"""
# 验证格式
if input_type == 'domain':
validate_domain(input_str)
target_type = 'domain'
elif input_type == 'ip':
validate_ip(input_str)
target_type = 'ip'
elif input_type == 'cidr':
validate_cidr(input_str)
target_type = 'cidr'
else:
raise ValueError(f"未知的输入类型: {input_type}")
return ParsedInputDTO(
original_input=input_str,
input_type=input_type,
target_name=input_str,
target_type=target_type,
website_url=None,
endpoint_url=None,
line_number=line_number
)
@transaction.atomic
def process_quick_scan(
self,
inputs: List[str],
engine_id: int
) -> Dict[str, Any]:
"""
处理快速扫描请求
Args:
inputs: 输入字符串列表
engine_id: 扫描引擎 ID
Returns:
处理结果字典
"""
# 1. 解析输入
parsed_inputs = self.parse_inputs(inputs)
# 分离有效和无效输入
valid_inputs = [p for p in parsed_inputs if p.is_valid]
invalid_inputs = [p for p in parsed_inputs if not p.is_valid]
if not valid_inputs:
return {
'targets': [],
'target_stats': {'created': 0, 'reused': 0, 'failed': len(invalid_inputs)},
'asset_stats': {'websites_created': 0, 'endpoints_created': 0},
'errors': [
{'line_number': p.line_number, 'input': p.original_input, 'error': p.error}
for p in invalid_inputs
]
}
# 2. 创建资产
asset_result = self.create_assets_from_parsed_inputs(valid_inputs)
# 3. 返回结果
return {
'targets': asset_result['targets'],
'target_stats': asset_result['target_stats'],
'asset_stats': asset_result['asset_stats'],
'errors': [
{'line_number': p.line_number, 'input': p.original_input, 'error': p.error}
for p in invalid_inputs
]
}
def create_assets_from_parsed_inputs(
self,
parsed_inputs: List[ParsedInputDTO]
) -> Dict[str, Any]:
"""
从解析结果创建资产
Args:
parsed_inputs: 解析结果列表(只包含有效输入)
Returns:
创建结果字典
"""
# 1. 收集所有 target 数据(内存操作,去重)
targets_data = {}
for dto in parsed_inputs:
if dto.target_name not in targets_data:
targets_data[dto.target_name] = {'name': dto.target_name, 'type': dto.target_type}
targets_list = list(targets_data.values())
# 2. 批量创建 Target复用现有方法
target_result = self.target_service.batch_create_targets(targets_list)
# 3. 查询刚创建的 Target建立 name → id 映射
target_names = [d['name'] for d in targets_list]
targets = Target.objects.filter(name__in=target_names)
target_id_map = {t.name: t.id for t in targets}
# 4. 收集 Website DTO内存操作去重
website_dtos = []
seen_websites = set()
for dto in parsed_inputs:
if dto.website_url and dto.website_url not in seen_websites:
seen_websites.add(dto.website_url)
target_id = target_id_map.get(dto.target_name)
if target_id:
website_dtos.append(WebSiteDTO(
target_id=target_id,
url=dto.website_url,
host=dto.target_name
))
# 5. 批量创建 Website存在即跳过
websites_created = 0
if website_dtos:
websites_created = self.website_repo.bulk_create_ignore_conflicts(website_dtos)
# 6. 收集 Endpoint DTO内存操作去重
endpoint_dtos = []
seen_endpoints = set()
for dto in parsed_inputs:
if dto.endpoint_url and dto.endpoint_url not in seen_endpoints:
seen_endpoints.add(dto.endpoint_url)
target_id = target_id_map.get(dto.target_name)
if target_id:
endpoint_dtos.append(EndpointDTO(
target_id=target_id,
url=dto.endpoint_url,
host=dto.target_name
))
# 7. 批量创建 Endpoint存在即跳过
endpoints_created = 0
if endpoint_dtos:
endpoints_created = self.endpoint_repo.bulk_create_ignore_conflicts(endpoint_dtos)
return {
'targets': list(targets),
'target_stats': {
'created': target_result['created_count'],
'reused': 0, # bulk_create 无法区分新建和复用
'failed': target_result['failed_count']
},
'asset_stats': {
'websites_created': websites_created,
'endpoints_created': endpoints_created
}
}

View File

@@ -0,0 +1,364 @@
"""
目标导出服务
提供统一的目标提取和文件导出功能,支持:
- URL 导出(流式写入 + 默认值回退)
- 域名/IP 导出(用于端口扫描)
- 黑名单过滤集成
"""
import ipaddress
import logging
from pathlib import Path
from typing import Dict, Any, Optional, Iterator
from django.db.models import QuerySet
from .blacklist_service import BlacklistService
logger = logging.getLogger(__name__)
class TargetExportService:
"""
目标导出服务 - 提供统一的目标提取和文件导出功能
使用方式:
# Task 层决定数据源
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
# 使用导出服务
blacklist_service = BlacklistService()
export_service = TargetExportService(blacklist_service=blacklist_service)
result = export_service.export_urls(target_id, output_path, queryset)
"""
def __init__(self, blacklist_service: Optional[BlacklistService] = None):
"""
初始化导出服务
Args:
blacklist_service: 黑名单过滤服务None 表示禁用过滤
"""
self.blacklist_service = blacklist_service
def export_urls(
self,
target_id: int,
output_path: str,
queryset: QuerySet,
url_field: str = 'url',
batch_size: int = 1000
) -> Dict[str, Any]:
"""
统一 URL 导出函数
自动判断数据库有无数据:
- 有数据:流式写入数据库数据到文件
- 无数据:调用默认值生成器生成 URL
Args:
target_id: 目标 ID
output_path: 输出文件路径
queryset: 数据源 queryset由 Task 层构建,应为 values_list flat=True
url_field: URL 字段名(用于黑名单过滤)
batch_size: 批次大小
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int
}
Raises:
IOError: 文件写入失败
"""
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
logger.info("开始导出 URL - target_id=%s, output=%s", target_id, output_path)
# 应用黑名单过滤(数据库层面)
if self.blacklist_service:
# 注意queryset 应该是原始 queryset不是 values_list
# 这里假设 Task 层传入的是 values_list需要在 Task 层处理过滤
pass
total_count = 0
try:
with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
for url in queryset.iterator(chunk_size=batch_size):
if url:
# Python 层面黑名单过滤
if self.blacklist_service and not self.blacklist_service.filter_url(url):
continue
f.write(f"{url}\n")
total_count += 1
if total_count % 10000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
except IOError as e:
logger.error("文件写入失败: %s - %s", output_path, e)
raise
# 默认值回退模式
if total_count == 0:
total_count = self._generate_default_urls(target_id, output_file)
logger.info("✓ URL 导出完成 - 数量: %d, 文件: %s", total_count, output_path)
return {
'success': True,
'output_file': str(output_file),
'total_count': total_count
}
def _generate_default_urls(
self,
target_id: int,
output_path: Path
) -> int:
"""
默认值生成器(内部函数)
根据 Target 类型生成默认 URL
- DOMAIN: http(s)://domain
- IP: http(s)://ip
- CIDR: 展开为所有 IP 的 http(s)://ip
- URL: 直接使用目标 URL
Args:
target_id: 目标 ID
output_path: 输出文件路径
Returns:
int: 写入的 URL 总数
"""
from apps.targets.services import TargetService
from apps.targets.models import Target
target_service = TargetService()
target = target_service.get_target(target_id)
if not target:
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
return 0
target_name = target.name
target_type = target.type
logger.info("懒加载模式Target 类型=%s, 名称=%s", target_type, target_name)
total_urls = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
if target_type == Target.TargetType.DOMAIN:
urls = [f"http://{target_name}", f"https://{target_name}"]
for url in urls:
if self._should_write_url(url):
f.write(f"{url}\n")
total_urls += 1
elif target_type == Target.TargetType.IP:
urls = [f"http://{target_name}", f"https://{target_name}"]
for url in urls:
if self._should_write_url(url):
f.write(f"{url}\n")
total_urls += 1
elif target_type == Target.TargetType.CIDR:
try:
network = ipaddress.ip_network(target_name, strict=False)
for ip in network.hosts():
urls = [f"http://{ip}", f"https://{ip}"]
for url in urls:
if self._should_write_url(url):
f.write(f"{url}\n")
total_urls += 1
if total_urls % 10000 == 0:
logger.info("已生成 %d 个 URL...", total_urls)
# /32 或 /128 特殊处理
if total_urls == 0:
ip = str(network.network_address)
urls = [f"http://{ip}", f"https://{ip}"]
for url in urls:
if self._should_write_url(url):
f.write(f"{url}\n")
total_urls += 1
except ValueError as e:
logger.error("CIDR 解析失败: %s - %s", target_name, e)
raise ValueError(f"无效的 CIDR: {target_name}") from e
elif target_type == Target.TargetType.URL:
if self._should_write_url(target_name):
f.write(f"{target_name}\n")
total_urls = 1
else:
logger.warning("不支持的 Target 类型: %s", target_type)
logger.info("✓ 懒加载生成默认 URL - 数量: %d", total_urls)
return total_urls
def _should_write_url(self, url: str) -> bool:
"""检查 URL 是否应该写入(通过黑名单过滤)"""
if self.blacklist_service:
return self.blacklist_service.filter_url(url)
return True
def export_targets(
self,
target_id: int,
output_path: str,
batch_size: int = 1000
) -> Dict[str, Any]:
"""
域名/IP 导出函数(用于端口扫描)
根据 Target 类型选择导出逻辑:
- DOMAIN: 从 Subdomain 表流式导出子域名
- IP: 直接写入 IP 地址
- CIDR: 展开为所有主机 IP
Args:
target_id: 目标 ID
output_path: 输出文件路径
batch_size: 批次大小
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int,
'target_type': str
}
"""
from apps.targets.services import TargetService
from apps.targets.models import Target
from apps.asset.services.asset.subdomain_service import SubdomainService
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
# 获取 Target 信息
target_service = TargetService()
target = target_service.get_target(target_id)
if not target:
raise ValueError(f"Target ID {target_id} 不存在")
target_type = target.type
target_name = target.name
logger.info(
"开始导出扫描目标 - Target ID: %d, Name: %s, Type: %s, 输出文件: %s",
target_id, target_name, target_type, output_path
)
total_count = 0
if target_type == Target.TargetType.DOMAIN:
total_count = self._export_domains(target_id, target_name, output_file, batch_size)
type_desc = "域名"
elif target_type == Target.TargetType.IP:
total_count = self._export_ip(target_name, output_file)
type_desc = "IP"
elif target_type == Target.TargetType.CIDR:
total_count = self._export_cidr(target_name, output_file)
type_desc = "CIDR IP"
else:
raise ValueError(f"不支持的目标类型: {target_type}")
logger.info(
"✓ 扫描目标导出完成 - 类型: %s, 总数: %d, 文件: %s",
type_desc, total_count, output_path
)
return {
'success': True,
'output_file': str(output_file),
'total_count': total_count,
'target_type': target_type
}
def _export_domains(
self,
target_id: int,
target_name: str,
output_path: Path,
batch_size: int
) -> int:
"""导出域名类型目标的子域名"""
from apps.asset.services.asset.subdomain_service import SubdomainService
subdomain_service = SubdomainService()
domain_iterator = subdomain_service.iter_subdomain_names_by_target(
target_id=target_id,
chunk_size=batch_size
)
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for domain_name in domain_iterator:
if self._should_write_target(domain_name):
f.write(f"{domain_name}\n")
total_count += 1
if total_count % 10000 == 0:
logger.info("已导出 %d 个域名...", total_count)
# 默认值模式:如果没有子域名,使用根域名
if total_count == 0:
logger.info("采用默认域名:%s (target_id=%d)", target_name, target_id)
if self._should_write_target(target_name):
with open(output_path, 'w', encoding='utf-8') as f:
f.write(f"{target_name}\n")
total_count = 1
return total_count
def _export_ip(self, target_name: str, output_path: Path) -> int:
"""导出 IP 类型目标"""
if self._should_write_target(target_name):
with open(output_path, 'w', encoding='utf-8') as f:
f.write(f"{target_name}\n")
return 1
return 0
def _export_cidr(self, target_name: str, output_path: Path) -> int:
"""导出 CIDR 类型目标,展开为每个 IP"""
network = ipaddress.ip_network(target_name, strict=False)
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for ip in network.hosts():
ip_str = str(ip)
if self._should_write_target(ip_str):
f.write(f"{ip_str}\n")
total_count += 1
if total_count % 10000 == 0:
logger.info("已导出 %d 个 IP...", total_count)
# /32 或 /128 特殊处理
if total_count == 0:
ip_str = str(network.network_address)
if self._should_write_target(ip_str):
with open(output_path, 'w', encoding='utf-8') as f:
f.write(f"{ip_str}\n")
total_count = 1
return total_count
def _should_write_target(self, target: str) -> bool:
"""检查目标是否应该写入(通过黑名单过滤)"""
if self.blacklist_service:
return self.blacklist_service.filter_url(target)
return True

View File

@@ -9,9 +9,6 @@
- Tasks 负责具体操作Flow 负责编排
"""
# Prefect Tasks
from .workspace_tasks import create_scan_workspace_task
# 子域名发现任务(已重构为多个子任务)
from .subdomain_discovery import (
run_subdomain_discovery_task,
@@ -19,17 +16,25 @@ from .subdomain_discovery import (
save_domains_task,
)
# 指纹识别任务
from .fingerprint_detect import (
export_urls_for_fingerprint_task,
run_xingfinger_and_stream_update_tech_task,
)
# 注意:
# - subdomain_discovery_task 已重构为多个子任务subdomain_discovery/
# - finalize_scan_task 已废弃Handler 统一管理状态)
# - initiate_scan_task 已迁移到 flows/initiate_scan_flow.py
# - cleanup_old_scans_task 已迁移到 flowscleanup_old_scans_flow
# - create_scan_workspace_task 已删除,直接使用 setup_scan_workspace()
__all__ = [
# Prefect Tasks
'create_scan_workspace_task',
# 子域名发现任务
'run_subdomain_discovery_task',
'merge_and_validate_task',
'save_domains_task',
# 指纹识别任务
'export_urls_for_fingerprint_task',
'run_xingfinger_and_stream_update_tech_task',
]

View File

@@ -1,13 +1,14 @@
"""
导出站点 URL 到 TXT 文件的 Task
使用流式处理,避免大量站点导致内存溢出
使用 TargetExportService 统一处理导出逻辑和默认值回退
数据源: WebSite.url
"""
import logging
from pathlib import Path
from prefect import task
from apps.asset.repositories import DjangoWebSiteRepository
from apps.asset.models import WebSite
from apps.scan.services import TargetExportService, BlacklistService
logger = logging.getLogger(__name__)
@@ -16,12 +17,18 @@ logger = logging.getLogger(__name__)
def export_sites_task(
target_id: int,
output_file: str,
batch_size: int = 1000
batch_size: int = 1000,
) -> dict:
"""
导出目标下的所有站点 URL 到 TXT 文件
使用流式处理支持大规模数据导出10万+站点)
数据源: WebSite.url
懒加载模式:
- 如果数据库为空,根据 Target 类型生成默认 URL
- DOMAIN: http(s)://domain
- IP: http(s)://ip
- CIDR: 展开为所有 IP 的 URL
Args:
target_id: 目标 ID
@@ -39,56 +46,26 @@ def export_sites_task(
ValueError: 参数错误
IOError: 文件写入失败
"""
try:
# 初始化 Repository
repository = DjangoWebSiteRepository()
logger.info("开始导出站点 URL - Target ID: %d, 输出文件: %s", target_id, output_file)
# 确保输出目录存在
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 使用 Repository 流式查询站点 URL
url_iterator = repository.get_urls_for_export(
target_id=target_id,
batch_size=batch_size
)
# 流式写入文件
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in url_iterator:
# 每次只处理一个 URL边读边写
f.write(f"{url}\n")
total_count += 1
# 每写入 10000 条记录打印一次进度
if total_count % 10000 == 0:
logger.info("已导出 %d 个站点 URL...", total_count)
logger.info(
"✓ 站点 URL 导出完成 - 总数: %d, 文件: %s (%.2f KB)",
total_count,
str(output_path), # 使用绝对路径
output_path.stat().st_size / 1024
)
return {
'success': True,
'output_file': str(output_path),
'total_count': total_count
}
except FileNotFoundError as e:
logger.error("输出目录不存在: %s", e)
raise
except PermissionError as e:
logger.error("文件写入权限不足: %s", e)
raise
except Exception as e:
logger.exception("导出站点 URL 失败: %s", e)
raise
# 构建数据源 querysetTask 层决定数据源)
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
# 使用 TargetExportService 处理导出
blacklist_service = BlacklistService()
export_service = TargetExportService(blacklist_service=blacklist_service)
result = export_service.export_urls(
target_id=target_id,
output_path=output_file,
queryset=queryset,
batch_size=batch_size
)
# 保持返回值格式不变(向后兼容)
return {
'success': result['success'],
'output_file': result['output_file'],
'total_count': result['total_count']
}

View File

@@ -30,7 +30,6 @@ from django.db import IntegrityError, OperationalError, DatabaseError
from psycopg2 import InterfaceError
from dataclasses import dataclass
from apps.asset.services import WebSiteService
from apps.asset.dtos.snapshot import DirectorySnapshotDTO
from apps.scan.utils import execute_stream
@@ -48,7 +47,6 @@ class ServiceSet:
提供目录扫描所需的 Service 实例,便于测试时注入 Mock 对象
"""
website: WebSiteService
snapshot: "DirectorySnapshotsService"
@classmethod
@@ -56,7 +54,6 @@ class ServiceSet:
"""创建默认的 Service 集合"""
from apps.asset.services.snapshot import DirectorySnapshotsService
return cls(
website=WebSiteService(),
snapshot=DirectorySnapshotsService()
)
@@ -79,19 +76,19 @@ def _parse_and_validate_line(line: str) -> Optional[dict]:
try:
# 步骤 1: 解析 JSON
try:
line_data = json.loads(line)
line_data = json.loads(line, strict=False)
except json.JSONDecodeError:
# logger.debug("跳过非 JSON 格式的行: %s", line[:100])
# logger.debug("跳过非 JSON 行: %s", line)
return None
# 步骤 2: 验证数据类型
if not isinstance(line_data, dict):
logger.warning("解析后的数据不是字典类型,跳过: %s", str(line_data)[:100])
logger.debug("跳过非字典数据")
return None
# 步骤 3: 验证必要字段
if not line_data.get('url'):
logger.debug("URL 为空,跳过")
logger.info("URL 为空,跳过 - 数据: %s", str(line_data)[:200])
return None
# 返回有效记录
@@ -105,8 +102,8 @@ def _parse_and_validate_line(line: str) -> Optional[dict]:
'duration': line_data.get('duration')
}
except Exception as e:
logger.error("解析行数据异常: %s - 数据: %s", e, line[:100])
except Exception:
logger.info("跳过无法解析的行: %s", line[:100])
return None
@@ -176,7 +173,6 @@ def _parse_ffuf_stream_output(
def _save_batch_with_retry(
batch: list,
website_id: int,
scan_id: int,
target_id: int,
batch_num: int,
@@ -188,7 +184,6 @@ def _save_batch_with_retry(
Args:
batch: 数据批次
website_id: 站点 ID
scan_id: 扫描任务ID
target_id: 目标ID
batch_num: 批次编号
@@ -203,7 +198,7 @@ def _save_batch_with_retry(
"""
for attempt in range(max_retries):
try:
count = _save_batch(batch, website_id, scan_id, target_id, batch_num, services)
count = _save_batch(batch, scan_id, target_id, batch_num, services)
return {
'success': True,
'created_directories': count
@@ -257,7 +252,6 @@ def _save_batch_with_retry(
def _save_batch(
batch: list,
website_id: int,
scan_id: int,
target_id: int,
batch_num: int,
@@ -267,7 +261,7 @@ def _save_batch(
保存一个批次的数据到数据库(使用快照 Service
数据关系链:
WebSite (已存在) → DirectorySnapshot (待创建) → Directory (自动同步)
Target → DirectorySnapshot (待创建) → Directory (自动同步)
处理流程:
1. 构建 DirectorySnapshotDTO包含 scan_id 和 target_id
@@ -275,7 +269,6 @@ def _save_batch(
Args:
batch: 数据批次list of dict
website_id: 站点 ID
scan_id: 扫描任务 ID
target_id: 目标 ID
batch_num: 批次编号(用于日志)
@@ -289,14 +282,11 @@ def _save_batch(
return 0
# ========== Step 1: 准备 DirectorySnapshot 数据(内存操作,无需事务)==========
snapshot_items = []
for record in batch:
# 创建 DirectorySnapshot DTO
snapshot_dto = DirectorySnapshotDTO(
# 使用列表推导式构建 DTO 列表
snapshot_items = [
DirectorySnapshotDTO(
scan_id=scan_id,
website_id=website_id,
target_id=target_id, # 冗余字段,用于同步到资产表
target_id=target_id,
url=record['url'],
status=record.get('status'),
content_length=record.get('length'),
@@ -305,8 +295,8 @@ def _save_batch(
content_type=record.get('content_type', ''),
duration=record.get('duration')
)
snapshot_items.append(snapshot_dto)
for record in batch
]
# ========== Step 2: 保存快照并同步到资产表(通过快照 Service==========
if snapshot_items:
@@ -373,19 +363,10 @@ def run_and_stream_save_directories_task(
# 1. 初始化服务
services = ServiceSet.create_default()
# 2. 查找站点(使用 Service
website_id = services.website.get_by_url(url=site_url, target_id=target_id)
if website_id is None:
logger.error("站点不存在: %s", site_url)
raise ValueError(f"站点不存在: {site_url}")
logger.info("找到站点: %s (ID: %d)", site_url, website_id)
# 3. 初始化资源
# 2. 初始化资源(不再需要查找 WebSiteDirectory 直接关联 Target
data_generator = _parse_ffuf_stream_output(cmd=cmd, tool_name=tool_name, cwd=cwd, shell=shell, timeout=timeout, log_file=log_file)
# 4. 流式处理记录并分批保存
# 3. 流式处理记录并分批保存
total_records = 0
batch_num = 0
failed_batches = []
@@ -400,7 +381,7 @@ def run_and_stream_save_directories_task(
if len(batch) >= batch_size:
batch_num += 1
result = _save_batch_with_retry(
batch, website_id, scan_id, target_id, batch_num, services
batch, scan_id, target_id, batch_num, services
)
total_created += result.get('created_directories', 0)
@@ -421,7 +402,7 @@ def run_and_stream_save_directories_task(
if batch:
batch_num += 1
result = _save_batch_with_retry(
batch, website_id, scan_id, target_id, batch_num, services
batch, scan_id, target_id, batch_num, services
)
total_created += result.get('created_directories', 0)

Some files were not shown because too many files have changed in this diff Show More