Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# Claude Code
.claude/
/.mcp.json

# Test output
/*.out
Expand Down
2 changes: 1 addition & 1 deletion cmd/stackrox-mcp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
func getToolsets(cfg *config.Config, c *client.Client) []toolsets.Toolset {
return []toolsets.Toolset{
toolsetConfig.NewToolset(cfg, c),
toolsetVulnerability.NewToolset(cfg),
toolsetVulnerability.NewToolset(cfg, c),
}
}

Expand Down
13 changes: 13 additions & 0 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/tls"
"fmt"
"sync"
"testing"
"time"

"github.com/pkg/errors"
Expand Down Expand Up @@ -143,6 +144,18 @@ func (c *Client) ReadyConn(ctx context.Context) (*grpc.ClientConn, error) {
return c.conn, nil
}

// SetConnForTesting sets a gRPC connection for testing purposes.
// This should only be used in tests.
func (c *Client) SetConnForTesting(t *testing.T, conn *grpc.ClientConn) {
t.Helper()

c.mu.Lock()
defer c.mu.Unlock()

c.conn = conn
c.connected = true
}

func (c *Client) shouldRedialNoLock() bool {
if !c.connected || c.conn == nil {
return true
Expand Down
127 changes: 96 additions & 31 deletions internal/toolsets/config/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,43 @@ package config

import (
"context"
"fmt"

"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/pkg/errors"
v1 "github.com/stackrox/rox/generated/api/v1"
"github.com/stackrox/stackrox-mcp/internal/client"
"github.com/stackrox/stackrox-mcp/internal/client/auth"
"github.com/stackrox/stackrox-mcp/internal/logging"
"github.com/stackrox/stackrox-mcp/internal/toolsets"
)

const (
defaultOffset = 0

// 0 = no limit.
defaultLimit = 0
)

// listClustersInput defines the input parameters for list_clusters tool.
type listClustersInput struct{}
type listClustersInput struct {
Offset int `json:"offset,omitempty"`
Limit int `json:"limit,omitempty"`
}

// ClusterInfo represents information about a single cluster.
type ClusterInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
}

// listClustersOutput defines the output structure for list_clusters tool.
type listClustersOutput struct {
Clusters []string `json:"clusters"`
Clusters []ClusterInfo `json:"clusters"`
TotalCount int `json:"totalCount"`
Offset int `json:"offset"`
Limit int `json:"limit"`
}

// listClustersTool implements the list_clusters tool.
Expand Down Expand Up @@ -48,63 +69,107 @@ func (t *listClustersTool) GetName() string {
func (t *listClustersTool) GetTool() *mcp.Tool {
return &mcp.Tool{
Name: t.name,
Description: "List all clusters managed by StackRox Central with their IDs, names, and types",
Description: "List all clusters managed by StackRox with their IDs, names, and types",
InputSchema: listClustersInputSchema(),
}
}

func listClustersInputSchema() *jsonschema.Schema {
schema, err := jsonschema.For[listClustersInput](nil)
if err != nil {
logging.Fatal("Could not get jsonschema for list_clusters input", err)

return nil
}

schema.Properties["offset"].Minimum = jsonschema.Ptr(0.0)
schema.Properties["offset"].Default = toolsets.MustJSONMarshal(defaultOffset)
schema.Properties["offset"].Description = "Starting index for pagination (0-based)"

schema.Properties["limit"].Minimum = jsonschema.Ptr(0.0)
schema.Properties["limit"].Default = toolsets.MustJSONMarshal(defaultLimit)
schema.Properties["limit"].Description = "Maximum number of clusters to return (default: 0 - unlimited)"

return schema
}

// RegisterWith registers the list_clusters tool handler with the MCP server.
func (t *listClustersTool) RegisterWith(server *mcp.Server) {
mcp.AddTool(server, t.GetTool(), t.handle)
}

// handle is the placeholder handler for list_clusters tool.
func (t *listClustersTool) handle(
ctx context.Context,
req *mcp.CallToolRequest,
_ listClustersInput,
) (*mcp.CallToolResult, *listClustersOutput, error) {
func (t *listClustersTool) getClusters(ctx context.Context, req *mcp.CallToolRequest) ([]ClusterInfo, error) {
conn, err := t.client.ReadyConn(ctx)
if err != nil {
return nil, nil, errors.Wrap(err, "unable to connect to server")
return nil, errors.Wrap(err, "unable to connect to server")
}

callCtx := auth.WithMCPRequestContext(ctx, req)

// Create ClustersService client
clustersClient := v1.NewClustersServiceClient(conn)

// Call GetClusters
// Call GetClusters to fetch all clusters
resp, err := clustersClient.GetClusters(callCtx, &v1.GetClustersRequest{})
if err != nil {
// Convert gRPC error to client error
clientErr := client.NewError(err, "GetClusters")

return nil, nil, clientErr
return nil, clientErr
}

// Extract cluster information
clusters := make([]string, 0, len(resp.GetClusters()))
// Convert all clusters to ClusterInfo objects
allClusters := make([]ClusterInfo, 0, len(resp.GetClusters()))
for _, cluster := range resp.GetClusters() {
// Format: "ID: <id>, Name: <name>, Type: <type>"
clusterInfo := fmt.Sprintf("ID: %s, Name: %s, Type: %s",
cluster.GetId(),
cluster.GetName(),
cluster.GetType().String())
clusters = append(clusters, clusterInfo)
clusterInfo := ClusterInfo{
ID: cluster.GetId(),
Name: cluster.GetName(),
Type: cluster.GetType().String(),
}
allClusters = append(allClusters, clusterInfo)
}

output := &listClustersOutput{
Clusters: clusters,
return allClusters, nil
}

// handle is the handler for list_clusters tool.
func (t *listClustersTool) handle(
ctx context.Context,
req *mcp.CallToolRequest,
input listClustersInput,
) (*mcp.CallToolResult, *listClustersOutput, error) {
clusters, err := t.getClusters(ctx, req)
if err != nil {
return nil, nil, err
}

totalCount := len(clusters)

// 0 = unlimited.
limit := input.Limit
if limit == 0 {
limit = totalCount
}

// Apply client-side pagination.
var paginatedClusters []ClusterInfo
if input.Offset >= totalCount {
paginatedClusters = []ClusterInfo{}
} else {
end := min(input.Offset+limit, totalCount)
if end < 0 {
end = totalCount
}

paginatedClusters = clusters[input.Offset:end]
}

// Return result with text content
result := &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{
Text: fmt.Sprintf("Found %d cluster(s)", len(clusters)),
},
},
output := &listClustersOutput{
Clusters: paginatedClusters,
TotalCount: totalCount,
Offset: input.Offset,
Limit: input.Limit,
}

return result, output, nil
return nil, output, nil
}
Loading
Loading