Initial commit: Xingrin v1.0.0

This commit is contained in:
yyhuni
2025-12-12 18:04:57 +08:00
commit 25db990bc3
571 changed files with 227914 additions and 0 deletions

13
.agent/rules/project.md Normal file
View File

@@ -0,0 +1,13 @@
---
trigger: always_on
---
1.后端网页应该是 8888 端口
2.后端请运行虚拟环境再运行命令,环境在项目根目录~/Desktop/scanner/.venv/bin/python
3.前端所有路由加上末尾斜杠,以匹配 django 的 DRF 规则
4.网页测试可以用 curl
8.所有前端 api 接口都应该写在@services 中,所有 type 类型都应该写在@types
10.前端的加载等逻辑用 React Query来实现自动管理
17.所有业务操作的 toast 都放在 hook 中
19.目前后端项目,去不用做安全漏洞方面的相关的代码
23.前端非必要不要采用window.location.href去跳转而是用Next.js 客户端路由

23
.dockerignore Normal file
View File

@@ -0,0 +1,23 @@
# Node modules前端本地开发产物Docker 构建时会重新安装)
frontend/node_modules
frontend/.next
# Python 虚拟环境
.venv
__pycache__
*.pyc
# 日志和临时文件
*.log
.DS_Store
# Git
.git
.gitignore
# IDE
.idea
.vscode
# Docker 相关(避免嵌套)
docker/.env

84
.github/workflows/docker-build.yml vendored Normal file
View File

@@ -0,0 +1,84 @@
name: Build and Push Docker Images
on:
push:
branches: [main]
paths:
- 'backend/**'
- 'frontend/**'
- 'docker/**'
- '.github/workflows/**'
workflow_dispatch: # 手动触发
# 并发控制:同一分支只保留最新的构建,取消之前正在运行的
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
REGISTRY: docker.io
IMAGE_PREFIX: yyhuni
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
include:
- image: xingrin-server
dockerfile: docker/server/Dockerfile
context: .
- image: xingrin-frontend
dockerfile: docker/frontend/Dockerfile
context: .
- image: xingrin-worker
dockerfile: docker/worker/Dockerfile
context: .
- image: xingrin-nginx
dockerfile: docker/nginx/Dockerfile
context: .
- image: xingrin-agent
dockerfile: docker/agent/Dockerfile
context: .
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Free disk space (for large builds like worker)
run: |
echo "=== Before cleanup ==="
df -h
# 删除不需要的大型软件包
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo docker image prune -af
echo "=== After cleanup ==="
df -h
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v5
with:
context: ${{ matrix.context }}
file: ${{ matrix.dockerfile }}
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:latest
${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:${{ github.sha }}
cache-from: type=gha
cache-to: type=gha,mode=max

133
.gitignore vendored Normal file
View File

@@ -0,0 +1,133 @@
# ============================
# 操作系统相关文件
# ============================
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# ============================
# 前端 (Next.js/Node.js) 相关
# ============================
# 依赖目录
front-back/node_modules/
front-back/.pnpm-store/
# Next.js 构建产物
front-back/.next/
front-back/out/
front-back/dist/
# 环境变量文件
front-back/.env
front-back/.env.local
front-back/.env.development.local
front-back/.env.test.local
front-back/.env.production.local
# 运行时和缓存
front-back/.turbo/
front-back/.swc/
front-back/.eslintcache
front-back/.tsbuildinfo
# ============================
# 后端 (Python/Django) 相关
# ============================
# Python 虚拟环境
.venv/
venv/
env/
ENV/
# Python 编译文件
*.pyc
*.pyo
*.pyd
__pycache__/
*.py[cod]
*$py.class
# Django 相关
backend/db.sqlite3
backend/db.sqlite3-journal
backend/media/
backend/staticfiles/
backend/.env
backend/.env.local
# Python 测试和覆盖率
.pytest_cache/
.coverage
htmlcov/
*.cover
# ============================
# 后端 (Go) 相关
# ============================
# 编译产物
backend/bin/
backend/dist/
backend/*.exe
backend/*.exe~
backend/*.dll
backend/*.so
backend/*.dylib
# 测试相关
backend/*.test
backend/*.out
backend/*.prof
# Go workspace 文件
backend/go.work
backend/go.work.sum
# Go 依赖管理
backend/vendor/
# ============================
# IDE 和编辑器相关
# ============================
.vscode/
.idea/
.cursor/
.claude/
.playwright-mcp/
*.swp
*.swo
*~
# ============================
# Docker 相关
# ============================
docker/.env
docker/.env.local
# SSL 证书和私钥(不应提交)
docker/nginx/ssl/*.pem
docker/nginx/ssl/*.key
docker/nginx/ssl/*.crt
# ============================
# 日志文件和扫描结果
# ============================
*.log
logs/
results/
# 开发脚本运行时文件(进程 ID 和启动日志)
backend/scripts/dev/.pids/
# ============================
# 临时文件
# ============================
tmp/
temp/
.cache/
HGETALL
KEYS

View File

@@ -0,0 +1,13 @@
---
trigger: always_on
---
1.后端网页应该是 8888 端口
3.前端所有路由加上末尾斜杠,以匹配 django 的 DRF 规则
4.网页测试可以用 curl
8.所有前端 api 接口都应该写在@services 中,所有 type 类型都应该写在@types
10.前端的加载等逻辑用 React Query来实现自动管理
17.所有业务操作的 toast 都放在 hook 中
19.目前后端项目,去不用做安全漏洞方面的相关的代码
23.前端非必要不要采用window.location.href去跳转而是用Next.js 客户端路由
24.ui相关的都去调用mcp来看看有没有通用组件美观的组件来实现

View File

@@ -0,0 +1,85 @@
---
trigger: manual
description: 进行代码审查的时候,必须调用这个规则
---
### **0. 逻辑正确性 & Bug 排查** *(最高优先级,必须手动推演)*
**目标**:不依赖测试,主动发现“代码能跑但结果错”的逻辑错误。
1. **手动推演关键路径**
- 选 2~3 个典型输入(含边界),**在脑中或纸上一步步推演代码执行流程**。
- 输出是否符合预期?每一步变量变化是否正确?
2. **常见逻辑 bug 检查**
- **off-by-one**:循环、数组索引、分页
- **条件逻辑错误**`and`/`or` 优先级、短路求值误用
- **状态混乱**:变量未初始化、被意外覆盖
- **算法偏差**:排序、搜索、二分查找的中点处理
- **浮点精度**:是否误用 `==` 比较浮点数?
3. **控制流审查**
- 所有 `if/else` 分支是否都覆盖?有无“不可达代码”?
- `switch`/`match` 是否有 `default`?是否漏 case
- 异常路径会返回什么?是否遗漏 `finally` 清理?
4. **业务逻辑一致性**
- 是否符合**业务规则**?(如“订单总额 = 商品价 × 数量 + 运费 - 折扣”)
- 是否遗漏隐含约束?(如“用户只能评价已完成的订单”)
### **一、功能性 & 正确性** *(阻塞性问题必须修复)*
1. **需求符合度**是否100%覆盖需求?遗漏/多余功能点?
2. **边界条件**
- 输入:`null`、空、极值、非法格式
- 集合空、单元素、超大如10⁶
- 循环终止条件、off-by-one
3. **错误处理**
- 异常捕获全面?失败路径有降级?
- 错误信息清晰?不泄露栈迹?
4. **并发安全**
- 竞态/死锁风险?共享资源是否同步?
- 使用了`volatile`/`synchronized`/`Lock`/`atomic`
5. **单元测试**
- 覆盖率 ≥80%?包含正向/边界/异常用例?
- 测试独立?无外部依赖?
### **二、代码质量与可读性**
1. **命名**:见名知意?遵循规范?
2. **函数设计**
- **单一职责**?参数 ≤4建议长度 <50行视语言调整
- 可提取为工具函数?
3. **结构与复杂度**
- 无重复代码?圈复杂度 <10
- 嵌套 ≤3层使用卫语句提前返回
4. **注释**:解释**为什么**而非**是什么**?复杂逻辑必注释
5. **风格一致**:通过`Prettier`/`ESLint`/`Spotless`自动格式化
### **三、架构与设计**
1. **SOLID**:是否符合单一职责、开闭、依赖倒置?
2. **依赖**:是否依赖接口而非实现?无循环依赖?
3. **可测试性**:是否支持依赖注入?避免`new`硬编码
4. **扩展性**:新增功能是否只需改一处?
### **四、性能优化**
- **N+1查询**循环内IO/日志/分配?
- 算法复杂度合理如O(n²)是否可优化)
- 内存:无泄漏?大对象及时释放?缓存有失效?
### **五、其他**
1. **可维护性**:日志带上下文?修改后更干净?
2. **兼容性**API/数据库变更是否向后兼容?
3. **依赖管理**:新库必要?许可证合规?
---
### **审查最佳实践**
- **小批次审查**≤200行/次
- **语气建议**`“建议将函数拆分以提升可读性”` 而非 `“这个函数太长了”`
- **自动化先行**:风格/空指针/安全扫描 → CI工具
- **重点分级**
- 🛑 **阻塞**:功能错、安全漏洞
- ⚠️ **必须改**:设计缺陷、性能瓶颈
- 💡 **建议**:风格、命名、可读性

View File

@@ -0,0 +1,195 @@
---
trigger: always_on
---
## 标准分层架构调用顺序
按照 **DDD领域驱动设计和清洁架构**原则,调用顺序应该是:
```
HTTP请求 → Views → Tasks → Services → Repositories → Models
```
---
### 📊 完整的调用链路图
```
┌─────────────────────────────────────────────────────────────┐
│ HTTP Request (前端) │
└────────────────────────┬────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ Views (HTTP 层) │
│ - 参数验证 │
│ - 权限检查 │
│ - 调用 Tasks/Services │
│ - 返回 HTTP 响应 │
└────────────────────────┬────────────────────────────────────┘
┌────────────────┴────────────────┐
↓ (异步) ↓ (同步)
┌──────────────────┐ ┌──────────────────┐
│ Tasks (任务层) │ │ Services (业务层)│
│ - 异步执行 │ │ - 业务逻辑 │
│ - 后台作业 │───────>│ - 事务管理 │
│ - 通知发送 │ │ - 数据验证 │
└──────────────────┘ └────────┬─────────┘
┌──────────────────────┐
│ Repositories (存储层) │
│ - 数据访问 │
│ - 查询封装 │
│ - 批量操作 │
└────────┬─────────────┘
┌──────────────────────┐
│ Models (模型层) │
│ - ORM 定义 │
│ - 数据结构 │
│ - 关系映射 │
└──────────────────────┘
```
---
### 🔄 具体调用示例
### **场景 1同步删除Views → Services → Repositories → Models**
```python
# 1. Views 层 (views.py)
def some_sync_delete(self, request):
# 参数验证
target_ids = request.data.get('ids')
# 调用 Service 层
service = TargetService()
result = service.bulk_delete_targets(target_ids)
# 返回响应
return Response({'message': 'deleted'})
# 2. Services 层 (services/target_service.py)
class TargetService:
def bulk_delete_targets(self, target_ids):
# 业务逻辑验证
logger.info("准备删除...")
# 调用 Repository 层
deleted_count = self.repo.bulk_delete_by_ids(target_ids)
# 返回结果
return deleted_count
# 3. Repositories 层 (repositories/django_target_repository.py)
class DjangoTargetRepository:
def bulk_delete_by_ids(self, target_ids):
# 数据访问操作
return Target.objects.filter(id__in=target_ids).delete()
# 4. Models 层 (models.py)
class Target(models.Model):
# ORM 定义
name = models.CharField(...)
```
---
### **场景 2异步删除Views → Tasks → Services → Repositories → Models**
```python
# 1. Views 层 (views.py)
def destroy(self, request, *args, **kwargs):
target = self.get_object()
# 调用 Tasks 层(异步)
async_bulk_delete_targets([target.id], [target.name])
# 立即返回 202
return Response(status=202)
# 2. Tasks 层 (tasks/target_tasks.py)
def async_bulk_delete_targets(target_ids, target_names):
def _delete():
# 发送通知
create_notification("删除中...")
# 调用 Service 层
service = TargetService()
result = service.bulk_delete_targets(target_ids)
# 发送完成通知
create_notification("删除成功")
# 后台线程执行
threading.Thread(target=_delete).start()
# 3. Services 层 (services/target_service.py)
class TargetService:
def bulk_delete_targets(self, target_ids):
# 业务逻辑
return self.repo.bulk_delete_by_ids(target_ids)
# 4. Repositories 层 (repositories/django_target_repository.py)
class DjangoTargetRepository:
def bulk_delete_by_ids(self, target_ids):
# 数据访问
return Target.objects.filter(id__in=target_ids).delete()
# 5. Models 层 (models.py)
class Target(models.Model):
# ORM 定义
...
```
---
### 📋 各层职责清单
| 层级 | 职责 | 不应该做 |
| --- | --- | --- |
| **Views** | HTTP 请求处理、参数验证、权限检查 | ❌ 直接访问 Models<br>❌ 业务逻辑 |
| **Tasks** | 异步执行、后台作业、通知发送 | ❌ 直接访问 Models<br>❌ HTTP 响应 |
| **Services** | 业务逻辑、事务管理、数据验证 | ❌ 直接写 SQL<br>❌ HTTP 相关 |
| **Repositories** | 数据访问、查询封装、批量操作 | ❌ 业务逻辑<br>❌ 通知发送 |
| **Models** | ORM 定义、数据结构、关系映射 | ❌ 业务逻辑<br>❌ 复杂查询 |
---
### ✅ 最佳实践原则
1. **单向依赖**:只能向下调用,不能向上调用
```
Views → Tasks → Services → Repositories → Models
(上层) (下层)
```
2. **层级隔离**:相邻层交互,禁止跨层
- ✅ Views → Services
- ✅ Tasks → Services
- ✅ Services → Repositories
- ❌ Views → Repositories跨层
- ❌ Tasks → Models跨层
3. **依赖注入**:通过构造函数注入依赖
```python
class TargetService:
def __init__(self):
self.repo = DjangoTargetRepository() # 注入
```
4. **接口抽象**:使用 Protocol 定义接口
```python
class TargetRepository(Protocol):
def bulk_delete_by_ids(self, ids): ...
```

131
LICENSE Normal file
View File

@@ -0,0 +1,131 @@
# PolyForm Noncommercial License 1.0.0
<https://polyformproject.org/licenses/noncommercial/1.0.0>
## Acceptance
In order to get any license under these terms, you must agree
to them as both strict obligations and conditions to all
your licenses.
## Copyright License
The licensor grants you a copyright license for the
software to do everything you might do with the software
that would otherwise infringe the licensor's copyright
in it for any permitted purpose. However, you may
only distribute the software according to [Distribution
License](#distribution-license) and make changes or new works
based on the software according to [Changes and New Works
License](#changes-and-new-works-license).
## Distribution License
The licensor grants you an additional copyright license
to distribute copies of the software. Your license
to distribute covers distributing the software with
changes and new works permitted by [Changes and New Works
License](#changes-and-new-works-license).
## Notices
You must ensure that anyone who gets a copy of any part of
the software from you also gets a copy of these terms or the
URL for them above, as well as copies of any plain-text lines
beginning with `Required Notice:` that the licensor provided
with the software. For example:
> Required Notice: Copyright Yuhang Yang (yyhuni)
## Changes and New Works License
The licensor grants you an additional copyright license to
make changes and new works based on the software for any
permitted purpose.
## Patent License
The licensor grants you a patent license for the software that
covers patent claims the licensor can license, or becomes able
to license, that you would infringe by using the software.
## Noncommercial Purposes
Any noncommercial purpose is a permitted purpose.
## Personal Uses
Personal use for research, experiment, and testing for
the benefit of public knowledge, personal study, private
entertainment, hobby projects, amateur pursuits, or religious
observance, without any anticipated commercial application,
is use for a permitted purpose.
## Noncommercial Organizations
Use by any charitable organization, educational institution,
public research organization, public safety or health
organization, environmental protection organization,
or government institution is use for a permitted purpose
regardless of the source of funding or obligations resulting
from the funding.
## Fair Use
You may have "fair use" rights for the software under the
law. These terms do not limit them.
## No Other Rights
These terms do not allow you to sublicense or transfer any of
your licenses to anyone else, or prevent the licensor from
granting licenses to anyone else. These terms do not imply
any other licenses.
## Patent Defense
If you make any written claim that the software infringes or
contributes to infringement of any patent, your patent license
for the software granted under these terms ends immediately. If
your company makes such a claim, your patent license ends
immediately for work on behalf of your company.
## Violations
The first time you are notified in writing that you have
violated any of these terms, or done anything with the software
not covered by your licenses, your licenses can nonetheless
continue if you come into full compliance with these terms,
and take practical steps to correct past violations, within
32 days of receiving notice. Otherwise, all your licenses
end immediately.
## No Liability
***As far as the law allows, the software comes as is, without
any warranty or condition, and the licensor will not be liable
to you for any damages arising out of these terms or the use
or nature of the software, under any kind of legal claim.***
## Definitions
The **licensor** is the individual or entity offering these
terms, and the **software** is the software the licensor makes
available under these terms.
**You** refers to the individual or entity agreeing to these
terms.
**Your company** is any legal entity, sole proprietorship,
or other kind of organization that you work for, plus all
organizations that have control over, are under the control of,
or are under common control with that organization. **Control**
means ownership of substantially all the assets of an entity,
or the power to direct its management and policies by vote,
contract, or otherwise. Control can be direct or indirect.
**Your licenses** are all the licenses granted to you for the
software under these terms.
**Use** means anything you do with the software requiring one
of your licenses.

140
README.md Normal file
View File

@@ -0,0 +1,140 @@
<h1 align="center">Xingrin - 星环</h1>
<p align="center">
<b>一款现代化的企业级漏洞扫描与资产管理平台</b><br>
提供自动化安全检测、资产发现、漏洞管理等功能
</p>
---
## ✨ 功能特性
### 🎯 目标与资产管理
- **组织管理** - 多层级目标组织,灵活分组
- **目标管理** - 支持域名、IP、URL 等多种目标类型
- **资产发现** - 子域名、网站、端点、目录自动发现
- **资产快照** - 扫描结果快照对比,追踪资产变化
### 🔍 漏洞扫描
- **多引擎支持** - 集成 Nuclei 等主流扫描引擎
- **自定义流程** - YAML 配置扫描流程,灵活编排
- **漏洞分级** - 严重/高危/中危/低危 四级分类
- **定时扫描** - Cron 表达式配置,自动化周期扫描
### 🖥️ 分布式架构
- **Worker 节点** - 支持多节点分布式扫描
- **本地/远程** - 本地 Docker 节点 + SSH 远程节点
- **负载均衡** - 自动任务分发与负载监控
- **实时状态** - WebSocket 实时推送扫描进度
### 📊 可视化界面
- **数据统计** - 资产/漏洞统计仪表盘
- **实时通知** - WebSocket 消息推送
- **暗色主题** - 支持明暗主题切换
---
## 🛠️ 技术栈
- **前端**: Next.js + React + TailwindCSS
- **后端**: Django + Django REST Framework
- **数据库**: PostgreSQL + Redis
- **部署**: Docker + Nginx
- **扫描引擎**: Nuclei
---
## 📦 快速开始
### 环境要求
- Docker 20.10+
- Docker Compose 2.0+
- 推荐 2核 4G 内存起步
- 10GB+ 磁盘空间
### 一键安装
```bash
# 克隆项目
git clone https://github.com/yyhuni/xingrin.git
cd xingrin
# 安装并启动(生产模式)
sudo ./install.sh
# 开发模式
sudo ./install.sh --dev
```
### 访问服务
- **Web 界面**: `https://localhost``http://localhost`
- **API 接口**: `http://localhost:8888/api/`
- **API 文档**: `http://localhost:8888/swagger/`
### 常用命令
```bash
# 启动服务
sudo ./start.sh
# 停止服务
sudo ./stop.sh
# 重启服务
sudo ./restart.sh
# 卸载
sudo ./uninstall.sh
# 更新
sudo ./update.sh
```
## ⚠️ 免责声明
**重要:请在使用前仔细阅读**
1. 本工具仅供**授权的安全测试**和**安全研究**使用
2. 使用者必须确保已获得目标系统的**合法授权**
3. **严禁**将本工具用于未经授权的渗透测试或攻击行为
4. 未经授权扫描他人系统属于**违法行为**,可能面临法律责任
5. 开发者**不对任何滥用行为负责**
使用本工具即表示您同意:
- 仅在合法授权范围内使用
- 遵守所在地区的法律法规
- 承担因滥用产生的一切后果
## 📄 许可证
本项目采用 [PolyForm Noncommercial License 1.0.0](LICENSE) 许可证。
### 允许的用途
- ✅ 个人学习和研究
- ✅ 非商业安全测试
- ✅ 教育机构使用
- ✅ 非营利组织使用
### 禁止的用途
-**商业用途**包括但不限于出售、商业服务、SaaS 等)
- ❌ 未经授权的渗透测试
- ❌ 任何违法行为
如需商业授权,请联系作者。
## 🤝 反馈与贡献
- 🐛 **发现 Bug** 欢迎提交 [Issue](https://github.com/yyhuni/xingrin/issues)
- 💡 **有新想法?** 欢迎提交功能建议
- 🔧 **想参与开发?** 欢迎提交 Pull Request
## 📧 联系
- GitHub: [@yyhuni](https://github.com/yyhuni)
- 微信公众号: **洋洋的小黑屋**
<img src="docs/wechat-qrcode.png" alt="微信公众号" width="200">

51
backend/.gitignore vendored Normal file
View File

@@ -0,0 +1,51 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
*.egg-info/
dist/
build/
# 虚拟环境
venv/
env/
ENV/
# Django
*.log
db.sqlite3
db.sqlite3-journal
/media
/staticfiles
# 运行时文件Flower、PID
/var/*
!/var/.gitkeep
flower.db
pids/
script/dev/.pids/
# 扫描结果和日志(后端数据)
/results/*
!/results/.gitkeep
/logs/*
!/logs/.gitkeep
# 环境变量(敏感信息)
.env
.env.development
.env.production
.env.staging
# 只提交模板文件:.env.*.example
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
# macOS
.DS_Store

0
backend/apps/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,10 @@
from django.apps import AppConfig
class AssetConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'apps.asset'
def ready(self):
# 导入所有模型以确保Django发现并注册
from . import models

View File

@@ -0,0 +1,28 @@
"""Asset DTOs - 数据传输对象"""
# 资产模块 DTOs
from .asset import (
SubdomainDTO,
WebSiteDTO,
IPAddressDTO,
DirectoryDTO,
PortDTO,
EndpointDTO,
)
# 快照模块 DTOs
from .snapshot import (
SubdomainSnapshotDTO,
)
__all__ = [
# 资产模块
'SubdomainDTO',
'WebSiteDTO',
'IPAddressDTO',
'DirectoryDTO',
'PortDTO',
'EndpointDTO',
# 快照模块
'SubdomainSnapshotDTO',
]

View File

@@ -0,0 +1,21 @@
"""Asset DTOs - 数据传输对象"""
from .subdomain_dto import SubdomainDTO
from .ip_address_dto import IPAddressDTO
from .port_dto import PortDTO
from .website_dto import WebSiteDTO
from .directory_dto import DirectoryDTO
from .host_port_mapping_dto import HostPortMappingDTO
from .endpoint_dto import EndpointDTO
from .vulnerability_dto import VulnerabilityDTO
__all__ = [
'SubdomainDTO',
'IPAddressDTO',
'PortDTO',
'WebSiteDTO',
'DirectoryDTO',
'HostPortMappingDTO',
'EndpointDTO',
'VulnerabilityDTO',
]

View File

@@ -0,0 +1,18 @@
"""Directory DTO"""
from dataclasses import dataclass
from typing import Optional
@dataclass
class DirectoryDTO:
"""目录数据传输对象"""
website_id: int
target_id: int
url: str
status: Optional[int] = None
content_length: Optional[int] = None
words: Optional[int] = None
lines: Optional[int] = None
content_type: str = ''
duration: Optional[int] = None

View File

@@ -0,0 +1,28 @@
"""Endpoint DTO"""
from dataclasses import dataclass
from typing import Optional, List
@dataclass
class EndpointDTO:
"""端点 DTO - 资产表数据传输对象"""
target_id: int
url: str
host: Optional[str] = None
title: Optional[str] = None
status_code: Optional[int] = None
content_length: Optional[int] = None
webserver: Optional[str] = None
body_preview: Optional[str] = None
content_type: Optional[str] = None
tech: Optional[List[str]] = None
vhost: Optional[bool] = None
location: Optional[str] = None
matched_gf_patterns: Optional[List[str]] = None
def __post_init__(self):
if self.tech is None:
self.tech = []
if self.matched_gf_patterns is None:
self.matched_gf_patterns = []

View File

@@ -0,0 +1,12 @@
"""HostPortMapping DTO"""
from dataclasses import dataclass
@dataclass
class HostPortMappingDTO:
"""主机端口映射 DTO资产表"""
target_id: int
host: str
ip: str
port: int

View File

@@ -0,0 +1,17 @@
"""IPAddress DTO"""
from dataclasses import dataclass
@dataclass
class IPAddressDTO:
"""
IP地址数据传输对象
只包含 IP 自身的信息,不包含关联关系。
关联关系通过 SubdomainIPAssociationDTO 管理。
"""
ip: str
protocol_version: str = ''
is_private: bool = False
reverse_pointer: str = ''

View File

@@ -0,0 +1,13 @@
"""Port DTO"""
from dataclasses import dataclass
@dataclass
class PortDTO:
"""端口数据传输对象"""
ip_address_id: int
number: int
service_name: str = ''
target_id: int = None
scan_id: int = None

View File

@@ -0,0 +1,15 @@
"""Subdomain DTO"""
from dataclasses import dataclass
@dataclass
class SubdomainDTO:
"""
子域名 DTO纯资产表
用于传递子域名资产数据,只包含资产本身的信息。
扫描相关信息存储在快照表中。
"""
name: str
target_id: int # 必填:子域名必须属于某个目标

View File

@@ -0,0 +1,18 @@
"""Vulnerability DTO"""
from dataclasses import dataclass, field
from typing import Optional, Dict, Any
from decimal import Decimal
@dataclass
class VulnerabilityDTO:
"""漏洞数据传输对象(资产表用)"""
target_id: int
url: str
vuln_type: str
severity: str
source: str = ""
cvss_score: Optional[Decimal] = None
description: str = ""
raw_output: Dict[str, Any] = field(default_factory=dict)

View File

@@ -0,0 +1,26 @@
"""WebSite DTO"""
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class WebSiteDTO:
"""网站数据传输对象"""
target_id: int
url: str
host: str
title: str = ''
status_code: Optional[int] = None
content_length: Optional[int] = None
location: str = ''
webserver: str = ''
content_type: str = ''
tech: List[str] = None
body_preview: str = ''
vhost: Optional[bool] = None
created_at: str = None
def __post_init__(self):
if self.tech is None:
self.tech = []

View File

@@ -0,0 +1,17 @@
"""Snapshot DTOs"""
from .subdomain_snapshot_dto import SubdomainSnapshotDTO
from .host_port_mapping_snapshot_dto import HostPortMappingSnapshotDTO
from .website_snapshot_dto import WebsiteSnapshotDTO
from .directory_snapshot_dto import DirectorySnapshotDTO
from .endpoint_snapshot_dto import EndpointSnapshotDTO
from .vulnerability_snapshot_dto import VulnerabilitySnapshotDTO
__all__ = [
'SubdomainSnapshotDTO',
'HostPortMappingSnapshotDTO',
'WebsiteSnapshotDTO',
'DirectorySnapshotDTO',
'EndpointSnapshotDTO',
'VulnerabilitySnapshotDTO',
]

View File

@@ -0,0 +1,48 @@
"""Directory Snapshot DTO"""
from dataclasses import dataclass
from typing import Optional
from apps.asset.dtos.asset import DirectoryDTO
@dataclass
class DirectorySnapshotDTO:
"""
目录快照数据传输对象
用于保存扫描过程中发现的目录信息到快照表
注意website_id 和 target_id 只用于传递数据和转换为资产 DTO不会保存到快照表中。
快照只属于 scan。
"""
scan_id: int
website_id: int # 仅用于传递数据,不保存到数据库
target_id: int # 仅用于传递数据,不保存到数据库
url: str
status: Optional[int] = None
content_length: Optional[int] = None
words: Optional[int] = None
lines: Optional[int] = None
content_type: str = ''
duration: Optional[int] = None
def to_asset_dto(self) -> DirectoryDTO:
"""
转换为资产 DTO用于同步到资产表
注意:去除 scan_id 字段,因为资产表不需要
Returns:
DirectoryDTO: 资产表 DTO
"""
return DirectoryDTO(
website_id=self.website_id,
target_id=self.target_id,
url=self.url,
status=self.status,
content_length=self.content_length,
words=self.words,
lines=self.lines,
content_type=self.content_type,
duration=self.duration
)

View File

@@ -0,0 +1,62 @@
"""EndpointSnapshot DTO"""
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class EndpointSnapshotDTO:
"""
端点快照 DTO
注意target_id 只用于传递数据和转换为资产 DTO不会保存到快照表中。
快照只属于 scan。
"""
scan_id: int
url: str
host: str = '' # 主机名域名或IP地址
title: str = ''
status_code: Optional[int] = None
content_length: Optional[int] = None
location: str = ''
webserver: str = ''
content_type: str = ''
tech: List[str] = None
body_preview: str = ''
vhost: Optional[bool] = None
matched_gf_patterns: List[str] = None
target_id: Optional[int] = None # 冗余字段,用于同步到资产表
def __post_init__(self):
if self.tech is None:
self.tech = []
if self.matched_gf_patterns is None:
self.matched_gf_patterns = []
def to_asset_dto(self):
"""
转换为资产 DTO用于同步到资产表
Returns:
EndpointDTO: 资产表 DTO移除 scan_id
"""
from apps.asset.dtos.asset import EndpointDTO
if self.target_id is None:
raise ValueError("target_id 不能为 None无法同步到资产表")
return EndpointDTO(
target_id=self.target_id,
url=self.url,
host=self.host,
title=self.title,
status_code=self.status_code,
content_length=self.content_length,
webserver=self.webserver,
body_preview=self.body_preview,
content_type=self.content_type,
tech=self.tech if self.tech else [],
vhost=self.vhost,
location=self.location,
matched_gf_patterns=self.matched_gf_patterns if self.matched_gf_patterns else []
)

View File

@@ -0,0 +1,33 @@
"""HostPortMappingSnapshot DTO"""
from dataclasses import dataclass
from typing import Optional
@dataclass
class HostPortMappingSnapshotDTO:
"""主机端口映射快照 DTO"""
scan_id: int
host: str
ip: str
port: int
target_id: Optional[int] = None # 冗余字段,用于同步到资产表
def to_asset_dto(self):
"""
转换为资产 DTO用于同步到资产表
Returns:
HostPortMappingDTO: 资产表 DTO移除 scan_id
"""
from apps.asset.dtos.asset import HostPortMappingDTO
if self.target_id is None:
raise ValueError("target_id 不能为 None无法同步到资产表")
return HostPortMappingDTO(
target_id=self.target_id,
host=self.host,
ip=self.ip,
port=self.port
)

View File

@@ -0,0 +1,34 @@
"""SubdomainSnapshot DTO"""
from dataclasses import dataclass
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from apps.asset.dtos import SubdomainDTO
@dataclass
class SubdomainSnapshotDTO:
"""
子域名快照 DTO
用于传递快照数据,包含完整的业务上下文信息。
快照表记录每次扫描的历史数据。
"""
name: str
scan_id: int # 必填:快照必须关联扫描任务
target_id: int # 必填目标ID用于转换为资产 DTO
def to_asset_dto(self) -> 'SubdomainDTO':
"""
转换为资产 DTO用于保存到资产表
Returns:
SubdomainDTO: 资产 DTO不包含 scan_id
Note:
资产表只存储核心数据扫描上下文scan_id不保存到资产表。
target_id 已经包含在 DTO 中,无需额外传参。
"""
from apps.asset.dtos import SubdomainDTO
return SubdomainDTO(name=self.name, target_id=self.target_id)

View File

@@ -0,0 +1,42 @@
"""VulnerabilitySnapshot DTO"""
from dataclasses import dataclass, field
from typing import Optional, Dict, Any
from decimal import Decimal
@dataclass
class VulnerabilitySnapshotDTO:
"""漏洞快照 DTO
对应 VulnerabilitySnapshot 模型,用于在 Service/Task 之间传递漏洞结果数据。
设计与其他快照 DTO 一致:
- scan_id: 只属于快照表
- target_id: 只用于转换为资产 DTO不直接存入快照表
"""
scan_id: int
target_id: int # 仅用于传递数据和生成资产 DTO不保存到快照表
url: str
vuln_type: str
severity: str
source: str = ""
cvss_score: Optional[Decimal] = None
description: str = ""
raw_output: Dict[str, Any] = field(default_factory=dict)
def to_asset_dto(self):
"""转换为漏洞资产 DTO用于同步到 Vulnerability 表)。"""
from apps.asset.dtos.asset import VulnerabilityDTO
return VulnerabilityDTO(
target_id=self.target_id,
url=self.url,
vuln_type=self.vuln_type,
severity=self.severity,
source=self.source,
cvss_score=self.cvss_score,
description=self.description,
raw_output=self.raw_output,
)

View File

@@ -0,0 +1,55 @@
"""WebsiteSnapshot DTO"""
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class WebsiteSnapshotDTO:
"""
网站快照 DTO
注意target_id 只用于传递数据和转换为资产 DTO不会保存到快照表中。
快照只属于 scantarget 信息通过 scan.target 获取。
"""
scan_id: int
target_id: int # 仅用于传递数据,不保存到数据库
url: str
host: str
title: str = ''
status: Optional[int] = None
content_length: Optional[int] = None
location: str = ''
web_server: str = ''
content_type: str = ''
tech: List[str] = None
body_preview: str = ''
vhost: Optional[bool] = None
def __post_init__(self):
if self.tech is None:
self.tech = []
def to_asset_dto(self):
"""
转换为资产 DTO用于同步到资产表
Returns:
WebSiteDTO: 资产表 DTO移除 scan_id
"""
from apps.asset.dtos.asset import WebSiteDTO
return WebSiteDTO(
target_id=self.target_id,
url=self.url,
host=self.host,
title=self.title,
status_code=self.status,
content_length=self.content_length,
location=self.location,
webserver=self.web_server,
content_type=self.content_type,
tech=self.tech if self.tech else [],
body_preview=self.body_preview,
vhost=self.vhost
)

View File

@@ -0,0 +1,45 @@
# 导入所有模型确保Django能发现它们
# 业务模型
from .asset_models import (
Subdomain,
WebSite,
Endpoint,
Directory,
HostPortMapping,
Vulnerability,
)
# 快照模型
from .snapshot_models import (
SubdomainSnapshot,
WebsiteSnapshot,
DirectorySnapshot,
HostPortMappingSnapshot,
EndpointSnapshot,
VulnerabilitySnapshot,
)
# 统计模型
from .statistics_models import AssetStatistics, StatisticsHistory
# 导出所有模型供外部导入
__all__ = [
# 业务模型
'Subdomain',
'WebSite',
'Endpoint',
'Directory',
'HostPortMapping',
'Vulnerability',
# 快照模型
'SubdomainSnapshot',
'WebsiteSnapshot',
'DirectorySnapshot',
'HostPortMappingSnapshot',
'EndpointSnapshot',
'VulnerabilitySnapshot',
# 统计模型
'AssetStatistics',
'StatisticsHistory',
]

View File

@@ -0,0 +1,515 @@
from django.db import models
from django.contrib.postgres.fields import ArrayField
from django.core.validators import MinValueValidator, MaxValueValidator
class SoftDeleteManager(models.Manager):
"""软删除管理器:默认只返回未删除的记录"""
def get_queryset(self):
return super().get_queryset().filter(deleted_at__isnull=True)
class Subdomain(models.Model):
"""
子域名模型(纯资产表)
设计特点:
- 只存储子域名资产信息
- 与其他资产表IPAddress、Port无直接关联
- 扫描历史记录存储在 SubdomainSnapshot 快照表中
"""
id = models.AutoField(primary_key=True)
target = models.ForeignKey(
'targets.Target', # 使用字符串引用避免循环导入
on_delete=models.CASCADE,
related_name='subdomains',
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
)
name = models.CharField(max_length=1000, help_text='子域名名称')
discovered_at = models.DateTimeField(auto_now_add=True, help_text='首次发现时间')
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间NULL表示未删除')
# ==================== 管理器 ====================
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
class Meta:
db_table = 'subdomain'
verbose_name = '子域名'
verbose_name_plural = '子域名'
ordering = ['-discovered_at']
indexes = [
models.Index(fields=['-discovered_at']),
models.Index(fields=['name', 'target']), # 复合索引,优化 get_by_names_and_target_id 批量查询
models.Index(fields=['target']), # 优化从target_id快速查找下面的子域名
models.Index(fields=['name']), # 优化从name快速查找子域名搜索场景
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
]
constraints = [
# 部分唯一约束:只对未删除记录生效
models.UniqueConstraint(
fields=['name', 'target'],
condition=models.Q(deleted_at__isnull=True),
name='unique_name_target_active'
)
]
def __str__(self):
return str(self.name or f'Subdomain {self.id}')
class Endpoint(models.Model):
"""端点模型"""
id = models.AutoField(primary_key=True)
target = models.ForeignKey(
'targets.Target', # 使用字符串引用
on_delete=models.CASCADE,
related_name='endpoints',
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
)
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
host = models.CharField(
max_length=253,
blank=True,
default='',
help_text='主机名域名或IP地址'
)
location = models.CharField(
max_length=1000,
blank=True,
default='',
help_text='重定向地址HTTP 3xx 响应头 Location'
)
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
title = models.CharField(
max_length=1000,
blank=True,
default='',
help_text='网页标题HTML <title> 标签内容)'
)
webserver = models.CharField(
max_length=200,
blank=True,
default='',
help_text='服务器类型HTTP 响应头 Server 值)'
)
body_preview = models.CharField(
max_length=1000,
blank=True,
default='',
help_text='响应正文前N个字符默认100个字符'
)
content_type = models.CharField(
max_length=200,
blank=True,
default='',
help_text='响应类型HTTP Content-Type 响应头)'
)
tech = ArrayField(
models.CharField(max_length=100),
blank=True,
default=list,
help_text='技术栈(服务器/框架/语言等)'
)
status_code = models.IntegerField(
null=True,
blank=True,
help_text='HTTP状态码'
)
content_length = models.IntegerField(
null=True,
blank=True,
help_text='响应体大小(单位字节)'
)
vhost = models.BooleanField(
null=True,
blank=True,
help_text='是否支持虚拟主机'
)
matched_gf_patterns = ArrayField(
models.CharField(max_length=100),
blank=True,
default=list,
help_text='匹配的GF模式列表用于识别敏感端点如api, debug, config等'
)
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间NULL表示未删除')
# ==================== 管理器 ====================
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
class Meta:
db_table = 'endpoint'
verbose_name = '端点'
verbose_name_plural = '端点'
ordering = ['-discovered_at']
indexes = [
models.Index(fields=['-discovered_at']),
models.Index(fields=['target']), # 优化从target_id快速查找下面的端点主关联字段
models.Index(fields=['url']), # URL索引优化查询性能
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['status_code']), # 状态码索引,优化筛选
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
]
constraints = [
# 部分唯一约束:只对未删除记录生效
models.UniqueConstraint(
fields=['url', 'target'],
condition=models.Q(deleted_at__isnull=True),
name='unique_endpoint_url_target_active'
)
]
def __str__(self):
return str(self.url or f'Endpoint {self.id}')
class WebSite(models.Model):
"""站点模型"""
id = models.AutoField(primary_key=True)
target = models.ForeignKey(
'targets.Target', # 使用字符串引用
on_delete=models.CASCADE,
related_name='websites',
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
)
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
host = models.CharField(
max_length=253,
blank=True,
default='',
help_text='主机名域名或IP地址'
)
location = models.CharField(
max_length=1000,
blank=True,
default='',
help_text='重定向地址HTTP 3xx 响应头 Location'
)
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
title = models.CharField(
max_length=1000,
blank=True,
default='',
help_text='网页标题HTML <title> 标签内容)'
)
webserver = models.CharField(
max_length=200,
blank=True,
default='',
help_text='服务器类型HTTP 响应头 Server 值)'
)
body_preview = models.CharField(
max_length=1000,
blank=True,
default='',
help_text='响应正文前N个字符默认100个字符'
)
content_type = models.CharField(
max_length=200,
blank=True,
default='',
help_text='响应类型HTTP Content-Type 响应头)'
)
tech = ArrayField(
models.CharField(max_length=100),
blank=True,
default=list,
help_text='技术栈(服务器/框架/语言等)'
)
status_code = models.IntegerField(
null=True,
blank=True,
help_text='HTTP状态码'
)
content_length = models.IntegerField(
null=True,
blank=True,
help_text='响应体大小(单位字节)'
)
vhost = models.BooleanField(
null=True,
blank=True,
help_text='是否支持虚拟主机'
)
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间NULL表示未删除')
# ==================== 管理器 ====================
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
class Meta:
db_table = 'website'
verbose_name = '站点'
verbose_name_plural = '站点'
ordering = ['-discovered_at']
indexes = [
models.Index(fields=['-discovered_at']),
models.Index(fields=['url']), # URL索引优化查询性能
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['target']), # 优化从target_id快速查找下面的站点
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
]
constraints = [
# 部分唯一约束:只对未删除记录生效
models.UniqueConstraint(
fields=['url', 'target'],
condition=models.Q(deleted_at__isnull=True),
name='unique_website_url_target_active'
)
]
def __str__(self):
return str(self.url or f'Website {self.id}')
class Directory(models.Model):
"""
目录模型
"""
id = models.AutoField(primary_key=True)
website = models.ForeignKey(
'Website',
on_delete=models.CASCADE,
related_name='directories',
help_text='所属的站点(主关联字段,表示所属关系,不能为空)'
)
target = models.ForeignKey(
'targets.Target', # 使用字符串引用
on_delete=models.CASCADE,
related_name='directories',
null=True,
blank=True,
help_text='所属的扫描目标(冗余字段,用于快速查询)'
)
url = models.CharField(
null=False,
blank=False,
max_length=2000,
help_text='完整请求 URL'
)
status = models.IntegerField(
null=True,
blank=True,
help_text='HTTP 响应状态码'
)
content_length = models.BigIntegerField(
null=True,
blank=True,
help_text='响应体字节大小Content-Length 或实际长度)'
)
words = models.IntegerField(
null=True,
blank=True,
help_text='响应体中单词数量(按空格分割)'
)
lines = models.IntegerField(
null=True,
blank=True,
help_text='响应体行数(按换行符分割)'
)
content_type = models.CharField(
max_length=200,
blank=True,
default='',
help_text='响应头 Content-Type 值'
)
duration = models.BigIntegerField(
null=True,
blank=True,
help_text='请求耗时(单位:纳秒)'
)
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间NULL表示未删除')
# ==================== 管理器 ====================
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
class Meta:
db_table = 'directory'
verbose_name = '目录'
verbose_name_plural = '目录'
ordering = ['-discovered_at']
indexes = [
models.Index(fields=['-discovered_at']),
models.Index(fields=['target']), # 优化从target_id快速查找下面的目录
models.Index(fields=['url']), # URL索引优化搜索和唯一约束
models.Index(fields=['website']), # 站点索引,优化按站点查询
models.Index(fields=['status']), # 状态码索引,优化筛选
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
]
constraints = [
# 部分唯一约束:只对未删除记录生效
models.UniqueConstraint(
fields=['website', 'url'],
condition=models.Q(deleted_at__isnull=True),
name='unique_directory_url_website_active'
),
]
def __str__(self):
return str(self.url or f'Directory {self.id}')
class HostPortMapping(models.Model):
"""
主机端口映射表
设计特点:
- 存储主机host、IP、端口的三元映射关系
- 只关联 target_id不关联其他资产表
- target + host + ip + port 组成复合唯一约束
"""
id = models.AutoField(primary_key=True)
# ==================== 关联字段 ====================
target = models.ForeignKey(
'targets.Target',
on_delete=models.CASCADE,
related_name='host_port_mappings',
help_text='所属的扫描目标'
)
# ==================== 核心字段 ====================
host = models.CharField(
max_length=1000,
blank=False,
help_text='主机名域名或IP'
)
ip = models.GenericIPAddressField(
blank=False,
help_text='IP地址'
)
port = models.IntegerField(
blank=False,
validators=[
MinValueValidator(1, message='端口号必须大于等于1'),
MaxValueValidator(65535, message='端口号必须小于等于65535')
],
help_text='端口号1-65535'
)
# ==================== 时间字段 ====================
discovered_at = models.DateTimeField(
auto_now_add=True,
help_text='发现时间'
)
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(
null=True,
blank=True,
db_index=True,
help_text='删除时间NULL表示未删除'
)
# ==================== 管理器 ====================
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
class Meta:
db_table = 'host_port_mapping'
verbose_name = '主机端口映射'
verbose_name_plural = '主机端口映射'
ordering = ['-discovered_at']
indexes = [
models.Index(fields=['target']), # 优化按目标查询
models.Index(fields=['host']), # 优化按主机名查询
models.Index(fields=['ip']), # 优化按IP查询
models.Index(fields=['port']), # 优化按端口查询
models.Index(fields=['host', 'ip']), # 优化组合查询
models.Index(fields=['-discovered_at']), # 优化时间排序
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
]
constraints = [
# 复合唯一约束target + host + ip + port 组合唯一(只对未删除记录生效)
models.UniqueConstraint(
fields=['target', 'host', 'ip', 'port'],
condition=models.Q(deleted_at__isnull=True),
name='unique_target_host_ip_port_active'
),
]
def __str__(self):
return f'{self.host} ({self.ip}:{self.port})'
class Vulnerability(models.Model):
"""
漏洞模型(资产表)
存储发现的漏洞资产,与 Target 关联。
扫描历史记录存储在 VulnerabilitySnapshot 快照表中。
"""
# 延迟导入避免循环引用
from apps.common.definitions import VulnSeverity
id = models.AutoField(primary_key=True)
target = models.ForeignKey(
'targets.Target',
on_delete=models.CASCADE,
related_name='vulnerabilities',
help_text='所属的扫描目标'
)
# ==================== 核心字段 ====================
url = models.TextField(help_text='漏洞所在的URL')
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli')
severity = models.CharField(
max_length=20,
choices=VulnSeverity.choices,
default=VulnSeverity.UNKNOWN,
help_text='严重性(未知/信息/低/中/高/危急)'
)
source = models.CharField(max_length=50, blank=True, default='', help_text='来源工具(如 dalfox, nuclei, crlfuzz')
cvss_score = models.DecimalField(max_digits=3, decimal_places=1, null=True, blank=True, help_text='CVSS 评分0.0-10.0')
description = models.TextField(blank=True, default='', help_text='漏洞描述')
raw_output = models.JSONField(blank=True, default=dict, help_text='工具原始输出')
# ==================== 时间字段 ====================
discovered_at = models.DateTimeField(auto_now_add=True, help_text='首次发现时间')
# ==================== 软删除字段 ====================
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间NULL表示未删除')
# ==================== 管理器 ====================
objects = SoftDeleteManager()
all_objects = models.Manager()
class Meta:
db_table = 'vulnerability'
verbose_name = '漏洞'
verbose_name_plural = '漏洞'
ordering = ['-discovered_at']
indexes = [
models.Index(fields=['target']),
models.Index(fields=['vuln_type']),
models.Index(fields=['severity']),
models.Index(fields=['source']),
models.Index(fields=['-discovered_at']),
models.Index(fields=['deleted_at', '-discovered_at']),
]
def __str__(self):
return f'{self.vuln_type} - {self.url[:50]}'

View File

@@ -0,0 +1,335 @@
from django.db import models
from django.contrib.postgres.fields import ArrayField
from django.core.validators import MinValueValidator, MaxValueValidator
class SubdomainSnapshot(models.Model):
"""子域名快照"""
id = models.AutoField(primary_key=True)
scan = models.ForeignKey(
'scan.Scan',
on_delete=models.CASCADE,
related_name='subdomain_snapshots',
help_text='所属的扫描任务'
)
name = models.CharField(max_length=1000, help_text='子域名名称')
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
class Meta:
db_table = 'subdomain_snapshot'
verbose_name = '子域名快照'
verbose_name_plural = '子域名快照'
ordering = ['-discovered_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['name']),
models.Index(fields=['-discovered_at']),
]
constraints = [
# 唯一约束:同一次扫描中,同一个子域名只能记录一次
models.UniqueConstraint(
fields=['scan', 'name'],
name='unique_subdomain_per_scan_snapshot'
),
]
def __str__(self):
return f'{self.name} (Scan #{self.scan_id})'
class WebsiteSnapshot(models.Model):
"""
网站快照
记录:某次扫描中发现的网站
"""
id = models.AutoField(primary_key=True)
scan = models.ForeignKey(
'scan.Scan',
on_delete=models.CASCADE,
related_name='website_snapshots',
help_text='所属的扫描任务'
)
# 扫描结果数据
url = models.CharField(max_length=2000, help_text='站点URL')
host = models.CharField(max_length=253, blank=True, default='', help_text='主机名域名或IP地址')
title = models.CharField(max_length=500, blank=True, default='', help_text='页面标题')
status = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
content_length = models.BigIntegerField(null=True, blank=True, help_text='内容长度')
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
web_server = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
tech = ArrayField(
models.CharField(max_length=100),
blank=True,
default=list,
help_text='技术栈'
)
body_preview = models.TextField(blank=True, default='', help_text='响应体预览')
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
class Meta:
db_table = 'website_snapshot'
verbose_name = '网站快照'
verbose_name_plural = '网站快照'
ordering = ['-discovered_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['url']),
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['-discovered_at']),
]
constraints = [
# 唯一约束同一次扫描中同一个URL只能记录一次
models.UniqueConstraint(
fields=['scan', 'url'],
name='unique_website_per_scan_snapshot'
),
]
def __str__(self):
return f'{self.url} (Scan #{self.scan_id})'
class DirectorySnapshot(models.Model):
"""
目录快照
记录:某次扫描中发现的目录
"""
id = models.AutoField(primary_key=True)
scan = models.ForeignKey(
'scan.Scan',
on_delete=models.CASCADE,
related_name='directory_snapshots',
help_text='所属的扫描任务'
)
# 扫描结果数据
url = models.CharField(max_length=2000, help_text='目录URL')
status = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
content_length = models.BigIntegerField(null=True, blank=True, help_text='内容长度')
words = models.IntegerField(null=True, blank=True, help_text='响应体中单词数量(按空格分割)')
lines = models.IntegerField(null=True, blank=True, help_text='响应体行数(按换行符分割)')
content_type = models.CharField(max_length=200, blank=True, default='', help_text='响应头 Content-Type 值')
duration = models.BigIntegerField(null=True, blank=True, help_text='请求耗时(单位:纳秒)')
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
class Meta:
db_table = 'directory_snapshot'
verbose_name = '目录快照'
verbose_name_plural = '目录快照'
ordering = ['-discovered_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['url']),
models.Index(fields=['status']), # 状态码索引,优化筛选
models.Index(fields=['-discovered_at']),
]
constraints = [
# 唯一约束同一次扫描中同一个目录URL只能记录一次
models.UniqueConstraint(
fields=['scan', 'url'],
name='unique_directory_per_scan_snapshot'
),
]
def __str__(self):
return f'{self.url} (Scan #{self.scan_id})'
class HostPortMappingSnapshot(models.Model):
"""
主机端口映射快照表
设计特点:
- 存储某次扫描中发现的主机host、IP、端口的三元映射关系
- 主关联 scan_id记录扫描历史
- scan + host + ip + port 组成复合唯一约束
"""
id = models.AutoField(primary_key=True)
# ==================== 关联字段 ====================
scan = models.ForeignKey(
'scan.Scan',
on_delete=models.CASCADE,
related_name='host_port_mapping_snapshots',
help_text='所属的扫描任务(主关联)'
)
# ==================== 核心字段 ====================
host = models.CharField(
max_length=1000,
blank=False,
help_text='主机名域名或IP'
)
ip = models.GenericIPAddressField(
blank=False,
help_text='IP地址'
)
port = models.IntegerField(
blank=False,
validators=[
MinValueValidator(1, message='端口号必须大于等于1'),
MaxValueValidator(65535, message='端口号必须小于等于65535')
],
help_text='端口号1-65535'
)
# ==================== 时间字段 ====================
discovered_at = models.DateTimeField(
auto_now_add=True,
help_text='发现时间'
)
class Meta:
db_table = 'host_port_mapping_snapshot'
verbose_name = '主机端口映射快照'
verbose_name_plural = '主机端口映射快照'
ordering = ['-discovered_at']
indexes = [
models.Index(fields=['scan']), # 优化按扫描查询
models.Index(fields=['host']), # 优化按主机名查询
models.Index(fields=['ip']), # 优化按IP查询
models.Index(fields=['port']), # 优化按端口查询
models.Index(fields=['host', 'ip']), # 优化组合查询
models.Index(fields=['scan', 'host']), # 优化扫描+主机查询
models.Index(fields=['-discovered_at']), # 优化时间排序
]
constraints = [
# 复合唯一约束同一次扫描中scan + host + ip + port 组合唯一
models.UniqueConstraint(
fields=['scan', 'host', 'ip', 'port'],
name='unique_scan_host_ip_port_snapshot'
),
]
def __str__(self):
return f'{self.host} ({self.ip}:{self.port}) [Scan #{self.scan_id}]'
class EndpointSnapshot(models.Model):
"""
端点快照
记录:某次扫描中发现的端点
"""
id = models.AutoField(primary_key=True)
scan = models.ForeignKey(
'scan.Scan',
on_delete=models.CASCADE,
related_name='endpoint_snapshots',
help_text='所属的扫描任务'
)
# 扫描结果数据
url = models.CharField(max_length=2000, help_text='端点URL')
host = models.CharField(
max_length=253,
blank=True,
default='',
help_text='主机名域名或IP地址'
)
title = models.CharField(max_length=1000, blank=True, default='', help_text='页面标题')
status_code = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
content_length = models.IntegerField(null=True, blank=True, help_text='内容长度')
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
webserver = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
tech = ArrayField(
models.CharField(max_length=100),
blank=True,
default=list,
help_text='技术栈'
)
body_preview = models.CharField(max_length=1000, blank=True, default='', help_text='响应体预览')
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
matched_gf_patterns = ArrayField(
models.CharField(max_length=100),
blank=True,
default=list,
help_text='匹配的GF模式列表'
)
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
class Meta:
db_table = 'endpoint_snapshot'
verbose_name = '端点快照'
verbose_name_plural = '端点快照'
ordering = ['-discovered_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['url']),
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['status_code']), # 状态码索引,优化筛选
models.Index(fields=['-discovered_at']),
]
constraints = [
# 唯一约束同一次扫描中同一个URL只能记录一次
models.UniqueConstraint(
fields=['scan', 'url'],
name='unique_endpoint_per_scan_snapshot'
),
]
def __str__(self):
return f'{self.url} (Scan #{self.scan_id})'
class VulnerabilitySnapshot(models.Model):
"""
漏洞快照
记录:某次扫描中发现的漏洞
"""
# 延迟导入避免循环引用
from apps.common.definitions import VulnSeverity
id = models.AutoField(primary_key=True)
scan = models.ForeignKey(
'scan.Scan',
on_delete=models.CASCADE,
related_name='vulnerability_snapshots',
help_text='所属的扫描任务'
)
# ==================== 核心字段 ====================
url = models.TextField(help_text='漏洞所在的URL')
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli')
severity = models.CharField(
max_length=20,
choices=VulnSeverity.choices,
default=VulnSeverity.UNKNOWN,
help_text='严重性(未知/信息/低/中/高/危急)'
)
source = models.CharField(max_length=50, blank=True, default='', help_text='来源工具(如 dalfox, nuclei, crlfuzz')
cvss_score = models.DecimalField(max_digits=3, decimal_places=1, null=True, blank=True, help_text='CVSS 评分0.0-10.0')
description = models.TextField(blank=True, default='', help_text='漏洞描述')
raw_output = models.JSONField(blank=True, default=dict, help_text='工具原始输出')
# ==================== 时间字段 ====================
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
class Meta:
db_table = 'vulnerability_snapshot'
verbose_name = '漏洞快照'
verbose_name_plural = '漏洞快照'
ordering = ['-discovered_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['vuln_type']),
models.Index(fields=['severity']),
models.Index(fields=['source']),
models.Index(fields=['-discovered_at']),
]
def __str__(self):
return f'{self.vuln_type} - {self.url[:50]} (Scan #{self.scan_id})'

View File

@@ -0,0 +1,82 @@
from django.db import models
class AssetStatistics(models.Model):
"""
资产统计表
存储预聚合的全局统计数据,避免仪表盘实时 COUNT 大表。
由定时任务Prefect Flow定期刷新。
"""
id = models.AutoField(primary_key=True)
# ==================== 当前统计字段 ====================
total_targets = models.IntegerField(default=0, help_text='目标总数')
total_subdomains = models.IntegerField(default=0, help_text='子域名总数')
total_ips = models.IntegerField(default=0, help_text='IP地址总数')
total_endpoints = models.IntegerField(default=0, help_text='端点总数')
total_websites = models.IntegerField(default=0, help_text='网站总数')
total_vulns = models.IntegerField(default=0, help_text='漏洞总数')
total_assets = models.IntegerField(default=0, help_text='总资产数(子域名+IP+端点+网站)')
# ==================== 上次统计字段(用于计算趋势)====================
prev_targets = models.IntegerField(default=0, help_text='上次目标总数')
prev_subdomains = models.IntegerField(default=0, help_text='上次子域名总数')
prev_ips = models.IntegerField(default=0, help_text='上次IP地址总数')
prev_endpoints = models.IntegerField(default=0, help_text='上次端点总数')
prev_websites = models.IntegerField(default=0, help_text='上次网站总数')
prev_vulns = models.IntegerField(default=0, help_text='上次漏洞总数')
prev_assets = models.IntegerField(default=0, help_text='上次总资产数')
# ==================== 时间字段 ====================
updated_at = models.DateTimeField(auto_now=True, help_text='最后更新时间')
class Meta:
db_table = 'asset_statistics'
verbose_name = '资产统计'
verbose_name_plural = '资产统计'
def __str__(self):
return f'AssetStatistics (updated: {self.updated_at})'
@classmethod
def get_or_create_singleton(cls) -> 'AssetStatistics':
"""获取或创建单例统计记录"""
obj, _ = cls.objects.get_or_create(pk=1)
return obj
class StatisticsHistory(models.Model):
"""
统计历史表(用于折线图)
每天记录一条快照,用于展示趋势。
由定时任务在刷新统计时自动写入。
"""
date = models.DateField(unique=True, help_text='统计日期')
# 各类资产数量
total_targets = models.IntegerField(default=0, help_text='目标总数')
total_subdomains = models.IntegerField(default=0, help_text='子域名总数')
total_ips = models.IntegerField(default=0, help_text='IP地址总数')
total_endpoints = models.IntegerField(default=0, help_text='端点总数')
total_websites = models.IntegerField(default=0, help_text='网站总数')
total_vulns = models.IntegerField(default=0, help_text='漏洞总数')
total_assets = models.IntegerField(default=0, help_text='总资产数')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
class Meta:
db_table = 'statistics_history'
verbose_name = '统计历史'
verbose_name_plural = '统计历史'
ordering = ['-date']
indexes = [
models.Index(fields=['date']),
]
def __str__(self):
return f'StatisticsHistory ({self.date})'

View File

@@ -0,0 +1,41 @@
"""Asset Repositories - 数据访问层"""
# 资产模块 Repositories
from .asset import (
DjangoSubdomainRepository,
DjangoWebSiteRepository,
DjangoDirectoryRepository,
DjangoHostPortMappingRepository,
DjangoEndpointRepository,
)
# 快照模块 Repositories
from .snapshot import (
DjangoSubdomainSnapshotRepository,
DjangoHostPortMappingSnapshotRepository,
DjangoWebsiteSnapshotRepository,
DjangoDirectorySnapshotRepository,
DjangoEndpointSnapshotRepository,
)
# 统计模块 Repository
from .statistics_repository import AssetStatisticsRepository
__all__ = [
# 资产模块
'DjangoSubdomainRepository',
'DjangoWebSiteRepository',
'DjangoDirectoryRepository',
'DjangoHostPortMappingRepository',
'DjangoEndpointRepository',
# 快照模块
'DjangoSubdomainSnapshotRepository',
'DjangoHostPortMappingSnapshotRepository',
'DjangoWebsiteSnapshotRepository',
'DjangoDirectorySnapshotRepository',
'DjangoEndpointSnapshotRepository',
# 统计模块
'AssetStatisticsRepository',
]

View File

@@ -0,0 +1,15 @@
"""Asset Repositories - 数据访问层"""
from .subdomain_repository import DjangoSubdomainRepository
from .website_repository import DjangoWebSiteRepository
from .directory_repository import DjangoDirectoryRepository
from .host_port_mapping_repository import DjangoHostPortMappingRepository
from .endpoint_repository import DjangoEndpointRepository
__all__ = [
'DjangoSubdomainRepository',
'DjangoWebSiteRepository',
'DjangoDirectoryRepository',
'DjangoHostPortMappingRepository',
'DjangoEndpointRepository',
]

View File

@@ -0,0 +1,249 @@
"""
Django ORM 实现的 Directory Repository
"""
import logging
from typing import List, Tuple, Dict, Iterator
from django.db import transaction, IntegrityError, OperationalError, DatabaseError
from django.utils import timezone
from apps.asset.models.asset_models import Directory
from apps.asset.dtos import DirectoryDTO
from apps.common.decorators import auto_ensure_db_connection
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoDirectoryRepository:
"""Django ORM 实现的 Directory Repository"""
def bulk_create_ignore_conflicts(self, items: List[DirectoryDTO]) -> int:
"""
批量创建 Directory忽略冲突
Args:
items: Directory DTO 列表
Returns:
int: 实际创建的记录数
Raises:
IntegrityError: 数据完整性错误
OperationalError: 数据库操作错误
DatabaseError: 数据库错误
"""
if not items:
return 0
try:
# 转换为 Django 模型对象
directory_objects = [
Directory(
website_id=item.website_id,
target_id=item.target_id,
url=item.url,
status=item.status,
content_length=item.content_length,
words=item.words,
lines=item.lines,
content_type=item.content_type,
duration=item.duration
)
for item in items
]
with transaction.atomic():
# 批量插入或忽略冲突
# 如果 website + url 已存在,忽略冲突
Directory.objects.bulk_create(
directory_objects,
ignore_conflicts=True
)
logger.debug(f"成功处理 {len(items)} 条 Directory 记录")
return len(items)
except IntegrityError as e:
logger.error(
f"批量插入 Directory 失败 - 数据完整性错误: {e}, "
f"记录数: {len(items)}"
)
raise
except OperationalError as e:
logger.error(
f"批量插入 Directory 失败 - 数据库操作错误: {e}, "
f"记录数: {len(items)}"
)
raise
except DatabaseError as e:
logger.error(
f"批量插入 Directory 失败 - 数据库错误: {e}, "
f"记录数: {len(items)}"
)
raise
except Exception as e:
logger.error(
f"批量插入 Directory 失败 - 未知错误: {e}, "
f"记录数: {len(items)}, "
f"错误类型: {type(e).__name__}",
exc_info=True
)
raise
def get_by_website(self, website_id: int) -> List[DirectoryDTO]:
"""
获取指定站点的所有目录
Args:
website_id: 站点 ID
Returns:
List[DirectoryDTO]: 目录列表
"""
try:
directories = Directory.objects.filter(website_id=website_id)
return [
DirectoryDTO(
website_id=d.website_id,
target_id=d.target_id,
url=d.url,
status=d.status,
content_length=d.content_length,
words=d.words,
lines=d.lines,
content_type=d.content_type,
duration=d.duration
)
for d in directories
]
except Exception as e:
logger.error(f"获取目录列表失败 - Website ID: {website_id}, 错误: {e}")
raise
def count_by_website(self, website_id: int) -> int:
"""
统计指定站点的目录总数
Args:
website_id: 站点 ID
Returns:
int: 目录总数
"""
try:
count = Directory.objects.filter(website_id=website_id).count()
logger.debug(f"Website {website_id} 的目录总数: {count}")
return count
except Exception as e:
logger.error(f"统计目录数量失败 - Website ID: {website_id}, 错误: {e}")
raise
def get_all(self):
"""
获取所有目录
Returns:
QuerySet: 目录查询集
"""
return Directory.objects.all()
def get_by_target(self, target_id: int):
return Directory.objects.filter(target_id=target_id).select_related('website').order_by('-discovered_at')
def get_urls_for_export(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
"""流式导出目标下的所有目录 URL只查 url 字段,避免加载多余数据)。"""
try:
queryset = (
Directory.objects
.filter(target_id=target_id)
.values_list('url', flat=True)
.order_by('url')
.iterator(chunk_size=batch_size)
)
for url in queryset:
yield url
except Exception as e:
logger.error("流式导出目录 URL 失败 - Target ID: %s, 错误: %s", target_id, e)
raise
def soft_delete_by_ids(self, directory_ids: List[int]) -> int:
"""
根据 ID 列表批量软删除Directory
Args:
directory_ids: Directory ID 列表
Returns:
软删除的记录数
"""
try:
updated_count = (
Directory.objects
.filter(id__in=directory_ids)
.update(deleted_at=timezone.now())
)
logger.debug(
"批量软删除Directory成功 - Count: %s, 更新记录: %s",
len(directory_ids),
updated_count
)
return updated_count
except Exception as e:
logger.error(
"批量软删除Directory失败 - IDs: %s, 错误: %s",
directory_ids,
e
)
raise
def hard_delete_by_ids(self, directory_ids: List[int]) -> Tuple[int, Dict[str, int]]:
"""
根据 ID 列表硬删除Directory使用数据库级 CASCADE
Args:
directory_ids: Directory ID 列表
Returns:
(删除的记录数, 删除详情字典)
"""
try:
batch_size = 1000
total_deleted = 0
logger.debug(f"开始批量删除 {len(directory_ids)} 个Directory数据库 CASCADE...")
for i in range(0, len(directory_ids), batch_size):
batch_ids = directory_ids[i:i + batch_size]
count, _ = Directory.all_objects.filter(id__in=batch_ids).delete()
total_deleted += count
logger.debug(f"批次删除完成: {len(batch_ids)} 个Directory删除 {count} 条记录")
deleted_details = {
'directories': len(directory_ids),
'total': total_deleted,
'note': 'Database CASCADE - detailed stats unavailable'
}
logger.debug(
"批量硬删除成功CASCADE- Directory数: %s, 总删除记录: %s",
len(directory_ids),
total_deleted
)
return total_deleted, deleted_details
except Exception as e:
logger.error(
"批量硬删除失败CASCADE- Directory数: %s, 错误: %s",
len(directory_ids),
str(e),
exc_info=True
)
raise

View File

@@ -0,0 +1,192 @@
"""Endpoint Repository - Django ORM 实现"""
import logging
from typing import List, Optional, Tuple, Dict, Any
from apps.asset.models import Endpoint
from apps.asset.dtos.asset import EndpointDTO
from apps.common.decorators import auto_ensure_db_connection
from django.db import transaction
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoEndpointRepository:
"""端点 Repository - 负责端点表的数据访问"""
def bulk_create_ignore_conflicts(self, items: List[EndpointDTO]) -> int:
"""
批量创建端点(忽略冲突)
Args:
items: 端点 DTO 列表
Returns:
int: 创建的记录数
"""
if not items:
return 0
try:
endpoints = []
for item in items:
# Endpoint 模型当前只关联 target不再依赖 website 外键
# 这里按照 EndpointDTO 的字段映射构造 Endpoint 实例
endpoints.append(Endpoint(
target_id=item.target_id,
url=item.url,
host=item.host or '',
title=item.title or '',
status_code=item.status_code,
content_length=item.content_length,
webserver=item.webserver or '',
body_preview=item.body_preview or '',
content_type=item.content_type or '',
tech=item.tech if item.tech else [],
vhost=item.vhost,
location=item.location or '',
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
))
with transaction.atomic():
created = Endpoint.objects.bulk_create(
endpoints,
ignore_conflicts=True,
batch_size=1000
)
return len(created)
except Exception as e:
logger.error(f"批量创建端点失败: {e}")
raise
def get_by_website(self, website_id: int) -> List[EndpointDTO]:
"""
获取网站下的所有端点
Args:
website_id: 网站 ID
Returns:
List[EndpointDTO]: 端点列表
"""
endpoints = Endpoint.objects.filter(
website_id=website_id
).order_by('-discovered_at')
result = []
for endpoint in endpoints:
result.append(EndpointDTO(
website_id=endpoint.website_id,
target_id=endpoint.target_id,
url=endpoint.url,
title=endpoint.title,
status_code=endpoint.status_code,
content_length=endpoint.content_length,
webserver=endpoint.webserver,
body_preview=endpoint.body_preview,
content_type=endpoint.content_type,
tech=endpoint.tech,
vhost=endpoint.vhost,
location=endpoint.location,
matched_gf_patterns=endpoint.matched_gf_patterns
))
return result
def get_queryset_by_target(self, target_id: int):
return Endpoint.objects.filter(target_id=target_id).order_by('-discovered_at')
def get_all(self):
"""获取所有端点(全局查询)"""
return Endpoint.objects.all().order_by('-discovered_at')
def get_by_target(self, target_id: int) -> List[EndpointDTO]:
"""
获取目标下的所有端点
Args:
target_id: 目标 ID
Returns:
List[EndpointDTO]: 端点列表
"""
endpoints = Endpoint.objects.filter(
target_id=target_id
).order_by('-discovered_at')
result = []
for endpoint in endpoints:
result.append(EndpointDTO(
website_id=endpoint.website_id,
target_id=endpoint.target_id,
url=endpoint.url,
title=endpoint.title,
status_code=endpoint.status_code,
content_length=endpoint.content_length,
webserver=endpoint.webserver,
body_preview=endpoint.body_preview,
content_type=endpoint.content_type,
tech=endpoint.tech,
vhost=endpoint.vhost,
location=endpoint.location,
matched_gf_patterns=endpoint.matched_gf_patterns
))
return result
def count_by_website(self, website_id: int) -> int:
"""
统计网站下的端点数量
Args:
website_id: 网站 ID
Returns:
int: 端点数量
"""
return Endpoint.objects.filter(website_id=website_id).count()
def count_by_target(self, target_id: int) -> int:
"""
统计目标下的端点数量
Args:
target_id: 目标 ID
Returns:
int: 端点数量
"""
return Endpoint.objects.filter(target_id=target_id).count()
def soft_delete_by_ids(self, ids: List[int]) -> int:
"""
软删除端点(批量)
Args:
ids: 端点 ID 列表
Returns:
int: 更新的记录数
"""
from django.utils import timezone
return Endpoint.objects.filter(
id__in=ids
).update(deleted_at=timezone.now())
def hard_delete_by_ids(self, ids: List[int]) -> Tuple[int, Dict[str, int]]:
"""
硬删除端点(批量)
Args:
ids: 端点 ID 列表
Returns:
Tuple[int, Dict[str, int]]: (删除总数, 详细信息)
"""
deleted_count, details = Endpoint.all_objects.filter(
id__in=ids
).delete()
return deleted_count, details

View File

@@ -0,0 +1,167 @@
"""HostPortMapping Repository - Django ORM 实现"""
import logging
from typing import List, Iterator
from apps.asset.models.asset_models import HostPortMapping
from apps.asset.dtos.asset import HostPortMappingDTO
from apps.common.decorators import auto_ensure_db_connection
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoHostPortMappingRepository:
"""HostPortMapping Repository - Django ORM 实现"""
def bulk_create_ignore_conflicts(self, items: List[HostPortMappingDTO]) -> int:
"""
批量创建主机端口关联(忽略冲突)
Args:
items: 主机端口关联 DTO 列表
Returns:
int: 实际创建的记录数注意ignore_conflicts 时可能为 0
Note:
- 基于唯一约束 (target + host + ip + port) 自动去重
- 忽略已存在的记录,不更新
"""
try:
logger.debug("准备批量创建主机端口关联 - 数量: %d", len(items))
if not items:
logger.debug("主机端口关联为空,跳过创建")
return 0
# 构建记录对象
records = []
for item in items:
records.append(HostPortMapping(
target_id=item.target_id,
host=item.host,
ip=item.ip,
port=item.port
))
# 批量创建(忽略冲突,基于唯一约束去重)
created = HostPortMapping.objects.bulk_create(
records,
ignore_conflicts=True
)
created_count = len(created) if created else 0
logger.debug("主机端口关联创建完成 - 数量: %d", created_count)
return created_count
except Exception as e:
logger.error(
"批量创建主机端口关联失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
def get_for_export(self, target_id: int, batch_size: int = 1000):
queryset = (
HostPortMapping.objects
.filter(target_id=target_id)
.order_by("host", "port")
.values("host", "port")
.iterator(chunk_size=batch_size)
)
for item in queryset:
yield item
def get_ips_for_export(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
"""流式导出目标下的所有唯一 IP 地址。"""
queryset = (
HostPortMapping.objects
.filter(target_id=target_id)
.values_list("ip", flat=True)
.distinct()
.order_by("ip")
.iterator(chunk_size=batch_size)
)
for ip in queryset:
yield ip
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
from django.db.models import Min
qs = HostPortMapping.objects.filter(target_id=target_id)
if search:
qs = qs.filter(ip__icontains=search)
ip_aggregated = (
qs
.values('ip')
.annotate(
discovered_at=Min('discovered_at')
)
.order_by('-discovered_at')
)
results = []
for item in ip_aggregated:
ip = item['ip']
mappings = (
HostPortMapping.objects
.filter(target_id=target_id, ip=ip)
.values('host', 'port')
.distinct()
)
hosts = sorted({m['host'] for m in mappings})
ports = sorted({m['port'] for m in mappings})
results.append({
'ip': ip,
'hosts': hosts,
'ports': ports,
'discovered_at': item['discovered_at'],
})
return results
def get_all_ip_aggregation(self, search: str = None):
"""获取所有 IP 聚合数据(全局查询)"""
from django.db.models import Min
qs = HostPortMapping.objects.all()
if search:
qs = qs.filter(ip__icontains=search)
ip_aggregated = (
qs
.values('ip')
.annotate(
discovered_at=Min('discovered_at')
)
.order_by('-discovered_at')
)
results = []
for item in ip_aggregated:
ip = item['ip']
mappings = (
HostPortMapping.objects
.filter(ip=ip)
.values('host', 'port')
.distinct()
)
hosts = sorted({m['host'] for m in mappings})
ports = sorted({m['port'] for m in mappings})
results.append({
'ip': ip,
'hosts': hosts,
'ports': ports,
'discovered_at': item['discovered_at'],
})
return results

View File

@@ -0,0 +1,256 @@
import logging
from typing import List, Iterator
from django.db import transaction, IntegrityError, OperationalError, DatabaseError
from django.utils import timezone
from typing import Tuple, Dict
from apps.asset.models.asset_models import Subdomain
from apps.asset.dtos import SubdomainDTO
from apps.common.decorators import auto_ensure_db_connection
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoSubdomainRepository:
"""基于 Django ORM 的子域名仓储实现。"""
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
"""
批量创建子域名,忽略冲突
Args:
items: 子域名 DTO 列表
Raises:
IntegrityError: 数据完整性错误(如唯一约束冲突)
OperationalError: 数据库操作错误(如连接失败)
DatabaseError: 其他数据库错误
"""
if not items:
return
try:
subdomain_objects = [
Subdomain(
name=item.name,
target_id=item.target_id,
)
for item in items
]
with transaction.atomic():
# 使用 ignore_conflicts 策略:
# - 新子域名INSERT 完整记录
# - 已存在子域名:忽略(不更新,因为没有探测字段数据)
# 注意ignore_conflicts 无法返回实际创建的数量
Subdomain.objects.bulk_create( # type: ignore[attr-defined]
subdomain_objects,
ignore_conflicts=True, # 忽略重复记录
)
logger.debug(f"成功处理 {len(items)} 条子域名记录")
except IntegrityError as e:
logger.error(
f"批量插入子域名失败 - 数据完整性错误: {e}, "
f"记录数: {len(items)}, "
f"示例域名: {items[0].name if items else 'N/A'}"
)
raise
except OperationalError as e:
logger.error(
f"批量插入子域名失败 - 数据库操作错误: {e}, "
f"记录数: {len(items)}"
)
raise
except DatabaseError as e:
logger.error(
f"批量插入子域名失败 - 数据库错误: {e}, "
f"记录数: {len(items)}"
)
raise
except Exception as e:
logger.error(
f"批量插入子域名失败 - 未知错误: {e}, "
f"记录数: {len(items)}, "
f"错误类型: {type(e).__name__}",
exc_info=True
)
raise
def get_or_create(self, name: str, target_id: int) -> Tuple[Subdomain, bool]:
"""
获取或创建子域名
Args:
name: 子域名名称
target_id: 目标 ID
Returns:
(Subdomain对象, 是否新创建)
"""
return Subdomain.objects.get_or_create(
name=name,
target_id=target_id,
)
def get_domains_for_export(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
"""
流式导出域名(用于生成扫描工具输入文件)
使用 iterator() 进行流式查询,避免一次性加载所有数据到内存
Args:
target_id: 目标 ID
batch_size: 每次从数据库读取的行数
Yields:
str: 域名
"""
queryset = Subdomain.objects.filter(
target_id=target_id
).only('name').iterator(chunk_size=batch_size)
for subdomain in queryset:
yield subdomain.name
def get_by_target(self, target_id: int):
return Subdomain.objects.filter(target_id=target_id).order_by('-discovered_at')
def count_by_target(self, target_id: int) -> int:
"""
统计目标下的域名数量
Args:
target_id: 目标 ID
Returns:
int: 域名数量
"""
return Subdomain.objects.filter(target_id=target_id).count()
def get_by_names_and_target_id(self, names: set, target_id: int) -> dict:
"""
根据域名列表和目标ID批量查询 Subdomain
Args:
names: 域名集合
target_id: 目标 ID
Returns:
dict: {domain_name: Subdomain对象}
"""
subdomains = Subdomain.objects.filter(
name__in=names,
target_id=target_id
).only('id', 'name')
return {sd.name: sd for sd in subdomains}
def get_all(self):
"""
获取所有子域名
Returns:
QuerySet: 子域名查询集
"""
return Subdomain.objects.all()
def soft_delete_by_ids(self, subdomain_ids: List[int]) -> int:
"""
根据 ID 列表批量软删除子域名
Args:
subdomain_ids: 子域名 ID 列表
Returns:
软删除的记录数
Note:
- 使用软删除:只标记为已删除,不真正删除数据库记录
- 保留所有关联数据,可恢复
"""
try:
updated_count = (
Subdomain.objects
.filter(id__in=subdomain_ids)
.update(deleted_at=timezone.now())
)
logger.debug(
"批量软删除子域名成功 - Count: %s, 更新记录: %s",
len(subdomain_ids),
updated_count
)
return updated_count
except Exception as e:
logger.error(
"批量软删除子域名失败 - IDs: %s, 错误: %s",
subdomain_ids,
e
)
raise
def hard_delete_by_ids(self, subdomain_ids: List[int]) -> Tuple[int, Dict[str, int]]:
"""
根据 ID 列表硬删除子域名(使用数据库级 CASCADE
Args:
subdomain_ids: 子域名 ID 列表
Returns:
(删除的记录数, 删除详情字典)
Strategy:
使用数据库级 CASCADE 删除,性能最优
Note:
- 硬删除:从数据库中永久删除
- 数据库自动处理所有外键级联删除
- 不触发 Django 信号pre_delete/post_delete
"""
try:
batch_size = 1000 # 每批处理1000个子域名
total_deleted = 0
logger.debug(f"开始批量删除 {len(subdomain_ids)} 个子域名(数据库 CASCADE...")
# 分批处理子域名ID避免单次删除过多
for i in range(0, len(subdomain_ids), batch_size):
batch_ids = subdomain_ids[i:i + batch_size]
# 直接删除子域名,数据库自动级联删除所有关联数据
count, _ = Subdomain.all_objects.filter(id__in=batch_ids).delete()
total_deleted += count
logger.debug(f"批次删除完成: {len(batch_ids)} 个子域名,删除 {count} 条记录")
# 由于使用数据库 CASCADE无法获取详细统计
deleted_details = {
'subdomains': len(subdomain_ids),
'total': total_deleted,
'note': 'Database CASCADE - detailed stats unavailable'
}
logger.debug(
"批量硬删除成功CASCADE- 子域名数: %s, 总删除记录: %s",
len(subdomain_ids),
total_deleted
)
return total_deleted, deleted_details
except Exception as e:
logger.error(
"批量硬删除失败CASCADE- 子域名数: %s, 错误: %s",
len(subdomain_ids),
str(e),
exc_info=True
)
raise

View File

@@ -0,0 +1,260 @@
"""
Django ORM 实现的 WebSite Repository
"""
import logging
from typing import List, Generator, Tuple, Dict, Optional
from django.db import transaction, IntegrityError, OperationalError, DatabaseError
from django.utils import timezone
from apps.asset.models.asset_models import WebSite
from apps.asset.dtos import WebSiteDTO
from apps.common.decorators import auto_ensure_db_connection
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoWebSiteRepository:
"""Django ORM 实现的 WebSite Repository"""
def bulk_create_ignore_conflicts(self, items: List[WebSiteDTO]) -> None:
"""
批量创建 WebSite忽略冲突
Args:
items: WebSite DTO 列表
Raises:
IntegrityError: 数据完整性错误
OperationalError: 数据库操作错误
DatabaseError: 数据库错误
"""
if not items:
return
try:
# 转换为 Django 模型对象
website_objects = [
WebSite(
target_id=item.target_id,
url=item.url,
host=item.host,
location=item.location,
title=item.title,
webserver=item.webserver,
body_preview=item.body_preview,
content_type=item.content_type,
tech=item.tech,
status_code=item.status_code,
content_length=item.content_length,
vhost=item.vhost
)
for item in items
]
with transaction.atomic():
# 批量插入或更新
# 如果URL和目标已存在忽略冲突
WebSite.objects.bulk_create(
website_objects,
ignore_conflicts=True
)
logger.debug(f"成功处理 {len(items)} 条 WebSite 记录")
except IntegrityError as e:
logger.error(
f"批量插入 WebSite 失败 - 数据完整性错误: {e}, "
f"记录数: {len(items)}"
)
raise
except OperationalError as e:
logger.error(
f"批量插入 WebSite 失败 - 数据库操作错误: {e}, "
f"记录数: {len(items)}"
)
raise
except DatabaseError as e:
logger.error(
f"批量插入 WebSite 失败 - 数据库错误: {e}, "
f"记录数: {len(items)}"
)
raise
except Exception as e:
logger.error(
f"批量插入 WebSite 失败 - 未知错误: {e}, "
f"记录数: {len(items)}, "
f"错误类型: {type(e).__name__}",
exc_info=True
)
raise
def get_urls_for_export(self, target_id: int, batch_size: int = 1000) -> Generator[str, None, None]:
"""
流式导出目标下的所有站点 URL
Args:
target_id: 目标 ID
batch_size: 批次大小
Yields:
str: 站点 URL
"""
try:
# 查询目标下的站点,只选择 URL 字段,避免不必要的数据传输
queryset = WebSite.objects.filter(
target_id=target_id
).values_list('url', flat=True).iterator(chunk_size=batch_size)
for url in queryset:
yield url
except Exception as e:
logger.error(f"流式导出站点 URL 失败 - Target ID: {target_id}, 错误: {e}")
raise
def get_by_target(self, target_id: int):
return WebSite.objects.filter(target_id=target_id).order_by('-discovered_at')
def count_by_target(self, target_id: int) -> int:
"""
统计目标下的站点总数
Args:
target_id: 目标 ID
Returns:
int: 站点总数
"""
try:
count = WebSite.objects.filter(target_id=target_id).count()
logger.debug(f"Target {target_id} 的站点总数: {count}")
return count
except Exception as e:
logger.error(f"统计站点数量失败 - Target ID: {target_id}, 错误: {e}")
raise
def count_by_scan(self, scan_id: int) -> int:
"""
统计扫描下的站点总数
"""
try:
count = WebSite.objects.filter(scan_id=scan_id).count()
logger.debug(f"Scan {scan_id} 的站点总数: {count}")
return count
except Exception as e:
logger.error(f"统计站点数量失败 - Scan ID: {scan_id}, 错误: {e}")
raise
def get_by_url(self, url: str, target_id: int) -> Optional[int]:
"""
根据 URL 和 target_id 查找站点 ID
Args:
url: 站点 URL
target_id: 目标 ID
Returns:
Optional[int]: 站点 ID如果不存在返回 None
Raises:
ValueError: 发现多个站点时
"""
try:
website = WebSite.objects.filter(url=url, target_id=target_id).first()
if website:
return website.id
return None
except Exception as e:
logger.error(f"查询站点失败 - URL: {url}, Target ID: {target_id}, 错误: {e}")
raise
def get_all(self):
"""
获取所有网站
Returns:
QuerySet: 网站查询集
"""
return WebSite.objects.all()
def soft_delete_by_ids(self, website_ids: List[int]) -> int:
"""
根据 ID 列表批量软删除WebSite
Args:
website_ids: WebSite ID 列表
Returns:
软删除的记录数
"""
try:
updated_count = (
WebSite.objects
.filter(id__in=website_ids)
.update(deleted_at=timezone.now())
)
logger.debug(
"批量软删除WebSite成功 - Count: %s, 更新记录: %s",
len(website_ids),
updated_count
)
return updated_count
except Exception as e:
logger.error(
"批量软删除WebSite失败 - IDs: %s, 错误: %s",
website_ids,
e
)
raise
def hard_delete_by_ids(self, website_ids: List[int]) -> Tuple[int, Dict[str, int]]:
"""
根据 ID 列表硬删除WebSite使用数据库级 CASCADE
Args:
website_ids: WebSite ID 列表
Returns:
(删除的记录数, 删除详情字典)
"""
try:
batch_size = 1000
total_deleted = 0
logger.debug(f"开始批量删除 {len(website_ids)} 个WebSite数据库 CASCADE...")
for i in range(0, len(website_ids), batch_size):
batch_ids = website_ids[i:i + batch_size]
count, _ = WebSite.all_objects.filter(id__in=batch_ids).delete()
total_deleted += count
logger.debug(f"批次删除完成: {len(batch_ids)} 个WebSite删除 {count} 条记录")
deleted_details = {
'websites': len(website_ids),
'total': total_deleted,
'note': 'Database CASCADE - detailed stats unavailable'
}
logger.debug(
"批量硬删除成功CASCADE- WebSite数: %s, 总删除记录: %s",
len(website_ids),
total_deleted
)
return total_deleted, deleted_details
except Exception as e:
logger.error(
"批量硬删除失败CASCADE- WebSite数: %s, 错误: %s",
len(website_ids),
str(e),
exc_info=True
)
raise

View File

@@ -0,0 +1,17 @@
"""Snapshot Repositories - 数据访问层"""
from .subdomain_snapshot_repository import DjangoSubdomainSnapshotRepository
from .host_port_mapping_snapshot_repository import DjangoHostPortMappingSnapshotRepository
from .website_snapshot_repository import DjangoWebsiteSnapshotRepository
from .directory_snapshot_repository import DjangoDirectorySnapshotRepository
from .endpoint_snapshot_repository import DjangoEndpointSnapshotRepository
from .vulnerability_snapshot_repository import DjangoVulnerabilitySnapshotRepository
__all__ = [
'DjangoSubdomainSnapshotRepository',
'DjangoHostPortMappingSnapshotRepository',
'DjangoWebsiteSnapshotRepository',
'DjangoDirectorySnapshotRepository',
'DjangoEndpointSnapshotRepository',
'DjangoVulnerabilitySnapshotRepository',
]

View File

@@ -0,0 +1,78 @@
"""Directory Snapshot Repository - 目录快照数据访问层"""
import logging
from typing import List
from django.db import transaction
from apps.asset.models import DirectorySnapshot
from apps.asset.dtos.snapshot import DirectorySnapshotDTO
from apps.common.decorators import auto_ensure_db_connection
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoDirectorySnapshotRepository:
"""
目录快照仓储Django ORM 实现)
负责目录快照表的数据访问操作
"""
def save_snapshots(self, items: List[DirectorySnapshotDTO]) -> None:
"""
批量保存目录快照记录
使用 ignore_conflicts 策略,如果快照已存在(相同 scan + url则跳过
Args:
items: 目录快照 DTO 列表
Raises:
ValueError: items 为空
Exception: 数据库操作失败
"""
if not items:
logger.warning("目录快照列表为空,跳过保存")
return
try:
# 转换为 Django 模型对象
snapshot_objects = [
DirectorySnapshot(
scan_id=item.scan_id,
url=item.url,
status=item.status,
content_length=item.content_length,
words=item.words,
lines=item.lines,
content_type=item.content_type,
duration=item.duration
)
for item in items
]
with transaction.atomic():
# 批量插入,忽略冲突
# 如果 scan + url 已存在,跳过
DirectorySnapshot.objects.bulk_create(
snapshot_objects,
ignore_conflicts=True
)
logger.debug("成功保存 %d 条目录快照记录", len(items))
except Exception as e:
logger.error(
"批量保存目录快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
def get_by_scan(self, scan_id: int):
return DirectorySnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
def get_all(self):
return DirectorySnapshot.objects.all().order_by('-discovered_at')

View File

@@ -0,0 +1,74 @@
"""EndpointSnapshot Repository - Django ORM 实现"""
import logging
from typing import List
from apps.asset.models.snapshot_models import EndpointSnapshot
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
from apps.common.decorators import auto_ensure_db_connection
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoEndpointSnapshotRepository:
"""端点快照 Repository - 负责端点快照表的数据访问"""
def save_snapshots(self, items: List[EndpointSnapshotDTO]) -> None:
"""
保存端点快照
Args:
items: 端点快照 DTO 列表
Note:
- 保存完整的快照数据
- 基于唯一约束 (scan + url) 自动去重
"""
try:
logger.debug("准备保存端点快照 - 数量: %d", len(items))
if not items:
logger.debug("端点快照为空,跳过保存")
return
# 构建快照对象
snapshots = []
for item in items:
snapshots.append(EndpointSnapshot(
scan_id=item.scan_id,
url=item.url,
title=item.title,
status_code=item.status_code,
content_length=item.content_length,
location=item.location,
webserver=item.webserver,
content_type=item.content_type,
tech=item.tech if item.tech else [],
body_preview=item.body_preview,
vhost=item.vhost,
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
))
# 批量创建(忽略冲突,基于唯一约束去重)
EndpointSnapshot.objects.bulk_create(
snapshots,
ignore_conflicts=True
)
logger.debug("端点快照保存成功 - 数量: %d", len(snapshots))
except Exception as e:
logger.error(
"保存端点快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
def get_by_scan(self, scan_id: int):
return EndpointSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
def get_all(self):
return EndpointSnapshot.objects.all().order_by('-discovered_at')

View File

@@ -0,0 +1,145 @@
"""HostPortMappingSnapshot Repository - Django ORM 实现"""
import logging
from typing import List, Iterator
from apps.asset.models.snapshot_models import HostPortMappingSnapshot
from apps.asset.dtos.snapshot import HostPortMappingSnapshotDTO
from apps.common.decorators import auto_ensure_db_connection
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoHostPortMappingSnapshotRepository:
"""HostPortMappingSnapshot Repository - Django ORM 实现,负责主机端口映射快照表的数据访问"""
def save_snapshots(self, items: List[HostPortMappingSnapshotDTO]) -> None:
"""
保存主机端口关联快照
Args:
items: 主机端口关联快照 DTO 列表
Note:
- 保存完整的快照数据
- 基于唯一约束 (scan + host + ip + port) 自动去重
"""
try:
logger.debug("准备保存主机端口关联快照 - 数量: %d", len(items))
if not items:
logger.debug("主机端口关联快照为空,跳过保存")
return
# 构建快照对象
snapshots = []
for item in items:
snapshots.append(HostPortMappingSnapshot(
scan_id=item.scan_id,
host=item.host,
ip=item.ip,
port=item.port
))
# 批量创建(忽略冲突,基于唯一约束去重)
HostPortMappingSnapshot.objects.bulk_create(
snapshots,
ignore_conflicts=True
)
logger.debug("主机端口关联快照保存成功 - 数量: %d", len(snapshots))
except Exception as e:
logger.error(
"保存主机端口关联快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
from django.db.models import Min
qs = HostPortMappingSnapshot.objects.filter(scan_id=scan_id)
if search:
qs = qs.filter(ip__icontains=search)
ip_aggregated = (
qs
.values('ip')
.annotate(
discovered_at=Min('discovered_at')
)
.order_by('-discovered_at')
)
results = []
for item in ip_aggregated:
ip = item['ip']
mappings = (
HostPortMappingSnapshot.objects
.filter(scan_id=scan_id, ip=ip)
.values('host', 'port')
.distinct()
)
hosts = sorted({m['host'] for m in mappings})
ports = sorted({m['port'] for m in mappings})
results.append({
'ip': ip,
'hosts': hosts,
'ports': ports,
'discovered_at': item['discovered_at'],
})
return results
def get_all_ip_aggregation(self, search: str = None):
"""获取所有 IP 聚合数据"""
from django.db.models import Min
qs = HostPortMappingSnapshot.objects.all()
if search:
qs = qs.filter(ip__icontains=search)
ip_aggregated = (
qs
.values('ip')
.annotate(discovered_at=Min('discovered_at'))
.order_by('-discovered_at')
)
results = []
for item in ip_aggregated:
ip = item['ip']
mappings = (
HostPortMappingSnapshot.objects
.filter(ip=ip)
.values('host', 'port')
.distinct()
)
hosts = sorted({m['host'] for m in mappings})
ports = sorted({m['port'] for m in mappings})
results.append({
'ip': ip,
'hosts': hosts,
'ports': ports,
'discovered_at': item['discovered_at'],
})
return results
def get_ips_for_export(self, scan_id: int, batch_size: int = 1000) -> Iterator[str]:
"""流式导出扫描下的所有唯一 IP 地址。"""
queryset = (
HostPortMappingSnapshot.objects
.filter(scan_id=scan_id)
.values_list("ip", flat=True)
.distinct()
.order_by("ip")
.iterator(chunk_size=batch_size)
)
for ip in queryset:
yield ip

View File

@@ -0,0 +1,61 @@
"""Django ORM 实现的 SubdomainSnapshot Repository"""
import logging
from typing import List
from apps.asset.models.snapshot_models import SubdomainSnapshot
from apps.asset.dtos import SubdomainSnapshotDTO
from apps.common.decorators import auto_ensure_db_connection
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoSubdomainSnapshotRepository:
"""子域名快照 Repository - 负责子域名快照表的数据访问"""
def save_subdomain_snapshots(self, items: List[SubdomainSnapshotDTO]) -> None:
"""
保存子域名快照
Args:
items: 子域名快照 DTO 列表
Note:
- 保存完整的快照数据
- 基于唯一约束自动去重(忽略冲突)
"""
try:
logger.debug("准备保存子域名快照 - 数量: %d", len(items))
if not items:
logger.debug("子域名快照为空,跳过保存")
return
# 构建快照对象
snapshots = []
for item in items:
snapshots.append(SubdomainSnapshot(
scan_id=item.scan_id,
name=item.name,
))
# 批量创建(忽略冲突,基于唯一约束去重)
SubdomainSnapshot.objects.bulk_create(snapshots, ignore_conflicts=True)
logger.debug("子域名快照保存成功 - 数量: %d", len(snapshots))
except Exception as e:
logger.error(
"保存子域名快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
def get_by_scan(self, scan_id: int):
return SubdomainSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
def get_all(self):
return SubdomainSnapshot.objects.all().order_by('-discovered_at')

View File

@@ -0,0 +1,66 @@
"""Vulnerability Snapshot Repository - 漏洞快照数据访问层"""
import logging
from typing import List
from django.db import transaction
from apps.asset.models import VulnerabilitySnapshot
from apps.asset.dtos.snapshot import VulnerabilitySnapshotDTO
from apps.common.decorators import auto_ensure_db_connection
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoVulnerabilitySnapshotRepository:
"""漏洞快照仓储Django ORM 实现)"""
def save_snapshots(self, items: List[VulnerabilitySnapshotDTO]) -> None:
"""批量保存漏洞快照记录。
使用 ``ignore_conflicts`` 策略,如果快照已存在则跳过。
具体唯一约束由数据库模型控制。
"""
if not items:
logger.warning("漏洞快照列表为空,跳过保存")
return
try:
snapshot_objects = [
VulnerabilitySnapshot(
scan_id=item.scan_id,
url=item.url,
vuln_type=item.vuln_type,
severity=item.severity,
source=item.source,
cvss_score=item.cvss_score,
description=item.description,
raw_output=item.raw_output,
)
for item in items
]
with transaction.atomic():
VulnerabilitySnapshot.objects.bulk_create(
snapshot_objects,
ignore_conflicts=True,
)
logger.debug("成功保存 %d 条漏洞快照记录", len(items))
except Exception as e:
logger.error(
"批量保存漏洞快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True,
)
raise
def get_by_scan(self, scan_id: int):
"""按扫描任务获取漏洞快照 QuerySet。"""
return VulnerabilitySnapshot.objects.filter(scan_id=scan_id).order_by("-discovered_at")
def get_all(self):
return VulnerabilitySnapshot.objects.all().order_by('-discovered_at')

View File

@@ -0,0 +1,74 @@
"""WebsiteSnapshot Repository - Django ORM 实现"""
import logging
from typing import List
from apps.asset.models.snapshot_models import WebsiteSnapshot
from apps.asset.dtos.snapshot import WebsiteSnapshotDTO
from apps.common.decorators import auto_ensure_db_connection
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoWebsiteSnapshotRepository:
"""网站快照 Repository - 负责网站快照表的数据访问"""
def save_snapshots(self, items: List[WebsiteSnapshotDTO]) -> None:
"""
保存网站快照
Args:
items: 网站快照 DTO 列表
Note:
- 保存完整的快照数据
- 基于唯一约束 (scan + subdomain + url) 自动去重
"""
try:
logger.debug("准备保存网站快照 - 数量: %d", len(items))
if not items:
logger.debug("网站快照为空,跳过保存")
return
# 构建快照对象
snapshots = []
for item in items:
snapshots.append(WebsiteSnapshot(
scan_id=item.scan_id,
url=item.url,
host=item.host,
title=item.title,
status=item.status,
content_length=item.content_length,
location=item.location,
web_server=item.web_server,
content_type=item.content_type,
tech=item.tech if item.tech else [],
body_preview=item.body_preview,
vhost=item.vhost
))
# 批量创建(忽略冲突,基于唯一约束去重)
WebsiteSnapshot.objects.bulk_create(
snapshots,
ignore_conflicts=True
)
logger.debug("网站快照保存成功 - 数量: %d", len(snapshots))
except Exception as e:
logger.error(
"保存网站快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
def get_by_scan(self, scan_id: int):
return WebsiteSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
def get_all(self):
return WebsiteSnapshot.objects.all().order_by('-discovered_at')

View File

@@ -0,0 +1,128 @@
"""资产统计 Repository"""
import logging
from datetime import date, timedelta
from typing import Optional, List
from apps.asset.models import AssetStatistics, StatisticsHistory
logger = logging.getLogger(__name__)
class AssetStatisticsRepository:
"""
资产统计数据访问层
职责:
- 读取/更新预聚合的统计数据
"""
def get_statistics(self) -> Optional[AssetStatistics]:
"""
获取统计数据
Returns:
统计数据对象,不存在则返回 None
"""
return AssetStatistics.objects.first()
def get_or_create_statistics(self) -> AssetStatistics:
"""
获取或创建统计数据(单例)
Returns:
统计数据对象
"""
return AssetStatistics.get_or_create_singleton()
def update_statistics(
self,
total_targets: int,
total_subdomains: int,
total_ips: int,
total_endpoints: int,
total_websites: int,
total_vulns: int,
) -> AssetStatistics:
"""
更新统计数据
Args:
total_targets: 目标总数
total_subdomains: 子域名总数
total_ips: IP 总数
total_endpoints: 端点总数
total_websites: 网站总数
total_vulns: 漏洞总数
Returns:
更新后的统计数据对象
"""
stats = self.get_or_create_statistics()
# 1. 保存当前值到 prev_* 字段
stats.prev_targets = stats.total_targets
stats.prev_subdomains = stats.total_subdomains
stats.prev_ips = stats.total_ips
stats.prev_endpoints = stats.total_endpoints
stats.prev_websites = stats.total_websites
stats.prev_vulns = stats.total_vulns
stats.prev_assets = stats.total_assets
# 2. 更新当前值
stats.total_targets = total_targets
stats.total_subdomains = total_subdomains
stats.total_ips = total_ips
stats.total_endpoints = total_endpoints
stats.total_websites = total_websites
stats.total_vulns = total_vulns
stats.total_assets = total_subdomains + total_ips + total_endpoints + total_websites
stats.save()
logger.info(
"更新资产统计: targets=%d, subdomains=%d, ips=%d, endpoints=%d, websites=%d, vulns=%d, assets=%d",
total_targets, total_subdomains, total_ips, total_endpoints, total_websites, total_vulns, stats.total_assets
)
return stats
def save_daily_snapshot(self, stats: AssetStatistics) -> StatisticsHistory:
"""
保存每日统计快照(幂等,每天只存一条)
Args:
stats: 当前统计数据
Returns:
历史记录对象
"""
history, created = StatisticsHistory.objects.update_or_create(
date=date.today(),
defaults={
'total_targets': stats.total_targets,
'total_subdomains': stats.total_subdomains,
'total_ips': stats.total_ips,
'total_endpoints': stats.total_endpoints,
'total_websites': stats.total_websites,
'total_vulns': stats.total_vulns,
'total_assets': stats.total_assets,
}
)
action = "创建" if created else "更新"
logger.info(f"{action}统计快照: date={history.date}, assets={history.total_assets}")
return history
def get_history(self, days: int = 7) -> List[StatisticsHistory]:
"""
获取历史统计数据(用于折线图)
Args:
days: 获取最近多少天的数据,默认 7 天
Returns:
历史记录列表,按日期升序
"""
start_date = date.today() - timedelta(days=days - 1)
return list(
StatisticsHistory.objects
.filter(date__gte=start_date)
.order_by('date')
)

View File

@@ -0,0 +1,291 @@
from rest_framework import serializers
from .models import Subdomain, WebSite, Directory, HostPortMapping, Endpoint, Vulnerability
from .models.snapshot_models import (
SubdomainSnapshot,
WebsiteSnapshot,
DirectorySnapshot,
EndpointSnapshot,
VulnerabilitySnapshot,
)
# 注意IPAddress 和 Port 模型已被重构为 HostPortMapping
# 以下是基于新架构的序列化器实现
# class PortSerializer(serializers.ModelSerializer):
# """端口序列化器"""
#
# class Meta:
# model = Port
# fields = ['number', 'service_name', 'description', 'is_uncommon']
class SubdomainSerializer(serializers.ModelSerializer):
"""子域名序列化器"""
class Meta:
model = Subdomain
fields = [
'id', 'name', 'discovered_at', 'target'
]
read_only_fields = ['id', 'discovered_at']
class SubdomainListSerializer(serializers.ModelSerializer):
"""子域名列表序列化器(用于扫描详情)"""
# 注意Subdomain 模型已简化,只保留核心字段
# cname, is_cdn, cdn_name 等字段已移至 SubdomainSnapshot
# ports 和 ip_addresses 关系已被重构为 HostPortMapping
class Meta:
model = Subdomain
fields = [
'id', 'name', 'discovered_at'
]
read_only_fields = ['id', 'discovered_at']
# class IPAddressListSerializer(serializers.ModelSerializer):
# """IP 地址列表序列化器"""
#
# subdomain = serializers.CharField(source='subdomain.name', allow_blank=True, default='')
# created_at = serializers.DateTimeField(read_only=True)
# ports = PortSerializer(many=True, read_only=True)
#
# class Meta:
# model = IPAddress
# fields = [
# 'id',
# 'ip',
# 'subdomain',
# 'reverse_pointer',
# 'created_at',
# 'ports',
# ]
# read_only_fields = fields
class WebSiteSerializer(serializers.ModelSerializer):
"""站点序列化器"""
subdomain = serializers.CharField(source='subdomain.name', allow_blank=True, default='')
class Meta:
model = WebSite
fields = [
'id',
'url',
'location',
'title',
'webserver',
'content_type',
'status_code',
'content_length',
'body_preview',
'tech',
'vhost',
'subdomain',
'discovered_at',
]
read_only_fields = fields
class VulnerabilitySerializer(serializers.ModelSerializer):
"""漏洞资产序列化器(按目标查看漏洞资产)。"""
class Meta:
model = Vulnerability
fields = [
'id',
'target',
'url',
'vuln_type',
'severity',
'source',
'cvss_score',
'description',
'raw_output',
'discovered_at',
]
read_only_fields = fields
class VulnerabilitySnapshotSerializer(serializers.ModelSerializer):
"""漏洞快照序列化器(用于扫描历史漏洞列表)。"""
class Meta:
model = VulnerabilitySnapshot
fields = [
'id',
'url',
'vuln_type',
'severity',
'source',
'cvss_score',
'description',
'raw_output',
'discovered_at',
]
read_only_fields = fields
class EndpointListSerializer(serializers.ModelSerializer):
"""端点列表序列化器(用于目标端点列表页)"""
# 将 GF 匹配模式映射为前端使用的 tags 字段
tags = serializers.ListField(
child=serializers.CharField(),
source='matched_gf_patterns',
read_only=True,
)
class Meta:
model = Endpoint
fields = [
'id',
'url',
'location',
'status_code',
'title',
'content_length',
'content_type',
'webserver',
'body_preview',
'tech',
'vhost',
'tags',
'discovered_at',
]
read_only_fields = fields
class DirectorySerializer(serializers.ModelSerializer):
"""目录序列化器"""
website_url = serializers.CharField(source='website.url', read_only=True)
discovered_at = serializers.DateTimeField(read_only=True)
class Meta:
model = Directory
fields = [
'id',
'url',
'status',
'content_length',
'words',
'lines',
'content_type',
'duration',
'website_url',
'discovered_at',
]
read_only_fields = fields
class IPAddressAggregatedSerializer(serializers.Serializer):
"""
IP 地址聚合序列化器
基于 HostPortMapping 模型,按 IP 聚合显示:
- ip: IP 地址
- hosts: 该 IP 关联的所有主机名列表
- ports: 该 IP 关联的所有端口列表
- discovered_at: 首次发现时间
"""
ip = serializers.IPAddressField(read_only=True)
hosts = serializers.ListField(child=serializers.CharField(), read_only=True)
ports = serializers.ListField(child=serializers.IntegerField(), read_only=True)
discovered_at = serializers.DateTimeField(read_only=True)
# ==================== 快照序列化器 ====================
class SubdomainSnapshotSerializer(serializers.ModelSerializer):
"""子域名快照序列化器(用于扫描历史)"""
class Meta:
model = SubdomainSnapshot
fields = ['id', 'name', 'discovered_at']
read_only_fields = fields
class WebsiteSnapshotSerializer(serializers.ModelSerializer):
"""网站快照序列化器(用于扫描历史)"""
subdomain_name = serializers.CharField(source='subdomain.name', read_only=True)
webserver = serializers.CharField(source='web_server', read_only=True) # 映射字段名
status_code = serializers.IntegerField(source='status', read_only=True) # 映射字段名
class Meta:
model = WebsiteSnapshot
fields = [
'id',
'url',
'location',
'title',
'webserver', # 使用映射后的字段名
'content_type',
'status_code', # 使用映射后的字段名
'content_length',
'body_preview',
'tech',
'vhost',
'subdomain_name',
'discovered_at',
]
read_only_fields = fields
class DirectorySnapshotSerializer(serializers.ModelSerializer):
"""目录快照序列化器(用于扫描历史)"""
# DirectorySnapshot 当前不再关联 Website这里暂时将 website_url 映射为自身的 url保证字段兼容
website_url = serializers.CharField(source='url', read_only=True)
class Meta:
model = DirectorySnapshot
fields = [
'id',
'url',
'status',
'content_length',
'words',
'lines',
'content_type',
'duration',
'website_url',
'discovered_at',
]
read_only_fields = fields
class EndpointSnapshotSerializer(serializers.ModelSerializer):
"""端点快照序列化器(用于扫描历史)"""
# 将 GF 匹配模式映射为前端使用的 tags 字段
tags = serializers.ListField(
child=serializers.CharField(),
source='matched_gf_patterns',
read_only=True,
)
class Meta:
model = EndpointSnapshot
fields = [
'id',
'url',
'host',
'location',
'title',
'webserver',
'content_type',
'status_code',
'content_length',
'body_preview',
'tech',
'vhost',
'tags',
'discovered_at',
]
read_only_fields = fields

View File

@@ -0,0 +1,43 @@
"""Asset Services - 业务逻辑层"""
# 资产模块 Services
from .asset import (
SubdomainService,
WebSiteService,
DirectoryService,
HostPortMappingService,
EndpointService,
VulnerabilityService,
)
# 快照模块 Services
from .snapshot import (
SubdomainSnapshotsService,
HostPortMappingSnapshotsService,
WebsiteSnapshotsService,
DirectorySnapshotsService,
EndpointSnapshotsService,
VulnerabilitySnapshotsService,
)
# 统计模块 Service
from .statistics_service import AssetStatisticsService
__all__ = [
# 资产模块
'SubdomainService',
'WebSiteService',
'DirectoryService',
'HostPortMappingService',
'EndpointService',
'VulnerabilityService',
# 快照模块
'SubdomainSnapshotsService',
'HostPortMappingSnapshotsService',
'WebsiteSnapshotsService',
'DirectorySnapshotsService',
'EndpointSnapshotsService',
'VulnerabilitySnapshotsService',
# 统计模块
'AssetStatisticsService',
]

View File

@@ -0,0 +1,17 @@
"""Asset Services - 资产模块的业务逻辑层"""
from .subdomain_service import SubdomainService
from .website_service import WebSiteService
from .directory_service import DirectoryService
from .host_port_mapping_service import HostPortMappingService
from .endpoint_service import EndpointService
from .vulnerability_service import VulnerabilityService
__all__ = [
'SubdomainService',
'WebSiteService',
'DirectoryService',
'HostPortMappingService',
'EndpointService',
'VulnerabilityService',
]

View File

@@ -0,0 +1,55 @@
import logging
from typing import Tuple, Iterator
from apps.asset.models.asset_models import Directory
from apps.asset.repositories import DjangoDirectoryRepository
logger = logging.getLogger(__name__)
class DirectoryService:
"""目录业务逻辑层"""
def __init__(self, repository=None):
"""
初始化目录服务
Args:
repository: 目录仓储实例(用于依赖注入)
"""
self.repo = repository or DjangoDirectoryRepository()
# ==================== 创建操作 ====================
def bulk_create_ignore_conflicts(self, directory_dtos: list) -> None:
"""
批量创建目录记录,忽略冲突(用于扫描任务)
Args:
directory_dtos: DirectoryDTO 列表
"""
return self.repo.bulk_create_ignore_conflicts(directory_dtos)
# ==================== 查询操作 ====================
def get_all(self):
"""
获取所有目录
Returns:
QuerySet: 目录查询集
"""
logger.debug("获取所有目录")
return self.repo.get_all()
def get_directories_by_target(self, target_id: int):
logger.debug("获取目标下所有目录 - Target ID: %d", target_id)
return self.repo.get_by_target(target_id)
def iter_directory_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取目标下的所有目录 URL用于导出大批量数据。"""
logger.debug("流式导出目标下目录 URL - Target ID: %d", target_id)
return self.repo.get_urls_for_export(target_id=target_id, batch_size=chunk_size)
__all__ = ['DirectoryService']

View File

@@ -0,0 +1,178 @@
"""
Endpoint 服务层
处理 URL/端点相关的业务逻辑
"""
import logging
from typing import List, Optional, Dict, Any, Iterator
from apps.asset.dtos.asset import EndpointDTO
from apps.asset.repositories.asset import DjangoEndpointRepository
logger = logging.getLogger(__name__)
class EndpointService:
"""
Endpoint 服务类
提供 EndpointURL/端点)相关的业务逻辑
"""
def __init__(self):
"""初始化 Endpoint 服务"""
self.repo = DjangoEndpointRepository()
def bulk_create_endpoints(
self,
endpoints: List[EndpointDTO],
ignore_conflicts: bool = True
) -> int:
"""
批量创建端点记录
Args:
endpoints: 端点数据列表
ignore_conflicts: 是否忽略冲突(去重)
Returns:
int: 创建的记录数
"""
if not endpoints:
return 0
try:
if ignore_conflicts:
return self.repo.bulk_create_ignore_conflicts(endpoints)
else:
# 如果需要非忽略冲突的版本,可以在 repository 中添加
return self.repo.bulk_create_ignore_conflicts(endpoints)
except Exception as e:
logger.error(f"批量创建端点失败: {e}")
raise
def get_endpoints_by_website(
self,
website_id: int,
limit: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
获取网站下的端点列表
Args:
website_id: 网站 ID
limit: 返回数量限制
Returns:
List[Dict]: 端点列表
"""
endpoints_dto = self.repo.get_by_website(website_id)
if limit:
endpoints_dto = endpoints_dto[:limit]
endpoints = []
for dto in endpoints_dto:
endpoints.append({
'url': dto.url,
'title': dto.title,
'status_code': dto.status_code,
'content_length': dto.content_length,
'webserver': dto.webserver
})
return endpoints
def get_endpoints_by_target(
self,
target_id: int,
limit: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
获取目标下的端点列表
Args:
target_id: 目标 ID
limit: 返回数量限制
Returns:
List[Dict]: 端点列表
"""
endpoints_dto = self.repo.get_by_target(target_id)
if limit:
endpoints_dto = endpoints_dto[:limit]
endpoints = []
for dto in endpoints_dto:
endpoints.append({
'url': dto.url,
'title': dto.title,
'status_code': dto.status_code,
'content_length': dto.content_length,
'webserver': dto.webserver
})
return endpoints
def count_endpoints_by_target(self, target_id: int) -> int:
"""
统计目标下的端点数量
Args:
target_id: 目标 ID
Returns:
int: 端点数量
"""
return self.repo.count_by_target(target_id)
def get_queryset_by_target(self, target_id: int):
return self.repo.get_queryset_by_target(target_id)
def get_all(self):
"""获取所有端点(全局查询)"""
return self.repo.get_all()
def iter_endpoint_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取目标下的所有端点 URL用于导出。"""
queryset = self.repo.get_queryset_by_target(target_id)
for url in queryset.values_list('url', flat=True).iterator(chunk_size=chunk_size):
yield url
def count_endpoints_by_website(self, website_id: int) -> int:
"""
统计网站下的端点数量
Args:
website_id: 网站 ID
Returns:
int: 端点数量
"""
return self.repo.count_by_website(website_id)
def soft_delete_endpoints(self, endpoint_ids: List[int]) -> int:
"""
软删除端点
Args:
endpoint_ids: 端点 ID 列表
Returns:
int: 更新的数量
"""
return self.repo.soft_delete_by_ids(endpoint_ids)
def hard_delete_endpoints(self, endpoint_ids: List[int]) -> tuple:
"""
硬删除端点
Args:
endpoint_ids: 端点 ID 列表
Returns:
tuple: (删除总数, 详细信息)
"""
return self.repo.hard_delete_by_ids(endpoint_ids)

View File

@@ -0,0 +1,61 @@
"""HostPortMapping Service - 业务逻辑层"""
import logging
from typing import List, Iterator
from apps.asset.repositories.asset import DjangoHostPortMappingRepository
from apps.asset.dtos.asset import HostPortMappingDTO
logger = logging.getLogger(__name__)
class HostPortMappingService:
"""主机端口映射服务 - 负责主机端口映射数据的业务逻辑"""
def __init__(self):
self.repo = DjangoHostPortMappingRepository()
def bulk_create_ignore_conflicts(self, items: List[HostPortMappingDTO]) -> int:
"""
批量创建主机端口映射(忽略冲突)
Args:
items: 主机端口映射 DTO 列表
Returns:
int: 实际创建的记录数
Note:
使用数据库唯一约束 + ignore_conflicts 自动去重
"""
try:
logger.debug("Service: 准备批量创建主机端口映射 - 数量: %d", len(items))
created_count = self.repo.bulk_create_ignore_conflicts(items)
logger.info("Service: 主机端口映射创建成功 - 数量: %d", created_count)
return created_count
except Exception as e:
logger.error(
"Service: 批量创建主机端口映射失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
def iter_host_port_by_target(self, target_id: int, batch_size: int = 1000):
return self.repo.get_for_export(target_id=target_id, batch_size=batch_size)
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
return self.repo.get_ip_aggregation_by_target(target_id, search=search)
def get_all_ip_aggregation(self, search: str = None):
"""获取所有 IP 聚合数据(全局查询)"""
return self.repo.get_all_ip_aggregation(search=search)
def iter_ips_by_target(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
"""流式获取目标下的所有唯一 IP 地址。"""
return self.repo.get_ips_for_export(target_id=target_id, batch_size=batch_size)

View File

@@ -0,0 +1,123 @@
import logging
from typing import Tuple, List, Dict
from apps.asset.repositories import DjangoSubdomainRepository
from apps.asset.dtos import SubdomainDTO
logger = logging.getLogger(__name__)
class SubdomainService:
"""子域名业务逻辑层"""
def __init__(self, repository=None):
"""
初始化子域名服务
Args:
repository: 子域名仓储实例(用于依赖注入)
"""
self.repo = repository or DjangoSubdomainRepository()
# ==================== 查询操作 ====================
def get_all(self):
"""
获取所有子域名
Returns:
QuerySet: 子域名查询集
"""
logger.debug("获取所有子域名")
return self.repo.get_all()
# ==================== 创建操作 ====================
def get_or_create(self, name: str, target_id: int) -> Tuple[any, bool]:
"""
获取或创建子域名
Args:
name: 子域名名称
target_id: 目标 ID
Returns:
(Subdomain对象, 是否新创建)
"""
logger.debug("获取或创建子域名 - Name: %s, Target ID: %d", name, target_id)
return self.repo.get_or_create(name, target_id)
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
"""
批量创建子域名,忽略冲突
Args:
items: 子域名 DTO 列表
Note:
使用 ignore_conflicts 策略,重复记录会被跳过
"""
logger.debug("批量创建子域名 - 数量: %d", len(items))
return self.repo.bulk_create_ignore_conflicts(items)
def get_by_names_and_target_id(self, names: set, target_id: int) -> dict:
"""
根据域名列表和目标ID批量查询子域名
Args:
names: 域名集合
target_id: 目标 ID
Returns:
dict: {域名: Subdomain对象}
"""
logger.debug("批量查询子域名 - 数量: %d, Target ID: %d", len(names), target_id)
return self.repo.get_by_names_and_target_id(names, target_id)
def get_subdomain_names_by_target(self, target_id: int) -> List[str]:
"""
获取目标下的所有子域名名称
Args:
target_id: 目标 ID
Returns:
List[str]: 子域名名称列表
"""
logger.debug("获取目标下所有子域名 - Target ID: %d", target_id)
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
return list(self.repo.get_domains_for_export(target_id=target_id))
def get_subdomains_by_target(self, target_id: int):
return self.repo.get_by_target(target_id)
def count_subdomains_by_target(self, target_id: int) -> int:
"""
统计目标下的子域名数量
Args:
target_id: 目标 ID
Returns:
int: 子域名数量
"""
logger.debug("统计目标下子域名数量 - Target ID: %d", target_id)
return self.repo.count_by_target(target_id)
def iter_subdomain_names_by_target(self, target_id: int, chunk_size: int = 1000):
"""
流式获取目标下的所有子域名名称(内存优化)
Args:
target_id: 目标 ID
chunk_size: 批次大小
Yields:
str: 子域名名称
"""
logger.debug("流式获取目标下所有子域名 - Target ID: %d, 批次大小: %d", target_id, chunk_size)
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
return self.repo.get_domains_for_export(target_id=target_id, batch_size=chunk_size)
__all__ = ['SubdomainService']

View File

@@ -0,0 +1,81 @@
"""Vulnerability Service - 漏洞资产业务逻辑层"""
import logging
from typing import List
from apps.asset.models import Vulnerability
from apps.asset.dtos.asset import VulnerabilityDTO
from apps.common.decorators import auto_ensure_db_connection
logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class VulnerabilityService:
"""漏洞资产服务
当前提供基础的批量创建能力,使用 ignore_conflicts 依赖数据库唯一约束去重。
"""
def bulk_create_ignore_conflicts(self, items: List[VulnerabilityDTO]) -> None:
"""批量创建漏洞资产记录,忽略冲突。
Note:
- 是否去重取决于模型上的唯一/部分唯一约束;
- 当前 Vulnerability 模型未定义唯一约束,因此会保留全部记录。
"""
if not items:
logger.debug("漏洞资产列表为空,跳过保存")
return
try:
vulns = [
Vulnerability(
target_id=item.target_id,
url=item.url,
vuln_type=item.vuln_type,
severity=item.severity,
source=item.source,
cvss_score=item.cvss_score,
description=item.description,
raw_output=item.raw_output,
)
for item in items
]
Vulnerability.objects.bulk_create(vulns, ignore_conflicts=True)
logger.info("漏洞资产保存成功 - 数量: %d", len(vulns))
except Exception as e:
logger.error(
"批量保存漏洞资产失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True,
)
raise
# ==================== 查询方法 ====================
def get_all(self):
"""获取所有漏洞 QuerySet用于全局漏洞列表
Returns:
QuerySet[Vulnerability]: 所有漏洞,按发现时间倒序
"""
return Vulnerability.objects.filter(deleted_at__isnull=True).order_by("-discovered_at")
def get_queryset_by_target(self, target_id: int):
"""按目标获取漏洞 QuerySet用于分页
Args:
target_id: 目标 ID
Returns:
QuerySet[Vulnerability]: 目标下的所有漏洞,按发现时间倒序
"""
return Vulnerability.objects.filter(target_id=target_id).order_by("-discovered_at")
def count_by_target(self, target_id: int) -> int:
"""统计目标下的漏洞数量。"""
return Vulnerability.objects.filter(target_id=target_id).count()

View File

@@ -0,0 +1,91 @@
import logging
from typing import Tuple, List
from apps.asset.repositories import DjangoWebSiteRepository
from apps.asset.dtos import WebSiteDTO
logger = logging.getLogger(__name__)
class WebSiteService:
"""网站业务逻辑层"""
def __init__(self, repository=None):
"""
初始化网站服务
Args:
repository: 网站仓储实例(用于依赖注入)
"""
self.repo = repository or DjangoWebSiteRepository()
# ==================== 创建操作 ====================
def bulk_create_ignore_conflicts(self, website_dtos: List[WebSiteDTO]) -> None:
"""
批量创建网站记录,忽略冲突(用于扫描任务)
Args:
website_dtos: WebSiteDTO 列表
Note:
使用 ignore_conflicts 策略,重复记录会被跳过
"""
logger.debug("批量创建网站 - 数量: %d", len(website_dtos))
return self.repo.bulk_create_ignore_conflicts(website_dtos)
# ==================== 查询操作 ====================
def get_by_url(self, url: str, target_id: int) -> int:
"""
根据 URL 和 target_id 查找网站 ID
Args:
url: 网站 URL
target_id: 目标 ID
Returns:
int: 网站 ID如果不存在返回 None
"""
return self.repo.get_by_url(url=url, target_id=target_id)
# ==================== 查询操作 ====================
def get_all(self):
"""
获取所有网站
Returns:
QuerySet: 网站查询集
"""
logger.debug("获取所有网站")
return self.repo.get_all()
def get_websites_by_target(self, target_id: int):
return self.repo.get_by_target(target_id)
def count_websites_by_scan(self, scan_id: int) -> int:
"""
统计扫描下的网站数量
Args:
scan_id: 扫描 ID
Returns:
int: 网站数量
"""
logger.debug("统计扫描下网站数量 - Scan ID: %d", scan_id)
return self.repo.count_by_scan(scan_id)
def iter_website_urls_by_target(self, target_id: int, chunk_size: int = 1000):
"""流式获取目标下的所有站点 URL内存优化委托给 Repository 层)"""
logger.debug(
"流式获取目标下所有站点 URL - Target ID: %d, 批次大小: %d",
target_id,
chunk_size,
)
# 通过仓储层统一访问数据库,避免 Service 直接依赖 ORM
return self.repo.get_urls_for_export(target_id=target_id, batch_size=chunk_size)
__all__ = ['WebSiteService']

View File

@@ -0,0 +1,17 @@
"""Snapshot Services - 快照服务"""
from .subdomain_snapshots_service import SubdomainSnapshotsService
from .host_port_mapping_snapshots_service import HostPortMappingSnapshotsService
from .website_snapshots_service import WebsiteSnapshotsService
from .directory_snapshots_service import DirectorySnapshotsService
from .endpoint_snapshots_service import EndpointSnapshotsService
from .vulnerability_snapshots_service import VulnerabilitySnapshotsService
__all__ = [
'SubdomainSnapshotsService',
'HostPortMappingSnapshotsService',
'WebsiteSnapshotsService',
'DirectorySnapshotsService',
'EndpointSnapshotsService',
'VulnerabilitySnapshotsService',
]

View File

@@ -0,0 +1,83 @@
"""Directory Snapshots Service - 业务逻辑层"""
import logging
from typing import List, Iterator
from apps.asset.repositories.snapshot import DjangoDirectorySnapshotRepository
from apps.asset.services.asset import DirectoryService
from apps.asset.dtos.snapshot import DirectorySnapshotDTO
logger = logging.getLogger(__name__)
class DirectorySnapshotsService:
"""目录快照服务 - 统一管理快照和资产同步"""
def __init__(self):
self.snapshot_repo = DjangoDirectorySnapshotRepository()
self.asset_service = DirectoryService()
def save_and_sync(self, items: List[DirectorySnapshotDTO]) -> None:
"""
保存目录快照并同步到资产表(统一入口)
流程:
1. 保存到快照表(完整记录,包含 scan_id
2. 同步到资产表(去重,不包含 scan_id
Args:
items: 目录快照 DTO 列表(必须包含 website_id
Raises:
ValueError: 如果 items 中的 website_id 为 None
Exception: 数据库操作失败
"""
if not items:
return
# 检查 Scan 是否仍存在(防止删除后竞态写入)
scan_id = items[0].scan_id
from apps.scan.repositories import DjangoScanRepository
if not DjangoScanRepository().exists(scan_id):
logger.warning("Scan 已删除,跳过目录快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
return
try:
logger.debug("保存目录快照并同步到资产表 - 数量: %d", len(items))
# 步骤 1: 保存到快照表
logger.debug("步骤 1: 保存到快照表")
self.snapshot_repo.save_snapshots(items)
# 步骤 2: 转换为资产 DTO 并保存到资产表
# 注意:去重是通过数据库的 UNIQUE 约束 + ignore_conflicts 实现的
# - 新记录:插入资产表
# - 已存在的记录:自动跳过
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
asset_items = [item.to_asset_dto() for item in items]
self.asset_service.bulk_create_ignore_conflicts(asset_items)
logger.info("目录快照和资产数据保存成功 - 数量: %d", len(items))
except Exception as e:
logger.error(
"保存目录快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
def get_by_scan(self, scan_id: int):
return self.snapshot_repo.get_by_scan(scan_id)
def get_all(self):
"""获取所有目录快照"""
return self.snapshot_repo.get_all()
def iter_directory_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有目录 URL。"""
queryset = self.snapshot_repo.get_by_scan(scan_id)
for snapshot in queryset.iterator(chunk_size=chunk_size):
yield snapshot.url

View File

@@ -0,0 +1,83 @@
"""Endpoint Snapshots Service - 业务逻辑层"""
import logging
from typing import List, Iterator
from apps.asset.repositories.snapshot import DjangoEndpointSnapshotRepository
from apps.asset.services.asset import EndpointService
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
logger = logging.getLogger(__name__)
class EndpointSnapshotsService:
"""端点快照服务 - 统一管理快照和资产同步"""
def __init__(self):
self.snapshot_repo = DjangoEndpointSnapshotRepository()
self.asset_service = EndpointService()
def save_and_sync(self, items: List[EndpointSnapshotDTO]) -> None:
"""
保存端点快照并同步到资产表(统一入口)
流程:
1. 保存到快照表(完整记录)
2. 同步到资产表(去重)
Args:
items: 端点快照 DTO 列表(必须包含 target_id
Raises:
ValueError: 如果 items 中的 target_id 为 None
Exception: 数据库操作失败
"""
if not items:
return
# 检查 Scan 是否仍存在(防止删除后竞态写入)
scan_id = items[0].scan_id
from apps.scan.repositories import DjangoScanRepository
if not DjangoScanRepository().exists(scan_id):
logger.warning("Scan 已删除,跳过端点快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
return
try:
logger.debug("保存端点快照并同步到资产表 - 数量: %d", len(items))
# 步骤 1: 保存到快照表
logger.debug("步骤 1: 保存到快照表")
self.snapshot_repo.save_snapshots(items)
# 步骤 2: 转换为资产 DTO 并保存到资产表
# 注意:去重是通过数据库的 UNIQUE 约束 + ignore_conflicts 实现的
# - 新记录:插入资产表
# - 已存在的记录:自动跳过
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
asset_items = [item.to_asset_dto() for item in items]
self.asset_service.bulk_create_endpoints(asset_items)
logger.info("端点快照和资产数据保存成功 - 数量: %d", len(items))
except Exception as e:
logger.error(
"保存端点快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
def get_by_scan(self, scan_id: int):
return self.snapshot_repo.get_by_scan(scan_id)
def get_all(self):
"""获取所有端点快照"""
return self.snapshot_repo.get_all()
def iter_endpoint_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有端点 URL。"""
queryset = self.snapshot_repo.get_by_scan(scan_id)
for snapshot in queryset.iterator(chunk_size=chunk_size):
yield snapshot.url

View File

@@ -0,0 +1,81 @@
"""HostPortMapping Snapshots Service - 业务逻辑层"""
import logging
from typing import List, Iterator
from apps.asset.repositories.snapshot import DjangoHostPortMappingSnapshotRepository
from apps.asset.services.asset import HostPortMappingService
from apps.asset.dtos.snapshot import HostPortMappingSnapshotDTO
logger = logging.getLogger(__name__)
class HostPortMappingSnapshotsService:
"""HostPortMapping Snapshots Service - 统一管理快照和资产同步"""
def __init__(self):
self.snapshot_repo = DjangoHostPortMappingSnapshotRepository()
self.asset_service = HostPortMappingService()
def save_and_sync(self, items: List[HostPortMappingSnapshotDTO]) -> None:
"""
保存主机端口关联快照并同步到资产表(统一入口)
流程:
1. 保存到快照表(完整记录,包含 scan_id
2. 同步到资产表(去重,不包含 scan_id
Args:
items: 主机端口关联快照 DTO 列表(必须包含 target_id
Note:
target_id 已经包含在 DTO 中,无需额外传参。
"""
logger.debug("保存主机端口关联快照 - 数量: %d", len(items))
if not items:
logger.debug("快照数据为空,跳过保存")
return
# 检查 Scan 是否仍存在(防止删除后竞态写入)
scan_id = items[0].scan_id
from apps.scan.repositories import DjangoScanRepository
if not DjangoScanRepository().exists(scan_id):
logger.warning("Scan 已删除,跳过主机端口快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
return
try:
# 步骤 1: 保存到快照表
logger.debug("步骤 1: 保存到快照表")
self.snapshot_repo.save_snapshots(items)
# 步骤 2: 转换为资产 DTO 并保存到资产表
# 注意:去重是通过数据库的 UNIQUE 约束 + ignore_conflicts 实现的
# - 新记录:插入资产表
# - 已存在的记录:自动跳过
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
asset_items = [item.to_asset_dto() for item in items]
self.asset_service.bulk_create_ignore_conflicts(asset_items)
logger.info("主机端口关联快照和资产数据保存成功 - 数量: %d", len(items))
except Exception as e:
logger.error(
"保存主机端口关联快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
return self.snapshot_repo.get_ip_aggregation_by_scan(scan_id, search=search)
def get_all_ip_aggregation(self, search: str = None):
"""获取所有 IP 聚合数据"""
return self.snapshot_repo.get_all_ip_aggregation(search=search)
def iter_ips_by_scan(self, scan_id: int, batch_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有唯一 IP 地址。"""
return self.snapshot_repo.get_ips_for_export(scan_id=scan_id, batch_size=batch_size)

View File

@@ -0,0 +1,79 @@
import logging
from typing import List, Iterator
from apps.asset.dtos import SubdomainSnapshotDTO
from apps.asset.repositories import DjangoSubdomainSnapshotRepository
logger = logging.getLogger(__name__)
class SubdomainSnapshotsService:
"""子域名快照服务 - 负责子域名快照数据的业务逻辑"""
def __init__(self):
self.subdomain_snapshot_repo = DjangoSubdomainSnapshotRepository()
def save_and_sync(self, items: List[SubdomainSnapshotDTO]) -> None:
"""
保存子域名快照并同步到资产表(统一入口)
流程:
1. 保存到快照表(完整记录,包含 scan_id
2. 同步到资产表(去重,不包含 scan_id
Args:
items: 子域名快照 DTO 列表(包含 target_id
Note:
target_id 已经包含在 DTO 中,无需额外传参。
"""
logger.debug("保存子域名快照 - 数量: %d", len(items))
if not items:
logger.debug("快照数据为空,跳过保存")
return
# 检查 Scan 是否仍存在(防止删除后竞态写入)
scan_id = items[0].scan_id
from apps.scan.repositories import DjangoScanRepository
if not DjangoScanRepository().exists(scan_id):
logger.warning("Scan 已删除,跳过快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
return
try:
# 步骤 1: 保存到快照表
logger.debug("步骤 1: 保存到快照表")
self.subdomain_snapshot_repo.save_subdomain_snapshots(items)
# 步骤 2: 转换为资产 DTO 并保存到资产表(通过数据库唯一约束自动去重)
# 注意:去重是通过数据库的 UNIQUE 约束 + ignore_conflicts 实现的
# - 新子域名:插入资产表
# - 已存在的子域名:自动跳过(不更新,因为资产表只记录核心数据)
asset_items = [item.to_asset_dto() for item in items]
from apps.asset.services import SubdomainService
subdomain_service = SubdomainService()
subdomain_service.bulk_create_ignore_conflicts(asset_items)
logger.info("子域名快照和业务数据保存成功 - 数量: %d", len(items))
except Exception as e:
logger.error(
"保存子域名快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
def get_by_scan(self, scan_id: int):
return self.subdomain_snapshot_repo.get_by_scan(scan_id)
def get_all(self):
"""获取所有子域名快照"""
return self.subdomain_snapshot_repo.get_all()
def iter_subdomain_names_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
queryset = self.subdomain_snapshot_repo.get_by_scan(scan_id)
for snapshot in queryset.iterator(chunk_size=chunk_size):
yield snapshot.name

View File

@@ -0,0 +1,81 @@
"""Vulnerability Snapshots Service - 业务逻辑层"""
import logging
from typing import List, Iterator
from apps.asset.repositories.snapshot import DjangoVulnerabilitySnapshotRepository
from apps.asset.services.asset.vulnerability_service import VulnerabilityService
from apps.asset.dtos.snapshot import VulnerabilitySnapshotDTO
logger = logging.getLogger(__name__)
class VulnerabilitySnapshotsService:
"""漏洞快照服务 - 统一管理快照和资产同步。
流程与 Website/Directory 等保持一致:
1. 保存到 VulnerabilitySnapshot 快照表(包含 scan_id
2. 转为 VulnerabilityDTO 并同步到 Vulnerability 资产表(基于 target_id
"""
def __init__(self):
self.snapshot_repo = DjangoVulnerabilitySnapshotRepository()
self.asset_service = VulnerabilityService()
def save_and_sync(self, items: List[VulnerabilitySnapshotDTO]) -> None:
"""保存漏洞快照并同步到漏洞资产表。"""
if not items:
return
# 检查 Scan 是否仍存在(防止删除后竞态写入)
scan_id = items[0].scan_id
from apps.scan.repositories import DjangoScanRepository
if not DjangoScanRepository().exists(scan_id):
logger.warning("Scan 已删除,跳过漏洞快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
return
try:
logger.debug("保存漏洞快照并同步到资产表 - 数量: %d", len(items))
# 步骤 1: 保存到快照表
logger.debug("步骤 1: 保存到漏洞快照表")
self.snapshot_repo.save_snapshots(items)
# 步骤 2: 转换为资产 DTO 并保存到资产表
logger.debug("步骤 2: 同步到漏洞资产表")
asset_items = [item.to_asset_dto() for item in items]
self.asset_service.bulk_create_ignore_conflicts(asset_items)
logger.info("漏洞快照和资产数据保存成功 - 数量: %d", len(items))
# 步骤 3: 发布漏洞保存信号(通知等模块可监听)
from apps.common.signals import vulnerabilities_saved
vulnerabilities_saved.send(
sender=self.__class__,
items=items,
scan_id=scan_id,
target_id=items[0].target_id if items else None
)
except Exception as e:
logger.error(
"保存漏洞快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True,
)
raise
def get_by_scan(self, scan_id: int):
"""按扫描任务获取所有漏洞快照。"""
return self.snapshot_repo.get_by_scan(scan_id)
def get_all(self):
"""获取所有漏洞快照"""
return self.snapshot_repo.get_all()
def iter_vuln_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有漏洞 URL。"""
queryset = self.snapshot_repo.get_by_scan(scan_id)
for snapshot in queryset.iterator(chunk_size=chunk_size):
yield snapshot.url

View File

@@ -0,0 +1,83 @@
"""Website Snapshots Service - 业务逻辑层"""
import logging
from typing import List, Iterator
from apps.asset.repositories.snapshot import DjangoWebsiteSnapshotRepository
from apps.asset.services.asset import WebSiteService
from apps.asset.dtos.snapshot import WebsiteSnapshotDTO
logger = logging.getLogger(__name__)
class WebsiteSnapshotsService:
"""网站快照服务 - 统一管理快照和资产同步"""
def __init__(self):
self.snapshot_repo = DjangoWebsiteSnapshotRepository()
self.asset_service = WebSiteService()
def save_and_sync(self, items: List[WebsiteSnapshotDTO]) -> None:
"""
保存网站快照并同步到资产表(统一入口)
流程:
1. 保存到快照表(完整记录,包含 scan_id
2. 同步到资产表(去重,不包含 scan_id
Args:
items: 网站快照 DTO 列表(必须包含 target_id
Raises:
ValueError: 如果 items 中的 target_id 为 None
Exception: 数据库操作失败
"""
if not items:
return
# 检查 Scan 是否仍存在(防止删除后竞态写入)
scan_id = items[0].scan_id
from apps.scan.repositories import DjangoScanRepository
if not DjangoScanRepository().exists(scan_id):
logger.warning("Scan 已删除,跳过网站快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
return
try:
logger.debug("保存网站快照并同步到资产表 - 数量: %d", len(items))
# 步骤 1: 保存到快照表
logger.debug("步骤 1: 保存到快照表")
self.snapshot_repo.save_snapshots(items)
# 步骤 2: 转换为资产 DTO 并保存到资产表
# 注意:去重是通过数据库的 UNIQUE 约束 + ignore_conflicts 实现的
# - 新记录:插入资产表
# - 已存在的记录:自动跳过
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
asset_items = [item.to_asset_dto() for item in items]
self.asset_service.bulk_create_ignore_conflicts(asset_items)
logger.info("网站快照和资产数据保存成功 - 数量: %d", len(items))
except Exception as e:
logger.error(
"保存网站快照失败 - 数量: %d, 错误: %s",
len(items),
str(e),
exc_info=True
)
raise
def get_by_scan(self, scan_id: int):
return self.snapshot_repo.get_by_scan(scan_id)
def get_all(self):
"""获取所有网站快照"""
return self.snapshot_repo.get_all()
def iter_website_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有站点 URL按发现时间倒序"""
queryset = self.snapshot_repo.get_by_scan(scan_id)
for snapshot in queryset.iterator(chunk_size=chunk_size):
yield snapshot.url

View File

@@ -0,0 +1,162 @@
"""资产统计 Service"""
import logging
from typing import Optional
from django.db.models import Count
from apps.asset.repositories import AssetStatisticsRepository
from apps.asset.models import (
AssetStatistics,
StatisticsHistory,
Subdomain,
WebSite,
Endpoint,
HostPortMapping,
Vulnerability,
)
from apps.targets.models import Target
from apps.scan.models import Scan
logger = logging.getLogger(__name__)
class AssetStatisticsService:
"""
资产统计服务
职责:
- 获取统计数据
- 刷新统计数据(供定时任务调用)
"""
def __init__(self):
self.repo = AssetStatisticsRepository()
def get_statistics(self) -> dict:
"""
获取统计数据
Returns:
统计数据字典
"""
stats = self.repo.get_statistics()
if stats is None:
# 如果没有统计数据,返回默认值
return {
'total_targets': 0,
'total_subdomains': 0,
'total_ips': 0,
'total_endpoints': 0,
'total_websites': 0,
'total_vulns': 0,
'total_assets': 0,
'running_scans': Scan.objects.filter(status='running').count(),
'updated_at': None,
# 变化值
'change_targets': 0,
'change_subdomains': 0,
'change_ips': 0,
'change_endpoints': 0,
'change_websites': 0,
'change_vulns': 0,
'change_assets': 0,
'vuln_by_severity': self._get_vuln_by_severity(),
}
# 运行中的扫描数量实时查询(数量小,无需缓存)
running_scans = Scan.objects.filter(status='running').count()
return {
'total_targets': stats.total_targets,
'total_subdomains': stats.total_subdomains,
'total_ips': stats.total_ips,
'total_endpoints': stats.total_endpoints,
'total_websites': stats.total_websites,
'total_vulns': stats.total_vulns,
'total_assets': stats.total_assets,
'running_scans': running_scans,
'updated_at': stats.updated_at,
# 变化值 = 当前值 - 上次值
'change_targets': stats.total_targets - stats.prev_targets,
'change_subdomains': stats.total_subdomains - stats.prev_subdomains,
'change_ips': stats.total_ips - stats.prev_ips,
'change_endpoints': stats.total_endpoints - stats.prev_endpoints,
'change_websites': stats.total_websites - stats.prev_websites,
'change_vulns': stats.total_vulns - stats.prev_vulns,
'change_assets': stats.total_assets - stats.prev_assets,
# 漏洞严重程度分布
'vuln_by_severity': self._get_vuln_by_severity(),
}
def _get_vuln_by_severity(self) -> dict:
"""获取按严重程度统计的漏洞数量"""
result = Vulnerability.objects.values('severity').annotate(count=Count('id'))
severity_map = {item['severity']: item['count'] for item in result}
return {
'critical': severity_map.get('critical', 0),
'high': severity_map.get('high', 0),
'medium': severity_map.get('medium', 0),
'low': severity_map.get('low', 0),
'info': severity_map.get('info', 0),
}
def refresh_statistics(self) -> AssetStatistics:
"""
刷新统计数据(执行实际 COUNT 查询)
供定时任务调用,不建议在接口中直接调用。
Returns:
更新后的统计数据对象
"""
logger.info("开始刷新资产统计...")
# 执行 COUNT 查询
total_targets = Target.objects.filter(deleted_at__isnull=True).count()
total_subdomains = Subdomain.objects.count()
total_ips = HostPortMapping.objects.values('ip').distinct().count()
total_endpoints = Endpoint.objects.count()
total_websites = WebSite.objects.count()
total_vulns = Vulnerability.objects.count()
# 更新统计表
stats = self.repo.update_statistics(
total_targets=total_targets,
total_subdomains=total_subdomains,
total_ips=total_ips,
total_endpoints=total_endpoints,
total_websites=total_websites,
total_vulns=total_vulns,
)
# 保存每日快照(用于折线图)
self.repo.save_daily_snapshot(stats)
logger.info("资产统计刷新完成")
return stats
def get_statistics_history(self, days: int = 7) -> list[dict]:
"""
获取历史统计数据(用于折线图)
Args:
days: 获取最近多少天的数据,默认 7 天
Returns:
历史数据列表,每项包含 date 和各统计字段
"""
history = self.repo.get_history(days=days)
return [
{
'date': h.date.isoformat(),
'totalTargets': h.total_targets,
'totalSubdomains': h.total_subdomains,
'totalIps': h.total_ips,
'totalEndpoints': h.total_endpoints,
'totalWebsites': h.total_websites,
'totalVulns': h.total_vulns,
'totalAssets': h.total_assets,
}
for h in history
]

View File

@@ -0,0 +1,28 @@
"""
Asset 应用 URL 配置
"""
from django.urls import path, include
from rest_framework.routers import DefaultRouter
from .views import (
SubdomainViewSet,
WebSiteViewSet,
DirectoryViewSet,
VulnerabilityViewSet,
AssetStatisticsViewSet,
)
# 创建 DRF 路由器
router = DefaultRouter()
# 注册 ViewSet
# 注意IPAddress 模型已被重构为 HostPortMapping相关路由已移除
router.register(r'subdomains', SubdomainViewSet, basename='subdomain')
router.register(r'websites', WebSiteViewSet, basename='website')
router.register(r'directories', DirectoryViewSet, basename='directory')
router.register(r'vulnerabilities', VulnerabilityViewSet, basename='vulnerability')
router.register(r'statistics', AssetStatisticsViewSet, basename='asset-statistics')
urlpatterns = [
path('assets/', include(router.urls)),
]

562
backend/apps/asset/views.py Normal file
View File

@@ -0,0 +1,562 @@
import logging
from rest_framework import viewsets, status, filters
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.request import Request
from rest_framework.exceptions import NotFound, ValidationError as DRFValidationError
from django.core.exceptions import ValidationError, ObjectDoesNotExist
from django.db import DatabaseError, IntegrityError, OperationalError
from django.http import StreamingHttpResponse
from .serializers import (
SubdomainListSerializer, WebSiteSerializer, DirectorySerializer,
VulnerabilitySerializer, EndpointListSerializer, IPAddressAggregatedSerializer,
SubdomainSnapshotSerializer, WebsiteSnapshotSerializer, DirectorySnapshotSerializer,
EndpointSnapshotSerializer, VulnerabilitySnapshotSerializer
)
from .services import (
SubdomainService, WebSiteService, DirectoryService,
VulnerabilityService, AssetStatisticsService, EndpointService, HostPortMappingService
)
from .services.snapshot import (
SubdomainSnapshotsService, WebsiteSnapshotsService, DirectorySnapshotsService,
EndpointSnapshotsService, HostPortMappingSnapshotsService, VulnerabilitySnapshotsService
)
from apps.common.pagination import BasePagination
logger = logging.getLogger(__name__)
class AssetStatisticsViewSet(viewsets.ViewSet):
"""
资产统计 API
提供仪表盘所需的统计数据(预聚合,读取缓存表)
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = AssetStatisticsService()
def list(self, request):
"""
获取资产统计数据
GET /assets/statistics/
返回:
- totalTargets: 目标总数
- totalSubdomains: 子域名总数
- totalIps: IP 总数
- totalEndpoints: 端点总数
- totalWebsites: 网站总数
- totalVulns: 漏洞总数
- totalAssets: 总资产数
- runningScans: 运行中的扫描数
- updatedAt: 统计更新时间
"""
try:
stats = self.service.get_statistics()
return Response({
'totalTargets': stats['total_targets'],
'totalSubdomains': stats['total_subdomains'],
'totalIps': stats['total_ips'],
'totalEndpoints': stats['total_endpoints'],
'totalWebsites': stats['total_websites'],
'totalVulns': stats['total_vulns'],
'totalAssets': stats['total_assets'],
'runningScans': stats['running_scans'],
'updatedAt': stats['updated_at'],
# 变化值
'changeTargets': stats['change_targets'],
'changeSubdomains': stats['change_subdomains'],
'changeIps': stats['change_ips'],
'changeEndpoints': stats['change_endpoints'],
'changeWebsites': stats['change_websites'],
'changeVulns': stats['change_vulns'],
'changeAssets': stats['change_assets'],
# 漏洞严重程度分布
'vulnBySeverity': stats['vuln_by_severity'],
})
except (DatabaseError, OperationalError) as e:
logger.exception("获取资产统计数据失败")
return Response(
{'error': '获取统计数据失败'},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@action(detail=False, methods=['get'], url_path='history')
def history(self, request: Request):
"""
获取统计历史数据(用于折线图)
GET /assets/statistics/history/?days=7
Query Parameters:
days: 获取最近多少天的数据,默认 7最大 90
Returns:
历史数据列表
"""
try:
days_param = request.query_params.get('days', '7')
try:
days = int(days_param)
except (ValueError, TypeError):
days = 7
days = min(max(days, 1), 90) # 限制在 1-90 天
history = self.service.get_statistics_history(days=days)
return Response(history)
except (DatabaseError, OperationalError) as e:
logger.exception("获取统计历史数据失败")
return Response(
{'error': '获取历史数据失败'},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
# 注意IPAddress 模型已被重构为 HostPortMapping
# IPAddressViewSet 已删除,需要根据新架构重新实现
class SubdomainViewSet(viewsets.ModelViewSet):
"""子域名管理 ViewSet
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/subdomains/
2. 独立路由GET /api/subdomains/(全局查询)
"""
serializer_class = SubdomainListSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['name']
ordering = ['-discovered_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = SubdomainService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
target_pk = self.kwargs.get('target_pk')
if target_pk:
return self.service.get_subdomains_by_target(target_pk)
return self.service.get_all()
@action(detail=False, methods=['get'], url_path='export')
def export(self, request):
"""导出子域名(纯文本,一行一个)"""
target_pk = self.kwargs.get('target_pk')
if not target_pk:
raise DRFValidationError('必须在目标下导出')
def line_iterator():
for name in self.service.iter_subdomain_names_by_target(target_pk):
yield f"{name}\n"
response = StreamingHttpResponse(
line_iterator(),
content_type='text/plain; charset=utf-8',
)
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-subdomains.txt"'
return response
class WebSiteViewSet(viewsets.ModelViewSet):
"""站点管理 ViewSet
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/websites/
2. 独立路由GET /api/websites/(全局查询)
"""
serializer_class = WebSiteSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['host']
ordering = ['-discovered_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = WebSiteService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
target_pk = self.kwargs.get('target_pk')
if target_pk:
return self.service.get_websites_by_target(target_pk)
return self.service.get_all()
@action(detail=False, methods=['get'], url_path='export')
def export(self, request):
"""导出站点 URL纯文本一行一个"""
target_pk = self.kwargs.get('target_pk')
if not target_pk:
raise DRFValidationError('必须在目标下导出')
def line_iterator():
for url in self.service.iter_website_urls_by_target(target_pk):
yield f"{url}\n"
response = StreamingHttpResponse(
line_iterator(),
content_type='text/plain; charset=utf-8',
)
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-websites.txt"'
return response
class DirectoryViewSet(viewsets.ModelViewSet):
"""目录管理 ViewSet
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/directories/
2. 独立路由GET /api/directories/(全局查询)
"""
serializer_class = DirectorySerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['url']
ordering = ['-discovered_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = DirectoryService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
target_pk = self.kwargs.get('target_pk')
if target_pk:
return self.service.get_directories_by_target(target_pk)
return self.service.get_all()
@action(detail=False, methods=['get'], url_path='export')
def export(self, request):
"""导出目录 URL纯文本一行一个"""
target_pk = self.kwargs.get('target_pk')
if not target_pk:
raise DRFValidationError('必须在目标下导出')
def line_iterator():
for url in self.service.iter_directory_urls_by_target(target_pk):
yield f"{url}\n"
response = StreamingHttpResponse(
line_iterator(),
content_type='text/plain; charset=utf-8',
)
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-directories.txt"'
return response
class EndpointViewSet(viewsets.ModelViewSet):
"""端点管理 ViewSet
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/endpoints/
2. 独立路由GET /api/endpoints/(全局查询)
"""
serializer_class = EndpointListSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['host']
ordering = ['-discovered_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = EndpointService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
target_pk = self.kwargs.get('target_pk')
if target_pk:
return self.service.get_queryset_by_target(target_pk)
return self.service.get_all()
@action(detail=False, methods=['get'], url_path='export')
def export(self, request):
"""导出端点 URL纯文本一行一个"""
target_pk = self.kwargs.get('target_pk')
if not target_pk:
raise DRFValidationError('必须在目标下导出')
def line_iterator():
for url in self.service.iter_endpoint_urls_by_target(target_pk):
yield f"{url}\n"
response = StreamingHttpResponse(
line_iterator(),
content_type='text/plain; charset=utf-8',
)
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-endpoints.txt"'
return response
class HostPortMappingViewSet(viewsets.ModelViewSet):
"""主机端口映射管理 ViewSetIP 地址聚合视图)
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/ip-addresses/
2. 独立路由GET /api/ip-addresses/(全局查询)
返回按 IP 聚合的数据,每个 IP 显示其关联的所有 hosts 和 ports
注意:由于返回的是聚合数据(字典列表),不支持 DRF SearchFilter
"""
serializer_class = IPAddressAggregatedSerializer
pagination_class = BasePagination
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = HostPortMappingService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围,返回按 IP 聚合的数据"""
target_pk = self.kwargs.get('target_pk')
search = self.request.query_params.get('search', None)
if target_pk:
return self.service.get_ip_aggregation_by_target(target_pk, search=search)
return self.service.get_all_ip_aggregation(search=search)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request):
"""导出 IP 地址(纯文本,一行一个)"""
target_pk = self.kwargs.get('target_pk')
if not target_pk:
raise DRFValidationError('必须在目标下导出')
def line_iterator():
for ip in self.service.iter_ips_by_target(target_pk):
yield f"{ip}\n"
response = StreamingHttpResponse(
line_iterator(),
content_type='text/plain; charset=utf-8',
)
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-ip-addresses.txt"'
return response
class VulnerabilityViewSet(viewsets.ModelViewSet):
"""漏洞资产管理 ViewSet只读
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/vulnerabilities/
2. 独立路由GET /api/vulnerabilities/(全局查询)
"""
serializer_class = VulnerabilitySerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['vuln_type']
ordering = ['-discovered_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = VulnerabilityService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
target_pk = self.kwargs.get('target_pk')
if target_pk:
return self.service.get_queryset_by_target(target_pk)
return self.service.get_all()
# ==================== 快照 ViewSetScan 嵌套路由) ====================
class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
"""子域名快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/subdomains/"""
serializer_class = SubdomainSnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['name']
ordering_fields = ['name', 'discovered_at']
ordering = ['-discovered_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = SubdomainSnapshotsService()
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
@action(detail=False, methods=['get'], url_path='export')
def export(self, request):
scan_pk = self.kwargs.get('scan_pk')
if not scan_pk:
raise DRFValidationError('必须在扫描下导出')
def line_iterator():
for name in self.service.iter_subdomain_names_by_scan(scan_pk):
yield f"{name}\n"
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-subdomains.txt"'
return response
class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
"""网站快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/websites/"""
serializer_class = WebsiteSnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['host']
ordering = ['-discovered_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = WebsiteSnapshotsService()
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
@action(detail=False, methods=['get'], url_path='export')
def export(self, request):
scan_pk = self.kwargs.get('scan_pk')
if not scan_pk:
raise DRFValidationError('必须在扫描下导出')
def line_iterator():
for url in self.service.iter_website_urls_by_scan(scan_pk):
yield f"{url}\n"
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-websites.txt"'
return response
class DirectorySnapshotViewSet(viewsets.ModelViewSet):
"""目录快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/directories/"""
serializer_class = DirectorySnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['url']
ordering = ['-discovered_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = DirectorySnapshotsService()
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
@action(detail=False, methods=['get'], url_path='export')
def export(self, request):
scan_pk = self.kwargs.get('scan_pk')
if not scan_pk:
raise DRFValidationError('必须在扫描下导出')
def line_iterator():
for url in self.service.iter_directory_urls_by_scan(scan_pk):
yield f"{url}\n"
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-directories.txt"'
return response
class EndpointSnapshotViewSet(viewsets.ModelViewSet):
"""端点快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/endpoints/"""
serializer_class = EndpointSnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['host']
ordering = ['-discovered_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = EndpointSnapshotsService()
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
@action(detail=False, methods=['get'], url_path='export')
def export(self, request):
scan_pk = self.kwargs.get('scan_pk')
if not scan_pk:
raise DRFValidationError('必须在扫描下导出')
def line_iterator():
for url in self.service.iter_endpoint_urls_by_scan(scan_pk):
yield f"{url}\n"
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-endpoints.txt"'
return response
class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
"""主机端口映射快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/ip-addresses/
注意:由于返回的是聚合数据(字典列表),不支持 DRF SearchFilter
"""
serializer_class = IPAddressAggregatedSerializer
pagination_class = BasePagination
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = HostPortMappingSnapshotsService()
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
search = self.request.query_params.get('search', None)
if scan_pk:
return self.service.get_ip_aggregation_by_scan(scan_pk, search=search)
return self.service.get_all_ip_aggregation(search=search)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request):
scan_pk = self.kwargs.get('scan_pk')
if not scan_pk:
raise DRFValidationError('必须在扫描下导出')
def line_iterator():
for ip in self.service.iter_ips_by_scan(scan_pk):
yield f"{ip}\n"
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-ip-addresses.txt"'
return response
class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
"""漏洞快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/vulnerabilities/"""
serializer_class = VulnerabilitySnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['vuln_type']
ordering = ['-discovered_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = VulnerabilitySnapshotsService()
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()

View File

@@ -0,0 +1,23 @@
"""
通用工具模块
提供各种共享的工具类和函数
"""
from .normalizer import normalize_domain, normalize_ip, normalize_cidr, normalize_target
from .validators import validate_domain, validate_ip, validate_cidr, detect_target_type
__all__ = [
# 规范化工具
'normalize_domain',
'normalize_ip',
'normalize_cidr',
'normalize_target',
# 验证器
'validate_domain',
'validate_ip',
'validate_cidr',
'detect_target_type',
]

View File

@@ -0,0 +1,10 @@
from django.apps import AppConfig
class CommonConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'apps.common' # 因为在 apps/ 目录下
def ready(self):
"""应用就绪时调用"""
pass

View File

@@ -0,0 +1,13 @@
from rest_framework.authentication import SessionAuthentication
class CsrfExemptSessionAuthentication(SessionAuthentication):
"""
前后端分离项目使用的 Session 认证
禁用 CSRF 检查,因为 CSRF 主要防护的是同源页面表单提交
前后端分离项目通过 CORS 控制跨域访问,不需要 CSRF
"""
def enforce_csrf(self, request):
# 不执行 CSRF 检查
return

View File

@@ -0,0 +1,66 @@
"""
容器启动引导模块
提供动态任务容器的通用初始化功能:
- 从 Server 配置中心获取配置
- 设置环境变量
- 初始化 Django 环境
使用方式:
from apps.common.container_bootstrap import fetch_config_and_setup_django
fetch_config_and_setup_django() # 必须在 Django 导入之前调用
"""
import os
import sys
import requests
import logging
logger = logging.getLogger(__name__)
def fetch_config_and_setup_django():
"""
从配置中心获取配置并初始化 Django
Note:
必须在 Django 导入之前调用此函数
"""
server_url = os.environ.get("SERVER_URL")
if not server_url:
print("[ERROR] 缺少 SERVER_URL 环境变量", file=sys.stderr)
sys.exit(1)
config_url = f"{server_url}/api/workers/config/"
try:
resp = requests.get(config_url, timeout=10)
resp.raise_for_status()
config = resp.json()
# 数据库配置(必需)
os.environ.setdefault("DB_HOST", config['db']['host'])
os.environ.setdefault("DB_PORT", config['db']['port'])
os.environ.setdefault("DB_NAME", config['db']['name'])
os.environ.setdefault("DB_USER", config['db']['user'])
os.environ.setdefault("DB_PASSWORD", config['db']['password'])
# Redis 配置
os.environ.setdefault("REDIS_URL", config['redisUrl'])
# 日志配置
os.environ.setdefault("LOG_DIR", config['paths']['logs'])
os.environ.setdefault("LOG_LEVEL", config['logging']['level'])
os.environ.setdefault("ENABLE_COMMAND_LOGGING", str(config['logging']['enableCommandLogging']).lower())
os.environ.setdefault("DEBUG", str(config['debug']))
print(f"[CONFIG] 从配置中心获取配置成功: {config_url}")
except Exception as e:
print(f"[ERROR] 获取配置失败: {config_url} - {e}", file=sys.stderr)
sys.exit(1)
# 初始化 Django
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
import django
django.setup()
return config

View File

@@ -0,0 +1,19 @@
"""
通用装饰器模块
提供可在整个项目中复用的装饰器
"""
from .db_connection import (
ensure_db_connection,
auto_ensure_db_connection,
async_check_and_reconnect,
ensure_db_connection_async,
)
__all__ = [
'ensure_db_connection',
'auto_ensure_db_connection',
'async_check_and_reconnect',
'ensure_db_connection_async',
]

View File

@@ -0,0 +1,169 @@
"""
数据库连接装饰器
提供自动数据库连接健康检查的装饰器,确保长时间运行的任务中数据库连接不会失效。
主要功能:
- @auto_ensure_db_connection: 类装饰器,自动为所有公共方法添加连接检查
- @ensure_db_connection: 方法装饰器,单独为某个方法添加连接检查
使用场景:
- Repository 层的数据库操作
- Service 层需要确保数据库连接的操作
- 任何需要数据库连接健康检查的类或方法
"""
import logging
import functools
import time
from django.db import connection
from asgiref.sync import sync_to_async
logger = logging.getLogger(__name__)
def ensure_db_connection(method):
"""
方法装饰器:自动确保数据库连接健康
在方法执行前自动检查数据库连接,如果连接失效则自动重连。
使用场景:
- 需要单独装饰某个方法时使用
- 通常建议使用类装饰器 @auto_ensure_db_connection
示例:
@ensure_db_connection
def my_method(self):
# 会自动检查连接健康
...
"""
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
_check_and_reconnect()
return method(self, *args, **kwargs)
return wrapper
def auto_ensure_db_connection(cls):
"""
类装饰器:自动给所有公共方法添加数据库连接检查
自动为类中所有公共方法(不以 _ 开头的方法)添加 @ensure_db_connection 装饰器。
特性:
- 自动装饰所有公共方法
- 跳过私有方法(以 _ 开头)
- 跳过类方法和静态方法
- 跳过已经装饰过的方法
使用方式:
@auto_ensure_db_connection
class MyRepository:
def bulk_create(self, items):
# 自动添加连接检查
...
def query(self, filters):
# 自动添加连接检查
...
def _private_method(self):
# 不会添加装饰器
...
优势:
- 无需为每个方法手动添加装饰器
- 减少代码重复
- 降低遗漏风险
"""
for attr_name in dir(cls):
# 跳过私有方法和魔术方法
if attr_name.startswith('_'):
continue
attr = getattr(cls, attr_name)
# 只装饰可调用的实例方法
if callable(attr) and not isinstance(attr, (staticmethod, classmethod)):
# 检查是否已经被装饰过(避免重复装饰)
if not hasattr(attr, '_db_connection_ensured'):
wrapped = ensure_db_connection(attr)
wrapped._db_connection_ensured = True
setattr(cls, attr_name, wrapped)
return cls
def _check_and_reconnect(max_retries=5):
"""
检查数据库连接健康状态,必要时使用指数退避重新连接
策略:
1. 尝试执行简单查询测试连接
2. 如果失败使用指数退避策略重试最多5次
3. 每次重试的等待时间2^attempt 秒 (1s, 2s, 4s, 8s, 16s)
异常处理:
- 连接失效时自动重连
- 记录警告日志和重试信息
- 忽略关闭连接时的错误
- 达到最大重试次数后抛出异常
"""
last_error = None
for attempt in range(max_retries):
try:
connection.ensure_connection()
with connection.cursor() as cursor:
cursor.execute("SELECT 1")
cursor.fetchone()
# 连接成功
if attempt > 0:
logger.info(f"数据库重连成功 (尝试 {attempt + 1}/{max_retries})")
return
except Exception as e:
last_error = e
logger.warning(
f"数据库连接检查失败 (尝试 {attempt + 1}/{max_retries}): {e}"
)
# 关闭失效的连接
try:
connection.close()
except Exception:
pass # 忽略关闭时的错误
# 如果还有重试机会,使用指数退避等待
if attempt < max_retries - 1:
delay = 2 ** attempt # 指数退避: 1, 2, 4, 8, 16 秒
logger.info(f"等待 {delay} 秒后重试...")
time.sleep(delay)
else:
# 最后一次尝试也失败,抛出异常
logger.error(
f"数据库重连失败,已达最大重试次数 ({max_retries})"
)
raise last_error
async def async_check_and_reconnect(max_retries=5):
await sync_to_async(_check_and_reconnect, thread_sensitive=True)(max_retries=max_retries)
def ensure_db_connection_async(method):
@functools.wraps(method)
async def wrapper(*args, **kwargs):
await async_check_and_reconnect()
return await method(*args, **kwargs)
return wrapper
__all__ = [
'ensure_db_connection',
'auto_ensure_db_connection',
'async_check_and_reconnect',
'ensure_db_connection_async',
]

View File

@@ -0,0 +1,20 @@
from django.db import models
class ScanStatus(models.TextChoices):
"""扫描任务状态枚举"""
CANCELLED = 'cancelled', '已取消'
COMPLETED = 'completed', '已完成'
FAILED = 'failed', '失败'
INITIATED = 'initiated', '初始化'
RUNNING = 'running', '运行中'
class VulnSeverity(models.TextChoices):
"""漏洞严重性枚举"""
UNKNOWN = 'unknown', '未知'
INFO = 'info', '信息'
LOW = 'low', ''
MEDIUM = 'medium', ''
HIGH = 'high', ''
CRITICAL = 'critical', '危急'

View File

@@ -0,0 +1,101 @@
"""通用文件 hash 计算与校验工具
提供 SHA-256 哈希计算和校验功能,用于:
- 字典文件上传时计算 hash
- Worker 侧本地缓存校验
"""
import hashlib
import logging
from pathlib import Path
from typing import Optional, BinaryIO
logger = logging.getLogger(__name__)
# 默认分块大小64KB兼顾内存和性能
DEFAULT_CHUNK_SIZE = 65536
def calc_file_sha256(file_path: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> str:
"""计算文件的 SHA-256 哈希值
Args:
file_path: 文件绝对路径
chunk_size: 分块读取大小(字节),默认 64KB
Returns:
str: SHA-256 十六进制字符串64 字符)
Raises:
FileNotFoundError: 文件不存在
OSError: 文件读取失败
"""
hasher = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
hasher.update(chunk)
return hasher.hexdigest()
def calc_stream_sha256(stream: BinaryIO, chunk_size: int = DEFAULT_CHUNK_SIZE) -> str:
"""从二进制流计算 SHA-256用于边写边算
Args:
stream: 可读取的二进制流(如 UploadedFile.chunks()
chunk_size: 分块大小
Returns:
str: SHA-256 十六进制字符串
"""
hasher = hashlib.sha256()
for chunk in iter(lambda: stream.read(chunk_size), b""):
hasher.update(chunk)
return hasher.hexdigest()
def safe_calc_file_sha256(file_path: str) -> Optional[str]:
"""安全计算文件 SHA-256异常时返回 None
Args:
file_path: 文件绝对路径
Returns:
str | None: SHA-256 十六进制字符串,或 None文件不存在/读取失败)
"""
try:
return calc_file_sha256(file_path)
except FileNotFoundError:
logger.warning("计算 hash 失败:文件不存在 - %s", file_path)
return None
except OSError as exc:
logger.warning("计算 hash 失败:读取错误 - %s: %s", file_path, exc)
return None
def is_file_hash_match(file_path: str, expected_hash: str) -> bool:
"""校验文件 hash 是否与期望值匹配
Args:
file_path: 文件绝对路径
expected_hash: 期望的 SHA-256 十六进制字符串
Returns:
bool: True 表示匹配False 表示不匹配或计算失败
"""
if not expected_hash:
# 期望值为空,视为"无法校验",返回 False 让调用方决定是否重新下载
return False
actual_hash = safe_calc_file_sha256(file_path)
if actual_hash is None:
return False
return actual_hash.lower() == expected_hash.lower()
__all__ = [
"calc_file_sha256",
"calc_stream_sha256",
"safe_calc_file_sha256",
"is_file_hash_match",
]

View File

@@ -0,0 +1,429 @@
"""
数据库健康检查管理命令
使用方法:
python manage.py db_health_check # 基础延迟测试5次
python manage.py db_health_check --test-count=10 # 指定测试次数
python manage.py db_health_check --reconnect # 强制重连后测试
python manage.py db_health_check --stats # 显示连接统计信息
python manage.py db_health_check --api-test # 测试实际API查询性能
python manage.py db_health_check --monitor # 监控数据库服务器性能指标
python manage.py db_health_check --db=other # 指定数据库别名
python manage.py db_health_check --api-test --test-count=10 # API性能测试10次
python manage.py db_health_check --reconnect --api-test # 重连后进行API测试
示例:
# 快速延迟检查
python manage.py db_health_check --test-count=3
# 完整性能分析
python manage.py db_health_check --api-test --stats --test-count=5
# 数据库服务器监控
python manage.py db_health_check --monitor
"""
import time
import logging
from django.core.management.base import BaseCommand
from django.db import connection, connections
from django.conf import settings
logger = logging.getLogger(__name__)
class Command(BaseCommand):
"""Django管理命令数据库健康检查"""
help = '检查数据库连接健康状态'
def add_arguments(self, parser):
parser.add_argument(
'--reconnect',
action='store_true',
help='强制重新连接数据库',
)
parser.add_argument(
'--stats',
action='store_true',
help='显示连接统计信息',
)
parser.add_argument(
'--db',
type=str,
default='default',
help='指定数据库别名(默认: default',
)
parser.add_argument(
'--test-count',
type=int,
default=5,
help='延迟测试次数(默认: 5',
)
parser.add_argument(
'--api-test',
action='store_true',
help='测试实际API查询性能',
)
parser.add_argument(
'--monitor',
action='store_true',
help='监控数据库服务器性能指标',
)
def handle(self, *args, **options):
db_alias = options['db']
test_count = options['test_count']
self.stdout.write(f"正在测试数据库 '{db_alias}' 连接...")
# 获取数据库连接
db_connection = connections[db_alias]
if options['reconnect']:
self.stdout.write("强制重新连接数据库...")
try:
db_connection.close()
db_connection.ensure_connection()
self.stdout.write(self.style.SUCCESS("✓ 重连成功"))
except Exception as e:
self.stdout.write(self.style.ERROR(f"✗ 重连失败: {e}"))
return
# 测试数据库延迟
if options['monitor']:
self.monitor_database_performance(db_connection)
elif options['api_test']:
self.test_api_performance(test_count)
else:
self.test_database_latency(db_connection, test_count)
if options['stats']:
self.show_connection_stats(db_connection)
def test_database_latency(self, db_connection, test_count):
"""测试数据库延迟"""
self.stdout.write(f"\n开始延迟测试({test_count} 次)...")
latencies = []
successful_tests = 0
connection_times = []
query_times = []
for i in range(test_count):
try:
# 测试连接建立时间
conn_start = time.time()
db_connection.ensure_connection()
conn_end = time.time()
conn_time = (conn_end - conn_start) * 1000
connection_times.append(conn_time)
# 测试查询执行时间
query_start = time.time()
with db_connection.cursor() as cursor:
cursor.execute("SELECT 1")
result = cursor.fetchone()
query_end = time.time()
query_time = (query_end - query_start) * 1000
query_times.append(query_time)
total_time = conn_time + query_time
latencies.append(total_time)
successful_tests += 1
self.stdout.write(f" 测试 {i+1}: 总计{total_time:.2f}ms (连接:{conn_time:.2f}ms + 查询:{query_time:.2f}ms) ✓")
except Exception as e:
self.stdout.write(f" 测试 {i+1}: 失败 - {e}")
# 计算统计信息
if latencies:
avg_latency = sum(latencies) / len(latencies)
min_latency = min(latencies)
max_latency = max(latencies)
avg_conn_time = sum(connection_times) / len(connection_times)
avg_query_time = sum(query_times) / len(query_times)
self.stdout.write(f"\n延迟统计:")
self.stdout.write(f" 成功测试: {successful_tests}/{test_count}")
self.stdout.write(f" 平均总延迟: {avg_latency:.2f}ms")
self.stdout.write(f" 平均连接时间: {avg_conn_time:.2f}ms")
self.stdout.write(f" 平均查询时间: {avg_query_time:.2f}ms")
self.stdout.write(f" 最小延迟: {min_latency:.2f}ms")
self.stdout.write(f" 最大延迟: {max_latency:.2f}ms")
# 分析延迟来源
if avg_conn_time > avg_query_time * 2:
self.stdout.write(self.style.WARNING(" 分析: 连接建立是主要延迟来源"))
elif avg_query_time > avg_conn_time * 2:
self.stdout.write(self.style.WARNING(" 分析: 查询执行是主要延迟来源"))
else:
self.stdout.write(" 分析: 连接和查询延迟相当")
# 延迟评估
if avg_latency < 10:
self.stdout.write(self.style.SUCCESS(" 评估: 延迟很低,连接优秀"))
elif avg_latency < 50:
self.stdout.write(self.style.SUCCESS(" 评估: 延迟较低,连接良好"))
elif avg_latency < 200:
self.stdout.write(self.style.WARNING(" 评估: 延迟中等,连接可接受"))
else:
self.stdout.write(self.style.ERROR(" 评估: 延迟较高,可能影响性能"))
else:
self.stdout.write(self.style.ERROR("所有测试都失败了"))
def test_api_performance(self, test_count):
"""测试实际API查询性能"""
self.stdout.write(f"\n开始API性能测试{test_count} 次)...")
# 导入必要的模块
from apps.scan.models import Scan
from apps.engine.models import ScanEngine
from apps.targets.models import Target
from django.db.models import Count
api_latencies = []
successful_tests = 0
for i in range(test_count):
try:
start_time = time.time()
# 测试多种查询类型
# 1. 简单查询 - 引擎列表
engines = list(ScanEngine.objects.all()[:10])
# 2. 复杂查询 - 扫描列表即使没有数据也会执行复杂的JOIN
scan_queryset = Scan.objects.select_related(
'target', 'engine'
).annotate(
subdomains_count=Count('subdomains', distinct=True),
endpoints_count=Count('endpoints', distinct=True),
ips_count=Count('ip_addresses', distinct=True)
).order_by('-id')[:10]
scan_list = list(scan_queryset)
# 3. 目标查询
targets = list(Target.objects.all()[:10])
end_time = time.time()
latency_ms = (end_time - start_time) * 1000
api_latencies.append(latency_ms)
successful_tests += 1
self.stdout.write(f" API测试 {i+1}: {latency_ms:.2f}ms ✓ (引擎:{len(engines)}, 扫描:{len(scan_list)}, 目标:{len(targets)})")
except Exception as e:
self.stdout.write(f" API测试 {i+1}: 失败 - {e}")
# 计算API查询统计信息
if api_latencies:
avg_latency = sum(api_latencies) / len(api_latencies)
min_latency = min(api_latencies)
max_latency = max(api_latencies)
self.stdout.write(f"\nAPI查询统计:")
self.stdout.write(f" 成功测试: {successful_tests}/{test_count}")
self.stdout.write(f" 平均延迟: {avg_latency:.2f}ms")
self.stdout.write(f" 最小延迟: {min_latency:.2f}ms")
self.stdout.write(f" 最大延迟: {max_latency:.2f}ms")
# 与简单查询对比
simple_query_avg = 150 # 基于之前的测试结果
overhead = avg_latency - simple_query_avg
self.stdout.write(f" 业务逻辑开销: {overhead:.2f}ms")
# 性能评估
if avg_latency < 500:
self.stdout.write(self.style.SUCCESS(" 评估: API响应速度良好"))
elif avg_latency < 1000:
self.stdout.write(self.style.WARNING(" 评估: API响应速度一般"))
else:
self.stdout.write(self.style.ERROR(" 评估: API响应速度较慢需要优化"))
else:
self.stdout.write(self.style.ERROR("所有API测试都失败了"))
def monitor_database_performance(self, db_connection):
"""监控数据库服务器性能指标"""
self.stdout.write(f"\n开始监控数据库性能指标...")
try:
with db_connection.cursor() as cursor:
# 1. 数据库基本信息
self.stdout.write(f"\n=== 数据库基本信息 ===")
cursor.execute("SELECT version();")
version = cursor.fetchone()[0]
self.stdout.write(f"PostgreSQL版本: {version}")
cursor.execute("SELECT current_database();")
db_name = cursor.fetchone()[0]
self.stdout.write(f"当前数据库: {db_name}")
# 2. 连接信息
self.stdout.write(f"\n=== 连接状态 ===")
cursor.execute("""
SELECT count(*) as total_connections,
count(*) FILTER (WHERE state = 'active') as active_connections,
count(*) FILTER (WHERE state = 'idle') as idle_connections
FROM pg_stat_activity;
""")
conn_stats = cursor.fetchone()
self.stdout.write(f"总连接数: {conn_stats[0]}")
self.stdout.write(f"活跃连接: {conn_stats[1]}")
self.stdout.write(f"空闲连接: {conn_stats[2]}")
# 3. 数据库大小
self.stdout.write(f"\n=== 数据库大小 ===")
cursor.execute("""
SELECT pg_size_pretty(pg_database_size(current_database())) as db_size;
""")
db_size = cursor.fetchone()[0]
self.stdout.write(f"数据库大小: {db_size}")
# 4. 表统计信息
self.stdout.write(f"\n=== 主要表统计 ===")
cursor.execute("""
SELECT schemaname, relname,
n_tup_ins as inserts,
n_tup_upd as updates,
n_tup_del as deletes,
n_live_tup as live_rows,
n_dead_tup as dead_rows
FROM pg_stat_user_tables
WHERE schemaname = 'public'
ORDER BY n_live_tup DESC
LIMIT 10;
""")
tables = cursor.fetchall()
if tables:
for table in tables:
self.stdout.write(f" {table[1]}: {table[5]} 行 (插入:{table[2]}, 更新:{table[3]}, 删除:{table[4]}, 死行:{table[6]})")
else:
self.stdout.write(" 暂无表统计数据")
# 5. 慢查询统计
self.stdout.write(f"\n=== 查询性能统计 ===")
cursor.execute("""
SELECT query,
calls,
total_time,
mean_time,
rows
FROM pg_stat_statements
WHERE query NOT LIKE '%pg_stat_statements%'
ORDER BY mean_time DESC
LIMIT 5;
""")
try:
slow_queries = cursor.fetchall()
if slow_queries:
for i, query in enumerate(slow_queries, 1):
self.stdout.write(f" {i}. 平均耗时: {query[3]:.2f}ms, 调用次数: {query[1]}")
self.stdout.write(f" 查询: {query[0][:100]}...")
else:
self.stdout.write(" 未找到查询统计可能未启用pg_stat_statements扩展")
except Exception:
self.stdout.write(" 查询统计不可用需要pg_stat_statements扩展")
# 6. 锁信息
self.stdout.write(f"\n=== 锁状态 ===")
cursor.execute("""
SELECT mode, count(*)
FROM pg_locks
GROUP BY mode
ORDER BY count(*) DESC;
""")
locks = cursor.fetchall()
total_locks = sum(lock[1] for lock in locks)
self.stdout.write(f"总锁数量: {total_locks}")
for lock in locks:
self.stdout.write(f" {lock[0]}: {lock[1]}")
# 7. 缓存命中率
self.stdout.write(f"\n=== 缓存性能 ===")
cursor.execute("""
SELECT
sum(heap_blks_read) as heap_read,
sum(heap_blks_hit) as heap_hit,
sum(heap_blks_hit) / (sum(heap_blks_hit) + sum(heap_blks_read)) * 100 as cache_hit_ratio
FROM pg_statio_user_tables;
""")
cache_stats = cursor.fetchone()
if cache_stats[0] and cache_stats[1]:
self.stdout.write(f"缓存命中率: {cache_stats[2]:.2f}%")
self.stdout.write(f"磁盘读取: {cache_stats[0]}")
self.stdout.write(f"缓存命中: {cache_stats[1]}")
else:
self.stdout.write("缓存统计: 暂无数据")
# 8. 检查点和WAL
self.stdout.write(f"\n=== WAL和检查点 ===")
cursor.execute("""
SELECT
checkpoints_timed,
checkpoints_req,
checkpoint_write_time,
checkpoint_sync_time
FROM pg_stat_bgwriter;
""")
bgwriter = cursor.fetchone()
self.stdout.write(f"定时检查点: {bgwriter[0]}")
self.stdout.write(f"请求检查点: {bgwriter[1]}")
self.stdout.write(f"检查点写入时间: {bgwriter[2]}ms")
self.stdout.write(f"检查点同步时间: {bgwriter[3]}ms")
# 9. 当前活跃查询
self.stdout.write(f"\n=== 当前活跃查询 ===")
cursor.execute("""
SELECT pid,
application_name,
state,
query_start,
now() - query_start as duration,
left(query, 100) as query_preview
FROM pg_stat_activity
WHERE state = 'active'
AND query NOT LIKE '%pg_stat_activity%'
ORDER BY query_start;
""")
active_queries = cursor.fetchall()
if active_queries:
for query in active_queries:
self.stdout.write(f" PID {query[0]} ({query[1]}): 运行 {query[4]}")
self.stdout.write(f" 查询: {query[5]}...")
else:
self.stdout.write(" 当前无活跃查询")
except Exception as e:
self.stdout.write(self.style.ERROR(f"监控失败: {e}"))
def show_connection_stats(self, db_connection):
"""显示连接统计信息"""
self.stdout.write(f"\n连接信息:")
# 基本连接信息
settings_dict = db_connection.settings_dict
self.stdout.write(f" 数据库类型: {db_connection.vendor}")
self.stdout.write(f" 主机: {settings_dict.get('HOST', 'localhost')}")
self.stdout.write(f" 端口: {settings_dict.get('PORT', '5432')}")
self.stdout.write(f" 数据库名: {settings_dict.get('NAME', '')}")
self.stdout.write(f" 用户: {settings_dict.get('USER', '')}")
# 连接配置
conn_max_age = settings_dict.get('CONN_MAX_AGE', 0)
self.stdout.write(f" 连接最大存活时间: {conn_max_age}")
# 查询统计
if hasattr(db_connection, 'queries'):
query_count = len(db_connection.queries)
if query_count > 0:
total_time = sum(float(q['time']) for q in db_connection.queries)
self.stdout.write(f" 查询次数: {query_count}")
self.stdout.write(f" 总查询时间: {total_time:.4f}")
# 连接状态
is_connected = hasattr(db_connection, 'connection') and db_connection.connection is not None
self.stdout.write(f" 连接状态: {'已连接' if is_connected else '未连接'}")

View File

@@ -0,0 +1,164 @@
"""
简化的数据库性能监控命令
专注于可能导致查询延迟的关键指标
"""
import time
from django.core.management.base import BaseCommand
from django.db import connections
class Command(BaseCommand):
"""简化的数据库性能监控"""
help = '监控数据库性能关键指标'
def add_arguments(self, parser):
parser.add_argument(
'--interval',
type=int,
default=5,
help='监控间隔(秒,默认: 5',
)
parser.add_argument(
'--count',
type=int,
default=3,
help='监控次数(默认: 3',
)
def handle(self, *args, **options):
interval = options['interval']
count = options['count']
self.stdout.write("🔍 数据库性能监控开始...")
for i in range(count):
if i > 0:
time.sleep(interval)
self.stdout.write(f"\n=== 第 {i+1} 次监控 ===")
self.monitor_key_metrics()
def monitor_key_metrics(self):
"""监控关键性能指标"""
db_connection = connections['default']
try:
with db_connection.cursor() as cursor:
# 1. 连接和活动状态
cursor.execute("""
SELECT
count(*) as total_connections,
count(*) FILTER (WHERE state = 'active') as active,
count(*) FILTER (WHERE state = 'idle') as idle,
count(*) FILTER (WHERE state = 'idle in transaction') as idle_in_trans,
count(*) FILTER (WHERE wait_event_type IS NOT NULL) as waiting
FROM pg_stat_activity;
""")
conn_stats = cursor.fetchone()
self.stdout.write(f"连接: 总计{conn_stats[0]} | 活跃{conn_stats[1]} | 空闲{conn_stats[2]} | 事务中{conn_stats[3]} | 等待{conn_stats[4]}")
# 2. 锁等待情况
cursor.execute("""
SELECT
count(*) as total_locks,
count(*) FILTER (WHERE NOT granted) as waiting_locks
FROM pg_locks;
""")
lock_stats = cursor.fetchone()
if lock_stats[1] > 0:
self.stdout.write(self.style.WARNING(f"🔒 锁: 总计{lock_stats[0]} | 等待{lock_stats[1]}"))
else:
self.stdout.write(f"🔒 锁: 总计{lock_stats[0]} | 等待{lock_stats[1]}")
# 3. 长时间运行的查询
cursor.execute("""
SELECT
pid,
application_name,
now() - query_start as duration,
state,
left(query, 60) as query_preview
FROM pg_stat_activity
WHERE state = 'active'
AND query_start < now() - interval '1 second'
AND query NOT LIKE '%pg_stat_activity%'
ORDER BY query_start;
""")
long_queries = cursor.fetchall()
if long_queries:
self.stdout.write(self.style.WARNING(f"⏱️ 长查询 ({len(long_queries)} 个):"))
for query in long_queries:
self.stdout.write(f" PID {query[0]} ({query[1]}): {query[2]} - {query[4]}...")
else:
self.stdout.write("⏱️ 长查询: 无")
# 4. 缓存命中率
cursor.execute("""
SELECT
sum(heap_blks_hit) as cache_hits,
sum(heap_blks_read) as disk_reads,
CASE
WHEN sum(heap_blks_hit) + sum(heap_blks_read) = 0 THEN 0
ELSE round(sum(heap_blks_hit) * 100.0 / (sum(heap_blks_hit) + sum(heap_blks_read)), 2)
END as hit_ratio
FROM pg_statio_user_tables;
""")
cache_stats = cursor.fetchone()
if cache_stats[0] or cache_stats[1]:
hit_ratio = cache_stats[2] or 0
if hit_ratio < 95:
self.stdout.write(self.style.WARNING(f"💾 缓存命中率: {hit_ratio}% (缓存:{cache_stats[0]}, 磁盘:{cache_stats[1]})"))
else:
self.stdout.write(f"💾 缓存命中率: {hit_ratio}% (缓存:{cache_stats[0]}, 磁盘:{cache_stats[1]})")
else:
self.stdout.write("💾 缓存: 暂无统计数据")
# 5. 检查点活动(尝试获取,如果失败则跳过)
try:
cursor.execute("SELECT * FROM pg_stat_bgwriter LIMIT 1;")
bgwriter_cols = [desc[0] for desc in cursor.description]
if 'checkpoints_timed' in bgwriter_cols:
cursor.execute("""
SELECT
checkpoints_timed,
checkpoints_req,
checkpoint_write_time,
checkpoint_sync_time
FROM pg_stat_bgwriter;
""")
bgwriter = cursor.fetchone()
total_checkpoints = bgwriter[0] + bgwriter[1]
if bgwriter[2] > 10000 or bgwriter[3] > 5000:
self.stdout.write(self.style.WARNING(f"📝 检查点: 总计{total_checkpoints} | 写入{bgwriter[2]}ms | 同步{bgwriter[3]}ms"))
else:
self.stdout.write(f"📝 检查点: 总计{total_checkpoints} | 写入{bgwriter[2]}ms | 同步{bgwriter[3]}ms")
else:
self.stdout.write("📝 检查点: 统计不可用")
except Exception:
self.stdout.write("📝 检查点: 统计不可用")
# 6. 数据库大小变化
cursor.execute("SELECT pg_database_size(current_database());")
db_size = cursor.fetchone()[0]
db_size_mb = round(db_size / 1024 / 1024, 2)
self.stdout.write(f"💿 数据库大小: {db_size_mb} MB")
# 7. 测试查询延迟
start_time = time.time()
cursor.execute("SELECT 1")
cursor.fetchone()
query_latency = (time.time() - start_time) * 1000
if query_latency > 500:
self.stdout.write(self.style.ERROR(f"⚡ 查询延迟: {query_latency:.2f}ms (高)"))
elif query_latency > 200:
self.stdout.write(self.style.WARNING(f"⚡ 查询延迟: {query_latency:.2f}ms (中)"))
else:
self.stdout.write(f"⚡ 查询延迟: {query_latency:.2f}ms (正常)")
except Exception as e:
self.stdout.write(self.style.ERROR(f"监控失败: {e}"))

View File

@@ -0,0 +1,64 @@
"""
初始化 admin 用户
用法: python manage.py init_admin [--password <password>]
"""
import os
from django.core.management.base import BaseCommand
from django.contrib.auth import get_user_model
User = get_user_model()
DEFAULT_USERNAME = 'admin'
DEFAULT_PASSWORD = 'admin' # 默认密码,建议首次登录后修改
class Command(BaseCommand):
help = '初始化 admin 用户'
def add_arguments(self, parser):
parser.add_argument(
'--password',
type=str,
default=os.getenv('ADMIN_PASSWORD', DEFAULT_PASSWORD),
help='admin 用户密码 (默认: admin 或 ADMIN_PASSWORD 环境变量)'
)
parser.add_argument(
'--force',
action='store_true',
help='强制重置密码(如果用户已存在)'
)
def handle(self, *args, **options):
password = options['password']
force = options['force']
try:
user = User.objects.get(username=DEFAULT_USERNAME)
if force:
user.set_password(password)
user.save()
self.stdout.write(
self.style.SUCCESS(f'✓ admin 用户密码已重置')
)
else:
self.stdout.write(
self.style.WARNING(f'⚠ admin 用户已存在,跳过创建(使用 --force 重置密码)')
)
except User.DoesNotExist:
User.objects.create_superuser(
username=DEFAULT_USERNAME,
email='admin@localhost',
password=password
)
self.stdout.write(
self.style.SUCCESS(f'✓ admin 用户创建成功')
)
self.stdout.write(
self.style.WARNING(f' 用户名: {DEFAULT_USERNAME}')
)
self.stdout.write(
self.style.WARNING(f' 密码: {password}')
)
self.stdout.write(
self.style.WARNING(f' ⚠ 请首次登录后修改密码!')
)

View File

@@ -0,0 +1,106 @@
import re
# 预编译正则表达式,避免每次调用时重新编译
IP_PATTERN = re.compile(r'^[\d.:]+$')
def normalize_domain(domain: str) -> str:
"""
规范化域名
- 去除首尾空格
- 转换为小写
- 移除末尾的点
Args:
domain: 原始域名
Returns:
规范化后的域名
Raises:
ValueError: 域名为空或只包含空格
"""
if not domain or not domain.strip():
raise ValueError("域名不能为空")
normalized = domain.strip().lower()
# 移除末尾的点
if normalized.endswith('.'):
normalized = normalized.rstrip('.')
return normalized
def normalize_ip(ip: str) -> str:
"""
规范化 IP 地址
- 去除首尾空格
- IP 地址不转小写(保持原样)
Args:
ip: 原始 IP 地址
Returns:
规范化后的 IP 地址
Raises:
ValueError: IP 地址为空或只包含空格
"""
if not ip or not ip.strip():
raise ValueError("IP 地址不能为空")
return ip.strip()
def normalize_cidr(cidr: str) -> str:
"""
规范化 CIDR
- 去除首尾空格
- CIDR 不转小写(保持原样)
Args:
cidr: 原始 CIDR
Returns:
规范化后的 CIDR
Raises:
ValueError: CIDR 为空或只包含空格
"""
if not cidr or not cidr.strip():
raise ValueError("CIDR 不能为空")
return cidr.strip()
def normalize_target(target: str) -> str:
"""
规范化目标名称(统一入口)
根据目标格式自动选择合适的规范化函数
Args:
target: 原始目标名称
Returns:
规范化后的目标名称
Raises:
ValueError: 目标为空或只包含空格
"""
if not target or not target.strip():
raise ValueError("目标名称不能为空")
# 先去除首尾空格
trimmed = target.strip()
# 如果包含 /,按 CIDR 处理
if '/' in trimmed:
return normalize_cidr(trimmed)
# 如果是纯数字、点、冒号组成,按 IP 处理
if IP_PATTERN.match(trimmed):
return normalize_ip(trimmed)
# 否则按域名处理
return normalize_domain(trimmed)

View File

@@ -0,0 +1,34 @@
"""
自定义分页器,匹配前端响应格式
"""
from rest_framework.pagination import PageNumberPagination
from rest_framework.response import Response
class BasePagination(PageNumberPagination):
"""
基础分页器,统一返回格式
响应格式:
{
"results": [...],
"total": 100,
"page": 1,
"pageSize": 10,
"totalPages": 10
}
"""
page_size = 10 # 默认每页 10 条
page_size_query_param = 'pageSize' # 允许客户端自定义每页数量
max_page_size = 1000 # 最大每页数量限制
def get_paginated_response(self, data):
"""自定义响应格式"""
return Response({
'results': data, # 数据列表
'total': self.page.paginator.count, # 总记录数
'page': self.page.number, # 当前页码(从 1 开始)
'page_size': self.page.paginator.per_page, # 实际使用的每页大小
'total_pages': self.page.paginator.num_pages # 总页数
})

View File

@@ -0,0 +1,42 @@
"""
Prefect Flow Django 环境初始化模块
在所有 Prefect Flow 文件开头导入此模块即可自动配置 Django 环境
"""
import os
import sys
def setup_django_for_prefect():
"""
为 Prefect Flow 配置 Django 环境
此函数会:
1. 添加项目根目录到 Python 路径
2. 设置 DJANGO_SETTINGS_MODULE 环境变量
3. 调用 django.setup() 初始化 Django
使用方式:
from apps.common.prefect_django_setup import setup_django_for_prefect
setup_django_for_prefect()
"""
# 获取项目根目录backend 目录)
current_dir = os.path.dirname(os.path.abspath(__file__))
backend_dir = os.path.join(current_dir, '../..')
backend_dir = os.path.abspath(backend_dir)
# 添加到 Python 路径
if backend_dir not in sys.path:
sys.path.insert(0, backend_dir)
# 配置 Django
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
# 初始化 Django
import django
django.setup()
# 自动执行初始化(导入即生效)
setup_django_for_prefect()

View File

@@ -0,0 +1,29 @@
"""通用信号定义
定义项目中使用的自定义信号,用于解耦各模块之间的通信。
使用方式:
1. 发布信号signal.send(sender=SomeClass, **kwargs)
2. 接收信号:@receiver(signal) def handler(sender, **kwargs): ...
"""
from django.dispatch import Signal
# ==================== 漏洞相关信号 ====================
# 漏洞保存完成信号
# 参数:
# - items: List[VulnerabilitySnapshotDTO] 保存的漏洞列表
# - scan_id: int 扫描任务ID
# - target_id: int 目标ID
vulnerabilities_saved = Signal()
# ==================== Worker 相关信号 ====================
# Worker 删除失败信号(只在失败时发送)
# 参数:
# - worker_name: str Worker 名称
# - message: str 失败原因
worker_delete_failed = Signal()

View File

@@ -0,0 +1,12 @@
"""
通用模块 URL 配置
"""
from django.urls import path
from .views import LoginView, LogoutView, MeView, ChangePasswordView
urlpatterns = [
path('auth/login/', LoginView.as_view(), name='auth-login'),
path('auth/logout/', LogoutView.as_view(), name='auth-logout'),
path('auth/me/', MeView.as_view(), name='auth-me'),
path('auth/change-password/', ChangePasswordView.as_view(), name='auth-change-password'),
]

View File

@@ -0,0 +1,142 @@
"""域名、IP、端口和目标验证工具函数"""
import ipaddress
import logging
import validators
logger = logging.getLogger(__name__)
def validate_domain(domain: str) -> None:
"""
验证域名格式(使用 validators 库)
Args:
domain: 域名字符串(应该已经规范化)
Raises:
ValueError: 域名格式无效
"""
if not domain:
raise ValueError("域名不能为空")
# 使用 validators 库验证域名格式
# 支持国际化域名IDN和各种边界情况
if not validators.domain(domain):
raise ValueError(f"域名格式无效: {domain}")
def validate_ip(ip: str) -> None:
"""
验证 IP 地址格式(支持 IPv4 和 IPv6
Args:
ip: IP 地址字符串(应该已经规范化)
Raises:
ValueError: IP 地址格式无效
"""
if not ip:
raise ValueError("IP 地址不能为空")
try:
ipaddress.ip_address(ip)
except ValueError:
raise ValueError(f"IP 地址格式无效: {ip}")
def validate_cidr(cidr: str) -> None:
"""
验证 CIDR 格式(支持 IPv4 和 IPv6
Args:
cidr: CIDR 字符串(应该已经规范化)
Raises:
ValueError: CIDR 格式无效
"""
if not cidr:
raise ValueError("CIDR 不能为空")
try:
ipaddress.ip_network(cidr, strict=False)
except ValueError:
raise ValueError(f"CIDR 格式无效: {cidr}")
def detect_target_type(name: str) -> str:
"""
检测目标类型(不做规范化,只验证)
Args:
name: 目标名称(应该已经规范化)
Returns:
str: 目标类型 ('domain', 'ip', 'cidr') - 使用 Target.TargetType 枚举值
Raises:
ValueError: 如果无法识别目标类型
"""
# 在函数内部导入模型,避免 AppRegistryNotReady 错误
from apps.targets.models import Target
if not name:
raise ValueError("目标名称不能为空")
# 检查是否是 CIDR 格式(包含 /
if '/' in name:
validate_cidr(name)
return Target.TargetType.CIDR
# 检查是否是 IP 地址
try:
validate_ip(name)
return Target.TargetType.IP
except ValueError:
pass
# 检查是否是合法域名
try:
validate_domain(name)
return Target.TargetType.DOMAIN
except ValueError:
pass
# 无法识别的格式
raise ValueError(f"无法识别的目标格式: {name}必须是域名、IP地址或CIDR范围")
def validate_port(port: any) -> tuple[bool, int | None]:
"""
验证并转换端口号
Args:
port: 待验证的端口号(可能是字符串、整数或其他类型)
Returns:
tuple: (is_valid, port_number)
- is_valid: 端口是否有效
- port_number: 有效时为整数端口号,无效时为 None
验证规则:
1. 必须能转换为整数
2. 必须在 1-65535 范围内
示例:
>>> is_valid, port_num = validate_port(8080)
>>> is_valid, port_num
(True, 8080)
>>> is_valid, port_num = validate_port("invalid")
>>> is_valid, port_num
(False, None)
"""
try:
port_num = int(port)
if 1 <= port_num <= 65535:
return True, port_num
else:
logger.warning("端口号超出有效范围 (1-65535): %d", port_num)
return False, None
except (ValueError, TypeError):
logger.warning("端口号格式错误,无法转换为整数: %s", port)
return False, None

View File

@@ -0,0 +1,3 @@
from .auth_views import LoginView, LogoutView, MeView, ChangePasswordView
__all__ = ['LoginView', 'LogoutView', 'MeView', 'ChangePasswordView']

View File

@@ -0,0 +1,173 @@
"""
认证相关视图
使用 Django 内置认证系统,支持 Session 认证
"""
import logging
from django.contrib.auth import authenticate, login, logout, update_session_auth_hash
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from rest_framework import status
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.permissions import AllowAny, IsAuthenticated
logger = logging.getLogger(__name__)
@method_decorator(csrf_exempt, name='dispatch')
class LoginView(APIView):
"""
用户登录
POST /api/auth/login/
"""
authentication_classes = [] # 禁用认证(绕过 CSRF
permission_classes = [AllowAny]
def post(self, request):
username = request.data.get('username')
password = request.data.get('password')
if not username or not password:
return Response(
{'error': '请提供用户名和密码'},
status=status.HTTP_400_BAD_REQUEST
)
user = authenticate(request, username=username, password=password)
if user is not None:
login(request, user)
logger.info(f"用户 {username} 登录成功")
return Response({
'message': '登录成功',
'user': {
'id': user.id,
'username': user.username,
'isStaff': user.is_staff,
'isSuperuser': user.is_superuser,
}
})
else:
logger.warning(f"用户 {username} 登录失败:用户名或密码错误")
return Response(
{'error': '用户名或密码错误'},
status=status.HTTP_401_UNAUTHORIZED
)
@method_decorator(csrf_exempt, name='dispatch')
class LogoutView(APIView):
"""
用户登出
POST /api/auth/logout/
"""
authentication_classes = [] # 禁用认证(绕过 CSRF
permission_classes = [AllowAny]
def post(self, request):
# 从 session 获取用户名用于日志
user_id = request.session.get('_auth_user_id')
if user_id:
from django.contrib.auth import get_user_model
User = get_user_model()
try:
user = User.objects.get(pk=user_id)
username = user.username
logout(request)
logger.info(f"用户 {username} 已登出")
except User.DoesNotExist:
logout(request)
else:
logout(request)
return Response({'message': '已登出'})
@method_decorator(csrf_exempt, name='dispatch')
class MeView(APIView):
"""
获取当前用户信息
GET /api/auth/me/
"""
authentication_classes = [] # 禁用认证(绕过 CSRF
permission_classes = [AllowAny]
def get(self, request):
# 从 session 获取用户
from django.contrib.auth import get_user_model
User = get_user_model()
user_id = request.session.get('_auth_user_id')
if user_id:
try:
user = User.objects.get(pk=user_id)
return Response({
'authenticated': True,
'user': {
'id': user.id,
'username': user.username,
'isStaff': user.is_staff,
'isSuperuser': user.is_superuser,
}
})
except User.DoesNotExist:
pass
return Response({
'authenticated': False,
'user': None
})
@method_decorator(csrf_exempt, name='dispatch')
class ChangePasswordView(APIView):
"""
修改密码
POST /api/auth/change-password/
"""
authentication_classes = [] # 禁用认证(绕过 CSRF
permission_classes = [AllowAny] # 手动检查登录状态
def post(self, request):
# 手动检查登录状态(从 session 获取用户)
from django.contrib.auth import get_user_model
User = get_user_model()
user_id = request.session.get('_auth_user_id')
if not user_id:
return Response(
{'error': '请先登录'},
status=status.HTTP_401_UNAUTHORIZED
)
try:
user = User.objects.get(pk=user_id)
except User.DoesNotExist:
return Response(
{'error': '用户不存在'},
status=status.HTTP_401_UNAUTHORIZED
)
# CamelCaseParser 将 oldPassword -> old_password
old_password = request.data.get('old_password')
new_password = request.data.get('new_password')
if not old_password or not new_password:
return Response(
{'error': '请提供旧密码和新密码'},
status=status.HTTP_400_BAD_REQUEST
)
if not user.check_password(old_password):
return Response(
{'error': '旧密码错误'},
status=status.HTTP_400_BAD_REQUEST
)
user.set_password(new_password)
user.save()
# 更新 session避免用户被登出
update_session_auth_hash(request, user)
logger.info(f"用户 {user.username} 已修改密码")
return Response({'message': '密码修改成功'})

View File

View File

@@ -0,0 +1,32 @@
import os
from django.apps import AppConfig
class EngineConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'apps.engine'
verbose_name = '扫描引擎'
def ready(self):
"""应用就绪时启动定时调度器"""
# 只在主进程中启动调度器(避免 autoreload 重复启动)
# 检查是否在 runserver 的 autoreload 子进程中
if os.environ.get('RUN_MAIN') == 'true' or not self._is_runserver():
# 只在 Server 容器中启动调度器Worker 容器不需要)
if not os.environ.get('SERVER_URL'): # Worker 容器有 SERVER_URL
self._start_scheduler()
def _is_runserver(self):
"""检查是否通过 runserver 启动"""
import sys
return 'runserver' in sys.argv
def _start_scheduler(self):
"""启动调度器"""
try:
from apps.engine.scheduler import start_scheduler
start_scheduler()
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"调度器启动失败: {e}")

View File

@@ -0,0 +1,6 @@
"""
Engine WebSocket Consumers
"""
from .worker_deploy_consumer import WorkerDeployConsumer
__all__ = ['WorkerDeployConsumer']

View File

@@ -0,0 +1,454 @@
"""
WebSocket Consumer - Worker 交互式终端 (使用 PTY)
"""
import json
import logging
import asyncio
import os
from channels.generic.websocket import AsyncWebsocketConsumer
from asgiref.sync import sync_to_async
from django.conf import settings
from apps.engine.services import WorkerService
logger = logging.getLogger(__name__)
class WorkerDeployConsumer(AsyncWebsocketConsumer):
"""
Worker 交互式终端 WebSocket Consumer
使用 paramiko invoke_shell 实现真正的交互式终端
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ssh_client = None
self.shell = None
self.worker = None
self.read_task = None
self.worker_service = WorkerService()
async def connect(self):
"""连接时加入对应 Worker 的组并自动建立 SSH 连接"""
self.worker_id = self.scope['url_route']['kwargs']['worker_id']
self.group_name = f'worker_deploy_{self.worker_id}'
await self.channel_layer.group_add(self.group_name, self.channel_name)
await self.accept()
logger.info(f"终端已连接 - Worker: {self.worker_id}")
# 自动建立 SSH 连接
await self._auto_ssh_connect()
async def disconnect(self, close_code):
"""断开时清理资源"""
if self.read_task:
self.read_task.cancel()
if self.shell:
try:
self.shell.close()
except Exception:
pass
if self.ssh_client:
try:
self.ssh_client.close()
except Exception:
pass
await self.channel_layer.group_discard(self.group_name, self.channel_name)
logger.info(f"终端已断开 - Worker: {self.worker_id}")
async def receive(self, text_data=None, bytes_data=None):
"""接收客户端消息"""
if bytes_data:
# 二进制数据直接发送到 shell
if self.shell:
await asyncio.to_thread(self.shell.send, bytes_data)
return
if not text_data:
return
try:
data = json.loads(text_data)
msg_type = data.get('type')
if msg_type == 'resize':
cols = data.get('cols', 80)
rows = data.get('rows', 24)
if self.shell:
await asyncio.to_thread(self.shell.resize_pty, cols, rows)
elif msg_type == 'input':
# 终端输入
if self.shell:
text = data.get('data', '')
await asyncio.to_thread(self.shell.send, text)
elif msg_type == 'deploy':
# 执行部署脚本(后台运行)
await self._run_deploy_script()
elif msg_type == 'attach':
# 查看部署进度attach 到 tmux 会话)
await self._attach_deploy_session()
elif msg_type == 'uninstall':
# 执行卸载脚本(后台运行)
await self._run_uninstall_script()
except json.JSONDecodeError:
# 可能是普通文本输入
if self.shell and text_data:
await asyncio.to_thread(self.shell.send, text_data)
except Exception as e:
logger.error(f"处理消息错误: {e}")
async def _auto_ssh_connect(self):
"""自动从数据库读取密码并连接"""
logger.info(f"[SSH] 开始自动连接 - Worker ID: {self.worker_id}")
# 通过服务层获取 Worker 节点
# thread_sensitive=False 确保在新线程中运行,避免数据库连接问题
self.worker = await sync_to_async(self.worker_service.get_worker, thread_sensitive=False)(self.worker_id)
logger.info(f"[SSH] Worker 查询结果: {self.worker}")
if not self.worker:
await self.send(text_data=json.dumps({
'type': 'error',
'message': 'Worker 不存在'
}))
return
if not self.worker.password:
await self.send(text_data=json.dumps({
'type': 'error',
'message': '未配置 SSH 密码,请先编辑节点信息'
}))
return
# 使用默认终端大小
await self._ssh_connect(self.worker.password, 80, 24)
async def _ssh_connect(self, password: str, cols: int = 80, rows: int = 24):
"""建立 SSH 连接并启动交互式 shell (使用 tmux 持久化会话)"""
try:
import paramiko
except ImportError:
await self.send(text_data=json.dumps({
'type': 'error',
'message': '服务器缺少 paramiko 库'
}))
return
# self.worker 已在 _auto_ssh_connect 中查询
try:
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
await asyncio.to_thread(
ssh.connect,
self.worker.ip_address,
port=self.worker.ssh_port,
username=self.worker.username,
password=password,
timeout=30
)
self.ssh_client = ssh
# 启动交互式 shell连接时不做 tmux 安装,仅提供普通 shell
self.shell = await asyncio.to_thread(
ssh.invoke_shell,
term='xterm-256color',
width=cols,
height=rows
)
# 发送连接成功消息
logger.info(f"[SSH] 准备发送 connected 消息 - Worker: {self.worker_id}")
await self.send(text_data=json.dumps({
'type': 'connected'
}))
logger.info(f"[SSH] connected 消息已发送 - Worker: {self.worker_id}")
# 启动读取任务
self.read_task = asyncio.create_task(self._read_shell_output())
logger.info(f"[SSH] Shell 已连接,读取任务已启动 - Worker: {self.worker_id}")
except paramiko.AuthenticationException:
logger.error(f"[SSH] 认证失败 - Worker: {self.worker_id}")
await self.send(text_data=json.dumps({
'type': 'error',
'message': '认证失败,密码错误'
}))
except Exception as e:
logger.error(f"[SSH] 连接失败 - Worker: {self.worker_id}, Error: {e}")
await self.send(text_data=json.dumps({
'type': 'error',
'message': f'连接失败: {str(e)}'
}))
async def _read_shell_output(self):
"""持续读取 shell 输出并发送到客户端"""
try:
while self.shell and not self.shell.closed:
if self.shell.recv_ready():
data = await asyncio.to_thread(self.shell.recv, 4096)
if data:
await self.send(bytes_data=data)
else:
await asyncio.sleep(0.01)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"读取 shell 输出错误: {e}")
async def _run_deploy_script(self):
"""运行部署脚本(在 tmux 会话中执行,支持断线续连)
流程:
1. 通过 SFTP 上传脚本到远程服务器
2. 使用 exec_command 静默执行(不在交互式终端回显)
3. 通过 WebSocket 发送结果到前端显示
"""
if not self.ssh_client:
return
from apps.engine.services.deploy_service import (
get_bootstrap_script,
get_deploy_script,
get_start_agent_script
)
# 优先使用 settings 中配置的对外访问主机PUBLIC_HOST拼接 Django URL
public_host = getattr(settings, 'PUBLIC_HOST', '').strip()
server_port = getattr(settings, 'SERVER_PORT', '8888')
if not public_host:
error_msg = (
"未配置 PUBLIC_HOST请在 docker/.env 中设置对外访问 IP/域名 "
"(PUBLIC_HOST) 并重启服务后再执行远程部署"
)
logger.error(error_msg)
await self.send(text_data=json.dumps({
'type': 'error',
'message': error_msg,
}))
return
django_host = f"{public_host}:{server_port}" # Django / 心跳上报使用
heartbeat_api_url = f"http://{django_host}" # 基础 URLagent 会加 /api/...
session_name = f'xingrin_deploy_{self.worker_id}'
remote_script_path = '/tmp/xingrin_deploy.sh'
# 获取外置脚本内容
bootstrap_script = get_bootstrap_script()
deploy_script = get_deploy_script()
start_script = get_start_agent_script(
heartbeat_api_url=heartbeat_api_url,
worker_id=self.worker_id
)
# 合并脚本
combined_script = f"""#!/bin/bash
set -e
# ==================== 阶段 1: 环境初始化 ====================
{bootstrap_script}
# ==================== 阶段 2: 安装 Docker ====================
{deploy_script}
# ==================== 阶段 3: 启动 Agent ====================
{start_script}
echo "SUCCESS"
"""
# 更新状态为 deploying
await sync_to_async(self.worker_service.update_status)(self.worker_id, 'deploying')
# 发送开始提示
start_msg = "\r\n\033[36m[XingRin] 正在准备部署...\033[0m\r\n"
await self.send(bytes_data=start_msg.encode())
try:
# 1. 上传脚本
sftp = await asyncio.to_thread(self.ssh_client.open_sftp)
with sftp.file(remote_script_path, 'w') as f:
f.write(combined_script)
sftp.chmod(remote_script_path, 0o755)
await asyncio.to_thread(sftp.close)
# 2. 静默执行部署命令(使用 exec_command不会回显到终端
deploy_cmd = f"""
# 确保 tmux 安装
if ! command -v tmux >/dev/null 2>&1; then
if command -v apt-get >/dev/null 2>&1; then
sudo apt-get update -qq && sudo apt-get install -y -qq tmux >/dev/null 2>&1
fi
fi
# 检查脚本是否存在
if [ ! -f "{remote_script_path}" ]; then
echo "SCRIPT_NOT_FOUND"
exit 1
fi
# 启动 tmux 会话
if command -v tmux >/dev/null 2>&1; then
tmux kill-session -t {session_name} 2>/dev/null || true
# 使用 bash 执行脚本,确保环境正确
tmux new-session -d -s {session_name} "bash {remote_script_path}; echo '部署完成,按回车退出'; read"
# 验证会话是否创建成功
sleep 0.5
if tmux has-session -t {session_name} 2>/dev/null; then
echo "SUCCESS"
else
echo "SESSION_CREATE_FAILED"
fi
else
echo "TMUX_NOT_FOUND"
fi
"""
stdin, stdout, stderr = await asyncio.to_thread(
self.ssh_client.exec_command, deploy_cmd
)
result = await asyncio.to_thread(stdout.read)
result = result.decode().strip()
# 3. 发送结果到前端终端显示
if "SUCCESS" in result:
# 部署任务已在后台启动,保持 deploying 状态
# 只有当心跳上报成功后才会变成 deployed通过 heartbeat API 自动更新)
success_msg = (
"\r\n\033[32m✓ 部署任务已在后台启动\033[0m\r\n"
f"\033[90m 会话: {session_name}\033[0m\r\n"
"\r\n"
"\033[36m点击 [查看进度] 按钮查看部署输出\033[0m\r\n"
f"\033[90m或手动执行: tmux attach -t {session_name}\033[0m\r\n"
"\r\n"
)
else:
# 获取更多错误信息
err = await asyncio.to_thread(stderr.read)
err_msg = err.decode().strip() if err else ""
success_msg = f"\r\n\033[31m✗ 部署启动失败\033[0m\r\n\033[90m结果: {result}\r\n错误: {err_msg}\033[0m\r\n"
await self.send(bytes_data=success_msg.encode())
except Exception as e:
error_msg = f"\033[31m✗ 部署失败: {str(e)}\033[0m\r\n"
await self.send(bytes_data=error_msg.encode())
logger.error(f"部署脚本执行失败: {e}")
async def _run_uninstall_script(self):
"""在远程主机上执行 Worker 卸载脚本
逻辑:
1. 通过服务层读取本地 worker-uninstall.sh 内容
2. 上传到远程 /tmp/xingrin_uninstall.sh 并赋予执行权限
3. 使用 exec_command 以 bash 执行脚本
4. 将执行结果摘要写回前端终端
"""
if not self.ssh_client:
return
from apps.engine.services.deploy_service import get_uninstall_script
uninstall_script = get_uninstall_script()
remote_script_path = '/tmp/xingrin_uninstall.sh'
start_msg = "\r\n\033[36m[XingRin] 正在执行 Worker 卸载...\033[0m\r\n"
await self.send(bytes_data=start_msg.encode())
try:
# 上传卸载脚本
sftp = await asyncio.to_thread(self.ssh_client.open_sftp)
with sftp.file(remote_script_path, 'w') as f:
f.write(uninstall_script)
sftp.chmod(remote_script_path, 0o755)
await asyncio.to_thread(sftp.close)
# 执行卸载脚本
cmd = f"bash {remote_script_path}"
stdin, stdout, stderr = await asyncio.to_thread(
self.ssh_client.exec_command, cmd
)
out = await asyncio.to_thread(stdout.read)
err = await asyncio.to_thread(stderr.read)
# 转换换行符为终端格式 (\n -> \r\n)
output_text = out.decode().strip().replace('\n', '\r\n') if out else ""
error_text = err.decode().strip().replace('\n', '\r\n') if err else ""
# 简单判断是否成功(退出码 + 关键字)
exit_status = stdout.channel.recv_exit_status()
if exit_status == 0:
# 卸载成功,重置状态为 pending
await sync_to_async(self.worker_service.update_status)(self.worker_id, 'pending')
# 删除 Redis 中的心跳数据
from apps.engine.services.worker_load_service import worker_load_service
worker_load_service.delete_load(self.worker_id)
# 发送状态更新到前端
await self.send(text_data=json.dumps({
'type': 'status',
'status': 'pending' # 卸载后变为待部署状态
}))
msg = "\r\n\033[32m✓ 节点卸载完成\033[0m\r\n"
if output_text:
msg += f"\033[90m{output_text}\033[0m\r\n"
else:
msg = "\r\n\033[31m✗ Worker 卸载失败\033[0m\r\n"
if output_text:
msg += f"\033[90m输出: {output_text}\033[0m\r\n"
if error_text:
msg += f"\033[90m错误: {error_text}\033[0m\r\n"
await self.send(bytes_data=msg.encode())
except Exception as e:
error_msg = f"\033[31m✗ 卸载执行异常: {str(e)}\033[0m\r\n"
await self.send(bytes_data=error_msg.encode())
logger.error(f"卸载脚本执行失败: {e}")
async def _attach_deploy_session(self):
"""Attach 到部署会话查看进度"""
if not self.shell or not self.ssh_client:
return
session_name = f'xingrin_deploy_{self.worker_id}'
# 先静默检查会话是否存在
check_cmd = f"tmux has-session -t {session_name} 2>/dev/null && echo EXISTS || echo NOT_EXISTS"
stdin, stdout, stderr = await asyncio.to_thread(
self.ssh_client.exec_command, check_cmd
)
result = await asyncio.to_thread(stdout.read)
result = result.decode().strip()
if "EXISTS" in result:
# 会话存在,直接 attach
await asyncio.to_thread(self.shell.send, f"tmux attach -t {session_name}\n")
else:
# 会话不存在,发送提示
msg = "\r\n\033[33m没有正在运行的部署任务\033[0m\r\n\033[90m请先点击 [执行部署] 按钮启动部署\033[0m\r\n\r\n"
await self.send(bytes_data=msg.encode())
# Channel Layer 消息处理
async def terminal_output(self, event):
if self.shell:
await asyncio.to_thread(self.shell.send, event['content'])
async def deploy_status(self, event):
await self.send(text_data=json.dumps({
'type': 'status',
'status': event['status'],
'message': event.get('message', '')
}))

View File

@@ -0,0 +1,112 @@
"""
初始化默认扫描引擎
用法:
python manage.py init_default_engine # 只创建不存在的引擎(不覆盖已有)
python manage.py init_default_engine --force # 强制覆盖所有引擎配置
cd /root/my-vulun-scan/docker
docker compose exec server python backend/manage.py init_default_engine --force
功能:
- 读取 engine_config_example.yaml 作为默认配置
- 创建 full scan默认引擎+ 各扫描类型的子引擎
- 默认不覆盖已有配置,加 --force 才会覆盖
"""
from django.core.management.base import BaseCommand
from pathlib import Path
import yaml
from apps.engine.models import ScanEngine
class Command(BaseCommand):
help = '初始化默认扫描引擎配置(默认不覆盖已有,加 --force 强制覆盖)'
def add_arguments(self, parser):
parser.add_argument(
'--force',
action='store_true',
help='强制覆盖已有的引擎配置',
)
def handle(self, *args, **options):
force = options.get('force', False)
# 读取默认配置文件
config_path = Path(__file__).resolve().parent.parent.parent.parent / 'scan' / 'configs' / 'engine_config_example.yaml'
if not config_path.exists():
self.stdout.write(self.style.ERROR(f'配置文件不存在: {config_path}'))
return
with open(config_path, 'r', encoding='utf-8') as f:
default_config = f.read()
# 解析 YAML 为字典,后续用于生成子引擎配置
try:
config_dict = yaml.safe_load(default_config) or {}
except yaml.YAMLError as e:
self.stdout.write(self.style.ERROR(f'引擎配置 YAML 解析失败: {e}'))
return
# 1) full scan保留完整配置
engine = ScanEngine.objects.filter(name='full scan').first()
if engine:
if force:
engine.configuration = default_config
engine.save()
self.stdout.write(self.style.SUCCESS(f'✓ 扫描引擎 full scan 配置已更新 (ID: {engine.id})'))
else:
self.stdout.write(self.style.WARNING(f' ⊘ full scan 已存在,跳过(使用 --force 覆盖)'))
else:
engine = ScanEngine.objects.create(
name='full scan',
configuration=default_config,
)
self.stdout.write(self.style.SUCCESS(f'✓ 扫描引擎 full scan 已创建 (ID: {engine.id})'))
# 2) 为每个扫描类型生成一个「单一扫描类型」的子引擎
# 例如subdomain_discovery, port_scan, ...
from apps.scan.configs.command_templates import get_supported_scan_types
supported_scan_types = set(get_supported_scan_types())
for scan_type, scan_cfg in config_dict.items():
# 只处理受支持且结构为 {tools: {...}} 的扫描类型
if scan_type not in supported_scan_types:
continue
if not isinstance(scan_cfg, dict):
continue
# subdomain_discovery 使用 4 阶段新结构(无 tools 字段),其他扫描类型仍要求有 tools
if scan_type != 'subdomain_discovery' and 'tools' not in scan_cfg:
continue
# 构造只包含当前扫描类型配置的 YAML
single_config = {scan_type: scan_cfg}
try:
single_yaml = yaml.safe_dump(
single_config,
sort_keys=False,
allow_unicode=True,
)
except yaml.YAMLError as e:
self.stdout.write(self.style.ERROR(f'生成子引擎 {scan_type} 配置失败: {e}'))
continue
engine_name = f"{scan_type}"
sub_engine = ScanEngine.objects.filter(name=engine_name).first()
if sub_engine:
if force:
sub_engine.configuration = single_yaml
sub_engine.save()
self.stdout.write(self.style.SUCCESS(f' ✓ 子引擎 {engine_name} 配置已更新 (ID: {sub_engine.id})'))
else:
self.stdout.write(self.style.WARNING(f'{engine_name} 已存在,跳过(使用 --force 覆盖)'))
else:
sub_engine = ScanEngine.objects.create(
name=engine_name,
configuration=single_yaml,
)
self.stdout.write(self.style.SUCCESS(f' ✓ 子引擎 {engine_name} 已创建 (ID: {sub_engine.id})'))

View File

@@ -0,0 +1,126 @@
"""初始化 Nuclei 模板仓库
项目安装后执行此命令,自动创建官方模板仓库记录。
使用方式:
python manage.py init_nuclei_templates # 只创建记录
python manage.py init_nuclei_templates --sync # 创建并同步git clone
"""
import logging
from django.core.management.base import BaseCommand
from apps.engine.models import NucleiTemplateRepo
from apps.engine.services import NucleiTemplateRepoService
logger = logging.getLogger(__name__)
# 默认仓库配置
DEFAULT_REPOS = [
{
"name": "nuclei-templates",
"repo_url": "https://github.com/projectdiscovery/nuclei-templates.git",
"description": "Nuclei 官方模板仓库,包含数千个漏洞检测模板",
},
]
class Command(BaseCommand):
help = "初始化 Nuclei 模板仓库(创建官方模板仓库记录)"
def add_arguments(self, parser):
parser.add_argument(
"--sync",
action="store_true",
help="创建后立即同步git clone首次需要较长时间",
)
parser.add_argument(
"--force",
action="store_true",
help="强制重新创建(删除已存在的同名仓库)",
)
def handle(self, *args, **options):
do_sync = options.get("sync", False)
force = options.get("force", False)
service = NucleiTemplateRepoService()
created = 0
skipped = 0
synced = 0
for repo_config in DEFAULT_REPOS:
name = repo_config["name"]
repo_url = repo_config["repo_url"]
# 检查是否已存在
existing = NucleiTemplateRepo.objects.filter(name=name).first()
if existing:
if force:
self.stdout.write(self.style.WARNING(
f"[{name}] 强制模式,删除已存在的仓库记录"
))
service.remove_local_path_dir(existing)
existing.delete()
else:
self.stdout.write(self.style.SUCCESS(
f"[{name}] 已存在,跳过创建"
))
skipped += 1
# 如果需要同步且已存在,也执行同步
if do_sync and existing.id:
try:
result = service.refresh_repo(existing.id)
self.stdout.write(self.style.SUCCESS(
f"[{name}] 同步完成: {result.get('action', 'unknown')}, "
f"commit={result.get('commitHash', 'N/A')[:8]}"
))
synced += 1
except Exception as e:
self.stdout.write(self.style.ERROR(
f"[{name}] 同步失败: {e}"
))
continue
# 创建新仓库记录
try:
repo = NucleiTemplateRepo.objects.create(
name=name,
repo_url=repo_url,
)
self.stdout.write(self.style.SUCCESS(
f"[{name}] 创建成功: id={repo.id}"
))
created += 1
# 初始化本地路径
service.ensure_local_path(repo)
# 如果需要同步
if do_sync:
try:
self.stdout.write(self.style.WARNING(
f"[{name}] 正在同步(首次可能需要几分钟)..."
))
result = service.refresh_repo(repo.id)
self.stdout.write(self.style.SUCCESS(
f"[{name}] 同步完成: {result.get('action', 'unknown')}, "
f"commit={result.get('commitHash', 'N/A')[:8]}"
))
synced += 1
except Exception as e:
self.stdout.write(self.style.ERROR(
f"[{name}] 同步失败: {e}"
))
except Exception as e:
self.stdout.write(self.style.ERROR(
f"[{name}] 创建失败: {e}"
))
self.stdout.write(self.style.SUCCESS(
f"\n初始化完成: 创建 {created}, 跳过 {skipped}, 同步 {synced}"
))

View File

@@ -0,0 +1,148 @@
"""初始化所有内置字典 Wordlist 记录
- 目录扫描默认字典: dir_default.txt -> /app/backend/wordlist/dir_default.txt
- 子域名爆破默认字典: subdomains-top1million-110000.txt -> /app/backend/wordlist/subdomains-top1million-110000.txt
可重复执行:如果已存在同名记录且文件有效则跳过,只在缺失或文件丢失时创建/修复。
"""
import logging
import shutil
from pathlib import Path
from django.conf import settings
from django.core.management.base import BaseCommand
from apps.common.hash_utils import safe_calc_file_sha256
from apps.engine.models import Wordlist
logger = logging.getLogger(__name__)
DEFAULT_WORDLISTS = [
{
"name": "dir_default.txt",
"filename": "dir_default.txt",
"description": "内置默认目录字典",
},
{
"name": "subdomains-top1million-110000.txt",
"filename": "subdomains-top1million-110000.txt",
"description": "内置默认子域名字典",
},
]
class Command(BaseCommand):
help = "初始化所有内置字典 Wordlist 记录"
def handle(self, *args, **options):
project_base = Path(settings.BASE_DIR).parent # /app/backend -> /app
base_wordlist_dir = project_base / "backend" / "wordlist"
runtime_base_dir = Path(getattr(settings, "WORDLISTS_BASE_PATH", "/opt/xingrin/wordlists"))
runtime_base_dir.mkdir(parents=True, exist_ok=True)
initialized = 0
skipped = 0
failed = 0
for item in DEFAULT_WORDLISTS:
name = item["name"]
filename = item["filename"]
description = item["description"]
existing = Wordlist.objects.filter(name=name).first()
if existing:
file_path = existing.file_path or ""
file_hash = getattr(existing, 'file_hash', '') or ''
if file_path and Path(file_path).exists() and file_hash:
# 记录、文件、hash 都在,直接跳过
self.stdout.write(self.style.SUCCESS(
f"[{name}] 已存在且文件有效,跳过初始化 (file_path={file_path})"
))
skipped += 1
continue
elif file_path and Path(file_path).exists() and not file_hash:
# 文件在但 hash 缺失,需要补算
self.stdout.write(self.style.WARNING(
f"[{name}] 记录已存在但缺少 file_hash将补算并更新"
))
else:
self.stdout.write(self.style.WARNING(
f"[{name}] 记录已存在但物理文件丢失,将重新创建文件路径并修复记录"
))
src_path = base_wordlist_dir / filename
dest_path = runtime_base_dir / filename
if not src_path.exists():
self.stdout.write(self.style.WARNING(
f"[{name}] 未找到内置字典文件: {src_path},跳过"
))
failed += 1
continue
try:
shutil.copy2(src_path, dest_path)
except OSError as exc:
self.stdout.write(self.style.WARNING(
f"[{name}] 复制内置字典到运行目录失败: {exc}"
))
failed += 1
continue
# 统计文件大小和行数
try:
file_size = dest_path.stat().st_size
except OSError:
file_size = 0
line_count = 0
try:
with dest_path.open("rb") as f:
for _ in f:
line_count += 1
except OSError:
logger.warning("统计字典行数失败: %s", src_path)
# 计算文件 hash
file_hash = safe_calc_file_sha256(str(dest_path)) or ""
# 如果之前已有记录则更新,否则创建新记录
if existing:
existing.file_path = str(dest_path)
existing.file_size = file_size
existing.line_count = line_count
existing.file_hash = file_hash
existing.description = existing.description or description
existing.save(update_fields=[
"file_path",
"file_size",
"line_count",
"file_hash",
"description",
"updated_at",
])
wordlist = existing
action = "更新"
else:
wordlist = Wordlist.objects.create(
name=name,
description=description,
file_path=str(dest_path),
file_size=file_size,
line_count=line_count,
file_hash=file_hash,
)
action = "创建"
initialized += 1
hash_preview = (wordlist.file_hash[:16] + "...") if wordlist.file_hash else "N/A"
self.stdout.write(self.style.SUCCESS(
f"[{name}] {action}字典记录成功: id={wordlist.id}, size={wordlist.file_size}, lines={wordlist.line_count}, hash={hash_preview}"
))
self.stdout.write(self.style.SUCCESS(
f"初始化完成: 成功 {initialized}, 已存在跳过 {skipped}, 文件缺失 {failed}"
))

View File

@@ -0,0 +1,126 @@
from django.db import models
class WorkerNode(models.Model):
"""Worker 节点模型 - 分布式扫描执行器"""
# 状态选项(前后端统一)
STATUS_CHOICES = [
('pending', '待部署'),
('deploying', '部署中'),
('online', '在线'),
('offline', '离线'),
]
name = models.CharField(max_length=100, help_text='节点名称')
# 本地节点会自动填入 127.0.0.1 或容器 IP
ip_address = models.GenericIPAddressField(help_text='IP 地址(本地节点为 127.0.0.1')
ssh_port = models.IntegerField(default=22, help_text='SSH 端口')
username = models.CharField(max_length=50, default='root', help_text='SSH 用户名')
password = models.CharField(max_length=200, blank=True, default='', help_text='SSH 密码')
# 本地节点标记Docker 容器内的 Worker
is_local = models.BooleanField(default=False, help_text='是否为本地节点Docker 容器内)')
# 状态(前后端统一)
status = models.CharField(
max_length=20,
choices=STATUS_CHOICES,
default='pending',
help_text='状态: pending/deploying/online/offline'
)
# 心跳数据存储在 Redisworker:load:{id}),不再使用数据库字段
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
db_table = 'worker_node'
verbose_name = 'Worker 节点'
ordering = ['-created_at']
constraints = [
# 远程节点 IP 唯一(本地节点不限制,因为都是 127.0.0.1
models.UniqueConstraint(
fields=['ip_address'],
condition=models.Q(is_local=False),
name='unique_remote_worker_ip'
),
# 名称全局唯一
models.UniqueConstraint(
fields=['name'],
name='unique_worker_name'
),
]
def __str__(self):
if self.is_local:
return f"{self.name} (本地)"
return f"{self.name} ({self.ip_address or '未知'})"
class ScanEngine(models.Model):
"""扫描引擎模型"""
id = models.AutoField(primary_key=True)
name = models.CharField(max_length=200, unique=True, help_text='引擎名称')
configuration = models.CharField(max_length=10000, blank=True, default='', help_text='引擎配置yaml 格式')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
class Meta:
db_table = 'scan_engine'
verbose_name = '扫描引擎'
verbose_name_plural = '扫描引擎'
ordering = ['-created_at']
indexes = [
models.Index(fields=['-created_at']),
]
def __str__(self):
return str(self.name or f'ScanEngine {self.id}')
class Wordlist(models.Model):
"""字典文件模型"""
id = models.AutoField(primary_key=True)
name = models.CharField(max_length=200, unique=True, help_text='字典名称,唯一')
description = models.CharField(max_length=200, blank=True, default='', help_text='字典描述')
file_path = models.CharField(max_length=500, help_text='后端保存的字典文件绝对路径')
file_size = models.BigIntegerField(default=0, help_text='文件大小(字节)')
line_count = models.IntegerField(default=0, help_text='字典行数')
file_hash = models.CharField(max_length=64, blank=True, default='', help_text='文件 SHA-256 哈希,用于缓存校验')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
class Meta:
db_table = 'wordlist'
verbose_name = '字典文件'
verbose_name_plural = '字典文件'
ordering = ['-created_at']
indexes = [
models.Index(fields=['-created_at']),
]
def __str__(self) -> str:
return self.name
class NucleiTemplateRepo(models.Model):
"""Nuclei 模板 Git 仓库模型(多仓库)"""
name = models.CharField(max_length=200, unique=True, help_text="仓库名称,用于前端展示和配置引用")
repo_url = models.CharField(max_length=500, help_text="Git 仓库地址")
local_path = models.CharField(max_length=500, blank=True, default='', help_text="本地工作目录绝对路径")
commit_hash = models.CharField(max_length=40, blank=True, default='', help_text="最后同步的 Git commit hash用于 Worker 版本校验")
last_synced_at = models.DateTimeField(null=True, blank=True, help_text="最后一次成功同步时间")
created_at = models.DateTimeField(auto_now_add=True, help_text="创建时间")
updated_at = models.DateTimeField(auto_now=True, help_text="更新时间")
class Meta:
db_table = "nuclei_template_repo"
verbose_name = "Nuclei 模板仓库"
verbose_name_plural = "Nuclei 模板仓库"
def __str__(self) -> str: # pragma: no cover - 简单表示
return f"NucleiTemplateRepo({self.id}, {self.name})"

View File

@@ -0,0 +1,17 @@
"""Engine Repositories 模块
提供 ScanEngine、WorkerNode、Wordlist、NucleiRepo 等数据访问层实现
"""
from .django_engine_repository import DjangoEngineRepository
from .django_worker_repository import DjangoWorkerRepository
from .django_wordlist_repository import DjangoWordlistRepository
from .nuclei_repo_repository import NucleiTemplateRepository, TemplateFileRepository
__all__ = [
"DjangoEngineRepository",
"DjangoWorkerRepository",
"DjangoWordlistRepository",
"NucleiTemplateRepository",
"TemplateFileRepository",
]

Some files were not shown because too many files have changed in this diff Show More