diff --git a/Makefile b/Makefile index 4d59f8d..de6bf3a 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ PKG := ./cmd/beacon-proxy export PATH := /opt/homebrew/bin:/usr/local/go/bin:$(HOME)/go/bin:$(PATH) -.PHONY: build build-small test test-race test-verbose bench coverage coverage-html lint fmt tidy vulncheck generate install check clean run +.PHONY: build build-small test test-race test-verbose bench coverage coverage-html lint fmt tidy vulncheck generate install check clean run demo demo-fake verify build: go build -o $(BINARY) $(PKG) @@ -59,3 +59,21 @@ clean: run: build ./$(BINARY) $(ARGS) + +demo: build + @echo "🔦 Beacon demo — resetting audit DB and generating real e2e traffic..." + @rm -f ~/.beacon/audit.db ~/.beacon/audit.db-wal ~/.beacon/audit.db-shm + @open http://localhost:8080 2>/dev/null || xdg-open http://localhost:8080 2>/dev/null || true + @go run ./cmd/beacon-e2e/ + +demo-fake: build + @echo "🔦 Beacon demo — resetting audit DB and generating fake traffic..." + @rm -f ~/.beacon/audit.db ~/.beacon/audit.db-wal ~/.beacon/audit.db-shm + @go run ./cmd/beacon-traffic/ + @echo "" + @echo "Starting dashboard..." + @open http://localhost:8080 2>/dev/null || xdg-open http://localhost:8080 2>/dev/null || echo "Open http://localhost:8080" + @./$(BINARY) --server-name demo --port 8080 -- sh -c 'echo "beacon-demo-server"; exec cat' + +verify: build + ./$(BINARY) verify $(ARGS) diff --git a/README.md b/README.md index d282d94..9176c7f 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ make build ``` beacon-proxy [flags] -- [server-args...] +beacon-proxy verify [--db path] ``` ### Flags @@ -131,7 +132,7 @@ beacon-proxy --server-name filesystem -- npx -y @modelcontextprotocol/server-fil Beacon includes a real-time web dashboard at `http://localhost:8080` (configurable via `--port`). -![Beacon Dashboard](docs/dashboard.png) +![Beacon Dashboard — Tool Calls](docs/dashboard-toolcalls.png) - **Live tool call stream** — see every MCP action as it happens via SSE - **Session overview** — browse sessions, see tool call counts and max risk per server @@ -140,17 +141,33 @@ Beacon includes a real-time web dashboard at `http://localhost:8080` (configurab - **Approve/deny** — handle paused tool calls directly from the browser - **Hash chain verification** — one-click integrity check of the entire audit trail +![Beacon Dashboard — Hash Chain Verification](docs/dashboard.png) + No build step — the dashboard is a single HTML file embedded in the binary. -### Generate Demo Traffic +### Quick Demo + +One command to build, generate traffic, and open the dashboard: + +```bash +make demo +``` -To populate the dashboard with realistic sample data across multiple servers: +Or step-by-step: ```bash -go run ./cmd/beacon-traffic/ +go run ./cmd/beacon-traffic/ # populate with realistic sample data +open http://localhost:8080 # open dashboard ``` -This creates sessions for GitHub, filesystem, and PostgreSQL MCP servers with a mix of read/write/delete/execute operations, risk scores, and policy actions (flag, pause, block). +### Verify Hash Chain + +Check the integrity of the audit trail from the CLI: + +```bash +beacon-proxy verify +beacon-proxy verify --db /path/to/audit.db +``` ## Inspecting the Audit Trail diff --git a/cmd/beacon-e2e/main.go b/cmd/beacon-e2e/main.go new file mode 100644 index 0000000..55a2d89 --- /dev/null +++ b/cmd/beacon-e2e/main.go @@ -0,0 +1,398 @@ +// beacon-e2e generates real MCP traffic by driving beacon-proxy end-to-end. +// It spawns beacon-proxy wrapping a real MCP server (filesystem by default), +// sends JSON-RPC requests, and reads real responses — producing genuine audit trail data. +package main + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "log" + "os" + "os/exec" + "path/filepath" + "sync" + "time" +) + +type request struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Method string `json:"method"` + Params any `json:"params,omitempty"` +} + +type response struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` +} + +func main() { + // Use a temp dir for the filesystem server to browse + workDir, err := os.MkdirTemp("", "beacon-e2e-*") + if err != nil { + log.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(workDir) + + // Create some realistic files for the filesystem server to find + seedFiles(workDir) + + // Build beacon-proxy path + binaryPath := findBinary() + + fmt.Println("🔦 Beacon E2E Traffic Generator") + fmt.Println(" Driving real MCP traffic through beacon-proxy...") + fmt.Printf(" Workspace: %s\n", workDir) + fmt.Println() + + // Spawn beacon-proxy wrapping the filesystem MCP server + cmd := exec.Command(binaryPath, + "--server-name", "filesystem", + "--port", "8080", + "--", + "npx", "-y", "@modelcontextprotocol/server-filesystem", workDir, + ) + cmd.Stderr = os.Stderr + + stdin, err := cmd.StdinPipe() + if err != nil { + log.Fatalf("stdin pipe: %v", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + log.Fatalf("stdout pipe: %v", err) + } + + if err := cmd.Start(); err != nil { + log.Fatalf("failed to start beacon-proxy: %v", err) + } + + reader := bufio.NewReader(stdout) + var mu sync.Mutex // protects stdin writes + + send := func(req request) *response { + mu.Lock() + data, err := json.Marshal(req) + if err != nil { + mu.Unlock() + log.Printf("failed to marshal request: %v", err) + return nil + } + if _, err := stdin.Write(data); err != nil { + mu.Unlock() + log.Printf("failed to write request: %v", err) + return nil + } + if _, err := stdin.Write([]byte("\n")); err != nil { + mu.Unlock() + log.Printf("failed to write newline: %v", err) + return nil + } + mu.Unlock() + + // Read response (line-delimited JSON) + for { + line, err := reader.ReadBytes('\n') + if err != nil { + if err == io.EOF { + return nil + } + log.Printf("read error: %v", err) + return nil + } + var resp response + if err := json.Unmarshal(line, &resp); err != nil { + continue // skip non-JSON lines (e.g. server stderr leaks) + } + // Check if this is a response (has id) or a notification (no id) + if resp.ID != nil && string(resp.ID) != "null" { + return &resp + } + // Skip notifications, keep reading + } + } + + // ===== Act 1: Initialize ===== + fmt.Println(" ▸ Initializing MCP connection...") + resp := send(request{ + JSONRPC: "2.0", + ID: 1, + Method: "initialize", + Params: map[string]any{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]any{}, + "clientInfo": map[string]any{ + "name": "beacon-e2e", + "version": "1.0.0", + }, + }, + }) + if resp != nil && resp.Error == nil { + fmt.Println(" ✓ MCP session initialized") + } else { + fmt.Println(" ✗ Initialize failed") + if resp != nil { + fmt.Printf(" Error: %s\n", string(resp.Error)) + } + } + + // Send initialized notification + notif, _ := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "method": "notifications/initialized", + }) + mu.Lock() + stdin.Write(notif) + stdin.Write([]byte("\n")) + mu.Unlock() + time.Sleep(500 * time.Millisecond) + + // ===== Act 2: List tools ===== + fmt.Println(" ▸ Discovering available tools...") + resp = send(request{ + JSONRPC: "2.0", + ID: 2, + Method: "tools/list", + }) + if resp != nil && resp.Error == nil { + var result struct { + Tools []struct { + Name string `json:"name"` + } `json:"tools"` + } + json.Unmarshal(resp.Result, &result) + fmt.Printf(" ✓ Found %d tools:", len(result.Tools)) + for _, t := range result.Tools { + fmt.Printf(" %s", t.Name) + } + fmt.Println() + } + + time.Sleep(1 * time.Second) + + // ===== Act 3: Explore — read operations (intent group 1) ===== + fmt.Println() + fmt.Println(" ┌ Intent: Explore project structure") + + calls := []struct { + tool string + args map[string]any + desc string + }{ + {"list_directory", map[string]any{"path": workDir}, "listing directory"}, + {"read_file", map[string]any{"path": filepath.Join(workDir, "README.md")}, "reading README.md"}, + {"read_file", map[string]any{"path": filepath.Join(workDir, "src/main.go")}, "reading src/main.go"}, + {"read_file", map[string]any{"path": filepath.Join(workDir, "config.yaml")}, "reading config.yaml"}, + } + + id := 10 + for i, c := range calls { + time.Sleep(300 * time.Millisecond) // fast calls within intent group + id++ + resp = send(request{ + JSONRPC: "2.0", + ID: id, + Method: "tools/call", + Params: map[string]any{ + "name": c.tool, + "arguments": c.args, + }, + }) + connector := "│" + if i == len(calls)-1 { + connector = "└" + } + if resp != nil && resp.Error == nil { + fmt.Printf(" %s ✓ %s (%s)\n", connector, c.tool, c.desc) + } else { + fmt.Printf(" %s ✗ %s failed\n", connector, c.tool) + } + } + + // Gap between intent groups (>5s to create separate intent) + fmt.Println() + fmt.Println(" ⏳ (6s gap — new intent group)") + time.Sleep(6 * time.Second) + + // ===== Act 4: Modify — write operations (intent group 2) ===== + fmt.Println() + fmt.Println(" ┌ Intent: Create new files") + + writeCalls := []struct { + tool string + args map[string]any + desc string + }{ + {"write_file", map[string]any{ + "path": filepath.Join(workDir, "CHANGELOG.md"), + "content": "# Changelog\n\n## v0.2.0 - Dashboard\n- Real-time web UI with SSE\n- Intent chain visualization\n- Hash chain verification\n", + }, "writing CHANGELOG.md"}, + {"write_file", map[string]any{ + "path": filepath.Join(workDir, "TODO.md"), + "content": "# TODO\n\n- [ ] SIEM export (Datadog, Splunk)\n- [ ] Multi-server intent chains\n- [ ] Role-based access control\n", + }, "writing TODO.md"}, + {"create_directory", map[string]any{ + "path": filepath.Join(workDir, "docs"), + }, "creating docs/"}, + {"write_file", map[string]any{ + "path": filepath.Join(workDir, "docs/architecture.md"), + "content": "# Architecture\n\nBeacon Proxy sits between MCP client and server.\n\nClient → Beacon → Server\n ↓\n SQLite audit trail\n ↓\n Dashboard (SSE)\n", + }, "writing docs/architecture.md"}, + } + + for i, c := range writeCalls { + time.Sleep(400 * time.Millisecond) + id++ + resp = send(request{ + JSONRPC: "2.0", + ID: id, + Method: "tools/call", + Params: map[string]any{ + "name": c.tool, + "arguments": c.args, + }, + }) + connector := "│" + if i == len(writeCalls)-1 { + connector = "└" + } + if resp != nil && resp.Error == nil { + fmt.Printf(" %s ✓ %s (%s)\n", connector, c.tool, c.desc) + } else { + errMsg := "" + if resp != nil && resp.Error != nil { + errMsg = string(resp.Error) + } + fmt.Printf(" %s ⚠ %s — %s\n", connector, c.tool, errMsg) + } + } + + // Another gap + fmt.Println() + fmt.Println(" ⏳ (6s gap — new intent group)") + time.Sleep(6 * time.Second) + + // ===== Act 5: Verify — read back what we wrote (intent group 3) ===== + fmt.Println() + fmt.Println(" ┌ Intent: Verify written files") + + verifyCalls := []struct { + tool string + args map[string]any + desc string + }{ + {"list_directory", map[string]any{"path": workDir}, "listing updated directory"}, + {"read_file", map[string]any{"path": filepath.Join(workDir, "CHANGELOG.md")}, "reading CHANGELOG.md"}, + {"read_file", map[string]any{"path": filepath.Join(workDir, "TODO.md")}, "reading TODO.md"}, + {"list_directory", map[string]any{"path": filepath.Join(workDir, "docs")}, "listing docs/"}, + {"read_file", map[string]any{"path": filepath.Join(workDir, "docs/architecture.md")}, "reading docs/architecture.md"}, + } + + for i, c := range verifyCalls { + time.Sleep(250 * time.Millisecond) + id++ + resp = send(request{ + JSONRPC: "2.0", + ID: id, + Method: "tools/call", + Params: map[string]any{ + "name": c.tool, + "arguments": c.args, + }, + }) + connector := "│" + if i == len(verifyCalls)-1 { + connector = "└" + } + if resp != nil && resp.Error == nil { + fmt.Printf(" %s ✓ %s (%s)\n", connector, c.tool, c.desc) + } else { + fmt.Printf(" %s ✗ %s failed\n", connector, c.tool) + } + } + + // Clean shutdown + fmt.Println() + fmt.Println(" ✅ Done! 3 intent chains generated with real MCP traffic") + fmt.Println(" Open http://localhost:8080 to see the dashboard") + fmt.Println() + + stdin.Close() + cmd.Wait() +} + +// seedFiles creates realistic project files in the work directory. +func seedFiles(dir string) { + os.MkdirAll(filepath.Join(dir, "src"), 0o755) + + os.WriteFile(filepath.Join(dir, "README.md"), []byte(`# My Project + +A Go application demonstrating MCP audit capabilities. + +## Quick Start +`+"`"+`bash +go build -o app ./cmd/app +./app serve --port 3000 +`+"`"+` +`), 0o644) + + os.WriteFile(filepath.Join(dir, "config.yaml"), []byte(`server: + port: 3000 + host: localhost +database: + driver: postgres + dsn: postgres://localhost/myapp + max_connections: 25 +logging: + level: info + format: json +`), 0o644) + + os.WriteFile(filepath.Join(dir, "src/main.go"), []byte(`package main + +import ( + "fmt" + "net/http" +) + +func main() { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello from the app!") + }) + http.ListenAndServe(":3000", nil) +} +`), 0o644) + + os.WriteFile(filepath.Join(dir, "go.mod"), []byte(`module github.com/example/myproject + +go 1.26 +`), 0o644) +} + +// findBinary looks for the beacon-proxy binary. +func findBinary() string { + // Try local build first + if _, err := os.Stat("./beacon-proxy"); err == nil { + return "./beacon-proxy" + } + // Try $GOPATH/bin + if gopath := os.Getenv("GOPATH"); gopath != "" { + p := filepath.Join(gopath, "bin", "beacon-proxy") + if _, err := os.Stat(p); err == nil { + return p + } + } + // Fallback: build it + fmt.Println(" Building beacon-proxy...") + cmd := exec.Command("go", "build", "-o", "./beacon-proxy", "./cmd/beacon-proxy") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + log.Fatalf("failed to build beacon-proxy: %v", err) + } + return "./beacon-proxy" +} diff --git a/cmd/beacon-proxy/main.go b/cmd/beacon-proxy/main.go index 5545683..65c5734 100644 --- a/cmd/beacon-proxy/main.go +++ b/cmd/beacon-proxy/main.go @@ -2,6 +2,8 @@ package main import ( "context" + "crypto/rand" + "encoding/hex" "flag" "fmt" "log" @@ -20,6 +22,12 @@ import ( ) func main() { + // Check for subcommands before flag parsing + if len(os.Args) > 1 && os.Args[1] == "verify" { + runVerify(os.Args[2:]) + return + } + serverName := flag.String("server-name", "unknown", "label for this MCP server") dbPath := flag.String("db", "~/.beacon/audit.db", "path to SQLite audit database") rulesPath := flag.String("rules", "", "path to YAML policy rules file (uses defaults if not set)") @@ -38,7 +46,8 @@ func main() { } if len(cmdArgs) == 0 { - fmt.Fprintf(os.Stderr, "Usage: beacon-proxy [flags] -- [server-args...]\n\n") + fmt.Fprintf(os.Stderr, "Usage: beacon-proxy [flags] -- [server-args...]\n") + fmt.Fprintf(os.Stderr, " beacon-proxy verify [--db path]\n\n") fmt.Fprintf(os.Stderr, "Flags:\n") flag.PrintDefaults() os.Exit(1) @@ -82,8 +91,15 @@ func main() { } engine := policy.NewEngine(rules) + // Generate auth token for approval endpoints + authToken := os.Getenv("BEACON_AUTH_TOKEN") + if authToken == "" { + authToken = generateToken() + } + log.Printf("beacon: approval token: %s", authToken) + // Start dashboard web server - dashboard := web.NewServer(store, engine) + dashboard := web.NewServer(store, engine, authToken) go startHTTPServer(*port, dashboard) ctx, cancel := context.WithCancel(context.Background()) @@ -112,6 +128,38 @@ func main() { } } +// runVerify verifies the hash chain integrity of the audit database. +func runVerify(args []string) { + verifyFlags := flag.NewFlagSet("verify", flag.ExitOnError) + dbPath := verifyFlags.String("db", "~/.beacon/audit.db", "path to SQLite audit database") + verifyFlags.Parse(args) + + db := expandHome(*dbPath) + encryptionKey := os.Getenv("BEACON_ENCRYPTION_KEY") + + store, err := audit.Open(db, encryptionKey) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to open database: %v\n", err) + os.Exit(1) + } + defer store.Close() + + fmt.Printf("Verifying hash chain in %s...\n", db) + + status, err := store.VerifyChain() + if err != nil { + fmt.Fprintf(os.Stderr, "Verification error: %v\n", err) + os.Exit(1) + } + + if status.Valid { + fmt.Printf("✅ Chain intact — %d message(s) verified, no tampering detected.\n", status.Total) + } else { + fmt.Printf("❌ Chain broken at sequence %d: %s\n", status.BrokenAt, status.Error) + os.Exit(1) + } +} + func startHTTPServer(port int, dashboard *web.Server) { addr := fmt.Sprintf("127.0.0.1:%d", port) log.Printf("beacon: dashboard listening on http://%s", addr) @@ -120,6 +168,14 @@ func startHTTPServer(port int, dashboard *web.Server) { } } +func generateToken() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + log.Fatalf("beacon: failed to generate auth token: %v", err) + } + return hex.EncodeToString(b) +} + func expandHome(path string) string { if strings.HasPrefix(path, "~/") { home, err := os.UserHomeDir() diff --git a/cmd/beacon-traffic/main.go b/cmd/beacon-traffic/main.go index 6d4d5b5..7a2b723 100644 --- a/cmd/beacon-traffic/main.go +++ b/cmd/beacon-traffic/main.go @@ -15,17 +15,23 @@ import ( ) type scenario struct { - server string - command string - toolCalls []fakeCall + server string + command string + // Each burst becomes a separate intent chain (separated by >2s gaps) + bursts []burst +} + +type burst struct { + label string // human-readable description of what this intent is + calls []fakeCall // tool calls in this burst (fired rapidly) } type fakeCall struct { tool string args map[string]any result map[string]any - delayMs int // delay before this call - policySet string + delayMs int // delay before this call (within burst, keep <1s) + policySet string // override policy action } func main() { @@ -42,35 +48,74 @@ func main() { { server: "github", command: "npx -y @modelcontextprotocol/server-github", - toolCalls: []fakeCall{ - {tool: "list_repos", args: m("owner", "ojongerius"), result: m("count", 12), delayMs: 200}, - {tool: "search_issues", args: m("repo", "beacon", "query", "audit"), result: m("count", 3), delayMs: 400}, - {tool: "read_file", args: m("repo", "beacon", "path", "README.md"), result: m("size", 4096), delayMs: 150}, - {tool: "create_issue", args: m("repo", "beacon", "title", "Add SIEM export", "body", "Stream events to Datadog"), result: m("number", 9), delayMs: 800, policySet: "flag"}, - {tool: "update_issue", args: m("repo", "beacon", "number", 5, "labels", []string{"enhancement"}), result: m("updated", true), delayMs: 300, policySet: "flag"}, + bursts: []burst{ + { + label: "Research: find open issues about audit", + calls: []fakeCall{ + {tool: "list_repos", args: m("owner", "ojongerius"), result: m("count", 12), delayMs: 100}, + {tool: "search_issues", args: m("repo", "beacon", "query", "audit trail"), result: m("count", 3), delayMs: 200}, + {tool: "read_file", args: m("repo", "beacon", "path", "README.md"), result: m("size", 4096), delayMs: 150}, + }, + }, + { + label: "Action: create issue and update labels", + calls: []fakeCall{ + {tool: "create_issue", args: m("repo", "beacon", "title", "Add SIEM export", "body", "Stream audit events to Datadog for SOC team visibility"), result: m("number", 9), delayMs: 300, policySet: "flag"}, + {tool: "update_issue", args: m("repo", "beacon", "number", 9, "labels", []string{"enhancement", "security"}), result: m("updated", true), delayMs: 200, policySet: "flag"}, + {tool: "create_issue", args: m("repo", "beacon", "title", "Multi-server intent chains", "body", "Track intent across GitHub + Slack + Postgres"), result: m("number", 10), delayMs: 250, policySet: "flag"}, + }, + }, }, }, { server: "filesystem", command: "npx -y @modelcontextprotocol/server-filesystem /Users/otto/projects", - toolCalls: []fakeCall{ - {tool: "list_directory", args: m("path", "/Users/otto/projects"), result: m("entries", 8), delayMs: 100}, - {tool: "read_file", args: m("path", "/Users/otto/projects/beacon/go.mod"), result: m("size", 512), delayMs: 120}, - {tool: "read_file", args: m("path", "/Users/otto/projects/beacon/internal/web/server.go"), result: m("size", 6144), delayMs: 180}, - {tool: "write_file", args: m("path", "/Users/otto/projects/beacon/TODO.md", "content", "# TODO\n- SIEM export\n- Multi-server"), result: m("written", true), delayMs: 500, policySet: "flag"}, - {tool: "delete_file", args: m("path", "/Users/otto/projects/beacon/tmp/old-cache.json"), result: m("deleted", true), delayMs: 400, policySet: "pause"}, + bursts: []burst{ + { + label: "Explore: understand project structure", + calls: []fakeCall{ + {tool: "list_directory", args: m("path", "/Users/otto/projects/beacon"), result: m("entries", 12), delayMs: 80}, + {tool: "read_file", args: m("path", "/Users/otto/projects/beacon/go.mod"), result: m("size", 512), delayMs: 100}, + {tool: "read_file", args: m("path", "/Users/otto/projects/beacon/internal/web/server.go"), result: m("size", 6144), delayMs: 120}, + {tool: "read_file", args: m("path", "/Users/otto/projects/beacon/internal/audit/store.go"), result: m("size", 8192), delayMs: 130}, + }, + }, + { + label: "Modify: write TODO and clean up temp files", + calls: []fakeCall{ + {tool: "write_file", args: m("path", "/Users/otto/projects/beacon/TODO.md", "content", "# TODO\n- SIEM export\n- Multi-server intent chains\n- Dashboard filters"), result: m("written", true), delayMs: 200, policySet: "flag"}, + {tool: "write_file", args: m("path", "/Users/otto/projects/beacon/CHANGELOG.md", "content", "# Changelog\n## v0.2 - Dashboard\n- Real-time web UI\n- Intent chain visualization"), result: m("written", true), delayMs: 180, policySet: "flag"}, + {tool: "delete_file", args: m("path", "/Users/otto/projects/beacon/tmp/old-cache.json"), result: m("deleted", true), delayMs: 250, policySet: "pause"}, + }, + }, }, }, { server: "postgres", command: "npx -y @modelcontextprotocol/server-postgres postgres://localhost/app", - toolCalls: []fakeCall{ - {tool: "list_tables", args: m("schema", "public"), result: m("tables", []string{"users", "orders", "products"}), delayMs: 150}, - {tool: "describe_table", args: m("table", "users"), result: m("columns", 8), delayMs: 200}, - {tool: "read_query", args: m("sql", "SELECT COUNT(*) FROM users"), result: m("count", 1847), delayMs: 250}, - {tool: "read_query", args: m("sql", "SELECT email, created_at FROM users ORDER BY created_at DESC LIMIT 5"), result: m("rows", 5), delayMs: 300}, - {tool: "exec_query", args: m("sql", "UPDATE users SET status = 'active' WHERE last_login > NOW() - INTERVAL '30 days'"), result: m("affected", 423), delayMs: 1200, policySet: "pause"}, - {tool: "exec_query", args: m("sql", "DELETE FROM sessions WHERE expired_at < NOW()"), result: m("affected", 89), delayMs: 600, policySet: "block"}, + bursts: []burst{ + { + label: "Investigate: check user table structure and counts", + calls: []fakeCall{ + {tool: "list_tables", args: m("schema", "public"), result: m("tables", []string{"users", "orders", "products", "sessions", "audit_log"}), delayMs: 100}, + {tool: "describe_table", args: m("table", "users"), result: m("columns", 8, "rows_estimate", 1847), delayMs: 120}, + {tool: "read_query", args: m("sql", "SELECT COUNT(*) FROM users"), result: m("count", 1847), delayMs: 150}, + }, + }, + { + label: "Analyze: query recent user activity", + calls: []fakeCall{ + {tool: "read_query", args: m("sql", "SELECT email, created_at FROM users ORDER BY created_at DESC LIMIT 5"), result: m("rows", 5), delayMs: 200}, + {tool: "read_query", args: m("sql", "SELECT status, COUNT(*) FROM users GROUP BY status"), result: m("rows", 3, "data", []string{"active: 1204", "inactive: 521", "suspended: 122"}), delayMs: 180}, + }, + }, + { + label: "Mutate: bulk update user status and clean expired sessions", + calls: []fakeCall{ + {tool: "exec_query", args: m("sql", "UPDATE users SET status = 'active' WHERE last_login > NOW() - INTERVAL '30 days'"), result: m("affected", 423), delayMs: 400, policySet: "pause"}, + {tool: "exec_query", args: m("sql", "DELETE FROM sessions WHERE expired_at < NOW()"), result: m("affected", 89), delayMs: 300, policySet: "block"}, + }, + }, }, }, } @@ -87,95 +132,106 @@ func main() { continue } - intentID := "" - var lastCallTime time.Time + msgSeq := 0 - for i, tc := range sc.toolCalls { - time.Sleep(time.Duration(tc.delayMs) * time.Millisecond) + for bi, b := range sc.bursts { + // Gap between bursts to create separate intents + if bi > 0 { + gap := 3000 + rand.Intn(2000) // 3-5s gap between bursts + time.Sleep(time.Duration(gap) * time.Millisecond) + } - now := time.Now().UTC() - argsJSON, _ := json.Marshal(tc.args) + fmt.Printf(" ┌ Intent: %s\n", b.label) - // Log request message - reqRaw := fmt.Sprintf(`{"jsonrpc":"2.0","id":%d,"method":"tools/call","params":{"name":"%s","arguments":%s}}`, - i+1, tc.tool, argsJSON) - msgID, err := store.LogMessage(sessionID, "client_to_server", fmt.Sprintf("%d", i+1), "tools/call", reqRaw) + intentID, err := store.CreateIntentContext(sessionID) if err != nil { - log.Printf(" ✗ log request: %v", err) + log.Printf(" ✗ failed to create intent: %v", err) continue } - // Classify and score - opType := audit.ClassifyOperation(tc.tool) - score, reasons := audit.ScoreRisk(tc.tool, opType, string(argsJSON)) - - tcID := fmt.Sprintf("tc-%s-%d", sc.server, i+1) - err = store.CreateToolCall(audit.ToolCallRecord{ - ID: tcID, - SessionID: sessionID, - RequestMsgID: msgID, - ToolName: tc.tool, - Arguments: string(argsJSON), - OperationType: opType, - RiskScore: score, - RiskReasons: reasons, - RequestedAt: now, - }) - if err != nil { - log.Printf(" ✗ create tool call: %v", err) - continue - } + for ci, tc := range b.calls { + time.Sleep(time.Duration(tc.delayMs) * time.Millisecond) - // Set policy if specified - if tc.policySet != "" { - store.UpdateToolCallPolicy(tcID, tc.policySet, nil) - } + now := time.Now().UTC() + argsJSON, _ := json.Marshal(tc.args) + msgSeq++ - // Simulate response after a brief delay - respDelay := time.Duration(50+rand.Intn(200)) * time.Millisecond - time.Sleep(respDelay) - respTime := time.Now().UTC() + // Log request message + reqRaw := fmt.Sprintf(`{"jsonrpc":"2.0","id":%d,"method":"tools/call","params":{"name":"%s","arguments":%s}}`, + msgSeq, tc.tool, argsJSON) + msgID, err := store.LogMessage(sessionID, "client_to_server", fmt.Sprintf("%d", msgSeq), "tools/call", reqRaw) + if err != nil { + log.Printf(" ✗ log request: %v", err) + continue + } - resultJSON, _ := json.Marshal(tc.result) - respRaw := fmt.Sprintf(`{"jsonrpc":"2.0","id":%d,"result":%s}`, i+1, resultJSON) - respMsgID, err := store.LogMessage(sessionID, "server_to_client", fmt.Sprintf("%d", i+1), "", respRaw) - if err != nil { - log.Printf(" ✗ log response: %v", err) - continue - } + // Classify and score + opType := audit.ClassifyOperation(tc.tool) + score, reasons := audit.ScoreRisk(tc.tool, opType, string(argsJSON)) - durationMs := respTime.Sub(now).Milliseconds() - result := string(resultJSON) - store.CompleteToolCall(tcID, respMsgID, &result, nil, respTime, durationMs) + tcID := fmt.Sprintf("tc-%s-%d-%d", sc.server, bi+1, ci+1) + err = store.CreateToolCall(audit.ToolCallRecord{ + ID: tcID, + SessionID: sessionID, + RequestMsgID: msgID, + ToolName: tc.tool, + Arguments: string(argsJSON), + OperationType: opType, + RiskScore: score, + RiskReasons: reasons, + RequestedAt: now, + }) + if err != nil { + log.Printf(" ✗ create tool call: %v", err) + continue + } - // Intent grouping: new intent if gap > 2s - if intentID == "" || now.Sub(lastCallTime) > 2*time.Second { - intentID2, err := store.CreateIntentContext(sessionID) - if err == nil { - intentID = intentID2 + // Set policy if specified + if tc.policySet != "" { + store.UpdateToolCallPolicy(tcID, tc.policySet, nil) + } + + // Simulate response + respDelay := time.Duration(30+rand.Intn(150)) * time.Millisecond + time.Sleep(respDelay) + respTime := time.Now().UTC() + + resultJSON, _ := json.Marshal(tc.result) + respRaw := fmt.Sprintf(`{"jsonrpc":"2.0","id":%d,"result":%s}`, msgSeq, resultJSON) + respMsgID, err := store.LogMessage(sessionID, "server_to_client", fmt.Sprintf("%d", msgSeq), "", respRaw) + if err != nil { + log.Printf(" ✗ log response: %v", err) + continue } - } - if intentID != "" { - store.AddToolCallToIntent(intentID, tcID, i+1) - } - lastCallTime = now - policyTag := "" - if tc.policySet != "" { - policyTag = fmt.Sprintf(" [%s]", strings.ToUpper(tc.policySet)) + durationMs := respTime.Sub(now).Milliseconds() + result := string(resultJSON) + store.CompleteToolCall(tcID, respMsgID, &result, nil, respTime, durationMs) + + // Link to intent + store.AddToolCallToIntent(intentID, tcID, ci+1) + + policyTag := "" + if tc.policySet != "" { + policyTag = fmt.Sprintf(" [%s]", strings.ToUpper(tc.policySet)) + } + connector := "│" + if ci == len(b.calls)-1 { + connector = "└" + } + fmt.Printf(" %s %s %-20s %s %-7s risk:%2d%s\n", + connector, "✓", tc.tool, riskIndicator(score), opType, score, policyTag) } - riskBar := riskIndicator(score) - fmt.Printf(" %s %-20s %s %-7s risk:%2d %s%s\n", - "✓", tc.tool, riskBar, opType, score, reasons, policyTag) } - // End some sessions (leave last one "live") + // Leave postgres session "live" if sc.server != "postgres" { store.EndSession(sessionID) } } - fmt.Println("\n ✅ Done! Refresh the dashboard at http://localhost:8080") + fmt.Println() + fmt.Println(" ✅ Done! Open http://localhost:8080 → click a session → Intent Chains tab") } func m(kvs ...any) map[string]any { diff --git a/configs/example_claude.json b/configs/example_claude.json new file mode 100644 index 0000000..2229437 --- /dev/null +++ b/configs/example_claude.json @@ -0,0 +1,40 @@ +{ + "mcpServers": { + "github": { + "command": "beacon-proxy", + "args": [ + "--server-name", "github", + "--db", "~/.beacon/audit.db", + "--port", "8080", + "--", "npx", "-y", "@modelcontextprotocol/server-github" + ], + "env": { + "GITHUB_TOKEN": "YOUR_GITHUB_TOKEN" + } + }, + "filesystem": { + "command": "beacon-proxy", + "args": [ + "--server-name", "filesystem", + "--db", "~/.beacon/audit.db", + "--port", "8081", + "--", "npx", "-y", "@modelcontextprotocol/server-filesystem", + "/Users/you/projects", + "/Users/you/Documents" + ] + }, + "slack": { + "command": "beacon-proxy", + "args": [ + "--server-name", "slack", + "--db", "~/.beacon/audit.db", + "--port", "8082", + "--", "npx", "-y", "@modelcontextprotocol/server-slack" + ], + "env": { + "SLACK_BOT_TOKEN": "YOUR_SLACK_BOT_TOKEN", + "SLACK_TEAM_ID": "T00000000" + } + } + } +} diff --git a/docs/dashboard-toolcalls.png b/docs/dashboard-toolcalls.png new file mode 100644 index 0000000..cbef8b2 Binary files /dev/null and b/docs/dashboard-toolcalls.png differ diff --git a/docs/dashboard.png b/docs/dashboard.png index 28dcbb2..64582cb 100644 Binary files a/docs/dashboard.png and b/docs/dashboard.png differ diff --git a/internal/audit/queries.go b/internal/audit/queries.go index 0a6ddd9..3231809 100644 --- a/internal/audit/queries.go +++ b/internal/audit/queries.go @@ -157,23 +157,35 @@ type IntentToolEntry struct { SequenceOrder int `json:"sequence_order"` ToolCallID string `json:"tool_call_id"` ToolName string `json:"tool_name"` + Arguments string `json:"arguments"` OperationType string `json:"operation_type"` RiskScore int `json:"risk_score"` PolicyAction string `json:"policy_action"` + DurationMs *int64 `json:"duration_ms"` + ServerName string `json:"server_name"` } // ListIntents returns intent contexts with their tool calls for a session. +// If sessionID is empty, returns intents across all sessions. func (s *Store) ListIntents(sessionID string) ([]IntentSummary, error) { - rows, err := s.db.Query(` + query := ` SELECT ic.id, ic.session_id, ic.created_at, itc.sequence_order, itc.tool_call_id, - tc.tool_name, tc.operation_type, tc.risk_score, tc.policy_action + tc.tool_name, tc.arguments, tc.operation_type, tc.risk_score, + tc.policy_action, tc.duration_ms, se.server_name FROM intent_contexts ic JOIN intent_tool_calls itc ON itc.intent_id = ic.id JOIN tool_calls tc ON tc.id = itc.tool_call_id - WHERE ic.session_id = ? - ORDER BY ic.created_at ASC, itc.sequence_order ASC - `, sessionID) + JOIN sessions se ON se.id = ic.session_id + ` + var args []any + if sessionID != "" { + query += " WHERE ic.session_id = ?" + args = append(args, sessionID) + } + query += " ORDER BY ic.created_at ASC, itc.sequence_order ASC" + + rows, err := s.db.Query(query, args...) if err != nil { return nil, err } @@ -183,18 +195,24 @@ func (s *Store) ListIntents(sessionID string) ([]IntentSummary, error) { var order []string for rows.Next() { - var intentID, sessionID string + var intentID, sid string var createdAt time.Time var entry IntentToolEntry - if err := rows.Scan(&intentID, &sessionID, &createdAt, + var durationMs sql.NullInt64 + if err := rows.Scan(&intentID, &sid, &createdAt, &entry.SequenceOrder, &entry.ToolCallID, - &entry.ToolName, &entry.OperationType, &entry.RiskScore, &entry.PolicyAction); err != nil { + &entry.ToolName, &entry.Arguments, &entry.OperationType, + &entry.RiskScore, &entry.PolicyAction, &durationMs, &entry.ServerName); err != nil { return nil, err } + if durationMs.Valid { + entry.DurationMs = &durationMs.Int64 + } + entry.Arguments = s.decrypt(entry.Arguments) if _, ok := intentMap[intentID]; !ok { intentMap[intentID] = &IntentSummary{ ID: intentID, - SessionID: sessionID, + SessionID: sid, CreatedAt: createdAt, } order = append(order, intentID) diff --git a/internal/audit/store.go b/internal/audit/store.go index 3e09572..b8156e3 100644 --- a/internal/audit/store.go +++ b/internal/audit/store.go @@ -127,11 +127,9 @@ func migrateHashChain(db *sql.DB) error { } type Store struct { - db *sql.DB - mu sync.Mutex - enc *Encryptor - lastHash string // hash chain: hash of the most recent message - sequence int64 // monotonic sequence number for messages + db *sql.DB + mu sync.RWMutex // protects writes; reads use RLock + enc *Encryptor } // Open creates or opens the SQLite database at the given path. @@ -182,17 +180,6 @@ func Open(dbPath string, encryptionKey ...string) (*Store, error) { } store := &Store{db: db, enc: enc} - - // Resume hash chain from the most recent message (supports restarts) - var lastHash sql.NullString - var seq sql.NullInt64 - err = db.QueryRow("SELECT hash, sequence FROM messages ORDER BY sequence DESC LIMIT 1").Scan(&lastHash, &seq) - if err == nil { - store.lastHash = lastHash.String - store.sequence = seq.Int64 - } - // err == sql.ErrNoRows is fine — empty DB, chain starts fresh - return store, nil } @@ -234,6 +221,8 @@ const maxStoredMessageSize = 512 * 1024 // 512KB — truncate raw payloads beyon // LogMessage records a single JSON-RPC message and returns its ID. // Raw payloads exceeding 512KB are truncated to limit DB growth from large responses. // Each message is linked to the previous via a SHA-256 hash chain for tamper detection. +// Sequence and prev_hash are read from the DB within a transaction, so multiple +// processes sharing the same database will produce a correct chain. func (s *Store) LogMessage(sessionID, direction, jsonrpcID, method, raw string) (string, error) { id := uuid.New().String() stored := Redact(raw) @@ -244,20 +233,41 @@ func (s *Store) LogMessage(sessionID, direction, jsonrpcID, method, raw string) s.mu.Lock() defer s.mu.Unlock() - s.sequence++ - prevHash := s.lastHash + tx, err := s.db.Begin() + if err != nil { + return "", fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback() + + // Read chain state from DB — authoritative across processes + var prevHash string + var seq int64 + var lastHash sql.NullString + var lastSeq sql.NullInt64 + err = tx.QueryRow("SELECT hash, sequence FROM messages ORDER BY sequence DESC LIMIT 1").Scan(&lastHash, &lastSeq) + if err == nil { + prevHash = lastHash.String + seq = lastSeq.Int64 + } + // sql.ErrNoRows → empty DB, chain starts fresh + seq++ + ts := time.Now().UTC() - hash := computeHash(id, sessionID, direction, jsonrpcID, method, stored, s.sequence, prevHash) - s.lastHash = hash + hash := computeHash(id, sessionID, direction, jsonrpcID, method, stored, seq, prevHash) - _, err := s.db.Exec( + _, err = tx.Exec( "INSERT INTO messages (id, session_id, direction, timestamp, jsonrpc_id, method, raw, sequence, prev_hash, hash) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", id, sessionID, direction, ts, nullIfEmpty(jsonrpcID), nullIfEmpty(method), s.encrypt(stored), - s.sequence, prevHash, hash, + seq, prevHash, hash, ) if err != nil { return "", err } + + if err := tx.Commit(); err != nil { + return "", fmt.Errorf("commit: %w", err) + } + return id, nil } @@ -358,6 +368,7 @@ func (s *Store) CreateIntentContext(sessionID string) (string, error) { // DeleteSessionsBefore removes sessions (and all related data) older than the given time. // Returns the number of sessions deleted. +// After deletion, the hash chain is re-anchored so that VerifyChain continues to work. func (s *Store) DeleteSessionsBefore(cutoff time.Time) (int64, error) { s.mu.Lock() defer s.mu.Unlock() @@ -390,6 +401,12 @@ func (s *Store) DeleteSessionsBefore(cutoff time.Time) (int64, error) { return 0, fmt.Errorf("delete sessions: %w", err) } + // Re-anchor the hash chain: recompute hashes for remaining messages + // so the first message chains from "" and verification still passes. + if err := reanchorChain(tx, s.enc); err != nil { + return 0, fmt.Errorf("re-anchor hash chain: %w", err) + } + if err := tx.Commit(); err != nil { return 0, fmt.Errorf("commit: %w", err) } @@ -397,6 +414,56 @@ func (s *Store) DeleteSessionsBefore(cutoff time.Time) (int64, error) { return result.RowsAffected() } +// reanchorChain recomputes prev_hash and hash for all remaining messages +// so the chain starts from "" after pruning. Must run within a transaction. +func reanchorChain(tx *sql.Tx, enc *Encryptor) error { + rows, err := tx.Query( + "SELECT id, session_id, direction, jsonrpc_id, method, raw, sequence FROM messages ORDER BY sequence ASC", + ) + if err != nil { + return fmt.Errorf("query messages: %w", err) + } + defer rows.Close() + + type msgRow struct { + id, sessionID, direction, raw string + jsonrpcID, method sql.NullString + seq int64 + } + var msgs []msgRow + for rows.Next() { + var m msgRow + if err := rows.Scan(&m.id, &m.sessionID, &m.direction, &m.jsonrpcID, &m.method, &m.raw, &m.seq); err != nil { + return fmt.Errorf("scan row: %w", err) + } + msgs = append(msgs, m) + } + if err := rows.Err(); err != nil { + return err + } + + prevHash := "" + for _, m := range msgs { + // Decrypt raw if encrypted — hash was computed on plaintext + raw := m.raw + if enc != nil { + decrypted, err := enc.Decrypt(raw) + if err != nil { + return fmt.Errorf("decrypt sequence %d: %w", m.seq, err) + } + raw = decrypted + } + + hash := computeHash(m.id, m.sessionID, m.direction, m.jsonrpcID.String, m.method.String, raw, m.seq, prevHash) + if _, err := tx.Exec("UPDATE messages SET prev_hash = ?, hash = ? WHERE id = ?", prevHash, hash, m.id); err != nil { + return fmt.Errorf("update sequence %d: %w", m.seq, err) + } + prevHash = hash + } + + return nil +} + // AddToolCallToIntent links a tool call to an intent context. func (s *Store) AddToolCallToIntent(intentID, toolCallID string, sequenceOrder int) error { s.mu.Lock() @@ -411,10 +478,10 @@ func (s *Store) AddToolCallToIntent(intentID, toolCallID string, sequenceOrder i // ChainStatus holds the result of a hash chain verification. type ChainStatus struct { - Total int // total messages checked - Valid bool // true if the entire chain is intact - BrokenAt int // sequence number where the chain broke (0 if valid) - Error string // description of the break + Total int `json:"total"` // total messages checked + Valid bool `json:"valid"` // true if the entire chain is intact + BrokenAt int `json:"broken_at"` // sequence number where the chain broke (0 if valid) + Error string `json:"error"` // description of the break } // VerifyChain walks the message hash chain and checks for tampering. @@ -422,6 +489,9 @@ type ChainStatus struct { // If encryption is enabled, raw content is stored encrypted — the hash // was computed on pre-encryption content, so this method decrypts before verifying. func (s *Store) VerifyChain() (*ChainStatus, error) { + s.mu.RLock() + defer s.mu.RUnlock() + rows, err := s.db.Query( "SELECT id, session_id, direction, jsonrpc_id, method, raw, sequence, prev_hash, hash FROM messages ORDER BY sequence ASC", ) @@ -474,6 +544,87 @@ func (s *Store) VerifyChain() (*ChainStatus, error) { return status, rows.Err() } +// ChainEntry represents a single message in the hash chain verification result. +type ChainEntry struct { + Sequence int `json:"sequence"` + Direction string `json:"direction"` + Method string `json:"method"` + Hash string `json:"hash"` + PrevHash string `json:"prev_hash"` + Valid bool `json:"valid"` +} + +// VerifyChainDetail walks the hash chain and returns per-message verification results. +func (s *Store) VerifyChainDetail() (*ChainStatus, []ChainEntry, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + rows, err := s.db.Query( + "SELECT id, session_id, direction, jsonrpc_id, method, raw, sequence, prev_hash, hash FROM messages ORDER BY sequence ASC", + ) + if err != nil { + return nil, nil, fmt.Errorf("query messages: %w", err) + } + defer rows.Close() + + status := &ChainStatus{Valid: true} + var entries []ChainEntry + expectedPrevHash := "" + + for rows.Next() { + var id, sessionID, direction, raw, prevHash, hash string + var jsonrpcID, method sql.NullString + var seq int64 + if err := rows.Scan(&id, &sessionID, &direction, &jsonrpcID, &method, &raw, &seq, &prevHash, &hash); err != nil { + return nil, nil, fmt.Errorf("scan row: %w", err) + } + status.Total++ + + entry := ChainEntry{ + Sequence: int(seq), + Direction: direction, + Method: method.String, + Hash: hash, + PrevHash: prevHash, + Valid: true, + } + + decrypted, err := s.decryptVerify(raw) + if err != nil { + entry.Valid = false + status.Valid = false + status.BrokenAt = int(seq) + status.Error = fmt.Sprintf("sequence %d: decryption failed", seq) + entries = append(entries, entry) + return status, entries, nil + } + + if prevHash != expectedPrevHash { + entry.Valid = false + status.Valid = false + status.BrokenAt = int(seq) + status.Error = fmt.Sprintf("sequence %d: prev_hash mismatch", seq) + entries = append(entries, entry) + return status, entries, nil + } + + computed := computeHash(id, sessionID, direction, jsonrpcID.String, method.String, decrypted, seq, prevHash) + if hash != computed { + entry.Valid = false + status.Valid = false + status.BrokenAt = int(seq) + status.Error = fmt.Sprintf("sequence %d: hash mismatch (record modified)", seq) + entries = append(entries, entry) + return status, entries, nil + } + + expectedPrevHash = hash + entries = append(entries, entry) + } + + return status, entries, rows.Err() +} + // decryptVerify decrypts a string if encryption is enabled, returning an error on failure // (unlike decrypt which silently falls back to ciphertext). func (s *Store) decryptVerify(ciphertext string) (string, error) { diff --git a/internal/audit/store_test.go b/internal/audit/store_test.go index 893c4c6..1302080 100644 --- a/internal/audit/store_test.go +++ b/internal/audit/store_test.go @@ -588,6 +588,108 @@ func TestDeleteSessionsBefore(t *testing.T) { } } +func TestDeleteSessionsBefore_ReanchorsHashChain(t *testing.T) { + store, err := Open(tempDB(t)) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + // Create old session with messages + oldID, err := store.CreateSession("old", "cat") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + _, err = store.db.Exec("UPDATE sessions SET started_at = ? WHERE id = ?", + time.Now().UTC().AddDate(0, 0, -60), oldID) + if err != nil { + t.Fatalf("backdate: %v", err) + } + if _, err := store.LogMessage(oldID, "client_to_server", "1", "init", `{}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + if _, err := store.LogMessage(oldID, "server_to_client", "1", "", `{}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + + // Create new session with messages + newID, err := store.CreateSession("new", "cat") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if _, err := store.LogMessage(newID, "client_to_server", "2", "tools/call", `{}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + if _, err := store.LogMessage(newID, "server_to_client", "2", "", `{}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + + // Delete old session + cutoff := time.Now().UTC().AddDate(0, 0, -30) + deleted, err := store.DeleteSessionsBefore(cutoff) + if err != nil { + t.Fatalf("DeleteSessionsBefore: %v", err) + } + if deleted != 1 { + t.Errorf("deleted = %d, want 1", deleted) + } + + // Hash chain should still verify after pruning + status, err := store.VerifyChain() + if err != nil { + t.Fatalf("VerifyChain: %v", err) + } + if !status.Valid { + t.Errorf("chain should be valid after retention pruning, got error at sequence %d: %s", + status.BrokenAt, status.Error) + } + if status.Total != 2 { + t.Errorf("total = %d, want 2 (remaining messages)", status.Total) + } +} + +func TestDeleteSessionsBefore_ReanchorsWithEncryption(t *testing.T) { + store, err := Open(tempDB(t), "test-key") + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + oldID, err := store.CreateSession("old", "cat") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + _, err = store.db.Exec("UPDATE sessions SET started_at = ? WHERE id = ?", + time.Now().UTC().AddDate(0, 0, -60), oldID) + if err != nil { + t.Fatalf("backdate: %v", err) + } + if _, err := store.LogMessage(oldID, "client_to_server", "1", "init", `{"secret":"data"}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + + newID, err := store.CreateSession("new", "cat") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if _, err := store.LogMessage(newID, "client_to_server", "2", "tools/call", `{"hello":"world"}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + + cutoff := time.Now().UTC().AddDate(0, 0, -30) + if _, err := store.DeleteSessionsBefore(cutoff); err != nil { + t.Fatalf("DeleteSessionsBefore: %v", err) + } + + status, err := store.VerifyChain() + if err != nil { + t.Fatalf("VerifyChain: %v", err) + } + if !status.Valid { + t.Errorf("chain should be valid after encrypted retention pruning: %s", status.Error) + } +} + func TestDeleteSessionsBefore_NothingToDelete(t *testing.T) { store, err := Open(tempDB(t)) if err != nil { diff --git a/internal/proxy/jsonrpc_test.go b/internal/proxy/jsonrpc_test.go index 10722e4..a3d22f9 100644 --- a/internal/proxy/jsonrpc_test.go +++ b/internal/proxy/jsonrpc_test.go @@ -1,6 +1,7 @@ package proxy import ( + "encoding/json" "testing" ) @@ -117,6 +118,79 @@ func TestMessage_IsResponse(t *testing.T) { } } +func TestBlockErrorJSON_StringID(t *testing.T) { + // Verify that using raw ID bytes produces valid JSON for string IDs + line := []byte(`{"jsonrpc":"2.0","id":"request-42","method":"tools/call","params":{"name":"delete_file"}}`) + msg := ParseMessage(line) + if msg == nil { + t.Fatal("expected message") + } + + // Simulate what sendBlockError does with the raw ID + errResp := map[string]any{ + "jsonrpc": "2.0", + "id": msg.ID, // json.RawMessage preserves the original bytes + "error": map[string]any{ + "code": -32000, + "message": "blocked", + }, + } + data, err := json.Marshal(errResp) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + // Verify it's valid JSON by round-tripping + var parsed map[string]any + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("error response is invalid JSON: %v\ndata: %s", err, data) + } + + // Verify the ID is preserved as a string + id, ok := parsed["id"].(string) + if !ok { + t.Fatalf("id should be a string, got %T: %v", parsed["id"], parsed["id"]) + } + if id != "request-42" { + t.Errorf("id = %q, want %q", id, "request-42") + } +} + +func TestBlockErrorJSON_NumericID(t *testing.T) { + line := []byte(`{"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"name":"delete_file"}}`) + msg := ParseMessage(line) + if msg == nil { + t.Fatal("expected message") + } + + errResp := map[string]any{ + "jsonrpc": "2.0", + "id": msg.ID, + "error": map[string]any{ + "code": -32000, + "message": "blocked", + }, + } + data, err := json.Marshal(errResp) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var parsed map[string]any + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("error response is invalid JSON: %v\ndata: %s", err, data) + } + + // JSON numbers unmarshal as float64 + id, ok := parsed["id"].(float64) + if !ok { + t.Fatalf("id should be a number, got %T: %v", parsed["id"], parsed["id"]) + } + if id != 5 { + t.Errorf("id = %v, want 5", id) + } +} + func TestParseToolCallParams(t *testing.T) { tests := []struct { name string diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index eec38cb..25972f7 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -18,11 +18,12 @@ import ( "github.com/ottojongerius/beacon/internal/web" ) -const maxMessageSize = 1 << 20 // 1MB +const maxMessageSize = 10 << 20 // 10MB — large enough for big file contents const pauseTimeout = 60 * time.Second type pendingCall struct { toolCallID string + rawID json.RawMessage // original JSON-RPC ID bytes for constructing responses msgID string requestedAt time.Time } @@ -36,10 +37,16 @@ type Proxy struct { Intent *audit.IntentTracker Dashboard *web.Server - pendingCalls map[string]*pendingCall // keyed by JSON-RPC id + pendingCalls map[string]*pendingCall // keyed by JSON-RPC id string pendingMu sync.Mutex - clientWriter io.Writer // for sending block/deny errors back to client - clientMu sync.Mutex + + clientWriter io.Writer // for sending messages to the MCP client + clientMu sync.Mutex // serializes all writes to clientWriter + + serverWriter io.Writer // for sending messages to the MCP server + serverMu sync.Mutex // serializes all writes to serverWriter + + pauseWg sync.WaitGroup // tracks in-flight pause goroutines } func (p *Proxy) Run(ctx context.Context) error { @@ -70,6 +77,7 @@ func (p *Proxy) Run(ctx context.Context) error { stdout := bufio.NewWriter(os.Stdout) p.clientWriter = stdout + p.serverWriter = serverIn var wg sync.WaitGroup wg.Add(2) @@ -77,14 +85,17 @@ func (p *Proxy) Run(ctx context.Context) error { // Client → Server go func() { defer wg.Done() - defer serverIn.Close() - p.pipe(os.Stdin, serverIn, "client_to_server", sessionID) + defer func() { + p.pauseWg.Wait() // wait for pending pause goroutines before closing + serverIn.Close() + }() + p.pipeClientToServer(os.Stdin, "client_to_server", sessionID) }() // Server → Client go func() { defer wg.Done() - p.pipe(serverOut, stdout, "server_to_client", sessionID) + p.pipeServerToClient(serverOut, "server_to_client", sessionID) }() wg.Wait() @@ -92,7 +103,9 @@ func (p *Proxy) Run(ctx context.Context) error { return cmd.Wait() } -func (p *Proxy) pipe(src io.Reader, dst io.Writer, direction, sessionID string) { +// pipeClientToServer reads from the MCP client, applies policy, and forwards to the server. +// Pause decisions are handled asynchronously to avoid blocking other messages. +func (p *Proxy) pipeClientToServer(src io.Reader, direction, sessionID string) { scanner := bufio.NewScanner(src) scanner.Buffer(make([]byte, 0, maxMessageSize), maxMessageSize) @@ -109,48 +122,110 @@ func (p *Proxy) pipe(src io.Reader, dst io.Writer, direction, sessionID string) msgID, err := p.Store.LogMessage(sessionID, direction, jsonrpcID, method, raw) if err != nil { - log.Printf("beacon: failed to log message") + log.Printf("beacon: failed to log message: %v", err) } // Policy + pairing for tool calls - if msg != nil { - if direction == "client_to_server" && method == "tools/call" { - action := p.handleToolCallRequest(sessionID, msgID, jsonrpcID, msg) - if action == "block" || action == "deny" { - // Don't forward to server — error already sent to client - continue - } - } else if direction == "server_to_client" && msg.IsResponse() { - p.handleToolCallResponse(msgID, jsonrpcID, msg) + if msg != nil && method == "tools/call" { + action := p.handleToolCallRequest(sessionID, msgID, jsonrpcID, msg, line) + switch action { + case "block", "deny": + // Don't forward — error already sent to client + continue + case "pause": + // Message will be forwarded asynchronously if approved + continue } } - if _, err := dst.Write(line); err != nil { - log.Printf("beacon: write error (%s)", direction) - return + // Forward to server + p.writeToServer(line) + } + + if err := scanner.Err(); err != nil { + log.Printf("beacon: scan error (%s): %v", direction, err) + } +} + +// pipeServerToClient reads from the MCP server and forwards to the client. +// All writes go through clientMu to prevent interleaving with block/deny errors. +func (p *Proxy) pipeServerToClient(src io.Reader, direction, sessionID string) { + scanner := bufio.NewScanner(src) + scanner.Buffer(make([]byte, 0, maxMessageSize), maxMessageSize) + + for scanner.Scan() { + line := scanner.Bytes() + raw := string(line) + + var jsonrpcID, method string + var msg *Message + if msg = ParseMessage(line); msg != nil { + jsonrpcID = msg.IDString() + method = msg.Method } - if _, err := dst.Write([]byte("\n")); err != nil { - log.Printf("beacon: write error (%s)", direction) - return + + if _, err := p.Store.LogMessage(sessionID, direction, jsonrpcID, method, raw); err != nil { + log.Printf("beacon: failed to log message: %v", err) } - // Flush after every message for low latency - if f, ok := dst.(interface{ Flush() error }); ok { - if err := f.Flush(); err != nil { - log.Printf("beacon: flush error (%s)", direction) - return - } + // Pair responses with pending tool calls + if msg != nil && msg.IsResponse() { + p.handleToolCallResponse(jsonrpcID, msg) } + + // Forward to client (under lock to prevent interleaving with sendBlockError) + p.writeToClient(line) } if err := scanner.Err(); err != nil { - log.Printf("beacon: scan error (%s)", direction) + log.Printf("beacon: scan error (%s): %v", direction, err) + } +} + +// writeToClient sends a message line to the MCP client, serialized with clientMu. +func (p *Proxy) writeToClient(line []byte) { + p.clientMu.Lock() + defer p.clientMu.Unlock() + + if _, err := p.clientWriter.Write(line); err != nil { + log.Printf("beacon: write error (to client): %v", err) + return + } + if _, err := p.clientWriter.Write([]byte("\n")); err != nil { + log.Printf("beacon: write error (to client): %v", err) + return + } + if f, ok := p.clientWriter.(interface{ Flush() error }); ok { + if err := f.Flush(); err != nil { + log.Printf("beacon: flush error (to client): %v", err) + } + } +} + +// writeToServer sends a message line to the MCP server, serialized with serverMu. +func (p *Proxy) writeToServer(line []byte) { + p.serverMu.Lock() + defer p.serverMu.Unlock() + + if _, err := p.serverWriter.Write(line); err != nil { + log.Printf("beacon: write error (to server): %v", err) + return + } + if _, err := p.serverWriter.Write([]byte("\n")); err != nil { + log.Printf("beacon: write error (to server): %v", err) + return + } + if f, ok := p.serverWriter.(interface{ Flush() error }); ok { + if err := f.Flush(); err != nil { + log.Printf("beacon: flush error (to server): %v", err) + } } } // handleToolCallRequest classifies, scores, evaluates policy, and returns the action taken. -// Returns "pass", "flag", "block", or "deny" (pause that was denied/timed out). -func (p *Proxy) handleToolCallRequest(sessionID, msgID, jsonrpcID string, msg *Message) string { +// Returns "pass", "flag", "block", "deny", or "pause". +// For "pause", the message is forwarded asynchronously after approval. +func (p *Proxy) handleToolCallRequest(sessionID, msgID, jsonrpcID string, msg *Message, rawLine []byte) string { params := ParseToolCallParams(msg.Params) if params == nil { return "pass" @@ -179,7 +254,7 @@ func (p *Proxy) handleToolCallRequest(sessionID, msgID, jsonrpcID string, msg *M } if err := p.Store.CreateToolCall(tc); err != nil { - log.Printf("beacon: failed to create tool call") + log.Printf("beacon: failed to create tool call: %v", err) return "pass" } @@ -200,9 +275,14 @@ func (p *Proxy) handleToolCallRequest(sessionID, msgID, jsonrpcID string, msg *M p.Intent.Track(sessionID, tcID, now) } + // Store raw ID for constructing JSON-RPC responses + rawID := make(json.RawMessage, len(msg.ID)) + copy(rawID, msg.ID) + p.pendingMu.Lock() p.pendingCalls[jsonrpcID] = &pendingCall{ toolCallID: tcID, + rawID: rawID, msgID: msgID, requestedAt: now, } @@ -227,23 +307,39 @@ func (p *Proxy) handleToolCallRequest(sessionID, msgID, jsonrpcID string, msg *M if decision.Rule != nil { reason = decision.Rule.Name } - p.sendBlockError(jsonrpcID, fmt.Sprintf("Blocked by policy: %s", reason)) + p.sendBlockError(rawID, fmt.Sprintf("Blocked by policy: %s", reason)) return "block" case "pause": p.Store.UpdateToolCallPolicy(tcID, "pause", nil) log.Printf("beacon: tool call %s (%s) paused, waiting for approval", tcID, params.Name) - result := p.Policy.WaitForApproval(tcID, pauseTimeout) - if result.Approved { - p.Store.UpdateToolCallPolicy(tcID, "pass", &result.ApprovedBy) - return "pass" - } - reason := "denied" + + // Copy the message line for async forwarding + lineCopy := make([]byte, len(rawLine)) + copy(lineCopy, rawLine) + + ruleName := "" if decision.Rule != nil { - reason = decision.Rule.Name + " (denied/timeout)" + ruleName = decision.Rule.Name } - p.sendBlockError(jsonrpcID, fmt.Sprintf("Denied by policy: %s", reason)) - return "deny" + + // Handle approval asynchronously — don't block the pipe + p.pauseWg.Add(1) + go func() { + defer p.pauseWg.Done() + result := p.Policy.WaitForApproval(tcID, pauseTimeout) + if result.Approved { + p.Store.UpdateToolCallPolicy(tcID, "pass", &result.ApprovedBy) + p.writeToServer(lineCopy) + } else { + reason := "denied" + if ruleName != "" { + reason = ruleName + " (denied/timeout)" + } + p.sendBlockError(rawID, fmt.Sprintf("Denied by policy: %s", reason)) + } + }() + return "pause" case "flag": p.Store.UpdateToolCallPolicy(tcID, "flag", nil) @@ -254,7 +350,7 @@ func (p *Proxy) handleToolCallRequest(sessionID, msgID, jsonrpcID string, msg *M } } -func (p *Proxy) handleToolCallResponse(msgID, jsonrpcID string, msg *Message) { +func (p *Proxy) handleToolCallResponse(jsonrpcID string, msg *Message) { p.pendingMu.Lock() pending, ok := p.pendingCalls[jsonrpcID] if ok { @@ -279,16 +375,17 @@ func (p *Proxy) handleToolCallResponse(msgID, jsonrpcID string, msg *Message) { errJSON = &s } - if err := p.Store.CompleteToolCall(pending.toolCallID, msgID, result, errJSON, now, durationMs); err != nil { - log.Printf("beacon: failed to complete tool call") + if err := p.Store.CompleteToolCall(pending.toolCallID, pending.msgID, result, errJSON, now, durationMs); err != nil { + log.Printf("beacon: failed to complete tool call: %v", err) } } // sendBlockError writes a JSON-RPC error response back to the client. -func (p *Proxy) sendBlockError(jsonrpcID, message string) { +// Uses the raw JSON-RPC ID bytes to produce valid JSON for both string and numeric IDs. +func (p *Proxy) sendBlockError(rawID json.RawMessage, message string) { errResp := map[string]any{ "jsonrpc": "2.0", - "id": json.RawMessage(jsonrpcID), + "id": rawID, "error": map[string]any{ "code": -32000, "message": message, @@ -297,24 +394,11 @@ func (p *Proxy) sendBlockError(jsonrpcID, message string) { data, err := json.Marshal(errResp) if err != nil { - log.Printf("beacon: failed to marshal block error") + log.Printf("beacon: failed to marshal block error: %v", err) return } - p.clientMu.Lock() - defer p.clientMu.Unlock() - - if _, err := p.clientWriter.Write(data); err != nil { - log.Printf("beacon: failed to send block error") - return - } - if _, err := p.clientWriter.Write([]byte("\n")); err != nil { - log.Printf("beacon: failed to send block error newline") - return - } - if f, ok := p.clientWriter.(interface{ Flush() error }); ok { - f.Flush() - } + p.writeToClient(data) } func (p *Proxy) fullCommand() string { diff --git a/internal/web/server.go b/internal/web/server.go index aced385..d173972 100644 --- a/internal/web/server.go +++ b/internal/web/server.go @@ -19,17 +19,21 @@ var staticFiles embed.FS // Server is the dashboard web server. type Server struct { - store *audit.Store - engine *policy.Engine - hub *wsHub + store *audit.Store + engine *policy.Engine + hub *wsHub + authToken string // required for approve/deny endpoints; empty disables auth } // NewServer creates a new dashboard server. -func NewServer(store *audit.Store, engine *policy.Engine) *Server { +// If authToken is non-empty, approve/deny endpoints require it via +// X-Beacon-Token header or ?token= query parameter. +func NewServer(store *audit.Store, engine *policy.Engine, authToken string) *Server { return &Server{ - store: store, - engine: engine, - hub: newHub(), + store: store, + engine: engine, + hub: newHub(), + authToken: authToken, } } @@ -41,8 +45,10 @@ func (s *Server) Handler() http.Handler { mux.HandleFunc("GET /api/stats", s.handleStats) mux.HandleFunc("GET /api/sessions", s.handleSessions) mux.HandleFunc("GET /api/tool-calls", s.handleToolCalls) + mux.HandleFunc("GET /api/intents", s.handleAllIntents) mux.HandleFunc("GET /api/sessions/{id}/intents", s.handleIntents) mux.HandleFunc("GET /api/chain/verify", s.handleVerifyChain) + mux.HandleFunc("GET /api/chain/detail", s.handleVerifyChainDetail) mux.HandleFunc("GET /api/tool-calls/pending", s.handlePending) mux.HandleFunc("POST /api/tool-calls/{id}/approve", s.handleApprove) mux.HandleFunc("POST /api/tool-calls/{id}/deny", s.handleDeny) @@ -102,6 +108,18 @@ func (s *Server) handleToolCalls(w http.ResponseWriter, r *http.Request) { writeJSON(w, calls) } +func (s *Server) handleAllIntents(w http.ResponseWriter, r *http.Request) { + intents, err := s.store.ListIntents("") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if intents == nil { + intents = []audit.IntentSummary{} + } + writeJSON(w, intents) +} + func (s *Server) handleIntents(w http.ResponseWriter, r *http.Request) { sessionID := r.PathValue("id") intents, err := s.store.ListIntents(sessionID) @@ -124,12 +142,27 @@ func (s *Server) handleVerifyChain(w http.ResponseWriter, r *http.Request) { writeJSON(w, status) } +func (s *Server) handleVerifyChainDetail(w http.ResponseWriter, r *http.Request) { + status, entries, err := s.store.VerifyChainDetail() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, map[string]any{ + "status": status, + "entries": entries, + }) +} + func (s *Server) handlePending(w http.ResponseWriter, r *http.Request) { ids := s.engine.PendingApprovals() writeJSON(w, map[string]any{"pending": ids}) } func (s *Server) handleApprove(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(w, r) { + return + } id := r.PathValue("id") if err := s.engine.Approve(id, "dashboard-user"); err != nil { http.Error(w, err.Error(), http.StatusNotFound) @@ -139,6 +172,9 @@ func (s *Server) handleApprove(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleDeny(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(w, r) { + return + } id := r.PathValue("id") if err := s.engine.Deny(id); err != nil { http.Error(w, err.Error(), http.StatusNotFound) @@ -147,6 +183,24 @@ func (s *Server) handleDeny(w http.ResponseWriter, r *http.Request) { writeJSON(w, map[string]string{"status": "denied"}) } +// checkAuth validates the request has a valid auth token. +// Returns true if authorized. Writes a 401 response and returns false otherwise. +// If no authToken is configured, all requests are allowed. +func (s *Server) checkAuth(w http.ResponseWriter, r *http.Request) bool { + if s.authToken == "" { + return true // auth disabled + } + token := r.Header.Get("X-Beacon-Token") + if token == "" { + token = r.URL.Query().Get("token") + } + if token != s.authToken { + http.Error(w, "unauthorized: provide token via X-Beacon-Token header or ?token= query param", http.StatusUnauthorized) + return false + } + return true +} + func writeJSON(w http.ResponseWriter, v any) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(v) diff --git a/internal/web/server_test.go b/internal/web/server_test.go index 2f99fad..78dc053 100644 --- a/internal/web/server_test.go +++ b/internal/web/server_test.go @@ -3,6 +3,7 @@ package web import ( "context" "encoding/json" + "net/http" "net/http/httptest" "os" "path/filepath" @@ -23,7 +24,7 @@ func setupTestServer(t *testing.T) (*Server, *audit.Store) { t.Cleanup(func() { store.Close() }) engine := policy.NewEngine(policy.DefaultRules()) - srv := NewServer(store, engine) + srv := NewServer(store, engine, "test-token") return srv, store } @@ -175,30 +176,40 @@ func TestDashboardServed(t *testing.T) { func TestSSEEndpoint(t *testing.T) { srv, _ := setupTestServer(t) - handler := srv.Handler() - // Use a cancellable context so we can cleanly stop the SSE handler - ctx, cancel := context.WithCancel(context.Background()) + // Use a real httptest.Server to avoid data race on ResponseRecorder — + // the SSE handler writes asynchronously and ResponseRecorder is not thread-safe. + ts := httptest.NewServer(srv.Handler()) + defer ts.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - req := httptest.NewRequest("GET", "/ws/live", nil).WithContext(ctx) - w := httptest.NewRecorder() + req, _ := http.NewRequestWithContext(ctx, "GET", ts.URL+"/ws/live", nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + // Context cancellation is expected for SSE + if ctx.Err() == nil { + t.Fatalf("SSE request failed: %v", err) + } + return + } + defer resp.Body.Close() - // Run in goroutine since SSE blocks - done := make(chan struct{}) - go func() { - handler.ServeHTTP(w, req) - close(done) - }() + if resp.StatusCode != 200 { + t.Errorf("SSE endpoint returned %d, want 200", resp.StatusCode) + } - // Cancel the context to stop the handler, then wait for it to finish - // before reading the response (avoids data race on ResponseRecorder) - time.Sleep(50 * time.Millisecond) - cancel() - <-done + ct := resp.Header.Get("Content-Type") + if ct != "text/event-stream" { + t.Errorf("Content-Type = %q, want text/event-stream", ct) + } - if w.Code != 200 { - t.Errorf("SSE endpoint returned %d, want 200", w.Code) + // Read at least one SSE event + buf := make([]byte, 4096) + n, _ := resp.Body.Read(buf) + if n == 0 { + t.Error("expected at least one SSE event") } } @@ -224,5 +235,77 @@ func TestBroadcast(t *testing.T) { } } +func TestApproveRequiresAuth(t *testing.T) { + srv, _ := setupTestServer(t) + handler := srv.Handler() + + // No token — should get 401 + req := httptest.NewRequest("POST", "/api/tool-calls/tc-1/approve", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("no token: status = %d, want 401", w.Code) + } + + // Wrong token — should get 401 + req = httptest.NewRequest("POST", "/api/tool-calls/tc-1/approve", nil) + req.Header.Set("X-Beacon-Token", "wrong-token") + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("wrong token: status = %d, want 401", w.Code) + } + + // Correct token via header — should pass auth (404 because no pending approval) + req = httptest.NewRequest("POST", "/api/tool-calls/tc-1/approve", nil) + req.Header.Set("X-Beacon-Token", "test-token") + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code == http.StatusUnauthorized { + t.Error("correct token should not get 401") + } + + // Correct token via query param + req = httptest.NewRequest("POST", "/api/tool-calls/tc-1/approve?token=test-token", nil) + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code == http.StatusUnauthorized { + t.Error("correct query token should not get 401") + } +} + +func TestDenyRequiresAuth(t *testing.T) { + srv, _ := setupTestServer(t) + handler := srv.Handler() + + req := httptest.NewRequest("POST", "/api/tool-calls/tc-1/deny", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("no token: status = %d, want 401", w.Code) + } +} + +func TestAuthDisabledWhenEmpty(t *testing.T) { + db := filepath.Join(t.TempDir(), "test.db") + store, err := audit.Open(db) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + engine := policy.NewEngine(policy.DefaultRules()) + srv := NewServer(store, engine, "") // empty token = auth disabled + handler := srv.Handler() + + // Should pass without token (404 because no pending approval, but not 401) + req := httptest.NewRequest("POST", "/api/tool-calls/tc-1/approve", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code == http.StatusUnauthorized { + t.Error("empty auth token should disable auth") + } +} + // Ensure unused import doesn't cause issues var _ = os.DevNull diff --git a/internal/web/static/index.html b/internal/web/static/index.html index dfd89be..53c09ec 100644 --- a/internal/web/static/index.html +++ b/internal/web/static/index.html @@ -318,16 +318,31 @@ background: var(--surface); border: 1px solid var(--border); border-radius: 8px; - padding: 16px; + padding: 16px 20px; + margin-bottom: 16px; + } + + .intent-group.has-risk { border-left: 3px solid var(--yellow); } + .intent-group.has-danger { border-left: 3px solid var(--red); } + + .intent-header { + display: flex; + justify-content: space-between; + align-items: flex-start; margin-bottom: 12px; } - .intent-group .intent-header { - font-size: 12px; + .intent-label { + font-size: 14px; + font-weight: 600; + } + + .intent-meta { + font-size: 11px; color: var(--text-dim); - margin-bottom: 10px; + margin-top: 2px; display: flex; - justify-content: space-between; + gap: 12px; } .intent-chain { @@ -339,16 +354,16 @@ .intent-step { display: flex; align-items: center; - gap: 12px; - padding: 6px 0; + gap: 10px; + padding: 7px 0; position: relative; - padding-left: 24px; + padding-left: 28px; } .intent-step::before { content: ''; position: absolute; - left: 7px; + left: 8px; top: 0; bottom: 0; width: 2px; @@ -362,20 +377,58 @@ .intent-step::after { content: ''; position: absolute; - left: 3px; + left: 4px; top: 50%; transform: translateY(-50%); width: 10px; height: 10px; border-radius: 50%; background: var(--border); - border: 2px solid var(--bg); + border: 2px solid var(--surface); z-index: 1; } + .intent-step.step-risk-low::after { background: var(--green); } + .intent-step.step-risk-med::after { background: var(--yellow); } + .intent-step.step-risk-high::after { background: var(--orange); } + .intent-step.step-risk-crit::after { background: var(--red); } + .intent-step .step-name { font-family: 'SF Mono', 'Fira Code', monospace; font-size: 13px; + min-width: 160px; + } + + .step-args { + font-size: 11px; + color: var(--text-dim); + max-width: 400px; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + } + + .step-duration { + font-size: 11px; + color: var(--text-dim); + font-family: 'SF Mono', 'Fira Code', monospace; + margin-left: auto; + } + + .intent-summary { + margin-top: 10px; + padding-top: 10px; + border-top: 1px solid var(--border); + font-size: 12px; + color: var(--text-dim); + display: flex; + gap: 16px; + } + + .intent-summary .summary-item { + display: flex; + align-items: center; + gap: 4px; } /* Chain verification */ @@ -394,6 +447,94 @@ .chain-status .title { font-size: 16px; font-weight: 600; margin-bottom: 4px; } .chain-status .detail { font-size: 13px; color: var(--text-dim); } + .chain-entries { + margin-top: 16px; + display: flex; + flex-direction: column; + gap: 0; + } + + .chain-entry { + display: grid; + grid-template-columns: 32px 60px 130px 1fr; + align-items: center; + gap: 10px; + padding: 8px 12px; + background: var(--surface); + border: 1px solid var(--border); + border-bottom: none; + font-size: 12px; + font-family: 'SF Mono', 'Fira Code', monospace; + } + + .chain-entry:first-child { border-radius: 8px 8px 0 0; } + .chain-entry:last-child { border-radius: 0 0 8px 8px; border-bottom: 1px solid var(--border); } + .chain-entry:only-child { border-radius: 8px; border-bottom: 1px solid var(--border); } + + .chain-entry.valid { border-left: 3px solid var(--green); } + .chain-entry.invalid { border-left: 3px solid var(--red); background: rgba(239,68,68,0.05); } + + .chain-entry .seq { + color: var(--text-dim); + font-size: 11px; + text-align: center; + } + + .chain-entry .direction { + font-size: 11px; + padding: 2px 6px; + border-radius: 4px; + text-align: center; + } + + .chain-entry .direction.req { background: rgba(59,130,246,0.15); color: #60a5fa; } + .chain-entry .direction.res { background: rgba(16,185,129,0.15); color: #34d399; } + + .chain-entry .method-name { + color: var(--text); + font-weight: 500; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + } + + .chain-entry .hash-value { + color: var(--text-dim); + font-size: 11px; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + } + + .chain-link { + display: flex; + align-items: center; + justify-content: center; + height: 20px; + color: var(--text-dim); + font-size: 10px; + opacity: 0.5; + } + + .chain-link.broken { color: var(--red); opacity: 1; } + + .chain-summary { + display: flex; + gap: 16px; + justify-content: center; + margin-bottom: 16px; + } + + .chain-summary .stat { + display: flex; + align-items: center; + gap: 6px; + font-size: 13px; + color: var(--text-dim); + } + + .chain-summary .stat .num { color: var(--text); font-weight: 600; } + /* Live feed */ .live-indicator { display: inline-flex; @@ -521,8 +662,8 @@

No tool calls yet

🔗
-

Select a session

-

Intent chains group tool calls by temporal proximity.

+

Intent Chains

+

Tool calls grouped by temporal proximity — showing what the agent did and why.

@@ -556,7 +697,7 @@

Select a session

document.getElementById('tab-tool-calls').style.display = target === 'tool-calls' ? '' : 'none'; document.getElementById('tab-intents').style.display = target === 'intents' ? '' : 'none'; document.getElementById('tab-chain').style.display = target === 'chain' ? '' : 'none'; - if (target === 'intents' && currentSession) loadIntents(currentSession); + if (target === 'intents') loadIntents(currentSession); }); }); @@ -640,7 +781,7 @@

Select a session

currentSession = id; loadSessions(); loadToolCalls(); - if (document.querySelector('.tab[data-tab="intents"]').classList.contains('active') && id) { + if (document.querySelector('.tab[data-tab="intents"]').classList.contains('active')) { loadIntents(id); } } @@ -722,66 +863,196 @@

Select a session

// --- Intents --- async function loadIntents(sessionId) { - if (!sessionId) return; try { - const res = await fetch(API + `/api/sessions/${sessionId}/intents`); + let url; + if (sessionId) { + url = API + `/api/sessions/${sessionId}/intents`; + } else { + url = API + '/api/intents'; + } + const res = await fetch(url); const intents = await res.json(); const container = document.getElementById('intents-container'); if (!intents.length) { container.innerHTML = `
🔗
-

No intent chains

+

No intent chains yet

Intent chains will appear when tool calls are grouped by temporal proximity.

`; return; } - container.innerHTML = intents.map((intent, i) => ` -
+ container.innerHTML = intents.map((intent, i) => { + const maxRisk = Math.max(...intent.tool_calls.map(tc => tc.risk_score)); + const hasBlock = intent.tool_calls.some(tc => tc.policy_action === 'block'); + const hasPause = intent.tool_calls.some(tc => tc.policy_action === 'pause'); + const hasFlag = intent.tool_calls.some(tc => tc.policy_action === 'flag'); + const dangerClass = hasBlock ? 'has-danger' : (hasPause || hasFlag) ? 'has-risk' : ''; + const serverName = intent.tool_calls[0]?.server_name || ''; + const totalDuration = intent.tool_calls.reduce((sum, tc) => sum + (tc.duration_ms || 0), 0); + const reads = intent.tool_calls.filter(tc => tc.operation_type === 'read').length; + const writes = intent.tool_calls.filter(tc => tc.operation_type === 'write').length; + const deletes = intent.tool_calls.filter(tc => tc.operation_type === 'delete').length; + const execs = intent.tool_calls.filter(tc => tc.operation_type === 'execute').length; + + // Build a narrative summary + const narrative = buildNarrative(intent.tool_calls); + + return ` +
- Intent #${i + 1} — ${intent.tool_calls.length} action(s) - ${formatTime(intent.created_at)} +
+
Intent #${i + 1} — ${esc(narrative)}
+
+ ${esc(serverName)} + ${intent.tool_calls.length} action(s) + ${totalDuration}ms total + ${formatTime(intent.created_at)} +
+
+ peak ${maxRisk}
- ${intent.tool_calls.map(tc => ` -
+ ${intent.tool_calls.map(tc => { + const stepClass = riskStepClass(tc.risk_score); + const argsSummary = summarizeArgs(tc.arguments); + return ` +
${esc(tc.tool_name)} ${tc.operation_type} ${tc.risk_score} - ${tc.policy_action} -
- `).join('')} + ${tc.policy_action !== 'pass' ? `${tc.policy_action}` : ''} + ${esc(argsSummary)} + ${tc.duration_ms != null ? tc.duration_ms + 'ms' : ''} +
`; + }).join('')}
-
- `).join(''); +
+ ${reads ? `read ×${reads}` : ''} + ${writes ? `write ×${writes}` : ''} + ${deletes ? `delete ×${deletes}` : ''} + ${execs ? `execute ×${execs}` : ''} +
+
`; + }).join(''); } catch(e) { console.error('Intents load failed:', e); } } +function buildNarrative(toolCalls) { + const ops = toolCalls.map(tc => tc.operation_type); + const uniqueOps = [...new Set(ops)]; + + if (uniqueOps.length === 1 && uniqueOps[0] === 'read') { + return 'Read-only exploration'; + } + if (ops.includes('delete') || ops.includes('execute')) { + const dangerTools = toolCalls.filter(tc => tc.operation_type === 'delete' || tc.operation_type === 'execute'); + return `${ops.filter(o => o === 'read').length} reads → ${dangerTools.map(tc => tc.tool_name).join(', ')}`; + } + if (ops.includes('write')) { + const writeTools = toolCalls.filter(tc => tc.operation_type === 'write'); + return `${ops.filter(o => o === 'read').length} reads → ${writeTools.map(tc => tc.tool_name).join(', ')}`; + } + return `${toolCalls.length} actions`; +} + +function summarizeArgs(argsStr) { + try { + const args = JSON.parse(argsStr); + const parts = []; + for (const [k, v] of Object.entries(args)) { + const val = typeof v === 'string' ? v : JSON.stringify(v); + if (val.length > 60) { + parts.push(k + '=' + val.substring(0, 57) + '...'); + } else { + parts.push(k + '=' + val); + } + } + return parts.join(' '); + } catch(e) { return ''; } +} + +function riskStepClass(score) { + if (score <= 30) return 'step-risk-low'; + if (score <= 60) return 'step-risk-med'; + if (score <= 80) return 'step-risk-high'; + return 'step-risk-crit'; +} + // --- Chain verification --- async function verifyChain() { const container = document.getElementById('chain-container'); container.innerHTML = `
-
Verifying...
+
Verifying hash chain...
+
Recomputing SHA-256 hashes for every message
`; try { - const res = await fetch(API + '/api/chain/verify'); - const status = await res.json(); + const res = await fetch(API + '/api/chain/detail'); + const data = await res.json(); + const status = data.status; + const entries = data.entries || []; + + let html = ''; + + // Summary banner if (status.valid) { - container.innerHTML = `
-
+ html += `
+
🔒
Hash Chain Intact
${status.total} message(s) verified — no tampering detected
`; } else { - container.innerHTML = `
-
+ html += `
+
🔓
Chain Broken
Tamper detected at sequence ${status.broken_at}: ${esc(status.error)}
`; } + + // Stats + const reqCount = entries.filter(e => e.direction === 'client_to_server').length; + const resCount = entries.filter(e => e.direction === 'server_to_client').length; + html += `
+
${entries.length} messages
+
${reqCount} requests
+
${resCount} responses
+
`; + + // Chain entries + html += '
'; + for (let i = 0; i < entries.length; i++) { + const e = entries[i]; + const isReq = e.direction === 'client_to_server'; + const dirLabel = isReq ? 'REQ' : 'RES'; + const dirClass = isReq ? 'req' : 'res'; + const validClass = e.valid ? 'valid' : 'invalid'; + const icon = e.valid ? '✓' : '✗'; + const methodDisplay = e.method || '(response)'; + const shortHash = e.hash ? e.hash.substring(0, 16) + '…' : '—'; + const prevShort = e.prev_hash ? e.prev_hash.substring(0, 8) + '…' : 'genesis'; + + // Link arrow between entries + if (i > 0) { + const linkClass = e.valid ? '' : 'broken'; + html += ``; + } + + html += `
+ #${e.sequence} + ${dirLabel} + ${icon} ${esc(methodDisplay)} + ${esc(shortHash)} +
`; + } + html += '
'; + + container.innerHTML = html; + } catch(e) { container.innerHTML = `
@@ -847,11 +1118,23 @@

No intent chains

} // --- Init --- -loadStats(); -loadSessions(); -loadToolCalls(); -loadPending(); -connectStream(); +async function init() { + await loadStats(); + await loadSessions(); + await loadToolCalls(); + await loadPending(); + connectStream(); + + // Auto-select first session with tool calls + try { + const res = await fetch(API + '/api/sessions'); + const sessions = await res.json(); + const first = sessions.find(s => s.tool_call_count > 0); + if (first) selectSession(first.id); + } catch(e) { /* ignore */ } +} + +init(); // Poll for updates every 5s as backup setInterval(() => { loadStats(); loadPending(); }, 5000);