mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-01-31 11:46:16 +08:00
refactor(asset): standardize snapshot and asset model field naming and types
- Rename `status` to `status_code` in WebsiteSnapshotDTO for consistency - Rename `web_server` to `webserver` in WebsiteSnapshotDTO for consistency - Make `target_id` required field in EndpointSnapshotDTO and WebsiteSnapshotDTO - Remove optional validation check for `target_id` in EndpointSnapshotDTO - Convert CharField to TextField for url, location, title, webserver, and content_type fields in Endpoint and EndpointSnapshot models to support longer values - Update migration 0001_initial.py to reflect field type changes from CharField to TextField - Update all related services and repositories to use standardized field names - Update serializers to map renamed fields correctly - Ensure consistent field naming across DTOs, models, and database schema
This commit is contained in:
@@ -13,6 +13,7 @@ class EndpointSnapshotDTO:
|
||||
快照只属于 scan。
|
||||
"""
|
||||
scan_id: int
|
||||
target_id: int # 必填,用于同步到资产表
|
||||
url: str
|
||||
host: str = '' # 主机名(域名或IP地址)
|
||||
title: str = ''
|
||||
@@ -25,7 +26,6 @@ class EndpointSnapshotDTO:
|
||||
response_body: str = ''
|
||||
vhost: Optional[bool] = None
|
||||
matched_gf_patterns: List[str] = None
|
||||
target_id: Optional[int] = None # 冗余字段,用于同步到资产表
|
||||
response_headers: str = ''
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -43,9 +43,6 @@ class EndpointSnapshotDTO:
|
||||
"""
|
||||
from apps.asset.dtos.asset import EndpointDTO
|
||||
|
||||
if self.target_id is None:
|
||||
raise ValueError("target_id 不能为 None,无法同步到资产表")
|
||||
|
||||
return EndpointDTO(
|
||||
target_id=self.target_id,
|
||||
url=self.url,
|
||||
|
||||
@@ -13,14 +13,14 @@ class WebsiteSnapshotDTO:
|
||||
快照只属于 scan,target 信息通过 scan.target 获取。
|
||||
"""
|
||||
scan_id: int
|
||||
target_id: int # 仅用于传递数据,不保存到数据库
|
||||
target_id: int # 必填,用于同步到资产表
|
||||
url: str
|
||||
host: str
|
||||
title: str = ''
|
||||
status: Optional[int] = None
|
||||
status_code: Optional[int] = None # 统一命名:status -> status_code
|
||||
content_length: Optional[int] = None
|
||||
location: str = ''
|
||||
web_server: str = ''
|
||||
webserver: str = '' # 统一命名:web_server -> webserver
|
||||
content_type: str = ''
|
||||
tech: List[str] = None
|
||||
response_body: str = ''
|
||||
@@ -45,10 +45,10 @@ class WebsiteSnapshotDTO:
|
||||
url=self.url,
|
||||
host=self.host,
|
||||
title=self.title,
|
||||
status_code=self.status,
|
||||
status_code=self.status_code,
|
||||
content_length=self.content_length,
|
||||
location=self.location,
|
||||
webserver=self.web_server,
|
||||
webserver=self.webserver,
|
||||
content_type=self.content_type,
|
||||
tech=self.tech if self.tech else [],
|
||||
response_body=self.response_body,
|
||||
|
||||
@@ -116,14 +116,14 @@ class Migration(migrations.Migration):
|
||||
name='Endpoint',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.CharField(help_text='最终访问的完整URL', max_length=2000)),
|
||||
('url', models.TextField(help_text='最终访问的完整URL')),
|
||||
('host', models.CharField(blank=True, default='', help_text='主机名(域名或IP地址)', max_length=253)),
|
||||
('location', models.CharField(blank=True, default='', help_text='重定向地址(HTTP 3xx 响应头 Location)', max_length=1000)),
|
||||
('location', models.TextField(blank=True, default='', help_text='重定向地址(HTTP 3xx 响应头 Location)')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('title', models.CharField(blank=True, default='', help_text='网页标题(HTML <title> 标签内容)', max_length=1000)),
|
||||
('webserver', models.CharField(blank=True, default='', help_text='服务器类型(HTTP 响应头 Server 值)', max_length=200)),
|
||||
('title', models.TextField(blank=True, default='', help_text='网页标题(HTML <title> 标签内容)')),
|
||||
('webserver', models.TextField(blank=True, default='', help_text='服务器类型(HTTP 响应头 Server 值)')),
|
||||
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
|
||||
('content_type', models.CharField(blank=True, default='', help_text='响应类型(HTTP Content-Type 响应头)', max_length=200)),
|
||||
('content_type', models.TextField(blank=True, default='', help_text='响应类型(HTTP Content-Type 响应头)')),
|
||||
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈(服务器/框架/语言等)', size=None)),
|
||||
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.IntegerField(blank=True, help_text='响应体大小(单位字节)', null=True)),
|
||||
@@ -145,14 +145,14 @@ class Migration(migrations.Migration):
|
||||
name='EndpointSnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.CharField(help_text='端点URL', max_length=2000)),
|
||||
('url', models.TextField(help_text='端点URL')),
|
||||
('host', models.CharField(blank=True, default='', help_text='主机名(域名或IP地址)', max_length=253)),
|
||||
('title', models.CharField(blank=True, default='', help_text='页面标题', max_length=1000)),
|
||||
('title', models.TextField(blank=True, default='', help_text='页面标题')),
|
||||
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.IntegerField(blank=True, help_text='内容长度', null=True)),
|
||||
('location', models.CharField(blank=True, default='', help_text='重定向位置', max_length=1000)),
|
||||
('webserver', models.CharField(blank=True, default='', help_text='Web服务器', max_length=200)),
|
||||
('content_type', models.CharField(blank=True, default='', help_text='内容类型', max_length=200)),
|
||||
('location', models.TextField(blank=True, default='', help_text='重定向位置')),
|
||||
('webserver', models.TextField(blank=True, default='', help_text='Web服务器')),
|
||||
('content_type', models.TextField(blank=True, default='', help_text='内容类型')),
|
||||
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈', size=None)),
|
||||
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
|
||||
('vhost', models.BooleanField(blank=True, help_text='虚拟主机标志', null=True)),
|
||||
@@ -290,14 +290,14 @@ class Migration(migrations.Migration):
|
||||
name='WebSite',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.CharField(help_text='最终访问的完整URL', max_length=2000)),
|
||||
('url', models.TextField(help_text='最终访问的完整URL')),
|
||||
('host', models.CharField(blank=True, default='', help_text='主机名(域名或IP地址)', max_length=253)),
|
||||
('location', models.CharField(blank=True, default='', help_text='重定向地址(HTTP 3xx 响应头 Location)', max_length=1000)),
|
||||
('location', models.TextField(blank=True, default='', help_text='重定向地址(HTTP 3xx 响应头 Location)')),
|
||||
('created_at', models.DateTimeField(auto_now_add=True, help_text='创建时间')),
|
||||
('title', models.CharField(blank=True, default='', help_text='网页标题(HTML <title> 标签内容)', max_length=1000)),
|
||||
('webserver', models.CharField(blank=True, default='', help_text='服务器类型(HTTP 响应头 Server 值)', max_length=200)),
|
||||
('title', models.TextField(blank=True, default='', help_text='网页标题(HTML <title> 标签内容)')),
|
||||
('webserver', models.TextField(blank=True, default='', help_text='服务器类型(HTTP 响应头 Server 值)')),
|
||||
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
|
||||
('content_type', models.CharField(blank=True, default='', help_text='响应类型(HTTP Content-Type 响应头)', max_length=200)),
|
||||
('content_type', models.TextField(blank=True, default='', help_text='响应类型(HTTP Content-Type 响应头)')),
|
||||
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈(服务器/框架/语言等)', size=None)),
|
||||
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.IntegerField(blank=True, help_text='响应体大小(单位字节)', null=True)),
|
||||
@@ -318,14 +318,14 @@ class Migration(migrations.Migration):
|
||||
name='WebsiteSnapshot',
|
||||
fields=[
|
||||
('id', models.AutoField(primary_key=True, serialize=False)),
|
||||
('url', models.CharField(help_text='站点URL', max_length=2000)),
|
||||
('url', models.TextField(help_text='站点URL')),
|
||||
('host', models.CharField(blank=True, default='', help_text='主机名(域名或IP地址)', max_length=253)),
|
||||
('title', models.CharField(blank=True, default='', help_text='页面标题', max_length=500)),
|
||||
('status', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('title', models.TextField(blank=True, default='', help_text='页面标题')),
|
||||
('status_code', models.IntegerField(blank=True, help_text='HTTP状态码', null=True)),
|
||||
('content_length', models.BigIntegerField(blank=True, help_text='内容长度', null=True)),
|
||||
('location', models.CharField(blank=True, default='', help_text='重定向位置', max_length=1000)),
|
||||
('web_server', models.CharField(blank=True, default='', help_text='Web服务器', max_length=200)),
|
||||
('content_type', models.CharField(blank=True, default='', help_text='内容类型', max_length=200)),
|
||||
('location', models.TextField(blank=True, default='', help_text='重定向位置')),
|
||||
('webserver', models.TextField(blank=True, default='', help_text='Web服务器')),
|
||||
('content_type', models.TextField(blank=True, default='', help_text='内容类型')),
|
||||
('tech', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=100), blank=True, default=list, help_text='技术栈', size=None)),
|
||||
('response_body', models.TextField(blank=True, default='', help_text='HTTP响应体')),
|
||||
('vhost', models.BooleanField(blank=True, help_text='虚拟主机标志', null=True)),
|
||||
|
||||
@@ -65,28 +65,25 @@ class Endpoint(models.Model):
|
||||
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
|
||||
)
|
||||
|
||||
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
|
||||
url = models.TextField(help_text='最终访问的完整URL')
|
||||
host = models.CharField(
|
||||
max_length=253,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='主机名(域名或IP地址)'
|
||||
)
|
||||
location = models.CharField(
|
||||
max_length=1000,
|
||||
location = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='重定向地址(HTTP 3xx 响应头 Location)'
|
||||
)
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
title = models.CharField(
|
||||
max_length=1000,
|
||||
title = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='网页标题(HTML <title> 标签内容)'
|
||||
)
|
||||
webserver = models.CharField(
|
||||
max_length=200,
|
||||
webserver = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='服务器类型(HTTP 响应头 Server 值)'
|
||||
@@ -96,8 +93,7 @@ class Endpoint(models.Model):
|
||||
default='',
|
||||
help_text='HTTP响应体'
|
||||
)
|
||||
content_type = models.CharField(
|
||||
max_length=200,
|
||||
content_type = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应类型(HTTP Content-Type 响应头)'
|
||||
@@ -188,28 +184,25 @@ class WebSite(models.Model):
|
||||
help_text='所属的扫描目标(主关联字段,表示所属关系,不能为空)'
|
||||
)
|
||||
|
||||
url = models.CharField(max_length=2000, help_text='最终访问的完整URL')
|
||||
url = models.TextField(help_text='最终访问的完整URL')
|
||||
host = models.CharField(
|
||||
max_length=253,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='主机名(域名或IP地址)'
|
||||
)
|
||||
location = models.CharField(
|
||||
max_length=1000,
|
||||
location = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='重定向地址(HTTP 3xx 响应头 Location)'
|
||||
)
|
||||
created_at = models.DateTimeField(auto_now_add=True, help_text='创建时间')
|
||||
title = models.CharField(
|
||||
max_length=1000,
|
||||
title = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='网页标题(HTML <title> 标签内容)'
|
||||
)
|
||||
webserver = models.CharField(
|
||||
max_length=200,
|
||||
webserver = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='服务器类型(HTTP 响应头 Server 值)'
|
||||
@@ -219,8 +212,7 @@ class WebSite(models.Model):
|
||||
default='',
|
||||
help_text='HTTP响应体'
|
||||
)
|
||||
content_type = models.CharField(
|
||||
max_length=200,
|
||||
content_type = models.TextField(
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='响应类型(HTTP Content-Type 响应头)'
|
||||
|
||||
@@ -61,14 +61,14 @@ class WebsiteSnapshot(models.Model):
|
||||
)
|
||||
|
||||
# 扫描结果数据
|
||||
url = models.CharField(max_length=2000, help_text='站点URL')
|
||||
url = models.TextField(help_text='站点URL')
|
||||
host = models.CharField(max_length=253, blank=True, default='', help_text='主机名(域名或IP地址)')
|
||||
title = models.CharField(max_length=500, blank=True, default='', help_text='页面标题')
|
||||
status = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
|
||||
title = models.TextField(blank=True, default='', help_text='页面标题')
|
||||
status_code = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
|
||||
content_length = models.BigIntegerField(null=True, blank=True, help_text='内容长度')
|
||||
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
|
||||
web_server = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
|
||||
location = models.TextField(blank=True, default='', help_text='重定向位置')
|
||||
webserver = models.TextField(blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.TextField(blank=True, default='', help_text='内容类型')
|
||||
tech = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
@@ -267,19 +267,19 @@ class EndpointSnapshot(models.Model):
|
||||
)
|
||||
|
||||
# 扫描结果数据
|
||||
url = models.CharField(max_length=2000, help_text='端点URL')
|
||||
url = models.TextField(help_text='端点URL')
|
||||
host = models.CharField(
|
||||
max_length=253,
|
||||
blank=True,
|
||||
default='',
|
||||
help_text='主机名(域名或IP地址)'
|
||||
)
|
||||
title = models.CharField(max_length=1000, blank=True, default='', help_text='页面标题')
|
||||
title = models.TextField(blank=True, default='', help_text='页面标题')
|
||||
status_code = models.IntegerField(null=True, blank=True, help_text='HTTP状态码')
|
||||
content_length = models.IntegerField(null=True, blank=True, help_text='内容长度')
|
||||
location = models.CharField(max_length=1000, blank=True, default='', help_text='重定向位置')
|
||||
webserver = models.CharField(max_length=200, blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.CharField(max_length=200, blank=True, default='', help_text='内容类型')
|
||||
location = models.TextField(blank=True, default='', help_text='重定向位置')
|
||||
webserver = models.TextField(blank=True, default='', help_text='Web服务器')
|
||||
content_type = models.TextField(blank=True, default='', help_text='内容类型')
|
||||
tech = ArrayField(
|
||||
models.CharField(max_length=100),
|
||||
blank=True,
|
||||
|
||||
@@ -46,10 +46,10 @@ class DjangoWebsiteSnapshotRepository:
|
||||
url=item.url,
|
||||
host=item.host,
|
||||
title=item.title,
|
||||
status=item.status,
|
||||
status_code=item.status_code,
|
||||
content_length=item.content_length,
|
||||
location=item.location,
|
||||
web_server=item.web_server,
|
||||
webserver=item.webserver,
|
||||
content_type=item.content_type,
|
||||
tech=item.tech if item.tech else [],
|
||||
response_body=item.response_body,
|
||||
@@ -99,27 +99,12 @@ class DjangoWebsiteSnapshotRepository:
|
||||
WebsiteSnapshot.objects
|
||||
.filter(scan_id=scan_id)
|
||||
.values(
|
||||
'url', 'host', 'location', 'title', 'status',
|
||||
'content_length', 'content_type', 'web_server', 'tech',
|
||||
'url', 'host', 'location', 'title', 'status_code',
|
||||
'content_length', 'content_type', 'webserver', 'tech',
|
||||
'response_body', 'response_headers', 'vhost', 'created_at'
|
||||
)
|
||||
.order_by('url')
|
||||
)
|
||||
|
||||
for row in qs.iterator(chunk_size=batch_size):
|
||||
# 重命名字段以匹配 CSV 表头
|
||||
yield {
|
||||
'url': row['url'],
|
||||
'host': row['host'],
|
||||
'location': row['location'],
|
||||
'title': row['title'],
|
||||
'status_code': row['status'],
|
||||
'content_length': row['content_length'],
|
||||
'content_type': row['content_type'],
|
||||
'webserver': row['web_server'],
|
||||
'tech': row['tech'],
|
||||
'response_body': row['response_body'],
|
||||
'response_headers': row['response_headers'],
|
||||
'vhost': row['vhost'],
|
||||
'created_at': row['created_at'],
|
||||
}
|
||||
yield row
|
||||
|
||||
@@ -217,8 +217,6 @@ class WebsiteSnapshotSerializer(serializers.ModelSerializer):
|
||||
"""网站快照序列化器(用于扫描历史)"""
|
||||
|
||||
subdomain_name = serializers.CharField(source='subdomain.name', read_only=True)
|
||||
webserver = serializers.CharField(source='web_server', read_only=True) # 映射字段名
|
||||
status_code = serializers.IntegerField(source='status', read_only=True) # 映射字段名
|
||||
responseHeaders = serializers.CharField(source='response_headers', read_only=True) # 原始HTTP响应头
|
||||
|
||||
class Meta:
|
||||
@@ -228,9 +226,9 @@ class WebsiteSnapshotSerializer(serializers.ModelSerializer):
|
||||
'url',
|
||||
'location',
|
||||
'title',
|
||||
'webserver', # 使用映射后的字段名
|
||||
'webserver',
|
||||
'content_type',
|
||||
'status_code', # 使用映射后的字段名
|
||||
'status_code',
|
||||
'content_length',
|
||||
'response_body',
|
||||
'tech',
|
||||
|
||||
@@ -27,7 +27,7 @@ class EndpointService:
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status_code',
|
||||
'status_code': 'status_code',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class WebSiteService:
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status_code',
|
||||
'status_code': 'status_code',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ class EndpointSnapshotsService:
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status_code',
|
||||
'status_code': 'status_code',
|
||||
'webserver': 'webserver',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
@@ -73,8 +73,8 @@ class WebsiteSnapshotsService:
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status',
|
||||
'webserver': 'web_server',
|
||||
'status_code': 'status_code',
|
||||
'webserver': 'webserver',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
|
||||
@@ -204,14 +204,13 @@ def _run_scans_sequentially(
|
||||
# 流式执行扫描并实时保存结果
|
||||
result = run_and_stream_save_websites_task(
|
||||
cmd=command,
|
||||
tool_name=tool_name, # 新增:工具名称
|
||||
tool_name=tool_name,
|
||||
scan_id=scan_id,
|
||||
target_id=target_id,
|
||||
cwd=str(site_scan_dir),
|
||||
shell=True,
|
||||
batch_size=1000,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file) # 新增:日志文件路径
|
||||
log_file=str(log_file)
|
||||
)
|
||||
|
||||
tool_stats[tool_name] = {
|
||||
|
||||
@@ -212,7 +212,6 @@ def _validate_and_stream_save_urls(
|
||||
target_id=target_id,
|
||||
cwd=str(url_fetch_dir),
|
||||
shell=True,
|
||||
batch_size=500,
|
||||
timeout=timeout,
|
||||
log_file=str(log_file)
|
||||
)
|
||||
|
||||
@@ -341,9 +341,9 @@ def _process_batch(
|
||||
url=record['url'],
|
||||
host=host,
|
||||
title=record.get('title', '') or '',
|
||||
status=record.get('status_code'),
|
||||
status_code=record.get('status_code'),
|
||||
content_length=record.get('content_length'),
|
||||
web_server=record.get('server', '') or '',
|
||||
webserver=record.get('server', '') or '',
|
||||
tech=record.get('techs', []),
|
||||
)
|
||||
snapshot_dtos.append(dto)
|
||||
|
||||
@@ -30,7 +30,6 @@ from typing import Generator, Optional, Dict, Any, TYPE_CHECKING
|
||||
from django.db import IntegrityError, OperationalError, DatabaseError
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from psycopg2 import InterfaceError
|
||||
|
||||
from apps.asset.dtos.snapshot import WebsiteSnapshotDTO
|
||||
@@ -62,6 +61,18 @@ class ServiceSet:
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_string(value: str) -> str:
|
||||
"""
|
||||
清理字符串中的 NUL 字符和其他不可打印字符
|
||||
|
||||
PostgreSQL 不允许字符串字段包含 NUL (0x00) 字符
|
||||
"""
|
||||
if not value:
|
||||
return value
|
||||
# 移除 NUL 字符
|
||||
return value.replace('\x00', '')
|
||||
|
||||
|
||||
def normalize_url(url: str) -> str:
|
||||
"""
|
||||
标准化 URL,移除默认端口号
|
||||
@@ -117,70 +128,50 @@ def normalize_url(url: str) -> str:
|
||||
return url
|
||||
|
||||
|
||||
def _extract_hostname(url: str) -> str:
|
||||
"""
|
||||
从 URL 提取主机名
|
||||
|
||||
Args:
|
||||
url: URL 字符串
|
||||
|
||||
Returns:
|
||||
str: 提取的主机名(小写)
|
||||
"""
|
||||
try:
|
||||
if url:
|
||||
parsed = urlparse(url)
|
||||
if parsed.hostname:
|
||||
return parsed.hostname
|
||||
# 降级方案:手动提取
|
||||
return url.replace('http://', '').replace('https://', '').split('/')[0].split(':')[0]
|
||||
return ''
|
||||
except Exception as e:
|
||||
logger.debug("提取主机名失败: %s", e)
|
||||
return ''
|
||||
|
||||
|
||||
class HttpxRecord:
|
||||
"""httpx 扫描记录数据类"""
|
||||
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
self.url = data.get('url', '')
|
||||
self.input = data.get('input', '')
|
||||
self.title = data.get('title', '')
|
||||
self.status_code = data.get('status_code')
|
||||
self.content_length = data.get('content_length')
|
||||
self.content_type = data.get('content_type', '')
|
||||
self.location = data.get('location', '')
|
||||
self.webserver = data.get('webserver', '')
|
||||
self.response_body = data.get('body', '') # 从 body 字段获取完整响应体
|
||||
self.tech = data.get('tech', [])
|
||||
self.vhost = data.get('vhost')
|
||||
self.failed = data.get('failed', False)
|
||||
self.timestamp = data.get('timestamp')
|
||||
self.response_headers = data.get('raw_header', '') # 从 raw_header 字段获取原始响应头字符串
|
||||
self.url = _sanitize_string(data.get('url', ''))
|
||||
self.input = _sanitize_string(data.get('input', ''))
|
||||
self.title = _sanitize_string(data.get('title', ''))
|
||||
self.status_code = data.get('status_code') # int,不需要清理
|
||||
self.content_length = data.get('content_length') # int,不需要清理
|
||||
self.content_type = _sanitize_string(data.get('content_type', ''))
|
||||
self.location = _sanitize_string(data.get('location', ''))
|
||||
self.webserver = _sanitize_string(data.get('webserver', ''))
|
||||
self.response_body = _sanitize_string(data.get('body', ''))
|
||||
self.tech = [_sanitize_string(t) for t in data.get('tech', []) if isinstance(t, str)] # 列表中的字符串也需要清理
|
||||
self.vhost = data.get('vhost') # bool,不需要清理
|
||||
self.failed = data.get('failed', False) # bool,不需要清理
|
||||
self.response_headers = _sanitize_string(data.get('raw_header', ''))
|
||||
|
||||
# 从 URL 中提取主机名
|
||||
self.host = self._extract_hostname()
|
||||
|
||||
def _extract_hostname(self) -> str:
|
||||
"""
|
||||
从 URL 或 input 字段提取主机名
|
||||
|
||||
优先级:
|
||||
1. 使用 urlparse 解析 URL 获取 hostname
|
||||
2. 从 input 字段提取(处理可能包含协议的情况)
|
||||
3. 从 URL 字段手动提取(降级方案)
|
||||
|
||||
Returns:
|
||||
str: 提取的主机名(小写)
|
||||
"""
|
||||
try:
|
||||
# 方法 1: 使用 urlparse 解析 URL
|
||||
if self.url:
|
||||
parsed = urlparse(self.url)
|
||||
if parsed.hostname:
|
||||
return parsed.hostname
|
||||
|
||||
# 方法 2: 从 input 字段提取
|
||||
if self.input:
|
||||
host = self.input.strip().lower()
|
||||
# 移除协议前缀
|
||||
if host.startswith(('http://', 'https://')):
|
||||
host = host.split('//', 1)[1].split('/')[0]
|
||||
return host
|
||||
|
||||
# 方法 3: 从 URL 手动提取(降级方案)
|
||||
if self.url:
|
||||
return self.url.replace('http://', '').replace('https://', '').split('/')[0]
|
||||
|
||||
# 兜底:返回空字符串
|
||||
return ''
|
||||
|
||||
except Exception as e:
|
||||
# 异常处理:尽力从 input 或 URL 提取
|
||||
logger.debug("提取主机名失败: %s,使用降级方案", e)
|
||||
if self.input:
|
||||
return self.input.strip().lower()
|
||||
if self.url:
|
||||
return self.url.replace('http://', '').replace('https://', '').split('/')[0]
|
||||
return ''
|
||||
# 从 URL 中提取主机名(优先使用 httpx 返回的 host,否则自动提取)
|
||||
httpx_host = _sanitize_string(data.get('host', ''))
|
||||
self.host = httpx_host if httpx_host else _extract_hostname(self.url)
|
||||
|
||||
|
||||
def _save_batch_with_retry(
|
||||
@@ -228,39 +219,31 @@ def _save_batch_with_retry(
|
||||
}
|
||||
|
||||
except (OperationalError, DatabaseError, InterfaceError) as e:
|
||||
# 数据库连接/操作错误,可重试
|
||||
# 数据库级错误(连接中断、表结构不匹配等):按指数退避重试,最终失败时抛出异常让 Flow 失败
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2 ** attempt # 指数退避: 1s, 2s, 4s
|
||||
wait_time = 2 ** attempt
|
||||
logger.warning(
|
||||
"批次 %d 保存失败(第 %d 次尝试),%d秒后重试: %s",
|
||||
batch_num, attempt + 1, wait_time, str(e)[:100]
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
logger.error("批次 %d 保存失败(已重试 %d 次): %s", batch_num, max_retries, e)
|
||||
return {
|
||||
'success': False,
|
||||
'created_websites': 0,
|
||||
'skipped_failed': 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 其他未知错误 - 检查是否为连接问题
|
||||
error_str = str(e).lower()
|
||||
if 'connection' in error_str and attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
"批次 %d 连接相关错误(尝试 %d/%d): %s,Repository 装饰器会自动重连",
|
||||
batch_num, attempt + 1, max_retries, str(e)
|
||||
logger.error(
|
||||
"批次 %d 保存失败(已重试 %d 次),将终止任务: %s",
|
||||
batch_num,
|
||||
max_retries,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
time.sleep(2)
|
||||
else:
|
||||
logger.error("批次 %d 未知错误: %s", batch_num, e, exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'created_websites': 0,
|
||||
'skipped_failed': 0
|
||||
}
|
||||
|
||||
# 让上层 Task 感知失败,从而标记整个扫描为失败
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
# 其他未知异常也不再吞掉,直接抛出以便 Flow 标记为失败
|
||||
logger.error("批次 %d 未知错误: %s", batch_num, e, exc_info=True)
|
||||
raise
|
||||
|
||||
# 理论上不会走到这里,保留兜底返回值以满足类型约束
|
||||
return {
|
||||
'success': False,
|
||||
'created_websites': 0,
|
||||
@@ -328,43 +311,39 @@ def _save_batch(
|
||||
skipped_failed += 1
|
||||
continue
|
||||
|
||||
# 解析时间戳
|
||||
created_at = None
|
||||
if hasattr(record, 'timestamp') and record.timestamp:
|
||||
try:
|
||||
created_at = parse_datetime(record.timestamp)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"无法解析时间戳 {record.timestamp}: {e}")
|
||||
|
||||
# 使用 input 字段(原始扫描的 URL)而不是 url 字段(重定向后的 URL)
|
||||
# 原因:避免多个不同的输入 URL 重定向到同一个 URL 时产生唯一约束冲突
|
||||
# 例如:http://example.com 和 https://example.com 都重定向到 https://example.com
|
||||
# 如果使用 record.url,两条记录会有相同的 url,导致数据库冲突
|
||||
# 如果使用 record.input,两条记录保留原始输入,不会冲突
|
||||
normalized_url = normalize_url(record.input)
|
||||
|
||||
# 提取 host 字段(域名或IP地址)
|
||||
host = record.host if record.host else ''
|
||||
|
||||
# 创建 WebsiteSnapshot DTO
|
||||
snapshot_dto = WebsiteSnapshotDTO(
|
||||
scan_id=scan_id,
|
||||
target_id=target_id, # 主关联字段
|
||||
url=normalized_url, # 保存原始输入 URL(归一化后)
|
||||
host=host, # 主机名(域名或IP地址)
|
||||
location=record.location, # location 字段保存重定向信息
|
||||
title=record.title[:1000] if record.title else '',
|
||||
web_server=record.webserver[:200] if record.webserver else '',
|
||||
response_body=record.response_body if record.response_body else '',
|
||||
content_type=record.content_type[:200] if record.content_type else '',
|
||||
tech=record.tech if isinstance(record.tech, list) else [],
|
||||
status=record.status_code,
|
||||
content_length=record.content_length,
|
||||
vhost=record.vhost,
|
||||
response_headers=record.response_headers if record.response_headers else '',
|
||||
)
|
||||
|
||||
snapshot_items.append(snapshot_dto)
|
||||
try:
|
||||
# 使用 input 字段(原始扫描的 URL)而不是 url 字段(重定向后的 URL)
|
||||
# 原因:避免多个不同的输入 URL 重定向到同一个 URL 时产生唯一约束冲突
|
||||
# 例如:http://example.com 和 https://example.com 都重定向到 https://example.com
|
||||
# 如果使用 record.url,两条记录会有相同的 url,导致数据库冲突
|
||||
# 如果使用 record.input,两条记录保留原始输入,不会冲突
|
||||
normalized_url = normalize_url(record.input) if record.input else normalize_url(record.url)
|
||||
|
||||
# 提取 host 字段(域名或IP地址)
|
||||
host = record.host if record.host else ''
|
||||
|
||||
# 创建 WebsiteSnapshot DTO
|
||||
snapshot_dto = WebsiteSnapshotDTO(
|
||||
scan_id=scan_id,
|
||||
target_id=target_id, # 主关联字段
|
||||
url=normalized_url, # 保存原始输入 URL(归一化后)
|
||||
host=host, # 主机名(域名或IP地址)
|
||||
location=record.location if record.location else '',
|
||||
title=record.title if record.title else '',
|
||||
webserver=record.webserver if record.webserver else '',
|
||||
response_body=record.response_body if record.response_body else '',
|
||||
content_type=record.content_type if record.content_type else '',
|
||||
tech=record.tech if isinstance(record.tech, list) else [],
|
||||
status_code=record.status_code,
|
||||
content_length=record.content_length,
|
||||
vhost=record.vhost,
|
||||
response_headers=record.response_headers if record.response_headers else '',
|
||||
)
|
||||
|
||||
snapshot_items.append(snapshot_dto)
|
||||
except Exception as e:
|
||||
logger.error("处理记录失败: %s,错误: %s", record.url, e)
|
||||
continue
|
||||
|
||||
# ========== Step 3: 保存快照并同步到资产表(通过快照 Service)==========
|
||||
if snapshot_items:
|
||||
@@ -386,28 +365,31 @@ def _parse_and_validate_line(line: str) -> Optional[HttpxRecord]:
|
||||
Optional[HttpxRecord]: 有效的 httpx 扫描记录,或 None 如果验证失败
|
||||
|
||||
验证步骤:
|
||||
1. 解析 JSON 格式
|
||||
2. 验证数据类型为字典
|
||||
3. 创建 HttpxRecord 对象
|
||||
4. 验证必要字段(url)
|
||||
1. 清理 NUL 字符
|
||||
2. 解析 JSON 格式
|
||||
3. 验证数据类型为字典
|
||||
4. 创建 HttpxRecord 对象
|
||||
5. 验证必要字段(url)
|
||||
"""
|
||||
try:
|
||||
# 步骤 1: 解析 JSON
|
||||
# 步骤 1: 清理 NUL 字符后再解析 JSON
|
||||
line = _sanitize_string(line)
|
||||
|
||||
# 步骤 2: 解析 JSON
|
||||
try:
|
||||
line_data = json.loads(line, strict=False)
|
||||
except json.JSONDecodeError:
|
||||
# logger.info("跳过非 JSON 行: %s", line)
|
||||
return None
|
||||
|
||||
# 步骤 2: 验证数据类型
|
||||
# 步骤 3: 验证数据类型
|
||||
if not isinstance(line_data, dict):
|
||||
logger.info("跳过非字典数据")
|
||||
return None
|
||||
|
||||
# 步骤 3: 创建记录
|
||||
# 步骤 4: 创建记录
|
||||
record = HttpxRecord(line_data)
|
||||
|
||||
# 步骤 4: 验证必要字段
|
||||
# 步骤 5: 验证必要字段
|
||||
if not record.url:
|
||||
logger.info("URL 为空,跳过 - 数据: %s", str(line_data)[:200])
|
||||
return None
|
||||
@@ -416,7 +398,7 @@ def _parse_and_validate_line(line: str) -> Optional[HttpxRecord]:
|
||||
return record
|
||||
|
||||
except Exception:
|
||||
logger.info("跳过无法解析的行: %s", line[:100])
|
||||
logger.info("跳过无法解析的行: %s", line[:100] if line else 'empty')
|
||||
return None
|
||||
|
||||
|
||||
@@ -464,8 +446,8 @@ def _parse_httpx_stream_output(
|
||||
# yield 一条有效记录
|
||||
yield record
|
||||
|
||||
# 每处理 1000 条记录输出一次进度
|
||||
if valid_records % 1000 == 0:
|
||||
# 每处理 5 条记录输出一次进度
|
||||
if valid_records % 5 == 0:
|
||||
logger.info("已解析 %d 条有效记录...", valid_records)
|
||||
|
||||
except subprocess.TimeoutExpired as e:
|
||||
@@ -604,8 +586,8 @@ def _process_records_in_batches(
|
||||
_process_batch(batch, scan_id, target_id, batch_num, total_stats, failed_batches, services)
|
||||
batch = [] # 清空批次
|
||||
|
||||
# 每20个批次输出进度
|
||||
if batch_num % 20 == 0:
|
||||
# 每 2 个批次输出进度
|
||||
if batch_num % 2 == 0:
|
||||
logger.info("进度: 已处理 %d 批次,%d 条记录", batch_num, total_records)
|
||||
|
||||
# 保存最后一批
|
||||
@@ -676,11 +658,7 @@ def _cleanup_resources(data_generator) -> None:
|
||||
logger.error("关闭生成器时出错: %s", gen_close_error)
|
||||
|
||||
|
||||
@task(
|
||||
name='run_and_stream_save_websites',
|
||||
retries=0,
|
||||
log_prints=True
|
||||
)
|
||||
@task(name='run_and_stream_save_websites', retries=0)
|
||||
def run_and_stream_save_websites_task(
|
||||
cmd: str,
|
||||
tool_name: str,
|
||||
@@ -688,7 +666,7 @@ def run_and_stream_save_websites_task(
|
||||
target_id: int,
|
||||
cwd: Optional[str] = None,
|
||||
shell: bool = False,
|
||||
batch_size: int = 1000,
|
||||
batch_size: int = 10,
|
||||
timeout: Optional[int] = None,
|
||||
log_file: Optional[str] = None
|
||||
) -> dict:
|
||||
|
||||
@@ -23,10 +23,11 @@ import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
from typing import Generator, Optional
|
||||
from typing import Generator, Optional, Dict, Any
|
||||
from django.db import IntegrityError, OperationalError, DatabaseError
|
||||
from psycopg2 import InterfaceError
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from apps.asset.services.snapshot import EndpointSnapshotsService
|
||||
from apps.scan.utils import execute_stream
|
||||
@@ -63,7 +64,53 @@ def _sanitize_string(value: str) -> str:
|
||||
return value.replace('\x00', '')
|
||||
|
||||
|
||||
def _parse_and_validate_line(line: str) -> Optional[dict]:
|
||||
def _extract_hostname(url: str) -> str:
|
||||
"""
|
||||
从 URL 提取主机名
|
||||
|
||||
Args:
|
||||
url: URL 字符串
|
||||
|
||||
Returns:
|
||||
str: 提取的主机名(小写)
|
||||
"""
|
||||
try:
|
||||
if url:
|
||||
parsed = urlparse(url)
|
||||
if parsed.hostname:
|
||||
return parsed.hostname
|
||||
# 降级方案:手动提取
|
||||
return url.replace('http://', '').replace('https://', '').split('/')[0].split(':')[0]
|
||||
return ''
|
||||
except Exception as e:
|
||||
logger.debug("提取主机名失败: %s", e)
|
||||
return ''
|
||||
|
||||
|
||||
class HttpxRecord:
|
||||
"""httpx 扫描记录数据类"""
|
||||
|
||||
def __init__(self, data: Dict[str, Any]):
|
||||
self.url = _sanitize_string(data.get('url', ''))
|
||||
self.input = _sanitize_string(data.get('input', ''))
|
||||
self.title = _sanitize_string(data.get('title', ''))
|
||||
self.status_code = data.get('status_code') # int,不需要清理
|
||||
self.content_length = data.get('content_length') # int,不需要清理
|
||||
self.content_type = _sanitize_string(data.get('content_type', ''))
|
||||
self.location = _sanitize_string(data.get('location', ''))
|
||||
self.webserver = _sanitize_string(data.get('webserver', ''))
|
||||
self.response_body = _sanitize_string(data.get('body', ''))
|
||||
self.tech = [_sanitize_string(t) for t in data.get('tech', []) if isinstance(t, str)] # 列表中的字符串也需要清理
|
||||
self.vhost = data.get('vhost') # bool,不需要清理
|
||||
self.failed = data.get('failed', False) # bool,不需要清理
|
||||
self.response_headers = _sanitize_string(data.get('raw_header', ''))
|
||||
|
||||
# 从 URL 中提取主机名(优先使用 httpx 返回的 host,否则自动提取)
|
||||
httpx_host = _sanitize_string(data.get('host', ''))
|
||||
self.host = httpx_host if httpx_host else _extract_hostname(self.url)
|
||||
|
||||
|
||||
def _parse_and_validate_line(line: str) -> Optional[HttpxRecord]:
|
||||
"""
|
||||
解析并验证单行 httpx JSON 输出
|
||||
|
||||
@@ -71,9 +118,7 @@ def _parse_and_validate_line(line: str) -> Optional[dict]:
|
||||
line: 单行输出数据
|
||||
|
||||
Returns:
|
||||
Optional[dict]: 有效的 httpx 记录,或 None 如果验证失败
|
||||
|
||||
保存所有有效 URL(不再过滤状态码,安全扫描中 403/404/500 等也有分析价值)
|
||||
Optional[HttpxRecord]: 有效的 httpx 记录,或 None 如果验证失败
|
||||
"""
|
||||
try:
|
||||
# 清理 NUL 字符后再解析 JSON
|
||||
@@ -83,7 +128,6 @@ def _parse_and_validate_line(line: str) -> Optional[dict]:
|
||||
try:
|
||||
line_data = json.loads(line, strict=False)
|
||||
except json.JSONDecodeError:
|
||||
# logger.info("跳过非 JSON 行: %s", line)
|
||||
return None
|
||||
|
||||
# 验证数据类型
|
||||
@@ -91,29 +135,15 @@ def _parse_and_validate_line(line: str) -> Optional[dict]:
|
||||
logger.info("跳过非字典数据")
|
||||
return None
|
||||
|
||||
# 获取必要字段
|
||||
url = line_data.get('url', '').strip()
|
||||
status_code = line_data.get('status_code')
|
||||
# 创建记录
|
||||
record = HttpxRecord(line_data)
|
||||
|
||||
if not url:
|
||||
# 验证必要字段
|
||||
if not record.url:
|
||||
logger.info("URL 为空,跳过 - 数据: %s", str(line_data)[:200])
|
||||
return None
|
||||
|
||||
# 保存所有有效 URL(不再过滤状态码)
|
||||
return {
|
||||
'url': _sanitize_string(url),
|
||||
'host': _sanitize_string(line_data.get('host', '')),
|
||||
'status_code': status_code,
|
||||
'title': _sanitize_string(line_data.get('title', '')),
|
||||
'content_length': line_data.get('content_length', 0),
|
||||
'content_type': _sanitize_string(line_data.get('content_type', '')),
|
||||
'webserver': _sanitize_string(line_data.get('webserver', '')),
|
||||
'location': _sanitize_string(line_data.get('location', '')),
|
||||
'tech': line_data.get('tech', []),
|
||||
'response_body': _sanitize_string(line_data.get('body', '')),
|
||||
'vhost': line_data.get('vhost', False),
|
||||
'response_headers': _sanitize_string(line_data.get('raw_header', '')),
|
||||
}
|
||||
return record
|
||||
|
||||
except Exception:
|
||||
logger.info("跳过无法解析的行: %s", line[:100] if line else 'empty')
|
||||
@@ -127,7 +157,7 @@ def _parse_httpx_stream_output(
|
||||
shell: bool = False,
|
||||
timeout: Optional[int] = None,
|
||||
log_file: Optional[str] = None
|
||||
) -> Generator[dict, None, None]:
|
||||
) -> Generator[HttpxRecord, None, None]:
|
||||
"""
|
||||
流式解析 httpx 命令输出
|
||||
|
||||
@@ -140,7 +170,7 @@ def _parse_httpx_stream_output(
|
||||
log_file: 日志文件路径
|
||||
|
||||
Yields:
|
||||
dict: 每次 yield 一条存活的 URL 记录
|
||||
HttpxRecord: 每次 yield 一条存活的 URL 记录
|
||||
"""
|
||||
logger.info("开始流式解析 httpx 输出 - 命令: %s", cmd)
|
||||
|
||||
@@ -170,8 +200,8 @@ def _parse_httpx_stream_output(
|
||||
# yield 一条有效记录(存活的 URL)
|
||||
yield record
|
||||
|
||||
# 每处理 500 条记录输出一次进度
|
||||
if valid_records % 500 == 0:
|
||||
# 每处理 100 条记录输出一次进度
|
||||
if valid_records % 100 == 0:
|
||||
logger.info("已解析 %d 条存活的 URL...", valid_records)
|
||||
|
||||
except subprocess.TimeoutExpired as e:
|
||||
@@ -188,6 +218,78 @@ def _parse_httpx_stream_output(
|
||||
)
|
||||
|
||||
|
||||
def _validate_task_parameters(cmd: str, target_id: int, scan_id: int, cwd: Optional[str]) -> None:
|
||||
"""
|
||||
验证任务参数的有效性
|
||||
|
||||
Args:
|
||||
cmd: 扫描命令
|
||||
target_id: 目标ID
|
||||
scan_id: 扫描ID
|
||||
cwd: 工作目录
|
||||
|
||||
Raises:
|
||||
ValueError: 参数验证失败
|
||||
"""
|
||||
if not cmd or not cmd.strip():
|
||||
raise ValueError("扫描命令不能为空")
|
||||
|
||||
if target_id is None:
|
||||
raise ValueError("target_id 不能为 None,必须指定目标ID")
|
||||
|
||||
if scan_id is None:
|
||||
raise ValueError("scan_id 不能为 None,必须指定扫描ID")
|
||||
|
||||
# 验证工作目录(如果指定)
|
||||
if cwd and not Path(cwd).exists():
|
||||
raise ValueError(f"工作目录不存在: {cwd}")
|
||||
|
||||
|
||||
def _build_final_result(stats: dict) -> dict:
|
||||
"""
|
||||
构建最终结果并输出日志
|
||||
|
||||
Args:
|
||||
stats: 处理统计信息
|
||||
|
||||
Returns:
|
||||
dict: 最终结果
|
||||
"""
|
||||
logger.info(
|
||||
"✓ URL 验证任务完成 - 处理记录: %d(%d 批次),创建端点: %d,跳过(失败): %d",
|
||||
stats['processed_records'], stats['batch_count'], stats['created_endpoints'],
|
||||
stats['skipped_failed']
|
||||
)
|
||||
|
||||
# 如果没有创建任何记录,给出明确提示
|
||||
if stats['created_endpoints'] == 0:
|
||||
logger.warning(
|
||||
"⚠️ 没有创建任何端点记录!可能原因:1) 命令输出格式问题 2) 重复数据被忽略 3) 所有请求都失败"
|
||||
)
|
||||
|
||||
return {
|
||||
'processed_records': stats['processed_records'],
|
||||
'created_endpoints': stats['created_endpoints'],
|
||||
'skipped_failed': stats['skipped_failed']
|
||||
}
|
||||
|
||||
|
||||
def _cleanup_resources(data_generator) -> None:
|
||||
"""
|
||||
清理任务资源
|
||||
|
||||
Args:
|
||||
data_generator: 数据生成器
|
||||
"""
|
||||
# 确保生成器被正确关闭
|
||||
if data_generator is not None:
|
||||
try:
|
||||
data_generator.close()
|
||||
logger.debug("已关闭数据生成器")
|
||||
except Exception as gen_close_error:
|
||||
logger.error("关闭生成器时出错: %s", gen_close_error)
|
||||
|
||||
|
||||
def _save_batch_with_retry(
|
||||
batch: list,
|
||||
scan_id: int,
|
||||
@@ -208,14 +310,19 @@ def _save_batch_with_retry(
|
||||
max_retries: 最大重试次数
|
||||
|
||||
Returns:
|
||||
dict: {'success': bool, 'saved_count': int}
|
||||
dict: {
|
||||
'success': bool,
|
||||
'created_endpoints': int,
|
||||
'skipped_failed': int
|
||||
}
|
||||
"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
count = _save_batch(batch, scan_id, target_id, batch_num, services)
|
||||
stats = _save_batch(batch, scan_id, target_id, batch_num, services)
|
||||
return {
|
||||
'success': True,
|
||||
'saved_count': count
|
||||
'created_endpoints': stats.get('created_endpoints', 0),
|
||||
'skipped_failed': stats.get('skipped_failed', 0)
|
||||
}
|
||||
|
||||
except IntegrityError as e:
|
||||
@@ -223,7 +330,8 @@ def _save_batch_with_retry(
|
||||
logger.error("批次 %d 数据完整性错误,跳过: %s", batch_num, str(e)[:100])
|
||||
return {
|
||||
'success': False,
|
||||
'saved_count': 0
|
||||
'created_endpoints': 0,
|
||||
'skipped_failed': 0
|
||||
}
|
||||
|
||||
except (OperationalError, DatabaseError, InterfaceError) as e:
|
||||
@@ -254,7 +362,8 @@ def _save_batch_with_retry(
|
||||
# 理论上不会走到这里,保留兜底返回值以满足类型约束
|
||||
return {
|
||||
'success': False,
|
||||
'saved_count': 0
|
||||
'created_endpoints': 0,
|
||||
'skipped_failed': 0
|
||||
}
|
||||
|
||||
|
||||
@@ -264,50 +373,72 @@ def _save_batch(
|
||||
target_id: int,
|
||||
batch_num: int,
|
||||
services: ServiceSet
|
||||
) -> int:
|
||||
) -> dict:
|
||||
"""
|
||||
保存一个批次的数据到数据库
|
||||
|
||||
Args:
|
||||
batch: 数据批次,list of dict
|
||||
batch: 数据批次,list of HttpxRecord
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
batch_num: 批次编号
|
||||
services: Service 集合
|
||||
|
||||
Returns:
|
||||
int: 创建的记录数
|
||||
dict: 包含创建和跳过记录的统计信息
|
||||
"""
|
||||
# 参数验证
|
||||
if not isinstance(batch, list):
|
||||
raise TypeError(f"batch 必须是 list 类型,实际: {type(batch).__name__}")
|
||||
|
||||
if not batch:
|
||||
logger.debug("批次 %d 为空,跳过处理", batch_num)
|
||||
return 0
|
||||
return {
|
||||
'created_endpoints': 0,
|
||||
'skipped_failed': 0
|
||||
}
|
||||
|
||||
# 统计变量
|
||||
skipped_failed = 0
|
||||
|
||||
# 批量构造 Endpoint 快照 DTO
|
||||
from apps.asset.dtos.snapshot import EndpointSnapshotDTO
|
||||
|
||||
snapshots = []
|
||||
for record in batch:
|
||||
# 跳过失败的请求
|
||||
if record.failed:
|
||||
skipped_failed += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
# Endpoint URL 直接使用原始值,不做标准化
|
||||
# 原因:Endpoint URL 来自 waymore/katana,包含路径和参数,标准化可能改变含义
|
||||
url = record.input if record.input else record.url
|
||||
|
||||
# 提取 host 字段(域名或IP地址)
|
||||
host = record.host if record.host else ''
|
||||
|
||||
dto = EndpointSnapshotDTO(
|
||||
scan_id=scan_id,
|
||||
url=record['url'],
|
||||
host=record.get('host', ''),
|
||||
title=record.get('title', ''),
|
||||
status_code=record.get('status_code'),
|
||||
content_length=record.get('content_length', 0),
|
||||
location=record.get('location', ''),
|
||||
webserver=record.get('webserver', ''),
|
||||
content_type=record.get('content_type', ''),
|
||||
tech=record.get('tech', []),
|
||||
response_body=record.get('response_body', ''),
|
||||
vhost=record.get('vhost', False),
|
||||
matched_gf_patterns=[],
|
||||
target_id=target_id,
|
||||
response_headers=record.get('response_headers', ''),
|
||||
url=url,
|
||||
host=host,
|
||||
title=record.title if record.title else '',
|
||||
status_code=record.status_code,
|
||||
content_length=record.content_length,
|
||||
location=record.location if record.location else '',
|
||||
webserver=record.webserver if record.webserver else '',
|
||||
content_type=record.content_type if record.content_type else '',
|
||||
tech=record.tech if isinstance(record.tech, list) else [],
|
||||
response_body=record.response_body if record.response_body else '',
|
||||
vhost=record.vhost if record.vhost else False,
|
||||
matched_gf_patterns=[],
|
||||
response_headers=record.response_headers if record.response_headers else '',
|
||||
)
|
||||
snapshots.append(dto)
|
||||
except Exception as e:
|
||||
logger.error("处理记录失败: %s,错误: %s", record.get('url', 'Unknown'), e)
|
||||
logger.error("处理记录失败: %s,错误: %s", record.url, e)
|
||||
continue
|
||||
|
||||
if snapshots:
|
||||
@@ -316,15 +447,69 @@ def _save_batch(
|
||||
services.snapshot.save_and_sync(snapshots)
|
||||
count = len(snapshots)
|
||||
logger.info(
|
||||
"批次 %d: 保存了 %d 个存活的 URL(共 %d 个)",
|
||||
batch_num, count, len(batch)
|
||||
"批次 %d: 保存了 %d 个存活的 URL(共 %d 个,跳过失败: %d)",
|
||||
batch_num, count, len(batch), skipped_failed
|
||||
)
|
||||
return count
|
||||
return {
|
||||
'created_endpoints': count,
|
||||
'skipped_failed': skipped_failed
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("批次 %d 批量保存失败: %s", batch_num, e)
|
||||
raise
|
||||
|
||||
return 0
|
||||
return {
|
||||
'created_endpoints': 0,
|
||||
'skipped_failed': skipped_failed
|
||||
}
|
||||
|
||||
|
||||
def _accumulate_batch_stats(total_stats: dict, batch_result: dict) -> None:
|
||||
"""
|
||||
累加批次统计信息
|
||||
|
||||
Args:
|
||||
total_stats: 总统计信息字典
|
||||
batch_result: 批次结果字典
|
||||
"""
|
||||
total_stats['created_endpoints'] += batch_result.get('created_endpoints', 0)
|
||||
total_stats['skipped_failed'] += batch_result.get('skipped_failed', 0)
|
||||
|
||||
|
||||
def _process_batch(
|
||||
batch: list,
|
||||
scan_id: int,
|
||||
target_id: int,
|
||||
batch_num: int,
|
||||
total_stats: dict,
|
||||
failed_batches: list,
|
||||
services: ServiceSet
|
||||
) -> None:
|
||||
"""
|
||||
处理单个批次
|
||||
|
||||
Args:
|
||||
batch: 数据批次
|
||||
scan_id: 扫描ID
|
||||
target_id: 目标ID
|
||||
batch_num: 批次编号
|
||||
total_stats: 总统计信息
|
||||
failed_batches: 失败批次列表
|
||||
services: Service 集合(必须,依赖注入)
|
||||
"""
|
||||
result = _save_batch_with_retry(
|
||||
batch, scan_id, target_id, batch_num, services
|
||||
)
|
||||
|
||||
# 累计统计信息(失败时可能有部分数据已保存)
|
||||
_accumulate_batch_stats(total_stats, result)
|
||||
|
||||
if not result['success']:
|
||||
failed_batches.append(batch_num)
|
||||
logger.warning(
|
||||
"批次 %d 保存失败,但已累计统计信息:创建端点=%d",
|
||||
batch_num, result.get('created_endpoints', 0)
|
||||
)
|
||||
|
||||
|
||||
def _process_records_in_batches(
|
||||
@@ -335,7 +520,7 @@ def _process_records_in_batches(
|
||||
services: ServiceSet
|
||||
) -> dict:
|
||||
"""
|
||||
分批处理记录并保存到数据库
|
||||
流式处理记录并分批保存
|
||||
|
||||
Args:
|
||||
data_generator: 数据生成器
|
||||
@@ -345,14 +530,23 @@ def _process_records_in_batches(
|
||||
services: Service 集合
|
||||
|
||||
Returns:
|
||||
dict: 处理统计结果
|
||||
dict: 处理统计信息
|
||||
|
||||
Raises:
|
||||
RuntimeError: 存在失败批次时抛出
|
||||
"""
|
||||
batch = []
|
||||
batch_num = 0
|
||||
total_records = 0
|
||||
total_saved = 0
|
||||
batch_num = 0
|
||||
failed_batches = []
|
||||
batch = []
|
||||
|
||||
# 统计信息
|
||||
total_stats = {
|
||||
'created_endpoints': 0,
|
||||
'skipped_failed': 0
|
||||
}
|
||||
|
||||
# 流式读取生成器并分批保存
|
||||
for record in data_generator:
|
||||
batch.append(record)
|
||||
total_records += 1
|
||||
@@ -360,46 +554,35 @@ def _process_records_in_batches(
|
||||
# 达到批次大小,执行保存
|
||||
if len(batch) >= batch_size:
|
||||
batch_num += 1
|
||||
result = _save_batch_with_retry(
|
||||
batch, scan_id, target_id, batch_num, services
|
||||
)
|
||||
|
||||
if result['success']:
|
||||
total_saved += result['saved_count']
|
||||
else:
|
||||
failed_batches.append(batch_num)
|
||||
|
||||
_process_batch(batch, scan_id, target_id, batch_num, total_stats, failed_batches, services)
|
||||
batch = [] # 清空批次
|
||||
|
||||
# 每 10 个批次输出进度
|
||||
if batch_num % 10 == 0:
|
||||
logger.info(
|
||||
"进度: 已处理 %d 批次,%d 条记录,保存 %d 条",
|
||||
batch_num, total_records, total_saved
|
||||
)
|
||||
logger.info("进度: 已处理 %d 批次,%d 条记录", batch_num, total_records)
|
||||
|
||||
# 保存最后一批
|
||||
if batch:
|
||||
batch_num += 1
|
||||
result = _save_batch_with_retry(
|
||||
batch, scan_id, target_id, batch_num, services
|
||||
_process_batch(batch, scan_id, target_id, batch_num, total_stats, failed_batches, services)
|
||||
|
||||
# 检查失败批次
|
||||
if failed_batches:
|
||||
error_msg = (
|
||||
f"流式保存 URL 验证结果时出现失败批次,处理记录: {total_records},"
|
||||
f"失败批次: {failed_batches}"
|
||||
)
|
||||
|
||||
if result['success']:
|
||||
total_saved += result['saved_count']
|
||||
else:
|
||||
failed_batches.append(batch_num)
|
||||
logger.warning(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
return {
|
||||
'processed_records': total_records,
|
||||
'saved_urls': total_saved,
|
||||
'failed_urls': total_records - total_saved,
|
||||
'batch_count': batch_num,
|
||||
'failed_batches': failed_batches
|
||||
**total_stats
|
||||
}
|
||||
|
||||
|
||||
@task(name="run_and_stream_save_urls", retries=3, retry_delay_seconds=10)
|
||||
@task(name="run_and_stream_save_urls", retries=0)
|
||||
def run_and_stream_save_urls_task(
|
||||
cmd: str,
|
||||
tool_name: str,
|
||||
@@ -407,7 +590,7 @@ def run_and_stream_save_urls_task(
|
||||
target_id: int,
|
||||
cwd: Optional[str] = None,
|
||||
shell: bool = False,
|
||||
batch_size: int = 500,
|
||||
batch_size: int = 100,
|
||||
timeout: Optional[int] = None,
|
||||
log_file: Optional[str] = None
|
||||
) -> dict:
|
||||
@@ -415,17 +598,18 @@ def run_and_stream_save_urls_task(
|
||||
执行 httpx 验证并流式保存存活的 URL
|
||||
|
||||
该任务将:
|
||||
1. 执行 httpx 命令验证 URL 存活
|
||||
2. 流式处理输出,实时解析
|
||||
3. 批量保存存活的 URL 到 Endpoint 表
|
||||
1. 验证输入参数
|
||||
2. 初始化资源(缓存、生成器)
|
||||
3. 流式处理记录并分批保存
|
||||
4. 构建并返回结果统计
|
||||
|
||||
Args:
|
||||
cmd: httpx 命令
|
||||
tool_name: 工具名称('httpx')
|
||||
scan_id: 扫描任务 ID
|
||||
target_id: 目标 ID
|
||||
cwd: 工作目录
|
||||
shell: 是否使用 shell 执行
|
||||
cwd: 工作目录(可选)
|
||||
shell: 是否使用 shell 执行(默认 False)
|
||||
batch_size: 批次大小(默认 500)
|
||||
timeout: 超时时间(秒)
|
||||
log_file: 日志文件路径
|
||||
@@ -433,11 +617,14 @@ def run_and_stream_save_urls_task(
|
||||
Returns:
|
||||
dict: {
|
||||
'processed_records': int, # 处理的记录总数
|
||||
'saved_urls': int, # 保存的存活 URL 数
|
||||
'failed_urls': int, # 失败/死链数
|
||||
'batch_count': int, # 批次数
|
||||
'failed_batches': list # 失败的批次号
|
||||
'created_endpoints': int, # 创建的端点记录数
|
||||
'skipped_failed': int, # 因请求失败跳过的记录数
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 参数验证失败
|
||||
RuntimeError: 命令执行或数据库操作失败
|
||||
subprocess.TimeoutExpired: 命令执行超时
|
||||
"""
|
||||
logger.info(
|
||||
"开始执行流式 URL 验证任务 - target_id=%s, 超时=%s秒, 命令: %s",
|
||||
@@ -447,33 +634,30 @@ def run_and_stream_save_urls_task(
|
||||
data_generator = None
|
||||
|
||||
try:
|
||||
# 1. 初始化资源
|
||||
# 1. 验证参数
|
||||
_validate_task_parameters(cmd, target_id, scan_id, cwd)
|
||||
|
||||
# 2. 初始化资源
|
||||
data_generator = _parse_httpx_stream_output(
|
||||
cmd, tool_name, cwd, shell, timeout, log_file
|
||||
)
|
||||
services = ServiceSet.create_default()
|
||||
|
||||
# 2. 流式处理记录并分批保存
|
||||
# 3. 流式处理记录并分批保存
|
||||
stats = _process_records_in_batches(
|
||||
data_generator, scan_id, target_id, batch_size, services
|
||||
)
|
||||
|
||||
# 3. 输出最终统计
|
||||
logger.info(
|
||||
"✓ URL 验证任务完成 - 处理: %d, 存活: %d, 失败: %d",
|
||||
stats['processed_records'],
|
||||
stats['saved_urls'],
|
||||
stats['failed_urls']
|
||||
)
|
||||
|
||||
return stats
|
||||
# 4. 构建最终结果
|
||||
return _build_final_result(stats)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
# 超时异常直接向上传播,保留异常类型
|
||||
logger.warning(
|
||||
"⚠️ URL 验证任务超时 - target_id=%s, 超时=%s秒",
|
||||
target_id, timeout
|
||||
)
|
||||
raise
|
||||
raise # 直接重新抛出,不包装
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"流式执行 URL 验证任务失败: {e}"
|
||||
@@ -481,12 +665,5 @@ def run_and_stream_save_urls_task(
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
finally:
|
||||
# 清理资源
|
||||
if data_generator is not None:
|
||||
try:
|
||||
# 确保生成器被正确关闭
|
||||
data_generator.close()
|
||||
except (GeneratorExit, StopIteration):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("关闭数据生成器时出错: %s", e)
|
||||
# 5. 清理资源
|
||||
_cleanup_resources(data_generator)
|
||||
|
||||
@@ -1516,7 +1516,7 @@ class TestDataGenerator:
|
||||
if batch_data:
|
||||
execute_values(cur, """
|
||||
INSERT INTO website_snapshot (
|
||||
scan_id, url, host, title, web_server, tech, status,
|
||||
scan_id, url, host, title, webserver, tech, status_code,
|
||||
content_length, content_type, location, response_body,
|
||||
response_headers, created_at
|
||||
) VALUES %s
|
||||
|
||||
Reference in New Issue
Block a user