Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions sdk/go/protomcp/hidden_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package protomcp

import (
"testing"
)

func TestHiddenToolDefault(t *testing.T) {
ClearRegistry()
defer ClearRegistry()

Tool("visible_tool",
Description("A visible tool"),
Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult {
return Result("ok")
}),
)

tools := GetRegisteredTools()
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
}
if tools[0].Hidden {
t.Error("tool should not be hidden by default")
}
}

func TestHiddenToolExplicitTrue(t *testing.T) {
ClearRegistry()
defer ClearRegistry()

Tool("hidden_tool",
Description("A hidden tool"),
HiddenHint(true),
Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult {
return Result("secret")
}),
)

tools := GetRegisteredTools()
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
}
if !tools[0].Hidden {
t.Error("tool should be hidden when HiddenHint(true) is set")
}
}

func TestHiddenToolExplicitFalse(t *testing.T) {
ClearRegistry()
defer ClearRegistry()

Tool("not_hidden",
Description("Explicitly not hidden"),
HiddenHint(false),
Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult {
return Result("visible")
}),
)

tools := GetRegisteredTools()
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
}
if tools[0].Hidden {
t.Error("tool should not be hidden when HiddenHint(false) is set")
}
}

func TestHiddenToolsCollectedFromRegistry(t *testing.T) {
ClearRegistry()
defer ClearRegistry()

Tool("visible1",
Description("Visible tool 1"),
Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult {
return Result("ok")
}),
)
Tool("hidden1",
Description("Hidden tool 1"),
HiddenHint(true),
Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult {
return Result("secret")
}),
)
Tool("visible2",
Description("Visible tool 2"),
Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult {
return Result("ok")
}),
)
Tool("hidden2",
Description("Hidden tool 2"),
HiddenHint(true),
Handler(func(ctx ToolContext, args map[string]interface{}) ToolResult {
return Result("secret")
}),
)

tools := GetRegisteredTools()
if len(tools) != 4 {
t.Fatalf("expected 4 tools, got %d", len(tools))
}

var hiddenNames []string
for _, td := range tools {
if td.Hidden {
hiddenNames = append(hiddenNames, td.Name)
}
}

if len(hiddenNames) != 2 {
t.Fatalf("expected 2 hidden tools, got %d: %v", len(hiddenNames), hiddenNames)
}

found1, found2 := false, false
for _, name := range hiddenNames {
if name == "hidden1" {
found1 = true
}
if name == "hidden2" {
found2 = true
}
}
if !found1 {
t.Error("hidden1 should be in hidden tools list")
}
if !found2 {
t.Error("hidden2 should be in hidden tools list")
}
}
34 changes: 33 additions & 1 deletion sdk/go/protomcp/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -315,6 +316,37 @@ func handleReload(tp *Transport, reqID string) {
sendMiddlewareRegistrations(tp)
}

func uriMatchesTemplate(template, uri string) bool {
parts := strings.Split(template, "{")
if len(parts) == 0 {
return template == uri
}
if !strings.HasPrefix(uri, parts[0]) {
return false
}
remaining := uri[len(parts[0]):]
for _, part := range parts[1:] {
closeBrace := strings.Index(part, "}")
if closeBrace < 0 {
return false
}
suffix := part[closeBrace+1:]
if suffix == "" {
if remaining == "" {
return false
}
remaining = ""
} else {
idx := strings.Index(remaining, suffix)
if idx <= 0 {
return false
}
remaining = remaining[idx+len(suffix):]
}
}
return remaining == ""
}

func handleListResources(tp *Transport, reqID string) {
resources := GetRegisteredResources()
var defs []*pb.ResourceDefinition
Expand Down Expand Up @@ -368,7 +400,7 @@ func handleReadResource(tp *Transport, req *pb.ReadResourceRequest, reqID string

// Try resource templates.
for _, t := range GetRegisteredResourceTemplates() {
if t.HandlerFn != nil {
if t.HandlerFn != nil && uriMatchesTemplate(t.URITemplate, uri) {
contents := t.HandlerFn(uri)
if len(contents) > 0 {
sendResourceContents(tp, reqID, contents)
Expand Down
58 changes: 46 additions & 12 deletions sdk/go/protomcp/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,7 @@ var ToolManagerAdapter struct {
SetAllowed func([]string)
}

func transitionToSteps(wf *WorkflowDef, state *WorkflowState, nextStepNames []string) {
if ToolManagerAdapter.SetAllowed == nil {
return
}

func transitionToSteps(wf *WorkflowDef, state *WorkflowState, nextStepNames []string) []string {
stepMap := map[string]*StepDef{}
for i := range wf.Steps {
stepMap[wf.Steps[i].Name] = &wf.Steps[i]
Expand Down Expand Up @@ -347,7 +343,11 @@ func transitionToSteps(wf *WorkflowDef, state *WorkflowState, nextStepNames []st
}
}

ToolManagerAdapter.SetAllowed(allowedTools)
if ToolManagerAdapter.SetAllowed != nil {
ToolManagerAdapter.SetAllowed(allowedTools)
}

return allowedTools
}

// --- Step dispatch ---
Expand Down Expand Up @@ -403,8 +403,22 @@ func HandleStepCall(workflowName, stepName string, ctx ToolContext, args map[str
for errKey, targetStep := range stepDef.OnError {
if strings.Contains(errStr, errKey) {
state.CurrentStep = targetStep
transitionToSteps(wf, state, []string{targetStep})
return Result(fmt.Sprintf("Error caught (%s), transitioning to '%s'", errStr, targetStep))
allowedTools := transitionToSteps(wf, state, []string{targetStep})
allTools := GetRegisteredTools()
allowedSet := map[string]bool{}
for _, t := range allowedTools {
allowedSet[t] = true
}
var disableTools []string
for _, t := range allTools {
if !allowedSet[t.Name] {
disableTools = append(disableTools, t.Name)
}
}
r := Result(fmt.Sprintf("Error caught (%s), transitioning to '%s'", errStr, targetStep))
r.EnableTools = allowedTools
r.DisableTools = disableTools
return r
}
}
}
Expand Down Expand Up @@ -469,16 +483,33 @@ func HandleStepCall(workflowName, stepName string, ctx ToolContext, args map[str
if resultText == "" {
resultText = "Workflow complete"
}
return Result(resultText)
r := Result(resultText)
r.EnableTools = state.PreWorkflowTools
r.DisableTools = []string{}
return r
}

// Transition to next steps
transitionToSteps(wf, state, effectiveNext)
allowedTools := transitionToSteps(wf, state, effectiveNext)
allTools := GetRegisteredTools()
allowedSet := map[string]bool{}
for _, t := range allowedTools {
allowedSet[t] = true
}
var disableTools []string
for _, t := range allTools {
if !allowedSet[t.Name] {
disableTools = append(disableTools, t.Name)
}
}
resultText := result.Result
if resultText == "" {
resultText = fmt.Sprintf("Proceed to: %v", effectiveNext)
}
return Result(resultText)
r := Result(resultText)
r.EnableTools = allowedTools
r.DisableTools = disableTools
return r
}

// HandleCancel handles a cancel tool call for the given workflow.
Expand All @@ -505,7 +536,10 @@ func HandleCancel(workflowName string) ToolResult {
ToolManagerAdapter.SetAllowed(state.PreWorkflowTools)
}
activeWorkflowStack = activeWorkflowStack[:len(activeWorkflowStack)-1]
return Result(fmt.Sprintf("Workflow '%s' cancelled", workflowName))
r := Result(fmt.Sprintf("Workflow '%s' cancelled", workflowName))
r.EnableTools = state.PreWorkflowTools
r.DisableTools = []string{}
return r
}

// --- Tool generation ---
Expand Down
5 changes: 3 additions & 2 deletions sdk/python/src/protomcp/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ def _handle_reload(transport, env, mw_handlers):
request_id=env.request_id,
)
transport.send(resp)
# Also re-send tool list
_handle_list_tools(transport, env)
# Also re-send tool list with empty request_id so it routes to handshakeCh
fake_env = pb.Envelope() # empty envelope (request_id defaults to "")
_handle_list_tools(transport, fake_env)
_send_middleware_registrations(transport, mw_handlers)

def _send_middleware_registrations(transport, mw_handlers):
Expand Down
Loading
Loading