Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

### Added

- **`ServerConfig.CallTimeout time.Duration`** — per-server call timeout override.
A zero or negative value inherits the multiplexer-wide default set via
`WithCallTimeout` (default 30 s). Use a shorter value for local stdio servers
and a longer value for HTTP servers that may need retries.
- **`WithHealthCheck(interval time.Duration) Option`** — opt-in liveness supervisor.
Probes each server on the given interval via `ListTools`; reconnects with
exponential backoff (1 s → 2 s → … → 60 s cap) when a server is unreachable.
Expand Down
14 changes: 12 additions & 2 deletions caller.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"strings"
"time"

"github.com/modelcontextprotocol/go-sdk/mcp"
)
Expand Down Expand Up @@ -99,14 +100,15 @@ func (mx *Multiplexer) CallTool(ctx context.Context, server, toolName string, ar
}
}

callCtx, cancel := context.WithTimeout(ctx, mx.opts.callTimeout)
timeout := effectiveTimeout(entry.config.CallTimeout, mx.opts.callTimeout)
callCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

rawResult, callErr := entrySess.CallTool(callCtx, params)
if callErr != nil {
var wrapped error
if callCtx.Err() != nil && errors.Is(callCtx.Err(), context.DeadlineExceeded) {
wrapped = fmt.Errorf("mcpx: call timeout %s/%s after %s", server, toolName, mx.opts.callTimeout)
wrapped = fmt.Errorf("mcpx: call timeout %s/%s after %s", server, toolName, timeout)
} else {
wrapped = fmt.Errorf("mcpx: server %s: %w", server, callErr)
}
Expand Down Expand Up @@ -137,6 +139,14 @@ func (mx *Multiplexer) runAfterCall(ctx context.Context, server string, tool Too
}
}

// effectiveTimeout returns perServer if positive, otherwise global.
func effectiveTimeout(perServer, global time.Duration) time.Duration {
if perServer > 0 {
return perServer
}
return global
}

func buildResult(r *mcp.CallToolResult) *CallResult {
if r == nil {
return &CallResult{}
Expand Down
13 changes: 12 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package mcpx

import "errors"
import (
"errors"
"time"
)

// MultiplexerConfig is the top-level config for New().
type MultiplexerConfig struct {
Expand Down Expand Up @@ -45,6 +48,14 @@ type ServerConfig struct {
// HTTP/SSE only.
URL string `json:"url,omitempty"`

// CallTimeout is the maximum duration allowed for a single tool call to
// this server. A zero or negative value inherits the multiplexer-wide
// default set via [WithCallTimeout] (default 30 s).
//
// Use a shorter value for local stdio servers and a longer value for HTTP
// servers that may need retries.
CallTimeout time.Duration `json:"call_timeout,omitempty"`

// Auth is an opaque parameter block read verbatim from the JSON "auth"
// field. The library does not interpret its shape — it is forwarded
// as-is to the AuthFunc registered via [WithAuthFunc].
Expand Down
3 changes: 3 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
// permission changes). An optional WithOnToolsChanged callback notifies the
// consumer after each refresh that produces a different tool list.
//
// Per-server call timeouts are supported via ServerConfig.CallTimeout; a zero
// value inherits the global default set via WithCallTimeout (30 s by default).
//
// The library is logger-agnostic via the Logger interface (4 methods).
// Adapters for go.uber.org/zap and log/slog are provided as separate
// packages under log/zaplog and log/sloglog so the core stays
Expand Down
121 changes: 121 additions & 0 deletions timeout_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package mcpx_test

import (
"encoding/json"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/require"

mcpx "github.com/inhuman/mcp-multiplexer"
"github.com/inhuman/mcp-multiplexer/internal/testutil/mcptest"
)

// TestPerServerTimeout_ShortTimeoutFires verifies that a server with a short
// CallTimeout times out independently from a server with a longer timeout.
func TestPerServerTimeout_ShortTimeoutFires(t *testing.T) {
ctx := t.Context()

slowSrv := mcptest.NewServer(
echoTool("fast_tool"),
mcptest.WithToolDelay("fast_tool", 5*time.Second),
)
slowTS := httptest.NewServer(slowSrv.HTTPHandler())
t.Cleanup(func() { slowTS.Close(); slowSrv.Close() })

fastSrv := mcptest.NewServer(echoTool("fast_tool"))
fastTS := httptest.NewServer(fastSrv.HTTPHandler())
t.Cleanup(func() { fastTS.Close(); fastSrv.Close() })

mx, err := mcpx.New(ctx, mcpx.MultiplexerConfig{
Servers: []mcpx.ServerConfig{
{
Name: "slow",
Transport: mcpx.TransportHTTP,
URL: slowTS.URL,
CallTimeout: 150 * time.Millisecond,
},
{
Name: "fast",
Transport: mcpx.TransportHTTP,
URL: fastTS.URL,
// CallTimeout zero — inherits global 10s
},
},
}, mcpx.WithCallTimeout(10*time.Second), mcpx.WithHTTPRetryMax(0))
require.NoError(t, err)
t.Cleanup(mx.Close)

// Slow server must time out with its 150ms per-server limit.
_, err = mx.CallTool(ctx, "slow", "fast_tool", json.RawMessage(`{}`))
require.Error(t, err)
require.Contains(t, err.Error(), "timeout")

// Fast server should succeed (no artificial delay).
result, err := mx.CallTool(ctx, "fast", "fast_tool", json.RawMessage(`{}`))
require.NoError(t, err)
require.NotNil(t, result)
}

// TestPerServerTimeout_ZeroInheritsGlobal verifies that a server with
// CallTimeout == 0 uses the multiplexer-wide global timeout.
func TestPerServerTimeout_ZeroInheritsGlobal(t *testing.T) {
ctx := t.Context()

// Slow server: delays 200ms — longer than the global 50ms but shorter than 1s.
slowSrv := mcptest.NewServer(
echoTool("slow_tool"),
mcptest.WithToolDelay("slow_tool", 200*time.Millisecond),
)
slowTS := httptest.NewServer(slowSrv.HTTPHandler())
t.Cleanup(func() { slowTS.Close(); slowSrv.Close() })

mx, err := mcpx.New(ctx, mcpx.MultiplexerConfig{
Servers: []mcpx.ServerConfig{
{
Name: "srv",
Transport: mcpx.TransportHTTP,
URL: slowTS.URL,
CallTimeout: 0, // inherit global
},
},
}, mcpx.WithCallTimeout(50*time.Millisecond), mcpx.WithHTTPRetryMax(0))
require.NoError(t, err)
t.Cleanup(mx.Close)

// Global 50ms must fire even though per-server is zero.
_, err = mx.CallTool(ctx, "srv", "slow_tool", json.RawMessage(`{}`))
require.Error(t, err)
require.Contains(t, err.Error(), "timeout")
}

// TestPerServerTimeout_NegativeTreatedAsZero verifies that a negative
// CallTimeout falls back to the global timeout (same as zero).
func TestPerServerTimeout_NegativeTreatedAsZero(t *testing.T) {
ctx := t.Context()

slowSrv := mcptest.NewServer(
echoTool("slow_tool"),
mcptest.WithToolDelay("slow_tool", 200*time.Millisecond),
)
slowTS := httptest.NewServer(slowSrv.HTTPHandler())
t.Cleanup(func() { slowTS.Close(); slowSrv.Close() })

mx, err := mcpx.New(ctx, mcpx.MultiplexerConfig{
Servers: []mcpx.ServerConfig{
{
Name: "srv",
Transport: mcpx.TransportHTTP,
URL: slowTS.URL,
CallTimeout: -1, // must be treated as zero → inherit global
},
},
}, mcpx.WithCallTimeout(50*time.Millisecond), mcpx.WithHTTPRetryMax(0))
require.NoError(t, err)
t.Cleanup(mx.Close)

_, err = mx.CallTool(ctx, "srv", "slow_tool", json.RawMessage(`{}`))
require.Error(t, err)
require.Contains(t, err.Error(), "timeout")
}
Loading