diff --git a/.gitignore b/.gitignore index f054949..186b7c9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ # Claude Code .claude/ +/.mcp.json # Test output /*.out diff --git a/cmd/stackrox-mcp/main.go b/cmd/stackrox-mcp/main.go index 4c67ce9..d1feecd 100644 --- a/cmd/stackrox-mcp/main.go +++ b/cmd/stackrox-mcp/main.go @@ -22,7 +22,7 @@ import ( func getToolsets(cfg *config.Config, c *client.Client) []toolsets.Toolset { return []toolsets.Toolset{ toolsetConfig.NewToolset(cfg, c), - toolsetVulnerability.NewToolset(cfg), + toolsetVulnerability.NewToolset(cfg, c), } } diff --git a/internal/client/client.go b/internal/client/client.go index 00cd160..04c8c61 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "fmt" "sync" + "testing" "time" "github.com/pkg/errors" @@ -143,6 +144,18 @@ func (c *Client) ReadyConn(ctx context.Context) (*grpc.ClientConn, error) { return c.conn, nil } +// SetConnForTesting sets a gRPC connection for testing purposes. +// This should only be used in tests. +func (c *Client) SetConnForTesting(t *testing.T, conn *grpc.ClientConn) { + t.Helper() + + c.mu.Lock() + defer c.mu.Unlock() + + c.conn = conn + c.connected = true +} + func (c *Client) shouldRedialNoLock() bool { if !c.connected || c.conn == nil { return true diff --git a/internal/toolsets/config/tools.go b/internal/toolsets/config/tools.go index 2b0c45f..456b6b7 100644 --- a/internal/toolsets/config/tools.go +++ b/internal/toolsets/config/tools.go @@ -2,22 +2,43 @@ package config import ( "context" - "fmt" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/pkg/errors" v1 "github.com/stackrox/rox/generated/api/v1" "github.com/stackrox/stackrox-mcp/internal/client" "github.com/stackrox/stackrox-mcp/internal/client/auth" + "github.com/stackrox/stackrox-mcp/internal/logging" "github.com/stackrox/stackrox-mcp/internal/toolsets" ) +const ( + defaultOffset = 0 + + // 0 = no limit. + defaultLimit = 0 +) + // listClustersInput defines the input parameters for list_clusters tool. -type listClustersInput struct{} +type listClustersInput struct { + Offset int `json:"offset,omitempty"` + Limit int `json:"limit,omitempty"` +} + +// ClusterInfo represents information about a single cluster. +type ClusterInfo struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` +} // listClustersOutput defines the output structure for list_clusters tool. type listClustersOutput struct { - Clusters []string `json:"clusters"` + Clusters []ClusterInfo `json:"clusters"` + TotalCount int `json:"totalCount"` + Offset int `json:"offset"` + Limit int `json:"limit"` } // listClustersTool implements the list_clusters tool. @@ -48,8 +69,28 @@ func (t *listClustersTool) GetName() string { func (t *listClustersTool) GetTool() *mcp.Tool { return &mcp.Tool{ Name: t.name, - Description: "List all clusters managed by StackRox Central with their IDs, names, and types", + Description: "List all clusters managed by StackRox with their IDs, names, and types", + InputSchema: listClustersInputSchema(), + } +} + +func listClustersInputSchema() *jsonschema.Schema { + schema, err := jsonschema.For[listClustersInput](nil) + if err != nil { + logging.Fatal("Could not get jsonschema for list_clusters input", err) + + return nil } + + schema.Properties["offset"].Minimum = jsonschema.Ptr(0.0) + schema.Properties["offset"].Default = toolsets.MustJSONMarshal(defaultOffset) + schema.Properties["offset"].Description = "Starting index for pagination (0-based)" + + schema.Properties["limit"].Minimum = jsonschema.Ptr(0.0) + schema.Properties["limit"].Default = toolsets.MustJSONMarshal(defaultLimit) + schema.Properties["limit"].Description = "Maximum number of clusters to return (default: 0 - unlimited)" + + return schema } // RegisterWith registers the list_clusters tool handler with the MCP server. @@ -57,15 +98,10 @@ func (t *listClustersTool) RegisterWith(server *mcp.Server) { mcp.AddTool(server, t.GetTool(), t.handle) } -// handle is the placeholder handler for list_clusters tool. -func (t *listClustersTool) handle( - ctx context.Context, - req *mcp.CallToolRequest, - _ listClustersInput, -) (*mcp.CallToolResult, *listClustersOutput, error) { +func (t *listClustersTool) getClusters(ctx context.Context, req *mcp.CallToolRequest) ([]ClusterInfo, error) { conn, err := t.client.ReadyConn(ctx) if err != nil { - return nil, nil, errors.Wrap(err, "unable to connect to server") + return nil, errors.Wrap(err, "unable to connect to server") } callCtx := auth.WithMCPRequestContext(ctx, req) @@ -73,38 +109,67 @@ func (t *listClustersTool) handle( // Create ClustersService client clustersClient := v1.NewClustersServiceClient(conn) - // Call GetClusters + // Call GetClusters to fetch all clusters resp, err := clustersClient.GetClusters(callCtx, &v1.GetClustersRequest{}) if err != nil { // Convert gRPC error to client error clientErr := client.NewError(err, "GetClusters") - return nil, nil, clientErr + return nil, clientErr } - // Extract cluster information - clusters := make([]string, 0, len(resp.GetClusters())) + // Convert all clusters to ClusterInfo objects + allClusters := make([]ClusterInfo, 0, len(resp.GetClusters())) for _, cluster := range resp.GetClusters() { - // Format: "ID: , Name: , Type: " - clusterInfo := fmt.Sprintf("ID: %s, Name: %s, Type: %s", - cluster.GetId(), - cluster.GetName(), - cluster.GetType().String()) - clusters = append(clusters, clusterInfo) + clusterInfo := ClusterInfo{ + ID: cluster.GetId(), + Name: cluster.GetName(), + Type: cluster.GetType().String(), + } + allClusters = append(allClusters, clusterInfo) } - output := &listClustersOutput{ - Clusters: clusters, + return allClusters, nil +} + +// handle is the handler for list_clusters tool. +func (t *listClustersTool) handle( + ctx context.Context, + req *mcp.CallToolRequest, + input listClustersInput, +) (*mcp.CallToolResult, *listClustersOutput, error) { + clusters, err := t.getClusters(ctx, req) + if err != nil { + return nil, nil, err + } + + totalCount := len(clusters) + + // 0 = unlimited. + limit := input.Limit + if limit == 0 { + limit = totalCount + } + + // Apply client-side pagination. + var paginatedClusters []ClusterInfo + if input.Offset >= totalCount { + paginatedClusters = []ClusterInfo{} + } else { + end := min(input.Offset+limit, totalCount) + if end < 0 { + end = totalCount + } + + paginatedClusters = clusters[input.Offset:end] } - // Return result with text content - result := &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{ - Text: fmt.Sprintf("Found %d cluster(s)", len(clusters)), - }, - }, + output := &listClustersOutput{ + Clusters: paginatedClusters, + TotalCount: totalCount, + Offset: input.Offset, + Limit: input.Limit, } - return result, output, nil + return nil, output, nil } diff --git a/internal/toolsets/config/tools_test.go b/internal/toolsets/config/tools_test.go index f6475ca..72c4854 100644 --- a/internal/toolsets/config/tools_test.go +++ b/internal/toolsets/config/tools_test.go @@ -1,12 +1,23 @@ package config import ( + "context" + "fmt" + "net" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" + v1 "github.com/stackrox/rox/generated/api/v1" + "github.com/stackrox/rox/generated/storage" "github.com/stackrox/stackrox-mcp/internal/client" + "github.com/stackrox/stackrox-mcp/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + "google.golang.org/grpc/test/bufconn" ) func TestNewListClustersTool(t *testing.T) { @@ -30,6 +41,7 @@ func TestListClustersTool_GetTool(t *testing.T) { require.NotNil(t, mcpTool) assert.Equal(t, "list_clusters", mcpTool.Name) assert.NotEmpty(t, mcpTool.Description) + require.NotNil(t, mcpTool.InputSchema, "InputSchema should be defined") } func TestListClustersTool_RegisterWith(t *testing.T) { @@ -47,3 +59,224 @@ func TestListClustersTool_RegisterWith(t *testing.T) { tool.RegisterWith(server) }) } + +// Mock infrastructure for gRPC testing. + +// mockClustersService implements v1.ClustersServiceServer for testing. +type mockClustersService struct { + v1.UnimplementedClustersServiceServer + + clusters []*storage.Cluster + err error +} + +func (m *mockClustersService) GetClusters( + _ context.Context, + _ *v1.GetClustersRequest, +) (*v1.ClustersList, error) { + if m.err != nil { + return nil, m.err + } + + return &v1.ClustersList{ + Clusters: m.clusters, + }, nil +} + +// setupMockServer creates an in-memory gRPC server using bufconn. +func setupMockServer(mockService *mockClustersService) (*grpc.Server, *bufconn.Listener) { + buffer := 1024 * 1024 + listener := bufconn.Listen(buffer) + + grpcServer := grpc.NewServer() + v1.RegisterClustersServiceServer(grpcServer, mockService) + + go func() { + _ = grpcServer.Serve(listener) + }() + + return grpcServer, listener +} + +// bufDialer creates a dialer function for bufconn. +func bufDialer(listener *bufconn.Listener) func(context.Context, string) (net.Conn, error) { + return func(_ context.Context, _ string) (net.Conn, error) { + return listener.Dial() + } +} + +// createTestClient creates a client connected to the mock server. +func createTestClient(t *testing.T, listener *bufconn.Listener) *client.Client { + t.Helper() + + conn, err := grpc.NewClient( + "passthrough://buffer", + grpc.WithLocalDNSResolution(), + grpc.WithContextDialer(bufDialer(listener)), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + + stackroxClient, err := client.NewClient(&config.CentralConfig{ + URL: "buffer", + }) + require.NoError(t, err) + + // Inject mock connection for testing. + stackroxClient.SetConnForTesting(t, conn) + + return stackroxClient +} + +func TestHandle_DefaultLimit(t *testing.T) { + mockService := &mockClustersService{ + clusters: []*storage.Cluster{ + {Id: "c1", Name: "Cluster 1", Type: storage.ClusterType_KUBERNETES_CLUSTER}, + {Id: "c2", Name: "Cluster 2", Type: storage.ClusterType_KUBERNETES_CLUSTER}, + {Id: "c3", Name: "Cluster 3", Type: storage.ClusterType_KUBERNETES_CLUSTER}, + {Id: "c4", Name: "Cluster 4", Type: storage.ClusterType_KUBERNETES_CLUSTER}, + {Id: "c5", Name: "Cluster 5", Type: storage.ClusterType_KUBERNETES_CLUSTER}, + }, + } + + grpcServer, listener := setupMockServer(mockService) + defer grpcServer.Stop() + + testClient := createTestClient(t, listener) + tool, ok := NewListClustersTool(testClient).(*listClustersTool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := listClustersInput{ + Offset: defaultOffset, + Limit: defaultOffset, + } + + result, output, err := tool.handle(ctx, req, input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) // MCP SDK handles result creation + + require.Len(t, output.Clusters, 5) + assert.Equal(t, 5, output.TotalCount) + assert.Equal(t, 0, output.Offset) + assert.Equal(t, 0, output.Limit) + assert.Equal(t, "Cluster 1", output.Clusters[0].Name) + assert.Equal(t, "Cluster 5", output.Clusters[4].Name) +} + +//nolint:funlen +func TestHandle_WithPagination(t *testing.T) { + totalClusters := 10 + + clusters := make([]*storage.Cluster, totalClusters) + for i := range totalClusters { + clusters[i] = &storage.Cluster{ + Id: fmt.Sprintf("cluster-%d", i), + Name: fmt.Sprintf("Cluster %d", i), + Type: storage.ClusterType_KUBERNETES_CLUSTER, + } + } + + mockService := &mockClustersService{ + clusters: clusters, + } + + grpcServer, listener := setupMockServer(mockService) + defer grpcServer.Stop() + + testClient := createTestClient(t, listener) + tool, ok := NewListClustersTool(testClient).(*listClustersTool) + require.True(t, ok) + + tests := map[string]struct { + offset int + limit int + expectedCount int + expectedFirst string + expectedLast string + }{ + "first page": { + offset: 0, + limit: 3, + expectedCount: 3, + expectedFirst: "cluster-0", + expectedLast: "cluster-2", + }, + "middle page": { + offset: 2, + limit: 3, + expectedCount: 3, + expectedFirst: "cluster-2", + expectedLast: "cluster-4", + }, + "partial page": { + offset: 8, + limit: 10, + expectedCount: 2, + expectedFirst: "cluster-8", + expectedLast: "cluster-9", + }, + "offset beyond total": { + offset: 100, + limit: 10, + expectedCount: 0, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := listClustersInput{ + Offset: testCase.offset, + Limit: testCase.limit, + } + + result, output, err := tool.handle(ctx, req, input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) // MCP SDK handles result creation. + + assert.Len(t, output.Clusters, testCase.expectedCount) + assert.Equal(t, totalClusters, output.TotalCount) + assert.Equal(t, testCase.offset, output.Offset) + assert.Equal(t, testCase.limit, output.Limit) + + if testCase.expectedCount > 0 { + assert.Equal(t, testCase.expectedFirst, output.Clusters[0].ID) + assert.Equal(t, testCase.expectedLast, output.Clusters[testCase.expectedCount-1].ID) + } + }) + } +} + +func TestHandle_GetClustersError(t *testing.T) { + mockService := &mockClustersService{ + err: status.Error(codes.Internal, "test"), + } + + grpcServer, listener := setupMockServer(mockService) + defer grpcServer.Stop() + + testClient := createTestClient(t, listener) + tool, ok := NewListClustersTool(testClient).(*listClustersTool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := listClustersInput{ + Offset: 0, + Limit: 10, + } + + result, output, err := tool.handle(ctx, req, input) + + require.Error(t, err) + assert.Nil(t, result) + assert.Nil(t, output) + assert.Contains(t, err.Error(), "Internal server error") +} diff --git a/internal/toolsets/utils.go b/internal/toolsets/utils.go new file mode 100644 index 0000000..62f7f72 --- /dev/null +++ b/internal/toolsets/utils.go @@ -0,0 +1,18 @@ +package toolsets + +import ( + "encoding/json" + "fmt" + + "github.com/stackrox/stackrox-mcp/internal/logging" +) + +// MustJSONMarshal marshals value into a raw encoded JSON value or crashes. +func MustJSONMarshal(value any) json.RawMessage { + marshaledValue, err := json.Marshal(value) + if err != nil { + logging.Fatal(fmt.Sprintf("marshaling failed for value: %v", value), err) + } + + return marshaledValue +} diff --git a/internal/toolsets/utils_test.go b/internal/toolsets/utils_test.go new file mode 100644 index 0000000..7987dd6 --- /dev/null +++ b/internal/toolsets/utils_test.go @@ -0,0 +1,182 @@ +package toolsets + +import ( + "errors" + "os" + "os/exec" + "testing" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestMustMarshal_SimpleTypes tests marshaling of primitive types. +func TestMustMarshal_SimpleTypes(t *testing.T) { + tests := map[string]struct { + input any + expected string + }{ + "integer": {42, "42"}, + "negative int": {-100, "-100"}, + "zero": {0, "0"}, + "string": {"hello", `"hello"`}, + "empty string": {"", `""`}, + "boolean true": {true, "true"}, + "boolean false": {false, "false"}, + "float": {3.14, "3.14"}, + "negative float": {-2.5, "-2.5"}, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + result := MustJSONMarshal(testCase.input) + + require.NotNil(t, result) + assert.JSONEq(t, testCase.expected, string(result)) + }) + } +} + +// TestMustMarshal_Structs tests marshaling of structs with JSON tags. +func TestMustMarshal_Structs(t *testing.T) { + type SimpleStruct struct { + Name string `json:"name"` + Value int `json:"value"` + } + + type NestedStruct struct { + ID string `json:"id"` + Simple SimpleStruct `json:"simple"` + } + + tests := map[string]struct { + input any + expected string + }{ + "simple struct": { + SimpleStruct{Name: "test", Value: 123}, + `{"name":"test","value":123}`, + }, + "struct with empty values": { + SimpleStruct{}, + `{"name":"","value":0}`, + }, + "nested struct": { + NestedStruct{ + ID: "nested-1", + Simple: SimpleStruct{Name: "inner", Value: 456}, + }, + `{"id":"nested-1","simple":{"name":"inner","value":456}}`, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + result := MustJSONMarshal(testCase.input) + + require.NotNil(t, result) + assert.JSONEq(t, testCase.expected, string(result)) + }) + } +} + +// TestMustMarshal_Collections tests marshaling of slices, arrays, and maps. +func TestMustMarshal_Collections(t *testing.T) { + tests := map[string]struct { + input any + expected string + }{ + "int slice": { + []int{1, 2, 3}, + `[1,2,3]`, + }, + "empty slice": { + []string{}, + `[]`, + }, + "string array": { + [3]string{"a", "b", "c"}, + `["a","b","c"]`, + }, + "map string to int": { + map[string]int{"one": 1, "two": 2}, + `{"one":1,"two":2}`, + }, + "empty map": { + map[string]string{}, + `{}`, + }, + "nested slice": { + [][]int{{1, 2}, {3, 4}}, + `[[1,2],[3,4]]`, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + result := MustJSONMarshal(testCase.input) + + require.NotNil(t, result) + assert.JSONEq(t, testCase.expected, string(result)) + }) + } +} + +// TestMustMarshal_SpecialValues tests marshaling of nil, zero values, and edge cases. +func TestMustMarshal_SpecialValues(t *testing.T) { + tests := map[string]struct { + input any + expected string + }{ + "nil slice": { + ([]int)(nil), + `null`, + }, + "nil map": { + (map[string]int)(nil), + `null`, + }, + "nil pointer": { + (*string)(nil), + `null`, + }, + "pointer to string": { + jsonschema.Ptr("test"), + `"test"`, + }, + "pointer to int": { + jsonschema.Ptr(42), + `42`, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + result := MustJSONMarshal(testCase.input) + + require.NotNil(t, result) + assert.JSONEq(t, testCase.expected, string(result)) + }) + } +} + +func TestMustMarshal_Failure(t *testing.T) { + if os.Getenv("CRASH_MustMarshal") == "true" { + MustJSONMarshal(func() {}) + + return + } + + // Run the test in a subprocess. + cmd := exec.CommandContext(t.Context(), "go", "test", "-test.run=TestMustMarshal_Failure") + + cmd.Env = append(os.Environ(), "CRASH_MustMarshal=true") + err := cmd.Run() + require.Error(t, err) + + exitState := &exec.ExitError{} + correctType := errors.As(err, &exitState) + require.True(t, correctType) + assert.Equal(t, 1, exitState.ExitCode()) +} diff --git a/internal/toolsets/vulnerability/tools.go b/internal/toolsets/vulnerability/tools.go index 03f2bc3..70b1660 100644 --- a/internal/toolsets/vulnerability/tools.go +++ b/internal/toolsets/vulnerability/tools.go @@ -2,63 +2,209 @@ package vulnerability import ( "context" - "errors" + "fmt" + "strings" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/pkg/errors" + v1 "github.com/stackrox/rox/generated/api/v1" + "github.com/stackrox/stackrox-mcp/internal/client" + "github.com/stackrox/stackrox-mcp/internal/client/auth" + "github.com/stackrox/stackrox-mcp/internal/logging" "github.com/stackrox/stackrox-mcp/internal/toolsets" ) -// listClusterCVEsInput defines the input parameters for list_cluster_cves tool. -type listClusterCVEsInput struct { - ClusterID string `json:"clusterId,omitempty"` +const ( + defaultLimit = 50 + maximumLimit = 200.0 +) + +type filterPlatformType string + +const ( + filterPlatformNoFilter filterPlatformType = "NO_FILTER" + filterPlatformUserWorkload filterPlatformType = "USER_WORKLOAD" + filterPlatformPlatform filterPlatformType = "PLATFORM" +) + +// getDeploymentsForCVEInput defines the input parameters for get_deployments_for_cve tool. +type getDeploymentsForCVEInput struct { + CVEName string `json:"cveName"` + FilterClusterID string `json:"filterClusterId,omitempty"` + FilterNamespace string `json:"filterNamespace,omitempty"` + FilterPlatform filterPlatformType `json:"filterPlatform,omitempty"` + Offset int32 `json:"offset,omitempty"` + Limit int32 `json:"limit,omitempty"` +} + +func (input *getDeploymentsForCVEInput) validate() error { + if input.CVEName == "" { + return errors.New("CVE name is required") + } + + return nil +} + +// DeploymentResult contains deployment information. +type DeploymentResult struct { + Name string `json:"name"` + Namespace string `json:"namespace"` + ClusterID string `json:"clusterId"` + ClusterName string `json:"clusterName"` } -// listClusterCVEsOutput defines the output structure for list_cluster_cves tool. -type listClusterCVEsOutput struct { - CVEs []string `json:"cves"` +// getDeploymentsForCVEOutput defines the output structure for get_deployments_for_cve tool. +type getDeploymentsForCVEOutput struct { + Deployments []DeploymentResult `json:"deployments"` } -// listClusterCVEsTool implements the list_cluster_cves tool. -type listClusterCVEsTool struct { - name string +// getDeploymentsForCVETool implements the get_deployments_for_cve tool. +type getDeploymentsForCVETool struct { + name string + client *client.Client } -// NewListClusterCVEsTool creates a new list_cluster_cves tool. -func NewListClusterCVEsTool() toolsets.Tool { - return &listClusterCVEsTool{ - name: "list_cluster_cves", +// NewGetDeploymentsForCVETool creates a new get_deployments_for_cve tool. +func NewGetDeploymentsForCVETool(c *client.Client) toolsets.Tool { + return &getDeploymentsForCVETool{ + name: "get_deployments_for_cve", + client: c, } } // IsReadOnly returns true as this tool only reads data. -func (t *listClusterCVEsTool) IsReadOnly() bool { +func (t *getDeploymentsForCVETool) IsReadOnly() bool { return true } // GetName returns the tool name. -func (t *listClusterCVEsTool) GetName() string { +func (t *getDeploymentsForCVETool) GetName() string { return t.name } // GetTool returns the MCP Tool definition. -func (t *listClusterCVEsTool) GetTool() *mcp.Tool { +func (t *getDeploymentsForCVETool) GetTool() *mcp.Tool { return &mcp.Tool{ - Name: t.name, - //nolint:lll - Description: "List CVEs affecting a specific cluster or all clusters in StackRox Central with CVE names, scores, affected images, and deployments", + Name: t.name, + Description: "Get list of deployments affected by a specific CVE", + InputSchema: getDeploymentsForCVEInputSchema(), } } -// RegisterWith registers the list_cluster_cves tool handler with the MCP server. -func (t *listClusterCVEsTool) RegisterWith(server *mcp.Server) { +// getDeploymentsForCVEInputSchema returns the JSON schema for input validation. +func getDeploymentsForCVEInputSchema() *jsonschema.Schema { + schema, err := jsonschema.For[getDeploymentsForCVEInput](nil) + if err != nil { + logging.Fatal("Could not get jsonschema for get_deployments_for_cve input", err) + + return nil + } + + // CVE name is required. + schema.Required = []string{"cveName"} + + schema.Properties["cveName"].Description = "CVE name to filter deployments (e.g., CVE-2021-44228)" + schema.Properties["filterClusterId"].Description = "Optional cluster ID to filter deployments" + schema.Properties["filterNamespace"].Description = "Optional namespace to filter deployments" + + schema.Properties["filterPlatform"].Description = + fmt.Sprintf("Optional platform filter: %s=no filter, %s=user workload deployments, %s=platform deployments", + filterPlatformNoFilter, filterPlatformUserWorkload, filterPlatformPlatform) + schema.Properties["filterPlatform"].Default = toolsets.MustJSONMarshal(filterPlatformNoFilter) + schema.Properties["filterPlatform"].Enum = []any{ + filterPlatformNoFilter, + filterPlatformUserWorkload, + filterPlatformPlatform, + } + + schema.Properties["offset"].Description = "Pagination offset (default: 0)" + schema.Properties["offset"].Default = toolsets.MustJSONMarshal(0) + schema.Properties["limit"].Minimum = jsonschema.Ptr(0.0) + + schema.Properties["limit"].Description = "Pagination limit: minimum: 1, maximum: 200 (default: 50)" + schema.Properties["limit"].Default = toolsets.MustJSONMarshal(defaultLimit) + schema.Properties["limit"].Minimum = jsonschema.Ptr(1.0) + schema.Properties["limit"].Maximum = jsonschema.Ptr(maximumLimit) + + return schema +} + +// RegisterWith registers the get_deployments_for_cve tool handler with the MCP server. +func (t *getDeploymentsForCVETool) RegisterWith(server *mcp.Server) { mcp.AddTool(server, t.GetTool(), t.handle) } -// handle is the placeholder handler for list_cluster_cves tool. -func (t *listClusterCVEsTool) handle( - _ context.Context, - _ *mcp.CallToolRequest, - _ listClusterCVEsInput, -) (*mcp.CallToolResult, *listClusterCVEsOutput, error) { - return nil, nil, errors.New("list_cluster_cves tool is not yet implemented") +// buildQuery builds query used to search deployments in StackRox Central. +// We will quote values to have strict match. Without quote: CVE-2025-10, would match CVE-2025-101. +func buildQuery(input getDeploymentsForCVEInput) string { + queryParts := []string{fmt.Sprintf("CVE:%q", input.CVEName)} + + if input.FilterClusterID != "" { + queryParts = append(queryParts, fmt.Sprintf("Cluster ID:%q", input.FilterClusterID)) + } + + if input.FilterNamespace != "" { + queryParts = append(queryParts, fmt.Sprintf("Namespace:%q", input.FilterNamespace)) + } + + // Add platform filter if provided. + switch input.FilterPlatform { + case filterPlatformUserWorkload: + queryParts = append(queryParts, "Platform Component:0") + case filterPlatformPlatform: + queryParts = append(queryParts, "Platform Component:1") + case filterPlatformNoFilter: + } + + return strings.Join(queryParts, "+") +} + +// handle is the handler for get_deployments_for_cve tool. +func (t *getDeploymentsForCVETool) handle( + ctx context.Context, + req *mcp.CallToolRequest, + input getDeploymentsForCVEInput, +) (*mcp.CallToolResult, *getDeploymentsForCVEOutput, error) { + err := input.validate() + if err != nil { + return nil, nil, err + } + + conn, err := t.client.ReadyConn(ctx) + if err != nil { + return nil, nil, errors.Wrap(err, "unable to connect to server") + } + + callCtx := auth.WithMCPRequestContext(ctx, req) + deploymentClient := v1.NewDeploymentServiceClient(conn) + + listReq := &v1.RawQuery{ + Query: buildQuery(input), + Pagination: &v1.Pagination{ + Offset: input.Offset, + Limit: input.Limit, + }, + } + + resp, err := deploymentClient.ListDeployments(callCtx, listReq) + if err != nil { + return nil, nil, client.NewError(err, "ListDeployments") + } + + deployments := make([]DeploymentResult, 0, len(resp.GetDeployments())) + for _, deployment := range resp.GetDeployments() { + deployments = append(deployments, DeploymentResult{ + Name: deployment.GetName(), + Namespace: deployment.GetNamespace(), + ClusterID: deployment.GetClusterId(), + ClusterName: deployment.GetCluster(), + }) + } + + output := &getDeploymentsForCVEOutput{ + Deployments: deployments, + } + + return nil, output, nil } diff --git a/internal/toolsets/vulnerability/tools_test.go b/internal/toolsets/vulnerability/tools_test.go index 9e2b0a6..beca9aa 100644 --- a/internal/toolsets/vulnerability/tools_test.go +++ b/internal/toolsets/vulnerability/tools_test.go @@ -1,38 +1,55 @@ package vulnerability import ( + "context" + "fmt" + "net" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" + v1 "github.com/stackrox/rox/generated/api/v1" + "github.com/stackrox/rox/generated/storage" + "github.com/stackrox/stackrox-mcp/internal/client" + "github.com/stackrox/stackrox-mcp/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + "google.golang.org/grpc/test/bufconn" ) -func TestNewListClusterCVEsTool(t *testing.T) { - tool := NewListClusterCVEsTool() +func TestNewGetDeploymentForCVETool(t *testing.T) { + c := &client.Client{} + tool := NewGetDeploymentsForCVETool(c) require.NotNil(t, tool) - assert.Equal(t, "list_cluster_cves", tool.GetName()) + assert.Equal(t, "get_deployments_for_cve", tool.GetName()) } -func TestListClusterCVEsTool_IsReadOnly(t *testing.T) { - tool := NewListClusterCVEsTool() +func TestGetDeploymentForCVETool_IsReadOnly(t *testing.T) { + c := &client.Client{} + tool := NewGetDeploymentsForCVETool(c) - assert.True(t, tool.IsReadOnly(), "list_cluster_cves should be read-only") + assert.True(t, tool.IsReadOnly(), "get_deployments_for_cve should be read-only") } -func TestListClusterCVEsTool_GetTool(t *testing.T) { - tool := NewListClusterCVEsTool() +func TestGetDeploymentForCVETool_GetTool(t *testing.T) { + c := &client.Client{} + tool := NewGetDeploymentsForCVETool(c) mcpTool := tool.GetTool() require.NotNil(t, mcpTool) - assert.Equal(t, "list_cluster_cves", mcpTool.Name) + assert.Equal(t, "get_deployments_for_cve", mcpTool.Name) assert.NotEmpty(t, mcpTool.Description) + assert.NotNil(t, mcpTool.InputSchema) } -func TestListClusterCVEsTool_RegisterWith(t *testing.T) { - tool := NewListClusterCVEsTool() +func TestGetDeploymentForCVETool_RegisterWith(t *testing.T) { + c := &client.Client{} + tool := NewGetDeploymentsForCVETool(c) server := mcp.NewServer( &mcp.Implementation{ Name: "test-server", @@ -46,3 +63,305 @@ func TestListClusterCVEsTool_RegisterWith(t *testing.T) { tool.RegisterWith(server) }) } + +// Unit tests for input validate method. +func TestInputValidate(t *testing.T) { + tests := map[string]struct { + input getDeploymentsForCVEInput + expectError bool + errorMsg string + }{ + "valid input with CVE only": { + input: getDeploymentsForCVEInput{CVEName: "CVE-2021-44228"}, + expectError: false, + }, + "missing CVE name (empty string)": { + input: getDeploymentsForCVEInput{CVEName: ""}, + expectError: true, + errorMsg: "CVE name is required", + }, + "missing CVE name (zero value)": { + input: getDeploymentsForCVEInput{}, + expectError: true, + errorMsg: "CVE name is required", + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + err := testCase.input.validate() + + if testCase.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), testCase.errorMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +// Mock infrastructure for gRPC testing. + +// mockDeploymentService implements v1.DeploymentServiceServer for testing. +type mockDeploymentService struct { + v1.UnimplementedDeploymentServiceServer + + deployments []*storage.ListDeployment + err error + + lastCallQuery string + lastCallLimit int32 + lastCallOffset int32 +} + +func (m *mockDeploymentService) ListDeployments( + _ context.Context, + query *v1.RawQuery, +) (*v1.ListDeploymentsResponse, error) { + m.lastCallQuery = query.GetQuery() + m.lastCallLimit = query.GetPagination().GetLimit() + m.lastCallOffset = query.GetPagination().GetOffset() + + if m.err != nil { + return nil, m.err + } + + return &v1.ListDeploymentsResponse{ + Deployments: m.deployments, + }, nil +} + +// setupMockDeploymentServer creates an in-memory gRPC server using bufconn. +func setupMockDeploymentServer(mockService *mockDeploymentService) (*grpc.Server, *bufconn.Listener) { + buffer := 1024 * 1024 + listener := bufconn.Listen(buffer) + + grpcServer := grpc.NewServer() + v1.RegisterDeploymentServiceServer(grpcServer, mockService) + + go func() { + _ = grpcServer.Serve(listener) + }() + + return grpcServer, listener +} + +// bufDialer creates a dialer function for bufconn. +func bufDialer(listener *bufconn.Listener) func(context.Context, string) (net.Conn, error) { + return func(_ context.Context, _ string) (net.Conn, error) { + return listener.Dial() + } +} + +// createTestClient creates a client connected to the mock server. +func createTestClient(t *testing.T, listener *bufconn.Listener) *client.Client { + t.Helper() + + conn, err := grpc.NewClient( + "passthrough://buffer", + grpc.WithLocalDNSResolution(), + grpc.WithContextDialer(bufDialer(listener)), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + + stackroxClient, err := client.NewClient(&config.CentralConfig{ + URL: "buffer", + }) + require.NoError(t, err) + + // Inject mock connection for testing. + stackroxClient.SetConnForTesting(t, conn) + + return stackroxClient +} + +// Test helper functions. +func getTestDeployments(totalDeployments int) []*storage.ListDeployment { + deployments := make([]*storage.ListDeployment, totalDeployments) + + for i := range totalDeployments { + deployments[i] = &storage.ListDeployment{ + Name: fmt.Sprintf("deployment-%d", i), + Namespace: "default", + ClusterId: "cluster-1", + Cluster: "Production", + } + } + + return deployments +} + +// Integration tests for handle method. +func TestHandle_MissingCVE(t *testing.T) { + mockService := &mockDeploymentService{ + deployments: []*storage.ListDeployment{}, + } + + grpcServer, listener := setupMockDeploymentServer(mockService) + defer grpcServer.Stop() + + testClient := createTestClient(t, listener) + tool, ok := NewGetDeploymentsForCVETool(testClient).(*getDeploymentsForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + inputWithoutCVEName := getDeploymentsForCVEInput{} + + result, output, err := tool.handle(ctx, req, inputWithoutCVEName) + + require.Error(t, err) + assert.Nil(t, result) + assert.Nil(t, output) + assert.Contains(t, err.Error(), "CVE name is required") +} + +func TestHandle_WithPagination(t *testing.T) { + mockService := &mockDeploymentService{ + deployments: getTestDeployments(5), + } + + grpcServer, listener := setupMockDeploymentServer(mockService) + defer grpcServer.Stop() + + testClient := createTestClient(t, listener) + tool, ok := NewGetDeploymentsForCVETool(testClient).(*getDeploymentsForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := getDeploymentsForCVEInput{ + CVEName: "CVE-2021-44228", + Offset: 3, + Limit: 19, + } + + result, output, err := tool.handle(ctx, req, input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + + assert.Len(t, output.Deployments, 5) + assert.Equal(t, int32(3), mockService.lastCallOffset) + assert.Equal(t, int32(19), mockService.lastCallLimit) +} + +func TestHandle_EmptyResults(t *testing.T) { + mockService := &mockDeploymentService{ + deployments: []*storage.ListDeployment{}, + } + + grpcServer, listener := setupMockDeploymentServer(mockService) + defer grpcServer.Stop() + + testClient := createTestClient(t, listener) + tool, ok := NewGetDeploymentsForCVETool(testClient).(*getDeploymentsForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := getDeploymentsForCVEInput{ + CVEName: "CVE-9999-99999", + Limit: defaultLimit, + } + + result, output, err := tool.handle(ctx, req, input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + assert.Empty(t, output.Deployments) +} + +func TestHandle_ListDeploymentsError(t *testing.T) { + mockService := &mockDeploymentService{ + err: status.Error(codes.Internal, "database error"), + } + + grpcServer, listener := setupMockDeploymentServer(mockService) + defer grpcServer.Stop() + + testClient := createTestClient(t, listener) + tool, ok := NewGetDeploymentsForCVETool(testClient).(*getDeploymentsForCVETool) + require.True(t, ok) + + ctx := context.Background() + req := &mcp.CallToolRequest{} + input := getDeploymentsForCVEInput{ + CVEName: "CVE-2021-44228", + Limit: defaultLimit, + } + + result, output, err := tool.handle(ctx, req, input) + + require.Error(t, err) + assert.Nil(t, result) + assert.Nil(t, output) + assert.Contains(t, err.Error(), "Internal server error") +} + +func TestHandle_WithFilters(t *testing.T) { + mockService := &mockDeploymentService{deployments: getTestDeployments(1)} + + grpcServer, listener := setupMockDeploymentServer(mockService) + defer grpcServer.Stop() + + tool, ok := NewGetDeploymentsForCVETool(createTestClient(t, listener)).(*getDeploymentsForCVETool) + require.True(t, ok) + + tests := map[string]struct { + input getDeploymentsForCVEInput + expectedQuery string + }{ + "CVE only": { + input: getDeploymentsForCVEInput{CVEName: "CVE-2021-44228"}, + expectedQuery: `CVE:"CVE-2021-44228"`, + }, + "CVE with cluster": { + input: getDeploymentsForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterID: "cluster-123", + }, + expectedQuery: `CVE:"CVE-2021-44228"+Cluster ID:"cluster-123"`, + }, + "CVE with namespace": { + input: getDeploymentsForCVEInput{ + CVEName: "CVE-2021-44228", + FilterNamespace: "kube-system", + FilterPlatform: filterPlatformNoFilter, + }, + expectedQuery: `CVE:"CVE-2021-44228"+Namespace:"kube-system"`, + }, + "CVE with platform filter 1 (platform)": { + input: getDeploymentsForCVEInput{ + CVEName: "CVE-2021-44228", + FilterPlatform: filterPlatformPlatform, + }, + expectedQuery: `CVE:"CVE-2021-44228"+Platform Component:1`, + }, + "CVE with all filters": { + input: getDeploymentsForCVEInput{ + CVEName: "CVE-2021-44228", + FilterClusterID: "cluster-123", + FilterNamespace: "default", + FilterPlatform: filterPlatformUserWorkload, + }, + expectedQuery: `CVE:"CVE-2021-44228"+Cluster ID:"cluster-123"+Namespace:"default"+Platform Component:0`, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + result, output, err := tool.handle(context.Background(), &mcp.CallToolRequest{}, testCase.input) + + require.NoError(t, err) + require.NotNil(t, output) + assert.Nil(t, result) + assert.Len(t, output.Deployments, 1) + assert.Equal(t, testCase.expectedQuery, mockService.lastCallQuery) + }) + } +} diff --git a/internal/toolsets/vulnerability/toolset.go b/internal/toolsets/vulnerability/toolset.go index 70950f7..ded3c78 100644 --- a/internal/toolsets/vulnerability/toolset.go +++ b/internal/toolsets/vulnerability/toolset.go @@ -2,6 +2,7 @@ package vulnerability import ( + "github.com/stackrox/stackrox-mcp/internal/client" "github.com/stackrox/stackrox-mcp/internal/config" "github.com/stackrox/stackrox-mcp/internal/toolsets" ) @@ -13,11 +14,11 @@ type Toolset struct { } // NewToolset creates a new vulnerability management toolset. -func NewToolset(cfg *config.Config) *Toolset { +func NewToolset(cfg *config.Config, c *client.Client) *Toolset { return &Toolset{ cfg: cfg, tools: []toolsets.Tool{ - NewListClusterCVEsTool(), + NewGetDeploymentsForCVETool(c), }, } } diff --git a/internal/toolsets/vulnerability/toolset_test.go b/internal/toolsets/vulnerability/toolset_test.go index b4cbfaf..c1b58eb 100644 --- a/internal/toolsets/vulnerability/toolset_test.go +++ b/internal/toolsets/vulnerability/toolset_test.go @@ -3,6 +3,7 @@ package vulnerability import ( "testing" + "github.com/stackrox/stackrox-mcp/internal/client" "github.com/stackrox/stackrox-mcp/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,7 +18,7 @@ func TestNewToolset(t *testing.T) { }, } - toolset := NewToolset(cfg) + toolset := NewToolset(cfg, &client.Client{}) require.NotNil(t, toolset) assert.Equal(t, "vulnerability", toolset.GetName()) @@ -32,13 +33,13 @@ func TestToolset_IsEnabled_True(t *testing.T) { }, } - toolset := NewToolset(cfg) + toolset := NewToolset(cfg, &client.Client{}) assert.True(t, toolset.IsEnabled()) tools := toolset.GetTools() require.NotEmpty(t, tools, "Should return tools when enabled") - require.Len(t, tools, 1, "Should have list_cluster_cves tool") - assert.Equal(t, "list_cluster_cves", tools[0].GetName()) + require.Len(t, tools, 1, "Should have tools") + assert.Equal(t, "get_deployments_for_cve", tools[0].GetName()) } func TestToolset_IsEnabled_False(t *testing.T) { @@ -50,7 +51,7 @@ func TestToolset_IsEnabled_False(t *testing.T) { }, } - toolset := NewToolset(cfg) + toolset := NewToolset(cfg, &client.Client{}) assert.False(t, toolset.IsEnabled()) tools := toolset.GetTools()