diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5860814..1bd155b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -23,6 +23,12 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} HOMEBREW_TAP_TOKEN: ${{ secrets.HOMEBREW_TAP_TOKEN }} + - name: Create SDK Go submodule tag + if: success() + run: | + TAG="${GITHUB_REF#refs/tags/}" + git tag -f "sdk/go/${TAG}" + git push origin "sdk/go/${TAG}" --force pypi: runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 33c2046..3bcbc62 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,7 @@ __pycache__/ .superpowers/ docs/superpowers/ sdk/rust/target/ +go.work +go.work.sum /protomcp docs/.astro/ diff --git a/.goreleaser.yml b/.goreleaser.yml index f725a5a..5e6fbcf 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -2,6 +2,8 @@ version: 2 builds: - main: ./cmd/protomcp binary: pmcp + ldflags: + - -s -w -X main.version={{.Version}} goos: - linux - darwin diff --git a/README.md b/README.md index 75df321..b6506d0 100644 --- a/README.md +++ b/README.md @@ -207,7 +207,7 @@ See the [full documentation](https://msilverblatt.github.io/protomcp/) for detai ## Tool Groups -Real-world MCP tools tend to accumulate dozens of parameters behind a single endpoint. Tool groups let you split actions into clean, per-action schemas while exposing a single tool with a discriminated union (`oneOf`) to the LLM. +Real-world MCP tools tend to accumulate dozens of parameters behind a single endpoint. Tool groups let you split actions into clean, per-action schemas. By default, each action becomes its own tool (e.g. `db.query`, `db.insert`) — the **separate** strategy. For clients that support `oneOf` schemas, the **union** strategy is available as an opt-in, exposing all actions as a single tool with a discriminated union. **Before** -- one tool, 20+ params: @@ -407,7 +407,7 @@ See [`examples/`](examples/) for runnable demos: - **Basic** — minimal tool examples in all four languages - **Resources & Prompts** — resources, prompts, completions, and tools together - **Full showcase** — structured output, progress, cancellation, dynamic tool lists, error handling -- **Tool Groups** — per-action schemas with union and separate strategies +- **Tool Groups** — per-action schemas with separate (default) and union strategies - **Advanced Server** — middleware, telemetry, server context working together - **Workflows** — deployment pipeline as a server-defined state machine diff --git a/cmd/protomcp/main.go b/cmd/protomcp/main.go index 4252a13..d72884d 100644 --- a/cmd/protomcp/main.go +++ b/cmd/protomcp/main.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "os/signal" + "path/filepath" "sync" "time" @@ -23,6 +24,8 @@ import ( "github.com/msilverblatt/protomcp/internal/validate" ) +var version = "dev" + func main() { cfg, err := config.Parse(os.Args[1:]) if err != nil { @@ -109,7 +112,7 @@ func main() { backend := &toolBackend{pm: pm, tlm: tlm, allTools: tools} // 5. Create bridge (replaces custom mcp.NewHandler) - b := bridge.New(backend, logger) + b := bridge.New(backend, logger, version) b.SetToolListMutationHandler(func(enable, disable []string) { if len(enable) > 0 { tlm.Enable(enable) @@ -121,9 +124,7 @@ func main() { }) // 6. Sync tools, resources, and prompts from backend into the official mcp.Server - b.SyncTools() - b.SyncResources() - b.SyncPrompts() + b.SyncAll() // 7. Wire process manager callbacks pm.OnProgress(func(msg *pb.ProgressNotification) { @@ -145,7 +146,8 @@ func main() { // 8. Start file watcher (dev mode only) if cfg.Command == "dev" { - w, err := reload.NewWatcher(cfg.File, nil, func() { + ext := filepath.Ext(cfg.File) + w, err := reload.NewWatcher(filepath.Dir(cfg.File), []string{ext}, func() { slog.Info("file changed, reloading...") newTools, err := pm.Reload(ctx) if err != nil { @@ -163,9 +165,7 @@ func main() { if !slicesEqual(oldActive, newActive) { slog.Info("tool list changed, syncing tools") } - b.SyncTools() - b.SyncResources() - b.SyncPrompts() + b.SyncAll() }) if err != nil { slog.Error("failed to create file watcher", "error", err) diff --git a/docs/src/content/docs/guides/hot-reload.mdx b/docs/src/content/docs/guides/hot-reload.mdx index 195cf04..0934570 100644 --- a/docs/src/content/docs/guides/hot-reload.mdx +++ b/docs/src/content/docs/guides/hot-reload.mdx @@ -6,15 +6,14 @@ description: How protomcp reloads tool processes when source files change. ## How it works -protomcp watches your tool file for changes. When a change is detected: +`pmcp dev` watches the directory containing your entry-point file, recursively. When any matching file changes: -1. A `ReloadRequest` is sent to the tool process over the unix socket -2. The tool process handles the reload (re-registers tools, re-reads config, etc.) -3. The tool process replies with `ReloadResponse { success: true }` -4. protomcp sends `notifications/tools/list_changed` to the MCP host -5. The MCP host re-fetches the tool list - -The file is the same process — no restart, no re-connection. +1. The tool process is killed +2. Internal state (tools, resources, prompts) is reset +3. A fresh process is spawned +4. The new process performs a full handshake, registering its tools from scratch +5. pmcp diffs the new tool list against the old one and sends `notifications/tools/list_changed` if anything changed +6. Stale tools, resources, and prompts that no longer exist are automatically removed --- @@ -32,84 +31,28 @@ pmcp dev tools.ts It is disabled in `run` mode, which is intended for production. -The file watcher monitors the tool file path specified on the command line. Changes to imported modules are not automatically detected — only the entry file. +The watcher monitors the entire directory tree rooted at the entry file's parent directory. Changes to imported modules are detected automatically as long as they live under that directory. --- -## Reload modes - -### Default: graceful reload +## File watching details -```sh -# Python -pmcp dev tools.py - -# TypeScript -pmcp dev tools.ts -``` - -Sends `ReloadRequest` to the running process. The SDK handles this by re-executing the module-level code and re-registering all tools. - -### Immediate: process restart - -```sh -# Python -pmcp dev tools.py --hot-reload immediate - -# TypeScript -pmcp dev tools.ts --hot-reload immediate -``` - -Kills and restarts the tool process. Useful when: -- The language or runtime doesn't support graceful reload -- You want a clean slate every time (no stale module cache) -- You're using `.go` or `.rs` files that need recompilation +- **Debounce**: there is a 100ms debounce — rapid saves (e.g. editor auto-format on write) trigger only one reload. +- **New directories**: subdirectories created while dev mode is running are automatically watched. +- **Skipped directories**: `.git`, `node_modules`, `__pycache__`, `target`, and `venv` are ignored. --- ## In-flight calls -Calls that are in flight when a reload is triggered are not interrupted. protomcp waits for in-flight calls to complete before applying the reload. If a reload is requested while calls are running, it is queued. - ---- - -## SDK behavior on reload - -The Python and TypeScript SDKs re-run the tool registration code automatically. Since `@tool()` and `tool()` append to a global registry, the SDK clears the registry before re-running. - -If you have initialization code that should only run once (e.g. loading a model, connecting to a database), gate it with a module-level flag: - -```python -from protomcp import tool, ToolResult - -_db = None - -def _get_db(): - global _db - if _db is None: - _db = connect_to_db() - return _db - -@tool("Query the database") -def query(sql: str) -> ToolResult: - results = _get_db().execute(sql) - return ToolResult(result=str(results)) -``` - -The database connection is created on first use and reused across reloads. +Calls that are in flight when a change is detected are interrupted — the process is killed immediately. If you need a call to complete before the process exits, that is not supported in dev mode. --- ## Gotchas -**Module-level side effects**: Code at module level runs on every reload. Avoid expensive operations (network calls, model loading) at module level without caching. - -**File-level watch only**: Only the entry file is watched. If you change an imported module, touch the entry file to trigger reload: - -```sh -touch tools.py -``` +**State is lost on every reload**: Because the process is killed and restarted, all in-memory state (caches, sessions, open connections) is lost. Re-initialize lazily if you need to survive across reloads. -**Immediate mode loses state**: With `--hot-reload immediate`, the process is killed and restarted. Any in-memory state (caches, sessions, etc.) is lost. +**Whole directory is watched**: Every file under the entry file's parent directory (filtered by extension) can trigger a reload. If your project directory contains generated or frequently-updated files, move them outside the watched tree or use a subdirectory structure where the entry file lives higher up. -**Syntax errors**: If the tool file has a syntax error after reload, the `ReloadResponse` will contain `success: false` and an error message. The previous tool list remains active. +**Syntax errors at startup**: If the tool process fails to start or crashes during the handshake, pmcp will log the error and leave the previous tool list active until the next successful reload. diff --git a/examples/go/resources_and_prompts/go.mod b/examples/go/resources_and_prompts/go.mod index 3b078e4..5227b21 100644 --- a/examples/go/resources_and_prompts/go.mod +++ b/examples/go/resources_and_prompts/go.mod @@ -2,15 +2,6 @@ module resources-and-prompts-example go 1.25.6 -require github.com/msilverblatt/protomcp/sdk/go v0.0.0 - require ( - github.com/klauspost/compress v1.18.4 // indirect - github.com/msilverblatt/protomcp v0.0.0 // indirect - google.golang.org/protobuf v1.36.11 // indirect -) - -replace ( - github.com/msilverblatt/protomcp => ../../../ - github.com/msilverblatt/protomcp/sdk/go => ../../../sdk/go + github.com/msilverblatt/protomcp/sdk/go v0.2.0 ) diff --git a/go.mod b/go.mod index d33ea1f..de86a2c 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,13 @@ module github.com/msilverblatt/protomcp -go 1.25.0 +go 1.25.6 require ( github.com/fsnotify/fsnotify v1.9.0 github.com/klauspost/compress v1.18.4 github.com/modelcontextprotocol/go-sdk v1.4.1 - golang.org/x/net v0.52.0 - google.golang.org/grpc v1.79.2 google.golang.org/protobuf v1.36.11 + nhooyr.io/websocket v1.8.17 ) require ( @@ -18,7 +17,5 @@ require ( github.com/yosida95/uritemplate/v3 v3.0.2 // indirect golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/sys v0.42.0 // indirect - golang.org/x/text v0.35.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect - nhooyr.io/websocket v1.8.17 // indirect + golang.org/x/tools v0.42.0 // indirect ) diff --git a/go.sum b/go.sum index a0b53b1..9ebc152 100644 --- a/go.sum +++ b/go.sum @@ -1,21 +1,11 @@ -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= -github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= -github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= -github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= -github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= -github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= -github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/modelcontextprotocol/go-sdk v1.4.1 h1:M4x9GyIPj+HoIlHNGpK2hq5o3BFhC+78PkEaldQRphc= @@ -26,34 +16,12 @@ github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfv github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= -go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= -go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= -go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= -go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= -go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= -go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= -go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= -go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= -go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= -go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= -go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= -golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= -golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= -golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= -google.golang.org/grpc v1.79.2 h1:fRMD94s2tITpyJGtBBn7MkMseNpOZU8ZxgC3MMBaXRU= -google.golang.org/grpc v1.79.2/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= nhooyr.io/websocket v1.8.17 h1:KEVeLJkUywCKVsnLIDlD/5gtayKp8VoCkksHCGGfT9Y= diff --git a/internal/bridge/bridge.go b/internal/bridge/bridge.go index a58a070..0035817 100644 --- a/internal/bridge/bridge.go +++ b/internal/bridge/bridge.go @@ -3,6 +3,7 @@ package bridge import ( "context" "log/slog" + "sync" "github.com/modelcontextprotocol/go-sdk/mcp" pb "github.com/msilverblatt/protomcp/gen/proto/protomcp" @@ -35,14 +36,20 @@ type ToolListMutationHandler func(enable, disable []string) // Bridge connects an mcp.Server to a FullBackend. // It registers proxy handlers that forward MCP requests to the SDK process. type Bridge struct { - Server *mcp.Server - backend FullBackend - logger *slog.Logger + Server *mcp.Server + backend FullBackend + logger *slog.Logger onToolListMutation ToolListMutationHandler + + mu sync.Mutex + registeredTools map[string]bool + registeredResources map[string]bool + registeredTemplates map[string]bool + registeredPrompts map[string]bool } // New creates a Bridge with an mcp.Server that proxies to the given backend. -func New(backend FullBackend, logger *slog.Logger) *Bridge { +func New(backend FullBackend, logger *slog.Logger, version string) *Bridge { opts := &mcp.ServerOptions{ Logger: logger, CompletionHandler: func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { @@ -74,14 +81,18 @@ func New(backend FullBackend, logger *slog.Logger) *Bridge { } server := mcp.NewServer( - &mcp.Implementation{Name: "protomcp", Version: "1.0.0"}, + &mcp.Implementation{Name: "protomcp", Version: version}, opts, ) b := &Bridge{ - Server: server, - backend: backend, - logger: logger, + Server: server, + backend: backend, + logger: logger, + registeredTools: make(map[string]bool), + registeredResources: make(map[string]bool), + registeredTemplates: make(map[string]bool), + registeredPrompts: make(map[string]bool), } // Wire reverse-request callbacks from the SDK process. @@ -103,17 +114,32 @@ func (b *Bridge) SetToolListMutationHandler(fn ToolListMutationHandler) { // SyncTools reads tool definitions from the backend and registers them // with the mcp.Server. Called on startup and after hot reload. func (b *Bridge) SyncTools() { - syncTools(b.Server, b.backend, b.onToolListMutation) + b.mu.Lock() + defer b.mu.Unlock() + syncTools(b.Server, b.backend, b.onToolListMutation, b.registeredTools) } // SyncResources reads resource and resource template definitions from the // backend and registers them with the mcp.Server. func (b *Bridge) SyncResources() { - syncResources(b.Server, b.backend) + b.mu.Lock() + defer b.mu.Unlock() + syncResources(b.Server, b.backend, b.registeredResources, b.registeredTemplates) } // SyncPrompts reads prompt definitions from the backend and registers them // with the mcp.Server. func (b *Bridge) SyncPrompts() { - syncPrompts(b.Server, b.backend) + b.mu.Lock() + defer b.mu.Unlock() + syncPrompts(b.Server, b.backend, b.registeredPrompts) +} + +// SyncAll atomically syncs tools, resources, and prompts under a single lock. +func (b *Bridge) SyncAll() { + b.mu.Lock() + defer b.mu.Unlock() + syncTools(b.Server, b.backend, b.onToolListMutation, b.registeredTools) + syncResources(b.Server, b.backend, b.registeredResources, b.registeredTemplates) + syncPrompts(b.Server, b.backend, b.registeredPrompts) } diff --git a/internal/bridge/bridge_test.go b/internal/bridge/bridge_test.go index bc7f0d4..26c80d4 100644 --- a/internal/bridge/bridge_test.go +++ b/internal/bridge/bridge_test.go @@ -65,7 +65,7 @@ func TestBridgeNew(t *testing.T) { {Name: "echo", Description: "echoes input", InputSchemaJson: `{"type":"object"}`}, }, } - b := New(backend, nil) + b := New(backend, nil, "dev") if b.Server == nil { t.Fatal("expected non-nil server") } @@ -78,7 +78,7 @@ func TestSyncTools(t *testing.T) { {Name: "add", Description: "adds numbers", InputSchemaJson: `{"type":"object"}`}, }, } - b := New(backend, nil) + b := New(backend, nil, "dev") b.SyncTools() // Server should now have 2 tools registered // Verified via tool handler invocation in TestMakeToolHandler diff --git a/internal/bridge/prompts.go b/internal/bridge/prompts.go index eaadd97..5b7d046 100644 --- a/internal/bridge/prompts.go +++ b/internal/bridge/prompts.go @@ -14,13 +14,15 @@ type PromptBackend interface { GetPrompt(ctx context.Context, name, argsJSON string) (*pb.GetPromptResponse, error) } -func syncPrompts(server *mcp.Server, backend PromptBackend) { +func syncPrompts(server *mcp.Server, backend PromptBackend, registered map[string]bool) { ctx := context.Background() prompts, err := backend.ListPrompts(ctx) if err != nil { return } + current := make(map[string]bool, len(prompts)) for _, p := range prompts { + current[p.Name] = true prompt := &mcp.Prompt{ Name: p.Name, Description: p.Description, @@ -35,6 +37,25 @@ func syncPrompts(server *mcp.Server, backend PromptBackend) { handler := makePromptHandler(backend, p.Name) server.AddPrompt(prompt, handler) } + + // Remove prompts that were previously registered but are no longer present. + var stale []string + for name := range registered { + if !current[name] { + stale = append(stale, name) + } + } + if len(stale) > 0 { + server.RemovePrompts(stale...) + } + + // Update the registered set to match current. + for name := range registered { + delete(registered, name) + } + for name := range current { + registered[name] = true + } } func makePromptHandler(backend PromptBackend, name string) mcp.PromptHandler { diff --git a/internal/bridge/resources.go b/internal/bridge/resources.go index f399cb1..1e90bcb 100644 --- a/internal/bridge/resources.go +++ b/internal/bridge/resources.go @@ -2,6 +2,7 @@ package bridge import ( "context" + "log/slog" "github.com/modelcontextprotocol/go-sdk/mcp" pb "github.com/msilverblatt/protomcp/gen/proto/protomcp" @@ -14,28 +15,63 @@ type ResourceBackend interface { ReadResource(ctx context.Context, uri string) (*pb.ReadResourceResponse, error) } -func syncResources(server *mcp.Server, backend ResourceBackend) { +func syncResources(server *mcp.Server, backend ResourceBackend, registeredRes map[string]bool, registeredTmpl map[string]bool) { ctx := context.Background() + + // Sync resources. resources, err := backend.ListResources(ctx) if err != nil { - return - } - for _, r := range resources { - res := &mcp.Resource{ - URI: r.Uri, - Name: r.Name, - Description: r.Description, - MIMEType: r.MimeType, + slog.Warn("failed to list resources", "error", err) + } else { + currentRes := make(map[string]bool, len(resources)) + for _, r := range resources { + currentRes[r.Uri] = true + res := &mcp.Resource{ + URI: r.Uri, + Name: r.Name, + Description: r.Description, + MIMEType: r.MimeType, + } + handler := makeResourceHandler(backend, r.Uri) + server.AddResource(res, handler) + } + var staleRes []string + for uri := range registeredRes { + if !currentRes[uri] { + staleRes = append(staleRes, uri) + } + } + if len(staleRes) > 0 { + server.RemoveResources(staleRes...) + } + for uri := range registeredRes { + delete(registeredRes, uri) + } + for uri := range currentRes { + registeredRes[uri] = true } - handler := makeResourceHandler(backend, r.Uri) - server.AddResource(res, handler) } + // Sync resource templates. templates, err := backend.ListResourceTemplates(ctx) if err != nil { + slog.Warn("failed to list resource templates", "error", err) + // Still remove stale templates since we can't verify them + var staleTmpl []string + for uri := range registeredTmpl { + staleTmpl = append(staleTmpl, uri) + } + if len(staleTmpl) > 0 { + server.RemoveResourceTemplates(staleTmpl...) + } + for k := range registeredTmpl { + delete(registeredTmpl, k) + } return } + currentTmpl := make(map[string]bool, len(templates)) for _, t := range templates { + currentTmpl[t.UriTemplate] = true tmpl := &mcp.ResourceTemplate{ URITemplate: t.UriTemplate, Name: t.Name, @@ -45,6 +81,21 @@ func syncResources(server *mcp.Server, backend ResourceBackend) { handler := makeResourceHandler(backend, t.UriTemplate) server.AddResourceTemplate(tmpl, handler) } + var staleTmpl []string + for uri := range registeredTmpl { + if !currentTmpl[uri] { + staleTmpl = append(staleTmpl, uri) + } + } + if len(staleTmpl) > 0 { + server.RemoveResourceTemplates(staleTmpl...) + } + for uri := range registeredTmpl { + delete(registeredTmpl, uri) + } + for uri := range currentTmpl { + registeredTmpl[uri] = true + } } func makeResourceHandler(backend ResourceBackend, uri string) mcp.ResourceHandler { diff --git a/internal/bridge/tools.go b/internal/bridge/tools.go index 7f1343f..6211408 100644 --- a/internal/bridge/tools.go +++ b/internal/bridge/tools.go @@ -9,15 +9,37 @@ import ( pb "github.com/msilverblatt/protomcp/gen/proto/protomcp" ) -// syncTools clears existing tools and re-registers them from the backend. -func syncTools(server *mcp.Server, backend ProcessBackend, onMutation ToolListMutationHandler) { +// syncTools registers tools from the backend and removes any stale tools +// that are no longer present. The registered map tracks currently known tools. +func syncTools(server *mcp.Server, backend ProcessBackend, onMutation ToolListMutationHandler, registered map[string]bool) { tools := backend.ActiveTools() + current := make(map[string]bool, len(tools)) for _, t := range tools { + current[t.Name] = true tool := convertToolDef(t) handler := makeToolHandler(backend, t.Name, onMutation) server.AddTool(tool, handler) } + + // Remove tools that were previously registered but are no longer present. + var stale []string + for name := range registered { + if !current[name] { + stale = append(stale, name) + } + } + if len(stale) > 0 { + server.RemoveTools(stale...) + } + + // Update the registered set to match current. + for name := range registered { + delete(registered, name) + } + for name := range current { + registered[name] = true + } } func convertToolDef(t *pb.ToolDefinition) *mcp.Tool { diff --git a/internal/process/manager.go b/internal/process/manager.go index dd5d725..11f0381 100644 --- a/internal/process/manager.go +++ b/internal/process/manager.go @@ -621,69 +621,22 @@ func (m *Manager) CallToolStream(ctx context.Context, name, argsJSON string) (<- // Reload sends a ReloadRequest, waits for ReloadResponse, then receives the // updated ToolListResponse. func (m *Manager) Reload(ctx context.Context) ([]*pb.ToolDefinition, error) { - reqID := m.nextRequestID() + // Stop the old process and clean up resources. + m.cleanup() - env := &pb.Envelope{ - RequestId: reqID, - Msg: &pb.Envelope_Reload{ - Reload: &pb.ReloadRequest{}, - }, - } + // Wait for the read loop to finish. + m.readWg.Wait() - respCh := make(chan *pb.Envelope, 1) + // Reset internal state for a fresh start. m.mu.Lock() - m.pending[reqID] = respCh + m.pending = make(map[string]chan *pb.Envelope) + m.streams = make(map[string]*streamAssembly) + m.streamChs = make(map[string]chan StreamEvent) + m.handshakeCh = make(chan *pb.Envelope, 4) m.mu.Unlock() - defer func() { - m.mu.Lock() - delete(m.pending, reqID) - m.mu.Unlock() - }() - - m.writeMu.Lock() - err := envelope.Write(m.conn, env) - m.writeMu.Unlock() - if err != nil { - return nil, fmt.Errorf("write ReloadRequest: %w", err) - } - - // Wait for ReloadResponse (matched by request_id). - timeout := m.cfg.CallTimeout - timer := time.NewTimer(timeout) - defer timer.Stop() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-timer.C: - return nil, fmt.Errorf("reload timed out after %v", timeout) - case resp := <-respCh: - reloadResp := resp.GetReloadResponse() - if reloadResp == nil { - return nil, fmt.Errorf("unexpected response type for Reload") - } - if !reloadResp.Success { - return nil, fmt.Errorf("reload failed: %s", reloadResp.Error) - } - } - - // Wait for the unsolicited ToolListResponse (no request_id). - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-timer.C: - return nil, fmt.Errorf("waiting for tool list after reload timed out") - case toolEnv := <-m.handshakeCh: - toolList := toolEnv.GetToolList() - if toolList == nil { - return nil, fmt.Errorf("unexpected message type after reload") - } - m.mu.Lock() - m.tools = toolList.Tools - m.mu.Unlock() - return toolList.Tools, nil - } + // Restart: same logic as Start but reuses existing socket path. + return m.Start(ctx) } // OnCrash returns a channel that receives an error when the child process diff --git a/internal/reload/watcher.go b/internal/reload/watcher.go index e96bc2b..4d6d867 100644 --- a/internal/reload/watcher.go +++ b/internal/reload/watcher.go @@ -2,7 +2,9 @@ package reload import ( "context" + "os" "path/filepath" + "strings" "sync" "time" @@ -27,13 +29,37 @@ func NewWatcher(path string, extensions []string, onChange func()) (*Watcher, er return nil, err } - if err := fsw.Add(path); err != nil { + info, err := os.Stat(path) + if err != nil { + fsw.Close() + return nil, err + } + + watchDir := path + if !info.IsDir() { + watchDir = filepath.Dir(path) + } + + err = filepath.WalkDir(watchDir, func(p string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + if d.IsDir() { + name := d.Name() + if name != "." && (strings.HasPrefix(name, ".") || name == "node_modules" || name == "__pycache__" || name == "target" || name == "venv") { + return filepath.SkipDir + } + return fsw.Add(p) + } + return nil + }) + if err != nil { fsw.Close() return nil, err } return &Watcher{ - path: path, + path: watchDir, extensions: extensions, onChange: onChange, watcher: fsw, @@ -59,6 +85,28 @@ func (w *Watcher) Start(ctx context.Context) error { continue } + // Auto-watch newly created directories + if event.Has(fsnotify.Create) { + if info, err := os.Stat(event.Name); err == nil && info.IsDir() { + name := filepath.Base(event.Name) + if !strings.HasPrefix(name, ".") && name != "node_modules" && name != "__pycache__" && name != "target" && name != "venv" { + filepath.WalkDir(event.Name, func(p string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + if d.IsDir() { + n := d.Name() + if n != "." && (strings.HasPrefix(n, ".") || n == "node_modules" || n == "__pycache__" || n == "target" || n == "venv") { + return filepath.SkipDir + } + w.watcher.Add(p) + } + return nil + }) + } + } + } + if !w.matchesExtension(event.Name) { continue } diff --git a/internal/reload/watcher_test.go b/internal/reload/watcher_test.go index 06de244..08da16f 100644 --- a/internal/reload/watcher_test.go +++ b/internal/reload/watcher_test.go @@ -50,6 +50,83 @@ func TestWatcherFileChange(t *testing.T) { } } +func TestWatcherDirectoryWatch(t *testing.T) { + dir := t.TempDir() + f1 := filepath.Join(dir, "main.py") + os.WriteFile(f1, []byte("v1"), 0644) + + called := make(chan struct{}, 1) + w, err := reload.NewWatcher(dir, []string{".py"}, func() { + select { + case called <- struct{}{}: + default: + } + }) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go w.Start(ctx) + defer w.Stop() + + time.Sleep(50 * time.Millisecond) + f2 := filepath.Join(dir, "helper.py") + os.WriteFile(f2, []byte("v1"), 0644) + + select { + case <-called: + // success + case <-time.After(2 * time.Second): + t.Error("expected callback for new file in watched directory") + } +} + +func TestWatcherNewSubdirectory(t *testing.T) { + dir := t.TempDir() + + called := make(chan struct{}, 1) + w, err := reload.NewWatcher(dir, []string{".py"}, func() { + select { + case called <- struct{}{}: + default: + } + }) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go w.Start(ctx) + defer w.Stop() + + // Give the watcher time to set up + time.Sleep(50 * time.Millisecond) + + // Create a new subdirectory + subdir := filepath.Join(dir, "subpkg") + if err := os.Mkdir(subdir, 0o755); err != nil { + t.Fatal(err) + } + + // Give the watcher time to register the new directory + time.Sleep(100 * time.Millisecond) + + // Create a .py file inside the new subdirectory + if err := os.WriteFile(filepath.Join(subdir, "module.py"), []byte("x = 1"), 0o644); err != nil { + t.Fatal(err) + } + + select { + case <-called: + // success + case <-time.After(2 * time.Second): + t.Error("expected callback for new file in newly created subdirectory") + } +} + func TestWatcherDebounce(t *testing.T) { dir := t.TempDir() testFile := filepath.Join(dir, "test.py") diff --git a/internal/testengine/engine.go b/internal/testengine/engine.go index e9d2928..742ec05 100644 --- a/internal/testengine/engine.go +++ b/internal/testengine/engine.go @@ -158,7 +158,7 @@ func (e *Engine) Start(ctx context.Context) error { e.be = newBackend(e.pm, e.tlm, tools) // Create bridge - e.br = bridge.New(e.be, e.cfg.logger) + e.br = bridge.New(e.be, e.cfg.logger, "dev") e.br.SetToolListMutationHandler(func(enable, disable []string) { if len(enable) > 0 { e.tlm.Enable(enable) diff --git a/sdk/go/go.mod b/sdk/go/go.mod index c353fe7..98c2aac 100644 --- a/sdk/go/go.mod +++ b/sdk/go/go.mod @@ -4,8 +4,6 @@ go 1.25.6 require ( github.com/klauspost/compress v1.18.4 - github.com/msilverblatt/protomcp v0.0.0 + github.com/msilverblatt/protomcp v0.2.0 google.golang.org/protobuf v1.36.11 ) - -replace github.com/msilverblatt/protomcp => ../.. diff --git a/sdk/go/go.sum b/sdk/go/go.sum index 4a6f7a8..a7d0c3c 100644 --- a/sdk/go/go.sum +++ b/sdk/go/go.sum @@ -2,5 +2,7 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/msilverblatt/protomcp v0.2.0 h1:0G4SWMSwuL+sBUZVK6tTcLxzaLHPCIRxa8qs7hkRPK4= +github.com/msilverblatt/protomcp v0.2.0/go.mod h1:4mxRk7EOVB2YyeFh9ZTDLobi2TbU3ABnm7FB/af3VaU= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= diff --git a/sdk/go/protomcp/group.go b/sdk/go/protomcp/group.go index c1de5d8..0864d27 100644 --- a/sdk/go/protomcp/group.go +++ b/sdk/go/protomcp/group.go @@ -37,7 +37,7 @@ var groupRegistry []GroupDef func ToolGroup(name string, opts ...GroupOption) { gd := GroupDef{ Name: name, - Strategy: "union", + Strategy: "separate", } for _, opt := range opts { opt(&gd) diff --git a/sdk/go/protomcp/group_test.go b/sdk/go/protomcp/group_test.go index 80c1485..51c699d 100644 --- a/sdk/go/protomcp/group_test.go +++ b/sdk/go/protomcp/group_test.go @@ -53,6 +53,7 @@ func TestUnionStrategySchema(t *testing.T) { ToolGroup("db", GroupDescription("DB ops"), + GroupStrategy("union"), Action("query", ActionDescription("Run query"), ActionArgs(StrArg("sql")), @@ -231,13 +232,13 @@ func TestGroupsInGetRegisteredTools(t *testing.T) { tools := GetRegisteredTools() found := false for _, td := range tools { - if td.Name == "tools_test" { + if td.Name == "tools_test.ping" { found = true break } } if !found { - t.Error("expected 'tools_test' in registered tools") + t.Error("expected 'tools_test.ping' in registered tools") } } @@ -246,6 +247,7 @@ func TestUnionHandlerDispatch(t *testing.T) { defer ClearGroupRegistry() ToolGroup("handler_test", + GroupStrategy("union"), Action("greet", ActionArgs(StrArg("name")), ActionHandler(func(ctx ToolContext, args map[string]interface{}) ToolResult { diff --git a/sdk/python/pyproject.toml b/sdk/python/pyproject.toml index a85f405..d604f1d 100644 --- a/sdk/python/pyproject.toml +++ b/sdk/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "protomcp" -version = "0.2.0" +version = "0.3.0" description = "Write MCP tools in Python. No MCP knowledge required." readme = "README.md" license = "MIT" diff --git a/sdk/python/src/protomcp/group.py b/sdk/python/src/protomcp/group.py index bda1a87..fae13d4 100644 --- a/sdk/python/src/protomcp/group.py +++ b/sdk/python/src/protomcp/group.py @@ -28,7 +28,7 @@ class GroupDef: description: str actions: list[ActionDef] instance: Any - strategy: str = "union" + strategy: str = "separate" title: str = "" destructive_hint: bool = False idempotent_hint: bool = False @@ -81,7 +81,7 @@ def _generate_action_schema(method: Callable) -> dict: def tool_group( name: str, description: str = "", - strategy: str = "union", + strategy: str = "separate", title: str = "", destructive: bool = False, idempotent: bool = False, diff --git a/sdk/python/tests/test_e2e.py b/sdk/python/tests/test_e2e.py index 8fbf597..2099d04 100644 --- a/sdk/python/tests/test_e2e.py +++ b/sdk/python/tests/test_e2e.py @@ -73,15 +73,17 @@ def multiply(self, a: int, b: int): return ToolResult(result=str(a * b)) tools = get_registered_tools() - assert len(tools) == 1 - assert tools[0].name == "math" + assert len(tools) == 2 + tool_names = [t.name for t in tools] + assert "math.add" in tool_names - handler = tools[0].handler - chain = build_middleware_chain("math", handler) + add_tool = next(t for t in tools if t.name == "math.add") + handler = add_tool.handler + chain = build_middleware_chain("math.add", handler) - emit_telemetry(ToolCallEvent(tool_name="math", action="add", phase="start", args={"a": 2, "b": 3})) - result = chain(None, {"action": "add", "a": 2, "b": 3}) - emit_telemetry(ToolCallEvent(tool_name="math", action="add", phase="success", args={}, result=str(result))) + emit_telemetry(ToolCallEvent(tool_name="math.add", action="add", phase="start", args={"a": 2, "b": 3})) + result = chain(None, {"a": 2, "b": 3}) + emit_telemetry(ToolCallEvent(tool_name="math.add", action="add", phase="success", args={}, result=str(result))) # Handler was called correctly — result contains the sum assert "5" in str(result) @@ -262,8 +264,8 @@ def standalone(): tools = get_registered_tools() tool_names = [t.name for t in tools] - assert "alpha" in tool_names - assert "beta" in tool_names + assert "alpha.do_a" in tool_names + assert "beta.do_b" in tool_names assert "standalone" in tool_names # Dispatch each independently @@ -301,17 +303,17 @@ def explode(self): raise ValueError("boom!") tools = get_registered_tools() - risky_tool = next(t for t in tools if t.name == "risky") + risky_tool = next(t for t in tools if t.name == "risky.explode") - chain = build_middleware_chain("risky", risky_tool.handler) - emit_telemetry(ToolCallEvent(tool_name="risky", action="explode", phase="start", args={})) + chain = build_middleware_chain("risky.explode", risky_tool.handler) + emit_telemetry(ToolCallEvent(tool_name="risky.explode", action="explode", phase="start", args={})) error_caught = None try: - chain(None, {"action": "explode"}) + chain(None, {}) except ValueError as e: error_caught = e - emit_telemetry(ToolCallEvent(tool_name="risky", action="explode", phase="error", args={}, error=e)) + emit_telemetry(ToolCallEvent(tool_name="risky.explode", action="explode", phase="error", args={}, error=e)) assert error_caught is not None assert str(error_caught) == "boom!" @@ -352,8 +354,8 @@ def ping(self): tools = get_registered_tools() tool_names = [t.name for t in tools] - assert "discovered" in tool_names + assert "discovered.ping" in tool_names - discovered_tool = next(t for t in tools if t.name == "discovered") - result = discovered_tool.handler(action="ping") + discovered_tool = next(t for t in tools if t.name == "discovered.ping") + result = discovered_tool.handler() assert str(result) == "pong" diff --git a/sdk/python/tests/test_group.py b/sdk/python/tests/test_group.py index e212af5..d3ef813 100644 --- a/sdk/python/tests/test_group.py +++ b/sdk/python/tests/test_group.py @@ -60,7 +60,7 @@ def do_thing(self, name: str, count: int = 5) -> str: def test_union_strategy_schema(): - @tool_group(name="db", description="DB ops") + @tool_group(name="db", description="DB ops", strategy="union") class DbGroup: @action("query", description="Run query") def query(self, sql: str) -> str: @@ -155,11 +155,11 @@ def ping(self) -> str: tools = get_registered_tools() names = [t.name for t in tools] - assert "tools_test" in names + assert "tools_test.ping" in names def test_union_handler_dispatch(): - @tool_group(name="handler_test", description="Handler test") + @tool_group(name="handler_test", description="Handler test", strategy="union") class HandlerTestGroup: @action("greet", description="Greet") def greet(self, name: str) -> str: diff --git a/sdk/rust/Cargo.toml b/sdk/rust/Cargo.toml index b7e1578..dde8fbe 100644 --- a/sdk/rust/Cargo.toml +++ b/sdk/rust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "protomcp" -version = "0.2.0" +version = "0.3.0" edition = "2021" description = "Rust SDK for protomcp — write MCP tools in Rust" license = "MIT" diff --git a/sdk/rust/src/group.rs b/sdk/rust/src/group.rs index 63b1396..76499e7 100644 --- a/sdk/rust/src/group.rs +++ b/sdk/rust/src/group.rs @@ -17,7 +17,7 @@ pub struct ActionDef { pub name: String, pub description: String, pub args: Vec, - pub handler: Box ToolResult + Send + Sync>, + pub handler: Arc ToolResult + Send + Sync>, pub requires: Vec, pub enum_fields: Vec<(String, Vec)>, pub cross_rules: Vec, @@ -41,7 +41,7 @@ pub struct ActionBuilder { name: String, description: String, args: Vec, - handler: Option ToolResult + Send + Sync>>, + handler: Option ToolResult + Send + Sync>>, requires: Vec, enum_fields: Vec<(String, Vec)>, cross_rules: Vec, @@ -51,7 +51,7 @@ pub fn tool_group(name: &str) -> GroupBuilder { GroupBuilder { name: name.to_string(), description: String::new(), - strategy: "union".to_string(), + strategy: "separate".to_string(), actions: Vec::new(), } } @@ -82,7 +82,7 @@ impl GroupBuilder { name: built.name, description: built.description, args: built.args, - handler: built.handler.unwrap_or_else(|| Box::new(|_, _| ToolResult::new(""))), + handler: built.handler.unwrap_or_else(|| Arc::new(|_, _| ToolResult::new(""))), requires: built.requires, enum_fields: built.enum_fields, cross_rules: built.cross_rules, @@ -131,7 +131,7 @@ impl ActionBuilder { where F: Fn(ToolContext, Value) -> ToolResult + Send + Sync + 'static, { - self.handler = Some(Box::new(f)); + self.handler = Some(Arc::new(f)); self } @@ -272,47 +272,96 @@ fn group_to_separate_defs(group: &GroupDef) -> Vec { } fn dispatch_group_action_by_name(group_name: &str, ctx: ToolContext, args: Value) -> ToolResult { - let guard = GROUP_REGISTRY.lock().unwrap_or_else(|e| e.into_inner()); - let group = guard.iter().find(|g| g.name == group_name); - match group { - Some(g) => dispatch_group_action(g, ctx, args), - None => ToolResult::error( - format!("Group '{}' not found", group_name), - "GROUP_NOT_FOUND", - "", - false, - ), - } + let action_name = match args.get("action").and_then(|v| v.as_str()) { + Some(name) => name.to_string(), + None => { + let guard = GROUP_REGISTRY.lock().unwrap_or_else(|e| e.into_inner()); + let names: Vec = guard.iter() + .find(|g| g.name == group_name) + .map(|g| g.actions.iter().map(|a| a.name.clone()).collect()) + .unwrap_or_default(); + return ToolResult::error( + format!("Missing 'action' field. Available actions: {}", names.join(", ")), + "MISSING_ACTION", + "", + false, + ); + } + }; + + let handler = { + let guard = GROUP_REGISTRY.lock().unwrap_or_else(|e| e.into_inner()); + let group = guard.iter().find(|g| g.name == group_name); + match group { + None => return ToolResult::error( + format!("Group '{}' not found", group_name), + "GROUP_NOT_FOUND", + "", + false, + ), + Some(g) => { + match g.actions.iter().find(|a| a.name == action_name) { + None => { + let names: Vec<&str> = g.actions.iter().map(|a| a.name.as_str()).collect(); + let suggestion = fuzzy_match(&action_name, &names); + let mut msg = format!("Unknown action '{}'.", action_name); + if let Some(s) = &suggestion { + msg.push_str(&format!(" Did you mean '{}'?", s)); + } + msg.push_str(&format!(" Available actions: {}", names.join(", "))); + return ToolResult::error( + msg, + "UNKNOWN_ACTION", + suggestion.unwrap_or_default(), + false, + ); + } + Some(a) => { + if let Some(err) = validate_action(a, &args) { + return err; + } + Arc::clone(&a.handler) + } + } + } + } + }; + // Lock is dropped here before calling the handler + handler(ctx, args) } fn dispatch_specific_action(group_name: &str, action_name: &str, ctx: ToolContext, args: Value) -> ToolResult { - let guard = GROUP_REGISTRY.lock().unwrap_or_else(|e| e.into_inner()); - let group = guard.iter().find(|g| g.name == group_name); - match group { - Some(g) => { - let act = g.actions.iter().find(|a| a.name == action_name); - match act { - Some(a) => { - if let Some(err) = validate_action(a, &args) { - return err; + let handler = { + let guard = GROUP_REGISTRY.lock().unwrap_or_else(|e| e.into_inner()); + let group = guard.iter().find(|g| g.name == group_name); + match group { + None => return ToolResult::error( + format!("Group '{}' not found", group_name), + "GROUP_NOT_FOUND", + "", + false, + ), + Some(g) => { + let act = g.actions.iter().find(|a| a.name == action_name); + match act { + None => return ToolResult::error( + format!("Action '{}' not found in group '{}'", action_name, group_name), + "UNKNOWN_ACTION", + "", + false, + ), + Some(a) => { + if let Some(err) = validate_action(a, &args) { + return err; + } + Arc::clone(&a.handler) } - (a.handler)(ctx, args) } - None => ToolResult::error( - format!("Action '{}' not found in group '{}'", action_name, group_name), - "UNKNOWN_ACTION", - "", - false, - ), } } - None => ToolResult::error( - format!("Group '{}' not found", group_name), - "GROUP_NOT_FOUND", - "", - false, - ), - } + }; + // Lock is dropped here before calling the handler + handler(ctx, args) } fn validate_action(action: &ActionDef, args: &Value) -> Option { @@ -373,6 +422,7 @@ fn validate_action(action: &ActionDef, args: &Value) -> Option { None } +#[cfg(test)] fn dispatch_group_action(group: &GroupDef, ctx: ToolContext, args: Value) -> ToolResult { let action_name = match args.get("action").and_then(|v| v.as_str()) { Some(name) => name.to_string(), @@ -483,7 +533,8 @@ mod tests { #[test] fn test_group_registration() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool_group("math") .description("Math operations") .action("add", |a| { @@ -513,9 +564,11 @@ mod tests { #[test] fn test_union_strategy_schema() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool_group("db") .description("DB ops") + .strategy("union") .action("query", |a| { a.description("Run query").arg(ArgDef::string("sql")) }) @@ -550,7 +603,8 @@ mod tests { #[test] fn test_separate_strategy_schema() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool_group("files") .description("File ops") .strategy("separate") @@ -576,7 +630,8 @@ mod tests { #[test] fn test_dispatch_correct_action() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool_group("calc") .action("add", |a| { a.arg(ArgDef::int("a")) @@ -605,7 +660,8 @@ mod tests { #[test] fn test_dispatch_unknown_action() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool_group("calc2") .action("add", |a| a.handler(|_, _| ToolResult::new("ok"))) .register(); @@ -625,7 +681,8 @@ mod tests { #[test] fn test_dispatch_missing_action() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool_group("calc3") .action("add", |a| a.handler(|_, _| ToolResult::new("ok"))) .register(); @@ -642,14 +699,15 @@ mod tests { #[test] fn test_groups_in_tool_defs() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool_group("tools_test") .description("Test group") .action("ping", |a| a.handler(|_, _| ToolResult::new("pong"))) .register(); crate::tool::with_registry(|tools| { - let found = tools.iter().any(|d| d.name == "tools_test"); + let found = tools.iter().any(|d| d.name == "tools_test.ping"); assert!(found); }); @@ -659,7 +717,8 @@ mod tests { #[test] fn test_validation_requires() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool_group("val_req") .action("create", |a| { a.requires(&["name", "email"]) @@ -692,7 +751,8 @@ mod tests { #[test] fn test_validation_enum_field() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool_group("val_enum") .action("set_mode", |a| { a.enum_field("mode", &["fast", "slow", "balanced"]) @@ -724,7 +784,8 @@ mod tests { #[test] fn test_validation_cross_rule() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool_group("val_cross") .action("transfer", |a| { a.cross_rule( @@ -766,7 +827,8 @@ mod tests { #[test] fn test_validation_enum_fuzzy_suggestion() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool_group("val_fuzzy") .action("color", |a| { a.enum_field("color", &["red", "green", "blue"]) diff --git a/sdk/rust/src/lib.rs b/sdk/rust/src/lib.rs index b5b1ca2..3bb2058 100644 --- a/sdk/rust/src/lib.rs +++ b/sdk/rust/src/lib.rs @@ -29,10 +29,30 @@ pub use runner::run; pub use log::ServerLogger; pub use manager::ToolManager; pub use middleware::{middleware, clear_middleware_registry}; -pub use resource::{register_resource, register_resource_template, ResourceDef, ResourceTemplateDef, ResourceContent}; -pub use prompt::{register_prompt, PromptDef, PromptArg, PromptMessage}; +pub use resource::{register_resource, register_resource_template, ResourceDef, ResourceTemplateDef, ResourceContent, clear_resource_registry, clear_resource_template_registry}; +pub use prompt::{register_prompt, PromptDef, PromptArg, PromptMessage, clear_prompt_registry}; pub use completion::{register_completion, CompletionResult}; pub use server_context::{server_context, resolve_contexts, clear_context_registry}; pub use local_middleware::{local_middleware, build_middleware_chain, clear_local_middleware}; pub use telemetry::{telemetry_sink, emit_telemetry, clear_telemetry_sinks, ToolCallEvent}; pub use sidecar::{sidecar, start_sidecars, stop_all_sidecars, clear_sidecar_registry, SidecarDef}; + +/// Clears all global registries. Call at the start of every test to ensure isolation. +pub fn clear_all_registries() { + clear_registry(); + clear_group_registry(); + clear_workflow_registry(); + clear_middleware_registry(); + clear_local_middleware(); + clear_telemetry_sinks(); + clear_context_registry(); + clear_sidecar_registry(); + clear_prompt_registry(); + clear_resource_registry(); + clear_resource_template_registry(); +} + +/// A process-wide mutex to serialize tests that share global registry state. +/// Acquire this at the top of any test that reads or writes global registries. +#[cfg(test)] +pub static TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); diff --git a/sdk/rust/src/local_middleware.rs b/sdk/rust/src/local_middleware.rs index e6732c5..30f072c 100644 --- a/sdk/rust/src/local_middleware.rs +++ b/sdk/rust/src/local_middleware.rs @@ -106,7 +106,8 @@ mod tests { #[test] fn test_no_middleware() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); let handler: Box ToolResult + Send + Sync> = Box::new(|_, _| ToolResult::new("direct")); let chain = build_middleware_chain("test_tool", handler); @@ -117,7 +118,8 @@ mod tests { #[test] fn test_single_middleware() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); local_middleware(10, |ctx, _tool_name, args, next| { let mut result = next(ctx, args); @@ -136,25 +138,26 @@ mod tests { #[test] fn test_priority_ordering() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); let call_order = Arc::new(Mutex::new(Vec::::new())); let co1 = call_order.clone(); local_middleware(20, move |ctx, _tn, args, next| { - co1.lock().unwrap().push(20); + co1.lock().unwrap_or_else(|e| e.into_inner()).push(20); next(ctx, args) }); let co2 = call_order.clone(); local_middleware(5, move |ctx, _tn, args, next| { - co2.lock().unwrap().push(5); + co2.lock().unwrap_or_else(|e| e.into_inner()).push(5); next(ctx, args) }); let co3 = call_order.clone(); local_middleware(10, move |ctx, _tn, args, next| { - co3.lock().unwrap().push(10); + co3.lock().unwrap_or_else(|e| e.into_inner()).push(10); next(ctx, args) }); @@ -164,7 +167,7 @@ mod tests { let result = chain(dummy_ctx(), serde_json::json!({})); assert_eq!(result.result_text, "done"); - let order = call_order.lock().unwrap(); + let order = call_order.lock().unwrap_or_else(|e| e.into_inner()); assert_eq!(*order, vec![5, 10, 20]); clear_local_middleware(); @@ -172,7 +175,8 @@ mod tests { #[test] fn test_middleware_can_short_circuit() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); local_middleware(1, |_ctx, _tn, _args, _next| { ToolResult::error("blocked", "BLOCKED", "", false) diff --git a/sdk/rust/src/middleware.rs b/sdk/rust/src/middleware.rs index 126d4a5..faee892 100644 --- a/sdk/rust/src/middleware.rs +++ b/sdk/rust/src/middleware.rs @@ -45,8 +45,8 @@ mod tests { #[test] fn test_middleware_registration() { - let _lock = TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); - clear_middleware_registry(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); middleware("audit", 10, |_phase, _tool, _args, _result, _err| { HashMap::new() }); diff --git a/sdk/rust/src/prompt.rs b/sdk/rust/src/prompt.rs index 2dfd1cb..d22bed0 100644 --- a/sdk/rust/src/prompt.rs +++ b/sdk/rust/src/prompt.rs @@ -29,3 +29,7 @@ where F: FnOnce(&[PromptDef]) -> R { let guard = PROMPT_REGISTRY.lock().unwrap_or_else(|e| e.into_inner()); f(&guard) } + +pub fn clear_prompt_registry() { + PROMPT_REGISTRY.lock().unwrap_or_else(|e| e.into_inner()).clear(); +} diff --git a/sdk/rust/src/resource.rs b/sdk/rust/src/resource.rs index e96b2f5..dd37969 100644 --- a/sdk/rust/src/resource.rs +++ b/sdk/rust/src/resource.rs @@ -58,3 +58,11 @@ where F: FnOnce(&[ResourceTemplateDef]) -> R { let guard = TEMPLATE_REGISTRY.lock().unwrap_or_else(|e| e.into_inner()); f(&guard) } + +pub fn clear_resource_registry() { + RESOURCE_REGISTRY.lock().unwrap_or_else(|e| e.into_inner()).clear(); +} + +pub fn clear_resource_template_registry() { + TEMPLATE_REGISTRY.lock().unwrap_or_else(|e| e.into_inner()).clear(); +} diff --git a/sdk/rust/src/server_context.rs b/sdk/rust/src/server_context.rs index 9aeaf67..c7ef48e 100644 --- a/sdk/rust/src/server_context.rs +++ b/sdk/rust/src/server_context.rs @@ -55,7 +55,8 @@ mod tests { #[test] fn test_register_and_resolve() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); server_context("user_id", |args| { let token = args.get("auth_token").and_then(|v| v.as_str()).unwrap_or(""); @@ -81,7 +82,8 @@ mod tests { #[test] fn test_multiple_contexts() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); server_context("ctx_a", |_args| Value::String("val_a".to_string())); server_context("ctx_b", |_args| Value::Number(serde_json::Number::from(42))); @@ -98,7 +100,8 @@ mod tests { #[test] fn test_clear_registry() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); server_context("x", |_| Value::Null); { let guard = CONTEXT_REGISTRY.lock().unwrap_or_else(|e| e.into_inner()); diff --git a/sdk/rust/src/sidecar.rs b/sdk/rust/src/sidecar.rs index 1fede89..66b8533 100644 --- a/sdk/rust/src/sidecar.rs +++ b/sdk/rust/src/sidecar.rs @@ -319,7 +319,8 @@ mod tests { #[test] fn test_sidecar_registration() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); sidecar("redis", &["redis-server", "--port", "6380"]) .health_check("http://localhost:6380") @@ -345,7 +346,8 @@ mod tests { #[test] fn test_builder_defaults() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); sidecar("test", &["echo", "hi"]).register(); @@ -367,7 +369,8 @@ mod tests { #[test] fn test_start_sidecars_filter() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); // Register with a trigger that won't actually start a real process // (command doesn't exist, which is fine — we just test filtering) @@ -390,7 +393,8 @@ mod tests { #[test] fn test_clear_sidecar_registry() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); sidecar("x", &["true"]).register(); { let guard = SIDECAR_REGISTRY.lock().unwrap_or_else(|e| e.into_inner()); @@ -405,7 +409,8 @@ mod tests { #[test] fn test_stop_all_no_crash() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); sidecar("phantom", &["sleep", "999"]).register(); // No processes actually running, stop should be a no-op stop_all_sidecars(); diff --git a/sdk/rust/src/telemetry.rs b/sdk/rust/src/telemetry.rs index 3fbdaf2..e7f4c67 100644 --- a/sdk/rust/src/telemetry.rs +++ b/sdk/rust/src/telemetry.rs @@ -67,7 +67,8 @@ mod tests { #[test] fn test_register_and_emit() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); let count = Arc::new(AtomicI32::new(0)); let count2 = count.clone(); @@ -85,7 +86,8 @@ mod tests { #[test] fn test_multiple_sinks() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); let count_a = Arc::new(AtomicI32::new(0)); let count_b = Arc::new(AtomicI32::new(0)); @@ -104,7 +106,8 @@ mod tests { #[test] fn test_panic_in_sink_does_not_crash() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); let reached = Arc::new(AtomicI32::new(0)); let reached2 = reached.clone(); @@ -121,13 +124,14 @@ mod tests { #[test] fn test_event_fields() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); let captured_name = Arc::new(Mutex::new(String::new())); let cn = captured_name.clone(); telemetry_sink(move |event| { - *cn.lock().unwrap() = format!( + *cn.lock().unwrap_or_else(|e| e.into_inner()) = format!( "{}:{}:{}ms:err={}", event.tool_name, event.phase, event.duration_ms, event.is_error ); @@ -138,14 +142,15 @@ mod tests { event.is_error = false; emit_telemetry(event); - assert_eq!(*captured_name.lock().unwrap(), "calc:success:42ms:err=false"); + assert_eq!(*captured_name.lock().unwrap_or_else(|e| e.into_inner()), "calc:success:42ms:err=false"); clear_telemetry_sinks(); } #[test] fn test_clear() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); telemetry_sink(|_| {}); telemetry_sink(|_| {}); { diff --git a/sdk/rust/src/tool.rs b/sdk/rust/src/tool.rs index 7cf4ee7..80b89ea 100644 --- a/sdk/rust/src/tool.rs +++ b/sdk/rust/src/tool.rs @@ -202,7 +202,8 @@ mod tests { #[test] fn test_tool_registration() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool("add") .description("Add two numbers") .arg(ArgDef::int("a")) @@ -220,7 +221,8 @@ mod tests { #[test] fn test_array_arg() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool("list_items") .description("List items") .arg(ArgDef::array("tags", "string")) @@ -239,7 +241,8 @@ mod tests { #[test] fn test_object_arg() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool("set_config") .description("Set config") .arg(ArgDef::object("config")) @@ -257,7 +260,8 @@ mod tests { #[test] fn test_union_arg() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool("process") .description("Process data") .arg(ArgDef::union("data", &["string", "object"])) @@ -278,7 +282,8 @@ mod tests { #[test] fn test_literal_arg() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool("set_mode") .description("Set mode") .arg(ArgDef::literal("mode", &["fast", "slow", "balanced"])) @@ -301,7 +306,8 @@ mod tests { #[test] fn test_tool_metadata() { - let _lock = lock_and_clear(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + crate::clear_all_registries(); tool("delete_user") .description("Delete a user") .destructive_hint(true) diff --git a/sdk/rust/src/workflow.rs b/sdk/rust/src/workflow.rs index 2f8bee7..8bbdffd 100644 --- a/sdk/rust/src/workflow.rs +++ b/sdk/rust/src/workflow.rs @@ -718,8 +718,7 @@ mod tests { } fn cleanup() { - clear_workflow_registry(); - clear_registry(); + crate::clear_all_registries(); } fn cleanup_and_lock() -> std::sync::MutexGuard<'static, ()> { @@ -730,7 +729,8 @@ mod tests { #[test] fn test_basic_workflow_registration() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("deploy") .description("Deploy workflow") .step("start", |s| { @@ -756,7 +756,8 @@ mod tests { #[test] fn test_tool_defs_generated() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("order") .step("create", |s| { s.description("Create order") @@ -795,7 +796,8 @@ mod tests { #[test] fn test_step_dispatch_initial() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("flow1") .step("begin", |s| { s.initial() @@ -821,7 +823,8 @@ mod tests { #[test] fn test_step_dispatch_terminal() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("flow2") .step("begin", |s| { s.initial() @@ -851,7 +854,8 @@ mod tests { #[test] fn test_cancel() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("flow3") .step("begin", |s| { s.initial() @@ -876,7 +880,8 @@ mod tests { #[test] fn test_cancel_no_active() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("flow4") .step("begin", |s| { s.initial() @@ -897,7 +902,8 @@ mod tests { #[test] fn test_non_initial_without_active_workflow() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("flow5") .step("begin", |s| { s.initial() @@ -918,7 +924,8 @@ mod tests { #[test] fn test_dynamic_next_narrowing() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("branch") .step("start", |s| { s.initial() @@ -945,7 +952,8 @@ mod tests { #[test] fn test_dynamic_next_invalid() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("bad_next") .step("start", |s| { s.initial() @@ -966,7 +974,8 @@ mod tests { #[test] fn test_error_handling_with_on_error() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("errable") .step("risky", |s| { s.initial() @@ -991,7 +1000,8 @@ mod tests { #[test] fn test_error_no_match_stays_in_state() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("errable2") .step("risky", |s| { s.initial() @@ -1016,7 +1026,8 @@ mod tests { #[test] fn test_no_cancel_step() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("strict") .step("start", |s| { s.initial() @@ -1039,7 +1050,8 @@ mod tests { #[test] fn test_on_cancel_callback() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); use std::sync::{Arc, Mutex}; let called = Arc::new(Mutex::new(false)); let called_clone = called.clone(); @@ -1055,20 +1067,21 @@ mod tests { .handler(|_, _| StepResult::new("done")) }) .on_cancel(move |_step, _history| { - *called_clone.lock().unwrap() = true; + *called_clone.lock().unwrap_or_else(|e| e.into_inner()) = true; "cancelled".to_string() }) .register(); handle_step_call("cbflow", "begin", dummy_ctx(), serde_json::json!({})); handle_cancel("cbflow"); - assert!(*called.lock().unwrap()); + assert!(*called.lock().unwrap_or_else(|e| e.into_inner())); cleanup(); } #[test] fn test_on_complete_callback() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); use std::sync::{Arc, Mutex}; let called = Arc::new(Mutex::new(false)); let called_clone = called.clone(); @@ -1084,19 +1097,20 @@ mod tests { .handler(|_, _| StepResult::new("done")) }) .on_complete(move |_history| { - *called_clone.lock().unwrap() = true; + *called_clone.lock().unwrap_or_else(|e| e.into_inner()) = true; }) .register(); handle_step_call("compflow", "begin", dummy_ctx(), serde_json::json!({})); handle_step_call("compflow", "end", dummy_ctx(), serde_json::json!({})); - assert!(*called.lock().unwrap()); + assert!(*called.lock().unwrap_or_else(|e| e.into_inner())); cleanup(); } #[test] fn test_history_tracking() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); use std::sync::{Arc, Mutex}; let history_len = Arc::new(Mutex::new(0usize)); let hl = history_len.clone(); @@ -1112,13 +1126,13 @@ mod tests { .handler(|_, _| StepResult::new("b done")) }) .on_complete(move |history| { - *hl.lock().unwrap() = history.len(); + *hl.lock().unwrap_or_else(|e| e.into_inner()) = history.len(); }) .register(); handle_step_call("hist", "a", dummy_ctx(), serde_json::json!({})); handle_step_call("hist", "b", dummy_ctx(), serde_json::json!({})); - assert_eq!(*history_len.lock().unwrap(), 2); + assert_eq!(*history_len.lock().unwrap_or_else(|e| e.into_inner()), 2); cleanup(); } @@ -1127,7 +1141,8 @@ mod tests { #[test] #[should_panic(expected = "no initial step")] fn test_validation_no_initial() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("bad") .step("a", |s| { s.terminal() @@ -1139,7 +1154,8 @@ mod tests { #[test] #[should_panic(expected = "multiple initial steps")] fn test_validation_multiple_initial() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("bad2") .step("a", |s| { s.initial() @@ -1157,7 +1173,8 @@ mod tests { #[test] #[should_panic(expected = "terminal step")] fn test_validation_terminal_with_next() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("bad3") .step("a", |s| { s.initial() @@ -1175,7 +1192,8 @@ mod tests { #[test] #[should_panic(expected = "dead end")] fn test_validation_dead_end() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("bad4") .step("a", |s| { s.initial() @@ -1192,7 +1210,8 @@ mod tests { #[test] #[should_panic(expected = "nonexistent step")] fn test_validation_bad_next_ref() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("bad5") .step("a", |s| { s.initial() @@ -1205,7 +1224,8 @@ mod tests { #[test] #[should_panic(expected = "on_error references nonexistent")] fn test_validation_bad_on_error_ref() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("bad6") .step("a", |s| { s.initial() @@ -1224,6 +1244,7 @@ mod tests { #[test] fn test_glob_match() { + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); assert!(glob_match("foo*", "foobar")); assert!(glob_match("foo*", "foo")); assert!(!glob_match("foo*", "barfoo")); @@ -1237,6 +1258,7 @@ mod tests { #[test] fn test_visibility_matching() { + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); // Neither allow nor block -> false assert!(!matches_visibility("tool", &None, &None)); @@ -1259,7 +1281,8 @@ mod tests { #[test] fn test_step_level_visibility_override() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("vis") .allow_during(&["global_*"]) .step("s1", |s| { @@ -1291,7 +1314,8 @@ mod tests { #[test] fn test_workflow_level_visibility() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("vis2") .allow_during(&["ext_*"]) .step("s1", |s| { @@ -1322,7 +1346,8 @@ mod tests { #[test] fn test_workflows_to_tool_defs() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("wf1") .step("init", |s| { s.initial().next(&["done"]).handler(|_, _| StepResult::new("ok")) @@ -1342,7 +1367,8 @@ mod tests { #[test] fn test_all_no_cancel_no_cancel_tool() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("nc") .step("start", |s| { s.initial().no_cancel().next(&["end"]).handler(|_, _| StepResult::new("ok")) @@ -1364,7 +1390,8 @@ mod tests { #[test] fn test_step_with_args_schema() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("argflow") .step("input", |s| { s.initial() @@ -1391,7 +1418,8 @@ mod tests { #[test] fn test_handler_receives_args() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("argtest") .step("greet", |s| { s.initial() @@ -1412,7 +1440,8 @@ mod tests { #[test] fn test_multi_step_flow() { - let _lock = cleanup_and_lock(); + let _lock = crate::TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + cleanup(); workflow("pipeline") .step("step1", |s| { s.initial() diff --git a/sdk/typescript/package.json b/sdk/typescript/package.json index f712386..ef7df35 100644 --- a/sdk/typescript/package.json +++ b/sdk/typescript/package.json @@ -1,6 +1,6 @@ { "name": "protomcp", - "version": "0.2.0", + "version": "0.3.0", "description": "Write MCP tools in TypeScript. No MCP knowledge required.", "type": "module", "main": "dist/index.js", diff --git a/sdk/typescript/src/group.ts b/sdk/typescript/src/group.ts index d255484..7c8e8f2 100644 --- a/sdk/typescript/src/group.ts +++ b/sdk/typescript/src/group.ts @@ -46,7 +46,7 @@ export function toolGroup(options: GroupOptions): GroupDef { name: options.name, description: options.description, actions: options.actions, - strategy: options.strategy ?? 'union', + strategy: options.strategy ?? 'separate', }; groupRegistry.push(def); return def; diff --git a/sdk/typescript/tests/group.test.ts b/sdk/typescript/tests/group.test.ts index 1165e71..5577932 100644 --- a/sdk/typescript/tests/group.test.ts +++ b/sdk/typescript/tests/group.test.ts @@ -46,6 +46,7 @@ describe('toolGroup', () => { toolGroup({ name: 'db', description: 'DB ops', + strategy: 'union', actions: { query: { description: 'Run query', @@ -112,6 +113,7 @@ describe('toolGroup', () => { toolGroup({ name: 'calc', description: 'Calculator', + strategy: 'union', actions: { add: { description: 'Add', @@ -130,6 +132,7 @@ describe('toolGroup', () => { toolGroup({ name: 'calc2', description: 'Calculator', + strategy: 'union', actions: { add: { description: 'Add', @@ -151,6 +154,7 @@ describe('toolGroup', () => { toolGroup({ name: 'calc3', description: 'Calculator', + strategy: 'union', actions: { add: { description: 'Add', @@ -182,7 +186,7 @@ describe('toolGroup', () => { const tools = getRegisteredTools(); const names = tools.map((t) => t.name); - expect(names).toContain('tools_test'); + expect(names).toContain('tools_test.ping'); }); it('dispatches separate strategy handlers', () => { @@ -210,6 +214,7 @@ describe('declarative validation', () => { toolGroup({ name: 'val_req', description: 'Validation test', + strategy: 'union', actions: { doIt: { description: 'Do something', @@ -232,6 +237,7 @@ describe('declarative validation', () => { toolGroup({ name: 'val_enum', description: 'Enum test', + strategy: 'union', actions: { setColor: { description: 'Set color', @@ -255,6 +261,7 @@ describe('declarative validation', () => { toolGroup({ name: 'val_enum_ok', description: 'Enum ok test', + strategy: 'union', actions: { setColor: { description: 'Set color', @@ -274,6 +281,7 @@ describe('declarative validation', () => { toolGroup({ name: 'val_cross', description: 'Cross rules test', + strategy: 'union', actions: { range: { description: 'Set range', @@ -298,6 +306,7 @@ describe('declarative validation', () => { toolGroup({ name: 'val_cross_ok', description: 'Cross ok', + strategy: 'union', actions: { range: { description: 'Set range', @@ -319,6 +328,7 @@ describe('declarative validation', () => { toolGroup({ name: 'val_hints', description: 'Hints test', + strategy: 'union', actions: { deploy: { description: 'Deploy', @@ -345,6 +355,7 @@ describe('declarative validation', () => { toolGroup({ name: 'val_hints_none', description: 'No hints', + strategy: 'union', actions: { deploy: { description: 'Deploy', diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index 309c85d..024c451 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -5,7 +5,9 @@ import ( "os" "path/filepath" "runtime" + "strings" "testing" + "time" "github.com/msilverblatt/protomcp/tests/testutil" ) @@ -85,6 +87,52 @@ func TestE2E_ToolsCall(t *testing.T) { } } +func TestE2E_ToolGroupSeparate(t *testing.T) { + w, r, cleanup := StartProtomcp(t, "dev", fixture("tool_group_separate.py")) + defer cleanup() + + InitializeSession(t, w, r) + + resp := SendRequest(t, w, r, "tools/list", nil) + if resp.Error != nil { + t.Fatalf("tools/list error: %v", resp.Error) + } + + var result testutil.ToolsListResult + if err := json.Unmarshal(resp.Result, &result); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + names := map[string]bool{} + for _, tool := range result.Tools { + names[tool.Name] = true + } + if !names["db.query"] { + t.Error("expected tool 'db.query'") + } + if !names["db.insert"] { + t.Error("expected tool 'db.insert'") + } + if names["db"] { + t.Error("should NOT have a single 'db' tool in separate strategy") + } + + callResp := SendRequest(t, w, r, "tools/call", map[string]interface{}{ + "name": "db.query", + "arguments": map[string]string{"sql": "SELECT * FROM users"}, + }) + if callResp.Error != nil { + t.Fatalf("tools/call error: %v", callResp.Error) + } + var callResult testutil.ToolsCallResult + if err := json.Unmarshal(callResp.Result, &callResult); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if callResult.IsError { + t.Error("tool call should not be an error") + } +} + func TestE2E_DynamicToolList(t *testing.T) { w, r, cleanup := StartProtomcp(t, "dev", fixture("dynamic_tool.py")) defer cleanup() @@ -119,6 +167,197 @@ func TestE2E_DynamicToolList(t *testing.T) { } } +func TestE2E_WorkflowBasic(t *testing.T) { + w, r, cleanup := StartProtomcp(t, "dev", fixture("workflow_basic.py")) + defer cleanup() + + InitializeSession(t, w, r) + + // List tools — should see deploy.review (initial step) and status + resp := SendRequestSkipNotifications(t, w, r, "tools/list", nil) + if resp.Error != nil { + t.Fatalf("tools/list error: %v", resp.Error) + } + var toolsList testutil.ToolsListResult + json.Unmarshal(resp.Result, &toolsList) + + names := map[string]bool{} + for _, tool := range toolsList.Tools { + names[tool.Name] = true + } + if !names["status"] { + t.Error("expected 'status' tool to be visible") + } + if !names["deploy.review"] { + t.Error("expected 'deploy.review' (initial step) to be visible") + } + + // Call the initial step + callResp := SendRequestSkipNotifications(t, w, r, "tools/call", map[string]interface{}{ + "name": "deploy.review", + "arguments": map[string]string{"pr_number": "42"}, + }) + if callResp.Error != nil { + t.Fatalf("deploy.review call error: %v", callResp.Error) + } + + // After calling review, the next steps (approve, reject) should become available + time.Sleep(200 * time.Millisecond) + + listResp2 := SendRequestSkipNotifications(t, w, r, "tools/list", nil) + var toolsList2 testutil.ToolsListResult + json.Unmarshal(listResp2.Result, &toolsList2) + + names2 := map[string]bool{} + for _, tool := range toolsList2.Tools { + names2[tool.Name] = true + } + if !names2["deploy.approve"] { + t.Error("expected 'deploy.approve' to be visible after review step") + } + + // Call approve, then execute + approveResp := SendRequestSkipNotifications(t, w, r, "tools/call", map[string]interface{}{ + "name": "deploy.approve", + "arguments": map[string]interface{}{}, + }) + if approveResp.Error != nil { + t.Fatalf("deploy.approve call error: %v", approveResp.Error) + } + + time.Sleep(200 * time.Millisecond) + + executeResp := SendRequestSkipNotifications(t, w, r, "tools/call", map[string]interface{}{ + "name": "deploy.execute", + "arguments": map[string]interface{}{}, + }) + if executeResp.Error != nil { + t.Fatalf("deploy.execute call error: %v", executeResp.Error) + } + var execResult testutil.ToolsCallResult + json.Unmarshal(executeResp.Result, &execResult) + if execResult.IsError { + t.Error("deploy.execute should succeed after approval") + } +} + +func TestE2E_Sidecar(t *testing.T) { + w, r, cleanup := StartProtomcp(t, "dev", fixture("sidecar_basic.py")) + defer cleanup() + + InitializeSession(t, w, r) + + resp := SendRequestSkipNotifications(t, w, r, "tools/call", map[string]interface{}{ + "name": "check_sidecar", + "arguments": map[string]interface{}{}, + }) + if resp.Error != nil { + t.Fatalf("check_sidecar error: %v", resp.Error) + } + + var result testutil.ToolsCallResult + json.Unmarshal(resp.Result, &result) + + text := extractText(result) + if !strings.Contains(text, "200") { + t.Errorf("expected sidecar to be reachable with status 200, got: %s", text) + } +} + +func TestE2E_HotReload(t *testing.T) { + // Copy fixture to a temp dir so we can modify it + tmpDir := t.TempDir() + srcFile := filepath.Join(tmpDir, "server.py") + v1Content, _ := os.ReadFile(fixture("hot_reload_v1.py")) + os.WriteFile(srcFile, v1Content, 0644) + + w, r, cleanup := StartProtomcp(t, "dev", srcFile) + defer cleanup() + + InitializeSession(t, w, r) + + // Verify v1 tools + resp := SendRequestSkipNotifications(t, w, r, "tools/list", nil) + var list1 testutil.ToolsListResult + json.Unmarshal(resp.Result, &list1) + + foundOriginal := false + for _, tool := range list1.Tools { + if tool.Name == "original" { + foundOriginal = true + } + } + if !foundOriginal { + t.Fatal("expected 'original' tool in v1") + } + + // Overwrite file with v2 content (different tool) + v2Content := []byte("from protomcp import tool\nfrom protomcp.runner import run\n\n@tool(description=\"New tool added in v2\")\ndef new_tool() -> str:\n return \"v2\"\n\nif __name__ == \"__main__\":\n run()\n") + os.WriteFile(srcFile, v2Content, 0644) + + // Poll tools/list until we see the new tool (reload: debounce 100ms + process restart) + deadline := time.Now().Add(15 * time.Second) + var names map[string]bool + for time.Now().Before(deadline) { + time.Sleep(500 * time.Millisecond) + resp2 := SendRequestSkipNotifications(t, w, r, "tools/list", nil) + var list2 testutil.ToolsListResult + json.Unmarshal(resp2.Result, &list2) + + names = map[string]bool{} + for _, tool := range list2.Tools { + names[tool.Name] = true + } + if names["new_tool"] { + break + } + } + + if !names["new_tool"] { + t.Error("expected 'new_tool' after hot reload") + } + if names["original"] { + t.Error("'original' should have been removed after hot reload") + } +} + +func TestE2E_Middleware(t *testing.T) { + w, r, cleanup := StartProtomcp(t, "dev", fixture("middleware_basic.py")) + defer cleanup() + + InitializeSession(t, w, r) + + resp := SendRequestSkipNotifications(t, w, r, "tools/call", map[string]interface{}{ + "name": "echo_args", + "arguments": map[string]string{"message": "hello"}, + }) + if resp.Error != nil { + t.Fatalf("echo_args error: %v", resp.Error) + } + + var result testutil.ToolsCallResult + json.Unmarshal(resp.Result, &result) + if result.IsError { + t.Fatalf("echo_args returned error: %s", string(resp.Result)) + } + + resultText := extractText(result) + if !strings.Contains(resultText, `"source": "middleware"`) { + t.Errorf("expected middleware-injected 'source' field in result, got: %s", resultText) + } + + logResp := SendRequestSkipNotifications(t, w, r, "tools/call", map[string]interface{}{ + "name": "get_call_log", + "arguments": map[string]interface{}{}, + }) + var logResult testutil.ToolsCallResult + json.Unmarshal(logResp.Result, &logResult) + + logText := extractText(logResult) + if !strings.Contains(logText, "echo_args") { + t.Errorf("expected 'echo_args' in call log, got: %s", logText) + } +} // --------------------------------------------------------------------------- // Python E2E: resource read // --------------------------------------------------------------------------- diff --git a/test/e2e/fixtures/hot_reload_v1.py b/test/e2e/fixtures/hot_reload_v1.py new file mode 100644 index 0000000..fc2f495 --- /dev/null +++ b/test/e2e/fixtures/hot_reload_v1.py @@ -0,0 +1,9 @@ +from protomcp import tool +from protomcp.runner import run + +@tool(description="Original tool") +def original() -> str: + return "v1" + +if __name__ == "__main__": + run() diff --git a/test/e2e/fixtures/middleware_basic.py b/test/e2e/fixtures/middleware_basic.py new file mode 100644 index 0000000..a5a1eb0 --- /dev/null +++ b/test/e2e/fixtures/middleware_basic.py @@ -0,0 +1,30 @@ +from protomcp import tool, ToolResult +from protomcp.local_middleware import local_middleware +from protomcp.runner import run + +_call_log = [] + +@local_middleware(priority=10) +def audit_logger(ctx, tool_name, args, next_handler): + """Logs every tool call and passes through.""" + _call_log.append({"tool": tool_name, "args": dict(args)}) + return next_handler(ctx, args) + +@local_middleware(priority=20) +def arg_injector(ctx, tool_name, args, next_handler): + """Injects a 'source' field into all tool args.""" + args["source"] = "middleware" + return next_handler(ctx, args) + +@tool(description="Echo args back as JSON") +def echo_args(**kwargs) -> str: + import json + return json.dumps(kwargs, sort_keys=True) + +@tool(description="Get the call log") +def get_call_log(**kwargs) -> str: + import json + return json.dumps(_call_log, sort_keys=True) + +if __name__ == "__main__": + run() diff --git a/test/e2e/fixtures/sidecar_basic.py b/test/e2e/fixtures/sidecar_basic.py new file mode 100644 index 0000000..3f425a5 --- /dev/null +++ b/test/e2e/fixtures/sidecar_basic.py @@ -0,0 +1,28 @@ +import random +import urllib.request +from protomcp import tool +from protomcp.sidecar import sidecar +from protomcp.runner import run + +_PORT = random.randint(49152, 65535) + +@sidecar( + name="test_http", + command=["python3", "-m", "http.server", str(_PORT)], + health_check=f"http://localhost:{_PORT}/", + start_on="server_start", + health_timeout=10, +) +class TestHTTPSidecar: + pass + +@tool(description="Check if sidecar is reachable") +def check_sidecar() -> str: + try: + resp = urllib.request.urlopen(f"http://localhost:{_PORT}/", timeout=5) + return f"sidecar status: {resp.status}" + except Exception as e: + return f"sidecar unreachable: {e}" + +if __name__ == "__main__": + run() diff --git a/test/e2e/fixtures/tool_group_separate.py b/test/e2e/fixtures/tool_group_separate.py new file mode 100644 index 0000000..23e8585 --- /dev/null +++ b/test/e2e/fixtures/tool_group_separate.py @@ -0,0 +1,15 @@ +from protomcp import tool_group, action, ToolResult +from protomcp.runner import run + +@tool_group("db", description="Database operations") +class DatabaseTools: + @action("query", description="Run a SQL query") + def query(self, sql: str) -> str: + return f"Results for: {sql}" + + @action("insert", description="Insert a record") + def insert(self, table: str, data: str) -> str: + return f"Inserted into {table}: {data}" + +if __name__ == "__main__": + run() diff --git a/test/e2e/fixtures/workflow_basic.py b/test/e2e/fixtures/workflow_basic.py new file mode 100644 index 0000000..4f49728 --- /dev/null +++ b/test/e2e/fixtures/workflow_basic.py @@ -0,0 +1,36 @@ +from protomcp import tool, ToolResult +from protomcp.workflow import workflow, step +from protomcp.runner import run + +@workflow("deploy", description="Deployment workflow") +class DeployWorkflow: + def __init__(self): + self.approved = False + self.reviewed = False + + @step("review", description="Review the deployment", initial=True, next=["approve", "reject"]) + def review(self, pr_number: str) -> str: + self.reviewed = True + return f"Reviewed PR #{pr_number}" + + @step("approve", description="Approve the deployment", next=["execute"]) + def approve(self) -> str: + self.approved = True + return "Deployment approved" + + @step("reject", description="Reject the deployment", terminal=True) + def reject(self, reason: str) -> str: + return f"Deployment rejected: {reason}" + + @step("execute", description="Execute the deployment", terminal=True) + def execute(self) -> str: + if not self.approved: + return "ERROR: not approved" + return "Deployment executed successfully" + +@tool(description="Check server status") +def status() -> str: + return "all systems operational" + +if __name__ == "__main__": + run() diff --git a/test/e2e/helpers.go b/test/e2e/helpers.go index 56e3bcc..abb42de 100644 --- a/test/e2e/helpers.go +++ b/test/e2e/helpers.go @@ -3,18 +3,21 @@ package e2e import ( "bufio" "encoding/json" + "fmt" "io" "os/exec" "path/filepath" "sync" + "sync/atomic" "testing" "github.com/msilverblatt/protomcp/tests/testutil" ) var ( - pmcpBinary string - pmcpBinaryOnce sync.Once + pmcpBinary string + pmcpBinaryOnce sync.Once + requestIDCounter int64 ) func getPMCPBinary(t *testing.T) string { @@ -58,7 +61,8 @@ func StartProtomcp(t *testing.T, args ...string) (io.Writer, *bufio.Scanner, fun // SendRequest sends a JSON-RPC request and reads the response. func SendRequest(t *testing.T, w io.Writer, r *bufio.Scanner, method string, params interface{}) testutil.JSONRPCResponse { t.Helper() - id := json.RawMessage(`1`) + idVal := atomic.AddInt64(&requestIDCounter, 1) + id := json.RawMessage(fmt.Sprintf("%d", idVal)) req := testutil.JSONRPCRequest{ JSONRPC: "2.0", ID: id, @@ -96,6 +100,65 @@ func SendNotification(t *testing.T, w io.Writer, method string, params interface w.Write(append(data, '\n')) } +// SendRequestSkipNotifications sends a request and reads the response, +// skipping any JSON-RPC notifications (messages without an "id" field) and +// responses whose ID does not match the request ID. +func SendRequestSkipNotifications(t *testing.T, w io.Writer, r *bufio.Scanner, method string, params interface{}) testutil.JSONRPCResponse { + t.Helper() + reqID := atomic.AddInt64(&requestIDCounter, 1) + id := json.RawMessage(fmt.Sprintf("%d", reqID)) + req := testutil.JSONRPCRequest{ + JSONRPC: "2.0", + ID: id, + Method: method, + } + if params != nil { + p, _ := json.Marshal(params) + req.Params = p + } + data, _ := json.Marshal(req) + w.Write(append(data, '\n')) + + for { + if !r.Scan() { + t.Fatalf("no response from protomcp for method %q: %v", method, r.Err()) + } + line := r.Bytes() + + var check map[string]json.RawMessage + if json.Unmarshal(line, &check) == nil { + rawID, hasID := check["id"] + if !hasID { + continue // skip notification + } + // Skip responses with non-matching IDs + expectedID := fmt.Sprintf("%d", reqID) + if string(rawID) != expectedID { + continue + } + } + + var resp testutil.JSONRPCResponse + if err := json.Unmarshal(line, &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + return resp + } +} + +// extractText returns the text from the first text-type content item. +func extractText(result testutil.ToolsCallResult) string { + if len(result.Content) == 0 { + return "" + } + for _, c := range result.Content { + if c.Type == "text" { + return c.Text + } + } + return "" +} + // InitializeSession sends a proper MCP initialize handshake. func InitializeSession(t *testing.T, w io.Writer, r *bufio.Scanner) testutil.JSONRPCResponse { t.Helper()