diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 06f6fda..db69a16 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.26' - name: Check formatting run: test -z "$(gofmt -l .)" || (echo "Files need formatting:"; gofmt -l .; exit 1) diff --git a/README.md b/README.md index f711778..0c9bcf1 100644 --- a/README.md +++ b/README.md @@ -150,13 +150,14 @@ sqlite3 ~/.beacon/audit.db "SELECT direction, method, jsonrpc_id FROM messages O - **Minimal stderr logging**: error messages do not leak paths, session IDs, or message contents - **Encryption at rest**: sensitive columns (`raw`, `arguments`, `result`, `error`) are encrypted with AES-256-GCM when an encryption key is provided. Enable by setting the `BEACON_ENCRYPTION_KEY` environment variable. Without a key, data is stored in plaintext (redaction still applies). - **Parameterized SQL**: all database queries use parameterized statements — no SQL injection risk +- **Tamper-evident hash chain**: each audit message includes a SHA-256 hash chained to the previous message. Any modification, deletion, or reordering of existing records is detectable. Verify with `VerifyChain()` — works offline against the DB file. ### Assumptions and known limitations - **The audit DB contains sensitive data.** Raw MCP payloads may include file contents or PII. Enable encryption at rest via `BEACON_ENCRYPTION_KEY` and treat `~/.beacon/audit.db` as sensitive. - **Beacon is a local tool, not a network boundary.** It trusts the local user and the MCP client. If an attacker can modify `claude_desktop_config.json`, they can bypass the proxy entirely. - **No DB size limit.** A high-volume MCP server can grow the audit DB indefinitely. Retention policies are planned but not yet implemented. -- **No tamper protection on the audit trail.** A local attacker with file access can modify or delete audit records. Signed/append-only logging is a future consideration. +- **Tamper detection, not prevention.** The hash chain detects modifications or deletions after the fact, but cannot prevent a local attacker with DB access from rewriting the entire chain. For stronger guarantees, consider external log shipping. ## Roadmap diff --git a/go.mod b/go.mod index f3d4660..b1faf7e 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/ottojongerius/beacon -go 1.22 +go 1.26 require ( github.com/google/uuid v1.6.0 diff --git a/internal/audit/store.go b/internal/audit/store.go index c926d11..6fc9425 100644 --- a/internal/audit/store.go +++ b/internal/audit/store.go @@ -1,11 +1,15 @@ package audit import ( + "crypto/sha256" "database/sql" + "encoding/binary" + "encoding/hex" "encoding/json" "fmt" "os" "path/filepath" + "strings" "sync" "time" @@ -29,10 +33,14 @@ CREATE TABLE IF NOT EXISTS messages ( timestamp TIMESTAMP NOT NULL, jsonrpc_id TEXT, method TEXT, - raw TEXT NOT NULL + raw TEXT NOT NULL, + sequence INTEGER NOT NULL, + prev_hash TEXT NOT NULL DEFAULT '', + hash TEXT NOT NULL DEFAULT '' ); CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp); +CREATE UNIQUE INDEX IF NOT EXISTS idx_messages_sequence ON messages(sequence); CREATE TABLE IF NOT EXISTS tool_calls ( id TEXT PRIMARY KEY, @@ -76,10 +84,54 @@ CREATE TABLE IF NOT EXISTS intent_tool_calls ( CREATE INDEX IF NOT EXISTS idx_intent_tool_calls_intent ON intent_tool_calls(intent_id); ` +// migrateHashChain adds sequence/prev_hash/hash columns to messages if they don't exist. +// This handles upgrading databases created before the hash chain feature. +func migrateHashChain(db *sql.DB) error { + var hasSequence bool + rows, err := db.Query("PRAGMA table_info(messages)") + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var cid int + var name, typ string + var notnull int + var dflt sql.NullString + var pk int + if err := rows.Scan(&cid, &name, &typ, ¬null, &dflt, &pk); err != nil { + return err + } + if name == "sequence" { + hasSequence = true + } + } + if hasSequence { + return nil // already migrated + } + + migrations := []string{ + "ALTER TABLE messages ADD COLUMN sequence INTEGER NOT NULL DEFAULT 0", + "ALTER TABLE messages ADD COLUMN prev_hash TEXT NOT NULL DEFAULT ''", + "ALTER TABLE messages ADD COLUMN hash TEXT NOT NULL DEFAULT ''", + } + for _, m := range migrations { + if _, err := db.Exec(m); err != nil { + // Column may already exist from partial migration + if !strings.Contains(err.Error(), "duplicate column") { + return err + } + } + } + return nil +} + type Store struct { - db *sql.DB - mu sync.Mutex - enc *Encryptor + db *sql.DB + mu sync.Mutex + enc *Encryptor + lastHash string // hash chain: hash of the most recent message + sequence int64 // monotonic sequence number for messages } // Open creates or opens the SQLite database at the given path. @@ -106,6 +158,12 @@ func Open(dbPath string, encryptionKey ...string) (*Store, error) { return nil, fmt.Errorf("run schema migration: %w", err) } + // Migrate existing DBs: add hash chain columns if missing + if err := migrateHashChain(db); err != nil { + db.Close() + return nil, fmt.Errorf("migrate hash chain columns: %w", err) + } + // Restrict DB file permissions — audit data contains raw MCP payloads. // Done after schema migration so the file is guaranteed to exist. if err := os.Chmod(dbPath, 0600); err != nil { @@ -123,7 +181,19 @@ func Open(dbPath string, encryptionKey ...string) (*Store, error) { } } - return &Store{db: db, enc: enc}, nil + store := &Store{db: db, enc: enc} + + // Resume hash chain from the most recent message (supports restarts) + var lastHash sql.NullString + var seq sql.NullInt64 + err = db.QueryRow("SELECT hash, sequence FROM messages ORDER BY sequence DESC LIMIT 1").Scan(&lastHash, &seq) + if err == nil { + store.lastHash = lastHash.String + store.sequence = seq.Int64 + } + // err == sql.ErrNoRows is fine — empty DB, chain starts fresh + + return store, nil } // Close closes the database connection. @@ -163,6 +233,7 @@ const maxStoredMessageSize = 512 * 1024 // 512KB — truncate raw payloads beyon // LogMessage records a single JSON-RPC message and returns its ID. // Raw payloads exceeding 512KB are truncated to limit DB growth from large responses. +// Each message is linked to the previous via a SHA-256 hash chain for tamper detection. func (s *Store) LogMessage(sessionID, direction, jsonrpcID, method, raw string) (string, error) { id := uuid.New().String() stored := Redact(raw) @@ -173,9 +244,16 @@ func (s *Store) LogMessage(sessionID, direction, jsonrpcID, method, raw string) s.mu.Lock() defer s.mu.Unlock() + s.sequence++ + prevHash := s.lastHash + ts := time.Now().UTC() + hash := computeHash(id, sessionID, direction, jsonrpcID, method, stored, s.sequence, prevHash) + s.lastHash = hash + _, err := s.db.Exec( - "INSERT INTO messages (id, session_id, direction, timestamp, jsonrpc_id, method, raw) VALUES (?, ?, ?, ?, ?, ?, ?)", - id, sessionID, direction, time.Now().UTC(), nullIfEmpty(jsonrpcID), nullIfEmpty(method), s.encrypt(stored), + "INSERT INTO messages (id, session_id, direction, timestamp, jsonrpc_id, method, raw, sequence, prev_hash, hash) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + id, sessionID, direction, ts, nullIfEmpty(jsonrpcID), nullIfEmpty(method), s.encrypt(stored), + s.sequence, prevHash, hash, ) if err != nil { return "", err @@ -183,6 +261,24 @@ func (s *Store) LogMessage(sessionID, direction, jsonrpcID, method, raw string) return id, nil } +// computeHash creates a SHA-256 hash of the message content chained with the previous hash. +// Uses length-prefixed encoding to prevent boundary-shifting attacks. +func computeHash(id, sessionID, direction, jsonrpcID, method, raw string, sequence int64, prevHash string) string { + h := sha256.New() + // Length-prefix each field to prevent ambiguous boundaries + for _, field := range []string{prevHash, id, sessionID, direction, jsonrpcID, method, raw} { + var lenBuf [8]byte + binary.BigEndian.PutUint64(lenBuf[:], uint64(len(field))) + h.Write(lenBuf[:]) + h.Write([]byte(field)) + } + // Include sequence number + var seqBuf [8]byte + binary.BigEndian.PutUint64(seqBuf[:], uint64(sequence)) + h.Write(seqBuf[:]) + return hex.EncodeToString(h.Sum(nil)) +} + // ToolCallRecord holds the data for creating a tool call entry. type ToolCallRecord struct { ID string @@ -272,6 +368,80 @@ func (s *Store) AddToolCallToIntent(intentID, toolCallID string, sequenceOrder i return err } +// ChainStatus holds the result of a hash chain verification. +type ChainStatus struct { + Total int // total messages checked + Valid bool // true if the entire chain is intact + BrokenAt int // sequence number where the chain broke (0 if valid) + Error string // description of the break +} + +// VerifyChain walks the message hash chain and checks for tampering. +// It recomputes each hash from stored content and verifies linkage. +// If encryption is enabled, raw content is stored encrypted — the hash +// was computed on pre-encryption content, so this method decrypts before verifying. +func (s *Store) VerifyChain() (*ChainStatus, error) { + rows, err := s.db.Query( + "SELECT id, session_id, direction, jsonrpc_id, method, raw, sequence, prev_hash, hash FROM messages ORDER BY sequence ASC", + ) + if err != nil { + return nil, fmt.Errorf("query messages: %w", err) + } + defer rows.Close() + + status := &ChainStatus{Valid: true} + expectedPrevHash := "" + + for rows.Next() { + var id, sessionID, direction, raw, prevHash, hash string + var jsonrpcID, method sql.NullString + var seq int64 + if err := rows.Scan(&id, &sessionID, &direction, &jsonrpcID, &method, &raw, &seq, &prevHash, &hash); err != nil { + return nil, fmt.Errorf("scan row: %w", err) + } + status.Total++ + + // Decrypt raw if encrypted + decrypted, err := s.decryptVerify(raw) + if err != nil { + status.Valid = false + status.BrokenAt = int(seq) + status.Error = fmt.Sprintf("sequence %d: decryption failed (wrong key or corrupted data)", seq) + return status, nil + } + + // Check prev_hash linkage + if prevHash != expectedPrevHash { + status.Valid = false + status.BrokenAt = int(seq) + status.Error = fmt.Sprintf("sequence %d: prev_hash mismatch (expected %s, got %s)", seq, expectedPrevHash, prevHash) + return status, nil + } + + // Recompute and verify hash + computed := computeHash(id, sessionID, direction, jsonrpcID.String, method.String, decrypted, seq, prevHash) + if hash != computed { + status.Valid = false + status.BrokenAt = int(seq) + status.Error = fmt.Sprintf("sequence %d: hash mismatch (record may have been modified)", seq) + return status, nil + } + + expectedPrevHash = hash + } + + return status, rows.Err() +} + +// decryptVerify decrypts a string if encryption is enabled, returning an error on failure +// (unlike decrypt which silently falls back to ciphertext). +func (s *Store) decryptVerify(ciphertext string) (string, error) { + if s.enc == nil { + return ciphertext, nil + } + return s.enc.Decrypt(ciphertext) +} + // encrypt encrypts a string if encryption is enabled. Returns plaintext otherwise. func (s *Store) encrypt(plaintext string) string { if s.enc == nil { diff --git a/internal/audit/store_test.go b/internal/audit/store_test.go index 1c6ad2b..2fc4459 100644 --- a/internal/audit/store_test.go +++ b/internal/audit/store_test.go @@ -300,3 +300,197 @@ func TestCompleteToolCall(t *testing.T) { t.Errorf("duration_ms = %d, want 50", durationMs) } } + +func TestVerifyChain_ValidChain(t *testing.T) { + store, err := Open(tempDB(t)) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + sessionID, err := store.CreateSession("test", "cat") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if _, err := store.LogMessage(sessionID, "client_to_server", "1", "initialize", `{"jsonrpc":"2.0"}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + if _, err := store.LogMessage(sessionID, "server_to_client", "1", "", `{"jsonrpc":"2.0","result":{}}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + if _, err := store.LogMessage(sessionID, "client_to_server", "2", "tools/call", `{"jsonrpc":"2.0"}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + + status, err := store.VerifyChain() + if err != nil { + t.Fatalf("VerifyChain: %v", err) + } + if !status.Valid { + t.Errorf("chain should be valid, got error at sequence %d: %s", status.BrokenAt, status.Error) + } + if status.Total != 3 { + t.Errorf("total = %d, want 3", status.Total) + } +} + +func TestVerifyChain_DetectsModifiedContent(t *testing.T) { + store, err := Open(tempDB(t)) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + sessionID, err := store.CreateSession("test", "cat") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if _, err := store.LogMessage(sessionID, "client_to_server", "1", "initialize", `{"jsonrpc":"2.0"}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + if _, err := store.LogMessage(sessionID, "server_to_client", "1", "", `{"jsonrpc":"2.0","result":{}}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + + // Tamper with a message's raw content + _, err = store.db.Exec("UPDATE messages SET raw = 'tampered' WHERE sequence = 1") + if err != nil { + t.Fatalf("tamper: %v", err) + } + + status, err := store.VerifyChain() + if err != nil { + t.Fatalf("VerifyChain: %v", err) + } + if status.Valid { + t.Error("chain should be invalid after tampering") + } + if status.BrokenAt != 1 { + t.Errorf("BrokenAt = %d, want 1", status.BrokenAt) + } +} + +func TestVerifyChain_DetectsDeletedMessage(t *testing.T) { + store, err := Open(tempDB(t)) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + sessionID, err := store.CreateSession("test", "cat") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if _, err := store.LogMessage(sessionID, "client_to_server", "1", "initialize", `{}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + if _, err := store.LogMessage(sessionID, "server_to_client", "1", "", `{}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + if _, err := store.LogMessage(sessionID, "client_to_server", "2", "tools/call", `{}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + + // Delete the middle message — breaks the prev_hash linkage + _, err = store.db.Exec("DELETE FROM messages WHERE sequence = 2") + if err != nil { + t.Fatalf("delete: %v", err) + } + + status, err := store.VerifyChain() + if err != nil { + t.Fatalf("VerifyChain: %v", err) + } + if status.Valid { + t.Error("chain should be invalid after deleting a message") + } +} + +func TestVerifyChain_EmptyDB(t *testing.T) { + store, err := Open(tempDB(t)) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + status, err := store.VerifyChain() + if err != nil { + t.Fatalf("VerifyChain: %v", err) + } + if !status.Valid { + t.Error("empty chain should be valid") + } + if status.Total != 0 { + t.Errorf("total = %d, want 0", status.Total) + } +} + +func TestVerifyChain_WithEncryption(t *testing.T) { + store, err := Open(tempDB(t), "test-key") + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store.Close() + + sessionID, err := store.CreateSession("test", "cat") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if _, err := store.LogMessage(sessionID, "client_to_server", "1", "initialize", `{"jsonrpc":"2.0"}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + if _, err := store.LogMessage(sessionID, "server_to_client", "1", "", `{"result":"ok"}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + + status, err := store.VerifyChain() + if err != nil { + t.Fatalf("VerifyChain: %v", err) + } + if !status.Valid { + t.Errorf("chain should be valid with encryption, got error at sequence %d: %s", status.BrokenAt, status.Error) + } +} + +func TestHashChain_SurvivesRestart(t *testing.T) { + dbPath := tempDB(t) + + // First "session" — write some messages + store1, err := Open(dbPath) + if err != nil { + t.Fatalf("Open: %v", err) + } + sessionID, err := store1.CreateSession("test", "cat") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if _, err := store1.LogMessage(sessionID, "client_to_server", "1", "initialize", `{}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + store1.Close() + + // Second "session" — reopen and write more + store2, err := Open(dbPath) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer store2.Close() + sessionID2, err := store2.CreateSession("test", "cat") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if _, err := store2.LogMessage(sessionID2, "client_to_server", "1", "tools/call", `{}`); err != nil { + t.Fatalf("LogMessage: %v", err) + } + + status, err := store2.VerifyChain() + if err != nil { + t.Fatalf("VerifyChain: %v", err) + } + if !status.Valid { + t.Errorf("chain should survive restart, got error at sequence %d: %s", status.BrokenAt, status.Error) + } + if status.Total != 2 { + t.Errorf("total = %d, want 2", status.Total) + } +}