Compare commits

...

125 Commits

Author SHA1 Message Date
yyhuni
d7599b8599 feat(fingerprints): Add database indexes and expand test data generation
- Add database indexes on 'link' field in FingersFingerprint model for improved query performance
- Add database index on 'author' field in FingerPrintHubFingerprint model for filtering optimization
- Expand test data generation to include Fingers, FingerPrintHub, and ARL fingerprint types
- Add comprehensive fingerprint data generation methods with realistic templates and patterns
- Update test data cleanup to include all fingerprint table types
- Add i18n translations for fingerprint-related UI components and labels
- Optimize route prefetching hook for better performance
- Improve fingerprint data table columns and vulnerability columns display consistencyzxc
2025-12-31 10:04:15 +08:00
yyhuni
8eff298293 更新镜像加速逻辑 2025-12-31 08:56:55 +08:00
yyhuni
3634101c5b 添加灯塔等指纹 2025-12-31 08:55:37 +08:00
yyhuni
163973a7df feat(i18n): Add internationalization support to dropzone component
- Add useTranslations hook to DropzoneContent component for multi-language support
- Add useTranslations hook to DropzoneEmptyState component for multi-language support
- Replace hardcoded English strings with i18n translation keys in dropzone UI
- Add comprehensive translation keys for dropzone messages in en.json:
* uploadFile, uploadFiles, dragOrClick, dragOrClickReplace
* moreFiles, supports, minimum, maximum, sizeBetween
- Add corresponding Chinese translations in zh.json for all dropzone messages
- Support dynamic content in translations using parameterized keys (files count, size ranges)
- Ensure consistent user experience across English and Chinese interfaces
2025-12-30 21:19:37 +08:00
yyhuni
80ffecba3e feat(i18n): Add UI component i18n provider and standardize translation keys
- Add UiI18nProvider component to wrap UI library translations globally
- Integrate UiI18nProvider into root layout for consistent i18n support
- Standardize download action translation keys (allEndpoints → all, selectedEndpoints → selected)
- Update ExpandableTagList component prop from maxVisible to maxLines for better layout control
- Fix color scheme in dashboard stop scan button (chart-2 → primary)
- Add DOCKER_API_VERSION configuration to backend settings for Docker client compatibility
- Update task distributor to use configurable Docker API version (default 1.40)
- Add environment variable support for Docker API version in task execution commands
- Update i18n configuration and message files with standardized keys
- Ensure UI components respect application locale settings across all data tables and dialogs
2025-12-30 21:19:28 +08:00
yyhuni
3c21ac940c 恢复ssh docker 2025-12-30 20:35:51 +08:00
yyhuni
5c9f484d70 fix(frontend): Fix i18n translation key references and add missing labels
- Change "nav" translation namespace to "navigation" in scan engine and wordlists pages
- Replace parameterized translation calls with raw translation strings for cron schedule options in scheduled scan page and dashboard component
- Cast raw translation results to string type for proper TypeScript typing
- Add missing "name" and "type" labels to fingerprint section in English and Chinese message files
- Ensure consistent translation key usage across components for better maintainability
2025-12-30 18:21:16 +08:00
yyhuni
7567f6c25b 更新文字描述 2025-12-30 18:08:39 +08:00
yyhuni
0599a0b298 ansi-to-html加入 2025-12-30 18:01:29 +08:00
yyhuni
f7557fe90c ansi-to-html替代log显示 2025-12-30 18:01:22 +08:00
yyhuni
13571b9772 fix(frontend): Fix xterm SSR initialization error
- Add 100ms delay for terminal initialization to ensure DOM is mounted
- Use requestAnimationFrame for fit() to avoid dimensions error
- Add try-catch for all xterm operations
- Proper cleanup on unmount

Fixes: Cannot read properties of undefined (reading 'dimensions')
2025-12-30 17:41:38 +08:00
yyhuni
8ee76eef69 feat(frontend): Add ANSI color support for system logs
- Create AnsiLogViewer component using xterm.js
- Replace Monaco Editor with xterm for log viewing
- Native ANSI escape code rendering (colors, bold, etc.)
- Auto-scroll to bottom, clickable URLs support

Benefits:
- Colorized logs for better readability
- No more escape codes like [32m[0m in UI
- Professional terminal-like experience
2025-12-30 17:39:12 +08:00
yyhuni
2a31e29aa2 fix: Add shell quoting for command arguments
- Use shlex.quote() to escape special characters in argument values
- Fixes: 'unrecognized arguments' error when values contain spaces
- Example: target_name='example.com scan' now correctly quoted
2025-12-30 17:32:09 +08:00
yyhuni
81abc59961 Refactor: Migrate TaskDistributor to Docker SDK
- Replace CLI subprocess with Python Docker SDK
- Add DockerClientManager for unified container management
- Remove 300+ lines of shell command building code
- Enable future features: container status monitoring, log streaming

Breaking changes: None (backward compatible with existing scans)
Rollback: git reset --hard v1.0-before-docker-sdk
2025-12-30 17:23:18 +08:00
yyhuni
ffbfec6dd5 feat(stage2): Refactor TaskDistributor to use Docker SDK
- Replace CLI subprocess calls with DockerClientManager.run_container()
- Add helper methods: _build_container_command, _build_container_environment, _build_container_volumes
- Refactor execute_scan_flow() and execute_cleanup_on_all_workers() to use SDK
- Remove old CLI methods: _build_docker_command, _execute_docker_command, _execute_local_docker, _execute_ssh_docker
- Remove paramiko import (no longer needed for local workers)

Benefits:
- 300+ lines removed (CLI string building complexity)
- Type-safe container configuration (no more shlex.quote errors)
- Structured error handling (ImageNotFound, APIError)
- Ready for container status monitoring and log streaming
2025-12-30 17:20:26 +08:00
yyhuni
a0091636a8 feat(stage1): Add DockerClientManager
- Create docker_client_manager.py with local Docker client support
- Add container lifecycle management (run, status, logs, stop, remove)
- Implement structured error handling (ImageNotFound, APIError)
- Add client connection caching and reuse
- Set Docker API version to 1.40 (compatible with Docker 19.03+)
- Add dependencies: docker>=6.0.0, packaging>=21.0

TODO: Remote worker support (Docker Context or SSH tunnel)
2025-12-30 17:17:17 +08:00
yyhuni
69490ab396 feat: Add DockerClientManager for unified Docker client management
- Create docker_client_manager.py with local Docker client support
- Add container lifecycle management (run, status, logs, stop, remove)
- Implement structured error handling (ImageNotFound, APIError)
- Add client connection caching and reuse
- Set Docker API version to 1.40 (compatible with Docker 19.03+)
- Add docker>=6.0.0 and packaging>=21.0 dependencies

TODO: Remote worker support (Docker Context or SSH tunnel)
2025-12-30 17:15:29 +08:00
yyhuni
7306964abf 更新readme 2025-12-30 16:44:08 +08:00
yyhuni
cb6b0259e3 fix:响应不匹配 2025-12-30 16:40:17 +08:00
yyhuni
e1b4618e58 refactor(worker): isolate scan tools to dedicated directory
- Move scan tools base path from `/usr/local/bin` to `/opt/xingrin-tools/bin` to avoid conflicts with system tools and Python packages
- Create dedicated `/opt/xingrin-tools/bin` directory in worker Dockerfile following FHS standards
- Update PATH environment variable to prioritize project-specific tools directory
- Add `SCAN_TOOLS_PATH` environment variable to `.env.example` with documentation
- Update settings.py to use new default path with explanatory comments
- Fix TypeScript type annotation in system-logs-view.tsx for better maintainability
- Remove frontend package-lock.json to reduce repository size
- Update task distributor comment to reflect new tool location
This change improves tool isolation and prevents naming conflicts while maintaining FHS compliance.
2025-12-30 11:42:09 +08:00
yyhuni
556dcf5f62 重构日志ui功能 2025-12-30 11:13:38 +08:00
yyhuni
0628eef025 重构响应为标准响应格式 2025-12-30 10:56:26 +08:00
yyhuni
38ed8bc642 fix(scan): improve config parser validation and enable subdomain resolve timeout
- Uncomment timeout: auto setting in subdomain discovery config example
- Add validation to reject None or non-dict configuration values
- Raise ValueError with descriptive message when config is None
- Raise ValueError when config is not a dictionary type
- Update docstring to document Raises section for error conditions
- Prevent silent failures from malformed YAML configurations
2025-12-30 08:54:02 +08:00
yyhuni
2f4d6a2168 统一工具挂载为/usr/local/bin 2025-12-30 08:45:36 +08:00
yyhuni
c25cb9e06b fix:工具挂载 2025-12-30 08:39:17 +08:00
yyhuni
b14ab71c7f fix:auth frontend 2025-12-30 08:12:04 +08:00
github-actions[bot]
8b5060e2d3 chore: bump version to v1.2.2-dev 2025-12-29 17:08:05 +00:00
yyhuni
3c9335febf refactor: determine target branch by tag location instead of naming
- Check which branch contains the tag (main or dev)
- Update VERSION file on the source branch
- Only tags from main branch update 'latest' Docker tag
- More flexible and follows standard Git workflow
2025-12-29 23:34:05 +08:00
yyhuni
1b95e4f2c3 feat: update VERSION file for dev tags on dev branch
- Dev tags (v*-dev) now update VERSION file on dev branch
- Release tags (v* without suffix) update VERSION file on main branch
- Keeps main and dev branches independent
2025-12-29 23:30:17 +08:00
yyhuni
d20a600afc refactor: use settings.GIT_MIRROR instead of os.getenv in worker_views 2025-12-29 23:13:35 +08:00
yyhuni
c29b11fd37 feat: add GIT_MIRROR to worker config center
- Add gitMirror field to worker configuration API
- Container bootstrap reads gitMirror and sets GIT_MIRROR env var
- Remove redundant GIT_MIRROR injection from task_distributor
- All environment variables are managed through config center
2025-12-29 23:11:31 +08:00
yyhuni
6caf707072 refactor: replace Chinese comments with English in frontend components
- Replace all Chinese inline comments with English equivalents across 24 frontend component files
- Update JSDoc comments to use English for better code documentation
- Improve code readability and maintainability for international development team
- Standardize comment style across directories, endpoints, ip-addresses, subdomains, and websites components
- Ensure consistency with previous frontend refactoring efforts
2025-12-29 23:01:16 +08:00
yyhuni
2627b1fc40 refactor: replace Chinese comments with English across frontend components
- Replace Chinese comments with English in fingerprint components (ehole, goby, wappalyzer)
- Update comments in scan engine, history, and scheduled scan modules
- Translate comments in worker deployment and configuration dialogs
- Update comments in subdomain management and target components
- Translate comments in tools configuration and command modules
- Replace Chinese comments in vulnerability components
- Improve code maintainability and consistency with English documentation standards
- Update Docker build workflow cache configuration with image-specific scopes for better cache isolation
2025-12-29 22:14:12 +08:00
yyhuni
ec6712b9b4 fix: add null coalescing to prevent undefined values in i18n translations
- Add null coalescing operator (?? "") to all i18n translation parameters across components
- Fix scheduled scan deletion dialog to handle undefined scheduled scan name
- Fix nuclei page to pass locale parameter to formatDateTime function
- Fix organization detail view unlink target dialog to handle undefined target name
- Fix organization list deletion dialog to handle undefined organization name
- Fix organization targets detail view unlink dialog to handle undefined target name
- Fix engine edit dialog to handle undefined engine name
- Fix scan history list deletion and stop dialogs to handle undefined target names
- Fix worker list deletion dialog to handle undefined worker name
- Fix all targets detail view deletion dialog to handle undefined target name
- Fix custom tools and opensource tools lists to handle undefined tool names
- Fix vulnerabilities detail view to handle undefined vulnerability names
- Prevents runtime errors when translation parameters are undefined or null
2025-12-29 21:03:47 +08:00
yyhuni
9d5e4d5408 fix(scan/engine): handle undefined engine name in delete confirmation
- Add nullish coalescing operator to prevent undefined value in delete confirmation message
- Ensure engineToDelete?.name defaults to empty string when undefined
- Improve robustness of alert dialog description rendering
2025-12-29 20:54:00 +08:00
yyhuni
c5d5b24c8f 更新github action dev版本不更新version 2025-12-29 20:48:42 +08:00
yyhuni
671cb56b62 fix:nuclei模板加速同步,模板下载到宿主机同步更新 2025-12-29 20:43:49 +08:00
yyhuni
51025f69a8 fix:大陆加速修复 2025-12-29 20:15:25 +08:00
yyhuni
b2403b29c4 删除update.sh 2025-12-29 20:08:40 +08:00
yyhuni
18ef01a47b fix:cn加速 2025-12-29 20:03:14 +08:00
yyhuni
0bf8108fb3 fix:镜像加速 2025-12-29 19:51:33 +08:00
yyhuni
837ad19131 fix:镜像加速问题 2025-12-29 19:48:48 +08:00
yyhuni
d7de9a7129 fix:镜像加速问题 2025-12-29 19:39:59 +08:00
yyhuni
22b4e51b42 feat(xget): add Git URL acceleration support via Xget proxy
- Add xget_proxy utility module to convert Git repository URLs to Xget proxy format
- Support domain mapping for GitHub, GitLab, Gitea, and Codeberg repositories
- Integrate Xget proxy into Nuclei template repository cloning process
- Add XGET_MIRROR environment variable configuration in container bootstrap
- Export XGET_MIRROR setting to worker node configuration endpoint
- Add --mirror flag to install.sh for easy Xget acceleration setup
- Add configure_docker_mirror function to install.sh for Docker registry mirror configuration
- Enable Git clone acceleration for faster template repository downloads in air-gapped or bandwidth-limited environments
2025-12-29 19:32:05 +08:00
yyhuni
d03628ee45 feat(i18n): translate Chinese comments to English in scan history component
- Replace Chinese console error messages with English equivalents
- Translate all inline code comments from Chinese to English
- Update dialog and section comments for consistency
- Improve code readability and maintainability for international development team
2025-12-29 18:42:13 +08:00
yyhuni
0baabe0753 feat(i18n): internationalize frontend components with English translations
- Replace Chinese comments with English equivalents across auth, dashboard, and scan components
- Update UI text labels and descriptions from Chinese to English in bulk-add-urls-dialog
- Translate placeholder text and dialog titles in asset management components
- Update column headers and data table labels to English in organization and engine modules
- Standardize English documentation strings in auth-guard and auth-layout components
- Improve code maintainability and accessibility for international users
- Align with existing internationalization efforts across the frontend codebase
2025-12-29 18:39:25 +08:00
yyhuni
e1191d7abf 国际化前端ui 2025-12-29 18:10:05 +08:00
yyhuni
82a2e9a0e7 国际化前端 2025-12-29 18:09:57 +08:00
yyhuni
1ccd1bc338 更新gfPatterns 2025-12-28 20:26:32 +08:00
yyhuni
b4d42f5372 更新指纹管理搜索 2025-12-28 20:18:26 +08:00
yyhuni
2c66450756 统一ui 2025-12-28 20:10:46 +08:00
yyhuni
119d82dc89 更新ui 2025-12-28 20:06:17 +08:00
yyhuni
fba7f7c508 更新ui 2025-12-28 19:55:57 +08:00
yyhuni
99d384ce29 修复前端列宽 2025-12-28 16:37:35 +08:00
yyhuni
07f36718ab 重构前端 2025-12-28 16:27:01 +08:00
yyhuni
7e3f69c208 重构前端组件 2025-12-28 12:05:47 +08:00
yyhuni
5f90473c3c fix:ui 2025-12-28 08:48:25 +08:00
yyhuni
e2a815b96a 增加:goby wappalyzer指纹 2025-12-28 08:42:37 +08:00
yyhuni
f86a1a9d47 优化ui 2025-12-27 22:01:40 +08:00
yyhuni
d5945679aa 增加日志 2025-12-27 21:50:43 +08:00
yyhuni
51e2c51748 fix:目录创建挂载 2025-12-27 21:44:47 +08:00
yyhuni
e2cbf98dda fix:target name已去除的bug 2025-12-27 21:27:05 +08:00
yyhuni
cd72bdf7c3 指纹接入 2025-12-27 20:19:25 +08:00
yyhuni
35abcf7e39 加入黑名单逻辑 2025-12-27 20:12:01 +08:00
yyhuni
09f2d343a4 新增:重构导出逻辑代码,加入黑名单过滤 2025-12-27 20:11:50 +08:00
yyhuni
54d1f86bde fix:安装报错 2025-12-27 17:51:32 +08:00
yyhuni
a3997c9676 更新yaml 2025-12-27 12:52:49 +08:00
yyhuni
c90a55f85e 更新负载逻辑 2025-12-27 12:49:14 +08:00
yyhuni
2eab88b452 chore(install): Add banner display and update confirmation
- Add show_banner() function to display XingRin ASCII art logo
- Call show_banner() before header in install.sh initialization
- Add experimental feature warning in update.sh with user confirmation
- Prompt user to confirm before proceeding with update operation
- Suggest full reinstall via uninstall.sh and install.sh as alternative
- Improve user experience with visual feedback and safety checks
2025-12-27 12:41:04 +08:00
yyhuni
1baf0eb5e1 fix:指纹扫描命令 2025-12-27 12:29:50 +08:00
yyhuni
b61e73f7be fix:json输出 2025-12-27 12:14:35 +08:00
yyhuni
e896734dfc feat(scan-engine): Add fingerprint detection feature flag
- Add fingerprint_detect feature flag to engine configuration parser
- Enable fingerprint detection capability in scan engine features
- Integrate fingerprint detection into existing feature detection logic
2025-12-27 11:59:51 +08:00
yyhuni
cd83f52f35 新增指纹识别 2025-12-27 11:39:26 +08:00
yyhuni
3e29554c36 新增:指纹识别 2025-12-27 11:39:19 +08:00
yyhuni
18e02b536e 加入:指纹识别 2025-12-27 10:06:23 +08:00
yyhuni
4c1c6f70ab 更新指纹 2025-12-26 21:50:38 +08:00
yyhuni
a72e7675f5 更新ui 2025-12-26 21:40:56 +08:00
yyhuni
93c2163764 新增:ehole指纹的导入 2025-12-26 21:34:36 +08:00
yyhuni
de72c91561 更新ui 2025-12-25 18:31:09 +08:00
github-actions[bot]
3e6d060b75 chore: bump version to v1.1.14 2025-12-25 10:11:08 +00:00
yyhuni
766f045904 fix:ffuf并发问题 2025-12-25 18:02:25 +08:00
yyhuni
8acfe1cc33 调整日志级别 2025-12-25 17:44:31 +08:00
github-actions[bot]
7aec3eabb2 chore: bump version to v1.1.13 2025-12-25 08:29:39 +00:00
yyhuni
b1f11c36a4 fix:字典下载端口 2025-12-25 16:21:32 +08:00
yyhuni
d97fb5245a 修复:提示 2025-12-25 16:18:46 +08:00
github-actions[bot]
ddf9a1f5a4 chore: bump version to v1.1.12 2025-12-25 08:10:57 +00:00
yyhuni
47f9f96a4b 更新文档 2025-12-25 16:07:30 +08:00
yyhuni
6f43e73162 readme up 2025-12-25 16:06:01 +08:00
yyhuni
9b7d496f3e 更新:端口号为8083 2025-12-25 16:02:55 +08:00
github-actions[bot]
6390849d52 chore: bump version to v1.1.11 2025-12-25 03:58:05 +00:00
yyhuni
7a6d2054f6 更新:ui 2025-12-25 11:50:21 +08:00
yyhuni
73ebaab232 更新:ui 2025-12-25 11:31:25 +08:00
github-actions[bot]
11899b29c2 chore: bump version to v1.1.10 2025-12-25 03:20:57 +00:00
github-actions[bot]
877d2a56d1 chore: bump version to v1.1.9 2025-12-25 03:13:58 +00:00
yyhuni
dc1e94f038 更新:ui 2025-12-25 11:12:51 +08:00
yyhuni
9c3833d13d 更新:ui 2025-12-25 11:06:00 +08:00
github-actions[bot]
92f3b722ef chore: bump version to v1.1.8 2025-12-25 02:16:12 +00:00
yyhuni
9ef503c666 更新:ui 2025-12-25 10:12:06 +08:00
yyhuni
c3a43e94fa 修复:ui 2025-12-25 10:08:25 +08:00
github-actions[bot]
d6d94355fb chore: bump version to v1.1.7 2025-12-25 02:02:27 +00:00
yyhuni
bc638eabf4 更新:ui 2025-12-25 10:02:13 +08:00
yyhuni
5acaada7ab 新增:支持多字段搜索功能 2025-12-25 09:54:50 +08:00
github-actions[bot]
aaad3f29cf chore: bump version to v1.1.6 2025-12-24 12:19:12 +00:00
yyhuni
f13eb2d9b2 更新:ui风格 2025-12-24 20:10:12 +08:00
yyhuni
f1b3b60382 新增:EVA主题 2025-12-24 19:57:26 +08:00
yyhuni
e249056289 Update README.md 2025-12-24 19:14:22 +08:00
yyhuni
dba195b83a 更新readme 2025-12-24 17:28:08 +08:00
github-actions[bot]
9b494e6c67 chore: bump version to v1.1.5 2025-12-24 09:23:21 +00:00
yyhuni
2841157747 优化:字体显示 2025-12-24 17:14:45 +08:00
yyhuni
f6c1fef1a6 修复:仪表盘页面删除问题 2025-12-24 17:10:48 +08:00
yyhuni
6ec0adf9dd 优化:日志打印 2025-12-24 16:39:13 +08:00
yyhuni
22c6661567 更新:ui 2025-12-24 16:25:41 +08:00
github-actions[bot]
d9ed004e35 chore: bump version to v1.1.4 2025-12-24 08:23:12 +00:00
yyhuni
a0d9d1f29d 新增:批量添加资产 2025-12-24 16:15:33 +08:00
yyhuni
8aa9ed2a97 新增:新增功能,目标详细页面批量添加资产 2025-12-24 16:15:22 +08:00
yyhuni
8baf29d1c3 新增:子域名添加功能 2025-12-24 11:27:48 +08:00
yyhuni
248e48353a 更新:数据库字段为create at 2025-12-24 10:35:55 +08:00
yyhuni
0d210be50b 更新:subdomain的字段,discovered_at TO created_at 2025-12-24 10:19:01 +08:00
github-actions[bot]
f7c0d0b215 chore: bump version to v1.1.3 2025-12-24 02:11:23 +00:00
github-actions[bot]
d83428f27b chore: bump version to v1.1.2 2025-12-24 02:08:28 +00:00
yyhuni
45a09b8173 优化:增强数据库连接稳定性 2025-12-24 10:03:24 +08:00
yyhuni
11dfdee6fd 更新ui 2025-12-24 09:57:39 +08:00
yyhuni
e53a884d13 更新:ui 2025-12-24 09:54:48 +08:00
yyhuni
3b318c89e3 fix:主题ui 2025-12-24 09:46:51 +08:00
github-actions[bot]
e564bc116a chore: bump version to v1.1.1 2025-12-23 12:17:52 +00:00
450 changed files with 207825 additions and 23540 deletions

View File

@@ -106,33 +106,65 @@ jobs:
${{ steps.version.outputs.IS_RELEASE == 'true' && format('{0}/{1}:latest', env.IMAGE_PREFIX, matrix.image) || '' }}
build-args: |
IMAGE_TAG=${{ steps.version.outputs.VERSION }}
cache-from: type=gha
cache-to: type=gha,mode=max
cache-from: type=gha,scope=${{ matrix.image }}
cache-to: type=gha,mode=max,scope=${{ matrix.image }}
provenance: false
sbom: false
# 所有镜像构建成功后,更新 VERSION 文件
# 根据 tag 所在的分支更新对应分支的 VERSION 文件
update-version:
runs-on: ubuntu-latest
needs: build
if: startsWith(github.ref, 'refs/tags/v')
steps:
- name: Checkout
- name: Checkout repository
uses: actions/checkout@v4
with:
ref: main
fetch-depth: 0 # 获取完整历史,用于判断 tag 所在分支
token: ${{ secrets.GITHUB_TOKEN }}
- name: Determine source branch and version
id: branch
run: |
VERSION="${GITHUB_REF#refs/tags/}"
# 查找包含此 tag 的分支
BRANCHES=$(git branch -r --contains ${{ github.ref_name }})
echo "Branches containing tag: $BRANCHES"
# 判断 tag 来自哪个分支
if echo "$BRANCHES" | grep -q "origin/main"; then
TARGET_BRANCH="main"
UPDATE_LATEST="true"
elif echo "$BRANCHES" | grep -q "origin/dev"; then
TARGET_BRANCH="dev"
UPDATE_LATEST="false"
else
echo "Warning: Tag not found in main or dev branch, defaulting to main"
TARGET_BRANCH="main"
UPDATE_LATEST="false"
fi
echo "BRANCH=$TARGET_BRANCH" >> $GITHUB_OUTPUT
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
echo "UPDATE_LATEST=$UPDATE_LATEST" >> $GITHUB_OUTPUT
echo "Will update VERSION on branch: $TARGET_BRANCH"
- name: Checkout target branch
run: |
git checkout ${{ steps.branch.outputs.BRANCH }}
- name: Update VERSION file
run: |
VERSION="${GITHUB_REF#refs/tags/}"
VERSION="${{ steps.branch.outputs.VERSION }}"
echo "$VERSION" > VERSION
echo "Updated VERSION to $VERSION"
echo "Updated VERSION to $VERSION on branch ${{ steps.branch.outputs.BRANCH }}"
- name: Commit and push
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git add VERSION
git diff --staged --quiet || git commit -m "chore: bump version to ${GITHUB_REF#refs/tags/}"
git push
git diff --staged --quiet || git commit -m "chore: bump version to ${{ steps.branch.outputs.VERSION }}"
git push origin ${{ steps.branch.outputs.BRANCH }}

1
.gitignore vendored
View File

@@ -133,3 +133,4 @@ temp/
HGETALL
KEYS
vuln_scan/input_endpoints.txt
open-in-v0

View File

@@ -25,23 +25,16 @@
---
<p align="center">
<b>🌗 明暗模式切换</b>
<b>🎨 现代化 UI </b>
</p>
<p align="center">
<img src="docs/screenshots/light.png" alt="Light Mode" width="49%">
<img src="docs/screenshots/dark.png" alt="Dark Mode" width="49%">
</p>
<p align="center">
<b>🎨 多种 UI 主题</b>
</p>
<p align="center">
<img src="docs/screenshots/bubblegum.png" alt="Bubblegum" width="32%">
<img src="docs/screenshots/cosmic-night.png" alt="Cosmic Night" width="32%">
<img src="docs/screenshots/quantum-rose.png" alt="Quantum Rose" width="32%">
<img src="docs/screenshots/light.png" alt="Light Mode" width="24%">
<img src="docs/screenshots/bubblegum.png" alt="Bubblegum" width="24%">
<img src="docs/screenshots/cosmic-night.png" alt="Cosmic Night" width="24%">
<img src="docs/screenshots/quantum-rose.png" alt="Quantum Rose" width="24%">
</p>
## 📚 文档
@@ -184,11 +177,19 @@ cd xingrin
# 安装并启动(生产模式)
sudo ./install.sh
# 🇨🇳 中国大陆用户推荐使用镜像加速(第三方加速服务可能会失效,不保证长期可用)
sudo ./install.sh --mirror
```
> **💡 --mirror 参数说明**
> - 自动配置 Docker 镜像加速(国内镜像源)
> - 加速 Git 仓库克隆Nuclei 模板等)
> - 大幅提升安装速度,避免网络超时
### 访问服务
- **Web 界面**: `https://localhost`
- **Web 界面**: `https://ip:8083`
### 常用命令
@@ -204,16 +205,12 @@ sudo ./restart.sh
# 卸载
sudo ./uninstall.sh
# 更新
sudo ./update.sh
```
## 🤝 反馈与贡献
- 🐛 **如果发现 Bug** 可以点击右边链接进行提交 [Issue](https://github.com/yyhuni/xingrin/issues)
- 💡 **有新想法比如UI设计功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues)
- 🔧 **想参与开发?** 关注我公众号与我个人联系
## 📧 联系
- 目前版本就我个人使用,可能会有很多边界问题

View File

@@ -1 +1 @@
v1.1.0
v1.2.2-dev

View File

@@ -9,7 +9,7 @@ class WebSiteDTO:
"""网站数据传输对象"""
target_id: int
url: str
host: str
host: str = ''
title: str = ''
status_code: Optional[int] = None
content_length: Optional[int] = None

View File

@@ -22,15 +22,15 @@ class Subdomain(models.Model):
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
)
name = models.CharField(max_length=1000, help_text='子域名名称')
discovered_at = models.DateTimeField(auto_now_add=True, help_text='首次发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'subdomain'
verbose_name = '子域名'
verbose_name_plural = '子域名'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['-discovered_at']),
models.Index(fields=['-created_at']),
models.Index(fields=['name', 'target']), # 复合索引,优化 get_by_names_and_target_id 批量查询
models.Index(fields=['target']), # 优化从target_id快速查找下面的子域名
models.Index(fields=['name']), # 优化从name快速查找子域名搜索场景
@@ -71,7 +71,7 @@ class Endpoint(models.Model):
default='',
help_text='重定向地址HTTP 3xx 响应头 Location'
)
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
title = models.CharField(
max_length=1000,
blank=True,
@@ -128,13 +128,14 @@ class Endpoint(models.Model):
db_table = 'endpoint'
verbose_name = '端点'
verbose_name_plural = '端点'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['-discovered_at']),
models.Index(fields=['-created_at']),
models.Index(fields=['target']), # 优化从target_id快速查找下面的端点主关联字段
models.Index(fields=['url']), # URL索引优化查询性能
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['status_code']), # 状态码索引,优化筛选
models.Index(fields=['title']), # title索引优化智能过滤搜索
]
constraints = [
# 普通唯一约束url + target 组合唯一
@@ -172,7 +173,7 @@ class WebSite(models.Model):
default='',
help_text='重定向地址HTTP 3xx 响应头 Location'
)
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
title = models.CharField(
max_length=1000,
blank=True,
@@ -223,12 +224,14 @@ class WebSite(models.Model):
db_table = 'website'
verbose_name = '站点'
verbose_name_plural = '站点'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['-discovered_at']),
models.Index(fields=['-created_at']),
models.Index(fields=['url']), # URL索引优化查询性能
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['target']), # 优化从target_id快速查找下面的站点
models.Index(fields=['title']), # title索引优化智能过滤搜索
models.Index(fields=['status_code']), # 状态码索引,优化智能过滤搜索
]
constraints = [
# 普通唯一约束url + target 组合唯一
@@ -293,15 +296,15 @@ class Directory(models.Model):
help_text='请求耗时(单位:纳秒)'
)
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'directory'
verbose_name = '目录'
verbose_name_plural = '目录'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['-discovered_at']),
models.Index(fields=['-created_at']),
models.Index(fields=['target']), # 优化从target_id快速查找下面的目录
models.Index(fields=['url']), # URL索引优化搜索和唯一约束
models.Index(fields=['status']), # 状态码索引,优化筛选
@@ -358,23 +361,23 @@ class HostPortMapping(models.Model):
)
# ==================== 时间字段 ====================
discovered_at = models.DateTimeField(
created_at = models.DateTimeField(
auto_now_add=True,
help_text='发现时间'
help_text='创建时间'
)
class Meta:
db_table = 'host_port_mapping'
verbose_name = '主机端口映射'
verbose_name_plural = '主机端口映射'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['target']), # 优化按目标查询
models.Index(fields=['host']), # 优化按主机名查询
models.Index(fields=['ip']), # 优化按IP查询
models.Index(fields=['port']), # 优化按端口查询
models.Index(fields=['host', 'ip']), # 优化组合查询
models.Index(fields=['-discovered_at']), # 优化时间排序
models.Index(fields=['-created_at']), # 优化时间排序
]
constraints = [
# 复合唯一约束target + host + ip + port 组合唯一
@@ -408,7 +411,7 @@ class Vulnerability(models.Model):
)
# ==================== 核心字段 ====================
url = models.TextField(help_text='漏洞所在的URL')
url = models.CharField(max_length=2000, help_text='漏洞所在的URL')
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli')
severity = models.CharField(
max_length=20,
@@ -422,19 +425,20 @@ class Vulnerability(models.Model):
raw_output = models.JSONField(blank=True, default=dict, help_text='工具原始输出')
# ==================== 时间字段 ====================
discovered_at = models.DateTimeField(auto_now_add=True, help_text='首次发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'vulnerability'
verbose_name = '漏洞'
verbose_name_plural = '漏洞'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['target']),
models.Index(fields=['vuln_type']),
models.Index(fields=['severity']),
models.Index(fields=['source']),
models.Index(fields=['-discovered_at']),
models.Index(fields=['url']), # url索引优化智能过滤搜索
models.Index(fields=['-created_at']),
]
def __str__(self):

View File

@@ -15,17 +15,17 @@ class SubdomainSnapshot(models.Model):
)
name = models.CharField(max_length=1000, help_text='子域名名称')
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'subdomain_snapshot'
verbose_name = '子域名快照'
verbose_name_plural = '子域名快照'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['name']),
models.Index(fields=['-discovered_at']),
models.Index(fields=['-created_at']),
]
constraints = [
# 唯一约束:同一次扫描中,同一个子域名只能记录一次
@@ -70,18 +70,19 @@ class WebsiteSnapshot(models.Model):
)
body_preview = models.TextField(blank=True, default='', help_text='响应体预览')
vhost = models.BooleanField(null=True, blank=True, help_text='虚拟主机标志')
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'website_snapshot'
verbose_name = '网站快照'
verbose_name_plural = '网站快照'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['url']),
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['-discovered_at']),
models.Index(fields=['title']), # title索引优化标题搜索
models.Index(fields=['-created_at']),
]
constraints = [
# 唯一约束同一次扫描中同一个URL只能记录一次
@@ -118,18 +119,19 @@ class DirectorySnapshot(models.Model):
lines = models.IntegerField(null=True, blank=True, help_text='响应体行数(按换行符分割)')
content_type = models.CharField(max_length=200, blank=True, default='', help_text='响应头 Content-Type 值')
duration = models.BigIntegerField(null=True, blank=True, help_text='请求耗时(单位:纳秒)')
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'directory_snapshot'
verbose_name = '目录快照'
verbose_name_plural = '目录快照'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['url']),
models.Index(fields=['status']), # 状态码索引,优化筛选
models.Index(fields=['-discovered_at']),
models.Index(fields=['content_type']), # content_type索引优化内容类型搜索
models.Index(fields=['-created_at']),
]
constraints = [
# 唯一约束同一次扫描中同一个目录URL只能记录一次
@@ -183,16 +185,16 @@ class HostPortMappingSnapshot(models.Model):
)
# ==================== 时间字段 ====================
discovered_at = models.DateTimeField(
created_at = models.DateTimeField(
auto_now_add=True,
help_text='发现时间'
help_text='创建时间'
)
class Meta:
db_table = 'host_port_mapping_snapshot'
verbose_name = '主机端口映射快照'
verbose_name_plural = '主机端口映射快照'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']), # 优化按扫描查询
models.Index(fields=['host']), # 优化按主机名查询
@@ -200,7 +202,7 @@ class HostPortMappingSnapshot(models.Model):
models.Index(fields=['port']), # 优化按端口查询
models.Index(fields=['host', 'ip']), # 优化组合查询
models.Index(fields=['scan', 'host']), # 优化扫描+主机查询
models.Index(fields=['-discovered_at']), # 优化时间排序
models.Index(fields=['-created_at']), # 优化时间排序
]
constraints = [
# 复合唯一约束同一次扫描中scan + host + ip + port 组合唯一
@@ -257,19 +259,21 @@ class EndpointSnapshot(models.Model):
default=list,
help_text='匹配的GF模式列表'
)
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'endpoint_snapshot'
verbose_name = '端点快照'
verbose_name_plural = '端点快照'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['url']),
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['title']), # title索引优化标题搜索
models.Index(fields=['status_code']), # 状态码索引,优化筛选
models.Index(fields=['-discovered_at']),
models.Index(fields=['webserver']), # webserver索引优化服务器搜索
models.Index(fields=['-created_at']),
]
constraints = [
# 唯一约束同一次扫描中同一个URL只能记录一次
@@ -302,7 +306,7 @@ class VulnerabilitySnapshot(models.Model):
)
# ==================== 核心字段 ====================
url = models.TextField(help_text='漏洞所在的URL')
url = models.CharField(max_length=2000, help_text='漏洞所在的URL')
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli')
severity = models.CharField(
max_length=20,
@@ -316,19 +320,20 @@ class VulnerabilitySnapshot(models.Model):
raw_output = models.JSONField(blank=True, default=dict, help_text='工具原始输出')
# ==================== 时间字段 ====================
discovered_at = models.DateTimeField(auto_now_add=True, help_text='发现时间')
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
class Meta:
db_table = 'vulnerability_snapshot'
verbose_name = '漏洞快照'
verbose_name_plural = '漏洞快照'
ordering = ['-discovered_at']
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['url']), # url索引优化URL搜索
models.Index(fields=['vuln_type']),
models.Index(fields=['severity']),
models.Index(fields=['source']),
models.Index(fields=['-discovered_at']),
models.Index(fields=['-created_at']),
]
def __str__(self):

View File

@@ -74,13 +74,67 @@ class DjangoDirectoryRepository:
logger.error(f"批量 upsert Directory 失败: {e}")
raise
def bulk_create_ignore_conflicts(self, items: List[DirectoryDTO]) -> int:
"""
批量创建 Directory存在即跳过
与 bulk_upsert 不同,此方法不会更新已存在的记录。
适用于批量添加场景,只提供 URL没有其他字段数据。
注意:自动按模型唯一约束去重,保留最后一条记录。
Args:
items: Directory DTO 列表
Returns:
int: 处理的记录数
"""
if not items:
return 0
try:
# 自动按模型唯一约束去重
unique_items = deduplicate_for_bulk(items, Directory)
directories = [
Directory(
target_id=item.target_id,
url=item.url,
status=item.status,
content_length=item.content_length,
words=item.words,
lines=item.lines,
content_type=item.content_type or '',
duration=item.duration
)
for item in unique_items
]
with transaction.atomic():
Directory.objects.bulk_create(
directories,
ignore_conflicts=True,
batch_size=1000
)
logger.debug(f"批量创建 Directory 成功ignore_conflicts: {len(unique_items)}")
return len(unique_items)
except Exception as e:
logger.error(f"批量创建 Directory 失败: {e}")
raise
def count_by_target(self, target_id: int) -> int:
"""统计目标下的目录总数"""
return Directory.objects.filter(target_id=target_id).count()
def get_all(self):
"""获取所有目录"""
return Directory.objects.all().order_by('-discovered_at')
return Directory.objects.all().order_by('-created_at')
def get_by_target(self, target_id: int):
"""获取目标下的所有目录"""
return Directory.objects.filter(target_id=target_id).order_by('-discovered_at')
return Directory.objects.filter(target_id=target_id).order_by('-created_at')
def get_urls_for_export(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
"""流式导出目标下的所有目录 URL"""
@@ -118,7 +172,7 @@ class DjangoDirectoryRepository:
.filter(target_id=target_id)
.values(
'url', 'status', 'content_length', 'words',
'lines', 'content_type', 'duration', 'discovered_at'
'lines', 'content_type', 'duration', 'created_at'
)
.order_by('url')
)

View File

@@ -80,7 +80,7 @@ class DjangoEndpointRepository:
def get_all(self):
"""获取所有端点(全局查询)"""
return Endpoint.objects.all().order_by('-discovered_at')
return Endpoint.objects.all().order_by('-created_at')
def get_by_target(self, target_id: int):
"""
@@ -92,7 +92,7 @@ class DjangoEndpointRepository:
Returns:
QuerySet: 端点查询集
"""
return Endpoint.objects.filter(target_id=target_id).order_by('-discovered_at')
return Endpoint.objects.filter(target_id=target_id).order_by('-created_at')
def count_by_target(self, target_id: int) -> int:
"""
@@ -183,7 +183,7 @@ class DjangoEndpointRepository:
.values(
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'discovered_at'
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
)
.order_by('url')
)

View File

@@ -1,7 +1,9 @@
"""HostPortMapping Repository - Django ORM 实现"""
import logging
from typing import List, Iterator
from typing import List, Iterator, Dict, Optional
from django.db.models import QuerySet, Min
from apps.asset.models.asset_models import HostPortMapping
from apps.asset.dtos.asset import HostPortMappingDTO
@@ -13,7 +15,10 @@ logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoHostPortMappingRepository:
"""HostPortMapping Repository - Django ORM 实现"""
"""HostPortMapping Repository - Django ORM 实现
职责:纯数据访问,不包含业务逻辑
"""
def bulk_create_ignore_conflicts(self, items: List[HostPortMappingDTO]) -> int:
"""
@@ -90,72 +95,20 @@ class DjangoHostPortMappingRepository:
for ip in queryset:
yield ip
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
from django.db.models import Min
def get_queryset_by_target(self, target_id: int) -> QuerySet:
"""获取目标下的 QuerySet"""
return HostPortMapping.objects.filter(target_id=target_id)
qs = HostPortMapping.objects.filter(target_id=target_id)
if search:
qs = qs.filter(ip__icontains=search)
def get_all_queryset(self) -> QuerySet:
"""获取所有记录的 QuerySet"""
return HostPortMapping.objects.all()
ip_aggregated = (
qs
.values('ip')
.annotate(discovered_at=Min('discovered_at'))
.order_by('-discovered_at')
)
results = []
for item in ip_aggregated:
ip = item['ip']
mappings = (
HostPortMapping.objects
.filter(target_id=target_id, ip=ip)
.values('host', 'port')
.distinct()
)
hosts = sorted({m['host'] for m in mappings})
ports = sorted({m['port'] for m in mappings})
results.append({
'ip': ip,
'hosts': hosts,
'ports': ports,
'discovered_at': item['discovered_at'],
})
return results
def get_all_ip_aggregation(self, search: str = None):
"""获取所有 IP 聚合数据(全局查询)"""
from django.db.models import Min
qs = HostPortMapping.objects.all()
if search:
qs = qs.filter(ip__icontains=search)
ip_aggregated = (
qs
.values('ip')
.annotate(discovered_at=Min('discovered_at'))
.order_by('-discovered_at')
)
results = []
for item in ip_aggregated:
ip = item['ip']
mappings = (
HostPortMapping.objects
.filter(ip=ip)
.values('host', 'port')
.distinct()
)
hosts = sorted({m['host'] for m in mappings})
ports = sorted({m['port'] for m in mappings})
results.append({
'ip': ip,
'hosts': hosts,
'ports': ports,
'discovered_at': item['discovered_at'],
})
return results
def get_queryset_by_ip(self, ip: str, target_id: Optional[int] = None) -> QuerySet:
"""获取指定 IP 的 QuerySet"""
qs = HostPortMapping.objects.filter(ip=ip)
if target_id:
qs = qs.filter(target_id=target_id)
return qs
def iter_raw_data_for_export(
self,
@@ -174,13 +127,13 @@ class DjangoHostPortMappingRepository:
'ip': '192.168.1.1',
'host': 'example.com',
'port': 80,
'discovered_at': datetime
'created_at': datetime
}
"""
qs = (
HostPortMapping.objects
.filter(target_id=target_id)
.values('ip', 'host', 'port', 'discovered_at')
.values('ip', 'host', 'port', 'created_at')
.order_by('ip', 'host', 'port')
)

View File

@@ -55,11 +55,11 @@ class DjangoSubdomainRepository:
def get_all(self):
"""获取所有子域名"""
return Subdomain.objects.all().order_by('-discovered_at')
return Subdomain.objects.all().order_by('-created_at')
def get_by_target(self, target_id: int):
"""获取目标下的所有子域名"""
return Subdomain.objects.filter(target_id=target_id).order_by('-discovered_at')
return Subdomain.objects.filter(target_id=target_id).order_by('-created_at')
def count_by_target(self, target_id: int) -> int:
"""统计目标下的域名数量"""
@@ -96,12 +96,12 @@ class DjangoSubdomainRepository:
batch_size: 每批数据量
Yields:
{'name': 'sub.example.com', 'discovered_at': datetime}
{'name': 'sub.example.com', 'created_at': datetime}
"""
qs = (
Subdomain.objects
.filter(target_id=target_id)
.values('name', 'discovered_at')
.values('name', 'created_at')
.order_by('name')
)

View File

@@ -96,11 +96,11 @@ class DjangoWebSiteRepository:
def get_all(self):
"""获取所有网站"""
return WebSite.objects.all().order_by('-discovered_at')
return WebSite.objects.all().order_by('-created_at')
def get_by_target(self, target_id: int):
"""获取目标下的所有网站"""
return WebSite.objects.filter(target_id=target_id).order_by('-discovered_at')
return WebSite.objects.filter(target_id=target_id).order_by('-created_at')
def count_by_target(self, target_id: int) -> int:
"""统计目标下的站点总数"""
@@ -177,7 +177,7 @@ class DjangoWebSiteRepository:
.values(
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'discovered_at'
'body_preview', 'vhost', 'created_at'
)
.order_by('url')
)

View File

@@ -78,10 +78,10 @@ class DjangoDirectorySnapshotRepository:
raise
def get_by_scan(self, scan_id: int):
return DirectorySnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
return DirectorySnapshot.objects.filter(scan_id=scan_id).order_by('-created_at')
def get_all(self):
return DirectorySnapshot.objects.all().order_by('-discovered_at')
return DirectorySnapshot.objects.all().order_by('-created_at')
def iter_raw_data_for_export(
self,
@@ -103,7 +103,7 @@ class DjangoDirectorySnapshotRepository:
.filter(scan_id=scan_id)
.values(
'url', 'status', 'content_length', 'words',
'lines', 'content_type', 'duration', 'discovered_at'
'lines', 'content_type', 'duration', 'created_at'
)
.order_by('url')
)

View File

@@ -74,10 +74,10 @@ class DjangoEndpointSnapshotRepository:
raise
def get_by_scan(self, scan_id: int):
return EndpointSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
return EndpointSnapshot.objects.filter(scan_id=scan_id).order_by('-created_at')
def get_all(self):
return EndpointSnapshot.objects.all().order_by('-discovered_at')
return EndpointSnapshot.objects.all().order_by('-created_at')
def iter_raw_data_for_export(
self,
@@ -100,7 +100,7 @@ class DjangoEndpointSnapshotRepository:
.values(
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'discovered_at'
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
)
.order_by('url')
)

View File

@@ -65,20 +65,28 @@ class DjangoHostPortMappingSnapshotRepository:
)
raise
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
def get_ip_aggregation_by_scan(self, scan_id: int, filter_query: str = None):
from django.db.models import Min
from apps.common.utils.filter_utils import apply_filters
qs = HostPortMappingSnapshot.objects.filter(scan_id=scan_id)
if search:
qs = qs.filter(ip__icontains=search)
# 应用智能过滤
if filter_query:
field_mapping = {
'ip': 'ip',
'port': 'port',
'host': 'host',
}
qs = apply_filters(qs, filter_query, field_mapping)
ip_aggregated = (
qs
.values('ip')
.annotate(
discovered_at=Min('discovered_at')
created_at=Min('created_at')
)
.order_by('-discovered_at')
.order_by('-created_at')
)
results = []
@@ -98,24 +106,32 @@ class DjangoHostPortMappingSnapshotRepository:
'ip': ip,
'hosts': hosts,
'ports': ports,
'discovered_at': item['discovered_at'],
'created_at': item['created_at'],
})
return results
def get_all_ip_aggregation(self, search: str = None):
def get_all_ip_aggregation(self, filter_query: str = None):
"""获取所有 IP 聚合数据"""
from django.db.models import Min
from apps.common.utils.filter_utils import apply_filters
qs = HostPortMappingSnapshot.objects.all()
if search:
qs = qs.filter(ip__icontains=search)
# 应用智能过滤
if filter_query:
field_mapping = {
'ip': 'ip',
'port': 'port',
'host': 'host',
}
qs = apply_filters(qs, filter_query, field_mapping)
ip_aggregated = (
qs
.values('ip')
.annotate(discovered_at=Min('discovered_at'))
.order_by('-discovered_at')
.annotate(created_at=Min('created_at'))
.order_by('-created_at')
)
results = []
@@ -133,7 +149,7 @@ class DjangoHostPortMappingSnapshotRepository:
'ip': ip,
'hosts': hosts,
'ports': ports,
'discovered_at': item['discovered_at'],
'created_at': item['created_at'],
})
return results
@@ -167,13 +183,13 @@ class DjangoHostPortMappingSnapshotRepository:
'ip': '192.168.1.1',
'host': 'example.com',
'port': 80,
'discovered_at': datetime
'created_at': datetime
}
"""
qs = (
HostPortMappingSnapshot.objects
.filter(scan_id=scan_id)
.values('ip', 'host', 'port', 'discovered_at')
.values('ip', 'host', 'port', 'created_at')
.order_by('ip', 'host', 'port')
)

View File

@@ -61,10 +61,10 @@ class DjangoSubdomainSnapshotRepository:
raise
def get_by_scan(self, scan_id: int):
return SubdomainSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
return SubdomainSnapshot.objects.filter(scan_id=scan_id).order_by('-created_at')
def get_all(self):
return SubdomainSnapshot.objects.all().order_by('-discovered_at')
return SubdomainSnapshot.objects.all().order_by('-created_at')
def iter_raw_data_for_export(
self,
@@ -79,12 +79,12 @@ class DjangoSubdomainSnapshotRepository:
batch_size: 每批数据量
Yields:
{'name': 'sub.example.com', 'discovered_at': datetime}
{'name': 'sub.example.com', 'created_at': datetime}
"""
qs = (
SubdomainSnapshot.objects
.filter(scan_id=scan_id)
.values('name', 'discovered_at')
.values('name', 'created_at')
.order_by('name')
)

View File

@@ -66,7 +66,7 @@ class DjangoVulnerabilitySnapshotRepository:
def get_by_scan(self, scan_id: int):
"""按扫描任务获取漏洞快照 QuerySet。"""
return VulnerabilitySnapshot.objects.filter(scan_id=scan_id).order_by("-discovered_at")
return VulnerabilitySnapshot.objects.filter(scan_id=scan_id).order_by("-created_at")
def get_all(self):
return VulnerabilitySnapshot.objects.all().order_by('-discovered_at')
return VulnerabilitySnapshot.objects.all().order_by('-created_at')

View File

@@ -74,10 +74,10 @@ class DjangoWebsiteSnapshotRepository:
raise
def get_by_scan(self, scan_id: int):
return WebsiteSnapshot.objects.filter(scan_id=scan_id).order_by('-discovered_at')
return WebsiteSnapshot.objects.filter(scan_id=scan_id).order_by('-created_at')
def get_all(self):
return WebsiteSnapshot.objects.all().order_by('-discovered_at')
return WebsiteSnapshot.objects.all().order_by('-created_at')
def iter_raw_data_for_export(
self,
@@ -100,7 +100,7 @@ class DjangoWebsiteSnapshotRepository:
.values(
'url', 'host', 'location', 'title', 'status',
'content_length', 'content_type', 'web_server', 'tech',
'body_preview', 'vhost', 'discovered_at'
'body_preview', 'vhost', 'created_at'
)
.order_by('url')
)
@@ -119,5 +119,5 @@ class DjangoWebsiteSnapshotRepository:
'tech': row['tech'],
'body_preview': row['body_preview'],
'vhost': row['vhost'],
'discovered_at': row['discovered_at'],
'created_at': row['created_at'],
}

View File

@@ -26,9 +26,9 @@ class SubdomainSerializer(serializers.ModelSerializer):
class Meta:
model = Subdomain
fields = [
'id', 'name', 'discovered_at', 'target'
'id', 'name', 'created_at', 'target'
]
read_only_fields = ['id', 'discovered_at']
read_only_fields = ['id', 'created_at']
class SubdomainListSerializer(serializers.ModelSerializer):
@@ -41,9 +41,9 @@ class SubdomainListSerializer(serializers.ModelSerializer):
class Meta:
model = Subdomain
fields = [
'id', 'name', 'discovered_at'
'id', 'name', 'created_at'
]
read_only_fields = ['id', 'discovered_at']
read_only_fields = ['id', 'created_at']
# class IPAddressListSerializer(serializers.ModelSerializer):
@@ -87,7 +87,7 @@ class WebSiteSerializer(serializers.ModelSerializer):
'tech',
'vhost',
'subdomain',
'discovered_at',
'created_at',
]
read_only_fields = fields
@@ -107,7 +107,7 @@ class VulnerabilitySerializer(serializers.ModelSerializer):
'cvss_score',
'description',
'raw_output',
'discovered_at',
'created_at',
]
read_only_fields = fields
@@ -126,7 +126,7 @@ class VulnerabilitySnapshotSerializer(serializers.ModelSerializer):
'cvss_score',
'description',
'raw_output',
'discovered_at',
'created_at',
]
read_only_fields = fields
@@ -134,8 +134,8 @@ class VulnerabilitySnapshotSerializer(serializers.ModelSerializer):
class EndpointListSerializer(serializers.ModelSerializer):
"""端点列表序列化器(用于目标端点列表页)"""
# GF 匹配模式映射为前端使用的 tags 字段
tags = serializers.ListField(
# GF 匹配模式gf-patterns 工具匹配的敏感 URL 模式)
gfPatterns = serializers.ListField(
child=serializers.CharField(),
source='matched_gf_patterns',
read_only=True,
@@ -155,8 +155,8 @@ class EndpointListSerializer(serializers.ModelSerializer):
'body_preview',
'tech',
'vhost',
'tags',
'discovered_at',
'gfPatterns',
'created_at',
]
read_only_fields = fields
@@ -164,7 +164,7 @@ class EndpointListSerializer(serializers.ModelSerializer):
class DirectorySerializer(serializers.ModelSerializer):
"""目录序列化器"""
discovered_at = serializers.DateTimeField(read_only=True)
created_at = serializers.DateTimeField(read_only=True)
class Meta:
model = Directory
@@ -177,7 +177,7 @@ class DirectorySerializer(serializers.ModelSerializer):
'lines',
'content_type',
'duration',
'discovered_at',
'created_at',
]
read_only_fields = fields
@@ -190,12 +190,12 @@ class IPAddressAggregatedSerializer(serializers.Serializer):
- ip: IP 地址
- hosts: 该 IP 关联的所有主机名列表
- ports: 该 IP 关联的所有端口列表
- discovered_at: 首次发现时间
- created_at: 创建时间
"""
ip = serializers.IPAddressField(read_only=True)
hosts = serializers.ListField(child=serializers.CharField(), read_only=True)
ports = serializers.ListField(child=serializers.IntegerField(), read_only=True)
discovered_at = serializers.DateTimeField(read_only=True)
created_at = serializers.DateTimeField(read_only=True)
# ==================== 快照序列化器 ====================
@@ -205,7 +205,7 @@ class SubdomainSnapshotSerializer(serializers.ModelSerializer):
class Meta:
model = SubdomainSnapshot
fields = ['id', 'name', 'discovered_at']
fields = ['id', 'name', 'created_at']
read_only_fields = fields
@@ -231,7 +231,7 @@ class WebsiteSnapshotSerializer(serializers.ModelSerializer):
'tech',
'vhost',
'subdomain_name',
'discovered_at',
'created_at',
]
read_only_fields = fields
@@ -250,7 +250,7 @@ class DirectorySnapshotSerializer(serializers.ModelSerializer):
'lines',
'content_type',
'duration',
'discovered_at',
'created_at',
]
read_only_fields = fields
@@ -258,8 +258,8 @@ class DirectorySnapshotSerializer(serializers.ModelSerializer):
class EndpointSnapshotSerializer(serializers.ModelSerializer):
"""端点快照序列化器(用于扫描历史)"""
# GF 匹配模式映射为前端使用的 tags 字段
tags = serializers.ListField(
# GF 匹配模式gf-patterns 工具匹配的敏感 URL 模式)
gfPatterns = serializers.ListField(
child=serializers.CharField(),
source='matched_gf_patterns',
read_only=True,
@@ -280,7 +280,7 @@ class EndpointSnapshotSerializer(serializers.ModelSerializer):
'body_preview',
'tech',
'vhost',
'tags',
'discovered_at',
'gfPatterns',
'created_at',
]
read_only_fields = fields

View File

@@ -1,10 +1,12 @@
"""Directory Service - 目录业务逻辑层"""
import logging
from typing import List, Iterator
from typing import List, Iterator, Optional
from apps.asset.repositories import DjangoDirectoryRepository
from apps.asset.dtos import DirectoryDTO
from apps.common.validators import is_valid_url, is_url_match_target
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@@ -12,6 +14,12 @@ logger = logging.getLogger(__name__)
class DirectoryService:
"""目录业务逻辑层"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'status': 'status',
}
def __init__(self, repository=None):
"""初始化目录服务"""
self.repo = repository or DjangoDirectoryRepository()
@@ -37,13 +45,75 @@ class DirectoryService:
logger.error(f"批量 upsert 目录失败: {e}")
raise
def get_directories_by_target(self, target_id: int):
"""获取目标下的所有目录"""
return self.repo.get_by_target(target_id)
def bulk_create_urls(self, target_id: int, target_name: str, target_type: str, urls: List[str]) -> int:
"""
批量创建目录(仅 URL使用 ignore_conflicts
验证 URL 格式和匹配,过滤无效/不匹配 URL去重后批量创建。
已存在的记录会被跳过。
Args:
target_id: 目标 ID
target_name: 目标名称(用于匹配验证)
target_type: 目标类型 ('domain', 'ip', 'cidr')
urls: URL 列表
Returns:
int: 实际创建的记录数
"""
if not urls:
return 0
# 过滤有效 URL 并去重
valid_urls = []
seen = set()
for url in urls:
if not isinstance(url, str):
continue
url = url.strip()
if not url or url in seen:
continue
if not is_valid_url(url):
continue
# 匹配验证(前端已阻止不匹配的提交,后端作为双重保障)
if not is_url_match_target(url, target_name, target_type):
continue
seen.add(url)
valid_urls.append(url)
if not valid_urls:
return 0
# 获取创建前的数量
count_before = self.repo.count_by_target(target_id)
# 创建 DTO 列表并批量创建
directory_dtos = [
DirectoryDTO(url=url, target_id=target_id)
for url in valid_urls
]
self.repo.bulk_create_ignore_conflicts(directory_dtos)
# 获取创建后的数量
count_after = self.repo.count_by_target(target_id)
return count_after - count_before
def get_all(self):
def get_directories_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""获取目标下的所有目录"""
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self, filter_query: Optional[str] = None):
"""获取所有目录"""
return self.repo.get_all()
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_directory_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取目标下的所有目录 URL"""

View File

@@ -5,10 +5,12 @@ Endpoint 服务层
"""
import logging
from typing import List, Iterator
from typing import List, Iterator, Optional
from apps.asset.dtos.asset import EndpointDTO
from apps.asset.repositories.asset import DjangoEndpointRepository
from apps.common.validators import is_valid_url, is_url_match_target
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@@ -20,6 +22,14 @@ class EndpointService:
提供 EndpointURL/端点)相关的业务逻辑
"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
}
def __init__(self):
"""初始化 Endpoint 服务"""
self.repo = DjangoEndpointRepository()
@@ -45,9 +55,68 @@ class EndpointService:
logger.error(f"批量 upsert 端点失败: {e}")
raise
def get_endpoints_by_target(self, target_id: int):
def bulk_create_urls(self, target_id: int, target_name: str, target_type: str, urls: List[str]) -> int:
"""
批量创建端点(仅 URL使用 ignore_conflicts
验证 URL 格式和匹配,过滤无效/不匹配 URL去重后批量创建。
已存在的记录会被跳过。
Args:
target_id: 目标 ID
target_name: 目标名称(用于匹配验证)
target_type: 目标类型 ('domain', 'ip', 'cidr')
urls: URL 列表
Returns:
int: 实际创建的记录数
"""
if not urls:
return 0
# 过滤有效 URL 并去重
valid_urls = []
seen = set()
for url in urls:
if not isinstance(url, str):
continue
url = url.strip()
if not url or url in seen:
continue
if not is_valid_url(url):
continue
# 匹配验证(前端已阻止不匹配的提交,后端作为双重保障)
if not is_url_match_target(url, target_name, target_type):
continue
seen.add(url)
valid_urls.append(url)
if not valid_urls:
return 0
# 获取创建前的数量
count_before = self.repo.count_by_target(target_id)
# 创建 DTO 列表并批量创建
endpoint_dtos = [
EndpointDTO(url=url, target_id=target_id)
for url in valid_urls
]
self.repo.bulk_create_ignore_conflicts(endpoint_dtos)
# 获取创建后的数量
count_after = self.repo.count_by_target(target_id)
return count_after - count_before
def get_endpoints_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""获取目标下的所有端点"""
return self.repo.get_by_target(target_id)
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def count_endpoints_by_target(self, target_id: int) -> int:
"""
@@ -61,9 +130,12 @@ class EndpointService:
"""
return self.repo.count_by_target(target_id)
def get_all(self):
def get_all(self, filter_query: Optional[str] = None):
"""获取所有端点(全局查询)"""
return self.repo.get_all()
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_endpoint_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取目标下的所有端点 URL用于导出。"""

View File

@@ -1,16 +1,31 @@
"""HostPortMapping Service - 业务逻辑层"""
import logging
from typing import List, Iterator
from typing import List, Iterator, Optional, Dict
from django.db.models import Min
from apps.asset.repositories.asset import DjangoHostPortMappingRepository
from apps.asset.dtos.asset import HostPortMappingDTO
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
class HostPortMappingService:
"""主机端口映射服务 - 负责主机端口映射数据的业务逻辑"""
"""主机端口映射服务 - 负责主机端口映射数据的业务逻辑
职责:
- 业务逻辑处理(过滤、聚合)
- 调用 Repository 进行数据访问
"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'ip': 'ip',
'port': 'port',
'host': 'host',
}
def __init__(self):
self.repo = DjangoHostPortMappingRepository()
@@ -49,12 +64,93 @@ class HostPortMappingService:
def iter_host_port_by_target(self, target_id: int, batch_size: int = 1000):
return self.repo.get_for_export(target_id=target_id, batch_size=batch_size)
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
return self.repo.get_ip_aggregation_by_target(target_id, search=search)
def get_ip_aggregation_by_target(
self,
target_id: int,
filter_query: Optional[str] = None
) -> List[Dict]:
"""获取目标下的 IP 聚合数据
Args:
target_id: 目标 ID
filter_query: 智能过滤语法字符串
Returns:
聚合后的 IP 数据列表
"""
# 从 Repository 获取基础 QuerySet
qs = self.repo.get_queryset_by_target(target_id)
# Service 层应用过滤逻辑
if filter_query:
qs = apply_filters(qs, filter_query, self.FILTER_FIELD_MAPPING)
# Service 层处理聚合逻辑
return self._aggregate_by_ip(qs, filter_query, target_id=target_id)
def get_all_ip_aggregation(self, search: str = None):
"""获取所有 IP 聚合数据(全局查询)"""
return self.repo.get_all_ip_aggregation(search=search)
def get_all_ip_aggregation(self, filter_query: Optional[str] = None) -> List[Dict]:
"""获取所有 IP 聚合数据(全局查询)
Args:
filter_query: 智能过滤语法字符串
Returns:
聚合后的 IP 数据列表
"""
# 从 Repository 获取基础 QuerySet
qs = self.repo.get_all_queryset()
# Service 层应用过滤逻辑
if filter_query:
qs = apply_filters(qs, filter_query, self.FILTER_FIELD_MAPPING)
# Service 层处理聚合逻辑
return self._aggregate_by_ip(qs, filter_query)
def _aggregate_by_ip(
self,
qs,
filter_query: Optional[str] = None,
target_id: Optional[int] = None
) -> List[Dict]:
"""按 IP 聚合数据
Args:
qs: 已过滤的 QuerySet
filter_query: 过滤条件(用于子查询)
target_id: 目标 ID用于子查询限定范围
Returns:
聚合后的数据列表
"""
ip_aggregated = (
qs
.values('ip')
.annotate(created_at=Min('created_at'))
.order_by('-created_at')
)
results = []
for item in ip_aggregated:
ip = item['ip']
# 获取该 IP 的所有 host 和 port也需要应用过滤条件
mappings_qs = self.repo.get_queryset_by_ip(ip, target_id=target_id)
if filter_query:
mappings_qs = apply_filters(mappings_qs, filter_query, self.FILTER_FIELD_MAPPING)
mappings = mappings_qs.values('host', 'port').distinct()
hosts = sorted({m['host'] for m in mappings})
ports = sorted({m['port'] for m in mappings})
results.append({
'ip': ip,
'hosts': hosts,
'ports': ports,
'created_at': item['created_at'],
})
return results
def iter_ips_by_target(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
"""流式获取目标下的所有唯一 IP 地址。"""
@@ -68,6 +164,6 @@ class HostPortMappingService:
target_id: 目标 ID
Yields:
原始数据字典 {ip, host, port, discovered_at}
原始数据字典 {ip, host, port, created_at}
"""
return self.repo.iter_raw_data_for_export(target_id=target_id)

View File

@@ -1,15 +1,33 @@
import logging
from typing import Tuple, List, Dict
from typing import List, Dict, Optional
from dataclasses import dataclass
from apps.asset.repositories import DjangoSubdomainRepository
from apps.asset.dtos import SubdomainDTO
from apps.common.validators import is_valid_domain
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@dataclass
class BulkCreateResult:
"""批量创建结果"""
created_count: int
skipped_count: int
invalid_count: int
mismatched_count: int
total_received: int
class SubdomainService:
"""子域名业务逻辑层"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'name': 'name',
}
def __init__(self, repository=None):
"""
初始化子域名服务
@@ -21,30 +39,50 @@ class SubdomainService:
# ==================== 查询操作 ====================
def get_all(self):
def get_all(self, filter_query: Optional[str] = None):
"""
获取所有子域名
Args:
filter_query: 智能过滤语法字符串
Returns:
QuerySet: 子域名查询集
"""
logger.debug("获取所有子域名")
return self.repo.get_all()
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
# ==================== 创建操作 ====================
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
def get_subdomains_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""
批量创建子域名,忽略冲突
获取目标下的子域名
Args:
items: 子域名 DTO 列表
target_id: 目标 ID
filter_query: 智能过滤语法字符串
Note:
使用 ignore_conflicts 策略,重复记录会被跳过
Returns:
QuerySet: 子域名查询集
"""
logger.debug("批量创建子域名 - 数量: %d", len(items))
return self.repo.bulk_create_ignore_conflicts(items)
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def count_subdomains_by_target(self, target_id: int) -> int:
"""
统计目标下的子域名数量
Args:
target_id: 目标 ID
Returns:
int: 子域名数量
"""
logger.debug("统计目标下子域名数量 - Target ID: %d", target_id)
return self.repo.count_by_target(target_id)
def get_by_names_and_target_id(self, names: set, target_id: int) -> dict:
"""
@@ -71,25 +109,8 @@ class SubdomainService:
List[str]: 子域名名称列表
"""
logger.debug("获取目标下所有子域名 - Target ID: %d", target_id)
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
return list(self.repo.get_domains_for_export(target_id=target_id))
def get_subdomains_by_target(self, target_id: int):
return self.repo.get_by_target(target_id)
def count_subdomains_by_target(self, target_id: int) -> int:
"""
统计目标下的子域名数量
Args:
target_id: 目标 ID
Returns:
int: 子域名数量
"""
logger.debug("统计目标下子域名数量 - Target ID: %d", target_id)
return self.repo.count_by_target(target_id)
def iter_subdomain_names_by_target(self, target_id: int, chunk_size: int = 1000):
"""
流式获取目标下的所有子域名名称(内存优化)
@@ -102,7 +123,6 @@ class SubdomainService:
str: 子域名名称
"""
logger.debug("流式获取目标下所有子域名 - Target ID: %d, 批次大小: %d", target_id, chunk_size)
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
return self.repo.get_domains_for_export(target_id=target_id, batch_size=chunk_size)
def iter_raw_data_for_csv_export(self, target_id: int):
@@ -113,9 +133,113 @@ class SubdomainService:
target_id: 目标 ID
Yields:
原始数据字典 {name, discovered_at}
原始数据字典 {name, created_at}
"""
return self.repo.iter_raw_data_for_export(target_id=target_id)
# ==================== 创建操作 ====================
__all__ = ['SubdomainService']
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
"""
批量创建子域名,忽略冲突
Args:
items: 子域名 DTO 列表
Note:
使用 ignore_conflicts 策略,重复记录会被跳过
"""
logger.debug("批量创建子域名 - 数量: %d", len(items))
return self.repo.bulk_create_ignore_conflicts(items)
def bulk_create_subdomains(
self,
target_id: int,
target_name: str,
subdomains: List[str]
) -> BulkCreateResult:
"""
批量创建子域名(带验证)
Args:
target_id: 目标 ID
target_name: 目标域名(用于匹配验证)
subdomains: 子域名列表
Returns:
BulkCreateResult: 创建结果统计
"""
total_received = len(subdomains)
target_name = target_name.lower().strip()
def is_subdomain_match(subdomain: str) -> bool:
"""验证子域名是否匹配目标域名"""
if subdomain == target_name:
return True
if subdomain.endswith('.' + target_name):
return True
return False
# 过滤有效的子域名
valid_subdomains = []
invalid_count = 0
mismatched_count = 0
for subdomain in subdomains:
if not isinstance(subdomain, str) or not subdomain.strip():
continue
subdomain = subdomain.lower().strip()
# 验证格式
if not is_valid_domain(subdomain):
invalid_count += 1
continue
# 验证匹配
if not is_subdomain_match(subdomain):
mismatched_count += 1
continue
valid_subdomains.append(subdomain)
# 去重
unique_subdomains = list(set(valid_subdomains))
duplicate_count = len(valid_subdomains) - len(unique_subdomains)
if not unique_subdomains:
return BulkCreateResult(
created_count=0,
skipped_count=duplicate_count,
invalid_count=invalid_count,
mismatched_count=mismatched_count,
total_received=total_received,
)
# 获取创建前的数量
count_before = self.repo.count_by_target(target_id)
# 创建 DTO 列表并批量创建
subdomain_dtos = [
SubdomainDTO(name=name, target_id=target_id)
for name in unique_subdomains
]
self.repo.bulk_create_ignore_conflicts(subdomain_dtos)
# 获取创建后的数量
count_after = self.repo.count_by_target(target_id)
created_count = count_after - count_before
# 计算因数据库冲突跳过的数量
db_skipped = len(unique_subdomains) - created_count
return BulkCreateResult(
created_count=created_count,
skipped_count=duplicate_count + db_skipped,
invalid_count=invalid_count,
mismatched_count=mismatched_count,
total_received=total_received,
)
__all__ = ['SubdomainService', 'BulkCreateResult']

View File

@@ -1,12 +1,13 @@
"""Vulnerability Service - 漏洞资产业务逻辑层"""
import logging
from typing import List
from typing import List, Optional
from apps.asset.models import Vulnerability
from apps.asset.dtos.asset import VulnerabilityDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@@ -17,6 +18,14 @@ class VulnerabilityService:
当前提供基础的批量创建能力,使用 ignore_conflicts 依赖数据库唯一约束去重。
"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'type': 'vuln_type',
'severity': 'severity',
'source': 'source',
'url': 'url',
}
def bulk_create_ignore_conflicts(self, items: List[VulnerabilityDTO]) -> None:
"""批量创建漏洞资产记录,忽略冲突。
@@ -63,24 +72,34 @@ class VulnerabilityService:
# ==================== 查询方法 ====================
def get_all(self):
def get_all(self, filter_query: Optional[str] = None):
"""获取所有漏洞 QuerySet用于全局漏洞列表
Returns:
QuerySet[Vulnerability]: 所有漏洞,按发现时间倒序
"""
return Vulnerability.objects.all().order_by("-discovered_at")
Args:
filter_query: 智能过滤语法字符串
def get_vulnerabilities_by_target(self, target_id: int):
Returns:
QuerySet[Vulnerability]: 所有漏洞,按创建时间倒序
"""
queryset = Vulnerability.objects.all().order_by("-created_at")
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_vulnerabilities_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""按目标获取漏洞 QuerySet用于分页
Args:
target_id: 目标 ID
filter_query: 智能过滤语法字符串
Returns:
QuerySet[Vulnerability]: 目标下的所有漏洞,按发现时间倒序
QuerySet[Vulnerability]: 目标下的所有漏洞,按创建时间倒序
"""
return Vulnerability.objects.filter(target_id=target_id).order_by("-discovered_at")
queryset = Vulnerability.objects.filter(target_id=target_id).order_by("-created_at")
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def count_by_target(self, target_id: int) -> int:
"""统计目标下的漏洞数量。"""

View File

@@ -1,10 +1,12 @@
"""WebSite Service - 网站业务逻辑层"""
import logging
from typing import List, Iterator
from typing import List, Iterator, Optional
from apps.asset.repositories import DjangoWebSiteRepository
from apps.asset.dtos import WebSiteDTO
from apps.common.validators import is_valid_url, is_url_match_target
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@@ -12,6 +14,14 @@ logger = logging.getLogger(__name__)
class WebSiteService:
"""网站业务逻辑层"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
}
def __init__(self, repository=None):
"""初始化网站服务"""
self.repo = repository or DjangoWebSiteRepository()
@@ -37,13 +47,75 @@ class WebSiteService:
logger.error(f"批量 upsert 网站失败: {e}")
raise
def get_websites_by_target(self, target_id: int):
"""获取目标下的所有网站"""
return self.repo.get_by_target(target_id)
def bulk_create_urls(self, target_id: int, target_name: str, target_type: str, urls: List[str]) -> int:
"""
批量创建网站(仅 URL使用 ignore_conflicts
验证 URL 格式和匹配,过滤无效/不匹配 URL去重后批量创建。
已存在的记录会被跳过。
Args:
target_id: 目标 ID
target_name: 目标名称(用于匹配验证)
target_type: 目标类型 ('domain', 'ip', 'cidr')
urls: URL 列表
Returns:
int: 实际创建的记录数
"""
if not urls:
return 0
# 过滤有效 URL 并去重
valid_urls = []
seen = set()
for url in urls:
if not isinstance(url, str):
continue
url = url.strip()
if not url or url in seen:
continue
if not is_valid_url(url):
continue
# 匹配验证(前端已阻止不匹配的提交,后端作为双重保障)
if not is_url_match_target(url, target_name, target_type):
continue
seen.add(url)
valid_urls.append(url)
if not valid_urls:
return 0
# 获取创建前的数量
count_before = self.repo.count_by_target(target_id)
# 创建 DTO 列表并批量创建
website_dtos = [
WebSiteDTO(url=url, target_id=target_id)
for url in valid_urls
]
self.repo.bulk_create_ignore_conflicts(website_dtos)
# 获取创建后的数量
count_after = self.repo.count_by_target(target_id)
return count_after - count_before
def get_all(self):
def get_websites_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""获取目标下的所有网站"""
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self, filter_query: Optional[str] = None):
"""获取所有网站"""
return self.repo.get_all()
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_by_url(self, url: str, target_id: int) -> int:
"""根据 URL 和 target_id 查找网站 ID"""

View File

@@ -50,7 +50,7 @@ class DirectorySnapshotsService:
# 步骤 2: 转换为资产 DTO 并保存到资产表upsert
# - 新记录:插入资产表
# - 已存在的记录:更新字段(discovered_at 不更新,保留首次发现时间)
# - 已存在的记录:更新字段(created_at 不更新,保留创建时间)
logger.debug("步骤 2: 同步到资产表(通过 Service 层upsert")
asset_items = [item.to_asset_dto() for item in items]
@@ -67,12 +67,29 @@ class DirectorySnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
return self.snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'status': 'status',
'content_type': 'content_type',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: str = None):
"""获取所有目录快照"""
return self.snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_directory_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有目录 URL。"""

View File

@@ -67,12 +67,32 @@ class EndpointSnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
return self.snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
'webserver': 'webserver',
'tech': 'tech',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: str = None):
"""获取所有端点快照"""
return self.snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_endpoint_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有端点 URL。"""

View File

@@ -69,12 +69,12 @@ class HostPortMappingSnapshotsService:
)
raise
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
return self.snapshot_repo.get_ip_aggregation_by_scan(scan_id, search=search)
def get_ip_aggregation_by_scan(self, scan_id: int, filter_query: str = None):
return self.snapshot_repo.get_ip_aggregation_by_scan(scan_id, filter_query=filter_query)
def get_all_ip_aggregation(self, search: str = None):
def get_all_ip_aggregation(self, filter_query: str = None):
"""获取所有 IP 聚合数据"""
return self.snapshot_repo.get_all_ip_aggregation(search=search)
return self.snapshot_repo.get_all_ip_aggregation(filter_query=filter_query)
def iter_ips_by_scan(self, scan_id: int, batch_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有唯一 IP 地址。"""
@@ -88,6 +88,6 @@ class HostPortMappingSnapshotsService:
scan_id: 扫描 ID
Yields:
原始数据字典 {ip, host, port, discovered_at}
原始数据字典 {ip, host, port, created_at}
"""
return self.snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)

View File

@@ -66,12 +66,27 @@ class SubdomainSnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
return self.subdomain_snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'name': 'name',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
from apps.common.utils.filter_utils import apply_filters
queryset = self.subdomain_snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: str = None):
"""获取所有子域名快照"""
return self.subdomain_snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.subdomain_snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_subdomain_names_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
queryset = self.subdomain_snapshot_repo.get_by_scan(scan_id)
@@ -86,6 +101,6 @@ class SubdomainSnapshotsService:
scan_id: 扫描 ID
Yields:
原始数据字典 {name, discovered_at}
原始数据字典 {name, created_at}
"""
return self.subdomain_snapshot_repo.iter_raw_data_for_export(scan_id=scan_id)

View File

@@ -66,13 +66,31 @@ class VulnerabilitySnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
"""按扫描任务获取所有漏洞快照。"""
return self.snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'type': 'vuln_type',
'url': 'url',
'severity': 'severity',
'source': 'source',
}
def get_all(self):
def get_by_scan(self, scan_id: int, filter_query: str = None):
"""按扫描任务获取所有漏洞快照。"""
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self, filter_query: str = None):
"""获取所有漏洞快照"""
return self.snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_vuln_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有漏洞 URL。"""

View File

@@ -51,7 +51,7 @@ class WebsiteSnapshotsService:
# 步骤 2: 转换为资产 DTO 并保存到资产表upsert
# - 新记录:插入资产表
# - 已存在的记录:更新字段(discovered_at 不更新,保留首次发现时间)
# - 已存在的记录:更新字段(created_at 不更新,保留创建时间)
logger.debug("步骤 2: 同步到资产表(通过 Service 层upsert")
asset_items = [item.to_asset_dto() for item in items]
@@ -68,15 +68,35 @@ class WebsiteSnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
return self.snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status',
'webserver': 'web_server',
'tech': 'tech',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: str = None):
"""获取所有网站快照"""
return self.snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_website_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有站点 URL发现时间倒序)。"""
"""流式获取某次扫描下的所有站点 URL创建时间倒序)。"""
queryset = self.snapshot_repo.get_by_scan(scan_id)
for snapshot in queryset.iterator(chunk_size=chunk_size):
yield snapshot.url

View File

@@ -2,6 +2,8 @@ import logging
from rest_framework import viewsets, status, filters
from rest_framework.decorators import action
from rest_framework.response import Response
from apps.common.response_helpers import success_response, error_response
from apps.common.error_codes import ErrorCodes
from rest_framework.request import Request
from rest_framework.exceptions import NotFound, ValidationError as DRFValidationError
from django.core.exceptions import ValidationError, ObjectDoesNotExist
@@ -57,7 +59,7 @@ class AssetStatisticsViewSet(viewsets.ViewSet):
"""
try:
stats = self.service.get_statistics()
return Response({
return success_response(data={
'totalTargets': stats['total_targets'],
'totalSubdomains': stats['total_subdomains'],
'totalIps': stats['total_ips'],
@@ -80,9 +82,10 @@ class AssetStatisticsViewSet(viewsets.ViewSet):
})
except (DatabaseError, OperationalError) as e:
logger.exception("获取资产统计数据失败")
return Response(
{'error': '获取统计数据失败'},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Failed to get statistics',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@action(detail=False, methods=['get'], url_path='history')
@@ -107,12 +110,13 @@ class AssetStatisticsViewSet(viewsets.ViewSet):
days = min(max(days, 1), 90) # 限制在 1-90 天
history = self.service.get_statistics_history(days=days)
return Response(history)
return success_response(data=history)
except (DatabaseError, OperationalError) as e:
logger.exception("获取统计历史数据失败")
return Response(
{'error': '获取历史数据失败'},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Failed to get history data',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@@ -126,30 +130,118 @@ class SubdomainViewSet(viewsets.ModelViewSet):
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/subdomains/
2. 独立路由GET /api/subdomains/(全局查询)
支持智能过滤语法filter 参数):
- name="api" 子域名模糊匹配
- name=="api.example.com" 精确匹配
- 多条件空格分隔 AND 关系
"""
serializer_class = SubdomainListSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['name']
ordering = ['-discovered_at']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = SubdomainService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
target_pk = self.kwargs.get('target_pk')
filter_query = self.request.query_params.get('filter', None)
if target_pk:
return self.service.get_subdomains_by_target(target_pk)
return self.service.get_all()
return self.service.get_subdomains_by_target(target_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['post'], url_path='bulk-create')
def bulk_create(self, request, **kwargs):
"""批量创建子域名
POST /api/targets/{target_pk}/subdomains/bulk-create/
请求体:
{
"subdomains": ["sub1.example.com", "sub2.example.com"]
}
响应:
{
"data": {
"createdCount": 10,
"skippedCount": 2,
"invalidCount": 1,
"mismatchedCount": 1,
"totalReceived": 14
}
}
"""
from apps.targets.models import Target
target_pk = self.kwargs.get('target_pk')
if not target_pk:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Must create subdomains under a target',
status_code=status.HTTP_400_BAD_REQUEST
)
# 获取目标
try:
target = Target.objects.get(pk=target_pk)
except Target.DoesNotExist:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Target not found',
status_code=status.HTTP_404_NOT_FOUND
)
# 验证目标类型必须为域名
if target.type != Target.TargetType.DOMAIN:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Only domain type targets support subdomain import',
status_code=status.HTTP_400_BAD_REQUEST
)
# 获取请求体中的子域名列表
subdomains = request.data.get('subdomains', [])
if not subdomains or not isinstance(subdomains, list):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Request body cannot be empty or invalid format',
status_code=status.HTTP_400_BAD_REQUEST
)
# 调用 service 层处理
try:
result = self.service.bulk_create_subdomains(
target_id=int(target_pk),
target_name=target.name,
subdomains=subdomains
)
except Exception as e:
logger.exception("批量创建子域名失败")
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Server internal error',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return success_response(data={
'createdCount': result.created_count,
'skippedCount': result.skipped_count,
'invalidCount': result.invalid_count,
'mismatchedCount': result.mismatched_count,
'totalReceived': result.total_received,
})
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
"""导出子域名为 CSV 格式
CSV 列name, discovered_at
CSV 列name, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime
@@ -159,8 +251,8 @@ class SubdomainViewSet(viewsets.ModelViewSet):
data_iterator = self.service.iter_raw_data_for_csv_export(target_id=target_pk)
headers = ['name', 'discovered_at']
formatters = {'discovered_at': format_datetime}
headers = ['name', 'created_at']
formatters = {'created_at': format_datetime}
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, formatters),
@@ -176,30 +268,105 @@ class WebSiteViewSet(viewsets.ModelViewSet):
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/websites/
2. 独立路由GET /api/websites/(全局查询)
支持智能过滤语法filter 参数):
- url="api" URL 模糊匹配
- host="example" 主机名模糊匹配
- title="login" 标题模糊匹配
- status="200,301" 状态码多值匹配
- 多条件空格分隔 AND 关系
"""
serializer_class = WebSiteSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['host']
ordering = ['-discovered_at']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = WebSiteService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
target_pk = self.kwargs.get('target_pk')
filter_query = self.request.query_params.get('filter', None)
if target_pk:
return self.service.get_websites_by_target(target_pk)
return self.service.get_all()
return self.service.get_websites_by_target(target_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['post'], url_path='bulk-create')
def bulk_create(self, request, **kwargs):
"""批量创建网站
POST /api/targets/{target_pk}/websites/bulk-create/
请求体:
{
"urls": ["https://example.com", "https://test.com"]
}
响应:
{
"data": {
"createdCount": 10
}
}
"""
from apps.targets.models import Target
target_pk = self.kwargs.get('target_pk')
if not target_pk:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Must create websites under a target',
status_code=status.HTTP_400_BAD_REQUEST
)
# 获取目标
try:
target = Target.objects.get(pk=target_pk)
except Target.DoesNotExist:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Target not found',
status_code=status.HTTP_404_NOT_FOUND
)
# 获取请求体中的 URL 列表
urls = request.data.get('urls', [])
if not urls or not isinstance(urls, list):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Request body cannot be empty or invalid format',
status_code=status.HTTP_400_BAD_REQUEST
)
# 调用 service 层处理
try:
created_count = self.service.bulk_create_urls(
target_id=int(target_pk),
target_name=target.name,
target_type=target.type,
urls=urls
)
except Exception as e:
logger.exception("批量创建网站失败")
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Server internal error',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return success_response(data={
'createdCount': created_count,
})
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
"""导出网站为 CSV 格式
CSV 列url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, discovered_at
CSV 列url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
@@ -212,10 +379,10 @@ class WebSiteViewSet(viewsets.ModelViewSet):
headers = [
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'discovered_at'
'body_preview', 'vhost', 'created_at'
]
formatters = {
'discovered_at': format_datetime,
'created_at': format_datetime,
'tech': lambda x: format_list_field(x, separator=','),
}
@@ -233,30 +400,103 @@ class DirectoryViewSet(viewsets.ModelViewSet):
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/directories/
2. 独立路由GET /api/directories/(全局查询)
支持智能过滤语法filter 参数):
- url="admin" URL 模糊匹配
- status="200,301" 状态码多值匹配
- 多条件空格分隔 AND 关系
"""
serializer_class = DirectorySerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['url']
ordering = ['-discovered_at']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = DirectoryService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
target_pk = self.kwargs.get('target_pk')
filter_query = self.request.query_params.get('filter', None)
if target_pk:
return self.service.get_directories_by_target(target_pk)
return self.service.get_all()
return self.service.get_directories_by_target(target_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['post'], url_path='bulk-create')
def bulk_create(self, request, **kwargs):
"""批量创建目录
POST /api/targets/{target_pk}/directories/bulk-create/
请求体:
{
"urls": ["https://example.com/admin", "https://example.com/api"]
}
响应:
{
"data": {
"createdCount": 10
}
}
"""
from apps.targets.models import Target
target_pk = self.kwargs.get('target_pk')
if not target_pk:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Must create directories under a target',
status_code=status.HTTP_400_BAD_REQUEST
)
# 获取目标
try:
target = Target.objects.get(pk=target_pk)
except Target.DoesNotExist:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Target not found',
status_code=status.HTTP_404_NOT_FOUND
)
# 获取请求体中的 URL 列表
urls = request.data.get('urls', [])
if not urls or not isinstance(urls, list):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Request body cannot be empty or invalid format',
status_code=status.HTTP_400_BAD_REQUEST
)
# 调用 service 层处理
try:
created_count = self.service.bulk_create_urls(
target_id=int(target_pk),
target_name=target.name,
target_type=target.type,
urls=urls
)
except Exception as e:
logger.exception("批量创建目录失败")
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Server internal error',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return success_response(data={
'createdCount': created_count,
})
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
"""导出目录为 CSV 格式
CSV 列url, status, content_length, words, lines, content_type, duration, discovered_at
CSV 列url, status, content_length, words, lines, content_type, duration, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime
@@ -268,10 +508,10 @@ class DirectoryViewSet(viewsets.ModelViewSet):
headers = [
'url', 'status', 'content_length', 'words',
'lines', 'content_type', 'duration', 'discovered_at'
'lines', 'content_type', 'duration', 'created_at'
]
formatters = {
'discovered_at': format_datetime,
'created_at': format_datetime,
}
response = StreamingHttpResponse(
@@ -288,30 +528,105 @@ class EndpointViewSet(viewsets.ModelViewSet):
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/endpoints/
2. 独立路由GET /api/endpoints/(全局查询)
支持智能过滤语法filter 参数):
- url="api" URL 模糊匹配
- host="example" 主机名模糊匹配
- title="login" 标题模糊匹配
- status="200,301" 状态码多值匹配
- 多条件空格分隔 AND 关系
"""
serializer_class = EndpointListSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['host']
ordering = ['-discovered_at']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = EndpointService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
target_pk = self.kwargs.get('target_pk')
filter_query = self.request.query_params.get('filter', None)
if target_pk:
return self.service.get_endpoints_by_target(target_pk)
return self.service.get_all()
return self.service.get_endpoints_by_target(target_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['post'], url_path='bulk-create')
def bulk_create(self, request, **kwargs):
"""批量创建端点
POST /api/targets/{target_pk}/endpoints/bulk-create/
请求体:
{
"urls": ["https://example.com/api/v1", "https://example.com/api/v2"]
}
响应:
{
"data": {
"createdCount": 10
}
}
"""
from apps.targets.models import Target
target_pk = self.kwargs.get('target_pk')
if not target_pk:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Must create endpoints under a target',
status_code=status.HTTP_400_BAD_REQUEST
)
# 获取目标
try:
target = Target.objects.get(pk=target_pk)
except Target.DoesNotExist:
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Target not found',
status_code=status.HTTP_404_NOT_FOUND
)
# 获取请求体中的 URL 列表
urls = request.data.get('urls', [])
if not urls or not isinstance(urls, list):
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Request body cannot be empty or invalid format',
status_code=status.HTTP_400_BAD_REQUEST
)
# 调用 service 层处理
try:
created_count = self.service.bulk_create_urls(
target_id=int(target_pk),
target_name=target.name,
target_type=target.type,
urls=urls
)
except Exception as e:
logger.exception("批量创建端点失败")
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Server internal error',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return success_response(data={
'createdCount': created_count,
})
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
"""导出端点为 CSV 格式
CSV 列url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, matched_gf_patterns, discovered_at
CSV 列url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, matched_gf_patterns, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
@@ -324,10 +639,10 @@ class EndpointViewSet(viewsets.ModelViewSet):
headers = [
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'discovered_at'
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
]
formatters = {
'discovered_at': format_datetime,
'created_at': format_datetime,
'tech': lambda x: format_list_field(x, separator=','),
'matched_gf_patterns': lambda x: format_list_field(x, separator=','),
}
@@ -349,29 +664,46 @@ class HostPortMappingViewSet(viewsets.ModelViewSet):
返回按 IP 聚合的数据,每个 IP 显示其关联的所有 hosts 和 ports
支持智能过滤语法filter 参数):
- ip="192.168" IP 模糊匹配
- port="80,443" 端口多值匹配
- host="api" 主机名模糊匹配
- 多条件空格分隔 AND 关系
注意:由于返回的是聚合数据(字典列表),不支持 DRF SearchFilter
"""
serializer_class = IPAddressAggregatedSerializer
pagination_class = BasePagination
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'ip': 'ip',
'port': 'port',
'host': 'host',
}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = HostPortMappingService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围,返回按 IP 聚合的数据"""
"""根据是否有 target_pk 参数决定查询范围,返回按 IP 聚合的数据
支持智能过滤语法filter 参数)
"""
target_pk = self.kwargs.get('target_pk')
search = self.request.query_params.get('search', None)
filter_query = self.request.query_params.get('filter', None)
if target_pk:
return self.service.get_ip_aggregation_by_target(target_pk, search=search)
return self.service.get_all_ip_aggregation(search=search)
return self.service.get_ip_aggregation_by_target(target_pk, filter_query=filter_query)
return self.service.get_all_ip_aggregation(filter_query=filter_query)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
"""导出 IP 地址为 CSV 格式
CSV 列ip, host, port, discovered_at
CSV 列ip, host, port, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime
@@ -383,9 +715,9 @@ class HostPortMappingViewSet(viewsets.ModelViewSet):
data_iterator = self.service.iter_raw_data_for_csv_export(target_id=target_pk)
# CSV 表头和格式化器
headers = ['ip', 'host', 'port', 'discovered_at']
headers = ['ip', 'host', 'port', 'created_at']
formatters = {
'discovered_at': format_datetime
'created_at': format_datetime
}
# 生成流式响应
@@ -404,37 +736,50 @@ class VulnerabilityViewSet(viewsets.ModelViewSet):
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/vulnerabilities/
2. 独立路由GET /api/vulnerabilities/(全局查询)
支持智能过滤语法filter 参数):
- type="xss" 漏洞类型模糊匹配
- severity="high" 严重程度匹配
- source="nuclei" 来源工具匹配
- url="api" URL 模糊匹配
- 多条件空格分隔 AND 关系
"""
serializer_class = VulnerabilitySerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['vuln_type']
ordering = ['-discovered_at']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = VulnerabilityService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
target_pk = self.kwargs.get('target_pk')
filter_query = self.request.query_params.get('filter', None)
if target_pk:
return self.service.get_vulnerabilities_by_target(target_pk)
return self.service.get_all()
return self.service.get_vulnerabilities_by_target(target_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
# ==================== 快照 ViewSetScan 嵌套路由) ====================
class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
"""子域名快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/subdomains/"""
"""子域名快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/subdomains/
支持智能过滤语法filter 参数):
- name="api" 子域名模糊匹配
- name=="api.example.com" 精确匹配
- name!="test" 排除匹配
"""
serializer_class = SubdomainSnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['name']
ordering_fields = ['name', 'discovered_at']
ordering = ['-discovered_at']
filter_backends = [filters.OrderingFilter]
ordering_fields = ['name', 'created_at']
ordering = ['-created_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -442,15 +787,17 @@ class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
filter_query = self.request.query_params.get('filter', None)
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
"""导出子域名快照为 CSV 格式
CSV 列name, discovered_at
CSV 列name, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime
@@ -460,8 +807,8 @@ class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
data_iterator = self.service.iter_raw_data_for_csv_export(scan_id=scan_pk)
headers = ['name', 'discovered_at']
formatters = {'discovered_at': format_datetime}
headers = ['name', 'created_at']
formatters = {'created_at': format_datetime}
response = StreamingHttpResponse(
generate_csv_rows(data_iterator, headers, formatters),
@@ -472,13 +819,21 @@ class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
"""网站快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/websites/"""
"""网站快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/websites/
支持智能过滤语法filter 参数):
- url="api" URL 模糊匹配
- host="example" 主机名模糊匹配
- title="login" 标题模糊匹配
- status="200" 状态码匹配
- webserver="nginx" 服务器类型匹配
- tech="php" 技术栈匹配
"""
serializer_class = WebsiteSnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['host']
ordering = ['-discovered_at']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -486,15 +841,17 @@ class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
filter_query = self.request.query_params.get('filter', None)
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
"""导出网站快照为 CSV 格式
CSV 列url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, discovered_at
CSV 列url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
@@ -507,10 +864,10 @@ class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
headers = [
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'discovered_at'
'body_preview', 'vhost', 'created_at'
]
formatters = {
'discovered_at': format_datetime,
'created_at': format_datetime,
'tech': lambda x: format_list_field(x, separator=','),
}
@@ -523,13 +880,18 @@ class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
class DirectorySnapshotViewSet(viewsets.ModelViewSet):
"""目录快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/directories/"""
"""目录快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/directories/
支持智能过滤语法filter 参数):
- url="admin" URL 模糊匹配
- status="200" 状态码匹配
- content_type="html" 内容类型匹配
"""
serializer_class = DirectorySnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['url']
ordering = ['-discovered_at']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -537,15 +899,17 @@ class DirectorySnapshotViewSet(viewsets.ModelViewSet):
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
filter_query = self.request.query_params.get('filter', None)
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
"""导出目录快照为 CSV 格式
CSV 列url, status, content_length, words, lines, content_type, duration, discovered_at
CSV 列url, status, content_length, words, lines, content_type, duration, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime
@@ -557,10 +921,10 @@ class DirectorySnapshotViewSet(viewsets.ModelViewSet):
headers = [
'url', 'status', 'content_length', 'words',
'lines', 'content_type', 'duration', 'discovered_at'
'lines', 'content_type', 'duration', 'created_at'
]
formatters = {
'discovered_at': format_datetime,
'created_at': format_datetime,
}
response = StreamingHttpResponse(
@@ -572,13 +936,21 @@ class DirectorySnapshotViewSet(viewsets.ModelViewSet):
class EndpointSnapshotViewSet(viewsets.ModelViewSet):
"""端点快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/endpoints/"""
"""端点快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/endpoints/
支持智能过滤语法filter 参数):
- url="api" URL 模糊匹配
- host="example" 主机名模糊匹配
- title="login" 标题模糊匹配
- status="200" 状态码匹配
- webserver="nginx" 服务器类型匹配
- tech="php" 技术栈匹配
"""
serializer_class = EndpointSnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['host']
ordering = ['-discovered_at']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -586,15 +958,17 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
filter_query = self.request.query_params.get('filter', None)
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
"""导出端点快照为 CSV 格式
CSV 列url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, matched_gf_patterns, discovered_at
CSV 列url, host, location, title, status_code, content_length, content_type, webserver, tech, body_preview, vhost, matched_gf_patterns, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime, format_list_field
@@ -607,10 +981,10 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
headers = [
'url', 'host', 'location', 'title', 'status_code',
'content_length', 'content_type', 'webserver', 'tech',
'body_preview', 'vhost', 'matched_gf_patterns', 'discovered_at'
'body_preview', 'vhost', 'matched_gf_patterns', 'created_at'
]
formatters = {
'discovered_at': format_datetime,
'created_at': format_datetime,
'tech': lambda x: format_list_field(x, separator=','),
'matched_gf_patterns': lambda x: format_list_field(x, separator=','),
}
@@ -626,7 +1000,12 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
"""主机端口映射快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/ip-addresses/
注意:由于返回的是聚合数据(字典列表),不支持 DRF SearchFilter
支持智能过滤语法filter 参数):
- ip="192.168" IP 模糊匹配
- port="80" 端口匹配
- host="api" 主机名模糊匹配
注意:由于返回的是聚合数据(字典列表),过滤在 Service 层处理
"""
serializer_class = IPAddressAggregatedSerializer
@@ -638,16 +1017,17 @@ class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
search = self.request.query_params.get('search', None)
filter_query = self.request.query_params.get('filter', None)
if scan_pk:
return self.service.get_ip_aggregation_by_scan(scan_pk, search=search)
return self.service.get_all_ip_aggregation(search=search)
return self.service.get_ip_aggregation_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all_ip_aggregation(filter_query=filter_query)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
"""导出 IP 地址为 CSV 格式
CSV 列ip, host, port, discovered_at
CSV 列ip, host, port, created_at
"""
from apps.common.utils import generate_csv_rows, format_datetime
@@ -659,9 +1039,9 @@ class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
data_iterator = self.service.iter_raw_data_for_csv_export(scan_id=scan_pk)
# CSV 表头和格式化器
headers = ['ip', 'host', 'port', 'discovered_at']
headers = ['ip', 'host', 'port', 'created_at']
formatters = {
'discovered_at': format_datetime
'created_at': format_datetime
}
# 生成流式响应
@@ -675,13 +1055,19 @@ class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
"""漏洞快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/vulnerabilities/"""
"""漏洞快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/vulnerabilities/
支持智能过滤语法filter 参数):
- type="xss" 漏洞类型模糊匹配
- url="api" URL 模糊匹配
- severity="high" 严重程度匹配
- source="nuclei" 来源工具匹配
"""
serializer_class = VulnerabilitySnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['vuln_type']
ordering = ['-discovered_at']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -689,6 +1075,8 @@ class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
filter_query = self.request.query_params.get('filter', None)
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)

View File

@@ -0,0 +1,31 @@
"""
标准化错误码定义
采用简化方案(参考 Stripe、GitHub 等大厂做法):
- 只定义 5-10 个通用错误码
- 未知错误使用通用错误码
- 错误码格式:大写字母和下划线组成
"""
class ErrorCodes:
"""标准化错误码
只定义通用错误码,其他错误使用通用消息。
这是 Stripe、GitHub 等大厂的标准做法。
错误码格式规范:
- 使用大写字母和下划线
- 简洁明了,易于理解
- 前端通过错误码映射到 i18n 键
"""
# 通用错误码8 个)
VALIDATION_ERROR = 'VALIDATION_ERROR' # 输入验证失败
NOT_FOUND = 'NOT_FOUND' # 资源未找到
PERMISSION_DENIED = 'PERMISSION_DENIED' # 权限不足
SERVER_ERROR = 'SERVER_ERROR' # 服务器内部错误
BAD_REQUEST = 'BAD_REQUEST' # 请求格式错误
CONFLICT = 'CONFLICT' # 资源冲突(如重复创建)
UNAUTHORIZED = 'UNAUTHORIZED' # 未认证
RATE_LIMITED = 'RATE_LIMITED' # 请求过于频繁

View File

@@ -16,6 +16,7 @@ def setup_django_for_prefect():
1. 添加项目根目录到 Python 路径
2. 设置 DJANGO_SETTINGS_MODULE 环境变量
3. 调用 django.setup() 初始化 Django
4. 关闭旧的数据库连接,确保使用新连接
使用方式:
from apps.common.prefect_django_setup import setup_django_for_prefect
@@ -36,6 +37,25 @@ def setup_django_for_prefect():
# 初始化 Django
import django
django.setup()
# 关闭所有旧的数据库连接,确保 Worker 进程使用新连接
# 解决 "server closed the connection unexpectedly" 问题
from django.db import connections
connections.close_all()
def close_old_db_connections():
"""
关闭旧的数据库连接
在长时间运行的任务中调用此函数,可以确保使用有效的数据库连接。
适用于:
- Flow 开始前
- Task 开始前
- 长时间空闲后恢复操作前
"""
from django.db import connections
connections.close_all()
# 自动执行初始化(导入即生效)

View File

@@ -0,0 +1,88 @@
"""
标准化 API 响应辅助函数
遵循行业标准RFC 9457 Problem Details和大厂实践Google、Stripe、GitHub
- 成功响应只包含数据,不包含 message 字段
- 错误响应使用机器可读的错误码,前端映射到 i18n 消息
"""
from typing import Any, Dict, List, Optional, Union
from rest_framework import status
from rest_framework.response import Response
def success_response(
data: Optional[Union[Dict[str, Any], List[Any]]] = None,
status_code: int = status.HTTP_200_OK
) -> Response:
"""
标准化成功响应
直接返回数据,不做包装,符合 Stripe/GitHub 等大厂标准。
Args:
data: 响应数据dict 或 list
status_code: HTTP 状态码,默认 200
Returns:
Response: DRF Response 对象
Examples:
# 单个资源
>>> success_response(data={'id': 1, 'name': 'Test'})
{'id': 1, 'name': 'Test'}
# 操作结果
>>> success_response(data={'count': 3, 'scans': [...]})
{'count': 3, 'scans': [...]}
# 创建资源
>>> success_response(data={'id': 1}, status_code=201)
"""
# 注意:不能使用 data or {},因为空列表 [] 会被转换为 {}
if data is None:
data = {}
return Response(data, status=status_code)
def error_response(
code: str,
message: Optional[str] = None,
details: Optional[List[Dict[str, Any]]] = None,
status_code: int = status.HTTP_400_BAD_REQUEST
) -> Response:
"""
标准化错误响应
Args:
code: 错误码(如 'VALIDATION_ERROR', 'NOT_FOUND'
格式:大写字母和下划线组成
message: 开发者调试信息(非用户显示)
details: 详细错误信息(如字段级验证错误)
status_code: HTTP 状态码,默认 400
Returns:
Response: DRF Response 对象
Examples:
# 简单错误
>>> error_response(code='NOT_FOUND', status_code=404)
{'error': {'code': 'NOT_FOUND'}}
# 带调试信息
>>> error_response(
... code='VALIDATION_ERROR',
... message='Invalid input data',
... details=[{'field': 'name', 'message': 'Required'}]
... )
{'error': {'code': 'VALIDATION_ERROR', 'message': '...', 'details': [...]}}
"""
error_body: Dict[str, Any] = {'code': code}
if message:
error_body['message'] = message
if details:
error_body['details'] = details
return Response({'error': error_body}, status=status_code)

View File

@@ -3,8 +3,13 @@
提供系统级别的公共服务,包括:
- SystemLogService: 系统日志读取服务
注意FilterService 已移至 apps.common.utils.filter_utils
推荐使用: from apps.common.utils.filter_utils import apply_filters
"""
from .system_log_service import SystemLogService
__all__ = ['SystemLogService']
__all__ = [
'SystemLogService',
]

View File

@@ -4,15 +4,28 @@
提供系统日志的读取功能,支持:
- 从日志目录读取日志文件
- 限制返回行数,防止内存溢出
- 列出可用的日志文件
"""
import fnmatch
import logging
import os
import subprocess
from datetime import datetime, timezone
from typing import TypedDict
logger = logging.getLogger(__name__)
class LogFileInfo(TypedDict):
"""日志文件信息"""
filename: str
category: str # 'system' | 'error' | 'performance' | 'container'
size: int
modifiedAt: str # ISO 8601 格式
class SystemLogService:
"""
系统日志服务类
@@ -20,23 +33,131 @@ class SystemLogService:
负责读取系统日志文件,支持从容器内路径或宿主机挂载路径读取日志。
"""
# 日志文件分类规则
CATEGORY_RULES = [
('xingrin.log', 'system'),
('xingrin_error.log', 'error'),
('performance.log', 'performance'),
('container_*.log', 'container'),
]
def __init__(self):
# 日志文件路径(容器内路径,通过 volume 挂载到宿主机 /opt/xingrin/logs
self.log_file = "/app/backend/logs/xingrin.log"
self.default_lines = 200 # 默认返回行数
self.max_lines = 10000 # 最大返回行数限制
self.timeout_seconds = 3 # tail 命令超时时间
# 日志目录路径
self.log_dir = "/opt/xingrin/logs"
self.default_file = "xingrin.log" # 默认日志文件
self.default_lines = 200 # 默认返回行数
self.max_lines = 10000 # 最大返回行数限制
self.timeout_seconds = 3 # tail 命令超时时间
def get_logs_content(self, lines: int | None = None) -> str:
def _categorize_file(self, filename: str) -> str | None:
"""
根据文件名判断日志分类
Returns:
分类名称,如果不是日志文件则返回 None
"""
for pattern, category in self.CATEGORY_RULES:
if fnmatch.fnmatch(filename, pattern):
return category
return None
def _validate_filename(self, filename: str) -> bool:
"""
验证文件名是否合法(防止路径遍历攻击)
Args:
filename: 要验证的文件名
Returns:
bool: 文件名是否合法
"""
# 不允许包含路径分隔符
if '/' in filename or '\\' in filename:
return False
# 不允许 .. 路径遍历
if '..' in filename:
return False
# 必须是已知的日志文件类型
return self._categorize_file(filename) is not None
def get_log_files(self) -> list[LogFileInfo]:
"""
获取所有可用的日志文件列表
Returns:
日志文件信息列表,按分类和文件名排序
"""
files: list[LogFileInfo] = []
if not os.path.isdir(self.log_dir):
logger.warning("日志目录不存在: %s", self.log_dir)
return files
for filename in os.listdir(self.log_dir):
filepath = os.path.join(self.log_dir, filename)
# 只处理文件,跳过目录
if not os.path.isfile(filepath):
continue
# 判断分类
category = self._categorize_file(filename)
if category is None:
continue
# 获取文件信息
try:
stat = os.stat(filepath)
modified_at = datetime.fromtimestamp(
stat.st_mtime, tz=timezone.utc
).isoformat()
files.append({
'filename': filename,
'category': category,
'size': stat.st_size,
'modifiedAt': modified_at,
})
except OSError as e:
logger.warning("获取文件信息失败 %s: %s", filepath, e)
continue
# 排序按分类优先级system > error > performance > container然后按文件名
category_order = {'system': 0, 'error': 1, 'performance': 2, 'container': 3}
files.sort(key=lambda f: (category_order.get(f['category'], 99), f['filename']))
return files
def get_logs_content(self, filename: str | None = None, lines: int | None = None) -> str:
"""
获取系统日志内容
Args:
filename: 日志文件名,默认为 xingrin.log
lines: 返回的日志行数,默认 200 行,最大 10000 行
Returns:
str: 日志内容,每行以换行符分隔,保持原始顺序
Raises:
ValueError: 文件名不合法
FileNotFoundError: 日志文件不存在
"""
# 文件名处理
if filename is None:
filename = self.default_file
# 验证文件名
if not self._validate_filename(filename):
raise ValueError(f"无效的文件名: {filename}")
# 构建完整路径
log_file = os.path.join(self.log_dir, filename)
# 检查文件是否存在
if not os.path.isfile(log_file):
raise FileNotFoundError(f"日志文件不存在: {filename}")
# 参数校验和默认值处理
if lines is None:
lines = self.default_lines
@@ -48,7 +169,7 @@ class SystemLogService:
lines = self.max_lines
# 使用 tail 命令读取日志文件末尾内容
cmd = ["tail", "-n", str(lines), self.log_file]
cmd = ["tail", "-n", str(lines), log_file]
result = subprocess.run(
cmd,

View File

@@ -7,7 +7,7 @@
"""
from django.urls import path
from .views import LoginView, LogoutView, MeView, ChangePasswordView, SystemLogsView
from .views import LoginView, LogoutView, MeView, ChangePasswordView, SystemLogsView, SystemLogFilesView
urlpatterns = [
# 认证相关
@@ -18,4 +18,5 @@ urlpatterns = [
# 系统管理
path('system/logs/', SystemLogsView.as_view(), name='system-logs'),
path('system/logs/files/', SystemLogFilesView.as_view(), name='system-log-files'),
]

View File

@@ -0,0 +1,281 @@
"""智能过滤工具 - 通用查询语法解析和 Django ORM 查询构建
支持的语法:
- field="value" 模糊匹配(包含)
- field=="value" 精确匹配
- field!="value" 不等于
逻辑运算符:
- AND: && 或 and 或 空格(默认)
- OR: || 或 or
示例:
type="xss" || type="sqli" # OR
type="xss" or type="sqli" # OR等价
severity="high" && source="nuclei" # AND
severity="high" source="nuclei" # AND空格默认为 AND
severity="high" and source="nuclei" # AND等价
使用示例:
from apps.common.utils.filter_utils import apply_filters
field_mapping = {'ip': 'ip', 'port': 'port', 'host': 'host'}
queryset = apply_filters(queryset, 'ip="192" || port="80"', field_mapping)
"""
import re
import logging
from dataclasses import dataclass
from typing import List, Dict, Optional, Union
from enum import Enum
from django.db.models import QuerySet, Q
logger = logging.getLogger(__name__)
class LogicalOp(Enum):
"""逻辑运算符"""
AND = 'AND'
OR = 'OR'
@dataclass
class ParsedFilter:
"""解析后的过滤条件"""
field: str # 字段名
operator: str # 操作符: '=', '==', '!='
value: str # 原始值
@dataclass
class FilterGroup:
"""过滤条件组(带逻辑运算符)"""
filter: ParsedFilter
logical_op: LogicalOp # 与前一个条件的逻辑关系
class QueryParser:
"""查询语法解析器
支持 ||/or (OR) 和 &&/and/空格 (AND) 逻辑运算符
"""
# 正则匹配: field="value", field=="value", field!="value"
FILTER_PATTERN = re.compile(r'(\w+)(==|!=|=)"([^"]*)"')
# 逻辑运算符模式(带空格)
OR_PATTERN = re.compile(r'\s*(\|\||(?<![a-zA-Z])or(?![a-zA-Z]))\s*', re.IGNORECASE)
AND_PATTERN = re.compile(r'\s*(&&|(?<![a-zA-Z])and(?![a-zA-Z]))\s*', re.IGNORECASE)
@classmethod
def parse(cls, query_string: str) -> List[FilterGroup]:
"""解析查询语法字符串
Args:
query_string: 查询语法字符串
Returns:
解析后的过滤条件组列表
Examples:
>>> QueryParser.parse('type="xss" || type="sqli"')
[FilterGroup(filter=..., logical_op=AND), # 第一个默认 AND
FilterGroup(filter=..., logical_op=OR)]
"""
if not query_string or not query_string.strip():
return []
# 标准化逻辑运算符
# 先处理 || 和 or -> __OR__
normalized = cls.OR_PATTERN.sub(' __OR__ ', query_string)
# 再处理 && 和 and -> __AND__
normalized = cls.AND_PATTERN.sub(' __AND__ ', normalized)
# 分词:按空格分割,保留逻辑运算符标记
tokens = normalized.split()
groups = []
pending_op = LogicalOp.AND # 默认 AND
for token in tokens:
if token == '__OR__':
pending_op = LogicalOp.OR
elif token == '__AND__':
pending_op = LogicalOp.AND
else:
# 尝试解析为过滤条件
match = cls.FILTER_PATTERN.match(token)
if match:
field, operator, value = match.groups()
groups.append(FilterGroup(
filter=ParsedFilter(
field=field.lower(),
operator=operator,
value=value
),
logical_op=pending_op if groups else LogicalOp.AND # 第一个条件默认 AND
))
pending_op = LogicalOp.AND # 重置为默认 AND
return groups
class QueryBuilder:
"""Django ORM 查询构建器
将解析后的过滤条件转换为 Django ORM 查询,支持 AND/OR 逻辑
"""
@classmethod
def build_query(
cls,
queryset: QuerySet,
filter_groups: List[FilterGroup],
field_mapping: Dict[str, str],
json_array_fields: List[str] = None
) -> QuerySet:
"""构建 Django ORM 查询
Args:
queryset: Django QuerySet
filter_groups: 解析后的过滤条件组列表
field_mapping: 字段映射
json_array_fields: JSON 数组字段列表(使用 __contains 查询)
Returns:
过滤后的 QuerySet
"""
if not filter_groups:
return queryset
json_array_fields = json_array_fields or []
# 构建 Q 对象
combined_q = None
for group in filter_groups:
f = group.filter
# 字段映射
db_field = field_mapping.get(f.field)
if not db_field:
logger.debug(f"忽略未知字段: {f.field}")
continue
# 判断是否为 JSON 数组字段
is_json_array = db_field in json_array_fields
# 构建单个条件的 Q 对象
q = cls._build_single_q(db_field, f.operator, f.value, is_json_array)
if q is None:
continue
# 组合 Q 对象
if combined_q is None:
combined_q = q
elif group.logical_op == LogicalOp.OR:
combined_q = combined_q | q
else: # AND
combined_q = combined_q & q
if combined_q is not None:
return queryset.filter(combined_q)
return queryset
@classmethod
def _build_single_q(cls, field: str, operator: str, value: str, is_json_array: bool = False) -> Optional[Q]:
"""构建单个条件的 Q 对象"""
if is_json_array:
# JSON 数组字段使用 __contains 查询
return Q(**{f'{field}__contains': [value]})
if operator == '!=':
return cls._build_not_equal_q(field, value)
elif operator == '==':
return cls._build_exact_q(field, value)
else: # '='
return cls._build_fuzzy_q(field, value)
@classmethod
def _try_convert_to_int(cls, value: str) -> Optional[int]:
"""尝试将值转换为整数"""
try:
return int(value.strip())
except (ValueError, TypeError):
return None
@classmethod
def _build_fuzzy_q(cls, field: str, value: str) -> Q:
"""模糊匹配: 包含"""
return Q(**{f'{field}__icontains': value})
@classmethod
def _build_exact_q(cls, field: str, value: str) -> Q:
"""精确匹配"""
int_val = cls._try_convert_to_int(value)
if int_val is not None:
return Q(**{f'{field}__exact': int_val})
return Q(**{f'{field}__exact': value})
@classmethod
def _build_not_equal_q(cls, field: str, value: str) -> Q:
"""不等于"""
int_val = cls._try_convert_to_int(value)
if int_val is not None:
return ~Q(**{f'{field}__exact': int_val})
return ~Q(**{f'{field}__exact': value})
def apply_filters(
queryset: QuerySet,
query_string: str,
field_mapping: Dict[str, str],
json_array_fields: List[str] = None
) -> QuerySet:
"""应用过滤条件到 QuerySet
Args:
queryset: Django QuerySet
query_string: 查询语法字符串
field_mapping: 字段映射
json_array_fields: JSON 数组字段列表(使用 __contains 查询)
Returns:
过滤后的 QuerySet
Examples:
# OR 查询
apply_filters(qs, 'type="xss" || type="sqli"', mapping)
apply_filters(qs, 'type="xss" or type="sqli"', mapping)
# AND 查询
apply_filters(qs, 'severity="high" && source="nuclei"', mapping)
apply_filters(qs, 'severity="high" source="nuclei"', mapping)
# 混合查询
apply_filters(qs, 'type="xss" || type="sqli" && severity="high"', mapping)
# JSON 数组字段查询
apply_filters(qs, 'implies="PHP"', mapping, json_array_fields=['implies'])
"""
if not query_string or not query_string.strip():
return queryset
try:
filter_groups = QueryParser.parse(query_string)
if not filter_groups:
logger.debug(f"未解析到有效过滤条件: {query_string}")
return queryset
logger.debug(f"解析过滤条件: {filter_groups}")
return QueryBuilder.build_query(
queryset,
filter_groups,
field_mapping,
json_array_fields=json_array_fields
)
except Exception as e:
logger.warning(f"过滤解析错误: {e}, query: {query_string}")
return queryset # 静默降级

View File

@@ -27,6 +27,21 @@ def validate_domain(domain: str) -> None:
raise ValueError(f"域名格式无效: {domain}")
def is_valid_domain(domain: str) -> bool:
"""
判断是否为有效域名(不抛异常)
Args:
domain: 域名字符串
Returns:
bool: 是否为有效域名
"""
if not domain or len(domain) > 253:
return False
return bool(validators.domain(domain))
def validate_ip(ip: str) -> None:
"""
验证 IP 地址格式(支持 IPv4 和 IPv6
@@ -190,6 +205,70 @@ def validate_url(url: str) -> None:
raise ValueError(f"URL 格式无效: {url}")
def is_valid_url(url: str, max_length: int = 2000) -> bool:
"""
判断是否为有效 URL不抛异常
Args:
url: URL 字符串
max_length: URL 最大长度,默认 2000
Returns:
bool: 是否为有效 URL
"""
if not url or len(url) > max_length:
return False
try:
validate_url(url)
return True
except ValueError:
return False
def is_url_match_target(url: str, target_name: str, target_type: str) -> bool:
"""
判断 URL 是否匹配目标
Args:
url: URL 字符串
target_name: 目标名称域名、IP 或 CIDR
target_type: 目标类型 ('domain', 'ip', 'cidr')
Returns:
bool: 是否匹配
"""
try:
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
return False
hostname = hostname.lower()
target_name = target_name.lower()
if target_type == 'domain':
# 域名类型hostname 等于 target_name 或以 .target_name 结尾
return hostname == target_name or hostname.endswith('.' + target_name)
elif target_type == 'ip':
# IP 类型hostname 必须完全等于 target_name
return hostname == target_name
elif target_type == 'cidr':
# CIDR 类型hostname 必须是 IP 且在 CIDR 范围内
try:
ip = ipaddress.ip_address(hostname)
network = ipaddress.ip_network(target_name, strict=False)
return ip in network
except ValueError:
# hostname 不是有效 IP
return False
return False
except Exception:
return False
def detect_input_type(input_str: str) -> str:
"""
检测输入类型(用于快速扫描输入解析)

View File

@@ -7,6 +7,6 @@
"""
from .auth_views import LoginView, LogoutView, MeView, ChangePasswordView
from .system_log_views import SystemLogsView
from .system_log_views import SystemLogsView, SystemLogFilesView
__all__ = ['LoginView', 'LogoutView', 'MeView', 'ChangePasswordView', 'SystemLogsView']
__all__ = ['LoginView', 'LogoutView', 'MeView', 'ChangePasswordView', 'SystemLogsView', 'SystemLogFilesView']

View File

@@ -11,6 +11,9 @@ from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.permissions import AllowAny, IsAuthenticated
from apps.common.response_helpers import success_response, error_response
from apps.common.error_codes import ErrorCodes
logger = logging.getLogger(__name__)
@@ -28,9 +31,10 @@ class LoginView(APIView):
password = request.data.get('password')
if not username or not password:
return Response(
{'error': '请提供用户名和密码'},
status=status.HTTP_400_BAD_REQUEST
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Username and password are required',
status_code=status.HTTP_400_BAD_REQUEST
)
user = authenticate(request, username=username, password=password)
@@ -38,20 +42,22 @@ class LoginView(APIView):
if user is not None:
login(request, user)
logger.info(f"用户 {username} 登录成功")
return Response({
'message': '登录成功',
'user': {
'id': user.id,
'username': user.username,
'isStaff': user.is_staff,
'isSuperuser': user.is_superuser,
return success_response(
data={
'user': {
'id': user.id,
'username': user.username,
'isStaff': user.is_staff,
'isSuperuser': user.is_superuser,
}
}
})
)
else:
logger.warning(f"用户 {username} 登录失败:用户名或密码错误")
return Response(
{'error': '用户名或密码错误'},
status=status.HTTP_401_UNAUTHORIZED
return error_response(
code=ErrorCodes.UNAUTHORIZED,
message='Invalid username or password',
status_code=status.HTTP_401_UNAUTHORIZED
)
@@ -79,7 +85,7 @@ class LogoutView(APIView):
logout(request)
else:
logout(request)
return Response({'message': '已登出'})
return success_response()
@method_decorator(csrf_exempt, name='dispatch')
@@ -100,22 +106,26 @@ class MeView(APIView):
if user_id:
try:
user = User.objects.get(pk=user_id)
return Response({
'authenticated': True,
'user': {
'id': user.id,
'username': user.username,
'isStaff': user.is_staff,
'isSuperuser': user.is_superuser,
return success_response(
data={
'authenticated': True,
'user': {
'id': user.id,
'username': user.username,
'isStaff': user.is_staff,
'isSuperuser': user.is_superuser,
}
}
})
)
except User.DoesNotExist:
pass
return Response({
'authenticated': False,
'user': None
})
return success_response(
data={
'authenticated': False,
'user': None
}
)
@method_decorator(csrf_exempt, name='dispatch')
@@ -134,17 +144,19 @@ class ChangePasswordView(APIView):
user_id = request.session.get('_auth_user_id')
if not user_id:
return Response(
{'error': '请先登录'},
status=status.HTTP_401_UNAUTHORIZED
return error_response(
code=ErrorCodes.UNAUTHORIZED,
message='Please login first',
status_code=status.HTTP_401_UNAUTHORIZED
)
try:
user = User.objects.get(pk=user_id)
except User.DoesNotExist:
return Response(
{'error': '用户不存在'},
status=status.HTTP_401_UNAUTHORIZED
return error_response(
code=ErrorCodes.UNAUTHORIZED,
message='User does not exist',
status_code=status.HTTP_401_UNAUTHORIZED
)
# CamelCaseParser 将 oldPassword -> old_password
@@ -152,15 +164,17 @@ class ChangePasswordView(APIView):
new_password = request.data.get('new_password')
if not old_password or not new_password:
return Response(
{'error': '请提供旧密码和新密码'},
status=status.HTTP_400_BAD_REQUEST
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Old password and new password are required',
status_code=status.HTTP_400_BAD_REQUEST
)
if not user.check_password(old_password):
return Response(
{'error': '旧密码错误'},
status=status.HTTP_400_BAD_REQUEST
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Old password is incorrect',
status_code=status.HTTP_400_BAD_REQUEST
)
user.set_password(new_password)
@@ -170,4 +184,4 @@ class ChangePasswordView(APIView):
update_session_auth_hash(request, user)
logger.info(f"用户 {user.username} 已修改密码")
return Response({'message': '密码修改成功'})
return success_response()

View File

@@ -13,12 +13,57 @@ from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.common.response_helpers import success_response, error_response
from apps.common.error_codes import ErrorCodes
from apps.common.services.system_log_service import SystemLogService
logger = logging.getLogger(__name__)
@method_decorator(csrf_exempt, name="dispatch")
class SystemLogFilesView(APIView):
"""
日志文件列表 API 视图
GET /api/system/logs/files/
获取所有可用的日志文件列表
Response:
{
"files": [
{
"filename": "xingrin.log",
"category": "system",
"size": 1048576,
"modifiedAt": "2025-01-15T10:30:00+00:00"
},
...
]
}
"""
authentication_classes = []
permission_classes = [AllowAny]
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = SystemLogService()
def get(self, request):
"""获取日志文件列表"""
try:
files = self.service.get_log_files()
return success_response(data={"files": files})
except Exception:
logger.exception("获取日志文件列表失败")
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Failed to get log files',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@method_decorator(csrf_exempt, name="dispatch")
class SystemLogsView(APIView):
"""
@@ -28,6 +73,7 @@ class SystemLogsView(APIView):
获取系统日志内容
Query Parameters:
file (str, optional): 日志文件名,默认 xingrin.log
lines (int, optional): 返回的日志行数,默认 200最大 10000
Response:
@@ -52,18 +98,33 @@ class SystemLogsView(APIView):
"""
获取系统日志
支持通过 lines 参数控制返回行数,用于前端分页或实时刷新场景
支持通过 file 和 lines 参数控制返回内容
"""
try:
# 解析 lines 参数
# 解析参数
filename = request.query_params.get("file")
lines_raw = request.query_params.get("lines")
lines = int(lines_raw) if lines_raw is not None else None
# 调用服务获取日志内容
content = self.service.get_logs_content(lines=lines)
return Response({"content": content})
except ValueError:
return Response({"error": "lines 参数必须是整数"}, status=status.HTTP_400_BAD_REQUEST)
content = self.service.get_logs_content(filename=filename, lines=lines)
return success_response(data={"content": content})
except ValueError as e:
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message=str(e) if 'file' in str(e).lower() else 'lines must be an integer',
status_code=status.HTTP_400_BAD_REQUEST
)
except FileNotFoundError as e:
return error_response(
code=ErrorCodes.NOT_FOUND,
message=str(e),
status_code=status.HTTP_404_NOT_FOUND
)
except Exception:
logger.exception("获取系统日志失败")
return Response({"error": "获取系统日志失败"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Failed to get system logs',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)

View File

@@ -242,8 +242,9 @@ class WorkerDeployConsumer(AsyncWebsocketConsumer):
return
# 远程 Worker 通过 nginx HTTPS 访问nginx 反代到后端 8888
# 使用 https://{PUBLIC_HOST} 而不是直连 8888 端口
heartbeat_api_url = f"https://{public_host}" # 基础 URLagent 会加 /api/...
# 使用 https://{PUBLIC_HOST}:{PUBLIC_PORT} 而不是直连 8888 端口
public_port = getattr(settings, 'PUBLIC_PORT', '8083')
heartbeat_api_url = f"https://{public_host}:{public_port}"
session_name = f'xingrin_deploy_{self.worker_id}'
remote_script_path = '/tmp/xingrin_deploy.sh'

View File

@@ -0,0 +1,205 @@
"""初始化内置指纹库
- EHole 指纹: ehole.json -> 导入到数据库
- Goby 指纹: goby.json -> 导入到数据库
- Wappalyzer 指纹: wappalyzer.json -> 导入到数据库
- Fingers 指纹: fingers_http.json -> 导入到数据库
- FingerPrintHub 指纹: fingerprinthub_web.json -> 导入到数据库
- ARL 指纹: ARL.yaml -> 导入到数据库
可重复执行:如果数据库已有数据则跳过,只在空库时导入。
"""
import json
import logging
from pathlib import Path
import yaml
from django.conf import settings
from django.core.management.base import BaseCommand
from apps.engine.models import (
EholeFingerprint,
GobyFingerprint,
WappalyzerFingerprint,
FingersFingerprint,
FingerPrintHubFingerprint,
ARLFingerprint,
)
from apps.engine.services.fingerprints import (
EholeFingerprintService,
GobyFingerprintService,
WappalyzerFingerprintService,
FingersFingerprintService,
FingerPrintHubFingerprintService,
ARLFingerprintService,
)
logger = logging.getLogger(__name__)
# 内置指纹配置
DEFAULT_FINGERPRINTS = [
{
"type": "ehole",
"filename": "ehole.json",
"model": EholeFingerprint,
"service": EholeFingerprintService,
"data_key": "fingerprint", # JSON 中指纹数组的 key
"file_format": "json",
},
{
"type": "goby",
"filename": "goby.json",
"model": GobyFingerprint,
"service": GobyFingerprintService,
"data_key": None, # Goby 是数组格式,直接使用整个 JSON
"file_format": "json",
},
{
"type": "wappalyzer",
"filename": "wappalyzer.json",
"model": WappalyzerFingerprint,
"service": WappalyzerFingerprintService,
"data_key": "apps", # Wappalyzer 使用 apps 对象
"file_format": "json",
},
{
"type": "fingers",
"filename": "fingers_http.json",
"model": FingersFingerprint,
"service": FingersFingerprintService,
"data_key": None, # Fingers 是数组格式
"file_format": "json",
},
{
"type": "fingerprinthub",
"filename": "fingerprinthub_web.json",
"model": FingerPrintHubFingerprint,
"service": FingerPrintHubFingerprintService,
"data_key": None, # FingerPrintHub 是数组格式
"file_format": "json",
},
{
"type": "arl",
"filename": "ARL.yaml",
"model": ARLFingerprint,
"service": ARLFingerprintService,
"data_key": None, # ARL 是 YAML 数组格式
"file_format": "yaml",
},
]
class Command(BaseCommand):
help = "初始化内置指纹库"
def handle(self, *args, **options):
project_base = Path(settings.BASE_DIR).parent # /app/backend -> /app
fingerprints_dir = project_base / "backend" / "fingerprints"
initialized = 0
skipped = 0
failed = 0
for item in DEFAULT_FINGERPRINTS:
fp_type = item["type"]
filename = item["filename"]
model = item["model"]
service_class = item["service"]
data_key = item["data_key"]
file_format = item.get("file_format", "json")
# 检查数据库是否已有数据
existing_count = model.objects.count()
if existing_count > 0:
self.stdout.write(self.style.SUCCESS(
f"[{fp_type}] 数据库已有 {existing_count} 条记录,跳过初始化"
))
skipped += 1
continue
# 查找源文件
src_path = fingerprints_dir / filename
if not src_path.exists():
self.stdout.write(self.style.WARNING(
f"[{fp_type}] 未找到内置指纹文件: {src_path},跳过"
))
failed += 1
continue
# 读取并解析文件(支持 JSON 和 YAML
try:
with open(src_path, "r", encoding="utf-8") as f:
if file_format == "yaml":
file_data = yaml.safe_load(f)
else:
file_data = json.load(f)
except (json.JSONDecodeError, yaml.YAMLError, OSError) as exc:
self.stdout.write(self.style.ERROR(
f"[{fp_type}] 读取指纹文件失败: {exc}"
))
failed += 1
continue
# 提取指纹数据(根据不同格式处理)
fingerprints = self._extract_fingerprints(file_data, data_key, fp_type)
if not fingerprints:
self.stdout.write(self.style.WARNING(
f"[{fp_type}] 指纹文件中没有有效数据,跳过"
))
failed += 1
continue
# 使用 Service 批量导入
try:
service = service_class()
result = service.batch_create_fingerprints(fingerprints)
created = result.get("created", 0)
failed_count = result.get("failed", 0)
self.stdout.write(self.style.SUCCESS(
f"[{fp_type}] 导入成功: 创建 {created} 条,失败 {failed_count}"
))
initialized += 1
except Exception as exc:
self.stdout.write(self.style.ERROR(
f"[{fp_type}] 导入失败: {exc}"
))
failed += 1
continue
self.stdout.write(self.style.SUCCESS(
f"指纹初始化完成: 成功 {initialized}, 已存在跳过 {skipped}, 失败 {failed}"
))
def _extract_fingerprints(self, json_data, data_key, fp_type):
"""
根据不同格式提取指纹数据,兼容数组和对象两种格式
支持的格式:
- 数组格式: [...] 或 {"key": [...]}
- 对象格式: {...} 或 {"key": {...}} -> 转换为 [{"name": k, ...v}]
"""
# 获取目标数据
if data_key is None:
# 直接使用整个 JSON
target = json_data
else:
# 从指定 key 获取,支持多个可能的 key如 apps/technologies
if data_key == "apps":
target = json_data.get("apps") or json_data.get("technologies") or {}
else:
target = json_data.get(data_key, [])
# 根据数据类型处理
if isinstance(target, list):
# 已经是数组格式,直接返回
return target
elif isinstance(target, dict):
# 对象格式,转换为数组 [{"name": key, ...value}]
return [{"name": name, **data} if isinstance(data, dict) else {"name": name}
for name, data in target.items()]
return []

View File

@@ -3,12 +3,17 @@
项目安装后执行此命令,自动创建官方模板仓库记录。
使用方式:
python manage.py init_nuclei_templates # 只创建记录
python manage.py init_nuclei_templates # 只创建记录(检测本地已有仓库)
python manage.py init_nuclei_templates --sync # 创建并同步git clone
"""
import logging
import subprocess
from pathlib import Path
from django.conf import settings
from django.core.management.base import BaseCommand
from django.utils import timezone
from apps.engine.models import NucleiTemplateRepo
from apps.engine.services import NucleiTemplateRepoService
@@ -26,6 +31,20 @@ DEFAULT_REPOS = [
]
def get_local_commit_hash(local_path: Path) -> str:
"""获取本地 Git 仓库的 commit hash"""
if not (local_path / ".git").is_dir():
return ""
result = subprocess.run(
["git", "-C", str(local_path), "rev-parse", "HEAD"],
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
return result.stdout.strip() if result.returncode == 0 else ""
class Command(BaseCommand):
help = "初始化 Nuclei 模板仓库(创建官方模板仓库记录)"
@@ -46,6 +65,8 @@ class Command(BaseCommand):
force = options.get("force", False)
service = NucleiTemplateRepoService()
base_dir = Path(getattr(settings, "NUCLEI_TEMPLATES_REPOS_BASE_DIR", "/opt/xingrin/nuclei-repos"))
created = 0
skipped = 0
synced = 0
@@ -87,20 +108,30 @@ class Command(BaseCommand):
# 创建新仓库记录
try:
# 检查本地是否已有仓库(由 install.sh 预下载)
local_path = base_dir / name
local_commit = get_local_commit_hash(local_path)
repo = NucleiTemplateRepo.objects.create(
name=name,
repo_url=repo_url,
local_path=str(local_path) if local_commit else "",
commit_hash=local_commit,
last_synced_at=timezone.now() if local_commit else None,
)
self.stdout.write(self.style.SUCCESS(
f"[{name}] 创建成功: id={repo.id}"
))
if local_commit:
self.stdout.write(self.style.SUCCESS(
f"[{name}] 创建成功(检测到本地仓库): commit={local_commit[:8]}"
))
else:
self.stdout.write(self.style.SUCCESS(
f"[{name}] 创建成功: id={repo.id}"
))
created += 1
# 初始化本地路径
service.ensure_local_path(repo)
# 如果需要同步
if do_sync:
# 如果本地没有仓库且需要同步
if not local_commit and do_sync:
try:
self.stdout.write(self.style.WARNING(
f"[{name}] 正在同步(首次可能需要几分钟)..."

View File

@@ -1,7 +1,8 @@
"""初始化所有内置字典 Wordlist 记录
- 目录扫描默认字典: dir_default.txt -> /app/backend/wordlist/dir_default.txt
- 子域名爆破默认字典: subdomains-top1million-110000.txt -> /app/backend/wordlist/subdomains-top1million-110000.txt
内置字典从镜像内 /app/backend/wordlist/ 复制到运行时目录 /opt/xingrin/wordlists/
- 目录扫描默认字典: dir_default.txt
- 子域名爆破默认字典: subdomains-top1million-110000.txt
可重复执行:如果已存在同名记录且文件有效则跳过,只在缺失或文件丢失时创建/修复。
"""

View File

@@ -0,0 +1,29 @@
"""Engine Models
导出所有 Engine 模块的 Models
"""
from .engine import WorkerNode, ScanEngine, Wordlist, NucleiTemplateRepo
from .fingerprints import (
EholeFingerprint,
GobyFingerprint,
WappalyzerFingerprint,
FingersFingerprint,
FingerPrintHubFingerprint,
ARLFingerprint,
)
__all__ = [
# 核心 Models
"WorkerNode",
"ScanEngine",
"Wordlist",
"NucleiTemplateRepo",
# 指纹 Models
"EholeFingerprint",
"GobyFingerprint",
"WappalyzerFingerprint",
"FingersFingerprint",
"FingerPrintHubFingerprint",
"ARLFingerprint",
]

View File

@@ -1,3 +1,8 @@
"""Engine 模块核心 Models
包含 WorkerNode, ScanEngine, Wordlist, NucleiTemplateRepo
"""
from django.db import models
@@ -78,6 +83,7 @@ class ScanEngine(models.Model):
indexes = [
models.Index(fields=['-created_at']),
]
def __str__(self):
return str(self.name or f'ScanEngine {self.id}')

View File

@@ -0,0 +1,195 @@
"""指纹相关 Models
包含 EHole、Goby、Wappalyzer 等指纹格式的数据模型
"""
from django.db import models
class GobyFingerprint(models.Model):
"""Goby 格式指纹规则
Goby 使用逻辑表达式和规则数组进行匹配:
- logic: 逻辑表达式,如 "a||b", "(a&&b)||c"
- rule: 规则数组,每条规则包含 label, feature, is_equal
"""
name = models.CharField(max_length=300, unique=True, help_text='产品名称')
logic = models.CharField(max_length=500, help_text='逻辑表达式')
rule = models.JSONField(default=list, help_text='规则数组')
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = 'goby_fingerprint'
verbose_name = 'Goby 指纹'
verbose_name_plural = 'Goby 指纹'
ordering = ['-created_at']
indexes = [
models.Index(fields=['name']),
models.Index(fields=['logic']),
models.Index(fields=['-created_at']),
]
def __str__(self) -> str:
return f"{self.name} ({self.logic})"
class EholeFingerprint(models.Model):
"""EHole 格式指纹规则(字段与 ehole.json 一致)"""
cms = models.CharField(max_length=200, help_text='产品/CMS名称')
method = models.CharField(max_length=200, default='keyword', help_text='匹配方式')
location = models.CharField(max_length=200, default='body', help_text='匹配位置')
keyword = models.JSONField(default=list, help_text='关键词列表')
is_important = models.BooleanField(default=False, help_text='是否重点资产')
type = models.CharField(max_length=100, blank=True, default='-', help_text='分类')
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = 'ehole_fingerprint'
verbose_name = 'EHole 指纹'
verbose_name_plural = 'EHole 指纹'
ordering = ['-created_at']
indexes = [
# 搜索过滤字段索引
models.Index(fields=['cms']),
models.Index(fields=['method']),
models.Index(fields=['location']),
models.Index(fields=['type']),
models.Index(fields=['is_important']),
# 排序字段索引
models.Index(fields=['-created_at']),
]
constraints = [
# 唯一约束cms + method + location 组合不能重复
models.UniqueConstraint(
fields=['cms', 'method', 'location'],
name='unique_ehole_fingerprint'
),
]
def __str__(self) -> str:
return f"{self.cms} ({self.method}@{self.location})"
class WappalyzerFingerprint(models.Model):
"""Wappalyzer 格式指纹规则
Wappalyzer 支持多种检测方式cookies, headers, scriptSrc, js, meta, html 等
"""
name = models.CharField(max_length=300, unique=True, help_text='应用名称')
cats = models.JSONField(default=list, help_text='分类 ID 数组')
cookies = models.JSONField(default=dict, blank=True, help_text='Cookie 检测规则')
headers = models.JSONField(default=dict, blank=True, help_text='HTTP Header 检测规则')
script_src = models.JSONField(default=list, blank=True, help_text='脚本 URL 正则数组')
js = models.JSONField(default=list, blank=True, help_text='JavaScript 变量检测规则')
implies = models.JSONField(default=list, blank=True, help_text='依赖关系数组')
meta = models.JSONField(default=dict, blank=True, help_text='HTML meta 标签检测规则')
html = models.JSONField(default=list, blank=True, help_text='HTML 内容正则数组')
description = models.TextField(blank=True, default='', help_text='应用描述')
website = models.URLField(max_length=500, blank=True, default='', help_text='官网链接')
cpe = models.CharField(max_length=300, blank=True, default='', help_text='CPE 标识符')
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = 'wappalyzer_fingerprint'
verbose_name = 'Wappalyzer 指纹'
verbose_name_plural = 'Wappalyzer 指纹'
ordering = ['-created_at']
indexes = [
models.Index(fields=['name']),
models.Index(fields=['website']),
models.Index(fields=['cpe']),
models.Index(fields=['-created_at']),
]
def __str__(self) -> str:
return f"{self.name}"
class FingersFingerprint(models.Model):
"""Fingers 格式指纹规则 (fingers_http.json)
使用正则表达式和标签进行匹配,支持 favicon hash、header、body 等多种检测方式
"""
name = models.CharField(max_length=300, unique=True, help_text='指纹名称')
link = models.URLField(max_length=500, blank=True, default='', help_text='相关链接')
rule = models.JSONField(default=list, help_text='匹配规则数组')
tag = models.JSONField(default=list, help_text='标签数组')
focus = models.BooleanField(default=False, help_text='是否重点关注')
default_port = models.JSONField(default=list, blank=True, help_text='默认端口数组')
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = 'fingers_fingerprint'
verbose_name = 'Fingers 指纹'
verbose_name_plural = 'Fingers 指纹'
ordering = ['-created_at']
indexes = [
models.Index(fields=['name']),
models.Index(fields=['link']),
models.Index(fields=['focus']),
models.Index(fields=['-created_at']),
]
def __str__(self) -> str:
return f"{self.name}"
class FingerPrintHubFingerprint(models.Model):
"""FingerPrintHub 格式指纹规则 (fingerprinthub_web.json)
基于 nuclei 模板格式,使用 HTTP 请求和响应特征进行匹配
"""
fp_id = models.CharField(max_length=200, unique=True, help_text='指纹ID')
name = models.CharField(max_length=300, help_text='指纹名称')
author = models.CharField(max_length=200, blank=True, default='', help_text='作者')
tags = models.CharField(max_length=500, blank=True, default='', help_text='标签')
severity = models.CharField(max_length=50, blank=True, default='info', help_text='严重程度')
metadata = models.JSONField(default=dict, blank=True, help_text='元数据')
http = models.JSONField(default=list, help_text='HTTP 匹配规则')
source_file = models.CharField(max_length=500, blank=True, default='', help_text='来源文件')
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = 'fingerprinthub_fingerprint'
verbose_name = 'FingerPrintHub 指纹'
verbose_name_plural = 'FingerPrintHub 指纹'
ordering = ['-created_at']
indexes = [
models.Index(fields=['fp_id']),
models.Index(fields=['name']),
models.Index(fields=['author']),
models.Index(fields=['severity']),
models.Index(fields=['-created_at']),
]
def __str__(self) -> str:
return f"{self.name} ({self.fp_id})"
class ARLFingerprint(models.Model):
"""ARL 格式指纹规则 (ARL.yaml)
使用简单的 name + rule 表达式格式
"""
name = models.CharField(max_length=300, unique=True, help_text='指纹名称')
rule = models.TextField(help_text='匹配规则表达式')
created_at = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = 'arl_fingerprint'
verbose_name = 'ARL 指纹'
verbose_name_plural = 'ARL 指纹'
ordering = ['-created_at']
indexes = [
models.Index(fields=['name']),
models.Index(fields=['-created_at']),
]
def __str__(self) -> str:
return f"{self.name}"

View File

@@ -0,0 +1,20 @@
"""指纹管理 Serializers
导出所有指纹相关的 Serializer 类
"""
from .ehole import EholeFingerprintSerializer
from .goby import GobyFingerprintSerializer
from .wappalyzer import WappalyzerFingerprintSerializer
from .fingers import FingersFingerprintSerializer
from .fingerprinthub import FingerPrintHubFingerprintSerializer
from .arl import ARLFingerprintSerializer
__all__ = [
"EholeFingerprintSerializer",
"GobyFingerprintSerializer",
"WappalyzerFingerprintSerializer",
"FingersFingerprintSerializer",
"FingerPrintHubFingerprintSerializer",
"ARLFingerprintSerializer",
]

View File

@@ -0,0 +1,31 @@
"""ARL 指纹 Serializer"""
from rest_framework import serializers
from apps.engine.models import ARLFingerprint
class ARLFingerprintSerializer(serializers.ModelSerializer):
"""ARL 指纹序列化器
字段映射:
- name: 指纹名称 (必填, 唯一)
- rule: 匹配规则表达式 (必填)
"""
class Meta:
model = ARLFingerprint
fields = ['id', 'name', 'rule', 'created_at']
read_only_fields = ['id', 'created_at']
def validate_name(self, value):
"""校验 name 字段"""
if not value or not value.strip():
raise serializers.ValidationError("name 字段不能为空")
return value.strip()
def validate_rule(self, value):
"""校验 rule 字段"""
if not value or not value.strip():
raise serializers.ValidationError("rule 字段不能为空")
return value.strip()

View File

@@ -0,0 +1,27 @@
"""EHole 指纹 Serializer"""
from rest_framework import serializers
from apps.engine.models import EholeFingerprint
class EholeFingerprintSerializer(serializers.ModelSerializer):
"""EHole 指纹序列化器"""
class Meta:
model = EholeFingerprint
fields = ['id', 'cms', 'method', 'location', 'keyword',
'is_important', 'type', 'created_at']
read_only_fields = ['id', 'created_at']
def validate_cms(self, value):
"""校验 cms 字段"""
if not value or not value.strip():
raise serializers.ValidationError("cms 字段不能为空")
return value.strip()
def validate_keyword(self, value):
"""校验 keyword 字段"""
if not isinstance(value, list):
raise serializers.ValidationError("keyword 必须是数组")
return value

View File

@@ -0,0 +1,50 @@
"""FingerPrintHub 指纹 Serializer"""
from rest_framework import serializers
from apps.engine.models import FingerPrintHubFingerprint
class FingerPrintHubFingerprintSerializer(serializers.ModelSerializer):
"""FingerPrintHub 指纹序列化器
字段映射:
- fp_id: 指纹ID (必填, 唯一)
- name: 指纹名称 (必填)
- author: 作者 (可选)
- tags: 标签字符串 (可选)
- severity: 严重程度 (可选, 默认 'info')
- metadata: 元数据 JSON (可选)
- http: HTTP 匹配规则数组 (必填)
- source_file: 来源文件 (可选)
"""
class Meta:
model = FingerPrintHubFingerprint
fields = ['id', 'fp_id', 'name', 'author', 'tags', 'severity',
'metadata', 'http', 'source_file', 'created_at']
read_only_fields = ['id', 'created_at']
def validate_fp_id(self, value):
"""校验 fp_id 字段"""
if not value or not value.strip():
raise serializers.ValidationError("fp_id 字段不能为空")
return value.strip()
def validate_name(self, value):
"""校验 name 字段"""
if not value or not value.strip():
raise serializers.ValidationError("name 字段不能为空")
return value.strip()
def validate_http(self, value):
"""校验 http 字段"""
if not isinstance(value, list):
raise serializers.ValidationError("http 必须是数组")
return value
def validate_metadata(self, value):
"""校验 metadata 字段"""
if not isinstance(value, dict):
raise serializers.ValidationError("metadata 必须是对象")
return value

View File

@@ -0,0 +1,48 @@
"""Fingers 指纹 Serializer"""
from rest_framework import serializers
from apps.engine.models import FingersFingerprint
class FingersFingerprintSerializer(serializers.ModelSerializer):
"""Fingers 指纹序列化器
字段映射:
- name: 指纹名称 (必填, 唯一)
- link: 相关链接 (可选)
- rule: 匹配规则数组 (必填)
- tag: 标签数组 (可选)
- focus: 是否重点关注 (可选, 默认 False)
- default_port: 默认端口数组 (可选)
"""
class Meta:
model = FingersFingerprint
fields = ['id', 'name', 'link', 'rule', 'tag', 'focus',
'default_port', 'created_at']
read_only_fields = ['id', 'created_at']
def validate_name(self, value):
"""校验 name 字段"""
if not value or not value.strip():
raise serializers.ValidationError("name 字段不能为空")
return value.strip()
def validate_rule(self, value):
"""校验 rule 字段"""
if not isinstance(value, list):
raise serializers.ValidationError("rule 必须是数组")
return value
def validate_tag(self, value):
"""校验 tag 字段"""
if not isinstance(value, list):
raise serializers.ValidationError("tag 必须是数组")
return value
def validate_default_port(self, value):
"""校验 default_port 字段"""
if not isinstance(value, list):
raise serializers.ValidationError("default_port 必须是数组")
return value

View File

@@ -0,0 +1,26 @@
"""Goby 指纹 Serializer"""
from rest_framework import serializers
from apps.engine.models import GobyFingerprint
class GobyFingerprintSerializer(serializers.ModelSerializer):
"""Goby 指纹序列化器"""
class Meta:
model = GobyFingerprint
fields = ['id', 'name', 'logic', 'rule', 'created_at']
read_only_fields = ['id', 'created_at']
def validate_name(self, value):
"""校验 name 字段"""
if not value or not value.strip():
raise serializers.ValidationError("name 字段不能为空")
return value.strip()
def validate_rule(self, value):
"""校验 rule 字段"""
if not isinstance(value, list):
raise serializers.ValidationError("rule 必须是数组")
return value

View File

@@ -0,0 +1,24 @@
"""Wappalyzer 指纹 Serializer"""
from rest_framework import serializers
from apps.engine.models import WappalyzerFingerprint
class WappalyzerFingerprintSerializer(serializers.ModelSerializer):
"""Wappalyzer 指纹序列化器"""
class Meta:
model = WappalyzerFingerprint
fields = [
'id', 'name', 'cats', 'cookies', 'headers', 'script_src',
'js', 'implies', 'meta', 'html', 'description', 'website',
'cpe', 'created_at'
]
read_only_fields = ['id', 'created_at']
def validate_name(self, value):
"""校验 name 字段"""
if not value or not value.strip():
raise serializers.ValidationError("name 字段不能为空")
return value.strip()

View File

@@ -0,0 +1,22 @@
"""指纹管理 Services
导出所有指纹相关的 Service 类
"""
from .base import BaseFingerprintService
from .ehole import EholeFingerprintService
from .goby import GobyFingerprintService
from .wappalyzer import WappalyzerFingerprintService
from .fingers_service import FingersFingerprintService
from .fingerprinthub_service import FingerPrintHubFingerprintService
from .arl_service import ARLFingerprintService
__all__ = [
"BaseFingerprintService",
"EholeFingerprintService",
"GobyFingerprintService",
"WappalyzerFingerprintService",
"FingersFingerprintService",
"FingerPrintHubFingerprintService",
"ARLFingerprintService",
]

View File

@@ -0,0 +1,110 @@
"""ARL 指纹管理 Service
实现 ARL 格式指纹的校验、转换和导出逻辑
支持 YAML 格式的导入导出
"""
import logging
import yaml
from apps.engine.models import ARLFingerprint
from .base import BaseFingerprintService
logger = logging.getLogger(__name__)
class ARLFingerprintService(BaseFingerprintService):
"""ARL 指纹管理服务(继承基类,实现 ARL 特定逻辑)"""
model = ARLFingerprint
def validate_fingerprint(self, item: dict) -> bool:
"""
校验单条 ARL 指纹
校验规则:
- name 字段必须存在且非空
- rule 字段必须存在且非空
Args:
item: 单条指纹数据
Returns:
bool: 是否有效
"""
name = item.get('name', '')
rule = item.get('rule', '')
return bool(name and str(name).strip()) and bool(rule and str(rule).strip())
def to_model_data(self, item: dict) -> dict:
"""
转换 ARL YAML 格式为 Model 字段
Args:
item: 原始 ARL YAML 数据
Returns:
dict: Model 字段数据
"""
return {
'name': str(item.get('name', '')).strip(),
'rule': str(item.get('rule', '')).strip(),
}
def get_export_data(self) -> list:
"""
获取导出数据ARL 格式 - 数组,用于 YAML 导出)
Returns:
list: ARL 格式的数据(数组格式)
[
{"name": "...", "rule": "..."},
...
]
"""
fingerprints = self.model.objects.all()
return [
{
'name': fp.name,
'rule': fp.rule,
}
for fp in fingerprints
]
def export_to_yaml(self, output_path: str) -> int:
"""
导出所有指纹到 YAML 文件
Args:
output_path: 输出文件路径
Returns:
int: 导出的指纹数量
"""
data = self.get_export_data()
with open(output_path, 'w', encoding='utf-8') as f:
yaml.dump(data, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
count = len(data)
logger.info("导出 ARL 指纹文件: %s, 数量: %d", output_path, count)
return count
def parse_yaml_import(self, yaml_content: str) -> list:
"""
解析 YAML 格式的导入内容
Args:
yaml_content: YAML 格式的字符串内容
Returns:
list: 解析后的指纹数据列表
Raises:
ValueError: 当 YAML 格式无效时
"""
try:
data = yaml.safe_load(yaml_content)
if not isinstance(data, list):
raise ValueError("ARL YAML 文件必须是数组格式")
return data
except yaml.YAMLError as e:
raise ValueError(f"无效的 YAML 格式: {e}")

View File

@@ -0,0 +1,144 @@
"""指纹管理基类 Service
提供通用的批量操作和缓存逻辑,供 EHole/Goby/Wappalyzer 等子类继承
"""
import json
import logging
from typing import Any
logger = logging.getLogger(__name__)
class BaseFingerprintService:
"""指纹管理基类 Service提供通用的批量操作和缓存逻辑"""
model = None # 子类必须指定
BATCH_SIZE = 1000 # 每批处理数量
def validate_fingerprint(self, item: dict) -> bool:
"""
校验单条指纹,子类必须实现
Args:
item: 单条指纹数据
Returns:
bool: 是否有效
"""
raise NotImplementedError("子类必须实现 validate_fingerprint 方法")
def validate_fingerprints(self, raw_data: list) -> tuple[list, list]:
"""
批量校验指纹数据
Args:
raw_data: 原始指纹数据列表
Returns:
tuple: (valid_items, invalid_items)
"""
valid, invalid = [], []
for item in raw_data:
if self.validate_fingerprint(item):
valid.append(item)
else:
invalid.append(item)
return valid, invalid
def to_model_data(self, item: dict) -> dict:
"""
转换为 Model 字段,子类必须实现
Args:
item: 原始指纹数据
Returns:
dict: Model 字段数据
"""
raise NotImplementedError("子类必须实现 to_model_data 方法")
def bulk_create(self, fingerprints: list) -> int:
"""
批量创建指纹记录(已校验的数据)
Args:
fingerprints: 已校验的指纹数据列表
Returns:
int: 成功创建数量
"""
if not fingerprints:
return 0
objects = [self.model(**self.to_model_data(item)) for item in fingerprints]
created = self.model.objects.bulk_create(objects, ignore_conflicts=True)
return len(created)
def batch_create_fingerprints(self, raw_data: list) -> dict:
"""
完整流程:分批校验 + 批量创建
Args:
raw_data: 原始指纹数据列表
Returns:
dict: {'created': int, 'failed': int}
"""
total_created = 0
total_failed = 0
for i in range(0, len(raw_data), self.BATCH_SIZE):
batch = raw_data[i:i + self.BATCH_SIZE]
valid, invalid = self.validate_fingerprints(batch)
total_created += self.bulk_create(valid)
total_failed += len(invalid)
logger.info(
"批量创建指纹完成: created=%d, failed=%d, total=%d",
total_created, total_failed, len(raw_data)
)
return {'created': total_created, 'failed': total_failed}
def get_export_data(self) -> dict:
"""
获取导出数据,子类必须实现
Returns:
dict: 导出的 JSON 数据
"""
raise NotImplementedError("子类必须实现 get_export_data 方法")
def export_to_file(self, output_path: str) -> int:
"""
导出所有指纹到 JSON 文件
Args:
output_path: 输出文件路径
Returns:
int: 导出的指纹数量
"""
data = self.get_export_data()
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False)
count = len(data.get('fingerprint', []))
logger.info("导出指纹文件: %s, 数量: %d", output_path, count)
return count
def get_fingerprint_version(self) -> str:
"""
获取指纹库版本标识(用于缓存校验)
Returns:
str: 版本标识,格式 "{count}_{latest_timestamp}"
版本变化场景:
- 新增记录 → count 变化
- 删除记录 → count 变化
- 清空全部 → count 变为 0
"""
count = self.model.objects.count()
latest = self.model.objects.order_by('-created_at').first()
latest_ts = int(latest.created_at.timestamp()) if latest else 0
return f"{count}_{latest_ts}"

View File

@@ -0,0 +1,84 @@
"""EHole 指纹管理 Service
实现 EHole 格式指纹的校验、转换和导出逻辑
"""
from apps.engine.models import EholeFingerprint
from .base import BaseFingerprintService
class EholeFingerprintService(BaseFingerprintService):
"""EHole 指纹管理服务(继承基类,实现 EHole 特定逻辑)"""
model = EholeFingerprint
def validate_fingerprint(self, item: dict) -> bool:
"""
校验单条 EHole 指纹
校验规则:
- cms 字段必须存在且非空
- keyword 字段必须是数组
Args:
item: 单条指纹数据
Returns:
bool: 是否有效
"""
cms = item.get('cms', '')
keyword = item.get('keyword')
return bool(cms and str(cms).strip()) and isinstance(keyword, list)
def to_model_data(self, item: dict) -> dict:
"""
转换 EHole JSON 格式为 Model 字段
字段映射:
- isImportant (JSON) → is_important (Model)
Args:
item: 原始 EHole JSON 数据
Returns:
dict: Model 字段数据
"""
return {
'cms': str(item.get('cms', '')).strip(),
'method': item.get('method', 'keyword'),
'location': item.get('location', 'body'),
'keyword': item.get('keyword', []),
'is_important': item.get('isImportant', False),
'type': item.get('type', '-'),
}
def get_export_data(self) -> dict:
"""
获取导出数据EHole JSON 格式)
Returns:
dict: EHole 格式的 JSON 数据
{
"fingerprint": [
{"cms": "...", "method": "...", "location": "...",
"keyword": [...], "isImportant": false, "type": "..."},
...
],
"version": "1000_1703836800"
}
"""
fingerprints = self.model.objects.all()
data = []
for fp in fingerprints:
data.append({
'cms': fp.cms,
'method': fp.method,
'location': fp.location,
'keyword': fp.keyword,
'isImportant': fp.is_important, # 转回 JSON 格式
'type': fp.type,
})
return {
'fingerprint': data,
'version': self.get_fingerprint_version(),
}

View File

@@ -0,0 +1,110 @@
"""FingerPrintHub 指纹管理 Service
实现 FingerPrintHub 格式指纹的校验、转换和导出逻辑
"""
from apps.engine.models import FingerPrintHubFingerprint
from .base import BaseFingerprintService
class FingerPrintHubFingerprintService(BaseFingerprintService):
"""FingerPrintHub 指纹管理服务(继承基类,实现 FingerPrintHub 特定逻辑)"""
model = FingerPrintHubFingerprint
def validate_fingerprint(self, item: dict) -> bool:
"""
校验单条 FingerPrintHub 指纹
校验规则:
- id 字段必须存在且非空
- info 字段必须存在且包含 name
- http 字段必须是数组
Args:
item: 单条指纹数据
Returns:
bool: 是否有效
"""
fp_id = item.get('id', '')
info = item.get('info', {})
http = item.get('http')
if not fp_id or not str(fp_id).strip():
return False
if not isinstance(info, dict) or not info.get('name'):
return False
if not isinstance(http, list):
return False
return True
def to_model_data(self, item: dict) -> dict:
"""
转换 FingerPrintHub JSON 格式为 Model 字段
字段映射(嵌套结构转扁平):
- id (JSON) → fp_id (Model)
- info.name (JSON) → name (Model)
- info.author (JSON) → author (Model)
- info.tags (JSON) → tags (Model)
- info.severity (JSON) → severity (Model)
- info.metadata (JSON) → metadata (Model)
- http (JSON) → http (Model)
- _source_file (JSON) → source_file (Model)
Args:
item: 原始 FingerPrintHub JSON 数据
Returns:
dict: Model 字段数据
"""
info = item.get('info', {})
return {
'fp_id': str(item.get('id', '')).strip(),
'name': str(info.get('name', '')).strip(),
'author': info.get('author', ''),
'tags': info.get('tags', ''),
'severity': info.get('severity', 'info'),
'metadata': info.get('metadata', {}),
'http': item.get('http', []),
'source_file': item.get('_source_file', ''),
}
def get_export_data(self) -> list:
"""
获取导出数据FingerPrintHub JSON 格式 - 数组)
Returns:
list: FingerPrintHub 格式的 JSON 数据(数组格式)
[
{
"id": "...",
"info": {"name": "...", "author": "...", "tags": "...",
"severity": "...", "metadata": {...}},
"http": [...],
"_source_file": "..."
},
...
]
"""
fingerprints = self.model.objects.all()
data = []
for fp in fingerprints:
item = {
'id': fp.fp_id,
'info': {
'name': fp.name,
'author': fp.author,
'tags': fp.tags,
'severity': fp.severity,
'metadata': fp.metadata,
},
'http': fp.http,
}
# 只有当 source_file 非空时才添加该字段
if fp.source_file:
item['_source_file'] = fp.source_file
data.append(item)
return data

View File

@@ -0,0 +1,83 @@
"""Fingers 指纹管理 Service
实现 Fingers 格式指纹的校验、转换和导出逻辑
"""
from apps.engine.models import FingersFingerprint
from .base import BaseFingerprintService
class FingersFingerprintService(BaseFingerprintService):
"""Fingers 指纹管理服务(继承基类,实现 Fingers 特定逻辑)"""
model = FingersFingerprint
def validate_fingerprint(self, item: dict) -> bool:
"""
校验单条 Fingers 指纹
校验规则:
- name 字段必须存在且非空
- rule 字段必须是数组
Args:
item: 单条指纹数据
Returns:
bool: 是否有效
"""
name = item.get('name', '')
rule = item.get('rule')
return bool(name and str(name).strip()) and isinstance(rule, list)
def to_model_data(self, item: dict) -> dict:
"""
转换 Fingers JSON 格式为 Model 字段
字段映射:
- default_port (JSON) → default_port (Model)
Args:
item: 原始 Fingers JSON 数据
Returns:
dict: Model 字段数据
"""
return {
'name': str(item.get('name', '')).strip(),
'link': item.get('link', ''),
'rule': item.get('rule', []),
'tag': item.get('tag', []),
'focus': item.get('focus', False),
'default_port': item.get('default_port', []),
}
def get_export_data(self) -> list:
"""
获取导出数据Fingers JSON 格式 - 数组)
Returns:
list: Fingers 格式的 JSON 数据(数组格式)
[
{"name": "...", "link": "...", "rule": [...], "tag": [...],
"focus": false, "default_port": [...]},
...
]
"""
fingerprints = self.model.objects.all()
data = []
for fp in fingerprints:
item = {
'name': fp.name,
'link': fp.link,
'rule': fp.rule,
'tag': fp.tag,
}
# 只有当 focus 为 True 时才添加该字段(保持与原始格式一致)
if fp.focus:
item['focus'] = fp.focus
# 只有当 default_port 非空时才添加该字段
if fp.default_port:
item['default_port'] = fp.default_port
data.append(item)
return data

View File

@@ -0,0 +1,70 @@
"""Goby 指纹管理 Service
实现 Goby 格式指纹的校验、转换和导出逻辑
"""
from apps.engine.models import GobyFingerprint
from .base import BaseFingerprintService
class GobyFingerprintService(BaseFingerprintService):
"""Goby 指纹管理服务(继承基类,实现 Goby 特定逻辑)"""
model = GobyFingerprint
def validate_fingerprint(self, item: dict) -> bool:
"""
校验单条 Goby 指纹
校验规则:
- name 字段必须存在且非空
- logic 字段必须存在
- rule 字段必须是数组
Args:
item: 单条指纹数据
Returns:
bool: 是否有效
"""
name = item.get('name', '')
logic = item.get('logic', '')
rule = item.get('rule')
return bool(name and str(name).strip()) and bool(logic) and isinstance(rule, list)
def to_model_data(self, item: dict) -> dict:
"""
转换 Goby JSON 格式为 Model 字段
Args:
item: 原始 Goby JSON 数据
Returns:
dict: Model 字段数据
"""
return {
'name': str(item.get('name', '')).strip(),
'logic': item.get('logic', ''),
'rule': item.get('rule', []),
}
def get_export_data(self) -> list:
"""
获取导出数据Goby JSON 格式 - 数组)
Returns:
list: Goby 格式的 JSON 数据(数组格式)
[
{"name": "...", "logic": "...", "rule": [...]},
...
]
"""
fingerprints = self.model.objects.all()
return [
{
'name': fp.name,
'logic': fp.logic,
'rule': fp.rule,
}
for fp in fingerprints
]

View File

@@ -0,0 +1,99 @@
"""Wappalyzer 指纹管理 Service
实现 Wappalyzer 格式指纹的校验、转换和导出逻辑
"""
from apps.engine.models import WappalyzerFingerprint
from .base import BaseFingerprintService
class WappalyzerFingerprintService(BaseFingerprintService):
"""Wappalyzer 指纹管理服务(继承基类,实现 Wappalyzer 特定逻辑)"""
model = WappalyzerFingerprint
def validate_fingerprint(self, item: dict) -> bool:
"""
校验单条 Wappalyzer 指纹
校验规则:
- name 字段必须存在且非空(从 apps 对象的 key 传入)
Args:
item: 单条指纹数据
Returns:
bool: 是否有效
"""
name = item.get('name', '')
return bool(name and str(name).strip())
def to_model_data(self, item: dict) -> dict:
"""
转换 Wappalyzer JSON 格式为 Model 字段
字段映射:
- scriptSrc (JSON) → script_src (Model)
Args:
item: 原始 Wappalyzer JSON 数据
Returns:
dict: Model 字段数据
"""
return {
'name': str(item.get('name', '')).strip(),
'cats': item.get('cats', []),
'cookies': item.get('cookies', {}),
'headers': item.get('headers', {}),
'script_src': item.get('scriptSrc', []), # JSON: scriptSrc -> Model: script_src
'js': item.get('js', []),
'implies': item.get('implies', []),
'meta': item.get('meta', {}),
'html': item.get('html', []),
'description': item.get('description', ''),
'website': item.get('website', ''),
'cpe': item.get('cpe', ''),
}
def get_export_data(self) -> dict:
"""
获取导出数据Wappalyzer JSON 格式)
Returns:
dict: Wappalyzer 格式的 JSON 数据
{
"apps": {
"AppName": {"cats": [...], "cookies": {...}, ...},
...
}
}
"""
fingerprints = self.model.objects.all()
apps = {}
for fp in fingerprints:
app_data = {}
if fp.cats:
app_data['cats'] = fp.cats
if fp.cookies:
app_data['cookies'] = fp.cookies
if fp.headers:
app_data['headers'] = fp.headers
if fp.script_src:
app_data['scriptSrc'] = fp.script_src # Model: script_src -> JSON: scriptSrc
if fp.js:
app_data['js'] = fp.js
if fp.implies:
app_data['implies'] = fp.implies
if fp.meta:
app_data['meta'] = fp.meta
if fp.html:
app_data['html'] = fp.html
if fp.description:
app_data['description'] = fp.description
if fp.website:
app_data['website'] = fp.website
if fp.cpe:
app_data['cpe'] = fp.cpe
apps[fp.name] = app_data
return {'apps': apps}

View File

@@ -196,6 +196,9 @@ class NucleiTemplateRepoService:
cmd: List[str]
action: str
# 直接使用原始 URL不再使用 Git 加速)
repo_url = obj.repo_url
# 判断是 clone 还是 pull
if git_dir.is_dir():
# 检查远程地址是否变化
@@ -208,12 +211,13 @@ class NucleiTemplateRepoService:
)
current_url = current_remote.stdout.strip() if current_remote.returncode == 0 else ""
if current_url != obj.repo_url:
# 检查是否需要重新 clone
if current_url != repo_url:
# 远程地址变化,删除旧目录重新 clone
logger.info("nuclei 模板仓库 %s 远程地址变化,重新 clone: %s -> %s", obj.id, current_url, obj.repo_url)
logger.info("nuclei 模板仓库 %s 远程地址变化,重新 clone: %s -> %s", obj.id, current_url, repo_url)
shutil.rmtree(local_path)
local_path.mkdir(parents=True, exist_ok=True)
cmd = ["git", "clone", "--depth", "1", obj.repo_url, str(local_path)]
cmd = ["git", "clone", "--depth", "1", repo_url, str(local_path)]
action = "clone"
else:
# 已有仓库且地址未变,执行 pull
@@ -224,7 +228,7 @@ class NucleiTemplateRepoService:
if local_path.exists() and not local_path.is_dir():
raise RuntimeError(f"本地路径已存在且不是目录: {local_path}")
# --depth 1 浅克隆,只获取最新提交,节省空间和时间
cmd = ["git", "clone", "--depth", "1", obj.repo_url, str(local_path)]
cmd = ["git", "clone", "--depth", "1", repo_url, str(local_path)]
action = "clone"
# 执行 Git 命令

View File

@@ -76,8 +76,8 @@ class TaskDistributor:
self.docker_image = settings.TASK_EXECUTOR_IMAGE
if not self.docker_image:
raise ValueError("TASK_EXECUTOR_IMAGE 未配置,请确保 IMAGE_TAG 环境变量已设置")
self.results_mount = getattr(settings, 'CONTAINER_RESULTS_MOUNT', '/app/backend/results')
self.logs_mount = getattr(settings, 'CONTAINER_LOGS_MOUNT', '/app/backend/logs')
# 统一使用 /opt/xingrin 下的路径
self.logs_mount = "/opt/xingrin/logs"
self.submit_interval = getattr(settings, 'TASK_SUBMIT_INTERVAL', 5)
def get_online_workers(self) -> list[WorkerNode]:
@@ -153,30 +153,68 @@ class TaskDistributor:
else:
scored_workers.append((worker, score, cpu, mem))
# 降级策略:如果没有正常负载的,等待后重新选择
# 降级策略:如果没有正常负载的,循环等待后重新检测
if not scored_workers:
if high_load_workers:
# 高负载时先等待,给系统喘息时间(默认 60 秒)
# 高负载等待参数(默认 60 秒检测一次,最多 10 次
high_load_wait = getattr(settings, 'HIGH_LOAD_WAIT_SECONDS', 60)
logger.warning("所有 Worker 高负载,等待 %d 秒后重试...", high_load_wait)
time.sleep(high_load_wait)
high_load_max_retries = getattr(settings, 'HIGH_LOAD_MAX_RETRIES', 10)
# 重新选择(递归调用,可能负载已降下来)
# 为避免无限递归,这里直接使用高负载中最低的
# 开始等待前发送高负载通知
high_load_workers.sort(key=lambda x: x[1])
best_worker, _, cpu, mem = high_load_workers[0]
# 发送高负载通知
_, _, first_cpu, first_mem = high_load_workers[0]
from apps.common.signals import all_workers_high_load
all_workers_high_load.send(
sender=self.__class__,
worker_name=best_worker.name,
cpu=cpu,
mem=mem
worker_name="所有节点",
cpu=first_cpu,
mem=first_mem
)
logger.info("选择 Worker: %s (CPU: %.1f%%, MEM: %.1f%%)", best_worker.name, cpu, mem)
return best_worker
for retry in range(high_load_max_retries):
logger.warning(
"所有 Worker 高负载,等待 %d 秒后重试... (%d/%d)",
high_load_wait, retry + 1, high_load_max_retries
)
time.sleep(high_load_wait)
# 重新获取负载数据
loads = worker_load_service.get_all_loads(worker_ids)
# 重新评估
scored_workers = []
high_load_workers = []
for worker in workers:
load = loads.get(worker.id)
if not load:
continue
cpu = load.get('cpu', 0)
mem = load.get('mem', 0)
score = cpu * 0.7 + mem * 0.3
if cpu > 85 or mem > 85:
high_load_workers.append((worker, score, cpu, mem))
else:
scored_workers.append((worker, score, cpu, mem))
# 如果有正常负载的 Worker跳出循环
if scored_workers:
logger.info("检测到正常负载 Worker结束等待")
break
# 超时或仍然高负载,选择负载最低的
if not scored_workers and high_load_workers:
high_load_workers.sort(key=lambda x: x[1])
best_worker, _, cpu, mem = high_load_workers[0]
logger.warning(
"等待超时,强制分发到高负载 Worker: %s (CPU: %.1f%%, MEM: %.1f%%)",
best_worker.name, cpu, mem
)
return best_worker
return best_worker
else:
logger.warning("没有可用的 Worker")
return None
@@ -226,6 +264,10 @@ class TaskDistributor:
"""
import shlex
# Docker API 版本配置(避免版本不兼容问题)
# 默认使用 1.40 以获得最大兼容性(支持 Docker 19.03+
api_version = getattr(settings, 'DOCKER_API_VERSION', '1.40')
# 根据 Worker 类型确定网络和 Server 地址
if worker.is_local:
# 本地:加入 Docker 网络,使用内部服务名
@@ -234,11 +276,10 @@ class TaskDistributor:
else:
# 远程:通过 Nginx 反向代理访问HTTPS不直连 8888 端口)
network_arg = ""
server_url = f"https://{settings.PUBLIC_HOST}"
server_url = f"https://{settings.PUBLIC_HOST}:{settings.PUBLIC_PORT}"
# 挂载路径(所有节点统一使用固定路径
host_results_dir = settings.HOST_RESULTS_DIR # /opt/xingrin/results
host_logs_dir = settings.HOST_LOGS_DIR # /opt/xingrin/logs
# 挂载路径(统一挂载 /opt/xingrin扫描工具在 /opt/xingrin-tools/bin 不受影响
host_xingrin_dir = "/opt/xingrin"
# 环境变量SERVER_URL + IS_LOCAL其他配置容器启动时从配置中心获取
# IS_LOCAL 用于 Worker 向配置中心声明身份,决定返回的数据库地址
@@ -251,15 +292,12 @@ class TaskDistributor:
"-e PREFECT_SERVER_EPHEMERAL_ENABLED=true", # 启用 ephemeral server本地临时服务器
"-e PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=120", # 增加启动超时时间
"-e PREFECT_SERVER_DATABASE_CONNECTION_URL=sqlite+aiosqlite:////tmp/.prefect/prefect.db", # 使用 /tmp 下的 SQLite
"-e PREFECT_LOGGING_LEVEL=DEBUG", # 启用 DEBUG 级别日志
"-e PREFECT_LOGGING_SERVER_LEVEL=DEBUG", # Server 日志级别
"-e PREFECT_DEBUG_MODE=true", # 启用调试模式
"-e PREFECT_LOGGING_LEVEL=WARNING", # 日志级别(减少 DEBUG 噪音)
]
# 挂载卷
# 挂载卷(统一挂载整个 /opt/xingrin 目录)
volumes = [
f"-v {host_results_dir}:{self.results_mount}",
f"-v {host_logs_dir}:{self.logs_mount}",
f"-v {host_xingrin_dir}:{host_xingrin_dir}",
]
# 构建命令行参数
@@ -278,10 +316,11 @@ class TaskDistributor:
# - 远程 Workerdeploy 时已预拉取镜像,直接使用本地版本
# - 避免每次任务都检查 Docker Hub提升性能和稳定性
# 使用双引号包裹 sh -c 命令,内部 shlex.quote 生成的单引号参数可正确解析
cmd = f'''docker run --rm -d --pull=missing {network_arg} \
{' '.join(env_vars)} \
{' '.join(volumes)} \
{self.docker_image} \
# DOCKER_API_VERSION 环境变量确保客户端和服务端 API 版本兼容
cmd = f'''DOCKER_API_VERSION={api_version} docker run --rm -d --pull=missing {network_arg} \\
{' '.join(env_vars)} \\
{' '.join(volumes)} \\
{self.docker_image} \\
sh -c "{inner_cmd}"'''
return cmd
@@ -520,7 +559,7 @@ class TaskDistributor:
try:
# 构建 docker run 命令(清理过期扫描结果目录)
script_args = {
'results_dir': '/app/backend/results',
'results_dir': '/opt/xingrin/results',
'retention_days': retention_days,
}

View File

@@ -7,6 +7,14 @@ from .views import (
WordlistViewSet,
NucleiTemplateRepoViewSet,
)
from .views.fingerprints import (
EholeFingerprintViewSet,
GobyFingerprintViewSet,
WappalyzerFingerprintViewSet,
FingersFingerprintViewSet,
FingerPrintHubFingerprintViewSet,
ARLFingerprintViewSet,
)
# 创建路由器
@@ -15,6 +23,13 @@ router.register(r"engines", ScanEngineViewSet, basename="engine")
router.register(r"workers", WorkerNodeViewSet, basename="worker")
router.register(r"wordlists", WordlistViewSet, basename="wordlist")
router.register(r"nuclei/repos", NucleiTemplateRepoViewSet, basename="nuclei-repos")
# 指纹管理
router.register(r"fingerprints/ehole", EholeFingerprintViewSet, basename="ehole-fingerprint")
router.register(r"fingerprints/goby", GobyFingerprintViewSet, basename="goby-fingerprint")
router.register(r"fingerprints/wappalyzer", WappalyzerFingerprintViewSet, basename="wappalyzer-fingerprint")
router.register(r"fingerprints/fingers", FingersFingerprintViewSet, basename="fingers-fingerprint")
router.register(r"fingerprints/fingerprinthub", FingerPrintHubFingerprintViewSet, basename="fingerprinthub-fingerprint")
router.register(r"fingerprints/arl", ARLFingerprintViewSet, basename="arl-fingerprint")
urlpatterns = [
path("", include(router.urls)),

View File

@@ -0,0 +1,22 @@
"""指纹管理 ViewSets
导出所有指纹相关的 ViewSet 类
"""
from .base import BaseFingerprintViewSet
from .ehole import EholeFingerprintViewSet
from .goby import GobyFingerprintViewSet
from .wappalyzer import WappalyzerFingerprintViewSet
from .fingers import FingersFingerprintViewSet
from .fingerprinthub import FingerPrintHubFingerprintViewSet
from .arl import ARLFingerprintViewSet
__all__ = [
"BaseFingerprintViewSet",
"EholeFingerprintViewSet",
"GobyFingerprintViewSet",
"WappalyzerFingerprintViewSet",
"FingersFingerprintViewSet",
"FingerPrintHubFingerprintViewSet",
"ARLFingerprintViewSet",
]

View File

@@ -0,0 +1,122 @@
"""ARL 指纹管理 ViewSet"""
import yaml
from django.http import HttpResponse
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from apps.common.pagination import BasePagination
from apps.common.response_helpers import success_response
from apps.engine.models import ARLFingerprint
from apps.engine.serializers.fingerprints import ARLFingerprintSerializer
from apps.engine.services.fingerprints import ARLFingerprintService
from .base import BaseFingerprintViewSet
class ARLFingerprintViewSet(BaseFingerprintViewSet):
"""ARL 指纹管理 ViewSet
继承自 BaseFingerprintViewSet提供以下 API
标准 CRUDModelViewSet
- GET / 列表查询(分页)
- POST / 创建单条
- GET /{id}/ 获取详情
- PUT /{id}/ 更新
- DELETE /{id}/ 删除
批量操作(继承自基类):
- POST /batch_create/ 批量创建JSON body
- POST /import_file/ 文件导入multipart/form-data支持 YAML
- POST /bulk-delete/ 批量删除
- POST /delete-all/ 删除所有
- GET /export/ 导出下载YAML 格式)
智能过滤语法filter 参数):
- name="word" 模糊匹配 name 字段
- name=="WordPress" 精确匹配
- rule="body=" 按规则内容筛选
"""
queryset = ARLFingerprint.objects.all()
serializer_class = ARLFingerprintSerializer
pagination_class = BasePagination
service_class = ARLFingerprintService
# 排序配置
ordering_fields = ['created_at', 'name']
ordering = ['-created_at']
# ARL 过滤字段映射
FILTER_FIELD_MAPPING = {
'name': 'name',
'rule': 'rule',
}
def parse_import_data(self, json_data) -> list:
"""
解析 ARL 格式的导入数据JSON 格式)
输入格式:[{...}, {...}] 数组格式
返回:指纹列表
"""
if isinstance(json_data, list):
return json_data
return []
def get_export_filename(self) -> str:
"""导出文件名"""
return 'ARL.yaml'
@action(detail=False, methods=['post'])
def import_file(self, request):
"""
文件导入(支持 YAML 和 JSON 格式)
POST /api/engine/fingerprints/arl/import_file/
请求格式multipart/form-data
- file: YAML 或 JSON 文件
返回:同 batch_create
"""
file = request.FILES.get('file')
if not file:
raise ValidationError('缺少文件')
filename = file.name.lower()
content = file.read().decode('utf-8')
try:
if filename.endswith('.yaml') or filename.endswith('.yml'):
# YAML 格式
fingerprints = yaml.safe_load(content)
else:
# JSON 格式
import json
fingerprints = json.loads(content)
except (yaml.YAMLError, json.JSONDecodeError) as e:
raise ValidationError(f'无效的文件格式: {e}')
if not isinstance(fingerprints, list):
raise ValidationError('文件内容必须是数组格式')
if not fingerprints:
raise ValidationError('文件中没有有效的指纹数据')
result = self.get_service().batch_create_fingerprints(fingerprints)
return success_response(data=result)
@action(detail=False, methods=['get'])
def export(self, request):
"""
导出指纹YAML 格式)
GET /api/engine/fingerprints/arl/export/
返回YAML 文件下载
"""
data = self.get_service().get_export_data()
content = yaml.dump(data, allow_unicode=True, default_flow_style=False, sort_keys=False)
response = HttpResponse(content, content_type='application/x-yaml')
response['Content-Disposition'] = f'attachment; filename="{self.get_export_filename()}"'
return response

View File

@@ -0,0 +1,203 @@
"""指纹管理基类 ViewSet
提供通用的 CRUD 和批量操作,供 EHole/Goby/Wappalyzer 等子类继承
"""
import json
import logging
from django.http import HttpResponse
from rest_framework import viewsets, status, filters
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.exceptions import ValidationError
from apps.common.pagination import BasePagination
from apps.common.response_helpers import success_response
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
class BaseFingerprintViewSet(viewsets.ModelViewSet):
"""指纹管理基类 ViewSet供 EHole/Goby/Wappalyzer 等子类继承
提供的 API
标准 CRUD继承自 ModelViewSet
- GET / 列表查询(分页 + 智能过滤)
- POST / 创建单条
- GET /{id}/ 获取详情
- PUT /{id}/ 更新
- DELETE /{id}/ 删除
批量操作(本类实现):
- POST /batch_create/ 批量创建JSON body
- POST /import_file/ 文件导入multipart/form-data适合 10MB+ 大文件)
- POST /bulk-delete/ 批量删除
- POST /delete-all/ 删除所有
- GET /export/ 导出下载
智能过滤语法filter 参数):
- field="value" 模糊匹配(包含)
- field=="value" 精确匹配
- 多条件空格分隔 AND 关系
- || 或 or OR 关系
子类必须实现:
- service_class Service 类
- parse_import_data 解析导入数据格式
- get_export_filename 导出文件名
"""
pagination_class = BasePagination
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
# 子类必须指定
service_class = None # Service 类
# 智能过滤字段映射,子类必须覆盖
FILTER_FIELD_MAPPING = {}
# JSON 数组字段列表(使用 __contains 查询),子类可覆盖
JSON_ARRAY_FIELDS = []
def get_queryset(self):
"""支持智能过滤语法"""
queryset = super().get_queryset()
filter_query = self.request.query_params.get('filter', None)
if filter_query:
queryset = apply_filters(
queryset,
filter_query,
self.FILTER_FIELD_MAPPING,
json_array_fields=getattr(self, 'JSON_ARRAY_FIELDS', [])
)
return queryset
def get_service(self):
"""获取 Service 实例"""
if self.service_class is None:
raise NotImplementedError("子类必须指定 service_class")
return self.service_class()
def parse_import_data(self, json_data: dict) -> list:
"""
解析导入数据,子类必须实现
Args:
json_data: 解析后的 JSON 数据
Returns:
list: 指纹数据列表
"""
raise NotImplementedError("子类必须实现 parse_import_data 方法")
def get_export_filename(self) -> str:
"""
导出文件名,子类必须实现
Returns:
str: 文件名
"""
raise NotImplementedError("子类必须实现 get_export_filename 方法")
@action(detail=False, methods=['post'])
def batch_create(self, request):
"""
批量创建指纹规则
POST /api/engine/fingerprints/{type}/batch_create/
请求格式:
{
"fingerprints": [
{"cms": "WordPress", "method": "keyword", ...},
...
]
}
返回:
{
"created": 2,
"failed": 0
}
"""
fingerprints = request.data.get('fingerprints', [])
if not fingerprints:
raise ValidationError('fingerprints 不能为空')
if not isinstance(fingerprints, list):
raise ValidationError('fingerprints 必须是数组')
result = self.get_service().batch_create_fingerprints(fingerprints)
return success_response(data=result, status_code=status.HTTP_201_CREATED)
@action(detail=False, methods=['post'])
def import_file(self, request):
"""
文件导入适合大文件10MB+
POST /api/engine/fingerprints/{type}/import_file/
请求格式multipart/form-data
- file: JSON 文件
返回:同 batch_create
"""
file = request.FILES.get('file')
if not file:
raise ValidationError('缺少文件')
try:
json_data = json.load(file)
except json.JSONDecodeError as e:
raise ValidationError(f'无效的 JSON 格式: {e}')
fingerprints = self.parse_import_data(json_data)
if not fingerprints:
raise ValidationError('文件中没有有效的指纹数据')
result = self.get_service().batch_create_fingerprints(fingerprints)
return success_response(data=result, status_code=status.HTTP_201_CREATED)
@action(detail=False, methods=['post'], url_path='bulk-delete')
def bulk_delete(self, request):
"""
批量删除
POST /api/engine/fingerprints/{type}/bulk-delete/
请求格式:{"ids": [1, 2, 3]}
返回:{"deleted": 3}
"""
ids = request.data.get('ids', [])
if not ids:
raise ValidationError('ids 不能为空')
if not isinstance(ids, list):
raise ValidationError('ids 必须是数组')
deleted_count = self.queryset.model.objects.filter(id__in=ids).delete()[0]
return success_response(data={'deleted': deleted_count})
@action(detail=False, methods=['post'], url_path='delete-all')
def delete_all(self, request):
"""
删除所有指纹
POST /api/engine/fingerprints/{type}/delete-all/
返回:{"deleted": 1000}
"""
deleted_count = self.queryset.model.objects.all().delete()[0]
return success_response(data={'deleted': deleted_count})
@action(detail=False, methods=['get'])
def export(self, request):
"""
导出指纹(前端下载)
GET /api/engine/fingerprints/{type}/export/
返回JSON 文件下载
"""
data = self.get_service().get_export_data()
content = json.dumps(data, ensure_ascii=False, indent=2)
response = HttpResponse(content, content_type='application/json')
response['Content-Disposition'] = f'attachment; filename="{self.get_export_filename()}"'
return response

View File

@@ -0,0 +1,67 @@
"""EHole 指纹管理 ViewSet"""
from apps.common.pagination import BasePagination
from apps.engine.models import EholeFingerprint
from apps.engine.serializers.fingerprints import EholeFingerprintSerializer
from apps.engine.services.fingerprints import EholeFingerprintService
from .base import BaseFingerprintViewSet
class EholeFingerprintViewSet(BaseFingerprintViewSet):
"""EHole 指纹管理 ViewSet
继承自 BaseFingerprintViewSet提供以下 API
标准 CRUDModelViewSet
- GET / 列表查询(分页)
- POST / 创建单条
- GET /{id}/ 获取详情
- PUT /{id}/ 更新
- DELETE /{id}/ 删除
批量操作(继承自基类):
- POST /batch_create/ 批量创建JSON body
- POST /import_file/ 文件导入multipart/form-data
- POST /bulk-delete/ 批量删除
- POST /delete-all/ 删除所有
- GET /export/ 导出下载
智能过滤语法filter 参数):
- cms="word" 模糊匹配 cms 字段
- cms=="WordPress" 精确匹配
- type="CMS" 按类型筛选
- method="keyword" 按匹配方式筛选
- location="body" 按匹配位置筛选
"""
queryset = EholeFingerprint.objects.all()
serializer_class = EholeFingerprintSerializer
pagination_class = BasePagination
service_class = EholeFingerprintService
# 排序配置
ordering_fields = ['created_at', 'cms']
ordering = ['-created_at']
# EHole 过滤字段映射
FILTER_FIELD_MAPPING = {
'cms': 'cms',
'method': 'method',
'location': 'location',
'type': 'type',
'isImportant': 'is_important',
}
def parse_import_data(self, json_data: dict) -> list:
"""
解析 EHole JSON 格式的导入数据
输入格式:{"fingerprint": [...]}
返回:指纹列表
"""
return json_data.get('fingerprint', [])
def get_export_filename(self) -> str:
"""导出文件名"""
return 'ehole.json'

View File

@@ -0,0 +1,73 @@
"""FingerPrintHub 指纹管理 ViewSet"""
from apps.common.pagination import BasePagination
from apps.engine.models import FingerPrintHubFingerprint
from apps.engine.serializers.fingerprints import FingerPrintHubFingerprintSerializer
from apps.engine.services.fingerprints import FingerPrintHubFingerprintService
from .base import BaseFingerprintViewSet
class FingerPrintHubFingerprintViewSet(BaseFingerprintViewSet):
"""FingerPrintHub 指纹管理 ViewSet
继承自 BaseFingerprintViewSet提供以下 API
标准 CRUDModelViewSet
- GET / 列表查询(分页)
- POST / 创建单条
- GET /{id}/ 获取详情
- PUT /{id}/ 更新
- DELETE /{id}/ 删除
批量操作(继承自基类):
- POST /batch_create/ 批量创建JSON body
- POST /import_file/ 文件导入multipart/form-data
- POST /bulk-delete/ 批量删除
- POST /delete-all/ 删除所有
- GET /export/ 导出下载
智能过滤语法filter 参数):
- name="word" 模糊匹配 name 字段
- fp_id=="xxx" 精确匹配指纹ID
- author="xxx" 按作者筛选
- severity="info" 按严重程度筛选
- tags="cms" 按标签筛选
"""
queryset = FingerPrintHubFingerprint.objects.all()
serializer_class = FingerPrintHubFingerprintSerializer
pagination_class = BasePagination
service_class = FingerPrintHubFingerprintService
# 排序配置
ordering_fields = ['created_at', 'name', 'severity']
ordering = ['-created_at']
# FingerPrintHub 过滤字段映射
FILTER_FIELD_MAPPING = {
'fp_id': 'fp_id',
'name': 'name',
'author': 'author',
'tags': 'tags',
'severity': 'severity',
'source_file': 'source_file',
}
# JSON 数组字段(使用 __contains 查询)
JSON_ARRAY_FIELDS = ['http']
def parse_import_data(self, json_data) -> list:
"""
解析 FingerPrintHub JSON 格式的导入数据
输入格式:[{...}, {...}] 数组格式
返回:指纹列表
"""
if isinstance(json_data, list):
return json_data
return []
def get_export_filename(self) -> str:
"""导出文件名"""
return 'fingerprinthub_web.json'

View File

@@ -0,0 +1,69 @@
"""Fingers 指纹管理 ViewSet"""
from apps.common.pagination import BasePagination
from apps.engine.models import FingersFingerprint
from apps.engine.serializers.fingerprints import FingersFingerprintSerializer
from apps.engine.services.fingerprints import FingersFingerprintService
from .base import BaseFingerprintViewSet
class FingersFingerprintViewSet(BaseFingerprintViewSet):
"""Fingers 指纹管理 ViewSet
继承自 BaseFingerprintViewSet提供以下 API
标准 CRUDModelViewSet
- GET / 列表查询(分页)
- POST / 创建单条
- GET /{id}/ 获取详情
- PUT /{id}/ 更新
- DELETE /{id}/ 删除
批量操作(继承自基类):
- POST /batch_create/ 批量创建JSON body
- POST /import_file/ 文件导入multipart/form-data
- POST /bulk-delete/ 批量删除
- POST /delete-all/ 删除所有
- GET /export/ 导出下载
智能过滤语法filter 参数):
- name="word" 模糊匹配 name 字段
- name=="WordPress" 精确匹配
- tag="cms" 按标签筛选
- focus="true" 按重点关注筛选
"""
queryset = FingersFingerprint.objects.all()
serializer_class = FingersFingerprintSerializer
pagination_class = BasePagination
service_class = FingersFingerprintService
# 排序配置
ordering_fields = ['created_at', 'name']
ordering = ['-created_at']
# Fingers 过滤字段映射
FILTER_FIELD_MAPPING = {
'name': 'name',
'link': 'link',
'focus': 'focus',
}
# JSON 数组字段(使用 __contains 查询)
JSON_ARRAY_FIELDS = ['tag', 'rule', 'default_port']
def parse_import_data(self, json_data) -> list:
"""
解析 Fingers JSON 格式的导入数据
输入格式:[{...}, {...}] 数组格式
返回:指纹列表
"""
if isinstance(json_data, list):
return json_data
return []
def get_export_filename(self) -> str:
"""导出文件名"""
return 'fingers_http.json'

View File

@@ -0,0 +1,65 @@
"""Goby 指纹管理 ViewSet"""
from apps.common.pagination import BasePagination
from apps.engine.models import GobyFingerprint
from apps.engine.serializers.fingerprints import GobyFingerprintSerializer
from apps.engine.services.fingerprints import GobyFingerprintService
from .base import BaseFingerprintViewSet
class GobyFingerprintViewSet(BaseFingerprintViewSet):
"""Goby 指纹管理 ViewSet
继承自 BaseFingerprintViewSet提供以下 API
标准 CRUDModelViewSet
- GET / 列表查询(分页)
- POST / 创建单条
- GET /{id}/ 获取详情
- PUT /{id}/ 更新
- DELETE /{id}/ 删除
批量操作(继承自基类):
- POST /batch_create/ 批量创建JSON body
- POST /import_file/ 文件导入multipart/form-data
- POST /bulk-delete/ 批量删除
- POST /delete-all/ 删除所有
- GET /export/ 导出下载
智能过滤语法filter 参数):
- name="word" 模糊匹配 name 字段
- name=="ProductName" 精确匹配
"""
queryset = GobyFingerprint.objects.all()
serializer_class = GobyFingerprintSerializer
pagination_class = BasePagination
service_class = GobyFingerprintService
# 排序配置
ordering_fields = ['created_at', 'name']
ordering = ['-created_at']
# Goby 过滤字段映射
FILTER_FIELD_MAPPING = {
'name': 'name',
'logic': 'logic',
}
def parse_import_data(self, json_data) -> list:
"""
解析 Goby JSON 格式的导入数据
Goby 格式是数组格式:[{...}, {...}, ...]
输入格式:[{"name": "...", "logic": "...", "rule": [...]}, ...]
返回:指纹列表
"""
if isinstance(json_data, list):
return json_data
return []
def get_export_filename(self) -> str:
"""导出文件名"""
return 'goby.json'

View File

@@ -0,0 +1,75 @@
"""Wappalyzer 指纹管理 ViewSet"""
from apps.common.pagination import BasePagination
from apps.engine.models import WappalyzerFingerprint
from apps.engine.serializers.fingerprints import WappalyzerFingerprintSerializer
from apps.engine.services.fingerprints import WappalyzerFingerprintService
from .base import BaseFingerprintViewSet
class WappalyzerFingerprintViewSet(BaseFingerprintViewSet):
"""Wappalyzer 指纹管理 ViewSet
继承自 BaseFingerprintViewSet提供以下 API
标准 CRUDModelViewSet
- GET / 列表查询(分页)
- POST / 创建单条
- GET /{id}/ 获取详情
- PUT /{id}/ 更新
- DELETE /{id}/ 删除
批量操作(继承自基类):
- POST /batch_create/ 批量创建JSON body
- POST /import_file/ 文件导入multipart/form-data
- POST /bulk-delete/ 批量删除
- POST /delete-all/ 删除所有
- GET /export/ 导出下载
智能过滤语法filter 参数):
- name="word" 模糊匹配 name 字段
- name=="AppName" 精确匹配
"""
queryset = WappalyzerFingerprint.objects.all()
serializer_class = WappalyzerFingerprintSerializer
pagination_class = BasePagination
service_class = WappalyzerFingerprintService
# 排序配置
ordering_fields = ['created_at', 'name']
ordering = ['-created_at']
# Wappalyzer 过滤字段映射
# 注意implies 是 JSON 数组字段,使用 __contains 查询
FILTER_FIELD_MAPPING = {
'name': 'name',
'description': 'description',
'website': 'website',
'cpe': 'cpe',
'implies': 'implies', # JSON 数组字段
}
# JSON 数组字段列表(使用 __contains 查询)
JSON_ARRAY_FIELDS = ['implies']
def parse_import_data(self, json_data: dict) -> list:
"""
解析 Wappalyzer JSON 格式的导入数据
Wappalyzer 格式是 apps 对象格式:{"apps": {"AppName": {...}, ...}}
输入格式:{"apps": {"1C-Bitrix": {"cats": [...], ...}, ...}}
返回:指纹列表(每个 app 转换为带 name 字段的 dict
"""
apps = json_data.get('apps', {})
fingerprints = []
for name, data in apps.items():
item = {'name': name, **data}
fingerprints.append(item)
return fingerprints
def get_export_filename(self) -> str:
"""导出文件名"""
return 'wappalyzer.json'

View File

@@ -31,6 +31,8 @@ from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.response import Response
from apps.common.response_helpers import success_response, error_response
from apps.common.error_codes import ErrorCodes
from apps.engine.models import NucleiTemplateRepo
from apps.engine.serializers import NucleiTemplateRepoSerializer
from apps.engine.services import NucleiTemplateRepoService
@@ -107,18 +109,30 @@ class NucleiTemplateRepoViewSet(viewsets.ModelViewSet):
try:
repo_id = int(pk) if pk is not None else None
except (TypeError, ValueError):
return Response({"message": "无效的仓库 ID"}, status=status.HTTP_400_BAD_REQUEST)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Invalid repository ID',
status_code=status.HTTP_400_BAD_REQUEST
)
# 调用 Service 层
try:
result = self.service.refresh_repo(repo_id)
except ValidationError as exc:
return Response({"message": str(exc)}, status=status.HTTP_400_BAD_REQUEST)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message=str(exc),
status_code=status.HTTP_400_BAD_REQUEST
)
except Exception as exc: # noqa: BLE001
logger.error("刷新 Nuclei 模板仓库失败: %s", exc, exc_info=True)
return Response({"message": f"刷新仓库失败: {exc}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
return error_response(
code=ErrorCodes.SERVER_ERROR,
message=f'Refresh failed: {exc}',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return Response({"message": "刷新成功", "result": result}, status=status.HTTP_200_OK)
return success_response(data={'result': result})
# ==================== 自定义 Action: 模板只读浏览 ====================
@@ -142,18 +156,30 @@ class NucleiTemplateRepoViewSet(viewsets.ModelViewSet):
try:
repo_id = int(pk) if pk is not None else None
except (TypeError, ValueError):
return Response({"message": "无效的仓库 ID"}, status=status.HTTP_400_BAD_REQUEST)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Invalid repository ID',
status_code=status.HTTP_400_BAD_REQUEST
)
# 调用 Service 层,仅从当前本地目录读取目录树
try:
roots = self.service.get_template_tree(repo_id)
except ValidationError as exc:
return Response({"message": str(exc)}, status=status.HTTP_400_BAD_REQUEST)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message=str(exc),
status_code=status.HTTP_400_BAD_REQUEST
)
except Exception as exc: # noqa: BLE001
logger.error("获取 Nuclei 模板目录树失败: %s", exc, exc_info=True)
return Response({"message": "获取模板目录树失败"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Failed to get template tree',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return Response({"roots": roots})
return success_response(data={'roots': roots})
@action(detail=True, methods=["get"], url_path="templates/content")
def templates_content(self, request: Request, pk: str | None = None) -> Response:
@@ -174,23 +200,43 @@ class NucleiTemplateRepoViewSet(viewsets.ModelViewSet):
try:
repo_id = int(pk) if pk is not None else None
except (TypeError, ValueError):
return Response({"message": "无效的仓库 ID"}, status=status.HTTP_400_BAD_REQUEST)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Invalid repository ID',
status_code=status.HTTP_400_BAD_REQUEST
)
# 解析 path 参数
rel_path = (request.query_params.get("path", "") or "").strip()
if not rel_path:
return Response({"message": "缺少 path 参数"}, status=status.HTTP_400_BAD_REQUEST)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Missing path parameter',
status_code=status.HTTP_400_BAD_REQUEST
)
# 调用 Service 层
try:
result = self.service.get_template_content(repo_id, rel_path)
except ValidationError as exc:
return Response({"message": str(exc)}, status=status.HTTP_400_BAD_REQUEST)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message=str(exc),
status_code=status.HTTP_400_BAD_REQUEST
)
except Exception as exc: # noqa: BLE001
logger.error("获取 Nuclei 模板内容失败: %s", exc, exc_info=True)
return Response({"message": "获取模板内容失败"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
return error_response(
code=ErrorCodes.SERVER_ERROR,
message='Failed to get template content',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
# 文件不存在
if result is None:
return Response({"message": "模板不存在或无法读取"}, status=status.HTTP_404_NOT_FOUND)
return Response(result)
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Template not found or unreadable',
status_code=status.HTTP_404_NOT_FOUND
)
return success_response(data=result)

View File

@@ -9,6 +9,8 @@ from rest_framework.decorators import action
from rest_framework.response import Response
from apps.common.pagination import BasePagination
from apps.common.response_helpers import success_response, error_response
from apps.common.error_codes import ErrorCodes
from apps.engine.serializers.wordlist_serializer import WordlistSerializer
from apps.engine.services.wordlist_service import WordlistService
@@ -46,7 +48,11 @@ class WordlistViewSet(viewsets.ViewSet):
uploaded_file = request.FILES.get("file")
if not uploaded_file:
return Response({"error": "缺少字典文件"}, status=status.HTTP_400_BAD_REQUEST)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Missing wordlist file',
status_code=status.HTTP_400_BAD_REQUEST
)
try:
wordlist = self.service.create_wordlist(
@@ -55,21 +61,32 @@ class WordlistViewSet(viewsets.ViewSet):
uploaded_file=uploaded_file,
)
except ValidationError as exc:
return Response({"error": str(exc)}, status=status.HTTP_400_BAD_REQUEST)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message=str(exc),
status_code=status.HTTP_400_BAD_REQUEST
)
serializer = WordlistSerializer(wordlist)
return Response(serializer.data, status=status.HTTP_201_CREATED)
return success_response(data=serializer.data, status_code=status.HTTP_201_CREATED)
def destroy(self, request, pk=None):
"""删除字典记录"""
try:
wordlist_id = int(pk)
except (TypeError, ValueError):
return Response({"error": "无效的 ID"}, status=status.HTTP_400_BAD_REQUEST)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Invalid ID',
status_code=status.HTTP_400_BAD_REQUEST
)
success = self.service.delete_wordlist(wordlist_id)
if not success:
return Response({"error": "字典不存在"}, status=status.HTTP_404_NOT_FOUND)
return error_response(
code=ErrorCodes.NOT_FOUND,
status_code=status.HTTP_404_NOT_FOUND
)
return Response(status=status.HTTP_204_NO_CONTENT)
@@ -82,15 +99,27 @@ class WordlistViewSet(viewsets.ViewSet):
"""
name = (request.query_params.get("wordlist", "") or "").strip()
if not name:
return Response({"error": "缺少参数 wordlist"}, status=status.HTTP_400_BAD_REQUEST)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Missing parameter: wordlist',
status_code=status.HTTP_400_BAD_REQUEST
)
wordlist = self.service.get_wordlist_by_name(name)
if not wordlist:
return Response({"error": "字典不存在"}, status=status.HTTP_404_NOT_FOUND)
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Wordlist not found',
status_code=status.HTTP_404_NOT_FOUND
)
file_path = wordlist.file_path
if not file_path or not os.path.exists(file_path):
return Response({"error": "字典文件不存在"}, status=status.HTTP_404_NOT_FOUND)
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Wordlist file not found',
status_code=status.HTTP_404_NOT_FOUND
)
filename = os.path.basename(file_path)
response = FileResponse(open(file_path, "rb"), as_attachment=True, filename=filename)
@@ -106,22 +135,38 @@ class WordlistViewSet(viewsets.ViewSet):
try:
wordlist_id = int(pk)
except (TypeError, ValueError):
return Response({"error": "无效的 ID"}, status=status.HTTP_400_BAD_REQUEST)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Invalid ID',
status_code=status.HTTP_400_BAD_REQUEST
)
if request.method == "GET":
content = self.service.get_wordlist_content(wordlist_id)
if content is None:
return Response({"error": "字典不存在或文件无法读取"}, status=status.HTTP_404_NOT_FOUND)
return Response({"content": content})
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Wordlist not found or file unreadable',
status_code=status.HTTP_404_NOT_FOUND
)
return success_response(data={"content": content})
elif request.method == "PUT":
content = request.data.get("content")
if content is None:
return Response({"error": "缺少 content 参数"}, status=status.HTTP_400_BAD_REQUEST)
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Missing content parameter',
status_code=status.HTTP_400_BAD_REQUEST
)
wordlist = self.service.update_wordlist_content(wordlist_id, content)
if not wordlist:
return Response({"error": "字典不存在或更新失败"}, status=status.HTTP_404_NOT_FOUND)
return error_response(
code=ErrorCodes.NOT_FOUND,
message='Wordlist not found or update failed',
status_code=status.HTTP_404_NOT_FOUND
)
serializer = WordlistSerializer(wordlist)
return Response(serializer.data)
return success_response(data=serializer.data)

View File

@@ -9,6 +9,8 @@ from rest_framework import viewsets, status
from rest_framework.decorators import action
from rest_framework.response import Response
from apps.common.response_helpers import success_response, error_response
from apps.common.error_codes import ErrorCodes
from apps.engine.serializers import WorkerNodeSerializer
from apps.engine.services import WorkerService
from apps.common.signals import worker_delete_failed
@@ -111,9 +113,8 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
threading.Thread(target=_async_remote_uninstall, daemon=True).start()
# 3. 立即返回成功
return Response(
{"message": f"节点 {worker_name} 已删除"},
status=status.HTTP_200_OK
return success_response(
data={'name': worker_name}
)
@action(detail=True, methods=['post'])
@@ -190,11 +191,13 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
worker.status = 'online'
worker.save(update_fields=['status'])
return Response({
'status': 'ok',
'need_update': need_update,
'server_version': server_version
})
return success_response(
data={
'status': 'ok',
'needUpdate': need_update,
'serverVersion': server_version
}
)
def _trigger_remote_agent_update(self, worker, target_version: str):
"""
@@ -238,7 +241,7 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
docker run -d --pull=always \
--name xingrin-agent \
--restart always \
-e HEARTBEAT_API_URL="https://{django_settings.PUBLIC_HOST}" \
-e HEARTBEAT_API_URL="https://{django_settings.PUBLIC_HOST}:{getattr(django_settings, 'PUBLIC_PORT', '8083')}" \
-e WORKER_ID="{worker_id}" \
-e IMAGE_TAG="{target_version}" \
-v /proc:/host/proc:ro \
@@ -304,9 +307,10 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
is_local = request.data.get('is_local', True)
if not name:
return Response(
{'error': '缺少 name 参数'},
status=status.HTTP_400_BAD_REQUEST
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message='Missing name parameter',
status_code=status.HTTP_400_BAD_REQUEST
)
worker, created = self.worker_service.register_worker(
@@ -314,11 +318,13 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
is_local=is_local
)
return Response({
'worker_id': worker.id,
'name': worker.name,
'created': created
})
return success_response(
data={
'workerId': worker.id,
'name': worker.name,
'created': created
}
)
@action(detail=False, methods=['get'])
def config(self, request):
@@ -380,22 +386,24 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
logger.info("返回 Worker 配置 - db_host: %s, redis_url: %s", worker_db_host, worker_redis_url)
return Response({
'db': {
'host': worker_db_host,
'port': str(settings.DATABASES['default']['PORT']),
'name': settings.DATABASES['default']['NAME'],
'user': settings.DATABASES['default']['USER'],
'password': settings.DATABASES['default']['PASSWORD'],
},
'redisUrl': worker_redis_url,
'paths': {
'results': getattr(settings, 'CONTAINER_RESULTS_MOUNT', '/app/backend/results'),
'logs': getattr(settings, 'CONTAINER_LOGS_MOUNT', '/app/backend/logs'),
},
'logging': {
'level': os.getenv('LOG_LEVEL', 'INFO'),
'enableCommandLogging': os.getenv('ENABLE_COMMAND_LOGGING', 'true').lower() == 'true',
},
'debug': settings.DEBUG
})
return success_response(
data={
'db': {
'host': worker_db_host,
'port': str(settings.DATABASES['default']['PORT']),
'name': settings.DATABASES['default']['NAME'],
'user': settings.DATABASES['default']['USER'],
'password': settings.DATABASES['default']['PASSWORD'],
},
'redisUrl': worker_redis_url,
'paths': {
'results': getattr(settings, 'CONTAINER_RESULTS_MOUNT', '/opt/xingrin/results'),
'logs': getattr(settings, 'CONTAINER_LOGS_MOUNT', '/opt/xingrin/logs'),
},
'logging': {
'level': os.getenv('LOG_LEVEL', 'INFO'),
'enableCommandLogging': os.getenv('ENABLE_COMMAND_LOGGING', 'true').lower() == 'true',
},
'debug': settings.DEBUG,
}
)

View File

@@ -7,7 +7,7 @@
from django.conf import settings
# ==================== 路径配置 ====================
SCAN_TOOLS_BASE_PATH = getattr(settings, 'SCAN_TOOLS_BASE_PATH', '/opt/xingrin/tools')
SCAN_TOOLS_BASE_PATH = getattr(settings, 'SCAN_TOOLS_BASE_PATH', '/usr/local/bin')
# ==================== 子域名发现 ====================
@@ -35,7 +35,7 @@ SUBDOMAIN_DISCOVERY_COMMANDS = {
},
'sublist3r': {
'base': "python3 '{scan_tools_base}/Sublist3r/sublist3r.py' -d {domain} -o '{output_file}'",
'base': "python3 '/usr/local/share/Sublist3r/sublist3r.py' -d {domain} -o '{output_file}'",
'optional': {
'threads': '-t {threads}'
}
@@ -115,7 +115,7 @@ SITE_SCAN_COMMANDS = {
DIRECTORY_SCAN_COMMANDS = {
'ffuf': {
'base': "ffuf -u '{url}FUZZ' -se -ac -sf -json -w '{wordlist}'",
'base': "'{scan_tools_base}/ffuf' -u '{url}FUZZ' -se -ac -sf -json -w '{wordlist}'",
'optional': {
'delay': '-p {delay}',
'threads': '-t {threads}',
@@ -225,12 +225,35 @@ VULN_SCAN_COMMANDS = {
}
# ==================== 指纹识别 ====================
FINGERPRINT_DETECT_COMMANDS = {
'xingfinger': {
# 流式输出模式(不使用 -o输出到 stdout
# -l: URL 列表文件输入
# -s: 静默模式,只输出命中结果
# --json: JSON 格式输出(每行一条)
'base': "xingfinger -l '{urls_file}' -s --json",
'optional': {
# 自定义指纹库路径
'ehole': '--ehole {ehole}',
'goby': '--goby {goby}',
'wappalyzer': '--wappalyzer {wappalyzer}',
'fingers': '--fingers {fingers}',
'fingerprinthub': '--fingerprint {fingerprinthub}',
'arl': '--arl {arl}',
}
},
}
# ==================== 工具映射 ====================
COMMAND_TEMPLATES = {
'subdomain_discovery': SUBDOMAIN_DISCOVERY_COMMANDS,
'port_scan': PORT_SCAN_COMMANDS,
'site_scan': SITE_SCAN_COMMANDS,
'fingerprint_detect': FINGERPRINT_DETECT_COMMANDS,
'directory_scan': DIRECTORY_SCAN_COMMANDS,
'url_fetch': URL_FETCH_COMMANDS,
'vuln_scan': VULN_SCAN_COMMANDS,
@@ -242,7 +265,7 @@ COMMAND_TEMPLATES = {
EXECUTION_STAGES = [
{
'mode': 'sequential',
'flows': ['subdomain_discovery', 'port_scan', 'site_scan']
'flows': ['subdomain_discovery', 'port_scan', 'site_scan', 'fingerprint_detect']
},
{
'mode': 'parallel',

View File

@@ -1,7 +1,8 @@
# 引擎配置
#
# 参数命名:统一用中划线(如 rate-limit系统自动转换为下划线
# 必需参数enabled是否启用、timeout超时秒数auto 表示自动计算)
# 必需参数enabled是否启用
# 可选参数timeout超时秒数默认 auto 自动计算)
# ==================== 子域名发现 ====================
#
@@ -39,7 +40,7 @@ subdomain_discovery:
bruteforce:
enabled: false
subdomain_bruteforce:
timeout: auto # 自动根据字典行数计算
# timeout: auto # 自动根据字典行数计算
wordlist-name: subdomains-top1million-110000.txt # 对应「字典管理」中的 Wordlist.name
# === Stage 3: 变异生成 + 存活验证(可选)===
@@ -52,14 +53,14 @@ subdomain_discovery:
resolve:
enabled: true
subdomain_resolve:
timeout: auto # 自动根据候选子域数量计算
timeout: auto # 自动根据候选子域数量计算
# ==================== 端口扫描 ====================
port_scan:
tools:
naabu_active:
enabled: true
timeout: auto # 自动计算(目标数 × 端口数 × 0.5秒),范围 60秒 ~ 2天
# timeout: auto # 自动计算(目标数 × 端口数 × 0.5秒),范围 60秒 ~ 2天
threads: 200 # 并发连接数(默认 5
# ports: 1-65535 # 扫描端口范围(默认 1-65535
top-ports: 100 # 扫描 nmap top 100 端口
@@ -67,25 +68,33 @@ port_scan:
naabu_passive:
enabled: true
timeout: auto # 被动扫描通常较快
# timeout: auto # 被动扫描通常较快
# ==================== 站点扫描 ====================
site_scan:
tools:
httpx:
enabled: true
timeout: auto # 自动计算(每个 URL 约 1 秒)
# timeout: auto # 自动计算(每个 URL 约 1 秒)
# threads: 50 # 并发线程数(默认 50
# rate-limit: 150 # 每秒请求数(默认 150
# request-timeout: 10 # 单个请求超时秒数(默认 10
# retries: 2 # 请求失败重试次数
# ==================== 指纹识别 ====================
# 在 site_scan 后串行执行,识别 WebSite 的技术栈
fingerprint_detect:
tools:
xingfinger:
enabled: true
fingerprint-libs: [ehole, goby, wappalyzer] # 启用的指纹库ehole, goby, wappalyzer, fingers, fingerprinthub, arl
# ==================== 目录扫描 ====================
directory_scan:
tools:
ffuf:
enabled: true
timeout: auto # 自动计算(字典行数 × 0.02秒),范围 60秒 ~ 2小时
# timeout: auto # 自动计算(字典行数 × 0.02秒),范围 60秒 ~ 2小时
max-workers: 5 # 并发扫描站点数(默认 5
wordlist-name: dir_default.txt # 对应「字典管理」中的 Wordlist.name
delay: 0.1-2.0 # 请求间隔,支持范围随机(如 "0.1-2.0"
@@ -103,7 +112,7 @@ url_fetch:
katana:
enabled: true
timeout: auto # 自动计算(根据站点数量)
# timeout: auto # 自动计算(根据站点数量)
depth: 5 # 爬取最大深度(默认 3
threads: 10 # 全局并发数
rate-limit: 30 # 每秒最多请求数
@@ -113,7 +122,7 @@ url_fetch:
uro:
enabled: true
timeout: auto # 自动计算(每 100 个 URL 约 1 秒),范围 30 ~ 300 秒
# timeout: auto # 自动计算(每 100 个 URL 约 1 秒),范围 30 ~ 300 秒
# whitelist: # 只保留指定扩展名
# - php
# - asp
@@ -127,7 +136,7 @@ url_fetch:
httpx:
enabled: true
timeout: auto # 自动计算(每个 URL 约 1 秒)
# timeout: auto # 自动计算(每个 URL 约 1 秒)
# threads: 50 # 并发线程数(默认 50
# rate-limit: 150 # 每秒请求数(默认 150
# request-timeout: 10 # 单个请求超时秒数(默认 10
@@ -138,7 +147,7 @@ vuln_scan:
tools:
dalfox_xss:
enabled: true
timeout: auto # 自动计算endpoints 行数 × 100 秒)
# timeout: auto # 自动计算endpoints 行数 × 100 秒)
request-timeout: 10 # 单个请求超时秒数
only-poc: r # 只输出 POC 结果r: 反射型)
ignore-return: "302,404,403" # 忽略的返回码
@@ -149,7 +158,7 @@ vuln_scan:
nuclei:
enabled: true
timeout: auto # 自动计算(根据 endpoints 行数)
# timeout: auto # 自动计算(根据 endpoints 行数)
template-repo-names: # 模板仓库列表对应「Nuclei 模板」中的仓库名
- nuclei-templates
# - nuclei-custom # 可追加自定义仓库

View File

@@ -5,8 +5,10 @@
from .initiate_scan_flow import initiate_scan_flow
from .subdomain_discovery_flow import subdomain_discovery_flow
from .fingerprint_detect_flow import fingerprint_detect_flow
__all__ = [
'initiate_scan_flow',
'subdomain_discovery_flow',
'fingerprint_detect_flow',
]

View File

@@ -140,28 +140,7 @@ def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> i
return default
def _setup_directory_scan_directory(scan_workspace_dir: str) -> Path:
"""
创建并验证目录扫描工作目录
Args:
scan_workspace_dir: 扫描工作空间目录
Returns:
Path: 目录扫描目录路径
Raises:
RuntimeError: 目录创建或验证失败
"""
directory_scan_dir = Path(scan_workspace_dir) / 'directory_scan'
directory_scan_dir.mkdir(parents=True, exist_ok=True)
if not directory_scan_dir.is_dir():
raise RuntimeError(f"目录扫描目录创建失败: {directory_scan_dir}")
if not os.access(directory_scan_dir, os.W_OK):
raise RuntimeError(f"目录扫描目录不可写: {directory_scan_dir}")
return directory_scan_dir
def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]:
@@ -185,8 +164,7 @@ def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path
export_result = export_sites_task(
target_id=target_id,
output_file=sites_file,
batch_size=1000, # 每次读取 1000 条,优化内存占用
target_name=target_name # 传入 target_name 用于懒加载
batch_size=1000 # 每次读取 1000 条,优化内存占用
)
site_count = export_result['total_count']
@@ -483,13 +461,23 @@ def _run_scans_concurrently(
logger.warning("没有有效的扫描任务")
continue
# 使用 ThreadPoolTaskRunner 并发执行
logger.info("开始并发提交 %d 个扫描任务...", len(scan_params_list))
# ============================================================
# 分批执行策略:控制实际并发的 ffuf 进程数
# ============================================================
total_tasks = len(scan_params_list)
logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers)
with ThreadPoolTaskRunner(max_workers=max_workers) as task_runner:
# 提交所有任务
batch_num = 0
for batch_start in range(0, total_tasks, max_workers):
batch_end = min(batch_start + max_workers, total_tasks)
batch_params = scan_params_list[batch_start:batch_end]
batch_num += 1
logger.info("执行第 %d 批任务(%d-%d/%d...", batch_num, batch_start + 1, batch_end, total_tasks)
# 提交当前批次的任务(非阻塞,立即返回 future
futures = []
for params in scan_params_list:
for params in batch_params:
future = run_and_stream_save_directories_task.submit(
cmd=params['command'],
tool_name=tool_name,
@@ -504,12 +492,10 @@ def _run_scans_concurrently(
)
futures.append((params['idx'], params['site_url'], future))
logger.info("✓ 已提交 %d 个扫描任务,等待完成...", len(futures))
# 等待所有任务完成并聚合结果
# 等待当前批次所有任务完成(阻塞,确保本批完成后再启动下一批)
for idx, site_url, future in futures:
try:
result = future.result()
result = future.result() # 阻塞等待单个任务完成
directories_found = result.get('created_directories', 0)
total_directories += directories_found
processed_sites_count += 1
@@ -521,7 +507,6 @@ def _run_scans_concurrently(
except Exception as exc:
failed_sites.append(site_url)
# 判断是否为超时异常
if 'timeout' in str(exc).lower() or isinstance(exc, subprocess.TimeoutExpired):
logger.warning(
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
@@ -633,7 +618,8 @@ def directory_scan_flow(
raise ValueError("enabled_tools 不能为空")
# Step 0: 创建工作目录
directory_scan_dir = _setup_directory_scan_directory(scan_workspace_dir)
from apps.scan.utils import setup_scan_directory
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
# Step 1: 导出站点 URL支持懒加载
sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir)

View File

@@ -0,0 +1,380 @@
"""
指纹识别 Flow
负责编排指纹识别的完整流程
架构:
- Flow 负责编排多个原子 Task
- 在 site_scan 后串行执行
- 使用 xingfinger 工具识别技术栈
- 流式处理输出,批量更新数据库
"""
# Django 环境初始化(导入即生效)
from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
import os
from datetime import datetime
from pathlib import Path
from prefect import flow
from apps.scan.handlers.scan_flow_handlers import (
on_scan_flow_running,
on_scan_flow_completed,
on_scan_flow_failed,
)
from apps.scan.tasks.fingerprint_detect import (
export_urls_for_fingerprint_task,
run_xingfinger_and_stream_update_tech_task,
)
from apps.scan.utils import build_scan_command
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
logger = logging.getLogger(__name__)
def calculate_fingerprint_detect_timeout(
url_count: int,
base_per_url: float = 3.0,
min_timeout: int = 60
) -> int:
"""
根据 URL 数量计算超时时间
公式:超时时间 = URL 数量 × 每 URL 基础时间
最小值60秒
无上限
Args:
url_count: URL 数量
base_per_url: 每 URL 基础时间(秒),默认 3秒
min_timeout: 最小超时时间(秒),默认 60秒
Returns:
int: 计算出的超时时间(秒)
示例:
100 URL × 3秒 = 300秒
1000 URL × 3秒 = 3000秒50分钟
10000 URL × 3秒 = 30000秒8.3小时)
"""
timeout = int(url_count * base_per_url)
return max(min_timeout, timeout)
def _export_urls(
target_id: int,
fingerprint_dir: Path,
source: str = 'website'
) -> tuple[str, int]:
"""
导出 URL 到文件
Args:
target_id: 目标 ID
fingerprint_dir: 指纹识别目录
source: 数据源类型
Returns:
tuple: (urls_file, total_count)
"""
logger.info("Step 1: 导出 URL 列表 (source=%s)", source)
urls_file = str(fingerprint_dir / 'urls.txt')
export_result = export_urls_for_fingerprint_task(
target_id=target_id,
output_file=urls_file,
source=source,
batch_size=1000
)
total_count = export_result['total_count']
logger.info(
"✓ URL 导出完成 - 文件: %s, 数量: %d",
export_result['output_file'],
total_count
)
return export_result['output_file'], total_count
def _run_fingerprint_detect(
enabled_tools: dict,
urls_file: str,
url_count: int,
fingerprint_dir: Path,
scan_id: int,
target_id: int,
source: str
) -> tuple[dict, list]:
"""
执行指纹识别任务
Args:
enabled_tools: 已启用的工具配置字典
urls_file: URL 文件路径
url_count: URL 总数
fingerprint_dir: 指纹识别目录
scan_id: 扫描任务 ID
target_id: 目标 ID
source: 数据源类型
Returns:
tuple: (tool_stats, failed_tools)
"""
tool_stats = {}
failed_tools = []
for tool_name, tool_config in enabled_tools.items():
# 1. 获取指纹库路径
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
fingerprint_paths = get_fingerprint_paths(lib_names)
if not fingerprint_paths:
reason = f"没有可用的指纹库: {lib_names}"
logger.warning(reason)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 2. 将指纹库路径合并到 tool_config用于命令构建
tool_config_with_paths = {**tool_config, **fingerprint_paths}
# 3. 构建命令
try:
command = build_scan_command(
tool_name=tool_name,
scan_type='fingerprint_detect',
command_params={
'urls_file': urls_file
},
tool_config=tool_config_with_paths
)
except Exception as e:
reason = f"命令构建失败: {str(e)}"
logger.error("构建 %s 命令失败: %s", tool_name, e)
failed_tools.append({'tool': tool_name, 'reason': reason})
continue
# 4. 计算超时时间
timeout = calculate_fingerprint_detect_timeout(url_count)
# 5. 生成日志文件路径
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = fingerprint_dir / f"{tool_name}_{timestamp}.log"
logger.info(
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
tool_name, url_count, timeout, list(fingerprint_paths.keys())
)
# 6. 执行扫描任务
try:
result = run_xingfinger_and_stream_update_tech_task(
cmd=command,
tool_name=tool_name,
scan_id=scan_id,
target_id=target_id,
source=source,
cwd=str(fingerprint_dir),
timeout=timeout,
log_file=str(log_file),
batch_size=100
)
tool_stats[tool_name] = {
'command': command,
'result': result,
'timeout': timeout,
'fingerprint_libs': list(fingerprint_paths.keys())
}
logger.info(
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
tool_name,
result.get('processed_records', 0),
result.get('updated_count', 0),
result.get('not_found_count', 0)
)
except Exception as exc:
failed_tools.append({'tool': tool_name, 'reason': str(exc)})
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
if failed_tools:
logger.warning(
"以下指纹识别工具执行失败: %s",
', '.join([f['tool'] for f in failed_tools])
)
return tool_stats, failed_tools
@flow(
name="fingerprint_detect",
log_prints=True,
on_running=[on_scan_flow_running],
on_completion=[on_scan_flow_completed],
on_failure=[on_scan_flow_failed],
)
def fingerprint_detect_flow(
scan_id: int,
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
) -> dict:
"""
指纹识别 Flow
主要功能:
1. 从数据库导出目标下所有 WebSite URL 到文件
2. 使用 xingfinger 进行技术栈识别
3. 解析结果并更新 WebSite.tech 字段(合并去重)
工作流程:
Step 0: 创建工作目录
Step 1: 导出 URL 列表
Step 2: 解析配置,获取启用的工具
Step 3: 执行 xingfinger 并解析结果
Args:
scan_id: 扫描任务 ID
target_name: 目标名称
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置xingfinger
Returns:
dict: {
'success': bool,
'scan_id': int,
'target': str,
'scan_workspace_dir': str,
'urls_file': str,
'url_count': int,
'processed_records': int,
'updated_count': int,
'not_found_count': int,
'executed_tasks': list,
'tool_stats': dict
}
"""
try:
logger.info(
"="*60 + "\n" +
"开始指纹识别\n" +
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
"="*60
)
# 参数验证
if scan_id is None:
raise ValueError("scan_id 不能为空")
if not target_name:
raise ValueError("target_name 不能为空")
if target_id is None:
raise ValueError("target_id 不能为空")
if not scan_workspace_dir:
raise ValueError("scan_workspace_dir 不能为空")
# 数据源类型(当前只支持 website
source = 'website'
# Step 0: 创建工作目录
from apps.scan.utils import setup_scan_directory
fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
# Step 1: 导出 URL支持懒加载
urls_file, url_count = _export_urls(target_id, fingerprint_dir, source)
if url_count == 0:
logger.warning("目标下没有可用的 URL跳过指纹识别")
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': 0,
'processed_records': 0,
'updated_count': 0,
'created_count': 0,
'executed_tasks': ['export_urls_for_fingerprint'],
'tool_stats': {
'total': 0,
'successful': 0,
'failed': 0,
'successful_tools': [],
'failed_tools': [],
'details': {}
}
}
# Step 2: 工具配置信息
logger.info("Step 2: 工具配置信息")
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
# Step 3: 执行指纹识别
logger.info("Step 3: 执行指纹识别")
tool_stats, failed_tools = _run_fingerprint_detect(
enabled_tools=enabled_tools,
urls_file=urls_file,
url_count=url_count,
fingerprint_dir=fingerprint_dir,
scan_id=scan_id,
target_id=target_id,
source=source
)
logger.info("="*60 + "\n✓ 指纹识别完成\n" + "="*60)
# 动态生成已执行的任务列表
executed_tasks = ['export_urls_for_fingerprint']
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats.keys()])
# 汇总所有工具的结果
total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values())
total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values())
total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values())
successful_tools = [name for name in enabled_tools.keys()
if name not in [f['tool'] for f in failed_tools]]
return {
'success': True,
'scan_id': scan_id,
'target': target_name,
'scan_workspace_dir': scan_workspace_dir,
'urls_file': urls_file,
'url_count': url_count,
'processed_records': total_processed,
'updated_count': total_updated,
'created_count': total_created,
'executed_tasks': executed_tasks,
'tool_stats': {
'total': len(enabled_tools),
'successful': len(successful_tools),
'failed': len(failed_tools),
'successful_tools': successful_tools,
'failed_tools': failed_tools,
'details': tool_stats
}
}
except ValueError as e:
logger.error("配置错误: %s", e)
raise
except RuntimeError as e:
logger.error("运行时错误: %s", e)
raise
except Exception as e:
logger.exception("指纹识别失败: %s", e)
raise

View File

@@ -30,7 +30,7 @@ from apps.scan.handlers import (
on_initiate_scan_flow_failed,
)
from prefect.futures import wait
from apps.scan.tasks.workspace_tasks import create_scan_workspace_task
from apps.scan.utils import setup_scan_workspace
from apps.scan.orchestrators import FlowOrchestrator
logger = logging.getLogger(__name__)
@@ -110,7 +110,7 @@ def initiate_scan_flow(
)
# ==================== Task 1: 创建 Scan 工作空间 ====================
scan_workspace_path = create_scan_workspace_task(scan_workspace_dir)
scan_workspace_path = setup_scan_workspace(scan_workspace_dir)
# ==================== Task 2: 获取引擎配置 ====================
from apps.scan.models import Scan

View File

@@ -154,28 +154,7 @@ def _parse_port_count(tool_config: dict) -> int:
return 100
def _setup_port_scan_directory(scan_workspace_dir: str) -> Path:
"""
创建并验证端口扫描工作目录
Args:
scan_workspace_dir: 扫描工作空间目录
Returns:
Path: 端口扫描目录路径
Raises:
RuntimeError: 目录创建或验证失败
"""
port_scan_dir = Path(scan_workspace_dir) / 'port_scan'
port_scan_dir.mkdir(parents=True, exist_ok=True)
if not port_scan_dir.is_dir():
raise RuntimeError(f"端口扫描目录创建失败: {port_scan_dir}")
if not os.access(port_scan_dir, os.W_OK):
raise RuntimeError(f"端口扫描目录不可写: {port_scan_dir}")
return port_scan_dir
def _export_scan_targets(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
@@ -442,7 +421,8 @@ def port_scan_flow(
)
# Step 0: 创建工作目录
port_scan_dir = _setup_port_scan_directory(scan_workspace_dir)
from apps.scan.utils import setup_scan_directory
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
# Step 1: 导出扫描目标列表到文件(根据 Target 类型自动决定内容)
targets_file, target_count, target_type = _export_scan_targets(target_id, port_scan_dir)

View File

@@ -85,28 +85,7 @@ def calculate_timeout_by_line_count(
return min_timeout
def _setup_site_scan_directory(scan_workspace_dir: str) -> Path:
"""
创建并验证站点扫描工作目录
Args:
scan_workspace_dir: 扫描工作空间目录
Returns:
Path: 站点扫描目录路径
Raises:
RuntimeError: 目录创建或验证失败
"""
site_scan_dir = Path(scan_workspace_dir) / 'site_scan'
site_scan_dir.mkdir(parents=True, exist_ok=True)
if not site_scan_dir.is_dir():
raise RuntimeError(f"站点扫描目录创建失败: {site_scan_dir}")
if not os.access(site_scan_dir, os.W_OK):
raise RuntimeError(f"站点扫描目录不可写: {site_scan_dir}")
return site_scan_dir
def _export_site_urls(target_id: int, site_scan_dir: Path, target_name: str = None) -> tuple[str, int, int]:
@@ -130,7 +109,6 @@ def _export_site_urls(target_id: int, site_scan_dir: Path, target_name: str = No
export_result = export_site_urls_task(
target_id=target_id,
output_file=urls_file,
target_name=target_name,
batch_size=1000 # 每次处理1000个子域名
)
@@ -403,7 +381,8 @@ def site_scan_flow(
raise ValueError("scan_workspace_dir 不能为空")
# Step 0: 创建工作目录
site_scan_dir = _setup_site_scan_directory(scan_workspace_dir)
from apps.scan.utils import setup_scan_directory
site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan')
# Step 1: 导出站点 URL
urls_file, total_urls, association_count = _export_site_urls(

View File

@@ -41,28 +41,7 @@ import subprocess
logger = logging.getLogger(__name__)
def _setup_subdomain_directory(scan_workspace_dir: str) -> Path:
"""
创建并验证子域名扫描工作目录
Args:
scan_workspace_dir: 扫描工作空间目录
Returns:
Path: 子域名扫描目录路径
Raises:
RuntimeError: 目录创建或验证失败
"""
result_dir = Path(scan_workspace_dir) / 'subdomain_discovery'
result_dir.mkdir(parents=True, exist_ok=True)
if not result_dir.is_dir():
raise RuntimeError(f"子域名扫描目录创建失败: {result_dir}")
if not os.access(result_dir, os.W_OK):
raise RuntimeError(f"子域名扫描目录不可写: {result_dir}")
return result_dir
def _validate_and_normalize_target(target_name: str) -> str:
@@ -119,12 +98,7 @@ def _run_scans_parallel(
# 生成时间戳(所有工具共用)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# TODO: 接入代理池管理系统
# from apps.proxy.services import proxy_pool
# proxy_stats = proxy_pool.get_stats()
# logger.info(f"代理池状态: {proxy_stats['healthy']}/{proxy_stats['total']} 可用")
failures = [] # 记录命令构建失败的工具
futures = {}
@@ -417,7 +391,8 @@ def subdomain_discovery_flow(
)
# Step 0: 准备工作
result_dir = _setup_subdomain_directory(scan_workspace_dir)
from apps.scan.utils import setup_scan_directory
result_dir = setup_scan_directory(scan_workspace_dir, 'subdomain_discovery')
# 验证并规范化目标域名
try:

View File

@@ -42,17 +42,7 @@ SITES_FILE_TOOLS = {'katana'}
POST_PROCESS_TOOLS = {'uro', 'httpx'}
def _setup_url_fetch_directory(scan_workspace_dir: str) -> Path:
"""创建并验证 URL 获取工作目录"""
url_fetch_dir = Path(scan_workspace_dir) / 'url_fetch'
url_fetch_dir.mkdir(parents=True, exist_ok=True)
if not url_fetch_dir.is_dir():
raise RuntimeError(f"URL 获取目录创建失败: {url_fetch_dir}")
if not os.access(url_fetch_dir, os.W_OK):
raise RuntimeError(f"URL 获取目录不可写: {url_fetch_dir}")
return url_fetch_dir
def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
@@ -304,7 +294,8 @@ def url_fetch_flow(
# Step 1: 准备工作目录
logger.info("Step 1: 准备工作目录")
url_fetch_dir = _setup_url_fetch_directory(scan_workspace_dir)
from apps.scan.utils import setup_scan_directory
url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch')
# Step 2: 分类工具(按输入类型)
logger.info("Step 2: 分类工具")

View File

@@ -40,8 +40,7 @@ def _export_sites_file(target_id: int, scan_id: int, target_name: str, output_di
result = export_sites_task(
output_file=output_file,
target_id=target_id,
scan_id=scan_id,
target_name=target_name
scan_id=scan_id
)
count = result['asset_count']

View File

@@ -25,10 +25,7 @@ from .utils import calculate_timeout_by_line_count
logger = logging.getLogger(__name__)
def _setup_vuln_scan_directory(scan_workspace_dir: str) -> Path:
vuln_scan_dir = Path(scan_workspace_dir) / "vuln_scan"
vuln_scan_dir.mkdir(parents=True, exist_ok=True)
return vuln_scan_dir
@flow(
@@ -55,14 +52,14 @@ def endpoints_vuln_scan_flow(
if not enabled_tools:
raise ValueError("enabled_tools 不能为空")
vuln_scan_dir = _setup_vuln_scan_directory(scan_workspace_dir)
from apps.scan.utils import setup_scan_directory
vuln_scan_dir = setup_scan_directory(scan_workspace_dir, 'vuln_scan')
endpoints_file = vuln_scan_dir / "input_endpoints.txt"
# Step 1: 导出 Endpoint URL
export_result = export_endpoints_task(
target_id=target_id,
output_file=str(endpoints_file),
target_name=target_name, # 传入 target_name 用于生成默认端点
)
total_endpoints = export_result.get("total_count", 0)

View File

@@ -87,8 +87,8 @@ def on_all_workers_high_load(sender, worker_name, cpu, mem, **kwargs):
"""所有 Worker 高负载时的通知处理"""
create_notification(
title="系统负载较高",
message=f"所有节点负载较高,已选择负载最低的节点 {worker_name}CPU: {cpu:.1f}%, 内存: {mem:.1f}%执行任务,扫描速度可能受影响",
message=f"所有节点负载较高(最低负载节点 CPU: {cpu:.1f}%, 内存: {mem:.1f}%,系统将等待最多 10 分钟后分发任务,扫描速度可能受影响",
level=NotificationLevel.MEDIUM,
category=NotificationCategory.SYSTEM
)
logger.warning("高负载通知已发送 - worker=%s, cpu=%.1f%%, mem=%.1f%%", worker_name, cpu, mem)
logger.warning("高负载通知已发送 - cpu=%.1f%%, mem=%.1f%%", cpu, mem)

View File

@@ -14,6 +14,8 @@ from rest_framework.response import Response
from rest_framework.views import APIView
from apps.common.pagination import BasePagination
from apps.common.response_helpers import success_response, error_response
from apps.common.error_codes import ErrorCodes
from .models import Notification
from .serializers import NotificationSerializer
from .types import NotificationLevel
@@ -60,34 +62,7 @@ def notifications_test(request):
}, status=500)
def build_api_response(
data: Any = None,
*,
message: str = '操作成功',
code: str = '200',
state: str = 'success',
status_code: int = status.HTTP_200_OK
) -> Response:
"""构建统一的 API 响应格式
Args:
data: 响应数据体(可选)
message: 响应消息
code: 响应代码
state: 响应状态success/error
status_code: HTTP 状态码
Returns:
DRF Response 对象
"""
payload = {
'code': code,
'state': state,
'message': message,
}
if data is not None:
payload['data'] = data
return Response(payload, status=status_code)
# build_api_response 已废弃,请使用 success_response/error_response
def _parse_bool(value: str | None) -> bool | None:
@@ -172,7 +147,7 @@ class NotificationUnreadCountView(APIView):
"""获取未读通知数量"""
service = NotificationService()
count = service.get_unread_count()
return build_api_response({'count': count}, message='获取未读数量成功')
return success_response(data={'count': count})
class NotificationMarkAllAsReadView(APIView):
@@ -192,7 +167,7 @@ class NotificationMarkAllAsReadView(APIView):
"""标记全部通知为已读"""
service = NotificationService()
updated = service.mark_all_as_read()
return build_api_response({'updated': updated}, message='全部标记已读成功')
return success_response(data={'updated': updated})
class NotificationSettingsView(APIView):
@@ -209,13 +184,13 @@ class NotificationSettingsView(APIView):
"""获取通知设置"""
service = NotificationSettingsService()
settings = service.get_settings()
return Response(settings)
return success_response(data=settings)
def put(self, request: Request) -> Response:
"""更新通知设置"""
service = NotificationSettingsService()
settings = service.update_settings(request.data)
return Response({'message': '已保存通知设置', **settings})
return success_response(data=settings)
# ============================================
@@ -247,22 +222,24 @@ def notification_callback(request):
required_fields = ['id', 'category', 'title', 'message', 'level', 'created_at']
for field in required_fields:
if field not in data:
return Response(
{'error': f'缺少字段: {field}'},
status=status.HTTP_400_BAD_REQUEST
return error_response(
code=ErrorCodes.VALIDATION_ERROR,
message=f'Missing field: {field}',
status_code=status.HTTP_400_BAD_REQUEST
)
# 推送到 WebSocket
_push_notification_to_websocket(data)
logger.debug(f"回调通知推送成功 - ID: {data['id']}, Title: {data['title']}")
return Response({'status': 'ok'})
return success_response(data={'status': 'ok'})
except Exception as e:
logger.error(f"回调通知处理失败: {e}", exc_info=True)
return Response(
{'error': str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
return error_response(
code=ErrorCodes.SERVER_ERROR,
message=str(e),
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)

View File

@@ -206,6 +206,10 @@ class FlowOrchestrator:
from apps.scan.flows.site_scan_flow import site_scan_flow
return site_scan_flow
elif scan_type == 'fingerprint_detect':
from apps.scan.flows.fingerprint_detect_flow import fingerprint_detect_flow
return fingerprint_detect_flow
elif scan_type == 'directory_scan':
from apps.scan.flows.directory_scan_flow import directory_scan_flow
return directory_scan_flow

View File

@@ -83,7 +83,7 @@ def cleanup_results(results_dir: str, retention_days: int) -> dict:
def main():
parser = argparse.ArgumentParser(description="清理任务")
parser.add_argument("--results_dir", type=str, default="/app/backend/results", help="扫描结果目录")
parser.add_argument("--results_dir", type=str, default="/opt/xingrin/results", help="扫描结果目录")
parser.add_argument("--retention_days", type=int, default=7, help="保留天数")
args = parser.parse_args()

View File

@@ -17,6 +17,8 @@ from .scan_state_service import ScanStateService
from .scan_control_service import ScanControlService
from .scan_stats_service import ScanStatsService
from .scheduled_scan_service import ScheduledScanService
from .blacklist_service import BlacklistService
from .target_export_service import TargetExportService
__all__ = [
'ScanService', # 主入口(向后兼容)
@@ -25,5 +27,7 @@ __all__ = [
'ScanControlService',
'ScanStatsService',
'ScheduledScanService',
'BlacklistService', # 黑名单过滤服务
'TargetExportService', # 目标导出服务
]

View File

@@ -0,0 +1,82 @@
"""
黑名单过滤服务
过滤敏感域名(如 .gov、.edu、.mil 等)
当前版本使用默认规则,后续将支持从前端配置加载。
"""
from typing import List, Optional
from django.db.models import QuerySet
import re
import logging
logger = logging.getLogger(__name__)
class BlacklistService:
"""
黑名单过滤服务 - 过滤敏感域名
TODO: 后续版本支持从前端配置加载黑名单规则
- 用户在开始扫描时配置黑名单 URL、域名、IP
- 黑名单规则存储在数据库中,与 Scan 或 Engine 关联
"""
# 默认黑名单正则规则
DEFAULT_PATTERNS = [
r'\.gov$', # .gov 结尾
r'\.gov\.[a-z]{2}$', # .gov.cn, .gov.uk 等
]
def __init__(self, patterns: Optional[List[str]] = None):
"""
初始化黑名单服务
Args:
patterns: 正则表达式列表None 使用默认规则
"""
self.patterns = patterns or self.DEFAULT_PATTERNS
self._compiled_patterns = [re.compile(p) for p in self.patterns]
def filter_queryset(
self,
queryset: QuerySet,
url_field: str = 'url'
) -> QuerySet:
"""
数据库层面过滤 queryset
使用 PostgreSQL 正则表达式排除黑名单 URL
Args:
queryset: 原始 queryset
url_field: URL 字段名
Returns:
QuerySet: 过滤后的 queryset
"""
for pattern in self.patterns:
queryset = queryset.exclude(**{f'{url_field}__regex': pattern})
return queryset
def filter_url(self, url: str) -> bool:
"""
检查单个 URL 是否通过黑名单过滤
Args:
url: 要检查的 URL
Returns:
bool: True 表示通过不在黑名单False 表示被过滤
"""
for pattern in self._compiled_patterns:
if pattern.search(url):
return False
return True
# TODO: 后续版本实现
# @classmethod
# def from_scan(cls, scan_id: int) -> 'BlacklistService':
# """从数据库加载扫描配置的黑名单规则"""
# pass

View File

@@ -0,0 +1,364 @@
"""
目标导出服务
提供统一的目标提取和文件导出功能,支持:
- URL 导出(流式写入 + 默认值回退)
- 域名/IP 导出(用于端口扫描)
- 黑名单过滤集成
"""
import ipaddress
import logging
from pathlib import Path
from typing import Dict, Any, Optional, Iterator
from django.db.models import QuerySet
from .blacklist_service import BlacklistService
logger = logging.getLogger(__name__)
class TargetExportService:
"""
目标导出服务 - 提供统一的目标提取和文件导出功能
使用方式:
# Task 层决定数据源
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
# 使用导出服务
blacklist_service = BlacklistService()
export_service = TargetExportService(blacklist_service=blacklist_service)
result = export_service.export_urls(target_id, output_path, queryset)
"""
def __init__(self, blacklist_service: Optional[BlacklistService] = None):
"""
初始化导出服务
Args:
blacklist_service: 黑名单过滤服务None 表示禁用过滤
"""
self.blacklist_service = blacklist_service
def export_urls(
self,
target_id: int,
output_path: str,
queryset: QuerySet,
url_field: str = 'url',
batch_size: int = 1000
) -> Dict[str, Any]:
"""
统一 URL 导出函数
自动判断数据库有无数据:
- 有数据:流式写入数据库数据到文件
- 无数据:调用默认值生成器生成 URL
Args:
target_id: 目标 ID
output_path: 输出文件路径
queryset: 数据源 queryset由 Task 层构建,应为 values_list flat=True
url_field: URL 字段名(用于黑名单过滤)
batch_size: 批次大小
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int
}
Raises:
IOError: 文件写入失败
"""
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
logger.info("开始导出 URL - target_id=%s, output=%s", target_id, output_path)
# 应用黑名单过滤(数据库层面)
if self.blacklist_service:
# 注意queryset 应该是原始 queryset不是 values_list
# 这里假设 Task 层传入的是 values_list需要在 Task 层处理过滤
pass
total_count = 0
try:
with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
for url in queryset.iterator(chunk_size=batch_size):
if url:
# Python 层面黑名单过滤
if self.blacklist_service and not self.blacklist_service.filter_url(url):
continue
f.write(f"{url}\n")
total_count += 1
if total_count % 10000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
except IOError as e:
logger.error("文件写入失败: %s - %s", output_path, e)
raise
# 默认值回退模式
if total_count == 0:
total_count = self._generate_default_urls(target_id, output_file)
logger.info("✓ URL 导出完成 - 数量: %d, 文件: %s", total_count, output_path)
return {
'success': True,
'output_file': str(output_file),
'total_count': total_count
}
def _generate_default_urls(
self,
target_id: int,
output_path: Path
) -> int:
"""
默认值生成器(内部函数)
根据 Target 类型生成默认 URL
- DOMAIN: http(s)://domain
- IP: http(s)://ip
- CIDR: 展开为所有 IP 的 http(s)://ip
- URL: 直接使用目标 URL
Args:
target_id: 目标 ID
output_path: 输出文件路径
Returns:
int: 写入的 URL 总数
"""
from apps.targets.services import TargetService
from apps.targets.models import Target
target_service = TargetService()
target = target_service.get_target(target_id)
if not target:
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
return 0
target_name = target.name
target_type = target.type
logger.info("懒加载模式Target 类型=%s, 名称=%s", target_type, target_name)
total_urls = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
if target_type == Target.TargetType.DOMAIN:
urls = [f"http://{target_name}", f"https://{target_name}"]
for url in urls:
if self._should_write_url(url):
f.write(f"{url}\n")
total_urls += 1
elif target_type == Target.TargetType.IP:
urls = [f"http://{target_name}", f"https://{target_name}"]
for url in urls:
if self._should_write_url(url):
f.write(f"{url}\n")
total_urls += 1
elif target_type == Target.TargetType.CIDR:
try:
network = ipaddress.ip_network(target_name, strict=False)
for ip in network.hosts():
urls = [f"http://{ip}", f"https://{ip}"]
for url in urls:
if self._should_write_url(url):
f.write(f"{url}\n")
total_urls += 1
if total_urls % 10000 == 0:
logger.info("已生成 %d 个 URL...", total_urls)
# /32 或 /128 特殊处理
if total_urls == 0:
ip = str(network.network_address)
urls = [f"http://{ip}", f"https://{ip}"]
for url in urls:
if self._should_write_url(url):
f.write(f"{url}\n")
total_urls += 1
except ValueError as e:
logger.error("CIDR 解析失败: %s - %s", target_name, e)
raise ValueError(f"无效的 CIDR: {target_name}") from e
elif target_type == Target.TargetType.URL:
if self._should_write_url(target_name):
f.write(f"{target_name}\n")
total_urls = 1
else:
logger.warning("不支持的 Target 类型: %s", target_type)
logger.info("✓ 懒加载生成默认 URL - 数量: %d", total_urls)
return total_urls
def _should_write_url(self, url: str) -> bool:
"""检查 URL 是否应该写入(通过黑名单过滤)"""
if self.blacklist_service:
return self.blacklist_service.filter_url(url)
return True
def export_targets(
self,
target_id: int,
output_path: str,
batch_size: int = 1000
) -> Dict[str, Any]:
"""
域名/IP 导出函数(用于端口扫描)
根据 Target 类型选择导出逻辑:
- DOMAIN: 从 Subdomain 表流式导出子域名
- IP: 直接写入 IP 地址
- CIDR: 展开为所有主机 IP
Args:
target_id: 目标 ID
output_path: 输出文件路径
batch_size: 批次大小
Returns:
dict: {
'success': bool,
'output_file': str,
'total_count': int,
'target_type': str
}
"""
from apps.targets.services import TargetService
from apps.targets.models import Target
from apps.asset.services.asset.subdomain_service import SubdomainService
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
# 获取 Target 信息
target_service = TargetService()
target = target_service.get_target(target_id)
if not target:
raise ValueError(f"Target ID {target_id} 不存在")
target_type = target.type
target_name = target.name
logger.info(
"开始导出扫描目标 - Target ID: %d, Name: %s, Type: %s, 输出文件: %s",
target_id, target_name, target_type, output_path
)
total_count = 0
if target_type == Target.TargetType.DOMAIN:
total_count = self._export_domains(target_id, target_name, output_file, batch_size)
type_desc = "域名"
elif target_type == Target.TargetType.IP:
total_count = self._export_ip(target_name, output_file)
type_desc = "IP"
elif target_type == Target.TargetType.CIDR:
total_count = self._export_cidr(target_name, output_file)
type_desc = "CIDR IP"
else:
raise ValueError(f"不支持的目标类型: {target_type}")
logger.info(
"✓ 扫描目标导出完成 - 类型: %s, 总数: %d, 文件: %s",
type_desc, total_count, output_path
)
return {
'success': True,
'output_file': str(output_file),
'total_count': total_count,
'target_type': target_type
}
def _export_domains(
self,
target_id: int,
target_name: str,
output_path: Path,
batch_size: int
) -> int:
"""导出域名类型目标的子域名"""
from apps.asset.services.asset.subdomain_service import SubdomainService
subdomain_service = SubdomainService()
domain_iterator = subdomain_service.iter_subdomain_names_by_target(
target_id=target_id,
chunk_size=batch_size
)
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for domain_name in domain_iterator:
if self._should_write_target(domain_name):
f.write(f"{domain_name}\n")
total_count += 1
if total_count % 10000 == 0:
logger.info("已导出 %d 个域名...", total_count)
# 默认值模式:如果没有子域名,使用根域名
if total_count == 0:
logger.info("采用默认域名:%s (target_id=%d)", target_name, target_id)
if self._should_write_target(target_name):
with open(output_path, 'w', encoding='utf-8') as f:
f.write(f"{target_name}\n")
total_count = 1
return total_count
def _export_ip(self, target_name: str, output_path: Path) -> int:
"""导出 IP 类型目标"""
if self._should_write_target(target_name):
with open(output_path, 'w', encoding='utf-8') as f:
f.write(f"{target_name}\n")
return 1
return 0
def _export_cidr(self, target_name: str, output_path: Path) -> int:
"""导出 CIDR 类型目标,展开为每个 IP"""
network = ipaddress.ip_network(target_name, strict=False)
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for ip in network.hosts():
ip_str = str(ip)
if self._should_write_target(ip_str):
f.write(f"{ip_str}\n")
total_count += 1
if total_count % 10000 == 0:
logger.info("已导出 %d 个 IP...", total_count)
# /32 或 /128 特殊处理
if total_count == 0:
ip_str = str(network.network_address)
if self._should_write_target(ip_str):
with open(output_path, 'w', encoding='utf-8') as f:
f.write(f"{ip_str}\n")
total_count = 1
return total_count
def _should_write_target(self, target: str) -> bool:
"""检查目标是否应该写入(通过黑名单过滤)"""
if self.blacklist_service:
return self.blacklist_service.filter_url(target)
return True

View File

@@ -9,9 +9,6 @@
- Tasks 负责具体操作Flow 负责编排
"""
# Prefect Tasks
from .workspace_tasks import create_scan_workspace_task
# 子域名发现任务(已重构为多个子任务)
from .subdomain_discovery import (
run_subdomain_discovery_task,
@@ -19,17 +16,25 @@ from .subdomain_discovery import (
save_domains_task,
)
# 指纹识别任务
from .fingerprint_detect import (
export_urls_for_fingerprint_task,
run_xingfinger_and_stream_update_tech_task,
)
# 注意:
# - subdomain_discovery_task 已重构为多个子任务subdomain_discovery/
# - finalize_scan_task 已废弃Handler 统一管理状态)
# - initiate_scan_task 已迁移到 flows/initiate_scan_flow.py
# - cleanup_old_scans_task 已迁移到 flowscleanup_old_scans_flow
# - create_scan_workspace_task 已删除,直接使用 setup_scan_workspace()
__all__ = [
# Prefect Tasks
'create_scan_workspace_task',
# 子域名发现任务
'run_subdomain_discovery_task',
'merge_and_validate_task',
'save_domains_task',
# 指纹识别任务
'export_urls_for_fingerprint_task',
'run_xingfinger_and_stream_update_tech_task',
]

View File

@@ -1,20 +1,14 @@
"""
导出站点 URL 到 TXT 文件的 Task
使用流式处理,避免大量站点导致内存溢出
支持默认值模式:如果没有站点,根据 Target 类型生成默认 URL
- DOMAIN: http(s)://target_name
- IP: http(s)://ip
- CIDR: 展开为所有 IP 的 http(s)://ip
使用 TargetExportService 统一处理导出逻辑和默认值回退
数据源: WebSite.url
"""
import logging
import ipaddress
from pathlib import Path
from prefect import task
from apps.asset.repositories import DjangoWebSiteRepository
from apps.targets.services import TargetService
from apps.targets.models import Target
from apps.asset.models import WebSite
from apps.scan.services import TargetExportService, BlacklistService
logger = logging.getLogger(__name__)
@@ -24,19 +18,22 @@ def export_sites_task(
target_id: int,
output_file: str,
batch_size: int = 1000,
target_name: str = None
) -> dict:
"""
导出目标下的所有站点 URL 到 TXT 文件
使用流式处理支持大规模数据导出10万+站点)
支持默认值模式:如果没有站点,自动使用默认站点 URLhttp(s)://target_name
数据源: WebSite.url
懒加载模式:
- 如果数据库为空,根据 Target 类型生成默认 URL
- DOMAIN: http(s)://domain
- IP: http(s)://ip
- CIDR: 展开为所有 IP 的 URL
Args:
target_id: 目标 ID
output_file: 输出文件路径(绝对路径)
batch_size: 每次读取的批次大小,默认 1000
target_name: 目标名称(用于默认值模式)
Returns:
dict: {
@@ -49,134 +46,26 @@ def export_sites_task(
ValueError: 参数错误
IOError: 文件写入失败
"""
try:
# 初始化 Repository
repository = DjangoWebSiteRepository()
logger.info("开始导出站点 URL - Target ID: %d, 输出文件: %s", target_id, output_file)
# 确保输出目录存在
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 使用 Repository 流式查询站点 URL
url_iterator = repository.get_urls_for_export(
target_id=target_id,
batch_size=batch_size
)
# 流式写入文件
total_count = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in url_iterator:
# 每次只处理一个 URL边读边写
f.write(f"{url}\n")
total_count += 1
# 每写入 10000 条记录打印一次进度
if total_count % 10000 == 0:
logger.info("已导出 %d 个站点 URL...", total_count)
# ==================== 懒加载模式:根据 Target 类型生成默认 URL ====================
if total_count == 0:
total_count = _write_default_urls(target_id, target_name, output_path)
logger.info(
"✓ 站点 URL 导出完成 - 总数: %d, 文件: %s (%.2f KB)",
total_count,
str(output_path), # 使用绝对路径
output_path.stat().st_size / 1024
)
return {
'success': True,
'output_file': str(output_path),
'total_count': total_count
}
except FileNotFoundError as e:
logger.error("输出目录不存在: %s", e)
raise
except PermissionError as e:
logger.error("文件写入权限不足: %s", e)
raise
except Exception as e:
logger.exception("导出站点 URL 失败: %s", e)
raise
def _write_default_urls(target_id: int, target_name: str, output_path: Path) -> int:
"""
懒加载模式:根据 Target 类型生成默认 URL
# 构建数据源 querysetTask 层决定数据源)
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
Args:
target_id: 目标 ID
target_name: 目标名称(可选,如果为空则从数据库查询)
output_path: 输出文件路径
Returns:
int: 生成的 URL 数量
"""
# 获取 Target 信息
target_service = TargetService()
target = target_service.get_target(target_id)
# 使用 TargetExportService 处理导出
blacklist_service = BlacklistService()
export_service = TargetExportService(blacklist_service=blacklist_service)
if not target:
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
return 0
result = export_service.export_urls(
target_id=target_id,
output_path=output_file,
queryset=queryset,
batch_size=batch_size
)
target_name = target.name
target_type = target.type
logger.info("懒加载模式Target 类型=%s, 名称=%s", target_type, target_name)
total_urls = 0
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
if target_type == Target.TargetType.DOMAIN:
# 域名类型:生成 http(s)://domain
f.write(f"http://{target_name}\n")
f.write(f"https://{target_name}\n")
total_urls = 2
logger.info("✓ 域名默认 URL 已写入: http(s)://%s", target_name)
elif target_type == Target.TargetType.IP:
# IP 类型:生成 http(s)://ip
f.write(f"http://{target_name}\n")
f.write(f"https://{target_name}\n")
total_urls = 2
logger.info("✓ IP 默认 URL 已写入: http(s)://%s", target_name)
elif target_type == Target.TargetType.CIDR:
# CIDR 类型:展开为所有 IP 的 URL
try:
network = ipaddress.ip_network(target_name, strict=False)
for ip in network.hosts(): # 排除网络地址和广播地址
f.write(f"http://{ip}\n")
f.write(f"https://{ip}\n")
total_urls += 2
if total_urls % 10000 == 0:
logger.info("已生成 %d 个 URL...", total_urls)
# 如果是 /32 或 /128单个 IPhosts() 会为空
if total_urls == 0:
ip = str(network.network_address)
f.write(f"http://{ip}\n")
f.write(f"https://{ip}\n")
total_urls = 2
logger.info("✓ CIDR 默认 URL 已写入: %d 个 URL (来自 %s)", total_urls, target_name)
except ValueError as e:
logger.error("CIDR 解析失败: %s - %s", target_name, e)
return 0
else:
logger.warning("不支持的 Target 类型: %s", target_type)
return 0
return total_urls
# 保持返回值格式不变(向后兼容)
return {
'success': result['success'],
'output_file': result['output_file'],
'total_count': result['total_count']
}

Some files were not shown because too many files have changed in this diff Show More