diff --git a/sdk/go/protomcp/runner.go b/sdk/go/protomcp/runner.go index bd28b88..0d9b5a9 100644 --- a/sdk/go/protomcp/runner.go +++ b/sdk/go/protomcp/runner.go @@ -314,6 +314,7 @@ func handleReload(tp *Transport, reqID string) { }) handleListTools(tp, "") sendMiddlewareRegistrations(tp) + sendDisableHiddenTools(tp) } func uriMatchesTemplate(template, uri string) bool { diff --git a/sdk/python/src/protomcp/runner.py b/sdk/python/src/protomcp/runner.py index 1aa137a..d7a7323 100644 --- a/sdk/python/src/protomcp/runner.py +++ b/sdk/python/src/protomcp/runner.py @@ -213,6 +213,7 @@ def _handle_reload(transport, env, mw_handlers): fake_env = pb.Envelope() # empty envelope (request_id defaults to "") _handle_list_tools(transport, fake_env) _send_middleware_registrations(transport, mw_handlers) + _disable_hidden_tools(transport) def _send_middleware_registrations(transport, mw_handlers): mw_defs = get_registered_middleware() @@ -265,8 +266,8 @@ def _handle_middleware_intercept(transport, env, mw_handlers): resp = pb.Envelope( middleware_intercept_response=pb.MiddlewareInterceptResponse( - arguments_json=resp_fields.get("arguments_json", ""), - result_json=resp_fields.get("result_json", ""), + arguments_json=resp_fields.get("arguments_json", req.arguments_json), + result_json=resp_fields.get("result_json", req.result_json), reject=resp_fields.get("reject", False), reject_reason=resp_fields.get("reject_reason", ""), ), diff --git a/sdk/rust/src/runner.rs b/sdk/rust/src/runner.rs index b54deef..4803999 100644 --- a/sdk/rust/src/runner.rs +++ b/sdk/rust/src/runner.rs @@ -234,6 +234,7 @@ async fn handle_reload(transport: &Transport, request_id: &str) { let _ = transport.send(&resp).await; handle_list_tools(transport, "").await; send_middleware_registrations(transport).await; + send_disable_hidden_tools(transport).await; } async fn send_disable_hidden_tools(transport: &Transport) { diff --git a/sdk/rust/src/workflow.rs b/sdk/rust/src/workflow.rs index b76f212..b19e771 100644 --- a/sdk/rust/src/workflow.rs +++ b/sdk/rust/src/workflow.rs @@ -64,6 +64,7 @@ struct WorkflowState { workflow_name: String, current_step: String, history: Vec, + pre_workflow_tools: Vec, } // ── Builders ── @@ -422,12 +423,23 @@ fn handle_step_call(workflow_name: &str, step_name: &str, ctx: ToolContext, args let mut state_guard = ACTIVE_WORKFLOW.lock().unwrap_or_else(|e| e.into_inner()); + // Collect all tool names for visibility computation (must be before initial state creation) + let all_tool_names: Vec = crate::tool::with_registry(|tools| { + tools.iter().map(|t| t.name.clone()).collect() + }); + if step_def.initial { + // Snapshot tools that aren't part of this workflow + let pre_tools: Vec = all_tool_names.iter() + .filter(|t| !t.starts_with(&format!("{}.", workflow_name))) + .cloned() + .collect(); // Start new workflow *state_guard = Some(WorkflowState { workflow_name: workflow_name.to_string(), current_step: step_name.to_string(), history: Vec::new(), + pre_workflow_tools: pre_tools, }); } else { // Must have active workflow @@ -442,11 +454,6 @@ fn handle_step_call(workflow_name: &str, step_name: &str, ctx: ToolContext, args } } - // Collect all tool names for visibility computation - let all_tool_names: Vec = crate::tool::with_registry(|tools| { - tools.iter().map(|t| t.name.clone()).collect() - }); - // Run handler, catch panics as errors let handler_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { (step_def.handler)(ctx, args) @@ -537,10 +544,10 @@ fn handle_step_call(workflow_name: &str, step_name: &str, ctx: ToolContext, args let cancel_tool = format!("{}.cancel", wf.name); result.disable_tools = non_initial_tools; result.disable_tools.push(cancel_tool); - // Don't re-enable initial — it was never disabled (it's always visible) - // But we do need to re-enable any external tools that were disabled - // Since we don't track pre-workflow state in the Rust version (no manager), - // we just clear the active state and let the result carry enable/disable + // Re-enable pre-workflow tools + if let Some(ref state) = *state_guard { + result.enable_tools = state.pre_workflow_tools.clone(); + } *state_guard = None; result } else { @@ -596,6 +603,9 @@ fn handle_cancel(workflow_name: &str) -> ToolResult { let mut result = ToolResult::new(format!("Workflow '{}' cancelled", workflow_name)); result.disable_tools = non_initial_tools; result.disable_tools.push(cancel_tool); + if let Some(ref state) = *state_guard { + result.enable_tools = state.pre_workflow_tools.clone(); + } *state_guard = None; result diff --git a/sdk/typescript/src/runner.ts b/sdk/typescript/src/runner.ts index 5af5cf2..0391f89 100644 --- a/sdk/typescript/src/runner.ts +++ b/sdk/typescript/src/runner.ts @@ -140,8 +140,8 @@ export async function run(): Promise { const resp = Envelope.create({ middlewareInterceptResponse: MiddlewareInterceptResponse.create({ - argumentsJson: respFields['argumentsJson'] ?? respFields['arguments_json'] ?? '', - resultJson: respFields['resultJson'] ?? respFields['result_json'] ?? '', + argumentsJson: respFields['argumentsJson'] ?? respFields['arguments_json'] ?? req['argumentsJson'] ?? '', + resultJson: respFields['resultJson'] ?? respFields['result_json'] ?? req['resultJson'] ?? '', reject: respFields['reject'] ?? false, rejectReason: respFields['rejectReason'] ?? respFields['reject_reason'] ?? '', }), @@ -324,6 +324,13 @@ export async function run(): Promise { await transport.send(reloadResp); await sendListTools(''); await sendMiddlewareRegistrations(); + const reloadHiddenNames = getHiddenToolNames(); + if (reloadHiddenNames.length > 0) { + const disableResp = Envelope.create({ + disableTools: { toolNames: reloadHiddenNames }, + }); + await transport.send(disableResp); + } } } } diff --git a/sdk/typescript/src/workflow.ts b/sdk/typescript/src/workflow.ts index 6115891..d9d2055 100644 --- a/sdk/typescript/src/workflow.ts +++ b/sdk/typescript/src/workflow.ts @@ -292,10 +292,13 @@ async function handleStepCall(workflowName: string, stepName: string, kwargs: Re if (errMsg.includes(substring)) { state!.currentStep = targetStep; const allowedTools = transitionToSteps(wf, state!, [targetStep]); + const allToolNames = getRegisteredTools().map(t => t.name); + const allowedSet = new Set(allowedTools); + const onErrorDisableTools = allToolNames.filter(t => !allowedSet.has(t)); return new ToolResult({ result: `Error caught (${errMsg}), transitioning to '${targetStep}'`, enableTools: allowedTools, - disableTools: [], + disableTools: onErrorDisableTools, }); } }