mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-01-31 19:53:11 +08:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5acaada7ab | ||
|
|
aaad3f29cf |
@@ -135,6 +135,7 @@ class Endpoint(models.Model):
|
||||
models.Index(fields=['url']), # URL索引,优化查询性能
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['status_code']), # 状态码索引,优化筛选
|
||||
models.Index(fields=['title']), # title索引,优化智能过滤搜索
|
||||
]
|
||||
constraints = [
|
||||
# 普通唯一约束:url + target 组合唯一
|
||||
@@ -229,6 +230,8 @@ class WebSite(models.Model):
|
||||
models.Index(fields=['url']), # URL索引,优化查询性能
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['target']), # 优化从target_id快速查找下面的站点
|
||||
models.Index(fields=['title']), # title索引,优化智能过滤搜索
|
||||
models.Index(fields=['status_code']), # 状态码索引,优化智能过滤搜索
|
||||
]
|
||||
constraints = [
|
||||
# 普通唯一约束:url + target 组合唯一
|
||||
@@ -408,7 +411,7 @@ class Vulnerability(models.Model):
|
||||
)
|
||||
|
||||
# ==================== 核心字段 ====================
|
||||
url = models.TextField(help_text='漏洞所在的URL')
|
||||
url = models.CharField(max_length=2000, help_text='漏洞所在的URL')
|
||||
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli)')
|
||||
severity = models.CharField(
|
||||
max_length=20,
|
||||
@@ -434,6 +437,7 @@ class Vulnerability(models.Model):
|
||||
models.Index(fields=['vuln_type']),
|
||||
models.Index(fields=['severity']),
|
||||
models.Index(fields=['source']),
|
||||
models.Index(fields=['url']), # url索引,优化智能过滤搜索
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ class WebsiteSnapshot(models.Model):
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['url']),
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['title']), # title索引,优化标题搜索
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
constraints = [
|
||||
@@ -129,6 +130,7 @@ class DirectorySnapshot(models.Model):
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['url']),
|
||||
models.Index(fields=['status']), # 状态码索引,优化筛选
|
||||
models.Index(fields=['content_type']), # content_type索引,优化内容类型搜索
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
constraints = [
|
||||
@@ -268,7 +270,9 @@ class EndpointSnapshot(models.Model):
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['url']),
|
||||
models.Index(fields=['host']), # host索引,优化根据主机名查询
|
||||
models.Index(fields=['title']), # title索引,优化标题搜索
|
||||
models.Index(fields=['status_code']), # 状态码索引,优化筛选
|
||||
models.Index(fields=['webserver']), # webserver索引,优化服务器搜索
|
||||
models.Index(fields=['-created_at']),
|
||||
]
|
||||
constraints = [
|
||||
@@ -302,7 +306,7 @@ class VulnerabilitySnapshot(models.Model):
|
||||
)
|
||||
|
||||
# ==================== 核心字段 ====================
|
||||
url = models.TextField(help_text='漏洞所在的URL')
|
||||
url = models.CharField(max_length=2000, help_text='漏洞所在的URL')
|
||||
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli)')
|
||||
severity = models.CharField(
|
||||
max_length=20,
|
||||
@@ -325,6 +329,7 @@ class VulnerabilitySnapshot(models.Model):
|
||||
ordering = ['-created_at']
|
||||
indexes = [
|
||||
models.Index(fields=['scan']),
|
||||
models.Index(fields=['url']), # url索引,优化URL搜索
|
||||
models.Index(fields=['vuln_type']),
|
||||
models.Index(fields=['severity']),
|
||||
models.Index(fields=['source']),
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""HostPortMapping Repository - Django ORM 实现"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
from typing import List, Iterator, Dict, Optional
|
||||
|
||||
from django.db.models import QuerySet, Min
|
||||
|
||||
from apps.asset.models.asset_models import HostPortMapping
|
||||
from apps.asset.dtos.asset import HostPortMappingDTO
|
||||
@@ -13,7 +15,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@auto_ensure_db_connection
|
||||
class DjangoHostPortMappingRepository:
|
||||
"""HostPortMapping Repository - Django ORM 实现"""
|
||||
"""HostPortMapping Repository - Django ORM 实现
|
||||
|
||||
职责:纯数据访问,不包含业务逻辑
|
||||
"""
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[HostPortMappingDTO]) -> int:
|
||||
"""
|
||||
@@ -90,72 +95,20 @@ class DjangoHostPortMappingRepository:
|
||||
for ip in queryset:
|
||||
yield ip
|
||||
|
||||
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
|
||||
from django.db.models import Min
|
||||
def get_queryset_by_target(self, target_id: int) -> QuerySet:
|
||||
"""获取目标下的 QuerySet"""
|
||||
return HostPortMapping.objects.filter(target_id=target_id)
|
||||
|
||||
qs = HostPortMapping.objects.filter(target_id=target_id)
|
||||
if search:
|
||||
qs = qs.filter(ip__icontains=search)
|
||||
def get_all_queryset(self) -> QuerySet:
|
||||
"""获取所有记录的 QuerySet"""
|
||||
return HostPortMapping.objects.all()
|
||||
|
||||
ip_aggregated = (
|
||||
qs
|
||||
.values('ip')
|
||||
.annotate(created_at=Min('created_at'))
|
||||
.order_by('-created_at')
|
||||
)
|
||||
|
||||
results = []
|
||||
for item in ip_aggregated:
|
||||
ip = item['ip']
|
||||
mappings = (
|
||||
HostPortMapping.objects
|
||||
.filter(target_id=target_id, ip=ip)
|
||||
.values('host', 'port')
|
||||
.distinct()
|
||||
)
|
||||
hosts = sorted({m['host'] for m in mappings})
|
||||
ports = sorted({m['port'] for m in mappings})
|
||||
results.append({
|
||||
'ip': ip,
|
||||
'hosts': hosts,
|
||||
'ports': ports,
|
||||
'created_at': item['created_at'],
|
||||
})
|
||||
return results
|
||||
|
||||
def get_all_ip_aggregation(self, search: str = None):
|
||||
"""获取所有 IP 聚合数据(全局查询)"""
|
||||
from django.db.models import Min
|
||||
|
||||
qs = HostPortMapping.objects.all()
|
||||
if search:
|
||||
qs = qs.filter(ip__icontains=search)
|
||||
|
||||
ip_aggregated = (
|
||||
qs
|
||||
.values('ip')
|
||||
.annotate(created_at=Min('created_at'))
|
||||
.order_by('-created_at')
|
||||
)
|
||||
|
||||
results = []
|
||||
for item in ip_aggregated:
|
||||
ip = item['ip']
|
||||
mappings = (
|
||||
HostPortMapping.objects
|
||||
.filter(ip=ip)
|
||||
.values('host', 'port')
|
||||
.distinct()
|
||||
)
|
||||
hosts = sorted({m['host'] for m in mappings})
|
||||
ports = sorted({m['port'] for m in mappings})
|
||||
results.append({
|
||||
'ip': ip,
|
||||
'hosts': hosts,
|
||||
'ports': ports,
|
||||
'created_at': item['created_at'],
|
||||
})
|
||||
return results
|
||||
def get_queryset_by_ip(self, ip: str, target_id: Optional[int] = None) -> QuerySet:
|
||||
"""获取指定 IP 的 QuerySet"""
|
||||
qs = HostPortMapping.objects.filter(ip=ip)
|
||||
if target_id:
|
||||
qs = qs.filter(target_id=target_id)
|
||||
return qs
|
||||
|
||||
def iter_raw_data_for_export(
|
||||
self,
|
||||
|
||||
@@ -65,12 +65,20 @@ class DjangoHostPortMappingSnapshotRepository:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
|
||||
def get_ip_aggregation_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
from django.db.models import Min
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
qs = HostPortMappingSnapshot.objects.filter(scan_id=scan_id)
|
||||
if search:
|
||||
qs = qs.filter(ip__icontains=search)
|
||||
|
||||
# 应用智能过滤
|
||||
if filter_query:
|
||||
field_mapping = {
|
||||
'ip': 'ip',
|
||||
'port': 'port',
|
||||
'host': 'host',
|
||||
}
|
||||
qs = apply_filters(qs, filter_query, field_mapping)
|
||||
|
||||
ip_aggregated = (
|
||||
qs
|
||||
@@ -103,13 +111,21 @@ class DjangoHostPortMappingSnapshotRepository:
|
||||
|
||||
return results
|
||||
|
||||
def get_all_ip_aggregation(self, search: str = None):
|
||||
def get_all_ip_aggregation(self, filter_query: str = None):
|
||||
"""获取所有 IP 聚合数据"""
|
||||
from django.db.models import Min
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
qs = HostPortMappingSnapshot.objects.all()
|
||||
if search:
|
||||
qs = qs.filter(ip__icontains=search)
|
||||
|
||||
# 应用智能过滤
|
||||
if filter_query:
|
||||
field_mapping = {
|
||||
'ip': 'ip',
|
||||
'port': 'port',
|
||||
'host': 'host',
|
||||
}
|
||||
qs = apply_filters(qs, filter_query, field_mapping)
|
||||
|
||||
ip_aggregated = (
|
||||
qs
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Directory Service - 目录业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
from typing import List, Iterator, Optional
|
||||
|
||||
from apps.asset.repositories import DjangoDirectoryRepository
|
||||
from apps.asset.dtos import DirectoryDTO
|
||||
from apps.common.validators import is_valid_url, is_url_match_target
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,6 +14,12 @@ logger = logging.getLogger(__name__)
|
||||
class DirectoryService:
|
||||
"""目录业务逻辑层"""
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'url': 'url',
|
||||
'status': 'status',
|
||||
}
|
||||
|
||||
def __init__(self, repository=None):
|
||||
"""初始化目录服务"""
|
||||
self.repo = repository or DjangoDirectoryRepository()
|
||||
@@ -94,13 +101,19 @@ class DirectoryService:
|
||||
count_after = self.repo.count_by_target(target_id)
|
||||
return count_after - count_before
|
||||
|
||||
def get_directories_by_target(self, target_id: int):
|
||||
def get_directories_by_target(self, target_id: int, filter_query: Optional[str] = None):
|
||||
"""获取目标下的所有目录"""
|
||||
return self.repo.get_by_target(target_id)
|
||||
queryset = self.repo.get_by_target(target_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: Optional[str] = None):
|
||||
"""获取所有目录"""
|
||||
return self.repo.get_all()
|
||||
queryset = self.repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_directory_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取目标下的所有目录 URL"""
|
||||
|
||||
@@ -5,11 +5,12 @@ Endpoint 服务层
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
from typing import List, Iterator, Optional
|
||||
|
||||
from apps.asset.dtos.asset import EndpointDTO
|
||||
from apps.asset.repositories.asset import DjangoEndpointRepository
|
||||
from apps.common.validators import is_valid_url, is_url_match_target
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,6 +22,14 @@ class EndpointService:
|
||||
提供 Endpoint(URL/端点)相关的业务逻辑
|
||||
"""
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status_code',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""初始化 Endpoint 服务"""
|
||||
self.repo = DjangoEndpointRepository()
|
||||
@@ -102,9 +111,12 @@ class EndpointService:
|
||||
count_after = self.repo.count_by_target(target_id)
|
||||
return count_after - count_before
|
||||
|
||||
def get_endpoints_by_target(self, target_id: int):
|
||||
def get_endpoints_by_target(self, target_id: int, filter_query: Optional[str] = None):
|
||||
"""获取目标下的所有端点"""
|
||||
return self.repo.get_by_target(target_id)
|
||||
queryset = self.repo.get_by_target(target_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def count_endpoints_by_target(self, target_id: int) -> int:
|
||||
"""
|
||||
@@ -118,9 +130,12 @@ class EndpointService:
|
||||
"""
|
||||
return self.repo.count_by_target(target_id)
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: Optional[str] = None):
|
||||
"""获取所有端点(全局查询)"""
|
||||
return self.repo.get_all()
|
||||
queryset = self.repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_endpoint_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取目标下的所有端点 URL,用于导出。"""
|
||||
|
||||
@@ -1,16 +1,31 @@
|
||||
"""HostPortMapping Service - 业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
from typing import List, Iterator, Optional, Dict
|
||||
|
||||
from django.db.models import Min
|
||||
|
||||
from apps.asset.repositories.asset import DjangoHostPortMappingRepository
|
||||
from apps.asset.dtos.asset import HostPortMappingDTO
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HostPortMappingService:
|
||||
"""主机端口映射服务 - 负责主机端口映射数据的业务逻辑"""
|
||||
"""主机端口映射服务 - 负责主机端口映射数据的业务逻辑
|
||||
|
||||
职责:
|
||||
- 业务逻辑处理(过滤、聚合)
|
||||
- 调用 Repository 进行数据访问
|
||||
"""
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'ip': 'ip',
|
||||
'port': 'port',
|
||||
'host': 'host',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.repo = DjangoHostPortMappingRepository()
|
||||
@@ -49,12 +64,93 @@ class HostPortMappingService:
|
||||
def iter_host_port_by_target(self, target_id: int, batch_size: int = 1000):
|
||||
return self.repo.get_for_export(target_id=target_id, batch_size=batch_size)
|
||||
|
||||
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
|
||||
return self.repo.get_ip_aggregation_by_target(target_id, search=search)
|
||||
def get_ip_aggregation_by_target(
|
||||
self,
|
||||
target_id: int,
|
||||
filter_query: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
"""获取目标下的 IP 聚合数据
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
filter_query: 智能过滤语法字符串
|
||||
|
||||
Returns:
|
||||
聚合后的 IP 数据列表
|
||||
"""
|
||||
# 从 Repository 获取基础 QuerySet
|
||||
qs = self.repo.get_queryset_by_target(target_id)
|
||||
|
||||
# Service 层应用过滤逻辑
|
||||
if filter_query:
|
||||
qs = apply_filters(qs, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
|
||||
# Service 层处理聚合逻辑
|
||||
return self._aggregate_by_ip(qs, filter_query, target_id=target_id)
|
||||
|
||||
def get_all_ip_aggregation(self, search: str = None):
|
||||
"""获取所有 IP 聚合数据(全局查询)"""
|
||||
return self.repo.get_all_ip_aggregation(search=search)
|
||||
def get_all_ip_aggregation(self, filter_query: Optional[str] = None) -> List[Dict]:
|
||||
"""获取所有 IP 聚合数据(全局查询)
|
||||
|
||||
Args:
|
||||
filter_query: 智能过滤语法字符串
|
||||
|
||||
Returns:
|
||||
聚合后的 IP 数据列表
|
||||
"""
|
||||
# 从 Repository 获取基础 QuerySet
|
||||
qs = self.repo.get_all_queryset()
|
||||
|
||||
# Service 层应用过滤逻辑
|
||||
if filter_query:
|
||||
qs = apply_filters(qs, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
|
||||
# Service 层处理聚合逻辑
|
||||
return self._aggregate_by_ip(qs, filter_query)
|
||||
|
||||
def _aggregate_by_ip(
|
||||
self,
|
||||
qs,
|
||||
filter_query: Optional[str] = None,
|
||||
target_id: Optional[int] = None
|
||||
) -> List[Dict]:
|
||||
"""按 IP 聚合数据
|
||||
|
||||
Args:
|
||||
qs: 已过滤的 QuerySet
|
||||
filter_query: 过滤条件(用于子查询)
|
||||
target_id: 目标 ID(用于子查询限定范围)
|
||||
|
||||
Returns:
|
||||
聚合后的数据列表
|
||||
"""
|
||||
ip_aggregated = (
|
||||
qs
|
||||
.values('ip')
|
||||
.annotate(created_at=Min('created_at'))
|
||||
.order_by('-created_at')
|
||||
)
|
||||
|
||||
results = []
|
||||
for item in ip_aggregated:
|
||||
ip = item['ip']
|
||||
|
||||
# 获取该 IP 的所有 host 和 port(也需要应用过滤条件)
|
||||
mappings_qs = self.repo.get_queryset_by_ip(ip, target_id=target_id)
|
||||
if filter_query:
|
||||
mappings_qs = apply_filters(mappings_qs, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
|
||||
mappings = mappings_qs.values('host', 'port').distinct()
|
||||
hosts = sorted({m['host'] for m in mappings})
|
||||
ports = sorted({m['port'] for m in mappings})
|
||||
|
||||
results.append({
|
||||
'ip': ip,
|
||||
'hosts': hosts,
|
||||
'ports': ports,
|
||||
'created_at': item['created_at'],
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def iter_ips_by_target(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取目标下的所有唯一 IP 地址。"""
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import logging
|
||||
from typing import Tuple, List, Dict
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from apps.asset.repositories import DjangoSubdomainRepository
|
||||
from apps.asset.dtos import SubdomainDTO
|
||||
from apps.common.validators import is_valid_domain
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,6 +23,11 @@ class BulkCreateResult:
|
||||
class SubdomainService:
|
||||
"""子域名业务逻辑层"""
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'name': 'name',
|
||||
}
|
||||
|
||||
def __init__(self, repository=None):
|
||||
"""
|
||||
初始化子域名服务
|
||||
@@ -33,30 +39,50 @@ class SubdomainService:
|
||||
|
||||
# ==================== 查询操作 ====================
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: Optional[str] = None):
|
||||
"""
|
||||
获取所有子域名
|
||||
|
||||
Args:
|
||||
filter_query: 智能过滤语法字符串
|
||||
|
||||
Returns:
|
||||
QuerySet: 子域名查询集
|
||||
"""
|
||||
logger.debug("获取所有子域名")
|
||||
return self.repo.get_all()
|
||||
queryset = self.repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
# ==================== 创建操作 ====================
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
|
||||
def get_subdomains_by_target(self, target_id: int, filter_query: Optional[str] = None):
|
||||
"""
|
||||
批量创建子域名,忽略冲突
|
||||
获取目标下的子域名
|
||||
|
||||
Args:
|
||||
items: 子域名 DTO 列表
|
||||
target_id: 目标 ID
|
||||
filter_query: 智能过滤语法字符串
|
||||
|
||||
Note:
|
||||
使用 ignore_conflicts 策略,重复记录会被跳过
|
||||
Returns:
|
||||
QuerySet: 子域名查询集
|
||||
"""
|
||||
logger.debug("批量创建子域名 - 数量: %d", len(items))
|
||||
return self.repo.bulk_create_ignore_conflicts(items)
|
||||
queryset = self.repo.get_by_target(target_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def count_subdomains_by_target(self, target_id: int) -> int:
|
||||
"""
|
||||
统计目标下的子域名数量
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
int: 子域名数量
|
||||
"""
|
||||
logger.debug("统计目标下子域名数量 - Target ID: %d", target_id)
|
||||
return self.repo.count_by_target(target_id)
|
||||
|
||||
def get_by_names_and_target_id(self, names: set, target_id: int) -> dict:
|
||||
"""
|
||||
@@ -83,25 +109,8 @@ class SubdomainService:
|
||||
List[str]: 子域名名称列表
|
||||
"""
|
||||
logger.debug("获取目标下所有子域名 - Target ID: %d", target_id)
|
||||
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
|
||||
return list(self.repo.get_domains_for_export(target_id=target_id))
|
||||
|
||||
def get_subdomains_by_target(self, target_id: int):
|
||||
return self.repo.get_by_target(target_id)
|
||||
|
||||
def count_subdomains_by_target(self, target_id: int) -> int:
|
||||
"""
|
||||
统计目标下的子域名数量
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
|
||||
Returns:
|
||||
int: 子域名数量
|
||||
"""
|
||||
logger.debug("统计目标下子域名数量 - Target ID: %d", target_id)
|
||||
return self.repo.count_by_target(target_id)
|
||||
|
||||
def iter_subdomain_names_by_target(self, target_id: int, chunk_size: int = 1000):
|
||||
"""
|
||||
流式获取目标下的所有子域名名称(内存优化)
|
||||
@@ -114,7 +123,6 @@ class SubdomainService:
|
||||
str: 子域名名称
|
||||
"""
|
||||
logger.debug("流式获取目标下所有子域名 - Target ID: %d, 批次大小: %d", target_id, chunk_size)
|
||||
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
|
||||
return self.repo.get_domains_for_export(target_id=target_id, batch_size=chunk_size)
|
||||
|
||||
def iter_raw_data_for_csv_export(self, target_id: int):
|
||||
@@ -129,6 +137,21 @@ class SubdomainService:
|
||||
"""
|
||||
return self.repo.iter_raw_data_for_export(target_id=target_id)
|
||||
|
||||
# ==================== 创建操作 ====================
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
|
||||
"""
|
||||
批量创建子域名,忽略冲突
|
||||
|
||||
Args:
|
||||
items: 子域名 DTO 列表
|
||||
|
||||
Note:
|
||||
使用 ignore_conflicts 策略,重复记录会被跳过
|
||||
"""
|
||||
logger.debug("批量创建子域名 - 数量: %d", len(items))
|
||||
return self.repo.bulk_create_ignore_conflicts(items)
|
||||
|
||||
def bulk_create_subdomains(
|
||||
self,
|
||||
target_id: int,
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""Vulnerability Service - 漏洞资产业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from apps.asset.models import Vulnerability
|
||||
from apps.asset.dtos.asset import VulnerabilityDTO
|
||||
from apps.common.decorators import auto_ensure_db_connection
|
||||
from apps.common.utils import deduplicate_for_bulk
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,6 +18,14 @@ class VulnerabilityService:
|
||||
|
||||
当前提供基础的批量创建能力,使用 ignore_conflicts 依赖数据库唯一约束去重。
|
||||
"""
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'type': 'vuln_type',
|
||||
'severity': 'severity',
|
||||
'source': 'source',
|
||||
'url': 'url',
|
||||
}
|
||||
|
||||
def bulk_create_ignore_conflicts(self, items: List[VulnerabilityDTO]) -> None:
|
||||
"""批量创建漏洞资产记录,忽略冲突。
|
||||
@@ -63,24 +72,34 @@ class VulnerabilityService:
|
||||
|
||||
# ==================== 查询方法 ====================
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: Optional[str] = None):
|
||||
"""获取所有漏洞 QuerySet(用于全局漏洞列表)。
|
||||
|
||||
Args:
|
||||
filter_query: 智能过滤语法字符串
|
||||
|
||||
Returns:
|
||||
QuerySet[Vulnerability]: 所有漏洞,按创建时间倒序
|
||||
"""
|
||||
return Vulnerability.objects.all().order_by("-created_at")
|
||||
queryset = Vulnerability.objects.all().order_by("-created_at")
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_vulnerabilities_by_target(self, target_id: int):
|
||||
def get_vulnerabilities_by_target(self, target_id: int, filter_query: Optional[str] = None):
|
||||
"""按目标获取漏洞 QuerySet(用于分页)。
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
filter_query: 智能过滤语法字符串
|
||||
|
||||
Returns:
|
||||
QuerySet[Vulnerability]: 目标下的所有漏洞,按创建时间倒序
|
||||
"""
|
||||
return Vulnerability.objects.filter(target_id=target_id).order_by("-created_at")
|
||||
queryset = Vulnerability.objects.filter(target_id=target_id).order_by("-created_at")
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def count_by_target(self, target_id: int) -> int:
|
||||
"""统计目标下的漏洞数量。"""
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""WebSite Service - 网站业务逻辑层"""
|
||||
|
||||
import logging
|
||||
from typing import List, Iterator
|
||||
from typing import List, Iterator, Optional
|
||||
|
||||
from apps.asset.repositories import DjangoWebSiteRepository
|
||||
from apps.asset.dtos import WebSiteDTO
|
||||
from apps.common.validators import is_valid_url, is_url_match_target
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,6 +14,14 @@ logger = logging.getLogger(__name__)
|
||||
class WebSiteService:
|
||||
"""网站业务逻辑层"""
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status_code',
|
||||
}
|
||||
|
||||
def __init__(self, repository=None):
|
||||
"""初始化网站服务"""
|
||||
self.repo = repository or DjangoWebSiteRepository()
|
||||
@@ -94,13 +103,19 @@ class WebSiteService:
|
||||
count_after = self.repo.count_by_target(target_id)
|
||||
return count_after - count_before
|
||||
|
||||
def get_websites_by_target(self, target_id: int):
|
||||
def get_websites_by_target(self, target_id: int, filter_query: Optional[str] = None):
|
||||
"""获取目标下的所有网站"""
|
||||
return self.repo.get_by_target(target_id)
|
||||
queryset = self.repo.get_by_target(target_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: Optional[str] = None):
|
||||
"""获取所有网站"""
|
||||
return self.repo.get_all()
|
||||
queryset = self.repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_by_url(self, url: str, target_id: int) -> int:
|
||||
"""根据 URL 和 target_id 查找网站 ID"""
|
||||
|
||||
@@ -67,12 +67,29 @@ class DirectorySnapshotsService:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return self.snapshot_repo.get_by_scan(scan_id)
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'url': 'url',
|
||||
'status': 'status',
|
||||
'content_type': 'content_type',
|
||||
}
|
||||
|
||||
def get_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: str = None):
|
||||
"""获取所有目录快照"""
|
||||
return self.snapshot_repo.get_all()
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_directory_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有目录 URL。"""
|
||||
|
||||
@@ -67,12 +67,32 @@ class EndpointSnapshotsService:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return self.snapshot_repo.get_by_scan(scan_id)
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status_code',
|
||||
'webserver': 'webserver',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
def get_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: str = None):
|
||||
"""获取所有端点快照"""
|
||||
return self.snapshot_repo.get_all()
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_endpoint_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有端点 URL。"""
|
||||
|
||||
@@ -69,12 +69,12 @@ class HostPortMappingSnapshotsService:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
|
||||
return self.snapshot_repo.get_ip_aggregation_by_scan(scan_id, search=search)
|
||||
def get_ip_aggregation_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
return self.snapshot_repo.get_ip_aggregation_by_scan(scan_id, filter_query=filter_query)
|
||||
|
||||
def get_all_ip_aggregation(self, search: str = None):
|
||||
def get_all_ip_aggregation(self, filter_query: str = None):
|
||||
"""获取所有 IP 聚合数据"""
|
||||
return self.snapshot_repo.get_all_ip_aggregation(search=search)
|
||||
return self.snapshot_repo.get_all_ip_aggregation(filter_query=filter_query)
|
||||
|
||||
def iter_ips_by_scan(self, scan_id: int, batch_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有唯一 IP 地址。"""
|
||||
|
||||
@@ -66,12 +66,27 @@ class SubdomainSnapshotsService:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return self.subdomain_snapshot_repo.get_by_scan(scan_id)
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'name': 'name',
|
||||
}
|
||||
|
||||
def get_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.subdomain_snapshot_repo.get_by_scan(scan_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: str = None):
|
||||
"""获取所有子域名快照"""
|
||||
return self.subdomain_snapshot_repo.get_all()
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.subdomain_snapshot_repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_subdomain_names_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
queryset = self.subdomain_snapshot_repo.get_by_scan(scan_id)
|
||||
|
||||
@@ -66,13 +66,31 @@ class VulnerabilitySnapshotsService:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
"""按扫描任务获取所有漏洞快照。"""
|
||||
return self.snapshot_repo.get_by_scan(scan_id)
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'type': 'vuln_type',
|
||||
'url': 'url',
|
||||
'severity': 'severity',
|
||||
'source': 'source',
|
||||
}
|
||||
|
||||
def get_all(self):
|
||||
def get_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
"""按扫描任务获取所有漏洞快照。"""
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self, filter_query: str = None):
|
||||
"""获取所有漏洞快照"""
|
||||
return self.snapshot_repo.get_all()
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_vuln_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有漏洞 URL。"""
|
||||
|
||||
@@ -68,12 +68,32 @@ class WebsiteSnapshotsService:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scan(self, scan_id: int):
|
||||
return self.snapshot_repo.get_by_scan(scan_id)
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'url': 'url',
|
||||
'host': 'host',
|
||||
'title': 'title',
|
||||
'status': 'status',
|
||||
'webserver': 'web_server',
|
||||
'tech': 'tech',
|
||||
}
|
||||
|
||||
def get_by_scan(self, scan_id: int, filter_query: str = None):
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_by_scan(scan_id)
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filter_query: str = None):
|
||||
"""获取所有网站快照"""
|
||||
return self.snapshot_repo.get_all()
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
queryset = self.snapshot_repo.get_all()
|
||||
if filter_query:
|
||||
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
|
||||
return queryset
|
||||
|
||||
def iter_website_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
|
||||
"""流式获取某次扫描下的所有站点 URL(按创建时间倒序)。"""
|
||||
|
||||
@@ -126,12 +126,16 @@ class SubdomainViewSet(viewsets.ModelViewSet):
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/subdomains/
|
||||
2. 独立路由:GET /api/subdomains/(全局查询)
|
||||
|
||||
支持智能过滤语法(filter 参数):
|
||||
- name="api" 子域名模糊匹配
|
||||
- name=="api.example.com" 精确匹配
|
||||
- 多条件空格分隔 AND 关系
|
||||
"""
|
||||
|
||||
serializer_class = SubdomainListSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['name']
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering = ['-created_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -139,11 +143,13 @@ class SubdomainViewSet(viewsets.ModelViewSet):
|
||||
self.service = SubdomainService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
if target_pk:
|
||||
return self.service.get_subdomains_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
return self.service.get_subdomains_by_target(target_pk, filter_query=filter_query)
|
||||
return self.service.get_all(filter_query=filter_query)
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='bulk-create')
|
||||
def bulk_create(self, request, **kwargs):
|
||||
@@ -253,12 +259,18 @@ class WebSiteViewSet(viewsets.ModelViewSet):
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/websites/
|
||||
2. 独立路由:GET /api/websites/(全局查询)
|
||||
|
||||
支持智能过滤语法(filter 参数):
|
||||
- url="api" URL 模糊匹配
|
||||
- host="example" 主机名模糊匹配
|
||||
- title="login" 标题模糊匹配
|
||||
- status="200,301" 状态码多值匹配
|
||||
- 多条件空格分隔 AND 关系
|
||||
"""
|
||||
|
||||
serializer_class = WebSiteSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['host']
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering = ['-created_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -266,11 +278,13 @@ class WebSiteViewSet(viewsets.ModelViewSet):
|
||||
self.service = WebSiteService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
if target_pk:
|
||||
return self.service.get_websites_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
return self.service.get_websites_by_target(target_pk, filter_query=filter_query)
|
||||
return self.service.get_all(filter_query=filter_query)
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='bulk-create')
|
||||
def bulk_create(self, request, **kwargs):
|
||||
@@ -374,12 +388,16 @@ class DirectoryViewSet(viewsets.ModelViewSet):
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/directories/
|
||||
2. 独立路由:GET /api/directories/(全局查询)
|
||||
|
||||
支持智能过滤语法(filter 参数):
|
||||
- url="admin" URL 模糊匹配
|
||||
- status="200,301" 状态码多值匹配
|
||||
- 多条件空格分隔 AND 关系
|
||||
"""
|
||||
|
||||
serializer_class = DirectorySerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['url']
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering = ['-created_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -387,11 +405,13 @@ class DirectoryViewSet(viewsets.ModelViewSet):
|
||||
self.service = DirectoryService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
if target_pk:
|
||||
return self.service.get_directories_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
return self.service.get_directories_by_target(target_pk, filter_query=filter_query)
|
||||
return self.service.get_all(filter_query=filter_query)
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='bulk-create')
|
||||
def bulk_create(self, request, **kwargs):
|
||||
@@ -493,12 +513,18 @@ class EndpointViewSet(viewsets.ModelViewSet):
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/endpoints/
|
||||
2. 独立路由:GET /api/endpoints/(全局查询)
|
||||
|
||||
支持智能过滤语法(filter 参数):
|
||||
- url="api" URL 模糊匹配
|
||||
- host="example" 主机名模糊匹配
|
||||
- title="login" 标题模糊匹配
|
||||
- status="200,301" 状态码多值匹配
|
||||
- 多条件空格分隔 AND 关系
|
||||
"""
|
||||
|
||||
serializer_class = EndpointListSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['host']
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering = ['-created_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -506,11 +532,13 @@ class EndpointViewSet(viewsets.ModelViewSet):
|
||||
self.service = EndpointService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
if target_pk:
|
||||
return self.service.get_endpoints_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
return self.service.get_endpoints_by_target(target_pk, filter_query=filter_query)
|
||||
return self.service.get_all(filter_query=filter_query)
|
||||
|
||||
@action(detail=False, methods=['post'], url_path='bulk-create')
|
||||
def bulk_create(self, request, **kwargs):
|
||||
@@ -618,23 +646,40 @@ class HostPortMappingViewSet(viewsets.ModelViewSet):
|
||||
|
||||
返回按 IP 聚合的数据,每个 IP 显示其关联的所有 hosts 和 ports
|
||||
|
||||
支持智能过滤语法(filter 参数):
|
||||
- ip="192.168" IP 模糊匹配
|
||||
- port="80,443" 端口多值匹配
|
||||
- host="api" 主机名模糊匹配
|
||||
- 多条件空格分隔 AND 关系
|
||||
|
||||
注意:由于返回的是聚合数据(字典列表),不支持 DRF SearchFilter
|
||||
"""
|
||||
|
||||
serializer_class = IPAddressAggregatedSerializer
|
||||
pagination_class = BasePagination
|
||||
|
||||
# 智能过滤字段映射
|
||||
FILTER_FIELD_MAPPING = {
|
||||
'ip': 'ip',
|
||||
'port': 'port',
|
||||
'host': 'host',
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.service = HostPortMappingService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围,返回按 IP 聚合的数据"""
|
||||
"""根据是否有 target_pk 参数决定查询范围,返回按 IP 聚合的数据
|
||||
|
||||
支持智能过滤语法(filter 参数)
|
||||
"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
search = self.request.query_params.get('search', None)
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
if target_pk:
|
||||
return self.service.get_ip_aggregation_by_target(target_pk, search=search)
|
||||
return self.service.get_all_ip_aggregation(search=search)
|
||||
return self.service.get_ip_aggregation_by_target(target_pk, filter_query=filter_query)
|
||||
return self.service.get_all_ip_aggregation(filter_query=filter_query)
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request, **kwargs):
|
||||
@@ -673,12 +718,18 @@ class VulnerabilityViewSet(viewsets.ModelViewSet):
|
||||
支持两种访问方式:
|
||||
1. 嵌套路由:GET /api/targets/{target_pk}/vulnerabilities/
|
||||
2. 独立路由:GET /api/vulnerabilities/(全局查询)
|
||||
|
||||
支持智能过滤语法(filter 参数):
|
||||
- type="xss" 漏洞类型模糊匹配
|
||||
- severity="high" 严重程度匹配
|
||||
- source="nuclei" 来源工具匹配
|
||||
- url="api" URL 模糊匹配
|
||||
- 多条件空格分隔 AND 关系
|
||||
"""
|
||||
|
||||
serializer_class = VulnerabilitySerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['vuln_type']
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering = ['-created_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -686,22 +737,29 @@ class VulnerabilityViewSet(viewsets.ModelViewSet):
|
||||
self.service = VulnerabilityService()
|
||||
|
||||
def get_queryset(self):
|
||||
"""根据是否有 target_pk 参数决定查询范围"""
|
||||
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
|
||||
target_pk = self.kwargs.get('target_pk')
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
if target_pk:
|
||||
return self.service.get_vulnerabilities_by_target(target_pk)
|
||||
return self.service.get_all()
|
||||
return self.service.get_vulnerabilities_by_target(target_pk, filter_query=filter_query)
|
||||
return self.service.get_all(filter_query=filter_query)
|
||||
|
||||
|
||||
# ==================== 快照 ViewSet(Scan 嵌套路由) ====================
|
||||
|
||||
class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""子域名快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/subdomains/"""
|
||||
"""子域名快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/subdomains/
|
||||
|
||||
支持智能过滤语法(filter 参数):
|
||||
- name="api" 子域名模糊匹配
|
||||
- name=="api.example.com" 精确匹配
|
||||
- name!="test" 排除匹配
|
||||
"""
|
||||
|
||||
serializer_class = SubdomainSnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['name']
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering_fields = ['name', 'created_at']
|
||||
ordering = ['-created_at']
|
||||
|
||||
@@ -711,9 +769,11 @@ class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
|
||||
return self.service.get_all(filter_query=filter_query)
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request, **kwargs):
|
||||
@@ -741,12 +801,20 @@ class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
|
||||
|
||||
|
||||
class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""网站快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/websites/"""
|
||||
"""网站快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/websites/
|
||||
|
||||
支持智能过滤语法(filter 参数):
|
||||
- url="api" URL 模糊匹配
|
||||
- host="example" 主机名模糊匹配
|
||||
- title="login" 标题模糊匹配
|
||||
- status="200" 状态码匹配
|
||||
- webserver="nginx" 服务器类型匹配
|
||||
- tech="php" 技术栈匹配
|
||||
"""
|
||||
|
||||
serializer_class = WebsiteSnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['host']
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering = ['-created_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -755,9 +823,11 @@ class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
|
||||
return self.service.get_all(filter_query=filter_query)
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request, **kwargs):
|
||||
@@ -792,12 +862,17 @@ class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
|
||||
|
||||
|
||||
class DirectorySnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""目录快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/directories/"""
|
||||
"""目录快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/directories/
|
||||
|
||||
支持智能过滤语法(filter 参数):
|
||||
- url="admin" URL 模糊匹配
|
||||
- status="200" 状态码匹配
|
||||
- content_type="html" 内容类型匹配
|
||||
"""
|
||||
|
||||
serializer_class = DirectorySnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['url']
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering = ['-created_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -806,9 +881,11 @@ class DirectorySnapshotViewSet(viewsets.ModelViewSet):
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
|
||||
return self.service.get_all(filter_query=filter_query)
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request, **kwargs):
|
||||
@@ -841,12 +918,20 @@ class DirectorySnapshotViewSet(viewsets.ModelViewSet):
|
||||
|
||||
|
||||
class EndpointSnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""端点快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/endpoints/"""
|
||||
"""端点快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/endpoints/
|
||||
|
||||
支持智能过滤语法(filter 参数):
|
||||
- url="api" URL 模糊匹配
|
||||
- host="example" 主机名模糊匹配
|
||||
- title="login" 标题模糊匹配
|
||||
- status="200" 状态码匹配
|
||||
- webserver="nginx" 服务器类型匹配
|
||||
- tech="php" 技术栈匹配
|
||||
"""
|
||||
|
||||
serializer_class = EndpointSnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['host']
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering = ['-created_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -855,9 +940,11 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
|
||||
return self.service.get_all(filter_query=filter_query)
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request, **kwargs):
|
||||
@@ -895,7 +982,12 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
|
||||
class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""主机端口映射快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/ip-addresses/
|
||||
|
||||
注意:由于返回的是聚合数据(字典列表),不支持 DRF SearchFilter
|
||||
支持智能过滤语法(filter 参数):
|
||||
- ip="192.168" IP 模糊匹配
|
||||
- port="80" 端口匹配
|
||||
- host="api" 主机名模糊匹配
|
||||
|
||||
注意:由于返回的是聚合数据(字典列表),过滤在 Service 层处理
|
||||
"""
|
||||
|
||||
serializer_class = IPAddressAggregatedSerializer
|
||||
@@ -907,10 +999,11 @@ class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
search = self.request.query_params.get('search', None)
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
if scan_pk:
|
||||
return self.service.get_ip_aggregation_by_scan(scan_pk, search=search)
|
||||
return self.service.get_all_ip_aggregation(search=search)
|
||||
return self.service.get_ip_aggregation_by_scan(scan_pk, filter_query=filter_query)
|
||||
return self.service.get_all_ip_aggregation(filter_query=filter_query)
|
||||
|
||||
@action(detail=False, methods=['get'], url_path='export')
|
||||
def export(self, request, **kwargs):
|
||||
@@ -944,12 +1037,18 @@ class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
|
||||
|
||||
|
||||
class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
|
||||
"""漏洞快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/vulnerabilities/"""
|
||||
"""漏洞快照 ViewSet - 嵌套路由:GET /api/scans/{scan_pk}/vulnerabilities/
|
||||
|
||||
支持智能过滤语法(filter 参数):
|
||||
- type="xss" 漏洞类型模糊匹配
|
||||
- url="api" URL 模糊匹配
|
||||
- severity="high" 严重程度匹配
|
||||
- source="nuclei" 来源工具匹配
|
||||
"""
|
||||
|
||||
serializer_class = VulnerabilitySnapshotSerializer
|
||||
pagination_class = BasePagination
|
||||
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
||||
search_fields = ['vuln_type']
|
||||
filter_backends = [filters.OrderingFilter]
|
||||
ordering = ['-created_at']
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -958,6 +1057,8 @@ class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
|
||||
|
||||
def get_queryset(self):
|
||||
scan_pk = self.kwargs.get('scan_pk')
|
||||
filter_query = self.request.query_params.get('filter', None)
|
||||
|
||||
if scan_pk:
|
||||
return self.service.get_by_scan(scan_pk)
|
||||
return self.service.get_all()
|
||||
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
|
||||
return self.service.get_all(filter_query=filter_query)
|
||||
|
||||
@@ -3,8 +3,13 @@
|
||||
|
||||
提供系统级别的公共服务,包括:
|
||||
- SystemLogService: 系统日志读取服务
|
||||
|
||||
注意:FilterService 已移至 apps.common.utils.filter_utils
|
||||
推荐使用: from apps.common.utils.filter_utils import apply_filters
|
||||
"""
|
||||
|
||||
from .system_log_service import SystemLogService
|
||||
|
||||
__all__ = ['SystemLogService']
|
||||
__all__ = [
|
||||
'SystemLogService',
|
||||
]
|
||||
|
||||
260
backend/apps/common/utils/filter_utils.py
Normal file
260
backend/apps/common/utils/filter_utils.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""智能过滤工具 - 通用查询语法解析和 Django ORM 查询构建
|
||||
|
||||
支持的语法:
|
||||
- field="value" 模糊匹配(包含)
|
||||
- field=="value" 精确匹配
|
||||
- field!="value" 不等于
|
||||
|
||||
逻辑运算符:
|
||||
- AND: && 或 and 或 空格(默认)
|
||||
- OR: || 或 or
|
||||
|
||||
示例:
|
||||
type="xss" || type="sqli" # OR
|
||||
type="xss" or type="sqli" # OR(等价)
|
||||
severity="high" && source="nuclei" # AND
|
||||
severity="high" source="nuclei" # AND(空格默认为 AND)
|
||||
severity="high" and source="nuclei" # AND(等价)
|
||||
|
||||
使用示例:
|
||||
from apps.common.utils.filter_utils import apply_filters
|
||||
|
||||
field_mapping = {'ip': 'ip', 'port': 'port', 'host': 'host'}
|
||||
queryset = apply_filters(queryset, 'ip="192" || port="80"', field_mapping)
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Optional, Union
|
||||
from enum import Enum
|
||||
|
||||
from django.db.models import QuerySet, Q
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogicalOp(Enum):
|
||||
"""逻辑运算符"""
|
||||
AND = 'AND'
|
||||
OR = 'OR'
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedFilter:
|
||||
"""解析后的过滤条件"""
|
||||
field: str # 字段名
|
||||
operator: str # 操作符: '=', '==', '!='
|
||||
value: str # 原始值
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterGroup:
|
||||
"""过滤条件组(带逻辑运算符)"""
|
||||
filter: ParsedFilter
|
||||
logical_op: LogicalOp # 与前一个条件的逻辑关系
|
||||
|
||||
|
||||
class QueryParser:
|
||||
"""查询语法解析器
|
||||
|
||||
支持 ||/or (OR) 和 &&/and/空格 (AND) 逻辑运算符
|
||||
"""
|
||||
|
||||
# 正则匹配: field="value", field=="value", field!="value"
|
||||
FILTER_PATTERN = re.compile(r'(\w+)(==|!=|=)"([^"]*)"')
|
||||
|
||||
# 逻辑运算符模式(带空格)
|
||||
OR_PATTERN = re.compile(r'\s*(\|\||(?<![a-zA-Z])or(?![a-zA-Z]))\s*', re.IGNORECASE)
|
||||
AND_PATTERN = re.compile(r'\s*(&&|(?<![a-zA-Z])and(?![a-zA-Z]))\s*', re.IGNORECASE)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, query_string: str) -> List[FilterGroup]:
|
||||
"""解析查询语法字符串
|
||||
|
||||
Args:
|
||||
query_string: 查询语法字符串
|
||||
|
||||
Returns:
|
||||
解析后的过滤条件组列表
|
||||
|
||||
Examples:
|
||||
>>> QueryParser.parse('type="xss" || type="sqli"')
|
||||
[FilterGroup(filter=..., logical_op=AND), # 第一个默认 AND
|
||||
FilterGroup(filter=..., logical_op=OR)]
|
||||
"""
|
||||
if not query_string or not query_string.strip():
|
||||
return []
|
||||
|
||||
# 标准化逻辑运算符
|
||||
# 先处理 || 和 or -> __OR__
|
||||
normalized = cls.OR_PATTERN.sub(' __OR__ ', query_string)
|
||||
# 再处理 && 和 and -> __AND__
|
||||
normalized = cls.AND_PATTERN.sub(' __AND__ ', normalized)
|
||||
|
||||
# 分词:按空格分割,保留逻辑运算符标记
|
||||
tokens = normalized.split()
|
||||
|
||||
groups = []
|
||||
pending_op = LogicalOp.AND # 默认 AND
|
||||
|
||||
for token in tokens:
|
||||
if token == '__OR__':
|
||||
pending_op = LogicalOp.OR
|
||||
elif token == '__AND__':
|
||||
pending_op = LogicalOp.AND
|
||||
else:
|
||||
# 尝试解析为过滤条件
|
||||
match = cls.FILTER_PATTERN.match(token)
|
||||
if match:
|
||||
field, operator, value = match.groups()
|
||||
groups.append(FilterGroup(
|
||||
filter=ParsedFilter(
|
||||
field=field.lower(),
|
||||
operator=operator,
|
||||
value=value
|
||||
),
|
||||
logical_op=pending_op if groups else LogicalOp.AND # 第一个条件默认 AND
|
||||
))
|
||||
pending_op = LogicalOp.AND # 重置为默认 AND
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
class QueryBuilder:
|
||||
"""Django ORM 查询构建器
|
||||
|
||||
将解析后的过滤条件转换为 Django ORM 查询,支持 AND/OR 逻辑
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def build_query(
|
||||
cls,
|
||||
queryset: QuerySet,
|
||||
filter_groups: List[FilterGroup],
|
||||
field_mapping: Dict[str, str]
|
||||
) -> QuerySet:
|
||||
"""构建 Django ORM 查询
|
||||
|
||||
Args:
|
||||
queryset: Django QuerySet
|
||||
filter_groups: 解析后的过滤条件组列表
|
||||
field_mapping: 字段映射
|
||||
|
||||
Returns:
|
||||
过滤后的 QuerySet
|
||||
"""
|
||||
if not filter_groups:
|
||||
return queryset
|
||||
|
||||
# 构建 Q 对象
|
||||
combined_q = None
|
||||
|
||||
for group in filter_groups:
|
||||
f = group.filter
|
||||
|
||||
# 字段映射
|
||||
db_field = field_mapping.get(f.field)
|
||||
if not db_field:
|
||||
logger.debug(f"忽略未知字段: {f.field}")
|
||||
continue
|
||||
|
||||
# 构建单个条件的 Q 对象
|
||||
q = cls._build_single_q(db_field, f.operator, f.value)
|
||||
if q is None:
|
||||
continue
|
||||
|
||||
# 组合 Q 对象
|
||||
if combined_q is None:
|
||||
combined_q = q
|
||||
elif group.logical_op == LogicalOp.OR:
|
||||
combined_q = combined_q | q
|
||||
else: # AND
|
||||
combined_q = combined_q & q
|
||||
|
||||
if combined_q is not None:
|
||||
return queryset.filter(combined_q)
|
||||
return queryset
|
||||
|
||||
@classmethod
|
||||
def _build_single_q(cls, field: str, operator: str, value: str) -> Optional[Q]:
|
||||
"""构建单个条件的 Q 对象"""
|
||||
if operator == '!=':
|
||||
return cls._build_not_equal_q(field, value)
|
||||
elif operator == '==':
|
||||
return cls._build_exact_q(field, value)
|
||||
else: # '='
|
||||
return cls._build_fuzzy_q(field, value)
|
||||
|
||||
@classmethod
|
||||
def _try_convert_to_int(cls, value: str) -> Optional[int]:
|
||||
"""尝试将值转换为整数"""
|
||||
try:
|
||||
return int(value.strip())
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _build_fuzzy_q(cls, field: str, value: str) -> Q:
|
||||
"""模糊匹配: 包含"""
|
||||
return Q(**{f'{field}__icontains': value})
|
||||
|
||||
@classmethod
|
||||
def _build_exact_q(cls, field: str, value: str) -> Q:
|
||||
"""精确匹配"""
|
||||
int_val = cls._try_convert_to_int(value)
|
||||
if int_val is not None:
|
||||
return Q(**{f'{field}__exact': int_val})
|
||||
return Q(**{f'{field}__exact': value})
|
||||
|
||||
@classmethod
|
||||
def _build_not_equal_q(cls, field: str, value: str) -> Q:
|
||||
"""不等于"""
|
||||
int_val = cls._try_convert_to_int(value)
|
||||
if int_val is not None:
|
||||
return ~Q(**{f'{field}__exact': int_val})
|
||||
return ~Q(**{f'{field}__exact': value})
|
||||
|
||||
|
||||
def apply_filters(
|
||||
queryset: QuerySet,
|
||||
query_string: str,
|
||||
field_mapping: Dict[str, str]
|
||||
) -> QuerySet:
|
||||
"""应用过滤条件到 QuerySet
|
||||
|
||||
Args:
|
||||
queryset: Django QuerySet
|
||||
query_string: 查询语法字符串
|
||||
field_mapping: 字段映射
|
||||
|
||||
Returns:
|
||||
过滤后的 QuerySet
|
||||
|
||||
Examples:
|
||||
# OR 查询
|
||||
apply_filters(qs, 'type="xss" || type="sqli"', mapping)
|
||||
apply_filters(qs, 'type="xss" or type="sqli"', mapping)
|
||||
|
||||
# AND 查询
|
||||
apply_filters(qs, 'severity="high" && source="nuclei"', mapping)
|
||||
apply_filters(qs, 'severity="high" source="nuclei"', mapping)
|
||||
|
||||
# 混合查询
|
||||
apply_filters(qs, 'type="xss" || type="sqli" && severity="high"', mapping)
|
||||
"""
|
||||
if not query_string or not query_string.strip():
|
||||
return queryset
|
||||
|
||||
try:
|
||||
filter_groups = QueryParser.parse(query_string)
|
||||
if not filter_groups:
|
||||
logger.debug(f"未解析到有效过滤条件: {query_string}")
|
||||
return queryset
|
||||
|
||||
logger.debug(f"解析过滤条件: {filter_groups}")
|
||||
return QueryBuilder.build_query(queryset, filter_groups, field_mapping)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"过滤解析错误: {e}, query: {query_string}")
|
||||
return queryset # 静默降级
|
||||
@@ -8,7 +8,7 @@ export default function ScanHistoryVulnerabilitiesPage() {
|
||||
const { id } = useParams<{ id: string }>()
|
||||
|
||||
return (
|
||||
<div className="relative flex flex-col gap-4 overflow-auto px-4 lg:px-6">
|
||||
<div className="px-4 lg:px-6">
|
||||
<VulnerabilitiesDetailView scanId={Number(id)} />
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ export default function TargetVulnerabilitiesPage() {
|
||||
const { id } = useParams<{ id: string }>()
|
||||
|
||||
return (
|
||||
<div className="relative flex flex-col gap-4 overflow-auto px-4 lg:px-6">
|
||||
<div className="px-4 lg:px-6">
|
||||
<VulnerabilitiesDetailView targetId={parseInt(id)} />
|
||||
</div>
|
||||
)
|
||||
|
||||
330
frontend/components/common/smart-filter-input.tsx
Normal file
330
frontend/components/common/smart-filter-input.tsx
Normal file
@@ -0,0 +1,330 @@
|
||||
"use client"
|
||||
|
||||
import * as React from "react"
|
||||
import { IconSearch } from "@tabler/icons-react"
|
||||
import {
|
||||
Command,
|
||||
CommandEmpty,
|
||||
CommandGroup,
|
||||
CommandItem,
|
||||
CommandList,
|
||||
} from "@/components/ui/command"
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverAnchor,
|
||||
} from "@/components/ui/popover"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import { Badge } from "@/components/ui/badge"
|
||||
import { Input } from "@/components/ui/input"
|
||||
|
||||
// 可用的筛选字段定义
|
||||
export interface FilterField {
|
||||
key: string
|
||||
label: string
|
||||
description: string
|
||||
}
|
||||
|
||||
// 预定义的字段配置,各页面可以选择使用
|
||||
export const PREDEFINED_FIELDS: Record<string, FilterField> = {
|
||||
ip: { key: "ip", label: "IP", description: "IP address" },
|
||||
port: { key: "port", label: "Port", description: "Port number" },
|
||||
host: { key: "host", label: "Host", description: "Hostname" },
|
||||
domain: { key: "domain", label: "Domain", description: "Domain name" },
|
||||
url: { key: "url", label: "URL", description: "Full URL" },
|
||||
status: { key: "status", label: "Status", description: "HTTP status code" },
|
||||
title: { key: "title", label: "Title", description: "Page title" },
|
||||
source: { key: "source", label: "Source", description: "Data source" },
|
||||
path: { key: "path", label: "Path", description: "URL path" },
|
||||
severity: { key: "severity", label: "Severity", description: "Vulnerability severity" },
|
||||
name: { key: "name", label: "Name", description: "Name" },
|
||||
type: { key: "type", label: "Type", description: "Type" },
|
||||
}
|
||||
|
||||
// 默认字段(IP Addresses 页面)
|
||||
const DEFAULT_FIELDS: FilterField[] = [
|
||||
PREDEFINED_FIELDS.ip,
|
||||
PREDEFINED_FIELDS.port,
|
||||
PREDEFINED_FIELDS.host,
|
||||
]
|
||||
|
||||
// 解析筛选表达式 (FOFA 风格)
|
||||
interface ParsedFilter {
|
||||
field: string
|
||||
operator: string
|
||||
value: string
|
||||
raw: string
|
||||
}
|
||||
|
||||
function parseFilterExpression(input: string): ParsedFilter[] {
|
||||
const filters: ParsedFilter[] = []
|
||||
// 匹配 FOFA 风格: field="value", field=="value", field!="value"
|
||||
// == 精确匹配, = 模糊匹配, != 不等于
|
||||
// 支持逗号分隔多值: port="80,443,8080"
|
||||
const regex = /(\w+)(==|!=|=)"([^"]+)"/g
|
||||
let match
|
||||
|
||||
while ((match = regex.exec(input)) !== null) {
|
||||
const [raw, field, operator, value] = match
|
||||
filters.push({ field, operator, value, raw })
|
||||
}
|
||||
|
||||
return filters
|
||||
}
|
||||
|
||||
interface SmartFilterInputProps {
|
||||
/** 可用的筛选字段,不传则使用默认字段 */
|
||||
fields?: FilterField[]
|
||||
/** 组合示例(使用逻辑运算符拼接的完整示例) */
|
||||
examples?: string[]
|
||||
placeholder?: string
|
||||
/** 受控模式:当前过滤值 */
|
||||
value?: string
|
||||
onSearch?: (filters: ParsedFilter[], rawQuery: string) => void
|
||||
className?: string
|
||||
}
|
||||
|
||||
export function SmartFilterInput({
|
||||
fields = DEFAULT_FIELDS,
|
||||
examples,
|
||||
placeholder,
|
||||
value,
|
||||
onSearch,
|
||||
className,
|
||||
}: SmartFilterInputProps) {
|
||||
const [open, setOpen] = React.useState(false)
|
||||
const [inputValue, setInputValue] = React.useState(value ?? "")
|
||||
const inputRef = React.useRef<HTMLInputElement>(null)
|
||||
const listRef = React.useRef<HTMLDivElement>(null)
|
||||
const savedScrollTop = React.useRef<number | null>(null)
|
||||
const hasInitialized = React.useRef(false)
|
||||
|
||||
// 同步外部 value 变化
|
||||
React.useEffect(() => {
|
||||
if (value !== undefined) {
|
||||
setInputValue(value)
|
||||
}
|
||||
}, [value])
|
||||
|
||||
// 当 Popover 打开时,恢复滚动位置(首次打开时滚动到顶部)
|
||||
React.useEffect(() => {
|
||||
if (open) {
|
||||
const restoreScroll = () => {
|
||||
if (listRef.current) {
|
||||
if (!hasInitialized.current) {
|
||||
// 首次打开,滚动到顶部
|
||||
listRef.current.scrollTop = 0
|
||||
hasInitialized.current = true
|
||||
} else if (savedScrollTop.current !== null) {
|
||||
// 之后恢复保存的滚动位置
|
||||
listRef.current.scrollTop = savedScrollTop.current
|
||||
}
|
||||
}
|
||||
}
|
||||
// 立即执行一次
|
||||
restoreScroll()
|
||||
// 延迟执行确保 Popover 动画完成
|
||||
const timer = setTimeout(restoreScroll, 50)
|
||||
return () => clearTimeout(timer)
|
||||
} else {
|
||||
// Popover 关闭时保存滚动位置
|
||||
if (listRef.current) {
|
||||
savedScrollTop.current = listRef.current.scrollTop
|
||||
}
|
||||
}
|
||||
}, [open])
|
||||
|
||||
// 生成默认 placeholder(使用第一个示例或字段组合)
|
||||
const defaultPlaceholder = React.useMemo(() => {
|
||||
if (examples && examples.length > 0) {
|
||||
return examples[0]
|
||||
}
|
||||
// 使用字段生成简单示例
|
||||
return fields.slice(0, 2).map(f => `${f.key}="..."`).join(" && ")
|
||||
}, [fields, examples])
|
||||
|
||||
// 解析当前输入
|
||||
const parsedFilters = parseFilterExpression(inputValue)
|
||||
|
||||
// 获取当前正在输入的词
|
||||
const getCurrentWord = () => {
|
||||
const words = inputValue.split(/\s+/)
|
||||
return words[words.length - 1] || ""
|
||||
}
|
||||
|
||||
const currentWord = getCurrentWord()
|
||||
|
||||
// 判断是否显示字段建议 (FOFA 风格用 = 而不是 :)
|
||||
const showFieldSuggestions = !currentWord.includes("=")
|
||||
|
||||
// 处理选择建议 (FOFA 风格: field=""),然后关闭弹窗
|
||||
const handleSelectSuggestion = (suggestion: string) => {
|
||||
const words = inputValue.split(/\s+/)
|
||||
words[words.length - 1] = suggestion
|
||||
const newValue = words.join(" ")
|
||||
setInputValue(newValue)
|
||||
setOpen(false)
|
||||
inputRef.current?.blur()
|
||||
}
|
||||
|
||||
// 处理搜索
|
||||
const handleSearch = () => {
|
||||
onSearch?.(parsedFilters, inputValue)
|
||||
setOpen(false)
|
||||
}
|
||||
|
||||
// 处理键盘事件
|
||||
const handleKeyDown = (e: React.KeyboardEvent) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault()
|
||||
handleSearch()
|
||||
}
|
||||
if (e.key === "Escape") {
|
||||
setOpen(false)
|
||||
}
|
||||
}
|
||||
|
||||
// 附加示例到输入框(而非覆盖),然后关闭弹窗
|
||||
const handleAppendExample = (example: string) => {
|
||||
const trimmed = inputValue.trim()
|
||||
const newValue = trimmed ? `${trimmed} ${example}` : example
|
||||
setInputValue(newValue)
|
||||
setOpen(false)
|
||||
inputRef.current?.blur()
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={className}>
|
||||
<Popover open={open} onOpenChange={setOpen} modal={false}>
|
||||
<PopoverAnchor asChild>
|
||||
<div className="flex items-center gap-2">
|
||||
<Input
|
||||
ref={inputRef}
|
||||
type="text"
|
||||
value={inputValue}
|
||||
onChange={(e) => {
|
||||
setInputValue(e.target.value)
|
||||
if (!open) setOpen(true)
|
||||
}}
|
||||
onFocus={() => setOpen(true)}
|
||||
onBlur={(e) => {
|
||||
// 如果焦点转移到 Popover 内部或输入框自身,不关闭
|
||||
const relatedTarget = e.relatedTarget as HTMLElement | null
|
||||
if (relatedTarget?.closest('[data-radix-popper-content-wrapper]')) {
|
||||
return
|
||||
}
|
||||
// 延迟关闭,让 CommandItem 的 onSelect 先执行
|
||||
setTimeout(() => setOpen(false), 150)
|
||||
}}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={placeholder || defaultPlaceholder}
|
||||
className="h-8 w-full"
|
||||
/>
|
||||
<Button variant="outline" size="sm" onClick={handleSearch}>
|
||||
<IconSearch className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</PopoverAnchor>
|
||||
<PopoverContent
|
||||
className="w-[var(--radix-popover-trigger-width)] p-0"
|
||||
align="start"
|
||||
side="bottom"
|
||||
sideOffset={4}
|
||||
collisionPadding={16}
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
onCloseAutoFocus={(e) => e.preventDefault()}
|
||||
onPointerDownOutside={(e) => {
|
||||
// 如果点击的是输入框,不关闭弹窗
|
||||
if (inputRef.current?.contains(e.target as Node)) {
|
||||
e.preventDefault()
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Command>
|
||||
<CommandList ref={listRef}>
|
||||
{/* 已解析的筛选条件预览 */}
|
||||
{parsedFilters.length > 0 && (
|
||||
<CommandGroup heading="Active filters">
|
||||
<div className="flex flex-wrap gap-1 px-2 py-1">
|
||||
{parsedFilters.map((filter, i) => (
|
||||
<Badge key={i} variant="secondary" className="text-xs font-mono">
|
||||
{filter.raw}
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
</CommandGroup>
|
||||
)}
|
||||
|
||||
{/* 可用字段 */}
|
||||
{showFieldSuggestions && (
|
||||
<CommandGroup heading="Available fields">
|
||||
<div className="flex flex-wrap gap-1 px-2 py-1">
|
||||
{fields.filter(
|
||||
(f) => !currentWord || f.key.startsWith(currentWord.toLowerCase())
|
||||
).map((field) => (
|
||||
<Badge
|
||||
key={field.key}
|
||||
variant="outline"
|
||||
className="text-xs font-mono cursor-pointer hover:bg-accent"
|
||||
onClick={() => handleSelectSuggestion(`${field.key}="`)}
|
||||
>
|
||||
{field.key}
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
</CommandGroup>
|
||||
)}
|
||||
|
||||
{/* 语法帮助 */}
|
||||
<CommandGroup heading="Syntax">
|
||||
<div className="px-2 py-1.5 text-xs text-muted-foreground space-y-2">
|
||||
{/* 匹配操作符 */}
|
||||
<div className="space-y-1">
|
||||
<div className="font-medium text-foreground/80">Operators</div>
|
||||
<div className="grid grid-cols-[auto_1fr] gap-x-3 gap-y-0.5">
|
||||
<code className="bg-muted px-1 rounded">=</code>
|
||||
<span>contains (fuzzy)</span>
|
||||
<code className="bg-muted px-1 rounded">==</code>
|
||||
<span>exact match</span>
|
||||
<code className="bg-muted px-1 rounded">!=</code>
|
||||
<span>not equals</span>
|
||||
</div>
|
||||
</div>
|
||||
{/* 逻辑运算符 */}
|
||||
<div className="space-y-1 pt-1 border-t border-muted">
|
||||
<div className="font-medium text-foreground/80">Logic</div>
|
||||
<div className="grid grid-cols-[auto_1fr] gap-x-3 gap-y-0.5">
|
||||
<span><code className="bg-muted px-1 rounded">||</code> <code className="bg-muted px-1 rounded">or</code></span>
|
||||
<span>match any</span>
|
||||
<span><code className="bg-muted px-1 rounded">&&</code> <code className="bg-muted px-1 rounded">and</code> <code className="bg-muted px-1 rounded">space</code></span>
|
||||
<span>match all</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CommandGroup>
|
||||
|
||||
{/* 示例 */}
|
||||
{examples && examples.length > 0 && (
|
||||
<CommandGroup heading="Examples">
|
||||
{examples.map((example, i) => (
|
||||
<CommandItem
|
||||
key={i}
|
||||
value={example}
|
||||
onSelect={() => handleAppendExample(example)}
|
||||
>
|
||||
<code className="text-xs">{example}</code>
|
||||
</CommandItem>
|
||||
))}
|
||||
</CommandGroup>
|
||||
)}
|
||||
|
||||
<CommandEmpty>Type to filter...</CommandEmpty>
|
||||
</CommandList>
|
||||
</Command>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export { parseFilterExpression, DEFAULT_FIELDS, type ParsedFilter }
|
||||
@@ -25,8 +25,6 @@ import {
|
||||
IconLayoutColumns,
|
||||
IconTrash,
|
||||
IconDownload,
|
||||
IconSearch,
|
||||
IconLoader2,
|
||||
IconPlus,
|
||||
} from "@tabler/icons-react"
|
||||
|
||||
@@ -40,7 +38,6 @@ import {
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu"
|
||||
import { Input } from "@/components/ui/input"
|
||||
import { Label } from "@/components/ui/label"
|
||||
import {
|
||||
Select,
|
||||
@@ -57,17 +54,30 @@ import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table"
|
||||
import { SmartFilterInput, type FilterField } from "@/components/common/smart-filter-input"
|
||||
|
||||
import type { Directory } from "@/types/directory.types"
|
||||
import type { PaginationInfo } from "@/types/common.types"
|
||||
|
||||
// 目录页面的过滤字段配置
|
||||
const DIRECTORY_FILTER_FIELDS: FilterField[] = [
|
||||
{ key: "url", label: "URL", description: "Directory URL" },
|
||||
{ key: "status", label: "Status", description: "HTTP status code" },
|
||||
]
|
||||
|
||||
// 目录页面的示例
|
||||
const DIRECTORY_FILTER_EXAMPLES = [
|
||||
'url="/admin" && status="200"',
|
||||
'url="/api/*" || url="/config/*"',
|
||||
'status="200" && url!="/index.html"',
|
||||
]
|
||||
|
||||
interface DirectoriesDataTableProps {
|
||||
data: Directory[]
|
||||
columns: ColumnDef<Directory>[]
|
||||
searchPlaceholder?: string
|
||||
searchColumn?: string
|
||||
searchValue?: string
|
||||
onSearch?: (value: string) => void
|
||||
// 智能过滤
|
||||
filterValue?: string
|
||||
onFilterChange?: (value: string) => void
|
||||
isSearching?: boolean
|
||||
pagination?: { pageIndex: number; pageSize: number }
|
||||
setPagination?: React.Dispatch<React.SetStateAction<{ pageIndex: number; pageSize: number }>>
|
||||
@@ -84,10 +94,8 @@ interface DirectoriesDataTableProps {
|
||||
export function DirectoriesDataTable({
|
||||
data = [],
|
||||
columns,
|
||||
searchPlaceholder = "搜索URL...",
|
||||
searchColumn = "url",
|
||||
searchValue,
|
||||
onSearch,
|
||||
filterValue,
|
||||
onFilterChange,
|
||||
isSearching = false,
|
||||
pagination,
|
||||
setPagination,
|
||||
@@ -109,24 +117,10 @@ export function DirectoriesDataTable({
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
// 本地搜索输入状态(只在回车或点击按钮时触发搜索)
|
||||
const [localSearchValue, setLocalSearchValue] = React.useState(searchValue ?? "")
|
||||
|
||||
React.useEffect(() => {
|
||||
setLocalSearchValue(searchValue ?? "")
|
||||
}, [searchValue])
|
||||
|
||||
const handleSearchSubmit = () => {
|
||||
if (onSearch) {
|
||||
onSearch(localSearchValue)
|
||||
} else {
|
||||
table.getColumn(searchColumn)?.setFilterValue(localSearchValue)
|
||||
}
|
||||
}
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === 'Enter') {
|
||||
handleSearchSubmit()
|
||||
// 处理智能过滤搜索
|
||||
const handleSmartSearch = (_filters: any[], rawQuery: string) => {
|
||||
if (onFilterChange) {
|
||||
onFilterChange(rawQuery)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,22 +184,14 @@ export function DirectoriesDataTable({
|
||||
<div className="flex flex-col gap-4 @container/toolbar">
|
||||
{/* 第一行:搜索和列控制 */}
|
||||
<div className="flex flex-col gap-4 @xl/toolbar:flex-row @xl/toolbar:items-center @xl/toolbar:justify-between">
|
||||
<div className="flex flex-1 items-center gap-2">
|
||||
<Input
|
||||
placeholder={searchPlaceholder}
|
||||
value={localSearchValue}
|
||||
onChange={(e) => setLocalSearchValue(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
className="h-8 w-full @xl/toolbar:max-w-sm"
|
||||
/>
|
||||
<Button variant="outline" size="sm" onClick={handleSearchSubmit} disabled={isSearching}>
|
||||
{isSearching ? (
|
||||
<IconLoader2 className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<IconSearch className="h-4 w-4" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
{/* 智能过滤搜索框 */}
|
||||
<SmartFilterInput
|
||||
fields={DIRECTORY_FILTER_FIELDS}
|
||||
examples={DIRECTORY_FILTER_EXAMPLES}
|
||||
value={filterValue}
|
||||
onSearch={handleSmartSearch}
|
||||
className="flex-1 max-w-xl"
|
||||
/>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
{/* 列可见性控制 */}
|
||||
|
||||
@@ -28,15 +28,15 @@ export function DirectoriesView({
|
||||
const [selectedDirectories, setSelectedDirectories] = useState<Directory[]>([])
|
||||
const [bulkAddDialogOpen, setBulkAddDialogOpen] = useState(false)
|
||||
|
||||
const [searchQuery, setSearchQuery] = useState("")
|
||||
const [filterQuery, setFilterQuery] = useState("")
|
||||
const [isSearching, setIsSearching] = useState(false)
|
||||
|
||||
// 获取目标信息(用于 URL 匹配校验)
|
||||
const { data: target } = useTarget(targetId || 0, { enabled: !!targetId })
|
||||
|
||||
const handleSearchChange = (value: string) => {
|
||||
const handleFilterChange = (value: string) => {
|
||||
setIsSearching(true)
|
||||
setSearchQuery(value)
|
||||
setFilterQuery(value)
|
||||
setPagination((prev) => ({ ...prev, pageIndex: 0 }))
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ export function DirectoriesView({
|
||||
{
|
||||
page: pagination.pageIndex + 1,
|
||||
pageSize: pagination.pageSize,
|
||||
search: searchQuery || undefined,
|
||||
filter: filterQuery || undefined,
|
||||
},
|
||||
{ enabled: !!targetId }
|
||||
)
|
||||
@@ -55,7 +55,7 @@ export function DirectoriesView({
|
||||
{
|
||||
page: pagination.pageIndex + 1,
|
||||
pageSize: pagination.pageSize,
|
||||
search: searchQuery || undefined,
|
||||
filter: filterQuery || undefined,
|
||||
},
|
||||
{ enabled: !!scanId }
|
||||
)
|
||||
@@ -242,10 +242,8 @@ export function DirectoriesView({
|
||||
<DirectoriesDataTable
|
||||
data={directories}
|
||||
columns={columns}
|
||||
searchPlaceholder="搜索URL..."
|
||||
searchColumn="url"
|
||||
searchValue={searchQuery}
|
||||
onSearch={handleSearchChange}
|
||||
filterValue={filterQuery}
|
||||
onFilterChange={handleFilterChange}
|
||||
isSearching={isSearching}
|
||||
pagination={pagination}
|
||||
setPagination={setPagination}
|
||||
|
||||
@@ -25,8 +25,6 @@ import {
|
||||
IconLayoutColumns,
|
||||
IconPlus,
|
||||
IconDownload,
|
||||
IconSearch,
|
||||
IconLoader2,
|
||||
} from "@tabler/icons-react"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import {
|
||||
@@ -38,7 +36,6 @@ import {
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu"
|
||||
import { Input } from "@/components/ui/input"
|
||||
import { Label } from "@/components/ui/label"
|
||||
import {
|
||||
Select,
|
||||
@@ -55,15 +52,30 @@ import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table"
|
||||
import { SmartFilterInput, type FilterField } from "@/components/common/smart-filter-input"
|
||||
import type { Endpoint } from "@/types/endpoint.types"
|
||||
|
||||
// 端点页面的过滤字段配置
|
||||
const ENDPOINT_FILTER_FIELDS: FilterField[] = [
|
||||
{ key: "url", label: "URL", description: "Endpoint URL" },
|
||||
{ key: "host", label: "Host", description: "Hostname" },
|
||||
{ key: "title", label: "Title", description: "Page title" },
|
||||
{ key: "status", label: "Status", description: "HTTP status code" },
|
||||
]
|
||||
|
||||
// 端点页面的示例
|
||||
const ENDPOINT_FILTER_EXAMPLES = [
|
||||
'url="/api/*" && status="200"',
|
||||
'host="api.example.com" || host="admin.example.com"',
|
||||
'title="Dashboard" && status!="404"',
|
||||
]
|
||||
|
||||
interface EndpointsDataTableProps<TData extends { id: number | string }, TValue> {
|
||||
columns: ColumnDef<TData, TValue>[]
|
||||
data: TData[]
|
||||
searchPlaceholder?: string
|
||||
searchColumn?: string
|
||||
searchValue?: string
|
||||
onSearch?: (value: string) => void
|
||||
// 智能过滤
|
||||
filterValue?: string
|
||||
onFilterChange?: (value: string) => void
|
||||
isSearching?: boolean
|
||||
onAddNew?: () => void
|
||||
addButtonText?: string
|
||||
@@ -80,10 +92,8 @@ interface EndpointsDataTableProps<TData extends { id: number | string }, TValue>
|
||||
export function EndpointsDataTable<TData extends { id: number | string }, TValue>({
|
||||
columns,
|
||||
data,
|
||||
searchPlaceholder = "搜索主机名...",
|
||||
searchColumn = "url",
|
||||
searchValue,
|
||||
onSearch,
|
||||
filterValue,
|
||||
onFilterChange,
|
||||
isSearching = false,
|
||||
onAddNew,
|
||||
addButtonText = "Add",
|
||||
@@ -106,24 +116,10 @@ export function EndpointsDataTable<TData extends { id: number | string }, TValue
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
// 本地搜索输入状态(只在回车或点击按钮时触发搜索)
|
||||
const [localSearchValue, setLocalSearchValue] = React.useState(searchValue ?? "")
|
||||
|
||||
React.useEffect(() => {
|
||||
setLocalSearchValue(searchValue ?? "")
|
||||
}, [searchValue])
|
||||
|
||||
const handleSearchSubmit = () => {
|
||||
if (onSearch) {
|
||||
onSearch(localSearchValue)
|
||||
} else {
|
||||
table.getColumn(searchColumn)?.setFilterValue(localSearchValue)
|
||||
}
|
||||
}
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === 'Enter') {
|
||||
handleSearchSubmit()
|
||||
// 处理智能过滤搜索
|
||||
const handleSmartSearch = (_filters: any[], rawQuery: string) => {
|
||||
if (onFilterChange) {
|
||||
onFilterChange(rawQuery)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,22 +179,14 @@ export function EndpointsDataTable<TData extends { id: number | string }, TValue
|
||||
return (
|
||||
<div className="w-full space-y-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center space-x-2">
|
||||
<Input
|
||||
placeholder={searchPlaceholder}
|
||||
value={localSearchValue}
|
||||
onChange={(e) => setLocalSearchValue(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
className="h-8 max-w-sm"
|
||||
/>
|
||||
<Button variant="outline" size="sm" onClick={handleSearchSubmit} disabled={isSearching}>
|
||||
{isSearching ? (
|
||||
<IconLoader2 className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<IconSearch className="h-4 w-4" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
{/* 智能过滤搜索框 */}
|
||||
<SmartFilterInput
|
||||
fields={ENDPOINT_FILTER_FIELDS}
|
||||
examples={ENDPOINT_FILTER_EXAMPLES}
|
||||
value={filterValue}
|
||||
onSearch={handleSmartSearch}
|
||||
className="flex-1 max-w-xl"
|
||||
/>
|
||||
|
||||
<div className="flex items-center space-x-2">
|
||||
<DropdownMenu>
|
||||
|
||||
@@ -46,15 +46,15 @@ export function EndpointsDetailView({
|
||||
pageSize: 10
|
||||
})
|
||||
|
||||
const [searchQuery, setSearchQuery] = useState("")
|
||||
const [filterQuery, setFilterQuery] = useState("")
|
||||
const [isSearching, setIsSearching] = useState(false)
|
||||
|
||||
// 获取目标信息(用于 URL 匹配校验)
|
||||
const { data: target } = useTarget(targetId || 0, { enabled: !!targetId })
|
||||
|
||||
const handleSearchChange = (value: string) => {
|
||||
const handleFilterChange = (value: string) => {
|
||||
setIsSearching(true)
|
||||
setSearchQuery(value)
|
||||
setFilterQuery(value)
|
||||
setPagination((prev) => ({ ...prev, pageIndex: 0 }))
|
||||
}
|
||||
|
||||
@@ -65,14 +65,18 @@ export function EndpointsDetailView({
|
||||
const targetEndpointsQuery = useTargetEndpoints(targetId || 0, {
|
||||
page: pagination.pageIndex + 1,
|
||||
pageSize: pagination.pageSize,
|
||||
search: searchQuery || undefined,
|
||||
filter: filterQuery || undefined,
|
||||
}, { enabled: !!targetId })
|
||||
|
||||
const scanEndpointsQuery = useScanEndpoints(scanId || 0, {
|
||||
page: pagination.pageIndex + 1,
|
||||
pageSize: pagination.pageSize,
|
||||
search: searchQuery || undefined,
|
||||
}, { enabled: !!scanId })
|
||||
const scanEndpointsQuery = useScanEndpoints(
|
||||
scanId || 0,
|
||||
{
|
||||
page: pagination.pageIndex + 1,
|
||||
pageSize: pagination.pageSize,
|
||||
},
|
||||
{ enabled: !!scanId },
|
||||
filterQuery || undefined,
|
||||
)
|
||||
|
||||
const {
|
||||
data,
|
||||
@@ -279,9 +283,8 @@ export function EndpointsDetailView({
|
||||
<EndpointsDataTable
|
||||
data={data?.endpoints || []}
|
||||
columns={endpointColumns}
|
||||
searchPlaceholder="搜索主机名..."
|
||||
searchValue={searchQuery}
|
||||
onSearch={handleSearchChange}
|
||||
filterValue={filterQuery}
|
||||
onFilterChange={handleFilterChange}
|
||||
isSearching={isSearching}
|
||||
pagination={pagination}
|
||||
onPaginationChange={handlePaginationChange}
|
||||
|
||||
@@ -23,8 +23,6 @@ import {
|
||||
IconLayoutColumns,
|
||||
IconTrash,
|
||||
IconDownload,
|
||||
IconSearch,
|
||||
IconLoader2,
|
||||
} from "@tabler/icons-react"
|
||||
import { Button } from "@/components/ui/button"
|
||||
import {
|
||||
@@ -36,7 +34,6 @@ import {
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu"
|
||||
import { Input } from "@/components/ui/input"
|
||||
import { Label } from "@/components/ui/label"
|
||||
import {
|
||||
Select,
|
||||
@@ -53,36 +50,45 @@ import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table"
|
||||
import { SmartFilterInput, PREDEFINED_FIELDS, type ParsedFilter, type FilterField } from "@/components/common/smart-filter-input"
|
||||
import type { IPAddress } from "@/types/ip-address.types"
|
||||
import type { PaginationInfo } from "@/types/common.types"
|
||||
|
||||
// IP 地址页面的过滤字段配置
|
||||
const IP_ADDRESS_FILTER_FIELDS: FilterField[] = [
|
||||
PREDEFINED_FIELDS.ip,
|
||||
PREDEFINED_FIELDS.port,
|
||||
PREDEFINED_FIELDS.host,
|
||||
]
|
||||
|
||||
// IP 地址页面的示例
|
||||
const IP_ADDRESS_FILTER_EXAMPLES = [
|
||||
'ip="192.168.1.*" && port="80"',
|
||||
'port="443" || port="8443"',
|
||||
'host="api.example.com" && port!="22"',
|
||||
]
|
||||
|
||||
|
||||
interface IPAddressesDataTableProps {
|
||||
data: IPAddress[]
|
||||
columns: ColumnDef<IPAddress>[]
|
||||
searchPlaceholder?: string
|
||||
searchColumn?: string
|
||||
searchValue?: string
|
||||
onSearch?: (value: string) => void
|
||||
isSearching?: boolean
|
||||
filterValue?: string
|
||||
onFilterChange?: (value: string) => void
|
||||
pagination?: { pageIndex: number; pageSize: number }
|
||||
setPagination?: React.Dispatch<React.SetStateAction<{ pageIndex: number; pageSize: number }>>
|
||||
paginationInfo?: PaginationInfo
|
||||
onPaginationChange?: (pagination: { pageIndex: number; pageSize: number }) => void
|
||||
onBulkDelete?: () => void // 批量删除回调函数
|
||||
onSelectionChange?: (selectedRows: IPAddress[]) => void // 选中行变化回调
|
||||
// 下载回调函数
|
||||
onDownloadAll?: () => void // 下载所有 IP 地址
|
||||
onDownloadSelected?: () => void // 下载选中的 IP 地址
|
||||
onBulkDelete?: () => void
|
||||
onSelectionChange?: (selectedRows: IPAddress[]) => void
|
||||
onDownloadAll?: () => void
|
||||
onDownloadSelected?: () => void
|
||||
}
|
||||
|
||||
export function IPAddressesDataTable({
|
||||
data = [],
|
||||
columns,
|
||||
searchPlaceholder = "搜索 IP 地址...",
|
||||
searchColumn = "ip",
|
||||
searchValue,
|
||||
onSearch,
|
||||
isSearching = false,
|
||||
filterValue = "",
|
||||
onFilterChange,
|
||||
pagination,
|
||||
setPagination,
|
||||
paginationInfo,
|
||||
@@ -102,25 +108,9 @@ export function IPAddressesDataTable({
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
// 本地搜索输入状态(只在回车或点击按钮时触发搜索)
|
||||
const [localSearchValue, setLocalSearchValue] = React.useState(searchValue ?? "")
|
||||
|
||||
React.useEffect(() => {
|
||||
setLocalSearchValue(searchValue ?? "")
|
||||
}, [searchValue])
|
||||
|
||||
const handleSearchSubmit = () => {
|
||||
if (onSearch) {
|
||||
onSearch(localSearchValue)
|
||||
} else {
|
||||
table.getColumn(searchColumn)?.setFilterValue(localSearchValue)
|
||||
}
|
||||
}
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === 'Enter') {
|
||||
handleSearchSubmit()
|
||||
}
|
||||
// 智能搜索处理
|
||||
const handleSmartSearch = (filters: ParsedFilter[], rawQuery: string) => {
|
||||
onFilterChange?.(rawQuery)
|
||||
}
|
||||
|
||||
const useServerPagination = !!paginationInfo && !!pagination && !!setPagination
|
||||
@@ -141,7 +131,7 @@ export function IPAddressesDataTable({
|
||||
enableColumnResizing: true,
|
||||
columnResizeMode: 'onChange',
|
||||
onColumnSizingChange: setColumnSizing,
|
||||
getRowId: (row) => row.ip, // IP 地址本身就是唯一标识
|
||||
getRowId: (row) => row.ip,
|
||||
enableRowSelection: true,
|
||||
onRowSelectionChange: setRowSelection,
|
||||
onSortingChange: setSorting,
|
||||
@@ -163,10 +153,6 @@ export function IPAddressesDataTable({
|
||||
: Math.ceil(data.length / tablePagination.pageSize) || 1,
|
||||
})
|
||||
|
||||
const totalItems = useServerPagination
|
||||
? paginationInfo?.total ?? data.length
|
||||
: table.getFilteredRowModel().rows.length
|
||||
|
||||
// 处理选中行变化
|
||||
React.useEffect(() => {
|
||||
if (onSelectionChange) {
|
||||
@@ -177,25 +163,16 @@ export function IPAddressesDataTable({
|
||||
|
||||
return (
|
||||
<div className="w-full space-y-4">
|
||||
{/* 工具栏 */}
|
||||
<div className="flex items-center justify-between">
|
||||
{/* 搜索框 */}
|
||||
<div className="flex items-center space-x-2">
|
||||
<Input
|
||||
placeholder={searchPlaceholder}
|
||||
value={localSearchValue}
|
||||
onChange={(e) => setLocalSearchValue(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
className="h-8 max-w-sm"
|
||||
/>
|
||||
<Button variant="outline" size="sm" onClick={handleSearchSubmit} disabled={isSearching}>
|
||||
{isSearching ? (
|
||||
<IconLoader2 className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<IconSearch className="h-4 w-4" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
{/* 工具栏 - 方案 D:智能搜索框 */}
|
||||
<div className="flex items-center justify-between gap-4 flex-wrap">
|
||||
{/* 左侧:智能搜索框 */}
|
||||
<SmartFilterInput
|
||||
fields={IP_ADDRESS_FILTER_FIELDS}
|
||||
examples={IP_ADDRESS_FILTER_EXAMPLES}
|
||||
value={filterValue}
|
||||
onSearch={handleSmartSearch}
|
||||
className="flex-1 max-w-xl"
|
||||
/>
|
||||
|
||||
{/* 右侧操作按钮 */}
|
||||
<div className="flex items-center space-x-2">
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client"
|
||||
|
||||
import React, { useCallback, useMemo, useState, useEffect } from "react"
|
||||
import React, { useCallback, useMemo, useState } from "react"
|
||||
import { AlertTriangle } from "lucide-react"
|
||||
import { IPAddressesDataTable } from "./ip-addresses-data-table"
|
||||
import { createIPAddressColumns } from "./ip-addresses-columns"
|
||||
@@ -23,13 +23,10 @@ export function IPAddressesView({
|
||||
pageSize: 10,
|
||||
})
|
||||
const [selectedIPAddresses, setSelectedIPAddresses] = useState<IPAddress[]>([])
|
||||
const [filterQuery, setFilterQuery] = useState("")
|
||||
|
||||
const [searchQuery, setSearchQuery] = useState("")
|
||||
const [isSearching, setIsSearching] = useState(false)
|
||||
|
||||
const handleSearchChange = (value: string) => {
|
||||
setIsSearching(true)
|
||||
setSearchQuery(value)
|
||||
const handleFilterChange = (value: string) => {
|
||||
setFilterQuery(value)
|
||||
setPagination((prev) => ({ ...prev, pageIndex: 0 }))
|
||||
}
|
||||
|
||||
@@ -38,7 +35,7 @@ export function IPAddressesView({
|
||||
{
|
||||
page: pagination.pageIndex + 1,
|
||||
pageSize: pagination.pageSize,
|
||||
search: searchQuery || undefined,
|
||||
filter: filterQuery || undefined,
|
||||
},
|
||||
{ enabled: !!targetId }
|
||||
)
|
||||
@@ -48,19 +45,13 @@ export function IPAddressesView({
|
||||
{
|
||||
page: pagination.pageIndex + 1,
|
||||
pageSize: pagination.pageSize,
|
||||
search: searchQuery || undefined,
|
||||
filter: filterQuery || undefined,
|
||||
},
|
||||
{ enabled: !!scanId }
|
||||
)
|
||||
|
||||
const activeQuery = targetId ? targetQuery : scanQuery
|
||||
const { data, isLoading, isFetching, error, refetch } = activeQuery
|
||||
|
||||
useEffect(() => {
|
||||
if (!isFetching && isSearching) {
|
||||
setIsSearching(false)
|
||||
}
|
||||
}, [isFetching, isSearching])
|
||||
const { data, isLoading, error, refetch } = activeQuery
|
||||
|
||||
const formatDate = useCallback((dateString: string) => {
|
||||
return new Date(dateString).toLocaleString("zh-CN", {
|
||||
@@ -225,11 +216,8 @@ export function IPAddressesView({
|
||||
<IPAddressesDataTable
|
||||
data={ipAddresses}
|
||||
columns={columns}
|
||||
searchPlaceholder="搜索IP地址..."
|
||||
searchColumn="ip"
|
||||
searchValue={searchQuery}
|
||||
onSearch={handleSearchChange}
|
||||
isSearching={isSearching}
|
||||
filterValue={filterQuery}
|
||||
onFilterChange={handleFilterChange}
|
||||
pagination={pagination}
|
||||
setPagination={setPagination}
|
||||
paginationInfo={paginationInfo}
|
||||
|
||||
@@ -29,8 +29,6 @@ import {
|
||||
IconPlus,
|
||||
IconTrash,
|
||||
IconDownload,
|
||||
IconSearch,
|
||||
IconLoader2,
|
||||
} from "@tabler/icons-react"
|
||||
|
||||
// 导入 UI 组件
|
||||
@@ -44,7 +42,6 @@ import {
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu"
|
||||
import { Input } from "@/components/ui/input"
|
||||
import { Label } from "@/components/ui/label"
|
||||
import {
|
||||
Select,
|
||||
@@ -61,11 +58,23 @@ import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table"
|
||||
import { SmartFilterInput, PREDEFINED_FIELDS, type FilterField } from "@/components/common/smart-filter-input"
|
||||
|
||||
// 导入子域名类型定义
|
||||
import type { Subdomain } from "@/types/subdomain.types"
|
||||
import type { PaginationInfo } from "@/types/common.types"
|
||||
|
||||
// 子域名页面的过滤字段配置
|
||||
const SUBDOMAIN_FILTER_FIELDS: FilterField[] = [
|
||||
{ key: "name", label: "Name", description: "Subdomain name" },
|
||||
]
|
||||
|
||||
// 子域名页面的示例
|
||||
const SUBDOMAIN_FILTER_EXAMPLES = [
|
||||
'name="api.example.com"',
|
||||
'name="*.test.com"',
|
||||
]
|
||||
|
||||
// 组件属性类型定义
|
||||
interface SubdomainsDataTableProps {
|
||||
data: Subdomain[] // 子域名数据数组
|
||||
@@ -74,10 +83,9 @@ interface SubdomainsDataTableProps {
|
||||
onBulkAdd?: () => void // 批量添加回调函数
|
||||
onBulkDelete?: () => void // 批量删除回调函数
|
||||
onSelectionChange?: (selectedRows: Subdomain[]) => void // 选中行变化回调
|
||||
searchPlaceholder?: string // 搜索框占位符
|
||||
searchColumn?: string // 搜索的列名
|
||||
searchValue?: string // 受控:搜索框当前值(服务端搜索)
|
||||
onSearch?: (value: string) => void // 受控:搜索框变更回调(服务端搜索)
|
||||
// 智能过滤
|
||||
filterValue?: string // 受控:过滤值
|
||||
onFilterChange?: (value: string) => void // 受控:过滤变更回调
|
||||
isSearching?: boolean // 搜索中状态(显示加载动画)
|
||||
addButtonText?: string // 添加按钮文本
|
||||
// 下载回调函数
|
||||
@@ -104,10 +112,8 @@ export function SubdomainsDataTable({
|
||||
onBulkAdd,
|
||||
onBulkDelete,
|
||||
onSelectionChange,
|
||||
searchPlaceholder = "搜索子域名...",
|
||||
searchColumn = "name",
|
||||
searchValue,
|
||||
onSearch,
|
||||
filterValue,
|
||||
onFilterChange,
|
||||
isSearching = false,
|
||||
addButtonText = "Add",
|
||||
onDownloadAll,
|
||||
@@ -140,25 +146,10 @@ export function SubdomainsDataTable({
|
||||
const pagination = externalPagination || internalPagination
|
||||
const setPagination = setExternalPagination || setInternalPagination
|
||||
|
||||
// 本地搜索输入状态(只在回车或点击按钮时触发搜索)
|
||||
const [localSearchValue, setLocalSearchValue] = React.useState(searchValue ?? "")
|
||||
|
||||
// 同步外部 searchValue 到本地状态
|
||||
React.useEffect(() => {
|
||||
setLocalSearchValue(searchValue ?? "")
|
||||
}, [searchValue])
|
||||
|
||||
const handleSearchSubmit = () => {
|
||||
if (onSearch) {
|
||||
onSearch(localSearchValue)
|
||||
} else {
|
||||
table.getColumn(searchColumn)?.setFilterValue(localSearchValue)
|
||||
}
|
||||
}
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === 'Enter') {
|
||||
handleSearchSubmit()
|
||||
// 处理智能过滤搜索
|
||||
const handleSmartSearch = (_filters: any[], rawQuery: string) => {
|
||||
if (onFilterChange) {
|
||||
onFilterChange(rawQuery)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,23 +214,14 @@ export function SubdomainsDataTable({
|
||||
<div className="w-full space-y-4">
|
||||
{/* 工具栏 */}
|
||||
<div className="flex items-center justify-between">
|
||||
{/* 搜索框 */}
|
||||
<div className="flex items-center space-x-2">
|
||||
<Input
|
||||
placeholder={searchPlaceholder}
|
||||
value={localSearchValue}
|
||||
onChange={(e) => setLocalSearchValue(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
className="h-8 max-w-sm"
|
||||
/>
|
||||
<Button variant="outline" size="sm" onClick={handleSearchSubmit} disabled={isSearching}>
|
||||
{isSearching ? (
|
||||
<IconLoader2 className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<IconSearch className="h-4 w-4" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
{/* 智能过滤搜索框 */}
|
||||
<SmartFilterInput
|
||||
fields={SUBDOMAIN_FILTER_FIELDS}
|
||||
examples={SUBDOMAIN_FILTER_EXAMPLES}
|
||||
value={filterValue}
|
||||
onSearch={handleSmartSearch}
|
||||
className="flex-1 max-w-xl"
|
||||
/>
|
||||
|
||||
{/* 右侧操作按钮 */}
|
||||
<div className="flex items-center space-x-2">
|
||||
|
||||
@@ -10,7 +10,6 @@ import {
|
||||
} from "@/hooks/use-subdomains"
|
||||
import { SubdomainsDataTable } from "./subdomains-data-table"
|
||||
import { createSubdomainColumns } from "./subdomains-columns"
|
||||
import { LoadingSpinner } from "@/components/loading-spinner"
|
||||
import { DataTableSkeleton } from "@/components/ui/data-table-skeleton"
|
||||
import { SubdomainService } from "@/services/subdomain.service"
|
||||
import { BulkAddSubdomainsDialog } from "./bulk-add-subdomains-dialog"
|
||||
@@ -40,23 +39,23 @@ export function SubdomainsDetailView({
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
// 搜索状态(服务端搜索)
|
||||
const [searchQuery, setSearchQuery] = useState("")
|
||||
// 过滤状态(智能过滤语法)
|
||||
const [filterQuery, setFilterQuery] = useState("")
|
||||
const [isSearching, setIsSearching] = useState(false)
|
||||
|
||||
const handleSearchChange = (value: string) => {
|
||||
const handleFilterChange = (value: string) => {
|
||||
setIsSearching(true)
|
||||
setSearchQuery(value)
|
||||
setFilterQuery(value)
|
||||
setPagination((prev) => ({ ...prev, pageIndex: 0 }))
|
||||
}
|
||||
|
||||
// 根据 targetId 或 scanId 获取子域名数据(传入分页和搜索参数)
|
||||
// 根据 targetId 或 scanId 获取子域名数据(传入分页和过滤参数)
|
||||
const targetSubdomainsQuery = useTargetSubdomains(
|
||||
targetId || 0,
|
||||
{
|
||||
page: pagination.pageIndex + 1, // 转换为 1-based
|
||||
pageSize: pagination.pageSize,
|
||||
search: searchQuery || undefined,
|
||||
filter: filterQuery || undefined,
|
||||
},
|
||||
{ enabled: !!targetId }
|
||||
)
|
||||
@@ -65,7 +64,7 @@ export function SubdomainsDetailView({
|
||||
{
|
||||
page: pagination.pageIndex + 1, // 转换为 1-based
|
||||
pageSize: pagination.pageSize,
|
||||
search: searchQuery || undefined,
|
||||
filter: filterQuery || undefined,
|
||||
},
|
||||
{ enabled: !!scanId }
|
||||
)
|
||||
@@ -254,10 +253,8 @@ export function SubdomainsDetailView({
|
||||
data={subdomains}
|
||||
columns={subdomainColumns}
|
||||
onSelectionChange={setSelectedSubdomains}
|
||||
searchPlaceholder="搜索子域名..."
|
||||
searchColumn="name"
|
||||
searchValue={searchQuery}
|
||||
onSearch={handleSearchChange}
|
||||
filterValue={filterQuery}
|
||||
onFilterChange={handleFilterChange}
|
||||
isSearching={isSearching}
|
||||
onDownloadAll={handleDownloadAll}
|
||||
onDownloadSelected={handleDownloadSelected}
|
||||
|
||||
@@ -25,8 +25,6 @@ import {
|
||||
IconLayoutColumns,
|
||||
IconTrash,
|
||||
IconDownload,
|
||||
IconSearch,
|
||||
IconLoader2,
|
||||
} from "@tabler/icons-react"
|
||||
|
||||
import { Button } from "@/components/ui/button"
|
||||
@@ -39,7 +37,6 @@ import {
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu"
|
||||
import { Input } from "@/components/ui/input"
|
||||
import { Label } from "@/components/ui/label"
|
||||
import {
|
||||
Select,
|
||||
@@ -56,37 +53,47 @@ import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table"
|
||||
import { SmartFilterInput, PREDEFINED_FIELDS, type FilterField } from "@/components/common/smart-filter-input"
|
||||
|
||||
import type { Vulnerability } from "@/types/vulnerability.types"
|
||||
import type { PaginationInfo } from "@/types/common.types"
|
||||
|
||||
// 漏洞页面的过滤字段
|
||||
const VULNERABILITY_FILTER_FIELDS: FilterField[] = [
|
||||
{ key: "type", label: "Type", description: "Vulnerability type" },
|
||||
PREDEFINED_FIELDS.severity,
|
||||
{ key: "source", label: "Source", description: "Scanner source" },
|
||||
PREDEFINED_FIELDS.url,
|
||||
]
|
||||
|
||||
// 漏洞页面的示例
|
||||
const VULNERABILITY_FILTER_EXAMPLES = [
|
||||
'type="xss" || type="sqli"',
|
||||
'severity="critical" || severity="high"',
|
||||
'source="nuclei" && severity="high"',
|
||||
'type="xss" && url="/api/*"',
|
||||
]
|
||||
|
||||
interface VulnerabilitiesDataTableProps {
|
||||
data: Vulnerability[]
|
||||
columns: ColumnDef<Vulnerability>[]
|
||||
searchPlaceholder?: string
|
||||
searchColumn?: string
|
||||
searchValue?: string
|
||||
onSearch?: (value: string) => void
|
||||
isSearching?: boolean
|
||||
filterValue?: string
|
||||
onFilterChange?: (value: string) => void
|
||||
pagination?: { pageIndex: number; pageSize: number }
|
||||
setPagination?: React.Dispatch<React.SetStateAction<{ pageIndex: number; pageSize: number }>>
|
||||
paginationInfo?: PaginationInfo
|
||||
onPaginationChange?: (pagination: { pageIndex: number; pageSize: number }) => void
|
||||
onBulkDelete?: () => void // 批量删除回调函数
|
||||
onSelectionChange?: (selectedRows: Vulnerability[]) => void // 选中行变化回调
|
||||
// 下载回调函数
|
||||
onDownloadAll?: () => void // 下载所有漏洞
|
||||
onDownloadSelected?: () => void // 下载选中的漏洞
|
||||
onBulkDelete?: () => void
|
||||
onSelectionChange?: (selectedRows: Vulnerability[]) => void
|
||||
onDownloadAll?: () => void
|
||||
onDownloadSelected?: () => void
|
||||
}
|
||||
|
||||
export function VulnerabilitiesDataTable({
|
||||
data = [],
|
||||
columns,
|
||||
searchPlaceholder = "搜索漏洞类型...",
|
||||
searchColumn = "title",
|
||||
searchValue,
|
||||
onSearch,
|
||||
isSearching = false,
|
||||
filterValue,
|
||||
onFilterChange,
|
||||
pagination,
|
||||
setPagination,
|
||||
paginationInfo,
|
||||
@@ -107,27 +114,6 @@ export function VulnerabilitiesDataTable({
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
// 本地搜索输入状态(只在回车或点击按钮时触发搜索)
|
||||
const [localSearchValue, setLocalSearchValue] = React.useState(searchValue ?? "")
|
||||
|
||||
React.useEffect(() => {
|
||||
setLocalSearchValue(searchValue ?? "")
|
||||
}, [searchValue])
|
||||
|
||||
const handleSearchSubmit = () => {
|
||||
if (onSearch) {
|
||||
onSearch(localSearchValue)
|
||||
} else {
|
||||
table.getColumn(searchColumn)?.setFilterValue(localSearchValue)
|
||||
}
|
||||
}
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === 'Enter') {
|
||||
handleSearchSubmit()
|
||||
}
|
||||
}
|
||||
|
||||
const useServerPagination = !!paginationInfo && !!pagination && !!setPagination
|
||||
const tablePagination = useServerPagination ? pagination : internalPagination
|
||||
const setTablePagination = useServerPagination ? setPagination : setInternalPagination
|
||||
@@ -182,27 +168,23 @@ export function VulnerabilitiesDataTable({
|
||||
}
|
||||
}, [rowSelection, onSelectionChange, table])
|
||||
|
||||
// 处理智能过滤搜索
|
||||
const handleFilterSearch = (_filters: any[], rawQuery: string) => {
|
||||
onFilterChange?.(rawQuery)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="w-full space-y-4">
|
||||
{/* 工具栏 */}
|
||||
<div className="flex items-center justify-between">
|
||||
{/* 搜索框 */}
|
||||
<div className="flex items-center space-x-2">
|
||||
<Input
|
||||
placeholder={searchPlaceholder}
|
||||
value={localSearchValue}
|
||||
onChange={(e) => setLocalSearchValue(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
className="h-8 max-w-sm"
|
||||
/>
|
||||
<Button variant="outline" size="sm" onClick={handleSearchSubmit} disabled={isSearching}>
|
||||
{isSearching ? (
|
||||
<IconLoader2 className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<IconSearch className="h-4 w-4" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
{/* 智能过滤输入框 */}
|
||||
<SmartFilterInput
|
||||
fields={VULNERABILITY_FILTER_FIELDS}
|
||||
examples={VULNERABILITY_FILTER_EXAMPLES}
|
||||
value={filterValue}
|
||||
onSearch={handleFilterSearch}
|
||||
className="flex-1 max-w-xl"
|
||||
/>
|
||||
|
||||
{/* 右侧操作按钮 */}
|
||||
<div className="flex items-center space-x-2">
|
||||
|
||||
@@ -42,19 +42,17 @@ export function VulnerabilitiesDetailView({
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
const [searchQuery, setSearchQuery] = useState("")
|
||||
const [isSearching, setIsSearching] = useState(false)
|
||||
// 智能过滤查询
|
||||
const [filterQuery, setFilterQuery] = useState("")
|
||||
|
||||
const handleSearchChange = (value: string) => {
|
||||
setIsSearching(true)
|
||||
setSearchQuery(value)
|
||||
const handleFilterChange = (value: string) => {
|
||||
setFilterQuery(value)
|
||||
setPagination((prev) => ({ ...prev, pageIndex: 0 }))
|
||||
}
|
||||
|
||||
const paginationParams = {
|
||||
page: pagination.pageIndex + 1,
|
||||
pageSize: pagination.pageSize,
|
||||
search: searchQuery || undefined,
|
||||
}
|
||||
|
||||
// 按 scan 维度加载(扫描历史页面)
|
||||
@@ -62,6 +60,7 @@ export function VulnerabilitiesDetailView({
|
||||
scanId ?? 0,
|
||||
paginationParams,
|
||||
{ enabled: !!scanId },
|
||||
filterQuery || undefined,
|
||||
)
|
||||
|
||||
// 按 target 维度加载(目标详情页面)
|
||||
@@ -69,25 +68,20 @@ export function VulnerabilitiesDetailView({
|
||||
targetId ?? 0,
|
||||
paginationParams,
|
||||
{ enabled: !!targetId && !scanId },
|
||||
filterQuery || undefined,
|
||||
)
|
||||
|
||||
// 获取所有漏洞(全局漏洞页面)
|
||||
const allQuery = useAllVulnerabilities(
|
||||
paginationParams,
|
||||
{ enabled: !scanId && !targetId },
|
||||
filterQuery || undefined,
|
||||
)
|
||||
|
||||
// 根据参数选择使用哪个 query
|
||||
const activeQuery = scanId ? scanQuery : targetId ? targetQuery : allQuery
|
||||
const isQueryLoading = activeQuery.isLoading
|
||||
|
||||
// 当请求完成时重置搜索状态
|
||||
React.useEffect(() => {
|
||||
if (!activeQuery.isFetching && isSearching) {
|
||||
setIsSearching(false)
|
||||
}
|
||||
}, [activeQuery.isFetching, isSearching])
|
||||
|
||||
const vulnerabilities = activeQuery.data?.vulnerabilities ?? []
|
||||
const paginationInfo = activeQuery.data?.pagination ?? {
|
||||
total: vulnerabilities.length,
|
||||
@@ -206,11 +200,8 @@ export function VulnerabilitiesDetailView({
|
||||
<VulnerabilitiesDataTable
|
||||
data={vulnerabilities}
|
||||
columns={vulnerabilityColumns}
|
||||
searchPlaceholder="搜索漏洞类型..."
|
||||
searchColumn="vulnType"
|
||||
searchValue={searchQuery}
|
||||
onSearch={handleSearchChange}
|
||||
isSearching={isSearching}
|
||||
filterValue={filterQuery}
|
||||
onFilterChange={handleFilterChange}
|
||||
pagination={pagination}
|
||||
setPagination={setPagination}
|
||||
paginationInfo={{
|
||||
|
||||
@@ -25,8 +25,6 @@ import {
|
||||
IconLayoutColumns,
|
||||
IconTrash,
|
||||
IconDownload,
|
||||
IconSearch,
|
||||
IconLoader2,
|
||||
IconPlus,
|
||||
} from "@tabler/icons-react"
|
||||
|
||||
@@ -40,7 +38,6 @@ import {
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu"
|
||||
import { Input } from "@/components/ui/input"
|
||||
import { Label } from "@/components/ui/label"
|
||||
import {
|
||||
Select,
|
||||
@@ -57,17 +54,32 @@ import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table"
|
||||
import { SmartFilterInput, type FilterField } from "@/components/common/smart-filter-input"
|
||||
|
||||
import type { WebSite } from "@/types/website.types"
|
||||
import type { PaginationInfo } from "@/types/common.types"
|
||||
|
||||
// 网站页面的过滤字段配置
|
||||
const WEBSITE_FILTER_FIELDS: FilterField[] = [
|
||||
{ key: "url", label: "URL", description: "Full URL" },
|
||||
{ key: "host", label: "Host", description: "Hostname" },
|
||||
{ key: "title", label: "Title", description: "Page title" },
|
||||
{ key: "status", label: "Status", description: "HTTP status code" },
|
||||
]
|
||||
|
||||
// 网站页面的示例
|
||||
const WEBSITE_FILTER_EXAMPLES = [
|
||||
'host="api.example.com" && status="200"',
|
||||
'title="Login" || title="Admin"',
|
||||
'url="/api/*" && status!="404"',
|
||||
]
|
||||
|
||||
interface WebSitesDataTableProps {
|
||||
data: WebSite[]
|
||||
columns: ColumnDef<WebSite>[]
|
||||
searchPlaceholder?: string
|
||||
searchColumn?: string
|
||||
searchValue?: string
|
||||
onSearch?: (value: string) => void
|
||||
// 智能过滤
|
||||
filterValue?: string
|
||||
onFilterChange?: (value: string) => void
|
||||
isSearching?: boolean
|
||||
pagination?: { pageIndex: number; pageSize: number }
|
||||
setPagination?: React.Dispatch<React.SetStateAction<{ pageIndex: number; pageSize: number }>>
|
||||
@@ -84,10 +96,8 @@ interface WebSitesDataTableProps {
|
||||
export function WebSitesDataTable({
|
||||
data = [],
|
||||
columns,
|
||||
searchPlaceholder = "搜索主机名...",
|
||||
searchColumn = "url",
|
||||
searchValue,
|
||||
onSearch,
|
||||
filterValue,
|
||||
onFilterChange,
|
||||
isSearching = false,
|
||||
pagination,
|
||||
setPagination,
|
||||
@@ -109,24 +119,10 @@ export function WebSitesDataTable({
|
||||
pageSize: 10,
|
||||
})
|
||||
|
||||
// 本地搜索输入状态(只在回车或点击按钮时触发搜索)
|
||||
const [localSearchValue, setLocalSearchValue] = React.useState(searchValue ?? "")
|
||||
|
||||
React.useEffect(() => {
|
||||
setLocalSearchValue(searchValue ?? "")
|
||||
}, [searchValue])
|
||||
|
||||
const handleSearchSubmit = () => {
|
||||
if (onSearch) {
|
||||
onSearch(localSearchValue)
|
||||
} else {
|
||||
table.getColumn(searchColumn)?.setFilterValue(localSearchValue)
|
||||
}
|
||||
}
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === 'Enter') {
|
||||
handleSearchSubmit()
|
||||
// 处理智能过滤搜索
|
||||
const handleSmartSearch = (_filters: any[], rawQuery: string) => {
|
||||
if (onFilterChange) {
|
||||
onFilterChange(rawQuery)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,23 +184,14 @@ export function WebSitesDataTable({
|
||||
<div className="w-full space-y-4">
|
||||
{/* 工具栏 */}
|
||||
<div className="flex items-center justify-between">
|
||||
{/* 搜索框 */}
|
||||
<div className="flex items-center space-x-2">
|
||||
<Input
|
||||
placeholder={searchPlaceholder}
|
||||
value={localSearchValue}
|
||||
onChange={(e) => setLocalSearchValue(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
className="h-8 max-w-sm"
|
||||
/>
|
||||
<Button variant="outline" size="sm" onClick={handleSearchSubmit} disabled={isSearching}>
|
||||
{isSearching ? (
|
||||
<IconLoader2 className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<IconSearch className="h-4 w-4" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
{/* 智能过滤搜索框 */}
|
||||
<SmartFilterInput
|
||||
fields={WEBSITE_FILTER_FIELDS}
|
||||
examples={WEBSITE_FILTER_EXAMPLES}
|
||||
value={filterValue}
|
||||
onSearch={handleSmartSearch}
|
||||
className="flex-1 max-w-xl"
|
||||
/>
|
||||
|
||||
{/* 右侧操作按钮 */}
|
||||
<div className="flex items-center space-x-2">
|
||||
|
||||
@@ -28,15 +28,15 @@ export function WebSitesView({
|
||||
const [selectedWebSites, setSelectedWebSites] = useState<WebSite[]>([])
|
||||
const [bulkAddDialogOpen, setBulkAddDialogOpen] = useState(false)
|
||||
|
||||
const [searchQuery, setSearchQuery] = useState("")
|
||||
const [filterQuery, setFilterQuery] = useState("")
|
||||
const [isSearching, setIsSearching] = useState(false)
|
||||
|
||||
// 获取目标信息(用于 URL 匹配校验)
|
||||
const { data: target } = useTarget(targetId || 0, { enabled: !!targetId })
|
||||
|
||||
const handleSearchChange = (value: string) => {
|
||||
const handleFilterChange = (value: string) => {
|
||||
setIsSearching(true)
|
||||
setSearchQuery(value)
|
||||
setFilterQuery(value)
|
||||
setPagination((prev) => ({ ...prev, pageIndex: 0 }))
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ export function WebSitesView({
|
||||
{
|
||||
page: pagination.pageIndex + 1,
|
||||
pageSize: pagination.pageSize,
|
||||
search: searchQuery || undefined,
|
||||
filter: filterQuery || undefined,
|
||||
},
|
||||
{ enabled: !!targetId }
|
||||
)
|
||||
@@ -55,7 +55,7 @@ export function WebSitesView({
|
||||
{
|
||||
page: pagination.pageIndex + 1,
|
||||
pageSize: pagination.pageSize,
|
||||
search: searchQuery || undefined,
|
||||
filter: filterQuery || undefined,
|
||||
},
|
||||
{ enabled: !!scanId }
|
||||
)
|
||||
@@ -248,10 +248,8 @@ export function WebSitesView({
|
||||
<WebSitesDataTable
|
||||
data={websites}
|
||||
columns={columns}
|
||||
searchPlaceholder="搜索主机名..."
|
||||
searchColumn="url"
|
||||
searchValue={searchQuery}
|
||||
onSearch={handleSearchChange}
|
||||
filterValue={filterQuery}
|
||||
onFilterChange={handleFilterChange}
|
||||
isSearching={isSearching}
|
||||
pagination={pagination}
|
||||
setPagination={setPagination}
|
||||
|
||||
@@ -8,11 +8,11 @@ const directoryService = {
|
||||
// 获取目标的目录列表
|
||||
getTargetDirectories: async (
|
||||
targetId: number,
|
||||
params: { page: number; pageSize: number; search?: string }
|
||||
params: { page: number; pageSize: number; filter?: string }
|
||||
): Promise<DirectoryListResponse> => {
|
||||
const searchParam = params.search ? `&search=${encodeURIComponent(params.search)}` : ''
|
||||
const filterParam = params.filter ? `&filter=${encodeURIComponent(params.filter)}` : ''
|
||||
const response = await fetch(
|
||||
`/api/targets/${targetId}/directories/?page=${params.page}&pageSize=${params.pageSize}${searchParam}`
|
||||
`/api/targets/${targetId}/directories/?page=${params.page}&pageSize=${params.pageSize}${filterParam}`
|
||||
)
|
||||
if (!response.ok) {
|
||||
throw new Error('获取目录列表失败')
|
||||
@@ -23,11 +23,11 @@ const directoryService = {
|
||||
// 获取扫描的目录列表
|
||||
getScanDirectories: async (
|
||||
scanId: number,
|
||||
params: { page: number; pageSize: number; search?: string }
|
||||
params: { page: number; pageSize: number; filter?: string }
|
||||
): Promise<DirectoryListResponse> => {
|
||||
const searchParam = params.search ? `&search=${encodeURIComponent(params.search)}` : ''
|
||||
const filterParam = params.filter ? `&filter=${encodeURIComponent(params.filter)}` : ''
|
||||
const response = await fetch(
|
||||
`/api/scans/${scanId}/directories/?page=${params.page}&pageSize=${params.pageSize}${searchParam}`
|
||||
`/api/scans/${scanId}/directories/?page=${params.page}&pageSize=${params.pageSize}${filterParam}`
|
||||
)
|
||||
if (!response.ok) {
|
||||
throw new Error('获取目录列表失败')
|
||||
@@ -80,7 +80,7 @@ const directoryService = {
|
||||
// 获取目标的目录列表
|
||||
export function useTargetDirectories(
|
||||
targetId: number,
|
||||
params: { page: number; pageSize: number; search?: string },
|
||||
params: { page: number; pageSize: number; filter?: string },
|
||||
options?: { enabled?: boolean }
|
||||
) {
|
||||
return useQuery({
|
||||
@@ -94,7 +94,7 @@ export function useTargetDirectories(
|
||||
// 获取扫描的目录列表
|
||||
export function useScanDirectories(
|
||||
scanId: number,
|
||||
params: { page: number; pageSize: number; search?: string },
|
||||
params: { page: number; pageSize: number; filter?: string },
|
||||
options?: { enabled?: boolean }
|
||||
) {
|
||||
return useQuery({
|
||||
|
||||
@@ -61,7 +61,7 @@ export function useEndpoints(params?: GetEndpointsRequest) {
|
||||
}
|
||||
|
||||
// 根据目标ID获取 Endpoint 列表(使用专用路由)
|
||||
export function useEndpointsByTarget(targetId: number, params?: Omit<GetEndpointsRequest, 'targetId'>) {
|
||||
export function useEndpointsByTarget(targetId: number, params?: Omit<GetEndpointsRequest, 'targetId'>, filter?: string) {
|
||||
const defaultParams: GetEndpointsRequest = {
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
@@ -69,8 +69,8 @@ export function useEndpointsByTarget(targetId: number, params?: Omit<GetEndpoint
|
||||
}
|
||||
|
||||
return useQuery({
|
||||
queryKey: endpointKeys.byTarget(targetId, defaultParams),
|
||||
queryFn: () => EndpointService.getEndpointsByTargetId(targetId, defaultParams),
|
||||
queryKey: [...endpointKeys.byTarget(targetId, defaultParams), filter],
|
||||
queryFn: () => EndpointService.getEndpointsByTargetId(targetId, defaultParams, filter),
|
||||
select: (response) => {
|
||||
// RESTful 标准:直接返回数据
|
||||
return response as GetEndpointsResponse
|
||||
@@ -81,7 +81,7 @@ export function useEndpointsByTarget(targetId: number, params?: Omit<GetEndpoint
|
||||
}
|
||||
|
||||
// 根据扫描ID获取 Endpoint 列表(历史快照)
|
||||
export function useScanEndpoints(scanId: number, params?: Omit<GetEndpointsRequest, 'targetId'>, options?: { enabled?: boolean }) {
|
||||
export function useScanEndpoints(scanId: number, params?: Omit<GetEndpointsRequest, 'targetId'>, options?: { enabled?: boolean }, filter?: string) {
|
||||
const defaultParams: GetEndpointsRequest = {
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
@@ -89,8 +89,8 @@ export function useScanEndpoints(scanId: number, params?: Omit<GetEndpointsReque
|
||||
}
|
||||
|
||||
return useQuery({
|
||||
queryKey: endpointKeys.byScan(scanId, defaultParams),
|
||||
queryFn: () => EndpointService.getEndpointsByScanId(scanId, defaultParams),
|
||||
queryKey: [...endpointKeys.byScan(scanId, defaultParams), filter],
|
||||
queryFn: () => EndpointService.getEndpointsByScanId(scanId, defaultParams, filter),
|
||||
enabled: options?.enabled !== undefined ? options.enabled : !!scanId,
|
||||
select: (response: any) => {
|
||||
// 后端使用通用分页格式:results/total/page/pageSize/totalPages
|
||||
|
||||
@@ -16,7 +16,7 @@ function normalizeParams(params?: GetIPAddressesParams): Required<GetIPAddresses
|
||||
return {
|
||||
page: params?.page ?? 1,
|
||||
pageSize: params?.pageSize ?? 10,
|
||||
search: params?.search ?? "",
|
||||
filter: params?.filter ?? "",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -267,11 +267,11 @@ export function useAllSubdomains(
|
||||
// 获取目标的子域名列表
|
||||
export function useTargetSubdomains(
|
||||
targetId: number,
|
||||
params?: { page?: number; pageSize?: number; search?: string },
|
||||
params?: { page?: number; pageSize?: number; filter?: string },
|
||||
options?: { enabled?: boolean }
|
||||
) {
|
||||
return useQuery({
|
||||
queryKey: ['targets', targetId, 'subdomains', { page: params?.page, pageSize: params?.pageSize, search: params?.search }],
|
||||
queryKey: ['targets', targetId, 'subdomains', { page: params?.page, pageSize: params?.pageSize, filter: params?.filter }],
|
||||
queryFn: () => SubdomainService.getSubdomainsByTargetId(targetId, params),
|
||||
enabled: options?.enabled !== undefined ? options.enabled : !!targetId,
|
||||
placeholderData: keepPreviousData,
|
||||
@@ -281,11 +281,11 @@ export function useTargetSubdomains(
|
||||
// 获取扫描的子域名列表
|
||||
export function useScanSubdomains(
|
||||
scanId: number,
|
||||
params?: { page?: number; pageSize?: number; search?: string },
|
||||
params?: { page?: number; pageSize?: number; filter?: string },
|
||||
options?: { enabled?: boolean }
|
||||
) {
|
||||
return useQuery({
|
||||
queryKey: ['scans', scanId, 'subdomains', { page: params?.page, pageSize: params?.pageSize, search: params?.search }],
|
||||
queryKey: ['scans', scanId, 'subdomains', { page: params?.page, pageSize: params?.pageSize, filter: params?.filter }],
|
||||
queryFn: () => SubdomainService.getSubdomainsByScanId(scanId, params),
|
||||
enabled: options?.enabled !== undefined ? options.enabled : !!scanId,
|
||||
placeholderData: keepPreviousData,
|
||||
|
||||
@@ -267,7 +267,7 @@ export function useTargetEndpoints(
|
||||
params?: {
|
||||
page?: number
|
||||
pageSize?: number
|
||||
search?: string
|
||||
filter?: string
|
||||
},
|
||||
options?: {
|
||||
enabled?: boolean
|
||||
@@ -277,9 +277,9 @@ export function useTargetEndpoints(
|
||||
queryKey: ['targets', 'detail', targetId, 'endpoints', {
|
||||
page: params?.page,
|
||||
pageSize: params?.pageSize,
|
||||
search: params?.search,
|
||||
filter: params?.filter,
|
||||
}],
|
||||
queryFn: () => getTargetEndpoints(targetId, params?.page || 1, params?.pageSize || 10, params?.search),
|
||||
queryFn: () => getTargetEndpoints(targetId, params?.page || 1, params?.pageSize || 10, params?.filter),
|
||||
enabled: options?.enabled !== undefined ? options.enabled : !!targetId,
|
||||
select: (response: any) => {
|
||||
// 后端使用通用分页格式:results/total/page/pageSize/totalPages
|
||||
|
||||
@@ -12,18 +12,19 @@ import type { PaginationInfo } from "@/types/common.types"
|
||||
|
||||
export const vulnerabilityKeys = {
|
||||
all: ["vulnerabilities"] as const,
|
||||
list: (params: GetVulnerabilitiesParams) =>
|
||||
[...vulnerabilityKeys.all, "list", params] as const,
|
||||
byScan: (scanId: number, params: GetVulnerabilitiesParams) =>
|
||||
[...vulnerabilityKeys.all, "scan", scanId, params] as const,
|
||||
byTarget: (targetId: number, params: GetVulnerabilitiesParams) =>
|
||||
[...vulnerabilityKeys.all, "target", targetId, params] as const,
|
||||
list: (params: GetVulnerabilitiesParams, filter?: string) =>
|
||||
[...vulnerabilityKeys.all, "list", params, filter] as const,
|
||||
byScan: (scanId: number, params: GetVulnerabilitiesParams, filter?: string) =>
|
||||
[...vulnerabilityKeys.all, "scan", scanId, params, filter] as const,
|
||||
byTarget: (targetId: number, params: GetVulnerabilitiesParams, filter?: string) =>
|
||||
[...vulnerabilityKeys.all, "target", targetId, params, filter] as const,
|
||||
}
|
||||
|
||||
/** 获取所有漏洞 */
|
||||
export function useAllVulnerabilities(
|
||||
params?: GetVulnerabilitiesParams,
|
||||
options?: { enabled?: boolean },
|
||||
filter?: string,
|
||||
) {
|
||||
const defaultParams: GetVulnerabilitiesParams = {
|
||||
page: 1,
|
||||
@@ -32,8 +33,8 @@ export function useAllVulnerabilities(
|
||||
}
|
||||
|
||||
return useQuery({
|
||||
queryKey: vulnerabilityKeys.list(defaultParams),
|
||||
queryFn: () => VulnerabilityService.getAllVulnerabilities(defaultParams),
|
||||
queryKey: vulnerabilityKeys.list(defaultParams, filter),
|
||||
queryFn: () => VulnerabilityService.getAllVulnerabilities(defaultParams, filter),
|
||||
enabled: options?.enabled ?? true,
|
||||
select: (response: any) => {
|
||||
const items = (response?.results ?? []) as any[]
|
||||
@@ -93,6 +94,7 @@ export function useScanVulnerabilities(
|
||||
scanId: number,
|
||||
params?: GetVulnerabilitiesParams,
|
||||
options?: { enabled?: boolean },
|
||||
filter?: string,
|
||||
) {
|
||||
const defaultParams: GetVulnerabilitiesParams = {
|
||||
page: 1,
|
||||
@@ -101,9 +103,9 @@ export function useScanVulnerabilities(
|
||||
}
|
||||
|
||||
return useQuery({
|
||||
queryKey: vulnerabilityKeys.byScan(scanId, defaultParams),
|
||||
queryKey: vulnerabilityKeys.byScan(scanId, defaultParams, filter),
|
||||
queryFn: () =>
|
||||
VulnerabilityService.getVulnerabilitiesByScanId(scanId, defaultParams),
|
||||
VulnerabilityService.getVulnerabilitiesByScanId(scanId, defaultParams, filter),
|
||||
enabled: options?.enabled !== undefined ? options.enabled : !!scanId,
|
||||
select: (response: any) => {
|
||||
const items = (response?.results ?? []) as any[]
|
||||
@@ -163,6 +165,7 @@ export function useTargetVulnerabilities(
|
||||
targetId: number,
|
||||
params?: GetVulnerabilitiesParams,
|
||||
options?: { enabled?: boolean },
|
||||
filter?: string,
|
||||
) {
|
||||
const defaultParams: GetVulnerabilitiesParams = {
|
||||
page: 1,
|
||||
@@ -171,9 +174,9 @@ export function useTargetVulnerabilities(
|
||||
}
|
||||
|
||||
return useQuery({
|
||||
queryKey: vulnerabilityKeys.byTarget(targetId, defaultParams),
|
||||
queryKey: vulnerabilityKeys.byTarget(targetId, defaultParams, filter),
|
||||
queryFn: () =>
|
||||
VulnerabilityService.getVulnerabilitiesByTargetId(targetId, defaultParams),
|
||||
VulnerabilityService.getVulnerabilitiesByTargetId(targetId, defaultParams, filter),
|
||||
enabled: options?.enabled !== undefined ? options.enabled : !!targetId,
|
||||
select: (response: any) => {
|
||||
const items = (response?.results ?? []) as any[]
|
||||
|
||||
@@ -8,11 +8,11 @@ const websiteService = {
|
||||
// 获取目标的网站列表
|
||||
getTargetWebSites: async (
|
||||
targetId: number,
|
||||
params: { page: number; pageSize: number; search?: string }
|
||||
params: { page: number; pageSize: number; filter?: string }
|
||||
): Promise<WebSiteListResponse> => {
|
||||
const searchParam = params.search ? `&search=${encodeURIComponent(params.search)}` : ''
|
||||
const filterParam = params.filter ? `&filter=${encodeURIComponent(params.filter)}` : ''
|
||||
const response = await fetch(
|
||||
`/api/targets/${targetId}/websites/?page=${params.page}&pageSize=${params.pageSize}${searchParam}`
|
||||
`/api/targets/${targetId}/websites/?page=${params.page}&pageSize=${params.pageSize}${filterParam}`
|
||||
)
|
||||
if (!response.ok) {
|
||||
throw new Error('获取网站列表失败')
|
||||
@@ -23,11 +23,11 @@ const websiteService = {
|
||||
// 获取扫描的网站列表
|
||||
getScanWebSites: async (
|
||||
scanId: number,
|
||||
params: { page: number; pageSize: number; search?: string }
|
||||
params: { page: number; pageSize: number; filter?: string }
|
||||
): Promise<WebSiteListResponse> => {
|
||||
const searchParam = params.search ? `&search=${encodeURIComponent(params.search)}` : ''
|
||||
const filterParam = params.filter ? `&filter=${encodeURIComponent(params.filter)}` : ''
|
||||
const response = await fetch(
|
||||
`/api/scans/${scanId}/websites/?page=${params.page}&pageSize=${params.pageSize}${searchParam}`
|
||||
`/api/scans/${scanId}/websites/?page=${params.page}&pageSize=${params.pageSize}${filterParam}`
|
||||
)
|
||||
if (!response.ok) {
|
||||
throw new Error('获取网站列表失败')
|
||||
@@ -80,7 +80,7 @@ const websiteService = {
|
||||
// 获取目标的网站列表
|
||||
export function useTargetWebSites(
|
||||
targetId: number,
|
||||
params: { page: number; pageSize: number; search?: string },
|
||||
params: { page: number; pageSize: number; filter?: string },
|
||||
options?: { enabled?: boolean }
|
||||
) {
|
||||
return useQuery({
|
||||
@@ -94,7 +94,7 @@ export function useTargetWebSites(
|
||||
// 获取扫描的网站列表
|
||||
export function useScanWebSites(
|
||||
scanId: number,
|
||||
params: { page: number; pageSize: number; search?: string },
|
||||
params: { page: number; pageSize: number; filter?: string },
|
||||
options?: { enabled?: boolean }
|
||||
) {
|
||||
return useQuery({
|
||||
|
||||
@@ -59,12 +59,13 @@ export class EndpointService {
|
||||
* 根据目标ID获取 Endpoint 列表(专用路由)
|
||||
* @param targetId - 目标ID
|
||||
* @param params - 其他查询参数
|
||||
* @param filter - 智能过滤查询字符串
|
||||
* @returns Promise<GetEndpointsResponse>
|
||||
*/
|
||||
static async getEndpointsByTargetId(targetId: number, params: GetEndpointsRequest): Promise<GetEndpointsResponse> {
|
||||
static async getEndpointsByTargetId(targetId: number, params: GetEndpointsRequest, filter?: string): Promise<GetEndpointsResponse> {
|
||||
// api-client.ts 会自动将 params 对象的驼峰转换为下划线
|
||||
const response = await api.get<GetEndpointsResponse>(`/targets/${targetId}/endpoints/`, {
|
||||
params
|
||||
params: { ...params, filter }
|
||||
})
|
||||
return response.data
|
||||
}
|
||||
@@ -73,13 +74,15 @@ export class EndpointService {
|
||||
* 根据扫描ID获取 Endpoint 列表(历史快照)
|
||||
* @param scanId - 扫描任务 ID
|
||||
* @param params - 分页等查询参数
|
||||
* @param filter - 智能过滤查询字符串
|
||||
*/
|
||||
static async getEndpointsByScanId(
|
||||
scanId: number,
|
||||
params: GetEndpointsRequest,
|
||||
filter?: string,
|
||||
): Promise<any> {
|
||||
const response = await api.get(`/scans/${scanId}/endpoints/`, {
|
||||
params,
|
||||
params: { ...params, filter },
|
||||
})
|
||||
return response.data
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ export class IPAddressService {
|
||||
params: {
|
||||
page: params?.page || 1,
|
||||
pageSize: params?.pageSize || 10,
|
||||
...(params?.search && { search: params.search }),
|
||||
...(params?.filter && { filter: params.filter }),
|
||||
},
|
||||
})
|
||||
return response.data
|
||||
@@ -24,7 +24,7 @@ export class IPAddressService {
|
||||
params: {
|
||||
page: params?.page || 1,
|
||||
pageSize: params?.pageSize || 10,
|
||||
...(params?.search && { search: params.search }),
|
||||
...(params?.filter && { filter: params.filter }),
|
||||
},
|
||||
})
|
||||
return response.data
|
||||
|
||||
@@ -173,32 +173,32 @@ export class SubdomainService {
|
||||
return response.data
|
||||
}
|
||||
|
||||
/** 获取目标的子域名列表(支持分页和搜索) */
|
||||
/** 获取目标的子域名列表(支持分页和过滤) */
|
||||
static async getSubdomainsByTargetId(
|
||||
targetId: number,
|
||||
params?: {
|
||||
page?: number
|
||||
pageSize?: number
|
||||
search?: string
|
||||
filter?: string
|
||||
}
|
||||
): Promise<any> {
|
||||
const response = await api.get(`/targets/${targetId}/subdomains/`, {
|
||||
params: {
|
||||
page: params?.page || 1,
|
||||
pageSize: params?.pageSize || 10,
|
||||
...(params?.search && { search: params.search }),
|
||||
...(params?.filter && { filter: params.filter }),
|
||||
}
|
||||
})
|
||||
return response.data
|
||||
}
|
||||
|
||||
/** 获取扫描的子域名列表(支持分页) */
|
||||
/** 获取扫描的子域名列表(支持分页和过滤) */
|
||||
static async getSubdomainsByScanId(
|
||||
scanId: number,
|
||||
params?: {
|
||||
page?: number
|
||||
pageSize?: number
|
||||
search?: string
|
||||
filter?: string
|
||||
}
|
||||
): Promise<{
|
||||
results: Array<{
|
||||
@@ -225,7 +225,7 @@ export class SubdomainService {
|
||||
params: {
|
||||
page: params?.page || 1,
|
||||
pageSize: params?.pageSize || 10,
|
||||
...(params?.search && { search: params.search }),
|
||||
...(params?.filter && { filter: params.filter }),
|
||||
}
|
||||
})
|
||||
return response.data as any
|
||||
|
||||
@@ -136,13 +136,13 @@ export async function getTargetEndpoints(
|
||||
id: number,
|
||||
page = 1,
|
||||
pageSize = 10,
|
||||
search?: string
|
||||
filter?: string
|
||||
): Promise<any> {
|
||||
const response = await api.get(`/targets/${id}/endpoints/`, {
|
||||
params: {
|
||||
page,
|
||||
pageSize,
|
||||
...(search && { search }),
|
||||
...(filter && { filter }),
|
||||
},
|
||||
})
|
||||
return response.data
|
||||
|
||||
@@ -5,9 +5,10 @@ export class VulnerabilityService {
|
||||
/** 获取所有漏洞列表(全局漏洞页使用) */
|
||||
static async getAllVulnerabilities(
|
||||
params: GetVulnerabilitiesParams,
|
||||
filter?: string,
|
||||
): Promise<any> {
|
||||
const response = await api.get(`/assets/vulnerabilities/`, {
|
||||
params,
|
||||
params: { ...params, filter },
|
||||
})
|
||||
return response.data
|
||||
}
|
||||
@@ -16,9 +17,10 @@ export class VulnerabilityService {
|
||||
static async getVulnerabilitiesByScanId(
|
||||
scanId: number,
|
||||
params: GetVulnerabilitiesParams,
|
||||
filter?: string,
|
||||
): Promise<any> {
|
||||
const response = await api.get(`/scans/${scanId}/vulnerabilities/`, {
|
||||
params,
|
||||
params: { ...params, filter },
|
||||
})
|
||||
return response.data
|
||||
}
|
||||
@@ -27,9 +29,10 @@ export class VulnerabilityService {
|
||||
static async getVulnerabilitiesByTargetId(
|
||||
targetId: number,
|
||||
params: GetVulnerabilitiesParams,
|
||||
filter?: string,
|
||||
): Promise<any> {
|
||||
const response = await api.get(`/targets/${targetId}/vulnerabilities/`, {
|
||||
params,
|
||||
params: { ...params, filter },
|
||||
})
|
||||
return response.data
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ export interface IPAddress {
|
||||
export interface GetIPAddressesParams {
|
||||
page?: number
|
||||
pageSize?: number
|
||||
search?: string
|
||||
filter?: string // 智能过滤语法字符串
|
||||
}
|
||||
|
||||
export interface GetIPAddressesResponse {
|
||||
|
||||
@@ -69,6 +69,7 @@ export interface GetVulnerabilitiesParams extends PaginationParams {
|
||||
endpointId?: number
|
||||
severity?: VulnerabilitySeverity
|
||||
status?: VulnerabilityStatus
|
||||
filter?: string // 智能过滤语法
|
||||
}
|
||||
|
||||
// 获取漏洞列表响应
|
||||
|
||||
Reference in New Issue
Block a user