mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-01-31 11:46:16 +08:00
Initial commit: Xingrin v1.0.0
This commit is contained in:
13
.agent/rules/project.md
Normal file
13
.agent/rules/project.md
Normal file
@@ -0,0 +1,13 @@
|
||||
---
|
||||
trigger: always_on
|
||||
---
|
||||
|
||||
1.后端网页应该是 8888 端口
|
||||
2.后端请运行虚拟环境再运行命令,环境在项目根目录~/Desktop/scanner/.venv/bin/python
|
||||
3.前端所有路由加上末尾斜杠,以匹配 django 的 DRF 规则
|
||||
4.网页测试可以用 curl
|
||||
8.所有前端 api 接口都应该写在@services 中,所有 type 类型都应该写在@types 中
|
||||
10.前端的加载等逻辑用 React Query来实现,自动管理
|
||||
17.所有业务操作的 toast 都放在 hook 中
|
||||
19.目前后端项目,去不用做安全漏洞方面的相关的代码
|
||||
23.前端非必要不要采用window.location.href去跳转,而是用Next.js 客户端路由
|
||||
23
.dockerignore
Normal file
23
.dockerignore
Normal file
@@ -0,0 +1,23 @@
|
||||
# Node modules(前端本地开发产物,Docker 构建时会重新安装)
|
||||
frontend/node_modules
|
||||
frontend/.next
|
||||
|
||||
# Python 虚拟环境
|
||||
.venv
|
||||
__pycache__
|
||||
*.pyc
|
||||
|
||||
# 日志和临时文件
|
||||
*.log
|
||||
.DS_Store
|
||||
|
||||
# Git
|
||||
.git
|
||||
.gitignore
|
||||
|
||||
# IDE
|
||||
.idea
|
||||
.vscode
|
||||
|
||||
# Docker 相关(避免嵌套)
|
||||
docker/.env
|
||||
84
.github/workflows/docker-build.yml
vendored
Normal file
84
.github/workflows/docker-build.yml
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
name: Build and Push Docker Images
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'backend/**'
|
||||
- 'frontend/**'
|
||||
- 'docker/**'
|
||||
- '.github/workflows/**'
|
||||
workflow_dispatch: # 手动触发
|
||||
|
||||
# 并发控制:同一分支只保留最新的构建,取消之前正在运行的
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
REGISTRY: docker.io
|
||||
IMAGE_PREFIX: yyhuni
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- image: xingrin-server
|
||||
dockerfile: docker/server/Dockerfile
|
||||
context: .
|
||||
- image: xingrin-frontend
|
||||
dockerfile: docker/frontend/Dockerfile
|
||||
context: .
|
||||
- image: xingrin-worker
|
||||
dockerfile: docker/worker/Dockerfile
|
||||
context: .
|
||||
- image: xingrin-nginx
|
||||
dockerfile: docker/nginx/Dockerfile
|
||||
context: .
|
||||
- image: xingrin-agent
|
||||
dockerfile: docker/agent/Dockerfile
|
||||
context: .
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free disk space (for large builds like worker)
|
||||
run: |
|
||||
echo "=== Before cleanup ==="
|
||||
df -h
|
||||
# 删除不需要的大型软件包
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo docker image prune -af
|
||||
echo "=== After cleanup ==="
|
||||
df -h
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.dockerfile }}
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:latest
|
||||
${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:${{ github.sha }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
133
.gitignore
vendored
Normal file
133
.gitignore
vendored
Normal file
@@ -0,0 +1,133 @@
|
||||
# ============================
|
||||
# 操作系统相关文件
|
||||
# ============================
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
|
||||
# ============================
|
||||
# 前端 (Next.js/Node.js) 相关
|
||||
# ============================
|
||||
# 依赖目录
|
||||
front-back/node_modules/
|
||||
front-back/.pnpm-store/
|
||||
|
||||
# Next.js 构建产物
|
||||
front-back/.next/
|
||||
front-back/out/
|
||||
front-back/dist/
|
||||
|
||||
# 环境变量文件
|
||||
front-back/.env
|
||||
front-back/.env.local
|
||||
front-back/.env.development.local
|
||||
front-back/.env.test.local
|
||||
front-back/.env.production.local
|
||||
|
||||
# 运行时和缓存
|
||||
front-back/.turbo/
|
||||
front-back/.swc/
|
||||
front-back/.eslintcache
|
||||
front-back/.tsbuildinfo
|
||||
|
||||
# ============================
|
||||
# 后端 (Python/Django) 相关
|
||||
# ============================
|
||||
# Python 虚拟环境
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
|
||||
# Python 编译文件
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Django 相关
|
||||
backend/db.sqlite3
|
||||
backend/db.sqlite3-journal
|
||||
backend/media/
|
||||
backend/staticfiles/
|
||||
backend/.env
|
||||
backend/.env.local
|
||||
|
||||
# Python 测试和覆盖率
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
*.cover
|
||||
|
||||
# ============================
|
||||
# 后端 (Go) 相关
|
||||
# ============================
|
||||
# 编译产物
|
||||
backend/bin/
|
||||
backend/dist/
|
||||
backend/*.exe
|
||||
backend/*.exe~
|
||||
backend/*.dll
|
||||
backend/*.so
|
||||
backend/*.dylib
|
||||
|
||||
# 测试相关
|
||||
backend/*.test
|
||||
backend/*.out
|
||||
backend/*.prof
|
||||
|
||||
# Go workspace 文件
|
||||
backend/go.work
|
||||
backend/go.work.sum
|
||||
|
||||
# Go 依赖管理
|
||||
backend/vendor/
|
||||
|
||||
# ============================
|
||||
# IDE 和编辑器相关
|
||||
# ============================
|
||||
.vscode/
|
||||
.idea/
|
||||
.cursor/
|
||||
.claude/
|
||||
.playwright-mcp/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# ============================
|
||||
# Docker 相关
|
||||
# ============================
|
||||
docker/.env
|
||||
docker/.env.local
|
||||
|
||||
# SSL 证书和私钥(不应提交)
|
||||
docker/nginx/ssl/*.pem
|
||||
docker/nginx/ssl/*.key
|
||||
docker/nginx/ssl/*.crt
|
||||
|
||||
# ============================
|
||||
# 日志文件和扫描结果
|
||||
# ============================
|
||||
*.log
|
||||
logs/
|
||||
results/
|
||||
|
||||
# 开发脚本运行时文件(进程 ID 和启动日志)
|
||||
backend/scripts/dev/.pids/
|
||||
|
||||
# ============================
|
||||
# 临时文件
|
||||
# ============================
|
||||
tmp/
|
||||
temp/
|
||||
.cache/
|
||||
|
||||
HGETALL
|
||||
KEYS
|
||||
13
.windsurf/rules/backend.md
Normal file
13
.windsurf/rules/backend.md
Normal file
@@ -0,0 +1,13 @@
|
||||
---
|
||||
trigger: always_on
|
||||
---
|
||||
|
||||
1.后端网页应该是 8888 端口
|
||||
3.前端所有路由加上末尾斜杠,以匹配 django 的 DRF 规则
|
||||
4.网页测试可以用 curl
|
||||
8.所有前端 api 接口都应该写在@services 中,所有 type 类型都应该写在@types 中
|
||||
10.前端的加载等逻辑用 React Query来实现,自动管理
|
||||
17.所有业务操作的 toast 都放在 hook 中
|
||||
19.目前后端项目,去不用做安全漏洞方面的相关的代码
|
||||
23.前端非必要不要采用window.location.href去跳转,而是用Next.js 客户端路由
|
||||
24.ui相关的都去调用mcp来看看有没有通用组件,美观的组件来实现
|
||||
85
.windsurf/rules/code-preview.md
Normal file
85
.windsurf/rules/code-preview.md
Normal file
@@ -0,0 +1,85 @@
|
||||
---
|
||||
trigger: manual
|
||||
description: 进行代码审查的时候,必须调用这个规则
|
||||
---
|
||||
|
||||
### **0. 逻辑正确性 & Bug 排查** *(最高优先级,必须手动推演)*
|
||||
|
||||
**目标**:不依赖测试,主动发现“代码能跑但结果错”的逻辑错误。
|
||||
|
||||
1. **手动推演关键路径**:
|
||||
- 选 2~3 个典型输入(含边界),**在脑中或纸上一步步推演代码执行流程**。
|
||||
- 输出是否符合预期?每一步变量变化是否正确?
|
||||
2. **常见逻辑 bug 检查**:
|
||||
- **off-by-one**:循环、数组索引、分页
|
||||
- **条件逻辑错误**:`and`/`or` 优先级、短路求值误用
|
||||
- **状态混乱**:变量未初始化、被意外覆盖
|
||||
- **算法偏差**:排序、搜索、二分查找的中点处理
|
||||
- **浮点精度**:是否误用 `==` 比较浮点数?
|
||||
3. **控制流审查**:
|
||||
- 所有 `if/else` 分支是否都覆盖?有无“不可达代码”?
|
||||
- `switch`/`match` 是否有 `default`?是否漏 case?
|
||||
- 异常路径会返回什么?是否遗漏 `finally` 清理?
|
||||
4. **业务逻辑一致性**:
|
||||
- 是否符合**业务规则**?(如“订单总额 = 商品价 × 数量 + 运费 - 折扣”)
|
||||
- 是否遗漏隐含约束?(如“用户只能评价已完成的订单”)
|
||||
|
||||
### **一、功能性 & 正确性** *(阻塞性问题必须修复)*
|
||||
|
||||
1. **需求符合度**:是否100%覆盖需求?遗漏/多余功能点?
|
||||
2. **边界条件**:
|
||||
- 输入:`null`、空、极值、非法格式
|
||||
- 集合:空、单元素、超大(如10⁶)
|
||||
- 循环:终止条件、off-by-one
|
||||
3. **错误处理**:
|
||||
- 异常捕获全面?失败路径有降级?
|
||||
- 错误信息清晰?不泄露栈迹?
|
||||
4. **并发安全**:
|
||||
- 竞态/死锁风险?共享资源是否同步?
|
||||
- 使用了`volatile`/`synchronized`/`Lock`/`atomic`?
|
||||
5. **单元测试**:
|
||||
- 覆盖率 ≥80%?包含正向/边界/异常用例?
|
||||
- 测试独立?无外部依赖?
|
||||
|
||||
### **二、代码质量与可读性**
|
||||
|
||||
1. **命名**:见名知意?遵循规范?
|
||||
2. **函数设计**:
|
||||
- **单一职责**?参数 ≤4?建议长度 <50行(视语言调整)
|
||||
- 可提取为工具函数?
|
||||
3. **结构与复杂度**:
|
||||
- 无重复代码?圈复杂度 <10?
|
||||
- 嵌套 ≤3层?使用卫语句提前返回
|
||||
4. **注释**:解释**为什么**而非**是什么**?复杂逻辑必注释
|
||||
5. **风格一致**:通过`Prettier`/`ESLint`/`Spotless`自动格式化
|
||||
|
||||
### **三、架构与设计**
|
||||
|
||||
1. **SOLID**:是否符合单一职责、开闭、依赖倒置?
|
||||
2. **依赖**:是否依赖接口而非实现?无循环依赖?
|
||||
3. **可测试性**:是否支持依赖注入?避免`new`硬编码
|
||||
4. **扩展性**:新增功能是否只需改一处?
|
||||
|
||||
### **四、性能优化**
|
||||
|
||||
- **N+1查询**?循环内IO/日志/分配?
|
||||
- 算法复杂度合理?(如O(n²)是否可优化)
|
||||
- 内存:无泄漏?大对象及时释放?缓存有失效?
|
||||
|
||||
### **五、其他**
|
||||
|
||||
1. **可维护性**:日志带上下文?修改后更干净?
|
||||
2. **兼容性**:API/数据库变更是否向后兼容?
|
||||
3. **依赖管理**:新库必要?许可证合规?
|
||||
|
||||
---
|
||||
|
||||
### **审查最佳实践**
|
||||
|
||||
- **小批次审查**:≤200行/次
|
||||
- **语气建议**:`“建议将函数拆分以提升可读性”` 而非 `“这个函数太长了”`
|
||||
- **自动化先行**:风格/空指针/安全扫描 → CI工具
|
||||
- **重点分级**:
|
||||
- 🛑 **阻塞**:功能错、安全漏洞
|
||||
- ⚠️ **必须改**:设计缺陷、性能瓶颈
|
||||
- 💡 **建议**:风格、命名、可读性
|
||||
195
.windsurf/rules/codelayering.md
Normal file
195
.windsurf/rules/codelayering.md
Normal file
@@ -0,0 +1,195 @@
|
||||
---
|
||||
trigger: always_on
|
||||
---
|
||||
|
||||
## 标准分层架构调用顺序
|
||||
|
||||
按照 **DDD(领域驱动设计)和清洁架构**原则,调用顺序应该是:
|
||||
|
||||
```
|
||||
HTTP请求 → Views → Tasks → Services → Repositories → Models
|
||||
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 📊 完整的调用链路图
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ HTTP Request (前端) │
|
||||
└────────────────────────┬────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Views (HTTP 层) │
|
||||
│ - 参数验证 │
|
||||
│ - 权限检查 │
|
||||
│ - 调用 Tasks/Services │
|
||||
│ - 返回 HTTP 响应 │
|
||||
└────────────────────────┬────────────────────────────────────┘
|
||||
↓
|
||||
┌────────────────┴────────────────┐
|
||||
↓ (异步) ↓ (同步)
|
||||
┌──────────────────┐ ┌──────────────────┐
|
||||
│ Tasks (任务层) │ │ Services (业务层)│
|
||||
│ - 异步执行 │ │ - 业务逻辑 │
|
||||
│ - 后台作业 │───────>│ - 事务管理 │
|
||||
│ - 通知发送 │ │ - 数据验证 │
|
||||
└──────────────────┘ └────────┬─────────┘
|
||||
↓
|
||||
┌──────────────────────┐
|
||||
│ Repositories (存储层) │
|
||||
│ - 数据访问 │
|
||||
│ - 查询封装 │
|
||||
│ - 批量操作 │
|
||||
└────────┬─────────────┘
|
||||
↓
|
||||
┌──────────────────────┐
|
||||
│ Models (模型层) │
|
||||
│ - ORM 定义 │
|
||||
│ - 数据结构 │
|
||||
│ - 关系映射 │
|
||||
└──────────────────────┘
|
||||
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 🔄 具体调用示例
|
||||
|
||||
### **场景 1:同步删除(Views → Services → Repositories → Models)**
|
||||
|
||||
```python
|
||||
# 1. Views 层 (views.py)
|
||||
def some_sync_delete(self, request):
|
||||
# 参数验证
|
||||
target_ids = request.data.get('ids')
|
||||
|
||||
# 调用 Service 层
|
||||
service = TargetService()
|
||||
result = service.bulk_delete_targets(target_ids)
|
||||
|
||||
# 返回响应
|
||||
return Response({'message': 'deleted'})
|
||||
|
||||
# 2. Services 层 (services/target_service.py)
|
||||
class TargetService:
|
||||
def bulk_delete_targets(self, target_ids):
|
||||
# 业务逻辑验证
|
||||
logger.info("准备删除...")
|
||||
|
||||
# 调用 Repository 层
|
||||
deleted_count = self.repo.bulk_delete_by_ids(target_ids)
|
||||
|
||||
# 返回结果
|
||||
return deleted_count
|
||||
|
||||
# 3. Repositories 层 (repositories/django_target_repository.py)
|
||||
class DjangoTargetRepository:
|
||||
def bulk_delete_by_ids(self, target_ids):
|
||||
# 数据访问操作
|
||||
return Target.objects.filter(id__in=target_ids).delete()
|
||||
|
||||
# 4. Models 层 (models.py)
|
||||
class Target(models.Model):
|
||||
# ORM 定义
|
||||
name = models.CharField(...)
|
||||
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### **场景 2:异步删除(Views → Tasks → Services → Repositories → Models)**
|
||||
|
||||
```python
|
||||
# 1. Views 层 (views.py)
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
target = self.get_object()
|
||||
|
||||
# 调用 Tasks 层(异步)
|
||||
async_bulk_delete_targets([target.id], [target.name])
|
||||
|
||||
# 立即返回 202
|
||||
return Response(status=202)
|
||||
|
||||
# 2. Tasks 层 (tasks/target_tasks.py)
|
||||
def async_bulk_delete_targets(target_ids, target_names):
|
||||
def _delete():
|
||||
# 发送通知
|
||||
create_notification("删除中...")
|
||||
|
||||
# 调用 Service 层
|
||||
service = TargetService()
|
||||
result = service.bulk_delete_targets(target_ids)
|
||||
|
||||
# 发送完成通知
|
||||
create_notification("删除成功")
|
||||
|
||||
# 后台线程执行
|
||||
threading.Thread(target=_delete).start()
|
||||
|
||||
# 3. Services 层 (services/target_service.py)
|
||||
class TargetService:
|
||||
def bulk_delete_targets(self, target_ids):
|
||||
# 业务逻辑
|
||||
return self.repo.bulk_delete_by_ids(target_ids)
|
||||
|
||||
# 4. Repositories 层 (repositories/django_target_repository.py)
|
||||
class DjangoTargetRepository:
|
||||
def bulk_delete_by_ids(self, target_ids):
|
||||
# 数据访问
|
||||
return Target.objects.filter(id__in=target_ids).delete()
|
||||
|
||||
# 5. Models 层 (models.py)
|
||||
class Target(models.Model):
|
||||
# ORM 定义
|
||||
...
|
||||
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 📋 各层职责清单
|
||||
|
||||
| 层级 | 职责 | 不应该做 |
|
||||
| --- | --- | --- |
|
||||
| **Views** | HTTP 请求处理、参数验证、权限检查 | ❌ 直接访问 Models<br>❌ 业务逻辑 |
|
||||
| **Tasks** | 异步执行、后台作业、通知发送 | ❌ 直接访问 Models<br>❌ HTTP 响应 |
|
||||
| **Services** | 业务逻辑、事务管理、数据验证 | ❌ 直接写 SQL<br>❌ HTTP 相关 |
|
||||
| **Repositories** | 数据访问、查询封装、批量操作 | ❌ 业务逻辑<br>❌ 通知发送 |
|
||||
| **Models** | ORM 定义、数据结构、关系映射 | ❌ 业务逻辑<br>❌ 复杂查询 |
|
||||
|
||||
---
|
||||
|
||||
### ✅ 最佳实践原则
|
||||
|
||||
1. **单向依赖**:只能向下调用,不能向上调用
|
||||
|
||||
```
|
||||
Views → Tasks → Services → Repositories → Models
|
||||
(上层) (下层)
|
||||
|
||||
```
|
||||
|
||||
2. **层级隔离**:相邻层交互,禁止跨层
|
||||
- ✅ Views → Services
|
||||
- ✅ Tasks → Services
|
||||
- ✅ Services → Repositories
|
||||
- ❌ Views → Repositories(跨层)
|
||||
- ❌ Tasks → Models(跨层)
|
||||
3. **依赖注入**:通过构造函数注入依赖
|
||||
|
||||
```python
|
||||
class TargetService:
|
||||
def __init__(self):
|
||||
self.repo = DjangoTargetRepository() # 注入
|
||||
|
||||
```
|
||||
|
||||
4. **接口抽象**:使用 Protocol 定义接口
|
||||
|
||||
```python
|
||||
class TargetRepository(Protocol):
|
||||
def bulk_delete_by_ids(self, ids): ...
|
||||
|
||||
```
|
||||
131
LICENSE
Normal file
131
LICENSE
Normal file
@@ -0,0 +1,131 @@
|
||||
# PolyForm Noncommercial License 1.0.0
|
||||
|
||||
<https://polyformproject.org/licenses/noncommercial/1.0.0>
|
||||
|
||||
## Acceptance
|
||||
|
||||
In order to get any license under these terms, you must agree
|
||||
to them as both strict obligations and conditions to all
|
||||
your licenses.
|
||||
|
||||
## Copyright License
|
||||
|
||||
The licensor grants you a copyright license for the
|
||||
software to do everything you might do with the software
|
||||
that would otherwise infringe the licensor's copyright
|
||||
in it for any permitted purpose. However, you may
|
||||
only distribute the software according to [Distribution
|
||||
License](#distribution-license) and make changes or new works
|
||||
based on the software according to [Changes and New Works
|
||||
License](#changes-and-new-works-license).
|
||||
|
||||
## Distribution License
|
||||
|
||||
The licensor grants you an additional copyright license
|
||||
to distribute copies of the software. Your license
|
||||
to distribute covers distributing the software with
|
||||
changes and new works permitted by [Changes and New Works
|
||||
License](#changes-and-new-works-license).
|
||||
|
||||
## Notices
|
||||
|
||||
You must ensure that anyone who gets a copy of any part of
|
||||
the software from you also gets a copy of these terms or the
|
||||
URL for them above, as well as copies of any plain-text lines
|
||||
beginning with `Required Notice:` that the licensor provided
|
||||
with the software. For example:
|
||||
|
||||
> Required Notice: Copyright Yuhang Yang (yyhuni)
|
||||
|
||||
## Changes and New Works License
|
||||
|
||||
The licensor grants you an additional copyright license to
|
||||
make changes and new works based on the software for any
|
||||
permitted purpose.
|
||||
|
||||
## Patent License
|
||||
|
||||
The licensor grants you a patent license for the software that
|
||||
covers patent claims the licensor can license, or becomes able
|
||||
to license, that you would infringe by using the software.
|
||||
|
||||
## Noncommercial Purposes
|
||||
|
||||
Any noncommercial purpose is a permitted purpose.
|
||||
|
||||
## Personal Uses
|
||||
|
||||
Personal use for research, experiment, and testing for
|
||||
the benefit of public knowledge, personal study, private
|
||||
entertainment, hobby projects, amateur pursuits, or religious
|
||||
observance, without any anticipated commercial application,
|
||||
is use for a permitted purpose.
|
||||
|
||||
## Noncommercial Organizations
|
||||
|
||||
Use by any charitable organization, educational institution,
|
||||
public research organization, public safety or health
|
||||
organization, environmental protection organization,
|
||||
or government institution is use for a permitted purpose
|
||||
regardless of the source of funding or obligations resulting
|
||||
from the funding.
|
||||
|
||||
## Fair Use
|
||||
|
||||
You may have "fair use" rights for the software under the
|
||||
law. These terms do not limit them.
|
||||
|
||||
## No Other Rights
|
||||
|
||||
These terms do not allow you to sublicense or transfer any of
|
||||
your licenses to anyone else, or prevent the licensor from
|
||||
granting licenses to anyone else. These terms do not imply
|
||||
any other licenses.
|
||||
|
||||
## Patent Defense
|
||||
|
||||
If you make any written claim that the software infringes or
|
||||
contributes to infringement of any patent, your patent license
|
||||
for the software granted under these terms ends immediately. If
|
||||
your company makes such a claim, your patent license ends
|
||||
immediately for work on behalf of your company.
|
||||
|
||||
## Violations
|
||||
|
||||
The first time you are notified in writing that you have
|
||||
violated any of these terms, or done anything with the software
|
||||
not covered by your licenses, your licenses can nonetheless
|
||||
continue if you come into full compliance with these terms,
|
||||
and take practical steps to correct past violations, within
|
||||
32 days of receiving notice. Otherwise, all your licenses
|
||||
end immediately.
|
||||
|
||||
## No Liability
|
||||
|
||||
***As far as the law allows, the software comes as is, without
|
||||
any warranty or condition, and the licensor will not be liable
|
||||
to you for any damages arising out of these terms or the use
|
||||
or nature of the software, under any kind of legal claim.***
|
||||
|
||||
## Definitions
|
||||
|
||||
The **licensor** is the individual or entity offering these
|
||||
terms, and the **software** is the software the licensor makes
|
||||
available under these terms.
|
||||
|
||||
**You** refers to the individual or entity agreeing to these
|
||||
terms.
|
||||
|
||||
**Your company** is any legal entity, sole proprietorship,
|
||||
or other kind of organization that you work for, plus all
|
||||
organizations that have control over, are under the control of,
|
||||
or are under common control with that organization. **Control**
|
||||
means ownership of substantially all the assets of an entity,
|
||||
or the power to direct its management and policies by vote,
|
||||
contract, or otherwise. Control can be direct or indirect.
|
||||
|
||||
**Your licenses** are all the licenses granted to you for the
|
||||
software under these terms.
|
||||
|
||||
**Use** means anything you do with the software requiring one
|
||||
of your licenses.
|
||||
140
README.md
Normal file
140
README.md
Normal file
@@ -0,0 +1,140 @@
|
||||
<h1 align="center">Xingrin - 星环</h1>
|
||||
|
||||
<p align="center">
|
||||
<b>一款现代化的企业级漏洞扫描与资产管理平台</b><br>
|
||||
提供自动化安全检测、资产发现、漏洞管理等功能
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## ✨ 功能特性
|
||||
|
||||
### 🎯 目标与资产管理
|
||||
- **组织管理** - 多层级目标组织,灵活分组
|
||||
- **目标管理** - 支持域名、IP、URL 等多种目标类型
|
||||
- **资产发现** - 子域名、网站、端点、目录自动发现
|
||||
- **资产快照** - 扫描结果快照对比,追踪资产变化
|
||||
|
||||
### 🔍 漏洞扫描
|
||||
- **多引擎支持** - 集成 Nuclei 等主流扫描引擎
|
||||
- **自定义流程** - YAML 配置扫描流程,灵活编排
|
||||
- **漏洞分级** - 严重/高危/中危/低危 四级分类
|
||||
- **定时扫描** - Cron 表达式配置,自动化周期扫描
|
||||
|
||||
### 🖥️ 分布式架构
|
||||
- **Worker 节点** - 支持多节点分布式扫描
|
||||
- **本地/远程** - 本地 Docker 节点 + SSH 远程节点
|
||||
- **负载均衡** - 自动任务分发与负载监控
|
||||
- **实时状态** - WebSocket 实时推送扫描进度
|
||||
|
||||
### 📊 可视化界面
|
||||
- **数据统计** - 资产/漏洞统计仪表盘
|
||||
- **实时通知** - WebSocket 消息推送
|
||||
- **暗色主题** - 支持明暗主题切换
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ 技术栈
|
||||
|
||||
- **前端**: Next.js + React + TailwindCSS
|
||||
- **后端**: Django + Django REST Framework
|
||||
- **数据库**: PostgreSQL + Redis
|
||||
- **部署**: Docker + Nginx
|
||||
- **扫描引擎**: Nuclei
|
||||
|
||||
---
|
||||
|
||||
## 📦 快速开始
|
||||
|
||||
### 环境要求
|
||||
|
||||
- Docker 20.10+
|
||||
- Docker Compose 2.0+
|
||||
- 推荐 2核 4G 内存起步
|
||||
- 10GB+ 磁盘空间
|
||||
|
||||
### 一键安装
|
||||
|
||||
```bash
|
||||
# 克隆项目
|
||||
git clone https://github.com/yyhuni/xingrin.git
|
||||
cd xingrin
|
||||
|
||||
# 安装并启动(生产模式)
|
||||
sudo ./install.sh
|
||||
|
||||
# 开发模式
|
||||
sudo ./install.sh --dev
|
||||
```
|
||||
|
||||
### 访问服务
|
||||
|
||||
- **Web 界面**: `https://localhost` 或 `http://localhost`
|
||||
- **API 接口**: `http://localhost:8888/api/`
|
||||
- **API 文档**: `http://localhost:8888/swagger/`
|
||||
|
||||
### 常用命令
|
||||
|
||||
```bash
|
||||
# 启动服务
|
||||
sudo ./start.sh
|
||||
|
||||
# 停止服务
|
||||
sudo ./stop.sh
|
||||
|
||||
# 重启服务
|
||||
sudo ./restart.sh
|
||||
|
||||
# 卸载
|
||||
sudo ./uninstall.sh
|
||||
|
||||
# 更新
|
||||
sudo ./update.sh
|
||||
```
|
||||
|
||||
## ⚠️ 免责声明
|
||||
|
||||
**重要:请在使用前仔细阅读**
|
||||
|
||||
1. 本工具仅供**授权的安全测试**和**安全研究**使用
|
||||
2. 使用者必须确保已获得目标系统的**合法授权**
|
||||
3. **严禁**将本工具用于未经授权的渗透测试或攻击行为
|
||||
4. 未经授权扫描他人系统属于**违法行为**,可能面临法律责任
|
||||
5. 开发者**不对任何滥用行为负责**
|
||||
|
||||
使用本工具即表示您同意:
|
||||
- 仅在合法授权范围内使用
|
||||
- 遵守所在地区的法律法规
|
||||
- 承担因滥用产生的一切后果
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
本项目采用 [PolyForm Noncommercial License 1.0.0](LICENSE) 许可证。
|
||||
|
||||
### 允许的用途
|
||||
|
||||
- ✅ 个人学习和研究
|
||||
- ✅ 非商业安全测试
|
||||
- ✅ 教育机构使用
|
||||
- ✅ 非营利组织使用
|
||||
|
||||
### 禁止的用途
|
||||
|
||||
- ❌ **商业用途**(包括但不限于:出售、商业服务、SaaS 等)
|
||||
- ❌ 未经授权的渗透测试
|
||||
- ❌ 任何违法行为
|
||||
|
||||
如需商业授权,请联系作者。
|
||||
|
||||
## 🤝 反馈与贡献
|
||||
|
||||
- 🐛 **发现 Bug?** 欢迎提交 [Issue](https://github.com/yyhuni/xingrin/issues)
|
||||
- 💡 **有新想法?** 欢迎提交功能建议
|
||||
- 🔧 **想参与开发?** 欢迎提交 Pull Request
|
||||
|
||||
## 📧 联系
|
||||
|
||||
- GitHub: [@yyhuni](https://github.com/yyhuni)
|
||||
- 微信公众号: **洋洋的小黑屋**
|
||||
|
||||
<img src="docs/wechat-qrcode.png" alt="微信公众号" width="200">
|
||||
51
backend/.gitignore
vendored
Normal file
51
backend/.gitignore
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# 虚拟环境
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
|
||||
# Django
|
||||
*.log
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
/media
|
||||
/staticfiles
|
||||
|
||||
# 运行时文件(Flower、PID)
|
||||
/var/*
|
||||
!/var/.gitkeep
|
||||
flower.db
|
||||
pids/
|
||||
script/dev/.pids/
|
||||
|
||||
# 扫描结果和日志(后端数据)
|
||||
/results/*
|
||||
!/results/.gitkeep
|
||||
/logs/*
|
||||
!/logs/.gitkeep
|
||||
|
||||
# 环境变量(敏感信息)
|
||||
.env
|
||||
.env.development
|
||||
.env.production
|
||||
.env.staging
|
||||
# 只提交模板文件:.env.*.example
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
0
backend/apps/__init__.py
Normal file
0
backend/apps/__init__.py
Normal file
0
backend/apps/asset/__init__.py
Normal file
0
backend/apps/asset/__init__.py
Normal file
10
backend/apps/asset/apps.py
Normal file
10
backend/apps/asset/apps.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class AssetConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'apps.asset'
|
||||
|
||||
def ready(self):
|
||||
# 导入所有模型以确保Django发现并注册
|
||||
from . import models
|
||||
28
backend/apps/asset/dtos/__init__.py
Normal file
28
backend/apps/asset/dtos/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Asset DTOs - 数据传输对象"""
|
||||
|
||||
# 资产模块 DTOs
|
||||
from .asset import (
|
||||
SubdomainDTO,
|
||||
WebSiteDTO,
|
||||
IPAddressDTO,
|
||||
DirectoryDTO,
|
||||
PortDTO,
|
||||
EndpointDTO,
|
||||
)
|
||||
|
||||
# 快照模块 DTOs
|
||||
from .snapshot import (
|
||||
SubdomainSnapshotDTO,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 资产模块
|
||||
'SubdomainDTO',
|
||||
'WebSiteDTO',
|
||||
'IPAddressDTO',
|
||||
'DirectoryDTO',
|
||||
'PortDTO',
|
||||
'EndpointDTO',
|
||||
# 快照模块
|
||||
'SubdomainSnapshotDTO',
|
||||
]
|
||||
21
backend/apps/asset/dtos/asset/__init__.py
Normal file
21
backend/apps/asset/dtos/asset/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Asset DTOs - 数据传输对象"""
|
||||
|
||||
from .subdomain_dto import SubdomainDTO
|
||||
from .ip_address_dto import IPAddressDTO
|
||||
from .port_dto import PortDTO
|
||||
from .website_dto import WebSiteDTO
|
||||
from .directory_dto import DirectoryDTO
|
||||
from .host_port_mapping_dto import HostPortMappingDTO
|
||||
from .endpoint_dto import EndpointDTO
|
||||
from .vulnerability_dto import VulnerabilityDTO
|
||||
|
||||
__all__ = [
|
||||
'SubdomainDTO',
|
||||
'IPAddressDTO',
|
||||
'PortDTO',
|
||||
'WebSiteDTO',
|
||||
'DirectoryDTO',
|
||||
'HostPortMappingDTO',
|
||||
'EndpointDTO',
|
||||
'VulnerabilityDTO',
|
||||
]
|
||||
18
backend/apps/asset/dtos/asset/directory_dto.py
Normal file
18
backend/apps/asset/dtos/asset/directory_dto.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Directory DTO"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirectoryDTO:
|
||||
"""目录数据传输对象"""
|
||||
website_id: int
|
||||
target_id: int
|
||||
url: str
|
||||
status: Optional[int] = None
|
||||
content_length: Optional[int] = None
|
||||
words: Optional[int] = None
|
||||
lines: Optional[int] = None
|
||||
content_type: str = ''
|
||||
duration: Optional[int] = None
|
||||
28
backend/apps/asset/dtos/asset/endpoint_dto.py
Normal file
28
backend/apps/asset/dtos/asset/endpoint_dto.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Endpoint DTO"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndpointDTO:
|
||||
"""端点 DTO - 资产表数据传输对象"""
|
||||
target_id: int
|
||||
url: str
|
||||
host: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
status_code: Optional[int] = None
|
||||
content_length: Optional[int] = None
|
||||
webserver: Optional[str] = None
|
||||
body_preview: Optional[str] = None
|
||||
content_type: Optional[str] = None
|
||||
tech: Optional[List[str]] = None
|
||||
vhost: Optional[bool] = None
|
||||
location: Optional[str] = None
|
||||
matched_gf_patterns: Optional[List[str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tech is None:
|
||||
self.tech = []
|
||||
if self.matched_gf_patterns is None:
|
||||
self.matched_gf_patterns = []
|
||||
12
backend/apps/asset/dtos/asset/host_port_mapping_dto.py
Normal file
12
backend/apps/asset/dtos/asset/host_port_mapping_dto.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""HostPortMapping DTO"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class HostPortMappingDTO:
|
||||
"""主机端口映射 DTO(资产表)"""
|
||||
target_id: int
|
||||
host: str
|
||||
ip: str
|
||||
port: int
|
||||
17
backend/apps/asset/dtos/asset/ip_address_dto.py
Normal file
17
backend/apps/asset/dtos/asset/ip_address_dto.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""IPAddress DTO"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAddressDTO:
|
||||
"""
|
||||
IP地址数据传输对象
|
||||
|
||||
只包含 IP 自身的信息,不包含关联关系。
|
||||
关联关系通过 SubdomainIPAssociationDTO 管理。
|
||||
"""
|
||||
ip: str
|
||||
protocol_version: str = ''
|
||||
is_private: bool = False
|
||||
reverse_pointer: str = ''
|
||||
13
backend/apps/asset/dtos/asset/port_dto.py
Normal file
13
backend/apps/asset/dtos/asset/port_dto.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Port DTO"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class PortDTO:
|
||||
"""端口数据传输对象"""
|
||||
ip_address_id: int
|
||||
number: int
|
||||
service_name: str = ''
|
||||
target_id: int = None
|
||||
scan_id: int = None
|
||||
15
backend/apps/asset/dtos/asset/subdomain_dto.py
Normal file
15
backend/apps/asset/dtos/asset/subdomain_dto.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Subdomain DTO"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubdomainDTO:
|
||||
"""
|
||||
子域名 DTO(纯资产表)
|
||||
|
||||
用于传递子域名资产数据,只包含资产本身的信息。
|
||||
扫描相关信息存储在快照表中。
|
||||
"""
|
||||
name: str
|
||||
target_id: int # 必填:子域名必须属于某个目标
|
||||
18
backend/apps/asset/dtos/asset/vulnerability_dto.py
Normal file
18
backend/apps/asset/dtos/asset/vulnerability_dto.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Vulnerability DTO"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Any
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
@dataclass
|
||||
class VulnerabilityDTO:
|
||||
"""漏洞数据传输对象(资产表用)"""
|
||||
target_id: int
|
||||
url: str
|
||||
vuln_type: str
|
||||
severity: str
|
||||
source: str = ""
|
||||
cvss_score: Optional[Decimal] = None
|
||||
description: str = ""
|
||||
raw_output: Dict[str, Any] = field(default_factory=dict)
|
||||
26
backend/apps/asset/dtos/asset/website_dto.py
Normal file
26
backend/apps/asset/dtos/asset/website_dto.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""WebSite DTO"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSiteDTO:
|
||||
"""网站数据传输对象"""
|
||||
target_id: int
|
||||
url: str
|
||||
host: str
|
||||
title: str = ''
|
||||
status_code: Optional[int] = None
|
||||
content_length: Optional[int] = None
|
||||
location: str = ''
|
||||
webserver: str = ''
|
||||
content_type: str = ''
|
||||
tech: List[str] = None
|
||||
body_preview: str = ''
|
||||
vhost: Optional[bool] = None
|
||||
created_at: str = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tech is None:
|
||||
self.tech = []
|
||||
17
backend/apps/asset/dtos/snapshot/__init__.py
Normal file
17
backend/apps/asset/dtos/snapshot/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Snapshot DTOs"""
|
||||
|
||||
from .subdomain_snapshot_dto import SubdomainSnapshotDTO
|
||||
from .host_port_mapping_snapshot_dto import HostPortMappingSnapshotDTO
|
||||
from .website_snapshot_dto import WebsiteSnapshotDTO
|
||||
from .directory_snapshot_dto import DirectorySnapshotDTO
|
||||
from .endpoint_snapshot_dto import EndpointSnapshotDTO
|
||||
from .vulnerability_snapshot_dto import VulnerabilitySnapshotDTO
|
||||
|
||||
__all__ = [
|
||||
'SubdomainSnapshotDTO',
|
||||
'HostPortMappingSnapshotDTO',
|
||||
'WebsiteSnapshotDTO',
|
||||
'DirectorySnapshotDTO',
|
||||
'EndpointSnapshotDTO',
|
||||
'VulnerabilitySnapshotDTO',
|
||||
]
|
||||
48
backend/apps/asset/dtos/snapshot/directory_snapshot_dto.py
Normal file
48
backend/apps/asset/dtos/snapshot/directory_snapshot_dto.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Directory Snapshot DTO"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from apps.asset.dtos.asset import DirectoryDTO
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirectorySnapshotDTO:
|
||||
"""
|
||||
目录快照数据传输对象
|
||||
|
||||
用于保存扫描过程中发现的目录信息到快照表
|
||||
|
||||
注意:website_id 和 target_id 只用于传递数据和转换为资产 DTO,不会保存到快照表中。
|
||||
快照只属于 scan。
|
||||
"""
|
||||
scan_id: int
|
||||
website_id: int # 仅用于传递数据,不保存到数据库
|
||||
target_id: int # 仅用于传递数据,不保存到数据库
|
||||
url: str
|
||||
status: Optional[int] = None
|
||||
content_length: Optional[int] = None
|
||||
words: Optional[int] = None
|
||||
lines: Optional[int] = None
|
||||
content_type: str = ''
|
||||
duration: Optional[int] = None
|
||||
|
||||
def to_asset_dto(self) -> DirectoryDTO:
|
||||
"""
|
||||
转换为资产 DTO(用于同步到资产表)
|
||||
|
||||
注意:去除 scan_id 字段,因为资产表不需要
|
||||
|
||||
Returns:
|
||||
DirectoryDTO: 资产表 DTO
|
||||
"""
|
||||
return DirectoryDTO(
|
||||
website_id=self.website_id,
|
||||
target_id=self.target_id,
|
||||
url=self.url,
|
||||
status=self.status,
|
||||
content_length=self.content_length,
|
||||
words=self.words,
|
||||
lines=self.lines,
|
||||
content_type=self.content_type,
|
||||
duration=self.duration
|
||||
)
|
||||
62
backend/apps/asset/dtos/snapshot/endpoint_snapshot_dto.py
Normal file
62
backend/apps/asset/dtos/snapshot/endpoint_snapshot_dto.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""EndpointSnapshot DTO"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndpointSnapshotDTO:
|
||||
"""
|
||||
端点快照 DTO
|
||||
|
||||
注意:target_id 只用于传递数据和转换为资产 DTO,不会保存到快照表中。
|
||||
快照只属于 scan。
|
||||
"""
|
||||
scan_id: int
|
||||
url: str
|
||||
host: str = '' # 主机名(域名或IP地址)
|
||||
title: str = ''
|
||||
status_code: Optional[int] = None
|
||||
content_length: Optional[int] = None
|
||||
location: str = ''
|
||||
webserver: str = ''
|
||||
content_type: str = ''
|
||||
tech: List[str] = None
|
||||
body_preview: str = ''
|
||||
vhost: Optional[bool] = None
|
||||
matched_gf_patterns: List[str] = None
|
||||
target_id: Optional[int] = None # 冗余字段,用于同步到资产表
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tech is None:
|
||||
self.tech = []
|
||||
if self.matched_gf_patterns is None:
|
||||
self.matched_gf_patterns = []
|
||||
|
||||
def to_asset_dto(self):
|
||||
"""
|
||||
转换为资产 DTO(用于同步到资产表)
|
||||
|
||||
Returns:
|
||||
EndpointDTO: 资产表 DTO(移除 scan_id)
|
||||
"""
|
||||
from apps.asset.dtos.asset import EndpointDTO
|
||||
|
||||
if self.target_id is None:
|
||||
raise ValueError("target_id 不能为 None,无法同步到资产表")
|
||||
|
||||
return EndpointDTO(
|
||||
target_id=self.target_id,
|
||||
url=self.url,
|
||||
host=self.host,
|
||||
title=self.title,
|
||||
status_code=self.status_code,
|
||||
content_length=self.content_length,
|
||||
webserver=self.webserver,
|
||||
body_preview=self.body_preview,
|
||||
content_type=self.content_type,
|
||||
tech=self.tech if self.tech else [],
|
||||
vhost=self.vhost,
|
||||
location=self.location,
|
||||
matched_gf_patterns=self.matched_gf_patterns if self.matched_gf_patterns else []
|
||||
)
|
||||
@@ -0,0 +1,33 @@
|
||||
"""HostPortMappingSnapshot DTO"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class HostPortMappingSnapshotDTO:
|
||||
"""主机端口映射快照 DTO"""
|
||||
scan_id: int
|
||||
host: str
|
||||
ip: str
|
||||
port: int
|
||||
target_id: Optional[int] = None # 冗余字段,用于同步到资产表
|
||||
|
||||
def to_asset_dto(self):
|
||||
"""
|
||||
转换为资产 DTO(用于同步到资产表)
|
||||
|
||||
Returns:
|
||||
HostPortMappingDTO: 资产表 DTO(移除 scan_id)
|
||||
"""
|
||||
from apps.asset.dtos.asset import HostPortMappingDTO
|
||||
|
||||
if self.target_id is None:
|
||||
raise ValueError("target_id 不能为 None,无法同步到资产表")
|
||||
|
||||
return HostPortMappingDTO(
|
||||
target_id=self.target_id,
|
||||
host=self.host,
|
||||
ip=self.ip,
|
||||
port=self.port
|
||||
)
|
||||
34
backend/apps/asset/dtos/snapshot/subdomain_snapshot_dto.py
Normal file
34
backend/apps/asset/dtos/snapshot/subdomain_snapshot_dto.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""SubdomainSnapshot DTO"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apps.asset.dtos import SubdomainDTO
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubdomainSnapshotDTO:
|
||||
"""
|
||||
子域名快照 DTO
|
||||
|
||||
用于传递快照数据,包含完整的业务上下文信息。
|
||||
快照表记录每次扫描的历史数据。
|
||||
"""
|
||||
name: str
|
||||
scan_id: int # 必填:快照必须关联扫描任务
|
||||
target_id: int # 必填:目标ID(用于转换为资产 DTO)
|
||||
|
||||
def to_asset_dto(self) -> 'SubdomainDTO':
|
||||
"""
|
||||
转换为资产 DTO(用于保存到资产表)
|
||||
|
||||
Returns:
|
||||
SubdomainDTO: 资产 DTO(不包含 scan_id)
|
||||
|
||||
Note:
|
||||
资产表只存储核心数据,扫描上下文(scan_id)不保存到资产表。
|
||||
target_id 已经包含在 DTO 中,无需额外传参。
|
||||
"""
|
||||
from apps.asset.dtos import SubdomainDTO
|
||||
return SubdomainDTO(name=self.name, target_id=self.target_id)
|
||||
@@ -0,0 +1,42 @@
|
||||
"""VulnerabilitySnapshot DTO"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Any
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
@dataclass
|
||||
class VulnerabilitySnapshotDTO:
|
||||
"""漏洞快照 DTO
|
||||
|
||||
对应 VulnerabilitySnapshot 模型,用于在 Service/Task 之间传递漏洞结果数据。
|
||||
|
||||
设计与其他快照 DTO 一致:
|
||||
- scan_id: 只属于快照表
|
||||
- target_id: 只用于转换为资产 DTO,不直接存入快照表
|
||||
"""
|
||||
|
||||
scan_id: int
|
||||
target_id: int # 仅用于传递数据和生成资产 DTO,不保存到快照表
|
||||
url: str
|
||||
vuln_type: str
|
||||
severity: str
|
||||
source: str = ""
|
||||
cvss_score: Optional[Decimal] = None
|
||||
description: str = ""
|
||||
raw_output: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_asset_dto(self):
|
||||
"""转换为漏洞资产 DTO(用于同步到 Vulnerability 表)。"""
|
||||
from apps.asset.dtos.asset import VulnerabilityDTO
|
||||
|
||||
return VulnerabilityDTO(
|
||||
target_id=self.target_id,
|
||||
url=self.url,
|
||||
vuln_type=self.vuln_type,
|
||||
severity=self.severity,
|
||||
source=self.source,
|
||||
cvss_score=self.cvss_score,
|
||||
description=self.description,
|
||||
raw_output=self.raw_output,
|
||||
)
|
||||
55
backend/apps/asset/dtos/snapshot/website_snapshot_dto.py
Normal file
55
backend/apps/asset/dtos/snapshot/website_snapshot_dto.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""WebsiteSnapshot DTO"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebsiteSnapshotDTO:
|
||||
"""
|
||||
网站快照 DTO
|
||||
|
||||
注意:target_id 只用于传递数据和转换为资产 DTO,不会保存到快照表中。
|
||||
快照只属于 scan,target 信息通过 scan.target 获取。
|
||||
"""
|
||||
scan_id: int
|
||||
target_id: int # 仅用于传递数据,不保存到数据库
|
||||
url: str
|
||||
host: str
|
||||
title: str = ''
|
||||
status: Optional[int] = None
|
||||
content_length: Optional[int] = None
|
||||
location: str = ''
|
||||
web_server: str = ''
|
||||
content_type: str = ''
|
||||
tech: List[str] = None
|
||||
body_preview: str = ''
|
||||
vhost: Optional[bool] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tech is None:
|
||||
self.tech = []
|
||||
|
||||
def to_asset_dto(self):
|
||||
"""
|
||||
转换为资产 DTO(用于同步到资产表)
|
||||
|
||||
Returns:
|
||||
WebSiteDTO: 资产表 DTO(移除 scan_id)
|
||||
"""
|
||||
from apps.asset.dtos.asset import WebSiteDTO
|
||||
|
||||
return WebSiteDTO(
|
||||
target_id=self.target_id,
|
||||
url=self.url,
|
||||
host=self.host,
|
||||
title=self.title,
|
||||
status_code=self.status,
|
||||
content_length=self.content_length,
|
||||
location=self.location,
|
||||
webserver=self.web_server,
|
||||
content_type=self.content_type,
|
||||
tech=self.tech if self.tech else [],
|
||||
body_preview=self.body_preview,
|
||||
vhost=self.vhost
|
||||
)
|
||||
0
backend/apps/asset/migrations/__init__.py
Normal file
0
backend/apps/asset/migrations/__init__.py
Normal file
45
backend/apps/asset/models/__init__.py
Normal file
45
backend/apps/asset/models/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# 导入所有模型,确保Django能发现它们
|
||||
|
||||
# 业务模型
|
||||
from .asset_models import (
|
||||
Subdomain,
|
||||
WebSite,
|
||||
Endpoint,
|
||||
Directory,
|
||||
HostPortMapping,
|
||||
Vulnerability,
|
||||
)
|
||||
|
||||
# 快照模型
|
||||
from .snapshot_models import (
|
||||
SubdomainSnapshot,
|
||||
WebsiteSnapshot,
|
||||
DirectorySnapshot,
|
||||
HostPortMappingSnapshot,
|
||||
EndpointSnapshot,
|
||||
VulnerabilitySnapshot,
|
||||
)
|
||||
|
||||
# 统计模型
|
||||
from .statistics_models import AssetStatistics, StatisticsHistory
|
||||
|
||||
# 导出所有模型供外部导入
|
||||
__all__ = [
|
||||
# 业务模型
|
||||
'Subdomain',
|
||||
'WebSite',
|
||||
'Endpoint',
|
||||
'Directory',
|
||||
'HostPortMapping',
|
||||
'Vulnerability',
|
||||
# 快照模型
|
||||
'SubdomainSnapshot',
|
||||
'WebsiteSnapshot',
|
||||
'DirectorySnapshot',
|
||||
'HostPortMappingSnapshot',
|
||||
'EndpointSnapshot',
|
||||
'VulnerabilitySnapshot',
|
||||
# 统计模型
|
||||
'AssetStatistics',
|
||||
'StatisticsHistory',
|
||||
]
|
||||
515
backend/apps/asset/models/asset_models.py
Normal file
515
backend/apps/asset/models/asset_models.py
Normal file
@@ -0,0 +1,515 @@
|
||||
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.core.validators import MinValueValidator, MaxValueValidator
|
||||
|
||||
|
||||
class SoftDeleteManager(models.Manager):
|
||||
"""软删除管理器:默认只返回未删除的记录"""
|
||||
|
||||
def get_queryset(self):
|
||||
return super().get_queryset().filter(deleted_at__isnull=True)
|
||||
|
||||
|
||||
class Subdomain(models.Model):
|
||||
"""
|
||||
子域名模型(纯资产表)
|
||||
|
||||
设计特点:
|
||||
- 只存储子域名资产信息
|
||||
- 与其他资产表(IPAddress、Port)无直接关联
|
||||
- 扫描历史记录存储在 SubdomainSnapshot 快照表中
|
||||
"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
target = models.ForeignKey(
|
||||
'targets.Target', # 使用字符串引用避免循环导入
|
||||
on_delete=models.CASCADE,
|
||||
related_name='subdomains',
|
||||
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
|
||||
)
|
||||
name = models.CharField(max_length=1000, help_text='子域名名称')
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='首次发现时间')
|
||||
|
||||
# ==================== 软删除字段 ====================
|
||||
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间(NULL表示未删除)')
|
||||
|
||||
# ==================== 管理器 ====================
|
||||
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
|
||||
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
|
||||
|
||||
class Meta:
|
||||
db_table = 'subdomain'
|
||||
verbose_name = '子域名'
|
||||
verbose_name_plural = '子域名'
|
||||
ordering = ['-discovered_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['name', 'target']), # 复合索引,优化 get_by_names_and_target_id 批量查询
|
||||
models.Index(fields=['target']), # 优化从target_id快速查找下面的子域名
|
||||
models.Index(fields=['name']), # 优化从name快速查找子域名,搜索场景
|
||||
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
|
||||
]
|
||||
constraints = [
|
||||
# 部分唯一约束:只对未删除记录生效
|
||||
models.UniqueConstraint(
|
||||
fields=['name', 'target'],
|
||||
condition=models.Q(deleted_at__isnull=True),
|
||||
name='unique_name_target_active'
|
||||
)
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.name or f'Subdomain {self.id}')
|
||||
|
||||
|
||||
class Endpoint(models.Model):
|
||||
"""端点模型"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
target = models.ForeignKey(
|
||||
'targets.Target', # 使用字符串引用
|
||||
on_delete=models.CASCADE,
|
||||
related_name='endpoints',
|
||||
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
|
||||
)
|
||||
|
||||
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
|
||||
host = models.CharField(
|
||||
max_length=253,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='主机名(域名或IP地址)'
|
||||
)
|
||||
location = models.CharField(
|
||||
max_length=1000,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='重定向地址(HTTP 3xx 响应头 Location)'
|
||||
)
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
title = models.CharField(
|
||||
max_length=1000,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='网页标题(HTML <title> 标签内容)'
|
||||
)
|
||||
webserver = models.CharField(
|
||||
max_length=200,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='服务器类型(HTTP 响应头 Server 值)'
|
||||
)
|
||||
body_preview = models.CharField(
|
||||
max_length=1000,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应正文前N个字符(默认100个字符)'
|
||||
)
|
||||
content_type = models.CharField(
|
||||
max_length=200,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应类型(HTTP Content-Type 响应头)'
|
||||
)
|
||||
tech = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
default=list,
|
||||
help_text='技术栈(服务器/框架/语言等)'
|
||||
)
|
||||
status_code = models.IntegerField(
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='HTTP状态码'
|
||||
)
|
||||
content_length = models.IntegerField(
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='响应体大小(单位字节)'
|
||||
)
|
||||
vhost = models.BooleanField(
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='是否支持虚拟主机'
|
||||
)
|
||||
matched_gf_patterns = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
default=list,
|
||||
help_text='匹配的GF模式列表,用于识别敏感端点(如api, debug, config等)'
|
||||
)
|
||||
|
||||
# ==================== 软删除字段 ====================
|
||||
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间(NULL表示未删除)')
|
||||
|
||||
# ==================== 管理器 ====================
|
||||
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
|
||||
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
|
||||
|
||||
class Meta:
|
||||
db_table = 'endpoint'
|
||||
verbose_name = '端点'
|
||||
verbose_name_plural = '端点'
|
||||
ordering = ['-discovered_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['target']), # 优化从target_id快速查找下面的端点(主关联字段)
|
||||
models.Index(fields=['url']), # URL索引,优化查询性能
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['status_code']), # 状态码索引,优化筛选
|
||||
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
|
||||
]
|
||||
constraints = [
|
||||
# 部分唯一约束:只对未删除记录生效
|
||||
models.UniqueConstraint(
|
||||
fields=['url', 'target'],
|
||||
condition=models.Q(deleted_at__isnull=True),
|
||||
name='unique_endpoint_url_target_active'
|
||||
)
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.url or f'Endpoint {self.id}')
|
||||
|
||||
|
||||
class WebSite(models.Model):
|
||||
"""站点模型"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
target = models.ForeignKey(
|
||||
'targets.Target', # 使用字符串引用
|
||||
on_delete=models.CASCADE,
|
||||
related_name='websites',
|
||||
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
|
||||
)
|
||||
|
||||
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
|
||||
host = models.CharField(
|
||||
max_length=253,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='主机名(域名或IP地址)'
|
||||
)
|
||||
location = models.CharField(
|
||||
max_length=1000,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='重定向地址(HTTP 3xx 响应头 Location)'
|
||||
)
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
title = models.CharField(
|
||||
max_length=1000,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='网页标题(HTML <title> 标签内容)'
|
||||
)
|
||||
webserver = models.CharField(
|
||||
max_length=200,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='服务器类型(HTTP 响应头 Server 值)'
|
||||
)
|
||||
body_preview = models.CharField(
|
||||
max_length=1000,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应正文前N个字符(默认100个字符)'
|
||||
)
|
||||
content_type = models.CharField(
|
||||
max_length=200,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应类型(HTTP Content-Type 响应头)'
|
||||
)
|
||||
tech = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
default=list,
|
||||
help_text='技术栈(服务器/框架/语言等)'
|
||||
)
|
||||
status_code = models.IntegerField(
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='HTTP状态码'
|
||||
)
|
||||
content_length = models.IntegerField(
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='响应体大小(单位字节)'
|
||||
)
|
||||
vhost = models.BooleanField(
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='是否支持虚拟主机'
|
||||
)
|
||||
|
||||
# ==================== 软删除字段 ====================
|
||||
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间(NULL表示未删除)')
|
||||
|
||||
# ==================== 管理器 ====================
|
||||
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
|
||||
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
|
||||
|
||||
class Meta:
|
||||
db_table = 'website'
|
||||
verbose_name = '站点'
|
||||
verbose_name_plural = '站点'
|
||||
ordering = ['-discovered_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['url']), # URL索引,优化查询性能
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['target']), # 优化从target_id快速查找下面的站点
|
||||
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
|
||||
]
|
||||
constraints = [
|
||||
# 部分唯一约束:只对未删除记录生效
|
||||
models.UniqueConstraint(
|
||||
fields=['url', 'target'],
|
||||
condition=models.Q(deleted_at__isnull=True),
|
||||
name='unique_website_url_target_active'
|
||||
)
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.url or f'Website {self.id}')
|
||||
|
||||
|
||||
class Directory(models.Model):
|
||||
"""
|
||||
目录模型
|
||||
"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
website = models.ForeignKey(
|
||||
'Website',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='directories',
|
||||
help_text='所属的站点(主关联字段,表示所属关系,不能为空)'
|
||||
)
|
||||
target = models.ForeignKey(
|
||||
'targets.Target', # 使用字符串引用
|
||||
on_delete=models.CASCADE,
|
||||
related_name='directories',
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='所属的扫描目标(冗余字段,用于快速查询)'
|
||||
)
|
||||
|
||||
url = models.CharField(
|
||||
null=False,
|
||||
blank=False,
|
||||
max_length=2000,
|
||||
help_text='完整请求 URL'
|
||||
)
|
||||
status = models.IntegerField(
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='HTTP 响应状态码'
|
||||
)
|
||||
content_length = models.BigIntegerField(
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='响应体字节大小(Content-Length 或实际长度)'
|
||||
)
|
||||
words = models.IntegerField(
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='响应体中单词数量(按空格分割)'
|
||||
)
|
||||
lines = models.IntegerField(
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='响应体行数(按换行符分割)'
|
||||
)
|
||||
content_type = models.CharField(
|
||||
max_length=200,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应头 Content-Type 值'
|
||||
)
|
||||
duration = models.BigIntegerField(
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text='请求耗时(单位:纳秒)'
|
||||
)
|
||||
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
|
||||
# ==================== 软删除字段 ====================
|
||||
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间(NULL表示未删除)')
|
||||
|
||||
# ==================== 管理器 ====================
|
||||
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
|
||||
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
|
||||
|
||||
class Meta:
|
||||
db_table = 'directory'
|
||||
verbose_name = '目录'
|
||||
verbose_name_plural = '目录'
|
||||
ordering = ['-discovered_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['target']), # 优化从target_id快速查找下面的目录
|
||||
models.Index(fields=['url']), # URL索引,优化搜索和唯一约束
|
||||
models.Index(fields=['website']), # 站点索引,优化按站点查询
|
||||
models.Index(fields=['status']), # 状态码索引,优化筛选
|
||||
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
|
||||
]
|
||||
constraints = [
|
||||
# 部分唯一约束:只对未删除记录生效
|
||||
models.UniqueConstraint(
|
||||
fields=['website', 'url'],
|
||||
condition=models.Q(deleted_at__isnull=True),
|
||||
name='unique_directory_url_website_active'
|
||||
),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.url or f'Directory {self.id}')
|
||||
|
||||
|
||||
class HostPortMapping(models.Model):
|
||||
"""
|
||||
主机端口映射表
|
||||
|
||||
设计特点:
|
||||
- 存储主机(host)、IP、端口的三元映射关系
|
||||
- 只关联 target_id,不关联其他资产表
|
||||
- target + host + ip + port 组成复合唯一约束
|
||||
"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
|
||||
# ==================== 关联字段 ====================
|
||||
target = models.ForeignKey(
|
||||
'targets.Target',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='host_port_mappings',
|
||||
help_text='所属的扫描目标'
|
||||
)
|
||||
|
||||
# ==================== 核心字段 ====================
|
||||
host = models.CharField(
|
||||
max_length=1000,
|
||||
blank=False,
|
||||
help_text='主机名(域名或IP)'
|
||||
)
|
||||
ip = models.GenericIPAddressField(
|
||||
blank=False,
|
||||
help_text='IP地址'
|
||||
)
|
||||
port = models.IntegerField(
|
||||
blank=False,
|
||||
validators=[
|
||||
MinValueValidator(1, message='端口号必须大于等于1'),
|
||||
MaxValueValidator(65535, message='端口号必须小于等于65535')
|
||||
],
|
||||
help_text='端口号(1-65535)'
|
||||
)
|
||||
|
||||
# ==================== 时间字段 ====================
|
||||
discovered_at = models.DateTimeField(
|
||||
auto_now_add=True,
|
||||
help_text='发现时间'
|
||||
)
|
||||
|
||||
# ==================== 软删除字段 ====================
|
||||
deleted_at = models.DateTimeField(
|
||||
null=True,
|
||||
blank=True,
|
||||
db_index=True,
|
||||
help_text='删除时间(NULL表示未删除)'
|
||||
)
|
||||
|
||||
# ==================== 管理器 ====================
|
||||
objects = SoftDeleteManager() # 默认管理器:只返回未删除的记录
|
||||
all_objects = models.Manager() # 全量管理器:包括已删除的记录(用于硬删除)
|
||||
|
||||
class Meta:
|
||||
db_table = 'host_port_mapping'
|
||||
verbose_name = '主机端口映射'
|
||||
verbose_name_plural = '主机端口映射'
|
||||
ordering = ['-discovered_at']
|
||||
indexes = [
|
||||
models.Index(fields=['target']), # 优化按目标查询
|
||||
models.Index(fields=['host']), # 优化按主机名查询
|
||||
models.Index(fields=['ip']), # 优化按IP查询
|
||||
models.Index(fields=['port']), # 优化按端口查询
|
||||
models.Index(fields=['host', 'ip']), # 优化组合查询
|
||||
models.Index(fields=['-discovered_at']), # 优化时间排序
|
||||
models.Index(fields=['deleted_at', '-discovered_at']), # 软删除 + 时间索引
|
||||
]
|
||||
constraints = [
|
||||
# 复合唯一约束:target + host + ip + port 组合唯一(只对未删除记录生效)
|
||||
models.UniqueConstraint(
|
||||
fields=['target', 'host', 'ip', 'port'],
|
||||
condition=models.Q(deleted_at__isnull=True),
|
||||
name='unique_target_host_ip_port_active'
|
||||
),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.host} ({self.ip}:{self.port})'
|
||||
|
||||
|
||||
class Vulnerability(models.Model):
|
||||
"""
|
||||
漏洞模型(资产表)
|
||||
|
||||
存储发现的漏洞资产,与 Target 关联。
|
||||
扫描历史记录存储在 VulnerabilitySnapshot 快照表中。
|
||||
"""
|
||||
|
||||
# 延迟导入避免循环引用
|
||||
from apps.common.definitions import VulnSeverity
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
target = models.ForeignKey(
|
||||
'targets.Target',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='vulnerabilities',
|
||||
help_text='所属的扫描目标'
|
||||
)
|
||||
|
||||
# ==================== 核心字段 ====================
|
||||
url = models.TextField(help_text='漏洞所在的URL')
|
||||
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli)')
|
||||
severity = models.CharField(
|
||||
max_length=20,
|
||||
choices=VulnSeverity.choices,
|
||||
default=VulnSeverity.UNKNOWN,
|
||||
help_text='严重性(未知/信息/低/中/高/危急)'
|
||||
)
|
||||
source = models.CharField(max_length=50, blank=True, default='', help_text='来源工具(如 dalfox, nuclei, crlfuzz)')
|
||||
cvss_score = models.DecimalField(max_digits=3, decimal_places=1, null=True, blank=True, help_text='CVSS 评分(0.0-10.0)')
|
||||
description = models.TextField(blank=True, default='', help_text='漏洞描述')
|
||||
raw_output = models.JSONField(blank=True, default=dict, help_text='工具原始输出')
|
||||
|
||||
# ==================== 时间字段 ====================
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='首次发现时间')
|
||||
|
||||
# ==================== 软删除字段 ====================
|
||||
deleted_at = models.DateTimeField(null=True, blank=True, db_index=True, help_text='删除时间(NULL表示未删除)')
|
||||
|
||||
# ==================== 管理器 ====================
|
||||
objects = SoftDeleteManager()
|
||||
all_objects = models.Manager()
|
||||
|
||||
class Meta:
|
||||
db_table = 'vulnerability'
|
||||
verbose_name = '漏洞'
|
||||
verbose_name_plural = '漏洞'
|
||||
ordering = ['-discovered_at']
|
||||
indexes = [
|
||||
models.Index(fields=['target']),
|
||||
models.Index(fields=['vuln_type']),
|
||||
models.Index(fields=['severity']),
|
||||
models.Index(fields=['source']),
|
||||
models.Index(fields=['-discovered_at']),
|
||||
models.Index(fields=['deleted_at', '-discovered_at']),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.vuln_type} - {self.url[:50]}'
|
||||
335
backend/apps/asset/models/snapshot_models.py
Normal file
335
backend/apps/asset/models/snapshot_models.py
Normal file
@@ -0,0 +1,335 @@
|
||||
from django.db import models
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.core.validators import MinValueValidator, MaxValueValidator
|
||||
|
||||
|
||||
class SubdomainSnapshot(models.Model):
|
||||
"""子域名快照"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
scan = models.ForeignKey(
|
||||
'scan.Scan',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='subdomain_snapshots',
|
||||
help_text='所属的扫描任务'
|
||||
)
|
||||
|
||||
name = models.CharField(max_length=1000, help_text='子域名名称')
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'subdomain_snapshot'
|
||||
verbose_name = '子域名快照'
|
||||
verbose_name_plural = '子域名快照'
|
||||
ordering = ['-discovered_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['name']),
|
||||
models.Index(fields=['-discovered_at']),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:同一次扫描中,同一个子域名只能记录一次
|
||||
models.UniqueConstraint(
|
||||
fields=['scan', 'name'],
|
||||
name='unique_subdomain_per_scan_snapshot'
|
||||
),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.name} (Scan #{self.scan_id})'
|
||||
|
||||
class WebsiteSnapshot(models.Model):
|
||||
"""
|
||||
网站快照
|
||||
|
||||
记录:某次扫描中发现的网站
|
||||
"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
scan = models.ForeignKey(
|
||||
'scan.Scan',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='website_snapshots',
|
||||
help_text='所属的扫描任务'
|
||||
)
|
||||
|
||||
# 扫描结果数据
|
||||
url = models.CharField(max_length=2000, help_text='站点URL')
|
||||
host = models.CharField(max_length=253, blank=True, default='', help_text='主机名(域名或IP地址)')
|
||||
title = models.CharField(max_length=500, blank=True, default='', help_text='页面标题')
|
||||
status = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
|
||||
content_length = models.BigIntegerField(null=True, blank=True, help_text='内容长度')
|
||||
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
|
||||
web_server = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
|
||||
tech = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
default=list,
|
||||
help_text='技术栈'
|
||||
)
|
||||
body_preview = models.TextField(blank=True, default='', help_text='响应体预览')
|
||||
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'website_snapshot'
|
||||
verbose_name = '网站快照'
|
||||
verbose_name_plural = '网站快照'
|
||||
ordering = ['-discovered_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['url']),
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['-discovered_at']),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:同一次扫描中,同一个URL只能记录一次
|
||||
models.UniqueConstraint(
|
||||
fields=['scan', 'url'],
|
||||
name='unique_website_per_scan_snapshot'
|
||||
),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.url} (Scan #{self.scan_id})'
|
||||
|
||||
|
||||
class DirectorySnapshot(models.Model):
|
||||
"""
|
||||
目录快照
|
||||
|
||||
记录:某次扫描中发现的目录
|
||||
"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
scan = models.ForeignKey(
|
||||
'scan.Scan',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='directory_snapshots',
|
||||
help_text='所属的扫描任务'
|
||||
)
|
||||
|
||||
# 扫描结果数据
|
||||
url = models.CharField(max_length=2000, help_text='目录URL')
|
||||
status = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
|
||||
content_length = models.BigIntegerField(null=True, blank=True, help_text='内容长度')
|
||||
words = models.IntegerField(null=True, blank=True, help_text='响应体中单词数量(按空格分割)')
|
||||
lines = models.IntegerField(null=True, blank=True, help_text='响应体行数(按换行符分割)')
|
||||
content_type = models.CharField(max_length=200, blank=True, default='', help_text='响应头 Content-Type 值')
|
||||
duration = models.BigIntegerField(null=True, blank=True, help_text='请求耗时(单位:纳秒)')
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'directory_snapshot'
|
||||
verbose_name = '目录快照'
|
||||
verbose_name_plural = '目录快照'
|
||||
ordering = ['-discovered_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['url']),
|
||||
models.Index(fields=['status']), # 状态码索引,优化筛选
|
||||
models.Index(fields=['-discovered_at']),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:同一次扫描中,同一个目录URL只能记录一次
|
||||
models.UniqueConstraint(
|
||||
fields=['scan', 'url'],
|
||||
name='unique_directory_per_scan_snapshot'
|
||||
),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.url} (Scan #{self.scan_id})'
|
||||
|
||||
|
||||
class HostPortMappingSnapshot(models.Model):
|
||||
"""
|
||||
主机端口映射快照表
|
||||
|
||||
设计特点:
|
||||
- 存储某次扫描中发现的主机(host)、IP、端口的三元映射关系
|
||||
- 主关联 scan_id,记录扫描历史
|
||||
- scan + host + ip + port 组成复合唯一约束
|
||||
"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
|
||||
# ==================== 关联字段 ====================
|
||||
scan = models.ForeignKey(
|
||||
'scan.Scan',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='host_port_mapping_snapshots',
|
||||
help_text='所属的扫描任务(主关联)'
|
||||
)
|
||||
|
||||
# ==================== 核心字段 ====================
|
||||
host = models.CharField(
|
||||
max_length=1000,
|
||||
blank=False,
|
||||
help_text='主机名(域名或IP)'
|
||||
)
|
||||
ip = models.GenericIPAddressField(
|
||||
blank=False,
|
||||
help_text='IP地址'
|
||||
)
|
||||
port = models.IntegerField(
|
||||
blank=False,
|
||||
validators=[
|
||||
MinValueValidator(1, message='端口号必须大于等于1'),
|
||||
MaxValueValidator(65535, message='端口号必须小于等于65535')
|
||||
],
|
||||
help_text='端口号(1-65535)'
|
||||
)
|
||||
|
||||
# ==================== 时间字段 ====================
|
||||
discovered_at = models.DateTimeField(
|
||||
auto_now_add=True,
|
||||
help_text='发现时间'
|
||||
)
|
||||
|
||||
class Meta:
|
||||
db_table = 'host_port_mapping_snapshot'
|
||||
verbose_name = '主机端口映射快照'
|
||||
verbose_name_plural = '主机端口映射快照'
|
||||
ordering = ['-discovered_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']), # 优化按扫描查询
|
||||
models.Index(fields=['host']), # 优化按主机名查询
|
||||
models.Index(fields=['ip']), # 优化按IP查询
|
||||
models.Index(fields=['port']), # 优化按端口查询
|
||||
models.Index(fields=['host', 'ip']), # 优化组合查询
|
||||
models.Index(fields=['scan', 'host']), # 优化扫描+主机查询
|
||||
models.Index(fields=['-discovered_at']), # 优化时间排序
|
||||
]
|
||||
constraints = [
|
||||
# 复合唯一约束:同一次扫描中,scan + host + ip + port 组合唯一
|
||||
models.UniqueConstraint(
|
||||
fields=['scan', 'host', 'ip', 'port'],
|
||||
name='unique_scan_host_ip_port_snapshot'
|
||||
),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.host} ({self.ip}:{self.port}) [Scan #{self.scan_id}]'
|
||||
|
||||
|
||||
class EndpointSnapshot(models.Model):
|
||||
"""
|
||||
端点快照
|
||||
|
||||
记录:某次扫描中发现的端点
|
||||
"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
scan = models.ForeignKey(
|
||||
'scan.Scan',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='endpoint_snapshots',
|
||||
help_text='所属的扫描任务'
|
||||
)
|
||||
|
||||
# 扫描结果数据
|
||||
url = models.CharField(max_length=2000, help_text='端点URL')
|
||||
host = models.CharField(
|
||||
max_length=253,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='主机名(域名或IP地址)'
|
||||
)
|
||||
title = models.CharField(max_length=1000, blank=True, default='', help_text='页面标题')
|
||||
status_code = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
|
||||
content_length = models.IntegerField(null=True, blank=True, help_text='内容长度')
|
||||
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
|
||||
webserver = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
|
||||
tech = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
default=list,
|
||||
help_text='技术栈'
|
||||
)
|
||||
body_preview = models.CharField(max_length=1000, blank=True, default='', help_text='响应体预览')
|
||||
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
|
||||
matched_gf_patterns = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
default=list,
|
||||
help_text='匹配的GF模式列表'
|
||||
)
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'endpoint_snapshot'
|
||||
verbose_name = '端点快照'
|
||||
verbose_name_plural = '端点快照'
|
||||
ordering = ['-discovered_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['url']),
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['status_code']), # 状态码索引,优化筛选
|
||||
models.Index(fields=['-discovered_at']),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:同一次扫描中,同一个URL只能记录一次
|
||||
models.UniqueConstraint(
|
||||
fields=['scan', 'url'],
|
||||
name='unique_endpoint_per_scan_snapshot'
|
||||
),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.url} (Scan #{self.scan_id})'
|
||||
|
||||
|
||||
class VulnerabilitySnapshot(models.Model):
|
||||
"""
|
||||
漏洞快照
|
||||
|
||||
记录:某次扫描中发现的漏洞
|
||||
"""
|
||||
|
||||
# 延迟导入避免循环引用
|
||||
from apps.common.definitions import VulnSeverity
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
scan = models.ForeignKey(
|
||||
'scan.Scan',
|
||||
on_delete=models.CASCADE,
|
||||
related_name='vulnerability_snapshots',
|
||||
help_text='所属的扫描任务'
|
||||
)
|
||||
|
||||
# ==================== 核心字段 ====================
|
||||
url = models.TextField(help_text='漏洞所在的URL')
|
||||
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli)')
|
||||
severity = models.CharField(
|
||||
max_length=20,
|
||||
choices=VulnSeverity.choices,
|
||||
default=VulnSeverity.UNKNOWN,
|
||||
help_text='严重性(未知/信息/低/中/高/危急)'
|
||||
)
|
||||
source = models.CharField(max_length=50, blank=True, default='', help_text='来源工具(如 dalfox, nuclei, crlfuzz)')
|
||||
cvss_score = models.DecimalField(max_digits=3, decimal_places=1, null=True, blank=True, help_text='CVSS 评分(0.0-10.0)')
|
||||
description = models.TextField(blank=True, default='', help_text='漏洞描述')
|
||||
raw_output = models.JSONField(blank=True, default=dict, help_text='工具原始输出')
|
||||
|
||||
# ==================== 时间字段 ====================
|
||||
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'vulnerability_snapshot'
|
||||
verbose_name = '漏洞快照'
|
||||
verbose_name_plural = '漏洞快照'
|
||||
ordering = ['-discovered_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['vuln_type']),
|
||||
models.Index(fields=['severity']),
|
||||
models.Index(fields=['source']),
|
||||
models.Index(fields=['-discovered_at']),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.vuln_type} - {self.url[:50]} (Scan #{self.scan_id})'
|
||||
82
backend/apps/asset/models/statistics_models.py
Normal file
82
backend/apps/asset/models/statistics_models.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from django.db import models
|
||||
|
||||
|
||||
class AssetStatistics(models.Model):
|
||||
"""
|
||||
资产统计表
|
||||
|
||||
存储预聚合的全局统计数据,避免仪表盘实时 COUNT 大表。
|
||||
由定时任务(Prefect Flow)定期刷新。
|
||||
"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
|
||||
# ==================== 当前统计字段 ====================
|
||||
total_targets = models.IntegerField(default=0, help_text='目标总数')
|
||||
total_subdomains = models.IntegerField(default=0, help_text='子域名总数')
|
||||
total_ips = models.IntegerField(default=0, help_text='IP地址总数')
|
||||
total_endpoints = models.IntegerField(default=0, help_text='端点总数')
|
||||
total_websites = models.IntegerField(default=0, help_text='网站总数')
|
||||
total_vulns = models.IntegerField(default=0, help_text='漏洞总数')
|
||||
total_assets = models.IntegerField(default=0, help_text='总资产数(子域名+IP+端点+网站)')
|
||||
|
||||
# ==================== 上次统计字段(用于计算趋势)====================
|
||||
prev_targets = models.IntegerField(default=0, help_text='上次目标总数')
|
||||
prev_subdomains = models.IntegerField(default=0, help_text='上次子域名总数')
|
||||
prev_ips = models.IntegerField(default=0, help_text='上次IP地址总数')
|
||||
prev_endpoints = models.IntegerField(default=0, help_text='上次端点总数')
|
||||
prev_websites = models.IntegerField(default=0, help_text='上次网站总数')
|
||||
prev_vulns = models.IntegerField(default=0, help_text='上次漏洞总数')
|
||||
prev_assets = models.IntegerField(default=0, help_text='上次总资产数')
|
||||
|
||||
# ==================== 时间字段 ====================
|
||||
updated_at = models.DateTimeField(auto_now=True, help_text='最后更新时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'asset_statistics'
|
||||
verbose_name = '资产统计'
|
||||
verbose_name_plural = '资产统计'
|
||||
|
||||
def __str__(self):
|
||||
return f'AssetStatistics (updated: {self.updated_at})'
|
||||
|
||||
@classmethod
|
||||
def get_or_create_singleton(cls) -> 'AssetStatistics':
|
||||
"""获取或创建单例统计记录"""
|
||||
obj, _ = cls.objects.get_or_create(pk=1)
|
||||
return obj
|
||||
|
||||
|
||||
class StatisticsHistory(models.Model):
|
||||
"""
|
||||
统计历史表(用于折线图)
|
||||
|
||||
每天记录一条快照,用于展示趋势。
|
||||
由定时任务在刷新统计时自动写入。
|
||||
"""
|
||||
|
||||
date = models.DateField(unique=True, help_text='统计日期')
|
||||
|
||||
# 各类资产数量
|
||||
total_targets = models.IntegerField(default=0, help_text='目标总数')
|
||||
total_subdomains = models.IntegerField(default=0, help_text='子域名总数')
|
||||
total_ips = models.IntegerField(default=0, help_text='IP地址总数')
|
||||
total_endpoints = models.IntegerField(default=0, help_text='端点总数')
|
||||
total_websites = models.IntegerField(default=0, help_text='网站总数')
|
||||
total_vulns = models.IntegerField(default=0, help_text='漏洞总数')
|
||||
total_assets = models.IntegerField(default=0, help_text='总资产数')
|
||||
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'statistics_history'
|
||||
verbose_name = '统计历史'
|
||||
verbose_name_plural = '统计历史'
|
||||
ordering = ['-date']
|
||||
indexes = [
|
||||
models.Index(fields=['date']),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return f'StatisticsHistory ({self.date})'
|
||||
41
backend/apps/asset/repositories/__init__.py
Normal file
41
backend/apps/asset/repositories/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Asset Repositories - 数据访问层"""
|
||||
|
||||
# 资产模块 Repositories
|
||||
from .asset import (
|
||||
DjangoSubdomainRepository,
|
||||
DjangoWebSiteRepository,
|
||||
DjangoDirectoryRepository,
|
||||
DjangoHostPortMappingRepository,
|
||||
DjangoEndpointRepository,
|
||||
)
|
||||
|
||||
# 快照模块 Repositories
|
||||
from .snapshot import (
|
||||
DjangoSubdomainSnapshotRepository,
|
||||
DjangoHostPortMappingSnapshotRepository,
|
||||
DjangoWebsiteSnapshotRepository,
|
||||
DjangoDirectorySnapshotRepository,
|
||||
DjangoEndpointSnapshotRepository,
|
||||
)
|
||||
|
||||
# 统计模块 Repository
|
||||
from .statistics_repository import AssetStatisticsRepository
|
||||
|
||||
__all__ = [
|
||||
# 资产模块
|
||||
'DjangoSubdomainRepository',
|
||||
'DjangoWebSiteRepository',
|
||||
'DjangoDirectoryRepository',
|
||||
'DjangoHostPortMappingRepository',
|
||||
'DjangoEndpointRepository',
|
||||
# 快照模块
|
||||
'DjangoSubdomainSnapshotRepository',
|
||||
'DjangoHostPortMappingSnapshotRepository',
|
||||
'DjangoWebsiteSnapshotRepository',
|
||||
'DjangoDirectorySnapshotRepository',
|
||||
'DjangoEndpointSnapshotRepository',
|
||||
# 统计模块
|
||||
'AssetStatisticsRepository',
|
||||
]
|
||||
|
||||
|
||||
15
backend/apps/asset/repositories/asset/__init__.py
Normal file
15
backend/apps/asset/repositories/asset/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Asset Repositories - 数据访问层"""
|
||||
|
||||
from .subdomain_repository import DjangoSubdomainRepository
|
||||
from .website_repository import DjangoWebSiteRepository
|
||||
from .directory_repository import DjangoDirectoryRepository
|
||||
from .host_port_mapping_repository import DjangoHostPortMappingRepository
|
||||
from .endpoint_repository import DjangoEndpointRepository
|
||||
|
||||
__all__ = [
|
||||
'DjangoSubdomainRepository',
|
||||
'DjangoWebSiteRepository',
|
||||
'DjangoDirectoryRepository',
|
||||
'DjangoHostPortMappingRepository',
|
||||
'DjangoEndpointRepository',
|
||||
]
|
||||
249
backend/apps/asset/repositories/asset/directory_repository.py
Normal file
249
backend/apps/asset/repositories/asset/directory_repository.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
Django ORM 实现的 Directory Repository
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple, Dict, Iterator
|
||||
from django.db import transaction, IntegrityError, OperationalError, DatabaseError
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.asset.models.asset_models import Directory
|
||||
from apps.asset.dtos import DirectoryDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoDirectoryRepository:
|
||||
"""Django ORM 实现的 Directory Repository"""
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[DirectoryDTO]) -> int:
|
||||
"""
|
||||
批量创建 Directory,忽略冲突
|
||||
|
||||
Args:
|
||||
items: Directory DTO 列表
|
||||
|
||||
Returns:
|
||||
int: 实际创建的记录数
|
||||
|
||||
Raises:
|
||||
IntegrityError: 数据完整性错误
|
||||
OperationalError: 数据库操作错误
|
||||
DatabaseError: 数据库错误
|
||||
"""
|
||||
if not items:
|
||||
return 0
|
||||
|
||||
try:
|
||||
# 转换为 Django 模型对象
|
||||
directory_objects = [
|
||||
Directory(
|
||||
website_id=item.website_id,
|
||||
target_id=item.target_id,
|
||||
url=item.url,
|
||||
status=item.status,
|
||||
content_length=item.content_length,
|
||||
words=item.words,
|
||||
lines=item.lines,
|
||||
content_type=item.content_type,
|
||||
duration=item.duration
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
# 批量插入或忽略冲突
|
||||
# 如果 website + url 已存在,忽略冲突
|
||||
Directory.objects.bulk_create(
|
||||
directory_objects,
|
||||
ignore_conflicts=True
|
||||
)
|
||||
|
||||
logger.debug(f"成功处理 {len(items)} 条 Directory 记录")
|
||||
return len(items)
|
||||
|
||||
except IntegrityError as e:
|
||||
logger.error(
|
||||
f"批量插入 Directory 失败 - 数据完整性错误: {e}, "
|
||||
f"记录数: {len(items)}"
|
||||
)
|
||||
raise
|
||||
|
||||
except OperationalError as e:
|
||||
logger.error(
|
||||
f"批量插入 Directory 失败 - 数据库操作错误: {e}, "
|
||||
f"记录数: {len(items)}"
|
||||
)
|
||||
raise
|
||||
|
||||
except DatabaseError as e:
|
||||
logger.error(
|
||||
f"批量插入 Directory 失败 - 数据库错误: {e}, "
|
||||
f"记录数: {len(items)}"
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"批量插入 Directory 失败 - 未知错误: {e}, "
|
||||
f"记录数: {len(items)}, "
|
||||
f"错误类型: {type(e).__name__}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_website(self, website_id: int) -> List[DirectoryDTO]:
|
||||
"""
|
||||
获取指定站点的所有目录
|
||||
|
||||
Args:
|
||||
website_id: 站点 ID
|
||||
|
||||
Returns:
|
||||
List[DirectoryDTO]: 目录列表
|
||||
"""
|
||||
try:
|
||||
directories = Directory.objects.filter(website_id=website_id)
|
||||
return [
|
||||
DirectoryDTO(
|
||||
website_id=d.website_id,
|
||||
target_id=d.target_id,
|
||||
url=d.url,
|
||||
status=d.status,
|
||||
content_length=d.content_length,
|
||||
words=d.words,
|
||||
lines=d.lines,
|
||||
content_type=d.content_type,
|
||||
duration=d.duration
|
||||
)
|
||||
for d in directories
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取目录列表失败 - Website ID: {website_id}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def count_by_website(self, website_id: int) -> int:
|
||||
"""
|
||||
统计指定站点的目录总数
|
||||
|
||||
Args:
|
||||
website_id: 站点 ID
|
||||
|
||||
Returns:
|
||||
int: 目录总数
|
||||
"""
|
||||
try:
|
||||
count = Directory.objects.filter(website_id=website_id).count()
|
||||
logger.debug(f"Website {website_id} 的目录总数: {count}")
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"统计目录数量失败 - Website ID: {website_id}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_all(self):
|
||||
"""
|
||||
获取所有目录
|
||||
|
||||
Returns:
|
||||
QuerySet: 目录查询集
|
||||
"""
|
||||
return Directory.objects.all()
|
||||
|
||||
def get_by_target(self, target_id: int):
|
||||
return Directory.objects.filter(target_id=target_id).select_related('website').order_by('-discovered_at')
|
||||
|
||||
def get_urls_for_export(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
|
||||
"""流式导出目标下的所有目录 URL(只查 url 字段,避免加载多余数据)。"""
|
||||
try:
|
||||
queryset = (
|
||||
Directory.objects
|
||||
.filter(target_id=target_id)
|
||||
.values_list('url', flat=True)
|
||||
.order_by('url')
|
||||
.iterator(chunk_size=batch_size)
|
||||
)
|
||||
for url in queryset:
|
||||
yield url
|
||||
except Exception as e:
|
||||
logger.error("流式导出目录 URL 失败 - Target ID: %s, 错误: %s", target_id, e)
|
||||
raise
|
||||
|
||||
def soft_delete_by_ids(self, directory_ids: List[int]) -> int:
|
||||
"""
|
||||
根据 ID 列表批量软删除Directory
|
||||
|
||||
Args:
|
||||
directory_ids: Directory ID 列表
|
||||
|
||||
Returns:
|
||||
软删除的记录数
|
||||
"""
|
||||
try:
|
||||
updated_count = (
|
||||
Directory.objects
|
||||
.filter(id__in=directory_ids)
|
||||
.update(deleted_at=timezone.now())
|
||||
)
|
||||
logger.debug(
|
||||
"批量软删除Directory成功 - Count: %s, 更新记录: %s",
|
||||
len(directory_ids),
|
||||
updated_count
|
||||
)
|
||||
return updated_count
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"批量软删除Directory失败 - IDs: %s, 错误: %s",
|
||||
directory_ids,
|
||||
e
|
||||
)
|
||||
raise
|
||||
|
||||
def hard_delete_by_ids(self, directory_ids: List[int]) -> Tuple[int, Dict[str, int]]:
|
||||
"""
|
||||
根据 ID 列表硬删除Directory(使用数据库级 CASCADE)
|
||||
|
||||
Args:
|
||||
directory_ids: Directory ID 列表
|
||||
|
||||
Returns:
|
||||
(删除的记录数, 删除详情字典)
|
||||
"""
|
||||
try:
|
||||
batch_size = 1000
|
||||
total_deleted = 0
|
||||
|
||||
logger.debug(f"开始批量删除 {len(directory_ids)} 个Directory(数据库 CASCADE)...")
|
||||
|
||||
for i in range(0, len(directory_ids), batch_size):
|
||||
batch_ids = directory_ids[i:i + batch_size]
|
||||
count, _ = Directory.all_objects.filter(id__in=batch_ids).delete()
|
||||
total_deleted += count
|
||||
logger.debug(f"批次删除完成: {len(batch_ids)} 个Directory,删除 {count} 条记录")
|
||||
|
||||
deleted_details = {
|
||||
'directories': len(directory_ids),
|
||||
'total': total_deleted,
|
||||
'note': 'Database CASCADE - detailed stats unavailable'
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
"批量硬删除成功(CASCADE)- Directory数: %s, 总删除记录: %s",
|
||||
len(directory_ids),
|
||||
total_deleted
|
||||
)
|
||||
|
||||
return total_deleted, deleted_details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"批量硬删除失败(CASCADE)- Directory数: %s, 错误: %s",
|
||||
len(directory_ids),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
192
backend/apps/asset/repositories/asset/endpoint_repository.py
Normal file
192
backend/apps/asset/repositories/asset/endpoint_repository.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Endpoint Repository - Django ORM 实现"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
|
||||
from apps.asset.models import Endpoint
|
||||
from apps.asset.dtos.asset import EndpointDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from django.db import transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoEndpointRepository:
|
||||
"""端点 Repository - 负责端点表的数据访问"""
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[EndpointDTO]) -> int:
|
||||
"""
|
||||
批量创建端点(忽略冲突)
|
||||
|
||||
Args:
|
||||
items: 端点 DTO 列表
|
||||
|
||||
Returns:
|
||||
int: 创建的记录数
|
||||
"""
|
||||
if not items:
|
||||
return 0
|
||||
|
||||
try:
|
||||
endpoints = []
|
||||
for item in items:
|
||||
# Endpoint 模型当前只关联 target,不再依赖 website 外键
|
||||
# 这里按照 EndpointDTO 的字段映射构造 Endpoint 实例
|
||||
endpoints.append(Endpoint(
|
||||
target_id=item.target_id,
|
||||
url=item.url,
|
||||
host=item.host or '',
|
||||
title=item.title or '',
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
webserver=item.webserver or '',
|
||||
body_preview=item.body_preview or '',
|
||||
content_type=item.content_type or '',
|
||||
tech=item.tech if item.tech else [],
|
||||
vhost=item.vhost,
|
||||
location=item.location or '',
|
||||
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
|
||||
))
|
||||
|
||||
with transaction.atomic():
|
||||
created = Endpoint.objects.bulk_create(
|
||||
endpoints,
|
||||
ignore_conflicts=True,
|
||||
batch_size=1000
|
||||
)
|
||||
return len(created)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量创建端点失败: {e}")
|
||||
raise
|
||||
|
||||
def get_by_website(self, website_id: int) -> List[EndpointDTO]:
|
||||
"""
|
||||
获取网站下的所有端点
|
||||
|
||||
Args:
|
||||
website_id: 网站 ID
|
||||
|
||||
Returns:
|
||||
List[EndpointDTO]: 端点列表
|
||||
"""
|
||||
endpoints = Endpoint.objects.filter(
|
||||
website_id=website_id
|
||||
).order_by('-discovered_at')
|
||||
|
||||
result = []
|
||||
for endpoint in endpoints:
|
||||
result.append(EndpointDTO(
|
||||
website_id=endpoint.website_id,
|
||||
target_id=endpoint.target_id,
|
||||
url=endpoint.url,
|
||||
title=endpoint.title,
|
||||
status_code=endpoint.status_code,
|
||||
content_length=endpoint.content_length,
|
||||
webserver=endpoint.webserver,
|
||||
body_preview=endpoint.body_preview,
|
||||
content_type=endpoint.content_type,
|
||||
tech=endpoint.tech,
|
||||
vhost=endpoint.vhost,
|
||||
location=endpoint.location,
|
||||
matched_gf_patterns=endpoint.matched_gf_patterns
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
def get_queryset_by_target(self, target_id: int):
|
||||
return Endpoint.objects.filter(target_id=target_id).order_by('-discovered_at')
|
||||
|
||||
def get_all(self):
|
||||
"""获取所有端点(全局查询)"""
|
||||
return Endpoint.objects.all().order_by('-discovered_at')
|
||||
|
||||
def get_by_target(self, target_id: int) -> List[EndpointDTO]:
|
||||
"""
|
||||
获取目标下的所有端点
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
List[EndpointDTO]: 端点列表
|
||||
"""
|
||||
endpoints = Endpoint.objects.filter(
|
||||
target_id=target_id
|
||||
).order_by('-discovered_at')
|
||||
|
||||
result = []
|
||||
for endpoint in endpoints:
|
||||
result.append(EndpointDTO(
|
||||
website_id=endpoint.website_id,
|
||||
target_id=endpoint.target_id,
|
||||
url=endpoint.url,
|
||||
title=endpoint.title,
|
||||
status_code=endpoint.status_code,
|
||||
content_length=endpoint.content_length,
|
||||
webserver=endpoint.webserver,
|
||||
body_preview=endpoint.body_preview,
|
||||
content_type=endpoint.content_type,
|
||||
tech=endpoint.tech,
|
||||
vhost=endpoint.vhost,
|
||||
location=endpoint.location,
|
||||
matched_gf_patterns=endpoint.matched_gf_patterns
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
def count_by_website(self, website_id: int) -> int:
|
||||
"""
|
||||
统计网站下的端点数量
|
||||
|
||||
Args:
|
||||
website_id: 网站 ID
|
||||
|
||||
Returns:
|
||||
int: 端点数量
|
||||
"""
|
||||
return Endpoint.objects.filter(website_id=website_id).count()
|
||||
|
||||
def count_by_target(self, target_id: int) -> int:
|
||||
"""
|
||||
统计目标下的端点数量
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
int: 端点数量
|
||||
"""
|
||||
return Endpoint.objects.filter(target_id=target_id).count()
|
||||
|
||||
def soft_delete_by_ids(self, ids: List[int]) -> int:
|
||||
"""
|
||||
软删除端点(批量)
|
||||
|
||||
Args:
|
||||
ids: 端点 ID 列表
|
||||
|
||||
Returns:
|
||||
int: 更新的记录数
|
||||
"""
|
||||
from django.utils import timezone
|
||||
return Endpoint.objects.filter(
|
||||
id__in=ids
|
||||
).update(deleted_at=timezone.now())
|
||||
|
||||
def hard_delete_by_ids(self, ids: List[int]) -> Tuple[int, Dict[str, int]]:
|
||||
"""
|
||||
硬删除端点(批量)
|
||||
|
||||
Args:
|
||||
ids: 端点 ID 列表
|
||||
|
||||
Returns:
|
||||
Tuple[int, Dict[str, int]]: (删除总数, 详细信息)
|
||||
"""
|
||||
deleted_count, details = Endpoint.all_objects.filter(
|
||||
id__in=ids
|
||||
).delete()
|
||||
|
||||
return deleted_count, details
|
||||
@@ -0,0 +1,167 @@
|
||||
"""HostPortMapping Repository - Django ORM 实现"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
|
||||
from apps.asset.models.asset_models import HostPortMapping
|
||||
from apps.asset.dtos.asset import HostPortMappingDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoHostPortMappingRepository:
|
||||
"""HostPortMapping Repository - Django ORM 实现"""
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[HostPortMappingDTO]) -> int:
|
||||
"""
|
||||
批量创建主机端口关联(忽略冲突)
|
||||
|
||||
Args:
|
||||
items: 主机端口关联 DTO 列表
|
||||
|
||||
Returns:
|
||||
int: 实际创建的记录数(注意:ignore_conflicts 时可能为 0)
|
||||
|
||||
Note:
|
||||
- 基于唯一约束 (target + host + ip + port) 自动去重
|
||||
- 忽略已存在的记录,不更新
|
||||
"""
|
||||
try:
|
||||
logger.debug("准备批量创建主机端口关联 - 数量: %d", len(items))
|
||||
|
||||
if not items:
|
||||
logger.debug("主机端口关联为空,跳过创建")
|
||||
return 0
|
||||
|
||||
# 构建记录对象
|
||||
records = []
|
||||
for item in items:
|
||||
records.append(HostPortMapping(
|
||||
target_id=item.target_id,
|
||||
host=item.host,
|
||||
ip=item.ip,
|
||||
port=item.port
|
||||
))
|
||||
|
||||
# 批量创建(忽略冲突,基于唯一约束去重)
|
||||
created = HostPortMapping.objects.bulk_create(
|
||||
records,
|
||||
ignore_conflicts=True
|
||||
)
|
||||
|
||||
created_count = len(created) if created else 0
|
||||
logger.debug("主机端口关联创建完成 - 数量: %d", created_count)
|
||||
|
||||
return created_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"批量创建主机端口关联失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_for_export(self, target_id: int, batch_size: int = 1000):
|
||||
queryset = (
|
||||
HostPortMapping.objects
|
||||
.filter(target_id=target_id)
|
||||
.order_by("host", "port")
|
||||
.values("host", "port")
|
||||
.iterator(chunk_size=batch_size)
|
||||
)
|
||||
for item in queryset:
|
||||
yield item
|
||||
|
||||
def get_ips_for_export(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
|
||||
"""流式导出目标下的所有唯一 IP 地址。"""
|
||||
queryset = (
|
||||
HostPortMapping.objects
|
||||
.filter(target_id=target_id)
|
||||
.values_list("ip", flat=True)
|
||||
.distinct()
|
||||
.order_by("ip")
|
||||
.iterator(chunk_size=batch_size)
|
||||
)
|
||||
for ip in queryset:
|
||||
yield ip
|
||||
|
||||
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
|
||||
from django.db.models import Min
|
||||
|
||||
qs = HostPortMapping.objects.filter(target_id=target_id)
|
||||
if search:
|
||||
qs = qs.filter(ip__icontains=search)
|
||||
|
||||
ip_aggregated = (
|
||||
qs
|
||||
.values('ip')
|
||||
.annotate(
|
||||
discovered_at=Min('discovered_at')
|
||||
)
|
||||
.order_by('-discovered_at')
|
||||
)
|
||||
|
||||
results = []
|
||||
for item in ip_aggregated:
|
||||
ip = item['ip']
|
||||
mappings = (
|
||||
HostPortMapping.objects
|
||||
.filter(target_id=target_id, ip=ip)
|
||||
.values('host', 'port')
|
||||
.distinct()
|
||||
)
|
||||
|
||||
hosts = sorted({m['host'] for m in mappings})
|
||||
ports = sorted({m['port'] for m in mappings})
|
||||
|
||||
results.append({
|
||||
'ip': ip,
|
||||
'hosts': hosts,
|
||||
'ports': ports,
|
||||
'discovered_at': item['discovered_at'],
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def get_all_ip_aggregation(self, search: str = None):
|
||||
"""获取所有 IP 聚合数据(全局查询)"""
|
||||
from django.db.models import Min
|
||||
|
||||
qs = HostPortMapping.objects.all()
|
||||
if search:
|
||||
qs = qs.filter(ip__icontains=search)
|
||||
|
||||
ip_aggregated = (
|
||||
qs
|
||||
.values('ip')
|
||||
.annotate(
|
||||
discovered_at=Min('discovered_at')
|
||||
)
|
||||
.order_by('-discovered_at')
|
||||
)
|
||||
|
||||
results = []
|
||||
for item in ip_aggregated:
|
||||
ip = item['ip']
|
||||
mappings = (
|
||||
HostPortMapping.objects
|
||||
.filter(ip=ip)
|
||||
.values('host', 'port')
|
||||
.distinct()
|
||||
)
|
||||
|
||||
hosts = sorted({m['host'] for m in mappings})
|
||||
ports = sorted({m['port'] for m in mappings})
|
||||
|
||||
results.append({
|
||||
'ip': ip,
|
||||
'hosts': hosts,
|
||||
'ports': ports,
|
||||
'discovered_at': item['discovered_at'],
|
||||
})
|
||||
|
||||
return results
|
||||
256
backend/apps/asset/repositories/asset/subdomain_repository.py
Normal file
256
backend/apps/asset/repositories/asset/subdomain_repository.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
|
||||
from django.db import transaction, IntegrityError, OperationalError, DatabaseError
|
||||
from django.utils import timezone
|
||||
from typing import Tuple, Dict
|
||||
|
||||
from apps.asset.models.asset_models import Subdomain
|
||||
from apps.asset.dtos import SubdomainDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoSubdomainRepository:
|
||||
"""基于 Django ORM 的子域名仓储实现。"""
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
|
||||
"""
|
||||
批量创建子域名,忽略冲突
|
||||
|
||||
Args:
|
||||
items: 子域名 DTO 列表
|
||||
|
||||
Raises:
|
||||
IntegrityError: 数据完整性错误(如唯一约束冲突)
|
||||
OperationalError: 数据库操作错误(如连接失败)
|
||||
DatabaseError: 其他数据库错误
|
||||
"""
|
||||
if not items:
|
||||
return
|
||||
|
||||
try:
|
||||
subdomain_objects = [
|
||||
Subdomain(
|
||||
name=item.name,
|
||||
target_id=item.target_id,
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
# 使用 ignore_conflicts 策略:
|
||||
# - 新子域名:INSERT 完整记录
|
||||
# - 已存在子域名:忽略(不更新,因为没有探测字段数据)
|
||||
# 注意:ignore_conflicts 无法返回实际创建的数量
|
||||
Subdomain.objects.bulk_create( # type: ignore[attr-defined]
|
||||
subdomain_objects,
|
||||
ignore_conflicts=True, # 忽略重复记录
|
||||
)
|
||||
|
||||
logger.debug(f"成功处理 {len(items)} 条子域名记录")
|
||||
|
||||
except IntegrityError as e:
|
||||
logger.error(
|
||||
f"批量插入子域名失败 - 数据完整性错误: {e}, "
|
||||
f"记录数: {len(items)}, "
|
||||
f"示例域名: {items[0].name if items else 'N/A'}"
|
||||
)
|
||||
raise
|
||||
|
||||
except OperationalError as e:
|
||||
logger.error(
|
||||
f"批量插入子域名失败 - 数据库操作错误: {e}, "
|
||||
f"记录数: {len(items)}"
|
||||
)
|
||||
raise
|
||||
|
||||
except DatabaseError as e:
|
||||
logger.error(
|
||||
f"批量插入子域名失败 - 数据库错误: {e}, "
|
||||
f"记录数: {len(items)}"
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"批量插入子域名失败 - 未知错误: {e}, "
|
||||
f"记录数: {len(items)}, "
|
||||
f"错误类型: {type(e).__name__}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_or_create(self, name: str, target_id: int) -> Tuple[Subdomain, bool]:
|
||||
"""
|
||||
获取或创建子域名
|
||||
|
||||
Args:
|
||||
name: 子域名名称
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
(Subdomain对象, 是否新创建)
|
||||
"""
|
||||
return Subdomain.objects.get_or_create(
|
||||
name=name,
|
||||
target_id=target_id,
|
||||
)
|
||||
|
||||
def get_domains_for_export(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
|
||||
"""
|
||||
流式导出域名(用于生成扫描工具输入文件)
|
||||
|
||||
使用 iterator() 进行流式查询,避免一次性加载所有数据到内存
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
batch_size: 每次从数据库读取的行数
|
||||
|
||||
Yields:
|
||||
str: 域名
|
||||
"""
|
||||
queryset = Subdomain.objects.filter(
|
||||
target_id=target_id
|
||||
).only('name').iterator(chunk_size=batch_size)
|
||||
|
||||
for subdomain in queryset:
|
||||
yield subdomain.name
|
||||
|
||||
def get_by_target(self, target_id: int):
|
||||
return Subdomain.objects.filter(target_id=target_id).order_by('-discovered_at')
|
||||
|
||||
def count_by_target(self, target_id: int) -> int:
|
||||
"""
|
||||
统计目标下的域名数量
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
int: 域名数量
|
||||
"""
|
||||
return Subdomain.objects.filter(target_id=target_id).count()
|
||||
|
||||
def get_by_names_and_target_id(self, names: set, target_id: int) -> dict:
|
||||
"""
|
||||
根据域名列表和目标ID批量查询 Subdomain
|
||||
|
||||
Args:
|
||||
names: 域名集合
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
dict: {domain_name: Subdomain对象}
|
||||
"""
|
||||
subdomains = Subdomain.objects.filter(
|
||||
name__in=names,
|
||||
target_id=target_id
|
||||
).only('id', 'name')
|
||||
|
||||
return {sd.name: sd for sd in subdomains}
|
||||
|
||||
def get_all(self):
|
||||
"""
|
||||
获取所有子域名
|
||||
|
||||
Returns:
|
||||
QuerySet: 子域名查询集
|
||||
"""
|
||||
return Subdomain.objects.all()
|
||||
|
||||
def soft_delete_by_ids(self, subdomain_ids: List[int]) -> int:
|
||||
"""
|
||||
根据 ID 列表批量软删除子域名
|
||||
|
||||
Args:
|
||||
subdomain_ids: 子域名 ID 列表
|
||||
|
||||
Returns:
|
||||
软删除的记录数
|
||||
|
||||
Note:
|
||||
- 使用软删除:只标记为已删除,不真正删除数据库记录
|
||||
- 保留所有关联数据,可恢复
|
||||
"""
|
||||
try:
|
||||
updated_count = (
|
||||
Subdomain.objects
|
||||
.filter(id__in=subdomain_ids)
|
||||
.update(deleted_at=timezone.now())
|
||||
)
|
||||
logger.debug(
|
||||
"批量软删除子域名成功 - Count: %s, 更新记录: %s",
|
||||
len(subdomain_ids),
|
||||
updated_count
|
||||
)
|
||||
return updated_count
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"批量软删除子域名失败 - IDs: %s, 错误: %s",
|
||||
subdomain_ids,
|
||||
e
|
||||
)
|
||||
raise
|
||||
|
||||
def hard_delete_by_ids(self, subdomain_ids: List[int]) -> Tuple[int, Dict[str, int]]:
|
||||
"""
|
||||
根据 ID 列表硬删除子域名(使用数据库级 CASCADE)
|
||||
|
||||
Args:
|
||||
subdomain_ids: 子域名 ID 列表
|
||||
|
||||
Returns:
|
||||
(删除的记录数, 删除详情字典)
|
||||
|
||||
Strategy:
|
||||
使用数据库级 CASCADE 删除,性能最优
|
||||
|
||||
Note:
|
||||
- 硬删除:从数据库中永久删除
|
||||
- 数据库自动处理所有外键级联删除
|
||||
- 不触发 Django 信号(pre_delete/post_delete)
|
||||
"""
|
||||
try:
|
||||
batch_size = 1000 # 每批处理1000个子域名
|
||||
total_deleted = 0
|
||||
|
||||
logger.debug(f"开始批量删除 {len(subdomain_ids)} 个子域名(数据库 CASCADE)...")
|
||||
|
||||
# 分批处理子域名ID,避免单次删除过多
|
||||
for i in range(0, len(subdomain_ids), batch_size):
|
||||
batch_ids = subdomain_ids[i:i + batch_size]
|
||||
|
||||
# 直接删除子域名,数据库自动级联删除所有关联数据
|
||||
count, _ = Subdomain.all_objects.filter(id__in=batch_ids).delete()
|
||||
total_deleted += count
|
||||
|
||||
logger.debug(f"批次删除完成: {len(batch_ids)} 个子域名,删除 {count} 条记录")
|
||||
|
||||
# 由于使用数据库 CASCADE,无法获取详细统计
|
||||
deleted_details = {
|
||||
'subdomains': len(subdomain_ids),
|
||||
'total': total_deleted,
|
||||
'note': 'Database CASCADE - detailed stats unavailable'
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
"批量硬删除成功(CASCADE)- 子域名数: %s, 总删除记录: %s",
|
||||
len(subdomain_ids),
|
||||
total_deleted
|
||||
)
|
||||
|
||||
return total_deleted, deleted_details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"批量硬删除失败(CASCADE)- 子域名数: %s, 错误: %s",
|
||||
len(subdomain_ids),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
260
backend/apps/asset/repositories/asset/website_repository.py
Normal file
260
backend/apps/asset/repositories/asset/website_repository.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
Django ORM 实现的 WebSite Repository
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Generator, Tuple, Dict, Optional
|
||||
from django.db import transaction, IntegrityError, OperationalError, DatabaseError
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.asset.models.asset_models import WebSite
|
||||
from apps.asset.dtos import WebSiteDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoWebSiteRepository:
|
||||
"""Django ORM 实现的 WebSite Repository"""
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[WebSiteDTO]) -> None:
|
||||
"""
|
||||
批量创建 WebSite,忽略冲突
|
||||
|
||||
Args:
|
||||
items: WebSite DTO 列表
|
||||
|
||||
Raises:
|
||||
IntegrityError: 数据完整性错误
|
||||
OperationalError: 数据库操作错误
|
||||
DatabaseError: 数据库错误
|
||||
"""
|
||||
if not items:
|
||||
return
|
||||
|
||||
try:
|
||||
# 转换为 Django 模型对象
|
||||
website_objects = [
|
||||
WebSite(
|
||||
target_id=item.target_id,
|
||||
url=item.url,
|
||||
host=item.host,
|
||||
location=item.location,
|
||||
title=item.title,
|
||||
webserver=item.webserver,
|
||||
body_preview=item.body_preview,
|
||||
content_type=item.content_type,
|
||||
tech=item.tech,
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
vhost=item.vhost
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
# 批量插入或更新
|
||||
# 如果URL和目标已存在,忽略冲突
|
||||
WebSite.objects.bulk_create(
|
||||
website_objects,
|
||||
ignore_conflicts=True
|
||||
)
|
||||
|
||||
logger.debug(f"成功处理 {len(items)} 条 WebSite 记录")
|
||||
|
||||
except IntegrityError as e:
|
||||
logger.error(
|
||||
f"批量插入 WebSite 失败 - 数据完整性错误: {e}, "
|
||||
f"记录数: {len(items)}"
|
||||
)
|
||||
raise
|
||||
|
||||
except OperationalError as e:
|
||||
logger.error(
|
||||
f"批量插入 WebSite 失败 - 数据库操作错误: {e}, "
|
||||
f"记录数: {len(items)}"
|
||||
)
|
||||
raise
|
||||
|
||||
except DatabaseError as e:
|
||||
logger.error(
|
||||
f"批量插入 WebSite 失败 - 数据库错误: {e}, "
|
||||
f"记录数: {len(items)}"
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"批量插入 WebSite 失败 - 未知错误: {e}, "
|
||||
f"记录数: {len(items)}, "
|
||||
f"错误类型: {type(e).__name__}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_urls_for_export(self, target_id: int, batch_size: int = 1000) -> Generator[str, None, None]:
|
||||
"""
|
||||
流式导出目标下的所有站点 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
batch_size: 批次大小
|
||||
|
||||
Yields:
|
||||
str: 站点 URL
|
||||
"""
|
||||
try:
|
||||
# 查询目标下的站点,只选择 URL 字段,避免不必要的数据传输
|
||||
queryset = WebSite.objects.filter(
|
||||
target_id=target_id
|
||||
).values_list('url', flat=True).iterator(chunk_size=batch_size)
|
||||
|
||||
for url in queryset:
|
||||
yield url
|
||||
except Exception as e:
|
||||
logger.error(f"流式导出站点 URL 失败 - Target ID: {target_id}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_by_target(self, target_id: int):
|
||||
return WebSite.objects.filter(target_id=target_id).order_by('-discovered_at')
|
||||
|
||||
def count_by_target(self, target_id: int) -> int:
|
||||
"""
|
||||
统计目标下的站点总数
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
int: 站点总数
|
||||
"""
|
||||
try:
|
||||
count = WebSite.objects.filter(target_id=target_id).count()
|
||||
logger.debug(f"Target {target_id} 的站点总数: {count}")
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"统计站点数量失败 - Target ID: {target_id}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def count_by_scan(self, scan_id: int) -> int:
|
||||
"""
|
||||
统计扫描下的站点总数
|
||||
"""
|
||||
try:
|
||||
count = WebSite.objects.filter(scan_id=scan_id).count()
|
||||
logger.debug(f"Scan {scan_id} 的站点总数: {count}")
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"统计站点数量失败 - Scan ID: {scan_id}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_by_url(self, url: str, target_id: int) -> Optional[int]:
|
||||
"""
|
||||
根据 URL 和 target_id 查找站点 ID
|
||||
|
||||
Args:
|
||||
url: 站点 URL
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
Optional[int]: 站点 ID,如果不存在返回 None
|
||||
|
||||
Raises:
|
||||
ValueError: 发现多个站点时
|
||||
"""
|
||||
try:
|
||||
website = WebSite.objects.filter(url=url, target_id=target_id).first()
|
||||
if website:
|
||||
return website.id
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询站点失败 - URL: {url}, Target ID: {target_id}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_all(self):
|
||||
"""
|
||||
获取所有网站
|
||||
|
||||
Returns:
|
||||
QuerySet: 网站查询集
|
||||
"""
|
||||
return WebSite.objects.all()
|
||||
|
||||
def soft_delete_by_ids(self, website_ids: List[int]) -> int:
|
||||
"""
|
||||
根据 ID 列表批量软删除WebSite
|
||||
|
||||
Args:
|
||||
website_ids: WebSite ID 列表
|
||||
|
||||
Returns:
|
||||
软删除的记录数
|
||||
"""
|
||||
try:
|
||||
updated_count = (
|
||||
WebSite.objects
|
||||
.filter(id__in=website_ids)
|
||||
.update(deleted_at=timezone.now())
|
||||
)
|
||||
logger.debug(
|
||||
"批量软删除WebSite成功 - Count: %s, 更新记录: %s",
|
||||
len(website_ids),
|
||||
updated_count
|
||||
)
|
||||
return updated_count
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"批量软删除WebSite失败 - IDs: %s, 错误: %s",
|
||||
website_ids,
|
||||
e
|
||||
)
|
||||
raise
|
||||
|
||||
def hard_delete_by_ids(self, website_ids: List[int]) -> Tuple[int, Dict[str, int]]:
|
||||
"""
|
||||
根据 ID 列表硬删除WebSite(使用数据库级 CASCADE)
|
||||
|
||||
Args:
|
||||
website_ids: WebSite ID 列表
|
||||
|
||||
Returns:
|
||||
(删除的记录数, 删除详情字典)
|
||||
"""
|
||||
try:
|
||||
batch_size = 1000
|
||||
total_deleted = 0
|
||||
|
||||
logger.debug(f"开始批量删除 {len(website_ids)} 个WebSite(数据库 CASCADE)...")
|
||||
|
||||
for i in range(0, len(website_ids), batch_size):
|
||||
batch_ids = website_ids[i:i + batch_size]
|
||||
count, _ = WebSite.all_objects.filter(id__in=batch_ids).delete()
|
||||
total_deleted += count
|
||||
logger.debug(f"批次删除完成: {len(batch_ids)} 个WebSite,删除 {count} 条记录")
|
||||
|
||||
deleted_details = {
|
||||
'websites': len(website_ids),
|
||||
'total': total_deleted,
|
||||
'note': 'Database CASCADE - detailed stats unavailable'
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
"批量硬删除成功(CASCADE)- WebSite数: %s, 总删除记录: %s",
|
||||
len(website_ids),
|
||||
total_deleted
|
||||
)
|
||||
|
||||
return total_deleted, deleted_details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"批量硬删除失败(CASCADE)- WebSite数: %s, 错误: %s",
|
||||
len(website_ids),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
17
backend/apps/asset/repositories/snapshot/__init__.py
Normal file
17
backend/apps/asset/repositories/snapshot/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Snapshot Repositories - 数据访问层"""
|
||||
|
||||
from .subdomain_snapshot_repository import DjangoSubdomainSnapshotRepository
|
||||
from .host_port_mapping_snapshot_repository import DjangoHostPortMappingSnapshotRepository
|
||||
from .website_snapshot_repository import DjangoWebsiteSnapshotRepository
|
||||
from .directory_snapshot_repository import DjangoDirectorySnapshotRepository
|
||||
from .endpoint_snapshot_repository import DjangoEndpointSnapshotRepository
|
||||
from .vulnerability_snapshot_repository import DjangoVulnerabilitySnapshotRepository
|
||||
|
||||
__all__ = [
|
||||
'DjangoSubdomainSnapshotRepository',
|
||||
'DjangoHostPortMappingSnapshotRepository',
|
||||
'DjangoWebsiteSnapshotRepository',
|
||||
'DjangoDirectorySnapshotRepository',
|
||||
'DjangoEndpointSnapshotRepository',
|
||||
'DjangoVulnerabilitySnapshotRepository',
|
||||
]
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Directory Snapshot Repository - 目录快照数据访问层"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from django.db import transaction
|
||||
|
||||
from apps.asset.models import DirectorySnapshot
|
||||
from apps.asset.dtos.snapshot import DirectorySnapshotDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoDirectorySnapshotRepository:
|
||||
"""
|
||||
目录快照仓储(Django ORM 实现)
|
||||
|
||||
负责目录快照表的数据访问操作
|
||||
"""
|
||||
|
||||
def save_snapshots(self, items: List[DirectorySnapshotDTO]) -> None:
|
||||
"""
|
||||
批量保存目录快照记录
|
||||
|
||||
使用 ignore_conflicts 策略,如果快照已存在(相同 scan + url)则跳过
|
||||
|
||||
Args:
|
||||
items: 目录快照 DTO 列表
|
||||
|
||||
Raises:
|
||||
ValueError: items 为空
|
||||
Exception: 数据库操作失败
|
||||
"""
|
||||
if not items:
|
||||
logger.warning("目录快照列表为空,跳过保存")
|
||||
return
|
||||
|
||||
try:
|
||||
# 转换为 Django 模型对象
|
||||
snapshot_objects = [
|
||||
DirectorySnapshot(
|
||||
scan_id=item.scan_id,
|
||||
url=item.url,
|
||||
status=item.status,
|
||||
content_length=item.content_length,
|
||||
words=item.words,
|
||||
lines=item.lines,
|
||||
content_type=item.content_type,
|
||||
duration=item.duration
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
# 批量插入,忽略冲突
|
||||
# 如果 scan + url 已存在,跳过
|
||||
DirectorySnapshot.objects.bulk_create(
|
||||
snapshot_objects,
|
||||
ignore_conflicts=True
|
||||
)
|
||||
|
||||
logger.debug("成功保存 %d 条目录快照记录", len(items))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"批量保存目录快照失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return DirectorySnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
|
||||
|
||||
def get_all(self):
|
||||
return DirectorySnapshot.objects.all().order_by('-discovered_at')
|
||||
@@ -0,0 +1,74 @@
|
||||
"""EndpointSnapshot Repository - Django ORM 实现"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from apps.asset.models.snapshot_models import EndpointSnapshot
|
||||
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoEndpointSnapshotRepository:
|
||||
"""端点快照 Repository - 负责端点快照表的数据访问"""
|
||||
|
||||
def save_snapshots(self, items: List[EndpointSnapshotDTO]) -> None:
|
||||
"""
|
||||
保存端点快照
|
||||
|
||||
Args:
|
||||
items: 端点快照 DTO 列表
|
||||
|
||||
Note:
|
||||
- 保存完整的快照数据
|
||||
- 基于唯一约束 (scan + url) 自动去重
|
||||
"""
|
||||
try:
|
||||
logger.debug("准备保存端点快照 - 数量: %d", len(items))
|
||||
|
||||
if not items:
|
||||
logger.debug("端点快照为空,跳过保存")
|
||||
return
|
||||
|
||||
# 构建快照对象
|
||||
snapshots = []
|
||||
for item in items:
|
||||
snapshots.append(EndpointSnapshot(
|
||||
scan_id=item.scan_id,
|
||||
url=item.url,
|
||||
title=item.title,
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
location=item.location,
|
||||
webserver=item.webserver,
|
||||
content_type=item.content_type,
|
||||
tech=item.tech if item.tech else [],
|
||||
body_preview=item.body_preview,
|
||||
vhost=item.vhost,
|
||||
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
|
||||
))
|
||||
|
||||
# 批量创建(忽略冲突,基于唯一约束去重)
|
||||
EndpointSnapshot.objects.bulk_create(
|
||||
snapshots,
|
||||
ignore_conflicts=True
|
||||
)
|
||||
|
||||
logger.debug("端点快照保存成功 - 数量: %d", len(snapshots))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"保存端点快照失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return EndpointSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
|
||||
|
||||
def get_all(self):
|
||||
return EndpointSnapshot.objects.all().order_by('-discovered_at')
|
||||
@@ -0,0 +1,145 @@
|
||||
"""HostPortMappingSnapshot Repository - Django ORM 实现"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
|
||||
from apps.asset.models.snapshot_models import HostPortMappingSnapshot
|
||||
from apps.asset.dtos.snapshot import HostPortMappingSnapshotDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoHostPortMappingSnapshotRepository:
|
||||
"""HostPortMappingSnapshot Repository - Django ORM 实现,负责主机端口映射快照表的数据访问"""
|
||||
|
||||
def save_snapshots(self, items: List[HostPortMappingSnapshotDTO]) -> None:
|
||||
"""
|
||||
保存主机端口关联快照
|
||||
|
||||
Args:
|
||||
items: 主机端口关联快照 DTO 列表
|
||||
|
||||
Note:
|
||||
- 保存完整的快照数据
|
||||
- 基于唯一约束 (scan + host + ip + port) 自动去重
|
||||
"""
|
||||
try:
|
||||
logger.debug("准备保存主机端口关联快照 - 数量: %d", len(items))
|
||||
|
||||
if not items:
|
||||
logger.debug("主机端口关联快照为空,跳过保存")
|
||||
return
|
||||
|
||||
# 构建快照对象
|
||||
snapshots = []
|
||||
for item in items:
|
||||
snapshots.append(HostPortMappingSnapshot(
|
||||
scan_id=item.scan_id,
|
||||
host=item.host,
|
||||
ip=item.ip,
|
||||
port=item.port
|
||||
))
|
||||
|
||||
# 批量创建(忽略冲突,基于唯一约束去重)
|
||||
HostPortMappingSnapshot.objects.bulk_create(
|
||||
snapshots,
|
||||
ignore_conflicts=True
|
||||
)
|
||||
|
||||
logger.debug("主机端口关联快照保存成功 - 数量: %d", len(snapshots))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"保存主机端口关联快照失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
|
||||
from django.db.models import Min
|
||||
|
||||
qs = HostPortMappingSnapshot.objects.filter(scan_id=scan_id)
|
||||
if search:
|
||||
qs = qs.filter(ip__icontains=search)
|
||||
|
||||
ip_aggregated = (
|
||||
qs
|
||||
.values('ip')
|
||||
.annotate(
|
||||
discovered_at=Min('discovered_at')
|
||||
)
|
||||
.order_by('-discovered_at')
|
||||
)
|
||||
|
||||
results = []
|
||||
for item in ip_aggregated:
|
||||
ip = item['ip']
|
||||
mappings = (
|
||||
HostPortMappingSnapshot.objects
|
||||
.filter(scan_id=scan_id, ip=ip)
|
||||
.values('host', 'port')
|
||||
.distinct()
|
||||
)
|
||||
|
||||
hosts = sorted({m['host'] for m in mappings})
|
||||
ports = sorted({m['port'] for m in mappings})
|
||||
|
||||
results.append({
|
||||
'ip': ip,
|
||||
'hosts': hosts,
|
||||
'ports': ports,
|
||||
'discovered_at': item['discovered_at'],
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def get_all_ip_aggregation(self, search: str = None):
|
||||
"""获取所有 IP 聚合数据"""
|
||||
from django.db.models import Min
|
||||
|
||||
qs = HostPortMappingSnapshot.objects.all()
|
||||
if search:
|
||||
qs = qs.filter(ip__icontains=search)
|
||||
|
||||
ip_aggregated = (
|
||||
qs
|
||||
.values('ip')
|
||||
.annotate(discovered_at=Min('discovered_at'))
|
||||
.order_by('-discovered_at')
|
||||
)
|
||||
|
||||
results = []
|
||||
for item in ip_aggregated:
|
||||
ip = item['ip']
|
||||
mappings = (
|
||||
HostPortMappingSnapshot.objects
|
||||
.filter(ip=ip)
|
||||
.values('host', 'port')
|
||||
.distinct()
|
||||
)
|
||||
hosts = sorted({m['host'] for m in mappings})
|
||||
ports = sorted({m['port'] for m in mappings})
|
||||
results.append({
|
||||
'ip': ip,
|
||||
'hosts': hosts,
|
||||
'ports': ports,
|
||||
'discovered_at': item['discovered_at'],
|
||||
})
|
||||
return results
|
||||
|
||||
def get_ips_for_export(self, scan_id: int, batch_size: int = 1000) -> Iterator[str]:
|
||||
"""流式导出扫描下的所有唯一 IP 地址。"""
|
||||
queryset = (
|
||||
HostPortMappingSnapshot.objects
|
||||
.filter(scan_id=scan_id)
|
||||
.values_list("ip", flat=True)
|
||||
.distinct()
|
||||
.order_by("ip")
|
||||
.iterator(chunk_size=batch_size)
|
||||
)
|
||||
for ip in queryset:
|
||||
yield ip
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Django ORM 实现的 SubdomainSnapshot Repository"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from apps.asset.models.snapshot_models import SubdomainSnapshot
|
||||
from apps.asset.dtos import SubdomainSnapshotDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoSubdomainSnapshotRepository:
|
||||
"""子域名快照 Repository - 负责子域名快照表的数据访问"""
|
||||
|
||||
def save_subdomain_snapshots(self, items: List[SubdomainSnapshotDTO]) -> None:
|
||||
"""
|
||||
保存子域名快照
|
||||
|
||||
Args:
|
||||
items: 子域名快照 DTO 列表
|
||||
|
||||
Note:
|
||||
- 保存完整的快照数据
|
||||
- 基于唯一约束自动去重(忽略冲突)
|
||||
"""
|
||||
try:
|
||||
logger.debug("准备保存子域名快照 - 数量: %d", len(items))
|
||||
|
||||
if not items:
|
||||
logger.debug("子域名快照为空,跳过保存")
|
||||
return
|
||||
|
||||
# 构建快照对象
|
||||
snapshots = []
|
||||
for item in items:
|
||||
snapshots.append(SubdomainSnapshot(
|
||||
scan_id=item.scan_id,
|
||||
name=item.name,
|
||||
))
|
||||
|
||||
# 批量创建(忽略冲突,基于唯一约束去重)
|
||||
SubdomainSnapshot.objects.bulk_create(snapshots, ignore_conflicts=True)
|
||||
|
||||
logger.debug("子域名快照保存成功 - 数量: %d", len(snapshots))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"保存子域名快照失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return SubdomainSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
|
||||
|
||||
def get_all(self):
|
||||
return SubdomainSnapshot.objects.all().order_by('-discovered_at')
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Vulnerability Snapshot Repository - 漏洞快照数据访问层"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from django.db import transaction
|
||||
|
||||
from apps.asset.models import VulnerabilitySnapshot
|
||||
from apps.asset.dtos.snapshot import VulnerabilitySnapshotDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoVulnerabilitySnapshotRepository:
|
||||
"""漏洞快照仓储(Django ORM 实现)"""
|
||||
|
||||
def save_snapshots(self, items: List[VulnerabilitySnapshotDTO]) -> None:
|
||||
"""批量保存漏洞快照记录。
|
||||
|
||||
使用 ``ignore_conflicts`` 策略,如果快照已存在则跳过。
|
||||
具体唯一约束由数据库模型控制。
|
||||
"""
|
||||
if not items:
|
||||
logger.warning("漏洞快照列表为空,跳过保存")
|
||||
return
|
||||
|
||||
try:
|
||||
snapshot_objects = [
|
||||
VulnerabilitySnapshot(
|
||||
scan_id=item.scan_id,
|
||||
url=item.url,
|
||||
vuln_type=item.vuln_type,
|
||||
severity=item.severity,
|
||||
source=item.source,
|
||||
cvss_score=item.cvss_score,
|
||||
description=item.description,
|
||||
raw_output=item.raw_output,
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
with transaction.atomic():
|
||||
VulnerabilitySnapshot.objects.bulk_create(
|
||||
snapshot_objects,
|
||||
ignore_conflicts=True,
|
||||
)
|
||||
|
||||
logger.debug("成功保存 %d 条漏洞快照记录", len(items))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"批量保存漏洞快照失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
"""按扫描任务获取漏洞快照 QuerySet。"""
|
||||
return VulnerabilitySnapshot.objects.filter(scan_id=scan_id).order_by("-discovered_at")
|
||||
|
||||
def get_all(self):
|
||||
return VulnerabilitySnapshot.objects.all().order_by('-discovered_at')
|
||||
@@ -0,0 +1,74 @@
|
||||
"""WebsiteSnapshot Repository - Django ORM 实现"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from apps.asset.models.snapshot_models import WebsiteSnapshot
|
||||
from apps.asset.dtos.snapshot import WebsiteSnapshotDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoWebsiteSnapshotRepository:
|
||||
"""网站快照 Repository - 负责网站快照表的数据访问"""
|
||||
|
||||
def save_snapshots(self, items: List[WebsiteSnapshotDTO]) -> None:
|
||||
"""
|
||||
保存网站快照
|
||||
|
||||
Args:
|
||||
items: 网站快照 DTO 列表
|
||||
|
||||
Note:
|
||||
- 保存完整的快照数据
|
||||
- 基于唯一约束 (scan + subdomain + url) 自动去重
|
||||
"""
|
||||
try:
|
||||
logger.debug("准备保存网站快照 - 数量: %d", len(items))
|
||||
|
||||
if not items:
|
||||
logger.debug("网站快照为空,跳过保存")
|
||||
return
|
||||
|
||||
# 构建快照对象
|
||||
snapshots = []
|
||||
for item in items:
|
||||
snapshots.append(WebsiteSnapshot(
|
||||
scan_id=item.scan_id,
|
||||
url=item.url,
|
||||
host=item.host,
|
||||
title=item.title,
|
||||
status=item.status,
|
||||
content_length=item.content_length,
|
||||
location=item.location,
|
||||
web_server=item.web_server,
|
||||
content_type=item.content_type,
|
||||
tech=item.tech if item.tech else [],
|
||||
body_preview=item.body_preview,
|
||||
vhost=item.vhost
|
||||
))
|
||||
|
||||
# 批量创建(忽略冲突,基于唯一约束去重)
|
||||
WebsiteSnapshot.objects.bulk_create(
|
||||
snapshots,
|
||||
ignore_conflicts=True
|
||||
)
|
||||
|
||||
logger.debug("网站快照保存成功 - 数量: %d", len(snapshots))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"保存网站快照失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return WebsiteSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
|
||||
|
||||
def get_all(self):
|
||||
return WebsiteSnapshot.objects.all().order_by('-discovered_at')
|
||||
128
backend/apps/asset/repositories/statistics_repository.py
Normal file
128
backend/apps/asset/repositories/statistics_repository.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""资产统计 Repository"""
|
||||
import logging
|
||||
from datetime import date, timedelta
|
||||
from typing import Optional, List
|
||||
|
||||
from apps.asset.models import AssetStatistics, StatisticsHistory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssetStatisticsRepository:
|
||||
"""
|
||||
资产统计数据访问层
|
||||
|
||||
职责:
|
||||
- 读取/更新预聚合的统计数据
|
||||
"""
|
||||
|
||||
def get_statistics(self) -> Optional[AssetStatistics]:
|
||||
"""
|
||||
获取统计数据
|
||||
|
||||
Returns:
|
||||
统计数据对象,不存在则返回 None
|
||||
"""
|
||||
return AssetStatistics.objects.first()
|
||||
|
||||
def get_or_create_statistics(self) -> AssetStatistics:
|
||||
"""
|
||||
获取或创建统计数据(单例)
|
||||
|
||||
Returns:
|
||||
统计数据对象
|
||||
"""
|
||||
return AssetStatistics.get_or_create_singleton()
|
||||
|
||||
def update_statistics(
|
||||
self,
|
||||
total_targets: int,
|
||||
total_subdomains: int,
|
||||
total_ips: int,
|
||||
total_endpoints: int,
|
||||
total_websites: int,
|
||||
total_vulns: int,
|
||||
) -> AssetStatistics:
|
||||
"""
|
||||
更新统计数据
|
||||
|
||||
Args:
|
||||
total_targets: 目标总数
|
||||
total_subdomains: 子域名总数
|
||||
total_ips: IP 总数
|
||||
total_endpoints: 端点总数
|
||||
total_websites: 网站总数
|
||||
total_vulns: 漏洞总数
|
||||
|
||||
Returns:
|
||||
更新后的统计数据对象
|
||||
"""
|
||||
stats = self.get_or_create_statistics()
|
||||
|
||||
# 1. 保存当前值到 prev_* 字段
|
||||
stats.prev_targets = stats.total_targets
|
||||
stats.prev_subdomains = stats.total_subdomains
|
||||
stats.prev_ips = stats.total_ips
|
||||
stats.prev_endpoints = stats.total_endpoints
|
||||
stats.prev_websites = stats.total_websites
|
||||
stats.prev_vulns = stats.total_vulns
|
||||
stats.prev_assets = stats.total_assets
|
||||
|
||||
# 2. 更新当前值
|
||||
stats.total_targets = total_targets
|
||||
stats.total_subdomains = total_subdomains
|
||||
stats.total_ips = total_ips
|
||||
stats.total_endpoints = total_endpoints
|
||||
stats.total_websites = total_websites
|
||||
stats.total_vulns = total_vulns
|
||||
stats.total_assets = total_subdomains + total_ips + total_endpoints + total_websites
|
||||
stats.save()
|
||||
|
||||
logger.info(
|
||||
"更新资产统计: targets=%d, subdomains=%d, ips=%d, endpoints=%d, websites=%d, vulns=%d, assets=%d",
|
||||
total_targets, total_subdomains, total_ips, total_endpoints, total_websites, total_vulns, stats.total_assets
|
||||
)
|
||||
return stats
|
||||
|
||||
def save_daily_snapshot(self, stats: AssetStatistics) -> StatisticsHistory:
|
||||
"""
|
||||
保存每日统计快照(幂等,每天只存一条)
|
||||
|
||||
Args:
|
||||
stats: 当前统计数据
|
||||
|
||||
Returns:
|
||||
历史记录对象
|
||||
"""
|
||||
history, created = StatisticsHistory.objects.update_or_create(
|
||||
date=date.today(),
|
||||
defaults={
|
||||
'total_targets': stats.total_targets,
|
||||
'total_subdomains': stats.total_subdomains,
|
||||
'total_ips': stats.total_ips,
|
||||
'total_endpoints': stats.total_endpoints,
|
||||
'total_websites': stats.total_websites,
|
||||
'total_vulns': stats.total_vulns,
|
||||
'total_assets': stats.total_assets,
|
||||
}
|
||||
)
|
||||
action = "创建" if created else "更新"
|
||||
logger.info(f"{action}统计快照: date={history.date}, assets={history.total_assets}")
|
||||
return history
|
||||
|
||||
def get_history(self, days: int = 7) -> List[StatisticsHistory]:
|
||||
"""
|
||||
获取历史统计数据(用于折线图)
|
||||
|
||||
Args:
|
||||
days: 获取最近多少天的数据,默认 7 天
|
||||
|
||||
Returns:
|
||||
历史记录列表,按日期升序
|
||||
"""
|
||||
start_date = date.today() - timedelta(days=days - 1)
|
||||
return list(
|
||||
StatisticsHistory.objects
|
||||
.filter(date__gte=start_date)
|
||||
.order_by('date')
|
||||
)
|
||||
291
backend/apps/asset/serializers.py
Normal file
291
backend/apps/asset/serializers.py
Normal file
@@ -0,0 +1,291 @@
|
||||
from rest_framework import serializers
|
||||
from .models import Subdomain, WebSite, Directory, HostPortMapping, Endpoint, Vulnerability
|
||||
from .models.snapshot_models import (
|
||||
SubdomainSnapshot,
|
||||
WebsiteSnapshot,
|
||||
DirectorySnapshot,
|
||||
EndpointSnapshot,
|
||||
VulnerabilitySnapshot,
|
||||
)
|
||||
|
||||
|
||||
# 注意:IPAddress 和 Port 模型已被重构为 HostPortMapping
|
||||
# 以下是基于新架构的序列化器实现
|
||||
|
||||
# class PortSerializer(serializers.ModelSerializer):
|
||||
# """端口序列化器"""
|
||||
#
|
||||
# class Meta:
|
||||
# model = Port
|
||||
# fields = ['number', 'service_name', 'description', 'is_uncommon']
|
||||
|
||||
|
||||
class SubdomainSerializer(serializers.ModelSerializer):
|
||||
"""子域名序列化器"""
|
||||
|
||||
class Meta:
|
||||
model = Subdomain
|
||||
fields = [
|
||||
'id', 'name', 'discovered_at', 'target'
|
||||
]
|
||||
read_only_fields = ['id', 'discovered_at']
|
||||
|
||||
|
||||
class SubdomainListSerializer(serializers.ModelSerializer):
|
||||
"""子域名列表序列化器(用于扫描详情)"""
|
||||
|
||||
# 注意:Subdomain 模型已简化,只保留核心字段
|
||||
# cname, is_cdn, cdn_name 等字段已移至 SubdomainSnapshot
|
||||
# ports 和 ip_addresses 关系已被重构为 HostPortMapping
|
||||
|
||||
class Meta:
|
||||
model = Subdomain
|
||||
fields = [
|
||||
'id', 'name', 'discovered_at'
|
||||
]
|
||||
read_only_fields = ['id', 'discovered_at']
|
||||
|
||||
|
||||
# class IPAddressListSerializer(serializers.ModelSerializer):
|
||||
# """IP 地址列表序列化器"""
|
||||
#
|
||||
# subdomain = serializers.CharField(source='subdomain.name', allow_blank=True, default='')
|
||||
# created_at = serializers.DateTimeField(read_only=True)
|
||||
# ports = PortSerializer(many=True, read_only=True)
|
||||
#
|
||||
# class Meta:
|
||||
# model = IPAddress
|
||||
# fields = [
|
||||
# 'id',
|
||||
# 'ip',
|
||||
# 'subdomain',
|
||||
# 'reverse_pointer',
|
||||
# 'created_at',
|
||||
# 'ports',
|
||||
# ]
|
||||
# read_only_fields = fields
|
||||
|
||||
|
||||
class WebSiteSerializer(serializers.ModelSerializer):
|
||||
"""站点序列化器"""
|
||||
|
||||
subdomain = serializers.CharField(source='subdomain.name', allow_blank=True, default='')
|
||||
|
||||
class Meta:
|
||||
model = WebSite
|
||||
fields = [
|
||||
'id',
|
||||
'url',
|
||||
'location',
|
||||
'title',
|
||||
'webserver',
|
||||
'content_type',
|
||||
'status_code',
|
||||
'content_length',
|
||||
'body_preview',
|
||||
'tech',
|
||||
'vhost',
|
||||
'subdomain',
|
||||
'discovered_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
|
||||
class VulnerabilitySerializer(serializers.ModelSerializer):
|
||||
"""漏洞资产序列化器(按目标查看漏洞资产)。"""
|
||||
|
||||
class Meta:
|
||||
model = Vulnerability
|
||||
fields = [
|
||||
'id',
|
||||
'target',
|
||||
'url',
|
||||
'vuln_type',
|
||||
'severity',
|
||||
'source',
|
||||
'cvss_score',
|
||||
'description',
|
||||
'raw_output',
|
||||
'discovered_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
|
||||
class VulnerabilitySnapshotSerializer(serializers.ModelSerializer):
|
||||
"""漏洞快照序列化器(用于扫描历史漏洞列表)。"""
|
||||
|
||||
class Meta:
|
||||
model = VulnerabilitySnapshot
|
||||
fields = [
|
||||
'id',
|
||||
'url',
|
||||
'vuln_type',
|
||||
'severity',
|
||||
'source',
|
||||
'cvss_score',
|
||||
'description',
|
||||
'raw_output',
|
||||
'discovered_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
|
||||
class EndpointListSerializer(serializers.ModelSerializer):
|
||||
"""端点列表序列化器(用于目标端点列表页)"""
|
||||
|
||||
# 将 GF 匹配模式映射为前端使用的 tags 字段
|
||||
tags = serializers.ListField(
|
||||
child=serializers.CharField(),
|
||||
source='matched_gf_patterns',
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = Endpoint
|
||||
fields = [
|
||||
'id',
|
||||
'url',
|
||||
'location',
|
||||
'status_code',
|
||||
'title',
|
||||
'content_length',
|
||||
'content_type',
|
||||
'webserver',
|
||||
'body_preview',
|
||||
'tech',
|
||||
'vhost',
|
||||
'tags',
|
||||
'discovered_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
|
||||
class DirectorySerializer(serializers.ModelSerializer):
|
||||
"""目录序列化器"""
|
||||
|
||||
website_url = serializers.CharField(source='website.url', read_only=True)
|
||||
discovered_at = serializers.DateTimeField(read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = Directory
|
||||
fields = [
|
||||
'id',
|
||||
'url',
|
||||
'status',
|
||||
'content_length',
|
||||
'words',
|
||||
'lines',
|
||||
'content_type',
|
||||
'duration',
|
||||
'website_url',
|
||||
'discovered_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
|
||||
class IPAddressAggregatedSerializer(serializers.Serializer):
|
||||
"""
|
||||
IP 地址聚合序列化器
|
||||
|
||||
基于 HostPortMapping 模型,按 IP 聚合显示:
|
||||
- ip: IP 地址
|
||||
- hosts: 该 IP 关联的所有主机名列表
|
||||
- ports: 该 IP 关联的所有端口列表
|
||||
- discovered_at: 首次发现时间
|
||||
"""
|
||||
ip = serializers.IPAddressField(read_only=True)
|
||||
hosts = serializers.ListField(child=serializers.CharField(), read_only=True)
|
||||
ports = serializers.ListField(child=serializers.IntegerField(), read_only=True)
|
||||
discovered_at = serializers.DateTimeField(read_only=True)
|
||||
|
||||
|
||||
# ==================== 快照序列化器 ====================
|
||||
|
||||
class SubdomainSnapshotSerializer(serializers.ModelSerializer):
|
||||
"""子域名快照序列化器(用于扫描历史)"""
|
||||
|
||||
class Meta:
|
||||
model = SubdomainSnapshot
|
||||
fields = ['id', 'name', 'discovered_at']
|
||||
read_only_fields = fields
|
||||
|
||||
|
||||
class WebsiteSnapshotSerializer(serializers.ModelSerializer):
|
||||
"""网站快照序列化器(用于扫描历史)"""
|
||||
|
||||
subdomain_name = serializers.CharField(source='subdomain.name', read_only=True)
|
||||
webserver = serializers.CharField(source='web_server', read_only=True) # 映射字段名
|
||||
status_code = serializers.IntegerField(source='status', read_only=True) # 映射字段名
|
||||
|
||||
class Meta:
|
||||
model = WebsiteSnapshot
|
||||
fields = [
|
||||
'id',
|
||||
'url',
|
||||
'location',
|
||||
'title',
|
||||
'webserver', # 使用映射后的字段名
|
||||
'content_type',
|
||||
'status_code', # 使用映射后的字段名
|
||||
'content_length',
|
||||
'body_preview',
|
||||
'tech',
|
||||
'vhost',
|
||||
'subdomain_name',
|
||||
'discovered_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
|
||||
class DirectorySnapshotSerializer(serializers.ModelSerializer):
|
||||
"""目录快照序列化器(用于扫描历史)"""
|
||||
|
||||
# DirectorySnapshot 当前不再关联 Website,这里暂时将 website_url 映射为自身的 url,保证字段兼容
|
||||
website_url = serializers.CharField(source='url', read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = DirectorySnapshot
|
||||
fields = [
|
||||
'id',
|
||||
'url',
|
||||
'status',
|
||||
'content_length',
|
||||
'words',
|
||||
'lines',
|
||||
'content_type',
|
||||
'duration',
|
||||
'website_url',
|
||||
'discovered_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
|
||||
class EndpointSnapshotSerializer(serializers.ModelSerializer):
|
||||
"""端点快照序列化器(用于扫描历史)"""
|
||||
|
||||
# 将 GF 匹配模式映射为前端使用的 tags 字段
|
||||
tags = serializers.ListField(
|
||||
child=serializers.CharField(),
|
||||
source='matched_gf_patterns',
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = EndpointSnapshot
|
||||
fields = [
|
||||
'id',
|
||||
'url',
|
||||
'host',
|
||||
'location',
|
||||
'title',
|
||||
'webserver',
|
||||
'content_type',
|
||||
'status_code',
|
||||
'content_length',
|
||||
'body_preview',
|
||||
'tech',
|
||||
'vhost',
|
||||
'tags',
|
||||
'discovered_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
43
backend/apps/asset/services/__init__.py
Normal file
43
backend/apps/asset/services/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Asset Services - 业务逻辑层"""
|
||||
|
||||
# 资产模块 Services
|
||||
from .asset import (
|
||||
SubdomainService,
|
||||
WebSiteService,
|
||||
DirectoryService,
|
||||
HostPortMappingService,
|
||||
EndpointService,
|
||||
VulnerabilityService,
|
||||
)
|
||||
|
||||
# 快照模块 Services
|
||||
from .snapshot import (
|
||||
SubdomainSnapshotsService,
|
||||
HostPortMappingSnapshotsService,
|
||||
WebsiteSnapshotsService,
|
||||
DirectorySnapshotsService,
|
||||
EndpointSnapshotsService,
|
||||
VulnerabilitySnapshotsService,
|
||||
)
|
||||
|
||||
# 统计模块 Service
|
||||
from .statistics_service import AssetStatisticsService
|
||||
|
||||
__all__ = [
|
||||
# 资产模块
|
||||
'SubdomainService',
|
||||
'WebSiteService',
|
||||
'DirectoryService',
|
||||
'HostPortMappingService',
|
||||
'EndpointService',
|
||||
'VulnerabilityService',
|
||||
# 快照模块
|
||||
'SubdomainSnapshotsService',
|
||||
'HostPortMappingSnapshotsService',
|
||||
'WebsiteSnapshotsService',
|
||||
'DirectorySnapshotsService',
|
||||
'EndpointSnapshotsService',
|
||||
'VulnerabilitySnapshotsService',
|
||||
# 统计模块
|
||||
'AssetStatisticsService',
|
||||
]
|
||||
17
backend/apps/asset/services/asset/__init__.py
Normal file
17
backend/apps/asset/services/asset/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Asset Services - 资产模块的业务逻辑层"""
|
||||
|
||||
from .subdomain_service import SubdomainService
|
||||
from .website_service import WebSiteService
|
||||
from .directory_service import DirectoryService
|
||||
from .host_port_mapping_service import HostPortMappingService
|
||||
from .endpoint_service import EndpointService
|
||||
from .vulnerability_service import VulnerabilityService
|
||||
|
||||
__all__ = [
|
||||
'SubdomainService',
|
||||
'WebSiteService',
|
||||
'DirectoryService',
|
||||
'HostPortMappingService',
|
||||
'EndpointService',
|
||||
'VulnerabilityService',
|
||||
]
|
||||
55
backend/apps/asset/services/asset/directory_service.py
Normal file
55
backend/apps/asset/services/asset/directory_service.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import logging
|
||||
from typing import Tuple, Iterator
|
||||
|
||||
from apps.asset.models.asset_models import Directory
|
||||
from apps.asset.repositories import DjangoDirectoryRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DirectoryService:
|
||||
"""目录业务逻辑层"""
|
||||
|
||||
def __init__(self, repository=None):
|
||||
"""
|
||||
初始化目录服务
|
||||
|
||||
Args:
|
||||
repository: 目录仓储实例(用于依赖注入)
|
||||
"""
|
||||
self.repo = repository or DjangoDirectoryRepository()
|
||||
|
||||
# ==================== 创建操作 ====================
|
||||
|
||||
def bulk_create_ignore_conflicts(self, directory_dtos: list) -> None:
|
||||
"""
|
||||
批量创建目录记录,忽略冲突(用于扫描任务)
|
||||
|
||||
Args:
|
||||
directory_dtos: DirectoryDTO 列表
|
||||
"""
|
||||
return self.repo.bulk_create_ignore_conflicts(directory_dtos)
|
||||
|
||||
# ==================== 查询操作 ====================
|
||||
|
||||
def get_all(self):
|
||||
"""
|
||||
获取所有目录
|
||||
|
||||
Returns:
|
||||
QuerySet: 目录查询集
|
||||
"""
|
||||
logger.debug("获取所有目录")
|
||||
return self.repo.get_all()
|
||||
|
||||
def get_directories_by_target(self, target_id: int):
|
||||
logger.debug("获取目标下所有目录 - Target ID: %d", target_id)
|
||||
return self.repo.get_by_target(target_id)
|
||||
|
||||
def iter_directory_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取目标下的所有目录 URL,用于导出大批量数据。"""
|
||||
logger.debug("流式导出目标下目录 URL - Target ID: %d", target_id)
|
||||
return self.repo.get_urls_for_export(target_id=target_id, batch_size=chunk_size)
|
||||
|
||||
|
||||
__all__ = ['DirectoryService']
|
||||
178
backend/apps/asset/services/asset/endpoint_service.py
Normal file
178
backend/apps/asset/services/asset/endpoint_service.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Endpoint 服务层
|
||||
|
||||
处理 URL/端点相关的业务逻辑
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any, Iterator
|
||||
|
||||
from apps.asset.dtos.asset import EndpointDTO
|
||||
from apps.asset.repositories.asset import DjangoEndpointRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndpointService:
|
||||
"""
|
||||
Endpoint 服务类
|
||||
|
||||
提供 Endpoint(URL/端点)相关的业务逻辑
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化 Endpoint 服务"""
|
||||
self.repo = DjangoEndpointRepository()
|
||||
|
||||
def bulk_create_endpoints(
|
||||
self,
|
||||
endpoints: List[EndpointDTO],
|
||||
ignore_conflicts: bool = True
|
||||
) -> int:
|
||||
"""
|
||||
批量创建端点记录
|
||||
|
||||
Args:
|
||||
endpoints: 端点数据列表
|
||||
ignore_conflicts: 是否忽略冲突(去重)
|
||||
|
||||
Returns:
|
||||
int: 创建的记录数
|
||||
"""
|
||||
if not endpoints:
|
||||
return 0
|
||||
|
||||
try:
|
||||
if ignore_conflicts:
|
||||
return self.repo.bulk_create_ignore_conflicts(endpoints)
|
||||
else:
|
||||
# 如果需要非忽略冲突的版本,可以在 repository 中添加
|
||||
return self.repo.bulk_create_ignore_conflicts(endpoints)
|
||||
except Exception as e:
|
||||
logger.error(f"批量创建端点失败: {e}")
|
||||
raise
|
||||
|
||||
def get_endpoints_by_website(
|
||||
self,
|
||||
website_id: int,
|
||||
limit: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取网站下的端点列表
|
||||
|
||||
Args:
|
||||
website_id: 网站 ID
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
List[Dict]: 端点列表
|
||||
"""
|
||||
endpoints_dto = self.repo.get_by_website(website_id)
|
||||
|
||||
if limit:
|
||||
endpoints_dto = endpoints_dto[:limit]
|
||||
|
||||
endpoints = []
|
||||
for dto in endpoints_dto:
|
||||
endpoints.append({
|
||||
'url': dto.url,
|
||||
'title': dto.title,
|
||||
'status_code': dto.status_code,
|
||||
'content_length': dto.content_length,
|
||||
'webserver': dto.webserver
|
||||
})
|
||||
|
||||
return endpoints
|
||||
|
||||
def get_endpoints_by_target(
|
||||
self,
|
||||
target_id: int,
|
||||
limit: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取目标下的端点列表
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
List[Dict]: 端点列表
|
||||
"""
|
||||
endpoints_dto = self.repo.get_by_target(target_id)
|
||||
|
||||
if limit:
|
||||
endpoints_dto = endpoints_dto[:limit]
|
||||
|
||||
endpoints = []
|
||||
for dto in endpoints_dto:
|
||||
endpoints.append({
|
||||
'url': dto.url,
|
||||
'title': dto.title,
|
||||
'status_code': dto.status_code,
|
||||
'content_length': dto.content_length,
|
||||
'webserver': dto.webserver
|
||||
})
|
||||
|
||||
return endpoints
|
||||
|
||||
def count_endpoints_by_target(self, target_id: int) -> int:
|
||||
"""
|
||||
统计目标下的端点数量
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
int: 端点数量
|
||||
"""
|
||||
return self.repo.count_by_target(target_id)
|
||||
|
||||
def get_queryset_by_target(self, target_id: int):
|
||||
return self.repo.get_queryset_by_target(target_id)
|
||||
|
||||
def get_all(self):
|
||||
"""获取所有端点(全局查询)"""
|
||||
return self.repo.get_all()
|
||||
|
||||
def iter_endpoint_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取目标下的所有端点 URL,用于导出。"""
|
||||
queryset = self.repo.get_queryset_by_target(target_id)
|
||||
for url in queryset.values_list('url', flat=True).iterator(chunk_size=chunk_size):
|
||||
yield url
|
||||
|
||||
def count_endpoints_by_website(self, website_id: int) -> int:
|
||||
"""
|
||||
统计网站下的端点数量
|
||||
|
||||
Args:
|
||||
website_id: 网站 ID
|
||||
|
||||
Returns:
|
||||
int: 端点数量
|
||||
"""
|
||||
return self.repo.count_by_website(website_id)
|
||||
|
||||
def soft_delete_endpoints(self, endpoint_ids: List[int]) -> int:
|
||||
"""
|
||||
软删除端点
|
||||
|
||||
Args:
|
||||
endpoint_ids: 端点 ID 列表
|
||||
|
||||
Returns:
|
||||
int: 更新的数量
|
||||
"""
|
||||
return self.repo.soft_delete_by_ids(endpoint_ids)
|
||||
|
||||
def hard_delete_endpoints(self, endpoint_ids: List[int]) -> tuple:
|
||||
"""
|
||||
硬删除端点
|
||||
|
||||
Args:
|
||||
endpoint_ids: 端点 ID 列表
|
||||
|
||||
Returns:
|
||||
tuple: (删除总数, 详细信息)
|
||||
"""
|
||||
return self.repo.hard_delete_by_ids(endpoint_ids)
|
||||
@@ -0,0 +1,61 @@
|
||||
"""HostPortMapping Service - 业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
|
||||
from apps.asset.repositories.asset import DjangoHostPortMappingRepository
|
||||
from apps.asset.dtos.asset import HostPortMappingDTO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HostPortMappingService:
|
||||
"""主机端口映射服务 - 负责主机端口映射数据的业务逻辑"""
|
||||
|
||||
def __init__(self):
|
||||
self.repo = DjangoHostPortMappingRepository()
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[HostPortMappingDTO]) -> int:
|
||||
"""
|
||||
批量创建主机端口映射(忽略冲突)
|
||||
|
||||
Args:
|
||||
items: 主机端口映射 DTO 列表
|
||||
|
||||
Returns:
|
||||
int: 实际创建的记录数
|
||||
|
||||
Note:
|
||||
使用数据库唯一约束 + ignore_conflicts 自动去重
|
||||
"""
|
||||
try:
|
||||
logger.debug("Service: 准备批量创建主机端口映射 - 数量: %d", len(items))
|
||||
|
||||
created_count = self.repo.bulk_create_ignore_conflicts(items)
|
||||
|
||||
logger.info("Service: 主机端口映射创建成功 - 数量: %d", created_count)
|
||||
|
||||
return created_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Service: 批量创建主机端口映射失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def iter_host_port_by_target(self, target_id: int, batch_size: int = 1000):
|
||||
return self.repo.get_for_export(target_id=target_id, batch_size=batch_size)
|
||||
|
||||
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
|
||||
return self.repo.get_ip_aggregation_by_target(target_id, search=search)
|
||||
|
||||
def get_all_ip_aggregation(self, search: str = None):
|
||||
"""获取所有 IP 聚合数据(全局查询)"""
|
||||
return self.repo.get_all_ip_aggregation(search=search)
|
||||
|
||||
def iter_ips_by_target(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取目标下的所有唯一 IP 地址。"""
|
||||
return self.repo.get_ips_for_export(target_id=target_id, batch_size=batch_size)
|
||||
123
backend/apps/asset/services/asset/subdomain_service.py
Normal file
123
backend/apps/asset/services/asset/subdomain_service.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import logging
|
||||
from typing import Tuple, List, Dict
|
||||
|
||||
from apps.asset.repositories import DjangoSubdomainRepository
|
||||
from apps.asset.dtos import SubdomainDTO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubdomainService:
|
||||
"""子域名业务逻辑层"""
|
||||
|
||||
def __init__(self, repository=None):
|
||||
"""
|
||||
初始化子域名服务
|
||||
|
||||
Args:
|
||||
repository: 子域名仓储实例(用于依赖注入)
|
||||
"""
|
||||
self.repo = repository or DjangoSubdomainRepository()
|
||||
|
||||
# ==================== 查询操作 ====================
|
||||
|
||||
def get_all(self):
|
||||
"""
|
||||
获取所有子域名
|
||||
|
||||
Returns:
|
||||
QuerySet: 子域名查询集
|
||||
"""
|
||||
logger.debug("获取所有子域名")
|
||||
return self.repo.get_all()
|
||||
|
||||
# ==================== 创建操作 ====================
|
||||
|
||||
def get_or_create(self, name: str, target_id: int) -> Tuple[any, bool]:
|
||||
"""
|
||||
获取或创建子域名
|
||||
|
||||
Args:
|
||||
name: 子域名名称
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
(Subdomain对象, 是否新创建)
|
||||
"""
|
||||
logger.debug("获取或创建子域名 - Name: %s, Target ID: %d", name, target_id)
|
||||
return self.repo.get_or_create(name, target_id)
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
|
||||
"""
|
||||
批量创建子域名,忽略冲突
|
||||
|
||||
Args:
|
||||
items: 子域名 DTO 列表
|
||||
|
||||
Note:
|
||||
使用 ignore_conflicts 策略,重复记录会被跳过
|
||||
"""
|
||||
logger.debug("批量创建子域名 - 数量: %d", len(items))
|
||||
return self.repo.bulk_create_ignore_conflicts(items)
|
||||
|
||||
def get_by_names_and_target_id(self, names: set, target_id: int) -> dict:
|
||||
"""
|
||||
根据域名列表和目标ID批量查询子域名
|
||||
|
||||
Args:
|
||||
names: 域名集合
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
dict: {域名: Subdomain对象}
|
||||
"""
|
||||
logger.debug("批量查询子域名 - 数量: %d, Target ID: %d", len(names), target_id)
|
||||
return self.repo.get_by_names_and_target_id(names, target_id)
|
||||
|
||||
def get_subdomain_names_by_target(self, target_id: int) -> List[str]:
|
||||
"""
|
||||
获取目标下的所有子域名名称
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
List[str]: 子域名名称列表
|
||||
"""
|
||||
logger.debug("获取目标下所有子域名 - Target ID: %d", target_id)
|
||||
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
|
||||
return list(self.repo.get_domains_for_export(target_id=target_id))
|
||||
|
||||
def get_subdomains_by_target(self, target_id: int):
|
||||
return self.repo.get_by_target(target_id)
|
||||
|
||||
def count_subdomains_by_target(self, target_id: int) -> int:
|
||||
"""
|
||||
统计目标下的子域名数量
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
int: 子域名数量
|
||||
"""
|
||||
logger.debug("统计目标下子域名数量 - Target ID: %d", target_id)
|
||||
return self.repo.count_by_target(target_id)
|
||||
|
||||
def iter_subdomain_names_by_target(self, target_id: int, chunk_size: int = 1000):
|
||||
"""
|
||||
流式获取目标下的所有子域名名称(内存优化)
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
chunk_size: 批次大小
|
||||
|
||||
Yields:
|
||||
str: 子域名名称
|
||||
"""
|
||||
logger.debug("流式获取目标下所有子域名 - Target ID: %d, 批次大小: %d", target_id, chunk_size)
|
||||
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
|
||||
return self.repo.get_domains_for_export(target_id=target_id, batch_size=chunk_size)
|
||||
|
||||
|
||||
__all__ = ['SubdomainService']
|
||||
81
backend/apps/asset/services/asset/vulnerability_service.py
Normal file
81
backend/apps/asset/services/asset/vulnerability_service.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Vulnerability Service - 漏洞资产业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from apps.asset.models import Vulnerability
|
||||
from apps.asset.dtos.asset import VulnerabilityDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class VulnerabilityService:
|
||||
"""漏洞资产服务
|
||||
|
||||
当前提供基础的批量创建能力,使用 ignore_conflicts 依赖数据库唯一约束去重。
|
||||
"""
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[VulnerabilityDTO]) -> None:
|
||||
"""批量创建漏洞资产记录,忽略冲突。
|
||||
|
||||
Note:
|
||||
- 是否去重取决于模型上的唯一/部分唯一约束;
|
||||
- 当前 Vulnerability 模型未定义唯一约束,因此会保留全部记录。
|
||||
"""
|
||||
if not items:
|
||||
logger.debug("漏洞资产列表为空,跳过保存")
|
||||
return
|
||||
|
||||
try:
|
||||
vulns = [
|
||||
Vulnerability(
|
||||
target_id=item.target_id,
|
||||
url=item.url,
|
||||
vuln_type=item.vuln_type,
|
||||
severity=item.severity,
|
||||
source=item.source,
|
||||
cvss_score=item.cvss_score,
|
||||
description=item.description,
|
||||
raw_output=item.raw_output,
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
Vulnerability.objects.bulk_create(vulns, ignore_conflicts=True)
|
||||
logger.info("漏洞资产保存成功 - 数量: %d", len(vulns))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"批量保存漏洞资产失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
# ==================== 查询方法 ====================
|
||||
|
||||
def get_all(self):
|
||||
"""获取所有漏洞 QuerySet(用于全局漏洞列表)。
|
||||
|
||||
Returns:
|
||||
QuerySet[Vulnerability]: 所有漏洞,按发现时间倒序
|
||||
"""
|
||||
return Vulnerability.objects.filter(deleted_at__isnull=True).order_by("-discovered_at")
|
||||
|
||||
def get_queryset_by_target(self, target_id: int):
|
||||
"""按目标获取漏洞 QuerySet(用于分页)。
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
QuerySet[Vulnerability]: 目标下的所有漏洞,按发现时间倒序
|
||||
"""
|
||||
return Vulnerability.objects.filter(target_id=target_id).order_by("-discovered_at")
|
||||
|
||||
def count_by_target(self, target_id: int) -> int:
|
||||
"""统计目标下的漏洞数量。"""
|
||||
return Vulnerability.objects.filter(target_id=target_id).count()
|
||||
91
backend/apps/asset/services/asset/website_service.py
Normal file
91
backend/apps/asset/services/asset/website_service.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import logging
|
||||
from typing import Tuple, List
|
||||
|
||||
from apps.asset.repositories import DjangoWebSiteRepository
|
||||
from apps.asset.dtos import WebSiteDTO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebSiteService:
|
||||
"""网站业务逻辑层"""
|
||||
|
||||
def __init__(self, repository=None):
|
||||
"""
|
||||
初始化网站服务
|
||||
|
||||
Args:
|
||||
repository: 网站仓储实例(用于依赖注入)
|
||||
"""
|
||||
self.repo = repository or DjangoWebSiteRepository()
|
||||
|
||||
# ==================== 创建操作 ====================
|
||||
|
||||
def bulk_create_ignore_conflicts(self, website_dtos: List[WebSiteDTO]) -> None:
|
||||
"""
|
||||
批量创建网站记录,忽略冲突(用于扫描任务)
|
||||
|
||||
Args:
|
||||
website_dtos: WebSiteDTO 列表
|
||||
|
||||
Note:
|
||||
使用 ignore_conflicts 策略,重复记录会被跳过
|
||||
"""
|
||||
logger.debug("批量创建网站 - 数量: %d", len(website_dtos))
|
||||
return self.repo.bulk_create_ignore_conflicts(website_dtos)
|
||||
|
||||
# ==================== 查询操作 ====================
|
||||
|
||||
def get_by_url(self, url: str, target_id: int) -> int:
|
||||
"""
|
||||
根据 URL 和 target_id 查找网站 ID
|
||||
|
||||
Args:
|
||||
url: 网站 URL
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
int: 网站 ID,如果不存在返回 None
|
||||
"""
|
||||
return self.repo.get_by_url(url=url, target_id=target_id)
|
||||
|
||||
# ==================== 查询操作 ====================
|
||||
|
||||
def get_all(self):
|
||||
"""
|
||||
获取所有网站
|
||||
|
||||
Returns:
|
||||
QuerySet: 网站查询集
|
||||
"""
|
||||
logger.debug("获取所有网站")
|
||||
return self.repo.get_all()
|
||||
|
||||
def get_websites_by_target(self, target_id: int):
|
||||
return self.repo.get_by_target(target_id)
|
||||
|
||||
def count_websites_by_scan(self, scan_id: int) -> int:
|
||||
"""
|
||||
统计扫描下的网站数量
|
||||
|
||||
Args:
|
||||
scan_id: 扫描 ID
|
||||
|
||||
Returns:
|
||||
int: 网站数量
|
||||
"""
|
||||
logger.debug("统计扫描下网站数量 - Scan ID: %d", scan_id)
|
||||
return self.repo.count_by_scan(scan_id)
|
||||
|
||||
def iter_website_urls_by_target(self, target_id: int, chunk_size: int = 1000):
|
||||
"""流式获取目标下的所有站点 URL(内存优化,委托给 Repository 层)"""
|
||||
logger.debug(
|
||||
"流式获取目标下所有站点 URL - Target ID: %d, 批次大小: %d",
|
||||
target_id,
|
||||
chunk_size,
|
||||
)
|
||||
# 通过仓储层统一访问数据库,避免 Service 直接依赖 ORM
|
||||
return self.repo.get_urls_for_export(target_id=target_id, batch_size=chunk_size)
|
||||
|
||||
|
||||
__all__ = ['WebSiteService']
|
||||
17
backend/apps/asset/services/snapshot/__init__.py
Normal file
17
backend/apps/asset/services/snapshot/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Snapshot Services - 快照服务"""
|
||||
|
||||
from .subdomain_snapshots_service import SubdomainSnapshotsService
|
||||
from .host_port_mapping_snapshots_service import HostPortMappingSnapshotsService
|
||||
from .website_snapshots_service import WebsiteSnapshotsService
|
||||
from .directory_snapshots_service import DirectorySnapshotsService
|
||||
from .endpoint_snapshots_service import EndpointSnapshotsService
|
||||
from .vulnerability_snapshots_service import VulnerabilitySnapshotsService
|
||||
|
||||
__all__ = [
|
||||
'SubdomainSnapshotsService',
|
||||
'HostPortMappingSnapshotsService',
|
||||
'WebsiteSnapshotsService',
|
||||
'DirectorySnapshotsService',
|
||||
'EndpointSnapshotsService',
|
||||
'VulnerabilitySnapshotsService',
|
||||
]
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Directory Snapshots Service - 业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
|
||||
from apps.asset.repositories.snapshot import DjangoDirectorySnapshotRepository
|
||||
from apps.asset.services.asset import DirectoryService
|
||||
from apps.asset.dtos.snapshot import DirectorySnapshotDTO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DirectorySnapshotsService:
|
||||
"""目录快照服务 - 统一管理快照和资产同步"""
|
||||
|
||||
def __init__(self):
|
||||
self.snapshot_repo = DjangoDirectorySnapshotRepository()
|
||||
self.asset_service = DirectoryService()
|
||||
|
||||
def save_and_sync(self, items: List[DirectorySnapshotDTO]) -> None:
|
||||
"""
|
||||
保存目录快照并同步到资产表(统一入口)
|
||||
|
||||
流程:
|
||||
1. 保存到快照表(完整记录,包含 scan_id)
|
||||
2. 同步到资产表(去重,不包含 scan_id)
|
||||
|
||||
Args:
|
||||
items: 目录快照 DTO 列表(必须包含 website_id)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 items 中的 website_id 为 None
|
||||
Exception: 数据库操作失败
|
||||
"""
|
||||
if not items:
|
||||
return
|
||||
|
||||
# 检查 Scan 是否仍存在(防止删除后竞态写入)
|
||||
scan_id = items[0].scan_id
|
||||
from apps.scan.repositories import DjangoScanRepository
|
||||
if not DjangoScanRepository().exists(scan_id):
|
||||
logger.warning("Scan 已删除,跳过目录快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
|
||||
return
|
||||
|
||||
try:
|
||||
logger.debug("保存目录快照并同步到资产表 - 数量: %d", len(items))
|
||||
|
||||
# 步骤 1: 保存到快照表
|
||||
logger.debug("步骤 1: 保存到快照表")
|
||||
self.snapshot_repo.save_snapshots(items)
|
||||
|
||||
# 步骤 2: 转换为资产 DTO 并保存到资产表
|
||||
# 注意:去重是通过数据库的 UNIQUE 约束 + ignore_conflicts 实现的
|
||||
# - 新记录:插入资产表
|
||||
# - 已存在的记录:自动跳过
|
||||
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
|
||||
asset_items = [item.to_asset_dto() for item in items]
|
||||
|
||||
self.asset_service.bulk_create_ignore_conflicts(asset_items)
|
||||
|
||||
logger.info("目录快照和资产数据保存成功 - 数量: %d", len(items))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"保存目录快照失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return self.snapshot_repo.get_by_scan(scan_id)
|
||||
|
||||
def get_all(self):
|
||||
"""获取所有目录快照"""
|
||||
return self.snapshot_repo.get_all()
|
||||
|
||||
def iter_directory_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有目录 URL。"""
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
for snapshot in queryset.iterator(chunk_size=chunk_size):
|
||||
yield snapshot.url
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Endpoint Snapshots Service - 业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
|
||||
from apps.asset.repositories.snapshot import DjangoEndpointSnapshotRepository
|
||||
from apps.asset.services.asset import EndpointService
|
||||
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndpointSnapshotsService:
|
||||
"""端点快照服务 - 统一管理快照和资产同步"""
|
||||
|
||||
def __init__(self):
|
||||
self.snapshot_repo = DjangoEndpointSnapshotRepository()
|
||||
self.asset_service = EndpointService()
|
||||
|
||||
def save_and_sync(self, items: List[EndpointSnapshotDTO]) -> None:
|
||||
"""
|
||||
保存端点快照并同步到资产表(统一入口)
|
||||
|
||||
流程:
|
||||
1. 保存到快照表(完整记录)
|
||||
2. 同步到资产表(去重)
|
||||
|
||||
Args:
|
||||
items: 端点快照 DTO 列表(必须包含 target_id)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 items 中的 target_id 为 None
|
||||
Exception: 数据库操作失败
|
||||
"""
|
||||
if not items:
|
||||
return
|
||||
|
||||
# 检查 Scan 是否仍存在(防止删除后竞态写入)
|
||||
scan_id = items[0].scan_id
|
||||
from apps.scan.repositories import DjangoScanRepository
|
||||
if not DjangoScanRepository().exists(scan_id):
|
||||
logger.warning("Scan 已删除,跳过端点快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
|
||||
return
|
||||
|
||||
try:
|
||||
logger.debug("保存端点快照并同步到资产表 - 数量: %d", len(items))
|
||||
|
||||
# 步骤 1: 保存到快照表
|
||||
logger.debug("步骤 1: 保存到快照表")
|
||||
self.snapshot_repo.save_snapshots(items)
|
||||
|
||||
# 步骤 2: 转换为资产 DTO 并保存到资产表
|
||||
# 注意:去重是通过数据库的 UNIQUE 约束 + ignore_conflicts 实现的
|
||||
# - 新记录:插入资产表
|
||||
# - 已存在的记录:自动跳过
|
||||
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
|
||||
asset_items = [item.to_asset_dto() for item in items]
|
||||
|
||||
self.asset_service.bulk_create_endpoints(asset_items)
|
||||
|
||||
logger.info("端点快照和资产数据保存成功 - 数量: %d", len(items))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"保存端点快照失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return self.snapshot_repo.get_by_scan(scan_id)
|
||||
|
||||
def get_all(self):
|
||||
"""获取所有端点快照"""
|
||||
return self.snapshot_repo.get_all()
|
||||
|
||||
def iter_endpoint_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有端点 URL。"""
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
for snapshot in queryset.iterator(chunk_size=chunk_size):
|
||||
yield snapshot.url
|
||||
@@ -0,0 +1,81 @@
|
||||
"""HostPortMapping Snapshots Service - 业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
|
||||
from apps.asset.repositories.snapshot import DjangoHostPortMappingSnapshotRepository
|
||||
from apps.asset.services.asset import HostPortMappingService
|
||||
from apps.asset.dtos.snapshot import HostPortMappingSnapshotDTO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HostPortMappingSnapshotsService:
|
||||
"""HostPortMapping Snapshots Service - 统一管理快照和资产同步"""
|
||||
|
||||
def __init__(self):
|
||||
self.snapshot_repo = DjangoHostPortMappingSnapshotRepository()
|
||||
self.asset_service = HostPortMappingService()
|
||||
|
||||
def save_and_sync(self, items: List[HostPortMappingSnapshotDTO]) -> None:
|
||||
"""
|
||||
保存主机端口关联快照并同步到资产表(统一入口)
|
||||
|
||||
流程:
|
||||
1. 保存到快照表(完整记录,包含 scan_id)
|
||||
2. 同步到资产表(去重,不包含 scan_id)
|
||||
|
||||
Args:
|
||||
items: 主机端口关联快照 DTO 列表(必须包含 target_id)
|
||||
|
||||
Note:
|
||||
target_id 已经包含在 DTO 中,无需额外传参。
|
||||
"""
|
||||
logger.debug("保存主机端口关联快照 - 数量: %d", len(items))
|
||||
|
||||
if not items:
|
||||
logger.debug("快照数据为空,跳过保存")
|
||||
return
|
||||
|
||||
# 检查 Scan 是否仍存在(防止删除后竞态写入)
|
||||
scan_id = items[0].scan_id
|
||||
from apps.scan.repositories import DjangoScanRepository
|
||||
if not DjangoScanRepository().exists(scan_id):
|
||||
logger.warning("Scan 已删除,跳过主机端口快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
|
||||
return
|
||||
|
||||
try:
|
||||
# 步骤 1: 保存到快照表
|
||||
logger.debug("步骤 1: 保存到快照表")
|
||||
self.snapshot_repo.save_snapshots(items)
|
||||
|
||||
# 步骤 2: 转换为资产 DTO 并保存到资产表
|
||||
# 注意:去重是通过数据库的 UNIQUE 约束 + ignore_conflicts 实现的
|
||||
# - 新记录:插入资产表
|
||||
# - 已存在的记录:自动跳过
|
||||
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
|
||||
asset_items = [item.to_asset_dto() for item in items]
|
||||
|
||||
self.asset_service.bulk_create_ignore_conflicts(asset_items)
|
||||
|
||||
logger.info("主机端口关联快照和资产数据保存成功 - 数量: %d", len(items))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"保存主机端口关联快照失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
|
||||
return self.snapshot_repo.get_ip_aggregation_by_scan(scan_id, search=search)
|
||||
|
||||
def get_all_ip_aggregation(self, search: str = None):
|
||||
"""获取所有 IP 聚合数据"""
|
||||
return self.snapshot_repo.get_all_ip_aggregation(search=search)
|
||||
|
||||
def iter_ips_by_scan(self, scan_id: int, batch_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有唯一 IP 地址。"""
|
||||
return self.snapshot_repo.get_ips_for_export(scan_id=scan_id, batch_size=batch_size)
|
||||
@@ -0,0 +1,79 @@
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
|
||||
from apps.asset.dtos import SubdomainSnapshotDTO
|
||||
from apps.asset.repositories import DjangoSubdomainSnapshotRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubdomainSnapshotsService:
|
||||
"""子域名快照服务 - 负责子域名快照数据的业务逻辑"""
|
||||
|
||||
def __init__(self):
|
||||
self.subdomain_snapshot_repo = DjangoSubdomainSnapshotRepository()
|
||||
|
||||
def save_and_sync(self, items: List[SubdomainSnapshotDTO]) -> None:
|
||||
"""
|
||||
保存子域名快照并同步到资产表(统一入口)
|
||||
|
||||
流程:
|
||||
1. 保存到快照表(完整记录,包含 scan_id)
|
||||
2. 同步到资产表(去重,不包含 scan_id)
|
||||
|
||||
Args:
|
||||
items: 子域名快照 DTO 列表(包含 target_id)
|
||||
|
||||
Note:
|
||||
target_id 已经包含在 DTO 中,无需额外传参。
|
||||
"""
|
||||
logger.debug("保存子域名快照 - 数量: %d", len(items))
|
||||
|
||||
if not items:
|
||||
logger.debug("快照数据为空,跳过保存")
|
||||
return
|
||||
|
||||
# 检查 Scan 是否仍存在(防止删除后竞态写入)
|
||||
scan_id = items[0].scan_id
|
||||
from apps.scan.repositories import DjangoScanRepository
|
||||
if not DjangoScanRepository().exists(scan_id):
|
||||
logger.warning("Scan 已删除,跳过快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
|
||||
return
|
||||
|
||||
try:
|
||||
# 步骤 1: 保存到快照表
|
||||
logger.debug("步骤 1: 保存到快照表")
|
||||
self.subdomain_snapshot_repo.save_subdomain_snapshots(items)
|
||||
|
||||
# 步骤 2: 转换为资产 DTO 并保存到资产表(通过数据库唯一约束自动去重)
|
||||
# 注意:去重是通过数据库的 UNIQUE 约束 + ignore_conflicts 实现的
|
||||
# - 新子域名:插入资产表
|
||||
# - 已存在的子域名:自动跳过(不更新,因为资产表只记录核心数据)
|
||||
asset_items = [item.to_asset_dto() for item in items]
|
||||
|
||||
from apps.asset.services import SubdomainService
|
||||
subdomain_service = SubdomainService()
|
||||
subdomain_service.bulk_create_ignore_conflicts(asset_items)
|
||||
|
||||
logger.info("子域名快照和业务数据保存成功 - 数量: %d", len(items))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"保存子域名快照失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return self.subdomain_snapshot_repo.get_by_scan(scan_id)
|
||||
|
||||
def get_all(self):
|
||||
"""获取所有子域名快照"""
|
||||
return self.subdomain_snapshot_repo.get_all()
|
||||
|
||||
def iter_subdomain_names_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
queryset = self.subdomain_snapshot_repo.get_by_scan(scan_id)
|
||||
for snapshot in queryset.iterator(chunk_size=chunk_size):
|
||||
yield snapshot.name
|
||||
@@ -0,0 +1,81 @@
|
||||
"""Vulnerability Snapshots Service - 业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
|
||||
from apps.asset.repositories.snapshot import DjangoVulnerabilitySnapshotRepository
|
||||
from apps.asset.services.asset.vulnerability_service import VulnerabilityService
|
||||
from apps.asset.dtos.snapshot import VulnerabilitySnapshotDTO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VulnerabilitySnapshotsService:
|
||||
"""漏洞快照服务 - 统一管理快照和资产同步。
|
||||
|
||||
流程与 Website/Directory 等保持一致:
|
||||
1. 保存到 VulnerabilitySnapshot 快照表(包含 scan_id)
|
||||
2. 转为 VulnerabilityDTO 并同步到 Vulnerability 资产表(基于 target_id)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.snapshot_repo = DjangoVulnerabilitySnapshotRepository()
|
||||
self.asset_service = VulnerabilityService()
|
||||
|
||||
def save_and_sync(self, items: List[VulnerabilitySnapshotDTO]) -> None:
|
||||
"""保存漏洞快照并同步到漏洞资产表。"""
|
||||
if not items:
|
||||
return
|
||||
|
||||
# 检查 Scan 是否仍存在(防止删除后竞态写入)
|
||||
scan_id = items[0].scan_id
|
||||
from apps.scan.repositories import DjangoScanRepository
|
||||
if not DjangoScanRepository().exists(scan_id):
|
||||
logger.warning("Scan 已删除,跳过漏洞快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
|
||||
return
|
||||
|
||||
try:
|
||||
logger.debug("保存漏洞快照并同步到资产表 - 数量: %d", len(items))
|
||||
|
||||
# 步骤 1: 保存到快照表
|
||||
logger.debug("步骤 1: 保存到漏洞快照表")
|
||||
self.snapshot_repo.save_snapshots(items)
|
||||
|
||||
# 步骤 2: 转换为资产 DTO 并保存到资产表
|
||||
logger.debug("步骤 2: 同步到漏洞资产表")
|
||||
asset_items = [item.to_asset_dto() for item in items]
|
||||
self.asset_service.bulk_create_ignore_conflicts(asset_items)
|
||||
|
||||
logger.info("漏洞快照和资产数据保存成功 - 数量: %d", len(items))
|
||||
|
||||
# 步骤 3: 发布漏洞保存信号(通知等模块可监听)
|
||||
from apps.common.signals import vulnerabilities_saved
|
||||
vulnerabilities_saved.send(
|
||||
sender=self.__class__,
|
||||
items=items,
|
||||
scan_id=scan_id,
|
||||
target_id=items[0].target_id if items else None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"保存漏洞快照失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
"""按扫描任务获取所有漏洞快照。"""
|
||||
return self.snapshot_repo.get_by_scan(scan_id)
|
||||
|
||||
def get_all(self):
|
||||
"""获取所有漏洞快照"""
|
||||
return self.snapshot_repo.get_all()
|
||||
|
||||
def iter_vuln_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有漏洞 URL。"""
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
for snapshot in queryset.iterator(chunk_size=chunk_size):
|
||||
yield snapshot.url
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Website Snapshots Service - 业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
|
||||
from apps.asset.repositories.snapshot import DjangoWebsiteSnapshotRepository
|
||||
from apps.asset.services.asset import WebSiteService
|
||||
from apps.asset.dtos.snapshot import WebsiteSnapshotDTO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebsiteSnapshotsService:
|
||||
"""网站快照服务 - 统一管理快照和资产同步"""
|
||||
|
||||
def __init__(self):
|
||||
self.snapshot_repo = DjangoWebsiteSnapshotRepository()
|
||||
self.asset_service = WebSiteService()
|
||||
|
||||
def save_and_sync(self, items: List[WebsiteSnapshotDTO]) -> None:
|
||||
"""
|
||||
保存网站快照并同步到资产表(统一入口)
|
||||
|
||||
流程:
|
||||
1. 保存到快照表(完整记录,包含 scan_id)
|
||||
2. 同步到资产表(去重,不包含 scan_id)
|
||||
|
||||
Args:
|
||||
items: 网站快照 DTO 列表(必须包含 target_id)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 items 中的 target_id 为 None
|
||||
Exception: 数据库操作失败
|
||||
"""
|
||||
if not items:
|
||||
return
|
||||
|
||||
# 检查 Scan 是否仍存在(防止删除后竞态写入)
|
||||
scan_id = items[0].scan_id
|
||||
from apps.scan.repositories import DjangoScanRepository
|
||||
if not DjangoScanRepository().exists(scan_id):
|
||||
logger.warning("Scan 已删除,跳过网站快照保存 - scan_id=%s, 数量=%d", scan_id, len(items))
|
||||
return
|
||||
|
||||
try:
|
||||
logger.debug("保存网站快照并同步到资产表 - 数量: %d", len(items))
|
||||
|
||||
# 步骤 1: 保存到快照表
|
||||
logger.debug("步骤 1: 保存到快照表")
|
||||
self.snapshot_repo.save_snapshots(items)
|
||||
|
||||
# 步骤 2: 转换为资产 DTO 并保存到资产表
|
||||
# 注意:去重是通过数据库的 UNIQUE 约束 + ignore_conflicts 实现的
|
||||
# - 新记录:插入资产表
|
||||
# - 已存在的记录:自动跳过
|
||||
logger.debug("步骤 2: 同步到资产表(通过 Service 层)")
|
||||
asset_items = [item.to_asset_dto() for item in items]
|
||||
|
||||
self.asset_service.bulk_create_ignore_conflicts(asset_items)
|
||||
|
||||
logger.info("网站快照和资产数据保存成功 - 数量: %d", len(items))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"保存网站快照失败 - 数量: %d, 错误: %s",
|
||||
len(items),
|
||||
str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return self.snapshot_repo.get_by_scan(scan_id)
|
||||
|
||||
def get_all(self):
|
||||
"""获取所有网站快照"""
|
||||
return self.snapshot_repo.get_all()
|
||||
|
||||
def iter_website_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有站点 URL(按发现时间倒序)。"""
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
for snapshot in queryset.iterator(chunk_size=chunk_size):
|
||||
yield snapshot.url
|
||||
162
backend/apps/asset/services/statistics_service.py
Normal file
162
backend/apps/asset/services/statistics_service.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""资产统计 Service"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from django.db.models import Count
|
||||
|
||||
from apps.asset.repositories import AssetStatisticsRepository
|
||||
from apps.asset.models import (
|
||||
AssetStatistics,
|
||||
StatisticsHistory,
|
||||
Subdomain,
|
||||
WebSite,
|
||||
Endpoint,
|
||||
HostPortMapping,
|
||||
Vulnerability,
|
||||
)
|
||||
from apps.targets.models import Target
|
||||
from apps.scan.models import Scan
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssetStatisticsService:
|
||||
"""
|
||||
资产统计服务
|
||||
|
||||
职责:
|
||||
- 获取统计数据
|
||||
- 刷新统计数据(供定时任务调用)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.repo = AssetStatisticsRepository()
|
||||
|
||||
def get_statistics(self) -> dict:
|
||||
"""
|
||||
获取统计数据
|
||||
|
||||
Returns:
|
||||
统计数据字典
|
||||
"""
|
||||
stats = self.repo.get_statistics()
|
||||
|
||||
if stats is None:
|
||||
# 如果没有统计数据,返回默认值
|
||||
return {
|
||||
'total_targets': 0,
|
||||
'total_subdomains': 0,
|
||||
'total_ips': 0,
|
||||
'total_endpoints': 0,
|
||||
'total_websites': 0,
|
||||
'total_vulns': 0,
|
||||
'total_assets': 0,
|
||||
'running_scans': Scan.objects.filter(status='running').count(),
|
||||
'updated_at': None,
|
||||
# 变化值
|
||||
'change_targets': 0,
|
||||
'change_subdomains': 0,
|
||||
'change_ips': 0,
|
||||
'change_endpoints': 0,
|
||||
'change_websites': 0,
|
||||
'change_vulns': 0,
|
||||
'change_assets': 0,
|
||||
'vuln_by_severity': self._get_vuln_by_severity(),
|
||||
}
|
||||
|
||||
# 运行中的扫描数量实时查询(数量小,无需缓存)
|
||||
running_scans = Scan.objects.filter(status='running').count()
|
||||
|
||||
return {
|
||||
'total_targets': stats.total_targets,
|
||||
'total_subdomains': stats.total_subdomains,
|
||||
'total_ips': stats.total_ips,
|
||||
'total_endpoints': stats.total_endpoints,
|
||||
'total_websites': stats.total_websites,
|
||||
'total_vulns': stats.total_vulns,
|
||||
'total_assets': stats.total_assets,
|
||||
'running_scans': running_scans,
|
||||
'updated_at': stats.updated_at,
|
||||
# 变化值 = 当前值 - 上次值
|
||||
'change_targets': stats.total_targets - stats.prev_targets,
|
||||
'change_subdomains': stats.total_subdomains - stats.prev_subdomains,
|
||||
'change_ips': stats.total_ips - stats.prev_ips,
|
||||
'change_endpoints': stats.total_endpoints - stats.prev_endpoints,
|
||||
'change_websites': stats.total_websites - stats.prev_websites,
|
||||
'change_vulns': stats.total_vulns - stats.prev_vulns,
|
||||
'change_assets': stats.total_assets - stats.prev_assets,
|
||||
# 漏洞严重程度分布
|
||||
'vuln_by_severity': self._get_vuln_by_severity(),
|
||||
}
|
||||
|
||||
def _get_vuln_by_severity(self) -> dict:
|
||||
"""获取按严重程度统计的漏洞数量"""
|
||||
result = Vulnerability.objects.values('severity').annotate(count=Count('id'))
|
||||
severity_map = {item['severity']: item['count'] for item in result}
|
||||
return {
|
||||
'critical': severity_map.get('critical', 0),
|
||||
'high': severity_map.get('high', 0),
|
||||
'medium': severity_map.get('medium', 0),
|
||||
'low': severity_map.get('low', 0),
|
||||
'info': severity_map.get('info', 0),
|
||||
}
|
||||
|
||||
def refresh_statistics(self) -> AssetStatistics:
|
||||
"""
|
||||
刷新统计数据(执行实际 COUNT 查询)
|
||||
|
||||
供定时任务调用,不建议在接口中直接调用。
|
||||
|
||||
Returns:
|
||||
更新后的统计数据对象
|
||||
"""
|
||||
logger.info("开始刷新资产统计...")
|
||||
|
||||
# 执行 COUNT 查询
|
||||
total_targets = Target.objects.filter(deleted_at__isnull=True).count()
|
||||
total_subdomains = Subdomain.objects.count()
|
||||
total_ips = HostPortMapping.objects.values('ip').distinct().count()
|
||||
total_endpoints = Endpoint.objects.count()
|
||||
total_websites = WebSite.objects.count()
|
||||
total_vulns = Vulnerability.objects.count()
|
||||
|
||||
# 更新统计表
|
||||
stats = self.repo.update_statistics(
|
||||
total_targets=total_targets,
|
||||
total_subdomains=total_subdomains,
|
||||
total_ips=total_ips,
|
||||
total_endpoints=total_endpoints,
|
||||
total_websites=total_websites,
|
||||
total_vulns=total_vulns,
|
||||
)
|
||||
|
||||
# 保存每日快照(用于折线图)
|
||||
self.repo.save_daily_snapshot(stats)
|
||||
|
||||
logger.info("资产统计刷新完成")
|
||||
return stats
|
||||
|
||||
def get_statistics_history(self, days: int = 7) -> list[dict]:
|
||||
"""
|
||||
获取历史统计数据(用于折线图)
|
||||
|
||||
Args:
|
||||
days: 获取最近多少天的数据,默认 7 天
|
||||
|
||||
Returns:
|
||||
历史数据列表,每项包含 date 和各统计字段
|
||||
"""
|
||||
history = self.repo.get_history(days=days)
|
||||
return [
|
||||
{
|
||||
'date': h.date.isoformat(),
|
||||
'totalTargets': h.total_targets,
|
||||
'totalSubdomains': h.total_subdomains,
|
||||
'totalIps': h.total_ips,
|
||||
'totalEndpoints': h.total_endpoints,
|
||||
'totalWebsites': h.total_websites,
|
||||
'totalVulns': h.total_vulns,
|
||||
'totalAssets': h.total_assets,
|
||||
}
|
||||
for h in history
|
||||
]
|
||||
28
backend/apps/asset/urls.py
Normal file
28
backend/apps/asset/urls.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Asset 应用 URL 配置
|
||||
"""
|
||||
|
||||
from django.urls import path, include
|
||||
from rest_framework.routers import DefaultRouter
|
||||
from .views import (
|
||||
SubdomainViewSet,
|
||||
WebSiteViewSet,
|
||||
DirectoryViewSet,
|
||||
VulnerabilityViewSet,
|
||||
AssetStatisticsViewSet,
|
||||
)
|
||||
|
||||
# 创建 DRF 路由器
|
||||
router = DefaultRouter()
|
||||
|
||||
# 注册 ViewSet
|
||||
# 注意:IPAddress 模型已被重构为 HostPortMapping,相关路由已移除
|
||||
router.register(r'subdomains', SubdomainViewSet, basename='subdomain')
|
||||
router.register(r'websites', WebSiteViewSet, basename='website')
|
||||
router.register(r'directories', DirectoryViewSet, basename='directory')
|
||||
router.register(r'vulnerabilities', VulnerabilityViewSet, basename='vulnerability')
|
||||
router.register(r'statistics', AssetStatisticsViewSet, basename='asset-statistics')
|
||||
|
||||
urlpatterns = [
|
||||
path('assets/', include(router.urls)),
|
||||
]
|
||||
562
backend/apps/asset/views.py
Normal file
562
backend/apps/asset/views.py
Normal file
@@ -0,0 +1,562 @@
|
||||
import logging
|
||||
from rest_framework import viewsets, status, filters
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.exceptions import NotFound, ValidationError as DRFValidationError
|
||||
from django.core.exceptions import ValidationError, ObjectDoesNotExist
|
||||
from django.db import DatabaseError, IntegrityError, OperationalError
|
||||
from django.http import StreamingHttpResponse
|
||||
|
||||
from .serializers import (
|
||||
SubdomainListSerializer, WebSiteSerializer, DirectorySerializer,
|
||||
VulnerabilitySerializer, EndpointListSerializer, IPAddressAggregatedSerializer,
|
||||
SubdomainSnapshotSerializer, WebsiteSnapshotSerializer, DirectorySnapshotSerializer,
|
||||
EndpointSnapshotSerializer, VulnerabilitySnapshotSerializer
|
||||
)
|
||||
from .services import (
|
||||
SubdomainService, WebSiteService, DirectoryService,
|
||||
VulnerabilityService, AssetStatisticsService, EndpointService, HostPortMappingService
|
||||
)
|
||||
from .services.snapshot import (
|
||||
SubdomainSnapshotsService, WebsiteSnapshotsService, DirectorySnapshotsService,
|
||||
EndpointSnapshotsService, HostPortMappingSnapshotsService, VulnerabilitySnapshotsService
|
||||
)
|
||||
from apps.common.pagination import BasePagination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssetStatisticsViewSet(viewsets.ViewSet):
|
||||
"""
|
||||
资产统计 API
|
||||
|
||||
提供仪表盘所需的统计数据(预聚合,读取缓存表)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = AssetStatisticsService()
|
||||
|
||||
def list(self, request):
|
||||
"""
|
||||
获取资产统计数据
|
||||
|
||||
GET /assets/statistics/
|
||||
|
||||
返回:
|
||||
- totalTargets: 目标总数
|
||||
- totalSubdomains: 子域名总数
|
||||
- totalIps: IP 总数
|
||||
- totalEndpoints: 端点总数
|
||||
- totalWebsites: 网站总数
|
||||
- totalVulns: 漏洞总数
|
||||
- totalAssets: 总资产数
|
||||
- runningScans: 运行中的扫描数
|
||||
- updatedAt: 统计更新时间
|
||||
"""
|
||||
try:
|
||||
stats = self.service.get_statistics()
|
||||
return Response({
|
||||
'totalTargets': stats['total_targets'],
|
||||
'totalSubdomains': stats['total_subdomains'],
|
||||
'totalIps': stats['total_ips'],
|
||||
'totalEndpoints': stats['total_endpoints'],
|
||||
'totalWebsites': stats['total_websites'],
|
||||
'totalVulns': stats['total_vulns'],
|
||||
'totalAssets': stats['total_assets'],
|
||||
'runningScans': stats['running_scans'],
|
||||
'updatedAt': stats['updated_at'],
|
||||
# 变化值
|
||||
'changeTargets': stats['change_targets'],
|
||||
'changeSubdomains': stats['change_subdomains'],
|
||||
'changeIps': stats['change_ips'],
|
||||
'changeEndpoints': stats['change_endpoints'],
|
||||
'changeWebsites': stats['change_websites'],
|
||||
'changeVulns': stats['change_vulns'],
|
||||
'changeAssets': stats['change_assets'],
|
||||
# 漏洞严重程度分布
|
||||
'vulnBySeverity': stats['vuln_by_severity'],
|
||||
})
|
||||
except (DatabaseError, OperationalError) as e:
|
||||
logger.exception("获取资产统计数据失败")
|
||||
return Response(
|
||||
{'error': '获取统计数据失败'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='history')
|
||||
def history(self, request: Request):
|
||||
"""
|
||||
获取统计历史数据(用于折线图)
|
||||
|
||||
GET /assets/statistics/history/?days=7
|
||||
|
||||
Query Parameters:
|
||||
days: 获取最近多少天的数据,默认 7,最大 90
|
||||
|
||||
Returns:
|
||||
历史数据列表
|
||||
"""
|
||||
try:
|
||||
days_param = request.query_params.get('days', '7')
|
||||
try:
|
||||
days = int(days_param)
|
||||
except (ValueError, TypeError):
|
||||
days = 7
|
||||
days = min(max(days, 1), 90) # 限制在 1-90 天
|
||||
|
||||
history = self.service.get_statistics_history(days=days)
|
||||
return Response(history)
|
||||
except (DatabaseError, OperationalError) as e:
|
||||
logger.exception("获取统计历史数据失败")
|
||||
return Response(
|
||||
{'error': '获取历史数据失败'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
# 注意:IPAddress 模型已被重构为 HostPortMapping
|
||||
# IPAddressViewSet 已删除,需要根据新架构重新实现
|
||||
|
||||
|
||||
class SubdomainViewSet(viewsets.ModelViewSet):
|
||||
"""子域名管理 ViewSet
|
||||
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/subdomains/
|
||||
2. 独立路由:GET /api/subdomains/(全局查询)
|
||||
"""
|
||||
|
||||
serializer_class = SubdomainListSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['name']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = SubdomainService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if target_pk:
|
||||
return self.service.get_subdomains_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
"""导出子域名(纯文本,一行一个)"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
raise DRFValidationError('必须在目标下导出')
|
||||
|
||||
def line_iterator():
|
||||
for name in self.service.iter_subdomain_names_by_target(target_pk):
|
||||
yield f"{name}\n"
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
line_iterator(),
|
||||
content_type='text/plain; charset=utf-8',
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-subdomains.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class WebSiteViewSet(viewsets.ModelViewSet):
|
||||
"""站点管理 ViewSet
|
||||
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/websites/
|
||||
2. 独立路由:GET /api/websites/(全局查询)
|
||||
"""
|
||||
|
||||
serializer_class = WebSiteSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['host']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = WebSiteService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if target_pk:
|
||||
return self.service.get_websites_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
"""导出站点 URL(纯文本,一行一个)"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
raise DRFValidationError('必须在目标下导出')
|
||||
|
||||
def line_iterator():
|
||||
for url in self.service.iter_website_urls_by_target(target_pk):
|
||||
yield f"{url}\n"
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
line_iterator(),
|
||||
content_type='text/plain; charset=utf-8',
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-websites.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class DirectoryViewSet(viewsets.ModelViewSet):
|
||||
"""目录管理 ViewSet
|
||||
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/directories/
|
||||
2. 独立路由:GET /api/directories/(全局查询)
|
||||
"""
|
||||
|
||||
serializer_class = DirectorySerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['url']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = DirectoryService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if target_pk:
|
||||
return self.service.get_directories_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
"""导出目录 URL(纯文本,一行一个)"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
raise DRFValidationError('必须在目标下导出')
|
||||
|
||||
def line_iterator():
|
||||
for url in self.service.iter_directory_urls_by_target(target_pk):
|
||||
yield f"{url}\n"
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
line_iterator(),
|
||||
content_type='text/plain; charset=utf-8',
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-directories.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class EndpointViewSet(viewsets.ModelViewSet):
|
||||
"""端点管理 ViewSet
|
||||
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/endpoints/
|
||||
2. 独立路由:GET /api/endpoints/(全局查询)
|
||||
"""
|
||||
|
||||
serializer_class = EndpointListSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['host']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = EndpointService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if target_pk:
|
||||
return self.service.get_queryset_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
"""导出端点 URL(纯文本,一行一个)"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
raise DRFValidationError('必须在目标下导出')
|
||||
|
||||
def line_iterator():
|
||||
for url in self.service.iter_endpoint_urls_by_target(target_pk):
|
||||
yield f"{url}\n"
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
line_iterator(),
|
||||
content_type='text/plain; charset=utf-8',
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-endpoints.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class HostPortMappingViewSet(viewsets.ModelViewSet):
|
||||
"""主机端口映射管理 ViewSet(IP 地址聚合视图)
|
||||
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/ip-addresses/
|
||||
2. 独立路由:GET /api/ip-addresses/(全局查询)
|
||||
|
||||
返回按 IP 聚合的数据,每个 IP 显示其关联的所有 hosts 和 ports
|
||||
|
||||
注意:由于返回的是聚合数据(字典列表),不支持 DRF SearchFilter
|
||||
"""
|
||||
|
||||
serializer_class = IPAddressAggregatedSerializer
|
||||
pagination_class = BasePagination
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = HostPortMappingService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围,返回按 IP 聚合的数据"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
search = self.request.query_params.get('search', None)
|
||||
if target_pk:
|
||||
return self.service.get_ip_aggregation_by_target(target_pk, search=search)
|
||||
return self.service.get_all_ip_aggregation(search=search)
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
"""导出 IP 地址(纯文本,一行一个)"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
raise DRFValidationError('必须在目标下导出')
|
||||
|
||||
def line_iterator():
|
||||
for ip in self.service.iter_ips_by_target(target_pk):
|
||||
yield f"{ip}\n"
|
||||
|
||||
response = StreamingHttpResponse(
|
||||
line_iterator(),
|
||||
content_type='text/plain; charset=utf-8',
|
||||
)
|
||||
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-ip-addresses.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class VulnerabilityViewSet(viewsets.ModelViewSet):
|
||||
"""漏洞资产管理 ViewSet(只读)
|
||||
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/vulnerabilities/
|
||||
2. 独立路由:GET /api/vulnerabilities/(全局查询)
|
||||
"""
|
||||
|
||||
serializer_class = VulnerabilitySerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['vuln_type']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = VulnerabilityService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if target_pk:
|
||||
return self.service.get_queryset_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
|
||||
# ==================== 快照 ViewSet(Scan 嵌套路由) ====================
|
||||
|
||||
class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""子域名快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/subdomains/"""
|
||||
|
||||
serializer_class = SubdomainSnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['name']
|
||||
ordering_fields = ['name', 'discovered_at']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = SubdomainSnapshotsService()
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
raise DRFValidationError('必须在扫描下导出')
|
||||
|
||||
def line_iterator():
|
||||
for name in self.service.iter_subdomain_names_by_scan(scan_pk):
|
||||
yield f"{name}\n"
|
||||
|
||||
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-subdomains.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""网站快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/websites/"""
|
||||
|
||||
serializer_class = WebsiteSnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['host']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = WebsiteSnapshotsService()
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
raise DRFValidationError('必须在扫描下导出')
|
||||
|
||||
def line_iterator():
|
||||
for url in self.service.iter_website_urls_by_scan(scan_pk):
|
||||
yield f"{url}\n"
|
||||
|
||||
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-websites.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class DirectorySnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""目录快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/directories/"""
|
||||
|
||||
serializer_class = DirectorySnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['url']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = DirectorySnapshotsService()
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
raise DRFValidationError('必须在扫描下导出')
|
||||
|
||||
def line_iterator():
|
||||
for url in self.service.iter_directory_urls_by_scan(scan_pk):
|
||||
yield f"{url}\n"
|
||||
|
||||
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-directories.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class EndpointSnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""端点快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/endpoints/"""
|
||||
|
||||
serializer_class = EndpointSnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['host']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = EndpointSnapshotsService()
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
raise DRFValidationError('必须在扫描下导出')
|
||||
|
||||
def line_iterator():
|
||||
for url in self.service.iter_endpoint_urls_by_scan(scan_pk):
|
||||
yield f"{url}\n"
|
||||
|
||||
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-endpoints.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""主机端口映射快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/ip-addresses/
|
||||
|
||||
注意:由于返回的是聚合数据(字典列表),不支持 DRF SearchFilter
|
||||
"""
|
||||
|
||||
serializer_class = IPAddressAggregatedSerializer
|
||||
pagination_class = BasePagination
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = HostPortMappingSnapshotsService()
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
search = self.request.query_params.get('search', None)
|
||||
if scan_pk:
|
||||
return self.service.get_ip_aggregation_by_scan(scan_pk, search=search)
|
||||
return self.service.get_all_ip_aggregation(search=search)
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if not scan_pk:
|
||||
raise DRFValidationError('必须在扫描下导出')
|
||||
|
||||
def line_iterator():
|
||||
for ip in self.service.iter_ips_by_scan(scan_pk):
|
||||
yield f"{ip}\n"
|
||||
|
||||
response = StreamingHttpResponse(line_iterator(), content_type='text/plain; charset=utf-8')
|
||||
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-ip-addresses.txt"'
|
||||
return response
|
||||
|
||||
|
||||
class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""漏洞快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/vulnerabilities/"""
|
||||
|
||||
serializer_class = VulnerabilitySnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['vuln_type']
|
||||
ordering = ['-discovered_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = VulnerabilitySnapshotsService()
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
23
backend/apps/common/__init__.py
Normal file
23
backend/apps/common/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
通用工具模块
|
||||
|
||||
提供各种共享的工具类和函数
|
||||
"""
|
||||
|
||||
from .normalizer import normalize_domain, normalize_ip, normalize_cidr, normalize_target
|
||||
from .validators import validate_domain, validate_ip, validate_cidr, detect_target_type
|
||||
|
||||
__all__ = [
|
||||
# 规范化工具
|
||||
'normalize_domain',
|
||||
'normalize_ip',
|
||||
'normalize_cidr',
|
||||
'normalize_target',
|
||||
|
||||
# 验证器
|
||||
'validate_domain',
|
||||
'validate_ip',
|
||||
'validate_cidr',
|
||||
'detect_target_type',
|
||||
]
|
||||
|
||||
10
backend/apps/common/apps.py
Normal file
10
backend/apps/common/apps.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class CommonConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'apps.common' # 因为在 apps/ 目录下
|
||||
|
||||
def ready(self):
|
||||
"""应用就绪时调用"""
|
||||
pass
|
||||
13
backend/apps/common/authentication.py
Normal file
13
backend/apps/common/authentication.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from rest_framework.authentication import SessionAuthentication
|
||||
|
||||
|
||||
class CsrfExemptSessionAuthentication(SessionAuthentication):
|
||||
"""
|
||||
前后端分离项目使用的 Session 认证
|
||||
禁用 CSRF 检查,因为 CSRF 主要防护的是同源页面表单提交
|
||||
前后端分离项目通过 CORS 控制跨域访问,不需要 CSRF
|
||||
"""
|
||||
|
||||
def enforce_csrf(self, request):
|
||||
# 不执行 CSRF 检查
|
||||
return
|
||||
66
backend/apps/common/container_bootstrap.py
Normal file
66
backend/apps/common/container_bootstrap.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
容器启动引导模块
|
||||
|
||||
提供动态任务容器的通用初始化功能:
|
||||
- 从 Server 配置中心获取配置
|
||||
- 设置环境变量
|
||||
- 初始化 Django 环境
|
||||
|
||||
使用方式:
|
||||
from apps.common.container_bootstrap import fetch_config_and_setup_django
|
||||
fetch_config_and_setup_django() # 必须在 Django 导入之前调用
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import requests
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fetch_config_and_setup_django():
|
||||
"""
|
||||
从配置中心获取配置并初始化 Django
|
||||
|
||||
Note:
|
||||
必须在 Django 导入之前调用此函数
|
||||
"""
|
||||
server_url = os.environ.get("SERVER_URL")
|
||||
if not server_url:
|
||||
print("[ERROR] 缺少 SERVER_URL 环境变量", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
config_url = f"{server_url}/api/workers/config/"
|
||||
try:
|
||||
resp = requests.get(config_url, timeout=10)
|
||||
resp.raise_for_status()
|
||||
config = resp.json()
|
||||
|
||||
# 数据库配置(必需)
|
||||
os.environ.setdefault("DB_HOST", config['db']['host'])
|
||||
os.environ.setdefault("DB_PORT", config['db']['port'])
|
||||
os.environ.setdefault("DB_NAME", config['db']['name'])
|
||||
os.environ.setdefault("DB_USER", config['db']['user'])
|
||||
os.environ.setdefault("DB_PASSWORD", config['db']['password'])
|
||||
|
||||
# Redis 配置
|
||||
os.environ.setdefault("REDIS_URL", config['redisUrl'])
|
||||
|
||||
# 日志配置
|
||||
os.environ.setdefault("LOG_DIR", config['paths']['logs'])
|
||||
os.environ.setdefault("LOG_LEVEL", config['logging']['level'])
|
||||
os.environ.setdefault("ENABLE_COMMAND_LOGGING", str(config['logging']['enableCommandLogging']).lower())
|
||||
os.environ.setdefault("DEBUG", str(config['debug']))
|
||||
|
||||
print(f"[CONFIG] 从配置中心获取配置成功: {config_url}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 获取配置失败: {config_url} - {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# 初始化 Django
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
|
||||
import django
|
||||
django.setup()
|
||||
|
||||
return config
|
||||
19
backend/apps/common/decorators/__init__.py
Normal file
19
backend/apps/common/decorators/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
通用装饰器模块
|
||||
|
||||
提供可在整个项目中复用的装饰器
|
||||
"""
|
||||
|
||||
from .db_connection import (
|
||||
ensure_db_connection,
|
||||
auto_ensure_db_connection,
|
||||
async_check_and_reconnect,
|
||||
ensure_db_connection_async,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'ensure_db_connection',
|
||||
'auto_ensure_db_connection',
|
||||
'async_check_and_reconnect',
|
||||
'ensure_db_connection_async',
|
||||
]
|
||||
169
backend/apps/common/decorators/db_connection.py
Normal file
169
backend/apps/common/decorators/db_connection.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
数据库连接装饰器
|
||||
|
||||
提供自动数据库连接健康检查的装饰器,确保长时间运行的任务中数据库连接不会失效。
|
||||
|
||||
主要功能:
|
||||
- @auto_ensure_db_connection: 类装饰器,自动为所有公共方法添加连接检查
|
||||
- @ensure_db_connection: 方法装饰器,单独为某个方法添加连接检查
|
||||
|
||||
使用场景:
|
||||
- Repository 层的数据库操作
|
||||
- Service 层需要确保数据库连接的操作
|
||||
- 任何需要数据库连接健康检查的类或方法
|
||||
"""
|
||||
|
||||
import logging
|
||||
import functools
|
||||
import time
|
||||
from django.db import connection
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ensure_db_connection(method):
|
||||
"""
|
||||
方法装饰器:自动确保数据库连接健康
|
||||
|
||||
在方法执行前自动检查数据库连接,如果连接失效则自动重连。
|
||||
|
||||
使用场景:
|
||||
- 需要单独装饰某个方法时使用
|
||||
- 通常建议使用类装饰器 @auto_ensure_db_connection
|
||||
|
||||
示例:
|
||||
@ensure_db_connection
|
||||
def my_method(self):
|
||||
# 会自动检查连接健康
|
||||
...
|
||||
"""
|
||||
@functools.wraps(method)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
_check_and_reconnect()
|
||||
return method(self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def auto_ensure_db_connection(cls):
|
||||
"""
|
||||
类装饰器:自动给所有公共方法添加数据库连接检查
|
||||
|
||||
自动为类中所有公共方法(不以 _ 开头的方法)添加 @ensure_db_connection 装饰器。
|
||||
|
||||
特性:
|
||||
- 自动装饰所有公共方法
|
||||
- 跳过私有方法(以 _ 开头)
|
||||
- 跳过类方法和静态方法
|
||||
- 跳过已经装饰过的方法
|
||||
|
||||
使用方式:
|
||||
@auto_ensure_db_connection
|
||||
class MyRepository:
|
||||
def bulk_create(self, items):
|
||||
# 自动添加连接检查
|
||||
...
|
||||
|
||||
def query(self, filters):
|
||||
# 自动添加连接检查
|
||||
...
|
||||
|
||||
def _private_method(self):
|
||||
# 不会添加装饰器
|
||||
...
|
||||
|
||||
优势:
|
||||
- 无需为每个方法手动添加装饰器
|
||||
- 减少代码重复
|
||||
- 降低遗漏风险
|
||||
"""
|
||||
for attr_name in dir(cls):
|
||||
# 跳过私有方法和魔术方法
|
||||
if attr_name.startswith('_'):
|
||||
continue
|
||||
|
||||
attr = getattr(cls, attr_name)
|
||||
|
||||
# 只装饰可调用的实例方法
|
||||
if callable(attr) and not isinstance(attr, (staticmethod, classmethod)):
|
||||
# 检查是否已经被装饰过(避免重复装饰)
|
||||
if not hasattr(attr, '_db_connection_ensured'):
|
||||
wrapped = ensure_db_connection(attr)
|
||||
wrapped._db_connection_ensured = True
|
||||
setattr(cls, attr_name, wrapped)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
def _check_and_reconnect(max_retries=5):
|
||||
"""
|
||||
检查数据库连接健康状态,必要时使用指数退避重新连接
|
||||
|
||||
策略:
|
||||
1. 尝试执行简单查询测试连接
|
||||
2. 如果失败,使用指数退避策略重试(最多5次)
|
||||
3. 每次重试的等待时间:2^attempt 秒 (1s, 2s, 4s, 8s, 16s)
|
||||
|
||||
异常处理:
|
||||
- 连接失效时自动重连
|
||||
- 记录警告日志和重试信息
|
||||
- 忽略关闭连接时的错误
|
||||
- 达到最大重试次数后抛出异常
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
connection.ensure_connection()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.fetchone()
|
||||
|
||||
# 连接成功
|
||||
if attempt > 0:
|
||||
logger.info(f"数据库重连成功 (尝试 {attempt + 1}/{max_retries})")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
f"数据库连接检查失败 (尝试 {attempt + 1}/{max_retries}): {e}"
|
||||
)
|
||||
|
||||
# 关闭失效的连接
|
||||
try:
|
||||
connection.close()
|
||||
except Exception:
|
||||
pass # 忽略关闭时的错误
|
||||
|
||||
# 如果还有重试机会,使用指数退避等待
|
||||
if attempt < max_retries - 1:
|
||||
delay = 2 ** attempt # 指数退避: 1, 2, 4, 8, 16 秒
|
||||
logger.info(f"等待 {delay} 秒后重试...")
|
||||
time.sleep(delay)
|
||||
else:
|
||||
# 最后一次尝试也失败,抛出异常
|
||||
logger.error(
|
||||
f"数据库重连失败,已达最大重试次数 ({max_retries})"
|
||||
)
|
||||
raise last_error
|
||||
|
||||
|
||||
async def async_check_and_reconnect(max_retries=5):
|
||||
await sync_to_async(_check_and_reconnect, thread_sensitive=True)(max_retries=max_retries)
|
||||
|
||||
|
||||
def ensure_db_connection_async(method):
|
||||
@functools.wraps(method)
|
||||
async def wrapper(*args, **kwargs):
|
||||
await async_check_and_reconnect()
|
||||
return await method(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
__all__ = [
|
||||
'ensure_db_connection',
|
||||
'auto_ensure_db_connection',
|
||||
'async_check_and_reconnect',
|
||||
'ensure_db_connection_async',
|
||||
]
|
||||
20
backend/apps/common/definitions.py
Normal file
20
backend/apps/common/definitions.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from django.db import models
|
||||
|
||||
|
||||
class ScanStatus(models.TextChoices):
|
||||
"""扫描任务状态枚举"""
|
||||
CANCELLED = 'cancelled', '已取消'
|
||||
COMPLETED = 'completed', '已完成'
|
||||
FAILED = 'failed', '失败'
|
||||
INITIATED = 'initiated', '初始化'
|
||||
RUNNING = 'running', '运行中'
|
||||
|
||||
|
||||
class VulnSeverity(models.TextChoices):
|
||||
"""漏洞严重性枚举"""
|
||||
UNKNOWN = 'unknown', '未知'
|
||||
INFO = 'info', '信息'
|
||||
LOW = 'low', '低'
|
||||
MEDIUM = 'medium', '中'
|
||||
HIGH = 'high', '高'
|
||||
CRITICAL = 'critical', '危急'
|
||||
101
backend/apps/common/hash_utils.py
Normal file
101
backend/apps/common/hash_utils.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""通用文件 hash 计算与校验工具
|
||||
|
||||
提供 SHA-256 哈希计算和校验功能,用于:
|
||||
- 字典文件上传时计算 hash
|
||||
- Worker 侧本地缓存校验
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, BinaryIO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认分块大小:64KB(兼顾内存和性能)
|
||||
DEFAULT_CHUNK_SIZE = 65536
|
||||
|
||||
|
||||
def calc_file_sha256(file_path: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> str:
|
||||
"""计算文件的 SHA-256 哈希值
|
||||
|
||||
Args:
|
||||
file_path: 文件绝对路径
|
||||
chunk_size: 分块读取大小(字节),默认 64KB
|
||||
|
||||
Returns:
|
||||
str: SHA-256 十六进制字符串(64 字符)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
OSError: 文件读取失败
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||
hasher.update(chunk)
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def calc_stream_sha256(stream: BinaryIO, chunk_size: int = DEFAULT_CHUNK_SIZE) -> str:
|
||||
"""从二进制流计算 SHA-256(用于边写边算)
|
||||
|
||||
Args:
|
||||
stream: 可读取的二进制流(如 UploadedFile.chunks())
|
||||
chunk_size: 分块大小
|
||||
|
||||
Returns:
|
||||
str: SHA-256 十六进制字符串
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
for chunk in iter(lambda: stream.read(chunk_size), b""):
|
||||
hasher.update(chunk)
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def safe_calc_file_sha256(file_path: str) -> Optional[str]:
|
||||
"""安全计算文件 SHA-256(异常时返回 None)
|
||||
|
||||
Args:
|
||||
file_path: 文件绝对路径
|
||||
|
||||
Returns:
|
||||
str | None: SHA-256 十六进制字符串,或 None(文件不存在/读取失败)
|
||||
"""
|
||||
try:
|
||||
return calc_file_sha256(file_path)
|
||||
except FileNotFoundError:
|
||||
logger.warning("计算 hash 失败:文件不存在 - %s", file_path)
|
||||
return None
|
||||
except OSError as exc:
|
||||
logger.warning("计算 hash 失败:读取错误 - %s: %s", file_path, exc)
|
||||
return None
|
||||
|
||||
|
||||
def is_file_hash_match(file_path: str, expected_hash: str) -> bool:
|
||||
"""校验文件 hash 是否与期望值匹配
|
||||
|
||||
Args:
|
||||
file_path: 文件绝对路径
|
||||
expected_hash: 期望的 SHA-256 十六进制字符串
|
||||
|
||||
Returns:
|
||||
bool: True 表示匹配,False 表示不匹配或计算失败
|
||||
"""
|
||||
if not expected_hash:
|
||||
# 期望值为空,视为"无法校验",返回 False 让调用方决定是否重新下载
|
||||
return False
|
||||
|
||||
actual_hash = safe_calc_file_sha256(file_path)
|
||||
if actual_hash is None:
|
||||
return False
|
||||
|
||||
return actual_hash.lower() == expected_hash.lower()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"calc_file_sha256",
|
||||
"calc_stream_sha256",
|
||||
"safe_calc_file_sha256",
|
||||
"is_file_hash_match",
|
||||
]
|
||||
429
backend/apps/common/management/commands/db_health_check.py
Normal file
429
backend/apps/common/management/commands/db_health_check.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
数据库健康检查管理命令
|
||||
|
||||
使用方法:
|
||||
python manage.py db_health_check # 基础延迟测试(5次)
|
||||
python manage.py db_health_check --test-count=10 # 指定测试次数
|
||||
python manage.py db_health_check --reconnect # 强制重连后测试
|
||||
python manage.py db_health_check --stats # 显示连接统计信息
|
||||
python manage.py db_health_check --api-test # 测试实际API查询性能
|
||||
python manage.py db_health_check --monitor # 监控数据库服务器性能指标
|
||||
python manage.py db_health_check --db=other # 指定数据库别名
|
||||
python manage.py db_health_check --api-test --test-count=10 # API性能测试10次
|
||||
python manage.py db_health_check --reconnect --api-test # 重连后进行API测试
|
||||
|
||||
示例:
|
||||
# 快速延迟检查
|
||||
python manage.py db_health_check --test-count=3
|
||||
|
||||
# 完整性能分析
|
||||
python manage.py db_health_check --api-test --stats --test-count=5
|
||||
|
||||
# 数据库服务器监控
|
||||
python manage.py db_health_check --monitor
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import connection, connections
|
||||
from django.conf import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""Django管理命令:数据库健康检查"""
|
||||
|
||||
help = '检查数据库连接健康状态'
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
'--reconnect',
|
||||
action='store_true',
|
||||
help='强制重新连接数据库',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--stats',
|
||||
action='store_true',
|
||||
help='显示连接统计信息',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--db',
|
||||
type=str,
|
||||
default='default',
|
||||
help='指定数据库别名(默认: default)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--test-count',
|
||||
type=int,
|
||||
default=5,
|
||||
help='延迟测试次数(默认: 5)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--api-test',
|
||||
action='store_true',
|
||||
help='测试实际API查询性能',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--monitor',
|
||||
action='store_true',
|
||||
help='监控数据库服务器性能指标',
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
db_alias = options['db']
|
||||
test_count = options['test_count']
|
||||
|
||||
self.stdout.write(f"正在测试数据库 '{db_alias}' 连接...")
|
||||
|
||||
# 获取数据库连接
|
||||
db_connection = connections[db_alias]
|
||||
|
||||
if options['reconnect']:
|
||||
self.stdout.write("强制重新连接数据库...")
|
||||
try:
|
||||
db_connection.close()
|
||||
db_connection.ensure_connection()
|
||||
self.stdout.write(self.style.SUCCESS("✓ 重连成功"))
|
||||
except Exception as e:
|
||||
self.stdout.write(self.style.ERROR(f"✗ 重连失败: {e}"))
|
||||
return
|
||||
|
||||
# 测试数据库延迟
|
||||
if options['monitor']:
|
||||
self.monitor_database_performance(db_connection)
|
||||
elif options['api_test']:
|
||||
self.test_api_performance(test_count)
|
||||
else:
|
||||
self.test_database_latency(db_connection, test_count)
|
||||
|
||||
if options['stats']:
|
||||
self.show_connection_stats(db_connection)
|
||||
|
||||
def test_database_latency(self, db_connection, test_count):
|
||||
"""测试数据库延迟"""
|
||||
self.stdout.write(f"\n开始延迟测试({test_count} 次)...")
|
||||
|
||||
latencies = []
|
||||
successful_tests = 0
|
||||
connection_times = []
|
||||
query_times = []
|
||||
|
||||
for i in range(test_count):
|
||||
try:
|
||||
# 测试连接建立时间
|
||||
conn_start = time.time()
|
||||
db_connection.ensure_connection()
|
||||
conn_end = time.time()
|
||||
conn_time = (conn_end - conn_start) * 1000
|
||||
connection_times.append(conn_time)
|
||||
|
||||
# 测试查询执行时间
|
||||
query_start = time.time()
|
||||
with db_connection.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
result = cursor.fetchone()
|
||||
query_end = time.time()
|
||||
query_time = (query_end - query_start) * 1000
|
||||
query_times.append(query_time)
|
||||
|
||||
total_time = conn_time + query_time
|
||||
latencies.append(total_time)
|
||||
successful_tests += 1
|
||||
|
||||
self.stdout.write(f" 测试 {i+1}: 总计{total_time:.2f}ms (连接:{conn_time:.2f}ms + 查询:{query_time:.2f}ms) ✓")
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(f" 测试 {i+1}: 失败 - {e}")
|
||||
|
||||
# 计算统计信息
|
||||
if latencies:
|
||||
avg_latency = sum(latencies) / len(latencies)
|
||||
min_latency = min(latencies)
|
||||
max_latency = max(latencies)
|
||||
|
||||
avg_conn_time = sum(connection_times) / len(connection_times)
|
||||
avg_query_time = sum(query_times) / len(query_times)
|
||||
|
||||
self.stdout.write(f"\n延迟统计:")
|
||||
self.stdout.write(f" 成功测试: {successful_tests}/{test_count}")
|
||||
self.stdout.write(f" 平均总延迟: {avg_latency:.2f}ms")
|
||||
self.stdout.write(f" 平均连接时间: {avg_conn_time:.2f}ms")
|
||||
self.stdout.write(f" 平均查询时间: {avg_query_time:.2f}ms")
|
||||
self.stdout.write(f" 最小延迟: {min_latency:.2f}ms")
|
||||
self.stdout.write(f" 最大延迟: {max_latency:.2f}ms")
|
||||
|
||||
# 分析延迟来源
|
||||
if avg_conn_time > avg_query_time * 2:
|
||||
self.stdout.write(self.style.WARNING(" 分析: 连接建立是主要延迟来源"))
|
||||
elif avg_query_time > avg_conn_time * 2:
|
||||
self.stdout.write(self.style.WARNING(" 分析: 查询执行是主要延迟来源"))
|
||||
else:
|
||||
self.stdout.write(" 分析: 连接和查询延迟相当")
|
||||
|
||||
# 延迟评估
|
||||
if avg_latency < 10:
|
||||
self.stdout.write(self.style.SUCCESS(" 评估: 延迟很低,连接优秀"))
|
||||
elif avg_latency < 50:
|
||||
self.stdout.write(self.style.SUCCESS(" 评估: 延迟较低,连接良好"))
|
||||
elif avg_latency < 200:
|
||||
self.stdout.write(self.style.WARNING(" 评估: 延迟中等,连接可接受"))
|
||||
else:
|
||||
self.stdout.write(self.style.ERROR(" 评估: 延迟较高,可能影响性能"))
|
||||
else:
|
||||
self.stdout.write(self.style.ERROR("所有测试都失败了"))
|
||||
|
||||
def test_api_performance(self, test_count):
|
||||
"""测试实际API查询性能"""
|
||||
self.stdout.write(f"\n开始API性能测试({test_count} 次)...")
|
||||
|
||||
# 导入必要的模块
|
||||
from apps.scan.models import Scan
|
||||
from apps.engine.models import ScanEngine
|
||||
from apps.targets.models import Target
|
||||
from django.db.models import Count
|
||||
|
||||
api_latencies = []
|
||||
successful_tests = 0
|
||||
|
||||
for i in range(test_count):
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 测试多种查询类型
|
||||
|
||||
# 1. 简单查询 - 引擎列表
|
||||
engines = list(ScanEngine.objects.all()[:10])
|
||||
|
||||
# 2. 复杂查询 - 扫描列表(即使没有数据也会执行复杂的JOIN)
|
||||
scan_queryset = Scan.objects.select_related(
|
||||
'target', 'engine'
|
||||
).annotate(
|
||||
subdomains_count=Count('subdomains', distinct=True),
|
||||
endpoints_count=Count('endpoints', distinct=True),
|
||||
ips_count=Count('ip_addresses', distinct=True)
|
||||
).order_by('-id')[:10]
|
||||
scan_list = list(scan_queryset)
|
||||
|
||||
# 3. 目标查询
|
||||
targets = list(Target.objects.all()[:10])
|
||||
|
||||
end_time = time.time()
|
||||
latency_ms = (end_time - start_time) * 1000
|
||||
api_latencies.append(latency_ms)
|
||||
successful_tests += 1
|
||||
|
||||
self.stdout.write(f" API测试 {i+1}: {latency_ms:.2f}ms ✓ (引擎:{len(engines)}, 扫描:{len(scan_list)}, 目标:{len(targets)})")
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(f" API测试 {i+1}: 失败 - {e}")
|
||||
|
||||
# 计算API查询统计信息
|
||||
if api_latencies:
|
||||
avg_latency = sum(api_latencies) / len(api_latencies)
|
||||
min_latency = min(api_latencies)
|
||||
max_latency = max(api_latencies)
|
||||
|
||||
self.stdout.write(f"\nAPI查询统计:")
|
||||
self.stdout.write(f" 成功测试: {successful_tests}/{test_count}")
|
||||
self.stdout.write(f" 平均延迟: {avg_latency:.2f}ms")
|
||||
self.stdout.write(f" 最小延迟: {min_latency:.2f}ms")
|
||||
self.stdout.write(f" 最大延迟: {max_latency:.2f}ms")
|
||||
|
||||
# 与简单查询对比
|
||||
simple_query_avg = 150 # 基于之前的测试结果
|
||||
overhead = avg_latency - simple_query_avg
|
||||
self.stdout.write(f" 业务逻辑开销: {overhead:.2f}ms")
|
||||
|
||||
# 性能评估
|
||||
if avg_latency < 500:
|
||||
self.stdout.write(self.style.SUCCESS(" 评估: API响应速度良好"))
|
||||
elif avg_latency < 1000:
|
||||
self.stdout.write(self.style.WARNING(" 评估: API响应速度一般"))
|
||||
else:
|
||||
self.stdout.write(self.style.ERROR(" 评估: API响应速度较慢,需要优化"))
|
||||
else:
|
||||
self.stdout.write(self.style.ERROR("所有API测试都失败了"))
|
||||
|
||||
def monitor_database_performance(self, db_connection):
|
||||
"""监控数据库服务器性能指标"""
|
||||
self.stdout.write(f"\n开始监控数据库性能指标...")
|
||||
|
||||
try:
|
||||
with db_connection.cursor() as cursor:
|
||||
# 1. 数据库基本信息
|
||||
self.stdout.write(f"\n=== 数据库基本信息 ===")
|
||||
cursor.execute("SELECT version();")
|
||||
version = cursor.fetchone()[0]
|
||||
self.stdout.write(f"PostgreSQL版本: {version}")
|
||||
|
||||
cursor.execute("SELECT current_database();")
|
||||
db_name = cursor.fetchone()[0]
|
||||
self.stdout.write(f"当前数据库: {db_name}")
|
||||
|
||||
# 2. 连接信息
|
||||
self.stdout.write(f"\n=== 连接状态 ===")
|
||||
cursor.execute("""
|
||||
SELECT count(*) as total_connections,
|
||||
count(*) FILTER (WHERE state = 'active') as active_connections,
|
||||
count(*) FILTER (WHERE state = 'idle') as idle_connections
|
||||
FROM pg_stat_activity;
|
||||
""")
|
||||
conn_stats = cursor.fetchone()
|
||||
self.stdout.write(f"总连接数: {conn_stats[0]}")
|
||||
self.stdout.write(f"活跃连接: {conn_stats[1]}")
|
||||
self.stdout.write(f"空闲连接: {conn_stats[2]}")
|
||||
|
||||
# 3. 数据库大小
|
||||
self.stdout.write(f"\n=== 数据库大小 ===")
|
||||
cursor.execute("""
|
||||
SELECT pg_size_pretty(pg_database_size(current_database())) as db_size;
|
||||
""")
|
||||
db_size = cursor.fetchone()[0]
|
||||
self.stdout.write(f"数据库大小: {db_size}")
|
||||
|
||||
# 4. 表统计信息
|
||||
self.stdout.write(f"\n=== 主要表统计 ===")
|
||||
cursor.execute("""
|
||||
SELECT schemaname, relname,
|
||||
n_tup_ins as inserts,
|
||||
n_tup_upd as updates,
|
||||
n_tup_del as deletes,
|
||||
n_live_tup as live_rows,
|
||||
n_dead_tup as dead_rows
|
||||
FROM pg_stat_user_tables
|
||||
WHERE schemaname = 'public'
|
||||
ORDER BY n_live_tup DESC
|
||||
LIMIT 10;
|
||||
""")
|
||||
tables = cursor.fetchall()
|
||||
if tables:
|
||||
for table in tables:
|
||||
self.stdout.write(f" {table[1]}: {table[5]} 行 (插入:{table[2]}, 更新:{table[3]}, 删除:{table[4]}, 死行:{table[6]})")
|
||||
else:
|
||||
self.stdout.write(" 暂无表统计数据")
|
||||
|
||||
# 5. 慢查询统计
|
||||
self.stdout.write(f"\n=== 查询性能统计 ===")
|
||||
cursor.execute("""
|
||||
SELECT query,
|
||||
calls,
|
||||
total_time,
|
||||
mean_time,
|
||||
rows
|
||||
FROM pg_stat_statements
|
||||
WHERE query NOT LIKE '%pg_stat_statements%'
|
||||
ORDER BY mean_time DESC
|
||||
LIMIT 5;
|
||||
""")
|
||||
try:
|
||||
slow_queries = cursor.fetchall()
|
||||
if slow_queries:
|
||||
for i, query in enumerate(slow_queries, 1):
|
||||
self.stdout.write(f" {i}. 平均耗时: {query[3]:.2f}ms, 调用次数: {query[1]}")
|
||||
self.stdout.write(f" 查询: {query[0][:100]}...")
|
||||
else:
|
||||
self.stdout.write(" 未找到查询统计(可能未启用pg_stat_statements扩展)")
|
||||
except Exception:
|
||||
self.stdout.write(" 查询统计不可用(需要pg_stat_statements扩展)")
|
||||
|
||||
# 6. 锁信息
|
||||
self.stdout.write(f"\n=== 锁状态 ===")
|
||||
cursor.execute("""
|
||||
SELECT mode, count(*)
|
||||
FROM pg_locks
|
||||
GROUP BY mode
|
||||
ORDER BY count(*) DESC;
|
||||
""")
|
||||
locks = cursor.fetchall()
|
||||
total_locks = sum(lock[1] for lock in locks)
|
||||
self.stdout.write(f"总锁数量: {total_locks}")
|
||||
for lock in locks:
|
||||
self.stdout.write(f" {lock[0]}: {lock[1]} 个")
|
||||
|
||||
# 7. 缓存命中率
|
||||
self.stdout.write(f"\n=== 缓存性能 ===")
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
sum(heap_blks_read) as heap_read,
|
||||
sum(heap_blks_hit) as heap_hit,
|
||||
sum(heap_blks_hit) / (sum(heap_blks_hit) + sum(heap_blks_read)) * 100 as cache_hit_ratio
|
||||
FROM pg_statio_user_tables;
|
||||
""")
|
||||
cache_stats = cursor.fetchone()
|
||||
if cache_stats[0] and cache_stats[1]:
|
||||
self.stdout.write(f"缓存命中率: {cache_stats[2]:.2f}%")
|
||||
self.stdout.write(f"磁盘读取: {cache_stats[0]} 块")
|
||||
self.stdout.write(f"缓存命中: {cache_stats[1]} 块")
|
||||
else:
|
||||
self.stdout.write("缓存统计: 暂无数据")
|
||||
|
||||
# 8. 检查点和WAL
|
||||
self.stdout.write(f"\n=== WAL和检查点 ===")
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
checkpoints_timed,
|
||||
checkpoints_req,
|
||||
checkpoint_write_time,
|
||||
checkpoint_sync_time
|
||||
FROM pg_stat_bgwriter;
|
||||
""")
|
||||
bgwriter = cursor.fetchone()
|
||||
self.stdout.write(f"定时检查点: {bgwriter[0]}")
|
||||
self.stdout.write(f"请求检查点: {bgwriter[1]}")
|
||||
self.stdout.write(f"检查点写入时间: {bgwriter[2]}ms")
|
||||
self.stdout.write(f"检查点同步时间: {bgwriter[3]}ms")
|
||||
|
||||
# 9. 当前活跃查询
|
||||
self.stdout.write(f"\n=== 当前活跃查询 ===")
|
||||
cursor.execute("""
|
||||
SELECT pid,
|
||||
application_name,
|
||||
state,
|
||||
query_start,
|
||||
now() - query_start as duration,
|
||||
left(query, 100) as query_preview
|
||||
FROM pg_stat_activity
|
||||
WHERE state = 'active'
|
||||
AND query NOT LIKE '%pg_stat_activity%'
|
||||
ORDER BY query_start;
|
||||
""")
|
||||
active_queries = cursor.fetchall()
|
||||
if active_queries:
|
||||
for query in active_queries:
|
||||
self.stdout.write(f" PID {query[0]} ({query[1]}): 运行 {query[4]}")
|
||||
self.stdout.write(f" 查询: {query[5]}...")
|
||||
else:
|
||||
self.stdout.write(" 当前无活跃查询")
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(self.style.ERROR(f"监控失败: {e}"))
|
||||
|
||||
def show_connection_stats(self, db_connection):
|
||||
"""显示连接统计信息"""
|
||||
self.stdout.write(f"\n连接信息:")
|
||||
|
||||
# 基本连接信息
|
||||
settings_dict = db_connection.settings_dict
|
||||
self.stdout.write(f" 数据库类型: {db_connection.vendor}")
|
||||
self.stdout.write(f" 主机: {settings_dict.get('HOST', 'localhost')}")
|
||||
self.stdout.write(f" 端口: {settings_dict.get('PORT', '5432')}")
|
||||
self.stdout.write(f" 数据库名: {settings_dict.get('NAME', '')}")
|
||||
self.stdout.write(f" 用户: {settings_dict.get('USER', '')}")
|
||||
|
||||
# 连接配置
|
||||
conn_max_age = settings_dict.get('CONN_MAX_AGE', 0)
|
||||
self.stdout.write(f" 连接最大存活时间: {conn_max_age}秒")
|
||||
|
||||
# 查询统计
|
||||
if hasattr(db_connection, 'queries'):
|
||||
query_count = len(db_connection.queries)
|
||||
if query_count > 0:
|
||||
total_time = sum(float(q['time']) for q in db_connection.queries)
|
||||
self.stdout.write(f" 查询次数: {query_count}")
|
||||
self.stdout.write(f" 总查询时间: {total_time:.4f}秒")
|
||||
|
||||
# 连接状态
|
||||
is_connected = hasattr(db_connection, 'connection') and db_connection.connection is not None
|
||||
self.stdout.write(f" 连接状态: {'已连接' if is_connected else '未连接'}")
|
||||
164
backend/apps/common/management/commands/db_monitor.py
Normal file
164
backend/apps/common/management/commands/db_monitor.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
简化的数据库性能监控命令
|
||||
|
||||
专注于可能导致查询延迟的关键指标
|
||||
"""
|
||||
|
||||
import time
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import connections
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""简化的数据库性能监控"""
|
||||
|
||||
help = '监控数据库性能关键指标'
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
'--interval',
|
||||
type=int,
|
||||
default=5,
|
||||
help='监控间隔(秒,默认: 5)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--count',
|
||||
type=int,
|
||||
default=3,
|
||||
help='监控次数(默认: 3)',
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
interval = options['interval']
|
||||
count = options['count']
|
||||
|
||||
self.stdout.write("🔍 数据库性能监控开始...")
|
||||
|
||||
for i in range(count):
|
||||
if i > 0:
|
||||
time.sleep(interval)
|
||||
|
||||
self.stdout.write(f"\n=== 第 {i+1} 次监控 ===")
|
||||
self.monitor_key_metrics()
|
||||
|
||||
def monitor_key_metrics(self):
|
||||
"""监控关键性能指标"""
|
||||
db_connection = connections['default']
|
||||
|
||||
try:
|
||||
with db_connection.cursor() as cursor:
|
||||
# 1. 连接和活动状态
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
count(*) as total_connections,
|
||||
count(*) FILTER (WHERE state = 'active') as active,
|
||||
count(*) FILTER (WHERE state = 'idle') as idle,
|
||||
count(*) FILTER (WHERE state = 'idle in transaction') as idle_in_trans,
|
||||
count(*) FILTER (WHERE wait_event_type IS NOT NULL) as waiting
|
||||
FROM pg_stat_activity;
|
||||
""")
|
||||
conn_stats = cursor.fetchone()
|
||||
self.stdout.write(f"连接: 总计{conn_stats[0]} | 活跃{conn_stats[1]} | 空闲{conn_stats[2]} | 事务中{conn_stats[3]} | 等待{conn_stats[4]}")
|
||||
|
||||
# 2. 锁等待情况
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
count(*) as total_locks,
|
||||
count(*) FILTER (WHERE NOT granted) as waiting_locks
|
||||
FROM pg_locks;
|
||||
""")
|
||||
lock_stats = cursor.fetchone()
|
||||
if lock_stats[1] > 0:
|
||||
self.stdout.write(self.style.WARNING(f"🔒 锁: 总计{lock_stats[0]} | 等待{lock_stats[1]}"))
|
||||
else:
|
||||
self.stdout.write(f"🔒 锁: 总计{lock_stats[0]} | 等待{lock_stats[1]}")
|
||||
|
||||
# 3. 长时间运行的查询
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
pid,
|
||||
application_name,
|
||||
now() - query_start as duration,
|
||||
state,
|
||||
left(query, 60) as query_preview
|
||||
FROM pg_stat_activity
|
||||
WHERE state = 'active'
|
||||
AND query_start < now() - interval '1 second'
|
||||
AND query NOT LIKE '%pg_stat_activity%'
|
||||
ORDER BY query_start;
|
||||
""")
|
||||
long_queries = cursor.fetchall()
|
||||
if long_queries:
|
||||
self.stdout.write(self.style.WARNING(f"⏱️ 长查询 ({len(long_queries)} 个):"))
|
||||
for query in long_queries:
|
||||
self.stdout.write(f" PID {query[0]} ({query[1]}): {query[2]} - {query[4]}...")
|
||||
else:
|
||||
self.stdout.write("⏱️ 长查询: 无")
|
||||
|
||||
# 4. 缓存命中率
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
sum(heap_blks_hit) as cache_hits,
|
||||
sum(heap_blks_read) as disk_reads,
|
||||
CASE
|
||||
WHEN sum(heap_blks_hit) + sum(heap_blks_read) = 0 THEN 0
|
||||
ELSE round(sum(heap_blks_hit) * 100.0 / (sum(heap_blks_hit) + sum(heap_blks_read)), 2)
|
||||
END as hit_ratio
|
||||
FROM pg_statio_user_tables;
|
||||
""")
|
||||
cache_stats = cursor.fetchone()
|
||||
if cache_stats[0] or cache_stats[1]:
|
||||
hit_ratio = cache_stats[2] or 0
|
||||
if hit_ratio < 95:
|
||||
self.stdout.write(self.style.WARNING(f"💾 缓存命中率: {hit_ratio}% (缓存:{cache_stats[0]}, 磁盘:{cache_stats[1]})"))
|
||||
else:
|
||||
self.stdout.write(f"💾 缓存命中率: {hit_ratio}% (缓存:{cache_stats[0]}, 磁盘:{cache_stats[1]})")
|
||||
else:
|
||||
self.stdout.write("💾 缓存: 暂无统计数据")
|
||||
|
||||
# 5. 检查点活动(尝试获取,如果失败则跳过)
|
||||
try:
|
||||
cursor.execute("SELECT * FROM pg_stat_bgwriter LIMIT 1;")
|
||||
bgwriter_cols = [desc[0] for desc in cursor.description]
|
||||
|
||||
if 'checkpoints_timed' in bgwriter_cols:
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
checkpoints_timed,
|
||||
checkpoints_req,
|
||||
checkpoint_write_time,
|
||||
checkpoint_sync_time
|
||||
FROM pg_stat_bgwriter;
|
||||
""")
|
||||
bgwriter = cursor.fetchone()
|
||||
total_checkpoints = bgwriter[0] + bgwriter[1]
|
||||
if bgwriter[2] > 10000 or bgwriter[3] > 5000:
|
||||
self.stdout.write(self.style.WARNING(f"📝 检查点: 总计{total_checkpoints} | 写入{bgwriter[2]}ms | 同步{bgwriter[3]}ms"))
|
||||
else:
|
||||
self.stdout.write(f"📝 检查点: 总计{total_checkpoints} | 写入{bgwriter[2]}ms | 同步{bgwriter[3]}ms")
|
||||
else:
|
||||
self.stdout.write("📝 检查点: 统计不可用")
|
||||
except Exception:
|
||||
self.stdout.write("📝 检查点: 统计不可用")
|
||||
|
||||
# 6. 数据库大小变化
|
||||
cursor.execute("SELECT pg_database_size(current_database());")
|
||||
db_size = cursor.fetchone()[0]
|
||||
db_size_mb = round(db_size / 1024 / 1024, 2)
|
||||
self.stdout.write(f"💿 数据库大小: {db_size_mb} MB")
|
||||
|
||||
# 7. 测试查询延迟
|
||||
start_time = time.time()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.fetchone()
|
||||
query_latency = (time.time() - start_time) * 1000
|
||||
|
||||
if query_latency > 500:
|
||||
self.stdout.write(self.style.ERROR(f"⚡ 查询延迟: {query_latency:.2f}ms (高)"))
|
||||
elif query_latency > 200:
|
||||
self.stdout.write(self.style.WARNING(f"⚡ 查询延迟: {query_latency:.2f}ms (中)"))
|
||||
else:
|
||||
self.stdout.write(f"⚡ 查询延迟: {query_latency:.2f}ms (正常)")
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(self.style.ERROR(f"监控失败: {e}"))
|
||||
64
backend/apps/common/management/commands/init_admin.py
Normal file
64
backend/apps/common/management/commands/init_admin.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
初始化 admin 用户
|
||||
用法: python manage.py init_admin [--password <password>]
|
||||
"""
|
||||
import os
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
DEFAULT_USERNAME = 'admin'
|
||||
DEFAULT_PASSWORD = 'admin' # 默认密码,建议首次登录后修改
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = '初始化 admin 用户'
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
'--password',
|
||||
type=str,
|
||||
default=os.getenv('ADMIN_PASSWORD', DEFAULT_PASSWORD),
|
||||
help='admin 用户密码 (默认: admin 或 ADMIN_PASSWORD 环境变量)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='强制重置密码(如果用户已存在)'
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
password = options['password']
|
||||
force = options['force']
|
||||
|
||||
try:
|
||||
user = User.objects.get(username=DEFAULT_USERNAME)
|
||||
if force:
|
||||
user.set_password(password)
|
||||
user.save()
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(f'✓ admin 用户密码已重置')
|
||||
)
|
||||
else:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(f'⚠ admin 用户已存在,跳过创建(使用 --force 重置密码)')
|
||||
)
|
||||
except User.DoesNotExist:
|
||||
User.objects.create_superuser(
|
||||
username=DEFAULT_USERNAME,
|
||||
email='admin@localhost',
|
||||
password=password
|
||||
)
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(f'✓ admin 用户创建成功')
|
||||
)
|
||||
self.stdout.write(
|
||||
self.style.WARNING(f' 用户名: {DEFAULT_USERNAME}')
|
||||
)
|
||||
self.stdout.write(
|
||||
self.style.WARNING(f' 密码: {password}')
|
||||
)
|
||||
self.stdout.write(
|
||||
self.style.WARNING(f' ⚠ 请首次登录后修改密码!')
|
||||
)
|
||||
106
backend/apps/common/normalizer.py
Normal file
106
backend/apps/common/normalizer.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import re
|
||||
|
||||
# 预编译正则表达式,避免每次调用时重新编译
|
||||
IP_PATTERN = re.compile(r'^[\d.:]+$')
|
||||
|
||||
|
||||
def normalize_domain(domain: str) -> str:
|
||||
"""
|
||||
规范化域名
|
||||
- 去除首尾空格
|
||||
- 转换为小写
|
||||
- 移除末尾的点
|
||||
|
||||
Args:
|
||||
domain: 原始域名
|
||||
|
||||
Returns:
|
||||
规范化后的域名
|
||||
|
||||
Raises:
|
||||
ValueError: 域名为空或只包含空格
|
||||
"""
|
||||
if not domain or not domain.strip():
|
||||
raise ValueError("域名不能为空")
|
||||
|
||||
normalized = domain.strip().lower()
|
||||
|
||||
# 移除末尾的点
|
||||
if normalized.endswith('.'):
|
||||
normalized = normalized.rstrip('.')
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def normalize_ip(ip: str) -> str:
|
||||
"""
|
||||
规范化 IP 地址
|
||||
- 去除首尾空格
|
||||
- IP 地址不转小写(保持原样)
|
||||
|
||||
Args:
|
||||
ip: 原始 IP 地址
|
||||
|
||||
Returns:
|
||||
规范化后的 IP 地址
|
||||
|
||||
Raises:
|
||||
ValueError: IP 地址为空或只包含空格
|
||||
"""
|
||||
if not ip or not ip.strip():
|
||||
raise ValueError("IP 地址不能为空")
|
||||
|
||||
return ip.strip()
|
||||
|
||||
|
||||
def normalize_cidr(cidr: str) -> str:
|
||||
"""
|
||||
规范化 CIDR
|
||||
- 去除首尾空格
|
||||
- CIDR 不转小写(保持原样)
|
||||
|
||||
Args:
|
||||
cidr: 原始 CIDR
|
||||
|
||||
Returns:
|
||||
规范化后的 CIDR
|
||||
|
||||
Raises:
|
||||
ValueError: CIDR 为空或只包含空格
|
||||
"""
|
||||
if not cidr or not cidr.strip():
|
||||
raise ValueError("CIDR 不能为空")
|
||||
|
||||
return cidr.strip()
|
||||
|
||||
|
||||
def normalize_target(target: str) -> str:
|
||||
"""
|
||||
规范化目标名称(统一入口)
|
||||
根据目标格式自动选择合适的规范化函数
|
||||
|
||||
Args:
|
||||
target: 原始目标名称
|
||||
|
||||
Returns:
|
||||
规范化后的目标名称
|
||||
|
||||
Raises:
|
||||
ValueError: 目标为空或只包含空格
|
||||
"""
|
||||
if not target or not target.strip():
|
||||
raise ValueError("目标名称不能为空")
|
||||
|
||||
# 先去除首尾空格
|
||||
trimmed = target.strip()
|
||||
|
||||
# 如果包含 /,按 CIDR 处理
|
||||
if '/' in trimmed:
|
||||
return normalize_cidr(trimmed)
|
||||
|
||||
# 如果是纯数字、点、冒号组成,按 IP 处理
|
||||
if IP_PATTERN.match(trimmed):
|
||||
return normalize_ip(trimmed)
|
||||
|
||||
# 否则按域名处理
|
||||
return normalize_domain(trimmed)
|
||||
34
backend/apps/common/pagination.py
Normal file
34
backend/apps/common/pagination.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
自定义分页器,匹配前端响应格式
|
||||
"""
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
from rest_framework.response import Response
|
||||
|
||||
|
||||
class BasePagination(PageNumberPagination):
|
||||
"""
|
||||
基础分页器,统一返回格式
|
||||
|
||||
响应格式:
|
||||
{
|
||||
"results": [...],
|
||||
"total": 100,
|
||||
"page": 1,
|
||||
"pageSize": 10,
|
||||
"totalPages": 10
|
||||
}
|
||||
"""
|
||||
page_size = 10 # 默认每页 10 条
|
||||
page_size_query_param = 'pageSize' # 允许客户端自定义每页数量
|
||||
max_page_size = 1000 # 最大每页数量限制
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
"""自定义响应格式"""
|
||||
return Response({
|
||||
'results': data, # 数据列表
|
||||
'total': self.page.paginator.count, # 总记录数
|
||||
'page': self.page.number, # 当前页码(从 1 开始)
|
||||
'page_size': self.page.paginator.per_page, # 实际使用的每页大小
|
||||
'total_pages': self.page.paginator.num_pages # 总页数
|
||||
})
|
||||
|
||||
42
backend/apps/common/prefect_django_setup.py
Normal file
42
backend/apps/common/prefect_django_setup.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Prefect Flow Django 环境初始化模块
|
||||
|
||||
在所有 Prefect Flow 文件开头导入此模块即可自动配置 Django 环境
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def setup_django_for_prefect():
|
||||
"""
|
||||
为 Prefect Flow 配置 Django 环境
|
||||
|
||||
此函数会:
|
||||
1. 添加项目根目录到 Python 路径
|
||||
2. 设置 DJANGO_SETTINGS_MODULE 环境变量
|
||||
3. 调用 django.setup() 初始化 Django
|
||||
|
||||
使用方式:
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
setup_django_for_prefect()
|
||||
"""
|
||||
# 获取项目根目录(backend 目录)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
backend_dir = os.path.join(current_dir, '../..')
|
||||
backend_dir = os.path.abspath(backend_dir)
|
||||
|
||||
# 添加到 Python 路径
|
||||
if backend_dir not in sys.path:
|
||||
sys.path.insert(0, backend_dir)
|
||||
|
||||
# 配置 Django
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
|
||||
|
||||
# 初始化 Django
|
||||
import django
|
||||
django.setup()
|
||||
|
||||
|
||||
# 自动执行初始化(导入即生效)
|
||||
setup_django_for_prefect()
|
||||
29
backend/apps/common/signals.py
Normal file
29
backend/apps/common/signals.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""通用信号定义
|
||||
|
||||
定义项目中使用的自定义信号,用于解耦各模块之间的通信。
|
||||
|
||||
使用方式:
|
||||
1. 发布信号:signal.send(sender=SomeClass, **kwargs)
|
||||
2. 接收信号:@receiver(signal) def handler(sender, **kwargs): ...
|
||||
"""
|
||||
|
||||
from django.dispatch import Signal
|
||||
|
||||
|
||||
# ==================== 漏洞相关信号 ====================
|
||||
|
||||
# 漏洞保存完成信号
|
||||
# 参数:
|
||||
# - items: List[VulnerabilitySnapshotDTO] 保存的漏洞列表
|
||||
# - scan_id: int 扫描任务ID
|
||||
# - target_id: int 目标ID
|
||||
vulnerabilities_saved = Signal()
|
||||
|
||||
|
||||
# ==================== Worker 相关信号 ====================
|
||||
|
||||
# Worker 删除失败信号(只在失败时发送)
|
||||
# 参数:
|
||||
# - worker_name: str Worker 名称
|
||||
# - message: str 失败原因
|
||||
worker_delete_failed = Signal()
|
||||
12
backend/apps/common/urls.py
Normal file
12
backend/apps/common/urls.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
通用模块 URL 配置
|
||||
"""
|
||||
from django.urls import path
|
||||
from .views import LoginView, LogoutView, MeView, ChangePasswordView
|
||||
|
||||
urlpatterns = [
|
||||
path('auth/login/', LoginView.as_view(), name='auth-login'),
|
||||
path('auth/logout/', LogoutView.as_view(), name='auth-logout'),
|
||||
path('auth/me/', MeView.as_view(), name='auth-me'),
|
||||
path('auth/change-password/', ChangePasswordView.as_view(), name='auth-change-password'),
|
||||
]
|
||||
142
backend/apps/common/validators.py
Normal file
142
backend/apps/common/validators.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""域名、IP、端口和目标验证工具函数"""
|
||||
import ipaddress
|
||||
import logging
|
||||
import validators
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_domain(domain: str) -> None:
|
||||
"""
|
||||
验证域名格式(使用 validators 库)
|
||||
|
||||
Args:
|
||||
domain: 域名字符串(应该已经规范化)
|
||||
|
||||
Raises:
|
||||
ValueError: 域名格式无效
|
||||
"""
|
||||
if not domain:
|
||||
raise ValueError("域名不能为空")
|
||||
|
||||
# 使用 validators 库验证域名格式
|
||||
# 支持国际化域名(IDN)和各种边界情况
|
||||
if not validators.domain(domain):
|
||||
raise ValueError(f"域名格式无效: {domain}")
|
||||
|
||||
|
||||
def validate_ip(ip: str) -> None:
|
||||
"""
|
||||
验证 IP 地址格式(支持 IPv4 和 IPv6)
|
||||
|
||||
Args:
|
||||
ip: IP 地址字符串(应该已经规范化)
|
||||
|
||||
Raises:
|
||||
ValueError: IP 地址格式无效
|
||||
"""
|
||||
if not ip:
|
||||
raise ValueError("IP 地址不能为空")
|
||||
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
except ValueError:
|
||||
raise ValueError(f"IP 地址格式无效: {ip}")
|
||||
|
||||
|
||||
def validate_cidr(cidr: str) -> None:
|
||||
"""
|
||||
验证 CIDR 格式(支持 IPv4 和 IPv6)
|
||||
|
||||
Args:
|
||||
cidr: CIDR 字符串(应该已经规范化)
|
||||
|
||||
Raises:
|
||||
ValueError: CIDR 格式无效
|
||||
"""
|
||||
if not cidr:
|
||||
raise ValueError("CIDR 不能为空")
|
||||
|
||||
try:
|
||||
ipaddress.ip_network(cidr, strict=False)
|
||||
except ValueError:
|
||||
raise ValueError(f"CIDR 格式无效: {cidr}")
|
||||
|
||||
|
||||
def detect_target_type(name: str) -> str:
|
||||
"""
|
||||
检测目标类型(不做规范化,只验证)
|
||||
|
||||
Args:
|
||||
name: 目标名称(应该已经规范化)
|
||||
|
||||
Returns:
|
||||
str: 目标类型 ('domain', 'ip', 'cidr') - 使用 Target.TargetType 枚举值
|
||||
|
||||
Raises:
|
||||
ValueError: 如果无法识别目标类型
|
||||
"""
|
||||
# 在函数内部导入模型,避免 AppRegistryNotReady 错误
|
||||
from apps.targets.models import Target
|
||||
|
||||
if not name:
|
||||
raise ValueError("目标名称不能为空")
|
||||
|
||||
# 检查是否是 CIDR 格式(包含 /)
|
||||
if '/' in name:
|
||||
validate_cidr(name)
|
||||
return Target.TargetType.CIDR
|
||||
|
||||
# 检查是否是 IP 地址
|
||||
try:
|
||||
validate_ip(name)
|
||||
return Target.TargetType.IP
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 检查是否是合法域名
|
||||
try:
|
||||
validate_domain(name)
|
||||
return Target.TargetType.DOMAIN
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 无法识别的格式
|
||||
raise ValueError(f"无法识别的目标格式: {name},必须是域名、IP地址或CIDR范围")
|
||||
|
||||
|
||||
def validate_port(port: any) -> tuple[bool, int | None]:
|
||||
"""
|
||||
验证并转换端口号
|
||||
|
||||
Args:
|
||||
port: 待验证的端口号(可能是字符串、整数或其他类型)
|
||||
|
||||
Returns:
|
||||
tuple: (is_valid, port_number)
|
||||
- is_valid: 端口是否有效
|
||||
- port_number: 有效时为整数端口号,无效时为 None
|
||||
|
||||
验证规则:
|
||||
1. 必须能转换为整数
|
||||
2. 必须在 1-65535 范围内
|
||||
|
||||
示例:
|
||||
>>> is_valid, port_num = validate_port(8080)
|
||||
>>> is_valid, port_num
|
||||
(True, 8080)
|
||||
|
||||
>>> is_valid, port_num = validate_port("invalid")
|
||||
>>> is_valid, port_num
|
||||
(False, None)
|
||||
"""
|
||||
try:
|
||||
port_num = int(port)
|
||||
if 1 <= port_num <= 65535:
|
||||
return True, port_num
|
||||
else:
|
||||
logger.warning("端口号超出有效范围 (1-65535): %d", port_num)
|
||||
return False, None
|
||||
except (ValueError, TypeError):
|
||||
logger.warning("端口号格式错误,无法转换为整数: %s", port)
|
||||
return False, None
|
||||
3
backend/apps/common/views/__init__.py
Normal file
3
backend/apps/common/views/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .auth_views import LoginView, LogoutView, MeView, ChangePasswordView
|
||||
|
||||
__all__ = ['LoginView', 'LogoutView', 'MeView', 'ChangePasswordView']
|
||||
173
backend/apps/common/views/auth_views.py
Normal file
173
backend/apps/common/views/auth_views.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
认证相关视图
|
||||
使用 Django 内置认证系统,支持 Session 认证
|
||||
"""
|
||||
import logging
|
||||
from django.contrib.auth import authenticate, login, logout, update_session_auth_hash
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from django.utils.decorators import method_decorator
|
||||
from rest_framework import status
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name='dispatch')
|
||||
class LoginView(APIView):
|
||||
"""
|
||||
用户登录
|
||||
POST /api/auth/login/
|
||||
"""
|
||||
authentication_classes = [] # 禁用认证(绕过 CSRF)
|
||||
permission_classes = [AllowAny]
|
||||
|
||||
def post(self, request):
|
||||
username = request.data.get('username')
|
||||
password = request.data.get('password')
|
||||
|
||||
if not username or not password:
|
||||
return Response(
|
||||
{'error': '请提供用户名和密码'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
user = authenticate(request, username=username, password=password)
|
||||
|
||||
if user is not None:
|
||||
login(request, user)
|
||||
logger.info(f"用户 {username} 登录成功")
|
||||
return Response({
|
||||
'message': '登录成功',
|
||||
'user': {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'isStaff': user.is_staff,
|
||||
'isSuperuser': user.is_superuser,
|
||||
}
|
||||
})
|
||||
else:
|
||||
logger.warning(f"用户 {username} 登录失败:用户名或密码错误")
|
||||
return Response(
|
||||
{'error': '用户名或密码错误'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name='dispatch')
|
||||
class LogoutView(APIView):
|
||||
"""
|
||||
用户登出
|
||||
POST /api/auth/logout/
|
||||
"""
|
||||
authentication_classes = [] # 禁用认证(绕过 CSRF)
|
||||
permission_classes = [AllowAny]
|
||||
|
||||
def post(self, request):
|
||||
# 从 session 获取用户名用于日志
|
||||
user_id = request.session.get('_auth_user_id')
|
||||
if user_id:
|
||||
from django.contrib.auth import get_user_model
|
||||
User = get_user_model()
|
||||
try:
|
||||
user = User.objects.get(pk=user_id)
|
||||
username = user.username
|
||||
logout(request)
|
||||
logger.info(f"用户 {username} 已登出")
|
||||
except User.DoesNotExist:
|
||||
logout(request)
|
||||
else:
|
||||
logout(request)
|
||||
return Response({'message': '已登出'})
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name='dispatch')
|
||||
class MeView(APIView):
|
||||
"""
|
||||
获取当前用户信息
|
||||
GET /api/auth/me/
|
||||
"""
|
||||
authentication_classes = [] # 禁用认证(绕过 CSRF)
|
||||
permission_classes = [AllowAny]
|
||||
|
||||
def get(self, request):
|
||||
# 从 session 获取用户
|
||||
from django.contrib.auth import get_user_model
|
||||
User = get_user_model()
|
||||
|
||||
user_id = request.session.get('_auth_user_id')
|
||||
if user_id:
|
||||
try:
|
||||
user = User.objects.get(pk=user_id)
|
||||
return Response({
|
||||
'authenticated': True,
|
||||
'user': {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'isStaff': user.is_staff,
|
||||
'isSuperuser': user.is_superuser,
|
||||
}
|
||||
})
|
||||
except User.DoesNotExist:
|
||||
pass
|
||||
|
||||
return Response({
|
||||
'authenticated': False,
|
||||
'user': None
|
||||
})
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name='dispatch')
|
||||
class ChangePasswordView(APIView):
|
||||
"""
|
||||
修改密码
|
||||
POST /api/auth/change-password/
|
||||
"""
|
||||
authentication_classes = [] # 禁用认证(绕过 CSRF)
|
||||
permission_classes = [AllowAny] # 手动检查登录状态
|
||||
|
||||
def post(self, request):
|
||||
# 手动检查登录状态(从 session 获取用户)
|
||||
from django.contrib.auth import get_user_model
|
||||
User = get_user_model()
|
||||
|
||||
user_id = request.session.get('_auth_user_id')
|
||||
if not user_id:
|
||||
return Response(
|
||||
{'error': '请先登录'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
try:
|
||||
user = User.objects.get(pk=user_id)
|
||||
except User.DoesNotExist:
|
||||
return Response(
|
||||
{'error': '用户不存在'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# CamelCaseParser 将 oldPassword -> old_password
|
||||
old_password = request.data.get('old_password')
|
||||
new_password = request.data.get('new_password')
|
||||
|
||||
if not old_password or not new_password:
|
||||
return Response(
|
||||
{'error': '请提供旧密码和新密码'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
if not user.check_password(old_password):
|
||||
return Response(
|
||||
{'error': '旧密码错误'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
user.set_password(new_password)
|
||||
user.save()
|
||||
|
||||
# 更新 session,避免用户被登出
|
||||
update_session_auth_hash(request, user)
|
||||
|
||||
logger.info(f"用户 {user.username} 已修改密码")
|
||||
return Response({'message': '密码修改成功'})
|
||||
0
backend/apps/engine/__init__.py
Normal file
0
backend/apps/engine/__init__.py
Normal file
32
backend/apps/engine/apps.py
Normal file
32
backend/apps/engine/apps.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import os
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class EngineConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'apps.engine'
|
||||
verbose_name = '扫描引擎'
|
||||
|
||||
def ready(self):
|
||||
"""应用就绪时启动定时调度器"""
|
||||
# 只在主进程中启动调度器(避免 autoreload 重复启动)
|
||||
# 检查是否在 runserver 的 autoreload 子进程中
|
||||
if os.environ.get('RUN_MAIN') == 'true' or not self._is_runserver():
|
||||
# 只在 Server 容器中启动调度器(Worker 容器不需要)
|
||||
if not os.environ.get('SERVER_URL'): # Worker 容器有 SERVER_URL
|
||||
self._start_scheduler()
|
||||
|
||||
def _is_runserver(self):
|
||||
"""检查是否通过 runserver 启动"""
|
||||
import sys
|
||||
return 'runserver' in sys.argv
|
||||
|
||||
def _start_scheduler(self):
|
||||
"""启动调度器"""
|
||||
try:
|
||||
from apps.engine.scheduler import start_scheduler
|
||||
start_scheduler()
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"调度器启动失败: {e}")
|
||||
6
backend/apps/engine/consumers/__init__.py
Normal file
6
backend/apps/engine/consumers/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Engine WebSocket Consumers
|
||||
"""
|
||||
from .worker_deploy_consumer import WorkerDeployConsumer
|
||||
|
||||
__all__ = ['WorkerDeployConsumer']
|
||||
454
backend/apps/engine/consumers/worker_deploy_consumer.py
Normal file
454
backend/apps/engine/consumers/worker_deploy_consumer.py
Normal file
@@ -0,0 +1,454 @@
|
||||
"""
|
||||
WebSocket Consumer - Worker 交互式终端 (使用 PTY)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from apps.engine.services import WorkerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerDeployConsumer(AsyncWebsocketConsumer):
|
||||
"""
|
||||
Worker 交互式终端 WebSocket Consumer
|
||||
|
||||
使用 paramiko invoke_shell 实现真正的交互式终端
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ssh_client = None
|
||||
self.shell = None
|
||||
self.worker = None
|
||||
self.read_task = None
|
||||
self.worker_service = WorkerService()
|
||||
|
||||
async def connect(self):
|
||||
"""连接时加入对应 Worker 的组并自动建立 SSH 连接"""
|
||||
self.worker_id = self.scope['url_route']['kwargs']['worker_id']
|
||||
self.group_name = f'worker_deploy_{self.worker_id}'
|
||||
|
||||
await self.channel_layer.group_add(self.group_name, self.channel_name)
|
||||
await self.accept()
|
||||
|
||||
logger.info(f"终端已连接 - Worker: {self.worker_id}")
|
||||
|
||||
# 自动建立 SSH 连接
|
||||
await self._auto_ssh_connect()
|
||||
|
||||
async def disconnect(self, close_code):
|
||||
"""断开时清理资源"""
|
||||
if self.read_task:
|
||||
self.read_task.cancel()
|
||||
if self.shell:
|
||||
try:
|
||||
self.shell.close()
|
||||
except Exception:
|
||||
pass
|
||||
if self.ssh_client:
|
||||
try:
|
||||
self.ssh_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await self.channel_layer.group_discard(self.group_name, self.channel_name)
|
||||
logger.info(f"终端已断开 - Worker: {self.worker_id}")
|
||||
|
||||
async def receive(self, text_data=None, bytes_data=None):
|
||||
"""接收客户端消息"""
|
||||
if bytes_data:
|
||||
# 二进制数据直接发送到 shell
|
||||
if self.shell:
|
||||
await asyncio.to_thread(self.shell.send, bytes_data)
|
||||
return
|
||||
|
||||
if not text_data:
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(text_data)
|
||||
msg_type = data.get('type')
|
||||
|
||||
if msg_type == 'resize':
|
||||
cols = data.get('cols', 80)
|
||||
rows = data.get('rows', 24)
|
||||
if self.shell:
|
||||
await asyncio.to_thread(self.shell.resize_pty, cols, rows)
|
||||
|
||||
elif msg_type == 'input':
|
||||
# 终端输入
|
||||
if self.shell:
|
||||
text = data.get('data', '')
|
||||
await asyncio.to_thread(self.shell.send, text)
|
||||
|
||||
elif msg_type == 'deploy':
|
||||
# 执行部署脚本(后台运行)
|
||||
await self._run_deploy_script()
|
||||
|
||||
elif msg_type == 'attach':
|
||||
# 查看部署进度(attach 到 tmux 会话)
|
||||
await self._attach_deploy_session()
|
||||
|
||||
elif msg_type == 'uninstall':
|
||||
# 执行卸载脚本(后台运行)
|
||||
await self._run_uninstall_script()
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# 可能是普通文本输入
|
||||
if self.shell and text_data:
|
||||
await asyncio.to_thread(self.shell.send, text_data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息错误: {e}")
|
||||
|
||||
async def _auto_ssh_connect(self):
|
||||
"""自动从数据库读取密码并连接"""
|
||||
logger.info(f"[SSH] 开始自动连接 - Worker ID: {self.worker_id}")
|
||||
# 通过服务层获取 Worker 节点
|
||||
# thread_sensitive=False 确保在新线程中运行,避免数据库连接问题
|
||||
self.worker = await sync_to_async(self.worker_service.get_worker, thread_sensitive=False)(self.worker_id)
|
||||
logger.info(f"[SSH] Worker 查询结果: {self.worker}")
|
||||
|
||||
if not self.worker:
|
||||
await self.send(text_data=json.dumps({
|
||||
'type': 'error',
|
||||
'message': 'Worker 不存在'
|
||||
}))
|
||||
return
|
||||
|
||||
if not self.worker.password:
|
||||
await self.send(text_data=json.dumps({
|
||||
'type': 'error',
|
||||
'message': '未配置 SSH 密码,请先编辑节点信息'
|
||||
}))
|
||||
return
|
||||
|
||||
# 使用默认终端大小
|
||||
await self._ssh_connect(self.worker.password, 80, 24)
|
||||
|
||||
async def _ssh_connect(self, password: str, cols: int = 80, rows: int = 24):
|
||||
"""建立 SSH 连接并启动交互式 shell (使用 tmux 持久化会话)"""
|
||||
try:
|
||||
import paramiko
|
||||
except ImportError:
|
||||
await self.send(text_data=json.dumps({
|
||||
'type': 'error',
|
||||
'message': '服务器缺少 paramiko 库'
|
||||
}))
|
||||
return
|
||||
|
||||
# self.worker 已在 _auto_ssh_connect 中查询
|
||||
try:
|
||||
ssh = paramiko.SSHClient()
|
||||
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
await asyncio.to_thread(
|
||||
ssh.connect,
|
||||
self.worker.ip_address,
|
||||
port=self.worker.ssh_port,
|
||||
username=self.worker.username,
|
||||
password=password,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
self.ssh_client = ssh
|
||||
|
||||
# 启动交互式 shell(连接时不做 tmux 安装,仅提供普通 shell)
|
||||
self.shell = await asyncio.to_thread(
|
||||
ssh.invoke_shell,
|
||||
term='xterm-256color',
|
||||
width=cols,
|
||||
height=rows
|
||||
)
|
||||
|
||||
# 发送连接成功消息
|
||||
logger.info(f"[SSH] 准备发送 connected 消息 - Worker: {self.worker_id}")
|
||||
await self.send(text_data=json.dumps({
|
||||
'type': 'connected'
|
||||
}))
|
||||
logger.info(f"[SSH] connected 消息已发送 - Worker: {self.worker_id}")
|
||||
|
||||
# 启动读取任务
|
||||
self.read_task = asyncio.create_task(self._read_shell_output())
|
||||
|
||||
logger.info(f"[SSH] Shell 已连接,读取任务已启动 - Worker: {self.worker_id}")
|
||||
|
||||
except paramiko.AuthenticationException:
|
||||
logger.error(f"[SSH] 认证失败 - Worker: {self.worker_id}")
|
||||
await self.send(text_data=json.dumps({
|
||||
'type': 'error',
|
||||
'message': '认证失败,密码错误'
|
||||
}))
|
||||
except Exception as e:
|
||||
logger.error(f"[SSH] 连接失败 - Worker: {self.worker_id}, Error: {e}")
|
||||
await self.send(text_data=json.dumps({
|
||||
'type': 'error',
|
||||
'message': f'连接失败: {str(e)}'
|
||||
}))
|
||||
|
||||
async def _read_shell_output(self):
|
||||
"""持续读取 shell 输出并发送到客户端"""
|
||||
try:
|
||||
while self.shell and not self.shell.closed:
|
||||
if self.shell.recv_ready():
|
||||
data = await asyncio.to_thread(self.shell.recv, 4096)
|
||||
if data:
|
||||
await self.send(bytes_data=data)
|
||||
else:
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"读取 shell 输出错误: {e}")
|
||||
|
||||
async def _run_deploy_script(self):
|
||||
"""运行部署脚本(在 tmux 会话中执行,支持断线续连)
|
||||
|
||||
流程:
|
||||
1. 通过 SFTP 上传脚本到远程服务器
|
||||
2. 使用 exec_command 静默执行(不在交互式终端回显)
|
||||
3. 通过 WebSocket 发送结果到前端显示
|
||||
"""
|
||||
if not self.ssh_client:
|
||||
return
|
||||
|
||||
from apps.engine.services.deploy_service import (
|
||||
get_bootstrap_script,
|
||||
get_deploy_script,
|
||||
get_start_agent_script
|
||||
)
|
||||
|
||||
# 优先使用 settings 中配置的对外访问主机(PUBLIC_HOST)拼接 Django URL
|
||||
public_host = getattr(settings, 'PUBLIC_HOST', '').strip()
|
||||
server_port = getattr(settings, 'SERVER_PORT', '8888')
|
||||
|
||||
if not public_host:
|
||||
error_msg = (
|
||||
"未配置 PUBLIC_HOST,请在 docker/.env 中设置对外访问 IP/域名 "
|
||||
"(PUBLIC_HOST) 并重启服务后再执行远程部署"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
await self.send(text_data=json.dumps({
|
||||
'type': 'error',
|
||||
'message': error_msg,
|
||||
}))
|
||||
return
|
||||
|
||||
django_host = f"{public_host}:{server_port}" # Django / 心跳上报使用
|
||||
heartbeat_api_url = f"http://{django_host}" # 基础 URL,agent 会加 /api/...
|
||||
|
||||
session_name = f'xingrin_deploy_{self.worker_id}'
|
||||
remote_script_path = '/tmp/xingrin_deploy.sh'
|
||||
|
||||
# 获取外置脚本内容
|
||||
bootstrap_script = get_bootstrap_script()
|
||||
deploy_script = get_deploy_script()
|
||||
start_script = get_start_agent_script(
|
||||
heartbeat_api_url=heartbeat_api_url,
|
||||
worker_id=self.worker_id
|
||||
)
|
||||
|
||||
# 合并脚本
|
||||
combined_script = f"""#!/bin/bash
|
||||
set -e
|
||||
|
||||
# ==================== 阶段 1: 环境初始化 ====================
|
||||
{bootstrap_script}
|
||||
|
||||
# ==================== 阶段 2: 安装 Docker ====================
|
||||
{deploy_script}
|
||||
|
||||
# ==================== 阶段 3: 启动 Agent ====================
|
||||
{start_script}
|
||||
|
||||
echo "SUCCESS"
|
||||
"""
|
||||
|
||||
# 更新状态为 deploying
|
||||
await sync_to_async(self.worker_service.update_status)(self.worker_id, 'deploying')
|
||||
|
||||
# 发送开始提示
|
||||
start_msg = "\r\n\033[36m[XingRin] 正在准备部署...\033[0m\r\n"
|
||||
await self.send(bytes_data=start_msg.encode())
|
||||
|
||||
try:
|
||||
# 1. 上传脚本
|
||||
sftp = await asyncio.to_thread(self.ssh_client.open_sftp)
|
||||
with sftp.file(remote_script_path, 'w') as f:
|
||||
f.write(combined_script)
|
||||
sftp.chmod(remote_script_path, 0o755)
|
||||
await asyncio.to_thread(sftp.close)
|
||||
|
||||
# 2. 静默执行部署命令(使用 exec_command,不会回显到终端)
|
||||
deploy_cmd = f"""
|
||||
# 确保 tmux 安装
|
||||
if ! command -v tmux >/dev/null 2>&1; then
|
||||
if command -v apt-get >/dev/null 2>&1; then
|
||||
sudo apt-get update -qq && sudo apt-get install -y -qq tmux >/dev/null 2>&1
|
||||
fi
|
||||
fi
|
||||
|
||||
# 检查脚本是否存在
|
||||
if [ ! -f "{remote_script_path}" ]; then
|
||||
echo "SCRIPT_NOT_FOUND"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 启动 tmux 会话
|
||||
if command -v tmux >/dev/null 2>&1; then
|
||||
tmux kill-session -t {session_name} 2>/dev/null || true
|
||||
# 使用 bash 执行脚本,确保环境正确
|
||||
tmux new-session -d -s {session_name} "bash {remote_script_path}; echo '部署完成,按回车退出'; read"
|
||||
# 验证会话是否创建成功
|
||||
sleep 0.5
|
||||
if tmux has-session -t {session_name} 2>/dev/null; then
|
||||
echo "SUCCESS"
|
||||
else
|
||||
echo "SESSION_CREATE_FAILED"
|
||||
fi
|
||||
else
|
||||
echo "TMUX_NOT_FOUND"
|
||||
fi
|
||||
"""
|
||||
stdin, stdout, stderr = await asyncio.to_thread(
|
||||
self.ssh_client.exec_command, deploy_cmd
|
||||
)
|
||||
result = await asyncio.to_thread(stdout.read)
|
||||
result = result.decode().strip()
|
||||
|
||||
# 3. 发送结果到前端终端显示
|
||||
if "SUCCESS" in result:
|
||||
# 部署任务已在后台启动,保持 deploying 状态
|
||||
# 只有当心跳上报成功后才会变成 deployed(通过 heartbeat API 自动更新)
|
||||
success_msg = (
|
||||
"\r\n\033[32m✓ 部署任务已在后台启动\033[0m\r\n"
|
||||
f"\033[90m 会话: {session_name}\033[0m\r\n"
|
||||
"\r\n"
|
||||
"\033[36m点击 [查看进度] 按钮查看部署输出\033[0m\r\n"
|
||||
f"\033[90m或手动执行: tmux attach -t {session_name}\033[0m\r\n"
|
||||
"\r\n"
|
||||
)
|
||||
else:
|
||||
# 获取更多错误信息
|
||||
err = await asyncio.to_thread(stderr.read)
|
||||
err_msg = err.decode().strip() if err else ""
|
||||
success_msg = f"\r\n\033[31m✗ 部署启动失败\033[0m\r\n\033[90m结果: {result}\r\n错误: {err_msg}\033[0m\r\n"
|
||||
|
||||
await self.send(bytes_data=success_msg.encode())
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"\033[31m✗ 部署失败: {str(e)}\033[0m\r\n"
|
||||
await self.send(bytes_data=error_msg.encode())
|
||||
logger.error(f"部署脚本执行失败: {e}")
|
||||
|
||||
async def _run_uninstall_script(self):
|
||||
"""在远程主机上执行 Worker 卸载脚本
|
||||
|
||||
逻辑:
|
||||
1. 通过服务层读取本地 worker-uninstall.sh 内容
|
||||
2. 上传到远程 /tmp/xingrin_uninstall.sh 并赋予执行权限
|
||||
3. 使用 exec_command 以 bash 执行脚本
|
||||
4. 将执行结果摘要写回前端终端
|
||||
"""
|
||||
if not self.ssh_client:
|
||||
return
|
||||
|
||||
from apps.engine.services.deploy_service import get_uninstall_script
|
||||
|
||||
uninstall_script = get_uninstall_script()
|
||||
remote_script_path = '/tmp/xingrin_uninstall.sh'
|
||||
|
||||
start_msg = "\r\n\033[36m[XingRin] 正在执行 Worker 卸载...\033[0m\r\n"
|
||||
await self.send(bytes_data=start_msg.encode())
|
||||
|
||||
try:
|
||||
# 上传卸载脚本
|
||||
sftp = await asyncio.to_thread(self.ssh_client.open_sftp)
|
||||
with sftp.file(remote_script_path, 'w') as f:
|
||||
f.write(uninstall_script)
|
||||
sftp.chmod(remote_script_path, 0o755)
|
||||
await asyncio.to_thread(sftp.close)
|
||||
|
||||
# 执行卸载脚本
|
||||
cmd = f"bash {remote_script_path}"
|
||||
stdin, stdout, stderr = await asyncio.to_thread(
|
||||
self.ssh_client.exec_command, cmd
|
||||
)
|
||||
out = await asyncio.to_thread(stdout.read)
|
||||
err = await asyncio.to_thread(stderr.read)
|
||||
|
||||
# 转换换行符为终端格式 (\n -> \r\n)
|
||||
output_text = out.decode().strip().replace('\n', '\r\n') if out else ""
|
||||
error_text = err.decode().strip().replace('\n', '\r\n') if err else ""
|
||||
|
||||
# 简单判断是否成功(退出码 + 关键字)
|
||||
exit_status = stdout.channel.recv_exit_status()
|
||||
if exit_status == 0:
|
||||
# 卸载成功,重置状态为 pending
|
||||
await sync_to_async(self.worker_service.update_status)(self.worker_id, 'pending')
|
||||
# 删除 Redis 中的心跳数据
|
||||
from apps.engine.services.worker_load_service import worker_load_service
|
||||
worker_load_service.delete_load(self.worker_id)
|
||||
# 发送状态更新到前端
|
||||
await self.send(text_data=json.dumps({
|
||||
'type': 'status',
|
||||
'status': 'pending' # 卸载后变为待部署状态
|
||||
}))
|
||||
msg = "\r\n\033[32m✓ 节点卸载完成\033[0m\r\n"
|
||||
if output_text:
|
||||
msg += f"\033[90m{output_text}\033[0m\r\n"
|
||||
else:
|
||||
msg = "\r\n\033[31m✗ Worker 卸载失败\033[0m\r\n"
|
||||
if output_text:
|
||||
msg += f"\033[90m输出: {output_text}\033[0m\r\n"
|
||||
if error_text:
|
||||
msg += f"\033[90m错误: {error_text}\033[0m\r\n"
|
||||
|
||||
await self.send(bytes_data=msg.encode())
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"\033[31m✗ 卸载执行异常: {str(e)}\033[0m\r\n"
|
||||
await self.send(bytes_data=error_msg.encode())
|
||||
logger.error(f"卸载脚本执行失败: {e}")
|
||||
|
||||
async def _attach_deploy_session(self):
|
||||
"""Attach 到部署会话查看进度"""
|
||||
if not self.shell or not self.ssh_client:
|
||||
return
|
||||
|
||||
session_name = f'xingrin_deploy_{self.worker_id}'
|
||||
|
||||
# 先静默检查会话是否存在
|
||||
check_cmd = f"tmux has-session -t {session_name} 2>/dev/null && echo EXISTS || echo NOT_EXISTS"
|
||||
stdin, stdout, stderr = await asyncio.to_thread(
|
||||
self.ssh_client.exec_command, check_cmd
|
||||
)
|
||||
result = await asyncio.to_thread(stdout.read)
|
||||
result = result.decode().strip()
|
||||
|
||||
if "EXISTS" in result:
|
||||
# 会话存在,直接 attach
|
||||
await asyncio.to_thread(self.shell.send, f"tmux attach -t {session_name}\n")
|
||||
else:
|
||||
# 会话不存在,发送提示
|
||||
msg = "\r\n\033[33m没有正在运行的部署任务\033[0m\r\n\033[90m请先点击 [执行部署] 按钮启动部署\033[0m\r\n\r\n"
|
||||
await self.send(bytes_data=msg.encode())
|
||||
|
||||
# Channel Layer 消息处理
|
||||
async def terminal_output(self, event):
|
||||
if self.shell:
|
||||
await asyncio.to_thread(self.shell.send, event['content'])
|
||||
|
||||
async def deploy_status(self, event):
|
||||
await self.send(text_data=json.dumps({
|
||||
'type': 'status',
|
||||
'status': event['status'],
|
||||
'message': event.get('message', '')
|
||||
}))
|
||||
0
backend/apps/engine/management/__init__.py
Normal file
0
backend/apps/engine/management/__init__.py
Normal file
0
backend/apps/engine/management/commands/__init__.py
Normal file
0
backend/apps/engine/management/commands/__init__.py
Normal file
112
backend/apps/engine/management/commands/init_default_engine.py
Normal file
112
backend/apps/engine/management/commands/init_default_engine.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
初始化默认扫描引擎
|
||||
|
||||
用法:
|
||||
python manage.py init_default_engine # 只创建不存在的引擎(不覆盖已有)
|
||||
python manage.py init_default_engine --force # 强制覆盖所有引擎配置
|
||||
|
||||
cd /root/my-vulun-scan/docker
|
||||
docker compose exec server python backend/manage.py init_default_engine --force
|
||||
|
||||
功能:
|
||||
- 读取 engine_config_example.yaml 作为默认配置
|
||||
- 创建 full scan(默认引擎)+ 各扫描类型的子引擎
|
||||
- 默认不覆盖已有配置,加 --force 才会覆盖
|
||||
"""
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from apps.engine.models import ScanEngine
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = '初始化默认扫描引擎配置(默认不覆盖已有,加 --force 强制覆盖)'
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='强制覆盖已有的引擎配置',
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
force = options.get('force', False)
|
||||
# 读取默认配置文件
|
||||
config_path = Path(__file__).resolve().parent.parent.parent.parent / 'scan' / 'configs' / 'engine_config_example.yaml'
|
||||
|
||||
if not config_path.exists():
|
||||
self.stdout.write(self.style.ERROR(f'配置文件不存在: {config_path}'))
|
||||
return
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
default_config = f.read()
|
||||
|
||||
# 解析 YAML 为字典,后续用于生成子引擎配置
|
||||
try:
|
||||
config_dict = yaml.safe_load(default_config) or {}
|
||||
except yaml.YAMLError as e:
|
||||
self.stdout.write(self.style.ERROR(f'引擎配置 YAML 解析失败: {e}'))
|
||||
return
|
||||
|
||||
# 1) full scan:保留完整配置
|
||||
engine = ScanEngine.objects.filter(name='full scan').first()
|
||||
if engine:
|
||||
if force:
|
||||
engine.configuration = default_config
|
||||
engine.save()
|
||||
self.stdout.write(self.style.SUCCESS(f'✓ 扫描引擎 full scan 配置已更新 (ID: {engine.id})'))
|
||||
else:
|
||||
self.stdout.write(self.style.WARNING(f' ⊘ full scan 已存在,跳过(使用 --force 覆盖)'))
|
||||
else:
|
||||
engine = ScanEngine.objects.create(
|
||||
name='full scan',
|
||||
configuration=default_config,
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS(f'✓ 扫描引擎 full scan 已创建 (ID: {engine.id})'))
|
||||
|
||||
# 2) 为每个扫描类型生成一个「单一扫描类型」的子引擎
|
||||
# 例如:subdomain_discovery, port_scan, ...
|
||||
from apps.scan.configs.command_templates import get_supported_scan_types
|
||||
|
||||
supported_scan_types = set(get_supported_scan_types())
|
||||
|
||||
for scan_type, scan_cfg in config_dict.items():
|
||||
# 只处理受支持且结构为 {tools: {...}} 的扫描类型
|
||||
if scan_type not in supported_scan_types:
|
||||
continue
|
||||
if not isinstance(scan_cfg, dict):
|
||||
continue
|
||||
# subdomain_discovery 使用 4 阶段新结构(无 tools 字段),其他扫描类型仍要求有 tools
|
||||
if scan_type != 'subdomain_discovery' and 'tools' not in scan_cfg:
|
||||
continue
|
||||
|
||||
# 构造只包含当前扫描类型配置的 YAML
|
||||
single_config = {scan_type: scan_cfg}
|
||||
try:
|
||||
single_yaml = yaml.safe_dump(
|
||||
single_config,
|
||||
sort_keys=False,
|
||||
allow_unicode=True,
|
||||
)
|
||||
except yaml.YAMLError as e:
|
||||
self.stdout.write(self.style.ERROR(f'生成子引擎 {scan_type} 配置失败: {e}'))
|
||||
continue
|
||||
|
||||
engine_name = f"{scan_type}"
|
||||
sub_engine = ScanEngine.objects.filter(name=engine_name).first()
|
||||
if sub_engine:
|
||||
if force:
|
||||
sub_engine.configuration = single_yaml
|
||||
sub_engine.save()
|
||||
self.stdout.write(self.style.SUCCESS(f' ✓ 子引擎 {engine_name} 配置已更新 (ID: {sub_engine.id})'))
|
||||
else:
|
||||
self.stdout.write(self.style.WARNING(f' ⊘ {engine_name} 已存在,跳过(使用 --force 覆盖)'))
|
||||
else:
|
||||
sub_engine = ScanEngine.objects.create(
|
||||
name=engine_name,
|
||||
configuration=single_yaml,
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS(f' ✓ 子引擎 {engine_name} 已创建 (ID: {sub_engine.id})'))
|
||||
126
backend/apps/engine/management/commands/init_nuclei_templates.py
Normal file
126
backend/apps/engine/management/commands/init_nuclei_templates.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""初始化 Nuclei 模板仓库
|
||||
|
||||
项目安装后执行此命令,自动创建官方模板仓库记录。
|
||||
|
||||
使用方式:
|
||||
python manage.py init_nuclei_templates # 只创建记录
|
||||
python manage.py init_nuclei_templates --sync # 创建并同步(git clone)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from apps.engine.models import NucleiTemplateRepo
|
||||
from apps.engine.services import NucleiTemplateRepoService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 默认仓库配置
|
||||
DEFAULT_REPOS = [
|
||||
{
|
||||
"name": "nuclei-templates",
|
||||
"repo_url": "https://github.com/projectdiscovery/nuclei-templates.git",
|
||||
"description": "Nuclei 官方模板仓库,包含数千个漏洞检测模板",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "初始化 Nuclei 模板仓库(创建官方模板仓库记录)"
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"--sync",
|
||||
action="store_true",
|
||||
help="创建后立即同步(git clone),首次需要较长时间",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="强制重新创建(删除已存在的同名仓库)",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
do_sync = options.get("sync", False)
|
||||
force = options.get("force", False)
|
||||
|
||||
service = NucleiTemplateRepoService()
|
||||
created = 0
|
||||
skipped = 0
|
||||
synced = 0
|
||||
|
||||
for repo_config in DEFAULT_REPOS:
|
||||
name = repo_config["name"]
|
||||
repo_url = repo_config["repo_url"]
|
||||
|
||||
# 检查是否已存在
|
||||
existing = NucleiTemplateRepo.objects.filter(name=name).first()
|
||||
|
||||
if existing:
|
||||
if force:
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f"[{name}] 强制模式,删除已存在的仓库记录"
|
||||
))
|
||||
service.remove_local_path_dir(existing)
|
||||
existing.delete()
|
||||
else:
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{name}] 已存在,跳过创建"
|
||||
))
|
||||
skipped += 1
|
||||
|
||||
# 如果需要同步且已存在,也执行同步
|
||||
if do_sync and existing.id:
|
||||
try:
|
||||
result = service.refresh_repo(existing.id)
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{name}] 同步完成: {result.get('action', 'unknown')}, "
|
||||
f"commit={result.get('commitHash', 'N/A')[:8]}"
|
||||
))
|
||||
synced += 1
|
||||
except Exception as e:
|
||||
self.stdout.write(self.style.ERROR(
|
||||
f"[{name}] 同步失败: {e}"
|
||||
))
|
||||
continue
|
||||
|
||||
# 创建新仓库记录
|
||||
try:
|
||||
repo = NucleiTemplateRepo.objects.create(
|
||||
name=name,
|
||||
repo_url=repo_url,
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{name}] 创建成功: id={repo.id}"
|
||||
))
|
||||
created += 1
|
||||
|
||||
# 初始化本地路径
|
||||
service.ensure_local_path(repo)
|
||||
|
||||
# 如果需要同步
|
||||
if do_sync:
|
||||
try:
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f"[{name}] 正在同步(首次可能需要几分钟)..."
|
||||
))
|
||||
result = service.refresh_repo(repo.id)
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{name}] 同步完成: {result.get('action', 'unknown')}, "
|
||||
f"commit={result.get('commitHash', 'N/A')[:8]}"
|
||||
))
|
||||
synced += 1
|
||||
except Exception as e:
|
||||
self.stdout.write(self.style.ERROR(
|
||||
f"[{name}] 同步失败: {e}"
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(self.style.ERROR(
|
||||
f"[{name}] 创建失败: {e}"
|
||||
))
|
||||
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"\n初始化完成: 创建 {created}, 跳过 {skipped}, 同步 {synced}"
|
||||
))
|
||||
148
backend/apps/engine/management/commands/init_wordlists.py
Normal file
148
backend/apps/engine/management/commands/init_wordlists.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""初始化所有内置字典 Wordlist 记录
|
||||
|
||||
- 目录扫描默认字典: dir_default.txt -> /app/backend/wordlist/dir_default.txt
|
||||
- 子域名爆破默认字典: subdomains-top1million-110000.txt -> /app/backend/wordlist/subdomains-top1million-110000.txt
|
||||
|
||||
可重复执行:如果已存在同名记录且文件有效则跳过,只在缺失或文件丢失时创建/修复。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from apps.common.hash_utils import safe_calc_file_sha256
|
||||
from apps.engine.models import Wordlist
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_WORDLISTS = [
|
||||
{
|
||||
"name": "dir_default.txt",
|
||||
"filename": "dir_default.txt",
|
||||
"description": "内置默认目录字典",
|
||||
},
|
||||
{
|
||||
"name": "subdomains-top1million-110000.txt",
|
||||
"filename": "subdomains-top1million-110000.txt",
|
||||
"description": "内置默认子域名字典",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "初始化所有内置字典 Wordlist 记录"
|
||||
|
||||
def handle(self, *args, **options):
|
||||
project_base = Path(settings.BASE_DIR).parent # /app/backend -> /app
|
||||
base_wordlist_dir = project_base / "backend" / "wordlist"
|
||||
runtime_base_dir = Path(getattr(settings, "WORDLISTS_BASE_PATH", "/opt/xingrin/wordlists"))
|
||||
runtime_base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
initialized = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
|
||||
for item in DEFAULT_WORDLISTS:
|
||||
name = item["name"]
|
||||
filename = item["filename"]
|
||||
description = item["description"]
|
||||
|
||||
existing = Wordlist.objects.filter(name=name).first()
|
||||
if existing:
|
||||
file_path = existing.file_path or ""
|
||||
file_hash = getattr(existing, 'file_hash', '') or ''
|
||||
if file_path and Path(file_path).exists() and file_hash:
|
||||
# 记录、文件、hash 都在,直接跳过
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{name}] 已存在且文件有效,跳过初始化 (file_path={file_path})"
|
||||
))
|
||||
skipped += 1
|
||||
continue
|
||||
elif file_path and Path(file_path).exists() and not file_hash:
|
||||
# 文件在但 hash 缺失,需要补算
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f"[{name}] 记录已存在但缺少 file_hash,将补算并更新"
|
||||
))
|
||||
else:
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f"[{name}] 记录已存在但物理文件丢失,将重新创建文件路径并修复记录"
|
||||
))
|
||||
|
||||
src_path = base_wordlist_dir / filename
|
||||
dest_path = runtime_base_dir / filename
|
||||
|
||||
if not src_path.exists():
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f"[{name}] 未找到内置字典文件: {src_path},跳过"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
shutil.copy2(src_path, dest_path)
|
||||
except OSError as exc:
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f"[{name}] 复制内置字典到运行目录失败: {exc}"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 统计文件大小和行数
|
||||
try:
|
||||
file_size = dest_path.stat().st_size
|
||||
except OSError:
|
||||
file_size = 0
|
||||
|
||||
line_count = 0
|
||||
try:
|
||||
with dest_path.open("rb") as f:
|
||||
for _ in f:
|
||||
line_count += 1
|
||||
except OSError:
|
||||
logger.warning("统计字典行数失败: %s", src_path)
|
||||
|
||||
# 计算文件 hash
|
||||
file_hash = safe_calc_file_sha256(str(dest_path)) or ""
|
||||
|
||||
# 如果之前已有记录则更新,否则创建新记录
|
||||
if existing:
|
||||
existing.file_path = str(dest_path)
|
||||
existing.file_size = file_size
|
||||
existing.line_count = line_count
|
||||
existing.file_hash = file_hash
|
||||
existing.description = existing.description or description
|
||||
existing.save(update_fields=[
|
||||
"file_path",
|
||||
"file_size",
|
||||
"line_count",
|
||||
"file_hash",
|
||||
"description",
|
||||
"updated_at",
|
||||
])
|
||||
wordlist = existing
|
||||
action = "更新"
|
||||
else:
|
||||
wordlist = Wordlist.objects.create(
|
||||
name=name,
|
||||
description=description,
|
||||
file_path=str(dest_path),
|
||||
file_size=file_size,
|
||||
line_count=line_count,
|
||||
file_hash=file_hash,
|
||||
)
|
||||
action = "创建"
|
||||
|
||||
initialized += 1
|
||||
hash_preview = (wordlist.file_hash[:16] + "...") if wordlist.file_hash else "N/A"
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{name}] {action}字典记录成功: id={wordlist.id}, size={wordlist.file_size}, lines={wordlist.line_count}, hash={hash_preview}"
|
||||
))
|
||||
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"初始化完成: 成功 {initialized}, 已存在跳过 {skipped}, 文件缺失 {failed}"
|
||||
))
|
||||
0
backend/apps/engine/migrations/__init__.py
Normal file
0
backend/apps/engine/migrations/__init__.py
Normal file
126
backend/apps/engine/models.py
Normal file
126
backend/apps/engine/models.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from django.db import models
|
||||
|
||||
|
||||
class WorkerNode(models.Model):
|
||||
"""Worker 节点模型 - 分布式扫描执行器"""
|
||||
|
||||
# 状态选项(前后端统一)
|
||||
STATUS_CHOICES = [
|
||||
('pending', '待部署'),
|
||||
('deploying', '部署中'),
|
||||
('online', '在线'),
|
||||
('offline', '离线'),
|
||||
]
|
||||
|
||||
name = models.CharField(max_length=100, help_text='节点名称')
|
||||
# 本地节点会自动填入 127.0.0.1 或容器 IP
|
||||
ip_address = models.GenericIPAddressField(help_text='IP 地址(本地节点为 127.0.0.1)')
|
||||
ssh_port = models.IntegerField(default=22, help_text='SSH 端口')
|
||||
username = models.CharField(max_length=50, default='root', help_text='SSH 用户名')
|
||||
password = models.CharField(max_length=200, blank=True, default='', help_text='SSH 密码')
|
||||
|
||||
# 本地节点标记(Docker 容器内的 Worker)
|
||||
is_local = models.BooleanField(default=False, help_text='是否为本地节点(Docker 容器内)')
|
||||
|
||||
# 状态(前后端统一)
|
||||
status = models.CharField(
|
||||
max_length=20,
|
||||
choices=STATUS_CHOICES,
|
||||
default='pending',
|
||||
help_text='状态: pending/deploying/online/offline'
|
||||
)
|
||||
|
||||
# 心跳数据存储在 Redis(worker:load:{id}),不再使用数据库字段
|
||||
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
updated_at = models.DateTimeField(auto_now=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'worker_node'
|
||||
verbose_name = 'Worker 节点'
|
||||
ordering = ['-created_at']
|
||||
constraints = [
|
||||
# 远程节点 IP 唯一(本地节点不限制,因为都是 127.0.0.1)
|
||||
models.UniqueConstraint(
|
||||
fields=['ip_address'],
|
||||
condition=models.Q(is_local=False),
|
||||
name='unique_remote_worker_ip'
|
||||
),
|
||||
# 名称全局唯一
|
||||
models.UniqueConstraint(
|
||||
fields=['name'],
|
||||
name='unique_worker_name'
|
||||
),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
if self.is_local:
|
||||
return f"{self.name} (本地)"
|
||||
return f"{self.name} ({self.ip_address or '未知'})"
|
||||
|
||||
|
||||
class ScanEngine(models.Model):
|
||||
"""扫描引擎模型"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
name = models.CharField(max_length=200, unique=True, help_text='引擎名称')
|
||||
configuration = models.CharField(max_length=10000, blank=True, default='', help_text='引擎配置,yaml 格式')
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'scan_engine'
|
||||
verbose_name = '扫描引擎'
|
||||
verbose_name_plural = '扫描引擎'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
def __str__(self):
|
||||
return str(self.name or f'ScanEngine {self.id}')
|
||||
|
||||
|
||||
class Wordlist(models.Model):
|
||||
"""字典文件模型"""
|
||||
|
||||
id = models.AutoField(primary_key=True)
|
||||
name = models.CharField(max_length=200, unique=True, help_text='字典名称,唯一')
|
||||
description = models.CharField(max_length=200, blank=True, default='', help_text='字典描述')
|
||||
file_path = models.CharField(max_length=500, help_text='后端保存的字典文件绝对路径')
|
||||
file_size = models.BigIntegerField(default=0, help_text='文件大小(字节)')
|
||||
line_count = models.IntegerField(default=0, help_text='字典行数')
|
||||
file_hash = models.CharField(max_length=64, blank=True, default='', help_text='文件 SHA-256 哈希,用于缓存校验')
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
|
||||
|
||||
class Meta:
|
||||
db_table = 'wordlist'
|
||||
verbose_name = '字典文件'
|
||||
verbose_name_plural = '字典文件'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
class NucleiTemplateRepo(models.Model):
|
||||
"""Nuclei 模板 Git 仓库模型(多仓库)"""
|
||||
|
||||
name = models.CharField(max_length=200, unique=True, help_text="仓库名称,用于前端展示和配置引用")
|
||||
repo_url = models.CharField(max_length=500, help_text="Git 仓库地址")
|
||||
local_path = models.CharField(max_length=500, blank=True, default='', help_text="本地工作目录绝对路径")
|
||||
commit_hash = models.CharField(max_length=40, blank=True, default='', help_text="最后同步的 Git commit hash,用于 Worker 版本校验")
|
||||
last_synced_at = models.DateTimeField(null=True, blank=True, help_text="最后一次成功同步时间")
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text="创建时间")
|
||||
updated_at = models.DateTimeField(auto_now=True, help_text="更新时间")
|
||||
|
||||
class Meta:
|
||||
db_table = "nuclei_template_repo"
|
||||
verbose_name = "Nuclei 模板仓库"
|
||||
verbose_name_plural = "Nuclei 模板仓库"
|
||||
|
||||
def __str__(self) -> str: # pragma: no cover - 简单表示
|
||||
return f"NucleiTemplateRepo({self.id}, {self.name})"
|
||||
17
backend/apps/engine/repositories/__init__.py
Normal file
17
backend/apps/engine/repositories/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Engine Repositories 模块
|
||||
|
||||
提供 ScanEngine、WorkerNode、Wordlist、NucleiRepo 等数据访问层实现
|
||||
"""
|
||||
|
||||
from .django_engine_repository import DjangoEngineRepository
|
||||
from .django_worker_repository import DjangoWorkerRepository
|
||||
from .django_wordlist_repository import DjangoWordlistRepository
|
||||
from .nuclei_repo_repository import NucleiTemplateRepository, TemplateFileRepository
|
||||
|
||||
__all__ = [
|
||||
"DjangoEngineRepository",
|
||||
"DjangoWorkerRepository",
|
||||
"DjangoWordlistRepository",
|
||||
"NucleiTemplateRepository",
|
||||
"TemplateFileRepository",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user