Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions internal/tools/barbican/barbican.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ func getSecretHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
if secretID == "" {
return shared.ToolError("secret_id is required"), nil
}
if errResult := shared.ValidateUUID(secretID, "secret_id"); errResult != nil {
return errResult, nil
}

secret, err := secrets.Get(ctx, client, secretID).Extract()
if err != nil {
Expand Down
10 changes: 3 additions & 7 deletions internal/tools/castellum/castellum.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"errors"
"net/http"
"net/url"
"regexp"

"github.com/gophercloud/gophercloud/v2"
"github.com/mark3labs/mcp-go/mcp"
Expand Down Expand Up @@ -49,9 +48,6 @@ var listRecentlyFailedOperationsTool = mcp.NewTool("castellum_list_recently_fail
mcp.WithString("max_age", mcp.Description("Time window for results (e.g., '12h', '7d'). Default: '1d'")),
)

// uuidPattern validates that a string is a proper UUID to prevent path traversal.
var uuidPattern = regexp.MustCompile(`(?i)^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)

func getProjectResourcesHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
client, err := provider.CastellumClient()
Expand All @@ -63,8 +59,8 @@ func getProjectResourcesHandler(provider *auth.Provider) mcpserver.ToolHandlerFu
if projectID == "" {
return shared.ToolError("project_id is required"), nil
}
if !uuidPattern.MatchString(projectID) {
return shared.ToolError("project_id must be a valid UUID"), nil
if errResult := shared.ValidateUUID(projectID, "project_id"); errResult != nil {
return errResult, nil
}

reqURL := client.Endpoint + "v1/projects/" + projectID
Expand Down Expand Up @@ -161,7 +157,7 @@ func buildOperationsQuery(request mcp.CallToolRequest) (string, error) {
params := url.Values{}

if projectID := shared.StringParam(request, "project_id"); projectID != "" {
if !uuidPattern.MatchString(projectID) {
if errResult := shared.ValidateUUID(projectID, "project_id"); errResult != nil {
return "", errors.New("project_id must be a valid UUID")
}
params.Set("project", projectID)
Expand Down
3 changes: 3 additions & 0 deletions internal/tools/cinder/cinder.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ func getVolumeHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
if volumeID == "" {
return shared.ToolError("volume_id is required"), nil
}
if errResult := shared.ValidateUUID(volumeID, "volume_id"); errResult != nil {
return errResult, nil
}

vol, err := volumes.Get(ctx, client, volumeID).Extract()
if err != nil {
Expand Down
6 changes: 6 additions & 0 deletions internal/tools/designate/designate.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ func getZoneHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
if zoneID == "" {
return shared.ToolError("zone_id is required"), nil
}
if errResult := shared.ValidateUUID(zoneID, "zone_id"); errResult != nil {
return errResult, nil
}

zone, err := zones.Get(ctx, client, zoneID).Extract()
if err != nil {
Expand All @@ -130,6 +133,9 @@ func listRecordsetsHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
if zoneID == "" {
return shared.ToolError("zone_id is required"), nil
}
if errResult := shared.ValidateUUID(zoneID, "zone_id"); errResult != nil {
return errResult, nil
}

opts := recordsets.ListOpts{
Name: shared.StringParam(request, "name"),
Expand Down
3 changes: 3 additions & 0 deletions internal/tools/glance/glance.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ func getImageHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
if imageID == "" {
return shared.ToolError("image_id is required"), nil
}
if errResult := shared.ValidateUUID(imageID, "image_id"); errResult != nil {
return errResult, nil
}

img, err := images.Get(ctx, client, imageID).Extract()
if err != nil {
Expand Down
54 changes: 31 additions & 23 deletions internal/tools/hermes/hermes.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"context"
"encoding/json"
"net/http"
"net/url"
"strconv"
"strings"

"github.com/gophercloud/gophercloud/v2"
"github.com/mark3labs/mcp-go/mcp"
Expand Down Expand Up @@ -63,45 +63,38 @@ func listEventsHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
return shared.ToolError("failed to get hermes client: %v", err), nil
}

query := make(map[string]string)
params := url.Values{}
for _, key := range []string{"target_type", "target_id", "initiator_name", "initiator_id", "action", "outcome", "observer_type", "sort"} {
if v := shared.StringParam(request, key); v != "" {
query[key] = v
params.Set(key, v)
}
}
if v := shared.StringParam(request, "time_gte"); v != "" {
query["time"] = "gte:" + v
timeFilter := "gte:" + v
if lte := shared.StringParam(request, "time_lte"); lte != "" {
query["time"] += ",lte:" + lte
timeFilter += ",lte:" + lte
}
params.Set("time", timeFilter)
}

limit := int(shared.NumberParam(request, "limit"))
if limit <= 0 {
limit = 50
}
query["limit"] = strconv.Itoa(limit)
params.Set("limit", strconv.Itoa(limit))

if offset := int(shared.NumberParam(request, "offset")); offset > 0 {
query["offset"] = strconv.Itoa(offset)
params.Set("offset", strconv.Itoa(offset))
}

var buf strings.Builder
buf.WriteString(client.ResourceBase)
buf.WriteString("events")
sep := "?"
for k, v := range query {
buf.WriteString(sep)
buf.WriteString(k)
buf.WriteString("=")
buf.WriteString(v)
sep = "&"
reqURL := client.ResourceBase + "events"
if encoded := params.Encode(); encoded != "" {
reqURL += "?" + encoded
}
url := buf.String()

var body any
//nolint:bodyclose
_, err = client.Get(ctx, url, &body, &gophercloud.RequestOpts{
_, err = client.Get(ctx, reqURL, &body, &gophercloud.RequestOpts{
OkCodes: []int{http.StatusOK},
})
if err != nil {
Expand All @@ -127,12 +120,15 @@ func getEventHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
if eventID == "" {
return shared.ToolError("event_id is required"), nil
}
if errResult := shared.ValidateUUID(eventID, "event_id"); errResult != nil {
return errResult, nil
}

url := client.ResourceBase + "v1/events/" + eventID
reqURL := client.ResourceBase + "v1/events/" + eventID

var body any
//nolint:bodyclose
_, err = client.Get(ctx, url, &body, &gophercloud.RequestOpts{
_, err = client.Get(ctx, reqURL, &body, &gophercloud.RequestOpts{
OkCodes: []int{http.StatusOK},
})
if err != nil {
Expand All @@ -159,11 +155,23 @@ func listAttributesHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
return shared.ToolError("attribute is required"), nil
}

url := client.ResourceBase + "attributes/" + attr
// Allowlist of valid attribute names to prevent path traversal.
validAttributes := map[string]bool{
"target_type": true,
"action": true,
"outcome": true,
"observer_type": true,
"initiator_type": true,
}
if !validAttributes[attr] {
return shared.ToolError("attribute must be one of: target_type, action, outcome, observer_type, initiator_type (got: %q)", attr), nil
}

reqURL := client.ResourceBase + "attributes/" + attr

var body any
//nolint:bodyclose
_, err = client.Get(ctx, url, &body, &gophercloud.RequestOpts{
_, err = client.Get(ctx, reqURL, &body, &gophercloud.RequestOpts{
OkCodes: []int{http.StatusOK},
})
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions internal/tools/ironic/ironic.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ func getNodeHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
if nodeID == "" {
return shared.ToolError("node_id is required"), nil
}
if errResult := shared.ValidatePathSegment(nodeID, "node_id"); errResult != nil {
return errResult, nil
}

node, err := nodes.Get(ctx, client, nodeID).Extract()
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions internal/tools/keystone/keystone.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ func deleteAppCredentialHandler(provider *auth.Provider) mcpserver.ToolHandlerFu
if id == "" {
return shared.ToolError("id is required"), nil
}
if errResult := shared.ValidateUUID(id, "id"); errResult != nil {
return errResult, nil
}

err = applicationcredentials.Delete(ctx, client, userID, id).ExtractErr()
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions internal/tools/manila/manila.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ func getShareHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
if shareID == "" {
return shared.ToolError("share_id is required"), nil
}
if errResult := shared.ValidateUUID(shareID, "share_id"); errResult != nil {
return errResult, nil
}

share, err := shares.Get(ctx, client, shareID).Extract()
if err != nil {
Expand Down
6 changes: 6 additions & 0 deletions internal/tools/nova/nova.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ func getServerHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
if serverID == "" {
return shared.ToolError("server_id is required"), nil
}
if errResult := shared.ValidateUUID(serverID, "server_id"); errResult != nil {
return errResult, nil
}

srv, err := servers.Get(ctx, client, serverID).Extract()
if err != nil {
Expand Down Expand Up @@ -221,6 +224,9 @@ func serverActionHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
if serverID == "" || action == "" {
return shared.ToolError("server_id and action are required"), nil
}
if errResult := shared.ValidateUUID(serverID, "server_id"); errResult != nil {
return errResult, nil
}

switch action {
case "start":
Expand Down
3 changes: 3 additions & 0 deletions internal/tools/octavia/octavia.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ func getLoadbalancerHandler(provider *auth.Provider) mcpserver.ToolHandlerFunc {
if lbID == "" {
return shared.ToolError("loadbalancer_id is required"), nil
}
if errResult := shared.ValidateUUID(lbID, "loadbalancer_id"); errResult != nil {
return errResult, nil
}

lb, err := loadbalancers.Get(ctx, client, lbID).Extract()
if err != nil {
Expand Down
10 changes: 10 additions & 0 deletions internal/tools/shared/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ func TestSafeQueryParams_EncodesValues(t *testing.T) {
params: map[string]string{"name": "my server"},
want: "?name=my+server",
},
{
name: "multiple params joined",
params: map[string]string{"action": "create", "outcome": "success"},
want: "?action=create&outcome=success",
},
{
name: "path traversal in value encoded",
params: map[string]string{"name": "../../etc/passwd"},
want: "?name=..%2F..%2Fetc%2Fpasswd",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
Loading