Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/go/protomcp/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ func handleReload(tp *Transport, reqID string) {
})
handleListTools(tp, "")
sendMiddlewareRegistrations(tp)
sendDisableHiddenTools(tp)
}

func uriMatchesTemplate(template, uri string) bool {
Expand Down
5 changes: 3 additions & 2 deletions sdk/python/src/protomcp/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def _handle_reload(transport, env, mw_handlers):
fake_env = pb.Envelope() # empty envelope (request_id defaults to "")
_handle_list_tools(transport, fake_env)
_send_middleware_registrations(transport, mw_handlers)
_disable_hidden_tools(transport)

def _send_middleware_registrations(transport, mw_handlers):
mw_defs = get_registered_middleware()
Expand Down Expand Up @@ -265,8 +266,8 @@ def _handle_middleware_intercept(transport, env, mw_handlers):

resp = pb.Envelope(
middleware_intercept_response=pb.MiddlewareInterceptResponse(
arguments_json=resp_fields.get("arguments_json", ""),
result_json=resp_fields.get("result_json", ""),
arguments_json=resp_fields.get("arguments_json", req.arguments_json),
result_json=resp_fields.get("result_json", req.result_json),
reject=resp_fields.get("reject", False),
reject_reason=resp_fields.get("reject_reason", ""),
),
Expand Down
1 change: 1 addition & 0 deletions sdk/rust/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ async fn handle_reload(transport: &Transport, request_id: &str) {
let _ = transport.send(&resp).await;
handle_list_tools(transport, "").await;
send_middleware_registrations(transport).await;
send_disable_hidden_tools(transport).await;
}

async fn send_disable_hidden_tools(transport: &Transport) {
Expand Down
28 changes: 19 additions & 9 deletions sdk/rust/src/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct WorkflowState {
workflow_name: String,
current_step: String,
history: Vec<StepHistoryEntry>,
pre_workflow_tools: Vec<String>,
}

// ── Builders ──
Expand Down Expand Up @@ -422,12 +423,23 @@ fn handle_step_call(workflow_name: &str, step_name: &str, ctx: ToolContext, args

let mut state_guard = ACTIVE_WORKFLOW.lock().unwrap_or_else(|e| e.into_inner());

// Collect all tool names for visibility computation (must be before initial state creation)
let all_tool_names: Vec<String> = crate::tool::with_registry(|tools| {
tools.iter().map(|t| t.name.clone()).collect()
});

if step_def.initial {
// Snapshot tools that aren't part of this workflow
let pre_tools: Vec<String> = all_tool_names.iter()
.filter(|t| !t.starts_with(&format!("{}.", workflow_name)))
.cloned()
.collect();
// Start new workflow
*state_guard = Some(WorkflowState {
workflow_name: workflow_name.to_string(),
current_step: step_name.to_string(),
history: Vec::new(),
pre_workflow_tools: pre_tools,
});
} else {
// Must have active workflow
Expand All @@ -442,11 +454,6 @@ fn handle_step_call(workflow_name: &str, step_name: &str, ctx: ToolContext, args
}
}

// Collect all tool names for visibility computation
let all_tool_names: Vec<String> = crate::tool::with_registry(|tools| {
tools.iter().map(|t| t.name.clone()).collect()
});

// Run handler, catch panics as errors
let handler_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
(step_def.handler)(ctx, args)
Expand Down Expand Up @@ -537,10 +544,10 @@ fn handle_step_call(workflow_name: &str, step_name: &str, ctx: ToolContext, args
let cancel_tool = format!("{}.cancel", wf.name);
result.disable_tools = non_initial_tools;
result.disable_tools.push(cancel_tool);
// Don't re-enable initial — it was never disabled (it's always visible)
// But we do need to re-enable any external tools that were disabled
// Since we don't track pre-workflow state in the Rust version (no manager),
// we just clear the active state and let the result carry enable/disable
// Re-enable pre-workflow tools
if let Some(ref state) = *state_guard {
result.enable_tools = state.pre_workflow_tools.clone();
}
*state_guard = None;
result
} else {
Expand Down Expand Up @@ -596,6 +603,9 @@ fn handle_cancel(workflow_name: &str) -> ToolResult {
let mut result = ToolResult::new(format!("Workflow '{}' cancelled", workflow_name));
result.disable_tools = non_initial_tools;
result.disable_tools.push(cancel_tool);
if let Some(ref state) = *state_guard {
result.enable_tools = state.pre_workflow_tools.clone();
}

*state_guard = None;
result
Expand Down
11 changes: 9 additions & 2 deletions sdk/typescript/src/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ export async function run(): Promise<void> {

const resp = Envelope.create({
middlewareInterceptResponse: MiddlewareInterceptResponse.create({
argumentsJson: respFields['argumentsJson'] ?? respFields['arguments_json'] ?? '',
resultJson: respFields['resultJson'] ?? respFields['result_json'] ?? '',
argumentsJson: respFields['argumentsJson'] ?? respFields['arguments_json'] ?? req['argumentsJson'] ?? '',
resultJson: respFields['resultJson'] ?? respFields['result_json'] ?? req['resultJson'] ?? '',
reject: respFields['reject'] ?? false,
rejectReason: respFields['rejectReason'] ?? respFields['reject_reason'] ?? '',
}),
Expand Down Expand Up @@ -324,6 +324,13 @@ export async function run(): Promise<void> {
await transport.send(reloadResp);
await sendListTools('');
await sendMiddlewareRegistrations();
const reloadHiddenNames = getHiddenToolNames();
if (reloadHiddenNames.length > 0) {
const disableResp = Envelope.create({
disableTools: { toolNames: reloadHiddenNames },
});
await transport.send(disableResp);
}
}
}
}
5 changes: 4 additions & 1 deletion sdk/typescript/src/workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,13 @@ async function handleStepCall(workflowName: string, stepName: string, kwargs: Re
if (errMsg.includes(substring)) {
state!.currentStep = targetStep;
const allowedTools = transitionToSteps(wf, state!, [targetStep]);
const allToolNames = getRegisteredTools().map(t => t.name);
const allowedSet = new Set(allowedTools);
const onErrorDisableTools = allToolNames.filter(t => !allowedSet.has(t));
return new ToolResult({
result: `Error caught (${errMsg}), transitioning to '${targetStep}'`,
enableTools: allowedTools,
disableTools: [],
disableTools: onErrorDisableTools,
});
}
}
Expand Down
Loading