diff --git a/internal/tools/barbican/barbican.go b/internal/tools/barbican/barbican.go index a09ac96..1cba219 100644 --- a/internal/tools/barbican/barbican.go +++ b/internal/tools/barbican/barbican.go @@ -91,6 +91,9 @@ func getSecretHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { if secretID == "" { return shared.ToolError("secret_id is required"), nil } + if errResult := shared.ValidateUUID(secretID, "secret_id"); errResult != nil { + return errResult, nil + } secret, err := secrets.Get(ctx, client, secretID).Extract() if err != nil { diff --git a/internal/tools/castellum/castellum.go b/internal/tools/castellum/castellum.go index 0c1f15f..02275ad 100644 --- a/internal/tools/castellum/castellum.go +++ b/internal/tools/castellum/castellum.go @@ -11,7 +11,6 @@ import ( "errors" "net/http" "net/url" - "regexp" "github.com/gophercloud/gophercloud/v2" "github.com/mark3labs/mcp-go/mcp" @@ -49,9 +48,6 @@ var listRecentlyFailedOperationsTool = mcp.NewTool("castellum_list_recently_fail mcp.WithString("max_age", mcp.Description("Time window for results (e.g., '12h', '7d'). Default: '1d'")), ) -// uuidPattern validates that a string is a proper UUID to prevent path traversal. -var uuidPattern = regexp.MustCompile(`(?i)^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) - func getProjectResourcesHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { client, err := provider.CastellumClient() @@ -63,8 +59,8 @@ func getProjectResourcesHandler(provider *auth.Provider) mcpserver.ToolHandlerFu if projectID == "" { return shared.ToolError("project_id is required"), nil } - if !uuidPattern.MatchString(projectID) { - return shared.ToolError("project_id must be a valid UUID"), nil + if errResult := shared.ValidateUUID(projectID, "project_id"); errResult != nil { + return errResult, nil } reqURL := client.Endpoint + "v1/projects/" + projectID @@ -161,7 +157,7 @@ func buildOperationsQuery(request mcp.CallToolRequest) (string, error) { params := url.Values{} if projectID := shared.StringParam(request, "project_id"); projectID != "" { - if !uuidPattern.MatchString(projectID) { + if errResult := shared.ValidateUUID(projectID, "project_id"); errResult != nil { return "", errors.New("project_id must be a valid UUID") } params.Set("project", projectID) diff --git a/internal/tools/cinder/cinder.go b/internal/tools/cinder/cinder.go index 5375785..6ff264a 100644 --- a/internal/tools/cinder/cinder.go +++ b/internal/tools/cinder/cinder.go @@ -94,6 +94,9 @@ func getVolumeHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { if volumeID == "" { return shared.ToolError("volume_id is required"), nil } + if errResult := shared.ValidateUUID(volumeID, "volume_id"); errResult != nil { + return errResult, nil + } vol, err := volumes.Get(ctx, client, volumeID).Extract() if err != nil { diff --git a/internal/tools/designate/designate.go b/internal/tools/designate/designate.go index 95710bb..75c57a9 100644 --- a/internal/tools/designate/designate.go +++ b/internal/tools/designate/designate.go @@ -105,6 +105,9 @@ func getZoneHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { if zoneID == "" { return shared.ToolError("zone_id is required"), nil } + if errResult := shared.ValidateUUID(zoneID, "zone_id"); errResult != nil { + return errResult, nil + } zone, err := zones.Get(ctx, client, zoneID).Extract() if err != nil { @@ -130,6 +133,9 @@ func listRecordsetsHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { if zoneID == "" { return shared.ToolError("zone_id is required"), nil } + if errResult := shared.ValidateUUID(zoneID, "zone_id"); errResult != nil { + return errResult, nil + } opts := recordsets.ListOpts{ Name: shared.StringParam(request, "name"), diff --git a/internal/tools/glance/glance.go b/internal/tools/glance/glance.go index 56236d0..3632519 100644 --- a/internal/tools/glance/glance.go +++ b/internal/tools/glance/glance.go @@ -104,6 +104,9 @@ func getImageHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { if imageID == "" { return shared.ToolError("image_id is required"), nil } + if errResult := shared.ValidateUUID(imageID, "image_id"); errResult != nil { + return errResult, nil + } img, err := images.Get(ctx, client, imageID).Extract() if err != nil { diff --git a/internal/tools/hermes/hermes.go b/internal/tools/hermes/hermes.go index 3ac0a97..9db0bf7 100644 --- a/internal/tools/hermes/hermes.go +++ b/internal/tools/hermes/hermes.go @@ -9,8 +9,8 @@ import ( "context" "encoding/json" "net/http" + "net/url" "strconv" - "strings" "github.com/gophercloud/gophercloud/v2" "github.com/mark3labs/mcp-go/mcp" @@ -63,45 +63,38 @@ func listEventsHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { return shared.ToolError("failed to get hermes client: %v", err), nil } - query := make(map[string]string) + params := url.Values{} for _, key := range []string{"target_type", "target_id", "initiator_name", "initiator_id", "action", "outcome", "observer_type", "sort"} { if v := shared.StringParam(request, key); v != "" { - query[key] = v + params.Set(key, v) } } if v := shared.StringParam(request, "time_gte"); v != "" { - query["time"] = "gte:" + v + timeFilter := "gte:" + v if lte := shared.StringParam(request, "time_lte"); lte != "" { - query["time"] += ",lte:" + lte + timeFilter += ",lte:" + lte } + params.Set("time", timeFilter) } limit := int(shared.NumberParam(request, "limit")) if limit <= 0 { limit = 50 } - query["limit"] = strconv.Itoa(limit) + params.Set("limit", strconv.Itoa(limit)) if offset := int(shared.NumberParam(request, "offset")); offset > 0 { - query["offset"] = strconv.Itoa(offset) + params.Set("offset", strconv.Itoa(offset)) } - var buf strings.Builder - buf.WriteString(client.ResourceBase) - buf.WriteString("events") - sep := "?" - for k, v := range query { - buf.WriteString(sep) - buf.WriteString(k) - buf.WriteString("=") - buf.WriteString(v) - sep = "&" + reqURL := client.ResourceBase + "events" + if encoded := params.Encode(); encoded != "" { + reqURL += "?" + encoded } - url := buf.String() var body any //nolint:bodyclose - _, err = client.Get(ctx, url, &body, &gophercloud.RequestOpts{ + _, err = client.Get(ctx, reqURL, &body, &gophercloud.RequestOpts{ OkCodes: []int{http.StatusOK}, }) if err != nil { @@ -127,12 +120,15 @@ func getEventHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { if eventID == "" { return shared.ToolError("event_id is required"), nil } + if errResult := shared.ValidateUUID(eventID, "event_id"); errResult != nil { + return errResult, nil + } - url := client.ResourceBase + "v1/events/" + eventID + reqURL := client.ResourceBase + "v1/events/" + eventID var body any //nolint:bodyclose - _, err = client.Get(ctx, url, &body, &gophercloud.RequestOpts{ + _, err = client.Get(ctx, reqURL, &body, &gophercloud.RequestOpts{ OkCodes: []int{http.StatusOK}, }) if err != nil { @@ -159,11 +155,23 @@ func listAttributesHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { return shared.ToolError("attribute is required"), nil } - url := client.ResourceBase + "attributes/" + attr + // Allowlist of valid attribute names to prevent path traversal. + validAttributes := map[string]bool{ + "target_type": true, + "action": true, + "outcome": true, + "observer_type": true, + "initiator_type": true, + } + if !validAttributes[attr] { + return shared.ToolError("attribute must be one of: target_type, action, outcome, observer_type, initiator_type (got: %q)", attr), nil + } + + reqURL := client.ResourceBase + "attributes/" + attr var body any //nolint:bodyclose - _, err = client.Get(ctx, url, &body, &gophercloud.RequestOpts{ + _, err = client.Get(ctx, reqURL, &body, &gophercloud.RequestOpts{ OkCodes: []int{http.StatusOK}, }) if err != nil { diff --git a/internal/tools/ironic/ironic.go b/internal/tools/ironic/ironic.go index 9b53552..8e7e828 100644 --- a/internal/tools/ironic/ironic.go +++ b/internal/tools/ironic/ironic.go @@ -107,6 +107,9 @@ func getNodeHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { if nodeID == "" { return shared.ToolError("node_id is required"), nil } + if errResult := shared.ValidatePathSegment(nodeID, "node_id"); errResult != nil { + return errResult, nil + } node, err := nodes.Get(ctx, client, nodeID).Extract() if err != nil { diff --git a/internal/tools/keystone/keystone.go b/internal/tools/keystone/keystone.go index 1705c53..9bc824c 100644 --- a/internal/tools/keystone/keystone.go +++ b/internal/tools/keystone/keystone.go @@ -307,6 +307,9 @@ func deleteAppCredentialHandler(provider *auth.Provider) mcpserver.ToolHandlerFu if id == "" { return shared.ToolError("id is required"), nil } + if errResult := shared.ValidateUUID(id, "id"); errResult != nil { + return errResult, nil + } err = applicationcredentials.Delete(ctx, client, userID, id).ExtractErr() if err != nil { diff --git a/internal/tools/manila/manila.go b/internal/tools/manila/manila.go index 98c5c90..7b3fd69 100644 --- a/internal/tools/manila/manila.go +++ b/internal/tools/manila/manila.go @@ -103,6 +103,9 @@ func getShareHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { if shareID == "" { return shared.ToolError("share_id is required"), nil } + if errResult := shared.ValidateUUID(shareID, "share_id"); errResult != nil { + return errResult, nil + } share, err := shares.Get(ctx, client, shareID).Extract() if err != nil { diff --git a/internal/tools/nova/nova.go b/internal/tools/nova/nova.go index 634f2cb..4dd7a44 100644 --- a/internal/tools/nova/nova.go +++ b/internal/tools/nova/nova.go @@ -134,6 +134,9 @@ func getServerHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { if serverID == "" { return shared.ToolError("server_id is required"), nil } + if errResult := shared.ValidateUUID(serverID, "server_id"); errResult != nil { + return errResult, nil + } srv, err := servers.Get(ctx, client, serverID).Extract() if err != nil { @@ -221,6 +224,9 @@ func serverActionHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { if serverID == "" || action == "" { return shared.ToolError("server_id and action are required"), nil } + if errResult := shared.ValidateUUID(serverID, "server_id"); errResult != nil { + return errResult, nil + } switch action { case "start": diff --git a/internal/tools/octavia/octavia.go b/internal/tools/octavia/octavia.go index 2d9a947..f2454f5 100644 --- a/internal/tools/octavia/octavia.go +++ b/internal/tools/octavia/octavia.go @@ -116,6 +116,9 @@ func getLoadbalancerHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc { if lbID == "" { return shared.ToolError("loadbalancer_id is required"), nil } + if errResult := shared.ValidateUUID(lbID, "loadbalancer_id"); errResult != nil { + return errResult, nil + } lb, err := loadbalancers.Get(ctx, client, lbID).Extract() if err != nil { diff --git a/internal/tools/shared/helpers_test.go b/internal/tools/shared/helpers_test.go index b49c0b0..9d9f8c4 100644 --- a/internal/tools/shared/helpers_test.go +++ b/internal/tools/shared/helpers_test.go @@ -119,6 +119,16 @@ func TestSafeQueryParams_EncodesValues(t *testing.T) { params: map[string]string{"name": "my server"}, want: "?name=my+server", }, + { + name: "multiple params joined", + params: map[string]string{"action": "create", "outcome": "success"}, + want: "?action=create&outcome=success", + }, + { + name: "path traversal in value encoded", + params: map[string]string{"name": "../../etc/passwd"}, + want: "?name=..%2F..%2Fetc%2Fpasswd", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {