diff --git a/sdk/go/protomcp/workflow.go b/sdk/go/protomcp/workflow.go index 5314dd6..f8c680a 100644 --- a/sdk/go/protomcp/workflow.go +++ b/sdk/go/protomcp/workflow.go @@ -366,10 +366,18 @@ func HandleStepCall(workflowName, stepName string, ctx ToolContext, args map[str state := getActiveState() if stepDef.Initial { - // Start a new workflow + // Start a new workflow — snapshot pre-workflow tools var preTools []string if ToolManagerAdapter.GetActiveTools != nil { preTools = ToolManagerAdapter.GetActiveTools() + } else { + // Compute from registered tools, excluding this workflow's own tools + prefix := workflowName + "." + for _, t := range GetRegisteredTools() { + if !strings.HasPrefix(t.Name, prefix) { + preTools = append(preTools, t.Name) + } + } } newState := WorkflowState{ WorkflowName: workflowName, @@ -485,7 +493,16 @@ func HandleStepCall(workflowName, stepName string, ctx ToolContext, args map[str } r := Result(resultText) r.EnableTools = state.PreWorkflowTools - r.DisableTools = []string{} + // Disable everything not in pre-workflow set + preSet := map[string]bool{} + for _, t := range state.PreWorkflowTools { + preSet[t] = true + } + for _, t := range GetRegisteredTools() { + if !preSet[t.Name] { + r.DisableTools = append(r.DisableTools, t.Name) + } + } return r } @@ -531,14 +548,22 @@ func HandleCancel(workflowName string) ToolResult { wf.OnCancelFn(state.CurrentStep, state.History) } - // Restore pre-workflow tools + // Restore pre-workflow tools — disable everything not in pre-workflow set if ToolManagerAdapter.SetAllowed != nil { ToolManagerAdapter.SetAllowed(state.PreWorkflowTools) } activeWorkflowStack = activeWorkflowStack[:len(activeWorkflowStack)-1] r := Result(fmt.Sprintf("Workflow '%s' cancelled", workflowName)) r.EnableTools = state.PreWorkflowTools - r.DisableTools = []string{} + preSet := map[string]bool{} + for _, t := range state.PreWorkflowTools { + preSet[t] = true + } + for _, t := range GetRegisteredTools() { + if !preSet[t.Name] { + r.DisableTools = append(r.DisableTools, t.Name) + } + } return r } diff --git a/sdk/rust/src/workflow.rs b/sdk/rust/src/workflow.rs index b19e771..2f8bee7 100644 --- a/sdk/rust/src/workflow.rs +++ b/sdk/rust/src/workflow.rs @@ -386,19 +386,12 @@ fn compute_transition( } // enable_tools = allowed_tools - // disable_tools = all workflow tools NOT in allowed (to hide them) - let all_wf_tools: Vec = wf.steps.iter().map(|s| format!("{}.{}", wf.name, s.name)).collect(); - let cancel_tool = format!("{}.cancel", wf.name); - - let mut disable_tools: Vec = Vec::new(); - for t in &all_wf_tools { - if !allowed_tools.contains(t) { - disable_tools.push(t.clone()); - } - } - if !allowed_tools.contains(&cancel_tool) { - disable_tools.push(cancel_tool); - } + // disable_tools = all registered tools NOT in allowed set + let allowed_set: std::collections::HashSet<&String> = allowed_tools.iter().collect(); + let disable_tools: Vec = all_tool_names.iter() + .filter(|t| !allowed_set.contains(t)) + .cloned() + .collect(); (allowed_tools, disable_tools) } @@ -454,11 +447,19 @@ fn handle_step_call(workflow_name: &str, step_name: &str, ctx: ToolContext, args } } + // Drop state lock before calling handler to avoid deadlock if handler + // accesses workflow state. Registry lock (guard) is still held since + // handler lives in the registry, but that's unavoidable without Arc. + drop(state_guard); + // Run handler, catch panics as errors let handler_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { (step_def.handler)(ctx, args) })); + // Re-acquire state lock after handler call + let mut state_guard = ACTIVE_WORKFLOW.lock().unwrap_or_else(|e| e.into_inner()); + match handler_result { Err(panic_info) => { // Handler panicked — treat as error diff --git a/sdk/typescript/src/workflow.ts b/sdk/typescript/src/workflow.ts index d9d2055..5957b23 100644 --- a/sdk/typescript/src/workflow.ts +++ b/sdk/typescript/src/workflow.ts @@ -344,12 +344,15 @@ async function handleStepCall(workflowName: string, stepName: string, kwargs: Re if (wf.onComplete) { wf.onComplete(state!.history); } - // Restore pre-workflow tools + // Restore pre-workflow tools — disable everything not in pre-workflow set + const allToolNames = getRegisteredTools().map(t => t.name); + const preSet = new Set(state!.preWorkflowTools); + const terminalDisableTools = allToolNames.filter(t => !preSet.has(t)); activeWorkflowStack.pop(); return new ToolResult({ result: result.result || 'Workflow complete', enableTools: state!.preWorkflowTools, - disableTools: [], + disableTools: terminalDisableTools, }); } else { // Transition to next steps @@ -387,12 +390,15 @@ function handleCancel(workflowName: string): ToolResult { wf.onCancel(state.currentStep, state.history); } - // Restore pre-workflow tools + // Restore pre-workflow tools — disable everything not in pre-workflow set + const allToolNames = getRegisteredTools().map(t => t.name); + const preSet = new Set(state.preWorkflowTools); + const cancelDisableTools = allToolNames.filter(t => !preSet.has(t)); activeWorkflowStack.pop(); return new ToolResult({ result: `Workflow '${workflowName}' cancelled`, enableTools: state.preWorkflowTools, - disableTools: [], + disableTools: cancelDisableTools, }); }