From eef77dc0bb7b65956cabd2de2e6d9fffc18f4055 Mon Sep 17 00:00:00 2001 From: Ivan Diachenko Date: Tue, 12 May 2026 23:26:51 +0300 Subject: [PATCH 1/2] fix: uniform observability on all rejection paths + RejectInvalidArgs + reconnect on initial connect failure --- CHANGELOG.md | 3 ++- caller.go | 16 ++++++++++++++-- caller_test.go | 33 +++++++++++++++++++++++++++++++++ hooks.go | 1 + multiplexer.go | 22 ++++++++++++++++++---- 5 files changed, 68 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f75836b..50f0145 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,8 +26,9 @@ - Per-tool TTL via `ToolInfo.Custom["cache_ttl"]` (parses `time.Duration` format). - Warn-once (per Multiplexer lifetime) when a cacheable call has no scope set. - **`(*CallResult).Clone() *CallResult`** — exported deep-copy method; `Data` and `Raw` byte slices are independently allocated. Nil receiver returns nil. -- **`RejectReason`** type and four constants: `RejectUnknownServer`, `RejectUnknownTool`, `RejectServerDown`, `RejectBeforeHookAbort`. +- **`RejectReason`** type and five constants: `RejectUnknownServer`, `RejectUnknownTool`, `RejectServerDown`, `RejectBeforeHookAbort`, `RejectInvalidArgs`. - **`OnRejectedCallFunc`** + `WithOnRejectedCall` option — fires before AfterCall on every rejection path; panics recovered. +- AfterCall, OnRejectedCall, and `Metrics.RecordCall` now fire consistently on **all** rejection paths, including invalid-args JSON, placeholder bad fields, schema-validation failures, server-not-found, and server-down. Observability is no longer asymmetric across rejection reasons. - **`OnConnectFunc`** + `WithOnConnect` option — fires once per server after initial successful connect (before `New` returns); tools list is post-MetaEnricher; panics recovered. ### Migrating from v0.3.x diff --git a/caller.go b/caller.go index fd378a7..9d884b1 100644 --- a/caller.go +++ b/caller.go @@ -60,6 +60,7 @@ func (mx *Multiplexer) CallTool(ctx context.Context, server, toolName string, ar ErrServerNotFound, server, strings.Join(mx.ServerNames(), ", ")) mx.fireRejected(ctx, server, toolName, RejectUnknownServer, err) mx.runAfterCall(ctx, server, toolName, ToolInfo{}, argsJSON, nil, err, time.Since(start)) + safeRecordCall(mx.opts.metrics, server, toolName, time.Since(start), err) return nil, err } @@ -90,6 +91,7 @@ func (mx *Multiplexer) CallTool(ctx context.Context, server, toolName string, ar err := fmt.Errorf("%w: %q", ErrServerDown, server) mx.fireRejected(ctx, server, toolName, RejectServerDown, err) mx.runAfterCall(ctx, server, toolName, toolMeta, argsJSON, nil, err, time.Since(start)) + safeRecordCall(mx.opts.metrics, server, toolName, time.Since(start), err) return nil, err } @@ -100,10 +102,18 @@ func (mx *Multiplexer) CallTool(ctx context.Context, server, toolName string, ar if len(argsJSON) > 0 { var rawArgs map[string]any if err := json.Unmarshal(argsJSON, &rawArgs); err != nil { - return nil, fmt.Errorf("mcpx: invalid args json: %w", err) + wrapped := fmt.Errorf("mcpx: invalid args json: %w", err) + mx.fireRejected(ctx, server, toolName, RejectInvalidArgs, wrapped) + mx.runAfterCall(ctx, server, toolName, toolMeta, argsJSON, nil, wrapped, time.Since(start)) + safeRecordCall(mx.opts.metrics, server, toolName, time.Since(start), wrapped) + return nil, wrapped } if bad := findInvalidArgs(rawArgs); len(bad) > 0 { - return nil, &ErrInvalidArgs{BadFields: bad} + ivErr := &ErrInvalidArgs{BadFields: bad} + mx.fireRejected(ctx, server, toolName, RejectInvalidArgs, ivErr) + mx.runAfterCall(ctx, server, toolName, toolMeta, argsJSON, nil, ivErr, time.Since(start)) + safeRecordCall(mx.opts.metrics, server, toolName, time.Since(start), ivErr) + return nil, ivErr } singularMap := mergedSingularMap(mx.opts.resourceSingular, entry.config.ResourceSingular) transformed := entry.config.ArgsTransformers.applyAll(rawArgs, mx.opts.customTransformers, singularMap) @@ -120,6 +130,8 @@ func (mx *Multiplexer) CallTool(ctx context.Context, server, toolName string, ar if mx.opts.schemaValidation { if errs := validateSchema(toolMeta.InputSchema, finalArgs); len(errs) > 0 { ivErr := &ErrInvalidArgs{SchemaErrors: errs} + mx.fireRejected(ctx, server, toolName, RejectInvalidArgs, ivErr) + mx.runAfterCall(ctx, server, toolName, toolMeta, finalArgs, nil, ivErr, time.Since(start)) safeRecordCall(mx.opts.metrics, server, toolName, time.Since(start), ivErr) return nil, ivErr } diff --git a/caller_test.go b/caller_test.go index 9995ad8..02a2518 100644 --- a/caller_test.go +++ b/caller_test.go @@ -839,6 +839,39 @@ func TestOnRejectedCall_PanicRecovered(t *testing.T) { }) } +func TestOnRejectedCall_InvalidArgs(t *testing.T) { + srv := mcptest.NewServer(echoToolSpec("t")) + ts := httptest.NewServer(srv.HTTPHandler()) + defer ts.Close() + defer srv.Close() + + var gotReason mcpx.RejectReason + var afterFired atomic.Bool + var afterErr error + mx, err := mcpx.New(t.Context(), mcpx.MultiplexerConfig{ + Servers: []mcpx.ServerConfig{{Name: "s", Transport: mcpx.TransportHTTP, URL: ts.URL}}, + }, + mcpx.WithHTTPRetryMax(0), + mcpx.WithoutCache(), + mcpx.WithOnRejectedCall(func(_ context.Context, _, _ string, reason mcpx.RejectReason, _ error) { + gotReason = reason + }), + mcpx.WithAfterCall(func(_ context.Context, _, _ string, _ mcpx.ToolInfo, _ json.RawMessage, _ *mcpx.CallResult, callErr error, _ time.Duration) { + afterFired.Store(true) + afterErr = callErr + }), + ) + require.NoError(t, err) + defer mx.Close() + + _, err = mx.CallTool(t.Context(), "s", "t", json.RawMessage(`{"x":"undefined"}`)) + var ivErr *mcpx.ErrInvalidArgs + require.ErrorAs(t, err, &ivErr) + require.Equal(t, mcpx.RejectInvalidArgs, gotReason) + require.True(t, afterFired.Load(), "AfterCall must fire on RejectInvalidArgs") + require.ErrorAs(t, afterErr, &ivErr) +} + // === v0.4.0: Cache === func buildCacheableSrv(t *testing.T) (url string, calls *atomic.Int64) { diff --git a/hooks.go b/hooks.go index 4950906..5c2d25a 100644 --- a/hooks.go +++ b/hooks.go @@ -50,6 +50,7 @@ const ( RejectUnknownTool RejectReason = "unknown_tool" RejectServerDown RejectReason = "server_down" RejectBeforeHookAbort RejectReason = "before_hook_abort" + RejectInvalidArgs RejectReason = "invalid_args" ) // OnRejectedCallFunc is called when CallTool is rejected before dispatch. diff --git a/multiplexer.go b/multiplexer.go index 9b78ee8..78bb211 100644 --- a/multiplexer.go +++ b/multiplexer.go @@ -98,9 +98,17 @@ func New(ctx context.Context, cfg MultiplexerConfig, opts ...Option) (*Multiplex sc = sc.withKindDefaults(ks) } wg.Go(func() { - entry, err := mx.connect(ctx, sc, nil) + refreshCh := make(chan struct{}, 1) + entry, err := mx.connect(ctx, sc, refreshCh) if err != nil { o.logger.Error("mcpx: failed to connect", F("server", sc.Name), F("error", err.Error())) + // Register a down entry so the supervisor can retry. + stub := &serverEntry{config: sc, state: ServerStateDown, refreshCh: refreshCh} + stub.reconnecting.Store(true) + mu.Lock() + mx.servers[sc.Name] = stub + mu.Unlock() + go mx.reconnectServer(ctx, sc.Name) return } mu.Lock() @@ -138,13 +146,19 @@ type KindGroup struct { // ConfigHints returns the kind_hints map from MultiplexerConfig (may be nil). func (mx *Multiplexer) ConfigHints() map[string][]string { return mx.kindHints } -// ServerNames returns the sorted list of registered MCP server names. +// ServerNames returns the sorted list of live (connected) MCP server names. +// Servers that are currently down (initial connect failed or lost) are excluded. func (mx *Multiplexer) ServerNames() []string { mx.mu.RLock() defer mx.mu.RUnlock() names := make([]string, 0, len(mx.servers)) - for n := range mx.servers { - names = append(names, n) + for n, e := range mx.servers { + e.mu.RLock() + st := e.state + e.mu.RUnlock() + if st != ServerStateDown { + names = append(names, n) + } } slices.Sort(names) return names From faab7af8e13400cad75038e6f21036ba644c77fd Mon Sep 17 00:00:00 2001 From: Ivan Diachenko Date: Tue, 12 May 2026 23:30:39 +0300 Subject: [PATCH 2/2] fix: start reconnect goroutines after wg.Wait to avoid data race on mx.servers --- multiplexer.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/multiplexer.go b/multiplexer.go index 78bb211..e15cfe6 100644 --- a/multiplexer.go +++ b/multiplexer.go @@ -92,6 +92,7 @@ func New(ctx context.Context, cfg MultiplexerConfig, opts ...Option) (*Multiplex var wg sync.WaitGroup var mu sync.Mutex + var failedServers []string for _, sc := range cfg.Servers { if ks, ok := cfg.KindSettings[sc.Kind]; ok { @@ -102,13 +103,12 @@ func New(ctx context.Context, cfg MultiplexerConfig, opts ...Option) (*Multiplex entry, err := mx.connect(ctx, sc, refreshCh) if err != nil { o.logger.Error("mcpx: failed to connect", F("server", sc.Name), F("error", err.Error())) - // Register a down entry so the supervisor can retry. stub := &serverEntry{config: sc, state: ServerStateDown, refreshCh: refreshCh} stub.reconnecting.Store(true) mu.Lock() mx.servers[sc.Name] = stub + failedServers = append(failedServers, sc.Name) mu.Unlock() - go mx.reconnectServer(ctx, sc.Name) return } mu.Lock() @@ -124,6 +124,12 @@ func New(ctx context.Context, cfg MultiplexerConfig, opts ...Option) (*Multiplex } wg.Wait() + // Start reconnect goroutines only after wg.Wait() so mx.servers is fully + // populated and all subsequent access goes through mx.mu (not local mu). + for _, name := range failedServers { + go mx.reconnectServer(ctx, name) + } + for name, entry := range mx.servers { go mx.runToolRefresh(ctx, name, entry) }