From cda3f768d2d0b5db41a23c29c52314125c37a3be Mon Sep 17 00:00:00 2001 From: David Gageot Date: Sat, 25 Apr 2026 11:29:06 +0200 Subject: [PATCH 1/3] refactor(agent): separate toolset notices from warnings Recovery messages (a previously-failed toolset is now available again) were going through the same warning queue as real failures. After lazy- init OAuth completed and the toolset reconnected, the user saw Some toolsets failed to initialize for agent 'root'. Details: - mcp(remote host=mcp.slack.com transport=streamable) is now available The body says 'is now available' (success) but the framing says 'failed' (failure), which reads like 'is NOT available'. This change splits the agent's queue in two: - AddToolWarning / DrainWarnings: real failures (start failed, list failed, ...). Exported because the runtime emits these from the startup tool-loading path (added in a follow-up commit). - AddToolNotice / DrainNotices: positive, informational status updates. The 'now available' recovery path now uses this queue. emitAgentWarnings drains both, formatting failures with the existing 'Some toolsets failed to initialize' banner and notices with a neutral 'Toolset status update' header so each reads correctly on its own. Assisted-By: docker-agent --- pkg/agent/agent.go | 44 ++++++++++++++++++----- pkg/agent/agent_test.go | 14 ++++---- pkg/agent/opts.go | 2 +- pkg/runtime/loop.go | 23 ++++++++++-- pkg/runtime/runtime_test.go | 70 +++++++++++++++++++++++++++++++++++++ 5 files changed, 135 insertions(+), 18 deletions(-) diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 154f104fd..f94f711e8 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -42,11 +42,13 @@ type Agent struct { commands types.Commands hooks *latest.HooksConfig - // warningsMu guards pendingWarnings. addToolWarning and DrainWarnings - // may be called concurrently from the runtime loop, the MCP server, - // the TUI and session manager. + // warningsMu guards pendingWarnings and pendingNotices. AddToolWarning, + // AddToolNotice, DrainWarnings and DrainNotices may be called + // concurrently from the runtime loop, the MCP server, the TUI and + // session manager. warningsMu sync.Mutex pendingWarnings []string + pendingNotices []string } // New creates a new agent @@ -238,7 +240,7 @@ func (a *Agent) collectTools(ctx context.Context) ([]tools.Tool, error) { if err != nil { desc := tools.DescribeToolSet(toolSet) slog.Warn("Toolset listing failed; skipping", "agent", a.Name(), "toolset", desc, "error", err) - a.addToolWarning(fmt.Sprintf("%s list failed: %v", desc, err)) + a.AddToolWarning(fmt.Sprintf("%s list failed: %v", desc, err)) continue } agentTools = append(agentTools, ta...) @@ -271,7 +273,7 @@ func (a *Agent) ensureToolSetsAreStarted(ctx context.Context) { if toolSet.ShouldReportFailure() { desc := tools.DescribeToolSet(toolSet) slog.Warn("Toolset start failed; will retry on next turn", "agent", a.Name(), "toolset", desc, "error", err) - a.addToolWarning(fmt.Sprintf("%s start failed: %v", desc, err)) + a.AddToolWarning(fmt.Sprintf("%s start failed: %v", desc, err)) } else { desc := tools.DescribeToolSet(toolSet) slog.Debug("Toolset still unavailable; retrying next turn", "agent", a.Name(), "toolset", desc, "error", err) @@ -282,13 +284,17 @@ func (a *Agent) ensureToolSetsAreStarted(ctx context.Context) { if toolSet.ConsumeRecovery() { desc := tools.DescribeToolSet(toolSet) slog.Info("Toolset now available", "agent", a.Name(), "toolset", desc) - a.addToolWarning(desc + " is now available") + a.AddToolNotice(desc + " is now available") } } } -// addToolWarning records a warning generated while loading or starting toolsets. -func (a *Agent) addToolWarning(msg string) { +// AddToolWarning records a warning generated while loading or starting toolsets. +// Warnings represent real failures the user should know about (a remote MCP +// server returning 4xx, an MCP binary missing, ...). For positive notices +// (a previously-failed toolset becoming available again) use AddToolNotice +// instead so the message isn't framed as a failure. +func (a *Agent) AddToolWarning(msg string) { if msg == "" { return } @@ -297,6 +303,19 @@ func (a *Agent) addToolWarning(msg string) { a.warningsMu.Unlock() } +// AddToolNotice records a positive, informational notice about a toolset +// (typically: a previously-failed toolset is now available). Notices are +// surfaced to the user separately from warnings so the framing doesn't +// say "failed to initialize" for a recovery message. +func (a *Agent) AddToolNotice(msg string) { + if msg == "" { + return + } + a.warningsMu.Lock() + a.pendingNotices = append(a.pendingNotices, msg) + a.warningsMu.Unlock() +} + // DrainWarnings returns pending warnings and clears them. func (a *Agent) DrainWarnings() []string { a.warningsMu.Lock() @@ -306,6 +325,15 @@ func (a *Agent) DrainWarnings() []string { return warnings } +// DrainNotices returns pending notices and clears them. +func (a *Agent) DrainNotices() []string { + a.warningsMu.Lock() + defer a.warningsMu.Unlock() + notices := a.pendingNotices + a.pendingNotices = nil + return notices +} + func (a *Agent) StopToolSets(ctx context.Context) error { for _, toolSet := range a.toolsets { // Only stop toolsets that were successfully started diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 63790795b..66d3d2e3e 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -253,21 +253,23 @@ func TestAgentReProbeEmitsWarningThenNotice(t *testing.T) { } a := New("root", "test", WithToolSets(stub)) - // Turn 1: start fails → 1 warning, 0 tools. + // Turn 1: start fails → 1 warning, 0 tools, no notice yet. got, err := a.Tools(t.Context()) require.NoError(t, err) assert.Empty(t, got, "turn 1: no tools while toolset is unavailable") warnings := a.DrainWarnings() require.Len(t, warnings, 1, "turn 1: exactly one warning expected") assert.Contains(t, warnings[0], "start failed") + assert.Empty(t, a.DrainNotices(), "turn 1: no recovery notice while still failing") - // Turn 2: start succeeds → 1 recovery warning, tools available. + // Turn 2: start succeeds → 1 recovery NOTICE (not warning), tools available. got, err = a.Tools(t.Context()) require.NoError(t, err) assert.Len(t, got, 1, "turn 2: tool should be available after recovery") - recovery := a.DrainWarnings() - require.Len(t, recovery, 1, "turn 2: exactly one recovery warning expected") - assert.Contains(t, recovery[0], "now available", "turn 2: recovery warning must mention availability") + assert.Empty(t, a.DrainWarnings(), "turn 2: recovery is a notice, not a failure warning") + notices := a.DrainNotices() + require.Len(t, notices, 1, "turn 2: exactly one recovery notice expected") + assert.Contains(t, notices[0], "now available", "turn 2: recovery notice must mention availability") } // TestAgentNoDuplicateStartWarnings verifies that repeated failures generate @@ -318,7 +320,7 @@ func TestAgentWarningsConcurrentAccess(t *testing.T) { go func() { defer wg.Done() for range perWriter { - a.addToolWarning("boom") + a.AddToolWarning("boom") } }() } diff --git a/pkg/agent/opts.go b/pkg/agent/opts.go index e66a598c6..e2b4aa344 100644 --- a/pkg/agent/opts.go +++ b/pkg/agent/opts.go @@ -162,7 +162,7 @@ func WithCommands(commands types.Commands) Opt { func WithLoadTimeWarnings(warnings []string) Opt { return func(a *Agent) { for _, w := range warnings { - a.addToolWarning(w) + a.AddToolWarning(w) } } } diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 03edea4e5..a067abf56 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -643,15 +643,23 @@ func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events chan Even } } -// emitAgentWarnings drains and emits any pending toolset warnings as persistent -// TUI notifications. Both start failures and recovery notices are emitted as -// warnings so they remain visible until the user dismisses them. +// emitAgentWarnings drains and emits any pending toolset warnings and notices +// as persistent TUI notifications. Failures ("start failed", "list failed") +// and recoveries ("is now available") flow through separate queues on the +// agent so each can be framed correctly — a single mixed message saying +// "Some toolsets failed to initialize ... is now available" reads as a +// contradiction. func (r *LocalRuntime) emitAgentWarnings(a *agent.Agent, send func(Event)) { warnings := a.DrainWarnings() if len(warnings) > 0 { slog.Warn("Tool setup partially failed; continuing", "agent", a.Name(), "warnings", warnings) send(Warning(formatToolWarning(a, warnings), a.Name())) } + notices := a.DrainNotices() + if len(notices) > 0 { + slog.Info("Toolset status update", "agent", a.Name(), "notices", notices) + send(Warning(formatToolNotice(a, notices), a.Name())) + } } func formatToolWarning(a *agent.Agent, warnings []string) string { @@ -663,6 +671,15 @@ func formatToolWarning(a *agent.Agent, warnings []string) string { return strings.TrimSuffix(builder.String(), "\n") } +func formatToolNotice(a *agent.Agent, notices []string) string { + var builder strings.Builder + fmt.Fprintf(&builder, "Toolset status update for agent '%s':\n\n", a.Name()) + for _, notice := range notices { + fmt.Fprintf(&builder, "- %s\n", notice) + } + return strings.TrimSuffix(builder.String(), "\n") +} + // filterExcludedTools removes tools whose names appear in the excluded list. // This is used by skill sub-sessions to prevent recursive run_skill calls. func filterExcludedTools(agentTools []tools.Tool, excluded []string) []tools.Tool { diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index a72c4a14c..7f4f51b67 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -836,6 +836,76 @@ func TestProcessToolCalls_UnknownTool_ReturnsErrorResponse(t *testing.T) { assert.Contains(t, toolContent, "not available") } +// TestEmitAgentWarnings_RecoveryNoticeIsNotFramedAsFailure verifies that +// when a previously-failed toolset recovers ("is now available"), the +// emitted WarningEvent is framed neutrally rather than wrapped in the +// "Some toolsets failed to initialize" framing used for real failures. +// +// Regression test for: after lazy-init OAuth completes and the toolset +// reconnects, the user saw a notification that read +// +// "Some toolsets failed to initialize for agent 'root'. +// Details: +// - mcp(remote host=mcp.slack.com transport=streamable) is now available" +// +// The body says "is now available" (success) but the framing says "failed" +// (failure), which the user reasonably read as "is NOT available". +func TestEmitAgentWarnings_RecoveryNoticeIsNotFramedAsFailure(t *testing.T) { + prov := &mockProvider{id: "test/m", stream: &mockStream{}} + root := agent.New("root", "agent", agent.WithModel(prov)) + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithCurrentAgent("root"), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + root.AddToolNotice("mcp(remote host=mcp.slack.com transport=streamable) is now available") + + var emitted []Event + rt.emitAgentWarnings(root, func(e Event) { emitted = append(emitted, e) }) + + require.Len(t, emitted, 1, "expected exactly one event from a single notice") + w, ok := emitted[0].(*WarningEvent) + require.True(t, ok, "expected a *WarningEvent, got %T", emitted[0]) + + assert.NotContains(t, strings.ToLower(w.Message), "failed", + "recovery notice must not be framed as a failure; got: %q", w.Message) + assert.Contains(t, w.Message, "is now available", + "recovery notice must preserve the actual content; got: %q", w.Message) +} + +// TestEmitAgentWarnings_FailureAndRecoveryAreEmittedSeparately verifies that +// when both a real failure and a recovery notice are pending, they are +// emitted as two separate events with appropriate framing each — not +// merged into a single message that conflates them. +func TestEmitAgentWarnings_FailureAndRecoveryAreEmittedSeparately(t *testing.T) { + prov := &mockProvider{id: "test/m", stream: &mockStream{}} + root := agent.New("root", "agent", agent.WithModel(prov)) + tm := team.New(team.WithAgents(root)) + rt, err := NewLocalRuntime(tm, WithCurrentAgent("root"), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + root.AddToolWarning("toolset_a start failed: connection refused") + root.AddToolNotice("toolset_b is now available") + + var emitted []*WarningEvent + rt.emitAgentWarnings(root, func(e Event) { + if w, ok := e.(*WarningEvent); ok { + emitted = append(emitted, w) + } + }) + + require.Len(t, emitted, 2, "failures and notices must be emitted as separate events") + + // Identify which is which (order is failure first, then notice). + failure, notice := emitted[0], emitted[1] + assert.Contains(t, strings.ToLower(failure.Message), "failed", + "failure event must use the failure framing; got: %q", failure.Message) + assert.Contains(t, failure.Message, "toolset_a start failed: connection refused") + + assert.NotContains(t, strings.ToLower(notice.Message), "failed", + "recovery event must NOT use the failure framing; got: %q", notice.Message) + assert.Contains(t, notice.Message, "toolset_b is now available") +} + func TestEmitStartupInfo(t *testing.T) { // Create a simple agent with mock provider prov := &mockProvider{id: "test/startup-model", stream: &mockStream{}} From 68372ee5cd36ac9d8bd62dd37f3282dead11e5a2 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Sat, 25 Apr 2026 11:32:46 +0200 Subject: [PATCH 2/3] feat(mcp): defer OAuth elicitation outside interactive context Running 'docker agent run ./examples/slack.yaml' could hang during startup before the TUI was even ready: the runtime was eagerly starting toolsets to populate sidebar tool counts, the remote MCP client got back 401 + WWW-Authenticate, and the OAuth flow tried to surface an elicitation dialog. The TUI hadn't rendered yet, the elicitation goroutine was blocked on a synchronous channel send, and Ctrl-C couldn't reach it. This change introduces a per-context flag that lets callers opt out of interactive flows: - WithoutInteractivePrompts(ctx) marks a context as non-interactive. - The oauthTransport checks the flag before triggering the OAuth elicitation on a 401 and short-circuits with a recognisable AuthorizationRequiredError instead of blocking. - IsAuthorizationRequired(err) lets callers distinguish 'OAuth was deferred' from a real failure. Because the MCP SDK wraps transport errors with %v (not %w), enrichConnectError reads the deferred-auth flag back off the transport in remote.go and re-emits a clean AuthorizationRequiredError. The runtime wraps the startup tool-loading context with this helper, so toolsets that need OAuth are detected immediately and silently deferred to the first RunStream where the user is interacting and a dialog can be rendered. Real start failures still surface as warnings via the agent's warning channel (a.AddToolWarning), and freshFailure is preserved for the deferred-OAuth case so the *real* failure on the eventual interactive retry isn't suppressed by the once-per-streak guard. Assisted-By: docker-agent --- pkg/runtime/runtime.go | 51 ++++++- pkg/runtime/runtime_test.go | 253 +++++++++++++++++++++++++++++++++++ pkg/tools/mcp/interactive.go | 60 +++++++++ pkg/tools/mcp/oauth.go | 35 ++++- pkg/tools/mcp/oauth_test.go | 52 +++++++ pkg/tools/mcp/remote.go | 43 ++++-- pkg/tools/mcp/remote_test.go | 46 +++++++ 7 files changed, 523 insertions(+), 17 deletions(-) create mode 100644 pkg/tools/mcp/interactive.go diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 711607615..0025d2631 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -926,12 +926,23 @@ func (r *LocalRuntime) EmitStartupInfo(ctx context.Context, sess *session.Sessio send(NewTokenUsageEvent(sess.ID, r.CurrentAgentName(), usage)) } - // Emit agent warnings (if any) - these are quick + // Tool loading can be slow (MCP servers need to start). Mark the + // context as non-interactive so toolsets that require user-driven + // flows (e.g. an OAuth elicitation for a remote MCP server) fail + // fast with a recognisable error rather than blocking on a dialog + // the TUI is not yet ready to render. The actual prompt happens on + // the first RunStream when the user is interacting with the agent. + nonInteractiveCtx := mcptools.WithoutInteractivePrompts(ctx) + r.emitToolsProgressively(nonInteractiveCtx, a, send) + + // Flush any agent warnings: load-time warnings recorded at agent + // construction (WithLoadTimeWarnings) and per-toolset warnings recorded + // during startup above (e.g. a remote MCP server returning 4xx during + // initialize). Surfacing them as WarningEvents lets the TUI show a + // persistent notice with the actual server-side explanation — otherwise + // the user only sees the toolset disappear from the sidebar with no clue + // as to why. r.emitAgentWarnings(a, func(e Event) { send(e) }) - - // Tool loading can be slow (MCP servers need to start) - // Emit progressive updates as each toolset loads - r.emitToolsProgressively(ctx, a, send) } // emitToolsProgressively loads tools from each toolset and emits progress updates. @@ -966,7 +977,35 @@ func (r *LocalRuntime) emitToolsProgressively(ctx context.Context, a *agent.Agen if startable, ok := toolset.(*tools.StartableToolSet); ok { if !startable.IsStarted() { if err := startable.Start(ctx); err != nil { - slog.Warn("Toolset start failed; skipping", "agent", a.Name(), "toolset", fmt.Sprintf("%T", startable.ToolSet), "error", err) + desc := tools.DescribeToolSet(startable.ToolSet) + // IsAuthorizationRequired must be checked BEFORE + // ShouldReportFailure: this is the first — expected — + // failure of a deferred-OAuth toolset, and consuming + // freshFailure here would suppress the *real* + // failure (e.g. server 4xx on the eventual interactive + // retry) that the user actually needs to see. + if mcptools.IsAuthorizationRequired(err) { + // The toolset just needs an OAuth approval that we + // deliberately deferred until the user is interacting + // with the agent. The dialog will appear naturally on + // the first RunStream — no need to pre-announce it. + slog.Debug("Toolset deferred until first message", "agent", a.Name(), "toolset", desc, "reason", err) + continue + } + // Route real failures through the agent's warning + // channel so the TUI surfaces a persistent, + // user-visible notice that includes the actual + // server-side cause (threaded through by + // remoteMCPClient.Initialize). Use the same + // once-per-streak guard as ensureToolSetsAreStarted + // so a failing toolset doesn't flood the UI with a + // new warning every time the agent is restarted. + if !startable.ShouldReportFailure() { + slog.Debug("Toolset still unavailable; skipping", "agent", a.Name(), "toolset", desc, "error", err) + continue + } + slog.Warn("Toolset start failed; skipping", "agent", a.Name(), "toolset", desc, "error", err) + a.AddToolWarning(fmt.Sprintf("%s start failed: %v", desc, err)) continue } } diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 7f4f51b67..0bbcd16ef 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -8,6 +8,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -23,6 +24,7 @@ import ( "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/team" "github.com/docker/docker-agent/pkg/tools" + mcptools "github.com/docker/docker-agent/pkg/tools/mcp" ) type stubToolSet struct { @@ -836,6 +838,257 @@ func TestProcessToolCalls_UnknownTool_ReturnsErrorResponse(t *testing.T) { assert.Contains(t, toolContent, "not available") } +// oauthAwareToolSet simulates a remote MCP toolset that needs an elicitation +// handler and the managed-OAuth flag configured before Start() runs. The +// Slack-MCP bug reported by users shows up exactly when Start() triggers an +// OAuth flow with neither handler installed, so this test captures the +// handler state at the moment Start() is entered. +type oauthAwareToolSet struct { + mu sync.Mutex + elicitationHandler tools.ElicitationHandler + managedOAuth bool + managedOAuthSet bool + started bool + startHandlerCaptured tools.ElicitationHandler + startManagedCaptured bool + startManagedWasSet bool +} + +// Verify interface compliance +var ( + _ tools.ToolSet = (*oauthAwareToolSet)(nil) + _ tools.Startable = (*oauthAwareToolSet)(nil) + _ tools.Elicitable = (*oauthAwareToolSet)(nil) + _ tools.OAuthCapable = (*oauthAwareToolSet)(nil) +) + +func (s *oauthAwareToolSet) Start(context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + s.started = true + // Snapshot the handler state at the moment Start runs — this is what + // the OAuth flow would see when it tries to prompt the user. + s.startHandlerCaptured = s.elicitationHandler + s.startManagedCaptured = s.managedOAuth + s.startManagedWasSet = s.managedOAuthSet + return nil +} + +func (s *oauthAwareToolSet) Stop(context.Context) error { return nil } + +func (s *oauthAwareToolSet) Tools(context.Context) ([]tools.Tool, error) { + return nil, nil +} + +func (s *oauthAwareToolSet) SetElicitationHandler(h tools.ElicitationHandler) { + s.mu.Lock() + defer s.mu.Unlock() + s.elicitationHandler = h +} + +func (s *oauthAwareToolSet) SetOAuthSuccessHandler(func()) {} + +func (s *oauthAwareToolSet) SetManagedOAuth(managed bool) { + s.mu.Lock() + defer s.mu.Unlock() + s.managedOAuth = managed + s.managedOAuthSet = true +} + +// TestEmitStartupInfo_DoesNotBlockOnInteractiveOAuth verifies that the +// startup path does NOT trigger interactive flows on toolsets. In particular: +// +// - EmitStartupInfo must complete promptly even when a toolset's Start() +// would normally prompt the user (e.g. an OAuth elicitation for a remote +// MCP server). +// - The runtime's elicitation/OAuth handlers must not be wired into the +// toolset during startup; the OAuth dialog only makes sense once the +// user is interacting with the agent. +// +// Regression test for: "docker agent run ./examples/slack.yaml" hanging +// before the TUI was even ready, with Ctrl-C unable to interrupt because +// the OAuth elicitation was synchronously blocked on a TUI dialog that the +// app hadn't started yet. The fix marks the startup context with +// mcptools.WithoutInteractivePrompts and defers OAuth to the first +// RunStream call. +func TestEmitStartupInfo_DoesNotBlockOnInteractiveOAuth(t *testing.T) { + prov := &mockProvider{id: "test/startup-model", stream: &mockStream{}} + + oauthTS := &oauthAwareToolSet{} + + root := agent.New("root", "agent", + agent.WithModel(prov), + agent.WithToolSets(oauthTS), + ) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, WithCurrentAgent("root"), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + events := make(chan Event, 20) + + done := make(chan struct{}) + go func() { + rt.EmitStartupInfo(t.Context(), nil, events) + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("EmitStartupInfo blocked: it must complete promptly even for toolsets that need OAuth") + } + close(events) + for range events { + } + + oauthTS.mu.Lock() + defer oauthTS.mu.Unlock() + + require.True(t, oauthTS.started, "toolset should still be started during EmitStartupInfo (just not interactively)") + + // During startup, no interactive plumbing should be wired up. OAuth and + // elicitation are deferred to the first RunStream call where the user + // is actively interacting with the agent. + require.Nil(t, oauthTS.startHandlerCaptured, + "elicitation handler must NOT be set during startup; OAuth is deferred until the user sends a message") + require.False(t, oauthTS.startManagedWasSet, + "managed-OAuth flag must NOT be set during startup") +} + +// TestEmitStartupInfo_SurfacesToolsetStartFailureAsWarning verifies that +// when a toolset fails to start during EmitStartupInfo, the failure is +// emitted as a WarningEvent on the events channel so the TUI can show +// the user the actual cause — not just silently drop the toolset. +// +// Without this, a remote MCP server returning a 4xx during initialize +// (e.g. Slack's "App is not enabled for Slack MCP server access") +// disappears from the sidebar with only a debug-log trace, leaving the +// user with no hint about what went wrong. +func TestEmitStartupInfo_SurfacesToolsetStartFailureAsWarning(t *testing.T) { + prov := &mockProvider{id: "test/startup-model", stream: &mockStream{}} + + // A toolset whose Start() always fails with a rich, provider-specific + // message — mimicking the error returned by remoteMCPClient.Initialize + // after it has been enriched with the server's own explanation. + failingTS := newStubToolSet( + errors.New("failed to initialize MCP client: failed to connect to MCP server: sending \"initialize\": Bad Request (server responded 400: App is not enabled for Slack MCP server access.)"), + nil, + nil, + ) + + root := agent.New("root", "agent", + agent.WithModel(prov), + agent.WithToolSets(failingTS), + ) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, WithCurrentAgent("root"), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + events := make(chan Event, 32) + rt.EmitStartupInfo(t.Context(), nil, events) + close(events) + + var warning *WarningEvent + for e := range events { + if w, ok := e.(*WarningEvent); ok { + warning = w + } + } + + require.NotNil(t, warning, "EmitStartupInfo should emit a WarningEvent when a toolset fails to start") + assert.Contains(t, warning.Message, "App is not enabled for Slack MCP server access.", + "warning should include the toolset's actual error message so the user can see the real cause") +} + +// TestEmitStartupInfo_AuthRequiredIsSilent verifies that when a toolset's +// Start() returns an mcptools.IsAuthorizationRequired error — the runtime +// deliberately deferred OAuth until the user is interacting — the user +// sees no warning event for it. The OAuth dialog will appear naturally on +// the first RunStream, so a pre-announcement would just be noise. +func TestEmitStartupInfo_AuthRequiredIsSilent(t *testing.T) { + prov := &mockProvider{id: "test/startup-model", stream: &mockStream{}} + + deferralErr := &mcptools.AuthorizationRequiredError{URL: "https://example.test/mcp"} + require.True(t, mcptools.IsAuthorizationRequired(deferralErr), + "sanity: AuthorizationRequiredError must be detected by IsAuthorizationRequired") + + failingTS := newStubToolSet(deferralErr, nil, nil) + + root := agent.New("root", "agent", + agent.WithModel(prov), + agent.WithToolSets(failingTS), + ) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, WithCurrentAgent("root"), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + events := make(chan Event, 32) + rt.EmitStartupInfo(t.Context(), nil, events) + close(events) + + for e := range events { + if w, ok := e.(*WarningEvent); ok { + t.Fatalf("deferred-OAuth must not produce a WarningEvent (would be redundant noise); got: %q", w.Message) + } + } +} + +// TestEmitStartupInfo_DeferredAuthPreservesFreshFailureFlag verifies that +// when a toolset's Start fails with AuthorizationRequiredError during the +// non-interactive startup phase, the StartableToolSet's freshFailure flag is +// LEFT INTACT — not silently consumed by the "is this the first failure?" +// check. +// +// Why this matters: the deferred-OAuth case is an *expected*, transient +// failure. The first user-visible failure that should produce a warning is +// whatever happens on the eventual interactive retry (e.g. "server +// responded 400: App is not enabled for Slack MCP server access"). If the +// flag is consumed during startup, the StartableToolSet's once-per-streak +// guard fires for the deferred case and silently swallows the real cause, +// leaving the user staring at "0 tools" with nothing in the UI explaining +// why. +func TestEmitStartupInfo_DeferredAuthPreservesFreshFailureFlag(t *testing.T) { + prov := &mockProvider{id: "test/startup-model", stream: &mockStream{}} + + deferralErr := &mcptools.AuthorizationRequiredError{URL: "https://example.test/mcp"} + failingTS := newStubToolSet(deferralErr, nil, nil) + + root := agent.New("root", "agent", + agent.WithModel(prov), + agent.WithToolSets(failingTS), + ) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, WithCurrentAgent("root"), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + events := make(chan Event, 32) + rt.EmitStartupInfo(t.Context(), nil, events) + close(events) + for range events { + } + + // Locate the StartableToolSet wrapping our stub so we can probe its + // internal state (the public API uses ShouldReportFailure as both the + // query and the consume operation). + var wrapped *tools.StartableToolSet + for _, ts := range root.ToolSets() { + if s, ok := ts.(*tools.StartableToolSet); ok { + wrapped = s + break + } + } + require.NotNil(t, wrapped, "agent.ToolSets() should return a *StartableToolSet wrapper") + + require.True(t, wrapped.ShouldReportFailure(), + "deferred-OAuth must NOT consume freshFailure during EmitStartupInfo: "+ + "otherwise the next real failure (Slack 4xx after OAuth, etc.) is silently dropped "+ + "and the user sees zero tools with no explanation") +} + // TestEmitAgentWarnings_RecoveryNoticeIsNotFramedAsFailure verifies that // when a previously-failed toolset recovers ("is now available"), the // emitted WarningEvent is framed neutrally rather than wrapped in the diff --git a/pkg/tools/mcp/interactive.go b/pkg/tools/mcp/interactive.go new file mode 100644 index 000000000..4dd145c94 --- /dev/null +++ b/pkg/tools/mcp/interactive.go @@ -0,0 +1,60 @@ +package mcp + +import ( + "context" + "errors" +) + +// noInteractivePromptsKey is the unexported key used to attach the +// "skip interactive prompts" flag to a context. +type noInteractivePromptsKey struct{} + +// WithoutInteractivePrompts returns a context that asks the MCP transport +// stack to refuse any flow that would require user input. The canonical +// example is OAuth: a remote MCP server's first contact is typically a 401 +// Unauthorized that triggers an interactive elicitation flow ("approve OAuth +// authorization?"). During startup the TUI is not yet ready to surface that +// dialog, the user has no input field, and Ctrl-C cannot reach the elicitation +// goroutine because it is blocked on a synchronous send/receive. +// +// Callers that prepare data eagerly (sidebar tool counts, dry-runs, health +// checks) should wrap their context with this helper so toolset Start() +// returns a meaningful error immediately instead of hanging the process. +// +// Once a real user interaction is in progress (RunStream), the context +// should NOT carry this value so the user can complete OAuth normally. +func WithoutInteractivePrompts(ctx context.Context) context.Context { + return context.WithValue(ctx, noInteractivePromptsKey{}, true) +} + +// interactivePromptsAllowed reports whether the context allows blocking on +// user-driven flows. The default is true so existing callers (RunStream, +// tests) keep working without changes. +func interactivePromptsAllowed(ctx context.Context) bool { + v, _ := ctx.Value(noInteractivePromptsKey{}).(bool) + return !v +} + +// AuthorizationRequiredError is returned by the transport when an OAuth +// elicitation would be needed but the context disallows interactive prompts +// (see WithoutInteractivePrompts). Callers can detect it with +// IsAuthorizationRequired and decide how (or whether) to surface it. +// +// The exported type is also useful in tests that want to simulate the +// deferred-OAuth path without spinning up a real HTTP server. +type AuthorizationRequiredError struct { + URL string +} + +func (e *AuthorizationRequiredError) Error() string { + return e.URL + " requires interactive OAuth authorization" +} + +// IsAuthorizationRequired reports whether err (or any error wrapped by it) +// signals that the toolset failed to start because OAuth is needed and the +// caller chose to defer the prompt. Callers can use this to render a softer, +// "needs auth" notice instead of a red error. +func IsAuthorizationRequired(err error) bool { + var target *AuthorizationRequiredError + return errors.As(err, &target) +} diff --git a/pkg/tools/mcp/oauth.go b/pkg/tools/mcp/oauth.go index 2b2bdfe7f..af66936e8 100644 --- a/pkg/tools/mcp/oauth.go +++ b/pkg/tools/mcp/oauth.go @@ -184,11 +184,18 @@ type oauthTransport struct { managed bool oauthConfig *latest.RemoteOAuthConfig - // mu protects refreshFailedAt from concurrent access. + // mu protects refreshFailedAt and lastAuthRequired from concurrent access. mu sync.Mutex // refreshFailedAt tracks the last time a silent token refresh failed, // so we avoid retrying on every request. refreshFailedAt time.Time + // lastAuthRequired records when the transport short-circuited an + // interactive OAuth flow because the request context disallowed + // prompts (see WithoutInteractivePrompts). The MCP SDK wraps transport + // errors with %v, breaking errors.As, so callers must use this field + // instead of unwrapping to know that OAuth was deferred rather than + // failed for some other reason. + lastAuthRequired bool } func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) { @@ -221,6 +228,22 @@ func (t *oauthTransport) roundTrip(req *http.Request, isRetry bool) (*http.Respo if resp.StatusCode == http.StatusUnauthorized && !isRetry { wwwAuth := resp.Header.Get("WWW-Authenticate") if wwwAuth != "" { + // If the caller asked for non-interactive operation (e.g. the + // runtime is populating sidebar tool counts during startup), + // don't block on an OAuth elicitation that the TUI is not yet + // ready to surface. Surface a recognisable error instead so + // the toolset can be flagged "needs auth" without freezing + // the agent and without making Ctrl-C wait for a user response + // that will never come. + if !interactivePromptsAllowed(req.Context()) { + slog.Debug("Skipping OAuth elicitation in non-interactive context", "url", t.baseURL) + resp.Body.Close() + t.mu.Lock() + t.lastAuthRequired = true + t.mu.Unlock() + return nil, &AuthorizationRequiredError{URL: t.baseURL} + } + resp.Body.Close() authServer := req.URL.Scheme + "://" + req.URL.Host @@ -239,6 +262,16 @@ func (t *oauthTransport) roundTrip(req *http.Request, isRetry bool) (*http.Respo return resp, nil } +// authorizationRequired reports whether the transport short-circuited an +// interactive OAuth flow because the request context disallowed prompts. +// Callers can use this to recognise the deferred-OAuth case even though +// the MCP SDK destroys the underlying error chain by wrapping with %v. +func (t *oauthTransport) authorizationRequired() bool { + t.mu.Lock() + defer t.mu.Unlock() + return t.lastAuthRequired +} + // getValidToken returns a non-expired token for the server, silently refreshing // an expired token when a refresh token is available. Returns nil if no usable // token can be obtained. diff --git a/pkg/tools/mcp/oauth_test.go b/pkg/tools/mcp/oauth_test.go index 573203898..604751310 100644 --- a/pkg/tools/mcp/oauth_test.go +++ b/pkg/tools/mcp/oauth_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" ) @@ -365,3 +366,54 @@ func TestCallbackServer_RejectsCallbackBeforeStateSet(t *testing.T) { t.Errorf("expected 400, got %d — callback accepted without expected state set", resp.StatusCode) } } + +// TestOAuthTransport_NonInteractiveCtxSkipsElicitation verifies that when +// the request context is marked non-interactive (via WithoutInteractivePrompts), +// a 401 with WWW-Authenticate does NOT trigger the OAuth flow. Instead the +// transport returns a recognisable AuthorizationRequiredError, so callers can +// surface a deferred-auth notice without the goroutine getting stuck on a +// dialog the UI is not yet ready to show. +// +// We deliberately leave the transport's `client` field nil: in non-interactive +// mode the short-circuit must happen before anything in the OAuth flow +// (which would dereference `client` to send an elicitation) is reached. A +// nil-pointer panic here would be a clear, loud signal that the contract +// is broken. +// +// Regression test for: "docker agent run ./examples/slack.yaml" hanging +// during startup, with Ctrl-C unable to interrupt because the OAuth +// elicitation was synchronously waiting on a TUI prompt that hadn't been +// rendered yet. +func TestOAuthTransport_NonInteractiveCtxSkipsElicitation(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("WWW-Authenticate", `Bearer resource="https://example.test/.well-known/oauth-protected-resource"`) + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + + transport := &oauthTransport{ + base: http.DefaultTransport, + tokenStore: NewInMemoryTokenStore(), + baseURL: srv.URL, + // client intentionally left nil — see test comment above. + } + + ctx := WithoutInteractivePrompts(t.Context()) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, srv.URL, strings.NewReader("{}")) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, gotErr := transport.RoundTrip(req) + if resp != nil { + resp.Body.Close() + } + + if gotErr == nil { + t.Fatalf("expected an error in non-interactive mode, got resp=%v err=nil", resp) + } + if !IsAuthorizationRequired(gotErr) { + t.Errorf("expected IsAuthorizationRequired(err)=true, got err=%v", gotErr) + } +} diff --git a/pkg/tools/mcp/remote.go b/pkg/tools/mcp/remote.go index 53abc858b..de5269f47 100644 --- a/pkg/tools/mcp/remote.go +++ b/pkg/tools/mcp/remote.go @@ -40,8 +40,12 @@ func newRemoteClient(url, transportType string, headers map[string]string, token } func (c *remoteMCPClient) Initialize(ctx context.Context, _ *gomcp.InitializeRequest) (*gomcp.InitializeResult, error) { - // Create HTTP client with OAuth support - httpClient := c.createHTTPClient() + // Create HTTP client with OAuth support. We keep a reference to the + // oauthTransport so we can recognise the deferred-OAuth case (the + // transport returned an AuthorizationRequiredError because the request + // context disallowed prompts) and re-emit a clean + // AuthorizationRequiredError that callers can detect with errors.As. + httpClient, oauthT := c.createHTTPClient() var transport gomcp.Transport @@ -80,7 +84,7 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *gomcp.InitializeReq // Connect to the MCP server session, err := client.Connect(ctx, transport, nil) if err != nil { - return nil, fmt.Errorf("failed to connect to MCP server: %w", err) + return nil, enrichConnectError(err, oauthT) } c.setSession(session) @@ -89,6 +93,23 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *gomcp.InitializeReq return session.InitializeResult(), nil } +// enrichConnectError wraps the error returned by client.Connect so callers +// can distinguish the deferred-OAuth case from a real failure. +// +// The MCP SDK uses fmt.Errorf("%w: %v", …) when it surfaces transport errors, +// which means the original error is included as text only — not in the unwrap +// chain — so we can't rely on errors.As against the SDK-wrapped error. +// Instead we read the deferred-auth flag back off the transport and re-emit +// a clean AuthorizationRequiredError. +// +// Pre: err != nil and t != nil; only called from the Connect failure path. +func enrichConnectError(err error, t *oauthTransport) error { + if t.authorizationRequired() { + return &AuthorizationRequiredError{URL: t.baseURL} + } + return fmt.Errorf("failed to connect to MCP server: %w", err) +} + // SetManagedOAuth sets whether OAuth should be handled in managed mode. // In managed mode, the client handles the OAuth flow instead of the server. func (c *remoteMCPClient) SetManagedOAuth(managed bool) { @@ -100,12 +121,16 @@ func (c *remoteMCPClient) SetManagedOAuth(managed bool) { // createHTTPClient creates an HTTP client with custom headers and OAuth support. // Header values may contain ${headers.NAME} placeholders that are resolved // at request time from upstream headers stored in the request context. -func (c *remoteMCPClient) createHTTPClient() *http.Client { - transport := c.headerTransport() +// +// The oauthTransport is returned alongside the client so callers can inspect +// the transport's state (e.g. whether OAuth was deferred) when Connect() +// returns and we need to surface the actual cause of the failure. +func (c *remoteMCPClient) createHTTPClient() (*http.Client, *oauthTransport) { + base := c.headerTransport() // Then wrap with OAuth support - transport = &oauthTransport{ - base: transport, + oauthT := &oauthTransport{ + base: base, client: c, tokenStore: c.tokenStore, baseURL: c.url, @@ -113,9 +138,7 @@ func (c *remoteMCPClient) createHTTPClient() *http.Client { oauthConfig: c.oauthConfig, } - return &http.Client{ - Transport: transport, - } + return &http.Client{Transport: oauthT}, oauthT } func (c *remoteMCPClient) headerTransport() http.RoundTripper { diff --git a/pkg/tools/mcp/remote_test.go b/pkg/tools/mcp/remote_test.go index f41a54143..3f54dc603 100644 --- a/pkg/tools/mcp/remote_test.go +++ b/pkg/tools/mcp/remote_test.go @@ -1,6 +1,7 @@ package mcp import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -180,3 +181,48 @@ func TestRemoteClientEmptyHeaders(t *testing.T) { t.Fatal("Server did not receive request within timeout") } } + +// TestInitialize_NonInteractiveCtxDefersOAuthAndDoesNotBlock verifies that +// when Initialize runs against a server that requires OAuth (responds with +// 401 + WWW-Authenticate) under a context flagged with +// WithoutInteractivePrompts, the call: +// +// - returns promptly, +// - returns an error that satisfies IsAuthorizationRequired, +// - never opens a callback HTTP server (i.e. doesn't try to bind a port). +// +// Regression test for: "docker agent run ./examples/slack.yaml" hanging +// during startup. The TUI was not yet ready to render the OAuth dialog, +// the elicitation goroutine was blocked on a synchronous channel send, +// and Ctrl-C couldn't reach it. +func TestInitialize_NonInteractiveCtxDefersOAuthAndDoesNotBlock(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("WWW-Authenticate", `Bearer resource="https://example.test/.well-known/oauth-protected-resource"`) + w.WriteHeader(http.StatusUnauthorized) + })) + defer server.Close() + + client := newRemoteClient(server.URL, "streamable", nil, NewInMemoryTokenStore(), nil) + + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) + defer cancel() + + nonInteractiveCtx := WithoutInteractivePrompts(ctx) + + done := make(chan error, 1) + go func() { + _, err := client.Initialize(nonInteractiveCtx, nil) + done <- err + }() + + select { + case err := <-done: + require.Error(t, err, "Initialize should fail with a deferred-auth error in non-interactive ctx") + assert.True(t, IsAuthorizationRequired(err), + "non-interactive Initialize should return IsAuthorizationRequired, got: %v", err) + case <-ctx.Done(): + t.Fatalf("Initialize blocked for too long; non-interactive ctx must short-circuit OAuth: %v", ctx.Err()) + } +} From 01fad8aff49effa9046d91259a5769b07b1ed113 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Sat, 25 Apr 2026 11:33:32 +0200 Subject: [PATCH 3/3] fix(mcp/oauth): support Slack token responses, surface server errors, re-auth on scope changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three related fixes to the remote MCP OAuth flow, motivated by the Slack MCP server (https://mcp.slack.com/mcp). - Accept Slack's oauth.v2.user.access response shape in addition to the standard RFC 6749 §5.1 flat shape. Slack nests the user token inside an authed_user object and signals application-level failures via HTTP 200 + {"ok":false}. Previously the flat decoder silently produced an empty bearer, and every subsequent MCP request was rejected with 401. Decoding is factored out into parseTokenResponse, used by both ExchangeCodeForToken and RefreshAccessToken, and we now error out explicitly when no access_token can be located. - Track the configured scope list on each stored token as a new RequestedScopes field. In getValidToken, if the configured scopes are no longer covered by the stored token, purge it from the store and return nil so the next request triggers a fresh OAuth flow. The scope list is preserved across silent refreshes, and legacy tokens without RequestedScopes are left untouched to avoid forcing re-auth on upgrade. - Log the response body (capped at 2 KiB) when an authenticated retry after a successful OAuth flow is still rejected with a non-2xx. The MCP SDK only surfaces http.StatusText, which hides provider-specific error detail such as insufficient_scope. The body remains readable by downstream consumers, and enrichConnectError attaches the extracted message to the Connect error so callers see the real server-side cause (e.g. 'App is not enabled for Slack MCP server access'). Covered by new unit tests for Slack's nested and ok:false payloads, post-OAuth retry body preservation, and the scope-invalidation path. Assisted-By: docker-agent --- pkg/tools/mcp/oauth.go | 201 +++++++++++++++++- pkg/tools/mcp/oauth_helpers.go | 108 ++++++++-- pkg/tools/mcp/oauth_test.go | 360 +++++++++++++++++++++++++++++++++ pkg/tools/mcp/remote.go | 33 +-- pkg/tools/mcp/remote_test.go | 41 ++++ pkg/tools/mcp/tokenstore.go | 7 + 6 files changed, 723 insertions(+), 27 deletions(-) diff --git a/pkg/tools/mcp/oauth.go b/pkg/tools/mcp/oauth.go index af66936e8..570175c7a 100644 --- a/pkg/tools/mcp/oauth.go +++ b/pkg/tools/mcp/oauth.go @@ -1,6 +1,7 @@ package mcp import ( + "bytes" "cmp" "context" "encoding/json" @@ -184,11 +185,18 @@ type oauthTransport struct { managed bool oauthConfig *latest.RemoteOAuthConfig - // mu protects refreshFailedAt and lastAuthRequired from concurrent access. + // mu protects refreshFailedAt and lastErr* from concurrent access. mu sync.Mutex // refreshFailedAt tracks the last time a silent token refresh failed, // so we avoid retrying on every request. refreshFailedAt time.Time + // lastErrStatus and lastErrBody capture the status code and (truncated) + // response body of the most recent non-2xx HTTP response received by the + // transport. They're read by callers of Initialize() to enrich bubbled-up + // errors with the server's own explanation, which the MCP SDK otherwise + // swallows in favor of a bare http.StatusText. + lastErrStatus int + lastErrBody []byte // lastAuthRequired records when the transport short-circuited an // interactive OAuth flow because the request context disallowed // prompts (see WithoutInteractivePrompts). The MCP SDK wraps transport @@ -259,9 +267,91 @@ func (t *oauthTransport) roundTrip(req *http.Request, isRetry bool) (*http.Respo } } + // On the authenticated retry, log the response body when the server + // rejects us with an error status. Otherwise the failure bubbles up as + // a generic "Bad Request" / "Forbidden" / ... with no detail, making it + // very hard to understand why the server refused the token we just + // obtained (e.g. a scope mismatch, insufficient permissions, or + // provider-specific payload complaints). + // + // We also log on the first attempt when the status is something other + // than a plain 401 we're going to handle via OAuth. In particular, some + // servers return a non-standard 400 instead of 401 when a stored token + // is no longer accepted (for example, Slack's MCP endpoint answers + // `400 Bad Request` with a JSON-RPC error payload when the app has + // lost access — "App is not enabled for Slack MCP server access"), + // and surfacing the body is the only way to see the real cause. + if resp.StatusCode >= 400 { + t.logErrorResponse(req, resp) + } + return resp, nil } +// logErrorResponse peeks at an error response body (up to a reasonable cap) +// and logs it so the user can see what the server is actually complaining +// about, without preventing the caller from reading the body themselves. +// +// Many MCP server failures come back as short JSON-RPC error envelopes +// (e.g. `{"error":{"code":-32000,"message":"insufficient_scope"}}`) that +// are invaluable for diagnosis but are otherwise swallowed by the MCP SDK +// which only surfaces `http.StatusText(resp.StatusCode)`. +func (t *oauthTransport) logErrorResponse(req *http.Request, resp *http.Response) { + const maxBody = 2048 + + body, err := io.ReadAll(io.LimitReader(resp.Body, maxBody)) + if err != nil { + slog.Warn("Authenticated MCP request failed; could not read response body", + "url", req.URL.String(), + "status", resp.StatusCode, + "error", err, + ) + // Ensure the body reader is in a usable state for the caller. + resp.Body = io.NopCloser(bytes.NewReader(nil)) + return + } + + // Drain and replace the body so downstream consumers can still read it. + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(body)) + + // Remember the last server-side failure so higher layers (Initialize / + // doStart) can enrich their error with a human-readable explanation + // rather than the SDK's bare "Bad Request". + t.mu.Lock() + t.lastErrStatus = resp.StatusCode + t.lastErrBody = body + t.mu.Unlock() + + slog.Warn("Authenticated MCP request was rejected by the server", + "url", req.URL.String(), + "status", resp.StatusCode, + "www_authenticate", resp.Header.Get("WWW-Authenticate"), + "content_type", resp.Header.Get("Content-Type"), + "body", string(body), + ) +} + +// lastServerError returns the status code and a short, human-readable +// explanation drawn from the most recent non-2xx response seen by this +// transport. The string is empty when no such response has been captured +// or when the body yielded no useful text. +// +// This is how the transport surfaces provider-specific errors (e.g. Slack's +// "App is not enabled for Slack MCP server access") that would otherwise +// be hidden behind the MCP SDK's generic http.StatusText-derived messages. +func (t *oauthTransport) lastServerError() (int, string) { + t.mu.Lock() + status := t.lastErrStatus + body := t.lastErrBody + t.mu.Unlock() + if status == 0 { + return 0, "" + } + return status, extractServerMessage(body) +} + // authorizationRequired reports whether the transport short-circuited an // interactive OAuth flow because the request context disallowed prompts. // Callers can use this to recognise the deferred-OAuth case even though @@ -272,15 +362,80 @@ func (t *oauthTransport) authorizationRequired() bool { return t.lastAuthRequired } +// extractServerMessage extracts a short, human-readable message from a +// server response body. It tries, in order: +// +// 1. A JSON-RPC envelope: {"error":{"message":"..."}} +// 2. A Slack-style envelope: {"error":"some_code"} +// 3. A top-level {"message":"..."} +// 4. The raw body, trimmed and collapsed to a single line. +// +// Returns "" when the body is empty or contains only whitespace. +func extractServerMessage(body []byte) string { + body = bytes.TrimSpace(body) + if len(body) == 0 { + return "" + } + + var envelope struct { + Error json.RawMessage `json:"error"` + Message string `json:"message"` + } + if err := json.Unmarshal(body, &envelope); err == nil { + // Nested object: {"error":{"message":"..."}}. + var nested struct { + Message string `json:"message"` + } + if json.Unmarshal(envelope.Error, &nested) == nil && nested.Message != "" { + return nested.Message + } + // Plain string: {"error":"some_code"}. + var s string + if json.Unmarshal(envelope.Error, &s) == nil && s != "" { + return s + } + if envelope.Message != "" { + return envelope.Message + } + } + + // Fall back to the raw body, collapsed to a single line so it's safe + // to embed in an error message. Caps the length conservatively. + const maxLen = 400 + text := strings.Join(strings.Fields(string(body)), " ") + if len(text) > maxLen { + text = text[:maxLen] + "\u2026" + } + return text +} + // getValidToken returns a non-expired token for the server, silently refreshing // an expired token when a refresh token is available. Returns nil if no usable // token can be obtained. +// +// If the stored token's recorded RequestedScopes no longer cover the scopes +// currently requested by the config, the stored token is discarded so that +// the next request triggers a fresh OAuth flow. This keeps us from reusing a +// token that was provisioned with too-narrow (or entirely wrong) scopes +// — typically after the user edits the scope list in their agent config. func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken { token, err := t.tokenStore.GetToken(t.baseURL) if err != nil { return nil } + if !t.tokenCoversConfiguredScopes(token) { + slog.Debug("Stored token scopes no longer cover configured scopes; discarding to force re-auth", + "url", t.baseURL, + "stored", token.RequestedScopes, + "configured", configuredScopes(t.oauthConfig), + ) + if err := t.tokenStore.RemoveToken(t.baseURL); err != nil { + slog.Debug("Failed to remove stale token", "url", t.baseURL, "error", err) + } + return nil + } + if !token.IsExpired() { return token } @@ -317,6 +472,7 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken { return nil } newToken.AuthServer = authServer + newToken.RequestedScopes = token.RequestedScopes t.mu.Lock() t.refreshFailedAt = time.Time{} // reset on success @@ -330,6 +486,45 @@ func (t *oauthTransport) getValidToken(ctx context.Context) *OAuthToken { return newToken } +// tokenCoversConfiguredScopes reports whether the stored token was obtained +// with a scope set that still satisfies the config. +// +// Scoping rules (kept deliberately simple): +// - If the config declares no scopes, any token is considered sufficient. +// - If the stored token has no RequestedScopes (legacy tokens stored before +// this field was introduced), it is treated as sufficient to avoid +// forcing a re-auth on upgrade. +// - Otherwise, every configured scope must appear in the token's +// RequestedScopes. +func (t *oauthTransport) tokenCoversConfiguredScopes(token *OAuthToken) bool { + configured := configuredScopes(t.oauthConfig) + if len(configured) == 0 { + return true + } + if len(token.RequestedScopes) == 0 { + return true + } + stored := make(map[string]struct{}, len(token.RequestedScopes)) + for _, s := range token.RequestedScopes { + stored[s] = struct{}{} + } + for _, s := range configured { + if _, ok := stored[s]; !ok { + return false + } + } + return true +} + +// configuredScopes is a nil-safe accessor for the Scopes slice on the +// optional RemoteOAuthConfig. +func configuredScopes(c *latest.RemoteOAuthConfig) []string { + if c == nil { + return nil + } + return c.Scopes +} + // handleOAuthFlow performs the OAuth flow when a 401 response is received func (t *oauthTransport) handleOAuthFlow(ctx context.Context, authServer, wwwAuth string) error { if t.managed { @@ -488,6 +683,7 @@ func (t *oauthTransport) handleManagedOAuthFlow(ctx context.Context, authServer, token.ClientID = clientID token.ClientSecret = clientSecret token.AuthServer = resourceMetadata.AuthorizationServers[0] + token.RequestedScopes = scopes if err := t.tokenStore.StoreToken(t.baseURL, token); err != nil { return fmt.Errorf("failed to store token: %w", err) @@ -589,6 +785,9 @@ func (t *oauthTransport) handleUnmanagedOAuthFlow(ctx context.Context, authServe if refreshToken, ok := tokenData["refresh_token"].(string); ok { token.RefreshToken = refreshToken } + if t.oauthConfig != nil { + token.RequestedScopes = t.oauthConfig.Scopes + } if err := t.tokenStore.StoreToken(t.baseURL, token); err != nil { return fmt.Errorf("failed to store token: %w", err) } diff --git a/pkg/tools/mcp/oauth_helpers.go b/pkg/tools/mcp/oauth_helpers.go index ff194c391..ca9e862c8 100644 --- a/pkg/tools/mcp/oauth_helpers.go +++ b/pkg/tools/mcp/oauth_helpers.go @@ -73,19 +73,105 @@ func ExchangeCodeForToken(ctx context.Context, tokenEndpoint, code, codeVerifier return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body)) } - var token OAuthToken - if err := json.NewDecoder(resp.Body).Decode(&token); err != nil { + token, err := parseTokenResponse(resp.Body) + if err != nil { return nil, fmt.Errorf("failed to decode token response: %w", err) } + token.ClientID = clientID + token.ClientSecret = clientSecret + + return token, nil +} + +// tokenResponse is the on-the-wire shape of an OAuth 2.0 token response. +// +// It accepts both: +// +// - the standard flat shape defined by RFC 6749 §5.1 (access_token, token_type, +// expires_in, refresh_token at the top level); and +// +// - Slack's user-token flow (`oauth.v2.user.access`), which returns the user +// token nested inside an `authed_user` object and signals application-level +// success/failure with an `ok` boolean and `error` string rather than via +// HTTP status alone. +// +// Fields that do not exist in one variant are simply left at their zero value. +type tokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` + + // Slack application-level status. OK is a pointer so we can distinguish + // "field absent" (standard OAuth response) from "ok:false" (Slack error). + OK *bool `json:"ok,omitempty"` + Error string `json:"error,omitempty"` + + // Slack user-token flow nests the actual token under authed_user. + AuthedUser *struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` + } `json:"authed_user,omitempty"` +} + +// parseTokenResponse decodes a JSON token response body and normalizes it to +// an OAuthToken, supporting both the standard flat OAuth 2.0 shape and +// Slack's nested `authed_user` shape. It returns an error when the response +// signals `ok:false` or contains no usable access token. +func parseTokenResponse(body io.Reader) (*OAuthToken, error) { + var resp tokenResponse + if err := json.NewDecoder(body).Decode(&resp); err != nil { + return nil, err + } + + // Slack surfaces application-level failures with HTTP 200 + ok:false. + if resp.OK != nil && !*resp.OK { + if resp.Error != "" { + return nil, fmt.Errorf("token endpoint returned error: %s", resp.Error) + } + return nil, errors.New("token endpoint returned ok:false with no error details") + } + + token := &OAuthToken{ + AccessToken: resp.AccessToken, + TokenType: resp.TokenType, + ExpiresIn: resp.ExpiresIn, + RefreshToken: resp.RefreshToken, + Scope: resp.Scope, + } + + // Fall back to authed_user for providers that nest the user token there + // (notably Slack's oauth.v2.user.access endpoint). + if token.AccessToken == "" && resp.AuthedUser != nil && resp.AuthedUser.AccessToken != "" { + token.AccessToken = resp.AuthedUser.AccessToken + if token.TokenType == "" { + token.TokenType = resp.AuthedUser.TokenType + } + if token.ExpiresIn == 0 { + token.ExpiresIn = resp.AuthedUser.ExpiresIn + } + if token.RefreshToken == "" { + token.RefreshToken = resp.AuthedUser.RefreshToken + } + if token.Scope == "" { + token.Scope = resp.AuthedUser.Scope + } + } + + if token.AccessToken == "" { + return nil, errors.New("token response did not contain an access_token") + } + if token.ExpiresIn > 0 { token.ExpiresAt = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second) } - token.ClientID = clientID - token.ClientSecret = clientSecret - - return &token, nil + return token, nil } // RequestAuthorizationCode requests the user to open the authorization URL and waits for the callback @@ -194,15 +280,11 @@ func RefreshAccessToken(ctx context.Context, tokenEndpoint, refreshToken, client return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body)) } - var token OAuthToken - if err := json.NewDecoder(resp.Body).Decode(&token); err != nil { + token, err := parseTokenResponse(resp.Body) + if err != nil { return nil, fmt.Errorf("failed to decode refresh response: %w", err) } - if token.ExpiresIn > 0 { - token.ExpiresAt = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second) - } - // Preserve the refresh token if the server didn't issue a new one if token.RefreshToken == "" { token.RefreshToken = refreshToken @@ -212,5 +294,5 @@ func RefreshAccessToken(ctx context.Context, tokenEndpoint, refreshToken, client token.ClientID = clientID token.ClientSecret = clientSecret - return &token, nil + return token, nil } diff --git a/pkg/tools/mcp/oauth_test.go b/pkg/tools/mcp/oauth_test.go index 604751310..633354481 100644 --- a/pkg/tools/mcp/oauth_test.go +++ b/pkg/tools/mcp/oauth_test.go @@ -2,12 +2,15 @@ package mcp import ( "encoding/json" + "io" "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" + + "github.com/docker/docker-agent/pkg/config/latest" ) // TestExchangeCodeForToken_PreservesClientCredentials verifies that @@ -367,6 +370,311 @@ func TestCallbackServer_RejectsCallbackBeforeStateSet(t *testing.T) { } } +// TestExchangeCodeForToken_SlackNestedResponse verifies that Slack's +// oauth.v2.user.access response shape (where the user access_token is +// nested inside an `authed_user` object) is decoded correctly. Before +// this was supported, we would silently store an empty bearer token and +// every subsequent request to the MCP server would be rejected with 401. +func TestExchangeCodeForToken_SlackNestedResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "ok": true, + "app_id": "A12345678", + "authed_user": map[string]any{ + "id": "U12345678", + "scope": "search:read.public,users:read", + "token_type": "user", + "access_token": "xoxp-slack-user-token", + "expires_in": 43200, + "refresh_token": "xoxe-1-slack-refresh", + }, + "team": map[string]any{"id": "T12345678", "name": "My Workspace"}, + }) + })) + defer srv.Close() + + token, err := ExchangeCodeForToken(t.Context(), srv.URL, "code", "verifier", "cid", "csec", "http://localhost/callback") + if err != nil { + t.Fatalf("ExchangeCodeForToken: %v", err) + } + + if token.AccessToken != "xoxp-slack-user-token" { + t.Errorf("AccessToken = %q, want %q", token.AccessToken, "xoxp-slack-user-token") + } + if token.TokenType != "user" { + t.Errorf("TokenType = %q, want %q", token.TokenType, "user") + } + if token.RefreshToken != "xoxe-1-slack-refresh" { + t.Errorf("RefreshToken = %q, want %q", token.RefreshToken, "xoxe-1-slack-refresh") + } + if token.ExpiresIn != 43200 { + t.Errorf("ExpiresIn = %d, want 43200", token.ExpiresIn) + } + if token.ExpiresAt.IsZero() { + t.Error("ExpiresAt should be set when expires_in is non-zero") + } + if token.Scope != "search:read.public,users:read" { + t.Errorf("Scope = %q, want %q", token.Scope, "search:read.public,users:read") + } +} + +// TestExchangeCodeForToken_SlackOkFalse verifies that a Slack-style +// {"ok":false,"error":"..."} payload — returned with HTTP 200 — surfaces +// as a meaningful error rather than being silently accepted as an empty +// token. +func TestExchangeCodeForToken_SlackOkFalse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "ok": false, + "error": "invalid_code", + }) + })) + defer srv.Close() + + _, err := ExchangeCodeForToken(t.Context(), srv.URL, "code", "verifier", "cid", "csec", "http://localhost/callback") + if err == nil { + t.Fatal("expected an error for ok:false response, got nil") + } + if !strings.Contains(err.Error(), "invalid_code") { + t.Errorf("error = %q, want it to contain the Slack error code %q", err.Error(), "invalid_code") + } +} + +// TestExchangeCodeForToken_MissingAccessToken verifies that a 200 response +// missing any access_token (top-level or nested) is rejected with an +// explicit error instead of silently producing an empty bearer token. +func TestExchangeCodeForToken_MissingAccessToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "token_type": "Bearer", + }) + })) + defer srv.Close() + + _, err := ExchangeCodeForToken(t.Context(), srv.URL, "code", "verifier", "cid", "csec", "http://localhost/callback") + if err == nil { + t.Fatal("expected an error for response with no access_token, got nil") + } + if !strings.Contains(err.Error(), "access_token") { + t.Errorf("error = %q, want it to mention access_token", err.Error()) + } +} + +// TestOAuthTransport_RetryFailureExposesResponseBody verifies that when +// the authenticated retry after a successful OAuth flow still fails with +// a non-2xx status, the response body is logged and preserved for the +// caller. Without this, diagnosing post-OAuth server errors is limited +// to the generic HTTP status text, which hides useful provider-specific +// detail such as scope mismatches or payload complaints. +func TestOAuthTransport_RetryFailureExposesResponseBody(t *testing.T) { + const errBody = `{"jsonrpc":"2.0","id":null,"error":{"code":-32000,"message":"insufficient_scope: missing users:read"}}` + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") == "Bearer stored-at" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(errBody)) + return + } + w.WriteHeader(http.StatusUnauthorized) + })) + defer srv.Close() + + store := NewInMemoryTokenStore() + if err := store.StoreToken(srv.URL, &OAuthToken{ + AccessToken: "stored-at", + TokenType: "Bearer", + }); err != nil { + t.Fatal(err) + } + + transport := &oauthTransport{ + base: http.DefaultTransport, + tokenStore: store, + baseURL: srv.URL, + } + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, srv.URL, strings.NewReader("{}")) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := transport.roundTrip(req, true) + if err != nil { + t.Fatalf("roundTrip: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("reading body after retry: %v", err) + } + if string(got) != errBody { + t.Errorf("response body = %q, want %q", string(got), errBody) + } +} + +// TestOAuthTransport_NonRetryFailureExposesResponseBody verifies that when +// the *first* request fails with a non-2xx that we cannot retry via OAuth +// (e.g. a 400 Bad Request rather than a 401), the response body is still +// preserved and made available to the caller. +// +// Regression test for: Slack's MCP endpoint answering +// +// 400 Bad Request +// {"jsonrpc":"2.0","id":null,"error":{"code":-32600, +// "message":"App is not enabled for Slack MCP server access. ..."}} +// +// where the user previously saw only "Bad Request" bubbled up from the +// MCP SDK because our transport was swallowing the body. We couldn't +// surface the single line that actually tells the user what to do. +func TestOAuthTransport_NonRetryFailureExposesResponseBody(t *testing.T) { + const errBody = `{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"App is not enabled for Slack MCP server access."}}` + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(errBody)) + })) + defer srv.Close() + + store := NewInMemoryTokenStore() + if err := store.StoreToken(srv.URL, &OAuthToken{ + AccessToken: "stored-at", + TokenType: "Bearer", + }); err != nil { + t.Fatal(err) + } + + transport := &oauthTransport{ + base: http.DefaultTransport, + tokenStore: store, + baseURL: srv.URL, + } + + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, srv.URL, strings.NewReader("{}")) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + // Use the exported RoundTrip, which is always called with isRetry=false. + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want 400", resp.StatusCode) + } + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("reading body on first attempt: %v", err) + } + if string(got) != errBody { + t.Errorf("response body = %q, want %q", string(got), errBody) + } +} + +// TestGetValidToken_DiscardsTokenWhenScopesNoLongerCovered verifies that +// a stored token whose RequestedScopes do not cover the config's current +// scopes is discarded (removed from the store and not returned), so the +// next authenticated request triggers a fresh OAuth flow. +func TestGetValidToken_DiscardsTokenWhenScopesNoLongerCovered(t *testing.T) { + store := NewInMemoryTokenStore() + if err := store.StoreToken("https://mcp.example.com", &OAuthToken{ + AccessToken: "stale-at", + TokenType: "Bearer", + RequestedScopes: []string{"users:read"}, + }); err != nil { + t.Fatal(err) + } + + transport := &oauthTransport{ + base: http.DefaultTransport, + tokenStore: store, + baseURL: "https://mcp.example.com", + oauthConfig: &latest.RemoteOAuthConfig{ + Scopes: []string{"users:read", "channels:history"}, + }, + } + + if got := transport.getValidToken(t.Context()); got != nil { + t.Fatalf("expected nil when configured scopes exceed stored scopes, got %+v", got) + } + + // Token must have been purged so the next call doesn't keep returning it. + if _, err := store.GetToken("https://mcp.example.com"); err == nil { + t.Error("expected token to be removed from the store") + } +} + +// TestGetValidToken_ReturnsTokenWhenScopesSatisfied verifies the happy +// path: a stored token whose RequestedScopes cover every configured scope +// is returned unchanged (no re-auth, no refresh). +func TestGetValidToken_ReturnsTokenWhenScopesSatisfied(t *testing.T) { + store := NewInMemoryTokenStore() + if err := store.StoreToken("https://mcp.example.com", &OAuthToken{ + AccessToken: "good-at", + TokenType: "Bearer", + RequestedScopes: []string{"users:read", "channels:history", "extra:scope"}, + }); err != nil { + t.Fatal(err) + } + + transport := &oauthTransport{ + base: http.DefaultTransport, + tokenStore: store, + baseURL: "https://mcp.example.com", + oauthConfig: &latest.RemoteOAuthConfig{ + Scopes: []string{"users:read", "channels:history"}, + }, + } + + got := transport.getValidToken(t.Context()) + if got == nil || got.AccessToken != "good-at" { + t.Fatalf("expected stored token to be returned, got %+v", got) + } +} + +// TestGetValidToken_LeavesLegacyTokenAlone verifies that stored tokens +// that predate the RequestedScopes field (empty slice) are treated as +// sufficient, so an upgrade doesn't forcibly invalidate every existing +// user's session. +func TestGetValidToken_LeavesLegacyTokenAlone(t *testing.T) { + store := NewInMemoryTokenStore() + if err := store.StoreToken("https://mcp.example.com", &OAuthToken{ + AccessToken: "legacy-at", + TokenType: "Bearer", + // RequestedScopes intentionally nil (legacy). + }); err != nil { + t.Fatal(err) + } + + transport := &oauthTransport{ + base: http.DefaultTransport, + tokenStore: store, + baseURL: "https://mcp.example.com", + oauthConfig: &latest.RemoteOAuthConfig{ + Scopes: []string{"users:read"}, + }, + } + + if got := transport.getValidToken(t.Context()); got == nil { + t.Fatal("legacy token without RequestedScopes should not be invalidated on scope mismatch") + } +} + // TestOAuthTransport_NonInteractiveCtxSkipsElicitation verifies that when // the request context is marked non-interactive (via WithoutInteractivePrompts), // a 401 with WWW-Authenticate does NOT trigger the OAuth flow. Instead the @@ -417,3 +725,55 @@ func TestOAuthTransport_NonInteractiveCtxSkipsElicitation(t *testing.T) { t.Errorf("expected IsAuthorizationRequired(err)=true, got err=%v", gotErr) } } + +// TestExtractServerMessage covers the body-to-string conversion used when +// wrapping Initialize errors. The goal is to pick the most human-readable +// string out of whatever the server returns so it can be shown as a TUI +// warning, falling back gracefully instead of leaking opaque JSON. +func TestExtractServerMessage(t *testing.T) { + tests := []struct { + name string + body string + want string + }{ + { + name: "jsonrpc_error_message", + body: `{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"App is not enabled."}}`, + want: "App is not enabled.", + }, + { + name: "top_level_message", + body: `{"message":"rate limited"}`, + want: "rate limited", + }, + { + name: "slack_style_error_string", + body: `{"ok":false,"error":"invalid_auth"}`, + want: "invalid_auth", + }, + { + name: "plain_text", + body: "Service Unavailable\n\n", + want: "Service Unavailable", + }, + { + name: "empty_body", + body: " ", + want: "", + }, + { + name: "very_long_plaintext_is_capped", + body: strings.Repeat("A", 1000), + want: strings.Repeat("A", 400) + "\u2026", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := extractServerMessage([]byte(tc.body)) + if got != tc.want { + t.Errorf("extractServerMessage(%q) = %q, want %q", tc.body, got, tc.want) + } + }) + } +} diff --git a/pkg/tools/mcp/remote.go b/pkg/tools/mcp/remote.go index de5269f47..805c3fe1a 100644 --- a/pkg/tools/mcp/remote.go +++ b/pkg/tools/mcp/remote.go @@ -41,10 +41,9 @@ func newRemoteClient(url, transportType string, headers map[string]string, token func (c *remoteMCPClient) Initialize(ctx context.Context, _ *gomcp.InitializeRequest) (*gomcp.InitializeResult, error) { // Create HTTP client with OAuth support. We keep a reference to the - // oauthTransport so we can recognise the deferred-OAuth case (the - // transport returned an AuthorizationRequiredError because the request - // context disallowed prompts) and re-emit a clean - // AuthorizationRequiredError that callers can detect with errors.As. + // oauthTransport so we can enrich Connect errors with the server's own + // explanation — without this, a plain `Bad Request` bubbles up and the + // user has no idea that, say, the Slack app hasn't been enabled for MCP. httpClient, oauthT := c.createHTTPClient() var transport gomcp.Transport @@ -93,20 +92,28 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *gomcp.InitializeReq return session.InitializeResult(), nil } -// enrichConnectError wraps the error returned by client.Connect so callers -// can distinguish the deferred-OAuth case from a real failure. +// enrichConnectError wraps the error returned by client.Connect with any +// server-side failure message captured by the transport. The MCP SDK +// surfaces only http.StatusText ("Bad Request", "Forbidden", ...) even when +// the server included a useful JSON-RPC error payload, so we append the +// extracted message here so callers — and ultimately the user — can see it. // -// The MCP SDK uses fmt.Errorf("%w: %v", …) when it surfaces transport errors, -// which means the original error is included as text only — not in the unwrap -// chain — so we can't rely on errors.As against the SDK-wrapped error. -// Instead we read the deferred-auth flag back off the transport and re-emit -// a clean AuthorizationRequiredError. +// It also recognises the deferred-OAuth case (the transport returned an +// AuthorizationRequiredError because the request context disallowed prompts) +// and re-emits a clean AuthorizationRequiredError so callers can distinguish +// it from a real failure with errors.As. We can't rely on the SDK's own +// wrapping for this because the SDK uses fmt.Errorf("%w: %v", …) when it +// surfaces transport errors — the original error is included as text only, +// not in the unwrap chain. // // Pre: err != nil and t != nil; only called from the Connect failure path. func enrichConnectError(err error, t *oauthTransport) error { if t.authorizationRequired() { return &AuthorizationRequiredError{URL: t.baseURL} } + if status, msg := t.lastServerError(); status != 0 && msg != "" { + return fmt.Errorf("failed to connect to MCP server: %w (server responded %d: %s)", err, status, msg) + } return fmt.Errorf("failed to connect to MCP server: %w", err) } @@ -123,8 +130,8 @@ func (c *remoteMCPClient) SetManagedOAuth(managed bool) { // at request time from upstream headers stored in the request context. // // The oauthTransport is returned alongside the client so callers can inspect -// the transport's state (e.g. whether OAuth was deferred) when Connect() -// returns and we need to surface the actual cause of the failure. +// the most recent server-side failure (via lastServerError) when Connect() +// returns a bare HTTP-status error and we need to surface the actual cause. func (c *remoteMCPClient) createHTTPClient() (*http.Client, *oauthTransport) { base := c.headerTransport() diff --git a/pkg/tools/mcp/remote_test.go b/pkg/tools/mcp/remote_test.go index 3f54dc603..81e266f1f 100644 --- a/pkg/tools/mcp/remote_test.go +++ b/pkg/tools/mcp/remote_test.go @@ -182,6 +182,47 @@ func TestRemoteClientEmptyHeaders(t *testing.T) { } } +// TestInitialize_SurfacesServerErrorInReturnedError verifies that when an +// MCP server rejects the initialize call with a 4xx carrying a JSON-RPC +// error body, the error returned by Initialize contains the server's own +// explanation — not just the generic "Bad Request" from http.StatusText. +// +// Regression test for: Slack's MCP endpoint answering +// +// 400 Bad Request +// {"jsonrpc":"2.0","id":null,"error":{"code":-32600, +// "message":"App is not enabled for Slack MCP server access. ..."}} +// +// where the bubbled-up error previously read only "...: Bad Request" and +// the user had no way to learn what was actually wrong. +func TestInitialize_SurfacesServerErrorInReturnedError(t *testing.T) { + t.Parallel() + + const msg = "App is not enabled for Slack MCP server access." + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintf(w, `{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":%q}}`, msg) + })) + defer server.Close() + + // Pre-populate a token so the transport doesn't try to trigger OAuth on + // the 401 path — we want to exercise the "server rejected us with a + // non-auth failure" code path. + store := NewInMemoryTokenStore() + require.NoError(t, store.StoreToken(server.URL, &OAuthToken{AccessToken: "at", TokenType: "Bearer"})) + + client := newRemoteClient(server.URL, "streamable", nil, store, nil) + + _, err := client.Initialize(t.Context(), nil) + require.Error(t, err, "Initialize should fail against a server that rejects initialize") + assert.Contains(t, err.Error(), msg, + "Initialize error must surface the server's JSON-RPC error message (%q), got: %v", msg, err) + assert.Contains(t, err.Error(), "400", + "Initialize error should include the HTTP status code so the user knows it was a server rejection, got: %v", err) +} + // TestInitialize_NonInteractiveCtxDefersOAuthAndDoesNotBlock verifies that // when Initialize runs against a server that requires OAuth (responds with // 401 + WWW-Authenticate) under a context flagged with diff --git a/pkg/tools/mcp/tokenstore.go b/pkg/tools/mcp/tokenstore.go index 74a3b89f8..78eadb0cc 100644 --- a/pkg/tools/mcp/tokenstore.go +++ b/pkg/tools/mcp/tokenstore.go @@ -27,6 +27,13 @@ type OAuthToken struct { ClientID string `json:"client_id,omitempty"` ClientSecret string `json:"client_secret,omitempty"` AuthServer string `json:"auth_server,omitempty"` + + // RequestedScopes records the scope list the config asked for when this + // token was obtained. Unlike Scope (which is whatever the authorization + // server chose to return, sometimes empty, sometimes comma/space + // separated), RequestedScopes reflects our intent and is used to detect + // when the config has changed and a new OAuth flow is required. + RequestedScopes []string `json:"requested_scopes,omitempty"` } // IsExpired checks if the token is expired