Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions api_surface_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,8 @@ func TestAPISurface_WithOnToolsChanged(t *testing.T) {
opt := mcpx.WithOnToolsChanged(func(_ string, _, _ []mcpx.ToolInfo) {})
require.NotNil(t, opt)
}

func TestAPISurface_WithSchemaValidation(t *testing.T) {
opt := mcpx.WithSchemaValidation()
require.NotNil(t, opt)
}
30 changes: 24 additions & 6 deletions caller.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,25 @@ var ErrServerNotFound = errors.New("mcpx: server not found")
// ErrToolNotFound is returned by CallTool when the named tool is not exposed by the server.
var ErrToolNotFound = errors.New("mcpx: tool not found")

// ErrInvalidArgs is returned by CallTool when arguments contain unresolved
// placeholder values (empty strings, "undefined", "null", etc.).
// ErrInvalidArgs is returned by CallTool when arguments fail validation.
// BadFields lists argument paths that contain unresolved placeholder values.
// SchemaErrors lists JSON Schema violations; populated only when
// [WithSchemaValidation] is enabled and args do not conform to the tool schema.
type ErrInvalidArgs struct {
BadFields []string
BadFields []string
SchemaErrors []string
}

func (e *ErrInvalidArgs) Error() string {
return fmt.Sprintf("mcpx: argument(s) have invalid placeholder values: %s",
strings.Join(e.BadFields, ", "))
if len(e.SchemaErrors) == 0 {
return "mcpx: argument(s) have invalid placeholder values: " + strings.Join(e.BadFields, ", ")
}
parts := make([]string, 0, 2)
if len(e.BadFields) > 0 {
parts = append(parts, "placeholder values: "+strings.Join(e.BadFields, ", "))
}
parts = append(parts, "schema violations: "+strings.Join(e.SchemaErrors, "; "))
return "mcpx: invalid arguments: " + strings.Join(parts, "; ")
}

// CallTool invokes the named tool on the named server with the given JSON
Expand All @@ -34,7 +44,7 @@ func (e *ErrInvalidArgs) Error() string {
//
// Errors returned:
// - ErrServerNotFound, ErrToolNotFound — caller mistake.
// - *ErrInvalidArgs — args contain unresolved placeholders.
// - *ErrInvalidArgs — args contain unresolved placeholders or schema violations.
// - errors from BeforeCallHook are propagated as-is.
// - upstream MCP errors are wrapped via fmt.Errorf("server %s: %w", ...).
func (mx *Multiplexer) CallTool(ctx context.Context, server, toolName string, argsJSON json.RawMessage) (*CallResult, error) {
Expand Down Expand Up @@ -98,6 +108,14 @@ func (mx *Multiplexer) CallTool(ctx context.Context, server, toolName string, ar
F("server", server), F("tool", toolName), F("args", transformed))
}

if mx.opts.schemaValidation {
if errs := validateSchema(toolMeta.InputSchema, finalArgs); len(errs) > 0 {
ivErr := &ErrInvalidArgs{SchemaErrors: errs}
safeRecordCall(mx.opts.metrics, server, toolName, time.Since(start), ivErr)
return nil, ivErr
}
}

for _, hook := range mx.opts.beforeCall {
if err := hook(ctx, server, toolMeta, finalArgs); err != nil {
mx.runAfterCall(ctx, server, toolMeta, finalArgs, nil, err)
Expand Down
121 changes: 121 additions & 0 deletions caller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"testing"
"time"

"github.com/google/jsonschema-go/jsonschema"
"github.com/stretchr/testify/require"

mcpx "github.com/inhuman/mcp-multiplexer"
Expand Down Expand Up @@ -413,3 +414,123 @@ func withMultiplexerCustom(t *testing.T, opts ...mcptest.Option) (callFn func()
}
return callFn, cleanup, nil, addAfter
}

// singleWithSchemaAndMxOpts builds a multiplexer with a single server, the
// given tool schema, and additional multiplexer options.
func singleWithSchemaAndMxOpts(t *testing.T, schema *jsonschema.Schema, mxOpts ...mcpx.Option) (*mcpx.Multiplexer, func()) {
t.Helper()
srv := mcptest.NewServer(
mcptest.WithTool(mcptest.ToolSpec{
Name: "tool",
Handler: func(_ context.Context, args map[string]any) (string, error) {
if v, ok := args["name"]; ok {
return fmt.Sprint(v), nil
}
return "ok", nil
},
InputSchema: schema,
}),
)
ts := httptest.NewServer(srv.HTTPHandler())
allOpts := append([]mcpx.Option{mcpx.WithHTTPRetryMax(0)}, mxOpts...)
mx, err := mcpx.New(t.Context(), mcpx.MultiplexerConfig{
Servers: []mcpx.ServerConfig{{Name: "s", Transport: mcpx.TransportHTTP, URL: ts.URL}},
}, allOpts...)
require.NoError(t, err)
return mx, func() { mx.Close(); ts.Close(); srv.Close() }
}

func TestCallTool_SchemaValidation_ValidArgs(t *testing.T) {
schema := &jsonschema.Schema{
Type: "object",
Required: []string{"name"},
Properties: map[string]*jsonschema.Schema{
"name": {Type: "string"},
},
}
mx, cleanup := singleWithSchemaAndMxOpts(t, schema, mcpx.WithSchemaValidation())
defer cleanup()

res, err := mx.CallTool(t.Context(), "s", "tool", []byte(`{"name":"alice"}`))
require.NoError(t, err)
require.Equal(t, "alice", res.Text)
}

func TestCallTool_SchemaValidation_MissingRequired(t *testing.T) {
schema := &jsonschema.Schema{
Type: "object",
Required: []string{"name"},
Properties: map[string]*jsonschema.Schema{
"name": {Type: "string"},
},
}
mx, cleanup := singleWithSchemaAndMxOpts(t, schema, mcpx.WithSchemaValidation())
defer cleanup()

_, err := mx.CallTool(t.Context(), "s", "tool", []byte(`{}`))
require.Error(t, err)
var ivErr *mcpx.ErrInvalidArgs
require.True(t, errors.As(err, &ivErr))
require.NotEmpty(t, ivErr.SchemaErrors)
require.Contains(t, err.Error(), "schema violations")
}

func TestCallTool_SchemaValidation_WrongType(t *testing.T) {
schema := &jsonschema.Schema{
Type: "object",
Properties: map[string]*jsonschema.Schema{
"count": {Type: "integer"},
},
}
mx, cleanup := singleWithSchemaAndMxOpts(t, schema, mcpx.WithSchemaValidation())
defer cleanup()

_, err := mx.CallTool(t.Context(), "s", "tool", []byte(`{"count":"not-a-number"}`))
require.Error(t, err)
var ivErr *mcpx.ErrInvalidArgs
require.True(t, errors.As(err, &ivErr))
require.NotEmpty(t, ivErr.SchemaErrors)
}

func TestCallTool_SchemaValidation_EmptyArgs_RequiredField(t *testing.T) {
schema := &jsonschema.Schema{
Type: "object",
Required: []string{"name"},
}
mx, cleanup := singleWithSchemaAndMxOpts(t, schema, mcpx.WithSchemaValidation())
defer cleanup()

_, err := mx.CallTool(t.Context(), "s", "tool", nil)
require.Error(t, err)
var ivErr *mcpx.ErrInvalidArgs
require.True(t, errors.As(err, &ivErr))
require.NotEmpty(t, ivErr.SchemaErrors)
}

func TestCallTool_SchemaValidation_Disabled_ByDefault(t *testing.T) {
schema := &jsonschema.Schema{
Type: "object",
Required: []string{"name"},
}
mx, cleanup := singleWithSchemaAndMxOpts(t, schema)
defer cleanup()

_, err := mx.CallTool(t.Context(), "s", "tool", []byte(`{}`))
require.NoError(t, err)
}

func TestCallTool_SchemaValidation_NoSchema_Skips(t *testing.T) {
srv := mcptest.NewServer(echoTool("tool"))
ts := httptest.NewServer(srv.HTTPHandler())
defer ts.Close()
defer srv.Close()

mx, err := mcpx.New(t.Context(), mcpx.MultiplexerConfig{
Servers: []mcpx.ServerConfig{{Name: "s", Transport: mcpx.TransportHTTP, URL: ts.URL}},
}, mcpx.WithHTTPRetryMax(0), mcpx.WithSchemaValidation())
require.NoError(t, err)
defer mx.Close()

_, err = mx.CallTool(t.Context(), "s", "tool", []byte(`{"msg":"hi"}`))
require.NoError(t, err)
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
github.com/modelcontextprotocol/go-sdk v0.6.0
github.com/ory/dockertest/v3 v3.12.0
github.com/stretchr/testify v1.11.1
github.com/xeipuuv/gojsonschema v1.2.0
go.uber.org/zap v1.27.0
)

Expand Down Expand Up @@ -38,7 +39,6 @@ require (
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
github.com/xeipuuv/gojsonschema v1.2.0 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.uber.org/multierr v1.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect
Expand Down
9 changes: 9 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type options struct {
healthCheckSet bool // true when WithHealthCheck was called
onReconnect OnReconnectFunc
onToolsChanged OnToolsChangedFunc
schemaValidation bool
}

func defaultOptions() *options {
Expand Down Expand Up @@ -207,3 +208,11 @@ func WithClientIdentity(name, version string) Option {
}
}
}

// WithSchemaValidation enables JSON Schema validation of tool arguments
// against each tool's InputSchema before the call is dispatched. When a
// tool declares no InputSchema the check is skipped. Violations are returned
// as *ErrInvalidArgs with SchemaErrors populated.
func WithSchemaValidation() Option {
return func(o *options) { o.schemaValidation = true }
}
35 changes: 34 additions & 1 deletion validate.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package mcpx

import "strings"
import (
"encoding/json"
"strings"

"github.com/xeipuuv/gojsonschema"
)

// invalidArgValues are placeholder strings indicating the model did not
// resolve the actual value before invoking the tool.
Expand All @@ -21,6 +26,34 @@ func findInvalidArgs(args map[string]any) []string {
return bad
}

// validateSchema validates args against a raw JSON Schema. Returns nil when
// args are valid or schema is empty. Returns a non-empty slice of human-
// readable violation strings otherwise. Empty or nil args are treated as
// an empty object ({}) so required-field checks fire correctly.
func validateSchema(schema, args json.RawMessage) []string {
if len(schema) == 0 {
return nil
}
if len(args) == 0 {
args = json.RawMessage("{}")
}
result, err := gojsonschema.Validate(
gojsonschema.NewBytesLoader(schema),
gojsonschema.NewBytesLoader(args),
)
if err != nil {
return []string{err.Error()}
}
if result.Valid() {
return nil
}
out := make([]string, len(result.Errors()))
for i, e := range result.Errors() {
out[i] = e.String()
}
return out
}

func walkArgs(v any, path string, bad *[]string) {
switch val := v.(type) {
case map[string]any:
Expand Down
49 changes: 49 additions & 0 deletions validate_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mcpx

import (
"encoding/json"
"sort"
"testing"

Expand Down Expand Up @@ -67,3 +68,51 @@ func TestFindInvalidArgs(t *testing.T) {
require.Equal(t, []string{"x"}, bad)
})
}

func TestValidateSchema_EmptySchema_Skips(t *testing.T) {
errs := validateSchema(nil, json.RawMessage(`{"anything":"goes"}`))
require.Empty(t, errs)
}

func TestValidateSchema_ValidArgs(t *testing.T) {
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}},"required":["name"]}`)
errs := validateSchema(schema, json.RawMessage(`{"name":"alice"}`))
require.Empty(t, errs)
}

func TestValidateSchema_MissingRequired(t *testing.T) {
schema := json.RawMessage(`{"type":"object","required":["name"]}`)
errs := validateSchema(schema, json.RawMessage(`{}`))
require.NotEmpty(t, errs)
require.Contains(t, errs[0], "name")
}

func TestValidateSchema_WrongType(t *testing.T) {
schema := json.RawMessage(`{"type":"object","properties":{"count":{"type":"integer"}}}`)
errs := validateSchema(schema, json.RawMessage(`{"count":"not-a-number"}`))
require.NotEmpty(t, errs)
}

func TestValidateSchema_EmptyArgs_ChecksRequired(t *testing.T) {
schema := json.RawMessage(`{"type":"object","required":["name"]}`)
errs := validateSchema(schema, nil)
require.NotEmpty(t, errs)
}

func TestValidateSchema_EmptyArgs_NoRequired_Valid(t *testing.T) {
schema := json.RawMessage(`{"type":"object"}`)
errs := validateSchema(schema, nil)
require.Empty(t, errs)
}

func TestValidateSchema_AdditionalConstraints(t *testing.T) {
schema := json.RawMessage(`{"type":"object","properties":{"age":{"type":"integer","minimum":0,"maximum":150}}}`)
t.Run("valid", func(t *testing.T) {
errs := validateSchema(schema, json.RawMessage(`{"age":25}`))
require.Empty(t, errs)
})
t.Run("below_minimum", func(t *testing.T) {
errs := validateSchema(schema, json.RawMessage(`{"age":-1}`))
require.NotEmpty(t, errs)
})
}
Loading