diff --git a/mcp_go/internal/api/analyze/get_event_list.go b/mcp_go/internal/api/analyze/get_event_list.go new file mode 100644 index 0000000..29eac72 --- /dev/null +++ b/mcp_go/internal/api/analyze/get_event_list.go @@ -0,0 +1,47 @@ +package analyze + +import ( + "context" + "fmt" + + "github.com/chaitin/SafeLine/mcp_server/internal/api" +) + +type GetEventListRequest struct { + Page int `json:"page"` + PageSize int `json:"page_size"` + IP string `json:"ip"` + Start int64 `json:"start"` + End int64 `json:"end"` +} + +type GetEventListResponse struct { + Nodes []Event `json:"nodes"` + Total int64 `json:"total"` +} + +type Event struct { + ID uint `json:"id"` + IP string `json:"ip"` + Protocol int `json:"protocol"` + Host string `json:"host"` + DstPort uint64 `json:"dst_port"` + UpdatedAt int64 `json:"updated_at"` + StartAt int64 `json:"start_at"` + EndAt int64 `json:"end_at"` + DenyCount int64 `json:"deny_count"` + PassCount int64 `json:"pass_count"` + Finished bool `json:"finished"` + Country string `json:"country"` + Province string `json:"province"` + City string `json:"city"` +} + +func GetEventList(ctx context.Context, req *GetEventListRequest) (*GetEventListResponse, error) { + var resp api.Response[GetEventListResponse] + err := api.Service().Get(ctx, fmt.Sprintf("/api/open/events?page=%d&page_size=%d&ip=%s&start=%d&end=%d", req.Page, req.PageSize, req.IP, req.Start, req.End), &resp) + if err != nil { + return nil, err + } + return &resp.Data, nil +} diff --git a/mcp_go/internal/api/client.go b/mcp_go/internal/api/client.go index dbaf226..7978172 100644 --- a/mcp_go/internal/api/client.go +++ b/mcp_go/internal/api/client.go @@ -8,10 +8,10 @@ import ( "fmt" "io" "net/http" - "net/url" "time" "github.com/chaitin/SafeLine/mcp_server/pkg/errors" + "github.com/chaitin/SafeLine/mcp_server/pkg/logger" ) // Client API client @@ -80,10 +80,7 @@ func NewClient(opts ...ClientOption) *Client { // Request Send request func (c *Client) Request(ctx context.Context, method, path string, body interface{}, result interface{}) error { - reqURL, err := url.JoinPath(c.baseURL, path) - if err != nil { - return errors.Wrap(err, "invalid URL path") - } + reqURL := fmt.Sprintf("%s%s", c.baseURL, path) var bodyReader io.Reader if body != nil { @@ -93,7 +90,7 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac } bodyReader = bytes.NewReader(bodyBytes) } - + logger.With("url", reqURL).Debug("request url") req, err := http.NewRequestWithContext(ctx, method, reqURL, bodyReader) if err != nil { return errors.Wrap(err, "create request failed") diff --git a/mcp_go/internal/tools/analyze/get_atttack_events.go b/mcp_go/internal/tools/analyze/get_atttack_events.go new file mode 100644 index 0000000..bc68cd6 --- /dev/null +++ b/mcp_go/internal/tools/analyze/get_atttack_events.go @@ -0,0 +1,45 @@ +package analyze + +import ( + "context" + + "github.com/chaitin/SafeLine/mcp_server/internal/api/analyze" + "github.com/chaitin/SafeLine/mcp_server/pkg/logger" +) + +type GetAttackEventsParams struct { + IP string `json:"ip" desc:"ip" required:"false"` + Page int `json:"page" desc:"page" required:"false" default:"1"` + PageSize int `json:"page_size" desc:"page size" required:"false" default:"10"` + Start int64 `json:"start" desc:"start unix timestamp" required:"false"` + End int64 `json:"end" desc:"end unix timestamp" required:"false"` +} + +type GetAttackEvents struct{} + +func (t *GetAttackEvents) Name() string { + return "get_attack_events" +} + +func (t *GetAttackEvents) Description() string { + return "get attack events" +} + +func (t *GetAttackEvents) Validate(params GetAttackEventsParams) error { + return nil +} + +func (t *GetAttackEvents) Execute(ctx context.Context, params GetAttackEventsParams) (analyze.GetEventListResponse, error) { + resp, err := analyze.GetEventList(ctx, &analyze.GetEventListRequest{ + IP: params.IP, + PageSize: params.PageSize, + Page: params.Page, + Start: params.Start, + End: params.End, + }) + if err != nil { + return analyze.GetEventListResponse{}, err + } + logger.With("total", resp.Total).Info("get attack events") + return *resp, nil +} diff --git a/mcp_go/internal/tools/init.go b/mcp_go/internal/tools/init.go index 8c10199..03f850c 100644 --- a/mcp_go/internal/tools/init.go +++ b/mcp_go/internal/tools/init.go @@ -1,12 +1,19 @@ package tools import ( + "github.com/chaitin/SafeLine/mcp_server/internal/tools/analyze" "github.com/chaitin/SafeLine/mcp_server/internal/tools/app" "github.com/chaitin/SafeLine/mcp_server/internal/tools/rule" ) func init() { + // app AppendTool(&app.CreateApp{}) + + // rule AppendTool(&rule.CreateBlacklistRule{}) AppendTool(&rule.CreateWhitelistRule{}) + + // analyze + AppendTool(&analyze.GetAttackEvents{}) } diff --git a/mcp_go/pkg/mcp/schema.go b/mcp_go/pkg/mcp/schema.go index df90a91..22f05d8 100644 --- a/mcp_go/pkg/mcp/schema.go +++ b/mcp_go/pkg/mcp/schema.go @@ -3,6 +3,7 @@ package mcp import ( "encoding/json" "reflect" + "strconv" "strings" "github.com/mark3labs/mcp-go/mcp" @@ -26,6 +27,9 @@ func SchemaToOptions(schema any) ([]mcp.ToolOption, error) { desc := field.Tag.Get("desc") required := field.Tag.Get("required") == "true" enumTag := field.Tag.Get("enum") + defaultTag := field.Tag.Get("default") + minTag := field.Tag.Get("min") + maxTag := field.Tag.Get("max") opts := []mcp.PropertyOption{} if desc != "" { @@ -41,10 +45,33 @@ func SchemaToOptions(schema any) ([]mcp.ToolOption, error) { switch field.Type.Kind() { case reflect.Int: + if defaultTag != "" { + if defaultValue, err := strconv.Atoi(defaultTag); err == nil { + opts = append(opts, mcp.DefaultNumber(float64(defaultValue))) + } + } + if minTag != "" { + if minValue, err := strconv.Atoi(minTag); err == nil { + opts = append(opts, mcp.Min(float64(minValue))) + } + } + if maxTag != "" { + if maxValue, err := strconv.Atoi(maxTag); err == nil { + opts = append(opts, mcp.Max(float64(maxValue))) + } + } options = append(options, mcp.WithNumber(jsonTag, opts...)) case reflect.Bool: + if defaultTag != "" { + if defaultValue, err := strconv.ParseBool(defaultTag); err == nil { + opts = append(opts, mcp.DefaultBool(defaultValue)) + } + } options = append(options, mcp.WithBoolean(jsonTag, opts...)) case reflect.String: + if defaultTag != "" { + opts = append(opts, mcp.DefaultString(defaultTag)) + } options = append(options, mcp.WithString(jsonTag, opts...)) case reflect.Struct: subSchema := reflect.New(field.Type).Interface() diff --git a/mcp_go/pkg/mcp/schema_test.go b/mcp_go/pkg/mcp/schema_test.go index 1fd23ea..b165819 100644 --- a/mcp_go/pkg/mcp/schema_test.go +++ b/mcp_go/pkg/mcp/schema_test.go @@ -22,6 +22,24 @@ func TestSchemaToOptions(t *testing.T) { mcp.WithNumber("a", mcp.Required(), mcp.Description("number a")), ), }, + { + name: "test number default", + args: struct { + A int `json:"a" desc:"number a" required:"true" default:"10"` + }{}, + want: mcp.NewTool("test number default", + mcp.WithNumber("a", mcp.Required(), mcp.Description("number a"), mcp.DefaultNumber(10)), + ), + }, + { + name: "test number min max", + args: struct { + A int `json:"a" desc:"number a" required:"true" min:"10" max:"20"` + }{}, + want: mcp.NewTool("test number min max", + mcp.WithNumber("a", mcp.Required(), mcp.Description("number a"), mcp.Min(10), mcp.Max(20)), + ), + }, { name: "test number optional", args: struct { @@ -49,6 +67,15 @@ func TestSchemaToOptions(t *testing.T) { mcp.WithString("a", mcp.Required(), mcp.Description("string a")), ), }, + { + name: "test string default", + args: struct { + A string `json:"a" desc:"string a" required:"true" default:"hello"` + }{}, + want: mcp.NewTool("test string default", + mcp.WithString("a", mcp.Required(), mcp.Description("string a"), mcp.DefaultString("hello")), + ), + }, { name: "test string enum", args: struct {