Compare commits

...

107 Commits

Author SHA1 Message Date
yyhuni
648a1888d4 增加企业微信 2026-01-10 10:16:01 +08:00
github-actions[bot]
2508268a45 chore: bump version to v1.5.4-dev 2026-01-10 02:10:05 +00:00
yyhuni
c60383940c 提供升级功能 2026-01-10 10:04:07 +08:00
yyhuni
47298c294a 性能优化 2026-01-10 09:44:49 +08:00
yyhuni
eba394e14e 优化:性能优化 2026-01-10 09:44:43 +08:00
yyhuni
592a1958c4 优化ui 2026-01-09 16:52:50 +08:00
yyhuni
38e2856c08 feat(scan): add provider abstraction layer for flexible target sourcing
- Add TargetProvider base class and ProviderContext for unified target acquisition
- Implement DatabaseTargetProvider for database-backed target queries
- Implement ListTargetProvider for in-memory target lists (fast scan phase 1)
- Implement SnapshotTargetProvider for snapshot table reads (fast scan phase 2+)
- Implement PipelineTargetProvider for pipeline stage outputs
- Add comprehensive provider tests covering common properties and individual providers
- Update screenshot_flow to support both legacy mode (target_id) and provider mode
- Add backward compatibility layer for existing task exports (directory, fingerprint, port, site, url_fetch, vuln scans)
- Add task backward compatibility tests
- Update .gitignore to exclude .hypothesis/ cache directory
- Update frontend ANSI log viewer component
- Update backend requirements.txt with new dependencies
- Enables flexible data source integration while maintaining backward compatibility with existing database-driven workflows
2026-01-09 09:02:09 +08:00
yyhuni
f5ad8e68e9 chore(backend): add hypothesis cache directory to gitignore
- Add .hypothesis/ directory to .gitignore to exclude Hypothesis property testing cache files
- Prevents test cache artifacts from being tracked in version control
- Improves repository cleanliness by ignoring generated test data
2026-01-08 11:58:49 +08:00
yyhuni
d5f91a236c Merge branch 'main' of https://github.com/yyhuni/xingrin 2026-01-08 10:37:32 +08:00
yyhuni
24ae8b5aeb docs: restructure README features section with capability tables
- Convert feature descriptions from nested lists to organized capability tables
- Add scanning capability table with tools and descriptions for each feature
- Add platform capability table highlighting core platform features
- Improve readability and scannability of feature documentation
- Maintain scanning pipeline architecture section for reference
- Simplify feature organization for better user comprehension
2026-01-08 10:35:56 +08:00
github-actions[bot]
86f43f94a0 chore: bump version to v1.5.3 2026-01-08 02:17:58 +00:00
yyhuni
53ba03d1e5 支持kali 2026-01-08 10:14:12 +08:00
github-actions[bot]
89c44ebd05 chore: bump version to v1.5.2 2026-01-08 00:20:11 +00:00
yyhuni
e0e3419edb chore(docker): improve worker dockerfile reliability with retry mechanism
- Add retry mechanism for apt-get install to handle ARM64 mirror sync delays
- Use --no-install-recommends flag to reduce image size and installation time
- Split apt-get update and install commands for better layer caching
- Add fallback installation logic for packages in case of initial failure
- Include explanatory comment about ARM64 ports.ubuntu.com potential delays
- Maintain compatibility with both ARM64 and AMD64 architectures
2026-01-08 08:14:24 +08:00
yyhuni
52ee4684a7 chore(docker): add apt-get update before playwright dependencies
- Add apt-get update before installing playwright chromium dependencies
- Ensures package lists are refreshed before installing system dependencies
- Prevents potential package installation failures in Docker builds
2026-01-08 08:09:21 +08:00
yyhuni
ce8cebf11d chore(frontend): update pnpm-lock.yaml with @radix-ui/react-hover-card
- Add @radix-ui/react-hover-card@1.1.15 package resolution entry
- Add package snapshot with all required dependencies and peer dependencies
- Update lock file to reflect new hover card component dependency
- Ensures consistent dependency management across the frontend environment
2026-01-08 07:57:58 +08:00
yyhuni
ec006d8f54 chore(frontend): add @radix-ui/react-hover-card dependency
- Add @radix-ui/react-hover-card v1.1.6 to project dependencies
- Enables hover card UI component functionality for improved user interactions
- Maintains consistency with existing Radix UI component library usage
2026-01-08 07:56:07 +08:00
yyhuni
48976a570f docs: update README with screenshot feature and sponsorship info
- Add screenshot feature documentation to features section with Playwright details
- Include WebP format compression benefits and multi-source URL support
- Add screenshot stage to scan flow architecture diagram with styling
- Add fingerprint library table with counts for public distribution
- Add sponsorship section with WeChat Pay and Alipay QR codes
- Add sponsor appreciation table
- Update frontend dependencies with @radix-ui/react-visually-hidden package
- Remove redundant installation speed note from mirror parameter documentation
- Clean up demo link formatting in online demo section
2026-01-08 07:54:31 +08:00
yyhuni
5da7229873 feat(scan-overview): add yaml configuration tab and improve logs layout
- Add yaml_configuration field to ScanHistorySerializer for backend exposure
- Implement tabbed interface with Logs and Configuration tabs in scan overview
- Add YamlEditor component to display scan configuration in read-only mode
- Refactor logs section to show status bar only when logs tab is active
- Move auto-refresh toggle to logs tab header for better UX
- Add padding to stage progress items for improved visual alignment
- Add internationalization strings for new UI elements (en and zh)
- Update ScanHistory type to include yamlConfiguration field
- Improve tab switching state management with activeTab state
2026-01-08 07:31:54 +08:00
yyhuni
8bb737a9fa feat(scan-history): add auto-refresh toggle and improve layout
- Add auto-refresh toggle switch to scan logs section for manual control
- Implement flexible polling based on auto-refresh state and scan status
- Restructure scan overview layout to use left-right split (stages + logs)
- Move stage progress to left column with vulnerability statistics
- Implement scrollable logs panel on right side with proper height constraints
- Update component imports to use Switch and Label instead of Button
- Add full-height flex layout to parent containers for proper scrolling
- Refactor grid layout from 2-column to fixed-width left + flexible right
- Update translations for new UI elements and labels
- Improve responsive design with better flex constraints and min-height handling
2026-01-07 23:30:27 +08:00
yyhuni
2d018d33f3 优化扫描历史详细页面 2026-01-07 22:44:46 +08:00
yyhuni
0c07cc8497 refactor(scan-flows): simplify logger calls by splitting multiline strings
- Split multiline logger.info() calls into separate single-line calls in initiate_scan_flow.py
- Improved log readability by removing string concatenation with newlines and separators
- Refactored 6 logger.info() calls across sequential, parallel, and completion stages
- Updated subdomain_discovery_flow.py to use consistent single-line logger pattern
- Maintains same log output while improving code maintainability and consistency
2026-01-07 22:21:50 +08:00
yyhuni
225b039985 style(system-logs): adjust log level filter dropdown width
- Increase SelectTrigger width from 100px to 130px for better label visibility
- Improve UI consistency in log toolbar component
- Prevent text truncation in log level filter dropdown
2026-01-07 22:17:07 +08:00
yyhuni
d1624627bc 一级tab加图标 2026-01-07 22:14:42 +08:00
yyhuni
7bb15e4ae4 增加:截图功能 2026-01-07 22:10:51 +08:00
github-actions[bot]
8e8cc29669 chore: bump version to v1.4.1 2026-01-07 01:33:29 +00:00
yyhuni
d6d5338acb 增加资产删除功能 2026-01-07 09:29:31 +08:00
yyhuni
c521bdb511 重构:回退逻辑 2026-01-07 08:45:27 +08:00
yyhuni
abf2d95f6f feat(targets): increase max batch size for target creation from 1000 to 5000
- Update MAX_BATCH_SIZE constant in BatchCreateTargetSerializer from 1000 to 5000
- Increase batch creation limit to support larger bulk operations
- Update documentation comment to reflect new limit
- Allows users to create up to 5000 targets in a single batch operation
2026-01-06 20:39:31 +08:00
github-actions[bot]
ab58cf0d85 chore: bump version to v1.4.0 2026-01-06 09:31:29 +00:00
yyhuni
fb0111adf2 Merge branch 'dev' 2026-01-06 17:27:35 +08:00
yyhuni
161ee9a2b1 Merge branch 'dev' 2026-01-06 17:27:16 +08:00
yyhuni
0cf75585d5 docs: 添加黑名单过滤功能说明到 README 2026-01-06 17:25:31 +08:00
yyhuni
1d8d5f51d9 feat(blacklist): add mock data and service integration for blacklist management
- Create new blacklist mock data module with global and target-specific patterns
- Add mock functions for getting and updating global blacklist rules
- Add mock functions for getting and updating target-specific blacklist rules
- Integrate mock blacklist endpoints into global-blacklist.service.ts
- Integrate mock blacklist endpoints into target.service.ts
- Export blacklist mock functions from main mock index
- Enable testing of blacklist management UI without backend API
2026-01-06 17:08:51 +08:00
github-actions[bot]
3f8de07c8c chore: bump version to v1.4.0-dev 2026-01-06 09:02:31 +00:00
yyhuni
cd5c2b9f11 chore(notifications): remove test notification endpoint
- Remove test notification route from URL patterns
- Delete notifications_test view function and associated logic
- Clean up unused test endpoint that was used for development purposes
- Simplify notification API surface by removing non-production code
2026-01-06 16:57:29 +08:00
yyhuni
54786c22dd feat(scan): increase max batch size for quick scan operations
- Increase MAX_BATCH_SIZE from 1000 to 5000 in QuickScanSerializer
- Allows processing of larger batch scans in a single operation
- Improves throughput for bulk scanning workflows
2026-01-06 16:55:28 +08:00
yyhuni
d468f975ab feat(scan): implement fallback chain for endpoint URL export
- Add fallback chain for URL data sources: Endpoint → WebSite → default generation
- Import WebSite model and Path utility for enhanced file handling
- Create output directory automatically if it doesn't exist
- Add "source" field to return value indicating data origin (endpoint/website/default)
- Update docstring to document the three-tier fallback priority system
- Implement sequential export attempts with logging at each fallback stage
- Improve error handling and data source transparency for endpoint exports
2026-01-06 16:30:42 +08:00
yyhuni
a85a12b8ad feat(asset): create incremental materialized views for asset search
- Add pg_ivm extension for incremental materialized view maintenance
- Create asset_search_view for Website model with optimized columns for full-text search
- Create endpoint_search_view for Endpoint model with matching search schema
- Add database indexes on host, url, title, status_code, and created_at columns for both views
- Enable high-performance asset search queries with automatic view maintenance
2026-01-06 16:22:24 +08:00
yyhuni
a8b0d97b7b feat(targets): update navigation routes and enhance add button UI
- Change target detail navigation route from `/website/` to `/overview/`
- Update TargetNameCell click handler to use new overview route
- Update TargetRowActions onView handler to use new overview route
- Add IconPlus icon import from @tabler/icons-react
- Add icon to create target button for improved visual clarity
- Improves navigation consistency and button affordance in targets table
2026-01-06 16:14:54 +08:00
yyhuni
b8504921c2 feat(fingerprints): add JSONL format support for Goby fingerprint imports
- Add support for JSONL format parsing in addition to standard JSON for Goby fingerprints
- Update GobyFingerprintService to validate both standard format (name/logic/rule) and JSONL format (product/rule)
- Implement _parse_json_content() method to handle both JSON and JSONL file formats with proper error handling
- Add JSONL parsing logic in frontend import dialog with per-line validation and error reporting
- Update file import endpoint documentation to indicate JSONL format support
- Improve error messages for encoding and parsing failures to aid user debugging
- Enable seamless import of Goby fingerprint data from multiple source formats
2026-01-06 16:10:14 +08:00
yyhuni
ecfc1822fb style(target): update vulnerability icon color to muted foreground
- Change ShieldAlert icon color from red-500 to muted-foreground in target overview
- Improves visual consistency with design system color palette
- Reduces visual emphasis on vulnerability section for better UI balance
2026-01-06 12:01:59 +08:00
github-actions[bot]
81633642e6 chore: bump version to v1.3.16-dev 2026-01-06 03:55:16 +00:00
yyhuni
d1ec9b7f27 feat(settings): add global blacklist management page and UI integration
- Add new global blacklist settings page with pattern management interface
- Create useGlobalBlacklist and useUpdateGlobalBlacklist React Query hooks for data fetching and mutations
- Implement global-blacklist.service.ts with API integration for blacklist operations
- Add Global Blacklist navigation item to app sidebar with Ban icon
- Add internationalization support for blacklist UI with English and Chinese translations
- Include pattern matching rules documentation (domain wildcards, keywords, IP addresses, CIDR ranges)
- Add loading states, error handling, and success/error toast notifications
- Implement textarea input with change tracking and save button state management
2026-01-06 11:50:31 +08:00
yyhuni
2a3d9b4446 feat(target): add initiate scan button and improve overview layout
- Add "Initiate Scan" button to target overview header with Play icon
- Implement InitiateScanDialog component integration for quick scan initiation
- Improve scheduled scans card layout with flexbox for better vertical spacing
- Reduce displayed scheduled scans from 3 to 2 items for better UI balance
- Enhance vulnerability statistics card styling with proper flex layout
- Add state management for scan dialog open/close functionality
- Update i18n translations (en.json, zh.json) with "initiateScan" label
- Refactor target info section to accommodate new action button with justify-between layout
- Improve empty state centering in scheduled scans card using flex layout
2026-01-06 11:10:47 +08:00
yyhuni
9b63203b5a refactor(migrations,frontend,backend): reorganize app structure and enhance target management UI
- Consolidate common migrations into dedicated common app module
- Remove asset search materialized view migration (0002) and simplify migration structure
- Reorganize target detail page with new overview and settings sub-routes
- Add target overview component displaying key asset information
- Add target settings component for configuration management
- Enhance scan history UI with improved data table and column definitions
- Update scheduled scan dialog with better form handling
- Refactor target service with improved API integration
- Update scan hooks (use-scans, use-scheduled-scans) with better state management
- Add internationalization strings for new target management features
- Update Docker initialization and startup scripts for new app structure
- Bump Django to 5.2.7 and update dependencies in requirements.txt
- Add WeChat group contact information to README
- Improve UI tabs component with better accessibility and styling
2026-01-06 10:42:38 +08:00
yyhuni
6ff86e14ec Update README.md 2026-01-06 09:59:55 +08:00
yyhuni
4c1282e9bb 完成黑名单后端逻辑 2026-01-05 23:26:50 +08:00
yyhuni
ba3a9b709d feat(system-logs): enhance ANSI log viewer with log level colorization
- Add LOG_LEVEL_COLORS configuration mapping for DEBUG, INFO, WARNING, WARN, ERROR, and CRITICAL levels
- Implement hasAnsiCodes() function to detect presence of ANSI escape sequences in log content
- Add colorizeLogContent() function to parse plain text logs and apply color styling based on log levels
- Support dual-mode log parsing: ANSI color codes and plain text log level detection
- Rename converter to ansiConverter for clarity and consistency
- Change newline handling from true to false for manual line break control
- Apply color-coded styling to timestamps (gray), log levels (level-specific colors), and messages
- Add bold font-weight styling for CRITICAL level logs for better visibility
2026-01-05 16:27:31 +08:00
github-actions[bot]
283b28b46a chore: bump version to v1.3.15-dev 2026-01-05 02:05:29 +00:00
yyhuni
1269e5a314 refactor(scan): reorganize models and serializers into modular structure
- Split monolithic models.py into separate model files (scan_models.py, scan_log_model.py, scheduled_scan_model.py, subfinder_provider_settings_model.py)
- Split monolithic serializers.py into separate serializer files with dedicated modules for each domain
- Add SubfinderProviderSettings model to store API key configurations for subfinder data sources
- Create SubfinderProviderConfigService to generate provider configuration files dynamically
- Add subfinder_provider_settings views and serializers for API key management
- Update subdomain_discovery_flow to support provider configuration file generation and passing to subfinder
- Update command templates to use provider config file and remove recursive flag for better source coverage
- Add frontend settings page for managing API keys at /settings/api-keys
- Add frontend hooks and services for API key settings management
- Update sidebar navigation to include API keys settings link
- Add internationalization support for new API keys settings UI (English and Chinese)
- Improves code maintainability by organizing related models and serializers into logical modules
2026-01-05 10:00:19 +08:00
yyhuni
802e967906 docs: add online demo link to README
- Add new "🌐 在线 Demo" section with live demo URL
- Include disclaimer note that demo is UI-only without backend database
- Improve documentation to help users quickly access and test the application
2026-01-04 19:19:33 +08:00
github-actions[bot]
e446326416 chore: bump version to v1.3.14 2026-01-04 11:02:14 +00:00
yyhuni
e0abb3ce7b Merge branch 'dev' 2026-01-04 18:57:49 +08:00
yyhuni
d418baaf79 feat(mock,scan): add comprehensive mock data and improve system load management
- Add mock data files for directories, fingerprints, IP addresses, notification settings, nuclei templates, search, system logs, tools, and wordlists
- Update mock index to export new mock data modules
- Increase SCAN_LOAD_CHECK_INTERVAL from 30 to 180 seconds for better system stability
- Improve load check logging message to clarify OOM prevention strategy
- Enhance mock data infrastructure to support frontend development and testing
2026-01-04 18:52:08 +08:00
github-actions[bot]
f8da408580 chore: bump version to v1.3.13-dev 2026-01-04 10:24:10 +00:00
yyhuni
7cd4354d8f feat(scan,asset): add scan logging system and improve search view architecture
- Add user_logger utility for structured scan operation logging
- Create scan log views and API endpoints for retrieving scan execution logs
- Add scan-log-list component and use-scan-logs hook for frontend log display
- Refactor asset search views to remove ArrayField support from pg_ivm IMMV
- Update search_service.py to JOIN original tables for array field retrieval
- Add system architecture requirements (AMD64/ARM64) to README
- Update scan flow handlers to integrate logging system
- Enhance scan progress dialog with log viewer integration
- Add ANSI log viewer component for formatted log display
- Update scan service API to support log retrieval endpoints
- Migrate database schema to support new logging infrastructure
- Add internationalization strings for scan logs (en/zh)
This change improves observability of scan operations and resolves pg_ivm limitations with ArrayField types by fetching array data from original tables via JOIN operations.
2026-01-04 18:19:45 +08:00
yyhuni
6bf35a760f chore(docker): configure Prefect home directory in worker image
- Add PREFECT_HOME environment variable pointing to /app/.prefect
- Create Prefect configuration directory to prevent home directory warnings
- Update step numbering in Dockerfile comments for clarity
- Ensures Prefect can properly initialize configuration without relying on user home directory
2026-01-04 10:39:11 +08:00
github-actions[bot]
be9ecadffb chore: bump version to v1.3.12-dev 2026-01-04 01:05:00 +00:00
yyhuni
adb53c9f85 feat(asset,scan): add configurable statement timeout and improve CSV export
- Add statement_timeout_ms parameter to search_service count() and stream_search() methods for long-running exports
- Replace server-side cursors with OFFSET/LIMIT batching for better Django compatibility
- Introduce create_csv_export_response() utility function to standardize CSV export handling
- Add engine-preset-selector and scan-config-editor components for enhanced scan configuration UI
- Update YAML editor component with improved styling and functionality
- Add i18n translations for new scan configuration features in English and Chinese
- Refactor CSV export endpoints to use new utility function instead of manual StreamingHttpResponse
- Remove unused uuid import from search_service.py
- Update nginx configuration for improved performance
- Enhance search service with configurable timeout support for large dataset exports
2026-01-04 08:58:31 +08:00
yyhuni
7b7bbed634 Update README.md 2026-01-03 22:15:35 +08:00
github-actions[bot]
8dd3f0536e chore: bump version to v1.3.11-dev 2026-01-03 11:54:31 +00:00
yyhuni
8a8062a12d refactor(scan): rename merged_configuration to yaml_configuration
- Rename `merged_configuration` field to `yaml_configuration` in Scan and ScheduledScan models for clarity
- Update all references across scan repositories, services, views, and serializers
- Update database migration to reflect field name change with improved help text
- Update frontend components to use new field naming convention
- Add YAML editor component for improved configuration editing in UI
- Update engine configuration retrieval in initiate_scan_flow to use new field name
- Remove unused asset tasks __init__.py module
- Simplify README feedback section for better clarity
- Update frontend type definitions and internationalization messages for consistency
2026-01-03 19:50:20 +08:00
yyhuni
55908a2da5 fix(asset,scan): improve decorator usage and dialog layout
- Fix transaction.non_atomic_requests decorator usage in AssetSearchExportView by wrapping with method_decorator for proper class-based view compatibility
- Update scan progress dialog to use flexible width (sm:max-w-fit sm:min-w-[450px]) instead of fixed width for better responsiveness
- Refactor engine names display from single Badge to grid layout with multiple badges for improved readability when multiple engines are present
- Add proper spacing and alignment adjustments (gap-4, items-start) to accommodate multi-line engine badge display
- Add text-xs and whitespace-nowrap to engine badges for consistent styling in grid layout
2026-01-03 18:46:44 +08:00
github-actions[bot]
22a7d4f091 chore: bump version to v1.3.10-dev 2026-01-03 10:45:32 +00:00
yyhuni
f287f18134 更新锁定镜像 2026-01-03 18:33:25 +08:00
yyhuni
de27230b7a 更新构建ci 2026-01-03 18:28:57 +08:00
github-actions[bot]
15a6295189 chore: bump version to v1.3.8-dev 2026-01-03 10:24:17 +00:00
yyhuni
674acdac66 refactor(asset): move database extension initialization to migrations
- Remove pg_trgm and pg_ivm extension setup from AssetConfig.ready() method
- Move extension creation to migration 0002 using RunSQL operations
- Add pg_trgm extension creation for text search index support
- Add pg_ivm extension creation for IMMV incremental maintenance
- Generate unique cursor names in search_service to prevent concurrent request conflicts
- Add @transaction.non_atomic_requests decorator to export view for server-side cursor compatibility
- Simplify app initialization by delegating extension setup to database migrations
- Improve thread safety and concurrency handling for streaming exports
2026-01-03 18:20:27 +08:00
github-actions[bot]
c59152bedf chore: bump version to v1.3.7-dev 2026-01-03 09:56:39 +00:00
yyhuni
b4037202dc feat: use registry cache for faster builds 2026-01-03 17:35:54 +08:00
yyhuni
4b4f9862bf ci(docker): add postgres image build configuration and update image tags
- Add xingrin-postgres image build job to docker-build workflow for multi-platform support (linux/amd64,linux/arm64)
- Update docker-compose.dev.yml to use IMAGE_TAG variable with dev as default fallback
- Update docker-compose.yml to use IMAGE_TAG variable with required validation
- Replace hardcoded postgres image tag (15) with dynamic IMAGE_TAG for better version management
- Enable flexible image tagging across development and production environments
2026-01-03 17:26:34 +08:00
github-actions[bot]
1c42e4978f chore: bump version to v1.3.5-dev 2026-01-03 08:44:06 +00:00
github-actions[bot]
57bab63997 chore: bump version to v1.3.3-dev 2026-01-03 05:55:07 +00:00
github-actions[bot]
b1f0f18ac0 chore: bump version to v1.3.4-dev 2026-01-03 05:54:50 +00:00
yyhuni
ccee5471b8 docs(readme): add notification push service documentation
- Add notification push service feature to visualization interface section
- Document support for real-time WeChat Work, Telegram, and Discord message push
- Enhance feature list clarity for notification capabilities
2026-01-03 13:34:36 +08:00
yyhuni
0ccd362535 优化下载逻辑 2026-01-03 13:32:58 +08:00
yyhuni
7f2af7f7e2 feat(search): add result export functionality and pagination limit support
- Add optional limit parameter to AssetSearchService.search() method for controlling result set size
- Implement AssetSearchExportView for exporting search results as CSV files with UTF-8 BOM encoding
- Add CSV export endpoint at GET /api/assets/search/export/ with configurable MAX_EXPORT_ROWS limit (10000)
- Support both website and endpoint asset types with type-specific column mappings in CSV export
- Format array fields (tech, matched_gf_patterns) and dates appropriately in exported CSV
- Update URL routing to include new search export endpoint
- Update views __init__.py to export AssetSearchExportView
- Add CSV generation with streaming response for efficient memory usage on large exports
- Update frontend search service to support export functionality
- Add internationalization strings for export feature in en.json and zh.json
- Update smart-filter-input and search-results-table components to support export UI
- Update installation and Docker startup scripts for deployment compatibility
2026-01-03 13:22:21 +08:00
yyhuni
4bd0f9e8c1 feat(search): implement dual-view IMMV architecture for website and endpoint assets
- Add incremental materialized view (IMMV) support for both Website and Endpoint asset types using pg_ivm extension
- Create asset_search_view IMMV with optimized indexes for host, title, url, headers, body, tech, status_code, and created_at fields
- Create endpoint_search_view IMMV with identical field structure and indexing strategy for endpoint-specific searches
- Extend search_service.py to support asset type routing with VIEW_MAPPING and VALID_ASSET_TYPES configuration
- Add comprehensive field mapping and array field definitions for both asset types
- Implement dual-query execution path in search views to handle website and endpoint searches independently
- Update frontend search components to support asset type filtering and result display
- Add search results table component with improved data presentation and filtering capabilities
- Update installation scripts and Docker configuration for pg_ivm extension deployment
- Add internationalization strings for new search UI elements in English and Chinese
- Consolidate index creation and cleanup logic in migrations for maintainability
- Enable automatic incremental updates on data changes without manual view refresh
2026-01-03 12:41:20 +08:00
yyhuni
68cc996e3b refactor(asset): standardize snapshot and asset model field naming and types
- Rename `status` to `status_code` in WebsiteSnapshotDTO for consistency
- Rename `web_server` to `webserver` in WebsiteSnapshotDTO for consistency
- Make `target_id` required field in EndpointSnapshotDTO and WebsiteSnapshotDTO
- Remove optional validation check for `target_id` in EndpointSnapshotDTO
- Convert CharField to TextField for url, location, title, webserver, and content_type fields in Endpoint and EndpointSnapshot models to support longer values
- Update migration 0001_initial.py to reflect field type changes from CharField to TextField
- Update all related services and repositories to use standardized field names
- Update serializers to map renamed fields correctly
- Ensure consistent field naming across DTOs, models, and database schema
2026-01-03 09:08:25 +08:00
github-actions[bot]
f1e79d638e chore: bump version to v1.3.2-dev 2026-01-03 00:33:26 +00:00
yyhuni
d484133e4c chore(docker): optimize server dockerfile with docker-ce-cli installation
- Replace full docker.io package with lightweight docker-ce-cli to reduce image size
- Add ca-certificates and gnupg dependencies for secure package management
- Improve Docker installation process for local Worker task distribution
- Reduce unnecessary dependencies in server container build
2026-01-03 08:09:03 +08:00
yyhuni
fc977ae029 chore(docker,frontend): optimize docker installation and add auth bypass config
- Replace docker.io installation script with apt-get package manager for better reliability
- Add NEXT_PUBLIC_SKIP_AUTH environment variable to Vercel config for development
- Improve Docker build layer caching by using native package manager instead of curl script
- Simplify frontend deployment configuration for local development workflows
2026-01-03 08:08:40 +08:00
yyhuni
f328474404 feat(frontend): add comprehensive mock data infrastructure for services
- Add mock data modules for auth, engines, notifications, scheduled-scans, and workers
- Implement mock authentication data with user profiles and login/logout responses
- Create mock scan engine configurations with multiple predefined scanning profiles
- Add mock notification system with various severity levels and categories
- Implement mock scheduled scan data with cron expressions and run history
- Add mock worker node data with status and performance metrics
- Update service layer to integrate with new mock data infrastructure
- Provide helper functions for filtering and paginating mock data
- Enable frontend development and testing without backend API dependency
2026-01-03 07:59:20 +08:00
yyhuni
68e726a066 chore(docker): update base image to python 3.10-slim-bookworm
- Update Python base image from 3.10-slim to 3.10-slim-bookworm
- Ensures compatibility with latest Debian stable release
- Improves security with updated system packages and dependencies
2026-01-02 23:19:09 +08:00
yyhuni
77a6f45909 fix:搜索的楼栋统计问题 2026-01-02 23:12:55 +08:00
yyhuni
49d1f1f1bb 采用ivm增量更新方案进行搜索 2026-01-02 22:46:40 +08:00
yyhuni
db8ecb1644 feat(search): add mock data infrastructure and vulnerability detail integration
- Add comprehensive mock data configuration for all major entities (dashboard, endpoints, organizations, scans, subdomains, targets, vulnerabilities, websites)
- Implement mock service layer with centralized config for development and testing
- Add vulnerability detail dialog integration to search results with lazy loading
- Enhance search result card with vulnerability viewing capability
- Update search materialized view migration to include vulnerability name field
- Implement default host fuzzy search fallback for bare text queries without operators
- Add vulnerability data formatting in search view for consistent API response structure
- Configure Vercel deployment settings and update Next.js configuration
- Update all service layers to support mock data injection for development environment
- Extend search types with improved vulnerability data structure
- Add internationalization strings for vulnerability loading errors
- Enable rapid frontend development and testing without backend API dependency
2026-01-02 19:06:09 +08:00
yyhuni
18cc016268 feat(search): implement advanced query parser with expression syntax support
- Add SearchQueryParser class to parse complex search expressions with operators (=, ==, !=)
- Support logical operators && (AND) and || (OR) for combining multiple conditions
- Implement field mapping for frontend to database field translation
- Add support for array field searching (tech stack) with unnest and ANY operators
- Support fuzzy matching (=), exact matching (==), and negation (!=) operators
- Add proper SQL injection prevention through parameterized queries
- Refactor search service to use expression-based filtering instead of simple filters
- Update search views to integrate new query parser
- Enhance frontend search hook and service to support new expression syntax
- Update search types to reflect new query structure
- Improve search page UI to display expression syntax examples and help text
- Enable complex multi-condition searches like: host="api" && tech="nginx" || status=="200"
2026-01-02 17:46:31 +08:00
yyhuni
23bc463283 feat(search): improve technology stack filtering with fuzzy matching
- Replace exact array matching with fuzzy search using ILIKE operator
- Update tech filter to search within array elements using unnest() and EXISTS
- Support partial technology name matching (e.g., "node" matches "nodejs")
- Apply consistent fuzzy matching logic across both search methods
- Enhance user experience by allowing flexible technology stack queries
2026-01-02 17:01:24 +08:00
yyhuni
7b903b91b2 feat(search): implement comprehensive search infrastructure with materialized views and pagination
- Add asset search service with materialized view support for optimized queries
- Implement search refresh service for maintaining up-to-date search indexes
- Create database migrations for AssetStatistics, StatisticsHistory, Directory, and DirectorySnapshot models
- Add PostgreSQL GIN indexes with trigram operators for full-text search capabilities
- Implement search pagination component with configurable page size and navigation
- Add search result card component with enhanced asset display formatting
- Create search API views with filtering and sorting capabilities
- Add use-search hook for client-side search state management
- Implement search service client for API communication
- Update search types with pagination metadata and result structures
- Add English and Chinese translations for search UI components
- Enhance scheduler to support search index refresh tasks
- Refactor asset views into modular search_views and asset_views
- Update URL routing to support new search endpoints
- Improve scan flow handlers for better search index integration
2026-01-02 16:57:54 +08:00
yyhuni
b3136d51b9 搜索页面前端UI设计完成 2026-01-02 10:07:26 +08:00
github-actions[bot]
08372588a4 chore: bump version to v1.2.15 2026-01-01 15:44:15 +00:00
yyhuni
236c828041 chore(fingerprints): remove deprecated ARL fingerprint rules
- Remove obsolete fingerprint detection rules from ARL.yaml
- Clean up legacy device and service signatures that are no longer maintained
- Reduce fingerprint database size by eliminating unused detection patterns
- Improve maintainability by removing outdated vendor-specific rules
2026-01-01 22:45:08 +08:00
yyhuni
fb13bb74d8 feat(filter): add array fuzzy search support with PostgreSQL array_to_string
- Add ArrayToString custom PostgreSQL function for converting arrays to delimited strings
- Implement array field annotation in QueryBuilder to support fuzzy matching on JSON array fields
- Enhance _build_single_q to handle three operators for JSON arrays: exact match (==), negation (!=), and fuzzy search (=)
- Update target navigation routes from subdomain to website view for consistency
- Enable fuzzy search on array fields by converting them to text during query building
2026-01-01 22:41:57 +08:00
yyhuni
f076c682b6 feat(scan): add multi-engine support and config merging with enhanced indexing
- Add multi-engine support to Scan model with engine_ids and engine_names fields
- Implement config_merger utility for merging multiple engine configurations
- Add merged_configuration property to Scan model for unified config access
- Update scan creation and scheduling services to handle multiple engines
- Add pg_trgm GIN indexes to asset and snapshot models for fuzzy search on url, title, and name fields
- Update scan views and serializers to support multi-engine selection and display
- Enhance frontend components for multi-engine scan initiation and scheduling
- Update test data generation script for multi-engine scan scenarios
- Add internationalization strings for multi-engine UI elements
- Refactor scan flow to use merged configuration instead of single engine config
- Update Docker compose files with latest configuration
2026-01-01 22:35:05 +08:00
yyhuni
9eda2caceb feat(asset): add response headers and body tracking with pg_trgm indexing
- Rename body_preview to response_body across endpoint and website models for consistency
- Change response_headers from Dict to string type for efficient text indexing
- Add pg_trgm PostgreSQL extension initialization in AssetConfig for GIN index support
- Update all DTOs to reflect response_body and response_headers field changes
- Modify repositories to handle new response_body and response_headers formats
- Update serializers and views to work with string-based response headers
- Add response_headers and response_body columns to frontend endpoint and website tables
- Update command templates and scan tasks to populate response_body and response_headers
- Add database initialization script for pg_trgm extension in PostgreSQL setup
- Update frontend types and translations for new field names
- Enable efficient full-text search on response headers and body content through GIN indexes
2026-01-01 19:34:11 +08:00
yyhuni
b1c9e202dd feat(sidebar): add feedback link to secondary navigation menu
- Import IconMessageReport icon from tabler/icons-react for feedback menu item
- Add feedback navigation item linking to GitHub issues page
- Add "feedback" translation key to English messages (en.json)
- Add "feedback" translation key to Chinese messages (zh.json) as "反馈建议"
- Improves user engagement by providing direct access to issue reporting
2026-01-01 18:31:34 +08:00
yyhuni
918669bc29 style(ui): update expandable cell whitespace handling for better formatting
- Change whitespace class from `whitespace-normal` to `whitespace-pre-wrap` in expandable cell component
- Improves text rendering by preserving whitespace and line breaks in cell content
- Ensures consistent formatting display across different content types (mono, url, muted variants)
2026-01-01 16:41:47 +08:00
yyhuni
fd70b0544d docs(frontend): update Chinese translations to English for consistency
- Change "响应头" to "Response Headers" in endpoint messages
- Change "响应头" to "Response Headers" in website messages
- Maintain consistency across frontend message translations
- Improve clarity for international users by standardizing field labels
2026-01-01 16:23:03 +08:00
github-actions[bot]
0f2df7a5f3 chore: bump version to v1.2.14-dev 2026-01-01 05:13:25 +00:00
yyhuni
857ab737b5 feat(fingerprint): enhance xingfinger task with snapshot tracking and field merging
- Replace `not_found_count` with `created_count` and `snapshot_count` metrics in fingerprint detect flow
- Initialize and aggregate `snapshot_count` across tool statistics
- Refactor `parse_xingfinger_line()` to return structured dict with url, techs, server, title, status_code, and content_length
- Replace `bulk_merge_tech_field()` with `bulk_merge_website_fields()` to support merging multiple WebSite fields
- Implement smart merge strategy: arrays deduplicated, scalar fields only updated when empty/NULL
- Remove dynamic model loading via importlib in favor of direct WebSite model import
- Add WebsiteSnapshotDTO and DjangoWebsiteSnapshotRepository imports for snapshot handling
- Improve xingfinger output parsing to capture server, title, and HTTP metadata alongside technology detection
2026-01-01 12:40:49 +08:00
yyhuni
ee2d99edda feat(asset): add response headers tracking to endpoints and websites
- Add response_headers field to Endpoint and WebSite models as JSONField
- Add response_headers field to EndpointSnapshot and WebsiteSnapshot models
- Update all related DTOs to include response_headers with Dict[str, Any] type
- Add GIN indexes on response_headers fields for optimized JSON queries
- Update endpoint and website repositories to handle response_headers data
- Update serializers to include response_headers in API responses
- Update frontend components to display response headers in detail views
- Add response_headers to fingerprint detection and site scan tasks
- Update command templates and engine config to support header extraction
- Add internationalization strings for response headers in en.json and zh.json
- Update TypeScript types for endpoint and website to include response_headers
- Enhance scan history and target detail pages to show response header information
2026-01-01 12:25:22 +08:00
github-actions[bot]
db6ce16aca chore: bump version to v1.2.13-dev 2026-01-01 02:24:08 +00:00
yyhuni
ab800eca06 feat(frontend): reorder navigation tabs for improved UX
- Move "Websites" tab to first position in scan history and target layouts
- Reposition "IP Addresses" tab before "Ports" for better logical flow
- Maintain consistent tab ordering across both scan history and target pages
- Improve navigation hierarchy by placing primary discovery results first
2026-01-01 09:47:30 +08:00
yyhuni
e8e5572339 perf(asset): add GIN indexes for tech array fields and improve query parser
- Add GinIndex for tech array field in Endpoint model to optimize __contains queries
- Add GinIndex for tech array field in WebSite model to optimize __contains queries
- Import GinIndex from django.contrib.postgres.indexes
- Refactor QueryParser to protect quoted filter values during tokenization
- Implement placeholder-based filter extraction to preserve spaces within quoted values
- Replace filter tokens with placeholders before logical operator normalization
- Restore original filter conditions from placeholders during parsing
- Fix spacing in comments for consistency (add space after "从")
- Improves query performance for technology stack filtering on large datasets
2026-01-01 08:58:03 +08:00
github-actions[bot]
d48d4bbcad chore: bump version to v1.2.12-dev 2025-12-31 16:01:48 +00:00
306 changed files with 24744 additions and 16733 deletions

View File

@@ -19,7 +19,8 @@ permissions:
contents: write
jobs:
build:
# AMD64 构建(原生 x64 runner
build-amd64:
runs-on: ubuntu-latest
strategy:
matrix:
@@ -27,39 +28,30 @@ jobs:
- image: xingrin-server
dockerfile: docker/server/Dockerfile
context: .
platforms: linux/amd64,linux/arm64
- image: xingrin-frontend
dockerfile: docker/frontend/Dockerfile
context: .
platforms: linux/amd64 # ARM64 构建时 Next.js 在 QEMU 下会崩溃
- image: xingrin-worker
dockerfile: docker/worker/Dockerfile
context: .
platforms: linux/amd64,linux/arm64
- image: xingrin-nginx
dockerfile: docker/nginx/Dockerfile
context: .
platforms: linux/amd64,linux/arm64
- image: xingrin-agent
dockerfile: docker/agent/Dockerfile
context: .
platforms: linux/amd64,linux/arm64
- image: xingrin-postgres
dockerfile: docker/postgres/Dockerfile
context: docker/postgres
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Free disk space (for large builds like worker)
- name: Free disk space
run: |
echo "=== Before cleanup ==="
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
sudo docker image prune -af
echo "=== After cleanup ==="
df -h
- name: Generate SSL certificates for nginx build
if: matrix.image == 'xingrin-nginx'
@@ -69,10 +61,6 @@ jobs:
-keyout docker/nginx/ssl/privkey.pem \
-out docker/nginx/ssl/fullchain.pem \
-subj "/CN=localhost"
echo "SSL certificates generated for CI build"
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -83,7 +71,120 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Get version from git tag
- name: Get version
id: version
run: |
if [[ $GITHUB_REF == refs/tags/* ]]; then
echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
else
echo "VERSION=dev-$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
fi
- name: Build and push AMD64
uses: docker/build-push-action@v5
with:
context: ${{ matrix.context }}
file: ${{ matrix.dockerfile }}
platforms: linux/amd64
push: true
tags: ${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:${{ steps.version.outputs.VERSION }}-amd64
build-args: IMAGE_TAG=${{ steps.version.outputs.VERSION }}
cache-from: type=registry,ref=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:cache-amd64
cache-to: type=registry,ref=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:cache-amd64,mode=max
provenance: false
sbom: false
# ARM64 构建(原生 ARM64 runner
build-arm64:
runs-on: ubuntu-22.04-arm
strategy:
matrix:
include:
- image: xingrin-server
dockerfile: docker/server/Dockerfile
context: .
- image: xingrin-frontend
dockerfile: docker/frontend/Dockerfile
context: .
- image: xingrin-worker
dockerfile: docker/worker/Dockerfile
context: .
- image: xingrin-nginx
dockerfile: docker/nginx/Dockerfile
context: .
- image: xingrin-agent
dockerfile: docker/agent/Dockerfile
context: .
- image: xingrin-postgres
dockerfile: docker/postgres/Dockerfile
context: docker/postgres
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Generate SSL certificates for nginx build
if: matrix.image == 'xingrin-nginx'
run: |
mkdir -p docker/nginx/ssl
openssl req -x509 -nodes -days 365 -newkey rsa:2048 \
-keyout docker/nginx/ssl/privkey.pem \
-out docker/nginx/ssl/fullchain.pem \
-subj "/CN=localhost"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Get version
id: version
run: |
if [[ $GITHUB_REF == refs/tags/* ]]; then
echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
else
echo "VERSION=dev-$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
fi
- name: Build and push ARM64
uses: docker/build-push-action@v5
with:
context: ${{ matrix.context }}
file: ${{ matrix.dockerfile }}
platforms: linux/arm64
push: true
tags: ${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:${{ steps.version.outputs.VERSION }}-arm64
build-args: IMAGE_TAG=${{ steps.version.outputs.VERSION }}
cache-from: type=registry,ref=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:cache-arm64
cache-to: type=registry,ref=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:cache-arm64,mode=max
provenance: false
sbom: false
# 合并多架构 manifest
merge-manifests:
runs-on: ubuntu-latest
needs: [build-amd64, build-arm64]
strategy:
matrix:
image:
- xingrin-server
- xingrin-frontend
- xingrin-worker
- xingrin-nginx
- xingrin-agent
- xingrin-postgres
steps:
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Get version
id: version
run: |
if [[ $GITHUB_REF == refs/tags/* ]]; then
@@ -94,28 +195,27 @@ jobs:
echo "IS_RELEASE=false" >> $GITHUB_OUTPUT
fi
- name: Build and push
uses: docker/build-push-action@v5
with:
context: ${{ matrix.context }}
file: ${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: true
tags: |
${{ env.IMAGE_PREFIX }}/${{ matrix.image }}:${{ steps.version.outputs.VERSION }}
${{ steps.version.outputs.IS_RELEASE == 'true' && format('{0}/{1}:latest', env.IMAGE_PREFIX, matrix.image) || '' }}
build-args: |
IMAGE_TAG=${{ steps.version.outputs.VERSION }}
cache-from: type=gha,scope=${{ matrix.image }}
cache-to: type=gha,mode=max,scope=${{ matrix.image }}
provenance: false
sbom: false
- name: Create and push multi-arch manifest
run: |
VERSION=${{ steps.version.outputs.VERSION }}
IMAGE=${{ env.IMAGE_PREFIX }}/${{ matrix.image }}
docker manifest create ${IMAGE}:${VERSION} \
${IMAGE}:${VERSION}-amd64 \
${IMAGE}:${VERSION}-arm64
docker manifest push ${IMAGE}:${VERSION}
if [[ "${{ steps.version.outputs.IS_RELEASE }}" == "true" ]]; then
docker manifest create ${IMAGE}:latest \
${IMAGE}:${VERSION}-amd64 \
${IMAGE}:${VERSION}-arm64
docker manifest push ${IMAGE}:latest
fi
# 所有镜像构建成功后,更新 VERSION 文件
# 根据 tag 所在的分支更新对应分支的 VERSION 文件
# 更新 VERSION 文件
update-version:
runs-on: ubuntu-latest
needs: build
needs: merge-manifests
if: startsWith(github.ref, 'refs/tags/v')
steps:
- name: Checkout repository

1
.gitignore vendored
View File

@@ -64,6 +64,7 @@ backend/.env.local
.coverage
htmlcov/
*.cover
.hypothesis/
# ============================
# 后端 (Go) 相关

114
README.md
View File

@@ -13,18 +13,25 @@
<p align="center">
<a href="#-功能特性">功能特性</a> •
<a href="#-全局资产搜索">资产搜索</a> •
<a href="#-快速开始">快速开始</a> •
<a href="#-文档">文档</a> •
<a href="#-技术栈">技术栈</a> •
<a href="#-反馈与贡献">反馈与贡献</a>
</p>
<p align="center">
<sub>🔍 关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
<sub>🔍 关键词: ASM | 攻击面管理 | 漏洞扫描 | 资产发现 | 资产搜索 | Bug Bounty | 渗透测试 | Nuclei | 子域名枚举 | EASM</sub>
</p>
---
## 🌐 在线 Demo
**[https://xingrin.vercel.app/](https://xingrin.vercel.app/)**
> ⚠️ 仅用于 UI 展示,未接入后端数据库
---
<p align="center">
<b>🎨 现代化 UI </b>
@@ -51,23 +58,33 @@
## ✨ 功能特性
### 🎯 目标与资产管理
- **组织管理** - 多层级目标组织,灵活分组
- **目标管理** - 支持域名、IP目标类型
- **资产发现** - 子域名、网站、端点、目录自动发现
- **资产快照** - 扫描结果快照对比,追踪资产变化
### 扫描能力
### 🔍 漏洞扫描
- **多引擎支持** - 集成 Nuclei 等主流扫描引擎
- **自定义流程** - YAML 配置扫描流程,灵活编排
- **定时扫描** - Cron 表达式配置,自动化周期扫描
| 功能 | 状态 | 工具 | 说明 |
|------|------|------|------|
| 子域名扫描 | ✅ | Subfinder, Amass, PureDNS | 被动收集 + 主动爆破,聚合 50+ 数据源 |
| 端口扫描 | ✅ | Naabu | 自定义端口范围 |
| 站点发现 | ✅ | HTTPX | HTTP 探测,自动获取标题、状态码、技术栈 |
| 指纹识别 | ✅ | XingFinger | 2.7W+ 指纹规则,多源指纹库 |
| URL 收集 | ✅ | Waymore, Katana | 历史数据 + 主动爬取 |
| 目录扫描 | ✅ | FFUF | 高速爆破,智能字典 |
| 漏洞扫描 | ✅ | Nuclei, Dalfox | 9000+ POC 模板XSS 检测 |
| 站点截图 | ✅ | Playwright | WebP 高压缩存储 |
### 🔖 指纹识别
- **多源指纹库** - 内置 EHole、Goby、Wappalyzer、Fingers、FingerPrintHub、ARL 等 2.7W+ 指纹规则
- **自动识别** - 扫描流程自动执行,识别 Web 应用技术栈
- **指纹管理** - 支持查询、导入、导出指纹规则
### 平台能力
#### 扫描流程架构
| 功能 | 状态 | 说明 |
|------|------|------|
| 目标管理 | ✅ | 多层级组织,支持域名/IP 目标 |
| 资产快照 | ✅ | 扫描结果对比,追踪资产变化 |
| 黑名单过滤 | ✅ | 全局 + Target 级,支持通配符/CIDR |
| 定时任务 | ✅ | Cron 表达式,自动化周期扫描 |
| 分布式扫描 | ✅ | 多 Worker 节点,负载感知调度 |
| 全局搜索 | ✅ | 表达式语法,多字段组合查询 |
| 通知推送 | ✅ | 企业微信、Telegram、Discord |
| API 密钥管理 | ✅ | 可视化配置各数据源 API Key |
### 扫描流程架构
完整的扫描流程包括子域名发现、端口扫描、站点发现、指纹识别、URL 收集、目录扫描、漏洞扫描等阶段
@@ -88,6 +105,7 @@ flowchart LR
direction TB
URL["URL 收集<br/>waymore, katana"]
DIR["目录扫描<br/>ffuf"]
SCREENSHOT["站点截图<br/>playwright"]
end
subgraph STAGE3["阶段 3: 漏洞检测"]
@@ -112,6 +130,7 @@ flowchart LR
style FINGER fill:#5dade2,stroke:#3498db,stroke-width:1px,color:#fff
style URL fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
style DIR fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
style SCREENSHOT fill:#bb8fce,stroke:#9b59b6,stroke-width:1px,color:#fff
style VULN fill:#f0b27a,stroke:#e67e22,stroke-width:1px,color:#fff
```
@@ -162,9 +181,34 @@ flowchart TB
W3 -.心跳上报.-> REDIS
```
### 🔎 全局资产搜索
- **多类型搜索** - 支持 Website 和 Endpoint 两种资产类型
- **表达式语法** - 支持 `=`(模糊)、`==`(精确)、`!=`(不等于)操作符
- **逻辑组合** - 支持 `&&` (AND) 和 `||` (OR) 逻辑组合
- **多字段查询** - 支持 host、url、title、tech、status、body、header 字段
- **CSV 导出** - 流式导出全部搜索结果,无数量限制
#### 搜索语法示例
```bash
# 基础搜索
host="api" # host 包含 "api"
status=="200" # 状态码精确等于 200
tech="nginx" # 技术栈包含 nginx
# 组合搜索
host="api" && status=="200" # host 包含 api 且状态码为 200
tech="vue" || tech="react" # 技术栈包含 vue 或 react
# 复杂查询
host="admin" && tech="php" && status=="200"
url="/api/v1" && status!="404"
```
### 📊 可视化界面
- **数据统计** - 资产/漏洞统计仪表盘
- **实时通知** - WebSocket 消息推送
- **通知推送** - 实时企业微信tgdiscard消息推送服务
---
@@ -172,7 +216,8 @@ flowchart TB
### 环境要求
- **操作系统**: Ubuntu 20.04+ / Debian 11+ (推荐)
- **操作系统**: Ubuntu 20.04+ / Debian 11+
- **系统架构**: AMD64 (x86_64) / ARM64 (aarch64)
- **硬件**: 2核 4G 内存起步20GB+ 磁盘空间
### 一键安装
@@ -192,11 +237,11 @@ sudo ./install.sh --mirror
> **💡 --mirror 参数说明**
> - 自动配置 Docker 镜像加速(国内镜像源)
> - 加速 Git 仓库克隆Nuclei 模板等)
> - 大幅提升安装速度,避免网络超时
### 访问服务
- **Web 界面**: `https://ip:8083`
- **默认账号**: admin / admin首次登录后请修改密码
### 常用命令
@@ -216,17 +261,40 @@ sudo ./uninstall.sh
## 🤝 反馈与贡献
- 🐛 **如果发现 Bug** 可以点击右边链接进行提交 [Issue](https://github.com/yyhuni/xingrin/issues)
- 💡 **有新想法比如UI设计功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues)
- 💡 **发现 Bug有新想法比如UI设计功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues) 或者公众号私信
## 📧 联系
- 目前版本就我个人使用,可能会有很多边界问题
- 如有问题,建议,其他,优先提交[Issue](https://github.com/yyhuni/xingrin/issues),也可以直接给我的公众号发消息,我都会回复的
- 微信公众号: **塔罗安全学苑**
- 微信群去公众号底下的菜单,有个交流群,点击就可以看到了,链接过期可以私信我拉你
<img src="docs/wechat-qrcode.png" alt="微信公众号" width="200">
### 🎁 关注公众号免费领取指纹库
| 指纹库 | 数量 |
|--------|------|
| ehole.json | 21,977 |
| ARL.yaml | 9,264 |
| goby.json | 7,086 |
| FingerprintHub.json | 3,147 |
> 💡 关注公众号回复「指纹」即可获取
## ☕ 赞助支持
如果这个项目对你有帮助谢谢请我能喝杯蜜雪冰城你的star和赞助是我免费更新的动力
<p>
<img src="docs/wx_pay.jpg" alt="微信支付" width="200">
<img src="docs/zfb_pay.jpg" alt="支付宝" width="200">
</p>
### 🙏 感谢以下赞助
| 昵称 | 金额 |
|------|------|
| X闭关中 | ¥88 |
## ⚠️ 免责声明

View File

@@ -1 +1 @@
v1.2.9-dev
v1.5.4-dev

1
backend/.gitignore vendored
View File

@@ -7,6 +7,7 @@ __pycache__/
*.egg-info/
dist/
build/
.hypothesis/ # Hypothesis 属性测试缓存
# 虚拟环境
venv/

View File

@@ -4,7 +4,3 @@ from django.apps import AppConfig
class AssetConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'apps.asset'
def ready(self):
# 导入所有模型以确保Django发现并注册
from . import models

View File

@@ -14,12 +14,13 @@ class EndpointDTO:
status_code: Optional[int] = None
content_length: Optional[int] = None
webserver: Optional[str] = None
body_preview: Optional[str] = None
response_body: Optional[str] = None
content_type: Optional[str] = None
tech: Optional[List[str]] = None
vhost: Optional[bool] = None
location: Optional[str] = None
matched_gf_patterns: Optional[List[str]] = None
response_headers: Optional[str] = None
def __post_init__(self):
if self.tech is None:

View File

@@ -17,9 +17,10 @@ class WebSiteDTO:
webserver: str = ''
content_type: str = ''
tech: List[str] = None
body_preview: str = ''
response_body: str = ''
vhost: Optional[bool] = None
created_at: str = None
response_headers: str = ''
def __post_init__(self):
if self.tech is None:

View File

@@ -13,6 +13,7 @@ class EndpointSnapshotDTO:
快照只属于 scan。
"""
scan_id: int
target_id: int # 必填,用于同步到资产表
url: str
host: str = '' # 主机名域名或IP地址
title: str = ''
@@ -22,10 +23,10 @@ class EndpointSnapshotDTO:
webserver: str = ''
content_type: str = ''
tech: List[str] = None
body_preview: str = ''
response_body: str = ''
vhost: Optional[bool] = None
matched_gf_patterns: List[str] = None
target_id: Optional[int] = None # 冗余字段,用于同步到资产表
response_headers: str = ''
def __post_init__(self):
if self.tech is None:
@@ -42,9 +43,6 @@ class EndpointSnapshotDTO:
"""
from apps.asset.dtos.asset import EndpointDTO
if self.target_id is None:
raise ValueError("target_id 不能为 None无法同步到资产表")
return EndpointDTO(
target_id=self.target_id,
url=self.url,
@@ -53,10 +51,11 @@ class EndpointSnapshotDTO:
status_code=self.status_code,
content_length=self.content_length,
webserver=self.webserver,
body_preview=self.body_preview,
response_body=self.response_body,
content_type=self.content_type,
tech=self.tech if self.tech else [],
vhost=self.vhost,
location=self.location,
matched_gf_patterns=self.matched_gf_patterns if self.matched_gf_patterns else []
matched_gf_patterns=self.matched_gf_patterns if self.matched_gf_patterns else [],
response_headers=self.response_headers,
)

View File

@@ -13,18 +13,19 @@ class WebsiteSnapshotDTO:
快照只属于 scantarget 信息通过 scan.target 获取。
"""
scan_id: int
target_id: int # 仅用于传递数据,不保存到数据库
target_id: int # 必填,用于同步到资产表
url: str
host: str
title: str = ''
status: Optional[int] = None
status_code: Optional[int] = None # 统一命名status -> status_code
content_length: Optional[int] = None
location: str = ''
web_server: str = ''
webserver: str = '' # 统一命名web_server -> webserver
content_type: str = ''
tech: List[str] = None
body_preview: str = ''
response_body: str = ''
vhost: Optional[bool] = None
response_headers: str = ''
def __post_init__(self):
if self.tech is None:
@@ -44,12 +45,13 @@ class WebsiteSnapshotDTO:
url=self.url,
host=self.host,
title=self.title,
status_code=self.status,
status_code=self.status_code,
content_length=self.content_length,
location=self.location,
webserver=self.web_server,
webserver=self.webserver,
content_type=self.content_type,
tech=self.tech if self.tech else [],
body_preview=self.body_preview,
vhost=self.vhost
response_body=self.response_body,
vhost=self.vhost,
response_headers=self.response_headers,
)

View File

@@ -0,0 +1,345 @@
# Generated by Django 5.2.7 on 2026-01-06 00:55
import django.contrib.postgres.fields
import django.contrib.postgres.indexes
import django.core.validators
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
('scan', '0001_initial'),
('targets', '0001_initial'),
]
operations = [
migrations.CreateModel(
name='AssetStatistics',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('total_targets', models.IntegerField(default=0, help_text='目标总数')),
('total_subdomains', models.IntegerField(default=0, help_text='子域名总数')),
('total_ips', models.IntegerField(default=0, help_text='IP地址总数')),
('total_endpoints', models.IntegerField(default=0, help_text='端点总数')),
('total_websites', models.IntegerField(default=0, help_text='网站总数')),
('total_vulns', models.IntegerField(default=0, help_text='漏洞总数')),
('total_assets', models.IntegerField(default=0, help_text='总资产数(子域名+IP+端点+网站)')),
('prev_targets', models.IntegerField(default=0, help_text='上次目标总数')),
('prev_subdomains', models.IntegerField(default=0, help_text='上次子域名总数')),
('prev_ips', models.IntegerField(default=0, help_text='上次IP地址总数')),
('prev_endpoints', models.IntegerField(default=0, help_text='上次端点总数')),
('prev_websites', models.IntegerField(default=0, help_text='上次网站总数')),
('prev_vulns', models.IntegerField(default=0, help_text='上次漏洞总数')),
('prev_assets', models.IntegerField(default=0, help_text='上次总资产数')),
('updated_at', models.DateTimeField(auto_now=True, help_text='最后更新时间')),
],
options={
'verbose_name': '资产统计',
'verbose_name_plural': '资产统计',
'db_table': 'asset_statistics',
},
),
migrations.CreateModel(
name='StatisticsHistory',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('date', models.DateField(help_text='统计日期', unique=True)),
('total_targets', models.IntegerField(default=0, help_text='目标总数')),
('total_subdomains', models.IntegerField(default=0, help_text='子域名总数')),
('total_ips', models.IntegerField(default=0, help_text='IP地址总数')),
('total_endpoints', models.IntegerField(default=0, help_text='端点总数')),
('total_websites', models.IntegerField(default=0, help_text='网站总数')),
('total_vulns', models.IntegerField(default=0, help_text='漏洞总数')),
('total_assets', models.IntegerField(default=0, help_text='总资产数')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
],
options={
'verbose_name': '统计历史',
'verbose_name_plural': '统计历史',
'db_table': 'statistics_history',
'ordering': ['-date'],
'indexes': [models.Index(fields=['date'], name='statistics__date_1d29cd_idx')],
},
),
migrations.CreateModel(
name='Directory',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.CharField(help_text='完整请求 URL', max_length=2000)),
('status', models.IntegerField(blank=True, help_text='HTTP 响应状态码', null=True)),
('content_length', models.BigIntegerField(blank=True, help_text='响应体字节大小Content-Length 或实际长度)', null=True)),
('words', models.IntegerField(blank=True, help_text='响应体中单词数量(按空格分割)', null=True)),
('lines', models.IntegerField(blank=True, help_text='响应体行数(按换行符分割)', null=True)),
('content_type', models.CharField(blank=True, default='', help_text='响应头 Content-Type 值', max_length=200)),
('duration', models.BigIntegerField(blank=True, help_text='请求耗时(单位:纳秒)', null=True)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('target', models.ForeignKey(help_text='所属的扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='directories', to='targets.target')),
],
options={
'verbose_name': '目录',
'verbose_name_plural': '目录',
'db_table': 'directory',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='directory_created_2cef03_idx'), models.Index(fields=['target'], name='directory_target__e310c8_idx'), models.Index(fields=['url'], name='directory_url_ba40cd_idx'), models.Index(fields=['status'], name='directory_status_40bbe6_idx'), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='directory_url_trgm_idx', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('target', 'url'), name='unique_directory_url_target')],
},
),
migrations.CreateModel(
name='DirectorySnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.CharField(help_text='目录URL', max_length=2000)),
('status', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
('content_length', models.BigIntegerField(blank=True, help_text='内容长度', null=True)),
('words', models.IntegerField(blank=True, help_text='响应体中单词数量(按空格分割)', null=True)),
('lines', models.IntegerField(blank=True, help_text='响应体行数(按换行符分割)', null=True)),
('content_type', models.CharField(blank=True, default='', help_text='响应头 Content-Type 值', max_length=200)),
('duration', models.BigIntegerField(blank=True, help_text='请求耗时(单位:纳秒)', null=True)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='directory_snapshots', to='scan.scan')),
],
options={
'verbose_name': '目录快照',
'verbose_name_plural': '目录快照',
'db_table': 'directory_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='directory_s_scan_id_c45900_idx'), models.Index(fields=['url'], name='directory_s_url_b4b72b_idx'), models.Index(fields=['status'], name='directory_s_status_e9f57e_idx'), models.Index(fields=['content_type'], name='directory_s_content_45e864_idx'), models.Index(fields=['-created_at'], name='directory_s_created_eb9d27_idx'), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='dir_snap_url_trgm', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_directory_per_scan_snapshot')],
},
),
migrations.CreateModel(
name='Endpoint',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.TextField(help_text='最终访问的完整URL')),
('host', models.CharField(blank=True, default='', help_text='主机名域名或IP地址', max_length=253)),
('location', models.TextField(blank=True, default='', help_text='重定向地址HTTP 3xx 响应头 Location')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('title', models.TextField(blank=True, default='', help_text='网页标题HTML <title> 标签内容)')),
('webserver', models.TextField(blank=True, default='', help_text='服务器类型HTTP 响应头 Server 值)')),
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
('content_type', models.TextField(blank=True, default='', help_text='响应类型HTTP Content-Type 响应头)')),
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈(服务器/框架/语言等)', size=None)),
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
('content_length', models.IntegerField(blank=True, help_text='响应体大小(单位字节)', null=True)),
('vhost', models.BooleanField(blank=True, help_text='是否支持虚拟主机', null=True)),
('matched_gf_patterns', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='匹配的GF模式列表用于识别敏感端点如api, debug, config等', size=None)),
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
('target', models.ForeignKey(help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)', on_delete=django.db.models.deletion.CASCADE, related_name='endpoints', to='targets.target')),
],
options={
'verbose_name': '端点',
'verbose_name_plural': '端点',
'db_table': 'endpoint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='endpoint_created_44fe9c_idx'), models.Index(fields=['target'], name='endpoint_target__7f9065_idx'), models.Index(fields=['url'], name='endpoint_url_30f66e_idx'), models.Index(fields=['host'], name='endpoint_host_5b4cc8_idx'), models.Index(fields=['status_code'], name='endpoint_status__5d4fdd_idx'), models.Index(fields=['title'], name='endpoint_title_29e26c_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='endpoint_tech_2bfa7c_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='endpoint_resp_headers_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='endpoint_url_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='endpoint_title_trgm_idx', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('url', 'target'), name='unique_endpoint_url_target')],
},
),
migrations.CreateModel(
name='EndpointSnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.TextField(help_text='端点URL')),
('host', models.CharField(blank=True, default='', help_text='主机名域名或IP地址', max_length=253)),
('title', models.TextField(blank=True, default='', help_text='页面标题')),
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
('content_length', models.IntegerField(blank=True, help_text='内容长度', null=True)),
('location', models.TextField(blank=True, default='', help_text='重定向位置')),
('webserver', models.TextField(blank=True, default='', help_text='Web服务器')),
('content_type', models.TextField(blank=True, default='', help_text='内容类型')),
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈', size=None)),
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
('vhost', models.BooleanField(blank=True, help_text='虚拟主机标志', null=True)),
('matched_gf_patterns', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='匹配的GF模式列表', size=None)),
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='endpoint_snapshots', to='scan.scan')),
],
options={
'verbose_name': '端点快照',
'verbose_name_plural': '端点快照',
'db_table': 'endpoint_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='endpoint_sn_scan_id_6ac9a7_idx'), models.Index(fields=['url'], name='endpoint_sn_url_205160_idx'), models.Index(fields=['host'], name='endpoint_sn_host_577bfd_idx'), models.Index(fields=['title'], name='endpoint_sn_title_516a05_idx'), models.Index(fields=['status_code'], name='endpoint_sn_status__83efb0_idx'), models.Index(fields=['webserver'], name='endpoint_sn_webserv_66be83_idx'), models.Index(fields=['-created_at'], name='endpoint_sn_created_21fb5b_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='endpoint_sn_tech_0d0752_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='ep_snap_resp_hdr_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='ep_snap_url_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='ep_snap_title_trgm', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_endpoint_per_scan_snapshot')],
},
),
migrations.CreateModel(
name='HostPortMapping',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('host', models.CharField(help_text='主机名域名或IP', max_length=1000)),
('ip', models.GenericIPAddressField(help_text='IP地址')),
('port', models.IntegerField(help_text='端口号1-65535', validators=[django.core.validators.MinValueValidator(1, message='端口号必须大于等于1'), django.core.validators.MaxValueValidator(65535, message='端口号必须小于等于65535')])),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('target', models.ForeignKey(help_text='所属的扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='host_port_mappings', to='targets.target')),
],
options={
'verbose_name': '主机端口映射',
'verbose_name_plural': '主机端口映射',
'db_table': 'host_port_mapping',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['target'], name='host_port_m_target__943e9b_idx'), models.Index(fields=['host'], name='host_port_m_host_f78363_idx'), models.Index(fields=['ip'], name='host_port_m_ip_2e6f02_idx'), models.Index(fields=['port'], name='host_port_m_port_9fb9ff_idx'), models.Index(fields=['host', 'ip'], name='host_port_m_host_3ce245_idx'), models.Index(fields=['-created_at'], name='host_port_m_created_11cd22_idx')],
'constraints': [models.UniqueConstraint(fields=('target', 'host', 'ip', 'port'), name='unique_target_host_ip_port')],
},
),
migrations.CreateModel(
name='HostPortMappingSnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('host', models.CharField(help_text='主机名域名或IP', max_length=1000)),
('ip', models.GenericIPAddressField(help_text='IP地址')),
('port', models.IntegerField(help_text='端口号1-65535', validators=[django.core.validators.MinValueValidator(1, message='端口号必须大于等于1'), django.core.validators.MaxValueValidator(65535, message='端口号必须小于等于65535')])),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务(主关联)', on_delete=django.db.models.deletion.CASCADE, related_name='host_port_mapping_snapshots', to='scan.scan')),
],
options={
'verbose_name': '主机端口映射快照',
'verbose_name_plural': '主机端口映射快照',
'db_table': 'host_port_mapping_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='host_port_m_scan_id_50ba0b_idx'), models.Index(fields=['host'], name='host_port_m_host_e99054_idx'), models.Index(fields=['ip'], name='host_port_m_ip_54818c_idx'), models.Index(fields=['port'], name='host_port_m_port_ed7b48_idx'), models.Index(fields=['host', 'ip'], name='host_port_m_host_8a463a_idx'), models.Index(fields=['scan', 'host'], name='host_port_m_scan_id_426fdb_idx'), models.Index(fields=['-created_at'], name='host_port_m_created_fb28b8_idx')],
'constraints': [models.UniqueConstraint(fields=('scan', 'host', 'ip', 'port'), name='unique_scan_host_ip_port_snapshot')],
},
),
migrations.CreateModel(
name='Subdomain',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('name', models.CharField(help_text='子域名名称', max_length=1000)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('target', models.ForeignKey(help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)', on_delete=django.db.models.deletion.CASCADE, related_name='subdomains', to='targets.target')),
],
options={
'verbose_name': '子域名',
'verbose_name_plural': '子域名',
'db_table': 'subdomain',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='subdomain_created_e187a8_idx'), models.Index(fields=['name', 'target'], name='subdomain_name_60e1d0_idx'), models.Index(fields=['target'], name='subdomain_target__e409f0_idx'), models.Index(fields=['name'], name='subdomain_name_d40ba7_idx'), django.contrib.postgres.indexes.GinIndex(fields=['name'], name='subdomain_name_trgm_idx', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('name', 'target'), name='unique_subdomain_name_target')],
},
),
migrations.CreateModel(
name='SubdomainSnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('name', models.CharField(help_text='子域名名称', max_length=1000)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='subdomain_snapshots', to='scan.scan')),
],
options={
'verbose_name': '子域名快照',
'verbose_name_plural': '子域名快照',
'db_table': 'subdomain_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='subdomain_s_scan_id_68c253_idx'), models.Index(fields=['name'], name='subdomain_s_name_2da42b_idx'), models.Index(fields=['-created_at'], name='subdomain_s_created_d2b48e_idx'), django.contrib.postgres.indexes.GinIndex(fields=['name'], name='subdomain_snap_name_trgm', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('scan', 'name'), name='unique_subdomain_per_scan_snapshot')],
},
),
migrations.CreateModel(
name='Vulnerability',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.CharField(help_text='漏洞所在的URL', max_length=2000)),
('vuln_type', models.CharField(help_text='漏洞类型(如 xss, sqli', max_length=100)),
('severity', models.CharField(choices=[('unknown', '未知'), ('info', '信息'), ('low', ''), ('medium', ''), ('high', ''), ('critical', '危急')], default='unknown', help_text='严重性(未知/信息/低/中/高/危急)', max_length=20)),
('source', models.CharField(blank=True, default='', help_text='来源工具(如 dalfox, nuclei, crlfuzz', max_length=50)),
('cvss_score', models.DecimalField(blank=True, decimal_places=1, help_text='CVSS 评分0.0-10.0', max_digits=3, null=True)),
('description', models.TextField(blank=True, default='', help_text='漏洞描述')),
('raw_output', models.JSONField(blank=True, default=dict, help_text='工具原始输出')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('target', models.ForeignKey(help_text='所属的扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='vulnerabilities', to='targets.target')),
],
options={
'verbose_name': '漏洞',
'verbose_name_plural': '漏洞',
'db_table': 'vulnerability',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['target'], name='vulnerabili_target__755a02_idx'), models.Index(fields=['vuln_type'], name='vulnerabili_vuln_ty_3010cd_idx'), models.Index(fields=['severity'], name='vulnerabili_severit_1a798b_idx'), models.Index(fields=['source'], name='vulnerabili_source_7c7552_idx'), models.Index(fields=['url'], name='vulnerabili_url_4dcc4d_idx'), models.Index(fields=['-created_at'], name='vulnerabili_created_e25ff7_idx')],
},
),
migrations.CreateModel(
name='VulnerabilitySnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.CharField(help_text='漏洞所在的URL', max_length=2000)),
('vuln_type', models.CharField(help_text='漏洞类型(如 xss, sqli', max_length=100)),
('severity', models.CharField(choices=[('unknown', '未知'), ('info', '信息'), ('low', ''), ('medium', ''), ('high', ''), ('critical', '危急')], default='unknown', help_text='严重性(未知/信息/低/中/高/危急)', max_length=20)),
('source', models.CharField(blank=True, default='', help_text='来源工具(如 dalfox, nuclei, crlfuzz', max_length=50)),
('cvss_score', models.DecimalField(blank=True, decimal_places=1, help_text='CVSS 评分0.0-10.0', max_digits=3, null=True)),
('description', models.TextField(blank=True, default='', help_text='漏洞描述')),
('raw_output', models.JSONField(blank=True, default=dict, help_text='工具原始输出')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='vulnerability_snapshots', to='scan.scan')),
],
options={
'verbose_name': '漏洞快照',
'verbose_name_plural': '漏洞快照',
'db_table': 'vulnerability_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='vulnerabili_scan_id_7b81c9_idx'), models.Index(fields=['url'], name='vulnerabili_url_11a707_idx'), models.Index(fields=['vuln_type'], name='vulnerabili_vuln_ty_6b90ee_idx'), models.Index(fields=['severity'], name='vulnerabili_severit_4eae0d_idx'), models.Index(fields=['source'], name='vulnerabili_source_968b1f_idx'), models.Index(fields=['-created_at'], name='vulnerabili_created_53a12e_idx')],
},
),
migrations.CreateModel(
name='WebSite',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.TextField(help_text='最终访问的完整URL')),
('host', models.CharField(blank=True, default='', help_text='主机名域名或IP地址', max_length=253)),
('location', models.TextField(blank=True, default='', help_text='重定向地址HTTP 3xx 响应头 Location')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('title', models.TextField(blank=True, default='', help_text='网页标题HTML <title> 标签内容)')),
('webserver', models.TextField(blank=True, default='', help_text='服务器类型HTTP 响应头 Server 值)')),
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
('content_type', models.TextField(blank=True, default='', help_text='响应类型HTTP Content-Type 响应头)')),
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈(服务器/框架/语言等)', size=None)),
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
('content_length', models.IntegerField(blank=True, help_text='响应体大小(单位字节)', null=True)),
('vhost', models.BooleanField(blank=True, help_text='是否支持虚拟主机', null=True)),
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
('target', models.ForeignKey(help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)', on_delete=django.db.models.deletion.CASCADE, related_name='websites', to='targets.target')),
],
options={
'verbose_name': '站点',
'verbose_name_plural': '站点',
'db_table': 'website',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='website_created_c9cfd2_idx'), models.Index(fields=['url'], name='website_url_b18883_idx'), models.Index(fields=['host'], name='website_host_996b50_idx'), models.Index(fields=['target'], name='website_target__2a353b_idx'), models.Index(fields=['title'], name='website_title_c2775b_idx'), models.Index(fields=['status_code'], name='website_status__51663d_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='website_tech_e3f0cb_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='website_resp_headers_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='website_url_trgm_idx', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='website_title_trgm_idx', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('url', 'target'), name='unique_website_url_target')],
},
),
migrations.CreateModel(
name='WebsiteSnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.TextField(help_text='站点URL')),
('host', models.CharField(blank=True, default='', help_text='主机名域名或IP地址', max_length=253)),
('title', models.TextField(blank=True, default='', help_text='页面标题')),
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
('content_length', models.BigIntegerField(blank=True, help_text='内容长度', null=True)),
('location', models.TextField(blank=True, default='', help_text='重定向位置')),
('webserver', models.TextField(blank=True, default='', help_text='Web服务器')),
('content_type', models.TextField(blank=True, default='', help_text='内容类型')),
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈', size=None)),
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
('vhost', models.BooleanField(blank=True, help_text='虚拟主机标志', null=True)),
('response_headers', models.TextField(blank=True, default='', help_text='原始HTTP响应头')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='website_snapshots', to='scan.scan')),
],
options={
'verbose_name': '网站快照',
'verbose_name_plural': '网站快照',
'db_table': 'website_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='website_sna_scan_id_26b6dc_idx'), models.Index(fields=['url'], name='website_sna_url_801a70_idx'), models.Index(fields=['host'], name='website_sna_host_348fe1_idx'), models.Index(fields=['title'], name='website_sna_title_b1a5ee_idx'), models.Index(fields=['-created_at'], name='website_sna_created_2c149a_idx'), django.contrib.postgres.indexes.GinIndex(fields=['tech'], name='website_sna_tech_3d6d2f_gin'), django.contrib.postgres.indexes.GinIndex(fields=['response_headers'], name='ws_snap_resp_hdr_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['url'], name='ws_snap_url_trgm', opclasses=['gin_trgm_ops']), django.contrib.postgres.indexes.GinIndex(fields=['title'], name='ws_snap_title_trgm', opclasses=['gin_trgm_ops'])],
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_website_per_scan_snapshot')],
},
),
]

View File

@@ -0,0 +1,104 @@
"""
创建资产搜索物化视图(使用 pg_ivm 增量维护)
这些视图用于资产搜索功能,提供高性能的全文搜索能力。
"""
from django.db import migrations
class Migration(migrations.Migration):
"""创建资产搜索所需的增量物化视图"""
dependencies = [
('asset', '0001_initial'),
]
operations = [
# 1. 确保 pg_ivm 扩展已安装
migrations.RunSQL(
sql="CREATE EXTENSION IF NOT EXISTS pg_ivm;",
reverse_sql="DROP EXTENSION IF EXISTS pg_ivm;",
),
# 2. 创建 Website 搜索视图
# 注意pg_ivm 不支持 ArrayField所以 tech 字段需要从原表 JOIN 获取
migrations.RunSQL(
sql="""
SELECT pgivm.create_immv('asset_search_view', $$
SELECT
w.id,
w.url,
w.host,
w.title,
w.status_code,
w.response_headers,
w.response_body,
w.content_type,
w.content_length,
w.webserver,
w.location,
w.vhost,
w.created_at,
w.target_id
FROM website w
$$);
""",
reverse_sql="DROP TABLE IF EXISTS asset_search_view CASCADE;",
),
# 3. 创建 Endpoint 搜索视图
migrations.RunSQL(
sql="""
SELECT pgivm.create_immv('endpoint_search_view', $$
SELECT
e.id,
e.url,
e.host,
e.title,
e.status_code,
e.response_headers,
e.response_body,
e.content_type,
e.content_length,
e.webserver,
e.location,
e.vhost,
e.created_at,
e.target_id
FROM endpoint e
$$);
""",
reverse_sql="DROP TABLE IF EXISTS endpoint_search_view CASCADE;",
),
# 4. 为搜索视图创建索引(加速查询)
migrations.RunSQL(
sql=[
# Website 搜索视图索引
"CREATE INDEX IF NOT EXISTS asset_search_view_host_idx ON asset_search_view (host);",
"CREATE INDEX IF NOT EXISTS asset_search_view_url_idx ON asset_search_view (url);",
"CREATE INDEX IF NOT EXISTS asset_search_view_title_idx ON asset_search_view (title);",
"CREATE INDEX IF NOT EXISTS asset_search_view_status_idx ON asset_search_view (status_code);",
"CREATE INDEX IF NOT EXISTS asset_search_view_created_idx ON asset_search_view (created_at DESC);",
# Endpoint 搜索视图索引
"CREATE INDEX IF NOT EXISTS endpoint_search_view_host_idx ON endpoint_search_view (host);",
"CREATE INDEX IF NOT EXISTS endpoint_search_view_url_idx ON endpoint_search_view (url);",
"CREATE INDEX IF NOT EXISTS endpoint_search_view_title_idx ON endpoint_search_view (title);",
"CREATE INDEX IF NOT EXISTS endpoint_search_view_status_idx ON endpoint_search_view (status_code);",
"CREATE INDEX IF NOT EXISTS endpoint_search_view_created_idx ON endpoint_search_view (created_at DESC);",
],
reverse_sql=[
"DROP INDEX IF EXISTS asset_search_view_host_idx;",
"DROP INDEX IF EXISTS asset_search_view_url_idx;",
"DROP INDEX IF EXISTS asset_search_view_title_idx;",
"DROP INDEX IF EXISTS asset_search_view_status_idx;",
"DROP INDEX IF EXISTS asset_search_view_created_idx;",
"DROP INDEX IF EXISTS endpoint_search_view_host_idx;",
"DROP INDEX IF EXISTS endpoint_search_view_url_idx;",
"DROP INDEX IF EXISTS endpoint_search_view_title_idx;",
"DROP INDEX IF EXISTS endpoint_search_view_status_idx;",
"DROP INDEX IF EXISTS endpoint_search_view_created_idx;",
],
),
]

View File

@@ -0,0 +1,53 @@
# Generated by Django 5.2.7 on 2026-01-07 02:21
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('asset', '0002_create_search_views'),
('scan', '0001_initial'),
('targets', '0001_initial'),
]
operations = [
migrations.CreateModel(
name='Screenshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.TextField(help_text='截图对应的 URL')),
('image', models.BinaryField(help_text='截图 WebP 二进制数据(压缩后)')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
('target', models.ForeignKey(help_text='所属目标', on_delete=django.db.models.deletion.CASCADE, related_name='screenshots', to='targets.target')),
],
options={
'verbose_name': '截图',
'verbose_name_plural': '截图',
'db_table': 'screenshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['target'], name='screenshot_target__2f01f6_idx'), models.Index(fields=['-created_at'], name='screenshot_created_c0ad4b_idx')],
'constraints': [models.UniqueConstraint(fields=('target', 'url'), name='unique_screenshot_per_target')],
},
),
migrations.CreateModel(
name='ScreenshotSnapshot',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('url', models.TextField(help_text='截图对应的 URL')),
('image', models.BinaryField(help_text='截图 WebP 二进制数据(压缩后)')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='所属的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='screenshot_snapshots', to='scan.scan')),
],
options={
'verbose_name': '截图快照',
'verbose_name_plural': '截图快照',
'db_table': 'screenshot_snapshot',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scan'], name='screenshot__scan_id_fb8c4d_idx'), models.Index(fields=['-created_at'], name='screenshot__created_804117_idx')],
'constraints': [models.UniqueConstraint(fields=('scan', 'url'), name='unique_screenshot_per_scan_snapshot')],
},
),
]

View File

@@ -0,0 +1,23 @@
# Generated by Django 5.2.7 on 2026-01-07 13:29
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('asset', '0003_add_screenshot_models'),
]
operations = [
migrations.AddField(
model_name='screenshot',
name='status_code',
field=models.SmallIntegerField(blank=True, help_text='HTTP 响应状态码', null=True),
),
migrations.AddField(
model_name='screenshotsnapshot',
name='status_code',
field=models.SmallIntegerField(blank=True, help_text='HTTP 响应状态码', null=True),
),
]

View File

@@ -20,6 +20,12 @@ from .snapshot_models import (
VulnerabilitySnapshot,
)
# 截图模型
from .screenshot_models import (
Screenshot,
ScreenshotSnapshot,
)
# 统计模型
from .statistics_models import AssetStatistics, StatisticsHistory
@@ -39,6 +45,9 @@ __all__ = [
'HostPortMappingSnapshot',
'EndpointSnapshot',
'VulnerabilitySnapshot',
# 截图模型
'Screenshot',
'ScreenshotSnapshot',
# 统计模型
'AssetStatistics',
'StatisticsHistory',

View File

@@ -1,6 +1,7 @@
from django.db import models
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.indexes import GinIndex
from django.core.validators import MinValueValidator, MaxValueValidator
@@ -34,6 +35,12 @@ class Subdomain(models.Model):
models.Index(fields=['name', 'target']), # 复合索引,优化 get_by_names_and_target_id 批量查询
models.Index(fields=['target']), # 优化从target_id快速查找下面的子域名
models.Index(fields=['name']), # 优化从name快速查找子域名搜索场景
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='subdomain_name_trgm_idx',
fields=['name'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 普通唯一约束name + target 组合唯一
@@ -58,40 +65,35 @@ class Endpoint(models.Model):
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
)
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
url = models.TextField(help_text='最终访问的完整URL')
host = models.CharField(
max_length=253,
blank=True,
default='',
help_text='主机名域名或IP地址'
)
location = models.CharField(
max_length=1000,
location = models.TextField(
blank=True,
default='',
help_text='重定向地址HTTP 3xx 响应头 Location'
)
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
title = models.CharField(
max_length=1000,
title = models.TextField(
blank=True,
default='',
help_text='网页标题HTML <title> 标签内容)'
)
webserver = models.CharField(
max_length=200,
webserver = models.TextField(
blank=True,
default='',
help_text='服务器类型HTTP 响应头 Server 值)'
)
body_preview = models.CharField(
max_length=1000,
response_body = models.TextField(
blank=True,
default='',
help_text='响应正文前N个字符默认100个字符'
help_text='HTTP响应体'
)
content_type = models.CharField(
max_length=200,
content_type = models.TextField(
blank=True,
default='',
help_text='响应类型HTTP Content-Type 响应头)'
@@ -123,6 +125,11 @@ class Endpoint(models.Model):
default=list,
help_text='匹配的GF模式列表用于识别敏感端点如api, debug, config等'
)
response_headers = models.TextField(
blank=True,
default='',
help_text='原始HTTP响应头'
)
class Meta:
db_table = 'endpoint'
@@ -131,11 +138,28 @@ class Endpoint(models.Model):
ordering = ['-created_at']
indexes = [
models.Index(fields=['-created_at']),
models.Index(fields=['target']), # 优化从target_id快速查找下面的端点主关联字段
models.Index(fields=['target']), # 优化从 target_id快速查找下面的端点主关联字段
models.Index(fields=['url']), # URL索引优化查询性能
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['status_code']), # 状态码索引,优化筛选
models.Index(fields=['title']), # title索引优化智能过滤搜索
GinIndex(fields=['tech']), # GIN索引优化 tech 数组字段的 __contains 查询
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='endpoint_resp_headers_trgm_idx',
fields=['response_headers'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='endpoint_url_trgm_idx',
fields=['url'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='endpoint_title_trgm_idx',
fields=['title'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 普通唯一约束url + target 组合唯一
@@ -160,40 +184,35 @@ class WebSite(models.Model):
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
)
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
url = models.TextField(help_text='最终访问的完整URL')
host = models.CharField(
max_length=253,
blank=True,
default='',
help_text='主机名域名或IP地址'
)
location = models.CharField(
max_length=1000,
location = models.TextField(
blank=True,
default='',
help_text='重定向地址HTTP 3xx 响应头 Location'
)
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
title = models.CharField(
max_length=1000,
title = models.TextField(
blank=True,
default='',
help_text='网页标题HTML <title> 标签内容)'
)
webserver = models.CharField(
max_length=200,
webserver = models.TextField(
blank=True,
default='',
help_text='服务器类型HTTP 响应头 Server 值)'
)
body_preview = models.CharField(
max_length=1000,
response_body = models.TextField(
blank=True,
default='',
help_text='响应正文前N个字符默认100个字符'
help_text='HTTP响应体'
)
content_type = models.CharField(
max_length=200,
content_type = models.TextField(
blank=True,
default='',
help_text='响应类型HTTP Content-Type 响应头)'
@@ -219,6 +238,11 @@ class WebSite(models.Model):
blank=True,
help_text='是否支持虚拟主机'
)
response_headers = models.TextField(
blank=True,
default='',
help_text='原始HTTP响应头'
)
class Meta:
db_table = 'website'
@@ -229,9 +253,26 @@ class WebSite(models.Model):
models.Index(fields=['-created_at']),
models.Index(fields=['url']), # URL索引优化查询性能
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['target']), # 优化从target_id快速查找下面的站点
models.Index(fields=['target']), # 优化从 target_id快速查找下面的站点
models.Index(fields=['title']), # title索引优化智能过滤搜索
models.Index(fields=['status_code']), # 状态码索引,优化智能过滤搜索
GinIndex(fields=['tech']), # GIN索引优化 tech 数组字段的 __contains 查询
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='website_resp_headers_trgm_idx',
fields=['response_headers'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='website_url_trgm_idx',
fields=['url'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='website_title_trgm_idx',
fields=['title'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 普通唯一约束url + target 组合唯一
@@ -308,6 +349,12 @@ class Directory(models.Model):
models.Index(fields=['target']), # 优化从target_id快速查找下面的目录
models.Index(fields=['url']), # URL索引优化搜索和唯一约束
models.Index(fields=['status']), # 状态码索引,优化筛选
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='directory_url_trgm_idx',
fields=['url'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 普通唯一约束target + url 组合唯一

View File

@@ -0,0 +1,80 @@
from django.db import models
class ScreenshotSnapshot(models.Model):
"""
截图快照
记录:某次扫描中捕获的网站截图
"""
id = models.AutoField(primary_key=True)
scan = models.ForeignKey(
'scan.Scan',
on_delete=models.CASCADE,
related_name='screenshot_snapshots',
help_text='所属的扫描任务'
)
url = models.TextField(help_text='截图对应的 URL')
status_code = models.SmallIntegerField(null=True, blank=True, help_text='HTTP 响应状态码')
image = models.BinaryField(help_text='截图 WebP 二进制数据(压缩后)')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'screenshot_snapshot'
verbose_name = '截图快照'
verbose_name_plural = '截图快照'
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['-created_at']),
]
constraints = [
models.UniqueConstraint(
fields=['scan', 'url'],
name='unique_screenshot_per_scan_snapshot'
),
]
def __str__(self):
return f'{self.url} (Scan #{self.scan_id})'
class Screenshot(models.Model):
"""
截图资产
存储:目标的最新截图(从快照同步)
"""
id = models.AutoField(primary_key=True)
target = models.ForeignKey(
'targets.Target',
on_delete=models.CASCADE,
related_name='screenshots',
help_text='所属目标'
)
url = models.TextField(help_text='截图对应的 URL')
status_code = models.SmallIntegerField(null=True, blank=True, help_text='HTTP 响应状态码')
image = models.BinaryField(help_text='截图 WebP 二进制数据(压缩后)')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
class Meta:
db_table = 'screenshot'
verbose_name = '截图'
verbose_name_plural = '截图'
ordering = ['-created_at']
indexes = [
models.Index(fields=['target']),
models.Index(fields=['-created_at']),
]
constraints = [
models.UniqueConstraint(
fields=['target', 'url'],
name='unique_screenshot_per_target'
),
]
def __str__(self):
return f'{self.url} (Target #{self.target_id})'

View File

@@ -1,5 +1,6 @@
from django.db import models
from django.contrib.postgres.fields import ArrayField
from django.contrib.postgres.indexes import GinIndex
from django.core.validators import MinValueValidator, MaxValueValidator
@@ -26,6 +27,12 @@ class SubdomainSnapshot(models.Model):
models.Index(fields=['scan']),
models.Index(fields=['name']),
models.Index(fields=['-created_at']),
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='subdomain_snap_name_trgm',
fields=['name'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 唯一约束:同一次扫描中,同一个子域名只能记录一次
@@ -54,22 +61,27 @@ class WebsiteSnapshot(models.Model):
)
# 扫描结果数据
url = models.CharField(max_length=2000, help_text='站点URL')
url = models.TextField(help_text='站点URL')
host = models.CharField(max_length=253, blank=True, default='', help_text='主机名域名或IP地址')
title = models.CharField(max_length=500, blank=True, default='', help_text='页面标题')
status = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
title = models.TextField(blank=True, default='', help_text='页面标题')
status_code = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
content_length = models.BigIntegerField(null=True, blank=True, help_text='内容长度')
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
web_server = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
location = models.TextField(blank=True, default='', help_text='重定向位置')
webserver = models.TextField(blank=True, default='', help_text='Web服务器')
content_type = models.TextField(blank=True, default='', help_text='内容类型')
tech = ArrayField(
models.CharField(max_length=100),
blank=True,
default=list,
help_text='技术栈'
)
body_preview = models.TextField(blank=True, default='', help_text='响应体预览')
response_body = models.TextField(blank=True, default='', help_text='HTTP响应体')
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
response_headers = models.TextField(
blank=True,
default='',
help_text='原始HTTP响应头'
)
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
@@ -83,6 +95,23 @@ class WebsiteSnapshot(models.Model):
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['title']), # title索引优化标题搜索
models.Index(fields=['-created_at']),
GinIndex(fields=['tech']), # GIN索引优化数组字段查询
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='ws_snap_resp_hdr_trgm',
fields=['response_headers'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='ws_snap_url_trgm',
fields=['url'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='ws_snap_title_trgm',
fields=['title'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 唯一约束同一次扫描中同一个URL只能记录一次
@@ -132,6 +161,12 @@ class DirectorySnapshot(models.Model):
models.Index(fields=['status']), # 状态码索引,优化筛选
models.Index(fields=['content_type']), # content_type索引优化内容类型搜索
models.Index(fields=['-created_at']),
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='dir_snap_url_trgm',
fields=['url'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 唯一约束同一次扫描中同一个目录URL只能记录一次
@@ -232,26 +267,26 @@ class EndpointSnapshot(models.Model):
)
# 扫描结果数据
url = models.CharField(max_length=2000, help_text='端点URL')
url = models.TextField(help_text='端点URL')
host = models.CharField(
max_length=253,
blank=True,
default='',
help_text='主机名域名或IP地址'
)
title = models.CharField(max_length=1000, blank=True, default='', help_text='页面标题')
title = models.TextField(blank=True, default='', help_text='页面标题')
status_code = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
content_length = models.IntegerField(null=True, blank=True, help_text='内容长度')
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
webserver = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
location = models.TextField(blank=True, default='', help_text='重定向位置')
webserver = models.TextField(blank=True, default='', help_text='Web服务器')
content_type = models.TextField(blank=True, default='', help_text='内容类型')
tech = ArrayField(
models.CharField(max_length=100),
blank=True,
default=list,
help_text='技术栈'
)
body_preview = models.CharField(max_length=1000, blank=True, default='', help_text='响应体预览')
response_body = models.TextField(blank=True, default='', help_text='HTTP响应体')
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
matched_gf_patterns = ArrayField(
models.CharField(max_length=100),
@@ -259,6 +294,11 @@ class EndpointSnapshot(models.Model):
default=list,
help_text='匹配的GF模式列表'
)
response_headers = models.TextField(
blank=True,
default='',
help_text='原始HTTP响应头'
)
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
@@ -274,6 +314,23 @@ class EndpointSnapshot(models.Model):
models.Index(fields=['status_code']), # 状态码索引,优化筛选
models.Index(fields=['webserver']), # webserver索引优化服务器搜索
models.Index(fields=['-created_at']),
GinIndex(fields=['tech']), # GIN索引优化数组字段查询
# pg_trgm GIN 索引,支持 LIKE '%keyword%' 模糊搜索
GinIndex(
name='ep_snap_resp_hdr_trgm',
fields=['response_headers'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='ep_snap_url_trgm',
fields=['url'],
opclasses=['gin_trgm_ops']
),
GinIndex(
name='ep_snap_title_trgm',
fields=['title'],
opclasses=['gin_trgm_ops']
),
]
constraints = [
# 唯一约束同一次扫描中同一个URL只能记录一次

View File

@@ -48,12 +48,13 @@ class DjangoEndpointRepository:
status_code=item.status_code,
content_length=item.content_length,
webserver=item.webserver or '',
body_preview=item.body_preview or '',
response_body=item.response_body or '',
content_type=item.content_type or '',
tech=item.tech if item.tech else [],
vhost=item.vhost,
location=item.location or '',
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else [],
response_headers=item.response_headers if item.response_headers else ''
)
for item in unique_items
]
@@ -65,8 +66,8 @@ class DjangoEndpointRepository:
unique_fields=['url', 'target'],
update_fields=[
'host', 'title', 'status_code', 'content_length',
'webserver', 'body_preview', 'content_type', 'tech',
'vhost', 'location', 'matched_gf_patterns'
'webserver', 'response_body', 'content_type', 'tech',
'vhost', 'location', 'matched_gf_patterns', 'response_headers'
],
batch_size=1000
)
@@ -138,12 +139,13 @@ class DjangoEndpointRepository:
status_code=item.status_code,
content_length=item.content_length,
webserver=item.webserver or '',
body_preview=item.body_preview or '',
response_body=item.response_body or '',
content_type=item.content_type or '',
tech=item.tech if item.tech else [],
vhost=item.vhost,
location=item.location or '',
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else [],
response_headers=item.response_headers if item.response_headers else ''
)
for item in unique_items
]
@@ -183,7 +185,7 @@ class DjangoEndpointRepository:
.values(
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
)
.order_by('url')
)

View File

@@ -49,12 +49,13 @@ class DjangoWebSiteRepository:
location=item.location or '',
title=item.title or '',
webserver=item.webserver or '',
body_preview=item.body_preview or '',
response_body=item.response_body or '',
content_type=item.content_type or '',
tech=item.tech if item.tech else [],
status_code=item.status_code,
content_length=item.content_length,
vhost=item.vhost
vhost=item.vhost,
response_headers=item.response_headers if item.response_headers else ''
)
for item in unique_items
]
@@ -66,8 +67,8 @@ class DjangoWebSiteRepository:
unique_fields=['url', 'target'],
update_fields=[
'host', 'location', 'title', 'webserver',
'body_preview', 'content_type', 'tech',
'status_code', 'content_length', 'vhost'
'response_body', 'content_type', 'tech',
'status_code', 'content_length', 'vhost', 'response_headers'
],
batch_size=1000
)
@@ -132,12 +133,13 @@ class DjangoWebSiteRepository:
location=item.location or '',
title=item.title or '',
webserver=item.webserver or '',
body_preview=item.body_preview or '',
response_body=item.response_body or '',
content_type=item.content_type or '',
tech=item.tech if item.tech else [],
status_code=item.status_code,
content_length=item.content_length,
vhost=item.vhost
vhost=item.vhost,
response_headers=item.response_headers if item.response_headers else ''
)
for item in unique_items
]
@@ -177,7 +179,7 @@ class DjangoWebSiteRepository:
.values(
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'created_at'
'response_body', 'response_headers', 'vhost', 'created_at'
)
.order_by('url')
)

View File

@@ -44,6 +44,7 @@ class DjangoEndpointSnapshotRepository:
snapshots.append(EndpointSnapshot(
scan_id=item.scan_id,
url=item.url,
host=item.host if item.host else '',
title=item.title,
status_code=item.status_code,
content_length=item.content_length,
@@ -51,9 +52,10 @@ class DjangoEndpointSnapshotRepository:
webserver=item.webserver,
content_type=item.content_type,
tech=item.tech if item.tech else [],
body_preview=item.body_preview,
response_body=item.response_body,
vhost=item.vhost,
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else []
matched_gf_patterns=item.matched_gf_patterns if item.matched_gf_patterns else [],
response_headers=item.response_headers if item.response_headers else ''
))
# 批量创建(忽略冲突,基于唯一约束去重)
@@ -100,7 +102,7 @@ class DjangoEndpointSnapshotRepository:
.values(
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
)
.order_by('url')
)

View File

@@ -46,14 +46,15 @@ class DjangoWebsiteSnapshotRepository:
url=item.url,
host=item.host,
title=item.title,
status=item.status,
status_code=item.status_code,
content_length=item.content_length,
location=item.location,
web_server=item.web_server,
webserver=item.webserver,
content_type=item.content_type,
tech=item.tech if item.tech else [],
body_preview=item.body_preview,
vhost=item.vhost
response_body=item.response_body,
vhost=item.vhost,
response_headers=item.response_headers if item.response_headers else ''
))
# 批量创建(忽略冲突,基于唯一约束去重)
@@ -98,26 +99,12 @@ class DjangoWebsiteSnapshotRepository:
WebsiteSnapshot.objects
.filter(scan_id=scan_id)
.values(
'url', 'host', 'location', 'title', 'status',
'content_length', 'content_type', 'web_server', 'tech',
'body_preview', 'vhost', 'created_at'
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'response_body', 'response_headers', 'vhost', 'created_at'
)
.order_by('url')
)
for row in qs.iterator(chunk_size=batch_size):
# 重命名字段以匹配 CSV 表头
yield {
'url': row['url'],
'host': row['host'],
'location': row['location'],
'title': row['title'],
'status_code': row['status'],
'content_length': row['content_length'],
'content_type': row['content_type'],
'webserver': row['web_server'],
'tech': row['tech'],
'body_preview': row['body_preview'],
'vhost': row['vhost'],
'created_at': row['created_at'],
}
yield row

View File

@@ -7,6 +7,7 @@ from .models.snapshot_models import (
EndpointSnapshot,
VulnerabilitySnapshot,
)
from .models.screenshot_models import Screenshot, ScreenshotSnapshot
# 注意IPAddress 和 Port 模型已被重构为 HostPortMapping
@@ -67,9 +68,10 @@ class SubdomainListSerializer(serializers.ModelSerializer):
class WebSiteSerializer(serializers.ModelSerializer):
"""站点序列化器"""
"""站点序列化器(目标详情页)"""
subdomain = serializers.CharField(source='subdomain.name', allow_blank=True, default='')
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
class Meta:
model = WebSite
@@ -83,9 +85,10 @@ class WebSiteSerializer(serializers.ModelSerializer):
'content_type',
'status_code',
'content_length',
'body_preview',
'response_body',
'tech',
'vhost',
'responseHeaders', # HTTP响应头
'subdomain',
'created_at',
]
@@ -140,6 +143,7 @@ class EndpointListSerializer(serializers.ModelSerializer):
source='matched_gf_patterns',
read_only=True,
)
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
class Meta:
model = Endpoint
@@ -152,9 +156,10 @@ class EndpointListSerializer(serializers.ModelSerializer):
'content_length',
'content_type',
'webserver',
'body_preview',
'response_body',
'tech',
'vhost',
'responseHeaders', # HTTP响应头
'gfPatterns',
'created_at',
]
@@ -213,8 +218,7 @@ class WebsiteSnapshotSerializer(serializers.ModelSerializer):
"""网站快照序列化器(用于扫描历史)"""
subdomain_name = serializers.CharField(source='subdomain.name', read_only=True)
webserver = serializers.CharField(source='web_server', read_only=True) # 映射字段名
status_code = serializers.IntegerField(source='status', read_only=True) # 映射字段名
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
class Meta:
model = WebsiteSnapshot
@@ -223,13 +227,14 @@ class WebsiteSnapshotSerializer(serializers.ModelSerializer):
'url',
'location',
'title',
'webserver', # 使用映射后的字段名
'webserver',
'content_type',
'status_code', # 使用映射后的字段名
'status_code',
'content_length',
'body_preview',
'response_body',
'tech',
'vhost',
'responseHeaders', # HTTP响应头
'subdomain_name',
'created_at',
]
@@ -264,6 +269,7 @@ class EndpointSnapshotSerializer(serializers.ModelSerializer):
source='matched_gf_patterns',
read_only=True,
)
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
class Meta:
model = EndpointSnapshot
@@ -277,10 +283,31 @@ class EndpointSnapshotSerializer(serializers.ModelSerializer):
'content_type',
'status_code',
'content_length',
'body_preview',
'response_body',
'tech',
'vhost',
'responseHeaders', # HTTP响应头
'gfPatterns',
'created_at',
]
read_only_fields = fields
# ==================== 截图序列化器 ====================
class ScreenshotListSerializer(serializers.ModelSerializer):
"""截图资产列表序列化器(不包含 image 字段)"""
class Meta:
model = Screenshot
fields = ['id', 'url', 'status_code', 'created_at', 'updated_at']
read_only_fields = fields
class ScreenshotSnapshotListSerializer(serializers.ModelSerializer):
"""截图快照列表序列化器(不包含 image 字段)"""
class Meta:
model = ScreenshotSnapshot
fields = ['id', 'url', 'status_code', 'created_at']
read_only_fields = fields

View File

@@ -27,7 +27,7 @@ class EndpointService:
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
'status_code': 'status_code',
'tech': 'tech',
}

View File

@@ -19,7 +19,7 @@ class WebSiteService:
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
'status_code': 'status_code',
'tech': 'tech',
}

View File

@@ -0,0 +1,186 @@
"""
Playwright 截图服务
使用 Playwright 异步批量捕获网站截图
"""
import asyncio
import logging
from typing import Optional, AsyncGenerator
logger = logging.getLogger(__name__)
class PlaywrightScreenshotService:
"""Playwright 截图服务 - 异步多 Page 并发截图"""
# 内置默认值(不暴露给用户)
DEFAULT_VIEWPORT_WIDTH = 1920
DEFAULT_VIEWPORT_HEIGHT = 1080
DEFAULT_TIMEOUT = 30000 # 毫秒
DEFAULT_JPEG_QUALITY = 85
def __init__(
self,
viewport_width: int = DEFAULT_VIEWPORT_WIDTH,
viewport_height: int = DEFAULT_VIEWPORT_HEIGHT,
timeout: int = DEFAULT_TIMEOUT,
concurrency: int = 5
):
"""
初始化 Playwright 截图服务
Args:
viewport_width: 视口宽度(像素)
viewport_height: 视口高度(像素)
timeout: 页面加载超时时间(毫秒)
concurrency: 并发截图数
"""
self.viewport_width = viewport_width
self.viewport_height = viewport_height
self.timeout = timeout
self.concurrency = concurrency
async def capture_screenshot(self, url: str, page) -> tuple[Optional[bytes], Optional[int]]:
"""
捕获单个 URL 的截图
Args:
url: 目标 URL
page: Playwright Page 对象
Returns:
(screenshot_bytes, status_code) 元组
- screenshot_bytes: JPEG 格式的截图字节数据,失败返回 None
- status_code: HTTP 响应状态码,失败返回 None
"""
status_code = None
try:
# 尝试加载页面,即使返回错误状态码也继续截图
try:
response = await page.goto(url, timeout=self.timeout, wait_until='networkidle')
if response:
status_code = response.status
except Exception as goto_error:
# 页面加载失败4xx/5xx 或其他错误),但页面可能已部分渲染
# 仍然尝试截图以捕获错误页面
logger.debug("页面加载异常但尝试截图: %s, 错误: %s", url, str(goto_error)[:50])
# 尝试截图(即使 goto 失败)
screenshot_bytes = await page.screenshot(
type='jpeg',
quality=self.DEFAULT_JPEG_QUALITY,
full_page=False
)
return (screenshot_bytes, status_code)
except asyncio.TimeoutError:
logger.warning("截图超时: %s", url)
return (None, None)
except Exception as e:
logger.warning("截图失败: %s, 错误: %s", url, str(e)[:100])
return (None, None)
async def _capture_with_semaphore(
self,
url: str,
context,
semaphore: asyncio.Semaphore
) -> tuple[str, Optional[bytes], Optional[int]]:
"""
使用信号量控制并发的截图任务
Args:
url: 目标 URL
context: Playwright BrowserContext
semaphore: 并发控制信号量
Returns:
(url, screenshot_bytes, status_code) 元组
"""
async with semaphore:
page = await context.new_page()
try:
screenshot_bytes, status_code = await self.capture_screenshot(url, page)
return (url, screenshot_bytes, status_code)
finally:
await page.close()
async def capture_batch(
self,
urls: list[str]
) -> AsyncGenerator[tuple[str, Optional[bytes], Optional[int]], None]:
"""
批量捕获截图(异步生成器)
使用单个 BrowserContext + 多 Page 并发模式
通过 Semaphore 控制并发数
Args:
urls: URL 列表
Yields:
(url, screenshot_bytes, status_code) 元组
"""
if not urls:
return
from playwright.async_api import async_playwright
async with async_playwright() as p:
# 启动浏览器headless 模式)
browser = await p.chromium.launch(
headless=True,
args=[
'--no-sandbox',
'--disable-setuid-sandbox',
'--disable-dev-shm-usage',
'--disable-gpu'
]
)
try:
# 创建单个 context
context = await browser.new_context(
viewport={
'width': self.viewport_width,
'height': self.viewport_height
},
ignore_https_errors=True,
user_agent='Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
)
# 使用 Semaphore 控制并发
semaphore = asyncio.Semaphore(self.concurrency)
# 创建所有任务
tasks = [
self._capture_with_semaphore(url, context, semaphore)
for url in urls
]
# 使用 as_completed 实现流式返回
for coro in asyncio.as_completed(tasks):
result = await coro
yield result
await context.close()
finally:
await browser.close()
async def capture_batch_collect(
self,
urls: list[str]
) -> list[tuple[str, Optional[bytes], Optional[int]]]:
"""
批量捕获截图(收集所有结果)
Args:
urls: URL 列表
Returns:
[(url, screenshot_bytes, status_code), ...] 列表
"""
results = []
async for result in self.capture_batch(urls):
results.append(result)
return results

View File

@@ -0,0 +1,185 @@
"""
截图服务
负责截图的压缩、保存和同步
"""
import io
import logging
import os
from typing import Optional
from PIL import Image
logger = logging.getLogger(__name__)
class ScreenshotService:
"""截图服务 - 负责压缩、保存和同步"""
def __init__(self, max_width: int = 800, target_kb: int = 100):
"""
初始化截图服务
Args:
max_width: 最大宽度(像素)
target_kb: 目标文件大小KB
"""
self.max_width = max_width
self.target_kb = target_kb
def compress_screenshot(self, image_path: str) -> Optional[bytes]:
"""
压缩截图为 WebP 格式
Args:
image_path: PNG 截图文件路径
Returns:
压缩后的 WebP 二进制数据,失败返回 None
"""
if not os.path.exists(image_path):
logger.warning(f"截图文件不存在: {image_path}")
return None
try:
with Image.open(image_path) as img:
return self._compress_image(img)
except Exception as e:
logger.error(f"压缩截图失败: {image_path}, 错误: {e}")
return None
def compress_from_bytes(self, image_bytes: bytes) -> Optional[bytes]:
"""
从字节数据压缩截图为 WebP 格式
Args:
image_bytes: JPEG/PNG 图片字节数据
Returns:
压缩后的 WebP 二进制数据,失败返回 None
"""
if not image_bytes:
return None
try:
img = Image.open(io.BytesIO(image_bytes))
return self._compress_image(img)
except Exception as e:
logger.error(f"从字节压缩截图失败: {e}")
return None
def _compress_image(self, img: Image.Image) -> Optional[bytes]:
"""
压缩 PIL Image 对象为 WebP 格式
Args:
img: PIL Image 对象
Returns:
压缩后的 WebP 二进制数据
"""
try:
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
width, height = img.size
if width > self.max_width:
ratio = self.max_width / width
new_size = (self.max_width, int(height * ratio))
img = img.resize(new_size, Image.Resampling.LANCZOS)
quality = 80
while quality >= 10:
buffer = io.BytesIO()
img.save(buffer, format='WEBP', quality=quality, method=6)
if len(buffer.getvalue()) <= self.target_kb * 1024:
return buffer.getvalue()
quality -= 10
return buffer.getvalue()
except Exception as e:
logger.error(f"压缩图片失败: {e}")
return None
def save_screenshot_snapshot(
self,
scan_id: int,
url: str,
image_data: bytes,
status_code: int | None = None
) -> bool:
"""
保存截图快照到 ScreenshotSnapshot 表
Args:
scan_id: 扫描 ID
url: 截图对应的 URL
image_data: 压缩后的图片二进制数据
status_code: HTTP 响应状态码
Returns:
是否保存成功
"""
from apps.asset.models import ScreenshotSnapshot
try:
ScreenshotSnapshot.objects.update_or_create(
scan_id=scan_id,
url=url,
defaults={'image': image_data, 'status_code': status_code}
)
return True
except Exception as e:
logger.error(f"保存截图快照失败: scan_id={scan_id}, url={url}, 错误: {e}")
return False
def sync_screenshots_to_asset(self, scan_id: int, target_id: int) -> int:
"""
将扫描的截图快照同步到资产表
Args:
scan_id: 扫描 ID
target_id: 目标 ID
Returns:
同步的截图数量
"""
from apps.asset.models import Screenshot, ScreenshotSnapshot
snapshots = ScreenshotSnapshot.objects.filter(scan_id=scan_id)
count = 0
for snapshot in snapshots:
try:
Screenshot.objects.update_or_create(
target_id=target_id,
url=snapshot.url,
defaults={
'image': snapshot.image,
'status_code': snapshot.status_code
}
)
count += 1
except Exception as e:
logger.error(f"同步截图到资产表失败: url={snapshot.url}, 错误: {e}")
logger.info(f"同步截图完成: scan_id={scan_id}, target_id={target_id}, 数量={count}")
return count
def process_and_save_screenshot(self, scan_id: int, url: str, image_path: str) -> bool:
"""
处理并保存截图(压缩 + 保存快照)
Args:
scan_id: 扫描 ID
url: 截图对应的 URL
image_path: PNG 截图文件路径
Returns:
是否处理成功
"""
image_data = self.compress_screenshot(image_path)
if image_data is None:
return False
return self.save_screenshot_snapshot(scan_id, url, image_data)

View File

@@ -0,0 +1,477 @@
"""
资产搜索服务
提供资产搜索的核心业务逻辑:
- 从物化视图查询数据
- 支持表达式语法解析
- 支持 =(模糊)、==(精确)、!=(不等于)操作符
- 支持 && (AND) 和 || (OR) 逻辑组合
- 支持 Website 和 Endpoint 两种资产类型
"""
import logging
import re
from typing import Optional, List, Dict, Any, Tuple, Literal, Iterator
from django.db import connection
logger = logging.getLogger(__name__)
# 支持的字段映射(前端字段名 -> 数据库字段名)
FIELD_MAPPING = {
'host': 'host',
'url': 'url',
'title': 'title',
'tech': 'tech',
'status': 'status_code',
'body': 'response_body',
'header': 'response_headers',
}
# 数组类型字段
ARRAY_FIELDS = {'tech'}
# 资产类型到视图名的映射
VIEW_MAPPING = {
'website': 'asset_search_view',
'endpoint': 'endpoint_search_view',
}
# 资产类型到原表名的映射(用于 JOIN 获取数组字段)
# ⚠️ 重要pg_ivm 不支持 ArrayField所有数组字段必须从原表 JOIN 获取
TABLE_MAPPING = {
'website': 'website',
'endpoint': 'endpoint',
}
# 有效的资产类型
VALID_ASSET_TYPES = {'website', 'endpoint'}
# Website 查询字段v=视图t=原表)
# ⚠️ 注意t.tech 从原表获取,因为 pg_ivm 不支持 ArrayField
WEBSITE_SELECT_FIELDS = """
v.id,
v.url,
v.host,
v.title,
t.tech, -- ArrayField从 website 表 JOIN 获取
v.status_code,
v.response_headers,
v.response_body,
v.content_type,
v.content_length,
v.webserver,
v.location,
v.vhost,
v.created_at,
v.target_id
"""
# Endpoint 查询字段
# ⚠️ 注意t.tech 和 t.matched_gf_patterns 从原表获取,因为 pg_ivm 不支持 ArrayField
ENDPOINT_SELECT_FIELDS = """
v.id,
v.url,
v.host,
v.title,
t.tech, -- ArrayField从 endpoint 表 JOIN 获取
v.status_code,
v.response_headers,
v.response_body,
v.content_type,
v.content_length,
v.webserver,
v.location,
v.vhost,
t.matched_gf_patterns, -- ArrayField从 endpoint 表 JOIN 获取
v.created_at,
v.target_id
"""
class SearchQueryParser:
"""
搜索查询解析器
支持语法:
- field="value" 模糊匹配ILIKE %value%
- field=="value" 精确匹配
- field!="value" 不等于
- && AND 连接
- || OR 连接
- () 分组(暂不支持嵌套)
示例:
- host="api" && tech="nginx"
- tech="vue" || tech="react"
- status=="200" && host!="test"
"""
# 匹配单个条件: field="value" 或 field=="value" 或 field!="value"
CONDITION_PATTERN = re.compile(r'(\w+)\s*(==|!=|=)\s*"([^"]*)"')
@classmethod
def parse(cls, query: str) -> Tuple[str, List[Any]]:
"""
解析查询字符串,返回 SQL WHERE 子句和参数
Args:
query: 搜索查询字符串
Returns:
(where_clause, params) 元组
"""
if not query or not query.strip():
return "1=1", []
query = query.strip()
# 检查是否包含操作符语法,如果不包含则作为 host 模糊搜索
if not cls.CONDITION_PATTERN.search(query):
# 裸文本,默认作为 host 模糊搜索v 是视图别名)
return "v.host ILIKE %s", [f"%{query}%"]
# 按 || 分割为 OR 组
or_groups = cls._split_by_or(query)
if len(or_groups) == 1:
# 没有 OR直接解析 AND 条件
return cls._parse_and_group(or_groups[0])
# 多个 OR 组
or_clauses = []
all_params = []
for group in or_groups:
clause, params = cls._parse_and_group(group)
if clause and clause != "1=1":
or_clauses.append(f"({clause})")
all_params.extend(params)
if not or_clauses:
return "1=1", []
return " OR ".join(or_clauses), all_params
@classmethod
def _split_by_or(cls, query: str) -> List[str]:
"""按 || 分割查询,但忽略引号内的 ||"""
parts = []
current = ""
in_quotes = False
i = 0
while i < len(query):
char = query[i]
if char == '"':
in_quotes = not in_quotes
current += char
elif not in_quotes and i + 1 < len(query) and query[i:i+2] == '||':
if current.strip():
parts.append(current.strip())
current = ""
i += 1 # 跳过第二个 |
else:
current += char
i += 1
if current.strip():
parts.append(current.strip())
return parts if parts else [query]
@classmethod
def _parse_and_group(cls, group: str) -> Tuple[str, List[Any]]:
"""解析 AND 组(用 && 连接的条件)"""
# 移除外层括号
group = group.strip()
if group.startswith('(') and group.endswith(')'):
group = group[1:-1].strip()
# 按 && 分割
parts = cls._split_by_and(group)
and_clauses = []
all_params = []
for part in parts:
clause, params = cls._parse_condition(part.strip())
if clause:
and_clauses.append(clause)
all_params.extend(params)
if not and_clauses:
return "1=1", []
return " AND ".join(and_clauses), all_params
@classmethod
def _split_by_and(cls, query: str) -> List[str]:
"""按 && 分割查询,但忽略引号内的 &&"""
parts = []
current = ""
in_quotes = False
i = 0
while i < len(query):
char = query[i]
if char == '"':
in_quotes = not in_quotes
current += char
elif not in_quotes and i + 1 < len(query) and query[i:i+2] == '&&':
if current.strip():
parts.append(current.strip())
current = ""
i += 1 # 跳过第二个 &
else:
current += char
i += 1
if current.strip():
parts.append(current.strip())
return parts if parts else [query]
@classmethod
def _parse_condition(cls, condition: str) -> Tuple[Optional[str], List[Any]]:
"""
解析单个条件
Returns:
(sql_clause, params) 或 (None, []) 如果解析失败
"""
# 移除括号
condition = condition.strip()
if condition.startswith('(') and condition.endswith(')'):
condition = condition[1:-1].strip()
match = cls.CONDITION_PATTERN.match(condition)
if not match:
logger.warning(f"无法解析条件: {condition}")
return None, []
field, operator, value = match.groups()
field = field.lower()
# 验证字段
if field not in FIELD_MAPPING:
logger.warning(f"未知字段: {field}")
return None, []
db_field = FIELD_MAPPING[field]
is_array = field in ARRAY_FIELDS
# 根据操作符生成 SQL
if operator == '=':
# 模糊匹配
return cls._build_like_condition(db_field, value, is_array)
elif operator == '==':
# 精确匹配
return cls._build_exact_condition(db_field, value, is_array)
elif operator == '!=':
# 不等于
return cls._build_not_equal_condition(db_field, value, is_array)
return None, []
@classmethod
def _build_like_condition(cls, field: str, value: str, is_array: bool) -> Tuple[str, List[Any]]:
"""构建模糊匹配条件"""
if is_array:
# 数组字段:检查数组中是否有元素包含该值(从原表 t 获取)
return f"EXISTS (SELECT 1 FROM unnest(t.{field}) AS elem WHERE elem ILIKE %s)", [f"%{value}%"]
elif field == 'status_code':
# 状态码是整数,模糊匹配转为精确匹配
try:
return f"v.{field} = %s", [int(value)]
except ValueError:
return f"v.{field}::text ILIKE %s", [f"%{value}%"]
else:
return f"v.{field} ILIKE %s", [f"%{value}%"]
@classmethod
def _build_exact_condition(cls, field: str, value: str, is_array: bool) -> Tuple[str, List[Any]]:
"""构建精确匹配条件"""
if is_array:
# 数组字段:检查数组中是否包含该精确值(从原表 t 获取)
return f"%s = ANY(t.{field})", [value]
elif field == 'status_code':
# 状态码是整数
try:
return f"v.{field} = %s", [int(value)]
except ValueError:
return f"v.{field}::text = %s", [value]
else:
return f"v.{field} = %s", [value]
@classmethod
def _build_not_equal_condition(cls, field: str, value: str, is_array: bool) -> Tuple[str, List[Any]]:
"""构建不等于条件"""
if is_array:
# 数组字段:检查数组中不包含该值(从原表 t 获取)
return f"NOT (%s = ANY(t.{field}))", [value]
elif field == 'status_code':
try:
return f"(v.{field} IS NULL OR v.{field} != %s)", [int(value)]
except ValueError:
return f"(v.{field} IS NULL OR v.{field}::text != %s)", [value]
else:
return f"(v.{field} IS NULL OR v.{field} != %s)", [value]
AssetType = Literal['website', 'endpoint']
class AssetSearchService:
"""资产搜索服务"""
def search(
self,
query: str,
asset_type: AssetType = 'website',
limit: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
搜索资产
Args:
query: 搜索查询字符串
asset_type: 资产类型 ('website''endpoint')
limit: 最大返回数量(可选)
Returns:
List[Dict]: 搜索结果列表
"""
where_clause, params = SearchQueryParser.parse(query)
# 根据资产类型选择视图、原表和字段
view_name = VIEW_MAPPING.get(asset_type, 'asset_search_view')
table_name = TABLE_MAPPING.get(asset_type, 'website')
select_fields = ENDPOINT_SELECT_FIELDS if asset_type == 'endpoint' else WEBSITE_SELECT_FIELDS
# JOIN 原表获取数组字段tech, matched_gf_patterns
sql = f"""
SELECT {select_fields}
FROM {view_name} v
JOIN {table_name} t ON v.id = t.id
WHERE {where_clause}
ORDER BY v.created_at DESC
"""
# 添加 LIMIT
if limit is not None and limit > 0:
sql += f" LIMIT {int(limit)}"
try:
with connection.cursor() as cursor:
cursor.execute(sql, params)
columns = [col[0] for col in cursor.description]
results = []
for row in cursor.fetchall():
result = dict(zip(columns, row))
results.append(result)
return results
except Exception as e:
logger.error(f"搜索查询失败: {e}, SQL: {sql}, params: {params}")
raise
def count(self, query: str, asset_type: AssetType = 'website', statement_timeout_ms: int = 300000) -> int:
"""
统计搜索结果数量
Args:
query: 搜索查询字符串
asset_type: 资产类型 ('website''endpoint')
statement_timeout_ms: SQL 语句超时时间(毫秒),默认 5 分钟
Returns:
int: 结果总数
"""
where_clause, params = SearchQueryParser.parse(query)
# 根据资产类型选择视图和原表
view_name = VIEW_MAPPING.get(asset_type, 'asset_search_view')
table_name = TABLE_MAPPING.get(asset_type, 'website')
# JOIN 原表以支持数组字段查询
sql = f"SELECT COUNT(*) FROM {view_name} v JOIN {table_name} t ON v.id = t.id WHERE {where_clause}"
try:
with connection.cursor() as cursor:
# 为导出设置更长的超时时间(仅影响当前会话)
cursor.execute(f"SET LOCAL statement_timeout = {statement_timeout_ms}")
cursor.execute(sql, params)
return cursor.fetchone()[0]
except Exception as e:
logger.error(f"统计查询失败: {e}")
raise
def search_iter(
self,
query: str,
asset_type: AssetType = 'website',
batch_size: int = 1000,
statement_timeout_ms: int = 300000
) -> Iterator[Dict[str, Any]]:
"""
流式搜索资产(使用分批查询,内存友好)
Args:
query: 搜索查询字符串
asset_type: 资产类型 ('website''endpoint')
batch_size: 每批获取的数量
statement_timeout_ms: SQL 语句超时时间(毫秒),默认 5 分钟
Yields:
Dict: 单条搜索结果
"""
where_clause, params = SearchQueryParser.parse(query)
# 根据资产类型选择视图、原表和字段
view_name = VIEW_MAPPING.get(asset_type, 'asset_search_view')
table_name = TABLE_MAPPING.get(asset_type, 'website')
select_fields = ENDPOINT_SELECT_FIELDS if asset_type == 'endpoint' else WEBSITE_SELECT_FIELDS
# 使用 OFFSET/LIMIT 分批查询Django 不支持命名游标)
offset = 0
try:
while True:
# JOIN 原表获取数组字段
sql = f"""
SELECT {select_fields}
FROM {view_name} v
JOIN {table_name} t ON v.id = t.id
WHERE {where_clause}
ORDER BY v.created_at DESC
LIMIT {batch_size} OFFSET {offset}
"""
with connection.cursor() as cursor:
# 为导出设置更长的超时时间(仅影响当前会话)
cursor.execute(f"SET LOCAL statement_timeout = {statement_timeout_ms}")
cursor.execute(sql, params)
columns = [col[0] for col in cursor.description]
rows = cursor.fetchall()
if not rows:
break
for row in rows:
yield dict(zip(columns, row))
# 如果返回的行数少于 batch_size说明已经是最后一批
if len(rows) < batch_size:
break
offset += batch_size
except Exception as e:
logger.error(f"流式搜索查询失败: {e}, SQL: {sql}, params: {params}")
raise

View File

@@ -72,7 +72,7 @@ class EndpointSnapshotsService:
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
'status_code': 'status_code',
'webserver': 'webserver',
'tech': 'tech',
}

View File

@@ -73,8 +73,8 @@ class WebsiteSnapshotsService:
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status',
'webserver': 'web_server',
'status_code': 'status_code',
'webserver': 'webserver',
'tech': 'tech',
}

View File

@@ -10,19 +10,28 @@ from .views import (
DirectoryViewSet,
VulnerabilityViewSet,
AssetStatisticsViewSet,
AssetSearchView,
AssetSearchExportView,
EndpointViewSet,
HostPortMappingViewSet,
ScreenshotViewSet,
)
# 创建 DRF 路由器
router = DefaultRouter()
# 注册 ViewSet
# 注意IPAddress 模型已被重构为 HostPortMapping相关路由已移除
router.register(r'subdomains', SubdomainViewSet, basename='subdomain')
router.register(r'websites', WebSiteViewSet, basename='website')
router.register(r'directories', DirectoryViewSet, basename='directory')
router.register(r'endpoints', EndpointViewSet, basename='endpoint')
router.register(r'ip-addresses', HostPortMappingViewSet, basename='ip-address')
router.register(r'vulnerabilities', VulnerabilityViewSet, basename='vulnerability')
router.register(r'screenshots', ScreenshotViewSet, basename='screenshot')
router.register(r'statistics', AssetStatisticsViewSet, basename='asset-statistics')
urlpatterns = [
path('assets/', include(router.urls)),
path('assets/search/', AssetSearchView.as_view(), name='asset-search'),
path('assets/search/export/', AssetSearchExportView.as_view(), name='asset-search-export'),
]

View File

@@ -0,0 +1,44 @@
"""
Asset 应用视图模块
重新导出所有视图类以保持向后兼容
"""
from .asset_views import (
AssetStatisticsViewSet,
SubdomainViewSet,
WebSiteViewSet,
DirectoryViewSet,
EndpointViewSet,
HostPortMappingViewSet,
VulnerabilityViewSet,
SubdomainSnapshotViewSet,
WebsiteSnapshotViewSet,
DirectorySnapshotViewSet,
EndpointSnapshotViewSet,
HostPortMappingSnapshotViewSet,
VulnerabilitySnapshotViewSet,
ScreenshotViewSet,
ScreenshotSnapshotViewSet,
)
from .search_views import AssetSearchView, AssetSearchExportView
__all__ = [
'AssetStatisticsViewSet',
'SubdomainViewSet',
'WebSiteViewSet',
'DirectoryViewSet',
'EndpointViewSet',
'HostPortMappingViewSet',
'VulnerabilityViewSet',
'SubdomainSnapshotViewSet',
'WebsiteSnapshotViewSet',
'DirectorySnapshotViewSet',
'EndpointSnapshotViewSet',
'HostPortMappingSnapshotViewSet',
'VulnerabilitySnapshotViewSet',
'ScreenshotViewSet',
'ScreenshotSnapshotViewSet',
'AssetSearchView',
'AssetSearchExportView',
]

View File

@@ -8,19 +8,18 @@ from rest_framework.request import Request
from rest_framework.exceptions import NotFound, ValidationError as DRFValidationError
from django.core.exceptions import ValidationError, ObjectDoesNotExist
from django.db import DatabaseError, IntegrityError, OperationalError
from django.http import StreamingHttpResponse
from .serializers import (
from ..serializers import (
SubdomainListSerializer, WebSiteSerializer, DirectorySerializer,
VulnerabilitySerializer, EndpointListSerializer, IPAddressAggregatedSerializer,
SubdomainSnapshotSerializer, WebsiteSnapshotSerializer, DirectorySnapshotSerializer,
EndpointSnapshotSerializer, VulnerabilitySnapshotSerializer
)
from .services import (
from ..services import (
SubdomainService, WebSiteService, DirectoryService,
VulnerabilityService, AssetStatisticsService, EndpointService, HostPortMappingService
)
from .services.snapshot import (
from ..services.snapshot import (
SubdomainSnapshotsService, WebsiteSnapshotsService, DirectorySnapshotsService,
EndpointSnapshotsService, HostPortMappingSnapshotsService, VulnerabilitySnapshotsService
)
@@ -243,7 +242,7 @@ class SubdomainViewSet(viewsets.ModelViewSet):
CSV name, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime
from apps.common.utils import create_csv_export_response, format_datetime
target_pk = self.kwargs.get('target_pk')
if not target_pk:
@@ -254,12 +253,41 @@ class SubdomainViewSet(viewsets.ModelViewSet):
headers = ['name', 'created_at']
formatters = {'created_at': format_datetime}
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, formatters),
content_type='text/csv; charset=utf-8'
return create_csv_export_response(
data_iterator=data_iterator,
headers=headers,
filename=f"target-{target_pk}-subdomains.csv",
field_formatters=formatters
)
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-subdomains.csv"'
return response
@action(detail=False, methods=['post'], url_path='bulk-delete')
def bulk_delete(self, request, **kwargs):
"""批量删除子域名
POST /api/assets/subdomains/bulk-delete/
请求体: {"ids": [1, 2, 3]}
响应: {"deletedCount": 3}
"""
ids = request.data.get('ids', [])
if not ids or not isinstance(ids, list):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='ids is required and must be a list',
status_code=status.HTTP_400_BAD_REQUEST
)
try:
from ..models import Subdomain
deleted_count, _ = Subdomain.objects.filter(id__in=ids).delete()
return success_response(data={'deletedCount': deleted_count})
except Exception as e:
logger.exception("批量删除子域名失败")
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Failed to delete subdomains',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
class WebSiteViewSet(viewsets.ModelViewSet):
@@ -367,9 +395,9 @@ class WebSiteViewSet(viewsets.ModelViewSet):
def export(self, request, **kwargs):
"""导出网站为 CSV 格式
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, created_at
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, response_body, response_headers, vhost, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
from apps.common.utils import create_csv_export_response, format_datetime, format_list_field
target_pk = self.kwargs.get('target_pk')
if not target_pk:
@@ -380,19 +408,48 @@ class WebSiteViewSet(viewsets.ModelViewSet):
headers = [
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'created_at'
'response_body', 'response_headers', 'vhost', 'created_at'
]
formatters = {
'created_at': format_datetime,
'tech': lambda x: format_list_field(x, separator=','),
}
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, formatters),
content_type='text/csv; charset=utf-8'
return create_csv_export_response(
data_iterator=data_iterator,
headers=headers,
filename=f"target-{target_pk}-websites.csv",
field_formatters=formatters
)
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-websites.csv"'
return response
@action(detail=False, methods=['post'], url_path='bulk-delete')
def bulk_delete(self, request, **kwargs):
"""批量删除网站
POST /api/assets/websites/bulk-delete/
请求体: {"ids": [1, 2, 3]}
响应: {"deletedCount": 3}
"""
ids = request.data.get('ids', [])
if not ids or not isinstance(ids, list):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='ids is required and must be a list',
status_code=status.HTTP_400_BAD_REQUEST
)
try:
from ..models import WebSite
deleted_count, _ = WebSite.objects.filter(id__in=ids).delete()
return success_response(data={'deletedCount': deleted_count})
except Exception as e:
logger.exception("批量删除网站失败")
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Failed to delete websites',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
class DirectoryViewSet(viewsets.ModelViewSet):
@@ -499,7 +556,7 @@ class DirectoryViewSet(viewsets.ModelViewSet):
CSV url, status, content_length, words, lines, content_type, duration, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime
from apps.common.utils import create_csv_export_response, format_datetime
target_pk = self.kwargs.get('target_pk')
if not target_pk:
@@ -515,12 +572,41 @@ class DirectoryViewSet(viewsets.ModelViewSet):
'created_at': format_datetime,
}
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, formatters),
content_type='text/csv; charset=utf-8'
return create_csv_export_response(
data_iterator=data_iterator,
headers=headers,
filename=f"target-{target_pk}-directories.csv",
field_formatters=formatters
)
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-directories.csv"'
return response
@action(detail=False, methods=['post'], url_path='bulk-delete')
def bulk_delete(self, request, **kwargs):
"""批量删除目录
POST /api/assets/directories/bulk-delete/
请求体: {"ids": [1, 2, 3]}
响应: {"deletedCount": 3}
"""
ids = request.data.get('ids', [])
if not ids or not isinstance(ids, list):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='ids is required and must be a list',
status_code=status.HTTP_400_BAD_REQUEST
)
try:
from ..models import Directory
deleted_count, _ = Directory.objects.filter(id__in=ids).delete()
return success_response(data={'deletedCount': deleted_count})
except Exception as e:
logger.exception("批量删除目录失败")
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Failed to delete directories',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
class EndpointViewSet(viewsets.ModelViewSet):
@@ -628,9 +714,9 @@ class EndpointViewSet(viewsets.ModelViewSet):
def export(self, request, **kwargs):
"""导出端点为 CSV 格式
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, matched_gf_patterns, created_at
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, response_body, response_headers, vhost, matched_gf_patterns, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
from apps.common.utils import create_csv_export_response, format_datetime, format_list_field
target_pk = self.kwargs.get('target_pk')
if not target_pk:
@@ -641,7 +727,7 @@ class EndpointViewSet(viewsets.ModelViewSet):
headers = [
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
]
formatters = {
'created_at': format_datetime,
@@ -649,12 +735,41 @@ class EndpointViewSet(viewsets.ModelViewSet):
'matched_gf_patterns': lambda x: format_list_field(x, separator=','),
}
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, formatters),
content_type='text/csv; charset=utf-8'
return create_csv_export_response(
data_iterator=data_iterator,
headers=headers,
filename=f"target-{target_pk}-endpoints.csv",
field_formatters=formatters
)
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-endpoints.csv"'
return response
@action(detail=False, methods=['post'], url_path='bulk-delete')
def bulk_delete(self, request, **kwargs):
"""批量删除端点
POST /api/assets/endpoints/bulk-delete/
请求体: {"ids": [1, 2, 3]}
响应: {"deletedCount": 3}
"""
ids = request.data.get('ids', [])
if not ids or not isinstance(ids, list):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='ids is required and must be a list',
status_code=status.HTTP_400_BAD_REQUEST
)
try:
from ..models import Endpoint
deleted_count, _ = Endpoint.objects.filter(id__in=ids).delete()
return success_response(data={'deletedCount': deleted_count})
except Exception as e:
logger.exception("批量删除端点失败")
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Failed to delete endpoints',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
class HostPortMappingViewSet(viewsets.ModelViewSet):
@@ -707,7 +822,7 @@ class HostPortMappingViewSet(viewsets.ModelViewSet):
CSV ip, host, port, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime
from apps.common.utils import create_csv_export_response, format_datetime
target_pk = self.kwargs.get('target_pk')
if not target_pk:
@@ -722,14 +837,44 @@ class HostPortMappingViewSet(viewsets.ModelViewSet):
'created_at': format_datetime
}
# 生成流式响应
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, formatters),
content_type='text/csv; charset=utf-8'
return create_csv_export_response(
data_iterator=data_iterator,
headers=headers,
filename=f"target-{target_pk}-ip-addresses.csv",
field_formatters=formatters
)
response['Content-Disposition'] = f'attachment; filename="target-{target_pk}-ip-addresses.csv"'
@action(detail=False, methods=['post'], url_path='bulk-delete')
def bulk_delete(self, request, **kwargs):
"""批量删除 IP 地址映射
return response
POST /api/assets/ip-addresses/bulk-delete/
请求体: {"ips": ["192.168.1.1", "10.0.0.1"]}
响应: {"deletedCount": 3}
注意由于 IP 地址是聚合显示的删除时传入 IP 列表
会删除该 IP 下的所有 host:port 映射记录
"""
ips = request.data.get('ips', [])
if not ips or not isinstance(ips, list):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='ips is required and must be a list',
status_code=status.HTTP_400_BAD_REQUEST
)
try:
from ..models import HostPortMapping
deleted_count, _ = HostPortMapping.objects.filter(ip__in=ips).delete()
return success_response(data={'deletedCount': deleted_count})
except Exception as e:
logger.exception("批量删除 IP 地址映射失败")
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Failed to delete ip addresses',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
class VulnerabilityViewSet(viewsets.ModelViewSet):
@@ -801,7 +946,7 @@ class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
CSV name, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime
from apps.common.utils import create_csv_export_response, format_datetime
scan_pk = self.kwargs.get('scan_pk')
if not scan_pk:
@@ -812,12 +957,12 @@ class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
headers = ['name', 'created_at']
formatters = {'created_at': format_datetime}
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, formatters),
content_type='text/csv; charset=utf-8'
return create_csv_export_response(
data_iterator=data_iterator,
headers=headers,
filename=f"scan-{scan_pk}-subdomains.csv",
field_formatters=formatters
)
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-subdomains.csv"'
return response
class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
@@ -853,9 +998,9 @@ class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
def export(self, request, **kwargs):
"""导出网站快照为 CSV 格式
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, created_at
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, response_body, response_headers, vhost, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
from apps.common.utils import create_csv_export_response, format_datetime, format_list_field
scan_pk = self.kwargs.get('scan_pk')
if not scan_pk:
@@ -866,19 +1011,19 @@ class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
headers = [
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'created_at'
'response_body', 'response_headers', 'vhost', 'created_at'
]
formatters = {
'created_at': format_datetime,
'tech': lambda x: format_list_field(x, separator=','),
}
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, formatters),
content_type='text/csv; charset=utf-8'
return create_csv_export_response(
data_iterator=data_iterator,
headers=headers,
filename=f"scan-{scan_pk}-websites.csv",
field_formatters=formatters
)
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-websites.csv"'
return response
class DirectorySnapshotViewSet(viewsets.ModelViewSet):
@@ -913,7 +1058,7 @@ class DirectorySnapshotViewSet(viewsets.ModelViewSet):
CSV url, status, content_length, words, lines, content_type, duration, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime
from apps.common.utils import create_csv_export_response, format_datetime
scan_pk = self.kwargs.get('scan_pk')
if not scan_pk:
@@ -929,12 +1074,12 @@ class DirectorySnapshotViewSet(viewsets.ModelViewSet):
'created_at': format_datetime,
}
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, formatters),
content_type='text/csv; charset=utf-8'
return create_csv_export_response(
data_iterator=data_iterator,
headers=headers,
filename=f"scan-{scan_pk}-directories.csv",
field_formatters=formatters
)
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-directories.csv"'
return response
class EndpointSnapshotViewSet(viewsets.ModelViewSet):
@@ -970,9 +1115,9 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
def export(self, request, **kwargs):
"""导出端点快照为 CSV 格式
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, matched_gf_patterns, created_at
CSV url, host, location, title, status_code, content_length, content_type, webserver, tech, response_body, response_headers, vhost, matched_gf_patterns, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
from apps.common.utils import create_csv_export_response, format_datetime, format_list_field
scan_pk = self.kwargs.get('scan_pk')
if not scan_pk:
@@ -983,7 +1128,7 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
headers = [
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
'response_body', 'response_headers', 'vhost', 'matched_gf_patterns', 'created_at'
]
formatters = {
'created_at': format_datetime,
@@ -991,12 +1136,12 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
'matched_gf_patterns': lambda x: format_list_field(x, separator=','),
}
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, formatters),
content_type='text/csv; charset=utf-8'
return create_csv_export_response(
data_iterator=data_iterator,
headers=headers,
filename=f"scan-{scan_pk}-endpoints.csv",
field_formatters=formatters
)
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-endpoints.csv"'
return response
class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
@@ -1031,7 +1176,7 @@ class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
CSV ip, host, port, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime
from apps.common.utils import create_csv_export_response, format_datetime
scan_pk = self.kwargs.get('scan_pk')
if not scan_pk:
@@ -1046,14 +1191,12 @@ class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
'created_at': format_datetime
}
# 生成流式响应
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, formatters),
content_type='text/csv; charset=utf-8'
return create_csv_export_response(
data_iterator=data_iterator,
headers=headers,
filename=f"scan-{scan_pk}-ip-addresses.csv",
field_formatters=formatters
)
response['Content-Disposition'] = f'attachment; filename="scan-{scan_pk}-ip-addresses.csv"'
return response
class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
@@ -1082,3 +1225,162 @@ class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
if scan_pk:
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
# ==================== 截图 ViewSet ====================
class ScreenshotViewSet(viewsets.ModelViewSet):
"""截图资产 ViewSet
支持两种访问方式
1. 嵌套路由GET /api/targets/{target_pk}/screenshots/
2. 独立路由GET /api/screenshots/全局查询
支持智能过滤语法filter 参数
- url="example" URL 模糊匹配
"""
from ..serializers import ScreenshotListSerializer
serializer_class = ScreenshotListSerializer
pagination_class = BasePagination
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
from ..models import Screenshot
target_pk = self.kwargs.get('target_pk')
filter_query = self.request.query_params.get('filter', None)
queryset = Screenshot.objects.all()
if target_pk:
queryset = queryset.filter(target_id=target_pk)
if filter_query:
# 简单的 URL 模糊匹配
queryset = queryset.filter(url__icontains=filter_query)
return queryset.order_by('-created_at')
@action(detail=True, methods=['get'], url_path='image')
def image(self, request, pk=None, **kwargs):
"""获取截图图片
GET /api/assets/screenshots/{id}/image/
返回 WebP 格式的图片二进制数据
"""
from django.http import HttpResponse
from ..models import Screenshot
try:
screenshot = Screenshot.objects.get(pk=pk)
if not screenshot.image:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Screenshot image not found',
status_code=status.HTTP_404_NOT_FOUND
)
response = HttpResponse(screenshot.image, content_type='image/webp')
response['Content-Disposition'] = f'inline; filename="screenshot_{pk}.webp"'
return response
except Screenshot.DoesNotExist:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Screenshot not found',
status_code=status.HTTP_404_NOT_FOUND
)
@action(detail=False, methods=['post'], url_path='bulk-delete')
def bulk_delete(self, request, **kwargs):
"""批量删除截图
POST /api/assets/screenshots/bulk-delete/
请求体: {"ids": [1, 2, 3]}
响应: {"deletedCount": 3}
"""
ids = request.data.get('ids', [])
if not ids or not isinstance(ids, list):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='ids is required and must be a list',
status_code=status.HTTP_400_BAD_REQUEST
)
try:
from ..models import Screenshot
deleted_count, _ = Screenshot.objects.filter(id__in=ids).delete()
return success_response(data={'deletedCount': deleted_count})
except Exception as e:
logger.exception("批量删除截图失败")
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Failed to delete screenshots',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
class ScreenshotSnapshotViewSet(viewsets.ModelViewSet):
"""截图快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/screenshots/
支持智能过滤语法filter 参数
- url="example" URL 模糊匹配
"""
from ..serializers import ScreenshotSnapshotListSerializer
serializer_class = ScreenshotSnapshotListSerializer
pagination_class = BasePagination
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def get_queryset(self):
"""根据 scan_pk 参数查询"""
from ..models import ScreenshotSnapshot
scan_pk = self.kwargs.get('scan_pk')
filter_query = self.request.query_params.get('filter', None)
queryset = ScreenshotSnapshot.objects.all()
if scan_pk:
queryset = queryset.filter(scan_id=scan_pk)
if filter_query:
# 简单的 URL 模糊匹配
queryset = queryset.filter(url__icontains=filter_query)
return queryset.order_by('-created_at')
@action(detail=True, methods=['get'], url_path='image')
def image(self, request, pk=None, **kwargs):
"""获取截图快照图片
GET /api/scans/{scan_pk}/screenshots/{id}/image/
返回 WebP 格式的图片二进制数据
"""
from django.http import HttpResponse
from ..models import ScreenshotSnapshot
try:
screenshot = ScreenshotSnapshot.objects.get(pk=pk)
if not screenshot.image:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Screenshot image not found',
status_code=status.HTTP_404_NOT_FOUND
)
response = HttpResponse(screenshot.image, content_type='image/webp')
response['Content-Disposition'] = f'inline; filename="screenshot_snapshot_{pk}.webp"'
return response
except ScreenshotSnapshot.DoesNotExist:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Screenshot snapshot not found',
status_code=status.HTTP_404_NOT_FOUND
)

View File

@@ -0,0 +1,361 @@
"""
资产搜索 API 视图
提供资产搜索的 REST API 接口:
- GET /api/assets/search/ - 搜索资产
- GET /api/assets/search/export/ - 导出搜索结果为 CSV
搜索语法:
- field="value" 模糊匹配ILIKE %value%
- field=="value" 精确匹配
- field!="value" 不等于
- && AND 连接
- || OR 连接
支持的字段:
- host: 主机名
- url: URL
- title: 标题
- tech: 技术栈
- status: 状态码
- body: 响应体
- header: 响应头
支持的资产类型:
- website: 站点(默认)
- endpoint: 端点
"""
import logging
import json
from datetime import datetime
from urllib.parse import urlparse, urlunparse
from rest_framework import status
from rest_framework.views import APIView
from rest_framework.request import Request
from django.db import connection
from apps.common.response_helpers import success_response, error_response
from apps.common.error_codes import ErrorCodes
from apps.asset.services.search_service import AssetSearchService, VALID_ASSET_TYPES
logger = logging.getLogger(__name__)
class AssetSearchView(APIView):
"""
资产搜索 API
GET /api/assets/search/
Query Parameters:
q: 搜索查询表达式
asset_type: 资产类型 ('website''endpoint',默认 'website')
page: 页码(从 1 开始,默认 1
pageSize: 每页数量(默认 10最大 100
示例查询:
?q=host="api" && tech="nginx"
?q=tech="vue" || tech="react"&asset_type=endpoint
?q=status=="200" && host!="test"
Response:
{
"results": [...],
"total": 100,
"page": 1,
"pageSize": 10,
"totalPages": 10,
"assetType": "website"
}
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = AssetSearchService()
def _parse_headers(self, headers_data) -> dict:
"""解析响应头为字典"""
if not headers_data:
return {}
try:
return json.loads(headers_data)
except (json.JSONDecodeError, TypeError):
result = {}
for line in str(headers_data).split('\n'):
if ':' in line:
key, value = line.split(':', 1)
result[key.strip()] = value.strip()
return result
def _format_result(self, result: dict, vulnerabilities_by_url: dict, asset_type: str) -> dict:
"""格式化单个搜索结果"""
url = result.get('url', '')
vulns = vulnerabilities_by_url.get(url, [])
# 基础字段Website 和 Endpoint 共有)
formatted = {
'id': result.get('id'),
'url': url,
'host': result.get('host', ''),
'title': result.get('title', ''),
'technologies': result.get('tech', []) or [],
'statusCode': result.get('status_code'),
'contentLength': result.get('content_length'),
'contentType': result.get('content_type', ''),
'webserver': result.get('webserver', ''),
'location': result.get('location', ''),
'vhost': result.get('vhost'),
'responseHeaders': self._parse_headers(result.get('response_headers')),
'responseBody': result.get('response_body', ''),
'createdAt': result.get('created_at').isoformat() if result.get('created_at') else None,
'targetId': result.get('target_id'),
}
# Website 特有字段:漏洞关联
if asset_type == 'website':
formatted['vulnerabilities'] = [
{
'id': v.get('id'),
'name': v.get('vuln_type', ''),
'vulnType': v.get('vuln_type', ''),
'severity': v.get('severity', 'info'),
}
for v in vulns
]
# Endpoint 特有字段
if asset_type == 'endpoint':
formatted['matchedGfPatterns'] = result.get('matched_gf_patterns', []) or []
return formatted
def _get_vulnerabilities_by_url_prefix(self, website_urls: list) -> dict:
"""
根据 URL 前缀批量查询漏洞数据
漏洞 URL 是 website URL 的子路径,使用前缀匹配:
- website.url: https://example.com/path?query=1
- vulnerability.url: https://example.com/path/api/users
Args:
website_urls: website URL 列表,格式为 [(url, target_id), ...]
Returns:
dict: {website_url: [vulnerability_list]}
"""
if not website_urls:
return {}
try:
with connection.cursor() as cursor:
# 构建 OR 条件:每个 website URL去掉查询参数作为前缀匹配
conditions = []
params = []
url_mapping = {} # base_url -> original_url
for url, target_id in website_urls:
if not url or target_id is None:
continue
# 使用 urlparse 去掉查询参数和片段,只保留 scheme://netloc/path
parsed = urlparse(url)
base_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, '', '', ''))
url_mapping[base_url] = url
conditions.append("(v.url LIKE %s AND v.target_id = %s)")
params.extend([base_url + '%', target_id])
if not conditions:
return {}
where_clause = " OR ".join(conditions)
sql = f"""
SELECT v.id, v.vuln_type, v.severity, v.url, v.target_id
FROM vulnerability v
WHERE {where_clause}
ORDER BY
CASE v.severity
WHEN 'critical' THEN 1
WHEN 'high' THEN 2
WHEN 'medium' THEN 3
WHEN 'low' THEN 4
ELSE 5
END
"""
cursor.execute(sql, params)
# 获取所有漏洞
all_vulns = []
for row in cursor.fetchall():
all_vulns.append({
'id': row[0],
'vuln_type': row[1],
'name': row[1],
'severity': row[2],
'url': row[3],
'target_id': row[4],
})
# 按原始 website URL 分组(用于返回结果)
result = {url: [] for url, _ in website_urls}
for vuln in all_vulns:
vuln_url = vuln['url']
# 找到匹配的 website URL最长前缀匹配
for website_url, target_id in website_urls:
parsed = urlparse(website_url)
base_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, '', '', ''))
if vuln_url.startswith(base_url) and vuln['target_id'] == target_id:
result[website_url].append(vuln)
break
return result
except Exception as e:
logger.error(f"批量查询漏洞失败: {e}")
return {}
def get(self, request: Request):
"""搜索资产"""
# 获取搜索查询
query = request.query_params.get('q', '').strip()
if not query:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Search query (q) is required',
status_code=status.HTTP_400_BAD_REQUEST
)
# 获取并验证资产类型
asset_type = request.query_params.get('asset_type', 'website').strip().lower()
if asset_type not in VALID_ASSET_TYPES:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message=f'Invalid asset_type. Must be one of: {", ".join(VALID_ASSET_TYPES)}',
status_code=status.HTTP_400_BAD_REQUEST
)
# 获取分页参数
try:
page = int(request.query_params.get('page', 1))
page_size = int(request.query_params.get('pageSize', 10))
except (ValueError, TypeError):
page = 1
page_size = 10
# 限制分页参数
page = max(1, page)
page_size = min(max(1, page_size), 100)
# 获取总数和搜索结果
total = self.service.count(query, asset_type)
total_pages = (total + page_size - 1) // page_size if total > 0 else 1
offset = (page - 1) * page_size
all_results = self.service.search(query, asset_type)
results = all_results[offset:offset + page_size]
# 批量查询漏洞数据(仅 Website 类型需要)
vulnerabilities_by_url = {}
if asset_type == 'website':
website_urls = [(r.get('url'), r.get('target_id')) for r in results if r.get('url') and r.get('target_id')]
vulnerabilities_by_url = self._get_vulnerabilities_by_url_prefix(website_urls) if website_urls else {}
# 格式化结果
formatted_results = [self._format_result(r, vulnerabilities_by_url, asset_type) for r in results]
return success_response(data={
'results': formatted_results,
'total': total,
'page': page,
'pageSize': page_size,
'totalPages': total_pages,
'assetType': asset_type,
})
class AssetSearchExportView(APIView):
"""
资产搜索导出 API
GET /api/assets/search/export/
Query Parameters:
q: 搜索查询表达式
asset_type: 资产类型 ('website''endpoint',默认 'website')
Response:
CSV 文件(带 Content-Length支持浏览器显示下载进度
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = AssetSearchService()
def _get_headers_and_formatters(self, asset_type: str):
"""获取 CSV 表头和格式化器"""
from apps.common.utils import format_datetime, format_list_field
if asset_type == 'website':
headers = ['url', 'host', 'title', 'status_code', 'content_type', 'content_length',
'webserver', 'location', 'tech', 'vhost', 'created_at']
else:
headers = ['url', 'host', 'title', 'status_code', 'content_type', 'content_length',
'webserver', 'location', 'tech', 'matched_gf_patterns', 'vhost', 'created_at']
formatters = {
'created_at': format_datetime,
'tech': lambda x: format_list_field(x, separator='; '),
'matched_gf_patterns': lambda x: format_list_field(x, separator='; '),
'vhost': lambda x: 'true' if x else ('false' if x is False else ''),
}
return headers, formatters
def get(self, request: Request):
"""导出搜索结果为 CSV带 Content-Length支持下载进度显示"""
from apps.common.utils import create_csv_export_response
# 获取搜索查询
query = request.query_params.get('q', '').strip()
if not query:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Search query (q) is required',
status_code=status.HTTP_400_BAD_REQUEST
)
# 获取并验证资产类型
asset_type = request.query_params.get('asset_type', 'website').strip().lower()
if asset_type not in VALID_ASSET_TYPES:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message=f'Invalid asset_type. Must be one of: {", ".join(VALID_ASSET_TYPES)}',
status_code=status.HTTP_400_BAD_REQUEST
)
# 检查是否有结果(快速检查,避免空导出)
total = self.service.count(query, asset_type)
if total == 0:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='No results to export',
status_code=status.HTTP_404_NOT_FOUND
)
# 获取表头和格式化器
headers, formatters = self._get_headers_and_formatters(asset_type)
# 生成文件名
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'search_{asset_type}_{timestamp}.csv'
# 使用通用导出工具
data_iterator = self.service.search_iter(query, asset_type)
return create_csv_export_response(
data_iterator=data_iterator,
headers=headers,
filename=filename,
field_formatters=formatters,
show_progress=True # 显示下载进度
)

View File

@@ -0,0 +1,34 @@
# Generated by Django 5.2.7 on 2026-01-06 00:55
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
('targets', '0001_initial'),
]
operations = [
migrations.CreateModel(
name='BlacklistRule',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('pattern', models.CharField(help_text='规则模式,如 *.gov, 10.0.0.0/8, 192.168.1.1', max_length=255)),
('rule_type', models.CharField(choices=[('domain', '域名'), ('ip', 'IP地址'), ('cidr', 'CIDR范围'), ('keyword', '关键词')], help_text='规则类型domain, ip, cidr', max_length=20)),
('scope', models.CharField(choices=[('global', '全局规则'), ('target', 'Target规则')], db_index=True, help_text='作用域global 或 target', max_length=20)),
('description', models.CharField(blank=True, default='', help_text='规则描述', max_length=500)),
('created_at', models.DateTimeField(auto_now_add=True)),
('target', models.ForeignKey(blank=True, help_text='关联的 Target仅 scope=target 时有值)', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='blacklist_rules', to='targets.target')),
],
options={
'db_table': 'blacklist_rule',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['scope', 'rule_type'], name='blacklist_r_scope_6ff77f_idx'), models.Index(fields=['target', 'scope'], name='blacklist_r_target__191441_idx')],
'constraints': [models.UniqueConstraint(fields=('pattern', 'scope', 'target'), name='unique_blacklist_rule')],
},
),
]

View File

@@ -0,0 +1,4 @@
"""Common models"""
from apps.common.models.blacklist import BlacklistRule
__all__ = ['BlacklistRule']

View File

@@ -0,0 +1,71 @@
"""黑名单规则模型"""
from django.db import models
class BlacklistRule(models.Model):
"""黑名单规则模型
用于存储黑名单过滤规则支持域名、IP、CIDR 三种类型。
支持两层作用域:全局规则和 Target 级规则。
"""
class RuleType(models.TextChoices):
DOMAIN = 'domain', '域名'
IP = 'ip', 'IP地址'
CIDR = 'cidr', 'CIDR范围'
KEYWORD = 'keyword', '关键词'
class Scope(models.TextChoices):
GLOBAL = 'global', '全局规则'
TARGET = 'target', 'Target规则'
id = models.AutoField(primary_key=True)
pattern = models.CharField(
max_length=255,
help_text='规则模式,如 *.gov, 10.0.0.0/8, 192.168.1.1'
)
rule_type = models.CharField(
max_length=20,
choices=RuleType.choices,
help_text='规则类型domain, ip, cidr'
)
scope = models.CharField(
max_length=20,
choices=Scope.choices,
db_index=True,
help_text='作用域global 或 target'
)
target = models.ForeignKey(
'targets.Target',
on_delete=models.CASCADE,
null=True,
blank=True,
related_name='blacklist_rules',
help_text='关联的 Target仅 scope=target 时有值)'
)
description = models.CharField(
max_length=500,
blank=True,
default='',
help_text='规则描述'
)
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = 'blacklist_rule'
indexes = [
models.Index(fields=['scope', 'rule_type']),
models.Index(fields=['target', 'scope']),
]
constraints = [
models.UniqueConstraint(
fields=['pattern', 'scope', 'target'],
name='unique_blacklist_rule'
),
]
ordering = ['-created_at']
def __str__(self):
if self.scope == self.Scope.TARGET and self.target:
return f"[{self.scope}:{self.target_id}] {self.pattern}"
return f"[{self.scope}] {self.pattern}"

View File

@@ -0,0 +1,12 @@
"""Common serializers"""
from .blacklist_serializers import (
BlacklistRuleSerializer,
GlobalBlacklistRuleSerializer,
TargetBlacklistRuleSerializer,
)
__all__ = [
'BlacklistRuleSerializer',
'GlobalBlacklistRuleSerializer',
'TargetBlacklistRuleSerializer',
]

View File

@@ -0,0 +1,68 @@
"""黑名单规则序列化器"""
from rest_framework import serializers
from apps.common.models import BlacklistRule
from apps.common.utils import detect_rule_type
class BlacklistRuleSerializer(serializers.ModelSerializer):
"""黑名单规则序列化器"""
class Meta:
model = BlacklistRule
fields = [
'id',
'pattern',
'rule_type',
'scope',
'target',
'description',
'created_at',
]
read_only_fields = ['id', 'rule_type', 'created_at']
def validate_pattern(self, value):
"""验证规则模式"""
if not value or not value.strip():
raise serializers.ValidationError("规则模式不能为空")
return value.strip()
def create(self, validated_data):
"""创建规则时自动识别规则类型"""
pattern = validated_data.get('pattern', '')
validated_data['rule_type'] = detect_rule_type(pattern)
return super().create(validated_data)
def update(self, instance, validated_data):
"""更新规则时重新识别规则类型"""
if 'pattern' in validated_data:
pattern = validated_data['pattern']
validated_data['rule_type'] = detect_rule_type(pattern)
return super().update(instance, validated_data)
class GlobalBlacklistRuleSerializer(BlacklistRuleSerializer):
"""全局黑名单规则序列化器"""
class Meta(BlacklistRuleSerializer.Meta):
fields = ['id', 'pattern', 'rule_type', 'description', 'created_at']
read_only_fields = ['id', 'rule_type', 'created_at']
def create(self, validated_data):
"""创建全局规则"""
validated_data['scope'] = BlacklistRule.Scope.GLOBAL
validated_data['target'] = None
return super().create(validated_data)
class TargetBlacklistRuleSerializer(BlacklistRuleSerializer):
"""Target 黑名单规则序列化器"""
class Meta(BlacklistRuleSerializer.Meta):
fields = ['id', 'pattern', 'rule_type', 'description', 'created_at']
read_only_fields = ['id', 'rule_type', 'created_at']
def create(self, validated_data):
"""创建 Target 规则target_id 由 view 设置)"""
validated_data['scope'] = BlacklistRule.Scope.TARGET
return super().create(validated_data)

View File

@@ -3,13 +3,16 @@
提供系统级别的公共服务,包括:
- SystemLogService: 系统日志读取服务
- BlacklistService: 黑名单过滤服务
注意FilterService 已移至 apps.common.utils.filter_utils
推荐使用: from apps.common.utils.filter_utils import apply_filters
"""
from .system_log_service import SystemLogService
from .blacklist_service import BlacklistService
__all__ = [
'SystemLogService',
'BlacklistService',
]

View File

@@ -0,0 +1,176 @@
"""
黑名单规则管理服务
负责黑名单规则的 CRUD 操作(数据库层面)。
过滤逻辑请使用 apps.common.utils.BlacklistFilter。
架构说明:
- Model: BlacklistRule (apps.common.models.blacklist)
- Service: BlacklistService (本文件) - 规则 CRUD
- Utils: BlacklistFilter (apps.common.utils.blacklist_filter) - 过滤逻辑
- View: GlobalBlacklistView, TargetViewSet.blacklist
"""
import logging
from typing import List, Dict, Any, Optional
from django.db.models import QuerySet
from apps.common.utils import detect_rule_type
logger = logging.getLogger(__name__)
def _normalize_patterns(patterns: List[str]) -> List[str]:
"""
规范化规则列表:去重 + 过滤空行
Args:
patterns: 原始规则列表
Returns:
List[str]: 去重后的规则列表(保持顺序)
"""
return list(dict.fromkeys(filter(None, (p.strip() for p in patterns))))
class BlacklistService:
"""
黑名单规则管理服务
只负责规则的 CRUD 操作,不包含过滤逻辑。
过滤逻辑请使用 BlacklistFilter 工具类。
"""
def get_global_rules(self) -> QuerySet:
"""
获取全局黑名单规则列表
Returns:
QuerySet: 全局规则查询集
"""
from apps.common.models import BlacklistRule
return BlacklistRule.objects.filter(scope=BlacklistRule.Scope.GLOBAL)
def get_target_rules(self, target_id: int) -> QuerySet:
"""
获取 Target 级黑名单规则列表
Args:
target_id: Target ID
Returns:
QuerySet: Target 级规则查询集
"""
from apps.common.models import BlacklistRule
return BlacklistRule.objects.filter(
scope=BlacklistRule.Scope.TARGET,
target_id=target_id
)
def get_rules(self, target_id: Optional[int] = None) -> List:
"""
获取黑名单规则(全局 + Target 级)
Args:
target_id: Target ID用于加载 Target 级规则
Returns:
List[BlacklistRule]: 规则列表
"""
from apps.common.models import BlacklistRule
# 加载全局规则
rules = list(BlacklistRule.objects.filter(scope=BlacklistRule.Scope.GLOBAL))
# 加载 Target 级规则
if target_id:
target_rules = BlacklistRule.objects.filter(
scope=BlacklistRule.Scope.TARGET,
target_id=target_id
)
rules.extend(target_rules)
return rules
def replace_global_rules(self, patterns: List[str]) -> Dict[str, Any]:
"""
全量替换全局黑名单规则PUT 语义)
Args:
patterns: 新的规则模式列表
Returns:
Dict: {'count': int} 最终规则数量
"""
from apps.common.models import BlacklistRule
count = self._replace_rules(
patterns=patterns,
scope=BlacklistRule.Scope.GLOBAL,
target=None
)
logger.info("全量替换全局黑名单规则: %d", count)
return {'count': count}
def replace_target_rules(self, target, patterns: List[str]) -> Dict[str, Any]:
"""
全量替换 Target 级黑名单规则PUT 语义)
Args:
target: Target 对象
patterns: 新的规则模式列表
Returns:
Dict: {'count': int} 最终规则数量
"""
from apps.common.models import BlacklistRule
count = self._replace_rules(
patterns=patterns,
scope=BlacklistRule.Scope.TARGET,
target=target
)
logger.info("全量替换 Target 黑名单规则: %d 条 (Target: %s)", count, target.name)
return {'count': count}
def _replace_rules(self, patterns: List[str], scope: str, target=None) -> int:
"""
内部方法:全量替换规则
Args:
patterns: 规则模式列表
scope: 规则作用域 (GLOBAL/TARGET)
target: Target 对象(仅 TARGET 作用域需要)
Returns:
int: 最终规则数量
"""
from apps.common.models import BlacklistRule
from django.db import transaction
patterns = _normalize_patterns(patterns)
with transaction.atomic():
# 1. 删除旧规则
delete_filter = {'scope': scope}
if target:
delete_filter['target'] = target
BlacklistRule.objects.filter(**delete_filter).delete()
# 2. 创建新规则
if patterns:
rules = [
BlacklistRule(
pattern=pattern,
rule_type=detect_rule_type(pattern),
scope=scope,
target=target
)
for pattern in patterns
]
BlacklistRule.objects.bulk_create(rules)
return len(patterns)

View File

@@ -2,13 +2,19 @@
通用模块 URL 配置
路由说明:
- /api/health/ 健康检查接口(无需认证)
- /api/auth/* 认证相关接口(登录、登出、用户信息)
- /api/system/* 系统管理接口(日志查看等)
- /api/health/ 健康检查接口(无需认证)
- /api/auth/* 认证相关接口(登录、登出、用户信息)
- /api/system/* 系统管理接口(日志查看等)
- /api/blacklist/* 黑名单管理接口
"""
from django.urls import path
from .views import LoginView, LogoutView, MeView, ChangePasswordView, SystemLogsView, SystemLogFilesView, HealthCheckView
from .views import (
LoginView, LogoutView, MeView, ChangePasswordView,
SystemLogsView, SystemLogFilesView, HealthCheckView,
GlobalBlacklistView,
)
urlpatterns = [
# 健康检查(无需认证)
@@ -23,4 +29,7 @@ urlpatterns = [
# 系统管理
path('system/logs/', SystemLogsView.as_view(), name='system-logs'),
path('system/logs/files/', SystemLogFilesView.as_view(), name='system-log-files'),
# 黑名单管理PUT 全量替换模式)
path('blacklist/rules/', GlobalBlacklistView.as_view(), name='blacklist-rules'),
]

View File

@@ -11,8 +11,14 @@ from .csv_utils import (
generate_csv_rows,
format_list_field,
format_datetime,
create_csv_export_response,
UTF8_BOM,
)
from .blacklist_filter import (
BlacklistFilter,
detect_rule_type,
extract_host,
)
__all__ = [
'deduplicate_for_bulk',
@@ -24,5 +30,9 @@ __all__ = [
'generate_csv_rows',
'format_list_field',
'format_datetime',
'create_csv_export_response',
'UTF8_BOM',
'BlacklistFilter',
'detect_rule_type',
'extract_host',
]

View File

@@ -0,0 +1,246 @@
"""
黑名单过滤工具
提供域名、IP、CIDR、关键词的黑名单匹配功能。
纯工具类,不涉及数据库操作。
支持的规则类型:
1. 域名精确匹配: example.com
- 规则: example.com
- 匹配: example.com
- 不匹配: sub.example.com, other.com
2. 域名后缀匹配: *.example.com
- 规则: *.example.com
- 匹配: sub.example.com, a.b.example.com, example.com
- 不匹配: other.com, example.com.cn
3. 关键词匹配: *cdn*
- 规则: *cdn*
- 匹配: cdn.example.com, a.cdn.b.com, mycdn123.com
- 不匹配: example.com (不包含 cdn)
4. IP 精确匹配: 192.168.1.1
- 规则: 192.168.1.1
- 匹配: 192.168.1.1
- 不匹配: 192.168.1.2
5. CIDR 范围匹配: 192.168.0.0/24
- 规则: 192.168.0.0/24
- 匹配: 192.168.0.1, 192.168.0.255
- 不匹配: 192.168.1.1
使用方式:
from apps.common.utils import BlacklistFilter
# 创建过滤器(传入规则列表)
rules = BlacklistRule.objects.filter(...)
filter = BlacklistFilter(rules)
# 检查单个目标
if filter.is_allowed('http://example.com'):
process(url)
# 流式处理
for url in urls:
if filter.is_allowed(url):
process(url)
"""
import ipaddress
import logging
from typing import List, Optional
from urllib.parse import urlparse
from apps.common.validators import is_valid_ip, validate_cidr
logger = logging.getLogger(__name__)
def detect_rule_type(pattern: str) -> str:
"""
自动识别规则类型
支持的模式:
- 域名精确匹配: example.com
- 域名后缀匹配: *.example.com
- 关键词匹配: *cdn* (匹配包含 cdn 的域名)
- IP 精确匹配: 192.168.1.1
- CIDR 范围: 192.168.0.0/24
Args:
pattern: 规则模式字符串
Returns:
str: 规则类型 ('domain', 'ip', 'cidr', 'keyword')
"""
if not pattern:
return 'domain'
pattern = pattern.strip()
# 检查关键词模式: *keyword* (前后都有星号,中间无点)
if pattern.startswith('*') and pattern.endswith('*') and len(pattern) > 2:
keyword = pattern[1:-1]
# 关键词中不能有点(否则可能是域名模式)
if '.' not in keyword:
return 'keyword'
# 检查 CIDR包含 /
if '/' in pattern:
try:
validate_cidr(pattern)
return 'cidr'
except ValueError:
pass
# 检查 IP去掉通配符前缀后验证
clean_pattern = pattern.lstrip('*').lstrip('.')
if is_valid_ip(clean_pattern):
return 'ip'
# 默认为域名
return 'domain'
def extract_host(target: str) -> str:
"""
从目标字符串中提取主机名
支持:
- 纯域名example.com
- 纯 IP192.168.1.1
- URLhttp://example.com/path
Args:
target: 目标字符串
Returns:
str: 提取的主机名
"""
if not target:
return ''
target = target.strip()
# 如果是 URL提取 hostname
if '://' in target:
try:
parsed = urlparse(target)
return parsed.hostname or target
except Exception:
return target
return target
class BlacklistFilter:
"""
黑名单过滤器
预编译规则,提供高效的匹配功能。
"""
def __init__(self, rules: List):
"""
初始化过滤器
Args:
rules: BlacklistRule 对象列表
"""
from apps.common.models import BlacklistRule
# 预解析:按类型分类 + CIDR 预编译
self._domain_rules = [] # (pattern, is_wildcard, suffix)
self._ip_rules = set() # 精确 IP 用 setO(1) 查找
self._cidr_rules = [] # (pattern, network_obj)
self._keyword_rules = [] # 关键词列表(小写)
# 去重:跨 scope 可能有重复规则
seen_patterns = set()
for rule in rules:
if rule.pattern in seen_patterns:
continue
seen_patterns.add(rule.pattern)
if rule.rule_type == BlacklistRule.RuleType.DOMAIN:
pattern = rule.pattern.lower()
if pattern.startswith('*.'):
self._domain_rules.append((pattern, True, pattern[1:]))
else:
self._domain_rules.append((pattern, False, None))
elif rule.rule_type == BlacklistRule.RuleType.IP:
self._ip_rules.add(rule.pattern)
elif rule.rule_type == BlacklistRule.RuleType.CIDR:
try:
network = ipaddress.ip_network(rule.pattern, strict=False)
self._cidr_rules.append((rule.pattern, network))
except ValueError:
pass
elif rule.rule_type == BlacklistRule.RuleType.KEYWORD:
# *cdn* -> cdn
keyword = rule.pattern[1:-1].lower()
self._keyword_rules.append(keyword)
def is_allowed(self, target: str) -> bool:
"""
检查目标是否通过过滤
Args:
target: 要检查的目标(域名/IP/URL
Returns:
bool: True 表示通过不在黑名单False 表示被过滤
"""
if not target:
return True
host = extract_host(target)
if not host:
return True
# 先判断输入类型,再走对应分支
if is_valid_ip(host):
return self._check_ip_rules(host)
else:
return self._check_domain_rules(host)
def _check_domain_rules(self, host: str) -> bool:
"""检查域名规则(精确匹配 + 后缀匹配 + 关键词匹配)"""
host_lower = host.lower()
# 1. 域名规则(精确 + 后缀)
for pattern, is_wildcard, suffix in self._domain_rules:
if is_wildcard:
if host_lower.endswith(suffix) or host_lower == pattern[2:]:
return False
else:
if host_lower == pattern:
return False
# 2. 关键词匹配(字符串 in 操作O(n*m)
for keyword in self._keyword_rules:
if keyword in host_lower:
return False
return True
def _check_ip_rules(self, host: str) -> bool:
"""检查 IP 规则(精确匹配 + CIDR"""
# 1. IP 精确匹配O(1)
if host in self._ip_rules:
return False
# 2. CIDR 匹配
if self._cidr_rules:
try:
ip_obj = ipaddress.ip_address(host)
for _, network in self._cidr_rules:
if ip_obj in network:
return False
except ValueError:
pass
return True

View File

@@ -4,13 +4,21 @@
- UTF-8 BOMExcel 兼容)
- RFC 4180 规范转义
- 流式生成(内存友好)
- 带 Content-Length 的文件响应(支持浏览器下载进度显示)
"""
import csv
import io
import os
import tempfile
import logging
from datetime import datetime
from typing import Iterator, Dict, Any, List, Callable, Optional
from django.http import FileResponse, StreamingHttpResponse
logger = logging.getLogger(__name__)
# UTF-8 BOM确保 Excel 正确识别编码
UTF8_BOM = '\ufeff'
@@ -114,3 +122,123 @@ def format_datetime(dt: Optional[datetime]) -> str:
dt = timezone.localtime(dt)
return dt.strftime('%Y-%m-%d %H:%M:%S')
def create_csv_export_response(
data_iterator: Iterator[Dict[str, Any]],
headers: List[str],
filename: str,
field_formatters: Optional[Dict[str, Callable]] = None,
show_progress: bool = True
) -> FileResponse | StreamingHttpResponse:
"""
创建 CSV 导出响应
根据 show_progress 参数选择响应类型:
- True: 使用临时文件 + FileResponse带 Content-Length浏览器显示下载进度
- False: 使用 StreamingHttpResponse内存更友好但无下载进度
Args:
data_iterator: 数据迭代器,每个元素是一个字典
headers: CSV 表头列表
filename: 下载文件名(如 "export_2024.csv"
field_formatters: 字段格式化函数字典
show_progress: 是否显示下载进度(默认 True
Returns:
FileResponse 或 StreamingHttpResponse
Example:
>>> data_iter = service.iter_data()
>>> headers = ['url', 'host', 'created_at']
>>> formatters = {'created_at': format_datetime}
>>> response = create_csv_export_response(
... data_iter, headers, 'websites.csv', formatters
... )
>>> return response
"""
if show_progress:
return _create_file_response(data_iterator, headers, filename, field_formatters)
else:
return _create_streaming_response(data_iterator, headers, filename, field_formatters)
def _create_file_response(
data_iterator: Iterator[Dict[str, Any]],
headers: List[str],
filename: str,
field_formatters: Optional[Dict[str, Callable]] = None
) -> FileResponse:
"""
创建带 Content-Length 的文件响应(支持浏览器下载进度)
实现方式:先写入临时文件,再返回 FileResponse
"""
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(
mode='w',
suffix='.csv',
delete=False,
encoding='utf-8'
)
temp_path = temp_file.name
try:
# 流式写入 CSV 数据到临时文件
for row in generate_csv_rows(data_iterator, headers, field_formatters):
temp_file.write(row)
temp_file.close()
# 获取文件大小
file_size = os.path.getsize(temp_path)
# 创建文件响应
response = FileResponse(
open(temp_path, 'rb'),
content_type='text/csv; charset=utf-8',
as_attachment=True,
filename=filename
)
response['Content-Length'] = file_size
# 设置清理回调:响应完成后删除临时文件
original_close = response.file_to_stream.close
def close_and_cleanup():
original_close()
try:
os.unlink(temp_path)
except OSError:
pass
response.file_to_stream.close = close_and_cleanup
return response
except Exception as e:
# 清理临时文件
try:
temp_file.close()
except:
pass
try:
os.unlink(temp_path)
except OSError:
pass
logger.error(f"创建 CSV 导出响应失败: {e}")
raise
def _create_streaming_response(
data_iterator: Iterator[Dict[str, Any]],
headers: List[str],
filename: str,
field_formatters: Optional[Dict[str, Callable]] = None
) -> StreamingHttpResponse:
"""
创建流式响应(无 Content-Length内存更友好
"""
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, field_formatters),
content_type='text/csv; charset=utf-8'
)
response['Content-Disposition'] = f'attachment; filename="{filename}"'
return response

View File

@@ -29,11 +29,19 @@ from dataclasses import dataclass
from typing import List, Dict, Optional, Union
from enum import Enum
from django.db.models import QuerySet, Q
from django.db.models import QuerySet, Q, F, Func, CharField
from django.db.models.functions import Cast
logger = logging.getLogger(__name__)
class ArrayToString(Func):
"""PostgreSQL array_to_string 函数"""
function = 'array_to_string'
template = "%(function)s(%(expressions)s, ',')"
output_field = CharField()
class LogicalOp(Enum):
"""逻辑运算符"""
AND = 'AND'
@@ -86,9 +94,21 @@ class QueryParser:
if not query_string or not query_string.strip():
return []
# 第一步:提取所有过滤条件并用占位符替换,保护引号内的空格
filters_found = []
placeholder_pattern = '__FILTER_{}__'
def replace_filter(match):
idx = len(filters_found)
filters_found.append(match.group(0))
return placeholder_pattern.format(idx)
# 先用正则提取所有 field="value" 形式的条件
protected = cls.FILTER_PATTERN.sub(replace_filter, query_string)
# 标准化逻辑运算符
# 先处理 || 和 or -> __OR__
normalized = cls.OR_PATTERN.sub(' __OR__ ', query_string)
normalized = cls.OR_PATTERN.sub(' __OR__ ', protected)
# 再处理 && 和 and -> __AND__
normalized = cls.AND_PATTERN.sub(' __AND__ ', normalized)
@@ -103,20 +123,26 @@ class QueryParser:
pending_op = LogicalOp.OR
elif token == '__AND__':
pending_op = LogicalOp.AND
else:
# 尝试解析为过滤条件
match = cls.FILTER_PATTERN.match(token)
if match:
field, operator, value = match.groups()
groups.append(FilterGroup(
filter=ParsedFilter(
field=field.lower(),
operator=operator,
value=value
),
logical_op=pending_op if groups else LogicalOp.AND # 第一个条件默认 AND
))
pending_op = LogicalOp.AND # 重置为默认 AND
elif token.startswith('__FILTER_') and token.endswith('__'):
# 还原占位符为原始过滤条件
try:
idx = int(token[9:-2]) # 提取索引
original_filter = filters_found[idx]
match = cls.FILTER_PATTERN.match(original_filter)
if match:
field, operator, value = match.groups()
groups.append(FilterGroup(
filter=ParsedFilter(
field=field.lower(),
operator=operator,
value=value
),
logical_op=pending_op if groups else LogicalOp.AND
))
pending_op = LogicalOp.AND # 重置为默认 AND
except (ValueError, IndexError):
pass
# 其他 token 忽略(无效输入)
return groups
@@ -151,6 +177,21 @@ class QueryBuilder:
json_array_fields = json_array_fields or []
# 收集需要 annotate 的数组模糊搜索字段
array_fuzzy_fields = set()
# 第一遍:检查是否有数组模糊匹配
for group in filter_groups:
f = group.filter
db_field = field_mapping.get(f.field)
if db_field and db_field in json_array_fields and f.operator == '=':
array_fuzzy_fields.add(db_field)
# 对数组模糊搜索字段做 annotate
for field in array_fuzzy_fields:
annotate_name = f'{field}_text'
queryset = queryset.annotate(**{annotate_name: ArrayToString(F(field))})
# 构建 Q 对象
combined_q = None
@@ -187,8 +228,17 @@ class QueryBuilder:
def _build_single_q(cls, field: str, operator: str, value: str, is_json_array: bool = False) -> Optional[Q]:
"""构建单个条件的 Q 对象"""
if is_json_array:
# JSON 数组字段使用 __contains 查询
return Q(**{f'{field}__contains': [value]})
if operator == '==':
# 精确匹配:数组中包含完全等于 value 的元素
return Q(**{f'{field}__contains': [value]})
elif operator == '!=':
# 不包含:数组中不包含完全等于 value 的元素
return ~Q(**{f'{field}__contains': [value]})
else: # '=' 模糊匹配
# 使用 annotate 后的字段进行模糊搜索
# 字段已在 build_query 中通过 ArrayToString 转换为文本
annotate_name = f'{field}_text'
return Q(**{f'{annotate_name}__icontains': value})
if operator == '!=':
return cls._build_not_equal_q(field, value)

View File

@@ -5,14 +5,17 @@
- 健康检查视图Docker 健康检查
- 认证相关视图:登录、登出、用户信息、修改密码
- 系统日志视图:实时日志查看
- 黑名单视图:全局黑名单规则管理
"""
from .health_views import HealthCheckView
from .auth_views import LoginView, LogoutView, MeView, ChangePasswordView
from .system_log_views import SystemLogsView, SystemLogFilesView
from .blacklist_views import GlobalBlacklistView
__all__ = [
'HealthCheckView',
'LoginView', 'LogoutView', 'MeView', 'ChangePasswordView',
'SystemLogsView', 'SystemLogFilesView',
'GlobalBlacklistView',
]

View File

@@ -0,0 +1,80 @@
"""全局黑名单 API 视图"""
import logging
from rest_framework import status
from rest_framework.views import APIView
from rest_framework.permissions import IsAuthenticated
from apps.common.response_helpers import success_response, error_response
from apps.common.services import BlacklistService
logger = logging.getLogger(__name__)
class GlobalBlacklistView(APIView):
"""
全局黑名单规则 API
Endpoints:
- GET /api/blacklist/rules/ - 获取全局黑名单列表
- PUT /api/blacklist/rules/ - 全量替换规则(文本框保存场景)
设计说明:
- 使用 PUT 全量替换模式,适合"文本框每行一个规则"的前端场景
- 用户编辑文本框 -> 点击保存 -> 后端全量替换
架构MVS 模式
- View: 参数验证、响应格式化
- Service: 业务逻辑BlacklistService
- Model: 数据持久化BlacklistRule
"""
permission_classes = [IsAuthenticated]
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.blacklist_service = BlacklistService()
def get(self, request):
"""
获取全局黑名单规则列表
返回格式:
{
"patterns": ["*.gov", "*.edu", "10.0.0.0/8"]
}
"""
rules = self.blacklist_service.get_global_rules()
patterns = list(rules.values_list('pattern', flat=True))
return success_response(data={'patterns': patterns})
def put(self, request):
"""
全量替换全局黑名单规则
请求格式:
{
"patterns": ["*.gov", "*.edu", "10.0.0.0/8"]
}
或者空数组清空所有规则:
{
"patterns": []
}
"""
patterns = request.data.get('patterns', [])
# 兼容字符串输入(换行分隔)
if isinstance(patterns, str):
patterns = [p for p in patterns.split('\n') if p.strip()]
if not isinstance(patterns, list):
return error_response(
code='VALIDATION_ERROR',
message='patterns 必须是数组'
)
# 调用 Service 层全量替换
result = self.blacklist_service.replace_global_rules(patterns)
return success_response(data=result)

View File

@@ -15,9 +15,10 @@
"""
from django.core.management.base import BaseCommand
from io import StringIO
from pathlib import Path
import yaml
from ruamel.yaml import YAML
from apps.engine.models import ScanEngine
@@ -44,10 +45,12 @@ class Command(BaseCommand):
with open(config_path, 'r', encoding='utf-8') as f:
default_config = f.read()
# 解析 YAML 为字典,后续用于生成子引擎配置
# 使用 ruamel.yaml 解析,保留注释
yaml_parser = YAML()
yaml_parser.preserve_quotes = True
try:
config_dict = yaml.safe_load(default_config) or {}
except yaml.YAMLError as e:
config_dict = yaml_parser.load(default_config) or {}
except Exception as e:
self.stdout.write(self.style.ERROR(f'引擎配置 YAML 解析失败: {e}'))
return
@@ -83,16 +86,13 @@ class Command(BaseCommand):
if scan_type != 'subdomain_discovery' and 'tools' not in scan_cfg:
continue
# 构造只包含当前扫描类型配置的 YAML
# 构造只包含当前扫描类型配置的 YAML(保留注释)
single_config = {scan_type: scan_cfg}
try:
single_yaml = yaml.safe_dump(
single_config,
sort_keys=False,
allow_unicode=True,
default_flow_style=None,
)
except yaml.YAMLError as e:
stream = StringIO()
yaml_parser.dump(single_config, stream)
single_yaml = stream.getvalue()
except Exception as e:
self.stdout.write(self.style.ERROR(f'生成子引擎 {scan_type} 配置失败: {e}'))
continue

View File

@@ -0,0 +1,213 @@
# Generated by Django 5.2.7 on 2026-01-06 00:55
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
]
operations = [
migrations.CreateModel(
name='NucleiTemplateRepo',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='仓库名称,用于前端展示和配置引用', max_length=200, unique=True)),
('repo_url', models.CharField(help_text='Git 仓库地址', max_length=500)),
('local_path', models.CharField(blank=True, default='', help_text='本地工作目录绝对路径', max_length=500)),
('commit_hash', models.CharField(blank=True, default='', help_text='最后同步的 Git commit hash用于 Worker 版本校验', max_length=40)),
('last_synced_at', models.DateTimeField(blank=True, help_text='最后一次成功同步时间', null=True)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
],
options={
'verbose_name': 'Nuclei 模板仓库',
'verbose_name_plural': 'Nuclei 模板仓库',
'db_table': 'nuclei_template_repo',
},
),
migrations.CreateModel(
name='ARLFingerprint',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='指纹名称', max_length=300, unique=True)),
('rule', models.TextField(help_text='匹配规则表达式')),
('created_at', models.DateTimeField(auto_now_add=True)),
],
options={
'verbose_name': 'ARL 指纹',
'verbose_name_plural': 'ARL 指纹',
'db_table': 'arl_fingerprint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['name'], name='arl_fingerp_name_c3a305_idx'), models.Index(fields=['-created_at'], name='arl_fingerp_created_ed1060_idx')],
},
),
migrations.CreateModel(
name='EholeFingerprint',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('cms', models.CharField(help_text='产品/CMS名称', max_length=200)),
('method', models.CharField(default='keyword', help_text='匹配方式', max_length=200)),
('location', models.CharField(default='body', help_text='匹配位置', max_length=200)),
('keyword', models.JSONField(default=list, help_text='关键词列表')),
('is_important', models.BooleanField(default=False, help_text='是否重点资产')),
('type', models.CharField(blank=True, default='-', help_text='分类', max_length=100)),
('created_at', models.DateTimeField(auto_now_add=True)),
],
options={
'verbose_name': 'EHole 指纹',
'verbose_name_plural': 'EHole 指纹',
'db_table': 'ehole_fingerprint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['cms'], name='ehole_finge_cms_72ca2c_idx'), models.Index(fields=['method'], name='ehole_finge_method_17f0db_idx'), models.Index(fields=['location'], name='ehole_finge_locatio_7bb82b_idx'), models.Index(fields=['type'], name='ehole_finge_type_ca2bce_idx'), models.Index(fields=['is_important'], name='ehole_finge_is_impo_d56e64_idx'), models.Index(fields=['-created_at'], name='ehole_finge_created_d862b0_idx')],
'constraints': [models.UniqueConstraint(fields=('cms', 'method', 'location'), name='unique_ehole_fingerprint')],
},
),
migrations.CreateModel(
name='FingerPrintHubFingerprint',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('fp_id', models.CharField(help_text='指纹ID', max_length=200, unique=True)),
('name', models.CharField(help_text='指纹名称', max_length=300)),
('author', models.CharField(blank=True, default='', help_text='作者', max_length=200)),
('tags', models.CharField(blank=True, default='', help_text='标签', max_length=500)),
('severity', models.CharField(blank=True, default='info', help_text='严重程度', max_length=50)),
('metadata', models.JSONField(blank=True, default=dict, help_text='元数据')),
('http', models.JSONField(default=list, help_text='HTTP 匹配规则')),
('source_file', models.CharField(blank=True, default='', help_text='来源文件', max_length=500)),
('created_at', models.DateTimeField(auto_now_add=True)),
],
options={
'verbose_name': 'FingerPrintHub 指纹',
'verbose_name_plural': 'FingerPrintHub 指纹',
'db_table': 'fingerprinthub_fingerprint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['fp_id'], name='fingerprint_fp_id_df467f_idx'), models.Index(fields=['name'], name='fingerprint_name_95b6fb_idx'), models.Index(fields=['author'], name='fingerprint_author_80f54b_idx'), models.Index(fields=['severity'], name='fingerprint_severit_f70422_idx'), models.Index(fields=['-created_at'], name='fingerprint_created_bec16c_idx')],
},
),
migrations.CreateModel(
name='FingersFingerprint',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='指纹名称', max_length=300, unique=True)),
('link', models.URLField(blank=True, default='', help_text='相关链接', max_length=500)),
('rule', models.JSONField(default=list, help_text='匹配规则数组')),
('tag', models.JSONField(default=list, help_text='标签数组')),
('focus', models.BooleanField(default=False, help_text='是否重点关注')),
('default_port', models.JSONField(blank=True, default=list, help_text='默认端口数组')),
('created_at', models.DateTimeField(auto_now_add=True)),
],
options={
'verbose_name': 'Fingers 指纹',
'verbose_name_plural': 'Fingers 指纹',
'db_table': 'fingers_fingerprint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['name'], name='fingers_fin_name_952de0_idx'), models.Index(fields=['link'], name='fingers_fin_link_4c6b7f_idx'), models.Index(fields=['focus'], name='fingers_fin_focus_568c7f_idx'), models.Index(fields=['-created_at'], name='fingers_fin_created_46fc91_idx')],
},
),
migrations.CreateModel(
name='GobyFingerprint',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='产品名称', max_length=300, unique=True)),
('logic', models.CharField(help_text='逻辑表达式', max_length=500)),
('rule', models.JSONField(default=list, help_text='规则数组')),
('created_at', models.DateTimeField(auto_now_add=True)),
],
options={
'verbose_name': 'Goby 指纹',
'verbose_name_plural': 'Goby 指纹',
'db_table': 'goby_fingerprint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['name'], name='goby_finger_name_82084c_idx'), models.Index(fields=['logic'], name='goby_finger_logic_a63226_idx'), models.Index(fields=['-created_at'], name='goby_finger_created_50e000_idx')],
},
),
migrations.CreateModel(
name='ScanEngine',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('name', models.CharField(help_text='引擎名称', max_length=200, unique=True)),
('configuration', models.CharField(blank=True, default='', help_text='引擎配置yaml 格式', max_length=10000)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
],
options={
'verbose_name': '扫描引擎',
'verbose_name_plural': '扫描引擎',
'db_table': 'scan_engine',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='scan_engine_created_da4870_idx')],
},
),
migrations.CreateModel(
name='WappalyzerFingerprint',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='应用名称', max_length=300, unique=True)),
('cats', models.JSONField(default=list, help_text='分类 ID 数组')),
('cookies', models.JSONField(blank=True, default=dict, help_text='Cookie 检测规则')),
('headers', models.JSONField(blank=True, default=dict, help_text='HTTP Header 检测规则')),
('script_src', models.JSONField(blank=True, default=list, help_text='脚本 URL 正则数组')),
('js', models.JSONField(blank=True, default=list, help_text='JavaScript 变量检测规则')),
('implies', models.JSONField(blank=True, default=list, help_text='依赖关系数组')),
('meta', models.JSONField(blank=True, default=dict, help_text='HTML meta 标签检测规则')),
('html', models.JSONField(blank=True, default=list, help_text='HTML 内容正则数组')),
('description', models.TextField(blank=True, default='', help_text='应用描述')),
('website', models.URLField(blank=True, default='', help_text='官网链接', max_length=500)),
('cpe', models.CharField(blank=True, default='', help_text='CPE 标识符', max_length=300)),
('created_at', models.DateTimeField(auto_now_add=True)),
],
options={
'verbose_name': 'Wappalyzer 指纹',
'verbose_name_plural': 'Wappalyzer 指纹',
'db_table': 'wappalyzer_fingerprint',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['name'], name='wappalyzer__name_63c669_idx'), models.Index(fields=['website'], name='wappalyzer__website_88de1c_idx'), models.Index(fields=['cpe'], name='wappalyzer__cpe_30c761_idx'), models.Index(fields=['-created_at'], name='wappalyzer__created_8e6c21_idx')],
},
),
migrations.CreateModel(
name='Wordlist',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('name', models.CharField(help_text='字典名称,唯一', max_length=200, unique=True)),
('description', models.CharField(blank=True, default='', help_text='字典描述', max_length=200)),
('file_path', models.CharField(help_text='后端保存的字典文件绝对路径', max_length=500)),
('file_size', models.BigIntegerField(default=0, help_text='文件大小(字节)')),
('line_count', models.IntegerField(default=0, help_text='字典行数')),
('file_hash', models.CharField(blank=True, default='', help_text='文件 SHA-256 哈希,用于缓存校验', max_length=64)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
],
options={
'verbose_name': '字典文件',
'verbose_name_plural': '字典文件',
'db_table': 'wordlist',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='wordlist_created_4afb02_idx')],
},
),
migrations.CreateModel(
name='WorkerNode',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.CharField(help_text='节点名称', max_length=100)),
('ip_address', models.GenericIPAddressField(help_text='IP 地址(本地节点为 127.0.0.1')),
('ssh_port', models.IntegerField(default=22, help_text='SSH 端口')),
('username', models.CharField(default='root', help_text='SSH 用户名', max_length=50)),
('password', models.CharField(blank=True, default='', help_text='SSH 密码', max_length=200)),
('is_local', models.BooleanField(default=False, help_text='是否为本地节点Docker 容器内)')),
('status', models.CharField(choices=[('pending', '待部署'), ('deploying', '部署中'), ('online', '在线'), ('offline', '离线'), ('updating', '更新中'), ('outdated', '版本过低')], default='pending', help_text='状态: pending/deploying/online/offline', max_length=20)),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
],
options={
'verbose_name': 'Worker 节点',
'db_table': 'worker_node',
'ordering': ['-created_at'],
'constraints': [models.UniqueConstraint(condition=models.Q(('is_local', False)), fields=('ip_address',), name='unique_remote_worker_ip'), models.UniqueConstraint(fields=('name',), name='unique_worker_name')],
},
),
]

View File

@@ -88,6 +88,8 @@ def _register_scheduled_jobs(scheduler: BackgroundScheduler):
replace_existing=True,
)
logger.info(" - 已注册: 扫描结果清理(每天 03:00")
# 注意:搜索物化视图刷新已迁移到 pg_ivm 增量维护,无需定时任务
def _trigger_scheduled_scans():

View File

@@ -16,10 +16,9 @@ class GobyFingerprintService(BaseFingerprintService):
"""
校验单条 Goby 指纹
校验规则
- name 字段必须存在且非空
- logic 字段必须存在
- rule 字段必须是数组
支持两种格式
1. 标准格式: {"name": "...", "logic": "...", "rule": [...]}
2. JSONL 格式: {"product": "...", "rule": "..."}
Args:
item: 单条指纹数据
@@ -27,25 +26,43 @@ class GobyFingerprintService(BaseFingerprintService):
Returns:
bool: 是否有效
"""
# 标准格式name + logic + rule(数组)
name = item.get('name', '')
logic = item.get('logic', '')
rule = item.get('rule')
return bool(name and str(name).strip()) and bool(logic) and isinstance(rule, list)
if name and item.get('logic') is not None and isinstance(item.get('rule'), list):
return bool(str(name).strip())
# JSONL 格式product + rule(字符串)
product = item.get('product', '')
rule = item.get('rule', '')
return bool(product and str(product).strip() and rule and str(rule).strip())
def to_model_data(self, item: dict) -> dict:
"""
转换 Goby JSON 格式为 Model 字段
支持两种输入格式:
1. 标准格式: {"name": "...", "logic": "...", "rule": [...]}
2. JSONL 格式: {"product": "...", "rule": "..."}
Args:
item: 原始 Goby JSON 数据
Returns:
dict: Model 字段数据
"""
# 标准格式
if 'name' in item and isinstance(item.get('rule'), list):
return {
'name': str(item.get('name', '')).strip(),
'logic': item.get('logic', ''),
'rule': item.get('rule', []),
}
# JSONL 格式:将 rule 字符串转为单元素数组
return {
'name': str(item.get('name', '')).strip(),
'logic': item.get('logic', ''),
'rule': item.get('rule', []),
'name': str(item.get('product', '')).strip(),
'logic': 'or', # JSONL 格式默认 or 逻辑
'rule': [item.get('rule', '')] if item.get('rule') else [],
}
def get_export_data(self) -> list:

View File

@@ -312,7 +312,11 @@ class TaskDistributor:
# - 本地 Workerinstall.sh 已预拉取镜像,直接使用本地版本
# - 远程 Workerdeploy 时已预拉取镜像,直接使用本地版本
# - 避免每次任务都检查 Docker Hub提升性能和稳定性
# OOM 优先级:--oom-score-adj=1000 让 Worker 在内存不足时优先被杀
# - 范围 -1000 到 1000值越大越容易被 OOM Killer 选中
# - 保护 server/nginx/frontend 等核心服务,确保 Web 界面可用
cmd = f'''docker run --rm -d --pull=missing {network_arg} \\
--oom-score-adj=1000 \\
{' '.join(env_vars)} \\
{' '.join(volumes)} \\
{self.docker_image} \\

View File

@@ -139,7 +139,7 @@ class BaseFingerprintViewSet(viewsets.ModelViewSet):
POST /api/engine/fingerprints/{type}/import_file/
请求格式multipart/form-data
- file: JSON 文件
- file: JSON 文件(支持标准 JSON 和 JSONL 格式)
返回:同 batch_create
"""
@@ -148,9 +148,12 @@ class BaseFingerprintViewSet(viewsets.ModelViewSet):
raise ValidationError('缺少文件')
try:
json_data = json.load(file)
content = file.read().decode('utf-8')
json_data = self._parse_json_content(content)
except json.JSONDecodeError as e:
raise ValidationError(f'无效的 JSON 格式: {e}')
except UnicodeDecodeError as e:
raise ValidationError(f'文件编码错误: {e}')
fingerprints = self.parse_import_data(json_data)
if not fingerprints:
@@ -159,6 +162,41 @@ class BaseFingerprintViewSet(viewsets.ModelViewSet):
result = self.get_service().batch_create_fingerprints(fingerprints)
return success_response(data=result, status_code=status.HTTP_201_CREATED)
def _parse_json_content(self, content: str):
"""
解析 JSON 内容,支持标准 JSON 和 JSONL 格式
Args:
content: 文件内容字符串
Returns:
解析后的数据list 或 dict
"""
content = content.strip()
# 尝试标准 JSON 解析
try:
return json.loads(content)
except json.JSONDecodeError:
pass
# 尝试 JSONL 格式(每行一个 JSON 对象)
lines = content.split('\n')
result = []
for i, line in enumerate(lines):
line = line.strip()
if not line:
continue
try:
result.append(json.loads(line))
except json.JSONDecodeError as e:
raise json.JSONDecodeError(f'{i + 1} 行解析失败: {e.msg}', e.doc, e.pos)
if not result:
raise json.JSONDecodeError('文件为空或格式无效', content, 0)
return result
@action(detail=False, methods=['post'], url_path='bulk-delete')
def bulk_delete(self, request):
"""

View File

@@ -13,27 +13,17 @@ SCAN_TOOLS_BASE_PATH = getattr(settings, 'SCAN_TOOLS_BASE_PATH', '/usr/local/bin
SUBDOMAIN_DISCOVERY_COMMANDS = {
'subfinder': {
# 默认使用所有数据源(更全面,略慢),并始终开启递归
# -all 使用所有数据源
# -recursive 对支持递归的源启用递归枚举(默认开启
'base': "subfinder -d {domain} -all -recursive -o '{output_file}' -silent",
# 使用所有数据源(包括付费源,只要配置了 API key
# -all 使用所有数据源slow 但全面)
# -v 显示详细输出,包括使用的数据源(调试用
# 注意:不要加 -recursive它会排除不支持递归的源如 fofa
'base': "subfinder -d {domain} -all -o '{output_file}' -v",
'optional': {
'threads': '-t {threads}', # 控制并发 goroutine 数
'provider_config': "-pc '{provider_config}'", # Provider 配置文件路径
}
},
'amass_passive': {
# 先执行被动枚举,将结果写入 amass 内部数据库然后从数据库中导出纯域名names到 output_file
# -silent 禁用进度条和其他输出
'base': "amass enum -passive -silent -d {domain} && amass subs -names -d {domain} > '{output_file}'"
},
'amass_active': {
# 先执行主动枚举 + 爆破,将结果写入 amass 内部数据库然后从数据库中导出纯域名names到 output_file
# -silent 禁用进度条和其他输出
'base': "amass enum -active -silent -d {domain} -brute && amass subs -names -d {domain} > '{output_file}'"
},
'sublist3r': {
'base': "python3 '/usr/local/share/Sublist3r/sublist3r.py' -d {domain} -o '{output_file}'",
'optional': {
@@ -97,9 +87,11 @@ SITE_SCAN_COMMANDS = {
'base': (
"'{scan_tools_base}/httpx' -l '{url_file}' "
'-status-code -content-type -content-length '
'-location -title -server -body-preview '
'-location -title -server '
'-tech-detect -cdn -vhost '
'-random-agent -no-color -json'
'-include-response '
'-rstr 2000 '
'-random-agent -no-color -json -silent'
),
'optional': {
'threads': '-threads {threads}',
@@ -169,9 +161,11 @@ URL_FETCH_COMMANDS = {
'base': (
"'{scan_tools_base}/httpx' -l '{url_file}' "
'-status-code -content-type -content-length '
'-location -title -server -body-preview '
'-location -title -server '
'-tech-detect -cdn -vhost '
'-random-agent -no-color -json'
'-include-response '
'-rstr 2000 '
'-random-agent -no-color -json -silent'
),
'optional': {
'threads': '-threads {threads}',
@@ -257,11 +251,16 @@ COMMAND_TEMPLATES = {
'directory_scan': DIRECTORY_SCAN_COMMANDS,
'url_fetch': URL_FETCH_COMMANDS,
'vuln_scan': VULN_SCAN_COMMANDS,
'screenshot': {}, # 使用 Python 原生库Playwright无命令模板
}
# ==================== 扫描类型配置 ====================
# 执行阶段定义(按顺序执行)
# Stage 1: 资产发现 - 子域名 → 端口 → 站点探测 → 指纹识别
# Stage 2: URL 收集 - URL 获取 + 目录扫描(并行)
# Stage 3: 截图 - 在 URL 收集完成后执行,捕获更多发现的页面
# Stage 4: 漏洞扫描 - 最后执行
EXECUTION_STAGES = [
{
'mode': 'sequential',
@@ -271,6 +270,10 @@ EXECUTION_STAGES = [
'mode': 'parallel',
'flows': ['url_fetch', 'directory_scan']
},
{
'mode': 'sequential',
'flows': ['screenshot']
},
{
'mode': 'sequential',
'flows': ['vuln_scan']

View File

@@ -4,14 +4,12 @@
# 必需参数enabled是否启用
# 可选参数timeout超时秒数默认 auto 自动计算)
# ==================== 子域名发现 ====================
#
# Stage 1: 被动收集(并行) - 必选,至少启用一个工具
# Stage 2: 字典爆破(可选) - 使用字典暴力枚举子域名
# Stage 3: 变异生成 + 验证(可选) - 基于已发现域名生成变异,流式验证存活
# Stage 4: DNS 存活验证(可选) - 验证所有候选域名是否能解析
#
subdomain_discovery:
# ==================== 子域名发现 ====================
# Stage 1: 被动收集(并行) - 必选,至少启用一个工具
# Stage 2: 字典爆破(可选) - 使用字典暴力枚举子域名
# Stage 3: 变异生成 + 验证(可选) - 基于已发现域名生成变异,流式验证存活
# Stage 4: DNS 存活验证(可选) - 验证所有候选域名是否能解析
# === Stage 1: 被动收集工具(并行执行)===
passive_tools:
subfinder:
@@ -19,14 +17,6 @@ subdomain_discovery:
timeout: 3600 # 1小时
# threads: 10 # 并发 goroutine 数
amass_passive:
enabled: true
timeout: 3600
amass_active:
enabled: true # 主动枚举 + 爆破
timeout: 3600
sublist3r:
enabled: true
timeout: 3600
@@ -55,8 +45,8 @@ subdomain_discovery:
subdomain_resolve:
timeout: auto # 自动根据候选子域数量计算
# ==================== 端口扫描 ====================
port_scan:
# ==================== 端口扫描 ====================
tools:
naabu_active:
enabled: true
@@ -64,14 +54,14 @@ port_scan:
threads: 200 # 并发连接数(默认 5
# ports: 1-65535 # 扫描端口范围(默认 1-65535
top-ports: 100 # 扫描 nmap top 100 端口
rate: 10 # 扫描速率(默认 10
rate: 50 # 扫描速率
naabu_passive:
enabled: true
# timeout: auto # 被动扫描通常较快
# ==================== 站点扫描 ====================
site_scan:
# ==================== 站点扫描 ====================
tools:
httpx:
enabled: true
@@ -81,16 +71,16 @@ site_scan:
# request-timeout: 10 # 单个请求超时秒数(默认 10
# retries: 2 # 请求失败重试次数
# ==================== 指纹识别 ====================
# 在 site_scan 后串行执行,识别 WebSite 的技术栈
fingerprint_detect:
# ==================== 指纹识别 ====================
# 在 站点扫描 后串行执行,识别 WebSite 的技术栈
tools:
xingfinger:
enabled: true
fingerprint-libs: [ehole, goby, wappalyzer, fingers, fingerprinthub, arl] # 全部指纹库
fingerprint-libs: [ehole, goby, wappalyzer, fingers, fingerprinthub, arl] # 默认启动全部指纹库
# ==================== 目录扫描 ====================
directory_scan:
# ==================== 目录扫描 ====================
tools:
ffuf:
enabled: true
@@ -103,8 +93,18 @@ directory_scan:
match-codes: 200,201,301,302,401,403 # 匹配的 HTTP 状态码
# rate: 0 # 每秒请求数(默认 0 不限制)
# ==================== URL 获取 ====================
screenshot:
# ==================== 网站截图 ====================
# 使用 Playwright 对网站进行截图,保存为 WebP 格式
# 在 Stage 2 与 url_fetch、directory_scan 并行执行
tools:
playwright:
enabled: true
concurrency: 5 # 并发截图数(默认 5
url_sources: [websites] # URL 来源当前对website截图还可以用 [websites, endpoints]
url_fetch:
# ==================== URL 获取 ====================
tools:
waymore:
enabled: true
@@ -142,8 +142,8 @@ url_fetch:
# request-timeout: 10 # 单个请求超时秒数(默认 10
# retries: 2 # 请求失败重试次数
# ==================== 漏洞扫描 ====================
vuln_scan:
# ==================== 漏洞扫描 ====================
tools:
dalfox_xss:
enabled: true

View File

@@ -10,30 +10,30 @@
- 配置由 YAML 解析
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
from prefect import flow
from prefect.task_runners import ThreadPoolTaskRunner
import hashlib
import logging
import os
import subprocess
from datetime import datetime
from pathlib import Path
from typing import List, Tuple
from apps.scan.tasks.directory_scan import (
export_sites_task,
run_and_stream_save_directories_task
)
from prefect import flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.tasks.directory_scan import (
export_sites_task,
run_and_stream_save_directories_task,
)
from apps.scan.utils import (
build_scan_command,
ensure_wordlist_local,
user_log,
wait_for_system_load,
)
from apps.scan.utils import config_parser, build_scan_command, ensure_wordlist_local
logger = logging.getLogger(__name__)
@@ -45,496 +45,343 @@ def calculate_directory_scan_timeout(
tool_config: dict,
base_per_word: float = 1.0,
min_timeout: int = 60,
max_timeout: int = 7200
) -> int:
"""
根据字典行数计算目录扫描超时时间
计算公式:超时时间 = 字典行数 × 每个单词基础时间
超时范围:60秒 ~ 2小时7200秒
超时范围:最小 60 秒,无上限
Args:
tool_config: 工具配置字典,包含 wordlist 路径
base_per_word: 每个单词的基础时间(秒),默认 1.0秒
min_timeout: 最小超时时间(秒),默认 60秒
max_timeout: 最大超时时间(秒),默认 7200秒2小时
Returns:
int: 计算出的超时时间(秒)范围60 ~ 7200
Example:
# 1000行字典 × 1.0秒 = 1000秒 → 限制为7200秒中的 1000秒
# 10000行字典 × 1.0秒 = 10000秒 → 限制为7200秒最大值
timeout = calculate_directory_scan_timeout(
tool_config={'wordlist': '/path/to/wordlist.txt'}
)
int: 计算出的超时时间(秒)
"""
import os
wordlist_path = tool_config.get('wordlist')
if not wordlist_path:
logger.warning("工具配置中未指定 wordlist使用默认超时: %d", min_timeout)
return min_timeout
wordlist_path = os.path.expanduser(wordlist_path)
if not os.path.exists(wordlist_path):
logger.warning("字典文件不存在: %s,使用默认超时: %d", wordlist_path, min_timeout)
return min_timeout
try:
# 从 tool_config 中获取 wordlist 路径
wordlist_path = tool_config.get('wordlist')
if not wordlist_path:
logger.warning("工具配置中未指定 wordlist使用默认超时: %d", min_timeout)
return min_timeout
# 展开用户目录(~
wordlist_path = os.path.expanduser(wordlist_path)
# 检查文件是否存在
if not os.path.exists(wordlist_path):
logger.warning("字典文件不存在: %s,使用默认超时: %d", wordlist_path, min_timeout)
return min_timeout
# 使用 wc -l 快速统计字典行数
result = subprocess.run(
['wc', '-l', wordlist_path],
capture_output=True,
text=True,
check=True
)
# wc -l 输出格式:行数 + 空格 + 文件名
line_count = int(result.stdout.strip().split()[0])
# 计算超时时间
timeout = int(line_count * base_per_word)
# 设置合理的下限(不再设置上限)
timeout = max(min_timeout, timeout)
timeout = max(min_timeout, int(line_count * base_per_word))
logger.info(
"目录扫描超时计算 - 字典: %s, 行数: %d, 基础时间: %.3f秒/词, 计算超时: %d",
wordlist_path, line_count, base_per_word, timeout
)
return timeout
except subprocess.CalledProcessError as e:
logger.error("统计字典行数失败: %s", e)
# 失败时返回默认超时
return min_timeout
except (ValueError, IndexError) as e:
logger.error("解析字典行数失败: %s", e)
return min_timeout
except Exception as e:
logger.error("计算超时时间异常: %s", e)
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
logger.error("计算超时时间失败: %s", e)
return min_timeout
def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> int:
"""
从单个工具配置中获取 max_workers 参数
Args:
tool_config: 单个工具的配置字典,如 {'max_workers': 10, 'threads': 5, ...}
default: 默认值,默认为 5
Returns:
int: max_workers 值
"""
"""从单个工具配置中获取 max_workers 参数"""
if not isinstance(tool_config, dict):
return default
# 支持 max_workers 和 max-workersYAML 中划线会被转换)
max_workers = tool_config.get('max_workers') or tool_config.get('max-workers')
if max_workers is not None and isinstance(max_workers, int) and max_workers > 0:
if isinstance(max_workers, int) and max_workers > 0:
return max_workers
return default
def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]:
def _export_site_urls(
target_id: int,
directory_scan_dir: Path
) -> Tuple[str, int]:
"""
导出目标下的所有站点 URL 到文件(支持懒加载)
导出目标下的所有站点 URL 到文件
Args:
target_id: 目标 ID
target_name: 目标名称(用于懒加载创建默认站点)
directory_scan_dir: 目录扫描目录
Returns:
tuple: (sites_file, site_count)
Raises:
ValueError: 站点数量为 0
"""
logger.info("Step 1: 导出目标的所有站点 URL")
sites_file = str(directory_scan_dir / 'sites.txt')
export_result = export_sites_task(
target_id=target_id,
output_file=sites_file,
batch_size=1000 # 每次读取 1000 条,优化内存占用
batch_size=1000
)
site_count = export_result['total_count']
logger.info(
"✓ 站点 URL 导出完成 - 文件: %s, 数量: %d",
export_result['output_file'],
site_count
)
if site_count == 0:
logger.warning("目标下没有站点,无法执行目录扫描")
# 不抛出异常,由上层决定如何处理
# raise ValueError("目标下没有站点,无法执行目录扫描")
return export_result['output_file'], site_count
def _run_scans_sequentially(
enabled_tools: dict,
sites_file: str,
directory_scan_dir: Path,
scan_id: int,
target_id: int,
site_count: int,
target_name: str
) -> tuple[int, int, list]:
"""
串行执行目录扫描任务(支持多工具)- 已废弃,保留用于兼容
Args:
enabled_tools: 启用的工具配置字典
sites_file: 站点文件路径
directory_scan_dir: 目录扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
site_count: 站点数量
target_name: 目标名称(用于错误日志)
Returns:
tuple: (total_directories, processed_sites, failed_sites)
"""
# 读取站点列表
sites = []
with open(sites_file, 'r', encoding='utf-8') as f:
for line in f:
site_url = line.strip()
if site_url:
sites.append(site_url)
logger.info("准备扫描 %d 个站点,使用工具: %s", len(sites), ', '.join(enabled_tools.keys()))
total_directories = 0
processed_sites_set = set() # 使用 set 避免重复计数
failed_sites = []
# 遍历每个工具
for tool_name, tool_config in enabled_tools.items():
logger.info("="*60)
logger.info("使用工具: %s", tool_name)
logger.info("="*60)
# 如果配置了 wordlist_name则先确保本地存在对应的字典文件含 hash 校验)
wordlist_name = tool_config.get('wordlist_name')
if wordlist_name:
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
failed_sites.extend(sites)
continue
# 逐个站点执行扫描
for idx, site_url in enumerate(sites, 1):
logger.info(
"[%d/%d] 开始扫描站点: %s (工具: %s)",
idx, len(sites), site_url, tool_name
)
# 使用统一的命令构建器
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={
'url': site_url
},
tool_config=tool_config
)
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
continue
# 单个站点超时:从配置中获取(支持 'auto' 动态计算)
# ffuf 逐个站点扫描timeout 就是单个站点的超时时间
site_timeout = tool_config.get('timeout', 300)
if site_timeout == 'auto':
# 动态计算超时时间(基于字典行数)
site_timeout = calculate_directory_scan_timeout(tool_config)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}")
# 生成日志文件路径
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = directory_scan_dir / f"{tool_name}_{timestamp}_{idx}.log"
try:
# 直接调用 task串行执行
result = run_and_stream_save_directories_task(
cmd=command,
tool_name=tool_name, # 新增:工具名称
scan_id=scan_id,
target_id=target_id,
site_url=site_url,
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=site_timeout,
log_file=str(log_file) # 新增:日志文件路径
)
total_directories += result.get('created_directories', 0)
processed_sites_set.add(site_url) # 使用 set 记录成功的站点
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, len(sites), site_url,
result.get('created_directories', 0)
)
except subprocess.TimeoutExpired as exc:
# 超时异常单独处理
failed_sites.append(site_url)
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 超时配置: %d\n"
"注意:超时前已解析的目录数据已保存到数据库,但扫描未完全完成。",
idx, len(sites), site_url, site_timeout
)
except Exception as exc:
# 其他异常
failed_sites.append(site_url)
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, len(sites), site_url, exc
)
# 每 10 个站点输出进度
if idx % 10 == 0:
logger.info(
"进度: %d/%d (%.1f%%) - 已发现 %d 个目录",
idx, len(sites), idx/len(sites)*100, total_directories
)
# 计算成功和失败的站点数
processed_count = len(processed_sites_set)
if failed_sites:
logger.warning(
"部分站点扫描失败: %d/%d",
len(failed_sites), len(sites)
)
logger.info(
"✓ 串行目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
processed_count, len(sites), len(failed_sites), total_directories
)
return total_directories, processed_count, failed_sites
def _generate_log_filename(tool_name: str, site_url: str, directory_scan_dir: Path) -> Path:
"""
生成唯一的日志文件名
使用 URL 的 hash 确保并发时不会冲突
Args:
tool_name: 工具名称
site_url: 站点 URL
directory_scan_dir: 目录扫描目录
Returns:
Path: 日志文件路径
"""
url_hash = hashlib.md5(site_url.encode()).hexdigest()[:8]
def _generate_log_filename(
tool_name: str,
site_url: str,
directory_scan_dir: Path
) -> Path:
"""生成唯一的日志文件名(使用 URL 的 hash 确保并发时不会冲突)"""
url_hash = hashlib.md5(
site_url.encode(),
usedforsecurity=False
).hexdigest()[:8]
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
return directory_scan_dir / f"{tool_name}_{url_hash}_{timestamp}.log"
def _prepare_tool_wordlist(tool_name: str, tool_config: dict) -> bool:
"""准备工具的字典文件,返回是否成功"""
wordlist_name = tool_config.get('wordlist_name')
if not wordlist_name:
return True
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
return True
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
return False
def _build_scan_params(
tool_name: str,
tool_config: dict,
sites: List[str],
directory_scan_dir: Path,
site_timeout: int
) -> Tuple[List[dict], List[str]]:
"""构建所有站点的扫描参数,返回 (scan_params_list, failed_sites)"""
scan_params_list = []
failed_sites = []
for idx, site_url in enumerate(sites, 1):
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={'url': site_url},
tool_config=tool_config
)
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
scan_params_list.append({
'idx': idx,
'site_url': site_url,
'command': command,
'log_file': str(log_file),
'timeout': site_timeout
})
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
return scan_params_list, failed_sites
def _execute_batch(
batch_params: List[dict],
tool_name: str,
scan_id: int,
target_id: int,
directory_scan_dir: Path,
total_sites: int
) -> Tuple[int, List[str]]:
"""执行一批扫描任务,返回 (directories_found, failed_sites)"""
directories_found = 0
failed_sites = []
# 提交任务
futures = []
for params in batch_params:
future = run_and_stream_save_directories_task.submit(
cmd=params['command'],
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
site_url=params['site_url'],
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=params['timeout'],
log_file=params['log_file']
)
futures.append((params['idx'], params['site_url'], future))
# 等待结果
for idx, site_url, future in futures:
try:
result = future.result()
dirs_count = result.get('created_directories', 0)
directories_found += dirs_count
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, total_sites, site_url, dirs_count
)
except Exception as exc:
failed_sites.append(site_url)
if 'timeout' in str(exc).lower():
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
idx, total_sites, site_url, exc
)
else:
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, total_sites, site_url, exc
)
return directories_found, failed_sites
def _run_scans_concurrently(
enabled_tools: dict,
sites_file: str,
directory_scan_dir: Path,
scan_id: int,
target_id: int,
site_count: int,
target_name: str
) -> Tuple[int, int, List[str]]:
"""
并发执行目录扫描任务(使用 ThreadPoolTaskRunner
Args:
enabled_tools: 启用的工具配置字典
sites_file: 站点文件路径
directory_scan_dir: 目录扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
site_count: 站点数量
target_name: 目标名称(用于错误日志)
并发执行目录扫描任务
Returns:
tuple: (total_directories, processed_sites, failed_sites)
"""
# 读取站点列表
sites: List[str] = []
with open(sites_file, 'r', encoding='utf-8') as f:
for line in f:
site_url = line.strip()
if site_url:
sites.append(site_url)
sites = [line.strip() for line in f if line.strip()]
if not sites:
logger.warning("站点列表为空")
return 0, 0, []
logger.info(
"准备并发扫描 %d 个站点,使用工具: %s",
len(sites), ', '.join(enabled_tools.keys())
)
total_directories = 0
processed_sites_count = 0
failed_sites: List[str] = []
# 遍历每个工具
for tool_name, tool_config in enabled_tools.items():
# 每个工具独立获取 max_workers 配置
max_workers = _get_max_workers(tool_config)
logger.info("="*60)
logger.info("使用工具: %s (并发模式, max_workers=%d)", tool_name, max_workers)
logger.info("="*60)
# 如果配置了 wordlist_name则先确保本地存在对应的字典文件含 hash 校验)
wordlist_name = tool_config.get('wordlist_name')
if wordlist_name:
try:
local_wordlist_path = ensure_wordlist_local(wordlist_name)
tool_config['wordlist'] = local_wordlist_path
except Exception as exc:
logger.error("为工具 %s 准备字典失败: %s", tool_name, exc)
# 当前工具无法执行,将所有站点视为失败,继续下一个工具
failed_sites.extend(sites)
continue
# 计算超时时间(所有站点共用)
for tool_name, tool_config in enabled_tools.items():
max_workers = _get_max_workers(tool_config)
logger.info("=" * 60)
logger.info("使用工具: %s (并发模式, max_workers=%d)", tool_name, max_workers)
logger.info("=" * 60)
user_log(scan_id, "directory_scan", f"Running {tool_name}")
# 准备字典文件
if not _prepare_tool_wordlist(tool_name, tool_config):
failed_sites.extend(sites)
continue
# 计算超时时间
site_timeout = tool_config.get('timeout', 300)
if site_timeout == 'auto':
site_timeout = calculate_directory_scan_timeout(tool_config)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {site_timeout}")
# 准备所有站点的扫描参数
scan_params_list = []
for idx, site_url in enumerate(sites, 1):
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='directory_scan',
command_params={'url': site_url},
tool_config=tool_config
)
log_file = _generate_log_filename(tool_name, site_url, directory_scan_dir)
scan_params_list.append({
'idx': idx,
'site_url': site_url,
'command': command,
'log_file': str(log_file),
'timeout': site_timeout
})
except Exception as e:
logger.error(
"✗ [%d/%d] 构建 %s 命令失败: %s - 站点: %s",
idx, len(sites), tool_name, e, site_url
)
failed_sites.append(site_url)
logger.info("✓ 工具 %s 动态计算 timeout: %d", tool_name, site_timeout)
# 构建扫描参数
scan_params_list, build_failed = _build_scan_params(
tool_name, tool_config, sites, directory_scan_dir, site_timeout
)
failed_sites.extend(build_failed)
if not scan_params_list:
logger.warning("没有有效的扫描任务")
continue
# ============================================================
# 分批执行策略:控制实际并发的 ffuf 进程数
# ============================================================
# 分批执行
total_tasks = len(scan_params_list)
logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers)
batch_num = 0
last_progress_percent = 0
tool_directories = 0
tool_processed = 0
for batch_start in range(0, total_tasks, max_workers):
batch_end = min(batch_start + max_workers, total_tasks)
batch_params = scan_params_list[batch_start:batch_end]
batch_num += 1
logger.info("执行第 %d 批任务(%d-%d/%d...", batch_num, batch_start + 1, batch_end, total_tasks)
# 提交当前批次的任务(非阻塞,立即返回 future
futures = []
for params in batch_params:
future = run_and_stream_save_directories_task.submit(
cmd=params['command'],
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
site_url=params['site_url'],
cwd=str(directory_scan_dir),
shell=True,
batch_size=1000,
timeout=params['timeout'],
log_file=params['log_file']
batch_num = batch_start // max_workers + 1
logger.info(
"执行第 %d 批任务(%d-%d/%d...",
batch_num, batch_start + 1, batch_end, total_tasks
)
dirs_found, batch_failed = _execute_batch(
batch_params, tool_name, scan_id, target_id,
directory_scan_dir, len(sites)
)
total_directories += dirs_found
tool_directories += dirs_found
tool_processed += len(batch_params) - len(batch_failed)
processed_sites_count += len(batch_params) - len(batch_failed)
failed_sites.extend(batch_failed)
# 进度里程碑:每 20% 输出一次
current_progress = int((batch_end / total_tasks) * 100)
if current_progress >= last_progress_percent + 20:
user_log(
scan_id, "directory_scan",
f"Progress: {batch_end}/{total_tasks} sites scanned"
)
futures.append((params['idx'], params['site_url'], future))
# 等待当前批次所有任务完成(阻塞,确保本批完成后再启动下一批)
for idx, site_url, future in futures:
try:
result = future.result() # 阻塞等待单个任务完成
directories_found = result.get('created_directories', 0)
total_directories += directories_found
processed_sites_count += 1
logger.info(
"✓ [%d/%d] 站点扫描完成: %s - 发现 %d 个目录",
idx, len(sites), site_url, directories_found
)
except Exception as exc:
failed_sites.append(site_url)
if 'timeout' in str(exc).lower() or isinstance(exc, subprocess.TimeoutExpired):
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
idx, len(sites), site_url, exc
)
else:
logger.error(
"✗ [%d/%d] 站点扫描失败: %s - 错误: %s",
idx, len(sites), site_url, exc
)
# 输出汇总信息
if failed_sites:
logger.warning(
"部分站点扫描失败: %d/%d",
len(failed_sites), len(sites)
last_progress_percent = (current_progress // 20) * 20
logger.info(
"✓ 工具 %s 执行完成 - 已处理站点: %d/%d, 发现目录: %d",
tool_name, tool_processed, total_tasks, tool_directories
)
user_log(
scan_id, "directory_scan",
f"{tool_name} completed: found {tool_directories} directories"
)
if failed_sites:
logger.warning("部分站点扫描失败: %d/%d", len(failed_sites), len(sites))
logger.info(
"✓ 并发目录扫描执行完成 - 成功: %d/%d, 失败: %d, 总目录数: %d",
processed_sites_count, len(sites), len(failed_sites), total_directories
)
return total_directories, processed_sites_count, failed_sites
@flow(
name="directory_scan",
name="directory_scan",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
@@ -549,62 +396,31 @@ def directory_scan_flow(
) -> dict:
"""
目录扫描 Flow
主要功能:
1. 从 target 获取所有站点的 URL
2. 对每个站点 URL 执行目录扫描(支持 ffuf 等工具)
3. 流式保存扫描结果到数据库 Directory 表
工作流程:
Step 0: 创建工作目录
Step 1: 导出站点 URL 列表到文件(供扫描工具使用)
Step 2: 验证工具配置
Step 3: 并发执行扫描工具并实时保存结果(使用 ThreadPoolTaskRunner
ffuf 输出字段:
- url: 发现的目录/文件 URL
- length: 响应内容长度
- status: HTTP 状态码
- words: 响应内容单词数
- lines: 响应内容行数
- content_type: 内容类型
- duration: 请求耗时
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置字典
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'sites_file': str,
'site_count': int,
'total_directories': int, # 发现的总目录数
'processed_sites': int, # 成功处理的站点数
'failed_sites_count': int, # 失败的站点数
'executed_tasks': list
}
Raises:
ValueError: 参数错误
RuntimeError: 执行失败
dict: 扫描结果
"""
try:
wait_for_system_load(context="directory_scan_flow")
logger.info(
"="*60 + "\n" +
"开始目录扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始目录扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "directory_scan", "Starting directory scan")
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
@@ -616,16 +432,17 @@ def directory_scan_flow(
raise ValueError("scan_workspace_dir 不能为空")
if not enabled_tools:
raise ValueError("enabled_tools 不能为空")
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
# Step 1: 导出站点 URL(支持懒加载)
sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir)
# Step 1: 导出站点 URL
sites_file, site_count = _export_site_urls(target_id, directory_scan_dir)
if site_count == 0:
logger.warning("目标下没有站点,跳过目录扫描")
logger.warning("跳过目录扫描:没有站点可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "directory_scan", "Skipped: no sites to scan", "warning")
return {
'success': True,
'scan_id': scan_id,
@@ -638,16 +455,16 @@ def directory_scan_flow(
'failed_sites_count': 0,
'executed_tasks': ['export_sites']
}
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
tool_info = []
for tool_name, tool_config in enabled_tools.items():
mw = _get_max_workers(tool_config)
tool_info.append(f"{tool_name}(max_workers={mw})")
tool_info = [
f"{name}(max_workers={_get_max_workers(cfg)})"
for name, cfg in enabled_tools.items()
]
logger.info("✓ 启用工具: %s", ', '.join(tool_info))
# Step 3: 并发执行扫描工具并实时保存结果
# Step 3: 并发执行扫描
logger.info("Step 3: 并发执行扫描工具并实时保存结果")
total_directories, processed_sites, failed_sites = _run_scans_concurrently(
enabled_tools=enabled_tools,
@@ -655,17 +472,20 @@ def directory_scan_flow(
directory_scan_dir=directory_scan_dir,
scan_id=scan_id,
target_id=target_id,
site_count=site_count,
target_name=target_name
)
# 检查是否所有站点都失败
if processed_sites == 0 and site_count > 0:
logger.warning("所有站点扫描均失败 - 总站点数: %d, 失败数: %d", site_count, len(failed_sites))
# 不抛出异常,让扫描继续
logger.info("="*60 + "\n✓ 目录扫描完成\n" + "="*60)
logger.warning(
"所有站点扫描均失败 - 总站点数: %d, 失败数: %d",
site_count, len(failed_sites)
)
logger.info("✓ 目录扫描完成 - 发现目录: %d", total_directories)
user_log(
scan_id, "directory_scan",
f"directory_scan completed: found {total_directories} directories"
)
return {
'success': True,
'scan_id': scan_id,
@@ -678,7 +498,7 @@ def directory_scan_flow(
'failed_sites_count': len(failed_sites),
'executed_tasks': ['export_sites', 'run_and_stream_save_directories']
}
except Exception as e:
logger.exception("目录扫描失败: %s", e)
raise
raise

View File

@@ -10,26 +10,22 @@
- 流式处理输出,批量更新数据库
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
from datetime import datetime
from pathlib import Path
from prefect import flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.tasks.fingerprint_detect import (
export_urls_for_fingerprint_task,
run_xingfinger_and_stream_update_tech_task,
)
from apps.scan.utils import build_scan_command
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
logger = logging.getLogger(__name__)
@@ -42,22 +38,19 @@ def calculate_fingerprint_detect_timeout(
) -> int:
"""
根据 URL 数量计算超时时间
公式:超时时间 = URL 数量 × 每 URL 基础时间
最小值300秒
无上限
最小值300秒,无上限
Args:
url_count: URL 数量
base_per_url: 每 URL 基础时间(秒),默认 10秒
min_timeout: 最小超时时间(秒),默认 300秒
Returns:
int: 计算出的超时时间(秒)
"""
timeout = int(url_count * base_per_url)
return max(min_timeout, timeout)
return max(min_timeout, int(url_count * base_per_url))
@@ -70,17 +63,17 @@ def _export_urls(
) -> tuple[str, int]:
"""
导出 URL 到文件
Args:
target_id: 目标 ID
fingerprint_dir: 指纹识别目录
source: 数据源类型
Returns:
tuple: (urls_file, total_count)
"""
logger.info("Step 1: 导出 URL 列表 (source=%s)", source)
urls_file = str(fingerprint_dir / 'urls.txt')
export_result = export_urls_for_fingerprint_task(
target_id=target_id,
@@ -88,15 +81,14 @@ def _export_urls(
source=source,
batch_size=1000
)
total_count = export_result['total_count']
logger.info(
"✓ URL 导出完成 - 文件: %s, 数量: %d",
export_result['output_file'],
total_count
)
return export_result['output_file'], total_count
@@ -111,7 +103,7 @@ def _run_fingerprint_detect(
) -> tuple[dict, list]:
"""
执行指纹识别任务
Args:
enabled_tools: 已启用的工具配置字典
urls_file: URL 文件路径
@@ -120,55 +112,54 @@ def _run_fingerprint_detect(
scan_id: 扫描任务 ID
target_id: 目标 ID
source: 数据源类型
Returns:
tuple: (tool_stats, failed_tools)
"""
tool_stats = {}
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 获取指纹库路径
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
fingerprint_paths = get_fingerprint_paths(lib_names)
if not fingerprint_paths:
reason = f"没有可用的指纹库: {lib_names}"
logger.warning(reason)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 将指纹库路径合并到 tool_config用于命令构建
tool_config_with_paths = {**tool_config, **fingerprint_paths}
# 3. 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='fingerprint_detect',
command_params={
'urls_file': urls_file
},
command_params={'urls_file': urls_file},
tool_config=tool_config_with_paths
)
except Exception as e:
reason = f"命令构建失败: {str(e)}"
reason = f"命令构建失败: {e}"
logger.error("构建 %s 命令失败: %s", tool_name, e)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 4. 计算超时时间
timeout = calculate_fingerprint_detect_timeout(url_count)
# 5. 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = fingerprint_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
tool_name, url_count, timeout, list(fingerprint_paths.keys())
)
user_log(scan_id, "fingerprint_detect", f"Running {tool_name}: {command}")
# 6. 执行扫描任务
try:
result = run_xingfinger_and_stream_update_tech_task(
@@ -182,32 +173,39 @@ def _run_fingerprint_detect(
log_file=str(log_file),
batch_size=100
)
tool_stats[tool_name] = {
'command': command,
'result': result,
'timeout': timeout,
'fingerprint_libs': list(fingerprint_paths.keys())
}
tool_updated = result.get('updated_count', 0)
logger.info(
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
tool_name,
result.get('processed_records', 0),
result.get('updated_count', 0),
tool_updated,
result.get('not_found_count', 0)
)
user_log(
scan_id, "fingerprint_detect",
f"{tool_name} completed: identified {tool_updated} fingerprints"
)
except Exception as exc:
failed_tools.append({'tool': tool_name, 'reason': str(exc)})
reason = str(exc)
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(scan_id, "fingerprint_detect", f"{tool_name} failed: {reason}", "error")
if failed_tools:
logger.warning(
"以下指纹识别工具执行失败: %s",
', '.join([f['tool'] for f in failed_tools])
)
return tool_stats, failed_tools
@@ -227,50 +225,38 @@ def fingerprint_detect_flow(
) -> dict:
"""
指纹识别 Flow
主要功能:
1. 从数据库导出目标下所有 WebSite URL 到文件
2. 使用 xingfinger 进行技术栈识别
3. 解析结果并更新 WebSite.tech 字段(合并去重)
工作流程:
Step 0: 创建工作目录
Step 1: 导出 URL 列表
Step 2: 解析配置,获取启用的工具
Step 3: 执行 xingfinger 并解析结果
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置xingfinger
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'urls_file': str,
'url_count': int,
'processed_records': int,
'updated_count': int,
'not_found_count': int,
'executed_tasks': list,
'tool_stats': dict
}
dict: 扫描结果
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="fingerprint_detect_flow")
logger.info(
"="*60 + "\n" +
"开始指纹识别\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始指纹识别 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "fingerprint_detect", "Starting fingerprint detection")
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
@@ -280,44 +266,26 @@ def fingerprint_detect_flow(
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
# 数据源类型(当前只支持 website
source = 'website'
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
# Step 1: 导出 URL支持懒加载
urls_file, url_count = _export_urls(target_id, fingerprint_dir, source)
if url_count == 0:
logger.warning("目标下没有可用的 URL跳过指纹识别")
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': 0,
'processed_records': 0,
'updated_count': 0,
'created_count': 0,
'executed_tasks': ['export_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}
logger.warning("跳过指纹识别:没有 URL 可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "fingerprint_detect", "Skipped: no URLs to scan", "warning")
return _build_empty_result(scan_id, target_name, scan_workspace_dir, urls_file)
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
# Step 3: 执行指纹识别
logger.info("Step 3: 执行指纹识别")
tool_stats, failed_tools = _run_fingerprint_detect(
@@ -329,21 +297,37 @@ def fingerprint_detect_flow(
target_id=target_id,
source=source
)
logger.info("="*60 + "\n✓ 指纹识别完成\n" + "="*60)
# 动态生成已执行的任务列表
executed_tasks = ['export_urls_for_fingerprint']
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats.keys()])
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats])
# 汇总所有工具的结果
total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values())
total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values())
total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values())
successful_tools = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
total_processed = sum(
stats['result'].get('processed_records', 0) for stats in tool_stats.values()
)
total_updated = sum(
stats['result'].get('updated_count', 0) for stats in tool_stats.values()
)
total_created = sum(
stats['result'].get('created_count', 0) for stats in tool_stats.values()
)
total_snapshots = sum(
stats['result'].get('snapshot_count', 0) for stats in tool_stats.values()
)
# 记录 Flow 完成
logger.info("✓ 指纹识别完成 - 识别指纹: %d", total_updated)
user_log(
scan_id, "fingerprint_detect",
f"fingerprint_detect completed: identified {total_updated} fingerprints"
)
successful_tools = [
name for name in enabled_tools
if name not in [f['tool'] for f in failed_tools]
]
return {
'success': True,
'scan_id': scan_id,
@@ -354,6 +338,7 @@ def fingerprint_detect_flow(
'processed_records': total_processed,
'updated_count': total_updated,
'created_count': total_created,
'snapshot_count': total_snapshots,
'executed_tasks': executed_tasks,
'tool_stats': {
'total': len(enabled_tools),
@@ -364,7 +349,7 @@ def fingerprint_detect_flow(
'details': tool_stats
}
}
except ValueError as e:
logger.error("配置错误: %s", e)
raise
@@ -374,3 +359,33 @@ def fingerprint_detect_flow(
except Exception as e:
logger.exception("指纹识别失败: %s", e)
raise
def _build_empty_result(
scan_id: int,
target_name: str,
scan_workspace_dir: str,
urls_file: str
) -> dict:
"""构建空结果(无 URL 可扫描时)"""
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': 0,
'processed_records': 0,
'updated_count': 0,
'created_count': 0,
'snapshot_count': 0,
'executed_tasks': ['export_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}

View File

@@ -99,23 +99,24 @@ def initiate_scan_flow(
raise ValueError("engine_name is required")
logger.info(
"="*60 + "\n" +
"开始初始化扫描任务\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Engine: {engine_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
)
logger.info("="*60)
logger.info("开始初始化扫描任务")
logger.info(f"Scan ID: {scan_id}")
logger.info(f"Target: {target_name}")
logger.info(f"Engine: {engine_name}")
logger.info(f"Workspace: {scan_workspace_dir}")
logger.info("="*60)
# ==================== Task 1: 创建 Scan 工作空间 ====================
scan_workspace_path = setup_scan_workspace(scan_workspace_dir)
# ==================== Task 2: 获取引擎配置 ====================
from apps.scan.models import Scan
scan = Scan.objects.select_related('engine').get(id=scan_id)
engine_config = scan.engine.configuration
scan = Scan.objects.get(id=scan_id)
engine_config = scan.yaml_configuration
# 使用 engine_names 进行显示
display_engine_name = ', '.join(scan.engine_names) if scan.engine_names else engine_name
# ==================== Task 3: 解析配置,生成执行计划 ====================
orchestrator = FlowOrchestrator(engine_config)
@@ -123,11 +124,9 @@ def initiate_scan_flow(
# FlowOrchestrator 已经解析了所有工具配置
enabled_tools_by_type = orchestrator.enabled_tools_by_type
logger.info(
f"执行计划生成成功:\n"
f" 扫描类型: {''.join(orchestrator.scan_types)}\n"
f" 总共 {len(orchestrator.scan_types)} 个 Flow"
)
logger.info("执行计划生成成功")
logger.info(f"扫描类型: {''.join(orchestrator.scan_types)}")
logger.info(f"总共 {len(orchestrator.scan_types)} 个 Flow")
# ==================== 初始化阶段进度 ====================
# 在解析完配置后立即初始化,此时已有完整的 scan_types 列表
@@ -206,9 +205,13 @@ def initiate_scan_flow(
for mode, enabled_flows in orchestrator.get_execution_stages():
if mode == 'sequential':
# 顺序执行
logger.info(f"\n{'='*60}\n顺序执行阶段: {', '.join(enabled_flows)}\n{'='*60}")
logger.info("="*60)
logger.info(f"顺序执行阶段: {', '.join(enabled_flows)}")
logger.info("="*60)
for scan_type, flow_func, flow_specific_kwargs in get_valid_flows(enabled_flows):
logger.info(f"\n{'='*60}\n执行 Flow: {scan_type}\n{'='*60}")
logger.info("="*60)
logger.info(f"执行 Flow: {scan_type}")
logger.info("="*60)
try:
result = flow_func(**flow_specific_kwargs)
record_flow_result(scan_type, result=result)
@@ -217,12 +220,16 @@ def initiate_scan_flow(
elif mode == 'parallel':
# 并行执行阶段:通过 Task 包装子 Flow并使用 Prefect TaskRunner 并发运行
logger.info(f"\n{'='*60}\n并行执行阶段: {', '.join(enabled_flows)}\n{'='*60}")
logger.info("="*60)
logger.info(f"并行执行阶段: {', '.join(enabled_flows)}")
logger.info("="*60)
futures = []
# 提交所有并行子 Flow 任务
for scan_type, flow_func, flow_specific_kwargs in get_valid_flows(enabled_flows):
logger.info(f"\n{'='*60}\n提交并行子 Flow 任务: {scan_type}\n{'='*60}")
logger.info("="*60)
logger.info(f"提交并行子 Flow 任务: {scan_type}")
logger.info("="*60)
future = _run_subflow_task.submit(
scan_type=scan_type,
flow_func=flow_func,
@@ -243,12 +250,10 @@ def initiate_scan_flow(
record_flow_result(scan_type, error=e)
# ==================== 完成 ====================
logger.info(
"="*60 + "\n" +
"✓ 扫描任务初始化完成\n" +
f" 执行的 Flow: {', '.join(executed_flows)}\n" +
"="*60
)
logger.info("="*60)
logger.info("✓ 扫描任务初始化完成")
logger.info(f"执行的 Flow: {', '.join(executed_flows)}")
logger.info("="*60)
# ==================== 返回结果 ====================
return {

View File

@@ -1,4 +1,4 @@
"""
"""
端口扫描 Flow
负责编排端口扫描的完整流程
@@ -10,25 +10,23 @@
- 配置由 YAML 解析
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
import subprocess
from datetime import datetime
from pathlib import Path
from typing import Callable
from prefect import flow
from apps.scan.tasks.port_scan import (
export_scan_targets_task,
run_and_stream_save_ports_task
)
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.utils import config_parser, build_scan_command
from apps.scan.tasks.port_scan import (
export_hosts_task,
run_and_stream_save_ports_task,
)
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
logger = logging.getLogger(__name__)
@@ -40,28 +38,19 @@ def calculate_port_scan_timeout(
) -> int:
"""
根据目标数量和端口数量计算超时时间
计算公式:超时时间 = 目标数 × 端口数 × base_per_pair
超时范围60秒 ~ 2天172800秒
超时范围60秒 ~ 无上限
Args:
tool_config: 工具配置字典包含端口配置ports, top-ports等
file_path: 目标文件路径(域名/IP列表
base_per_pair: 每个"端口-目标对"的基础时间(秒),默认 0.5秒
Returns:
int: 计算出的超时时间(秒),范围60 ~ 172800
Example:
# 100个目标 × 100个端口 × 0.5秒 = 5000秒
# 10个目标 × 1000个端口 × 0.5秒 = 5000秒
timeout = calculate_port_scan_timeout(
tool_config={'top-ports': 100},
file_path='/path/to/domains.txt'
)
int: 计算出的超时时间(秒),最小 60 秒
"""
try:
# 1. 统计目标数量
result = subprocess.run(
['wc', '-l', file_path],
capture_output=True,
@@ -69,133 +58,116 @@ def calculate_port_scan_timeout(
check=True
)
target_count = int(result.stdout.strip().split()[0])
# 2. 解析端口数量
port_count = _parse_port_count(tool_config)
# 3. 计算超时时间
# 总工作量 = 目标数 × 端口数
total_work = target_count * port_count
timeout = int(total_work * base_per_pair)
# 4. 设置合理的下限(不再设置上限)
min_timeout = 60 # 最小 60 秒
timeout = max(min_timeout, timeout)
timeout = max(60, int(total_work * base_per_pair))
logger.info(
f"计算端口扫描 timeout - "
f"目标数: {target_count}, "
f"端口数: {port_count}, "
f"总工作量: {total_work}, "
f"超时: {timeout}"
"计算端口扫描 timeout - 目标数: %d, 端口数: %d, 总工作量: %d, 超时: %d",
target_count, port_count, total_work, timeout
)
return timeout
except Exception as e:
logger.warning(f"计算 timeout 失败: {e},使用默认值 600秒")
logger.warning("计算 timeout 失败: %s,使用默认值 600秒", e)
return 600
def _parse_port_count(tool_config: dict) -> int:
"""
从工具配置中解析端口数量
优先级:
1. top-ports: N → 返回 N
2. ports: "80,443,8080" → 返回逗号分隔的数量
3. ports: "1-1000" → 返回范围的大小
4. ports: "1-65535" → 返回 65535
5. 默认 → 返回 100naabu 默认扫描 top 100
Args:
tool_config: 工具配置字典
Returns:
int: 端口数量
"""
# 1. 检查 top-ports 配置
# 检查 top-ports 配置
if 'top-ports' in tool_config:
top_ports = tool_config['top-ports']
if isinstance(top_ports, int) and top_ports > 0:
return top_ports
logger.warning(f"top-ports 配置无效: {top_ports},使用默认值")
# 2. 检查 ports 配置
logger.warning("top-ports 配置无效: %s,使用默认值", top_ports)
# 检查 ports 配置
if 'ports' in tool_config:
ports_str = str(tool_config['ports']).strip()
# 2.1 逗号分隔的端口列表80,443,8080
# 逗号分隔的端口列表80,443,8080
if ',' in ports_str:
port_list = [p.strip() for p in ports_str.split(',') if p.strip()]
return len(port_list)
# 2.2 端口范围1-1000
return len([p.strip() for p in ports_str.split(',') if p.strip()])
# 端口范围1-1000
if '-' in ports_str:
try:
start, end = ports_str.split('-', 1)
start_port = int(start.strip())
end_port = int(end.strip())
if 1 <= start_port <= end_port <= 65535:
return end_port - start_port + 1
logger.warning(f"端口范围无效: {ports_str},使用默认值")
logger.warning("端口范围无效: %s,使用默认值", ports_str)
except ValueError:
logger.warning(f"端口范围解析失败: {ports_str},使用默认值")
# 2.3 单个端口
logger.warning("端口范围解析失败: %s,使用默认值", ports_str)
# 单个端口
try:
port = int(ports_str)
if 1 <= port <= 65535:
return 1
except ValueError:
logger.warning(f"端口配置解析失败: {ports_str},使用默认值")
# 3. 默认值naabu 默认扫描 top 100 端口
logger.warning("端口配置解析失败: %s,使用默认值", ports_str)
# 默认值naabu 默认扫描 top 100 端口
return 100
def _export_scan_targets(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
def _export_hosts(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
"""
导出扫描目标到文件
导出主机列表到文件
根据 Target 类型自动决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名
- IP: 直接写入 target.name
- CIDR: 展开 CIDR 范围内的所有 IP
Args:
target_id: 目标 ID
port_scan_dir: 端口扫描目录
Returns:
tuple: (targets_file, target_count, target_type)
tuple: (hosts_file, host_count, target_type)
"""
logger.info("Step 1: 导出扫描目标列表")
targets_file = str(port_scan_dir / 'targets.txt')
export_result = export_scan_targets_task(
logger.info("Step 1: 导出主机列表")
hosts_file = str(port_scan_dir / 'hosts.txt')
export_result = export_hosts_task(
target_id=target_id,
output_file=targets_file,
batch_size=1000 # 每次读取 1000 条,优化内存占用
output_file=hosts_file,
)
target_count = export_result['total_count']
host_count = export_result['total_count']
target_type = export_result.get('target_type', 'unknown')
logger.info(
"扫描目标导出完成 - 类型: %s, 文件: %s, 数量: %d",
target_type,
export_result['output_file'],
target_count
"主机列表导出完成 - 类型: %s, 文件: %s, 数量: %d",
target_type, export_result['output_file'], host_count
)
if target_count == 0:
logger.warning("目标下没有可扫描的地址,无法执行端口扫描")
return export_result['output_file'], target_count, target_type
if host_count == 0:
logger.warning("目标下没有可扫描的主机,无法执行端口扫描")
return export_result['output_file'], host_count, target_type
def _run_scans_sequentially(
@@ -208,7 +180,7 @@ def _run_scans_sequentially(
) -> tuple[dict, int, list, list]:
"""
串行执行端口扫描任务
Args:
enabled_tools: 已启用的工具配置字典
domains_file: 域名文件路径
@@ -216,125 +188,109 @@ def _run_scans_sequentially(
scan_id: 扫描任务 ID
target_id: 目标 ID
target_name: 目标名称(用于错误日志)
Returns:
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
注意:端口扫描是流式输出,不生成结果文件
Raises:
RuntimeError: 所有工具均失败
"""
# ==================== 构建命令并串行执行 ====================
tool_stats = {}
processed_records = 0
failed_tools = [] # 记录失败的工具(含原因)
# for循环执行工具按顺序串行运行每个启用的端口扫描工具
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 构建完整命令(变量替换)
# 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='port_scan',
command_params={
'domains_file': domains_file # 对应 {domains_file}
},
tool_config=tool_config #yaml的工具配置
command_params={'domains_file': domains_file},
tool_config=tool_config
)
except Exception as e:
reason = f"命令构建失败: {str(e)}"
logger.error(f"构建 {tool_name} 命令失败: {e}")
reason = f"命令构建失败: {e}"
logger.error("构建 %s 命令失败: %s", tool_name, e)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 获取超时时间(支持 'auto' 动态计算)
# 获取超时时间
config_timeout = tool_config['timeout']
if config_timeout == 'auto':
# 动态计算超时时间
config_timeout = calculate_port_scan_timeout(
tool_config=tool_config,
file_path=str(domains_file)
)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {config_timeout}")
# 2.1 生成日志文件路径
from datetime import datetime
config_timeout = calculate_port_scan_timeout(tool_config, str(domains_file))
logger.info("✓ 工具 %s 动态计算 timeout: %d", tool_name, config_timeout)
# 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = port_scan_dir / f"{tool_name}_{timestamp}.log"
# 3. 执行扫描任务
logger.info("开始执行 %s 扫描(超时: %d秒)...", tool_name, config_timeout)
user_log(scan_id, "port_scan", f"Running {tool_name}: {command}")
# 执行扫描任务
try:
# 直接调用 task串行执行
# 注意:端口扫描是流式输出到 stdout不使用 output_file
result = run_and_stream_save_ports_task(
cmd=command,
tool_name=tool_name, # 工具名称
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
cwd=str(port_scan_dir),
shell=True,
batch_size=1000,
timeout=config_timeout,
log_file=str(log_file) # 新增:日志文件路径
log_file=str(log_file)
)
tool_stats[tool_name] = {
'command': command,
'result': result,
'timeout': config_timeout
}
processed_records += result.get('processed_records', 0)
logger.info(
"✓ 工具 %s 流式处理完成 - 记录数: %d",
tool_name, result.get('processed_records', 0)
)
except subprocess.TimeoutExpired as exc:
# 超时异常单独处理
# 注意:流式处理任务超时时,已解析的数据已保存到数据库
reason = f"执行超时(配置: {config_timeout}秒)"
tool_records = result.get('processed_records', 0)
processed_records += tool_records
logger.info("✓ 工具 %s 流式处理完成 - 记录数: %d", tool_name, tool_records)
user_log(scan_id, "port_scan", f"{tool_name} completed: found {tool_records} ports")
except subprocess.TimeoutExpired:
reason = f"timeout after {config_timeout}s"
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.warning(
"⚠️ 工具 %s 执行超时 - 超时配置: %d\n"
"注意:超时前已解析的端口数据已保存到数据库,但扫描未完全完成。",
tool_name, config_timeout
)
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
except Exception as exc:
# 其他异常
failed_tools.append({'tool': tool_name, 'reason': str(exc)})
reason = str(exc)
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(scan_id, "port_scan", f"{tool_name} failed: {reason}", "error")
if failed_tools:
logger.warning(
"以下扫描工具执行失败: %s",
', '.join([f['tool'] for f in failed_tools])
)
if not tool_stats:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
logger.warning("所有端口扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
# 返回空结果,不抛出异常,让扫描继续
return {}, 0, [], failed_tools
# 动态计算成功的工具列表
successful_tool_names = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
successful_tool_names = [
name for name in enabled_tools
if name not in [f['tool'] for f in failed_tools]
]
logger.info(
"✓ 串行端口扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
len(tool_stats), len(enabled_tools),
', '.join(successful_tool_names) if successful_tool_names else '',
', '.join([f['tool'] for f in failed_tools]) if failed_tools else ''
)
return tool_stats, processed_records, successful_tool_names, failed_tools
@flow(
name="port_scan",
name="port_scan",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
@@ -349,19 +305,19 @@ def port_scan_flow(
) -> dict:
"""
端口扫描 Flow
主要功能:
1. 扫描目标域名/IP 的开放端口
2. 保存 host + ip + port 三元映射到 HostPortMapping 表
输出资产:
- HostPortMapping主机端口映射host + ip + port 三元组)
工作流程:
Step 0: 创建工作目录
Step 1: 导出域名列表到文件(供扫描工具使用)
Step 2: 解析配置,获取启用的工具
Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库(→ HostPortMapping
Step 3: 串行执行扫描工具,运行端口扫描工具并实时解析输出到数据库
Args:
scan_id: 扫描任务 ID
@@ -371,35 +327,15 @@ def port_scan_flow(
enabled_tools: 启用的工具配置字典
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'domains_file': str,
'domain_count': int,
'processed_records': int,
'executed_tasks': list,
'tool_stats': {
'total': int, # 总工具数
'successful': int, # 成功工具数
'failed': int, # 失败工具数
'successful_tools': list[str], # 成功工具列表 ['naabu_active']
'failed_tools': list[dict], # 失败工具列表 [{'tool': 'naabu_passive', 'reason': '超时'}]
'details': dict # 详细执行结果(保留向后兼容)
}
}
dict: 扫描结果
Raises:
ValueError: 配置错误
RuntimeError: 执行失败
Note:
端口扫描工具(如 naabu会解析域名获取 IP输出 host + ip + port 三元组。
同一 host 可能对应多个 IPCDN、负载均衡因此使用三元映射表存储。
"""
try:
# 参数验证
wait_for_system_load(context="port_scan_flow")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
@@ -410,35 +346,33 @@ def port_scan_flow(
raise ValueError("scan_workspace_dir 不能为空")
if not enabled_tools:
raise ValueError("enabled_tools 不能为空")
logger.info(
"="*60 + "\n" +
"开始端口扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始端口扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "port_scan", "Starting port scan")
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
# Step 1: 导出扫描目标列表到文件(根据 Target 类型自动决定内容)
targets_file, target_count, target_type = _export_scan_targets(target_id, port_scan_dir)
if target_count == 0:
logger.warning("目标下没有可扫描的地址,跳过端口扫描")
# Step 1: 导出主机列表
hosts_file, host_count, target_type = _export_hosts(target_id, port_scan_dir)
if host_count == 0:
logger.warning("跳过端口扫描:没有主机可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "port_scan", "Skipped: no hosts to scan", "warning")
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'targets_file': targets_file,
'target_count': 0,
'hosts_file': hosts_file,
'host_count': 0,
'target_type': target_type,
'processed_records': 0,
'executed_tasks': ['export_scan_targets'],
'executed_tasks': ['export_hosts'],
'tool_stats': {
'total': 0,
'successful': 0,
@@ -448,38 +382,35 @@ def port_scan_flow(
'details': {}
}
}
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info(
"✓ 启用工具: %s",
', '.join(enabled_tools.keys())
)
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
# Step 3: 串行执行扫描工具
logger.info("Step 3: 串行执行扫描工具")
tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially(
enabled_tools=enabled_tools,
domains_file=targets_file, # 现在是 targets_file兼容原参数名
domains_file=hosts_file,
port_scan_dir=port_scan_dir,
scan_id=scan_id,
target_id=target_id,
target_name=target_name
)
logger.info("="*60 + "\n✓ 端口扫描完成\n" + "="*60)
# 动态生成已执行的任务列表
executed_tasks = ['export_scan_targets', 'parse_config']
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats.keys()])
logger.info("✓ 端口扫描完成 - 发现端口: %d", processed_records)
user_log(scan_id, "port_scan", f"port_scan completed: found {processed_records} ports")
executed_tasks = ['export_hosts', 'parse_config']
executed_tasks.extend([f'run_and_stream_save_ports ({tool})' for tool in tool_stats])
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'targets_file': targets_file,
'target_count': target_count,
'hosts_file': hosts_file,
'host_count': host_count,
'target_type': target_type,
'processed_records': processed_records,
'executed_tasks': executed_tasks,
@@ -488,8 +419,8 @@ def port_scan_flow(
'successful': len(successful_tool_names),
'failed': len(failed_tools),
'successful_tools': successful_tool_names,
'failed_tools': failed_tools, # [{'tool': 'naabu_active', 'reason': '超时'}]
'details': tool_stats # 详细结果(保留向后兼容)
'failed_tools': failed_tools,
'details': tool_stats
}
}

View File

@@ -0,0 +1,208 @@
"""
截图 Flow
负责编排截图的完整流程:
1. 从数据库获取 URL 列表websites 和/或 endpoints
2. 批量截图并保存快照
3. 同步到资产表
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
"""
import logging
from typing import Optional
from prefect import flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.providers import TargetProvider
from apps.scan.services.target_export_service import DataSource, get_urls_with_fallback
from apps.scan.tasks.screenshot import capture_screenshots_task
from apps.scan.utils import user_log, wait_for_system_load
logger = logging.getLogger(__name__)
# URL 来源到 DataSource 的映射
_SOURCE_MAPPING = {
'websites': DataSource.WEBSITE,
'endpoints': DataSource.ENDPOINT,
}
def _parse_screenshot_config(enabled_tools: dict) -> dict:
"""解析截图配置"""
playwright_config = enabled_tools.get('playwright', {})
return {
'concurrency': playwright_config.get('concurrency', 5),
'url_sources': playwright_config.get('url_sources', ['websites'])
}
def _map_url_sources_to_data_sources(url_sources: list[str]) -> list[str]:
"""将配置中的 url_sources 映射为 DataSource 常量"""
sources = []
for source in url_sources:
if source in _SOURCE_MAPPING:
sources.append(_SOURCE_MAPPING[source])
else:
logger.warning("未知的 URL 来源: %s,跳过", source)
# 添加默认回退(从 subdomain 构造)
sources.append(DataSource.DEFAULT)
return sources
def _collect_urls_from_provider(provider: TargetProvider) -> tuple[list[str], str, list[str]]:
"""从 Provider 收集 URL"""
logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__)
urls = list(provider.iter_urls())
blacklist_filter = provider.get_blacklist_filter()
if blacklist_filter:
urls = [url for url in urls if blacklist_filter.is_allowed(url)]
return urls, 'provider', ['provider']
def _collect_urls_from_database(
target_id: int,
url_sources: list[str]
) -> tuple[list[str], str, list[str]]:
"""从数据库收集 URL带黑名单过滤和回退"""
data_sources = _map_url_sources_to_data_sources(url_sources)
result = get_urls_with_fallback(target_id, sources=data_sources)
return result['urls'], result['source'], result['tried_sources']
def _build_empty_result(scan_id: int, target_name: str) -> dict:
"""构建空结果"""
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'total_urls': 0,
'successful': 0,
'failed': 0,
'synced': 0
}
@flow(
name="screenshot",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
on_failure=[on_scan_flow_failed],
)
def screenshot_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict,
provider: Optional[TargetProvider] = None
) -> dict:
"""
截图 Flow
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置
provider: TargetProvider 实例(新模式,可选)
Returns:
截图结果字典
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="screenshot_flow")
mode = 'Provider' if provider else 'Legacy'
logger.info(
"开始截图扫描 - Scan ID: %s, Target: %s, Mode: %s",
scan_id, target_name, mode
)
user_log(scan_id, "screenshot", "Starting screenshot capture")
# Step 1: 解析配置
config = _parse_screenshot_config(enabled_tools)
concurrency = config['concurrency']
logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, config['url_sources'])
# Step 2: 收集 URL 列表
if provider is not None:
urls, source_info, tried_sources = _collect_urls_from_provider(provider)
else:
urls, source_info, tried_sources = _collect_urls_from_database(
target_id, config['url_sources']
)
logger.info(
"URL 收集完成 - 来源: %s, 数量: %d, 尝试过: %s",
source_info, len(urls), tried_sources
)
if not urls:
logger.warning("没有可截图的 URL跳过截图任务")
user_log(scan_id, "screenshot", "Skipped: no URLs to capture", "warning")
return _build_empty_result(scan_id, target_name)
user_log(
scan_id, "screenshot",
f"Found {len(urls)} URLs to capture (source: {source_info})"
)
# Step 3: 批量截图
logger.info("批量截图 - %d 个 URL", len(urls))
capture_result = capture_screenshots_task(
urls=urls,
scan_id=scan_id,
target_id=target_id,
config={'concurrency': concurrency}
)
# Step 4: 同步到资产表
logger.info("同步截图到资产表")
from apps.asset.services.screenshot_service import ScreenshotService
synced = ScreenshotService().sync_screenshots_to_asset(scan_id, target_id)
total = capture_result['total']
successful = capture_result['successful']
failed = capture_result['failed']
logger.info(
"✓ 截图完成 - 总数: %d, 成功: %d, 失败: %d, 同步: %d",
total, successful, failed, synced
)
user_log(
scan_id, "screenshot",
f"Screenshot completed: {successful}/{total} captured, {synced} synced"
)
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'total_urls': total,
'successful': successful,
'failed': failed,
'synced': synced
}
except Exception:
logger.exception("截图 Flow 失败")
user_log(scan_id, "screenshot", "Screenshot failed", "error")
raise

View File

@@ -1,4 +1,3 @@
"""
站点扫描 Flow
@@ -11,296 +10,319 @@
- 配置由 YAML 解析
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
import subprocess
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Callable
from typing import Optional
from prefect import flow
from apps.scan.tasks.site_scan import export_site_urls_task, run_and_stream_save_websites_task
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect # noqa: F401
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.utils import config_parser, build_scan_command
from apps.scan.tasks.site_scan import (
export_site_urls_task,
run_and_stream_save_websites_task,
)
from apps.scan.utils import build_scan_command, user_log, wait_for_system_load
logger = logging.getLogger(__name__)
def calculate_timeout_by_line_count(
tool_config: dict,
file_path: str,
base_per_time: int = 1,
min_timeout: int = 60
) -> int:
"""
根据文件行数计算 timeout
使用 wc -l 统计文件行数,根据行数和每行基础时间计算 timeout
Args:
tool_config: 工具配置字典(此函数未使用,但保持接口一致性)
file_path: 要统计行数的文件路径
base_per_time: 每行的基础时间默认1秒
min_timeout: 最小超时时间默认60秒
Returns:
int: 计算出的超时时间(秒),不低于 min_timeout
Example:
timeout = calculate_timeout_by_line_count(
tool_config={},
file_path='/path/to/urls.txt',
base_per_time=2
)
"""
@dataclass
class ScanContext:
"""扫描上下文,封装扫描参数"""
scan_id: int
target_id: int
target_name: str
site_scan_dir: Path
urls_file: str
total_urls: int
def _count_file_lines(file_path: str) -> int:
"""使用 wc -l 统计文件行数"""
try:
# 使用 wc -l 快速统计行数
result = subprocess.run(
['wc', '-l', file_path],
capture_output=True,
text=True,
check=True
)
# wc -l 输出格式:行数 + 空格 + 文件名
line_count = int(result.stdout.strip().split()[0])
# 计算 timeout行数 × 每行基础时间,不低于最小值
timeout = max(line_count * base_per_time, min_timeout)
logger.info(
f"timeout 自动计算: 文件={file_path}, "
f"行数={line_count}, 每行时间={base_per_time}秒, 最小值={min_timeout}秒, timeout={timeout}"
)
return timeout
except Exception as e:
# 如果 wc -l 失败,使用默认值
logger.warning(f"wc -l 计算行数失败: {e},使用默认 timeout: {min_timeout}")
return min_timeout
return int(result.stdout.strip().split()[0])
except (subprocess.CalledProcessError, ValueError, IndexError) as e:
logger.warning("wc -l 计算行数失败: %s,返回 0", e)
return 0
def _calculate_timeout_by_line_count(
file_path: str,
base_per_time: int = 1,
min_timeout: int = 60
) -> int:
"""
根据文件行数计算 timeout
Args:
file_path: 要统计行数的文件路径
base_per_time: 每行的基础时间默认1秒
min_timeout: 最小超时时间默认60秒
Returns:
int: 计算出的超时时间(秒),不低于 min_timeout
"""
line_count = _count_file_lines(file_path)
timeout = max(line_count * base_per_time, min_timeout)
logger.info(
"timeout 自动计算: 文件=%s, 行数=%d, 每行时间=%d秒, timeout=%d",
file_path, line_count, base_per_time, timeout
)
return timeout
def _export_site_urls(target_id: int, site_scan_dir: Path, target_name: str = None) -> tuple[str, int, int]:
def _export_site_urls(
target_id: int,
site_scan_dir: Path
) -> tuple[str, int, int]:
"""
导出站点 URL 到文件
Args:
target_id: 目标 ID
site_scan_dir: 站点扫描目录
target_name: 目标名称(用于懒加载时写入默认值)
Returns:
tuple: (urls_file, total_urls, association_count)
Raises:
ValueError: URL 数量为 0
"""
logger.info("Step 1: 导出站点URL列表")
urls_file = str(site_scan_dir / 'site_urls.txt')
export_result = export_site_urls_task(
target_id=target_id,
output_file=urls_file,
batch_size=1000 # 每次处理1000个子域名
batch_size=1000
)
total_urls = export_result['total_urls']
association_count = export_result['association_count'] # 主机端口关联数
association_count = export_result['association_count']
logger.info(
"✓ 站点URL导出完成 - 文件: %s, URL数量: %d, 关联数: %d",
export_result['output_file'],
total_urls,
association_count
export_result['output_file'], total_urls, association_count
)
if total_urls == 0:
logger.warning("目标下没有可用的站点URL无法执行站点扫描")
# 不抛出异常,由上层决定如何处理
# raise ValueError("目标下没有可用的站点URL无法执行站点扫描")
return export_result['output_file'], total_urls, association_count
def _get_tool_timeout(tool_config: dict, urls_file: str) -> int:
"""获取工具超时时间(支持 'auto' 动态计算)"""
config_timeout = tool_config.get('timeout', 300)
if config_timeout == 'auto':
return _calculate_timeout_by_line_count(urls_file, base_per_time=1)
dynamic_timeout = _calculate_timeout_by_line_count(urls_file, base_per_time=1)
return max(dynamic_timeout, config_timeout)
def _execute_single_tool(
tool_name: str,
tool_config: dict,
ctx: ScanContext
) -> Optional[dict]:
"""
执行单个扫描工具
Returns:
成功返回结果字典,失败返回 None
"""
# 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='site_scan',
command_params={'url_file': ctx.urls_file},
tool_config=tool_config
)
except (ValueError, KeyError) as e:
logger.error("构建 %s 命令失败: %s", tool_name, e)
return None
timeout = _get_tool_timeout(tool_config, ctx.urls_file)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = ctx.site_scan_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 站点扫描 - URL数: %d, 超时: %ds",
tool_name, ctx.total_urls, timeout
)
user_log(ctx.scan_id, "site_scan", f"Running {tool_name}: {command}")
try:
result = run_and_stream_save_websites_task(
cmd=command,
tool_name=tool_name,
scan_id=ctx.scan_id,
target_id=ctx.target_id,
cwd=str(ctx.site_scan_dir),
shell=True,
timeout=timeout,
log_file=str(log_file)
)
tool_created = result.get('created_websites', 0)
skipped = result.get('skipped_no_subdomain', 0) + result.get('skipped_failed', 0)
logger.info(
"✓ 工具 %s 完成 - 处理: %d, 创建: %d, 跳过: %d",
tool_name, result.get('processed_records', 0), tool_created, skipped
)
user_log(
ctx.scan_id, "site_scan",
f"{tool_name} completed: found {tool_created} websites"
)
return {'command': command, 'result': result, 'timeout': timeout}
except subprocess.TimeoutExpired:
logger.warning(
"⚠️ 工具 %s 执行超时 - 超时配置: %d秒 (超时前数据已保存)",
tool_name, timeout
)
user_log(
ctx.scan_id, "site_scan",
f"{tool_name} failed: timeout after {timeout}s", "error"
)
except (OSError, RuntimeError) as exc:
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
user_log(ctx.scan_id, "site_scan", f"{tool_name} failed: {exc}", "error")
return None
def _run_scans_sequentially(
enabled_tools: dict,
urls_file: str,
total_urls: int,
site_scan_dir: Path,
scan_id: int,
target_id: int,
target_name: str
ctx: ScanContext
) -> tuple[dict, int, list, list]:
"""
串行执行站点扫描任务
Args:
enabled_tools: 已启用的工具配置字典
urls_file: URL 文件路径
total_urls: URL 总数
site_scan_dir: 站点扫描目录
scan_id: 扫描任务 ID
target_id: 目标 ID
target_name: 目标名称(用于错误日志)
Returns:
tuple: (tool_stats, processed_records, successful_tool_names, failed_tools)
Raises:
RuntimeError: 所有工具均失败
tuple: (tool_stats, processed_records, successful_tools, failed_tools)
"""
tool_stats = {}
processed_records = 0
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 构建完整命令(变量替换)
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='site_scan',
command_params={
'url_file': urls_file
},
tool_config=tool_config
)
except Exception as e:
reason = f"命令构建失败: {str(e)}"
logger.error(f"构建 {tool_name} 命令失败: {e}")
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 获取超时时间(支持 'auto' 动态计算)
config_timeout = tool_config.get('timeout', 300)
if config_timeout == 'auto':
# 动态计算超时时间
timeout = calculate_timeout_by_line_count(tool_config, urls_file, base_per_time=1)
logger.info(f"✓ 工具 {tool_name} 动态计算 timeout: {timeout}")
result = _execute_single_tool(tool_name, tool_config, ctx)
if result:
tool_stats[tool_name] = result
processed_records += result['result'].get('processed_records', 0)
else:
# 使用配置的超时时间和动态计算的较大值
dynamic_timeout = calculate_timeout_by_line_count(tool_config, urls_file, base_per_time=1)
timeout = max(dynamic_timeout, config_timeout)
# 2.1 生成日志文件路径(类似端口扫描)
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = site_scan_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 站点扫描 - URL数: %d, 最终超时: %ds",
tool_name, total_urls, timeout
)
# 3. 执行扫描任务
try:
# 流式执行扫描并实时保存结果
result = run_and_stream_save_websites_task(
cmd=command,
tool_name=tool_name, # 新增:工具名称
scan_id=scan_id,
target_id=target_id,
cwd=str(site_scan_dir),
shell=True,
batch_size=1000,
timeout=timeout,
log_file=str(log_file) # 新增:日志文件路径
)
tool_stats[tool_name] = {
'command': command,
'result': result,
'timeout': timeout
}
processed_records += result.get('processed_records', 0)
logger.info(
"✓ 工具 %s 流式处理完成 - 处理记录: %d, 创建站点: %d, 跳过: %d",
tool_name,
result.get('processed_records', 0),
result.get('created_websites', 0),
result.get('skipped_no_subdomain', 0) + result.get('skipped_failed', 0)
)
except subprocess.TimeoutExpired as exc:
# 超时异常单独处理
reason = f"执行超时(配置: {timeout}秒)"
failed_tools.append({'tool': tool_name, 'reason': reason})
logger.warning(
"⚠️ 工具 %s 执行超时 - 超时配置: %d\n"
"注意:超时前已解析的站点数据已保存到数据库,但扫描未完全完成。",
tool_name, timeout
)
except Exception as exc:
# 其他异常
failed_tools.append({'tool': tool_name, 'reason': str(exc)})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
failed_tools.append({'tool': tool_name, 'reason': '执行失败'})
if failed_tools:
logger.warning(
"以下扫描工具执行失败: %s",
', '.join([f['tool'] for f in failed_tools])
', '.join(f['tool'] for f in failed_tools)
)
if not tool_stats:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in failed_tools])
logger.warning("所有站点扫描工具均失败 - 目标: %s, 失败工具: %s", target_name, error_details)
# 返回空结果,不抛出异常,让扫描继续
logger.warning(
"所有站点扫描工具均失败 - 目标: %s", ctx.target_name
)
return {}, 0, [], failed_tools
# 动态计算成功的工具列表
successful_tool_names = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
successful_tools = [
name for name in enabled_tools
if name not in {f['tool'] for f in failed_tools}
]
logger.info(
"串行站点扫描执行完成 - 成功: %d/%d (成功: %s, 失败: %s)",
len(tool_stats), len(enabled_tools),
', '.join(successful_tool_names) if successful_tool_names else '',
', '.join([f['tool'] for f in failed_tools]) if failed_tools else ''
"✓ 站点扫描执行完成 - 成功: %d/%d",
len(tool_stats), len(enabled_tools)
)
return tool_stats, processed_records, successful_tool_names, failed_tools
return tool_stats, processed_records, successful_tools, failed_tools
def calculate_timeout(url_count: int, base: int = 600, per_url: int = 1) -> int:
"""
根据 URL 数量动态计算扫描超时时间
def _build_empty_result(
scan_id: int,
target_name: str,
scan_workspace_dir: str,
urls_file: str,
association_count: int
) -> dict:
"""构建空结果(无 URL 可扫描时)"""
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'total_urls': 0,
'association_count': association_count,
'processed_records': 0,
'created_websites': 0,
'skipped_no_subdomain': 0,
'skipped_failed': 0,
'executed_tasks': ['export_site_urls'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}
规则:
- 基础时间:默认 600 秒10 分钟)
- 每个 URL 额外增加:默认 1 秒
Args:
url_count: URL 数量,必须为正整数
base: 基础超时时间(秒),默认 600
per_url: 每个 URL 增加的时间(秒),默认 1
def _aggregate_tool_results(tool_stats: dict) -> tuple[int, int, int]:
"""汇总工具结果"""
total_created = sum(
s['result'].get('created_websites', 0) for s in tool_stats.values()
)
total_skipped_no_subdomain = sum(
s['result'].get('skipped_no_subdomain', 0) for s in tool_stats.values()
)
total_skipped_failed = sum(
s['result'].get('skipped_failed', 0) for s in tool_stats.values()
)
return total_created, total_skipped_no_subdomain, total_skipped_failed
Returns:
int: 计算得到的超时时间(秒),不超过 max_timeout
Raises:
ValueError: 当 url_count 为负数或 0 时抛出异常
"""
if url_count < 0:
raise ValueError(f"URL数量不能为负数: {url_count}")
if url_count == 0:
raise ValueError("URL数量不能为0")
timeout = base + int(url_count * per_url)
# 不设置上限,由调用方根据需要控制
return timeout
def _validate_flow_params(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str
) -> None:
"""验证 Flow 参数"""
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
@flow(
name="site_scan",
name="site_scan",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
@@ -315,135 +337,83 @@ def site_scan_flow(
) -> dict:
"""
站点扫描 Flow
主要功能:
1. 从target获取所有子域名与其对应的端口号拼接成URL写入文件
2. 用httpx进行批量请求并实时保存到数据库流式处理
工作流程:
Step 0: 创建工作目录
Step 1: 导出站点 URL 列表
Step 2: 解析配置,获取启用的工具
Step 3: 串行执行扫描工具并实时保存结果
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置字典
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'urls_file': str,
'total_urls': int,
'association_count': int,
'processed_records': int,
'created_websites': int,
'skipped_no_subdomain': int,
'skipped_failed': int,
'executed_tasks': list,
'tool_stats': {
'total': int,
'successful': int,
'failed': int,
'successful_tools': list[str],
'failed_tools': list[dict]
}
}
dict: 扫描结果
Raises:
ValueError: 配置错误
RuntimeError: 执行失败
"""
try:
wait_for_system_load(context="site_scan_flow")
logger.info(
"="*60 + "\n" +
"开始站点扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始站点扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
_validate_flow_params(scan_id, target_name, target_id, scan_workspace_dir)
user_log(scan_id, "site_scan", "Starting site scan")
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan')
# Step 1: 导出站点 URL
urls_file, total_urls, association_count = _export_site_urls(
target_id, site_scan_dir, target_name
target_id, site_scan_dir
)
if total_urls == 0:
logger.warning("目标下没有可用的站点URL,跳过站点扫描")
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'total_urls': 0,
'association_count': association_count,
'processed_records': 0,
'created_websites': 0,
'skipped_no_subdomain': 0,
'skipped_failed': 0,
'executed_tasks': ['export_site_urls'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}
logger.warning("跳过站点扫描:没有站点 URL 可扫描 - Scan ID: %s", scan_id)
user_log(scan_id, "site_scan", "Skipped: no site URLs to scan", "warning")
return _build_empty_result(
scan_id, target_name, scan_workspace_dir, urls_file, association_count
)
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info(
"✓ 启用工具: %s",
', '.join(enabled_tools.keys())
)
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools))
# Step 3: 串行执行扫描工具
logger.info("Step 3: 串行执行扫描工具并实时保存结果")
tool_stats, processed_records, successful_tool_names, failed_tools = _run_scans_sequentially(
enabled_tools=enabled_tools,
urls_file=urls_file,
total_urls=total_urls,
site_scan_dir=site_scan_dir,
ctx = ScanContext(
scan_id=scan_id,
target_id=target_id,
target_name=target_name
target_name=target_name,
site_scan_dir=site_scan_dir,
urls_file=urls_file,
total_urls=total_urls
)
logger.info("="*60 + "\n✓ 站点扫描完成\n" + "="*60)
# 动态生成已执行的任务列表
tool_stats, processed_records, successful_tools, failed_tools = \
_run_scans_sequentially(enabled_tools, ctx)
# 汇总结果
executed_tasks = ['export_site_urls', 'parse_config']
executed_tasks.extend([f'run_and_stream_save_websites ({tool})' for tool in tool_stats.keys()])
# 汇总所有工具的结果
total_created = sum(stats['result'].get('created_websites', 0) for stats in tool_stats.values())
total_skipped_no_subdomain = sum(stats['result'].get('skipped_no_subdomain', 0) for stats in tool_stats.values())
total_skipped_failed = sum(stats['result'].get('skipped_failed', 0) for stats in tool_stats.values())
executed_tasks.extend(
f'run_and_stream_save_websites ({tool})' for tool in tool_stats
)
total_created, total_skipped_no_sub, total_skipped_failed = \
_aggregate_tool_results(tool_stats)
logger.info("✓ 站点扫描完成 - 创建站点: %d", total_created)
user_log(
scan_id, "site_scan",
f"site_scan completed: found {total_created} websites"
)
return {
'success': True,
'scan_id': scan_id,
@@ -454,25 +424,20 @@ def site_scan_flow(
'association_count': association_count,
'processed_records': processed_records,
'created_websites': total_created,
'skipped_no_subdomain': total_skipped_no_subdomain,
'skipped_no_subdomain': total_skipped_no_sub,
'skipped_failed': total_skipped_failed,
'executed_tasks': executed_tasks,
'tool_stats': {
'total': len(enabled_tools),
'successful': len(successful_tool_names),
'successful': len(successful_tools),
'failed': len(failed_tools),
'successful_tools': successful_tool_names,
'successful_tools': successful_tools,
'failed_tools': failed_tools,
'details': tool_stats
}
}
except ValueError as e:
logger.error("配置错误: %s", e)
except ValueError:
raise
except RuntimeError as e:
logger.error("运行时错误: %s", e)
except RuntimeError:
raise
except Exception as e:
logger.exception("站点扫描失败: %s", e)
raise

File diff suppressed because it is too large Load Diff

View File

@@ -59,6 +59,8 @@ def domain_name_url_fetch_flow(
- IP 和 CIDR 类型会自动跳过waymore 等工具不支持)
- 工具会自动收集 *.target_name 的所有历史 URL无需遍历子域名
"""
from apps.scan.utils import user_log
try:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
@@ -145,6 +147,9 @@ def domain_name_url_fetch_flow(
timeout,
)
# 记录工具开始执行日志
user_log(scan_id, "url_fetch", f"Running {tool_name}: {command}")
future = run_url_fetcher_task.submit(
tool_name=tool_name,
command=command,
@@ -163,22 +168,28 @@ def domain_name_url_fetch_flow(
if result and result.get("success"):
result_files.append(result["output_file"])
successful_tools.append(tool_name)
url_count = result.get("url_count", 0)
logger.info(
"✓ 工具 %s 执行成功 - 发现 URL: %d",
tool_name,
result.get("url_count", 0),
url_count,
)
user_log(scan_id, "url_fetch", f"{tool_name} completed: found {url_count} urls")
else:
reason = "未生成结果或无有效 URL"
failed_tools.append(
{
"tool": tool_name,
"reason": "未生成结果或无有效 URL",
"reason": reason,
}
)
logger.warning("⚠️ 工具 %s 未生成有效结果", tool_name)
user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error")
except Exception as e:
failed_tools.append({"tool": tool_name, "reason": str(e)})
reason = str(e)
failed_tools.append({"tool": tool_name, "reason": reason})
logger.warning("⚠️ 工具 %s 执行失败: %s", tool_name, e)
user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error")
logger.info(
"基于 domain_name 的 URL 获取完成 - 成功工具: %s, 失败工具: %s",

View File

@@ -10,21 +10,18 @@ URL Fetch 主 Flow
- 统一进行 httpx 验证(如果启用)
"""
# Django 环境初始化
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
from pathlib import Path
from datetime import datetime
from pathlib import Path
from prefect import flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
on_scan_flow_running,
)
from apps.scan.utils import user_log, wait_for_system_load
from .domain_name_url_fetch_flow import domain_name_url_fetch_flow
from .sites_url_fetch_flow import sites_url_fetch_flow
@@ -42,13 +39,10 @@ SITES_FILE_TOOLS = {'katana'}
POST_PROCESS_TOOLS = {'uro', 'httpx'}
def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
"""
将启用的工具按输入类型分类
Returns:
tuple: (domain_name_tools, sites_file_tools, uro_config, httpx_config)
"""
@@ -75,23 +69,23 @@ def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
def _merge_and_deduplicate_urls(result_files: list, url_fetch_dir: Path) -> tuple[str, int]:
"""合并并去重 URL"""
from apps.scan.tasks.url_fetch import merge_and_deduplicate_urls_task
merged_file = merge_and_deduplicate_urls_task(
result_files=result_files,
result_dir=str(url_fetch_dir)
)
# 统计唯一 URL 数量
unique_url_count = 0
if Path(merged_file).exists():
with open(merged_file, 'r') as f:
with open(merged_file, 'r', encoding='utf-8') as f:
unique_url_count = sum(1 for line in f if line.strip())
logger.info(
"✓ URL 合并去重完成 - 合并文件: %s, 唯一 URL 数: %d",
merged_file, unique_url_count
)
return merged_file, unique_url_count
@@ -102,12 +96,12 @@ def _clean_urls_with_uro(
) -> tuple[str, int, int]:
"""使用 uro 清理合并后的 URL 列表"""
from apps.scan.tasks.url_fetch import clean_urls_task
raw_timeout = uro_config.get('timeout', 60)
whitelist = uro_config.get('whitelist')
blacklist = uro_config.get('blacklist')
filters = uro_config.get('filters')
# 计算超时时间
if isinstance(raw_timeout, str) and raw_timeout == 'auto':
timeout = calculate_timeout_by_line_count(
@@ -123,7 +117,7 @@ def _clean_urls_with_uro(
except (TypeError, ValueError):
logger.warning("uro timeout 配置无效(%s),使用默认 60 秒", raw_timeout)
timeout = 60
result = clean_urls_task(
input_file=merged_file,
output_dir=str(url_fetch_dir),
@@ -132,12 +126,12 @@ def _clean_urls_with_uro(
blacklist=blacklist,
filters=filters
)
if result['success']:
return result['output_file'], result['output_count'], result['removed_count']
else:
logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误'))
return merged_file, result['input_count'], 0
logger.warning("uro 清理失败: %s,使用原始合并文件", result.get('error', '未知错误'))
return merged_file, result['input_count'], 0
def _validate_and_stream_save_urls(
@@ -150,25 +144,25 @@ def _validate_and_stream_save_urls(
"""使用 httpx 验证 URL 存活并流式保存到数据库"""
from apps.scan.utils import build_scan_command
from apps.scan.tasks.url_fetch import run_and_stream_save_urls_task
logger.info("开始使用 httpx 验证 URL 存活状态...")
# 统计待验证的 URL 数量
try:
with open(merged_file, 'r') as f:
with open(merged_file, 'r', encoding='utf-8') as f:
url_count = sum(1 for _ in f)
logger.info("待验证 URL 数量: %d", url_count)
except Exception as e:
except OSError as e:
logger.error("读取 URL 文件失败: %s", e)
return 0
if url_count == 0:
logger.warning("没有需要验证的 URL")
return 0
# 构建 httpx 命令
command_params = {'url_file': merged_file}
try:
command = build_scan_command(
tool_name='httpx',
@@ -176,21 +170,19 @@ def _validate_and_stream_save_urls(
command_params=command_params,
tool_config=httpx_config
)
except Exception as e:
except (ValueError, KeyError) as e:
logger.error("构建 httpx 命令失败: %s", e)
logger.warning("降级处理:将直接保存所有 URL不验证存活")
return _save_urls_to_database(merged_file, scan_id, target_id)
# 计算超时时间
raw_timeout = httpx_config.get('timeout', 'auto')
timeout = 3600
if isinstance(raw_timeout, str) and raw_timeout == 'auto':
# 按 URL 行数计算超时时间:每行 3 秒,最小 60 秒
timeout = max(60, url_count * 3)
logger.info(
"自动计算 httpx 超时时间(按行数,每行 3 秒,最小 60 秒): url_count=%d, timeout=%d",
url_count,
timeout,
url_count, timeout
)
else:
try:
@@ -198,50 +190,44 @@ def _validate_and_stream_save_urls(
except (TypeError, ValueError):
timeout = 3600
logger.info("使用配置的 httpx 超时时间: %d", timeout)
# 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = url_fetch_dir / f"httpx_validation_{timestamp}.log"
# 流式执行
try:
result = run_and_stream_save_urls_task(
cmd=command,
tool_name='httpx',
scan_id=scan_id,
target_id=target_id,
cwd=str(url_fetch_dir),
shell=True,
batch_size=500,
timeout=timeout,
log_file=str(log_file)
)
saved = result.get('saved_urls', 0)
logger.info(
"✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)",
saved, (saved / url_count * 100) if url_count > 0 else 0
)
return saved
except Exception as e:
logger.error("httpx 流式验证失败: %s", e, exc_info=True)
raise
result = run_and_stream_save_urls_task(
cmd=command,
tool_name='httpx',
scan_id=scan_id,
target_id=target_id,
cwd=str(url_fetch_dir),
shell=True,
timeout=timeout,
log_file=str(log_file)
)
saved = result.get('saved_urls', 0)
logger.info(
"✓ httpx 验证完成 - 存活 URL: %d (%.1f%%)",
saved, (saved / url_count * 100) if url_count > 0 else 0
)
return saved
def _save_urls_to_database(merged_file: str, scan_id: int, target_id: int) -> int:
"""保存 URL 到数据库(不验证存活)"""
from apps.scan.tasks.url_fetch import save_urls_task
result = save_urls_task(
urls_file=merged_file,
scan_id=scan_id,
target_id=target_id
)
saved_count = result.get('saved_urls', 0)
logger.info("✓ URL 保存完成 - 保存数量: %d", saved_count)
return saved_count
@@ -261,7 +247,7 @@ def url_fetch_flow(
) -> dict:
"""
URL 获取主 Flow
执行流程:
1. 准备工作目录
2. 按输入类型分类工具domain_name / sites_file / 后处理)
@@ -271,34 +257,32 @@ def url_fetch_flow(
4. 合并所有子 Flow 的结果并去重
5. uro 去重(如果启用)
6. httpx 验证(如果启用)
Args:
scan_id: 扫描 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作目录
enabled_tools: 启用的工具配置
Returns:
dict: 扫描结果
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="url_fetch_flow")
logger.info(
"="*60 + "\n" +
"开始 URL 获取扫描\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
"开始 URL 获取扫描 - Scan ID: %s, Target: %s, Workspace: %s",
scan_id, target_name, scan_workspace_dir
)
user_log(scan_id, "url_fetch", "Starting URL fetch")
# Step 1: 准备工作目录
logger.info("Step 1: 准备工作目录")
from apps.scan.utils import setup_scan_directory
url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch')
# Step 2: 分类工具(按输入类型)
logger.info("Step 2: 分类工具")
domain_name_tools, sites_file_tools, uro_config, httpx_config = _classify_tools(enabled_tools)
logger.info(
@@ -315,15 +299,14 @@ def url_fetch_flow(
"URL Fetch 流程需要至少启用一个 URL 获取工具(如 waymore, katana"
"httpx 和 uro 仅用于后处理,不能单独使用。"
)
# Step 3: 并行执行子 Flow
# Step 3: 执行子 Flow
all_result_files = []
all_failed_tools = []
all_successful_tools = []
# 3a: 基于 domain_nametarget_name 的 URL 被动收集(如 waymore
# 3a: 基于 domain_name 的 URL 被动收集(如 waymore
if domain_name_tools:
logger.info("Step 3a: 执行基于 domain_name 的 URL 被动收集子 Flow")
tn_result = domain_name_url_fetch_flow(
scan_id=scan_id,
target_id=target_id,
@@ -334,10 +317,9 @@ def url_fetch_flow(
all_result_files.extend(tn_result.get('result_files', []))
all_failed_tools.extend(tn_result.get('failed_tools', []))
all_successful_tools.extend(tn_result.get('successful_tools', []))
# 3b: 爬虫(以 sites_file 为输入)
if sites_file_tools:
logger.info("Step 3b: 执行爬虫子 Flow")
crawl_result = sites_url_fetch_flow(
scan_id=scan_id,
target_id=target_id,
@@ -348,12 +330,13 @@ def url_fetch_flow(
all_result_files.extend(crawl_result.get('result_files', []))
all_failed_tools.extend(crawl_result.get('failed_tools', []))
all_successful_tools.extend(crawl_result.get('successful_tools', []))
# 检查是否有成功的工具
if not all_result_files:
error_details = "; ".join([f"{f['tool']}: {f['reason']}" for f in all_failed_tools])
error_details = "; ".join([
"%s: %s" % (f['tool'], f['reason']) for f in all_failed_tools
])
logger.warning("所有 URL 获取工具均失败 - 目标: %s, 失败详情: %s", target_name, error_details)
# 返回空结果,不抛出异常,让扫描继续
return {
'success': True,
'scan_id': scan_id,
@@ -364,31 +347,24 @@ def url_fetch_flow(
'successful_tools': [],
'message': '所有 URL 获取工具均无结果'
}
# Step 4: 合并并去重 URL
logger.info("Step 4: 合并并去重 URL")
merged_file, unique_url_count = _merge_and_deduplicate_urls(
merged_file, _ = _merge_and_deduplicate_urls(
result_files=all_result_files,
url_fetch_dir=url_fetch_dir
)
# Step 5: 使用 uro 清理 URL如果启用
url_file_for_validation = merged_file
uro_removed_count = 0
if uro_config and uro_config.get('enabled', False):
logger.info("Step 5: 使用 uro 清理 URL")
url_file_for_validation, cleaned_count, uro_removed_count = _clean_urls_with_uro(
url_file_for_validation, _, _ = _clean_urls_with_uro(
merged_file=merged_file,
uro_config=uro_config,
url_fetch_dir=url_fetch_dir
)
else:
logger.info("Step 5: 跳过 uro 清理(未启用)")
# Step 6: 使用 httpx 验证存活并保存(如果启用)
if httpx_config and httpx_config.get('enabled', False):
logger.info("Step 6: 使用 httpx 验证 URL 存活并流式保存")
saved_count = _validate_and_stream_save_urls(
merged_file=url_file_for_validation,
httpx_config=httpx_config,
@@ -397,15 +373,16 @@ def url_fetch_flow(
target_id=target_id
)
else:
logger.info("Step 6: 保存到数据库(未启用 httpx 验证)")
saved_count = _save_urls_to_database(
merged_file=url_file_for_validation,
scan_id=scan_id,
target_id=target_id
)
logger.info("="*60 + "\n✓ URL 获取扫描完成\n" + "="*60)
# 记录 Flow 完成
logger.info("✓ URL 获取完成 - 保存 endpoints: %d", saved_count)
user_log(scan_id, "url_fetch", "url_fetch completed: found %d endpoints" % saved_count)
# 构建已执行的任务列表
executed_tasks = ['setup_directory', 'classify_tools']
if domain_name_tools:
@@ -419,7 +396,7 @@ def url_fetch_flow(
executed_tasks.append('httpx_validation_and_save')
else:
executed_tasks.append('save_urls')
return {
'success': True,
'scan_id': scan_id,
@@ -435,7 +412,7 @@ def url_fetch_flow(
'failed_tools': [f['tool'] for f in all_failed_tools]
}
}
except Exception as e:
logger.error("URL 获取扫描失败: %s", e, exc_info=True)
raise

View File

@@ -116,7 +116,8 @@ def sites_url_fetch_flow(
tools=enabled_tools,
input_file=sites_file,
input_type="sites_file",
output_dir=output_path
output_dir=output_path,
scan_id=scan_id
)
logger.info(

View File

@@ -152,7 +152,8 @@ def run_tools_parallel(
tools: dict,
input_file: str,
input_type: str,
output_dir: Path
output_dir: Path,
scan_id: int
) -> tuple[list, list, list]:
"""
并行执行工具列表
@@ -162,11 +163,13 @@ def run_tools_parallel(
input_file: 输入文件路径
input_type: 输入类型
output_dir: 输出目录
scan_id: 扫描任务 ID用于记录日志
Returns:
tuple: (result_files, failed_tools, successful_tool_names)
"""
from apps.scan.tasks.url_fetch import run_url_fetcher_task
from apps.scan.utils import user_log
futures: dict[str, object] = {}
failed_tools: list[dict] = []
@@ -192,6 +195,9 @@ def run_tools_parallel(
exec_params["timeout"],
)
# 记录工具开始执行日志
user_log(scan_id, "url_fetch", f"Running {tool_name}: {exec_params['command']}")
# 提交并行任务
future = run_url_fetcher_task.submit(
tool_name=tool_name,
@@ -208,22 +214,28 @@ def run_tools_parallel(
result = future.result()
if result and result['success']:
result_files.append(result['output_file'])
url_count = result['url_count']
logger.info(
"✓ 工具 %s 执行成功 - 发现 URL: %d",
tool_name, result['url_count']
tool_name, url_count
)
user_log(scan_id, "url_fetch", f"{tool_name} completed: found {url_count} urls")
else:
reason = '未生成结果或无有效URL'
failed_tools.append({
'tool': tool_name,
'reason': '未生成结果或无有效URL'
'reason': reason
})
logger.warning("⚠️ 工具 %s 未生成有效结果", tool_name)
user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error")
except Exception as e:
reason = str(e)
failed_tools.append({
'tool': tool_name,
'reason': str(e)
'reason': reason
})
logger.warning("⚠️ 工具 %s 执行失败: %s", tool_name, e)
user_log(scan_id, "url_fetch", f"{tool_name} failed: {reason}", "error")
# 计算成功的工具列表
failed_tool_names = [f['tool'] for f in failed_tools]

View File

@@ -12,7 +12,7 @@ from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_completed,
on_scan_flow_failed,
)
from apps.scan.utils import build_scan_command, ensure_nuclei_templates_local
from apps.scan.utils import build_scan_command, ensure_nuclei_templates_local, user_log
from apps.scan.tasks.vuln_scan import (
export_endpoints_task,
run_vuln_tool_task,
@@ -141,6 +141,7 @@ def endpoints_vuln_scan_flow(
# Dalfox XSS 使用流式任务,一边解析一边保存漏洞结果
if tool_name == "dalfox_xss":
logger.info("开始执行漏洞扫描工具 %s(流式保存漏洞结果,已提交任务)", tool_name)
user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}")
future = run_and_stream_save_dalfox_vulns_task.submit(
cmd=command,
tool_name=tool_name,
@@ -163,6 +164,7 @@ def endpoints_vuln_scan_flow(
elif tool_name == "nuclei":
# Nuclei 使用流式任务
logger.info("开始执行漏洞扫描工具 %s(流式保存漏洞结果,已提交任务)", tool_name)
user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}")
future = run_and_stream_save_nuclei_vulns_task.submit(
cmd=command,
tool_name=tool_name,
@@ -185,6 +187,7 @@ def endpoints_vuln_scan_flow(
else:
# 其他工具仍使用非流式执行逻辑
logger.info("开始执行漏洞扫描工具 %s(已提交任务)", tool_name)
user_log(scan_id, "vuln_scan", f"Running {tool_name}: {command}")
future = run_vuln_tool_task.submit(
tool_name=tool_name,
command=command,
@@ -203,24 +206,34 @@ def endpoints_vuln_scan_flow(
# 统一收集所有工具的执行结果
for tool_name, meta in tool_futures.items():
future = meta["future"]
result = future.result()
try:
result = future.result()
if meta["mode"] == "streaming":
tool_results[tool_name] = {
"command": meta["command"],
"timeout": meta["timeout"],
"processed_records": result.get("processed_records"),
"created_vulns": result.get("created_vulns"),
"command_log_file": meta["log_file"],
}
else:
tool_results[tool_name] = {
"command": meta["command"],
"timeout": meta["timeout"],
"duration": result.get("duration"),
"returncode": result.get("returncode"),
"command_log_file": result.get("command_log_file"),
}
if meta["mode"] == "streaming":
created_vulns = result.get("created_vulns", 0)
tool_results[tool_name] = {
"command": meta["command"],
"timeout": meta["timeout"],
"processed_records": result.get("processed_records"),
"created_vulns": created_vulns,
"command_log_file": meta["log_file"],
}
logger.info("✓ 工具 %s 执行完成 - 漏洞: %d", tool_name, created_vulns)
user_log(scan_id, "vuln_scan", f"{tool_name} completed: found {created_vulns} vulnerabilities")
else:
tool_results[tool_name] = {
"command": meta["command"],
"timeout": meta["timeout"],
"duration": result.get("duration"),
"returncode": result.get("returncode"),
"command_log_file": result.get("command_log_file"),
}
logger.info("✓ 工具 %s 执行完成 - returncode=%s", tool_name, result.get("returncode"))
user_log(scan_id, "vuln_scan", f"{tool_name} completed")
except Exception as e:
reason = str(e)
logger.error("工具 %s 执行失败: %s", tool_name, e, exc_info=True)
user_log(scan_id, "vuln_scan", f"{tool_name} failed: {reason}", "error")
return {
"success": True,

View File

@@ -1,5 +1,6 @@
from apps.common.prefect_django_setup import setup_django_for_prefect
"""
漏洞扫描主 Flow
"""
import logging
from typing import Dict, Tuple
@@ -11,6 +12,7 @@ from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_failed,
)
from apps.scan.configs.command_templates import get_command_template
from apps.scan.utils import user_log, wait_for_system_load
from .endpoints_vuln_scan_flow import endpoints_vuln_scan_flow
@@ -61,6 +63,9 @@ def vuln_scan_flow(
- nuclei: 通用漏洞扫描(流式保存,支持模板 commit hash 同步)
"""
try:
# 负载检查:等待系统资源充足
wait_for_system_load(context="vuln_scan_flow")
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
@@ -72,6 +77,9 @@ def vuln_scan_flow(
if not enabled_tools:
raise ValueError("enabled_tools 不能为空")
logger.info("开始漏洞扫描 - Scan ID: %s, Target: %s", scan_id, target_name)
user_log(scan_id, "vuln_scan", "Starting vulnerability scan")
# Step 1: 分类工具
endpoints_tools, other_tools = _classify_vuln_tools(enabled_tools)
@@ -99,6 +107,14 @@ def vuln_scan_flow(
enabled_tools=endpoints_tools,
)
# 记录 Flow 完成
total_vulns = sum(
r.get("created_vulns", 0)
for r in endpoint_result.get("tool_results", {}).values()
)
logger.info("✓ 漏洞扫描完成 - 新增漏洞: %d", total_vulns)
user_log(scan_id, "vuln_scan", f"vuln_scan completed: found {total_vulns} vulnerabilities")
# 目前只有一个子 Flow直接返回其结果
return endpoint_result

View File

@@ -162,6 +162,8 @@ def on_initiate_scan_flow_completed(flow: Flow, flow_run: FlowRun, state: State)
# 执行状态更新并获取统计数据
stats = _update_completed_status()
# 注意:物化视图刷新已迁移到 pg_ivm 增量维护,无需手动标记刷新
# 发送通知(包含统计摘要)
logger.info("准备发送扫描完成通知 - Scan ID: %s, Target: %s", scan_id, target_name)
try:

View File

@@ -14,6 +14,7 @@ from prefect import Flow
from prefect.client.schemas import FlowRun, State
from apps.scan.utils.performance import FlowPerformanceTracker
from apps.scan.utils import user_log
logger = logging.getLogger(__name__)
@@ -136,6 +137,7 @@ def on_scan_flow_failed(flow: Flow, flow_run: FlowRun, state: State) -> None:
- 更新阶段进度为 failed
- 发送扫描失败通知
- 记录性能指标(含错误信息)
- 写入 ScanLog 供前端显示
Args:
flow: Prefect Flow 对象
@@ -152,6 +154,11 @@ def on_scan_flow_failed(flow: Flow, flow_run: FlowRun, state: State) -> None:
# 提取错误信息
error_message = str(state.message) if state.message else "未知错误"
# 写入 ScanLog 供前端显示
stage = _get_stage_from_flow_name(flow.name)
if scan_id and stage:
user_log(scan_id, stage, f"Failed: {error_message}", "error")
# 记录性能指标(失败情况)
tracker = _flow_trackers.pop(str(flow_run.id), None)
if tracker:

View File

@@ -0,0 +1,175 @@
# Generated by Django 5.2.7 on 2026-01-06 00:55
import django.contrib.postgres.fields
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
('engine', '0001_initial'),
('targets', '0001_initial'),
]
operations = [
migrations.CreateModel(
name='NotificationSettings',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('discord_enabled', models.BooleanField(default=False, help_text='是否启用 Discord 通知')),
('discord_webhook_url', models.URLField(blank=True, default='', help_text='Discord Webhook URL')),
('categories', models.JSONField(default=dict, help_text='各分类通知开关,如 {"scan": true, "vulnerability": true, "asset": true, "system": false}')),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
],
options={
'verbose_name': '通知设置',
'verbose_name_plural': '通知设置',
'db_table': 'notification_settings',
},
),
migrations.CreateModel(
name='SubfinderProviderSettings',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('providers', models.JSONField(default=dict, help_text='各 Provider 的 API Key 配置')),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
],
options={
'verbose_name': 'Subfinder Provider 配置',
'verbose_name_plural': 'Subfinder Provider 配置',
'db_table': 'subfinder_provider_settings',
},
),
migrations.CreateModel(
name='Notification',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('category', models.CharField(choices=[('scan', '扫描任务'), ('vulnerability', '漏洞发现'), ('asset', '资产发现'), ('system', '系统消息')], db_index=True, default='system', help_text='通知分类', max_length=20)),
('level', models.CharField(choices=[('low', ''), ('medium', ''), ('high', ''), ('critical', '严重')], db_index=True, default='low', help_text='通知级别', max_length=20)),
('title', models.CharField(help_text='通知标题', max_length=200)),
('message', models.CharField(help_text='通知内容', max_length=2000)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('is_read', models.BooleanField(default=False, help_text='是否已读')),
('read_at', models.DateTimeField(blank=True, help_text='阅读时间', null=True)),
],
options={
'verbose_name': '通知',
'verbose_name_plural': '通知',
'db_table': 'notification',
'ordering': ['-created_at'],
'indexes': [models.Index(fields=['-created_at'], name='notificatio_created_c430f0_idx'), models.Index(fields=['category', '-created_at'], name='notificatio_categor_df0584_idx'), models.Index(fields=['level', '-created_at'], name='notificatio_level_0e5d12_idx'), models.Index(fields=['is_read', '-created_at'], name='notificatio_is_read_518ce0_idx')],
},
),
migrations.CreateModel(
name='Scan',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('engine_ids', django.contrib.postgres.fields.ArrayField(base_field=models.IntegerField(), default=list, help_text='引擎 ID 列表', size=None)),
('engine_names', models.JSONField(default=list, help_text='引擎名称列表,如 ["引擎A", "引擎B"]')),
('yaml_configuration', models.TextField(default='', help_text='YAML 格式的扫描配置')),
('created_at', models.DateTimeField(auto_now_add=True, help_text='任务创建时间')),
('stopped_at', models.DateTimeField(blank=True, help_text='扫描结束时间', null=True)),
('status', models.CharField(choices=[('cancelled', '已取消'), ('completed', '已完成'), ('failed', '失败'), ('initiated', '初始化'), ('running', '运行中')], db_index=True, default='initiated', help_text='任务状态', max_length=20)),
('results_dir', models.CharField(blank=True, default='', help_text='结果存储目录', max_length=100)),
('container_ids', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='容器 ID 列表Docker Container ID', size=None)),
('error_message', models.CharField(blank=True, default='', help_text='错误信息', max_length=2000)),
('deleted_at', models.DateTimeField(blank=True, db_index=True, help_text='删除时间NULL表示未删除', null=True)),
('progress', models.IntegerField(default=0, help_text='扫描进度 0-100')),
('current_stage', models.CharField(blank=True, default='', help_text='当前扫描阶段', max_length=50)),
('stage_progress', models.JSONField(default=dict, help_text='各阶段进度详情')),
('cached_subdomains_count', models.IntegerField(default=0, help_text='缓存的子域名数量')),
('cached_websites_count', models.IntegerField(default=0, help_text='缓存的网站数量')),
('cached_endpoints_count', models.IntegerField(default=0, help_text='缓存的端点数量')),
('cached_ips_count', models.IntegerField(default=0, help_text='缓存的IP地址数量')),
('cached_directories_count', models.IntegerField(default=0, help_text='缓存的目录数量')),
('cached_vulns_total', models.IntegerField(default=0, help_text='缓存的漏洞总数')),
('cached_vulns_critical', models.IntegerField(default=0, help_text='缓存的严重漏洞数量')),
('cached_vulns_high', models.IntegerField(default=0, help_text='缓存的高危漏洞数量')),
('cached_vulns_medium', models.IntegerField(default=0, help_text='缓存的中危漏洞数量')),
('cached_vulns_low', models.IntegerField(default=0, help_text='缓存的低危漏洞数量')),
('stats_updated_at', models.DateTimeField(blank=True, help_text='统计数据最后更新时间', null=True)),
('target', models.ForeignKey(help_text='扫描目标', on_delete=django.db.models.deletion.CASCADE, related_name='scans', to='targets.target')),
('worker', models.ForeignKey(blank=True, help_text='执行扫描的 Worker 节点', null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='scans', to='engine.workernode')),
],
options={
'verbose_name': '扫描任务',
'verbose_name_plural': '扫描任务',
'db_table': 'scan',
'ordering': ['-created_at'],
},
),
migrations.CreateModel(
name='ScanLog',
fields=[
('id', models.BigAutoField(primary_key=True, serialize=False)),
('level', models.CharField(choices=[('info', 'Info'), ('warning', 'Warning'), ('error', 'Error')], default='info', help_text='日志级别', max_length=10)),
('content', models.TextField(help_text='日志内容')),
('created_at', models.DateTimeField(auto_now_add=True, db_index=True, help_text='创建时间')),
('scan', models.ForeignKey(help_text='关联的扫描任务', on_delete=django.db.models.deletion.CASCADE, related_name='logs', to='scan.scan')),
],
options={
'verbose_name': '扫描日志',
'verbose_name_plural': '扫描日志',
'db_table': 'scan_log',
'ordering': ['created_at'],
},
),
migrations.CreateModel(
name='ScheduledScan',
fields=[
('id', models.AutoField(primary_key=True, serialize=False)),
('name', models.CharField(help_text='任务名称', max_length=200)),
('engine_ids', django.contrib.postgres.fields.ArrayField(base_field=models.IntegerField(), default=list, help_text='引擎 ID 列表', size=None)),
('engine_names', models.JSONField(default=list, help_text='引擎名称列表,如 ["引擎A", "引擎B"]')),
('yaml_configuration', models.TextField(default='', help_text='YAML 格式的扫描配置')),
('cron_expression', models.CharField(default='0 2 * * *', help_text='Cron 表达式,格式:分 时 日 月 周', max_length=100)),
('is_enabled', models.BooleanField(db_index=True, default=True, help_text='是否启用')),
('run_count', models.IntegerField(default=0, help_text='已执行次数')),
('last_run_time', models.DateTimeField(blank=True, help_text='上次执行时间', null=True)),
('next_run_time', models.DateTimeField(blank=True, help_text='下次执行时间', null=True)),
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
('updated_at', models.DateTimeField(auto_now=True, help_text='更新时间')),
('organization', models.ForeignKey(blank=True, help_text='扫描组织(设置后执行时动态获取组织下所有目标)', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='scheduled_scans', to='targets.organization')),
('target', models.ForeignKey(blank=True, help_text='扫描单个目标(与 organization 二选一)', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='scheduled_scans', to='targets.target')),
],
options={
'verbose_name': '定时扫描任务',
'verbose_name_plural': '定时扫描任务',
'db_table': 'scheduled_scan',
'ordering': ['-created_at'],
},
),
migrations.AddIndex(
model_name='scan',
index=models.Index(fields=['-created_at'], name='scan_created_0bb6c7_idx'),
),
migrations.AddIndex(
model_name='scan',
index=models.Index(fields=['target'], name='scan_target__718b9d_idx'),
),
migrations.AddIndex(
model_name='scan',
index=models.Index(fields=['deleted_at', '-created_at'], name='scan_deleted_eb17e8_idx'),
),
migrations.AddIndex(
model_name='scanlog',
index=models.Index(fields=['scan', 'created_at'], name='scan_log_scan_id_c4814a_idx'),
),
migrations.AddIndex(
model_name='scheduledscan',
index=models.Index(fields=['-created_at'], name='scheduled_s_created_9b9c2e_idx'),
),
migrations.AddIndex(
model_name='scheduledscan',
index=models.Index(fields=['is_enabled', '-created_at'], name='scheduled_s_is_enab_23d660_idx'),
),
migrations.AddIndex(
model_name='scheduledscan',
index=models.Index(fields=['name'], name='scheduled_s_name_bf332d_idx'),
),
]

View File

@@ -0,0 +1,18 @@
# Generated by Django 5.2.7 on 2026-01-07 14:03
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('scan', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='scan',
name='cached_screenshots_count',
field=models.IntegerField(default=0, help_text='缓存的截图数量'),
),
]

View File

@@ -0,0 +1,23 @@
# Generated manually for WeCom notification support
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('scan', '0002_add_cached_screenshots_count'),
]
operations = [
migrations.AddField(
model_name='notificationsettings',
name='wecom_enabled',
field=models.BooleanField(default=False, help_text='是否启用企业微信通知'),
),
migrations.AddField(
model_name='notificationsettings',
name='wecom_webhook_url',
field=models.URLField(blank=True, default='', help_text='企业微信机器人 Webhook URL'),
),
]

View File

@@ -0,0 +1,18 @@
"""Scan Models - 统一导出"""
from .scan_models import Scan, SoftDeleteManager
from .scan_log_model import ScanLog
from .scheduled_scan_model import ScheduledScan
from .subfinder_provider_settings_model import SubfinderProviderSettings
# 兼容旧名称(已废弃,请使用 SubfinderProviderSettings
ProviderSettings = SubfinderProviderSettings
__all__ = [
'Scan',
'ScanLog',
'ScheduledScan',
'SoftDeleteManager',
'SubfinderProviderSettings',
'ProviderSettings', # 兼容旧名称
]

View File

@@ -0,0 +1,41 @@
"""扫描日志模型"""
from django.db import models
class ScanLog(models.Model):
"""扫描日志模型"""
class Level(models.TextChoices):
INFO = 'info', 'Info'
WARNING = 'warning', 'Warning'
ERROR = 'error', 'Error'
id = models.BigAutoField(primary_key=True)
scan = models.ForeignKey(
'Scan',
on_delete=models.CASCADE,
related_name='logs',
db_index=True,
help_text='关联的扫描任务'
)
level = models.CharField(
max_length=10,
choices=Level.choices,
default=Level.INFO,
help_text='日志级别'
)
content = models.TextField(help_text='日志内容')
created_at = models.DateTimeField(auto_now_add=True, db_index=True, help_text='创建时间')
class Meta:
db_table = 'scan_log'
verbose_name = '扫描日志'
verbose_name_plural = '扫描日志'
ordering = ['created_at']
indexes = [
models.Index(fields=['scan', 'created_at']),
]
def __str__(self):
return f"[{self.level}] {self.content[:50]}"

View File

@@ -1,9 +1,9 @@
"""扫描相关模型"""
from django.db import models
from django.contrib.postgres.fields import ArrayField
from ..common.definitions import ScanStatus
from apps.common.definitions import ScanStatus
class SoftDeleteManager(models.Manager):
@@ -20,11 +20,19 @@ class Scan(models.Model):
target = models.ForeignKey('targets.Target', on_delete=models.CASCADE, related_name='scans', help_text='扫描目标')
engine = models.ForeignKey(
'engine.ScanEngine',
on_delete=models.CASCADE,
related_name='scans',
help_text='使用的扫描引擎'
# 多引擎支持字段
engine_ids = ArrayField(
models.IntegerField(),
default=list,
help_text='引擎 ID 列表'
)
engine_names = models.JSONField(
default=list,
help_text='引擎名称列表,如 ["引擎A", "引擎B"]'
)
yaml_configuration = models.TextField(
default='',
help_text='YAML 格式的扫描配置'
)
created_at = models.DateTimeField(auto_now_add=True, help_text='任务创建时间')
@@ -76,6 +84,7 @@ class Scan(models.Model):
cached_endpoints_count = models.IntegerField(default=0, help_text='缓存的端点数量')
cached_ips_count = models.IntegerField(default=0, help_text='缓存的IP地址数量')
cached_directories_count = models.IntegerField(default=0, help_text='缓存的目录数量')
cached_screenshots_count = models.IntegerField(default=0, help_text='缓存的截图数量')
cached_vulns_total = models.IntegerField(default=0, help_text='缓存的漏洞总数')
cached_vulns_critical = models.IntegerField(default=0, help_text='缓存的严重漏洞数量')
cached_vulns_high = models.IntegerField(default=0, help_text='缓存的高危漏洞数量')
@@ -89,92 +98,10 @@ class Scan(models.Model):
verbose_name_plural = '扫描任务'
ordering = ['-created_at']
indexes = [
models.Index(fields=['-created_at']), # 优化按创建时间降序排序list 查询的默认排序)
models.Index(fields=['target']), # 优化按目标查询扫描任务
models.Index(fields=['deleted_at', '-created_at']), # 软删除 + 时间索引
models.Index(fields=['-created_at']),
models.Index(fields=['target']),
models.Index(fields=['deleted_at', '-created_at']),
]
def __str__(self):
return f"Scan #{self.id} - {self.target.name}"
class ScheduledScan(models.Model):
"""
定时扫描任务模型
调度机制
- APScheduler 每分钟检查 next_run_time
- 到期任务通过 task_distributor 分发到 Worker 执行
- 支持 cron 表达式进行灵活调度
扫描模式二选一
- 组织扫描设置 organization执行时动态获取组织下所有目标
- 目标扫描设置 target扫描单个目标
- organization 优先级高于 target
"""
id = models.AutoField(primary_key=True)
# 基本信息
name = models.CharField(max_length=200, help_text='任务名称')
# 关联的扫描引擎
engine = models.ForeignKey(
'engine.ScanEngine',
on_delete=models.CASCADE,
related_name='scheduled_scans',
help_text='使用的扫描引擎'
)
# 关联的组织(组织扫描模式:执行时动态获取组织下所有目标)
organization = models.ForeignKey(
'targets.Organization',
on_delete=models.CASCADE,
related_name='scheduled_scans',
null=True,
blank=True,
help_text='扫描组织(设置后执行时动态获取组织下所有目标)'
)
# 关联的目标(目标扫描模式:扫描单个目标)
target = models.ForeignKey(
'targets.Target',
on_delete=models.CASCADE,
related_name='scheduled_scans',
null=True,
blank=True,
help_text='扫描单个目标(与 organization 二选一)'
)
# 调度配置 - 直接使用 Cron 表达式
cron_expression = models.CharField(
max_length=100,
default='0 2 * * *',
help_text='Cron 表达式,格式:分 时 日 月 周'
)
# 状态
is_enabled = models.BooleanField(default=True, db_index=True, help_text='是否启用')
# 执行统计
run_count = models.IntegerField(default=0, help_text='已执行次数')
last_run_time = models.DateTimeField(null=True, blank=True, help_text='上次执行时间')
next_run_time = models.DateTimeField(null=True, blank=True, help_text='下次执行时间')
# 时间戳
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
class Meta:
db_table = 'scheduled_scan'
verbose_name = '定时扫描任务'
verbose_name_plural = '定时扫描任务'
ordering = ['-created_at']
indexes = [
models.Index(fields=['-created_at']),
models.Index(fields=['is_enabled', '-created_at']),
models.Index(fields=['name']), # 优化 name 搜索
]
def __str__(self):
return f"ScheduledScan #{self.id} - {self.name}"

View File

@@ -0,0 +1,73 @@
"""定时扫描任务模型"""
from django.db import models
from django.contrib.postgres.fields import ArrayField
class ScheduledScan(models.Model):
"""定时扫描任务模型"""
id = models.AutoField(primary_key=True)
name = models.CharField(max_length=200, help_text='任务名称')
engine_ids = ArrayField(
models.IntegerField(),
default=list,
help_text='引擎 ID 列表'
)
engine_names = models.JSONField(
default=list,
help_text='引擎名称列表,如 ["引擎A", "引擎B"]'
)
yaml_configuration = models.TextField(
default='',
help_text='YAML 格式的扫描配置'
)
organization = models.ForeignKey(
'targets.Organization',
on_delete=models.CASCADE,
related_name='scheduled_scans',
null=True,
blank=True,
help_text='扫描组织(设置后执行时动态获取组织下所有目标)'
)
target = models.ForeignKey(
'targets.Target',
on_delete=models.CASCADE,
related_name='scheduled_scans',
null=True,
blank=True,
help_text='扫描单个目标(与 organization 二选一)'
)
cron_expression = models.CharField(
max_length=100,
default='0 2 * * *',
help_text='Cron 表达式,格式:分 时 日 月 周'
)
is_enabled = models.BooleanField(default=True, db_index=True, help_text='是否启用')
run_count = models.IntegerField(default=0, help_text='已执行次数')
last_run_time = models.DateTimeField(null=True, blank=True, help_text='上次执行时间')
next_run_time = models.DateTimeField(null=True, blank=True, help_text='下次执行时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
updated_at = models.DateTimeField(auto_now=True, help_text='更新时间')
class Meta:
db_table = 'scheduled_scan'
verbose_name = '定时扫描任务'
verbose_name_plural = '定时扫描任务'
ordering = ['-created_at']
indexes = [
models.Index(fields=['-created_at']),
models.Index(fields=['is_enabled', '-created_at']),
models.Index(fields=['name']),
]
def __str__(self):
return f"ScheduledScan #{self.id} - {self.name}"

View File

@@ -0,0 +1,64 @@
"""Subfinder Provider 配置模型(单例模式)
用于存储 subfinder 第三方数据源的 API Key 配置
"""
from django.db import models
class SubfinderProviderSettings(models.Model):
"""
Subfinder Provider 配置(单例模式)
存储第三方数据源的 API Key 配置,用于 subfinder 子域名发现
支持的 Provider:
- fofa: email + api_key (composite)
- censys: api_id + api_secret (composite)
- hunter, shodan, zoomeye, securitytrails, threatbook, quake: api_key (single)
"""
providers = models.JSONField(
default=dict,
help_text='各 Provider 的 API Key 配置'
)
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
db_table = 'subfinder_provider_settings'
verbose_name = 'Subfinder Provider 配置'
verbose_name_plural = 'Subfinder Provider 配置'
DEFAULT_PROVIDERS = {
'fofa': {'enabled': False, 'email': '', 'api_key': ''},
'hunter': {'enabled': False, 'api_key': ''},
'shodan': {'enabled': False, 'api_key': ''},
'censys': {'enabled': False, 'api_id': '', 'api_secret': ''},
'zoomeye': {'enabled': False, 'api_key': ''},
'securitytrails': {'enabled': False, 'api_key': ''},
'threatbook': {'enabled': False, 'api_key': ''},
'quake': {'enabled': False, 'api_key': ''},
}
def save(self, *args, **kwargs):
self.pk = 1
super().save(*args, **kwargs)
@classmethod
def get_instance(cls) -> 'SubfinderProviderSettings':
"""获取或创建单例实例"""
obj, _ = cls.objects.get_or_create(
pk=1,
defaults={'providers': cls.DEFAULT_PROVIDERS.copy()}
)
return obj
def get_provider_config(self, provider: str) -> dict:
"""获取指定 Provider 的配置"""
return self.providers.get(provider, self.DEFAULT_PROVIDERS.get(provider, {}))
def is_provider_enabled(self, provider: str) -> bool:
"""检查指定 Provider 是否启用"""
config = self.get_provider_config(provider)
return config.get('enabled', False)

View File

@@ -1,8 +1,14 @@
"""通知系统数据模型"""
from django.db import models
import logging
from datetime import timedelta
from .types import NotificationLevel, NotificationCategory
from django.db import models
from django.utils import timezone
from .types import NotificationCategory, NotificationLevel
logger = logging.getLogger(__name__)
class NotificationSettings(models.Model):
@@ -10,31 +16,34 @@ class NotificationSettings(models.Model):
通知设置(单例模型)
存储 Discord webhook 配置和各分类的通知开关
"""
# Discord 配置
discord_enabled = models.BooleanField(default=False, help_text='是否启用 Discord 通知')
discord_webhook_url = models.URLField(blank=True, default='', help_text='Discord Webhook URL')
# 企业微信配置
wecom_enabled = models.BooleanField(default=False, help_text='是否启用企业微信通知')
wecom_webhook_url = models.URLField(blank=True, default='', help_text='企业微信机器人 Webhook URL')
# 分类开关(使用 JSONField 存储)
categories = models.JSONField(
default=dict,
help_text='各分类通知开关,如 {"scan": true, "vulnerability": true, "asset": true, "system": false}'
)
# 时间信息
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
db_table = 'notification_settings'
verbose_name = '通知设置'
verbose_name_plural = '通知设置'
def save(self, *args, **kwargs):
# 单例模式:强制只有一条记录
self.pk = 1
self.pk = 1 # 单例模式
super().save(*args, **kwargs)
@classmethod
def get_instance(cls) -> 'NotificationSettings':
"""获取或创建单例实例"""
@@ -52,7 +61,7 @@ class NotificationSettings(models.Model):
}
)
return obj
def is_category_enabled(self, category: str) -> bool:
"""检查指定分类是否启用通知"""
return self.categories.get(category, False)
@@ -60,10 +69,9 @@ class NotificationSettings(models.Model):
class Notification(models.Model):
"""通知模型"""
id = models.AutoField(primary_key=True)
# 通知分类
category = models.CharField(
max_length=20,
choices=NotificationCategory.choices,
@@ -71,8 +79,7 @@ class Notification(models.Model):
db_index=True,
help_text='通知分类'
)
# 通知级别
level = models.CharField(
max_length=20,
choices=NotificationLevel.choices,
@@ -80,16 +87,15 @@ class Notification(models.Model):
db_index=True,
help_text='通知级别'
)
title = models.CharField(max_length=200, help_text='通知标题')
message = models.CharField(max_length=2000, help_text='通知内容')
# 时间信息
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
is_read = models.BooleanField(default=False, help_text='是否已读')
read_at = models.DateTimeField(null=True, blank=True, help_text='阅读时间')
class Meta:
db_table = 'notification'
verbose_name = '通知'
@@ -101,44 +107,26 @@ class Notification(models.Model):
models.Index(fields=['level', '-created_at']),
models.Index(fields=['is_read', '-created_at']),
]
def __str__(self):
return f"{self.get_level_display()} - {self.title}"
@classmethod
def cleanup_old_notifications(cls):
"""
清理超过15天的旧通知硬编码
Returns:
int: 删除的通知数量
"""
from datetime import timedelta
from django.utils import timezone
# 硬编码只保留最近15天的通知
def cleanup_old_notifications(cls) -> int:
"""清理超过15天的旧通知"""
cutoff_date = timezone.now() - timedelta(days=15)
delete_result = cls.objects.filter(created_at__lt=cutoff_date).delete()
return delete_result[0] if delete_result[0] else 0
deleted_count, _ = cls.objects.filter(created_at__lt=cutoff_date).delete()
return deleted_count or 0
def save(self, *args, **kwargs):
"""
重写save方法在创建新通知时自动清理旧通知
"""
"""重写save方法在创建新通知时自动清理旧通知"""
is_new = self.pk is None
super().save(*args, **kwargs)
# 只在创建新通知时执行清理自动清理超过15天的通知
if is_new:
try:
deleted_count = self.__class__.cleanup_old_notifications()
if deleted_count > 0:
import logging
logger = logging.getLogger(__name__)
logger.info(f"自动清理{deleted_count} 条超过15天的旧通知")
except Exception as e:
# 清理失败不应影响通知创建
import logging
logger = logging.getLogger(__name__)
logger.warning(f"通知自动清理失败: {e}")
logger.info("自动清理了 %d 条超过15天的旧通知", deleted_count)
except Exception:
logger.warning("通知自动清理失败", exc_info=True)

View File

@@ -1,52 +1,70 @@
"""通知系统仓储层模块"""
import logging
from typing import TypedDict
from dataclasses import dataclass
from typing import Optional
from django.db.models import QuerySet
from django.utils import timezone
from apps.common.decorators import auto_ensure_db_connection
from .models import Notification, NotificationSettings
from .models import Notification, NotificationSettings
logger = logging.getLogger(__name__)
class NotificationSettingsData(TypedDict):
"""通知设置数据结构"""
@dataclass
class NotificationSettingsData:
"""通知设置更新数据"""
discord_enabled: bool
discord_webhook_url: str
categories: dict[str, bool]
wecom_enabled: bool = False
wecom_webhook_url: str = ''
@auto_ensure_db_connection
class NotificationSettingsRepository:
"""通知设置仓储层"""
def get_settings(self) -> NotificationSettings:
"""获取通知设置单例"""
return NotificationSettings.get_instance()
def update_settings(
self,
discord_enabled: bool,
discord_webhook_url: str,
categories: dict[str, bool]
) -> NotificationSettings:
def update_settings(self, data: NotificationSettingsData) -> NotificationSettings:
"""更新通知设置"""
settings = NotificationSettings.get_instance()
settings.discord_enabled = discord_enabled
settings.discord_webhook_url = discord_webhook_url
settings.categories = categories
settings.discord_enabled = data.discord_enabled
settings.discord_webhook_url = data.discord_webhook_url
settings.wecom_enabled = data.wecom_enabled
settings.wecom_webhook_url = data.wecom_webhook_url
settings.categories = data.categories
settings.save()
return settings
def is_category_enabled(self, category: str) -> bool:
"""检查指定分类是否启用"""
settings = self.get_settings()
return settings.is_category_enabled(category)
return self.get_settings().is_category_enabled(category)
@auto_ensure_db_connection
class DjangoNotificationRepository:
def get_filtered(self, level: str | None = None, unread: bool | None = None):
"""通知数据仓储层"""
def get_filtered(
self,
level: Optional[str] = None,
unread: Optional[bool] = None
) -> QuerySet[Notification]:
"""
获取过滤后的通知列表
Args:
level: 通知级别过滤
unread: 已读状态过滤 (True=未读, False=已读, None=全部)
"""
queryset = Notification.objects.all()
if level:
@@ -60,16 +78,24 @@ class DjangoNotificationRepository:
return queryset.order_by("-created_at")
def get_unread_count(self) -> int:
"""获取未读通知数量"""
return Notification.objects.filter(is_read=False).count()
def mark_all_as_read(self) -> int:
updated = Notification.objects.filter(is_read=False).update(
"""标记所有通知为已读,返回更新数量"""
return Notification.objects.filter(is_read=False).update(
is_read=True,
read_at=timezone.now(),
)
return updated
def create(self, title: str, message: str, level: str, category: str = 'system') -> Notification:
def create(
self,
title: str,
message: str,
level: str,
category: str = 'system'
) -> Notification:
"""创建新通知"""
return Notification.objects.create(
category=category,
level=level,

View File

@@ -60,13 +60,12 @@ def push_to_external_channels(notification: Notification) -> None:
except Exception as e:
logger.warning(f"Discord 推送失败: {e}")
# 未来扩展Slack
# if settings.slack_enabled and settings.slack_webhook_url:
# _send_slack(notification, settings.slack_webhook_url)
# 未来扩展Telegram
# if settings.telegram_enabled and settings.telegram_bot_token:
# _send_telegram(notification, settings.telegram_chat_id)
# 企业微信渠道
if settings.wecom_enabled and settings.wecom_webhook_url:
try:
_send_wecom(notification, settings.wecom_webhook_url)
except Exception as e:
logger.warning(f"企业微信推送失败: {e}")
def _send_discord(notification: Notification, webhook_url: str) -> bool:
@@ -103,6 +102,41 @@ def _send_discord(notification: Notification, webhook_url: str) -> bool:
return False
def _send_wecom(notification: Notification, webhook_url: str) -> bool:
"""发送到企业微信机器人 Webhook"""
try:
emoji = CATEGORY_EMOJI.get(notification.category, '📢')
# 企业微信 Markdown 格式
content = f"""**{emoji} {notification.title}**
> 级别:{notification.get_level_display()}
> 分类:{notification.get_category_display()}
{notification.message}"""
payload = {
'msgtype': 'markdown',
'markdown': {'content': content}
}
response = requests.post(webhook_url, json=payload, timeout=10)
if response.status_code == 200:
result = response.json()
if result.get('errcode') == 0:
logger.info(f"企业微信通知发送成功 - {notification.title}")
return True
logger.warning(f"企业微信发送失败 - errcode: {result.get('errcode')}, errmsg: {result.get('errmsg')}")
return False
logger.warning(f"企业微信发送失败 - 状态码: {response.status_code}")
return False
except requests.RequestException as e:
logger.error(f"企业微信网络错误: {e}")
return False
# ============================================================
# 设置服务
# ============================================================
@@ -121,31 +155,43 @@ class NotificationSettingsService:
'enabled': settings.discord_enabled,
'webhookUrl': settings.discord_webhook_url,
},
'wecom': {
'enabled': settings.wecom_enabled,
'webhookUrl': settings.wecom_webhook_url,
},
'categories': settings.categories,
}
def update_settings(self, data: dict) -> dict:
"""更新通知设置
注意DRF CamelCaseJSONParser 会将前端的 webhookUrl 转换为 webhook_url
"""
discord_data = data.get('discord', {})
wecom_data = data.get('wecom', {})
categories = data.get('categories', {})
# CamelCaseJSONParser 转换后的字段名是 webhook_url
webhook_url = discord_data.get('webhook_url', '')
discord_webhook_url = discord_data.get('webhook_url', '')
wecom_webhook_url = wecom_data.get('webhook_url', '')
settings = self.repo.update_settings(
discord_enabled=discord_data.get('enabled', False),
discord_webhook_url=webhook_url,
discord_webhook_url=discord_webhook_url,
wecom_enabled=wecom_data.get('enabled', False),
wecom_webhook_url=wecom_webhook_url,
categories=categories,
)
return {
'discord': {
'enabled': settings.discord_enabled,
'webhookUrl': settings.discord_webhook_url,
},
'wecom': {
'enabled': settings.wecom_enabled,
'webhookUrl': settings.wecom_webhook_url,
},
'categories': settings.categories,
}

View File

@@ -21,9 +21,6 @@ urlpatterns = [
# 标记全部已读
path('mark-all-as-read/', NotificationMarkAllAsReadView.as_view(), name='mark-all-as-read'),
# 测试通知
path('test/', views.notifications_test, name='test'),
]
# WebSocket 实时通知路由在 routing.py 中定义ws://host/ws/notifications/

View File

@@ -23,45 +23,7 @@ from .services import NotificationService, NotificationSettingsService
logger = logging.getLogger(__name__)
def notifications_test(request):
"""
测试通知推送
"""
try:
from .services import create_notification
from django.http import JsonResponse
level_param = request.GET.get('level', NotificationLevel.LOW)
try:
level_choice = NotificationLevel(level_param)
except ValueError:
level_choice = NotificationLevel.LOW
title = request.GET.get('title') or "测试通知"
message = request.GET.get('message') or "这是一条测试通知消息"
# 创建测试通知
notification = create_notification(
title=title,
message=message,
level=level_choice
)
return JsonResponse({
'success': True,
'message': '测试通知已发送',
'notification_id': notification.id
})
except Exception as e:
logger.error(f"发送测试通知失败: {e}")
return JsonResponse({
'success': False,
'error': str(e)
}, status=500)
# build_api_response 已废弃,请使用 success_response/error_response
def _parse_bool(value: str | None) -> bool | None:

View File

@@ -147,10 +147,10 @@ class FlowOrchestrator:
return True
return False
# 其他扫描类型:检查 tools
# 其他扫描类型(包括 screenshot:检查 tools
tools = scan_config.get('tools', {})
for tool_config in tools.values():
if tool_config.get('enabled', False):
if isinstance(tool_config, dict) and tool_config.get('enabled', False):
return True
return False
@@ -222,6 +222,10 @@ class FlowOrchestrator:
from apps.scan.flows.vuln_scan import vuln_scan_flow
return vuln_scan_flow
elif scan_type == 'screenshot':
from apps.scan.flows.screenshot_flow import screenshot_flow
return screenshot_flow
else:
logger.warning(f"未实现的扫描类型: {scan_type}")
return None

View File

@@ -0,0 +1,56 @@
"""
扫描目标提供者模块
提供统一的目标获取接口,支持多种数据源:
- DatabaseTargetProvider: 从数据库查询(完整扫描)
- ListTargetProvider: 使用内存列表快速扫描阶段1
- SnapshotTargetProvider: 从快照表读取快速扫描阶段2+
- PipelineTargetProvider: 使用管道输出Phase 2
使用方式:
from apps.scan.providers import (
DatabaseTargetProvider,
ListTargetProvider,
SnapshotTargetProvider,
ProviderContext
)
# 数据库模式(完整扫描)
provider = DatabaseTargetProvider(target_id=123)
# 列表模式快速扫描阶段1
context = ProviderContext(target_id=1, scan_id=100)
provider = ListTargetProvider(
targets=["a.test.com"],
context=context
)
# 快照模式快速扫描阶段2+
context = ProviderContext(target_id=1, scan_id=100)
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain",
context=context
)
# 使用 Provider
for host in provider.iter_hosts():
scan(host)
"""
from .base import TargetProvider, ProviderContext
from .list_provider import ListTargetProvider
from .database_provider import DatabaseTargetProvider
from .snapshot_provider import SnapshotTargetProvider, SnapshotType
from .pipeline_provider import PipelineTargetProvider, StageOutput
__all__ = [
'TargetProvider',
'ProviderContext',
'ListTargetProvider',
'DatabaseTargetProvider',
'SnapshotTargetProvider',
'SnapshotType',
'PipelineTargetProvider',
'StageOutput',
]

View File

@@ -0,0 +1,115 @@
"""
扫描目标提供者基础模块
定义 ProviderContext 数据类和 TargetProvider 抽象基类。
"""
import ipaddress
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, Optional
if TYPE_CHECKING:
from apps.common.utils import BlacklistFilter
logger = logging.getLogger(__name__)
@dataclass
class ProviderContext:
"""
Provider 上下文,携带元数据
Attributes:
target_id: 关联的 Target ID用于结果保存None 表示临时扫描(不保存)
scan_id: 扫描任务 ID
"""
target_id: Optional[int] = None
scan_id: Optional[int] = None
class TargetProvider(ABC):
"""
扫描目标提供者抽象基类
职责:
- 提供扫描目标域名、IP、URL 等)的迭代器
- 提供黑名单过滤器
- 携带上下文信息target_id, scan_id 等)
- 自动展开 CIDR子类无需关心
使用方式:
provider = create_target_provider(target_id=123)
for host in provider.iter_hosts():
print(host)
"""
def __init__(self, context: Optional[ProviderContext] = None):
self._context = context or ProviderContext()
@property
def context(self) -> ProviderContext:
"""返回 Provider 上下文"""
return self._context
@staticmethod
def _expand_host(host: str) -> Iterator[str]:
"""
展开主机(如果是 CIDR 则展开为多个 IP否则直接返回
示例:
"192.168.1.0/30""192.168.1.1", "192.168.1.2"
"192.168.1.1""192.168.1.1"
"example.com""example.com"
"""
from apps.common.validators import detect_target_type
from apps.targets.models import Target
host = host.strip()
if not host:
return
try:
target_type = detect_target_type(host)
if target_type == Target.TargetType.CIDR:
network = ipaddress.ip_network(host, strict=False)
if network.num_addresses == 1:
yield str(network.network_address)
else:
yield from (str(ip) for ip in network.hosts())
elif target_type in (Target.TargetType.IP, Target.TargetType.DOMAIN):
yield host
except ValueError as e:
logger.warning("跳过无效的主机格式 '%s': %s", host, str(e))
def iter_hosts(self) -> Iterator[str]:
"""迭代主机列表(域名/IP自动展开 CIDR"""
for host in self._iter_raw_hosts():
yield from self._expand_host(host)
@abstractmethod
def _iter_raw_hosts(self) -> Iterator[str]:
"""迭代原始主机列表(可能包含 CIDR子类实现"""
pass
@abstractmethod
def iter_urls(self) -> Iterator[str]:
"""迭代 URL 列表"""
pass
@abstractmethod
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
"""获取黑名单过滤器,返回 None 表示不过滤"""
pass
@property
def target_id(self) -> Optional[int]:
"""返回关联的 target_id临时扫描返回 None"""
return self._context.target_id
@property
def scan_id(self) -> Optional[int]:
"""返回关联的 scan_id"""
return self._context.scan_id

View File

@@ -0,0 +1,93 @@
"""
数据库目标提供者模块
提供基于数据库查询的目标提供者实现。
"""
import logging
from typing import TYPE_CHECKING, Iterator, Optional
from .base import ProviderContext, TargetProvider
if TYPE_CHECKING:
from apps.common.utils import BlacklistFilter
logger = logging.getLogger(__name__)
class DatabaseTargetProvider(TargetProvider):
"""
数据库目标提供者 - 从 Target 表及关联资产表查询
数据来源:
- iter_hosts(): 根据 Target 类型返回域名/IP
- iter_urls(): WebSite/Endpoint 表,带回退链
使用方式:
provider = DatabaseTargetProvider(target_id=123)
for host in provider.iter_hosts():
scan(host)
"""
def __init__(self, target_id: int, context: Optional[ProviderContext] = None):
ctx = context or ProviderContext()
ctx.target_id = target_id
super().__init__(ctx)
self._blacklist_filter: Optional['BlacklistFilter'] = None
def iter_hosts(self) -> Iterator[str]:
"""从数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤"""
blacklist = self.get_blacklist_filter()
for host in self._iter_raw_hosts():
for expanded_host in self._expand_host(host):
if not blacklist or blacklist.is_allowed(expanded_host):
yield expanded_host
def _iter_raw_hosts(self) -> Iterator[str]:
"""从数据库查询原始主机列表(可能包含 CIDR"""
from apps.asset.services.asset.subdomain_service import SubdomainService
from apps.targets.models import Target
from apps.targets.services import TargetService
target = TargetService().get_target(self.target_id)
if not target:
logger.warning("Target ID %d 不存在", self.target_id)
return
if target.type == Target.TargetType.DOMAIN:
yield target.name
for domain in SubdomainService().iter_subdomain_names_by_target(
target_id=self.target_id,
chunk_size=1000
):
if domain != target.name:
yield domain
elif target.type in (Target.TargetType.IP, Target.TargetType.CIDR):
yield target.name
def iter_urls(self) -> Iterator[str]:
"""从数据库查询 URL 列表使用回退链Endpoint → WebSite → Default"""
from apps.scan.services.target_export_service import (
DataSource,
_iter_urls_with_fallback,
)
blacklist = self.get_blacklist_filter()
for url, _ in _iter_urls_with_fallback(
target_id=self.target_id,
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
blacklist_filter=blacklist
):
yield url
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
"""获取黑名单过滤器(延迟加载)"""
if self._blacklist_filter is None:
from apps.common.services import BlacklistService
from apps.common.utils import BlacklistFilter
rules = BlacklistService().get_rules(self.target_id)
self._blacklist_filter = BlacklistFilter(rules)
return self._blacklist_filter

View File

@@ -0,0 +1,84 @@
"""
列表目标提供者模块
提供基于内存列表的目标提供者实现。
"""
from typing import Iterator, Optional, List
from .base import TargetProvider, ProviderContext
class ListTargetProvider(TargetProvider):
"""
列表目标提供者 - 直接使用内存中的列表
用于快速扫描、临时扫描等场景,只扫描用户指定的目标。
特点:
- 不查询数据库
- 不应用黑名单过滤(用户明确指定的目标)
- 不关联 target_id由调用方负责创建 Target
- 自动检测输入类型URL/域名/IP/CIDR
- 自动展开 CIDR
使用方式:
# 快速扫描:用户提供目标,自动识别类型
provider = ListTargetProvider(targets=[
"example.com", # 域名
"192.168.1.0/24", # CIDR自动展开
"https://api.example.com" # URL
])
for host in provider.iter_hosts():
scan(host)
"""
def __init__(
self,
targets: Optional[List[str]] = None,
context: Optional[ProviderContext] = None
):
"""
初始化列表目标提供者
Args:
targets: 目标列表自动识别类型URL/域名/IP/CIDR
context: Provider 上下文
"""
from apps.common.validators import detect_input_type
ctx = context or ProviderContext()
super().__init__(ctx)
# 自动分类目标
self._hosts = []
self._urls = []
if targets:
for target in targets:
target = target.strip()
if not target:
continue
try:
input_type = detect_input_type(target)
if input_type == 'url':
self._urls.append(target)
else:
# domain/ip/cidr 都作为 host
self._hosts.append(target)
except ValueError:
# 无法识别类型,默认作为 host
self._hosts.append(target)
def _iter_raw_hosts(self) -> Iterator[str]:
"""迭代原始主机列表(可能包含 CIDR"""
yield from self._hosts
def iter_urls(self) -> Iterator[str]:
"""迭代 URL 列表"""
yield from self._urls
def get_blacklist_filter(self) -> None:
"""列表模式不使用黑名单过滤"""
return None

View File

@@ -0,0 +1,91 @@
"""
管道目标提供者模块
提供基于管道阶段输出的目标提供者实现。
用于 Phase 2 管道模式的阶段间数据传递。
"""
from dataclasses import dataclass, field
from typing import Iterator, Optional, List, Dict, Any
from .base import TargetProvider, ProviderContext
@dataclass
class StageOutput:
"""
阶段输出数据
用于在管道阶段之间传递数据。
Attributes:
hosts: 主机列表(域名/IP
urls: URL 列表
new_targets: 新发现的目标列表
stats: 统计信息
success: 是否成功
error: 错误信息
"""
hosts: List[str] = field(default_factory=list)
urls: List[str] = field(default_factory=list)
new_targets: List[str] = field(default_factory=list)
stats: Dict[str, Any] = field(default_factory=dict)
success: bool = True
error: Optional[str] = None
class PipelineTargetProvider(TargetProvider):
"""
管道目标提供者 - 使用上一阶段的输出
用于 Phase 2 管道模式的阶段间数据传递。
特点:
- 不查询数据库
- 不应用黑名单过滤(数据已在上一阶段过滤)
- 直接使用 StageOutput 中的数据
使用方式Phase 2
stage1_output = stage1.run(input)
provider = PipelineTargetProvider(
previous_output=stage1_output,
target_id=123
)
for host in provider.iter_hosts():
stage2.scan(host)
"""
def __init__(
self,
previous_output: StageOutput,
target_id: Optional[int] = None,
context: Optional[ProviderContext] = None
):
"""
初始化管道目标提供者
Args:
previous_output: 上一阶段的输出
target_id: 可选,关联到某个 Target用于保存结果
context: Provider 上下文
"""
ctx = context or ProviderContext(target_id=target_id)
super().__init__(ctx)
self._previous_output = previous_output
def _iter_raw_hosts(self) -> Iterator[str]:
"""迭代上一阶段输出的原始主机(可能包含 CIDR"""
yield from self._previous_output.hosts
def iter_urls(self) -> Iterator[str]:
"""迭代上一阶段输出的 URL"""
yield from self._previous_output.urls
def get_blacklist_filter(self) -> None:
"""管道传递的数据已经过滤过了"""
return None
@property
def previous_output(self) -> StageOutput:
"""返回上一阶段的输出"""
return self._previous_output

View File

@@ -0,0 +1,175 @@
"""
快照目标提供者模块
提供基于快照表的目标提供者实现。
用于快速扫描的阶段间数据传递。
"""
import logging
from typing import Iterator, Optional, Literal
from .base import TargetProvider, ProviderContext
logger = logging.getLogger(__name__)
# 快照类型定义
SnapshotType = Literal["subdomain", "website", "endpoint", "host_port"]
class SnapshotTargetProvider(TargetProvider):
"""
快照目标提供者 - 从快照表读取本次扫描的数据
用于快速扫描的阶段间数据传递,解决精确扫描控制问题。
核心价值:
- 只返回本次扫描scan_id发现的资产
- 避免扫描历史数据DatabaseTargetProvider 会扫描所有历史资产)
特点:
- 通过 scan_id 过滤快照表
- 不应用黑名单过滤(数据已在上一阶段过滤)
- 支持多种快照类型subdomain/website/endpoint/host_port
使用场景:
# 快速扫描流程
用户输入: a.test.com
创建 Target: test.com (id=1)
创建 Scan: scan_id=100
# 阶段1: 子域名发现
provider = ListTargetProvider(
targets=["a.test.com"],
context=ProviderContext(target_id=1, scan_id=100)
)
# 发现: b.a.test.com, c.a.test.com
# 保存: SubdomainSnapshot(scan_id=100) + Subdomain(target_id=1)
# 阶段2: 端口扫描
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain",
context=ProviderContext(target_id=1, scan_id=100)
)
# 只返回: b.a.test.com, c.a.test.com本次扫描发现的
# 不返回: www.test.com, api.test.com历史数据
# 阶段3: 网站扫描
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="host_port",
context=ProviderContext(target_id=1, scan_id=100)
)
# 只返回本次扫描发现的 IP:Port
"""
def __init__(
self,
scan_id: int,
snapshot_type: SnapshotType,
context: Optional[ProviderContext] = None
):
"""
初始化快照目标提供者
Args:
scan_id: 扫描任务 ID必需
snapshot_type: 快照类型
- "subdomain": 子域名快照SubdomainSnapshot
- "website": 网站快照WebsiteSnapshot
- "endpoint": 端点快照EndpointSnapshot
- "host_port": 主机端口映射快照HostPortMappingSnapshot
context: Provider 上下文
"""
ctx = context or ProviderContext()
ctx.scan_id = scan_id
super().__init__(ctx)
self._scan_id = scan_id
self._snapshot_type = snapshot_type
def _iter_raw_hosts(self) -> Iterator[str]:
"""
从快照表迭代主机列表
根据 snapshot_type 选择不同的快照表:
- subdomain: SubdomainSnapshot.name
- host_port: HostPortMappingSnapshot.host (返回 host:port 格式,不经过验证)
"""
if self._snapshot_type == "subdomain":
from apps.asset.services.snapshot import SubdomainSnapshotsService
service = SubdomainSnapshotsService()
yield from service.iter_subdomain_names_by_scan(
scan_id=self._scan_id,
chunk_size=1000
)
elif self._snapshot_type == "host_port":
# host_port 类型不使用 _iter_raw_hosts直接在 iter_hosts 中处理
# 这里返回空,避免被基类的 iter_hosts 调用
return
else:
# 其他类型暂不支持 iter_hosts
logger.warning(
"快照类型 '%s' 不支持 iter_hosts返回空迭代器",
self._snapshot_type
)
return
def iter_hosts(self) -> Iterator[str]:
"""
迭代主机列表
对于 host_port 类型,返回 host:port 格式,不经过 CIDR 展开验证
"""
if self._snapshot_type == "host_port":
# host_port 类型直接返回 host:port不经过 _expand_host 验证
from apps.asset.services.snapshot import HostPortMappingSnapshotsService
service = HostPortMappingSnapshotsService()
queryset = service.get_by_scan(scan_id=self._scan_id)
for mapping in queryset.iterator(chunk_size=1000):
yield f"{mapping.host}:{mapping.port}"
else:
# 其他类型使用基类的 iter_hosts会调用 _iter_raw_hosts 并展开 CIDR
yield from super().iter_hosts()
def iter_urls(self) -> Iterator[str]:
"""
从快照表迭代 URL 列表
根据 snapshot_type 选择不同的快照表:
- website: WebsiteSnapshot.url
- endpoint: EndpointSnapshot.url
"""
if self._snapshot_type == "website":
from apps.asset.services.snapshot import WebsiteSnapshotsService
service = WebsiteSnapshotsService()
yield from service.iter_website_urls_by_scan(
scan_id=self._scan_id,
chunk_size=1000
)
elif self._snapshot_type == "endpoint":
from apps.asset.services.snapshot import EndpointSnapshotsService
service = EndpointSnapshotsService()
# 从快照表获取端点 URL
queryset = service.get_by_scan(scan_id=self._scan_id)
for endpoint in queryset.iterator(chunk_size=1000):
yield endpoint.url
else:
# 其他类型暂不支持 iter_urls
logger.warning(
"快照类型 '%s' 不支持 iter_urls返回空迭代器",
self._snapshot_type
)
return
def get_blacklist_filter(self) -> None:
"""快照数据已在上一阶段过滤过了"""
return None
@property
def snapshot_type(self) -> SnapshotType:
"""返回快照类型"""
return self._snapshot_type

View File

@@ -0,0 +1,3 @@
"""
扫描目标提供者测试模块
"""

View File

@@ -0,0 +1,256 @@
"""
通用属性测试
包含跨多个 Provider 的通用属性测试:
- Property 4: Context Propagation
- Property 5: Non-Database Provider Blacklist Filter
- Property 7: CIDR Expansion Consistency
"""
import pytest
from hypothesis import given, strategies as st, settings
from ipaddress import IPv4Network
from apps.scan.providers import (
ProviderContext,
ListTargetProvider,
DatabaseTargetProvider,
PipelineTargetProvider,
SnapshotTargetProvider
)
from apps.scan.providers.pipeline_provider import StageOutput
class TestContextPropagation:
"""
Property 4: Context Propagation
*For any* ProviderContext传入 Provider 构造函数后,
Provider 的 target_id 和 scan_id 属性应该与 context 中的值一致。
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_list_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (ListTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
provider = ListTargetProvider(targets=["example.com"], context=ctx)
assert provider.target_id == target_id
assert provider.scan_id == scan_id
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_database_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (DatabaseTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=999, scan_id=scan_id)
# DatabaseTargetProvider 会覆盖 context 中的 target_id
provider = DatabaseTargetProvider(target_id=target_id, context=ctx)
assert provider.target_id == target_id # 使用构造函数参数
assert provider.scan_id == scan_id # 使用 context 中的值
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_pipeline_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (PipelineTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
assert provider.target_id == target_id
assert provider.scan_id == scan_id
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_snapshot_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (SnapshotTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=target_id, scan_id=999)
# SnapshotTargetProvider 会覆盖 context 中的 scan_id
provider = SnapshotTargetProvider(
scan_id=scan_id,
snapshot_type="subdomain",
context=ctx
)
assert provider.target_id == target_id # 使用 context 中的值
assert provider.scan_id == scan_id # 使用构造函数参数
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
class TestNonDatabaseProviderBlacklistFilter:
"""
Property 5: Non-Database Provider Blacklist Filter
*For any* ListTargetProvider 或 PipelineTargetProvider 实例,
get_blacklist_filter() 方法应该返回 None。
**Validates: Requirements 3.4, 9.4, 9.5**
"""
@given(targets=st.lists(st.text(min_size=1, max_size=20), max_size=10))
@settings(max_examples=100)
def test_property_5_list_provider_no_blacklist(self, targets):
"""
Property 5: Non-Database Provider Blacklist Filter (ListTargetProvider)
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
**Validates: Requirements 3.4, 9.4, 9.5**
"""
provider = ListTargetProvider(targets=targets)
assert provider.get_blacklist_filter() is None
@given(hosts=st.lists(st.text(min_size=1, max_size=20), max_size=10))
@settings(max_examples=100)
def test_property_5_pipeline_provider_no_blacklist(self, hosts):
"""
Property 5: Non-Database Provider Blacklist Filter (PipelineTargetProvider)
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
**Validates: Requirements 3.4, 9.4, 9.5**
"""
stage_output = StageOutput(hosts=hosts)
provider = PipelineTargetProvider(previous_output=stage_output)
assert provider.get_blacklist_filter() is None
def test_property_5_snapshot_provider_no_blacklist(self):
"""
Property 5: Non-Database Provider Blacklist Filter (SnapshotTargetProvider)
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
**Validates: Requirements 3.4, 9.4, 9.5**
"""
provider = SnapshotTargetProvider(scan_id=1, snapshot_type="subdomain")
assert provider.get_blacklist_filter() is None
class TestCIDRExpansionConsistency:
"""
Property 7: CIDR Expansion Consistency
*For any* CIDR 字符串(如 "192.168.1.0/24"),所有 Provider 的 iter_hosts()
方法应该将其展开为相同的单个 IP 地址列表。
**Validates: Requirements 1.1, 3.6**
"""
@given(
# 生成小的 CIDR 范围以避免测试超时
network_prefix=st.integers(min_value=1, max_value=254),
cidr_suffix=st.integers(min_value=28, max_value=30) # /28 = 16 IPs, /30 = 4 IPs
)
@settings(max_examples=50, deadline=None)
def test_property_7_cidr_expansion_consistency(self, network_prefix, cidr_suffix):
"""
Property 7: CIDR Expansion Consistency
Feature: scan-target-provider, Property 7: CIDR Expansion Consistency
**Validates: Requirements 1.1, 3.6**
For any CIDR string, all Providers should expand it to the same IP list.
"""
cidr = f"192.168.{network_prefix}.0/{cidr_suffix}"
# 计算预期的 IP 列表
network = IPv4Network(cidr, strict=False)
# 排除网络地址和广播地址
expected_ips = [str(ip) for ip in network.hosts()]
# 如果 CIDR 太小(/31 或 /32使用所有地址
if not expected_ips:
expected_ips = [str(ip) for ip in network]
# ListTargetProvider
list_provider = ListTargetProvider(targets=[cidr])
list_result = list(list_provider.iter_hosts())
# PipelineTargetProvider
stage_output = StageOutput(hosts=[cidr])
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
pipeline_result = list(pipeline_provider.iter_hosts())
# 验证:所有 Provider 展开的结果应该一致
assert list_result == expected_ips, f"ListProvider CIDR expansion mismatch for {cidr}"
assert pipeline_result == expected_ips, f"PipelineProvider CIDR expansion mismatch for {cidr}"
assert list_result == pipeline_result, f"Providers produce different results for {cidr}"
def test_cidr_expansion_with_multiple_cidrs(self):
"""测试多个 CIDR 的展开一致性"""
cidrs = ["192.168.1.0/30", "10.0.0.0/30"]
# 计算预期结果
expected_ips = []
for cidr in cidrs:
network = IPv4Network(cidr, strict=False)
expected_ips.extend([str(ip) for ip in network.hosts()])
# ListTargetProvider
list_provider = ListTargetProvider(targets=cidrs)
list_result = list(list_provider.iter_hosts())
# PipelineTargetProvider
stage_output = StageOutput(hosts=cidrs)
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
pipeline_result = list(pipeline_provider.iter_hosts())
# 验证
assert list_result == expected_ips
assert pipeline_result == expected_ips
assert list_result == pipeline_result
def test_mixed_hosts_and_cidrs(self):
"""测试混合主机和 CIDR 的处理"""
targets = ["example.com", "192.168.1.0/30", "test.com"]
# 计算预期结果
network = IPv4Network("192.168.1.0/30", strict=False)
cidr_ips = [str(ip) for ip in network.hosts()]
expected = ["example.com"] + cidr_ips + ["test.com"]
# ListTargetProvider
list_provider = ListTargetProvider(targets=targets)
list_result = list(list_provider.iter_hosts())
# 验证
assert list_result == expected

View File

@@ -0,0 +1,158 @@
"""
DatabaseTargetProvider 属性测试
Property 7: DatabaseTargetProvider Blacklist Application
*For any* 带有黑名单规则的 target_idDatabaseTargetProvider 的 iter_hosts() 和 iter_urls()
应该过滤掉匹配黑名单规则的目标。
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
"""
import pytest
from unittest.mock import patch, MagicMock
from hypothesis import given, strategies as st, settings
from apps.scan.providers.database_provider import DatabaseTargetProvider
from apps.scan.providers.base import ProviderContext
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
class MockBlacklistFilter:
"""模拟黑名单过滤器"""
def __init__(self, blocked_patterns: list):
self.blocked_patterns = blocked_patterns
def is_allowed(self, target: str) -> bool:
"""检查目标是否被允许(不在黑名单中)"""
for pattern in self.blocked_patterns:
if pattern in target:
return False
return True
class TestDatabaseTargetProviderProperties:
"""DatabaseTargetProvider 属性测试类"""
@given(
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=20),
blocked_keyword=st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=5
)
)
@settings(max_examples=100)
def test_property_7_blacklist_filters_hosts(self, hosts, blocked_keyword):
"""
Property 7: DatabaseTargetProvider Blacklist Application (hosts)
Feature: scan-target-provider, Property 7: DatabaseTargetProvider Blacklist Application
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
For any set of hosts and a blacklist keyword, the provider should filter out
all hosts containing the blocked keyword.
"""
# 创建模拟的黑名单过滤器
mock_filter = MockBlacklistFilter([blocked_keyword])
# 创建 provider 并注入模拟的黑名单过滤器
provider = DatabaseTargetProvider(target_id=1)
provider._blacklist_filter = mock_filter
# 模拟 Target 和 SubdomainService
mock_target = MagicMock()
mock_target.type = 'domain'
mock_target.name = hosts[0] if hosts else 'example.com'
with patch('apps.targets.services.TargetService') as mock_target_service, \
patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_subdomain_service:
mock_target_service.return_value.get_target.return_value = mock_target
mock_subdomain_service.return_value.iter_subdomain_names_by_target.return_value = iter(hosts[1:] if len(hosts) > 1 else [])
# 获取结果
result = list(provider.iter_hosts())
# 验证:所有结果都不包含被阻止的关键词
for host in result:
assert blocked_keyword not in host, f"Host '{host}' should be filtered by blacklist keyword '{blocked_keyword}'"
# 验证:所有不包含关键词的主机都应该在结果中
if hosts:
all_hosts = [hosts[0]] + [h for h in hosts[1:] if h != hosts[0]]
expected_allowed = [h for h in all_hosts if blocked_keyword not in h]
else:
expected_allowed = []
assert set(result) == set(expected_allowed)
class TestDatabaseTargetProviderUnit:
"""DatabaseTargetProvider 单元测试类"""
def test_target_id_in_context(self):
"""测试 target_id 正确设置到上下文中"""
provider = DatabaseTargetProvider(target_id=123)
assert provider.target_id == 123
assert provider.context.target_id == 123
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(scan_id=789)
provider = DatabaseTargetProvider(target_id=123, context=ctx)
assert provider.target_id == 123 # target_id 被覆盖
assert provider.scan_id == 789
def test_blacklist_filter_lazy_loading(self):
"""测试黑名单过滤器延迟加载"""
provider = DatabaseTargetProvider(target_id=123)
# 初始时 _blacklist_filter 为 None
assert provider._blacklist_filter is None
# 模拟 BlacklistService
with patch('apps.common.services.BlacklistService') as mock_service, \
patch('apps.common.utils.BlacklistFilter') as mock_filter_class:
mock_service.return_value.get_rules.return_value = []
mock_filter_instance = MagicMock()
mock_filter_class.return_value = mock_filter_instance
# 第一次调用
result1 = provider.get_blacklist_filter()
assert result1 == mock_filter_instance
# 第二次调用应该返回缓存的实例
result2 = provider.get_blacklist_filter()
assert result2 == mock_filter_instance
# BlacklistService 只应该被调用一次
mock_service.return_value.get_rules.assert_called_once_with(123)
def test_nonexistent_target_returns_empty(self):
"""测试不存在的 target 返回空迭代器"""
provider = DatabaseTargetProvider(target_id=99999)
with patch('apps.targets.services.TargetService') as mock_service, \
patch('apps.common.services.BlacklistService') as mock_blacklist_service:
mock_service.return_value.get_target.return_value = None
mock_blacklist_service.return_value.get_rules.return_value = []
result = list(provider.iter_hosts())
assert result == []

View File

@@ -0,0 +1,152 @@
"""
ListTargetProvider 属性测试
Property 1: ListTargetProvider Round-Trip
*For any* 主机列表和 URL 列表,创建 ListTargetProvider 后迭代 iter_hosts() 和 iter_urls()
应该返回与输入相同的元素(顺序相同)。
**Validates: Requirements 3.1, 3.2**
"""
import pytest
from hypothesis import given, strategies as st, settings, assume
from apps.scan.providers.list_provider import ListTargetProvider
from apps.scan.providers.base import ProviderContext
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
# 生成简单的域名格式: subdomain.domain.tld
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
# 生成有效 IP 地址的策略
def valid_ip_strategy():
"""生成有效的 IPv4 地址"""
octet = st.integers(min_value=1, max_value=254)
return st.builds(
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
octet, octet, octet, octet
)
# 组合策略:域名或 IP
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
# 生成有效 URL 的策略
def valid_url_strategy():
"""生成有效的 URL"""
domain = valid_domain_strategy()
return st.builds(
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
domain,
st.one_of(
st.just(""),
st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=1,
max_size=10
)
)
)
url_strategy = valid_url_strategy()
class TestListTargetProviderProperties:
"""ListTargetProvider 属性测试类"""
@given(hosts=st.lists(host_strategy, max_size=50))
@settings(max_examples=100)
def test_property_1_hosts_round_trip(self, hosts):
"""
Property 1: ListTargetProvider Round-Trip (hosts)
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
**Validates: Requirements 3.1, 3.2**
For any host list, creating a ListTargetProvider and iterating iter_hosts()
should return the same elements in the same order.
"""
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
provider = ListTargetProvider(targets=hosts)
result = list(provider.iter_hosts())
assert result == hosts
@given(urls=st.lists(url_strategy, max_size=50))
@settings(max_examples=100)
def test_property_1_urls_round_trip(self, urls):
"""
Property 1: ListTargetProvider Round-Trip (urls)
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
**Validates: Requirements 3.1, 3.2**
For any URL list, creating a ListTargetProvider and iterating iter_urls()
should return the same elements in the same order.
"""
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
provider = ListTargetProvider(targets=urls)
result = list(provider.iter_urls())
assert result == urls
@given(
hosts=st.lists(host_strategy, max_size=30),
urls=st.lists(url_strategy, max_size=30)
)
@settings(max_examples=100)
def test_property_1_combined_round_trip(self, hosts, urls):
"""
Property 1: ListTargetProvider Round-Trip (combined)
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
**Validates: Requirements 3.1, 3.2**
For any combination of hosts and URLs, both should round-trip correctly.
"""
# 合并 hosts 和 urlsListTargetProvider 会自动分类
combined = hosts + urls
provider = ListTargetProvider(targets=combined)
hosts_result = list(provider.iter_hosts())
urls_result = list(provider.iter_urls())
assert hosts_result == hosts
assert urls_result == urls
class TestListTargetProviderUnit:
"""ListTargetProvider 单元测试类"""
def test_empty_lists(self):
"""测试空列表返回空迭代器 - Requirements 3.5"""
provider = ListTargetProvider()
assert list(provider.iter_hosts()) == []
assert list(provider.iter_urls()) == []
def test_blacklist_filter_returns_none(self):
"""测试黑名单过滤器返回 None - Requirements 3.4"""
provider = ListTargetProvider(targets=["example.com"])
assert provider.get_blacklist_filter() is None
def test_target_id_association(self):
"""测试 target_id 关联 - Requirements 3.3"""
ctx = ProviderContext(target_id=123)
provider = ListTargetProvider(targets=["example.com"], context=ctx)
assert provider.target_id == 123
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(target_id=456, scan_id=789)
provider = ListTargetProvider(targets=["example.com"], context=ctx)
assert provider.target_id == 456
assert provider.scan_id == 789

View File

@@ -0,0 +1,180 @@
"""
PipelineTargetProvider 属性测试
Property 3: PipelineTargetProvider Round-Trip
*For any* StageOutput 对象PipelineTargetProvider 的 iter_hosts() 和 iter_urls()
应该返回与 StageOutput 中 hosts 和 urls 列表相同的元素。
**Validates: Requirements 5.1, 5.2**
"""
import pytest
from hypothesis import given, strategies as st, settings
from apps.scan.providers.pipeline_provider import PipelineTargetProvider, StageOutput
from apps.scan.providers.base import ProviderContext
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
# 生成有效 IP 地址的策略
def valid_ip_strategy():
"""生成有效的 IPv4 地址"""
octet = st.integers(min_value=1, max_value=254)
return st.builds(
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
octet, octet, octet, octet
)
# 组合策略:域名或 IP
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
# 生成有效 URL 的策略
def valid_url_strategy():
"""生成有效的 URL"""
domain = valid_domain_strategy()
return st.builds(
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
domain,
st.one_of(
st.just(""),
st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=1,
max_size=10
)
)
)
url_strategy = valid_url_strategy()
class TestPipelineTargetProviderProperties:
"""PipelineTargetProvider 属性测试类"""
@given(hosts=st.lists(host_strategy, max_size=50))
@settings(max_examples=100)
def test_property_3_hosts_round_trip(self, hosts):
"""
Property 3: PipelineTargetProvider Round-Trip (hosts)
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
**Validates: Requirements 5.1, 5.2**
For any StageOutput with hosts, PipelineTargetProvider should return
the same hosts in the same order.
"""
stage_output = StageOutput(hosts=hosts)
provider = PipelineTargetProvider(previous_output=stage_output)
result = list(provider.iter_hosts())
assert result == hosts
@given(urls=st.lists(url_strategy, max_size=50))
@settings(max_examples=100)
def test_property_3_urls_round_trip(self, urls):
"""
Property 3: PipelineTargetProvider Round-Trip (urls)
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
**Validates: Requirements 5.1, 5.2**
For any StageOutput with urls, PipelineTargetProvider should return
the same urls in the same order.
"""
stage_output = StageOutput(urls=urls)
provider = PipelineTargetProvider(previous_output=stage_output)
result = list(provider.iter_urls())
assert result == urls
@given(
hosts=st.lists(host_strategy, max_size=30),
urls=st.lists(url_strategy, max_size=30)
)
@settings(max_examples=100)
def test_property_3_combined_round_trip(self, hosts, urls):
"""
Property 3: PipelineTargetProvider Round-Trip (combined)
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
**Validates: Requirements 5.1, 5.2**
For any StageOutput with both hosts and urls, both should round-trip correctly.
"""
stage_output = StageOutput(hosts=hosts, urls=urls)
provider = PipelineTargetProvider(previous_output=stage_output)
hosts_result = list(provider.iter_hosts())
urls_result = list(provider.iter_urls())
assert hosts_result == hosts
assert urls_result == urls
class TestPipelineTargetProviderUnit:
"""PipelineTargetProvider 单元测试类"""
def test_empty_stage_output(self):
"""测试空 StageOutput 返回空迭代器 - Requirements 5.5"""
stage_output = StageOutput()
provider = PipelineTargetProvider(previous_output=stage_output)
assert list(provider.iter_hosts()) == []
assert list(provider.iter_urls()) == []
def test_blacklist_filter_returns_none(self):
"""测试黑名单过滤器返回 None - Requirements 5.3"""
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output)
assert provider.get_blacklist_filter() is None
def test_target_id_association(self):
"""测试 target_id 关联 - Requirements 5.4"""
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output, target_id=123)
assert provider.target_id == 123
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(target_id=456, scan_id=789)
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
assert provider.target_id == 456
assert provider.scan_id == 789
def test_previous_output_property(self):
"""测试 previous_output 属性"""
stage_output = StageOutput(hosts=["example.com"], urls=["https://example.com"])
provider = PipelineTargetProvider(previous_output=stage_output)
assert provider.previous_output is stage_output
assert provider.previous_output.hosts == ["example.com"]
assert provider.previous_output.urls == ["https://example.com"]
def test_stage_output_with_metadata(self):
"""测试带元数据的 StageOutput"""
stage_output = StageOutput(
hosts=["example.com"],
urls=["https://example.com"],
new_targets=["new.example.com"],
stats={"count": 1},
success=True,
error=None
)
provider = PipelineTargetProvider(previous_output=stage_output)
assert list(provider.iter_hosts()) == ["example.com"]
assert list(provider.iter_urls()) == ["https://example.com"]
assert provider.previous_output.new_targets == ["new.example.com"]
assert provider.previous_output.stats == {"count": 1}

View File

@@ -0,0 +1,191 @@
"""
SnapshotTargetProvider 单元测试
"""
import pytest
from unittest.mock import Mock, patch
from apps.scan.providers import SnapshotTargetProvider, ProviderContext
class TestSnapshotTargetProvider:
"""SnapshotTargetProvider 测试类"""
def test_init_with_scan_id_and_type(self):
"""测试初始化"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain"
)
assert provider.scan_id == 100
assert provider.snapshot_type == "subdomain"
assert provider.target_id is None # 默认 context
def test_init_with_context(self):
"""测试带 context 初始化"""
ctx = ProviderContext(target_id=1, scan_id=100)
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain",
context=ctx
)
assert provider.scan_id == 100
assert provider.target_id == 1
assert provider.snapshot_type == "subdomain"
@patch('apps.asset.services.snapshot.SubdomainSnapshotsService')
def test_iter_hosts_subdomain(self, mock_service_class):
"""测试从子域名快照迭代主机"""
# Mock service
mock_service = Mock()
mock_service.iter_subdomain_names_by_scan.return_value = iter([
"a.example.com",
"b.example.com"
])
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain"
)
# 迭代主机
hosts = list(provider.iter_hosts())
assert hosts == ["a.example.com", "b.example.com"]
mock_service.iter_subdomain_names_by_scan.assert_called_once_with(
scan_id=100,
chunk_size=1000
)
@patch('apps.asset.services.snapshot.HostPortMappingSnapshotsService')
def test_iter_hosts_host_port(self, mock_service_class):
"""测试从主机端口映射快照迭代主机"""
# Mock queryset
mock_mapping1 = Mock()
mock_mapping1.host = "example.com"
mock_mapping1.port = 80
mock_mapping2 = Mock()
mock_mapping2.host = "example.com"
mock_mapping2.port = 443
mock_queryset = Mock()
mock_queryset.iterator.return_value = iter([mock_mapping1, mock_mapping2])
# Mock service
mock_service = Mock()
mock_service.get_by_scan.return_value = mock_queryset
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="host_port"
)
# 迭代主机
hosts = list(provider.iter_hosts())
assert hosts == ["example.com:80", "example.com:443"]
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
@patch('apps.asset.services.snapshot.WebsiteSnapshotsService')
def test_iter_urls_website(self, mock_service_class):
"""测试从网站快照迭代 URL"""
# Mock service
mock_service = Mock()
mock_service.iter_website_urls_by_scan.return_value = iter([
"http://example.com",
"https://example.com"
])
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="website"
)
# 迭代 URL
urls = list(provider.iter_urls())
assert urls == ["http://example.com", "https://example.com"]
mock_service.iter_website_urls_by_scan.assert_called_once_with(
scan_id=100,
chunk_size=1000
)
@patch('apps.asset.services.snapshot.EndpointSnapshotsService')
def test_iter_urls_endpoint(self, mock_service_class):
"""测试从端点快照迭代 URL"""
# Mock queryset
mock_endpoint1 = Mock()
mock_endpoint1.url = "http://example.com/api/v1"
mock_endpoint2 = Mock()
mock_endpoint2.url = "http://example.com/api/v2"
mock_queryset = Mock()
mock_queryset.iterator.return_value = iter([mock_endpoint1, mock_endpoint2])
# Mock service
mock_service = Mock()
mock_service.get_by_scan.return_value = mock_queryset
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="endpoint"
)
# 迭代 URL
urls = list(provider.iter_urls())
assert urls == ["http://example.com/api/v1", "http://example.com/api/v2"]
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
def test_iter_hosts_unsupported_type(self):
"""测试不支持的快照类型iter_hosts"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="website" # website 不支持 iter_hosts
)
hosts = list(provider.iter_hosts())
assert hosts == []
def test_iter_urls_unsupported_type(self):
"""测试不支持的快照类型iter_urls"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain" # subdomain 不支持 iter_urls
)
urls = list(provider.iter_urls())
assert urls == []
def test_get_blacklist_filter(self):
"""测试黑名单过滤器(快照模式不使用黑名单)"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain"
)
assert provider.get_blacklist_filter() is None
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(target_id=456, scan_id=789)
provider = SnapshotTargetProvider(
scan_id=100, # 会被 context 覆盖
snapshot_type="subdomain",
context=ctx
)
assert provider.target_id == 456
assert provider.scan_id == 100 # scan_id 在 __init__ 中被设置

View File

@@ -16,7 +16,6 @@ from django.utils import timezone
from apps.scan.models import Scan
from apps.targets.models import Target
from apps.engine.models import ScanEngine
from apps.common.definitions import ScanStatus
from apps.common.decorators import auto_ensure_db_connection
@@ -40,7 +39,7 @@ class DjangoScanRepository:
Args:
scan_id: 扫描任务 ID
prefetch_relations: 是否预加载关联对象(engine, target
prefetch_relations: 是否预加载关联对象(target, worker
默认 False只在需要展示关联信息时设为 True
for_update: 是否加锁(用于更新场景)
@@ -56,7 +55,7 @@ class DjangoScanRepository:
# 预加载关联对象(性能优化:默认不加载)
if prefetch_relations:
queryset = queryset.select_related('engine', 'target')
queryset = queryset.select_related('target', 'worker')
return queryset.get(id=scan_id)
except Scan.DoesNotExist: # type: ignore # pylint: disable=no-member
@@ -79,7 +78,7 @@ class DjangoScanRepository:
Note:
- 使用默认的阻塞模式(等待锁释放)
- 不包含关联对象(engine, target),如需关联对象请使用 get_by_id()
- 不包含关联对象(target, worker),如需关联对象请使用 get_by_id()
"""
try:
return Scan.objects.select_for_update().get(id=scan_id) # type: ignore # pylint: disable=no-member
@@ -103,7 +102,9 @@ class DjangoScanRepository:
def create(self,
target: Target,
engine: ScanEngine,
engine_ids: List[int],
engine_names: List[str],
yaml_configuration: str,
results_dir: str,
status: ScanStatus = ScanStatus.INITIATED
) -> Scan:
@@ -112,7 +113,9 @@ class DjangoScanRepository:
Args:
target: 扫描目标
engine: 扫描引擎
engine_ids: 引擎 ID 列表
engine_names: 引擎名称列表
yaml_configuration: YAML 格式的扫描配置
results_dir: 结果目录
status: 初始状态
@@ -121,7 +124,9 @@ class DjangoScanRepository:
"""
scan = Scan(
target=target,
engine=engine,
engine_ids=engine_ids,
engine_names=engine_names,
yaml_configuration=yaml_configuration,
results_dir=results_dir,
status=status,
container_ids=[]
@@ -231,14 +236,14 @@ class DjangoScanRepository:
获取所有扫描任务
Args:
prefetch_relations: 是否预加载关联对象(engine, target
prefetch_relations: 是否预加载关联对象(target, worker
Returns:
Scan QuerySet
"""
queryset = Scan.objects.all() # type: ignore # pylint: disable=no-member
if prefetch_relations:
queryset = queryset.select_related('engine', 'target')
queryset = queryset.select_related('target', 'worker')
return queryset.order_by('-created_at')
@@ -459,6 +464,7 @@ class DjangoScanRepository:
'endpoints': scan.endpoint_snapshots.count(),
'ips': ips_count,
'directories': scan.directory_snapshots.count(),
'screenshots': scan.screenshot_snapshots.count(),
'vulns_total': total_vulns,
'vulns_critical': severity_stats['critical'],
'vulns_high': severity_stats['high'],
@@ -473,6 +479,7 @@ class DjangoScanRepository:
'cached_endpoints_count': stats['endpoints'],
'cached_ips_count': stats['ips'],
'cached_directories_count': stats['directories'],
'cached_screenshots_count': stats['screenshots'],
'cached_vulns_total': stats['vulns_total'],
'cached_vulns_critical': stats['vulns_critical'],
'cached_vulns_high': stats['vulns_high'],

View File

@@ -29,7 +29,9 @@ class ScheduledScanDTO:
"""
id: Optional[int] = None
name: str = ''
engine_id: int = 0
engine_ids: List[int] = None # 多引擎支持
engine_names: List[str] = None # 引擎名称列表
yaml_configuration: str = '' # YAML 格式的扫描配置
organization_id: Optional[int] = None # 组织扫描模式
target_id: Optional[int] = None # 目标扫描模式
cron_expression: Optional[str] = None
@@ -40,6 +42,11 @@ class ScheduledScanDTO:
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
def __post_init__(self):
if self.engine_ids is None:
self.engine_ids = []
if self.engine_names is None:
self.engine_names = []
@auto_ensure_db_connection
@@ -56,7 +63,7 @@ class DjangoScheduledScanRepository:
def get_by_id(self, scheduled_scan_id: int) -> Optional[ScheduledScan]:
"""根据 ID 查询定时扫描任务"""
try:
return ScheduledScan.objects.select_related('engine', 'organization', 'target').get(id=scheduled_scan_id)
return ScheduledScan.objects.select_related('organization', 'target').get(id=scheduled_scan_id)
except ScheduledScan.DoesNotExist:
return None
@@ -67,7 +74,7 @@ class DjangoScheduledScanRepository:
Returns:
QuerySet
"""
return ScheduledScan.objects.select_related('engine', 'organization', 'target').order_by('-created_at')
return ScheduledScan.objects.select_related('organization', 'target').order_by('-created_at')
def get_all(self, page: int = 1, page_size: int = 10) -> Tuple[List[ScheduledScan], int]:
"""
@@ -87,7 +94,7 @@ class DjangoScheduledScanRepository:
def get_enabled(self) -> List[ScheduledScan]:
"""获取所有启用的定时扫描任务"""
return list(
ScheduledScan.objects.select_related('engine', 'target')
ScheduledScan.objects.select_related('target')
.filter(is_enabled=True)
.order_by('-created_at')
)
@@ -105,7 +112,9 @@ class DjangoScheduledScanRepository:
with transaction.atomic():
scheduled_scan = ScheduledScan.objects.create(
name=dto.name,
engine_id=dto.engine_id,
engine_ids=dto.engine_ids,
engine_names=dto.engine_names,
yaml_configuration=dto.yaml_configuration,
organization_id=dto.organization_id, # 组织扫描模式
target_id=dto.target_id if not dto.organization_id else None, # 目标扫描模式
cron_expression=dto.cron_expression,
@@ -134,8 +143,12 @@ class DjangoScheduledScanRepository:
# 更新基本字段
if dto.name:
scheduled_scan.name = dto.name
if dto.engine_id:
scheduled_scan.engine_id = dto.engine_id
if dto.engine_ids is not None:
scheduled_scan.engine_ids = dto.engine_ids
if dto.engine_names is not None:
scheduled_scan.engine_names = dto.engine_names
if dto.yaml_configuration is not None:
scheduled_scan.yaml_configuration = dto.yaml_configuration
if dto.cron_expression is not None:
scheduled_scan.cron_expression = dto.cron_expression
if dto.is_enabled is not None:

Some files were not shown because too many files have changed in this diff Show More