Compare commits

...

2 Commits

Author SHA1 Message Date
yyhuni
5acaada7ab 新增:支持多字段搜索功能 2025-12-25 09:54:50 +08:00
github-actions[bot]
aaad3f29cf chore: bump version to v1.1.6 2025-12-24 12:19:12 +00:00
49 changed files with 1490 additions and 658 deletions

View File

@@ -1 +1 @@
v1.1.5
v1.1.6

View File

@@ -135,6 +135,7 @@ class Endpoint(models.Model):
models.Index(fields=['url']), # URL索引优化查询性能
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['status_code']), # 状态码索引,优化筛选
models.Index(fields=['title']), # title索引优化智能过滤搜索
]
constraints = [
# 普通唯一约束url + target 组合唯一
@@ -229,6 +230,8 @@ class WebSite(models.Model):
models.Index(fields=['url']), # URL索引优化查询性能
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['target']), # 优化从target_id快速查找下面的站点
models.Index(fields=['title']), # title索引优化智能过滤搜索
models.Index(fields=['status_code']), # 状态码索引,优化智能过滤搜索
]
constraints = [
# 普通唯一约束url + target 组合唯一
@@ -408,7 +411,7 @@ class Vulnerability(models.Model):
)
# ==================== 核心字段 ====================
url = models.TextField(help_text='漏洞所在的URL')
url = models.CharField(max_length=2000, help_text='漏洞所在的URL')
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli')
severity = models.CharField(
max_length=20,
@@ -434,6 +437,7 @@ class Vulnerability(models.Model):
models.Index(fields=['vuln_type']),
models.Index(fields=['severity']),
models.Index(fields=['source']),
models.Index(fields=['url']), # url索引优化智能过滤搜索
models.Index(fields=['-created_at']),
]

View File

@@ -81,6 +81,7 @@ class WebsiteSnapshot(models.Model):
models.Index(fields=['scan']),
models.Index(fields=['url']),
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['title']), # title索引优化标题搜索
models.Index(fields=['-created_at']),
]
constraints = [
@@ -129,6 +130,7 @@ class DirectorySnapshot(models.Model):
models.Index(fields=['scan']),
models.Index(fields=['url']),
models.Index(fields=['status']), # 状态码索引,优化筛选
models.Index(fields=['content_type']), # content_type索引优化内容类型搜索
models.Index(fields=['-created_at']),
]
constraints = [
@@ -268,7 +270,9 @@ class EndpointSnapshot(models.Model):
models.Index(fields=['scan']),
models.Index(fields=['url']),
models.Index(fields=['host']), # host索引优化根据主机名查询
models.Index(fields=['title']), # title索引优化标题搜索
models.Index(fields=['status_code']), # 状态码索引,优化筛选
models.Index(fields=['webserver']), # webserver索引优化服务器搜索
models.Index(fields=['-created_at']),
]
constraints = [
@@ -302,7 +306,7 @@ class VulnerabilitySnapshot(models.Model):
)
# ==================== 核心字段 ====================
url = models.TextField(help_text='漏洞所在的URL')
url = models.CharField(max_length=2000, help_text='漏洞所在的URL')
vuln_type = models.CharField(max_length=100, help_text='漏洞类型(如 xss, sqli')
severity = models.CharField(
max_length=20,
@@ -325,6 +329,7 @@ class VulnerabilitySnapshot(models.Model):
ordering = ['-created_at']
indexes = [
models.Index(fields=['scan']),
models.Index(fields=['url']), # url索引优化URL搜索
models.Index(fields=['vuln_type']),
models.Index(fields=['severity']),
models.Index(fields=['source']),

View File

@@ -1,7 +1,9 @@
"""HostPortMapping Repository - Django ORM 实现"""
import logging
from typing import List, Iterator
from typing import List, Iterator, Dict, Optional
from django.db.models import QuerySet, Min
from apps.asset.models.asset_models import HostPortMapping
from apps.asset.dtos.asset import HostPortMappingDTO
@@ -13,7 +15,10 @@ logger = logging.getLogger(__name__)
@auto_ensure_db_connection
class DjangoHostPortMappingRepository:
"""HostPortMapping Repository - Django ORM 实现"""
"""HostPortMapping Repository - Django ORM 实现
职责:纯数据访问,不包含业务逻辑
"""
def bulk_create_ignore_conflicts(self, items: List[HostPortMappingDTO]) -> int:
"""
@@ -90,72 +95,20 @@ class DjangoHostPortMappingRepository:
for ip in queryset:
yield ip
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
from django.db.models import Min
def get_queryset_by_target(self, target_id: int) -> QuerySet:
"""获取目标下的 QuerySet"""
return HostPortMapping.objects.filter(target_id=target_id)
qs = HostPortMapping.objects.filter(target_id=target_id)
if search:
qs = qs.filter(ip__icontains=search)
def get_all_queryset(self) -> QuerySet:
"""获取所有记录的 QuerySet"""
return HostPortMapping.objects.all()
ip_aggregated = (
qs
.values('ip')
.annotate(created_at=Min('created_at'))
.order_by('-created_at')
)
results = []
for item in ip_aggregated:
ip = item['ip']
mappings = (
HostPortMapping.objects
.filter(target_id=target_id, ip=ip)
.values('host', 'port')
.distinct()
)
hosts = sorted({m['host'] for m in mappings})
ports = sorted({m['port'] for m in mappings})
results.append({
'ip': ip,
'hosts': hosts,
'ports': ports,
'created_at': item['created_at'],
})
return results
def get_all_ip_aggregation(self, search: str = None):
"""获取所有 IP 聚合数据(全局查询)"""
from django.db.models import Min
qs = HostPortMapping.objects.all()
if search:
qs = qs.filter(ip__icontains=search)
ip_aggregated = (
qs
.values('ip')
.annotate(created_at=Min('created_at'))
.order_by('-created_at')
)
results = []
for item in ip_aggregated:
ip = item['ip']
mappings = (
HostPortMapping.objects
.filter(ip=ip)
.values('host', 'port')
.distinct()
)
hosts = sorted({m['host'] for m in mappings})
ports = sorted({m['port'] for m in mappings})
results.append({
'ip': ip,
'hosts': hosts,
'ports': ports,
'created_at': item['created_at'],
})
return results
def get_queryset_by_ip(self, ip: str, target_id: Optional[int] = None) -> QuerySet:
"""获取指定 IP 的 QuerySet"""
qs = HostPortMapping.objects.filter(ip=ip)
if target_id:
qs = qs.filter(target_id=target_id)
return qs
def iter_raw_data_for_export(
self,

View File

@@ -65,12 +65,20 @@ class DjangoHostPortMappingSnapshotRepository:
)
raise
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
def get_ip_aggregation_by_scan(self, scan_id: int, filter_query: str = None):
from django.db.models import Min
from apps.common.utils.filter_utils import apply_filters
qs = HostPortMappingSnapshot.objects.filter(scan_id=scan_id)
if search:
qs = qs.filter(ip__icontains=search)
# 应用智能过滤
if filter_query:
field_mapping = {
'ip': 'ip',
'port': 'port',
'host': 'host',
}
qs = apply_filters(qs, filter_query, field_mapping)
ip_aggregated = (
qs
@@ -103,13 +111,21 @@ class DjangoHostPortMappingSnapshotRepository:
return results
def get_all_ip_aggregation(self, search: str = None):
def get_all_ip_aggregation(self, filter_query: str = None):
"""获取所有 IP 聚合数据"""
from django.db.models import Min
from apps.common.utils.filter_utils import apply_filters
qs = HostPortMappingSnapshot.objects.all()
if search:
qs = qs.filter(ip__icontains=search)
# 应用智能过滤
if filter_query:
field_mapping = {
'ip': 'ip',
'port': 'port',
'host': 'host',
}
qs = apply_filters(qs, filter_query, field_mapping)
ip_aggregated = (
qs

View File

@@ -1,11 +1,12 @@
"""Directory Service - 目录业务逻辑层"""
import logging
from typing import List, Iterator
from typing import List, Iterator, Optional
from apps.asset.repositories import DjangoDirectoryRepository
from apps.asset.dtos import DirectoryDTO
from apps.common.validators import is_valid_url, is_url_match_target
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@@ -13,6 +14,12 @@ logger = logging.getLogger(__name__)
class DirectoryService:
"""目录业务逻辑层"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'status': 'status',
}
def __init__(self, repository=None):
"""初始化目录服务"""
self.repo = repository or DjangoDirectoryRepository()
@@ -94,13 +101,19 @@ class DirectoryService:
count_after = self.repo.count_by_target(target_id)
return count_after - count_before
def get_directories_by_target(self, target_id: int):
def get_directories_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""获取目标下的所有目录"""
return self.repo.get_by_target(target_id)
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: Optional[str] = None):
"""获取所有目录"""
return self.repo.get_all()
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_directory_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取目标下的所有目录 URL"""

View File

@@ -5,11 +5,12 @@ Endpoint 服务层
"""
import logging
from typing import List, Iterator
from typing import List, Iterator, Optional
from apps.asset.dtos.asset import EndpointDTO
from apps.asset.repositories.asset import DjangoEndpointRepository
from apps.common.validators import is_valid_url, is_url_match_target
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@@ -21,6 +22,14 @@ class EndpointService:
提供 EndpointURL/端点)相关的业务逻辑
"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
}
def __init__(self):
"""初始化 Endpoint 服务"""
self.repo = DjangoEndpointRepository()
@@ -102,9 +111,12 @@ class EndpointService:
count_after = self.repo.count_by_target(target_id)
return count_after - count_before
def get_endpoints_by_target(self, target_id: int):
def get_endpoints_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""获取目标下的所有端点"""
return self.repo.get_by_target(target_id)
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def count_endpoints_by_target(self, target_id: int) -> int:
"""
@@ -118,9 +130,12 @@ class EndpointService:
"""
return self.repo.count_by_target(target_id)
def get_all(self):
def get_all(self, filter_query: Optional[str] = None):
"""获取所有端点(全局查询)"""
return self.repo.get_all()
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_endpoint_urls_by_target(self, target_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取目标下的所有端点 URL用于导出。"""

View File

@@ -1,16 +1,31 @@
"""HostPortMapping Service - 业务逻辑层"""
import logging
from typing import List, Iterator
from typing import List, Iterator, Optional, Dict
from django.db.models import Min
from apps.asset.repositories.asset import DjangoHostPortMappingRepository
from apps.asset.dtos.asset import HostPortMappingDTO
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
class HostPortMappingService:
"""主机端口映射服务 - 负责主机端口映射数据的业务逻辑"""
"""主机端口映射服务 - 负责主机端口映射数据的业务逻辑
职责:
- 业务逻辑处理(过滤、聚合)
- 调用 Repository 进行数据访问
"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'ip': 'ip',
'port': 'port',
'host': 'host',
}
def __init__(self):
self.repo = DjangoHostPortMappingRepository()
@@ -49,12 +64,93 @@ class HostPortMappingService:
def iter_host_port_by_target(self, target_id: int, batch_size: int = 1000):
return self.repo.get_for_export(target_id=target_id, batch_size=batch_size)
def get_ip_aggregation_by_target(self, target_id: int, search: str = None):
return self.repo.get_ip_aggregation_by_target(target_id, search=search)
def get_ip_aggregation_by_target(
self,
target_id: int,
filter_query: Optional[str] = None
) -> List[Dict]:
"""获取目标下的 IP 聚合数据
Args:
target_id: 目标 ID
filter_query: 智能过滤语法字符串
Returns:
聚合后的 IP 数据列表
"""
# 从 Repository 获取基础 QuerySet
qs = self.repo.get_queryset_by_target(target_id)
# Service 层应用过滤逻辑
if filter_query:
qs = apply_filters(qs, filter_query, self.FILTER_FIELD_MAPPING)
# Service 层处理聚合逻辑
return self._aggregate_by_ip(qs, filter_query, target_id=target_id)
def get_all_ip_aggregation(self, search: str = None):
"""获取所有 IP 聚合数据(全局查询)"""
return self.repo.get_all_ip_aggregation(search=search)
def get_all_ip_aggregation(self, filter_query: Optional[str] = None) -> List[Dict]:
"""获取所有 IP 聚合数据(全局查询)
Args:
filter_query: 智能过滤语法字符串
Returns:
聚合后的 IP 数据列表
"""
# 从 Repository 获取基础 QuerySet
qs = self.repo.get_all_queryset()
# Service 层应用过滤逻辑
if filter_query:
qs = apply_filters(qs, filter_query, self.FILTER_FIELD_MAPPING)
# Service 层处理聚合逻辑
return self._aggregate_by_ip(qs, filter_query)
def _aggregate_by_ip(
self,
qs,
filter_query: Optional[str] = None,
target_id: Optional[int] = None
) -> List[Dict]:
"""按 IP 聚合数据
Args:
qs: 已过滤的 QuerySet
filter_query: 过滤条件(用于子查询)
target_id: 目标 ID用于子查询限定范围
Returns:
聚合后的数据列表
"""
ip_aggregated = (
qs
.values('ip')
.annotate(created_at=Min('created_at'))
.order_by('-created_at')
)
results = []
for item in ip_aggregated:
ip = item['ip']
# 获取该 IP 的所有 host 和 port也需要应用过滤条件
mappings_qs = self.repo.get_queryset_by_ip(ip, target_id=target_id)
if filter_query:
mappings_qs = apply_filters(mappings_qs, filter_query, self.FILTER_FIELD_MAPPING)
mappings = mappings_qs.values('host', 'port').distinct()
hosts = sorted({m['host'] for m in mappings})
ports = sorted({m['port'] for m in mappings})
results.append({
'ip': ip,
'hosts': hosts,
'ports': ports,
'created_at': item['created_at'],
})
return results
def iter_ips_by_target(self, target_id: int, batch_size: int = 1000) -> Iterator[str]:
"""流式获取目标下的所有唯一 IP 地址。"""

View File

@@ -1,10 +1,11 @@
import logging
from typing import Tuple, List, Dict
from typing import List, Dict, Optional
from dataclasses import dataclass
from apps.asset.repositories import DjangoSubdomainRepository
from apps.asset.dtos import SubdomainDTO
from apps.common.validators import is_valid_domain
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@@ -22,6 +23,11 @@ class BulkCreateResult:
class SubdomainService:
"""子域名业务逻辑层"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'name': 'name',
}
def __init__(self, repository=None):
"""
初始化子域名服务
@@ -33,30 +39,50 @@ class SubdomainService:
# ==================== 查询操作 ====================
def get_all(self):
def get_all(self, filter_query: Optional[str] = None):
"""
获取所有子域名
Args:
filter_query: 智能过滤语法字符串
Returns:
QuerySet: 子域名查询集
"""
logger.debug("获取所有子域名")
return self.repo.get_all()
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
# ==================== 创建操作 ====================
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
def get_subdomains_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""
批量创建子域名,忽略冲突
获取目标下的子域名
Args:
items: 子域名 DTO 列表
target_id: 目标 ID
filter_query: 智能过滤语法字符串
Note:
使用 ignore_conflicts 策略,重复记录会被跳过
Returns:
QuerySet: 子域名查询集
"""
logger.debug("批量创建子域名 - 数量: %d", len(items))
return self.repo.bulk_create_ignore_conflicts(items)
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def count_subdomains_by_target(self, target_id: int) -> int:
"""
统计目标下的子域名数量
Args:
target_id: 目标 ID
Returns:
int: 子域名数量
"""
logger.debug("统计目标下子域名数量 - Target ID: %d", target_id)
return self.repo.count_by_target(target_id)
def get_by_names_and_target_id(self, names: set, target_id: int) -> dict:
"""
@@ -83,25 +109,8 @@ class SubdomainService:
List[str]: 子域名名称列表
"""
logger.debug("获取目标下所有子域名 - Target ID: %d", target_id)
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
return list(self.repo.get_domains_for_export(target_id=target_id))
def get_subdomains_by_target(self, target_id: int):
return self.repo.get_by_target(target_id)
def count_subdomains_by_target(self, target_id: int) -> int:
"""
统计目标下的子域名数量
Args:
target_id: 目标 ID
Returns:
int: 子域名数量
"""
logger.debug("统计目标下子域名数量 - Target ID: %d", target_id)
return self.repo.count_by_target(target_id)
def iter_subdomain_names_by_target(self, target_id: int, chunk_size: int = 1000):
"""
流式获取目标下的所有子域名名称(内存优化)
@@ -114,7 +123,6 @@ class SubdomainService:
str: 子域名名称
"""
logger.debug("流式获取目标下所有子域名 - Target ID: %d, 批次大小: %d", target_id, chunk_size)
# 通过仓储层统一访问数据库,内部已使用 iterator() 做流式查询
return self.repo.get_domains_for_export(target_id=target_id, batch_size=chunk_size)
def iter_raw_data_for_csv_export(self, target_id: int):
@@ -129,6 +137,21 @@ class SubdomainService:
"""
return self.repo.iter_raw_data_for_export(target_id=target_id)
# ==================== 创建操作 ====================
def bulk_create_ignore_conflicts(self, items: List[SubdomainDTO]) -> None:
"""
批量创建子域名,忽略冲突
Args:
items: 子域名 DTO 列表
Note:
使用 ignore_conflicts 策略,重复记录会被跳过
"""
logger.debug("批量创建子域名 - 数量: %d", len(items))
return self.repo.bulk_create_ignore_conflicts(items)
def bulk_create_subdomains(
self,
target_id: int,

View File

@@ -1,12 +1,13 @@
"""Vulnerability Service - 漏洞资产业务逻辑层"""
import logging
from typing import List
from typing import List, Optional
from apps.asset.models import Vulnerability
from apps.asset.dtos.asset import VulnerabilityDTO
from apps.common.decorators import auto_ensure_db_connection
from apps.common.utils import deduplicate_for_bulk
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@@ -17,6 +18,14 @@ class VulnerabilityService:
当前提供基础的批量创建能力,使用 ignore_conflicts 依赖数据库唯一约束去重。
"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'type': 'vuln_type',
'severity': 'severity',
'source': 'source',
'url': 'url',
}
def bulk_create_ignore_conflicts(self, items: List[VulnerabilityDTO]) -> None:
"""批量创建漏洞资产记录,忽略冲突。
@@ -63,24 +72,34 @@ class VulnerabilityService:
# ==================== 查询方法 ====================
def get_all(self):
def get_all(self, filter_query: Optional[str] = None):
"""获取所有漏洞 QuerySet用于全局漏洞列表
Args:
filter_query: 智能过滤语法字符串
Returns:
QuerySet[Vulnerability]: 所有漏洞,按创建时间倒序
"""
return Vulnerability.objects.all().order_by("-created_at")
queryset = Vulnerability.objects.all().order_by("-created_at")
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_vulnerabilities_by_target(self, target_id: int):
def get_vulnerabilities_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""按目标获取漏洞 QuerySet用于分页
Args:
target_id: 目标 ID
filter_query: 智能过滤语法字符串
Returns:
QuerySet[Vulnerability]: 目标下的所有漏洞,按创建时间倒序
"""
return Vulnerability.objects.filter(target_id=target_id).order_by("-created_at")
queryset = Vulnerability.objects.filter(target_id=target_id).order_by("-created_at")
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def count_by_target(self, target_id: int) -> int:
"""统计目标下的漏洞数量。"""

View File

@@ -1,11 +1,12 @@
"""WebSite Service - 网站业务逻辑层"""
import logging
from typing import List, Iterator
from typing import List, Iterator, Optional
from apps.asset.repositories import DjangoWebSiteRepository
from apps.asset.dtos import WebSiteDTO
from apps.common.validators import is_valid_url, is_url_match_target
from apps.common.utils.filter_utils import apply_filters
logger = logging.getLogger(__name__)
@@ -13,6 +14,14 @@ logger = logging.getLogger(__name__)
class WebSiteService:
"""网站业务逻辑层"""
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
}
def __init__(self, repository=None):
"""初始化网站服务"""
self.repo = repository or DjangoWebSiteRepository()
@@ -94,13 +103,19 @@ class WebSiteService:
count_after = self.repo.count_by_target(target_id)
return count_after - count_before
def get_websites_by_target(self, target_id: int):
def get_websites_by_target(self, target_id: int, filter_query: Optional[str] = None):
"""获取目标下的所有网站"""
return self.repo.get_by_target(target_id)
queryset = self.repo.get_by_target(target_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: Optional[str] = None):
"""获取所有网站"""
return self.repo.get_all()
queryset = self.repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_by_url(self, url: str, target_id: int) -> int:
"""根据 URL 和 target_id 查找网站 ID"""

View File

@@ -67,12 +67,29 @@ class DirectorySnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
return self.snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'status': 'status',
'content_type': 'content_type',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: str = None):
"""获取所有目录快照"""
return self.snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_directory_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有目录 URL。"""

View File

@@ -67,12 +67,32 @@ class EndpointSnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
return self.snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status_code',
'webserver': 'webserver',
'tech': 'tech',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: str = None):
"""获取所有端点快照"""
return self.snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_endpoint_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有端点 URL。"""

View File

@@ -69,12 +69,12 @@ class HostPortMappingSnapshotsService:
)
raise
def get_ip_aggregation_by_scan(self, scan_id: int, search: str = None):
return self.snapshot_repo.get_ip_aggregation_by_scan(scan_id, search=search)
def get_ip_aggregation_by_scan(self, scan_id: int, filter_query: str = None):
return self.snapshot_repo.get_ip_aggregation_by_scan(scan_id, filter_query=filter_query)
def get_all_ip_aggregation(self, search: str = None):
def get_all_ip_aggregation(self, filter_query: str = None):
"""获取所有 IP 聚合数据"""
return self.snapshot_repo.get_all_ip_aggregation(search=search)
return self.snapshot_repo.get_all_ip_aggregation(filter_query=filter_query)
def iter_ips_by_scan(self, scan_id: int, batch_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有唯一 IP 地址。"""

View File

@@ -66,12 +66,27 @@ class SubdomainSnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
return self.subdomain_snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'name': 'name',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
from apps.common.utils.filter_utils import apply_filters
queryset = self.subdomain_snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: str = None):
"""获取所有子域名快照"""
return self.subdomain_snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.subdomain_snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_subdomain_names_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
queryset = self.subdomain_snapshot_repo.get_by_scan(scan_id)

View File

@@ -66,13 +66,31 @@ class VulnerabilitySnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
"""按扫描任务获取所有漏洞快照。"""
return self.snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'type': 'vuln_type',
'url': 'url',
'severity': 'severity',
'source': 'source',
}
def get_all(self):
def get_by_scan(self, scan_id: int, filter_query: str = None):
"""按扫描任务获取所有漏洞快照。"""
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self, filter_query: str = None):
"""获取所有漏洞快照"""
return self.snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_vuln_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有漏洞 URL。"""

View File

@@ -68,12 +68,32 @@ class WebsiteSnapshotsService:
)
raise
def get_by_scan(self, scan_id: int):
return self.snapshot_repo.get_by_scan(scan_id)
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'url': 'url',
'host': 'host',
'title': 'title',
'status': 'status',
'webserver': 'web_server',
'tech': 'tech',
}
def get_by_scan(self, scan_id: int, filter_query: str = None):
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_by_scan(scan_id)
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def get_all(self):
def get_all(self, filter_query: str = None):
"""获取所有网站快照"""
return self.snapshot_repo.get_all()
from apps.common.utils.filter_utils import apply_filters
queryset = self.snapshot_repo.get_all()
if filter_query:
queryset = apply_filters(queryset, filter_query, self.FILTER_FIELD_MAPPING)
return queryset
def iter_website_urls_by_scan(self, scan_id: int, chunk_size: int = 1000) -> Iterator[str]:
"""流式获取某次扫描下的所有站点 URL按创建时间倒序"""

View File

@@ -126,12 +126,16 @@ class SubdomainViewSet(viewsets.ModelViewSet):
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/subdomains/
2. 独立路由GET /api/subdomains/(全局查询)
支持智能过滤语法filter 参数):
- name="api" 子域名模糊匹配
- name=="api.example.com" 精确匹配
- 多条件空格分隔 AND 关系
"""
serializer_class = SubdomainListSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['name']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
@@ -139,11 +143,13 @@ class SubdomainViewSet(viewsets.ModelViewSet):
self.service = SubdomainService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
target_pk = self.kwargs.get('target_pk')
filter_query = self.request.query_params.get('filter', None)
if target_pk:
return self.service.get_subdomains_by_target(target_pk)
return self.service.get_all()
return self.service.get_subdomains_by_target(target_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['post'], url_path='bulk-create')
def bulk_create(self, request, **kwargs):
@@ -253,12 +259,18 @@ class WebSiteViewSet(viewsets.ModelViewSet):
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/websites/
2. 独立路由GET /api/websites/(全局查询)
支持智能过滤语法filter 参数):
- url="api" URL 模糊匹配
- host="example" 主机名模糊匹配
- title="login" 标题模糊匹配
- status="200,301" 状态码多值匹配
- 多条件空格分隔 AND 关系
"""
serializer_class = WebSiteSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['host']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
@@ -266,11 +278,13 @@ class WebSiteViewSet(viewsets.ModelViewSet):
self.service = WebSiteService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
target_pk = self.kwargs.get('target_pk')
filter_query = self.request.query_params.get('filter', None)
if target_pk:
return self.service.get_websites_by_target(target_pk)
return self.service.get_all()
return self.service.get_websites_by_target(target_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['post'], url_path='bulk-create')
def bulk_create(self, request, **kwargs):
@@ -374,12 +388,16 @@ class DirectoryViewSet(viewsets.ModelViewSet):
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/directories/
2. 独立路由GET /api/directories/(全局查询)
支持智能过滤语法filter 参数):
- url="admin" URL 模糊匹配
- status="200,301" 状态码多值匹配
- 多条件空格分隔 AND 关系
"""
serializer_class = DirectorySerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['url']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
@@ -387,11 +405,13 @@ class DirectoryViewSet(viewsets.ModelViewSet):
self.service = DirectoryService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
target_pk = self.kwargs.get('target_pk')
filter_query = self.request.query_params.get('filter', None)
if target_pk:
return self.service.get_directories_by_target(target_pk)
return self.service.get_all()
return self.service.get_directories_by_target(target_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['post'], url_path='bulk-create')
def bulk_create(self, request, **kwargs):
@@ -493,12 +513,18 @@ class EndpointViewSet(viewsets.ModelViewSet):
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/endpoints/
2. 独立路由GET /api/endpoints/(全局查询)
支持智能过滤语法filter 参数):
- url="api" URL 模糊匹配
- host="example" 主机名模糊匹配
- title="login" 标题模糊匹配
- status="200,301" 状态码多值匹配
- 多条件空格分隔 AND 关系
"""
serializer_class = EndpointListSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['host']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
@@ -506,11 +532,13 @@ class EndpointViewSet(viewsets.ModelViewSet):
self.service = EndpointService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
target_pk = self.kwargs.get('target_pk')
filter_query = self.request.query_params.get('filter', None)
if target_pk:
return self.service.get_endpoints_by_target(target_pk)
return self.service.get_all()
return self.service.get_endpoints_by_target(target_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['post'], url_path='bulk-create')
def bulk_create(self, request, **kwargs):
@@ -618,23 +646,40 @@ class HostPortMappingViewSet(viewsets.ModelViewSet):
返回按 IP 聚合的数据,每个 IP 显示其关联的所有 hosts 和 ports
支持智能过滤语法filter 参数):
- ip="192.168" IP 模糊匹配
- port="80,443" 端口多值匹配
- host="api" 主机名模糊匹配
- 多条件空格分隔 AND 关系
注意:由于返回的是聚合数据(字典列表),不支持 DRF SearchFilter
"""
serializer_class = IPAddressAggregatedSerializer
pagination_class = BasePagination
# 智能过滤字段映射
FILTER_FIELD_MAPPING = {
'ip': 'ip',
'port': 'port',
'host': 'host',
}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.service = HostPortMappingService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围,返回按 IP 聚合的数据"""
"""根据是否有 target_pk 参数决定查询范围,返回按 IP 聚合的数据
支持智能过滤语法filter 参数)
"""
target_pk = self.kwargs.get('target_pk')
search = self.request.query_params.get('search', None)
filter_query = self.request.query_params.get('filter', None)
if target_pk:
return self.service.get_ip_aggregation_by_target(target_pk, search=search)
return self.service.get_all_ip_aggregation(search=search)
return self.service.get_ip_aggregation_by_target(target_pk, filter_query=filter_query)
return self.service.get_all_ip_aggregation(filter_query=filter_query)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
@@ -673,12 +718,18 @@ class VulnerabilityViewSet(viewsets.ModelViewSet):
支持两种访问方式:
1. 嵌套路由GET /api/targets/{target_pk}/vulnerabilities/
2. 独立路由GET /api/vulnerabilities/(全局查询)
支持智能过滤语法filter 参数):
- type="xss" 漏洞类型模糊匹配
- severity="high" 严重程度匹配
- source="nuclei" 来源工具匹配
- url="api" URL 模糊匹配
- 多条件空格分隔 AND 关系
"""
serializer_class = VulnerabilitySerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['vuln_type']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
@@ -686,22 +737,29 @@ class VulnerabilityViewSet(viewsets.ModelViewSet):
self.service = VulnerabilityService()
def get_queryset(self):
"""根据是否有 target_pk 参数决定查询范围"""
"""根据是否有 target_pk 参数决定查询范围,支持智能过滤"""
target_pk = self.kwargs.get('target_pk')
filter_query = self.request.query_params.get('filter', None)
if target_pk:
return self.service.get_vulnerabilities_by_target(target_pk)
return self.service.get_all()
return self.service.get_vulnerabilities_by_target(target_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
# ==================== 快照 ViewSetScan 嵌套路由) ====================
class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
"""子域名快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/subdomains/"""
"""子域名快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/subdomains/
支持智能过滤语法filter 参数):
- name="api" 子域名模糊匹配
- name=="api.example.com" 精确匹配
- name!="test" 排除匹配
"""
serializer_class = SubdomainSnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['name']
filter_backends = [filters.OrderingFilter]
ordering_fields = ['name', 'created_at']
ordering = ['-created_at']
@@ -711,9 +769,11 @@ class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
filter_query = self.request.query_params.get('filter', None)
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
@@ -741,12 +801,20 @@ class SubdomainSnapshotViewSet(viewsets.ModelViewSet):
class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
"""网站快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/websites/"""
"""网站快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/websites/
支持智能过滤语法filter 参数):
- url="api" URL 模糊匹配
- host="example" 主机名模糊匹配
- title="login" 标题模糊匹配
- status="200" 状态码匹配
- webserver="nginx" 服务器类型匹配
- tech="php" 技术栈匹配
"""
serializer_class = WebsiteSnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['host']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
@@ -755,9 +823,11 @@ class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
filter_query = self.request.query_params.get('filter', None)
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
@@ -792,12 +862,17 @@ class WebsiteSnapshotViewSet(viewsets.ModelViewSet):
class DirectorySnapshotViewSet(viewsets.ModelViewSet):
"""目录快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/directories/"""
"""目录快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/directories/
支持智能过滤语法filter 参数):
- url="admin" URL 模糊匹配
- status="200" 状态码匹配
- content_type="html" 内容类型匹配
"""
serializer_class = DirectorySnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['url']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
@@ -806,9 +881,11 @@ class DirectorySnapshotViewSet(viewsets.ModelViewSet):
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
filter_query = self.request.query_params.get('filter', None)
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
@@ -841,12 +918,20 @@ class DirectorySnapshotViewSet(viewsets.ModelViewSet):
class EndpointSnapshotViewSet(viewsets.ModelViewSet):
"""端点快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/endpoints/"""
"""端点快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/endpoints/
支持智能过滤语法filter 参数):
- url="api" URL 模糊匹配
- host="example" 主机名模糊匹配
- title="login" 标题模糊匹配
- status="200" 状态码匹配
- webserver="nginx" 服务器类型匹配
- tech="php" 技术栈匹配
"""
serializer_class = EndpointSnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['host']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
@@ -855,9 +940,11 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
filter_query = self.request.query_params.get('filter', None)
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
@@ -895,7 +982,12 @@ class EndpointSnapshotViewSet(viewsets.ModelViewSet):
class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
"""主机端口映射快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/ip-addresses/
注意:由于返回的是聚合数据(字典列表),不支持 DRF SearchFilter
支持智能过滤语法filter 参数):
- ip="192.168" IP 模糊匹配
- port="80" 端口匹配
- host="api" 主机名模糊匹配
注意:由于返回的是聚合数据(字典列表),过滤在 Service 层处理
"""
serializer_class = IPAddressAggregatedSerializer
@@ -907,10 +999,11 @@ class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
search = self.request.query_params.get('search', None)
filter_query = self.request.query_params.get('filter', None)
if scan_pk:
return self.service.get_ip_aggregation_by_scan(scan_pk, search=search)
return self.service.get_all_ip_aggregation(search=search)
return self.service.get_ip_aggregation_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all_ip_aggregation(filter_query=filter_query)
@action(detail=False, methods=['get'], url_path='export')
def export(self, request, **kwargs):
@@ -944,12 +1037,18 @@ class HostPortMappingSnapshotViewSet(viewsets.ModelViewSet):
class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
"""漏洞快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/vulnerabilities/"""
"""漏洞快照 ViewSet - 嵌套路由GET /api/scans/{scan_pk}/vulnerabilities/
支持智能过滤语法filter 参数):
- type="xss" 漏洞类型模糊匹配
- url="api" URL 模糊匹配
- severity="high" 严重程度匹配
- source="nuclei" 来源工具匹配
"""
serializer_class = VulnerabilitySnapshotSerializer
pagination_class = BasePagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['vuln_type']
filter_backends = [filters.OrderingFilter]
ordering = ['-created_at']
def __init__(self, **kwargs):
@@ -958,6 +1057,8 @@ class VulnerabilitySnapshotViewSet(viewsets.ModelViewSet):
def get_queryset(self):
scan_pk = self.kwargs.get('scan_pk')
filter_query = self.request.query_params.get('filter', None)
if scan_pk:
return self.service.get_by_scan(scan_pk)
return self.service.get_all()
return self.service.get_by_scan(scan_pk, filter_query=filter_query)
return self.service.get_all(filter_query=filter_query)

View File

@@ -3,8 +3,13 @@
提供系统级别的公共服务,包括:
- SystemLogService: 系统日志读取服务
注意FilterService 已移至 apps.common.utils.filter_utils
推荐使用: from apps.common.utils.filter_utils import apply_filters
"""
from .system_log_service import SystemLogService
__all__ = ['SystemLogService']
__all__ = [
'SystemLogService',
]

View File

@@ -0,0 +1,260 @@
"""智能过滤工具 - 通用查询语法解析和 Django ORM 查询构建
支持的语法:
- field="value" 模糊匹配(包含)
- field=="value" 精确匹配
- field!="value" 不等于
逻辑运算符:
- AND: && 或 and 或 空格(默认)
- OR: || 或 or
示例:
type="xss" || type="sqli" # OR
type="xss" or type="sqli" # OR等价
severity="high" && source="nuclei" # AND
severity="high" source="nuclei" # AND空格默认为 AND
severity="high" and source="nuclei" # AND等价
使用示例:
from apps.common.utils.filter_utils import apply_filters
field_mapping = {'ip': 'ip', 'port': 'port', 'host': 'host'}
queryset = apply_filters(queryset, 'ip="192" || port="80"', field_mapping)
"""
import re
import logging
from dataclasses import dataclass
from typing import List, Dict, Optional, Union
from enum import Enum
from django.db.models import QuerySet, Q
logger = logging.getLogger(__name__)
class LogicalOp(Enum):
"""逻辑运算符"""
AND = 'AND'
OR = 'OR'
@dataclass
class ParsedFilter:
"""解析后的过滤条件"""
field: str # 字段名
operator: str # 操作符: '=', '==', '!='
value: str # 原始值
@dataclass
class FilterGroup:
"""过滤条件组(带逻辑运算符)"""
filter: ParsedFilter
logical_op: LogicalOp # 与前一个条件的逻辑关系
class QueryParser:
"""查询语法解析器
支持 ||/or (OR) 和 &&/and/空格 (AND) 逻辑运算符
"""
# 正则匹配: field="value", field=="value", field!="value"
FILTER_PATTERN = re.compile(r'(\w+)(==|!=|=)"([^"]*)"')
# 逻辑运算符模式(带空格)
OR_PATTERN = re.compile(r'\s*(\|\||(?<![a-zA-Z])or(?![a-zA-Z]))\s*', re.IGNORECASE)
AND_PATTERN = re.compile(r'\s*(&&|(?<![a-zA-Z])and(?![a-zA-Z]))\s*', re.IGNORECASE)
@classmethod
def parse(cls, query_string: str) -> List[FilterGroup]:
"""解析查询语法字符串
Args:
query_string: 查询语法字符串
Returns:
解析后的过滤条件组列表
Examples:
>>> QueryParser.parse('type="xss" || type="sqli"')
[FilterGroup(filter=..., logical_op=AND), # 第一个默认 AND
FilterGroup(filter=..., logical_op=OR)]
"""
if not query_string or not query_string.strip():
return []
# 标准化逻辑运算符
# 先处理 || 和 or -> __OR__
normalized = cls.OR_PATTERN.sub(' __OR__ ', query_string)
# 再处理 && 和 and -> __AND__
normalized = cls.AND_PATTERN.sub(' __AND__ ', normalized)
# 分词:按空格分割,保留逻辑运算符标记
tokens = normalized.split()
groups = []
pending_op = LogicalOp.AND # 默认 AND
for token in tokens:
if token == '__OR__':
pending_op = LogicalOp.OR
elif token == '__AND__':
pending_op = LogicalOp.AND
else:
# 尝试解析为过滤条件
match = cls.FILTER_PATTERN.match(token)
if match:
field, operator, value = match.groups()
groups.append(FilterGroup(
filter=ParsedFilter(
field=field.lower(),
operator=operator,
value=value
),
logical_op=pending_op if groups else LogicalOp.AND # 第一个条件默认 AND
))
pending_op = LogicalOp.AND # 重置为默认 AND
return groups
class QueryBuilder:
"""Django ORM 查询构建器
将解析后的过滤条件转换为 Django ORM 查询,支持 AND/OR 逻辑
"""
@classmethod
def build_query(
cls,
queryset: QuerySet,
filter_groups: List[FilterGroup],
field_mapping: Dict[str, str]
) -> QuerySet:
"""构建 Django ORM 查询
Args:
queryset: Django QuerySet
filter_groups: 解析后的过滤条件组列表
field_mapping: 字段映射
Returns:
过滤后的 QuerySet
"""
if not filter_groups:
return queryset
# 构建 Q 对象
combined_q = None
for group in filter_groups:
f = group.filter
# 字段映射
db_field = field_mapping.get(f.field)
if not db_field:
logger.debug(f"忽略未知字段: {f.field}")
continue
# 构建单个条件的 Q 对象
q = cls._build_single_q(db_field, f.operator, f.value)
if q is None:
continue
# 组合 Q 对象
if combined_q is None:
combined_q = q
elif group.logical_op == LogicalOp.OR:
combined_q = combined_q | q
else: # AND
combined_q = combined_q & q
if combined_q is not None:
return queryset.filter(combined_q)
return queryset
@classmethod
def _build_single_q(cls, field: str, operator: str, value: str) -> Optional[Q]:
"""构建单个条件的 Q 对象"""
if operator == '!=':
return cls._build_not_equal_q(field, value)
elif operator == '==':
return cls._build_exact_q(field, value)
else: # '='
return cls._build_fuzzy_q(field, value)
@classmethod
def _try_convert_to_int(cls, value: str) -> Optional[int]:
"""尝试将值转换为整数"""
try:
return int(value.strip())
except (ValueError, TypeError):
return None
@classmethod
def _build_fuzzy_q(cls, field: str, value: str) -> Q:
"""模糊匹配: 包含"""
return Q(**{f'{field}__icontains': value})
@classmethod
def _build_exact_q(cls, field: str, value: str) -> Q:
"""精确匹配"""
int_val = cls._try_convert_to_int(value)
if int_val is not None:
return Q(**{f'{field}__exact': int_val})
return Q(**{f'{field}__exact': value})
@classmethod
def _build_not_equal_q(cls, field: str, value: str) -> Q:
"""不等于"""
int_val = cls._try_convert_to_int(value)
if int_val is not None:
return ~Q(**{f'{field}__exact': int_val})
return ~Q(**{f'{field}__exact': value})
def apply_filters(
queryset: QuerySet,
query_string: str,
field_mapping: Dict[str, str]
) -> QuerySet:
"""应用过滤条件到 QuerySet
Args:
queryset: Django QuerySet
query_string: 查询语法字符串
field_mapping: 字段映射
Returns:
过滤后的 QuerySet
Examples:
# OR 查询
apply_filters(qs, 'type="xss" || type="sqli"', mapping)
apply_filters(qs, 'type="xss" or type="sqli"', mapping)
# AND 查询
apply_filters(qs, 'severity="high" && source="nuclei"', mapping)
apply_filters(qs, 'severity="high" source="nuclei"', mapping)
# 混合查询
apply_filters(qs, 'type="xss" || type="sqli" && severity="high"', mapping)
"""
if not query_string or not query_string.strip():
return queryset
try:
filter_groups = QueryParser.parse(query_string)
if not filter_groups:
logger.debug(f"未解析到有效过滤条件: {query_string}")
return queryset
logger.debug(f"解析过滤条件: {filter_groups}")
return QueryBuilder.build_query(queryset, filter_groups, field_mapping)
except Exception as e:
logger.warning(f"过滤解析错误: {e}, query: {query_string}")
return queryset # 静默降级

View File

@@ -8,7 +8,7 @@ export default function ScanHistoryVulnerabilitiesPage() {
const { id } = useParams<{ id: string }>()
return (
<div className="relative flex flex-col gap-4 overflow-auto px-4 lg:px-6">
<div className="px-4 lg:px-6">
<VulnerabilitiesDetailView scanId={Number(id)} />
</div>
)

View File

@@ -12,7 +12,7 @@ export default function TargetVulnerabilitiesPage() {
const { id } = useParams<{ id: string }>()
return (
<div className="relative flex flex-col gap-4 overflow-auto px-4 lg:px-6">
<div className="px-4 lg:px-6">
<VulnerabilitiesDetailView targetId={parseInt(id)} />
</div>
)

View File

@@ -0,0 +1,330 @@
"use client"
import * as React from "react"
import { IconSearch } from "@tabler/icons-react"
import {
Command,
CommandEmpty,
CommandGroup,
CommandItem,
CommandList,
} from "@/components/ui/command"
import {
Popover,
PopoverContent,
PopoverAnchor,
} from "@/components/ui/popover"
import { Button } from "@/components/ui/button"
import { Badge } from "@/components/ui/badge"
import { Input } from "@/components/ui/input"
// 可用的筛选字段定义
export interface FilterField {
key: string
label: string
description: string
}
// 预定义的字段配置,各页面可以选择使用
export const PREDEFINED_FIELDS: Record<string, FilterField> = {
ip: { key: "ip", label: "IP", description: "IP address" },
port: { key: "port", label: "Port", description: "Port number" },
host: { key: "host", label: "Host", description: "Hostname" },
domain: { key: "domain", label: "Domain", description: "Domain name" },
url: { key: "url", label: "URL", description: "Full URL" },
status: { key: "status", label: "Status", description: "HTTP status code" },
title: { key: "title", label: "Title", description: "Page title" },
source: { key: "source", label: "Source", description: "Data source" },
path: { key: "path", label: "Path", description: "URL path" },
severity: { key: "severity", label: "Severity", description: "Vulnerability severity" },
name: { key: "name", label: "Name", description: "Name" },
type: { key: "type", label: "Type", description: "Type" },
}
// 默认字段IP Addresses 页面)
const DEFAULT_FIELDS: FilterField[] = [
PREDEFINED_FIELDS.ip,
PREDEFINED_FIELDS.port,
PREDEFINED_FIELDS.host,
]
// 解析筛选表达式 (FOFA 风格)
interface ParsedFilter {
field: string
operator: string
value: string
raw: string
}
function parseFilterExpression(input: string): ParsedFilter[] {
const filters: ParsedFilter[] = []
// 匹配 FOFA 风格: field="value", field=="value", field!="value"
// == 精确匹配, = 模糊匹配, != 不等于
// 支持逗号分隔多值: port="80,443,8080"
const regex = /(\w+)(==|!=|=)"([^"]+)"/g
let match
while ((match = regex.exec(input)) !== null) {
const [raw, field, operator, value] = match
filters.push({ field, operator, value, raw })
}
return filters
}
interface SmartFilterInputProps {
/** 可用的筛选字段,不传则使用默认字段 */
fields?: FilterField[]
/** 组合示例(使用逻辑运算符拼接的完整示例) */
examples?: string[]
placeholder?: string
/** 受控模式:当前过滤值 */
value?: string
onSearch?: (filters: ParsedFilter[], rawQuery: string) => void
className?: string
}
export function SmartFilterInput({
fields = DEFAULT_FIELDS,
examples,
placeholder,
value,
onSearch,
className,
}: SmartFilterInputProps) {
const [open, setOpen] = React.useState(false)
const [inputValue, setInputValue] = React.useState(value ?? "")
const inputRef = React.useRef<HTMLInputElement>(null)
const listRef = React.useRef<HTMLDivElement>(null)
const savedScrollTop = React.useRef<number | null>(null)
const hasInitialized = React.useRef(false)
// 同步外部 value 变化
React.useEffect(() => {
if (value !== undefined) {
setInputValue(value)
}
}, [value])
// 当 Popover 打开时,恢复滚动位置(首次打开时滚动到顶部)
React.useEffect(() => {
if (open) {
const restoreScroll = () => {
if (listRef.current) {
if (!hasInitialized.current) {
// 首次打开,滚动到顶部
listRef.current.scrollTop = 0
hasInitialized.current = true
} else if (savedScrollTop.current !== null) {
// 之后恢复保存的滚动位置
listRef.current.scrollTop = savedScrollTop.current
}
}
}
// 立即执行一次
restoreScroll()
// 延迟执行确保 Popover 动画完成
const timer = setTimeout(restoreScroll, 50)
return () => clearTimeout(timer)
} else {
// Popover 关闭时保存滚动位置
if (listRef.current) {
savedScrollTop.current = listRef.current.scrollTop
}
}
}, [open])
// 生成默认 placeholder使用第一个示例或字段组合
const defaultPlaceholder = React.useMemo(() => {
if (examples && examples.length > 0) {
return examples[0]
}
// 使用字段生成简单示例
return fields.slice(0, 2).map(f => `${f.key}="..."`).join(" && ")
}, [fields, examples])
// 解析当前输入
const parsedFilters = parseFilterExpression(inputValue)
// 获取当前正在输入的词
const getCurrentWord = () => {
const words = inputValue.split(/\s+/)
return words[words.length - 1] || ""
}
const currentWord = getCurrentWord()
// 判断是否显示字段建议 (FOFA 风格用 = 而不是 :)
const showFieldSuggestions = !currentWord.includes("=")
// 处理选择建议 (FOFA 风格: field=""),然后关闭弹窗
const handleSelectSuggestion = (suggestion: string) => {
const words = inputValue.split(/\s+/)
words[words.length - 1] = suggestion
const newValue = words.join(" ")
setInputValue(newValue)
setOpen(false)
inputRef.current?.blur()
}
// 处理搜索
const handleSearch = () => {
onSearch?.(parsedFilters, inputValue)
setOpen(false)
}
// 处理键盘事件
const handleKeyDown = (e: React.KeyboardEvent) => {
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault()
handleSearch()
}
if (e.key === "Escape") {
setOpen(false)
}
}
// 附加示例到输入框(而非覆盖),然后关闭弹窗
const handleAppendExample = (example: string) => {
const trimmed = inputValue.trim()
const newValue = trimmed ? `${trimmed} ${example}` : example
setInputValue(newValue)
setOpen(false)
inputRef.current?.blur()
}
return (
<div className={className}>
<Popover open={open} onOpenChange={setOpen} modal={false}>
<PopoverAnchor asChild>
<div className="flex items-center gap-2">
<Input
ref={inputRef}
type="text"
value={inputValue}
onChange={(e) => {
setInputValue(e.target.value)
if (!open) setOpen(true)
}}
onFocus={() => setOpen(true)}
onBlur={(e) => {
// 如果焦点转移到 Popover 内部或输入框自身,不关闭
const relatedTarget = e.relatedTarget as HTMLElement | null
if (relatedTarget?.closest('[data-radix-popper-content-wrapper]')) {
return
}
// 延迟关闭,让 CommandItem 的 onSelect 先执行
setTimeout(() => setOpen(false), 150)
}}
onKeyDown={handleKeyDown}
placeholder={placeholder || defaultPlaceholder}
className="h-8 w-full"
/>
<Button variant="outline" size="sm" onClick={handleSearch}>
<IconSearch className="h-4 w-4" />
</Button>
</div>
</PopoverAnchor>
<PopoverContent
className="w-[var(--radix-popover-trigger-width)] p-0"
align="start"
side="bottom"
sideOffset={4}
collisionPadding={16}
onOpenAutoFocus={(e) => e.preventDefault()}
onCloseAutoFocus={(e) => e.preventDefault()}
onPointerDownOutside={(e) => {
// 如果点击的是输入框,不关闭弹窗
if (inputRef.current?.contains(e.target as Node)) {
e.preventDefault()
}
}}
>
<Command>
<CommandList ref={listRef}>
{/* 已解析的筛选条件预览 */}
{parsedFilters.length > 0 && (
<CommandGroup heading="Active filters">
<div className="flex flex-wrap gap-1 px-2 py-1">
{parsedFilters.map((filter, i) => (
<Badge key={i} variant="secondary" className="text-xs font-mono">
{filter.raw}
</Badge>
))}
</div>
</CommandGroup>
)}
{/* 可用字段 */}
{showFieldSuggestions && (
<CommandGroup heading="Available fields">
<div className="flex flex-wrap gap-1 px-2 py-1">
{fields.filter(
(f) => !currentWord || f.key.startsWith(currentWord.toLowerCase())
).map((field) => (
<Badge
key={field.key}
variant="outline"
className="text-xs font-mono cursor-pointer hover:bg-accent"
onClick={() => handleSelectSuggestion(`${field.key}="`)}
>
{field.key}
</Badge>
))}
</div>
</CommandGroup>
)}
{/* 语法帮助 */}
<CommandGroup heading="Syntax">
<div className="px-2 py-1.5 text-xs text-muted-foreground space-y-2">
{/* 匹配操作符 */}
<div className="space-y-1">
<div className="font-medium text-foreground/80">Operators</div>
<div className="grid grid-cols-[auto_1fr] gap-x-3 gap-y-0.5">
<code className="bg-muted px-1 rounded">=</code>
<span>contains (fuzzy)</span>
<code className="bg-muted px-1 rounded">==</code>
<span>exact match</span>
<code className="bg-muted px-1 rounded">!=</code>
<span>not equals</span>
</div>
</div>
{/* 逻辑运算符 */}
<div className="space-y-1 pt-1 border-t border-muted">
<div className="font-medium text-foreground/80">Logic</div>
<div className="grid grid-cols-[auto_1fr] gap-x-3 gap-y-0.5">
<span><code className="bg-muted px-1 rounded">||</code> <code className="bg-muted px-1 rounded">or</code></span>
<span>match any</span>
<span><code className="bg-muted px-1 rounded">&&</code> <code className="bg-muted px-1 rounded">and</code> <code className="bg-muted px-1 rounded">space</code></span>
<span>match all</span>
</div>
</div>
</div>
</CommandGroup>
{/* 示例 */}
{examples && examples.length > 0 && (
<CommandGroup heading="Examples">
{examples.map((example, i) => (
<CommandItem
key={i}
value={example}
onSelect={() => handleAppendExample(example)}
>
<code className="text-xs">{example}</code>
</CommandItem>
))}
</CommandGroup>
)}
<CommandEmpty>Type to filter...</CommandEmpty>
</CommandList>
</Command>
</PopoverContent>
</Popover>
</div>
)
}
export { parseFilterExpression, DEFAULT_FIELDS, type ParsedFilter }

View File

@@ -25,8 +25,6 @@ import {
IconLayoutColumns,
IconTrash,
IconDownload,
IconSearch,
IconLoader2,
IconPlus,
} from "@tabler/icons-react"
@@ -40,7 +38,6 @@ import {
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu"
import { Input } from "@/components/ui/input"
import { Label } from "@/components/ui/label"
import {
Select,
@@ -57,17 +54,30 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table"
import { SmartFilterInput, type FilterField } from "@/components/common/smart-filter-input"
import type { Directory } from "@/types/directory.types"
import type { PaginationInfo } from "@/types/common.types"
// 目录页面的过滤字段配置
const DIRECTORY_FILTER_FIELDS: FilterField[] = [
{ key: "url", label: "URL", description: "Directory URL" },
{ key: "status", label: "Status", description: "HTTP status code" },
]
// 目录页面的示例
const DIRECTORY_FILTER_EXAMPLES = [
'url="/admin" && status="200"',
'url="/api/*" || url="/config/*"',
'status="200" && url!="/index.html"',
]
interface DirectoriesDataTableProps {
data: Directory[]
columns: ColumnDef<Directory>[]
searchPlaceholder?: string
searchColumn?: string
searchValue?: string
onSearch?: (value: string) => void
// 智能过滤
filterValue?: string
onFilterChange?: (value: string) => void
isSearching?: boolean
pagination?: { pageIndex: number; pageSize: number }
setPagination?: React.Dispatch<React.SetStateAction<{ pageIndex: number; pageSize: number }>>
@@ -84,10 +94,8 @@ interface DirectoriesDataTableProps {
export function DirectoriesDataTable({
data = [],
columns,
searchPlaceholder = "搜索URL...",
searchColumn = "url",
searchValue,
onSearch,
filterValue,
onFilterChange,
isSearching = false,
pagination,
setPagination,
@@ -109,24 +117,10 @@ export function DirectoriesDataTable({
pageSize: 10,
})
// 本地搜索输入状态(只在回车或点击按钮时触发搜索
const [localSearchValue, setLocalSearchValue] = React.useState(searchValue ?? "")
React.useEffect(() => {
setLocalSearchValue(searchValue ?? "")
}, [searchValue])
const handleSearchSubmit = () => {
if (onSearch) {
onSearch(localSearchValue)
} else {
table.getColumn(searchColumn)?.setFilterValue(localSearchValue)
}
}
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
if (e.key === 'Enter') {
handleSearchSubmit()
// 处理智能过滤搜索
const handleSmartSearch = (_filters: any[], rawQuery: string) => {
if (onFilterChange) {
onFilterChange(rawQuery)
}
}
@@ -190,22 +184,14 @@ export function DirectoriesDataTable({
<div className="flex flex-col gap-4 @container/toolbar">
{/* 第一行:搜索和列控制 */}
<div className="flex flex-col gap-4 @xl/toolbar:flex-row @xl/toolbar:items-center @xl/toolbar:justify-between">
<div className="flex flex-1 items-center gap-2">
<Input
placeholder={searchPlaceholder}
value={localSearchValue}
onChange={(e) => setLocalSearchValue(e.target.value)}
onKeyDown={handleKeyDown}
className="h-8 w-full @xl/toolbar:max-w-sm"
/>
<Button variant="outline" size="sm" onClick={handleSearchSubmit} disabled={isSearching}>
{isSearching ? (
<IconLoader2 className="h-4 w-4 animate-spin" />
) : (
<IconSearch className="h-4 w-4" />
)}
</Button>
</div>
{/* 智能过滤搜索框 */}
<SmartFilterInput
fields={DIRECTORY_FILTER_FIELDS}
examples={DIRECTORY_FILTER_EXAMPLES}
value={filterValue}
onSearch={handleSmartSearch}
className="flex-1 max-w-xl"
/>
<div className="flex items-center gap-2">
{/* 列可见性控制 */}

View File

@@ -28,15 +28,15 @@ export function DirectoriesView({
const [selectedDirectories, setSelectedDirectories] = useState<Directory[]>([])
const [bulkAddDialogOpen, setBulkAddDialogOpen] = useState(false)
const [searchQuery, setSearchQuery] = useState("")
const [filterQuery, setFilterQuery] = useState("")
const [isSearching, setIsSearching] = useState(false)
// 获取目标信息(用于 URL 匹配校验)
const { data: target } = useTarget(targetId || 0, { enabled: !!targetId })
const handleSearchChange = (value: string) => {
const handleFilterChange = (value: string) => {
setIsSearching(true)
setSearchQuery(value)
setFilterQuery(value)
setPagination((prev) => ({ ...prev, pageIndex: 0 }))
}
@@ -45,7 +45,7 @@ export function DirectoriesView({
{
page: pagination.pageIndex + 1,
pageSize: pagination.pageSize,
search: searchQuery || undefined,
filter: filterQuery || undefined,
},
{ enabled: !!targetId }
)
@@ -55,7 +55,7 @@ export function DirectoriesView({
{
page: pagination.pageIndex + 1,
pageSize: pagination.pageSize,
search: searchQuery || undefined,
filter: filterQuery || undefined,
},
{ enabled: !!scanId }
)
@@ -242,10 +242,8 @@ export function DirectoriesView({
<DirectoriesDataTable
data={directories}
columns={columns}
searchPlaceholder="搜索URL..."
searchColumn="url"
searchValue={searchQuery}
onSearch={handleSearchChange}
filterValue={filterQuery}
onFilterChange={handleFilterChange}
isSearching={isSearching}
pagination={pagination}
setPagination={setPagination}

View File

@@ -25,8 +25,6 @@ import {
IconLayoutColumns,
IconPlus,
IconDownload,
IconSearch,
IconLoader2,
} from "@tabler/icons-react"
import { Button } from "@/components/ui/button"
import {
@@ -38,7 +36,6 @@ import {
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu"
import { Input } from "@/components/ui/input"
import { Label } from "@/components/ui/label"
import {
Select,
@@ -55,15 +52,30 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table"
import { SmartFilterInput, type FilterField } from "@/components/common/smart-filter-input"
import type { Endpoint } from "@/types/endpoint.types"
// 端点页面的过滤字段配置
const ENDPOINT_FILTER_FIELDS: FilterField[] = [
{ key: "url", label: "URL", description: "Endpoint URL" },
{ key: "host", label: "Host", description: "Hostname" },
{ key: "title", label: "Title", description: "Page title" },
{ key: "status", label: "Status", description: "HTTP status code" },
]
// 端点页面的示例
const ENDPOINT_FILTER_EXAMPLES = [
'url="/api/*" && status="200"',
'host="api.example.com" || host="admin.example.com"',
'title="Dashboard" && status!="404"',
]
interface EndpointsDataTableProps<TData extends { id: number | string }, TValue> {
columns: ColumnDef<TData, TValue>[]
data: TData[]
searchPlaceholder?: string
searchColumn?: string
searchValue?: string
onSearch?: (value: string) => void
// 智能过滤
filterValue?: string
onFilterChange?: (value: string) => void
isSearching?: boolean
onAddNew?: () => void
addButtonText?: string
@@ -80,10 +92,8 @@ interface EndpointsDataTableProps<TData extends { id: number | string }, TValue>
export function EndpointsDataTable<TData extends { id: number | string }, TValue>({
columns,
data,
searchPlaceholder = "搜索主机名...",
searchColumn = "url",
searchValue,
onSearch,
filterValue,
onFilterChange,
isSearching = false,
onAddNew,
addButtonText = "Add",
@@ -106,24 +116,10 @@ export function EndpointsDataTable<TData extends { id: number | string }, TValue
pageSize: 10,
})
// 本地搜索输入状态(只在回车或点击按钮时触发搜索
const [localSearchValue, setLocalSearchValue] = React.useState(searchValue ?? "")
React.useEffect(() => {
setLocalSearchValue(searchValue ?? "")
}, [searchValue])
const handleSearchSubmit = () => {
if (onSearch) {
onSearch(localSearchValue)
} else {
table.getColumn(searchColumn)?.setFilterValue(localSearchValue)
}
}
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
if (e.key === 'Enter') {
handleSearchSubmit()
// 处理智能过滤搜索
const handleSmartSearch = (_filters: any[], rawQuery: string) => {
if (onFilterChange) {
onFilterChange(rawQuery)
}
}
@@ -183,22 +179,14 @@ export function EndpointsDataTable<TData extends { id: number | string }, TValue
return (
<div className="w-full space-y-4">
<div className="flex items-center justify-between">
<div className="flex items-center space-x-2">
<Input
placeholder={searchPlaceholder}
value={localSearchValue}
onChange={(e) => setLocalSearchValue(e.target.value)}
onKeyDown={handleKeyDown}
className="h-8 max-w-sm"
/>
<Button variant="outline" size="sm" onClick={handleSearchSubmit} disabled={isSearching}>
{isSearching ? (
<IconLoader2 className="h-4 w-4 animate-spin" />
) : (
<IconSearch className="h-4 w-4" />
)}
</Button>
</div>
{/* 智能过滤搜索框 */}
<SmartFilterInput
fields={ENDPOINT_FILTER_FIELDS}
examples={ENDPOINT_FILTER_EXAMPLES}
value={filterValue}
onSearch={handleSmartSearch}
className="flex-1 max-w-xl"
/>
<div className="flex items-center space-x-2">
<DropdownMenu>

View File

@@ -46,15 +46,15 @@ export function EndpointsDetailView({
pageSize: 10
})
const [searchQuery, setSearchQuery] = useState("")
const [filterQuery, setFilterQuery] = useState("")
const [isSearching, setIsSearching] = useState(false)
// 获取目标信息(用于 URL 匹配校验)
const { data: target } = useTarget(targetId || 0, { enabled: !!targetId })
const handleSearchChange = (value: string) => {
const handleFilterChange = (value: string) => {
setIsSearching(true)
setSearchQuery(value)
setFilterQuery(value)
setPagination((prev) => ({ ...prev, pageIndex: 0 }))
}
@@ -65,14 +65,18 @@ export function EndpointsDetailView({
const targetEndpointsQuery = useTargetEndpoints(targetId || 0, {
page: pagination.pageIndex + 1,
pageSize: pagination.pageSize,
search: searchQuery || undefined,
filter: filterQuery || undefined,
}, { enabled: !!targetId })
const scanEndpointsQuery = useScanEndpoints(scanId || 0, {
page: pagination.pageIndex + 1,
pageSize: pagination.pageSize,
search: searchQuery || undefined,
}, { enabled: !!scanId })
const scanEndpointsQuery = useScanEndpoints(
scanId || 0,
{
page: pagination.pageIndex + 1,
pageSize: pagination.pageSize,
},
{ enabled: !!scanId },
filterQuery || undefined,
)
const {
data,
@@ -279,9 +283,8 @@ export function EndpointsDetailView({
<EndpointsDataTable
data={data?.endpoints || []}
columns={endpointColumns}
searchPlaceholder="搜索主机名..."
searchValue={searchQuery}
onSearch={handleSearchChange}
filterValue={filterQuery}
onFilterChange={handleFilterChange}
isSearching={isSearching}
pagination={pagination}
onPaginationChange={handlePaginationChange}

View File

@@ -23,8 +23,6 @@ import {
IconLayoutColumns,
IconTrash,
IconDownload,
IconSearch,
IconLoader2,
} from "@tabler/icons-react"
import { Button } from "@/components/ui/button"
import {
@@ -36,7 +34,6 @@ import {
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu"
import { Input } from "@/components/ui/input"
import { Label } from "@/components/ui/label"
import {
Select,
@@ -53,36 +50,45 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table"
import { SmartFilterInput, PREDEFINED_FIELDS, type ParsedFilter, type FilterField } from "@/components/common/smart-filter-input"
import type { IPAddress } from "@/types/ip-address.types"
import type { PaginationInfo } from "@/types/common.types"
// IP 地址页面的过滤字段配置
const IP_ADDRESS_FILTER_FIELDS: FilterField[] = [
PREDEFINED_FIELDS.ip,
PREDEFINED_FIELDS.port,
PREDEFINED_FIELDS.host,
]
// IP 地址页面的示例
const IP_ADDRESS_FILTER_EXAMPLES = [
'ip="192.168.1.*" && port="80"',
'port="443" || port="8443"',
'host="api.example.com" && port!="22"',
]
interface IPAddressesDataTableProps {
data: IPAddress[]
columns: ColumnDef<IPAddress>[]
searchPlaceholder?: string
searchColumn?: string
searchValue?: string
onSearch?: (value: string) => void
isSearching?: boolean
filterValue?: string
onFilterChange?: (value: string) => void
pagination?: { pageIndex: number; pageSize: number }
setPagination?: React.Dispatch<React.SetStateAction<{ pageIndex: number; pageSize: number }>>
paginationInfo?: PaginationInfo
onPaginationChange?: (pagination: { pageIndex: number; pageSize: number }) => void
onBulkDelete?: () => void // 批量删除回调函数
onSelectionChange?: (selectedRows: IPAddress[]) => void // 选中行变化回调
// 下载回调函数
onDownloadAll?: () => void // 下载所有 IP 地址
onDownloadSelected?: () => void // 下载选中的 IP 地址
onBulkDelete?: () => void
onSelectionChange?: (selectedRows: IPAddress[]) => void
onDownloadAll?: () => void
onDownloadSelected?: () => void
}
export function IPAddressesDataTable({
data = [],
columns,
searchPlaceholder = "搜索 IP 地址...",
searchColumn = "ip",
searchValue,
onSearch,
isSearching = false,
filterValue = "",
onFilterChange,
pagination,
setPagination,
paginationInfo,
@@ -102,25 +108,9 @@ export function IPAddressesDataTable({
pageSize: 10,
})
// 本地搜索输入状态(只在回车或点击按钮时触发搜索)
const [localSearchValue, setLocalSearchValue] = React.useState(searchValue ?? "")
React.useEffect(() => {
setLocalSearchValue(searchValue ?? "")
}, [searchValue])
const handleSearchSubmit = () => {
if (onSearch) {
onSearch(localSearchValue)
} else {
table.getColumn(searchColumn)?.setFilterValue(localSearchValue)
}
}
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
if (e.key === 'Enter') {
handleSearchSubmit()
}
// 智能搜索处理
const handleSmartSearch = (filters: ParsedFilter[], rawQuery: string) => {
onFilterChange?.(rawQuery)
}
const useServerPagination = !!paginationInfo && !!pagination && !!setPagination
@@ -141,7 +131,7 @@ export function IPAddressesDataTable({
enableColumnResizing: true,
columnResizeMode: 'onChange',
onColumnSizingChange: setColumnSizing,
getRowId: (row) => row.ip, // IP 地址本身就是唯一标识
getRowId: (row) => row.ip,
enableRowSelection: true,
onRowSelectionChange: setRowSelection,
onSortingChange: setSorting,
@@ -163,10 +153,6 @@ export function IPAddressesDataTable({
: Math.ceil(data.length / tablePagination.pageSize) || 1,
})
const totalItems = useServerPagination
? paginationInfo?.total ?? data.length
: table.getFilteredRowModel().rows.length
// 处理选中行变化
React.useEffect(() => {
if (onSelectionChange) {
@@ -177,25 +163,16 @@ export function IPAddressesDataTable({
return (
<div className="w-full space-y-4">
{/* 工具栏 */}
<div className="flex items-center justify-between">
{/* 搜索框 */}
<div className="flex items-center space-x-2">
<Input
placeholder={searchPlaceholder}
value={localSearchValue}
onChange={(e) => setLocalSearchValue(e.target.value)}
onKeyDown={handleKeyDown}
className="h-8 max-w-sm"
/>
<Button variant="outline" size="sm" onClick={handleSearchSubmit} disabled={isSearching}>
{isSearching ? (
<IconLoader2 className="h-4 w-4 animate-spin" />
) : (
<IconSearch className="h-4 w-4" />
)}
</Button>
</div>
{/* 工具栏 - 方案 D智能搜索框 */}
<div className="flex items-center justify-between gap-4 flex-wrap">
{/* 左侧:智能搜索框 */}
<SmartFilterInput
fields={IP_ADDRESS_FILTER_FIELDS}
examples={IP_ADDRESS_FILTER_EXAMPLES}
value={filterValue}
onSearch={handleSmartSearch}
className="flex-1 max-w-xl"
/>
{/* 右侧操作按钮 */}
<div className="flex items-center space-x-2">

View File

@@ -1,6 +1,6 @@
"use client"
import React, { useCallback, useMemo, useState, useEffect } from "react"
import React, { useCallback, useMemo, useState } from "react"
import { AlertTriangle } from "lucide-react"
import { IPAddressesDataTable } from "./ip-addresses-data-table"
import { createIPAddressColumns } from "./ip-addresses-columns"
@@ -23,13 +23,10 @@ export function IPAddressesView({
pageSize: 10,
})
const [selectedIPAddresses, setSelectedIPAddresses] = useState<IPAddress[]>([])
const [filterQuery, setFilterQuery] = useState("")
const [searchQuery, setSearchQuery] = useState("")
const [isSearching, setIsSearching] = useState(false)
const handleSearchChange = (value: string) => {
setIsSearching(true)
setSearchQuery(value)
const handleFilterChange = (value: string) => {
setFilterQuery(value)
setPagination((prev) => ({ ...prev, pageIndex: 0 }))
}
@@ -38,7 +35,7 @@ export function IPAddressesView({
{
page: pagination.pageIndex + 1,
pageSize: pagination.pageSize,
search: searchQuery || undefined,
filter: filterQuery || undefined,
},
{ enabled: !!targetId }
)
@@ -48,19 +45,13 @@ export function IPAddressesView({
{
page: pagination.pageIndex + 1,
pageSize: pagination.pageSize,
search: searchQuery || undefined,
filter: filterQuery || undefined,
},
{ enabled: !!scanId }
)
const activeQuery = targetId ? targetQuery : scanQuery
const { data, isLoading, isFetching, error, refetch } = activeQuery
useEffect(() => {
if (!isFetching && isSearching) {
setIsSearching(false)
}
}, [isFetching, isSearching])
const { data, isLoading, error, refetch } = activeQuery
const formatDate = useCallback((dateString: string) => {
return new Date(dateString).toLocaleString("zh-CN", {
@@ -225,11 +216,8 @@ export function IPAddressesView({
<IPAddressesDataTable
data={ipAddresses}
columns={columns}
searchPlaceholder="搜索IP地址..."
searchColumn="ip"
searchValue={searchQuery}
onSearch={handleSearchChange}
isSearching={isSearching}
filterValue={filterQuery}
onFilterChange={handleFilterChange}
pagination={pagination}
setPagination={setPagination}
paginationInfo={paginationInfo}

View File

@@ -29,8 +29,6 @@ import {
IconPlus,
IconTrash,
IconDownload,
IconSearch,
IconLoader2,
} from "@tabler/icons-react"
// 导入 UI 组件
@@ -44,7 +42,6 @@ import {
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu"
import { Input } from "@/components/ui/input"
import { Label } from "@/components/ui/label"
import {
Select,
@@ -61,11 +58,23 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table"
import { SmartFilterInput, PREDEFINED_FIELDS, type FilterField } from "@/components/common/smart-filter-input"
// 导入子域名类型定义
import type { Subdomain } from "@/types/subdomain.types"
import type { PaginationInfo } from "@/types/common.types"
// 子域名页面的过滤字段配置
const SUBDOMAIN_FILTER_FIELDS: FilterField[] = [
{ key: "name", label: "Name", description: "Subdomain name" },
]
// 子域名页面的示例
const SUBDOMAIN_FILTER_EXAMPLES = [
'name="api.example.com"',
'name="*.test.com"',
]
// 组件属性类型定义
interface SubdomainsDataTableProps {
data: Subdomain[] // 子域名数据数组
@@ -74,10 +83,9 @@ interface SubdomainsDataTableProps {
onBulkAdd?: () => void // 批量添加回调函数
onBulkDelete?: () => void // 批量删除回调函数
onSelectionChange?: (selectedRows: Subdomain[]) => void // 选中行变化回调
searchPlaceholder?: string // 搜索框占位符
searchColumn?: string // 搜索的列名
searchValue?: string // 受控:搜索框当前值(服务端搜索)
onSearch?: (value: string) => void // 受控:搜索框变更回调(服务端搜索)
// 智能过滤
filterValue?: string // 受控:过滤值
onFilterChange?: (value: string) => void // 受控:过滤变更回调
isSearching?: boolean // 搜索中状态(显示加载动画)
addButtonText?: string // 添加按钮文本
// 下载回调函数
@@ -104,10 +112,8 @@ export function SubdomainsDataTable({
onBulkAdd,
onBulkDelete,
onSelectionChange,
searchPlaceholder = "搜索子域名...",
searchColumn = "name",
searchValue,
onSearch,
filterValue,
onFilterChange,
isSearching = false,
addButtonText = "Add",
onDownloadAll,
@@ -140,25 +146,10 @@ export function SubdomainsDataTable({
const pagination = externalPagination || internalPagination
const setPagination = setExternalPagination || setInternalPagination
// 本地搜索输入状态(只在回车或点击按钮时触发搜索
const [localSearchValue, setLocalSearchValue] = React.useState(searchValue ?? "")
// 同步外部 searchValue 到本地状态
React.useEffect(() => {
setLocalSearchValue(searchValue ?? "")
}, [searchValue])
const handleSearchSubmit = () => {
if (onSearch) {
onSearch(localSearchValue)
} else {
table.getColumn(searchColumn)?.setFilterValue(localSearchValue)
}
}
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
if (e.key === 'Enter') {
handleSearchSubmit()
// 处理智能过滤搜索
const handleSmartSearch = (_filters: any[], rawQuery: string) => {
if (onFilterChange) {
onFilterChange(rawQuery)
}
}
@@ -223,23 +214,14 @@ export function SubdomainsDataTable({
<div className="w-full space-y-4">
{/* 工具栏 */}
<div className="flex items-center justify-between">
{/* 搜索框 */}
<div className="flex items-center space-x-2">
<Input
placeholder={searchPlaceholder}
value={localSearchValue}
onChange={(e) => setLocalSearchValue(e.target.value)}
onKeyDown={handleKeyDown}
className="h-8 max-w-sm"
/>
<Button variant="outline" size="sm" onClick={handleSearchSubmit} disabled={isSearching}>
{isSearching ? (
<IconLoader2 className="h-4 w-4 animate-spin" />
) : (
<IconSearch className="h-4 w-4" />
)}
</Button>
</div>
{/* 智能过滤搜索框 */}
<SmartFilterInput
fields={SUBDOMAIN_FILTER_FIELDS}
examples={SUBDOMAIN_FILTER_EXAMPLES}
value={filterValue}
onSearch={handleSmartSearch}
className="flex-1 max-w-xl"
/>
{/* 右侧操作按钮 */}
<div className="flex items-center space-x-2">

View File

@@ -10,7 +10,6 @@ import {
} from "@/hooks/use-subdomains"
import { SubdomainsDataTable } from "./subdomains-data-table"
import { createSubdomainColumns } from "./subdomains-columns"
import { LoadingSpinner } from "@/components/loading-spinner"
import { DataTableSkeleton } from "@/components/ui/data-table-skeleton"
import { SubdomainService } from "@/services/subdomain.service"
import { BulkAddSubdomainsDialog } from "./bulk-add-subdomains-dialog"
@@ -40,23 +39,23 @@ export function SubdomainsDetailView({
pageSize: 10,
})
// 搜索状态(服务端搜索
const [searchQuery, setSearchQuery] = useState("")
// 过滤状态(智能过滤语法
const [filterQuery, setFilterQuery] = useState("")
const [isSearching, setIsSearching] = useState(false)
const handleSearchChange = (value: string) => {
const handleFilterChange = (value: string) => {
setIsSearching(true)
setSearchQuery(value)
setFilterQuery(value)
setPagination((prev) => ({ ...prev, pageIndex: 0 }))
}
// 根据 targetId 或 scanId 获取子域名数据(传入分页和搜索参数)
// 根据 targetId 或 scanId 获取子域名数据(传入分页和过滤参数)
const targetSubdomainsQuery = useTargetSubdomains(
targetId || 0,
{
page: pagination.pageIndex + 1, // 转换为 1-based
pageSize: pagination.pageSize,
search: searchQuery || undefined,
filter: filterQuery || undefined,
},
{ enabled: !!targetId }
)
@@ -65,7 +64,7 @@ export function SubdomainsDetailView({
{
page: pagination.pageIndex + 1, // 转换为 1-based
pageSize: pagination.pageSize,
search: searchQuery || undefined,
filter: filterQuery || undefined,
},
{ enabled: !!scanId }
)
@@ -254,10 +253,8 @@ export function SubdomainsDetailView({
data={subdomains}
columns={subdomainColumns}
onSelectionChange={setSelectedSubdomains}
searchPlaceholder="搜索子域名..."
searchColumn="name"
searchValue={searchQuery}
onSearch={handleSearchChange}
filterValue={filterQuery}
onFilterChange={handleFilterChange}
isSearching={isSearching}
onDownloadAll={handleDownloadAll}
onDownloadSelected={handleDownloadSelected}

View File

@@ -25,8 +25,6 @@ import {
IconLayoutColumns,
IconTrash,
IconDownload,
IconSearch,
IconLoader2,
} from "@tabler/icons-react"
import { Button } from "@/components/ui/button"
@@ -39,7 +37,6 @@ import {
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu"
import { Input } from "@/components/ui/input"
import { Label } from "@/components/ui/label"
import {
Select,
@@ -56,37 +53,47 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table"
import { SmartFilterInput, PREDEFINED_FIELDS, type FilterField } from "@/components/common/smart-filter-input"
import type { Vulnerability } from "@/types/vulnerability.types"
import type { PaginationInfo } from "@/types/common.types"
// 漏洞页面的过滤字段
const VULNERABILITY_FILTER_FIELDS: FilterField[] = [
{ key: "type", label: "Type", description: "Vulnerability type" },
PREDEFINED_FIELDS.severity,
{ key: "source", label: "Source", description: "Scanner source" },
PREDEFINED_FIELDS.url,
]
// 漏洞页面的示例
const VULNERABILITY_FILTER_EXAMPLES = [
'type="xss" || type="sqli"',
'severity="critical" || severity="high"',
'source="nuclei" && severity="high"',
'type="xss" && url="/api/*"',
]
interface VulnerabilitiesDataTableProps {
data: Vulnerability[]
columns: ColumnDef<Vulnerability>[]
searchPlaceholder?: string
searchColumn?: string
searchValue?: string
onSearch?: (value: string) => void
isSearching?: boolean
filterValue?: string
onFilterChange?: (value: string) => void
pagination?: { pageIndex: number; pageSize: number }
setPagination?: React.Dispatch<React.SetStateAction<{ pageIndex: number; pageSize: number }>>
paginationInfo?: PaginationInfo
onPaginationChange?: (pagination: { pageIndex: number; pageSize: number }) => void
onBulkDelete?: () => void // 批量删除回调函数
onSelectionChange?: (selectedRows: Vulnerability[]) => void // 选中行变化回调
// 下载回调函数
onDownloadAll?: () => void // 下载所有漏洞
onDownloadSelected?: () => void // 下载选中的漏洞
onBulkDelete?: () => void
onSelectionChange?: (selectedRows: Vulnerability[]) => void
onDownloadAll?: () => void
onDownloadSelected?: () => void
}
export function VulnerabilitiesDataTable({
data = [],
columns,
searchPlaceholder = "搜索漏洞类型...",
searchColumn = "title",
searchValue,
onSearch,
isSearching = false,
filterValue,
onFilterChange,
pagination,
setPagination,
paginationInfo,
@@ -107,27 +114,6 @@ export function VulnerabilitiesDataTable({
pageSize: 10,
})
// 本地搜索输入状态(只在回车或点击按钮时触发搜索)
const [localSearchValue, setLocalSearchValue] = React.useState(searchValue ?? "")
React.useEffect(() => {
setLocalSearchValue(searchValue ?? "")
}, [searchValue])
const handleSearchSubmit = () => {
if (onSearch) {
onSearch(localSearchValue)
} else {
table.getColumn(searchColumn)?.setFilterValue(localSearchValue)
}
}
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
if (e.key === 'Enter') {
handleSearchSubmit()
}
}
const useServerPagination = !!paginationInfo && !!pagination && !!setPagination
const tablePagination = useServerPagination ? pagination : internalPagination
const setTablePagination = useServerPagination ? setPagination : setInternalPagination
@@ -182,27 +168,23 @@ export function VulnerabilitiesDataTable({
}
}, [rowSelection, onSelectionChange, table])
// 处理智能过滤搜索
const handleFilterSearch = (_filters: any[], rawQuery: string) => {
onFilterChange?.(rawQuery)
}
return (
<div className="w-full space-y-4">
{/* 工具栏 */}
<div className="flex items-center justify-between">
{/* 搜索框 */}
<div className="flex items-center space-x-2">
<Input
placeholder={searchPlaceholder}
value={localSearchValue}
onChange={(e) => setLocalSearchValue(e.target.value)}
onKeyDown={handleKeyDown}
className="h-8 max-w-sm"
/>
<Button variant="outline" size="sm" onClick={handleSearchSubmit} disabled={isSearching}>
{isSearching ? (
<IconLoader2 className="h-4 w-4 animate-spin" />
) : (
<IconSearch className="h-4 w-4" />
)}
</Button>
</div>
{/* 智能过滤输入框 */}
<SmartFilterInput
fields={VULNERABILITY_FILTER_FIELDS}
examples={VULNERABILITY_FILTER_EXAMPLES}
value={filterValue}
onSearch={handleFilterSearch}
className="flex-1 max-w-xl"
/>
{/* 右侧操作按钮 */}
<div className="flex items-center space-x-2">

View File

@@ -42,19 +42,17 @@ export function VulnerabilitiesDetailView({
pageSize: 10,
})
const [searchQuery, setSearchQuery] = useState("")
const [isSearching, setIsSearching] = useState(false)
// 智能过滤查询
const [filterQuery, setFilterQuery] = useState("")
const handleSearchChange = (value: string) => {
setIsSearching(true)
setSearchQuery(value)
const handleFilterChange = (value: string) => {
setFilterQuery(value)
setPagination((prev) => ({ ...prev, pageIndex: 0 }))
}
const paginationParams = {
page: pagination.pageIndex + 1,
pageSize: pagination.pageSize,
search: searchQuery || undefined,
}
// 按 scan 维度加载(扫描历史页面)
@@ -62,6 +60,7 @@ export function VulnerabilitiesDetailView({
scanId ?? 0,
paginationParams,
{ enabled: !!scanId },
filterQuery || undefined,
)
// 按 target 维度加载(目标详情页面)
@@ -69,25 +68,20 @@ export function VulnerabilitiesDetailView({
targetId ?? 0,
paginationParams,
{ enabled: !!targetId && !scanId },
filterQuery || undefined,
)
// 获取所有漏洞(全局漏洞页面)
const allQuery = useAllVulnerabilities(
paginationParams,
{ enabled: !scanId && !targetId },
filterQuery || undefined,
)
// 根据参数选择使用哪个 query
const activeQuery = scanId ? scanQuery : targetId ? targetQuery : allQuery
const isQueryLoading = activeQuery.isLoading
// 当请求完成时重置搜索状态
React.useEffect(() => {
if (!activeQuery.isFetching && isSearching) {
setIsSearching(false)
}
}, [activeQuery.isFetching, isSearching])
const vulnerabilities = activeQuery.data?.vulnerabilities ?? []
const paginationInfo = activeQuery.data?.pagination ?? {
total: vulnerabilities.length,
@@ -206,11 +200,8 @@ export function VulnerabilitiesDetailView({
<VulnerabilitiesDataTable
data={vulnerabilities}
columns={vulnerabilityColumns}
searchPlaceholder="搜索漏洞类型..."
searchColumn="vulnType"
searchValue={searchQuery}
onSearch={handleSearchChange}
isSearching={isSearching}
filterValue={filterQuery}
onFilterChange={handleFilterChange}
pagination={pagination}
setPagination={setPagination}
paginationInfo={{

View File

@@ -25,8 +25,6 @@ import {
IconLayoutColumns,
IconTrash,
IconDownload,
IconSearch,
IconLoader2,
IconPlus,
} from "@tabler/icons-react"
@@ -40,7 +38,6 @@ import {
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu"
import { Input } from "@/components/ui/input"
import { Label } from "@/components/ui/label"
import {
Select,
@@ -57,17 +54,32 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table"
import { SmartFilterInput, type FilterField } from "@/components/common/smart-filter-input"
import type { WebSite } from "@/types/website.types"
import type { PaginationInfo } from "@/types/common.types"
// 网站页面的过滤字段配置
const WEBSITE_FILTER_FIELDS: FilterField[] = [
{ key: "url", label: "URL", description: "Full URL" },
{ key: "host", label: "Host", description: "Hostname" },
{ key: "title", label: "Title", description: "Page title" },
{ key: "status", label: "Status", description: "HTTP status code" },
]
// 网站页面的示例
const WEBSITE_FILTER_EXAMPLES = [
'host="api.example.com" && status="200"',
'title="Login" || title="Admin"',
'url="/api/*" && status!="404"',
]
interface WebSitesDataTableProps {
data: WebSite[]
columns: ColumnDef<WebSite>[]
searchPlaceholder?: string
searchColumn?: string
searchValue?: string
onSearch?: (value: string) => void
// 智能过滤
filterValue?: string
onFilterChange?: (value: string) => void
isSearching?: boolean
pagination?: { pageIndex: number; pageSize: number }
setPagination?: React.Dispatch<React.SetStateAction<{ pageIndex: number; pageSize: number }>>
@@ -84,10 +96,8 @@ interface WebSitesDataTableProps {
export function WebSitesDataTable({
data = [],
columns,
searchPlaceholder = "搜索主机名...",
searchColumn = "url",
searchValue,
onSearch,
filterValue,
onFilterChange,
isSearching = false,
pagination,
setPagination,
@@ -109,24 +119,10 @@ export function WebSitesDataTable({
pageSize: 10,
})
// 本地搜索输入状态(只在回车或点击按钮时触发搜索
const [localSearchValue, setLocalSearchValue] = React.useState(searchValue ?? "")
React.useEffect(() => {
setLocalSearchValue(searchValue ?? "")
}, [searchValue])
const handleSearchSubmit = () => {
if (onSearch) {
onSearch(localSearchValue)
} else {
table.getColumn(searchColumn)?.setFilterValue(localSearchValue)
}
}
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
if (e.key === 'Enter') {
handleSearchSubmit()
// 处理智能过滤搜索
const handleSmartSearch = (_filters: any[], rawQuery: string) => {
if (onFilterChange) {
onFilterChange(rawQuery)
}
}
@@ -188,23 +184,14 @@ export function WebSitesDataTable({
<div className="w-full space-y-4">
{/* 工具栏 */}
<div className="flex items-center justify-between">
{/* 搜索框 */}
<div className="flex items-center space-x-2">
<Input
placeholder={searchPlaceholder}
value={localSearchValue}
onChange={(e) => setLocalSearchValue(e.target.value)}
onKeyDown={handleKeyDown}
className="h-8 max-w-sm"
/>
<Button variant="outline" size="sm" onClick={handleSearchSubmit} disabled={isSearching}>
{isSearching ? (
<IconLoader2 className="h-4 w-4 animate-spin" />
) : (
<IconSearch className="h-4 w-4" />
)}
</Button>
</div>
{/* 智能过滤搜索框 */}
<SmartFilterInput
fields={WEBSITE_FILTER_FIELDS}
examples={WEBSITE_FILTER_EXAMPLES}
value={filterValue}
onSearch={handleSmartSearch}
className="flex-1 max-w-xl"
/>
{/* 右侧操作按钮 */}
<div className="flex items-center space-x-2">

View File

@@ -28,15 +28,15 @@ export function WebSitesView({
const [selectedWebSites, setSelectedWebSites] = useState<WebSite[]>([])
const [bulkAddDialogOpen, setBulkAddDialogOpen] = useState(false)
const [searchQuery, setSearchQuery] = useState("")
const [filterQuery, setFilterQuery] = useState("")
const [isSearching, setIsSearching] = useState(false)
// 获取目标信息(用于 URL 匹配校验)
const { data: target } = useTarget(targetId || 0, { enabled: !!targetId })
const handleSearchChange = (value: string) => {
const handleFilterChange = (value: string) => {
setIsSearching(true)
setSearchQuery(value)
setFilterQuery(value)
setPagination((prev) => ({ ...prev, pageIndex: 0 }))
}
@@ -45,7 +45,7 @@ export function WebSitesView({
{
page: pagination.pageIndex + 1,
pageSize: pagination.pageSize,
search: searchQuery || undefined,
filter: filterQuery || undefined,
},
{ enabled: !!targetId }
)
@@ -55,7 +55,7 @@ export function WebSitesView({
{
page: pagination.pageIndex + 1,
pageSize: pagination.pageSize,
search: searchQuery || undefined,
filter: filterQuery || undefined,
},
{ enabled: !!scanId }
)
@@ -248,10 +248,8 @@ export function WebSitesView({
<WebSitesDataTable
data={websites}
columns={columns}
searchPlaceholder="搜索主机名..."
searchColumn="url"
searchValue={searchQuery}
onSearch={handleSearchChange}
filterValue={filterQuery}
onFilterChange={handleFilterChange}
isSearching={isSearching}
pagination={pagination}
setPagination={setPagination}

View File

@@ -8,11 +8,11 @@ const directoryService = {
// 获取目标的目录列表
getTargetDirectories: async (
targetId: number,
params: { page: number; pageSize: number; search?: string }
params: { page: number; pageSize: number; filter?: string }
): Promise<DirectoryListResponse> => {
const searchParam = params.search ? `&search=${encodeURIComponent(params.search)}` : ''
const filterParam = params.filter ? `&filter=${encodeURIComponent(params.filter)}` : ''
const response = await fetch(
`/api/targets/${targetId}/directories/?page=${params.page}&pageSize=${params.pageSize}${searchParam}`
`/api/targets/${targetId}/directories/?page=${params.page}&pageSize=${params.pageSize}${filterParam}`
)
if (!response.ok) {
throw new Error('获取目录列表失败')
@@ -23,11 +23,11 @@ const directoryService = {
// 获取扫描的目录列表
getScanDirectories: async (
scanId: number,
params: { page: number; pageSize: number; search?: string }
params: { page: number; pageSize: number; filter?: string }
): Promise<DirectoryListResponse> => {
const searchParam = params.search ? `&search=${encodeURIComponent(params.search)}` : ''
const filterParam = params.filter ? `&filter=${encodeURIComponent(params.filter)}` : ''
const response = await fetch(
`/api/scans/${scanId}/directories/?page=${params.page}&pageSize=${params.pageSize}${searchParam}`
`/api/scans/${scanId}/directories/?page=${params.page}&pageSize=${params.pageSize}${filterParam}`
)
if (!response.ok) {
throw new Error('获取目录列表失败')
@@ -80,7 +80,7 @@ const directoryService = {
// 获取目标的目录列表
export function useTargetDirectories(
targetId: number,
params: { page: number; pageSize: number; search?: string },
params: { page: number; pageSize: number; filter?: string },
options?: { enabled?: boolean }
) {
return useQuery({
@@ -94,7 +94,7 @@ export function useTargetDirectories(
// 获取扫描的目录列表
export function useScanDirectories(
scanId: number,
params: { page: number; pageSize: number; search?: string },
params: { page: number; pageSize: number; filter?: string },
options?: { enabled?: boolean }
) {
return useQuery({

View File

@@ -61,7 +61,7 @@ export function useEndpoints(params?: GetEndpointsRequest) {
}
// 根据目标ID获取 Endpoint 列表(使用专用路由)
export function useEndpointsByTarget(targetId: number, params?: Omit<GetEndpointsRequest, 'targetId'>) {
export function useEndpointsByTarget(targetId: number, params?: Omit<GetEndpointsRequest, 'targetId'>, filter?: string) {
const defaultParams: GetEndpointsRequest = {
page: 1,
pageSize: 10,
@@ -69,8 +69,8 @@ export function useEndpointsByTarget(targetId: number, params?: Omit<GetEndpoint
}
return useQuery({
queryKey: endpointKeys.byTarget(targetId, defaultParams),
queryFn: () => EndpointService.getEndpointsByTargetId(targetId, defaultParams),
queryKey: [...endpointKeys.byTarget(targetId, defaultParams), filter],
queryFn: () => EndpointService.getEndpointsByTargetId(targetId, defaultParams, filter),
select: (response) => {
// RESTful 标准:直接返回数据
return response as GetEndpointsResponse
@@ -81,7 +81,7 @@ export function useEndpointsByTarget(targetId: number, params?: Omit<GetEndpoint
}
// 根据扫描ID获取 Endpoint 列表(历史快照)
export function useScanEndpoints(scanId: number, params?: Omit<GetEndpointsRequest, 'targetId'>, options?: { enabled?: boolean }) {
export function useScanEndpoints(scanId: number, params?: Omit<GetEndpointsRequest, 'targetId'>, options?: { enabled?: boolean }, filter?: string) {
const defaultParams: GetEndpointsRequest = {
page: 1,
pageSize: 10,
@@ -89,8 +89,8 @@ export function useScanEndpoints(scanId: number, params?: Omit<GetEndpointsReque
}
return useQuery({
queryKey: endpointKeys.byScan(scanId, defaultParams),
queryFn: () => EndpointService.getEndpointsByScanId(scanId, defaultParams),
queryKey: [...endpointKeys.byScan(scanId, defaultParams), filter],
queryFn: () => EndpointService.getEndpointsByScanId(scanId, defaultParams, filter),
enabled: options?.enabled !== undefined ? options.enabled : !!scanId,
select: (response: any) => {
// 后端使用通用分页格式results/total/page/pageSize/totalPages

View File

@@ -16,7 +16,7 @@ function normalizeParams(params?: GetIPAddressesParams): Required<GetIPAddresses
return {
page: params?.page ?? 1,
pageSize: params?.pageSize ?? 10,
search: params?.search ?? "",
filter: params?.filter ?? "",
}
}

View File

@@ -267,11 +267,11 @@ export function useAllSubdomains(
// 获取目标的子域名列表
export function useTargetSubdomains(
targetId: number,
params?: { page?: number; pageSize?: number; search?: string },
params?: { page?: number; pageSize?: number; filter?: string },
options?: { enabled?: boolean }
) {
return useQuery({
queryKey: ['targets', targetId, 'subdomains', { page: params?.page, pageSize: params?.pageSize, search: params?.search }],
queryKey: ['targets', targetId, 'subdomains', { page: params?.page, pageSize: params?.pageSize, filter: params?.filter }],
queryFn: () => SubdomainService.getSubdomainsByTargetId(targetId, params),
enabled: options?.enabled !== undefined ? options.enabled : !!targetId,
placeholderData: keepPreviousData,
@@ -281,11 +281,11 @@ export function useTargetSubdomains(
// 获取扫描的子域名列表
export function useScanSubdomains(
scanId: number,
params?: { page?: number; pageSize?: number; search?: string },
params?: { page?: number; pageSize?: number; filter?: string },
options?: { enabled?: boolean }
) {
return useQuery({
queryKey: ['scans', scanId, 'subdomains', { page: params?.page, pageSize: params?.pageSize, search: params?.search }],
queryKey: ['scans', scanId, 'subdomains', { page: params?.page, pageSize: params?.pageSize, filter: params?.filter }],
queryFn: () => SubdomainService.getSubdomainsByScanId(scanId, params),
enabled: options?.enabled !== undefined ? options.enabled : !!scanId,
placeholderData: keepPreviousData,

View File

@@ -267,7 +267,7 @@ export function useTargetEndpoints(
params?: {
page?: number
pageSize?: number
search?: string
filter?: string
},
options?: {
enabled?: boolean
@@ -277,9 +277,9 @@ export function useTargetEndpoints(
queryKey: ['targets', 'detail', targetId, 'endpoints', {
page: params?.page,
pageSize: params?.pageSize,
search: params?.search,
filter: params?.filter,
}],
queryFn: () => getTargetEndpoints(targetId, params?.page || 1, params?.pageSize || 10, params?.search),
queryFn: () => getTargetEndpoints(targetId, params?.page || 1, params?.pageSize || 10, params?.filter),
enabled: options?.enabled !== undefined ? options.enabled : !!targetId,
select: (response: any) => {
// 后端使用通用分页格式results/total/page/pageSize/totalPages

View File

@@ -12,18 +12,19 @@ import type { PaginationInfo } from "@/types/common.types"
export const vulnerabilityKeys = {
all: ["vulnerabilities"] as const,
list: (params: GetVulnerabilitiesParams) =>
[...vulnerabilityKeys.all, "list", params] as const,
byScan: (scanId: number, params: GetVulnerabilitiesParams) =>
[...vulnerabilityKeys.all, "scan", scanId, params] as const,
byTarget: (targetId: number, params: GetVulnerabilitiesParams) =>
[...vulnerabilityKeys.all, "target", targetId, params] as const,
list: (params: GetVulnerabilitiesParams, filter?: string) =>
[...vulnerabilityKeys.all, "list", params, filter] as const,
byScan: (scanId: number, params: GetVulnerabilitiesParams, filter?: string) =>
[...vulnerabilityKeys.all, "scan", scanId, params, filter] as const,
byTarget: (targetId: number, params: GetVulnerabilitiesParams, filter?: string) =>
[...vulnerabilityKeys.all, "target", targetId, params, filter] as const,
}
/** 获取所有漏洞 */
export function useAllVulnerabilities(
params?: GetVulnerabilitiesParams,
options?: { enabled?: boolean },
filter?: string,
) {
const defaultParams: GetVulnerabilitiesParams = {
page: 1,
@@ -32,8 +33,8 @@ export function useAllVulnerabilities(
}
return useQuery({
queryKey: vulnerabilityKeys.list(defaultParams),
queryFn: () => VulnerabilityService.getAllVulnerabilities(defaultParams),
queryKey: vulnerabilityKeys.list(defaultParams, filter),
queryFn: () => VulnerabilityService.getAllVulnerabilities(defaultParams, filter),
enabled: options?.enabled ?? true,
select: (response: any) => {
const items = (response?.results ?? []) as any[]
@@ -93,6 +94,7 @@ export function useScanVulnerabilities(
scanId: number,
params?: GetVulnerabilitiesParams,
options?: { enabled?: boolean },
filter?: string,
) {
const defaultParams: GetVulnerabilitiesParams = {
page: 1,
@@ -101,9 +103,9 @@ export function useScanVulnerabilities(
}
return useQuery({
queryKey: vulnerabilityKeys.byScan(scanId, defaultParams),
queryKey: vulnerabilityKeys.byScan(scanId, defaultParams, filter),
queryFn: () =>
VulnerabilityService.getVulnerabilitiesByScanId(scanId, defaultParams),
VulnerabilityService.getVulnerabilitiesByScanId(scanId, defaultParams, filter),
enabled: options?.enabled !== undefined ? options.enabled : !!scanId,
select: (response: any) => {
const items = (response?.results ?? []) as any[]
@@ -163,6 +165,7 @@ export function useTargetVulnerabilities(
targetId: number,
params?: GetVulnerabilitiesParams,
options?: { enabled?: boolean },
filter?: string,
) {
const defaultParams: GetVulnerabilitiesParams = {
page: 1,
@@ -171,9 +174,9 @@ export function useTargetVulnerabilities(
}
return useQuery({
queryKey: vulnerabilityKeys.byTarget(targetId, defaultParams),
queryKey: vulnerabilityKeys.byTarget(targetId, defaultParams, filter),
queryFn: () =>
VulnerabilityService.getVulnerabilitiesByTargetId(targetId, defaultParams),
VulnerabilityService.getVulnerabilitiesByTargetId(targetId, defaultParams, filter),
enabled: options?.enabled !== undefined ? options.enabled : !!targetId,
select: (response: any) => {
const items = (response?.results ?? []) as any[]

View File

@@ -8,11 +8,11 @@ const websiteService = {
// 获取目标的网站列表
getTargetWebSites: async (
targetId: number,
params: { page: number; pageSize: number; search?: string }
params: { page: number; pageSize: number; filter?: string }
): Promise<WebSiteListResponse> => {
const searchParam = params.search ? `&search=${encodeURIComponent(params.search)}` : ''
const filterParam = params.filter ? `&filter=${encodeURIComponent(params.filter)}` : ''
const response = await fetch(
`/api/targets/${targetId}/websites/?page=${params.page}&pageSize=${params.pageSize}${searchParam}`
`/api/targets/${targetId}/websites/?page=${params.page}&pageSize=${params.pageSize}${filterParam}`
)
if (!response.ok) {
throw new Error('获取网站列表失败')
@@ -23,11 +23,11 @@ const websiteService = {
// 获取扫描的网站列表
getScanWebSites: async (
scanId: number,
params: { page: number; pageSize: number; search?: string }
params: { page: number; pageSize: number; filter?: string }
): Promise<WebSiteListResponse> => {
const searchParam = params.search ? `&search=${encodeURIComponent(params.search)}` : ''
const filterParam = params.filter ? `&filter=${encodeURIComponent(params.filter)}` : ''
const response = await fetch(
`/api/scans/${scanId}/websites/?page=${params.page}&pageSize=${params.pageSize}${searchParam}`
`/api/scans/${scanId}/websites/?page=${params.page}&pageSize=${params.pageSize}${filterParam}`
)
if (!response.ok) {
throw new Error('获取网站列表失败')
@@ -80,7 +80,7 @@ const websiteService = {
// 获取目标的网站列表
export function useTargetWebSites(
targetId: number,
params: { page: number; pageSize: number; search?: string },
params: { page: number; pageSize: number; filter?: string },
options?: { enabled?: boolean }
) {
return useQuery({
@@ -94,7 +94,7 @@ export function useTargetWebSites(
// 获取扫描的网站列表
export function useScanWebSites(
scanId: number,
params: { page: number; pageSize: number; search?: string },
params: { page: number; pageSize: number; filter?: string },
options?: { enabled?: boolean }
) {
return useQuery({

View File

@@ -59,12 +59,13 @@ export class EndpointService {
* 根据目标ID获取 Endpoint 列表(专用路由)
* @param targetId - 目标ID
* @param params - 其他查询参数
* @param filter - 智能过滤查询字符串
* @returns Promise<GetEndpointsResponse>
*/
static async getEndpointsByTargetId(targetId: number, params: GetEndpointsRequest): Promise<GetEndpointsResponse> {
static async getEndpointsByTargetId(targetId: number, params: GetEndpointsRequest, filter?: string): Promise<GetEndpointsResponse> {
// api-client.ts 会自动将 params 对象的驼峰转换为下划线
const response = await api.get<GetEndpointsResponse>(`/targets/${targetId}/endpoints/`, {
params
params: { ...params, filter }
})
return response.data
}
@@ -73,13 +74,15 @@ export class EndpointService {
* 根据扫描ID获取 Endpoint 列表(历史快照)
* @param scanId - 扫描任务 ID
* @param params - 分页等查询参数
* @param filter - 智能过滤查询字符串
*/
static async getEndpointsByScanId(
scanId: number,
params: GetEndpointsRequest,
filter?: string,
): Promise<any> {
const response = await api.get(`/scans/${scanId}/endpoints/`, {
params,
params: { ...params, filter },
})
return response.data
}

View File

@@ -10,7 +10,7 @@ export class IPAddressService {
params: {
page: params?.page || 1,
pageSize: params?.pageSize || 10,
...(params?.search && { search: params.search }),
...(params?.filter && { filter: params.filter }),
},
})
return response.data
@@ -24,7 +24,7 @@ export class IPAddressService {
params: {
page: params?.page || 1,
pageSize: params?.pageSize || 10,
...(params?.search && { search: params.search }),
...(params?.filter && { filter: params.filter }),
},
})
return response.data

View File

@@ -173,32 +173,32 @@ export class SubdomainService {
return response.data
}
/** 获取目标的子域名列表(支持分页和搜索 */
/** 获取目标的子域名列表(支持分页和过滤 */
static async getSubdomainsByTargetId(
targetId: number,
params?: {
page?: number
pageSize?: number
search?: string
filter?: string
}
): Promise<any> {
const response = await api.get(`/targets/${targetId}/subdomains/`, {
params: {
page: params?.page || 1,
pageSize: params?.pageSize || 10,
...(params?.search && { search: params.search }),
...(params?.filter && { filter: params.filter }),
}
})
return response.data
}
/** 获取扫描的子域名列表(支持分页) */
/** 获取扫描的子域名列表(支持分页和过滤 */
static async getSubdomainsByScanId(
scanId: number,
params?: {
page?: number
pageSize?: number
search?: string
filter?: string
}
): Promise<{
results: Array<{
@@ -225,7 +225,7 @@ export class SubdomainService {
params: {
page: params?.page || 1,
pageSize: params?.pageSize || 10,
...(params?.search && { search: params.search }),
...(params?.filter && { filter: params.filter }),
}
})
return response.data as any

View File

@@ -136,13 +136,13 @@ export async function getTargetEndpoints(
id: number,
page = 1,
pageSize = 10,
search?: string
filter?: string
): Promise<any> {
const response = await api.get(`/targets/${id}/endpoints/`, {
params: {
page,
pageSize,
...(search && { search }),
...(filter && { filter }),
},
})
return response.data

View File

@@ -5,9 +5,10 @@ export class VulnerabilityService {
/** 获取所有漏洞列表(全局漏洞页使用) */
static async getAllVulnerabilities(
params: GetVulnerabilitiesParams,
filter?: string,
): Promise<any> {
const response = await api.get(`/assets/vulnerabilities/`, {
params,
params: { ...params, filter },
})
return response.data
}
@@ -16,9 +17,10 @@ export class VulnerabilityService {
static async getVulnerabilitiesByScanId(
scanId: number,
params: GetVulnerabilitiesParams,
filter?: string,
): Promise<any> {
const response = await api.get(`/scans/${scanId}/vulnerabilities/`, {
params,
params: { ...params, filter },
})
return response.data
}
@@ -27,9 +29,10 @@ export class VulnerabilityService {
static async getVulnerabilitiesByTargetId(
targetId: number,
params: GetVulnerabilitiesParams,
filter?: string,
): Promise<any> {
const response = await api.get(`/targets/${targetId}/vulnerabilities/`, {
params,
params: { ...params, filter },
})
return response.data
}

View File

@@ -15,7 +15,7 @@ export interface IPAddress {
export interface GetIPAddressesParams {
page?: number
pageSize?: number
search?: string
filter?: string // 智能过滤语法字符串
}
export interface GetIPAddressesResponse {

View File

@@ -69,6 +69,7 @@ export interface GetVulnerabilitiesParams extends PaginationParams {
endpointId?: number
severity?: VulnerabilitySeverity
status?: VulnerabilityStatus
filter?: string // 智能过滤语法
}
// 获取漏洞列表响应