mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-01-31 19:53:11 +08:00
Compare commits
91 Commits
v1.2.4-dev
...
v1.4.0-dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd5c2b9f11 | ||
|
|
54786c22dd | ||
|
|
d468f975ab | ||
|
|
a85a12b8ad | ||
|
|
a8b0d97b7b | ||
|
|
b8504921c2 | ||
|
|
ecfc1822fb | ||
|
|
81633642e6 | ||
|
|
d1ec9b7f27 | ||
|
|
2a3d9b4446 | ||
|
|
9b63203b5a | ||
|
|
4c1282e9bb | ||
|
|
ba3a9b709d | ||
|
|
283b28b46a | ||
|
|
1269e5a314 | ||
|
|
802e967906 | ||
|
|
e446326416 | ||
|
|
e0abb3ce7b | ||
|
|
d418baaf79 | ||
|
|
f8da408580 | ||
|
|
7cd4354d8f | ||
|
|
6bf35a760f | ||
|
|
be9ecadffb | ||
|
|
adb53c9f85 | ||
|
|
7b7bbed634 | ||
|
|
8dd3f0536e | ||
|
|
8a8062a12d | ||
|
|
55908a2da5 | ||
|
|
22a7d4f091 | ||
|
|
f287f18134 | ||
|
|
de27230b7a | ||
|
|
15a6295189 | ||
|
|
674acdac66 | ||
|
|
c59152bedf | ||
|
|
b4037202dc | ||
|
|
4b4f9862bf | ||
|
|
1c42e4978f | ||
|
|
57bab63997 | ||
|
|
b1f0f18ac0 | ||
|
|
ccee5471b8 | ||
|
|
0ccd362535 | ||
|
|
7f2af7f7e2 | ||
|
|
4bd0f9e8c1 | ||
|
|
68cc996e3b | ||
|
|
f1e79d638e | ||
|
|
d484133e4c | ||
|
|
fc977ae029 | ||
|
|
f328474404 | ||
|
|
68e726a066 | ||
|
|
77a6f45909 | ||
|
|
49d1f1f1bb | ||
|
|
db8ecb1644 | ||
|
|
18cc016268 | ||
|
|
23bc463283 | ||
|
|
7b903b91b2 | ||
|
|
b3136d51b9 | ||
|
|
08372588a4 | ||
|
|
236c828041 | ||
|
|
fb13bb74d8 | ||
|
|
f076c682b6 | ||
|
|
9eda2caceb | ||
|
|
b1c9e202dd | ||
|
|
918669bc29 | ||
|
|
fd70b0544d | ||
|
|
0f2df7a5f3 | ||
|
|
857ab737b5 | ||
|
|
ee2d99edda | ||
|
|
db6ce16aca | ||
|
|
ab800eca06 | ||
|
|
e8e5572339 | ||
|
|
d48d4bbcad | ||
|
|
d1cca4c083 | ||
|
|
df0810c863 | ||
|
|
d33e54c440 | ||
|
|
35a306fe8b | ||
|
|
724df82931 | ||
|
|
8dfffdf802 | ||
|
|
b8cb85ce0b | ||
|
|
da96d437a4 | ||
|
|
feaf8062e5 | ||
|
|
4bab76f233 | ||
|
|
09416b4615 | ||
|
|
bc1c5f6b0e | ||
|
|
2f2742e6fe | ||
|
|
be3c346a74 | ||
|
|
0c7a6fff12 | ||
|
|
3b4f0e3147 | ||
|
|
51212a2a0c | ||
|
|
58533bbaf6 | ||
|
|
6ccca1602d | ||
|
|
6389b0f672 |
178
.github/workflows/docker-build.yml
vendored
178
.github/workflows/docker-build.yml
vendored
@@ -19,7 +19,8 @@ permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
build:
|
||||
# AMD64 构建(原生 x64 runner)
|
||||
build-amd64:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -27,39 +28,30 @@ jobs:
|
||||
- image: xingrin-server
|
||||
dockerfile: docker/server/Dockerfile
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- image: xingrin-frontend
|
||||
dockerfile: docker/frontend/Dockerfile
|
||||
context: .
|
||||
platforms: linux/amd64 # ARM64 构建时 Next.js 在 QEMU 下会崩溃
|
||||
- image: xingrin-worker
|
||||
dockerfile: docker/worker/Dockerfile
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- image: xingrin-nginx
|
||||
dockerfile: docker/nginx/Dockerfile
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- image: xingrin-agent
|
||||
dockerfile: docker/agent/Dockerfile
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
- image: xingrin-postgres
|
||||
dockerfile: docker/postgres/Dockerfile
|
||||
context: docker/postgres
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free disk space (for large builds like worker)
|
||||
- name: Free disk space
|
||||
run: |
|
||||
echo "=== Before cleanup ==="
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
sudo docker image prune -af
|
||||
echo "=== After cleanup ==="
|
||||
df -h
|
||||
|
||||
- name: Generate SSL certificates for nginx build
|
||||
if: matrix.image == 'xingrin-nginx'
|
||||
@@ -69,10 +61,6 @@ jobs:
|
||||
-keyout docker/nginx/ssl/privkey.pem \
|
||||
-out docker/nginx/ssl/fullchain.pem \
|
||||
-subj "/CN=localhost"
|
||||
echo "SSL certificates generated for CI build"
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -83,7 +71,120 @@ jobs:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Get version from git tag
|
||||
- name: Get version
|
||||
id: version
|
||||
run: |
|
||||
if [[ $GITHUB_REF == refs/tags/* ]]; then
|
||||
echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "VERSION=dev-$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Build and push AMD64
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.dockerfile }}
|
||||
platforms: linux/amd64
|
||||
push: true
|
||||
tags: ${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:${{ steps.version.outputs.VERSION }}-amd64
|
||||
build-args: IMAGE_TAG=${{ steps.version.outputs.VERSION }}
|
||||
cache-from: type=registry,ref=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:cache-amd64
|
||||
cache-to: type=registry,ref=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:cache-amd64,mode=max
|
||||
provenance: false
|
||||
sbom: false
|
||||
|
||||
# ARM64 构建(原生 ARM64 runner)
|
||||
build-arm64:
|
||||
runs-on: ubuntu-22.04-arm
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- image: xingrin-server
|
||||
dockerfile: docker/server/Dockerfile
|
||||
context: .
|
||||
- image: xingrin-frontend
|
||||
dockerfile: docker/frontend/Dockerfile
|
||||
context: .
|
||||
- image: xingrin-worker
|
||||
dockerfile: docker/worker/Dockerfile
|
||||
context: .
|
||||
- image: xingrin-nginx
|
||||
dockerfile: docker/nginx/Dockerfile
|
||||
context: .
|
||||
- image: xingrin-agent
|
||||
dockerfile: docker/agent/Dockerfile
|
||||
context: .
|
||||
- image: xingrin-postgres
|
||||
dockerfile: docker/postgres/Dockerfile
|
||||
context: docker/postgres
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Generate SSL certificates for nginx build
|
||||
if: matrix.image == 'xingrin-nginx'
|
||||
run: |
|
||||
mkdir -p docker/nginx/ssl
|
||||
openssl req -x509 -nodes -days 365 -newkey rsa:2048 \
|
||||
-keyout docker/nginx/ssl/privkey.pem \
|
||||
-out docker/nginx/ssl/fullchain.pem \
|
||||
-subj "/CN=localhost"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Get version
|
||||
id: version
|
||||
run: |
|
||||
if [[ $GITHUB_REF == refs/tags/* ]]; then
|
||||
echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "VERSION=dev-$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Build and push ARM64
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.dockerfile }}
|
||||
platforms: linux/arm64
|
||||
push: true
|
||||
tags: ${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:${{ steps.version.outputs.VERSION }}-arm64
|
||||
build-args: IMAGE_TAG=${{ steps.version.outputs.VERSION }}
|
||||
cache-from: type=registry,ref=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:cache-arm64
|
||||
cache-to: type=registry,ref=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:cache-arm64,mode=max
|
||||
provenance: false
|
||||
sbom: false
|
||||
|
||||
# 合并多架构 manifest
|
||||
merge-manifests:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build-amd64, build-arm64]
|
||||
strategy:
|
||||
matrix:
|
||||
image:
|
||||
- xingrin-server
|
||||
- xingrin-frontend
|
||||
- xingrin-worker
|
||||
- xingrin-nginx
|
||||
- xingrin-agent
|
||||
- xingrin-postgres
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Get version
|
||||
id: version
|
||||
run: |
|
||||
if [[ $GITHUB_REF == refs/tags/* ]]; then
|
||||
@@ -94,28 +195,27 @@ jobs:
|
||||
echo "IS_RELEASE=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.dockerfile }}
|
||||
platforms: ${{ matrix.platforms }}
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:${{ steps.version.outputs.VERSION }}
|
||||
${{ steps.version.outputs.IS_RELEASE == 'true' && format('{0}/{1}:latest', env.IMAGE_PREFIX, matrix.image) || '' }}
|
||||
build-args: |
|
||||
IMAGE_TAG=${{ steps.version.outputs.VERSION }}
|
||||
cache-from: type=gha,scope=${{ matrix.image }}
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.image }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
- name: Create and push multi-arch manifest
|
||||
run: |
|
||||
VERSION=${{ steps.version.outputs.VERSION }}
|
||||
IMAGE=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}
|
||||
|
||||
docker manifest create ${IMAGE}:${VERSION} \
|
||||
${IMAGE}:${VERSION}-amd64 \
|
||||
${IMAGE}:${VERSION}-arm64
|
||||
docker manifest push ${IMAGE}:${VERSION}
|
||||
|
||||
if [[ "${{ steps.version.outputs.IS_RELEASE }}" == "true" ]]; then
|
||||
docker manifest create ${IMAGE}:latest \
|
||||
${IMAGE}:${VERSION}-amd64 \
|
||||
${IMAGE}:${VERSION}-arm64
|
||||
docker manifest push ${IMAGE}:latest
|
||||
fi
|
||||
|
||||
# 所有镜像构建成功后,更新 VERSION 文件
|
||||
# 根据 tag 所在的分支更新对应分支的 VERSION 文件
|
||||
# 更新 VERSION 文件
|
||||
update-version:
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
needs: merge-manifests
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
||||
62
README.md
62
README.md
@@ -13,18 +13,25 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="#-功能特性">功能特性</a> •
|
||||
<a href="#-全局资产搜索">资产搜索</a> •
|
||||
<a href="#-快速开始">快速开始</a> •
|
||||
<a href="#-文档">文档</a> •
|
||||
<a href="#-技术栈">技术栈</a> •
|
||||
<a href="#-反馈与贡献">反馈与贡献</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<sub>🔍 关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
|
||||
<sub>🔍 关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | 资产搜索 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## 🌐 在线 Demo
|
||||
|
||||
👉 **[https://xingrin.vercel.app/](https://xingrin.vercel.app/)**
|
||||
|
||||
> ⚠️ 仅用于 UI 展示,未接入后端数据库
|
||||
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
<b>🎨 现代化 UI </b>
|
||||
@@ -62,9 +69,14 @@
|
||||
- **自定义流程** - YAML 配置扫描流程,灵活编排
|
||||
- **定时扫描** - Cron 表达式配置,自动化周期扫描
|
||||
|
||||
### 🔖 指纹识别
|
||||
- **多源指纹库** - 内置 EHole、Goby、Wappalyzer、Fingers、FingerPrintHub、ARL 等 2.7W+ 指纹规则
|
||||
- **自动识别** - 扫描流程自动执行,识别 Web 应用技术栈
|
||||
- **指纹管理** - 支持查询、导入、导出指纹规则
|
||||
|
||||
#### 扫描流程架构
|
||||
|
||||
完整的扫描流程包括:子域名发现、端口扫描、站点发现、URL 收集、目录扫描、漏洞扫描等阶段
|
||||
完整的扫描流程包括:子域名发现、端口扫描、站点发现、指纹识别、URL 收集、目录扫描、漏洞扫描等阶段
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
@@ -75,7 +87,8 @@ flowchart LR
|
||||
SUB["子域名发现<br/>subfinder, amass, puredns"]
|
||||
PORT["端口扫描<br/>naabu"]
|
||||
SITE["站点识别<br/>httpx"]
|
||||
SUB --> PORT --> SITE
|
||||
FINGER["指纹识别<br/>xingfinger"]
|
||||
SUB --> PORT --> SITE --> FINGER
|
||||
end
|
||||
|
||||
subgraph STAGE2["阶段 2: 深度分析"]
|
||||
@@ -91,7 +104,7 @@ flowchart LR
|
||||
FINISH["扫描完成"]
|
||||
|
||||
START --> STAGE1
|
||||
SITE --> STAGE2
|
||||
FINGER --> STAGE2
|
||||
STAGE2 --> STAGE3
|
||||
STAGE3 --> FINISH
|
||||
|
||||
@@ -103,6 +116,7 @@ flowchart LR
|
||||
style SUB fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
|
||||
style PORT fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
|
||||
style SITE fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
|
||||
style FINGER fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
|
||||
style URL fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
|
||||
style DIR fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
|
||||
style VULN fill:#f0b27a,stroke:#e67e22,stroke-width:1px,color:#fff
|
||||
@@ -155,9 +169,34 @@ flowchart TB
|
||||
W3 -.心跳上报.-> REDIS
|
||||
```
|
||||
|
||||
### 🔎 全局资产搜索
|
||||
- **多类型搜索** - 支持 Website 和 Endpoint 两种资产类型
|
||||
- **表达式语法** - 支持 `=`(模糊)、`==`(精确)、`!=`(不等于)操作符
|
||||
- **逻辑组合** - 支持 `&&` (AND) 和 `||` (OR) 逻辑组合
|
||||
- **多字段查询** - 支持 host、url、title、tech、status、body、header 字段
|
||||
- **CSV 导出** - 流式导出全部搜索结果,无数量限制
|
||||
|
||||
#### 搜索语法示例
|
||||
|
||||
```bash
|
||||
# 基础搜索
|
||||
host="api" # host 包含 "api"
|
||||
status=="200" # 状态码精确等于 200
|
||||
tech="nginx" # 技术栈包含 nginx
|
||||
|
||||
# 组合搜索
|
||||
host="api" && status=="200" # host 包含 api 且状态码为 200
|
||||
tech="vue" || tech="react" # 技术栈包含 vue 或 react
|
||||
|
||||
# 复杂查询
|
||||
host="admin" && tech="php" && status=="200"
|
||||
url="/api/v1" && status!="404"
|
||||
```
|
||||
|
||||
### 📊 可视化界面
|
||||
- **数据统计** - 资产/漏洞统计仪表盘
|
||||
- **实时通知** - WebSocket 消息推送
|
||||
- **通知推送** - 实时企业微信,tg,discard消息推送服务
|
||||
|
||||
---
|
||||
|
||||
@@ -165,7 +204,8 @@ flowchart TB
|
||||
|
||||
### 环境要求
|
||||
|
||||
- **操作系统**: Ubuntu 20.04+ / Debian 11+ (推荐)
|
||||
- **操作系统**: Ubuntu 20.04+ / Debian 11+
|
||||
- **系统架构**: AMD64 (x86_64) / ARM64 (aarch64)
|
||||
- **硬件**: 2核 4G 内存起步,20GB+ 磁盘空间
|
||||
|
||||
### 一键安装
|
||||
@@ -190,6 +230,7 @@ sudo ./install.sh --mirror
|
||||
### 访问服务
|
||||
|
||||
- **Web 界面**: `https://ip:8083`
|
||||
- **默认账号**: admin / admin(首次登录后请修改密码)
|
||||
|
||||
### 常用命令
|
||||
|
||||
@@ -209,14 +250,11 @@ sudo ./uninstall.sh
|
||||
|
||||
## 🤝 反馈与贡献
|
||||
|
||||
- 🐛 **如果发现 Bug** 可以点击右边链接进行提交 [Issue](https://github.com/yyhuni/xingrin/issues)
|
||||
- 💡 **有新想法,比如UI设计,功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues)
|
||||
- 💡 **发现 Bug,有新想法,比如UI设计,功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues) 或者公众号私信
|
||||
|
||||
## 📧 联系
|
||||
- 目前版本就我个人使用,可能会有很多边界问题
|
||||
- 如有问题,建议,其他,优先提交[Issue](https://github.com/yyhuni/xingrin/issues),也可以直接给我的公众号发消息,我都会回复的
|
||||
|
||||
- 微信公众号: **洋洋的小黑屋**
|
||||
- 微信公众号: **塔罗安全学苑**
|
||||
- 微信群去公众号底下的菜单,有个交流群,点击就可以看到了,链接过期可以私信我拉你
|
||||
|
||||
<img src="docs/wechat-qrcode.png" alt="微信公众号" width="200">
|
||||
|
||||
|
||||
@@ -4,7 +4,3 @@ from django.apps import AppConfig
|
||||
class AssetConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'apps.asset'
|
||||
|
||||
def ready(self):
|
||||
# 导入所有模型以确保Django发现并注册
|
||||
from . import models
|
||||
|
||||
@@ -14,12 +14,13 @@ class EndpointDTO:
|
||||
status_code: Optional[int] = None
|
||||
content_length: Optional[int] = None
|
||||
webserver: Optional[str] = None
|
||||
body_preview: Optional[str] = None
|
||||
response_body: Optional[str] = None
|
||||
content_type: Optional[str] = None
|
||||
tech: Optional[List[str]] = None
|
||||
vhost: Optional[bool] = None
|
||||
location: Optional[str] = None
|
||||
matched_gf_patterns: Optional[List[str]] = None
|
||||
response_headers: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tech is None:
|
||||
|
||||
@@ -17,9 +17,10 @@ class WebSiteDTO:
|
||||
webserver: str = ''
|
||||
content_type: str = ''
|
||||
tech: List[str] = None
|
||||
body_preview: str = ''
|
||||
response_body: str = ''
|
||||
vhost: Optional[bool] = None
|
||||
created_at: str = None
|
||||
response_headers: str = ''
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tech is None:
|
||||
|
||||
@@ -13,6 +13,7 @@ class EndpointSnapshotDTO:
|
||||
快照只属于 scan。
|
||||
"""
|
||||
scan_id: int
|
||||
target_id: int # 必填,用于同步到资产表
|
||||
url: str
|
||||
host: str = '' # 主机名(域名或IP地址)
|
||||
title: str = ''
|
||||
@@ -22,10 +23,10 @@ class EndpointSnapshotDTO:
|
||||
webserver: str = ''
|
||||
content_type: str = ''
|
||||
tech: List[str] = None
|
||||
body_preview: str = ''
|
||||
response_body: str = ''
|
||||
vhost: Optional[bool] = None
|
||||
matched_gf_patterns: List[str] = None
|
||||
target_id: Optional[int] = None # 冗余字段,用于同步到资产表
|
||||
response_headers: str = ''
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tech is None:
|
||||
@@ -42,9 +43,6 @@ class EndpointSnapshotDTO:
|
||||
"""
|
||||
from apps.asset.dtos.asset import EndpointDTO
|
||||
|
||||
if self.target_id is None:
|
||||
raise ValueError("target_id 不能为 None,无法同步到资产表")
|
||||
|
||||
return EndpointDTO(
|
||||
target_id=self.target_id,
|
||||
url=self.url,
|
||||
@@ -53,10 +51,11 @@ class EndpointSnapshotDTO:
|
||||
status_code=self.status_code,
|
||||
content_length=self.content_length,
|
||||
webserver=self.webserver,
|
||||
body_preview=self.body_preview,
|
||||
response_body=self.response_body,
|
||||
content_type=self.content_type,
|
||||
tech=self.tech if self.tech else [],
|
||||
vhost=self.vhost,
|
||||
location=self.location,
|
||||
matched_gf_patterns=self.matched_gf_patterns if self.matched_gf_patterns else []
|
||||
matched_gf_patterns=self.matched_gf_patterns if self.matched_gf_patterns else [],
|
||||
response_headers=self.response_headers,
|
||||
)
|
||||
|
||||
@@ -13,18 +13,19 @@ class WebsiteSnapshotDTO:
|
||||
快照只属于 scan,target 信息通过 scan.target 获取。
|
||||
"""
|
||||
scan_id: int
|
||||
target_id: int # 仅用于传递数据,不保存到数据库
|
||||
target_id: int # 必填,用于同步到资产表
|
||||
url: str
|
||||
host: str
|
||||
title: str = ''
|
||||
status: Optional[int] = None
|
||||
status_code: Optional[int] = None # 统一命名:status -> status_code
|
||||
content_length: Optional[int] = None
|
||||
location: str = ''
|
||||
web_server: str = ''
|
||||
webserver: str = '' # 统一命名:web_server -> webserver
|
||||
content_type: str = ''
|
||||
tech: List[str] = None
|
||||
body_preview: str = ''
|
||||
response_body: str = ''
|
||||
vhost: Optional[bool] = None
|
||||
response_headers: str = ''
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tech is None:
|
||||
@@ -44,12 +45,13 @@ class WebsiteSnapshotDTO:
|
||||
url=self.url,
|
||||
host=self.host,
|
||||
title=self.title,
|
||||
status_code=self.status,
|
||||
status_code=self.status_code,
|
||||
content_length=self.content_length,
|
||||
location=self.location,
|
||||
webserver=self.web_server,
|
||||
webserver=self.webserver,
|
||||
content_type=self.content_type,
|
||||
tech=self.tech if self.tech else [],
|
||||
body_preview=self.body_preview,
|
||||
vhost=self.vhost
|
||||
response_body=self.response_body,
|
||||
vhost=self.vhost,
|
||||
response_headers=self.response_headers,
|
||||
)
|
||||
|
||||
345
backend/apps/asset/migrations/0001_initial.py
Normal file
345
backend/apps/asset/migrations/0001_initial.py
Normal file
@@ -0,0 +1,345 @@
|
||||
# Generated by Django 5.2.7 on 2026-01-06 00:55
|
||||
|
||||
import django.contrib.postgres.fields
|
||||
import django.contrib.postgres.indexes
|
||||
import django.core.validators
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
('scan', '0001_initial'),
|
||||
('targets', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='AssetStatistics',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('total_targets', models.IntegerField(default=0, help_text='目标总数')),
|
||||
('total_subdomains', models.IntegerField(default=0, help_text='子域名总数')),
|
||||
('total_ips', models.IntegerField(default=0, help_text='IP地址总数')),
|
||||
('total_endpoints', models.IntegerField(default=0, help_text='端点总数')),
|
||||
('total_websites', models.IntegerField(default=0, help_text='网站总数')),
|
||||
('total_vulns', models.IntegerField(default=0, help_text='漏洞总数')),
|
||||
('total_assets', models.IntegerField(default=0, help_text='总资产数(子域名+IP+端点+网站)')),
|
||||
('prev_targets', models.IntegerField(default=0, help_text='上次目标总数')),
|
||||
('prev_subdomains', models.IntegerField(default=0, help_text='上次子域名总数')),
|
||||
('prev_ips', models.IntegerField(default=0, help_text='上次IP地址总数')),
|
||||
('prev_endpoints', models.IntegerField(default=0, help_text='上次端点总数')),
|
||||
('prev_websites', models.IntegerField(default=0, help_text='上次网站总数')),
|
||||
('prev_vulns', models.IntegerField(default=0, help_text='上次漏洞总数')),
|
||||
('prev_assets', models.IntegerField(default=0, help_text='上次总资产数')),
|
||||
('updated_at', models.DateTimeField(auto_now=True, help_text='最后更新时间')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '资产统计',
|
||||
'verbose_name_plural': '资产统计',
|
||||
'db_table': 'asset_statistics',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='StatisticsHistory',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('date', models.DateField(help_text='统计日期', unique=True)),
|
||||
('total_targets', models.IntegerField(default=0, help_text='目标总数')),
|
||||
('total_subdomains', models.IntegerField(default=0, help_text='子域名总数')),
|
||||
('total_ips', models.IntegerField(default=0, help_text='IP地址总数')),
|
||||
('total_endpoints', models.IntegerField(default=0, help_text='端点总数')),
|
||||
('total_websites', models.IntegerField(default=0, help_text='网站总数')),
|
||||
('total_vulns', models.IntegerField(default=0, help_text='漏洞总数')),
|
||||
('total_assets', models.IntegerField(default=0, help_text='总资产数')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '统计历史',
|
||||
'verbose_name_plural': '统计历史',
|
||||
'db_table': 'statistics_history',
|
||||
'ordering': ['-date'],
|
||||
'indexes': [models.Index(fields=['date'], name='statistics__date_1d29cd_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Directory',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.CharField(help_text='完整请求 URL', max_length=2000)),
|
||||
('status', models.IntegerField(blank=True, help_text='HTTP 响应状态码', null=True)),
|
||||
('content_length', models.BigIntegerField(blank=True, help_text='响应体字节大小(Content-Length 或实际长度)', null=True)),
|
||||
('words', models.IntegerField(blank=True, help_text='响应体中单词数量(按空格分割)', null=True)),
|
||||
('lines', models.IntegerField(blank=True, help_text='响应体行数(按换行符分割)', null=True)),
|
||||
('content_type', models.CharField(blank=True, default='', help_text='响应头 Content-Type 值', max_length=200)),
|
||||
('duration', models.BigIntegerField(blank=True, help_text='请求耗时(单位:纳秒)', null=True)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('target', models.ForeignKey(help_text='所属的扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='directories', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '目录',
|
||||
'verbose_name_plural': '目录',
|
||||
'db_table': 'directory',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['-created_at'], name='directory_created_2cef03_idx'), models.Index(fields=['target'], name='directory_target__e310c8_idx'), models.Index(fields=['url'], name='directory_url_ba40cd_idx'), models.Index(fields=['status'], name='directory_status_40bbe6_idx'), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='directory_url_trgm_idx', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('target', 'url'), name='unique_directory_url_target')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='DirectorySnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.CharField(help_text='目录URL', max_length=2000)),
|
||||
('status', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.BigIntegerField(blank=True, help_text='内容长度', null=True)),
|
||||
('words', models.IntegerField(blank=True, help_text='响应体中单词数量(按空格分割)', null=True)),
|
||||
('lines', models.IntegerField(blank=True, help_text='响应体行数(按换行符分割)', null=True)),
|
||||
('content_type', models.CharField(blank=True, default='', help_text='响应头 Content-Type 值', max_length=200)),
|
||||
('duration', models.BigIntegerField(blank=True, help_text='请求耗时(单位:纳秒)', null=True)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='directory_snapshots', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '目录快照',
|
||||
'verbose_name_plural': '目录快照',
|
||||
'db_table': 'directory_snapshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scan'], name='directory_s_scan_id_c45900_idx'), models.Index(fields=['url'], name='directory_s_url_b4b72b_idx'), models.Index(fields=['status'], name='directory_s_status_e9f57e_idx'), models.Index(fields=['content_type'], name='directory_s_content_45e864_idx'), models.Index(fields=['-created_at'], name='directory_s_created_eb9d27_idx'), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='dir_snap_url_trgm', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_directory_per_scan_snapshot')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Endpoint',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.TextField(help_text='最终访问的完整URL')),
|
||||
('host', models.CharField(blank=True, default='', help_text='主机名(域名或IP地址)', max_length=253)),
|
||||
('location', models.TextField(blank=True, default='', help_text='重定向地址(HTTP 3xx 响应头 Location)')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('title', models.TextField(blank=True, default='', help_text='网页标题(HTML <title> 标签内容)')),
|
||||
('webserver', models.TextField(blank=True, default='', help_text='服务器类型(HTTP 响应头 Server 值)')),
|
||||
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
|
||||
('content_type', models.TextField(blank=True, default='', help_text='响应类型(HTTP Content-Type 响应头)')),
|
||||
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈(服务器/框架/语言等)', size=None)),
|
||||
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.IntegerField(blank=True, help_text='响应体大小(单位字节)', null=True)),
|
||||
('vhost', models.BooleanField(blank=True, help_text='是否支持虚拟主机', null=True)),
|
||||
('matched_gf_patterns', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='匹配的GF模式列表,用于识别敏感端点(如api, debug, config等)', size=None)),
|
||||
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
|
||||
('target', models.ForeignKey(help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)', on_delete=django.db.models.deletion.CASCADE, related_name='endpoints', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '端点',
|
||||
'verbose_name_plural': '端点',
|
||||
'db_table': 'endpoint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['-created_at'], name='endpoint_created_44fe9c_idx'), models.Index(fields=['target'], name='endpoint_target__7f9065_idx'), models.Index(fields=['url'], name='endpoint_url_30f66e_idx'), models.Index(fields=['host'], name='endpoint_host_5b4cc8_idx'), models.Index(fields=['status_code'], name='endpoint_status__5d4fdd_idx'), models.Index(fields=['title'], name='endpoint_title_29e26c_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='endpoint_tech_2bfa7c_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='endpoint_resp_headers_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='endpoint_url_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='endpoint_title_trgm_idx', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('url', 'target'), name='unique_endpoint_url_target')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='EndpointSnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.TextField(help_text='端点URL')),
|
||||
('host', models.CharField(blank=True, default='', help_text='主机名(域名或IP地址)', max_length=253)),
|
||||
('title', models.TextField(blank=True, default='', help_text='页面标题')),
|
||||
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.IntegerField(blank=True, help_text='内容长度', null=True)),
|
||||
('location', models.TextField(blank=True, default='', help_text='重定向位置')),
|
||||
('webserver', models.TextField(blank=True, default='', help_text='Web服务器')),
|
||||
('content_type', models.TextField(blank=True, default='', help_text='内容类型')),
|
||||
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈', size=None)),
|
||||
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
|
||||
('vhost', models.BooleanField(blank=True, help_text='虚拟主机标志', null=True)),
|
||||
('matched_gf_patterns', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='匹配的GF模式列表', size=None)),
|
||||
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='endpoint_snapshots', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '端点快照',
|
||||
'verbose_name_plural': '端点快照',
|
||||
'db_table': 'endpoint_snapshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scan'], name='endpoint_sn_scan_id_6ac9a7_idx'), models.Index(fields=['url'], name='endpoint_sn_url_205160_idx'), models.Index(fields=['host'], name='endpoint_sn_host_577bfd_idx'), models.Index(fields=['title'], name='endpoint_sn_title_516a05_idx'), models.Index(fields=['status_code'], name='endpoint_sn_status__83efb0_idx'), models.Index(fields=['webserver'], name='endpoint_sn_webserv_66be83_idx'), models.Index(fields=['-created_at'], name='endpoint_sn_created_21fb5b_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='endpoint_sn_tech_0d0752_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='ep_snap_resp_hdr_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='ep_snap_url_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='ep_snap_title_trgm', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_endpoint_per_scan_snapshot')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='HostPortMapping',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('host', models.CharField(help_text='主机名(域名或IP)', max_length=1000)),
|
||||
('ip', models.GenericIPAddressField(help_text='IP地址')),
|
||||
('port', models.IntegerField(help_text='端口号(1-65535)', validators=[django.core.validators.MinValueValidator(1, message='端口号必须大于等于1'), django.core.validators.MaxValueValidator(65535, message='端口号必须小于等于65535')])),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('target', models.ForeignKey(help_text='所属的扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='host_port_mappings', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '主机端口映射',
|
||||
'verbose_name_plural': '主机端口映射',
|
||||
'db_table': 'host_port_mapping',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['target'], name='host_port_m_target__943e9b_idx'), models.Index(fields=['host'], name='host_port_m_host_f78363_idx'), models.Index(fields=['ip'], name='host_port_m_ip_2e6f02_idx'), models.Index(fields=['port'], name='host_port_m_port_9fb9ff_idx'), models.Index(fields=['host', 'ip'], name='host_port_m_host_3ce245_idx'), models.Index(fields=['-created_at'], name='host_port_m_created_11cd22_idx')],
|
||||
'constraints': [models.UniqueConstraint(fields=('target', 'host', 'ip', 'port'), name='unique_target_host_ip_port')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='HostPortMappingSnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('host', models.CharField(help_text='主机名(域名或IP)', max_length=1000)),
|
||||
('ip', models.GenericIPAddressField(help_text='IP地址')),
|
||||
('port', models.IntegerField(help_text='端口号(1-65535)', validators=[django.core.validators.MinValueValidator(1, message='端口号必须大于等于1'), django.core.validators.MaxValueValidator(65535, message='端口号必须小于等于65535')])),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务(主关联)', on_delete=django.db.models.deletion.CASCADE, related_name='host_port_mapping_snapshots', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '主机端口映射快照',
|
||||
'verbose_name_plural': '主机端口映射快照',
|
||||
'db_table': 'host_port_mapping_snapshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scan'], name='host_port_m_scan_id_50ba0b_idx'), models.Index(fields=['host'], name='host_port_m_host_e99054_idx'), models.Index(fields=['ip'], name='host_port_m_ip_54818c_idx'), models.Index(fields=['port'], name='host_port_m_port_ed7b48_idx'), models.Index(fields=['host', 'ip'], name='host_port_m_host_8a463a_idx'), models.Index(fields=['scan', 'host'], name='host_port_m_scan_id_426fdb_idx'), models.Index(fields=['-created_at'], name='host_port_m_created_fb28b8_idx')],
|
||||
'constraints': [models.UniqueConstraint(fields=('scan', 'host', 'ip', 'port'), name='unique_scan_host_ip_port_snapshot')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Subdomain',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('name', models.CharField(help_text='子域名名称', max_length=1000)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('target', models.ForeignKey(help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)', on_delete=django.db.models.deletion.CASCADE, related_name='subdomains', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '子域名',
|
||||
'verbose_name_plural': '子域名',
|
||||
'db_table': 'subdomain',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['-created_at'], name='subdomain_created_e187a8_idx'), models.Index(fields=['name', 'target'], name='subdomain_name_60e1d0_idx'), models.Index(fields=['target'], name='subdomain_target__e409f0_idx'), models.Index(fields=['name'], name='subdomain_name_d40ba7_idx'), django.contrib.postgres.indexes.GinIndex(fields=['name'], name='subdomain_name_trgm_idx', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('name', 'target'), name='unique_subdomain_name_target')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='SubdomainSnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('name', models.CharField(help_text='子域名名称', max_length=1000)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='subdomain_snapshots', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '子域名快照',
|
||||
'verbose_name_plural': '子域名快照',
|
||||
'db_table': 'subdomain_snapshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scan'], name='subdomain_s_scan_id_68c253_idx'), models.Index(fields=['name'], name='subdomain_s_name_2da42b_idx'), models.Index(fields=['-created_at'], name='subdomain_s_created_d2b48e_idx'), django.contrib.postgres.indexes.GinIndex(fields=['name'], name='subdomain_snap_name_trgm', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('scan', 'name'), name='unique_subdomain_per_scan_snapshot')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Vulnerability',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.CharField(help_text='漏洞所在的URL', max_length=2000)),
|
||||
('vuln_type', models.CharField(help_text='漏洞类型(如 xss, sqli)', max_length=100)),
|
||||
('severity', models.CharField(choices=[('unknown', '未知'), ('info', '信息'), ('low', '低'), ('medium', '中'), ('high', '高'), ('critical', '危急')], default='unknown', help_text='严重性(未知/信息/低/中/高/危急)', max_length=20)),
|
||||
('source', models.CharField(blank=True, default='', help_text='来源工具(如 dalfox, nuclei, crlfuzz)', max_length=50)),
|
||||
('cvss_score', models.DecimalField(blank=True, decimal_places=1, help_text='CVSS 评分(0.0-10.0)', max_digits=3, null=True)),
|
||||
('description', models.TextField(blank=True, default='', help_text='漏洞描述')),
|
||||
('raw_output', models.JSONField(blank=True, default=dict, help_text='工具原始输出')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('target', models.ForeignKey(help_text='所属的扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='vulnerabilities', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '漏洞',
|
||||
'verbose_name_plural': '漏洞',
|
||||
'db_table': 'vulnerability',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['target'], name='vulnerabili_target__755a02_idx'), models.Index(fields=['vuln_type'], name='vulnerabili_vuln_ty_3010cd_idx'), models.Index(fields=['severity'], name='vulnerabili_severit_1a798b_idx'), models.Index(fields=['source'], name='vulnerabili_source_7c7552_idx'), models.Index(fields=['url'], name='vulnerabili_url_4dcc4d_idx'), models.Index(fields=['-created_at'], name='vulnerabili_created_e25ff7_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='VulnerabilitySnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.CharField(help_text='漏洞所在的URL', max_length=2000)),
|
||||
('vuln_type', models.CharField(help_text='漏洞类型(如 xss, sqli)', max_length=100)),
|
||||
('severity', models.CharField(choices=[('unknown', '未知'), ('info', '信息'), ('low', '低'), ('medium', '中'), ('high', '高'), ('critical', '危急')], default='unknown', help_text='严重性(未知/信息/低/中/高/危急)', max_length=20)),
|
||||
('source', models.CharField(blank=True, default='', help_text='来源工具(如 dalfox, nuclei, crlfuzz)', max_length=50)),
|
||||
('cvss_score', models.DecimalField(blank=True, decimal_places=1, help_text='CVSS 评分(0.0-10.0)', max_digits=3, null=True)),
|
||||
('description', models.TextField(blank=True, default='', help_text='漏洞描述')),
|
||||
('raw_output', models.JSONField(blank=True, default=dict, help_text='工具原始输出')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='vulnerability_snapshots', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '漏洞快照',
|
||||
'verbose_name_plural': '漏洞快照',
|
||||
'db_table': 'vulnerability_snapshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scan'], name='vulnerabili_scan_id_7b81c9_idx'), models.Index(fields=['url'], name='vulnerabili_url_11a707_idx'), models.Index(fields=['vuln_type'], name='vulnerabili_vuln_ty_6b90ee_idx'), models.Index(fields=['severity'], name='vulnerabili_severit_4eae0d_idx'), models.Index(fields=['source'], name='vulnerabili_source_968b1f_idx'), models.Index(fields=['-created_at'], name='vulnerabili_created_53a12e_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='WebSite',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.TextField(help_text='最终访问的完整URL')),
|
||||
('host', models.CharField(blank=True, default='', help_text='主机名(域名或IP地址)', max_length=253)),
|
||||
('location', models.TextField(blank=True, default='', help_text='重定向地址(HTTP 3xx 响应头 Location)')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('title', models.TextField(blank=True, default='', help_text='网页标题(HTML <title> 标签内容)')),
|
||||
('webserver', models.TextField(blank=True, default='', help_text='服务器类型(HTTP 响应头 Server 值)')),
|
||||
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
|
||||
('content_type', models.TextField(blank=True, default='', help_text='响应类型(HTTP Content-Type 响应头)')),
|
||||
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈(服务器/框架/语言等)', size=None)),
|
||||
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.IntegerField(blank=True, help_text='响应体大小(单位字节)', null=True)),
|
||||
('vhost', models.BooleanField(blank=True, help_text='是否支持虚拟主机', null=True)),
|
||||
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
|
||||
('target', models.ForeignKey(help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)', on_delete=django.db.models.deletion.CASCADE, related_name='websites', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '站点',
|
||||
'verbose_name_plural': '站点',
|
||||
'db_table': 'website',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['-created_at'], name='website_created_c9cfd2_idx'), models.Index(fields=['url'], name='website_url_b18883_idx'), models.Index(fields=['host'], name='website_host_996b50_idx'), models.Index(fields=['target'], name='website_target__2a353b_idx'), models.Index(fields=['title'], name='website_title_c2775b_idx'), models.Index(fields=['status_code'], name='website_status__51663d_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='website_tech_e3f0cb_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='website_resp_headers_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='website_url_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='website_title_trgm_idx', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('url', 'target'), name='unique_website_url_target')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='WebsiteSnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.TextField(help_text='站点URL')),
|
||||
('host', models.CharField(blank=True, default='', help_text='主机名(域名或IP地址)', max_length=253)),
|
||||
('title', models.TextField(blank=True, default='', help_text='页面标题')),
|
||||
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.BigIntegerField(blank=True, help_text='内容长度', null=True)),
|
||||
('location', models.TextField(blank=True, default='', help_text='重定向位置')),
|
||||
('webserver', models.TextField(blank=True, default='', help_text='Web服务器')),
|
||||
('content_type', models.TextField(blank=True, default='', help_text='内容类型')),
|
||||
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈', size=None)),
|
||||
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
|
||||
('vhost', models.BooleanField(blank=True, help_text='虚拟主机标志', null=True)),
|
||||
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='website_snapshots', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '网站快照',
|
||||
'verbose_name_plural': '网站快照',
|
||||
'db_table': 'website_snapshot',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scan'], name='website_sna_scan_id_26b6dc_idx'), models.Index(fields=['url'], name='website_sna_url_801a70_idx'), models.Index(fields=['host'], name='website_sna_host_348fe1_idx'), models.Index(fields=['title'], name='website_sna_title_b1a5ee_idx'), models.Index(fields=['-created_at'], name='website_sna_created_2c149a_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='website_sna_tech_3d6d2f_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='ws_snap_resp_hdr_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='ws_snap_url_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='ws_snap_title_trgm', opclasses=['gin_trgm_ops'])],
|
||||
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_website_per_scan_snapshot')],
|
||||
},
|
||||
),
|
||||
]
|
||||
104
backend/apps/asset/migrations/0002_create_search_views.py
Normal file
104
backend/apps/asset/migrations/0002_create_search_views.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
创建资产搜索物化视图(使用 pg_ivm 增量维护)
|
||||
|
||||
这些视图用于资产搜索功能,提供高性能的全文搜索能力。
|
||||
"""
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
"""创建资产搜索所需的增量物化视图"""
|
||||
|
||||
dependencies = [
|
||||
('asset', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
# 1. 确保 pg_ivm 扩展已安装
|
||||
migrations.RunSQL(
|
||||
sql="CREATE EXTENSION IF NOT EXISTS pg_ivm;",
|
||||
reverse_sql="DROP EXTENSION IF EXISTS pg_ivm;",
|
||||
),
|
||||
|
||||
# 2. 创建 Website 搜索视图
|
||||
# 注意:pg_ivm 不支持 ArrayField,所以 tech 字段需要从原表 JOIN 获取
|
||||
migrations.RunSQL(
|
||||
sql="""
|
||||
SELECT pgivm.create_immv('asset_search_view', $$
|
||||
SELECT
|
||||
w.id,
|
||||
w.url,
|
||||
w.host,
|
||||
w.title,
|
||||
w.status_code,
|
||||
w.response_headers,
|
||||
w.response_body,
|
||||
w.content_type,
|
||||
w.content_length,
|
||||
w.webserver,
|
||||
w.location,
|
||||
w.vhost,
|
||||
w.created_at,
|
||||
w.target_id
|
||||
FROM website w
|
||||
$$);
|
||||
""",
|
||||
reverse_sql="DROP TABLE IF EXISTS asset_search_view CASCADE;",
|
||||
),
|
||||
|
||||
# 3. 创建 Endpoint 搜索视图
|
||||
migrations.RunSQL(
|
||||
sql="""
|
||||
SELECT pgivm.create_immv('endpoint_search_view', $$
|
||||
SELECT
|
||||
e.id,
|
||||
e.url,
|
||||
e.host,
|
||||
e.title,
|
||||
e.status_code,
|
||||
e.response_headers,
|
||||
e.response_body,
|
||||
e.content_type,
|
||||
e.content_length,
|
||||
e.webserver,
|
||||
e.location,
|
||||
e.vhost,
|
||||
e.created_at,
|
||||
e.target_id
|
||||
FROM endpoint e
|
||||
$$);
|
||||
""",
|
||||
reverse_sql="DROP TABLE IF EXISTS endpoint_search_view CASCADE;",
|
||||
),
|
||||
|
||||
# 4. 为搜索视图创建索引(加速查询)
|
||||
migrations.RunSQL(
|
||||
sql=[
|
||||
# Website 搜索视图索引
|
||||
"CREATE INDEX IF NOT EXISTS asset_search_view_host_idx ON asset_search_view (host);",
|
||||
"CREATE INDEX IF NOT EXISTS asset_search_view_url_idx ON asset_search_view (url);",
|
||||
"CREATE INDEX IF NOT EXISTS asset_search_view_title_idx ON asset_search_view (title);",
|
||||
"CREATE INDEX IF NOT EXISTS asset_search_view_status_idx ON asset_search_view (status_code);",
|
||||
"CREATE INDEX IF NOT EXISTS asset_search_view_created_idx ON asset_search_view (created_at DESC);",
|
||||
# Endpoint 搜索视图索引
|
||||
"CREATE INDEX IF NOT EXISTS endpoint_search_view_host_idx ON endpoint_search_view (host);",
|
||||
"CREATE INDEX IF NOT EXISTS endpoint_search_view_url_idx ON endpoint_search_view (url);",
|
||||
"CREATE INDEX IF NOT EXISTS endpoint_search_view_title_idx ON endpoint_search_view (title);",
|
||||
"CREATE INDEX IF NOT EXISTS endpoint_search_view_status_idx ON endpoint_search_view (status_code);",
|
||||
"CREATE INDEX IF NOT EXISTS endpoint_search_view_created_idx ON endpoint_search_view (created_at DESC);",
|
||||
],
|
||||
reverse_sql=[
|
||||
"DROP INDEX IF EXISTS asset_search_view_host_idx;",
|
||||
"DROP INDEX IF EXISTS asset_search_view_url_idx;",
|
||||
"DROP INDEX IF EXISTS asset_search_view_title_idx;",
|
||||
"DROP INDEX IF EXISTS asset_search_view_status_idx;",
|
||||
"DROP INDEX IF EXISTS asset_search_view_created_idx;",
|
||||
"DROP INDEX IF EXISTS endpoint_search_view_host_idx;",
|
||||
"DROP INDEX IF EXISTS endpoint_search_view_url_idx;",
|
||||
"DROP INDEX IF EXISTS endpoint_search_view_title_idx;",
|
||||
"DROP INDEX IF EXISTS endpoint_search_view_status_idx;",
|
||||
"DROP INDEX IF EXISTS endpoint_search_view_created_idx;",
|
||||
],
|
||||
),
|
||||
]
|
||||
@@ -1,6 +1,7 @@
|
||||
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.contrib.postgres.indexes import GinIndex
|
||||
from django.core.validators import MinValueValidator, MaxValueValidator
|
||||
|
||||
|
||||
@@ -34,6 +35,12 @@ class Subdomain(models.Model):
|
||||
models.Index(fields=['name', 'target']), # 复合索引,优化 get_by_names_and_target_id 批量查询
|
||||
models.Index(fields=['target']), # 优化从target_id快速查找下面的子域名
|
||||
models.Index(fields=['name']), # 优化从name快速查找子域名,搜索场景
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='subdomain_name_trgm_idx',
|
||||
fields=['name'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 普通唯一约束:name + target 组合唯一
|
||||
@@ -58,40 +65,35 @@ class Endpoint(models.Model):
|
||||
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
|
||||
)
|
||||
|
||||
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
|
||||
url = models.TextField(help_text='最终访问的完整URL')
|
||||
host = models.CharField(
|
||||
max_length=253,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='主机名(域名或IP地址)'
|
||||
)
|
||||
location = models.CharField(
|
||||
max_length=1000,
|
||||
location = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='重定向地址(HTTP 3xx 响应头 Location)'
|
||||
)
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
title = models.CharField(
|
||||
max_length=1000,
|
||||
title = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='网页标题(HTML <title> 标签内容)'
|
||||
)
|
||||
webserver = models.CharField(
|
||||
max_length=200,
|
||||
webserver = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='服务器类型(HTTP 响应头 Server 值)'
|
||||
)
|
||||
body_preview = models.CharField(
|
||||
max_length=1000,
|
||||
response_body = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应正文前N个字符(默认100个字符)'
|
||||
help_text='HTTP响应体'
|
||||
)
|
||||
content_type = models.CharField(
|
||||
max_length=200,
|
||||
content_type = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应类型(HTTP Content-Type 响应头)'
|
||||
@@ -123,6 +125,11 @@ class Endpoint(models.Model):
|
||||
default=list,
|
||||
help_text='匹配的GF模式列表,用于识别敏感端点(如api, debug, config等)'
|
||||
)
|
||||
response_headers = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='原始HTTP响应头'
|
||||
)
|
||||
|
||||
class Meta:
|
||||
db_table = 'endpoint'
|
||||
@@ -131,11 +138,28 @@ class Endpoint(models.Model):
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-created_at']),
|
||||
models.Index(fields=['target']), # 优化从target_id快速查找下面的端点(主关联字段)
|
||||
models.Index(fields=['target']), # 优化从 target_id快速查找下面的端点(主关联字段)
|
||||
models.Index(fields=['url']), # URL索引,优化查询性能
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['status_code']), # 状态码索引,优化筛选
|
||||
models.Index(fields=['title']), # title索引,优化智能过滤搜索
|
||||
GinIndex(fields=['tech']), # GIN索引,优化 tech 数组字段的 __contains 查询
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='endpoint_resp_headers_trgm_idx',
|
||||
fields=['response_headers'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='endpoint_url_trgm_idx',
|
||||
fields=['url'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='endpoint_title_trgm_idx',
|
||||
fields=['title'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 普通唯一约束:url + target 组合唯一
|
||||
@@ -160,40 +184,35 @@ class WebSite(models.Model):
|
||||
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
|
||||
)
|
||||
|
||||
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
|
||||
url = models.TextField(help_text='最终访问的完整URL')
|
||||
host = models.CharField(
|
||||
max_length=253,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='主机名(域名或IP地址)'
|
||||
)
|
||||
location = models.CharField(
|
||||
max_length=1000,
|
||||
location = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='重定向地址(HTTP 3xx 响应头 Location)'
|
||||
)
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
title = models.CharField(
|
||||
max_length=1000,
|
||||
title = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='网页标题(HTML <title> 标签内容)'
|
||||
)
|
||||
webserver = models.CharField(
|
||||
max_length=200,
|
||||
webserver = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='服务器类型(HTTP 响应头 Server 值)'
|
||||
)
|
||||
body_preview = models.CharField(
|
||||
max_length=1000,
|
||||
response_body = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应正文前N个字符(默认100个字符)'
|
||||
help_text='HTTP响应体'
|
||||
)
|
||||
content_type = models.CharField(
|
||||
max_length=200,
|
||||
content_type = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应类型(HTTP Content-Type 响应头)'
|
||||
@@ -219,6 +238,11 @@ class WebSite(models.Model):
|
||||
blank=True,
|
||||
help_text='是否支持虚拟主机'
|
||||
)
|
||||
response_headers = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='原始HTTP响应头'
|
||||
)
|
||||
|
||||
class Meta:
|
||||
db_table = 'website'
|
||||
@@ -229,9 +253,26 @@ class WebSite(models.Model):
|
||||
models.Index(fields=['-created_at']),
|
||||
models.Index(fields=['url']), # URL索引,优化查询性能
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['target']), # 优化从target_id快速查找下面的站点
|
||||
models.Index(fields=['target']), # 优化从 target_id快速查找下面的站点
|
||||
models.Index(fields=['title']), # title索引,优化智能过滤搜索
|
||||
models.Index(fields=['status_code']), # 状态码索引,优化智能过滤搜索
|
||||
GinIndex(fields=['tech']), # GIN索引,优化 tech 数组字段的 __contains 查询
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='website_resp_headers_trgm_idx',
|
||||
fields=['response_headers'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='website_url_trgm_idx',
|
||||
fields=['url'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='website_title_trgm_idx',
|
||||
fields=['title'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 普通唯一约束:url + target 组合唯一
|
||||
@@ -308,6 +349,12 @@ class Directory(models.Model):
|
||||
models.Index(fields=['target']), # 优化从target_id快速查找下面的目录
|
||||
models.Index(fields=['url']), # URL索引,优化搜索和唯一约束
|
||||
models.Index(fields=['status']), # 状态码索引,优化筛选
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='directory_url_trgm_idx',
|
||||
fields=['url'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 普通唯一约束:target + url 组合唯一
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.contrib.postgres.indexes import GinIndex
|
||||
from django.core.validators import MinValueValidator, MaxValueValidator
|
||||
|
||||
|
||||
@@ -26,6 +27,12 @@ class SubdomainSnapshot(models.Model):
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['name']),
|
||||
models.Index(fields=['-created_at']),
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='subdomain_snap_name_trgm',
|
||||
fields=['name'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:同一次扫描中,同一个子域名只能记录一次
|
||||
@@ -54,22 +61,27 @@ class WebsiteSnapshot(models.Model):
|
||||
)
|
||||
|
||||
# 扫描结果数据
|
||||
url = models.CharField(max_length=2000, help_text='站点URL')
|
||||
url = models.TextField(help_text='站点URL')
|
||||
host = models.CharField(max_length=253, blank=True, default='', help_text='主机名(域名或IP地址)')
|
||||
title = models.CharField(max_length=500, blank=True, default='', help_text='页面标题')
|
||||
status = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
|
||||
title = models.TextField(blank=True, default='', help_text='页面标题')
|
||||
status_code = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
|
||||
content_length = models.BigIntegerField(null=True, blank=True, help_text='内容长度')
|
||||
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
|
||||
web_server = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
|
||||
location = models.TextField(blank=True, default='', help_text='重定向位置')
|
||||
webserver = models.TextField(blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.TextField(blank=True, default='', help_text='内容类型')
|
||||
tech = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
default=list,
|
||||
help_text='技术栈'
|
||||
)
|
||||
body_preview = models.TextField(blank=True, default='', help_text='响应体预览')
|
||||
response_body = models.TextField(blank=True, default='', help_text='HTTP响应体')
|
||||
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
|
||||
response_headers = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='原始HTTP响应头'
|
||||
)
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
|
||||
class Meta:
|
||||
@@ -83,6 +95,23 @@ class WebsiteSnapshot(models.Model):
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['title']), # title索引,优化标题搜索
|
||||
models.Index(fields=['-created_at']),
|
||||
GinIndex(fields=['tech']), # GIN索引,优化数组字段查询
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='ws_snap_resp_hdr_trgm',
|
||||
fields=['response_headers'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='ws_snap_url_trgm',
|
||||
fields=['url'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='ws_snap_title_trgm',
|
||||
fields=['title'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:同一次扫描中,同一个URL只能记录一次
|
||||
@@ -132,6 +161,12 @@ class DirectorySnapshot(models.Model):
|
||||
models.Index(fields=['status']), # 状态码索引,优化筛选
|
||||
models.Index(fields=['content_type']), # content_type索引,优化内容类型搜索
|
||||
models.Index(fields=['-created_at']),
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='dir_snap_url_trgm',
|
||||
fields=['url'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:同一次扫描中,同一个目录URL只能记录一次
|
||||
@@ -232,26 +267,26 @@ class EndpointSnapshot(models.Model):
|
||||
)
|
||||
|
||||
# 扫描结果数据
|
||||
url = models.CharField(max_length=2000, help_text='端点URL')
|
||||
url = models.TextField(help_text='端点URL')
|
||||
host = models.CharField(
|
||||
max_length=253,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='主机名(域名或IP地址)'
|
||||
)
|
||||
title = models.CharField(max_length=1000, blank=True, default='', help_text='页面标题')
|
||||
title = models.TextField(blank=True, default='', help_text='页面标题')
|
||||
status_code = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
|
||||
content_length = models.IntegerField(null=True, blank=True, help_text='内容长度')
|
||||
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
|
||||
webserver = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
|
||||
location = models.TextField(blank=True, default='', help_text='重定向位置')
|
||||
webserver = models.TextField(blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.TextField(blank=True, default='', help_text='内容类型')
|
||||
tech = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
default=list,
|
||||
help_text='技术栈'
|
||||
)
|
||||
body_preview = models.CharField(max_length=1000, blank=True, default='', help_text='响应体预览')
|
||||
response_body = models.TextField(blank=True, default='', help_text='HTTP响应体')
|
||||
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
|
||||
matched_gf_patterns = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
@@ -259,6 +294,11 @@ class EndpointSnapshot(models.Model):
|
||||
default=list,
|
||||
help_text='匹配的GF模式列表'
|
||||
)
|
||||
response_headers = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='原始HTTP响应头'
|
||||
)
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
|
||||
class Meta:
|
||||
@@ -274,6 +314,23 @@ class EndpointSnapshot(models.Model):
|
||||
models.Index(fields=['status_code']), # 状态码索引,优化筛选
|
||||
models.Index(fields=['webserver']), # webserver索引,优化服务器搜索
|
||||
models.Index(fields=['-created_at']),
|
||||
GinIndex(fields=['tech']), # GIN索引,优化数组字段查询
|
||||
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
|
||||
GinIndex(
|
||||
name='ep_snap_resp_hdr_trgm',
|
||||
fields=['response_headers'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='ep_snap_url_trgm',
|
||||
fields=['url'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
GinIndex(
|
||||
name='ep_snap_title_trgm',
|
||||
fields=['title'],
|
||||
opclasses=['gin_trgm_ops']
|
||||
),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:同一次扫描中,同一个URL只能记录一次
|
||||
|
||||
@@ -48,12 +48,13 @@ class DjangoEndpointRepository:
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
webserver=item.webserver or '',
|
||||
body_preview=item.body_preview or '',
|
||||
response_body=item.response_body or '',
|
||||
content_type=item.content_type or '',
|
||||
tech=item.tech if item.tech else [],
|
||||
vhost=item.vhost,
|
||||
location=item.location or '',
|
||||
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
|
||||
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else [],
|
||||
response_headers=item.response_headers if item.response_headers else ''
|
||||
)
|
||||
for item in unique_items
|
||||
]
|
||||
@@ -65,8 +66,8 @@ class DjangoEndpointRepository:
|
||||
unique_fields=['url', 'target'],
|
||||
update_fields=[
|
||||
'host', 'title', 'status_code', 'content_length',
|
||||
'webserver', 'body_preview', 'content_type', 'tech',
|
||||
'vhost', 'location', 'matched_gf_patterns'
|
||||
'webserver', 'response_body', 'content_type', 'tech',
|
||||
'vhost', 'location', 'matched_gf_patterns', 'response_headers'
|
||||
],
|
||||
batch_size=1000
|
||||
)
|
||||
@@ -138,12 +139,13 @@ class DjangoEndpointRepository:
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
webserver=item.webserver or '',
|
||||
body_preview=item.body_preview or '',
|
||||
response_body=item.response_body or '',
|
||||
content_type=item.content_type or '',
|
||||
tech=item.tech if item.tech else [],
|
||||
vhost=item.vhost,
|
||||
location=item.location or '',
|
||||
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
|
||||
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else [],
|
||||
response_headers=item.response_headers if item.response_headers else ''
|
||||
)
|
||||
for item in unique_items
|
||||
]
|
||||
@@ -183,7 +185,7 @@ class DjangoEndpointRepository:
|
||||
.values(
|
||||
'url', 'host', 'location', 'title', 'status_code',
|
||||
'content_length', 'content_type', 'webserver', 'tech',
|
||||
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
|
||||
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
|
||||
)
|
||||
.order_by('url')
|
||||
)
|
||||
|
||||
@@ -49,12 +49,13 @@ class DjangoWebSiteRepository:
|
||||
location=item.location or '',
|
||||
title=item.title or '',
|
||||
webserver=item.webserver or '',
|
||||
body_preview=item.body_preview or '',
|
||||
response_body=item.response_body or '',
|
||||
content_type=item.content_type or '',
|
||||
tech=item.tech if item.tech else [],
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
vhost=item.vhost
|
||||
vhost=item.vhost,
|
||||
response_headers=item.response_headers if item.response_headers else ''
|
||||
)
|
||||
for item in unique_items
|
||||
]
|
||||
@@ -66,8 +67,8 @@ class DjangoWebSiteRepository:
|
||||
unique_fields=['url', 'target'],
|
||||
update_fields=[
|
||||
'host', 'location', 'title', 'webserver',
|
||||
'body_preview', 'content_type', 'tech',
|
||||
'status_code', 'content_length', 'vhost'
|
||||
'response_body', 'content_type', 'tech',
|
||||
'status_code', 'content_length', 'vhost', 'response_headers'
|
||||
],
|
||||
batch_size=1000
|
||||
)
|
||||
@@ -132,12 +133,13 @@ class DjangoWebSiteRepository:
|
||||
location=item.location or '',
|
||||
title=item.title or '',
|
||||
webserver=item.webserver or '',
|
||||
body_preview=item.body_preview or '',
|
||||
response_body=item.response_body or '',
|
||||
content_type=item.content_type or '',
|
||||
tech=item.tech if item.tech else [],
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
vhost=item.vhost
|
||||
vhost=item.vhost,
|
||||
response_headers=item.response_headers if item.response_headers else ''
|
||||
)
|
||||
for item in unique_items
|
||||
]
|
||||
@@ -177,7 +179,7 @@ class DjangoWebSiteRepository:
|
||||
.values(
|
||||
'url', 'host', 'location', 'title', 'status_code',
|
||||
'content_length', 'content_type', 'webserver', 'tech',
|
||||
'body_preview', 'vhost', 'created_at'
|
||||
'response_body', 'response_headers', 'vhost', 'created_at'
|
||||
)
|
||||
.order_by('url')
|
||||
)
|
||||
|
||||
@@ -44,6 +44,7 @@ class DjangoEndpointSnapshotRepository:
|
||||
snapshots.append(EndpointSnapshot(
|
||||
scan_id=item.scan_id,
|
||||
url=item.url,
|
||||
host=item.host if item.host else '',
|
||||
title=item.title,
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
@@ -51,9 +52,10 @@ class DjangoEndpointSnapshotRepository:
|
||||
webserver=item.webserver,
|
||||
content_type=item.content_type,
|
||||
tech=item.tech if item.tech else [],
|
||||
body_preview=item.body_preview,
|
||||
response_body=item.response_body,
|
||||
vhost=item.vhost,
|
||||
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
|
||||
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else [],
|
||||
response_headers=item.response_headers if item.response_headers else ''
|
||||
))
|
||||
|
||||
# 批量创建(忽略冲突,基于唯一约束去重)
|
||||
@@ -100,7 +102,7 @@ class DjangoEndpointSnapshotRepository:
|
||||
.values(
|
||||
'url', 'host', 'location', 'title', 'status_code',
|
||||
'content_length', 'content_type', 'webserver', 'tech',
|
||||
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
|
||||
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
|
||||
)
|
||||
.order_by('url')
|
||||
)
|
||||
|
||||
@@ -46,14 +46,15 @@ class DjangoWebsiteSnapshotRepository:
|
||||
url=item.url,
|
||||
host=item.host,
|
||||
title=item.title,
|
||||
status=item.status,
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
location=item.location,
|
||||
web_server=item.web_server,
|
||||
webserver=item.webserver,
|
||||
content_type=item.content_type,
|
||||
tech=item.tech if item.tech else [],
|
||||
body_preview=item.body_preview,
|
||||
vhost=item.vhost
|
||||
response_body=item.response_body,
|
||||
vhost=item.vhost,
|
||||
response_headers=item.response_headers if item.response_headers else ''
|
||||
))
|
||||
|
||||
# 批量创建(忽略冲突,基于唯一约束去重)
|
||||
@@ -98,26 +99,12 @@ class DjangoWebsiteSnapshotRepository:
|
||||
WebsiteSnapshot.objects
|
||||
.filter(scan_id=scan_id)
|
||||
.values(
|
||||
'url', 'host', 'location', 'title', 'status',
|
||||
'content_length', 'content_type', 'web_server', 'tech',
|
||||
'body_preview', 'vhost', 'created_at'
|
||||
'url', 'host', 'location', 'title', 'status_code',
|
||||
'content_length', 'content_type', 'webserver', 'tech',
|
||||
'response_body', 'response_headers', 'vhost', 'created_at'
|
||||
)
|
||||
.order_by('url')
|
||||
)
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
# 重命名字段以匹配 CSV 表头
|
||||
yield {
|
||||
'url': row['url'],
|
||||
'host': row['host'],
|
||||
'location': row['location'],
|
||||
'title': row['title'],
|
||||
'status_code': row['status'],
|
||||
'content_length': row['content_length'],
|
||||
'content_type': row['content_type'],
|
||||
'webserver': row['web_server'],
|
||||
'tech': row['tech'],
|
||||
'body_preview': row['body_preview'],
|
||||
'vhost': row['vhost'],
|
||||
'created_at': row['created_at'],
|
||||
}
|
||||
yield row
|
||||
|
||||
@@ -67,9 +67,10 @@ class SubdomainListSerializer(serializers.ModelSerializer):
|
||||
|
||||
|
||||
class WebSiteSerializer(serializers.ModelSerializer):
|
||||
"""站点序列化器"""
|
||||
"""站点序列化器(目标详情页)"""
|
||||
|
||||
subdomain = serializers.CharField(source='subdomain.name', allow_blank=True, default='')
|
||||
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
|
||||
|
||||
class Meta:
|
||||
model = WebSite
|
||||
@@ -83,9 +84,10 @@ class WebSiteSerializer(serializers.ModelSerializer):
|
||||
'content_type',
|
||||
'status_code',
|
||||
'content_length',
|
||||
'body_preview',
|
||||
'response_body',
|
||||
'tech',
|
||||
'vhost',
|
||||
'responseHeaders', # HTTP响应头
|
||||
'subdomain',
|
||||
'created_at',
|
||||
]
|
||||
@@ -140,6 +142,7 @@ class EndpointListSerializer(serializers.ModelSerializer):
|
||||
source='matched_gf_patterns',
|
||||
read_only=True,
|
||||
)
|
||||
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
|
||||
|
||||
class Meta:
|
||||
model = Endpoint
|
||||
@@ -152,9 +155,10 @@ class EndpointListSerializer(serializers.ModelSerializer):
|
||||
'content_length',
|
||||
'content_type',
|
||||
'webserver',
|
||||
'body_preview',
|
||||
'response_body',
|
||||
'tech',
|
||||
'vhost',
|
||||
'responseHeaders', # HTTP响应头
|
||||
'gfPatterns',
|
||||
'created_at',
|
||||
]
|
||||
@@ -213,8 +217,7 @@ class WebsiteSnapshotSerializer(serializers.ModelSerializer):
|
||||
"""网站快照序列化器(用于扫描历史)"""
|
||||
|
||||
subdomain_name = serializers.CharField(source='subdomain.name', read_only=True)
|
||||
webserver = serializers.CharField(source='web_server', read_only=True) # 映射字段名
|
||||
status_code = serializers.IntegerField(source='status', read_only=True) # 映射字段名
|
||||
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
|
||||
|
||||
class Meta:
|
||||
model = WebsiteSnapshot
|
||||
@@ -223,13 +226,14 @@ class WebsiteSnapshotSerializer(serializers.ModelSerializer):
|
||||
'url',
|
||||
'location',
|
||||
'title',
|
||||
'webserver', # 使用映射后的字段名
|
||||
'webserver',
|
||||
'content_type',
|
||||
'status_code', # 使用映射后的字段名
|
||||
'status_code',
|
||||
'content_length',
|
||||
'body_preview',
|
||||
'response_body',
|
||||
'tech',
|
||||
'vhost',
|
||||
'responseHeaders', # HTTP响应头
|
||||
'subdomain_name',
|
||||
'created_at',
|
||||
]
|
||||
@@ -264,6 +268,7 @@ class EndpointSnapshotSerializer(serializers.ModelSerializer):
|
||||
source='matched_gf_patterns',
|
||||
read_only=True,
|
||||
)
|
||||
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
|
||||
|
||||
class Meta:
|
||||
model = EndpointSnapshot
|
||||
@@ -277,9 +282,10 @@ class EndpointSnapshotSerializer(serializers.ModelSerializer):
|
||||
'content_type',
|
||||
'status_code',
|
||||
'content_length',
|
||||
'body_preview',
|
||||
'response_body',
|
||||
'tech',
|
||||
'vhost',
|
||||
'responseHeaders', # HTTP响应头
|
||||
'gfPatterns',
|
||||
'created_at',
|
||||
]
|
||||
|
||||
@@ -27,7 +27,8 @@ class EndpointService:
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status_code',
|
||||
'status_code': 'status_code',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
@@ -115,7 +116,7 @@ class EndpointService:
|
||||
"""获取目标下的所有端点"""
|
||||
queryset = self.repo.get_by_target(target_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
|
||||
return queryset
|
||||
|
||||
def count_endpoints_by_target(self, target_id: int) -> int:
|
||||
@@ -134,7 +135,7 @@ class EndpointService:
|
||||
"""获取所有端点(全局查询)"""
|
||||
queryset = self.repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
|
||||
return queryset
|
||||
|
||||
def iter_endpoint_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
|
||||
@@ -19,7 +19,8 @@ class WebSiteService:
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status_code',
|
||||
'status_code': 'status_code',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
def __init__(self, repository=None):
|
||||
@@ -107,14 +108,14 @@ class WebSiteService:
|
||||
"""获取目标下的所有网站"""
|
||||
queryset = self.repo.get_by_target(target_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
|
||||
return queryset
|
||||
|
||||
def get_all(self, filter_query: Optional[str] = None):
|
||||
"""获取所有网站"""
|
||||
queryset = self.repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
|
||||
return queryset
|
||||
|
||||
def get_by_url(self, url: str, target_id: int) -> int:
|
||||
|
||||
477
backend/apps/asset/services/search_service.py
Normal file
477
backend/apps/asset/services/search_service.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
资产搜索服务
|
||||
|
||||
提供资产搜索的核心业务逻辑:
|
||||
- 从物化视图查询数据
|
||||
- 支持表达式语法解析
|
||||
- 支持 =(模糊)、==(精确)、!=(不等于)操作符
|
||||
- 支持 && (AND) 和 || (OR) 逻辑组合
|
||||
- 支持 Website 和 Endpoint 两种资产类型
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Dict, Any, Tuple, Literal, Iterator
|
||||
|
||||
from django.db import connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 支持的字段映射(前端字段名 -> 数据库字段名)
|
||||
FIELD_MAPPING = {
|
||||
'host': 'host',
|
||||
'url': 'url',
|
||||
'title': 'title',
|
||||
'tech': 'tech',
|
||||
'status': 'status_code',
|
||||
'body': 'response_body',
|
||||
'header': 'response_headers',
|
||||
}
|
||||
|
||||
# 数组类型字段
|
||||
ARRAY_FIELDS = {'tech'}
|
||||
|
||||
# 资产类型到视图名的映射
|
||||
VIEW_MAPPING = {
|
||||
'website': 'asset_search_view',
|
||||
'endpoint': 'endpoint_search_view',
|
||||
}
|
||||
|
||||
# 资产类型到原表名的映射(用于 JOIN 获取数组字段)
|
||||
# ⚠️ 重要:pg_ivm 不支持 ArrayField,所有数组字段必须从原表 JOIN 获取
|
||||
TABLE_MAPPING = {
|
||||
'website': 'website',
|
||||
'endpoint': 'endpoint',
|
||||
}
|
||||
|
||||
# 有效的资产类型
|
||||
VALID_ASSET_TYPES = {'website', 'endpoint'}
|
||||
|
||||
# Website 查询字段(v=视图,t=原表)
|
||||
# ⚠️ 注意:t.tech 从原表获取,因为 pg_ivm 不支持 ArrayField
|
||||
WEBSITE_SELECT_FIELDS = """
|
||||
v.id,
|
||||
v.url,
|
||||
v.host,
|
||||
v.title,
|
||||
t.tech, -- ArrayField,从 website 表 JOIN 获取
|
||||
v.status_code,
|
||||
v.response_headers,
|
||||
v.response_body,
|
||||
v.content_type,
|
||||
v.content_length,
|
||||
v.webserver,
|
||||
v.location,
|
||||
v.vhost,
|
||||
v.created_at,
|
||||
v.target_id
|
||||
"""
|
||||
|
||||
# Endpoint 查询字段
|
||||
# ⚠️ 注意:t.tech 和 t.matched_gf_patterns 从原表获取,因为 pg_ivm 不支持 ArrayField
|
||||
ENDPOINT_SELECT_FIELDS = """
|
||||
v.id,
|
||||
v.url,
|
||||
v.host,
|
||||
v.title,
|
||||
t.tech, -- ArrayField,从 endpoint 表 JOIN 获取
|
||||
v.status_code,
|
||||
v.response_headers,
|
||||
v.response_body,
|
||||
v.content_type,
|
||||
v.content_length,
|
||||
v.webserver,
|
||||
v.location,
|
||||
v.vhost,
|
||||
t.matched_gf_patterns, -- ArrayField,从 endpoint 表 JOIN 获取
|
||||
v.created_at,
|
||||
v.target_id
|
||||
"""
|
||||
|
||||
|
||||
class SearchQueryParser:
|
||||
"""
|
||||
搜索查询解析器
|
||||
|
||||
支持语法:
|
||||
- field="value" 模糊匹配(ILIKE %value%)
|
||||
- field=="value" 精确匹配
|
||||
- field!="value" 不等于
|
||||
- && AND 连接
|
||||
- || OR 连接
|
||||
- () 分组(暂不支持嵌套)
|
||||
|
||||
示例:
|
||||
- host="api" && tech="nginx"
|
||||
- tech="vue" || tech="react"
|
||||
- status=="200" && host!="test"
|
||||
"""
|
||||
|
||||
# 匹配单个条件: field="value" 或 field=="value" 或 field!="value"
|
||||
CONDITION_PATTERN = re.compile(r'(\w+)\s*(==|!=|=)\s*"([^"]*)"')
|
||||
|
||||
@classmethod
|
||||
def parse(cls, query: str) -> Tuple[str, List[Any]]:
|
||||
"""
|
||||
解析查询字符串,返回 SQL WHERE 子句和参数
|
||||
|
||||
Args:
|
||||
query: 搜索查询字符串
|
||||
|
||||
Returns:
|
||||
(where_clause, params) 元组
|
||||
"""
|
||||
if not query or not query.strip():
|
||||
return "1=1", []
|
||||
|
||||
query = query.strip()
|
||||
|
||||
# 检查是否包含操作符语法,如果不包含则作为 host 模糊搜索
|
||||
if not cls.CONDITION_PATTERN.search(query):
|
||||
# 裸文本,默认作为 host 模糊搜索(v 是视图别名)
|
||||
return "v.host ILIKE %s", [f"%{query}%"]
|
||||
|
||||
# 按 || 分割为 OR 组
|
||||
or_groups = cls._split_by_or(query)
|
||||
|
||||
if len(or_groups) == 1:
|
||||
# 没有 OR,直接解析 AND 条件
|
||||
return cls._parse_and_group(or_groups[0])
|
||||
|
||||
# 多个 OR 组
|
||||
or_clauses = []
|
||||
all_params = []
|
||||
|
||||
for group in or_groups:
|
||||
clause, params = cls._parse_and_group(group)
|
||||
if clause and clause != "1=1":
|
||||
or_clauses.append(f"({clause})")
|
||||
all_params.extend(params)
|
||||
|
||||
if not or_clauses:
|
||||
return "1=1", []
|
||||
|
||||
return " OR ".join(or_clauses), all_params
|
||||
|
||||
@classmethod
|
||||
def _split_by_or(cls, query: str) -> List[str]:
|
||||
"""按 || 分割查询,但忽略引号内的 ||"""
|
||||
parts = []
|
||||
current = ""
|
||||
in_quotes = False
|
||||
i = 0
|
||||
|
||||
while i < len(query):
|
||||
char = query[i]
|
||||
|
||||
if char == '"':
|
||||
in_quotes = not in_quotes
|
||||
current += char
|
||||
elif not in_quotes and i + 1 < len(query) and query[i:i+2] == '||':
|
||||
if current.strip():
|
||||
parts.append(current.strip())
|
||||
current = ""
|
||||
i += 1 # 跳过第二个 |
|
||||
else:
|
||||
current += char
|
||||
|
||||
i += 1
|
||||
|
||||
if current.strip():
|
||||
parts.append(current.strip())
|
||||
|
||||
return parts if parts else [query]
|
||||
|
||||
@classmethod
|
||||
def _parse_and_group(cls, group: str) -> Tuple[str, List[Any]]:
|
||||
"""解析 AND 组(用 && 连接的条件)"""
|
||||
# 移除外层括号
|
||||
group = group.strip()
|
||||
if group.startswith('(') and group.endswith(')'):
|
||||
group = group[1:-1].strip()
|
||||
|
||||
# 按 && 分割
|
||||
parts = cls._split_by_and(group)
|
||||
|
||||
and_clauses = []
|
||||
all_params = []
|
||||
|
||||
for part in parts:
|
||||
clause, params = cls._parse_condition(part.strip())
|
||||
if clause:
|
||||
and_clauses.append(clause)
|
||||
all_params.extend(params)
|
||||
|
||||
if not and_clauses:
|
||||
return "1=1", []
|
||||
|
||||
return " AND ".join(and_clauses), all_params
|
||||
|
||||
@classmethod
|
||||
def _split_by_and(cls, query: str) -> List[str]:
|
||||
"""按 && 分割查询,但忽略引号内的 &&"""
|
||||
parts = []
|
||||
current = ""
|
||||
in_quotes = False
|
||||
i = 0
|
||||
|
||||
while i < len(query):
|
||||
char = query[i]
|
||||
|
||||
if char == '"':
|
||||
in_quotes = not in_quotes
|
||||
current += char
|
||||
elif not in_quotes and i + 1 < len(query) and query[i:i+2] == '&&':
|
||||
if current.strip():
|
||||
parts.append(current.strip())
|
||||
current = ""
|
||||
i += 1 # 跳过第二个 &
|
||||
else:
|
||||
current += char
|
||||
|
||||
i += 1
|
||||
|
||||
if current.strip():
|
||||
parts.append(current.strip())
|
||||
|
||||
return parts if parts else [query]
|
||||
|
||||
@classmethod
|
||||
def _parse_condition(cls, condition: str) -> Tuple[Optional[str], List[Any]]:
|
||||
"""
|
||||
解析单个条件
|
||||
|
||||
Returns:
|
||||
(sql_clause, params) 或 (None, []) 如果解析失败
|
||||
"""
|
||||
# 移除括号
|
||||
condition = condition.strip()
|
||||
if condition.startswith('(') and condition.endswith(')'):
|
||||
condition = condition[1:-1].strip()
|
||||
|
||||
match = cls.CONDITION_PATTERN.match(condition)
|
||||
if not match:
|
||||
logger.warning(f"无法解析条件: {condition}")
|
||||
return None, []
|
||||
|
||||
field, operator, value = match.groups()
|
||||
field = field.lower()
|
||||
|
||||
# 验证字段
|
||||
if field not in FIELD_MAPPING:
|
||||
logger.warning(f"未知字段: {field}")
|
||||
return None, []
|
||||
|
||||
db_field = FIELD_MAPPING[field]
|
||||
is_array = field in ARRAY_FIELDS
|
||||
|
||||
# 根据操作符生成 SQL
|
||||
if operator == '=':
|
||||
# 模糊匹配
|
||||
return cls._build_like_condition(db_field, value, is_array)
|
||||
elif operator == '==':
|
||||
# 精确匹配
|
||||
return cls._build_exact_condition(db_field, value, is_array)
|
||||
elif operator == '!=':
|
||||
# 不等于
|
||||
return cls._build_not_equal_condition(db_field, value, is_array)
|
||||
|
||||
return None, []
|
||||
|
||||
@classmethod
|
||||
def _build_like_condition(cls, field: str, value: str, is_array: bool) -> Tuple[str, List[Any]]:
|
||||
"""构建模糊匹配条件"""
|
||||
if is_array:
|
||||
# 数组字段:检查数组中是否有元素包含该值(从原表 t 获取)
|
||||
return f"EXISTS (SELECT 1 FROM unnest(t.{field}) AS elem WHERE elem ILIKE %s)", [f"%{value}%"]
|
||||
elif field == 'status_code':
|
||||
# 状态码是整数,模糊匹配转为精确匹配
|
||||
try:
|
||||
return f"v.{field} = %s", [int(value)]
|
||||
except ValueError:
|
||||
return f"v.{field}::text ILIKE %s", [f"%{value}%"]
|
||||
else:
|
||||
return f"v.{field} ILIKE %s", [f"%{value}%"]
|
||||
|
||||
@classmethod
|
||||
def _build_exact_condition(cls, field: str, value: str, is_array: bool) -> Tuple[str, List[Any]]:
|
||||
"""构建精确匹配条件"""
|
||||
if is_array:
|
||||
# 数组字段:检查数组中是否包含该精确值(从原表 t 获取)
|
||||
return f"%s = ANY(t.{field})", [value]
|
||||
elif field == 'status_code':
|
||||
# 状态码是整数
|
||||
try:
|
||||
return f"v.{field} = %s", [int(value)]
|
||||
except ValueError:
|
||||
return f"v.{field}::text = %s", [value]
|
||||
else:
|
||||
return f"v.{field} = %s", [value]
|
||||
|
||||
@classmethod
|
||||
def _build_not_equal_condition(cls, field: str, value: str, is_array: bool) -> Tuple[str, List[Any]]:
|
||||
"""构建不等于条件"""
|
||||
if is_array:
|
||||
# 数组字段:检查数组中不包含该值(从原表 t 获取)
|
||||
return f"NOT (%s = ANY(t.{field}))", [value]
|
||||
elif field == 'status_code':
|
||||
try:
|
||||
return f"(v.{field} IS NULL OR v.{field} != %s)", [int(value)]
|
||||
except ValueError:
|
||||
return f"(v.{field} IS NULL OR v.{field}::text != %s)", [value]
|
||||
else:
|
||||
return f"(v.{field} IS NULL OR v.{field} != %s)", [value]
|
||||
|
||||
|
||||
AssetType = Literal['website', 'endpoint']
|
||||
|
||||
|
||||
class AssetSearchService:
|
||||
"""资产搜索服务"""
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
asset_type: AssetType = 'website',
|
||||
limit: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
搜索资产
|
||||
|
||||
Args:
|
||||
query: 搜索查询字符串
|
||||
asset_type: 资产类型 ('website' 或 'endpoint')
|
||||
limit: 最大返回数量(可选)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 搜索结果列表
|
||||
"""
|
||||
where_clause, params = SearchQueryParser.parse(query)
|
||||
|
||||
# 根据资产类型选择视图、原表和字段
|
||||
view_name = VIEW_MAPPING.get(asset_type, 'asset_search_view')
|
||||
table_name = TABLE_MAPPING.get(asset_type, 'website')
|
||||
select_fields = ENDPOINT_SELECT_FIELDS if asset_type == 'endpoint' else WEBSITE_SELECT_FIELDS
|
||||
|
||||
# JOIN 原表获取数组字段(tech, matched_gf_patterns)
|
||||
sql = f"""
|
||||
SELECT {select_fields}
|
||||
FROM {view_name} v
|
||||
JOIN {table_name} t ON v.id = t.id
|
||||
WHERE {where_clause}
|
||||
ORDER BY v.created_at DESC
|
||||
"""
|
||||
|
||||
# 添加 LIMIT
|
||||
if limit is not None and limit > 0:
|
||||
sql += f" LIMIT {int(limit)}"
|
||||
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(sql, params)
|
||||
columns = [col[0] for col in cursor.description]
|
||||
results = []
|
||||
|
||||
for row in cursor.fetchall():
|
||||
result = dict(zip(columns, row))
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"搜索查询失败: {e}, SQL: {sql}, params: {params}")
|
||||
raise
|
||||
|
||||
def count(self, query: str, asset_type: AssetType = 'website', statement_timeout_ms: int = 300000) -> int:
|
||||
"""
|
||||
统计搜索结果数量
|
||||
|
||||
Args:
|
||||
query: 搜索查询字符串
|
||||
asset_type: 资产类型 ('website' 或 'endpoint')
|
||||
statement_timeout_ms: SQL 语句超时时间(毫秒),默认 5 分钟
|
||||
|
||||
Returns:
|
||||
int: 结果总数
|
||||
"""
|
||||
where_clause, params = SearchQueryParser.parse(query)
|
||||
|
||||
# 根据资产类型选择视图和原表
|
||||
view_name = VIEW_MAPPING.get(asset_type, 'asset_search_view')
|
||||
table_name = TABLE_MAPPING.get(asset_type, 'website')
|
||||
|
||||
# JOIN 原表以支持数组字段查询
|
||||
sql = f"SELECT COUNT(*) FROM {view_name} v JOIN {table_name} t ON v.id = t.id WHERE {where_clause}"
|
||||
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
# 为导出设置更长的超时时间(仅影响当前会话)
|
||||
cursor.execute(f"SET LOCAL statement_timeout = {statement_timeout_ms}")
|
||||
cursor.execute(sql, params)
|
||||
return cursor.fetchone()[0]
|
||||
except Exception as e:
|
||||
logger.error(f"统计查询失败: {e}")
|
||||
raise
|
||||
|
||||
def search_iter(
|
||||
self,
|
||||
query: str,
|
||||
asset_type: AssetType = 'website',
|
||||
batch_size: int = 1000,
|
||||
statement_timeout_ms: int = 300000
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""
|
||||
流式搜索资产(使用分批查询,内存友好)
|
||||
|
||||
Args:
|
||||
query: 搜索查询字符串
|
||||
asset_type: 资产类型 ('website' 或 'endpoint')
|
||||
batch_size: 每批获取的数量
|
||||
statement_timeout_ms: SQL 语句超时时间(毫秒),默认 5 分钟
|
||||
|
||||
Yields:
|
||||
Dict: 单条搜索结果
|
||||
"""
|
||||
where_clause, params = SearchQueryParser.parse(query)
|
||||
|
||||
# 根据资产类型选择视图、原表和字段
|
||||
view_name = VIEW_MAPPING.get(asset_type, 'asset_search_view')
|
||||
table_name = TABLE_MAPPING.get(asset_type, 'website')
|
||||
select_fields = ENDPOINT_SELECT_FIELDS if asset_type == 'endpoint' else WEBSITE_SELECT_FIELDS
|
||||
|
||||
# 使用 OFFSET/LIMIT 分批查询(Django 不支持命名游标)
|
||||
offset = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
# JOIN 原表获取数组字段
|
||||
sql = f"""
|
||||
SELECT {select_fields}
|
||||
FROM {view_name} v
|
||||
JOIN {table_name} t ON v.id = t.id
|
||||
WHERE {where_clause}
|
||||
ORDER BY v.created_at DESC
|
||||
LIMIT {batch_size} OFFSET {offset}
|
||||
"""
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
# 为导出设置更长的超时时间(仅影响当前会话)
|
||||
cursor.execute(f"SET LOCAL statement_timeout = {statement_timeout_ms}")
|
||||
cursor.execute(sql, params)
|
||||
columns = [col[0] for col in cursor.description]
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
yield dict(zip(columns, row))
|
||||
|
||||
# 如果返回的行数少于 batch_size,说明已经是最后一批
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
|
||||
offset += batch_size
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式搜索查询失败: {e}, SQL: {sql}, params: {params}")
|
||||
raise
|
||||
@@ -72,7 +72,7 @@ class EndpointSnapshotsService:
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status_code',
|
||||
'status_code': 'status_code',
|
||||
'webserver': 'webserver',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
@@ -73,8 +73,8 @@ class WebsiteSnapshotsService:
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status',
|
||||
'webserver': 'web_server',
|
||||
'status_code': 'status_code',
|
||||
'webserver': 'webserver',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ from .views import (
|
||||
DirectoryViewSet,
|
||||
VulnerabilityViewSet,
|
||||
AssetStatisticsViewSet,
|
||||
AssetSearchView,
|
||||
AssetSearchExportView,
|
||||
)
|
||||
|
||||
# 创建 DRF 路由器
|
||||
@@ -25,4 +27,6 @@ router.register(r'statistics', AssetStatisticsViewSet, basename='asset-statistic
|
||||
|
||||
urlpatterns = [
|
||||
path('assets/', include(router.urls)),
|
||||
path('assets/search/', AssetSearchView.as_view(), name='asset-search'),
|
||||
path('assets/search/export/', AssetSearchExportView.as_view(), name='asset-search-export'),
|
||||
]
|
||||
|
||||
40
backend/apps/asset/views/__init__.py
Normal file
40
backend/apps/asset/views/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Asset 应用视图模块
|
||||
|
||||
重新导出所有视图类以保持向后兼容
|
||||
"""
|
||||
|
||||
from .asset_views import (
|
||||
AssetStatisticsViewSet,
|
||||
SubdomainViewSet,
|
||||
WebSiteViewSet,
|
||||
DirectoryViewSet,
|
||||
EndpointViewSet,
|
||||
HostPortMappingViewSet,
|
||||
VulnerabilityViewSet,
|
||||
SubdomainSnapshotViewSet,
|
||||
WebsiteSnapshotViewSet,
|
||||
DirectorySnapshotViewSet,
|
||||
EndpointSnapshotViewSet,
|
||||
HostPortMappingSnapshotViewSet,
|
||||
VulnerabilitySnapshotViewSet,
|
||||
)
|
||||
from .search_views import AssetSearchView, AssetSearchExportView
|
||||
|
||||
__all__ = [
|
||||
'AssetStatisticsViewSet',
|
||||
'SubdomainViewSet',
|
||||
'WebSiteViewSet',
|
||||
'DirectoryViewSet',
|
||||
'EndpointViewSet',
|
||||
'HostPortMappingViewSet',
|
||||
'VulnerabilityViewSet',
|
||||
'SubdomainSnapshotViewSet',
|
||||
'WebsiteSnapshotViewSet',
|
||||
'DirectorySnapshotViewSet',
|
||||
'EndpointSnapshotViewSet',
|
||||
'HostPortMappingSnapshotViewSet',
|
||||
'VulnerabilitySnapshotViewSet',
|
||||
'AssetSearchView',
|
||||
'AssetSearchExportView',
|
||||
]
|
||||
@@ -8,19 +8,18 @@ from rest_framework.request import Request
|
||||
from rest_framework.exceptions import NotFound, ValidationError as DRFValidationError
|
||||
from django.core.exceptions import ValidationError, ObjectDoesNotExist
|
||||
from django.db import DatabaseError, IntegrityError, OperationalError
|
||||
from django.http import StreamingHttpResponse
|
||||
|
||||
from .serializers import (
|
||||
from ..serializers import (
|
||||
SubdomainListSerializer, WebSiteSerializer, DirectorySerializer,
|
||||
VulnerabilitySerializer, EndpointListSerializer, IPAddressAggregatedSerializer,
|
||||
SubdomainSnapshotSerializer, WebsiteSnapshotSerializer, DirectorySnapshotSerializer,
|
||||
EndpointSnapshotSerializer, VulnerabilitySnapshotSerializer
|
||||
)
|
||||
from .services import (
|
||||
from ..services import (
|
||||
SubdomainService, WebSiteService, DirectoryService,
|
||||
VulnerabilityService, AssetStatisticsService, EndpointService, HostPortMappingService
|
||||
)
|
||||
from .services.snapshot import (
|
||||
from ..services.snapshot import (
|
||||
SubdomainSnapshotsService, WebsiteSnapshotsService, DirectorySnapshotsService,
|
||||
EndpointSnapshotsService, HostPortMappingSnapshotsService, VulnerabilitySnapshotsService
|
||||
)
|
||||
@@ -243,7 +242,7 @@ class SubdomainViewSet(viewsets.ModelViewSet):
|
||||
|
||||
CSV 列:name, created_at
|
||||
"""
|
||||
from apps.common.utils import generate_csv_rows, format_datetime
|
||||
from apps.common.utils import create_csv_export_response, format_datetime
|
||||
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
@@ -254,12 +253,12 @@ class SubdomainViewSet(viewsets.ModelViewSet):
|
||||
headers = ['name', 'created_at']
|
||||
formatters = {'created_at': format_datetime}
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
generate_csv_rows(data_iterator, headers, formatters),
|
||||
content_type='text/csv; charset=utf-8'
|
||||
return create_csv_export_response(
|
||||
data_iterator=data_iterator,
|
||||
headers=headers,
|
||||
filename=f"target-{target_pk}-subdomains.csv",
|
||||
field_formatters=formatters
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-subdomains.csv"'
|
||||
return response
|
||||
|
||||
|
||||
class WebSiteViewSet(viewsets.ModelViewSet):
|
||||
@@ -274,6 +273,7 @@ class WebSiteViewSet(viewsets.ModelViewSet):
|
||||
- host="example" 主机名模糊匹配
|
||||
- title="login" 标题模糊匹配
|
||||
- status="200,301" 状态码多值匹配
|
||||
- tech="nginx" 技术栈匹配(数组字段)
|
||||
- 多条件空格分隔 AND 关系
|
||||
"""
|
||||
|
||||
@@ -366,9 +366,9 @@ class WebSiteViewSet(viewsets.ModelViewSet):
|
||||
def export(self, request, **kwargs):
|
||||
"""导出网站为 CSV 格式
|
||||
|
||||
CSV 列:url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, created_at
|
||||
CSV 列:url, host, location, title, status_code, content_length, content_type, webserver, tech, response_body, response_headers, vhost, created_at
|
||||
"""
|
||||
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
|
||||
from apps.common.utils import create_csv_export_response, format_datetime, format_list_field
|
||||
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
@@ -379,19 +379,19 @@ class WebSiteViewSet(viewsets.ModelViewSet):
|
||||
headers = [
|
||||
'url', 'host', 'location', 'title', 'status_code',
|
||||
'content_length', 'content_type', 'webserver', 'tech',
|
||||
'body_preview', 'vhost', 'created_at'
|
||||
'response_body', 'response_headers', 'vhost', 'created_at'
|
||||
]
|
||||
formatters = {
|
||||
'created_at': format_datetime,
|
||||
'tech': lambda x: format_list_field(x, separator=','),
|
||||
}
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
generate_csv_rows(data_iterator, headers, formatters),
|
||||
content_type='text/csv; charset=utf-8'
|
||||
return create_csv_export_response(
|
||||
data_iterator=data_iterator,
|
||||
headers=headers,
|
||||
filename=f"target-{target_pk}-websites.csv",
|
||||
field_formatters=formatters
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-websites.csv"'
|
||||
return response
|
||||
|
||||
|
||||
class DirectoryViewSet(viewsets.ModelViewSet):
|
||||
@@ -498,7 +498,7 @@ class DirectoryViewSet(viewsets.ModelViewSet):
|
||||
|
||||
CSV 列:url, status, content_length, words, lines, content_type, duration, created_at
|
||||
"""
|
||||
from apps.common.utils import generate_csv_rows, format_datetime
|
||||
from apps.common.utils import create_csv_export_response, format_datetime
|
||||
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
@@ -514,12 +514,12 @@ class DirectoryViewSet(viewsets.ModelViewSet):
|
||||
'created_at': format_datetime,
|
||||
}
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
generate_csv_rows(data_iterator, headers, formatters),
|
||||
content_type='text/csv; charset=utf-8'
|
||||
return create_csv_export_response(
|
||||
data_iterator=data_iterator,
|
||||
headers=headers,
|
||||
filename=f"target-{target_pk}-directories.csv",
|
||||
field_formatters=formatters
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-directories.csv"'
|
||||
return response
|
||||
|
||||
|
||||
class EndpointViewSet(viewsets.ModelViewSet):
|
||||
@@ -534,6 +534,7 @@ class EndpointViewSet(viewsets.ModelViewSet):
|
||||
- host="example" 主机名模糊匹配
|
||||
- title="login" 标题模糊匹配
|
||||
- status="200,301" 状态码多值匹配
|
||||
- tech="nginx" 技术栈匹配(数组字段)
|
||||
- 多条件空格分隔 AND 关系
|
||||
"""
|
||||
|
||||
@@ -626,9 +627,9 @@ class EndpointViewSet(viewsets.ModelViewSet):
|
||||
def export(self, request, **kwargs):
|
||||
"""导出端点为 CSV 格式
|
||||
|
||||
CSV 列:url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, matched_gf_patterns, created_at
|
||||
CSV 列:url, host, location, title, status_code, content_length, content_type, webserver, tech, response_body, response_headers, vhost, matched_gf_patterns, created_at
|
||||
"""
|
||||
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
|
||||
from apps.common.utils import create_csv_export_response, format_datetime, format_list_field
|
||||
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
@@ -639,7 +640,7 @@ class EndpointViewSet(viewsets.ModelViewSet):
|
||||
headers = [
|
||||
'url', 'host', 'location', 'title', 'status_code',
|
||||
'content_length', 'content_type', 'webserver', 'tech',
|
||||
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
|
||||
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
|
||||
]
|
||||
formatters = {
|
||||
'created_at': format_datetime,
|
||||
@@ -647,12 +648,12 @@ class EndpointViewSet(viewsets.ModelViewSet):
|
||||
'matched_gf_patterns': lambda x: format_list_field(x, separator=','),
|
||||
}
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
generate_csv_rows(data_iterator, headers, formatters),
|
||||
content_type='text/csv; charset=utf-8'
|
||||
return create_csv_export_response(
|
||||
data_iterator=data_iterator,
|
||||
headers=headers,
|
||||
filename=f"target-{target_pk}-endpoints.csv",
|
||||
field_formatters=formatters
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-endpoints.csv"'
|
||||
return response
|
||||
|
||||
|
||||
class HostPortMappingViewSet(viewsets.ModelViewSet):
|
||||
@@ -705,7 +706,7 @@ class HostPortMappingViewSet(viewsets.ModelViewSet):
|
||||
|
||||
CSV 列:ip, host, port, created_at
|
||||
"""
|
||||
from apps.common.utils import generate_csv_rows, format_datetime
|
||||
from apps.common.utils import create_csv_export_response, format_datetime
|
||||
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
@@ -720,14 +721,12 @@ class HostPortMappingViewSet(viewsets.ModelViewSet):
|
||||
'created_at': format_datetime
|
||||
}
|
||||
|
||||
# 生成流式响应
|
||||
response = StreamingHttpResponse(
|
||||
generate_csv_rows(data_iterator, headers, formatters),
|
||||
content_type='text/csv; charset=utf-8'
|
||||
return create_csv_export_response(
|
||||
data_iterator=data_iterator,
|
||||
headers=headers,
|
||||
filename=f"target-{target_pk}-ip-addresses.csv",
|
||||
field_formatters=formatters
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-ip-addresses.csv"'
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class VulnerabilityViewSet(viewsets.ModelViewSet):
|
||||
@@ -799,7 +798,7 @@ class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
|
||||
|
||||
CSV 列:name, created_at
|
||||
"""
|
||||
from apps.common.utils import generate_csv_rows, format_datetime
|
||||
from apps.common.utils import create_csv_export_response, format_datetime
|
||||
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
@@ -810,12 +809,12 @@ class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
|
||||
headers = ['name', 'created_at']
|
||||
formatters = {'created_at': format_datetime}
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
generate_csv_rows(data_iterator, headers, formatters),
|
||||
content_type='text/csv; charset=utf-8'
|
||||
return create_csv_export_response(
|
||||
data_iterator=data_iterator,
|
||||
headers=headers,
|
||||
filename=f"scan-{scan_pk}-subdomains.csv",
|
||||
field_formatters=formatters
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-subdomains.csv"'
|
||||
return response
|
||||
|
||||
|
||||
class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
|
||||
@@ -851,9 +850,9 @@ class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
|
||||
def export(self, request, **kwargs):
|
||||
"""导出网站快照为 CSV 格式
|
||||
|
||||
CSV 列:url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, created_at
|
||||
CSV 列:url, host, location, title, status_code, content_length, content_type, webserver, tech, response_body, response_headers, vhost, created_at
|
||||
"""
|
||||
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
|
||||
from apps.common.utils import create_csv_export_response, format_datetime, format_list_field
|
||||
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
@@ -864,19 +863,19 @@ class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
|
||||
headers = [
|
||||
'url', 'host', 'location', 'title', 'status_code',
|
||||
'content_length', 'content_type', 'webserver', 'tech',
|
||||
'body_preview', 'vhost', 'created_at'
|
||||
'response_body', 'response_headers', 'vhost', 'created_at'
|
||||
]
|
||||
formatters = {
|
||||
'created_at': format_datetime,
|
||||
'tech': lambda x: format_list_field(x, separator=','),
|
||||
}
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
generate_csv_rows(data_iterator, headers, formatters),
|
||||
content_type='text/csv; charset=utf-8'
|
||||
return create_csv_export_response(
|
||||
data_iterator=data_iterator,
|
||||
headers=headers,
|
||||
filename=f"scan-{scan_pk}-websites.csv",
|
||||
field_formatters=formatters
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-websites.csv"'
|
||||
return response
|
||||
|
||||
|
||||
class DirectorySnapshotViewSet(viewsets.ModelViewSet):
|
||||
@@ -911,7 +910,7 @@ class DirectorySnapshotViewSet(viewsets.ModelViewSet):
|
||||
|
||||
CSV 列:url, status, content_length, words, lines, content_type, duration, created_at
|
||||
"""
|
||||
from apps.common.utils import generate_csv_rows, format_datetime
|
||||
from apps.common.utils import create_csv_export_response, format_datetime
|
||||
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
@@ -927,12 +926,12 @@ class DirectorySnapshotViewSet(viewsets.ModelViewSet):
|
||||
'created_at': format_datetime,
|
||||
}
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
generate_csv_rows(data_iterator, headers, formatters),
|
||||
content_type='text/csv; charset=utf-8'
|
||||
return create_csv_export_response(
|
||||
data_iterator=data_iterator,
|
||||
headers=headers,
|
||||
filename=f"scan-{scan_pk}-directories.csv",
|
||||
field_formatters=formatters
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-directories.csv"'
|
||||
return response
|
||||
|
||||
|
||||
class EndpointSnapshotViewSet(viewsets.ModelViewSet):
|
||||
@@ -968,9 +967,9 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
|
||||
def export(self, request, **kwargs):
|
||||
"""导出端点快照为 CSV 格式
|
||||
|
||||
CSV 列:url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, matched_gf_patterns, created_at
|
||||
CSV 列:url, host, location, title, status_code, content_length, content_type, webserver, tech, response_body, response_headers, vhost, matched_gf_patterns, created_at
|
||||
"""
|
||||
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
|
||||
from apps.common.utils import create_csv_export_response, format_datetime, format_list_field
|
||||
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
@@ -981,7 +980,7 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
|
||||
headers = [
|
||||
'url', 'host', 'location', 'title', 'status_code',
|
||||
'content_length', 'content_type', 'webserver', 'tech',
|
||||
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
|
||||
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
|
||||
]
|
||||
formatters = {
|
||||
'created_at': format_datetime,
|
||||
@@ -989,12 +988,12 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
|
||||
'matched_gf_patterns': lambda x: format_list_field(x, separator=','),
|
||||
}
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
generate_csv_rows(data_iterator, headers, formatters),
|
||||
content_type='text/csv; charset=utf-8'
|
||||
return create_csv_export_response(
|
||||
data_iterator=data_iterator,
|
||||
headers=headers,
|
||||
filename=f"scan-{scan_pk}-endpoints.csv",
|
||||
field_formatters=formatters
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-endpoints.csv"'
|
||||
return response
|
||||
|
||||
|
||||
class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
|
||||
@@ -1029,7 +1028,7 @@ class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
|
||||
|
||||
CSV 列:ip, host, port, created_at
|
||||
"""
|
||||
from apps.common.utils import generate_csv_rows, format_datetime
|
||||
from apps.common.utils import create_csv_export_response, format_datetime
|
||||
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
@@ -1044,14 +1043,12 @@ class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
|
||||
'created_at': format_datetime
|
||||
}
|
||||
|
||||
# 生成流式响应
|
||||
response = StreamingHttpResponse(
|
||||
generate_csv_rows(data_iterator, headers, formatters),
|
||||
content_type='text/csv; charset=utf-8'
|
||||
return create_csv_export_response(
|
||||
data_iterator=data_iterator,
|
||||
headers=headers,
|
||||
filename=f"scan-{scan_pk}-ip-addresses.csv",
|
||||
field_formatters=formatters
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-ip-addresses.csv"'
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
|
||||
361
backend/apps/asset/views/search_views.py
Normal file
361
backend/apps/asset/views/search_views.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
资产搜索 API 视图
|
||||
|
||||
提供资产搜索的 REST API 接口:
|
||||
- GET /api/assets/search/ - 搜索资产
|
||||
- GET /api/assets/search/export/ - 导出搜索结果为 CSV
|
||||
|
||||
搜索语法:
|
||||
- field="value" 模糊匹配(ILIKE %value%)
|
||||
- field=="value" 精确匹配
|
||||
- field!="value" 不等于
|
||||
- && AND 连接
|
||||
- || OR 连接
|
||||
|
||||
支持的字段:
|
||||
- host: 主机名
|
||||
- url: URL
|
||||
- title: 标题
|
||||
- tech: 技术栈
|
||||
- status: 状态码
|
||||
- body: 响应体
|
||||
- header: 响应头
|
||||
|
||||
支持的资产类型:
|
||||
- website: 站点(默认)
|
||||
- endpoint: 端点
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from rest_framework import status
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.request import Request
|
||||
from django.db import connection
|
||||
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
from apps.asset.services.search_service import AssetSearchService, VALID_ASSET_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssetSearchView(APIView):
|
||||
"""
|
||||
资产搜索 API
|
||||
|
||||
GET /api/assets/search/
|
||||
|
||||
Query Parameters:
|
||||
q: 搜索查询表达式
|
||||
asset_type: 资产类型 ('website' 或 'endpoint',默认 'website')
|
||||
page: 页码(从 1 开始,默认 1)
|
||||
pageSize: 每页数量(默认 10,最大 100)
|
||||
|
||||
示例查询:
|
||||
?q=host="api" && tech="nginx"
|
||||
?q=tech="vue" || tech="react"&asset_type=endpoint
|
||||
?q=status=="200" && host!="test"
|
||||
|
||||
Response:
|
||||
{
|
||||
"results": [...],
|
||||
"total": 100,
|
||||
"page": 1,
|
||||
"pageSize": 10,
|
||||
"totalPages": 10,
|
||||
"assetType": "website"
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = AssetSearchService()
|
||||
|
||||
def _parse_headers(self, headers_data) -> dict:
|
||||
"""解析响应头为字典"""
|
||||
if not headers_data:
|
||||
return {}
|
||||
try:
|
||||
return json.loads(headers_data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
result = {}
|
||||
for line in str(headers_data).split('\n'):
|
||||
if ':' in line:
|
||||
key, value = line.split(':', 1)
|
||||
result[key.strip()] = value.strip()
|
||||
return result
|
||||
|
||||
def _format_result(self, result: dict, vulnerabilities_by_url: dict, asset_type: str) -> dict:
|
||||
"""格式化单个搜索结果"""
|
||||
url = result.get('url', '')
|
||||
vulns = vulnerabilities_by_url.get(url, [])
|
||||
|
||||
# 基础字段(Website 和 Endpoint 共有)
|
||||
formatted = {
|
||||
'id': result.get('id'),
|
||||
'url': url,
|
||||
'host': result.get('host', ''),
|
||||
'title': result.get('title', ''),
|
||||
'technologies': result.get('tech', []) or [],
|
||||
'statusCode': result.get('status_code'),
|
||||
'contentLength': result.get('content_length'),
|
||||
'contentType': result.get('content_type', ''),
|
||||
'webserver': result.get('webserver', ''),
|
||||
'location': result.get('location', ''),
|
||||
'vhost': result.get('vhost'),
|
||||
'responseHeaders': self._parse_headers(result.get('response_headers')),
|
||||
'responseBody': result.get('response_body', ''),
|
||||
'createdAt': result.get('created_at').isoformat() if result.get('created_at') else None,
|
||||
'targetId': result.get('target_id'),
|
||||
}
|
||||
|
||||
# Website 特有字段:漏洞关联
|
||||
if asset_type == 'website':
|
||||
formatted['vulnerabilities'] = [
|
||||
{
|
||||
'id': v.get('id'),
|
||||
'name': v.get('vuln_type', ''),
|
||||
'vulnType': v.get('vuln_type', ''),
|
||||
'severity': v.get('severity', 'info'),
|
||||
}
|
||||
for v in vulns
|
||||
]
|
||||
|
||||
# Endpoint 特有字段
|
||||
if asset_type == 'endpoint':
|
||||
formatted['matchedGfPatterns'] = result.get('matched_gf_patterns', []) or []
|
||||
|
||||
return formatted
|
||||
|
||||
def _get_vulnerabilities_by_url_prefix(self, website_urls: list) -> dict:
|
||||
"""
|
||||
根据 URL 前缀批量查询漏洞数据
|
||||
|
||||
漏洞 URL 是 website URL 的子路径,使用前缀匹配:
|
||||
- website.url: https://example.com/path?query=1
|
||||
- vulnerability.url: https://example.com/path/api/users
|
||||
|
||||
Args:
|
||||
website_urls: website URL 列表,格式为 [(url, target_id), ...]
|
||||
|
||||
Returns:
|
||||
dict: {website_url: [vulnerability_list]}
|
||||
"""
|
||||
if not website_urls:
|
||||
return {}
|
||||
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
# 构建 OR 条件:每个 website URL(去掉查询参数)作为前缀匹配
|
||||
conditions = []
|
||||
params = []
|
||||
url_mapping = {} # base_url -> original_url
|
||||
|
||||
for url, target_id in website_urls:
|
||||
if not url or target_id is None:
|
||||
continue
|
||||
# 使用 urlparse 去掉查询参数和片段,只保留 scheme://netloc/path
|
||||
parsed = urlparse(url)
|
||||
base_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, '', '', ''))
|
||||
url_mapping[base_url] = url
|
||||
conditions.append("(v.url LIKE %s AND v.target_id = %s)")
|
||||
params.extend([base_url + '%', target_id])
|
||||
|
||||
if not conditions:
|
||||
return {}
|
||||
|
||||
where_clause = " OR ".join(conditions)
|
||||
|
||||
sql = f"""
|
||||
SELECT v.id, v.vuln_type, v.severity, v.url, v.target_id
|
||||
FROM vulnerability v
|
||||
WHERE {where_clause}
|
||||
ORDER BY
|
||||
CASE v.severity
|
||||
WHEN 'critical' THEN 1
|
||||
WHEN 'high' THEN 2
|
||||
WHEN 'medium' THEN 3
|
||||
WHEN 'low' THEN 4
|
||||
ELSE 5
|
||||
END
|
||||
"""
|
||||
cursor.execute(sql, params)
|
||||
|
||||
# 获取所有漏洞
|
||||
all_vulns = []
|
||||
for row in cursor.fetchall():
|
||||
all_vulns.append({
|
||||
'id': row[0],
|
||||
'vuln_type': row[1],
|
||||
'name': row[1],
|
||||
'severity': row[2],
|
||||
'url': row[3],
|
||||
'target_id': row[4],
|
||||
})
|
||||
|
||||
# 按原始 website URL 分组(用于返回结果)
|
||||
result = {url: [] for url, _ in website_urls}
|
||||
for vuln in all_vulns:
|
||||
vuln_url = vuln['url']
|
||||
# 找到匹配的 website URL(最长前缀匹配)
|
||||
for website_url, target_id in website_urls:
|
||||
parsed = urlparse(website_url)
|
||||
base_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, '', '', ''))
|
||||
if vuln_url.startswith(base_url) and vuln['target_id'] == target_id:
|
||||
result[website_url].append(vuln)
|
||||
break
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"批量查询漏洞失败: {e}")
|
||||
return {}
|
||||
|
||||
def get(self, request: Request):
|
||||
"""搜索资产"""
|
||||
# 获取搜索查询
|
||||
query = request.query_params.get('q', '').strip()
|
||||
|
||||
if not query:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Search query (q) is required',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 获取并验证资产类型
|
||||
asset_type = request.query_params.get('asset_type', 'website').strip().lower()
|
||||
if asset_type not in VALID_ASSET_TYPES:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=f'Invalid asset_type. Must be one of: {", ".join(VALID_ASSET_TYPES)}',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 获取分页参数
|
||||
try:
|
||||
page = int(request.query_params.get('page', 1))
|
||||
page_size = int(request.query_params.get('pageSize', 10))
|
||||
except (ValueError, TypeError):
|
||||
page = 1
|
||||
page_size = 10
|
||||
|
||||
# 限制分页参数
|
||||
page = max(1, page)
|
||||
page_size = min(max(1, page_size), 100)
|
||||
|
||||
# 获取总数和搜索结果
|
||||
total = self.service.count(query, asset_type)
|
||||
total_pages = (total + page_size - 1) // page_size if total > 0 else 1
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
all_results = self.service.search(query, asset_type)
|
||||
results = all_results[offset:offset + page_size]
|
||||
|
||||
# 批量查询漏洞数据(仅 Website 类型需要)
|
||||
vulnerabilities_by_url = {}
|
||||
if asset_type == 'website':
|
||||
website_urls = [(r.get('url'), r.get('target_id')) for r in results if r.get('url') and r.get('target_id')]
|
||||
vulnerabilities_by_url = self._get_vulnerabilities_by_url_prefix(website_urls) if website_urls else {}
|
||||
|
||||
# 格式化结果
|
||||
formatted_results = [self._format_result(r, vulnerabilities_by_url, asset_type) for r in results]
|
||||
|
||||
return success_response(data={
|
||||
'results': formatted_results,
|
||||
'total': total,
|
||||
'page': page,
|
||||
'pageSize': page_size,
|
||||
'totalPages': total_pages,
|
||||
'assetType': asset_type,
|
||||
})
|
||||
|
||||
|
||||
class AssetSearchExportView(APIView):
|
||||
"""
|
||||
资产搜索导出 API
|
||||
|
||||
GET /api/assets/search/export/
|
||||
|
||||
Query Parameters:
|
||||
q: 搜索查询表达式
|
||||
asset_type: 资产类型 ('website' 或 'endpoint',默认 'website')
|
||||
|
||||
Response:
|
||||
CSV 文件(带 Content-Length,支持浏览器显示下载进度)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = AssetSearchService()
|
||||
|
||||
def _get_headers_and_formatters(self, asset_type: str):
|
||||
"""获取 CSV 表头和格式化器"""
|
||||
from apps.common.utils import format_datetime, format_list_field
|
||||
|
||||
if asset_type == 'website':
|
||||
headers = ['url', 'host', 'title', 'status_code', 'content_type', 'content_length',
|
||||
'webserver', 'location', 'tech', 'vhost', 'created_at']
|
||||
else:
|
||||
headers = ['url', 'host', 'title', 'status_code', 'content_type', 'content_length',
|
||||
'webserver', 'location', 'tech', 'matched_gf_patterns', 'vhost', 'created_at']
|
||||
|
||||
formatters = {
|
||||
'created_at': format_datetime,
|
||||
'tech': lambda x: format_list_field(x, separator='; '),
|
||||
'matched_gf_patterns': lambda x: format_list_field(x, separator='; '),
|
||||
'vhost': lambda x: 'true' if x else ('false' if x is False else ''),
|
||||
}
|
||||
|
||||
return headers, formatters
|
||||
|
||||
def get(self, request: Request):
|
||||
"""导出搜索结果为 CSV(带 Content-Length,支持下载进度显示)"""
|
||||
from apps.common.utils import create_csv_export_response
|
||||
|
||||
# 获取搜索查询
|
||||
query = request.query_params.get('q', '').strip()
|
||||
|
||||
if not query:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Search query (q) is required',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 获取并验证资产类型
|
||||
asset_type = request.query_params.get('asset_type', 'website').strip().lower()
|
||||
if asset_type not in VALID_ASSET_TYPES:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=f'Invalid asset_type. Must be one of: {", ".join(VALID_ASSET_TYPES)}',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 检查是否有结果(快速检查,避免空导出)
|
||||
total = self.service.count(query, asset_type)
|
||||
if total == 0:
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='No results to export',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
# 获取表头和格式化器
|
||||
headers, formatters = self._get_headers_and_formatters(asset_type)
|
||||
|
||||
# 生成文件名
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
filename = f'search_{asset_type}_{timestamp}.csv'
|
||||
|
||||
# 使用通用导出工具
|
||||
data_iterator = self.service.search_iter(query, asset_type)
|
||||
return create_csv_export_response(
|
||||
data_iterator=data_iterator,
|
||||
headers=headers,
|
||||
filename=filename,
|
||||
field_formatters=formatters,
|
||||
show_progress=True # 显示下载进度
|
||||
)
|
||||
@@ -40,8 +40,14 @@ def fetch_config_and_setup_django():
|
||||
print(f"[CONFIG] 正在从配置中心获取配置: {config_url}")
|
||||
print(f"[CONFIG] IS_LOCAL={is_local}")
|
||||
try:
|
||||
# 构建请求头(包含 Worker API Key)
|
||||
headers = {}
|
||||
worker_api_key = os.environ.get("WORKER_API_KEY", "")
|
||||
if worker_api_key:
|
||||
headers["X-Worker-API-Key"] = worker_api_key
|
||||
|
||||
# verify=False: 远程 Worker 通过 HTTPS 访问时可能使用自签名证书
|
||||
resp = requests.get(config_url, timeout=10, verify=False)
|
||||
resp = requests.get(config_url, headers=headers, timeout=10, verify=False)
|
||||
resp.raise_for_status()
|
||||
config = resp.json()
|
||||
|
||||
@@ -57,9 +63,6 @@ def fetch_config_and_setup_django():
|
||||
os.environ.setdefault("DB_USER", db_user)
|
||||
os.environ.setdefault("DB_PASSWORD", config['db']['password'])
|
||||
|
||||
# Redis 配置
|
||||
os.environ.setdefault("REDIS_URL", config['redisUrl'])
|
||||
|
||||
# 日志配置
|
||||
os.environ.setdefault("LOG_DIR", config['paths']['logs'])
|
||||
os.environ.setdefault("LOG_LEVEL", config['logging']['level'])
|
||||
@@ -71,7 +74,6 @@ def fetch_config_and_setup_django():
|
||||
print(f"[CONFIG] DB_PORT: {db_port}")
|
||||
print(f"[CONFIG] DB_NAME: {db_name}")
|
||||
print(f"[CONFIG] DB_USER: {db_user}")
|
||||
print(f"[CONFIG] REDIS_URL: {config['redisUrl']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 获取配置失败: {config_url} - {e}", file=sys.stderr)
|
||||
|
||||
49
backend/apps/common/exception_handlers.py
Normal file
49
backend/apps/common/exception_handlers.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
自定义异常处理器
|
||||
|
||||
统一处理 DRF 异常,确保错误响应格式一致
|
||||
"""
|
||||
|
||||
from rest_framework.views import exception_handler
|
||||
from rest_framework import status
|
||||
from rest_framework.exceptions import AuthenticationFailed, NotAuthenticated
|
||||
|
||||
from apps.common.response_helpers import error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
|
||||
|
||||
def custom_exception_handler(exc, context):
|
||||
"""
|
||||
自定义异常处理器
|
||||
|
||||
处理认证相关异常,返回统一格式的错误响应
|
||||
"""
|
||||
# 先调用 DRF 默认的异常处理器
|
||||
response = exception_handler(exc, context)
|
||||
|
||||
if response is not None:
|
||||
# 处理 401 未认证错误
|
||||
if response.status_code == status.HTTP_401_UNAUTHORIZED:
|
||||
return error_response(
|
||||
code=ErrorCodes.UNAUTHORIZED,
|
||||
message='Authentication required',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# 处理 403 权限不足错误
|
||||
if response.status_code == status.HTTP_403_FORBIDDEN:
|
||||
return error_response(
|
||||
code=ErrorCodes.PERMISSION_DENIED,
|
||||
message='Permission denied',
|
||||
status_code=status.HTTP_403_FORBIDDEN
|
||||
)
|
||||
|
||||
# 处理 NotAuthenticated 和 AuthenticationFailed 异常
|
||||
if isinstance(exc, (NotAuthenticated, AuthenticationFailed)):
|
||||
return error_response(
|
||||
code=ErrorCodes.UNAUTHORIZED,
|
||||
message='Authentication required',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
return response
|
||||
34
backend/apps/common/migrations/0001_initial.py
Normal file
34
backend/apps/common/migrations/0001_initial.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Generated by Django 5.2.7 on 2026-01-06 00:55
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
('targets', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='BlacklistRule',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('pattern', models.CharField(help_text='规则模式,如 *.gov, 10.0.0.0/8, 192.168.1.1', max_length=255)),
|
||||
('rule_type', models.CharField(choices=[('domain', '域名'), ('ip', 'IP地址'), ('cidr', 'CIDR范围'), ('keyword', '关键词')], help_text='规则类型:domain, ip, cidr', max_length=20)),
|
||||
('scope', models.CharField(choices=[('global', '全局规则'), ('target', 'Target规则')], db_index=True, help_text='作用域:global 或 target', max_length=20)),
|
||||
('description', models.CharField(blank=True, default='', help_text='规则描述', max_length=500)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
('target', models.ForeignKey(blank=True, help_text='关联的 Target(仅 scope=target 时有值)', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='blacklist_rules', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'blacklist_rule',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['scope', 'rule_type'], name='blacklist_r_scope_6ff77f_idx'), models.Index(fields=['target', 'scope'], name='blacklist_r_target__191441_idx')],
|
||||
'constraints': [models.UniqueConstraint(fields=('pattern', 'scope', 'target'), name='unique_blacklist_rule')],
|
||||
},
|
||||
),
|
||||
]
|
||||
0
backend/apps/common/migrations/__init__.py
Normal file
0
backend/apps/common/migrations/__init__.py
Normal file
4
backend/apps/common/models/__init__.py
Normal file
4
backend/apps/common/models/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Common models"""
|
||||
from apps.common.models.blacklist import BlacklistRule
|
||||
|
||||
__all__ = ['BlacklistRule']
|
||||
71
backend/apps/common/models/blacklist.py
Normal file
71
backend/apps/common/models/blacklist.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""黑名单规则模型"""
|
||||
from django.db import models
|
||||
|
||||
|
||||
class BlacklistRule(models.Model):
|
||||
"""黑名单规则模型
|
||||
|
||||
用于存储黑名单过滤规则,支持域名、IP、CIDR 三种类型。
|
||||
支持两层作用域:全局规则和 Target 级规则。
|
||||
"""
|
||||
|
||||
class RuleType(models.TextChoices):
|
||||
DOMAIN = 'domain', '域名'
|
||||
IP = 'ip', 'IP地址'
|
||||
CIDR = 'cidr', 'CIDR范围'
|
||||
KEYWORD = 'keyword', '关键词'
|
||||
|
||||
class Scope(models.TextChoices):
|
||||
GLOBAL = 'global', '全局规则'
|
||||
TARGET = 'target', 'Target规则'
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
pattern = models.CharField(
|
||||
max_length=255,
|
||||
help_text='规则模式,如 *.gov, 10.0.0.0/8, 192.168.1.1'
|
||||
)
|
||||
rule_type = models.CharField(
|
||||
max_length=20,
|
||||
choices=RuleType.choices,
|
||||
help_text='规则类型:domain, ip, cidr'
|
||||
)
|
||||
scope = models.CharField(
|
||||
max_length=20,
|
||||
choices=Scope.choices,
|
||||
db_index=True,
|
||||
help_text='作用域:global 或 target'
|
||||
)
|
||||
target = models.ForeignKey(
|
||||
'targets.Target',
|
||||
on_delete=models.CASCADE,
|
||||
null=True,
|
||||
blank=True,
|
||||
related_name='blacklist_rules',
|
||||
help_text='关联的 Target(仅 scope=target 时有值)'
|
||||
)
|
||||
description = models.CharField(
|
||||
max_length=500,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='规则描述'
|
||||
)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'blacklist_rule'
|
||||
indexes = [
|
||||
models.Index(fields=['scope', 'rule_type']),
|
||||
models.Index(fields=['target', 'scope']),
|
||||
]
|
||||
constraints = [
|
||||
models.UniqueConstraint(
|
||||
fields=['pattern', 'scope', 'target'],
|
||||
name='unique_blacklist_rule'
|
||||
),
|
||||
]
|
||||
ordering = ['-created_at']
|
||||
|
||||
def __str__(self):
|
||||
if self.scope == self.Scope.TARGET and self.target:
|
||||
return f"[{self.scope}:{self.target_id}] {self.pattern}"
|
||||
return f"[{self.scope}] {self.pattern}"
|
||||
80
backend/apps/common/permissions.py
Normal file
80
backend/apps/common/permissions.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
集中式权限管理
|
||||
|
||||
实现三类端点的认证逻辑:
|
||||
1. 公开端点(无需认证):登录、登出、获取当前用户状态
|
||||
2. Worker 端点(API Key 认证):注册、配置、心跳、回调、资源同步
|
||||
3. 业务端点(Session 认证):其他所有 API
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from django.conf import settings
|
||||
from rest_framework.permissions import BasePermission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 公开端点白名单(无需任何认证)
|
||||
PUBLIC_ENDPOINTS = [
|
||||
r'^/api/auth/login/$',
|
||||
r'^/api/auth/logout/$',
|
||||
r'^/api/auth/me/$',
|
||||
]
|
||||
|
||||
# Worker API 端点(需要 API Key 认证)
|
||||
# 包括:注册、配置、心跳、回调、资源同步(字典下载)
|
||||
WORKER_ENDPOINTS = [
|
||||
r'^/api/workers/register/$',
|
||||
r'^/api/workers/config/$',
|
||||
r'^/api/workers/\d+/heartbeat/$',
|
||||
r'^/api/callbacks/',
|
||||
# 资源同步端点(Worker 需要下载字典文件)
|
||||
r'^/api/wordlists/download/$',
|
||||
# 注意:指纹导出 API 使用 Session 认证(前端用户导出用)
|
||||
# Worker 通过数据库直接获取指纹数据,不需要 HTTP API
|
||||
]
|
||||
|
||||
|
||||
class IsAuthenticatedOrPublic(BasePermission):
|
||||
"""
|
||||
自定义权限类:
|
||||
- 白名单内的端点公开访问
|
||||
- Worker 端点需要 API Key 认证
|
||||
- 其他端点需要 Session 认证
|
||||
"""
|
||||
|
||||
def has_permission(self, request, view):
|
||||
path = request.path
|
||||
|
||||
# 检查是否在公开白名单内
|
||||
for pattern in PUBLIC_ENDPOINTS:
|
||||
if re.match(pattern, path):
|
||||
return True
|
||||
|
||||
# 检查是否是 Worker 端点
|
||||
for pattern in WORKER_ENDPOINTS:
|
||||
if re.match(pattern, path):
|
||||
return self._check_worker_api_key(request)
|
||||
|
||||
# 其他路径需要 Session 认证
|
||||
return request.user and request.user.is_authenticated
|
||||
|
||||
def _check_worker_api_key(self, request):
|
||||
"""验证 Worker API Key"""
|
||||
api_key = request.headers.get('X-Worker-API-Key')
|
||||
expected_key = getattr(settings, 'WORKER_API_KEY', None)
|
||||
|
||||
if not expected_key:
|
||||
# 未配置 API Key 时,拒绝所有 Worker 请求
|
||||
logger.warning("WORKER_API_KEY 未配置,拒绝 Worker 请求")
|
||||
return False
|
||||
|
||||
if not api_key:
|
||||
logger.warning(f"Worker 请求缺少 X-Worker-API-Key Header: {request.path}")
|
||||
return False
|
||||
|
||||
if api_key != expected_key:
|
||||
logger.warning(f"Worker API Key 无效: {request.path}")
|
||||
return False
|
||||
|
||||
return True
|
||||
12
backend/apps/common/serializers/__init__.py
Normal file
12
backend/apps/common/serializers/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Common serializers"""
|
||||
from .blacklist_serializers import (
|
||||
BlacklistRuleSerializer,
|
||||
GlobalBlacklistRuleSerializer,
|
||||
TargetBlacklistRuleSerializer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'BlacklistRuleSerializer',
|
||||
'GlobalBlacklistRuleSerializer',
|
||||
'TargetBlacklistRuleSerializer',
|
||||
]
|
||||
68
backend/apps/common/serializers/blacklist_serializers.py
Normal file
68
backend/apps/common/serializers/blacklist_serializers.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""黑名单规则序列化器"""
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.common.models import BlacklistRule
|
||||
from apps.common.utils import detect_rule_type
|
||||
|
||||
|
||||
class BlacklistRuleSerializer(serializers.ModelSerializer):
|
||||
"""黑名单规则序列化器"""
|
||||
|
||||
class Meta:
|
||||
model = BlacklistRule
|
||||
fields = [
|
||||
'id',
|
||||
'pattern',
|
||||
'rule_type',
|
||||
'scope',
|
||||
'target',
|
||||
'description',
|
||||
'created_at',
|
||||
]
|
||||
read_only_fields = ['id', 'rule_type', 'created_at']
|
||||
|
||||
def validate_pattern(self, value):
|
||||
"""验证规则模式"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("规则模式不能为空")
|
||||
return value.strip()
|
||||
|
||||
def create(self, validated_data):
|
||||
"""创建规则时自动识别规则类型"""
|
||||
pattern = validated_data.get('pattern', '')
|
||||
validated_data['rule_type'] = detect_rule_type(pattern)
|
||||
return super().create(validated_data)
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
"""更新规则时重新识别规则类型"""
|
||||
if 'pattern' in validated_data:
|
||||
pattern = validated_data['pattern']
|
||||
validated_data['rule_type'] = detect_rule_type(pattern)
|
||||
return super().update(instance, validated_data)
|
||||
|
||||
|
||||
class GlobalBlacklistRuleSerializer(BlacklistRuleSerializer):
|
||||
"""全局黑名单规则序列化器"""
|
||||
|
||||
class Meta(BlacklistRuleSerializer.Meta):
|
||||
fields = ['id', 'pattern', 'rule_type', 'description', 'created_at']
|
||||
read_only_fields = ['id', 'rule_type', 'created_at']
|
||||
|
||||
def create(self, validated_data):
|
||||
"""创建全局规则"""
|
||||
validated_data['scope'] = BlacklistRule.Scope.GLOBAL
|
||||
validated_data['target'] = None
|
||||
return super().create(validated_data)
|
||||
|
||||
|
||||
class TargetBlacklistRuleSerializer(BlacklistRuleSerializer):
|
||||
"""Target 黑名单规则序列化器"""
|
||||
|
||||
class Meta(BlacklistRuleSerializer.Meta):
|
||||
fields = ['id', 'pattern', 'rule_type', 'description', 'created_at']
|
||||
read_only_fields = ['id', 'rule_type', 'created_at']
|
||||
|
||||
def create(self, validated_data):
|
||||
"""创建 Target 规则(target_id 由 view 设置)"""
|
||||
validated_data['scope'] = BlacklistRule.Scope.TARGET
|
||||
return super().create(validated_data)
|
||||
@@ -3,13 +3,16 @@
|
||||
|
||||
提供系统级别的公共服务,包括:
|
||||
- SystemLogService: 系统日志读取服务
|
||||
- BlacklistService: 黑名单过滤服务
|
||||
|
||||
注意:FilterService 已移至 apps.common.utils.filter_utils
|
||||
推荐使用: from apps.common.utils.filter_utils import apply_filters
|
||||
"""
|
||||
|
||||
from .system_log_service import SystemLogService
|
||||
from .blacklist_service import BlacklistService
|
||||
|
||||
__all__ = [
|
||||
'SystemLogService',
|
||||
'BlacklistService',
|
||||
]
|
||||
|
||||
176
backend/apps/common/services/blacklist_service.py
Normal file
176
backend/apps/common/services/blacklist_service.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
黑名单规则管理服务
|
||||
|
||||
负责黑名单规则的 CRUD 操作(数据库层面)。
|
||||
过滤逻辑请使用 apps.common.utils.BlacklistFilter。
|
||||
|
||||
架构说明:
|
||||
- Model: BlacklistRule (apps.common.models.blacklist)
|
||||
- Service: BlacklistService (本文件) - 规则 CRUD
|
||||
- Utils: BlacklistFilter (apps.common.utils.blacklist_filter) - 过滤逻辑
|
||||
- View: GlobalBlacklistView, TargetViewSet.blacklist
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from apps.common.utils import detect_rule_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_patterns(patterns: List[str]) -> List[str]:
|
||||
"""
|
||||
规范化规则列表:去重 + 过滤空行
|
||||
|
||||
Args:
|
||||
patterns: 原始规则列表
|
||||
|
||||
Returns:
|
||||
List[str]: 去重后的规则列表(保持顺序)
|
||||
"""
|
||||
return list(dict.fromkeys(filter(None, (p.strip() for p in patterns))))
|
||||
|
||||
|
||||
class BlacklistService:
|
||||
"""
|
||||
黑名单规则管理服务
|
||||
|
||||
只负责规则的 CRUD 操作,不包含过滤逻辑。
|
||||
过滤逻辑请使用 BlacklistFilter 工具类。
|
||||
"""
|
||||
|
||||
def get_global_rules(self) -> QuerySet:
|
||||
"""
|
||||
获取全局黑名单规则列表
|
||||
|
||||
Returns:
|
||||
QuerySet: 全局规则查询集
|
||||
"""
|
||||
from apps.common.models import BlacklistRule
|
||||
return BlacklistRule.objects.filter(scope=BlacklistRule.Scope.GLOBAL)
|
||||
|
||||
def get_target_rules(self, target_id: int) -> QuerySet:
|
||||
"""
|
||||
获取 Target 级黑名单规则列表
|
||||
|
||||
Args:
|
||||
target_id: Target ID
|
||||
|
||||
Returns:
|
||||
QuerySet: Target 级规则查询集
|
||||
"""
|
||||
from apps.common.models import BlacklistRule
|
||||
return BlacklistRule.objects.filter(
|
||||
scope=BlacklistRule.Scope.TARGET,
|
||||
target_id=target_id
|
||||
)
|
||||
|
||||
def get_rules(self, target_id: Optional[int] = None) -> List:
|
||||
"""
|
||||
获取黑名单规则(全局 + Target 级)
|
||||
|
||||
Args:
|
||||
target_id: Target ID,用于加载 Target 级规则
|
||||
|
||||
Returns:
|
||||
List[BlacklistRule]: 规则列表
|
||||
"""
|
||||
from apps.common.models import BlacklistRule
|
||||
|
||||
# 加载全局规则
|
||||
rules = list(BlacklistRule.objects.filter(scope=BlacklistRule.Scope.GLOBAL))
|
||||
|
||||
# 加载 Target 级规则
|
||||
if target_id:
|
||||
target_rules = BlacklistRule.objects.filter(
|
||||
scope=BlacklistRule.Scope.TARGET,
|
||||
target_id=target_id
|
||||
)
|
||||
rules.extend(target_rules)
|
||||
|
||||
return rules
|
||||
|
||||
def replace_global_rules(self, patterns: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
全量替换全局黑名单规则(PUT 语义)
|
||||
|
||||
Args:
|
||||
patterns: 新的规则模式列表
|
||||
|
||||
Returns:
|
||||
Dict: {'count': int} 最终规则数量
|
||||
"""
|
||||
from apps.common.models import BlacklistRule
|
||||
|
||||
count = self._replace_rules(
|
||||
patterns=patterns,
|
||||
scope=BlacklistRule.Scope.GLOBAL,
|
||||
target=None
|
||||
)
|
||||
|
||||
logger.info("全量替换全局黑名单规则: %d 条", count)
|
||||
return {'count': count}
|
||||
|
||||
def replace_target_rules(self, target, patterns: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
全量替换 Target 级黑名单规则(PUT 语义)
|
||||
|
||||
Args:
|
||||
target: Target 对象
|
||||
patterns: 新的规则模式列表
|
||||
|
||||
Returns:
|
||||
Dict: {'count': int} 最终规则数量
|
||||
"""
|
||||
from apps.common.models import BlacklistRule
|
||||
|
||||
count = self._replace_rules(
|
||||
patterns=patterns,
|
||||
scope=BlacklistRule.Scope.TARGET,
|
||||
target=target
|
||||
)
|
||||
|
||||
logger.info("全量替换 Target 黑名单规则: %d 条 (Target: %s)", count, target.name)
|
||||
return {'count': count}
|
||||
|
||||
def _replace_rules(self, patterns: List[str], scope: str, target=None) -> int:
|
||||
"""
|
||||
内部方法:全量替换规则
|
||||
|
||||
Args:
|
||||
patterns: 规则模式列表
|
||||
scope: 规则作用域 (GLOBAL/TARGET)
|
||||
target: Target 对象(仅 TARGET 作用域需要)
|
||||
|
||||
Returns:
|
||||
int: 最终规则数量
|
||||
"""
|
||||
from apps.common.models import BlacklistRule
|
||||
from django.db import transaction
|
||||
|
||||
patterns = _normalize_patterns(patterns)
|
||||
|
||||
with transaction.atomic():
|
||||
# 1. 删除旧规则
|
||||
delete_filter = {'scope': scope}
|
||||
if target:
|
||||
delete_filter['target'] = target
|
||||
BlacklistRule.objects.filter(**delete_filter).delete()
|
||||
|
||||
# 2. 创建新规则
|
||||
if patterns:
|
||||
rules = [
|
||||
BlacklistRule(
|
||||
pattern=pattern,
|
||||
rule_type=detect_rule_type(pattern),
|
||||
scope=scope,
|
||||
target=target
|
||||
)
|
||||
for pattern in patterns
|
||||
]
|
||||
BlacklistRule.objects.bulk_create(rules)
|
||||
|
||||
return len(patterns)
|
||||
@@ -2,14 +2,24 @@
|
||||
通用模块 URL 配置
|
||||
|
||||
路由说明:
|
||||
- /api/auth/* 认证相关接口(登录、登出、用户信息)
|
||||
- /api/system/* 系统管理接口(日志查看等)
|
||||
- /api/health/ 健康检查接口(无需认证)
|
||||
- /api/auth/* 认证相关接口(登录、登出、用户信息)
|
||||
- /api/system/* 系统管理接口(日志查看等)
|
||||
- /api/blacklist/* 黑名单管理接口
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
from .views import LoginView, LogoutView, MeView, ChangePasswordView, SystemLogsView, SystemLogFilesView
|
||||
|
||||
from .views import (
|
||||
LoginView, LogoutView, MeView, ChangePasswordView,
|
||||
SystemLogsView, SystemLogFilesView, HealthCheckView,
|
||||
GlobalBlacklistView,
|
||||
)
|
||||
|
||||
urlpatterns = [
|
||||
# 健康检查(无需认证)
|
||||
path('health/', HealthCheckView.as_view(), name='health-check'),
|
||||
|
||||
# 认证相关
|
||||
path('auth/login/', LoginView.as_view(), name='auth-login'),
|
||||
path('auth/logout/', LogoutView.as_view(), name='auth-logout'),
|
||||
@@ -19,4 +29,7 @@ urlpatterns = [
|
||||
# 系统管理
|
||||
path('system/logs/', SystemLogsView.as_view(), name='system-logs'),
|
||||
path('system/logs/files/', SystemLogFilesView.as_view(), name='system-log-files'),
|
||||
|
||||
# 黑名单管理(PUT 全量替换模式)
|
||||
path('blacklist/rules/', GlobalBlacklistView.as_view(), name='blacklist-rules'),
|
||||
]
|
||||
|
||||
@@ -11,8 +11,14 @@ from .csv_utils import (
|
||||
generate_csv_rows,
|
||||
format_list_field,
|
||||
format_datetime,
|
||||
create_csv_export_response,
|
||||
UTF8_BOM,
|
||||
)
|
||||
from .blacklist_filter import (
|
||||
BlacklistFilter,
|
||||
detect_rule_type,
|
||||
extract_host,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'deduplicate_for_bulk',
|
||||
@@ -24,5 +30,9 @@ __all__ = [
|
||||
'generate_csv_rows',
|
||||
'format_list_field',
|
||||
'format_datetime',
|
||||
'create_csv_export_response',
|
||||
'UTF8_BOM',
|
||||
'BlacklistFilter',
|
||||
'detect_rule_type',
|
||||
'extract_host',
|
||||
]
|
||||
|
||||
246
backend/apps/common/utils/blacklist_filter.py
Normal file
246
backend/apps/common/utils/blacklist_filter.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
黑名单过滤工具
|
||||
|
||||
提供域名、IP、CIDR、关键词的黑名单匹配功能。
|
||||
纯工具类,不涉及数据库操作。
|
||||
|
||||
支持的规则类型:
|
||||
1. 域名精确匹配: example.com
|
||||
- 规则: example.com
|
||||
- 匹配: example.com
|
||||
- 不匹配: sub.example.com, other.com
|
||||
|
||||
2. 域名后缀匹配: *.example.com
|
||||
- 规则: *.example.com
|
||||
- 匹配: sub.example.com, a.b.example.com, example.com
|
||||
- 不匹配: other.com, example.com.cn
|
||||
|
||||
3. 关键词匹配: *cdn*
|
||||
- 规则: *cdn*
|
||||
- 匹配: cdn.example.com, a.cdn.b.com, mycdn123.com
|
||||
- 不匹配: example.com (不包含 cdn)
|
||||
|
||||
4. IP 精确匹配: 192.168.1.1
|
||||
- 规则: 192.168.1.1
|
||||
- 匹配: 192.168.1.1
|
||||
- 不匹配: 192.168.1.2
|
||||
|
||||
5. CIDR 范围匹配: 192.168.0.0/24
|
||||
- 规则: 192.168.0.0/24
|
||||
- 匹配: 192.168.0.1, 192.168.0.255
|
||||
- 不匹配: 192.168.1.1
|
||||
|
||||
使用方式:
|
||||
from apps.common.utils import BlacklistFilter
|
||||
|
||||
# 创建过滤器(传入规则列表)
|
||||
rules = BlacklistRule.objects.filter(...)
|
||||
filter = BlacklistFilter(rules)
|
||||
|
||||
# 检查单个目标
|
||||
if filter.is_allowed('http://example.com'):
|
||||
process(url)
|
||||
|
||||
# 流式处理
|
||||
for url in urls:
|
||||
if filter.is_allowed(url):
|
||||
process(url)
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from apps.common.validators import is_valid_ip, validate_cidr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def detect_rule_type(pattern: str) -> str:
|
||||
"""
|
||||
自动识别规则类型
|
||||
|
||||
支持的模式:
|
||||
- 域名精确匹配: example.com
|
||||
- 域名后缀匹配: *.example.com
|
||||
- 关键词匹配: *cdn* (匹配包含 cdn 的域名)
|
||||
- IP 精确匹配: 192.168.1.1
|
||||
- CIDR 范围: 192.168.0.0/24
|
||||
|
||||
Args:
|
||||
pattern: 规则模式字符串
|
||||
|
||||
Returns:
|
||||
str: 规则类型 ('domain', 'ip', 'cidr', 'keyword')
|
||||
"""
|
||||
if not pattern:
|
||||
return 'domain'
|
||||
|
||||
pattern = pattern.strip()
|
||||
|
||||
# 检查关键词模式: *keyword* (前后都有星号,中间无点)
|
||||
if pattern.startswith('*') and pattern.endswith('*') and len(pattern) > 2:
|
||||
keyword = pattern[1:-1]
|
||||
# 关键词中不能有点(否则可能是域名模式)
|
||||
if '.' not in keyword:
|
||||
return 'keyword'
|
||||
|
||||
# 检查 CIDR(包含 /)
|
||||
if '/' in pattern:
|
||||
try:
|
||||
validate_cidr(pattern)
|
||||
return 'cidr'
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 检查 IP(去掉通配符前缀后验证)
|
||||
clean_pattern = pattern.lstrip('*').lstrip('.')
|
||||
if is_valid_ip(clean_pattern):
|
||||
return 'ip'
|
||||
|
||||
# 默认为域名
|
||||
return 'domain'
|
||||
|
||||
|
||||
def extract_host(target: str) -> str:
|
||||
"""
|
||||
从目标字符串中提取主机名
|
||||
|
||||
支持:
|
||||
- 纯域名:example.com
|
||||
- 纯 IP:192.168.1.1
|
||||
- URL:http://example.com/path
|
||||
|
||||
Args:
|
||||
target: 目标字符串
|
||||
|
||||
Returns:
|
||||
str: 提取的主机名
|
||||
"""
|
||||
if not target:
|
||||
return ''
|
||||
|
||||
target = target.strip()
|
||||
|
||||
# 如果是 URL,提取 hostname
|
||||
if '://' in target:
|
||||
try:
|
||||
parsed = urlparse(target)
|
||||
return parsed.hostname or target
|
||||
except Exception:
|
||||
return target
|
||||
|
||||
return target
|
||||
|
||||
|
||||
class BlacklistFilter:
|
||||
"""
|
||||
黑名单过滤器
|
||||
|
||||
预编译规则,提供高效的匹配功能。
|
||||
"""
|
||||
|
||||
def __init__(self, rules: List):
|
||||
"""
|
||||
初始化过滤器
|
||||
|
||||
Args:
|
||||
rules: BlacklistRule 对象列表
|
||||
"""
|
||||
from apps.common.models import BlacklistRule
|
||||
|
||||
# 预解析:按类型分类 + CIDR 预编译
|
||||
self._domain_rules = [] # (pattern, is_wildcard, suffix)
|
||||
self._ip_rules = set() # 精确 IP 用 set,O(1) 查找
|
||||
self._cidr_rules = [] # (pattern, network_obj)
|
||||
self._keyword_rules = [] # 关键词列表(小写)
|
||||
|
||||
# 去重:跨 scope 可能有重复规则
|
||||
seen_patterns = set()
|
||||
|
||||
for rule in rules:
|
||||
if rule.pattern in seen_patterns:
|
||||
continue
|
||||
seen_patterns.add(rule.pattern)
|
||||
if rule.rule_type == BlacklistRule.RuleType.DOMAIN:
|
||||
pattern = rule.pattern.lower()
|
||||
if pattern.startswith('*.'):
|
||||
self._domain_rules.append((pattern, True, pattern[1:]))
|
||||
else:
|
||||
self._domain_rules.append((pattern, False, None))
|
||||
elif rule.rule_type == BlacklistRule.RuleType.IP:
|
||||
self._ip_rules.add(rule.pattern)
|
||||
elif rule.rule_type == BlacklistRule.RuleType.CIDR:
|
||||
try:
|
||||
network = ipaddress.ip_network(rule.pattern, strict=False)
|
||||
self._cidr_rules.append((rule.pattern, network))
|
||||
except ValueError:
|
||||
pass
|
||||
elif rule.rule_type == BlacklistRule.RuleType.KEYWORD:
|
||||
# *cdn* -> cdn
|
||||
keyword = rule.pattern[1:-1].lower()
|
||||
self._keyword_rules.append(keyword)
|
||||
|
||||
def is_allowed(self, target: str) -> bool:
|
||||
"""
|
||||
检查目标是否通过过滤
|
||||
|
||||
Args:
|
||||
target: 要检查的目标(域名/IP/URL)
|
||||
|
||||
Returns:
|
||||
bool: True 表示通过(不在黑名单),False 表示被过滤
|
||||
"""
|
||||
if not target:
|
||||
return True
|
||||
|
||||
host = extract_host(target)
|
||||
if not host:
|
||||
return True
|
||||
|
||||
# 先判断输入类型,再走对应分支
|
||||
if is_valid_ip(host):
|
||||
return self._check_ip_rules(host)
|
||||
else:
|
||||
return self._check_domain_rules(host)
|
||||
|
||||
def _check_domain_rules(self, host: str) -> bool:
|
||||
"""检查域名规则(精确匹配 + 后缀匹配 + 关键词匹配)"""
|
||||
host_lower = host.lower()
|
||||
|
||||
# 1. 域名规则(精确 + 后缀)
|
||||
for pattern, is_wildcard, suffix in self._domain_rules:
|
||||
if is_wildcard:
|
||||
if host_lower.endswith(suffix) or host_lower == pattern[2:]:
|
||||
return False
|
||||
else:
|
||||
if host_lower == pattern:
|
||||
return False
|
||||
|
||||
# 2. 关键词匹配(字符串 in 操作,O(n*m))
|
||||
for keyword in self._keyword_rules:
|
||||
if keyword in host_lower:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _check_ip_rules(self, host: str) -> bool:
|
||||
"""检查 IP 规则(精确匹配 + CIDR)"""
|
||||
# 1. IP 精确匹配(O(1))
|
||||
if host in self._ip_rules:
|
||||
return False
|
||||
|
||||
# 2. CIDR 匹配
|
||||
if self._cidr_rules:
|
||||
try:
|
||||
ip_obj = ipaddress.ip_address(host)
|
||||
for _, network in self._cidr_rules:
|
||||
if ip_obj in network:
|
||||
return False
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -4,13 +4,21 @@
|
||||
- UTF-8 BOM(Excel 兼容)
|
||||
- RFC 4180 规范转义
|
||||
- 流式生成(内存友好)
|
||||
- 带 Content-Length 的文件响应(支持浏览器下载进度显示)
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Iterator, Dict, Any, List, Callable, Optional
|
||||
|
||||
from django.http import FileResponse, StreamingHttpResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# UTF-8 BOM,确保 Excel 正确识别编码
|
||||
UTF8_BOM = '\ufeff'
|
||||
|
||||
@@ -114,3 +122,123 @@ def format_datetime(dt: Optional[datetime]) -> str:
|
||||
dt = timezone.localtime(dt)
|
||||
|
||||
return dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
|
||||
def create_csv_export_response(
|
||||
data_iterator: Iterator[Dict[str, Any]],
|
||||
headers: List[str],
|
||||
filename: str,
|
||||
field_formatters: Optional[Dict[str, Callable]] = None,
|
||||
show_progress: bool = True
|
||||
) -> FileResponse | StreamingHttpResponse:
|
||||
"""
|
||||
创建 CSV 导出响应
|
||||
|
||||
根据 show_progress 参数选择响应类型:
|
||||
- True: 使用临时文件 + FileResponse,带 Content-Length(浏览器显示下载进度)
|
||||
- False: 使用 StreamingHttpResponse(内存更友好,但无下载进度)
|
||||
|
||||
Args:
|
||||
data_iterator: 数据迭代器,每个元素是一个字典
|
||||
headers: CSV 表头列表
|
||||
filename: 下载文件名(如 "export_2024.csv")
|
||||
field_formatters: 字段格式化函数字典
|
||||
show_progress: 是否显示下载进度(默认 True)
|
||||
|
||||
Returns:
|
||||
FileResponse 或 StreamingHttpResponse
|
||||
|
||||
Example:
|
||||
>>> data_iter = service.iter_data()
|
||||
>>> headers = ['url', 'host', 'created_at']
|
||||
>>> formatters = {'created_at': format_datetime}
|
||||
>>> response = create_csv_export_response(
|
||||
... data_iter, headers, 'websites.csv', formatters
|
||||
... )
|
||||
>>> return response
|
||||
"""
|
||||
if show_progress:
|
||||
return _create_file_response(data_iterator, headers, filename, field_formatters)
|
||||
else:
|
||||
return _create_streaming_response(data_iterator, headers, filename, field_formatters)
|
||||
|
||||
|
||||
def _create_file_response(
|
||||
data_iterator: Iterator[Dict[str, Any]],
|
||||
headers: List[str],
|
||||
filename: str,
|
||||
field_formatters: Optional[Dict[str, Callable]] = None
|
||||
) -> FileResponse:
|
||||
"""
|
||||
创建带 Content-Length 的文件响应(支持浏览器下载进度)
|
||||
|
||||
实现方式:先写入临时文件,再返回 FileResponse
|
||||
"""
|
||||
# 创建临时文件
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
mode='w',
|
||||
suffix='.csv',
|
||||
delete=False,
|
||||
encoding='utf-8'
|
||||
)
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
# 流式写入 CSV 数据到临时文件
|
||||
for row in generate_csv_rows(data_iterator, headers, field_formatters):
|
||||
temp_file.write(row)
|
||||
temp_file.close()
|
||||
|
||||
# 获取文件大小
|
||||
file_size = os.path.getsize(temp_path)
|
||||
|
||||
# 创建文件响应
|
||||
response = FileResponse(
|
||||
open(temp_path, 'rb'),
|
||||
content_type='text/csv; charset=utf-8',
|
||||
as_attachment=True,
|
||||
filename=filename
|
||||
)
|
||||
response['Content-Length'] = file_size
|
||||
|
||||
# 设置清理回调:响应完成后删除临时文件
|
||||
original_close = response.file_to_stream.close
|
||||
def close_and_cleanup():
|
||||
original_close()
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
response.file_to_stream.close = close_and_cleanup
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# 清理临时文件
|
||||
try:
|
||||
temp_file.close()
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
logger.error(f"创建 CSV 导出响应失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _create_streaming_response(
|
||||
data_iterator: Iterator[Dict[str, Any]],
|
||||
headers: List[str],
|
||||
filename: str,
|
||||
field_formatters: Optional[Dict[str, Callable]] = None
|
||||
) -> StreamingHttpResponse:
|
||||
"""
|
||||
创建流式响应(无 Content-Length,内存更友好)
|
||||
"""
|
||||
response = StreamingHttpResponse(
|
||||
generate_csv_rows(data_iterator, headers, field_formatters),
|
||||
content_type='text/csv; charset=utf-8'
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="{filename}"'
|
||||
return response
|
||||
|
||||
@@ -29,11 +29,19 @@ from dataclasses import dataclass
|
||||
from typing import List, Dict, Optional, Union
|
||||
from enum import Enum
|
||||
|
||||
from django.db.models import QuerySet, Q
|
||||
from django.db.models import QuerySet, Q, F, Func, CharField
|
||||
from django.db.models.functions import Cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ArrayToString(Func):
|
||||
"""PostgreSQL array_to_string 函数"""
|
||||
function = 'array_to_string'
|
||||
template = "%(function)s(%(expressions)s, ',')"
|
||||
output_field = CharField()
|
||||
|
||||
|
||||
class LogicalOp(Enum):
|
||||
"""逻辑运算符"""
|
||||
AND = 'AND'
|
||||
@@ -86,9 +94,21 @@ class QueryParser:
|
||||
if not query_string or not query_string.strip():
|
||||
return []
|
||||
|
||||
# 第一步:提取所有过滤条件并用占位符替换,保护引号内的空格
|
||||
filters_found = []
|
||||
placeholder_pattern = '__FILTER_{}__'
|
||||
|
||||
def replace_filter(match):
|
||||
idx = len(filters_found)
|
||||
filters_found.append(match.group(0))
|
||||
return placeholder_pattern.format(idx)
|
||||
|
||||
# 先用正则提取所有 field="value" 形式的条件
|
||||
protected = cls.FILTER_PATTERN.sub(replace_filter, query_string)
|
||||
|
||||
# 标准化逻辑运算符
|
||||
# 先处理 || 和 or -> __OR__
|
||||
normalized = cls.OR_PATTERN.sub(' __OR__ ', query_string)
|
||||
normalized = cls.OR_PATTERN.sub(' __OR__ ', protected)
|
||||
# 再处理 && 和 and -> __AND__
|
||||
normalized = cls.AND_PATTERN.sub(' __AND__ ', normalized)
|
||||
|
||||
@@ -103,20 +123,26 @@ class QueryParser:
|
||||
pending_op = LogicalOp.OR
|
||||
elif token == '__AND__':
|
||||
pending_op = LogicalOp.AND
|
||||
else:
|
||||
# 尝试解析为过滤条件
|
||||
match = cls.FILTER_PATTERN.match(token)
|
||||
if match:
|
||||
field, operator, value = match.groups()
|
||||
groups.append(FilterGroup(
|
||||
filter=ParsedFilter(
|
||||
field=field.lower(),
|
||||
operator=operator,
|
||||
value=value
|
||||
),
|
||||
logical_op=pending_op if groups else LogicalOp.AND # 第一个条件默认 AND
|
||||
))
|
||||
pending_op = LogicalOp.AND # 重置为默认 AND
|
||||
elif token.startswith('__FILTER_') and token.endswith('__'):
|
||||
# 还原占位符为原始过滤条件
|
||||
try:
|
||||
idx = int(token[9:-2]) # 提取索引
|
||||
original_filter = filters_found[idx]
|
||||
match = cls.FILTER_PATTERN.match(original_filter)
|
||||
if match:
|
||||
field, operator, value = match.groups()
|
||||
groups.append(FilterGroup(
|
||||
filter=ParsedFilter(
|
||||
field=field.lower(),
|
||||
operator=operator,
|
||||
value=value
|
||||
),
|
||||
logical_op=pending_op if groups else LogicalOp.AND
|
||||
))
|
||||
pending_op = LogicalOp.AND # 重置为默认 AND
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
# 其他 token 忽略(无效输入)
|
||||
|
||||
return groups
|
||||
|
||||
@@ -151,6 +177,21 @@ class QueryBuilder:
|
||||
|
||||
json_array_fields = json_array_fields or []
|
||||
|
||||
# 收集需要 annotate 的数组模糊搜索字段
|
||||
array_fuzzy_fields = set()
|
||||
|
||||
# 第一遍:检查是否有数组模糊匹配
|
||||
for group in filter_groups:
|
||||
f = group.filter
|
||||
db_field = field_mapping.get(f.field)
|
||||
if db_field and db_field in json_array_fields and f.operator == '=':
|
||||
array_fuzzy_fields.add(db_field)
|
||||
|
||||
# 对数组模糊搜索字段做 annotate
|
||||
for field in array_fuzzy_fields:
|
||||
annotate_name = f'{field}_text'
|
||||
queryset = queryset.annotate(**{annotate_name: ArrayToString(F(field))})
|
||||
|
||||
# 构建 Q 对象
|
||||
combined_q = None
|
||||
|
||||
@@ -187,8 +228,17 @@ class QueryBuilder:
|
||||
def _build_single_q(cls, field: str, operator: str, value: str, is_json_array: bool = False) -> Optional[Q]:
|
||||
"""构建单个条件的 Q 对象"""
|
||||
if is_json_array:
|
||||
# JSON 数组字段使用 __contains 查询
|
||||
return Q(**{f'{field}__contains': [value]})
|
||||
if operator == '==':
|
||||
# 精确匹配:数组中包含完全等于 value 的元素
|
||||
return Q(**{f'{field}__contains': [value]})
|
||||
elif operator == '!=':
|
||||
# 不包含:数组中不包含完全等于 value 的元素
|
||||
return ~Q(**{f'{field}__contains': [value]})
|
||||
else: # '=' 模糊匹配
|
||||
# 使用 annotate 后的字段进行模糊搜索
|
||||
# 字段已在 build_query 中通过 ArrayToString 转换为文本
|
||||
annotate_name = f'{field}_text'
|
||||
return Q(**{f'{annotate_name}__icontains': value})
|
||||
|
||||
if operator == '!=':
|
||||
return cls._build_not_equal_q(field, value)
|
||||
|
||||
@@ -2,11 +2,20 @@
|
||||
通用模块视图导出
|
||||
|
||||
包含:
|
||||
- 健康检查视图:Docker 健康检查
|
||||
- 认证相关视图:登录、登出、用户信息、修改密码
|
||||
- 系统日志视图:实时日志查看
|
||||
- 黑名单视图:全局黑名单规则管理
|
||||
"""
|
||||
|
||||
from .health_views import HealthCheckView
|
||||
from .auth_views import LoginView, LogoutView, MeView, ChangePasswordView
|
||||
from .system_log_views import SystemLogsView, SystemLogFilesView
|
||||
from .blacklist_views import GlobalBlacklistView
|
||||
|
||||
__all__ = ['LoginView', 'LogoutView', 'MeView', 'ChangePasswordView', 'SystemLogsView', 'SystemLogFilesView']
|
||||
__all__ = [
|
||||
'HealthCheckView',
|
||||
'LoginView', 'LogoutView', 'MeView', 'ChangePasswordView',
|
||||
'SystemLogsView', 'SystemLogFilesView',
|
||||
'GlobalBlacklistView',
|
||||
]
|
||||
|
||||
@@ -9,7 +9,7 @@ from django.utils.decorators import method_decorator
|
||||
from rest_framework import status
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rest_framework.permissions import AllowAny
|
||||
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
@@ -134,30 +134,10 @@ class ChangePasswordView(APIView):
|
||||
修改密码
|
||||
POST /api/auth/change-password/
|
||||
"""
|
||||
authentication_classes = [] # 禁用认证(绕过 CSRF)
|
||||
permission_classes = [AllowAny] # 手动检查登录状态
|
||||
|
||||
def post(self, request):
|
||||
# 手动检查登录状态(从 session 获取用户)
|
||||
from django.contrib.auth import get_user_model
|
||||
User = get_user_model()
|
||||
|
||||
user_id = request.session.get('_auth_user_id')
|
||||
if not user_id:
|
||||
return error_response(
|
||||
code=ErrorCodes.UNAUTHORIZED,
|
||||
message='Please login first',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
try:
|
||||
user = User.objects.get(pk=user_id)
|
||||
except User.DoesNotExist:
|
||||
return error_response(
|
||||
code=ErrorCodes.UNAUTHORIZED,
|
||||
message='User does not exist',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
# 使用全局权限类验证,request.user 已经是认证用户
|
||||
user = request.user
|
||||
|
||||
# CamelCaseParser 将 oldPassword -> old_password
|
||||
old_password = request.data.get('old_password')
|
||||
|
||||
80
backend/apps/common/views/blacklist_views.py
Normal file
80
backend/apps/common/views/blacklist_views.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""全局黑名单 API 视图"""
|
||||
import logging
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.services import BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GlobalBlacklistView(APIView):
|
||||
"""
|
||||
全局黑名单规则 API
|
||||
|
||||
Endpoints:
|
||||
- GET /api/blacklist/rules/ - 获取全局黑名单列表
|
||||
- PUT /api/blacklist/rules/ - 全量替换规则(文本框保存场景)
|
||||
|
||||
设计说明:
|
||||
- 使用 PUT 全量替换模式,适合"文本框每行一个规则"的前端场景
|
||||
- 用户编辑文本框 -> 点击保存 -> 后端全量替换
|
||||
|
||||
架构:MVS 模式
|
||||
- View: 参数验证、响应格式化
|
||||
- Service: 业务逻辑(BlacklistService)
|
||||
- Model: 数据持久化(BlacklistRule)
|
||||
"""
|
||||
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.blacklist_service = BlacklistService()
|
||||
|
||||
def get(self, request):
|
||||
"""
|
||||
获取全局黑名单规则列表
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"patterns": ["*.gov", "*.edu", "10.0.0.0/8"]
|
||||
}
|
||||
"""
|
||||
rules = self.blacklist_service.get_global_rules()
|
||||
patterns = list(rules.values_list('pattern', flat=True))
|
||||
return success_response(data={'patterns': patterns})
|
||||
|
||||
def put(self, request):
|
||||
"""
|
||||
全量替换全局黑名单规则
|
||||
|
||||
请求格式:
|
||||
{
|
||||
"patterns": ["*.gov", "*.edu", "10.0.0.0/8"]
|
||||
}
|
||||
|
||||
或者空数组清空所有规则:
|
||||
{
|
||||
"patterns": []
|
||||
}
|
||||
"""
|
||||
patterns = request.data.get('patterns', [])
|
||||
|
||||
# 兼容字符串输入(换行分隔)
|
||||
if isinstance(patterns, str):
|
||||
patterns = [p for p in patterns.split('\n') if p.strip()]
|
||||
|
||||
if not isinstance(patterns, list):
|
||||
return error_response(
|
||||
code='VALIDATION_ERROR',
|
||||
message='patterns 必须是数组'
|
||||
)
|
||||
|
||||
# 调用 Service 层全量替换
|
||||
result = self.blacklist_service.replace_global_rules(patterns)
|
||||
|
||||
return success_response(data=result)
|
||||
24
backend/apps/common/views/health_views.py
Normal file
24
backend/apps/common/views/health_views.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
健康检查视图
|
||||
|
||||
提供 Docker 健康检查端点,无需认证。
|
||||
"""
|
||||
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.permissions import AllowAny
|
||||
|
||||
|
||||
class HealthCheckView(APIView):
|
||||
"""
|
||||
健康检查端点
|
||||
|
||||
GET /api/health/
|
||||
|
||||
返回服务状态,用于 Docker 健康检查。
|
||||
此端点无需认证。
|
||||
"""
|
||||
permission_classes = [AllowAny]
|
||||
|
||||
def get(self, request):
|
||||
return Response({'status': 'ok'})
|
||||
@@ -9,7 +9,6 @@ import logging
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
@@ -42,9 +41,6 @@ class SystemLogFilesView(APIView):
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
authentication_classes = []
|
||||
permission_classes = [AllowAny]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -80,15 +76,7 @@ class SystemLogsView(APIView):
|
||||
{
|
||||
"content": "日志内容字符串..."
|
||||
}
|
||||
|
||||
Note:
|
||||
- 当前为开发阶段,暂时允许匿名访问
|
||||
- 生产环境应添加管理员权限验证
|
||||
"""
|
||||
|
||||
# TODO: 生产环境应改为 IsAdminUser 权限
|
||||
authentication_classes = []
|
||||
permission_classes = [AllowAny]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
44
backend/apps/common/websocket_auth.py
Normal file
44
backend/apps/common/websocket_auth.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
WebSocket 认证基类
|
||||
|
||||
提供需要认证的 WebSocket Consumer 基类
|
||||
"""
|
||||
|
||||
import logging
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthenticatedWebsocketConsumer(AsyncWebsocketConsumer):
|
||||
"""
|
||||
需要认证的 WebSocket Consumer 基类
|
||||
|
||||
子类应该重写 on_connect() 方法实现具体的连接逻辑
|
||||
"""
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
连接时验证用户认证状态
|
||||
|
||||
未认证时使用 close(code=4001) 拒绝连接
|
||||
"""
|
||||
user = self.scope.get('user')
|
||||
|
||||
if not user or not user.is_authenticated:
|
||||
logger.warning(
|
||||
f"WebSocket 连接被拒绝:用户未认证 - Path: {self.scope.get('path')}"
|
||||
)
|
||||
await self.close(code=4001)
|
||||
return
|
||||
|
||||
# 调用子类的连接逻辑
|
||||
await self.on_connect()
|
||||
|
||||
async def on_connect(self):
|
||||
"""
|
||||
子类实现具体的连接逻辑
|
||||
|
||||
默认实现:接受连接
|
||||
"""
|
||||
await self.accept()
|
||||
@@ -6,17 +6,17 @@ import json
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from apps.common.websocket_auth import AuthenticatedWebsocketConsumer
|
||||
from apps.engine.services import WorkerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerDeployConsumer(AsyncWebsocketConsumer):
|
||||
class WorkerDeployConsumer(AuthenticatedWebsocketConsumer):
|
||||
"""
|
||||
Worker 交互式终端 WebSocket Consumer
|
||||
|
||||
@@ -31,8 +31,8 @@ class WorkerDeployConsumer(AsyncWebsocketConsumer):
|
||||
self.read_task = None
|
||||
self.worker_service = WorkerService()
|
||||
|
||||
async def connect(self):
|
||||
"""连接时加入对应 Worker 的组并自动建立 SSH 连接"""
|
||||
async def on_connect(self):
|
||||
"""连接时加入对应 Worker 的组并自动建立 SSH 连接(已通过认证)"""
|
||||
self.worker_id = self.scope['url_route']['kwargs']['worker_id']
|
||||
self.group_name = f'worker_deploy_{self.worker_id}'
|
||||
|
||||
|
||||
@@ -15,9 +15,10 @@
|
||||
"""
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
from apps.engine.models import ScanEngine
|
||||
|
||||
@@ -44,10 +45,12 @@ class Command(BaseCommand):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
default_config = f.read()
|
||||
|
||||
# 解析 YAML 为字典,后续用于生成子引擎配置
|
||||
# 使用 ruamel.yaml 解析,保留注释
|
||||
yaml_parser = YAML()
|
||||
yaml_parser.preserve_quotes = True
|
||||
try:
|
||||
config_dict = yaml.safe_load(default_config) or {}
|
||||
except yaml.YAMLError as e:
|
||||
config_dict = yaml_parser.load(default_config) or {}
|
||||
except Exception as e:
|
||||
self.stdout.write(self.style.ERROR(f'引擎配置 YAML 解析失败: {e}'))
|
||||
return
|
||||
|
||||
@@ -83,15 +86,13 @@ class Command(BaseCommand):
|
||||
if scan_type != 'subdomain_discovery' and 'tools' not in scan_cfg:
|
||||
continue
|
||||
|
||||
# 构造只包含当前扫描类型配置的 YAML
|
||||
# 构造只包含当前扫描类型配置的 YAML(保留注释)
|
||||
single_config = {scan_type: scan_cfg}
|
||||
try:
|
||||
single_yaml = yaml.safe_dump(
|
||||
single_config,
|
||||
sort_keys=False,
|
||||
allow_unicode=True,
|
||||
)
|
||||
except yaml.YAMLError as e:
|
||||
stream = StringIO()
|
||||
yaml_parser.dump(single_config, stream)
|
||||
single_yaml = stream.getvalue()
|
||||
except Exception as e:
|
||||
self.stdout.write(self.style.ERROR(f'生成子引擎 {scan_type} 配置失败: {e}'))
|
||||
continue
|
||||
|
||||
|
||||
213
backend/apps/engine/migrations/0001_initial.py
Normal file
213
backend/apps/engine/migrations/0001_initial.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# Generated by Django 5.2.7 on 2026-01-06 00:55
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='NucleiTemplateRepo',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(help_text='仓库名称,用于前端展示和配置引用', max_length=200, unique=True)),
|
||||
('repo_url', models.CharField(help_text='Git 仓库地址', max_length=500)),
|
||||
('local_path', models.CharField(blank=True, default='', help_text='本地工作目录绝对路径', max_length=500)),
|
||||
('commit_hash', models.CharField(blank=True, default='', help_text='最后同步的 Git commit hash,用于 Worker 版本校验', max_length=40)),
|
||||
('last_synced_at', models.DateTimeField(blank=True, help_text='最后一次成功同步时间', null=True)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Nuclei 模板仓库',
|
||||
'verbose_name_plural': 'Nuclei 模板仓库',
|
||||
'db_table': 'nuclei_template_repo',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ARLFingerprint',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(help_text='指纹名称', max_length=300, unique=True)),
|
||||
('rule', models.TextField(help_text='匹配规则表达式')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'ARL 指纹',
|
||||
'verbose_name_plural': 'ARL 指纹',
|
||||
'db_table': 'arl_fingerprint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['name'], name='arl_fingerp_name_c3a305_idx'), models.Index(fields=['-created_at'], name='arl_fingerp_created_ed1060_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='EholeFingerprint',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('cms', models.CharField(help_text='产品/CMS名称', max_length=200)),
|
||||
('method', models.CharField(default='keyword', help_text='匹配方式', max_length=200)),
|
||||
('location', models.CharField(default='body', help_text='匹配位置', max_length=200)),
|
||||
('keyword', models.JSONField(default=list, help_text='关键词列表')),
|
||||
('is_important', models.BooleanField(default=False, help_text='是否重点资产')),
|
||||
('type', models.CharField(blank=True, default='-', help_text='分类', max_length=100)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'EHole 指纹',
|
||||
'verbose_name_plural': 'EHole 指纹',
|
||||
'db_table': 'ehole_fingerprint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['cms'], name='ehole_finge_cms_72ca2c_idx'), models.Index(fields=['method'], name='ehole_finge_method_17f0db_idx'), models.Index(fields=['location'], name='ehole_finge_locatio_7bb82b_idx'), models.Index(fields=['type'], name='ehole_finge_type_ca2bce_idx'), models.Index(fields=['is_important'], name='ehole_finge_is_impo_d56e64_idx'), models.Index(fields=['-created_at'], name='ehole_finge_created_d862b0_idx')],
|
||||
'constraints': [models.UniqueConstraint(fields=('cms', 'method', 'location'), name='unique_ehole_fingerprint')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='FingerPrintHubFingerprint',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('fp_id', models.CharField(help_text='指纹ID', max_length=200, unique=True)),
|
||||
('name', models.CharField(help_text='指纹名称', max_length=300)),
|
||||
('author', models.CharField(blank=True, default='', help_text='作者', max_length=200)),
|
||||
('tags', models.CharField(blank=True, default='', help_text='标签', max_length=500)),
|
||||
('severity', models.CharField(blank=True, default='info', help_text='严重程度', max_length=50)),
|
||||
('metadata', models.JSONField(blank=True, default=dict, help_text='元数据')),
|
||||
('http', models.JSONField(default=list, help_text='HTTP 匹配规则')),
|
||||
('source_file', models.CharField(blank=True, default='', help_text='来源文件', max_length=500)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'FingerPrintHub 指纹',
|
||||
'verbose_name_plural': 'FingerPrintHub 指纹',
|
||||
'db_table': 'fingerprinthub_fingerprint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['fp_id'], name='fingerprint_fp_id_df467f_idx'), models.Index(fields=['name'], name='fingerprint_name_95b6fb_idx'), models.Index(fields=['author'], name='fingerprint_author_80f54b_idx'), models.Index(fields=['severity'], name='fingerprint_severit_f70422_idx'), models.Index(fields=['-created_at'], name='fingerprint_created_bec16c_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='FingersFingerprint',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(help_text='指纹名称', max_length=300, unique=True)),
|
||||
('link', models.URLField(blank=True, default='', help_text='相关链接', max_length=500)),
|
||||
('rule', models.JSONField(default=list, help_text='匹配规则数组')),
|
||||
('tag', models.JSONField(default=list, help_text='标签数组')),
|
||||
('focus', models.BooleanField(default=False, help_text='是否重点关注')),
|
||||
('default_port', models.JSONField(blank=True, default=list, help_text='默认端口数组')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Fingers 指纹',
|
||||
'verbose_name_plural': 'Fingers 指纹',
|
||||
'db_table': 'fingers_fingerprint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['name'], name='fingers_fin_name_952de0_idx'), models.Index(fields=['link'], name='fingers_fin_link_4c6b7f_idx'), models.Index(fields=['focus'], name='fingers_fin_focus_568c7f_idx'), models.Index(fields=['-created_at'], name='fingers_fin_created_46fc91_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='GobyFingerprint',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(help_text='产品名称', max_length=300, unique=True)),
|
||||
('logic', models.CharField(help_text='逻辑表达式', max_length=500)),
|
||||
('rule', models.JSONField(default=list, help_text='规则数组')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Goby 指纹',
|
||||
'verbose_name_plural': 'Goby 指纹',
|
||||
'db_table': 'goby_fingerprint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['name'], name='goby_finger_name_82084c_idx'), models.Index(fields=['logic'], name='goby_finger_logic_a63226_idx'), models.Index(fields=['-created_at'], name='goby_finger_created_50e000_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ScanEngine',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('name', models.CharField(help_text='引擎名称', max_length=200, unique=True)),
|
||||
('configuration', models.CharField(blank=True, default='', help_text='引擎配置,yaml 格式', max_length=10000)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '扫描引擎',
|
||||
'verbose_name_plural': '扫描引擎',
|
||||
'db_table': 'scan_engine',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['-created_at'], name='scan_engine_created_da4870_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='WappalyzerFingerprint',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(help_text='应用名称', max_length=300, unique=True)),
|
||||
('cats', models.JSONField(default=list, help_text='分类 ID 数组')),
|
||||
('cookies', models.JSONField(blank=True, default=dict, help_text='Cookie 检测规则')),
|
||||
('headers', models.JSONField(blank=True, default=dict, help_text='HTTP Header 检测规则')),
|
||||
('script_src', models.JSONField(blank=True, default=list, help_text='脚本 URL 正则数组')),
|
||||
('js', models.JSONField(blank=True, default=list, help_text='JavaScript 变量检测规则')),
|
||||
('implies', models.JSONField(blank=True, default=list, help_text='依赖关系数组')),
|
||||
('meta', models.JSONField(blank=True, default=dict, help_text='HTML meta 标签检测规则')),
|
||||
('html', models.JSONField(blank=True, default=list, help_text='HTML 内容正则数组')),
|
||||
('description', models.TextField(blank=True, default='', help_text='应用描述')),
|
||||
('website', models.URLField(blank=True, default='', help_text='官网链接', max_length=500)),
|
||||
('cpe', models.CharField(blank=True, default='', help_text='CPE 标识符', max_length=300)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Wappalyzer 指纹',
|
||||
'verbose_name_plural': 'Wappalyzer 指纹',
|
||||
'db_table': 'wappalyzer_fingerprint',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['name'], name='wappalyzer__name_63c669_idx'), models.Index(fields=['website'], name='wappalyzer__website_88de1c_idx'), models.Index(fields=['cpe'], name='wappalyzer__cpe_30c761_idx'), models.Index(fields=['-created_at'], name='wappalyzer__created_8e6c21_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Wordlist',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('name', models.CharField(help_text='字典名称,唯一', max_length=200, unique=True)),
|
||||
('description', models.CharField(blank=True, default='', help_text='字典描述', max_length=200)),
|
||||
('file_path', models.CharField(help_text='后端保存的字典文件绝对路径', max_length=500)),
|
||||
('file_size', models.BigIntegerField(default=0, help_text='文件大小(字节)')),
|
||||
('line_count', models.IntegerField(default=0, help_text='字典行数')),
|
||||
('file_hash', models.CharField(blank=True, default='', help_text='文件 SHA-256 哈希,用于缓存校验', max_length=64)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '字典文件',
|
||||
'verbose_name_plural': '字典文件',
|
||||
'db_table': 'wordlist',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['-created_at'], name='wordlist_created_4afb02_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='WorkerNode',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(help_text='节点名称', max_length=100)),
|
||||
('ip_address', models.GenericIPAddressField(help_text='IP 地址(本地节点为 127.0.0.1)')),
|
||||
('ssh_port', models.IntegerField(default=22, help_text='SSH 端口')),
|
||||
('username', models.CharField(default='root', help_text='SSH 用户名', max_length=50)),
|
||||
('password', models.CharField(blank=True, default='', help_text='SSH 密码', max_length=200)),
|
||||
('is_local', models.BooleanField(default=False, help_text='是否为本地节点(Docker 容器内)')),
|
||||
('status', models.CharField(choices=[('pending', '待部署'), ('deploying', '部署中'), ('online', '在线'), ('offline', '离线'), ('updating', '更新中'), ('outdated', '版本过低')], default='pending', help_text='状态: pending/deploying/online/offline', max_length=20)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
('updated_at', models.DateTimeField(auto_now=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Worker 节点',
|
||||
'db_table': 'worker_node',
|
||||
'ordering': ['-created_at'],
|
||||
'constraints': [models.UniqueConstraint(condition=models.Q(('is_local', False)), fields=('ip_address',), name='unique_remote_worker_ip'), models.UniqueConstraint(fields=('name',), name='unique_worker_name')],
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -88,6 +88,8 @@ def _register_scheduled_jobs(scheduler: BackgroundScheduler):
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(" - 已注册: 扫描结果清理(每天 03:00)")
|
||||
|
||||
# 注意:搜索物化视图刷新已迁移到 pg_ivm 增量维护,无需定时任务
|
||||
|
||||
|
||||
def _trigger_scheduled_scans():
|
||||
|
||||
@@ -66,6 +66,7 @@ def get_start_agent_script(
|
||||
# 替换变量
|
||||
script = script.replace("{{HEARTBEAT_API_URL}}", heartbeat_api_url or '')
|
||||
script = script.replace("{{WORKER_ID}}", str(worker_id) if worker_id else '')
|
||||
script = script.replace("{{WORKER_API_KEY}}", getattr(settings, 'WORKER_API_KEY', ''))
|
||||
|
||||
# 注入镜像版本配置(确保远程节点使用相同版本)
|
||||
docker_user = getattr(settings, 'DOCKER_USER', 'yyhuni')
|
||||
|
||||
@@ -16,10 +16,9 @@ class GobyFingerprintService(BaseFingerprintService):
|
||||
"""
|
||||
校验单条 Goby 指纹
|
||||
|
||||
校验规则:
|
||||
- name 字段必须存在且非空
|
||||
- logic 字段必须存在
|
||||
- rule 字段必须是数组
|
||||
支持两种格式:
|
||||
1. 标准格式: {"name": "...", "logic": "...", "rule": [...]}
|
||||
2. JSONL 格式: {"product": "...", "rule": "..."}
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
@@ -27,25 +26,43 @@ class GobyFingerprintService(BaseFingerprintService):
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
# 标准格式:name + logic + rule(数组)
|
||||
name = item.get('name', '')
|
||||
logic = item.get('logic', '')
|
||||
rule = item.get('rule')
|
||||
return bool(name and str(name).strip()) and bool(logic) and isinstance(rule, list)
|
||||
if name and item.get('logic') is not None and isinstance(item.get('rule'), list):
|
||||
return bool(str(name).strip())
|
||||
|
||||
# JSONL 格式:product + rule(字符串)
|
||||
product = item.get('product', '')
|
||||
rule = item.get('rule', '')
|
||||
return bool(product and str(product).strip() and rule and str(rule).strip())
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换 Goby JSON 格式为 Model 字段
|
||||
|
||||
支持两种输入格式:
|
||||
1. 标准格式: {"name": "...", "logic": "...", "rule": [...]}
|
||||
2. JSONL 格式: {"product": "...", "rule": "..."}
|
||||
|
||||
Args:
|
||||
item: 原始 Goby JSON 数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
# 标准格式
|
||||
if 'name' in item and isinstance(item.get('rule'), list):
|
||||
return {
|
||||
'name': str(item.get('name', '')).strip(),
|
||||
'logic': item.get('logic', ''),
|
||||
'rule': item.get('rule', []),
|
||||
}
|
||||
|
||||
# JSONL 格式:将 rule 字符串转为单元素数组
|
||||
return {
|
||||
'name': str(item.get('name', '')).strip(),
|
||||
'logic': item.get('logic', ''),
|
||||
'rule': item.get('rule', []),
|
||||
'name': str(item.get('product', '')).strip(),
|
||||
'logic': 'or', # JSONL 格式默认 or 逻辑
|
||||
'rule': [item.get('rule', '')] if item.get('rule') else [],
|
||||
}
|
||||
|
||||
def get_export_data(self) -> list:
|
||||
|
||||
@@ -264,10 +264,6 @@ class TaskDistributor:
|
||||
"""
|
||||
import shlex
|
||||
|
||||
# Docker API 版本配置(避免版本不兼容问题)
|
||||
# 默认使用 1.40 以获得最大兼容性(支持 Docker 19.03+)
|
||||
api_version = getattr(settings, 'DOCKER_API_VERSION', '1.40')
|
||||
|
||||
# 根据 Worker 类型确定网络和 Server 地址
|
||||
if worker.is_local:
|
||||
# 本地:加入 Docker 网络,使用内部服务名
|
||||
@@ -288,6 +284,7 @@ class TaskDistributor:
|
||||
env_vars = [
|
||||
f"-e SERVER_URL={shlex.quote(server_url)}",
|
||||
f"-e IS_LOCAL={is_local_str}",
|
||||
f"-e WORKER_API_KEY={shlex.quote(settings.WORKER_API_KEY)}", # Worker API 认证密钥
|
||||
"-e PREFECT_HOME=/tmp/.prefect", # 设置 Prefect 数据目录到可写位置
|
||||
"-e PREFECT_SERVER_EPHEMERAL_ENABLED=true", # 启用 ephemeral server(本地临时服务器)
|
||||
"-e PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=120", # 增加启动超时时间
|
||||
@@ -315,9 +312,7 @@ class TaskDistributor:
|
||||
# - 本地 Worker:install.sh 已预拉取镜像,直接使用本地版本
|
||||
# - 远程 Worker:deploy 时已预拉取镜像,直接使用本地版本
|
||||
# - 避免每次任务都检查 Docker Hub,提升性能和稳定性
|
||||
# 使用双引号包裹 sh -c 命令,内部 shlex.quote 生成的单引号参数可正确解析
|
||||
# DOCKER_API_VERSION 环境变量确保客户端和服务端 API 版本兼容
|
||||
cmd = f'''DOCKER_API_VERSION={api_version} docker run --rm -d --pull=missing {network_arg} \\
|
||||
cmd = f'''docker run --rm -d --pull=missing {network_arg} \\
|
||||
{' '.join(env_vars)} \\
|
||||
{' '.join(volumes)} \\
|
||||
{self.docker_image} \\
|
||||
|
||||
@@ -139,7 +139,7 @@ class BaseFingerprintViewSet(viewsets.ModelViewSet):
|
||||
POST /api/engine/fingerprints/{type}/import_file/
|
||||
|
||||
请求格式:multipart/form-data
|
||||
- file: JSON 文件
|
||||
- file: JSON 文件(支持标准 JSON 和 JSONL 格式)
|
||||
|
||||
返回:同 batch_create
|
||||
"""
|
||||
@@ -148,9 +148,12 @@ class BaseFingerprintViewSet(viewsets.ModelViewSet):
|
||||
raise ValidationError('缺少文件')
|
||||
|
||||
try:
|
||||
json_data = json.load(file)
|
||||
content = file.read().decode('utf-8')
|
||||
json_data = self._parse_json_content(content)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValidationError(f'无效的 JSON 格式: {e}')
|
||||
except UnicodeDecodeError as e:
|
||||
raise ValidationError(f'文件编码错误: {e}')
|
||||
|
||||
fingerprints = self.parse_import_data(json_data)
|
||||
if not fingerprints:
|
||||
@@ -159,6 +162,41 @@ class BaseFingerprintViewSet(viewsets.ModelViewSet):
|
||||
result = self.get_service().batch_create_fingerprints(fingerprints)
|
||||
return success_response(data=result, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
def _parse_json_content(self, content: str):
|
||||
"""
|
||||
解析 JSON 内容,支持标准 JSON 和 JSONL 格式
|
||||
|
||||
Args:
|
||||
content: 文件内容字符串
|
||||
|
||||
Returns:
|
||||
解析后的数据(list 或 dict)
|
||||
"""
|
||||
content = content.strip()
|
||||
|
||||
# 尝试标准 JSON 解析
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试 JSONL 格式(每行一个 JSON 对象)
|
||||
lines = content.split('\n')
|
||||
result = []
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
result.append(json.loads(line))
|
||||
except json.JSONDecodeError as e:
|
||||
raise json.JSONDecodeError(f'第 {i + 1} 行解析失败: {e.msg}', e.doc, e.pos)
|
||||
|
||||
if not result:
|
||||
raise json.JSONDecodeError('文件为空或格式无效', content, 0)
|
||||
|
||||
return result
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='bulk-delete')
|
||||
def bulk_delete(self, request):
|
||||
"""
|
||||
|
||||
@@ -340,13 +340,12 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
|
||||
返回:
|
||||
{
|
||||
"db": {"host": "...", "port": "...", ...},
|
||||
"redisUrl": "...",
|
||||
"paths": {"results": "...", "logs": "..."}
|
||||
}
|
||||
|
||||
配置逻辑:
|
||||
- 本地 Worker (is_local=true): db_host=postgres, redis=redis:6379
|
||||
- 远程 Worker (is_local=false): db_host=PUBLIC_HOST, redis=PUBLIC_HOST:6379
|
||||
- 本地 Worker (is_local=true): db_host=postgres
|
||||
- 远程 Worker (is_local=false): db_host=PUBLIC_HOST
|
||||
"""
|
||||
from django.conf import settings
|
||||
import logging
|
||||
@@ -371,20 +370,17 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
|
||||
if is_local_worker:
|
||||
# 本地 Worker:直接用 Docker 内部服务名
|
||||
worker_db_host = 'postgres'
|
||||
worker_redis_url = 'redis://redis:6379/0'
|
||||
else:
|
||||
# 远程 Worker:通过公网 IP 访问
|
||||
public_host = settings.PUBLIC_HOST
|
||||
if public_host in ('server', 'localhost', '127.0.0.1'):
|
||||
logger.warning("远程 Worker 请求配置,但 PUBLIC_HOST=%s 不是有效的公网地址", public_host)
|
||||
worker_db_host = public_host
|
||||
worker_redis_url = f'redis://{public_host}:6379/0'
|
||||
else:
|
||||
# 远程数据库场景:所有 Worker 都用 DB_HOST
|
||||
worker_db_host = db_host
|
||||
worker_redis_url = getattr(settings, 'WORKER_REDIS_URL', 'redis://redis:6379/0')
|
||||
|
||||
logger.info("返回 Worker 配置 - db_host: %s, redis_url: %s", worker_db_host, worker_redis_url)
|
||||
logger.info("返回 Worker 配置 - db_host: %s", worker_db_host)
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
@@ -395,7 +391,6 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
|
||||
'user': settings.DATABASES['default']['USER'],
|
||||
'password': settings.DATABASES['default']['PASSWORD'],
|
||||
},
|
||||
'redisUrl': worker_redis_url,
|
||||
'paths': {
|
||||
'results': getattr(settings, 'CONTAINER_RESULTS_MOUNT', '/opt/xingrin/results'),
|
||||
'logs': getattr(settings, 'CONTAINER_LOGS_MOUNT', '/opt/xingrin/logs'),
|
||||
|
||||
@@ -13,12 +13,14 @@ SCAN_TOOLS_BASE_PATH = getattr(settings, 'SCAN_TOOLS_BASE_PATH', '/usr/local/bin
|
||||
|
||||
SUBDOMAIN_DISCOVERY_COMMANDS = {
|
||||
'subfinder': {
|
||||
# 默认使用所有数据源(更全面,略慢),并始终开启递归
|
||||
# -all 使用所有数据源
|
||||
# -recursive 对支持递归的源启用递归枚举(默认开启)
|
||||
'base': "subfinder -d {domain} -all -recursive -o '{output_file}' -silent",
|
||||
# 使用所有数据源(包括付费源,只要配置了 API key)
|
||||
# -all 使用所有数据源(slow 但全面)
|
||||
# -v 显示详细输出,包括使用的数据源(调试用)
|
||||
# 注意:不要加 -recursive,它会排除不支持递归的源(如 fofa)
|
||||
'base': "subfinder -d {domain} -all -o '{output_file}' -v",
|
||||
'optional': {
|
||||
'threads': '-t {threads}', # 控制并发 goroutine 数
|
||||
'provider_config': "-pc '{provider_config}'", # Provider 配置文件路径
|
||||
}
|
||||
},
|
||||
|
||||
@@ -97,9 +99,11 @@ SITE_SCAN_COMMANDS = {
|
||||
'base': (
|
||||
"'{scan_tools_base}/httpx' -l '{url_file}' "
|
||||
'-status-code -content-type -content-length '
|
||||
'-location -title -server -body-preview '
|
||||
'-location -title -server '
|
||||
'-tech-detect -cdn -vhost '
|
||||
'-random-agent -no-color -json'
|
||||
'-include-response '
|
||||
'-rstr 2000 '
|
||||
'-random-agent -no-color -json -silent'
|
||||
),
|
||||
'optional': {
|
||||
'threads': '-threads {threads}',
|
||||
@@ -169,9 +173,11 @@ URL_FETCH_COMMANDS = {
|
||||
'base': (
|
||||
"'{scan_tools_base}/httpx' -l '{url_file}' "
|
||||
'-status-code -content-type -content-length '
|
||||
'-location -title -server -body-preview '
|
||||
'-location -title -server '
|
||||
'-tech-detect -cdn -vhost '
|
||||
'-random-agent -no-color -json'
|
||||
'-include-response '
|
||||
'-rstr 2000 '
|
||||
'-random-agent -no-color -json -silent'
|
||||
),
|
||||
'optional': {
|
||||
'threads': '-threads {threads}',
|
||||
|
||||
@@ -4,14 +4,12 @@
|
||||
# 必需参数:enabled(是否启用)
|
||||
# 可选参数:timeout(超时秒数,默认 auto 自动计算)
|
||||
|
||||
# ==================== 子域名发现 ====================
|
||||
#
|
||||
# Stage 1: 被动收集(并行) - 必选,至少启用一个工具
|
||||
# Stage 2: 字典爆破(可选) - 使用字典暴力枚举子域名
|
||||
# Stage 3: 变异生成 + 验证(可选) - 基于已发现域名生成变异,流式验证存活
|
||||
# Stage 4: DNS 存活验证(可选) - 验证所有候选域名是否能解析
|
||||
#
|
||||
subdomain_discovery:
|
||||
# ==================== 子域名发现 ====================
|
||||
# Stage 1: 被动收集(并行) - 必选,至少启用一个工具
|
||||
# Stage 2: 字典爆破(可选) - 使用字典暴力枚举子域名
|
||||
# Stage 3: 变异生成 + 验证(可选) - 基于已发现域名生成变异,流式验证存活
|
||||
# Stage 4: DNS 存活验证(可选) - 验证所有候选域名是否能解析
|
||||
# === Stage 1: 被动收集工具(并行执行)===
|
||||
passive_tools:
|
||||
subfinder:
|
||||
@@ -55,8 +53,8 @@ subdomain_discovery:
|
||||
subdomain_resolve:
|
||||
timeout: auto # 自动根据候选子域数量计算
|
||||
|
||||
# ==================== 端口扫描 ====================
|
||||
port_scan:
|
||||
# ==================== 端口扫描 ====================
|
||||
tools:
|
||||
naabu_active:
|
||||
enabled: true
|
||||
@@ -70,8 +68,8 @@ port_scan:
|
||||
enabled: true
|
||||
# timeout: auto # 被动扫描通常较快
|
||||
|
||||
# ==================== 站点扫描 ====================
|
||||
site_scan:
|
||||
# ==================== 站点扫描 ====================
|
||||
tools:
|
||||
httpx:
|
||||
enabled: true
|
||||
@@ -81,16 +79,16 @@ site_scan:
|
||||
# request-timeout: 10 # 单个请求超时秒数(默认 10)
|
||||
# retries: 2 # 请求失败重试次数
|
||||
|
||||
# ==================== 指纹识别 ====================
|
||||
# 在 site_scan 后串行执行,识别 WebSite 的技术栈
|
||||
fingerprint_detect:
|
||||
# ==================== 指纹识别 ====================
|
||||
# 在 站点扫描 后串行执行,识别 WebSite 的技术栈
|
||||
tools:
|
||||
xingfinger:
|
||||
enabled: true
|
||||
fingerprint-libs: [ehole, goby, wappalyzer] # 启用的指纹库:ehole, goby, wappalyzer, fingers, fingerprinthub, arl
|
||||
fingerprint-libs: [ehole, goby, wappalyzer, fingers, fingerprinthub, arl] # 默认启动全部指纹库
|
||||
|
||||
# ==================== 目录扫描 ====================
|
||||
directory_scan:
|
||||
# ==================== 目录扫描 ====================
|
||||
tools:
|
||||
ffuf:
|
||||
enabled: true
|
||||
@@ -103,8 +101,8 @@ directory_scan:
|
||||
match-codes: 200,201,301,302,401,403 # 匹配的 HTTP 状态码
|
||||
# rate: 0 # 每秒请求数(默认 0 不限制)
|
||||
|
||||
# ==================== URL 获取 ====================
|
||||
url_fetch:
|
||||
# ==================== URL 获取 ====================
|
||||
tools:
|
||||
waymore:
|
||||
enabled: true
|
||||
@@ -142,8 +140,8 @@ url_fetch:
|
||||
# request-timeout: 10 # 单个请求超时秒数(默认 10)
|
||||
# retries: 2 # 请求失败重试次数
|
||||
|
||||
# ==================== 漏洞扫描 ====================
|
||||
vuln_scan:
|
||||
# ==================== 漏洞扫描 ====================
|
||||
tools:
|
||||
dalfox_xss:
|
||||
enabled: true
|
||||
|
||||
@@ -33,7 +33,7 @@ from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.utils import config_parser, build_scan_command, ensure_wordlist_local
|
||||
from apps.scan.utils import config_parser, build_scan_command, ensure_wordlist_local, user_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -413,6 +413,7 @@ def _run_scans_concurrently(
|
||||
logger.info("="*60)
|
||||
logger.info("使用工具: %s (并发模式, max_workers=%d)", tool_name, max_workers)
|
||||
logger.info("="*60)
|
||||
user_log(scan_id, "directory_scan", f"Running {tool_name}")
|
||||
|
||||
# 如果配置了 wordlist_name,则先确保本地存在对应的字典文件(含 hash 校验)
|
||||
wordlist_name = tool_config.get('wordlist_name')
|
||||
@@ -467,6 +468,11 @@ def _run_scans_concurrently(
|
||||
total_tasks = len(scan_params_list)
|
||||
logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers)
|
||||
|
||||
# 进度里程碑跟踪
|
||||
last_progress_percent = 0
|
||||
tool_directories = 0
|
||||
tool_processed = 0
|
||||
|
||||
batch_num = 0
|
||||
for batch_start in range(0, total_tasks, max_workers):
|
||||
batch_end = min(batch_start + max_workers, total_tasks)
|
||||
@@ -498,7 +504,9 @@ def _run_scans_concurrently(
|
||||
result = future.result() # 阻塞等待单个任务完成
|
||||
directories_found = result.get('created_directories', 0)
|
||||
total_directories += directories_found
|
||||
tool_directories += directories_found
|
||||
processed_sites_count += 1
|
||||
tool_processed += 1
|
||||
|
||||
logger.info(
|
||||
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
|
||||
@@ -517,6 +525,19 @@ def _run_scans_concurrently(
|
||||
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
|
||||
idx, len(sites), site_url, exc
|
||||
)
|
||||
|
||||
# 进度里程碑:每 20% 输出一次
|
||||
current_progress = int((batch_end / total_tasks) * 100)
|
||||
if current_progress >= last_progress_percent + 20:
|
||||
user_log(scan_id, "directory_scan", f"Progress: {batch_end}/{total_tasks} sites scanned")
|
||||
last_progress_percent = (current_progress // 20) * 20
|
||||
|
||||
# 工具完成日志(开发者日志 + 用户日志)
|
||||
logger.info(
|
||||
"✓ 工具 %s 执行完成 - 已处理站点: %d/%d, 发现目录: %d",
|
||||
tool_name, tool_processed, total_tasks, tool_directories
|
||||
)
|
||||
user_log(scan_id, "directory_scan", f"{tool_name} completed: found {tool_directories} directories")
|
||||
|
||||
# 输出汇总信息
|
||||
if failed_sites:
|
||||
@@ -605,6 +626,8 @@ def directory_scan_flow(
|
||||
"="*60
|
||||
)
|
||||
|
||||
user_log(scan_id, "directory_scan", "Starting directory scan")
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
@@ -625,7 +648,8 @@ def directory_scan_flow(
|
||||
sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir)
|
||||
|
||||
if site_count == 0:
|
||||
logger.warning("目标下没有站点,跳过目录扫描")
|
||||
logger.warning("跳过目录扫描:没有站点可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "directory_scan", "Skipped: no sites to scan", "warning")
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -664,7 +688,9 @@ def directory_scan_flow(
|
||||
logger.warning("所有站点扫描均失败 - 总站点数: %d, 失败数: %d", site_count, len(failed_sites))
|
||||
# 不抛出异常,让扫描继续
|
||||
|
||||
logger.info("="*60 + "\n✓ 目录扫描完成\n" + "="*60)
|
||||
# 记录 Flow 完成
|
||||
logger.info("✓ 目录扫描完成 - 发现目录: %d", total_directories)
|
||||
user_log(scan_id, "directory_scan", f"directory_scan completed: found {total_directories} directories")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
|
||||
@@ -29,7 +29,7 @@ from apps.scan.tasks.fingerprint_detect import (
|
||||
export_urls_for_fingerprint_task,
|
||||
run_xingfinger_and_stream_update_tech_task,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command
|
||||
from apps.scan.utils import build_scan_command, user_log
|
||||
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,28 +37,24 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def calculate_fingerprint_detect_timeout(
|
||||
url_count: int,
|
||||
base_per_url: float = 3.0,
|
||||
min_timeout: int = 60
|
||||
base_per_url: float = 10.0,
|
||||
min_timeout: int = 300
|
||||
) -> int:
|
||||
"""
|
||||
根据 URL 数量计算超时时间
|
||||
|
||||
公式:超时时间 = URL 数量 × 每 URL 基础时间
|
||||
最小值:60秒
|
||||
最小值:300秒
|
||||
无上限
|
||||
|
||||
Args:
|
||||
url_count: URL 数量
|
||||
base_per_url: 每 URL 基础时间(秒),默认 3秒
|
||||
min_timeout: 最小超时时间(秒),默认 60秒
|
||||
base_per_url: 每 URL 基础时间(秒),默认 10秒
|
||||
min_timeout: 最小超时时间(秒),默认 300秒
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒)
|
||||
|
||||
示例:
|
||||
100 URL × 3秒 = 300秒
|
||||
1000 URL × 3秒 = 3000秒(50分钟)
|
||||
10000 URL × 3秒 = 30000秒(8.3小时)
|
||||
"""
|
||||
timeout = int(url_count * base_per_url)
|
||||
return max(min_timeout, timeout)
|
||||
@@ -172,6 +168,7 @@ def _run_fingerprint_detect(
|
||||
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
|
||||
tool_name, url_count, timeout, list(fingerprint_paths.keys())
|
||||
)
|
||||
user_log(scan_id, "fingerprint_detect", f"Running {tool_name}: {command}")
|
||||
|
||||
# 6. 执行扫描任务
|
||||
try:
|
||||
@@ -194,17 +191,21 @@ def _run_fingerprint_detect(
|
||||
'fingerprint_libs': list(fingerprint_paths.keys())
|
||||
}
|
||||
|
||||
tool_updated = result.get('updated_count', 0)
|
||||
logger.info(
|
||||
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
|
||||
tool_name,
|
||||
result.get('processed_records', 0),
|
||||
result.get('updated_count', 0),
|
||||
tool_updated,
|
||||
result.get('not_found_count', 0)
|
||||
)
|
||||
user_log(scan_id, "fingerprint_detect", f"{tool_name} completed: identified {tool_updated} fingerprints")
|
||||
|
||||
except Exception as exc:
|
||||
failed_tools.append({'tool': tool_name, 'reason': str(exc)})
|
||||
reason = str(exc)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
user_log(scan_id, "fingerprint_detect", f"{tool_name} failed: {reason}", "error")
|
||||
|
||||
if failed_tools:
|
||||
logger.warning(
|
||||
@@ -260,7 +261,8 @@ def fingerprint_detect_flow(
|
||||
'url_count': int,
|
||||
'processed_records': int,
|
||||
'updated_count': int,
|
||||
'not_found_count': int,
|
||||
'created_count': int,
|
||||
'snapshot_count': int,
|
||||
'executed_tasks': list,
|
||||
'tool_stats': dict
|
||||
}
|
||||
@@ -275,6 +277,8 @@ def fingerprint_detect_flow(
|
||||
"="*60
|
||||
)
|
||||
|
||||
user_log(scan_id, "fingerprint_detect", "Starting fingerprint detection")
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
@@ -296,7 +300,8 @@ def fingerprint_detect_flow(
|
||||
urls_file, url_count = _export_urls(target_id, fingerprint_dir, source)
|
||||
|
||||
if url_count == 0:
|
||||
logger.warning("目标下没有可用的 URL,跳过指纹识别")
|
||||
logger.warning("跳过指纹识别:没有 URL 可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "fingerprint_detect", "Skipped: no URLs to scan", "warning")
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -307,6 +312,7 @@ def fingerprint_detect_flow(
|
||||
'processed_records': 0,
|
||||
'updated_count': 0,
|
||||
'created_count': 0,
|
||||
'snapshot_count': 0,
|
||||
'executed_tasks': ['export_urls_for_fingerprint'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
@@ -334,8 +340,6 @@ def fingerprint_detect_flow(
|
||||
source=source
|
||||
)
|
||||
|
||||
logger.info("="*60 + "\n✓ 指纹识别完成\n" + "="*60)
|
||||
|
||||
# 动态生成已执行的任务列表
|
||||
executed_tasks = ['export_urls_for_fingerprint']
|
||||
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats.keys()])
|
||||
@@ -344,6 +348,11 @@ def fingerprint_detect_flow(
|
||||
total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values())
|
||||
total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values())
|
||||
total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values())
|
||||
total_snapshots = sum(stats['result'].get('snapshot_count', 0) for stats in tool_stats.values())
|
||||
|
||||
# 记录 Flow 完成
|
||||
logger.info("✓ 指纹识别完成 - 识别指纹: %d", total_updated)
|
||||
user_log(scan_id, "fingerprint_detect", f"fingerprint_detect completed: identified {total_updated} fingerprints")
|
||||
|
||||
successful_tools = [name for name in enabled_tools.keys()
|
||||
if name not in [f['tool'] for f in failed_tools]]
|
||||
@@ -358,6 +367,7 @@ def fingerprint_detect_flow(
|
||||
'processed_records': total_processed,
|
||||
'updated_count': total_updated,
|
||||
'created_count': total_created,
|
||||
'snapshot_count': total_snapshots,
|
||||
'executed_tasks': executed_tasks,
|
||||
'tool_stats': {
|
||||
'total': len(enabled_tools),
|
||||
|
||||
@@ -114,8 +114,11 @@ def initiate_scan_flow(
|
||||
|
||||
# ==================== Task 2: 获取引擎配置 ====================
|
||||
from apps.scan.models import Scan
|
||||
scan = Scan.objects.select_related('engine').get(id=scan_id)
|
||||
engine_config = scan.engine.configuration
|
||||
scan = Scan.objects.get(id=scan_id)
|
||||
engine_config = scan.yaml_configuration
|
||||
|
||||
# 使用 engine_names 进行显示
|
||||
display_engine_name = ', '.join(scan.engine_names) if scan.engine_names else engine_name
|
||||
|
||||
# ==================== Task 3: 解析配置,生成执行计划 ====================
|
||||
orchestrator = FlowOrchestrator(engine_config)
|
||||
|
||||
@@ -20,7 +20,7 @@ from pathlib import Path
|
||||
from typing import Callable
|
||||
from prefect import flow
|
||||
from apps.scan.tasks.port_scan import (
|
||||
export_scan_targets_task,
|
||||
export_hosts_task,
|
||||
run_and_stream_save_ports_task
|
||||
)
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
@@ -28,7 +28,7 @@ from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.utils import config_parser, build_scan_command
|
||||
from apps.scan.utils import config_parser, build_scan_command, user_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -157,9 +157,9 @@ def _parse_port_count(tool_config: dict) -> int:
|
||||
|
||||
|
||||
|
||||
def _export_scan_targets(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
|
||||
def _export_hosts(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
|
||||
"""
|
||||
导出扫描目标到文件
|
||||
导出主机列表到文件
|
||||
|
||||
根据 Target 类型自动决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名
|
||||
@@ -171,31 +171,31 @@ def _export_scan_targets(target_id: int, port_scan_dir: Path) -> tuple[str, int,
|
||||
port_scan_dir: 端口扫描目录
|
||||
|
||||
Returns:
|
||||
tuple: (targets_file, target_count, target_type)
|
||||
tuple: (hosts_file, host_count, target_type)
|
||||
"""
|
||||
logger.info("Step 1: 导出扫描目标列表")
|
||||
logger.info("Step 1: 导出主机列表")
|
||||
|
||||
targets_file = str(port_scan_dir / 'targets.txt')
|
||||
export_result = export_scan_targets_task(
|
||||
hosts_file = str(port_scan_dir / 'hosts.txt')
|
||||
export_result = export_hosts_task(
|
||||
target_id=target_id,
|
||||
output_file=targets_file,
|
||||
output_file=hosts_file,
|
||||
batch_size=1000 # 每次读取 1000 条,优化内存占用
|
||||
)
|
||||
|
||||
target_count = export_result['total_count']
|
||||
host_count = export_result['total_count']
|
||||
target_type = export_result.get('target_type', 'unknown')
|
||||
|
||||
logger.info(
|
||||
"✓ 扫描目标导出完成 - 类型: %s, 文件: %s, 数量: %d",
|
||||
"✓ 主机列表导出完成 - 类型: %s, 文件: %s, 数量: %d",
|
||||
target_type,
|
||||
export_result['output_file'],
|
||||
target_count
|
||||
host_count
|
||||
)
|
||||
|
||||
if target_count == 0:
|
||||
logger.warning("目标下没有可扫描的地址,无法执行端口扫描")
|
||||
if host_count == 0:
|
||||
logger.warning("目标下没有可扫描的主机,无法执行端口扫描")
|
||||
|
||||
return export_result['output_file'], target_count, target_type
|
||||
return export_result['output_file'], host_count, target_type
|
||||
|
||||
|
||||
def _run_scans_sequentially(
|
||||
@@ -265,6 +265,7 @@ def _run_scans_sequentially(
|
||||
|
||||
# 3. 执行扫描任务
|
||||
logger.info("开始执行 %s 扫描(超时: %d秒)...", tool_name, config_timeout)
|
||||
user_log(scan_id, "port_scan", f"Running {tool_name}: {command}")
|
||||
|
||||
try:
|
||||
# 直接调用 task(串行执行)
|
||||
@@ -286,26 +287,31 @@ def _run_scans_sequentially(
|
||||
'result': result,
|
||||
'timeout': config_timeout
|
||||
}
|
||||
processed_records += result.get('processed_records', 0)
|
||||
tool_records = result.get('processed_records', 0)
|
||||
processed_records += tool_records
|
||||
logger.info(
|
||||
"✓ 工具 %s 流式处理完成 - 记录数: %d",
|
||||
tool_name, result.get('processed_records', 0)
|
||||
tool_name, tool_records
|
||||
)
|
||||
user_log(scan_id, "port_scan", f"{tool_name} completed: found {tool_records} ports")
|
||||
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
# 超时异常单独处理
|
||||
# 注意:流式处理任务超时时,已解析的数据已保存到数据库
|
||||
reason = f"执行超时(配置: {config_timeout}秒)"
|
||||
reason = f"timeout after {config_timeout}s"
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.warning(
|
||||
"⚠️ 工具 %s 执行超时 - 超时配置: %d秒\n"
|
||||
"注意:超时前已解析的端口数据已保存到数据库,但扫描未完全完成。",
|
||||
tool_name, config_timeout
|
||||
)
|
||||
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
|
||||
except Exception as exc:
|
||||
# 其他异常
|
||||
failed_tools.append({'tool': tool_name, 'reason': str(exc)})
|
||||
reason = str(exc)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
|
||||
|
||||
if failed_tools:
|
||||
logger.warning(
|
||||
@@ -376,8 +382,8 @@ def port_scan_flow(
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'scan_workspace_dir': str,
|
||||
'domains_file': str,
|
||||
'domain_count': int,
|
||||
'hosts_file': str,
|
||||
'host_count': int,
|
||||
'processed_records': int,
|
||||
'executed_tasks': list,
|
||||
'tool_stats': {
|
||||
@@ -420,25 +426,28 @@ def port_scan_flow(
|
||||
"="*60
|
||||
)
|
||||
|
||||
user_log(scan_id, "port_scan", "Starting port scan")
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
|
||||
|
||||
# Step 1: 导出扫描目标列表到文件(根据 Target 类型自动决定内容)
|
||||
targets_file, target_count, target_type = _export_scan_targets(target_id, port_scan_dir)
|
||||
# Step 1: 导出主机列表到文件(根据 Target 类型自动决定内容)
|
||||
hosts_file, host_count, target_type = _export_hosts(target_id, port_scan_dir)
|
||||
|
||||
if target_count == 0:
|
||||
logger.warning("目标下没有可扫描的地址,跳过端口扫描")
|
||||
if host_count == 0:
|
||||
logger.warning("跳过端口扫描:没有主机可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "port_scan", "Skipped: no hosts to scan", "warning")
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'targets_file': targets_file,
|
||||
'target_count': 0,
|
||||
'hosts_file': hosts_file,
|
||||
'host_count': 0,
|
||||
'target_type': target_type,
|
||||
'processed_records': 0,
|
||||
'executed_tasks': ['export_scan_targets'],
|
||||
'executed_tasks': ['export_hosts'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
@@ -460,17 +469,19 @@ def port_scan_flow(
|
||||
logger.info("Step 3: 串行执行扫描工具")
|
||||
tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially(
|
||||
enabled_tools=enabled_tools,
|
||||
domains_file=targets_file, # 现在是 targets_file,兼容原参数名
|
||||
domains_file=hosts_file,
|
||||
port_scan_dir=port_scan_dir,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
target_name=target_name
|
||||
)
|
||||
|
||||
logger.info("="*60 + "\n✓ 端口扫描完成\n" + "="*60)
|
||||
# 记录 Flow 完成
|
||||
logger.info("✓ 端口扫描完成 - 发现端口: %d", processed_records)
|
||||
user_log(scan_id, "port_scan", f"port_scan completed: found {processed_records} ports")
|
||||
|
||||
# 动态生成已执行的任务列表
|
||||
executed_tasks = ['export_scan_targets', 'parse_config']
|
||||
executed_tasks = ['export_hosts', 'parse_config']
|
||||
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats.keys()])
|
||||
|
||||
return {
|
||||
@@ -478,8 +489,8 @@ def port_scan_flow(
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'targets_file': targets_file,
|
||||
'target_count': target_count,
|
||||
'hosts_file': hosts_file,
|
||||
'host_count': host_count,
|
||||
'target_type': target_type,
|
||||
'processed_records': processed_records,
|
||||
'executed_tasks': executed_tasks,
|
||||
@@ -488,8 +499,8 @@ def port_scan_flow(
|
||||
'successful': len(successful_tool_names),
|
||||
'failed': len(failed_tools),
|
||||
'successful_tools': successful_tool_names,
|
||||
'failed_tools': failed_tools, # [{'tool': 'naabu_active', 'reason': '超时'}]
|
||||
'details': tool_stats # 详细结果(保留向后兼容)
|
||||
'failed_tools': failed_tools,
|
||||
'details': tool_stats
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from prefect import flow
|
||||
@@ -26,7 +27,7 @@ from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.utils import config_parser, build_scan_command
|
||||
from apps.scan.utils import config_parser, build_scan_command, user_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -198,20 +199,20 @@ def _run_scans_sequentially(
|
||||
"开始执行 %s 站点扫描 - URL数: %d, 最终超时: %ds",
|
||||
tool_name, total_urls, timeout
|
||||
)
|
||||
user_log(scan_id, "site_scan", f"Running {tool_name}: {command}")
|
||||
|
||||
# 3. 执行扫描任务
|
||||
try:
|
||||
# 流式执行扫描并实时保存结果
|
||||
result = run_and_stream_save_websites_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name, # 新增:工具名称
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(site_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1000,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file) # 新增:日志文件路径
|
||||
log_file=str(log_file)
|
||||
)
|
||||
|
||||
tool_stats[tool_name] = {
|
||||
@@ -219,29 +220,35 @@ def _run_scans_sequentially(
|
||||
'result': result,
|
||||
'timeout': timeout
|
||||
}
|
||||
processed_records += result.get('processed_records', 0)
|
||||
tool_records = result.get('processed_records', 0)
|
||||
tool_created = result.get('created_websites', 0)
|
||||
processed_records += tool_records
|
||||
|
||||
logger.info(
|
||||
"✓ 工具 %s 流式处理完成 - 处理记录: %d, 创建站点: %d, 跳过: %d",
|
||||
tool_name,
|
||||
result.get('processed_records', 0),
|
||||
result.get('created_websites', 0),
|
||||
tool_records,
|
||||
tool_created,
|
||||
result.get('skipped_no_subdomain', 0) + result.get('skipped_failed', 0)
|
||||
)
|
||||
user_log(scan_id, "site_scan", f"{tool_name} completed: found {tool_created} websites")
|
||||
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
# 超时异常单独处理
|
||||
reason = f"执行超时(配置: {timeout}秒)"
|
||||
reason = f"timeout after {timeout}s"
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.warning(
|
||||
"⚠️ 工具 %s 执行超时 - 超时配置: %d秒\n"
|
||||
"注意:超时前已解析的站点数据已保存到数据库,但扫描未完全完成。",
|
||||
tool_name, timeout
|
||||
)
|
||||
user_log(scan_id, "site_scan", f"{tool_name} failed: {reason}", "error")
|
||||
except Exception as exc:
|
||||
# 其他异常
|
||||
failed_tools.append({'tool': tool_name, 'reason': str(exc)})
|
||||
reason = str(exc)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
user_log(scan_id, "site_scan", f"{tool_name} failed: {reason}", "error")
|
||||
|
||||
if failed_tools:
|
||||
logger.warning(
|
||||
@@ -380,6 +387,8 @@ def site_scan_flow(
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
|
||||
user_log(scan_id, "site_scan", "Starting site scan")
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan')
|
||||
@@ -390,7 +399,8 @@ def site_scan_flow(
|
||||
)
|
||||
|
||||
if total_urls == 0:
|
||||
logger.warning("目标下没有可用的站点URL,跳过站点扫描")
|
||||
logger.warning("跳过站点扫描:没有站点 URL 可扫描 - Scan ID: %s", scan_id)
|
||||
user_log(scan_id, "site_scan", "Skipped: no site URLs to scan", "warning")
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
@@ -433,8 +443,6 @@ def site_scan_flow(
|
||||
target_name=target_name
|
||||
)
|
||||
|
||||
logger.info("="*60 + "\n✓ 站点扫描完成\n" + "="*60)
|
||||
|
||||
# 动态生成已执行的任务列表
|
||||
executed_tasks = ['export_site_urls', 'parse_config']
|
||||
executed_tasks.extend([f'run_and_stream_save_websites ({tool})' for tool in tool_stats.keys()])
|
||||
@@ -444,6 +452,10 @@ def site_scan_flow(
|
||||
total_skipped_no_subdomain = sum(stats['result'].get('skipped_no_subdomain', 0) for stats in tool_stats.values())
|
||||
total_skipped_failed = sum(stats['result'].get('skipped_failed', 0) for stats in tool_stats.values())
|
||||
|
||||
# 记录 Flow 完成
|
||||
logger.info("✓ 站点扫描完成 - 创建站点: %d", total_created)
|
||||
user_log(scan_id, "site_scan", f"site_scan completed: found {total_created} websites")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
|
||||
@@ -30,7 +30,7 @@ from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command, ensure_wordlist_local
|
||||
from apps.scan.utils import build_scan_command, ensure_wordlist_local, user_log
|
||||
from apps.engine.services.wordlist_service import WordlistService
|
||||
from apps.common.normalizer import normalize_domain
|
||||
from apps.common.validators import validate_domain
|
||||
@@ -77,7 +77,9 @@ def _validate_and_normalize_target(target_name: str) -> str:
|
||||
def _run_scans_parallel(
|
||||
enabled_tools: dict,
|
||||
domain_name: str,
|
||||
result_dir: Path
|
||||
result_dir: Path,
|
||||
scan_id: int,
|
||||
provider_config_path: str = None
|
||||
) -> tuple[list, list, list]:
|
||||
"""
|
||||
并行运行所有启用的子域名扫描工具
|
||||
@@ -86,6 +88,8 @@ def _run_scans_parallel(
|
||||
enabled_tools: 启用的工具配置字典 {'tool_name': {'timeout': 600, ...}}
|
||||
domain_name: 目标域名
|
||||
result_dir: 结果输出目录
|
||||
scan_id: 扫描任务 ID(用于记录日志)
|
||||
provider_config_path: Provider 配置文件路径(可选,用于 subfinder)
|
||||
|
||||
Returns:
|
||||
tuple: (result_files, failed_tools, successful_tool_names)
|
||||
@@ -110,13 +114,19 @@ def _run_scans_parallel(
|
||||
|
||||
# 1.2 构建完整命令(变量替换)
|
||||
try:
|
||||
command_params = {
|
||||
'domain': domain_name, # 对应 {domain}
|
||||
'output_file': output_file # 对应 {output_file}
|
||||
}
|
||||
|
||||
# 如果是 subfinder 且有 provider_config,添加到参数
|
||||
if tool_name == 'subfinder' and provider_config_path:
|
||||
command_params['provider_config'] = provider_config_path
|
||||
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='subdomain_discovery',
|
||||
command_params={
|
||||
'domain': domain_name, # 对应 {domain}
|
||||
'output_file': output_file # 对应 {output_file}
|
||||
},
|
||||
command_params=command_params,
|
||||
tool_config=tool_config
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -137,6 +147,9 @@ def _run_scans_parallel(
|
||||
f"提交任务 - 工具: {tool_name}, 超时: {timeout}s, 输出: {output_file}"
|
||||
)
|
||||
|
||||
# 记录工具开始执行日志
|
||||
user_log(scan_id, "subdomain_discovery", f"Running {tool_name}: {command}")
|
||||
|
||||
future = run_subdomain_discovery_task.submit(
|
||||
tool=tool_name,
|
||||
command=command,
|
||||
@@ -164,16 +177,19 @@ def _run_scans_parallel(
|
||||
if result:
|
||||
result_files.append(result)
|
||||
logger.info("✓ 扫描工具 %s 执行成功: %s", tool_name, result)
|
||||
user_log(scan_id, "subdomain_discovery", f"{tool_name} completed")
|
||||
else:
|
||||
failure_msg = f"{tool_name}: 未生成结果文件"
|
||||
failures.append(failure_msg)
|
||||
failed_tools.append({'tool': tool_name, 'reason': '未生成结果文件'})
|
||||
logger.warning("⚠️ 扫描工具 %s 未生成结果文件", tool_name)
|
||||
user_log(scan_id, "subdomain_discovery", f"{tool_name} failed: no output file", "error")
|
||||
except Exception as e:
|
||||
failure_msg = f"{tool_name}: {str(e)}"
|
||||
failures.append(failure_msg)
|
||||
failed_tools.append({'tool': tool_name, 'reason': str(e)})
|
||||
logger.warning("⚠️ 扫描工具 %s 执行失败: %s", tool_name, str(e))
|
||||
user_log(scan_id, "subdomain_discovery", f"{tool_name} failed: {str(e)}", "error")
|
||||
|
||||
# 4. 检查是否有成功的工具
|
||||
if not result_files:
|
||||
@@ -203,7 +219,8 @@ def _run_single_tool(
|
||||
tool_config: dict,
|
||||
command_params: dict,
|
||||
result_dir: Path,
|
||||
scan_type: str = 'subdomain_discovery'
|
||||
scan_type: str = 'subdomain_discovery',
|
||||
scan_id: int = None
|
||||
) -> str:
|
||||
"""
|
||||
运行单个扫描工具
|
||||
@@ -214,6 +231,7 @@ def _run_single_tool(
|
||||
command_params: 命令参数
|
||||
result_dir: 结果目录
|
||||
scan_type: 扫描类型
|
||||
scan_id: 扫描 ID(用于记录用户日志)
|
||||
|
||||
Returns:
|
||||
str: 输出文件路径,失败返回空字符串
|
||||
@@ -242,7 +260,9 @@ def _run_single_tool(
|
||||
if timeout == 'auto':
|
||||
timeout = 3600
|
||||
|
||||
logger.info(f"执行 {tool_name}: timeout={timeout}s")
|
||||
logger.info(f"执行 {tool_name}: {command}")
|
||||
if scan_id:
|
||||
user_log(scan_id, scan_type, f"Running {tool_name}: {command}")
|
||||
|
||||
try:
|
||||
result = run_subdomain_discovery_task(
|
||||
@@ -401,7 +421,6 @@ def subdomain_discovery_flow(
|
||||
logger.warning("目标域名无效,跳过子域名发现扫描: %s", e)
|
||||
return _empty_result(scan_id, target_name, scan_workspace_dir)
|
||||
|
||||
# 验证成功后打印日志
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始子域名发现扫描\n" +
|
||||
@@ -410,6 +429,7 @@ def subdomain_discovery_flow(
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
)
|
||||
user_log(scan_id, "subdomain_discovery", f"Starting subdomain discovery for {domain_name}")
|
||||
|
||||
# 解析配置
|
||||
passive_tools = scan_config.get('passive_tools', {})
|
||||
@@ -428,24 +448,37 @@ def subdomain_discovery_flow(
|
||||
failed_tools = []
|
||||
successful_tool_names = []
|
||||
|
||||
# ==================== Stage 1: 被动收集(并行)====================
|
||||
logger.info("=" * 40)
|
||||
logger.info("Stage 1: 被动收集(并行)")
|
||||
logger.info("=" * 40)
|
||||
# ==================== 生成 Provider 配置文件 ====================
|
||||
# 为 subfinder 生成第三方数据源配置
|
||||
provider_config_path = None
|
||||
try:
|
||||
from apps.scan.services.subfinder_provider_config_service import SubfinderProviderConfigService
|
||||
provider_config_service = SubfinderProviderConfigService()
|
||||
provider_config_path = provider_config_service.generate(str(result_dir))
|
||||
if provider_config_path:
|
||||
logger.info(f"Provider 配置文件已生成: {provider_config_path}")
|
||||
user_log(scan_id, "subdomain_discovery", "Provider config generated for subfinder")
|
||||
except Exception as e:
|
||||
logger.warning(f"生成 Provider 配置文件失败: {e}")
|
||||
|
||||
# ==================== Stage 1: 被动收集(并行)====================
|
||||
if enabled_passive_tools:
|
||||
logger.info("=" * 40)
|
||||
logger.info("Stage 1: 被动收集(并行)")
|
||||
logger.info("=" * 40)
|
||||
logger.info("启用工具: %s", ', '.join(enabled_passive_tools.keys()))
|
||||
user_log(scan_id, "subdomain_discovery", f"Stage 1: passive collection ({', '.join(enabled_passive_tools.keys())})")
|
||||
result_files, stage1_failed, stage1_success = _run_scans_parallel(
|
||||
enabled_tools=enabled_passive_tools,
|
||||
domain_name=domain_name,
|
||||
result_dir=result_dir
|
||||
result_dir=result_dir,
|
||||
scan_id=scan_id,
|
||||
provider_config_path=provider_config_path
|
||||
)
|
||||
all_result_files.extend(result_files)
|
||||
failed_tools.extend(stage1_failed)
|
||||
successful_tool_names.extend(stage1_success)
|
||||
executed_tasks.extend([f'passive ({tool})' for tool in stage1_success])
|
||||
else:
|
||||
logger.warning("未启用任何被动收集工具")
|
||||
|
||||
# 合并 Stage 1 结果
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
@@ -456,7 +489,6 @@ def subdomain_discovery_flow(
|
||||
else:
|
||||
# 创建空文件
|
||||
Path(current_result).touch()
|
||||
logger.warning("Stage 1 无结果,创建空文件")
|
||||
|
||||
# ==================== Stage 2: 字典爆破(可选)====================
|
||||
bruteforce_enabled = bruteforce_config.get('enabled', False)
|
||||
@@ -464,6 +496,7 @@ def subdomain_discovery_flow(
|
||||
logger.info("=" * 40)
|
||||
logger.info("Stage 2: 字典爆破")
|
||||
logger.info("=" * 40)
|
||||
user_log(scan_id, "subdomain_discovery", "Stage 2: bruteforce")
|
||||
|
||||
bruteforce_tool_config = bruteforce_config.get('subdomain_bruteforce', {})
|
||||
wordlist_name = bruteforce_tool_config.get('wordlist_name', 'dns_wordlist.txt')
|
||||
@@ -496,22 +529,16 @@ def subdomain_discovery_flow(
|
||||
**bruteforce_tool_config,
|
||||
'timeout': timeout_value,
|
||||
}
|
||||
logger.info(
|
||||
"subdomain_bruteforce 使用自动 timeout: %s 秒 (字典行数=%s, 3秒/行)",
|
||||
timeout_value,
|
||||
line_count_int,
|
||||
)
|
||||
|
||||
brute_output = str(result_dir / f"subs_brute_{timestamp}.txt")
|
||||
brute_result = _run_single_tool(
|
||||
tool_name='subdomain_bruteforce',
|
||||
tool_config=bruteforce_tool_config,
|
||||
command_params={
|
||||
'domain': domain_name,
|
||||
'wordlist': local_wordlist_path,
|
||||
'output_file': brute_output
|
||||
},
|
||||
result_dir=result_dir
|
||||
result_dir=result_dir,
|
||||
scan_id=scan_id
|
||||
)
|
||||
|
||||
if brute_result:
|
||||
@@ -522,11 +549,16 @@ def subdomain_discovery_flow(
|
||||
)
|
||||
successful_tool_names.append('subdomain_bruteforce')
|
||||
executed_tasks.append('bruteforce')
|
||||
logger.info("✓ subdomain_bruteforce 执行完成")
|
||||
user_log(scan_id, "subdomain_discovery", "subdomain_bruteforce completed")
|
||||
else:
|
||||
failed_tools.append({'tool': 'subdomain_bruteforce', 'reason': '执行失败'})
|
||||
logger.warning("⚠️ subdomain_bruteforce 执行失败")
|
||||
user_log(scan_id, "subdomain_discovery", "subdomain_bruteforce failed: execution failed", "error")
|
||||
except Exception as exc:
|
||||
logger.warning("字典准备失败,跳过字典爆破: %s", exc)
|
||||
failed_tools.append({'tool': 'subdomain_bruteforce', 'reason': str(exc)})
|
||||
logger.warning("字典准备失败,跳过字典爆破: %s", exc)
|
||||
user_log(scan_id, "subdomain_discovery", f"subdomain_bruteforce failed: {str(exc)}", "error")
|
||||
|
||||
# ==================== Stage 3: 变异生成 + 验证(可选)====================
|
||||
permutation_enabled = permutation_config.get('enabled', False)
|
||||
@@ -534,6 +566,7 @@ def subdomain_discovery_flow(
|
||||
logger.info("=" * 40)
|
||||
logger.info("Stage 3: 变异生成 + 存活验证(流式管道)")
|
||||
logger.info("=" * 40)
|
||||
user_log(scan_id, "subdomain_discovery", "Stage 3: permutation + resolve")
|
||||
|
||||
permutation_tool_config = permutation_config.get('subdomain_permutation_resolve', {})
|
||||
|
||||
@@ -587,20 +620,19 @@ def subdomain_discovery_flow(
|
||||
'tool': 'subdomain_permutation_resolve',
|
||||
'reason': f"采样检测到泛解析 (膨胀率 {ratio:.1f}x)"
|
||||
})
|
||||
user_log(scan_id, "subdomain_discovery", f"subdomain_permutation_resolve skipped: wildcard detected (ratio {ratio:.1f}x)", "warning")
|
||||
else:
|
||||
# === Step 3.2: 采样通过,执行完整变异 ===
|
||||
logger.info("采样检测通过,执行完整变异...")
|
||||
|
||||
permuted_output = str(result_dir / f"subs_permuted_{timestamp}.txt")
|
||||
|
||||
permuted_result = _run_single_tool(
|
||||
tool_name='subdomain_permutation_resolve',
|
||||
tool_config=permutation_tool_config,
|
||||
command_params={
|
||||
'input_file': current_result,
|
||||
'output_file': permuted_output,
|
||||
},
|
||||
result_dir=result_dir
|
||||
result_dir=result_dir,
|
||||
scan_id=scan_id
|
||||
)
|
||||
|
||||
if permuted_result:
|
||||
@@ -611,15 +643,21 @@ def subdomain_discovery_flow(
|
||||
)
|
||||
successful_tool_names.append('subdomain_permutation_resolve')
|
||||
executed_tasks.append('permutation')
|
||||
logger.info("✓ subdomain_permutation_resolve 执行完成")
|
||||
user_log(scan_id, "subdomain_discovery", "subdomain_permutation_resolve completed")
|
||||
else:
|
||||
failed_tools.append({'tool': 'subdomain_permutation_resolve', 'reason': '执行失败'})
|
||||
logger.warning("⚠️ subdomain_permutation_resolve 执行失败")
|
||||
user_log(scan_id, "subdomain_discovery", "subdomain_permutation_resolve failed: execution failed", "error")
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(f"采样检测超时 ({SAMPLE_TIMEOUT}秒),跳过变异")
|
||||
failed_tools.append({'tool': 'subdomain_permutation_resolve', 'reason': '采样检测超时'})
|
||||
logger.warning(f"采样检测超时 ({SAMPLE_TIMEOUT}秒),跳过变异")
|
||||
user_log(scan_id, "subdomain_discovery", "subdomain_permutation_resolve failed: sample detection timeout", "error")
|
||||
except Exception as e:
|
||||
logger.warning(f"采样检测失败: {e},跳过变异")
|
||||
failed_tools.append({'tool': 'subdomain_permutation_resolve', 'reason': f'采样检测失败: {e}'})
|
||||
logger.warning(f"采样检测失败: {e},跳过变异")
|
||||
user_log(scan_id, "subdomain_discovery", f"subdomain_permutation_resolve failed: {str(e)}", "error")
|
||||
|
||||
# ==================== Stage 4: DNS 存活验证(可选)====================
|
||||
# 无论是否启用 Stage 3,只要 resolve.enabled 为 true 就会执行,对当前所有候选子域做统一 DNS 验证
|
||||
@@ -628,6 +666,7 @@ def subdomain_discovery_flow(
|
||||
logger.info("=" * 40)
|
||||
logger.info("Stage 4: DNS 存活验证")
|
||||
logger.info("=" * 40)
|
||||
user_log(scan_id, "subdomain_discovery", "Stage 4: DNS resolve")
|
||||
|
||||
resolve_tool_config = resolve_config.get('subdomain_resolve', {})
|
||||
|
||||
@@ -651,30 +690,27 @@ def subdomain_discovery_flow(
|
||||
**resolve_tool_config,
|
||||
'timeout': timeout_value,
|
||||
}
|
||||
logger.info(
|
||||
"subdomain_resolve 使用自动 timeout: %s 秒 (候选子域数=%s, 3秒/域名)",
|
||||
timeout_value,
|
||||
line_count_int,
|
||||
)
|
||||
|
||||
alive_output = str(result_dir / f"subs_alive_{timestamp}.txt")
|
||||
|
||||
alive_result = _run_single_tool(
|
||||
tool_name='subdomain_resolve',
|
||||
tool_config=resolve_tool_config,
|
||||
command_params={
|
||||
'input_file': current_result,
|
||||
'output_file': alive_output,
|
||||
},
|
||||
result_dir=result_dir
|
||||
result_dir=result_dir,
|
||||
scan_id=scan_id
|
||||
)
|
||||
|
||||
if alive_result:
|
||||
current_result = alive_result
|
||||
successful_tool_names.append('subdomain_resolve')
|
||||
executed_tasks.append('resolve')
|
||||
logger.info("✓ subdomain_resolve 执行完成")
|
||||
user_log(scan_id, "subdomain_discovery", "subdomain_resolve completed")
|
||||
else:
|
||||
failed_tools.append({'tool': 'subdomain_resolve', 'reason': '执行失败'})
|
||||
logger.warning("⚠️ subdomain_resolve 执行失败")
|
||||
user_log(scan_id, "subdomain_discovery", "subdomain_resolve failed: execution failed", "error")
|
||||
|
||||
# ==================== Final: 保存到数据库 ====================
|
||||
logger.info("=" * 40)
|
||||
@@ -695,7 +731,9 @@ def subdomain_discovery_flow(
|
||||
processed_domains = save_result.get('processed_records', 0)
|
||||
executed_tasks.append('save_domains')
|
||||
|
||||
# 记录 Flow 完成
|
||||
logger.info("="*60 + "\n✓ 子域名发现扫描完成\n" + "="*60)
|
||||
user_log(scan_id, "subdomain_discovery", f"subdomain_discovery completed: found {processed_domains} subdomains")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
|
||||
@@ -59,6 +59,8 @@ def domain_name_url_fetch_flow(
|
||||
- IP 和 CIDR 类型会自动跳过(waymore 等工具不支持)
|
||||
- 工具会自动收集 *.target_name 的所有历史 URL,无需遍历子域名
|
||||
"""
|
||||
from apps.scan.utils import user_log
|
||||
|
||||
try:
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
@@ -145,6 +147,9 @@ def domain_name_url_fetch_flow(
|
||||
timeout,
|
||||
)
|
||||
|
||||
# 记录工具开始执行日志
|
||||
user_log(scan_id, "url_fetch", f"Running {tool_name}: {command}")
|
||||
|
||||
future = run_url_fetcher_task.submit(
|
||||
tool_name=tool_name,
|
||||
command=command,
|
||||
@@ -163,22 +168,28 @@ def domain_name_url_fetch_flow(
|
||||
if result and result.get("success"):
|
||||
result_files.append(result["output_file"])
|
||||
successful_tools.append(tool_name)
|
||||
url_count = result.get("url_count", 0)
|
||||
logger.info(
|
||||
"✓ 工具 %s 执行成功 - 发现 URL: %d",
|
||||
tool_name,
|
||||
result.get("url_count", 0),
|
||||
url_count,
|
||||
)
|
||||
user_log(scan_id, "url_fetch", f"{tool_name} completed: found {url_count} urls")
|
||||
else:
|
||||
reason = "未生成结果或无有效 URL"
|
||||
failed_tools.append(
|
||||
{
|
||||
"tool": tool_name,
|
||||
"reason": "未生成结果或无有效 URL",
|
||||
"reason": reason,
|
||||
}
|
||||
)
|
||||
logger.warning("⚠️ 工具 %s 未生成有效结果", tool_name)
|
||||
user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error")
|
||||
except Exception as e:
|
||||
failed_tools.append({"tool": tool_name, "reason": str(e)})
|
||||
reason = str(e)
|
||||
failed_tools.append({"tool": tool_name, "reason": reason})
|
||||
logger.warning("⚠️ 工具 %s 执行失败: %s", tool_name, e)
|
||||
user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error")
|
||||
|
||||
logger.info(
|
||||
"基于 domain_name 的 URL 获取完成 - 成功工具: %s, 失败工具: %s",
|
||||
|
||||
@@ -25,6 +25,7 @@ from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.utils import user_log
|
||||
|
||||
from .domain_name_url_fetch_flow import domain_name_url_fetch_flow
|
||||
from .sites_url_fetch_flow import sites_url_fetch_flow
|
||||
@@ -212,7 +213,6 @@ def _validate_and_stream_save_urls(
|
||||
target_id=target_id,
|
||||
cwd=str(url_fetch_dir),
|
||||
shell=True,
|
||||
batch_size=500,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file)
|
||||
)
|
||||
@@ -292,6 +292,8 @@ def url_fetch_flow(
|
||||
"="*60
|
||||
)
|
||||
|
||||
user_log(scan_id, "url_fetch", "Starting URL fetch")
|
||||
|
||||
# Step 1: 准备工作目录
|
||||
logger.info("Step 1: 准备工作目录")
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
@@ -404,7 +406,9 @@ def url_fetch_flow(
|
||||
target_id=target_id
|
||||
)
|
||||
|
||||
logger.info("="*60 + "\n✓ URL 获取扫描完成\n" + "="*60)
|
||||
# 记录 Flow 完成
|
||||
logger.info("✓ URL 获取完成 - 保存 endpoints: %d", saved_count)
|
||||
user_log(scan_id, "url_fetch", f"url_fetch completed: found {saved_count} endpoints")
|
||||
|
||||
# 构建已执行的任务列表
|
||||
executed_tasks = ['setup_directory', 'classify_tools']
|
||||
|
||||
@@ -116,7 +116,8 @@ def sites_url_fetch_flow(
|
||||
tools=enabled_tools,
|
||||
input_file=sites_file,
|
||||
input_type="sites_file",
|
||||
output_dir=output_path
|
||||
output_dir=output_path,
|
||||
scan_id=scan_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -152,7 +152,8 @@ def run_tools_parallel(
|
||||
tools: dict,
|
||||
input_file: str,
|
||||
input_type: str,
|
||||
output_dir: Path
|
||||
output_dir: Path,
|
||||
scan_id: int
|
||||
) -> tuple[list, list, list]:
|
||||
"""
|
||||
并行执行工具列表
|
||||
@@ -162,11 +163,13 @@ def run_tools_parallel(
|
||||
input_file: 输入文件路径
|
||||
input_type: 输入类型
|
||||
output_dir: 输出目录
|
||||
scan_id: 扫描任务 ID(用于记录日志)
|
||||
|
||||
Returns:
|
||||
tuple: (result_files, failed_tools, successful_tool_names)
|
||||
"""
|
||||
from apps.scan.tasks.url_fetch import run_url_fetcher_task
|
||||
from apps.scan.utils import user_log
|
||||
|
||||
futures: dict[str, object] = {}
|
||||
failed_tools: list[dict] = []
|
||||
@@ -192,6 +195,9 @@ def run_tools_parallel(
|
||||
exec_params["timeout"],
|
||||
)
|
||||
|
||||
# 记录工具开始执行日志
|
||||
user_log(scan_id, "url_fetch", f"Running {tool_name}: {exec_params['command']}")
|
||||
|
||||
# 提交并行任务
|
||||
future = run_url_fetcher_task.submit(
|
||||
tool_name=tool_name,
|
||||
@@ -208,22 +214,28 @@ def run_tools_parallel(
|
||||
result = future.result()
|
||||
if result and result['success']:
|
||||
result_files.append(result['output_file'])
|
||||
url_count = result['url_count']
|
||||
logger.info(
|
||||
"✓ 工具 %s 执行成功 - 发现 URL: %d",
|
||||
tool_name, result['url_count']
|
||||
tool_name, url_count
|
||||
)
|
||||
user_log(scan_id, "url_fetch", f"{tool_name} completed: found {url_count} urls")
|
||||
else:
|
||||
reason = '未生成结果或无有效URL'
|
||||
failed_tools.append({
|
||||
'tool': tool_name,
|
||||
'reason': '未生成结果或无有效URL'
|
||||
'reason': reason
|
||||
})
|
||||
logger.warning("⚠️ 工具 %s 未生成有效结果", tool_name)
|
||||
user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error")
|
||||
except Exception as e:
|
||||
reason = str(e)
|
||||
failed_tools.append({
|
||||
'tool': tool_name,
|
||||
'reason': str(e)
|
||||
'reason': reason
|
||||
})
|
||||
logger.warning("⚠️ 工具 %s 执行失败: %s", tool_name, e)
|
||||
user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error")
|
||||
|
||||
# 计算成功的工具列表
|
||||
failed_tool_names = [f['tool'] for f in failed_tools]
|
||||
|
||||
@@ -12,7 +12,7 @@ from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command, ensure_nuclei_templates_local
|
||||
from apps.scan.utils import build_scan_command, ensure_nuclei_templates_local, user_log
|
||||
from apps.scan.tasks.vuln_scan import (
|
||||
export_endpoints_task,
|
||||
run_vuln_tool_task,
|
||||
@@ -141,6 +141,7 @@ def endpoints_vuln_scan_flow(
|
||||
# Dalfox XSS 使用流式任务,一边解析一边保存漏洞结果
|
||||
if tool_name == "dalfox_xss":
|
||||
logger.info("开始执行漏洞扫描工具 %s(流式保存漏洞结果,已提交任务)", tool_name)
|
||||
user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}")
|
||||
future = run_and_stream_save_dalfox_vulns_task.submit(
|
||||
cmd=command,
|
||||
tool_name=tool_name,
|
||||
@@ -163,6 +164,7 @@ def endpoints_vuln_scan_flow(
|
||||
elif tool_name == "nuclei":
|
||||
# Nuclei 使用流式任务
|
||||
logger.info("开始执行漏洞扫描工具 %s(流式保存漏洞结果,已提交任务)", tool_name)
|
||||
user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}")
|
||||
future = run_and_stream_save_nuclei_vulns_task.submit(
|
||||
cmd=command,
|
||||
tool_name=tool_name,
|
||||
@@ -185,6 +187,7 @@ def endpoints_vuln_scan_flow(
|
||||
else:
|
||||
# 其他工具仍使用非流式执行逻辑
|
||||
logger.info("开始执行漏洞扫描工具 %s(已提交任务)", tool_name)
|
||||
user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}")
|
||||
future = run_vuln_tool_task.submit(
|
||||
tool_name=tool_name,
|
||||
command=command,
|
||||
@@ -203,24 +206,34 @@ def endpoints_vuln_scan_flow(
|
||||
# 统一收集所有工具的执行结果
|
||||
for tool_name, meta in tool_futures.items():
|
||||
future = meta["future"]
|
||||
result = future.result()
|
||||
try:
|
||||
result = future.result()
|
||||
|
||||
if meta["mode"] == "streaming":
|
||||
tool_results[tool_name] = {
|
||||
"command": meta["command"],
|
||||
"timeout": meta["timeout"],
|
||||
"processed_records": result.get("processed_records"),
|
||||
"created_vulns": result.get("created_vulns"),
|
||||
"command_log_file": meta["log_file"],
|
||||
}
|
||||
else:
|
||||
tool_results[tool_name] = {
|
||||
"command": meta["command"],
|
||||
"timeout": meta["timeout"],
|
||||
"duration": result.get("duration"),
|
||||
"returncode": result.get("returncode"),
|
||||
"command_log_file": result.get("command_log_file"),
|
||||
}
|
||||
if meta["mode"] == "streaming":
|
||||
created_vulns = result.get("created_vulns", 0)
|
||||
tool_results[tool_name] = {
|
||||
"command": meta["command"],
|
||||
"timeout": meta["timeout"],
|
||||
"processed_records": result.get("processed_records"),
|
||||
"created_vulns": created_vulns,
|
||||
"command_log_file": meta["log_file"],
|
||||
}
|
||||
logger.info("✓ 工具 %s 执行完成 - 漏洞: %d", tool_name, created_vulns)
|
||||
user_log(scan_id, "vuln_scan", f"{tool_name} completed: found {created_vulns} vulnerabilities")
|
||||
else:
|
||||
tool_results[tool_name] = {
|
||||
"command": meta["command"],
|
||||
"timeout": meta["timeout"],
|
||||
"duration": result.get("duration"),
|
||||
"returncode": result.get("returncode"),
|
||||
"command_log_file": result.get("command_log_file"),
|
||||
}
|
||||
logger.info("✓ 工具 %s 执行完成 - returncode=%s", tool_name, result.get("returncode"))
|
||||
user_log(scan_id, "vuln_scan", f"{tool_name} completed")
|
||||
except Exception as e:
|
||||
reason = str(e)
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, e, exc_info=True)
|
||||
user_log(scan_id, "vuln_scan", f"{tool_name} failed: {reason}", "error")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
|
||||
@@ -11,6 +11,7 @@ from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.configs.command_templates import get_command_template
|
||||
from apps.scan.utils import user_log
|
||||
from .endpoints_vuln_scan_flow import endpoints_vuln_scan_flow
|
||||
|
||||
|
||||
@@ -72,6 +73,9 @@ def vuln_scan_flow(
|
||||
if not enabled_tools:
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
logger.info("开始漏洞扫描 - Scan ID: %s, Target: %s", scan_id, target_name)
|
||||
user_log(scan_id, "vuln_scan", "Starting vulnerability scan")
|
||||
|
||||
# Step 1: 分类工具
|
||||
endpoints_tools, other_tools = _classify_vuln_tools(enabled_tools)
|
||||
|
||||
@@ -99,6 +103,14 @@ def vuln_scan_flow(
|
||||
enabled_tools=endpoints_tools,
|
||||
)
|
||||
|
||||
# 记录 Flow 完成
|
||||
total_vulns = sum(
|
||||
r.get("created_vulns", 0)
|
||||
for r in endpoint_result.get("tool_results", {}).values()
|
||||
)
|
||||
logger.info("✓ 漏洞扫描完成 - 新增漏洞: %d", total_vulns)
|
||||
user_log(scan_id, "vuln_scan", f"vuln_scan completed: found {total_vulns} vulnerabilities")
|
||||
|
||||
# 目前只有一个子 Flow,直接返回其结果
|
||||
return endpoint_result
|
||||
|
||||
|
||||
@@ -162,6 +162,8 @@ def on_initiate_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State)
|
||||
# 执行状态更新并获取统计数据
|
||||
stats = _update_completed_status()
|
||||
|
||||
# 注意:物化视图刷新已迁移到 pg_ivm 增量维护,无需手动标记刷新
|
||||
|
||||
# 发送通知(包含统计摘要)
|
||||
logger.info("准备发送扫描完成通知 - Scan ID: %s, Target: %s", scan_id, target_name)
|
||||
try:
|
||||
|
||||
@@ -14,6 +14,7 @@ from prefect import Flow
|
||||
from prefect.client.schemas import FlowRun, State
|
||||
|
||||
from apps.scan.utils.performance import FlowPerformanceTracker
|
||||
from apps.scan.utils import user_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -136,6 +137,7 @@ def on_scan_flow_failed(flow: Flow, flow_run: FlowRun, state: State) -> None:
|
||||
- 更新阶段进度为 failed
|
||||
- 发送扫描失败通知
|
||||
- 记录性能指标(含错误信息)
|
||||
- 写入 ScanLog 供前端显示
|
||||
|
||||
Args:
|
||||
flow: Prefect Flow 对象
|
||||
@@ -152,6 +154,11 @@ def on_scan_flow_failed(flow: Flow, flow_run: FlowRun, state: State) -> None:
|
||||
# 提取错误信息
|
||||
error_message = str(state.message) if state.message else "未知错误"
|
||||
|
||||
# 写入 ScanLog 供前端显示
|
||||
stage = _get_stage_from_flow_name(flow.name)
|
||||
if scan_id and stage:
|
||||
user_log(scan_id, stage, f"Failed: {error_message}", "error")
|
||||
|
||||
# 记录性能指标(失败情况)
|
||||
tracker = _flow_trackers.pop(str(flow_run.id), None)
|
||||
if tracker:
|
||||
|
||||
175
backend/apps/scan/migrations/0001_initial.py
Normal file
175
backend/apps/scan/migrations/0001_initial.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Generated by Django 5.2.7 on 2026-01-06 00:55
|
||||
|
||||
import django.contrib.postgres.fields
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
('engine', '0001_initial'),
|
||||
('targets', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='NotificationSettings',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('discord_enabled', models.BooleanField(default=False, help_text='是否启用 Discord 通知')),
|
||||
('discord_webhook_url', models.URLField(blank=True, default='', help_text='Discord Webhook URL')),
|
||||
('categories', models.JSONField(default=dict, help_text='各分类通知开关,如 {"scan": true, "vulnerability": true, "asset": true, "system": false}')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
('updated_at', models.DateTimeField(auto_now=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '通知设置',
|
||||
'verbose_name_plural': '通知设置',
|
||||
'db_table': 'notification_settings',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='SubfinderProviderSettings',
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('providers', models.JSONField(default=dict, help_text='各 Provider 的 API Key 配置')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True)),
|
||||
('updated_at', models.DateTimeField(auto_now=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Subfinder Provider 配置',
|
||||
'verbose_name_plural': 'Subfinder Provider 配置',
|
||||
'db_table': 'subfinder_provider_settings',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Notification',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('category', models.CharField(choices=[('scan', '扫描任务'), ('vulnerability', '漏洞发现'), ('asset', '资产发现'), ('system', '系统消息')], db_index=True, default='system', help_text='通知分类', max_length=20)),
|
||||
('level', models.CharField(choices=[('low', '低'), ('medium', '中'), ('high', '高'), ('critical', '严重')], db_index=True, default='low', help_text='通知级别', max_length=20)),
|
||||
('title', models.CharField(help_text='通知标题', max_length=200)),
|
||||
('message', models.CharField(help_text='通知内容', max_length=2000)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('is_read', models.BooleanField(default=False, help_text='是否已读')),
|
||||
('read_at', models.DateTimeField(blank=True, help_text='阅读时间', null=True)),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '通知',
|
||||
'verbose_name_plural': '通知',
|
||||
'db_table': 'notification',
|
||||
'ordering': ['-created_at'],
|
||||
'indexes': [models.Index(fields=['-created_at'], name='notificatio_created_c430f0_idx'), models.Index(fields=['category', '-created_at'], name='notificatio_categor_df0584_idx'), models.Index(fields=['level', '-created_at'], name='notificatio_level_0e5d12_idx'), models.Index(fields=['is_read', '-created_at'], name='notificatio_is_read_518ce0_idx')],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='Scan',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('engine_ids', django.contrib.postgres.fields.ArrayField(base_field=models.IntegerField(), default=list, help_text='引擎 ID 列表', size=None)),
|
||||
('engine_names', models.JSONField(default=list, help_text='引擎名称列表,如 ["引擎A", "引擎B"]')),
|
||||
('yaml_configuration', models.TextField(default='', help_text='YAML 格式的扫描配置')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='任务创建时间')),
|
||||
('stopped_at', models.DateTimeField(blank=True, help_text='扫描结束时间', null=True)),
|
||||
('status', models.CharField(choices=[('cancelled', '已取消'), ('completed', '已完成'), ('failed', '失败'), ('initiated', '初始化'), ('running', '运行中')], db_index=True, default='initiated', help_text='任务状态', max_length=20)),
|
||||
('results_dir', models.CharField(blank=True, default='', help_text='结果存储目录', max_length=100)),
|
||||
('container_ids', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='容器 ID 列表(Docker Container ID)', size=None)),
|
||||
('error_message', models.CharField(blank=True, default='', help_text='错误信息', max_length=2000)),
|
||||
('deleted_at', models.DateTimeField(blank=True, db_index=True, help_text='删除时间(NULL表示未删除)', null=True)),
|
||||
('progress', models.IntegerField(default=0, help_text='扫描进度 0-100')),
|
||||
('current_stage', models.CharField(blank=True, default='', help_text='当前扫描阶段', max_length=50)),
|
||||
('stage_progress', models.JSONField(default=dict, help_text='各阶段进度详情')),
|
||||
('cached_subdomains_count', models.IntegerField(default=0, help_text='缓存的子域名数量')),
|
||||
('cached_websites_count', models.IntegerField(default=0, help_text='缓存的网站数量')),
|
||||
('cached_endpoints_count', models.IntegerField(default=0, help_text='缓存的端点数量')),
|
||||
('cached_ips_count', models.IntegerField(default=0, help_text='缓存的IP地址数量')),
|
||||
('cached_directories_count', models.IntegerField(default=0, help_text='缓存的目录数量')),
|
||||
('cached_vulns_total', models.IntegerField(default=0, help_text='缓存的漏洞总数')),
|
||||
('cached_vulns_critical', models.IntegerField(default=0, help_text='缓存的严重漏洞数量')),
|
||||
('cached_vulns_high', models.IntegerField(default=0, help_text='缓存的高危漏洞数量')),
|
||||
('cached_vulns_medium', models.IntegerField(default=0, help_text='缓存的中危漏洞数量')),
|
||||
('cached_vulns_low', models.IntegerField(default=0, help_text='缓存的低危漏洞数量')),
|
||||
('stats_updated_at', models.DateTimeField(blank=True, help_text='统计数据最后更新时间', null=True)),
|
||||
('target', models.ForeignKey(help_text='扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='scans', to='targets.target')),
|
||||
('worker', models.ForeignKey(blank=True, help_text='执行扫描的 Worker 节点', null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='scans', to='engine.workernode')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '扫描任务',
|
||||
'verbose_name_plural': '扫描任务',
|
||||
'db_table': 'scan',
|
||||
'ordering': ['-created_at'],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ScanLog',
|
||||
fields=[
|
||||
('id', models.BigAutoField(primary_key=True, serialize=False)),
|
||||
('level', models.CharField(choices=[('info', 'Info'), ('warning', 'Warning'), ('error', 'Error')], default='info', help_text='日志级别', max_length=10)),
|
||||
('content', models.TextField(help_text='日志内容')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, db_index=True, help_text='创建时间')),
|
||||
('scan', models.ForeignKey(help_text='关联的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='logs', to='scan.scan')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '扫描日志',
|
||||
'verbose_name_plural': '扫描日志',
|
||||
'db_table': 'scan_log',
|
||||
'ordering': ['created_at'],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='ScheduledScan',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('name', models.CharField(help_text='任务名称', max_length=200)),
|
||||
('engine_ids', django.contrib.postgres.fields.ArrayField(base_field=models.IntegerField(), default=list, help_text='引擎 ID 列表', size=None)),
|
||||
('engine_names', models.JSONField(default=list, help_text='引擎名称列表,如 ["引擎A", "引擎B"]')),
|
||||
('yaml_configuration', models.TextField(default='', help_text='YAML 格式的扫描配置')),
|
||||
('cron_expression', models.CharField(default='0 2 * * *', help_text='Cron 表达式,格式:分 时 日 月 周', max_length=100)),
|
||||
('is_enabled', models.BooleanField(db_index=True, default=True, help_text='是否启用')),
|
||||
('run_count', models.IntegerField(default=0, help_text='已执行次数')),
|
||||
('last_run_time', models.DateTimeField(blank=True, help_text='上次执行时间', null=True)),
|
||||
('next_run_time', models.DateTimeField(blank=True, help_text='下次执行时间', null=True)),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
|
||||
('organization', models.ForeignKey(blank=True, help_text='扫描组织(设置后执行时动态获取组织下所有目标)', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='scheduled_scans', to='targets.organization')),
|
||||
('target', models.ForeignKey(blank=True, help_text='扫描单个目标(与 organization 二选一)', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='scheduled_scans', to='targets.target')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': '定时扫描任务',
|
||||
'verbose_name_plural': '定时扫描任务',
|
||||
'db_table': 'scheduled_scan',
|
||||
'ordering': ['-created_at'],
|
||||
},
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name='scan',
|
||||
index=models.Index(fields=['-created_at'], name='scan_created_0bb6c7_idx'),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name='scan',
|
||||
index=models.Index(fields=['target'], name='scan_target__718b9d_idx'),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name='scan',
|
||||
index=models.Index(fields=['deleted_at', '-created_at'], name='scan_deleted_eb17e8_idx'),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name='scanlog',
|
||||
index=models.Index(fields=['scan', 'created_at'], name='scan_log_scan_id_c4814a_idx'),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name='scheduledscan',
|
||||
index=models.Index(fields=['-created_at'], name='scheduled_s_created_9b9c2e_idx'),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name='scheduledscan',
|
||||
index=models.Index(fields=['is_enabled', '-created_at'], name='scheduled_s_is_enab_23d660_idx'),
|
||||
),
|
||||
migrations.AddIndex(
|
||||
model_name='scheduledscan',
|
||||
index=models.Index(fields=['name'], name='scheduled_s_name_bf332d_idx'),
|
||||
),
|
||||
]
|
||||
18
backend/apps/scan/models/__init__.py
Normal file
18
backend/apps/scan/models/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Scan Models - 统一导出"""
|
||||
|
||||
from .scan_models import Scan, SoftDeleteManager
|
||||
from .scan_log_model import ScanLog
|
||||
from .scheduled_scan_model import ScheduledScan
|
||||
from .subfinder_provider_settings_model import SubfinderProviderSettings
|
||||
|
||||
# 兼容旧名称(已废弃,请使用 SubfinderProviderSettings)
|
||||
ProviderSettings = SubfinderProviderSettings
|
||||
|
||||
__all__ = [
|
||||
'Scan',
|
||||
'ScanLog',
|
||||
'ScheduledScan',
|
||||
'SoftDeleteManager',
|
||||
'SubfinderProviderSettings',
|
||||
'ProviderSettings', # 兼容旧名称
|
||||
]
|
||||
41
backend/apps/scan/models/scan_log_model.py
Normal file
41
backend/apps/scan/models/scan_log_model.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""扫描日志模型"""
|
||||
|
||||
from django.db import models
|
||||
|
||||
|
||||
class ScanLog(models.Model):
|
||||
"""扫描日志模型"""
|
||||
|
||||
class Level(models.TextChoices):
|
||||
INFO = 'info', 'Info'
|
||||
WARNING = 'warning', 'Warning'
|
||||
ERROR = 'error', 'Error'
|
||||
|
||||
id = models.BigAutoField(primary_key=True)
|
||||
scan = models.ForeignKey(
|
||||
'Scan',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='logs',
|
||||
db_index=True,
|
||||
help_text='关联的扫描任务'
|
||||
)
|
||||
level = models.CharField(
|
||||
max_length=10,
|
||||
choices=Level.choices,
|
||||
default=Level.INFO,
|
||||
help_text='日志级别'
|
||||
)
|
||||
content = models.TextField(help_text='日志内容')
|
||||
created_at = models.DateTimeField(auto_now_add=True, db_index=True, help_text='创建时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'scan_log'
|
||||
verbose_name = '扫描日志'
|
||||
verbose_name_plural = '扫描日志'
|
||||
ordering = ['created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan', 'created_at']),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f"[{self.level}] {self.content[:50]}"
|
||||
@@ -1,9 +1,9 @@
|
||||
"""扫描相关模型"""
|
||||
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
||||
from ..common.definitions import ScanStatus
|
||||
|
||||
|
||||
from apps.common.definitions import ScanStatus
|
||||
|
||||
|
||||
class SoftDeleteManager(models.Manager):
|
||||
@@ -20,11 +20,19 @@ class Scan(models.Model):
|
||||
|
||||
target = models.ForeignKey('targets.Target', on_delete=models.CASCADE, related_name='scans', help_text='扫描目标')
|
||||
|
||||
engine = models.ForeignKey(
|
||||
'engine.ScanEngine',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='scans',
|
||||
help_text='使用的扫描引擎'
|
||||
# 多引擎支持字段
|
||||
engine_ids = ArrayField(
|
||||
models.IntegerField(),
|
||||
default=list,
|
||||
help_text='引擎 ID 列表'
|
||||
)
|
||||
engine_names = models.JSONField(
|
||||
default=list,
|
||||
help_text='引擎名称列表,如 ["引擎A", "引擎B"]'
|
||||
)
|
||||
yaml_configuration = models.TextField(
|
||||
default='',
|
||||
help_text='YAML 格式的扫描配置'
|
||||
)
|
||||
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='任务创建时间')
|
||||
@@ -89,92 +97,10 @@ class Scan(models.Model):
|
||||
verbose_name_plural = '扫描任务'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-created_at']), # 优化按创建时间降序排序(list 查询的默认排序)
|
||||
models.Index(fields=['target']), # 优化按目标查询扫描任务
|
||||
models.Index(fields=['deleted_at', '-created_at']), # 软删除 + 时间索引
|
||||
models.Index(fields=['-created_at']),
|
||||
models.Index(fields=['target']),
|
||||
models.Index(fields=['deleted_at', '-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f"Scan #{self.id} - {self.target.name}"
|
||||
|
||||
|
||||
class ScheduledScan(models.Model):
|
||||
"""
|
||||
定时扫描任务模型
|
||||
|
||||
调度机制:
|
||||
- APScheduler 每分钟检查 next_run_time
|
||||
- 到期任务通过 task_distributor 分发到 Worker 执行
|
||||
- 支持 cron 表达式进行灵活调度
|
||||
|
||||
扫描模式(二选一):
|
||||
- 组织扫描:设置 organization,执行时动态获取组织下所有目标
|
||||
- 目标扫描:设置 target,扫描单个目标
|
||||
- organization 优先级高于 target
|
||||
"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
|
||||
# 基本信息
|
||||
name = models.CharField(max_length=200, help_text='任务名称')
|
||||
|
||||
# 关联的扫描引擎
|
||||
engine = models.ForeignKey(
|
||||
'engine.ScanEngine',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='scheduled_scans',
|
||||
help_text='使用的扫描引擎'
|
||||
)
|
||||
|
||||
# 关联的组织(组织扫描模式:执行时动态获取组织下所有目标)
|
||||
organization = models.ForeignKey(
|
||||
'targets.Organization',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='scheduled_scans',
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='扫描组织(设置后执行时动态获取组织下所有目标)'
|
||||
)
|
||||
|
||||
# 关联的目标(目标扫描模式:扫描单个目标)
|
||||
target = models.ForeignKey(
|
||||
'targets.Target',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='scheduled_scans',
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='扫描单个目标(与 organization 二选一)'
|
||||
)
|
||||
|
||||
# 调度配置 - 直接使用 Cron 表达式
|
||||
cron_expression = models.CharField(
|
||||
max_length=100,
|
||||
default='0 2 * * *',
|
||||
help_text='Cron 表达式,格式:分 时 日 月 周'
|
||||
)
|
||||
|
||||
# 状态
|
||||
is_enabled = models.BooleanField(default=True, db_index=True, help_text='是否启用')
|
||||
|
||||
# 执行统计
|
||||
run_count = models.IntegerField(default=0, help_text='已执行次数')
|
||||
last_run_time = models.DateTimeField(null=True, blank=True, help_text='上次执行时间')
|
||||
next_run_time = models.DateTimeField(null=True, blank=True, help_text='下次执行时间')
|
||||
|
||||
# 时间戳
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'scheduled_scan'
|
||||
verbose_name = '定时扫描任务'
|
||||
verbose_name_plural = '定时扫描任务'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-created_at']),
|
||||
models.Index(fields=['is_enabled', '-created_at']),
|
||||
models.Index(fields=['name']), # 优化 name 搜索
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f"ScheduledScan #{self.id} - {self.name}"
|
||||
73
backend/apps/scan/models/scheduled_scan_model.py
Normal file
73
backend/apps/scan/models/scheduled_scan_model.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""定时扫描任务模型"""
|
||||
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
|
||||
|
||||
class ScheduledScan(models.Model):
|
||||
"""定时扫描任务模型"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
|
||||
name = models.CharField(max_length=200, help_text='任务名称')
|
||||
|
||||
engine_ids = ArrayField(
|
||||
models.IntegerField(),
|
||||
default=list,
|
||||
help_text='引擎 ID 列表'
|
||||
)
|
||||
engine_names = models.JSONField(
|
||||
default=list,
|
||||
help_text='引擎名称列表,如 ["引擎A", "引擎B"]'
|
||||
)
|
||||
yaml_configuration = models.TextField(
|
||||
default='',
|
||||
help_text='YAML 格式的扫描配置'
|
||||
)
|
||||
|
||||
organization = models.ForeignKey(
|
||||
'targets.Organization',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='scheduled_scans',
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='扫描组织(设置后执行时动态获取组织下所有目标)'
|
||||
)
|
||||
|
||||
target = models.ForeignKey(
|
||||
'targets.Target',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='scheduled_scans',
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='扫描单个目标(与 organization 二选一)'
|
||||
)
|
||||
|
||||
cron_expression = models.CharField(
|
||||
max_length=100,
|
||||
default='0 2 * * *',
|
||||
help_text='Cron 表达式,格式:分 时 日 月 周'
|
||||
)
|
||||
|
||||
is_enabled = models.BooleanField(default=True, db_index=True, help_text='是否启用')
|
||||
|
||||
run_count = models.IntegerField(default=0, help_text='已执行次数')
|
||||
last_run_time = models.DateTimeField(null=True, blank=True, help_text='上次执行时间')
|
||||
next_run_time = models.DateTimeField(null=True, blank=True, help_text='下次执行时间')
|
||||
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'scheduled_scan'
|
||||
verbose_name = '定时扫描任务'
|
||||
verbose_name_plural = '定时扫描任务'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-created_at']),
|
||||
models.Index(fields=['is_enabled', '-created_at']),
|
||||
models.Index(fields=['name']),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f"ScheduledScan #{self.id} - {self.name}"
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Subfinder Provider 配置模型(单例模式)
|
||||
|
||||
用于存储 subfinder 第三方数据源的 API Key 配置
|
||||
"""
|
||||
|
||||
from django.db import models
|
||||
|
||||
|
||||
class SubfinderProviderSettings(models.Model):
|
||||
"""
|
||||
Subfinder Provider 配置(单例模式)
|
||||
存储第三方数据源的 API Key 配置,用于 subfinder 子域名发现
|
||||
|
||||
支持的 Provider:
|
||||
- fofa: email + api_key (composite)
|
||||
- censys: api_id + api_secret (composite)
|
||||
- hunter, shodan, zoomeye, securitytrails, threatbook, quake: api_key (single)
|
||||
"""
|
||||
|
||||
providers = models.JSONField(
|
||||
default=dict,
|
||||
help_text='各 Provider 的 API Key 配置'
|
||||
)
|
||||
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
updated_at = models.DateTimeField(auto_now=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'subfinder_provider_settings'
|
||||
verbose_name = 'Subfinder Provider 配置'
|
||||
verbose_name_plural = 'Subfinder Provider 配置'
|
||||
|
||||
DEFAULT_PROVIDERS = {
|
||||
'fofa': {'enabled': False, 'email': '', 'api_key': ''},
|
||||
'hunter': {'enabled': False, 'api_key': ''},
|
||||
'shodan': {'enabled': False, 'api_key': ''},
|
||||
'censys': {'enabled': False, 'api_id': '', 'api_secret': ''},
|
||||
'zoomeye': {'enabled': False, 'api_key': ''},
|
||||
'securitytrails': {'enabled': False, 'api_key': ''},
|
||||
'threatbook': {'enabled': False, 'api_key': ''},
|
||||
'quake': {'enabled': False, 'api_key': ''},
|
||||
}
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
self.pk = 1
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> 'SubfinderProviderSettings':
|
||||
"""获取或创建单例实例"""
|
||||
obj, _ = cls.objects.get_or_create(
|
||||
pk=1,
|
||||
defaults={'providers': cls.DEFAULT_PROVIDERS.copy()}
|
||||
)
|
||||
return obj
|
||||
|
||||
def get_provider_config(self, provider: str) -> dict:
|
||||
"""获取指定 Provider 的配置"""
|
||||
return self.providers.get(provider, self.DEFAULT_PROVIDERS.get(provider, {}))
|
||||
|
||||
def is_provider_enabled(self, provider: str) -> bool:
|
||||
"""检查指定 Provider 是否启用"""
|
||||
config = self.get_provider_config(provider)
|
||||
return config.get('enabled', False)
|
||||
@@ -5,12 +5,13 @@ WebSocket Consumer - 通知实时推送
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
|
||||
from apps.common.websocket_auth import AuthenticatedWebsocketConsumer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationConsumer(AsyncWebsocketConsumer):
|
||||
class NotificationConsumer(AuthenticatedWebsocketConsumer):
|
||||
"""
|
||||
通知 WebSocket Consumer
|
||||
|
||||
@@ -23,9 +24,9 @@ class NotificationConsumer(AsyncWebsocketConsumer):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.heartbeat_task = None # 心跳任务
|
||||
|
||||
async def connect(self):
|
||||
async def on_connect(self):
|
||||
"""
|
||||
客户端连接时调用
|
||||
客户端连接时调用(已通过认证)
|
||||
加入通知广播组
|
||||
"""
|
||||
# 通知组名(所有客户端共享)
|
||||
|
||||
@@ -305,6 +305,7 @@ def _push_via_api_callback(notification: Notification, server_url: str) -> None:
|
||||
通过 HTTP 请求 Server 容器的 /api/callbacks/notification/ 接口。
|
||||
Worker 无法直接访问 Redis,需要由 Server 代为推送 WebSocket。
|
||||
"""
|
||||
import os
|
||||
import requests
|
||||
|
||||
try:
|
||||
@@ -318,8 +319,14 @@ def _push_via_api_callback(notification: Notification, server_url: str) -> None:
|
||||
'created_at': notification.created_at.isoformat()
|
||||
}
|
||||
|
||||
# 构建请求头(包含 Worker API Key)
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
worker_api_key = os.environ.get("WORKER_API_KEY", "")
|
||||
if worker_api_key:
|
||||
headers["X-Worker-API-Key"] = worker_api_key
|
||||
|
||||
# verify=False: 远程 Worker 回调 Server 时可能使用自签名证书
|
||||
resp = requests.post(callback_url, json=data, timeout=5, verify=False)
|
||||
resp = requests.post(callback_url, json=data, headers=headers, timeout=5, verify=False)
|
||||
resp.raise_for_status()
|
||||
|
||||
logger.debug(f"通知回调推送成功 - ID: {notification.id}")
|
||||
|
||||
@@ -21,9 +21,6 @@ urlpatterns = [
|
||||
|
||||
# 标记全部已读
|
||||
path('mark-all-as-read/', NotificationMarkAllAsReadView.as_view(), name='mark-all-as-read'),
|
||||
|
||||
# 测试通知
|
||||
path('test/', views.notifications_test, name='test'),
|
||||
]
|
||||
|
||||
# WebSocket 实时通知路由在 routing.py 中定义:ws://host/ws/notifications/
|
||||
|
||||
@@ -7,8 +7,7 @@ from typing import Any
|
||||
from django.http import JsonResponse
|
||||
from django.utils import timezone
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import api_view, permission_classes
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.decorators import api_view
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
@@ -24,45 +23,7 @@ from .services import NotificationService, NotificationSettingsService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def notifications_test(request):
|
||||
"""
|
||||
测试通知推送
|
||||
"""
|
||||
try:
|
||||
from .services import create_notification
|
||||
from django.http import JsonResponse
|
||||
|
||||
level_param = request.GET.get('level', NotificationLevel.LOW)
|
||||
try:
|
||||
level_choice = NotificationLevel(level_param)
|
||||
except ValueError:
|
||||
level_choice = NotificationLevel.LOW
|
||||
|
||||
title = request.GET.get('title') or "测试通知"
|
||||
message = request.GET.get('message') or "这是一条测试通知消息"
|
||||
|
||||
# 创建测试通知
|
||||
notification = create_notification(
|
||||
title=title,
|
||||
message=message,
|
||||
level=level_choice
|
||||
)
|
||||
|
||||
return JsonResponse({
|
||||
'success': True,
|
||||
'message': '测试通知已发送',
|
||||
'notification_id': notification.id
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送测试通知失败: {e}")
|
||||
return JsonResponse({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
|
||||
# build_api_response 已废弃,请使用 success_response/error_response
|
||||
|
||||
|
||||
def _parse_bool(value: str | None) -> bool | None:
|
||||
@@ -198,12 +159,13 @@ class NotificationSettingsView(APIView):
|
||||
# ============================================
|
||||
|
||||
@api_view(['POST'])
|
||||
@permission_classes([AllowAny]) # Worker 容器无认证,可考虑添加 Token 验证
|
||||
# 权限由全局 IsAuthenticatedOrPublic 处理,/api/callbacks/* 需要 Worker API Key 认证
|
||||
def notification_callback(request):
|
||||
"""
|
||||
接收 Worker 的通知推送请求
|
||||
|
||||
Worker 容器无法直接访问 Redis,通过此 API 回调让 Server 推送 WebSocket。
|
||||
需要 Worker API Key 认证(X-Worker-API-Key Header)。
|
||||
|
||||
POST /api/callbacks/notification/
|
||||
{
|
||||
|
||||
@@ -16,7 +16,6 @@ from django.utils import timezone
|
||||
|
||||
from apps.scan.models import Scan
|
||||
from apps.targets.models import Target
|
||||
from apps.engine.models import ScanEngine
|
||||
from apps.common.definitions import ScanStatus
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
|
||||
@@ -40,7 +39,7 @@ class DjangoScanRepository:
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
prefetch_relations: 是否预加载关联对象(engine, target)
|
||||
prefetch_relations: 是否预加载关联对象(target, worker)
|
||||
默认 False,只在需要展示关联信息时设为 True
|
||||
for_update: 是否加锁(用于更新场景)
|
||||
|
||||
@@ -56,7 +55,7 @@ class DjangoScanRepository:
|
||||
|
||||
# 预加载关联对象(性能优化:默认不加载)
|
||||
if prefetch_relations:
|
||||
queryset = queryset.select_related('engine', 'target')
|
||||
queryset = queryset.select_related('target', 'worker')
|
||||
|
||||
return queryset.get(id=scan_id)
|
||||
except Scan.DoesNotExist: # type: ignore # pylint: disable=no-member
|
||||
@@ -79,7 +78,7 @@ class DjangoScanRepository:
|
||||
|
||||
Note:
|
||||
- 使用默认的阻塞模式(等待锁释放)
|
||||
- 不包含关联对象(engine, target),如需关联对象请使用 get_by_id()
|
||||
- 不包含关联对象(target, worker),如需关联对象请使用 get_by_id()
|
||||
"""
|
||||
try:
|
||||
return Scan.objects.select_for_update().get(id=scan_id) # type: ignore # pylint: disable=no-member
|
||||
@@ -103,7 +102,9 @@ class DjangoScanRepository:
|
||||
|
||||
def create(self,
|
||||
target: Target,
|
||||
engine: ScanEngine,
|
||||
engine_ids: List[int],
|
||||
engine_names: List[str],
|
||||
yaml_configuration: str,
|
||||
results_dir: str,
|
||||
status: ScanStatus = ScanStatus.INITIATED
|
||||
) -> Scan:
|
||||
@@ -112,7 +113,9 @@ class DjangoScanRepository:
|
||||
|
||||
Args:
|
||||
target: 扫描目标
|
||||
engine: 扫描引擎
|
||||
engine_ids: 引擎 ID 列表
|
||||
engine_names: 引擎名称列表
|
||||
yaml_configuration: YAML 格式的扫描配置
|
||||
results_dir: 结果目录
|
||||
status: 初始状态
|
||||
|
||||
@@ -121,7 +124,9 @@ class DjangoScanRepository:
|
||||
"""
|
||||
scan = Scan(
|
||||
target=target,
|
||||
engine=engine,
|
||||
engine_ids=engine_ids,
|
||||
engine_names=engine_names,
|
||||
yaml_configuration=yaml_configuration,
|
||||
results_dir=results_dir,
|
||||
status=status,
|
||||
container_ids=[]
|
||||
@@ -231,14 +236,14 @@ class DjangoScanRepository:
|
||||
获取所有扫描任务
|
||||
|
||||
Args:
|
||||
prefetch_relations: 是否预加载关联对象(engine, target)
|
||||
prefetch_relations: 是否预加载关联对象(target, worker)
|
||||
|
||||
Returns:
|
||||
Scan QuerySet
|
||||
"""
|
||||
queryset = Scan.objects.all() # type: ignore # pylint: disable=no-member
|
||||
if prefetch_relations:
|
||||
queryset = queryset.select_related('engine', 'target')
|
||||
queryset = queryset.select_related('target', 'worker')
|
||||
return queryset.order_by('-created_at')
|
||||
|
||||
|
||||
|
||||
@@ -29,7 +29,9 @@ class ScheduledScanDTO:
|
||||
"""
|
||||
id: Optional[int] = None
|
||||
name: str = ''
|
||||
engine_id: int = 0
|
||||
engine_ids: List[int] = None # 多引擎支持
|
||||
engine_names: List[str] = None # 引擎名称列表
|
||||
yaml_configuration: str = '' # YAML 格式的扫描配置
|
||||
organization_id: Optional[int] = None # 组织扫描模式
|
||||
target_id: Optional[int] = None # 目标扫描模式
|
||||
cron_expression: Optional[str] = None
|
||||
@@ -40,6 +42,11 @@ class ScheduledScanDTO:
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.engine_ids is None:
|
||||
self.engine_ids = []
|
||||
if self.engine_names is None:
|
||||
self.engine_names = []
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
@@ -56,7 +63,7 @@ class DjangoScheduledScanRepository:
|
||||
def get_by_id(self, scheduled_scan_id: int) -> Optional[ScheduledScan]:
|
||||
"""根据 ID 查询定时扫描任务"""
|
||||
try:
|
||||
return ScheduledScan.objects.select_related('engine', 'organization', 'target').get(id=scheduled_scan_id)
|
||||
return ScheduledScan.objects.select_related('organization', 'target').get(id=scheduled_scan_id)
|
||||
except ScheduledScan.DoesNotExist:
|
||||
return None
|
||||
|
||||
@@ -67,7 +74,7 @@ class DjangoScheduledScanRepository:
|
||||
Returns:
|
||||
QuerySet
|
||||
"""
|
||||
return ScheduledScan.objects.select_related('engine', 'organization', 'target').order_by('-created_at')
|
||||
return ScheduledScan.objects.select_related('organization', 'target').order_by('-created_at')
|
||||
|
||||
def get_all(self, page: int = 1, page_size: int = 10) -> Tuple[List[ScheduledScan], int]:
|
||||
"""
|
||||
@@ -87,7 +94,7 @@ class DjangoScheduledScanRepository:
|
||||
def get_enabled(self) -> List[ScheduledScan]:
|
||||
"""获取所有启用的定时扫描任务"""
|
||||
return list(
|
||||
ScheduledScan.objects.select_related('engine', 'target')
|
||||
ScheduledScan.objects.select_related('target')
|
||||
.filter(is_enabled=True)
|
||||
.order_by('-created_at')
|
||||
)
|
||||
@@ -105,7 +112,9 @@ class DjangoScheduledScanRepository:
|
||||
with transaction.atomic():
|
||||
scheduled_scan = ScheduledScan.objects.create(
|
||||
name=dto.name,
|
||||
engine_id=dto.engine_id,
|
||||
engine_ids=dto.engine_ids,
|
||||
engine_names=dto.engine_names,
|
||||
yaml_configuration=dto.yaml_configuration,
|
||||
organization_id=dto.organization_id, # 组织扫描模式
|
||||
target_id=dto.target_id if not dto.organization_id else None, # 目标扫描模式
|
||||
cron_expression=dto.cron_expression,
|
||||
@@ -134,8 +143,12 @@ class DjangoScheduledScanRepository:
|
||||
# 更新基本字段
|
||||
if dto.name:
|
||||
scheduled_scan.name = dto.name
|
||||
if dto.engine_id:
|
||||
scheduled_scan.engine_id = dto.engine_id
|
||||
if dto.engine_ids is not None:
|
||||
scheduled_scan.engine_ids = dto.engine_ids
|
||||
if dto.engine_names is not None:
|
||||
scheduled_scan.engine_names = dto.engine_names
|
||||
if dto.yaml_configuration is not None:
|
||||
scheduled_scan.yaml_configuration = dto.yaml_configuration
|
||||
if dto.cron_expression is not None:
|
||||
scheduled_scan.cron_expression = dto.cron_expression
|
||||
if dto.is_enabled is not None:
|
||||
|
||||
@@ -1,245 +0,0 @@
|
||||
from rest_framework import serializers
|
||||
from django.db.models import Count
|
||||
|
||||
from .models import Scan, ScheduledScan
|
||||
|
||||
|
||||
class ScanSerializer(serializers.ModelSerializer):
|
||||
"""扫描任务序列化器"""
|
||||
target_name = serializers.SerializerMethodField()
|
||||
engine_name = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = Scan
|
||||
fields = [
|
||||
'id', 'target', 'target_name', 'engine', 'engine_name',
|
||||
'created_at', 'stopped_at', 'status', 'results_dir',
|
||||
'container_ids', 'error_message'
|
||||
]
|
||||
read_only_fields = [
|
||||
'id', 'created_at', 'stopped_at', 'results_dir',
|
||||
'container_ids', 'error_message', 'status'
|
||||
]
|
||||
|
||||
def get_target_name(self, obj):
|
||||
"""获取目标名称"""
|
||||
return obj.target.name if obj.target else None
|
||||
|
||||
def get_engine_name(self, obj):
|
||||
"""获取引擎名称"""
|
||||
return obj.engine.name if obj.engine else None
|
||||
|
||||
|
||||
class ScanHistorySerializer(serializers.ModelSerializer):
|
||||
"""扫描历史列表专用序列化器
|
||||
|
||||
为前端扫描历史页面提供优化的数据格式,包括:
|
||||
- 扫描汇总统计(子域名、端点、漏洞数量)
|
||||
- 进度百分比和当前阶段
|
||||
"""
|
||||
|
||||
# 字段映射
|
||||
target_name = serializers.CharField(source='target.name', read_only=True)
|
||||
engine_name = serializers.CharField(source='engine.name', read_only=True)
|
||||
|
||||
# 计算字段
|
||||
summary = serializers.SerializerMethodField()
|
||||
|
||||
# 进度跟踪字段(直接从模型读取)
|
||||
progress = serializers.IntegerField(read_only=True)
|
||||
current_stage = serializers.CharField(read_only=True)
|
||||
stage_progress = serializers.JSONField(read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = Scan
|
||||
fields = [
|
||||
'id', 'target', 'target_name', 'engine', 'engine_name',
|
||||
'created_at', 'status', 'error_message', 'summary', 'progress',
|
||||
'current_stage', 'stage_progress'
|
||||
]
|
||||
|
||||
def get_summary(self, obj):
|
||||
"""获取扫描汇总数据。
|
||||
|
||||
设计原则:
|
||||
- 子域名/网站/端点/IP/目录使用缓存字段(避免实时 COUNT)
|
||||
- 漏洞统计使用 Scan 上的缓存字段,在扫描结束时统一聚合
|
||||
"""
|
||||
# 1. 使用缓存字段构建基础统计(子域名、网站、端点、IP、目录)
|
||||
summary = {
|
||||
'subdomains': obj.cached_subdomains_count or 0,
|
||||
'websites': obj.cached_websites_count or 0,
|
||||
'endpoints': obj.cached_endpoints_count or 0,
|
||||
'ips': obj.cached_ips_count or 0,
|
||||
'directories': obj.cached_directories_count or 0,
|
||||
}
|
||||
|
||||
# 2. 使用 Scan 模型上的缓存漏洞统计(按严重性聚合)
|
||||
summary['vulnerabilities'] = {
|
||||
'total': obj.cached_vulns_total or 0,
|
||||
'critical': obj.cached_vulns_critical or 0,
|
||||
'high': obj.cached_vulns_high or 0,
|
||||
'medium': obj.cached_vulns_medium or 0,
|
||||
'low': obj.cached_vulns_low or 0,
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
class QuickScanSerializer(serializers.Serializer):
|
||||
"""
|
||||
快速扫描序列化器
|
||||
|
||||
功能:
|
||||
- 接收目标列表和引擎配置
|
||||
- 自动创建/获取目标
|
||||
- 立即发起扫描
|
||||
"""
|
||||
|
||||
# 批量创建的最大数量限制
|
||||
MAX_BATCH_SIZE = 1000
|
||||
|
||||
# 目标列表
|
||||
targets = serializers.ListField(
|
||||
child=serializers.DictField(),
|
||||
help_text='目标列表,每个目标包含 name 字段'
|
||||
)
|
||||
|
||||
# 扫描引擎 ID
|
||||
engine_id = serializers.IntegerField(
|
||||
required=True,
|
||||
help_text='使用的扫描引擎 ID (必填)'
|
||||
)
|
||||
|
||||
def validate_targets(self, value):
|
||||
"""验证目标列表"""
|
||||
if not value:
|
||||
raise serializers.ValidationError("目标列表不能为空")
|
||||
|
||||
# 检查数量限制,防止服务器过载
|
||||
if len(value) > self.MAX_BATCH_SIZE:
|
||||
raise serializers.ValidationError(
|
||||
f"快速扫描最多支持 {self.MAX_BATCH_SIZE} 个目标,当前提交了 {len(value)} 个"
|
||||
)
|
||||
|
||||
# 验证每个目标的必填字段
|
||||
for idx, target in enumerate(value):
|
||||
if 'name' not in target:
|
||||
raise serializers.ValidationError(f"第 {idx + 1} 个目标缺少 name 字段")
|
||||
if not target['name']:
|
||||
raise serializers.ValidationError(f"第 {idx + 1} 个目标的 name 不能为空")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# ==================== 定时扫描序列化器 ====================
|
||||
|
||||
class ScheduledScanSerializer(serializers.ModelSerializer):
|
||||
"""定时扫描任务序列化器(用于列表和详情)"""
|
||||
|
||||
# 关联字段
|
||||
engine_name = serializers.CharField(source='engine.name', read_only=True)
|
||||
organization_id = serializers.IntegerField(source='organization.id', read_only=True, allow_null=True)
|
||||
organization_name = serializers.CharField(source='organization.name', read_only=True, allow_null=True)
|
||||
target_id = serializers.IntegerField(source='target.id', read_only=True, allow_null=True)
|
||||
target_name = serializers.CharField(source='target.name', read_only=True, allow_null=True)
|
||||
scan_mode = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = ScheduledScan
|
||||
fields = [
|
||||
'id', 'name',
|
||||
'engine', 'engine_name',
|
||||
'organization_id', 'organization_name',
|
||||
'target_id', 'target_name',
|
||||
'scan_mode',
|
||||
'cron_expression',
|
||||
'is_enabled',
|
||||
'run_count', 'last_run_time', 'next_run_time',
|
||||
'created_at', 'updated_at'
|
||||
]
|
||||
read_only_fields = [
|
||||
'id', 'run_count',
|
||||
'last_run_time', 'next_run_time',
|
||||
'created_at', 'updated_at'
|
||||
]
|
||||
|
||||
def get_scan_mode(self, obj):
|
||||
"""获取扫描模式:organization 或 target"""
|
||||
return 'organization' if obj.organization_id else 'target'
|
||||
|
||||
|
||||
class CreateScheduledScanSerializer(serializers.Serializer):
|
||||
"""创建定时扫描任务序列化器
|
||||
|
||||
扫描模式(二选一):
|
||||
- 组织扫描:提供 organization_id,执行时动态获取组织下所有目标
|
||||
- 目标扫描:提供 target_id,扫描单个目标
|
||||
"""
|
||||
|
||||
name = serializers.CharField(max_length=200, help_text='任务名称')
|
||||
engine_id = serializers.IntegerField(help_text='扫描引擎 ID')
|
||||
|
||||
# 组织扫描模式
|
||||
organization_id = serializers.IntegerField(
|
||||
required=False,
|
||||
allow_null=True,
|
||||
help_text='组织 ID(组织扫描模式:执行时动态获取组织下所有目标)'
|
||||
)
|
||||
|
||||
# 目标扫描模式
|
||||
target_id = serializers.IntegerField(
|
||||
required=False,
|
||||
allow_null=True,
|
||||
help_text='目标 ID(目标扫描模式:扫描单个目标)'
|
||||
)
|
||||
|
||||
cron_expression = serializers.CharField(
|
||||
max_length=100,
|
||||
default='0 2 * * *',
|
||||
help_text='Cron 表达式,格式:分 时 日 月 周'
|
||||
)
|
||||
is_enabled = serializers.BooleanField(default=True, help_text='是否立即启用')
|
||||
|
||||
def validate(self, data):
|
||||
"""验证 organization_id 和 target_id 互斥"""
|
||||
organization_id = data.get('organization_id')
|
||||
target_id = data.get('target_id')
|
||||
|
||||
if not organization_id and not target_id:
|
||||
raise serializers.ValidationError('必须提供 organization_id 或 target_id 其中之一')
|
||||
|
||||
if organization_id and target_id:
|
||||
raise serializers.ValidationError('organization_id 和 target_id 只能提供其中之一')
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class UpdateScheduledScanSerializer(serializers.Serializer):
|
||||
"""更新定时扫描任务序列化器"""
|
||||
|
||||
name = serializers.CharField(max_length=200, required=False, help_text='任务名称')
|
||||
engine_id = serializers.IntegerField(required=False, help_text='扫描引擎 ID')
|
||||
|
||||
# 组织扫描模式
|
||||
organization_id = serializers.IntegerField(
|
||||
required=False,
|
||||
allow_null=True,
|
||||
help_text='组织 ID(设置后清空 target_id)'
|
||||
)
|
||||
|
||||
# 目标扫描模式
|
||||
target_id = serializers.IntegerField(
|
||||
required=False,
|
||||
allow_null=True,
|
||||
help_text='目标 ID(设置后清空 organization_id)'
|
||||
)
|
||||
|
||||
cron_expression = serializers.CharField(max_length=100, required=False, help_text='Cron 表达式')
|
||||
is_enabled = serializers.BooleanField(required=False, help_text='是否启用')
|
||||
|
||||
|
||||
class ToggleScheduledScanSerializer(serializers.Serializer):
|
||||
"""切换定时扫描启用状态序列化器"""
|
||||
|
||||
is_enabled = serializers.BooleanField(help_text='是否启用')
|
||||
40
backend/apps/scan/serializers/__init__.py
Normal file
40
backend/apps/scan/serializers/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Scan Serializers - 统一导出"""
|
||||
|
||||
from .mixins import ScanConfigValidationMixin
|
||||
from .scan_serializers import (
|
||||
ScanSerializer,
|
||||
ScanHistorySerializer,
|
||||
QuickScanSerializer,
|
||||
InitiateScanSerializer,
|
||||
)
|
||||
from .scan_log_serializers import ScanLogSerializer
|
||||
from .scheduled_scan_serializers import (
|
||||
ScheduledScanSerializer,
|
||||
CreateScheduledScanSerializer,
|
||||
UpdateScheduledScanSerializer,
|
||||
ToggleScheduledScanSerializer,
|
||||
)
|
||||
from .subfinder_provider_settings_serializers import SubfinderProviderSettingsSerializer
|
||||
|
||||
# 兼容旧名称
|
||||
ProviderSettingsSerializer = SubfinderProviderSettingsSerializer
|
||||
|
||||
__all__ = [
|
||||
# Mixins
|
||||
'ScanConfigValidationMixin',
|
||||
# Scan
|
||||
'ScanSerializer',
|
||||
'ScanHistorySerializer',
|
||||
'QuickScanSerializer',
|
||||
'InitiateScanSerializer',
|
||||
# ScanLog
|
||||
'ScanLogSerializer',
|
||||
# Scheduled Scan
|
||||
'ScheduledScanSerializer',
|
||||
'CreateScheduledScanSerializer',
|
||||
'UpdateScheduledScanSerializer',
|
||||
'ToggleScheduledScanSerializer',
|
||||
# Subfinder Provider Settings
|
||||
'SubfinderProviderSettingsSerializer',
|
||||
'ProviderSettingsSerializer', # 兼容旧名称
|
||||
]
|
||||
57
backend/apps/scan/serializers/mixins.py
Normal file
57
backend/apps/scan/serializers/mixins.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""序列化器通用 Mixin 和工具类"""
|
||||
|
||||
from rest_framework import serializers
|
||||
import yaml
|
||||
|
||||
|
||||
class DuplicateKeyLoader(yaml.SafeLoader):
|
||||
"""自定义 YAML Loader,检测重复 key"""
|
||||
pass
|
||||
|
||||
|
||||
def _check_duplicate_keys(loader, node, deep=False):
|
||||
"""检测 YAML mapping 中的重复 key"""
|
||||
mapping = {}
|
||||
for key_node, value_node in node.value:
|
||||
key = loader.construct_object(key_node, deep=deep)
|
||||
if key in mapping:
|
||||
raise yaml.constructor.ConstructorError(
|
||||
"while constructing a mapping", node.start_mark,
|
||||
f"发现重复的配置项 '{key}',后面的配置会覆盖前面的配置,请删除重复项", key_node.start_mark
|
||||
)
|
||||
mapping[key] = loader.construct_object(value_node, deep=deep)
|
||||
return mapping
|
||||
|
||||
|
||||
DuplicateKeyLoader.add_constructor(
|
||||
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
|
||||
_check_duplicate_keys
|
||||
)
|
||||
|
||||
|
||||
class ScanConfigValidationMixin:
|
||||
"""扫描配置验证 Mixin"""
|
||||
|
||||
def validate_configuration(self, value):
|
||||
"""验证 YAML 配置格式"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("configuration 不能为空")
|
||||
|
||||
try:
|
||||
yaml.load(value, Loader=DuplicateKeyLoader)
|
||||
except yaml.YAMLError as e:
|
||||
raise serializers.ValidationError(f"无效的 YAML 格式: {str(e)}")
|
||||
|
||||
return value
|
||||
|
||||
def validate_engine_ids(self, value):
|
||||
"""验证引擎 ID 列表"""
|
||||
if not value:
|
||||
raise serializers.ValidationError("engine_ids 不能为空,请至少选择一个扫描引擎")
|
||||
return value
|
||||
|
||||
def validate_engine_names(self, value):
|
||||
"""验证引擎名称列表"""
|
||||
if not value:
|
||||
raise serializers.ValidationError("engine_names 不能为空")
|
||||
return value
|
||||
13
backend/apps/scan/serializers/scan_log_serializers.py
Normal file
13
backend/apps/scan/serializers/scan_log_serializers.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""扫描日志序列化器"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from ..models import ScanLog
|
||||
|
||||
|
||||
class ScanLogSerializer(serializers.ModelSerializer):
|
||||
"""扫描日志序列化器"""
|
||||
|
||||
class Meta:
|
||||
model = ScanLog
|
||||
fields = ['id', 'level', 'content', 'created_at']
|
||||
111
backend/apps/scan/serializers/scan_serializers.py
Normal file
111
backend/apps/scan/serializers/scan_serializers.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""扫描任务序列化器"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from ..models import Scan
|
||||
from .mixins import ScanConfigValidationMixin
|
||||
|
||||
|
||||
class ScanSerializer(serializers.ModelSerializer):
|
||||
"""扫描任务序列化器"""
|
||||
target_name = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = Scan
|
||||
fields = [
|
||||
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
|
||||
'created_at', 'stopped_at', 'status', 'results_dir',
|
||||
'container_ids', 'error_message'
|
||||
]
|
||||
read_only_fields = [
|
||||
'id', 'created_at', 'stopped_at', 'results_dir',
|
||||
'container_ids', 'error_message', 'status'
|
||||
]
|
||||
|
||||
def get_target_name(self, obj):
|
||||
return obj.target.name if obj.target else None
|
||||
|
||||
|
||||
class ScanHistorySerializer(serializers.ModelSerializer):
|
||||
"""扫描历史列表序列化器"""
|
||||
|
||||
target_name = serializers.CharField(source='target.name', read_only=True)
|
||||
worker_name = serializers.CharField(source='worker.name', read_only=True, allow_null=True)
|
||||
summary = serializers.SerializerMethodField()
|
||||
progress = serializers.IntegerField(read_only=True)
|
||||
current_stage = serializers.CharField(read_only=True)
|
||||
stage_progress = serializers.JSONField(read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = Scan
|
||||
fields = [
|
||||
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
|
||||
'worker_name', 'created_at', 'status', 'error_message', 'summary',
|
||||
'progress', 'current_stage', 'stage_progress'
|
||||
]
|
||||
|
||||
def get_summary(self, obj):
|
||||
summary = {
|
||||
'subdomains': obj.cached_subdomains_count or 0,
|
||||
'websites': obj.cached_websites_count or 0,
|
||||
'endpoints': obj.cached_endpoints_count or 0,
|
||||
'ips': obj.cached_ips_count or 0,
|
||||
'directories': obj.cached_directories_count or 0,
|
||||
}
|
||||
summary['vulnerabilities'] = {
|
||||
'total': obj.cached_vulns_total or 0,
|
||||
'critical': obj.cached_vulns_critical or 0,
|
||||
'high': obj.cached_vulns_high or 0,
|
||||
'medium': obj.cached_vulns_medium or 0,
|
||||
'low': obj.cached_vulns_low or 0,
|
||||
}
|
||||
return summary
|
||||
|
||||
|
||||
class QuickScanSerializer(ScanConfigValidationMixin, serializers.Serializer):
|
||||
"""快速扫描序列化器"""
|
||||
|
||||
MAX_BATCH_SIZE = 5000
|
||||
|
||||
targets = serializers.ListField(
|
||||
child=serializers.DictField(),
|
||||
help_text='目标列表,每个目标包含 name 字段'
|
||||
)
|
||||
configuration = serializers.CharField(required=True, help_text='YAML 格式的扫描配置')
|
||||
engine_ids = serializers.ListField(child=serializers.IntegerField(), required=True)
|
||||
engine_names = serializers.ListField(child=serializers.CharField(), required=True)
|
||||
|
||||
def validate_targets(self, value):
|
||||
if not value:
|
||||
raise serializers.ValidationError("目标列表不能为空")
|
||||
if len(value) > self.MAX_BATCH_SIZE:
|
||||
raise serializers.ValidationError(
|
||||
f"快速扫描最多支持 {self.MAX_BATCH_SIZE} 个目标,当前提交了 {len(value)} 个"
|
||||
)
|
||||
for idx, target in enumerate(value):
|
||||
if 'name' not in target:
|
||||
raise serializers.ValidationError(f"第 {idx + 1} 个目标缺少 name 字段")
|
||||
if not target['name']:
|
||||
raise serializers.ValidationError(f"第 {idx + 1} 个目标的 name 不能为空")
|
||||
return value
|
||||
|
||||
|
||||
class InitiateScanSerializer(ScanConfigValidationMixin, serializers.Serializer):
|
||||
"""发起扫描任务序列化器"""
|
||||
|
||||
configuration = serializers.CharField(required=True, help_text='YAML 格式的扫描配置')
|
||||
engine_ids = serializers.ListField(child=serializers.IntegerField(), required=True)
|
||||
engine_names = serializers.ListField(child=serializers.CharField(), required=True)
|
||||
organization_id = serializers.IntegerField(required=False, allow_null=True)
|
||||
target_id = serializers.IntegerField(required=False, allow_null=True)
|
||||
|
||||
def validate(self, data):
|
||||
organization_id = data.get('organization_id')
|
||||
target_id = data.get('target_id')
|
||||
|
||||
if not organization_id and not target_id:
|
||||
raise serializers.ValidationError('必须提供 organization_id 或 target_id 其中之一')
|
||||
if organization_id and target_id:
|
||||
raise serializers.ValidationError('organization_id 和 target_id 只能提供其中之一')
|
||||
|
||||
return data
|
||||
84
backend/apps/scan/serializers/scheduled_scan_serializers.py
Normal file
84
backend/apps/scan/serializers/scheduled_scan_serializers.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""定时扫描序列化器"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from ..models import ScheduledScan
|
||||
from .mixins import ScanConfigValidationMixin
|
||||
|
||||
|
||||
class ScheduledScanSerializer(serializers.ModelSerializer):
|
||||
"""定时扫描任务序列化器(用于列表和详情)"""
|
||||
|
||||
organization_id = serializers.IntegerField(source='organization.id', read_only=True, allow_null=True)
|
||||
organization_name = serializers.CharField(source='organization.name', read_only=True, allow_null=True)
|
||||
target_id = serializers.IntegerField(source='target.id', read_only=True, allow_null=True)
|
||||
target_name = serializers.CharField(source='target.name', read_only=True, allow_null=True)
|
||||
scan_mode = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = ScheduledScan
|
||||
fields = [
|
||||
'id', 'name',
|
||||
'engine_ids', 'engine_names',
|
||||
'organization_id', 'organization_name',
|
||||
'target_id', 'target_name',
|
||||
'scan_mode',
|
||||
'cron_expression',
|
||||
'is_enabled',
|
||||
'run_count', 'last_run_time', 'next_run_time',
|
||||
'created_at', 'updated_at'
|
||||
]
|
||||
read_only_fields = [
|
||||
'id', 'run_count',
|
||||
'last_run_time', 'next_run_time',
|
||||
'created_at', 'updated_at'
|
||||
]
|
||||
|
||||
def get_scan_mode(self, obj):
|
||||
return 'organization' if obj.organization_id else 'target'
|
||||
|
||||
|
||||
class CreateScheduledScanSerializer(ScanConfigValidationMixin, serializers.Serializer):
|
||||
"""创建定时扫描任务序列化器"""
|
||||
|
||||
name = serializers.CharField(max_length=200, help_text='任务名称')
|
||||
configuration = serializers.CharField(required=True, help_text='YAML 格式的扫描配置')
|
||||
engine_ids = serializers.ListField(child=serializers.IntegerField(), required=True)
|
||||
engine_names = serializers.ListField(child=serializers.CharField(), required=True)
|
||||
organization_id = serializers.IntegerField(required=False, allow_null=True)
|
||||
target_id = serializers.IntegerField(required=False, allow_null=True)
|
||||
cron_expression = serializers.CharField(max_length=100, default='0 2 * * *')
|
||||
is_enabled = serializers.BooleanField(default=True)
|
||||
|
||||
def validate(self, data):
|
||||
organization_id = data.get('organization_id')
|
||||
target_id = data.get('target_id')
|
||||
|
||||
if not organization_id and not target_id:
|
||||
raise serializers.ValidationError('必须提供 organization_id 或 target_id 其中之一')
|
||||
if organization_id and target_id:
|
||||
raise serializers.ValidationError('organization_id 和 target_id 只能提供其中之一')
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class UpdateScheduledScanSerializer(serializers.Serializer):
|
||||
"""更新定时扫描任务序列化器"""
|
||||
|
||||
name = serializers.CharField(max_length=200, required=False)
|
||||
engine_ids = serializers.ListField(child=serializers.IntegerField(), required=False)
|
||||
organization_id = serializers.IntegerField(required=False, allow_null=True)
|
||||
target_id = serializers.IntegerField(required=False, allow_null=True)
|
||||
cron_expression = serializers.CharField(max_length=100, required=False)
|
||||
is_enabled = serializers.BooleanField(required=False)
|
||||
|
||||
def validate_engine_ids(self, value):
|
||||
if value is not None and not value:
|
||||
raise serializers.ValidationError("engine_ids 不能为空")
|
||||
return value
|
||||
|
||||
|
||||
class ToggleScheduledScanSerializer(serializers.Serializer):
|
||||
"""切换定时扫描启用状态序列化器"""
|
||||
|
||||
is_enabled = serializers.BooleanField(help_text='是否启用')
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Subfinder Provider 配置序列化器"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class SubfinderProviderSettingsSerializer(serializers.Serializer):
|
||||
"""Subfinder Provider 配置序列化器
|
||||
|
||||
支持的 Provider:
|
||||
- fofa: email + api_key (composite)
|
||||
- censys: api_id + api_secret (composite)
|
||||
- hunter, shodan, zoomeye, securitytrails, threatbook, quake: api_key (single)
|
||||
|
||||
注意:djangorestframework-camel-case 会自动处理 camelCase <-> snake_case 转换
|
||||
所以这里统一使用 snake_case
|
||||
"""
|
||||
|
||||
VALID_PROVIDERS = {
|
||||
'fofa', 'hunter', 'shodan', 'censys',
|
||||
'zoomeye', 'securitytrails', 'threatbook', 'quake'
|
||||
}
|
||||
|
||||
def to_internal_value(self, data):
|
||||
"""验证并转换输入数据"""
|
||||
if not isinstance(data, dict):
|
||||
raise serializers.ValidationError('Expected a dictionary')
|
||||
|
||||
result = {}
|
||||
for provider, config in data.items():
|
||||
if provider not in self.VALID_PROVIDERS:
|
||||
continue
|
||||
|
||||
if not isinstance(config, dict):
|
||||
continue
|
||||
|
||||
db_config = {'enabled': bool(config.get('enabled', False))}
|
||||
|
||||
if provider == 'fofa':
|
||||
db_config['email'] = str(config.get('email', ''))
|
||||
db_config['api_key'] = str(config.get('api_key', ''))
|
||||
elif provider == 'censys':
|
||||
db_config['api_id'] = str(config.get('api_id', ''))
|
||||
db_config['api_secret'] = str(config.get('api_secret', ''))
|
||||
else:
|
||||
db_config['api_key'] = str(config.get('api_key', ''))
|
||||
|
||||
result[provider] = db_config
|
||||
|
||||
return result
|
||||
|
||||
def to_representation(self, instance):
|
||||
"""输出数据(数据库格式,camel-case 中间件会自动转换)"""
|
||||
if isinstance(instance, dict):
|
||||
return instance
|
||||
return instance.providers if hasattr(instance, 'providers') else {}
|
||||
@@ -17,7 +17,6 @@ from .scan_state_service import ScanStateService
|
||||
from .scan_control_service import ScanControlService
|
||||
from .scan_stats_service import ScanStatsService
|
||||
from .scheduled_scan_service import ScheduledScanService
|
||||
from .blacklist_service import BlacklistService
|
||||
from .target_export_service import TargetExportService
|
||||
|
||||
__all__ = [
|
||||
@@ -27,7 +26,6 @@ __all__ = [
|
||||
'ScanControlService',
|
||||
'ScanStatsService',
|
||||
'ScheduledScanService',
|
||||
'BlacklistService', # 黑名单过滤服务
|
||||
'TargetExportService', # 目标导出服务
|
||||
]
|
||||
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
"""
|
||||
黑名单过滤服务
|
||||
|
||||
过滤敏感域名(如 .gov、.edu、.mil 等)
|
||||
|
||||
当前版本使用默认规则,后续将支持从前端配置加载。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from django.db.models import QuerySet
|
||||
import re
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BlacklistService:
|
||||
"""
|
||||
黑名单过滤服务 - 过滤敏感域名
|
||||
|
||||
TODO: 后续版本支持从前端配置加载黑名单规则
|
||||
- 用户在开始扫描时配置黑名单 URL、域名、IP
|
||||
- 黑名单规则存储在数据库中,与 Scan 或 Engine 关联
|
||||
"""
|
||||
|
||||
# 默认黑名单正则规则
|
||||
DEFAULT_PATTERNS = [
|
||||
r'\.gov$', # .gov 结尾
|
||||
r'\.gov\.[a-z]{2}$', # .gov.cn, .gov.uk 等
|
||||
]
|
||||
|
||||
def __init__(self, patterns: Optional[List[str]] = None):
|
||||
"""
|
||||
初始化黑名单服务
|
||||
|
||||
Args:
|
||||
patterns: 正则表达式列表,None 使用默认规则
|
||||
"""
|
||||
self.patterns = patterns or self.DEFAULT_PATTERNS
|
||||
self._compiled_patterns = [re.compile(p) for p in self.patterns]
|
||||
|
||||
def filter_queryset(
|
||||
self,
|
||||
queryset: QuerySet,
|
||||
url_field: str = 'url'
|
||||
) -> QuerySet:
|
||||
"""
|
||||
数据库层面过滤 queryset
|
||||
|
||||
使用 PostgreSQL 正则表达式排除黑名单 URL
|
||||
|
||||
Args:
|
||||
queryset: 原始 queryset
|
||||
url_field: URL 字段名
|
||||
|
||||
Returns:
|
||||
QuerySet: 过滤后的 queryset
|
||||
"""
|
||||
for pattern in self.patterns:
|
||||
queryset = queryset.exclude(**{f'{url_field}__regex': pattern})
|
||||
return queryset
|
||||
|
||||
def filter_url(self, url: str) -> bool:
|
||||
"""
|
||||
检查单个 URL 是否通过黑名单过滤
|
||||
|
||||
Args:
|
||||
url: 要检查的 URL
|
||||
|
||||
Returns:
|
||||
bool: True 表示通过(不在黑名单),False 表示被过滤
|
||||
"""
|
||||
for pattern in self._compiled_patterns:
|
||||
if pattern.search(url):
|
||||
return False
|
||||
return True
|
||||
|
||||
# TODO: 后续版本实现
|
||||
# @classmethod
|
||||
# def from_scan(cls, scan_id: int) -> 'BlacklistService':
|
||||
# """从数据库加载扫描配置的黑名单规则"""
|
||||
# pass
|
||||
@@ -10,7 +10,7 @@
|
||||
import uuid
|
||||
import logging
|
||||
import threading
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from django.conf import settings
|
||||
@@ -20,6 +20,7 @@ from django.core.exceptions import ValidationError, ObjectDoesNotExist
|
||||
|
||||
from apps.scan.models import Scan
|
||||
from apps.scan.repositories import DjangoScanRepository
|
||||
from apps.scan.utils.config_merger import merge_engine_configs, ConfigConflictError
|
||||
from apps.targets.repositories import DjangoTargetRepository, DjangoOrganizationRepository
|
||||
from apps.engine.repositories import DjangoEngineRepository
|
||||
from apps.targets.models import Target
|
||||
@@ -142,6 +143,106 @@ class ScanCreationService:
|
||||
|
||||
return targets, engine
|
||||
|
||||
def prepare_initiate_scan_multi_engine(
|
||||
self,
|
||||
organization_id: int | None = None,
|
||||
target_id: int | None = None,
|
||||
engine_ids: List[int] | None = None
|
||||
) -> Tuple[List[Target], str, List[str], List[int]]:
|
||||
"""
|
||||
准备多引擎扫描任务所需的数据
|
||||
|
||||
职责:
|
||||
1. 参数验证(必填项、互斥参数)
|
||||
2. 资源查询(Engines、Organization、Target)
|
||||
3. 合并引擎配置(检测冲突)
|
||||
4. 返回准备好的目标列表、合并配置和引擎信息
|
||||
|
||||
Args:
|
||||
organization_id: 组织ID(可选)
|
||||
target_id: 目标ID(可选)
|
||||
engine_ids: 扫描引擎ID列表(必填)
|
||||
|
||||
Returns:
|
||||
(目标列表, 合并配置, 引擎名称列表, 引擎ID列表) - 供 create_scans 方法使用
|
||||
|
||||
Raises:
|
||||
ValidationError: 参数验证失败或业务规则不满足
|
||||
ObjectDoesNotExist: 资源不存在(Organization/Target/ScanEngine)
|
||||
ConfigConflictError: 引擎配置存在冲突
|
||||
|
||||
Note:
|
||||
- organization_id 和 target_id 必须二选一
|
||||
- 如果提供 organization_id,返回该组织下所有目标
|
||||
- 如果提供 target_id,返回单个目标列表
|
||||
"""
|
||||
# 1. 参数验证
|
||||
if not engine_ids:
|
||||
raise ValidationError('缺少必填参数: engine_ids')
|
||||
|
||||
if not organization_id and not target_id:
|
||||
raise ValidationError('必须提供 organization_id 或 target_id 其中之一')
|
||||
|
||||
if organization_id and target_id:
|
||||
raise ValidationError('organization_id 和 target_id 只能提供其中之一')
|
||||
|
||||
# 2. 查询所有扫描引擎
|
||||
engines = []
|
||||
for engine_id in engine_ids:
|
||||
engine = self.engine_repo.get_by_id(engine_id)
|
||||
if not engine:
|
||||
logger.error("扫描引擎不存在 - Engine ID: %s", engine_id)
|
||||
raise ObjectDoesNotExist(f'ScanEngine ID {engine_id} 不存在')
|
||||
engines.append(engine)
|
||||
|
||||
# 3. 合并引擎配置(可能抛出 ConfigConflictError)
|
||||
engine_configs = [(e.name, e.configuration or '') for e in engines]
|
||||
merged_configuration = merge_engine_configs(engine_configs)
|
||||
engine_names = [e.name for e in engines]
|
||||
|
||||
logger.debug(
|
||||
"引擎配置合并成功 - 引擎: %s",
|
||||
', '.join(engine_names)
|
||||
)
|
||||
|
||||
# 4. 根据参数获取目标列表
|
||||
targets = []
|
||||
|
||||
if organization_id:
|
||||
# 根据组织ID获取所有目标
|
||||
organization = self.organization_repo.get_by_id(organization_id)
|
||||
if not organization:
|
||||
logger.error("组织不存在 - Organization ID: %s", organization_id)
|
||||
raise ObjectDoesNotExist(f'Organization ID {organization_id} 不存在')
|
||||
|
||||
targets = self.organization_repo.get_targets(organization_id)
|
||||
|
||||
if not targets:
|
||||
raise ValidationError(f'组织 ID {organization_id} 下没有目标')
|
||||
|
||||
logger.debug(
|
||||
"准备发起扫描 - 组织: %s, 目标数量: %d, 引擎: %s",
|
||||
organization.name,
|
||||
len(targets),
|
||||
', '.join(engine_names)
|
||||
)
|
||||
else:
|
||||
# 根据目标ID获取单个目标
|
||||
target = self.target_repo.get_by_id(target_id)
|
||||
if not target:
|
||||
logger.error("目标不存在 - Target ID: %s", target_id)
|
||||
raise ObjectDoesNotExist(f'Target ID {target_id} 不存在')
|
||||
|
||||
targets = [target]
|
||||
|
||||
logger.debug(
|
||||
"准备发起扫描 - 目标: %s, 引擎: %s",
|
||||
target.name,
|
||||
', '.join(engine_names)
|
||||
)
|
||||
|
||||
return targets, merged_configuration, engine_names, engine_ids
|
||||
|
||||
def _generate_scan_workspace_dir(self) -> str:
|
||||
"""
|
||||
生成 Scan 工作空间目录路径
|
||||
@@ -179,7 +280,9 @@ class ScanCreationService:
|
||||
def create_scans(
|
||||
self,
|
||||
targets: List[Target],
|
||||
engine: ScanEngine,
|
||||
engine_ids: List[int],
|
||||
engine_names: List[str],
|
||||
yaml_configuration: str,
|
||||
scheduled_scan_name: str | None = None
|
||||
) -> List[Scan]:
|
||||
"""
|
||||
@@ -187,7 +290,9 @@ class ScanCreationService:
|
||||
|
||||
Args:
|
||||
targets: 目标列表
|
||||
engine: 扫描引擎对象
|
||||
engine_ids: 引擎 ID 列表
|
||||
engine_names: 引擎名称列表
|
||||
yaml_configuration: YAML 格式的扫描配置
|
||||
scheduled_scan_name: 定时扫描任务名称(可选,用于通知显示)
|
||||
|
||||
Returns:
|
||||
@@ -205,7 +310,9 @@ class ScanCreationService:
|
||||
scan_workspace_dir = self._generate_scan_workspace_dir()
|
||||
scan = Scan(
|
||||
target=target,
|
||||
engine=engine,
|
||||
engine_ids=engine_ids,
|
||||
engine_names=engine_names,
|
||||
yaml_configuration=yaml_configuration,
|
||||
results_dir=scan_workspace_dir,
|
||||
status=ScanStatus.INITIATED,
|
||||
container_ids=[],
|
||||
@@ -236,13 +343,15 @@ class ScanCreationService:
|
||||
return []
|
||||
|
||||
# 第三步:分发任务到 Workers
|
||||
# 使用第一个引擎名称作为显示名称,或者合并显示
|
||||
display_engine_name = ', '.join(engine_names) if engine_names else ''
|
||||
scan_data = [
|
||||
{
|
||||
'scan_id': scan.id,
|
||||
'target_name': scan.target.name,
|
||||
'target_id': scan.target.id,
|
||||
'results_dir': scan.results_dir,
|
||||
'engine_name': scan.engine.name,
|
||||
'engine_name': display_engine_name,
|
||||
'scheduled_scan_name': scheduled_scan_name,
|
||||
}
|
||||
for scan in created_scans
|
||||
|
||||
@@ -96,14 +96,34 @@ class ScanService:
|
||||
organization_id, target_id, engine_id
|
||||
)
|
||||
|
||||
def prepare_initiate_scan_multi_engine(
|
||||
self,
|
||||
organization_id: int | None = None,
|
||||
target_id: int | None = None,
|
||||
engine_ids: List[int] | None = None
|
||||
) -> tuple[List[Target], str, List[str], List[int]]:
|
||||
"""
|
||||
为创建多引擎扫描任务做准备
|
||||
|
||||
Returns:
|
||||
(目标列表, 合并配置, 引擎名称列表, 引擎ID列表)
|
||||
"""
|
||||
return self.creation_service.prepare_initiate_scan_multi_engine(
|
||||
organization_id, target_id, engine_ids
|
||||
)
|
||||
|
||||
def create_scans(
|
||||
self,
|
||||
targets: List[Target],
|
||||
engine: ScanEngine,
|
||||
engine_ids: List[int],
|
||||
engine_names: List[str],
|
||||
yaml_configuration: str,
|
||||
scheduled_scan_name: str | None = None
|
||||
) -> List[Scan]:
|
||||
"""批量创建扫描任务(委托给 ScanCreationService)"""
|
||||
return self.creation_service.create_scans(targets, engine, scheduled_scan_name)
|
||||
return self.creation_service.create_scans(
|
||||
targets, engine_ids, engine_names, yaml_configuration, scheduled_scan_name
|
||||
)
|
||||
|
||||
# ==================== 状态管理方法(委托给 ScanStateService) ====================
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from django.core.exceptions import ValidationError
|
||||
|
||||
from apps.scan.models import ScheduledScan
|
||||
from apps.scan.repositories import DjangoScheduledScanRepository, ScheduledScanDTO
|
||||
from apps.scan.utils.config_merger import merge_engine_configs, ConfigConflictError
|
||||
from apps.engine.repositories import DjangoEngineRepository
|
||||
from apps.targets.services import TargetService
|
||||
|
||||
@@ -53,12 +54,13 @@ class ScheduledScanService:
|
||||
|
||||
def create(self, dto: ScheduledScanDTO) -> ScheduledScan:
|
||||
"""
|
||||
创建定时扫描任务
|
||||
创建定时扫描任务(使用引擎 ID 合并配置)
|
||||
|
||||
流程:
|
||||
1. 验证参数
|
||||
2. 创建数据库记录
|
||||
3. 计算并设置 next_run_time
|
||||
2. 合并引擎配置
|
||||
3. 创建数据库记录
|
||||
4. 计算并设置 next_run_time
|
||||
|
||||
Args:
|
||||
dto: 定时扫描 DTO
|
||||
@@ -68,11 +70,66 @@ class ScheduledScanService:
|
||||
|
||||
Raises:
|
||||
ValidationError: 参数验证失败
|
||||
ConfigConflictError: 引擎配置冲突
|
||||
"""
|
||||
# 1. 验证参数
|
||||
self._validate_create_dto(dto)
|
||||
|
||||
# 2. 创建数据库记录
|
||||
# 2. 合并引擎配置
|
||||
engines = []
|
||||
engine_names = []
|
||||
for engine_id in dto.engine_ids:
|
||||
engine = self.engine_repo.get_by_id(engine_id)
|
||||
if engine:
|
||||
engines.append((engine.name, engine.configuration or ''))
|
||||
engine_names.append(engine.name)
|
||||
|
||||
merged_configuration = merge_engine_configs(engines)
|
||||
|
||||
# 设置 DTO 的合并配置和引擎名称
|
||||
dto.engine_names = engine_names
|
||||
dto.yaml_configuration = merged_configuration
|
||||
|
||||
# 3. 创建数据库记录
|
||||
scheduled_scan = self.repo.create(dto)
|
||||
|
||||
# 4. 如果有 cron 表达式且已启用,计算下次执行时间
|
||||
if scheduled_scan.cron_expression and scheduled_scan.is_enabled:
|
||||
next_run_time = self._calculate_next_run_time(scheduled_scan)
|
||||
if next_run_time:
|
||||
self.repo.update_next_run_time(scheduled_scan.id, next_run_time)
|
||||
scheduled_scan.next_run_time = next_run_time
|
||||
|
||||
logger.info(
|
||||
"创建定时扫描任务 - ID: %s, 名称: %s, 下次执行: %s",
|
||||
scheduled_scan.id, scheduled_scan.name, scheduled_scan.next_run_time
|
||||
)
|
||||
|
||||
return scheduled_scan
|
||||
|
||||
def create_with_configuration(self, dto: ScheduledScanDTO) -> ScheduledScan:
|
||||
"""
|
||||
创建定时扫描任务(直接使用前端传递的配置)
|
||||
|
||||
流程:
|
||||
1. 验证参数
|
||||
2. 直接使用 dto.yaml_configuration
|
||||
3. 创建数据库记录
|
||||
4. 计算并设置 next_run_time
|
||||
|
||||
Args:
|
||||
dto: 定时扫描 DTO(必须包含 yaml_configuration)
|
||||
|
||||
Returns:
|
||||
创建的 ScheduledScan 对象
|
||||
|
||||
Raises:
|
||||
ValidationError: 参数验证失败
|
||||
"""
|
||||
# 1. 验证参数
|
||||
self._validate_create_dto_with_configuration(dto)
|
||||
|
||||
# 2. 创建数据库记录(直接使用 dto 中的配置)
|
||||
scheduled_scan = self.repo.create(dto)
|
||||
|
||||
# 3. 如果有 cron 表达式且已启用,计算下次执行时间
|
||||
@@ -90,18 +147,33 @@ class ScheduledScanService:
|
||||
return scheduled_scan
|
||||
|
||||
def _validate_create_dto(self, dto: ScheduledScanDTO) -> None:
|
||||
"""验证创建 DTO"""
|
||||
"""验证创建 DTO(使用引擎 ID)"""
|
||||
# 基础验证
|
||||
self._validate_base_dto(dto)
|
||||
|
||||
if not dto.engine_ids:
|
||||
raise ValidationError('必须选择扫描引擎')
|
||||
|
||||
# 验证所有引擎是否存在
|
||||
for engine_id in dto.engine_ids:
|
||||
if not self.engine_repo.get_by_id(engine_id):
|
||||
raise ValidationError(f'扫描引擎 ID {engine_id} 不存在')
|
||||
|
||||
def _validate_create_dto_with_configuration(self, dto: ScheduledScanDTO) -> None:
|
||||
"""验证创建 DTO(使用前端传递的配置)"""
|
||||
# 基础验证
|
||||
self._validate_base_dto(dto)
|
||||
|
||||
if not dto.yaml_configuration:
|
||||
raise ValidationError('配置不能为空')
|
||||
|
||||
def _validate_base_dto(self, dto: ScheduledScanDTO) -> None:
|
||||
"""验证 DTO 的基础字段(公共逻辑)"""
|
||||
from apps.targets.repositories import DjangoOrganizationRepository
|
||||
|
||||
if not dto.name:
|
||||
raise ValidationError('任务名称不能为空')
|
||||
|
||||
if not dto.engine_id:
|
||||
raise ValidationError('必须选择扫描引擎')
|
||||
|
||||
if not self.engine_repo.get_by_id(dto.engine_id):
|
||||
raise ValidationError(f'扫描引擎 ID {dto.engine_id} 不存在')
|
||||
|
||||
# 验证扫描模式(organization_id 和 target_id 互斥)
|
||||
if not dto.organization_id and not dto.target_id:
|
||||
raise ValidationError('必须选择组织或扫描目标')
|
||||
@@ -138,11 +210,28 @@ class ScheduledScanService:
|
||||
|
||||
Returns:
|
||||
更新后的 ScheduledScan 对象
|
||||
|
||||
Raises:
|
||||
ConfigConflictError: 引擎配置冲突
|
||||
"""
|
||||
existing = self.repo.get_by_id(scheduled_scan_id)
|
||||
if not existing:
|
||||
return None
|
||||
|
||||
# 如果引擎变更,重新合并配置
|
||||
if dto.engine_ids is not None:
|
||||
engines = []
|
||||
engine_names = []
|
||||
for engine_id in dto.engine_ids:
|
||||
engine = self.engine_repo.get_by_id(engine_id)
|
||||
if engine:
|
||||
engines.append((engine.name, engine.configuration or ''))
|
||||
engine_names.append(engine.name)
|
||||
|
||||
merged_configuration = merge_engine_configs(engines)
|
||||
dto.engine_names = engine_names
|
||||
dto.yaml_configuration = merged_configuration
|
||||
|
||||
# 更新数据库记录
|
||||
scheduled_scan = self.repo.update(scheduled_scan_id, dto)
|
||||
if not scheduled_scan:
|
||||
@@ -292,21 +381,25 @@ class ScheduledScanService:
|
||||
立即触发扫描(支持组织扫描和目标扫描两种模式)
|
||||
|
||||
复用 ScanService 的逻辑,与 API 调用保持一致。
|
||||
使用存储的 yaml_configuration 而不是重新合并。
|
||||
"""
|
||||
from apps.scan.services.scan_service import ScanService
|
||||
|
||||
scan_service = ScanService()
|
||||
|
||||
# 1. 准备扫描所需数据(复用 API 的逻辑)
|
||||
targets, engine = scan_service.prepare_initiate_scan(
|
||||
# 1. 准备扫描所需数据(使用存储的多引擎配置)
|
||||
targets, _, _, _ = scan_service.prepare_initiate_scan_multi_engine(
|
||||
organization_id=scheduled_scan.organization_id,
|
||||
target_id=scheduled_scan.target_id,
|
||||
engine_id=scheduled_scan.engine_id
|
||||
engine_ids=scheduled_scan.engine_ids
|
||||
)
|
||||
|
||||
# 2. 创建扫描任务,传递定时扫描名称用于通知显示
|
||||
# 2. 创建扫描任务,使用存储的合并配置
|
||||
created_scans = scan_service.create_scans(
|
||||
targets, engine,
|
||||
targets=targets,
|
||||
engine_ids=scheduled_scan.engine_ids,
|
||||
engine_names=scheduled_scan.engine_names,
|
||||
yaml_configuration=scheduled_scan.yaml_configuration,
|
||||
scheduled_scan_name=scheduled_scan.name
|
||||
)
|
||||
|
||||
|
||||
138
backend/apps/scan/services/subfinder_provider_config_service.py
Normal file
138
backend/apps/scan/services/subfinder_provider_config_service.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Subfinder Provider 配置文件生成服务
|
||||
|
||||
负责生成 subfinder 的 provider-config.yaml 配置文件
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from ..models import SubfinderProviderSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubfinderProviderConfigService:
|
||||
"""Subfinder Provider 配置文件生成服务"""
|
||||
|
||||
# Provider 格式定义
|
||||
PROVIDER_FORMATS = {
|
||||
'fofa': {'type': 'composite', 'format': '{email}:{api_key}'},
|
||||
'censys': {'type': 'composite', 'format': '{api_id}:{api_secret}'},
|
||||
'hunter': {'type': 'single', 'field': 'api_key'},
|
||||
'shodan': {'type': 'single', 'field': 'api_key'},
|
||||
'zoomeye': {'type': 'single', 'field': 'api_key'},
|
||||
'securitytrails': {'type': 'single', 'field': 'api_key'},
|
||||
'threatbook': {'type': 'single', 'field': 'api_key'},
|
||||
'quake': {'type': 'single', 'field': 'api_key'},
|
||||
}
|
||||
|
||||
def generate(self, output_dir: str) -> Optional[str]:
|
||||
"""
|
||||
生成 provider-config.yaml 文件
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录路径
|
||||
|
||||
Returns:
|
||||
生成的配置文件路径,如果没有启用的 provider 则返回 None
|
||||
"""
|
||||
settings = SubfinderProviderSettings.get_instance()
|
||||
|
||||
config = {}
|
||||
has_enabled = False
|
||||
|
||||
for provider, format_info in self.PROVIDER_FORMATS.items():
|
||||
provider_config = settings.providers.get(provider, {})
|
||||
|
||||
if not provider_config.get('enabled'):
|
||||
config[provider] = []
|
||||
continue
|
||||
|
||||
value = self._build_provider_value(provider, provider_config)
|
||||
if value:
|
||||
config[provider] = [value] # 单个 key 放入数组
|
||||
has_enabled = True
|
||||
logger.debug(f"Provider {provider} 已启用")
|
||||
else:
|
||||
config[provider] = []
|
||||
|
||||
# 检查是否有任何启用的 provider
|
||||
if not has_enabled:
|
||||
logger.info("没有启用的 Provider,跳过配置文件生成")
|
||||
return None
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_dir) / 'provider-config.yaml'
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 写入 YAML 文件(使用默认列表格式,和 subfinder 一致)
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
# 设置文件权限为 600(仅所有者可读写)
|
||||
os.chmod(output_path, 0o600)
|
||||
|
||||
logger.info(f"Provider 配置文件已生成: {output_path}")
|
||||
return str(output_path)
|
||||
|
||||
def _build_provider_value(self, provider: str, config: dict) -> Optional[str]:
|
||||
"""根据 provider 格式规则构建配置值
|
||||
|
||||
Args:
|
||||
provider: provider 名称
|
||||
config: provider 配置字典
|
||||
|
||||
Returns:
|
||||
构建的配置值字符串,如果配置不完整则返回 None
|
||||
"""
|
||||
format_info = self.PROVIDER_FORMATS.get(provider)
|
||||
if not format_info:
|
||||
return None
|
||||
|
||||
if format_info['type'] == 'composite':
|
||||
# 复合格式:需要多个字段
|
||||
format_str = format_info['format']
|
||||
try:
|
||||
# 提取格式字符串中的字段名
|
||||
# 例如 '{email}:{api_key}' -> ['email', 'api_key']
|
||||
import re
|
||||
fields = re.findall(r'\{(\w+)\}', format_str)
|
||||
|
||||
# 检查所有字段是否都有值
|
||||
values = {}
|
||||
for field in fields:
|
||||
value = config.get(field, '').strip()
|
||||
if not value:
|
||||
logger.debug(f"Provider {provider} 缺少字段 {field}")
|
||||
return None
|
||||
values[field] = value
|
||||
|
||||
return format_str.format(**values)
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning(f"构建 {provider} 配置值失败: {e}")
|
||||
return None
|
||||
else:
|
||||
# 单字段格式
|
||||
field = format_info['field']
|
||||
value = config.get(field, '').strip()
|
||||
if not value:
|
||||
logger.debug(f"Provider {provider} 缺少字段 {field}")
|
||||
return None
|
||||
return value
|
||||
|
||||
def cleanup(self, config_path: str) -> None:
|
||||
"""清理配置文件
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
"""
|
||||
try:
|
||||
if config_path and Path(config_path).exists():
|
||||
Path(config_path).unlink()
|
||||
logger.debug(f"已清理配置文件: {config_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清理配置文件失败: {config_path} - {e}")
|
||||
@@ -10,37 +10,58 @@
|
||||
import ipaddress
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Iterator
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from .blacklist_service import BlacklistService
|
||||
from apps.common.utils import BlacklistFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_export_service(target_id: int) -> 'TargetExportService':
|
||||
"""
|
||||
工厂函数:创建带黑名单过滤的导出服务
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID,用于加载黑名单规则
|
||||
|
||||
Returns:
|
||||
TargetExportService: 配置好黑名单过滤器的导出服务实例
|
||||
"""
|
||||
from apps.common.services import BlacklistService
|
||||
|
||||
rules = BlacklistService().get_rules(target_id)
|
||||
blacklist_filter = BlacklistFilter(rules)
|
||||
return TargetExportService(blacklist_filter=blacklist_filter)
|
||||
|
||||
|
||||
class TargetExportService:
|
||||
"""
|
||||
目标导出服务 - 提供统一的目标提取和文件导出功能
|
||||
|
||||
使用方式:
|
||||
# Task 层决定数据源
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
from apps.common.services import BlacklistService
|
||||
from apps.common.utils import BlacklistFilter
|
||||
|
||||
# 获取规则并创建过滤器
|
||||
blacklist_service = BlacklistService()
|
||||
rules = blacklist_service.get_rules(target_id)
|
||||
blacklist_filter = BlacklistFilter(rules)
|
||||
|
||||
# 使用导出服务
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
export_service = TargetExportService(blacklist_filter=blacklist_filter)
|
||||
result = export_service.export_urls(target_id, output_path, queryset)
|
||||
"""
|
||||
|
||||
def __init__(self, blacklist_service: Optional[BlacklistService] = None):
|
||||
def __init__(self, blacklist_filter: Optional[BlacklistFilter] = None):
|
||||
"""
|
||||
初始化导出服务
|
||||
|
||||
Args:
|
||||
blacklist_service: 黑名单过滤服务,None 表示禁用过滤
|
||||
blacklist_filter: 黑名单过滤器,None 表示禁用过滤
|
||||
"""
|
||||
self.blacklist_service = blacklist_service
|
||||
self.blacklist_filter = blacklist_filter
|
||||
|
||||
def export_urls(
|
||||
self,
|
||||
@@ -79,19 +100,15 @@ class TargetExportService:
|
||||
|
||||
logger.info("开始导出 URL - target_id=%s, output=%s", target_id, output_path)
|
||||
|
||||
# 应用黑名单过滤(数据库层面)
|
||||
if self.blacklist_service:
|
||||
# 注意:queryset 应该是原始 queryset,不是 values_list
|
||||
# 这里假设 Task 层传入的是 values_list,需要在 Task 层处理过滤
|
||||
pass
|
||||
|
||||
total_count = 0
|
||||
filtered_count = 0
|
||||
try:
|
||||
with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in queryset.iterator(chunk_size=batch_size):
|
||||
if url:
|
||||
# Python 层面黑名单过滤
|
||||
if self.blacklist_service and not self.blacklist_service.filter_url(url):
|
||||
# 黑名单过滤
|
||||
if self.blacklist_filter and not self.blacklist_filter.is_allowed(url):
|
||||
filtered_count += 1
|
||||
continue
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
@@ -102,6 +119,9 @@ class TargetExportService:
|
||||
logger.error("文件写入失败: %s - %s", output_path, e)
|
||||
raise
|
||||
|
||||
if filtered_count > 0:
|
||||
logger.info("黑名单过滤: 过滤 %d 个 URL", filtered_count)
|
||||
|
||||
# 默认值回退模式
|
||||
if total_count == 0:
|
||||
total_count = self._generate_default_urls(target_id, output_file)
|
||||
@@ -206,18 +226,18 @@ class TargetExportService:
|
||||
|
||||
def _should_write_url(self, url: str) -> bool:
|
||||
"""检查 URL 是否应该写入(通过黑名单过滤)"""
|
||||
if self.blacklist_service:
|
||||
return self.blacklist_service.filter_url(url)
|
||||
if self.blacklist_filter:
|
||||
return self.blacklist_filter.is_allowed(url)
|
||||
return True
|
||||
|
||||
def export_targets(
|
||||
def export_hosts(
|
||||
self,
|
||||
target_id: int,
|
||||
output_path: str,
|
||||
batch_size: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
域名/IP 导出函数(用于端口扫描)
|
||||
主机列表导出函数(用于端口扫描)
|
||||
|
||||
根据 Target 类型选择导出逻辑:
|
||||
- DOMAIN: 从 Subdomain 表流式导出子域名
|
||||
@@ -255,7 +275,7 @@ class TargetExportService:
|
||||
target_name = target.name
|
||||
|
||||
logger.info(
|
||||
"开始导出扫描目标 - Target ID: %d, Name: %s, Type: %s, 输出文件: %s",
|
||||
"开始导出主机列表 - Target ID: %d, Name: %s, Type: %s, 输出文件: %s",
|
||||
target_id, target_name, target_type, output_path
|
||||
)
|
||||
|
||||
@@ -277,7 +297,7 @@ class TargetExportService:
|
||||
raise ValueError(f"不支持的目标类型: {target_type}")
|
||||
|
||||
logger.info(
|
||||
"✓ 扫描目标导出完成 - 类型: %s, 总数: %d, 文件: %s",
|
||||
"✓ 主机列表导出完成 - 类型: %s, 总数: %d, 文件: %s",
|
||||
type_desc, total_count, output_path
|
||||
)
|
||||
|
||||
@@ -295,7 +315,7 @@ class TargetExportService:
|
||||
output_path: Path,
|
||||
batch_size: int
|
||||
) -> int:
|
||||
"""导出域名类型目标的子域名"""
|
||||
"""导出域名类型目标的根域名 + 子域名"""
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
|
||||
subdomain_service = SubdomainService()
|
||||
@@ -305,23 +325,27 @@ class TargetExportService:
|
||||
)
|
||||
|
||||
total_count = 0
|
||||
written_domains = set() # 去重(子域名表可能已包含根域名)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
# 1. 先写入根域名
|
||||
if self._should_write_target(target_name):
|
||||
f.write(f"{target_name}\n")
|
||||
written_domains.add(target_name)
|
||||
total_count += 1
|
||||
|
||||
# 2. 再写入子域名(跳过已写入的根域名)
|
||||
for domain_name in domain_iterator:
|
||||
if domain_name in written_domains:
|
||||
continue
|
||||
if self._should_write_target(domain_name):
|
||||
f.write(f"{domain_name}\n")
|
||||
written_domains.add(domain_name)
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个域名...", total_count)
|
||||
|
||||
# 默认值模式:如果没有子域名,使用根域名
|
||||
if total_count == 0:
|
||||
logger.info("采用默认域名:%s (target_id=%d)", target_name, target_id)
|
||||
if self._should_write_target(target_name):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{target_name}\n")
|
||||
total_count = 1
|
||||
|
||||
return total_count
|
||||
|
||||
def _export_ip(self, target_name: str, output_path: Path) -> int:
|
||||
@@ -359,6 +383,6 @@ class TargetExportService:
|
||||
|
||||
def _should_write_target(self, target: str) -> bool:
|
||||
"""检查目标是否应该写入(通过黑名单过滤)"""
|
||||
if self.blacklist_service:
|
||||
return self.blacklist_service.filter_url(target)
|
||||
if self.blacklist_filter:
|
||||
return self.blacklist_filter.is_allowed(target)
|
||||
return True
|
||||
|
||||
@@ -8,7 +8,8 @@ import logging
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.models import WebSite
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
from apps.scan.services import TargetExportService
|
||||
from apps.scan.services.target_export_service import create_export_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,9 +50,8 @@ def export_sites_task(
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
# 使用工厂函数创建导出服务
|
||||
export_service = create_export_service(target_id)
|
||||
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user