diff --git a/mcp_server/tools/__init__.py b/mcp_server/tools/__init__.py index d0d4183..55cd153 100644 --- a/mcp_server/tools/__init__.py +++ b/mcp_server/tools/__init__.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod import os import importlib import logging + class ABCTool(ABC): @classmethod @abstractmethod diff --git a/mcp_server/tools/create_http_application.py b/mcp_server/tools/create_http_application.py index 4827d40..83d9f04 100644 --- a/mcp_server/tools/create_http_application.py +++ b/mcp_server/tools/create_http_application.py @@ -36,7 +36,7 @@ class CreateHttpApplication(BaseModel, ABCTool): @classmethod def tool(self) -> Tool: return Tool( - name="create_http_application", + name="waf_ create_http_application", description="在雷池 WAF 上创建一个站点应用", inputSchema=self.model_json_schema() ) diff --git a/mcp_server/tools/create_ip_custom_rule.py b/mcp_server/tools/create_ip_custom_rule.py index 1913a71..706631f 100644 --- a/mcp_server/tools/create_ip_custom_rule.py +++ b/mcp_server/tools/create_ip_custom_rule.py @@ -49,7 +49,7 @@ class CreateIPCustomRule(BaseModel, ABCTool): @classmethod def tool(self) -> Tool: return Tool( - name="create_ip_custom_rule", - description="在雷池 WAF 上创建一个 ip 的自定义黑名单或者自定义白名单", + name="waf_create_ip_custom_rule", + description="以 客户端 IP 地址为条件,在雷池 WAF 上创建一个黑/白名单", inputSchema=self.model_json_schema() ) \ No newline at end of file diff --git a/mcp_server/tools/create_path_custom_rule.py b/mcp_server/tools/create_path_custom_rule.py index 583c9c5..5c0b3b0 100644 --- a/mcp_server/tools/create_path_custom_rule.py +++ b/mcp_server/tools/create_path_custom_rule.py @@ -47,7 +47,7 @@ class CreatePathCustomRule(BaseModel, ABCTool): @classmethod def tool(self) -> Tool: return Tool( - name="create_path_custom_rule", - description="在雷池 WAF 上创建一个 url 路径的自定义黑名单或者自定义白名单", + name="waf_create_path_custom_rule", + description="以 URL Path 为条件,在雷池 WAF 上创建一个黑/白名单", inputSchema=self.model_json_schema() ) \ No newline at end of file diff --git a/mcp_server/tools/get_attack_events.py b/mcp_server/tools/get_attack_events.py new file mode 100644 index 0000000..e5044bb --- /dev/null +++ b/mcp_server/tools/get_attack_events.py @@ -0,0 +1,33 @@ +from pydantic import BaseModel, Field +from utils.request import get_slce_api +from tools import Tool, ABCTool, tools +from urllib.parse import urlparse +@tools.register +class CreateHttpApplication(BaseModel, ABCTool): + ip: str = Field(default="", description="the attacker's client IP address") + size: int = Field(default=10, min=1, max=100, description="the number of results to return") + start: str = Field(default="", description="start time, millisecond timestamp") + end: str = Field(default="", description="end time, millisecond timestamp") + + @classmethod + async def run(self, arguments:dict) -> str: + try: + req = CreateHttpApplication.model_validate(arguments) + parsed_upstream = urlparse(req.upstream) + if parsed_upstream.scheme not in ["https", "http"]: + return "invalid upstream scheme" + + if parsed_upstream.hostname == "": + return "invalid upstream host" + except Exception as e: + return str(e) + + return await get_slce_api(f"api/open/events?page=1&page_size={req.size}&ip={req.ip}&start={req.start}&end={req.end}") + + @classmethod + def tool(self) -> Tool: + return Tool( + name="waf_get_attack_events", + description="获取雷池 WAF 所记录的攻击事件", + inputSchema=self.model_json_schema() + ) diff --git a/mcp_server/utils/request.py b/mcp_server/utils/request.py index 506d04f..54f1d32 100644 --- a/mcp_server/utils/request.py +++ b/mcp_server/utils/request.py @@ -15,7 +15,21 @@ def check_slce_response(response: httpx.Response) -> str: return "success" -async def post_slce_api(path: str,req_body: dict) -> str: +async def get_slce_api(path: str) -> str: + if not path.startswith("/"): + path = f"/{path}" + + try: + async with AsyncClient(verify=False) as client: + response = await client.get(f"{GLOBAL_CONFIG.SAFELINE_ADDRESS}{path}", json=req_body, headers={ + "X-SLCE-API-TOKEN": f"{GLOBAL_CONFIG.SAFELINE_API_TOKEN}" + }) + return check_slce_response(response) + except Exception as e: + return str(e) + + +async def post_slce_api(path: str, req_body: dict) -> str: if not path.startswith("/"): path = f"/{path}"