Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions sdk/go/protomcp/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,18 @@ func HandleStepCall(workflowName, stepName string, ctx ToolContext, args map[str
state := getActiveState()

if stepDef.Initial {
// Start a new workflow
// Start a new workflow — snapshot pre-workflow tools
var preTools []string
if ToolManagerAdapter.GetActiveTools != nil {
preTools = ToolManagerAdapter.GetActiveTools()
} else {
// Compute from registered tools, excluding this workflow's own tools
prefix := workflowName + "."
for _, t := range GetRegisteredTools() {
if !strings.HasPrefix(t.Name, prefix) {
preTools = append(preTools, t.Name)
}
}
}
newState := WorkflowState{
WorkflowName: workflowName,
Expand Down Expand Up @@ -485,7 +493,16 @@ func HandleStepCall(workflowName, stepName string, ctx ToolContext, args map[str
}
r := Result(resultText)
r.EnableTools = state.PreWorkflowTools
r.DisableTools = []string{}
// Disable everything not in pre-workflow set
preSet := map[string]bool{}
for _, t := range state.PreWorkflowTools {
preSet[t] = true
}
for _, t := range GetRegisteredTools() {
if !preSet[t.Name] {
r.DisableTools = append(r.DisableTools, t.Name)
}
}
return r
}

Expand Down Expand Up @@ -531,14 +548,22 @@ func HandleCancel(workflowName string) ToolResult {
wf.OnCancelFn(state.CurrentStep, state.History)
}

// Restore pre-workflow tools
// Restore pre-workflow tools — disable everything not in pre-workflow set
if ToolManagerAdapter.SetAllowed != nil {
ToolManagerAdapter.SetAllowed(state.PreWorkflowTools)
}
activeWorkflowStack = activeWorkflowStack[:len(activeWorkflowStack)-1]
r := Result(fmt.Sprintf("Workflow '%s' cancelled", workflowName))
r.EnableTools = state.PreWorkflowTools
r.DisableTools = []string{}
preSet := map[string]bool{}
for _, t := range state.PreWorkflowTools {
preSet[t] = true
}
for _, t := range GetRegisteredTools() {
if !preSet[t.Name] {
r.DisableTools = append(r.DisableTools, t.Name)
}
}
return r
}

Expand Down
27 changes: 14 additions & 13 deletions sdk/rust/src/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,19 +386,12 @@ fn compute_transition(
}

// enable_tools = allowed_tools
// disable_tools = all workflow tools NOT in allowed (to hide them)
let all_wf_tools: Vec<String> = wf.steps.iter().map(|s| format!("{}.{}", wf.name, s.name)).collect();
let cancel_tool = format!("{}.cancel", wf.name);

let mut disable_tools: Vec<String> = Vec::new();
for t in &all_wf_tools {
if !allowed_tools.contains(t) {
disable_tools.push(t.clone());
}
}
if !allowed_tools.contains(&cancel_tool) {
disable_tools.push(cancel_tool);
}
// disable_tools = all registered tools NOT in allowed set
let allowed_set: std::collections::HashSet<&String> = allowed_tools.iter().collect();
let disable_tools: Vec<String> = all_tool_names.iter()
.filter(|t| !allowed_set.contains(t))
.cloned()
.collect();

(allowed_tools, disable_tools)
}
Expand Down Expand Up @@ -454,11 +447,19 @@ fn handle_step_call(workflow_name: &str, step_name: &str, ctx: ToolContext, args
}
}

// Drop state lock before calling handler to avoid deadlock if handler
// accesses workflow state. Registry lock (guard) is still held since
// handler lives in the registry, but that's unavoidable without Arc.
drop(state_guard);

// Run handler, catch panics as errors
let handler_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
(step_def.handler)(ctx, args)
}));

// Re-acquire state lock after handler call
let mut state_guard = ACTIVE_WORKFLOW.lock().unwrap_or_else(|e| e.into_inner());

match handler_result {
Err(panic_info) => {
// Handler panicked — treat as error
Expand Down
14 changes: 10 additions & 4 deletions sdk/typescript/src/workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -344,12 +344,15 @@ async function handleStepCall(workflowName: string, stepName: string, kwargs: Re
if (wf.onComplete) {
wf.onComplete(state!.history);
}
// Restore pre-workflow tools
// Restore pre-workflow tools — disable everything not in pre-workflow set
const allToolNames = getRegisteredTools().map(t => t.name);
const preSet = new Set(state!.preWorkflowTools);
const terminalDisableTools = allToolNames.filter(t => !preSet.has(t));
activeWorkflowStack.pop();
return new ToolResult({
result: result.result || 'Workflow complete',
enableTools: state!.preWorkflowTools,
disableTools: [],
disableTools: terminalDisableTools,
});
} else {
// Transition to next steps
Expand Down Expand Up @@ -387,12 +390,15 @@ function handleCancel(workflowName: string): ToolResult {
wf.onCancel(state.currentStep, state.history);
}

// Restore pre-workflow tools
// Restore pre-workflow tools — disable everything not in pre-workflow set
const allToolNames = getRegisteredTools().map(t => t.name);
const preSet = new Set(state.preWorkflowTools);
const cancelDisableTools = allToolNames.filter(t => !preSet.has(t));
activeWorkflowStack.pop();
return new ToolResult({
result: `Workflow '${workflowName}' cancelled`,
enableTools: state.preWorkflowTools,
disableTools: [],
disableTools: cancelDisableTools,
});
}

Expand Down
Loading