Compare commits

..

56 Commits

Author SHA1 Message Date
yyhuni
b4037202dc feat: use registry cache for faster builds 2026-01-03 17:35:54 +08:00
yyhuni
4b4f9862bf ci(docker): add postgres image build configuration and update image tags
- Add xingrin-postgres image build job to docker-build workflow for multi-platform support (linux/amd64,linux/arm64)
- Update docker-compose.dev.yml to use IMAGE_TAG variable with dev as default fallback
- Update docker-compose.yml to use IMAGE_TAG variable with required validation
- Replace hardcoded postgres image tag (15) with dynamic IMAGE_TAG for better version management
- Enable flexible image tagging across development and production environments
2026-01-03 17:26:34 +08:00
github-actions[bot]
1c42e4978f chore: bump version to v1.3.5-dev 2026-01-03 08:44:06 +00:00
github-actions[bot]
57bab63997 chore: bump version to v1.3.3-dev 2026-01-03 05:55:07 +00:00
github-actions[bot]
b1f0f18ac0 chore: bump version to v1.3.4-dev 2026-01-03 05:54:50 +00:00
yyhuni
ccee5471b8 docs(readme): add notification push service documentation
- Add notification push service feature to visualization interface section
- Document support for real-time WeChat Work, Telegram, and Discord message push
- Enhance feature list clarity for notification capabilities
2026-01-03 13:34:36 +08:00
yyhuni
0ccd362535 优化下载逻辑 2026-01-03 13:32:58 +08:00
yyhuni
7f2af7f7e2 feat(search): add result export functionality and pagination limit support
- Add optional limit parameter to AssetSearchService.search() method for controlling result set size
- Implement AssetSearchExportView for exporting search results as CSV files with UTF-8 BOM encoding
- Add CSV export endpoint at GET /api/assets/search/export/ with configurable MAX_EXPORT_ROWS limit (10000)
- Support both website and endpoint asset types with type-specific column mappings in CSV export
- Format array fields (tech, matched_gf_patterns) and dates appropriately in exported CSV
- Update URL routing to include new search export endpoint
- Update views __init__.py to export AssetSearchExportView
- Add CSV generation with streaming response for efficient memory usage on large exports
- Update frontend search service to support export functionality
- Add internationalization strings for export feature in en.json and zh.json
- Update smart-filter-input and search-results-table components to support export UI
- Update installation and Docker startup scripts for deployment compatibility
2026-01-03 13:22:21 +08:00
yyhuni
4bd0f9e8c1 feat(search): implement dual-view IMMV architecture for website and endpoint assets
- Add incremental materialized view (IMMV) support for both Website and Endpoint asset types using pg_ivm extension
- Create asset_search_view IMMV with optimized indexes for host, title, url, headers, body, tech, status_code, and created_at fields
- Create endpoint_search_view IMMV with identical field structure and indexing strategy for endpoint-specific searches
- Extend search_service.py to support asset type routing with VIEW_MAPPING and VALID_ASSET_TYPES configuration
- Add comprehensive field mapping and array field definitions for both asset types
- Implement dual-query execution path in search views to handle website and endpoint searches independently
- Update frontend search components to support asset type filtering and result display
- Add search results table component with improved data presentation and filtering capabilities
- Update installation scripts and Docker configuration for pg_ivm extension deployment
- Add internationalization strings for new search UI elements in English and Chinese
- Consolidate index creation and cleanup logic in migrations for maintainability
- Enable automatic incremental updates on data changes without manual view refresh
2026-01-03 12:41:20 +08:00
yyhuni
68cc996e3b refactor(asset): standardize snapshot and asset model field naming and types
- Rename `status` to `status_code` in WebsiteSnapshotDTO for consistency
- Rename `web_server` to `webserver` in WebsiteSnapshotDTO for consistency
- Make `target_id` required field in EndpointSnapshotDTO and WebsiteSnapshotDTO
- Remove optional validation check for `target_id` in EndpointSnapshotDTO
- Convert CharField to TextField for url, location, title, webserver, and content_type fields in Endpoint and EndpointSnapshot models to support longer values
- Update migration 0001_initial.py to reflect field type changes from CharField to TextField
- Update all related services and repositories to use standardized field names
- Update serializers to map renamed fields correctly
- Ensure consistent field naming across DTOs, models, and database schema
2026-01-03 09:08:25 +08:00
github-actions[bot]
f1e79d638e chore: bump version to v1.3.2-dev 2026-01-03 00:33:26 +00:00
yyhuni
d484133e4c chore(docker): optimize server dockerfile with docker-ce-cli installation
- Replace full docker.io package with lightweight docker-ce-cli to reduce image size
- Add ca-certificates and gnupg dependencies for secure package management
- Improve Docker installation process for local Worker task distribution
- Reduce unnecessary dependencies in server container build
2026-01-03 08:09:03 +08:00
yyhuni
fc977ae029 chore(docker,frontend): optimize docker installation and add auth bypass config
- Replace docker.io installation script with apt-get package manager for better reliability
- Add NEXT_PUBLIC_SKIP_AUTH environment variable to Vercel config for development
- Improve Docker build layer caching by using native package manager instead of curl script
- Simplify frontend deployment configuration for local development workflows
2026-01-03 08:08:40 +08:00
yyhuni
f328474404 feat(frontend): add comprehensive mock data infrastructure for services
- Add mock data modules for auth, engines, notifications, scheduled-scans, and workers
- Implement mock authentication data with user profiles and login/logout responses
- Create mock scan engine configurations with multiple predefined scanning profiles
- Add mock notification system with various severity levels and categories
- Implement mock scheduled scan data with cron expressions and run history
- Add mock worker node data with status and performance metrics
- Update service layer to integrate with new mock data infrastructure
- Provide helper functions for filtering and paginating mock data
- Enable frontend development and testing without backend API dependency
2026-01-03 07:59:20 +08:00
yyhuni
68e726a066 chore(docker): update base image to python 3.10-slim-bookworm
- Update Python base image from 3.10-slim to 3.10-slim-bookworm
- Ensures compatibility with latest Debian stable release
- Improves security with updated system packages and dependencies
2026-01-02 23:19:09 +08:00
yyhuni
77a6f45909 fix:搜索的楼栋统计问题 2026-01-02 23:12:55 +08:00
yyhuni
49d1f1f1bb 采用ivm增量更新方案进行搜索 2026-01-02 22:46:40 +08:00
yyhuni
db8ecb1644 feat(search): add mock data infrastructure and vulnerability detail integration
- Add comprehensive mock data configuration for all major entities (dashboard, endpoints, organizations, scans, subdomains, targets, vulnerabilities, websites)
- Implement mock service layer with centralized config for development and testing
- Add vulnerability detail dialog integration to search results with lazy loading
- Enhance search result card with vulnerability viewing capability
- Update search materialized view migration to include vulnerability name field
- Implement default host fuzzy search fallback for bare text queries without operators
- Add vulnerability data formatting in search view for consistent API response structure
- Configure Vercel deployment settings and update Next.js configuration
- Update all service layers to support mock data injection for development environment
- Extend search types with improved vulnerability data structure
- Add internationalization strings for vulnerability loading errors
- Enable rapid frontend development and testing without backend API dependency
2026-01-02 19:06:09 +08:00
yyhuni
18cc016268 feat(search): implement advanced query parser with expression syntax support
- Add SearchQueryParser class to parse complex search expressions with operators (=, ==, !=)
- Support logical operators && (AND) and || (OR) for combining multiple conditions
- Implement field mapping for frontend to database field translation
- Add support for array field searching (tech stack) with unnest and ANY operators
- Support fuzzy matching (=), exact matching (==), and negation (!=) operators
- Add proper SQL injection prevention through parameterized queries
- Refactor search service to use expression-based filtering instead of simple filters
- Update search views to integrate new query parser
- Enhance frontend search hook and service to support new expression syntax
- Update search types to reflect new query structure
- Improve search page UI to display expression syntax examples and help text
- Enable complex multi-condition searches like: host="api" && tech="nginx" || status=="200"
2026-01-02 17:46:31 +08:00
yyhuni
23bc463283 feat(search): improve technology stack filtering with fuzzy matching
- Replace exact array matching with fuzzy search using ILIKE operator
- Update tech filter to search within array elements using unnest() and EXISTS
- Support partial technology name matching (e.g., "node" matches "nodejs")
- Apply consistent fuzzy matching logic across both search methods
- Enhance user experience by allowing flexible technology stack queries
2026-01-02 17:01:24 +08:00
yyhuni
7b903b91b2 feat(search): implement comprehensive search infrastructure with materialized views and pagination
- Add asset search service with materialized view support for optimized queries
- Implement search refresh service for maintaining up-to-date search indexes
- Create database migrations for AssetStatistics, StatisticsHistory, Directory, and DirectorySnapshot models
- Add PostgreSQL GIN indexes with trigram operators for full-text search capabilities
- Implement search pagination component with configurable page size and navigation
- Add search result card component with enhanced asset display formatting
- Create search API views with filtering and sorting capabilities
- Add use-search hook for client-side search state management
- Implement search service client for API communication
- Update search types with pagination metadata and result structures
- Add English and Chinese translations for search UI components
- Enhance scheduler to support search index refresh tasks
- Refactor asset views into modular search_views and asset_views
- Update URL routing to support new search endpoints
- Improve scan flow handlers for better search index integration
2026-01-02 16:57:54 +08:00
yyhuni
b3136d51b9 搜索页面前端UI设计完成 2026-01-02 10:07:26 +08:00
yyhuni
236c828041 chore(fingerprints): remove deprecated ARL fingerprint rules
- Remove obsolete fingerprint detection rules from ARL.yaml
- Clean up legacy device and service signatures that are no longer maintained
- Reduce fingerprint database size by eliminating unused detection patterns
- Improve maintainability by removing outdated vendor-specific rules
2026-01-01 22:45:08 +08:00
yyhuni
fb13bb74d8 feat(filter): add array fuzzy search support with PostgreSQL array_to_string
- Add ArrayToString custom PostgreSQL function for converting arrays to delimited strings
- Implement array field annotation in QueryBuilder to support fuzzy matching on JSON array fields
- Enhance _build_single_q to handle three operators for JSON arrays: exact match (==), negation (!=), and fuzzy search (=)
- Update target navigation routes from subdomain to website view for consistency
- Enable fuzzy search on array fields by converting them to text during query building
2026-01-01 22:41:57 +08:00
yyhuni
f076c682b6 feat(scan): add multi-engine support and config merging with enhanced indexing
- Add multi-engine support to Scan model with engine_ids and engine_names fields
- Implement config_merger utility for merging multiple engine configurations
- Add merged_configuration property to Scan model for unified config access
- Update scan creation and scheduling services to handle multiple engines
- Add pg_trgm GIN indexes to asset and snapshot models for fuzzy search on url, title, and name fields
- Update scan views and serializers to support multi-engine selection and display
- Enhance frontend components for multi-engine scan initiation and scheduling
- Update test data generation script for multi-engine scan scenarios
- Add internationalization strings for multi-engine UI elements
- Refactor scan flow to use merged configuration instead of single engine config
- Update Docker compose files with latest configuration
2026-01-01 22:35:05 +08:00
yyhuni
9eda2caceb feat(asset): add response headers and body tracking with pg_trgm indexing
- Rename body_preview to response_body across endpoint and website models for consistency
- Change response_headers from Dict to string type for efficient text indexing
- Add pg_trgm PostgreSQL extension initialization in AssetConfig for GIN index support
- Update all DTOs to reflect response_body and response_headers field changes
- Modify repositories to handle new response_body and response_headers formats
- Update serializers and views to work with string-based response headers
- Add response_headers and response_body columns to frontend endpoint and website tables
- Update command templates and scan tasks to populate response_body and response_headers
- Add database initialization script for pg_trgm extension in PostgreSQL setup
- Update frontend types and translations for new field names
- Enable efficient full-text search on response headers and body content through GIN indexes
2026-01-01 19:34:11 +08:00
yyhuni
b1c9e202dd feat(sidebar): add feedback link to secondary navigation menu
- Import IconMessageReport icon from tabler/icons-react for feedback menu item
- Add feedback navigation item linking to GitHub issues page
- Add "feedback" translation key to English messages (en.json)
- Add "feedback" translation key to Chinese messages (zh.json) as "反馈建议"
- Improves user engagement by providing direct access to issue reporting
2026-01-01 18:31:34 +08:00
yyhuni
918669bc29 style(ui): update expandable cell whitespace handling for better formatting
- Change whitespace class from `whitespace-normal` to `whitespace-pre-wrap` in expandable cell component
- Improves text rendering by preserving whitespace and line breaks in cell content
- Ensures consistent formatting display across different content types (mono, url, muted variants)
2026-01-01 16:41:47 +08:00
yyhuni
fd70b0544d docs(frontend): update Chinese translations to English for consistency
- Change "响应头" to "Response Headers" in endpoint messages
- Change "响应头" to "Response Headers" in website messages
- Maintain consistency across frontend message translations
- Improve clarity for international users by standardizing field labels
2026-01-01 16:23:03 +08:00
github-actions[bot]
0f2df7a5f3 chore: bump version to v1.2.14-dev 2026-01-01 05:13:25 +00:00
yyhuni
857ab737b5 feat(fingerprint): enhance xingfinger task with snapshot tracking and field merging
- Replace `not_found_count` with `created_count` and `snapshot_count` metrics in fingerprint detect flow
- Initialize and aggregate `snapshot_count` across tool statistics
- Refactor `parse_xingfinger_line()` to return structured dict with url, techs, server, title, status_code, and content_length
- Replace `bulk_merge_tech_field()` with `bulk_merge_website_fields()` to support merging multiple WebSite fields
- Implement smart merge strategy: arrays deduplicated, scalar fields only updated when empty/NULL
- Remove dynamic model loading via importlib in favor of direct WebSite model import
- Add WebsiteSnapshotDTO and DjangoWebsiteSnapshotRepository imports for snapshot handling
- Improve xingfinger output parsing to capture server, title, and HTTP metadata alongside technology detection
2026-01-01 12:40:49 +08:00
yyhuni
ee2d99edda feat(asset): add response headers tracking to endpoints and websites
- Add response_headers field to Endpoint and WebSite models as JSONField
- Add response_headers field to EndpointSnapshot and WebsiteSnapshot models
- Update all related DTOs to include response_headers with Dict[str, Any] type
- Add GIN indexes on response_headers fields for optimized JSON queries
- Update endpoint and website repositories to handle response_headers data
- Update serializers to include response_headers in API responses
- Update frontend components to display response headers in detail views
- Add response_headers to fingerprint detection and site scan tasks
- Update command templates and engine config to support header extraction
- Add internationalization strings for response headers in en.json and zh.json
- Update TypeScript types for endpoint and website to include response_headers
- Enhance scan history and target detail pages to show response header information
2026-01-01 12:25:22 +08:00
github-actions[bot]
db6ce16aca chore: bump version to v1.2.13-dev 2026-01-01 02:24:08 +00:00
yyhuni
ab800eca06 feat(frontend): reorder navigation tabs for improved UX
- Move "Websites" tab to first position in scan history and target layouts
- Reposition "IP Addresses" tab before "Ports" for better logical flow
- Maintain consistent tab ordering across both scan history and target pages
- Improve navigation hierarchy by placing primary discovery results first
2026-01-01 09:47:30 +08:00
yyhuni
e8e5572339 perf(asset): add GIN indexes for tech array fields and improve query parser
- Add GinIndex for tech array field in Endpoint model to optimize __contains queries
- Add GinIndex for tech array field in WebSite model to optimize __contains queries
- Import GinIndex from django.contrib.postgres.indexes
- Refactor QueryParser to protect quoted filter values during tokenization
- Implement placeholder-based filter extraction to preserve spaces within quoted values
- Replace filter tokens with placeholders before logical operator normalization
- Restore original filter conditions from placeholders during parsing
- Fix spacing in comments for consistency (add space after "从")
- Improves query performance for technology stack filtering on large datasets
2026-01-01 08:58:03 +08:00
github-actions[bot]
d48d4bbcad chore: bump version to v1.2.12-dev 2025-12-31 16:01:48 +00:00
yyhuni
d1cca4c083 base timeout set 10s 2025-12-31 23:27:02 +08:00
yyhuni
df0810c863 feat: add fingerprint recognition feature and update documentation
- Add fingerprint recognition section to README with support for 2.7W+ rules from multiple sources (EHole, Goby, Wappalyzer, Fingers, FingerPrintHub, ARL)
- Update scanning pipeline architecture diagram to include fingerprint recognition stage between site identification and deep analysis
- Add fingerprint recognition styling to mermaid diagram for visual consistency
- Include WORKER_API_KEY environment variable in task distributor for worker authentication
- Update WeChat QR code image and public account name from "洋洋的小黑屋" to "塔罗安全学苑"
- Fix import statements in nav-system.tsx to use i18n navigation utilities instead of next/link and next/navigation
- Enhance scanning workflow documentation to reflect complete pipeline: subdomain discovery → port scanning → site identification → fingerprint recognition → URL collection → directory scanning → vulnerability scanning
2025-12-31 23:09:25 +08:00
yyhuni
d33e54c440 docs: simplify quick-start guide
- Remove alternative ZIP download method, keep only Git clone approach
- Remove update.sh script reference from service management section
- Remove dedicated "定期更新" (periodic updates) section
- Streamline documentation to focus on primary installation and usage paths
2025-12-31 22:50:08 +08:00
yyhuni
35a306fe8b fix:dev环境 2025-12-31 22:46:42 +08:00
yyhuni
724df82931 chore: pin Docker base image digests and add worker API key generation
- Pin golang:1.24 base image to specific digest to prevent upstream cache invalidation
- Pin ubuntu:24.04 base image to specific digest to prevent upstream cache invalidation
- Add WORKER_API_KEY generation in install.sh auto_fill_docker_env_secrets function
- Generate random 32-character string for WORKER_API_KEY during installation
- Update installation info message to include WORKER_API_KEY in generated secrets list
- Improve build reproducibility and security by using immutable image references
2025-12-31 22:40:38 +08:00
yyhuni
8dfffdf802 fix:认证 2025-12-31 22:21:40 +08:00
github-actions[bot]
b8cb85ce0b chore: bump version to v1.2.9-dev 2025-12-31 13:48:44 +00:00
yyhuni
da96d437a4 增加授权认证 2025-12-31 20:18:34 +08:00
github-actions[bot]
feaf8062e5 chore: bump version to v1.2.8-dev 2025-12-31 11:33:14 +00:00
yyhuni
4bab76f233 fix:组织删除问题 2025-12-31 17:50:37 +08:00
yyhuni
09416b4615 fix:redis端口 2025-12-31 17:45:25 +08:00
github-actions[bot]
bc1c5f6b0e chore: bump version to v1.2.7-dev 2025-12-31 06:16:42 +00:00
github-actions[bot]
2f2742e6fe chore: bump version to v1.2.6-dev 2025-12-31 05:29:36 +00:00
yyhuni
be3c346a74 增加搜索字段 2025-12-31 12:40:21 +08:00
yyhuni
0c7a6fff12 增加tech字段的搜索 2025-12-31 12:37:02 +08:00
yyhuni
3b4f0e3147 fix:指纹识别 2025-12-31 12:30:31 +08:00
yyhuni
51212a2a0c fix:指纹识别 2025-12-31 12:17:23 +08:00
yyhuni
58533bbaf6 fix:docker api 2025-12-31 12:03:08 +08:00
github-actions[bot]
6ccca1602d chore: bump version to v1.2.5-dev 2025-12-31 03:48:32 +00:00
yyhuni
6389b0f672 feat(fingerprints): Add type annotation to getAcceptConfig function
- Add explicit return type annotation `Record<string, string[]>` to getAcceptConfig function
- Improve type safety and IDE autocomplete for file type configuration
- Enhance code clarity for accepted file types mapping in import dialog
2025-12-31 10:17:25 +08:00
167 changed files with 9158 additions and 13196 deletions

View File

@@ -44,6 +44,10 @@ jobs:
dockerfile: docker/agent/Dockerfile
context: .
platforms: linux/amd64,linux/arm64
- image: xingrin-postgres
dockerfile: docker/postgres/Dockerfile
context: docker/postgres
platforms: linux/amd64,linux/arm64
steps:
- name: Checkout
@@ -106,8 +110,8 @@ jobs:
${{ steps.version.outputs.IS_RELEASE == 'true' && format('{0}/{1}:latest', env.IMAGE_PREFIX, matrix.image) || '' }}
build-args: |
IMAGE_TAG=${{ steps.version.outputs.VERSION }}
cache-from: type=gha,scope=${{ matrix.image }}
cache-to: type=gha,mode=max,scope=${{ matrix.image }}
cache-from: type=registry,ref=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:cache
cache-to: type=registry,ref=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:cache,mode=max
provenance: false
sbom: false

View File

@@ -13,14 +13,14 @@
<p align="center">
<a href="#-功能特性">功能特性</a> •
<a href="#-全局资产搜索">资产搜索</a> •
<a href="#-快速开始">快速开始</a> •
<a href="#-文档">文档</a> •
<a href="#-技术栈">技术栈</a> •
<a href="#-反馈与贡献">反馈与贡献</a>
</p>
<p align="center">
<sub>🔍 关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
<sub>🔍 关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | 资产搜索 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
</p>
---
@@ -62,9 +62,14 @@
- **自定义流程** - YAML 配置扫描流程,灵活编排
- **定时扫描** - Cron 表达式配置,自动化周期扫描
### 🔖 指纹识别
- **多源指纹库** - 内置 EHole、Goby、Wappalyzer、Fingers、FingerPrintHub、ARL 等 2.7W+ 指纹规则
- **自动识别** - 扫描流程自动执行,识别 Web 应用技术栈
- **指纹管理** - 支持查询、导入、导出指纹规则
#### 扫描流程架构
完整的扫描流程包括子域名发现、端口扫描、站点发现、URL 收集、目录扫描、漏洞扫描等阶段
完整的扫描流程包括:子域名发现、端口扫描、站点发现、指纹识别、URL 收集、目录扫描、漏洞扫描等阶段
```mermaid
flowchart LR
@@ -75,7 +80,8 @@ flowchart LR
SUB["子域名发现<br/>subfinder, amass, puredns"]
PORT["端口扫描<br/>naabu"]
SITE["站点识别<br/>httpx"]
SUB --> PORT --> SITE
FINGER["指纹识别<br/>xingfinger"]
SUB --> PORT --> SITE --> FINGER
end
subgraph STAGE2["阶段 2: 深度分析"]
@@ -91,7 +97,7 @@ flowchart LR
FINISH["扫描完成"]
START --> STAGE1
SITE --> STAGE2
FINGER --> STAGE2
STAGE2 --> STAGE3
STAGE3 --> FINISH
@@ -103,6 +109,7 @@ flowchart LR
style SUB fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
style PORT fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
style SITE fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
style FINGER fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
style URL fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
style DIR fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
style VULN fill:#f0b27a,stroke:#e67e22,stroke-width:1px,color:#fff
@@ -155,9 +162,34 @@ flowchart TB
W3 -.心跳上报.-> REDIS
```
### 🔎 全局资产搜索
- **多类型搜索** - 支持 Website 和 Endpoint 两种资产类型
- **表达式语法** - 支持 `=`(模糊)、`==`(精确)、`!=`(不等于)操作符
- **逻辑组合** - 支持 `&&` (AND) 和 `||` (OR) 逻辑组合
- **多字段查询** - 支持 host、url、title、tech、status、body、header 字段
- **CSV 导出** - 流式导出全部搜索结果,无数量限制
#### 搜索语法示例
```bash
# 基础搜索
host="api" # host 包含 "api"
status=="200" # 状态码精确等于 200
tech="nginx" # 技术栈包含 nginx
# 组合搜索
host="api" && status=="200" # host 包含 api 且状态码为 200
tech="vue" || tech="react" # 技术栈包含 vue 或 react
# 复杂查询
host="admin" && tech="php" && status=="200"
url="/api/v1" && status!="404"
```
### 📊 可视化界面
- **数据统计** - 资产/漏洞统计仪表盘
- **实时通知** - WebSocket 消息推送
- **通知推送** - 实时企业微信tgdiscard消息推送服务
---
@@ -165,7 +197,7 @@ flowchart TB
### 环境要求
- **操作系统**: Ubuntu 20.04+ / Debian 11+ (推荐)
- **操作系统**: Ubuntu 20.04+ / Debian 11+
- **硬件**: 2核 4G 内存起步20GB+ 磁盘空间
### 一键安装
@@ -190,6 +222,7 @@ sudo ./install.sh --mirror
### 访问服务
- **Web 界面**: `https://ip:8083`
- **默认账号**: admin / admin首次登录后请修改密码
### 常用命令
@@ -216,7 +249,7 @@ sudo ./uninstall.sh
- 目前版本就我个人使用,可能会有很多边界问题
- 如有问题,建议,其他,优先提交[Issue](https://github.com/yyhuni/xingrin/issues),也可以直接给我的公众号发消息,我都会回复的
- 微信公众号: **洋洋的小黑屋**
- 微信公众号: **塔罗安全学苑**
<img src="docs/wechat-qrcode.png" alt="微信公众号" width="200">

View File

@@ -1 +1 @@
v1.2.2-dev
v1.3.5-dev

View File

@@ -1,5 +1,10 @@
import logging
import sys
from django.apps import AppConfig
logger = logging.getLogger(__name__)
class AssetConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
@@ -8,3 +13,94 @@ class AssetConfig(AppConfig):
def ready(self):
# 导入所有模型以确保Django发现并注册
from . import models
# 启用 pg_trgm 扩展(用于文本模糊搜索索引)
# 用于已有数据库升级场景
self._ensure_pg_trgm_extension()
# 验证 pg_ivm 扩展是否可用(用于 IMMV 增量维护)
self._verify_pg_ivm_extension()
def _ensure_pg_trgm_extension(self):
"""
确保 pg_trgm 扩展已启用。
该扩展用于 response_body 和 response_headers 字段的 GIN 索引,
支持高效的文本模糊搜索。
"""
from django.db import connection
# 检查是否为 PostgreSQL 数据库
if connection.vendor != 'postgresql':
logger.debug("跳过 pg_trgm 扩展:当前数据库不是 PostgreSQL")
return
try:
with connection.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
logger.debug("pg_trgm 扩展已启用")
except Exception as e:
# 记录错误但不阻止应用启动
# 常见原因:权限不足(需要超级用户权限)
logger.warning(
"无法创建 pg_trgm 扩展: %s"
"这可能导致 response_body 和 response_headers 字段的 GIN 索引无法正常工作。"
"请手动执行: CREATE EXTENSION IF NOT EXISTS pg_trgm;",
str(e)
)
def _verify_pg_ivm_extension(self):
"""
验证 pg_ivm 扩展是否可用。
pg_ivm 用于 IMMV增量维护物化视图是系统必需的扩展。
如果不可用,将记录错误并退出。
"""
from django.db import connection
# 检查是否为 PostgreSQL 数据库
if connection.vendor != 'postgresql':
logger.debug("跳过 pg_ivm 验证:当前数据库不是 PostgreSQL")
return
# 跳过某些管理命令(如 migrate、makemigrations
import sys
if len(sys.argv) > 1 and sys.argv[1] in ('migrate', 'makemigrations', 'collectstatic', 'check'):
logger.debug("跳过 pg_ivm 验证:当前为管理命令")
return
try:
with connection.cursor() as cursor:
# 检查 pg_ivm 扩展是否已安装
cursor.execute("""
SELECT COUNT(*) FROM pg_extension WHERE extname = 'pg_ivm'
""")
count = cursor.fetchone()[0]
if count > 0:
logger.info("✓ pg_ivm 扩展已启用")
else:
# 尝试创建扩展
try:
cursor.execute("CREATE EXTENSION IF NOT EXISTS pg_ivm;")
logger.info("✓ pg_ivm 扩展已创建并启用")
except Exception as create_error:
logger.error(
"=" * 60 + "\n"
"错误: pg_ivm 扩展未安装\n"
"=" * 60 + "\n"
"pg_ivm 是系统必需的扩展,用于增量维护物化视图。\n\n"
"请在 PostgreSQL 服务器上安装 pg_ivm\n"
" curl -sSL https://raw.githubusercontent.com/yyhuni/xingrin/main/docker/scripts/install-pg-ivm.sh | sudo bash\n\n"
"或手动安装:\n"
" 1. apt install build-essential postgresql-server-dev-15 git\n"
" 2. git clone https://github.com/sraoss/pg_ivm.git && cd pg_ivm && make && make install\n"
" 3. 在 postgresql.conf 中添加: shared_preload_libraries = 'pg_ivm'\n"
" 4. 重启 PostgreSQL\n"
"=" * 60
)
# 在生产环境中退出,开发环境中仅警告
from django.conf import settings
if not settings.DEBUG:
sys.exit(1)
except Exception as e:
logger.error(f"pg_ivm 扩展验证失败: {e}")

View File

@@ -14,12 +14,13 @@ class EndpointDTO:
status_code: Optional[int] = None
content_length: Optional[int] = None
webserver: Optional[str] = None
body_preview: Optional[str] = None
response_body: Optional[str] = None
content_type: Optional[str] = None
tech: Optional[List[str]] = None
vhost: Optional[bool] = None
location: Optional[str] = None
matched_gf_patterns: Optional[List[str]] = None
response_headers: Optional[str] = None
def __post_init__(self):
if self.tech is None:

View File

@@ -17,9 +17,10 @@ class WebSiteDTO:
webserver: str = ''
content_type: str = ''
tech: List[str] = None
body_preview: str = ''
response_body: str = ''
vhost: Optional[bool] = None
created_at: str = None
response_headers: str = ''
def __post_init__(self):
if self.tech is None:

View File

@@ -13,6 +13,7 @@ class EndpointSnapshotDTO:
快照只属于 scan。
"""
scan_id: int
target_id: int # 必填,用于同步到资产表
url: str
host: str = '' # 主机名域名或IP地址
title: str = ''
@@ -22,10 +23,10 @@ class EndpointSnapshotDTO:
webserver: str = ''
content_type: str = ''
tech: List[str] = None
body_preview: str = ''
response_body: str = ''
vhost: Optional[bool] = None
matched_gf_patterns: List[str] = None
target_id: Optional[int] = None # 冗余字段,用于同步到资产表
response_headers: str = ''
def __post_init__(self):
if self.tech is None:
@@ -42,9 +43,6 @@ class EndpointSnapshotDTO:
"""
from apps.asset.dtos.asset import EndpointDTO
if self.target_id is None:
raise ValueError("target_id 不能为 None无法同步到资产表")
return EndpointDTO(
target_id=self.target_id,
url=self.url,
@@ -53,10 +51,11 @@ class EndpointSnapshotDTO:
status_code=self.status_code,
content_length=self.content_length,
webserver=self.webserver,
body_preview=self.body_preview,
response_body=self.response_body,
content_type=self.content_type,
tech=self.tech if self.tech else [],
vhost=self.vhost,
location=self.location,
matched_gf_patterns=self.matched_gf_patterns if self.matched_gf_patterns else []
matched_gf_patterns=self.matched_gf_patterns if self.matched_gf_patterns else [],
response_headers=self.response_headers,
)

View File

@@ -13,18 +13,19 @@ class WebsiteSnapshotDTO:
快照只属于 scantarget 信息通过 scan.target 获取。
"""
scan_id: int
target_id: int # 仅用于传递数据,不保存到数据库
target_id: int # 必填,用于同步到资产表
url: str
host: str
title: str = ''
status: Optional[int] = None
status_code: Optional[int] = None # 统一命名status -> status_code
content_length: Optional[int] = None
location: str = ''
web_server: str = ''
webserver: str = '' # 统一命名web_server -> webserver
content_type: str = ''
tech: List[str] = None
body_preview: str = ''
response_body: str = ''
vhost: Optional[bool] = None
response_headers: str = ''
def __post_init__(self):
if self.tech is None:
@@ -44,12 +45,13 @@ class WebsiteSnapshotDTO:
url=self.url,
host=self.host,
title=self.title,
status_code=self.status,
status_code=self.status_code,
content_length=self.content_length,
location=self.location,
webserver=self.web_server,
webserver=self.webserver,
content_type=self.content_type,
tech=self.tech if self.tech else [],
body_preview=self.body_preview,
vhost=self.vhost
response_body=self.response_body,
vhost=self.vhost,
response_headers=self.response_headers,
)

View File

@@ -0,0 +1,345 @@
# Generated by Django 5.2.7 on 2026-01-02 04:45
import django.contrib.postgres.fields
import django.contrib.postgres.indexes
import django.core.validators
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
('scan', '0001_initial'),
('targets', '0001_initial'),
]
operations = [
migrations.CreateModel(
name='AssetStatistics',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('total_targets', models.IntegerField(default=0, help_text='目标总数')),
('total_subdomains', models.IntegerField(default=0, help_text='子域名总数')),
('total_ips', models.IntegerField(default=0, help_text='IP地址总数')),
('total_endpoints', models.IntegerField(default=0, help_text='端点总数')),
('total_websites', models.IntegerField(default=0, help_text='网站总数')),
('total_vulns', models.IntegerField(default=0, help_text='漏洞总数')),
('total_assets', models.IntegerField(default=0, help_text='总资产数(子域名+IP+端点+网站)')),
('prev_targets', models.IntegerField(default=0, help_text='上次目标总数')),
('prev_subdomains', models.IntegerField(default=0, help_text='上次子域名总数')),
('prev_ips', models.IntegerField(default=0, help_text='上次IP地址总数')),
('prev_endpoints', models.IntegerField(default=0, help_text='上次端点总数')),
('prev_websites', models.IntegerField(default=0, help_text='上次网站总数')),
('prev_vulns', models.IntegerField(default=0, help_text='上次漏洞总数')),
('prev_assets', models.IntegerField(default=0, help_text='上次总资产数')),
('updated_at', models.DateTimeField(auto_now=True, help_text='最后更新时间')),
],
options={
'verbose_name': '资产统计',
'verbose_name_plural': '资产统计',
'db_table': 'asset_statistics',
},
),
migrations.CreateModel(
name='StatisticsHistory',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('date', models.DateField(help_text='统计日期', unique=True)),
('total_targets', models.IntegerField(default=0, help_text='目标总数')),
('total_subdomains', models.IntegerField(default=0, help_text='子域名总数')),
('total_ips', models.IntegerField(default=0, help_text='IP地址总数')),
('total_endpoints', models.IntegerField(default=0, help_text='端点总数')),
('total_websites', models.IntegerField(default=0, help_text='网站总数')),
('total_vulns', models.IntegerField(default=0, help_text='漏洞总数')),
('total_assets', models.IntegerField(default=0, help_text='总资产数')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
],
options={
'verbose_name': '统计历史',
'verbose_name_plural': '统计历史',
'db_table': 'statistics_history',
'ordering': ['-date'],
'indexes': [models.Index(fields=['date'], name='statistics__date_1d29cd_idx')],
},
),
migrations.CreateModel(
name='Directory',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.CharField(help_text='完整请求 URL', max_length=2000)),
('status', models.IntegerField(blank=True, help_text='HTTP 响应状态码', null=True)),
('content_length', models.BigIntegerField(blank=True, help_text='响应体字节大小Content-Length 或实际长度)', null=True)),
('words', models.IntegerField(blank=True, help_text='响应体中单词数量(按空格分割)', null=True)),
('lines', models.IntegerField(blank=True, help_text='响应体行数(按换行符分割)', null=True)),
('content_type', models.CharField(blank=True, default='', help_text='响应头 Content-Type 值', max_length=200)),
('duration', models.BigIntegerField(blank=True, help_text='请求耗时(单位:纳秒)', null=True)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('target', models.ForeignKey(help_text='所属的扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='directories', to='targets.target')),
],
options={
'verbose_name': '目录',
'verbose_name_plural': '目录',
'db_table': 'directory',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='directory_created_2cef03_idx'), models.Index(fields=['target'], name='directory_target__e310c8_idx'), models.Index(fields=['url'], name='directory_url_ba40cd_idx'), models.Index(fields=['status'], name='directory_status_40bbe6_idx'), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='directory_url_trgm_idx', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('target', 'url'), name='unique_directory_url_target')],
},
),
migrations.CreateModel(
name='DirectorySnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.CharField(help_text='目录URL', max_length=2000)),
('status', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
('content_length', models.BigIntegerField(blank=True, help_text='内容长度', null=True)),
('words', models.IntegerField(blank=True, help_text='响应体中单词数量(按空格分割)', null=True)),
('lines', models.IntegerField(blank=True, help_text='响应体行数(按换行符分割)', null=True)),
('content_type', models.CharField(blank=True, default='', help_text='响应头 Content-Type 值', max_length=200)),
('duration', models.BigIntegerField(blank=True, help_text='请求耗时(单位:纳秒)', null=True)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='directory_snapshots', to='scan.scan')),
],
options={
'verbose_name': '目录快照',
'verbose_name_plural': '目录快照',
'db_table': 'directory_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='directory_s_scan_id_c45900_idx'), models.Index(fields=['url'], name='directory_s_url_b4b72b_idx'), models.Index(fields=['status'], name='directory_s_status_e9f57e_idx'), models.Index(fields=['content_type'], name='directory_s_content_45e864_idx'), models.Index(fields=['-created_at'], name='directory_s_created_eb9d27_idx'), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='dir_snap_url_trgm', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_directory_per_scan_snapshot')],
},
),
migrations.CreateModel(
name='Endpoint',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.TextField(help_text='最终访问的完整URL')),
('host', models.CharField(blank=True, default='', help_text='主机名域名或IP地址', max_length=253)),
('location', models.TextField(blank=True, default='', help_text='重定向地址HTTP 3xx 响应头 Location')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('title', models.TextField(blank=True, default='', help_text='网页标题HTML <title> 标签内容)')),
('webserver', models.TextField(blank=True, default='', help_text='服务器类型HTTP 响应头 Server 值)')),
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
('content_type', models.TextField(blank=True, default='', help_text='响应类型HTTP Content-Type 响应头)')),
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈(服务器/框架/语言等)', size=None)),
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
('content_length', models.IntegerField(blank=True, help_text='响应体大小(单位字节)', null=True)),
('vhost', models.BooleanField(blank=True, help_text='是否支持虚拟主机', null=True)),
('matched_gf_patterns', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='匹配的GF模式列表用于识别敏感端点如api, debug, config等', size=None)),
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
('target', models.ForeignKey(help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)', on_delete=django.db.models.deletion.CASCADE, related_name='endpoints', to='targets.target')),
],
options={
'verbose_name': '端点',
'verbose_name_plural': '端点',
'db_table': 'endpoint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='endpoint_created_44fe9c_idx'), models.Index(fields=['target'], name='endpoint_target__7f9065_idx'), models.Index(fields=['url'], name='endpoint_url_30f66e_idx'), models.Index(fields=['host'], name='endpoint_host_5b4cc8_idx'), models.Index(fields=['status_code'], name='endpoint_status__5d4fdd_idx'), models.Index(fields=['title'], name='endpoint_title_29e26c_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='endpoint_tech_2bfa7c_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='endpoint_resp_headers_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='endpoint_url_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='endpoint_title_trgm_idx', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('url', 'target'), name='unique_endpoint_url_target')],
},
),
migrations.CreateModel(
name='EndpointSnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.TextField(help_text='端点URL')),
('host', models.CharField(blank=True, default='', help_text='主机名域名或IP地址', max_length=253)),
('title', models.TextField(blank=True, default='', help_text='页面标题')),
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
('content_length', models.IntegerField(blank=True, help_text='内容长度', null=True)),
('location', models.TextField(blank=True, default='', help_text='重定向位置')),
('webserver', models.TextField(blank=True, default='', help_text='Web服务器')),
('content_type', models.TextField(blank=True, default='', help_text='内容类型')),
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈', size=None)),
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
('vhost', models.BooleanField(blank=True, help_text='虚拟主机标志', null=True)),
('matched_gf_patterns', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='匹配的GF模式列表', size=None)),
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='endpoint_snapshots', to='scan.scan')),
],
options={
'verbose_name': '端点快照',
'verbose_name_plural': '端点快照',
'db_table': 'endpoint_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='endpoint_sn_scan_id_6ac9a7_idx'), models.Index(fields=['url'], name='endpoint_sn_url_205160_idx'), models.Index(fields=['host'], name='endpoint_sn_host_577bfd_idx'), models.Index(fields=['title'], name='endpoint_sn_title_516a05_idx'), models.Index(fields=['status_code'], name='endpoint_sn_status__83efb0_idx'), models.Index(fields=['webserver'], name='endpoint_sn_webserv_66be83_idx'), models.Index(fields=['-created_at'], name='endpoint_sn_created_21fb5b_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='endpoint_sn_tech_0d0752_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='ep_snap_resp_hdr_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='ep_snap_url_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='ep_snap_title_trgm', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_endpoint_per_scan_snapshot')],
},
),
migrations.CreateModel(
name='HostPortMapping',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('host', models.CharField(help_text='主机名域名或IP', max_length=1000)),
('ip', models.GenericIPAddressField(help_text='IP地址')),
('port', models.IntegerField(help_text='端口号1-65535', validators=[django.core.validators.MinValueValidator(1, message='端口号必须大于等于1'), django.core.validators.MaxValueValidator(65535, message='端口号必须小于等于65535')])),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('target', models.ForeignKey(help_text='所属的扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='host_port_mappings', to='targets.target')),
],
options={
'verbose_name': '主机端口映射',
'verbose_name_plural': '主机端口映射',
'db_table': 'host_port_mapping',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['target'], name='host_port_m_target__943e9b_idx'), models.Index(fields=['host'], name='host_port_m_host_f78363_idx'), models.Index(fields=['ip'], name='host_port_m_ip_2e6f02_idx'), models.Index(fields=['port'], name='host_port_m_port_9fb9ff_idx'), models.Index(fields=['host', 'ip'], name='host_port_m_host_3ce245_idx'), models.Index(fields=['-created_at'], name='host_port_m_created_11cd22_idx')],
'constraints': [models.UniqueConstraint(fields=('target', 'host', 'ip', 'port'), name='unique_target_host_ip_port')],
},
),
migrations.CreateModel(
name='HostPortMappingSnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('host', models.CharField(help_text='主机名域名或IP', max_length=1000)),
('ip', models.GenericIPAddressField(help_text='IP地址')),
('port', models.IntegerField(help_text='端口号1-65535', validators=[django.core.validators.MinValueValidator(1, message='端口号必须大于等于1'), django.core.validators.MaxValueValidator(65535, message='端口号必须小于等于65535')])),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务(主关联)', on_delete=django.db.models.deletion.CASCADE, related_name='host_port_mapping_snapshots', to='scan.scan')),
],
options={
'verbose_name': '主机端口映射快照',
'verbose_name_plural': '主机端口映射快照',
'db_table': 'host_port_mapping_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='host_port_m_scan_id_50ba0b_idx'), models.Index(fields=['host'], name='host_port_m_host_e99054_idx'), models.Index(fields=['ip'], name='host_port_m_ip_54818c_idx'), models.Index(fields=['port'], name='host_port_m_port_ed7b48_idx'), models.Index(fields=['host', 'ip'], name='host_port_m_host_8a463a_idx'), models.Index(fields=['scan', 'host'], name='host_port_m_scan_id_426fdb_idx'), models.Index(fields=['-created_at'], name='host_port_m_created_fb28b8_idx')],
'constraints': [models.UniqueConstraint(fields=('scan', 'host', 'ip', 'port'), name='unique_scan_host_ip_port_snapshot')],
},
),
migrations.CreateModel(
name='Subdomain',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('name', models.CharField(help_text='子域名名称', max_length=1000)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('target', models.ForeignKey(help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)', on_delete=django.db.models.deletion.CASCADE, related_name='subdomains', to='targets.target')),
],
options={
'verbose_name': '子域名',
'verbose_name_plural': '子域名',
'db_table': 'subdomain',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='subdomain_created_e187a8_idx'), models.Index(fields=['name', 'target'], name='subdomain_name_60e1d0_idx'), models.Index(fields=['target'], name='subdomain_target__e409f0_idx'), models.Index(fields=['name'], name='subdomain_name_d40ba7_idx'), django.contrib.postgres.indexes.GinIndex(fields=['name'], name='subdomain_name_trgm_idx', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('name', 'target'), name='unique_subdomain_name_target')],
},
),
migrations.CreateModel(
name='SubdomainSnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('name', models.CharField(help_text='子域名名称', max_length=1000)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='subdomain_snapshots', to='scan.scan')),
],
options={
'verbose_name': '子域名快照',
'verbose_name_plural': '子域名快照',
'db_table': 'subdomain_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='subdomain_s_scan_id_68c253_idx'), models.Index(fields=['name'], name='subdomain_s_name_2da42b_idx'), models.Index(fields=['-created_at'], name='subdomain_s_created_d2b48e_idx'), django.contrib.postgres.indexes.GinIndex(fields=['name'], name='subdomain_snap_name_trgm', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('scan', 'name'), name='unique_subdomain_per_scan_snapshot')],
},
),
migrations.CreateModel(
name='Vulnerability',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.CharField(help_text='漏洞所在的URL', max_length=2000)),
('vuln_type', models.CharField(help_text='漏洞类型(如 xss, sqli', max_length=100)),
('severity', models.CharField(choices=[('unknown', '未知'), ('info', '信息'), ('low', ''), ('medium', ''), ('high', ''), ('critical', '危急')], default='unknown', help_text='严重性(未知/信息/低/中/高/危急)', max_length=20)),
('source', models.CharField(blank=True, default='', help_text='来源工具(如 dalfox, nuclei, crlfuzz', max_length=50)),
('cvss_score', models.DecimalField(blank=True, decimal_places=1, help_text='CVSS 评分0.0-10.0', max_digits=3, null=True)),
('description', models.TextField(blank=True, default='', help_text='漏洞描述')),
('raw_output', models.JSONField(blank=True, default=dict, help_text='工具原始输出')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('target', models.ForeignKey(help_text='所属的扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='vulnerabilities', to='targets.target')),
],
options={
'verbose_name': '漏洞',
'verbose_name_plural': '漏洞',
'db_table': 'vulnerability',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['target'], name='vulnerabili_target__755a02_idx'), models.Index(fields=['vuln_type'], name='vulnerabili_vuln_ty_3010cd_idx'), models.Index(fields=['severity'], name='vulnerabili_severit_1a798b_idx'), models.Index(fields=['source'], name='vulnerabili_source_7c7552_idx'), models.Index(fields=['url'], name='vulnerabili_url_4dcc4d_idx'), models.Index(fields=['-created_at'], name='vulnerabili_created_e25ff7_idx')],
},
),
migrations.CreateModel(
name='VulnerabilitySnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.CharField(help_text='漏洞所在的URL', max_length=2000)),
('vuln_type', models.CharField(help_text='漏洞类型(如 xss, sqli', max_length=100)),
('severity', models.CharField(choices=[('unknown', '未知'), ('info', '信息'), ('low', ''), ('medium', ''), ('high', ''), ('critical', '危急')], default='unknown', help_text='严重性(未知/信息/低/中/高/危急)', max_length=20)),
('source', models.CharField(blank=True, default='', help_text='来源工具(如 dalfox, nuclei, crlfuzz', max_length=50)),
('cvss_score', models.DecimalField(blank=True, decimal_places=1, help_text='CVSS 评分0.0-10.0', max_digits=3, null=True)),
('description', models.TextField(blank=True, default='', help_text='漏洞描述')),
('raw_output', models.JSONField(blank=True, default=dict, help_text='工具原始输出')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='vulnerability_snapshots', to='scan.scan')),
],
options={
'verbose_name': '漏洞快照',
'verbose_name_plural': '漏洞快照',
'db_table': 'vulnerability_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='vulnerabili_scan_id_7b81c9_idx'), models.Index(fields=['url'], name='vulnerabili_url_11a707_idx'), models.Index(fields=['vuln_type'], name='vulnerabili_vuln_ty_6b90ee_idx'), models.Index(fields=['severity'], name='vulnerabili_severit_4eae0d_idx'), models.Index(fields=['source'], name='vulnerabili_source_968b1f_idx'), models.Index(fields=['-created_at'], name='vulnerabili_created_53a12e_idx')],
},
),
migrations.CreateModel(
name='WebSite',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.TextField(help_text='最终访问的完整URL')),
('host', models.CharField(blank=True, default='', help_text='主机名域名或IP地址', max_length=253)),
('location', models.TextField(blank=True, default='', help_text='重定向地址HTTP 3xx 响应头 Location')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('title', models.TextField(blank=True, default='', help_text='网页标题HTML <title> 标签内容)')),
('webserver', models.TextField(blank=True, default='', help_text='服务器类型HTTP 响应头 Server 值)')),
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
('content_type', models.TextField(blank=True, default='', help_text='响应类型HTTP Content-Type 响应头)')),
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈(服务器/框架/语言等)', size=None)),
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
('content_length', models.IntegerField(blank=True, help_text='响应体大小(单位字节)', null=True)),
('vhost', models.BooleanField(blank=True, help_text='是否支持虚拟主机', null=True)),
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
('target', models.ForeignKey(help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)', on_delete=django.db.models.deletion.CASCADE, related_name='websites', to='targets.target')),
],
options={
'verbose_name': '站点',
'verbose_name_plural': '站点',
'db_table': 'website',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='website_created_c9cfd2_idx'), models.Index(fields=['url'], name='website_url_b18883_idx'), models.Index(fields=['host'], name='website_host_996b50_idx'), models.Index(fields=['target'], name='website_target__2a353b_idx'), models.Index(fields=['title'], name='website_title_c2775b_idx'), models.Index(fields=['status_code'], name='website_status__51663d_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='website_tech_e3f0cb_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='website_resp_headers_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='website_url_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='website_title_trgm_idx', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('url', 'target'), name='unique_website_url_target')],
},
),
migrations.CreateModel(
name='WebsiteSnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.TextField(help_text='站点URL')),
('host', models.CharField(blank=True, default='', help_text='主机名域名或IP地址', max_length=253)),
('title', models.TextField(blank=True, default='', help_text='页面标题')),
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
('content_length', models.BigIntegerField(blank=True, help_text='内容长度', null=True)),
('location', models.TextField(blank=True, default='', help_text='重定向位置')),
('webserver', models.TextField(blank=True, default='', help_text='Web服务器')),
('content_type', models.TextField(blank=True, default='', help_text='内容类型')),
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈', size=None)),
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
('vhost', models.BooleanField(blank=True, help_text='虚拟主机标志', null=True)),
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='website_snapshots', to='scan.scan')),
],
options={
'verbose_name': '网站快照',
'verbose_name_plural': '网站快照',
'db_table': 'website_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='website_sna_scan_id_26b6dc_idx'), models.Index(fields=['url'], name='website_sna_url_801a70_idx'), models.Index(fields=['host'], name='website_sna_host_348fe1_idx'), models.Index(fields=['title'], name='website_sna_title_b1a5ee_idx'), models.Index(fields=['-created_at'], name='website_sna_created_2c149a_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='website_sna_tech_3d6d2f_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='ws_snap_resp_hdr_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='ws_snap_url_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='ws_snap_title_trgm', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_website_per_scan_snapshot')],
},
),
]

View File

@@ -0,0 +1,187 @@
"""
创建资产搜索 IMMV增量维护物化视图
使用 pg_ivm 扩展创建 IMMV数据变更时自动增量更新无需手动刷新。
包含:
1. asset_search_view - Website 搜索视图
2. endpoint_search_view - Endpoint 搜索视图
"""
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('asset', '0001_initial'),
]
operations = [
# 1. 确保 pg_ivm 扩展已启用
migrations.RunSQL(
sql="CREATE EXTENSION IF NOT EXISTS pg_ivm;",
reverse_sql="-- pg_ivm extension kept for other uses"
),
# ==================== Website IMMV ====================
# 2. 创建 asset_search_view IMMV
migrations.RunSQL(
sql="""
SELECT pgivm.create_immv('asset_search_view', $$
SELECT
w.id,
w.url,
w.host,
w.title,
w.tech,
w.status_code,
w.response_headers,
w.response_body,
w.content_type,
w.content_length,
w.webserver,
w.location,
w.vhost,
w.created_at,
w.target_id
FROM website w
$$);
""",
reverse_sql="SELECT pgivm.drop_immv('asset_search_view');"
),
# 3. 创建 asset_search_view 索引
migrations.RunSQL(
sql="""
-- 唯一索引
CREATE UNIQUE INDEX IF NOT EXISTS asset_search_view_id_idx
ON asset_search_view (id);
-- host 模糊搜索索引
CREATE INDEX IF NOT EXISTS asset_search_view_host_trgm_idx
ON asset_search_view USING gin (host gin_trgm_ops);
-- title 模糊搜索索引
CREATE INDEX IF NOT EXISTS asset_search_view_title_trgm_idx
ON asset_search_view USING gin (title gin_trgm_ops);
-- url 模糊搜索索引
CREATE INDEX IF NOT EXISTS asset_search_view_url_trgm_idx
ON asset_search_view USING gin (url gin_trgm_ops);
-- response_headers 模糊搜索索引
CREATE INDEX IF NOT EXISTS asset_search_view_headers_trgm_idx
ON asset_search_view USING gin (response_headers gin_trgm_ops);
-- response_body 模糊搜索索引
CREATE INDEX IF NOT EXISTS asset_search_view_body_trgm_idx
ON asset_search_view USING gin (response_body gin_trgm_ops);
-- tech 数组索引
CREATE INDEX IF NOT EXISTS asset_search_view_tech_idx
ON asset_search_view USING gin (tech);
-- status_code 索引
CREATE INDEX IF NOT EXISTS asset_search_view_status_idx
ON asset_search_view (status_code);
-- created_at 排序索引
CREATE INDEX IF NOT EXISTS asset_search_view_created_idx
ON asset_search_view (created_at DESC);
""",
reverse_sql="""
DROP INDEX IF EXISTS asset_search_view_id_idx;
DROP INDEX IF EXISTS asset_search_view_host_trgm_idx;
DROP INDEX IF EXISTS asset_search_view_title_trgm_idx;
DROP INDEX IF EXISTS asset_search_view_url_trgm_idx;
DROP INDEX IF EXISTS asset_search_view_headers_trgm_idx;
DROP INDEX IF EXISTS asset_search_view_body_trgm_idx;
DROP INDEX IF EXISTS asset_search_view_tech_idx;
DROP INDEX IF EXISTS asset_search_view_status_idx;
DROP INDEX IF EXISTS asset_search_view_created_idx;
"""
),
# ==================== Endpoint IMMV ====================
# 4. 创建 endpoint_search_view IMMV
migrations.RunSQL(
sql="""
SELECT pgivm.create_immv('endpoint_search_view', $$
SELECT
e.id,
e.url,
e.host,
e.title,
e.tech,
e.status_code,
e.response_headers,
e.response_body,
e.content_type,
e.content_length,
e.webserver,
e.location,
e.vhost,
e.matched_gf_patterns,
e.created_at,
e.target_id
FROM endpoint e
$$);
""",
reverse_sql="SELECT pgivm.drop_immv('endpoint_search_view');"
),
# 5. 创建 endpoint_search_view 索引
migrations.RunSQL(
sql="""
-- 唯一索引
CREATE UNIQUE INDEX IF NOT EXISTS endpoint_search_view_id_idx
ON endpoint_search_view (id);
-- host 模糊搜索索引
CREATE INDEX IF NOT EXISTS endpoint_search_view_host_trgm_idx
ON endpoint_search_view USING gin (host gin_trgm_ops);
-- title 模糊搜索索引
CREATE INDEX IF NOT EXISTS endpoint_search_view_title_trgm_idx
ON endpoint_search_view USING gin (title gin_trgm_ops);
-- url 模糊搜索索引
CREATE INDEX IF NOT EXISTS endpoint_search_view_url_trgm_idx
ON endpoint_search_view USING gin (url gin_trgm_ops);
-- response_headers 模糊搜索索引
CREATE INDEX IF NOT EXISTS endpoint_search_view_headers_trgm_idx
ON endpoint_search_view USING gin (response_headers gin_trgm_ops);
-- response_body 模糊搜索索引
CREATE INDEX IF NOT EXISTS endpoint_search_view_body_trgm_idx
ON endpoint_search_view USING gin (response_body gin_trgm_ops);
-- tech 数组索引
CREATE INDEX IF NOT EXISTS endpoint_search_view_tech_idx
ON endpoint_search_view USING gin (tech);
-- status_code 索引
CREATE INDEX IF NOT EXISTS endpoint_search_view_status_idx
ON endpoint_search_view (status_code);
-- created_at 排序索引
CREATE INDEX IF NOT EXISTS endpoint_search_view_created_idx
ON endpoint_search_view (created_at DESC);
""",
reverse_sql="""
DROP INDEX IF EXISTS endpoint_search_view_id_idx;
DROP INDEX IF EXISTS endpoint_search_view_host_trgm_idx;
DROP INDEX IF EXISTS endpoint_search_view_title_trgm_idx;
DROP INDEX IF EXISTS endpoint_search_view_url_trgm_idx;
DROP INDEX IF EXISTS endpoint_search_view_headers_trgm_idx;
DROP INDEX IF EXISTS endpoint_search_view_body_trgm_idx;
DROP INDEX IF EXISTS endpoint_search_view_tech_idx;
DROP INDEX IF EXISTS endpoint_search_view_status_idx;
DROP INDEX IF EXISTS endpoint_search_view_created_idx;
"""
),
]

View File

@@ -1,6 +1,7 @@
from django.db import models
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.indexes import GinIndex
from django.core.validators import MinValueValidator, MaxValueValidator
@@ -34,6 +35,12 @@ class Subdomain(models.Model):
models.Index(fields=['name', 'target']), # 复合索引,优化 get_by_names_and_target_id 批量查询
models.Index(fields=['target']), # 优化从target_id快速查找下面的子域名
models.Index(fields=['name']), # 优化从name快速查找子域名搜索场景
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='subdomain_name_trgm_idx',
fields=['name'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 普通唯一约束name + target 组合唯一
@@ -58,40 +65,35 @@ class Endpoint(models.Model):
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
)
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
url = models.TextField(help_text='最终访问的完整URL')
host = models.CharField(
max_length=253,
blank=True,
default='',
help_text='主机名域名或IP地址'
)
location = models.CharField(
max_length=1000,
location = models.TextField(
blank=True,
default='',
help_text='重定向地址HTTP 3xx 响应头 Location'
)
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
title = models.CharField(
max_length=1000,
title = models.TextField(
blank=True,
default='',
help_text='网页标题HTML <title> 标签内容)'
)
webserver = models.CharField(
max_length=200,
webserver = models.TextField(
blank=True,
default='',
help_text='服务器类型HTTP 响应头 Server 值)'
)
body_preview = models.CharField(
max_length=1000,
response_body = models.TextField(
blank=True,
default='',
help_text='响应正文前N个字符默认100个字符'
help_text='HTTP响应体'
)
content_type = models.CharField(
max_length=200,
content_type = models.TextField(
blank=True,
default='',
help_text='响应类型HTTP Content-Type 响应头)'
@@ -123,6 +125,11 @@ class Endpoint(models.Model):
default=list,
help_text='匹配的GF模式列表用于识别敏感端点如api, debug, config等'
)
response_headers = models.TextField(
blank=True,
default='',
help_text='原始HTTP响应头'
)
class Meta:
db_table = 'endpoint'
@@ -131,11 +138,28 @@ class Endpoint(models.Model):
ordering = ['-created_at']
indexes = [
models.Index(fields=['-created_at']),
models.Index(fields=['target']), # 优化从target_id快速查找下面的端点主关联字段
models.Index(fields=['target']), # 优化从 target_id快速查找下面的端点主关联字段
models.Index(fields=['url']), # URL索引优化查询性能
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['status_code']), # 状态码索引,优化筛选
models.Index(fields=['title']), # title索引优化智能过滤搜索
GinIndex(fields=['tech']), # GIN索引优化 tech 数组字段的 __contains 查询
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='endpoint_resp_headers_trgm_idx',
fields=['response_headers'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='endpoint_url_trgm_idx',
fields=['url'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='endpoint_title_trgm_idx',
fields=['title'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 普通唯一约束url + target 组合唯一
@@ -160,40 +184,35 @@ class WebSite(models.Model):
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
)
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
url = models.TextField(help_text='最终访问的完整URL')
host = models.CharField(
max_length=253,
blank=True,
default='',
help_text='主机名域名或IP地址'
)
location = models.CharField(
max_length=1000,
location = models.TextField(
blank=True,
default='',
help_text='重定向地址HTTP 3xx 响应头 Location'
)
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
title = models.CharField(
max_length=1000,
title = models.TextField(
blank=True,
default='',
help_text='网页标题HTML <title> 标签内容)'
)
webserver = models.CharField(
max_length=200,
webserver = models.TextField(
blank=True,
default='',
help_text='服务器类型HTTP 响应头 Server 值)'
)
body_preview = models.CharField(
max_length=1000,
response_body = models.TextField(
blank=True,
default='',
help_text='响应正文前N个字符默认100个字符'
help_text='HTTP响应体'
)
content_type = models.CharField(
max_length=200,
content_type = models.TextField(
blank=True,
default='',
help_text='响应类型HTTP Content-Type 响应头)'
@@ -219,6 +238,11 @@ class WebSite(models.Model):
blank=True,
help_text='是否支持虚拟主机'
)
response_headers = models.TextField(
blank=True,
default='',
help_text='原始HTTP响应头'
)
class Meta:
db_table = 'website'
@@ -229,9 +253,26 @@ class WebSite(models.Model):
models.Index(fields=['-created_at']),
models.Index(fields=['url']), # URL索引优化查询性能
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['target']), # 优化从target_id快速查找下面的站点
models.Index(fields=['target']), # 优化从 target_id快速查找下面的站点
models.Index(fields=['title']), # title索引优化智能过滤搜索
models.Index(fields=['status_code']), # 状态码索引,优化智能过滤搜索
GinIndex(fields=['tech']), # GIN索引优化 tech 数组字段的 __contains 查询
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='website_resp_headers_trgm_idx',
fields=['response_headers'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='website_url_trgm_idx',
fields=['url'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='website_title_trgm_idx',
fields=['title'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 普通唯一约束url + target 组合唯一
@@ -308,6 +349,12 @@ class Directory(models.Model):
models.Index(fields=['target']), # 优化从target_id快速查找下面的目录
models.Index(fields=['url']), # URL索引优化搜索和唯一约束
models.Index(fields=['status']), # 状态码索引,优化筛选
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='directory_url_trgm_idx',
fields=['url'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 普通唯一约束target + url 组合唯一

View File

@@ -1,5 +1,6 @@
from django.db import models
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.indexes import GinIndex
from django.core.validators import MinValueValidator, MaxValueValidator
@@ -26,6 +27,12 @@ class SubdomainSnapshot(models.Model):
models.Index(fields=['scan']),
models.Index(fields=['name']),
models.Index(fields=['-created_at']),
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='subdomain_snap_name_trgm',
fields=['name'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 唯一约束:同一次扫描中,同一个子域名只能记录一次
@@ -54,22 +61,27 @@ class WebsiteSnapshot(models.Model):
)
# 扫描结果数据
url = models.CharField(max_length=2000, help_text='站点URL')
url = models.TextField(help_text='站点URL')
host = models.CharField(max_length=253, blank=True, default='', help_text='主机名域名或IP地址')
title = models.CharField(max_length=500, blank=True, default='', help_text='页面标题')
status = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
title = models.TextField(blank=True, default='', help_text='页面标题')
status_code = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
content_length = models.BigIntegerField(null=True, blank=True, help_text='内容长度')
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
web_server = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
location = models.TextField(blank=True, default='', help_text='重定向位置')
webserver = models.TextField(blank=True, default='', help_text='Web服务器')
content_type = models.TextField(blank=True, default='', help_text='内容类型')
tech = ArrayField(
models.CharField(max_length=100),
blank=True,
default=list,
help_text='技术栈'
)
body_preview = models.TextField(blank=True, default='', help_text='响应体预览')
response_body = models.TextField(blank=True, default='', help_text='HTTP响应体')
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
response_headers = models.TextField(
blank=True,
default='',
help_text='原始HTTP响应头'
)
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
@@ -83,6 +95,23 @@ class WebsiteSnapshot(models.Model):
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['title']), # title索引优化标题搜索
models.Index(fields=['-created_at']),
GinIndex(fields=['tech']), # GIN索引优化数组字段查询
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='ws_snap_resp_hdr_trgm',
fields=['response_headers'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='ws_snap_url_trgm',
fields=['url'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='ws_snap_title_trgm',
fields=['title'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 唯一约束同一次扫描中同一个URL只能记录一次
@@ -132,6 +161,12 @@ class DirectorySnapshot(models.Model):
models.Index(fields=['status']), # 状态码索引,优化筛选
models.Index(fields=['content_type']), # content_type索引优化内容类型搜索
models.Index(fields=['-created_at']),
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='dir_snap_url_trgm',
fields=['url'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 唯一约束同一次扫描中同一个目录URL只能记录一次
@@ -232,26 +267,26 @@ class EndpointSnapshot(models.Model):
)
# 扫描结果数据
url = models.CharField(max_length=2000, help_text='端点URL')
url = models.TextField(help_text='端点URL')
host = models.CharField(
max_length=253,
blank=True,
default='',
help_text='主机名域名或IP地址'
)
title = models.CharField(max_length=1000, blank=True, default='', help_text='页面标题')
title = models.TextField(blank=True, default='', help_text='页面标题')
status_code = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
content_length = models.IntegerField(null=True, blank=True, help_text='内容长度')
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
webserver = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
location = models.TextField(blank=True, default='', help_text='重定向位置')
webserver = models.TextField(blank=True, default='', help_text='Web服务器')
content_type = models.TextField(blank=True, default='', help_text='内容类型')
tech = ArrayField(
models.CharField(max_length=100),
blank=True,
default=list,
help_text='技术栈'
)
body_preview = models.CharField(max_length=1000, blank=True, default='', help_text='响应体预览')
response_body = models.TextField(blank=True, default='', help_text='HTTP响应体')
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
matched_gf_patterns = ArrayField(
models.CharField(max_length=100),
@@ -259,6 +294,11 @@ class EndpointSnapshot(models.Model):
default=list,
help_text='匹配的GF模式列表'
)
response_headers = models.TextField(
blank=True,
default='',
help_text='原始HTTP响应头'
)
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
@@ -274,6 +314,23 @@ class EndpointSnapshot(models.Model):
models.Index(fields=['status_code']), # 状态码索引,优化筛选
models.Index(fields=['webserver']), # webserver索引优化服务器搜索
models.Index(fields=['-created_at']),
GinIndex(fields=['tech']), # GIN索引优化数组字段查询
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='ep_snap_resp_hdr_trgm',
fields=['response_headers'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='ep_snap_url_trgm',
fields=['url'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='ep_snap_title_trgm',
fields=['title'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 唯一约束同一次扫描中同一个URL只能记录一次

View File

@@ -48,12 +48,13 @@ class DjangoEndpointRepository:
status_code=item.status_code,
content_length=item.content_length,
webserver=item.webserver or '',
body_preview=item.body_preview or '',
response_body=item.response_body or '',
content_type=item.content_type or '',
tech=item.tech if item.tech else [],
vhost=item.vhost,
location=item.location or '',
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else [],
response_headers=item.response_headers if item.response_headers else ''
)
for item in unique_items
]
@@ -65,8 +66,8 @@ class DjangoEndpointRepository:
unique_fields=['url', 'target'],
update_fields=[
'host', 'title', 'status_code', 'content_length',
'webserver', 'body_preview', 'content_type', 'tech',
'vhost', 'location', 'matched_gf_patterns'
'webserver', 'response_body', 'content_type', 'tech',
'vhost', 'location', 'matched_gf_patterns', 'response_headers'
],
batch_size=1000
)
@@ -138,12 +139,13 @@ class DjangoEndpointRepository:
status_code=item.status_code,
content_length=item.content_length,
webserver=item.webserver or '',
body_preview=item.body_preview or '',
response_body=item.response_body or '',
content_type=item.content_type or '',
tech=item.tech if item.tech else [],
vhost=item.vhost,
location=item.location or '',
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else [],
response_headers=item.response_headers if item.response_headers else ''
)
for item in unique_items
]
@@ -183,7 +185,7 @@ class DjangoEndpointRepository:
.values(
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
)
.order_by('url')
)

View File

@@ -49,12 +49,13 @@ class DjangoWebSiteRepository:
location=item.location or '',
title=item.title or '',
webserver=item.webserver or '',
body_preview=item.body_preview or '',
response_body=item.response_body or '',
content_type=item.content_type or '',
tech=item.tech if item.tech else [],
status_code=item.status_code,
content_length=item.content_length,
vhost=item.vhost
vhost=item.vhost,
response_headers=item.response_headers if item.response_headers else ''
)
for item in unique_items
]
@@ -66,8 +67,8 @@ class DjangoWebSiteRepository:
unique_fields=['url', 'target'],
update_fields=[
'host', 'location', 'title', 'webserver',
'body_preview', 'content_type', 'tech',
'status_code', 'content_length', 'vhost'
'response_body', 'content_type', 'tech',
'status_code', 'content_length', 'vhost', 'response_headers'
],
batch_size=1000
)
@@ -132,12 +133,13 @@ class DjangoWebSiteRepository:
location=item.location or '',
title=item.title or '',
webserver=item.webserver or '',
body_preview=item.body_preview or '',
response_body=item.response_body or '',
content_type=item.content_type or '',
tech=item.tech if item.tech else [],
status_code=item.status_code,
content_length=item.content_length,
vhost=item.vhost
vhost=item.vhost,
response_headers=item.response_headers if item.response_headers else ''
)
for item in unique_items
]
@@ -177,7 +179,7 @@ class DjangoWebSiteRepository:
.values(
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'created_at'
'response_body', 'response_headers', 'vhost', 'created_at'
)
.order_by('url')
)

View File

@@ -44,6 +44,7 @@ class DjangoEndpointSnapshotRepository:
snapshots.append(EndpointSnapshot(
scan_id=item.scan_id,
url=item.url,
host=item.host if item.host else '',
title=item.title,
status_code=item.status_code,
content_length=item.content_length,
@@ -51,9 +52,10 @@ class DjangoEndpointSnapshotRepository:
webserver=item.webserver,
content_type=item.content_type,
tech=item.tech if item.tech else [],
body_preview=item.body_preview,
response_body=item.response_body,
vhost=item.vhost,
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else [],
response_headers=item.response_headers if item.response_headers else ''
))
# 批量创建(忽略冲突,基于唯一约束去重)
@@ -100,7 +102,7 @@ class DjangoEndpointSnapshotRepository:
.values(
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
)
.order_by('url')
)

View File

@@ -46,14 +46,15 @@ class DjangoWebsiteSnapshotRepository:
url=item.url,
host=item.host,
title=item.title,
status=item.status,
status_code=item.status_code,
content_length=item.content_length,
location=item.location,
web_server=item.web_server,
webserver=item.webserver,
content_type=item.content_type,
tech=item.tech if item.tech else [],
body_preview=item.body_preview,
vhost=item.vhost
response_body=item.response_body,
vhost=item.vhost,
response_headers=item.response_headers if item.response_headers else ''
))
# 批量创建(忽略冲突,基于唯一约束去重)
@@ -98,26 +99,12 @@ class DjangoWebsiteSnapshotRepository:
WebsiteSnapshot.objects
.filter(scan_id=scan_id)
.values(
'url', 'host', 'location', 'title', 'status',
'content_length', 'content_type', 'web_server', 'tech',
'body_preview', 'vhost', 'created_at'
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'response_body', 'response_headers', 'vhost', 'created_at'
)
.order_by('url')
)
for row in qs.iterator(chunk_size=batch_size):
# 重命名字段以匹配 CSV 表头
yield {
'url': row['url'],
'host': row['host'],
'location': row['location'],
'title': row['title'],
'status_code': row['status'],
'content_length': row['content_length'],
'content_type': row['content_type'],
'webserver': row['web_server'],
'tech': row['tech'],
'body_preview': row['body_preview'],
'vhost': row['vhost'],
'created_at': row['created_at'],
}
yield row

View File

@@ -67,9 +67,10 @@ class SubdomainListSerializer(serializers.ModelSerializer):
class WebSiteSerializer(serializers.ModelSerializer):
"""站点序列化器"""
"""站点序列化器(目标详情页)"""
subdomain = serializers.CharField(source='subdomain.name', allow_blank=True, default='')
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
class Meta:
model = WebSite
@@ -83,9 +84,10 @@ class WebSiteSerializer(serializers.ModelSerializer):
'content_type',
'status_code',
'content_length',
'body_preview',
'response_body',
'tech',
'vhost',
'responseHeaders', # HTTP响应头
'subdomain',
'created_at',
]
@@ -140,6 +142,7 @@ class EndpointListSerializer(serializers.ModelSerializer):
source='matched_gf_patterns',
read_only=True,
)
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
class Meta:
model = Endpoint
@@ -152,9 +155,10 @@ class EndpointListSerializer(serializers.ModelSerializer):
'content_length',
'content_type',
'webserver',
'body_preview',
'response_body',
'tech',
'vhost',
'responseHeaders', # HTTP响应头
'gfPatterns',
'created_at',
]
@@ -213,8 +217,7 @@ class WebsiteSnapshotSerializer(serializers.ModelSerializer):
"""网站快照序列化器(用于扫描历史)"""
subdomain_name = serializers.CharField(source='subdomain.name', read_only=True)
webserver = serializers.CharField(source='web_server', read_only=True) # 映射字段名
status_code = serializers.IntegerField(source='status', read_only=True) # 映射字段名
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
class Meta:
model = WebsiteSnapshot
@@ -223,13 +226,14 @@ class WebsiteSnapshotSerializer(serializers.ModelSerializer):
'url',
'location',
'title',
'webserver', # 使用映射后的字段名
'webserver',
'content_type',
'status_code', # 使用映射后的字段名
'status_code',
'content_length',
'body_preview',
'response_body',
'tech',
'vhost',
'responseHeaders', # HTTP响应头
'subdomain_name',
'created_at',
]
@@ -264,6 +268,7 @@ class EndpointSnapshotSerializer(serializers.ModelSerializer):
source='matched_gf_patterns',
read_only=True,
)
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
class Meta:
model = EndpointSnapshot
@@ -277,9 +282,10 @@ class EndpointSnapshotSerializer(serializers.ModelSerializer):
'content_type',
'status_code',
'content_length',
'body_preview',
'response_body',
'tech',
'vhost',
'responseHeaders', # HTTP响应头
'gfPatterns',
'created_at',
]

View File

@@ -27,7 +27,8 @@ class EndpointService:
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
'status_code': 'status_code',
'tech': 'tech',
}
def __init__(self):
@@ -115,7 +116,7 @@ class EndpointService:
"""获取目标下的所有端点"""
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
return queryset
def count_endpoints_by_target(self, target_id: int) -> int:
@@ -134,7 +135,7 @@ class EndpointService:
"""获取所有端点(全局查询)"""
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
return queryset
def iter_endpoint_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:

View File

@@ -19,7 +19,8 @@ class WebSiteService:
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
'status_code': 'status_code',
'tech': 'tech',
}
def __init__(self, repository=None):
@@ -107,14 +108,14 @@ class WebSiteService:
"""获取目标下的所有网站"""
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
return queryset
def get_all(self, filter_query: Optional[str] = None):
"""获取所有网站"""
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
return queryset
def get_by_url(self, url: str, target_id: int) -> int:

View File

@@ -0,0 +1,439 @@
"""
资产搜索服务
提供资产搜索的核心业务逻辑:
- 从物化视图查询数据
- 支持表达式语法解析
- 支持 =(模糊)、==(精确)、!=(不等于)操作符
- 支持 && (AND) 和 || (OR) 逻辑组合
- 支持 Website 和 Endpoint 两种资产类型
"""
import logging
import re
from typing import Optional, List, Dict, Any, Tuple, Literal
from django.db import connection
logger = logging.getLogger(__name__)
# 支持的字段映射(前端字段名 -> 数据库字段名)
FIELD_MAPPING = {
'host': 'host',
'url': 'url',
'title': 'title',
'tech': 'tech',
'status': 'status_code',
'body': 'response_body',
'header': 'response_headers',
}
# 数组类型字段
ARRAY_FIELDS = {'tech'}
# 资产类型到视图名的映射
VIEW_MAPPING = {
'website': 'asset_search_view',
'endpoint': 'endpoint_search_view',
}
# 有效的资产类型
VALID_ASSET_TYPES = {'website', 'endpoint'}
# Website 查询字段
WEBSITE_SELECT_FIELDS = """
id,
url,
host,
title,
tech,
status_code,
response_headers,
response_body,
content_type,
content_length,
webserver,
location,
vhost,
created_at,
target_id
"""
# Endpoint 查询字段(包含 matched_gf_patterns
ENDPOINT_SELECT_FIELDS = """
id,
url,
host,
title,
tech,
status_code,
response_headers,
response_body,
content_type,
content_length,
webserver,
location,
vhost,
matched_gf_patterns,
created_at,
target_id
"""
class SearchQueryParser:
"""
搜索查询解析器
支持语法:
- field="value" 模糊匹配ILIKE %value%
- field=="value" 精确匹配
- field!="value" 不等于
- && AND 连接
- || OR 连接
- () 分组(暂不支持嵌套)
示例:
- host="api" && tech="nginx"
- tech="vue" || tech="react"
- status=="200" && host!="test"
"""
# 匹配单个条件: field="value" 或 field=="value" 或 field!="value"
CONDITION_PATTERN = re.compile(r'(\w+)\s*(==|!=|=)\s*"([^"]*)"')
@classmethod
def parse(cls, query: str) -> Tuple[str, List[Any]]:
"""
解析查询字符串,返回 SQL WHERE 子句和参数
Args:
query: 搜索查询字符串
Returns:
(where_clause, params) 元组
"""
if not query or not query.strip():
return "1=1", []
query = query.strip()
# 检查是否包含操作符语法,如果不包含则作为 host 模糊搜索
if not cls.CONDITION_PATTERN.search(query):
# 裸文本,默认作为 host 模糊搜索
return "host ILIKE %s", [f"%{query}%"]
# 按 || 分割为 OR 组
or_groups = cls._split_by_or(query)
if len(or_groups) == 1:
# 没有 OR直接解析 AND 条件
return cls._parse_and_group(or_groups[0])
# 多个 OR 组
or_clauses = []
all_params = []
for group in or_groups:
clause, params = cls._parse_and_group(group)
if clause and clause != "1=1":
or_clauses.append(f"({clause})")
all_params.extend(params)
if not or_clauses:
return "1=1", []
return " OR ".join(or_clauses), all_params
@classmethod
def _split_by_or(cls, query: str) -> List[str]:
"""按 || 分割查询,但忽略引号内的 ||"""
parts = []
current = ""
in_quotes = False
i = 0
while i < len(query):
char = query[i]
if char == '"':
in_quotes = not in_quotes
current += char
elif not in_quotes and i + 1 < len(query) and query[i:i+2] == '||':
if current.strip():
parts.append(current.strip())
current = ""
i += 1 # 跳过第二个 |
else:
current += char
i += 1
if current.strip():
parts.append(current.strip())
return parts if parts else [query]
@classmethod
def _parse_and_group(cls, group: str) -> Tuple[str, List[Any]]:
"""解析 AND 组(用 && 连接的条件)"""
# 移除外层括号
group = group.strip()
if group.startswith('(') and group.endswith(')'):
group = group[1:-1].strip()
# 按 && 分割
parts = cls._split_by_and(group)
and_clauses = []
all_params = []
for part in parts:
clause, params = cls._parse_condition(part.strip())
if clause:
and_clauses.append(clause)
all_params.extend(params)
if not and_clauses:
return "1=1", []
return " AND ".join(and_clauses), all_params
@classmethod
def _split_by_and(cls, query: str) -> List[str]:
"""按 && 分割查询,但忽略引号内的 &&"""
parts = []
current = ""
in_quotes = False
i = 0
while i < len(query):
char = query[i]
if char == '"':
in_quotes = not in_quotes
current += char
elif not in_quotes and i + 1 < len(query) and query[i:i+2] == '&&':
if current.strip():
parts.append(current.strip())
current = ""
i += 1 # 跳过第二个 &
else:
current += char
i += 1
if current.strip():
parts.append(current.strip())
return parts if parts else [query]
@classmethod
def _parse_condition(cls, condition: str) -> Tuple[Optional[str], List[Any]]:
"""
解析单个条件
Returns:
(sql_clause, params) 或 (None, []) 如果解析失败
"""
# 移除括号
condition = condition.strip()
if condition.startswith('(') and condition.endswith(')'):
condition = condition[1:-1].strip()
match = cls.CONDITION_PATTERN.match(condition)
if not match:
logger.warning(f"无法解析条件: {condition}")
return None, []
field, operator, value = match.groups()
field = field.lower()
# 验证字段
if field not in FIELD_MAPPING:
logger.warning(f"未知字段: {field}")
return None, []
db_field = FIELD_MAPPING[field]
is_array = field in ARRAY_FIELDS
# 根据操作符生成 SQL
if operator == '=':
# 模糊匹配
return cls._build_like_condition(db_field, value, is_array)
elif operator == '==':
# 精确匹配
return cls._build_exact_condition(db_field, value, is_array)
elif operator == '!=':
# 不等于
return cls._build_not_equal_condition(db_field, value, is_array)
return None, []
@classmethod
def _build_like_condition(cls, field: str, value: str, is_array: bool) -> Tuple[str, List[Any]]:
"""构建模糊匹配条件"""
if is_array:
# 数组字段:检查数组中是否有元素包含该值
return f"EXISTS (SELECT 1 FROM unnest({field}) AS t WHERE t ILIKE %s)", [f"%{value}%"]
elif field == 'status_code':
# 状态码是整数,模糊匹配转为精确匹配
try:
return f"{field} = %s", [int(value)]
except ValueError:
return f"{field}::text ILIKE %s", [f"%{value}%"]
else:
return f"{field} ILIKE %s", [f"%{value}%"]
@classmethod
def _build_exact_condition(cls, field: str, value: str, is_array: bool) -> Tuple[str, List[Any]]:
"""构建精确匹配条件"""
if is_array:
# 数组字段:检查数组中是否包含该精确值
return f"%s = ANY({field})", [value]
elif field == 'status_code':
# 状态码是整数
try:
return f"{field} = %s", [int(value)]
except ValueError:
return f"{field}::text = %s", [value]
else:
return f"{field} = %s", [value]
@classmethod
def _build_not_equal_condition(cls, field: str, value: str, is_array: bool) -> Tuple[str, List[Any]]:
"""构建不等于条件"""
if is_array:
# 数组字段:检查数组中不包含该值
return f"NOT (%s = ANY({field}))", [value]
elif field == 'status_code':
try:
return f"({field} IS NULL OR {field} != %s)", [int(value)]
except ValueError:
return f"({field} IS NULL OR {field}::text != %s)", [value]
else:
return f"({field} IS NULL OR {field} != %s)", [value]
AssetType = Literal['website', 'endpoint']
class AssetSearchService:
"""资产搜索服务"""
def search(
self,
query: str,
asset_type: AssetType = 'website',
limit: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
搜索资产
Args:
query: 搜索查询字符串
asset_type: 资产类型 ('website''endpoint')
limit: 最大返回数量(可选)
Returns:
List[Dict]: 搜索结果列表
"""
where_clause, params = SearchQueryParser.parse(query)
# 根据资产类型选择视图和字段
view_name = VIEW_MAPPING.get(asset_type, 'asset_search_view')
select_fields = ENDPOINT_SELECT_FIELDS if asset_type == 'endpoint' else WEBSITE_SELECT_FIELDS
sql = f"""
SELECT {select_fields}
FROM {view_name}
WHERE {where_clause}
ORDER BY created_at DESC
"""
# 添加 LIMIT
if limit is not None and limit > 0:
sql += f" LIMIT {int(limit)}"
try:
with connection.cursor() as cursor:
cursor.execute(sql, params)
columns = [col[0] for col in cursor.description]
results = []
for row in cursor.fetchall():
result = dict(zip(columns, row))
results.append(result)
return results
except Exception as e:
logger.error(f"搜索查询失败: {e}, SQL: {sql}, params: {params}")
raise
def count(self, query: str, asset_type: AssetType = 'website') -> int:
"""
统计搜索结果数量
Args:
query: 搜索查询字符串
asset_type: 资产类型 ('website''endpoint')
Returns:
int: 结果总数
"""
where_clause, params = SearchQueryParser.parse(query)
# 根据资产类型选择视图
view_name = VIEW_MAPPING.get(asset_type, 'asset_search_view')
sql = f"SELECT COUNT(*) FROM {view_name} WHERE {where_clause}"
try:
with connection.cursor() as cursor:
cursor.execute(sql, params)
return cursor.fetchone()[0]
except Exception as e:
logger.error(f"统计查询失败: {e}")
raise
def search_iter(
self,
query: str,
asset_type: AssetType = 'website',
batch_size: int = 1000
):
"""
流式搜索资产(使用服务端游标,内存友好)
Args:
query: 搜索查询字符串
asset_type: 资产类型 ('website''endpoint')
batch_size: 每批获取的数量
Yields:
Dict: 单条搜索结果
"""
where_clause, params = SearchQueryParser.parse(query)
# 根据资产类型选择视图和字段
view_name = VIEW_MAPPING.get(asset_type, 'asset_search_view')
select_fields = ENDPOINT_SELECT_FIELDS if asset_type == 'endpoint' else WEBSITE_SELECT_FIELDS
sql = f"""
SELECT {select_fields}
FROM {view_name}
WHERE {where_clause}
ORDER BY created_at DESC
"""
try:
# 使用服务端游标,避免一次性加载所有数据到内存
with connection.cursor(name='export_cursor') as cursor:
cursor.itersize = batch_size
cursor.execute(sql, params)
columns = [col[0] for col in cursor.description]
for row in cursor:
yield dict(zip(columns, row))
except Exception as e:
logger.error(f"流式搜索查询失败: {e}, SQL: {sql}, params: {params}")
raise

View File

@@ -72,7 +72,7 @@ class EndpointSnapshotsService:
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
'status_code': 'status_code',
'webserver': 'webserver',
'tech': 'tech',
}

View File

@@ -73,8 +73,8 @@ class WebsiteSnapshotsService:
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status',
'webserver': 'web_server',
'status_code': 'status_code',
'webserver': 'webserver',
'tech': 'tech',
}

View File

@@ -0,0 +1,7 @@
"""
Asset 应用的任务模块
注意:物化视图刷新已移至 APScheduler 定时任务apps.engine.scheduler
"""
__all__ = []

View File

@@ -10,6 +10,8 @@ from .views import (
DirectoryViewSet,
VulnerabilityViewSet,
AssetStatisticsViewSet,
AssetSearchView,
AssetSearchExportView,
)
# 创建 DRF 路由器
@@ -25,4 +27,6 @@ router.register(r'statistics', AssetStatisticsViewSet, basename='asset-statistic
urlpatterns = [
path('assets/', include(router.urls)),
path('assets/search/', AssetSearchView.as_view(), name='asset-search'),
path('assets/search/export/', AssetSearchExportView.as_view(), name='asset-search-export'),
]

View File

@@ -0,0 +1,40 @@
"""
Asset 应用视图模块
重新导出所有视图类以保持向后兼容
"""
from .asset_views import (
AssetStatisticsViewSet,
SubdomainViewSet,
WebSiteViewSet,
DirectoryViewSet,
EndpointViewSet,
HostPortMappingViewSet,
VulnerabilityViewSet,
SubdomainSnapshotViewSet,
WebsiteSnapshotViewSet,
DirectorySnapshotViewSet,
EndpointSnapshotViewSet,
HostPortMappingSnapshotViewSet,
VulnerabilitySnapshotViewSet,
)
from .search_views import AssetSearchView, AssetSearchExportView
__all__ = [
'AssetStatisticsViewSet',
'SubdomainViewSet',
'WebSiteViewSet',
'DirectoryViewSet',
'EndpointViewSet',
'HostPortMappingViewSet',
'VulnerabilityViewSet',
'SubdomainSnapshotViewSet',
'WebsiteSnapshotViewSet',
'DirectorySnapshotViewSet',
'EndpointSnapshotViewSet',
'HostPortMappingSnapshotViewSet',
'VulnerabilitySnapshotViewSet',
'AssetSearchView',
'AssetSearchExportView',
]

View File

@@ -10,17 +10,17 @@ from django.core.exceptions import ValidationError, ObjectDoesNotExist
from django.db import DatabaseError, IntegrityError, OperationalError
from django.http import StreamingHttpResponse
from .serializers import (
from ..serializers import (
SubdomainListSerializer, WebSiteSerializer, DirectorySerializer,
VulnerabilitySerializer, EndpointListSerializer, IPAddressAggregatedSerializer,
SubdomainSnapshotSerializer, WebsiteSnapshotSerializer, DirectorySnapshotSerializer,
EndpointSnapshotSerializer, VulnerabilitySnapshotSerializer
)
from .services import (
from ..services import (
SubdomainService, WebSiteService, DirectoryService,
VulnerabilityService, AssetStatisticsService, EndpointService, HostPortMappingService
)
from .services.snapshot import (
from ..services.snapshot import (
SubdomainSnapshotsService, WebsiteSnapshotsService, DirectorySnapshotsService,
EndpointSnapshotsService, HostPortMappingSnapshotsService, VulnerabilitySnapshotsService
)
@@ -274,6 +274,7 @@ class WebSiteViewSet(viewsets.ModelViewSet):
- host="example" 主机名模糊匹配
- title="login" 标题模糊匹配
- status="200,301" 状态码多值匹配
- tech="nginx" 技术栈匹配数组字段
- 多条件空格分隔 AND 关系
"""
@@ -366,7 +367,7 @@ class WebSiteViewSet(viewsets.ModelViewSet):
def export(self, request, **kwargs):
"""导出网站为 CSV 格式
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, created_at
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, response_body, response_headers, vhost, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
@@ -379,7 +380,7 @@ class WebSiteViewSet(viewsets.ModelViewSet):
headers = [
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'created_at'
'response_body', 'response_headers', 'vhost', 'created_at'
]
formatters = {
'created_at': format_datetime,
@@ -534,6 +535,7 @@ class EndpointViewSet(viewsets.ModelViewSet):
- host="example" 主机名模糊匹配
- title="login" 标题模糊匹配
- status="200,301" 状态码多值匹配
- tech="nginx" 技术栈匹配数组字段
- 多条件空格分隔 AND 关系
"""
@@ -626,7 +628,7 @@ class EndpointViewSet(viewsets.ModelViewSet):
def export(self, request, **kwargs):
"""导出端点为 CSV 格式
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, matched_gf_patterns, created_at
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, response_body, response_headers, vhost, matched_gf_patterns, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
@@ -639,7 +641,7 @@ class EndpointViewSet(viewsets.ModelViewSet):
headers = [
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
]
formatters = {
'created_at': format_datetime,
@@ -851,7 +853,7 @@ class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
def export(self, request, **kwargs):
"""导出网站快照为 CSV 格式
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, created_at
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, response_body, response_headers, vhost, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
@@ -864,7 +866,7 @@ class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
headers = [
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'created_at'
'response_body', 'response_headers', 'vhost', 'created_at'
]
formatters = {
'created_at': format_datetime,
@@ -968,7 +970,7 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
def export(self, request, **kwargs):
"""导出端点快照为 CSV 格式
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, matched_gf_patterns, created_at
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, response_body, response_headers, vhost, matched_gf_patterns, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
@@ -981,7 +983,7 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
headers = [
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
]
formatters = {
'created_at': format_datetime,

View File

@@ -0,0 +1,364 @@
"""
资产搜索 API 视图
提供资产搜索的 REST API 接口:
- GET /api/assets/search/ - 搜索资产
- GET /api/assets/search/export/ - 导出搜索结果为 CSV
搜索语法:
- field="value" 模糊匹配ILIKE %value%
- field=="value" 精确匹配
- field!="value" 不等于
- && AND 连接
- || OR 连接
支持的字段:
- host: 主机名
- url: URL
- title: 标题
- tech: 技术栈
- status: 状态码
- body: 响应体
- header: 响应头
支持的资产类型:
- website: 站点(默认)
- endpoint: 端点
"""
import logging
import json
from datetime import datetime
from urllib.parse import urlparse, urlunparse
from rest_framework import status
from rest_framework.views import APIView
from rest_framework.request import Request
from django.http import StreamingHttpResponse
from django.db import connection
from apps.common.response_helpers import success_response, error_response
from apps.common.error_codes import ErrorCodes
from apps.asset.services.search_service import AssetSearchService, VALID_ASSET_TYPES
logger = logging.getLogger(__name__)
class AssetSearchView(APIView):
"""
资产搜索 API
GET /api/assets/search/
Query Parameters:
q: 搜索查询表达式
asset_type: 资产类型 ('website''endpoint',默认 'website')
page: 页码(从 1 开始,默认 1
pageSize: 每页数量(默认 10最大 100
示例查询:
?q=host="api" && tech="nginx"
?q=tech="vue" || tech="react"&asset_type=endpoint
?q=status=="200" && host!="test"
Response:
{
"results": [...],
"total": 100,
"page": 1,
"pageSize": 10,
"totalPages": 10,
"assetType": "website"
}
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = AssetSearchService()
def _parse_headers(self, headers_data) -> dict:
"""解析响应头为字典"""
if not headers_data:
return {}
try:
return json.loads(headers_data)
except (json.JSONDecodeError, TypeError):
result = {}
for line in str(headers_data).split('\n'):
if ':' in line:
key, value = line.split(':', 1)
result[key.strip()] = value.strip()
return result
def _format_result(self, result: dict, vulnerabilities_by_url: dict, asset_type: str) -> dict:
"""格式化单个搜索结果"""
url = result.get('url', '')
vulns = vulnerabilities_by_url.get(url, [])
# 基础字段Website 和 Endpoint 共有)
formatted = {
'id': result.get('id'),
'url': url,
'host': result.get('host', ''),
'title': result.get('title', ''),
'technologies': result.get('tech', []) or [],
'statusCode': result.get('status_code'),
'contentLength': result.get('content_length'),
'contentType': result.get('content_type', ''),
'webserver': result.get('webserver', ''),
'location': result.get('location', ''),
'vhost': result.get('vhost'),
'responseHeaders': self._parse_headers(result.get('response_headers')),
'responseBody': result.get('response_body', ''),
'createdAt': result.get('created_at').isoformat() if result.get('created_at') else None,
'targetId': result.get('target_id'),
}
# Website 特有字段:漏洞关联
if asset_type == 'website':
formatted['vulnerabilities'] = [
{
'id': v.get('id'),
'name': v.get('vuln_type', ''),
'vulnType': v.get('vuln_type', ''),
'severity': v.get('severity', 'info'),
}
for v in vulns
]
# Endpoint 特有字段
if asset_type == 'endpoint':
formatted['matchedGfPatterns'] = result.get('matched_gf_patterns', []) or []
return formatted
def _get_vulnerabilities_by_url_prefix(self, website_urls: list) -> dict:
"""
根据 URL 前缀批量查询漏洞数据
漏洞 URL 是 website URL 的子路径,使用前缀匹配:
- website.url: https://example.com/path?query=1
- vulnerability.url: https://example.com/path/api/users
Args:
website_urls: website URL 列表,格式为 [(url, target_id), ...]
Returns:
dict: {website_url: [vulnerability_list]}
"""
if not website_urls:
return {}
try:
with connection.cursor() as cursor:
# 构建 OR 条件:每个 website URL去掉查询参数作为前缀匹配
conditions = []
params = []
url_mapping = {} # base_url -> original_url
for url, target_id in website_urls:
if not url or target_id is None:
continue
# 使用 urlparse 去掉查询参数和片段,只保留 scheme://netloc/path
parsed = urlparse(url)
base_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, '', '', ''))
url_mapping[base_url] = url
conditions.append("(v.url LIKE %s AND v.target_id = %s)")
params.extend([base_url + '%', target_id])
if not conditions:
return {}
where_clause = " OR ".join(conditions)
sql = f"""
SELECT v.id, v.vuln_type, v.severity, v.url, v.target_id
FROM vulnerability v
WHERE {where_clause}
ORDER BY
CASE v.severity
WHEN 'critical' THEN 1
WHEN 'high' THEN 2
WHEN 'medium' THEN 3
WHEN 'low' THEN 4
ELSE 5
END
"""
cursor.execute(sql, params)
# 获取所有漏洞
all_vulns = []
for row in cursor.fetchall():
all_vulns.append({
'id': row[0],
'vuln_type': row[1],
'name': row[1],
'severity': row[2],
'url': row[3],
'target_id': row[4],
})
# 按原始 website URL 分组(用于返回结果)
result = {url: [] for url, _ in website_urls}
for vuln in all_vulns:
vuln_url = vuln['url']
# 找到匹配的 website URL最长前缀匹配
for website_url, target_id in website_urls:
parsed = urlparse(website_url)
base_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, '', '', ''))
if vuln_url.startswith(base_url) and vuln['target_id'] == target_id:
result[website_url].append(vuln)
break
return result
except Exception as e:
logger.error(f"批量查询漏洞失败: {e}")
return {}
def get(self, request: Request):
"""搜索资产"""
# 获取搜索查询
query = request.query_params.get('q', '').strip()
if not query:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Search query (q) is required',
status_code=status.HTTP_400_BAD_REQUEST
)
# 获取并验证资产类型
asset_type = request.query_params.get('asset_type', 'website').strip().lower()
if asset_type not in VALID_ASSET_TYPES:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message=f'Invalid asset_type. Must be one of: {", ".join(VALID_ASSET_TYPES)}',
status_code=status.HTTP_400_BAD_REQUEST
)
# 获取分页参数
try:
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('pageSize', 10))
except (ValueError, TypeError):
page = 1
page_size = 10
# 限制分页参数
page = max(1, page)
page_size = min(max(1, page_size), 100)
# 获取总数和搜索结果
total = self.service.count(query, asset_type)
total_pages = (total + page_size - 1) // page_size if total > 0 else 1
offset = (page - 1) * page_size
all_results = self.service.search(query, asset_type)
results = all_results[offset:offset + page_size]
# 批量查询漏洞数据(仅 Website 类型需要)
vulnerabilities_by_url = {}
if asset_type == 'website':
website_urls = [(r.get('url'), r.get('target_id')) for r in results if r.get('url') and r.get('target_id')]
vulnerabilities_by_url = self._get_vulnerabilities_by_url_prefix(website_urls) if website_urls else {}
# 格式化结果
formatted_results = [self._format_result(r, vulnerabilities_by_url, asset_type) for r in results]
return success_response(data={
'results': formatted_results,
'total': total,
'page': page,
'pageSize': page_size,
'totalPages': total_pages,
'assetType': asset_type,
})
class AssetSearchExportView(APIView):
"""
资产搜索导出 API
GET /api/assets/search/export/
Query Parameters:
q: 搜索查询表达式
asset_type: 资产类型 ('website''endpoint',默认 'website')
Response:
CSV 文件流(使用服务端游标,支持大数据量导出)
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = AssetSearchService()
def _get_headers_and_formatters(self, asset_type: str):
"""获取 CSV 表头和格式化器"""
from apps.common.utils import format_datetime, format_list_field
if asset_type == 'website':
headers = ['url', 'host', 'title', 'status_code', 'content_type', 'content_length',
'webserver', 'location', 'tech', 'vhost', 'created_at']
else:
headers = ['url', 'host', 'title', 'status_code', 'content_type', 'content_length',
'webserver', 'location', 'tech', 'matched_gf_patterns', 'vhost', 'created_at']
formatters = {
'created_at': format_datetime,
'tech': lambda x: format_list_field(x, separator='; '),
'matched_gf_patterns': lambda x: format_list_field(x, separator='; '),
'vhost': lambda x: 'true' if x else ('false' if x is False else ''),
}
return headers, formatters
def get(self, request: Request):
"""导出搜索结果为 CSV流式导出无数量限制"""
from apps.common.utils import generate_csv_rows
# 获取搜索查询
query = request.query_params.get('q', '').strip()
if not query:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Search query (q) is required',
status_code=status.HTTP_400_BAD_REQUEST
)
# 获取并验证资产类型
asset_type = request.query_params.get('asset_type', 'website').strip().lower()
if asset_type not in VALID_ASSET_TYPES:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message=f'Invalid asset_type. Must be one of: {", ".join(VALID_ASSET_TYPES)}',
status_code=status.HTTP_400_BAD_REQUEST
)
# 检查是否有结果(快速检查,避免空导出)
total = self.service.count(query, asset_type)
if total == 0:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='No results to export',
status_code=status.HTTP_404_NOT_FOUND
)
# 获取表头和格式化器
headers, formatters = self._get_headers_and_formatters(asset_type)
# 获取流式数据迭代器
data_iterator = self.service.search_iter(query, asset_type)
# 生成文件名
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'search_{asset_type}_{timestamp}.csv'
# 返回流式响应
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, formatters),
content_type='text/csv; charset=utf-8'
)
response['Content-Disposition'] = f'attachment; filename="{filename}"'
return response

View File

@@ -40,8 +40,14 @@ def fetch_config_and_setup_django():
print(f"[CONFIG] 正在从配置中心获取配置: {config_url}")
print(f"[CONFIG] IS_LOCAL={is_local}")
try:
# 构建请求头(包含 Worker API Key
headers = {}
worker_api_key = os.environ.get("WORKER_API_KEY", "")
if worker_api_key:
headers["X-Worker-API-Key"] = worker_api_key
# verify=False: 远程 Worker 通过 HTTPS 访问时可能使用自签名证书
resp = requests.get(config_url, timeout=10, verify=False)
resp = requests.get(config_url, headers=headers, timeout=10, verify=False)
resp.raise_for_status()
config = resp.json()
@@ -57,9 +63,6 @@ def fetch_config_and_setup_django():
os.environ.setdefault("DB_USER", db_user)
os.environ.setdefault("DB_PASSWORD", config['db']['password'])
# Redis 配置
os.environ.setdefault("REDIS_URL", config['redisUrl'])
# 日志配置
os.environ.setdefault("LOG_DIR", config['paths']['logs'])
os.environ.setdefault("LOG_LEVEL", config['logging']['level'])
@@ -71,7 +74,6 @@ def fetch_config_and_setup_django():
print(f"[CONFIG] DB_PORT: {db_port}")
print(f"[CONFIG] DB_NAME: {db_name}")
print(f"[CONFIG] DB_USER: {db_user}")
print(f"[CONFIG] REDIS_URL: {config['redisUrl']}")
except Exception as e:
print(f"[ERROR] 获取配置失败: {config_url} - {e}", file=sys.stderr)

View File

@@ -0,0 +1,49 @@
"""
自定义异常处理器
统一处理 DRF 异常,确保错误响应格式一致
"""
from rest_framework.views import exception_handler
from rest_framework import status
from rest_framework.exceptions import AuthenticationFailed, NotAuthenticated
from apps.common.response_helpers import error_response
from apps.common.error_codes import ErrorCodes
def custom_exception_handler(exc, context):
"""
自定义异常处理器
处理认证相关异常,返回统一格式的错误响应
"""
# 先调用 DRF 默认的异常处理器
response = exception_handler(exc, context)
if response is not None:
# 处理 401 未认证错误
if response.status_code == status.HTTP_401_UNAUTHORIZED:
return error_response(
code=ErrorCodes.UNAUTHORIZED,
message='Authentication required',
status_code=status.HTTP_401_UNAUTHORIZED
)
# 处理 403 权限不足错误
if response.status_code == status.HTTP_403_FORBIDDEN:
return error_response(
code=ErrorCodes.PERMISSION_DENIED,
message='Permission denied',
status_code=status.HTTP_403_FORBIDDEN
)
# 处理 NotAuthenticated 和 AuthenticationFailed 异常
if isinstance(exc, (NotAuthenticated, AuthenticationFailed)):
return error_response(
code=ErrorCodes.UNAUTHORIZED,
message='Authentication required',
status_code=status.HTTP_401_UNAUTHORIZED
)
return response

View File

@@ -0,0 +1,80 @@
"""
集中式权限管理
实现三类端点的认证逻辑:
1. 公开端点(无需认证):登录、登出、获取当前用户状态
2. Worker 端点API Key 认证):注册、配置、心跳、回调、资源同步
3. 业务端点Session 认证):其他所有 API
"""
import re
import logging
from django.conf import settings
from rest_framework.permissions import BasePermission
logger = logging.getLogger(__name__)
# 公开端点白名单(无需任何认证)
PUBLIC_ENDPOINTS = [
r'^/api/auth/login/$',
r'^/api/auth/logout/$',
r'^/api/auth/me/$',
]
# Worker API 端点(需要 API Key 认证)
# 包括:注册、配置、心跳、回调、资源同步(字典下载)
WORKER_ENDPOINTS = [
r'^/api/workers/register/$',
r'^/api/workers/config/$',
r'^/api/workers/\d+/heartbeat/$',
r'^/api/callbacks/',
# 资源同步端点Worker 需要下载字典文件)
r'^/api/wordlists/download/$',
# 注意:指纹导出 API 使用 Session 认证(前端用户导出用)
# Worker 通过数据库直接获取指纹数据,不需要 HTTP API
]
class IsAuthenticatedOrPublic(BasePermission):
"""
自定义权限类:
- 白名单内的端点公开访问
- Worker 端点需要 API Key 认证
- 其他端点需要 Session 认证
"""
def has_permission(self, request, view):
path = request.path
# 检查是否在公开白名单内
for pattern in PUBLIC_ENDPOINTS:
if re.match(pattern, path):
return True
# 检查是否是 Worker 端点
for pattern in WORKER_ENDPOINTS:
if re.match(pattern, path):
return self._check_worker_api_key(request)
# 其他路径需要 Session 认证
return request.user and request.user.is_authenticated
def _check_worker_api_key(self, request):
"""验证 Worker API Key"""
api_key = request.headers.get('X-Worker-API-Key')
expected_key = getattr(settings, 'WORKER_API_KEY', None)
if not expected_key:
# 未配置 API Key 时,拒绝所有 Worker 请求
logger.warning("WORKER_API_KEY 未配置,拒绝 Worker 请求")
return False
if not api_key:
logger.warning(f"Worker 请求缺少 X-Worker-API-Key Header: {request.path}")
return False
if api_key != expected_key:
logger.warning(f"Worker API Key 无效: {request.path}")
return False
return True

View File

@@ -2,14 +2,18 @@
通用模块 URL 配置
路由说明:
- /api/health/ 健康检查接口(无需认证)
- /api/auth/* 认证相关接口(登录、登出、用户信息)
- /api/system/* 系统管理接口(日志查看等)
"""
from django.urls import path
from .views import LoginView, LogoutView, MeView, ChangePasswordView, SystemLogsView, SystemLogFilesView
from .views import LoginView, LogoutView, MeView, ChangePasswordView, SystemLogsView, SystemLogFilesView, HealthCheckView
urlpatterns = [
# 健康检查(无需认证)
path('health/', HealthCheckView.as_view(), name='health-check'),
# 认证相关
path('auth/login/', LoginView.as_view(), name='auth-login'),
path('auth/logout/', LogoutView.as_view(), name='auth-logout'),

View File

@@ -29,11 +29,19 @@ from dataclasses import dataclass
from typing import List, Dict, Optional, Union
from enum import Enum
from django.db.models import QuerySet, Q
from django.db.models import QuerySet, Q, F, Func, CharField
from django.db.models.functions import Cast
logger = logging.getLogger(__name__)
class ArrayToString(Func):
"""PostgreSQL array_to_string 函数"""
function = 'array_to_string'
template = "%(function)s(%(expressions)s, ',')"
output_field = CharField()
class LogicalOp(Enum):
"""逻辑运算符"""
AND = 'AND'
@@ -86,9 +94,21 @@ class QueryParser:
if not query_string or not query_string.strip():
return []
# 第一步:提取所有过滤条件并用占位符替换,保护引号内的空格
filters_found = []
placeholder_pattern = '__FILTER_{}__'
def replace_filter(match):
idx = len(filters_found)
filters_found.append(match.group(0))
return placeholder_pattern.format(idx)
# 先用正则提取所有 field="value" 形式的条件
protected = cls.FILTER_PATTERN.sub(replace_filter, query_string)
# 标准化逻辑运算符
# 先处理 || 和 or -> __OR__
normalized = cls.OR_PATTERN.sub(' __OR__ ', query_string)
normalized = cls.OR_PATTERN.sub(' __OR__ ', protected)
# 再处理 && 和 and -> __AND__
normalized = cls.AND_PATTERN.sub(' __AND__ ', normalized)
@@ -103,20 +123,26 @@ class QueryParser:
pending_op = LogicalOp.OR
elif token == '__AND__':
pending_op = LogicalOp.AND
else:
# 尝试解析为过滤条件
match = cls.FILTER_PATTERN.match(token)
if match:
field, operator, value = match.groups()
groups.append(FilterGroup(
filter=ParsedFilter(
field=field.lower(),
operator=operator,
value=value
),
logical_op=pending_op if groups else LogicalOp.AND # 第一个条件默认 AND
))
pending_op = LogicalOp.AND # 重置为默认 AND
elif token.startswith('__FILTER_') and token.endswith('__'):
# 还原占位符为原始过滤条件
try:
idx = int(token[9:-2]) # 提取索引
original_filter = filters_found[idx]
match = cls.FILTER_PATTERN.match(original_filter)
if match:
field, operator, value = match.groups()
groups.append(FilterGroup(
filter=ParsedFilter(
field=field.lower(),
operator=operator,
value=value
),
logical_op=pending_op if groups else LogicalOp.AND
))
pending_op = LogicalOp.AND # 重置为默认 AND
except (ValueError, IndexError):
pass
# 其他 token 忽略(无效输入)
return groups
@@ -151,6 +177,21 @@ class QueryBuilder:
json_array_fields = json_array_fields or []
# 收集需要 annotate 的数组模糊搜索字段
array_fuzzy_fields = set()
# 第一遍:检查是否有数组模糊匹配
for group in filter_groups:
f = group.filter
db_field = field_mapping.get(f.field)
if db_field and db_field in json_array_fields and f.operator == '=':
array_fuzzy_fields.add(db_field)
# 对数组模糊搜索字段做 annotate
for field in array_fuzzy_fields:
annotate_name = f'{field}_text'
queryset = queryset.annotate(**{annotate_name: ArrayToString(F(field))})
# 构建 Q 对象
combined_q = None
@@ -187,8 +228,17 @@ class QueryBuilder:
def _build_single_q(cls, field: str, operator: str, value: str, is_json_array: bool = False) -> Optional[Q]:
"""构建单个条件的 Q 对象"""
if is_json_array:
# JSON 数组字段使用 __contains 查询
return Q(**{f'{field}__contains': [value]})
if operator == '==':
# 精确匹配:数组中包含完全等于 value 的元素
return Q(**{f'{field}__contains': [value]})
elif operator == '!=':
# 不包含:数组中不包含完全等于 value 的元素
return ~Q(**{f'{field}__contains': [value]})
else: # '=' 模糊匹配
# 使用 annotate 后的字段进行模糊搜索
# 字段已在 build_query 中通过 ArrayToString 转换为文本
annotate_name = f'{field}_text'
return Q(**{f'{annotate_name}__icontains': value})
if operator == '!=':
return cls._build_not_equal_q(field, value)

View File

@@ -2,11 +2,17 @@
通用模块视图导出
包含:
- 健康检查视图Docker 健康检查
- 认证相关视图:登录、登出、用户信息、修改密码
- 系统日志视图:实时日志查看
"""
from .health_views import HealthCheckView
from .auth_views import LoginView, LogoutView, MeView, ChangePasswordView
from .system_log_views import SystemLogsView, SystemLogFilesView
__all__ = ['LoginView', 'LogoutView', 'MeView', 'ChangePasswordView', 'SystemLogsView', 'SystemLogFilesView']
__all__ = [
'HealthCheckView',
'LoginView', 'LogoutView', 'MeView', 'ChangePasswordView',
'SystemLogsView', 'SystemLogFilesView',
]

View File

@@ -9,7 +9,7 @@ from django.utils.decorators import method_decorator
from rest_framework import status
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.permissions import AllowAny
from apps.common.response_helpers import success_response, error_response
from apps.common.error_codes import ErrorCodes
@@ -134,30 +134,10 @@ class ChangePasswordView(APIView):
修改密码
POST /api/auth/change-password/
"""
authentication_classes = [] # 禁用认证(绕过 CSRF
permission_classes = [AllowAny] # 手动检查登录状态
def post(self, request):
# 手动检查登录状态(从 session 获取用户
from django.contrib.auth import get_user_model
User = get_user_model()
user_id = request.session.get('_auth_user_id')
if not user_id:
return error_response(
code=ErrorCodes.UNAUTHORIZED,
message='Please login first',
status_code=status.HTTP_401_UNAUTHORIZED
)
try:
user = User.objects.get(pk=user_id)
except User.DoesNotExist:
return error_response(
code=ErrorCodes.UNAUTHORIZED,
message='User does not exist',
status_code=status.HTTP_401_UNAUTHORIZED
)
# 使用全局权限类验证request.user 已经是认证用户
user = request.user
# CamelCaseParser 将 oldPassword -> old_password
old_password = request.data.get('old_password')

View File

@@ -0,0 +1,24 @@
"""
健康检查视图
提供 Docker 健康检查端点,无需认证。
"""
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.permissions import AllowAny
class HealthCheckView(APIView):
"""
健康检查端点
GET /api/health/
返回服务状态,用于 Docker 健康检查。
此端点无需认证。
"""
permission_classes = [AllowAny]
def get(self, request):
return Response({'status': 'ok'})

View File

@@ -9,7 +9,6 @@ import logging
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.views import APIView
@@ -42,9 +41,6 @@ class SystemLogFilesView(APIView):
]
}
"""
authentication_classes = []
permission_classes = [AllowAny]
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -80,15 +76,7 @@ class SystemLogsView(APIView):
{
"content": "日志内容字符串..."
}
Note:
- 当前为开发阶段,暂时允许匿名访问
- 生产环境应添加管理员权限验证
"""
# TODO: 生产环境应改为 IsAdminUser 权限
authentication_classes = []
permission_classes = [AllowAny]
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -0,0 +1,44 @@
"""
WebSocket 认证基类
提供需要认证的 WebSocket Consumer 基类
"""
import logging
from channels.generic.websocket import AsyncWebsocketConsumer
logger = logging.getLogger(__name__)
class AuthenticatedWebsocketConsumer(AsyncWebsocketConsumer):
"""
需要认证的 WebSocket Consumer 基类
子类应该重写 on_connect() 方法实现具体的连接逻辑
"""
async def connect(self):
"""
连接时验证用户认证状态
未认证时使用 close(code=4001) 拒绝连接
"""
user = self.scope.get('user')
if not user or not user.is_authenticated:
logger.warning(
f"WebSocket 连接被拒绝:用户未认证 - Path: {self.scope.get('path')}"
)
await self.close(code=4001)
return
# 调用子类的连接逻辑
await self.on_connect()
async def on_connect(self):
"""
子类实现具体的连接逻辑
默认实现:接受连接
"""
await self.accept()

View File

@@ -6,17 +6,17 @@ import json
import logging
import asyncio
import os
from channels.generic.websocket import AsyncWebsocketConsumer
from asgiref.sync import sync_to_async
from django.conf import settings
from apps.common.websocket_auth import AuthenticatedWebsocketConsumer
from apps.engine.services import WorkerService
logger = logging.getLogger(__name__)
class WorkerDeployConsumer(AsyncWebsocketConsumer):
class WorkerDeployConsumer(AuthenticatedWebsocketConsumer):
"""
Worker 交互式终端 WebSocket Consumer
@@ -31,8 +31,8 @@ class WorkerDeployConsumer(AsyncWebsocketConsumer):
self.read_task = None
self.worker_service = WorkerService()
async def connect(self):
"""连接时加入对应 Worker 的组并自动建立 SSH 连接"""
async def on_connect(self):
"""连接时加入对应 Worker 的组并自动建立 SSH 连接(已通过认证)"""
self.worker_id = self.scope['url_route']['kwargs']['worker_id']
self.group_name = f'worker_deploy_{self.worker_id}'

View File

@@ -15,9 +15,10 @@
"""
from django.core.management.base import BaseCommand
from io import StringIO
from pathlib import Path
import yaml
from ruamel.yaml import YAML
from apps.engine.models import ScanEngine
@@ -44,10 +45,12 @@ class Command(BaseCommand):
with open(config_path, 'r', encoding='utf-8') as f:
default_config = f.read()
# 解析 YAML 为字典,后续用于生成子引擎配置
# 使用 ruamel.yaml 解析,保留注释
yaml_parser = YAML()
yaml_parser.preserve_quotes = True
try:
config_dict = yaml.safe_load(default_config) or {}
except yaml.YAMLError as e:
config_dict = yaml_parser.load(default_config) or {}
except Exception as e:
self.stdout.write(self.style.ERROR(f'引擎配置 YAML 解析失败: {e}'))
return
@@ -83,15 +86,13 @@ class Command(BaseCommand):
if scan_type != 'subdomain_discovery' and 'tools' not in scan_cfg:
continue
# 构造只包含当前扫描类型配置的 YAML
# 构造只包含当前扫描类型配置的 YAML(保留注释)
single_config = {scan_type: scan_cfg}
try:
single_yaml = yaml.safe_dump(
single_config,
sort_keys=False,
allow_unicode=True,
)
except yaml.YAMLError as e:
stream = StringIO()
yaml_parser.dump(single_config, stream)
single_yaml = stream.getvalue()
except Exception as e:
self.stdout.write(self.style.ERROR(f'生成子引擎 {scan_type} 配置失败: {e}'))
continue

View File

@@ -0,0 +1,213 @@
# Generated by Django 5.2.7 on 2026-01-02 04:45
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
]
operations = [
migrations.CreateModel(
name='NucleiTemplateRepo',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='仓库名称,用于前端展示和配置引用', max_length=200, unique=True)),
('repo_url', models.CharField(help_text='Git 仓库地址', max_length=500)),
('local_path', models.CharField(blank=True, default='', help_text='本地工作目录绝对路径', max_length=500)),
('commit_hash', models.CharField(blank=True, default='', help_text='最后同步的 Git commit hash用于 Worker 版本校验', max_length=40)),
('last_synced_at', models.DateTimeField(blank=True, help_text='最后一次成功同步时间', null=True)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
],
options={
'verbose_name': 'Nuclei 模板仓库',
'verbose_name_plural': 'Nuclei 模板仓库',
'db_table': 'nuclei_template_repo',
},
),
migrations.CreateModel(
name='ARLFingerprint',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='指纹名称', max_length=300, unique=True)),
('rule', models.TextField(help_text='匹配规则表达式')),
('created_at', models.DateTimeField(auto_now_add=True)),
],
options={
'verbose_name': 'ARL 指纹',
'verbose_name_plural': 'ARL 指纹',
'db_table': 'arl_fingerprint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['name'], name='arl_fingerp_name_c3a305_idx'), models.Index(fields=['-created_at'], name='arl_fingerp_created_ed1060_idx')],
},
),
migrations.CreateModel(
name='EholeFingerprint',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('cms', models.CharField(help_text='产品/CMS名称', max_length=200)),
('method', models.CharField(default='keyword', help_text='匹配方式', max_length=200)),
('location', models.CharField(default='body', help_text='匹配位置', max_length=200)),
('keyword', models.JSONField(default=list, help_text='关键词列表')),
('is_important', models.BooleanField(default=False, help_text='是否重点资产')),
('type', models.CharField(blank=True, default='-', help_text='分类', max_length=100)),
('created_at', models.DateTimeField(auto_now_add=True)),
],
options={
'verbose_name': 'EHole 指纹',
'verbose_name_plural': 'EHole 指纹',
'db_table': 'ehole_fingerprint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['cms'], name='ehole_finge_cms_72ca2c_idx'), models.Index(fields=['method'], name='ehole_finge_method_17f0db_idx'), models.Index(fields=['location'], name='ehole_finge_locatio_7bb82b_idx'), models.Index(fields=['type'], name='ehole_finge_type_ca2bce_idx'), models.Index(fields=['is_important'], name='ehole_finge_is_impo_d56e64_idx'), models.Index(fields=['-created_at'], name='ehole_finge_created_d862b0_idx')],
'constraints': [models.UniqueConstraint(fields=('cms', 'method', 'location'), name='unique_ehole_fingerprint')],
},
),
migrations.CreateModel(
name='FingerPrintHubFingerprint',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('fp_id', models.CharField(help_text='指纹ID', max_length=200, unique=True)),
('name', models.CharField(help_text='指纹名称', max_length=300)),
('author', models.CharField(blank=True, default='', help_text='作者', max_length=200)),
('tags', models.CharField(blank=True, default='', help_text='标签', max_length=500)),
('severity', models.CharField(blank=True, default='info', help_text='严重程度', max_length=50)),
('metadata', models.JSONField(blank=True, default=dict, help_text='元数据')),
('http', models.JSONField(default=list, help_text='HTTP 匹配规则')),
('source_file', models.CharField(blank=True, default='', help_text='来源文件', max_length=500)),
('created_at', models.DateTimeField(auto_now_add=True)),
],
options={
'verbose_name': 'FingerPrintHub 指纹',
'verbose_name_plural': 'FingerPrintHub 指纹',
'db_table': 'fingerprinthub_fingerprint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['fp_id'], name='fingerprint_fp_id_df467f_idx'), models.Index(fields=['name'], name='fingerprint_name_95b6fb_idx'), models.Index(fields=['author'], name='fingerprint_author_80f54b_idx'), models.Index(fields=['severity'], name='fingerprint_severit_f70422_idx'), models.Index(fields=['-created_at'], name='fingerprint_created_bec16c_idx')],
},
),
migrations.CreateModel(
name='FingersFingerprint',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='指纹名称', max_length=300, unique=True)),
('link', models.URLField(blank=True, default='', help_text='相关链接', max_length=500)),
('rule', models.JSONField(default=list, help_text='匹配规则数组')),
('tag', models.JSONField(default=list, help_text='标签数组')),
('focus', models.BooleanField(default=False, help_text='是否重点关注')),
('default_port', models.JSONField(blank=True, default=list, help_text='默认端口数组')),
('created_at', models.DateTimeField(auto_now_add=True)),
],
options={
'verbose_name': 'Fingers 指纹',
'verbose_name_plural': 'Fingers 指纹',
'db_table': 'fingers_fingerprint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['name'], name='fingers_fin_name_952de0_idx'), models.Index(fields=['link'], name='fingers_fin_link_4c6b7f_idx'), models.Index(fields=['focus'], name='fingers_fin_focus_568c7f_idx'), models.Index(fields=['-created_at'], name='fingers_fin_created_46fc91_idx')],
},
),
migrations.CreateModel(
name='GobyFingerprint',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='产品名称', max_length=300, unique=True)),
('logic', models.CharField(help_text='逻辑表达式', max_length=500)),
('rule', models.JSONField(default=list, help_text='规则数组')),
('created_at', models.DateTimeField(auto_now_add=True)),
],
options={
'verbose_name': 'Goby 指纹',
'verbose_name_plural': 'Goby 指纹',
'db_table': 'goby_fingerprint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['name'], name='goby_finger_name_82084c_idx'), models.Index(fields=['logic'], name='goby_finger_logic_a63226_idx'), models.Index(fields=['-created_at'], name='goby_finger_created_50e000_idx')],
},
),
migrations.CreateModel(
name='ScanEngine',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('name', models.CharField(help_text='引擎名称', max_length=200, unique=True)),
('configuration', models.CharField(blank=True, default='', help_text='引擎配置yaml 格式', max_length=10000)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
],
options={
'verbose_name': '扫描引擎',
'verbose_name_plural': '扫描引擎',
'db_table': 'scan_engine',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='scan_engine_created_da4870_idx')],
},
),
migrations.CreateModel(
name='WappalyzerFingerprint',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='应用名称', max_length=300, unique=True)),
('cats', models.JSONField(default=list, help_text='分类 ID 数组')),
('cookies', models.JSONField(blank=True, default=dict, help_text='Cookie 检测规则')),
('headers', models.JSONField(blank=True, default=dict, help_text='HTTP Header 检测规则')),
('script_src', models.JSONField(blank=True, default=list, help_text='脚本 URL 正则数组')),
('js', models.JSONField(blank=True, default=list, help_text='JavaScript 变量检测规则')),
('implies', models.JSONField(blank=True, default=list, help_text='依赖关系数组')),
('meta', models.JSONField(blank=True, default=dict, help_text='HTML meta 标签检测规则')),
('html', models.JSONField(blank=True, default=list, help_text='HTML 内容正则数组')),
('description', models.TextField(blank=True, default='', help_text='应用描述')),
('website', models.URLField(blank=True, default='', help_text='官网链接', max_length=500)),
('cpe', models.CharField(blank=True, default='', help_text='CPE 标识符', max_length=300)),
('created_at', models.DateTimeField(auto_now_add=True)),
],
options={
'verbose_name': 'Wappalyzer 指纹',
'verbose_name_plural': 'Wappalyzer 指纹',
'db_table': 'wappalyzer_fingerprint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['name'], name='wappalyzer__name_63c669_idx'), models.Index(fields=['website'], name='wappalyzer__website_88de1c_idx'), models.Index(fields=['cpe'], name='wappalyzer__cpe_30c761_idx'), models.Index(fields=['-created_at'], name='wappalyzer__created_8e6c21_idx')],
},
),
migrations.CreateModel(
name='Wordlist',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('name', models.CharField(help_text='字典名称,唯一', max_length=200, unique=True)),
('description', models.CharField(blank=True, default='', help_text='字典描述', max_length=200)),
('file_path', models.CharField(help_text='后端保存的字典文件绝对路径', max_length=500)),
('file_size', models.BigIntegerField(default=0, help_text='文件大小(字节)')),
('line_count', models.IntegerField(default=0, help_text='字典行数')),
('file_hash', models.CharField(blank=True, default='', help_text='文件 SHA-256 哈希,用于缓存校验', max_length=64)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
],
options={
'verbose_name': '字典文件',
'verbose_name_plural': '字典文件',
'db_table': 'wordlist',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='wordlist_created_4afb02_idx')],
},
),
migrations.CreateModel(
name='WorkerNode',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='节点名称', max_length=100)),
('ip_address', models.GenericIPAddressField(help_text='IP 地址(本地节点为 127.0.0.1')),
('ssh_port', models.IntegerField(default=22, help_text='SSH 端口')),
('username', models.CharField(default='root', help_text='SSH 用户名', max_length=50)),
('password', models.CharField(blank=True, default='', help_text='SSH 密码', max_length=200)),
('is_local', models.BooleanField(default=False, help_text='是否为本地节点Docker 容器内)')),
('status', models.CharField(choices=[('pending', '待部署'), ('deploying', '部署中'), ('online', '在线'), ('offline', '离线'), ('updating', '更新中'), ('outdated', '版本过低')], default='pending', help_text='状态: pending/deploying/online/offline', max_length=20)),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
],
options={
'verbose_name': 'Worker 节点',
'db_table': 'worker_node',
'ordering': ['-created_at'],
'constraints': [models.UniqueConstraint(condition=models.Q(('is_local', False)), fields=('ip_address',), name='unique_remote_worker_ip'), models.UniqueConstraint(fields=('name',), name='unique_worker_name')],
},
),
]

View File

@@ -88,6 +88,8 @@ def _register_scheduled_jobs(scheduler: BackgroundScheduler):
replace_existing=True,
)
logger.info(" - 已注册: 扫描结果清理(每天 03:00")
# 注意:搜索物化视图刷新已迁移到 pg_ivm 增量维护,无需定时任务
def _trigger_scheduled_scans():

View File

@@ -66,6 +66,7 @@ def get_start_agent_script(
# 替换变量
script = script.replace("{{HEARTBEAT_API_URL}}", heartbeat_api_url or '')
script = script.replace("{{WORKER_ID}}", str(worker_id) if worker_id else '')
script = script.replace("{{WORKER_API_KEY}}", getattr(settings, 'WORKER_API_KEY', ''))
# 注入镜像版本配置(确保远程节点使用相同版本)
docker_user = getattr(settings, 'DOCKER_USER', 'yyhuni')

View File

@@ -264,10 +264,6 @@ class TaskDistributor:
"""
import shlex
# Docker API 版本配置(避免版本不兼容问题)
# 默认使用 1.40 以获得最大兼容性(支持 Docker 19.03+
api_version = getattr(settings, 'DOCKER_API_VERSION', '1.40')
# 根据 Worker 类型确定网络和 Server 地址
if worker.is_local:
# 本地:加入 Docker 网络,使用内部服务名
@@ -288,6 +284,7 @@ class TaskDistributor:
env_vars = [
f"-e SERVER_URL={shlex.quote(server_url)}",
f"-e IS_LOCAL={is_local_str}",
f"-e WORKER_API_KEY={shlex.quote(settings.WORKER_API_KEY)}", # Worker API 认证密钥
"-e PREFECT_HOME=/tmp/.prefect", # 设置 Prefect 数据目录到可写位置
"-e PREFECT_SERVER_EPHEMERAL_ENABLED=true", # 启用 ephemeral server本地临时服务器
"-e PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=120", # 增加启动超时时间
@@ -315,9 +312,7 @@ class TaskDistributor:
# - 本地 Workerinstall.sh 已预拉取镜像,直接使用本地版本
# - 远程 Workerdeploy 时已预拉取镜像,直接使用本地版本
# - 避免每次任务都检查 Docker Hub提升性能和稳定性
# 使用双引号包裹 sh -c 命令,内部 shlex.quote 生成的单引号参数可正确解析
# DOCKER_API_VERSION 环境变量确保客户端和服务端 API 版本兼容
cmd = f'''DOCKER_API_VERSION={api_version} docker run --rm -d --pull=missing {network_arg} \\
cmd = f'''docker run --rm -d --pull=missing {network_arg} \\
{' '.join(env_vars)} \\
{' '.join(volumes)} \\
{self.docker_image} \\

View File

@@ -340,13 +340,12 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
返回:
{
"db": {"host": "...", "port": "...", ...},
"redisUrl": "...",
"paths": {"results": "...", "logs": "..."}
}
配置逻辑:
- 本地 Worker (is_local=true): db_host=postgres, redis=redis:6379
- 远程 Worker (is_local=false): db_host=PUBLIC_HOST, redis=PUBLIC_HOST:6379
- 本地 Worker (is_local=true): db_host=postgres
- 远程 Worker (is_local=false): db_host=PUBLIC_HOST
"""
from django.conf import settings
import logging
@@ -371,20 +370,17 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
if is_local_worker:
# 本地 Worker直接用 Docker 内部服务名
worker_db_host = 'postgres'
worker_redis_url = 'redis://redis:6379/0'
else:
# 远程 Worker通过公网 IP 访问
public_host = settings.PUBLIC_HOST
if public_host in ('server', 'localhost', '127.0.0.1'):
logger.warning("远程 Worker 请求配置,但 PUBLIC_HOST=%s 不是有效的公网地址", public_host)
worker_db_host = public_host
worker_redis_url = f'redis://{public_host}:6379/0'
else:
# 远程数据库场景:所有 Worker 都用 DB_HOST
worker_db_host = db_host
worker_redis_url = getattr(settings, 'WORKER_REDIS_URL', 'redis://redis:6379/0')
logger.info("返回 Worker 配置 - db_host: %s, redis_url: %s", worker_db_host, worker_redis_url)
logger.info("返回 Worker 配置 - db_host: %s", worker_db_host)
return success_response(
data={
@@ -395,7 +391,6 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
'user': settings.DATABASES['default']['USER'],
'password': settings.DATABASES['default']['PASSWORD'],
},
'redisUrl': worker_redis_url,
'paths': {
'results': getattr(settings, 'CONTAINER_RESULTS_MOUNT', '/opt/xingrin/results'),
'logs': getattr(settings, 'CONTAINER_LOGS_MOUNT', '/opt/xingrin/logs'),

View File

@@ -97,9 +97,11 @@ SITE_SCAN_COMMANDS = {
'base': (
"'{scan_tools_base}/httpx' -l '{url_file}' "
'-status-code -content-type -content-length '
'-location -title -server -body-preview '
'-location -title -server '
'-tech-detect -cdn -vhost '
'-random-agent -no-color -json'
'-include-response '
'-rstr 2000 '
'-random-agent -no-color -json -silent'
),
'optional': {
'threads': '-threads {threads}',
@@ -169,9 +171,11 @@ URL_FETCH_COMMANDS = {
'base': (
"'{scan_tools_base}/httpx' -l '{url_file}' "
'-status-code -content-type -content-length '
'-location -title -server -body-preview '
'-location -title -server '
'-tech-detect -cdn -vhost '
'-random-agent -no-color -json'
'-include-response '
'-rstr 2000 '
'-random-agent -no-color -json -silent'
),
'optional': {
'threads': '-threads {threads}',

View File

@@ -4,14 +4,12 @@
# 必需参数enabled是否启用
# 可选参数timeout超时秒数默认 auto 自动计算)
# ==================== 子域名发现 ====================
#
# Stage 1: 被动收集(并行) - 必选,至少启用一个工具
# Stage 2: 字典爆破(可选) - 使用字典暴力枚举子域名
# Stage 3: 变异生成 + 验证(可选) - 基于已发现域名生成变异,流式验证存活
# Stage 4: DNS 存活验证(可选) - 验证所有候选域名是否能解析
#
subdomain_discovery:
# ==================== 子域名发现 ====================
# Stage 1: 被动收集(并行) - 必选,至少启用一个工具
# Stage 2: 字典爆破(可选) - 使用字典暴力枚举子域名
# Stage 3: 变异生成 + 验证(可选) - 基于已发现域名生成变异,流式验证存活
# Stage 4: DNS 存活验证(可选) - 验证所有候选域名是否能解析
# === Stage 1: 被动收集工具(并行执行)===
passive_tools:
subfinder:
@@ -55,8 +53,8 @@ subdomain_discovery:
subdomain_resolve:
timeout: auto # 自动根据候选子域数量计算
# ==================== 端口扫描 ====================
port_scan:
# ==================== 端口扫描 ====================
tools:
naabu_active:
enabled: true
@@ -70,8 +68,8 @@ port_scan:
enabled: true
# timeout: auto # 被动扫描通常较快
# ==================== 站点扫描 ====================
site_scan:
# ==================== 站点扫描 ====================
tools:
httpx:
enabled: true
@@ -81,16 +79,16 @@ site_scan:
# request-timeout: 10 # 单个请求超时秒数(默认 10
# retries: 2 # 请求失败重试次数
# ==================== 指纹识别 ====================
# 在 site_scan 后串行执行,识别 WebSite 的技术栈
fingerprint_detect:
# ==================== 指纹识别 ====================
# 在 站点扫描 后串行执行,识别 WebSite 的技术栈
tools:
xingfinger:
enabled: true
fingerprint-libs: [ehole, goby, wappalyzer] # 启用的指纹库ehole, goby, wappalyzer, fingers, fingerprinthub, arl
fingerprint-libs: [ehole, goby, wappalyzer, fingers, fingerprinthub, arl] # 默认启动全部指纹库
# ==================== 目录扫描 ====================
directory_scan:
# ==================== 目录扫描 ====================
tools:
ffuf:
enabled: true
@@ -103,8 +101,8 @@ directory_scan:
match-codes: 200,201,301,302,401,403 # 匹配的 HTTP 状态码
# rate: 0 # 每秒请求数(默认 0 不限制)
# ==================== URL 获取 ====================
url_fetch:
# ==================== URL 获取 ====================
tools:
waymore:
enabled: true
@@ -142,8 +140,8 @@ url_fetch:
# request-timeout: 10 # 单个请求超时秒数(默认 10
# retries: 2 # 请求失败重试次数
# ==================== 漏洞扫描 ====================
vuln_scan:
# ==================== 漏洞扫描 ====================
tools:
dalfox_xss:
enabled: true

View File

@@ -37,28 +37,24 @@ logger = logging.getLogger(__name__)
def calculate_fingerprint_detect_timeout(
url_count: int,
base_per_url: float = 3.0,
min_timeout: int = 60
base_per_url: float = 10.0,
min_timeout: int = 300
) -> int:
"""
根据 URL 数量计算超时时间
公式:超时时间 = URL 数量 × 每 URL 基础时间
最小值:60秒
最小值:300秒
无上限
Args:
url_count: URL 数量
base_per_url: 每 URL 基础时间(秒),默认 3
min_timeout: 最小超时时间(秒),默认 60秒
base_per_url: 每 URL 基础时间(秒),默认 10
min_timeout: 最小超时时间(秒),默认 300秒
Returns:
int: 计算出的超时时间(秒)
示例:
100 URL × 3秒 = 300秒
1000 URL × 3秒 = 3000秒50分钟
10000 URL × 3秒 = 30000秒8.3小时)
"""
timeout = int(url_count * base_per_url)
return max(min_timeout, timeout)
@@ -260,7 +256,8 @@ def fingerprint_detect_flow(
'url_count': int,
'processed_records': int,
'updated_count': int,
'not_found_count': int,
'created_count': int,
'snapshot_count': int,
'executed_tasks': list,
'tool_stats': dict
}
@@ -307,6 +304,7 @@ def fingerprint_detect_flow(
'processed_records': 0,
'updated_count': 0,
'created_count': 0,
'snapshot_count': 0,
'executed_tasks': ['export_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
@@ -344,6 +342,7 @@ def fingerprint_detect_flow(
total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values())
total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values())
total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values())
total_snapshots = sum(stats['result'].get('snapshot_count', 0) for stats in tool_stats.values())
successful_tools = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
@@ -358,6 +357,7 @@ def fingerprint_detect_flow(
'processed_records': total_processed,
'updated_count': total_updated,
'created_count': total_created,
'snapshot_count': total_snapshots,
'executed_tasks': executed_tasks,
'tool_stats': {
'total': len(enabled_tools),

View File

@@ -114,8 +114,11 @@ def initiate_scan_flow(
# ==================== Task 2: 获取引擎配置 ====================
from apps.scan.models import Scan
scan = Scan.objects.select_related('engine').get(id=scan_id)
engine_config = scan.engine.configuration
scan = Scan.objects.get(id=scan_id)
engine_config = scan.merged_configuration
# 使用 engine_names 进行显示
display_engine_name = ', '.join(scan.engine_names) if scan.engine_names else engine_name
# ==================== Task 3: 解析配置,生成执行计划 ====================
orchestrator = FlowOrchestrator(engine_config)

View File

@@ -204,14 +204,13 @@ def _run_scans_sequentially(
# 流式执行扫描并实时保存结果
result = run_and_stream_save_websites_task(
cmd=command,
tool_name=tool_name, # 新增:工具名称
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
cwd=str(site_scan_dir),
shell=True,
batch_size=1000,
timeout=timeout,
log_file=str(log_file) # 新增:日志文件路径
log_file=str(log_file)
)
tool_stats[tool_name] = {

View File

@@ -212,7 +212,6 @@ def _validate_and_stream_save_urls(
target_id=target_id,
cwd=str(url_fetch_dir),
shell=True,
batch_size=500,
timeout=timeout,
log_file=str(log_file)
)

View File

@@ -162,6 +162,8 @@ def on_initiate_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State)
# 执行状态更新并获取统计数据
stats = _update_completed_status()
# 注意:物化视图刷新已迁移到 pg_ivm 增量维护,无需手动标记刷新
# 发送通知(包含统计摘要)
logger.info("准备发送扫描完成通知 - Scan ID: %s, Target: %s", scan_id, target_name)
try:

View File

@@ -0,0 +1,119 @@
# Generated by Django 5.2.7 on 2026-01-02 04:45
import django.contrib.postgres.fields
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
('engine', '0001_initial'),
('targets', '0001_initial'),
]
operations = [
migrations.CreateModel(
name='NotificationSettings',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('discord_enabled', models.BooleanField(default=False, help_text='是否启用 Discord 通知')),
('discord_webhook_url', models.URLField(blank=True, default='', help_text='Discord Webhook URL')),
('categories', models.JSONField(default=dict, help_text='各分类通知开关,如 {"scan": true, "vulnerability": true, "asset": true, "system": false}')),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
],
options={
'verbose_name': '通知设置',
'verbose_name_plural': '通知设置',
'db_table': 'notification_settings',
},
),
migrations.CreateModel(
name='Notification',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('category', models.CharField(choices=[('scan', '扫描任务'), ('vulnerability', '漏洞发现'), ('asset', '资产发现'), ('system', '系统消息')], db_index=True, default='system', help_text='通知分类', max_length=20)),
('level', models.CharField(choices=[('low', ''), ('medium', ''), ('high', ''), ('critical', '严重')], db_index=True, default='low', help_text='通知级别', max_length=20)),
('title', models.CharField(help_text='通知标题', max_length=200)),
('message', models.CharField(help_text='通知内容', max_length=2000)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('is_read', models.BooleanField(default=False, help_text='是否已读')),
('read_at', models.DateTimeField(blank=True, help_text='阅读时间', null=True)),
],
options={
'verbose_name': '通知',
'verbose_name_plural': '通知',
'db_table': 'notification',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='notificatio_created_c430f0_idx'), models.Index(fields=['category', '-created_at'], name='notificatio_categor_df0584_idx'), models.Index(fields=['level', '-created_at'], name='notificatio_level_0e5d12_idx'), models.Index(fields=['is_read', '-created_at'], name='notificatio_is_read_518ce0_idx')],
},
),
migrations.CreateModel(
name='Scan',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('engine_ids', django.contrib.postgres.fields.ArrayField(base_field=models.IntegerField(), default=list, help_text='引擎 ID 列表', size=None)),
('engine_names', models.JSONField(default=list, help_text='引擎名称列表,如 ["引擎A", "引擎B"]')),
('merged_configuration', models.TextField(default='', help_text='合并后的 YAML 配置')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='任务创建时间')),
('stopped_at', models.DateTimeField(blank=True, help_text='扫描结束时间', null=True)),
('status', models.CharField(choices=[('cancelled', '已取消'), ('completed', '已完成'), ('failed', '失败'), ('initiated', '初始化'), ('running', '运行中')], db_index=True, default='initiated', help_text='任务状态', max_length=20)),
('results_dir', models.CharField(blank=True, default='', help_text='结果存储目录', max_length=100)),
('container_ids', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='容器 ID 列表Docker Container ID', size=None)),
('error_message', models.CharField(blank=True, default='', help_text='错误信息', max_length=2000)),
('deleted_at', models.DateTimeField(blank=True, db_index=True, help_text='删除时间NULL表示未删除', null=True)),
('progress', models.IntegerField(default=0, help_text='扫描进度 0-100')),
('current_stage', models.CharField(blank=True, default='', help_text='当前扫描阶段', max_length=50)),
('stage_progress', models.JSONField(default=dict, help_text='各阶段进度详情')),
('cached_subdomains_count', models.IntegerField(default=0, help_text='缓存的子域名数量')),
('cached_websites_count', models.IntegerField(default=0, help_text='缓存的网站数量')),
('cached_endpoints_count', models.IntegerField(default=0, help_text='缓存的端点数量')),
('cached_ips_count', models.IntegerField(default=0, help_text='缓存的IP地址数量')),
('cached_directories_count', models.IntegerField(default=0, help_text='缓存的目录数量')),
('cached_vulns_total', models.IntegerField(default=0, help_text='缓存的漏洞总数')),
('cached_vulns_critical', models.IntegerField(default=0, help_text='缓存的严重漏洞数量')),
('cached_vulns_high', models.IntegerField(default=0, help_text='缓存的高危漏洞数量')),
('cached_vulns_medium', models.IntegerField(default=0, help_text='缓存的中危漏洞数量')),
('cached_vulns_low', models.IntegerField(default=0, help_text='缓存的低危漏洞数量')),
('stats_updated_at', models.DateTimeField(blank=True, help_text='统计数据最后更新时间', null=True)),
('target', models.ForeignKey(help_text='扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='scans', to='targets.target')),
('worker', models.ForeignKey(blank=True, help_text='执行扫描的 Worker 节点', null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='scans', to='engine.workernode')),
],
options={
'verbose_name': '扫描任务',
'verbose_name_plural': '扫描任务',
'db_table': 'scan',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='scan_created_0bb6c7_idx'), models.Index(fields=['target'], name='scan_target__718b9d_idx'), models.Index(fields=['deleted_at', '-created_at'], name='scan_deleted_eb17e8_idx')],
},
),
migrations.CreateModel(
name='ScheduledScan',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('name', models.CharField(help_text='任务名称', max_length=200)),
('engine_ids', django.contrib.postgres.fields.ArrayField(base_field=models.IntegerField(), default=list, help_text='引擎 ID 列表', size=None)),
('engine_names', models.JSONField(default=list, help_text='引擎名称列表,如 ["引擎A", "引擎B"]')),
('merged_configuration', models.TextField(default='', help_text='合并后的 YAML 配置')),
('cron_expression', models.CharField(default='0 2 * * *', help_text='Cron 表达式,格式:分 时 日 月 周', max_length=100)),
('is_enabled', models.BooleanField(db_index=True, default=True, help_text='是否启用')),
('run_count', models.IntegerField(default=0, help_text='已执行次数')),
('last_run_time', models.DateTimeField(blank=True, help_text='上次执行时间', null=True)),
('next_run_time', models.DateTimeField(blank=True, help_text='下次执行时间', null=True)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
('organization', models.ForeignKey(blank=True, help_text='扫描组织(设置后执行时动态获取组织下所有目标)', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='scheduled_scans', to='targets.organization')),
('target', models.ForeignKey(blank=True, help_text='扫描单个目标(与 organization 二选一)', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='scheduled_scans', to='targets.target')),
],
options={
'verbose_name': '定时扫描任务',
'verbose_name_plural': '定时扫描任务',
'db_table': 'scheduled_scan',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='scheduled_s_created_9b9c2e_idx'), models.Index(fields=['is_enabled', '-created_at'], name='scheduled_s_is_enab_23d660_idx'), models.Index(fields=['name'], name='scheduled_s_name_bf332d_idx')],
},
),
]

View File

@@ -20,11 +20,19 @@ class Scan(models.Model):
target = models.ForeignKey('targets.Target', on_delete=models.CASCADE, related_name='scans', help_text='扫描目标')
engine = models.ForeignKey(
'engine.ScanEngine',
on_delete=models.CASCADE,
related_name='scans',
help_text='使用的扫描引擎'
# 多引擎支持字段
engine_ids = ArrayField(
models.IntegerField(),
default=list,
help_text='引擎 ID 列表'
)
engine_names = models.JSONField(
default=list,
help_text='引擎名称列表,如 ["引擎A", "引擎B"]'
)
merged_configuration = models.TextField(
default='',
help_text='合并后的 YAML 配置'
)
created_at = models.DateTimeField(auto_now_add=True, help_text='任务创建时间')
@@ -118,12 +126,19 @@ class ScheduledScan(models.Model):
# 基本信息
name = models.CharField(max_length=200, help_text='任务名称')
# 关联的扫描引擎
engine = models.ForeignKey(
'engine.ScanEngine',
on_delete=models.CASCADE,
related_name='scheduled_scans',
help_text='使用的扫描引擎'
# 多引擎支持字段
engine_ids = ArrayField(
models.IntegerField(),
default=list,
help_text='引擎 ID 列表'
)
engine_names = models.JSONField(
default=list,
help_text='引擎名称列表,如 ["引擎A", "引擎B"]'
)
merged_configuration = models.TextField(
default='',
help_text='合并后的 YAML 配置'
)
# 关联的组织(组织扫描模式:执行时动态获取组织下所有目标)

View File

@@ -5,12 +5,13 @@ WebSocket Consumer - 通知实时推送
import json
import logging
import asyncio
from channels.generic.websocket import AsyncWebsocketConsumer
from apps.common.websocket_auth import AuthenticatedWebsocketConsumer
logger = logging.getLogger(__name__)
class NotificationConsumer(AsyncWebsocketConsumer):
class NotificationConsumer(AuthenticatedWebsocketConsumer):
"""
通知 WebSocket Consumer
@@ -23,9 +24,9 @@ class NotificationConsumer(AsyncWebsocketConsumer):
super().__init__(*args, **kwargs)
self.heartbeat_task = None # 心跳任务
async def connect(self):
async def on_connect(self):
"""
客户端连接时调用
客户端连接时调用(已通过认证)
加入通知广播组
"""
# 通知组名(所有客户端共享)

View File

@@ -305,6 +305,7 @@ def _push_via_api_callback(notification: Notification, server_url: str) -> None:
通过 HTTP 请求 Server 容器的 /api/callbacks/notification/ 接口。
Worker 无法直接访问 Redis需要由 Server 代为推送 WebSocket。
"""
import os
import requests
try:
@@ -318,8 +319,14 @@ def _push_via_api_callback(notification: Notification, server_url: str) -> None:
'created_at': notification.created_at.isoformat()
}
# 构建请求头(包含 Worker API Key
headers = {'Content-Type': 'application/json'}
worker_api_key = os.environ.get("WORKER_API_KEY", "")
if worker_api_key:
headers["X-Worker-API-Key"] = worker_api_key
# verify=False: 远程 Worker 回调 Server 时可能使用自签名证书
resp = requests.post(callback_url, json=data, timeout=5, verify=False)
resp = requests.post(callback_url, json=data, headers=headers, timeout=5, verify=False)
resp.raise_for_status()
logger.debug(f"通知回调推送成功 - ID: {notification.id}")

View File

@@ -7,8 +7,7 @@ from typing import Any
from django.http import JsonResponse
from django.utils import timezone
from rest_framework import status
from rest_framework.decorators import api_view, permission_classes
from rest_framework.permissions import AllowAny
from rest_framework.decorators import api_view
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
@@ -198,12 +197,13 @@ class NotificationSettingsView(APIView):
# ============================================
@api_view(['POST'])
@permission_classes([AllowAny]) # Worker 容器无认证,可考虑添加 Token 验
# 权限由全局 IsAuthenticatedOrPublic 处理,/api/callbacks/* 需要 Worker API Key 认
def notification_callback(request):
"""
接收 Worker 的通知推送请求
Worker 容器无法直接访问 Redis通过此 API 回调让 Server 推送 WebSocket。
需要 Worker API Key 认证X-Worker-API-Key Header
POST /api/callbacks/notification/
{

View File

@@ -16,7 +16,6 @@ from django.utils import timezone
from apps.scan.models import Scan
from apps.targets.models import Target
from apps.engine.models import ScanEngine
from apps.common.definitions import ScanStatus
from apps.common.decorators import auto_ensure_db_connection
@@ -40,7 +39,7 @@ class DjangoScanRepository:
Args:
scan_id: 扫描任务 ID
prefetch_relations: 是否预加载关联对象(engine, target
prefetch_relations: 是否预加载关联对象(target, worker
默认 False只在需要展示关联信息时设为 True
for_update: 是否加锁(用于更新场景)
@@ -56,7 +55,7 @@ class DjangoScanRepository:
# 预加载关联对象(性能优化:默认不加载)
if prefetch_relations:
queryset = queryset.select_related('engine', 'target')
queryset = queryset.select_related('target', 'worker')
return queryset.get(id=scan_id)
except Scan.DoesNotExist: # type: ignore # pylint: disable=no-member
@@ -79,7 +78,7 @@ class DjangoScanRepository:
Note:
- 使用默认的阻塞模式(等待锁释放)
- 不包含关联对象(engine, target),如需关联对象请使用 get_by_id()
- 不包含关联对象(target, worker),如需关联对象请使用 get_by_id()
"""
try:
return Scan.objects.select_for_update().get(id=scan_id) # type: ignore # pylint: disable=no-member
@@ -103,7 +102,9 @@ class DjangoScanRepository:
def create(self,
target: Target,
engine: ScanEngine,
engine_ids: List[int],
engine_names: List[str],
merged_configuration: str,
results_dir: str,
status: ScanStatus = ScanStatus.INITIATED
) -> Scan:
@@ -112,7 +113,9 @@ class DjangoScanRepository:
Args:
target: 扫描目标
engine: 扫描引擎
engine_ids: 引擎 ID 列表
engine_names: 引擎名称列表
merged_configuration: 合并后的 YAML 配置
results_dir: 结果目录
status: 初始状态
@@ -121,7 +124,9 @@ class DjangoScanRepository:
"""
scan = Scan(
target=target,
engine=engine,
engine_ids=engine_ids,
engine_names=engine_names,
merged_configuration=merged_configuration,
results_dir=results_dir,
status=status,
container_ids=[]
@@ -231,14 +236,14 @@ class DjangoScanRepository:
获取所有扫描任务
Args:
prefetch_relations: 是否预加载关联对象(engine, target
prefetch_relations: 是否预加载关联对象(target, worker
Returns:
Scan QuerySet
"""
queryset = Scan.objects.all() # type: ignore # pylint: disable=no-member
if prefetch_relations:
queryset = queryset.select_related('engine', 'target')
queryset = queryset.select_related('target', 'worker')
return queryset.order_by('-created_at')

View File

@@ -29,7 +29,9 @@ class ScheduledScanDTO:
"""
id: Optional[int] = None
name: str = ''
engine_id: int = 0
engine_ids: List[int] = None # 多引擎支持
engine_names: List[str] = None # 引擎名称列表
merged_configuration: str = '' # 合并后的配置
organization_id: Optional[int] = None # 组织扫描模式
target_id: Optional[int] = None # 目标扫描模式
cron_expression: Optional[str] = None
@@ -40,6 +42,11 @@ class ScheduledScanDTO:
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
def __post_init__(self):
if self.engine_ids is None:
self.engine_ids = []
if self.engine_names is None:
self.engine_names = []
@auto_ensure_db_connection
@@ -56,7 +63,7 @@ class DjangoScheduledScanRepository:
def get_by_id(self, scheduled_scan_id: int) -> Optional[ScheduledScan]:
"""根据 ID 查询定时扫描任务"""
try:
return ScheduledScan.objects.select_related('engine', 'organization', 'target').get(id=scheduled_scan_id)
return ScheduledScan.objects.select_related('organization', 'target').get(id=scheduled_scan_id)
except ScheduledScan.DoesNotExist:
return None
@@ -67,7 +74,7 @@ class DjangoScheduledScanRepository:
Returns:
QuerySet
"""
return ScheduledScan.objects.select_related('engine', 'organization', 'target').order_by('-created_at')
return ScheduledScan.objects.select_related('organization', 'target').order_by('-created_at')
def get_all(self, page: int = 1, page_size: int = 10) -> Tuple[List[ScheduledScan], int]:
"""
@@ -87,7 +94,7 @@ class DjangoScheduledScanRepository:
def get_enabled(self) -> List[ScheduledScan]:
"""获取所有启用的定时扫描任务"""
return list(
ScheduledScan.objects.select_related('engine', 'target')
ScheduledScan.objects.select_related('target')
.filter(is_enabled=True)
.order_by('-created_at')
)
@@ -105,7 +112,9 @@ class DjangoScheduledScanRepository:
with transaction.atomic():
scheduled_scan = ScheduledScan.objects.create(
name=dto.name,
engine_id=dto.engine_id,
engine_ids=dto.engine_ids,
engine_names=dto.engine_names,
merged_configuration=dto.merged_configuration,
organization_id=dto.organization_id, # 组织扫描模式
target_id=dto.target_id if not dto.organization_id else None, # 目标扫描模式
cron_expression=dto.cron_expression,
@@ -134,8 +143,12 @@ class DjangoScheduledScanRepository:
# 更新基本字段
if dto.name:
scheduled_scan.name = dto.name
if dto.engine_id:
scheduled_scan.engine_id = dto.engine_id
if dto.engine_ids is not None:
scheduled_scan.engine_ids = dto.engine_ids
if dto.engine_names is not None:
scheduled_scan.engine_names = dto.engine_names
if dto.merged_configuration is not None:
scheduled_scan.merged_configuration = dto.merged_configuration
if dto.cron_expression is not None:
scheduled_scan.cron_expression = dto.cron_expression
if dto.is_enabled is not None:

View File

@@ -7,12 +7,11 @@ from .models import Scan, ScheduledScan
class ScanSerializer(serializers.ModelSerializer):
"""扫描任务序列化器"""
target_name = serializers.SerializerMethodField()
engine_name = serializers.SerializerMethodField()
class Meta:
model = Scan
fields = [
'id', 'target', 'target_name', 'engine', 'engine_name',
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
'created_at', 'stopped_at', 'status', 'results_dir',
'container_ids', 'error_message'
]
@@ -24,10 +23,6 @@ class ScanSerializer(serializers.ModelSerializer):
def get_target_name(self, obj):
"""获取目标名称"""
return obj.target.name if obj.target else None
def get_engine_name(self, obj):
"""获取引擎名称"""
return obj.engine.name if obj.engine else None
class ScanHistorySerializer(serializers.ModelSerializer):
@@ -36,11 +31,12 @@ class ScanHistorySerializer(serializers.ModelSerializer):
为前端扫描历史页面提供优化的数据格式,包括:
- 扫描汇总统计(子域名、端点、漏洞数量)
- 进度百分比和当前阶段
- 执行节点信息
"""
# 字段映射
target_name = serializers.CharField(source='target.name', read_only=True)
engine_name = serializers.CharField(source='engine.name', read_only=True)
worker_name = serializers.CharField(source='worker.name', read_only=True, allow_null=True)
# 计算字段
summary = serializers.SerializerMethodField()
@@ -53,9 +49,9 @@ class ScanHistorySerializer(serializers.ModelSerializer):
class Meta:
model = Scan
fields = [
'id', 'target', 'target_name', 'engine', 'engine_name',
'created_at', 'status', 'error_message', 'summary', 'progress',
'current_stage', 'stage_progress'
'id', 'target', 'target_name', 'engine_ids', 'engine_names',
'worker_name', 'created_at', 'status', 'error_message', 'summary',
'progress', 'current_stage', 'stage_progress'
]
def get_summary(self, obj):
@@ -105,10 +101,11 @@ class QuickScanSerializer(serializers.Serializer):
help_text='目标列表,每个目标包含 name 字段'
)
# 扫描引擎 ID
engine_id = serializers.IntegerField(
# 扫描引擎 ID 列表
engine_ids = serializers.ListField(
child=serializers.IntegerField(),
required=True,
help_text='使用的扫描引擎 ID (必填)'
help_text='使用的扫描引擎 ID 列表 (必填)'
)
def validate_targets(self, value):
@@ -130,6 +127,12 @@ class QuickScanSerializer(serializers.Serializer):
raise serializers.ValidationError(f"{idx + 1} 个目标的 name 不能为空")
return value
def validate_engine_ids(self, value):
"""验证引擎 ID 列表"""
if not value:
raise serializers.ValidationError("engine_ids 不能为空")
return value
# ==================== 定时扫描序列化器 ====================
@@ -138,7 +141,6 @@ class ScheduledScanSerializer(serializers.ModelSerializer):
"""定时扫描任务序列化器(用于列表和详情)"""
# 关联字段
engine_name = serializers.CharField(source='engine.name', read_only=True)
organization_id = serializers.IntegerField(source='organization.id', read_only=True, allow_null=True)
organization_name = serializers.CharField(source='organization.name', read_only=True, allow_null=True)
target_id = serializers.IntegerField(source='target.id', read_only=True, allow_null=True)
@@ -149,7 +151,7 @@ class ScheduledScanSerializer(serializers.ModelSerializer):
model = ScheduledScan
fields = [
'id', 'name',
'engine', 'engine_name',
'engine_ids', 'engine_names',
'organization_id', 'organization_name',
'target_id', 'target_name',
'scan_mode',
@@ -178,7 +180,10 @@ class CreateScheduledScanSerializer(serializers.Serializer):
"""
name = serializers.CharField(max_length=200, help_text='任务名称')
engine_id = serializers.IntegerField(help_text='扫描引擎 ID')
engine_ids = serializers.ListField(
child=serializers.IntegerField(),
help_text='扫描引擎 ID 列表'
)
# 组织扫描模式
organization_id = serializers.IntegerField(
@@ -201,6 +206,12 @@ class CreateScheduledScanSerializer(serializers.Serializer):
)
is_enabled = serializers.BooleanField(default=True, help_text='是否立即启用')
def validate_engine_ids(self, value):
"""验证引擎 ID 列表"""
if not value:
raise serializers.ValidationError("engine_ids 不能为空")
return value
def validate(self, data):
"""验证 organization_id 和 target_id 互斥"""
organization_id = data.get('organization_id')
@@ -219,7 +230,11 @@ class UpdateScheduledScanSerializer(serializers.Serializer):
"""更新定时扫描任务序列化器"""
name = serializers.CharField(max_length=200, required=False, help_text='任务名称')
engine_id = serializers.IntegerField(required=False, help_text='扫描引擎 ID')
engine_ids = serializers.ListField(
child=serializers.IntegerField(),
required=False,
help_text='扫描引擎 ID 列表'
)
# 组织扫描模式
organization_id = serializers.IntegerField(
@@ -237,6 +252,12 @@ class UpdateScheduledScanSerializer(serializers.Serializer):
cron_expression = serializers.CharField(max_length=100, required=False, help_text='Cron 表达式')
is_enabled = serializers.BooleanField(required=False, help_text='是否启用')
def validate_engine_ids(self, value):
"""验证引擎 ID 列表"""
if value is not None and not value:
raise serializers.ValidationError("engine_ids 不能为空")
return value
class ToggleScheduledScanSerializer(serializers.Serializer):

View File

@@ -10,7 +10,7 @@
import uuid
import logging
import threading
from typing import List
from typing import List, Tuple
from datetime import datetime
from pathlib import Path
from django.conf import settings
@@ -20,6 +20,7 @@ from django.core.exceptions import ValidationError, ObjectDoesNotExist
from apps.scan.models import Scan
from apps.scan.repositories import DjangoScanRepository
from apps.scan.utils.config_merger import merge_engine_configs, ConfigConflictError
from apps.targets.repositories import DjangoTargetRepository, DjangoOrganizationRepository
from apps.engine.repositories import DjangoEngineRepository
from apps.targets.models import Target
@@ -142,6 +143,106 @@ class ScanCreationService:
return targets, engine
def prepare_initiate_scan_multi_engine(
self,
organization_id: int | None = None,
target_id: int | None = None,
engine_ids: List[int] | None = None
) -> Tuple[List[Target], str, List[str], List[int]]:
"""
准备多引擎扫描任务所需的数据
职责:
1. 参数验证(必填项、互斥参数)
2. 资源查询Engines、Organization、Target
3. 合并引擎配置(检测冲突)
4. 返回准备好的目标列表、合并配置和引擎信息
Args:
organization_id: 组织ID可选
target_id: 目标ID可选
engine_ids: 扫描引擎ID列表必填
Returns:
(目标列表, 合并配置, 引擎名称列表, 引擎ID列表) - 供 create_scans 方法使用
Raises:
ValidationError: 参数验证失败或业务规则不满足
ObjectDoesNotExist: 资源不存在Organization/Target/ScanEngine
ConfigConflictError: 引擎配置存在冲突
Note:
- organization_id 和 target_id 必须二选一
- 如果提供 organization_id返回该组织下所有目标
- 如果提供 target_id返回单个目标列表
"""
# 1. 参数验证
if not engine_ids:
raise ValidationError('缺少必填参数: engine_ids')
if not organization_id and not target_id:
raise ValidationError('必须提供 organization_id 或 target_id 其中之一')
if organization_id and target_id:
raise ValidationError('organization_id 和 target_id 只能提供其中之一')
# 2. 查询所有扫描引擎
engines = []
for engine_id in engine_ids:
engine = self.engine_repo.get_by_id(engine_id)
if not engine:
logger.error("扫描引擎不存在 - Engine ID: %s", engine_id)
raise ObjectDoesNotExist(f'ScanEngine ID {engine_id} 不存在')
engines.append(engine)
# 3. 合并引擎配置(可能抛出 ConfigConflictError
engine_configs = [(e.name, e.configuration or '') for e in engines]
merged_configuration = merge_engine_configs(engine_configs)
engine_names = [e.name for e in engines]
logger.debug(
"引擎配置合并成功 - 引擎: %s",
', '.join(engine_names)
)
# 4. 根据参数获取目标列表
targets = []
if organization_id:
# 根据组织ID获取所有目标
organization = self.organization_repo.get_by_id(organization_id)
if not organization:
logger.error("组织不存在 - Organization ID: %s", organization_id)
raise ObjectDoesNotExist(f'Organization ID {organization_id} 不存在')
targets = self.organization_repo.get_targets(organization_id)
if not targets:
raise ValidationError(f'组织 ID {organization_id} 下没有目标')
logger.debug(
"准备发起扫描 - 组织: %s, 目标数量: %d, 引擎: %s",
organization.name,
len(targets),
', '.join(engine_names)
)
else:
# 根据目标ID获取单个目标
target = self.target_repo.get_by_id(target_id)
if not target:
logger.error("目标不存在 - Target ID: %s", target_id)
raise ObjectDoesNotExist(f'Target ID {target_id} 不存在')
targets = [target]
logger.debug(
"准备发起扫描 - 目标: %s, 引擎: %s",
target.name,
', '.join(engine_names)
)
return targets, merged_configuration, engine_names, engine_ids
def _generate_scan_workspace_dir(self) -> str:
"""
生成 Scan 工作空间目录路径
@@ -179,7 +280,9 @@ class ScanCreationService:
def create_scans(
self,
targets: List[Target],
engine: ScanEngine,
engine_ids: List[int],
engine_names: List[str],
merged_configuration: str,
scheduled_scan_name: str | None = None
) -> List[Scan]:
"""
@@ -187,7 +290,9 @@ class ScanCreationService:
Args:
targets: 目标列表
engine: 扫描引擎对象
engine_ids: 引擎 ID 列表
engine_names: 引擎名称列表
merged_configuration: 合并后的 YAML 配置
scheduled_scan_name: 定时扫描任务名称(可选,用于通知显示)
Returns:
@@ -205,7 +310,9 @@ class ScanCreationService:
scan_workspace_dir = self._generate_scan_workspace_dir()
scan = Scan(
target=target,
engine=engine,
engine_ids=engine_ids,
engine_names=engine_names,
merged_configuration=merged_configuration,
results_dir=scan_workspace_dir,
status=ScanStatus.INITIATED,
container_ids=[],
@@ -236,13 +343,15 @@ class ScanCreationService:
return []
# 第三步:分发任务到 Workers
# 使用第一个引擎名称作为显示名称,或者合并显示
display_engine_name = ', '.join(engine_names) if engine_names else ''
scan_data = [
{
'scan_id': scan.id,
'target_name': scan.target.name,
'target_id': scan.target.id,
'results_dir': scan.results_dir,
'engine_name': scan.engine.name,
'engine_name': display_engine_name,
'scheduled_scan_name': scheduled_scan_name,
}
for scan in created_scans

View File

@@ -96,14 +96,34 @@ class ScanService:
organization_id, target_id, engine_id
)
def prepare_initiate_scan_multi_engine(
self,
organization_id: int | None = None,
target_id: int | None = None,
engine_ids: List[int] | None = None
) -> tuple[List[Target], str, List[str], List[int]]:
"""
为创建多引擎扫描任务做准备
Returns:
(目标列表, 合并配置, 引擎名称列表, 引擎ID列表)
"""
return self.creation_service.prepare_initiate_scan_multi_engine(
organization_id, target_id, engine_ids
)
def create_scans(
self,
targets: List[Target],
engine: ScanEngine,
engine_ids: List[int],
engine_names: List[str],
merged_configuration: str,
scheduled_scan_name: str | None = None
) -> List[Scan]:
"""批量创建扫描任务(委托给 ScanCreationService"""
return self.creation_service.create_scans(targets, engine, scheduled_scan_name)
return self.creation_service.create_scans(
targets, engine_ids, engine_names, merged_configuration, scheduled_scan_name
)
# ==================== 状态管理方法(委托给 ScanStateService ====================

View File

@@ -14,6 +14,7 @@ from django.core.exceptions import ValidationError
from apps.scan.models import ScheduledScan
from apps.scan.repositories import DjangoScheduledScanRepository, ScheduledScanDTO
from apps.scan.utils.config_merger import merge_engine_configs, ConfigConflictError
from apps.engine.repositories import DjangoEngineRepository
from apps.targets.services import TargetService
@@ -57,8 +58,9 @@ class ScheduledScanService:
流程:
1. 验证参数
2. 创建数据库记录
3. 计算并设置 next_run_time
2. 合并引擎配置
3. 创建数据库记录
4. 计算并设置 next_run_time
Args:
dto: 定时扫描 DTO
@@ -68,14 +70,30 @@ class ScheduledScanService:
Raises:
ValidationError: 参数验证失败
ConfigConflictError: 引擎配置冲突
"""
# 1. 验证参数
self._validate_create_dto(dto)
# 2. 创建数据库记录
# 2. 合并引擎配置
engines = []
engine_names = []
for engine_id in dto.engine_ids:
engine = self.engine_repo.get_by_id(engine_id)
if engine:
engines.append((engine.name, engine.configuration or ''))
engine_names.append(engine.name)
merged_configuration = merge_engine_configs(engines)
# 设置 DTO 的合并配置和引擎名称
dto.engine_names = engine_names
dto.merged_configuration = merged_configuration
# 3. 创建数据库记录
scheduled_scan = self.repo.create(dto)
# 3. 如果有 cron 表达式且已启用,计算下次执行时间
# 4. 如果有 cron 表达式且已启用,计算下次执行时间
if scheduled_scan.cron_expression and scheduled_scan.is_enabled:
next_run_time = self._calculate_next_run_time(scheduled_scan)
if next_run_time:
@@ -96,11 +114,13 @@ class ScheduledScanService:
if not dto.name:
raise ValidationError('任务名称不能为空')
if not dto.engine_id:
if not dto.engine_ids:
raise ValidationError('必须选择扫描引擎')
if not self.engine_repo.get_by_id(dto.engine_id):
raise ValidationError(f'扫描引擎 ID {dto.engine_id} 不存在')
# 验证所有引擎是否存在
for engine_id in dto.engine_ids:
if not self.engine_repo.get_by_id(engine_id):
raise ValidationError(f'扫描引擎 ID {engine_id} 不存在')
# 验证扫描模式organization_id 和 target_id 互斥)
if not dto.organization_id and not dto.target_id:
@@ -138,11 +158,28 @@ class ScheduledScanService:
Returns:
更新后的 ScheduledScan 对象
Raises:
ConfigConflictError: 引擎配置冲突
"""
existing = self.repo.get_by_id(scheduled_scan_id)
if not existing:
return None
# 如果引擎变更,重新合并配置
if dto.engine_ids is not None:
engines = []
engine_names = []
for engine_id in dto.engine_ids:
engine = self.engine_repo.get_by_id(engine_id)
if engine:
engines.append((engine.name, engine.configuration or ''))
engine_names.append(engine.name)
merged_configuration = merge_engine_configs(engines)
dto.engine_names = engine_names
dto.merged_configuration = merged_configuration
# 更新数据库记录
scheduled_scan = self.repo.update(scheduled_scan_id, dto)
if not scheduled_scan:
@@ -292,21 +329,25 @@ class ScheduledScanService:
立即触发扫描(支持组织扫描和目标扫描两种模式)
复用 ScanService 的逻辑,与 API 调用保持一致。
使用存储的 merged_configuration 而不是重新合并。
"""
from apps.scan.services.scan_service import ScanService
scan_service = ScanService()
# 1. 准备扫描所需数据(复用 API 的逻辑
targets, engine = scan_service.prepare_initiate_scan(
# 1. 准备扫描所需数据(使用存储的多引擎配置
targets, _, _, _ = scan_service.prepare_initiate_scan_multi_engine(
organization_id=scheduled_scan.organization_id,
target_id=scheduled_scan.target_id,
engine_id=scheduled_scan.engine_id
engine_ids=scheduled_scan.engine_ids
)
# 2. 创建扫描任务,传递定时扫描名称用于通知显示
# 2. 创建扫描任务,使用存储的合并配置
created_scans = scan_service.create_scans(
targets, engine,
targets=targets,
engine_ids=scheduled_scan.engine_ids,
engine_names=scheduled_scan.engine_names,
merged_configuration=scheduled_scan.merged_configuration,
scheduled_scan_name=scheduled_scan.name
)

View File

@@ -4,7 +4,6 @@ xingfinger 执行任务
流式执行 xingfinger 命令并实时更新 tech 字段
"""
import importlib
import json
import logging
import subprocess
@@ -15,93 +14,97 @@ from django.db import connection
from prefect import task
from apps.scan.utils import execute_stream
from apps.asset.dtos.snapshot import WebsiteSnapshotDTO
from apps.asset.repositories.snapshot import DjangoWebsiteSnapshotRepository
logger = logging.getLogger(__name__)
# 数据源映射source → (module_path, model_name, url_field)
SOURCE_MODEL_MAP = {
'website': ('apps.asset.models', 'WebSite', 'url'),
# 以后扩展:
# 'endpoint': ('apps.asset.models', 'Endpoint', 'url'),
# 'directory': ('apps.asset.models', 'Directory', 'url'),
}
def _get_model_class(source: str):
"""根据数据源类型获取 Model 类"""
if source not in SOURCE_MODEL_MAP:
raise ValueError(f"不支持的数据源: {source}")
module_path, model_name, _ = SOURCE_MODEL_MAP[source]
module = importlib.import_module(module_path)
return getattr(module, model_name)
def parse_xingfinger_line(line: str) -> tuple[str, list[str]] | None:
def parse_xingfinger_line(line: str) -> dict | None:
"""
解析 xingfinger 单行 JSON 输出
xingfinger 静默模式输出格式:
{"url": "https://example.com", "cms": "WordPress,PHP,nginx", ...}
xingfinger 输出格式:
{"url": "...", "cms": "...", "server": "BWS/1.1", "status_code": 200, "length": 642831, "title": "..."}
Returns:
tuple: (url, tech_list) 或 None解析失败时
dict: 包含 url, techs, server, title, status_code, content_length 的字典
None: 解析失败或 URL 为空时
"""
try:
item = json.loads(line)
url = item.get('url', '').strip()
cms = item.get('cms', '')
if not url or not cms:
if not url:
return None
# cms 字段按逗号分割,去除空白
techs = [t.strip() for t in cms.split(',') if t.strip()]
cms = item.get('cms', '')
techs = [t.strip() for t in cms.split(',') if t.strip()] if cms else []
return (url, techs) if techs else None
return {
'url': url,
'techs': techs,
'server': item.get('server', ''),
'title': item.get('title', ''),
'status_code': item.get('status_code'),
'content_length': item.get('length'),
}
except json.JSONDecodeError:
return None
def bulk_merge_tech_field(
source: str,
url_techs_map: dict[str, list[str]],
def bulk_merge_website_fields(
records: list[dict],
target_id: int
) -> dict:
"""
批量合并 tech 数组字段PostgreSQL 原生 SQL
批量合并更新 WebSite 字段PostgreSQL 原生 SQL
合并策略:
- tech数组合并去重
- title, webserver, status_code, content_length只在原值为空/NULL 时更新
使用 PostgreSQL 原生 SQL 实现高效的数组合并去重操作。
如果 URL 对应的记录不存在,会自动创建新记录。
Args:
records: 解析后的记录列表,每个包含 {url, techs, server, title, status_code, content_length}
target_id: 目标 ID
Returns:
dict: {'updated_count': int, 'created_count': int}
"""
Model = _get_model_class(source)
table_name = Model._meta.db_table
from apps.asset.models import WebSite
table_name = WebSite._meta.db_table
updated_count = 0
created_count = 0
with connection.cursor() as cursor:
for url, techs in url_techs_map.items():
if not techs:
continue
for record in records:
url = record['url']
techs = record.get('techs', [])
server = record.get('server', '') or ''
title = record.get('title', '') or ''
status_code = record.get('status_code')
content_length = record.get('content_length')
# 先尝试更新(PostgreSQL 数组合并去重
sql = f"""
# 先尝试更新(合并策略
update_sql = f"""
UPDATE {table_name}
SET tech = (
SELECT ARRAY(SELECT DISTINCT unnest(
SET
tech = (SELECT ARRAY(SELECT DISTINCT unnest(
COALESCE(tech, ARRAY[]::varchar[]) || %s::varchar[]
))
)
))),
title = CASE WHEN title = '' OR title IS NULL THEN %s ELSE title END,
webserver = CASE WHEN webserver = '' OR webserver IS NULL THEN %s ELSE webserver END,
status_code = CASE WHEN status_code IS NULL THEN %s ELSE status_code END,
content_length = CASE WHEN content_length IS NULL THEN %s ELSE content_length END
WHERE url = %s AND target_id = %s
"""
cursor.execute(sql, [techs, url, target_id])
cursor.execute(update_sql, [techs, title, server, status_code, content_length, url, target_id])
if cursor.rowcount > 0:
updated_count += cursor.rowcount
@@ -114,20 +117,26 @@ def bulk_merge_tech_field(
# 插入新记录(带冲突处理)
insert_sql = f"""
INSERT INTO {table_name} (target_id, url, host, tech, created_at)
VALUES (%s, %s, %s, %s::varchar[], NOW())
INSERT INTO {table_name} (
target_id, url, host, location, title, webserver,
response_body, content_type, tech, status_code, content_length,
response_headers, created_at
)
VALUES (%s, %s, %s, '', %s, %s, '', '', %s::varchar[], %s, %s, '', NOW())
ON CONFLICT (target_id, url) DO UPDATE SET
tech = (
SELECT ARRAY(SELECT DISTINCT unnest(
COALESCE({table_name}.tech, ARRAY[]::varchar[]) || EXCLUDED.tech
))
)
tech = (SELECT ARRAY(SELECT DISTINCT unnest(
COALESCE({table_name}.tech, ARRAY[]::varchar[]) || EXCLUDED.tech
))),
title = CASE WHEN {table_name}.title = '' OR {table_name}.title IS NULL THEN EXCLUDED.title ELSE {table_name}.title END,
webserver = CASE WHEN {table_name}.webserver = '' OR {table_name}.webserver IS NULL THEN EXCLUDED.webserver ELSE {table_name}.webserver END,
status_code = CASE WHEN {table_name}.status_code IS NULL THEN EXCLUDED.status_code ELSE {table_name}.status_code END,
content_length = CASE WHEN {table_name}.content_length IS NULL THEN EXCLUDED.content_length ELSE {table_name}.content_length END
"""
cursor.execute(insert_sql, [target_id, url, host, techs])
cursor.execute(insert_sql, [target_id, url, host, title, server, techs, status_code, content_length])
created_count += 1
except Exception as e:
logger.warning("创建 %s 记录失败 (url=%s): %s", source, url, e)
logger.warning("创建 WebSite 记录失败 (url=%s): %s", url, e)
return {
'updated_count': updated_count,
@@ -141,12 +150,12 @@ def _parse_xingfinger_stream_output(
cwd: Optional[str] = None,
timeout: Optional[int] = None,
log_file: Optional[str] = None
) -> Generator[tuple[str, list[str]], None, None]:
) -> Generator[dict, None, None]:
"""
流式解析 xingfinger 命令输出
基于 execute_stream 实时处理 xingfinger 命令的 stdout将每行 JSON 输出
转换为 (url, tech_list) 格式
转换为完整字段字典
"""
logger.info("开始流式解析 xingfinger 命令输出 - 命令: %s", cmd)
@@ -193,43 +202,46 @@ def run_xingfinger_and_stream_update_tech_task(
batch_size: int = 100
) -> dict:
"""
流式执行 xingfinger 命令并实时更新 tech 字段
根据 source 参数更新对应表的 tech 字段:
- website → WebSite.tech
- endpoint → Endpoint.tech以后扩展
流式执行 xingfinger 命令,保存快照并合并更新资产表
处理流程:
1. 流式执行 xingfinger 命令
2. 实时解析 JSON 输出
3. 累积到 batch_size 条后批量更新数据库
4. 使用 PostgreSQL 原生 SQL 进行数组合并去重
5. 如果记录不存在,自动创建
2. 实时解析 JSON 输出(完整字段)
3. 累积到 batch_size 条后批量处理:
- 保存快照WebsiteSnapshot
- 合并更新资产表WebSite
合并策略:
- tech数组合并去重
- title, webserver, status_code, content_length只在原值为空时更新
Returns:
dict: {
'processed_records': int,
'updated_count': int,
'created_count': int,
'snapshot_count': int,
'batch_count': int
}
"""
logger.info(
"开始执行 xingfinger 并更新 tech - target_id=%s, source=%s, timeout=%s",
target_id, source, timeout
"开始执行 xingfinger - scan_id=%s, target_id=%s, timeout=%s",
scan_id, target_id, timeout
)
data_generator = None
snapshot_repo = DjangoWebsiteSnapshotRepository()
try:
# 初始化统计
processed_records = 0
updated_count = 0
created_count = 0
snapshot_count = 0
batch_count = 0
# 当前批次的 URL -> techs 映射
url_techs_map = {}
# 当前批次的记录列表
batch_records = []
# 流式处理
data_generator = _parse_xingfinger_stream_output(
@@ -240,47 +252,43 @@ def run_xingfinger_and_stream_update_tech_task(
log_file=log_file
)
for url, techs in data_generator:
for record in data_generator:
processed_records += 1
batch_records.append(record)
# 累积到 url_techs_map
if url in url_techs_map:
# 合并同一 URL 的多次识别结果
url_techs_map[url].extend(techs)
else:
url_techs_map[url] = techs
# 达到批次大小,执行批量更新
if len(url_techs_map) >= batch_size:
# 达到批次大小,执行批量处理
if len(batch_records) >= batch_size:
batch_count += 1
result = bulk_merge_tech_field(source, url_techs_map, target_id)
updated_count += result['updated_count']
created_count += result.get('created_count', 0)
logger.debug(
"批次 %d 完成 - 更新: %d, 创建: %d",
batch_count, result['updated_count'], result.get('created_count', 0)
result = _process_batch(
batch_records, scan_id, target_id, batch_count, snapshot_repo
)
updated_count += result['updated_count']
created_count += result['created_count']
snapshot_count += result['snapshot_count']
# 清空批次
url_techs_map = {}
batch_records = []
# 处理最后一批
if url_techs_map:
if batch_records:
batch_count += 1
result = bulk_merge_tech_field(source, url_techs_map, target_id)
result = _process_batch(
batch_records, scan_id, target_id, batch_count, snapshot_repo
)
updated_count += result['updated_count']
created_count += result.get('created_count', 0)
created_count += result['created_count']
snapshot_count += result['snapshot_count']
logger.info(
"✓ xingfinger 执行完成 - 处理记录: %d, 更新: %d, 创建: %d, 批次: %d",
processed_records, updated_count, created_count, batch_count
"✓ xingfinger 执行完成 - 处理: %d, 更新: %d, 创建: %d, 快照: %d, 批次: %d",
processed_records, updated_count, created_count, snapshot_count, batch_count
)
return {
'processed_records': processed_records,
'updated_count': updated_count,
'created_count': created_count,
'snapshot_count': snapshot_count,
'batch_count': batch_count
}
@@ -298,3 +306,67 @@ def run_xingfinger_and_stream_update_tech_task(
data_generator.close()
except Exception as e:
logger.debug("关闭生成器时出错: %s", e)
def _process_batch(
records: list[dict],
scan_id: int,
target_id: int,
batch_num: int,
snapshot_repo: DjangoWebsiteSnapshotRepository
) -> dict:
"""
处理一个批次的数据:保存快照 + 合并更新资产表
Args:
records: 解析后的记录列表
scan_id: 扫描任务 ID
target_id: 目标 ID
batch_num: 批次编号
snapshot_repo: 快照仓库
Returns:
dict: {'updated_count': int, 'created_count': int, 'snapshot_count': int}
"""
# 1. 构建快照 DTO 列表
snapshot_dtos = []
for record in records:
# 从 URL 提取 host
parsed = urlparse(record['url'])
host = parsed.hostname or ''
dto = WebsiteSnapshotDTO(
scan_id=scan_id,
target_id=target_id,
url=record['url'],
host=host,
title=record.get('title', '') or '',
status_code=record.get('status_code'),
content_length=record.get('content_length'),
webserver=record.get('server', '') or '',
tech=record.get('techs', []),
)
snapshot_dtos.append(dto)
# 2. 保存快照
snapshot_count = 0
if snapshot_dtos:
try:
snapshot_repo.save_snapshots(snapshot_dtos)
snapshot_count = len(snapshot_dtos)
except Exception as e:
logger.warning("批次 %d 保存快照失败: %s", batch_num, e)
# 3. 合并更新资产表
merge_result = bulk_merge_website_fields(records, target_id)
logger.debug(
"批次 %d 完成 - 更新: %d, 创建: %d, 快照: %d",
batch_num, merge_result['updated_count'], merge_result['created_count'], snapshot_count
)
return {
'updated_count': merge_result['updated_count'],
'created_count': merge_result['created_count'],
'snapshot_count': snapshot_count
}

View File

@@ -30,7 +30,6 @@ from typing import Generator, Optional, Dict, Any, TYPE_CHECKING
from django.db import IntegrityError, OperationalError, DatabaseError
from dataclasses import dataclass
from urllib.parse import urlparse, urlunparse
from dateutil.parser import parse as parse_datetime
from psycopg2 import InterfaceError
from apps.asset.dtos.snapshot import WebsiteSnapshotDTO
@@ -62,6 +61,18 @@ class ServiceSet:
)
def _sanitize_string(value: str) -> str:
"""
清理字符串中的 NUL 字符和其他不可打印字符
PostgreSQL 不允许字符串字段包含 NUL (0x00) 字符
"""
if not value:
return value
# 移除 NUL 字符
return value.replace('\x00', '')
def normalize_url(url: str) -> str:
"""
标准化 URL移除默认端口号
@@ -117,69 +128,50 @@ def normalize_url(url: str) -> str:
return url
def _extract_hostname(url: str) -> str:
"""
从 URL 提取主机名
Args:
url: URL 字符串
Returns:
str: 提取的主机名(小写)
"""
try:
if url:
parsed = urlparse(url)
if parsed.hostname:
return parsed.hostname
# 降级方案:手动提取
return url.replace('http://', '').replace('https://', '').split('/')[0].split(':')[0]
return ''
except Exception as e:
logger.debug("提取主机名失败: %s", e)
return ''
class HttpxRecord:
"""httpx 扫描记录数据类"""
def __init__(self, data: Dict[str, Any]):
self.url = data.get('url', '')
self.input = data.get('input', '')
self.title = data.get('title', '')
self.status_code = data.get('status_code')
self.content_length = data.get('content_length')
self.content_type = data.get('content_type', '')
self.location = data.get('location', '')
self.webserver = data.get('webserver', '')
self.body_preview = data.get('body_preview', '')
self.tech = data.get('tech', [])
self.vhost = data.get('vhost')
self.failed = data.get('failed', False)
self.timestamp = data.get('timestamp')
self.url = _sanitize_string(data.get('url', ''))
self.input = _sanitize_string(data.get('input', ''))
self.title = _sanitize_string(data.get('title', ''))
self.status_code = data.get('status_code') # int不需要清理
self.content_length = data.get('content_length') # int不需要清理
self.content_type = _sanitize_string(data.get('content_type', ''))
self.location = _sanitize_string(data.get('location', ''))
self.webserver = _sanitize_string(data.get('webserver', ''))
self.response_body = _sanitize_string(data.get('body', ''))
self.tech = [_sanitize_string(t) for t in data.get('tech', []) if isinstance(t, str)] # 列表中的字符串也需要清理
self.vhost = data.get('vhost') # bool不需要清理
self.failed = data.get('failed', False) # bool不需要清理
self.response_headers = _sanitize_string(data.get('raw_header', ''))
# 从 URL 中提取主机名
self.host = self._extract_hostname()
def _extract_hostname(self) -> str:
"""
从 URL 或 input 字段提取主机名
优先级:
1. 使用 urlparse 解析 URL 获取 hostname
2. 从 input 字段提取(处理可能包含协议的情况)
3. 从 URL 字段手动提取(降级方案)
Returns:
str: 提取的主机名(小写)
"""
try:
# 方法 1: 使用 urlparse 解析 URL
if self.url:
parsed = urlparse(self.url)
if parsed.hostname:
return parsed.hostname
# 方法 2: 从 input 字段提取
if self.input:
host = self.input.strip().lower()
# 移除协议前缀
if host.startswith(('http://', 'https://')):
host = host.split('//', 1)[1].split('/')[0]
return host
# 方法 3: 从 URL 手动提取(降级方案)
if self.url:
return self.url.replace('http://', '').replace('https://', '').split('/')[0]
# 兜底:返回空字符串
return ''
except Exception as e:
# 异常处理:尽力从 input 或 URL 提取
logger.debug("提取主机名失败: %s,使用降级方案", e)
if self.input:
return self.input.strip().lower()
if self.url:
return self.url.replace('http://', '').replace('https://', '').split('/')[0]
return ''
# 从 URL 中提取主机名(优先使用 httpx 返回的 host否则自动提取
httpx_host = _sanitize_string(data.get('host', ''))
self.host = httpx_host if httpx_host else _extract_hostname(self.url)
def _save_batch_with_retry(
@@ -227,39 +219,31 @@ def _save_batch_with_retry(
}
except (OperationalError, DatabaseError, InterfaceError) as e:
# 数据库连接/操作错误,可重试
# 数据库级错误(连接中断、表结构不匹配等):按指数退避重试,最终失败时抛出异常让 Flow 失败
if attempt < max_retries - 1:
wait_time = 2 ** attempt # 指数退避: 1s, 2s, 4s
wait_time = 2 ** attempt
logger.warning(
"批次 %d 保存失败(第 %d 次尝试),%d秒后重试: %s",
batch_num, attempt + 1, wait_time, str(e)[:100]
)
time.sleep(wait_time)
else:
logger.error("批次 %d 保存失败(已重试 %d 次): %s", batch_num, max_retries, e)
return {
'success': False,
'created_websites': 0,
'skipped_failed': 0
}
except Exception as e:
# 其他未知错误 - 检查是否为连接问题
error_str = str(e).lower()
if 'connection' in error_str and attempt < max_retries - 1:
logger.warning(
"批次 %d 连接相关错误(尝试 %d/%d: %sRepository 装饰器会自动重连",
batch_num, attempt + 1, max_retries, str(e)
logger.error(
"批次 %d 保存失败(已重试 %d 次),将终止任务: %s",
batch_num,
max_retries,
e,
exc_info=True,
)
time.sleep(2)
else:
logger.error("批次 %d 未知错误: %s", batch_num, e, exc_info=True)
return {
'success': False,
'created_websites': 0,
'skipped_failed': 0
}
# 让上层 Task 感知失败,从而标记整个扫描为失败
raise
except Exception as e:
# 其他未知异常也不再吞掉,直接抛出以便 Flow 标记为失败
logger.error("批次 %d 未知错误: %s", batch_num, e, exc_info=True)
raise
# 理论上不会走到这里,保留兜底返回值以满足类型约束
return {
'success': False,
'created_websites': 0,
@@ -327,42 +311,39 @@ def _save_batch(
skipped_failed += 1
continue
# 解析时间戳
created_at = None
if hasattr(record, 'timestamp') and record.timestamp:
try:
created_at = parse_datetime(record.timestamp)
except (ValueError, TypeError) as e:
logger.warning(f"无法解析时间戳 {record.timestamp}: {e}")
# 使用 input 字段(原始扫描的 URL而不是 url 字段(重定向后的 URL
# 原因:避免多个不同的输入 URL 重定向到同一个 URL 时产生唯一约束冲突
# 例如http://example.com 和 https://example.com 都重定向到 https://example.com
# 如果使用 record.url两条记录会有相同的 url导致数据库冲突
# 如果使用 record.input两条记录保留原始输入不会冲突
normalized_url = normalize_url(record.input)
# 提取 host 字段域名或IP地址
host = record.host if record.host else ''
# 创建 WebsiteSnapshot DTO
snapshot_dto = WebsiteSnapshotDTO(
scan_id=scan_id,
target_id=target_id, # 主关联字段
url=normalized_url, # 保存原始输入 URL归一化后
host=host, # 主机名域名或IP地址
location=record.location, # location 字段保存重定向信息
title=record.title[:1000] if record.title else '',
web_server=record.webserver[:200] if record.webserver else '',
body_preview=record.body_preview[:1000] if record.body_preview else '',
content_type=record.content_type[:200] if record.content_type else '',
tech=record.tech if isinstance(record.tech, list) else [],
status=record.status_code,
content_length=record.content_length,
vhost=record.vhost
)
snapshot_items.append(snapshot_dto)
try:
# 使用 input 字段(原始扫描的 URL而不是 url 字段(重定向后的 URL
# 原因:避免多个不同的输入 URL 重定向到同一个 URL 时产生唯一约束冲突
# 例如http://example.com 和 https://example.com 都重定向到 https://example.com
# 如果使用 record.url两条记录会有相同的 url导致数据库冲突
# 如果使用 record.input两条记录保留原始输入不会冲突
normalized_url = normalize_url(record.input) if record.input else normalize_url(record.url)
# 提取 host 字段域名或IP地址
host = record.host if record.host else ''
# 创建 WebsiteSnapshot DTO
snapshot_dto = WebsiteSnapshotDTO(
scan_id=scan_id,
target_id=target_id, # 主关联字段
url=normalized_url, # 保存原始输入 URL归一化后
host=host, # 主机名域名或IP地址
location=record.location if record.location else '',
title=record.title if record.title else '',
webserver=record.webserver if record.webserver else '',
response_body=record.response_body if record.response_body else '',
content_type=record.content_type if record.content_type else '',
tech=record.tech if isinstance(record.tech, list) else [],
status_code=record.status_code,
content_length=record.content_length,
vhost=record.vhost,
response_headers=record.response_headers if record.response_headers else '',
)
snapshot_items.append(snapshot_dto)
except Exception as e:
logger.error("处理记录失败: %s,错误: %s", record.url, e)
continue
# ========== Step 3: 保存快照并同步到资产表(通过快照 Service==========
if snapshot_items:
@@ -384,28 +365,31 @@ def _parse_and_validate_line(line: str) -> Optional[HttpxRecord]:
Optional[HttpxRecord]: 有效的 httpx 扫描记录,或 None 如果验证失败
验证步骤:
1. 解析 JSON 格式
2. 验证数据类型为字典
3. 创建 HttpxRecord 对象
4. 验证必要字段url
1. 清理 NUL 字符
2. 解析 JSON 格式
3. 验证数据类型为字典
4. 创建 HttpxRecord 对象
5. 验证必要字段url
"""
try:
# 步骤 1: 解析 JSON
# 步骤 1: 清理 NUL 字符后再解析 JSON
line = _sanitize_string(line)
# 步骤 2: 解析 JSON
try:
line_data = json.loads(line, strict=False)
except json.JSONDecodeError:
# logger.info("跳过非 JSON 行: %s", line)
return None
# 步骤 2: 验证数据类型
# 步骤 3: 验证数据类型
if not isinstance(line_data, dict):
logger.info("跳过非字典数据")
return None
# 步骤 3: 创建记录
# 步骤 4: 创建记录
record = HttpxRecord(line_data)
# 步骤 4: 验证必要字段
# 步骤 5: 验证必要字段
if not record.url:
logger.info("URL 为空,跳过 - 数据: %s", str(line_data)[:200])
return None
@@ -414,7 +398,7 @@ def _parse_and_validate_line(line: str) -> Optional[HttpxRecord]:
return record
except Exception:
logger.info("跳过无法解析的行: %s", line[:100])
logger.info("跳过无法解析的行: %s", line[:100] if line else 'empty')
return None
@@ -462,8 +446,8 @@ def _parse_httpx_stream_output(
# yield 一条有效记录
yield record
# 每处理 1000 条记录输出一次进度
if valid_records % 1000 == 0:
# 每处理 5 条记录输出一次进度
if valid_records % 5 == 0:
logger.info("已解析 %d 条有效记录...", valid_records)
except subprocess.TimeoutExpired as e:
@@ -602,8 +586,8 @@ def _process_records_in_batches(
_process_batch(batch, scan_id, target_id, batch_num, total_stats, failed_batches, services)
batch = [] # 清空批次
# 每20个批次输出进度
if batch_num % 20 == 0:
# 每 2 个批次输出进度
if batch_num % 2 == 0:
logger.info("进度: 已处理 %d 批次,%d 条记录", batch_num, total_records)
# 保存最后一批
@@ -674,11 +658,7 @@ def _cleanup_resources(data_generator) -> None:
logger.error("关闭生成器时出错: %s", gen_close_error)
@task(
name='run_and_stream_save_websites',
retries=0,
log_prints=True
)
@task(name='run_and_stream_save_websites', retries=0)
def run_and_stream_save_websites_task(
cmd: str,
tool_name: str,
@@ -686,7 +666,7 @@ def run_and_stream_save_websites_task(
target_id: int,
cwd: Optional[str] = None,
shell: bool = False,
batch_size: int = 1000,
batch_size: int = 10,
timeout: Optional[int] = None,
log_file: Optional[str] = None
) -> dict:

View File

@@ -2,8 +2,8 @@
基于 execute_stream 的流式 URL 验证任务
主要功能:
1. 实时执行 httpx 命令验证 URL 存活
2. 流式处理命令输出,解析存活的 URL
1. 实时执行 httpx 命令验证 URL
2. 流式处理命令输出,解析 URL 信息
3. 批量保存到数据库Endpoint 表)
4. 避免一次性加载所有 URL 到内存
@@ -14,7 +14,7 @@
- 使用 execute_stream 实时处理输出
- 流式处理避免内存溢出
- 批量操作减少数据库交互
- 保存存活的 URLstatus 2xx/3xx
- 保存所有有效 URL包括 4xx/5xx,便于安全分析
"""
import logging
@@ -23,10 +23,11 @@ import subprocess
import time
from pathlib import Path
from prefect import task
from typing import Generator, Optional
from typing import Generator, Optional, Dict, Any
from django.db import IntegrityError, OperationalError, DatabaseError
from psycopg2 import InterfaceError
from dataclasses import dataclass
from urllib.parse import urlparse
from apps.asset.services.snapshot import EndpointSnapshotsService
from apps.scan.utils import execute_stream
@@ -63,7 +64,53 @@ def _sanitize_string(value: str) -> str:
return value.replace('\x00', '')
def _parse_and_validate_line(line: str) -> Optional[dict]:
def _extract_hostname(url: str) -> str:
"""
从 URL 提取主机名
Args:
url: URL 字符串
Returns:
str: 提取的主机名(小写)
"""
try:
if url:
parsed = urlparse(url)
if parsed.hostname:
return parsed.hostname
# 降级方案:手动提取
return url.replace('http://', '').replace('https://', '').split('/')[0].split(':')[0]
return ''
except Exception as e:
logger.debug("提取主机名失败: %s", e)
return ''
class HttpxRecord:
"""httpx 扫描记录数据类"""
def __init__(self, data: Dict[str, Any]):
self.url = _sanitize_string(data.get('url', ''))
self.input = _sanitize_string(data.get('input', ''))
self.title = _sanitize_string(data.get('title', ''))
self.status_code = data.get('status_code') # int不需要清理
self.content_length = data.get('content_length') # int不需要清理
self.content_type = _sanitize_string(data.get('content_type', ''))
self.location = _sanitize_string(data.get('location', ''))
self.webserver = _sanitize_string(data.get('webserver', ''))
self.response_body = _sanitize_string(data.get('body', ''))
self.tech = [_sanitize_string(t) for t in data.get('tech', []) if isinstance(t, str)] # 列表中的字符串也需要清理
self.vhost = data.get('vhost') # bool不需要清理
self.failed = data.get('failed', False) # bool不需要清理
self.response_headers = _sanitize_string(data.get('raw_header', ''))
# 从 URL 中提取主机名(优先使用 httpx 返回的 host否则自动提取
httpx_host = _sanitize_string(data.get('host', ''))
self.host = httpx_host if httpx_host else _extract_hostname(self.url)
def _parse_and_validate_line(line: str) -> Optional[HttpxRecord]:
"""
解析并验证单行 httpx JSON 输出
@@ -71,9 +118,7 @@ def _parse_and_validate_line(line: str) -> Optional[dict]:
line: 单行输出数据
Returns:
Optional[dict]: 有效的 httpx 记录,或 None 如果验证失败
只返回存活的 URL2xx/3xx 状态码)
Optional[HttpxRecord]: 有效的 httpx 记录,或 None 如果验证失败
"""
try:
# 清理 NUL 字符后再解析 JSON
@@ -83,7 +128,6 @@ def _parse_and_validate_line(line: str) -> Optional[dict]:
try:
line_data = json.loads(line, strict=False)
except json.JSONDecodeError:
# logger.info("跳过非 JSON 行: %s", line)
return None
# 验证数据类型
@@ -91,32 +135,15 @@ def _parse_and_validate_line(line: str) -> Optional[dict]:
logger.info("跳过非字典数据")
return None
# 获取必要字段
url = line_data.get('url', '').strip()
status_code = line_data.get('status_code')
# 创建记录
record = HttpxRecord(line_data)
if not url:
# 验证必要字段
if not record.url:
logger.info("URL 为空,跳过 - 数据: %s", str(line_data)[:200])
return None
# 只保存存活的 URL2xx 或 3xx
if status_code and (200 <= status_code < 400):
return {
'url': _sanitize_string(url),
'host': _sanitize_string(line_data.get('host', '')),
'status_code': status_code,
'title': _sanitize_string(line_data.get('title', '')),
'content_length': line_data.get('content_length', 0),
'content_type': _sanitize_string(line_data.get('content_type', '')),
'webserver': _sanitize_string(line_data.get('webserver', '')),
'location': _sanitize_string(line_data.get('location', '')),
'tech': line_data.get('tech', []),
'body_preview': _sanitize_string(line_data.get('body_preview', '')),
'vhost': line_data.get('vhost', False),
}
else:
logger.debug("URL 不存活(状态码: %s),跳过: %s", status_code, url)
return None
return record
except Exception:
logger.info("跳过无法解析的行: %s", line[:100] if line else 'empty')
@@ -130,7 +157,7 @@ def _parse_httpx_stream_output(
shell: bool = False,
timeout: Optional[int] = None,
log_file: Optional[str] = None
) -> Generator[dict, None, None]:
) -> Generator[HttpxRecord, None, None]:
"""
流式解析 httpx 命令输出
@@ -143,7 +170,7 @@ def _parse_httpx_stream_output(
log_file: 日志文件路径
Yields:
dict: 每次 yield 一条存活的 URL 记录
HttpxRecord: 每次 yield 一条存活的 URL 记录
"""
logger.info("开始流式解析 httpx 输出 - 命令: %s", cmd)
@@ -173,8 +200,8 @@ def _parse_httpx_stream_output(
# yield 一条有效记录(存活的 URL
yield record
# 每处理 500 条记录输出一次进度
if valid_records % 500 == 0:
# 每处理 100 条记录输出一次进度
if valid_records % 100 == 0:
logger.info("已解析 %d 条存活的 URL...", valid_records)
except subprocess.TimeoutExpired as e:
@@ -191,6 +218,78 @@ def _parse_httpx_stream_output(
)
def _validate_task_parameters(cmd: str, target_id: int, scan_id: int, cwd: Optional[str]) -> None:
"""
验证任务参数的有效性
Args:
cmd: 扫描命令
target_id: 目标ID
scan_id: 扫描ID
cwd: 工作目录
Raises:
ValueError: 参数验证失败
"""
if not cmd or not cmd.strip():
raise ValueError("扫描命令不能为空")
if target_id is None:
raise ValueError("target_id 不能为 None必须指定目标ID")
if scan_id is None:
raise ValueError("scan_id 不能为 None必须指定扫描ID")
# 验证工作目录(如果指定)
if cwd and not Path(cwd).exists():
raise ValueError(f"工作目录不存在: {cwd}")
def _build_final_result(stats: dict) -> dict:
"""
构建最终结果并输出日志
Args:
stats: 处理统计信息
Returns:
dict: 最终结果
"""
logger.info(
"✓ URL 验证任务完成 - 处理记录: %d%d 批次),创建端点: %d,跳过(失败): %d",
stats['processed_records'], stats['batch_count'], stats['created_endpoints'],
stats['skipped_failed']
)
# 如果没有创建任何记录,给出明确提示
if stats['created_endpoints'] == 0:
logger.warning(
"⚠️ 没有创建任何端点记录可能原因1) 命令输出格式问题 2) 重复数据被忽略 3) 所有请求都失败"
)
return {
'processed_records': stats['processed_records'],
'created_endpoints': stats['created_endpoints'],
'skipped_failed': stats['skipped_failed']
}
def _cleanup_resources(data_generator) -> None:
"""
清理任务资源
Args:
data_generator: 数据生成器
"""
# 确保生成器被正确关闭
if data_generator is not None:
try:
data_generator.close()
logger.debug("已关闭数据生成器")
except Exception as gen_close_error:
logger.error("关闭生成器时出错: %s", gen_close_error)
def _save_batch_with_retry(
batch: list,
scan_id: int,
@@ -211,14 +310,19 @@ def _save_batch_with_retry(
max_retries: 最大重试次数
Returns:
dict: {'success': bool, 'saved_count': int}
dict: {
'success': bool,
'created_endpoints': int,
'skipped_failed': int
}
"""
for attempt in range(max_retries):
try:
count = _save_batch(batch, scan_id, target_id, batch_num, services)
stats = _save_batch(batch, scan_id, target_id, batch_num, services)
return {
'success': True,
'saved_count': count
'created_endpoints': stats.get('created_endpoints', 0),
'skipped_failed': stats.get('skipped_failed', 0)
}
except IntegrityError as e:
@@ -226,7 +330,8 @@ def _save_batch_with_retry(
logger.error("批次 %d 数据完整性错误,跳过: %s", batch_num, str(e)[:100])
return {
'success': False,
'saved_count': 0
'created_endpoints': 0,
'skipped_failed': 0
}
except (OperationalError, DatabaseError, InterfaceError) as e:
@@ -257,7 +362,8 @@ def _save_batch_with_retry(
# 理论上不会走到这里,保留兜底返回值以满足类型约束
return {
'success': False,
'saved_count': 0
'created_endpoints': 0,
'skipped_failed': 0
}
@@ -267,49 +373,72 @@ def _save_batch(
target_id: int,
batch_num: int,
services: ServiceSet
) -> int:
) -> dict:
"""
保存一个批次的数据到数据库
Args:
batch: 数据批次list of dict
batch: 数据批次list of HttpxRecord
scan_id: 扫描任务 ID
target_id: 目标 ID
batch_num: 批次编号
services: Service 集合
Returns:
int: 创建的记录数
dict: 包含创建和跳过记录的统计信息
"""
# 参数验证
if not isinstance(batch, list):
raise TypeError(f"batch 必须是 list 类型,实际: {type(batch).__name__}")
if not batch:
logger.debug("批次 %d 为空,跳过处理", batch_num)
return 0
return {
'created_endpoints': 0,
'skipped_failed': 0
}
# 统计变量
skipped_failed = 0
# 批量构造 Endpoint 快照 DTO
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
snapshots = []
for record in batch:
# 跳过失败的请求
if record.failed:
skipped_failed += 1
continue
try:
# Endpoint URL 直接使用原始值,不做标准化
# 原因Endpoint URL 来自 waymore/katana包含路径和参数标准化可能改变含义
url = record.input if record.input else record.url
# 提取 host 字段域名或IP地址
host = record.host if record.host else ''
dto = EndpointSnapshotDTO(
scan_id=scan_id,
url=record['url'],
host=record.get('host', ''),
title=record.get('title', ''),
status_code=record.get('status_code'),
content_length=record.get('content_length', 0),
location=record.get('location', ''),
webserver=record.get('webserver', ''),
content_type=record.get('content_type', ''),
tech=record.get('tech', []),
body_preview=record.get('body_preview', ''),
vhost=record.get('vhost', False),
matched_gf_patterns=[],
target_id=target_id,
url=url,
host=host,
title=record.title if record.title else '',
status_code=record.status_code,
content_length=record.content_length,
location=record.location if record.location else '',
webserver=record.webserver if record.webserver else '',
content_type=record.content_type if record.content_type else '',
tech=record.tech if isinstance(record.tech, list) else [],
response_body=record.response_body if record.response_body else '',
vhost=record.vhost if record.vhost else False,
matched_gf_patterns=[],
response_headers=record.response_headers if record.response_headers else '',
)
snapshots.append(dto)
except Exception as e:
logger.error("处理记录失败: %s,错误: %s", record.get('url', 'Unknown'), e)
logger.error("处理记录失败: %s,错误: %s", record.url, e)
continue
if snapshots:
@@ -318,15 +447,69 @@ def _save_batch(
services.snapshot.save_and_sync(snapshots)
count = len(snapshots)
logger.info(
"批次 %d: 保存了 %d 个存活的 URL%d 个)",
batch_num, count, len(batch)
"批次 %d: 保存了 %d 个存活的 URL%d,跳过失败: %d",
batch_num, count, len(batch), skipped_failed
)
return count
return {
'created_endpoints': count,
'skipped_failed': skipped_failed
}
except Exception as e:
logger.error("批次 %d 批量保存失败: %s", batch_num, e)
raise
return 0
return {
'created_endpoints': 0,
'skipped_failed': skipped_failed
}
def _accumulate_batch_stats(total_stats: dict, batch_result: dict) -> None:
"""
累加批次统计信息
Args:
total_stats: 总统计信息字典
batch_result: 批次结果字典
"""
total_stats['created_endpoints'] += batch_result.get('created_endpoints', 0)
total_stats['skipped_failed'] += batch_result.get('skipped_failed', 0)
def _process_batch(
batch: list,
scan_id: int,
target_id: int,
batch_num: int,
total_stats: dict,
failed_batches: list,
services: ServiceSet
) -> None:
"""
处理单个批次
Args:
batch: 数据批次
scan_id: 扫描ID
target_id: 目标ID
batch_num: 批次编号
total_stats: 总统计信息
failed_batches: 失败批次列表
services: Service 集合(必须,依赖注入)
"""
result = _save_batch_with_retry(
batch, scan_id, target_id, batch_num, services
)
# 累计统计信息(失败时可能有部分数据已保存)
_accumulate_batch_stats(total_stats, result)
if not result['success']:
failed_batches.append(batch_num)
logger.warning(
"批次 %d 保存失败,但已累计统计信息:创建端点=%d",
batch_num, result.get('created_endpoints', 0)
)
def _process_records_in_batches(
@@ -337,7 +520,7 @@ def _process_records_in_batches(
services: ServiceSet
) -> dict:
"""
分批处理记录并保存到数据库
流式处理记录并分批保存
Args:
data_generator: 数据生成器
@@ -347,14 +530,23 @@ def _process_records_in_batches(
services: Service 集合
Returns:
dict: 处理统计结果
dict: 处理统计信息
Raises:
RuntimeError: 存在失败批次时抛出
"""
batch = []
batch_num = 0
total_records = 0
total_saved = 0
batch_num = 0
failed_batches = []
batch = []
# 统计信息
total_stats = {
'created_endpoints': 0,
'skipped_failed': 0
}
# 流式读取生成器并分批保存
for record in data_generator:
batch.append(record)
total_records += 1
@@ -362,46 +554,35 @@ def _process_records_in_batches(
# 达到批次大小,执行保存
if len(batch) >= batch_size:
batch_num += 1
result = _save_batch_with_retry(
batch, scan_id, target_id, batch_num, services
)
if result['success']:
total_saved += result['saved_count']
else:
failed_batches.append(batch_num)
_process_batch(batch, scan_id, target_id, batch_num, total_stats, failed_batches, services)
batch = [] # 清空批次
# 每 10 个批次输出进度
if batch_num % 10 == 0:
logger.info(
"进度: 已处理 %d 批次,%d 条记录,保存 %d",
batch_num, total_records, total_saved
)
logger.info("进度: 已处理 %d 批次,%d 条记录", batch_num, total_records)
# 保存最后一批
if batch:
batch_num += 1
result = _save_batch_with_retry(
batch, scan_id, target_id, batch_num, services
_process_batch(batch, scan_id, target_id, batch_num, total_stats, failed_batches, services)
# 检查失败批次
if failed_batches:
error_msg = (
f"流式保存 URL 验证结果时出现失败批次,处理记录: {total_records}"
f"失败批次: {failed_batches}"
)
if result['success']:
total_saved += result['saved_count']
else:
failed_batches.append(batch_num)
logger.warning(error_msg)
raise RuntimeError(error_msg)
return {
'processed_records': total_records,
'saved_urls': total_saved,
'failed_urls': total_records - total_saved,
'batch_count': batch_num,
'failed_batches': failed_batches
**total_stats
}
@task(name="run_and_stream_save_urls", retries=3, retry_delay_seconds=10)
@task(name="run_and_stream_save_urls", retries=0)
def run_and_stream_save_urls_task(
cmd: str,
tool_name: str,
@@ -409,7 +590,7 @@ def run_and_stream_save_urls_task(
target_id: int,
cwd: Optional[str] = None,
shell: bool = False,
batch_size: int = 500,
batch_size: int = 100,
timeout: Optional[int] = None,
log_file: Optional[str] = None
) -> dict:
@@ -417,17 +598,18 @@ def run_and_stream_save_urls_task(
执行 httpx 验证并流式保存存活的 URL
该任务将:
1. 执行 httpx 命令验证 URL 存活
2. 流式处理输出,实时解析
3. 批量保存存活的 URL 到 Endpoint 表
1. 验证输入参数
2. 初始化资源(缓存、生成器)
3. 流式处理记录并分批保存
4. 构建并返回结果统计
Args:
cmd: httpx 命令
tool_name: 工具名称('httpx'
scan_id: 扫描任务 ID
target_id: 目标 ID
cwd: 工作目录
shell: 是否使用 shell 执行
cwd: 工作目录(可选)
shell: 是否使用 shell 执行(默认 False
batch_size: 批次大小(默认 500
timeout: 超时时间(秒)
log_file: 日志文件路径
@@ -435,11 +617,14 @@ def run_and_stream_save_urls_task(
Returns:
dict: {
'processed_records': int, # 处理的记录总数
'saved_urls': int, # 保存的存活 URL
'failed_urls': int, # 失败/死链
'batch_count': int, # 批次数
'failed_batches': list # 失败的批次号
'created_endpoints': int, # 创建的端点记录
'skipped_failed': int, # 因请求失败跳过的记录
}
Raises:
ValueError: 参数验证失败
RuntimeError: 命令执行或数据库操作失败
subprocess.TimeoutExpired: 命令执行超时
"""
logger.info(
"开始执行流式 URL 验证任务 - target_id=%s, 超时=%s秒, 命令: %s",
@@ -449,33 +634,30 @@ def run_and_stream_save_urls_task(
data_generator = None
try:
# 1. 初始化资源
# 1. 验证参数
_validate_task_parameters(cmd, target_id, scan_id, cwd)
# 2. 初始化资源
data_generator = _parse_httpx_stream_output(
cmd, tool_name, cwd, shell, timeout, log_file
)
services = ServiceSet.create_default()
# 2. 流式处理记录并分批保存
# 3. 流式处理记录并分批保存
stats = _process_records_in_batches(
data_generator, scan_id, target_id, batch_size, services
)
# 3. 输出最终统计
logger.info(
"✓ URL 验证任务完成 - 处理: %d, 存活: %d, 失败: %d",
stats['processed_records'],
stats['saved_urls'],
stats['failed_urls']
)
return stats
# 4. 构建最终结果
return _build_final_result(stats)
except subprocess.TimeoutExpired:
# 超时异常直接向上传播,保留异常类型
logger.warning(
"⚠️ URL 验证任务超时 - target_id=%s, 超时=%s",
target_id, timeout
)
raise
raise # 直接重新抛出,不包装
except Exception as e:
error_msg = f"流式执行 URL 验证任务失败: {e}"
@@ -483,12 +665,5 @@ def run_and_stream_save_urls_task(
raise RuntimeError(error_msg) from e
finally:
# 清理资源
if data_generator is not None:
try:
# 确保生成器被正确关闭
data_generator.close()
except (GeneratorExit, StopIteration):
pass
except Exception as e:
logger.warning("关闭数据生成器时出错: %s", e)
# 5. 清理资源
_cleanup_resources(data_generator)

View File

@@ -0,0 +1,80 @@
"""
配置合并工具模块
提供多引擎 YAML 配置的冲突检测和合并功能。
"""
from typing import List, Tuple
import yaml
class ConfigConflictError(Exception):
"""配置冲突异常
当两个或多个引擎定义相同的顶层扫描类型键时抛出。
"""
def __init__(self, conflicts: List[Tuple[str, str, str]]):
"""
参数:
conflicts: (键, 引擎1名称, 引擎2名称) 元组列表
"""
self.conflicts = conflicts
msg = "; ".join([f"{k} 同时存在于「{e1}」和「{e2}" for k, e1, e2 in conflicts])
super().__init__(f"扫描类型冲突: {msg}")
def merge_engine_configs(engines: List[Tuple[str, str]]) -> str:
"""
合并多个引擎的 YAML 配置。
参数:
engines: (引擎名称, 配置YAML) 元组列表
返回:
合并后的 YAML 字符串
异常:
ConfigConflictError: 当顶层键冲突时
"""
if not engines:
return ""
if len(engines) == 1:
return engines[0][1]
# 追踪每个顶层键属于哪个引擎
key_to_engine: dict[str, str] = {}
conflicts: List[Tuple[str, str, str]] = []
for engine_name, config_yaml in engines:
if not config_yaml or not config_yaml.strip():
continue
try:
parsed = yaml.safe_load(config_yaml)
except yaml.YAMLError:
# 无效 YAML 跳过
continue
if not isinstance(parsed, dict):
continue
# 检查顶层键冲突
for key in parsed.keys():
if key in key_to_engine:
conflicts.append((key, key_to_engine[key], engine_name))
else:
key_to_engine[key] = engine_name
if conflicts:
raise ConfigConflictError(conflicts)
# 无冲突,用双换行符连接配置
configs = []
for _, config_yaml in engines:
if config_yaml and config_yaml.strip():
configs.append(config_yaml.strip())
return "\n\n".join(configs)

View File

@@ -96,7 +96,13 @@ def ensure_wordlist_local(wordlist_name: str) -> str:
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
with urllib_request.urlopen(download_url, context=ssl_context) as resp:
# 创建带 API Key 的请求
req = urllib_request.Request(download_url)
worker_api_key = os.getenv('WORKER_API_KEY', '')
if worker_api_key:
req.add_header('X-Worker-API-Key', worker_api_key)
with urllib_request.urlopen(req, context=ssl_context) as resp:
if resp.status != 200:
raise RuntimeError(f"下载字典失败HTTP {resp.status}")
data = resp.read()

View File

@@ -9,6 +9,7 @@ import logging
from apps.common.response_helpers import success_response, error_response
from apps.common.error_codes import ErrorCodes
from apps.scan.utils.config_merger import ConfigConflictError
logger = logging.getLogger(__name__)
@@ -118,7 +119,7 @@ class ScanViewSet(viewsets.ModelViewSet):
请求参数:
{
"targets": [{"name": "example.com"}, {"name": "https://example.com/api"}],
"engine_id": 1
"engine_ids": [1, 2]
}
支持的输入格式:
@@ -133,7 +134,7 @@ class ScanViewSet(viewsets.ModelViewSet):
serializer.is_valid(raise_exception=True)
targets_data = serializer.validated_data['targets']
engine_id = serializer.validated_data.get('engine_id')
engine_ids = serializer.validated_data.get('engine_ids')
try:
# 提取输入字符串列表
@@ -141,7 +142,7 @@ class ScanViewSet(viewsets.ModelViewSet):
# 1. 使用 QuickScanService 解析输入并创建资产
quick_scan_service = QuickScanService()
result = quick_scan_service.process_quick_scan(inputs, engine_id)
result = quick_scan_service.process_quick_scan(inputs, engine_ids[0] if engine_ids else None)
targets = result['targets']
@@ -153,17 +154,19 @@ class ScanViewSet(viewsets.ModelViewSet):
status_code=status.HTTP_400_BAD_REQUEST
)
# 2. 获取扫描引擎
engine_service = EngineService()
engine = engine_service.get_engine(engine_id)
if not engine:
raise ValidationError(f'扫描引擎 ID {engine_id} 不存在')
# 2. 准备多引擎扫描
scan_service = ScanService()
_, merged_configuration, engine_names, engine_ids = scan_service.prepare_initiate_scan_multi_engine(
target_id=targets[0].id, # 使用第一个目标来验证引擎
engine_ids=engine_ids
)
# 3. 批量发起扫描
scan_service = ScanService()
created_scans = scan_service.create_scans(
targets=targets,
engine=engine
engine_ids=engine_ids,
engine_names=engine_names,
merged_configuration=merged_configuration
)
# 检查是否成功创建扫描任务
@@ -192,6 +195,17 @@ class ScanViewSet(viewsets.ModelViewSet):
},
status_code=status.HTTP_201_CREATED
)
except ConfigConflictError as e:
return error_response(
code='CONFIG_CONFLICT',
message=str(e),
details=[
{'key': k, 'engines': [e1, e2]}
for k, e1, e2 in e.conflicts
],
status_code=status.HTTP_400_BAD_REQUEST
)
except ValidationError as e:
return error_response(
@@ -214,7 +228,7 @@ class ScanViewSet(viewsets.ModelViewSet):
请求参数:
- organization_id: 组织ID (int, 可选)
- target_id: 目标ID (int, 可选)
- engine_id: 扫描引擎ID (int, 必填)
- engine_ids: 扫描引擎ID列表 (list[int], 必填)
注意: organization_id 和 target_id 二选一
@@ -224,21 +238,38 @@ class ScanViewSet(viewsets.ModelViewSet):
# 获取请求数据
organization_id = request.data.get('organization_id')
target_id = request.data.get('target_id')
engine_id = request.data.get('engine_id')
engine_ids = request.data.get('engine_ids')
# 验证 engine_ids
if not engine_ids:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='缺少必填参数: engine_ids',
status_code=status.HTTP_400_BAD_REQUEST
)
if not isinstance(engine_ids, list):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='engine_ids 必须是数组',
status_code=status.HTTP_400_BAD_REQUEST
)
try:
# 步骤1准备扫描所需的数据(验证参数、查询资源、返回目标列表和引擎)
# 步骤1准备多引擎扫描所需的数据
scan_service = ScanService()
targets, engine = scan_service.prepare_initiate_scan(
targets, merged_configuration, engine_names, engine_ids = scan_service.prepare_initiate_scan_multi_engine(
organization_id=organization_id,
target_id=target_id,
engine_id=engine_id
engine_ids=engine_ids
)
# 步骤2批量创建扫描记录并分发扫描任务
created_scans = scan_service.create_scans(
targets=targets,
engine=engine
engine_ids=engine_ids,
engine_names=engine_names,
merged_configuration=merged_configuration
)
# 检查是否成功创建扫描任务
@@ -259,6 +290,17 @@ class ScanViewSet(viewsets.ModelViewSet):
},
status_code=status.HTTP_201_CREATED
)
except ConfigConflictError as e:
return error_response(
code='CONFIG_CONFLICT',
message=str(e),
details=[
{'key': k, 'engines': [e1, e2]}
for k, e1, e2 in e.conflicts
],
status_code=status.HTTP_400_BAD_REQUEST
)
except ObjectDoesNotExist as e:
# 资源不存在错误(由 service 层抛出)

View File

@@ -17,6 +17,7 @@ from ..serializers import (
)
from ..services.scheduled_scan_service import ScheduledScanService
from ..repositories import ScheduledScanDTO
from ..utils.config_merger import ConfigConflictError
from apps.common.pagination import BasePagination
from apps.common.response_helpers import success_response, error_response
from apps.common.error_codes import ErrorCodes
@@ -67,7 +68,7 @@ class ScheduledScanViewSet(viewsets.ModelViewSet):
data = serializer.validated_data
dto = ScheduledScanDTO(
name=data['name'],
engine_id=data['engine_id'],
engine_ids=data['engine_ids'],
organization_id=data.get('organization_id'),
target_id=data.get('target_id'),
cron_expression=data.get('cron_expression', '0 2 * * *'),
@@ -81,6 +82,16 @@ class ScheduledScanViewSet(viewsets.ModelViewSet):
data=response_serializer.data,
status_code=status.HTTP_201_CREATED
)
except ConfigConflictError as e:
return error_response(
code='CONFIG_CONFLICT',
message=str(e),
details=[
{'key': k, 'engines': [e1, e2]}
for k, e1, e2 in e.conflicts
],
status_code=status.HTTP_400_BAD_REQUEST
)
except ValidationError as e:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
@@ -98,7 +109,7 @@ class ScheduledScanViewSet(viewsets.ModelViewSet):
data = serializer.validated_data
dto = ScheduledScanDTO(
name=data.get('name'),
engine_id=data.get('engine_id'),
engine_ids=data.get('engine_ids'),
organization_id=data.get('organization_id'),
target_id=data.get('target_id'),
cron_expression=data.get('cron_expression'),
@@ -109,6 +120,16 @@ class ScheduledScanViewSet(viewsets.ModelViewSet):
response_serializer = ScheduledScanSerializer(scheduled_scan)
return success_response(data=response_serializer.data)
except ConfigConflictError as e:
return error_response(
code='CONFIG_CONFLICT',
message=str(e),
details=[
{'key': k, 'engines': [e1, e2]}
for k, e1, e2 in e.conflicts
],
status_code=status.HTTP_400_BAD_REQUEST
)
except ValidationError as e:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,

View File

@@ -0,0 +1,52 @@
# Generated by Django 5.2.7 on 2026-01-02 04:45
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
]
operations = [
migrations.CreateModel(
name='Target',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('name', models.CharField(blank=True, default='', help_text='目标标识(域名/IP/CIDR', max_length=300)),
('type', models.CharField(choices=[('domain', '域名'), ('ip', 'IP地址'), ('cidr', 'CIDR范围')], db_index=True, default='domain', help_text='目标类型', max_length=20)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('last_scanned_at', models.DateTimeField(blank=True, help_text='最后扫描时间', null=True)),
('deleted_at', models.DateTimeField(blank=True, db_index=True, help_text='删除时间NULL表示未删除', null=True)),
],
options={
'verbose_name': '扫描目标',
'verbose_name_plural': '扫描目标',
'db_table': 'target',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['type'], name='target_type_36a73c_idx'), models.Index(fields=['-created_at'], name='target_created_67f489_idx'), models.Index(fields=['deleted_at', '-created_at'], name='target_deleted_9fc9da_idx'), models.Index(fields=['deleted_at', 'type'], name='target_deleted_306a89_idx'), models.Index(fields=['name'], name='target_name_f1c641_idx')],
'constraints': [models.UniqueConstraint(condition=models.Q(('deleted_at__isnull', True)), fields=('name',), name='unique_target_name_active')],
},
),
migrations.CreateModel(
name='Organization',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('name', models.CharField(blank=True, default='', help_text='组织名称', max_length=300)),
('description', models.CharField(blank=True, default='', help_text='组织描述', max_length=1000)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('deleted_at', models.DateTimeField(blank=True, db_index=True, help_text='删除时间NULL表示未删除', null=True)),
('targets', models.ManyToManyField(blank=True, help_text='所属目标列表', related_name='organizations', to='targets.target')),
],
options={
'verbose_name': '组织',
'verbose_name_plural': '组织',
'db_table': 'organization',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='organizatio_created_012eac_idx'), models.Index(fields=['deleted_at', '-created_at'], name='organizatio_deleted_2c604f_idx'), models.Index(fields=['name'], name='organizatio_name_bcc2ee_idx')],
'constraints': [models.UniqueConstraint(condition=models.Q(('deleted_at__isnull', True)), fields=('name',), name='unique_organization_name_active')],
},
),
]

View File

@@ -177,6 +177,10 @@ STATIC_URL = 'static/'
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
# ==================== Worker API Key 配置 ====================
# Worker 节点认证密钥(从环境变量读取)
WORKER_API_KEY = os.environ.get('WORKER_API_KEY', '')
# ==================== REST Framework 配置 ====================
REST_FRAMEWORK = {
'DEFAULT_PAGINATION_CLASS': 'apps.common.pagination.BasePagination', # 使用基础分页器
@@ -186,6 +190,14 @@ REST_FRAMEWORK = {
'apps.common.authentication.CsrfExemptSessionAuthentication',
],
# 全局权限配置:默认需要认证,公开端点和 Worker 端点在权限类中单独处理
'DEFAULT_PERMISSION_CLASSES': [
'apps.common.permissions.IsAuthenticatedOrPublic',
],
# 自定义异常处理器:统一 401/403 错误响应格式
'EXCEPTION_HANDLER': 'apps.common.exception_handlers.custom_exception_handler',
# JSON 命名格式转换:后端 snake_case ↔ 前端 camelCase
'DEFAULT_RENDERER_CLASSES': (
'djangorestframework_camel_case.render.CamelCaseJSONRenderer', # 响应数据转换为 camelCase
@@ -345,12 +357,6 @@ TASK_SUBMIT_INTERVAL = int(os.getenv('TASK_SUBMIT_INTERVAL', '6'))
# 本地 Worker Docker 网络名称(与 docker-compose.yml 中定义的一致)
DOCKER_NETWORK_NAME = os.getenv('DOCKER_NETWORK_NAME', 'xingrin_network')
# Docker API 版本配置(防止客户端与服务端版本不匹配)
# API 1.40 支持 Docker 19.03+ (2019年至今),具有最大兼容性
# 如果所有 worker 节点都是 Docker 20.10+,可设置为 1.41
# 查看 worker 节点的 API 版本ssh user@worker "docker version --format '{{.Server.APIVersion}}'"
DOCKER_API_VERSION = os.getenv('DOCKER_API_VERSION', '1.40')
# 宿主机挂载源路径(所有节点统一使用固定路径)
# 部署前需创建mkdir -p /opt/xingrin
HOST_RESULTS_DIR = '/opt/xingrin/results'
@@ -361,25 +367,16 @@ HOST_WORDLISTS_DIR = '/opt/xingrin/wordlists'
# ============================================
# Worker 配置中心(任务容器从 /api/workers/config/ 获取)
# ============================================
# Worker 数据库/Redis 地址由 worker_views.py 的 config API 动态返回
# Worker 数据库地址由 worker_views.py 的 config API 动态返回
# 根据请求来源(本地/远程)返回不同的配置:
# - 本地 WorkerDocker 网络内):使用内部服务名postgres, redis
# - 本地 WorkerDocker 网络内):使用内部服务名 postgres
# - 远程 Worker公网访问使用 PUBLIC_HOST
#
# 以下变量仅作为备用/兼容配置,实际配置由 API 动态生成
# 注意Redis 仅在 Server 容器内使用Worker 不需要直接连接 Redis
_db_host = DATABASES['default']['HOST']
_is_internal_db = _db_host in ('postgres', 'localhost', '127.0.0.1')
WORKER_DB_HOST = os.getenv('WORKER_DB_HOST', _db_host)
# 远程 Worker 访问 Redis 的地址(自动推导)
# - 如果 PUBLIC_HOST 是外部 IP → 使用 PUBLIC_HOST
# - 如果 PUBLIC_HOST 是 Docker 内部名 → 使用 redis本地部署
_is_internal_public = PUBLIC_HOST in ('server', 'localhost', '127.0.0.1')
WORKER_REDIS_URL = os.getenv(
'WORKER_REDIS_URL',
'redis://redis:6379/0' if _is_internal_public else f'redis://{PUBLIC_HOST}:6379/0'
)
# 容器内挂载目标路径(统一使用 /opt/xingrin
CONTAINER_RESULTS_MOUNT = '/opt/xingrin/results'
CONTAINER_LOGS_MOUNT = '/opt/xingrin/logs'

View File

@@ -16,7 +16,6 @@ Including another URLconf
"""
from django.contrib import admin
from django.urls import path, include
from rest_framework import permissions
from drf_yasg.views import get_schema_view
from drf_yasg import openapi
@@ -30,7 +29,6 @@ schema_view = get_schema_view(
description="Web 应用侦察工具 API 文档",
),
public=True,
permission_classes=(permissions.AllowAny,),
)
urlpatterns = [

File diff suppressed because it is too large Load Diff

View File

@@ -41,6 +41,7 @@ python-dateutil==2.9.0
pytz==2024.1
validators==0.22.0
PyYAML==6.0.1
ruamel.yaml>=0.18.0 # 保留注释的 YAML 解析
colorlog==6.8.2 # 彩色日志输出
python-json-logger==2.0.7 # JSON 结构化日志
Jinja2>=3.1.6 # 命令模板引擎

View File

@@ -180,6 +180,28 @@ def get_db_config() -> dict:
}
def generate_raw_response_headers(headers_dict: dict) -> str:
"""
将响应头字典转换为原始 HTTP 响应头字符串格式
Args:
headers_dict: 响应头字典
Returns:
原始 HTTP 响应头字符串,格式如:
HTTP/1.1 200 OK
Server: nginx
Content-Type: text/html
...
"""
lines = ['HTTP/1.1 200 OK']
for key, value in headers_dict.items():
# 将下划线转换为连字符,并首字母大写
header_name = key.replace('_', '-').title()
lines.append(f'{header_name}: {value}')
return '\r\n'.join(lines)
DB_CONFIG = get_db_config()
@@ -238,6 +260,12 @@ class TestDataGenerator:
def clear_data(self):
"""清除所有测试数据"""
cur = self.conn.cursor()
# 先删除 IMMV避免 pg_ivm 的 anyarray bug
print(" 删除 IMMV...")
cur.execute("DROP TABLE IF EXISTS asset_search_view CASCADE")
self.conn.commit()
tables = [
# 指纹表
'ehole_fingerprint', 'goby_fingerprint', 'wappalyzer_fingerprint',
@@ -254,6 +282,26 @@ class TestDataGenerator:
for table in tables:
cur.execute(f"DELETE FROM {table}")
self.conn.commit()
# 重建 IMMV
print(" 重建 IMMV...")
cur.execute("""
SELECT pgivm.create_immv('asset_search_view', $$
SELECT
w.id,
w.url,
w.host,
w.title,
w.tech,
w.status_code,
w.response_headers,
w.response_body,
w.created_at,
w.target_id
FROM website w
$$)
""")
self.conn.commit()
print(" ✓ 数据清除完成\n")
def create_workers(self) -> list:
@@ -548,6 +596,10 @@ class TestDataGenerator:
'Authentication failed for protected resources.',
]
# 获取引擎名称映射
cur.execute("SELECT id, name FROM scan_engine WHERE id = ANY(%s)", (engine_ids,))
engine_name_map = {row[0]: row[1] for row in cur.fetchall()}
ids = []
# 随机选择目标数量 - 增加到 80-120 个
num_targets = min(random.randint(80, 120), len(target_ids))
@@ -558,7 +610,10 @@ class TestDataGenerator:
num_scans = random.randint(3, 15)
for _ in range(num_scans):
status = random.choices(statuses, weights=status_weights)[0]
engine_id = random.choice(engine_ids)
# 随机选择 1-3 个引擎
num_engines = random.randint(1, min(3, len(engine_ids)))
selected_engine_ids = random.sample(engine_ids, num_engines)
selected_engine_names = [engine_name_map.get(eid, f'Engine-{eid}') for eid in selected_engine_ids]
worker_id = random.choice(worker_ids) if worker_ids else None
progress = random.randint(10, 95) if status == 'running' else (100 if status == 'completed' else random.randint(0, 50))
@@ -581,20 +636,20 @@ class TestDataGenerator:
cur.execute("""
INSERT INTO scan (
target_id, engine_id, status, worker_id, progress, current_stage,
target_id, engine_ids, engine_names, merged_configuration, status, worker_id, progress, current_stage,
results_dir, error_message, container_ids, stage_progress,
cached_subdomains_count, cached_websites_count, cached_endpoints_count,
cached_ips_count, cached_directories_count, cached_vulns_total,
cached_vulns_critical, cached_vulns_high, cached_vulns_medium, cached_vulns_low,
created_at, stopped_at, deleted_at
) VALUES (
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
NOW() - INTERVAL '%s days', %s, NULL
)
RETURNING id
""", (
target_id, engine_id, status, worker_id, progress, stage,
target_id, selected_engine_ids, json.dumps(selected_engine_names), '', status, worker_id, progress, stage,
f'/app/results/scan_{target_id}_{random.randint(1000, 9999)}', error_msg, '{}', '{}',
subdomains, websites, endpoints, ips, directories, vulns_total,
vulns_critical, vulns_high, vulns_medium, vulns_low,
@@ -651,6 +706,10 @@ class TestDataGenerator:
num_schedules = random.randint(40, 50)
selected = random.sample(schedule_templates, min(num_schedules, len(schedule_templates)))
# 获取引擎名称映射
cur.execute("SELECT id, name FROM scan_engine WHERE id = ANY(%s)", (engine_ids,))
engine_name_map = {row[0]: row[1] for row in cur.fetchall()}
count = 0
for name_base, cron_template in selected:
name = f'{name_base}-{suffix}-{count:02d}'
@@ -662,7 +721,11 @@ class TestDataGenerator:
)
enabled = random.random() > 0.3 # 70% 启用
engine_id = random.choice(engine_ids)
# 随机选择 1-3 个引擎
num_engines = random.randint(1, min(3, len(engine_ids)))
selected_engine_ids = random.sample(engine_ids, num_engines)
selected_engine_names = [engine_name_map.get(eid, f'Engine-{eid}') for eid in selected_engine_ids]
# 随机决定关联组织还是目标
if org_ids and target_ids:
if random.random() > 0.5:
@@ -686,12 +749,12 @@ class TestDataGenerator:
cur.execute("""
INSERT INTO scheduled_scan (
name, engine_id, organization_id, target_id, cron_expression, is_enabled,
name, engine_ids, engine_names, merged_configuration, organization_id, target_id, cron_expression, is_enabled,
run_count, last_run_time, next_run_time, created_at, updated_at
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, NOW() - INTERVAL '%s days', NOW())
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW() - INTERVAL '%s days', NOW())
ON CONFLICT DO NOTHING
""", (
name, engine_id, org_id, target_id, cron, enabled,
name, selected_engine_ids, json.dumps(selected_engine_names), '', org_id, target_id, cron, enabled,
run_count if has_run else 0,
datetime.now() - timedelta(days=random.randint(0, 14), hours=random.randint(0, 23)) if has_run else None,
datetime.now() + timedelta(hours=random.randint(1, 336)) # 最多 2 周后
@@ -812,7 +875,7 @@ class TestDataGenerator:
]
# 真实的 body preview 内容
body_previews = [
response_bodies = [
'<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><title>Login - Enterprise Portal</title><link rel="stylesheet" href="/assets/css/main.css"></head><body><div id="app"></div><script src="/assets/js/bundle.js"></script></body></html>',
'<!DOCTYPE html><html><head><title>Dashboard</title><meta name="description" content="Enterprise management dashboard for monitoring and analytics"><link rel="icon" href="/favicon.ico"></head><body><noscript>You need to enable JavaScript to run this app.</noscript><div id="root"></div></body></html>',
'{"status":"ok","version":"2.4.1","environment":"production","timestamp":"2024-12-22T10:30:00Z","services":{"database":"healthy","cache":"healthy","queue":"healthy"},"uptime":864000}',
@@ -843,14 +906,27 @@ class TestDataGenerator:
# 生成固定 245 长度的 URL
url = generate_fixed_length_url(target_name, length=245, path_hint=f'website/{i:04d}')
# 生成模拟的响应头数据
response_headers = {
'server': random.choice(['nginx', 'Apache', 'cloudflare', 'Microsoft-IIS/10.0']),
'content_type': 'text/html; charset=utf-8',
'x_powered_by': random.choice(['PHP/8.2', 'ASP.NET', 'Express', None]),
'x_frame_options': random.choice(['DENY', 'SAMEORIGIN', None]),
'strict_transport_security': 'max-age=31536000; includeSubDomains' if random.choice([True, False]) else None,
'set_cookie': f'session={random.randint(100000, 999999)}; HttpOnly; Secure' if random.choice([True, False]) else None,
}
# 移除 None 值
response_headers = {k: v for k, v in response_headers.items() if v is not None}
batch_data.append((
url, target_id, target_name, random.choice(titles),
random.choice(webservers), random.choice(tech_stacks),
random.choice([200, 301, 302, 403, 404]),
random.randint(1000, 500000), 'text/html; charset=utf-8',
f'https://{target_name}/login' if random.choice([True, False]) else '',
random.choice(body_previews),
random.choice([True, False, None])
random.choice(response_bodies),
random.choice([True, False, None]),
generate_raw_response_headers(response_headers)
))
# 批量插入
@@ -859,12 +935,12 @@ class TestDataGenerator:
execute_values(cur, """
INSERT INTO website (
url, target_id, host, title, webserver, tech, status_code,
content_length, content_type, location, body_preview, vhost,
created_at
content_length, content_type, location, response_body, vhost,
response_headers, created_at
) VALUES %s
ON CONFLICT DO NOTHING
RETURNING id
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
ids = [row[0] for row in cur.fetchall()]
print(f" ✓ 创建了 {len(batch_data)} 个网站\n")
@@ -965,7 +1041,7 @@ class TestDataGenerator:
]
# 真实的 API 响应 body preview
body_previews = [
response_bodies = [
'{"status":"success","data":{"user_id":12345,"username":"john_doe","email":"john@example.com","role":"user","created_at":"2024-01-15T10:30:00Z","last_login":"2024-12-22T08:45:00Z"}}',
'{"success":true,"message":"Authentication successful","token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c","expires_in":3600}',
'{"error":"Unauthorized","code":"AUTH_FAILED","message":"Invalid credentials provided. Please check your username and password.","timestamp":"2024-12-22T15:30:45.123Z","request_id":"req_abc123xyz"}',
@@ -1017,14 +1093,27 @@ class TestDataGenerator:
# 生成 10-20 个 tags (gf_patterns)
tags = random.choice(gf_patterns)
# 生成模拟的响应头数据
response_headers = {
'server': random.choice(['nginx', 'gunicorn', 'uvicorn', 'Apache']),
'content_type': 'application/json',
'x_request_id': f'req_{random.randint(100000, 999999)}',
'x_ratelimit_limit': str(random.choice([100, 1000, 5000])),
'x_ratelimit_remaining': str(random.randint(0, 1000)),
'cache_control': random.choice(['no-cache', 'max-age=3600', 'private', None]),
}
# 移除 None 值
response_headers = {k: v for k, v in response_headers.items() if v is not None}
batch_data.append((
url, target_id, target_name, title,
random.choice(['nginx/1.24.0', 'gunicorn/21.2.0']),
random.choice([200, 201, 301, 400, 401, 403, 404, 500]),
random.randint(100, 50000), 'application/json',
tech_list,
'', random.choice(body_previews),
random.choice([True, False, None]), tags
'', random.choice(response_bodies),
random.choice([True, False, None]), tags,
generate_raw_response_headers(response_headers)
))
count += 1
@@ -1033,11 +1122,11 @@ class TestDataGenerator:
execute_values(cur, """
INSERT INTO endpoint (
url, target_id, host, title, webserver, status_code, content_length,
content_type, tech, location, body_preview, vhost, matched_gf_patterns,
created_at
content_type, tech, location, response_body, vhost, matched_gf_patterns,
response_headers, created_at
) VALUES %s
ON CONFLICT DO NOTHING
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
print(f" ✓ 创建了 {count} 个端点\n")
@@ -1185,77 +1274,79 @@ class TestDataGenerator:
print(f" ✓ 创建了 {count} 个主机端口映射\n")
def create_vulnerabilities(self, target_ids: list):
"""创建漏洞"""
"""创建漏洞(基于 website URL 前缀)"""
print("🐛 创建漏洞...")
cur = self.conn.cursor()
vuln_types = [
'sql-injection-authentication-bypass-vulnerability-', # 50 chars
'cross-site-scripting-xss-stored-persistent-attack-', # 50 chars
'cross-site-request-forgery-csrf-token-validation--', # 50 chars
'server-side-request-forgery-ssrf-internal-access--', # 50 chars
'xml-external-entity-xxe-injection-vulnerability---', # 50 chars
'remote-code-execution-rce-command-injection-flaw--', # 50 chars
'local-file-inclusion-lfi-path-traversal-exploit---', # 50 chars
'directory-traversal-arbitrary-file-read-access----', # 50 chars
'authentication-bypass-session-management-flaw-----', # 50 chars
'insecure-direct-object-reference-idor-access-ctrl-', # 50 chars
'sensitive-data-exposure-information-disclosure----', # 50 chars
'security-misconfiguration-default-credentials-----', # 50 chars
'broken-access-control-privilege-escalation-vuln---', # 50 chars
'cors-misconfiguration-cross-origin-data-leakage---', # 50 chars
'subdomain-takeover-dns-misconfiguration-exploit---', # 50 chars
'exposed-admin-panel-unauthorized-access-control---', # 50 chars
'default-credentials-weak-authentication-bypass----', # 50 chars
'information-disclosure-sensitive-data-exposure----', # 50 chars
'command-injection-os-command-execution-exploit----', # 50 chars
'ldap-injection-directory-service-manipulation-----', # 50 chars
'xpath-injection-xml-query-manipulation-attack-----', # 50 chars
'nosql-injection-mongodb-query-manipulation--------', # 50 chars
'template-injection-ssti-server-side-execution-----', # 50 chars
'deserialization-vulnerability-object-injection----', # 50 chars
'jwt-vulnerability-token-forgery-authentication----', # 50 chars
'open-redirect-url-redirection-phishing-attack-----', # 50 chars
'http-request-smuggling-cache-poisoning-attack-----', # 50 chars
'host-header-injection-password-reset-poisoning----', # 50 chars
'clickjacking-ui-redressing-frame-injection--------', # 50 chars
'session-fixation-authentication-session-attack----', # 50 chars
'sql-injection-authentication-bypass-vulnerability-',
'cross-site-scripting-xss-stored-persistent-attack-',
'cross-site-request-forgery-csrf-token-validation--',
'server-side-request-forgery-ssrf-internal-access--',
'xml-external-entity-xxe-injection-vulnerability---',
'remote-code-execution-rce-command-injection-flaw--',
'local-file-inclusion-lfi-path-traversal-exploit---',
'directory-traversal-arbitrary-file-read-access----',
'authentication-bypass-session-management-flaw-----',
'insecure-direct-object-reference-idor-access-ctrl-',
'sensitive-data-exposure-information-disclosure----',
'security-misconfiguration-default-credentials-----',
'broken-access-control-privilege-escalation-vuln---',
'cors-misconfiguration-cross-origin-data-leakage---',
'subdomain-takeover-dns-misconfiguration-exploit---',
'exposed-admin-panel-unauthorized-access-control---',
'default-credentials-weak-authentication-bypass----',
'information-disclosure-sensitive-data-exposure----',
'command-injection-os-command-execution-exploit----',
'ldap-injection-directory-service-manipulation-----',
]
sources = [
'nuclei-vulnerability-scanner--', # 30 chars
'dalfox-xss-parameter-analysis-', # 30 chars
'sqlmap-sql-injection-testing--', # 30 chars
'crlfuzz-crlf-injection-finder-', # 30 chars
'httpx-web-probe-fingerprint---', # 30 chars
'manual-penetration-testing----', # 30 chars
'burp-suite-professional-scan--', # 30 chars
'owasp-zap-security-scanner----', # 30 chars
'nmap-network-service-scanner--', # 30 chars
'nikto-web-server-scanner------', # 30 chars
'wpscan-wordpress-vuln-scan----', # 30 chars
'dirsearch-directory-brute-----', # 30 chars
'ffuf-web-fuzzer-content-disc--', # 30 chars
'amass-subdomain-enumeration---', # 30 chars
'subfinder-passive-subdomain---', # 30 chars
'masscan-port-scanner-fast-----', # 30 chars
'nessus-vulnerability-assess---', # 30 chars
'qualys-cloud-security-scan----', # 30 chars
'acunetix-web-vuln-scanner-----', # 30 chars
'semgrep-static-code-analysis--', # 30 chars
'nuclei-vulnerability-scanner--',
'dalfox-xss-parameter-analysis-',
'sqlmap-sql-injection-testing--',
'crlfuzz-crlf-injection-finder-',
'httpx-web-probe-fingerprint---',
'manual-penetration-testing----',
'burp-suite-professional-scan--',
'owasp-zap-security-scanner----',
]
severities = ['unknown', 'info', 'low', 'medium', 'high', 'critical']
# 获取域名目标
cur.execute("SELECT id, name FROM target WHERE type = 'domain' AND deleted_at IS NULL LIMIT 80")
domain_targets = cur.fetchall()
# 漏洞路径后缀(会追加到 website URL 后面)
vuln_paths = [
'/api/users?id=1',
'/api/admin/config',
'/api/v1/auth/login',
'/api/v2/data/export',
'/admin/settings',
'/debug/console',
'/backup/db.sql',
'/.env',
'/.git/config',
'/wp-admin/',
'/phpmyadmin/',
'/api/graphql',
'/swagger.json',
'/actuator/health',
'/metrics',
]
# 获取所有 website 的 URL 和 target_id
cur.execute("SELECT id, url, target_id FROM website LIMIT 500")
websites = cur.fetchall()
if not websites:
print(" ⚠ 没有 website 数据,跳过漏洞生成\n")
return
count = 0
batch_data = []
for target_id, target_name in domain_targets:
num = random.randint(30, 80)
for website_id, website_url, target_id in websites:
# 每个 website 生成 1-5 个漏洞
num_vulns = random.randint(1, 5)
for idx in range(num):
for idx in range(num_vulns):
severity = random.choice(severities)
cvss_ranges = {
'critical': (9.0, 10.0), 'high': (7.0, 8.9), 'medium': (4.0, 6.9),
@@ -1264,22 +1355,22 @@ class TestDataGenerator:
cvss_range = cvss_ranges.get(severity, (0.0, 10.0))
cvss_score = round(random.uniform(*cvss_range), 1)
# 生成固定 245 长度的 URL
url = generate_fixed_length_url(target_name, length=245, path_hint=f'vuln/{idx:04d}')
# 漏洞 URL = website URL + 漏洞路径
# 先移除 website URL 中的查询参数
base_url = website_url.split('?')[0]
vuln_url = base_url + random.choice(vuln_paths)
# 生成固定 300 长度的描述
description = generate_fixed_length_text(length=300, text_type='description')
raw_output = json.dumps({
'template': f'CVE-2024-{random.randint(10000, 99999)}',
'matcher_name': 'default',
'severity': severity,
'host': target_name,
'matched_at': url,
'matched_at': vuln_url,
})
batch_data.append((
target_id, url, random.choice(vuln_types), severity,
target_id, vuln_url, random.choice(vuln_types), severity,
random.choice(sources), cvss_score, description, raw_output
))
count += 1
@@ -1401,13 +1492,23 @@ class TestDataGenerator:
# 生成固定 245 长度的 URL
url = generate_fixed_length_url(target_name, length=245, path_hint=f'website-snap/{i:04d}')
# 生成模拟的响应头数据
response_headers = {
'server': random.choice(['nginx', 'Apache', 'cloudflare']),
'content_type': 'text/html; charset=utf-8',
'x_frame_options': random.choice(['DENY', 'SAMEORIGIN', None]),
}
# 移除 None 值
response_headers = {k: v for k, v in response_headers.items() if v is not None}
batch_data.append((
scan_id, url, target_name, random.choice(titles),
random.choice(webservers), random.choice(tech_stacks),
random.choice([200, 301, 403]),
random.randint(1000, 50000), 'text/html; charset=utf-8',
'', # location 字段
'<!DOCTYPE html><html><head><title>Test</title></head><body>Content</body></html>'
'<!DOCTYPE html><html><head><title>Test</title></head><body>Content</body></html>',
generate_raw_response_headers(response_headers)
))
count += 1
@@ -1415,11 +1516,12 @@ class TestDataGenerator:
if batch_data:
execute_values(cur, """
INSERT INTO website_snapshot (
scan_id, url, host, title, web_server, tech, status,
content_length, content_type, location, body_preview, created_at
scan_id, url, host, title, webserver, tech, status_code,
content_length, content_type, location, response_body,
response_headers, created_at
) VALUES %s
ON CONFLICT DO NOTHING
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
print(f" ✓ 创建了 {count} 个网站快照\n")
@@ -1498,6 +1600,13 @@ class TestDataGenerator:
num_tags = random.randint(10, 20)
tags = random.sample(all_tags, min(num_tags, len(all_tags)))
# 生成模拟的响应头数据
response_headers = {
'server': 'nginx/1.24.0',
'content_type': 'application/json',
'x_request_id': f'req_{random.randint(100000, 999999)}',
}
batch_data.append((
scan_id, url, target_name, title,
random.choice([200, 201, 401, 403, 404]),
@@ -1506,7 +1615,8 @@ class TestDataGenerator:
'nginx/1.24.0',
'application/json', tech_list,
'{"status":"ok","data":{}}',
tags
tags,
generate_raw_response_headers(response_headers)
))
count += 1
@@ -1515,11 +1625,11 @@ class TestDataGenerator:
execute_values(cur, """
INSERT INTO endpoint_snapshot (
scan_id, url, host, title, status_code, content_length,
location, webserver, content_type, tech, body_preview,
matched_gf_patterns, created_at
location, webserver, content_type, tech, response_body,
matched_gf_patterns, response_headers, created_at
) VALUES %s
ON CONFLICT DO NOTHING
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
print(f" ✓ 创建了 {count} 个端点快照\n")
@@ -2543,9 +2653,10 @@ class MillionDataGenerator:
if len(batch_data) >= batch_size:
execute_values(cur, """
INSERT INTO website (url, target_id, host, title, webserver, tech,
status_code, content_length, content_type, location, body_preview, created_at)
status_code, content_length, content_type, location, response_body,
vhost, response_headers, created_at)
VALUES %s ON CONFLICT DO NOTHING
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NULL, '', NOW())")
self.conn.commit()
batch_data = []
print(f"{count:,} / {target_count:,}")
@@ -2555,9 +2666,10 @@ class MillionDataGenerator:
if batch_data:
execute_values(cur, """
INSERT INTO website (url, target_id, host, title, webserver, tech,
status_code, content_length, content_type, location, body_preview, created_at)
status_code, content_length, content_type, location, response_body,
vhost, response_headers, created_at)
VALUES %s ON CONFLICT DO NOTHING
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NULL, '', NOW())")
self.conn.commit()
print(f" ✓ 创建了 {count:,} 个网站\n")
@@ -2631,10 +2743,10 @@ class MillionDataGenerator:
if len(batch_data) >= batch_size:
execute_values(cur, """
INSERT INTO endpoint (url, target_id, host, title, webserver, status_code,
content_length, content_type, tech, location, body_preview, vhost,
matched_gf_patterns, created_at)
content_length, content_type, tech, location, response_body, vhost,
matched_gf_patterns, response_headers, created_at)
VALUES %s ON CONFLICT DO NOTHING
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, '', NOW())")
self.conn.commit()
batch_data = []
print(f"{count:,} / {target_count:,}")
@@ -2644,10 +2756,10 @@ class MillionDataGenerator:
if batch_data:
execute_values(cur, """
INSERT INTO endpoint (url, target_id, host, title, webserver, status_code,
content_length, content_type, tech, location, body_preview, vhost,
matched_gf_patterns, created_at)
content_length, content_type, tech, location, response_body, vhost,
matched_gf_patterns, response_headers, created_at)
VALUES %s ON CONFLICT DO NOTHING
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW())")
""", batch_data, template="(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, '', NOW())")
self.conn.commit()
print(f" ✓ 创建了 {count:,} 个端点\n")

View File

@@ -95,6 +95,7 @@ EOF
RESPONSE=$(curl -k -s -X POST \
-H "Content-Type: application/json" \
-H "X-Worker-API-Key: ${WORKER_API_KEY}" \
-d "$REGISTER_DATA" \
"${API_URL}/api/workers/register/" 2>/dev/null)
@@ -116,7 +117,7 @@ if [ -z "$WORKER_ID" ]; then
# 等待 Server 就绪
log "等待 Server 就绪..."
for i in $(seq 1 30); do
if curl -k -s "${API_URL}/api/" > /dev/null 2>&1; then
if curl -k -s -H "X-Worker-API-Key: ${WORKER_API_KEY}" "${API_URL}/api/workers/config/?is_local=${IS_LOCAL}" > /dev/null 2>&1; then
log "${GREEN}Server 已就绪${NC}"
break
fi
@@ -189,6 +190,7 @@ EOF
RESPONSE_FILE=$(mktemp)
HTTP_CODE=$(curl -k -s -o "$RESPONSE_FILE" -w "%{http_code}" -X POST \
-H "Content-Type: application/json" \
-H "X-Worker-API-Key: ${WORKER_API_KEY}" \
-d "$JSON_DATA" \
"${API_URL}/api/workers/${WORKER_ID}/heartbeat/" 2>/dev/null || echo "000")
RESPONSE_BODY=$(cat "$RESPONSE_FILE" 2>/dev/null)

View File

@@ -27,10 +27,50 @@ BLUE='\033[0;34m'
RED='\033[0;31m'
NC='\033[0m'
log_info() { echo -e "${BLUE}[XingRin]${NC} $1"; }
log_success() { echo -e "${GREEN}[XingRin]${NC} $1"; }
log_warn() { echo -e "${YELLOW}[XingRin]${NC} $1"; }
log_error() { echo -e "${RED}[XingRin]${NC} $1"; }
# 渐变色定义
CYAN='\033[0;36m'
MAGENTA='\033[0;35m'
BOLD='\033[1m'
DIM='\033[2m'
log_info() { echo -e "${CYAN}${NC} $1"; }
log_success() { echo -e "${GREEN}${NC} $1"; }
log_warn() { echo -e "${YELLOW}${NC} $1"; }
log_error() { echo -e "${RED}${NC} $1"; }
# 炫酷 Banner
show_banner() {
echo -e ""
echo -e "${CYAN}${BOLD} ██╗ ██╗██╗███╗ ██╗ ██████╗ ██████╗ ██╗███╗ ██╗${NC}"
echo -e "${CYAN} ╚██╗██╔╝██║████╗ ██║██╔════╝ ██╔══██╗██║████╗ ██║${NC}"
echo -e "${BLUE}${BOLD} ╚███╔╝ ██║██╔██╗ ██║██║ ███╗██████╔╝██║██╔██╗ ██║${NC}"
echo -e "${BLUE} ██╔██╗ ██║██║╚██╗██║██║ ██║██╔══██╗██║██║╚██╗██║${NC}"
echo -e "${MAGENTA}${BOLD} ██╔╝ ██╗██║██║ ╚████║╚██████╔╝██║ ██║██║██║ ╚████║${NC}"
echo -e "${MAGENTA} ╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝ ╚═════╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝${NC}"
echo -e ""
echo -e "${DIM} ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
echo -e "${BOLD} 🚀 分布式安全扫描平台 │ Worker 节点部署${NC}"
echo -e "${DIM} ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
echo -e ""
}
# 完成 Banner
show_complete() {
echo -e ""
echo -e "${GREEN}${BOLD} ╔═══════════════════════════════════════════════════╗${NC}"
echo -e "${GREEN}${BOLD} ║ ║${NC}"
echo -e "${GREEN}${BOLD} ║ ██████╗ ██████╗ ███╗ ██╗███████╗██╗ ║${NC}"
echo -e "${GREEN}${BOLD} ║ ██╔══██╗██╔═══██╗████╗ ██║██╔════╝██║ ║${NC}"
echo -e "${GREEN}${BOLD} ║ ██║ ██║██║ ██║██╔██╗ ██║█████╗ ██║ ║${NC}"
echo -e "${GREEN}${BOLD} ║ ██║ ██║██║ ██║██║╚██╗██║██╔══╝ ╚═╝ ║${NC}"
echo -e "${GREEN}${BOLD} ║ ██████╔╝╚██████╔╝██║ ╚████║███████╗██╗ ║${NC}"
echo -e "${GREEN}${BOLD} ║ ╚═════╝ ╚═════╝ ╚═╝ ╚═══╝╚══════╝╚═╝ ║${NC}"
echo -e "${GREEN}${BOLD} ║ ║${NC}"
echo -e "${GREEN}${BOLD} ║ ✨ XingRin Worker 节点部署完成! ║${NC}"
echo -e "${GREEN}${BOLD} ║ ║${NC}"
echo -e "${GREEN}${BOLD} ╚═══════════════════════════════════════════════════╝${NC}"
echo -e ""
}
# 等待 apt 锁释放
wait_for_apt_lock() {
@@ -150,9 +190,7 @@ pull_image() {
# 主流程
main() {
log_info "=========================================="
log_info " XingRin 节点安装"
log_info "=========================================="
show_banner
detect_os
install_docker
@@ -162,9 +200,7 @@ main() {
touch "$DOCKER_MARKER"
log_success "=========================================="
log_success " ✓ 安装完成"
log_success "=========================================="
show_complete
}
main "$@"

View File

@@ -30,6 +30,7 @@ IMAGE="${DOCKER_USER}/xingrin-agent:${IMAGE_TAG}"
# 预设变量(远程部署时由 deploy_service.py 替换)
PRESET_SERVER_URL="{{HEARTBEAT_API_URL}}"
PRESET_WORKER_ID="{{WORKER_ID}}"
PRESET_API_KEY="{{WORKER_API_KEY}}"
# 颜色定义
GREEN='\033[0;32m'
@@ -68,6 +69,7 @@ start_agent() {
-e SERVER_URL="${PRESET_SERVER_URL}" \
-e WORKER_ID="${PRESET_WORKER_ID}" \
-e IMAGE_TAG="${IMAGE_TAG}" \
-e WORKER_API_KEY="${PRESET_API_KEY}" \
-v /proc:/host/proc:ro \
${IMAGE}

View File

@@ -9,9 +9,8 @@ DB_USER=postgres
DB_PASSWORD=123.com
# ==================== Redis 配置 ====================
# 在 Docker 网络中Redis 服务名称为 redis
# Redis 仅在 Docker 内部网络使用,不暴露公网端口
REDIS_HOST=redis
REDIS_PORT=6379
REDIS_DB=0
# ==================== 服务端口配置 ====================
@@ -51,6 +50,12 @@ LOG_LEVEL=INFO
# 是否记录命令执行日志(大量扫描时会增加磁盘占用)
ENABLE_COMMAND_LOGGING=true
# ==================== Worker API Key 配置 ====================
# Worker 节点认证密钥(用于 Worker 与主服务器之间的 API 认证)
# 生产环境务必更换为随机强密钥(建议 32 位以上随机字符串)
# 生成方法: openssl rand -hex 32
WORKER_API_KEY=change-me-to-a-secure-random-key
# ==================== Docker Hub 配置(生产模式) ====================
# 生产模式下从 Docker Hub 拉取镜像时使用
DOCKER_USER=yyhuni

View File

@@ -2,9 +2,13 @@ services:
# PostgreSQL可选使用远程数据库时不启动
# 本地模式: docker compose --profile local-db up -d
# 远程模式: docker compose up -d需配置 DB_HOST 为远程地址)
# 使用自定义镜像,预装 pg_ivm 扩展
postgres:
profiles: ["local-db"]
image: postgres:15
build:
context: ./postgres
dockerfile: Dockerfile
image: ${DOCKER_USER:-yyhuni}/xingrin-postgres:${IMAGE_TAG:-dev}
restart: always
environment:
POSTGRES_DB: ${DB_NAME}
@@ -15,6 +19,9 @@ services:
- ./postgres/init-user-db.sh:/docker-entrypoint-initdb.d/init-user-db.sh
ports:
- "${DB_PORT}:5432"
command: >
postgres
-c shared_preload_libraries=pg_ivm
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${DB_USER}"]
interval: 5s
@@ -24,8 +31,6 @@ services:
redis:
image: redis:7-alpine
restart: always
ports:
- "${REDIS_PORT}:6379"
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
@@ -49,7 +54,8 @@ services:
- /opt/xingrin:/opt/xingrin
- /var/run/docker.sock:/var/run/docker.sock
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8888/api/"]
# 使用专门的健康检查端点(无需认证)
test: ["CMD", "curl", "-f", "http://localhost:8888/api/health/"]
interval: 30s
timeout: 10s
retries: 3
@@ -65,9 +71,10 @@ services:
restart: always
environment:
- SERVER_URL=http://server:8888
- WORKER_NAME=本地节点
- WORKER_NAME=Local-Worker
- IS_LOCAL=true
- IMAGE_TAG=${IMAGE_TAG:-dev}
- WORKER_API_KEY=${WORKER_API_KEY}
depends_on:
server:
condition: service_healthy

View File

@@ -8,9 +8,13 @@
services:
# PostgreSQL可选使用远程数据库时不启动
# 使用自定义镜像,预装 pg_ivm 扩展
postgres:
profiles: ["local-db"]
image: postgres:15
build:
context: ./postgres
dockerfile: Dockerfile
image: ${DOCKER_USER:-yyhuni}/xingrin-postgres:${IMAGE_TAG:?IMAGE_TAG is required}
restart: always
environment:
POSTGRES_DB: ${DB_NAME}
@@ -21,6 +25,9 @@ services:
- ./postgres/init-user-db.sh:/docker-entrypoint-initdb.d/init-user-db.sh
ports:
- "${DB_PORT}:5432"
command: >
postgres
-c shared_preload_libraries=pg_ivm
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${DB_USER}"]
interval: 5s
@@ -30,8 +37,6 @@ services:
redis:
image: redis:7-alpine
restart: always
ports:
- "${REDIS_PORT}:6379"
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
@@ -52,7 +57,8 @@ services:
# Docker Socket 挂载:允许 Django 服务器执行本地 docker 命令(用于本地 Worker 任务分发)
- /var/run/docker.sock:/var/run/docker.sock
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8888/api/"]
# 使用专门的健康检查端点(无需认证)
test: ["CMD", "curl", "-f", "http://localhost:8888/api/health/"]
interval: 30s
timeout: 10s
retries: 3
@@ -69,9 +75,10 @@ services:
restart: always
environment:
- SERVER_URL=http://server:8888
- WORKER_NAME=本地节点
- WORKER_NAME=Local-Worker
- IS_LOCAL=true
- IMAGE_TAG=${IMAGE_TAG}
- WORKER_API_KEY=${WORKER_API_KEY}
depends_on:
server:
condition: service_healthy

View File

@@ -0,0 +1,19 @@
FROM postgres:15
# 安装编译依赖
RUN apt-get update && apt-get install -y \
build-essential \
postgresql-server-dev-15 \
git \
&& rm -rf /var/lib/apt/lists/*
# 编译安装 pg_ivm
RUN git clone https://github.com/sraoss/pg_ivm.git /tmp/pg_ivm \
&& cd /tmp/pg_ivm \
&& make \
&& make install \
&& rm -rf /tmp/pg_ivm
# 配置 shared_preload_libraries
# 注意: 这个配置会在容器启动时被应用
RUN echo "shared_preload_libraries = 'pg_ivm'" >> /usr/share/postgresql/postgresql.conf.sample

View File

@@ -9,3 +9,12 @@ psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "postgres" <<-EOSQL
GRANT ALL PRIVILEGES ON DATABASE xingrin TO "$POSTGRES_USER";
GRANT ALL PRIVILEGES ON DATABASE xingrin_dev TO "$POSTGRES_USER";
EOSQL
# 启用 pg_trgm 扩展(用于文本模糊搜索索引)
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "xingrin" <<-EOSQL
CREATE EXTENSION IF NOT EXISTS pg_trgm;
EOSQL
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "xingrin_dev" <<-EOSQL
CREATE EXTENSION IF NOT EXISTS pg_trgm;
EOSQL

129
docker/scripts/install-pg-ivm.sh Executable file
View File

@@ -0,0 +1,129 @@
#!/bin/bash
# pg_ivm 一键安装脚本(用于远程自建 PostgreSQL 服务器)
# 要求: PostgreSQL 13+ 版本
set -e
echo "=========================================="
echo "pg_ivm 一键安装脚本"
echo "要求: PostgreSQL 13+ 版本"
echo "=========================================="
echo ""
# 检查是否以 root 运行
if [ "$EUID" -ne 0 ]; then
echo "错误: 请使用 sudo 运行此脚本"
exit 1
fi
# 检测 PostgreSQL 版本
detect_pg_version() {
if command -v psql &> /dev/null; then
psql --version | grep -oP '\d+' | head -1
elif [ -n "$PG_VERSION" ]; then
echo "$PG_VERSION"
else
echo "15"
fi
}
PG_VERSION=${PG_VERSION:-$(detect_pg_version)}
# 检测 PostgreSQL
if ! command -v psql &> /dev/null; then
echo "错误: 未检测到 PostgreSQL请先安装 PostgreSQL"
exit 1
fi
echo "检测到 PostgreSQL 版本: $PG_VERSION"
# 检查版本要求
if [ "$PG_VERSION" -lt 13 ]; then
echo "错误: pg_ivm 要求 PostgreSQL 13+ 版本,当前版本: $PG_VERSION"
exit 1
fi
# 安装编译依赖
echo ""
echo "[1/4] 安装编译依赖..."
if command -v apt-get &> /dev/null; then
apt-get update -qq
apt-get install -y -qq build-essential postgresql-server-dev-${PG_VERSION} git
elif command -v yum &> /dev/null; then
yum install -y gcc make git postgresql${PG_VERSION}-devel
else
echo "错误: 不支持的包管理器,请手动安装编译依赖"
exit 1
fi
echo "✓ 编译依赖安装完成"
# 编译安装 pg_ivm
echo ""
echo "[2/4] 编译安装 pg_ivm..."
rm -rf /tmp/pg_ivm
git clone --quiet https://github.com/sraoss/pg_ivm.git /tmp/pg_ivm
cd /tmp/pg_ivm
make -s
make install -s
rm -rf /tmp/pg_ivm
echo "✓ pg_ivm 编译安装完成"
# 配置 shared_preload_libraries
echo ""
echo "[3/4] 配置 shared_preload_libraries..."
PG_CONF_DIRS=(
"/etc/postgresql/${PG_VERSION}/main"
"/var/lib/pgsql/${PG_VERSION}/data"
"/var/lib/postgresql/data"
)
PG_CONF_DIR=""
for dir in "${PG_CONF_DIRS[@]}"; do
if [ -d "$dir" ]; then
PG_CONF_DIR="$dir"
break
fi
done
if [ -z "$PG_CONF_DIR" ]; then
echo "警告: 未找到 PostgreSQL 配置目录,请手动配置 shared_preload_libraries"
echo "在 postgresql.conf 中添加: shared_preload_libraries = 'pg_ivm'"
else
if grep -q "shared_preload_libraries.*pg_ivm" "$PG_CONF_DIR/postgresql.conf" 2>/dev/null; then
echo "✓ shared_preload_libraries 已配置"
else
if [ -d "$PG_CONF_DIR/conf.d" ]; then
echo "shared_preload_libraries = 'pg_ivm'" > "$PG_CONF_DIR/conf.d/pg_ivm.conf"
echo "✓ 配置已写入 $PG_CONF_DIR/conf.d/pg_ivm.conf"
else
if grep -q "^shared_preload_libraries" "$PG_CONF_DIR/postgresql.conf"; then
sed -i "s/^shared_preload_libraries = '\(.*\)'/shared_preload_libraries = '\1,pg_ivm'/" "$PG_CONF_DIR/postgresql.conf"
else
echo "shared_preload_libraries = 'pg_ivm'" >> "$PG_CONF_DIR/postgresql.conf"
fi
echo "✓ 配置已写入 $PG_CONF_DIR/postgresql.conf"
fi
fi
fi
# 重启 PostgreSQL
echo ""
echo "[4/4] 重启 PostgreSQL..."
if systemctl is-active --quiet postgresql; then
systemctl restart postgresql
echo "✓ PostgreSQL 已重启"
elif systemctl is-active --quiet postgresql-${PG_VERSION}; then
systemctl restart postgresql-${PG_VERSION}
echo "✓ PostgreSQL 已重启"
else
echo "警告: 无法自动重启 PostgreSQL请手动重启"
fi
echo ""
echo "=========================================="
echo "✓ pg_ivm 安装完成"
echo "=========================================="
echo ""
echo "验证安装:"
echo " psql -U postgres -c \"CREATE EXTENSION IF NOT EXISTS pg_ivm;\""
echo ""

126
docker/scripts/test-pg-ivm.sh Executable file
View File

@@ -0,0 +1,126 @@
#!/bin/bash
# pg_ivm 安装验证测试
# 在 Docker 容器中测试 install-pg-ivm.sh 的安装流程
set -e
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
CONTAINER_NAME="pg_ivm_test_$$"
IMAGE_NAME="postgres:15"
echo "=========================================="
echo "pg_ivm 安装验证测试"
echo "=========================================="
# 清理函数
cleanup() {
echo ""
echo "[清理] 删除测试容器..."
docker rm -f "$CONTAINER_NAME" 2>/dev/null || true
}
trap cleanup EXIT
# 1. 启动临时容器
echo ""
echo "[1/5] 启动临时 PostgreSQL 容器..."
docker run -d --name "$CONTAINER_NAME" \
-e POSTGRES_PASSWORD=test \
-e POSTGRES_USER=postgres \
-e POSTGRES_DB=testdb \
-e PG_VERSION=15 \
"$IMAGE_NAME"
echo "等待 PostgreSQL 启动..."
sleep 10
if ! docker ps | grep -q "$CONTAINER_NAME"; then
echo "错误: 容器启动失败"
exit 1
fi
# 2. 复制并执行安装脚本
echo ""
echo "[2/5] 执行 pg_ivm 安装脚本..."
docker cp "$SCRIPT_DIR/install-pg-ivm.sh" "$CONTAINER_NAME:/tmp/install-pg-ivm.sh"
# 在容器内模拟安装(跳过 systemctl 重启,手动重启容器)
docker exec "$CONTAINER_NAME" bash -c "
set -e
export PG_VERSION=15
echo '安装编译依赖...'
apt-get update -qq
apt-get install -y -qq build-essential postgresql-server-dev-15 git
echo '编译安装 pg_ivm...'
rm -rf /tmp/pg_ivm
git clone --quiet https://github.com/sraoss/pg_ivm.git /tmp/pg_ivm
cd /tmp/pg_ivm
make -s
make install -s
rm -rf /tmp/pg_ivm
echo '✓ pg_ivm 编译安装完成'
"
# 3. 配置 shared_preload_libraries 并重启
echo ""
echo "[3/5] 配置 shared_preload_libraries..."
docker exec "$CONTAINER_NAME" bash -c "
echo \"shared_preload_libraries = 'pg_ivm'\" >> /var/lib/postgresql/data/postgresql.conf
"
echo "重启 PostgreSQL..."
docker restart "$CONTAINER_NAME"
sleep 8
# 4. 验证扩展是否可用
echo ""
echo "[4/5] 验证 pg_ivm 扩展..."
docker exec "$CONTAINER_NAME" psql -U postgres -d testdb -c "CREATE EXTENSION IF NOT EXISTS pg_ivm;" > /dev/null 2>&1
EXTENSION_EXISTS=$(docker exec "$CONTAINER_NAME" psql -U postgres -d testdb -t -c "SELECT COUNT(*) FROM pg_extension WHERE extname = 'pg_ivm';")
if [ "$(echo $EXTENSION_EXISTS | tr -d ' ')" != "1" ]; then
echo "错误: pg_ivm 扩展未正确加载"
exit 1
fi
echo "✓ pg_ivm 扩展已加载"
# 5. 测试 IMMV 功能
echo ""
echo "[5/5] 测试 IMMV 增量更新功能..."
docker exec "$CONTAINER_NAME" psql -U postgres -d testdb -c "
CREATE TABLE test_table (id SERIAL PRIMARY KEY, name TEXT, value INTEGER);
SELECT pgivm.create_immv('test_immv', 'SELECT id, name, value FROM test_table');
INSERT INTO test_table (name, value) VALUES ('test1', 100);
INSERT INTO test_table (name, value) VALUES ('test2', 200);
" > /dev/null 2>&1
IMMV_COUNT=$(docker exec "$CONTAINER_NAME" psql -U postgres -d testdb -t -c "SELECT COUNT(*) FROM test_immv;")
if [ "$(echo $IMMV_COUNT | tr -d ' ')" != "2" ]; then
echo "错误: IMMV 增量更新失败,期望 2 行,实际 $(echo $IMMV_COUNT | tr -d ' ')"
exit 1
fi
echo "✓ IMMV 增量更新正常 (2 行数据)"
# 测试更新
docker exec "$CONTAINER_NAME" psql -U postgres -d testdb -c "UPDATE test_table SET value = 150 WHERE name = 'test1';" > /dev/null 2>&1
UPDATED_VALUE=$(docker exec "$CONTAINER_NAME" psql -U postgres -d testdb -t -c "SELECT value FROM test_immv WHERE name = 'test1';")
if [ "$(echo $UPDATED_VALUE | tr -d ' ')" != "150" ]; then
echo "错误: IMMV 更新同步失败"
exit 1
fi
echo "✓ IMMV 更新同步正常"
# 测试删除
docker exec "$CONTAINER_NAME" psql -U postgres -d testdb -c "DELETE FROM test_table WHERE name = 'test2';" > /dev/null 2>&1
IMMV_COUNT_AFTER=$(docker exec "$CONTAINER_NAME" psql -U postgres -d testdb -t -c "SELECT COUNT(*) FROM test_immv;")
if [ "$(echo $IMMV_COUNT_AFTER | tr -d ' ')" != "1" ]; then
echo "错误: IMMV 删除同步失败"
exit 1
fi
echo "✓ IMMV 删除同步正常"
echo ""
echo "=========================================="
echo "✓ 所有测试通过"
echo "=========================================="
echo ""
echo "pg_ivm 安装验证成功,可以继续构建自定义 PostgreSQL 镜像"

View File

@@ -1,4 +1,4 @@
FROM python:3.10-slim
FROM python:3.10-slim-bookworm
WORKDIR /app
@@ -11,7 +11,16 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
# 安装 Docker CLI用于本地 Worker 任务分发)
RUN curl -fsSL https://get.docker.com | sh
# 只安装 docker-ce-cli避免安装完整 Docker 引擎
RUN apt-get update && \
apt-get install -y ca-certificates gnupg && \
install -m 0755 -d /etc/apt/keyrings && \
curl -fsSL https://download.docker.com/linux/debian/gpg | gpg --dearmor -o /etc/apt/keyrings/docker.gpg && \
chmod a+r /etc/apt/keyrings/docker.gpg && \
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/debian bookworm stable" > /etc/apt/sources.list.d/docker.list && \
apt-get update && \
apt-get install -y docker-ce-cli && \
rm -rf /var/lib/apt/lists/*
# 安装 uv超快的 Python 包管理器)
RUN pip install uv

View File

@@ -15,10 +15,12 @@ NC='\033[0m'
# 解析参数
WITH_FRONTEND=true
DEV_MODE=false
QUIET_MODE=false
for arg in "$@"; do
case $arg in
--no-frontend) WITH_FRONTEND=false ;;
--dev) DEV_MODE=true ;;
--quiet) QUIET_MODE=true ;;
esac
done
@@ -155,6 +157,11 @@ echo -e "${GREEN}[OK]${NC} 服务已启动"
# 数据初始化
./scripts/init-data.sh
# 静默模式下不显示结果(由调用方显示)
if [ "$QUIET_MODE" = true ]; then
exit 0
fi
# 获取访问地址
PUBLIC_HOST=$(grep "^PUBLIC_HOST=" .env 2>/dev/null | cut -d= -f2)
if [ -n "$PUBLIC_HOST" ] && [ "$PUBLIC_HOST" != "server" ]; then

View File

@@ -1,5 +1,6 @@
# 第一阶段:使用 Go 官方镜像编译工具
FROM golang:1.24 AS go-builder
# 锁定 digest 避免上游更新导致缓存失效
FROM golang:1.24@sha256:7e050c14ae9ca5ae56408a288336545b18632f51402ab0ec8e7be0e649a1fc42 AS go-builder
ENV GOPROXY=https://goproxy.cn,direct
# Naabu 需要 CGO 和 libpcap
@@ -36,7 +37,8 @@ RUN CGO_ENABLED=0 go install -v github.com/owasp-amass/amass/v5/cmd/amass@main
RUN go install github.com/hahwul/dalfox/v2@latest
# 第二阶段:运行时镜像
FROM ubuntu:24.04
# 锁定 digest 避免上游更新导致缓存失效
FROM ubuntu:24.04@sha256:4fdf0125919d24aec972544669dcd7d6a26a8ad7e6561c73d5549bd6db258ac2
# 避免交互式提示
ENV DEBIAN_FRONTEND=noninteractive

View File

@@ -13,21 +13,16 @@
- **权限**: sudo 管理员权限
- **端口要求**: 需要开放以下端口
- `8083` - HTTPS 访问(主要访问端口)
- `5432` - PostgreSQL 数据库(如使用本地数据库)
- `6379` - Redis 缓存服务
- `5432` - PostgreSQL 数据库(如使用本地数据库且有远程 Worker
- 后端 API 仅容器内监听 8888由 nginx 反代到 8083对公网无需放行 8888
- Redis 仅在 Docker 内部网络使用,无需对外开放
## 一键安装
### 1. 下载项目
```bash
# 方式 1Git 克隆(推荐)
git clone https://github.com/你的用户名/xingrin.git
cd xingrin
# 方式 2下载 ZIP
wget https://github.com/你的用户名/xingrin/archive/main.zip
unzip main.zip && cd xingrin-main
```
### 2. 执行安装
@@ -60,8 +55,7 @@ sudo ./install.sh --no-frontend
#### 必须放行的端口
```
8083 - HTTPS 访问(主要访问端口)
5432 - PostgreSQL如使用本地数据库
6379 - Redis 缓存
5432 - PostgreSQL如使用本地数据库且有远程 Worker
```
#### 推荐方案
@@ -110,9 +104,6 @@ graph TD
# 重启服务
./restart.sh
# 更新系统
./update.sh
# 卸载系统
./uninstall.sh
```
@@ -234,11 +225,6 @@ docker logs --tail 100 xingrin-agent
tail -f /opt/xingrin/logs/*.log
```
### 3. 定期更新
```bash
# 定期执行系统更新
./update.sh
```
## 下一步

Binary file not shown.

Before

Width:  |  Height:  |  Size: 95 KiB

After

Width:  |  Height:  |  Size: 112 KiB

View File

@@ -64,6 +64,16 @@ export default function ScanHistoryLayout({
<div className="flex items-center justify-between px-4 lg:px-6">
<Tabs value={getActiveTab()} className="w-full">
<TabsList>
<TabsTrigger value="websites" asChild>
<Link href={tabPaths.websites} className="flex items-center gap-0.5">
Websites
{counts.websites > 0 && (
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
{counts.websites}
</Badge>
)}
</Link>
</TabsTrigger>
<TabsTrigger value="subdomain" asChild>
<Link href={tabPaths.subdomain} className="flex items-center gap-0.5">
Subdomains
@@ -74,12 +84,12 @@ export default function ScanHistoryLayout({
)}
</Link>
</TabsTrigger>
<TabsTrigger value="websites" asChild>
<Link href={tabPaths.websites} className="flex items-center gap-0.5">
Websites
{counts.websites > 0 && (
<TabsTrigger value="ip-addresses" asChild>
<Link href={tabPaths["ip-addresses"]} className="flex items-center gap-0.5">
IP Addresses
{counts["ip-addresses"] > 0 && (
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
{counts.websites}
{counts["ip-addresses"]}
</Badge>
)}
</Link>
@@ -104,16 +114,6 @@ export default function ScanHistoryLayout({
)}
</Link>
</TabsTrigger>
<TabsTrigger value="ip-addresses" asChild>
<Link href={tabPaths["ip-addresses"]} className="flex items-center gap-0.5">
IP Addresses
{counts["ip-addresses"] > 0 && (
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
{counts["ip-addresses"]}
</Badge>
)}
</Link>
</TabsTrigger>
<TabsTrigger value="vulnerabilities" asChild>
<Link href={tabPaths.vulnerabilities} className="flex items-center gap-0.5">
Vulnerabilities

View File

@@ -8,7 +8,7 @@ export default function ScanHistoryDetailPage() {
const router = useRouter()
useEffect(() => {
router.replace(`/scan/history/${id}/subdomain/`)
router.replace(`/scan/history/${id}/websites/`)
}, [id, router])
return null

View File

@@ -0,0 +1,5 @@
import { SearchPage } from "@/components/search"
export default function Search() {
return <SearchPage />
}

View File

@@ -5,15 +5,15 @@ import { useEffect } from "react"
/**
* Target detail page (compatible with old routes)
* Automatically redirects to subdomain page
* Automatically redirects to websites page
*/
export default function TargetDetailsPage() {
const { id } = useParams<{ id: string }>()
const router = useRouter()
useEffect(() => {
// Redirect to subdomain page
router.replace(`/target/${id}/subdomain/`)
// Redirect to websites page
router.replace(`/target/${id}/websites/`)
}, [id, router])
return null

View File

@@ -138,6 +138,16 @@ export default function TargetLayout({
<div className="flex items-center justify-between px-4 lg:px-6">
<Tabs value={getActiveTab()} className="w-full">
<TabsList>
<TabsTrigger value="websites" asChild>
<Link href={tabPaths.websites} className="flex items-center gap-0.5">
Websites
{counts.websites > 0 && (
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
{counts.websites}
</Badge>
)}
</Link>
</TabsTrigger>
<TabsTrigger value="subdomain" asChild>
<Link href={tabPaths.subdomain} className="flex items-center gap-0.5">
Subdomains
@@ -148,12 +158,12 @@ export default function TargetLayout({
)}
</Link>
</TabsTrigger>
<TabsTrigger value="websites" asChild>
<Link href={tabPaths.websites} className="flex items-center gap-0.5">
Websites
{counts.websites > 0 && (
<TabsTrigger value="ip-addresses" asChild>
<Link href={tabPaths["ip-addresses"]} className="flex items-center gap-0.5">
IP Addresses
{counts["ip-addresses"] > 0 && (
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
{counts.websites}
{counts["ip-addresses"]}
</Badge>
)}
</Link>
@@ -178,16 +188,6 @@ export default function TargetLayout({
)}
</Link>
</TabsTrigger>
<TabsTrigger value="ip-addresses" asChild>
<Link href={tabPaths["ip-addresses"]} className="flex items-center gap-0.5">
IP Addresses
{counts["ip-addresses"] > 0 && (
<Badge variant="secondary" className="ml-1.5 h-5 min-w-5 rounded-full px-1.5 text-xs">
{counts["ip-addresses"]}
</Badge>
)}
</Link>
</TabsTrigger>
<TabsTrigger value="vulnerabilities" asChild>
<Link href={tabPaths.vulnerabilities} className="flex items-center gap-0.5">
Vulnerabilities

View File

@@ -5,15 +5,15 @@ import { useEffect } from "react"
/**
* Target detail default page
* Automatically redirects to subdomain page
* Automatically redirects to websites page
*/
export default function TargetDetailPage() {
const { id } = useParams<{ id: string }>()
const router = useRouter()
useEffect(() => {
// Redirect to subdomain page
router.replace(`/target/${id}/subdomain/`)
// Redirect to websites page
router.replace(`/target/${id}/websites/`)
}, [id, router])
return null

View File

@@ -44,7 +44,6 @@
--font-sans: 'Noto Sans SC', system-ui, -apple-system, PingFang SC, sans-serif;
--font-mono: 'JetBrains Mono', 'Fira Code', Consolas, monospace;
--font-serif: Georgia, 'Noto Serif SC', serif;
--radius: 0.625rem;
--tracking-tighter: calc(var(--tracking-normal) - 0.05em);
--tracking-tight: calc(var(--tracking-normal) - 0.025em);
--tracking-wide: calc(var(--tracking-normal) + 0.025em);

View File

@@ -15,6 +15,8 @@ import {
IconServer, // Server icon
IconTerminal2, // Terminal icon
IconBug, // Vulnerability icon
IconMessageReport, // Feedback icon
IconSearch, // Search icon
} from "@tabler/icons-react"
// Import internationalization hook
import { useTranslations } from 'next-intl'
@@ -75,6 +77,11 @@ export function AppSidebar({ ...props }: React.ComponentProps<typeof Sidebar>) {
url: "/dashboard/",
icon: IconDashboard,
},
{
title: t('search'),
url: "/search/",
icon: IconSearch,
},
{
title: t('organization'),
url: "/organization/",
@@ -132,6 +139,11 @@ export function AppSidebar({ ...props }: React.ComponentProps<typeof Sidebar>) {
// Secondary navigation menu items
const navSecondary = [
{
title: t('feedback'),
url: "https://github.com/yyhuni/xingrin/issues",
icon: IconMessageReport,
},
{
title: t('help'),
url: "https://github.com/yyhuni/xingrin",

View File

@@ -67,6 +67,45 @@ const DEFAULT_FIELDS: FilterField[] = [
PREDEFINED_FIELDS.host,
]
// History storage key
const FILTER_HISTORY_KEY = 'smart_filter_history'
const MAX_HISTORY_PER_FIELD = 10
// Get history values for a field
function getFieldHistory(field: string): string[] {
if (typeof window === 'undefined') return []
try {
const history = JSON.parse(localStorage.getItem(FILTER_HISTORY_KEY) || '{}')
return history[field] || []
} catch {
return []
}
}
// Save a value to field history
function saveFieldHistory(field: string, value: string) {
if (typeof window === 'undefined' || !value.trim()) return
try {
const history = JSON.parse(localStorage.getItem(FILTER_HISTORY_KEY) || '{}')
const fieldHistory = (history[field] || []).filter((v: string) => v !== value)
fieldHistory.unshift(value)
history[field] = fieldHistory.slice(0, MAX_HISTORY_PER_FIELD)
localStorage.setItem(FILTER_HISTORY_KEY, JSON.stringify(history))
} catch {
// ignore
}
}
// Extract field-value pairs from query and save to history
function saveQueryHistory(query: string) {
const regex = /(\w+)(==|!=|=)"([^"]+)"/g
let match
while ((match = regex.exec(query)) !== null) {
const [, field, , value] = match
saveFieldHistory(field, value)
}
}
// Parse filter expression (FOFA style)
interface ParsedFilter {
field: string
@@ -115,10 +154,114 @@ export function SmartFilterInput({
const [open, setOpen] = React.useState(false)
const [inputValue, setInputValue] = React.useState(value ?? "")
const inputRef = React.useRef<HTMLInputElement>(null)
const ghostRef = React.useRef<HTMLSpanElement>(null)
const listRef = React.useRef<HTMLDivElement>(null)
const savedScrollTop = React.useRef<number | null>(null)
const hasInitialized = React.useRef(false)
// Calculate ghost text suggestion
const ghostText = React.useMemo(() => {
if (!inputValue) return ""
// Get the last word/token being typed
const lastSpaceIndex = inputValue.lastIndexOf(' ')
const currentToken = lastSpaceIndex === -1 ? inputValue : inputValue.slice(lastSpaceIndex + 1)
const lowerToken = currentToken.toLowerCase()
// If empty token after space, check if previous expression is complete
if (!currentToken && inputValue.trim()) {
// Check if last expression is complete (ends with ")
if (inputValue.trimEnd().endsWith('"')) {
return '&& '
}
return ""
}
if (!currentToken) return ""
// Priority 1: Field name completion (no = in token)
if (!currentToken.includes('=') && !currentToken.includes('!')) {
// Find matching field first
const matchingField = fields.find(f =>
f.key.toLowerCase().startsWith(lowerToken) &&
f.key.toLowerCase() !== lowerToken
)
if (matchingField) {
return matchingField.key.slice(currentToken.length) + '="'
}
// If exact match of field name, suggest ="
const exactField = fields.find(f => f.key.toLowerCase() === lowerToken)
if (exactField) {
return '="'
}
// Priority 2: Logical operators (only if no field matches)
if ('&&'.startsWith(currentToken) && currentToken.startsWith('&')) {
return '&&'.slice(currentToken.length) + ' '
}
if ('||'.startsWith(currentToken) && currentToken.startsWith('|')) {
return '||'.slice(currentToken.length) + ' '
}
// 'and' / 'or' only if no field name starts with these
if (!matchingField) {
if ('and'.startsWith(lowerToken) && lowerToken.length > 0 && !fields.some(f => f.key.toLowerCase().startsWith(lowerToken))) {
return 'and'.slice(lowerToken.length) + ' '
}
if ('or'.startsWith(lowerToken) && lowerToken.length > 0 && !fields.some(f => f.key.toLowerCase().startsWith(lowerToken))) {
return 'or'.slice(lowerToken.length) + ' '
}
}
return ""
}
// Check if typing ! for != operator
if (currentToken.match(/^(\w+)!$/)) {
return '="'
}
// Check if typing = and might want ==
const singleEqMatch = currentToken.match(/^(\w+)=$/)
if (singleEqMatch) {
// Suggest " for fuzzy match (most common)
return '"'
}
// Check if typed == or != (no opening quote yet)
const doubleOpMatch = currentToken.match(/^(\w+)(==|!=)$/)
if (doubleOpMatch) {
return '"'
}
// Check if typing a value (has = and opening quote)
const eqMatch = currentToken.match(/^(\w+)(==|!=|=)"([^"]*)$/)
if (eqMatch) {
const [, field, , partialValue] = eqMatch
// Get history for this field
const history = getFieldHistory(field)
// Find matching history value
const matchingValue = history.find(v =>
v.toLowerCase().startsWith(partialValue.toLowerCase()) &&
v.toLowerCase() !== partialValue.toLowerCase()
)
if (matchingValue) {
return matchingValue.slice(partialValue.length) + '"'
}
// If value has content but no closing quote, suggest closing quote
if (partialValue.length > 0) {
return '"'
}
}
// Check if a complete expression just finished (ends with ")
if (currentToken.match(/^\w+(==|!=|=)"[^"]+"$/)) {
return ' && '
}
return ""
}, [inputValue, fields])
// Synchronize external value changes
React.useEffect(() => {
if (value !== undefined) {
@@ -189,12 +332,27 @@ export function SmartFilterInput({
// Handle search
const handleSearch = () => {
// Save query values to history
saveQueryHistory(inputValue)
onSearch?.(parsedFilters, inputValue)
setOpen(false)
}
// Accept ghost text suggestion
const acceptGhostText = () => {
if (ghostText) {
setInputValue(inputValue + ghostText)
return true
}
return false
}
// Handle keyboard events
const handleKeyDown = (e: React.KeyboardEvent) => {
if (e.key === "Tab" && ghostText) {
e.preventDefault()
acceptGhostText()
}
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault()
handleSearch()
@@ -202,6 +360,14 @@ export function SmartFilterInput({
if (e.key === "Escape") {
setOpen(false)
}
// Right arrow at end of input accepts ghost text
if (e.key === "ArrowRight" && ghostText) {
const input = inputRef.current
if (input && input.selectionStart === input.value.length) {
e.preventDefault()
acceptGhostText()
}
}
}
// Append example to input box (not overwrite), then close popover
@@ -215,36 +381,46 @@ export function SmartFilterInput({
return (
<div className={className}>
<Popover open={open} onOpenChange={setOpen} modal={false}>
<PopoverAnchor asChild>
<div className="flex items-center gap-2">
<Input
ref={inputRef}
type="text"
value={inputValue}
onChange={(e) => {
setInputValue(e.target.value)
if (!open) setOpen(true)
}}
onFocus={() => setOpen(true)}
onBlur={(e) => {
// If focus moves to inside Popover or input itself, don't close
const relatedTarget = e.relatedTarget as HTMLElement | null
if (relatedTarget?.closest('[data-radix-popper-content-wrapper]')) {
return
}
// Delay close to let CommandItem's onSelect execute first
setTimeout(() => setOpen(false), 150)
}}
onKeyDown={handleKeyDown}
placeholder={placeholder || defaultPlaceholder}
className="h-8 w-full"
/>
<Button variant="outline" size="sm" onClick={handleSearch}>
<IconSearch className="h-4 w-4" />
</Button>
</div>
</PopoverAnchor>
<div className="flex items-center gap-2">
<Popover open={open} onOpenChange={setOpen} modal={false}>
<PopoverAnchor asChild>
<div className="relative flex-1">
<Input
ref={inputRef}
type="text"
value={inputValue}
onChange={(e) => {
setInputValue(e.target.value)
if (!open) setOpen(true)
}}
onFocus={() => setOpen(true)}
onBlur={(e) => {
// If focus moves to inside Popover or input itself, don't close
const relatedTarget = e.relatedTarget as HTMLElement | null
if (relatedTarget?.closest('[data-radix-popper-content-wrapper]')) {
return
}
// Delay close to let CommandItem's onSelect execute first
setTimeout(() => setOpen(false), 150)
}}
onKeyDown={handleKeyDown}
placeholder={placeholder || defaultPlaceholder}
className="h-8 w-full font-mono text-sm"
/>
{/* Ghost text overlay */}
{ghostText && (
<div
className="absolute inset-0 flex items-center pointer-events-none overflow-hidden px-3"
aria-hidden="true"
>
<span className="font-mono text-sm">
<span className="invisible">{inputValue}</span>
<span ref={ghostRef} className="text-muted-foreground/40">{ghostText}</span>
</span>
</div>
)}
</div>
</PopoverAnchor>
<PopoverContent
className="w-[var(--radix-popover-trigger-width)] p-0"
align="start"
@@ -343,6 +519,10 @@ export function SmartFilterInput({
</Command>
</PopoverContent>
</Popover>
<Button variant="outline" size="sm" onClick={handleSearch}>
<IconSearch className="h-4 w-4" />
</Button>
</div>
</div>
)
}

View File

@@ -209,6 +209,7 @@ export function DashboardDataTable() {
target: t('columns.scanHistory.target'),
summary: t('columns.scanHistory.summary'),
engineName: t('columns.scanHistory.engineName'),
workerName: t('columns.scanHistory.workerName'),
createdAt: t('columns.common.createdAt'),
status: t('columns.common.status'),
progress: t('columns.scanHistory.progress'),

Some files were not shown because too many files have changed in this diff Show More