mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-02-02 04:33:10 +08:00
Compare commits
114 Commits
v1.1.7
...
v1.2.9-dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da96d437a4 | ||
|
|
feaf8062e5 | ||
|
|
4bab76f233 | ||
|
|
09416b4615 | ||
|
|
bc1c5f6b0e | ||
|
|
2f2742e6fe | ||
|
|
be3c346a74 | ||
|
|
0c7a6fff12 | ||
|
|
3b4f0e3147 | ||
|
|
51212a2a0c | ||
|
|
58533bbaf6 | ||
|
|
6ccca1602d | ||
|
|
6389b0f672 | ||
|
|
d7599b8599 | ||
|
|
8eff298293 | ||
|
|
3634101c5b | ||
|
|
163973a7df | ||
|
|
80ffecba3e | ||
|
|
3c21ac940c | ||
|
|
5c9f484d70 | ||
|
|
7567f6c25b | ||
|
|
0599a0b298 | ||
|
|
f7557fe90c | ||
|
|
13571b9772 | ||
|
|
8ee76eef69 | ||
|
|
2a31e29aa2 | ||
|
|
81abc59961 | ||
|
|
ffbfec6dd5 | ||
|
|
a0091636a8 | ||
|
|
69490ab396 | ||
|
|
7306964abf | ||
|
|
cb6b0259e3 | ||
|
|
e1b4618e58 | ||
|
|
556dcf5f62 | ||
|
|
0628eef025 | ||
|
|
38ed8bc642 | ||
|
|
2f4d6a2168 | ||
|
|
c25cb9e06b | ||
|
|
b14ab71c7f | ||
|
|
8b5060e2d3 | ||
|
|
3c9335febf | ||
|
|
1b95e4f2c3 | ||
|
|
d20a600afc | ||
|
|
c29b11fd37 | ||
|
|
6caf707072 | ||
|
|
2627b1fc40 | ||
|
|
ec6712b9b4 | ||
|
|
9d5e4d5408 | ||
|
|
c5d5b24c8f | ||
|
|
671cb56b62 | ||
|
|
51025f69a8 | ||
|
|
b2403b29c4 | ||
|
|
18ef01a47b | ||
|
|
0bf8108fb3 | ||
|
|
837ad19131 | ||
|
|
d7de9a7129 | ||
|
|
22b4e51b42 | ||
|
|
d03628ee45 | ||
|
|
0baabe0753 | ||
|
|
e1191d7abf | ||
|
|
82a2e9a0e7 | ||
|
|
1ccd1bc338 | ||
|
|
b4d42f5372 | ||
|
|
2c66450756 | ||
|
|
119d82dc89 | ||
|
|
fba7f7c508 | ||
|
|
99d384ce29 | ||
|
|
07f36718ab | ||
|
|
7e3f69c208 | ||
|
|
5f90473c3c | ||
|
|
e2a815b96a | ||
|
|
f86a1a9d47 | ||
|
|
d5945679aa | ||
|
|
51e2c51748 | ||
|
|
e2cbf98dda | ||
|
|
cd72bdf7c3 | ||
|
|
35abcf7e39 | ||
|
|
09f2d343a4 | ||
|
|
54d1f86bde | ||
|
|
a3997c9676 | ||
|
|
c90a55f85e | ||
|
|
2eab88b452 | ||
|
|
1baf0eb5e1 | ||
|
|
b61e73f7be | ||
|
|
e896734dfc | ||
|
|
cd83f52f35 | ||
|
|
3e29554c36 | ||
|
|
18e02b536e | ||
|
|
4c1c6f70ab | ||
|
|
a72e7675f5 | ||
|
|
93c2163764 | ||
|
|
de72c91561 | ||
|
|
3e6d060b75 | ||
|
|
766f045904 | ||
|
|
8acfe1cc33 | ||
|
|
7aec3eabb2 | ||
|
|
b1f11c36a4 | ||
|
|
d97fb5245a | ||
|
|
ddf9a1f5a4 | ||
|
|
47f9f96a4b | ||
|
|
6f43e73162 | ||
|
|
9b7d496f3e | ||
|
|
6390849d52 | ||
|
|
7a6d2054f6 | ||
|
|
73ebaab232 | ||
|
|
11899b29c2 | ||
|
|
877d2a56d1 | ||
|
|
dc1e94f038 | ||
|
|
9c3833d13d | ||
|
|
92f3b722ef | ||
|
|
9ef503c666 | ||
|
|
c3a43e94fa | ||
|
|
d6d94355fb | ||
|
|
bc638eabf4 |
48
.github/workflows/docker-build.yml
vendored
48
.github/workflows/docker-build.yml
vendored
@@ -106,33 +106,65 @@ jobs:
|
||||
${{ steps.version.outputs.IS_RELEASE == 'true' && format('{0}/{1}:latest', env.IMAGE_PREFIX, matrix.image) || '' }}
|
||||
build-args: |
|
||||
IMAGE_TAG=${{ steps.version.outputs.VERSION }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
cache-from: type=gha,scope=${{ matrix.image }}
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.image }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
|
||||
# 所有镜像构建成功后,更新 VERSION 文件
|
||||
# 根据 tag 所在的分支更新对应分支的 VERSION 文件
|
||||
update-version:
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
steps:
|
||||
- name: Checkout
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: main
|
||||
fetch-depth: 0 # 获取完整历史,用于判断 tag 所在分支
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Determine source branch and version
|
||||
id: branch
|
||||
run: |
|
||||
VERSION="${GITHUB_REF#refs/tags/}"
|
||||
|
||||
# 查找包含此 tag 的分支
|
||||
BRANCHES=$(git branch -r --contains ${{ github.ref_name }})
|
||||
echo "Branches containing tag: $BRANCHES"
|
||||
|
||||
# 判断 tag 来自哪个分支
|
||||
if echo "$BRANCHES" | grep -q "origin/main"; then
|
||||
TARGET_BRANCH="main"
|
||||
UPDATE_LATEST="true"
|
||||
elif echo "$BRANCHES" | grep -q "origin/dev"; then
|
||||
TARGET_BRANCH="dev"
|
||||
UPDATE_LATEST="false"
|
||||
else
|
||||
echo "Warning: Tag not found in main or dev branch, defaulting to main"
|
||||
TARGET_BRANCH="main"
|
||||
UPDATE_LATEST="false"
|
||||
fi
|
||||
|
||||
echo "BRANCH=$TARGET_BRANCH" >> $GITHUB_OUTPUT
|
||||
echo "VERSION=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "UPDATE_LATEST=$UPDATE_LATEST" >> $GITHUB_OUTPUT
|
||||
echo "Will update VERSION on branch: $TARGET_BRANCH"
|
||||
|
||||
- name: Checkout target branch
|
||||
run: |
|
||||
git checkout ${{ steps.branch.outputs.BRANCH }}
|
||||
|
||||
- name: Update VERSION file
|
||||
run: |
|
||||
VERSION="${GITHUB_REF#refs/tags/}"
|
||||
VERSION="${{ steps.branch.outputs.VERSION }}"
|
||||
echo "$VERSION" > VERSION
|
||||
echo "Updated VERSION to $VERSION"
|
||||
echo "Updated VERSION to $VERSION on branch ${{ steps.branch.outputs.BRANCH }}"
|
||||
|
||||
- name: Commit and push
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git add VERSION
|
||||
git diff --staged --quiet || git commit -m "chore: bump version to ${GITHUB_REF#refs/tags/}"
|
||||
git push
|
||||
git diff --staged --quiet || git commit -m "chore: bump version to ${{ steps.branch.outputs.VERSION }}"
|
||||
git push origin ${{ steps.branch.outputs.BRANCH }}
|
||||
|
||||
14
README.md
14
README.md
@@ -177,11 +177,19 @@ cd xingrin
|
||||
|
||||
# 安装并启动(生产模式)
|
||||
sudo ./install.sh
|
||||
|
||||
# 🇨🇳 中国大陆用户推荐使用镜像加速(第三方加速服务可能会失效,不保证长期可用)
|
||||
sudo ./install.sh --mirror
|
||||
```
|
||||
|
||||
> **💡 --mirror 参数说明**
|
||||
> - 自动配置 Docker 镜像加速(国内镜像源)
|
||||
> - 加速 Git 仓库克隆(Nuclei 模板等)
|
||||
> - 大幅提升安装速度,避免网络超时
|
||||
|
||||
### 访问服务
|
||||
|
||||
- **Web 界面**: `https://localhost`
|
||||
- **Web 界面**: `https://ip:8083`
|
||||
|
||||
### 常用命令
|
||||
|
||||
@@ -197,16 +205,12 @@ sudo ./restart.sh
|
||||
|
||||
# 卸载
|
||||
sudo ./uninstall.sh
|
||||
|
||||
# 更新
|
||||
sudo ./update.sh
|
||||
```
|
||||
|
||||
## 🤝 反馈与贡献
|
||||
|
||||
- 🐛 **如果发现 Bug** 可以点击右边链接进行提交 [Issue](https://github.com/yyhuni/xingrin/issues)
|
||||
- 💡 **有新想法,比如UI设计,功能设计等** 欢迎点击右边链接进行提交建议 [Issue](https://github.com/yyhuni/xingrin/issues)
|
||||
- 🔧 **想参与开发?** 关注我公众号与我个人联系
|
||||
|
||||
## 📧 联系
|
||||
- 目前版本就我个人使用,可能会有很多边界问题
|
||||
|
||||
@@ -134,8 +134,8 @@ class VulnerabilitySnapshotSerializer(serializers.ModelSerializer):
|
||||
class EndpointListSerializer(serializers.ModelSerializer):
|
||||
"""端点列表序列化器(用于目标端点列表页)"""
|
||||
|
||||
# 将 GF 匹配模式映射为前端使用的 tags 字段
|
||||
tags = serializers.ListField(
|
||||
# GF 匹配模式(gf-patterns 工具匹配的敏感 URL 模式)
|
||||
gfPatterns = serializers.ListField(
|
||||
child=serializers.CharField(),
|
||||
source='matched_gf_patterns',
|
||||
read_only=True,
|
||||
@@ -155,7 +155,7 @@ class EndpointListSerializer(serializers.ModelSerializer):
|
||||
'body_preview',
|
||||
'tech',
|
||||
'vhost',
|
||||
'tags',
|
||||
'gfPatterns',
|
||||
'created_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
@@ -258,8 +258,8 @@ class DirectorySnapshotSerializer(serializers.ModelSerializer):
|
||||
class EndpointSnapshotSerializer(serializers.ModelSerializer):
|
||||
"""端点快照序列化器(用于扫描历史)"""
|
||||
|
||||
# 将 GF 匹配模式映射为前端使用的 tags 字段
|
||||
tags = serializers.ListField(
|
||||
# GF 匹配模式(gf-patterns 工具匹配的敏感 URL 模式)
|
||||
gfPatterns = serializers.ListField(
|
||||
child=serializers.CharField(),
|
||||
source='matched_gf_patterns',
|
||||
read_only=True,
|
||||
@@ -280,7 +280,7 @@ class EndpointSnapshotSerializer(serializers.ModelSerializer):
|
||||
'body_preview',
|
||||
'tech',
|
||||
'vhost',
|
||||
'tags',
|
||||
'gfPatterns',
|
||||
'created_at',
|
||||
]
|
||||
read_only_fields = fields
|
||||
|
||||
@@ -28,6 +28,7 @@ class EndpointService:
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status_code',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
@@ -115,7 +116,7 @@ class EndpointService:
|
||||
"""获取目标下的所有端点"""
|
||||
queryset = self.repo.get_by_target(target_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
|
||||
return queryset
|
||||
|
||||
def count_endpoints_by_target(self, target_id: int) -> int:
|
||||
@@ -134,7 +135,7 @@ class EndpointService:
|
||||
"""获取所有端点(全局查询)"""
|
||||
queryset = self.repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
|
||||
return queryset
|
||||
|
||||
def iter_endpoint_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
|
||||
@@ -20,6 +20,7 @@ class WebSiteService:
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status_code',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
def __init__(self, repository=None):
|
||||
@@ -107,14 +108,14 @@ class WebSiteService:
|
||||
"""获取目标下的所有网站"""
|
||||
queryset = self.repo.get_by_target(target_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
|
||||
return queryset
|
||||
|
||||
def get_all(self, filter_query: Optional[str] = None):
|
||||
"""获取所有网站"""
|
||||
queryset = self.repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING, json_array_fields=['tech'])
|
||||
return queryset
|
||||
|
||||
def get_by_url(self, url: str, target_id: int) -> int:
|
||||
|
||||
@@ -2,6 +2,8 @@ import logging
|
||||
from rest_framework import viewsets, status, filters
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.exceptions import NotFound, ValidationError as DRFValidationError
|
||||
from django.core.exceptions import ValidationError, ObjectDoesNotExist
|
||||
@@ -57,7 +59,7 @@ class AssetStatisticsViewSet(viewsets.ViewSet):
|
||||
"""
|
||||
try:
|
||||
stats = self.service.get_statistics()
|
||||
return Response({
|
||||
return success_response(data={
|
||||
'totalTargets': stats['total_targets'],
|
||||
'totalSubdomains': stats['total_subdomains'],
|
||||
'totalIps': stats['total_ips'],
|
||||
@@ -80,9 +82,10 @@ class AssetStatisticsViewSet(viewsets.ViewSet):
|
||||
})
|
||||
except (DatabaseError, OperationalError) as e:
|
||||
logger.exception("获取资产统计数据失败")
|
||||
return Response(
|
||||
{'error': '获取统计数据失败'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to get statistics',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='history')
|
||||
@@ -107,12 +110,13 @@ class AssetStatisticsViewSet(viewsets.ViewSet):
|
||||
days = min(max(days, 1), 90) # 限制在 1-90 天
|
||||
|
||||
history = self.service.get_statistics_history(days=days)
|
||||
return Response(history)
|
||||
return success_response(data=history)
|
||||
except (DatabaseError, OperationalError) as e:
|
||||
logger.exception("获取统计历史数据失败")
|
||||
return Response(
|
||||
{'error': '获取历史数据失败'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to get history data',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
@@ -164,45 +168,50 @@ class SubdomainViewSet(viewsets.ModelViewSet):
|
||||
|
||||
响应:
|
||||
{
|
||||
"message": "批量创建完成",
|
||||
"createdCount": 10,
|
||||
"skippedCount": 2,
|
||||
"invalidCount": 1,
|
||||
"mismatchedCount": 1,
|
||||
"totalReceived": 14
|
||||
"data": {
|
||||
"createdCount": 10,
|
||||
"skippedCount": 2,
|
||||
"invalidCount": 1,
|
||||
"mismatchedCount": 1,
|
||||
"totalReceived": 14
|
||||
}
|
||||
}
|
||||
"""
|
||||
from apps.targets.models import Target
|
||||
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
return Response(
|
||||
{'error': '必须在目标下批量创建子域名'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Must create subdomains under a target',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 获取目标
|
||||
try:
|
||||
target = Target.objects.get(pk=target_pk)
|
||||
except Target.DoesNotExist:
|
||||
return Response(
|
||||
{'error': '目标不存在'},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='Target not found',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
# 验证目标类型必须为域名
|
||||
if target.type != Target.TargetType.DOMAIN:
|
||||
return Response(
|
||||
{'error': '只有域名类型的目标支持导入子域名'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Only domain type targets support subdomain import',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 获取请求体中的子域名列表
|
||||
subdomains = request.data.get('subdomains', [])
|
||||
if not subdomains or not isinstance(subdomains, list):
|
||||
return Response(
|
||||
{'error': '请求体不能为空或格式错误'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Request body cannot be empty or invalid format',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 调用 service 层处理
|
||||
@@ -214,19 +223,19 @@ class SubdomainViewSet(viewsets.ModelViewSet):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("批量创建子域名失败")
|
||||
return Response(
|
||||
{'error': '服务器内部错误'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Server internal error',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
return Response({
|
||||
'message': '批量创建完成',
|
||||
return success_response(data={
|
||||
'createdCount': result.created_count,
|
||||
'skippedCount': result.skipped_count,
|
||||
'invalidCount': result.invalid_count,
|
||||
'mismatchedCount': result.mismatched_count,
|
||||
'totalReceived': result.total_received,
|
||||
}, status=status.HTTP_200_OK)
|
||||
})
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request, **kwargs):
|
||||
@@ -265,6 +274,7 @@ class WebSiteViewSet(viewsets.ModelViewSet):
|
||||
- host="example" 主机名模糊匹配
|
||||
- title="login" 标题模糊匹配
|
||||
- status="200,301" 状态码多值匹配
|
||||
- tech="nginx" 技术栈匹配(数组字段)
|
||||
- 多条件空格分隔 AND 关系
|
||||
"""
|
||||
|
||||
@@ -299,35 +309,38 @@ class WebSiteViewSet(viewsets.ModelViewSet):
|
||||
|
||||
响应:
|
||||
{
|
||||
"message": "批量创建完成",
|
||||
"createdCount": 10,
|
||||
"mismatchedCount": 2
|
||||
"data": {
|
||||
"createdCount": 10
|
||||
}
|
||||
}
|
||||
"""
|
||||
from apps.targets.models import Target
|
||||
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
return Response(
|
||||
{'error': '必须在目标下批量创建网站'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Must create websites under a target',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 获取目标
|
||||
try:
|
||||
target = Target.objects.get(pk=target_pk)
|
||||
except Target.DoesNotExist:
|
||||
return Response(
|
||||
{'error': '目标不存在'},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='Target not found',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
# 获取请求体中的 URL 列表
|
||||
urls = request.data.get('urls', [])
|
||||
if not urls or not isinstance(urls, list):
|
||||
return Response(
|
||||
{'error': '请求体不能为空或格式错误'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Request body cannot be empty or invalid format',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 调用 service 层处理
|
||||
@@ -340,15 +353,15 @@ class WebSiteViewSet(viewsets.ModelViewSet):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("批量创建网站失败")
|
||||
return Response(
|
||||
{'error': '服务器内部错误'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Server internal error',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
return Response({
|
||||
'message': '批量创建完成',
|
||||
return success_response(data={
|
||||
'createdCount': created_count,
|
||||
}, status=status.HTTP_200_OK)
|
||||
})
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request, **kwargs):
|
||||
@@ -426,35 +439,38 @@ class DirectoryViewSet(viewsets.ModelViewSet):
|
||||
|
||||
响应:
|
||||
{
|
||||
"message": "批量创建完成",
|
||||
"createdCount": 10,
|
||||
"mismatchedCount": 2
|
||||
"data": {
|
||||
"createdCount": 10
|
||||
}
|
||||
}
|
||||
"""
|
||||
from apps.targets.models import Target
|
||||
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
return Response(
|
||||
{'error': '必须在目标下批量创建目录'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Must create directories under a target',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 获取目标
|
||||
try:
|
||||
target = Target.objects.get(pk=target_pk)
|
||||
except Target.DoesNotExist:
|
||||
return Response(
|
||||
{'error': '目标不存在'},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='Target not found',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
# 获取请求体中的 URL 列表
|
||||
urls = request.data.get('urls', [])
|
||||
if not urls or not isinstance(urls, list):
|
||||
return Response(
|
||||
{'error': '请求体不能为空或格式错误'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Request body cannot be empty or invalid format',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 调用 service 层处理
|
||||
@@ -467,15 +483,15 @@ class DirectoryViewSet(viewsets.ModelViewSet):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("批量创建目录失败")
|
||||
return Response(
|
||||
{'error': '服务器内部错误'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Server internal error',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
return Response({
|
||||
'message': '批量创建完成',
|
||||
return success_response(data={
|
||||
'createdCount': created_count,
|
||||
}, status=status.HTTP_200_OK)
|
||||
})
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request, **kwargs):
|
||||
@@ -519,6 +535,7 @@ class EndpointViewSet(viewsets.ModelViewSet):
|
||||
- host="example" 主机名模糊匹配
|
||||
- title="login" 标题模糊匹配
|
||||
- status="200,301" 状态码多值匹配
|
||||
- tech="nginx" 技术栈匹配(数组字段)
|
||||
- 多条件空格分隔 AND 关系
|
||||
"""
|
||||
|
||||
@@ -553,35 +570,38 @@ class EndpointViewSet(viewsets.ModelViewSet):
|
||||
|
||||
响应:
|
||||
{
|
||||
"message": "批量创建完成",
|
||||
"createdCount": 10,
|
||||
"mismatchedCount": 2
|
||||
"data": {
|
||||
"createdCount": 10
|
||||
}
|
||||
}
|
||||
"""
|
||||
from apps.targets.models import Target
|
||||
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
if not target_pk:
|
||||
return Response(
|
||||
{'error': '必须在目标下批量创建端点'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Must create endpoints under a target',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 获取目标
|
||||
try:
|
||||
target = Target.objects.get(pk=target_pk)
|
||||
except Target.DoesNotExist:
|
||||
return Response(
|
||||
{'error': '目标不存在'},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='Target not found',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
# 获取请求体中的 URL 列表
|
||||
urls = request.data.get('urls', [])
|
||||
if not urls or not isinstance(urls, list):
|
||||
return Response(
|
||||
{'error': '请求体不能为空或格式错误'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Request body cannot be empty or invalid format',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 调用 service 层处理
|
||||
@@ -594,15 +614,15 @@ class EndpointViewSet(viewsets.ModelViewSet):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("批量创建端点失败")
|
||||
return Response(
|
||||
{'error': '服务器内部错误'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Server internal error',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
return Response({
|
||||
'message': '批量创建完成',
|
||||
return success_response(data={
|
||||
'createdCount': created_count,
|
||||
}, status=status.HTTP_200_OK)
|
||||
})
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request, **kwargs):
|
||||
|
||||
@@ -40,8 +40,14 @@ def fetch_config_and_setup_django():
|
||||
print(f"[CONFIG] 正在从配置中心获取配置: {config_url}")
|
||||
print(f"[CONFIG] IS_LOCAL={is_local}")
|
||||
try:
|
||||
# 构建请求头(包含 Worker API Key)
|
||||
headers = {}
|
||||
worker_api_key = os.environ.get("WORKER_API_KEY", "")
|
||||
if worker_api_key:
|
||||
headers["X-Worker-API-Key"] = worker_api_key
|
||||
|
||||
# verify=False: 远程 Worker 通过 HTTPS 访问时可能使用自签名证书
|
||||
resp = requests.get(config_url, timeout=10, verify=False)
|
||||
resp = requests.get(config_url, headers=headers, timeout=10, verify=False)
|
||||
resp.raise_for_status()
|
||||
config = resp.json()
|
||||
|
||||
@@ -57,9 +63,6 @@ def fetch_config_and_setup_django():
|
||||
os.environ.setdefault("DB_USER", db_user)
|
||||
os.environ.setdefault("DB_PASSWORD", config['db']['password'])
|
||||
|
||||
# Redis 配置
|
||||
os.environ.setdefault("REDIS_URL", config['redisUrl'])
|
||||
|
||||
# 日志配置
|
||||
os.environ.setdefault("LOG_DIR", config['paths']['logs'])
|
||||
os.environ.setdefault("LOG_LEVEL", config['logging']['level'])
|
||||
@@ -71,7 +74,6 @@ def fetch_config_and_setup_django():
|
||||
print(f"[CONFIG] DB_PORT: {db_port}")
|
||||
print(f"[CONFIG] DB_NAME: {db_name}")
|
||||
print(f"[CONFIG] DB_USER: {db_user}")
|
||||
print(f"[CONFIG] REDIS_URL: {config['redisUrl']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 获取配置失败: {config_url} - {e}", file=sys.stderr)
|
||||
|
||||
31
backend/apps/common/error_codes.py
Normal file
31
backend/apps/common/error_codes.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
标准化错误码定义
|
||||
|
||||
采用简化方案(参考 Stripe、GitHub 等大厂做法):
|
||||
- 只定义 5-10 个通用错误码
|
||||
- 未知错误使用通用错误码
|
||||
- 错误码格式:大写字母和下划线组成
|
||||
"""
|
||||
|
||||
|
||||
class ErrorCodes:
|
||||
"""标准化错误码
|
||||
|
||||
只定义通用错误码,其他错误使用通用消息。
|
||||
这是 Stripe、GitHub 等大厂的标准做法。
|
||||
|
||||
错误码格式规范:
|
||||
- 使用大写字母和下划线
|
||||
- 简洁明了,易于理解
|
||||
- 前端通过错误码映射到 i18n 键
|
||||
"""
|
||||
|
||||
# 通用错误码(8 个)
|
||||
VALIDATION_ERROR = 'VALIDATION_ERROR' # 输入验证失败
|
||||
NOT_FOUND = 'NOT_FOUND' # 资源未找到
|
||||
PERMISSION_DENIED = 'PERMISSION_DENIED' # 权限不足
|
||||
SERVER_ERROR = 'SERVER_ERROR' # 服务器内部错误
|
||||
BAD_REQUEST = 'BAD_REQUEST' # 请求格式错误
|
||||
CONFLICT = 'CONFLICT' # 资源冲突(如重复创建)
|
||||
UNAUTHORIZED = 'UNAUTHORIZED' # 未认证
|
||||
RATE_LIMITED = 'RATE_LIMITED' # 请求过于频繁
|
||||
49
backend/apps/common/exception_handlers.py
Normal file
49
backend/apps/common/exception_handlers.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
自定义异常处理器
|
||||
|
||||
统一处理 DRF 异常,确保错误响应格式一致
|
||||
"""
|
||||
|
||||
from rest_framework.views import exception_handler
|
||||
from rest_framework import status
|
||||
from rest_framework.exceptions import AuthenticationFailed, NotAuthenticated
|
||||
|
||||
from apps.common.response_helpers import error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
|
||||
|
||||
def custom_exception_handler(exc, context):
|
||||
"""
|
||||
自定义异常处理器
|
||||
|
||||
处理认证相关异常,返回统一格式的错误响应
|
||||
"""
|
||||
# 先调用 DRF 默认的异常处理器
|
||||
response = exception_handler(exc, context)
|
||||
|
||||
if response is not None:
|
||||
# 处理 401 未认证错误
|
||||
if response.status_code == status.HTTP_401_UNAUTHORIZED:
|
||||
return error_response(
|
||||
code=ErrorCodes.UNAUTHORIZED,
|
||||
message='Authentication required',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
# 处理 403 权限不足错误
|
||||
if response.status_code == status.HTTP_403_FORBIDDEN:
|
||||
return error_response(
|
||||
code=ErrorCodes.PERMISSION_DENIED,
|
||||
message='Permission denied',
|
||||
status_code=status.HTTP_403_FORBIDDEN
|
||||
)
|
||||
|
||||
# 处理 NotAuthenticated 和 AuthenticationFailed 异常
|
||||
if isinstance(exc, (NotAuthenticated, AuthenticationFailed)):
|
||||
return error_response(
|
||||
code=ErrorCodes.UNAUTHORIZED,
|
||||
message='Authentication required',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
return response
|
||||
80
backend/apps/common/permissions.py
Normal file
80
backend/apps/common/permissions.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
集中式权限管理
|
||||
|
||||
实现三类端点的认证逻辑:
|
||||
1. 公开端点(无需认证):登录、登出、获取当前用户状态
|
||||
2. Worker 端点(API Key 认证):注册、配置、心跳、回调、资源同步
|
||||
3. 业务端点(Session 认证):其他所有 API
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from django.conf import settings
|
||||
from rest_framework.permissions import BasePermission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 公开端点白名单(无需任何认证)
|
||||
PUBLIC_ENDPOINTS = [
|
||||
r'^/api/auth/login/$',
|
||||
r'^/api/auth/logout/$',
|
||||
r'^/api/auth/me/$',
|
||||
]
|
||||
|
||||
# Worker API 端点(需要 API Key 认证)
|
||||
# 包括:注册、配置、心跳、回调、资源同步(字典下载)
|
||||
WORKER_ENDPOINTS = [
|
||||
r'^/api/workers/register/$',
|
||||
r'^/api/workers/config/$',
|
||||
r'^/api/workers/\d+/heartbeat/$',
|
||||
r'^/api/callbacks/',
|
||||
# 资源同步端点(Worker 需要下载字典文件)
|
||||
r'^/api/wordlists/download/$',
|
||||
# 注意:指纹导出 API 使用 Session 认证(前端用户导出用)
|
||||
# Worker 通过数据库直接获取指纹数据,不需要 HTTP API
|
||||
]
|
||||
|
||||
|
||||
class IsAuthenticatedOrPublic(BasePermission):
|
||||
"""
|
||||
自定义权限类:
|
||||
- 白名单内的端点公开访问
|
||||
- Worker 端点需要 API Key 认证
|
||||
- 其他端点需要 Session 认证
|
||||
"""
|
||||
|
||||
def has_permission(self, request, view):
|
||||
path = request.path
|
||||
|
||||
# 检查是否在公开白名单内
|
||||
for pattern in PUBLIC_ENDPOINTS:
|
||||
if re.match(pattern, path):
|
||||
return True
|
||||
|
||||
# 检查是否是 Worker 端点
|
||||
for pattern in WORKER_ENDPOINTS:
|
||||
if re.match(pattern, path):
|
||||
return self._check_worker_api_key(request)
|
||||
|
||||
# 其他路径需要 Session 认证
|
||||
return request.user and request.user.is_authenticated
|
||||
|
||||
def _check_worker_api_key(self, request):
|
||||
"""验证 Worker API Key"""
|
||||
api_key = request.headers.get('X-Worker-API-Key')
|
||||
expected_key = getattr(settings, 'WORKER_API_KEY', None)
|
||||
|
||||
if not expected_key:
|
||||
# 未配置 API Key 时,拒绝所有 Worker 请求
|
||||
logger.warning("WORKER_API_KEY 未配置,拒绝 Worker 请求")
|
||||
return False
|
||||
|
||||
if not api_key:
|
||||
logger.warning(f"Worker 请求缺少 X-Worker-API-Key Header: {request.path}")
|
||||
return False
|
||||
|
||||
if api_key != expected_key:
|
||||
logger.warning(f"Worker API Key 无效: {request.path}")
|
||||
return False
|
||||
|
||||
return True
|
||||
88
backend/apps/common/response_helpers.py
Normal file
88
backend/apps/common/response_helpers.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
标准化 API 响应辅助函数
|
||||
|
||||
遵循行业标准(RFC 9457 Problem Details)和大厂实践(Google、Stripe、GitHub):
|
||||
- 成功响应只包含数据,不包含 message 字段
|
||||
- 错误响应使用机器可读的错误码,前端映射到 i18n 消息
|
||||
"""
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
|
||||
|
||||
def success_response(
|
||||
data: Optional[Union[Dict[str, Any], List[Any]]] = None,
|
||||
status_code: int = status.HTTP_200_OK
|
||||
) -> Response:
|
||||
"""
|
||||
标准化成功响应
|
||||
|
||||
直接返回数据,不做包装,符合 Stripe/GitHub 等大厂标准。
|
||||
|
||||
Args:
|
||||
data: 响应数据(dict 或 list)
|
||||
status_code: HTTP 状态码,默认 200
|
||||
|
||||
Returns:
|
||||
Response: DRF Response 对象
|
||||
|
||||
Examples:
|
||||
# 单个资源
|
||||
>>> success_response(data={'id': 1, 'name': 'Test'})
|
||||
{'id': 1, 'name': 'Test'}
|
||||
|
||||
# 操作结果
|
||||
>>> success_response(data={'count': 3, 'scans': [...]})
|
||||
{'count': 3, 'scans': [...]}
|
||||
|
||||
# 创建资源
|
||||
>>> success_response(data={'id': 1}, status_code=201)
|
||||
"""
|
||||
# 注意:不能使用 data or {},因为空列表 [] 会被转换为 {}
|
||||
if data is None:
|
||||
data = {}
|
||||
return Response(data, status=status_code)
|
||||
|
||||
|
||||
def error_response(
|
||||
code: str,
|
||||
message: Optional[str] = None,
|
||||
details: Optional[List[Dict[str, Any]]] = None,
|
||||
status_code: int = status.HTTP_400_BAD_REQUEST
|
||||
) -> Response:
|
||||
"""
|
||||
标准化错误响应
|
||||
|
||||
Args:
|
||||
code: 错误码(如 'VALIDATION_ERROR', 'NOT_FOUND')
|
||||
格式:大写字母和下划线组成
|
||||
message: 开发者调试信息(非用户显示)
|
||||
details: 详细错误信息(如字段级验证错误)
|
||||
status_code: HTTP 状态码,默认 400
|
||||
|
||||
Returns:
|
||||
Response: DRF Response 对象
|
||||
|
||||
Examples:
|
||||
# 简单错误
|
||||
>>> error_response(code='NOT_FOUND', status_code=404)
|
||||
{'error': {'code': 'NOT_FOUND'}}
|
||||
|
||||
# 带调试信息
|
||||
>>> error_response(
|
||||
... code='VALIDATION_ERROR',
|
||||
... message='Invalid input data',
|
||||
... details=[{'field': 'name', 'message': 'Required'}]
|
||||
... )
|
||||
{'error': {'code': 'VALIDATION_ERROR', 'message': '...', 'details': [...]}}
|
||||
"""
|
||||
error_body: Dict[str, Any] = {'code': code}
|
||||
|
||||
if message:
|
||||
error_body['message'] = message
|
||||
|
||||
if details:
|
||||
error_body['details'] = details
|
||||
|
||||
return Response({'error': error_body}, status=status_code)
|
||||
@@ -4,15 +4,28 @@
|
||||
提供系统日志的读取功能,支持:
|
||||
- 从日志目录读取日志文件
|
||||
- 限制返回行数,防止内存溢出
|
||||
- 列出可用的日志文件
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from datetime import datetime, timezone
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogFileInfo(TypedDict):
|
||||
"""日志文件信息"""
|
||||
filename: str
|
||||
category: str # 'system' | 'error' | 'performance' | 'container'
|
||||
size: int
|
||||
modifiedAt: str # ISO 8601 格式
|
||||
|
||||
|
||||
class SystemLogService:
|
||||
"""
|
||||
系统日志服务类
|
||||
@@ -20,23 +33,131 @@ class SystemLogService:
|
||||
负责读取系统日志文件,支持从容器内路径或宿主机挂载路径读取日志。
|
||||
"""
|
||||
|
||||
# 日志文件分类规则
|
||||
CATEGORY_RULES = [
|
||||
('xingrin.log', 'system'),
|
||||
('xingrin_error.log', 'error'),
|
||||
('performance.log', 'performance'),
|
||||
('container_*.log', 'container'),
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
# 日志文件路径(容器内路径,通过 volume 挂载到宿主机 /opt/xingrin/logs)
|
||||
self.log_file = "/app/backend/logs/xingrin.log"
|
||||
self.default_lines = 200 # 默认返回行数
|
||||
self.max_lines = 10000 # 最大返回行数限制
|
||||
self.timeout_seconds = 3 # tail 命令超时时间
|
||||
# 日志目录路径
|
||||
self.log_dir = "/opt/xingrin/logs"
|
||||
self.default_file = "xingrin.log" # 默认日志文件
|
||||
self.default_lines = 200 # 默认返回行数
|
||||
self.max_lines = 10000 # 最大返回行数限制
|
||||
self.timeout_seconds = 3 # tail 命令超时时间
|
||||
|
||||
def get_logs_content(self, lines: int | None = None) -> str:
|
||||
def _categorize_file(self, filename: str) -> str | None:
|
||||
"""
|
||||
根据文件名判断日志分类
|
||||
|
||||
Returns:
|
||||
分类名称,如果不是日志文件则返回 None
|
||||
"""
|
||||
for pattern, category in self.CATEGORY_RULES:
|
||||
if fnmatch.fnmatch(filename, pattern):
|
||||
return category
|
||||
return None
|
||||
|
||||
def _validate_filename(self, filename: str) -> bool:
|
||||
"""
|
||||
验证文件名是否合法(防止路径遍历攻击)
|
||||
|
||||
Args:
|
||||
filename: 要验证的文件名
|
||||
|
||||
Returns:
|
||||
bool: 文件名是否合法
|
||||
"""
|
||||
# 不允许包含路径分隔符
|
||||
if '/' in filename or '\\' in filename:
|
||||
return False
|
||||
# 不允许 .. 路径遍历
|
||||
if '..' in filename:
|
||||
return False
|
||||
# 必须是已知的日志文件类型
|
||||
return self._categorize_file(filename) is not None
|
||||
|
||||
def get_log_files(self) -> list[LogFileInfo]:
|
||||
"""
|
||||
获取所有可用的日志文件列表
|
||||
|
||||
Returns:
|
||||
日志文件信息列表,按分类和文件名排序
|
||||
"""
|
||||
files: list[LogFileInfo] = []
|
||||
|
||||
if not os.path.isdir(self.log_dir):
|
||||
logger.warning("日志目录不存在: %s", self.log_dir)
|
||||
return files
|
||||
|
||||
for filename in os.listdir(self.log_dir):
|
||||
filepath = os.path.join(self.log_dir, filename)
|
||||
|
||||
# 只处理文件,跳过目录
|
||||
if not os.path.isfile(filepath):
|
||||
continue
|
||||
|
||||
# 判断分类
|
||||
category = self._categorize_file(filename)
|
||||
if category is None:
|
||||
continue
|
||||
|
||||
# 获取文件信息
|
||||
try:
|
||||
stat = os.stat(filepath)
|
||||
modified_at = datetime.fromtimestamp(
|
||||
stat.st_mtime, tz=timezone.utc
|
||||
).isoformat()
|
||||
|
||||
files.append({
|
||||
'filename': filename,
|
||||
'category': category,
|
||||
'size': stat.st_size,
|
||||
'modifiedAt': modified_at,
|
||||
})
|
||||
except OSError as e:
|
||||
logger.warning("获取文件信息失败 %s: %s", filepath, e)
|
||||
continue
|
||||
|
||||
# 排序:按分类优先级(system > error > performance > container),然后按文件名
|
||||
category_order = {'system': 0, 'error': 1, 'performance': 2, 'container': 3}
|
||||
files.sort(key=lambda f: (category_order.get(f['category'], 99), f['filename']))
|
||||
|
||||
return files
|
||||
|
||||
def get_logs_content(self, filename: str | None = None, lines: int | None = None) -> str:
|
||||
"""
|
||||
获取系统日志内容
|
||||
|
||||
Args:
|
||||
filename: 日志文件名,默认为 xingrin.log
|
||||
lines: 返回的日志行数,默认 200 行,最大 10000 行
|
||||
|
||||
Returns:
|
||||
str: 日志内容,每行以换行符分隔,保持原始顺序
|
||||
|
||||
Raises:
|
||||
ValueError: 文件名不合法
|
||||
FileNotFoundError: 日志文件不存在
|
||||
"""
|
||||
# 文件名处理
|
||||
if filename is None:
|
||||
filename = self.default_file
|
||||
|
||||
# 验证文件名
|
||||
if not self._validate_filename(filename):
|
||||
raise ValueError(f"无效的文件名: {filename}")
|
||||
|
||||
# 构建完整路径
|
||||
log_file = os.path.join(self.log_dir, filename)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.isfile(log_file):
|
||||
raise FileNotFoundError(f"日志文件不存在: {filename}")
|
||||
|
||||
# 参数校验和默认值处理
|
||||
if lines is None:
|
||||
lines = self.default_lines
|
||||
@@ -48,7 +169,7 @@ class SystemLogService:
|
||||
lines = self.max_lines
|
||||
|
||||
# 使用 tail 命令读取日志文件末尾内容
|
||||
cmd = ["tail", "-n", str(lines), self.log_file]
|
||||
cmd = ["tail", "-n", str(lines), log_file]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
from .views import LoginView, LogoutView, MeView, ChangePasswordView, SystemLogsView
|
||||
from .views import LoginView, LogoutView, MeView, ChangePasswordView, SystemLogsView, SystemLogFilesView
|
||||
|
||||
urlpatterns = [
|
||||
# 认证相关
|
||||
@@ -18,4 +18,5 @@ urlpatterns = [
|
||||
|
||||
# 系统管理
|
||||
path('system/logs/', SystemLogsView.as_view(), name='system-logs'),
|
||||
path('system/logs/files/', SystemLogFilesView.as_view(), name='system-log-files'),
|
||||
]
|
||||
|
||||
@@ -132,7 +132,8 @@ class QueryBuilder:
|
||||
cls,
|
||||
queryset: QuerySet,
|
||||
filter_groups: List[FilterGroup],
|
||||
field_mapping: Dict[str, str]
|
||||
field_mapping: Dict[str, str],
|
||||
json_array_fields: List[str] = None
|
||||
) -> QuerySet:
|
||||
"""构建 Django ORM 查询
|
||||
|
||||
@@ -140,6 +141,7 @@ class QueryBuilder:
|
||||
queryset: Django QuerySet
|
||||
filter_groups: 解析后的过滤条件组列表
|
||||
field_mapping: 字段映射
|
||||
json_array_fields: JSON 数组字段列表(使用 __contains 查询)
|
||||
|
||||
Returns:
|
||||
过滤后的 QuerySet
|
||||
@@ -147,6 +149,8 @@ class QueryBuilder:
|
||||
if not filter_groups:
|
||||
return queryset
|
||||
|
||||
json_array_fields = json_array_fields or []
|
||||
|
||||
# 构建 Q 对象
|
||||
combined_q = None
|
||||
|
||||
@@ -159,8 +163,11 @@ class QueryBuilder:
|
||||
logger.debug(f"忽略未知字段: {f.field}")
|
||||
continue
|
||||
|
||||
# 判断是否为 JSON 数组字段
|
||||
is_json_array = db_field in json_array_fields
|
||||
|
||||
# 构建单个条件的 Q 对象
|
||||
q = cls._build_single_q(db_field, f.operator, f.value)
|
||||
q = cls._build_single_q(db_field, f.operator, f.value, is_json_array)
|
||||
if q is None:
|
||||
continue
|
||||
|
||||
@@ -177,8 +184,12 @@ class QueryBuilder:
|
||||
return queryset
|
||||
|
||||
@classmethod
|
||||
def _build_single_q(cls, field: str, operator: str, value: str) -> Optional[Q]:
|
||||
def _build_single_q(cls, field: str, operator: str, value: str, is_json_array: bool = False) -> Optional[Q]:
|
||||
"""构建单个条件的 Q 对象"""
|
||||
if is_json_array:
|
||||
# JSON 数组字段使用 __contains 查询
|
||||
return Q(**{f'{field}__contains': [value]})
|
||||
|
||||
if operator == '!=':
|
||||
return cls._build_not_equal_q(field, value)
|
||||
elif operator == '==':
|
||||
@@ -219,7 +230,8 @@ class QueryBuilder:
|
||||
def apply_filters(
|
||||
queryset: QuerySet,
|
||||
query_string: str,
|
||||
field_mapping: Dict[str, str]
|
||||
field_mapping: Dict[str, str],
|
||||
json_array_fields: List[str] = None
|
||||
) -> QuerySet:
|
||||
"""应用过滤条件到 QuerySet
|
||||
|
||||
@@ -227,6 +239,7 @@ def apply_filters(
|
||||
queryset: Django QuerySet
|
||||
query_string: 查询语法字符串
|
||||
field_mapping: 字段映射
|
||||
json_array_fields: JSON 数组字段列表(使用 __contains 查询)
|
||||
|
||||
Returns:
|
||||
过滤后的 QuerySet
|
||||
@@ -242,6 +255,9 @@ def apply_filters(
|
||||
|
||||
# 混合查询
|
||||
apply_filters(qs, 'type="xss" || type="sqli" && severity="high"', mapping)
|
||||
|
||||
# JSON 数组字段查询
|
||||
apply_filters(qs, 'implies="PHP"', mapping, json_array_fields=['implies'])
|
||||
"""
|
||||
if not query_string or not query_string.strip():
|
||||
return queryset
|
||||
@@ -253,7 +269,12 @@ def apply_filters(
|
||||
return queryset
|
||||
|
||||
logger.debug(f"解析过滤条件: {filter_groups}")
|
||||
return QueryBuilder.build_query(queryset, filter_groups, field_mapping)
|
||||
return QueryBuilder.build_query(
|
||||
queryset,
|
||||
filter_groups,
|
||||
field_mapping,
|
||||
json_array_fields=json_array_fields
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"过滤解析错误: {e}, query: {query_string}")
|
||||
|
||||
@@ -7,6 +7,6 @@
|
||||
"""
|
||||
|
||||
from .auth_views import LoginView, LogoutView, MeView, ChangePasswordView
|
||||
from .system_log_views import SystemLogsView
|
||||
from .system_log_views import SystemLogsView, SystemLogFilesView
|
||||
|
||||
__all__ = ['LoginView', 'LogoutView', 'MeView', 'ChangePasswordView', 'SystemLogsView']
|
||||
__all__ = ['LoginView', 'LogoutView', 'MeView', 'ChangePasswordView', 'SystemLogsView', 'SystemLogFilesView']
|
||||
|
||||
@@ -9,7 +9,10 @@ from django.utils.decorators import method_decorator
|
||||
from rest_framework import status
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rest_framework.permissions import AllowAny
|
||||
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,9 +31,10 @@ class LoginView(APIView):
|
||||
password = request.data.get('password')
|
||||
|
||||
if not username or not password:
|
||||
return Response(
|
||||
{'error': '请提供用户名和密码'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Username and password are required',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
user = authenticate(request, username=username, password=password)
|
||||
@@ -38,20 +42,22 @@ class LoginView(APIView):
|
||||
if user is not None:
|
||||
login(request, user)
|
||||
logger.info(f"用户 {username} 登录成功")
|
||||
return Response({
|
||||
'message': '登录成功',
|
||||
'user': {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'isStaff': user.is_staff,
|
||||
'isSuperuser': user.is_superuser,
|
||||
return success_response(
|
||||
data={
|
||||
'user': {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'isStaff': user.is_staff,
|
||||
'isSuperuser': user.is_superuser,
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
else:
|
||||
logger.warning(f"用户 {username} 登录失败:用户名或密码错误")
|
||||
return Response(
|
||||
{'error': '用户名或密码错误'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
return error_response(
|
||||
code=ErrorCodes.UNAUTHORIZED,
|
||||
message='Invalid username or password',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
|
||||
@@ -79,7 +85,7 @@ class LogoutView(APIView):
|
||||
logout(request)
|
||||
else:
|
||||
logout(request)
|
||||
return Response({'message': '已登出'})
|
||||
return success_response()
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name='dispatch')
|
||||
@@ -100,22 +106,26 @@ class MeView(APIView):
|
||||
if user_id:
|
||||
try:
|
||||
user = User.objects.get(pk=user_id)
|
||||
return Response({
|
||||
'authenticated': True,
|
||||
'user': {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'isStaff': user.is_staff,
|
||||
'isSuperuser': user.is_superuser,
|
||||
return success_response(
|
||||
data={
|
||||
'authenticated': True,
|
||||
'user': {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'isStaff': user.is_staff,
|
||||
'isSuperuser': user.is_superuser,
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
except User.DoesNotExist:
|
||||
pass
|
||||
|
||||
return Response({
|
||||
'authenticated': False,
|
||||
'user': None
|
||||
})
|
||||
return success_response(
|
||||
data={
|
||||
'authenticated': False,
|
||||
'user': None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name='dispatch')
|
||||
@@ -124,43 +134,27 @@ class ChangePasswordView(APIView):
|
||||
修改密码
|
||||
POST /api/auth/change-password/
|
||||
"""
|
||||
authentication_classes = [] # 禁用认证(绕过 CSRF)
|
||||
permission_classes = [AllowAny] # 手动检查登录状态
|
||||
|
||||
def post(self, request):
|
||||
# 手动检查登录状态(从 session 获取用户)
|
||||
from django.contrib.auth import get_user_model
|
||||
User = get_user_model()
|
||||
|
||||
user_id = request.session.get('_auth_user_id')
|
||||
if not user_id:
|
||||
return Response(
|
||||
{'error': '请先登录'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
|
||||
try:
|
||||
user = User.objects.get(pk=user_id)
|
||||
except User.DoesNotExist:
|
||||
return Response(
|
||||
{'error': '用户不存在'},
|
||||
status=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
# 使用全局权限类验证,request.user 已经是认证用户
|
||||
user = request.user
|
||||
|
||||
# CamelCaseParser 将 oldPassword -> old_password
|
||||
old_password = request.data.get('old_password')
|
||||
new_password = request.data.get('new_password')
|
||||
|
||||
if not old_password or not new_password:
|
||||
return Response(
|
||||
{'error': '请提供旧密码和新密码'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Old password and new password are required',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
if not user.check_password(old_password):
|
||||
return Response(
|
||||
{'error': '旧密码错误'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Old password is incorrect',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
user.set_password(new_password)
|
||||
@@ -170,4 +164,4 @@ class ChangePasswordView(APIView):
|
||||
update_session_auth_hash(request, user)
|
||||
|
||||
logger.info(f"用户 {user.username} 已修改密码")
|
||||
return Response({'message': '密码修改成功'})
|
||||
return success_response()
|
||||
|
||||
@@ -9,16 +9,57 @@ import logging
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
from apps.common.services.system_log_service import SystemLogService
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name="dispatch")
|
||||
class SystemLogFilesView(APIView):
|
||||
"""
|
||||
日志文件列表 API 视图
|
||||
|
||||
GET /api/system/logs/files/
|
||||
获取所有可用的日志文件列表
|
||||
|
||||
Response:
|
||||
{
|
||||
"files": [
|
||||
{
|
||||
"filename": "xingrin.log",
|
||||
"category": "system",
|
||||
"size": 1048576,
|
||||
"modifiedAt": "2025-01-15T10:30:00+00:00"
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = SystemLogService()
|
||||
|
||||
def get(self, request):
|
||||
"""获取日志文件列表"""
|
||||
try:
|
||||
files = self.service.get_log_files()
|
||||
return success_response(data={"files": files})
|
||||
except Exception:
|
||||
logger.exception("获取日志文件列表失败")
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to get log files',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
@method_decorator(csrf_exempt, name="dispatch")
|
||||
class SystemLogsView(APIView):
|
||||
"""
|
||||
@@ -28,21 +69,14 @@ class SystemLogsView(APIView):
|
||||
获取系统日志内容
|
||||
|
||||
Query Parameters:
|
||||
file (str, optional): 日志文件名,默认 xingrin.log
|
||||
lines (int, optional): 返回的日志行数,默认 200,最大 10000
|
||||
|
||||
Response:
|
||||
{
|
||||
"content": "日志内容字符串..."
|
||||
}
|
||||
|
||||
Note:
|
||||
- 当前为开发阶段,暂时允许匿名访问
|
||||
- 生产环境应添加管理员权限验证
|
||||
"""
|
||||
|
||||
# TODO: 生产环境应改为 IsAdminUser 权限
|
||||
authentication_classes = []
|
||||
permission_classes = [AllowAny]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -52,18 +86,33 @@ class SystemLogsView(APIView):
|
||||
"""
|
||||
获取系统日志
|
||||
|
||||
支持通过 lines 参数控制返回行数,用于前端分页或实时刷新场景。
|
||||
支持通过 file 和 lines 参数控制返回内容。
|
||||
"""
|
||||
try:
|
||||
# 解析 lines 参数
|
||||
# 解析参数
|
||||
filename = request.query_params.get("file")
|
||||
lines_raw = request.query_params.get("lines")
|
||||
lines = int(lines_raw) if lines_raw is not None else None
|
||||
|
||||
# 调用服务获取日志内容
|
||||
content = self.service.get_logs_content(lines=lines)
|
||||
return Response({"content": content})
|
||||
except ValueError:
|
||||
return Response({"error": "lines 参数必须是整数"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
content = self.service.get_logs_content(filename=filename, lines=lines)
|
||||
return success_response(data={"content": content})
|
||||
except ValueError as e:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=str(e) if 'file' in str(e).lower() else 'lines must be an integer',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("获取系统日志失败")
|
||||
return Response({"error": "获取系统日志失败"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to get system logs',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
44
backend/apps/common/websocket_auth.py
Normal file
44
backend/apps/common/websocket_auth.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
WebSocket 认证基类
|
||||
|
||||
提供需要认证的 WebSocket Consumer 基类
|
||||
"""
|
||||
|
||||
import logging
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthenticatedWebsocketConsumer(AsyncWebsocketConsumer):
|
||||
"""
|
||||
需要认证的 WebSocket Consumer 基类
|
||||
|
||||
子类应该重写 on_connect() 方法实现具体的连接逻辑
|
||||
"""
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
连接时验证用户认证状态
|
||||
|
||||
未认证时使用 close(code=4001) 拒绝连接
|
||||
"""
|
||||
user = self.scope.get('user')
|
||||
|
||||
if not user or not user.is_authenticated:
|
||||
logger.warning(
|
||||
f"WebSocket 连接被拒绝:用户未认证 - Path: {self.scope.get('path')}"
|
||||
)
|
||||
await self.close(code=4001)
|
||||
return
|
||||
|
||||
# 调用子类的连接逻辑
|
||||
await self.on_connect()
|
||||
|
||||
async def on_connect(self):
|
||||
"""
|
||||
子类实现具体的连接逻辑
|
||||
|
||||
默认实现:接受连接
|
||||
"""
|
||||
await self.accept()
|
||||
@@ -6,17 +6,17 @@ import json
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from apps.common.websocket_auth import AuthenticatedWebsocketConsumer
|
||||
from apps.engine.services import WorkerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerDeployConsumer(AsyncWebsocketConsumer):
|
||||
class WorkerDeployConsumer(AuthenticatedWebsocketConsumer):
|
||||
"""
|
||||
Worker 交互式终端 WebSocket Consumer
|
||||
|
||||
@@ -31,8 +31,8 @@ class WorkerDeployConsumer(AsyncWebsocketConsumer):
|
||||
self.read_task = None
|
||||
self.worker_service = WorkerService()
|
||||
|
||||
async def connect(self):
|
||||
"""连接时加入对应 Worker 的组并自动建立 SSH 连接"""
|
||||
async def on_connect(self):
|
||||
"""连接时加入对应 Worker 的组并自动建立 SSH 连接(已通过认证)"""
|
||||
self.worker_id = self.scope['url_route']['kwargs']['worker_id']
|
||||
self.group_name = f'worker_deploy_{self.worker_id}'
|
||||
|
||||
@@ -242,8 +242,9 @@ class WorkerDeployConsumer(AsyncWebsocketConsumer):
|
||||
return
|
||||
|
||||
# 远程 Worker 通过 nginx HTTPS 访问(nginx 反代到后端 8888)
|
||||
# 使用 https://{PUBLIC_HOST} 而不是直连 8888 端口
|
||||
heartbeat_api_url = f"https://{public_host}" # 基础 URL,agent 会加 /api/...
|
||||
# 使用 https://{PUBLIC_HOST}:{PUBLIC_PORT} 而不是直连 8888 端口
|
||||
public_port = getattr(settings, 'PUBLIC_PORT', '8083')
|
||||
heartbeat_api_url = f"https://{public_host}:{public_port}"
|
||||
|
||||
session_name = f'xingrin_deploy_{self.worker_id}'
|
||||
remote_script_path = '/tmp/xingrin_deploy.sh'
|
||||
|
||||
@@ -90,6 +90,7 @@ class Command(BaseCommand):
|
||||
single_config,
|
||||
sort_keys=False,
|
||||
allow_unicode=True,
|
||||
default_flow_style=None,
|
||||
)
|
||||
except yaml.YAMLError as e:
|
||||
self.stdout.write(self.style.ERROR(f'生成子引擎 {scan_type} 配置失败: {e}'))
|
||||
|
||||
205
backend/apps/engine/management/commands/init_fingerprints.py
Normal file
205
backend/apps/engine/management/commands/init_fingerprints.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""初始化内置指纹库
|
||||
|
||||
- EHole 指纹: ehole.json -> 导入到数据库
|
||||
- Goby 指纹: goby.json -> 导入到数据库
|
||||
- Wappalyzer 指纹: wappalyzer.json -> 导入到数据库
|
||||
- Fingers 指纹: fingers_http.json -> 导入到数据库
|
||||
- FingerPrintHub 指纹: fingerprinthub_web.json -> 导入到数据库
|
||||
- ARL 指纹: ARL.yaml -> 导入到数据库
|
||||
|
||||
可重复执行:如果数据库已有数据则跳过,只在空库时导入。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from apps.engine.models import (
|
||||
EholeFingerprint,
|
||||
GobyFingerprint,
|
||||
WappalyzerFingerprint,
|
||||
FingersFingerprint,
|
||||
FingerPrintHubFingerprint,
|
||||
ARLFingerprint,
|
||||
)
|
||||
from apps.engine.services.fingerprints import (
|
||||
EholeFingerprintService,
|
||||
GobyFingerprintService,
|
||||
WappalyzerFingerprintService,
|
||||
FingersFingerprintService,
|
||||
FingerPrintHubFingerprintService,
|
||||
ARLFingerprintService,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 内置指纹配置
|
||||
DEFAULT_FINGERPRINTS = [
|
||||
{
|
||||
"type": "ehole",
|
||||
"filename": "ehole.json",
|
||||
"model": EholeFingerprint,
|
||||
"service": EholeFingerprintService,
|
||||
"data_key": "fingerprint", # JSON 中指纹数组的 key
|
||||
"file_format": "json",
|
||||
},
|
||||
{
|
||||
"type": "goby",
|
||||
"filename": "goby.json",
|
||||
"model": GobyFingerprint,
|
||||
"service": GobyFingerprintService,
|
||||
"data_key": None, # Goby 是数组格式,直接使用整个 JSON
|
||||
"file_format": "json",
|
||||
},
|
||||
{
|
||||
"type": "wappalyzer",
|
||||
"filename": "wappalyzer.json",
|
||||
"model": WappalyzerFingerprint,
|
||||
"service": WappalyzerFingerprintService,
|
||||
"data_key": "apps", # Wappalyzer 使用 apps 对象
|
||||
"file_format": "json",
|
||||
},
|
||||
{
|
||||
"type": "fingers",
|
||||
"filename": "fingers_http.json",
|
||||
"model": FingersFingerprint,
|
||||
"service": FingersFingerprintService,
|
||||
"data_key": None, # Fingers 是数组格式
|
||||
"file_format": "json",
|
||||
},
|
||||
{
|
||||
"type": "fingerprinthub",
|
||||
"filename": "fingerprinthub_web.json",
|
||||
"model": FingerPrintHubFingerprint,
|
||||
"service": FingerPrintHubFingerprintService,
|
||||
"data_key": None, # FingerPrintHub 是数组格式
|
||||
"file_format": "json",
|
||||
},
|
||||
{
|
||||
"type": "arl",
|
||||
"filename": "ARL.yaml",
|
||||
"model": ARLFingerprint,
|
||||
"service": ARLFingerprintService,
|
||||
"data_key": None, # ARL 是 YAML 数组格式
|
||||
"file_format": "yaml",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "初始化内置指纹库"
|
||||
|
||||
def handle(self, *args, **options):
|
||||
project_base = Path(settings.BASE_DIR).parent # /app/backend -> /app
|
||||
fingerprints_dir = project_base / "backend" / "fingerprints"
|
||||
|
||||
initialized = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
|
||||
for item in DEFAULT_FINGERPRINTS:
|
||||
fp_type = item["type"]
|
||||
filename = item["filename"]
|
||||
model = item["model"]
|
||||
service_class = item["service"]
|
||||
data_key = item["data_key"]
|
||||
file_format = item.get("file_format", "json")
|
||||
|
||||
# 检查数据库是否已有数据
|
||||
existing_count = model.objects.count()
|
||||
if existing_count > 0:
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{fp_type}] 数据库已有 {existing_count} 条记录,跳过初始化"
|
||||
))
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# 查找源文件
|
||||
src_path = fingerprints_dir / filename
|
||||
if not src_path.exists():
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f"[{fp_type}] 未找到内置指纹文件: {src_path},跳过"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 读取并解析文件(支持 JSON 和 YAML)
|
||||
try:
|
||||
with open(src_path, "r", encoding="utf-8") as f:
|
||||
if file_format == "yaml":
|
||||
file_data = yaml.safe_load(f)
|
||||
else:
|
||||
file_data = json.load(f)
|
||||
except (json.JSONDecodeError, yaml.YAMLError, OSError) as exc:
|
||||
self.stdout.write(self.style.ERROR(
|
||||
f"[{fp_type}] 读取指纹文件失败: {exc}"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 提取指纹数据(根据不同格式处理)
|
||||
fingerprints = self._extract_fingerprints(file_data, data_key, fp_type)
|
||||
if not fingerprints:
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f"[{fp_type}] 指纹文件中没有有效数据,跳过"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 使用 Service 批量导入
|
||||
try:
|
||||
service = service_class()
|
||||
result = service.batch_create_fingerprints(fingerprints)
|
||||
created = result.get("created", 0)
|
||||
failed_count = result.get("failed", 0)
|
||||
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{fp_type}] 导入成功: 创建 {created} 条,失败 {failed_count} 条"
|
||||
))
|
||||
initialized += 1
|
||||
except Exception as exc:
|
||||
self.stdout.write(self.style.ERROR(
|
||||
f"[{fp_type}] 导入失败: {exc}"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"指纹初始化完成: 成功 {initialized}, 已存在跳过 {skipped}, 失败 {failed}"
|
||||
))
|
||||
|
||||
def _extract_fingerprints(self, json_data, data_key, fp_type):
|
||||
"""
|
||||
根据不同格式提取指纹数据,兼容数组和对象两种格式
|
||||
|
||||
支持的格式:
|
||||
- 数组格式: [...] 或 {"key": [...]}
|
||||
- 对象格式: {...} 或 {"key": {...}} -> 转换为 [{"name": k, ...v}]
|
||||
"""
|
||||
# 获取目标数据
|
||||
if data_key is None:
|
||||
# 直接使用整个 JSON
|
||||
target = json_data
|
||||
else:
|
||||
# 从指定 key 获取,支持多个可能的 key(如 apps/technologies)
|
||||
if data_key == "apps":
|
||||
target = json_data.get("apps") or json_data.get("technologies") or {}
|
||||
else:
|
||||
target = json_data.get(data_key, [])
|
||||
|
||||
# 根据数据类型处理
|
||||
if isinstance(target, list):
|
||||
# 已经是数组格式,直接返回
|
||||
return target
|
||||
elif isinstance(target, dict):
|
||||
# 对象格式,转换为数组 [{"name": key, ...value}]
|
||||
return [{"name": name, **data} if isinstance(data, dict) else {"name": name}
|
||||
for name, data in target.items()]
|
||||
|
||||
return []
|
||||
@@ -3,12 +3,17 @@
|
||||
项目安装后执行此命令,自动创建官方模板仓库记录。
|
||||
|
||||
使用方式:
|
||||
python manage.py init_nuclei_templates # 只创建记录
|
||||
python manage.py init_nuclei_templates # 只创建记录(检测本地已有仓库)
|
||||
python manage.py init_nuclei_templates --sync # 创建并同步(git clone)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.engine.models import NucleiTemplateRepo
|
||||
from apps.engine.services import NucleiTemplateRepoService
|
||||
@@ -26,6 +31,20 @@ DEFAULT_REPOS = [
|
||||
]
|
||||
|
||||
|
||||
def get_local_commit_hash(local_path: Path) -> str:
|
||||
"""获取本地 Git 仓库的 commit hash"""
|
||||
if not (local_path / ".git").is_dir():
|
||||
return ""
|
||||
result = subprocess.run(
|
||||
["git", "-C", str(local_path), "rev-parse", "HEAD"],
|
||||
check=False,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
return result.stdout.strip() if result.returncode == 0 else ""
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "初始化 Nuclei 模板仓库(创建官方模板仓库记录)"
|
||||
|
||||
@@ -46,6 +65,8 @@ class Command(BaseCommand):
|
||||
force = options.get("force", False)
|
||||
|
||||
service = NucleiTemplateRepoService()
|
||||
base_dir = Path(getattr(settings, "NUCLEI_TEMPLATES_REPOS_BASE_DIR", "/opt/xingrin/nuclei-repos"))
|
||||
|
||||
created = 0
|
||||
skipped = 0
|
||||
synced = 0
|
||||
@@ -87,20 +108,30 @@ class Command(BaseCommand):
|
||||
|
||||
# 创建新仓库记录
|
||||
try:
|
||||
# 检查本地是否已有仓库(由 install.sh 预下载)
|
||||
local_path = base_dir / name
|
||||
local_commit = get_local_commit_hash(local_path)
|
||||
|
||||
repo = NucleiTemplateRepo.objects.create(
|
||||
name=name,
|
||||
repo_url=repo_url,
|
||||
local_path=str(local_path) if local_commit else "",
|
||||
commit_hash=local_commit,
|
||||
last_synced_at=timezone.now() if local_commit else None,
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{name}] 创建成功: id={repo.id}"
|
||||
))
|
||||
|
||||
if local_commit:
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{name}] 创建成功(检测到本地仓库): commit={local_commit[:8]}"
|
||||
))
|
||||
else:
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f"[{name}] 创建成功: id={repo.id}"
|
||||
))
|
||||
created += 1
|
||||
|
||||
# 初始化本地路径
|
||||
service.ensure_local_path(repo)
|
||||
|
||||
# 如果需要同步
|
||||
if do_sync:
|
||||
# 如果本地没有仓库且需要同步
|
||||
if not local_commit and do_sync:
|
||||
try:
|
||||
self.stdout.write(self.style.WARNING(
|
||||
f"[{name}] 正在同步(首次可能需要几分钟)..."
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""初始化所有内置字典 Wordlist 记录
|
||||
|
||||
- 目录扫描默认字典: dir_default.txt -> /app/backend/wordlist/dir_default.txt
|
||||
- 子域名爆破默认字典: subdomains-top1million-110000.txt -> /app/backend/wordlist/subdomains-top1million-110000.txt
|
||||
内置字典从镜像内 /app/backend/wordlist/ 复制到运行时目录 /opt/xingrin/wordlists/:
|
||||
- 目录扫描默认字典: dir_default.txt
|
||||
- 子域名爆破默认字典: subdomains-top1million-110000.txt
|
||||
|
||||
可重复执行:如果已存在同名记录且文件有效则跳过,只在缺失或文件丢失时创建/修复。
|
||||
"""
|
||||
|
||||
29
backend/apps/engine/models/__init__.py
Normal file
29
backend/apps/engine/models/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Engine Models
|
||||
|
||||
导出所有 Engine 模块的 Models
|
||||
"""
|
||||
|
||||
from .engine import WorkerNode, ScanEngine, Wordlist, NucleiTemplateRepo
|
||||
from .fingerprints import (
|
||||
EholeFingerprint,
|
||||
GobyFingerprint,
|
||||
WappalyzerFingerprint,
|
||||
FingersFingerprint,
|
||||
FingerPrintHubFingerprint,
|
||||
ARLFingerprint,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 核心 Models
|
||||
"WorkerNode",
|
||||
"ScanEngine",
|
||||
"Wordlist",
|
||||
"NucleiTemplateRepo",
|
||||
# 指纹 Models
|
||||
"EholeFingerprint",
|
||||
"GobyFingerprint",
|
||||
"WappalyzerFingerprint",
|
||||
"FingersFingerprint",
|
||||
"FingerPrintHubFingerprint",
|
||||
"ARLFingerprint",
|
||||
]
|
||||
@@ -1,3 +1,8 @@
|
||||
"""Engine 模块核心 Models
|
||||
|
||||
包含 WorkerNode, ScanEngine, Wordlist, NucleiTemplateRepo
|
||||
"""
|
||||
|
||||
from django.db import models
|
||||
|
||||
|
||||
@@ -78,6 +83,7 @@ class ScanEngine(models.Model):
|
||||
indexes = [
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.name or f'ScanEngine {self.id}')
|
||||
|
||||
195
backend/apps/engine/models/fingerprints.py
Normal file
195
backend/apps/engine/models/fingerprints.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""指纹相关 Models
|
||||
|
||||
包含 EHole、Goby、Wappalyzer 等指纹格式的数据模型
|
||||
"""
|
||||
|
||||
from django.db import models
|
||||
|
||||
|
||||
class GobyFingerprint(models.Model):
|
||||
"""Goby 格式指纹规则
|
||||
|
||||
Goby 使用逻辑表达式和规则数组进行匹配:
|
||||
- logic: 逻辑表达式,如 "a||b", "(a&&b)||c"
|
||||
- rule: 规则数组,每条规则包含 label, feature, is_equal
|
||||
"""
|
||||
|
||||
name = models.CharField(max_length=300, unique=True, help_text='产品名称')
|
||||
logic = models.CharField(max_length=500, help_text='逻辑表达式')
|
||||
rule = models.JSONField(default=list, help_text='规则数组')
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'goby_fingerprint'
|
||||
verbose_name = 'Goby 指纹'
|
||||
verbose_name_plural = 'Goby 指纹'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['name']),
|
||||
models.Index(fields=['logic']),
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name} ({self.logic})"
|
||||
|
||||
|
||||
class EholeFingerprint(models.Model):
|
||||
"""EHole 格式指纹规则(字段与 ehole.json 一致)"""
|
||||
|
||||
cms = models.CharField(max_length=200, help_text='产品/CMS名称')
|
||||
method = models.CharField(max_length=200, default='keyword', help_text='匹配方式')
|
||||
location = models.CharField(max_length=200, default='body', help_text='匹配位置')
|
||||
keyword = models.JSONField(default=list, help_text='关键词列表')
|
||||
is_important = models.BooleanField(default=False, help_text='是否重点资产')
|
||||
type = models.CharField(max_length=100, blank=True, default='-', help_text='分类')
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'ehole_fingerprint'
|
||||
verbose_name = 'EHole 指纹'
|
||||
verbose_name_plural = 'EHole 指纹'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
# 搜索过滤字段索引
|
||||
models.Index(fields=['cms']),
|
||||
models.Index(fields=['method']),
|
||||
models.Index(fields=['location']),
|
||||
models.Index(fields=['type']),
|
||||
models.Index(fields=['is_important']),
|
||||
# 排序字段索引
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
constraints = [
|
||||
# 唯一约束:cms + method + location 组合不能重复
|
||||
models.UniqueConstraint(
|
||||
fields=['cms', 'method', 'location'],
|
||||
name='unique_ehole_fingerprint'
|
||||
),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.cms} ({self.method}@{self.location})"
|
||||
|
||||
|
||||
class WappalyzerFingerprint(models.Model):
|
||||
"""Wappalyzer 格式指纹规则
|
||||
|
||||
Wappalyzer 支持多种检测方式:cookies, headers, scriptSrc, js, meta, html 等
|
||||
"""
|
||||
|
||||
name = models.CharField(max_length=300, unique=True, help_text='应用名称')
|
||||
cats = models.JSONField(default=list, help_text='分类 ID 数组')
|
||||
cookies = models.JSONField(default=dict, blank=True, help_text='Cookie 检测规则')
|
||||
headers = models.JSONField(default=dict, blank=True, help_text='HTTP Header 检测规则')
|
||||
script_src = models.JSONField(default=list, blank=True, help_text='脚本 URL 正则数组')
|
||||
js = models.JSONField(default=list, blank=True, help_text='JavaScript 变量检测规则')
|
||||
implies = models.JSONField(default=list, blank=True, help_text='依赖关系数组')
|
||||
meta = models.JSONField(default=dict, blank=True, help_text='HTML meta 标签检测规则')
|
||||
html = models.JSONField(default=list, blank=True, help_text='HTML 内容正则数组')
|
||||
description = models.TextField(blank=True, default='', help_text='应用描述')
|
||||
website = models.URLField(max_length=500, blank=True, default='', help_text='官网链接')
|
||||
cpe = models.CharField(max_length=300, blank=True, default='', help_text='CPE 标识符')
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'wappalyzer_fingerprint'
|
||||
verbose_name = 'Wappalyzer 指纹'
|
||||
verbose_name_plural = 'Wappalyzer 指纹'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['name']),
|
||||
models.Index(fields=['website']),
|
||||
models.Index(fields=['cpe']),
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}"
|
||||
|
||||
|
||||
class FingersFingerprint(models.Model):
|
||||
"""Fingers 格式指纹规则 (fingers_http.json)
|
||||
|
||||
使用正则表达式和标签进行匹配,支持 favicon hash、header、body 等多种检测方式
|
||||
"""
|
||||
|
||||
name = models.CharField(max_length=300, unique=True, help_text='指纹名称')
|
||||
link = models.URLField(max_length=500, blank=True, default='', help_text='相关链接')
|
||||
rule = models.JSONField(default=list, help_text='匹配规则数组')
|
||||
tag = models.JSONField(default=list, help_text='标签数组')
|
||||
focus = models.BooleanField(default=False, help_text='是否重点关注')
|
||||
default_port = models.JSONField(default=list, blank=True, help_text='默认端口数组')
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'fingers_fingerprint'
|
||||
verbose_name = 'Fingers 指纹'
|
||||
verbose_name_plural = 'Fingers 指纹'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['name']),
|
||||
models.Index(fields=['link']),
|
||||
models.Index(fields=['focus']),
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}"
|
||||
|
||||
|
||||
class FingerPrintHubFingerprint(models.Model):
|
||||
"""FingerPrintHub 格式指纹规则 (fingerprinthub_web.json)
|
||||
|
||||
基于 nuclei 模板格式,使用 HTTP 请求和响应特征进行匹配
|
||||
"""
|
||||
|
||||
fp_id = models.CharField(max_length=200, unique=True, help_text='指纹ID')
|
||||
name = models.CharField(max_length=300, help_text='指纹名称')
|
||||
author = models.CharField(max_length=200, blank=True, default='', help_text='作者')
|
||||
tags = models.CharField(max_length=500, blank=True, default='', help_text='标签')
|
||||
severity = models.CharField(max_length=50, blank=True, default='info', help_text='严重程度')
|
||||
metadata = models.JSONField(default=dict, blank=True, help_text='元数据')
|
||||
http = models.JSONField(default=list, help_text='HTTP 匹配规则')
|
||||
source_file = models.CharField(max_length=500, blank=True, default='', help_text='来源文件')
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'fingerprinthub_fingerprint'
|
||||
verbose_name = 'FingerPrintHub 指纹'
|
||||
verbose_name_plural = 'FingerPrintHub 指纹'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['fp_id']),
|
||||
models.Index(fields=['name']),
|
||||
models.Index(fields=['author']),
|
||||
models.Index(fields=['severity']),
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name} ({self.fp_id})"
|
||||
|
||||
|
||||
class ARLFingerprint(models.Model):
|
||||
"""ARL 格式指纹规则 (ARL.yaml)
|
||||
|
||||
使用简单的 name + rule 表达式格式
|
||||
"""
|
||||
|
||||
name = models.CharField(max_length=300, unique=True, help_text='指纹名称')
|
||||
rule = models.TextField(help_text='匹配规则表达式')
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
class Meta:
|
||||
db_table = 'arl_fingerprint'
|
||||
verbose_name = 'ARL 指纹'
|
||||
verbose_name_plural = 'ARL 指纹'
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['name']),
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}"
|
||||
20
backend/apps/engine/serializers/fingerprints/__init__.py
Normal file
20
backend/apps/engine/serializers/fingerprints/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""指纹管理 Serializers
|
||||
|
||||
导出所有指纹相关的 Serializer 类
|
||||
"""
|
||||
|
||||
from .ehole import EholeFingerprintSerializer
|
||||
from .goby import GobyFingerprintSerializer
|
||||
from .wappalyzer import WappalyzerFingerprintSerializer
|
||||
from .fingers import FingersFingerprintSerializer
|
||||
from .fingerprinthub import FingerPrintHubFingerprintSerializer
|
||||
from .arl import ARLFingerprintSerializer
|
||||
|
||||
__all__ = [
|
||||
"EholeFingerprintSerializer",
|
||||
"GobyFingerprintSerializer",
|
||||
"WappalyzerFingerprintSerializer",
|
||||
"FingersFingerprintSerializer",
|
||||
"FingerPrintHubFingerprintSerializer",
|
||||
"ARLFingerprintSerializer",
|
||||
]
|
||||
31
backend/apps/engine/serializers/fingerprints/arl.py
Normal file
31
backend/apps/engine/serializers/fingerprints/arl.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""ARL 指纹 Serializer"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.engine.models import ARLFingerprint
|
||||
|
||||
|
||||
class ARLFingerprintSerializer(serializers.ModelSerializer):
|
||||
"""ARL 指纹序列化器
|
||||
|
||||
字段映射:
|
||||
- name: 指纹名称 (必填, 唯一)
|
||||
- rule: 匹配规则表达式 (必填)
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
model = ARLFingerprint
|
||||
fields = ['id', 'name', 'rule', 'created_at']
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
def validate_name(self, value):
|
||||
"""校验 name 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("name 字段不能为空")
|
||||
return value.strip()
|
||||
|
||||
def validate_rule(self, value):
|
||||
"""校验 rule 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("rule 字段不能为空")
|
||||
return value.strip()
|
||||
27
backend/apps/engine/serializers/fingerprints/ehole.py
Normal file
27
backend/apps/engine/serializers/fingerprints/ehole.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""EHole 指纹 Serializer"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.engine.models import EholeFingerprint
|
||||
|
||||
|
||||
class EholeFingerprintSerializer(serializers.ModelSerializer):
|
||||
"""EHole 指纹序列化器"""
|
||||
|
||||
class Meta:
|
||||
model = EholeFingerprint
|
||||
fields = ['id', 'cms', 'method', 'location', 'keyword',
|
||||
'is_important', 'type', 'created_at']
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
def validate_cms(self, value):
|
||||
"""校验 cms 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("cms 字段不能为空")
|
||||
return value.strip()
|
||||
|
||||
def validate_keyword(self, value):
|
||||
"""校验 keyword 字段"""
|
||||
if not isinstance(value, list):
|
||||
raise serializers.ValidationError("keyword 必须是数组")
|
||||
return value
|
||||
@@ -0,0 +1,50 @@
|
||||
"""FingerPrintHub 指纹 Serializer"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.engine.models import FingerPrintHubFingerprint
|
||||
|
||||
|
||||
class FingerPrintHubFingerprintSerializer(serializers.ModelSerializer):
|
||||
"""FingerPrintHub 指纹序列化器
|
||||
|
||||
字段映射:
|
||||
- fp_id: 指纹ID (必填, 唯一)
|
||||
- name: 指纹名称 (必填)
|
||||
- author: 作者 (可选)
|
||||
- tags: 标签字符串 (可选)
|
||||
- severity: 严重程度 (可选, 默认 'info')
|
||||
- metadata: 元数据 JSON (可选)
|
||||
- http: HTTP 匹配规则数组 (必填)
|
||||
- source_file: 来源文件 (可选)
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
model = FingerPrintHubFingerprint
|
||||
fields = ['id', 'fp_id', 'name', 'author', 'tags', 'severity',
|
||||
'metadata', 'http', 'source_file', 'created_at']
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
def validate_fp_id(self, value):
|
||||
"""校验 fp_id 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("fp_id 字段不能为空")
|
||||
return value.strip()
|
||||
|
||||
def validate_name(self, value):
|
||||
"""校验 name 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("name 字段不能为空")
|
||||
return value.strip()
|
||||
|
||||
def validate_http(self, value):
|
||||
"""校验 http 字段"""
|
||||
if not isinstance(value, list):
|
||||
raise serializers.ValidationError("http 必须是数组")
|
||||
return value
|
||||
|
||||
def validate_metadata(self, value):
|
||||
"""校验 metadata 字段"""
|
||||
if not isinstance(value, dict):
|
||||
raise serializers.ValidationError("metadata 必须是对象")
|
||||
return value
|
||||
48
backend/apps/engine/serializers/fingerprints/fingers.py
Normal file
48
backend/apps/engine/serializers/fingerprints/fingers.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Fingers 指纹 Serializer"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.engine.models import FingersFingerprint
|
||||
|
||||
|
||||
class FingersFingerprintSerializer(serializers.ModelSerializer):
|
||||
"""Fingers 指纹序列化器
|
||||
|
||||
字段映射:
|
||||
- name: 指纹名称 (必填, 唯一)
|
||||
- link: 相关链接 (可选)
|
||||
- rule: 匹配规则数组 (必填)
|
||||
- tag: 标签数组 (可选)
|
||||
- focus: 是否重点关注 (可选, 默认 False)
|
||||
- default_port: 默认端口数组 (可选)
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
model = FingersFingerprint
|
||||
fields = ['id', 'name', 'link', 'rule', 'tag', 'focus',
|
||||
'default_port', 'created_at']
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
def validate_name(self, value):
|
||||
"""校验 name 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("name 字段不能为空")
|
||||
return value.strip()
|
||||
|
||||
def validate_rule(self, value):
|
||||
"""校验 rule 字段"""
|
||||
if not isinstance(value, list):
|
||||
raise serializers.ValidationError("rule 必须是数组")
|
||||
return value
|
||||
|
||||
def validate_tag(self, value):
|
||||
"""校验 tag 字段"""
|
||||
if not isinstance(value, list):
|
||||
raise serializers.ValidationError("tag 必须是数组")
|
||||
return value
|
||||
|
||||
def validate_default_port(self, value):
|
||||
"""校验 default_port 字段"""
|
||||
if not isinstance(value, list):
|
||||
raise serializers.ValidationError("default_port 必须是数组")
|
||||
return value
|
||||
26
backend/apps/engine/serializers/fingerprints/goby.py
Normal file
26
backend/apps/engine/serializers/fingerprints/goby.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Goby 指纹 Serializer"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.engine.models import GobyFingerprint
|
||||
|
||||
|
||||
class GobyFingerprintSerializer(serializers.ModelSerializer):
|
||||
"""Goby 指纹序列化器"""
|
||||
|
||||
class Meta:
|
||||
model = GobyFingerprint
|
||||
fields = ['id', 'name', 'logic', 'rule', 'created_at']
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
def validate_name(self, value):
|
||||
"""校验 name 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("name 字段不能为空")
|
||||
return value.strip()
|
||||
|
||||
def validate_rule(self, value):
|
||||
"""校验 rule 字段"""
|
||||
if not isinstance(value, list):
|
||||
raise serializers.ValidationError("rule 必须是数组")
|
||||
return value
|
||||
24
backend/apps/engine/serializers/fingerprints/wappalyzer.py
Normal file
24
backend/apps/engine/serializers/fingerprints/wappalyzer.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Wappalyzer 指纹 Serializer"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.engine.models import WappalyzerFingerprint
|
||||
|
||||
|
||||
class WappalyzerFingerprintSerializer(serializers.ModelSerializer):
|
||||
"""Wappalyzer 指纹序列化器"""
|
||||
|
||||
class Meta:
|
||||
model = WappalyzerFingerprint
|
||||
fields = [
|
||||
'id', 'name', 'cats', 'cookies', 'headers', 'script_src',
|
||||
'js', 'implies', 'meta', 'html', 'description', 'website',
|
||||
'cpe', 'created_at'
|
||||
]
|
||||
read_only_fields = ['id', 'created_at']
|
||||
|
||||
def validate_name(self, value):
|
||||
"""校验 name 字段"""
|
||||
if not value or not value.strip():
|
||||
raise serializers.ValidationError("name 字段不能为空")
|
||||
return value.strip()
|
||||
@@ -66,6 +66,7 @@ def get_start_agent_script(
|
||||
# 替换变量
|
||||
script = script.replace("{{HEARTBEAT_API_URL}}", heartbeat_api_url or '')
|
||||
script = script.replace("{{WORKER_ID}}", str(worker_id) if worker_id else '')
|
||||
script = script.replace("{{WORKER_API_KEY}}", getattr(settings, 'WORKER_API_KEY', ''))
|
||||
|
||||
# 注入镜像版本配置(确保远程节点使用相同版本)
|
||||
docker_user = getattr(settings, 'DOCKER_USER', 'yyhuni')
|
||||
|
||||
22
backend/apps/engine/services/fingerprints/__init__.py
Normal file
22
backend/apps/engine/services/fingerprints/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""指纹管理 Services
|
||||
|
||||
导出所有指纹相关的 Service 类
|
||||
"""
|
||||
|
||||
from .base import BaseFingerprintService
|
||||
from .ehole import EholeFingerprintService
|
||||
from .goby import GobyFingerprintService
|
||||
from .wappalyzer import WappalyzerFingerprintService
|
||||
from .fingers_service import FingersFingerprintService
|
||||
from .fingerprinthub_service import FingerPrintHubFingerprintService
|
||||
from .arl_service import ARLFingerprintService
|
||||
|
||||
__all__ = [
|
||||
"BaseFingerprintService",
|
||||
"EholeFingerprintService",
|
||||
"GobyFingerprintService",
|
||||
"WappalyzerFingerprintService",
|
||||
"FingersFingerprintService",
|
||||
"FingerPrintHubFingerprintService",
|
||||
"ARLFingerprintService",
|
||||
]
|
||||
110
backend/apps/engine/services/fingerprints/arl_service.py
Normal file
110
backend/apps/engine/services/fingerprints/arl_service.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""ARL 指纹管理 Service
|
||||
|
||||
实现 ARL 格式指纹的校验、转换和导出逻辑
|
||||
支持 YAML 格式的导入导出
|
||||
"""
|
||||
|
||||
import logging
|
||||
import yaml
|
||||
|
||||
from apps.engine.models import ARLFingerprint
|
||||
from .base import BaseFingerprintService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ARLFingerprintService(BaseFingerprintService):
|
||||
"""ARL 指纹管理服务(继承基类,实现 ARL 特定逻辑)"""
|
||||
|
||||
model = ARLFingerprint
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条 ARL 指纹
|
||||
|
||||
校验规则:
|
||||
- name 字段必须存在且非空
|
||||
- rule 字段必须存在且非空
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
name = item.get('name', '')
|
||||
rule = item.get('rule', '')
|
||||
return bool(name and str(name).strip()) and bool(rule and str(rule).strip())
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换 ARL YAML 格式为 Model 字段
|
||||
|
||||
Args:
|
||||
item: 原始 ARL YAML 数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
return {
|
||||
'name': str(item.get('name', '')).strip(),
|
||||
'rule': str(item.get('rule', '')).strip(),
|
||||
}
|
||||
|
||||
def get_export_data(self) -> list:
|
||||
"""
|
||||
获取导出数据(ARL 格式 - 数组,用于 YAML 导出)
|
||||
|
||||
Returns:
|
||||
list: ARL 格式的数据(数组格式)
|
||||
[
|
||||
{"name": "...", "rule": "..."},
|
||||
...
|
||||
]
|
||||
"""
|
||||
fingerprints = self.model.objects.all()
|
||||
return [
|
||||
{
|
||||
'name': fp.name,
|
||||
'rule': fp.rule,
|
||||
}
|
||||
for fp in fingerprints
|
||||
]
|
||||
|
||||
def export_to_yaml(self, output_path: str) -> int:
|
||||
"""
|
||||
导出所有指纹到 YAML 文件
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 导出的指纹数量
|
||||
"""
|
||||
data = self.get_export_data()
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(data, f, allow_unicode=True, default_flow_style=False, sort_keys=False)
|
||||
count = len(data)
|
||||
logger.info("导出 ARL 指纹文件: %s, 数量: %d", output_path, count)
|
||||
return count
|
||||
|
||||
def parse_yaml_import(self, yaml_content: str) -> list:
|
||||
"""
|
||||
解析 YAML 格式的导入内容
|
||||
|
||||
Args:
|
||||
yaml_content: YAML 格式的字符串内容
|
||||
|
||||
Returns:
|
||||
list: 解析后的指纹数据列表
|
||||
|
||||
Raises:
|
||||
ValueError: 当 YAML 格式无效时
|
||||
"""
|
||||
try:
|
||||
data = yaml.safe_load(yaml_content)
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("ARL YAML 文件必须是数组格式")
|
||||
return data
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"无效的 YAML 格式: {e}")
|
||||
144
backend/apps/engine/services/fingerprints/base.py
Normal file
144
backend/apps/engine/services/fingerprints/base.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""指纹管理基类 Service
|
||||
|
||||
提供通用的批量操作和缓存逻辑,供 EHole/Goby/Wappalyzer 等子类继承
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseFingerprintService:
|
||||
"""指纹管理基类 Service,提供通用的批量操作和缓存逻辑"""
|
||||
|
||||
model = None # 子类必须指定
|
||||
BATCH_SIZE = 1000 # 每批处理数量
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条指纹,子类必须实现
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 validate_fingerprint 方法")
|
||||
|
||||
def validate_fingerprints(self, raw_data: list) -> tuple[list, list]:
|
||||
"""
|
||||
批量校验指纹数据
|
||||
|
||||
Args:
|
||||
raw_data: 原始指纹数据列表
|
||||
|
||||
Returns:
|
||||
tuple: (valid_items, invalid_items)
|
||||
"""
|
||||
valid, invalid = [], []
|
||||
for item in raw_data:
|
||||
if self.validate_fingerprint(item):
|
||||
valid.append(item)
|
||||
else:
|
||||
invalid.append(item)
|
||||
return valid, invalid
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换为 Model 字段,子类必须实现
|
||||
|
||||
Args:
|
||||
item: 原始指纹数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 to_model_data 方法")
|
||||
|
||||
def bulk_create(self, fingerprints: list) -> int:
|
||||
"""
|
||||
批量创建指纹记录(已校验的数据)
|
||||
|
||||
Args:
|
||||
fingerprints: 已校验的指纹数据列表
|
||||
|
||||
Returns:
|
||||
int: 成功创建数量
|
||||
"""
|
||||
if not fingerprints:
|
||||
return 0
|
||||
|
||||
objects = [self.model(**self.to_model_data(item)) for item in fingerprints]
|
||||
created = self.model.objects.bulk_create(objects, ignore_conflicts=True)
|
||||
return len(created)
|
||||
|
||||
def batch_create_fingerprints(self, raw_data: list) -> dict:
|
||||
"""
|
||||
完整流程:分批校验 + 批量创建
|
||||
|
||||
Args:
|
||||
raw_data: 原始指纹数据列表
|
||||
|
||||
Returns:
|
||||
dict: {'created': int, 'failed': int}
|
||||
"""
|
||||
total_created = 0
|
||||
total_failed = 0
|
||||
|
||||
for i in range(0, len(raw_data), self.BATCH_SIZE):
|
||||
batch = raw_data[i:i + self.BATCH_SIZE]
|
||||
valid, invalid = self.validate_fingerprints(batch)
|
||||
total_created += self.bulk_create(valid)
|
||||
total_failed += len(invalid)
|
||||
|
||||
logger.info(
|
||||
"批量创建指纹完成: created=%d, failed=%d, total=%d",
|
||||
total_created, total_failed, len(raw_data)
|
||||
)
|
||||
return {'created': total_created, 'failed': total_failed}
|
||||
|
||||
def get_export_data(self) -> dict:
|
||||
"""
|
||||
获取导出数据,子类必须实现
|
||||
|
||||
Returns:
|
||||
dict: 导出的 JSON 数据
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 get_export_data 方法")
|
||||
|
||||
def export_to_file(self, output_path: str) -> int:
|
||||
"""
|
||||
导出所有指纹到 JSON 文件
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 导出的指纹数量
|
||||
"""
|
||||
data = self.get_export_data()
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False)
|
||||
count = len(data.get('fingerprint', []))
|
||||
logger.info("导出指纹文件: %s, 数量: %d", output_path, count)
|
||||
return count
|
||||
|
||||
def get_fingerprint_version(self) -> str:
|
||||
"""
|
||||
获取指纹库版本标识(用于缓存校验)
|
||||
|
||||
Returns:
|
||||
str: 版本标识,格式 "{count}_{latest_timestamp}"
|
||||
|
||||
版本变化场景:
|
||||
- 新增记录 → count 变化
|
||||
- 删除记录 → count 变化
|
||||
- 清空全部 → count 变为 0
|
||||
"""
|
||||
count = self.model.objects.count()
|
||||
latest = self.model.objects.order_by('-created_at').first()
|
||||
latest_ts = int(latest.created_at.timestamp()) if latest else 0
|
||||
return f"{count}_{latest_ts}"
|
||||
84
backend/apps/engine/services/fingerprints/ehole.py
Normal file
84
backend/apps/engine/services/fingerprints/ehole.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""EHole 指纹管理 Service
|
||||
|
||||
实现 EHole 格式指纹的校验、转换和导出逻辑
|
||||
"""
|
||||
|
||||
from apps.engine.models import EholeFingerprint
|
||||
from .base import BaseFingerprintService
|
||||
|
||||
|
||||
class EholeFingerprintService(BaseFingerprintService):
|
||||
"""EHole 指纹管理服务(继承基类,实现 EHole 特定逻辑)"""
|
||||
|
||||
model = EholeFingerprint
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条 EHole 指纹
|
||||
|
||||
校验规则:
|
||||
- cms 字段必须存在且非空
|
||||
- keyword 字段必须是数组
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
cms = item.get('cms', '')
|
||||
keyword = item.get('keyword')
|
||||
return bool(cms and str(cms).strip()) and isinstance(keyword, list)
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换 EHole JSON 格式为 Model 字段
|
||||
|
||||
字段映射:
|
||||
- isImportant (JSON) → is_important (Model)
|
||||
|
||||
Args:
|
||||
item: 原始 EHole JSON 数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
return {
|
||||
'cms': str(item.get('cms', '')).strip(),
|
||||
'method': item.get('method', 'keyword'),
|
||||
'location': item.get('location', 'body'),
|
||||
'keyword': item.get('keyword', []),
|
||||
'is_important': item.get('isImportant', False),
|
||||
'type': item.get('type', '-'),
|
||||
}
|
||||
|
||||
def get_export_data(self) -> dict:
|
||||
"""
|
||||
获取导出数据(EHole JSON 格式)
|
||||
|
||||
Returns:
|
||||
dict: EHole 格式的 JSON 数据
|
||||
{
|
||||
"fingerprint": [
|
||||
{"cms": "...", "method": "...", "location": "...",
|
||||
"keyword": [...], "isImportant": false, "type": "..."},
|
||||
...
|
||||
],
|
||||
"version": "1000_1703836800"
|
||||
}
|
||||
"""
|
||||
fingerprints = self.model.objects.all()
|
||||
data = []
|
||||
for fp in fingerprints:
|
||||
data.append({
|
||||
'cms': fp.cms,
|
||||
'method': fp.method,
|
||||
'location': fp.location,
|
||||
'keyword': fp.keyword,
|
||||
'isImportant': fp.is_important, # 转回 JSON 格式
|
||||
'type': fp.type,
|
||||
})
|
||||
return {
|
||||
'fingerprint': data,
|
||||
'version': self.get_fingerprint_version(),
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
"""FingerPrintHub 指纹管理 Service
|
||||
|
||||
实现 FingerPrintHub 格式指纹的校验、转换和导出逻辑
|
||||
"""
|
||||
|
||||
from apps.engine.models import FingerPrintHubFingerprint
|
||||
from .base import BaseFingerprintService
|
||||
|
||||
|
||||
class FingerPrintHubFingerprintService(BaseFingerprintService):
|
||||
"""FingerPrintHub 指纹管理服务(继承基类,实现 FingerPrintHub 特定逻辑)"""
|
||||
|
||||
model = FingerPrintHubFingerprint
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条 FingerPrintHub 指纹
|
||||
|
||||
校验规则:
|
||||
- id 字段必须存在且非空
|
||||
- info 字段必须存在且包含 name
|
||||
- http 字段必须是数组
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
fp_id = item.get('id', '')
|
||||
info = item.get('info', {})
|
||||
http = item.get('http')
|
||||
|
||||
if not fp_id or not str(fp_id).strip():
|
||||
return False
|
||||
if not isinstance(info, dict) or not info.get('name'):
|
||||
return False
|
||||
if not isinstance(http, list):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换 FingerPrintHub JSON 格式为 Model 字段
|
||||
|
||||
字段映射(嵌套结构转扁平):
|
||||
- id (JSON) → fp_id (Model)
|
||||
- info.name (JSON) → name (Model)
|
||||
- info.author (JSON) → author (Model)
|
||||
- info.tags (JSON) → tags (Model)
|
||||
- info.severity (JSON) → severity (Model)
|
||||
- info.metadata (JSON) → metadata (Model)
|
||||
- http (JSON) → http (Model)
|
||||
- _source_file (JSON) → source_file (Model)
|
||||
|
||||
Args:
|
||||
item: 原始 FingerPrintHub JSON 数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
info = item.get('info', {})
|
||||
return {
|
||||
'fp_id': str(item.get('id', '')).strip(),
|
||||
'name': str(info.get('name', '')).strip(),
|
||||
'author': info.get('author', ''),
|
||||
'tags': info.get('tags', ''),
|
||||
'severity': info.get('severity', 'info'),
|
||||
'metadata': info.get('metadata', {}),
|
||||
'http': item.get('http', []),
|
||||
'source_file': item.get('_source_file', ''),
|
||||
}
|
||||
|
||||
def get_export_data(self) -> list:
|
||||
"""
|
||||
获取导出数据(FingerPrintHub JSON 格式 - 数组)
|
||||
|
||||
Returns:
|
||||
list: FingerPrintHub 格式的 JSON 数据(数组格式)
|
||||
[
|
||||
{
|
||||
"id": "...",
|
||||
"info": {"name": "...", "author": "...", "tags": "...",
|
||||
"severity": "...", "metadata": {...}},
|
||||
"http": [...],
|
||||
"_source_file": "..."
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
fingerprints = self.model.objects.all()
|
||||
data = []
|
||||
for fp in fingerprints:
|
||||
item = {
|
||||
'id': fp.fp_id,
|
||||
'info': {
|
||||
'name': fp.name,
|
||||
'author': fp.author,
|
||||
'tags': fp.tags,
|
||||
'severity': fp.severity,
|
||||
'metadata': fp.metadata,
|
||||
},
|
||||
'http': fp.http,
|
||||
}
|
||||
# 只有当 source_file 非空时才添加该字段
|
||||
if fp.source_file:
|
||||
item['_source_file'] = fp.source_file
|
||||
data.append(item)
|
||||
return data
|
||||
83
backend/apps/engine/services/fingerprints/fingers_service.py
Normal file
83
backend/apps/engine/services/fingerprints/fingers_service.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Fingers 指纹管理 Service
|
||||
|
||||
实现 Fingers 格式指纹的校验、转换和导出逻辑
|
||||
"""
|
||||
|
||||
from apps.engine.models import FingersFingerprint
|
||||
from .base import BaseFingerprintService
|
||||
|
||||
|
||||
class FingersFingerprintService(BaseFingerprintService):
|
||||
"""Fingers 指纹管理服务(继承基类,实现 Fingers 特定逻辑)"""
|
||||
|
||||
model = FingersFingerprint
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条 Fingers 指纹
|
||||
|
||||
校验规则:
|
||||
- name 字段必须存在且非空
|
||||
- rule 字段必须是数组
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
name = item.get('name', '')
|
||||
rule = item.get('rule')
|
||||
return bool(name and str(name).strip()) and isinstance(rule, list)
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换 Fingers JSON 格式为 Model 字段
|
||||
|
||||
字段映射:
|
||||
- default_port (JSON) → default_port (Model)
|
||||
|
||||
Args:
|
||||
item: 原始 Fingers JSON 数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
return {
|
||||
'name': str(item.get('name', '')).strip(),
|
||||
'link': item.get('link', ''),
|
||||
'rule': item.get('rule', []),
|
||||
'tag': item.get('tag', []),
|
||||
'focus': item.get('focus', False),
|
||||
'default_port': item.get('default_port', []),
|
||||
}
|
||||
|
||||
def get_export_data(self) -> list:
|
||||
"""
|
||||
获取导出数据(Fingers JSON 格式 - 数组)
|
||||
|
||||
Returns:
|
||||
list: Fingers 格式的 JSON 数据(数组格式)
|
||||
[
|
||||
{"name": "...", "link": "...", "rule": [...], "tag": [...],
|
||||
"focus": false, "default_port": [...]},
|
||||
...
|
||||
]
|
||||
"""
|
||||
fingerprints = self.model.objects.all()
|
||||
data = []
|
||||
for fp in fingerprints:
|
||||
item = {
|
||||
'name': fp.name,
|
||||
'link': fp.link,
|
||||
'rule': fp.rule,
|
||||
'tag': fp.tag,
|
||||
}
|
||||
# 只有当 focus 为 True 时才添加该字段(保持与原始格式一致)
|
||||
if fp.focus:
|
||||
item['focus'] = fp.focus
|
||||
# 只有当 default_port 非空时才添加该字段
|
||||
if fp.default_port:
|
||||
item['default_port'] = fp.default_port
|
||||
data.append(item)
|
||||
return data
|
||||
70
backend/apps/engine/services/fingerprints/goby.py
Normal file
70
backend/apps/engine/services/fingerprints/goby.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Goby 指纹管理 Service
|
||||
|
||||
实现 Goby 格式指纹的校验、转换和导出逻辑
|
||||
"""
|
||||
|
||||
from apps.engine.models import GobyFingerprint
|
||||
from .base import BaseFingerprintService
|
||||
|
||||
|
||||
class GobyFingerprintService(BaseFingerprintService):
|
||||
"""Goby 指纹管理服务(继承基类,实现 Goby 特定逻辑)"""
|
||||
|
||||
model = GobyFingerprint
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条 Goby 指纹
|
||||
|
||||
校验规则:
|
||||
- name 字段必须存在且非空
|
||||
- logic 字段必须存在
|
||||
- rule 字段必须是数组
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
name = item.get('name', '')
|
||||
logic = item.get('logic', '')
|
||||
rule = item.get('rule')
|
||||
return bool(name and str(name).strip()) and bool(logic) and isinstance(rule, list)
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换 Goby JSON 格式为 Model 字段
|
||||
|
||||
Args:
|
||||
item: 原始 Goby JSON 数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
return {
|
||||
'name': str(item.get('name', '')).strip(),
|
||||
'logic': item.get('logic', ''),
|
||||
'rule': item.get('rule', []),
|
||||
}
|
||||
|
||||
def get_export_data(self) -> list:
|
||||
"""
|
||||
获取导出数据(Goby JSON 格式 - 数组)
|
||||
|
||||
Returns:
|
||||
list: Goby 格式的 JSON 数据(数组格式)
|
||||
[
|
||||
{"name": "...", "logic": "...", "rule": [...]},
|
||||
...
|
||||
]
|
||||
"""
|
||||
fingerprints = self.model.objects.all()
|
||||
return [
|
||||
{
|
||||
'name': fp.name,
|
||||
'logic': fp.logic,
|
||||
'rule': fp.rule,
|
||||
}
|
||||
for fp in fingerprints
|
||||
]
|
||||
99
backend/apps/engine/services/fingerprints/wappalyzer.py
Normal file
99
backend/apps/engine/services/fingerprints/wappalyzer.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Wappalyzer 指纹管理 Service
|
||||
|
||||
实现 Wappalyzer 格式指纹的校验、转换和导出逻辑
|
||||
"""
|
||||
|
||||
from apps.engine.models import WappalyzerFingerprint
|
||||
from .base import BaseFingerprintService
|
||||
|
||||
|
||||
class WappalyzerFingerprintService(BaseFingerprintService):
|
||||
"""Wappalyzer 指纹管理服务(继承基类,实现 Wappalyzer 特定逻辑)"""
|
||||
|
||||
model = WappalyzerFingerprint
|
||||
|
||||
def validate_fingerprint(self, item: dict) -> bool:
|
||||
"""
|
||||
校验单条 Wappalyzer 指纹
|
||||
|
||||
校验规则:
|
||||
- name 字段必须存在且非空(从 apps 对象的 key 传入)
|
||||
|
||||
Args:
|
||||
item: 单条指纹数据
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
name = item.get('name', '')
|
||||
return bool(name and str(name).strip())
|
||||
|
||||
def to_model_data(self, item: dict) -> dict:
|
||||
"""
|
||||
转换 Wappalyzer JSON 格式为 Model 字段
|
||||
|
||||
字段映射:
|
||||
- scriptSrc (JSON) → script_src (Model)
|
||||
|
||||
Args:
|
||||
item: 原始 Wappalyzer JSON 数据
|
||||
|
||||
Returns:
|
||||
dict: Model 字段数据
|
||||
"""
|
||||
return {
|
||||
'name': str(item.get('name', '')).strip(),
|
||||
'cats': item.get('cats', []),
|
||||
'cookies': item.get('cookies', {}),
|
||||
'headers': item.get('headers', {}),
|
||||
'script_src': item.get('scriptSrc', []), # JSON: scriptSrc -> Model: script_src
|
||||
'js': item.get('js', []),
|
||||
'implies': item.get('implies', []),
|
||||
'meta': item.get('meta', {}),
|
||||
'html': item.get('html', []),
|
||||
'description': item.get('description', ''),
|
||||
'website': item.get('website', ''),
|
||||
'cpe': item.get('cpe', ''),
|
||||
}
|
||||
|
||||
def get_export_data(self) -> dict:
|
||||
"""
|
||||
获取导出数据(Wappalyzer JSON 格式)
|
||||
|
||||
Returns:
|
||||
dict: Wappalyzer 格式的 JSON 数据
|
||||
{
|
||||
"apps": {
|
||||
"AppName": {"cats": [...], "cookies": {...}, ...},
|
||||
...
|
||||
}
|
||||
}
|
||||
"""
|
||||
fingerprints = self.model.objects.all()
|
||||
apps = {}
|
||||
for fp in fingerprints:
|
||||
app_data = {}
|
||||
if fp.cats:
|
||||
app_data['cats'] = fp.cats
|
||||
if fp.cookies:
|
||||
app_data['cookies'] = fp.cookies
|
||||
if fp.headers:
|
||||
app_data['headers'] = fp.headers
|
||||
if fp.script_src:
|
||||
app_data['scriptSrc'] = fp.script_src # Model: script_src -> JSON: scriptSrc
|
||||
if fp.js:
|
||||
app_data['js'] = fp.js
|
||||
if fp.implies:
|
||||
app_data['implies'] = fp.implies
|
||||
if fp.meta:
|
||||
app_data['meta'] = fp.meta
|
||||
if fp.html:
|
||||
app_data['html'] = fp.html
|
||||
if fp.description:
|
||||
app_data['description'] = fp.description
|
||||
if fp.website:
|
||||
app_data['website'] = fp.website
|
||||
if fp.cpe:
|
||||
app_data['cpe'] = fp.cpe
|
||||
apps[fp.name] = app_data
|
||||
return {'apps': apps}
|
||||
@@ -196,6 +196,9 @@ class NucleiTemplateRepoService:
|
||||
cmd: List[str]
|
||||
action: str
|
||||
|
||||
# 直接使用原始 URL(不再使用 Git 加速)
|
||||
repo_url = obj.repo_url
|
||||
|
||||
# 判断是 clone 还是 pull
|
||||
if git_dir.is_dir():
|
||||
# 检查远程地址是否变化
|
||||
@@ -208,12 +211,13 @@ class NucleiTemplateRepoService:
|
||||
)
|
||||
current_url = current_remote.stdout.strip() if current_remote.returncode == 0 else ""
|
||||
|
||||
if current_url != obj.repo_url:
|
||||
# 检查是否需要重新 clone
|
||||
if current_url != repo_url:
|
||||
# 远程地址变化,删除旧目录重新 clone
|
||||
logger.info("nuclei 模板仓库 %s 远程地址变化,重新 clone: %s -> %s", obj.id, current_url, obj.repo_url)
|
||||
logger.info("nuclei 模板仓库 %s 远程地址变化,重新 clone: %s -> %s", obj.id, current_url, repo_url)
|
||||
shutil.rmtree(local_path)
|
||||
local_path.mkdir(parents=True, exist_ok=True)
|
||||
cmd = ["git", "clone", "--depth", "1", obj.repo_url, str(local_path)]
|
||||
cmd = ["git", "clone", "--depth", "1", repo_url, str(local_path)]
|
||||
action = "clone"
|
||||
else:
|
||||
# 已有仓库且地址未变,执行 pull
|
||||
@@ -224,7 +228,7 @@ class NucleiTemplateRepoService:
|
||||
if local_path.exists() and not local_path.is_dir():
|
||||
raise RuntimeError(f"本地路径已存在且不是目录: {local_path}")
|
||||
# --depth 1 浅克隆,只获取最新提交,节省空间和时间
|
||||
cmd = ["git", "clone", "--depth", "1", obj.repo_url, str(local_path)]
|
||||
cmd = ["git", "clone", "--depth", "1", repo_url, str(local_path)]
|
||||
action = "clone"
|
||||
|
||||
# 执行 Git 命令
|
||||
|
||||
@@ -76,8 +76,8 @@ class TaskDistributor:
|
||||
self.docker_image = settings.TASK_EXECUTOR_IMAGE
|
||||
if not self.docker_image:
|
||||
raise ValueError("TASK_EXECUTOR_IMAGE 未配置,请确保 IMAGE_TAG 环境变量已设置")
|
||||
self.results_mount = getattr(settings, 'CONTAINER_RESULTS_MOUNT', '/app/backend/results')
|
||||
self.logs_mount = getattr(settings, 'CONTAINER_LOGS_MOUNT', '/app/backend/logs')
|
||||
# 统一使用 /opt/xingrin 下的路径
|
||||
self.logs_mount = "/opt/xingrin/logs"
|
||||
self.submit_interval = getattr(settings, 'TASK_SUBMIT_INTERVAL', 5)
|
||||
|
||||
def get_online_workers(self) -> list[WorkerNode]:
|
||||
@@ -153,30 +153,68 @@ class TaskDistributor:
|
||||
else:
|
||||
scored_workers.append((worker, score, cpu, mem))
|
||||
|
||||
# 降级策略:如果没有正常负载的,等待后重新选择
|
||||
# 降级策略:如果没有正常负载的,循环等待后重新检测
|
||||
if not scored_workers:
|
||||
if high_load_workers:
|
||||
# 高负载时先等待,给系统喘息时间(默认 60 秒)
|
||||
# 高负载等待参数(默认每 60 秒检测一次,最多 10 次)
|
||||
high_load_wait = getattr(settings, 'HIGH_LOAD_WAIT_SECONDS', 60)
|
||||
logger.warning("所有 Worker 高负载,等待 %d 秒后重试...", high_load_wait)
|
||||
time.sleep(high_load_wait)
|
||||
high_load_max_retries = getattr(settings, 'HIGH_LOAD_MAX_RETRIES', 10)
|
||||
|
||||
# 重新选择(递归调用,可能负载已降下来)
|
||||
# 为避免无限递归,这里直接使用高负载中最低的
|
||||
# 开始等待前发送高负载通知
|
||||
high_load_workers.sort(key=lambda x: x[1])
|
||||
best_worker, _, cpu, mem = high_load_workers[0]
|
||||
|
||||
# 发送高负载通知
|
||||
_, _, first_cpu, first_mem = high_load_workers[0]
|
||||
from apps.common.signals import all_workers_high_load
|
||||
all_workers_high_load.send(
|
||||
sender=self.__class__,
|
||||
worker_name=best_worker.name,
|
||||
cpu=cpu,
|
||||
mem=mem
|
||||
worker_name="所有节点",
|
||||
cpu=first_cpu,
|
||||
mem=first_mem
|
||||
)
|
||||
|
||||
logger.info("选择 Worker: %s (CPU: %.1f%%, MEM: %.1f%%)", best_worker.name, cpu, mem)
|
||||
return best_worker
|
||||
for retry in range(high_load_max_retries):
|
||||
logger.warning(
|
||||
"所有 Worker 高负载,等待 %d 秒后重试... (%d/%d)",
|
||||
high_load_wait, retry + 1, high_load_max_retries
|
||||
)
|
||||
time.sleep(high_load_wait)
|
||||
|
||||
# 重新获取负载数据
|
||||
loads = worker_load_service.get_all_loads(worker_ids)
|
||||
|
||||
# 重新评估
|
||||
scored_workers = []
|
||||
high_load_workers = []
|
||||
|
||||
for worker in workers:
|
||||
load = loads.get(worker.id)
|
||||
if not load:
|
||||
continue
|
||||
|
||||
cpu = load.get('cpu', 0)
|
||||
mem = load.get('mem', 0)
|
||||
score = cpu * 0.7 + mem * 0.3
|
||||
|
||||
if cpu > 85 or mem > 85:
|
||||
high_load_workers.append((worker, score, cpu, mem))
|
||||
else:
|
||||
scored_workers.append((worker, score, cpu, mem))
|
||||
|
||||
# 如果有正常负载的 Worker,跳出循环
|
||||
if scored_workers:
|
||||
logger.info("检测到正常负载 Worker,结束等待")
|
||||
break
|
||||
|
||||
# 超时或仍然高负载,选择负载最低的
|
||||
if not scored_workers and high_load_workers:
|
||||
high_load_workers.sort(key=lambda x: x[1])
|
||||
best_worker, _, cpu, mem = high_load_workers[0]
|
||||
|
||||
logger.warning(
|
||||
"等待超时,强制分发到高负载 Worker: %s (CPU: %.1f%%, MEM: %.1f%%)",
|
||||
best_worker.name, cpu, mem
|
||||
)
|
||||
return best_worker
|
||||
return best_worker
|
||||
else:
|
||||
logger.warning("没有可用的 Worker")
|
||||
return None
|
||||
@@ -234,11 +272,10 @@ class TaskDistributor:
|
||||
else:
|
||||
# 远程:通过 Nginx 反向代理访问(HTTPS,不直连 8888 端口)
|
||||
network_arg = ""
|
||||
server_url = f"https://{settings.PUBLIC_HOST}"
|
||||
server_url = f"https://{settings.PUBLIC_HOST}:{settings.PUBLIC_PORT}"
|
||||
|
||||
# 挂载路径(所有节点统一使用固定路径)
|
||||
host_results_dir = settings.HOST_RESULTS_DIR # /opt/xingrin/results
|
||||
host_logs_dir = settings.HOST_LOGS_DIR # /opt/xingrin/logs
|
||||
# 挂载路径(统一挂载 /opt/xingrin,扫描工具在 /opt/xingrin-tools/bin 不受影响)
|
||||
host_xingrin_dir = "/opt/xingrin"
|
||||
|
||||
# 环境变量:SERVER_URL + IS_LOCAL,其他配置容器启动时从配置中心获取
|
||||
# IS_LOCAL 用于 Worker 向配置中心声明身份,决定返回的数据库地址
|
||||
@@ -251,15 +288,12 @@ class TaskDistributor:
|
||||
"-e PREFECT_SERVER_EPHEMERAL_ENABLED=true", # 启用 ephemeral server(本地临时服务器)
|
||||
"-e PREFECT_SERVER_EPHEMERAL_STARTUP_TIMEOUT_SECONDS=120", # 增加启动超时时间
|
||||
"-e PREFECT_SERVER_DATABASE_CONNECTION_URL=sqlite+aiosqlite:////tmp/.prefect/prefect.db", # 使用 /tmp 下的 SQLite
|
||||
"-e PREFECT_LOGGING_LEVEL=DEBUG", # 启用 DEBUG 级别日志
|
||||
"-e PREFECT_LOGGING_SERVER_LEVEL=DEBUG", # Server 日志级别
|
||||
"-e PREFECT_DEBUG_MODE=true", # 启用调试模式
|
||||
"-e PREFECT_LOGGING_LEVEL=WARNING", # 日志级别(减少 DEBUG 噪音)
|
||||
]
|
||||
|
||||
# 挂载卷
|
||||
# 挂载卷(统一挂载整个 /opt/xingrin 目录)
|
||||
volumes = [
|
||||
f"-v {host_results_dir}:{self.results_mount}",
|
||||
f"-v {host_logs_dir}:{self.logs_mount}",
|
||||
f"-v {host_xingrin_dir}:{host_xingrin_dir}",
|
||||
]
|
||||
|
||||
# 构建命令行参数
|
||||
@@ -277,11 +311,10 @@ class TaskDistributor:
|
||||
# - 本地 Worker:install.sh 已预拉取镜像,直接使用本地版本
|
||||
# - 远程 Worker:deploy 时已预拉取镜像,直接使用本地版本
|
||||
# - 避免每次任务都检查 Docker Hub,提升性能和稳定性
|
||||
# 使用双引号包裹 sh -c 命令,内部 shlex.quote 生成的单引号参数可正确解析
|
||||
cmd = f'''docker run --rm -d --pull=missing {network_arg} \
|
||||
{' '.join(env_vars)} \
|
||||
{' '.join(volumes)} \
|
||||
{self.docker_image} \
|
||||
cmd = f'''docker run --rm -d --pull=missing {network_arg} \\
|
||||
{' '.join(env_vars)} \\
|
||||
{' '.join(volumes)} \\
|
||||
{self.docker_image} \\
|
||||
sh -c "{inner_cmd}"'''
|
||||
|
||||
return cmd
|
||||
@@ -520,7 +553,7 @@ class TaskDistributor:
|
||||
try:
|
||||
# 构建 docker run 命令(清理过期扫描结果目录)
|
||||
script_args = {
|
||||
'results_dir': '/app/backend/results',
|
||||
'results_dir': '/opt/xingrin/results',
|
||||
'retention_days': retention_days,
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,14 @@ from .views import (
|
||||
WordlistViewSet,
|
||||
NucleiTemplateRepoViewSet,
|
||||
)
|
||||
from .views.fingerprints import (
|
||||
EholeFingerprintViewSet,
|
||||
GobyFingerprintViewSet,
|
||||
WappalyzerFingerprintViewSet,
|
||||
FingersFingerprintViewSet,
|
||||
FingerPrintHubFingerprintViewSet,
|
||||
ARLFingerprintViewSet,
|
||||
)
|
||||
|
||||
|
||||
# 创建路由器
|
||||
@@ -15,6 +23,13 @@ router.register(r"engines", ScanEngineViewSet, basename="engine")
|
||||
router.register(r"workers", WorkerNodeViewSet, basename="worker")
|
||||
router.register(r"wordlists", WordlistViewSet, basename="wordlist")
|
||||
router.register(r"nuclei/repos", NucleiTemplateRepoViewSet, basename="nuclei-repos")
|
||||
# 指纹管理
|
||||
router.register(r"fingerprints/ehole", EholeFingerprintViewSet, basename="ehole-fingerprint")
|
||||
router.register(r"fingerprints/goby", GobyFingerprintViewSet, basename="goby-fingerprint")
|
||||
router.register(r"fingerprints/wappalyzer", WappalyzerFingerprintViewSet, basename="wappalyzer-fingerprint")
|
||||
router.register(r"fingerprints/fingers", FingersFingerprintViewSet, basename="fingers-fingerprint")
|
||||
router.register(r"fingerprints/fingerprinthub", FingerPrintHubFingerprintViewSet, basename="fingerprinthub-fingerprint")
|
||||
router.register(r"fingerprints/arl", ARLFingerprintViewSet, basename="arl-fingerprint")
|
||||
|
||||
urlpatterns = [
|
||||
path("", include(router.urls)),
|
||||
|
||||
22
backend/apps/engine/views/fingerprints/__init__.py
Normal file
22
backend/apps/engine/views/fingerprints/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""指纹管理 ViewSets
|
||||
|
||||
导出所有指纹相关的 ViewSet 类
|
||||
"""
|
||||
|
||||
from .base import BaseFingerprintViewSet
|
||||
from .ehole import EholeFingerprintViewSet
|
||||
from .goby import GobyFingerprintViewSet
|
||||
from .wappalyzer import WappalyzerFingerprintViewSet
|
||||
from .fingers import FingersFingerprintViewSet
|
||||
from .fingerprinthub import FingerPrintHubFingerprintViewSet
|
||||
from .arl import ARLFingerprintViewSet
|
||||
|
||||
__all__ = [
|
||||
"BaseFingerprintViewSet",
|
||||
"EholeFingerprintViewSet",
|
||||
"GobyFingerprintViewSet",
|
||||
"WappalyzerFingerprintViewSet",
|
||||
"FingersFingerprintViewSet",
|
||||
"FingerPrintHubFingerprintViewSet",
|
||||
"ARLFingerprintViewSet",
|
||||
]
|
||||
122
backend/apps/engine/views/fingerprints/arl.py
Normal file
122
backend/apps/engine/views/fingerprints/arl.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""ARL 指纹管理 ViewSet"""
|
||||
|
||||
import yaml
|
||||
from django.http import HttpResponse
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.common.response_helpers import success_response
|
||||
from apps.engine.models import ARLFingerprint
|
||||
from apps.engine.serializers.fingerprints import ARLFingerprintSerializer
|
||||
from apps.engine.services.fingerprints import ARLFingerprintService
|
||||
|
||||
from .base import BaseFingerprintViewSet
|
||||
|
||||
|
||||
class ARLFingerprintViewSet(BaseFingerprintViewSet):
|
||||
"""ARL 指纹管理 ViewSet
|
||||
|
||||
继承自 BaseFingerprintViewSet,提供以下 API:
|
||||
|
||||
标准 CRUD(ModelViewSet):
|
||||
- GET / 列表查询(分页)
|
||||
- POST / 创建单条
|
||||
- GET /{id}/ 获取详情
|
||||
- PUT /{id}/ 更新
|
||||
- DELETE /{id}/ 删除
|
||||
|
||||
批量操作(继承自基类):
|
||||
- POST /batch_create/ 批量创建(JSON body)
|
||||
- POST /import_file/ 文件导入(multipart/form-data,支持 YAML)
|
||||
- POST /bulk-delete/ 批量删除
|
||||
- POST /delete-all/ 删除所有
|
||||
- GET /export/ 导出下载(YAML 格式)
|
||||
|
||||
智能过滤语法(filter 参数):
|
||||
- name="word" 模糊匹配 name 字段
|
||||
- name=="WordPress" 精确匹配
|
||||
- rule="body=" 按规则内容筛选
|
||||
"""
|
||||
|
||||
queryset = ARLFingerprint.objects.all()
|
||||
serializer_class = ARLFingerprintSerializer
|
||||
pagination_class = BasePagination
|
||||
service_class = ARLFingerprintService
|
||||
|
||||
# 排序配置
|
||||
ordering_fields = ['created_at', 'name']
|
||||
ordering = ['-created_at']
|
||||
|
||||
# ARL 过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'name': 'name',
|
||||
'rule': 'rule',
|
||||
}
|
||||
|
||||
def parse_import_data(self, json_data) -> list:
|
||||
"""
|
||||
解析 ARL 格式的导入数据(JSON 格式)
|
||||
|
||||
输入格式:[{...}, {...}] 数组格式
|
||||
返回:指纹列表
|
||||
"""
|
||||
if isinstance(json_data, list):
|
||||
return json_data
|
||||
return []
|
||||
|
||||
def get_export_filename(self) -> str:
|
||||
"""导出文件名"""
|
||||
return 'ARL.yaml'
|
||||
|
||||
@action(detail=False, methods=['post'])
|
||||
def import_file(self, request):
|
||||
"""
|
||||
文件导入(支持 YAML 和 JSON 格式)
|
||||
POST /api/engine/fingerprints/arl/import_file/
|
||||
|
||||
请求格式:multipart/form-data
|
||||
- file: YAML 或 JSON 文件
|
||||
|
||||
返回:同 batch_create
|
||||
"""
|
||||
file = request.FILES.get('file')
|
||||
if not file:
|
||||
raise ValidationError('缺少文件')
|
||||
|
||||
filename = file.name.lower()
|
||||
content = file.read().decode('utf-8')
|
||||
|
||||
try:
|
||||
if filename.endswith('.yaml') or filename.endswith('.yml'):
|
||||
# YAML 格式
|
||||
fingerprints = yaml.safe_load(content)
|
||||
else:
|
||||
# JSON 格式
|
||||
import json
|
||||
fingerprints = json.loads(content)
|
||||
except (yaml.YAMLError, json.JSONDecodeError) as e:
|
||||
raise ValidationError(f'无效的文件格式: {e}')
|
||||
|
||||
if not isinstance(fingerprints, list):
|
||||
raise ValidationError('文件内容必须是数组格式')
|
||||
|
||||
if not fingerprints:
|
||||
raise ValidationError('文件中没有有效的指纹数据')
|
||||
|
||||
result = self.get_service().batch_create_fingerprints(fingerprints)
|
||||
return success_response(data=result)
|
||||
|
||||
@action(detail=False, methods=['get'])
|
||||
def export(self, request):
|
||||
"""
|
||||
导出指纹(YAML 格式)
|
||||
GET /api/engine/fingerprints/arl/export/
|
||||
|
||||
返回:YAML 文件下载
|
||||
"""
|
||||
data = self.get_service().get_export_data()
|
||||
content = yaml.dump(data, allow_unicode=True, default_flow_style=False, sort_keys=False)
|
||||
response = HttpResponse(content, content_type='application/x-yaml')
|
||||
response['Content-Disposition'] = f'attachment; filename="{self.get_export_filename()}"'
|
||||
return response
|
||||
203
backend/apps/engine/views/fingerprints/base.py
Normal file
203
backend/apps/engine/views/fingerprints/base.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""指纹管理基类 ViewSet
|
||||
|
||||
提供通用的 CRUD 和批量操作,供 EHole/Goby/Wappalyzer 等子类继承
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from django.http import HttpResponse
|
||||
from rest_framework import viewsets, status, filters
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.exceptions import ValidationError
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.common.response_helpers import success_response
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseFingerprintViewSet(viewsets.ModelViewSet):
|
||||
"""指纹管理基类 ViewSet,供 EHole/Goby/Wappalyzer 等子类继承
|
||||
|
||||
提供的 API:
|
||||
|
||||
标准 CRUD(继承自 ModelViewSet):
|
||||
- GET / 列表查询(分页 + 智能过滤)
|
||||
- POST / 创建单条
|
||||
- GET /{id}/ 获取详情
|
||||
- PUT /{id}/ 更新
|
||||
- DELETE /{id}/ 删除
|
||||
|
||||
批量操作(本类实现):
|
||||
- POST /batch_create/ 批量创建(JSON body)
|
||||
- POST /import_file/ 文件导入(multipart/form-data,适合 10MB+ 大文件)
|
||||
- POST /bulk-delete/ 批量删除
|
||||
- POST /delete-all/ 删除所有
|
||||
- GET /export/ 导出下载
|
||||
|
||||
智能过滤语法(filter 参数):
|
||||
- field="value" 模糊匹配(包含)
|
||||
- field=="value" 精确匹配
|
||||
- 多条件空格分隔 AND 关系
|
||||
- || 或 or OR 关系
|
||||
|
||||
子类必须实现:
|
||||
- service_class Service 类
|
||||
- parse_import_data 解析导入数据格式
|
||||
- get_export_filename 导出文件名
|
||||
"""
|
||||
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering = ['-created_at']
|
||||
|
||||
# 子类必须指定
|
||||
service_class = None # Service 类
|
||||
|
||||
# 智能过滤字段映射,子类必须覆盖
|
||||
FILTER_FIELD_MAPPING = {}
|
||||
|
||||
# JSON 数组字段列表(使用 __contains 查询),子类可覆盖
|
||||
JSON_ARRAY_FIELDS = []
|
||||
|
||||
def get_queryset(self):
|
||||
"""支持智能过滤语法"""
|
||||
queryset = super().get_queryset()
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
if filter_query:
|
||||
queryset = apply_filters(
|
||||
queryset,
|
||||
filter_query,
|
||||
self.FILTER_FIELD_MAPPING,
|
||||
json_array_fields=getattr(self, 'JSON_ARRAY_FIELDS', [])
|
||||
)
|
||||
return queryset
|
||||
|
||||
def get_service(self):
|
||||
"""获取 Service 实例"""
|
||||
if self.service_class is None:
|
||||
raise NotImplementedError("子类必须指定 service_class")
|
||||
return self.service_class()
|
||||
|
||||
def parse_import_data(self, json_data: dict) -> list:
|
||||
"""
|
||||
解析导入数据,子类必须实现
|
||||
|
||||
Args:
|
||||
json_data: 解析后的 JSON 数据
|
||||
|
||||
Returns:
|
||||
list: 指纹数据列表
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 parse_import_data 方法")
|
||||
|
||||
def get_export_filename(self) -> str:
|
||||
"""
|
||||
导出文件名,子类必须实现
|
||||
|
||||
Returns:
|
||||
str: 文件名
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 get_export_filename 方法")
|
||||
|
||||
@action(detail=False, methods=['post'])
|
||||
def batch_create(self, request):
|
||||
"""
|
||||
批量创建指纹规则
|
||||
POST /api/engine/fingerprints/{type}/batch_create/
|
||||
|
||||
请求格式:
|
||||
{
|
||||
"fingerprints": [
|
||||
{"cms": "WordPress", "method": "keyword", ...},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
返回:
|
||||
{
|
||||
"created": 2,
|
||||
"failed": 0
|
||||
}
|
||||
"""
|
||||
fingerprints = request.data.get('fingerprints', [])
|
||||
if not fingerprints:
|
||||
raise ValidationError('fingerprints 不能为空')
|
||||
if not isinstance(fingerprints, list):
|
||||
raise ValidationError('fingerprints 必须是数组')
|
||||
|
||||
result = self.get_service().batch_create_fingerprints(fingerprints)
|
||||
return success_response(data=result, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
@action(detail=False, methods=['post'])
|
||||
def import_file(self, request):
|
||||
"""
|
||||
文件导入(适合大文件,10MB+)
|
||||
POST /api/engine/fingerprints/{type}/import_file/
|
||||
|
||||
请求格式:multipart/form-data
|
||||
- file: JSON 文件
|
||||
|
||||
返回:同 batch_create
|
||||
"""
|
||||
file = request.FILES.get('file')
|
||||
if not file:
|
||||
raise ValidationError('缺少文件')
|
||||
|
||||
try:
|
||||
json_data = json.load(file)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValidationError(f'无效的 JSON 格式: {e}')
|
||||
|
||||
fingerprints = self.parse_import_data(json_data)
|
||||
if not fingerprints:
|
||||
raise ValidationError('文件中没有有效的指纹数据')
|
||||
|
||||
result = self.get_service().batch_create_fingerprints(fingerprints)
|
||||
return success_response(data=result, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='bulk-delete')
|
||||
def bulk_delete(self, request):
|
||||
"""
|
||||
批量删除
|
||||
POST /api/engine/fingerprints/{type}/bulk-delete/
|
||||
|
||||
请求格式:{"ids": [1, 2, 3]}
|
||||
返回:{"deleted": 3}
|
||||
"""
|
||||
ids = request.data.get('ids', [])
|
||||
if not ids:
|
||||
raise ValidationError('ids 不能为空')
|
||||
if not isinstance(ids, list):
|
||||
raise ValidationError('ids 必须是数组')
|
||||
|
||||
deleted_count = self.queryset.model.objects.filter(id__in=ids).delete()[0]
|
||||
return success_response(data={'deleted': deleted_count})
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='delete-all')
|
||||
def delete_all(self, request):
|
||||
"""
|
||||
删除所有指纹
|
||||
POST /api/engine/fingerprints/{type}/delete-all/
|
||||
|
||||
返回:{"deleted": 1000}
|
||||
"""
|
||||
deleted_count = self.queryset.model.objects.all().delete()[0]
|
||||
return success_response(data={'deleted': deleted_count})
|
||||
|
||||
@action(detail=False, methods=['get'])
|
||||
def export(self, request):
|
||||
"""
|
||||
导出指纹(前端下载)
|
||||
GET /api/engine/fingerprints/{type}/export/
|
||||
|
||||
返回:JSON 文件下载
|
||||
"""
|
||||
data = self.get_service().get_export_data()
|
||||
content = json.dumps(data, ensure_ascii=False, indent=2)
|
||||
response = HttpResponse(content, content_type='application/json')
|
||||
response['Content-Disposition'] = f'attachment; filename="{self.get_export_filename()}"'
|
||||
return response
|
||||
67
backend/apps/engine/views/fingerprints/ehole.py
Normal file
67
backend/apps/engine/views/fingerprints/ehole.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""EHole 指纹管理 ViewSet"""
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.engine.models import EholeFingerprint
|
||||
from apps.engine.serializers.fingerprints import EholeFingerprintSerializer
|
||||
from apps.engine.services.fingerprints import EholeFingerprintService
|
||||
|
||||
from .base import BaseFingerprintViewSet
|
||||
|
||||
|
||||
class EholeFingerprintViewSet(BaseFingerprintViewSet):
|
||||
"""EHole 指纹管理 ViewSet
|
||||
|
||||
继承自 BaseFingerprintViewSet,提供以下 API:
|
||||
|
||||
标准 CRUD(ModelViewSet):
|
||||
- GET / 列表查询(分页)
|
||||
- POST / 创建单条
|
||||
- GET /{id}/ 获取详情
|
||||
- PUT /{id}/ 更新
|
||||
- DELETE /{id}/ 删除
|
||||
|
||||
批量操作(继承自基类):
|
||||
- POST /batch_create/ 批量创建(JSON body)
|
||||
- POST /import_file/ 文件导入(multipart/form-data)
|
||||
- POST /bulk-delete/ 批量删除
|
||||
- POST /delete-all/ 删除所有
|
||||
- GET /export/ 导出下载
|
||||
|
||||
智能过滤语法(filter 参数):
|
||||
- cms="word" 模糊匹配 cms 字段
|
||||
- cms=="WordPress" 精确匹配
|
||||
- type="CMS" 按类型筛选
|
||||
- method="keyword" 按匹配方式筛选
|
||||
- location="body" 按匹配位置筛选
|
||||
"""
|
||||
|
||||
queryset = EholeFingerprint.objects.all()
|
||||
serializer_class = EholeFingerprintSerializer
|
||||
pagination_class = BasePagination
|
||||
service_class = EholeFingerprintService
|
||||
|
||||
# 排序配置
|
||||
ordering_fields = ['created_at', 'cms']
|
||||
ordering = ['-created_at']
|
||||
|
||||
# EHole 过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'cms': 'cms',
|
||||
'method': 'method',
|
||||
'location': 'location',
|
||||
'type': 'type',
|
||||
'isImportant': 'is_important',
|
||||
}
|
||||
|
||||
def parse_import_data(self, json_data: dict) -> list:
|
||||
"""
|
||||
解析 EHole JSON 格式的导入数据
|
||||
|
||||
输入格式:{"fingerprint": [...]}
|
||||
返回:指纹列表
|
||||
"""
|
||||
return json_data.get('fingerprint', [])
|
||||
|
||||
def get_export_filename(self) -> str:
|
||||
"""导出文件名"""
|
||||
return 'ehole.json'
|
||||
73
backend/apps/engine/views/fingerprints/fingerprinthub.py
Normal file
73
backend/apps/engine/views/fingerprints/fingerprinthub.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""FingerPrintHub 指纹管理 ViewSet"""
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.engine.models import FingerPrintHubFingerprint
|
||||
from apps.engine.serializers.fingerprints import FingerPrintHubFingerprintSerializer
|
||||
from apps.engine.services.fingerprints import FingerPrintHubFingerprintService
|
||||
|
||||
from .base import BaseFingerprintViewSet
|
||||
|
||||
|
||||
class FingerPrintHubFingerprintViewSet(BaseFingerprintViewSet):
|
||||
"""FingerPrintHub 指纹管理 ViewSet
|
||||
|
||||
继承自 BaseFingerprintViewSet,提供以下 API:
|
||||
|
||||
标准 CRUD(ModelViewSet):
|
||||
- GET / 列表查询(分页)
|
||||
- POST / 创建单条
|
||||
- GET /{id}/ 获取详情
|
||||
- PUT /{id}/ 更新
|
||||
- DELETE /{id}/ 删除
|
||||
|
||||
批量操作(继承自基类):
|
||||
- POST /batch_create/ 批量创建(JSON body)
|
||||
- POST /import_file/ 文件导入(multipart/form-data)
|
||||
- POST /bulk-delete/ 批量删除
|
||||
- POST /delete-all/ 删除所有
|
||||
- GET /export/ 导出下载
|
||||
|
||||
智能过滤语法(filter 参数):
|
||||
- name="word" 模糊匹配 name 字段
|
||||
- fp_id=="xxx" 精确匹配指纹ID
|
||||
- author="xxx" 按作者筛选
|
||||
- severity="info" 按严重程度筛选
|
||||
- tags="cms" 按标签筛选
|
||||
"""
|
||||
|
||||
queryset = FingerPrintHubFingerprint.objects.all()
|
||||
serializer_class = FingerPrintHubFingerprintSerializer
|
||||
pagination_class = BasePagination
|
||||
service_class = FingerPrintHubFingerprintService
|
||||
|
||||
# 排序配置
|
||||
ordering_fields = ['created_at', 'name', 'severity']
|
||||
ordering = ['-created_at']
|
||||
|
||||
# FingerPrintHub 过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'fp_id': 'fp_id',
|
||||
'name': 'name',
|
||||
'author': 'author',
|
||||
'tags': 'tags',
|
||||
'severity': 'severity',
|
||||
'source_file': 'source_file',
|
||||
}
|
||||
|
||||
# JSON 数组字段(使用 __contains 查询)
|
||||
JSON_ARRAY_FIELDS = ['http']
|
||||
|
||||
def parse_import_data(self, json_data) -> list:
|
||||
"""
|
||||
解析 FingerPrintHub JSON 格式的导入数据
|
||||
|
||||
输入格式:[{...}, {...}] 数组格式
|
||||
返回:指纹列表
|
||||
"""
|
||||
if isinstance(json_data, list):
|
||||
return json_data
|
||||
return []
|
||||
|
||||
def get_export_filename(self) -> str:
|
||||
"""导出文件名"""
|
||||
return 'fingerprinthub_web.json'
|
||||
69
backend/apps/engine/views/fingerprints/fingers.py
Normal file
69
backend/apps/engine/views/fingerprints/fingers.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Fingers 指纹管理 ViewSet"""
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.engine.models import FingersFingerprint
|
||||
from apps.engine.serializers.fingerprints import FingersFingerprintSerializer
|
||||
from apps.engine.services.fingerprints import FingersFingerprintService
|
||||
|
||||
from .base import BaseFingerprintViewSet
|
||||
|
||||
|
||||
class FingersFingerprintViewSet(BaseFingerprintViewSet):
|
||||
"""Fingers 指纹管理 ViewSet
|
||||
|
||||
继承自 BaseFingerprintViewSet,提供以下 API:
|
||||
|
||||
标准 CRUD(ModelViewSet):
|
||||
- GET / 列表查询(分页)
|
||||
- POST / 创建单条
|
||||
- GET /{id}/ 获取详情
|
||||
- PUT /{id}/ 更新
|
||||
- DELETE /{id}/ 删除
|
||||
|
||||
批量操作(继承自基类):
|
||||
- POST /batch_create/ 批量创建(JSON body)
|
||||
- POST /import_file/ 文件导入(multipart/form-data)
|
||||
- POST /bulk-delete/ 批量删除
|
||||
- POST /delete-all/ 删除所有
|
||||
- GET /export/ 导出下载
|
||||
|
||||
智能过滤语法(filter 参数):
|
||||
- name="word" 模糊匹配 name 字段
|
||||
- name=="WordPress" 精确匹配
|
||||
- tag="cms" 按标签筛选
|
||||
- focus="true" 按重点关注筛选
|
||||
"""
|
||||
|
||||
queryset = FingersFingerprint.objects.all()
|
||||
serializer_class = FingersFingerprintSerializer
|
||||
pagination_class = BasePagination
|
||||
service_class = FingersFingerprintService
|
||||
|
||||
# 排序配置
|
||||
ordering_fields = ['created_at', 'name']
|
||||
ordering = ['-created_at']
|
||||
|
||||
# Fingers 过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'name': 'name',
|
||||
'link': 'link',
|
||||
'focus': 'focus',
|
||||
}
|
||||
|
||||
# JSON 数组字段(使用 __contains 查询)
|
||||
JSON_ARRAY_FIELDS = ['tag', 'rule', 'default_port']
|
||||
|
||||
def parse_import_data(self, json_data) -> list:
|
||||
"""
|
||||
解析 Fingers JSON 格式的导入数据
|
||||
|
||||
输入格式:[{...}, {...}] 数组格式
|
||||
返回:指纹列表
|
||||
"""
|
||||
if isinstance(json_data, list):
|
||||
return json_data
|
||||
return []
|
||||
|
||||
def get_export_filename(self) -> str:
|
||||
"""导出文件名"""
|
||||
return 'fingers_http.json'
|
||||
65
backend/apps/engine/views/fingerprints/goby.py
Normal file
65
backend/apps/engine/views/fingerprints/goby.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Goby 指纹管理 ViewSet"""
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.engine.models import GobyFingerprint
|
||||
from apps.engine.serializers.fingerprints import GobyFingerprintSerializer
|
||||
from apps.engine.services.fingerprints import GobyFingerprintService
|
||||
|
||||
from .base import BaseFingerprintViewSet
|
||||
|
||||
|
||||
class GobyFingerprintViewSet(BaseFingerprintViewSet):
|
||||
"""Goby 指纹管理 ViewSet
|
||||
|
||||
继承自 BaseFingerprintViewSet,提供以下 API:
|
||||
|
||||
标准 CRUD(ModelViewSet):
|
||||
- GET / 列表查询(分页)
|
||||
- POST / 创建单条
|
||||
- GET /{id}/ 获取详情
|
||||
- PUT /{id}/ 更新
|
||||
- DELETE /{id}/ 删除
|
||||
|
||||
批量操作(继承自基类):
|
||||
- POST /batch_create/ 批量创建(JSON body)
|
||||
- POST /import_file/ 文件导入(multipart/form-data)
|
||||
- POST /bulk-delete/ 批量删除
|
||||
- POST /delete-all/ 删除所有
|
||||
- GET /export/ 导出下载
|
||||
|
||||
智能过滤语法(filter 参数):
|
||||
- name="word" 模糊匹配 name 字段
|
||||
- name=="ProductName" 精确匹配
|
||||
"""
|
||||
|
||||
queryset = GobyFingerprint.objects.all()
|
||||
serializer_class = GobyFingerprintSerializer
|
||||
pagination_class = BasePagination
|
||||
service_class = GobyFingerprintService
|
||||
|
||||
# 排序配置
|
||||
ordering_fields = ['created_at', 'name']
|
||||
ordering = ['-created_at']
|
||||
|
||||
# Goby 过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'name': 'name',
|
||||
'logic': 'logic',
|
||||
}
|
||||
|
||||
def parse_import_data(self, json_data) -> list:
|
||||
"""
|
||||
解析 Goby JSON 格式的导入数据
|
||||
|
||||
Goby 格式是数组格式:[{...}, {...}, ...]
|
||||
|
||||
输入格式:[{"name": "...", "logic": "...", "rule": [...]}, ...]
|
||||
返回:指纹列表
|
||||
"""
|
||||
if isinstance(json_data, list):
|
||||
return json_data
|
||||
return []
|
||||
|
||||
def get_export_filename(self) -> str:
|
||||
"""导出文件名"""
|
||||
return 'goby.json'
|
||||
75
backend/apps/engine/views/fingerprints/wappalyzer.py
Normal file
75
backend/apps/engine/views/fingerprints/wappalyzer.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Wappalyzer 指纹管理 ViewSet"""
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.engine.models import WappalyzerFingerprint
|
||||
from apps.engine.serializers.fingerprints import WappalyzerFingerprintSerializer
|
||||
from apps.engine.services.fingerprints import WappalyzerFingerprintService
|
||||
|
||||
from .base import BaseFingerprintViewSet
|
||||
|
||||
|
||||
class WappalyzerFingerprintViewSet(BaseFingerprintViewSet):
|
||||
"""Wappalyzer 指纹管理 ViewSet
|
||||
|
||||
继承自 BaseFingerprintViewSet,提供以下 API:
|
||||
|
||||
标准 CRUD(ModelViewSet):
|
||||
- GET / 列表查询(分页)
|
||||
- POST / 创建单条
|
||||
- GET /{id}/ 获取详情
|
||||
- PUT /{id}/ 更新
|
||||
- DELETE /{id}/ 删除
|
||||
|
||||
批量操作(继承自基类):
|
||||
- POST /batch_create/ 批量创建(JSON body)
|
||||
- POST /import_file/ 文件导入(multipart/form-data)
|
||||
- POST /bulk-delete/ 批量删除
|
||||
- POST /delete-all/ 删除所有
|
||||
- GET /export/ 导出下载
|
||||
|
||||
智能过滤语法(filter 参数):
|
||||
- name="word" 模糊匹配 name 字段
|
||||
- name=="AppName" 精确匹配
|
||||
"""
|
||||
|
||||
queryset = WappalyzerFingerprint.objects.all()
|
||||
serializer_class = WappalyzerFingerprintSerializer
|
||||
pagination_class = BasePagination
|
||||
service_class = WappalyzerFingerprintService
|
||||
|
||||
# 排序配置
|
||||
ordering_fields = ['created_at', 'name']
|
||||
ordering = ['-created_at']
|
||||
|
||||
# Wappalyzer 过滤字段映射
|
||||
# 注意:implies 是 JSON 数组字段,使用 __contains 查询
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'name': 'name',
|
||||
'description': 'description',
|
||||
'website': 'website',
|
||||
'cpe': 'cpe',
|
||||
'implies': 'implies', # JSON 数组字段
|
||||
}
|
||||
|
||||
# JSON 数组字段列表(使用 __contains 查询)
|
||||
JSON_ARRAY_FIELDS = ['implies']
|
||||
|
||||
def parse_import_data(self, json_data: dict) -> list:
|
||||
"""
|
||||
解析 Wappalyzer JSON 格式的导入数据
|
||||
|
||||
Wappalyzer 格式是 apps 对象格式:{"apps": {"AppName": {...}, ...}}
|
||||
|
||||
输入格式:{"apps": {"1C-Bitrix": {"cats": [...], ...}, ...}}
|
||||
返回:指纹列表(每个 app 转换为带 name 字段的 dict)
|
||||
"""
|
||||
apps = json_data.get('apps', {})
|
||||
fingerprints = []
|
||||
for name, data in apps.items():
|
||||
item = {'name': name, **data}
|
||||
fingerprints.append(item)
|
||||
return fingerprints
|
||||
|
||||
def get_export_filename(self) -> str:
|
||||
"""导出文件名"""
|
||||
return 'wappalyzer.json'
|
||||
@@ -31,6 +31,8 @@ from rest_framework.decorators import action
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
from apps.engine.models import NucleiTemplateRepo
|
||||
from apps.engine.serializers import NucleiTemplateRepoSerializer
|
||||
from apps.engine.services import NucleiTemplateRepoService
|
||||
@@ -107,18 +109,30 @@ class NucleiTemplateRepoViewSet(viewsets.ModelViewSet):
|
||||
try:
|
||||
repo_id = int(pk) if pk is not None else None
|
||||
except (TypeError, ValueError):
|
||||
return Response({"message": "无效的仓库 ID"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Invalid repository ID',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 调用 Service 层
|
||||
try:
|
||||
result = self.service.refresh_repo(repo_id)
|
||||
except ValidationError as exc:
|
||||
return Response({"message": str(exc)}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=str(exc),
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("刷新 Nuclei 模板仓库失败: %s", exc, exc_info=True)
|
||||
return Response({"message": f"刷新仓库失败: {exc}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message=f'Refresh failed: {exc}',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
return Response({"message": "刷新成功", "result": result}, status=status.HTTP_200_OK)
|
||||
return success_response(data={'result': result})
|
||||
|
||||
# ==================== 自定义 Action: 模板只读浏览 ====================
|
||||
|
||||
@@ -142,18 +156,30 @@ class NucleiTemplateRepoViewSet(viewsets.ModelViewSet):
|
||||
try:
|
||||
repo_id = int(pk) if pk is not None else None
|
||||
except (TypeError, ValueError):
|
||||
return Response({"message": "无效的仓库 ID"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Invalid repository ID',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 调用 Service 层,仅从当前本地目录读取目录树
|
||||
try:
|
||||
roots = self.service.get_template_tree(repo_id)
|
||||
except ValidationError as exc:
|
||||
return Response({"message": str(exc)}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=str(exc),
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("获取 Nuclei 模板目录树失败: %s", exc, exc_info=True)
|
||||
return Response({"message": "获取模板目录树失败"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to get template tree',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
return Response({"roots": roots})
|
||||
return success_response(data={'roots': roots})
|
||||
|
||||
@action(detail=True, methods=["get"], url_path="templates/content")
|
||||
def templates_content(self, request: Request, pk: str | None = None) -> Response:
|
||||
@@ -174,23 +200,43 @@ class NucleiTemplateRepoViewSet(viewsets.ModelViewSet):
|
||||
try:
|
||||
repo_id = int(pk) if pk is not None else None
|
||||
except (TypeError, ValueError):
|
||||
return Response({"message": "无效的仓库 ID"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Invalid repository ID',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 解析 path 参数
|
||||
rel_path = (request.query_params.get("path", "") or "").strip()
|
||||
if not rel_path:
|
||||
return Response({"message": "缺少 path 参数"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Missing path parameter',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 调用 Service 层
|
||||
try:
|
||||
result = self.service.get_template_content(repo_id, rel_path)
|
||||
except ValidationError as exc:
|
||||
return Response({"message": str(exc)}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=str(exc),
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("获取 Nuclei 模板内容失败: %s", exc, exc_info=True)
|
||||
return Response({"message": "获取模板内容失败"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to get template content',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
# 文件不存在
|
||||
if result is None:
|
||||
return Response({"message": "模板不存在或无法读取"}, status=status.HTTP_404_NOT_FOUND)
|
||||
return Response(result)
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='Template not found or unreadable',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
return success_response(data=result)
|
||||
|
||||
@@ -9,6 +9,8 @@ from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
from apps.engine.serializers.wordlist_serializer import WordlistSerializer
|
||||
from apps.engine.services.wordlist_service import WordlistService
|
||||
|
||||
@@ -46,7 +48,11 @@ class WordlistViewSet(viewsets.ViewSet):
|
||||
uploaded_file = request.FILES.get("file")
|
||||
|
||||
if not uploaded_file:
|
||||
return Response({"error": "缺少字典文件"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Missing wordlist file',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
try:
|
||||
wordlist = self.service.create_wordlist(
|
||||
@@ -55,21 +61,32 @@ class WordlistViewSet(viewsets.ViewSet):
|
||||
uploaded_file=uploaded_file,
|
||||
)
|
||||
except ValidationError as exc:
|
||||
return Response({"error": str(exc)}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=str(exc),
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
serializer = WordlistSerializer(wordlist)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
||||
return success_response(data=serializer.data, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
def destroy(self, request, pk=None):
|
||||
"""删除字典记录"""
|
||||
try:
|
||||
wordlist_id = int(pk)
|
||||
except (TypeError, ValueError):
|
||||
return Response({"error": "无效的 ID"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Invalid ID',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
success = self.service.delete_wordlist(wordlist_id)
|
||||
if not success:
|
||||
return Response({"error": "字典不存在"}, status=status.HTTP_404_NOT_FOUND)
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
@@ -82,15 +99,27 @@ class WordlistViewSet(viewsets.ViewSet):
|
||||
"""
|
||||
name = (request.query_params.get("wordlist", "") or "").strip()
|
||||
if not name:
|
||||
return Response({"error": "缺少参数 wordlist"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Missing parameter: wordlist',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
wordlist = self.service.get_wordlist_by_name(name)
|
||||
if not wordlist:
|
||||
return Response({"error": "字典不存在"}, status=status.HTTP_404_NOT_FOUND)
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='Wordlist not found',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
file_path = wordlist.file_path
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
return Response({"error": "字典文件不存在"}, status=status.HTTP_404_NOT_FOUND)
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='Wordlist file not found',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
filename = os.path.basename(file_path)
|
||||
response = FileResponse(open(file_path, "rb"), as_attachment=True, filename=filename)
|
||||
@@ -106,22 +135,38 @@ class WordlistViewSet(viewsets.ViewSet):
|
||||
try:
|
||||
wordlist_id = int(pk)
|
||||
except (TypeError, ValueError):
|
||||
return Response({"error": "无效的 ID"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Invalid ID',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
if request.method == "GET":
|
||||
content = self.service.get_wordlist_content(wordlist_id)
|
||||
if content is None:
|
||||
return Response({"error": "字典不存在或文件无法读取"}, status=status.HTTP_404_NOT_FOUND)
|
||||
return Response({"content": content})
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='Wordlist not found or file unreadable',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
return success_response(data={"content": content})
|
||||
|
||||
elif request.method == "PUT":
|
||||
content = request.data.get("content")
|
||||
if content is None:
|
||||
return Response({"error": "缺少 content 参数"}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Missing content parameter',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
wordlist = self.service.update_wordlist_content(wordlist_id, content)
|
||||
if not wordlist:
|
||||
return Response({"error": "字典不存在或更新失败"}, status=status.HTTP_404_NOT_FOUND)
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message='Wordlist not found or update failed',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
serializer = WordlistSerializer(wordlist)
|
||||
return Response(serializer.data)
|
||||
return success_response(data=serializer.data)
|
||||
|
||||
@@ -9,6 +9,8 @@ from rest_framework import viewsets, status
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
from apps.engine.serializers import WorkerNodeSerializer
|
||||
from apps.engine.services import WorkerService
|
||||
from apps.common.signals import worker_delete_failed
|
||||
@@ -111,9 +113,8 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
|
||||
threading.Thread(target=_async_remote_uninstall, daemon=True).start()
|
||||
|
||||
# 3. 立即返回成功
|
||||
return Response(
|
||||
{"message": f"节点 {worker_name} 已删除"},
|
||||
status=status.HTTP_200_OK
|
||||
return success_response(
|
||||
data={'name': worker_name}
|
||||
)
|
||||
|
||||
@action(detail=True, methods=['post'])
|
||||
@@ -190,11 +191,13 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
|
||||
worker.status = 'online'
|
||||
worker.save(update_fields=['status'])
|
||||
|
||||
return Response({
|
||||
'status': 'ok',
|
||||
'need_update': need_update,
|
||||
'server_version': server_version
|
||||
})
|
||||
return success_response(
|
||||
data={
|
||||
'status': 'ok',
|
||||
'needUpdate': need_update,
|
||||
'serverVersion': server_version
|
||||
}
|
||||
)
|
||||
|
||||
def _trigger_remote_agent_update(self, worker, target_version: str):
|
||||
"""
|
||||
@@ -238,7 +241,7 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
|
||||
docker run -d --pull=always \
|
||||
--name xingrin-agent \
|
||||
--restart always \
|
||||
-e HEARTBEAT_API_URL="https://{django_settings.PUBLIC_HOST}" \
|
||||
-e HEARTBEAT_API_URL="https://{django_settings.PUBLIC_HOST}:{getattr(django_settings, 'PUBLIC_PORT', '8083')}" \
|
||||
-e WORKER_ID="{worker_id}" \
|
||||
-e IMAGE_TAG="{target_version}" \
|
||||
-v /proc:/host/proc:ro \
|
||||
@@ -304,9 +307,10 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
|
||||
is_local = request.data.get('is_local', True)
|
||||
|
||||
if not name:
|
||||
return Response(
|
||||
{'error': '缺少 name 参数'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Missing name parameter',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
worker, created = self.worker_service.register_worker(
|
||||
@@ -314,11 +318,13 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
|
||||
is_local=is_local
|
||||
)
|
||||
|
||||
return Response({
|
||||
'worker_id': worker.id,
|
||||
'name': worker.name,
|
||||
'created': created
|
||||
})
|
||||
return success_response(
|
||||
data={
|
||||
'workerId': worker.id,
|
||||
'name': worker.name,
|
||||
'created': created
|
||||
}
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['get'])
|
||||
def config(self, request):
|
||||
@@ -334,13 +340,12 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
|
||||
返回:
|
||||
{
|
||||
"db": {"host": "...", "port": "...", ...},
|
||||
"redisUrl": "...",
|
||||
"paths": {"results": "...", "logs": "..."}
|
||||
}
|
||||
|
||||
配置逻辑:
|
||||
- 本地 Worker (is_local=true): db_host=postgres, redis=redis:6379
|
||||
- 远程 Worker (is_local=false): db_host=PUBLIC_HOST, redis=PUBLIC_HOST:6379
|
||||
- 本地 Worker (is_local=true): db_host=postgres
|
||||
- 远程 Worker (is_local=false): db_host=PUBLIC_HOST
|
||||
"""
|
||||
from django.conf import settings
|
||||
import logging
|
||||
@@ -365,37 +370,35 @@ class WorkerNodeViewSet(viewsets.ModelViewSet):
|
||||
if is_local_worker:
|
||||
# 本地 Worker:直接用 Docker 内部服务名
|
||||
worker_db_host = 'postgres'
|
||||
worker_redis_url = 'redis://redis:6379/0'
|
||||
else:
|
||||
# 远程 Worker:通过公网 IP 访问
|
||||
public_host = settings.PUBLIC_HOST
|
||||
if public_host in ('server', 'localhost', '127.0.0.1'):
|
||||
logger.warning("远程 Worker 请求配置,但 PUBLIC_HOST=%s 不是有效的公网地址", public_host)
|
||||
worker_db_host = public_host
|
||||
worker_redis_url = f'redis://{public_host}:6379/0'
|
||||
else:
|
||||
# 远程数据库场景:所有 Worker 都用 DB_HOST
|
||||
worker_db_host = db_host
|
||||
worker_redis_url = getattr(settings, 'WORKER_REDIS_URL', 'redis://redis:6379/0')
|
||||
|
||||
logger.info("返回 Worker 配置 - db_host: %s, redis_url: %s", worker_db_host, worker_redis_url)
|
||||
logger.info("返回 Worker 配置 - db_host: %s", worker_db_host)
|
||||
|
||||
return Response({
|
||||
'db': {
|
||||
'host': worker_db_host,
|
||||
'port': str(settings.DATABASES['default']['PORT']),
|
||||
'name': settings.DATABASES['default']['NAME'],
|
||||
'user': settings.DATABASES['default']['USER'],
|
||||
'password': settings.DATABASES['default']['PASSWORD'],
|
||||
},
|
||||
'redisUrl': worker_redis_url,
|
||||
'paths': {
|
||||
'results': getattr(settings, 'CONTAINER_RESULTS_MOUNT', '/app/backend/results'),
|
||||
'logs': getattr(settings, 'CONTAINER_LOGS_MOUNT', '/app/backend/logs'),
|
||||
},
|
||||
'logging': {
|
||||
'level': os.getenv('LOG_LEVEL', 'INFO'),
|
||||
'enableCommandLogging': os.getenv('ENABLE_COMMAND_LOGGING', 'true').lower() == 'true',
|
||||
},
|
||||
'debug': settings.DEBUG
|
||||
})
|
||||
return success_response(
|
||||
data={
|
||||
'db': {
|
||||
'host': worker_db_host,
|
||||
'port': str(settings.DATABASES['default']['PORT']),
|
||||
'name': settings.DATABASES['default']['NAME'],
|
||||
'user': settings.DATABASES['default']['USER'],
|
||||
'password': settings.DATABASES['default']['PASSWORD'],
|
||||
},
|
||||
'paths': {
|
||||
'results': getattr(settings, 'CONTAINER_RESULTS_MOUNT', '/opt/xingrin/results'),
|
||||
'logs': getattr(settings, 'CONTAINER_LOGS_MOUNT', '/opt/xingrin/logs'),
|
||||
},
|
||||
'logging': {
|
||||
'level': os.getenv('LOG_LEVEL', 'INFO'),
|
||||
'enableCommandLogging': os.getenv('ENABLE_COMMAND_LOGGING', 'true').lower() == 'true',
|
||||
},
|
||||
'debug': settings.DEBUG,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
from django.conf import settings
|
||||
|
||||
# ==================== 路径配置 ====================
|
||||
SCAN_TOOLS_BASE_PATH = getattr(settings, 'SCAN_TOOLS_BASE_PATH', '/opt/xingrin/tools')
|
||||
SCAN_TOOLS_BASE_PATH = getattr(settings, 'SCAN_TOOLS_BASE_PATH', '/usr/local/bin')
|
||||
|
||||
# ==================== 子域名发现 ====================
|
||||
|
||||
@@ -35,7 +35,7 @@ SUBDOMAIN_DISCOVERY_COMMANDS = {
|
||||
},
|
||||
|
||||
'sublist3r': {
|
||||
'base': "python3 '{scan_tools_base}/Sublist3r/sublist3r.py' -d {domain} -o '{output_file}'",
|
||||
'base': "python3 '/usr/local/share/Sublist3r/sublist3r.py' -d {domain} -o '{output_file}'",
|
||||
'optional': {
|
||||
'threads': '-t {threads}'
|
||||
}
|
||||
@@ -115,7 +115,7 @@ SITE_SCAN_COMMANDS = {
|
||||
|
||||
DIRECTORY_SCAN_COMMANDS = {
|
||||
'ffuf': {
|
||||
'base': "ffuf -u '{url}FUZZ' -se -ac -sf -json -w '{wordlist}'",
|
||||
'base': "'{scan_tools_base}/ffuf' -u '{url}FUZZ' -se -ac -sf -json -w '{wordlist}'",
|
||||
'optional': {
|
||||
'delay': '-p {delay}',
|
||||
'threads': '-t {threads}',
|
||||
@@ -225,12 +225,35 @@ VULN_SCAN_COMMANDS = {
|
||||
}
|
||||
|
||||
|
||||
# ==================== 指纹识别 ====================
|
||||
|
||||
FINGERPRINT_DETECT_COMMANDS = {
|
||||
'xingfinger': {
|
||||
# 流式输出模式(不使用 -o,输出到 stdout)
|
||||
# -l: URL 列表文件输入
|
||||
# -s: 静默模式,只输出命中结果
|
||||
# --json: JSON 格式输出(每行一条)
|
||||
'base': "xingfinger -l '{urls_file}' -s --json",
|
||||
'optional': {
|
||||
# 自定义指纹库路径
|
||||
'ehole': '--ehole {ehole}',
|
||||
'goby': '--goby {goby}',
|
||||
'wappalyzer': '--wappalyzer {wappalyzer}',
|
||||
'fingers': '--fingers {fingers}',
|
||||
'fingerprinthub': '--fingerprint {fingerprinthub}',
|
||||
'arl': '--arl {arl}',
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ==================== 工具映射 ====================
|
||||
|
||||
COMMAND_TEMPLATES = {
|
||||
'subdomain_discovery': SUBDOMAIN_DISCOVERY_COMMANDS,
|
||||
'port_scan': PORT_SCAN_COMMANDS,
|
||||
'site_scan': SITE_SCAN_COMMANDS,
|
||||
'fingerprint_detect': FINGERPRINT_DETECT_COMMANDS,
|
||||
'directory_scan': DIRECTORY_SCAN_COMMANDS,
|
||||
'url_fetch': URL_FETCH_COMMANDS,
|
||||
'vuln_scan': VULN_SCAN_COMMANDS,
|
||||
@@ -242,7 +265,7 @@ COMMAND_TEMPLATES = {
|
||||
EXECUTION_STAGES = [
|
||||
{
|
||||
'mode': 'sequential',
|
||||
'flows': ['subdomain_discovery', 'port_scan', 'site_scan']
|
||||
'flows': ['subdomain_discovery', 'port_scan', 'site_scan', 'fingerprint_detect']
|
||||
},
|
||||
{
|
||||
'mode': 'parallel',
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# 引擎配置
|
||||
#
|
||||
# 参数命名:统一用中划线(如 rate-limit),系统自动转换为下划线
|
||||
# 必需参数:enabled(是否启用)、timeout(超时秒数,auto 表示自动计算)
|
||||
# 必需参数:enabled(是否启用)
|
||||
# 可选参数:timeout(超时秒数,默认 auto 自动计算)
|
||||
|
||||
# ==================== 子域名发现 ====================
|
||||
#
|
||||
@@ -39,7 +40,7 @@ subdomain_discovery:
|
||||
bruteforce:
|
||||
enabled: false
|
||||
subdomain_bruteforce:
|
||||
timeout: auto # 自动根据字典行数计算
|
||||
# timeout: auto # 自动根据字典行数计算
|
||||
wordlist-name: subdomains-top1million-110000.txt # 对应「字典管理」中的 Wordlist.name
|
||||
|
||||
# === Stage 3: 变异生成 + 存活验证(可选)===
|
||||
@@ -52,14 +53,14 @@ subdomain_discovery:
|
||||
resolve:
|
||||
enabled: true
|
||||
subdomain_resolve:
|
||||
timeout: auto # 自动根据候选子域数量计算
|
||||
timeout: auto # 自动根据候选子域数量计算
|
||||
|
||||
# ==================== 端口扫描 ====================
|
||||
port_scan:
|
||||
tools:
|
||||
naabu_active:
|
||||
enabled: true
|
||||
timeout: auto # 自动计算(目标数 × 端口数 × 0.5秒),范围 60秒 ~ 2天
|
||||
# timeout: auto # 自动计算(目标数 × 端口数 × 0.5秒),范围 60秒 ~ 2天
|
||||
threads: 200 # 并发连接数(默认 5)
|
||||
# ports: 1-65535 # 扫描端口范围(默认 1-65535)
|
||||
top-ports: 100 # 扫描 nmap top 100 端口
|
||||
@@ -67,25 +68,33 @@ port_scan:
|
||||
|
||||
naabu_passive:
|
||||
enabled: true
|
||||
timeout: auto # 被动扫描通常较快
|
||||
# timeout: auto # 被动扫描通常较快
|
||||
|
||||
# ==================== 站点扫描 ====================
|
||||
site_scan:
|
||||
tools:
|
||||
httpx:
|
||||
enabled: true
|
||||
timeout: auto # 自动计算(每个 URL 约 1 秒)
|
||||
# timeout: auto # 自动计算(每个 URL 约 1 秒)
|
||||
# threads: 50 # 并发线程数(默认 50)
|
||||
# rate-limit: 150 # 每秒请求数(默认 150)
|
||||
# request-timeout: 10 # 单个请求超时秒数(默认 10)
|
||||
# retries: 2 # 请求失败重试次数
|
||||
|
||||
# ==================== 指纹识别 ====================
|
||||
# 在 site_scan 后串行执行,识别 WebSite 的技术栈
|
||||
fingerprint_detect:
|
||||
tools:
|
||||
xingfinger:
|
||||
enabled: true
|
||||
fingerprint-libs: [ehole, goby, wappalyzer, fingers, fingerprinthub, arl] # 全部指纹库
|
||||
|
||||
# ==================== 目录扫描 ====================
|
||||
directory_scan:
|
||||
tools:
|
||||
ffuf:
|
||||
enabled: true
|
||||
timeout: auto # 自动计算(字典行数 × 0.02秒),范围 60秒 ~ 2小时
|
||||
# timeout: auto # 自动计算(字典行数 × 0.02秒),范围 60秒 ~ 2小时
|
||||
max-workers: 5 # 并发扫描站点数(默认 5)
|
||||
wordlist-name: dir_default.txt # 对应「字典管理」中的 Wordlist.name
|
||||
delay: 0.1-2.0 # 请求间隔,支持范围随机(如 "0.1-2.0")
|
||||
@@ -103,7 +112,7 @@ url_fetch:
|
||||
|
||||
katana:
|
||||
enabled: true
|
||||
timeout: auto # 自动计算(根据站点数量)
|
||||
# timeout: auto # 自动计算(根据站点数量)
|
||||
depth: 5 # 爬取最大深度(默认 3)
|
||||
threads: 10 # 全局并发数
|
||||
rate-limit: 30 # 每秒最多请求数
|
||||
@@ -113,7 +122,7 @@ url_fetch:
|
||||
|
||||
uro:
|
||||
enabled: true
|
||||
timeout: auto # 自动计算(每 100 个 URL 约 1 秒),范围 30 ~ 300 秒
|
||||
# timeout: auto # 自动计算(每 100 个 URL 约 1 秒),范围 30 ~ 300 秒
|
||||
# whitelist: # 只保留指定扩展名
|
||||
# - php
|
||||
# - asp
|
||||
@@ -127,7 +136,7 @@ url_fetch:
|
||||
|
||||
httpx:
|
||||
enabled: true
|
||||
timeout: auto # 自动计算(每个 URL 约 1 秒)
|
||||
# timeout: auto # 自动计算(每个 URL 约 1 秒)
|
||||
# threads: 50 # 并发线程数(默认 50)
|
||||
# rate-limit: 150 # 每秒请求数(默认 150)
|
||||
# request-timeout: 10 # 单个请求超时秒数(默认 10)
|
||||
@@ -138,7 +147,7 @@ vuln_scan:
|
||||
tools:
|
||||
dalfox_xss:
|
||||
enabled: true
|
||||
timeout: auto # 自动计算(endpoints 行数 × 100 秒)
|
||||
# timeout: auto # 自动计算(endpoints 行数 × 100 秒)
|
||||
request-timeout: 10 # 单个请求超时秒数
|
||||
only-poc: r # 只输出 POC 结果(r: 反射型)
|
||||
ignore-return: "302,404,403" # 忽略的返回码
|
||||
@@ -149,7 +158,7 @@ vuln_scan:
|
||||
|
||||
nuclei:
|
||||
enabled: true
|
||||
timeout: auto # 自动计算(根据 endpoints 行数)
|
||||
# timeout: auto # 自动计算(根据 endpoints 行数)
|
||||
template-repo-names: # 模板仓库列表,对应「Nuclei 模板」中的仓库名
|
||||
- nuclei-templates
|
||||
# - nuclei-custom # 可追加自定义仓库
|
||||
|
||||
@@ -5,8 +5,10 @@
|
||||
|
||||
from .initiate_scan_flow import initiate_scan_flow
|
||||
from .subdomain_discovery_flow import subdomain_discovery_flow
|
||||
from .fingerprint_detect_flow import fingerprint_detect_flow
|
||||
|
||||
__all__ = [
|
||||
'initiate_scan_flow',
|
||||
'subdomain_discovery_flow',
|
||||
'fingerprint_detect_flow',
|
||||
]
|
||||
|
||||
@@ -140,28 +140,7 @@ def _get_max_workers(tool_config: dict, default: int = DEFAULT_MAX_WORKERS) -> i
|
||||
return default
|
||||
|
||||
|
||||
def _setup_directory_scan_directory(scan_workspace_dir: str) -> Path:
|
||||
"""
|
||||
创建并验证目录扫描工作目录
|
||||
|
||||
Args:
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
|
||||
Returns:
|
||||
Path: 目录扫描目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 目录创建或验证失败
|
||||
"""
|
||||
directory_scan_dir = Path(scan_workspace_dir) / 'directory_scan'
|
||||
directory_scan_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not directory_scan_dir.is_dir():
|
||||
raise RuntimeError(f"目录扫描目录创建失败: {directory_scan_dir}")
|
||||
if not os.access(directory_scan_dir, os.W_OK):
|
||||
raise RuntimeError(f"目录扫描目录不可写: {directory_scan_dir}")
|
||||
|
||||
return directory_scan_dir
|
||||
|
||||
|
||||
|
||||
def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path) -> tuple[str, int]:
|
||||
@@ -185,8 +164,7 @@ def _export_site_urls(target_id: int, target_name: str, directory_scan_dir: Path
|
||||
export_result = export_sites_task(
|
||||
target_id=target_id,
|
||||
output_file=sites_file,
|
||||
batch_size=1000, # 每次读取 1000 条,优化内存占用
|
||||
target_name=target_name # 传入 target_name 用于懒加载
|
||||
batch_size=1000 # 每次读取 1000 条,优化内存占用
|
||||
)
|
||||
|
||||
site_count = export_result['total_count']
|
||||
@@ -483,13 +461,23 @@ def _run_scans_concurrently(
|
||||
logger.warning("没有有效的扫描任务")
|
||||
continue
|
||||
|
||||
# 使用 ThreadPoolTaskRunner 并发执行
|
||||
logger.info("开始并发提交 %d 个扫描任务...", len(scan_params_list))
|
||||
# ============================================================
|
||||
# 分批执行策略:控制实际并发的 ffuf 进程数
|
||||
# ============================================================
|
||||
total_tasks = len(scan_params_list)
|
||||
logger.info("开始分批执行 %d 个扫描任务(每批 %d 个)...", total_tasks, max_workers)
|
||||
|
||||
with ThreadPoolTaskRunner(max_workers=max_workers) as task_runner:
|
||||
# 提交所有任务
|
||||
batch_num = 0
|
||||
for batch_start in range(0, total_tasks, max_workers):
|
||||
batch_end = min(batch_start + max_workers, total_tasks)
|
||||
batch_params = scan_params_list[batch_start:batch_end]
|
||||
batch_num += 1
|
||||
|
||||
logger.info("执行第 %d 批任务(%d-%d/%d)...", batch_num, batch_start + 1, batch_end, total_tasks)
|
||||
|
||||
# 提交当前批次的任务(非阻塞,立即返回 future)
|
||||
futures = []
|
||||
for params in scan_params_list:
|
||||
for params in batch_params:
|
||||
future = run_and_stream_save_directories_task.submit(
|
||||
cmd=params['command'],
|
||||
tool_name=tool_name,
|
||||
@@ -504,12 +492,10 @@ def _run_scans_concurrently(
|
||||
)
|
||||
futures.append((params['idx'], params['site_url'], future))
|
||||
|
||||
logger.info("✓ 已提交 %d 个扫描任务,等待完成...", len(futures))
|
||||
|
||||
# 等待所有任务完成并聚合结果
|
||||
# 等待当前批次所有任务完成(阻塞,确保本批完成后再启动下一批)
|
||||
for idx, site_url, future in futures:
|
||||
try:
|
||||
result = future.result()
|
||||
result = future.result() # 阻塞等待单个任务完成
|
||||
directories_found = result.get('created_directories', 0)
|
||||
total_directories += directories_found
|
||||
processed_sites_count += 1
|
||||
@@ -521,7 +507,6 @@ def _run_scans_concurrently(
|
||||
|
||||
except Exception as exc:
|
||||
failed_sites.append(site_url)
|
||||
# 判断是否为超时异常
|
||||
if 'timeout' in str(exc).lower() or isinstance(exc, subprocess.TimeoutExpired):
|
||||
logger.warning(
|
||||
"⚠️ [%d/%d] 站点扫描超时: %s - 错误: %s",
|
||||
@@ -633,7 +618,8 @@ def directory_scan_flow(
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
directory_scan_dir = _setup_directory_scan_directory(scan_workspace_dir)
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
directory_scan_dir = setup_scan_directory(scan_workspace_dir, 'directory_scan')
|
||||
|
||||
# Step 1: 导出站点 URL(支持懒加载)
|
||||
sites_file, site_count = _export_site_urls(target_id, target_name, directory_scan_dir)
|
||||
|
||||
376
backend/apps/scan/flows/fingerprint_detect_flow.py
Normal file
376
backend/apps/scan/flows/fingerprint_detect_flow.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
指纹识别 Flow
|
||||
|
||||
负责编排指纹识别的完整流程
|
||||
|
||||
架构:
|
||||
- Flow 负责编排多个原子 Task
|
||||
- 在 site_scan 后串行执行
|
||||
- 使用 xingfinger 工具识别技术栈
|
||||
- 流式处理输出,批量更新数据库
|
||||
"""
|
||||
|
||||
# Django 环境初始化(导入即生效)
|
||||
from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.handlers.scan_flow_handlers import (
|
||||
on_scan_flow_running,
|
||||
on_scan_flow_completed,
|
||||
on_scan_flow_failed,
|
||||
)
|
||||
from apps.scan.tasks.fingerprint_detect import (
|
||||
export_urls_for_fingerprint_task,
|
||||
run_xingfinger_and_stream_update_tech_task,
|
||||
)
|
||||
from apps.scan.utils import build_scan_command
|
||||
from apps.scan.utils.fingerprint_helpers import get_fingerprint_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_fingerprint_detect_timeout(
|
||||
url_count: int,
|
||||
base_per_url: float = 5.0,
|
||||
min_timeout: int = 300
|
||||
) -> int:
|
||||
"""
|
||||
根据 URL 数量计算超时时间
|
||||
|
||||
公式:超时时间 = URL 数量 × 每 URL 基础时间
|
||||
最小值:300秒
|
||||
无上限
|
||||
|
||||
Args:
|
||||
url_count: URL 数量
|
||||
base_per_url: 每 URL 基础时间(秒),默认 5秒
|
||||
min_timeout: 最小超时时间(秒),默认 300秒
|
||||
|
||||
Returns:
|
||||
int: 计算出的超时时间(秒)
|
||||
|
||||
"""
|
||||
timeout = int(url_count * base_per_url)
|
||||
return max(min_timeout, timeout)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _export_urls(
|
||||
target_id: int,
|
||||
fingerprint_dir: Path,
|
||||
source: str = 'website'
|
||||
) -> tuple[str, int]:
|
||||
"""
|
||||
导出 URL 到文件
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
fingerprint_dir: 指纹识别目录
|
||||
source: 数据源类型
|
||||
|
||||
Returns:
|
||||
tuple: (urls_file, total_count)
|
||||
"""
|
||||
logger.info("Step 1: 导出 URL 列表 (source=%s)", source)
|
||||
|
||||
urls_file = str(fingerprint_dir / 'urls.txt')
|
||||
export_result = export_urls_for_fingerprint_task(
|
||||
target_id=target_id,
|
||||
output_file=urls_file,
|
||||
source=source,
|
||||
batch_size=1000
|
||||
)
|
||||
|
||||
total_count = export_result['total_count']
|
||||
|
||||
logger.info(
|
||||
"✓ URL 导出完成 - 文件: %s, 数量: %d",
|
||||
export_result['output_file'],
|
||||
total_count
|
||||
)
|
||||
|
||||
return export_result['output_file'], total_count
|
||||
|
||||
|
||||
def _run_fingerprint_detect(
|
||||
enabled_tools: dict,
|
||||
urls_file: str,
|
||||
url_count: int,
|
||||
fingerprint_dir: Path,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
source: str
|
||||
) -> tuple[dict, list]:
|
||||
"""
|
||||
执行指纹识别任务
|
||||
|
||||
Args:
|
||||
enabled_tools: 已启用的工具配置字典
|
||||
urls_file: URL 文件路径
|
||||
url_count: URL 总数
|
||||
fingerprint_dir: 指纹识别目录
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
source: 数据源类型
|
||||
|
||||
Returns:
|
||||
tuple: (tool_stats, failed_tools)
|
||||
"""
|
||||
tool_stats = {}
|
||||
failed_tools = []
|
||||
|
||||
for tool_name, tool_config in enabled_tools.items():
|
||||
# 1. 获取指纹库路径
|
||||
lib_names = tool_config.get('fingerprint_libs', ['ehole'])
|
||||
fingerprint_paths = get_fingerprint_paths(lib_names)
|
||||
|
||||
if not fingerprint_paths:
|
||||
reason = f"没有可用的指纹库: {lib_names}"
|
||||
logger.warning(reason)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
# 2. 将指纹库路径合并到 tool_config(用于命令构建)
|
||||
tool_config_with_paths = {**tool_config, **fingerprint_paths}
|
||||
|
||||
# 3. 构建命令
|
||||
try:
|
||||
command = build_scan_command(
|
||||
tool_name=tool_name,
|
||||
scan_type='fingerprint_detect',
|
||||
command_params={
|
||||
'urls_file': urls_file
|
||||
},
|
||||
tool_config=tool_config_with_paths
|
||||
)
|
||||
except Exception as e:
|
||||
reason = f"命令构建失败: {str(e)}"
|
||||
logger.error("构建 %s 命令失败: %s", tool_name, e)
|
||||
failed_tools.append({'tool': tool_name, 'reason': reason})
|
||||
continue
|
||||
|
||||
# 4. 计算超时时间
|
||||
timeout = calculate_fingerprint_detect_timeout(url_count)
|
||||
|
||||
# 5. 生成日志文件路径
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_file = fingerprint_dir / f"{tool_name}_{timestamp}.log"
|
||||
|
||||
logger.info(
|
||||
"开始执行 %s 指纹识别 - URL数: %d, 超时: %ds, 指纹库: %s",
|
||||
tool_name, url_count, timeout, list(fingerprint_paths.keys())
|
||||
)
|
||||
|
||||
# 6. 执行扫描任务
|
||||
try:
|
||||
result = run_xingfinger_and_stream_update_tech_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
source=source,
|
||||
cwd=str(fingerprint_dir),
|
||||
timeout=timeout,
|
||||
log_file=str(log_file),
|
||||
batch_size=100
|
||||
)
|
||||
|
||||
tool_stats[tool_name] = {
|
||||
'command': command,
|
||||
'result': result,
|
||||
'timeout': timeout,
|
||||
'fingerprint_libs': list(fingerprint_paths.keys())
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"✓ 工具 %s 执行完成 - 处理记录: %d, 更新: %d, 未找到: %d",
|
||||
tool_name,
|
||||
result.get('processed_records', 0),
|
||||
result.get('updated_count', 0),
|
||||
result.get('not_found_count', 0)
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
failed_tools.append({'tool': tool_name, 'reason': str(exc)})
|
||||
logger.error("工具 %s 执行失败: %s", tool_name, exc, exc_info=True)
|
||||
|
||||
if failed_tools:
|
||||
logger.warning(
|
||||
"以下指纹识别工具执行失败: %s",
|
||||
', '.join([f['tool'] for f in failed_tools])
|
||||
)
|
||||
|
||||
return tool_stats, failed_tools
|
||||
|
||||
|
||||
@flow(
|
||||
name="fingerprint_detect",
|
||||
log_prints=True,
|
||||
on_running=[on_scan_flow_running],
|
||||
on_completion=[on_scan_flow_completed],
|
||||
on_failure=[on_scan_flow_failed],
|
||||
)
|
||||
def fingerprint_detect_flow(
|
||||
scan_id: int,
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict
|
||||
) -> dict:
|
||||
"""
|
||||
指纹识别 Flow
|
||||
|
||||
主要功能:
|
||||
1. 从数据库导出目标下所有 WebSite URL 到文件
|
||||
2. 使用 xingfinger 进行技术栈识别
|
||||
3. 解析结果并更新 WebSite.tech 字段(合并去重)
|
||||
|
||||
工作流程:
|
||||
Step 0: 创建工作目录
|
||||
Step 1: 导出 URL 列表
|
||||
Step 2: 解析配置,获取启用的工具
|
||||
Step 3: 执行 xingfinger 并解析结果
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID
|
||||
target_name: 目标名称
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置(xingfinger)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'scan_id': int,
|
||||
'target': str,
|
||||
'scan_workspace_dir': str,
|
||||
'urls_file': str,
|
||||
'url_count': int,
|
||||
'processed_records': int,
|
||||
'updated_count': int,
|
||||
'not_found_count': int,
|
||||
'executed_tasks': list,
|
||||
'tool_stats': dict
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"="*60 + "\n" +
|
||||
"开始指纹识别\n" +
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
"="*60
|
||||
)
|
||||
|
||||
# 参数验证
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为空")
|
||||
if not target_name:
|
||||
raise ValueError("target_name 不能为空")
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为空")
|
||||
if not scan_workspace_dir:
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
|
||||
# 数据源类型(当前只支持 website)
|
||||
source = 'website'
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
fingerprint_dir = setup_scan_directory(scan_workspace_dir, 'fingerprint_detect')
|
||||
|
||||
# Step 1: 导出 URL(支持懒加载)
|
||||
urls_file, url_count = _export_urls(target_id, fingerprint_dir, source)
|
||||
|
||||
if url_count == 0:
|
||||
logger.warning("目标下没有可用的 URL,跳过指纹识别")
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'url_count': 0,
|
||||
'processed_records': 0,
|
||||
'updated_count': 0,
|
||||
'created_count': 0,
|
||||
'executed_tasks': ['export_urls_for_fingerprint'],
|
||||
'tool_stats': {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'successful_tools': [],
|
||||
'failed_tools': [],
|
||||
'details': {}
|
||||
}
|
||||
}
|
||||
|
||||
# Step 2: 工具配置信息
|
||||
logger.info("Step 2: 工具配置信息")
|
||||
logger.info("✓ 启用工具: %s", ', '.join(enabled_tools.keys()))
|
||||
|
||||
# Step 3: 执行指纹识别
|
||||
logger.info("Step 3: 执行指纹识别")
|
||||
tool_stats, failed_tools = _run_fingerprint_detect(
|
||||
enabled_tools=enabled_tools,
|
||||
urls_file=urls_file,
|
||||
url_count=url_count,
|
||||
fingerprint_dir=fingerprint_dir,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
source=source
|
||||
)
|
||||
|
||||
logger.info("="*60 + "\n✓ 指纹识别完成\n" + "="*60)
|
||||
|
||||
# 动态生成已执行的任务列表
|
||||
executed_tasks = ['export_urls_for_fingerprint']
|
||||
executed_tasks.extend([f'run_xingfinger ({tool})' for tool in tool_stats.keys()])
|
||||
|
||||
# 汇总所有工具的结果
|
||||
total_processed = sum(stats['result'].get('processed_records', 0) for stats in tool_stats.values())
|
||||
total_updated = sum(stats['result'].get('updated_count', 0) for stats in tool_stats.values())
|
||||
total_created = sum(stats['result'].get('created_count', 0) for stats in tool_stats.values())
|
||||
|
||||
successful_tools = [name for name in enabled_tools.keys()
|
||||
if name not in [f['tool'] for f in failed_tools]]
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'scan_id': scan_id,
|
||||
'target': target_name,
|
||||
'scan_workspace_dir': scan_workspace_dir,
|
||||
'urls_file': urls_file,
|
||||
'url_count': url_count,
|
||||
'processed_records': total_processed,
|
||||
'updated_count': total_updated,
|
||||
'created_count': total_created,
|
||||
'executed_tasks': executed_tasks,
|
||||
'tool_stats': {
|
||||
'total': len(enabled_tools),
|
||||
'successful': len(successful_tools),
|
||||
'failed': len(failed_tools),
|
||||
'successful_tools': successful_tools,
|
||||
'failed_tools': failed_tools,
|
||||
'details': tool_stats
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("配置错误: %s", e)
|
||||
raise
|
||||
except RuntimeError as e:
|
||||
logger.error("运行时错误: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("指纹识别失败: %s", e)
|
||||
raise
|
||||
@@ -30,7 +30,7 @@ from apps.scan.handlers import (
|
||||
on_initiate_scan_flow_failed,
|
||||
)
|
||||
from prefect.futures import wait
|
||||
from apps.scan.tasks.workspace_tasks import create_scan_workspace_task
|
||||
from apps.scan.utils import setup_scan_workspace
|
||||
from apps.scan.orchestrators import FlowOrchestrator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -110,7 +110,7 @@ def initiate_scan_flow(
|
||||
)
|
||||
|
||||
# ==================== Task 1: 创建 Scan 工作空间 ====================
|
||||
scan_workspace_path = create_scan_workspace_task(scan_workspace_dir)
|
||||
scan_workspace_path = setup_scan_workspace(scan_workspace_dir)
|
||||
|
||||
# ==================== Task 2: 获取引擎配置 ====================
|
||||
from apps.scan.models import Scan
|
||||
|
||||
@@ -154,28 +154,7 @@ def _parse_port_count(tool_config: dict) -> int:
|
||||
return 100
|
||||
|
||||
|
||||
def _setup_port_scan_directory(scan_workspace_dir: str) -> Path:
|
||||
"""
|
||||
创建并验证端口扫描工作目录
|
||||
|
||||
Args:
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
|
||||
Returns:
|
||||
Path: 端口扫描目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 目录创建或验证失败
|
||||
"""
|
||||
port_scan_dir = Path(scan_workspace_dir) / 'port_scan'
|
||||
port_scan_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not port_scan_dir.is_dir():
|
||||
raise RuntimeError(f"端口扫描目录创建失败: {port_scan_dir}")
|
||||
if not os.access(port_scan_dir, os.W_OK):
|
||||
raise RuntimeError(f"端口扫描目录不可写: {port_scan_dir}")
|
||||
|
||||
return port_scan_dir
|
||||
|
||||
|
||||
|
||||
def _export_scan_targets(target_id: int, port_scan_dir: Path) -> tuple[str, int, str]:
|
||||
@@ -442,7 +421,8 @@ def port_scan_flow(
|
||||
)
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
port_scan_dir = _setup_port_scan_directory(scan_workspace_dir)
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
port_scan_dir = setup_scan_directory(scan_workspace_dir, 'port_scan')
|
||||
|
||||
# Step 1: 导出扫描目标列表到文件(根据 Target 类型自动决定内容)
|
||||
targets_file, target_count, target_type = _export_scan_targets(target_id, port_scan_dir)
|
||||
|
||||
@@ -85,28 +85,7 @@ def calculate_timeout_by_line_count(
|
||||
return min_timeout
|
||||
|
||||
|
||||
def _setup_site_scan_directory(scan_workspace_dir: str) -> Path:
|
||||
"""
|
||||
创建并验证站点扫描工作目录
|
||||
|
||||
Args:
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
|
||||
Returns:
|
||||
Path: 站点扫描目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 目录创建或验证失败
|
||||
"""
|
||||
site_scan_dir = Path(scan_workspace_dir) / 'site_scan'
|
||||
site_scan_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not site_scan_dir.is_dir():
|
||||
raise RuntimeError(f"站点扫描目录创建失败: {site_scan_dir}")
|
||||
if not os.access(site_scan_dir, os.W_OK):
|
||||
raise RuntimeError(f"站点扫描目录不可写: {site_scan_dir}")
|
||||
|
||||
return site_scan_dir
|
||||
|
||||
|
||||
|
||||
def _export_site_urls(target_id: int, site_scan_dir: Path, target_name: str = None) -> tuple[str, int, int]:
|
||||
@@ -130,7 +109,6 @@ def _export_site_urls(target_id: int, site_scan_dir: Path, target_name: str = No
|
||||
export_result = export_site_urls_task(
|
||||
target_id=target_id,
|
||||
output_file=urls_file,
|
||||
target_name=target_name,
|
||||
batch_size=1000 # 每次处理1000个子域名
|
||||
)
|
||||
|
||||
@@ -403,7 +381,8 @@ def site_scan_flow(
|
||||
raise ValueError("scan_workspace_dir 不能为空")
|
||||
|
||||
# Step 0: 创建工作目录
|
||||
site_scan_dir = _setup_site_scan_directory(scan_workspace_dir)
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
site_scan_dir = setup_scan_directory(scan_workspace_dir, 'site_scan')
|
||||
|
||||
# Step 1: 导出站点 URL
|
||||
urls_file, total_urls, association_count = _export_site_urls(
|
||||
|
||||
@@ -41,28 +41,7 @@ import subprocess
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _setup_subdomain_directory(scan_workspace_dir: str) -> Path:
|
||||
"""
|
||||
创建并验证子域名扫描工作目录
|
||||
|
||||
Args:
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
|
||||
Returns:
|
||||
Path: 子域名扫描目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 目录创建或验证失败
|
||||
"""
|
||||
result_dir = Path(scan_workspace_dir) / 'subdomain_discovery'
|
||||
result_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not result_dir.is_dir():
|
||||
raise RuntimeError(f"子域名扫描目录创建失败: {result_dir}")
|
||||
if not os.access(result_dir, os.W_OK):
|
||||
raise RuntimeError(f"子域名扫描目录不可写: {result_dir}")
|
||||
|
||||
return result_dir
|
||||
|
||||
|
||||
|
||||
def _validate_and_normalize_target(target_name: str) -> str:
|
||||
@@ -119,12 +98,7 @@ def _run_scans_parallel(
|
||||
|
||||
# 生成时间戳(所有工具共用)
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
|
||||
# TODO: 接入代理池管理系统
|
||||
# from apps.proxy.services import proxy_pool
|
||||
# proxy_stats = proxy_pool.get_stats()
|
||||
# logger.info(f"代理池状态: {proxy_stats['healthy']}/{proxy_stats['total']} 可用")
|
||||
|
||||
|
||||
failures = [] # 记录命令构建失败的工具
|
||||
futures = {}
|
||||
|
||||
@@ -417,7 +391,8 @@ def subdomain_discovery_flow(
|
||||
)
|
||||
|
||||
# Step 0: 准备工作
|
||||
result_dir = _setup_subdomain_directory(scan_workspace_dir)
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
result_dir = setup_scan_directory(scan_workspace_dir, 'subdomain_discovery')
|
||||
|
||||
# 验证并规范化目标域名
|
||||
try:
|
||||
|
||||
@@ -42,17 +42,7 @@ SITES_FILE_TOOLS = {'katana'}
|
||||
POST_PROCESS_TOOLS = {'uro', 'httpx'}
|
||||
|
||||
|
||||
def _setup_url_fetch_directory(scan_workspace_dir: str) -> Path:
|
||||
"""创建并验证 URL 获取工作目录"""
|
||||
url_fetch_dir = Path(scan_workspace_dir) / 'url_fetch'
|
||||
url_fetch_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not url_fetch_dir.is_dir():
|
||||
raise RuntimeError(f"URL 获取目录创建失败: {url_fetch_dir}")
|
||||
if not os.access(url_fetch_dir, os.W_OK):
|
||||
raise RuntimeError(f"URL 获取目录不可写: {url_fetch_dir}")
|
||||
|
||||
return url_fetch_dir
|
||||
|
||||
|
||||
|
||||
def _classify_tools(enabled_tools: dict) -> tuple[dict, dict, dict, dict]:
|
||||
@@ -304,7 +294,8 @@ def url_fetch_flow(
|
||||
|
||||
# Step 1: 准备工作目录
|
||||
logger.info("Step 1: 准备工作目录")
|
||||
url_fetch_dir = _setup_url_fetch_directory(scan_workspace_dir)
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
url_fetch_dir = setup_scan_directory(scan_workspace_dir, 'url_fetch')
|
||||
|
||||
# Step 2: 分类工具(按输入类型)
|
||||
logger.info("Step 2: 分类工具")
|
||||
|
||||
@@ -40,8 +40,7 @@ def _export_sites_file(target_id: int, scan_id: int, target_name: str, output_di
|
||||
result = export_sites_task(
|
||||
output_file=output_file,
|
||||
target_id=target_id,
|
||||
scan_id=scan_id,
|
||||
target_name=target_name
|
||||
scan_id=scan_id
|
||||
)
|
||||
|
||||
count = result['asset_count']
|
||||
|
||||
@@ -25,10 +25,7 @@ from .utils import calculate_timeout_by_line_count
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _setup_vuln_scan_directory(scan_workspace_dir: str) -> Path:
|
||||
vuln_scan_dir = Path(scan_workspace_dir) / "vuln_scan"
|
||||
vuln_scan_dir.mkdir(parents=True, exist_ok=True)
|
||||
return vuln_scan_dir
|
||||
|
||||
|
||||
|
||||
@flow(
|
||||
@@ -55,14 +52,14 @@ def endpoints_vuln_scan_flow(
|
||||
if not enabled_tools:
|
||||
raise ValueError("enabled_tools 不能为空")
|
||||
|
||||
vuln_scan_dir = _setup_vuln_scan_directory(scan_workspace_dir)
|
||||
from apps.scan.utils import setup_scan_directory
|
||||
vuln_scan_dir = setup_scan_directory(scan_workspace_dir, 'vuln_scan')
|
||||
endpoints_file = vuln_scan_dir / "input_endpoints.txt"
|
||||
|
||||
# Step 1: 导出 Endpoint URL
|
||||
export_result = export_endpoints_task(
|
||||
target_id=target_id,
|
||||
output_file=str(endpoints_file),
|
||||
target_name=target_name, # 传入 target_name 用于生成默认端点
|
||||
)
|
||||
total_endpoints = export_result.get("total_count", 0)
|
||||
|
||||
|
||||
@@ -5,12 +5,13 @@ WebSocket Consumer - 通知实时推送
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
|
||||
from apps.common.websocket_auth import AuthenticatedWebsocketConsumer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationConsumer(AsyncWebsocketConsumer):
|
||||
class NotificationConsumer(AuthenticatedWebsocketConsumer):
|
||||
"""
|
||||
通知 WebSocket Consumer
|
||||
|
||||
@@ -23,9 +24,9 @@ class NotificationConsumer(AsyncWebsocketConsumer):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.heartbeat_task = None # 心跳任务
|
||||
|
||||
async def connect(self):
|
||||
async def on_connect(self):
|
||||
"""
|
||||
客户端连接时调用
|
||||
客户端连接时调用(已通过认证)
|
||||
加入通知广播组
|
||||
"""
|
||||
# 通知组名(所有客户端共享)
|
||||
|
||||
@@ -87,8 +87,8 @@ def on_all_workers_high_load(sender, worker_name, cpu, mem, **kwargs):
|
||||
"""所有 Worker 高负载时的通知处理"""
|
||||
create_notification(
|
||||
title="系统负载较高",
|
||||
message=f"所有节点负载较高,已选择负载最低的节点 {worker_name}(CPU: {cpu:.1f}%, 内存: {mem:.1f}%)执行任务,扫描速度可能受影响",
|
||||
message=f"所有节点负载较高(最低负载节点 CPU: {cpu:.1f}%, 内存: {mem:.1f}%),系统将等待最多 10 分钟后分发任务,扫描速度可能受影响",
|
||||
level=NotificationLevel.MEDIUM,
|
||||
category=NotificationCategory.SYSTEM
|
||||
)
|
||||
logger.warning("高负载通知已发送 - worker=%s, cpu=%.1f%%, mem=%.1f%%", worker_name, cpu, mem)
|
||||
logger.warning("高负载通知已发送 - cpu=%.1f%%, mem=%.1f%%", cpu, mem)
|
||||
|
||||
@@ -305,6 +305,7 @@ def _push_via_api_callback(notification: Notification, server_url: str) -> None:
|
||||
通过 HTTP 请求 Server 容器的 /api/callbacks/notification/ 接口。
|
||||
Worker 无法直接访问 Redis,需要由 Server 代为推送 WebSocket。
|
||||
"""
|
||||
import os
|
||||
import requests
|
||||
|
||||
try:
|
||||
@@ -318,8 +319,14 @@ def _push_via_api_callback(notification: Notification, server_url: str) -> None:
|
||||
'created_at': notification.created_at.isoformat()
|
||||
}
|
||||
|
||||
# 构建请求头(包含 Worker API Key)
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
worker_api_key = os.environ.get("WORKER_API_KEY", "")
|
||||
if worker_api_key:
|
||||
headers["X-Worker-API-Key"] = worker_api_key
|
||||
|
||||
# verify=False: 远程 Worker 回调 Server 时可能使用自签名证书
|
||||
resp = requests.post(callback_url, json=data, timeout=5, verify=False)
|
||||
resp = requests.post(callback_url, json=data, headers=headers, timeout=5, verify=False)
|
||||
resp.raise_for_status()
|
||||
|
||||
logger.debug(f"通知回调推送成功 - ID: {notification.id}")
|
||||
|
||||
@@ -14,6 +14,8 @@ from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
from .models import Notification
|
||||
from .serializers import NotificationSerializer
|
||||
from .types import NotificationLevel
|
||||
@@ -60,34 +62,7 @@ def notifications_test(request):
|
||||
}, status=500)
|
||||
|
||||
|
||||
def build_api_response(
|
||||
data: Any = None,
|
||||
*,
|
||||
message: str = '操作成功',
|
||||
code: str = '200',
|
||||
state: str = 'success',
|
||||
status_code: int = status.HTTP_200_OK
|
||||
) -> Response:
|
||||
"""构建统一的 API 响应格式
|
||||
|
||||
Args:
|
||||
data: 响应数据体(可选)
|
||||
message: 响应消息
|
||||
code: 响应代码
|
||||
state: 响应状态(success/error)
|
||||
status_code: HTTP 状态码
|
||||
|
||||
Returns:
|
||||
DRF Response 对象
|
||||
"""
|
||||
payload = {
|
||||
'code': code,
|
||||
'state': state,
|
||||
'message': message,
|
||||
}
|
||||
if data is not None:
|
||||
payload['data'] = data
|
||||
return Response(payload, status=status_code)
|
||||
# build_api_response 已废弃,请使用 success_response/error_response
|
||||
|
||||
|
||||
def _parse_bool(value: str | None) -> bool | None:
|
||||
@@ -172,7 +147,7 @@ class NotificationUnreadCountView(APIView):
|
||||
"""获取未读通知数量"""
|
||||
service = NotificationService()
|
||||
count = service.get_unread_count()
|
||||
return build_api_response({'count': count}, message='获取未读数量成功')
|
||||
return success_response(data={'count': count})
|
||||
|
||||
|
||||
class NotificationMarkAllAsReadView(APIView):
|
||||
@@ -192,7 +167,7 @@ class NotificationMarkAllAsReadView(APIView):
|
||||
"""标记全部通知为已读"""
|
||||
service = NotificationService()
|
||||
updated = service.mark_all_as_read()
|
||||
return build_api_response({'updated': updated}, message='全部标记已读成功')
|
||||
return success_response(data={'updated': updated})
|
||||
|
||||
|
||||
class NotificationSettingsView(APIView):
|
||||
@@ -209,13 +184,13 @@ class NotificationSettingsView(APIView):
|
||||
"""获取通知设置"""
|
||||
service = NotificationSettingsService()
|
||||
settings = service.get_settings()
|
||||
return Response(settings)
|
||||
return success_response(data=settings)
|
||||
|
||||
def put(self, request: Request) -> Response:
|
||||
"""更新通知设置"""
|
||||
service = NotificationSettingsService()
|
||||
settings = service.update_settings(request.data)
|
||||
return Response({'message': '已保存通知设置', **settings})
|
||||
return success_response(data=settings)
|
||||
|
||||
|
||||
# ============================================
|
||||
@@ -247,22 +222,24 @@ def notification_callback(request):
|
||||
required_fields = ['id', 'category', 'title', 'message', 'level', 'created_at']
|
||||
for field in required_fields:
|
||||
if field not in data:
|
||||
return Response(
|
||||
{'error': f'缺少字段: {field}'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=f'Missing field: {field}',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 推送到 WebSocket
|
||||
_push_notification_to_websocket(data)
|
||||
|
||||
logger.debug(f"回调通知推送成功 - ID: {data['id']}, Title: {data['title']}")
|
||||
return Response({'status': 'ok'})
|
||||
return success_response(data={'status': 'ok'})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"回调通知处理失败: {e}", exc_info=True)
|
||||
return Response(
|
||||
{'error': str(e)},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -206,6 +206,10 @@ class FlowOrchestrator:
|
||||
from apps.scan.flows.site_scan_flow import site_scan_flow
|
||||
return site_scan_flow
|
||||
|
||||
elif scan_type == 'fingerprint_detect':
|
||||
from apps.scan.flows.fingerprint_detect_flow import fingerprint_detect_flow
|
||||
return fingerprint_detect_flow
|
||||
|
||||
elif scan_type == 'directory_scan':
|
||||
from apps.scan.flows.directory_scan_flow import directory_scan_flow
|
||||
return directory_scan_flow
|
||||
|
||||
@@ -83,7 +83,7 @@ def cleanup_results(results_dir: str, retention_days: int) -> dict:
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="清理任务")
|
||||
parser.add_argument("--results_dir", type=str, default="/app/backend/results", help="扫描结果目录")
|
||||
parser.add_argument("--results_dir", type=str, default="/opt/xingrin/results", help="扫描结果目录")
|
||||
parser.add_argument("--retention_days", type=int, default=7, help="保留天数")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -17,6 +17,8 @@ from .scan_state_service import ScanStateService
|
||||
from .scan_control_service import ScanControlService
|
||||
from .scan_stats_service import ScanStatsService
|
||||
from .scheduled_scan_service import ScheduledScanService
|
||||
from .blacklist_service import BlacklistService
|
||||
from .target_export_service import TargetExportService
|
||||
|
||||
__all__ = [
|
||||
'ScanService', # 主入口(向后兼容)
|
||||
@@ -25,5 +27,7 @@ __all__ = [
|
||||
'ScanControlService',
|
||||
'ScanStatsService',
|
||||
'ScheduledScanService',
|
||||
'BlacklistService', # 黑名单过滤服务
|
||||
'TargetExportService', # 目标导出服务
|
||||
]
|
||||
|
||||
|
||||
82
backend/apps/scan/services/blacklist_service.py
Normal file
82
backend/apps/scan/services/blacklist_service.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
黑名单过滤服务
|
||||
|
||||
过滤敏感域名(如 .gov、.edu、.mil 等)
|
||||
|
||||
当前版本使用默认规则,后续将支持从前端配置加载。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from django.db.models import QuerySet
|
||||
import re
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BlacklistService:
|
||||
"""
|
||||
黑名单过滤服务 - 过滤敏感域名
|
||||
|
||||
TODO: 后续版本支持从前端配置加载黑名单规则
|
||||
- 用户在开始扫描时配置黑名单 URL、域名、IP
|
||||
- 黑名单规则存储在数据库中,与 Scan 或 Engine 关联
|
||||
"""
|
||||
|
||||
# 默认黑名单正则规则
|
||||
DEFAULT_PATTERNS = [
|
||||
r'\.gov$', # .gov 结尾
|
||||
r'\.gov\.[a-z]{2}$', # .gov.cn, .gov.uk 等
|
||||
]
|
||||
|
||||
def __init__(self, patterns: Optional[List[str]] = None):
|
||||
"""
|
||||
初始化黑名单服务
|
||||
|
||||
Args:
|
||||
patterns: 正则表达式列表,None 使用默认规则
|
||||
"""
|
||||
self.patterns = patterns or self.DEFAULT_PATTERNS
|
||||
self._compiled_patterns = [re.compile(p) for p in self.patterns]
|
||||
|
||||
def filter_queryset(
|
||||
self,
|
||||
queryset: QuerySet,
|
||||
url_field: str = 'url'
|
||||
) -> QuerySet:
|
||||
"""
|
||||
数据库层面过滤 queryset
|
||||
|
||||
使用 PostgreSQL 正则表达式排除黑名单 URL
|
||||
|
||||
Args:
|
||||
queryset: 原始 queryset
|
||||
url_field: URL 字段名
|
||||
|
||||
Returns:
|
||||
QuerySet: 过滤后的 queryset
|
||||
"""
|
||||
for pattern in self.patterns:
|
||||
queryset = queryset.exclude(**{f'{url_field}__regex': pattern})
|
||||
return queryset
|
||||
|
||||
def filter_url(self, url: str) -> bool:
|
||||
"""
|
||||
检查单个 URL 是否通过黑名单过滤
|
||||
|
||||
Args:
|
||||
url: 要检查的 URL
|
||||
|
||||
Returns:
|
||||
bool: True 表示通过(不在黑名单),False 表示被过滤
|
||||
"""
|
||||
for pattern in self._compiled_patterns:
|
||||
if pattern.search(url):
|
||||
return False
|
||||
return True
|
||||
|
||||
# TODO: 后续版本实现
|
||||
# @classmethod
|
||||
# def from_scan(cls, scan_id: int) -> 'BlacklistService':
|
||||
# """从数据库加载扫描配置的黑名单规则"""
|
||||
# pass
|
||||
364
backend/apps/scan/services/target_export_service.py
Normal file
364
backend/apps/scan/services/target_export_service.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
目标导出服务
|
||||
|
||||
提供统一的目标提取和文件导出功能,支持:
|
||||
- URL 导出(流式写入 + 默认值回退)
|
||||
- 域名/IP 导出(用于端口扫描)
|
||||
- 黑名单过滤集成
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Iterator
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from .blacklist_service import BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TargetExportService:
|
||||
"""
|
||||
目标导出服务 - 提供统一的目标提取和文件导出功能
|
||||
|
||||
使用方式:
|
||||
# Task 层决定数据源
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
# 使用导出服务
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
result = export_service.export_urls(target_id, output_path, queryset)
|
||||
"""
|
||||
|
||||
def __init__(self, blacklist_service: Optional[BlacklistService] = None):
|
||||
"""
|
||||
初始化导出服务
|
||||
|
||||
Args:
|
||||
blacklist_service: 黑名单过滤服务,None 表示禁用过滤
|
||||
"""
|
||||
self.blacklist_service = blacklist_service
|
||||
|
||||
def export_urls(
|
||||
self,
|
||||
target_id: int,
|
||||
output_path: str,
|
||||
queryset: QuerySet,
|
||||
url_field: str = 'url',
|
||||
batch_size: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
统一 URL 导出函数
|
||||
|
||||
自动判断数据库有无数据:
|
||||
- 有数据:流式写入数据库数据到文件
|
||||
- 无数据:调用默认值生成器生成 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_path: 输出文件路径
|
||||
queryset: 数据源 queryset(由 Task 层构建,应为 values_list flat=True)
|
||||
url_field: URL 字段名(用于黑名单过滤)
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int
|
||||
}
|
||||
|
||||
Raises:
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
output_file = Path(output_path)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("开始导出 URL - target_id=%s, output=%s", target_id, output_path)
|
||||
|
||||
# 应用黑名单过滤(数据库层面)
|
||||
if self.blacklist_service:
|
||||
# 注意:queryset 应该是原始 queryset,不是 values_list
|
||||
# 这里假设 Task 层传入的是 values_list,需要在 Task 层处理过滤
|
||||
pass
|
||||
|
||||
total_count = 0
|
||||
try:
|
||||
with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in queryset.iterator(chunk_size=batch_size):
|
||||
if url:
|
||||
# Python 层面黑名单过滤
|
||||
if self.blacklist_service and not self.blacklist_service.filter_url(url):
|
||||
continue
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
except IOError as e:
|
||||
logger.error("文件写入失败: %s - %s", output_path, e)
|
||||
raise
|
||||
|
||||
# 默认值回退模式
|
||||
if total_count == 0:
|
||||
total_count = self._generate_default_urls(target_id, output_file)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 数量: %d, 文件: %s", total_count, output_path)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_file),
|
||||
'total_count': total_count
|
||||
}
|
||||
|
||||
def _generate_default_urls(
|
||||
self,
|
||||
target_id: int,
|
||||
output_path: Path
|
||||
) -> int:
|
||||
"""
|
||||
默认值生成器(内部函数)
|
||||
|
||||
根据 Target 类型生成默认 URL:
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
- URL: 直接使用目标 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 写入的 URL 总数
|
||||
"""
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
return 0
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name)
|
||||
|
||||
total_urls = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
urls = [f"http://{target_name}", f"https://{target_name}"]
|
||||
for url in urls:
|
||||
if self._should_write_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
urls = [f"http://{target_name}", f"https://{target_name}"]
|
||||
for url in urls:
|
||||
if self._should_write_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
|
||||
for ip in network.hosts():
|
||||
urls = [f"http://{ip}", f"https://{ip}"]
|
||||
for url in urls:
|
||||
if self._should_write_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
if total_urls % 10000 == 0:
|
||||
logger.info("已生成 %d 个 URL...", total_urls)
|
||||
|
||||
# /32 或 /128 特殊处理
|
||||
if total_urls == 0:
|
||||
ip = str(network.network_address)
|
||||
urls = [f"http://{ip}", f"https://{ip}"]
|
||||
for url in urls:
|
||||
if self._should_write_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
raise ValueError(f"无效的 CIDR: {target_name}") from e
|
||||
|
||||
elif target_type == Target.TargetType.URL:
|
||||
if self._should_write_url(target_name):
|
||||
f.write(f"{target_name}\n")
|
||||
total_urls = 1
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
|
||||
logger.info("✓ 懒加载生成默认 URL - 数量: %d", total_urls)
|
||||
return total_urls
|
||||
|
||||
def _should_write_url(self, url: str) -> bool:
|
||||
"""检查 URL 是否应该写入(通过黑名单过滤)"""
|
||||
if self.blacklist_service:
|
||||
return self.blacklist_service.filter_url(url)
|
||||
return True
|
||||
|
||||
def export_targets(
|
||||
self,
|
||||
target_id: int,
|
||||
output_path: str,
|
||||
batch_size: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
域名/IP 导出函数(用于端口扫描)
|
||||
|
||||
根据 Target 类型选择导出逻辑:
|
||||
- DOMAIN: 从 Subdomain 表流式导出子域名
|
||||
- IP: 直接写入 IP 地址
|
||||
- CIDR: 展开为所有主机 IP
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_path: 输出文件路径
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int,
|
||||
'target_type': str
|
||||
}
|
||||
"""
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
|
||||
output_file = Path(output_path)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取 Target 信息
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
|
||||
if not target:
|
||||
raise ValueError(f"Target ID {target_id} 不存在")
|
||||
|
||||
target_type = target.type
|
||||
target_name = target.name
|
||||
|
||||
logger.info(
|
||||
"开始导出扫描目标 - Target ID: %d, Name: %s, Type: %s, 输出文件: %s",
|
||||
target_id, target_name, target_type, output_path
|
||||
)
|
||||
|
||||
total_count = 0
|
||||
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
total_count = self._export_domains(target_id, target_name, output_file, batch_size)
|
||||
type_desc = "域名"
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
total_count = self._export_ip(target_name, output_file)
|
||||
type_desc = "IP"
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
total_count = self._export_cidr(target_name, output_file)
|
||||
type_desc = "CIDR IP"
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的目标类型: {target_type}")
|
||||
|
||||
logger.info(
|
||||
"✓ 扫描目标导出完成 - 类型: %s, 总数: %d, 文件: %s",
|
||||
type_desc, total_count, output_path
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_file),
|
||||
'total_count': total_count,
|
||||
'target_type': target_type
|
||||
}
|
||||
|
||||
def _export_domains(
|
||||
self,
|
||||
target_id: int,
|
||||
target_name: str,
|
||||
output_path: Path,
|
||||
batch_size: int
|
||||
) -> int:
|
||||
"""导出域名类型目标的子域名"""
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
|
||||
subdomain_service = SubdomainService()
|
||||
domain_iterator = subdomain_service.iter_subdomain_names_by_target(
|
||||
target_id=target_id,
|
||||
chunk_size=batch_size
|
||||
)
|
||||
|
||||
total_count = 0
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for domain_name in domain_iterator:
|
||||
if self._should_write_target(domain_name):
|
||||
f.write(f"{domain_name}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个域名...", total_count)
|
||||
|
||||
# 默认值模式:如果没有子域名,使用根域名
|
||||
if total_count == 0:
|
||||
logger.info("采用默认域名:%s (target_id=%d)", target_name, target_id)
|
||||
if self._should_write_target(target_name):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{target_name}\n")
|
||||
total_count = 1
|
||||
|
||||
return total_count
|
||||
|
||||
def _export_ip(self, target_name: str, output_path: Path) -> int:
|
||||
"""导出 IP 类型目标"""
|
||||
if self._should_write_target(target_name):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{target_name}\n")
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def _export_cidr(self, target_name: str, output_path: Path) -> int:
|
||||
"""导出 CIDR 类型目标,展开为每个 IP"""
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
total_count = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for ip in network.hosts():
|
||||
ip_str = str(ip)
|
||||
if self._should_write_target(ip_str):
|
||||
f.write(f"{ip_str}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个 IP...", total_count)
|
||||
|
||||
# /32 或 /128 特殊处理
|
||||
if total_count == 0:
|
||||
ip_str = str(network.network_address)
|
||||
if self._should_write_target(ip_str):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{ip_str}\n")
|
||||
total_count = 1
|
||||
|
||||
return total_count
|
||||
|
||||
def _should_write_target(self, target: str) -> bool:
|
||||
"""检查目标是否应该写入(通过黑名单过滤)"""
|
||||
if self.blacklist_service:
|
||||
return self.blacklist_service.filter_url(target)
|
||||
return True
|
||||
@@ -9,9 +9,6 @@
|
||||
- Tasks 负责具体操作,Flow 负责编排
|
||||
"""
|
||||
|
||||
# Prefect Tasks
|
||||
from .workspace_tasks import create_scan_workspace_task
|
||||
|
||||
# 子域名发现任务(已重构为多个子任务)
|
||||
from .subdomain_discovery import (
|
||||
run_subdomain_discovery_task,
|
||||
@@ -19,17 +16,25 @@ from .subdomain_discovery import (
|
||||
save_domains_task,
|
||||
)
|
||||
|
||||
# 指纹识别任务
|
||||
from .fingerprint_detect import (
|
||||
export_urls_for_fingerprint_task,
|
||||
run_xingfinger_and_stream_update_tech_task,
|
||||
)
|
||||
|
||||
# 注意:
|
||||
# - subdomain_discovery_task 已重构为多个子任务(subdomain_discovery/)
|
||||
# - finalize_scan_task 已废弃(Handler 统一管理状态)
|
||||
# - initiate_scan_task 已迁移到 flows/initiate_scan_flow.py
|
||||
# - cleanup_old_scans_task 已迁移到 flows(cleanup_old_scans_flow)
|
||||
# - create_scan_workspace_task 已删除,直接使用 setup_scan_workspace()
|
||||
|
||||
__all__ = [
|
||||
# Prefect Tasks
|
||||
'create_scan_workspace_task',
|
||||
# 子域名发现任务
|
||||
'run_subdomain_discovery_task',
|
||||
'merge_and_validate_task',
|
||||
'save_domains_task',
|
||||
# 指纹识别任务
|
||||
'export_urls_for_fingerprint_task',
|
||||
'run_xingfinger_and_stream_update_tech_task',
|
||||
]
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
"""
|
||||
导出站点 URL 到 TXT 文件的 Task
|
||||
|
||||
使用流式处理,避免大量站点导致内存溢出
|
||||
支持默认值模式:如果没有站点,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://target_name
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
数据源: WebSite.url
|
||||
"""
|
||||
import logging
|
||||
import ipaddress
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.repositories import DjangoWebSiteRepository
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.asset.models import WebSite
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,19 +18,22 @@ def export_sites_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
batch_size: int = 1000,
|
||||
target_name: str = None
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点 URL 到 TXT 文件
|
||||
|
||||
使用流式处理,支持大规模数据导出(10万+站点)
|
||||
支持默认值模式:如果没有站点,自动使用默认站点 URL(http(s)://target_name)
|
||||
数据源: WebSite.url
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
batch_size: 每次读取的批次大小,默认 1000
|
||||
target_name: 目标名称(用于默认值模式)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
@@ -49,134 +46,26 @@ def export_sites_task(
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
try:
|
||||
# 初始化 Repository
|
||||
repository = DjangoWebSiteRepository()
|
||||
|
||||
logger.info("开始导出站点 URL - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用 Repository 流式查询站点 URL
|
||||
url_iterator = repository.get_urls_for_export(
|
||||
target_id=target_id,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 流式写入文件
|
||||
total_count = 0
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in url_iterator:
|
||||
# 每次只处理一个 URL,边读边写
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
# 每写入 10000 条记录打印一次进度
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个站点 URL...", total_count)
|
||||
|
||||
# ==================== 懒加载模式:根据 Target 类型生成默认 URL ====================
|
||||
if total_count == 0:
|
||||
total_count = _write_default_urls(target_id, target_name, output_path)
|
||||
|
||||
logger.info(
|
||||
"✓ 站点 URL 导出完成 - 总数: %d, 文件: %s (%.2f KB)",
|
||||
total_count,
|
||||
str(output_path), # 使用绝对路径
|
||||
output_path.stat().st_size / 1024
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count
|
||||
}
|
||||
|
||||
except FileNotFoundError as e:
|
||||
logger.error("输出目录不存在: %s", e)
|
||||
raise
|
||||
except PermissionError as e:
|
||||
logger.error("文件写入权限不足: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("导出站点 URL 失败: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _write_default_urls(target_id: int, target_name: str, output_path: Path) -> int:
|
||||
"""
|
||||
懒加载模式:根据 Target 类型生成默认 URL
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(可选,如果为空则从数据库查询)
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 生成的 URL 数量
|
||||
"""
|
||||
# 获取 Target 信息
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
return 0
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name)
|
||||
|
||||
total_urls = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
# 域名类型:生成 http(s)://domain
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ 域名默认 URL 已写入: http(s)://%s", target_name)
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
# IP 类型:生成 http(s)://ip
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ IP 默认 URL 已写入: http(s)://%s", target_name)
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
# CIDR 类型:展开为所有 IP 的 URL
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
|
||||
for ip in network.hosts(): # 排除网络地址和广播地址
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls += 2
|
||||
|
||||
if total_urls % 10000 == 0:
|
||||
logger.info("已生成 %d 个 URL...", total_urls)
|
||||
|
||||
# 如果是 /32 或 /128(单个 IP),hosts() 会为空
|
||||
if total_urls == 0:
|
||||
ip = str(network.network_address)
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls = 2
|
||||
|
||||
logger.info("✓ CIDR 默认 URL 已写入: %d 个 URL (来自 %s)", total_urls, target_name)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
return 0
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
return 0
|
||||
|
||||
return total_urls
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'success': result['success'],
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count']
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
15
backend/apps/scan/tasks/fingerprint_detect/__init__.py
Normal file
15
backend/apps/scan/tasks/fingerprint_detect/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
指纹识别任务模块
|
||||
|
||||
包含:
|
||||
- export_urls_for_fingerprint_task: 导出 URL 到文件
|
||||
- run_xingfinger_and_stream_update_tech_task: 流式执行 xingfinger 并更新 tech
|
||||
"""
|
||||
|
||||
from .export_urls_task import export_urls_for_fingerprint_task
|
||||
from .run_xingfinger_task import run_xingfinger_and_stream_update_tech_task
|
||||
|
||||
__all__ = [
|
||||
'export_urls_for_fingerprint_task',
|
||||
'run_xingfinger_and_stream_update_tech_task',
|
||||
]
|
||||
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
导出 URL 任务
|
||||
|
||||
用于指纹识别前导出目标下的 URL 到文件
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.models import WebSite
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_urls_for_fingerprint")
|
||||
def export_urls_for_fingerprint_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
source: str = 'website',
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的 URL 到文件(用于指纹识别)
|
||||
|
||||
数据源: WebSite.url
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
- URL: 直接使用目标 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径
|
||||
source: 数据源类型(保留参数,兼容旧调用)
|
||||
batch_size: 批量读取大小
|
||||
|
||||
Returns:
|
||||
dict: {'output_file': str, 'total_count': int, 'source': str}
|
||||
"""
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
'source': source
|
||||
}
|
||||
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
xingfinger 执行任务
|
||||
|
||||
流式执行 xingfinger 命令并实时更新 tech 字段
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
from typing import Optional, Generator
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from django.db import connection
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.utils import execute_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 数据源映射:source → (module_path, model_name, url_field)
|
||||
SOURCE_MODEL_MAP = {
|
||||
'website': ('apps.asset.models', 'WebSite', 'url'),
|
||||
# 以后扩展:
|
||||
# 'endpoint': ('apps.asset.models', 'Endpoint', 'url'),
|
||||
# 'directory': ('apps.asset.models', 'Directory', 'url'),
|
||||
}
|
||||
|
||||
|
||||
def _get_model_class(source: str):
|
||||
"""根据数据源类型获取 Model 类"""
|
||||
if source not in SOURCE_MODEL_MAP:
|
||||
raise ValueError(f"不支持的数据源: {source}")
|
||||
|
||||
module_path, model_name, _ = SOURCE_MODEL_MAP[source]
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, model_name)
|
||||
|
||||
|
||||
def parse_xingfinger_line(line: str) -> tuple[str, list[str]] | None:
|
||||
"""
|
||||
解析 xingfinger 单行 JSON 输出
|
||||
|
||||
xingfinger 静默模式输出格式:
|
||||
{"url": "https://example.com", "cms": "WordPress,PHP,nginx", ...}
|
||||
|
||||
Returns:
|
||||
tuple: (url, tech_list) 或 None(解析失败时)
|
||||
"""
|
||||
try:
|
||||
item = json.loads(line)
|
||||
url = item.get('url', '').strip()
|
||||
cms = item.get('cms', '')
|
||||
|
||||
if not url or not cms:
|
||||
return None
|
||||
|
||||
# cms 字段按逗号分割,去除空白
|
||||
techs = [t.strip() for t in cms.split(',') if t.strip()]
|
||||
|
||||
return (url, techs) if techs else None
|
||||
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
|
||||
def bulk_merge_tech_field(
|
||||
source: str,
|
||||
url_techs_map: dict[str, list[str]],
|
||||
target_id: int
|
||||
) -> dict:
|
||||
"""
|
||||
批量合并 tech 数组字段(PostgreSQL 原生 SQL)
|
||||
|
||||
使用 PostgreSQL 原生 SQL 实现高效的数组合并去重操作。
|
||||
如果 URL 对应的记录不存在,会自动创建新记录。
|
||||
|
||||
Returns:
|
||||
dict: {'updated_count': int, 'created_count': int}
|
||||
"""
|
||||
Model = _get_model_class(source)
|
||||
table_name = Model._meta.db_table
|
||||
|
||||
updated_count = 0
|
||||
created_count = 0
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
for url, techs in url_techs_map.items():
|
||||
if not techs:
|
||||
continue
|
||||
|
||||
# 先尝试更新(PostgreSQL 数组合并去重)
|
||||
sql = f"""
|
||||
UPDATE {table_name}
|
||||
SET tech = (
|
||||
SELECT ARRAY(SELECT DISTINCT unnest(
|
||||
COALESCE(tech, ARRAY[]::varchar[]) || %s::varchar[]
|
||||
))
|
||||
)
|
||||
WHERE url = %s AND target_id = %s
|
||||
"""
|
||||
|
||||
cursor.execute(sql, [techs, url, target_id])
|
||||
|
||||
if cursor.rowcount > 0:
|
||||
updated_count += cursor.rowcount
|
||||
else:
|
||||
# 记录不存在,创建新记录
|
||||
try:
|
||||
# 从 URL 提取 host
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname or ''
|
||||
|
||||
# 插入新记录(带冲突处理)
|
||||
# 显式传入所有 NOT NULL 字段的默认值
|
||||
insert_sql = f"""
|
||||
INSERT INTO {table_name} (target_id, url, host, location, title, webserver, body_preview, content_type, tech, created_at)
|
||||
VALUES (%s, %s, %s, '', '', '', '', '', %s::varchar[], NOW())
|
||||
ON CONFLICT (target_id, url) DO UPDATE SET
|
||||
tech = (
|
||||
SELECT ARRAY(SELECT DISTINCT unnest(
|
||||
COALESCE({table_name}.tech, ARRAY[]::varchar[]) || EXCLUDED.tech
|
||||
))
|
||||
)
|
||||
"""
|
||||
cursor.execute(insert_sql, [target_id, url, host, techs])
|
||||
created_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("创建 %s 记录失败 (url=%s): %s", source, url, e)
|
||||
|
||||
return {
|
||||
'updated_count': updated_count,
|
||||
'created_count': created_count
|
||||
}
|
||||
|
||||
|
||||
def _parse_xingfinger_stream_output(
|
||||
cmd: str,
|
||||
tool_name: str,
|
||||
cwd: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
log_file: Optional[str] = None
|
||||
) -> Generator[tuple[str, list[str]], None, None]:
|
||||
"""
|
||||
流式解析 xingfinger 命令输出
|
||||
|
||||
基于 execute_stream 实时处理 xingfinger 命令的 stdout,将每行 JSON 输出
|
||||
转换为 (url, tech_list) 格式
|
||||
"""
|
||||
logger.info("开始流式解析 xingfinger 命令输出 - 命令: %s", cmd)
|
||||
|
||||
total_lines = 0
|
||||
valid_records = 0
|
||||
|
||||
try:
|
||||
for line in execute_stream(cmd=cmd, tool_name=tool_name, cwd=cwd, shell=True, timeout=timeout, log_file=log_file):
|
||||
total_lines += 1
|
||||
|
||||
# 解析单行 JSON
|
||||
result = parse_xingfinger_line(line)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
valid_records += 1
|
||||
yield result
|
||||
|
||||
# 每处理 500 条记录输出一次进度
|
||||
if valid_records % 500 == 0:
|
||||
logger.info("已解析 %d 条有效记录...", valid_records)
|
||||
|
||||
except subprocess.TimeoutExpired as e:
|
||||
error_msg = f"xingfinger 命令执行超时 - 超过 {timeout} 秒"
|
||||
logger.warning(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
except Exception as e:
|
||||
logger.error("流式解析 xingfinger 输出失败: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
logger.info("流式解析完成 - 总行数: %d, 有效记录: %d", total_lines, valid_records)
|
||||
|
||||
|
||||
@task(name="run_xingfinger_and_stream_update_tech")
|
||||
def run_xingfinger_and_stream_update_tech_task(
|
||||
cmd: str,
|
||||
tool_name: str,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
source: str,
|
||||
cwd: str,
|
||||
timeout: int,
|
||||
log_file: str,
|
||||
batch_size: int = 100
|
||||
) -> dict:
|
||||
"""
|
||||
流式执行 xingfinger 命令并实时更新 tech 字段
|
||||
|
||||
根据 source 参数更新对应表的 tech 字段:
|
||||
- website → WebSite.tech
|
||||
- endpoint → Endpoint.tech(以后扩展)
|
||||
|
||||
处理流程:
|
||||
1. 流式执行 xingfinger 命令
|
||||
2. 实时解析 JSON 输出
|
||||
3. 累积到 batch_size 条后批量更新数据库
|
||||
4. 使用 PostgreSQL 原生 SQL 进行数组合并去重
|
||||
5. 如果记录不存在,自动创建
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'processed_records': int,
|
||||
'updated_count': int,
|
||||
'created_count': int,
|
||||
'batch_count': int
|
||||
}
|
||||
"""
|
||||
logger.info(
|
||||
"开始执行 xingfinger 并更新 tech - target_id=%s, source=%s, timeout=%s秒",
|
||||
target_id, source, timeout
|
||||
)
|
||||
|
||||
data_generator = None
|
||||
|
||||
try:
|
||||
# 初始化统计
|
||||
processed_records = 0
|
||||
updated_count = 0
|
||||
created_count = 0
|
||||
batch_count = 0
|
||||
|
||||
# 当前批次的 URL -> techs 映射
|
||||
url_techs_map = {}
|
||||
|
||||
# 流式处理
|
||||
data_generator = _parse_xingfinger_stream_output(
|
||||
cmd=cmd,
|
||||
tool_name=tool_name,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
log_file=log_file
|
||||
)
|
||||
|
||||
for url, techs in data_generator:
|
||||
processed_records += 1
|
||||
|
||||
# 累积到 url_techs_map
|
||||
if url in url_techs_map:
|
||||
# 合并同一 URL 的多次识别结果
|
||||
url_techs_map[url].extend(techs)
|
||||
else:
|
||||
url_techs_map[url] = techs
|
||||
|
||||
# 达到批次大小,执行批量更新
|
||||
if len(url_techs_map) >= batch_size:
|
||||
batch_count += 1
|
||||
result = bulk_merge_tech_field(source, url_techs_map, target_id)
|
||||
updated_count += result['updated_count']
|
||||
created_count += result.get('created_count', 0)
|
||||
|
||||
logger.debug(
|
||||
"批次 %d 完成 - 更新: %d, 创建: %d",
|
||||
batch_count, result['updated_count'], result.get('created_count', 0)
|
||||
)
|
||||
|
||||
# 清空批次
|
||||
url_techs_map = {}
|
||||
|
||||
# 处理最后一批
|
||||
if url_techs_map:
|
||||
batch_count += 1
|
||||
result = bulk_merge_tech_field(source, url_techs_map, target_id)
|
||||
updated_count += result['updated_count']
|
||||
created_count += result.get('created_count', 0)
|
||||
|
||||
logger.info(
|
||||
"✓ xingfinger 执行完成 - 处理记录: %d, 更新: %d, 创建: %d, 批次: %d",
|
||||
processed_records, updated_count, created_count, batch_count
|
||||
)
|
||||
|
||||
return {
|
||||
'processed_records': processed_records,
|
||||
'updated_count': updated_count,
|
||||
'created_count': created_count,
|
||||
'batch_count': batch_count
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("⚠️ xingfinger 执行超时 - target_id=%s, timeout=%s秒", target_id, timeout)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"xingfinger 执行失败: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
finally:
|
||||
# 清理资源
|
||||
if data_generator is not None:
|
||||
try:
|
||||
data_generator.close()
|
||||
except Exception as e:
|
||||
logger.debug("关闭生成器时出错: %s", e)
|
||||
@@ -1,119 +1,21 @@
|
||||
"""
|
||||
导出扫描目标到 TXT 文件的 Task
|
||||
|
||||
使用 TargetExportService.export_targets() 统一处理导出逻辑
|
||||
|
||||
根据 Target 类型决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名
|
||||
- IP: 直接写入 target.name
|
||||
- CIDR: 展开 CIDR 范围内的所有 IP
|
||||
|
||||
使用流式处理,避免大量数据导致内存溢出
|
||||
"""
|
||||
import logging
|
||||
import ipaddress
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target # 仅用于 TargetType 常量
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _export_domains(target_id: int, target_name: str, output_path: Path, batch_size: int) -> int:
|
||||
"""
|
||||
导出域名类型目标的子域名(支持默认值模式)
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(域名)
|
||||
output_path: 输出文件路径
|
||||
batch_size: 批次大小
|
||||
|
||||
Returns:
|
||||
int: 导出的记录数
|
||||
|
||||
默认值模式:
|
||||
如果没有子域名,自动使用根域名作为默认子域名
|
||||
"""
|
||||
subdomain_service = SubdomainService()
|
||||
domain_iterator = subdomain_service.iter_subdomain_names_by_target(
|
||||
target_id=target_id,
|
||||
chunk_size=batch_size
|
||||
)
|
||||
|
||||
total_count = 0
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for domain_name in domain_iterator:
|
||||
f.write(f"{domain_name}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个域名...", total_count)
|
||||
|
||||
# ==================== 采用默认域名:如果没有子域名,使用根域名 ====================
|
||||
# 只写入文件供扫描工具使用,不写入数据库
|
||||
# 数据库只存储扫描发现的真实资产
|
||||
if total_count == 0:
|
||||
logger.info("采用默认域名:%s (target_id=%d)", target_name, target_id)
|
||||
|
||||
# 只写入文件,不写入数据库
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{target_name}\n")
|
||||
total_count = 1
|
||||
|
||||
logger.info("✓ 默认域名已写入文件: %s", target_name)
|
||||
|
||||
return total_count
|
||||
|
||||
|
||||
def _export_ip(target_name: str, output_path: Path) -> int:
|
||||
"""
|
||||
导出 IP 类型目标
|
||||
|
||||
Args:
|
||||
target_name: IP 地址
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 导出的记录数(始终为 1)
|
||||
"""
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{target_name}\n")
|
||||
return 1
|
||||
|
||||
|
||||
def _export_cidr(target_name: str, output_path: Path) -> int:
|
||||
"""
|
||||
导出 CIDR 类型目标,展开为每个 IP
|
||||
|
||||
Args:
|
||||
target_name: CIDR 范围(如 192.168.1.0/24)
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 导出的 IP 数量
|
||||
"""
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
total_count = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for ip in network.hosts(): # 排除网络地址和广播地址
|
||||
f.write(f"{ip}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个 IP...", total_count)
|
||||
|
||||
# 如果是 /32 或 /128(单个 IP),hosts() 会为空,需要特殊处理
|
||||
if total_count == 0:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f"{network.network_address}\n")
|
||||
total_count = 1
|
||||
|
||||
return total_count
|
||||
|
||||
|
||||
@task(name="export_scan_targets")
|
||||
def export_scan_targets_task(
|
||||
target_id: int,
|
||||
@@ -145,62 +47,20 @@ def export_scan_targets_task(
|
||||
ValueError: Target 不存在
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
try:
|
||||
# 1. 通过 Service 层获取 Target
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
if not target:
|
||||
raise ValueError(f"Target ID {target_id} 不存在")
|
||||
|
||||
target_type = target.type
|
||||
target_name = target.name
|
||||
|
||||
logger.info(
|
||||
"开始导出扫描目标 - Target ID: %d, Name: %s, Type: %s, 输出文件: %s",
|
||||
target_id, target_name, target_type, output_file
|
||||
)
|
||||
|
||||
# 2. 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 3. 根据类型导出
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
total_count = _export_domains(target_id, target_name, output_path, batch_size)
|
||||
type_desc = "域名"
|
||||
elif target_type == Target.TargetType.IP:
|
||||
total_count = _export_ip(target_name, output_path)
|
||||
type_desc = "IP"
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
total_count = _export_cidr(target_name, output_path)
|
||||
type_desc = "CIDR IP"
|
||||
else:
|
||||
raise ValueError(f"不支持的目标类型: {target_type}")
|
||||
|
||||
logger.info(
|
||||
"✓ 扫描目标导出完成 - 类型: %s, 总数: %d, 文件: %s (%.2f KB)",
|
||||
type_desc,
|
||||
total_count,
|
||||
str(output_path),
|
||||
output_path.stat().st_size / 1024
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
'target_type': target_type
|
||||
}
|
||||
|
||||
except FileNotFoundError as e:
|
||||
logger.error("输出目录不存在: %s", e)
|
||||
raise
|
||||
except PermissionError as e:
|
||||
logger.error("文件写入权限不足: %s", e)
|
||||
raise
|
||||
except ValueError as e:
|
||||
logger.error("参数错误: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("导出扫描目标失败: %s", e)
|
||||
raise
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
result = export_service.export_targets(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'success': result['success'],
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
'target_type': result['target_type']
|
||||
}
|
||||
|
||||
@@ -2,52 +2,65 @@
|
||||
导出站点URL到文件的Task
|
||||
|
||||
直接使用 HostPortMapping 表查询 host+port 组合,拼接成URL格式写入文件
|
||||
使用 TargetExportService 处理默认值回退逻辑
|
||||
|
||||
默认值模式:
|
||||
- 如果没有 HostPortMapping 数据,写入默认 URL 到文件(不写入数据库)
|
||||
- DOMAIN: http(s)://target_name
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
"""
|
||||
import logging
|
||||
import ipaddress
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
from typing import Optional
|
||||
|
||||
from apps.asset.services import HostPortMappingService
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _generate_urls_from_port(host: str, port: int) -> list[str]:
|
||||
"""
|
||||
根据端口生成 URL 列表
|
||||
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
"""
|
||||
if port == 80:
|
||||
return [f"http://{host}"]
|
||||
elif port == 443:
|
||||
return [f"https://{host}"]
|
||||
else:
|
||||
return [f"http://{host}:{port}", f"https://{host}:{port}"]
|
||||
|
||||
|
||||
@task(name="export_site_urls")
|
||||
def export_site_urls_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_name: Optional[str] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点URL到文件(基于 HostPortMapping 表)
|
||||
|
||||
功能:
|
||||
1. 从 HostPortMapping 表查询 target 下所有 host+port 组合
|
||||
2. 拼接成URL格式(标准端口80/443将省略端口号)
|
||||
3. 写入到指定文件中
|
||||
数据源: HostPortMapping (host + port)
|
||||
|
||||
默认值模式(懒加载):
|
||||
- 如果没有 HostPortMapping 数据,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://target_name
|
||||
特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
target_name: 目标名称(用于懒加载时写入默认值)
|
||||
batch_size: 每次处理的批次大小,默认1000(暂未使用,预留)
|
||||
batch_size: 每次处理的批次大小
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
@@ -61,155 +74,54 @@ def export_site_urls_task(
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
try:
|
||||
logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 直接查询 HostPortMapping 表,按 host 排序
|
||||
service = HostPortMappingService()
|
||||
associations = service.iter_host_port_by_target(
|
||||
target_id=target_id,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
total_urls = 0
|
||||
association_count = 0
|
||||
|
||||
# 流式写入文件
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for assoc in associations:
|
||||
association_count += 1
|
||||
host = assoc['host']
|
||||
port = assoc['port']
|
||||
|
||||
# 根据端口号生成URL
|
||||
# 80 端口:只生成 HTTP URL(省略端口号)
|
||||
# 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
# 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
if port == 80:
|
||||
# HTTP 标准端口,省略端口号
|
||||
url = f"http://{host}"
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
elif port == 443:
|
||||
# HTTPS 标准端口,省略端口号
|
||||
url = f"https://{host}"
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
else:
|
||||
# 非标准端口,生成 HTTP 和 HTTPS 两个URL
|
||||
http_url = f"http://{host}:{port}"
|
||||
https_url = f"https://{host}:{port}"
|
||||
f.write(f"{http_url}\n")
|
||||
f.write(f"{https_url}\n")
|
||||
total_urls += 2
|
||||
|
||||
# 每处理1000条记录打印一次进度
|
||||
if association_count % 1000 == 0:
|
||||
logger.info("已处理 %d 条关联,生成 %d 个URL...", association_count, total_urls)
|
||||
|
||||
logger.info(
|
||||
"✓ 站点URL导出完成 - 关联数: %d, 总URL数: %d, 文件: %s (%.2f KB)",
|
||||
association_count,
|
||||
total_urls,
|
||||
str(output_path),
|
||||
output_path.stat().st_size / 1024
|
||||
)
|
||||
|
||||
# ==================== 懒加载模式:根据 Target 类型生成默认 URL ====================
|
||||
if total_urls == 0:
|
||||
total_urls = _write_default_urls(target_id, target_name, output_path)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_urls': total_urls,
|
||||
'association_count': association_count
|
||||
}
|
||||
|
||||
except FileNotFoundError as e:
|
||||
logger.error("输出目录不存在: %s", e)
|
||||
raise
|
||||
except PermissionError as e:
|
||||
logger.error("文件写入权限不足: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("导出站点URL失败: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _write_default_urls(target_id: int, target_name: Optional[str], output_path: Path) -> int:
|
||||
"""
|
||||
懒加载模式:根据 Target 类型生成默认 URL
|
||||
logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(可选,如果为空则从数据库查询)
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 生成的 URL 数量
|
||||
"""
|
||||
# 获取 Target 信息
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
return 0
|
||||
# 初始化黑名单服务
|
||||
blacklist_service = BlacklistService()
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name)
|
||||
# 直接查询 HostPortMapping 表,按 host 排序
|
||||
service = HostPortMappingService()
|
||||
associations = service.iter_host_port_by_target(
|
||||
target_id=target_id,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
total_urls = 0
|
||||
association_count = 0
|
||||
|
||||
# 流式写入文件(特殊端口逻辑)
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
# 域名类型:生成 http(s)://domain
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ 域名默认 URL 已写入: http(s)://%s", target_name)
|
||||
for assoc in associations:
|
||||
association_count += 1
|
||||
host = assoc['host']
|
||||
port = assoc['port']
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
# IP 类型:生成 http(s)://ip
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ IP 默认 URL 已写入: http(s)://%s", target_name)
|
||||
# 根据端口号生成URL
|
||||
for url in _generate_urls_from_port(host, port):
|
||||
if blacklist_service.filter_url(url):
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
# CIDR 类型:展开为所有 IP 的 URL
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
|
||||
for ip in network.hosts(): # 排除网络地址和广播地址
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls += 2
|
||||
|
||||
if total_urls % 10000 == 0:
|
||||
logger.info("已生成 %d 个 URL...", total_urls)
|
||||
|
||||
# 如果是 /32 或 /128(单个 IP),hosts() 会为空
|
||||
if total_urls == 0:
|
||||
ip = str(network.network_address)
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls = 2
|
||||
|
||||
logger.info("✓ CIDR 默认 URL 已写入: %d 个 URL (来自 %s)", total_urls, target_name)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
return 0
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
return 0
|
||||
if association_count % 1000 == 0:
|
||||
logger.info("已处理 %d 条关联,生成 %d 个URL...", association_count, total_urls)
|
||||
|
||||
return total_urls
|
||||
logger.info(
|
||||
"✓ 站点URL导出完成 - 关联数: %d, 总URL数: %d, 文件: %s",
|
||||
association_count, total_urls, str(output_path)
|
||||
)
|
||||
|
||||
# 默认值回退模式:使用 TargetExportService
|
||||
if total_urls == 0:
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
total_urls = export_service._generate_default_urls(target_id, output_path)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_urls': total_urls,
|
||||
'association_count': association_count
|
||||
}
|
||||
|
||||
@@ -1,25 +1,16 @@
|
||||
"""
|
||||
导出站点 URL 列表任务
|
||||
|
||||
从 WebSite 表导出站点 URL 列表到文件(用于 katana 等爬虫工具)
|
||||
|
||||
使用流式写入,避免内存溢出
|
||||
|
||||
懒加载模式:
|
||||
- 如果 WebSite 表为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: 写入 http(s)://domain
|
||||
- IP: 写入 http(s)://ip
|
||||
- CIDR: 展开为所有 IP
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
数据源: WebSite.url(用于 katana 等爬虫工具)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import ipaddress
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
from typing import Optional
|
||||
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.asset.models import WebSite
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,21 +24,23 @@ def export_sites_task(
|
||||
output_file: str,
|
||||
target_id: int,
|
||||
scan_id: int,
|
||||
target_name: Optional[str] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出站点 URL 列表到文件(用于 katana 等爬虫工具)
|
||||
|
||||
数据源: WebSite.url
|
||||
|
||||
懒加载模式:
|
||||
- 如果 WebSite 表为空,根据 Target 类型生成默认 URL
|
||||
- 数据库只存储"真实发现"的资产
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
|
||||
Args:
|
||||
output_file: 输出文件路径
|
||||
target_id: 目标 ID
|
||||
scan_id: 扫描 ID
|
||||
target_name: 目标名称(用于懒加载时写入默认值)
|
||||
scan_id: 扫描 ID(保留参数,兼容旧调用)
|
||||
batch_size: 批次大小(内存优化)
|
||||
|
||||
Returns:
|
||||
@@ -60,109 +53,22 @@ def export_sites_task(
|
||||
ValueError: 参数错误
|
||||
RuntimeError: 执行失败
|
||||
"""
|
||||
try:
|
||||
logger.info("开始导出站点 URL 列表 - Target ID: %d", target_id)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 从 WebSite 表导出站点 URL
|
||||
from apps.asset.services import WebSiteService
|
||||
|
||||
website_service = WebSiteService()
|
||||
|
||||
# 流式写入文件
|
||||
asset_count = 0
|
||||
with open(output_path, 'w') as f:
|
||||
for url in website_service.iter_website_urls_by_target(target_id, batch_size):
|
||||
f.write(f"{url}\n")
|
||||
asset_count += 1
|
||||
|
||||
if asset_count % batch_size == 0:
|
||||
f.flush()
|
||||
|
||||
# ==================== 懒加载模式:根据 Target 类型生成默认 URL ====================
|
||||
if asset_count == 0:
|
||||
asset_count = _write_default_urls(target_id, target_name, output_path)
|
||||
|
||||
logger.info("✓ 站点 URL 导出完成 - 文件: %s, 数量: %d", output_file, asset_count)
|
||||
|
||||
return {
|
||||
'output_file': output_file,
|
||||
'asset_count': asset_count,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("导出站点 URL 失败: %s", e, exc_info=True)
|
||||
raise RuntimeError(f"导出站点 URL 失败: {e}") from e
|
||||
|
||||
|
||||
def _write_default_urls(target_id: int, target_name: Optional[str], output_path: Path) -> int:
|
||||
"""
|
||||
懒加载模式:根据 Target 类型生成默认 URL 列表
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = WebSite.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 生成的 URL 数量
|
||||
"""
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
return 0
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name)
|
||||
|
||||
total_urls = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ 域名默认 URL 已写入: http(s)://%s", target_name)
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ IP 默认 URL 已写入: http(s)://%s", target_name)
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
|
||||
for ip in network.hosts():
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls += 2
|
||||
|
||||
if total_urls % 10000 == 0:
|
||||
logger.info("已生成 %d 个 URL...", total_urls)
|
||||
|
||||
# /32 或 /128 特殊处理
|
||||
if total_urls == 0:
|
||||
ip = str(network.network_address)
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls = 2
|
||||
|
||||
logger.info("✓ CIDR 默认 URL 已写入: %d 个 URL (来自 %s)", total_urls, target_name)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
return 0
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
return 0
|
||||
|
||||
return total_urls
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'output_file': result['output_file'],
|
||||
'asset_count': result['total_count'],
|
||||
}
|
||||
|
||||
@@ -1,25 +1,16 @@
|
||||
"""导出 Endpoint URL 到文件的 Task
|
||||
|
||||
基于 EndpointService.iter_endpoint_urls_by_target 按目标流式导出端点 URL,
|
||||
用于漏洞扫描(如 Dalfox XSS)的输入文件生成。
|
||||
|
||||
默认值模式:
|
||||
- 如果没有 Endpoint,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://target_name
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 http(s)://ip
|
||||
使用 TargetExportService 统一处理导出逻辑和默认值回退
|
||||
数据源: Endpoint.url
|
||||
"""
|
||||
|
||||
import logging
|
||||
import ipaddress
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from prefect import task
|
||||
|
||||
from apps.asset.services import EndpointService
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.asset.models import Endpoint
|
||||
from apps.scan.services import TargetExportService, BlacklistService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,17 +20,21 @@ def export_endpoints_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
batch_size: int = 1000,
|
||||
target_name: Optional[str] = None,
|
||||
) -> Dict[str, object]:
|
||||
"""导出目标下的所有 Endpoint URL 到文本文件。
|
||||
|
||||
默认值模式:如果没有 Endpoint,根据 Target 类型生成默认 URL
|
||||
数据源: Endpoint.url
|
||||
|
||||
懒加载模式:
|
||||
- 如果数据库为空,根据 Target 类型生成默认 URL
|
||||
- DOMAIN: http(s)://domain
|
||||
- IP: http(s)://ip
|
||||
- CIDR: 展开为所有 IP 的 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
batch_size: 每次从数据库迭代的批大小
|
||||
target_name: 目标名称(用于默认值模式)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
@@ -48,117 +43,23 @@ def export_endpoints_task(
|
||||
"total_count": int,
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info("开始导出 Endpoint URL - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
service = EndpointService()
|
||||
url_iterator = service.iter_endpoint_urls_by_target(target_id, chunk_size=batch_size)
|
||||
|
||||
total_count = 0
|
||||
with open(output_path, "w", encoding="utf-8", buffering=8192) as f:
|
||||
for url in url_iterator:
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 10000 == 0:
|
||||
logger.info("已导出 %d 个 Endpoint URL...", total_count)
|
||||
|
||||
# ==================== 懒加载模式:根据 Target 类型生成默认 URL ====================
|
||||
if total_count == 0:
|
||||
total_count = _write_default_urls(target_id, target_name, output_path)
|
||||
|
||||
logger.info(
|
||||
"✓ Endpoint URL 导出完成 - 总数: %d, 文件: %s (%.2f KB)",
|
||||
total_count,
|
||||
str(output_path),
|
||||
output_path.stat().st_size / 1024,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output_file": str(output_path),
|
||||
"total_count": total_count,
|
||||
}
|
||||
|
||||
except FileNotFoundError as e:
|
||||
logger.error("输出目录不存在: %s", e)
|
||||
raise
|
||||
except PermissionError as e:
|
||||
logger.error("文件写入权限不足: %s", e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("导出 Endpoint URL 失败: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _write_default_urls(target_id: int, target_name: Optional[str], output_path: Path) -> int:
|
||||
"""
|
||||
懒加载模式:根据 Target 类型生成默认 URL
|
||||
# 构建数据源 queryset(Task 层决定数据源)
|
||||
queryset = Endpoint.objects.filter(target_id=target_id).values_list('url', flat=True)
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_name: 目标名称(可选,如果为空则从数据库查询)
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
int: 生成的 URL 数量
|
||||
"""
|
||||
target_service = TargetService()
|
||||
target = target_service.get_target(target_id)
|
||||
# 使用 TargetExportService 处理导出
|
||||
blacklist_service = BlacklistService()
|
||||
export_service = TargetExportService(blacklist_service=blacklist_service)
|
||||
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在,无法生成默认 URL", target_id)
|
||||
return 0
|
||||
result = export_service.export_urls(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
queryset=queryset,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
target_name = target.name
|
||||
target_type = target.type
|
||||
|
||||
logger.info("懒加载模式:Target 类型=%s, 名称=%s", target_type, target_name)
|
||||
|
||||
total_urls = 0
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
if target_type == Target.TargetType.DOMAIN:
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ 域名默认 URL 已写入: http(s)://%s", target_name)
|
||||
|
||||
elif target_type == Target.TargetType.IP:
|
||||
f.write(f"http://{target_name}\n")
|
||||
f.write(f"https://{target_name}\n")
|
||||
total_urls = 2
|
||||
logger.info("✓ IP 默认 URL 已写入: http(s)://%s", target_name)
|
||||
|
||||
elif target_type == Target.TargetType.CIDR:
|
||||
try:
|
||||
network = ipaddress.ip_network(target_name, strict=False)
|
||||
|
||||
for ip in network.hosts():
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls += 2
|
||||
|
||||
if total_urls % 10000 == 0:
|
||||
logger.info("已生成 %d 个 URL...", total_urls)
|
||||
|
||||
# /32 或 /128 特殊处理
|
||||
if total_urls == 0:
|
||||
ip = str(network.network_address)
|
||||
f.write(f"http://{ip}\n")
|
||||
f.write(f"https://{ip}\n")
|
||||
total_urls = 2
|
||||
|
||||
logger.info("✓ CIDR 默认 URL 已写入: %d 个 URL (来自 %s)", total_urls, target_name)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("CIDR 解析失败: %s - %s", target_name, e)
|
||||
return 0
|
||||
else:
|
||||
logger.warning("不支持的 Target 类型: %s", target_type)
|
||||
return 0
|
||||
|
||||
return total_urls
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
"success": result['success'],
|
||||
"output_file": result['output_file'],
|
||||
"total_count": result['total_count'],
|
||||
}
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
"""
|
||||
工作空间相关的 Prefect Tasks
|
||||
|
||||
负责扫描工作空间的创建、验证和管理
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(
|
||||
name="create_scan_workspace",
|
||||
description="创建并验证 Scan 工作空间目录",
|
||||
retries=2,
|
||||
retry_delay_seconds=5
|
||||
)
|
||||
def create_scan_workspace_task(scan_workspace_dir: str) -> Path:
|
||||
"""
|
||||
创建并验证 Scan 工作空间目录
|
||||
|
||||
Args:
|
||||
scan_workspace_dir: Scan 工作空间目录路径
|
||||
|
||||
Returns:
|
||||
Path: 创建的 Scan 工作空间路径对象
|
||||
|
||||
Raises:
|
||||
OSError: 目录创建失败或不可写
|
||||
"""
|
||||
scan_workspace_path = Path(scan_workspace_dir)
|
||||
|
||||
# 创建目录
|
||||
try:
|
||||
scan_workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("✓ Scan 工作空间已创建: %s", scan_workspace_path)
|
||||
except OSError as e:
|
||||
logger.error("创建 Scan 工作空间失败: %s - %s", scan_workspace_dir, e)
|
||||
raise
|
||||
|
||||
# 验证目录是否可写
|
||||
test_file = scan_workspace_path / ".test_write"
|
||||
try:
|
||||
test_file.touch()
|
||||
test_file.unlink()
|
||||
logger.info("✓ Scan 工作空间验证通过(可写): %s", scan_workspace_path)
|
||||
except OSError as e:
|
||||
error_msg = f"Scan 工作空间不可写: {scan_workspace_path}"
|
||||
logger.error(error_msg)
|
||||
raise OSError(error_msg) from e
|
||||
|
||||
return scan_workspace_path
|
||||
@@ -10,11 +10,15 @@ from .command_executor import execute_and_wait, execute_stream
|
||||
from .wordlist_helpers import ensure_wordlist_local
|
||||
from .nuclei_helpers import ensure_nuclei_templates_local
|
||||
from .performance import FlowPerformanceTracker, CommandPerformanceTracker
|
||||
from .workspace_utils import setup_scan_workspace, setup_scan_directory
|
||||
from . import config_parser
|
||||
|
||||
__all__ = [
|
||||
# 目录清理
|
||||
'remove_directory',
|
||||
# 工作空间
|
||||
'setup_scan_workspace', # 创建 Scan 根工作空间
|
||||
'setup_scan_directory', # 创建扫描子目录
|
||||
# 命令构建
|
||||
'build_scan_command', # 扫描工具命令构建(基于 f-string)
|
||||
# 命令执行
|
||||
|
||||
@@ -36,7 +36,14 @@ def _normalize_config_keys(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
Returns:
|
||||
key 已转换的新字典
|
||||
|
||||
Raises:
|
||||
ValueError: 配置为 None 或非字典类型时抛出
|
||||
"""
|
||||
if config is None:
|
||||
raise ValueError("配置不能为空(None),请检查 YAML 格式,确保冒号后有配置内容或使用 {} 表示空配置")
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError(f"配置格式错误:期望 dict,实际 {type(config).__name__}")
|
||||
return {
|
||||
k.replace('-', '_') if isinstance(k, str) else k: v
|
||||
for k, v in config.items()
|
||||
@@ -169,26 +176,23 @@ def parse_enabled_tools_from_dict(
|
||||
)
|
||||
|
||||
if enabled_value:
|
||||
# 检查 timeout 必需参数
|
||||
if 'timeout' not in config:
|
||||
raise ValueError(f"工具 {name} 缺少必需参数 'timeout'")
|
||||
# timeout 默认为 'auto',由具体 Flow 自动计算
|
||||
timeout_value = config.get('timeout', 'auto')
|
||||
|
||||
# 验证 timeout 值的有效性
|
||||
timeout_value = config['timeout']
|
||||
|
||||
if timeout_value == 'auto':
|
||||
# 允许 'auto',由具体 Flow 处理
|
||||
pass
|
||||
elif isinstance(timeout_value, int):
|
||||
if timeout_value <= 0:
|
||||
raise ValueError(f"工具 {name} 的 timeout 参数无效({timeout_value}),必须大于0")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"工具 {name} 的 timeout 参数类型错误:期望 int 或 'auto',实际 {type(timeout_value).__name__}"
|
||||
)
|
||||
if timeout_value != 'auto':
|
||||
if isinstance(timeout_value, int):
|
||||
if timeout_value <= 0:
|
||||
raise ValueError(f"工具 {name} 的 timeout 参数无效({timeout_value}),必须大于0")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"工具 {name} 的 timeout 参数类型错误:期望 int 或 'auto',实际 {type(timeout_value).__name__}"
|
||||
)
|
||||
|
||||
# 将配置 key 中划线转为下划线,统一给下游代码使用
|
||||
enabled_tools[name] = _normalize_config_keys(config)
|
||||
normalized_config = _normalize_config_keys(config)
|
||||
normalized_config['timeout'] = timeout_value # 确保 timeout 存在
|
||||
enabled_tools[name] = normalized_config
|
||||
|
||||
logger.info(f"扫描类型: {scan_type}, 启用工具: {len(enabled_tools)}/{len(tools)}")
|
||||
|
||||
|
||||
393
backend/apps/scan/utils/fingerprint_helpers.py
Normal file
393
backend/apps/scan/utils/fingerprint_helpers.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""指纹文件本地缓存工具
|
||||
|
||||
提供 Worker 侧的指纹文件缓存和版本校验功能,用于:
|
||||
- 指纹识别扫描 (fingerprint_detect_flow)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 指纹库映射:lib_name → ensure_func_name
|
||||
FINGERPRINT_LIB_MAP = {
|
||||
'ehole': 'ensure_ehole_fingerprint_local',
|
||||
'goby': 'ensure_goby_fingerprint_local',
|
||||
'wappalyzer': 'ensure_wappalyzer_fingerprint_local',
|
||||
'fingers': 'ensure_fingers_fingerprint_local',
|
||||
'fingerprinthub': 'ensure_fingerprinthub_fingerprint_local',
|
||||
'arl': 'ensure_arl_fingerprint_local',
|
||||
}
|
||||
|
||||
|
||||
def ensure_ehole_fingerprint_local() -> str:
|
||||
"""
|
||||
确保本地存在最新的 EHole 指纹文件(带缓存)
|
||||
|
||||
流程:
|
||||
1. 获取当前指纹库版本
|
||||
2. 检查缓存文件是否存在且版本匹配
|
||||
3. 版本不匹配则重新导出
|
||||
|
||||
Returns:
|
||||
str: 本地指纹文件路径
|
||||
|
||||
使用场景:
|
||||
Worker 执行扫描任务前调用,获取最新指纹文件路径
|
||||
"""
|
||||
from apps.engine.services.fingerprints import EholeFingerprintService
|
||||
|
||||
service = EholeFingerprintService()
|
||||
current_version = service.get_fingerprint_version()
|
||||
|
||||
# 缓存目录和文件
|
||||
base_dir = getattr(settings, 'FINGERPRINTS_BASE_PATH', '/opt/xingrin/fingerprints')
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
cache_file = os.path.join(base_dir, 'ehole.json')
|
||||
version_file = os.path.join(base_dir, 'ehole.version')
|
||||
|
||||
# 检查缓存版本
|
||||
cached_version = None
|
||||
if os.path.exists(version_file):
|
||||
try:
|
||||
with open(version_file, 'r') as f:
|
||||
cached_version = f.read().strip()
|
||||
except OSError as e:
|
||||
logger.warning("读取版本文件失败: %s", e)
|
||||
|
||||
# 版本匹配,直接返回缓存
|
||||
if cached_version == current_version and os.path.exists(cache_file):
|
||||
logger.info("EHole 指纹文件缓存有效(版本匹配): %s", cache_file)
|
||||
return cache_file
|
||||
|
||||
# 版本不匹配,重新导出
|
||||
logger.info(
|
||||
"EHole 指纹文件需要更新: cached=%s, current=%s",
|
||||
cached_version, current_version
|
||||
)
|
||||
count = service.export_to_file(cache_file)
|
||||
|
||||
# 写入版本文件
|
||||
try:
|
||||
with open(version_file, 'w') as f:
|
||||
f.write(current_version)
|
||||
except OSError as e:
|
||||
logger.warning("写入版本文件失败: %s", e)
|
||||
|
||||
logger.info("EHole 指纹文件已更新: %s", cache_file)
|
||||
return cache_file
|
||||
|
||||
|
||||
def ensure_goby_fingerprint_local() -> str:
|
||||
"""
|
||||
确保本地存在最新的 Goby 指纹文件(带缓存)
|
||||
|
||||
Returns:
|
||||
str: 本地指纹文件路径
|
||||
"""
|
||||
from apps.engine.services.fingerprints import GobyFingerprintService
|
||||
|
||||
service = GobyFingerprintService()
|
||||
current_version = service.get_fingerprint_version()
|
||||
|
||||
# 缓存目录和文件
|
||||
base_dir = getattr(settings, 'FINGERPRINTS_BASE_PATH', '/opt/xingrin/fingerprints')
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
cache_file = os.path.join(base_dir, 'goby.json')
|
||||
version_file = os.path.join(base_dir, 'goby.version')
|
||||
|
||||
# 检查缓存版本
|
||||
cached_version = None
|
||||
if os.path.exists(version_file):
|
||||
try:
|
||||
with open(version_file, 'r') as f:
|
||||
cached_version = f.read().strip()
|
||||
except OSError as e:
|
||||
logger.warning("读取 Goby 版本文件失败: %s", e)
|
||||
|
||||
# 版本匹配,直接返回缓存
|
||||
if cached_version == current_version and os.path.exists(cache_file):
|
||||
logger.info("Goby 指纹文件缓存有效(版本匹配): %s", cache_file)
|
||||
return cache_file
|
||||
|
||||
# 版本不匹配,重新导出
|
||||
logger.info(
|
||||
"Goby 指纹文件需要更新: cached=%s, current=%s",
|
||||
cached_version, current_version
|
||||
)
|
||||
# Goby 导出格式是数组,直接写入
|
||||
data = service.get_export_data()
|
||||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False)
|
||||
|
||||
# 写入版本文件
|
||||
try:
|
||||
with open(version_file, 'w') as f:
|
||||
f.write(current_version)
|
||||
except OSError as e:
|
||||
logger.warning("写入 Goby 版本文件失败: %s", e)
|
||||
|
||||
logger.info("Goby 指纹文件已更新: %s", cache_file)
|
||||
return cache_file
|
||||
|
||||
|
||||
def ensure_wappalyzer_fingerprint_local() -> str:
|
||||
"""
|
||||
确保本地存在最新的 Wappalyzer 指纹文件(带缓存)
|
||||
|
||||
Returns:
|
||||
str: 本地指纹文件路径
|
||||
"""
|
||||
from apps.engine.services.fingerprints import WappalyzerFingerprintService
|
||||
|
||||
service = WappalyzerFingerprintService()
|
||||
current_version = service.get_fingerprint_version()
|
||||
|
||||
# 缓存目录和文件
|
||||
base_dir = getattr(settings, 'FINGERPRINTS_BASE_PATH', '/opt/xingrin/fingerprints')
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
cache_file = os.path.join(base_dir, 'wappalyzer.json')
|
||||
version_file = os.path.join(base_dir, 'wappalyzer.version')
|
||||
|
||||
# 检查缓存版本
|
||||
cached_version = None
|
||||
if os.path.exists(version_file):
|
||||
try:
|
||||
with open(version_file, 'r') as f:
|
||||
cached_version = f.read().strip()
|
||||
except OSError as e:
|
||||
logger.warning("读取 Wappalyzer 版本文件失败: %s", e)
|
||||
|
||||
# 版本匹配,直接返回缓存
|
||||
if cached_version == current_version and os.path.exists(cache_file):
|
||||
logger.info("Wappalyzer 指纹文件缓存有效(版本匹配): %s", cache_file)
|
||||
return cache_file
|
||||
|
||||
# 版本不匹配,重新导出
|
||||
logger.info(
|
||||
"Wappalyzer 指纹文件需要更新: cached=%s, current=%s",
|
||||
cached_version, current_version
|
||||
)
|
||||
# Wappalyzer 导出格式是 {"apps": {...}}
|
||||
data = service.get_export_data()
|
||||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False)
|
||||
|
||||
# 写入版本文件
|
||||
try:
|
||||
with open(version_file, 'w') as f:
|
||||
f.write(current_version)
|
||||
except OSError as e:
|
||||
logger.warning("写入 Wappalyzer 版本文件失败: %s", e)
|
||||
|
||||
logger.info("Wappalyzer 指纹文件已更新: %s", cache_file)
|
||||
return cache_file
|
||||
|
||||
|
||||
def get_fingerprint_paths(lib_names: list) -> dict:
|
||||
"""
|
||||
获取多个指纹库的本地路径
|
||||
|
||||
Args:
|
||||
lib_names: 指纹库名称列表,如 ['ehole', 'goby']
|
||||
|
||||
Returns:
|
||||
dict: {lib_name: local_path},如 {'ehole': '/opt/xingrin/fingerprints/ehole.json'}
|
||||
|
||||
示例:
|
||||
paths = get_fingerprint_paths(['ehole'])
|
||||
# {'ehole': '/opt/xingrin/fingerprints/ehole.json'}
|
||||
"""
|
||||
paths = {}
|
||||
for lib_name in lib_names:
|
||||
if lib_name not in FINGERPRINT_LIB_MAP:
|
||||
logger.warning("不支持的指纹库: %s,跳过", lib_name)
|
||||
continue
|
||||
|
||||
ensure_func_name = FINGERPRINT_LIB_MAP[lib_name]
|
||||
# 获取当前模块中的函数
|
||||
ensure_func = globals().get(ensure_func_name)
|
||||
if ensure_func is None:
|
||||
logger.warning("指纹库 %s 的导出函数 %s 未实现,跳过", lib_name, ensure_func_name)
|
||||
continue
|
||||
|
||||
try:
|
||||
paths[lib_name] = ensure_func()
|
||||
except Exception as e:
|
||||
logger.error("获取指纹库 %s 路径失败: %s", lib_name, e)
|
||||
continue
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
def ensure_fingers_fingerprint_local() -> str:
|
||||
"""
|
||||
确保本地存在最新的 Fingers 指纹文件(带缓存)
|
||||
|
||||
Returns:
|
||||
str: 本地指纹文件路径
|
||||
"""
|
||||
from apps.engine.services.fingerprints import FingersFingerprintService
|
||||
|
||||
service = FingersFingerprintService()
|
||||
current_version = service.get_fingerprint_version()
|
||||
|
||||
# 缓存目录和文件
|
||||
base_dir = getattr(settings, 'FINGERPRINTS_BASE_PATH', '/opt/xingrin/fingerprints')
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
cache_file = os.path.join(base_dir, 'fingers.json')
|
||||
version_file = os.path.join(base_dir, 'fingers.version')
|
||||
|
||||
# 检查缓存版本
|
||||
cached_version = None
|
||||
if os.path.exists(version_file):
|
||||
try:
|
||||
with open(version_file, 'r') as f:
|
||||
cached_version = f.read().strip()
|
||||
except OSError as e:
|
||||
logger.warning("读取 Fingers 版本文件失败: %s", e)
|
||||
|
||||
# 版本匹配,直接返回缓存
|
||||
if cached_version == current_version and os.path.exists(cache_file):
|
||||
logger.info("Fingers 指纹文件缓存有效(版本匹配): %s", cache_file)
|
||||
return cache_file
|
||||
|
||||
# 版本不匹配,重新导出
|
||||
logger.info(
|
||||
"Fingers 指纹文件需要更新: cached=%s, current=%s",
|
||||
cached_version, current_version
|
||||
)
|
||||
data = service.get_export_data()
|
||||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False)
|
||||
|
||||
# 写入版本文件
|
||||
try:
|
||||
with open(version_file, 'w') as f:
|
||||
f.write(current_version)
|
||||
except OSError as e:
|
||||
logger.warning("写入 Fingers 版本文件失败: %s", e)
|
||||
|
||||
logger.info("Fingers 指纹文件已更新: %s", cache_file)
|
||||
return cache_file
|
||||
|
||||
|
||||
def ensure_fingerprinthub_fingerprint_local() -> str:
|
||||
"""
|
||||
确保本地存在最新的 FingerPrintHub 指纹文件(带缓存)
|
||||
|
||||
Returns:
|
||||
str: 本地指纹文件路径
|
||||
"""
|
||||
from apps.engine.services.fingerprints import FingerPrintHubFingerprintService
|
||||
|
||||
service = FingerPrintHubFingerprintService()
|
||||
current_version = service.get_fingerprint_version()
|
||||
|
||||
# 缓存目录和文件
|
||||
base_dir = getattr(settings, 'FINGERPRINTS_BASE_PATH', '/opt/xingrin/fingerprints')
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
cache_file = os.path.join(base_dir, 'fingerprinthub.json')
|
||||
version_file = os.path.join(base_dir, 'fingerprinthub.version')
|
||||
|
||||
# 检查缓存版本
|
||||
cached_version = None
|
||||
if os.path.exists(version_file):
|
||||
try:
|
||||
with open(version_file, 'r') as f:
|
||||
cached_version = f.read().strip()
|
||||
except OSError as e:
|
||||
logger.warning("读取 FingerPrintHub 版本文件失败: %s", e)
|
||||
|
||||
# 版本匹配,直接返回缓存
|
||||
if cached_version == current_version and os.path.exists(cache_file):
|
||||
logger.info("FingerPrintHub 指纹文件缓存有效(版本匹配): %s", cache_file)
|
||||
return cache_file
|
||||
|
||||
# 版本不匹配,重新导出
|
||||
logger.info(
|
||||
"FingerPrintHub 指纹文件需要更新: cached=%s, current=%s",
|
||||
cached_version, current_version
|
||||
)
|
||||
data = service.get_export_data()
|
||||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False)
|
||||
|
||||
# 写入版本文件
|
||||
try:
|
||||
with open(version_file, 'w') as f:
|
||||
f.write(current_version)
|
||||
except OSError as e:
|
||||
logger.warning("写入 FingerPrintHub 版本文件失败: %s", e)
|
||||
|
||||
logger.info("FingerPrintHub 指纹文件已更新: %s", cache_file)
|
||||
return cache_file
|
||||
|
||||
|
||||
def ensure_arl_fingerprint_local() -> str:
|
||||
"""
|
||||
确保本地存在最新的 ARL 指纹文件(带缓存)
|
||||
|
||||
Returns:
|
||||
str: 本地指纹文件路径(YAML 格式)
|
||||
"""
|
||||
import yaml
|
||||
from apps.engine.services.fingerprints import ARLFingerprintService
|
||||
|
||||
service = ARLFingerprintService()
|
||||
current_version = service.get_fingerprint_version()
|
||||
|
||||
# 缓存目录和文件
|
||||
base_dir = getattr(settings, 'FINGERPRINTS_BASE_PATH', '/opt/xingrin/fingerprints')
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
cache_file = os.path.join(base_dir, 'arl.yaml')
|
||||
version_file = os.path.join(base_dir, 'arl.version')
|
||||
|
||||
# 检查缓存版本
|
||||
cached_version = None
|
||||
if os.path.exists(version_file):
|
||||
try:
|
||||
with open(version_file, 'r') as f:
|
||||
cached_version = f.read().strip()
|
||||
except OSError as e:
|
||||
logger.warning("读取 ARL 版本文件失败: %s", e)
|
||||
|
||||
# 版本匹配,直接返回缓存
|
||||
if cached_version == current_version and os.path.exists(cache_file):
|
||||
logger.info("ARL 指纹文件缓存有效(版本匹配): %s", cache_file)
|
||||
return cache_file
|
||||
|
||||
# 版本不匹配,重新导出
|
||||
logger.info(
|
||||
"ARL 指纹文件需要更新: cached=%s, current=%s",
|
||||
cached_version, current_version
|
||||
)
|
||||
data = service.get_export_data()
|
||||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(data, f, allow_unicode=True, default_flow_style=False)
|
||||
|
||||
# 写入版本文件
|
||||
try:
|
||||
with open(version_file, 'w') as f:
|
||||
f.write(current_version)
|
||||
except OSError as e:
|
||||
logger.warning("写入 ARL 版本文件失败: %s", e)
|
||||
|
||||
logger.info("ARL 指纹文件已更新: %s", cache_file)
|
||||
return cache_file
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ensure_ehole_fingerprint_local",
|
||||
"ensure_goby_fingerprint_local",
|
||||
"ensure_wappalyzer_fingerprint_local",
|
||||
"ensure_fingers_fingerprint_local",
|
||||
"ensure_fingerprinthub_fingerprint_local",
|
||||
"ensure_arl_fingerprint_local",
|
||||
"get_fingerprint_paths",
|
||||
"FINGERPRINT_LIB_MAP",
|
||||
]
|
||||
@@ -83,7 +83,8 @@ def ensure_wordlist_local(wordlist_name: str) -> str:
|
||||
"无法确定 Django API 地址:请配置 SERVER_URL 或 PUBLIC_HOST 环境变量"
|
||||
)
|
||||
# 远程 Worker 通过 nginx HTTPS 访问,不再直连 8888
|
||||
api_base = f"https://{public_host}/api"
|
||||
public_port = getattr(settings, 'PUBLIC_PORT', '8083')
|
||||
api_base = f"https://{public_host}:{public_port}/api"
|
||||
query = urllib_parse.urlencode({'wordlist': wordlist_name})
|
||||
download_url = f"{api_base.rstrip('/')}/wordlists/download/?{query}"
|
||||
|
||||
@@ -95,7 +96,13 @@ def ensure_wordlist_local(wordlist_name: str) -> str:
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
with urllib_request.urlopen(download_url, context=ssl_context) as resp:
|
||||
# 创建带 API Key 的请求
|
||||
req = urllib_request.Request(download_url)
|
||||
worker_api_key = os.getenv('WORKER_API_KEY', '')
|
||||
if worker_api_key:
|
||||
req.add_header('X-Worker-API-Key', worker_api_key)
|
||||
|
||||
with urllib_request.urlopen(req, context=ssl_context) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"下载字典失败,HTTP {resp.status}")
|
||||
data = resp.read()
|
||||
|
||||
83
backend/apps/scan/utils/workspace_utils.py
Normal file
83
backend/apps/scan/utils/workspace_utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
工作空间工具模块
|
||||
|
||||
提供统一的扫描工作目录创建和验证功能
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_scan_workspace(scan_workspace_dir: str) -> Path:
|
||||
"""
|
||||
创建 Scan 根工作空间目录
|
||||
|
||||
Args:
|
||||
scan_workspace_dir: 工作空间目录路径
|
||||
|
||||
Returns:
|
||||
Path: 创建的目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 目录创建失败或不可写
|
||||
"""
|
||||
workspace_path = Path(scan_workspace_dir)
|
||||
|
||||
try:
|
||||
workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"创建工作空间失败: {scan_workspace_dir} - {e}") from e
|
||||
|
||||
# 验证可写
|
||||
_verify_writable(workspace_path)
|
||||
|
||||
logger.info("✓ Scan 工作空间已创建: %s", workspace_path)
|
||||
return workspace_path
|
||||
|
||||
|
||||
def setup_scan_directory(scan_workspace_dir: str, subdir: str) -> Path:
|
||||
"""
|
||||
创建扫描子目录
|
||||
|
||||
Args:
|
||||
scan_workspace_dir: 根工作空间目录
|
||||
subdir: 子目录名称(如 'fingerprint_detect', 'site_scan')
|
||||
|
||||
Returns:
|
||||
Path: 创建的子目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 目录创建失败或不可写
|
||||
"""
|
||||
scan_dir = Path(scan_workspace_dir) / subdir
|
||||
|
||||
try:
|
||||
scan_dir.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"创建扫描目录失败: {scan_dir} - {e}") from e
|
||||
|
||||
# 验证可写
|
||||
_verify_writable(scan_dir)
|
||||
|
||||
logger.info("✓ 扫描目录已创建: %s", scan_dir)
|
||||
return scan_dir
|
||||
|
||||
|
||||
def _verify_writable(path: Path) -> None:
|
||||
"""
|
||||
验证目录可写
|
||||
|
||||
Args:
|
||||
path: 目录路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 目录不可写
|
||||
"""
|
||||
test_file = path / ".test_write"
|
||||
try:
|
||||
test_file.touch()
|
||||
test_file.unlink()
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"目录不可写: {path} - {e}") from e
|
||||
@@ -7,6 +7,9 @@ from django.core.exceptions import ObjectDoesNotExist, ValidationError
|
||||
from django.db.utils import DatabaseError, IntegrityError, OperationalError
|
||||
import logging
|
||||
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ..models import Scan, ScheduledScan
|
||||
@@ -75,20 +78,31 @@ class ScanViewSet(viewsets.ModelViewSet):
|
||||
scan_service = ScanService()
|
||||
result = scan_service.delete_scans_two_phase([scan.id])
|
||||
|
||||
return Response({
|
||||
'message': f'已删除扫描任务: Scan #{scan.id}',
|
||||
'scanId': scan.id,
|
||||
'deletedCount': result['soft_deleted_count'],
|
||||
'deletedScans': result['scan_names']
|
||||
}, status=status.HTTP_200_OK)
|
||||
return success_response(
|
||||
data={
|
||||
'scanId': scan.id,
|
||||
'deletedCount': result['soft_deleted_count'],
|
||||
'deletedScans': result['scan_names']
|
||||
}
|
||||
)
|
||||
|
||||
except Scan.DoesNotExist:
|
||||
raise NotFound('扫描任务不存在')
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("删除扫描任务时发生错误")
|
||||
raise APIException('服务器错误,请稍后重试')
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['post'])
|
||||
def quick(self, request):
|
||||
@@ -132,10 +146,12 @@ class ScanViewSet(viewsets.ModelViewSet):
|
||||
targets = result['targets']
|
||||
|
||||
if not targets:
|
||||
return Response({
|
||||
'error': '没有有效的目标可供扫描',
|
||||
'errors': result.get('errors', [])
|
||||
}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='No valid targets for scanning',
|
||||
details=result.get('errors', []),
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# 2. 获取扫描引擎
|
||||
engine_service = EngineService()
|
||||
@@ -150,24 +166,44 @@ class ScanViewSet(viewsets.ModelViewSet):
|
||||
engine=engine
|
||||
)
|
||||
|
||||
# 检查是否成功创建扫描任务
|
||||
if not created_scans:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='No scan tasks were created. All targets may already have active scans.',
|
||||
details={
|
||||
'targetStats': result['target_stats'],
|
||||
'assetStats': result['asset_stats'],
|
||||
'errors': result.get('errors', [])
|
||||
},
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
)
|
||||
|
||||
# 序列化返回结果
|
||||
scan_serializer = ScanSerializer(created_scans, many=True)
|
||||
|
||||
return Response({
|
||||
'message': f'快速扫描已启动:{len(created_scans)} 个任务',
|
||||
'target_stats': result['target_stats'],
|
||||
'asset_stats': result['asset_stats'],
|
||||
'errors': result.get('errors', []),
|
||||
'scans': scan_serializer.data
|
||||
}, status=status.HTTP_201_CREATED)
|
||||
return success_response(
|
||||
data={
|
||||
'count': len(created_scans),
|
||||
'targetStats': result['target_stats'],
|
||||
'assetStats': result['asset_stats'],
|
||||
'errors': result.get('errors', []),
|
||||
'scans': scan_serializer.data
|
||||
},
|
||||
status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
return Response({'error': str(e)}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("快速扫描启动失败")
|
||||
return Response(
|
||||
{'error': '服务器内部错误,请稍后重试'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['post'])
|
||||
@@ -205,38 +241,47 @@ class ScanViewSet(viewsets.ModelViewSet):
|
||||
engine=engine
|
||||
)
|
||||
|
||||
# 检查是否成功创建扫描任务
|
||||
if not created_scans:
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='No scan tasks were created. All targets may already have active scans.',
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
)
|
||||
|
||||
# 序列化返回结果
|
||||
scan_serializer = ScanSerializer(created_scans, many=True)
|
||||
|
||||
return Response(
|
||||
{
|
||||
'message': f'已成功发起 {len(created_scans)} 个扫描任务',
|
||||
return success_response(
|
||||
data={
|
||||
'count': len(created_scans),
|
||||
'scans': scan_serializer.data
|
||||
},
|
||||
status=status.HTTP_201_CREATED
|
||||
status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
|
||||
except ObjectDoesNotExist as e:
|
||||
# 资源不存在错误(由 service 层抛出)
|
||||
error_msg = str(e)
|
||||
return Response(
|
||||
{'error': error_msg},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
# 参数验证错误(由 service 层抛出)
|
||||
return Response(
|
||||
{'error': str(e)},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
except (DatabaseError, IntegrityError, OperationalError):
|
||||
# 数据库错误
|
||||
return Response(
|
||||
{'error': '数据库错误,请稍后重试'},
|
||||
status=status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Database error',
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
)
|
||||
|
||||
# 所有快照相关的 action 和 export 已迁移到 asset/views.py 中的快照 ViewSet
|
||||
@@ -278,21 +323,24 @@ class ScanViewSet(viewsets.ModelViewSet):
|
||||
|
||||
# 参数验证
|
||||
if not ids:
|
||||
return Response(
|
||||
{'error': '缺少必填参数: ids'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='Missing required parameter: ids',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
if not isinstance(ids, list):
|
||||
return Response(
|
||||
{'error': 'ids 必须是数组'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='ids must be an array',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
if not all(isinstance(i, int) for i in ids):
|
||||
return Response(
|
||||
{'error': 'ids 数组中的所有元素必须是整数'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message='All elements in ids array must be integers',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -300,19 +348,27 @@ class ScanViewSet(viewsets.ModelViewSet):
|
||||
scan_service = ScanService()
|
||||
result = scan_service.delete_scans_two_phase(ids)
|
||||
|
||||
return Response({
|
||||
'message': f"已删除 {result['soft_deleted_count']} 个扫描任务",
|
||||
'deletedCount': result['soft_deleted_count'],
|
||||
'deletedScans': result['scan_names']
|
||||
}, status=status.HTTP_200_OK)
|
||||
return success_response(
|
||||
data={
|
||||
'deletedCount': result['soft_deleted_count'],
|
||||
'deletedScans': result['scan_names']
|
||||
}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
# 未找到记录
|
||||
raise NotFound(str(e))
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("批量删除扫描任务时发生错误")
|
||||
raise APIException('服务器错误,请稍后重试')
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
@action(detail=False, methods=['get'])
|
||||
def statistics(self, request):
|
||||
@@ -337,22 +393,25 @@ class ScanViewSet(viewsets.ModelViewSet):
|
||||
scan_service = ScanService()
|
||||
stats = scan_service.get_statistics()
|
||||
|
||||
return Response({
|
||||
'total': stats['total'],
|
||||
'running': stats['running'],
|
||||
'completed': stats['completed'],
|
||||
'failed': stats['failed'],
|
||||
'totalVulns': stats['total_vulns'],
|
||||
'totalSubdomains': stats['total_subdomains'],
|
||||
'totalEndpoints': stats['total_endpoints'],
|
||||
'totalWebsites': stats['total_websites'],
|
||||
'totalAssets': stats['total_assets'],
|
||||
})
|
||||
return success_response(
|
||||
data={
|
||||
'total': stats['total'],
|
||||
'running': stats['running'],
|
||||
'completed': stats['completed'],
|
||||
'failed': stats['failed'],
|
||||
'totalVulns': stats['total_vulns'],
|
||||
'totalSubdomains': stats['total_subdomains'],
|
||||
'totalEndpoints': stats['total_endpoints'],
|
||||
'totalWebsites': stats['total_websites'],
|
||||
'totalAssets': stats['total_assets'],
|
||||
}
|
||||
)
|
||||
|
||||
except (DatabaseError, OperationalError):
|
||||
return Response(
|
||||
{'error': '数据库错误,请稍后重试'},
|
||||
status=status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Database error',
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
)
|
||||
|
||||
@action(detail=True, methods=['post'])
|
||||
@@ -383,35 +442,31 @@ class ScanViewSet(viewsets.ModelViewSet):
|
||||
# 检查是否是状态不允许的问题
|
||||
scan = scan_service.get_scan(scan_id=pk, prefetch_relations=False)
|
||||
if scan and scan.status not in [ScanStatus.RUNNING, ScanStatus.INITIATED]:
|
||||
return Response(
|
||||
{
|
||||
'error': f'无法停止扫描:当前状态为 {ScanStatus(scan.status).label}',
|
||||
'detail': '只能停止运行中或初始化状态的扫描'
|
||||
},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
return error_response(
|
||||
code=ErrorCodes.BAD_REQUEST,
|
||||
message=f'Cannot stop scan: current status is {ScanStatus(scan.status).label}',
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
# 其他失败原因
|
||||
return Response(
|
||||
{'error': '停止扫描失败'},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
return Response(
|
||||
{
|
||||
'message': f'扫描已停止,已撤销 {revoked_count} 个任务',
|
||||
'revokedTaskCount': revoked_count
|
||||
},
|
||||
status=status.HTTP_200_OK
|
||||
return success_response(
|
||||
data={'revokedTaskCount': revoked_count}
|
||||
)
|
||||
|
||||
except ObjectDoesNotExist:
|
||||
return Response(
|
||||
{'error': f'扫描 ID {pk} 不存在'},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message=f'Scan ID {pk} not found',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
except (DatabaseError, IntegrityError, OperationalError):
|
||||
return Response(
|
||||
{'error': '数据库错误,请稍后重试'},
|
||||
status=status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Database error',
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
)
|
||||
|
||||
@@ -18,6 +18,8 @@ from ..serializers import (
|
||||
from ..services.scheduled_scan_service import ScheduledScanService
|
||||
from ..repositories import ScheduledScanDTO
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.common.response_helpers import success_response, error_response
|
||||
from apps.common.error_codes import ErrorCodes
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -75,15 +77,16 @@ class ScheduledScanViewSet(viewsets.ModelViewSet):
|
||||
scheduled_scan = self.service.create(dto)
|
||||
response_serializer = ScheduledScanSerializer(scheduled_scan)
|
||||
|
||||
return Response(
|
||||
{
|
||||
'message': f'创建定时扫描任务成功: {scheduled_scan.name}',
|
||||
'scheduled_scan': response_serializer.data
|
||||
},
|
||||
status=status.HTTP_201_CREATED
|
||||
return success_response(
|
||||
data=response_serializer.data,
|
||||
status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
except ValidationError as e:
|
||||
return Response({'error': str(e)}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
def update(self, request, *args, **kwargs):
|
||||
"""更新定时扫描任务"""
|
||||
@@ -105,24 +108,27 @@ class ScheduledScanViewSet(viewsets.ModelViewSet):
|
||||
scheduled_scan = self.service.update(instance.id, dto)
|
||||
response_serializer = ScheduledScanSerializer(scheduled_scan)
|
||||
|
||||
return Response({
|
||||
'message': f'更新定时扫描任务成功: {scheduled_scan.name}',
|
||||
'scheduled_scan': response_serializer.data
|
||||
})
|
||||
return success_response(data=response_serializer.data)
|
||||
except ValidationError as e:
|
||||
return Response({'error': str(e)}, status=status.HTTP_400_BAD_REQUEST)
|
||||
return error_response(
|
||||
code=ErrorCodes.VALIDATION_ERROR,
|
||||
message=str(e),
|
||||
status_code=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
"""删除定时扫描任务"""
|
||||
instance = self.get_object()
|
||||
scan_id = instance.id
|
||||
name = instance.name
|
||||
|
||||
if self.service.delete(instance.id):
|
||||
return Response({
|
||||
'message': f'删除定时扫描任务成功: {name}',
|
||||
'id': instance.id
|
||||
})
|
||||
return Response({'error': '删除失败'}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
if self.service.delete(scan_id):
|
||||
return success_response(data={'id': scan_id, 'name': name})
|
||||
return error_response(
|
||||
code=ErrorCodes.SERVER_ERROR,
|
||||
message='Failed to delete scheduled scan',
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
@action(detail=True, methods=['post'])
|
||||
def toggle(self, request, pk=None):
|
||||
@@ -136,14 +142,11 @@ class ScheduledScanViewSet(viewsets.ModelViewSet):
|
||||
scheduled_scan = self.get_object()
|
||||
response_serializer = ScheduledScanSerializer(scheduled_scan)
|
||||
|
||||
status_text = '启用' if is_enabled else '禁用'
|
||||
return Response({
|
||||
'message': f'已{status_text}定时扫描任务',
|
||||
'scheduled_scan': response_serializer.data
|
||||
})
|
||||
return success_response(data=response_serializer.data)
|
||||
|
||||
return Response(
|
||||
{'error': f'定时扫描任务 ID {pk} 不存在或操作失败'},
|
||||
status=status.HTTP_404_NOT_FOUND
|
||||
return error_response(
|
||||
code=ErrorCodes.NOT_FOUND,
|
||||
message=f'Scheduled scan with ID {pk} not found or operation failed',
|
||||
status_code=status.HTTP_404_NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from .serializers import OrganizationSerializer, TargetSerializer, TargetDetailS
|
||||
from .services.target_service import TargetService
|
||||
from .services.organization_service import OrganizationService
|
||||
from apps.common.pagination import BasePagination
|
||||
from apps.common.response_helpers import success_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -94,9 +95,8 @@ class OrganizationViewSet(viewsets.ModelViewSet):
|
||||
# 批量解除关联(直接使用 ID,避免查询对象)
|
||||
organization.targets.remove(*existing_target_ids)
|
||||
|
||||
return Response({
|
||||
'unlinked_count': existing_count,
|
||||
'message': f'成功解除 {existing_count} 个目标的关联'
|
||||
return success_response(data={
|
||||
'unlinkedCount': existing_count
|
||||
})
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
@@ -124,13 +124,12 @@ class OrganizationViewSet(viewsets.ModelViewSet):
|
||||
# 直接调用 Service 层的业务方法(软删除 + 分发硬删除任务)
|
||||
result = self.org_service.delete_organizations_two_phase([organization.id])
|
||||
|
||||
return Response({
|
||||
'message': f'已删除组织: {organization.name}',
|
||||
return success_response(data={
|
||||
'organizationId': organization.id,
|
||||
'organizationName': organization.name,
|
||||
'deletedCount': result['soft_deleted_count'],
|
||||
'deletedOrganizations': result['organization_names']
|
||||
}, status=200)
|
||||
})
|
||||
|
||||
except Organization.DoesNotExist:
|
||||
raise NotFound('组织不存在')
|
||||
@@ -181,11 +180,10 @@ class OrganizationViewSet(viewsets.ModelViewSet):
|
||||
# 调用 Service 层的业务方法(软删除 + 分发硬删除任务)
|
||||
result = self.org_service.delete_organizations_two_phase(ids)
|
||||
|
||||
return Response({
|
||||
'message': f"已删除 {result['soft_deleted_count']} 个组织",
|
||||
return success_response(data={
|
||||
'deletedCount': result['soft_deleted_count'],
|
||||
'deletedOrganizations': result['organization_names']
|
||||
}, status=200)
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
@@ -271,12 +269,11 @@ class TargetViewSet(viewsets.ModelViewSet):
|
||||
# 直接调用 Service 层的业务方法(软删除 + 分发硬删除任务)
|
||||
result = self.target_service.delete_targets_two_phase([target.id])
|
||||
|
||||
return Response({
|
||||
'message': f'已删除目标: {target.name}',
|
||||
return success_response(data={
|
||||
'targetId': target.id,
|
||||
'targetName': target.name,
|
||||
'deletedCount': result['soft_deleted_count']
|
||||
}, status=200)
|
||||
})
|
||||
|
||||
except Target.DoesNotExist:
|
||||
raise NotFound('目标不存在')
|
||||
@@ -330,11 +327,10 @@ class TargetViewSet(viewsets.ModelViewSet):
|
||||
# 调用 Service 层的业务方法(软删除 + 分发硬删除任务)
|
||||
result = self.target_service.delete_targets_two_phase(ids)
|
||||
|
||||
return Response({
|
||||
'message': f"已删除 {result['soft_deleted_count']} 个目标",
|
||||
return success_response(data={
|
||||
'deletedCount': result['soft_deleted_count'],
|
||||
'deletedTargets': result['target_names']
|
||||
}, status=200)
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
@@ -389,7 +385,7 @@ class TargetViewSet(viewsets.ModelViewSet):
|
||||
raise ValidationError(str(e))
|
||||
|
||||
# 3. 返回响应
|
||||
return Response(result, status=status.HTTP_201_CREATED)
|
||||
return success_response(data=result, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
# subdomains action 已迁移到 SubdomainViewSet 嵌套路由
|
||||
# GET /api/targets/{id}/subdomains/ -> SubdomainViewSet
|
||||
|
||||
@@ -177,6 +177,10 @@ STATIC_URL = 'static/'
|
||||
|
||||
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
|
||||
|
||||
# ==================== Worker API Key 配置 ====================
|
||||
# Worker 节点认证密钥(从环境变量读取)
|
||||
WORKER_API_KEY = os.environ.get('WORKER_API_KEY', '')
|
||||
|
||||
# ==================== REST Framework 配置 ====================
|
||||
REST_FRAMEWORK = {
|
||||
'DEFAULT_PAGINATION_CLASS': 'apps.common.pagination.BasePagination', # 使用基础分页器
|
||||
@@ -186,6 +190,14 @@ REST_FRAMEWORK = {
|
||||
'apps.common.authentication.CsrfExemptSessionAuthentication',
|
||||
],
|
||||
|
||||
# 全局权限配置:默认需要认证,公开端点和 Worker 端点在权限类中单独处理
|
||||
'DEFAULT_PERMISSION_CLASSES': [
|
||||
'apps.common.permissions.IsAuthenticatedOrPublic',
|
||||
],
|
||||
|
||||
# 自定义异常处理器:统一 401/403 错误响应格式
|
||||
'EXCEPTION_HANDLER': 'apps.common.exception_handlers.custom_exception_handler',
|
||||
|
||||
# JSON 命名格式转换:后端 snake_case ↔ 前端 camelCase
|
||||
'DEFAULT_RENDERER_CLASSES': (
|
||||
'djangorestframework_camel_case.render.CamelCaseJSONRenderer', # 响应数据转换为 camelCase
|
||||
@@ -275,13 +287,23 @@ LOGGING = get_logging_config(debug=DEBUG)
|
||||
# 命令执行日志开关(供 apps.scan.utils.command_executor 使用)
|
||||
ENABLE_COMMAND_LOGGING = get_bool_env('ENABLE_COMMAND_LOGGING', True)
|
||||
|
||||
# 扫描工具基础路径(后端和 Worker 统一使用该路径前缀存放三方工具等文件)
|
||||
SCAN_TOOLS_BASE_PATH = os.getenv('SCAN_TOOLS_PATH', '/opt/xingrin/tools')
|
||||
# ==================== 数据目录配置(统一使用 /opt/xingrin) ====================
|
||||
# 所有数据目录统一挂载到 /opt/xingrin,便于管理和备份
|
||||
|
||||
# 字典文件基础路径(后端和 Worker 统一使用该路径前缀存放字典文件)
|
||||
# 扫描工具基础路径(worker 容器内,符合 FHS 标准)
|
||||
# 使用 /opt/xingrin-tools/bin 隔离项目专用扫描工具,避免与系统工具或 Python 包冲突
|
||||
SCAN_TOOLS_BASE_PATH = os.getenv('SCAN_TOOLS_PATH', '/opt/xingrin-tools/bin')
|
||||
|
||||
# 字典文件基础路径
|
||||
WORDLISTS_BASE_PATH = os.getenv('WORDLISTS_PATH', '/opt/xingrin/wordlists')
|
||||
|
||||
# Nuclei 模板基础路径(custom / public 两类模板目录)
|
||||
# 指纹库基础路径
|
||||
FINGERPRINTS_BASE_PATH = os.getenv('FINGERPRINTS_PATH', '/opt/xingrin/fingerprints')
|
||||
|
||||
# Nuclei 模板仓库根目录(存放 git clone 的仓库)
|
||||
NUCLEI_TEMPLATES_REPOS_BASE_DIR = os.getenv('NUCLEI_TEMPLATES_REPOS_DIR', '/opt/xingrin/nuclei-repos')
|
||||
|
||||
# Nuclei 模板基础路径(custom / public 两类模板目录,已废弃,保留兼容)
|
||||
NUCLEI_CUSTOM_TEMPLATES_DIR = os.getenv('NUCLEI_CUSTOM_TEMPLATES_DIR', '/opt/xingrin/nuclei-templates/custom')
|
||||
NUCLEI_PUBLIC_TEMPLATES_DIR = os.getenv('NUCLEI_PUBLIC_TEMPLATES_DIR', '/opt/xingrin/nuclei-templates/public')
|
||||
|
||||
@@ -290,6 +312,7 @@ NUCLEI_TEMPLATES_REPO_URL = os.getenv('NUCLEI_TEMPLATES_REPO_URL', 'https://gith
|
||||
|
||||
# 对外访问主机与端口(供 Worker 访问 Django 使用)
|
||||
PUBLIC_HOST = os.getenv('PUBLIC_HOST', 'localhost').strip()
|
||||
PUBLIC_PORT = os.getenv('PUBLIC_PORT', '8083').strip() # 对外 HTTPS 端口
|
||||
SERVER_PORT = os.getenv('SERVER_PORT', '8888')
|
||||
|
||||
# ============================================
|
||||
@@ -335,32 +358,25 @@ TASK_SUBMIT_INTERVAL = int(os.getenv('TASK_SUBMIT_INTERVAL', '6'))
|
||||
DOCKER_NETWORK_NAME = os.getenv('DOCKER_NETWORK_NAME', 'xingrin_network')
|
||||
|
||||
# 宿主机挂载源路径(所有节点统一使用固定路径)
|
||||
# 部署前需创建:mkdir -p /opt/xingrin/{results,logs}
|
||||
# 部署前需创建:mkdir -p /opt/xingrin
|
||||
HOST_RESULTS_DIR = '/opt/xingrin/results'
|
||||
HOST_LOGS_DIR = '/opt/xingrin/logs'
|
||||
HOST_FINGERPRINTS_DIR = '/opt/xingrin/fingerprints'
|
||||
HOST_WORDLISTS_DIR = '/opt/xingrin/wordlists'
|
||||
|
||||
# ============================================
|
||||
# Worker 配置中心(任务容器从 /api/workers/config/ 获取)
|
||||
# ============================================
|
||||
# Worker 数据库/Redis 地址由 worker_views.py 的 config API 动态返回
|
||||
# Worker 数据库地址由 worker_views.py 的 config API 动态返回
|
||||
# 根据请求来源(本地/远程)返回不同的配置:
|
||||
# - 本地 Worker(Docker 网络内):使用内部服务名(postgres, redis)
|
||||
# - 本地 Worker(Docker 网络内):使用内部服务名 postgres
|
||||
# - 远程 Worker(公网访问):使用 PUBLIC_HOST
|
||||
#
|
||||
# 以下变量仅作为备用/兼容配置,实际配置由 API 动态生成
|
||||
# 注意:Redis 仅在 Server 容器内使用,Worker 不需要直接连接 Redis
|
||||
_db_host = DATABASES['default']['HOST']
|
||||
_is_internal_db = _db_host in ('postgres', 'localhost', '127.0.0.1')
|
||||
WORKER_DB_HOST = os.getenv('WORKER_DB_HOST', _db_host)
|
||||
|
||||
# 远程 Worker 访问 Redis 的地址(自动推导)
|
||||
# - 如果 PUBLIC_HOST 是外部 IP → 使用 PUBLIC_HOST
|
||||
# - 如果 PUBLIC_HOST 是 Docker 内部名 → 使用 redis(本地部署)
|
||||
_is_internal_public = PUBLIC_HOST in ('server', 'localhost', '127.0.0.1')
|
||||
WORKER_REDIS_URL = os.getenv(
|
||||
'WORKER_REDIS_URL',
|
||||
'redis://redis:6379/0' if _is_internal_public else f'redis://{PUBLIC_HOST}:6379/0'
|
||||
)
|
||||
|
||||
# 容器内挂载目标路径(固定值,不需要修改)
|
||||
CONTAINER_RESULTS_MOUNT = '/app/backend/results'
|
||||
CONTAINER_LOGS_MOUNT = '/app/backend/logs'
|
||||
# 容器内挂载目标路径(统一使用 /opt/xingrin)
|
||||
CONTAINER_RESULTS_MOUNT = '/opt/xingrin/results'
|
||||
CONTAINER_LOGS_MOUNT = '/opt/xingrin/logs'
|
||||
|
||||
@@ -16,7 +16,6 @@ Including another URLconf
|
||||
"""
|
||||
from django.contrib import admin
|
||||
from django.urls import path, include
|
||||
from rest_framework import permissions
|
||||
from drf_yasg.views import get_schema_view
|
||||
from drf_yasg import openapi
|
||||
|
||||
@@ -30,7 +29,6 @@ schema_view = get_schema_view(
|
||||
description="Web 应用侦察工具 API 文档",
|
||||
),
|
||||
public=True,
|
||||
permission_classes=(permissions.AllowAny,),
|
||||
)
|
||||
|
||||
urlpatterns = [
|
||||
|
||||
19519
backend/fingerprints/ARL.yaml
Normal file
19519
backend/fingerprints/ARL.yaml
Normal file
File diff suppressed because it is too large
Load Diff
4793
backend/fingerprints/ehole.json
Normal file
4793
backend/fingerprints/ehole.json
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user