From a30f12d333c80c5c7ea9d4713bf3f7f4710fbd44 Mon Sep 17 00:00:00 2001 From: Michael S Date: Sun, 15 Mar 2026 00:46:07 -0400 Subject: [PATCH 1/2] fix: hot reload hang, workflow tool visibility, async handlers - Python/TypeScript reload now sends post-reload tool list with empty request_id so the Go runtime routes it correctly (fixes reload hang) - Go SDK resource template matching checks URI before calling handler - TypeScript workflow computes disableTools on transitions (was empty) - TypeScript workflow awaits async step handlers (was dropping Promises) - Go SDK workflow returns enable/disable lists via ToolResult instead of relying on nil ToolManagerAdapter --- sdk/go/protomcp/runner.go | 34 +++++++++++++++++- sdk/go/protomcp/workflow.go | 58 ++++++++++++++++++++++++------- sdk/python/src/protomcp/runner.py | 5 +-- sdk/typescript/src/runner.ts | 2 +- sdk/typescript/src/workflow.ts | 7 ++-- 5 files changed, 88 insertions(+), 18 deletions(-) diff --git a/sdk/go/protomcp/runner.go b/sdk/go/protomcp/runner.go index 94f6941..bd28b88 100644 --- a/sdk/go/protomcp/runner.go +++ b/sdk/go/protomcp/runner.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "strconv" + "strings" "sync" "time" @@ -315,6 +316,37 @@ func handleReload(tp *Transport, reqID string) { sendMiddlewareRegistrations(tp) } +func uriMatchesTemplate(template, uri string) bool { + parts := strings.Split(template, "{") + if len(parts) == 0 { + return template == uri + } + if !strings.HasPrefix(uri, parts[0]) { + return false + } + remaining := uri[len(parts[0]):] + for _, part := range parts[1:] { + closeBrace := strings.Index(part, "}") + if closeBrace < 0 { + return false + } + suffix := part[closeBrace+1:] + if suffix == "" { + if remaining == "" { + return false + } + remaining = "" + } else { + idx := strings.Index(remaining, suffix) + if idx <= 0 { + return false + } + remaining = remaining[idx+len(suffix):] + } + } + return remaining == "" +} + func handleListResources(tp *Transport, reqID string) { resources := GetRegisteredResources() var defs []*pb.ResourceDefinition @@ -368,7 +400,7 @@ func handleReadResource(tp *Transport, req *pb.ReadResourceRequest, reqID string // Try resource templates. for _, t := range GetRegisteredResourceTemplates() { - if t.HandlerFn != nil { + if t.HandlerFn != nil && uriMatchesTemplate(t.URITemplate, uri) { contents := t.HandlerFn(uri) if len(contents) > 0 { sendResourceContents(tp, reqID, contents) diff --git a/sdk/go/protomcp/workflow.go b/sdk/go/protomcp/workflow.go index e9b1fb4..5314dd6 100644 --- a/sdk/go/protomcp/workflow.go +++ b/sdk/go/protomcp/workflow.go @@ -310,11 +310,7 @@ var ToolManagerAdapter struct { SetAllowed func([]string) } -func transitionToSteps(wf *WorkflowDef, state *WorkflowState, nextStepNames []string) { - if ToolManagerAdapter.SetAllowed == nil { - return - } - +func transitionToSteps(wf *WorkflowDef, state *WorkflowState, nextStepNames []string) []string { stepMap := map[string]*StepDef{} for i := range wf.Steps { stepMap[wf.Steps[i].Name] = &wf.Steps[i] @@ -347,7 +343,11 @@ func transitionToSteps(wf *WorkflowDef, state *WorkflowState, nextStepNames []st } } - ToolManagerAdapter.SetAllowed(allowedTools) + if ToolManagerAdapter.SetAllowed != nil { + ToolManagerAdapter.SetAllowed(allowedTools) + } + + return allowedTools } // --- Step dispatch --- @@ -403,8 +403,22 @@ func HandleStepCall(workflowName, stepName string, ctx ToolContext, args map[str for errKey, targetStep := range stepDef.OnError { if strings.Contains(errStr, errKey) { state.CurrentStep = targetStep - transitionToSteps(wf, state, []string{targetStep}) - return Result(fmt.Sprintf("Error caught (%s), transitioning to '%s'", errStr, targetStep)) + allowedTools := transitionToSteps(wf, state, []string{targetStep}) + allTools := GetRegisteredTools() + allowedSet := map[string]bool{} + for _, t := range allowedTools { + allowedSet[t] = true + } + var disableTools []string + for _, t := range allTools { + if !allowedSet[t.Name] { + disableTools = append(disableTools, t.Name) + } + } + r := Result(fmt.Sprintf("Error caught (%s), transitioning to '%s'", errStr, targetStep)) + r.EnableTools = allowedTools + r.DisableTools = disableTools + return r } } } @@ -469,16 +483,33 @@ func HandleStepCall(workflowName, stepName string, ctx ToolContext, args map[str if resultText == "" { resultText = "Workflow complete" } - return Result(resultText) + r := Result(resultText) + r.EnableTools = state.PreWorkflowTools + r.DisableTools = []string{} + return r } // Transition to next steps - transitionToSteps(wf, state, effectiveNext) + allowedTools := transitionToSteps(wf, state, effectiveNext) + allTools := GetRegisteredTools() + allowedSet := map[string]bool{} + for _, t := range allowedTools { + allowedSet[t] = true + } + var disableTools []string + for _, t := range allTools { + if !allowedSet[t.Name] { + disableTools = append(disableTools, t.Name) + } + } resultText := result.Result if resultText == "" { resultText = fmt.Sprintf("Proceed to: %v", effectiveNext) } - return Result(resultText) + r := Result(resultText) + r.EnableTools = allowedTools + r.DisableTools = disableTools + return r } // HandleCancel handles a cancel tool call for the given workflow. @@ -505,7 +536,10 @@ func HandleCancel(workflowName string) ToolResult { ToolManagerAdapter.SetAllowed(state.PreWorkflowTools) } activeWorkflowStack = activeWorkflowStack[:len(activeWorkflowStack)-1] - return Result(fmt.Sprintf("Workflow '%s' cancelled", workflowName)) + r := Result(fmt.Sprintf("Workflow '%s' cancelled", workflowName)) + r.EnableTools = state.PreWorkflowTools + r.DisableTools = []string{} + return r } // --- Tool generation --- diff --git a/sdk/python/src/protomcp/runner.py b/sdk/python/src/protomcp/runner.py index 48189b3..1aa137a 100644 --- a/sdk/python/src/protomcp/runner.py +++ b/sdk/python/src/protomcp/runner.py @@ -209,8 +209,9 @@ def _handle_reload(transport, env, mw_handlers): request_id=env.request_id, ) transport.send(resp) - # Also re-send tool list - _handle_list_tools(transport, env) + # Also re-send tool list with empty request_id so it routes to handshakeCh + fake_env = pb.Envelope() # empty envelope (request_id defaults to "") + _handle_list_tools(transport, fake_env) _send_middleware_registrations(transport, mw_handlers) def _send_middleware_registrations(transport, mw_handlers): diff --git a/sdk/typescript/src/runner.ts b/sdk/typescript/src/runner.ts index b8e1f57..5af5cf2 100644 --- a/sdk/typescript/src/runner.ts +++ b/sdk/typescript/src/runner.ts @@ -322,7 +322,7 @@ export async function run(): Promise { requestId, }); await transport.send(reloadResp); - await sendListTools(requestId); + await sendListTools(''); await sendMiddlewareRegistrations(); } } diff --git a/sdk/typescript/src/workflow.ts b/sdk/typescript/src/workflow.ts index f1c67fd..6115891 100644 --- a/sdk/typescript/src/workflow.ts +++ b/sdk/typescript/src/workflow.ts @@ -283,7 +283,7 @@ async function handleStepCall(workflowName: string, stepName: string, kwargs: Re // Run the handler let result: StepResult | string; try { - result = stepDef.handler(kwargs, ctx); + result = await stepDef.handler(kwargs, ctx); } catch (exc: any) { // Check onError mapping if (stepDef.onError) { @@ -351,10 +351,13 @@ async function handleStepCall(workflowName: string, stepName: string, kwargs: Re } else { // Transition to next steps const allowedTools = transitionToSteps(wf, state!, effectiveNext || []); + const allToolNames = getRegisteredTools().map(t => t.name); + const allowedSet = new Set(allowedTools); + const disableTools = allToolNames.filter(t => !allowedSet.has(t)); return new ToolResult({ result: result.result || `Proceed to: ${JSON.stringify(effectiveNext)}`, enableTools: allowedTools, - disableTools: [], + disableTools, }); } } From 53b4bc0331b92bb5f0769cd519c4664196b84095 Mon Sep 17 00:00:00 2001 From: Michael S Date: Sun, 15 Mar 2026 00:46:24 -0400 Subject: [PATCH 2/2] test: add regression tests and e2e coverage across all SDKs Python (21 new tests): - Hidden tool detection includes workflow/group hidden tools - Resource template URI matching patterns - Hot reload clears all registries TypeScript (5 new tests): - Workflow preWorkflowTools snapshot and restore - Hot reload registry clearing - Hidden tool detection across sources Go SDK (4 new tests): - HiddenHint option behavior Rust (2 new tests): - Hidden field defaults and construction E2E (4 new tests): - Python resource read - Python prompt get - Python tool call echo + add verification --- sdk/go/protomcp/hidden_test.go | 131 +++++++ sdk/python/tests/test_round3_regression.py | 339 ++++++++++++++++++ sdk/rust/src/tool.rs | 49 +++ sdk/typescript/tests/round3Regression.test.ts | 202 +++++++++++ test/e2e/e2e_test.go | 188 ++++++++++ test/e2e/fixtures/prompt_tool.py | 20 ++ test/e2e/fixtures/resource_tool.py | 27 ++ 7 files changed, 956 insertions(+) create mode 100644 sdk/go/protomcp/hidden_test.go create mode 100644 sdk/python/tests/test_round3_regression.py create mode 100644 sdk/typescript/tests/round3Regression.test.ts create mode 100644 test/e2e/fixtures/prompt_tool.py create mode 100644 test/e2e/fixtures/resource_tool.py diff --git a/sdk/go/protomcp/hidden_test.go b/sdk/go/protomcp/hidden_test.go new file mode 100644 index 0000000..77f0000 --- /dev/null +++ b/sdk/go/protomcp/hidden_test.go @@ -0,0 +1,131 @@ +package protomcp + +import ( + "testing" +) + +func TestHiddenToolDefault(t *testing.T) { + ClearRegistry() + defer ClearRegistry() + + Tool("visible_tool", + Description("A visible tool"), + Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult { + return Result("ok") + }), + ) + + tools := GetRegisteredTools() + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + if tools[0].Hidden { + t.Error("tool should not be hidden by default") + } +} + +func TestHiddenToolExplicitTrue(t *testing.T) { + ClearRegistry() + defer ClearRegistry() + + Tool("hidden_tool", + Description("A hidden tool"), + HiddenHint(true), + Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult { + return Result("secret") + }), + ) + + tools := GetRegisteredTools() + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + if !tools[0].Hidden { + t.Error("tool should be hidden when HiddenHint(true) is set") + } +} + +func TestHiddenToolExplicitFalse(t *testing.T) { + ClearRegistry() + defer ClearRegistry() + + Tool("not_hidden", + Description("Explicitly not hidden"), + HiddenHint(false), + Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult { + return Result("visible") + }), + ) + + tools := GetRegisteredTools() + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + if tools[0].Hidden { + t.Error("tool should not be hidden when HiddenHint(false) is set") + } +} + +func TestHiddenToolsCollectedFromRegistry(t *testing.T) { + ClearRegistry() + defer ClearRegistry() + + Tool("visible1", + Description("Visible tool 1"), + Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult { + return Result("ok") + }), + ) + Tool("hidden1", + Description("Hidden tool 1"), + HiddenHint(true), + Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult { + return Result("secret") + }), + ) + Tool("visible2", + Description("Visible tool 2"), + Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult { + return Result("ok") + }), + ) + Tool("hidden2", + Description("Hidden tool 2"), + HiddenHint(true), + Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult { + return Result("secret") + }), + ) + + tools := GetRegisteredTools() + if len(tools) != 4 { + t.Fatalf("expected 4 tools, got %d", len(tools)) + } + + var hiddenNames []string + for _, td := range tools { + if td.Hidden { + hiddenNames = append(hiddenNames, td.Name) + } + } + + if len(hiddenNames) != 2 { + t.Fatalf("expected 2 hidden tools, got %d: %v", len(hiddenNames), hiddenNames) + } + + found1, found2 := false, false + for _, name := range hiddenNames { + if name == "hidden1" { + found1 = true + } + if name == "hidden2" { + found2 = true + } + } + if !found1 { + t.Error("hidden1 should be in hidden tools list") + } + if !found2 { + t.Error("hidden2 should be in hidden tools list") + } +} diff --git a/sdk/python/tests/test_round3_regression.py b/sdk/python/tests/test_round3_regression.py new file mode 100644 index 0000000..32d96ab --- /dev/null +++ b/sdk/python/tests/test_round3_regression.py @@ -0,0 +1,339 @@ +"""Round 3 regression tests for bugs found during stress testing.""" + +import os +import tempfile + +import pytest + +from protomcp.tool import tool, get_registered_tools, get_hidden_tool_names, clear_registry +from protomcp.workflow import ( + step, + workflow, + StepResult, + get_registered_workflows, + clear_workflow_registry, + workflows_to_tool_defs, +) +from protomcp.group import ( + tool_group, + action, + clear_group_registry, + groups_to_tool_defs, +) +from protomcp.resource import ( + resource, + resource_template, + get_registered_resources, + get_registered_resource_templates, + clear_resource_registry, + clear_template_registry, +) +from protomcp.prompt import prompt, get_registered_prompts, clear_prompt_registry, PromptArg, PromptMessage +from protomcp.discovery import configure, discover_handlers, reset_config +from protomcp.runner import _uri_matches_template +from protomcp.local_middleware import clear_local_middleware +from protomcp.server_context import clear_context_registry +from protomcp.telemetry import clear_telemetry_sinks +from protomcp.completion import clear_completion_registry +from protomcp.sidecar import clear_sidecar_registry +from protomcp.middleware import clear_middleware_registry + + +def _clear_all(): + clear_registry() + clear_group_registry() + clear_workflow_registry() + clear_resource_registry() + clear_template_registry() + clear_prompt_registry() + clear_local_middleware() + clear_context_registry() + clear_telemetry_sinks() + clear_completion_registry() + clear_sidecar_registry() + clear_middleware_registry() + reset_config() + + +@pytest.fixture(autouse=True) +def clean_registries(): + _clear_all() + yield + _clear_all() + + +# --------------------------------------------------------------------------- +# 1. Hidden tool detection includes workflows +# --------------------------------------------------------------------------- + + +class TestHiddenToolDetectionIncludesWorkflows: + def test_workflow_hidden_steps_in_get_hidden_tool_names(self): + """Register a workflow and verify get_hidden_tool_names() returns the hidden step names.""" + @workflow(name="deploy", description="Deploy workflow") + class Deploy: + @step("start", description="Start deploy", initial=True, next=["approve"]) + def start(self, env: str) -> StepResult: + return StepResult(result=f"Deploying to {env}") + + @step("approve", description="Approve deploy", terminal=True) + def approve(self) -> StepResult: + return StepResult(result="Approved") + + hidden = get_hidden_tool_names() + # Non-initial steps and cancel should be hidden + assert "deploy.approve" in hidden + assert "deploy.cancel" in hidden + # Initial step should NOT be hidden + assert "deploy.start" not in hidden + + def test_workflow_with_multiple_non_initial_steps(self): + """All non-initial steps should appear in hidden tool names.""" + @workflow(name="pipeline", description="Pipeline") + class Pipeline: + @step("init", description="Init", initial=True, next=["build"]) + def init(self) -> StepResult: + return StepResult() + + @step("build", description="Build", next=["test"]) + def build(self) -> StepResult: + return StepResult() + + @step("test", description="Test", next=["deploy"]) + def test_step(self) -> StepResult: + return StepResult() + + @step("deploy", description="Deploy", terminal=True) + def deploy(self) -> StepResult: + return StepResult() + + hidden = get_hidden_tool_names() + assert "pipeline.build" in hidden + assert "pipeline.test" in hidden + assert "pipeline.deploy" in hidden + assert "pipeline.cancel" in hidden + assert "pipeline.init" not in hidden + + +# --------------------------------------------------------------------------- +# 2. Hidden tool detection includes groups with hidden tools +# --------------------------------------------------------------------------- + + +class TestHiddenToolDetectionIncludesGroups: + def test_hidden_group_in_get_hidden_tool_names(self): + """A group registered with hidden=True should appear in get_hidden_tool_names().""" + @tool_group(name="secret_ops", description="Secret operations", hidden=True) + class SecretOps: + @action("do_secret", description="Do secret thing") + def do_secret(self) -> str: + return "done" + + hidden = get_hidden_tool_names() + assert "secret_ops" in hidden + + def test_non_hidden_group_not_in_hidden_names(self): + """A group registered without hidden=True should NOT appear in hidden tool names.""" + @tool_group(name="public_ops", description="Public operations") + class PublicOps: + @action("do_public", description="Do public thing") + def do_public(self) -> str: + return "done" + + hidden = get_hidden_tool_names() + assert "public_ops" not in hidden + + def test_hidden_individual_tool_in_hidden_names(self): + """An individual tool with hidden=True should appear in get_hidden_tool_names().""" + @tool(description="A hidden tool", hidden=True) + def secret_tool() -> str: + return "secret" + + hidden = get_hidden_tool_names() + assert "secret_tool" in hidden + + def test_combined_hidden_from_tools_groups_workflows(self): + """Hidden tools from all sources should be collected together.""" + @tool(description="Hidden tool", hidden=True) + def hidden_tool() -> str: + return "x" + + @tool_group(name="hidden_group", description="Hidden group", hidden=True) + class HiddenGroup: + @action("act", description="Act") + def act(self) -> str: + return "y" + + @workflow(name="wf", description="Workflow") + class WF: + @step("start", description="Start", initial=True, next=["end"]) + def start(self) -> StepResult: + return StepResult() + + @step("end", description="End", terminal=True) + def end(self) -> StepResult: + return StepResult() + + hidden = get_hidden_tool_names() + assert "hidden_tool" in hidden + assert "hidden_group" in hidden + assert "wf.end" in hidden + assert "wf.cancel" in hidden + + +# --------------------------------------------------------------------------- +# 3. Resource template URI matching +# --------------------------------------------------------------------------- + + +class TestURIMatchesTemplate: + def test_basic_match(self): + assert _uri_matches_template("notes://{id}", "notes://123") + + def test_basic_no_match_wrong_scheme(self): + assert not _uri_matches_template("notes://{id}", "other://123") + + def test_multiple_parameters(self): + assert _uri_matches_template("users://{org}/{id}", "users://acme/42") + + def test_no_match_missing_segment(self): + assert not _uri_matches_template("users://{org}/{id}", "users://acme") + + def test_exact_static_uri(self): + assert _uri_matches_template("config://global", "config://global") + + def test_exact_static_no_match(self): + assert not _uri_matches_template("config://global", "config://local") + + def test_parameter_does_not_match_slash(self): + """A single {param} should not match across path separators.""" + assert not _uri_matches_template("files://{name}", "files://dir/file") + + def test_empty_parameter_no_match(self): + """An empty segment should not match a template parameter.""" + assert not _uri_matches_template("notes://{id}", "notes://") + + def test_complex_template(self): + assert _uri_matches_template( + "repo://{owner}/{repo}/issues/{number}", + "repo://octocat/hello-world/issues/42", + ) + + +# --------------------------------------------------------------------------- +# 4. Hot reload discovery clears all registries +# --------------------------------------------------------------------------- + + +class TestHotReloadClearsAllRegistries: + @staticmethod + def _write_dummy_handler(tmpdir): + """Write a minimal handler file so discover_handlers loads something.""" + handler_path = os.path.join(tmpdir, "dummy.py") + with open(handler_path, "w") as f: + f.write("x = 1\n") + return handler_path + + def test_hot_reload_clears_tools(self): + """Registering tools then triggering hot reload should clear the tool registry.""" + @tool(description="Temp tool") + def temp_tool() -> str: + return "temp" + + assert any(t.name == "temp_tool" for t in get_registered_tools()) + + with tempfile.TemporaryDirectory() as tmpdir: + self._write_dummy_handler(tmpdir) + configure(handlers_dir=tmpdir, hot_reload=True) + # First discover loads modules (populates _loaded_modules) + discover_handlers() + # Second discover triggers hot reload path (clears registries) + discover_handlers() + + # After hot reload, the manually registered tool should be gone + assert not any(t.name == "temp_tool" for t in get_registered_tools()) + + def test_hot_reload_clears_resources(self): + """Resources should be cleared on hot reload.""" + @resource(uri="test://resource", description="Test resource") + def test_res(uri: str): + return "data" + + assert len(get_registered_resources()) == 1 + + with tempfile.TemporaryDirectory() as tmpdir: + self._write_dummy_handler(tmpdir) + configure(handlers_dir=tmpdir, hot_reload=True) + discover_handlers() + discover_handlers() + + assert len(get_registered_resources()) == 0 + + def test_hot_reload_clears_resource_templates(self): + """Resource templates should be cleared on hot reload.""" + @resource_template(uri_template="test://{id}", description="Test template") + def test_tmpl(uri: str): + return "data" + + assert len(get_registered_resource_templates()) == 1 + + with tempfile.TemporaryDirectory() as tmpdir: + self._write_dummy_handler(tmpdir) + configure(handlers_dir=tmpdir, hot_reload=True) + discover_handlers() + discover_handlers() + + assert len(get_registered_resource_templates()) == 0 + + def test_hot_reload_clears_prompts(self): + """Prompts should be cleared on hot reload.""" + @prompt(description="Test prompt") + def test_prompt_fn() -> list: + return [PromptMessage(role="user", content="hello")] + + assert len(get_registered_prompts()) == 1 + + with tempfile.TemporaryDirectory() as tmpdir: + self._write_dummy_handler(tmpdir) + configure(handlers_dir=tmpdir, hot_reload=True) + discover_handlers() + discover_handlers() + + assert len(get_registered_prompts()) == 0 + + def test_hot_reload_clears_groups(self): + """Groups should be cleared on hot reload.""" + @tool_group(name="temp_group", description="Temp group") + class TempGroup: + @action("act", description="Action") + def act(self) -> str: + return "x" + + from protomcp.group import get_registered_groups + assert len(get_registered_groups()) == 1 + + with tempfile.TemporaryDirectory() as tmpdir: + self._write_dummy_handler(tmpdir) + configure(handlers_dir=tmpdir, hot_reload=True) + discover_handlers() + discover_handlers() + + assert len(get_registered_groups()) == 0 + + def test_hot_reload_clears_workflows(self): + """Workflows should be cleared on hot reload.""" + @workflow(name="temp_wf", description="Temp workflow") + class TempWF: + @step("start", description="Start", initial=True, terminal=True) + def start(self) -> StepResult: + return StepResult() + + assert len(get_registered_workflows()) == 1 + + with tempfile.TemporaryDirectory() as tmpdir: + self._write_dummy_handler(tmpdir) + configure(handlers_dir=tmpdir, hot_reload=True) + discover_handlers() + discover_handlers() + + assert len(get_registered_workflows()) == 0 diff --git a/sdk/rust/src/tool.rs b/sdk/rust/src/tool.rs index e0a9c52..7cf4ee7 100644 --- a/sdk/rust/src/tool.rs +++ b/sdk/rust/src/tool.rs @@ -313,4 +313,53 @@ mod tests { }); clear_registry(); } + + #[test] + fn test_hidden_defaults_to_false() { + let _lock = lock_and_clear(); + tool("my_tool") + .description("A tool") + .handler(|_, _| ToolResult::new("ok")) + .register(); + + with_registry(|tools| { + assert_eq!(tools.len(), 1); + assert!(!tools[0].hidden, "hidden should default to false"); + }); + clear_registry(); + } + + #[test] + fn test_hidden_field_set_by_direct_construction() { + let _lock = lock_and_clear(); + // Since there's no builder method for hidden, verify the field + // is accessible and can be set directly on ToolDef + let td = ToolDef { + name: "secret".to_string(), + description: "A hidden tool".to_string(), + input_schema: serde_json::json!({"type": "object", "properties": {}}), + handler: Arc::new(|_, _| ToolResult::new("secret")), + destructive: false, + idempotent: false, + read_only: false, + open_world: false, + task_support: false, + hidden: true, + }; + assert!(td.hidden, "hidden should be true when set explicitly"); + + let td2 = ToolDef { + name: "visible".to_string(), + description: "A visible tool".to_string(), + input_schema: serde_json::json!({"type": "object", "properties": {}}), + handler: Arc::new(|_, _| ToolResult::new("visible")), + destructive: false, + idempotent: false, + read_only: false, + open_world: false, + task_support: false, + hidden: false, + }; + assert!(!td2.hidden, "hidden should be false when set explicitly"); + } } diff --git a/sdk/typescript/tests/round3Regression.test.ts b/sdk/typescript/tests/round3Regression.test.ts new file mode 100644 index 0000000..39097ce --- /dev/null +++ b/sdk/typescript/tests/round3Regression.test.ts @@ -0,0 +1,202 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { z } from 'zod'; +import { tool, getRegisteredTools, getHiddenToolNames, clearRegistry } from '../src/tool.js'; +import { workflow, StepResult, getRegisteredWorkflows, clearWorkflowRegistry } from '../src/workflow.js'; +import { clearGroupRegistry } from '../src/group.js'; +import { configure, discoverHandlers, resetConfig } from '../src/discovery.js'; +import { ToolContext } from '../src/context.js'; +import { ToolResult } from '../src/result.js'; +import { toolManager } from '../src/manager.js'; + +function dummyCtx(): ToolContext { + return new ToolContext('', () => {}); +} + +beforeEach(() => { + clearWorkflowRegistry(); + clearRegistry(); + clearGroupRegistry(); + resetConfig(); +}); + +// --------------------------------------------------------------------------- +// 1. Workflow preWorkflowTools snapshot +// --------------------------------------------------------------------------- + +describe('workflow preWorkflowTools snapshot', () => { + it('initial step result has enableTools/disableTools reflecting pre-workflow tools', async () => { + // Register some regular tools first + tool({ + name: 'regular_tool', + description: 'A regular tool', + args: z.object({}), + handler: () => 'ok', + }); + + // Register a workflow + workflow({ + name: 'deploy', + description: 'Deploy workflow', + steps: { + start: { + description: 'Start deploy', + initial: true, + next: ['finish'], + handler: () => 'started', + }, + finish: { + description: 'Finish deploy', + terminal: true, + handler: () => 'done', + }, + }, + }); + + vi.spyOn(toolManager, 'getActiveTools').mockResolvedValue([]); + vi.spyOn(toolManager, 'setAllowed').mockResolvedValue([]); + + const tools = getRegisteredTools(); + const startTool = tools.find(t => t.name === 'deploy.start')!; + expect(startTool).toBeDefined(); + + const result = await startTool.handler({}, dummyCtx()); + expect(result).toBeInstanceOf(ToolResult); + expect(result.isError).toBe(false); + expect(result.result).toBe('started'); + // The result should have enableTools (next steps) and disableTools + expect(result.enableTools).toBeDefined(); + expect(result.disableTools).toBeDefined(); + // Enable tools should include the next step + expect(result.enableTools).toContain('deploy.finish'); + }); +}); + +// --------------------------------------------------------------------------- +// 2. Hot reload discovery clears all registries +// --------------------------------------------------------------------------- + +describe('hot reload clears all registries', () => { + it('clears tools on hot reload', async () => { + tool({ + name: 'temp_tool', + description: 'A temporary tool', + args: z.object({}), + handler: () => 'temp', + }); + + expect(getRegisteredTools().some(t => t.name === 'temp_tool')).toBe(true); + + configure({ handlersDir: '/tmp/nonexistent-protomcp-test-dir', hotReload: true }); + // First discover loads (noop for nonexistent dir but marks loadedModules) + await discoverHandlers(); + // Second discover triggers hot reload clear + await discoverHandlers(); + + // After hot reload, cleanup + clearRegistry(); + resetConfig(); + }); + + it('clears workflows on hot reload', async () => { + workflow({ + name: 'temp_wf', + description: 'Temp workflow', + steps: { + start: { + description: 'Start', + initial: true, + terminal: true, + handler: () => 'ok', + }, + }, + }); + + expect(getRegisteredWorkflows()).toHaveLength(1); + + configure({ handlersDir: '/tmp/nonexistent-protomcp-test-dir', hotReload: true }); + await discoverHandlers(); + await discoverHandlers(); + + // Cleanup + clearWorkflowRegistry(); + resetConfig(); + }); +}); + +// --------------------------------------------------------------------------- +// 3. Hidden tool names include workflow tools +// --------------------------------------------------------------------------- + +describe('hidden tool names include workflow tools', () => { + it('getHiddenToolNames returns hidden workflow step names', () => { + workflow({ + name: 'review', + description: 'Review workflow', + steps: { + start: { + description: 'Start review', + initial: true, + next: ['approve', 'reject'], + handler: () => 'reviewing', + }, + approve: { + description: 'Approve', + terminal: true, + handler: () => 'approved', + }, + reject: { + description: 'Reject', + terminal: true, + handler: () => 'rejected', + }, + }, + }); + + const hidden = getHiddenToolNames(); + // Non-initial steps should be hidden + expect(hidden).toContain('review.approve'); + expect(hidden).toContain('review.reject'); + expect(hidden).toContain('review.cancel'); + // Initial step should NOT be hidden + expect(hidden).not.toContain('review.start'); + }); + + it('getHiddenToolNames returns hidden tools from mixed sources', () => { + // Add a hidden individual tool + tool({ + name: 'hidden_tool', + description: 'A hidden tool', + args: z.object({}), + handler: () => 'secret', + }); + // Manually set hidden on the last tool + const tools = getRegisteredTools(); + const hiddenTool = tools.find(t => t.name === 'hidden_tool')!; + hiddenTool.hidden = true; + + // Add a workflow (non-initial steps are auto-hidden) + workflow({ + name: 'wf', + description: 'Workflow', + steps: { + init: { + description: 'Init', + initial: true, + next: ['done'], + handler: () => 'ok', + }, + done: { + description: 'Done', + terminal: true, + handler: () => 'ok', + }, + }, + }); + + const hidden = getHiddenToolNames(); + expect(hidden).toContain('hidden_tool'); + expect(hidden).toContain('wf.done'); + expect(hidden).toContain('wf.cancel'); + expect(hidden).not.toContain('wf.init'); + }); +}); diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index 25fc2a9..309c85d 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -118,3 +118,191 @@ func TestE2E_DynamicToolList(t *testing.T) { t.Error("admin_action should be visible after auth") } } + +// --------------------------------------------------------------------------- +// Python E2E: resource read +// --------------------------------------------------------------------------- + +func TestE2E_Python_ResourceRead(t *testing.T) { + w, r, cleanup := StartProtomcp(t, "dev", fixture("resource_tool.py")) + defer cleanup() + + InitializeSession(t, w, r) + + // List resources + listResp := SendRequest(t, w, r, "resources/list", nil) + if listResp.Error != nil { + t.Fatalf("resources/list error: %v", listResp.Error) + } + var listResult struct { + Resources []struct { + URI string `json:"uri"` + Name string `json:"name"` + Description string `json:"description"` + } `json:"resources"` + } + if err := json.Unmarshal(listResp.Result, &listResult); err != nil { + t.Fatalf("unmarshal resources/list: %v", err) + } + if len(listResult.Resources) < 1 { + t.Fatalf("expected at least 1 resource, got %d", len(listResult.Resources)) + } + + // Read a specific resource + readResp := SendRequest(t, w, r, "resources/read", map[string]interface{}{ + "uri": "config://app", + }) + if readResp.Error != nil { + t.Fatalf("resources/read error: %v", readResp.Error) + } + var readResult struct { + Contents []struct { + URI string `json:"uri"` + Text string `json:"text"` + MIMEType string `json:"mimeType"` + } `json:"contents"` + } + if err := json.Unmarshal(readResp.Result, &readResult); err != nil { + t.Fatalf("unmarshal resources/read: %v", err) + } + if len(readResult.Contents) == 0 { + t.Fatal("expected at least 1 content item") + } + if readResult.Contents[0].Text == "" { + t.Error("expected non-empty text content") + } +} + +// --------------------------------------------------------------------------- +// Python E2E: prompt get +// --------------------------------------------------------------------------- + +func TestE2E_Python_PromptGet(t *testing.T) { + w, r, cleanup := StartProtomcp(t, "dev", fixture("prompt_tool.py")) + defer cleanup() + + InitializeSession(t, w, r) + + // List prompts + listResp := SendRequest(t, w, r, "prompts/list", nil) + if listResp.Error != nil { + t.Fatalf("prompts/list error: %v", listResp.Error) + } + var listResult struct { + Prompts []struct { + Name string `json:"name"` + Description string `json:"description"` + } `json:"prompts"` + } + if err := json.Unmarshal(listResp.Result, &listResult); err != nil { + t.Fatalf("unmarshal prompts/list: %v", err) + } + if len(listResult.Prompts) < 1 { + t.Fatalf("expected at least 1 prompt, got %d", len(listResult.Prompts)) + } + + // Get a specific prompt + getResp := SendRequest(t, w, r, "prompts/get", map[string]interface{}{ + "name": "greet", + "arguments": map[string]string{"name": "Alice"}, + }) + if getResp.Error != nil { + t.Fatalf("prompts/get error: %v", getResp.Error) + } + var getResult struct { + Messages []struct { + Role string `json:"role"` + Content struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + } `json:"messages"` + } + if err := json.Unmarshal(getResp.Result, &getResult); err != nil { + t.Fatalf("unmarshal prompts/get: %v", err) + } + if len(getResult.Messages) == 0 { + t.Fatal("expected at least 1 message") + } + if getResult.Messages[0].Content.Text == "" { + t.Error("expected non-empty prompt message text") + } +} + +// --------------------------------------------------------------------------- +// Python E2E: tool call with echo verification +// --------------------------------------------------------------------------- + +func TestE2E_Python_ToolCallEchoVerify(t *testing.T) { + w, r, cleanup := StartProtomcp(t, "dev", fixture("simple_tool.py")) + defer cleanup() + + InitializeSession(t, w, r) + resp := SendRequest(t, w, r, "tools/call", map[string]interface{}{ + "name": "echo", + "arguments": map[string]string{"message": "test_echo_value"}, + }) + + if resp.Error != nil { + t.Fatalf("tools/call error: %v", resp.Error) + } + + var result testutil.ToolsCallResult + if err := json.Unmarshal(resp.Result, &result); err != nil { + t.Fatalf("unmarshal ToolsCallResult: %v", err) + } + if result.IsError { + t.Fatalf("tool call returned error") + } + if len(result.Content) == 0 { + t.Fatal("expected at least 1 content item") + } + found := false + for _, c := range result.Content { + if c.Text == "test_echo_value" { + found = true + } + } + if !found { + t.Errorf("expected echo to return 'test_echo_value', got content: %+v", result.Content) + } +} + +// --------------------------------------------------------------------------- +// Python E2E: add tool call +// --------------------------------------------------------------------------- + +func TestE2E_Python_ToolCallAdd(t *testing.T) { + w, r, cleanup := StartProtomcp(t, "dev", fixture("simple_tool.py")) + defer cleanup() + + InitializeSession(t, w, r) + resp := SendRequest(t, w, r, "tools/call", map[string]interface{}{ + "name": "add", + "arguments": map[string]interface{}{"a": 7, "b": 3}, + }) + + if resp.Error != nil { + t.Fatalf("tools/call error: %v", resp.Error) + } + + var result testutil.ToolsCallResult + if err := json.Unmarshal(resp.Result, &result); err != nil { + t.Fatalf("unmarshal ToolsCallResult: %v", err) + } + if result.IsError { + t.Fatalf("tool call returned error") + } + if len(result.Content) == 0 { + t.Fatal("expected at least 1 content item") + } + found := false + for _, c := range result.Content { + if c.Text == "10" { + found = true + } + } + if !found { + t.Errorf("expected add(7,3) to return '10', got content: %+v", result.Content) + } +} diff --git a/test/e2e/fixtures/prompt_tool.py b/test/e2e/fixtures/prompt_tool.py new file mode 100644 index 0000000..bb59d10 --- /dev/null +++ b/test/e2e/fixtures/prompt_tool.py @@ -0,0 +1,20 @@ +from protomcp import tool +from protomcp.prompt import prompt, PromptArg, PromptMessage +from protomcp.runner import run + + +@tool(description="Echo a message back") +def echo(message: str) -> str: + return message + + +@prompt( + description="Generate a greeting", + arguments=[PromptArg(name="name", description="Name to greet", required=True)], +) +def greet(name: str) -> list: + return [PromptMessage(role="user", content=f"Please greet {name} warmly.")] + + +if __name__ == "__main__": + run() diff --git a/test/e2e/fixtures/resource_tool.py b/test/e2e/fixtures/resource_tool.py new file mode 100644 index 0000000..6798bb6 --- /dev/null +++ b/test/e2e/fixtures/resource_tool.py @@ -0,0 +1,27 @@ +from protomcp import tool +from protomcp.resource import resource, resource_template, ResourceContent +from protomcp.runner import run + + +@tool(description="Echo a message back") +def echo(message: str) -> str: + return message + + +@resource(uri="config://app", description="App configuration", name="app_config") +def app_config(uri: str) -> ResourceContent: + return ResourceContent(uri=uri, text='{"debug": true}', mime_type="application/json") + + +@resource_template( + uri_template="notes://{id}", + description="Get a note by ID", + name="note", +) +def get_note(uri: str) -> ResourceContent: + note_id = uri.split("://")[1] + return ResourceContent(uri=uri, text=f"Note {note_id} content", mime_type="text/plain") + + +if __name__ == "__main__": + run()