Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
- Per-tool TTL via `ToolInfo.Custom["cache_ttl"]` (parses `time.Duration` format).
- Warn-once (per Multiplexer lifetime) when a cacheable call has no scope set.
- **`(*CallResult).Clone() *CallResult`** — exported deep-copy method; `Data` and `Raw` byte slices are independently allocated. Nil receiver returns nil.
- **`RejectReason`** type and four constants: `RejectUnknownServer`, `RejectUnknownTool`, `RejectServerDown`, `RejectBeforeHookAbort`.
- **`RejectReason`** type and five constants: `RejectUnknownServer`, `RejectUnknownTool`, `RejectServerDown`, `RejectBeforeHookAbort`, `RejectInvalidArgs`.
- **`OnRejectedCallFunc`** + `WithOnRejectedCall` option — fires before AfterCall on every rejection path; panics recovered.
- AfterCall, OnRejectedCall, and `Metrics.RecordCall` now fire consistently on **all** rejection paths, including invalid-args JSON, placeholder bad fields, schema-validation failures, server-not-found, and server-down. Observability is no longer asymmetric across rejection reasons.
- **`OnConnectFunc`** + `WithOnConnect` option — fires once per server after initial successful connect (before `New` returns); tools list is post-MetaEnricher; panics recovered.

### Migrating from v0.3.x
Expand Down
16 changes: 14 additions & 2 deletions caller.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ func (mx *Multiplexer) CallTool(ctx context.Context, server, toolName string, ar
ErrServerNotFound, server, strings.Join(mx.ServerNames(), ", "))
mx.fireRejected(ctx, server, toolName, RejectUnknownServer, err)
mx.runAfterCall(ctx, server, toolName, ToolInfo{}, argsJSON, nil, err, time.Since(start))
safeRecordCall(mx.opts.metrics, server, toolName, time.Since(start), err)
return nil, err
}

Expand Down Expand Up @@ -90,6 +91,7 @@ func (mx *Multiplexer) CallTool(ctx context.Context, server, toolName string, ar
err := fmt.Errorf("%w: %q", ErrServerDown, server)
mx.fireRejected(ctx, server, toolName, RejectServerDown, err)
mx.runAfterCall(ctx, server, toolName, toolMeta, argsJSON, nil, err, time.Since(start))
safeRecordCall(mx.opts.metrics, server, toolName, time.Since(start), err)
return nil, err
}

Expand All @@ -100,10 +102,18 @@ func (mx *Multiplexer) CallTool(ctx context.Context, server, toolName string, ar
if len(argsJSON) > 0 {
var rawArgs map[string]any
if err := json.Unmarshal(argsJSON, &rawArgs); err != nil {
return nil, fmt.Errorf("mcpx: invalid args json: %w", err)
wrapped := fmt.Errorf("mcpx: invalid args json: %w", err)
mx.fireRejected(ctx, server, toolName, RejectInvalidArgs, wrapped)
mx.runAfterCall(ctx, server, toolName, toolMeta, argsJSON, nil, wrapped, time.Since(start))
safeRecordCall(mx.opts.metrics, server, toolName, time.Since(start), wrapped)
return nil, wrapped
}
if bad := findInvalidArgs(rawArgs); len(bad) > 0 {
return nil, &ErrInvalidArgs{BadFields: bad}
ivErr := &ErrInvalidArgs{BadFields: bad}
mx.fireRejected(ctx, server, toolName, RejectInvalidArgs, ivErr)
mx.runAfterCall(ctx, server, toolName, toolMeta, argsJSON, nil, ivErr, time.Since(start))
safeRecordCall(mx.opts.metrics, server, toolName, time.Since(start), ivErr)
return nil, ivErr
}
singularMap := mergedSingularMap(mx.opts.resourceSingular, entry.config.ResourceSingular)
transformed := entry.config.ArgsTransformers.applyAll(rawArgs, mx.opts.customTransformers, singularMap)
Expand All @@ -120,6 +130,8 @@ func (mx *Multiplexer) CallTool(ctx context.Context, server, toolName string, ar
if mx.opts.schemaValidation {
if errs := validateSchema(toolMeta.InputSchema, finalArgs); len(errs) > 0 {
ivErr := &ErrInvalidArgs{SchemaErrors: errs}
mx.fireRejected(ctx, server, toolName, RejectInvalidArgs, ivErr)
mx.runAfterCall(ctx, server, toolName, toolMeta, finalArgs, nil, ivErr, time.Since(start))
safeRecordCall(mx.opts.metrics, server, toolName, time.Since(start), ivErr)
return nil, ivErr
}
Expand Down
33 changes: 33 additions & 0 deletions caller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,39 @@ func TestOnRejectedCall_PanicRecovered(t *testing.T) {
})
}

func TestOnRejectedCall_InvalidArgs(t *testing.T) {
srv := mcptest.NewServer(echoToolSpec("t"))
ts := httptest.NewServer(srv.HTTPHandler())
defer ts.Close()
defer srv.Close()

var gotReason mcpx.RejectReason
var afterFired atomic.Bool
var afterErr error
mx, err := mcpx.New(t.Context(), mcpx.MultiplexerConfig{
Servers: []mcpx.ServerConfig{{Name: "s", Transport: mcpx.TransportHTTP, URL: ts.URL}},
},
mcpx.WithHTTPRetryMax(0),
mcpx.WithoutCache(),
mcpx.WithOnRejectedCall(func(_ context.Context, _, _ string, reason mcpx.RejectReason, _ error) {
gotReason = reason
}),
mcpx.WithAfterCall(func(_ context.Context, _, _ string, _ mcpx.ToolInfo, _ json.RawMessage, _ *mcpx.CallResult, callErr error, _ time.Duration) {
afterFired.Store(true)
afterErr = callErr
}),
)
require.NoError(t, err)
defer mx.Close()

_, err = mx.CallTool(t.Context(), "s", "t", json.RawMessage(`{"x":"undefined"}`))
var ivErr *mcpx.ErrInvalidArgs
require.ErrorAs(t, err, &ivErr)
require.Equal(t, mcpx.RejectInvalidArgs, gotReason)
require.True(t, afterFired.Load(), "AfterCall must fire on RejectInvalidArgs")
require.ErrorAs(t, afterErr, &ivErr)
}

// === v0.4.0: Cache ===

func buildCacheableSrv(t *testing.T) (url string, calls *atomic.Int64) {
Expand Down
1 change: 1 addition & 0 deletions hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ const (
RejectUnknownTool RejectReason = "unknown_tool"
RejectServerDown RejectReason = "server_down"
RejectBeforeHookAbort RejectReason = "before_hook_abort"
RejectInvalidArgs RejectReason = "invalid_args"
)

// OnRejectedCallFunc is called when CallTool is rejected before dispatch.
Expand Down
28 changes: 24 additions & 4 deletions multiplexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,23 @@ func New(ctx context.Context, cfg MultiplexerConfig, opts ...Option) (*Multiplex

var wg sync.WaitGroup
var mu sync.Mutex
var failedServers []string

for _, sc := range cfg.Servers {
if ks, ok := cfg.KindSettings[sc.Kind]; ok {
sc = sc.withKindDefaults(ks)
}
wg.Go(func() {
entry, err := mx.connect(ctx, sc, nil)
refreshCh := make(chan struct{}, 1)
entry, err := mx.connect(ctx, sc, refreshCh)
if err != nil {
o.logger.Error("mcpx: failed to connect", F("server", sc.Name), F("error", err.Error()))
stub := &serverEntry{config: sc, state: ServerStateDown, refreshCh: refreshCh}
stub.reconnecting.Store(true)
mu.Lock()
mx.servers[sc.Name] = stub
failedServers = append(failedServers, sc.Name)
mu.Unlock()
return
}
mu.Lock()
Expand All @@ -116,6 +124,12 @@ func New(ctx context.Context, cfg MultiplexerConfig, opts ...Option) (*Multiplex
}
wg.Wait()

// Start reconnect goroutines only after wg.Wait() so mx.servers is fully
// populated and all subsequent access goes through mx.mu (not local mu).
for _, name := range failedServers {
go mx.reconnectServer(ctx, name)
}

for name, entry := range mx.servers {
go mx.runToolRefresh(ctx, name, entry)
}
Expand All @@ -138,13 +152,19 @@ type KindGroup struct {
// ConfigHints returns the kind_hints map from MultiplexerConfig (may be nil).
func (mx *Multiplexer) ConfigHints() map[string][]string { return mx.kindHints }

// ServerNames returns the sorted list of registered MCP server names.
// ServerNames returns the sorted list of live (connected) MCP server names.
// Servers that are currently down (initial connect failed or lost) are excluded.
func (mx *Multiplexer) ServerNames() []string {
mx.mu.RLock()
defer mx.mu.RUnlock()
names := make([]string, 0, len(mx.servers))
for n := range mx.servers {
names = append(names, n)
for n, e := range mx.servers {
e.mu.RLock()
st := e.state
e.mu.RUnlock()
if st != ServerStateDown {
names = append(names, n)
}
}
slices.Sort(names)
return names
Expand Down
Loading