diff --git a/CLAUDE.md b/CLAUDE.md index 8af8941..bce5963 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -15,8 +15,8 @@ make lint # runs go vet ./... - `cmd/beacon-proxy/main.go` — CLI entry point - `internal/proxy/` — core proxy (child process, stdin/stdout piping, JSON-RPC parsing) - `internal/audit/` — SQLite audit store (sessions, messages) -- `internal/policy/` — policy engine (not yet implemented) -- `internal/web/` — dashboard web UI (not yet implemented) +- `internal/policy/` — policy engine (YAML rules, pause/approve flow) +- `internal/web/` — dashboard web UI (embedded HTML, SSE live stream, API handlers) ## Conventions diff --git a/README.md b/README.md index 2eb7a52..d282d94 100644 --- a/README.md +++ b/README.md @@ -127,9 +127,34 @@ beacon-proxy --server-name filesystem -- npx -y @modelcontextprotocol/server-fil } ``` +## Dashboard + +Beacon includes a real-time web dashboard at `http://localhost:8080` (configurable via `--port`). + +![Beacon Dashboard](docs/dashboard.png) + +- **Live tool call stream** — see every MCP action as it happens via SSE +- **Session overview** — browse sessions, see tool call counts and max risk per server +- **Risk & policy views** — filter by operation type, policy action, or risk level +- **Intent chains** — visualize temporal groupings of related tool calls +- **Approve/deny** — handle paused tool calls directly from the browser +- **Hash chain verification** — one-click integrity check of the entire audit trail + +No build step — the dashboard is a single HTML file embedded in the binary. + +### Generate Demo Traffic + +To populate the dashboard with realistic sample data across multiple servers: + +```bash +go run ./cmd/beacon-traffic/ +``` + +This creates sessions for GitHub, filesystem, and PostgreSQL MCP servers with a mix of read/write/delete/execute operations, risk scores, and policy actions (flag, pause, block). + ## Inspecting the Audit Trail -All messages are logged to SQLite. Query directly: +The dashboard provides the primary UI. You can also query SQLite directly: ```bash # List all sessions @@ -173,8 +198,9 @@ sqlite3 ~/.beacon/audit.db "SELECT direction, method, jsonrpc_id FROM messages O - [x] HTTP approval endpoints for paused tool calls - [x] Temporal intent grouping (5s gap threshold) +- [x] Real-time web dashboard with live SSE stream + ### Next -- [ ] Real-time web dashboard ### Vision - [ ] Cross-system intent chains — one human request, multiple MCP servers, one narrative diff --git a/cmd/beacon-proxy/main.go b/cmd/beacon-proxy/main.go index 14c4949..5545683 100644 --- a/cmd/beacon-proxy/main.go +++ b/cmd/beacon-proxy/main.go @@ -2,7 +2,6 @@ package main import ( "context" - "encoding/json" "flag" "fmt" "log" @@ -17,13 +16,14 @@ import ( "github.com/ottojongerius/beacon/internal/audit" "github.com/ottojongerius/beacon/internal/policy" "github.com/ottojongerius/beacon/internal/proxy" + "github.com/ottojongerius/beacon/internal/web" ) func main() { serverName := flag.String("server-name", "unknown", "label for this MCP server") dbPath := flag.String("db", "~/.beacon/audit.db", "path to SQLite audit database") rulesPath := flag.String("rules", "", "path to YAML policy rules file (uses defaults if not set)") - port := flag.Int("port", 8080, "HTTP port for approval endpoints") + port := flag.Int("port", 8080, "HTTP port for dashboard and approval endpoints") retentionDays := flag.Int("retention-days", 0, "auto-delete sessions older than N days on startup (0 = keep forever)") // Find "--" separator — everything after it is the server command @@ -82,8 +82,9 @@ func main() { } engine := policy.NewEngine(rules) - // Start HTTP server for approvals - go startHTTPServer(*port, engine) + // Start dashboard web server + dashboard := web.NewServer(store, engine) + go startHTTPServer(*port, dashboard) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -103,6 +104,7 @@ func main() { Store: store, Policy: engine, Intent: audit.NewIntentTracker(store, 5*time.Second), + Dashboard: dashboard, } if err := p.Run(ctx); err != nil { @@ -110,38 +112,10 @@ func main() { } } -func startHTTPServer(port int, engine *policy.Engine) { - mux := http.NewServeMux() - - mux.HandleFunc("POST /api/tool-calls/{id}/approve", func(w http.ResponseWriter, r *http.Request) { - id := r.PathValue("id") - if err := engine.Approve(id, "http-user"); err != nil { - http.Error(w, err.Error(), http.StatusNotFound) - return - } - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"status":"approved"}`) - }) - - mux.HandleFunc("POST /api/tool-calls/{id}/deny", func(w http.ResponseWriter, r *http.Request) { - id := r.PathValue("id") - if err := engine.Deny(id); err != nil { - http.Error(w, err.Error(), http.StatusNotFound) - return - } - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"status":"denied"}`) - }) - - mux.HandleFunc("GET /api/tool-calls/pending", func(w http.ResponseWriter, r *http.Request) { - ids := engine.PendingApprovals() - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{"pending": ids}) - }) - +func startHTTPServer(port int, dashboard *web.Server) { addr := fmt.Sprintf("127.0.0.1:%d", port) - log.Printf("beacon: approval server listening on %s", addr) - if err := http.ListenAndServe(addr, mux); err != nil { + log.Printf("beacon: dashboard listening on http://%s", addr) + if err := http.ListenAndServe(addr, dashboard.Handler()); err != nil { log.Printf("beacon: HTTP server error: %v", err) } } diff --git a/cmd/beacon-traffic/main.go b/cmd/beacon-traffic/main.go new file mode 100644 index 0000000..6d4d5b5 --- /dev/null +++ b/cmd/beacon-traffic/main.go @@ -0,0 +1,200 @@ +// beacon-traffic generates realistic MCP audit traffic for dashboard demos. +package main + +import ( + "encoding/json" + "fmt" + "log" + "math/rand" + "os" + "path/filepath" + "strings" + "time" + + "github.com/ottojongerius/beacon/internal/audit" +) + +type scenario struct { + server string + command string + toolCalls []fakeCall +} + +type fakeCall struct { + tool string + args map[string]any + result map[string]any + delayMs int // delay before this call + policySet string +} + +func main() { + home, _ := os.UserHomeDir() + dbPath := filepath.Join(home, ".beacon", "audit.db") + + store, err := audit.Open(dbPath) + if err != nil { + log.Fatalf("failed to open db: %v", err) + } + defer store.Close() + + scenarios := []scenario{ + { + server: "github", + command: "npx -y @modelcontextprotocol/server-github", + toolCalls: []fakeCall{ + {tool: "list_repos", args: m("owner", "ojongerius"), result: m("count", 12), delayMs: 200}, + {tool: "search_issues", args: m("repo", "beacon", "query", "audit"), result: m("count", 3), delayMs: 400}, + {tool: "read_file", args: m("repo", "beacon", "path", "README.md"), result: m("size", 4096), delayMs: 150}, + {tool: "create_issue", args: m("repo", "beacon", "title", "Add SIEM export", "body", "Stream events to Datadog"), result: m("number", 9), delayMs: 800, policySet: "flag"}, + {tool: "update_issue", args: m("repo", "beacon", "number", 5, "labels", []string{"enhancement"}), result: m("updated", true), delayMs: 300, policySet: "flag"}, + }, + }, + { + server: "filesystem", + command: "npx -y @modelcontextprotocol/server-filesystem /Users/otto/projects", + toolCalls: []fakeCall{ + {tool: "list_directory", args: m("path", "/Users/otto/projects"), result: m("entries", 8), delayMs: 100}, + {tool: "read_file", args: m("path", "/Users/otto/projects/beacon/go.mod"), result: m("size", 512), delayMs: 120}, + {tool: "read_file", args: m("path", "/Users/otto/projects/beacon/internal/web/server.go"), result: m("size", 6144), delayMs: 180}, + {tool: "write_file", args: m("path", "/Users/otto/projects/beacon/TODO.md", "content", "# TODO\n- SIEM export\n- Multi-server"), result: m("written", true), delayMs: 500, policySet: "flag"}, + {tool: "delete_file", args: m("path", "/Users/otto/projects/beacon/tmp/old-cache.json"), result: m("deleted", true), delayMs: 400, policySet: "pause"}, + }, + }, + { + server: "postgres", + command: "npx -y @modelcontextprotocol/server-postgres postgres://localhost/app", + toolCalls: []fakeCall{ + {tool: "list_tables", args: m("schema", "public"), result: m("tables", []string{"users", "orders", "products"}), delayMs: 150}, + {tool: "describe_table", args: m("table", "users"), result: m("columns", 8), delayMs: 200}, + {tool: "read_query", args: m("sql", "SELECT COUNT(*) FROM users"), result: m("count", 1847), delayMs: 250}, + {tool: "read_query", args: m("sql", "SELECT email, created_at FROM users ORDER BY created_at DESC LIMIT 5"), result: m("rows", 5), delayMs: 300}, + {tool: "exec_query", args: m("sql", "UPDATE users SET status = 'active' WHERE last_login > NOW() - INTERVAL '30 days'"), result: m("affected", 423), delayMs: 1200, policySet: "pause"}, + {tool: "exec_query", args: m("sql", "DELETE FROM sessions WHERE expired_at < NOW()"), result: m("affected", 89), delayMs: 600, policySet: "block"}, + }, + }, + } + + fmt.Println("🔦 Beacon traffic generator") + fmt.Println(" Generating realistic MCP audit traffic...") + fmt.Println() + + for _, sc := range scenarios { + fmt.Printf(" ▸ %s (%s)\n", sc.server, sc.command) + sessionID, err := store.CreateSession(sc.server, sc.command) + if err != nil { + log.Printf(" ✗ failed to create session: %v", err) + continue + } + + intentID := "" + var lastCallTime time.Time + + for i, tc := range sc.toolCalls { + time.Sleep(time.Duration(tc.delayMs) * time.Millisecond) + + now := time.Now().UTC() + argsJSON, _ := json.Marshal(tc.args) + + // Log request message + reqRaw := fmt.Sprintf(`{"jsonrpc":"2.0","id":%d,"method":"tools/call","params":{"name":"%s","arguments":%s}}`, + i+1, tc.tool, argsJSON) + msgID, err := store.LogMessage(sessionID, "client_to_server", fmt.Sprintf("%d", i+1), "tools/call", reqRaw) + if err != nil { + log.Printf(" ✗ log request: %v", err) + continue + } + + // Classify and score + opType := audit.ClassifyOperation(tc.tool) + score, reasons := audit.ScoreRisk(tc.tool, opType, string(argsJSON)) + + tcID := fmt.Sprintf("tc-%s-%d", sc.server, i+1) + err = store.CreateToolCall(audit.ToolCallRecord{ + ID: tcID, + SessionID: sessionID, + RequestMsgID: msgID, + ToolName: tc.tool, + Arguments: string(argsJSON), + OperationType: opType, + RiskScore: score, + RiskReasons: reasons, + RequestedAt: now, + }) + if err != nil { + log.Printf(" ✗ create tool call: %v", err) + continue + } + + // Set policy if specified + if tc.policySet != "" { + store.UpdateToolCallPolicy(tcID, tc.policySet, nil) + } + + // Simulate response after a brief delay + respDelay := time.Duration(50+rand.Intn(200)) * time.Millisecond + time.Sleep(respDelay) + respTime := time.Now().UTC() + + resultJSON, _ := json.Marshal(tc.result) + respRaw := fmt.Sprintf(`{"jsonrpc":"2.0","id":%d,"result":%s}`, i+1, resultJSON) + respMsgID, err := store.LogMessage(sessionID, "server_to_client", fmt.Sprintf("%d", i+1), "", respRaw) + if err != nil { + log.Printf(" ✗ log response: %v", err) + continue + } + + durationMs := respTime.Sub(now).Milliseconds() + result := string(resultJSON) + store.CompleteToolCall(tcID, respMsgID, &result, nil, respTime, durationMs) + + // Intent grouping: new intent if gap > 2s + if intentID == "" || now.Sub(lastCallTime) > 2*time.Second { + intentID2, err := store.CreateIntentContext(sessionID) + if err == nil { + intentID = intentID2 + } + } + if intentID != "" { + store.AddToolCallToIntent(intentID, tcID, i+1) + } + lastCallTime = now + + policyTag := "" + if tc.policySet != "" { + policyTag = fmt.Sprintf(" [%s]", strings.ToUpper(tc.policySet)) + } + riskBar := riskIndicator(score) + fmt.Printf(" %s %-20s %s %-7s risk:%2d %s%s\n", + "✓", tc.tool, riskBar, opType, score, reasons, policyTag) + } + + // End some sessions (leave last one "live") + if sc.server != "postgres" { + store.EndSession(sessionID) + } + } + + fmt.Println("\n ✅ Done! Refresh the dashboard at http://localhost:8080") +} + +func m(kvs ...any) map[string]any { + out := make(map[string]any) + for i := 0; i < len(kvs)-1; i += 2 { + out[kvs[i].(string)] = kvs[i+1] + } + return out +} + +func riskIndicator(score int) string { + switch { + case score <= 30: + return "🟢" + case score <= 60: + return "🟡" + case score <= 80: + return "🟠" + default: + return "🔴" + } +} diff --git a/docs/dashboard.png b/docs/dashboard.png new file mode 100644 index 0000000..28dcbb2 Binary files /dev/null and b/docs/dashboard.png differ diff --git a/internal/audit/queries.go b/internal/audit/queries.go new file mode 100644 index 0000000..0a6ddd9 --- /dev/null +++ b/internal/audit/queries.go @@ -0,0 +1,252 @@ +package audit + +import ( + "database/sql" + "encoding/json" + "time" +) + +// SessionSummary is a session with aggregate stats for the dashboard. +type SessionSummary struct { + ID string `json:"id"` + ServerName string `json:"server_name"` + ServerCommand string `json:"server_command"` + StartedAt time.Time `json:"started_at"` + EndedAt *time.Time `json:"ended_at"` + MessageCount int `json:"message_count"` + ToolCallCount int `json:"tool_call_count"` + MaxRisk int `json:"max_risk"` +} + +// ListSessions returns all sessions with aggregate counts, newest first. +func (s *Store) ListSessions() ([]SessionSummary, error) { + rows, err := s.db.Query(` + SELECT s.id, s.server_name, s.server_command, s.started_at, s.ended_at, + COALESCE((SELECT COUNT(*) FROM messages WHERE session_id = s.id), 0) AS msg_count, + COALESCE((SELECT COUNT(*) FROM tool_calls WHERE session_id = s.id), 0) AS tc_count, + COALESCE((SELECT MAX(risk_score) FROM tool_calls WHERE session_id = s.id), 0) AS max_risk + FROM sessions s ORDER BY s.started_at DESC + `) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []SessionSummary + for rows.Next() { + var ss SessionSummary + var endedAt sql.NullTime + if err := rows.Scan(&ss.ID, &ss.ServerName, &ss.ServerCommand, &ss.StartedAt, &endedAt, + &ss.MessageCount, &ss.ToolCallCount, &ss.MaxRisk); err != nil { + return nil, err + } + if endedAt.Valid { + ss.EndedAt = &endedAt.Time + } + out = append(out, ss) + } + return out, rows.Err() +} + +// ToolCallSummary is a tool call row for the dashboard. +type ToolCallSummary struct { + ID string `json:"id"` + SessionID string `json:"session_id"` + ToolName string `json:"tool_name"` + Arguments string `json:"arguments"` + Result *string `json:"result"` + Error *string `json:"error"` + OperationType string `json:"operation_type"` + RiskScore int `json:"risk_score"` + RiskReasons []string `json:"risk_reasons"` + PolicyAction string `json:"policy_action"` + ApprovedBy *string `json:"approved_by"` + RequestedAt time.Time `json:"requested_at"` + RespondedAt *time.Time `json:"responded_at"` + DurationMs *int64 `json:"duration_ms"` + ServerName string `json:"server_name"` +} + +// ListToolCalls returns tool calls with optional filters, newest first. +func (s *Store) ListToolCalls(sessionID, policyFilter, opFilter string, limit int) ([]ToolCallSummary, error) { + query := ` + SELECT tc.id, tc.session_id, tc.tool_name, tc.arguments, tc.result, tc.error, + tc.operation_type, tc.risk_score, tc.risk_reasons, tc.policy_action, + tc.approved_by, tc.requested_at, tc.responded_at, tc.duration_ms, + s.server_name + FROM tool_calls tc + JOIN sessions s ON s.id = tc.session_id + WHERE 1=1 + ` + var args []any + + if sessionID != "" { + query += " AND tc.session_id = ?" + args = append(args, sessionID) + } + if policyFilter != "" { + query += " AND tc.policy_action = ?" + args = append(args, policyFilter) + } + if opFilter != "" { + query += " AND tc.operation_type = ?" + args = append(args, opFilter) + } + query += " ORDER BY tc.requested_at DESC" + if limit > 0 { + query += " LIMIT ?" + args = append(args, limit) + } + + rows, err := s.db.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []ToolCallSummary + for rows.Next() { + var tc ToolCallSummary + var reasons string + var result, errJSON, approvedBy sql.NullString + var respondedAt sql.NullTime + var durationMs sql.NullInt64 + if err := rows.Scan(&tc.ID, &tc.SessionID, &tc.ToolName, &tc.Arguments, + &result, &errJSON, &tc.OperationType, &tc.RiskScore, &reasons, + &tc.PolicyAction, &approvedBy, &tc.RequestedAt, &respondedAt, &durationMs, + &tc.ServerName); err != nil { + return nil, err + } + if result.Valid { + v := s.decrypt(result.String) + tc.Result = &v + } + if errJSON.Valid { + v := s.decrypt(errJSON.String) + tc.Error = &v + } + if approvedBy.Valid { + tc.ApprovedBy = &approvedBy.String + } + if respondedAt.Valid { + tc.RespondedAt = &respondedAt.Time + } + if durationMs.Valid { + tc.DurationMs = &durationMs.Int64 + } + tc.Arguments = s.decrypt(tc.Arguments) + json.Unmarshal([]byte(reasons), &tc.RiskReasons) + if tc.RiskReasons == nil { + tc.RiskReasons = []string{} + } + out = append(out, tc) + } + return out, rows.Err() +} + +// IntentSummary is an intent context with its tool calls for the dashboard. +type IntentSummary struct { + ID string `json:"id"` + SessionID string `json:"session_id"` + CreatedAt time.Time `json:"created_at"` + ToolCalls []IntentToolEntry `json:"tool_calls"` +} + +// IntentToolEntry is a single tool call within an intent chain. +type IntentToolEntry struct { + SequenceOrder int `json:"sequence_order"` + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + OperationType string `json:"operation_type"` + RiskScore int `json:"risk_score"` + PolicyAction string `json:"policy_action"` +} + +// ListIntents returns intent contexts with their tool calls for a session. +func (s *Store) ListIntents(sessionID string) ([]IntentSummary, error) { + rows, err := s.db.Query(` + SELECT ic.id, ic.session_id, ic.created_at, + itc.sequence_order, itc.tool_call_id, + tc.tool_name, tc.operation_type, tc.risk_score, tc.policy_action + FROM intent_contexts ic + JOIN intent_tool_calls itc ON itc.intent_id = ic.id + JOIN tool_calls tc ON tc.id = itc.tool_call_id + WHERE ic.session_id = ? + ORDER BY ic.created_at ASC, itc.sequence_order ASC + `, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + + intentMap := make(map[string]*IntentSummary) + var order []string + + for rows.Next() { + var intentID, sessionID string + var createdAt time.Time + var entry IntentToolEntry + if err := rows.Scan(&intentID, &sessionID, &createdAt, + &entry.SequenceOrder, &entry.ToolCallID, + &entry.ToolName, &entry.OperationType, &entry.RiskScore, &entry.PolicyAction); err != nil { + return nil, err + } + if _, ok := intentMap[intentID]; !ok { + intentMap[intentID] = &IntentSummary{ + ID: intentID, + SessionID: sessionID, + CreatedAt: createdAt, + } + order = append(order, intentID) + } + intentMap[intentID].ToolCalls = append(intentMap[intentID].ToolCalls, entry) + } + + var out []IntentSummary + for _, id := range order { + out = append(out, *intentMap[id]) + } + return out, rows.Err() +} + +// DashboardStats returns aggregate stats for the dashboard header. +type DashboardStats struct { + TotalSessions int `json:"total_sessions"` + ActiveSessions int `json:"active_sessions"` + TotalToolCalls int `json:"total_tool_calls"` + FlaggedCalls int `json:"flagged_calls"` + PausedCalls int `json:"paused_calls"` + BlockedCalls int `json:"blocked_calls"` + AvgRiskScore int `json:"avg_risk_score"` + MaxRiskScore int `json:"max_risk_score"` +} + +// GetStats returns aggregate statistics for the dashboard. +func (s *Store) GetStats() (*DashboardStats, error) { + stats := &DashboardStats{} + err := s.db.QueryRow("SELECT COUNT(*) FROM sessions").Scan(&stats.TotalSessions) + if err != nil { + return nil, err + } + s.db.QueryRow("SELECT COUNT(*) FROM sessions WHERE ended_at IS NULL").Scan(&stats.ActiveSessions) + s.db.QueryRow("SELECT COUNT(*) FROM tool_calls").Scan(&stats.TotalToolCalls) + s.db.QueryRow("SELECT COUNT(*) FROM tool_calls WHERE policy_action = 'flag'").Scan(&stats.FlaggedCalls) + s.db.QueryRow("SELECT COUNT(*) FROM tool_calls WHERE policy_action = 'pause'").Scan(&stats.PausedCalls) + s.db.QueryRow("SELECT COUNT(*) FROM tool_calls WHERE policy_action = 'block'").Scan(&stats.BlockedCalls) + s.db.QueryRow("SELECT COALESCE(AVG(risk_score), 0) FROM tool_calls").Scan(&stats.AvgRiskScore) + s.db.QueryRow("SELECT COALESCE(MAX(risk_score), 0) FROM tool_calls").Scan(&stats.MaxRiskScore) + return stats, nil +} + +// decrypt is a helper that decrypts a string if encryption is enabled, +// falling back to the input on error (for dashboard display). +func (s *Store) decrypt(ciphertext string) string { + if s.enc == nil { + return ciphertext + } + plain, err := s.enc.Decrypt(ciphertext) + if err != nil { + return ciphertext // show as-is if decryption fails + } + return plain +} diff --git a/internal/audit/queries_test.go b/internal/audit/queries_test.go new file mode 100644 index 0000000..d9d38a1 --- /dev/null +++ b/internal/audit/queries_test.go @@ -0,0 +1,295 @@ +package audit + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestListSessions(t *testing.T) { + db := filepath.Join(t.TempDir(), "test.db") + store, err := Open(db) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + // Empty DB + sessions, err := store.ListSessions() + if err != nil { + t.Fatalf("ListSessions: %v", err) + } + if len(sessions) != 0 { + t.Errorf("expected 0 sessions, got %d", len(sessions)) + } + + // Create sessions with tool calls + id1, err := store.CreateSession("github", "npx github-mcp") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + msgID, err := store.LogMessage(id1, "client_to_server", "1", "tools/call", `{"method":"tools/call"}`) + if err != nil { + t.Fatalf("LogMessage: %v", err) + } + + err = store.CreateToolCall(ToolCallRecord{ + ID: "tc-1", + SessionID: id1, + RequestMsgID: msgID, + ToolName: "read_file", + Arguments: `{"path":"test.go"}`, + OperationType: "read", + RiskScore: 10, + RiskReasons: []string{}, + RequestedAt: time.Now().UTC(), + }) + if err != nil { + t.Fatalf("CreateToolCall: %v", err) + } + + sessions, err = store.ListSessions() + if err != nil { + t.Fatalf("ListSessions: %v", err) + } + if len(sessions) != 1 { + t.Fatalf("expected 1 session, got %d", len(sessions)) + } + if sessions[0].ServerName != "github" { + t.Errorf("expected server name 'github', got %q", sessions[0].ServerName) + } + if sessions[0].MessageCount != 1 { + t.Errorf("expected 1 message, got %d", sessions[0].MessageCount) + } + if sessions[0].ToolCallCount != 1 { + t.Errorf("expected 1 tool call, got %d", sessions[0].ToolCallCount) + } +} + +func TestListToolCalls(t *testing.T) { + db := filepath.Join(t.TempDir(), "test.db") + store, err := Open(db) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + id, err := store.CreateSession("test-server", "echo test") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + msgID, err := store.LogMessage(id, "client_to_server", "1", "tools/call", `{}`) + if err != nil { + t.Fatalf("LogMessage: %v", err) + } + + err = store.CreateToolCall(ToolCallRecord{ + ID: "tc-1", + SessionID: id, + RequestMsgID: msgID, + ToolName: "delete_issue", + Arguments: `{"id":42}`, + OperationType: "delete", + RiskScore: 60, + RiskReasons: []string{"delete_operation"}, + RequestedAt: time.Now().UTC(), + }) + if err != nil { + t.Fatalf("CreateToolCall: %v", err) + } + + if err := store.UpdateToolCallPolicy("tc-1", "flag", nil); err != nil { + t.Fatalf("UpdateToolCallPolicy: %v", err) + } + + // No filter + calls, err := store.ListToolCalls("", "", "", 100) + if err != nil { + t.Fatalf("ListToolCalls: %v", err) + } + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %d", len(calls)) + } + if calls[0].ToolName != "delete_issue" { + t.Errorf("tool name = %q, want delete_issue", calls[0].ToolName) + } + if calls[0].PolicyAction != "flag" { + t.Errorf("policy = %q, want flag", calls[0].PolicyAction) + } + if calls[0].ServerName != "test-server" { + t.Errorf("server = %q, want test-server", calls[0].ServerName) + } + if len(calls[0].RiskReasons) != 1 || calls[0].RiskReasons[0] != "delete_operation" { + t.Errorf("risk reasons = %v, want [delete_operation]", calls[0].RiskReasons) + } + + // Filter by policy + calls, err = store.ListToolCalls("", "block", "", 100) + if err != nil { + t.Fatalf("ListToolCalls with filter: %v", err) + } + if len(calls) != 0 { + t.Errorf("expected 0 calls with policy=block, got %d", len(calls)) + } + + // Filter by operation + calls, err = store.ListToolCalls("", "", "delete", 100) + if err != nil { + t.Fatalf("ListToolCalls op filter: %v", err) + } + if len(calls) != 1 { + t.Errorf("expected 1 delete call, got %d", len(calls)) + } +} + +func TestListIntents(t *testing.T) { + db := filepath.Join(t.TempDir(), "test.db") + store, err := Open(db) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + sessionID, err := store.CreateSession("test", "echo test") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + msgID, err := store.LogMessage(sessionID, "client_to_server", "1", "tools/call", `{}`) + if err != nil { + t.Fatalf("LogMessage: %v", err) + } + + err = store.CreateToolCall(ToolCallRecord{ + ID: "tc-1", SessionID: sessionID, RequestMsgID: msgID, + ToolName: "read_file", Arguments: "{}", OperationType: "read", + RiskScore: 0, RiskReasons: []string{}, RequestedAt: time.Now().UTC(), + }) + if err != nil { + t.Fatalf("CreateToolCall: %v", err) + } + + intentID, err := store.CreateIntentContext(sessionID) + if err != nil { + t.Fatalf("CreateIntentContext: %v", err) + } + if err := store.AddToolCallToIntent(intentID, "tc-1", 1); err != nil { + t.Fatalf("AddToolCallToIntent: %v", err) + } + + intents, err := store.ListIntents(sessionID) + if err != nil { + t.Fatalf("ListIntents: %v", err) + } + if len(intents) != 1 { + t.Fatalf("expected 1 intent, got %d", len(intents)) + } + if len(intents[0].ToolCalls) != 1 { + t.Fatalf("expected 1 tool call in intent, got %d", len(intents[0].ToolCalls)) + } + if intents[0].ToolCalls[0].ToolName != "read_file" { + t.Errorf("tool name = %q, want read_file", intents[0].ToolCalls[0].ToolName) + } +} + +func TestGetStats(t *testing.T) { + db := filepath.Join(t.TempDir(), "test.db") + store, err := Open(db) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + stats, err := store.GetStats() + if err != nil { + t.Fatalf("GetStats: %v", err) + } + if stats.TotalSessions != 0 || stats.TotalToolCalls != 0 { + t.Errorf("expected zeros, got sessions=%d calls=%d", stats.TotalSessions, stats.TotalToolCalls) + } + + id, err := store.CreateSession("s1", "echo") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + msgID, err := store.LogMessage(id, "client_to_server", "1", "tools/call", `{}`) + if err != nil { + t.Fatalf("LogMessage: %v", err) + } + + err = store.CreateToolCall(ToolCallRecord{ + ID: "tc-1", SessionID: id, RequestMsgID: msgID, + ToolName: "write_file", Arguments: "{}", OperationType: "write", + RiskScore: 40, RiskReasons: []string{"write_operation"}, RequestedAt: time.Now().UTC(), + }) + if err != nil { + t.Fatalf("CreateToolCall: %v", err) + } + + stats, err = store.GetStats() + if err != nil { + t.Fatalf("GetStats: %v", err) + } + if stats.TotalSessions != 1 { + t.Errorf("sessions = %d, want 1", stats.TotalSessions) + } + if stats.ActiveSessions != 1 { + t.Errorf("active = %d, want 1 (not ended)", stats.ActiveSessions) + } + if stats.TotalToolCalls != 1 { + t.Errorf("tool calls = %d, want 1", stats.TotalToolCalls) + } + if stats.AvgRiskScore != 40 { + t.Errorf("avg risk = %d, want 40", stats.AvgRiskScore) + } +} + +func TestListToolCalls_WithEncryption(t *testing.T) { + db := filepath.Join(t.TempDir(), "test.db") + // Use a 32-byte hex key for AES-256 + key := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + os.Setenv("BEACON_ENCRYPTION_KEY", key) + defer os.Unsetenv("BEACON_ENCRYPTION_KEY") + + store, err := Open(db, key) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + id, err := store.CreateSession("encrypted-server", "echo test") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + msgID, err := store.LogMessage(id, "client_to_server", "1", "tools/call", `{}`) + if err != nil { + t.Fatalf("LogMessage: %v", err) + } + + err = store.CreateToolCall(ToolCallRecord{ + ID: "tc-1", SessionID: id, RequestMsgID: msgID, + ToolName: "read_file", Arguments: `{"path":"secret.txt"}`, OperationType: "read", + RiskScore: 10, RiskReasons: []string{}, RequestedAt: time.Now().UTC(), + }) + if err != nil { + t.Fatalf("CreateToolCall: %v", err) + } + + // Should decrypt transparently + calls, err := store.ListToolCalls("", "", "", 100) + if err != nil { + t.Fatalf("ListToolCalls: %v", err) + } + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %d", len(calls)) + } + // Arguments should be decrypted (redacted but readable) + if calls[0].Arguments == "" { + t.Error("expected non-empty arguments after decryption") + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index c11f85e..eec38cb 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -15,6 +15,7 @@ import ( "github.com/google/uuid" "github.com/ottojongerius/beacon/internal/audit" "github.com/ottojongerius/beacon/internal/policy" + "github.com/ottojongerius/beacon/internal/web" ) const maxMessageSize = 1 << 20 // 1MB @@ -33,6 +34,7 @@ type Proxy struct { Store *audit.Store Policy *policy.Engine Intent *audit.IntentTracker + Dashboard *web.Server pendingCalls map[string]*pendingCall // keyed by JSON-RPC id pendingMu sync.Mutex @@ -181,6 +183,18 @@ func (p *Proxy) handleToolCallRequest(sessionID, msgID, jsonrpcID string, msg *M return "pass" } + // Broadcast to dashboard + if p.Dashboard != nil { + p.Dashboard.Broadcast(map[string]any{ + "type": "tool_call", + "id": tcID, + "tool_name": params.Name, + "operation_type": opType, + "risk_score": score, + "server_name": p.ServerName, + }) + } + // Track intent grouping if p.Intent != nil { p.Intent.Track(sessionID, tcID, now) diff --git a/internal/web/server.go b/internal/web/server.go new file mode 100644 index 0000000..aced385 --- /dev/null +++ b/internal/web/server.go @@ -0,0 +1,231 @@ +package web + +import ( + "embed" + "encoding/json" + "fmt" + "io/fs" + "log" + "net/http" + "sync" + "time" + + "github.com/ottojongerius/beacon/internal/audit" + "github.com/ottojongerius/beacon/internal/policy" +) + +//go:embed static +var staticFiles embed.FS + +// Server is the dashboard web server. +type Server struct { + store *audit.Store + engine *policy.Engine + hub *wsHub +} + +// NewServer creates a new dashboard server. +func NewServer(store *audit.Store, engine *policy.Engine) *Server { + return &Server{ + store: store, + engine: engine, + hub: newHub(), + } +} + +// Handler returns an http.Handler with all dashboard routes. +func (s *Server) Handler() http.Handler { + mux := http.NewServeMux() + + // API routes + mux.HandleFunc("GET /api/stats", s.handleStats) + mux.HandleFunc("GET /api/sessions", s.handleSessions) + mux.HandleFunc("GET /api/tool-calls", s.handleToolCalls) + mux.HandleFunc("GET /api/sessions/{id}/intents", s.handleIntents) + mux.HandleFunc("GET /api/chain/verify", s.handleVerifyChain) + mux.HandleFunc("GET /api/tool-calls/pending", s.handlePending) + mux.HandleFunc("POST /api/tool-calls/{id}/approve", s.handleApprove) + mux.HandleFunc("POST /api/tool-calls/{id}/deny", s.handleDeny) + mux.HandleFunc("GET /ws/live", s.handleWebSocket) + + // Serve embedded static files (dashboard HTML) + staticSub, _ := fs.Sub(staticFiles, "static") + mux.Handle("GET /", http.FileServer(http.FS(staticSub))) + + return mux +} + +// Broadcast sends a tool call event to all connected WebSocket clients. +func (s *Server) Broadcast(event any) { + data, err := json.Marshal(event) + if err != nil { + return + } + s.hub.broadcast(data) +} + +func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) { + stats, err := s.store.GetStats() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, stats) +} + +func (s *Server) handleSessions(w http.ResponseWriter, r *http.Request) { + sessions, err := s.store.ListSessions() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if sessions == nil { + sessions = []audit.SessionSummary{} + } + writeJSON(w, sessions) +} + +func (s *Server) handleToolCalls(w http.ResponseWriter, r *http.Request) { + sessionID := r.URL.Query().Get("session_id") + policyFilter := r.URL.Query().Get("policy") + opFilter := r.URL.Query().Get("operation") + limit := 200 + + calls, err := s.store.ListToolCalls(sessionID, policyFilter, opFilter, limit) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if calls == nil { + calls = []audit.ToolCallSummary{} + } + writeJSON(w, calls) +} + +func (s *Server) handleIntents(w http.ResponseWriter, r *http.Request) { + sessionID := r.PathValue("id") + intents, err := s.store.ListIntents(sessionID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if intents == nil { + intents = []audit.IntentSummary{} + } + writeJSON(w, intents) +} + +func (s *Server) handleVerifyChain(w http.ResponseWriter, r *http.Request) { + status, err := s.store.VerifyChain() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, status) +} + +func (s *Server) handlePending(w http.ResponseWriter, r *http.Request) { + ids := s.engine.PendingApprovals() + writeJSON(w, map[string]any{"pending": ids}) +} + +func (s *Server) handleApprove(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if err := s.engine.Approve(id, "dashboard-user"); err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + writeJSON(w, map[string]string{"status": "approved"}) +} + +func (s *Server) handleDeny(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if err := s.engine.Deny(id); err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + writeJSON(w, map[string]string{"status": "denied"}) +} + +func writeJSON(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(v) +} + +// --- WebSocket hub (minimal, no external dependencies) --- +// Uses Server-Sent Events (SSE) as a simpler alternative that requires no +// third-party WebSocket library. The endpoint is still mounted at /ws/live +// for compatibility with the spec, but uses SSE under the hood. + +func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming not supported", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + ch := s.hub.subscribe() + defer s.hub.unsubscribe(ch) + + // Send initial ping + fmt.Fprintf(w, "data: {\"type\":\"connected\",\"time\":\"%s\"}\n\n", time.Now().UTC().Format(time.RFC3339)) + flusher.Flush() + + for { + select { + case msg, ok := <-ch: + if !ok { + return + } + fmt.Fprintf(w, "data: %s\n\n", msg) + flusher.Flush() + case <-r.Context().Done(): + return + } + } +} + +// wsHub manages SSE subscribers. +type wsHub struct { + mu sync.RWMutex + clients map[chan []byte]struct{} +} + +func newHub() *wsHub { + return &wsHub{ + clients: make(map[chan []byte]struct{}), + } +} + +func (h *wsHub) subscribe() chan []byte { + ch := make(chan []byte, 64) + h.mu.Lock() + h.clients[ch] = struct{}{} + h.mu.Unlock() + return ch +} + +func (h *wsHub) unsubscribe(ch chan []byte) { + h.mu.Lock() + delete(h.clients, ch) + close(ch) + h.mu.Unlock() +} + +func (h *wsHub) broadcast(data []byte) { + h.mu.RLock() + defer h.mu.RUnlock() + + for ch := range h.clients { + select { + case ch <- data: + default: + log.Printf("beacon: dropping event for slow SSE client") + } + } +} diff --git a/internal/web/server_test.go b/internal/web/server_test.go new file mode 100644 index 0000000..2f99fad --- /dev/null +++ b/internal/web/server_test.go @@ -0,0 +1,228 @@ +package web + +import ( + "context" + "encoding/json" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/ottojongerius/beacon/internal/audit" + "github.com/ottojongerius/beacon/internal/policy" +) + +func setupTestServer(t *testing.T) (*Server, *audit.Store) { + t.Helper() + db := filepath.Join(t.TempDir(), "test.db") + store, err := audit.Open(db) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { store.Close() }) + + engine := policy.NewEngine(policy.DefaultRules()) + srv := NewServer(store, engine) + return srv, store +} + +func TestHandleStats_Empty(t *testing.T) { + srv, _ := setupTestServer(t) + req := httptest.NewRequest("GET", "/api/stats", nil) + w := httptest.NewRecorder() + + srv.handleStats(w, req) + + if w.Code != 200 { + t.Fatalf("status = %d, want 200", w.Code) + } + + var stats audit.DashboardStats + if err := json.NewDecoder(w.Body).Decode(&stats); err != nil { + t.Fatalf("decode: %v", err) + } + if stats.TotalSessions != 0 { + t.Errorf("sessions = %d, want 0", stats.TotalSessions) + } +} + +func TestHandleSessions(t *testing.T) { + srv, store := setupTestServer(t) + + _, err := store.CreateSession("github", "npx github-mcp") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + req := httptest.NewRequest("GET", "/api/sessions", nil) + w := httptest.NewRecorder() + srv.handleSessions(w, req) + + if w.Code != 200 { + t.Fatalf("status = %d", w.Code) + } + + var sessions []audit.SessionSummary + if err := json.NewDecoder(w.Body).Decode(&sessions); err != nil { + t.Fatalf("decode: %v", err) + } + if len(sessions) != 1 { + t.Fatalf("expected 1 session, got %d", len(sessions)) + } + if sessions[0].ServerName != "github" { + t.Errorf("server name = %q", sessions[0].ServerName) + } +} + +func TestHandleToolCalls_WithFilters(t *testing.T) { + srv, store := setupTestServer(t) + + id, err := store.CreateSession("test", "echo") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + msgID, err := store.LogMessage(id, "client_to_server", "1", "tools/call", `{}`) + if err != nil { + t.Fatalf("LogMessage: %v", err) + } + + err = store.CreateToolCall(audit.ToolCallRecord{ + ID: "tc-1", SessionID: id, RequestMsgID: msgID, + ToolName: "delete_repo", Arguments: "{}", OperationType: "delete", + RiskScore: 80, RiskReasons: []string{"delete_operation"}, RequestedAt: time.Now().UTC(), + }) + if err != nil { + t.Fatalf("CreateToolCall: %v", err) + } + + // No filter — should return the call + req := httptest.NewRequest("GET", "/api/tool-calls", nil) + w := httptest.NewRecorder() + srv.handleToolCalls(w, req) + + var calls []audit.ToolCallSummary + if err := json.NewDecoder(w.Body).Decode(&calls); err != nil { + t.Fatalf("decode: %v", err) + } + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %d", len(calls)) + } + + // Filter by op=write — should return 0 + req = httptest.NewRequest("GET", "/api/tool-calls?operation=write", nil) + w = httptest.NewRecorder() + srv.handleToolCalls(w, req) + + calls = nil + if err := json.NewDecoder(w.Body).Decode(&calls); err != nil { + t.Fatalf("decode: %v", err) + } + if len(calls) != 0 { + t.Errorf("expected 0 calls with op=write, got %d", len(calls)) + } +} + +func TestHandleVerifyChain(t *testing.T) { + srv, store := setupTestServer(t) + + id, err := store.CreateSession("test", "echo") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if _, err := store.LogMessage(id, "client_to_server", "1", "tools/call", `{"test":"msg"}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + + req := httptest.NewRequest("GET", "/api/chain/verify", nil) + w := httptest.NewRecorder() + srv.handleVerifyChain(w, req) + + if w.Code != 200 { + t.Fatalf("status = %d", w.Code) + } + + var status audit.ChainStatus + if err := json.NewDecoder(w.Body).Decode(&status); err != nil { + t.Fatalf("decode: %v", err) + } + if !status.Valid { + t.Errorf("chain should be valid, got error: %s", status.Error) + } + if status.Total != 1 { + t.Errorf("total = %d, want 1", status.Total) + } +} + +func TestDashboardServed(t *testing.T) { + srv, _ := setupTestServer(t) + handler := srv.Handler() + + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatalf("status = %d, want 200", w.Code) + } + + body := w.Body.String() + if len(body) < 100 { + t.Error("expected HTML content from embedded dashboard") + } +} + +func TestSSEEndpoint(t *testing.T) { + srv, _ := setupTestServer(t) + handler := srv.Handler() + + // Use a cancellable context so we can cleanly stop the SSE handler + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req := httptest.NewRequest("GET", "/ws/live", nil).WithContext(ctx) + w := httptest.NewRecorder() + + // Run in goroutine since SSE blocks + done := make(chan struct{}) + go func() { + handler.ServeHTTP(w, req) + close(done) + }() + + // Cancel the context to stop the handler, then wait for it to finish + // before reading the response (avoids data race on ResponseRecorder) + time.Sleep(50 * time.Millisecond) + cancel() + <-done + + if w.Code != 200 { + t.Errorf("SSE endpoint returned %d, want 200", w.Code) + } +} + +func TestBroadcast(t *testing.T) { + srv, _ := setupTestServer(t) + + ch := srv.hub.subscribe() + defer srv.hub.unsubscribe(ch) + + srv.Broadcast(map[string]string{"type": "test", "data": "hello"}) + + select { + case msg := <-ch: + var data map[string]string + if err := json.Unmarshal(msg, &data); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if data["type"] != "test" { + t.Errorf("type = %q, want test", data["type"]) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for broadcast") + } +} + +// Ensure unused import doesn't cause issues +var _ = os.DevNull diff --git a/internal/web/static/index.html b/internal/web/static/index.html new file mode 100644 index 0000000..dfd89be --- /dev/null +++ b/internal/web/static/index.html @@ -0,0 +1,860 @@ + + + + + +Beacon — MCP Audit Dashboard + + + +
+

Beacon

+
+ Connecting... +
+
+ +
+
Sessions
-
+
Active
-
+
Tool Calls
-
+
Flagged
-
+
Blocked
-
+
Avg Risk
-
+
+ +
+ + +
+
+ + + + LIVE +
+ +
+ +
+
+ + + + + + + + + +
+ + + + + + + + + + + + + + +
TimeServerToolTypeRiskPolicyDurationActions
+ +
+ + + + +
+
+ + + +