mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-01-31 19:53:11 +08:00
Compare commits
198 Commits
v1.0.36
...
v1.3.7-dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4037202dc | ||
|
|
4b4f9862bf | ||
|
|
1c42e4978f | ||
|
|
57bab63997 | ||
|
|
b1f0f18ac0 | ||
|
|
ccee5471b8 | ||
|
|
0ccd362535 | ||
|
|
7f2af7f7e2 | ||
|
|
4bd0f9e8c1 | ||
|
|
68cc996e3b | ||
|
|
f1e79d638e | ||
|
|
d484133e4c | ||
|
|
fc977ae029 | ||
|
|
f328474404 | ||
|
|
68e726a066 | ||
|
|
77a6f45909 | ||
|
|
49d1f1f1bb | ||
|
|
db8ecb1644 | ||
|
|
18cc016268 | ||
|
|
23bc463283 | ||
|
|
7b903b91b2 | ||
|
|
b3136d51b9 | ||
|
|
236c828041 | ||
|
|
fb13bb74d8 | ||
|
|
f076c682b6 | ||
|
|
9eda2caceb | ||
|
|
b1c9e202dd | ||
|
|
918669bc29 | ||
|
|
fd70b0544d | ||
|
|
0f2df7a5f3 | ||
|
|
857ab737b5 | ||
|
|
ee2d99edda | ||
|
|
db6ce16aca | ||
|
|
ab800eca06 | ||
|
|
e8e5572339 | ||
|
|
d48d4bbcad | ||
|
|
d1cca4c083 | ||
|
|
df0810c863 | ||
|
|
d33e54c440 | ||
|
|
35a306fe8b | ||
|
|
724df82931 | ||
|
|
8dfffdf802 | ||
|
|
b8cb85ce0b | ||
|
|
da96d437a4 | ||
|
|
feaf8062e5 | ||
|
|
4bab76f233 | ||
|
|
09416b4615 | ||
|
|
bc1c5f6b0e | ||
|
|
2f2742e6fe | ||
|
|
be3c346a74 | ||
|
|
0c7a6fff12 | ||
|
|
3b4f0e3147 | ||
|
|
51212a2a0c | ||
|
|
58533bbaf6 | ||
|
|
6ccca1602d | ||
|
|
6389b0f672 | ||
|
|
d7599b8599 | ||
|
|
8eff298293 | ||
|
|
3634101c5b | ||
|
|
163973a7df | ||
|
|
80ffecba3e | ||
|
|
3c21ac940c | ||
|
|
5c9f484d70 | ||
|
|
7567f6c25b | ||
|
|
0599a0b298 | ||
|
|
f7557fe90c | ||
|
|
13571b9772 | ||
|
|
8ee76eef69 | ||
|
|
2a31e29aa2 | ||
|
|
81abc59961 | ||
|
|
ffbfec6dd5 | ||
|
|
a0091636a8 | ||
|
|
69490ab396 | ||
|
|
7306964abf | ||
|
|
cb6b0259e3 | ||
|
|
e1b4618e58 | ||
|
|
556dcf5f62 | ||
|
|
0628eef025 | ||
|
|
38ed8bc642 | ||
|
|
2f4d6a2168 | ||
|
|
c25cb9e06b | ||
|
|
b14ab71c7f | ||
|
|
8b5060e2d3 | ||
|
|
3c9335febf | ||
|
|
1b95e4f2c3 | ||
|
|
d20a600afc | ||
|
|
c29b11fd37 | ||
|
|
6caf707072 | ||
|
|
2627b1fc40 | ||
|
|
ec6712b9b4 | ||
|
|
9d5e4d5408 | ||
|
|
c5d5b24c8f | ||
|
|
671cb56b62 | ||
|
|
51025f69a8 | ||
|
|
b2403b29c4 | ||
|
|
18ef01a47b | ||
|
|
0bf8108fb3 | ||
|
|
837ad19131 | ||
|
|
d7de9a7129 | ||
|
|
22b4e51b42 | ||
|
|
d03628ee45 | ||
|
|
0baabe0753 | ||
|
|
e1191d7abf | ||
|
|
82a2e9a0e7 | ||
|
|
1ccd1bc338 | ||
|
|
b4d42f5372 | ||
|
|
2c66450756 | ||
|
|
119d82dc89 | ||
|
|
fba7f7c508 | ||
|
|
99d384ce29 | ||
|
|
07f36718ab | ||
|
|
7e3f69c208 | ||
|
|
5f90473c3c | ||
|
|
e2a815b96a | ||
|
|
f86a1a9d47 | ||
|
|
d5945679aa | ||
|
|
51e2c51748 | ||
|
|
e2cbf98dda | ||
|
|
cd72bdf7c3 | ||
|
|
35abcf7e39 | ||
|
|
09f2d343a4 | ||
|
|
54d1f86bde | ||
|
|
a3997c9676 | ||
|
|
c90a55f85e | ||
|
|
2eab88b452 | ||
|
|
1baf0eb5e1 | ||
|
|
b61e73f7be | ||
|
|
e896734dfc | ||
|
|
cd83f52f35 | ||
|
|
3e29554c36 | ||
|
|
18e02b536e | ||
|
|
4c1c6f70ab | ||
|
|
a72e7675f5 | ||
|
|
93c2163764 | ||
|
|
de72c91561 | ||
|
|
3e6d060b75 | ||
|
|
766f045904 | ||
|
|
8acfe1cc33 | ||
|
|
7aec3eabb2 | ||
|
|
b1f11c36a4 | ||
|
|
d97fb5245a | ||
|
|
ddf9a1f5a4 | ||
|
|
47f9f96a4b | ||
|
|
6f43e73162 | ||
|
|
9b7d496f3e | ||
|
|
6390849d52 | ||
|
|
7a6d2054f6 | ||
|
|
73ebaab232 | ||
|
|
11899b29c2 | ||
|
|
877d2a56d1 | ||
|
|
dc1e94f038 | ||
|
|
9c3833d13d | ||
|
|
92f3b722ef | ||
|
|
9ef503c666 | ||
|
|
c3a43e94fa | ||
|
|
d6d94355fb | ||
|
|
bc638eabf4 | ||
|
|
5acaada7ab | ||
|
|
aaad3f29cf | ||
|
|
f13eb2d9b2 | ||
|
|
f1b3b60382 | ||
|
|
e249056289 | ||
|
|
dba195b83a | ||
|
|
9b494e6c67 | ||
|
|
2841157747 | ||
|
|
f6c1fef1a6 | ||
|
|
6ec0adf9dd | ||
|
|
22c6661567 | ||
|
|
d9ed004e35 | ||
|
|
a0d9d1f29d | ||
|
|
8aa9ed2a97 | ||
|
|
8baf29d1c3 | ||
|
|
248e48353a | ||
|
|
0d210be50b | ||
|
|
f7c0d0b215 | ||
|
|
d83428f27b | ||
|
|
45a09b8173 | ||
|
|
11dfdee6fd | ||
|
|
e53a884d13 | ||
|
|
3b318c89e3 | ||
|
|
e564bc116a | ||
|
|
410c543066 | ||
|
|
66da140801 | ||
|
|
e60aac3622 | ||
|
|
14aaa71cb1 | ||
|
|
0309dba510 | ||
|
|
967ff8a69f | ||
|
|
9ac23d50b6 | ||
|
|
265525c61e | ||
|
|
1b9d05ce62 | ||
|
|
737980b30f | ||
|
|
494ee81478 | ||
|
|
452686b282 | ||
|
|
c95c68f4e9 | ||
|
|
b02f38606d | ||
|
|
b543f3d2b7 | ||
|
|
a18fb46906 | ||
|
|
bb74f61ea2 |
52
.github/workflows/docker-build.yml
vendored
52
.github/workflows/docker-build.yml
vendored
@@ -44,6 +44,10 @@ jobs:
|
||||
dockerfile: docker/agent/Dockerfile
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- image: xingrin-postgres
|
||||
dockerfile: docker/postgres/Dockerfile
|
||||
context: docker/postgres
|
||||
platforms: linux/amd64,linux/arm64
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -106,33 +110,65 @@ jobs:
|
||||
${{ steps.version.outputs.IS_RELEASE == 'true' && format('{0}/{1}:latest', env.IMAGE_PREFIX, matrix.image) || '' }}
|
||||
build-args: |
|
||||
IMAGE_TAG=${{ steps.version.outputs.VERSION }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
cache-from: type=registry,ref=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:cache
|
||||
cache-to: type=registry,ref=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:cache,mode=max
|
||||
provenance: false
|
||||
sbom: false
|
||||
|
||||
# 所有镜像构建成功后,更新 VERSION 文件
|
||||
# 根据 tag 所在的分支更新对应分支的 VERSION 文件
|
||||
update-version:
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
steps:
|
||||
- name: Checkout
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: main
|
||||
fetch-depth: 0 # 获取完整历史,用于判断 tag 所在分支
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Determine source branch and version
|
||||
id: branch
|
||||
run: |
|
||||
VERSION="${GITHUB_REF#refs/tags/}"
|
||||
|
||||
# 查找包含此 tag 的分支
|
||||
BRANCHES=$(git branch -r --contains ${{ github.ref_name }})
|
||||
echo "Branches containing tag: $BRANCHES"
|
||||
|
||||
# 判断 tag 来自哪个分支
|
||||
if echo "$BRANCHES" | grep -q "origin/main"; then
|
||||
TARGET_BRANCH="main"
|
||||
UPDATE_LATEST="true"
|
||||
elif echo "$BRANCHES" | grep -q "origin/dev"; then
|
||||
TARGET_BRANCH="dev"
|
||||
UPDATE_LATEST="false"
|
||||
else
|
||||
echo "Warning: Tag not found in main or dev branch, defaulting to main"
|
||||
TARGET_BRANCH="main"
|
||||
UPDATE_LATEST="false"
|
||||
fi
|
||||
|
||||
echo "BRANCH=$TARGET_BRANCH" >> $GITHUB_OUTPUT
|
||||
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "UPDATE_LATEST=$UPDATE_LATEST" >> $GITHUB_OUTPUT
|
||||
echo "Will update VERSION on branch: $TARGET_BRANCH"
|
||||
|
||||
- name: Checkout target branch
|
||||
run: |
|
||||
git checkout ${{ steps.branch.outputs.BRANCH }}
|
||||
|
||||
- name: Update VERSION file
|
||||
run: |
|
||||
VERSION="${GITHUB_REF#refs/tags/}"
|
||||
VERSION="${{ steps.branch.outputs.VERSION }}"
|
||||
echo "$VERSION" > VERSION
|
||||
echo "Updated VERSION to $VERSION"
|
||||
echo "Updated VERSION to $VERSION on branch ${{ steps.branch.outputs.BRANCH }}"
|
||||
|
||||
- name: Commit and push
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git add VERSION
|
||||
git diff --staged --quiet || git commit -m "chore: bump version to ${GITHUB_REF#refs/tags/}"
|
||||
git push
|
||||
git diff --staged --quiet || git commit -m "chore: bump version to ${{ steps.branch.outputs.VERSION }}"
|
||||
git push origin ${{ steps.branch.outputs.BRANCH }}
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -132,3 +132,5 @@ temp/
|
||||
|
||||
HGETALL
|
||||
KEYS
|
||||
vuln_scan/input_endpoints.txt
|
||||
open-in-v0
|
||||
88
README.md
88
README.md
@@ -13,35 +13,28 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="#-功能特性">功能特性</a> •
|
||||
<a href="#-全局资产搜索">资产搜索</a> •
|
||||
<a href="#-快速开始">快速开始</a> •
|
||||
<a href="#-文档">文档</a> •
|
||||
<a href="#-技术栈">技术栈</a> •
|
||||
<a href="#-反馈与贡献">反馈与贡献</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<sub>🔍 关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
|
||||
<sub>🔍 关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | 资产搜索 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
|
||||
<p align="center">
|
||||
<b>🌗 明暗模式切换</b>
|
||||
<b>🎨 现代化 UI </b>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="docs/screenshots/light.png" alt="Light Mode" width="49%">
|
||||
<img src="docs/screenshots/dark.png" alt="Dark Mode" width="49%">
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<b>🎨 多种 UI 主题</b>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="docs/screenshots/bubblegum.png" alt="Bubblegum" width="32%">
|
||||
<img src="docs/screenshots/cosmic-night.png" alt="Cosmic Night" width="32%">
|
||||
<img src="docs/screenshots/quantum-rose.png" alt="Quantum Rose" width="32%">
|
||||
<img src="docs/screenshots/light.png" alt="Light Mode" width="24%">
|
||||
<img src="docs/screenshots/bubblegum.png" alt="Bubblegum" width="24%">
|
||||
<img src="docs/screenshots/cosmic-night.png" alt="Cosmic Night" width="24%">
|
||||
<img src="docs/screenshots/quantum-rose.png" alt="Quantum Rose" width="24%">
|
||||
</p>
|
||||
|
||||
## 📚 文档
|
||||
@@ -69,9 +62,14 @@
|
||||
- **自定义流程** - YAML 配置扫描流程,灵活编排
|
||||
- **定时扫描** - Cron 表达式配置,自动化周期扫描
|
||||
|
||||
### 🔖 指纹识别
|
||||
- **多源指纹库** - 内置 EHole、Goby、Wappalyzer、Fingers、FingerPrintHub、ARL 等 2.7W+ 指纹规则
|
||||
- **自动识别** - 扫描流程自动执行,识别 Web 应用技术栈
|
||||
- **指纹管理** - 支持查询、导入、导出指纹规则
|
||||
|
||||
#### 扫描流程架构
|
||||
|
||||
完整的扫描流程包括:子域名发现、端口扫描、站点发现、URL 收集、目录扫描、漏洞扫描等阶段
|
||||
完整的扫描流程包括:子域名发现、端口扫描、站点发现、指纹识别、URL 收集、目录扫描、漏洞扫描等阶段
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
@@ -82,7 +80,8 @@ flowchart LR
|
||||
SUB["子域名发现<br/>subfinder, amass, puredns"]
|
||||
PORT["端口扫描<br/>naabu"]
|
||||
SITE["站点识别<br/>httpx"]
|
||||
SUB --> PORT --> SITE
|
||||
FINGER["指纹识别<br/>xingfinger"]
|
||||
SUB --> PORT --> SITE --> FINGER
|
||||
end
|
||||
|
||||
subgraph STAGE2["阶段 2: 深度分析"]
|
||||
@@ -98,7 +97,7 @@ flowchart LR
|
||||
FINISH["扫描完成"]
|
||||
|
||||
START --> STAGE1
|
||||
SITE --> STAGE2
|
||||
FINGER --> STAGE2
|
||||
STAGE2 --> STAGE3
|
||||
STAGE3 --> FINISH
|
||||
|
||||
@@ -110,6 +109,7 @@ flowchart LR
|
||||
style SUB fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
|
||||
style PORT fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
|
||||
style SITE fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
|
||||
style FINGER fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
|
||||
style URL fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
|
||||
style DIR fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
|
||||
style VULN fill:#f0b27a,stroke:#e67e22,stroke-width:1px,color:#fff
|
||||
@@ -162,25 +162,42 @@ flowchart TB
|
||||
W3 -.心跳上报.-> REDIS
|
||||
```
|
||||
|
||||
### 🔎 全局资产搜索
|
||||
- **多类型搜索** - 支持 Website 和 Endpoint 两种资产类型
|
||||
- **表达式语法** - 支持 `=`(模糊)、`==`(精确)、`!=`(不等于)操作符
|
||||
- **逻辑组合** - 支持 `&&` (AND) 和 `||` (OR) 逻辑组合
|
||||
- **多字段查询** - 支持 host、url、title、tech、status、body、header 字段
|
||||
- **CSV 导出** - 流式导出全部搜索结果,无数量限制
|
||||
|
||||
#### 搜索语法示例
|
||||
|
||||
```bash
|
||||
# 基础搜索
|
||||
host="api" # host 包含 "api"
|
||||
status=="200" # 状态码精确等于 200
|
||||
tech="nginx" # 技术栈包含 nginx
|
||||
|
||||
# 组合搜索
|
||||
host="api" && status=="200" # host 包含 api 且状态码为 200
|
||||
tech="vue" || tech="react" # 技术栈包含 vue 或 react
|
||||
|
||||
# 复杂查询
|
||||
host="admin" && tech="php" && status=="200"
|
||||
url="/api/v1" && status!="404"
|
||||
```
|
||||
|
||||
### 📊 可视化界面
|
||||
- **数据统计** - 资产/漏洞统计仪表盘
|
||||
- **实时通知** - WebSocket 消息推送
|
||||
- **暗色主题** - 支持明暗主题切换
|
||||
- **通知推送** - 实时企业微信,tg,discard消息推送服务
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ 技术栈
|
||||
|
||||
- **前端**: Next.js + React + TailwindCSS
|
||||
- **后端**: Django + Django REST Framework
|
||||
- **数据库**: PostgreSQL + Redis
|
||||
- **部署**: Docker + Nginx
|
||||
|
||||
## 📦 快速开始
|
||||
|
||||
### 环境要求
|
||||
|
||||
- **操作系统**: Ubuntu 20.04+ / Debian 11+ (推荐)
|
||||
- **操作系统**: Ubuntu 20.04+ / Debian 11+
|
||||
- **硬件**: 2核 4G 内存起步,20GB+ 磁盘空间
|
||||
|
||||
### 一键安装
|
||||
@@ -192,11 +209,20 @@ cd xingrin
|
||||
|
||||
# 安装并启动(生产模式)
|
||||
sudo ./install.sh
|
||||
|
||||
# 🇨🇳 中国大陆用户推荐使用镜像加速(第三方加速服务可能会失效,不保证长期可用)
|
||||
sudo ./install.sh --mirror
|
||||
```
|
||||
|
||||
> **💡 --mirror 参数说明**
|
||||
> - 自动配置 Docker 镜像加速(国内镜像源)
|
||||
> - 加速 Git 仓库克隆(Nuclei 模板等)
|
||||
> - 大幅提升安装速度,避免网络超时
|
||||
|
||||
### 访问服务
|
||||
|
||||
- **Web 界面**: `https://localhost`
|
||||
- **Web 界面**: `https://ip:8083`
|
||||
- **默认账号**: admin / admin(首次登录后请修改密码)
|
||||
|
||||
### 常用命令
|
||||
|
||||
@@ -212,22 +238,18 @@ sudo ./restart.sh
|
||||
|
||||
# 卸载
|
||||
sudo ./uninstall.sh
|
||||
|
||||
# 更新
|
||||
sudo ./update.sh
|
||||
```
|
||||
|
||||
## 🤝 反馈与贡献
|
||||
|
||||
- 🐛 **如果发现 Bug** 可以点击右边链接进行提交 [Issue](https://github.com/yyhuni/xingrin/issues)
|
||||
- 💡 **有新想法,比如UI设计,功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues)
|
||||
- 🔧 **想参与开发?** 关注我公众号与我个人联系
|
||||
|
||||
## 📧 联系
|
||||
- 目前版本就我个人使用,可能会有很多边界问题
|
||||
- 如有问题,建议,其他,优先提交[Issue](https://github.com/yyhuni/xingrin/issues),也可以直接给我的公众号发消息,我都会回复的
|
||||
|
||||
- 微信公众号: **洋洋的小黑屋**
|
||||
- 微信公众号: **塔罗安全学苑**
|
||||
|
||||
<img src="docs/wechat-qrcode.png" alt="微信公众号" width="200">
|
||||
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from django.apps import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssetConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
@@ -8,3 +13,94 @@ class AssetConfig(AppConfig):
|
||||
def ready(self):
|
||||
# 导入所有模型以确保Django发现并注册
|
||||
from . import models
|
||||
|
||||
# 启用 pg_trgm 扩展(用于文本模糊搜索索引)
|
||||
# 用于已有数据库升级场景
|
||||
self._ensure_pg_trgm_extension()
|
||||
|
||||
# 验证 pg_ivm 扩展是否可用(用于 IMMV 增量维护)
|
||||
self._verify_pg_ivm_extension()
|
||||
|
||||
def _ensure_pg_trgm_extension(self):
|
||||
"""
|
||||
确保 pg_trgm 扩展已启用。
|
||||
该扩展用于 response_body 和 response_headers 字段的 GIN 索引,
|
||||
支持高效的文本模糊搜索。
|
||||
"""
|
||||
from django.db import connection
|
||||
|
||||
# 检查是否为 PostgreSQL 数据库
|
||||
if connection.vendor != 'postgresql':
|
||||
logger.debug("跳过 pg_trgm 扩展:当前数据库不是 PostgreSQL")
|
||||
return
|
||||
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
|
||||
logger.debug("pg_trgm 扩展已启用")
|
||||
except Exception as e:
|
||||
# 记录错误但不阻止应用启动
|
||||
# 常见原因:权限不足(需要超级用户权限)
|
||||
logger.warning(
|
||||
"无法创建 pg_trgm 扩展: %s。"
|
||||
"这可能导致 response_body 和 response_headers 字段的 GIN 索引无法正常工作。"
|
||||
"请手动执行: CREATE EXTENSION IF NOT EXISTS pg_trgm;",
|
||||
str(e)
|
||||
)
|
||||
|
||||
def _verify_pg_ivm_extension(self):
|
||||
"""
|
||||
验证 pg_ivm 扩展是否可用。
|
||||
pg_ivm 用于 IMMV(增量维护物化视图),是系统必需的扩展。
|
||||
如果不可用,将记录错误并退出。
|
||||
"""
|
||||
from django.db import connection
|
||||
|
||||
# 检查是否为 PostgreSQL 数据库
|
||||
if connection.vendor != 'postgresql':
|
||||
logger.debug("跳过 pg_ivm 验证:当前数据库不是 PostgreSQL")
|
||||
return
|
||||
|
||||
# 跳过某些管理命令(如 migrate、makemigrations)
|
||||
import sys
|
||||
if len(sys.argv) > 1 and sys.argv[1] in ('migrate', 'makemigrations', 'collectstatic', 'check'):
|
||||
logger.debug("跳过 pg_ivm 验证:当前为管理命令")
|
||||
return
|
||||
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
# 检查 pg_ivm 扩展是否已安装
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*) FROM pg_extension WHERE extname = 'pg_ivm'
|
||||
""")
|
||||
count = cursor.fetchone()[0]
|
||||
|
||||
if count > 0:
|
||||
logger.info("✓ pg_ivm 扩展已启用")
|
||||
else:
|
||||
# 尝试创建扩展
|
||||
try:
|
||||
cursor.execute("CREATE EXTENSION IF NOT EXISTS pg_ivm;")
|
||||
logger.info("✓ pg_ivm 扩展已创建并启用")
|
||||
except Exception as create_error:
|
||||
logger.error(
|
||||
"=" * 60 + "\n"
|
||||
"错误: pg_ivm 扩展未安装\n"
|
||||
"=" * 60 + "\n"
|
||||
"pg_ivm 是系统必需的扩展,用于增量维护物化视图。\n\n"
|
||||
"请在 PostgreSQL 服务器上安装 pg_ivm:\n"
|
||||
" curl -sSL https://raw.githubusercontent.com/yyhuni/xingrin/main/docker/scripts/install-pg-ivm.sh | sudo bash\n\n"
|
||||
"或手动安装:\n"
|
||||
" 1. apt install build-essential postgresql-server-dev-15 git\n"
|
||||
" 2. git clone https://github.com/sraoss/pg_ivm.git && cd pg_ivm && make && make install\n"
|
||||
" 3. 在 postgresql.conf 中添加: shared_preload_libraries = 'pg_ivm'\n"
|
||||
" 4. 重启 PostgreSQL\n"
|
||||
"=" * 60
|
||||
)
|
||||
# 在生产环境中退出,开发环境中仅警告
|
||||
from django.conf import settings
|
||||
if not settings.DEBUG:
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"pg_ivm 扩展验证失败: {e}")
|
||||
|
||||
@@ -14,12 +14,13 @@ class EndpointDTO:
|
||||
status_code: Optional[int] = None
|
||||
content_length: Optional[int] = None
|
||||
webserver: Optional[str] = None
|
||||
body_preview: Optional[str] = None
|
||||
response_body: Optional[str] = None
|
||||
content_type: Optional[str] = None
|
||||
tech: Optional[List[str]] = None
|
||||
vhost: Optional[bool] = None
|
||||
location: Optional[str] = None
|
||||
matched_gf_patterns: Optional[List[str]] = None
|
||||
response_headers: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tech is None:
|
||||
|
||||
@@ -9,7 +9,7 @@ class WebSiteDTO:
|
||||
"""网站数据传输对象"""
|
||||
target_id: int
|
||||
url: str
|
||||
host: str
|
||||
host: str = ''
|
||||
title: str = ''
|
||||
status_code: Optional[int] = None
|
||||
content_length: Optional[int] = None
|
||||
@@ -17,9 +17,10 @@ class WebSiteDTO:
|
||||
webserver: str = ''
|
||||
content_type: str = ''
|
||||
tech: List[str] = None
|
||||
body_preview: str = ''
|
||||
response_body: str = ''
|
||||
vhost: Optional[bool] = None
|
||||
created_at: str = None
|
||||
response_headers: str = ''
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tech is None:
|
||||
|
||||
@@ -13,6 +13,7 @@ class EndpointSnapshotDTO:
|
||||
快照只属于 scan。
|
||||
"""
|
||||
scan_id: int
|
||||
target_id: int # 必填,用于同步到资产表
|
||||
url: str
|
||||
host: str = '' # 主机名(域名或IP地址)
|
||||
title: str = ''
|
||||
@@ -22,10 +23,10 @@ class EndpointSnapshotDTO:
|
||||
webserver: str = ''
|
||||
content_type: str = ''
|
||||
tech: List[str] = None
|
||||
body_preview: str = ''
|
||||
response_body: str = ''
|
||||
vhost: Optional[bool] = None
|
||||
matched_gf_patterns: List[str] = None
|
||||
target_id: Optional[int] = None # 冗余字段,用于同步到资产表
|
||||
response_headers: str = ''
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tech is None:
|
||||
@@ -42,9 +43,6 @@ class EndpointSnapshotDTO:
|
||||
"""
|
||||
from apps.asset.dtos.asset import EndpointDTO
|
||||
|
||||
if self.target_id is None:
|
||||
raise ValueError("target_id 不能为 None,无法同步到资产表")
|
||||
|
||||
return EndpointDTO(
|
||||
target_id=self.target_id,
|
||||
url=self.url,
|
||||
@@ -53,10 +51,11 @@ class EndpointSnapshotDTO:
|
||||
status_code=self.status_code,
|
||||
content_length=self.content_length,
|
||||
webserver=self.webserver,
|
||||
body_preview=self.body_preview,
|
||||
response_body=self.response_body,
|
||||
content_type=self.content_type,
|
||||
tech=self.tech if self.tech else [],
|
||||
vhost=self.vhost,
|
||||
location=self.location,
|
||||
matched_gf_patterns=self.matched_gf_patterns if self.matched_gf_patterns else []
|
||||
matched_gf_patterns=self.matched_gf_patterns if self.matched_gf_patterns else [],
|
||||
response_headers=self.response_headers,
|
||||
)
|
||||
|
||||
@@ -13,18 +13,19 @@ class WebsiteSnapshotDTO:
|
||||
快照只属于 scan,target 信息通过 scan.target 获取。
|
||||
"""
|
||||
scan_id: int
|
||||
target_id: int # 仅用于传递数据,不保存到数据库
|
||||
target_id: int # 必填,用于同步到资产表
|
||||
url: str
|
||||
host: str
|
||||
title: str = ''
|
||||
status: Optional[int] = None
|
||||
status_code: Optional[int] = None # 统一命名:status -> status_code
|
||||
content_length: Optional[int] = None
|
||||
location: str = ''
|
||||
web_server: str = ''
|
||||
webserver: str = '' # 统一命名:web_server -> webserver
|
||||
content_type: str = ''
|
||||
tech: List[str] = None
|
||||
body_preview: str = ''
|
||||
response_body: str = ''
|
||||
vhost: Optional[bool] = None
|
||||
response_headers: str = ''
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tech is None:
|
||||
@@ -44,12 +45,13 @@ class WebsiteSnapshotDTO:
|
||||
url=self.url,
|
||||
host=self.host,
|
||||
title=self.title,
|
||||
status_code=self.status,
|
||||
status_code=self.status_code,
|
||||
content_length=self.content_length,
|
||||
location=self.location,
|
||||
webserver=self.web_server,
|
||||
webserver=self.webserver,
|
||||
content_type=self.content_type,
|
||||
tech=self.tech if self.tech else [],
|
||||
body_preview=self.body_preview,
|
||||
vhost=self.vhost
|
||||
response_body=self.response_body,
|
||||
vhost=self.vhost,
|
||||
response_headers=self.response_headers,
|
||||
)
|
||||
|
||||
345
backend/apps/asset/migrations/0001_initial.py
Normal file
345
backend/apps/asset/migrations/0001_initial.py
Normal file
@@ -0,0 +1,345 @@
|
||||
# Generated by Django 5.2.7 on 2026-01-02 04:45
|
||||
|
||||
import django.contrib.postgres.fields
|
||||
import django.contrib.postgres.indexes
|
||||
import django.core.validators
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
('scan', '0001_initial'),
|
||||
('targets', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='AssetStatistics',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('total_targets', models.IntegerField(default=0, help_text='目标总数')),
|
||||
('total_subdomains', models.IntegerField(default=0, help_text='子域名总数')),
|
||||
('total_ips', models.IntegerField(default=0, help_text='IP地址总数')),
|
||||
('total_endpoints', models.IntegerField(default=0, help_text='端点总数')),
|
||||
('total_websites', models.IntegerField(default=0, help_text='网站总数')),
|
||||
('total_vulns', models.IntegerField(default=0, help_text='漏洞总数')),
|
||||
('total_assets', models.IntegerField(default=0, help_text='总资产数(子域名+IP+端点+网站)')),
|
||||
('prev_targets', models.IntegerField(default=0, help_text='上次目标总数')),
|
||||
('prev_subdomains', models.IntegerField(default=0, help_text='上次子域名总数')),
|
||||
('prev_ips', models.IntegerField(default=0, help_text='上次IP地址总数')),
|
||||
('prev_endpoints', models.IntegerField(default=0, help_text='上次端点总数')),
|
||||
('prev_websites', models.IntegerField(default=0, help_text='上次网站总数')),
|
||||
('prev_vulns', models.IntegerField(default=0, help_text='上次漏洞总数')),
|
||||
('prev_assets', models.IntegerField(default=0, help_text='上次总资产数')),
|
||||
('updated_at', models.DateTimeField(auto_now=True, help_text='最后更新时间')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '资产统计',
|
||||
'verbose_name_plural': '资产统计',
|
||||
'db_table': 'asset_statistics',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='StatisticsHistory',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('date', models.DateField(help_text='统计日期', unique=True)),
|
||||
('total_targets', models.IntegerField(default=0, help_text='目标总数')),
|
||||
('total_subdomains', models.IntegerField(default=0, help_text='子域名总数')),
|
||||
('total_ips', models.IntegerField(default=0, help_text='IP地址总数')),
|
||||
('total_endpoints', models.IntegerField(default=0, help_text='端点总数')),
|
||||
('total_websites', models.IntegerField(default=0, help_text='网站总数')),
|
||||
('total_vulns', models.IntegerField(default=0, help_text='漏洞总数')),
|
||||
('total_assets', models.IntegerField(default=0, help_text='总资产数')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '统计历史',
|
||||
'verbose_name_plural': '统计历史',
|
||||
'db_table': 'statistics_history',
|
||||
'ordering': ['-date'],
|
||||
'indexes': [models.Index(fields=['date'], name='statistics__date_1d29cd_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Directory',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.CharField(help_text='完整请求 URL', max_length=2000)),
|
||||
('status', models.IntegerField(blank=True, help_text='HTTP 响应状态码', null=True)),
|
||||
('content_length', models.BigIntegerField(blank=True, help_text='响应体字节大小(Content-Length 或实际长度)', null=True)),
|
||||
('words', models.IntegerField(blank=True, help_text='响应体中单词数量(按空格分割)', null=True)),
|
||||
('lines', models.IntegerField(blank=True, help_text='响应体行数(按换行符分割)', null=True)),
|
||||
('content_type', models.CharField(blank=True, default='', help_text='响应头 Content-Type 值', max_length=200)),
|
||||
('duration', models.BigIntegerField(blank=True, help_text='请求耗时(单位:纳秒)', null=True)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('target', models.ForeignKey(help_text='所属的扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='directories', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '目录',
|
||||
'verbose_name_plural': '目录',
|
||||
'db_table': 'directory',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['-created_at'], name='directory_created_2cef03_idx'), models.Index(fields=['target'], name='directory_target__e310c8_idx'), models.Index(fields=['url'], name='directory_url_ba40cd_idx'), models.Index(fields=['status'], name='directory_status_40bbe6_idx'), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='directory_url_trgm_idx', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('target', 'url'), name='unique_directory_url_target')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='DirectorySnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.CharField(help_text='目录URL', max_length=2000)),
|
||||
('status', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.BigIntegerField(blank=True, help_text='内容长度', null=True)),
|
||||
('words', models.IntegerField(blank=True, help_text='响应体中单词数量(按空格分割)', null=True)),
|
||||
('lines', models.IntegerField(blank=True, help_text='响应体行数(按换行符分割)', null=True)),
|
||||
('content_type', models.CharField(blank=True, default='', help_text='响应头 Content-Type 值', max_length=200)),
|
||||
('duration', models.BigIntegerField(blank=True, help_text='请求耗时(单位:纳秒)', null=True)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='directory_snapshots', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '目录快照',
|
||||
'verbose_name_plural': '目录快照',
|
||||
'db_table': 'directory_snapshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scan'], name='directory_s_scan_id_c45900_idx'), models.Index(fields=['url'], name='directory_s_url_b4b72b_idx'), models.Index(fields=['status'], name='directory_s_status_e9f57e_idx'), models.Index(fields=['content_type'], name='directory_s_content_45e864_idx'), models.Index(fields=['-created_at'], name='directory_s_created_eb9d27_idx'), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='dir_snap_url_trgm', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_directory_per_scan_snapshot')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Endpoint',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.TextField(help_text='最终访问的完整URL')),
|
||||
('host', models.CharField(blank=True, default='', help_text='主机名(域名或IP地址)', max_length=253)),
|
||||
('location', models.TextField(blank=True, default='', help_text='重定向地址(HTTP 3xx 响应头 Location)')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('title', models.TextField(blank=True, default='', help_text='网页标题(HTML <title> 标签内容)')),
|
||||
('webserver', models.TextField(blank=True, default='', help_text='服务器类型(HTTP 响应头 Server 值)')),
|
||||
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
|
||||
('content_type', models.TextField(blank=True, default='', help_text='响应类型(HTTP Content-Type 响应头)')),
|
||||
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈(服务器/框架/语言等)', size=None)),
|
||||
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.IntegerField(blank=True, help_text='响应体大小(单位字节)', null=True)),
|
||||
('vhost', models.BooleanField(blank=True, help_text='是否支持虚拟主机', null=True)),
|
||||
('matched_gf_patterns', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='匹配的GF模式列表,用于识别敏感端点(如api, debug, config等)', size=None)),
|
||||
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
|
||||
('target', models.ForeignKey(help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)', on_delete=django.db.models.deletion.CASCADE, related_name='endpoints', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '端点',
|
||||
'verbose_name_plural': '端点',
|
||||
'db_table': 'endpoint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['-created_at'], name='endpoint_created_44fe9c_idx'), models.Index(fields=['target'], name='endpoint_target__7f9065_idx'), models.Index(fields=['url'], name='endpoint_url_30f66e_idx'), models.Index(fields=['host'], name='endpoint_host_5b4cc8_idx'), models.Index(fields=['status_code'], name='endpoint_status__5d4fdd_idx'), models.Index(fields=['title'], name='endpoint_title_29e26c_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='endpoint_tech_2bfa7c_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='endpoint_resp_headers_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='endpoint_url_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='endpoint_title_trgm_idx', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('url', 'target'), name='unique_endpoint_url_target')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='EndpointSnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.TextField(help_text='端点URL')),
|
||||
('host', models.CharField(blank=True, default='', help_text='主机名(域名或IP地址)', max_length=253)),
|
||||
('title', models.TextField(blank=True, default='', help_text='页面标题')),
|
||||
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.IntegerField(blank=True, help_text='内容长度', null=True)),
|
||||
('location', models.TextField(blank=True, default='', help_text='重定向位置')),
|
||||
('webserver', models.TextField(blank=True, default='', help_text='Web服务器')),
|
||||
('content_type', models.TextField(blank=True, default='', help_text='内容类型')),
|
||||
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈', size=None)),
|
||||
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
|
||||
('vhost', models.BooleanField(blank=True, help_text='虚拟主机标志', null=True)),
|
||||
('matched_gf_patterns', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='匹配的GF模式列表', size=None)),
|
||||
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='endpoint_snapshots', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '端点快照',
|
||||
'verbose_name_plural': '端点快照',
|
||||
'db_table': 'endpoint_snapshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scan'], name='endpoint_sn_scan_id_6ac9a7_idx'), models.Index(fields=['url'], name='endpoint_sn_url_205160_idx'), models.Index(fields=['host'], name='endpoint_sn_host_577bfd_idx'), models.Index(fields=['title'], name='endpoint_sn_title_516a05_idx'), models.Index(fields=['status_code'], name='endpoint_sn_status__83efb0_idx'), models.Index(fields=['webserver'], name='endpoint_sn_webserv_66be83_idx'), models.Index(fields=['-created_at'], name='endpoint_sn_created_21fb5b_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='endpoint_sn_tech_0d0752_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='ep_snap_resp_hdr_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='ep_snap_url_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='ep_snap_title_trgm', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_endpoint_per_scan_snapshot')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='HostPortMapping',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('host', models.CharField(help_text='主机名(域名或IP)', max_length=1000)),
|
||||
('ip', models.GenericIPAddressField(help_text='IP地址')),
|
||||
('port', models.IntegerField(help_text='端口号(1-65535)', validators=[django.core.validators.MinValueValidator(1, message='端口号必须大于等于1'), django.core.validators.MaxValueValidator(65535, message='端口号必须小于等于65535')])),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('target', models.ForeignKey(help_text='所属的扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='host_port_mappings', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '主机端口映射',
|
||||
'verbose_name_plural': '主机端口映射',
|
||||
'db_table': 'host_port_mapping',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['target'], name='host_port_m_target__943e9b_idx'), models.Index(fields=['host'], name='host_port_m_host_f78363_idx'), models.Index(fields=['ip'], name='host_port_m_ip_2e6f02_idx'), models.Index(fields=['port'], name='host_port_m_port_9fb9ff_idx'), models.Index(fields=['host', 'ip'], name='host_port_m_host_3ce245_idx'), models.Index(fields=['-created_at'], name='host_port_m_created_11cd22_idx')],
|
||||
'constraints': [models.UniqueConstraint(fields=('target', 'host', 'ip', 'port'), name='unique_target_host_ip_port')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='HostPortMappingSnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('host', models.CharField(help_text='主机名(域名或IP)', max_length=1000)),
|
||||
('ip', models.GenericIPAddressField(help_text='IP地址')),
|
||||
('port', models.IntegerField(help_text='端口号(1-65535)', validators=[django.core.validators.MinValueValidator(1, message='端口号必须大于等于1'), django.core.validators.MaxValueValidator(65535, message='端口号必须小于等于65535')])),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务(主关联)', on_delete=django.db.models.deletion.CASCADE, related_name='host_port_mapping_snapshots', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '主机端口映射快照',
|
||||
'verbose_name_plural': '主机端口映射快照',
|
||||
'db_table': 'host_port_mapping_snapshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scan'], name='host_port_m_scan_id_50ba0b_idx'), models.Index(fields=['host'], name='host_port_m_host_e99054_idx'), models.Index(fields=['ip'], name='host_port_m_ip_54818c_idx'), models.Index(fields=['port'], name='host_port_m_port_ed7b48_idx'), models.Index(fields=['host', 'ip'], name='host_port_m_host_8a463a_idx'), models.Index(fields=['scan', 'host'], name='host_port_m_scan_id_426fdb_idx'), models.Index(fields=['-created_at'], name='host_port_m_created_fb28b8_idx')],
|
||||
'constraints': [models.UniqueConstraint(fields=('scan', 'host', 'ip', 'port'), name='unique_scan_host_ip_port_snapshot')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Subdomain',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('name', models.CharField(help_text='子域名名称', max_length=1000)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('target', models.ForeignKey(help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)', on_delete=django.db.models.deletion.CASCADE, related_name='subdomains', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '子域名',
|
||||
'verbose_name_plural': '子域名',
|
||||
'db_table': 'subdomain',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['-created_at'], name='subdomain_created_e187a8_idx'), models.Index(fields=['name', 'target'], name='subdomain_name_60e1d0_idx'), models.Index(fields=['target'], name='subdomain_target__e409f0_idx'), models.Index(fields=['name'], name='subdomain_name_d40ba7_idx'), django.contrib.postgres.indexes.GinIndex(fields=['name'], name='subdomain_name_trgm_idx', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('name', 'target'), name='unique_subdomain_name_target')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='SubdomainSnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('name', models.CharField(help_text='子域名名称', max_length=1000)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='subdomain_snapshots', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '子域名快照',
|
||||
'verbose_name_plural': '子域名快照',
|
||||
'db_table': 'subdomain_snapshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scan'], name='subdomain_s_scan_id_68c253_idx'), models.Index(fields=['name'], name='subdomain_s_name_2da42b_idx'), models.Index(fields=['-created_at'], name='subdomain_s_created_d2b48e_idx'), django.contrib.postgres.indexes.GinIndex(fields=['name'], name='subdomain_snap_name_trgm', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('scan', 'name'), name='unique_subdomain_per_scan_snapshot')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Vulnerability',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.CharField(help_text='漏洞所在的URL', max_length=2000)),
|
||||
('vuln_type', models.CharField(help_text='漏洞类型(如 xss, sqli)', max_length=100)),
|
||||
('severity', models.CharField(choices=[('unknown', '未知'), ('info', '信息'), ('low', '低'), ('medium', '中'), ('high', '高'), ('critical', '危急')], default='unknown', help_text='严重性(未知/信息/低/中/高/危急)', max_length=20)),
|
||||
('source', models.CharField(blank=True, default='', help_text='来源工具(如 dalfox, nuclei, crlfuzz)', max_length=50)),
|
||||
('cvss_score', models.DecimalField(blank=True, decimal_places=1, help_text='CVSS 评分(0.0-10.0)', max_digits=3, null=True)),
|
||||
('description', models.TextField(blank=True, default='', help_text='漏洞描述')),
|
||||
('raw_output', models.JSONField(blank=True, default=dict, help_text='工具原始输出')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('target', models.ForeignKey(help_text='所属的扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='vulnerabilities', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '漏洞',
|
||||
'verbose_name_plural': '漏洞',
|
||||
'db_table': 'vulnerability',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['target'], name='vulnerabili_target__755a02_idx'), models.Index(fields=['vuln_type'], name='vulnerabili_vuln_ty_3010cd_idx'), models.Index(fields=['severity'], name='vulnerabili_severit_1a798b_idx'), models.Index(fields=['source'], name='vulnerabili_source_7c7552_idx'), models.Index(fields=['url'], name='vulnerabili_url_4dcc4d_idx'), models.Index(fields=['-created_at'], name='vulnerabili_created_e25ff7_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='VulnerabilitySnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.CharField(help_text='漏洞所在的URL', max_length=2000)),
|
||||
('vuln_type', models.CharField(help_text='漏洞类型(如 xss, sqli)', max_length=100)),
|
||||
('severity', models.CharField(choices=[('unknown', '未知'), ('info', '信息'), ('low', '低'), ('medium', '中'), ('high', '高'), ('critical', '危急')], default='unknown', help_text='严重性(未知/信息/低/中/高/危急)', max_length=20)),
|
||||
('source', models.CharField(blank=True, default='', help_text='来源工具(如 dalfox, nuclei, crlfuzz)', max_length=50)),
|
||||
('cvss_score', models.DecimalField(blank=True, decimal_places=1, help_text='CVSS 评分(0.0-10.0)', max_digits=3, null=True)),
|
||||
('description', models.TextField(blank=True, default='', help_text='漏洞描述')),
|
||||
('raw_output', models.JSONField(blank=True, default=dict, help_text='工具原始输出')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='vulnerability_snapshots', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '漏洞快照',
|
||||
'verbose_name_plural': '漏洞快照',
|
||||
'db_table': 'vulnerability_snapshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scan'], name='vulnerabili_scan_id_7b81c9_idx'), models.Index(fields=['url'], name='vulnerabili_url_11a707_idx'), models.Index(fields=['vuln_type'], name='vulnerabili_vuln_ty_6b90ee_idx'), models.Index(fields=['severity'], name='vulnerabili_severit_4eae0d_idx'), models.Index(fields=['source'], name='vulnerabili_source_968b1f_idx'), models.Index(fields=['-created_at'], name='vulnerabili_created_53a12e_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='WebSite',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.TextField(help_text='最终访问的完整URL')),
|
||||
('host', models.CharField(blank=True, default='', help_text='主机名(域名或IP地址)', max_length=253)),
|
||||
('location', models.TextField(blank=True, default='', help_text='重定向地址(HTTP 3xx 响应头 Location)')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('title', models.TextField(blank=True, default='', help_text='网页标题(HTML <title> 标签内容)')),
|
||||
('webserver', models.TextField(blank=True, default='', help_text='服务器类型(HTTP 响应头 Server 值)')),
|
||||
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
|
||||
('content_type', models.TextField(blank=True, default='', help_text='响应类型(HTTP Content-Type 响应头)')),
|
||||
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈(服务器/框架/语言等)', size=None)),
|
||||
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.IntegerField(blank=True, help_text='响应体大小(单位字节)', null=True)),
|
||||
('vhost', models.BooleanField(blank=True, help_text='是否支持虚拟主机', null=True)),
|
||||
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
|
||||
('target', models.ForeignKey(help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)', on_delete=django.db.models.deletion.CASCADE, related_name='websites', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '站点',
|
||||
'verbose_name_plural': '站点',
|
||||
'db_table': 'website',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['-created_at'], name='website_created_c9cfd2_idx'), models.Index(fields=['url'], name='website_url_b18883_idx'), models.Index(fields=['host'], name='website_host_996b50_idx'), models.Index(fields=['target'], name='website_target__2a353b_idx'), models.Index(fields=['title'], name='website_title_c2775b_idx'), models.Index(fields=['status_code'], name='website_status__51663d_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='website_tech_e3f0cb_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='website_resp_headers_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='website_url_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='website_title_trgm_idx', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('url', 'target'), name='unique_website_url_target')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='WebsiteSnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.TextField(help_text='站点URL')),
|
||||
('host', models.CharField(blank=True, default='', help_text='主机名(域名或IP地址)', max_length=253)),
|
||||
('title', models.TextField(blank=True, default='', help_text='页面标题')),
|
||||
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.BigIntegerField(blank=True, help_text='内容长度', null=True)),
|
||||
('location', models.TextField(blank=True, default='', help_text='重定向位置')),
|
||||
('webserver', models.TextField(blank=True, default='', help_text='Web服务器')),
|
||||
('content_type', models.TextField(blank=True, default='', help_text='内容类型')),
|
||||
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈', size=None)),
|
||||
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
|
||||
('vhost', models.BooleanField(blank=True, help_text='虚拟主机标志', null=True)),
|
||||
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='website_snapshots', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '网站快照',
|
||||
'verbose_name_plural': '网站快照',
|
||||
'db_table': 'website_snapshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scan'], name='website_sna_scan_id_26b6dc_idx'), models.Index(fields=['url'], name='website_sna_url_801a70_idx'), models.Index(fields=['host'], name='website_sna_host_348fe1_idx'), models.Index(fields=['title'], name='website_sna_title_b1a5ee_idx'), models.Index(fields=['-created_at'], name='website_sna_created_2c149a_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='website_sna_tech_3d6d2f_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='ws_snap_resp_hdr_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='ws_snap_url_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='ws_snap_title_trgm', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_website_per_scan_snapshot')],
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
创建资产搜索 IMMV(增量维护物化视图)
|
||||
|
||||
使用 pg_ivm 扩展创建 IMMV,数据变更时自动增量更新,无需手动刷新。
|
||||
|
||||
包含:
|
||||
1. asset_search_view - Website 搜索视图
|
||||
2. endpoint_search_view - Endpoint 搜索视图
|
||||
"""
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('asset', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
# 1. 确保 pg_ivm 扩展已启用
|
||||
migrations.RunSQL(
|
||||
sql="CREATE EXTENSION IF NOT EXISTS pg_ivm;",
|
||||
reverse_sql="-- pg_ivm extension kept for other uses"
|
||||
),
|
||||
|
||||
# ==================== Website IMMV ====================
|
||||
|
||||
# 2. 创建 asset_search_view IMMV
|
||||
migrations.RunSQL(
|
||||
sql="""
|
||||
SELECT pgivm.create_immv('asset_search_view', $$
|
||||
SELECT
|
||||
w.id,
|
||||
w.url,
|
||||
w.host,
|
||||
w.title,
|
||||
w.tech,
|
||||
w.status_code,
|
||||
w.response_headers,
|
||||
w.response_body,
|
||||
w.content_type,
|
||||
w.content_length,
|
||||
w.webserver,
|
||||
w.location,
|
||||
w.vhost,
|
||||
w.created_at,
|
||||
w.target_id
|
||||
FROM website w
|
||||
$$);
|
||||
""",
|
||||
reverse_sql="SELECT pgivm.drop_immv('asset_search_view');"
|
||||
),
|
||||
|
||||
# 3. 创建 asset_search_view 索引
|
||||
migrations.RunSQL(
|
||||
sql="""
|
||||
-- 唯一索引
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS asset_search_view_id_idx
|
||||
ON asset_search_view (id);
|
||||
|
||||
-- host 模糊搜索索引
|
||||
CREATE INDEX IF NOT EXISTS asset_search_view_host_trgm_idx
|
||||
ON asset_search_view USING gin (host gin_trgm_ops);
|
||||
|
||||
-- title 模糊搜索索引
|
||||
CREATE INDEX IF NOT EXISTS asset_search_view_title_trgm_idx
|
||||
ON asset_search_view USING gin (title gin_trgm_ops);
|
||||
|
||||
-- url 模糊搜索索引
|
||||
CREATE INDEX IF NOT EXISTS asset_search_view_url_trgm_idx
|
||||
ON asset_search_view USING gin (url gin_trgm_ops);
|
||||
|
||||
-- response_headers 模糊搜索索引
|
||||
CREATE INDEX IF NOT EXISTS asset_search_view_headers_trgm_idx
|
||||
ON asset_search_view USING gin (response_headers gin_trgm_ops);
|
||||
|
||||
-- response_body 模糊搜索索引
|
||||
CREATE INDEX IF NOT EXISTS asset_search_view_body_trgm_idx
|
||||
ON asset_search_view USING gin (response_body gin_trgm_ops);
|
||||
|
||||
-- tech 数组索引
|
||||
CREATE INDEX IF NOT EXISTS asset_search_view_tech_idx
|
||||
ON asset_search_view USING gin (tech);
|
||||
|
||||
-- status_code 索引
|
||||
CREATE INDEX IF NOT EXISTS asset_search_view_status_idx
|
||||
ON asset_search_view (status_code);
|
||||
|
||||
-- created_at 排序索引
|
||||
CREATE INDEX IF NOT EXISTS asset_search_view_created_idx
|
||||
ON asset_search_view (created_at DESC);
|
||||
""",
|
||||
reverse_sql="""
|
||||
DROP INDEX IF EXISTS asset_search_view_id_idx;
|
||||
DROP INDEX IF EXISTS asset_search_view_host_trgm_idx;
|
||||
DROP INDEX IF EXISTS asset_search_view_title_trgm_idx;
|
||||
DROP INDEX IF EXISTS asset_search_view_url_trgm_idx;
|
||||
DROP INDEX IF EXISTS asset_search_view_headers_trgm_idx;
|
||||
DROP INDEX IF EXISTS asset_search_view_body_trgm_idx;
|
||||
DROP INDEX IF EXISTS asset_search_view_tech_idx;
|
||||
DROP INDEX IF EXISTS asset_search_view_status_idx;
|
||||
DROP INDEX IF EXISTS asset_search_view_created_idx;
|
||||
"""
|
||||
),
|
||||
|
||||
# ==================== Endpoint IMMV ====================
|
||||
|
||||
# 4. 创建 endpoint_search_view IMMV
|
||||
migrations.RunSQL(
|
||||
sql="""
|
||||
SELECT pgivm.create_immv('endpoint_search_view', $$
|
||||
SELECT
|
||||
e.id,
|
||||
e.url,
|
||||
e.host,
|
||||
e.title,
|
||||
e.tech,
|
||||
e.status_code,
|
||||
e.response_headers,
|
||||
e.response_body,
|
||||
e.content_type,
|
||||
e.content_length,
|
||||
e.webserver,
|
||||
e.location,
|
||||
e.vhost,
|
||||
e.matched_gf_patterns,
|
||||
e.created_at,
|
||||
e.target_id
|
||||
FROM endpoint e
|
||||
$$);
|
||||
""",
|
||||
reverse_sql="SELECT pgivm.drop_immv('endpoint_search_view');"
|
||||
),
|
||||
|
||||
# 5. 创建 endpoint_search_view 索引
|
||||
migrations.RunSQL(
|
||||
sql="""
|
||||
-- 唯一索引
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS endpoint_search_view_id_idx
|
||||
ON endpoint_search_view (id);
|
||||
|
||||
-- host 模糊搜索索引
|
||||
CREATE INDEX IF NOT EXISTS endpoint_search_view_host_trgm_idx
|
||||
ON endpoint_search_view USING gin (host gin_trgm_ops);
|
||||
|
||||
-- title 模糊搜索索引
|
||||
CREATE INDEX IF NOT EXISTS endpoint_search_view_title_trgm_idx
|
||||
ON endpoint_search_view USING gin (title gin_trgm_ops);
|
||||
|
||||
-- url 模糊搜索索引
|
||||
CREATE INDEX IF NOT EXISTS endpoint_search_view_url_trgm_idx
|
||||
ON endpoint_search_view USING gin (url gin_trgm_ops);
|
||||
|
||||
-- response_headers 模糊搜索索引
|
||||
CREATE INDEX IF NOT EXISTS endpoint_search_view_headers_trgm_idx
|
||||
ON endpoint_search_view USING gin (response_headers gin_trgm_ops);
|
||||
|
||||
-- response_body 模糊搜索索引
|
||||
CREATE INDEX IF NOT EXISTS endpoint_search_view_body_trgm_idx
|
||||
ON endpoint_search_view USING gin (response_body gin_trgm_ops);
|
||||
|
||||
-- tech 数组索引
|
||||
CREATE INDEX IF NOT EXISTS endpoint_search_view_tech_idx
|
||||
ON endpoint_search_view USING gin (tech);
|
||||
|
||||
-- status_code 索引
|
||||
CREATE INDEX IF NOT EXISTS endpoint_search_view_status_idx
|
||||
ON endpoint_search_view (status_code);
|
||||
|
||||
-- created_at 排序索引
|
||||
CREATE INDEX IF NOT EXISTS endpoint_search_view_created_idx
|
||||
ON endpoint_search_view (created_at DESC);
|
||||
""",
|
||||
reverse_sql="""
|
||||
DROP INDEX IF EXISTS endpoint_search_view_id_idx;
|
||||
DROP INDEX IF EXISTS endpoint_search_view_host_trgm_idx;
|
||||
DROP INDEX IF EXISTS endpoint_search_view_title_trgm_idx;
|
||||
DROP INDEX IF EXISTS endpoint_search_view_url_trgm_idx;
|
||||
DROP INDEX IF EXISTS endpoint_search_view_headers_trgm_idx;
|
||||
DROP INDEX IF EXISTS endpoint_search_view_body_trgm_idx;
|
||||
DROP INDEX IF EXISTS endpoint_search_view_tech_idx;
|
||||
DROP INDEX IF EXISTS endpoint_search_view_status_idx;
|
||||
DROP INDEX IF EXISTS endpoint_search_view_created_idx;
|
||||
"""
|
||||
),
|
||||
]
|
||||
@@ -1,6 +1,7 @@
|
||||
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.contrib.postgres.indexes import GinIndex
|
||||
from django.core.validators import MinValueValidator, MaxValueValidator
|
||||
|
||||
|
||||
@@ -22,18 +23,24 @@ class Subdomain(models.Model):
|
||||
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
|
||||
)
|
||||
name = models.CharField(max_length=1000, help_text='子域名名称')
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='首次发现时间')
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'subdomain'
|
||||
verbose_name = '子域名'
|
||||
verbose_name_plural = '子域名'
|
||||
ordering = ['-discovered_at']
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['-created_at']),
|
||||
models.Index(fields=['name', 'target']), # 复合索引,优化 get_by_names_and_target_id 批量查询
|
||||
models.Index(fields=['target']), # 优化从target_id快速查找下面的子域名
|
||||
models.Index(fields=['name']), # 优化从name快速查找子域名,搜索场景
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='subdomain_name_trgm_idx',
|
||||
fields=['name'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 普通唯一约束:name + target 组合唯一
|
||||
@@ -58,40 +65,35 @@ class Endpoint(models.Model):
|
||||
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
|
||||
)
|
||||
|
||||
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
|
||||
url = models.TextField(help_text='最终访问的完整URL')
|
||||
host = models.CharField(
|
||||
max_length=253,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='主机名(域名或IP地址)'
|
||||
)
|
||||
location = models.CharField(
|
||||
max_length=1000,
|
||||
location = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='重定向地址(HTTP 3xx 响应头 Location)'
|
||||
)
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
title = models.CharField(
|
||||
max_length=1000,
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
title = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='网页标题(HTML <title> 标签内容)'
|
||||
)
|
||||
webserver = models.CharField(
|
||||
max_length=200,
|
||||
webserver = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='服务器类型(HTTP 响应头 Server 值)'
|
||||
)
|
||||
body_preview = models.CharField(
|
||||
max_length=1000,
|
||||
response_body = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应正文前N个字符(默认100个字符)'
|
||||
help_text='HTTP响应体'
|
||||
)
|
||||
content_type = models.CharField(
|
||||
max_length=200,
|
||||
content_type = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应类型(HTTP Content-Type 响应头)'
|
||||
@@ -123,18 +125,41 @@ class Endpoint(models.Model):
|
||||
default=list,
|
||||
help_text='匹配的GF模式列表,用于识别敏感端点(如api, debug, config等)'
|
||||
)
|
||||
response_headers = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='原始HTTP响应头'
|
||||
)
|
||||
|
||||
class Meta:
|
||||
db_table = 'endpoint'
|
||||
verbose_name = '端点'
|
||||
verbose_name_plural = '端点'
|
||||
ordering = ['-discovered_at']
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['target']), # 优化从target_id快速查找下面的端点(主关联字段)
|
||||
models.Index(fields=['-created_at']),
|
||||
models.Index(fields=['target']), # 优化从 target_id快速查找下面的端点(主关联字段)
|
||||
models.Index(fields=['url']), # URL索引,优化查询性能
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['status_code']), # 状态码索引,优化筛选
|
||||
models.Index(fields=['title']), # title索引,优化智能过滤搜索
|
||||
GinIndex(fields=['tech']), # GIN索引,优化 tech 数组字段的 __contains 查询
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='endpoint_resp_headers_trgm_idx',
|
||||
fields=['response_headers'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='endpoint_url_trgm_idx',
|
||||
fields=['url'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='endpoint_title_trgm_idx',
|
||||
fields=['title'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 普通唯一约束:url + target 组合唯一
|
||||
@@ -159,40 +184,35 @@ class WebSite(models.Model):
|
||||
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
|
||||
)
|
||||
|
||||
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
|
||||
url = models.TextField(help_text='最终访问的完整URL')
|
||||
host = models.CharField(
|
||||
max_length=253,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='主机名(域名或IP地址)'
|
||||
)
|
||||
location = models.CharField(
|
||||
max_length=1000,
|
||||
location = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='重定向地址(HTTP 3xx 响应头 Location)'
|
||||
)
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
title = models.CharField(
|
||||
max_length=1000,
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
title = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='网页标题(HTML <title> 标签内容)'
|
||||
)
|
||||
webserver = models.CharField(
|
||||
max_length=200,
|
||||
webserver = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='服务器类型(HTTP 响应头 Server 值)'
|
||||
)
|
||||
body_preview = models.CharField(
|
||||
max_length=1000,
|
||||
response_body = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应正文前N个字符(默认100个字符)'
|
||||
help_text='HTTP响应体'
|
||||
)
|
||||
content_type = models.CharField(
|
||||
max_length=200,
|
||||
content_type = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应类型(HTTP Content-Type 响应头)'
|
||||
@@ -218,17 +238,41 @@ class WebSite(models.Model):
|
||||
blank=True,
|
||||
help_text='是否支持虚拟主机'
|
||||
)
|
||||
response_headers = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='原始HTTP响应头'
|
||||
)
|
||||
|
||||
class Meta:
|
||||
db_table = 'website'
|
||||
verbose_name = '站点'
|
||||
verbose_name_plural = '站点'
|
||||
ordering = ['-discovered_at']
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['-created_at']),
|
||||
models.Index(fields=['url']), # URL索引,优化查询性能
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['target']), # 优化从target_id快速查找下面的站点
|
||||
models.Index(fields=['target']), # 优化从 target_id快速查找下面的站点
|
||||
models.Index(fields=['title']), # title索引,优化智能过滤搜索
|
||||
models.Index(fields=['status_code']), # 状态码索引,优化智能过滤搜索
|
||||
GinIndex(fields=['tech']), # GIN索引,优化 tech 数组字段的 __contains 查询
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='website_resp_headers_trgm_idx',
|
||||
fields=['response_headers'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='website_url_trgm_idx',
|
||||
fields=['url'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='website_title_trgm_idx',
|
||||
fields=['title'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 普通唯一约束:url + target 组合唯一
|
||||
@@ -293,18 +337,24 @@ class Directory(models.Model):
|
||||
help_text='请求耗时(单位:纳秒)'
|
||||
)
|
||||
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'directory'
|
||||
verbose_name = '目录'
|
||||
verbose_name_plural = '目录'
|
||||
ordering = ['-discovered_at']
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['-created_at']),
|
||||
models.Index(fields=['target']), # 优化从target_id快速查找下面的目录
|
||||
models.Index(fields=['url']), # URL索引,优化搜索和唯一约束
|
||||
models.Index(fields=['status']), # 状态码索引,优化筛选
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='directory_url_trgm_idx',
|
||||
fields=['url'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 普通唯一约束:target + url 组合唯一
|
||||
@@ -358,23 +408,23 @@ class HostPortMapping(models.Model):
|
||||
)
|
||||
|
||||
# ==================== 时间字段 ====================
|
||||
discovered_at = models.DateTimeField(
|
||||
created_at = models.DateTimeField(
|
||||
auto_now_add=True,
|
||||
help_text='发现时间'
|
||||
help_text='创建时间'
|
||||
)
|
||||
|
||||
class Meta:
|
||||
db_table = 'host_port_mapping'
|
||||
verbose_name = '主机端口映射'
|
||||
verbose_name_plural = '主机端口映射'
|
||||
ordering = ['-discovered_at']
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['target']), # 优化按目标查询
|
||||
models.Index(fields=['host']), # 优化按主机名查询
|
||||
models.Index(fields=['ip']), # 优化按IP查询
|
||||
models.Index(fields=['port']), # 优化按端口查询
|
||||
models.Index(fields=['host', 'ip']), # 优化组合查询
|
||||
models.Index(fields=['-discovered_at']), # 优化时间排序
|
||||
models.Index(fields=['-created_at']), # 优化时间排序
|
||||
]
|
||||
constraints = [
|
||||
# 复合唯一约束:target + host + ip + port 组合唯一
|
||||
@@ -408,7 +458,7 @@ class Vulnerability(models.Model):
|
||||
)
|
||||
|
||||
# ==================== 核心字段 ====================
|
||||
url = models.TextField(help_text='漏洞所在的URL')
|
||||
url = models.CharField(max_length=2000, help_text='漏洞所在的URL')
|
||||
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli)')
|
||||
severity = models.CharField(
|
||||
max_length=20,
|
||||
@@ -422,19 +472,20 @@ class Vulnerability(models.Model):
|
||||
raw_output = models.JSONField(blank=True, default=dict, help_text='工具原始输出')
|
||||
|
||||
# ==================== 时间字段 ====================
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='首次发现时间')
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'vulnerability'
|
||||
verbose_name = '漏洞'
|
||||
verbose_name_plural = '漏洞'
|
||||
ordering = ['-discovered_at']
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['target']),
|
||||
models.Index(fields=['vuln_type']),
|
||||
models.Index(fields=['severity']),
|
||||
models.Index(fields=['source']),
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['url']), # url索引,优化智能过滤搜索
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.contrib.postgres.indexes import GinIndex
|
||||
from django.core.validators import MinValueValidator, MaxValueValidator
|
||||
|
||||
|
||||
@@ -15,17 +16,23 @@ class SubdomainSnapshot(models.Model):
|
||||
)
|
||||
|
||||
name = models.CharField(max_length=1000, help_text='子域名名称')
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'subdomain_snapshot'
|
||||
verbose_name = '子域名快照'
|
||||
verbose_name_plural = '子域名快照'
|
||||
ordering = ['-discovered_at']
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['name']),
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['-created_at']),
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='subdomain_snap_name_trgm',
|
||||
fields=['name'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:同一次扫描中,同一个子域名只能记录一次
|
||||
@@ -54,34 +61,57 @@ class WebsiteSnapshot(models.Model):
|
||||
)
|
||||
|
||||
# 扫描结果数据
|
||||
url = models.CharField(max_length=2000, help_text='站点URL')
|
||||
url = models.TextField(help_text='站点URL')
|
||||
host = models.CharField(max_length=253, blank=True, default='', help_text='主机名(域名或IP地址)')
|
||||
title = models.CharField(max_length=500, blank=True, default='', help_text='页面标题')
|
||||
status = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
|
||||
title = models.TextField(blank=True, default='', help_text='页面标题')
|
||||
status_code = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
|
||||
content_length = models.BigIntegerField(null=True, blank=True, help_text='内容长度')
|
||||
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
|
||||
web_server = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
|
||||
location = models.TextField(blank=True, default='', help_text='重定向位置')
|
||||
webserver = models.TextField(blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.TextField(blank=True, default='', help_text='内容类型')
|
||||
tech = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
default=list,
|
||||
help_text='技术栈'
|
||||
)
|
||||
body_preview = models.TextField(blank=True, default='', help_text='响应体预览')
|
||||
response_body = models.TextField(blank=True, default='', help_text='HTTP响应体')
|
||||
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
response_headers = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='原始HTTP响应头'
|
||||
)
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'website_snapshot'
|
||||
verbose_name = '网站快照'
|
||||
verbose_name_plural = '网站快照'
|
||||
ordering = ['-discovered_at']
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['url']),
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['title']), # title索引,优化标题搜索
|
||||
models.Index(fields=['-created_at']),
|
||||
GinIndex(fields=['tech']), # GIN索引,优化数组字段查询
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='ws_snap_resp_hdr_trgm',
|
||||
fields=['response_headers'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='ws_snap_url_trgm',
|
||||
fields=['url'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='ws_snap_title_trgm',
|
||||
fields=['title'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:同一次扫描中,同一个URL只能记录一次
|
||||
@@ -118,18 +148,25 @@ class DirectorySnapshot(models.Model):
|
||||
lines = models.IntegerField(null=True, blank=True, help_text='响应体行数(按换行符分割)')
|
||||
content_type = models.CharField(max_length=200, blank=True, default='', help_text='响应头 Content-Type 值')
|
||||
duration = models.BigIntegerField(null=True, blank=True, help_text='请求耗时(单位:纳秒)')
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'directory_snapshot'
|
||||
verbose_name = '目录快照'
|
||||
verbose_name_plural = '目录快照'
|
||||
ordering = ['-discovered_at']
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['url']),
|
||||
models.Index(fields=['status']), # 状态码索引,优化筛选
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['content_type']), # content_type索引,优化内容类型搜索
|
||||
models.Index(fields=['-created_at']),
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='dir_snap_url_trgm',
|
||||
fields=['url'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:同一次扫描中,同一个目录URL只能记录一次
|
||||
@@ -183,16 +220,16 @@ class HostPortMappingSnapshot(models.Model):
|
||||
)
|
||||
|
||||
# ==================== 时间字段 ====================
|
||||
discovered_at = models.DateTimeField(
|
||||
created_at = models.DateTimeField(
|
||||
auto_now_add=True,
|
||||
help_text='发现时间'
|
||||
help_text='创建时间'
|
||||
)
|
||||
|
||||
class Meta:
|
||||
db_table = 'host_port_mapping_snapshot'
|
||||
verbose_name = '主机端口映射快照'
|
||||
verbose_name_plural = '主机端口映射快照'
|
||||
ordering = ['-discovered_at']
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']), # 优化按扫描查询
|
||||
models.Index(fields=['host']), # 优化按主机名查询
|
||||
@@ -200,7 +237,7 @@ class HostPortMappingSnapshot(models.Model):
|
||||
models.Index(fields=['port']), # 优化按端口查询
|
||||
models.Index(fields=['host', 'ip']), # 优化组合查询
|
||||
models.Index(fields=['scan', 'host']), # 优化扫描+主机查询
|
||||
models.Index(fields=['-discovered_at']), # 优化时间排序
|
||||
models.Index(fields=['-created_at']), # 优化时间排序
|
||||
]
|
||||
constraints = [
|
||||
# 复合唯一约束:同一次扫描中,scan + host + ip + port 组合唯一
|
||||
@@ -230,26 +267,26 @@ class EndpointSnapshot(models.Model):
|
||||
)
|
||||
|
||||
# 扫描结果数据
|
||||
url = models.CharField(max_length=2000, help_text='端点URL')
|
||||
url = models.TextField(help_text='端点URL')
|
||||
host = models.CharField(
|
||||
max_length=253,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='主机名(域名或IP地址)'
|
||||
)
|
||||
title = models.CharField(max_length=1000, blank=True, default='', help_text='页面标题')
|
||||
title = models.TextField(blank=True, default='', help_text='页面标题')
|
||||
status_code = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
|
||||
content_length = models.IntegerField(null=True, blank=True, help_text='内容长度')
|
||||
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
|
||||
webserver = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
|
||||
location = models.TextField(blank=True, default='', help_text='重定向位置')
|
||||
webserver = models.TextField(blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.TextField(blank=True, default='', help_text='内容类型')
|
||||
tech = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
default=list,
|
||||
help_text='技术栈'
|
||||
)
|
||||
body_preview = models.CharField(max_length=1000, blank=True, default='', help_text='响应体预览')
|
||||
response_body = models.TextField(blank=True, default='', help_text='HTTP响应体')
|
||||
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
|
||||
matched_gf_patterns = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
@@ -257,19 +294,43 @@ class EndpointSnapshot(models.Model):
|
||||
default=list,
|
||||
help_text='匹配的GF模式列表'
|
||||
)
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
response_headers = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='原始HTTP响应头'
|
||||
)
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'endpoint_snapshot'
|
||||
verbose_name = '端点快照'
|
||||
verbose_name_plural = '端点快照'
|
||||
ordering = ['-discovered_at']
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['url']),
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['title']), # title索引,优化标题搜索
|
||||
models.Index(fields=['status_code']), # 状态码索引,优化筛选
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['webserver']), # webserver索引,优化服务器搜索
|
||||
models.Index(fields=['-created_at']),
|
||||
GinIndex(fields=['tech']), # GIN索引,优化数组字段查询
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='ep_snap_resp_hdr_trgm',
|
||||
fields=['response_headers'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='ep_snap_url_trgm',
|
||||
fields=['url'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='ep_snap_title_trgm',
|
||||
fields=['title'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:同一次扫描中,同一个URL只能记录一次
|
||||
@@ -302,7 +363,7 @@ class VulnerabilitySnapshot(models.Model):
|
||||
)
|
||||
|
||||
# ==================== 核心字段 ====================
|
||||
url = models.TextField(help_text='漏洞所在的URL')
|
||||
url = models.CharField(max_length=2000, help_text='漏洞所在的URL')
|
||||
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli)')
|
||||
severity = models.CharField(
|
||||
max_length=20,
|
||||
@@ -316,19 +377,20 @@ class VulnerabilitySnapshot(models.Model):
|
||||
raw_output = models.JSONField(blank=True, default=dict, help_text='工具原始输出')
|
||||
|
||||
# ==================== 时间字段 ====================
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'vulnerability_snapshot'
|
||||
verbose_name = '漏洞快照'
|
||||
verbose_name_plural = '漏洞快照'
|
||||
ordering = ['-discovered_at']
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['url']), # url索引,优化URL搜索
|
||||
models.Index(fields=['vuln_type']),
|
||||
models.Index(fields=['severity']),
|
||||
models.Index(fields=['source']),
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -9,6 +9,7 @@ from django.db import transaction
|
||||
from apps.asset.models.asset_models import Directory
|
||||
from apps.asset.dtos import DirectoryDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from apps.common.utils import deduplicate_for_bulk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,6 +25,8 @@ class DjangoDirectoryRepository:
|
||||
存在则更新所有字段,不存在则创建。
|
||||
使用 Django 原生 update_conflicts。
|
||||
|
||||
注意:自动按模型唯一约束去重,保留最后一条记录。
|
||||
|
||||
Args:
|
||||
items: Directory DTO 列表
|
||||
|
||||
@@ -34,6 +37,9 @@ class DjangoDirectoryRepository:
|
||||
return 0
|
||||
|
||||
try:
|
||||
# 自动按模型唯一约束去重
|
||||
unique_items = deduplicate_for_bulk(items, Directory)
|
||||
|
||||
# 直接从 DTO 字段构建 Model
|
||||
directories = [
|
||||
Directory(
|
||||
@@ -46,7 +52,7 @@ class DjangoDirectoryRepository:
|
||||
content_type=item.content_type or '',
|
||||
duration=item.duration
|
||||
)
|
||||
for item in items
|
||||
for item in unique_items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
@@ -61,20 +67,74 @@ class DjangoDirectoryRepository:
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
logger.debug(f"批量 upsert Directory 成功: {len(items)} 条")
|
||||
return len(items)
|
||||
logger.debug(f"批量 upsert Directory 成功: {len(unique_items)} 条")
|
||||
return len(unique_items)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量 upsert Directory 失败: {e}")
|
||||
raise
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[DirectoryDTO]) -> int:
|
||||
"""
|
||||
批量创建 Directory(存在即跳过)
|
||||
|
||||
与 bulk_upsert 不同,此方法不会更新已存在的记录。
|
||||
适用于批量添加场景,只提供 URL,没有其他字段数据。
|
||||
|
||||
注意:自动按模型唯一约束去重,保留最后一条记录。
|
||||
|
||||
Args:
|
||||
items: Directory DTO 列表
|
||||
|
||||
Returns:
|
||||
int: 处理的记录数
|
||||
"""
|
||||
if not items:
|
||||
return 0
|
||||
|
||||
try:
|
||||
# 自动按模型唯一约束去重
|
||||
unique_items = deduplicate_for_bulk(items, Directory)
|
||||
|
||||
directories = [
|
||||
Directory(
|
||||
target_id=item.target_id,
|
||||
url=item.url,
|
||||
status=item.status,
|
||||
content_length=item.content_length,
|
||||
words=item.words,
|
||||
lines=item.lines,
|
||||
content_type=item.content_type or '',
|
||||
duration=item.duration
|
||||
)
|
||||
for item in unique_items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
Directory.objects.bulk_create(
|
||||
directories,
|
||||
ignore_conflicts=True,
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
logger.debug(f"批量创建 Directory 成功(ignore_conflicts): {len(unique_items)} 条")
|
||||
return len(unique_items)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量创建 Directory 失败: {e}")
|
||||
raise
|
||||
|
||||
def count_by_target(self, target_id: int) -> int:
|
||||
"""统计目标下的目录总数"""
|
||||
return Directory.objects.filter(target_id=target_id).count()
|
||||
|
||||
def get_all(self):
|
||||
"""获取所有目录"""
|
||||
return Directory.objects.all().order_by('-discovered_at')
|
||||
return Directory.objects.all().order_by('-created_at')
|
||||
|
||||
def get_by_target(self, target_id: int):
|
||||
"""获取目标下的所有目录"""
|
||||
return Directory.objects.filter(target_id=target_id).order_by('-discovered_at')
|
||||
return Directory.objects.filter(target_id=target_id).order_by('-created_at')
|
||||
|
||||
def get_urls_for_export(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
|
||||
"""流式导出目标下的所有目录 URL"""
|
||||
@@ -91,3 +151,31 @@ class DjangoDirectoryRepository:
|
||||
except Exception as e:
|
||||
logger.error("流式导出目录 URL 失败 - Target ID: %s, 错误: %s", target_id, e)
|
||||
raise
|
||||
|
||||
def iter_raw_data_for_export(
|
||||
self,
|
||||
target_id: int,
|
||||
batch_size: int = 1000
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
batch_size: 每批数据量
|
||||
|
||||
Yields:
|
||||
包含所有目录字段的字典
|
||||
"""
|
||||
qs = (
|
||||
Directory.objects
|
||||
.filter(target_id=target_id)
|
||||
.values(
|
||||
'url', 'status', 'content_length', 'words',
|
||||
'lines', 'content_type', 'duration', 'created_at'
|
||||
)
|
||||
.order_by('url')
|
||||
)
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
yield row
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Endpoint Repository - Django ORM 实现"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Iterator
|
||||
|
||||
from apps.asset.models import Endpoint
|
||||
from apps.asset.dtos.asset import EndpointDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from apps.common.utils import deduplicate_for_bulk
|
||||
from django.db import transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,6 +23,8 @@ class DjangoEndpointRepository:
|
||||
存在则更新所有字段,不存在则创建。
|
||||
使用 Django 原生 update_conflicts。
|
||||
|
||||
注意:自动按模型唯一约束去重,保留最后一条记录。
|
||||
|
||||
Args:
|
||||
items: 端点 DTO 列表
|
||||
|
||||
@@ -32,6 +35,9 @@ class DjangoEndpointRepository:
|
||||
return 0
|
||||
|
||||
try:
|
||||
# 自动按模型唯一约束去重
|
||||
unique_items = deduplicate_for_bulk(items, Endpoint)
|
||||
|
||||
# 直接从 DTO 字段构建 Model
|
||||
endpoints = [
|
||||
Endpoint(
|
||||
@@ -42,14 +48,15 @@ class DjangoEndpointRepository:
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
webserver=item.webserver or '',
|
||||
body_preview=item.body_preview or '',
|
||||
response_body=item.response_body or '',
|
||||
content_type=item.content_type or '',
|
||||
tech=item.tech if item.tech else [],
|
||||
vhost=item.vhost,
|
||||
location=item.location or '',
|
||||
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
|
||||
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else [],
|
||||
response_headers=item.response_headers if item.response_headers else ''
|
||||
)
|
||||
for item in items
|
||||
for item in unique_items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
@@ -59,14 +66,14 @@ class DjangoEndpointRepository:
|
||||
unique_fields=['url', 'target'],
|
||||
update_fields=[
|
||||
'host', 'title', 'status_code', 'content_length',
|
||||
'webserver', 'body_preview', 'content_type', 'tech',
|
||||
'vhost', 'location', 'matched_gf_patterns'
|
||||
'webserver', 'response_body', 'content_type', 'tech',
|
||||
'vhost', 'location', 'matched_gf_patterns', 'response_headers'
|
||||
],
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
logger.debug(f"批量 upsert 端点成功: {len(items)} 条")
|
||||
return len(items)
|
||||
logger.debug(f"批量 upsert 端点成功: {len(unique_items)} 条")
|
||||
return len(unique_items)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量 upsert 端点失败: {e}")
|
||||
@@ -74,7 +81,7 @@ class DjangoEndpointRepository:
|
||||
|
||||
def get_all(self):
|
||||
"""获取所有端点(全局查询)"""
|
||||
return Endpoint.objects.all().order_by('-discovered_at')
|
||||
return Endpoint.objects.all().order_by('-created_at')
|
||||
|
||||
def get_by_target(self, target_id: int):
|
||||
"""
|
||||
@@ -86,7 +93,7 @@ class DjangoEndpointRepository:
|
||||
Returns:
|
||||
QuerySet: 端点查询集
|
||||
"""
|
||||
return Endpoint.objects.filter(target_id=target_id).order_by('-discovered_at')
|
||||
return Endpoint.objects.filter(target_id=target_id).order_by('-created_at')
|
||||
|
||||
def count_by_target(self, target_id: int) -> int:
|
||||
"""
|
||||
@@ -99,3 +106,89 @@ class DjangoEndpointRepository:
|
||||
int: 端点数量
|
||||
"""
|
||||
return Endpoint.objects.filter(target_id=target_id).count()
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[EndpointDTO]) -> int:
|
||||
"""
|
||||
批量创建端点(存在即跳过)
|
||||
|
||||
与 bulk_upsert 不同,此方法不会更新已存在的记录。
|
||||
适用于快速扫描场景,只提供 URL,没有其他字段数据。
|
||||
|
||||
注意:自动按模型唯一约束去重,保留最后一条记录。
|
||||
|
||||
Args:
|
||||
items: 端点 DTO 列表
|
||||
|
||||
Returns:
|
||||
int: 处理的记录数
|
||||
"""
|
||||
if not items:
|
||||
return 0
|
||||
|
||||
try:
|
||||
# 自动按模型唯一约束去重
|
||||
unique_items = deduplicate_for_bulk(items, Endpoint)
|
||||
|
||||
# 直接从 DTO 字段构建 Model
|
||||
endpoints = [
|
||||
Endpoint(
|
||||
target_id=item.target_id,
|
||||
url=item.url,
|
||||
host=item.host or '',
|
||||
title=item.title or '',
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
webserver=item.webserver or '',
|
||||
response_body=item.response_body or '',
|
||||
content_type=item.content_type or '',
|
||||
tech=item.tech if item.tech else [],
|
||||
vhost=item.vhost,
|
||||
location=item.location or '',
|
||||
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else [],
|
||||
response_headers=item.response_headers if item.response_headers else ''
|
||||
)
|
||||
for item in unique_items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
Endpoint.objects.bulk_create(
|
||||
endpoints,
|
||||
ignore_conflicts=True,
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
logger.debug(f"批量创建端点成功(ignore_conflicts): {len(unique_items)} 条")
|
||||
return len(unique_items)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量创建端点失败: {e}")
|
||||
raise
|
||||
|
||||
def iter_raw_data_for_export(
|
||||
self,
|
||||
target_id: int,
|
||||
batch_size: int = 1000
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
batch_size: 每批数据量
|
||||
|
||||
Yields:
|
||||
包含所有端点字段的字典
|
||||
"""
|
||||
qs = (
|
||||
Endpoint.objects
|
||||
.filter(target_id=target_id)
|
||||
.values(
|
||||
'url', 'host', 'location', 'title', 'status_code',
|
||||
'content_length', 'content_type', 'webserver', 'tech',
|
||||
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
|
||||
)
|
||||
.order_by('url')
|
||||
)
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
yield row
|
||||
|
||||
@@ -1,32 +1,36 @@
|
||||
"""HostPortMapping Repository - Django ORM 实现"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
from typing import List, Iterator, Dict, Optional
|
||||
|
||||
from django.db.models import QuerySet, Min
|
||||
|
||||
from apps.asset.models.asset_models import HostPortMapping
|
||||
from apps.asset.dtos.asset import HostPortMappingDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from apps.common.utils import deduplicate_for_bulk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoHostPortMappingRepository:
|
||||
"""HostPortMapping Repository - Django ORM 实现"""
|
||||
"""HostPortMapping Repository - Django ORM 实现
|
||||
|
||||
职责:纯数据访问,不包含业务逻辑
|
||||
"""
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[HostPortMappingDTO]) -> int:
|
||||
"""
|
||||
批量创建主机端口关联(忽略冲突)
|
||||
|
||||
注意:自动按模型唯一约束去重,保留最后一条记录。
|
||||
|
||||
Args:
|
||||
items: 主机端口关联 DTO 列表
|
||||
|
||||
Returns:
|
||||
int: 实际创建的记录数(注意:ignore_conflicts 时可能为 0)
|
||||
|
||||
Note:
|
||||
- 基于唯一约束 (target + host + ip + port) 自动去重
|
||||
- 忽略已存在的记录,不更新
|
||||
int: 实际创建的记录数
|
||||
"""
|
||||
try:
|
||||
logger.debug("准备批量创建主机端口关联 - 数量: %d", len(items))
|
||||
@@ -34,8 +38,10 @@ class DjangoHostPortMappingRepository:
|
||||
if not items:
|
||||
logger.debug("主机端口关联为空,跳过创建")
|
||||
return 0
|
||||
|
||||
# 自动按模型唯一约束去重
|
||||
unique_items = deduplicate_for_bulk(items, HostPortMapping)
|
||||
|
||||
# 构建记录对象
|
||||
records = [
|
||||
HostPortMapping(
|
||||
target_id=item.target_id,
|
||||
@@ -43,10 +49,9 @@ class DjangoHostPortMappingRepository:
|
||||
ip=item.ip,
|
||||
port=item.port
|
||||
)
|
||||
for item in items
|
||||
for item in unique_items
|
||||
]
|
||||
|
||||
# 批量创建(忽略冲突,基于唯一约束去重)
|
||||
created = HostPortMapping.objects.bulk_create(
|
||||
records,
|
||||
ignore_conflicts=True
|
||||
@@ -90,79 +95,47 @@ class DjangoHostPortMappingRepository:
|
||||
for ip in queryset:
|
||||
yield ip
|
||||
|
||||
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
|
||||
from django.db.models import Min
|
||||
def get_queryset_by_target(self, target_id: int) -> QuerySet:
|
||||
"""获取目标下的 QuerySet"""
|
||||
return HostPortMapping.objects.filter(target_id=target_id)
|
||||
|
||||
qs = HostPortMapping.objects.filter(target_id=target_id)
|
||||
if search:
|
||||
qs = qs.filter(ip__icontains=search)
|
||||
def get_all_queryset(self) -> QuerySet:
|
||||
"""获取所有记录的 QuerySet"""
|
||||
return HostPortMapping.objects.all()
|
||||
|
||||
ip_aggregated = (
|
||||
qs
|
||||
.values('ip')
|
||||
.annotate(
|
||||
discovered_at=Min('discovered_at')
|
||||
)
|
||||
.order_by('-discovered_at')
|
||||
def get_queryset_by_ip(self, ip: str, target_id: Optional[int] = None) -> QuerySet:
|
||||
"""获取指定 IP 的 QuerySet"""
|
||||
qs = HostPortMapping.objects.filter(ip=ip)
|
||||
if target_id:
|
||||
qs = qs.filter(target_id=target_id)
|
||||
return qs
|
||||
|
||||
def iter_raw_data_for_export(
|
||||
self,
|
||||
target_id: int,
|
||||
batch_size: int = 1000
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
batch_size: 每批数据量
|
||||
|
||||
Yields:
|
||||
{
|
||||
'ip': '192.168.1.1',
|
||||
'host': 'example.com',
|
||||
'port': 80,
|
||||
'created_at': datetime
|
||||
}
|
||||
"""
|
||||
qs = (
|
||||
HostPortMapping.objects
|
||||
.filter(target_id=target_id)
|
||||
.values('ip', 'host', 'port', 'created_at')
|
||||
.order_by('ip', 'host', 'port')
|
||||
)
|
||||
|
||||
results = []
|
||||
for item in ip_aggregated:
|
||||
ip = item['ip']
|
||||
mappings = (
|
||||
HostPortMapping.objects
|
||||
.filter(target_id=target_id, ip=ip)
|
||||
.values('host', 'port')
|
||||
.distinct()
|
||||
)
|
||||
|
||||
hosts = sorted({m['host'] for m in mappings})
|
||||
ports = sorted({m['port'] for m in mappings})
|
||||
|
||||
results.append({
|
||||
'ip': ip,
|
||||
'hosts': hosts,
|
||||
'ports': ports,
|
||||
'discovered_at': item['discovered_at'],
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def get_all_ip_aggregation(self, search: str = None):
|
||||
"""获取所有 IP 聚合数据(全局查询)"""
|
||||
from django.db.models import Min
|
||||
|
||||
qs = HostPortMapping.objects.all()
|
||||
if search:
|
||||
qs = qs.filter(ip__icontains=search)
|
||||
|
||||
ip_aggregated = (
|
||||
qs
|
||||
.values('ip')
|
||||
.annotate(
|
||||
discovered_at=Min('discovered_at')
|
||||
)
|
||||
.order_by('-discovered_at')
|
||||
)
|
||||
|
||||
results = []
|
||||
for item in ip_aggregated:
|
||||
ip = item['ip']
|
||||
mappings = (
|
||||
HostPortMapping.objects
|
||||
.filter(ip=ip)
|
||||
.values('host', 'port')
|
||||
.distinct()
|
||||
)
|
||||
|
||||
hosts = sorted({m['host'] for m in mappings})
|
||||
ports = sorted({m['port'] for m in mappings})
|
||||
|
||||
results.append({
|
||||
'ip': ip,
|
||||
'hosts': hosts,
|
||||
'ports': ports,
|
||||
'discovered_at': item['discovered_at'],
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
yield row
|
||||
|
||||
@@ -8,6 +8,7 @@ from django.db import transaction
|
||||
from apps.asset.models.asset_models import Subdomain
|
||||
from apps.asset.dtos import SubdomainDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from apps.common.utils import deduplicate_for_bulk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,6 +21,8 @@ class DjangoSubdomainRepository:
|
||||
"""
|
||||
批量创建子域名,忽略冲突
|
||||
|
||||
注意:自动按模型唯一约束去重,保留最后一条记录。
|
||||
|
||||
Args:
|
||||
items: 子域名 DTO 列表
|
||||
"""
|
||||
@@ -27,12 +30,15 @@ class DjangoSubdomainRepository:
|
||||
return
|
||||
|
||||
try:
|
||||
# 自动按模型唯一约束去重
|
||||
unique_items = deduplicate_for_bulk(items, Subdomain)
|
||||
|
||||
subdomain_objects = [
|
||||
Subdomain(
|
||||
name=item.name,
|
||||
target_id=item.target_id,
|
||||
)
|
||||
for item in items
|
||||
for item in unique_items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
@@ -41,7 +47,7 @@ class DjangoSubdomainRepository:
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
|
||||
logger.debug(f"成功处理 {len(items)} 条子域名记录")
|
||||
logger.debug(f"成功处理 {len(unique_items)} 条子域名记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量插入子域名失败: {e}")
|
||||
@@ -49,11 +55,11 @@ class DjangoSubdomainRepository:
|
||||
|
||||
def get_all(self):
|
||||
"""获取所有子域名"""
|
||||
return Subdomain.objects.all().order_by('-discovered_at')
|
||||
return Subdomain.objects.all().order_by('-created_at')
|
||||
|
||||
def get_by_target(self, target_id: int):
|
||||
"""获取目标下的所有子域名"""
|
||||
return Subdomain.objects.filter(target_id=target_id).order_by('-discovered_at')
|
||||
return Subdomain.objects.filter(target_id=target_id).order_by('-created_at')
|
||||
|
||||
def count_by_target(self, target_id: int) -> int:
|
||||
"""统计目标下的域名数量"""
|
||||
@@ -76,3 +82,28 @@ class DjangoSubdomainRepository:
|
||||
).only('id', 'name')
|
||||
|
||||
return {sd.name: sd for sd in subdomains}
|
||||
|
||||
def iter_raw_data_for_export(
|
||||
self,
|
||||
target_id: int,
|
||||
batch_size: int = 1000
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
batch_size: 每批数据量
|
||||
|
||||
Yields:
|
||||
{'name': 'sub.example.com', 'created_at': datetime}
|
||||
"""
|
||||
qs = (
|
||||
Subdomain.objects
|
||||
.filter(target_id=target_id)
|
||||
.values('name', 'created_at')
|
||||
.order_by('name')
|
||||
)
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
yield row
|
||||
|
||||
@@ -3,12 +3,13 @@ Django ORM 实现的 WebSite Repository
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Generator, Optional
|
||||
from typing import List, Generator, Optional, Iterator
|
||||
from django.db import transaction
|
||||
|
||||
from apps.asset.models.asset_models import WebSite
|
||||
from apps.asset.dtos import WebSiteDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from apps.common.utils import deduplicate_for_bulk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,6 +25,8 @@ class DjangoWebSiteRepository:
|
||||
存在则更新所有字段,不存在则创建。
|
||||
使用 Django 原生 update_conflicts。
|
||||
|
||||
注意:自动按模型唯一约束去重,保留最后一条记录。
|
||||
|
||||
Args:
|
||||
items: WebSite DTO 列表
|
||||
|
||||
@@ -34,6 +37,9 @@ class DjangoWebSiteRepository:
|
||||
return 0
|
||||
|
||||
try:
|
||||
# 自动按模型唯一约束去重
|
||||
unique_items = deduplicate_for_bulk(items, WebSite)
|
||||
|
||||
# 直接从 DTO 字段构建 Model
|
||||
websites = [
|
||||
WebSite(
|
||||
@@ -43,14 +49,15 @@ class DjangoWebSiteRepository:
|
||||
location=item.location or '',
|
||||
title=item.title or '',
|
||||
webserver=item.webserver or '',
|
||||
body_preview=item.body_preview or '',
|
||||
response_body=item.response_body or '',
|
||||
content_type=item.content_type or '',
|
||||
tech=item.tech if item.tech else [],
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
vhost=item.vhost
|
||||
vhost=item.vhost,
|
||||
response_headers=item.response_headers if item.response_headers else ''
|
||||
)
|
||||
for item in items
|
||||
for item in unique_items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
@@ -60,14 +67,14 @@ class DjangoWebSiteRepository:
|
||||
unique_fields=['url', 'target'],
|
||||
update_fields=[
|
||||
'host', 'location', 'title', 'webserver',
|
||||
'body_preview', 'content_type', 'tech',
|
||||
'status_code', 'content_length', 'vhost'
|
||||
'response_body', 'content_type', 'tech',
|
||||
'status_code', 'content_length', 'vhost', 'response_headers'
|
||||
],
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
logger.debug(f"批量 upsert WebSite 成功: {len(items)} 条")
|
||||
return len(items)
|
||||
logger.debug(f"批量 upsert WebSite 成功: {len(unique_items)} 条")
|
||||
return len(unique_items)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量 upsert WebSite 失败: {e}")
|
||||
@@ -76,13 +83,6 @@ class DjangoWebSiteRepository:
|
||||
def get_urls_for_export(self, target_id: int, batch_size: int = 1000) -> Generator[str, None, None]:
|
||||
"""
|
||||
流式导出目标下的所有站点 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
batch_size: 批次大小
|
||||
|
||||
Yields:
|
||||
str: 站点 URL
|
||||
"""
|
||||
try:
|
||||
queryset = WebSite.objects.filter(
|
||||
@@ -97,26 +97,92 @@ class DjangoWebSiteRepository:
|
||||
|
||||
def get_all(self):
|
||||
"""获取所有网站"""
|
||||
return WebSite.objects.all().order_by('-discovered_at')
|
||||
return WebSite.objects.all().order_by('-created_at')
|
||||
|
||||
def get_by_target(self, target_id: int):
|
||||
"""获取目标下的所有网站"""
|
||||
return WebSite.objects.filter(target_id=target_id).order_by('-discovered_at')
|
||||
return WebSite.objects.filter(target_id=target_id).order_by('-created_at')
|
||||
|
||||
def count_by_target(self, target_id: int) -> int:
|
||||
"""统计目标下的站点总数"""
|
||||
return WebSite.objects.filter(target_id=target_id).count()
|
||||
|
||||
def get_by_url(self, url: str, target_id: int) -> Optional[int]:
|
||||
"""
|
||||
根据 URL 和 target_id 查找站点 ID
|
||||
|
||||
Args:
|
||||
url: 站点 URL
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
Optional[int]: 站点 ID,如果不存在返回 None
|
||||
"""
|
||||
"""根据 URL 和 target_id 查找站点 ID"""
|
||||
website = WebSite.objects.filter(url=url, target_id=target_id).first()
|
||||
return website.id if website else None
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[WebSiteDTO]) -> int:
|
||||
"""
|
||||
批量创建 WebSite(存在即跳过)
|
||||
|
||||
注意:自动按模型唯一约束去重,保留最后一条记录。
|
||||
"""
|
||||
if not items:
|
||||
return 0
|
||||
|
||||
try:
|
||||
# 自动按模型唯一约束去重
|
||||
unique_items = deduplicate_for_bulk(items, WebSite)
|
||||
|
||||
websites = [
|
||||
WebSite(
|
||||
target_id=item.target_id,
|
||||
url=item.url,
|
||||
host=item.host or '',
|
||||
location=item.location or '',
|
||||
title=item.title or '',
|
||||
webserver=item.webserver or '',
|
||||
response_body=item.response_body or '',
|
||||
content_type=item.content_type or '',
|
||||
tech=item.tech if item.tech else [],
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
vhost=item.vhost,
|
||||
response_headers=item.response_headers if item.response_headers else ''
|
||||
)
|
||||
for item in unique_items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
WebSite.objects.bulk_create(
|
||||
websites,
|
||||
ignore_conflicts=True,
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
logger.debug(f"批量创建 WebSite 成功(ignore_conflicts): {len(unique_items)} 条")
|
||||
return len(unique_items)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量创建 WebSite 失败: {e}")
|
||||
raise
|
||||
|
||||
def iter_raw_data_for_export(
|
||||
self,
|
||||
target_id: int,
|
||||
batch_size: int = 1000
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
batch_size: 每批数据量
|
||||
|
||||
Yields:
|
||||
包含所有网站字段的字典
|
||||
"""
|
||||
qs = (
|
||||
WebSite.objects
|
||||
.filter(target_id=target_id)
|
||||
.values(
|
||||
'url', 'host', 'location', 'title', 'status_code',
|
||||
'content_length', 'content_type', 'webserver', 'tech',
|
||||
'response_body', 'response_headers', 'vhost', 'created_at'
|
||||
)
|
||||
.order_by('url')
|
||||
)
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
yield row
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""Directory Snapshot Repository - 目录快照数据访问层"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Iterator
|
||||
from django.db import transaction
|
||||
|
||||
from apps.asset.models import DirectorySnapshot
|
||||
from apps.asset.dtos.snapshot import DirectorySnapshotDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from apps.common.utils import deduplicate_for_bulk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +26,8 @@ class DjangoDirectorySnapshotRepository:
|
||||
|
||||
使用 ignore_conflicts 策略,如果快照已存在(相同 scan + url)则跳过
|
||||
|
||||
注意:会自动按 (scan_id, url) 去重,保留最后一条记录。
|
||||
|
||||
Args:
|
||||
items: 目录快照 DTO 列表
|
||||
|
||||
@@ -37,6 +40,9 @@ class DjangoDirectorySnapshotRepository:
|
||||
return
|
||||
|
||||
try:
|
||||
# 根据模型唯一约束自动去重
|
||||
unique_items = deduplicate_for_bulk(items, DirectorySnapshot)
|
||||
|
||||
# 转换为 Django 模型对象
|
||||
snapshot_objects = [
|
||||
DirectorySnapshot(
|
||||
@@ -49,7 +55,7 @@ class DjangoDirectorySnapshotRepository:
|
||||
content_type=item.content_type,
|
||||
duration=item.duration
|
||||
)
|
||||
for item in items
|
||||
for item in unique_items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
@@ -60,7 +66,7 @@ class DjangoDirectorySnapshotRepository:
|
||||
ignore_conflicts=True
|
||||
)
|
||||
|
||||
logger.debug("成功保存 %d 条目录快照记录", len(items))
|
||||
logger.debug("成功保存 %d 条目录快照记录", len(unique_items))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -72,7 +78,35 @@ class DjangoDirectorySnapshotRepository:
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return DirectorySnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
|
||||
return DirectorySnapshot.objects.filter(scan_id=scan_id).order_by('-created_at')
|
||||
|
||||
def get_all(self):
|
||||
return DirectorySnapshot.objects.all().order_by('-discovered_at')
|
||||
return DirectorySnapshot.objects.all().order_by('-created_at')
|
||||
|
||||
def iter_raw_data_for_export(
|
||||
self,
|
||||
scan_id: int,
|
||||
batch_size: int = 1000
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
batch_size: 每批数据量
|
||||
|
||||
Yields:
|
||||
包含所有目录字段的字典
|
||||
"""
|
||||
qs = (
|
||||
DirectorySnapshot.objects
|
||||
.filter(scan_id=scan_id)
|
||||
.values(
|
||||
'url', 'status', 'content_length', 'words',
|
||||
'lines', 'content_type', 'duration', 'created_at'
|
||||
)
|
||||
.order_by('url')
|
||||
)
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
yield row
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""EndpointSnapshot Repository - Django ORM 实现"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Iterator
|
||||
|
||||
from apps.asset.models.snapshot_models import EndpointSnapshot
|
||||
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from apps.common.utils import deduplicate_for_bulk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,6 +19,8 @@ class DjangoEndpointSnapshotRepository:
|
||||
"""
|
||||
保存端点快照
|
||||
|
||||
注意:会自动按 (scan_id, url) 去重,保留最后一条记录。
|
||||
|
||||
Args:
|
||||
items: 端点快照 DTO 列表
|
||||
|
||||
@@ -31,13 +34,17 @@ class DjangoEndpointSnapshotRepository:
|
||||
if not items:
|
||||
logger.debug("端点快照为空,跳过保存")
|
||||
return
|
||||
|
||||
# 根据模型唯一约束自动去重
|
||||
unique_items = deduplicate_for_bulk(items, EndpointSnapshot)
|
||||
|
||||
# 构建快照对象
|
||||
snapshots = []
|
||||
for item in items:
|
||||
for item in unique_items:
|
||||
snapshots.append(EndpointSnapshot(
|
||||
scan_id=item.scan_id,
|
||||
url=item.url,
|
||||
host=item.host if item.host else '',
|
||||
title=item.title,
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
@@ -45,9 +52,10 @@ class DjangoEndpointSnapshotRepository:
|
||||
webserver=item.webserver,
|
||||
content_type=item.content_type,
|
||||
tech=item.tech if item.tech else [],
|
||||
body_preview=item.body_preview,
|
||||
response_body=item.response_body,
|
||||
vhost=item.vhost,
|
||||
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
|
||||
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else [],
|
||||
response_headers=item.response_headers if item.response_headers else ''
|
||||
))
|
||||
|
||||
# 批量创建(忽略冲突,基于唯一约束去重)
|
||||
@@ -68,7 +76,36 @@ class DjangoEndpointSnapshotRepository:
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return EndpointSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
|
||||
return EndpointSnapshot.objects.filter(scan_id=scan_id).order_by('-created_at')
|
||||
|
||||
def get_all(self):
|
||||
return EndpointSnapshot.objects.all().order_by('-discovered_at')
|
||||
return EndpointSnapshot.objects.all().order_by('-created_at')
|
||||
|
||||
def iter_raw_data_for_export(
|
||||
self,
|
||||
scan_id: int,
|
||||
batch_size: int = 1000
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
batch_size: 每批数据量
|
||||
|
||||
Yields:
|
||||
包含所有端点字段的字典
|
||||
"""
|
||||
qs = (
|
||||
EndpointSnapshot.objects
|
||||
.filter(scan_id=scan_id)
|
||||
.values(
|
||||
'url', 'host', 'location', 'title', 'status_code',
|
||||
'content_length', 'content_type', 'webserver', 'tech',
|
||||
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
|
||||
)
|
||||
.order_by('url')
|
||||
)
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
yield row
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import List, Iterator
|
||||
from apps.asset.models.snapshot_models import HostPortMappingSnapshot
|
||||
from apps.asset.dtos.snapshot import HostPortMappingSnapshotDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from apps.common.utils import deduplicate_for_bulk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,6 +19,8 @@ class DjangoHostPortMappingSnapshotRepository:
|
||||
"""
|
||||
保存主机端口关联快照
|
||||
|
||||
注意:会自动按 (scan_id, host, ip, port) 去重,保留最后一条记录。
|
||||
|
||||
Args:
|
||||
items: 主机端口关联快照 DTO 列表
|
||||
|
||||
@@ -31,10 +34,13 @@ class DjangoHostPortMappingSnapshotRepository:
|
||||
if not items:
|
||||
logger.debug("主机端口关联快照为空,跳过保存")
|
||||
return
|
||||
|
||||
# 根据模型唯一约束自动去重
|
||||
unique_items = deduplicate_for_bulk(items, HostPortMappingSnapshot)
|
||||
|
||||
# 构建快照对象
|
||||
snapshots = []
|
||||
for item in items:
|
||||
for item in unique_items:
|
||||
snapshots.append(HostPortMappingSnapshot(
|
||||
scan_id=item.scan_id,
|
||||
host=item.host,
|
||||
@@ -59,20 +65,28 @@ class DjangoHostPortMappingSnapshotRepository:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
|
||||
def get_ip_aggregation_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
from django.db.models import Min
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
qs = HostPortMappingSnapshot.objects.filter(scan_id=scan_id)
|
||||
if search:
|
||||
qs = qs.filter(ip__icontains=search)
|
||||
|
||||
# 应用智能过滤
|
||||
if filter_query:
|
||||
field_mapping = {
|
||||
'ip': 'ip',
|
||||
'port': 'port',
|
||||
'host': 'host',
|
||||
}
|
||||
qs = apply_filters(qs, filter_query, field_mapping)
|
||||
|
||||
ip_aggregated = (
|
||||
qs
|
||||
.values('ip')
|
||||
.annotate(
|
||||
discovered_at=Min('discovered_at')
|
||||
created_at=Min('created_at')
|
||||
)
|
||||
.order_by('-discovered_at')
|
||||
.order_by('-created_at')
|
||||
)
|
||||
|
||||
results = []
|
||||
@@ -92,24 +106,32 @@ class DjangoHostPortMappingSnapshotRepository:
|
||||
'ip': ip,
|
||||
'hosts': hosts,
|
||||
'ports': ports,
|
||||
'discovered_at': item['discovered_at'],
|
||||
'created_at': item['created_at'],
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def get_all_ip_aggregation(self, search: str = None):
|
||||
def get_all_ip_aggregation(self, filter_query: str = None):
|
||||
"""获取所有 IP 聚合数据"""
|
||||
from django.db.models import Min
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
qs = HostPortMappingSnapshot.objects.all()
|
||||
if search:
|
||||
qs = qs.filter(ip__icontains=search)
|
||||
|
||||
# 应用智能过滤
|
||||
if filter_query:
|
||||
field_mapping = {
|
||||
'ip': 'ip',
|
||||
'port': 'port',
|
||||
'host': 'host',
|
||||
}
|
||||
qs = apply_filters(qs, filter_query, field_mapping)
|
||||
|
||||
ip_aggregated = (
|
||||
qs
|
||||
.values('ip')
|
||||
.annotate(discovered_at=Min('discovered_at'))
|
||||
.order_by('-discovered_at')
|
||||
.annotate(created_at=Min('created_at'))
|
||||
.order_by('-created_at')
|
||||
)
|
||||
|
||||
results = []
|
||||
@@ -127,7 +149,7 @@ class DjangoHostPortMappingSnapshotRepository:
|
||||
'ip': ip,
|
||||
'hosts': hosts,
|
||||
'ports': ports,
|
||||
'discovered_at': item['discovered_at'],
|
||||
'created_at': item['created_at'],
|
||||
})
|
||||
return results
|
||||
|
||||
@@ -143,3 +165,33 @@ class DjangoHostPortMappingSnapshotRepository:
|
||||
)
|
||||
for ip in queryset:
|
||||
yield ip
|
||||
|
||||
def iter_raw_data_for_export(
|
||||
self,
|
||||
scan_id: int,
|
||||
batch_size: int = 1000
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
batch_size: 每批数据量
|
||||
|
||||
Yields:
|
||||
{
|
||||
'ip': '192.168.1.1',
|
||||
'host': 'example.com',
|
||||
'port': 80,
|
||||
'created_at': datetime
|
||||
}
|
||||
"""
|
||||
qs = (
|
||||
HostPortMappingSnapshot.objects
|
||||
.filter(scan_id=scan_id)
|
||||
.values('ip', 'host', 'port', 'created_at')
|
||||
.order_by('ip', 'host', 'port')
|
||||
)
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
yield row
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Django ORM 实现的 SubdomainSnapshot Repository"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Iterator
|
||||
|
||||
from apps.asset.models.snapshot_models import SubdomainSnapshot
|
||||
from apps.asset.dtos import SubdomainSnapshotDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from apps.common.utils import deduplicate_for_bulk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,6 +19,8 @@ class DjangoSubdomainSnapshotRepository:
|
||||
"""
|
||||
保存子域名快照
|
||||
|
||||
注意:会自动按 (scan_id, name) 去重,保留最后一条记录。
|
||||
|
||||
Args:
|
||||
items: 子域名快照 DTO 列表
|
||||
|
||||
@@ -31,10 +34,13 @@ class DjangoSubdomainSnapshotRepository:
|
||||
if not items:
|
||||
logger.debug("子域名快照为空,跳过保存")
|
||||
return
|
||||
|
||||
# 根据模型唯一约束自动去重
|
||||
unique_items = deduplicate_for_bulk(items, SubdomainSnapshot)
|
||||
|
||||
# 构建快照对象
|
||||
snapshots = []
|
||||
for item in items:
|
||||
for item in unique_items:
|
||||
snapshots.append(SubdomainSnapshot(
|
||||
scan_id=item.scan_id,
|
||||
name=item.name,
|
||||
@@ -55,7 +61,32 @@ class DjangoSubdomainSnapshotRepository:
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return SubdomainSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
|
||||
return SubdomainSnapshot.objects.filter(scan_id=scan_id).order_by('-created_at')
|
||||
|
||||
def get_all(self):
|
||||
return SubdomainSnapshot.objects.all().order_by('-discovered_at')
|
||||
return SubdomainSnapshot.objects.all().order_by('-created_at')
|
||||
|
||||
def iter_raw_data_for_export(
|
||||
self,
|
||||
scan_id: int,
|
||||
batch_size: int = 1000
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
batch_size: 每批数据量
|
||||
|
||||
Yields:
|
||||
{'name': 'sub.example.com', 'created_at': datetime}
|
||||
"""
|
||||
qs = (
|
||||
SubdomainSnapshot.objects
|
||||
.filter(scan_id=scan_id)
|
||||
.values('name', 'created_at')
|
||||
.order_by('name')
|
||||
)
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
yield row
|
||||
|
||||
@@ -8,6 +8,7 @@ from django.db import transaction
|
||||
from apps.asset.models import VulnerabilitySnapshot
|
||||
from apps.asset.dtos.snapshot import VulnerabilitySnapshotDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from apps.common.utils import deduplicate_for_bulk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,12 +22,17 @@ class DjangoVulnerabilitySnapshotRepository:
|
||||
|
||||
使用 ``ignore_conflicts`` 策略,如果快照已存在则跳过。
|
||||
具体唯一约束由数据库模型控制。
|
||||
|
||||
注意:会自动按唯一约束字段去重,保留最后一条记录。
|
||||
"""
|
||||
if not items:
|
||||
logger.warning("漏洞快照列表为空,跳过保存")
|
||||
return
|
||||
|
||||
try:
|
||||
# 根据模型唯一约束自动去重
|
||||
unique_items = deduplicate_for_bulk(items, VulnerabilitySnapshot)
|
||||
|
||||
snapshot_objects = [
|
||||
VulnerabilitySnapshot(
|
||||
scan_id=item.scan_id,
|
||||
@@ -38,7 +44,7 @@ class DjangoVulnerabilitySnapshotRepository:
|
||||
description=item.description,
|
||||
raw_output=item.raw_output,
|
||||
)
|
||||
for item in items
|
||||
for item in unique_items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
@@ -47,7 +53,7 @@ class DjangoVulnerabilitySnapshotRepository:
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
|
||||
logger.debug("成功保存 %d 条漏洞快照记录", len(items))
|
||||
logger.debug("成功保存 %d 条漏洞快照记录", len(unique_items))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -60,7 +66,7 @@ class DjangoVulnerabilitySnapshotRepository:
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
"""按扫描任务获取漏洞快照 QuerySet。"""
|
||||
return VulnerabilitySnapshot.objects.filter(scan_id=scan_id).order_by("-discovered_at")
|
||||
return VulnerabilitySnapshot.objects.filter(scan_id=scan_id).order_by("-created_at")
|
||||
|
||||
def get_all(self):
|
||||
return VulnerabilitySnapshot.objects.all().order_by('-discovered_at')
|
||||
return VulnerabilitySnapshot.objects.all().order_by('-created_at')
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""WebsiteSnapshot Repository - Django ORM 实现"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Iterator
|
||||
|
||||
from apps.asset.models.snapshot_models import WebsiteSnapshot
|
||||
from apps.asset.dtos.snapshot import WebsiteSnapshotDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from apps.common.utils import deduplicate_for_bulk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,6 +19,8 @@ class DjangoWebsiteSnapshotRepository:
|
||||
"""
|
||||
保存网站快照
|
||||
|
||||
注意:会自动按 (scan_id, url) 去重,保留最后一条记录。
|
||||
|
||||
Args:
|
||||
items: 网站快照 DTO 列表
|
||||
|
||||
@@ -31,23 +34,27 @@ class DjangoWebsiteSnapshotRepository:
|
||||
if not items:
|
||||
logger.debug("网站快照为空,跳过保存")
|
||||
return
|
||||
|
||||
# 根据模型唯一约束自动去重
|
||||
unique_items = deduplicate_for_bulk(items, WebsiteSnapshot)
|
||||
|
||||
# 构建快照对象
|
||||
snapshots = []
|
||||
for item in items:
|
||||
for item in unique_items:
|
||||
snapshots.append(WebsiteSnapshot(
|
||||
scan_id=item.scan_id,
|
||||
url=item.url,
|
||||
host=item.host,
|
||||
title=item.title,
|
||||
status=item.status,
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
location=item.location,
|
||||
web_server=item.web_server,
|
||||
webserver=item.webserver,
|
||||
content_type=item.content_type,
|
||||
tech=item.tech if item.tech else [],
|
||||
body_preview=item.body_preview,
|
||||
vhost=item.vhost
|
||||
response_body=item.response_body,
|
||||
vhost=item.vhost,
|
||||
response_headers=item.response_headers if item.response_headers else ''
|
||||
))
|
||||
|
||||
# 批量创建(忽略冲突,基于唯一约束去重)
|
||||
@@ -68,7 +75,36 @@ class DjangoWebsiteSnapshotRepository:
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return WebsiteSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
|
||||
return WebsiteSnapshot.objects.filter(scan_id=scan_id).order_by('-created_at')
|
||||
|
||||
def get_all(self):
|
||||
return WebsiteSnapshot.objects.all().order_by('-discovered_at')
|
||||
return WebsiteSnapshot.objects.all().order_by('-created_at')
|
||||
|
||||
def iter_raw_data_for_export(
|
||||
self,
|
||||
scan_id: int,
|
||||
batch_size: int = 1000
|
||||
) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
batch_size: 每批数据量
|
||||
|
||||
Yields:
|
||||
包含所有网站字段的字典
|
||||
"""
|
||||
qs = (
|
||||
WebsiteSnapshot.objects
|
||||
.filter(scan_id=scan_id)
|
||||
.values(
|
||||
'url', 'host', 'location', 'title', 'status_code',
|
||||
'content_length', 'content_type', 'webserver', 'tech',
|
||||
'response_body', 'response_headers', 'vhost', 'created_at'
|
||||
)
|
||||
.order_by('url')
|
||||
)
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
yield row
|
||||
|
||||
@@ -26,9 +26,9 @@ class SubdomainSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Subdomain
|
||||
fields = [
|
||||
'id', 'name', 'discovered_at', 'target'
|
||||
'id', 'name', 'created_at', 'target'
|
||||
]
|
||||
read_only_fields = ['id', 'discovered_at']
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
|
||||
class SubdomainListSerializer(serializers.ModelSerializer):
|
||||
@@ -41,9 +41,9 @@ class SubdomainListSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Subdomain
|
||||
fields = [
|
||||
'id', 'name', 'discovered_at'
|
||||
'id', 'name', 'created_at'
|
||||
]
|
||||
read_only_fields = ['id', 'discovered_at']
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
|
||||
# class IPAddressListSerializer(serializers.ModelSerializer):
|
||||
@@ -67,9 +67,10 @@ class SubdomainListSerializer(serializers.ModelSerializer):
|
||||
|
||||
|
||||
class WebSiteSerializer(serializers.ModelSerializer):
|
||||
"""站点序列化器"""
|
||||
"""站点序列化器(目标详情页)"""
|
||||
|
||||
subdomain = serializers.CharField(source='subdomain.name', allow_blank=True, default='')
|
||||
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
|
||||
|
||||
class Meta:
|
||||
model = WebSite
|
||||
@@ -83,11 +84,12 @@ class WebSiteSerializer(serializers.ModelSerializer):
|
||||
'content_type',
|
||||
'status_code',
|
||||
'content_length',
|
||||
'body_preview',
|
||||
'response_body',
|
||||
'tech',
|
||||
'vhost',
|
||||
'responseHeaders', # HTTP响应头
|
||||
'subdomain',
|
||||
'discovered_at',
|
||||
'created_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
@@ -107,7 +109,7 @@ class VulnerabilitySerializer(serializers.ModelSerializer):
|
||||
'cvss_score',
|
||||
'description',
|
||||
'raw_output',
|
||||
'discovered_at',
|
||||
'created_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
@@ -126,7 +128,7 @@ class VulnerabilitySnapshotSerializer(serializers.ModelSerializer):
|
||||
'cvss_score',
|
||||
'description',
|
||||
'raw_output',
|
||||
'discovered_at',
|
||||
'created_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
@@ -134,12 +136,13 @@ class VulnerabilitySnapshotSerializer(serializers.ModelSerializer):
|
||||
class EndpointListSerializer(serializers.ModelSerializer):
|
||||
"""端点列表序列化器(用于目标端点列表页)"""
|
||||
|
||||
# 将 GF 匹配模式映射为前端使用的 tags 字段
|
||||
tags = serializers.ListField(
|
||||
# GF 匹配模式(gf-patterns 工具匹配的敏感 URL 模式)
|
||||
gfPatterns = serializers.ListField(
|
||||
child=serializers.CharField(),
|
||||
source='matched_gf_patterns',
|
||||
read_only=True,
|
||||
)
|
||||
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
|
||||
|
||||
class Meta:
|
||||
model = Endpoint
|
||||
@@ -152,11 +155,12 @@ class EndpointListSerializer(serializers.ModelSerializer):
|
||||
'content_length',
|
||||
'content_type',
|
||||
'webserver',
|
||||
'body_preview',
|
||||
'response_body',
|
||||
'tech',
|
||||
'vhost',
|
||||
'tags',
|
||||
'discovered_at',
|
||||
'responseHeaders', # HTTP响应头
|
||||
'gfPatterns',
|
||||
'created_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
@@ -164,7 +168,7 @@ class EndpointListSerializer(serializers.ModelSerializer):
|
||||
class DirectorySerializer(serializers.ModelSerializer):
|
||||
"""目录序列化器"""
|
||||
|
||||
discovered_at = serializers.DateTimeField(read_only=True)
|
||||
created_at = serializers.DateTimeField(read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = Directory
|
||||
@@ -177,7 +181,7 @@ class DirectorySerializer(serializers.ModelSerializer):
|
||||
'lines',
|
||||
'content_type',
|
||||
'duration',
|
||||
'discovered_at',
|
||||
'created_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
@@ -190,12 +194,12 @@ class IPAddressAggregatedSerializer(serializers.Serializer):
|
||||
- ip: IP 地址
|
||||
- hosts: 该 IP 关联的所有主机名列表
|
||||
- ports: 该 IP 关联的所有端口列表
|
||||
- discovered_at: 首次发现时间
|
||||
- created_at: 创建时间
|
||||
"""
|
||||
ip = serializers.IPAddressField(read_only=True)
|
||||
hosts = serializers.ListField(child=serializers.CharField(), read_only=True)
|
||||
ports = serializers.ListField(child=serializers.IntegerField(), read_only=True)
|
||||
discovered_at = serializers.DateTimeField(read_only=True)
|
||||
created_at = serializers.DateTimeField(read_only=True)
|
||||
|
||||
|
||||
# ==================== 快照序列化器 ====================
|
||||
@@ -205,7 +209,7 @@ class SubdomainSnapshotSerializer(serializers.ModelSerializer):
|
||||
|
||||
class Meta:
|
||||
model = SubdomainSnapshot
|
||||
fields = ['id', 'name', 'discovered_at']
|
||||
fields = ['id', 'name', 'created_at']
|
||||
read_only_fields = fields
|
||||
|
||||
|
||||
@@ -213,8 +217,7 @@ class WebsiteSnapshotSerializer(serializers.ModelSerializer):
|
||||
"""网站快照序列化器(用于扫描历史)"""
|
||||
|
||||
subdomain_name = serializers.CharField(source='subdomain.name', read_only=True)
|
||||
webserver = serializers.CharField(source='web_server', read_only=True) # 映射字段名
|
||||
status_code = serializers.IntegerField(source='status', read_only=True) # 映射字段名
|
||||
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
|
||||
|
||||
class Meta:
|
||||
model = WebsiteSnapshot
|
||||
@@ -223,15 +226,16 @@ class WebsiteSnapshotSerializer(serializers.ModelSerializer):
|
||||
'url',
|
||||
'location',
|
||||
'title',
|
||||
'webserver', # 使用映射后的字段名
|
||||
'webserver',
|
||||
'content_type',
|
||||
'status_code', # 使用映射后的字段名
|
||||
'status_code',
|
||||
'content_length',
|
||||
'body_preview',
|
||||
'response_body',
|
||||
'tech',
|
||||
'vhost',
|
||||
'responseHeaders', # HTTP响应头
|
||||
'subdomain_name',
|
||||
'discovered_at',
|
||||
'created_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
@@ -250,7 +254,7 @@ class DirectorySnapshotSerializer(serializers.ModelSerializer):
|
||||
'lines',
|
||||
'content_type',
|
||||
'duration',
|
||||
'discovered_at',
|
||||
'created_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
@@ -258,12 +262,13 @@ class DirectorySnapshotSerializer(serializers.ModelSerializer):
|
||||
class EndpointSnapshotSerializer(serializers.ModelSerializer):
|
||||
"""端点快照序列化器(用于扫描历史)"""
|
||||
|
||||
# 将 GF 匹配模式映射为前端使用的 tags 字段
|
||||
tags = serializers.ListField(
|
||||
# GF 匹配模式(gf-patterns 工具匹配的敏感 URL 模式)
|
||||
gfPatterns = serializers.ListField(
|
||||
child=serializers.CharField(),
|
||||
source='matched_gf_patterns',
|
||||
read_only=True,
|
||||
)
|
||||
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
|
||||
|
||||
class Meta:
|
||||
model = EndpointSnapshot
|
||||
@@ -277,10 +282,11 @@ class EndpointSnapshotSerializer(serializers.ModelSerializer):
|
||||
'content_type',
|
||||
'status_code',
|
||||
'content_length',
|
||||
'body_preview',
|
||||
'response_body',
|
||||
'tech',
|
||||
'vhost',
|
||||
'tags',
|
||||
'discovered_at',
|
||||
'responseHeaders', # HTTP响应头
|
||||
'gfPatterns',
|
||||
'created_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""Directory Service - 目录业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
from typing import List, Iterator, Optional
|
||||
|
||||
from apps.asset.repositories import DjangoDirectoryRepository
|
||||
from apps.asset.dtos import DirectoryDTO
|
||||
from apps.common.validators import is_valid_url, is_url_match_target
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -12,6 +14,12 @@ logger = logging.getLogger(__name__)
|
||||
class DirectoryService:
|
||||
"""目录业务逻辑层"""
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'url': 'url',
|
||||
'status': 'status',
|
||||
}
|
||||
|
||||
def __init__(self, repository=None):
|
||||
"""初始化目录服务"""
|
||||
self.repo = repository or DjangoDirectoryRepository()
|
||||
@@ -37,17 +45,91 @@ class DirectoryService:
|
||||
logger.error(f"批量 upsert 目录失败: {e}")
|
||||
raise
|
||||
|
||||
def get_directories_by_target(self, target_id: int):
|
||||
"""获取目标下的所有目录"""
|
||||
return self.repo.get_by_target(target_id)
|
||||
def bulk_create_urls(self, target_id: int, target_name: str, target_type: str, urls: List[str]) -> int:
|
||||
"""
|
||||
批量创建目录(仅 URL,使用 ignore_conflicts)
|
||||
|
||||
验证 URL 格式和匹配,过滤无效/不匹配 URL,去重后批量创建。
|
||||
已存在的记录会被跳过。
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(用于匹配验证)
|
||||
target_type: 目标类型 ('domain', 'ip', 'cidr')
|
||||
urls: URL 列表
|
||||
|
||||
Returns:
|
||||
int: 实际创建的记录数
|
||||
"""
|
||||
if not urls:
|
||||
return 0
|
||||
|
||||
# 过滤有效 URL 并去重
|
||||
valid_urls = []
|
||||
seen = set()
|
||||
|
||||
for url in urls:
|
||||
if not isinstance(url, str):
|
||||
continue
|
||||
url = url.strip()
|
||||
if not url or url in seen:
|
||||
continue
|
||||
if not is_valid_url(url):
|
||||
continue
|
||||
|
||||
# 匹配验证(前端已阻止不匹配的提交,后端作为双重保障)
|
||||
if not is_url_match_target(url, target_name, target_type):
|
||||
continue
|
||||
|
||||
seen.add(url)
|
||||
valid_urls.append(url)
|
||||
|
||||
if not valid_urls:
|
||||
return 0
|
||||
|
||||
# 获取创建前的数量
|
||||
count_before = self.repo.count_by_target(target_id)
|
||||
|
||||
# 创建 DTO 列表并批量创建
|
||||
directory_dtos = [
|
||||
DirectoryDTO(url=url, target_id=target_id)
|
||||
for url in valid_urls
|
||||
]
|
||||
self.repo.bulk_create_ignore_conflicts(directory_dtos)
|
||||
|
||||
# 获取创建后的数量
|
||||
count_after = self.repo.count_by_target(target_id)
|
||||
return count_after - count_before
|
||||
|
||||
def get_all(self):
|
||||
def get_directories_by_target(self, target_id: int, filter_query: Optional[str] = None):
|
||||
"""获取目标下的所有目录"""
|
||||
queryset = self.repo.get_by_target(target_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self, filter_query: Optional[str] = None):
|
||||
"""获取所有目录"""
|
||||
return self.repo.get_all()
|
||||
queryset = self.repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_directory_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取目标下的所有目录 URL"""
|
||||
return self.repo.get_urls_for_export(target_id=target_id, batch_size=chunk_size)
|
||||
|
||||
def iter_raw_data_for_csv_export(self, target_id: int) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Yields:
|
||||
原始数据字典
|
||||
"""
|
||||
return self.repo.iter_raw_data_for_export(target_id=target_id)
|
||||
|
||||
|
||||
__all__ = ['DirectoryService']
|
||||
|
||||
@@ -5,10 +5,12 @@ Endpoint 服务层
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
from typing import List, Iterator, Optional
|
||||
|
||||
from apps.asset.dtos.asset import EndpointDTO
|
||||
from apps.asset.repositories.asset import DjangoEndpointRepository
|
||||
from apps.common.validators import is_valid_url, is_url_match_target
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,6 +22,15 @@ class EndpointService:
|
||||
提供 Endpoint(URL/端点)相关的业务逻辑
|
||||
"""
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status_code': 'status_code',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""初始化 Endpoint 服务"""
|
||||
self.repo = DjangoEndpointRepository()
|
||||
@@ -45,9 +56,68 @@ class EndpointService:
|
||||
logger.error(f"批量 upsert 端点失败: {e}")
|
||||
raise
|
||||
|
||||
def get_endpoints_by_target(self, target_id: int):
|
||||
def bulk_create_urls(self, target_id: int, target_name: str, target_type: str, urls: List[str]) -> int:
|
||||
"""
|
||||
批量创建端点(仅 URL,使用 ignore_conflicts)
|
||||
|
||||
验证 URL 格式和匹配,过滤无效/不匹配 URL,去重后批量创建。
|
||||
已存在的记录会被跳过。
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(用于匹配验证)
|
||||
target_type: 目标类型 ('domain', 'ip', 'cidr')
|
||||
urls: URL 列表
|
||||
|
||||
Returns:
|
||||
int: 实际创建的记录数
|
||||
"""
|
||||
if not urls:
|
||||
return 0
|
||||
|
||||
# 过滤有效 URL 并去重
|
||||
valid_urls = []
|
||||
seen = set()
|
||||
|
||||
for url in urls:
|
||||
if not isinstance(url, str):
|
||||
continue
|
||||
url = url.strip()
|
||||
if not url or url in seen:
|
||||
continue
|
||||
if not is_valid_url(url):
|
||||
continue
|
||||
|
||||
# 匹配验证(前端已阻止不匹配的提交,后端作为双重保障)
|
||||
if not is_url_match_target(url, target_name, target_type):
|
||||
continue
|
||||
|
||||
seen.add(url)
|
||||
valid_urls.append(url)
|
||||
|
||||
if not valid_urls:
|
||||
return 0
|
||||
|
||||
# 获取创建前的数量
|
||||
count_before = self.repo.count_by_target(target_id)
|
||||
|
||||
# 创建 DTO 列表并批量创建
|
||||
endpoint_dtos = [
|
||||
EndpointDTO(url=url, target_id=target_id)
|
||||
for url in valid_urls
|
||||
]
|
||||
self.repo.bulk_create_ignore_conflicts(endpoint_dtos)
|
||||
|
||||
# 获取创建后的数量
|
||||
count_after = self.repo.count_by_target(target_id)
|
||||
return count_after - count_before
|
||||
|
||||
def get_endpoints_by_target(self, target_id: int, filter_query: Optional[str] = None):
|
||||
"""获取目标下的所有端点"""
|
||||
return self.repo.get_by_target(target_id)
|
||||
queryset = self.repo.get_by_target(target_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
|
||||
return queryset
|
||||
|
||||
def count_endpoints_by_target(self, target_id: int) -> int:
|
||||
"""
|
||||
@@ -61,12 +131,27 @@ class EndpointService:
|
||||
"""
|
||||
return self.repo.count_by_target(target_id)
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: Optional[str] = None):
|
||||
"""获取所有端点(全局查询)"""
|
||||
return self.repo.get_all()
|
||||
queryset = self.repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
|
||||
return queryset
|
||||
|
||||
def iter_endpoint_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取目标下的所有端点 URL,用于导出。"""
|
||||
queryset = self.repo.get_by_target(target_id)
|
||||
for url in queryset.values_list('url', flat=True).iterator(chunk_size=chunk_size):
|
||||
yield url
|
||||
|
||||
def iter_raw_data_for_csv_export(self, target_id: int) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Yields:
|
||||
原始数据字典
|
||||
"""
|
||||
return self.repo.iter_raw_data_for_export(target_id=target_id)
|
||||
|
||||
@@ -1,16 +1,31 @@
|
||||
"""HostPortMapping Service - 业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
from typing import List, Iterator, Optional, Dict
|
||||
|
||||
from django.db.models import Min
|
||||
|
||||
from apps.asset.repositories.asset import DjangoHostPortMappingRepository
|
||||
from apps.asset.dtos.asset import HostPortMappingDTO
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HostPortMappingService:
|
||||
"""主机端口映射服务 - 负责主机端口映射数据的业务逻辑"""
|
||||
"""主机端口映射服务 - 负责主机端口映射数据的业务逻辑
|
||||
|
||||
职责:
|
||||
- 业务逻辑处理(过滤、聚合)
|
||||
- 调用 Repository 进行数据访问
|
||||
"""
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'ip': 'ip',
|
||||
'port': 'port',
|
||||
'host': 'host',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.repo = DjangoHostPortMappingRepository()
|
||||
@@ -49,13 +64,106 @@ class HostPortMappingService:
|
||||
def iter_host_port_by_target(self, target_id: int, batch_size: int = 1000):
|
||||
return self.repo.get_for_export(target_id=target_id, batch_size=batch_size)
|
||||
|
||||
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
|
||||
return self.repo.get_ip_aggregation_by_target(target_id, search=search)
|
||||
def get_ip_aggregation_by_target(
|
||||
self,
|
||||
target_id: int,
|
||||
filter_query: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
"""获取目标下的 IP 聚合数据
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
filter_query: 智能过滤语法字符串
|
||||
|
||||
Returns:
|
||||
聚合后的 IP 数据列表
|
||||
"""
|
||||
# 从 Repository 获取基础 QuerySet
|
||||
qs = self.repo.get_queryset_by_target(target_id)
|
||||
|
||||
# Service 层应用过滤逻辑
|
||||
if filter_query:
|
||||
qs = apply_filters(qs, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
|
||||
# Service 层处理聚合逻辑
|
||||
return self._aggregate_by_ip(qs, filter_query, target_id=target_id)
|
||||
|
||||
def get_all_ip_aggregation(self, search: str = None):
|
||||
"""获取所有 IP 聚合数据(全局查询)"""
|
||||
return self.repo.get_all_ip_aggregation(search=search)
|
||||
def get_all_ip_aggregation(self, filter_query: Optional[str] = None) -> List[Dict]:
|
||||
"""获取所有 IP 聚合数据(全局查询)
|
||||
|
||||
Args:
|
||||
filter_query: 智能过滤语法字符串
|
||||
|
||||
Returns:
|
||||
聚合后的 IP 数据列表
|
||||
"""
|
||||
# 从 Repository 获取基础 QuerySet
|
||||
qs = self.repo.get_all_queryset()
|
||||
|
||||
# Service 层应用过滤逻辑
|
||||
if filter_query:
|
||||
qs = apply_filters(qs, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
|
||||
# Service 层处理聚合逻辑
|
||||
return self._aggregate_by_ip(qs, filter_query)
|
||||
|
||||
def _aggregate_by_ip(
|
||||
self,
|
||||
qs,
|
||||
filter_query: Optional[str] = None,
|
||||
target_id: Optional[int] = None
|
||||
) -> List[Dict]:
|
||||
"""按 IP 聚合数据
|
||||
|
||||
Args:
|
||||
qs: 已过滤的 QuerySet
|
||||
filter_query: 过滤条件(用于子查询)
|
||||
target_id: 目标 ID(用于子查询限定范围)
|
||||
|
||||
Returns:
|
||||
聚合后的数据列表
|
||||
"""
|
||||
ip_aggregated = (
|
||||
qs
|
||||
.values('ip')
|
||||
.annotate(created_at=Min('created_at'))
|
||||
.order_by('-created_at')
|
||||
)
|
||||
|
||||
results = []
|
||||
for item in ip_aggregated:
|
||||
ip = item['ip']
|
||||
|
||||
# 获取该 IP 的所有 host 和 port(也需要应用过滤条件)
|
||||
mappings_qs = self.repo.get_queryset_by_ip(ip, target_id=target_id)
|
||||
if filter_query:
|
||||
mappings_qs = apply_filters(mappings_qs, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
|
||||
mappings = mappings_qs.values('host', 'port').distinct()
|
||||
hosts = sorted({m['host'] for m in mappings})
|
||||
ports = sorted({m['port'] for m in mappings})
|
||||
|
||||
results.append({
|
||||
'ip': ip,
|
||||
'hosts': hosts,
|
||||
'ports': ports,
|
||||
'created_at': item['created_at'],
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def iter_ips_by_target(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取目标下的所有唯一 IP 地址。"""
|
||||
return self.repo.get_ips_for_export(target_id=target_id, batch_size=batch_size)
|
||||
|
||||
def iter_raw_data_for_csv_export(self, target_id: int) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Yields:
|
||||
原始数据字典 {ip, host, port, created_at}
|
||||
"""
|
||||
return self.repo.iter_raw_data_for_export(target_id=target_id)
|
||||
|
||||
@@ -1,15 +1,33 @@
|
||||
import logging
|
||||
from typing import Tuple, List, Dict
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from apps.asset.repositories import DjangoSubdomainRepository
|
||||
from apps.asset.dtos import SubdomainDTO
|
||||
from apps.common.validators import is_valid_domain
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BulkCreateResult:
|
||||
"""批量创建结果"""
|
||||
created_count: int
|
||||
skipped_count: int
|
||||
invalid_count: int
|
||||
mismatched_count: int
|
||||
total_received: int
|
||||
|
||||
|
||||
class SubdomainService:
|
||||
"""子域名业务逻辑层"""
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'name': 'name',
|
||||
}
|
||||
|
||||
def __init__(self, repository=None):
|
||||
"""
|
||||
初始化子域名服务
|
||||
@@ -21,30 +39,50 @@ class SubdomainService:
|
||||
|
||||
# ==================== 查询操作 ====================
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: Optional[str] = None):
|
||||
"""
|
||||
获取所有子域名
|
||||
|
||||
Args:
|
||||
filter_query: 智能过滤语法字符串
|
||||
|
||||
Returns:
|
||||
QuerySet: 子域名查询集
|
||||
"""
|
||||
logger.debug("获取所有子域名")
|
||||
return self.repo.get_all()
|
||||
queryset = self.repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
# ==================== 创建操作 ====================
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
|
||||
def get_subdomains_by_target(self, target_id: int, filter_query: Optional[str] = None):
|
||||
"""
|
||||
批量创建子域名,忽略冲突
|
||||
获取目标下的子域名
|
||||
|
||||
Args:
|
||||
items: 子域名 DTO 列表
|
||||
target_id: 目标 ID
|
||||
filter_query: 智能过滤语法字符串
|
||||
|
||||
Note:
|
||||
使用 ignore_conflicts 策略,重复记录会被跳过
|
||||
Returns:
|
||||
QuerySet: 子域名查询集
|
||||
"""
|
||||
logger.debug("批量创建子域名 - 数量: %d", len(items))
|
||||
return self.repo.bulk_create_ignore_conflicts(items)
|
||||
queryset = self.repo.get_by_target(target_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def count_subdomains_by_target(self, target_id: int) -> int:
|
||||
"""
|
||||
统计目标下的子域名数量
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
int: 子域名数量
|
||||
"""
|
||||
logger.debug("统计目标下子域名数量 - Target ID: %d", target_id)
|
||||
return self.repo.count_by_target(target_id)
|
||||
|
||||
def get_by_names_and_target_id(self, names: set, target_id: int) -> dict:
|
||||
"""
|
||||
@@ -71,25 +109,8 @@ class SubdomainService:
|
||||
List[str]: 子域名名称列表
|
||||
"""
|
||||
logger.debug("获取目标下所有子域名 - Target ID: %d", target_id)
|
||||
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
|
||||
return list(self.repo.get_domains_for_export(target_id=target_id))
|
||||
|
||||
def get_subdomains_by_target(self, target_id: int):
|
||||
return self.repo.get_by_target(target_id)
|
||||
|
||||
def count_subdomains_by_target(self, target_id: int) -> int:
|
||||
"""
|
||||
统计目标下的子域名数量
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
int: 子域名数量
|
||||
"""
|
||||
logger.debug("统计目标下子域名数量 - Target ID: %d", target_id)
|
||||
return self.repo.count_by_target(target_id)
|
||||
|
||||
def iter_subdomain_names_by_target(self, target_id: int, chunk_size: int = 1000):
|
||||
"""
|
||||
流式获取目标下的所有子域名名称(内存优化)
|
||||
@@ -102,8 +123,123 @@ class SubdomainService:
|
||||
str: 子域名名称
|
||||
"""
|
||||
logger.debug("流式获取目标下所有子域名 - Target ID: %d, 批次大小: %d", target_id, chunk_size)
|
||||
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
|
||||
return self.repo.get_domains_for_export(target_id=target_id, batch_size=chunk_size)
|
||||
|
||||
def iter_raw_data_for_csv_export(self, target_id: int):
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Yields:
|
||||
原始数据字典 {name, created_at}
|
||||
"""
|
||||
return self.repo.iter_raw_data_for_export(target_id=target_id)
|
||||
|
||||
__all__ = ['SubdomainService']
|
||||
# ==================== 创建操作 ====================
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
|
||||
"""
|
||||
批量创建子域名,忽略冲突
|
||||
|
||||
Args:
|
||||
items: 子域名 DTO 列表
|
||||
|
||||
Note:
|
||||
使用 ignore_conflicts 策略,重复记录会被跳过
|
||||
"""
|
||||
logger.debug("批量创建子域名 - 数量: %d", len(items))
|
||||
return self.repo.bulk_create_ignore_conflicts(items)
|
||||
|
||||
def bulk_create_subdomains(
|
||||
self,
|
||||
target_id: int,
|
||||
target_name: str,
|
||||
subdomains: List[str]
|
||||
) -> BulkCreateResult:
|
||||
"""
|
||||
批量创建子域名(带验证)
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标域名(用于匹配验证)
|
||||
subdomains: 子域名列表
|
||||
|
||||
Returns:
|
||||
BulkCreateResult: 创建结果统计
|
||||
"""
|
||||
total_received = len(subdomains)
|
||||
target_name = target_name.lower().strip()
|
||||
|
||||
def is_subdomain_match(subdomain: str) -> bool:
|
||||
"""验证子域名是否匹配目标域名"""
|
||||
if subdomain == target_name:
|
||||
return True
|
||||
if subdomain.endswith('.' + target_name):
|
||||
return True
|
||||
return False
|
||||
|
||||
# 过滤有效的子域名
|
||||
valid_subdomains = []
|
||||
invalid_count = 0
|
||||
mismatched_count = 0
|
||||
|
||||
for subdomain in subdomains:
|
||||
if not isinstance(subdomain, str) or not subdomain.strip():
|
||||
continue
|
||||
|
||||
subdomain = subdomain.lower().strip()
|
||||
|
||||
# 验证格式
|
||||
if not is_valid_domain(subdomain):
|
||||
invalid_count += 1
|
||||
continue
|
||||
|
||||
# 验证匹配
|
||||
if not is_subdomain_match(subdomain):
|
||||
mismatched_count += 1
|
||||
continue
|
||||
|
||||
valid_subdomains.append(subdomain)
|
||||
|
||||
# 去重
|
||||
unique_subdomains = list(set(valid_subdomains))
|
||||
duplicate_count = len(valid_subdomains) - len(unique_subdomains)
|
||||
|
||||
if not unique_subdomains:
|
||||
return BulkCreateResult(
|
||||
created_count=0,
|
||||
skipped_count=duplicate_count,
|
||||
invalid_count=invalid_count,
|
||||
mismatched_count=mismatched_count,
|
||||
total_received=total_received,
|
||||
)
|
||||
|
||||
# 获取创建前的数量
|
||||
count_before = self.repo.count_by_target(target_id)
|
||||
|
||||
# 创建 DTO 列表并批量创建
|
||||
subdomain_dtos = [
|
||||
SubdomainDTO(name=name, target_id=target_id)
|
||||
for name in unique_subdomains
|
||||
]
|
||||
self.repo.bulk_create_ignore_conflicts(subdomain_dtos)
|
||||
|
||||
# 获取创建后的数量
|
||||
count_after = self.repo.count_by_target(target_id)
|
||||
created_count = count_after - count_before
|
||||
|
||||
# 计算因数据库冲突跳过的数量
|
||||
db_skipped = len(unique_subdomains) - created_count
|
||||
|
||||
return BulkCreateResult(
|
||||
created_count=created_count,
|
||||
skipped_count=duplicate_count + db_skipped,
|
||||
invalid_count=invalid_count,
|
||||
mismatched_count=mismatched_count,
|
||||
total_received=total_received,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ['SubdomainService', 'BulkCreateResult']
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
"""Vulnerability Service - 漏洞资产业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from apps.asset.models import Vulnerability
|
||||
from apps.asset.dtos.asset import VulnerabilityDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from apps.common.utils import deduplicate_for_bulk
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,10 +18,20 @@ class VulnerabilityService:
|
||||
|
||||
当前提供基础的批量创建能力,使用 ignore_conflicts 依赖数据库唯一约束去重。
|
||||
"""
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'type': 'vuln_type',
|
||||
'severity': 'severity',
|
||||
'source': 'source',
|
||||
'url': 'url',
|
||||
}
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[VulnerabilityDTO]) -> None:
|
||||
"""批量创建漏洞资产记录,忽略冲突。
|
||||
|
||||
注意:会自动按 (target_id, url, vuln_type, source) 去重,保留最后一条记录。
|
||||
|
||||
Note:
|
||||
- 是否去重取决于模型上的唯一/部分唯一约束;
|
||||
- 当前 Vulnerability 模型未定义唯一约束,因此会保留全部记录。
|
||||
@@ -29,6 +41,9 @@ class VulnerabilityService:
|
||||
return
|
||||
|
||||
try:
|
||||
# 根据模型唯一约束自动去重(如果模型没有唯一约束则跳过)
|
||||
unique_items = deduplicate_for_bulk(items, Vulnerability)
|
||||
|
||||
vulns = [
|
||||
Vulnerability(
|
||||
target_id=item.target_id,
|
||||
@@ -40,7 +55,7 @@ class VulnerabilityService:
|
||||
description=item.description,
|
||||
raw_output=item.raw_output,
|
||||
)
|
||||
for item in items
|
||||
for item in unique_items
|
||||
]
|
||||
|
||||
Vulnerability.objects.bulk_create(vulns, ignore_conflicts=True)
|
||||
@@ -57,24 +72,34 @@ class VulnerabilityService:
|
||||
|
||||
# ==================== 查询方法 ====================
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: Optional[str] = None):
|
||||
"""获取所有漏洞 QuerySet(用于全局漏洞列表)。
|
||||
|
||||
Returns:
|
||||
QuerySet[Vulnerability]: 所有漏洞,按发现时间倒序
|
||||
"""
|
||||
return Vulnerability.objects.all().order_by("-discovered_at")
|
||||
Args:
|
||||
filter_query: 智能过滤语法字符串
|
||||
|
||||
def get_vulnerabilities_by_target(self, target_id: int):
|
||||
Returns:
|
||||
QuerySet[Vulnerability]: 所有漏洞,按创建时间倒序
|
||||
"""
|
||||
queryset = Vulnerability.objects.all().order_by("-created_at")
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_vulnerabilities_by_target(self, target_id: int, filter_query: Optional[str] = None):
|
||||
"""按目标获取漏洞 QuerySet(用于分页)。
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
filter_query: 智能过滤语法字符串
|
||||
|
||||
Returns:
|
||||
QuerySet[Vulnerability]: 目标下的所有漏洞,按发现时间倒序
|
||||
QuerySet[Vulnerability]: 目标下的所有漏洞,按创建时间倒序
|
||||
"""
|
||||
return Vulnerability.objects.filter(target_id=target_id).order_by("-discovered_at")
|
||||
queryset = Vulnerability.objects.filter(target_id=target_id).order_by("-created_at")
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def count_by_target(self, target_id: int) -> int:
|
||||
"""统计目标下的漏洞数量。"""
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""WebSite Service - 网站业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Iterator, Optional
|
||||
|
||||
from apps.asset.repositories import DjangoWebSiteRepository
|
||||
from apps.asset.dtos import WebSiteDTO
|
||||
from apps.common.validators import is_valid_url, is_url_match_target
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -12,6 +14,15 @@ logger = logging.getLogger(__name__)
|
||||
class WebSiteService:
|
||||
"""网站业务逻辑层"""
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status_code': 'status_code',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
def __init__(self, repository=None):
|
||||
"""初始化网站服务"""
|
||||
self.repo = repository or DjangoWebSiteRepository()
|
||||
@@ -37,13 +48,75 @@ class WebSiteService:
|
||||
logger.error(f"批量 upsert 网站失败: {e}")
|
||||
raise
|
||||
|
||||
def get_websites_by_target(self, target_id: int):
|
||||
"""获取目标下的所有网站"""
|
||||
return self.repo.get_by_target(target_id)
|
||||
def bulk_create_urls(self, target_id: int, target_name: str, target_type: str, urls: List[str]) -> int:
|
||||
"""
|
||||
批量创建网站(仅 URL,使用 ignore_conflicts)
|
||||
|
||||
验证 URL 格式和匹配,过滤无效/不匹配 URL,去重后批量创建。
|
||||
已存在的记录会被跳过。
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(用于匹配验证)
|
||||
target_type: 目标类型 ('domain', 'ip', 'cidr')
|
||||
urls: URL 列表
|
||||
|
||||
Returns:
|
||||
int: 实际创建的记录数
|
||||
"""
|
||||
if not urls:
|
||||
return 0
|
||||
|
||||
# 过滤有效 URL 并去重
|
||||
valid_urls = []
|
||||
seen = set()
|
||||
|
||||
for url in urls:
|
||||
if not isinstance(url, str):
|
||||
continue
|
||||
url = url.strip()
|
||||
if not url or url in seen:
|
||||
continue
|
||||
if not is_valid_url(url):
|
||||
continue
|
||||
|
||||
# 匹配验证(前端已阻止不匹配的提交,后端作为双重保障)
|
||||
if not is_url_match_target(url, target_name, target_type):
|
||||
continue
|
||||
|
||||
seen.add(url)
|
||||
valid_urls.append(url)
|
||||
|
||||
if not valid_urls:
|
||||
return 0
|
||||
|
||||
# 获取创建前的数量
|
||||
count_before = self.repo.count_by_target(target_id)
|
||||
|
||||
# 创建 DTO 列表并批量创建
|
||||
website_dtos = [
|
||||
WebSiteDTO(url=url, target_id=target_id)
|
||||
for url in valid_urls
|
||||
]
|
||||
self.repo.bulk_create_ignore_conflicts(website_dtos)
|
||||
|
||||
# 获取创建后的数量
|
||||
count_after = self.repo.count_by_target(target_id)
|
||||
return count_after - count_before
|
||||
|
||||
def get_all(self):
|
||||
def get_websites_by_target(self, target_id: int, filter_query: Optional[str] = None):
|
||||
"""获取目标下的所有网站"""
|
||||
queryset = self.repo.get_by_target(target_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
|
||||
return queryset
|
||||
|
||||
def get_all(self, filter_query: Optional[str] = None):
|
||||
"""获取所有网站"""
|
||||
return self.repo.get_all()
|
||||
queryset = self.repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
|
||||
return queryset
|
||||
|
||||
def get_by_url(self, url: str, target_id: int) -> int:
|
||||
"""根据 URL 和 target_id 查找网站 ID"""
|
||||
@@ -53,5 +126,17 @@ class WebSiteService:
|
||||
"""流式获取目标下的所有站点 URL"""
|
||||
return self.repo.get_urls_for_export(target_id=target_id, batch_size=chunk_size)
|
||||
|
||||
def iter_raw_data_for_csv_export(self, target_id: int) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Yields:
|
||||
原始数据字典
|
||||
"""
|
||||
return self.repo.iter_raw_data_for_export(target_id=target_id)
|
||||
|
||||
|
||||
__all__ = ['WebSiteService']
|
||||
|
||||
439
backend/apps/asset/services/search_service.py
Normal file
439
backend/apps/asset/services/search_service.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
资产搜索服务
|
||||
|
||||
提供资产搜索的核心业务逻辑:
|
||||
- 从物化视图查询数据
|
||||
- 支持表达式语法解析
|
||||
- 支持 =(模糊)、==(精确)、!=(不等于)操作符
|
||||
- 支持 && (AND) 和 || (OR) 逻辑组合
|
||||
- 支持 Website 和 Endpoint 两种资产类型
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Dict, Any, Tuple, Literal
|
||||
|
||||
from django.db import connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 支持的字段映射(前端字段名 -> 数据库字段名)
|
||||
FIELD_MAPPING = {
|
||||
'host': 'host',
|
||||
'url': 'url',
|
||||
'title': 'title',
|
||||
'tech': 'tech',
|
||||
'status': 'status_code',
|
||||
'body': 'response_body',
|
||||
'header': 'response_headers',
|
||||
}
|
||||
|
||||
# 数组类型字段
|
||||
ARRAY_FIELDS = {'tech'}
|
||||
|
||||
# 资产类型到视图名的映射
|
||||
VIEW_MAPPING = {
|
||||
'website': 'asset_search_view',
|
||||
'endpoint': 'endpoint_search_view',
|
||||
}
|
||||
|
||||
# 有效的资产类型
|
||||
VALID_ASSET_TYPES = {'website', 'endpoint'}
|
||||
|
||||
# Website 查询字段
|
||||
WEBSITE_SELECT_FIELDS = """
|
||||
id,
|
||||
url,
|
||||
host,
|
||||
title,
|
||||
tech,
|
||||
status_code,
|
||||
response_headers,
|
||||
response_body,
|
||||
content_type,
|
||||
content_length,
|
||||
webserver,
|
||||
location,
|
||||
vhost,
|
||||
created_at,
|
||||
target_id
|
||||
"""
|
||||
|
||||
# Endpoint 查询字段(包含 matched_gf_patterns)
|
||||
ENDPOINT_SELECT_FIELDS = """
|
||||
id,
|
||||
url,
|
||||
host,
|
||||
title,
|
||||
tech,
|
||||
status_code,
|
||||
response_headers,
|
||||
response_body,
|
||||
content_type,
|
||||
content_length,
|
||||
webserver,
|
||||
location,
|
||||
vhost,
|
||||
matched_gf_patterns,
|
||||
created_at,
|
||||
target_id
|
||||
"""
|
||||
|
||||
|
||||
class SearchQueryParser:
|
||||
"""
|
||||
搜索查询解析器
|
||||
|
||||
支持语法:
|
||||
- field="value" 模糊匹配(ILIKE %value%)
|
||||
- field=="value" 精确匹配
|
||||
- field!="value" 不等于
|
||||
- && AND 连接
|
||||
- || OR 连接
|
||||
- () 分组(暂不支持嵌套)
|
||||
|
||||
示例:
|
||||
- host="api" && tech="nginx"
|
||||
- tech="vue" || tech="react"
|
||||
- status=="200" && host!="test"
|
||||
"""
|
||||
|
||||
# 匹配单个条件: field="value" 或 field=="value" 或 field!="value"
|
||||
CONDITION_PATTERN = re.compile(r'(\w+)\s*(==|!=|=)\s*"([^"]*)"')
|
||||
|
||||
@classmethod
|
||||
def parse(cls, query: str) -> Tuple[str, List[Any]]:
|
||||
"""
|
||||
解析查询字符串,返回 SQL WHERE 子句和参数
|
||||
|
||||
Args:
|
||||
query: 搜索查询字符串
|
||||
|
||||
Returns:
|
||||
(where_clause, params) 元组
|
||||
"""
|
||||
if not query or not query.strip():
|
||||
return "1=1", []
|
||||
|
||||
query = query.strip()
|
||||
|
||||
# 检查是否包含操作符语法,如果不包含则作为 host 模糊搜索
|
||||
if not cls.CONDITION_PATTERN.search(query):
|
||||
# 裸文本,默认作为 host 模糊搜索
|
||||
return "host ILIKE %s", [f"%{query}%"]
|
||||
|
||||
# 按 || 分割为 OR 组
|
||||
or_groups = cls._split_by_or(query)
|
||||
|
||||
if len(or_groups) == 1:
|
||||
# 没有 OR,直接解析 AND 条件
|
||||
return cls._parse_and_group(or_groups[0])
|
||||
|
||||
# 多个 OR 组
|
||||
or_clauses = []
|
||||
all_params = []
|
||||
|
||||
for group in or_groups:
|
||||
clause, params = cls._parse_and_group(group)
|
||||
if clause and clause != "1=1":
|
||||
or_clauses.append(f"({clause})")
|
||||
all_params.extend(params)
|
||||
|
||||
if not or_clauses:
|
||||
return "1=1", []
|
||||
|
||||
return " OR ".join(or_clauses), all_params
|
||||
|
||||
@classmethod
|
||||
def _split_by_or(cls, query: str) -> List[str]:
|
||||
"""按 || 分割查询,但忽略引号内的 ||"""
|
||||
parts = []
|
||||
current = ""
|
||||
in_quotes = False
|
||||
i = 0
|
||||
|
||||
while i < len(query):
|
||||
char = query[i]
|
||||
|
||||
if char == '"':
|
||||
in_quotes = not in_quotes
|
||||
current += char
|
||||
elif not in_quotes and i + 1 < len(query) and query[i:i+2] == '||':
|
||||
if current.strip():
|
||||
parts.append(current.strip())
|
||||
current = ""
|
||||
i += 1 # 跳过第二个 |
|
||||
else:
|
||||
current += char
|
||||
|
||||
i += 1
|
||||
|
||||
if current.strip():
|
||||
parts.append(current.strip())
|
||||
|
||||
return parts if parts else [query]
|
||||
|
||||
@classmethod
|
||||
def _parse_and_group(cls, group: str) -> Tuple[str, List[Any]]:
|
||||
"""解析 AND 组(用 && 连接的条件)"""
|
||||
# 移除外层括号
|
||||
group = group.strip()
|
||||
if group.startswith('(') and group.endswith(')'):
|
||||
group = group[1:-1].strip()
|
||||
|
||||
# 按 && 分割
|
||||
parts = cls._split_by_and(group)
|
||||
|
||||
and_clauses = []
|
||||
all_params = []
|
||||
|
||||
for part in parts:
|
||||
clause, params = cls._parse_condition(part.strip())
|
||||
if clause:
|
||||
and_clauses.append(clause)
|
||||
all_params.extend(params)
|
||||
|
||||
if not and_clauses:
|
||||
return "1=1", []
|
||||
|
||||
return " AND ".join(and_clauses), all_params
|
||||
|
||||
@classmethod
|
||||
def _split_by_and(cls, query: str) -> List[str]:
|
||||
"""按 && 分割查询,但忽略引号内的 &&"""
|
||||
parts = []
|
||||
current = ""
|
||||
in_quotes = False
|
||||
i = 0
|
||||
|
||||
while i < len(query):
|
||||
char = query[i]
|
||||
|
||||
if char == '"':
|
||||
in_quotes = not in_quotes
|
||||
current += char
|
||||
elif not in_quotes and i + 1 < len(query) and query[i:i+2] == '&&':
|
||||
if current.strip():
|
||||
parts.append(current.strip())
|
||||
current = ""
|
||||
i += 1 # 跳过第二个 &
|
||||
else:
|
||||
current += char
|
||||
|
||||
i += 1
|
||||
|
||||
if current.strip():
|
||||
parts.append(current.strip())
|
||||
|
||||
return parts if parts else [query]
|
||||
|
||||
@classmethod
|
||||
def _parse_condition(cls, condition: str) -> Tuple[Optional[str], List[Any]]:
|
||||
"""
|
||||
解析单个条件
|
||||
|
||||
Returns:
|
||||
(sql_clause, params) 或 (None, []) 如果解析失败
|
||||
"""
|
||||
# 移除括号
|
||||
condition = condition.strip()
|
||||
if condition.startswith('(') and condition.endswith(')'):
|
||||
condition = condition[1:-1].strip()
|
||||
|
||||
match = cls.CONDITION_PATTERN.match(condition)
|
||||
if not match:
|
||||
logger.warning(f"无法解析条件: {condition}")
|
||||
return None, []
|
||||
|
||||
field, operator, value = match.groups()
|
||||
field = field.lower()
|
||||
|
||||
# 验证字段
|
||||
if field not in FIELD_MAPPING:
|
||||
logger.warning(f"未知字段: {field}")
|
||||
return None, []
|
||||
|
||||
db_field = FIELD_MAPPING[field]
|
||||
is_array = field in ARRAY_FIELDS
|
||||
|
||||
# 根据操作符生成 SQL
|
||||
if operator == '=':
|
||||
# 模糊匹配
|
||||
return cls._build_like_condition(db_field, value, is_array)
|
||||
elif operator == '==':
|
||||
# 精确匹配
|
||||
return cls._build_exact_condition(db_field, value, is_array)
|
||||
elif operator == '!=':
|
||||
# 不等于
|
||||
return cls._build_not_equal_condition(db_field, value, is_array)
|
||||
|
||||
return None, []
|
||||
|
||||
@classmethod
|
||||
def _build_like_condition(cls, field: str, value: str, is_array: bool) -> Tuple[str, List[Any]]:
|
||||
"""构建模糊匹配条件"""
|
||||
if is_array:
|
||||
# 数组字段:检查数组中是否有元素包含该值
|
||||
return f"EXISTS (SELECT 1 FROM unnest({field}) AS t WHERE t ILIKE %s)", [f"%{value}%"]
|
||||
elif field == 'status_code':
|
||||
# 状态码是整数,模糊匹配转为精确匹配
|
||||
try:
|
||||
return f"{field} = %s", [int(value)]
|
||||
except ValueError:
|
||||
return f"{field}::text ILIKE %s", [f"%{value}%"]
|
||||
else:
|
||||
return f"{field} ILIKE %s", [f"%{value}%"]
|
||||
|
||||
@classmethod
|
||||
def _build_exact_condition(cls, field: str, value: str, is_array: bool) -> Tuple[str, List[Any]]:
|
||||
"""构建精确匹配条件"""
|
||||
if is_array:
|
||||
# 数组字段:检查数组中是否包含该精确值
|
||||
return f"%s = ANY({field})", [value]
|
||||
elif field == 'status_code':
|
||||
# 状态码是整数
|
||||
try:
|
||||
return f"{field} = %s", [int(value)]
|
||||
except ValueError:
|
||||
return f"{field}::text = %s", [value]
|
||||
else:
|
||||
return f"{field} = %s", [value]
|
||||
|
||||
@classmethod
|
||||
def _build_not_equal_condition(cls, field: str, value: str, is_array: bool) -> Tuple[str, List[Any]]:
|
||||
"""构建不等于条件"""
|
||||
if is_array:
|
||||
# 数组字段:检查数组中不包含该值
|
||||
return f"NOT (%s = ANY({field}))", [value]
|
||||
elif field == 'status_code':
|
||||
try:
|
||||
return f"({field} IS NULL OR {field} != %s)", [int(value)]
|
||||
except ValueError:
|
||||
return f"({field} IS NULL OR {field}::text != %s)", [value]
|
||||
else:
|
||||
return f"({field} IS NULL OR {field} != %s)", [value]
|
||||
|
||||
|
||||
AssetType = Literal['website', 'endpoint']
|
||||
|
||||
|
||||
class AssetSearchService:
|
||||
"""资产搜索服务"""
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
asset_type: AssetType = 'website',
|
||||
limit: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
搜索资产
|
||||
|
||||
Args:
|
||||
query: 搜索查询字符串
|
||||
asset_type: 资产类型 ('website' 或 'endpoint')
|
||||
limit: 最大返回数量(可选)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 搜索结果列表
|
||||
"""
|
||||
where_clause, params = SearchQueryParser.parse(query)
|
||||
|
||||
# 根据资产类型选择视图和字段
|
||||
view_name = VIEW_MAPPING.get(asset_type, 'asset_search_view')
|
||||
select_fields = ENDPOINT_SELECT_FIELDS if asset_type == 'endpoint' else WEBSITE_SELECT_FIELDS
|
||||
|
||||
sql = f"""
|
||||
SELECT {select_fields}
|
||||
FROM {view_name}
|
||||
WHERE {where_clause}
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
|
||||
# 添加 LIMIT
|
||||
if limit is not None and limit > 0:
|
||||
sql += f" LIMIT {int(limit)}"
|
||||
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(sql, params)
|
||||
columns = [col[0] for col in cursor.description]
|
||||
results = []
|
||||
|
||||
for row in cursor.fetchall():
|
||||
result = dict(zip(columns, row))
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"搜索查询失败: {e}, SQL: {sql}, params: {params}")
|
||||
raise
|
||||
|
||||
def count(self, query: str, asset_type: AssetType = 'website') -> int:
|
||||
"""
|
||||
统计搜索结果数量
|
||||
|
||||
Args:
|
||||
query: 搜索查询字符串
|
||||
asset_type: 资产类型 ('website' 或 'endpoint')
|
||||
|
||||
Returns:
|
||||
int: 结果总数
|
||||
"""
|
||||
where_clause, params = SearchQueryParser.parse(query)
|
||||
|
||||
# 根据资产类型选择视图
|
||||
view_name = VIEW_MAPPING.get(asset_type, 'asset_search_view')
|
||||
|
||||
sql = f"SELECT COUNT(*) FROM {view_name} WHERE {where_clause}"
|
||||
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(sql, params)
|
||||
return cursor.fetchone()[0]
|
||||
except Exception as e:
|
||||
logger.error(f"统计查询失败: {e}")
|
||||
raise
|
||||
|
||||
def search_iter(
|
||||
self,
|
||||
query: str,
|
||||
asset_type: AssetType = 'website',
|
||||
batch_size: int = 1000
|
||||
):
|
||||
"""
|
||||
流式搜索资产(使用服务端游标,内存友好)
|
||||
|
||||
Args:
|
||||
query: 搜索查询字符串
|
||||
asset_type: 资产类型 ('website' 或 'endpoint')
|
||||
batch_size: 每批获取的数量
|
||||
|
||||
Yields:
|
||||
Dict: 单条搜索结果
|
||||
"""
|
||||
where_clause, params = SearchQueryParser.parse(query)
|
||||
|
||||
# 根据资产类型选择视图和字段
|
||||
view_name = VIEW_MAPPING.get(asset_type, 'asset_search_view')
|
||||
select_fields = ENDPOINT_SELECT_FIELDS if asset_type == 'endpoint' else WEBSITE_SELECT_FIELDS
|
||||
|
||||
sql = f"""
|
||||
SELECT {select_fields}
|
||||
FROM {view_name}
|
||||
WHERE {where_clause}
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
|
||||
try:
|
||||
# 使用服务端游标,避免一次性加载所有数据到内存
|
||||
with connection.cursor(name='export_cursor') as cursor:
|
||||
cursor.itersize = batch_size
|
||||
cursor.execute(sql, params)
|
||||
columns = [col[0] for col in cursor.description]
|
||||
|
||||
for row in cursor:
|
||||
yield dict(zip(columns, row))
|
||||
except Exception as e:
|
||||
logger.error(f"流式搜索查询失败: {e}, SQL: {sql}, params: {params}")
|
||||
raise
|
||||
@@ -50,7 +50,7 @@ class DirectorySnapshotsService:
|
||||
|
||||
# 步骤 2: 转换为资产 DTO 并保存到资产表(upsert)
|
||||
# - 新记录:插入资产表
|
||||
# - 已存在的记录:更新字段(discovered_at 不更新,保留首次发现时间)
|
||||
# - 已存在的记录:更新字段(created_at 不更新,保留创建时间)
|
||||
logger.debug("步骤 2: 同步到资产表(通过 Service 层,upsert)")
|
||||
asset_items = [item.to_asset_dto() for item in items]
|
||||
|
||||
@@ -67,15 +67,44 @@ class DirectorySnapshotsService:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return self.snapshot_repo.get_by_scan(scan_id)
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'url': 'url',
|
||||
'status': 'status',
|
||||
'content_type': 'content_type',
|
||||
}
|
||||
|
||||
def get_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: str = None):
|
||||
"""获取所有目录快照"""
|
||||
return self.snapshot_repo.get_all()
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_directory_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有目录 URL。"""
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
for snapshot in queryset.iterator(chunk_size=chunk_size):
|
||||
yield snapshot.url
|
||||
|
||||
def iter_raw_data_for_csv_export(self, scan_id: int) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
|
||||
Yields:
|
||||
原始数据字典
|
||||
"""
|
||||
return self.snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)
|
||||
|
||||
@@ -67,15 +67,47 @@ class EndpointSnapshotsService:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return self.snapshot_repo.get_by_scan(scan_id)
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status_code': 'status_code',
|
||||
'webserver': 'webserver',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
def get_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: str = None):
|
||||
"""获取所有端点快照"""
|
||||
return self.snapshot_repo.get_all()
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_endpoint_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有端点 URL。"""
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
for snapshot in queryset.iterator(chunk_size=chunk_size):
|
||||
yield snapshot.url
|
||||
|
||||
def iter_raw_data_for_csv_export(self, scan_id: int) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
|
||||
Yields:
|
||||
原始数据字典
|
||||
"""
|
||||
return self.snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)
|
||||
|
||||
@@ -69,13 +69,25 @@ class HostPortMappingSnapshotsService:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
|
||||
return self.snapshot_repo.get_ip_aggregation_by_scan(scan_id, search=search)
|
||||
def get_ip_aggregation_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
return self.snapshot_repo.get_ip_aggregation_by_scan(scan_id, filter_query=filter_query)
|
||||
|
||||
def get_all_ip_aggregation(self, search: str = None):
|
||||
def get_all_ip_aggregation(self, filter_query: str = None):
|
||||
"""获取所有 IP 聚合数据"""
|
||||
return self.snapshot_repo.get_all_ip_aggregation(search=search)
|
||||
return self.snapshot_repo.get_all_ip_aggregation(filter_query=filter_query)
|
||||
|
||||
def iter_ips_by_scan(self, scan_id: int, batch_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有唯一 IP 地址。"""
|
||||
return self.snapshot_repo.get_ips_for_export(scan_id=scan_id, batch_size=batch_size)
|
||||
|
||||
def iter_raw_data_for_csv_export(self, scan_id: int) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
|
||||
Yields:
|
||||
原始数据字典 {ip, host, port, created_at}
|
||||
"""
|
||||
return self.snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)
|
||||
|
||||
@@ -66,14 +66,41 @@ class SubdomainSnapshotsService:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return self.subdomain_snapshot_repo.get_by_scan(scan_id)
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'name': 'name',
|
||||
}
|
||||
|
||||
def get_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.subdomain_snapshot_repo.get_by_scan(scan_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: str = None):
|
||||
"""获取所有子域名快照"""
|
||||
return self.subdomain_snapshot_repo.get_all()
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.subdomain_snapshot_repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_subdomain_names_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
queryset = self.subdomain_snapshot_repo.get_by_scan(scan_id)
|
||||
for snapshot in queryset.iterator(chunk_size=chunk_size):
|
||||
yield snapshot.name
|
||||
yield snapshot.name
|
||||
|
||||
def iter_raw_data_for_csv_export(self, scan_id: int) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
|
||||
Yields:
|
||||
原始数据字典 {name, created_at}
|
||||
"""
|
||||
return self.subdomain_snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)
|
||||
@@ -66,13 +66,31 @@ class VulnerabilitySnapshotsService:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
"""按扫描任务获取所有漏洞快照。"""
|
||||
return self.snapshot_repo.get_by_scan(scan_id)
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'type': 'vuln_type',
|
||||
'url': 'url',
|
||||
'severity': 'severity',
|
||||
'source': 'source',
|
||||
}
|
||||
|
||||
def get_all(self):
|
||||
def get_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
"""按扫描任务获取所有漏洞快照。"""
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self, filter_query: str = None):
|
||||
"""获取所有漏洞快照"""
|
||||
return self.snapshot_repo.get_all()
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_vuln_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有漏洞 URL。"""
|
||||
|
||||
@@ -51,7 +51,7 @@ class WebsiteSnapshotsService:
|
||||
|
||||
# 步骤 2: 转换为资产 DTO 并保存到资产表(upsert)
|
||||
# - 新记录:插入资产表
|
||||
# - 已存在的记录:更新字段(discovered_at 不更新,保留首次发现时间)
|
||||
# - 已存在的记录:更新字段(created_at 不更新,保留创建时间)
|
||||
logger.debug("步骤 2: 同步到资产表(通过 Service 层,upsert)")
|
||||
asset_items = [item.to_asset_dto() for item in items]
|
||||
|
||||
@@ -68,15 +68,47 @@ class WebsiteSnapshotsService:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return self.snapshot_repo.get_by_scan(scan_id)
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status_code': 'status_code',
|
||||
'webserver': 'webserver',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
def get_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: str = None):
|
||||
"""获取所有网站快照"""
|
||||
return self.snapshot_repo.get_all()
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_website_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有站点 URL(按发现时间倒序)。"""
|
||||
"""流式获取某次扫描下的所有站点 URL(按创建时间倒序)。"""
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
for snapshot in queryset.iterator(chunk_size=chunk_size):
|
||||
yield snapshot.url
|
||||
|
||||
def iter_raw_data_for_csv_export(self, scan_id: int) -> Iterator[dict]:
|
||||
"""
|
||||
流式获取原始数据用于 CSV 导出
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
|
||||
Yields:
|
||||
原始数据字典
|
||||
"""
|
||||
return self.snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)
|
||||
|
||||
7
backend/apps/asset/tasks/__init__.py
Normal file
7
backend/apps/asset/tasks/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Asset 应用的任务模块
|
||||
|
||||
注意:物化视图刷新已移至 APScheduler 定时任务(apps.engine.scheduler)
|
||||
"""
|
||||
|
||||
__all__ = []
|
||||
@@ -10,6 +10,8 @@ from .views import (
|
||||
DirectoryViewSet,
|
||||
VulnerabilityViewSet,
|
||||
AssetStatisticsViewSet,
|
||||
AssetSearchView,
|
||||
AssetSearchExportView,
|
||||
)
|
||||
|
||||
# 创建 DRF 路由器
|
||||
@@ -25,4 +27,6 @@ router.register(r'statistics', AssetStatisticsViewSet, basename='asset-statistic
|
||||
|
||||
urlpatterns = [
|
||||
path('assets/', include(router.urls)),
|
||||
path('assets/search/', AssetSearchView.as_view(), name='asset-search'),
|
||||
path('assets/search/export/', AssetSearchExportView.as_view(), name='asset-search-export'),
|
||||
]
|
||||
|
||||
@@ -1,562 +0,0 @@
|
||||
import logging
|
||||
from rest_framework import viewsets, status, filters
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.exceptions import NotFound, ValidationError as DRFValidationError
|
||||
from django.core.exceptions import ValidationError, ObjectDoesNotExist
|
||||
from django.db import DatabaseError, IntegrityError, OperationalError
|
||||
from django.http import StreamingHttpResponse
|
||||
|
||||
from .serializers import (
|
||||
SubdomainListSerializer, WebSiteSerializer, DirectorySerializer,
|
||||
VulnerabilitySerializer, EndpointListSerializer, IPAddressAggregatedSerializer,
|
||||
SubdomainSnapshotSerializer, WebsiteSnapshotSerializer, DirectorySnapshotSerializer,
|
||||
EndpointSnapshotSerializer, VulnerabilitySnapshotSerializer
|
||||
)
|
||||
from .services import (
|
||||
SubdomainService, WebSiteService, DirectoryService,
|
||||
VulnerabilityService, AssetStatisticsService, EndpointService, HostPortMappingService
|
||||
)
|
||||
from .services.snapshot import (
|
||||
SubdomainSnapshotsService, WebsiteSnapshotsService, DirectorySnapshotsService,
|
||||
EndpointSnapshotsService, HostPortMappingSnapshotsService, VulnerabilitySnapshotsService
|
||||
)
|
||||
from apps.common.pagination import BasePagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssetStatisticsViewSet(viewsets.ViewSet):
|
||||
"""
|
||||
资产统计 API
|
||||
|
||||
提供仪表盘所需的统计数据(预聚合,读取缓存表)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = AssetStatisticsService()
|
||||
|
||||
def list(self, request):
|
||||
"""
|
||||
获取资产统计数据
|
||||
|
||||
GET /assets/statistics/
|
||||
|
||||
返回:
|
||||
- totalTargets: 目标总数
|
||||
- totalSubdomains: 子域名总数
|
||||
- totalIps: IP 总数
|
||||
- totalEndpoints: 端点总数
|
||||
- totalWebsites: 网站总数
|
||||
- totalVulns: 漏洞总数
|
||||
- totalAssets: 总资产数
|
||||
- runningScans: 运行中的扫描数
|
||||
- updatedAt: 统计更新时间
|
||||
"""
|
||||
try:
|
||||
stats = self.service.get_statistics()
|
||||
return Response({
|
||||
'totalTargets': stats['total_targets'],
|
||||
'totalSubdomains': stats['total_subdomains'],
|
||||
'totalIps': stats['total_ips'],
|
||||
'totalEndpoints': stats['total_endpoints'],
|
||||
'totalWebsites': stats['total_websites'],
|
||||
'totalVulns': stats['total_vulns'],
|
||||
'totalAssets': stats['total_assets'],
|
||||
'runningScans': stats['running_scans'],
|
||||
'updatedAt': stats['updated_at'],
|
||||
# 变化值
|
||||
'changeTargets': stats['change_targets'],
|
||||
'changeSubdomains': stats['change_subdomains'],
|
||||
'changeIps': stats['change_ips'],
|
||||
'changeEndpoints': stats['change_endpoints'],
|
||||
'changeWebsites': stats['change_websites'],
|
||||
'changeVulns': stats['change_vulns'],
|
||||
'changeAssets': stats['change_assets'],
|
||||
# 漏洞严重程度分布
|
||||
'vulnBySeverity': stats['vuln_by_severity'],
|
||||
})
|
||||
except (DatabaseError, OperationalError) as e:
|
||||
logger.exception("获取资产统计数据失败")
|
||||
return Response(
|
||||
{'error': '获取统计数据失败'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='history')
|
||||
def history(self, request: Request):
|
||||
"""
|
||||
获取统计历史数据(用于折线图)
|
||||
|
||||
GET /assets/statistics/history/?days=7
|
||||
|
||||
Query Parameters:
|
||||
days: 获取最近多少天的数据,默认 7,最大 90
|
||||
|
||||
Returns:
|
||||
历史数据列表
|
||||
"""
|
||||
try:
|
||||
days_param = request.query_params.get('days', '7')
|
||||
try:
|
||||
days = int(days_param)
|
||||
except (ValueError, TypeError):
|
||||
days = 7
|
||||
days = min(max(days, 1), 90) # 限制在 1-90 天
|
||||
|
||||
history = self.service.get_statistics_history(days=days)
|
||||
return Response(history)
|
||||
except (DatabaseError, OperationalError) as e:
|
||||
logger.exception("获取统计历史数据失败")
|
||||
return Response(
|
||||
{'error': '获取历史数据失败'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
# 注意:IPAddress 模型已被重构为 HostPortMapping
|
||||
# IPAddressViewSet 已删除,需要根据新架构重新实现
|
||||
|
||||
|
||||
class SubdomainViewSet(viewsets.ModelViewSet):
|
||||
"""子域名管理 ViewSet
|
||||
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/subdomains/
|
||||
2. 独立路由:GET /api/subdomains/(全局查询)
|
||||
"""
|
||||
|
||||
serializer_class = SubdomainListSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['name']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = SubdomainService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if target_pk:
|
||||
return self.service.get_subdomains_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
"""导出子域名(纯文本,一行一个)"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
raise DRFValidationError('必须在目标下导出')
|
||||
|
||||
def line_iterator():
|
||||
for name in self.service.iter_subdomain_names_by_target(target_pk):
|
||||
yield f"{name}\n"
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
line_iterator(),
|
||||
content_type='text/plain; charset=utf-8',
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-subdomains.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class WebSiteViewSet(viewsets.ModelViewSet):
|
||||
"""站点管理 ViewSet
|
||||
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/websites/
|
||||
2. 独立路由:GET /api/websites/(全局查询)
|
||||
"""
|
||||
|
||||
serializer_class = WebSiteSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['host']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = WebSiteService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if target_pk:
|
||||
return self.service.get_websites_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
"""导出站点 URL(纯文本,一行一个)"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
raise DRFValidationError('必须在目标下导出')
|
||||
|
||||
def line_iterator():
|
||||
for url in self.service.iter_website_urls_by_target(target_pk):
|
||||
yield f"{url}\n"
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
line_iterator(),
|
||||
content_type='text/plain; charset=utf-8',
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-websites.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class DirectoryViewSet(viewsets.ModelViewSet):
|
||||
"""目录管理 ViewSet
|
||||
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/directories/
|
||||
2. 独立路由:GET /api/directories/(全局查询)
|
||||
"""
|
||||
|
||||
serializer_class = DirectorySerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['url']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = DirectoryService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if target_pk:
|
||||
return self.service.get_directories_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
"""导出目录 URL(纯文本,一行一个)"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
raise DRFValidationError('必须在目标下导出')
|
||||
|
||||
def line_iterator():
|
||||
for url in self.service.iter_directory_urls_by_target(target_pk):
|
||||
yield f"{url}\n"
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
line_iterator(),
|
||||
content_type='text/plain; charset=utf-8',
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-directories.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class EndpointViewSet(viewsets.ModelViewSet):
|
||||
"""端点管理 ViewSet
|
||||
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/endpoints/
|
||||
2. 独立路由:GET /api/endpoints/(全局查询)
|
||||
"""
|
||||
|
||||
serializer_class = EndpointListSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['host']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = EndpointService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if target_pk:
|
||||
return self.service.get_endpoints_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
"""导出端点 URL(纯文本,一行一个)"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
raise DRFValidationError('必须在目标下导出')
|
||||
|
||||
def line_iterator():
|
||||
for url in self.service.iter_endpoint_urls_by_target(target_pk):
|
||||
yield f"{url}\n"
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
line_iterator(),
|
||||
content_type='text/plain; charset=utf-8',
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-endpoints.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class HostPortMappingViewSet(viewsets.ModelViewSet):
|
||||
"""主机端口映射管理 ViewSet(IP 地址聚合视图)
|
||||
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/ip-addresses/
|
||||
2. 独立路由:GET /api/ip-addresses/(全局查询)
|
||||
|
||||
返回按 IP 聚合的数据,每个 IP 显示其关联的所有 hosts 和 ports
|
||||
|
||||
注意:由于返回的是聚合数据(字典列表),不支持 DRF SearchFilter
|
||||
"""
|
||||
|
||||
serializer_class = IPAddressAggregatedSerializer
|
||||
pagination_class = BasePagination
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = HostPortMappingService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围,返回按 IP 聚合的数据"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
search = self.request.query_params.get('search', None)
|
||||
if target_pk:
|
||||
return self.service.get_ip_aggregation_by_target(target_pk, search=search)
|
||||
return self.service.get_all_ip_aggregation(search=search)
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
"""导出 IP 地址(纯文本,一行一个)"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
raise DRFValidationError('必须在目标下导出')
|
||||
|
||||
def line_iterator():
|
||||
for ip in self.service.iter_ips_by_target(target_pk):
|
||||
yield f"{ip}\n"
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
line_iterator(),
|
||||
content_type='text/plain; charset=utf-8',
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-ip-addresses.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class VulnerabilityViewSet(viewsets.ModelViewSet):
|
||||
"""漏洞资产管理 ViewSet(只读)
|
||||
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/vulnerabilities/
|
||||
2. 独立路由:GET /api/vulnerabilities/(全局查询)
|
||||
"""
|
||||
|
||||
serializer_class = VulnerabilitySerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['vuln_type']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = VulnerabilityService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if target_pk:
|
||||
return self.service.get_vulnerabilities_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
|
||||
# ==================== 快照 ViewSet(Scan 嵌套路由) ====================
|
||||
|
||||
class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""子域名快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/subdomains/"""
|
||||
|
||||
serializer_class = SubdomainSnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['name']
|
||||
ordering_fields = ['name', 'discovered_at']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = SubdomainSnapshotsService()
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
raise DRFValidationError('必须在扫描下导出')
|
||||
|
||||
def line_iterator():
|
||||
for name in self.service.iter_subdomain_names_by_scan(scan_pk):
|
||||
yield f"{name}\n"
|
||||
|
||||
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-subdomains.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""网站快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/websites/"""
|
||||
|
||||
serializer_class = WebsiteSnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['host']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = WebsiteSnapshotsService()
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
raise DRFValidationError('必须在扫描下导出')
|
||||
|
||||
def line_iterator():
|
||||
for url in self.service.iter_website_urls_by_scan(scan_pk):
|
||||
yield f"{url}\n"
|
||||
|
||||
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-websites.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class DirectorySnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""目录快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/directories/"""
|
||||
|
||||
serializer_class = DirectorySnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['url']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = DirectorySnapshotsService()
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
raise DRFValidationError('必须在扫描下导出')
|
||||
|
||||
def line_iterator():
|
||||
for url in self.service.iter_directory_urls_by_scan(scan_pk):
|
||||
yield f"{url}\n"
|
||||
|
||||
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-directories.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class EndpointSnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""端点快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/endpoints/"""
|
||||
|
||||
serializer_class = EndpointSnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['host']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = EndpointSnapshotsService()
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
raise DRFValidationError('必须在扫描下导出')
|
||||
|
||||
def line_iterator():
|
||||
for url in self.service.iter_endpoint_urls_by_scan(scan_pk):
|
||||
yield f"{url}\n"
|
||||
|
||||
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-endpoints.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""主机端口映射快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/ip-addresses/
|
||||
|
||||
注意:由于返回的是聚合数据(字典列表),不支持 DRF SearchFilter
|
||||
"""
|
||||
|
||||
serializer_class = IPAddressAggregatedSerializer
|
||||
pagination_class = BasePagination
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = HostPortMappingSnapshotsService()
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
search = self.request.query_params.get('search', None)
|
||||
if scan_pk:
|
||||
return self.service.get_ip_aggregation_by_scan(scan_pk, search=search)
|
||||
return self.service.get_all_ip_aggregation(search=search)
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
raise DRFValidationError('必须在扫描下导出')
|
||||
|
||||
def line_iterator():
|
||||
for ip in self.service.iter_ips_by_scan(scan_pk):
|
||||
yield f"{ip}\n"
|
||||
|
||||
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-ip-addresses.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""漏洞快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/vulnerabilities/"""
|
||||
|
||||
serializer_class = VulnerabilitySnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['vuln_type']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = VulnerabilitySnapshotsService()
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
40
backend/apps/asset/views/__init__.py
Normal file
40
backend/apps/asset/views/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Asset 应用视图模块
|
||||
|
||||
重新导出所有视图类以保持向后兼容
|
||||
"""
|
||||
|
||||
from .asset_views import (
|
||||
AssetStatisticsViewSet,
|
||||
SubdomainViewSet,
|
||||
WebSiteViewSet,
|
||||
DirectoryViewSet,
|
||||
EndpointViewSet,
|
||||
HostPortMappingViewSet,
|
||||
VulnerabilityViewSet,
|
||||
SubdomainSnapshotViewSet,
|
||||
WebsiteSnapshotViewSet,
|
||||
DirectorySnapshotViewSet,
|
||||
EndpointSnapshotViewSet,
|
||||
HostPortMappingSnapshotViewSet,
|
||||
VulnerabilitySnapshotViewSet,
|
||||
)
|
||||
from .search_views import AssetSearchView, AssetSearchExportView
|
||||
|
||||
__all__ = [
|
||||
'AssetStatisticsViewSet',
|
||||
'SubdomainViewSet',
|
||||
'WebSiteViewSet',
|
||||
'DirectoryViewSet',
|
||||
'EndpointViewSet',
|
||||
'HostPortMappingViewSet',
|
||||
'VulnerabilityViewSet',
|
||||
'SubdomainSnapshotViewSet',
|
||||
'WebsiteSnapshotViewSet',
|
||||
'DirectorySnapshotViewSet',
|
||||
'EndpointSnapshotViewSet',
|
||||
'HostPortMappingSnapshotViewSet',
|
||||
'VulnerabilitySnapshotViewSet',
|
||||
'AssetSearchView',
|
||||
'AssetSearchExportView',
|
||||
]
|
||||
1084
backend/apps/asset/views/asset_views.py
Normal file
1084
backend/apps/asset/views/asset_views.py
Normal file
File diff suppressed because it is too large
Load Diff
364
backend/apps/asset/views/search_views.py
Normal file
364
backend/apps/asset/views/search_views.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
资产搜索 API 视图
|
||||
|
||||
提供资产搜索的 REST API 接口:
|
||||
- GET /api/assets/search/ - 搜索资产
|
||||
- GET /api/assets/search/export/ - 导出搜索结果为 CSV
|
||||
|
||||
搜索语法:
|
||||
- field="value" 模糊匹配(ILIKE %value%)
|
||||
- field=="value" 精确匹配
|
||||
- field!="value" 不等于
|
||||
- && AND 连接
|
||||
- || OR 连接
|
||||
|
||||
支持的字段:
|
||||
- host: 主机名
|
||||
- url: URL
|
||||
- title: 标题
|
||||
- tech: 技术栈
|
||||
- status: 状态码
|
||||
- body: 响应体
|
||||
- header: 响应头
|
||||
|
||||
支持的资产类型:
|
||||
- website: 站点(默认)
|
||||
- endpoint: 端点
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from rest_framework import status
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.request import Request
|
||||
from django.http import StreamingHttpResponse
|
||||
from django.db import connection
|
||||
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
from apps.asset.services.search_service import AssetSearchService, VALID_ASSET_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssetSearchView(APIView):
|
||||
"""
|
||||
资产搜索 API
|
||||
|
||||
GET /api/assets/search/
|
||||
|
||||
Query Parameters:
|
||||
q: 搜索查询表达式
|
||||
asset_type: 资产类型 ('website' 或 'endpoint',默认 'website')
|
||||
page: 页码(从 1 开始,默认 1)
|
||||
pageSize: 每页数量(默认 10,最大 100)
|
||||
|
||||
示例查询:
|
||||
?q=host="api" && tech="nginx"
|
||||
?q=tech="vue" || tech="react"&asset_type=endpoint
|
||||
?q=status=="200" && host!="test"
|
||||
|
||||
Response:
|
||||
{
|
||||
"results": [...],
|
||||
"total": 100,
|
||||
"page": 1,
|
||||
"pageSize": 10,
|
||||
"totalPages": 10,
|
||||
"assetType": "website"
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = AssetSearchService()
|
||||
|
||||
def _parse_headers(self, headers_data) -> dict:
|
||||
"""解析响应头为字典"""
|
||||
if not headers_data:
|
||||
return {}
|
||||
try:
|
||||
return json.loads(headers_data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
result = {}
|
||||
for line in str(headers_data).split('\n'):
|
||||
if ':' in line:
|
||||
key, value = line.split(':', 1)
|
||||
result[key.strip()] = value.strip()
|
||||
return result
|
||||
|
||||
def _format_result(self, result: dict, vulnerabilities_by_url: dict, asset_type: str) -> dict:
|
||||
"""格式化单个搜索结果"""
|
||||
url = result.get('url', '')
|
||||
vulns = vulnerabilities_by_url.get(url, [])
|
||||
|
||||
# 基础字段(Website 和 Endpoint 共有)
|
||||
formatted = {
|
||||
'id': result.get('id'),
|
||||
'url': url,
|
||||
'host': result.get('host', ''),
|
||||
'title': result.get('title', ''),
|
||||
'technologies': result.get('tech', []) or [],
|
||||
'statusCode': result.get('status_code'),
|
||||
'contentLength': result.get('content_length'),
|
||||
'contentType': result.get('content_type', ''),
|
||||
'webserver': result.get('webserver', ''),
|
||||
'location': result.get('location', ''),
|
||||
'vhost': result.get('vhost'),
|
||||
'responseHeaders': self._parse_headers(result.get('response_headers')),
|
||||
'responseBody': result.get('response_body', ''),
|
||||
'createdAt': result.get('created_at').isoformat() if result.get('created_at') else None,
|
||||
'targetId': result.get('target_id'),
|
||||
}
|
||||
|
||||
# Website 特有字段:漏洞关联
|
||||
if asset_type == 'website':
|
||||
formatted['vulnerabilities'] = [
|
||||
{
|
||||
'id': v.get('id'),
|
||||
'name': v.get('vuln_type', ''),
|
||||
'vulnType': v.get('vuln_type', ''),
|
||||
'severity': v.get('severity', 'info'),
|
||||
}
|
||||
for v in vulns
|
||||
]
|
||||
|
||||
# Endpoint 特有字段
|
||||
if asset_type == 'endpoint':
|
||||
formatted['matchedGfPatterns'] = result.get('matched_gf_patterns', []) or []
|
||||
|
||||
return formatted
|
||||
|
||||
def _get_vulnerabilities_by_url_prefix(self, website_urls: list) -> dict:
|
||||
"""
|
||||
根据 URL 前缀批量查询漏洞数据
|
||||
|
||||
漏洞 URL 是 website URL 的子路径,使用前缀匹配:
|
||||
- website.url: https://example.com/path?query=1
|
||||
- vulnerability.url: https://example.com/path/api/users
|
||||
|
||||
Args:
|
||||
website_urls: website URL 列表,格式为 [(url, target_id), ...]
|
||||
|
||||
Returns:
|
||||
dict: {website_url: [vulnerability_list]}
|
||||
"""
|
||||
if not website_urls:
|
||||
return {}
|
||||
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
# 构建 OR 条件:每个 website URL(去掉查询参数)作为前缀匹配
|
||||
conditions = []
|
||||
params = []
|
||||
url_mapping = {} # base_url -> original_url
|
||||
|
||||
for url, target_id in website_urls:
|
||||
if not url or target_id is None:
|
||||
continue
|
||||
# 使用 urlparse 去掉查询参数和片段,只保留 scheme://netloc/path
|
||||
parsed = urlparse(url)
|
||||
base_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, '', '', ''))
|
||||
url_mapping[base_url] = url
|
||||
conditions.append("(v.url LIKE %s AND v.target_id = %s)")
|
||||
params.extend([base_url + '%', target_id])
|
||||
|
||||
if not conditions:
|
||||
return {}
|
||||
|
||||
where_clause = " OR ".join(conditions)
|
||||
|
||||
sql = f"""
|
||||
SELECT v.id, v.vuln_type, v.severity, v.url, v.target_id
|
||||
FROM vulnerability v
|
||||
WHERE {where_clause}
|
||||
ORDER BY
|
||||
CASE v.severity
|
||||
WHEN 'critical' THEN 1
|
||||
WHEN 'high' THEN 2
|
||||
WHEN 'medium' THEN 3
|
||||
WHEN 'low' THEN 4
|
||||
ELSE 5
|
||||
END
|
||||
"""
|
||||
cursor.execute(sql, params)
|
||||
|
||||
# 获取所有漏洞
|
||||
all_vulns = []
|
||||
for row in cursor.fetchall():
|
||||
all_vulns.append({
|
||||
'id': row[0],
|
||||
'vuln_type': row[1],
|
||||
'name': row[1],
|
||||
'severity': row[2],
|
||||
'url': row[3],
|
||||
'target_id': row[4],
|
||||
})
|
||||
|
||||
# 按原始 website URL 分组(用于返回结果)
|
||||
result = {url: [] for url, _ in website_urls}
|
||||
for vuln in all_vulns:
|
||||
vuln_url = vuln['url']
|
||||
# 找到匹配的 website URL(最长前缀匹配)
|
||||
for website_url, target_id in website_urls:
|
||||
parsed = urlparse(website_url)
|
||||
base_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, '', '', ''))
|
||||
if vuln_url.startswith(base_url) and vuln['target_id'] == target_id:
|
||||
result[website_url].append(vuln)
|
||||
break
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"批量查询漏洞失败: {e}")
|
||||
return {}
|
||||
|
||||
def get(self, request: Request):
|
||||
"""搜索资产"""
|
||||
# 获取搜索查询
|
||||
query = request.query_params.get('q', '').strip()
|
||||
|
||||
if not query:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Search query (q) is required',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 获取并验证资产类型
|
||||
asset_type = request.query_params.get('asset_type', 'website').strip().lower()
|
||||
if asset_type not in VALID_ASSET_TYPES:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=f'Invalid asset_type. Must be one of: {", ".join(VALID_ASSET_TYPES)}',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 获取分页参数
|
||||
try:
|
||||
page = int(request.query_params.get('page', 1))
|
||||
page_size = int(request.query_params.get('pageSize', 10))
|
||||
except (ValueError, TypeError):
|
||||
page = 1
|
||||
page_size = 10
|
||||
|
||||
# 限制分页参数
|
||||
page = max(1, page)
|
||||
page_size = min(max(1, page_size), 100)
|
||||
|
||||
# 获取总数和搜索结果
|
||||
total = self.service.count(query, asset_type)
|
||||
total_pages = (total + page_size - 1) // page_size if total > 0 else 1
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
all_results = self.service.search(query, asset_type)
|
||||
results = all_results[offset:offset + page_size]
|
||||
|
||||
# 批量查询漏洞数据(仅 Website 类型需要)
|
||||
vulnerabilities_by_url = {}
|
||||
if asset_type == 'website':
|
||||
website_urls = [(r.get('url'), r.get('target_id')) for r in results if r.get('url') and r.get('target_id')]
|
||||
vulnerabilities_by_url = self._get_vulnerabilities_by_url_prefix(website_urls) if website_urls else {}
|
||||
|
||||
# 格式化结果
|
||||
formatted_results = [self._format_result(r, vulnerabilities_by_url, asset_type) for r in results]
|
||||
|
||||
return success_response(data={
|
||||
'results': formatted_results,
|
||||
'total': total,
|
||||
'page': page,
|
||||
'pageSize': page_size,
|
||||
'totalPages': total_pages,
|
||||
'assetType': asset_type,
|
||||
})
|
||||
|
||||
|
||||
class AssetSearchExportView(APIView):
|
||||
"""
|
||||
资产搜索导出 API
|
||||
|
||||
GET /api/assets/search/export/
|
||||
|
||||
Query Parameters:
|
||||
q: 搜索查询表达式
|
||||
asset_type: 资产类型 ('website' 或 'endpoint',默认 'website')
|
||||
|
||||
Response:
|
||||
CSV 文件流(使用服务端游标,支持大数据量导出)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = AssetSearchService()
|
||||
|
||||
def _get_headers_and_formatters(self, asset_type: str):
|
||||
"""获取 CSV 表头和格式化器"""
|
||||
from apps.common.utils import format_datetime, format_list_field
|
||||
|
||||
if asset_type == 'website':
|
||||
headers = ['url', 'host', 'title', 'status_code', 'content_type', 'content_length',
|
||||
'webserver', 'location', 'tech', 'vhost', 'created_at']
|
||||
else:
|
||||
headers = ['url', 'host', 'title', 'status_code', 'content_type', 'content_length',
|
||||
'webserver', 'location', 'tech', 'matched_gf_patterns', 'vhost', 'created_at']
|
||||
|
||||
formatters = {
|
||||
'created_at': format_datetime,
|
||||
'tech': lambda x: format_list_field(x, separator='; '),
|
||||
'matched_gf_patterns': lambda x: format_list_field(x, separator='; '),
|
||||
'vhost': lambda x: 'true' if x else ('false' if x is False else ''),
|
||||
}
|
||||
|
||||
return headers, formatters
|
||||
|
||||
def get(self, request: Request):
|
||||
"""导出搜索结果为 CSV(流式导出,无数量限制)"""
|
||||
from apps.common.utils import generate_csv_rows
|
||||
|
||||
# 获取搜索查询
|
||||
query = request.query_params.get('q', '').strip()
|
||||
|
||||
if not query:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Search query (q) is required',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 获取并验证资产类型
|
||||
asset_type = request.query_params.get('asset_type', 'website').strip().lower()
|
||||
if asset_type not in VALID_ASSET_TYPES:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=f'Invalid asset_type. Must be one of: {", ".join(VALID_ASSET_TYPES)}',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 检查是否有结果(快速检查,避免空导出)
|
||||
total = self.service.count(query, asset_type)
|
||||
if total == 0:
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='No results to export',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
# 获取表头和格式化器
|
||||
headers, formatters = self._get_headers_and_formatters(asset_type)
|
||||
|
||||
# 获取流式数据迭代器
|
||||
data_iterator = self.service.search_iter(query, asset_type)
|
||||
|
||||
# 生成文件名
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
filename = f'search_{asset_type}_{timestamp}.csv'
|
||||
|
||||
# 返回流式响应
|
||||
response = StreamingHttpResponse(
|
||||
generate_csv_rows(data_iterator, headers, formatters),
|
||||
content_type='text/csv; charset=utf-8'
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="{filename}"'
|
||||
|
||||
return response
|
||||
@@ -40,8 +40,14 @@ def fetch_config_and_setup_django():
|
||||
print(f"[CONFIG] 正在从配置中心获取配置: {config_url}")
|
||||
print(f"[CONFIG] IS_LOCAL={is_local}")
|
||||
try:
|
||||
# 构建请求头(包含 Worker API Key)
|
||||
headers = {}
|
||||
worker_api_key = os.environ.get("WORKER_API_KEY", "")
|
||||
if worker_api_key:
|
||||
headers["X-Worker-API-Key"] = worker_api_key
|
||||
|
||||
# verify=False: 远程 Worker 通过 HTTPS 访问时可能使用自签名证书
|
||||
resp = requests.get(config_url, timeout=10, verify=False)
|
||||
resp = requests.get(config_url, headers=headers, timeout=10, verify=False)
|
||||
resp.raise_for_status()
|
||||
config = resp.json()
|
||||
|
||||
@@ -57,9 +63,6 @@ def fetch_config_and_setup_django():
|
||||
os.environ.setdefault("DB_USER", db_user)
|
||||
os.environ.setdefault("DB_PASSWORD", config['db']['password'])
|
||||
|
||||
# Redis 配置
|
||||
os.environ.setdefault("REDIS_URL", config['redisUrl'])
|
||||
|
||||
# 日志配置
|
||||
os.environ.setdefault("LOG_DIR", config['paths']['logs'])
|
||||
os.environ.setdefault("LOG_LEVEL", config['logging']['level'])
|
||||
@@ -71,7 +74,6 @@ def fetch_config_and_setup_django():
|
||||
print(f"[CONFIG] DB_PORT: {db_port}")
|
||||
print(f"[CONFIG] DB_NAME: {db_name}")
|
||||
print(f"[CONFIG] DB_USER: {db_user}")
|
||||
print(f"[CONFIG] REDIS_URL: {config['redisUrl']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 获取配置失败: {config_url} - {e}", file=sys.stderr)
|
||||
|
||||
31
backend/apps/common/error_codes.py
Normal file
31
backend/apps/common/error_codes.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
标准化错误码定义
|
||||
|
||||
采用简化方案(参考 Stripe、GitHub 等大厂做法):
|
||||
- 只定义 5-10 个通用错误码
|
||||
- 未知错误使用通用错误码
|
||||
- 错误码格式:大写字母和下划线组成
|
||||
"""
|
||||
|
||||
|
||||
class ErrorCodes:
|
||||
"""标准化错误码
|
||||
|
||||
只定义通用错误码,其他错误使用通用消息。
|
||||
这是 Stripe、GitHub 等大厂的标准做法。
|
||||
|
||||
错误码格式规范:
|
||||
- 使用大写字母和下划线
|
||||
- 简洁明了,易于理解
|
||||
- 前端通过错误码映射到 i18n 键
|
||||
"""
|
||||
|
||||
# 通用错误码(8 个)
|
||||
VALIDATION_ERROR = 'VALIDATION_ERROR' # 输入验证失败
|
||||
NOT_FOUND = 'NOT_FOUND' # 资源未找到
|
||||
PERMISSION_DENIED = 'PERMISSION_DENIED' # 权限不足
|
||||
SERVER_ERROR = 'SERVER_ERROR' # 服务器内部错误
|
||||
BAD_REQUEST = 'BAD_REQUEST' # 请求格式错误
|
||||
CONFLICT = 'CONFLICT' # 资源冲突(如重复创建)
|
||||
UNAUTHORIZED = 'UNAUTHORIZED' # 未认证
|
||||
RATE_LIMITED = 'RATE_LIMITED' # 请求过于频繁
|
||||
49
backend/apps/common/exception_handlers.py
Normal file
49
backend/apps/common/exception_handlers.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
自定义异常处理器
|
||||
|
||||
统一处理 DRF 异常,确保错误响应格式一致
|
||||
"""
|
||||
|
||||
from rest_framework.views import exception_handler
|
||||
from rest_framework import status
|
||||
from rest_framework.exceptions import AuthenticationFailed, NotAuthenticated
|
||||
|
||||
from apps.common.response_helpers import error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
|
||||
|
||||
def custom_exception_handler(exc, context):
|
||||
"""
|
||||
自定义异常处理器
|
||||
|
||||
处理认证相关异常,返回统一格式的错误响应
|
||||
"""
|
||||
# 先调用 DRF 默认的异常处理器
|
||||
response = exception_handler(exc, context)
|
||||
|
||||
if response is not None:
|
||||
# 处理 401 未认证错误
|
||||
if response.status_code == status.HTTP_401_UNAUTHORIZED:
|
||||
return error_response(
|
||||
code=ErrorCodes.UNAUTHORIZED,
|
||||
message='Authentication required',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# 处理 403 权限不足错误
|
||||
if response.status_code == status.HTTP_403_FORBIDDEN:
|
||||
return error_response(
|
||||
code=ErrorCodes.PERMISSION_DENIED,
|
||||
message='Permission denied',
|
||||
status_code=status.HTTP_403_FORBIDDEN
|
||||
)
|
||||
|
||||
# 处理 NotAuthenticated 和 AuthenticationFailed 异常
|
||||
if isinstance(exc, (NotAuthenticated, AuthenticationFailed)):
|
||||
return error_response(
|
||||
code=ErrorCodes.UNAUTHORIZED,
|
||||
message='Authentication required',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
return response
|
||||
80
backend/apps/common/permissions.py
Normal file
80
backend/apps/common/permissions.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
集中式权限管理
|
||||
|
||||
实现三类端点的认证逻辑:
|
||||
1. 公开端点(无需认证):登录、登出、获取当前用户状态
|
||||
2. Worker 端点(API Key 认证):注册、配置、心跳、回调、资源同步
|
||||
3. 业务端点(Session 认证):其他所有 API
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from django.conf import settings
|
||||
from rest_framework.permissions import BasePermission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 公开端点白名单(无需任何认证)
|
||||
PUBLIC_ENDPOINTS = [
|
||||
r'^/api/auth/login/$',
|
||||
r'^/api/auth/logout/$',
|
||||
r'^/api/auth/me/$',
|
||||
]
|
||||
|
||||
# Worker API 端点(需要 API Key 认证)
|
||||
# 包括:注册、配置、心跳、回调、资源同步(字典下载)
|
||||
WORKER_ENDPOINTS = [
|
||||
r'^/api/workers/register/$',
|
||||
r'^/api/workers/config/$',
|
||||
r'^/api/workers/\d+/heartbeat/$',
|
||||
r'^/api/callbacks/',
|
||||
# 资源同步端点(Worker 需要下载字典文件)
|
||||
r'^/api/wordlists/download/$',
|
||||
# 注意:指纹导出 API 使用 Session 认证(前端用户导出用)
|
||||
# Worker 通过数据库直接获取指纹数据,不需要 HTTP API
|
||||
]
|
||||
|
||||
|
||||
class IsAuthenticatedOrPublic(BasePermission):
|
||||
"""
|
||||
自定义权限类:
|
||||
- 白名单内的端点公开访问
|
||||
- Worker 端点需要 API Key 认证
|
||||
- 其他端点需要 Session 认证
|
||||
"""
|
||||
|
||||
def has_permission(self, request, view):
|
||||
path = request.path
|
||||
|
||||
# 检查是否在公开白名单内
|
||||
for pattern in PUBLIC_ENDPOINTS:
|
||||
if re.match(pattern, path):
|
||||
return True
|
||||
|
||||
# 检查是否是 Worker 端点
|
||||
for pattern in WORKER_ENDPOINTS:
|
||||
if re.match(pattern, path):
|
||||
return self._check_worker_api_key(request)
|
||||
|
||||
# 其他路径需要 Session 认证
|
||||
return request.user and request.user.is_authenticated
|
||||
|
||||
def _check_worker_api_key(self, request):
|
||||
"""验证 Worker API Key"""
|
||||
api_key = request.headers.get('X-Worker-API-Key')
|
||||
expected_key = getattr(settings, 'WORKER_API_KEY', None)
|
||||
|
||||
if not expected_key:
|
||||
# 未配置 API Key 时,拒绝所有 Worker 请求
|
||||
logger.warning("WORKER_API_KEY 未配置,拒绝 Worker 请求")
|
||||
return False
|
||||
|
||||
if not api_key:
|
||||
logger.warning(f"Worker 请求缺少 X-Worker-API-Key Header: {request.path}")
|
||||
return False
|
||||
|
||||
if api_key != expected_key:
|
||||
logger.warning(f"Worker API Key 无效: {request.path}")
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -16,6 +16,7 @@ def setup_django_for_prefect():
|
||||
1. 添加项目根目录到 Python 路径
|
||||
2. 设置 DJANGO_SETTINGS_MODULE 环境变量
|
||||
3. 调用 django.setup() 初始化 Django
|
||||
4. 关闭旧的数据库连接,确保使用新连接
|
||||
|
||||
使用方式:
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
@@ -36,6 +37,25 @@ def setup_django_for_prefect():
|
||||
# 初始化 Django
|
||||
import django
|
||||
django.setup()
|
||||
|
||||
# 关闭所有旧的数据库连接,确保 Worker 进程使用新连接
|
||||
# 解决 "server closed the connection unexpectedly" 问题
|
||||
from django.db import connections
|
||||
connections.close_all()
|
||||
|
||||
|
||||
def close_old_db_connections():
|
||||
"""
|
||||
关闭旧的数据库连接
|
||||
|
||||
在长时间运行的任务中调用此函数,可以确保使用有效的数据库连接。
|
||||
适用于:
|
||||
- Flow 开始前
|
||||
- Task 开始前
|
||||
- 长时间空闲后恢复操作前
|
||||
"""
|
||||
from django.db import connections
|
||||
connections.close_all()
|
||||
|
||||
|
||||
# 自动执行初始化(导入即生效)
|
||||
|
||||
88
backend/apps/common/response_helpers.py
Normal file
88
backend/apps/common/response_helpers.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
标准化 API 响应辅助函数
|
||||
|
||||
遵循行业标准(RFC 9457 Problem Details)和大厂实践(Google、Stripe、GitHub):
|
||||
- 成功响应只包含数据,不包含 message 字段
|
||||
- 错误响应使用机器可读的错误码,前端映射到 i18n 消息
|
||||
"""
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
|
||||
|
||||
def success_response(
|
||||
data: Optional[Union[Dict[str, Any], List[Any]]] = None,
|
||||
status_code: int = status.HTTP_200_OK
|
||||
) -> Response:
|
||||
"""
|
||||
标准化成功响应
|
||||
|
||||
直接返回数据,不做包装,符合 Stripe/GitHub 等大厂标准。
|
||||
|
||||
Args:
|
||||
data: 响应数据(dict 或 list)
|
||||
status_code: HTTP 状态码,默认 200
|
||||
|
||||
Returns:
|
||||
Response: DRF Response 对象
|
||||
|
||||
Examples:
|
||||
# 单个资源
|
||||
>>> success_response(data={'id': 1, 'name': 'Test'})
|
||||
{'id': 1, 'name': 'Test'}
|
||||
|
||||
# 操作结果
|
||||
>>> success_response(data={'count': 3, 'scans': [...]})
|
||||
{'count': 3, 'scans': [...]}
|
||||
|
||||
# 创建资源
|
||||
>>> success_response(data={'id': 1}, status_code=201)
|
||||
"""
|
||||
# 注意:不能使用 data or {},因为空列表 [] 会被转换为 {}
|
||||
if data is None:
|
||||
data = {}
|
||||
return Response(data, status=status_code)
|
||||
|
||||
|
||||
def error_response(
|
||||
code: str,
|
||||
message: Optional[str] = None,
|
||||
details: Optional[List[Dict[str, Any]]] = None,
|
||||
status_code: int = status.HTTP_400_BAD_REQUEST
|
||||
) -> Response:
|
||||
"""
|
||||
标准化错误响应
|
||||
|
||||
Args:
|
||||
code: 错误码(如 'VALIDATION_ERROR', 'NOT_FOUND')
|
||||
格式:大写字母和下划线组成
|
||||
message: 开发者调试信息(非用户显示)
|
||||
details: 详细错误信息(如字段级验证错误)
|
||||
status_code: HTTP 状态码,默认 400
|
||||
|
||||
Returns:
|
||||
Response: DRF Response 对象
|
||||
|
||||
Examples:
|
||||
# 简单错误
|
||||
>>> error_response(code='NOT_FOUND', status_code=404)
|
||||
{'error': {'code': 'NOT_FOUND'}}
|
||||
|
||||
# 带调试信息
|
||||
>>> error_response(
|
||||
... code='VALIDATION_ERROR',
|
||||
... message='Invalid input data',
|
||||
... details=[{'field': 'name', 'message': 'Required'}]
|
||||
... )
|
||||
{'error': {'code': 'VALIDATION_ERROR', 'message': '...', 'details': [...]}}
|
||||
"""
|
||||
error_body: Dict[str, Any] = {'code': code}
|
||||
|
||||
if message:
|
||||
error_body['message'] = message
|
||||
|
||||
if details:
|
||||
error_body['details'] = details
|
||||
|
||||
return Response({'error': error_body}, status=status_code)
|
||||
@@ -3,8 +3,13 @@
|
||||
|
||||
提供系统级别的公共服务,包括:
|
||||
- SystemLogService: 系统日志读取服务
|
||||
|
||||
注意:FilterService 已移至 apps.common.utils.filter_utils
|
||||
推荐使用: from apps.common.utils.filter_utils import apply_filters
|
||||
"""
|
||||
|
||||
from .system_log_service import SystemLogService
|
||||
|
||||
__all__ = ['SystemLogService']
|
||||
__all__ = [
|
||||
'SystemLogService',
|
||||
]
|
||||
|
||||
@@ -4,15 +4,28 @@
|
||||
提供系统日志的读取功能,支持:
|
||||
- 从日志目录读取日志文件
|
||||
- 限制返回行数,防止内存溢出
|
||||
- 列出可用的日志文件
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime, timezone
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogFileInfo(TypedDict):
|
||||
"""日志文件信息"""
|
||||
filename: str
|
||||
category: str # 'system' | 'error' | 'performance' | 'container'
|
||||
size: int
|
||||
modifiedAt: str # ISO 8601 格式
|
||||
|
||||
|
||||
class SystemLogService:
|
||||
"""
|
||||
系统日志服务类
|
||||
@@ -20,23 +33,131 @@ class SystemLogService:
|
||||
负责读取系统日志文件,支持从容器内路径或宿主机挂载路径读取日志。
|
||||
"""
|
||||
|
||||
# 日志文件分类规则
|
||||
CATEGORY_RULES = [
|
||||
('xingrin.log', 'system'),
|
||||
('xingrin_error.log', 'error'),
|
||||
('performance.log', 'performance'),
|
||||
('container_*.log', 'container'),
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
# 日志文件路径(容器内路径,通过 volume 挂载到宿主机 /opt/xingrin/logs)
|
||||
self.log_file = "/app/backend/logs/xingrin.log"
|
||||
self.default_lines = 200 # 默认返回行数
|
||||
self.max_lines = 10000 # 最大返回行数限制
|
||||
self.timeout_seconds = 3 # tail 命令超时时间
|
||||
# 日志目录路径
|
||||
self.log_dir = "/opt/xingrin/logs"
|
||||
self.default_file = "xingrin.log" # 默认日志文件
|
||||
self.default_lines = 200 # 默认返回行数
|
||||
self.max_lines = 10000 # 最大返回行数限制
|
||||
self.timeout_seconds = 3 # tail 命令超时时间
|
||||
|
||||
def get_logs_content(self, lines: int | None = None) -> str:
|
||||
def _categorize_file(self, filename: str) -> str | None:
|
||||
"""
|
||||
根据文件名判断日志分类
|
||||
|
||||
Returns:
|
||||
分类名称,如果不是日志文件则返回 None
|
||||
"""
|
||||
for pattern, category in self.CATEGORY_RULES:
|
||||
if fnmatch.fnmatch(filename, pattern):
|
||||
return category
|
||||
return None
|
||||
|
||||
def _validate_filename(self, filename: str) -> bool:
|
||||
"""
|
||||
验证文件名是否合法(防止路径遍历攻击)
|
||||
|
||||
Args:
|
||||
filename: 要验证的文件名
|
||||
|
||||
Returns:
|
||||
bool: 文件名是否合法
|
||||
"""
|
||||
# 不允许包含路径分隔符
|
||||
if '/' in filename or '\\' in filename:
|
||||
return False
|
||||
# 不允许 .. 路径遍历
|
||||
if '..' in filename:
|
||||
return False
|
||||
# 必须是已知的日志文件类型
|
||||
return self._categorize_file(filename) is not None
|
||||
|
||||
def get_log_files(self) -> list[LogFileInfo]:
|
||||
"""
|
||||
获取所有可用的日志文件列表
|
||||
|
||||
Returns:
|
||||
日志文件信息列表,按分类和文件名排序
|
||||
"""
|
||||
files: list[LogFileInfo] = []
|
||||
|
||||
if not os.path.isdir(self.log_dir):
|
||||
logger.warning("日志目录不存在: %s", self.log_dir)
|
||||
return files
|
||||
|
||||
for filename in os.listdir(self.log_dir):
|
||||
filepath = os.path.join(self.log_dir, filename)
|
||||
|
||||
# 只处理文件,跳过目录
|
||||
if not os.path.isfile(filepath):
|
||||
continue
|
||||
|
||||
# 判断分类
|
||||
category = self._categorize_file(filename)
|
||||
if category is None:
|
||||
continue
|
||||
|
||||
# 获取文件信息
|
||||
try:
|
||||
stat = os.stat(filepath)
|
||||
modified_at = datetime.fromtimestamp(
|
||||
stat.st_mtime, tz=timezone.utc
|
||||
).isoformat()
|
||||
|
||||
files.append({
|
||||
'filename': filename,
|
||||
'category': category,
|
||||
'size': stat.st_size,
|
||||
'modifiedAt': modified_at,
|
||||
})
|
||||
except OSError as e:
|
||||
logger.warning("获取文件信息失败 %s: %s", filepath, e)
|
||||
continue
|
||||
|
||||
# 排序:按分类优先级(system > error > performance > container),然后按文件名
|
||||
category_order = {'system': 0, 'error': 1, 'performance': 2, 'container': 3}
|
||||
files.sort(key=lambda f: (category_order.get(f['category'], 99), f['filename']))
|
||||
|
||||
return files
|
||||
|
||||
def get_logs_content(self, filename: str | None = None, lines: int | None = None) -> str:
|
||||
"""
|
||||
获取系统日志内容
|
||||
|
||||
Args:
|
||||
filename: 日志文件名,默认为 xingrin.log
|
||||
lines: 返回的日志行数,默认 200 行,最大 10000 行
|
||||
|
||||
Returns:
|
||||
str: 日志内容,每行以换行符分隔,保持原始顺序
|
||||
|
||||
Raises:
|
||||
ValueError: 文件名不合法
|
||||
FileNotFoundError: 日志文件不存在
|
||||
"""
|
||||
# 文件名处理
|
||||
if filename is None:
|
||||
filename = self.default_file
|
||||
|
||||
# 验证文件名
|
||||
if not self._validate_filename(filename):
|
||||
raise ValueError(f"无效的文件名: {filename}")
|
||||
|
||||
# 构建完整路径
|
||||
log_file = os.path.join(self.log_dir, filename)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.isfile(log_file):
|
||||
raise FileNotFoundError(f"日志文件不存在: {filename}")
|
||||
|
||||
# 参数校验和默认值处理
|
||||
if lines is None:
|
||||
lines = self.default_lines
|
||||
@@ -48,7 +169,7 @@ class SystemLogService:
|
||||
lines = self.max_lines
|
||||
|
||||
# 使用 tail 命令读取日志文件末尾内容
|
||||
cmd = ["tail", "-n", str(lines), self.log_file]
|
||||
cmd = ["tail", "-n", str(lines), log_file]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
|
||||
@@ -2,14 +2,18 @@
|
||||
通用模块 URL 配置
|
||||
|
||||
路由说明:
|
||||
- /api/health/ 健康检查接口(无需认证)
|
||||
- /api/auth/* 认证相关接口(登录、登出、用户信息)
|
||||
- /api/system/* 系统管理接口(日志查看等)
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
from .views import LoginView, LogoutView, MeView, ChangePasswordView, SystemLogsView
|
||||
from .views import LoginView, LogoutView, MeView, ChangePasswordView, SystemLogsView, SystemLogFilesView, HealthCheckView
|
||||
|
||||
urlpatterns = [
|
||||
# 健康检查(无需认证)
|
||||
path('health/', HealthCheckView.as_view(), name='health-check'),
|
||||
|
||||
# 认证相关
|
||||
path('auth/login/', LoginView.as_view(), name='auth-login'),
|
||||
path('auth/logout/', LogoutView.as_view(), name='auth-logout'),
|
||||
@@ -18,4 +22,5 @@ urlpatterns = [
|
||||
|
||||
# 系统管理
|
||||
path('system/logs/', SystemLogsView.as_view(), name='system-logs'),
|
||||
path('system/logs/files/', SystemLogFilesView.as_view(), name='system-log-files'),
|
||||
]
|
||||
|
||||
28
backend/apps/common/utils/__init__.py
Normal file
28
backend/apps/common/utils/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Common utilities"""
|
||||
|
||||
from .dedup import deduplicate_for_bulk, get_unique_fields
|
||||
from .hash import (
|
||||
calc_file_sha256,
|
||||
calc_stream_sha256,
|
||||
safe_calc_file_sha256,
|
||||
is_file_hash_match,
|
||||
)
|
||||
from .csv_utils import (
|
||||
generate_csv_rows,
|
||||
format_list_field,
|
||||
format_datetime,
|
||||
UTF8_BOM,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'deduplicate_for_bulk',
|
||||
'get_unique_fields',
|
||||
'calc_file_sha256',
|
||||
'calc_stream_sha256',
|
||||
'safe_calc_file_sha256',
|
||||
'is_file_hash_match',
|
||||
'generate_csv_rows',
|
||||
'format_list_field',
|
||||
'format_datetime',
|
||||
'UTF8_BOM',
|
||||
]
|
||||
116
backend/apps/common/utils/csv_utils.py
Normal file
116
backend/apps/common/utils/csv_utils.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""CSV 导出工具模块
|
||||
|
||||
提供流式 CSV 生成功能,支持:
|
||||
- UTF-8 BOM(Excel 兼容)
|
||||
- RFC 4180 规范转义
|
||||
- 流式生成(内存友好)
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from typing import Iterator, Dict, Any, List, Callable, Optional
|
||||
|
||||
# UTF-8 BOM,确保 Excel 正确识别编码
|
||||
UTF8_BOM = '\ufeff'
|
||||
|
||||
|
||||
def generate_csv_rows(
|
||||
data_iterator: Iterator[Dict[str, Any]],
|
||||
headers: List[str],
|
||||
field_formatters: Optional[Dict[str, Callable]] = None
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
流式生成 CSV 行
|
||||
|
||||
Args:
|
||||
data_iterator: 数据迭代器,每个元素是一个字典
|
||||
headers: CSV 表头列表
|
||||
field_formatters: 字段格式化函数字典,key 为字段名,value 为格式化函数
|
||||
|
||||
Yields:
|
||||
CSV 行字符串(包含换行符)
|
||||
|
||||
Example:
|
||||
>>> data = [{'ip': '192.168.1.1', 'hosts': ['a.com', 'b.com']}]
|
||||
>>> headers = ['ip', 'hosts']
|
||||
>>> formatters = {'hosts': format_list_field}
|
||||
>>> for row in generate_csv_rows(iter(data), headers, formatters):
|
||||
... print(row, end='')
|
||||
"""
|
||||
# 输出 BOM + 表头
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
|
||||
writer.writerow(headers)
|
||||
yield UTF8_BOM + output.getvalue()
|
||||
|
||||
# 输出数据行
|
||||
for row_data in data_iterator:
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
|
||||
|
||||
row = []
|
||||
for header in headers:
|
||||
value = row_data.get(header, '')
|
||||
if field_formatters and header in field_formatters:
|
||||
value = field_formatters[header](value)
|
||||
row.append(value if value is not None else '')
|
||||
|
||||
writer.writerow(row)
|
||||
yield output.getvalue()
|
||||
|
||||
|
||||
def format_list_field(values: List, separator: str = ';') -> str:
|
||||
"""
|
||||
将列表字段格式化为分号分隔的字符串
|
||||
|
||||
Args:
|
||||
values: 值列表
|
||||
separator: 分隔符,默认为分号
|
||||
|
||||
Returns:
|
||||
分隔符连接的字符串
|
||||
|
||||
Example:
|
||||
>>> format_list_field(['a.com', 'b.com'])
|
||||
'a.com;b.com'
|
||||
>>> format_list_field([80, 443])
|
||||
'80;443'
|
||||
>>> format_list_field([])
|
||||
''
|
||||
>>> format_list_field(None)
|
||||
''
|
||||
"""
|
||||
if not values:
|
||||
return ''
|
||||
return separator.join(str(v) for v in values)
|
||||
|
||||
|
||||
def format_datetime(dt: Optional[datetime]) -> str:
|
||||
"""
|
||||
格式化日期时间为字符串(转换为本地时区)
|
||||
|
||||
Args:
|
||||
dt: datetime 对象或 None
|
||||
|
||||
Returns:
|
||||
格式化的日期时间字符串,格式为 YYYY-MM-DD HH:MM:SS(本地时区)
|
||||
|
||||
Example:
|
||||
>>> from datetime import datetime
|
||||
>>> format_datetime(datetime(2024, 1, 15, 10, 30, 0))
|
||||
'2024-01-15 10:30:00'
|
||||
>>> format_datetime(None)
|
||||
''
|
||||
"""
|
||||
if dt is None:
|
||||
return ''
|
||||
if isinstance(dt, str):
|
||||
return dt
|
||||
|
||||
# 转换为本地时区(从 Django settings 获取)
|
||||
from django.utils import timezone
|
||||
if timezone.is_aware(dt):
|
||||
dt = timezone.localtime(dt)
|
||||
|
||||
return dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
101
backend/apps/common/utils/dedup.py
Normal file
101
backend/apps/common/utils/dedup.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
批量数据去重工具
|
||||
|
||||
用于 bulk_create 前的批次内去重,避免 PostgreSQL ON CONFLICT 错误。
|
||||
自动从 Django 模型读取唯一约束字段,无需手动指定。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, TypeVar, Tuple, Optional
|
||||
|
||||
from django.db import models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def get_unique_fields(model: type[models.Model]) -> Optional[Tuple[str, ...]]:
|
||||
"""
|
||||
从 Django 模型获取唯一约束字段
|
||||
|
||||
按优先级查找:
|
||||
1. Meta.constraints 中的 UniqueConstraint
|
||||
2. Meta.unique_together
|
||||
|
||||
Args:
|
||||
model: Django 模型类
|
||||
|
||||
Returns:
|
||||
唯一约束字段元组,如果没有则返回 None
|
||||
"""
|
||||
meta = model._meta
|
||||
|
||||
# 1. 优先查找 UniqueConstraint
|
||||
for constraint in getattr(meta, 'constraints', []):
|
||||
if isinstance(constraint, models.UniqueConstraint):
|
||||
# 跳过条件约束(partial unique)
|
||||
if getattr(constraint, 'condition', None) is None:
|
||||
return tuple(constraint.fields)
|
||||
|
||||
# 2. 回退到 unique_together
|
||||
unique_together = getattr(meta, 'unique_together', None)
|
||||
if unique_together:
|
||||
# unique_together 可能是 (('a', 'b'),) 或 ('a', 'b')
|
||||
if unique_together and isinstance(unique_together[0], (list, tuple)):
|
||||
return tuple(unique_together[0])
|
||||
return tuple(unique_together)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def deduplicate_for_bulk(items: List[T], model: type[models.Model]) -> List[T]:
|
||||
"""
|
||||
根据模型唯一约束对数据去重
|
||||
|
||||
自动从模型读取唯一约束字段,生成去重 key。
|
||||
保留最后一条记录(后面的数据通常是更新的)。
|
||||
|
||||
Args:
|
||||
items: 待去重的数据列表(DTO 或 Model 对象)
|
||||
model: Django 模型类(用于读取唯一约束)
|
||||
|
||||
Returns:
|
||||
去重后的数据列表
|
||||
|
||||
Example:
|
||||
# 自动从 Endpoint 模型读取唯一约束 (url, target)
|
||||
unique_items = deduplicate_for_bulk(items, Endpoint)
|
||||
"""
|
||||
if not items:
|
||||
return items
|
||||
|
||||
unique_fields = get_unique_fields(model)
|
||||
if unique_fields is None:
|
||||
# 模型没有唯一约束,无需去重
|
||||
logger.debug(f"{model.__name__} 没有唯一约束,跳过去重")
|
||||
return items
|
||||
|
||||
# 处理外键字段名(target -> target_id)
|
||||
def make_key(item: T) -> tuple:
|
||||
key_parts = []
|
||||
for field in unique_fields:
|
||||
# 尝试 field_id(外键)和 field 两种形式
|
||||
value = getattr(item, f'{field}_id', None)
|
||||
if value is None:
|
||||
value = getattr(item, field, None)
|
||||
key_parts.append(value)
|
||||
return tuple(key_parts)
|
||||
|
||||
# 使用字典去重,保留最后一条
|
||||
seen = {}
|
||||
for item in items:
|
||||
key = make_key(item)
|
||||
seen[key] = item
|
||||
|
||||
unique_items = list(seen.values())
|
||||
|
||||
if len(unique_items) < len(items):
|
||||
logger.debug(f"{model.__name__} 去重: {len(items)} -> {len(unique_items)} 条")
|
||||
|
||||
return unique_items
|
||||
331
backend/apps/common/utils/filter_utils.py
Normal file
331
backend/apps/common/utils/filter_utils.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""智能过滤工具 - 通用查询语法解析和 Django ORM 查询构建
|
||||
|
||||
支持的语法:
|
||||
- field="value" 模糊匹配(包含)
|
||||
- field=="value" 精确匹配
|
||||
- field!="value" 不等于
|
||||
|
||||
逻辑运算符:
|
||||
- AND: && 或 and 或 空格(默认)
|
||||
- OR: || 或 or
|
||||
|
||||
示例:
|
||||
type="xss" || type="sqli" # OR
|
||||
type="xss" or type="sqli" # OR(等价)
|
||||
severity="high" && source="nuclei" # AND
|
||||
severity="high" source="nuclei" # AND(空格默认为 AND)
|
||||
severity="high" and source="nuclei" # AND(等价)
|
||||
|
||||
使用示例:
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
field_mapping = {'ip': 'ip', 'port': 'port', 'host': 'host'}
|
||||
queryset = apply_filters(queryset, 'ip="192" || port="80"', field_mapping)
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Optional, Union
|
||||
from enum import Enum
|
||||
|
||||
from django.db.models import QuerySet, Q, F, Func, CharField
|
||||
from django.db.models.functions import Cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ArrayToString(Func):
|
||||
"""PostgreSQL array_to_string 函数"""
|
||||
function = 'array_to_string'
|
||||
template = "%(function)s(%(expressions)s, ',')"
|
||||
output_field = CharField()
|
||||
|
||||
|
||||
class LogicalOp(Enum):
|
||||
"""逻辑运算符"""
|
||||
AND = 'AND'
|
||||
OR = 'OR'
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedFilter:
|
||||
"""解析后的过滤条件"""
|
||||
field: str # 字段名
|
||||
operator: str # 操作符: '=', '==', '!='
|
||||
value: str # 原始值
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterGroup:
|
||||
"""过滤条件组(带逻辑运算符)"""
|
||||
filter: ParsedFilter
|
||||
logical_op: LogicalOp # 与前一个条件的逻辑关系
|
||||
|
||||
|
||||
class QueryParser:
|
||||
"""查询语法解析器
|
||||
|
||||
支持 ||/or (OR) 和 &&/and/空格 (AND) 逻辑运算符
|
||||
"""
|
||||
|
||||
# 正则匹配: field="value", field=="value", field!="value"
|
||||
FILTER_PATTERN = re.compile(r'(\w+)(==|!=|=)"([^"]*)"')
|
||||
|
||||
# 逻辑运算符模式(带空格)
|
||||
OR_PATTERN = re.compile(r'\s*(\|\||(?<![a-zA-Z])or(?![a-zA-Z]))\s*', re.IGNORECASE)
|
||||
AND_PATTERN = re.compile(r'\s*(&&|(?<![a-zA-Z])and(?![a-zA-Z]))\s*', re.IGNORECASE)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, query_string: str) -> List[FilterGroup]:
|
||||
"""解析查询语法字符串
|
||||
|
||||
Args:
|
||||
query_string: 查询语法字符串
|
||||
|
||||
Returns:
|
||||
解析后的过滤条件组列表
|
||||
|
||||
Examples:
|
||||
>>> QueryParser.parse('type="xss" || type="sqli"')
|
||||
[FilterGroup(filter=..., logical_op=AND), # 第一个默认 AND
|
||||
FilterGroup(filter=..., logical_op=OR)]
|
||||
"""
|
||||
if not query_string or not query_string.strip():
|
||||
return []
|
||||
|
||||
# 第一步:提取所有过滤条件并用占位符替换,保护引号内的空格
|
||||
filters_found = []
|
||||
placeholder_pattern = '__FILTER_{}__'
|
||||
|
||||
def replace_filter(match):
|
||||
idx = len(filters_found)
|
||||
filters_found.append(match.group(0))
|
||||
return placeholder_pattern.format(idx)
|
||||
|
||||
# 先用正则提取所有 field="value" 形式的条件
|
||||
protected = cls.FILTER_PATTERN.sub(replace_filter, query_string)
|
||||
|
||||
# 标准化逻辑运算符
|
||||
# 先处理 || 和 or -> __OR__
|
||||
normalized = cls.OR_PATTERN.sub(' __OR__ ', protected)
|
||||
# 再处理 && 和 and -> __AND__
|
||||
normalized = cls.AND_PATTERN.sub(' __AND__ ', normalized)
|
||||
|
||||
# 分词:按空格分割,保留逻辑运算符标记
|
||||
tokens = normalized.split()
|
||||
|
||||
groups = []
|
||||
pending_op = LogicalOp.AND # 默认 AND
|
||||
|
||||
for token in tokens:
|
||||
if token == '__OR__':
|
||||
pending_op = LogicalOp.OR
|
||||
elif token == '__AND__':
|
||||
pending_op = LogicalOp.AND
|
||||
elif token.startswith('__FILTER_') and token.endswith('__'):
|
||||
# 还原占位符为原始过滤条件
|
||||
try:
|
||||
idx = int(token[9:-2]) # 提取索引
|
||||
original_filter = filters_found[idx]
|
||||
match = cls.FILTER_PATTERN.match(original_filter)
|
||||
if match:
|
||||
field, operator, value = match.groups()
|
||||
groups.append(FilterGroup(
|
||||
filter=ParsedFilter(
|
||||
field=field.lower(),
|
||||
operator=operator,
|
||||
value=value
|
||||
),
|
||||
logical_op=pending_op if groups else LogicalOp.AND
|
||||
))
|
||||
pending_op = LogicalOp.AND # 重置为默认 AND
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
# 其他 token 忽略(无效输入)
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
class QueryBuilder:
|
||||
"""Django ORM 查询构建器
|
||||
|
||||
将解析后的过滤条件转换为 Django ORM 查询,支持 AND/OR 逻辑
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def build_query(
|
||||
cls,
|
||||
queryset: QuerySet,
|
||||
filter_groups: List[FilterGroup],
|
||||
field_mapping: Dict[str, str],
|
||||
json_array_fields: List[str] = None
|
||||
) -> QuerySet:
|
||||
"""构建 Django ORM 查询
|
||||
|
||||
Args:
|
||||
queryset: Django QuerySet
|
||||
filter_groups: 解析后的过滤条件组列表
|
||||
field_mapping: 字段映射
|
||||
json_array_fields: JSON 数组字段列表(使用 __contains 查询)
|
||||
|
||||
Returns:
|
||||
过滤后的 QuerySet
|
||||
"""
|
||||
if not filter_groups:
|
||||
return queryset
|
||||
|
||||
json_array_fields = json_array_fields or []
|
||||
|
||||
# 收集需要 annotate 的数组模糊搜索字段
|
||||
array_fuzzy_fields = set()
|
||||
|
||||
# 第一遍:检查是否有数组模糊匹配
|
||||
for group in filter_groups:
|
||||
f = group.filter
|
||||
db_field = field_mapping.get(f.field)
|
||||
if db_field and db_field in json_array_fields and f.operator == '=':
|
||||
array_fuzzy_fields.add(db_field)
|
||||
|
||||
# 对数组模糊搜索字段做 annotate
|
||||
for field in array_fuzzy_fields:
|
||||
annotate_name = f'{field}_text'
|
||||
queryset = queryset.annotate(**{annotate_name: ArrayToString(F(field))})
|
||||
|
||||
# 构建 Q 对象
|
||||
combined_q = None
|
||||
|
||||
for group in filter_groups:
|
||||
f = group.filter
|
||||
|
||||
# 字段映射
|
||||
db_field = field_mapping.get(f.field)
|
||||
if not db_field:
|
||||
logger.debug(f"忽略未知字段: {f.field}")
|
||||
continue
|
||||
|
||||
# 判断是否为 JSON 数组字段
|
||||
is_json_array = db_field in json_array_fields
|
||||
|
||||
# 构建单个条件的 Q 对象
|
||||
q = cls._build_single_q(db_field, f.operator, f.value, is_json_array)
|
||||
if q is None:
|
||||
continue
|
||||
|
||||
# 组合 Q 对象
|
||||
if combined_q is None:
|
||||
combined_q = q
|
||||
elif group.logical_op == LogicalOp.OR:
|
||||
combined_q = combined_q | q
|
||||
else: # AND
|
||||
combined_q = combined_q & q
|
||||
|
||||
if combined_q is not None:
|
||||
return queryset.filter(combined_q)
|
||||
return queryset
|
||||
|
||||
@classmethod
|
||||
def _build_single_q(cls, field: str, operator: str, value: str, is_json_array: bool = False) -> Optional[Q]:
|
||||
"""构建单个条件的 Q 对象"""
|
||||
if is_json_array:
|
||||
if operator == '==':
|
||||
# 精确匹配:数组中包含完全等于 value 的元素
|
||||
return Q(**{f'{field}__contains': [value]})
|
||||
elif operator == '!=':
|
||||
# 不包含:数组中不包含完全等于 value 的元素
|
||||
return ~Q(**{f'{field}__contains': [value]})
|
||||
else: # '=' 模糊匹配
|
||||
# 使用 annotate 后的字段进行模糊搜索
|
||||
# 字段已在 build_query 中通过 ArrayToString 转换为文本
|
||||
annotate_name = f'{field}_text'
|
||||
return Q(**{f'{annotate_name}__icontains': value})
|
||||
|
||||
if operator == '!=':
|
||||
return cls._build_not_equal_q(field, value)
|
||||
elif operator == '==':
|
||||
return cls._build_exact_q(field, value)
|
||||
else: # '='
|
||||
return cls._build_fuzzy_q(field, value)
|
||||
|
||||
@classmethod
|
||||
def _try_convert_to_int(cls, value: str) -> Optional[int]:
|
||||
"""尝试将值转换为整数"""
|
||||
try:
|
||||
return int(value.strip())
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _build_fuzzy_q(cls, field: str, value: str) -> Q:
|
||||
"""模糊匹配: 包含"""
|
||||
return Q(**{f'{field}__icontains': value})
|
||||
|
||||
@classmethod
|
||||
def _build_exact_q(cls, field: str, value: str) -> Q:
|
||||
"""精确匹配"""
|
||||
int_val = cls._try_convert_to_int(value)
|
||||
if int_val is not None:
|
||||
return Q(**{f'{field}__exact': int_val})
|
||||
return Q(**{f'{field}__exact': value})
|
||||
|
||||
@classmethod
|
||||
def _build_not_equal_q(cls, field: str, value: str) -> Q:
|
||||
"""不等于"""
|
||||
int_val = cls._try_convert_to_int(value)
|
||||
if int_val is not None:
|
||||
return ~Q(**{f'{field}__exact': int_val})
|
||||
return ~Q(**{f'{field}__exact': value})
|
||||
|
||||
|
||||
def apply_filters(
|
||||
queryset: QuerySet,
|
||||
query_string: str,
|
||||
field_mapping: Dict[str, str],
|
||||
json_array_fields: List[str] = None
|
||||
) -> QuerySet:
|
||||
"""应用过滤条件到 QuerySet
|
||||
|
||||
Args:
|
||||
queryset: Django QuerySet
|
||||
query_string: 查询语法字符串
|
||||
field_mapping: 字段映射
|
||||
json_array_fields: JSON 数组字段列表(使用 __contains 查询)
|
||||
|
||||
Returns:
|
||||
过滤后的 QuerySet
|
||||
|
||||
Examples:
|
||||
# OR 查询
|
||||
apply_filters(qs, 'type="xss" || type="sqli"', mapping)
|
||||
apply_filters(qs, 'type="xss" or type="sqli"', mapping)
|
||||
|
||||
# AND 查询
|
||||
apply_filters(qs, 'severity="high" && source="nuclei"', mapping)
|
||||
apply_filters(qs, 'severity="high" source="nuclei"', mapping)
|
||||
|
||||
# 混合查询
|
||||
apply_filters(qs, 'type="xss" || type="sqli" && severity="high"', mapping)
|
||||
|
||||
# JSON 数组字段查询
|
||||
apply_filters(qs, 'implies="PHP"', mapping, json_array_fields=['implies'])
|
||||
"""
|
||||
if not query_string or not query_string.strip():
|
||||
return queryset
|
||||
|
||||
try:
|
||||
filter_groups = QueryParser.parse(query_string)
|
||||
if not filter_groups:
|
||||
logger.debug(f"未解析到有效过滤条件: {query_string}")
|
||||
return queryset
|
||||
|
||||
logger.debug(f"解析过滤条件: {filter_groups}")
|
||||
return QueryBuilder.build_query(
|
||||
queryset,
|
||||
filter_groups,
|
||||
field_mapping,
|
||||
json_array_fields=json_array_fields
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"过滤解析错误: {e}, query: {query_string}")
|
||||
return queryset # 静默降级
|
||||
@@ -7,7 +7,6 @@
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, BinaryIO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -91,11 +90,3 @@ def is_file_hash_match(file_path: str, expected_hash: str) -> bool:
|
||||
return False
|
||||
|
||||
return actual_hash.lower() == expected_hash.lower()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"calc_file_sha256",
|
||||
"calc_stream_sha256",
|
||||
"safe_calc_file_sha256",
|
||||
"is_file_hash_match",
|
||||
]
|
||||
@@ -1,6 +1,8 @@
|
||||
"""域名、IP、端口和目标验证工具函数"""
|
||||
"""域名、IP、端口、URL 和目标验证工具函数"""
|
||||
import ipaddress
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import validators
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,6 +27,21 @@ def validate_domain(domain: str) -> None:
|
||||
raise ValueError(f"域名格式无效: {domain}")
|
||||
|
||||
|
||||
def is_valid_domain(domain: str) -> bool:
|
||||
"""
|
||||
判断是否为有效域名(不抛异常)
|
||||
|
||||
Args:
|
||||
domain: 域名字符串
|
||||
|
||||
Returns:
|
||||
bool: 是否为有效域名
|
||||
"""
|
||||
if not domain or len(domain) > 253:
|
||||
return False
|
||||
return bool(validators.domain(domain))
|
||||
|
||||
|
||||
def validate_ip(ip: str) -> None:
|
||||
"""
|
||||
验证 IP 地址格式(支持 IPv4 和 IPv6)
|
||||
@@ -44,6 +61,25 @@ def validate_ip(ip: str) -> None:
|
||||
raise ValueError(f"IP 地址格式无效: {ip}")
|
||||
|
||||
|
||||
def is_valid_ip(ip: str) -> bool:
|
||||
"""
|
||||
判断是否为有效 IP 地址(不抛异常)
|
||||
|
||||
Args:
|
||||
ip: IP 地址字符串
|
||||
|
||||
Returns:
|
||||
bool: 是否为有效 IP 地址
|
||||
"""
|
||||
if not ip:
|
||||
return False
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def validate_cidr(cidr: str) -> None:
|
||||
"""
|
||||
验证 CIDR 格式(支持 IPv4 和 IPv6)
|
||||
@@ -140,3 +176,136 @@ def validate_port(port: any) -> tuple[bool, int | None]:
|
||||
except (ValueError, TypeError):
|
||||
logger.warning("端口号格式错误,无法转换为整数: %s", port)
|
||||
return False, None
|
||||
|
||||
|
||||
# ==================== URL 验证函数 ====================
|
||||
|
||||
def validate_url(url: str) -> None:
|
||||
"""
|
||||
验证 URL 格式,必须包含 scheme(http:// 或 https://)
|
||||
|
||||
Args:
|
||||
url: URL 字符串
|
||||
|
||||
Raises:
|
||||
ValueError: URL 格式无效或缺少 scheme
|
||||
"""
|
||||
if not url:
|
||||
raise ValueError("URL 不能为空")
|
||||
|
||||
# 检查是否包含 scheme
|
||||
if not url.startswith('http://') and not url.startswith('https://'):
|
||||
raise ValueError("URL 必须包含协议(http:// 或 https://)")
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL 必须包含主机名")
|
||||
except Exception:
|
||||
raise ValueError(f"URL 格式无效: {url}")
|
||||
|
||||
|
||||
def is_valid_url(url: str, max_length: int = 2000) -> bool:
|
||||
"""
|
||||
判断是否为有效 URL(不抛异常)
|
||||
|
||||
Args:
|
||||
url: URL 字符串
|
||||
max_length: URL 最大长度,默认 2000
|
||||
|
||||
Returns:
|
||||
bool: 是否为有效 URL
|
||||
"""
|
||||
if not url or len(url) > max_length:
|
||||
return False
|
||||
try:
|
||||
validate_url(url)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_url_match_target(url: str, target_name: str, target_type: str) -> bool:
|
||||
"""
|
||||
判断 URL 是否匹配目标
|
||||
|
||||
Args:
|
||||
url: URL 字符串
|
||||
target_name: 目标名称(域名、IP 或 CIDR)
|
||||
target_type: 目标类型 ('domain', 'ip', 'cidr')
|
||||
|
||||
Returns:
|
||||
bool: 是否匹配
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
|
||||
hostname = hostname.lower()
|
||||
target_name = target_name.lower()
|
||||
|
||||
if target_type == 'domain':
|
||||
# 域名类型:hostname 等于 target_name 或以 .target_name 结尾
|
||||
return hostname == target_name or hostname.endswith('.' + target_name)
|
||||
|
||||
elif target_type == 'ip':
|
||||
# IP 类型:hostname 必须完全等于 target_name
|
||||
return hostname == target_name
|
||||
|
||||
elif target_type == 'cidr':
|
||||
# CIDR 类型:hostname 必须是 IP 且在 CIDR 范围内
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
return ip in network
|
||||
except ValueError:
|
||||
# hostname 不是有效 IP
|
||||
return False
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def detect_input_type(input_str: str) -> str:
|
||||
"""
|
||||
检测输入类型(用于快速扫描输入解析)
|
||||
|
||||
Args:
|
||||
input_str: 输入字符串(应该已经 strip)
|
||||
|
||||
Returns:
|
||||
str: 输入类型 ('url', 'domain', 'ip', 'cidr')
|
||||
"""
|
||||
if not input_str:
|
||||
raise ValueError("输入不能为空")
|
||||
|
||||
# 1. 包含 :// 一定是 URL
|
||||
if '://' in input_str:
|
||||
return 'url'
|
||||
|
||||
# 2. 包含 / 需要判断是 CIDR 还是 URL(缺少 scheme)
|
||||
if '/' in input_str:
|
||||
# CIDR 格式: IP/prefix,如 10.0.0.0/8
|
||||
parts = input_str.split('/')
|
||||
if len(parts) == 2:
|
||||
ip_part, prefix_part = parts
|
||||
# 如果斜杠后是纯数字且在 0-32 范围内,检查是否是 CIDR
|
||||
if prefix_part.isdigit() and 0 <= int(prefix_part) <= 32:
|
||||
ip_parts = ip_part.split('.')
|
||||
if len(ip_parts) == 4 and all(p.isdigit() for p in ip_parts):
|
||||
return 'cidr'
|
||||
# 不是 CIDR,视为 URL(缺少 scheme,后续验证会报错)
|
||||
return 'url'
|
||||
|
||||
# 3. 检查是否是 IP 地址
|
||||
try:
|
||||
ipaddress.ip_address(input_str)
|
||||
return 'ip'
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 4. 默认为域名
|
||||
return 'domain'
|
||||
|
||||
@@ -2,11 +2,17 @@
|
||||
通用模块视图导出
|
||||
|
||||
包含:
|
||||
- 健康检查视图:Docker 健康检查
|
||||
- 认证相关视图:登录、登出、用户信息、修改密码
|
||||
- 系统日志视图:实时日志查看
|
||||
"""
|
||||
|
||||
from .health_views import HealthCheckView
|
||||
from .auth_views import LoginView, LogoutView, MeView, ChangePasswordView
|
||||
from .system_log_views import SystemLogsView
|
||||
from .system_log_views import SystemLogsView, SystemLogFilesView
|
||||
|
||||
__all__ = ['LoginView', 'LogoutView', 'MeView', 'ChangePasswordView', 'SystemLogsView']
|
||||
__all__ = [
|
||||
'HealthCheckView',
|
||||
'LoginView', 'LogoutView', 'MeView', 'ChangePasswordView',
|
||||
'SystemLogsView', 'SystemLogFilesView',
|
||||
]
|
||||
|
||||
@@ -9,7 +9,10 @@ from django.utils.decorators import method_decorator
|
||||
from rest_framework import status
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rest_framework.permissions import AllowAny
|
||||
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,9 +31,10 @@ class LoginView(APIView):
|
||||
password = request.data.get('password')
|
||||
|
||||
if not username or not password:
|
||||
return Response(
|
||||
{'error': '请提供用户名和密码'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Username and password are required',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
user = authenticate(request, username=username, password=password)
|
||||
@@ -38,20 +42,22 @@ class LoginView(APIView):
|
||||
if user is not None:
|
||||
login(request, user)
|
||||
logger.info(f"用户 {username} 登录成功")
|
||||
return Response({
|
||||
'message': '登录成功',
|
||||
'user': {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'isStaff': user.is_staff,
|
||||
'isSuperuser': user.is_superuser,
|
||||
return success_response(
|
||||
data={
|
||||
'user': {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'isStaff': user.is_staff,
|
||||
'isSuperuser': user.is_superuser,
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
else:
|
||||
logger.warning(f"用户 {username} 登录失败:用户名或密码错误")
|
||||
return Response(
|
||||
{'error': '用户名或密码错误'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
return error_response(
|
||||
code=ErrorCodes.UNAUTHORIZED,
|
||||
message='Invalid username or password',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
|
||||
@@ -79,7 +85,7 @@ class LogoutView(APIView):
|
||||
logout(request)
|
||||
else:
|
||||
logout(request)
|
||||
return Response({'message': '已登出'})
|
||||
return success_response()
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name='dispatch')
|
||||
@@ -100,22 +106,26 @@ class MeView(APIView):
|
||||
if user_id:
|
||||
try:
|
||||
user = User.objects.get(pk=user_id)
|
||||
return Response({
|
||||
'authenticated': True,
|
||||
'user': {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'isStaff': user.is_staff,
|
||||
'isSuperuser': user.is_superuser,
|
||||
return success_response(
|
||||
data={
|
||||
'authenticated': True,
|
||||
'user': {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'isStaff': user.is_staff,
|
||||
'isSuperuser': user.is_superuser,
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
except User.DoesNotExist:
|
||||
pass
|
||||
|
||||
return Response({
|
||||
'authenticated': False,
|
||||
'user': None
|
||||
})
|
||||
return success_response(
|
||||
data={
|
||||
'authenticated': False,
|
||||
'user': None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name='dispatch')
|
||||
@@ -124,43 +134,27 @@ class ChangePasswordView(APIView):
|
||||
修改密码
|
||||
POST /api/auth/change-password/
|
||||
"""
|
||||
authentication_classes = [] # 禁用认证(绕过 CSRF)
|
||||
permission_classes = [AllowAny] # 手动检查登录状态
|
||||
|
||||
def post(self, request):
|
||||
# 手动检查登录状态(从 session 获取用户)
|
||||
from django.contrib.auth import get_user_model
|
||||
User = get_user_model()
|
||||
|
||||
user_id = request.session.get('_auth_user_id')
|
||||
if not user_id:
|
||||
return Response(
|
||||
{'error': '请先登录'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
try:
|
||||
user = User.objects.get(pk=user_id)
|
||||
except User.DoesNotExist:
|
||||
return Response(
|
||||
{'error': '用户不存在'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
# 使用全局权限类验证,request.user 已经是认证用户
|
||||
user = request.user
|
||||
|
||||
# CamelCaseParser 将 oldPassword -> old_password
|
||||
old_password = request.data.get('old_password')
|
||||
new_password = request.data.get('new_password')
|
||||
|
||||
if not old_password or not new_password:
|
||||
return Response(
|
||||
{'error': '请提供旧密码和新密码'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Old password and new password are required',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
if not user.check_password(old_password):
|
||||
return Response(
|
||||
{'error': '旧密码错误'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Old password is incorrect',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
user.set_password(new_password)
|
||||
@@ -170,4 +164,4 @@ class ChangePasswordView(APIView):
|
||||
update_session_auth_hash(request, user)
|
||||
|
||||
logger.info(f"用户 {user.username} 已修改密码")
|
||||
return Response({'message': '密码修改成功'})
|
||||
return success_response()
|
||||
|
||||
24
backend/apps/common/views/health_views.py
Normal file
24
backend/apps/common/views/health_views.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
健康检查视图
|
||||
|
||||
提供 Docker 健康检查端点,无需认证。
|
||||
"""
|
||||
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.permissions import AllowAny
|
||||
|
||||
|
||||
class HealthCheckView(APIView):
|
||||
"""
|
||||
健康检查端点
|
||||
|
||||
GET /api/health/
|
||||
|
||||
返回服务状态,用于 Docker 健康检查。
|
||||
此端点无需认证。
|
||||
"""
|
||||
permission_classes = [AllowAny]
|
||||
|
||||
def get(self, request):
|
||||
return Response({'status': 'ok'})
|
||||
@@ -9,16 +9,57 @@ import logging
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
from apps.common.services.system_log_service import SystemLogService
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name="dispatch")
|
||||
class SystemLogFilesView(APIView):
|
||||
"""
|
||||
日志文件列表 API 视图
|
||||
|
||||
GET /api/system/logs/files/
|
||||
获取所有可用的日志文件列表
|
||||
|
||||
Response:
|
||||
{
|
||||
"files": [
|
||||
{
|
||||
"filename": "xingrin.log",
|
||||
"category": "system",
|
||||
"size": 1048576,
|
||||
"modifiedAt": "2025-01-15T10:30:00+00:00"
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = SystemLogService()
|
||||
|
||||
def get(self, request):
|
||||
"""获取日志文件列表"""
|
||||
try:
|
||||
files = self.service.get_log_files()
|
||||
return success_response(data={"files": files})
|
||||
except Exception:
|
||||
logger.exception("获取日志文件列表失败")
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to get log files',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name="dispatch")
|
||||
class SystemLogsView(APIView):
|
||||
"""
|
||||
@@ -28,21 +69,14 @@ class SystemLogsView(APIView):
|
||||
获取系统日志内容
|
||||
|
||||
Query Parameters:
|
||||
file (str, optional): 日志文件名,默认 xingrin.log
|
||||
lines (int, optional): 返回的日志行数,默认 200,最大 10000
|
||||
|
||||
Response:
|
||||
{
|
||||
"content": "日志内容字符串..."
|
||||
}
|
||||
|
||||
Note:
|
||||
- 当前为开发阶段,暂时允许匿名访问
|
||||
- 生产环境应添加管理员权限验证
|
||||
"""
|
||||
|
||||
# TODO: 生产环境应改为 IsAdminUser 权限
|
||||
authentication_classes = []
|
||||
permission_classes = [AllowAny]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -52,18 +86,33 @@ class SystemLogsView(APIView):
|
||||
"""
|
||||
获取系统日志
|
||||
|
||||
支持通过 lines 参数控制返回行数,用于前端分页或实时刷新场景。
|
||||
支持通过 file 和 lines 参数控制返回内容。
|
||||
"""
|
||||
try:
|
||||
# 解析 lines 参数
|
||||
# 解析参数
|
||||
filename = request.query_params.get("file")
|
||||
lines_raw = request.query_params.get("lines")
|
||||
lines = int(lines_raw) if lines_raw is not None else None
|
||||
|
||||
# 调用服务获取日志内容
|
||||
content = self.service.get_logs_content(lines=lines)
|
||||
return Response({"content": content})
|
||||
except ValueError:
|
||||
return Response({"error": "lines 参数必须是整数"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
content = self.service.get_logs_content(filename=filename, lines=lines)
|
||||
return success_response(data={"content": content})
|
||||
except ValueError as e:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=str(e) if 'file' in str(e).lower() else 'lines must be an integer',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("获取系统日志失败")
|
||||
return Response({"error": "获取系统日志失败"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to get system logs',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
44
backend/apps/common/websocket_auth.py
Normal file
44
backend/apps/common/websocket_auth.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
WebSocket 认证基类
|
||||
|
||||
提供需要认证的 WebSocket Consumer 基类
|
||||
"""
|
||||
|
||||
import logging
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthenticatedWebsocketConsumer(AsyncWebsocketConsumer):
|
||||
"""
|
||||
需要认证的 WebSocket Consumer 基类
|
||||
|
||||
子类应该重写 on_connect() 方法实现具体的连接逻辑
|
||||
"""
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
连接时验证用户认证状态
|
||||
|
||||
未认证时使用 close(code=4001) 拒绝连接
|
||||
"""
|
||||
user = self.scope.get('user')
|
||||
|
||||
if not user or not user.is_authenticated:
|
||||
logger.warning(
|
||||
f"WebSocket 连接被拒绝:用户未认证 - Path: {self.scope.get('path')}"
|
||||
)
|
||||
await self.close(code=4001)
|
||||
return
|
||||
|
||||
# 调用子类的连接逻辑
|
||||
await self.on_connect()
|
||||
|
||||
async def on_connect(self):
|
||||
"""
|
||||
子类实现具体的连接逻辑
|
||||
|
||||
默认实现:接受连接
|
||||
"""
|
||||
await self.accept()
|
||||
@@ -6,17 +6,17 @@ import json
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from apps.common.websocket_auth import AuthenticatedWebsocketConsumer
|
||||
from apps.engine.services import WorkerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerDeployConsumer(AsyncWebsocketConsumer):
|
||||
class WorkerDeployConsumer(AuthenticatedWebsocketConsumer):
|
||||
"""
|
||||
Worker 交互式终端 WebSocket Consumer
|
||||
|
||||
@@ -31,8 +31,8 @@ class WorkerDeployConsumer(AsyncWebsocketConsumer):
|
||||
self.read_task = None
|
||||
self.worker_service = WorkerService()
|
||||
|
||||
async def connect(self):
|
||||
"""连接时加入对应 Worker 的组并自动建立 SSH 连接"""
|
||||
async def on_connect(self):
|
||||
"""连接时加入对应 Worker 的组并自动建立 SSH 连接(已通过认证)"""
|
||||
self.worker_id = self.scope['url_route']['kwargs']['worker_id']
|
||||
self.group_name = f'worker_deploy_{self.worker_id}'
|
||||
|
||||
@@ -242,8 +242,9 @@ class WorkerDeployConsumer(AsyncWebsocketConsumer):
|
||||
return
|
||||
|
||||
# 远程 Worker 通过 nginx HTTPS 访问(nginx 反代到后端 8888)
|
||||
# 使用 https://{PUBLIC_HOST} 而不是直连 8888 端口
|
||||
heartbeat_api_url = f"https://{public_host}" # 基础 URL,agent 会加 /api/...
|
||||
# 使用 https://{PUBLIC_HOST}:{PUBLIC_PORT} 而不是直连 8888 端口
|
||||
public_port = getattr(settings, 'PUBLIC_PORT', '8083')
|
||||
heartbeat_api_url = f"https://{public_host}:{public_port}"
|
||||
|
||||
session_name = f'xingrin_deploy_{self.worker_id}'
|
||||
remote_script_path = '/tmp/xingrin_deploy.sh'
|
||||
|
||||
@@ -15,9 +15,10 @@
|
||||
"""
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
from apps.engine.models import ScanEngine
|
||||
|
||||
@@ -44,10 +45,12 @@ class Command(BaseCommand):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
default_config = f.read()
|
||||
|
||||
# 解析 YAML 为字典,后续用于生成子引擎配置
|
||||
# 使用 ruamel.yaml 解析,保留注释
|
||||
yaml_parser = YAML()
|
||||
yaml_parser.preserve_quotes = True
|
||||
try:
|
||||
config_dict = yaml.safe_load(default_config) or {}
|
||||
except yaml.YAMLError as e:
|
||||
config_dict = yaml_parser.load(default_config) or {}
|
||||
except Exception as e:
|
||||
self.stdout.write(self.style.ERROR(f'引擎配置 YAML 解析失败: {e}'))
|
||||
return
|
||||
|
||||
@@ -83,15 +86,13 @@ class Command(BaseCommand):
|
||||
if scan_type != 'subdomain_discovery' and 'tools' not in scan_cfg:
|
||||
continue
|
||||
|
||||
# 构造只包含当前扫描类型配置的 YAML
|
||||
# 构造只包含当前扫描类型配置的 YAML(保留注释)
|
||||
single_config = {scan_type: scan_cfg}
|
||||
try:
|
||||
single_yaml = yaml.safe_dump(
|
||||
single_config,
|
||||
sort_keys=False,
|
||||
allow_unicode=True,
|
||||
)
|
||||
except yaml.YAMLError as e:
|
||||
stream = StringIO()
|
||||
yaml_parser.dump(single_config, stream)
|
||||
single_yaml = stream.getvalue()
|
||||
except Exception as e:
|
||||
self.stdout.write(self.style.ERROR(f'生成子引擎 {scan_type} 配置失败: {e}'))
|
||||
continue
|
||||
|
||||
|
||||
205
backend/apps/engine/management/commands/init_fingerprints.py
Normal file
205
backend/apps/engine/management/commands/init_fingerprints.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""初始化内置指纹库
|
||||
|
||||
- EHole 指纹: ehole.json -> 导入到数据库
|
||||
- Goby 指纹: goby.json -> 导入到数据库
|
||||
- Wappalyzer 指纹: wappalyzer.json -> 导入到数据库
|
||||
- Fingers 指纹: fingers_http.json -> 导入到数据库
|
||||
- FingerPrintHub 指纹: fingerprinthub_web.json -> 导入到数据库
|
||||
- ARL 指纹: ARL.yaml -> 导入到数据库
|
||||
|
||||
可重复执行:如果数据库已有数据则跳过,只在空库时导入。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from apps.engine.models import (
|
||||
EholeFingerprint,
|
||||
GobyFingerprint,
|
||||
WappalyzerFingerprint,
|
||||
FingersFingerprint,
|
||||
FingerPrintHubFingerprint,
|
||||
ARLFingerprint,
|
||||
)
|
||||
from apps.engine.services.fingerprints import (
|
||||
EholeFingerprintService,
|
||||
GobyFingerprintService,
|
||||
WappalyzerFingerprintService,
|
||||
FingersFingerprintService,
|
||||
FingerPrintHubFingerprintService,
|
||||
ARLFingerprintService,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 内置指纹配置
|
||||
DEFAULT_FINGERPRINTS = [
|
||||
{
|
||||
"type": "ehole",
|
||||
"filename": "ehole.json",
|
||||
"model": EholeFingerprint,
|
||||
"service": EholeFingerprintService,
|
||||
"data_key": "fingerprint", # JSON 中指纹数组的 key
|
||||
"file_format": "json",
|
||||
},
|
||||
{
|
||||
"type": "goby",
|
||||
"filename": "goby.json",
|
||||
"model": GobyFingerprint,
|
||||
"service": GobyFingerprintService,
|
||||
"data_key": None, # Goby 是数组格式,直接使用整个 JSON
|
||||
"file_format": "json",
|
||||
},
|
||||
{
|
||||
"type": "wappalyzer",
|
||||
"filename": "wappalyzer.json",
|
||||
"model": WappalyzerFingerprint,
|
||||
"service": WappalyzerFingerprintService,
|
||||
"data_key": "apps", # Wappalyzer 使用 apps 对象
|
||||
"file_format": "json",
|
||||
},
|
||||
{
|
||||
"type": "fingers",
|
||||
"filename": "fingers_http.json",
|
||||
"model": FingersFingerprint,
|
||||
"service": FingersFingerprintService,
|
||||
"data_key": None, # Fingers 是数组格式
|
||||
"file_format": "json",
|
||||
},
|
||||
{
|
||||
"type": "fingerprinthub",
|
||||
"filename": "fingerprinthub_web.json",
|
||||
"model": FingerPrintHubFingerprint,
|
||||
"service": FingerPrintHubFingerprintService,
|
||||
"data_key": None, # FingerPrintHub 是数组格式
|
||||
"file_format": "json",
|
||||
},
|
||||
{
|
||||
"type": "arl",
|
||||
"filename": "ARL.yaml",
|
||||
"model": ARLFingerprint,
|
||||
"service": ARLFingerprintService,
|
||||
"data_key": None, # ARL 是 YAML 数组格式
|
||||
"file_format": "yaml",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "初始化内置指纹库"
|
||||
|
||||
def handle(self, *args, **options):
|
||||
project_base = Path(settings.BASE_DIR).parent # /app/backend -> /app
|
||||
fingerprints_dir = project_base / "backend" / "fingerprints"
|
||||
|
||||
initialized = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
|
||||
for item in DEFAULT_FINGERPRINTS:
|
||||
fp_type = item["type"]
|
||||
filename = item["filename"]
|
||||
model = item["model"]
|
||||
service_class = item["service"]
|
||||
data_key = item["data_key"]
|
||||
file_format = item.get("file_format", "json")
|
||||
|
||||
# 检查数据库是否已有数据
|
||||
existing_count = model.objects.count()
|
||||
if existing_count > 0:
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{fp_type}] 数据库已有 {existing_count} 条记录,跳过初始化"
|
||||
))
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# 查找源文件
|
||||
src_path = fingerprints_dir / filename
|
||||
if not src_path.exists():
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f"[{fp_type}] 未找到内置指纹文件: {src_path},跳过"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 读取并解析文件(支持 JSON 和 YAML)
|
||||
try:
|
||||
with open(src_path, "r", encoding="utf-8") as f:
|
||||
if file_format == "yaml":
|
||||
file_data = yaml.safe_load(f)
|
||||
else:
|
||||
file_data = json.load(f)
|
||||
except (json.JSONDecodeError, yaml.YAMLError, OSError) as exc:
|
||||
self.stdout.write(self.style.ERROR(
|
||||
f"[{fp_type}] 读取指纹文件失败: {exc}"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 提取指纹数据(根据不同格式处理)
|
||||
fingerprints = self._extract_fingerprints(file_data, data_key, fp_type)
|
||||
if not fingerprints:
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f"[{fp_type}] 指纹文件中没有有效数据,跳过"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 使用 Service 批量导入
|
||||
try:
|
||||
service = service_class()
|
||||
result = service.batch_create_fingerprints(fingerprints)
|
||||
created = result.get("created", 0)
|
||||
failed_count = result.get("failed", 0)
|
||||
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{fp_type}] 导入成功: 创建 {created} 条,失败 {failed_count} 条"
|
||||
))
|
||||
initialized += 1
|
||||
except Exception as exc:
|
||||
self.stdout.write(self.style.ERROR(
|
||||
f"[{fp_type}] 导入失败: {exc}"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"指纹初始化完成: 成功 {initialized}, 已存在跳过 {skipped}, 失败 {failed}"
|
||||
))
|
||||
|
||||
def _extract_fingerprints(self, json_data, data_key, fp_type):
|
||||
"""
|
||||
根据不同格式提取指纹数据,兼容数组和对象两种格式
|
||||
|
||||
支持的格式:
|
||||
- 数组格式: [...] 或 {"key": [...]}
|
||||
- 对象格式: {...} 或 {"key": {...}} -> 转换为 [{"name": k, ...v}]
|
||||
"""
|
||||
# 获取目标数据
|
||||
if data_key is None:
|
||||
# 直接使用整个 JSON
|
||||
target = json_data
|
||||
else:
|
||||
# 从指定 key 获取,支持多个可能的 key(如 apps/technologies)
|
||||
if data_key == "apps":
|
||||
target = json_data.get("apps") or json_data.get("technologies") or {}
|
||||
else:
|
||||
target = json_data.get(data_key, [])
|
||||
|
||||
# 根据数据类型处理
|
||||
if isinstance(target, list):
|
||||
# 已经是数组格式,直接返回
|
||||
return target
|
||||
elif isinstance(target, dict):
|
||||
# 对象格式,转换为数组 [{"name": key, ...value}]
|
||||
return [{"name": name, **data} if isinstance(data, dict) else {"name": name}
|
||||
for name, data in target.items()]
|
||||
|
||||
return []
|
||||
@@ -3,12 +3,17 @@
|
||||
项目安装后执行此命令,自动创建官方模板仓库记录。
|
||||
|
||||
使用方式:
|
||||
python manage.py init_nuclei_templates # 只创建记录
|
||||
python manage.py init_nuclei_templates # 只创建记录(检测本地已有仓库)
|
||||
python manage.py init_nuclei_templates --sync # 创建并同步(git clone)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.engine.models import NucleiTemplateRepo
|
||||
from apps.engine.services import NucleiTemplateRepoService
|
||||
@@ -26,6 +31,20 @@ DEFAULT_REPOS = [
|
||||
]
|
||||
|
||||
|
||||
def get_local_commit_hash(local_path: Path) -> str:
|
||||
"""获取本地 Git 仓库的 commit hash"""
|
||||
if not (local_path / ".git").is_dir():
|
||||
return ""
|
||||
result = subprocess.run(
|
||||
["git", "-C", str(local_path), "rev-parse", "HEAD"],
|
||||
check=False,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
return result.stdout.strip() if result.returncode == 0 else ""
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "初始化 Nuclei 模板仓库(创建官方模板仓库记录)"
|
||||
|
||||
@@ -46,6 +65,8 @@ class Command(BaseCommand):
|
||||
force = options.get("force", False)
|
||||
|
||||
service = NucleiTemplateRepoService()
|
||||
base_dir = Path(getattr(settings, "NUCLEI_TEMPLATES_REPOS_BASE_DIR", "/opt/xingrin/nuclei-repos"))
|
||||
|
||||
created = 0
|
||||
skipped = 0
|
||||
synced = 0
|
||||
@@ -87,20 +108,30 @@ class Command(BaseCommand):
|
||||
|
||||
# 创建新仓库记录
|
||||
try:
|
||||
# 检查本地是否已有仓库(由 install.sh 预下载)
|
||||
local_path = base_dir / name
|
||||
local_commit = get_local_commit_hash(local_path)
|
||||
|
||||
repo = NucleiTemplateRepo.objects.create(
|
||||
name=name,
|
||||
repo_url=repo_url,
|
||||
local_path=str(local_path) if local_commit else "",
|
||||
commit_hash=local_commit,
|
||||
last_synced_at=timezone.now() if local_commit else None,
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{name}] 创建成功: id={repo.id}"
|
||||
))
|
||||
|
||||
if local_commit:
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{name}] 创建成功(检测到本地仓库): commit={local_commit[:8]}"
|
||||
))
|
||||
else:
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{name}] 创建成功: id={repo.id}"
|
||||
))
|
||||
created += 1
|
||||
|
||||
# 初始化本地路径
|
||||
service.ensure_local_path(repo)
|
||||
|
||||
# 如果需要同步
|
||||
if do_sync:
|
||||
# 如果本地没有仓库且需要同步
|
||||
if not local_commit and do_sync:
|
||||
try:
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f"[{name}] 正在同步(首次可能需要几分钟)..."
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""初始化所有内置字典 Wordlist 记录
|
||||
|
||||
- 目录扫描默认字典: dir_default.txt -> /app/backend/wordlist/dir_default.txt
|
||||
- 子域名爆破默认字典: subdomains-top1million-110000.txt -> /app/backend/wordlist/subdomains-top1million-110000.txt
|
||||
内置字典从镜像内 /app/backend/wordlist/ 复制到运行时目录 /opt/xingrin/wordlists/:
|
||||
- 目录扫描默认字典: dir_default.txt
|
||||
- 子域名爆破默认字典: subdomains-top1million-110000.txt
|
||||
|
||||
可重复执行:如果已存在同名记录且文件有效则跳过,只在缺失或文件丢失时创建/修复。
|
||||
"""
|
||||
@@ -13,7 +14,7 @@ from pathlib import Path
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from apps.common.hash_utils import safe_calc_file_sha256
|
||||
from apps.common.utils import safe_calc_file_sha256
|
||||
from apps.engine.models import Wordlist
|
||||
|
||||
|
||||
|
||||
213
backend/apps/engine/migrations/0001_initial.py
Normal file
213
backend/apps/engine/migrations/0001_initial.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# Generated by Django 5.2.7 on 2026-01-02 04:45
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='NucleiTemplateRepo',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(help_text='仓库名称,用于前端展示和配置引用', max_length=200, unique=True)),
|
||||
('repo_url', models.CharField(help_text='Git 仓库地址', max_length=500)),
|
||||
('local_path', models.CharField(blank=True, default='', help_text='本地工作目录绝对路径', max_length=500)),
|
||||
('commit_hash', models.CharField(blank=True, default='', help_text='最后同步的 Git commit hash,用于 Worker 版本校验', max_length=40)),
|
||||
('last_synced_at', models.DateTimeField(blank=True, help_text='最后一次成功同步时间', null=True)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Nuclei 模板仓库',
|
||||
'verbose_name_plural': 'Nuclei 模板仓库',
|
||||
'db_table': 'nuclei_template_repo',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ARLFingerprint',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(help_text='指纹名称', max_length=300, unique=True)),
|
||||
('rule', models.TextField(help_text='匹配规则表达式')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'ARL 指纹',
|
||||
'verbose_name_plural': 'ARL 指纹',
|
||||
'db_table': 'arl_fingerprint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['name'], name='arl_fingerp_name_c3a305_idx'), models.Index(fields=['-created_at'], name='arl_fingerp_created_ed1060_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='EholeFingerprint',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('cms', models.CharField(help_text='产品/CMS名称', max_length=200)),
|
||||
('method', models.CharField(default='keyword', help_text='匹配方式', max_length=200)),
|
||||
('location', models.CharField(default='body', help_text='匹配位置', max_length=200)),
|
||||
('keyword', models.JSONField(default=list, help_text='关键词列表')),
|
||||
('is_important', models.BooleanField(default=False, help_text='是否重点资产')),
|
||||
('type', models.CharField(blank=True, default='-', help_text='分类', max_length=100)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'EHole 指纹',
|
||||
'verbose_name_plural': 'EHole 指纹',
|
||||
'db_table': 'ehole_fingerprint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['cms'], name='ehole_finge_cms_72ca2c_idx'), models.Index(fields=['method'], name='ehole_finge_method_17f0db_idx'), models.Index(fields=['location'], name='ehole_finge_locatio_7bb82b_idx'), models.Index(fields=['type'], name='ehole_finge_type_ca2bce_idx'), models.Index(fields=['is_important'], name='ehole_finge_is_impo_d56e64_idx'), models.Index(fields=['-created_at'], name='ehole_finge_created_d862b0_idx')],
|
||||
'constraints': [models.UniqueConstraint(fields=('cms', 'method', 'location'), name='unique_ehole_fingerprint')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='FingerPrintHubFingerprint',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('fp_id', models.CharField(help_text='指纹ID', max_length=200, unique=True)),
|
||||
('name', models.CharField(help_text='指纹名称', max_length=300)),
|
||||
('author', models.CharField(blank=True, default='', help_text='作者', max_length=200)),
|
||||
('tags', models.CharField(blank=True, default='', help_text='标签', max_length=500)),
|
||||
('severity', models.CharField(blank=True, default='info', help_text='严重程度', max_length=50)),
|
||||
('metadata', models.JSONField(blank=True, default=dict, help_text='元数据')),
|
||||
('http', models.JSONField(default=list, help_text='HTTP 匹配规则')),
|
||||
('source_file', models.CharField(blank=True, default='', help_text='来源文件', max_length=500)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'FingerPrintHub 指纹',
|
||||
'verbose_name_plural': 'FingerPrintHub 指纹',
|
||||
'db_table': 'fingerprinthub_fingerprint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['fp_id'], name='fingerprint_fp_id_df467f_idx'), models.Index(fields=['name'], name='fingerprint_name_95b6fb_idx'), models.Index(fields=['author'], name='fingerprint_author_80f54b_idx'), models.Index(fields=['severity'], name='fingerprint_severit_f70422_idx'), models.Index(fields=['-created_at'], name='fingerprint_created_bec16c_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='FingersFingerprint',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(help_text='指纹名称', max_length=300, unique=True)),
|
||||
('link', models.URLField(blank=True, default='', help_text='相关链接', max_length=500)),
|
||||
('rule', models.JSONField(default=list, help_text='匹配规则数组')),
|
||||
('tag', models.JSONField(default=list, help_text='标签数组')),
|
||||
('focus', models.BooleanField(default=False, help_text='是否重点关注')),
|
||||
('default_port', models.JSONField(blank=True, default=list, help_text='默认端口数组')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Fingers 指纹',
|
||||
'verbose_name_plural': 'Fingers 指纹',
|
||||
'db_table': 'fingers_fingerprint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['name'], name='fingers_fin_name_952de0_idx'), models.Index(fields=['link'], name='fingers_fin_link_4c6b7f_idx'), models.Index(fields=['focus'], name='fingers_fin_focus_568c7f_idx'), models.Index(fields=['-created_at'], name='fingers_fin_created_46fc91_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='GobyFingerprint',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(help_text='产品名称', max_length=300, unique=True)),
|
||||
('logic', models.CharField(help_text='逻辑表达式', max_length=500)),
|
||||
('rule', models.JSONField(default=list, help_text='规则数组')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Goby 指纹',
|
||||
'verbose_name_plural': 'Goby 指纹',
|
||||
'db_table': 'goby_fingerprint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['name'], name='goby_finger_name_82084c_idx'), models.Index(fields=['logic'], name='goby_finger_logic_a63226_idx'), models.Index(fields=['-created_at'], name='goby_finger_created_50e000_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ScanEngine',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('name', models.CharField(help_text='引擎名称', max_length=200, unique=True)),
|
||||
('configuration', models.CharField(blank=True, default='', help_text='引擎配置,yaml 格式', max_length=10000)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '扫描引擎',
|
||||
'verbose_name_plural': '扫描引擎',
|
||||
'db_table': 'scan_engine',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['-created_at'], name='scan_engine_created_da4870_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='WappalyzerFingerprint',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(help_text='应用名称', max_length=300, unique=True)),
|
||||
('cats', models.JSONField(default=list, help_text='分类 ID 数组')),
|
||||
('cookies', models.JSONField(blank=True, default=dict, help_text='Cookie 检测规则')),
|
||||
('headers', models.JSONField(blank=True, default=dict, help_text='HTTP Header 检测规则')),
|
||||
('script_src', models.JSONField(blank=True, default=list, help_text='脚本 URL 正则数组')),
|
||||
('js', models.JSONField(blank=True, default=list, help_text='JavaScript 变量检测规则')),
|
||||
('implies', models.JSONField(blank=True, default=list, help_text='依赖关系数组')),
|
||||
('meta', models.JSONField(blank=True, default=dict, help_text='HTML meta 标签检测规则')),
|
||||
('html', models.JSONField(blank=True, default=list, help_text='HTML 内容正则数组')),
|
||||
('description', models.TextField(blank=True, default='', help_text='应用描述')),
|
||||
('website', models.URLField(blank=True, default='', help_text='官网链接', max_length=500)),
|
||||
('cpe', models.CharField(blank=True, default='', help_text='CPE 标识符', max_length=300)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Wappalyzer 指纹',
|
||||
'verbose_name_plural': 'Wappalyzer 指纹',
|
||||
'db_table': 'wappalyzer_fingerprint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['name'], name='wappalyzer__name_63c669_idx'), models.Index(fields=['website'], name='wappalyzer__website_88de1c_idx'), models.Index(fields=['cpe'], name='wappalyzer__cpe_30c761_idx'), models.Index(fields=['-created_at'], name='wappalyzer__created_8e6c21_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Wordlist',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('name', models.CharField(help_text='字典名称,唯一', max_length=200, unique=True)),
|
||||
('description', models.CharField(blank=True, default='', help_text='字典描述', max_length=200)),
|
||||
('file_path', models.CharField(help_text='后端保存的字典文件绝对路径', max_length=500)),
|
||||
('file_size', models.BigIntegerField(default=0, help_text='文件大小(字节)')),
|
||||
('line_count', models.IntegerField(default=0, help_text='字典行数')),
|
||||
('file_hash', models.CharField(blank=True, default='', help_text='文件 SHA-256 哈希,用于缓存校验', max_length=64)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '字典文件',
|
||||
'verbose_name_plural': '字典文件',
|
||||
'db_table': 'wordlist',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['-created_at'], name='wordlist_created_4afb02_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='WorkerNode',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(help_text='节点名称', max_length=100)),
|
||||
('ip_address', models.GenericIPAddressField(help_text='IP 地址(本地节点为 127.0.0.1)')),
|
||||
('ssh_port', models.IntegerField(default=22, help_text='SSH 端口')),
|
||||
('username', models.CharField(default='root', help_text='SSH 用户名', max_length=50)),
|
||||
('password', models.CharField(blank=True, default='', help_text='SSH 密码', max_length=200)),
|
||||
('is_local', models.BooleanField(default=False, help_text='是否为本地节点(Docker 容器内)')),
|
||||
('status', models.CharField(choices=[('pending', '待部署'), ('deploying', '部署中'), ('online', '在线'), ('offline', '离线'), ('updating', '更新中'), ('outdated', '版本过低')], default='pending', help_text='状态: pending/deploying/online/offline', max_length=20)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
('updated_at', models.DateTimeField(auto_now=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Worker 节点',
|
||||
'db_table': 'worker_node',
|
||||
'ordering': ['-created_at'],
|
||||
'constraints': [models.UniqueConstraint(condition=models.Q(('is_local', False)), fields=('ip_address',), name='unique_remote_worker_ip'), models.UniqueConstraint(fields=('name',), name='unique_worker_name')],
|
||||
},
|
||||
),
|
||||
]
|
||||
29
backend/apps/engine/models/__init__.py
Normal file
29
backend/apps/engine/models/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Engine Models
|
||||
|
||||
导出所有 Engine 模块的 Models
|
||||
"""
|
||||
|
||||
from .engine import WorkerNode, ScanEngine, Wordlist, NucleiTemplateRepo
|
||||
from .fingerprints import (
|
||||
EholeFingerprint,
|
||||
GobyFingerprint,
|
||||
WappalyzerFingerprint,
|
||||
FingersFingerprint,
|
||||
FingerPrintHubFingerprint,
|
||||
ARLFingerprint,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 核心 Models
|
||||
"WorkerNode",
|
||||
"ScanEngine",
|
||||
"Wordlist",
|
||||
"NucleiTemplateRepo",
|
||||
# 指纹 Models
|
||||
"EholeFingerprint",
|
||||
"GobyFingerprint",
|
||||
"WappalyzerFingerprint",
|
||||
"FingersFingerprint",
|
||||
"FingerPrintHubFingerprint",
|
||||
"ARLFingerprint",
|
||||
]
|
||||
@@ -1,3 +1,8 @@
|
||||
"""Engine 模块核心 Models
|
||||
|
||||
包含 WorkerNode, ScanEngine, Wordlist, NucleiTemplateRepo
|
||||
"""
|
||||
|
||||
from django.db import models
|
||||
|
||||
|
||||
@@ -78,6 +83,7 @@ class ScanEngine(models.Model):
|
||||
indexes = [
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.name or f'ScanEngine {self.id}')
|
||||
|
||||
195
backend/apps/engine/models/fingerprints.py
Normal file
195
backend/apps/engine/models/fingerprints.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""指纹相关 Models
|
||||
|
||||
包含 EHole、Goby、Wappalyzer 等指纹格式的数据模型
|
||||
"""
|
||||
|
||||
from django.db import models
|
||||
|
||||
|
||||
class GobyFingerprint(models.Model):
|
||||
"""Goby 格式指纹规则
|
||||
|
||||
Goby 使用逻辑表达式和规则数组进行匹配:
|
||||
- logic: 逻辑表达式,如 "a||b", "(a&&b)||c"
|
||||
- rule: 规则数组,每条规则包含 label, feature, is_equal
|
||||
"""
|
||||
|
||||
name = models.CharField(max_length=300, unique=True, help_text='产品名称')
|
||||
logic = models.CharField(max_length=500, help_text='逻辑表达式')
|
||||
rule = models.JSONField(default=list, help_text='规则数组')
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'goby_fingerprint'
|
||||
verbose_name = 'Goby 指纹'
|
||||
verbose_name_plural = 'Goby 指纹'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['name']),
|
||||
models.Index(fields=['logic']),
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name} ({self.logic})"
|
||||
|
||||
|
||||
class EholeFingerprint(models.Model):
|
||||
"""EHole 格式指纹规则(字段与 ehole.json 一致)"""
|
||||
|
||||
cms = models.CharField(max_length=200, help_text='产品/CMS名称')
|
||||
method = models.CharField(max_length=200, default='keyword', help_text='匹配方式')
|
||||
location = models.CharField(max_length=200, default='body', help_text='匹配位置')
|
||||
keyword = models.JSONField(default=list, help_text='关键词列表')
|
||||
is_important = models.BooleanField(default=False, help_text='是否重点资产')
|
||||
type = models.CharField(max_length=100, blank=True, default='-', help_text='分类')
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'ehole_fingerprint'
|
||||
verbose_name = 'EHole 指纹'
|
||||
verbose_name_plural = 'EHole 指纹'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
# 搜索过滤字段索引
|
||||
models.Index(fields=['cms']),
|
||||
models.Index(fields=['method']),
|
||||
models.Index(fields=['location']),
|
||||
models.Index(fields=['type']),
|
||||
models.Index(fields=['is_important']),
|
||||
# 排序字段索引
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:cms + method + location 组合不能重复
|
||||
models.UniqueConstraint(
|
||||
fields=['cms', 'method', 'location'],
|
||||
name='unique_ehole_fingerprint'
|
||||
),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.cms} ({self.method}@{self.location})"
|
||||
|
||||
|
||||
class WappalyzerFingerprint(models.Model):
|
||||
"""Wappalyzer 格式指纹规则
|
||||
|
||||
Wappalyzer 支持多种检测方式:cookies, headers, scriptSrc, js, meta, html 等
|
||||
"""
|
||||
|
||||
name = models.CharField(max_length=300, unique=True, help_text='应用名称')
|
||||
cats = models.JSONField(default=list, help_text='分类 ID 数组')
|
||||
cookies = models.JSONField(default=dict, blank=True, help_text='Cookie 检测规则')
|
||||
headers = models.JSONField(default=dict, blank=True, help_text='HTTP Header 检测规则')
|
||||
script_src = models.JSONField(default=list, blank=True, help_text='脚本 URL 正则数组')
|
||||
js = models.JSONField(default=list, blank=True, help_text='JavaScript 变量检测规则')
|
||||
implies = models.JSONField(default=list, blank=True, help_text='依赖关系数组')
|
||||
meta = models.JSONField(default=dict, blank=True, help_text='HTML meta 标签检测规则')
|
||||
html = models.JSONField(default=list, blank=True, help_text='HTML 内容正则数组')
|
||||
description = models.TextField(blank=True, default='', help_text='应用描述')
|
||||
website = models.URLField(max_length=500, blank=True, default='', help_text='官网链接')
|
||||
cpe = models.CharField(max_length=300, blank=True, default='', help_text='CPE 标识符')
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'wappalyzer_fingerprint'
|
||||
verbose_name = 'Wappalyzer 指纹'
|
||||
verbose_name_plural = 'Wappalyzer 指纹'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['name']),
|
||||
models.Index(fields=['website']),
|
||||
models.Index(fields=['cpe']),
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}"
|
||||
|
||||
|
||||
class FingersFingerprint(models.Model):
|
||||
"""Fingers 格式指纹规则 (fingers_http.json)
|
||||
|
||||
使用正则表达式和标签进行匹配,支持 favicon hash、header、body 等多种检测方式
|
||||
"""
|
||||
|
||||
name = models.CharField(max_length=300, unique=True, help_text='指纹名称')
|
||||
link = models.URLField(max_length=500, blank=True, default='', help_text='相关链接')
|
||||
rule = models.JSONField(default=list, help_text='匹配规则数组')
|
||||
tag = models.JSONField(default=list, help_text='标签数组')
|
||||
focus = models.BooleanField(default=False, help_text='是否重点关注')
|
||||
default_port = models.JSONField(default=list, blank=True, help_text='默认端口数组')
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'fingers_fingerprint'
|
||||
verbose_name = 'Fingers 指纹'
|
||||
verbose_name_plural = 'Fingers 指纹'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['name']),
|
||||
models.Index(fields=['link']),
|
||||
models.Index(fields=['focus']),
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}"
|
||||
|
||||
|
||||
class FingerPrintHubFingerprint(models.Model):
|
||||
"""FingerPrintHub 格式指纹规则 (fingerprinthub_web.json)
|
||||
|
||||
基于 nuclei 模板格式,使用 HTTP 请求和响应特征进行匹配
|
||||
"""
|
||||
|
||||
fp_id = models.CharField(max_length=200, unique=True, help_text='指纹ID')
|
||||
name = models.CharField(max_length=300, help_text='指纹名称')
|
||||
author = models.CharField(max_length=200, blank=True, default='', help_text='作者')
|
||||
tags = models.CharField(max_length=500, blank=True, default='', help_text='标签')
|
||||
severity = models.CharField(max_length=50, blank=True, default='info', help_text='严重程度')
|
||||
metadata = models.JSONField(default=dict, blank=True, help_text='元数据')
|
||||
http = models.JSONField(default=list, help_text='HTTP 匹配规则')
|
||||
source_file = models.CharField(max_length=500, blank=True, default='', help_text='来源文件')
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'fingerprinthub_fingerprint'
|
||||
verbose_name = 'FingerPrintHub 指纹'
|
||||
verbose_name_plural = 'FingerPrintHub 指纹'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['fp_id']),
|
||||
models.Index(fields=['name']),
|
||||
models.Index(fields=['author']),
|
||||
models.Index(fields=['severity']),
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name} ({self.fp_id})"
|
||||
|
||||
|
||||
class ARLFingerprint(models.Model):
|
||||
"""ARL 格式指纹规则 (ARL.yaml)
|
||||
|
||||
使用简单的 name + rule 表达式格式
|
||||
"""
|
||||
|
||||
name = models.CharField(max_length=300, unique=True, help_text='指纹名称')
|
||||
rule = models.TextField(help_text='匹配规则表达式')
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'arl_fingerprint'
|
||||
verbose_name = 'ARL 指纹'
|
||||
verbose_name_plural = 'ARL 指纹'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['name']),
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}"
|
||||
@@ -88,6 +88,8 @@ def _register_scheduled_jobs(scheduler: BackgroundScheduler):
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(" - 已注册: 扫描结果清理(每天 03:00)")
|
||||
|
||||
# 注意:搜索物化视图刷新已迁移到 pg_ivm 增量维护,无需定时任务
|
||||
|
||||
|
||||
def _trigger_scheduled_scans():
|
||||
|
||||
20
backend/apps/engine/serializers/fingerprints/__init__.py
Normal file
20
backend/apps/engine/serializers/fingerprints/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""指纹管理 Serializers
|
||||
|
||||
导出所有指纹相关的 Serializer 类
|
||||
"""
|
||||
|
||||
from .ehole import EholeFingerprintSerializer
|
||||
from .goby import GobyFingerprintSerializer
|
||||
from .wappalyzer import WappalyzerFingerprintSerializer
|
||||
from .fingers import FingersFingerprintSerializer
|
||||
from .fingerprinthub import FingerPrintHubFingerprintSerializer
|
||||
from .arl import ARLFingerprintSerializer
|
||||
|
||||
__all__ = [
|
||||
"EholeFingerprintSerializer",
|
||||
"GobyFingerprintSerializer",
|
||||
"WappalyzerFingerprintSerializer",
|
||||
"FingersFingerprintSerializer",
|
||||
"FingerPrintHubFingerprintSerializer",
|
||||
"ARLFingerprintSerializer",
|
||||
]
|
||||
31
backend/apps/engine/serializers/fingerprints/arl.py
Normal file
31
backend/apps/engine/serializers/fingerprints/arl.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""ARL 指纹 Serializer"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.engine.models import ARLFingerprint
|
||||
|
||||
|
||||
class ARLFingerprintSerializer(serializers.ModelSerializer):
|
||||
"""ARL 指纹序列化器
|
||||
|
||||
字段映射:
|
||||
- name: 指纹名称 (必填, 唯一)
|
||||
- rule: 匹配规则表达式 (必填)
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
model = ARLFingerprint
|
||||
fields = ['id', 'name', 'rule', 'created_at']
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
def validate_name(self, value):
|
||||
"""校验 name 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("name 字段不能为空")
|
||||
return value.strip()
|
||||
|
||||
def validate_rule(self, value):
|
||||
"""校验 rule 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("rule 字段不能为空")
|
||||
return value.strip()
|
||||
27
backend/apps/engine/serializers/fingerprints/ehole.py
Normal file
27
backend/apps/engine/serializers/fingerprints/ehole.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""EHole 指纹 Serializer"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.engine.models import EholeFingerprint
|
||||
|
||||
|
||||
class EholeFingerprintSerializer(serializers.ModelSerializer):
|
||||
"""EHole 指纹序列化器"""
|
||||
|
||||
class Meta:
|
||||
model = EholeFingerprint
|
||||
fields = ['id', 'cms', 'method', 'location', 'keyword',
|
||||
'is_important', 'type', 'created_at']
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
def validate_cms(self, value):
|
||||
"""校验 cms 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("cms 字段不能为空")
|
||||
return value.strip()
|
||||
|
||||
def validate_keyword(self, value):
|
||||
"""校验 keyword 字段"""
|
||||
if not isinstance(value, list):
|
||||
raise serializers.ValidationError("keyword 必须是数组")
|
||||
return value
|
||||
@@ -0,0 +1,50 @@
|
||||
"""FingerPrintHub 指纹 Serializer"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.engine.models import FingerPrintHubFingerprint
|
||||
|
||||
|
||||
class FingerPrintHubFingerprintSerializer(serializers.ModelSerializer):
|
||||
"""FingerPrintHub 指纹序列化器
|
||||
|
||||
字段映射:
|
||||
- fp_id: 指纹ID (必填, 唯一)
|
||||
- name: 指纹名称 (必填)
|
||||
- author: 作者 (可选)
|
||||
- tags: 标签字符串 (可选)
|
||||
- severity: 严重程度 (可选, 默认 'info')
|
||||
- metadata: 元数据 JSON (可选)
|
||||
- http: HTTP 匹配规则数组 (必填)
|
||||
- source_file: 来源文件 (可选)
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
model = FingerPrintHubFingerprint
|
||||
fields = ['id', 'fp_id', 'name', 'author', 'tags', 'severity',
|
||||
'metadata', 'http', 'source_file', 'created_at']
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
def validate_fp_id(self, value):
|
||||
"""校验 fp_id 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("fp_id 字段不能为空")
|
||||
return value.strip()
|
||||
|
||||
def validate_name(self, value):
|
||||
"""校验 name 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("name 字段不能为空")
|
||||
return value.strip()
|
||||
|
||||
def validate_http(self, value):
|
||||
"""校验 http 字段"""
|
||||
if not isinstance(value, list):
|
||||
raise serializers.ValidationError("http 必须是数组")
|
||||
return value
|
||||
|
||||
def validate_metadata(self, value):
|
||||
"""校验 metadata 字段"""
|
||||
if not isinstance(value, dict):
|
||||
raise serializers.ValidationError("metadata 必须是对象")
|
||||
return value
|
||||
48
backend/apps/engine/serializers/fingerprints/fingers.py
Normal file
48
backend/apps/engine/serializers/fingerprints/fingers.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Fingers 指纹 Serializer"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.engine.models import FingersFingerprint
|
||||
|
||||
|
||||
class FingersFingerprintSerializer(serializers.ModelSerializer):
|
||||
"""Fingers 指纹序列化器
|
||||
|
||||
字段映射:
|
||||
- name: 指纹名称 (必填, 唯一)
|
||||
- link: 相关链接 (可选)
|
||||
- rule: 匹配规则数组 (必填)
|
||||
- tag: 标签数组 (可选)
|
||||
- focus: 是否重点关注 (可选, 默认 False)
|
||||
- default_port: 默认端口数组 (可选)
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
model = FingersFingerprint
|
||||
fields = ['id', 'name', 'link', 'rule', 'tag', 'focus',
|
||||
'default_port', 'created_at']
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
def validate_name(self, value):
|
||||
"""校验 name 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("name 字段不能为空")
|
||||
return value.strip()
|
||||
|
||||
def validate_rule(self, value):
|
||||
"""校验 rule 字段"""
|
||||
if not isinstance(value, list):
|
||||
raise serializers.ValidationError("rule 必须是数组")
|
||||
return value
|
||||
|
||||
def validate_tag(self, value):
|
||||
"""校验 tag 字段"""
|
||||
if not isinstance(value, list):
|
||||
raise serializers.ValidationError("tag 必须是数组")
|
||||
return value
|
||||
|
||||
def validate_default_port(self, value):
|
||||
"""校验 default_port 字段"""
|
||||
if not isinstance(value, list):
|
||||
raise serializers.ValidationError("default_port 必须是数组")
|
||||
return value
|
||||
26
backend/apps/engine/serializers/fingerprints/goby.py
Normal file
26
backend/apps/engine/serializers/fingerprints/goby.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Goby 指纹 Serializer"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.engine.models import GobyFingerprint
|
||||
|
||||
|
||||
class GobyFingerprintSerializer(serializers.ModelSerializer):
|
||||
"""Goby 指纹序列化器"""
|
||||
|
||||
class Meta:
|
||||
model = GobyFingerprint
|
||||
fields = ['id', 'name', 'logic', 'rule', 'created_at']
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
def validate_name(self, value):
|
||||
"""校验 name 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("name 字段不能为空")
|
||||
return value.strip()
|
||||
|
||||
def validate_rule(self, value):
|
||||
"""校验 rule 字段"""
|
||||
if not isinstance(value, list):
|
||||
raise serializers.ValidationError("rule 必须是数组")
|
||||
return value
|
||||
24
backend/apps/engine/serializers/fingerprints/wappalyzer.py
Normal file
24
backend/apps/engine/serializers/fingerprints/wappalyzer.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Wappalyzer 指纹 Serializer"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.engine.models import WappalyzerFingerprint
|
||||
|
||||
|
||||
class WappalyzerFingerprintSerializer(serializers.ModelSerializer):
|
||||
"""Wappalyzer 指纹序列化器"""
|
||||
|
||||
class Meta:
|
||||
model = WappalyzerFingerprint
|
||||
fields = [
|
||||
'id', 'name', 'cats', 'cookies', 'headers', 'script_src',
|
||||
'js', 'implies', 'meta', 'html', 'description', 'website',
|
||||
'cpe', 'created_at'
|
||||
]
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
def validate_name(self, value):
|
||||
"""校验 name 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("name 字段不能为空")
|
||||
return value.strip()
|
||||
@@ -66,6 +66,7 @@ def get_start_agent_script(
|
||||
# 替换变量
|
||||
script = script.replace("{{HEARTBEAT_API_URL}}", heartbeat_api_url or '')
|
||||
script = script.replace("{{WORKER_ID}}", str(worker_id) if worker_id else '')
|
||||
script = script.replace("{{WORKER_API_KEY}}", getattr(settings, 'WORKER_API_KEY', ''))
|
||||
|
||||
# 注入镜像版本配置(确保远程节点使用相同版本)
|
||||
docker_user = getattr(settings, 'DOCKER_USER', 'yyhuni')
|
||||
|
||||
22
backend/apps/engine/services/fingerprints/__init__.py
Normal file
22
backend/apps/engine/services/fingerprints/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""指纹管理 Services
|
||||
|
||||
导出所有指纹相关的 Service 类
|
||||
"""
|
||||
|
||||
from .base import BaseFingerprintService
|
||||
from .ehole import EholeFingerprintService
|
||||
from .goby import GobyFingerprintService
|
||||
from .wappalyzer import WappalyzerFingerprintService
|
||||
from .fingers_service import FingersFingerprintService
|
||||
from .fingerprinthub_service import FingerPrintHubFingerprintService
|
||||
from .arl_service import ARLFingerprintService
|
||||
|
||||
__all__ = [
|
||||
"BaseFingerprintService",
|
||||
"EholeFingerprintService",
|
||||
"GobyFingerprintService",
|
||||
"WappalyzerFingerprintService",
|
||||
"FingersFingerprintService",
|
||||
"FingerPrintHubFingerprintService",
|
||||
"ARLFingerprintService",
|
||||
]
|
||||
110
backend/apps/engine/services/fingerprints/arl_service.py
Normal file
110
backend/apps/engine/services/fingerprints/arl_service.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""ARL 指纹管理 Service
|
||||
|
||||
实现 ARL 格式指纹的校验、转换和导出逻辑
|
||||
支持 YAML 格式的导入导出
|
||||
"""
|
||||
|
||||
import logging
|
||||
import yaml
|
||||
|
||||
from apps.engine.models import ARLFingerprint
|
||||
from .base import BaseFingerprintService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ARLFingerprintService(BaseFingerprintService):
|
||||
"""ARL 指纹管理服务(继承基类,实现 ARL 特定逻辑)"""
|
||||
|
||||
model = ARLFingerprint
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条 ARL 指纹
|
||||
|
||||
校验规则:
|
||||
- name 字段必须存在且非空
|
||||
- rule 字段必须存在且非空
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
name = item.get('name', '')
|
||||
rule = item.get('rule', '')
|
||||
return bool(name and str(name).strip()) and bool(rule and str(rule).strip())
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换 ARL YAML 格式为 Model 字段
|
||||
|
||||
Args:
|
||||
item: 原始 ARL YAML 数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
return {
|
||||
'name': str(item.get('name', '')).strip(),
|
||||
'rule': str(item.get('rule', '')).strip(),
|
||||
}
|
||||
|
||||
def get_export_data(self) -> list:
|
||||
"""
|
||||
获取导出数据(ARL 格式 - 数组,用于 YAML 导出)
|
||||
|
||||
Returns:
|
||||
list: ARL 格式的数据(数组格式)
|
||||
[
|
||||
{"name": "...", "rule": "..."},
|
||||
...
|
||||
]
|
||||
"""
|
||||
fingerprints = self.model.objects.all()
|
||||
return [
|
||||
{
|
||||
'name': fp.name,
|
||||
'rule': fp.rule,
|
||||
}
|
||||
for fp in fingerprints
|
||||
]
|
||||
|
||||
def export_to_yaml(self, output_path: str) -> int:
|
||||
"""
|
||||
导出所有指纹到 YAML 文件
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 导出的指纹数量
|
||||
"""
|
||||
data = self.get_export_data()
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(data, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
|
||||
count = len(data)
|
||||
logger.info("导出 ARL 指纹文件: %s, 数量: %d", output_path, count)
|
||||
return count
|
||||
|
||||
def parse_yaml_import(self, yaml_content: str) -> list:
|
||||
"""
|
||||
解析 YAML 格式的导入内容
|
||||
|
||||
Args:
|
||||
yaml_content: YAML 格式的字符串内容
|
||||
|
||||
Returns:
|
||||
list: 解析后的指纹数据列表
|
||||
|
||||
Raises:
|
||||
ValueError: 当 YAML 格式无效时
|
||||
"""
|
||||
try:
|
||||
data = yaml.safe_load(yaml_content)
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("ARL YAML 文件必须是数组格式")
|
||||
return data
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"无效的 YAML 格式: {e}")
|
||||
144
backend/apps/engine/services/fingerprints/base.py
Normal file
144
backend/apps/engine/services/fingerprints/base.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""指纹管理基类 Service
|
||||
|
||||
提供通用的批量操作和缓存逻辑,供 EHole/Goby/Wappalyzer 等子类继承
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseFingerprintService:
|
||||
"""指纹管理基类 Service,提供通用的批量操作和缓存逻辑"""
|
||||
|
||||
model = None # 子类必须指定
|
||||
BATCH_SIZE = 1000 # 每批处理数量
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条指纹,子类必须实现
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 validate_fingerprint 方法")
|
||||
|
||||
def validate_fingerprints(self, raw_data: list) -> tuple[list, list]:
|
||||
"""
|
||||
批量校验指纹数据
|
||||
|
||||
Args:
|
||||
raw_data: 原始指纹数据列表
|
||||
|
||||
Returns:
|
||||
tuple: (valid_items, invalid_items)
|
||||
"""
|
||||
valid, invalid = [], []
|
||||
for item in raw_data:
|
||||
if self.validate_fingerprint(item):
|
||||
valid.append(item)
|
||||
else:
|
||||
invalid.append(item)
|
||||
return valid, invalid
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换为 Model 字段,子类必须实现
|
||||
|
||||
Args:
|
||||
item: 原始指纹数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 to_model_data 方法")
|
||||
|
||||
def bulk_create(self, fingerprints: list) -> int:
|
||||
"""
|
||||
批量创建指纹记录(已校验的数据)
|
||||
|
||||
Args:
|
||||
fingerprints: 已校验的指纹数据列表
|
||||
|
||||
Returns:
|
||||
int: 成功创建数量
|
||||
"""
|
||||
if not fingerprints:
|
||||
return 0
|
||||
|
||||
objects = [self.model(**self.to_model_data(item)) for item in fingerprints]
|
||||
created = self.model.objects.bulk_create(objects, ignore_conflicts=True)
|
||||
return len(created)
|
||||
|
||||
def batch_create_fingerprints(self, raw_data: list) -> dict:
|
||||
"""
|
||||
完整流程:分批校验 + 批量创建
|
||||
|
||||
Args:
|
||||
raw_data: 原始指纹数据列表
|
||||
|
||||
Returns:
|
||||
dict: {'created': int, 'failed': int}
|
||||
"""
|
||||
total_created = 0
|
||||
total_failed = 0
|
||||
|
||||
for i in range(0, len(raw_data), self.BATCH_SIZE):
|
||||
batch = raw_data[i:i + self.BATCH_SIZE]
|
||||
valid, invalid = self.validate_fingerprints(batch)
|
||||
total_created += self.bulk_create(valid)
|
||||
total_failed += len(invalid)
|
||||
|
||||
logger.info(
|
||||
"批量创建指纹完成: created=%d, failed=%d, total=%d",
|
||||
total_created, total_failed, len(raw_data)
|
||||
)
|
||||
return {'created': total_created, 'failed': total_failed}
|
||||
|
||||
def get_export_data(self) -> dict:
|
||||
"""
|
||||
获取导出数据,子类必须实现
|
||||
|
||||
Returns:
|
||||
dict: 导出的 JSON 数据
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 get_export_data 方法")
|
||||
|
||||
def export_to_file(self, output_path: str) -> int:
|
||||
"""
|
||||
导出所有指纹到 JSON 文件
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 导出的指纹数量
|
||||
"""
|
||||
data = self.get_export_data()
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False)
|
||||
count = len(data.get('fingerprint', []))
|
||||
logger.info("导出指纹文件: %s, 数量: %d", output_path, count)
|
||||
return count
|
||||
|
||||
def get_fingerprint_version(self) -> str:
|
||||
"""
|
||||
获取指纹库版本标识(用于缓存校验)
|
||||
|
||||
Returns:
|
||||
str: 版本标识,格式 "{count}_{latest_timestamp}"
|
||||
|
||||
版本变化场景:
|
||||
- 新增记录 → count 变化
|
||||
- 删除记录 → count 变化
|
||||
- 清空全部 → count 变为 0
|
||||
"""
|
||||
count = self.model.objects.count()
|
||||
latest = self.model.objects.order_by('-created_at').first()
|
||||
latest_ts = int(latest.created_at.timestamp()) if latest else 0
|
||||
return f"{count}_{latest_ts}"
|
||||
84
backend/apps/engine/services/fingerprints/ehole.py
Normal file
84
backend/apps/engine/services/fingerprints/ehole.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""EHole 指纹管理 Service
|
||||
|
||||
实现 EHole 格式指纹的校验、转换和导出逻辑
|
||||
"""
|
||||
|
||||
from apps.engine.models import EholeFingerprint
|
||||
from .base import BaseFingerprintService
|
||||
|
||||
|
||||
class EholeFingerprintService(BaseFingerprintService):
|
||||
"""EHole 指纹管理服务(继承基类,实现 EHole 特定逻辑)"""
|
||||
|
||||
model = EholeFingerprint
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条 EHole 指纹
|
||||
|
||||
校验规则:
|
||||
- cms 字段必须存在且非空
|
||||
- keyword 字段必须是数组
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
cms = item.get('cms', '')
|
||||
keyword = item.get('keyword')
|
||||
return bool(cms and str(cms).strip()) and isinstance(keyword, list)
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换 EHole JSON 格式为 Model 字段
|
||||
|
||||
字段映射:
|
||||
- isImportant (JSON) → is_important (Model)
|
||||
|
||||
Args:
|
||||
item: 原始 EHole JSON 数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
return {
|
||||
'cms': str(item.get('cms', '')).strip(),
|
||||
'method': item.get('method', 'keyword'),
|
||||
'location': item.get('location', 'body'),
|
||||
'keyword': item.get('keyword', []),
|
||||
'is_important': item.get('isImportant', False),
|
||||
'type': item.get('type', '-'),
|
||||
}
|
||||
|
||||
def get_export_data(self) -> dict:
|
||||
"""
|
||||
获取导出数据(EHole JSON 格式)
|
||||
|
||||
Returns:
|
||||
dict: EHole 格式的 JSON 数据
|
||||
{
|
||||
"fingerprint": [
|
||||
{"cms": "...", "method": "...", "location": "...",
|
||||
"keyword": [...], "isImportant": false, "type": "..."},
|
||||
...
|
||||
],
|
||||
"version": "1000_1703836800"
|
||||
}
|
||||
"""
|
||||
fingerprints = self.model.objects.all()
|
||||
data = []
|
||||
for fp in fingerprints:
|
||||
data.append({
|
||||
'cms': fp.cms,
|
||||
'method': fp.method,
|
||||
'location': fp.location,
|
||||
'keyword': fp.keyword,
|
||||
'isImportant': fp.is_important, # 转回 JSON 格式
|
||||
'type': fp.type,
|
||||
})
|
||||
return {
|
||||
'fingerprint': data,
|
||||
'version': self.get_fingerprint_version(),
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
"""FingerPrintHub 指纹管理 Service
|
||||
|
||||
实现 FingerPrintHub 格式指纹的校验、转换和导出逻辑
|
||||
"""
|
||||
|
||||
from apps.engine.models import FingerPrintHubFingerprint
|
||||
from .base import BaseFingerprintService
|
||||
|
||||
|
||||
class FingerPrintHubFingerprintService(BaseFingerprintService):
|
||||
"""FingerPrintHub 指纹管理服务(继承基类,实现 FingerPrintHub 特定逻辑)"""
|
||||
|
||||
model = FingerPrintHubFingerprint
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条 FingerPrintHub 指纹
|
||||
|
||||
校验规则:
|
||||
- id 字段必须存在且非空
|
||||
- info 字段必须存在且包含 name
|
||||
- http 字段必须是数组
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
fp_id = item.get('id', '')
|
||||
info = item.get('info', {})
|
||||
http = item.get('http')
|
||||
|
||||
if not fp_id or not str(fp_id).strip():
|
||||
return False
|
||||
if not isinstance(info, dict) or not info.get('name'):
|
||||
return False
|
||||
if not isinstance(http, list):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换 FingerPrintHub JSON 格式为 Model 字段
|
||||
|
||||
字段映射(嵌套结构转扁平):
|
||||
- id (JSON) → fp_id (Model)
|
||||
- info.name (JSON) → name (Model)
|
||||
- info.author (JSON) → author (Model)
|
||||
- info.tags (JSON) → tags (Model)
|
||||
- info.severity (JSON) → severity (Model)
|
||||
- info.metadata (JSON) → metadata (Model)
|
||||
- http (JSON) → http (Model)
|
||||
- _source_file (JSON) → source_file (Model)
|
||||
|
||||
Args:
|
||||
item: 原始 FingerPrintHub JSON 数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
info = item.get('info', {})
|
||||
return {
|
||||
'fp_id': str(item.get('id', '')).strip(),
|
||||
'name': str(info.get('name', '')).strip(),
|
||||
'author': info.get('author', ''),
|
||||
'tags': info.get('tags', ''),
|
||||
'severity': info.get('severity', 'info'),
|
||||
'metadata': info.get('metadata', {}),
|
||||
'http': item.get('http', []),
|
||||
'source_file': item.get('_source_file', ''),
|
||||
}
|
||||
|
||||
def get_export_data(self) -> list:
|
||||
"""
|
||||
获取导出数据(FingerPrintHub JSON 格式 - 数组)
|
||||
|
||||
Returns:
|
||||
list: FingerPrintHub 格式的 JSON 数据(数组格式)
|
||||
[
|
||||
{
|
||||
"id": "...",
|
||||
"info": {"name": "...", "author": "...", "tags": "...",
|
||||
"severity": "...", "metadata": {...}},
|
||||
"http": [...],
|
||||
"_source_file": "..."
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
fingerprints = self.model.objects.all()
|
||||
data = []
|
||||
for fp in fingerprints:
|
||||
item = {
|
||||
'id': fp.fp_id,
|
||||
'info': {
|
||||
'name': fp.name,
|
||||
'author': fp.author,
|
||||
'tags': fp.tags,
|
||||
'severity': fp.severity,
|
||||
'metadata': fp.metadata,
|
||||
},
|
||||
'http': fp.http,
|
||||
}
|
||||
# 只有当 source_file 非空时才添加该字段
|
||||
if fp.source_file:
|
||||
item['_source_file'] = fp.source_file
|
||||
data.append(item)
|
||||
return data
|
||||
83
backend/apps/engine/services/fingerprints/fingers_service.py
Normal file
83
backend/apps/engine/services/fingerprints/fingers_service.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Fingers 指纹管理 Service
|
||||
|
||||
实现 Fingers 格式指纹的校验、转换和导出逻辑
|
||||
"""
|
||||
|
||||
from apps.engine.models import FingersFingerprint
|
||||
from .base import BaseFingerprintService
|
||||
|
||||
|
||||
class FingersFingerprintService(BaseFingerprintService):
|
||||
"""Fingers 指纹管理服务(继承基类,实现 Fingers 特定逻辑)"""
|
||||
|
||||
model = FingersFingerprint
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条 Fingers 指纹
|
||||
|
||||
校验规则:
|
||||
- name 字段必须存在且非空
|
||||
- rule 字段必须是数组
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
name = item.get('name', '')
|
||||
rule = item.get('rule')
|
||||
return bool(name and str(name).strip()) and isinstance(rule, list)
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换 Fingers JSON 格式为 Model 字段
|
||||
|
||||
字段映射:
|
||||
- default_port (JSON) → default_port (Model)
|
||||
|
||||
Args:
|
||||
item: 原始 Fingers JSON 数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
return {
|
||||
'name': str(item.get('name', '')).strip(),
|
||||
'link': item.get('link', ''),
|
||||
'rule': item.get('rule', []),
|
||||
'tag': item.get('tag', []),
|
||||
'focus': item.get('focus', False),
|
||||
'default_port': item.get('default_port', []),
|
||||
}
|
||||
|
||||
def get_export_data(self) -> list:
|
||||
"""
|
||||
获取导出数据(Fingers JSON 格式 - 数组)
|
||||
|
||||
Returns:
|
||||
list: Fingers 格式的 JSON 数据(数组格式)
|
||||
[
|
||||
{"name": "...", "link": "...", "rule": [...], "tag": [...],
|
||||
"focus": false, "default_port": [...]},
|
||||
...
|
||||
]
|
||||
"""
|
||||
fingerprints = self.model.objects.all()
|
||||
data = []
|
||||
for fp in fingerprints:
|
||||
item = {
|
||||
'name': fp.name,
|
||||
'link': fp.link,
|
||||
'rule': fp.rule,
|
||||
'tag': fp.tag,
|
||||
}
|
||||
# 只有当 focus 为 True 时才添加该字段(保持与原始格式一致)
|
||||
if fp.focus:
|
||||
item['focus'] = fp.focus
|
||||
# 只有当 default_port 非空时才添加该字段
|
||||
if fp.default_port:
|
||||
item['default_port'] = fp.default_port
|
||||
data.append(item)
|
||||
return data
|
||||
70
backend/apps/engine/services/fingerprints/goby.py
Normal file
70
backend/apps/engine/services/fingerprints/goby.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Goby 指纹管理 Service
|
||||
|
||||
实现 Goby 格式指纹的校验、转换和导出逻辑
|
||||
"""
|
||||
|
||||
from apps.engine.models import GobyFingerprint
|
||||
from .base import BaseFingerprintService
|
||||
|
||||
|
||||
class GobyFingerprintService(BaseFingerprintService):
|
||||
"""Goby 指纹管理服务(继承基类,实现 Goby 特定逻辑)"""
|
||||
|
||||
model = GobyFingerprint
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条 Goby 指纹
|
||||
|
||||
校验规则:
|
||||
- name 字段必须存在且非空
|
||||
- logic 字段必须存在
|
||||
- rule 字段必须是数组
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
name = item.get('name', '')
|
||||
logic = item.get('logic', '')
|
||||
rule = item.get('rule')
|
||||
return bool(name and str(name).strip()) and bool(logic) and isinstance(rule, list)
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换 Goby JSON 格式为 Model 字段
|
||||
|
||||
Args:
|
||||
item: 原始 Goby JSON 数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
return {
|
||||
'name': str(item.get('name', '')).strip(),
|
||||
'logic': item.get('logic', ''),
|
||||
'rule': item.get('rule', []),
|
||||
}
|
||||
|
||||
def get_export_data(self) -> list:
|
||||
"""
|
||||
获取导出数据(Goby JSON 格式 - 数组)
|
||||
|
||||
Returns:
|
||||
list: Goby 格式的 JSON 数据(数组格式)
|
||||
[
|
||||
{"name": "...", "logic": "...", "rule": [...]},
|
||||
...
|
||||
]
|
||||
"""
|
||||
fingerprints = self.model.objects.all()
|
||||
return [
|
||||
{
|
||||
'name': fp.name,
|
||||
'logic': fp.logic,
|
||||
'rule': fp.rule,
|
||||
}
|
||||
for fp in fingerprints
|
||||
]
|
||||
99
backend/apps/engine/services/fingerprints/wappalyzer.py
Normal file
99
backend/apps/engine/services/fingerprints/wappalyzer.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Wappalyzer 指纹管理 Service
|
||||
|
||||
实现 Wappalyzer 格式指纹的校验、转换和导出逻辑
|
||||
"""
|
||||
|
||||
from apps.engine.models import WappalyzerFingerprint
|
||||
from .base import BaseFingerprintService
|
||||
|
||||
|
||||
class WappalyzerFingerprintService(BaseFingerprintService):
|
||||
"""Wappalyzer 指纹管理服务(继承基类,实现 Wappalyzer 特定逻辑)"""
|
||||
|
||||
model = WappalyzerFingerprint
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条 Wappalyzer 指纹
|
||||
|
||||
校验规则:
|
||||
- name 字段必须存在且非空(从 apps 对象的 key 传入)
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
name = item.get('name', '')
|
||||
return bool(name and str(name).strip())
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换 Wappalyzer JSON 格式为 Model 字段
|
||||
|
||||
字段映射:
|
||||
- scriptSrc (JSON) → script_src (Model)
|
||||
|
||||
Args:
|
||||
item: 原始 Wappalyzer JSON 数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
return {
|
||||
'name': str(item.get('name', '')).strip(),
|
||||
'cats': item.get('cats', []),
|
||||
'cookies': item.get('cookies', {}),
|
||||
'headers': item.get('headers', {}),
|
||||
'script_src': item.get('scriptSrc', []), # JSON: scriptSrc -> Model: script_src
|
||||
'js': item.get('js', []),
|
||||
'implies': item.get('implies', []),
|
||||
'meta': item.get('meta', {}),
|
||||
'html': item.get('html', []),
|
||||
'description': item.get('description', ''),
|
||||
'website': item.get('website', ''),
|
||||
'cpe': item.get('cpe', ''),
|
||||
}
|
||||
|
||||
def get_export_data(self) -> dict:
|
||||
"""
|
||||
获取导出数据(Wappalyzer JSON 格式)
|
||||
|
||||
Returns:
|
||||
dict: Wappalyzer 格式的 JSON 数据
|
||||
{
|
||||
"apps": {
|
||||
"AppName": {"cats": [...], "cookies": {...}, ...},
|
||||
...
|
||||
}
|
||||
}
|
||||
"""
|
||||
fingerprints = self.model.objects.all()
|
||||
apps = {}
|
||||
for fp in fingerprints:
|
||||
app_data = {}
|
||||
if fp.cats:
|
||||
app_data['cats'] = fp.cats
|
||||
if fp.cookies:
|
||||
app_data['cookies'] = fp.cookies
|
||||
if fp.headers:
|
||||
app_data['headers'] = fp.headers
|
||||
if fp.script_src:
|
||||
app_data['scriptSrc'] = fp.script_src # Model: script_src -> JSON: scriptSrc
|
||||
if fp.js:
|
||||
app_data['js'] = fp.js
|
||||
if fp.implies:
|
||||
app_data['implies'] = fp.implies
|
||||
if fp.meta:
|
||||
app_data['meta'] = fp.meta
|
||||
if fp.html:
|
||||
app_data['html'] = fp.html
|
||||
if fp.description:
|
||||
app_data['description'] = fp.description
|
||||
if fp.website:
|
||||
app_data['website'] = fp.website
|
||||
if fp.cpe:
|
||||
app_data['cpe'] = fp.cpe
|
||||
apps[fp.name] = app_data
|
||||
return {'apps': apps}
|
||||
@@ -196,6 +196,9 @@ class NucleiTemplateRepoService:
|
||||
cmd: List[str]
|
||||
action: str
|
||||
|
||||
# 直接使用原始 URL(不再使用 Git 加速)
|
||||
repo_url = obj.repo_url
|
||||
|
||||
# 判断是 clone 还是 pull
|
||||
if git_dir.is_dir():
|
||||
# 检查远程地址是否变化
|
||||
@@ -208,12 +211,13 @@ class NucleiTemplateRepoService:
|
||||
)
|
||||
current_url = current_remote.stdout.strip() if current_remote.returncode == 0 else ""
|
||||
|
||||
if current_url != obj.repo_url:
|
||||
# 检查是否需要重新 clone
|
||||
if current_url != repo_url:
|
||||
# 远程地址变化,删除旧目录重新 clone
|
||||
logger.info("nuclei 模板仓库 %s 远程地址变化,重新 clone: %s -> %s", obj.id, current_url, obj.repo_url)
|
||||
logger.info("nuclei 模板仓库 %s 远程地址变化,重新 clone: %s -> %s", obj.id, current_url, repo_url)
|
||||
shutil.rmtree(local_path)
|
||||
local_path.mkdir(parents=True, exist_ok=True)
|
||||
cmd = ["git", "clone", "--depth", "1", obj.repo_url, str(local_path)]
|
||||
cmd = ["git", "clone", "--depth", "1", repo_url, str(local_path)]
|
||||
action = "clone"
|
||||
else:
|
||||
# 已有仓库且地址未变,执行 pull
|
||||
@@ -224,7 +228,7 @@ class NucleiTemplateRepoService:
|
||||
if local_path.exists() and not local_path.is_dir():
|
||||
raise RuntimeError(f"本地路径已存在且不是目录: {local_path}")
|
||||
# --depth 1 浅克隆,只获取最新提交,节省空间和时间
|
||||
cmd = ["git", "clone", "--depth", "1", obj.repo_url, str(local_path)]
|
||||
cmd = ["git", "clone", "--depth", "1", repo_url, str(local_path)]
|
||||
action = "clone"
|
||||
|
||||
# 执行 Git 命令
|
||||
|
||||
@@ -76,8 +76,8 @@ class TaskDistributor:
|
||||
self.docker_image = settings.TASK_EXECUTOR_IMAGE
|
||||
if not self.docker_image:
|
||||
raise ValueError("TASK_EXECUTOR_IMAGE 未配置,请确保 IMAGE_TAG 环境变量已设置")
|
||||
self.results_mount = getattr(settings, 'CONTAINER_RESULTS_MOUNT', '/app/backend/results')
|
||||
self.logs_mount = getattr(settings, 'CONTAINER_LOGS_MOUNT', '/app/backend/logs')
|
||||
# 统一使用 /opt/xingrin 下的路径
|
||||
self.logs_mount = "/opt/xingrin/logs"
|
||||
self.submit_interval = getattr(settings, 'TASK_SUBMIT_INTERVAL', 5)
|
||||
|
||||
def get_online_workers(self) -> list[WorkerNode]:
|
||||
@@ -153,30 +153,68 @@ class TaskDistributor:
|
||||
else:
|
||||
scored_workers.append((worker, score, cpu, mem))
|
||||
|
||||
# 降级策略:如果没有正常负载的,等待后重新选择
|
||||
# 降级策略:如果没有正常负载的,循环等待后重新检测
|
||||
if not scored_workers:
|
||||
if high_load_workers:
|
||||
# 高负载时先等待,给系统喘息时间(默认 60 秒)
|
||||
# 高负载等待参数(默认每 60 秒检测一次,最多 10 次)
|
||||
high_load_wait = getattr(settings, 'HIGH_LOAD_WAIT_SECONDS', 60)
|
||||
logger.warning("所有 Worker 高负载,等待 %d 秒后重试...", high_load_wait)
|
||||
time.sleep(high_load_wait)
|
||||
high_load_max_retries = getattr(settings, 'HIGH_LOAD_MAX_RETRIES', 10)
|
||||
|
||||
# 重新选择(递归调用,可能负载已降下来)
|
||||
# 为避免无限递归,这里直接使用高负载中最低的
|
||||
# 开始等待前发送高负载通知
|
||||
high_load_workers.sort(key=lambda x: x[1])
|
||||
best_worker, _, cpu, mem = high_load_workers[0]
|
||||
|
||||
# 发送高负载通知
|
||||
_, _, first_cpu, first_mem = high_load_workers[0]
|
||||
from apps.common.signals import all_workers_high_load
|
||||
all_workers_high_load.send(
|
||||
sender=self.__class__,
|
||||
worker_name=best_worker.name,
|
||||
cpu=cpu,
|
||||
mem=mem
|
||||
worker_name="所有节点",
|
||||
cpu=first_cpu,
|
||||
mem=first_mem
|
||||
)
|
||||
|
||||
logger.info("选择 Worker: %s (CPU: %.1f%%, MEM: %.1f%%)", best_worker.name, cpu, mem)
|
||||
return best_worker
|
||||
for retry in range(high_load_max_retries):
|
||||
logger.warning(
|
||||
"所有 Worker 高负载,等待 %d 秒后重试... (%d/%d)",
|
||||
high_load_wait, retry + 1, high_load_max_retries
|
||||
)
|
||||
time.sleep(high_load_wait)
|
||||
|
||||
# 重新获取负载数据
|
||||
loads = worker_load_service.get_all_loads(worker_ids)
|
||||
|
||||
# 重新评估
|
||||
scored_workers = []
|
||||
high_load_workers = []
|
||||
|
||||
for worker in workers:
|
||||
load = loads.get(worker.id)
|
||||
if not load:
|
||||
continue
|
||||
|
||||
cpu = load.get('cpu', 0)
|
||||
mem = load.get('mem', 0)
|
||||
score = cpu * 0.7 + mem * 0.3
|
||||
|
||||
if cpu > 85 or mem > 85:
|
||||
high_load_workers.append((worker, score, cpu, mem))
|
||||
else:
|
||||
scored_workers.append((worker, score, cpu, mem))
|
||||
|
||||
# 如果有正常负载的 Worker,跳出循环
|
||||
if scored_workers:
|
||||
logger.info("检测到正常负载 Worker,结束等待")
|
||||
break
|
||||
|
||||
# 超时或仍然高负载,选择负载最低的
|
||||
if not scored_workers and high_load_workers:
|
||||
high_load_workers.sort(key=lambda x: x[1])
|
||||
best_worker, _, cpu, mem = high_load_workers[0]
|
||||
|
||||
logger.warning(
|
||||
"等待超时,强制分发到高负载 Worker: %s (CPU: %.1f%%, MEM: %.1f%%)",
|
||||
best_worker.name, cpu, mem
|
||||
)
|
||||
return best_worker
|
||||
return best_worker
|
||||
else:
|
||||
logger.warning("没有可用的 Worker")
|
||||
return None
|
||||
@@ -234,11 +272,10 @@ class TaskDistributor:
|
||||
else:
|
||||
# 远程:通过 Nginx 反向代理访问(HTTPS,不直连 8888 端口)
|
||||
network_arg = ""
|
||||
server_url = f"https://{settings.PUBLIC_HOST}"
|
||||
server_url = f"https://{settings.PUBLIC_HOST}:{settings.PUBLIC_PORT}"
|
||||
|
||||
# 挂载路径(所有节点统一使用固定路径)
|
||||
host_results_dir = settings.HOST_RESULTS_DIR # /opt/xingrin/results
|
||||
host_logs_dir = settings.HOST_LOGS_DIR # /opt/xingrin/logs
|
||||
# 挂载路径(统一挂载 /opt/xingrin,扫描工具在 /opt/xingrin-tools/bin 不受影响)
|
||||
host_xingrin_dir = "/opt/xingrin"
|
||||
|
||||
# 环境变量:SERVER_URL + IS_LOCAL,其他配置容器启动时从配置中心获取
|
||||
# IS_LOCAL 用于 Worker 向配置中心声明身份,决定返回的数据库地址
|
||||
@@ -247,19 +284,17 @@ class TaskDistributor:
|
||||
env_vars = [
|
||||
f"-e SERVER_URL={shlex.quote(server_url)}",
|
||||
f"-e IS_LOCAL={is_local_str}",
|
||||
f"-e WORKER_API_KEY={shlex.quote(settings.WORKER_API_KEY)}", # Worker API 认证密钥
|
||||
"-e PREFECT_HOME=/tmp/.prefect", # 设置 Prefect 数据目录到可写位置
|
||||
"-e PREFECT_SERVER_EPHEMERAL_ENABLED=true", # 启用 ephemeral server(本地临时服务器)
|
||||
"-e PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=120", # 增加启动超时时间
|
||||
"-e PREFECT_SERVER_DATABASE_CONNECTION_URL=sqlite+aiosqlite:////tmp/.prefect/prefect.db", # 使用 /tmp 下的 SQLite
|
||||
"-e PREFECT_LOGGING_LEVEL=DEBUG", # 启用 DEBUG 级别日志
|
||||
"-e PREFECT_LOGGING_SERVER_LEVEL=DEBUG", # Server 日志级别
|
||||
"-e PREFECT_DEBUG_MODE=true", # 启用调试模式
|
||||
"-e PREFECT_LOGGING_LEVEL=WARNING", # 日志级别(减少 DEBUG 噪音)
|
||||
]
|
||||
|
||||
# 挂载卷
|
||||
# 挂载卷(统一挂载整个 /opt/xingrin 目录)
|
||||
volumes = [
|
||||
f"-v {host_results_dir}:{self.results_mount}",
|
||||
f"-v {host_logs_dir}:{self.logs_mount}",
|
||||
f"-v {host_xingrin_dir}:{host_xingrin_dir}",
|
||||
]
|
||||
|
||||
# 构建命令行参数
|
||||
@@ -277,11 +312,10 @@ class TaskDistributor:
|
||||
# - 本地 Worker:install.sh 已预拉取镜像,直接使用本地版本
|
||||
# - 远程 Worker:deploy 时已预拉取镜像,直接使用本地版本
|
||||
# - 避免每次任务都检查 Docker Hub,提升性能和稳定性
|
||||
# 使用双引号包裹 sh -c 命令,内部 shlex.quote 生成的单引号参数可正确解析
|
||||
cmd = f'''docker run --rm -d --pull=missing {network_arg} \
|
||||
{' '.join(env_vars)} \
|
||||
{' '.join(volumes)} \
|
||||
{self.docker_image} \
|
||||
cmd = f'''docker run --rm -d --pull=missing {network_arg} \\
|
||||
{' '.join(env_vars)} \\
|
||||
{' '.join(volumes)} \\
|
||||
{self.docker_image} \\
|
||||
sh -c "{inner_cmd}"'''
|
||||
|
||||
return cmd
|
||||
@@ -520,7 +554,7 @@ class TaskDistributor:
|
||||
try:
|
||||
# 构建 docker run 命令(清理过期扫描结果目录)
|
||||
script_args = {
|
||||
'results_dir': '/app/backend/results',
|
||||
'results_dir': '/opt/xingrin/results',
|
||||
'retention_days': retention_days,
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from django.conf import settings
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.files.uploadedfile import UploadedFile
|
||||
|
||||
from apps.common.hash_utils import safe_calc_file_sha256
|
||||
from apps.common.utils import safe_calc_file_sha256
|
||||
from apps.engine.models import Wordlist
|
||||
from apps.engine.repositories import DjangoWordlistRepository
|
||||
|
||||
|
||||
@@ -7,6 +7,14 @@ from .views import (
|
||||
WordlistViewSet,
|
||||
NucleiTemplateRepoViewSet,
|
||||
)
|
||||
from .views.fingerprints import (
|
||||
EholeFingerprintViewSet,
|
||||
GobyFingerprintViewSet,
|
||||
WappalyzerFingerprintViewSet,
|
||||
FingersFingerprintViewSet,
|
||||
FingerPrintHubFingerprintViewSet,
|
||||
ARLFingerprintViewSet,
|
||||
)
|
||||
|
||||
|
||||
# 创建路由器
|
||||
@@ -15,6 +23,13 @@ router.register(r"engines", ScanEngineViewSet, basename="engine")
|
||||
router.register(r"workers", WorkerNodeViewSet, basename="worker")
|
||||
router.register(r"wordlists", WordlistViewSet, basename="wordlist")
|
||||
router.register(r"nuclei/repos", NucleiTemplateRepoViewSet, basename="nuclei-repos")
|
||||
# 指纹管理
|
||||
router.register(r"fingerprints/ehole", EholeFingerprintViewSet, basename="ehole-fingerprint")
|
||||
router.register(r"fingerprints/goby", GobyFingerprintViewSet, basename="goby-fingerprint")
|
||||
router.register(r"fingerprints/wappalyzer", WappalyzerFingerprintViewSet, basename="wappalyzer-fingerprint")
|
||||
router.register(r"fingerprints/fingers", FingersFingerprintViewSet, basename="fingers-fingerprint")
|
||||
router.register(r"fingerprints/fingerprinthub", FingerPrintHubFingerprintViewSet, basename="fingerprinthub-fingerprint")
|
||||
router.register(r"fingerprints/arl", ARLFingerprintViewSet, basename="arl-fingerprint")
|
||||
|
||||
urlpatterns = [
|
||||
path("", include(router.urls)),
|
||||
|
||||
22
backend/apps/engine/views/fingerprints/__init__.py
Normal file
22
backend/apps/engine/views/fingerprints/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""指纹管理 ViewSets
|
||||
|
||||
导出所有指纹相关的 ViewSet 类
|
||||
"""
|
||||
|
||||
from .base import BaseFingerprintViewSet
|
||||
from .ehole import EholeFingerprintViewSet
|
||||
from .goby import GobyFingerprintViewSet
|
||||
from .wappalyzer import WappalyzerFingerprintViewSet
|
||||
from .fingers import FingersFingerprintViewSet
|
||||
from .fingerprinthub import FingerPrintHubFingerprintViewSet
|
||||
from .arl import ARLFingerprintViewSet
|
||||
|
||||
__all__ = [
|
||||
"BaseFingerprintViewSet",
|
||||
"EholeFingerprintViewSet",
|
||||
"GobyFingerprintViewSet",
|
||||
"WappalyzerFingerprintViewSet",
|
||||
"FingersFingerprintViewSet",
|
||||
"FingerPrintHubFingerprintViewSet",
|
||||
"ARLFingerprintViewSet",
|
||||
]
|
||||
122
backend/apps/engine/views/fingerprints/arl.py
Normal file
122
backend/apps/engine/views/fingerprints/arl.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""ARL 指纹管理 ViewSet"""
|
||||
|
||||
import yaml
|
||||
from django.http import HttpResponse
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.common.response_helpers import success_response
|
||||
from apps.engine.models import ARLFingerprint
|
||||
from apps.engine.serializers.fingerprints import ARLFingerprintSerializer
|
||||
from apps.engine.services.fingerprints import ARLFingerprintService
|
||||
|
||||
from .base import BaseFingerprintViewSet
|
||||
|
||||
|
||||
class ARLFingerprintViewSet(BaseFingerprintViewSet):
|
||||
"""ARL 指纹管理 ViewSet
|
||||
|
||||
继承自 BaseFingerprintViewSet,提供以下 API:
|
||||
|
||||
标准 CRUD(ModelViewSet):
|
||||
- GET / 列表查询(分页)
|
||||
- POST / 创建单条
|
||||
- GET /{id}/ 获取详情
|
||||
- PUT /{id}/ 更新
|
||||
- DELETE /{id}/ 删除
|
||||
|
||||
批量操作(继承自基类):
|
||||
- POST /batch_create/ 批量创建(JSON body)
|
||||
- POST /import_file/ 文件导入(multipart/form-data,支持 YAML)
|
||||
- POST /bulk-delete/ 批量删除
|
||||
- POST /delete-all/ 删除所有
|
||||
- GET /export/ 导出下载(YAML 格式)
|
||||
|
||||
智能过滤语法(filter 参数):
|
||||
- name="word" 模糊匹配 name 字段
|
||||
- name=="WordPress" 精确匹配
|
||||
- rule="body=" 按规则内容筛选
|
||||
"""
|
||||
|
||||
queryset = ARLFingerprint.objects.all()
|
||||
serializer_class = ARLFingerprintSerializer
|
||||
pagination_class = BasePagination
|
||||
service_class = ARLFingerprintService
|
||||
|
||||
# 排序配置
|
||||
ordering_fields = ['created_at', 'name']
|
||||
ordering = ['-created_at']
|
||||
|
||||
# ARL 过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'name': 'name',
|
||||
'rule': 'rule',
|
||||
}
|
||||
|
||||
def parse_import_data(self, json_data) -> list:
|
||||
"""
|
||||
解析 ARL 格式的导入数据(JSON 格式)
|
||||
|
||||
输入格式:[{...}, {...}] 数组格式
|
||||
返回:指纹列表
|
||||
"""
|
||||
if isinstance(json_data, list):
|
||||
return json_data
|
||||
return []
|
||||
|
||||
def get_export_filename(self) -> str:
|
||||
"""导出文件名"""
|
||||
return 'ARL.yaml'
|
||||
|
||||
@action(detail=False, methods=['post'])
|
||||
def import_file(self, request):
|
||||
"""
|
||||
文件导入(支持 YAML 和 JSON 格式)
|
||||
POST /api/engine/fingerprints/arl/import_file/
|
||||
|
||||
请求格式:multipart/form-data
|
||||
- file: YAML 或 JSON 文件
|
||||
|
||||
返回:同 batch_create
|
||||
"""
|
||||
file = request.FILES.get('file')
|
||||
if not file:
|
||||
raise ValidationError('缺少文件')
|
||||
|
||||
filename = file.name.lower()
|
||||
content = file.read().decode('utf-8')
|
||||
|
||||
try:
|
||||
if filename.endswith('.yaml') or filename.endswith('.yml'):
|
||||
# YAML 格式
|
||||
fingerprints = yaml.safe_load(content)
|
||||
else:
|
||||
# JSON 格式
|
||||
import json
|
||||
fingerprints = json.loads(content)
|
||||
except (yaml.YAMLError, json.JSONDecodeError) as e:
|
||||
raise ValidationError(f'无效的文件格式: {e}')
|
||||
|
||||
if not isinstance(fingerprints, list):
|
||||
raise ValidationError('文件内容必须是数组格式')
|
||||
|
||||
if not fingerprints:
|
||||
raise ValidationError('文件中没有有效的指纹数据')
|
||||
|
||||
result = self.get_service().batch_create_fingerprints(fingerprints)
|
||||
return success_response(data=result)
|
||||
|
||||
@action(detail=False, methods=['get'])
|
||||
def export(self, request):
|
||||
"""
|
||||
导出指纹(YAML 格式)
|
||||
GET /api/engine/fingerprints/arl/export/
|
||||
|
||||
返回:YAML 文件下载
|
||||
"""
|
||||
data = self.get_service().get_export_data()
|
||||
content = yaml.dump(data, allow_unicode=True, default_flow_style=False, sort_keys=False)
|
||||
response = HttpResponse(content, content_type='application/x-yaml')
|
||||
response['Content-Disposition'] = f'attachment; filename="{self.get_export_filename()}"'
|
||||
return response
|
||||
203
backend/apps/engine/views/fingerprints/base.py
Normal file
203
backend/apps/engine/views/fingerprints/base.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""指纹管理基类 ViewSet
|
||||
|
||||
提供通用的 CRUD 和批量操作,供 EHole/Goby/Wappalyzer 等子类继承
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from django.http import HttpResponse
|
||||
from rest_framework import viewsets, status, filters
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.common.response_helpers import success_response
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseFingerprintViewSet(viewsets.ModelViewSet):
|
||||
"""指纹管理基类 ViewSet,供 EHole/Goby/Wappalyzer 等子类继承
|
||||
|
||||
提供的 API:
|
||||
|
||||
标准 CRUD(继承自 ModelViewSet):
|
||||
- GET / 列表查询(分页 + 智能过滤)
|
||||
- POST / 创建单条
|
||||
- GET /{id}/ 获取详情
|
||||
- PUT /{id}/ 更新
|
||||
- DELETE /{id}/ 删除
|
||||
|
||||
批量操作(本类实现):
|
||||
- POST /batch_create/ 批量创建(JSON body)
|
||||
- POST /import_file/ 文件导入(multipart/form-data,适合 10MB+ 大文件)
|
||||
- POST /bulk-delete/ 批量删除
|
||||
- POST /delete-all/ 删除所有
|
||||
- GET /export/ 导出下载
|
||||
|
||||
智能过滤语法(filter 参数):
|
||||
- field="value" 模糊匹配(包含)
|
||||
- field=="value" 精确匹配
|
||||
- 多条件空格分隔 AND 关系
|
||||
- || 或 or OR 关系
|
||||
|
||||
子类必须实现:
|
||||
- service_class Service 类
|
||||
- parse_import_data 解析导入数据格式
|
||||
- get_export_filename 导出文件名
|
||||
"""
|
||||
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering = ['-created_at']
|
||||
|
||||
# 子类必须指定
|
||||
service_class = None # Service 类
|
||||
|
||||
# 智能过滤字段映射,子类必须覆盖
|
||||
FILTER_FIELD_MAPPING = {}
|
||||
|
||||
# JSON 数组字段列表(使用 __contains 查询),子类可覆盖
|
||||
JSON_ARRAY_FIELDS = []
|
||||
|
||||
def get_queryset(self):
|
||||
"""支持智能过滤语法"""
|
||||
queryset = super().get_queryset()
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
if filter_query:
|
||||
queryset = apply_filters(
|
||||
queryset,
|
||||
filter_query,
|
||||
self.FILTER_FIELD_MAPPING,
|
||||
json_array_fields=getattr(self, 'JSON_ARRAY_FIELDS', [])
|
||||
)
|
||||
return queryset
|
||||
|
||||
def get_service(self):
|
||||
"""获取 Service 实例"""
|
||||
if self.service_class is None:
|
||||
raise NotImplementedError("子类必须指定 service_class")
|
||||
return self.service_class()
|
||||
|
||||
def parse_import_data(self, json_data: dict) -> list:
|
||||
"""
|
||||
解析导入数据,子类必须实现
|
||||
|
||||
Args:
|
||||
json_data: 解析后的 JSON 数据
|
||||
|
||||
Returns:
|
||||
list: 指纹数据列表
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 parse_import_data 方法")
|
||||
|
||||
def get_export_filename(self) -> str:
|
||||
"""
|
||||
导出文件名,子类必须实现
|
||||
|
||||
Returns:
|
||||
str: 文件名
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 get_export_filename 方法")
|
||||
|
||||
@action(detail=False, methods=['post'])
|
||||
def batch_create(self, request):
|
||||
"""
|
||||
批量创建指纹规则
|
||||
POST /api/engine/fingerprints/{type}/batch_create/
|
||||
|
||||
请求格式:
|
||||
{
|
||||
"fingerprints": [
|
||||
{"cms": "WordPress", "method": "keyword", ...},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
返回:
|
||||
{
|
||||
"created": 2,
|
||||
"failed": 0
|
||||
}
|
||||
"""
|
||||
fingerprints = request.data.get('fingerprints', [])
|
||||
if not fingerprints:
|
||||
raise ValidationError('fingerprints 不能为空')
|
||||
if not isinstance(fingerprints, list):
|
||||
raise ValidationError('fingerprints 必须是数组')
|
||||
|
||||
result = self.get_service().batch_create_fingerprints(fingerprints)
|
||||
return success_response(data=result, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
@action(detail=False, methods=['post'])
|
||||
def import_file(self, request):
|
||||
"""
|
||||
文件导入(适合大文件,10MB+)
|
||||
POST /api/engine/fingerprints/{type}/import_file/
|
||||
|
||||
请求格式:multipart/form-data
|
||||
- file: JSON 文件
|
||||
|
||||
返回:同 batch_create
|
||||
"""
|
||||
file = request.FILES.get('file')
|
||||
if not file:
|
||||
raise ValidationError('缺少文件')
|
||||
|
||||
try:
|
||||
json_data = json.load(file)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValidationError(f'无效的 JSON 格式: {e}')
|
||||
|
||||
fingerprints = self.parse_import_data(json_data)
|
||||
if not fingerprints:
|
||||
raise ValidationError('文件中没有有效的指纹数据')
|
||||
|
||||
result = self.get_service().batch_create_fingerprints(fingerprints)
|
||||
return success_response(data=result, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='bulk-delete')
|
||||
def bulk_delete(self, request):
|
||||
"""
|
||||
批量删除
|
||||
POST /api/engine/fingerprints/{type}/bulk-delete/
|
||||
|
||||
请求格式:{"ids": [1, 2, 3]}
|
||||
返回:{"deleted": 3}
|
||||
"""
|
||||
ids = request.data.get('ids', [])
|
||||
if not ids:
|
||||
raise ValidationError('ids 不能为空')
|
||||
if not isinstance(ids, list):
|
||||
raise ValidationError('ids 必须是数组')
|
||||
|
||||
deleted_count = self.queryset.model.objects.filter(id__in=ids).delete()[0]
|
||||
return success_response(data={'deleted': deleted_count})
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='delete-all')
|
||||
def delete_all(self, request):
|
||||
"""
|
||||
删除所有指纹
|
||||
POST /api/engine/fingerprints/{type}/delete-all/
|
||||
|
||||
返回:{"deleted": 1000}
|
||||
"""
|
||||
deleted_count = self.queryset.model.objects.all().delete()[0]
|
||||
return success_response(data={'deleted': deleted_count})
|
||||
|
||||
@action(detail=False, methods=['get'])
|
||||
def export(self, request):
|
||||
"""
|
||||
导出指纹(前端下载)
|
||||
GET /api/engine/fingerprints/{type}/export/
|
||||
|
||||
返回:JSON 文件下载
|
||||
"""
|
||||
data = self.get_service().get_export_data()
|
||||
content = json.dumps(data, ensure_ascii=False, indent=2)
|
||||
response = HttpResponse(content, content_type='application/json')
|
||||
response['Content-Disposition'] = f'attachment; filename="{self.get_export_filename()}"'
|
||||
return response
|
||||
67
backend/apps/engine/views/fingerprints/ehole.py
Normal file
67
backend/apps/engine/views/fingerprints/ehole.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""EHole 指纹管理 ViewSet"""
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.engine.models import EholeFingerprint
|
||||
from apps.engine.serializers.fingerprints import EholeFingerprintSerializer
|
||||
from apps.engine.services.fingerprints import EholeFingerprintService
|
||||
|
||||
from .base import BaseFingerprintViewSet
|
||||
|
||||
|
||||
class EholeFingerprintViewSet(BaseFingerprintViewSet):
|
||||
"""EHole 指纹管理 ViewSet
|
||||
|
||||
继承自 BaseFingerprintViewSet,提供以下 API:
|
||||
|
||||
标准 CRUD(ModelViewSet):
|
||||
- GET / 列表查询(分页)
|
||||
- POST / 创建单条
|
||||
- GET /{id}/ 获取详情
|
||||
- PUT /{id}/ 更新
|
||||
- DELETE /{id}/ 删除
|
||||
|
||||
批量操作(继承自基类):
|
||||
- POST /batch_create/ 批量创建(JSON body)
|
||||
- POST /import_file/ 文件导入(multipart/form-data)
|
||||
- POST /bulk-delete/ 批量删除
|
||||
- POST /delete-all/ 删除所有
|
||||
- GET /export/ 导出下载
|
||||
|
||||
智能过滤语法(filter 参数):
|
||||
- cms="word" 模糊匹配 cms 字段
|
||||
- cms=="WordPress" 精确匹配
|
||||
- type="CMS" 按类型筛选
|
||||
- method="keyword" 按匹配方式筛选
|
||||
- location="body" 按匹配位置筛选
|
||||
"""
|
||||
|
||||
queryset = EholeFingerprint.objects.all()
|
||||
serializer_class = EholeFingerprintSerializer
|
||||
pagination_class = BasePagination
|
||||
service_class = EholeFingerprintService
|
||||
|
||||
# 排序配置
|
||||
ordering_fields = ['created_at', 'cms']
|
||||
ordering = ['-created_at']
|
||||
|
||||
# EHole 过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'cms': 'cms',
|
||||
'method': 'method',
|
||||
'location': 'location',
|
||||
'type': 'type',
|
||||
'isImportant': 'is_important',
|
||||
}
|
||||
|
||||
def parse_import_data(self, json_data: dict) -> list:
|
||||
"""
|
||||
解析 EHole JSON 格式的导入数据
|
||||
|
||||
输入格式:{"fingerprint": [...]}
|
||||
返回:指纹列表
|
||||
"""
|
||||
return json_data.get('fingerprint', [])
|
||||
|
||||
def get_export_filename(self) -> str:
|
||||
"""导出文件名"""
|
||||
return 'ehole.json'
|
||||
73
backend/apps/engine/views/fingerprints/fingerprinthub.py
Normal file
73
backend/apps/engine/views/fingerprints/fingerprinthub.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""FingerPrintHub 指纹管理 ViewSet"""
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.engine.models import FingerPrintHubFingerprint
|
||||
from apps.engine.serializers.fingerprints import FingerPrintHubFingerprintSerializer
|
||||
from apps.engine.services.fingerprints import FingerPrintHubFingerprintService
|
||||
|
||||
from .base import BaseFingerprintViewSet
|
||||
|
||||
|
||||
class FingerPrintHubFingerprintViewSet(BaseFingerprintViewSet):
|
||||
"""FingerPrintHub 指纹管理 ViewSet
|
||||
|
||||
继承自 BaseFingerprintViewSet,提供以下 API:
|
||||
|
||||
标准 CRUD(ModelViewSet):
|
||||
- GET / 列表查询(分页)
|
||||
- POST / 创建单条
|
||||
- GET /{id}/ 获取详情
|
||||
- PUT /{id}/ 更新
|
||||
- DELETE /{id}/ 删除
|
||||
|
||||
批量操作(继承自基类):
|
||||
- POST /batch_create/ 批量创建(JSON body)
|
||||
- POST /import_file/ 文件导入(multipart/form-data)
|
||||
- POST /bulk-delete/ 批量删除
|
||||
- POST /delete-all/ 删除所有
|
||||
- GET /export/ 导出下载
|
||||
|
||||
智能过滤语法(filter 参数):
|
||||
- name="word" 模糊匹配 name 字段
|
||||
- fp_id=="xxx" 精确匹配指纹ID
|
||||
- author="xxx" 按作者筛选
|
||||
- severity="info" 按严重程度筛选
|
||||
- tags="cms" 按标签筛选
|
||||
"""
|
||||
|
||||
queryset = FingerPrintHubFingerprint.objects.all()
|
||||
serializer_class = FingerPrintHubFingerprintSerializer
|
||||
pagination_class = BasePagination
|
||||
service_class = FingerPrintHubFingerprintService
|
||||
|
||||
# 排序配置
|
||||
ordering_fields = ['created_at', 'name', 'severity']
|
||||
ordering = ['-created_at']
|
||||
|
||||
# FingerPrintHub 过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'fp_id': 'fp_id',
|
||||
'name': 'name',
|
||||
'author': 'author',
|
||||
'tags': 'tags',
|
||||
'severity': 'severity',
|
||||
'source_file': 'source_file',
|
||||
}
|
||||
|
||||
# JSON 数组字段(使用 __contains 查询)
|
||||
JSON_ARRAY_FIELDS = ['http']
|
||||
|
||||
def parse_import_data(self, json_data) -> list:
|
||||
"""
|
||||
解析 FingerPrintHub JSON 格式的导入数据
|
||||
|
||||
输入格式:[{...}, {...}] 数组格式
|
||||
返回:指纹列表
|
||||
"""
|
||||
if isinstance(json_data, list):
|
||||
return json_data
|
||||
return []
|
||||
|
||||
def get_export_filename(self) -> str:
|
||||
"""导出文件名"""
|
||||
return 'fingerprinthub_web.json'
|
||||
69
backend/apps/engine/views/fingerprints/fingers.py
Normal file
69
backend/apps/engine/views/fingerprints/fingers.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Fingers 指纹管理 ViewSet"""
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.engine.models import FingersFingerprint
|
||||
from apps.engine.serializers.fingerprints import FingersFingerprintSerializer
|
||||
from apps.engine.services.fingerprints import FingersFingerprintService
|
||||
|
||||
from .base import BaseFingerprintViewSet
|
||||
|
||||
|
||||
class FingersFingerprintViewSet(BaseFingerprintViewSet):
|
||||
"""Fingers 指纹管理 ViewSet
|
||||
|
||||
继承自 BaseFingerprintViewSet,提供以下 API:
|
||||
|
||||
标准 CRUD(ModelViewSet):
|
||||
- GET / 列表查询(分页)
|
||||
- POST / 创建单条
|
||||
- GET /{id}/ 获取详情
|
||||
- PUT /{id}/ 更新
|
||||
- DELETE /{id}/ 删除
|
||||
|
||||
批量操作(继承自基类):
|
||||
- POST /batch_create/ 批量创建(JSON body)
|
||||
- POST /import_file/ 文件导入(multipart/form-data)
|
||||
- POST /bulk-delete/ 批量删除
|
||||
- POST /delete-all/ 删除所有
|
||||
- GET /export/ 导出下载
|
||||
|
||||
智能过滤语法(filter 参数):
|
||||
- name="word" 模糊匹配 name 字段
|
||||
- name=="WordPress" 精确匹配
|
||||
- tag="cms" 按标签筛选
|
||||
- focus="true" 按重点关注筛选
|
||||
"""
|
||||
|
||||
queryset = FingersFingerprint.objects.all()
|
||||
serializer_class = FingersFingerprintSerializer
|
||||
pagination_class = BasePagination
|
||||
service_class = FingersFingerprintService
|
||||
|
||||
# 排序配置
|
||||
ordering_fields = ['created_at', 'name']
|
||||
ordering = ['-created_at']
|
||||
|
||||
# Fingers 过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'name': 'name',
|
||||
'link': 'link',
|
||||
'focus': 'focus',
|
||||
}
|
||||
|
||||
# JSON 数组字段(使用 __contains 查询)
|
||||
JSON_ARRAY_FIELDS = ['tag', 'rule', 'default_port']
|
||||
|
||||
def parse_import_data(self, json_data) -> list:
|
||||
"""
|
||||
解析 Fingers JSON 格式的导入数据
|
||||
|
||||
输入格式:[{...}, {...}] 数组格式
|
||||
返回:指纹列表
|
||||
"""
|
||||
if isinstance(json_data, list):
|
||||
return json_data
|
||||
return []
|
||||
|
||||
def get_export_filename(self) -> str:
|
||||
"""导出文件名"""
|
||||
return 'fingers_http.json'
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user