diff --git a/cmd/ja4monitor/main.go b/cmd/ja4monitor/main.go index 4fd08a5..79a511e 100644 --- a/cmd/ja4monitor/main.go +++ b/cmd/ja4monitor/main.go @@ -801,7 +801,8 @@ func runTUI(src capture.PacketSource, eng *engine.Engine, firstSeen *engine.Firs evaluator.SuppressedCount() } - router := tui.NewRouter(connChan, evaluator.Subscribe(128), statsFn, engineFn, store, name) + router := tui.NewRouter(connChan, evaluator.Subscribe(128), statsFn, engineFn, store, name). + WithRuleStatsFn(evaluator.CustomRuleStats) evaluator.Start() p := tea.NewProgram(router, tea.WithAltScreen()) diff --git a/internal/anomaly/evaluator.go b/internal/anomaly/evaluator.go index 0be3dca..2a45c17 100644 --- a/internal/anomaly/evaluator.go +++ b/internal/anomaly/evaluator.go @@ -297,3 +297,25 @@ func (e *Evaluator) RuleCount() int { } return count } + +// RuleStat is a snapshot of one custom rule's activity for the summary dashboard. +type RuleStat struct { + Name string + FireCount int64 +} + +// CustomRuleStats returns a fire-count snapshot for every active custom rule. +// Returns nil when no custom rules are loaded. Safe to call from any goroutine. +func (e *Evaluator) CustomRuleStats() []RuleStat { + cr := e.customRules.Load() + if cr == nil || len(*cr) == 0 { + return nil + } + stats := make([]RuleStat, 0, len(*cr)) + for _, rule := range *cr { + if tr, ok := rule.(*ThresholdRule); ok { + stats = append(stats, RuleStat{Name: rule.Name(), FireCount: tr.FireCount()}) + } + } + return stats +} diff --git a/internal/anomaly/evaluator_test.go b/internal/anomaly/evaluator_test.go index 8936461..0eb82c6 100644 --- a/internal/anomaly/evaluator_test.go +++ b/internal/anomaly/evaluator_test.go @@ -312,3 +312,171 @@ func TestReloadConfigSwapsAllowlist(t *testing.T) { // correct } } + +// ── ReloadCustomRules ──────────────────────────────────────────────────────── + +func TestReloadCustomRules_LoadsAndEvaluates(t *testing.T) { + ev := newTestEvaluator() + sub := ev.Subscribe(16) + ev.Start() + + cfgs := []config.CustomRuleConfig{{ + Name: "rapid_conns", + Enabled: true, + Severity: "high", + Type: "threshold", + Aggregation: "count", + Field: "connection", + GroupBy: "src_ip", + Threshold: 1, + Window: "60s", + Condition: "gt", + }} + if err := ev.ReloadCustomRules(cfgs); err != nil { + t.Fatalf("ReloadCustomRules: %v", err) + } + + // Two evaluations exceed the gt-1 threshold → alert should fire. + conn := tracker.NewConnection("10.0.0.1", 1234, "1.1.1.1", 443, "tcp", time.Now()) + ev.Evaluate(conn) // count = 1 + ev.Evaluate(conn) // count = 2 > 1 → fire + time.Sleep(20 * time.Millisecond) + + select { + case a := <-sub: + if a.Rule != "rapid_conns" { + t.Errorf("unexpected rule name: %q", a.Rule) + } + default: + t.Error("expected custom rule alert, got none") + } +} + +func TestReloadCustomRules_InvalidConfigKeepsOldRules(t *testing.T) { + ev := newTestEvaluator() + + // Load a valid rule first. + valid := []config.CustomRuleConfig{{ + Name: "valid", + Enabled: true, + Severity: "high", + Type: "threshold", + Aggregation: "count", + Field: "connection", + GroupBy: "src_ip", + Threshold: 1, + Window: "60s", + Condition: "gt", + }} + if err := ev.ReloadCustomRules(valid); err != nil { + t.Fatalf("initial load failed: %v", err) + } + stats := ev.CustomRuleStats() + if len(stats) != 1 || stats[0].Name != "valid" { + t.Fatalf("expected 1 rule 'valid', got %v", stats) + } + + // Attempt reload with an invalid window duration — must fail without + // replacing the current rule set. + bad := []config.CustomRuleConfig{{ + Name: "bad", + Enabled: true, + Severity: "high", + Type: "threshold", + Aggregation: "count", + Field: "connection", + GroupBy: "src_ip", + Threshold: 1, + Window: "not-a-duration", // invalid + Condition: "gt", + }} + if err := ev.ReloadCustomRules(bad); err == nil { + t.Fatal("expected error for invalid window, got nil") + } + + // Old rule set must still be intact. + stats = ev.CustomRuleStats() + if len(stats) != 1 || stats[0].Name != "valid" { + t.Errorf("expected old rule 'valid' preserved after failed reload, got %v", stats) + } +} + +func TestReloadCustomRules_ClearRules(t *testing.T) { + ev := newTestEvaluator() + + cfgs := []config.CustomRuleConfig{{ + Name: "r", + Enabled: true, + Severity: "high", + Type: "threshold", + Aggregation: "count", + Field: "connection", + GroupBy: "src_ip", + Threshold: 1, + Window: "60s", + Condition: "gt", + }} + if err := ev.ReloadCustomRules(cfgs); err != nil { + t.Fatalf("load: %v", err) + } + + // Passing nil clears all custom rules. + if err := ev.ReloadCustomRules(nil); err != nil { + t.Fatalf("clear: %v", err) + } + + if stats := ev.CustomRuleStats(); stats != nil { + t.Errorf("after clear, expected nil stats, got %v", stats) + } +} + +// ── CustomRuleStats ────────────────────────────────────────────────────────── + +func TestCustomRuleStats_NilWhenNoRules(t *testing.T) { + ev := newTestEvaluator() + if stats := ev.CustomRuleStats(); stats != nil { + t.Errorf("expected nil with no custom rules, got %v", stats) + } +} + +func TestCustomRuleStats_ReturnsFireCounts(t *testing.T) { + ev := newTestEvaluator() + ev.Start() + + cfgs := []config.CustomRuleConfig{ + { + Name: "rule_a", Enabled: true, Severity: "high", Type: "threshold", + Aggregation: "count", Field: "connection", GroupBy: "src_ip", + Threshold: 1, Window: "60s", Condition: "gt", + }, + { + Name: "rule_b", Enabled: true, Severity: "medium", Type: "threshold", + Aggregation: "count", Field: "connection", GroupBy: "src_ip", + Threshold: 5, Window: "60s", Condition: "gt", + }, + } + if err := ev.ReloadCustomRules(cfgs); err != nil { + t.Fatalf("ReloadCustomRules: %v", err) + } + + // Fire rule_a (threshold 1) but not rule_b (threshold 5). + conn := tracker.NewConnection("10.0.0.1", 1234, "1.1.1.1", 443, "tcp", time.Now()) + ev.Evaluate(conn) // count=1 + ev.Evaluate(conn) // count=2 > 1 → rule_a fires + + stats := ev.CustomRuleStats() + if len(stats) != 2 { + t.Fatalf("expected 2 stats entries, got %d", len(stats)) + } + + byName := make(map[string]RuleStat, len(stats)) + for _, s := range stats { + byName[s.Name] = s + } + if byName["rule_a"].FireCount != 1 { + t.Errorf("rule_a FireCount = %d, want 1", byName["rule_a"].FireCount) + } + if byName["rule_b"].FireCount != 0 { + t.Errorf("rule_b FireCount = %d, want 0", byName["rule_b"].FireCount) + } +} diff --git a/internal/anomaly/threshold.go b/internal/anomaly/threshold.go index e399c7e..29adbc5 100644 --- a/internal/anomaly/threshold.go +++ b/internal/anomaly/threshold.go @@ -5,6 +5,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/Crank-Git/ja4monitor/internal/config" @@ -60,10 +61,10 @@ type ThresholdRule struct { mu sync.Mutex buckets map[string]*windowBucket - // fireCount is incremented without the lock every time this rule fires. - // The summary dashboard reads it with a relaxed load — display staleness - // of a single increment is acceptable. - fireCount int64 // accessed via sync/atomic in FireCount + // fireCount is incremented every time this rule fires. atomic.Int64 because + // the summary dashboard reads it from the bubbletea goroutine concurrently + // with engine shards calling Evaluate (and potentially incrementing it). + fireCount atomic.Int64 } // newThresholdRule constructs a ThresholdRule from a validated CustomRuleConfig. @@ -106,12 +107,7 @@ func (r *ThresholdRule) Name() string { return r.cfg.Name } // FireCount returns the total number of times this rule has fired. // Safe to call from any goroutine without holding mu. -func (r *ThresholdRule) FireCount() int64 { - // Direct load is safe on all supported architectures (64-bit aligned field). - // We avoid importing sync/atomic here to keep the method simple; callers - // that need strict ordering should use their own fence. - return r.fireCount -} +func (r *ThresholdRule) FireCount() int64 { return r.fireCount.Load() } // Evaluate implements Rule. Called on every packet for a connection. // @@ -169,7 +165,7 @@ func (r *ThresholdRule) Evaluate(conn *tracker.Connection) []Alert { return nil } - r.fireCount++ // relaxed increment; see FireCount comment + r.fireCount.Add(1) return []Alert{r.buildAlert(conn, groupKey, count)} } diff --git a/internal/anomaly/threshold_test.go b/internal/anomaly/threshold_test.go index be76af2..57a855d 100644 --- a/internal/anomaly/threshold_test.go +++ b/internal/anomaly/threshold_test.go @@ -98,6 +98,40 @@ func TestThresholdRule_CountEQ_FiresExactly(t *testing.T) { } } +func TestThresholdRule_CountLT_FiresWhenBelowThreshold(t *testing.T) { + // lt fires when count < threshold. With threshold=3 and count_distinct + // over a compound group, the simplest test: one evaluation produces + // count=1, which is < 3. + rule := makeRule(t, "r", "count", "connection", "src_ip", "lt", 3, "60s") + conn := makeConn("10.0.0.1", "1.2.3.4", 80) + + // First call: count = 1, which is < 3 → should fire + alerts := rule.Evaluate(conn) + if len(alerts) != 1 { + t.Fatalf("lt: expected 1 alert when count(1) < threshold(3), got %d", len(alerts)) + } +} + +func TestThresholdRule_CountLTE_FiresAtThreshold(t *testing.T) { + // lte fires when count <= threshold. With threshold=1, the first + // evaluation (count=1) should fire; subsequent calls within the window + // must not (once-per-window dedup via lastFired). + rule := makeRule(t, "r", "count", "connection", "src_ip", "lte", 1, "60s") + conn := makeConn("10.0.0.1", "1.2.3.4", 80) + + // First call: count = 1 ≤ 1 → should fire + alerts := rule.Evaluate(conn) + if len(alerts) != 1 { + t.Fatalf("lte: expected 1 alert when count(1) <= threshold(1), got %d", len(alerts)) + } + + // Second call within same window: count = 2 > 1 → must not fire + alerts = rule.Evaluate(conn) + if len(alerts) != 0 { + t.Fatalf("lte: should not fire when count(2) > threshold(1), got %d alerts", len(alerts)) + } +} + func TestThresholdRule_OncePerWindow(t *testing.T) { rule := makeRule(t, "r", "count", "connection", "src_ip", "gt", 1, "60s") conn := makeConn("10.0.0.1", "1.2.3.4", 80) diff --git a/internal/storage/sqlite.go b/internal/storage/sqlite.go index 405f03e..e81b2d7 100644 --- a/internal/storage/sqlite.go +++ b/internal/storage/sqlite.go @@ -413,6 +413,136 @@ func (s *Store) Query(f filter.Filter, timeout time.Duration) (results []*tracke return results, false, rows.Err() } +// QueryTimeWindow returns connections whose lifetimes overlap [start, end]: +// any connection where first_seen <= end AND last_seen >= start. Results are +// ordered by last_seen DESC, capped at 500. Uses the read-only handle so +// concurrent main-handle writes are not blocked. +// +// A 5-second default timeout applies (caller may override). timedOut=true +// means the result set is partial. +func (s *Store) QueryTimeWindow(start, end time.Time, timeout time.Duration) (results []*tracker.Connection, timedOut bool, err error) { + if timeout <= 0 { + timeout = 5 * time.Second + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + const query = `SELECT c.id, c.src_ip, c.src_port, c.dst_ip, c.dst_port, c.protocol, + c.first_seen, c.last_seen, c.packet_count, c.byte_count, + c.fingerprints, c.identified_app, c.is_known_bad + FROM connections c + WHERE c.first_seen <= ? AND c.last_seen >= ? + ORDER BY c.last_seen DESC LIMIT 500` + + rows, err := s.readOnly.QueryContext(ctx, query, + end.UTC().Format(time.RFC3339), + start.UTC().Format(time.RFC3339), + ) + if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return nil, true, nil + } + return nil, false, err + } + defer rows.Close() + + for rows.Next() { + if ctx.Err() == context.DeadlineExceeded { + return results, true, nil + } + c, err := scanConnection(rows) + if err != nil { + s.reporter.IncFlushFailure() + continue + } + results = append(results, c) + } + if ctx.Err() == context.DeadlineExceeded { + return results, true, nil + } + return results, false, rows.Err() +} + +// DiffSpikeMultiplier and DiffMinDelta are the thresholds for spike/drop +// detection in DiffWindows. A fingerprint "spikes" in B when B count is at +// least DiffSpikeMultiplier times A count AND the absolute delta exceeds +// DiffMinDelta (to filter low-frequency noise). +const ( + DiffSpikeMultiplier = 3 + DiffMinDelta = 10 +) + +// DiffWindows counts fingerprint occurrences in two adjacent time windows and +// returns both maps so the caller can compute new/gone/spiked/dropped entries. +// +// Window layout (non-overlapping, A precedes B): +// +// [aStart ─────── aEnd][bStart ─────── bEnd] +// └── typically aEnd == bStart +// +// Map structure: [fpType][fpValue] = connection count. +// Uses the read-only handle; a 5-second default timeout applies. +func (s *Store) DiffWindows( + aStart, aEnd time.Time, + bStart, bEnd time.Time, + timeout time.Duration, +) (aMap, bMap map[string]map[string]int, err error) { + if timeout <= 0 { + timeout = 5 * time.Second + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + aMap, err = s.countFingerprintsInWindow(ctx, aStart, aEnd) + if err != nil { + return nil, nil, err + } + bMap, err = s.countFingerprintsInWindow(ctx, bStart, bEnd) + if err != nil { + return nil, nil, err + } + return aMap, bMap, nil +} + +// countFingerprintsInWindow returns a [fpType][fpValue]=count map for all +// fingerprints seen in connections that overlap the given time window. +func (s *Store) countFingerprintsInWindow(ctx context.Context, start, end time.Time) (map[string]map[string]int, error) { + const query = `SELECT fi.fp_type, fi.fp_value, COUNT(*) AS cnt + FROM fingerprint_index fi + JOIN connections c ON c.id = fi.conn_id + WHERE c.first_seen <= ? AND c.last_seen >= ? + GROUP BY fi.fp_type, fi.fp_value` + + rows, err := s.readOnly.QueryContext(ctx, query, + end.UTC().Format(time.RFC3339), + start.UTC().Format(time.RFC3339), + ) + if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return nil, context.DeadlineExceeded + } + return nil, err + } + defer rows.Close() + + out := make(map[string]map[string]int) + for rows.Next() { + if ctx.Err() == context.DeadlineExceeded { + return nil, context.DeadlineExceeded + } + var fpType, fpValue string + var cnt int + if err := rows.Scan(&fpType, &fpValue, &cnt); err != nil { + return nil, err + } + if out[fpType] == nil { + out[fpType] = make(map[string]int) + } + out[fpType][fpValue] = cnt + } + return out, rows.Err() +} + // scanConnection reads a single row from either the read-only or main // handle into a Connection. Handles both v1-legacy and v2 fingerprints // JSON formats (v1 backfill might not have finished yet). diff --git a/internal/storage/sqlite_test.go b/internal/storage/sqlite_test.go index cfd60a3..9afc079 100644 --- a/internal/storage/sqlite_test.go +++ b/internal/storage/sqlite_test.go @@ -1,6 +1,7 @@ package storage import ( + "fmt" "os" "path/filepath" "testing" @@ -213,3 +214,233 @@ func TestStore_BookmarkPersistsAcrossReopen(t *testing.T) { t.Error("bookmark did not survive store close+reopen") } } + +// ── QueryTimeWindow ────────────────────────────────────────────────────────── + +// insertConn is a test helper that queues and flushes a single connection. +func insertConn(t *testing.T, s *Store, srcIP string, firstSeen, lastSeen time.Time) *tracker.Connection { + t.Helper() + conn := tracker.NewConnection(srcIP, 12345, "1.1.1.1", 443, "tcp", firstSeen) + conn.LastSeen = lastSeen + s.QueueConnection(conn) + s.Flush() + return conn +} + +func TestQueryTimeWindow_MatchesOverlappingConnections(t *testing.T) { + store, err := NewStore(tempDB(t)) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + now := time.Now().UTC().Truncate(time.Second) + // inside: first_seen < window start, last_seen inside window → overlaps + inside := insertConn(t, store, "10.0.0.1", now.Add(-2*time.Hour), now.Add(-30*time.Minute)) + // outside-before: entirely before the window + insertConn(t, store, "10.0.0.2", now.Add(-4*time.Hour), now.Add(-2*time.Hour).Add(-time.Second)) + // outside-after: entirely after the window + insertConn(t, store, "10.0.0.3", now.Add(time.Hour), now.Add(2*time.Hour)) + // straddles: started before, ended after — still overlaps + straddles := insertConn(t, store, "10.0.0.4", now.Add(-3*time.Hour), now) + + windowStart := now.Add(-90 * time.Minute) + windowEnd := now.Add(-15 * time.Minute) + + results, timedOut, err := store.QueryTimeWindow(windowStart, windowEnd, 5*time.Second) + if err != nil { + t.Fatalf("QueryTimeWindow error: %v", err) + } + if timedOut { + t.Error("QueryTimeWindow should not time out on a tiny DB") + } + + ids := make(map[string]bool, len(results)) + for _, c := range results { + ids[c.ID] = true + } + if !ids[inside.ID] { + t.Errorf("expected conn %s (inside/overlapping) in results", inside.ID) + } + if !ids[straddles.ID] { + t.Errorf("expected conn %s (straddles window) in results", straddles.ID) + } + if len(results) != 2 { + t.Errorf("expected 2 results, got %d (IDs: %v)", len(results), results) + } +} + +func TestQueryTimeWindow_EmptyWindowReturnsEmpty(t *testing.T) { + store, err := NewStore(tempDB(t)) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + now := time.Now().UTC() + // Insert a connection entirely outside the query window. + insertConn(t, store, "10.0.0.1", now.Add(-3*time.Hour), now.Add(-2*time.Hour)) + + results, _, err := store.QueryTimeWindow(now.Add(-time.Hour), now, 5*time.Second) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(results) != 0 { + t.Errorf("expected 0 results for empty window, got %d", len(results)) + } +} + +func TestQueryTimeWindow_EmptyDBReturnsEmpty(t *testing.T) { + store, err := NewStore(tempDB(t)) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + now := time.Now().UTC() + results, timedOut, err := store.QueryTimeWindow(now.Add(-time.Hour), now, 5*time.Second) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if timedOut { + t.Error("should not time out on empty DB") + } + if len(results) != 0 { + t.Errorf("expected 0 results, got %d", len(results)) + } +} + +func TestQueryTimeWindow_DefaultTimeout(t *testing.T) { + store, err := NewStore(tempDB(t)) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + now := time.Now().UTC() + // Passing timeout=0 should use the 5s default — no error. + _, _, err = store.QueryTimeWindow(now.Add(-time.Hour), now, 0) + if err != nil { + t.Fatalf("default-timeout query returned error: %v", err) + } +} + +// ── DiffWindows ────────────────────────────────────────────────────────────── + +// insertConnWithFP queues and flushes a connection with a single fingerprint. +func insertConnWithFP(t *testing.T, s *Store, srcIP, fpType, fpValue string, firstSeen, lastSeen time.Time) { + t.Helper() + conn := tracker.NewConnection(srcIP, 12345, "1.1.1.1", 443, "tcp", firstSeen) + conn.LastSeen = lastSeen + conn.AddFingerprint(fpType, fpValue, firstSeen, 1) + s.QueueConnection(conn) + s.Flush() +} + +func TestDiffWindows_NewAndGone(t *testing.T) { + store, err := NewStore(tempDB(t)) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + now := time.Now().UTC().Truncate(time.Second) + // Connection in window A only. + insertConnWithFP(t, store, "10.0.0.1", "ja4", "t13d_baseline", + now.Add(-2*time.Hour), now.Add(-90*time.Minute)) + // Connection in window B only. + insertConnWithFP(t, store, "10.0.0.2", "ja4", "t13d_new", + now.Add(-30*time.Minute), now.Add(-10*time.Minute)) + + aStart := now.Add(-2 * time.Hour) + aEnd := now.Add(-time.Hour) + bStart := now.Add(-time.Hour) + bEnd := now + + aMap, bMap, err := store.DiffWindows(aStart, aEnd, bStart, bEnd, 5*time.Second) + if err != nil { + t.Fatalf("DiffWindows error: %v", err) + } + + if aMap["ja4"]["t13d_baseline"] != 1 { + t.Errorf("baseline fp should appear in A with count 1, got %d", aMap["ja4"]["t13d_baseline"]) + } + if bMap["ja4"]["t13d_baseline"] != 0 { + t.Errorf("baseline fp should not appear in B, got %d", bMap["ja4"]["t13d_baseline"]) + } + if bMap["ja4"]["t13d_new"] != 1 { + t.Errorf("new fp should appear in B with count 1, got %d", bMap["ja4"]["t13d_new"]) + } + if aMap["ja4"]["t13d_new"] != 0 { + t.Errorf("new fp should not appear in A, got %d", aMap["ja4"]["t13d_new"]) + } +} + +func TestDiffWindows_CountsMultipleConns(t *testing.T) { + store, err := NewStore(tempDB(t)) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + now := time.Now().UTC().Truncate(time.Second) + bStart := now.Add(-time.Hour) + bEnd := now + + // 3 connections with the same fingerprint in window B. + for i := 0; i < 3; i++ { + insertConnWithFP(t, store, fmt.Sprintf("10.0.0.%d", i+1), "ja4", "t13d_common", + now.Add(-30*time.Minute), now.Add(-10*time.Minute)) + } + + _, bMap, err := store.DiffWindows( + now.Add(-2*time.Hour), now.Add(-time.Hour), + bStart, bEnd, + 5*time.Second, + ) + if err != nil { + t.Fatalf("DiffWindows error: %v", err) + } + if bMap["ja4"]["t13d_common"] != 3 { + t.Errorf("expected count=3 for 3 connections, got %d", bMap["ja4"]["t13d_common"]) + } +} + +func TestDiffWindows_EmptyDB(t *testing.T) { + store, err := NewStore(tempDB(t)) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + now := time.Now().UTC() + aMap, bMap, err := store.DiffWindows( + now.Add(-2*time.Hour), now.Add(-time.Hour), + now.Add(-time.Hour), now, + 5*time.Second, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(aMap) != 0 || len(bMap) != 0 { + t.Errorf("empty DB: expected empty maps, got aMap=%v bMap=%v", aMap, bMap) + } +} + +func TestDiffWindows_DefaultTimeout(t *testing.T) { + store, err := NewStore(tempDB(t)) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + now := time.Now().UTC() + _, _, err = store.DiffWindows( + now.Add(-2*time.Hour), now.Add(-time.Hour), + now.Add(-time.Hour), now, + 0, // 0 → default 5s timeout + ) + if err != nil { + t.Fatalf("default-timeout DiffWindows returned error: %v", err) + } +} diff --git a/internal/tui/diff.go b/internal/tui/diff.go new file mode 100644 index 0000000..aa9cf16 --- /dev/null +++ b/internal/tui/diff.go @@ -0,0 +1,402 @@ +package tui + +import ( + "fmt" + "sort" + "strings" + "time" + + "github.com/charmbracelet/bubbles/key" + tea "github.com/charmbracelet/bubbletea" + + "github.com/Crank-Git/ja4monitor/internal/storage" +) + +// diffSubState tracks which step of the two-window input the user is in. +type diffSubState int + +const ( + diffInputB diffSubState = iota // user enters comparison window B (recent) + diffInputA // user enters baseline window A (older) + diffResults // showing diff results +) + +// diffChange classifies how a fingerprint changed between windows A and B. +type diffChange int + +const ( + diffNew diffChange = iota // in B, not in A + diffGone // in A, not in B + diffSpiked // B count >> A count + diffDropped // A count >> B count +) + +// diffEntry is one row in the rendered diff table. +type diffEntry struct { + fpType string + fpValue string + aCount int + bCount int + change diffChange +} + +// DiffModel renders the fingerprint diff view (D key). +// +// The user types two adjacent, non-overlapping windows: +// +// [aStart ──── aEnd][bStart ──── bEnd=now] +// └── aEnd == bStart +// +// Step 1: enter B duration (comparison, ending at now). +// Step 2: enter A duration (baseline, immediately preceding B). +// Results show new/gone/spiked/dropped fingerprints between the windows. +type DiffModel struct { + subState diffSubState + inputBuf string + inputErr string + bDur time.Duration // comparison window size + pending bool + entries []diffEntry + queryErr error + aStart time.Time + aEnd time.Time + bStart time.Time + bEnd time.Time + scrollOff int +} + +func newDiffModel() DiffModel { return DiffModel{} } + +// diffResultMsg delivers the async DiffWindows query results. +type diffResultMsg struct { + aMap map[string]map[string]int + bMap map[string]map[string]int + err error + aStart time.Time + aEnd time.Time + bStart time.Time + bEnd time.Time +} + +// Update handles diff mode input and result delivery. +func (d DiffModel) Update(msg tea.Msg, r *Router) (DiffModel, tea.Cmd) { + switch m := msg.(type) { + case diffResultMsg: + d.pending = false + d.queryErr = m.err + d.aStart = m.aStart + d.aEnd = m.aEnd + d.bStart = m.bStart + d.bEnd = m.bEnd + if m.err == nil { + d.entries = computeDiff(m.aMap, m.bMap) + } + d.subState = diffResults + d.scrollOff = 0 + return d, nil + + case tea.KeyMsg: + switch d.subState { + case diffInputB: + return d.updateInputB(m, r) + case diffInputA: + return d.updateInputA(m, r) + case diffResults: + return d.updateResults(m, r) + } + } + return d, nil +} + +func (d DiffModel) updateInputB(m tea.KeyMsg, r *Router) (DiffModel, tea.Cmd) { + switch { + case key.Matches(m, keys.Quit): + return d, tea.Quit + case key.Matches(m, keys.Back): + r.state = stateTable + case m.Type == tea.KeyEnter: + dur, err := time.ParseDuration(strings.TrimSpace(d.inputBuf)) + if err != nil || dur <= 0 { + d.inputErr = "invalid duration — try 1h, 30m, 24h" + return d, nil + } + d.bDur = dur + d.inputErr = "" + d.inputBuf = "" + d.subState = diffInputA + case m.Type == tea.KeyBackspace || m.Type == tea.KeyDelete: + if len(d.inputBuf) > 0 { + d.inputBuf = d.inputBuf[:len(d.inputBuf)-1] + } + d.inputErr = "" + case m.Type == tea.KeyRunes: + d.inputBuf += string(m.Runes) + d.inputErr = "" + } + return d, nil +} + +func (d DiffModel) updateInputA(m tea.KeyMsg, r *Router) (DiffModel, tea.Cmd) { + switch { + case key.Matches(m, keys.Quit): + return d, tea.Quit + case key.Matches(m, keys.Back): + // Back from step 2 → return to step 1. + d.subState = diffInputB + d.inputBuf = "" + d.inputErr = "" + case m.Type == tea.KeyEnter: + dur, err := time.ParseDuration(strings.TrimSpace(d.inputBuf)) + if err != nil || dur <= 0 { + d.inputErr = "invalid duration — try 1h, 30m, 24h" + return d, nil + } + d.inputErr = "" + now := time.Now() + bEnd := now + bStart := now.Add(-d.bDur) + aEnd := bStart + aStart := aEnd.Add(-dur) + d.pending = true + d.subState = diffResults + return d, d.queryCmd(r, aStart, aEnd, bStart, bEnd) + case m.Type == tea.KeyBackspace || m.Type == tea.KeyDelete: + if len(d.inputBuf) > 0 { + d.inputBuf = d.inputBuf[:len(d.inputBuf)-1] + } + d.inputErr = "" + case m.Type == tea.KeyRunes: + d.inputBuf += string(m.Runes) + d.inputErr = "" + } + return d, nil +} + +func (d DiffModel) updateResults(m tea.KeyMsg, r *Router) (DiffModel, tea.Cmd) { + switch { + case key.Matches(m, keys.Quit): + return d, tea.Quit + case key.Matches(m, keys.Back): + r.state = stateTable + case key.Matches(m, keys.Up): + if d.scrollOff > 0 { + d.scrollOff-- + } + case key.Matches(m, keys.Down): + if d.scrollOff < len(d.entries)-1 { + d.scrollOff++ + } + case m.Type == tea.KeyRunes && string(m.Runes) == "r": + // r → start a new diff from step 1 + d = newDiffModel() + } + return d, nil +} + +// queryCmd fires DiffWindows on a background goroutine and returns a +// diffResultMsg. Falls back to an empty result when no diffFn is wired +// (daemon attach mode with no local store). +func (d DiffModel) queryCmd(r *Router, aStart, aEnd, bStart, bEnd time.Time) tea.Cmd { + fn := r.diffFn + return func() tea.Msg { + if fn == nil { + return diffResultMsg{aStart: aStart, aEnd: aEnd, bStart: bStart, bEnd: bEnd} + } + aMap, bMap, err := fn(aStart, aEnd, bStart, bEnd, 5*time.Second) + return diffResultMsg{ + aMap: aMap, + bMap: bMap, + err: err, + aStart: aStart, + aEnd: aEnd, + bStart: bStart, + bEnd: bEnd, + } + } +} + +// computeDiff classifies fingerprints into new/gone/spiked/dropped entries. +// Only entries that changed are included — unchanged fingerprints are skipped. +func computeDiff(aMap, bMap map[string]map[string]int) []diffEntry { + // Collect all unique fp_types. + allTypes := make(map[string]bool) + for t := range aMap { + allTypes[t] = true + } + for t := range bMap { + allTypes[t] = true + } + + var entries []diffEntry + for fpType := range allTypes { + aVals := aMap[fpType] + bVals := bMap[fpType] + + // Union of all fp_values for this type. + allVals := make(map[string]bool) + for v := range aVals { + allVals[v] = true + } + for v := range bVals { + allVals[v] = true + } + + for fpValue := range allVals { + ac := aVals[fpValue] // 0 if absent + bc := bVals[fpValue] // 0 if absent + + var change diffChange + switch { + case ac == 0 && bc > 0: + change = diffNew + case ac > 0 && bc == 0: + change = diffGone + case ac > 0 && bc > storage.DiffSpikeMultiplier*ac && (bc-ac) > storage.DiffMinDelta: + change = diffSpiked + case ac > 0 && ac > storage.DiffSpikeMultiplier*bc && (ac-bc) > storage.DiffMinDelta: + change = diffDropped + default: + continue // unchanged — skip + } + entries = append(entries, diffEntry{ + fpType: fpType, + fpValue: fpValue, + aCount: ac, + bCount: bc, + change: change, + }) + } + } + + // Sort: by change category, then fp_type, then fp_value. + sort.Slice(entries, func(i, j int) bool { + if entries[i].change != entries[j].change { + return entries[i].change < entries[j].change + } + if entries[i].fpType != entries[j].fpType { + return entries[i].fpType < entries[j].fpType + } + return entries[i].fpValue < entries[j].fpValue + }) + return entries +} + +// changeLabel returns the short color-coded label for a diff category. +func changeLabel(c diffChange) string { + switch c { + case diffNew: + return diffNewStyle.Render("NEW ") + case diffGone: + return diffGoneStyle.Render("GONE ") + case diffSpiked: + return diffSpikedStyle.Render("SPIKED ") + case diffDropped: + return diffDroppedStyle.Render("DROPPED") + } + return " " +} + +// View renders the full-screen fingerprint diff dashboard. +func (d DiffModel) View(r *Router) string { + var b strings.Builder + + hdr := fmt.Sprintf(" DIFF %s", r.ifaceName) + b.WriteString(headerStyle.Width(r.width).Render(hdr)) + b.WriteString("\n") + + switch d.subState { + case diffInputB: + b.WriteString("\n") + b.WriteString(" Step 1 of 2 — Comparison window (recent, ending now)\n") + b.WriteString(" Duration (e.g. 1h, 30m, 24h):\n\n") + b.WriteString(" > ") + b.WriteString(d.inputBuf) + b.WriteString("█\n") + if d.inputErr != "" { + b.WriteString("\n ") + b.WriteString(severityHighStyle.Render(d.inputErr)) + b.WriteString("\n") + } + b.WriteString("\n [Enter]next [Esc]back to table\n") + + case diffInputA: + b.WriteString("\n") + b.WriteString(fmt.Sprintf(" Comparison B: last %s ✓\n", d.bDur)) + b.WriteString("\n") + b.WriteString(" Step 2 of 2 — Baseline window (immediately before B)\n") + b.WriteString(" Duration (e.g. 1h, 30m, 24h):\n\n") + b.WriteString(" > ") + b.WriteString(d.inputBuf) + b.WriteString("█\n") + if d.inputErr != "" { + b.WriteString("\n ") + b.WriteString(severityHighStyle.Render(d.inputErr)) + b.WriteString("\n") + } + b.WriteString("\n [Enter]run diff [Esc]back to step 1\n") + + case diffResults: + if d.pending { + b.WriteString("\n Computing diff") + b.WriteString(fmt.Sprintf(" A: %s → %s", + d.aStart.Format("15:04:05"), d.aEnd.Format("15:04:05"))) + b.WriteString(fmt.Sprintf(" B: %s → %s ...\n", + d.bStart.Format("15:04:05"), d.bEnd.Format("15:04:05"))) + } else if d.queryErr != nil { + b.WriteString("\n ") + b.WriteString(severityHighStyle.Render("Query error: " + d.queryErr.Error())) + b.WriteString("\n") + } else { + b.WriteString("\n") + b.WriteString(fmt.Sprintf(" A (baseline): %s → %s\n", + d.aStart.Format("2006-01-02 15:04:05"), + d.aEnd.Format("2006-01-02 15:04:05"))) + b.WriteString(fmt.Sprintf(" B (comparison): %s → %s\n", + d.bStart.Format("2006-01-02 15:04:05"), + d.bEnd.Format("2006-01-02 15:04:05"))) + b.WriteString("\n") + + if len(d.entries) == 0 { + b.WriteString(" No fingerprint changes between windows.\n") + } else { + b.WriteString(fmt.Sprintf(" %d changed fingerprints\n\n", len(d.entries))) + + // Column header. + b.WriteString(sectionHeaderStyle.Render(fmt.Sprintf( + " %-7s %-6s %-14s %-5s %-5s %s", + "CHANGE", "TYPE", "FINGERPRINT", "A", "B", ""))) + b.WriteString("\n") + + maxRows := r.height - 10 + if maxRows < 1 { + maxRows = 1 + } + end := d.scrollOff + maxRows + if end > len(d.entries) { + end = len(d.entries) + } + for _, e := range d.entries[d.scrollOff:end] { + b.WriteString(fmt.Sprintf(" %s %-6s %-14s %-5d %-5d\n", + changeLabel(e.change), + truncate(e.fpType, 6), + truncate(e.fpValue, 14), + e.aCount, e.bCount, + )) + } + if len(d.entries) > maxRows { + b.WriteString(fmt.Sprintf("\n row %d–%d of %d ↑/k ↓/j\n", + d.scrollOff+1, end, len(d.entries))) + } + } + } + } + + footer := " [Esc]back [r]new diff [D]close [q]quit" + if d.subState == diffInputB || d.subState == diffInputA { + footer = " [Esc]back [Enter]next/run [q]quit" + } + if status := r.currentStatus(); status != "" { + return finalize(r, b.String(), learningStyle.Render(" "+status)) + } + return finalize(r, b.String(), footerStyle.Render(footer)) +} diff --git a/internal/tui/diff_test.go b/internal/tui/diff_test.go new file mode 100644 index 0000000..f362dcd --- /dev/null +++ b/internal/tui/diff_test.go @@ -0,0 +1,558 @@ +package tui + +import ( + "testing" + "time" + + tea "github.com/charmbracelet/bubbletea" +) + +// ── computeDiff ────────────────────────────────────────────────────────────── + +func TestComputeDiff_NewFingerprint(t *testing.T) { + aMap := map[string]map[string]int{} + bMap := map[string]map[string]int{ + "ja4": {"t13d_new_fp": 5}, + } + entries := computeDiff(aMap, bMap) + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + if entries[0].change != diffNew { + t.Errorf("change = %v, want diffNew", entries[0].change) + } + if entries[0].fpValue != "t13d_new_fp" { + t.Errorf("fpValue = %q, want %q", entries[0].fpValue, "t13d_new_fp") + } + if entries[0].aCount != 0 || entries[0].bCount != 5 { + t.Errorf("counts = %d/%d, want 0/5", entries[0].aCount, entries[0].bCount) + } +} + +func TestComputeDiff_GoneFingerprint(t *testing.T) { + aMap := map[string]map[string]int{ + "ja4": {"t13d_old_fp": 3}, + } + bMap := map[string]map[string]int{} + entries := computeDiff(aMap, bMap) + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + if entries[0].change != diffGone { + t.Errorf("change = %v, want diffGone", entries[0].change) + } +} + +func TestComputeDiff_SpikedFingerprint(t *testing.T) { + // B count is >3x A count AND delta > 10. + aMap := map[string]map[string]int{"ja4": {"t13d_fp": 5}} + bMap := map[string]map[string]int{"ja4": {"t13d_fp": 20}} // 4x, delta=15 + entries := computeDiff(aMap, bMap) + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + if entries[0].change != diffSpiked { + t.Errorf("change = %v, want diffSpiked", entries[0].change) + } +} + +func TestComputeDiff_DroppedFingerprint(t *testing.T) { + // A count is >3x B count AND delta > 10. + aMap := map[string]map[string]int{"ja4": {"t13d_fp": 20}} + bMap := map[string]map[string]int{"ja4": {"t13d_fp": 5}} // 4x, delta=15 + entries := computeDiff(aMap, bMap) + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + if entries[0].change != diffDropped { + t.Errorf("change = %v, want diffDropped", entries[0].change) + } +} + +func TestComputeDiff_UnchangedSkipped(t *testing.T) { + // Same count in both windows — should produce 0 entries. + aMap := map[string]map[string]int{"ja4": {"t13d_same": 5}} + bMap := map[string]map[string]int{"ja4": {"t13d_same": 5}} + entries := computeDiff(aMap, bMap) + if len(entries) != 0 { + t.Errorf("unchanged fingerprint should be skipped, got %d entries", len(entries)) + } +} + +func TestComputeDiff_SpikeRequiresMinDelta(t *testing.T) { + // 4x ratio but absolute delta = 3 (below threshold 10) → not a spike. + aMap := map[string]map[string]int{"ja4": {"t13d_fp": 1}} + bMap := map[string]map[string]int{"ja4": {"t13d_fp": 4}} + entries := computeDiff(aMap, bMap) + // 4 > 3*1 but delta=3 < 10 → not spiked; 4 != 0 so not new; not gone. + if len(entries) != 0 { + t.Errorf("small-delta spike should be skipped, got %d entries", len(entries)) + } +} + +func TestComputeDiff_EmptyBothMaps(t *testing.T) { + entries := computeDiff(nil, nil) + if len(entries) != 0 { + t.Errorf("empty maps: expected 0 entries, got %d", len(entries)) + } +} + +func TestComputeDiff_MultipleFPTypes(t *testing.T) { + aMap := map[string]map[string]int{ + "ja4": {"old_fp": 5}, + "ja4h": {"http_fp": 2}, + } + bMap := map[string]map[string]int{ + "ja4": {"new_fp": 3}, + "ja4h": {"http_fp": 2}, // unchanged + } + entries := computeDiff(aMap, bMap) + // old_fp: gone, new_fp: new; http_fp: unchanged (skipped) + if len(entries) != 2 { + t.Fatalf("expected 2 entries (gone + new), got %d", len(entries)) + } + changes := make(map[diffChange]int) + for _, e := range entries { + changes[e.change]++ + } + if changes[diffNew] != 1 || changes[diffGone] != 1 { + t.Errorf("changes = %v, want 1 new + 1 gone", changes) + } +} + +func TestComputeDiff_SortOrder(t *testing.T) { + // new < gone < spiked < dropped — new entries should come first. + aMap := map[string]map[string]int{ + "ja4": {"gone_fp": 5}, + } + bMap := map[string]map[string]int{ + "ja4": {"new_fp": 3}, + } + entries := computeDiff(aMap, bMap) + if len(entries) != 2 { + t.Fatalf("expected 2 entries, got %d", len(entries)) + } + if entries[0].change != diffNew { + t.Errorf("first entry should be diffNew, got %v", entries[0].change) + } + if entries[1].change != diffGone { + t.Errorf("second entry should be diffGone, got %v", entries[1].change) + } +} + +// ── DiffModel.Update — step 1 (inputB) ────────────────────────────────────── + +func TestDiffModel_Step1_ValidDuration_AdvancesToStep2(t *testing.T) { + r := newTestRouter() + d := newDiffModel() + + d.inputBuf = "1h" + d, _ = d.Update(tea.KeyMsg{Type: tea.KeyEnter}, &r) + + if d.subState != diffInputA { + t.Errorf("after valid B input: expected diffInputA, got %v", d.subState) + } + if d.bDur != time.Hour { + t.Errorf("bDur = %v, want 1h", d.bDur) + } + if d.inputBuf != "" { + t.Error("inputBuf should clear when advancing to step 2") + } +} + +func TestDiffModel_Step1_InvalidDuration_SetsError(t *testing.T) { + r := newTestRouter() + d := newDiffModel() + d.inputBuf = "notadur" + d, _ = d.Update(tea.KeyMsg{Type: tea.KeyEnter}, &r) + + if d.inputErr == "" { + t.Error("expected inputErr for invalid duration") + } + if d.subState != diffInputB { + t.Error("should stay at diffInputB on invalid input") + } +} + +func TestDiffModel_Step1_EscReturnsToTable(t *testing.T) { + r := newTestRouter() + r.state = stateDiff + d := newDiffModel() + d, _ = d.Update(tea.KeyMsg{Type: tea.KeyEsc}, &r) + if r.state != stateTable { + t.Errorf("Esc at step1: expected stateTable, got %v", r.state) + } + _ = d +} + +func TestDiffModel_Step1_Typing(t *testing.T) { + r := newTestRouter() + d := newDiffModel() + d, _ = d.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("2")}, &r) + d, _ = d.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("h")}, &r) + if d.inputBuf != "2h" { + t.Errorf("inputBuf = %q, want %q", d.inputBuf, "2h") + } +} + +func TestDiffModel_Step1_Backspace(t *testing.T) { + r := newTestRouter() + d := newDiffModel() + d.inputBuf = "2h" + d, _ = d.Update(tea.KeyMsg{Type: tea.KeyBackspace}, &r) + if d.inputBuf != "2" { + t.Errorf("inputBuf = %q, want %q", d.inputBuf, "2") + } +} + +// ── DiffModel.Update — step 2 (inputA) ────────────────────────────────────── + +func TestDiffModel_Step2_ValidDuration_FiresQuery(t *testing.T) { + r := newTestRouter() + d := newDiffModel() + d.subState = diffInputA + d.bDur = time.Hour + d.inputBuf = "1h" + + d, cmd := d.Update(tea.KeyMsg{Type: tea.KeyEnter}, &r) + + if !d.pending { + t.Error("pending should be true while query in flight") + } + if d.subState != diffResults { + t.Errorf("subState = %v, want diffResults", d.subState) + } + if cmd == nil { + t.Error("should return a query command") + } +} + +func TestDiffModel_Step2_InvalidDuration_SetsError(t *testing.T) { + r := newTestRouter() + d := newDiffModel() + d.subState = diffInputA + d.bDur = time.Hour + d.inputBuf = "bad" + d, _ = d.Update(tea.KeyMsg{Type: tea.KeyEnter}, &r) + + if d.inputErr == "" { + t.Error("expected inputErr for invalid duration at step 2") + } + if d.subState != diffInputA { + t.Errorf("should stay at diffInputA, got %v", d.subState) + } +} + +func TestDiffModel_Step2_EscReturnsToStep1(t *testing.T) { + r := newTestRouter() + d := newDiffModel() + d.subState = diffInputA + d.bDur = time.Hour + d.inputBuf = "30m" + d, _ = d.Update(tea.KeyMsg{Type: tea.KeyEsc}, &r) + + if d.subState != diffInputB { + t.Errorf("Esc at step 2: expected diffInputB, got %v", d.subState) + } + if d.inputBuf != "" { + t.Error("inputBuf should clear on back to step 1") + } +} + +// ── DiffModel.Update — results state ──────────────────────────────────────── + +func TestDiffModel_ResultMsg_PopulatesEntries(t *testing.T) { + r := newTestRouter() + d := newDiffModel() + d.pending = true + d.subState = diffResults + + now := time.Now() + msg := diffResultMsg{ + aMap: map[string]map[string]int{"ja4": {"old_fp": 5}}, + bMap: map[string]map[string]int{"ja4": {"new_fp": 3}}, + aStart: now.Add(-2 * time.Hour), + aEnd: now.Add(-time.Hour), + bStart: now.Add(-time.Hour), + bEnd: now, + } + d, _ = d.Update(msg, &r) + + if d.pending { + t.Error("pending should clear after result") + } + if len(d.entries) != 2 { + t.Fatalf("expected 2 entries, got %d", len(d.entries)) + } + if d.scrollOff != 0 { + t.Error("scrollOff should reset to 0") + } +} + +func TestDiffModel_ResultMsg_ErrorPreserved(t *testing.T) { + r := newTestRouter() + d := newDiffModel() + d.subState = diffResults + + d, _ = d.Update(diffResultMsg{err: errTest}, &r) + if d.queryErr == nil { + t.Error("queryErr should be set from result message") + } + if len(d.entries) != 0 { + t.Error("entries should remain nil on error") + } +} + +func TestDiffModel_RKeyResetsToStart(t *testing.T) { + r := newTestRouter() + now := time.Now() + d := newDiffModel() + d.subState = diffResults + d.bDur = time.Hour + d.entries = []diffEntry{{fpType: "ja4", fpValue: "t13d_fp", change: diffNew}} + d.aStart = now.Add(-2 * time.Hour) + d.bEnd = now + + d, _ = d.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("r")}, &r) + if d.subState != diffInputB { + t.Errorf("r key: expected diffInputB, got %v", d.subState) + } + if len(d.entries) != 0 { + t.Error("entries should clear on reset") + } + if d.bDur != 0 { + t.Error("bDur should clear on reset") + } +} + +func TestDiffModel_ScrollDown(t *testing.T) { + r := newTestRouter() + d := newDiffModel() + d.subState = diffResults + for i := 0; i < 5; i++ { + d.entries = append(d.entries, diffEntry{fpType: "ja4", fpValue: "fp", change: diffNew}) + } + + d, _ = d.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("j")}, &r) + if d.scrollOff != 1 { + t.Errorf("scrollOff = %d, want 1", d.scrollOff) + } +} + +func TestDiffModel_ScrollClampsAtBounds(t *testing.T) { + r := newTestRouter() + d := newDiffModel() + d.subState = diffResults + d.entries = []diffEntry{{fpType: "ja4", fpValue: "fp", change: diffNew}} + + for i := 0; i < 5; i++ { + d, _ = d.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("j")}, &r) + } + if d.scrollOff != 0 { + t.Errorf("scrollOff = %d, want 0 (clamped for 1-entry list)", d.scrollOff) + } +} + +// ── State machine: D key ───────────────────────────────────────────────────── + +func TestRouter_DKey_EntersDiff(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + + updated, _ := r.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("D")}) + r2 := updated.(Router) + if r2.state != stateDiff { + t.Errorf("D key: expected stateDiff, got %v", r2.state) + } +} + +func TestRouter_DKey_TogglesBackToTable(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + r.state = stateDiff + // Must be in results state: global D is suppressed in input states so + // the user can type duration characters (e.g. "1h30s") freely. + r.diff.subState = diffResults + + updated, _ := r.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("D")}) + r2 := updated.(Router) + if r2.state != stateTable { + t.Errorf("D key in diff results: expected stateTable, got %v", r2.state) + } +} + +func TestRouter_DKey_FromSummaryMode(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + r.state = stateSummary + + updated, _ := r.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("D")}) + r2 := updated.(Router) + if r2.state != stateDiff { + t.Errorf("D from summary: expected stateDiff, got %v", r2.state) + } +} + +func TestRouter_DKey_ResetsModelOnEntry(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + // Pre-populate diff with stale state. + r.diff.subState = diffResults + r.diff.bDur = time.Hour + + updated, _ := r.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("D")}) + r2 := updated.(Router) + if r2.diff.subState != diffInputB { + t.Errorf("diff model should reset to diffInputB on re-entry, got %v", r2.diff.subState) + } + if r2.diff.bDur != 0 { + t.Error("bDur should reset on entry") + } +} + +// ── View smoke tests ───────────────────────────────────────────────────────── + +func TestDiffModel_View_Step1(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + d := newDiffModel() + + view := d.View(&r) + if view == "" { + t.Fatal("View returned empty string") + } + if !contains(view, "DIFF") { + t.Error("View should show DIFF header") + } + if !contains(view, "Step 1") { + t.Error("View should show step 1 prompt") + } +} + +func TestDiffModel_View_Step2(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + d := newDiffModel() + d.subState = diffInputA + d.bDur = time.Hour + + view := d.View(&r) + if !contains(view, "Step 2") { + t.Error("View should show step 2 prompt") + } + if !contains(view, "1h") { + t.Error("View should confirm the B window duration") + } +} + +func TestDiffModel_View_PendingState(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + d := newDiffModel() + d.subState = diffResults + d.pending = true + d.aStart = time.Now().Add(-2 * time.Hour) + d.aEnd = time.Now().Add(-time.Hour) + d.bStart = time.Now().Add(-time.Hour) + d.bEnd = time.Now() + + view := d.View(&r) + if !contains(view, "Computing diff") { + t.Error("View should show pending message while query in flight") + } +} + +func TestDiffModel_View_NoChanges(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + d := newDiffModel() + d.subState = diffResults + d.aStart = time.Now().Add(-2 * time.Hour) + d.aEnd = time.Now().Add(-time.Hour) + d.bStart = time.Now().Add(-time.Hour) + d.bEnd = time.Now() + + view := d.View(&r) + if !contains(view, "No fingerprint changes") { + t.Error("View should show empty-state when no changes") + } +} + +func TestDiffModel_View_WithEntries(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + now := time.Now() + d := newDiffModel() + d.subState = diffResults + d.aStart = now.Add(-2 * time.Hour) + d.aEnd = now.Add(-time.Hour) + d.bStart = now.Add(-time.Hour) + d.bEnd = now + d.entries = []diffEntry{ + {fpType: "ja4", fpValue: "t13d_new_fp", aCount: 0, bCount: 5, change: diffNew}, + {fpType: "ja4", fpValue: "t13d_old_fp", aCount: 3, bCount: 0, change: diffGone}, + } + + view := d.View(&r) + if !contains(view, "t13d_new_fp") { + t.Error("View should show new fingerprint value") + } + if !contains(view, "t13d_old_fp") { + t.Error("View should show gone fingerprint value") + } + if !contains(view, "2 changed") { + t.Error("View should show entry count") + } +} + +func TestRouter_ViewDiff(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + r.state = stateDiff + + view := r.View() + if !contains(view, "DIFF") { + t.Error("Router.View() in stateDiff should render diff view") + } +} + +// ── diffFn wiring ───────────────────────────────────────────────────────────── + +func TestRouter_DiffFn_NilWhenNoStore(t *testing.T) { + r := newTestRouter() // no store + if r.diffFn != nil { + t.Error("diffFn should be nil when no store is wired") + } +} + +func TestDiffModel_QueryCmd_NilFn_ReturnsEmptyResult(t *testing.T) { + r := newTestRouter() // no store + d := newDiffModel() + now := time.Now() + + cmd := d.queryCmd(&r, now.Add(-2*time.Hour), now.Add(-time.Hour), now.Add(-time.Hour), now) + if cmd == nil { + t.Fatal("queryCmd should return a Cmd even with nil diffFn") + } + msg := cmd() + result, ok := msg.(diffResultMsg) + if !ok { + t.Fatalf("expected diffResultMsg, got %T", msg) + } + if result.err != nil { + t.Errorf("nil diffFn: unexpected error %v", result.err) + } + if len(result.aMap) != 0 || len(result.bMap) != 0 { + t.Error("nil diffFn: expected empty maps") + } +} diff --git a/internal/tui/investigate.go b/internal/tui/investigate.go new file mode 100644 index 0000000..69640f9 --- /dev/null +++ b/internal/tui/investigate.go @@ -0,0 +1,268 @@ +package tui + +import ( + "fmt" + "strings" + "time" + + "github.com/charmbracelet/bubbles/key" + tea "github.com/charmbracelet/bubbletea" + + "github.com/Crank-Git/ja4monitor/internal/tracker" +) + +// investigateSubState tracks whether the investigation view is waiting +// for user input or showing query results. +type investigateSubState int + +const ( + investigateInput investigateSubState = iota // user is typing a duration + investigateResults // displaying query results +) + +// InvestigateModel renders the time-window investigation view (I key). +// +// The user types a relative duration (e.g. "1h", "30m", "24h"); the model +// fires an async SQLite QueryTimeWindow query and merges the results with +// any matching in-memory connections. In-memory wins on dedup so the live +// state for active connections is always shown. +type InvestigateModel struct { + subState investigateSubState + inputBuf string + inputErr string + pending bool + results []*tracker.Connection + timedOut bool + queryErr error + windowStart time.Time + windowEnd time.Time + scrollOff int +} + +func newInvestigateModel() InvestigateModel { return InvestigateModel{} } + +// investigateResultMsg delivers the async time-window query results. +type investigateResultMsg struct { + results []*tracker.Connection + timedOut bool + err error + windowStart time.Time + windowEnd time.Time +} + +// Update handles investigate mode input and result delivery. +func (inv InvestigateModel) Update(msg tea.Msg, r *Router) (InvestigateModel, tea.Cmd) { + switch m := msg.(type) { + case investigateResultMsg: + inv.pending = false + inv.timedOut = m.timedOut + inv.queryErr = m.err + inv.windowStart = m.windowStart + inv.windowEnd = m.windowEnd + if m.err == nil { + inv.results = mergeInvestigateResults(m.results, r, m.windowStart, m.windowEnd) + } + inv.subState = investigateResults + inv.scrollOff = 0 + return inv, nil + + case tea.KeyMsg: + switch inv.subState { + case investigateInput: + return inv.updateInput(m, r) + case investigateResults: + return inv.updateResults(m, r) + } + } + return inv, nil +} + +func (inv InvestigateModel) updateInput(m tea.KeyMsg, r *Router) (InvestigateModel, tea.Cmd) { + switch { + case key.Matches(m, keys.Quit): + return inv, tea.Quit + case key.Matches(m, keys.Back): + r.state = stateTable + case m.Type == tea.KeyEnter: + dur, err := time.ParseDuration(strings.TrimSpace(inv.inputBuf)) + if err != nil || dur <= 0 { + inv.inputErr = "invalid duration — try 1h, 30m, 24h, 2h30m" + return inv, nil + } + inv.inputErr = "" + end := time.Now() + start := end.Add(-dur) + inv.pending = true + inv.subState = investigateResults + return inv, inv.queryCmd(r, start, end) + case m.Type == tea.KeyBackspace || m.Type == tea.KeyDelete: + if len(inv.inputBuf) > 0 { + inv.inputBuf = inv.inputBuf[:len(inv.inputBuf)-1] + } + inv.inputErr = "" + case m.Type == tea.KeyRunes: + inv.inputBuf += string(m.Runes) + inv.inputErr = "" + } + return inv, nil +} + +func (inv InvestigateModel) updateResults(m tea.KeyMsg, r *Router) (InvestigateModel, tea.Cmd) { + switch { + case key.Matches(m, keys.Quit): + return inv, tea.Quit + case key.Matches(m, keys.Back): + r.state = stateTable + case key.Matches(m, keys.Up): + if inv.scrollOff > 0 { + inv.scrollOff-- + } + case key.Matches(m, keys.Down): + if inv.scrollOff < len(inv.results)-1 { + inv.scrollOff++ + } + case m.Type == tea.KeyRunes && string(m.Runes) == "r": + // r → clear results and start a fresh query from input state + inv.subState = investigateInput + inv.inputBuf = "" + inv.inputErr = "" + inv.results = nil + } + return inv, nil +} + +// queryCmd fires QueryTimeWindow on a background goroutine and delivers an +// investigateResultMsg. Falls back to empty results when no timeWindowFn is +// wired (daemon attach mode has no local SQLite store). +func (inv InvestigateModel) queryCmd(r *Router, start, end time.Time) tea.Cmd { + fn := r.timeWindowFn + return func() tea.Msg { + if fn == nil { + return investigateResultMsg{windowStart: start, windowEnd: end} + } + results, timedOut, err := fn(start, end, 5*time.Second) + return investigateResultMsg{ + results: results, + timedOut: timedOut, + err: err, + windowStart: start, + windowEnd: end, + } + } +} + +// mergeInvestigateResults merges SQLite results with in-memory connections +// that overlap [start, end]. In-memory wins on dedup (same ID → fresher state). +func mergeInvestigateResults(sqliteResults []*tracker.Connection, r *Router, start, end time.Time) []*tracker.Connection { + // Index SQLite results; in-memory will overwrite on collision. + byID := make(map[string]int, len(sqliteResults)) + out := make([]*tracker.Connection, 0, len(sqliteResults)) + for _, c := range sqliteResults { + byID[c.ID] = len(out) + out = append(out, c) + } + // Overlay in-memory connections whose lifetime overlaps the window. + // Overlap: first_seen <= end AND last_seen >= start. + for _, c := range r.connMap { + if c.FirstSeen.After(end) || c.LastSeen.Before(start) { + continue + } + if idx, exists := byID[c.ID]; exists { + out[idx] = c // replace SQLite copy with live state + } else { + byID[c.ID] = len(out) + out = append(out, c) + } + } + return out +} + +// View renders the full-screen investigation dashboard. +func (inv InvestigateModel) View(r *Router) string { + var b strings.Builder + + hdr := fmt.Sprintf(" INVESTIGATE %s", r.ifaceName) + b.WriteString(headerStyle.Width(r.width).Render(hdr)) + b.WriteString("\n") + + switch inv.subState { + case investigateInput: + b.WriteString("\n") + b.WriteString(" Time window (relative duration, e.g. 1h 30m 24h 2h30m):\n\n") + b.WriteString(" > ") + b.WriteString(inv.inputBuf) + b.WriteString("█\n") + if inv.inputErr != "" { + b.WriteString("\n ") + b.WriteString(severityHighStyle.Render(inv.inputErr)) + b.WriteString("\n") + } + + case investigateResults: + if inv.pending { + b.WriteString("\n Querying ") + b.WriteString(inv.windowStart.Format("15:04:05")) + b.WriteString(" → ") + b.WriteString(inv.windowEnd.Format("15:04:05")) + b.WriteString(" ...\n") + } else if inv.queryErr != nil { + b.WriteString("\n ") + b.WriteString(severityHighStyle.Render("Query error: " + inv.queryErr.Error())) + b.WriteString("\n") + } else { + b.WriteString("\n") + windowLine := fmt.Sprintf(" Window: %s → %s", + inv.windowStart.Format("2006-01-02 15:04:05"), + inv.windowEnd.Format("2006-01-02 15:04:05")) + if inv.timedOut { + windowLine += " " + learningStyle.Render("[TIMEOUT — partial results]") + } + b.WriteString(windowLine) + b.WriteString("\n") + b.WriteString(fmt.Sprintf(" Results: %d connections\n\n", len(inv.results))) + + if len(inv.results) == 0 { + b.WriteString(" (no connections in this window)\n") + } else { + // Reserve rows for: header(1) + window(1) + count(1) + blank(1) + + // col-header(1) + footer(2) = 7 lines overhead. + maxRows := r.height - 7 + if maxRows < 1 { + maxRows = 1 + } + end := inv.scrollOff + maxRows + if end > len(inv.results) { + end = len(inv.results) + } + b.WriteString(sectionHeaderStyle.Render(fmt.Sprintf( + " %-15s %-5s %-15s %-5s %-16s %s", + "SRC IP", "SPORT", "DST IP", "DPORT", "JA4", "PACKETS"))) + b.WriteString("\n") + for _, c := range inv.results[inv.scrollOff:end] { + ja4 := truncate(c.LatestFingerprint("ja4"), 16) + if ja4 == "" { + ja4 = "-" + } + b.WriteString(fmt.Sprintf(" %-15s %-5d %-15s %-5d %-16s %d\n", + c.SrcIP, c.SrcPort, + c.DstIP, c.DstPort, + ja4, c.PacketCount, + )) + } + if len(inv.results) > maxRows { + b.WriteString(fmt.Sprintf("\n row %d–%d of %d ↑/k ↓/j\n", + inv.scrollOff+1, end, len(inv.results))) + } + } + } + } + + footer := " [Esc]back [r]new query [I]close [q]quit" + if inv.subState == investigateInput { + footer = " [Esc]back [Enter]query [q]quit" + } + if status := r.currentStatus(); status != "" { + return finalize(r, b.String(), learningStyle.Render(" "+status)) + } + return finalize(r, b.String(), footerStyle.Render(footer)) +} diff --git a/internal/tui/investigate_test.go b/internal/tui/investigate_test.go new file mode 100644 index 0000000..e99326d --- /dev/null +++ b/internal/tui/investigate_test.go @@ -0,0 +1,472 @@ +package tui + +import ( + "errors" + "testing" + "time" + + tea "github.com/charmbracelet/bubbletea" + + "github.com/Crank-Git/ja4monitor/internal/tracker" +) + +// ── mergeInvestigateResults ────────────────────────────────────────────────── + +func TestMergeInvestigateResults_SQLiteOnlyNoConnMap(t *testing.T) { + r := newTestRouter() + now := time.Now() + conn := tracker.NewConnection("10.0.0.1", 1234, "1.1.1.1", 443, "tcp", now.Add(-30*time.Minute)) + conn.LastSeen = now.Add(-10 * time.Minute) + + results := mergeInvestigateResults([]*tracker.Connection{conn}, &r, now.Add(-time.Hour), now) + if len(results) != 1 { + t.Fatalf("expected 1 result, got %d", len(results)) + } + if results[0].ID != conn.ID { + t.Errorf("expected conn %s, got %s", conn.ID, results[0].ID) + } +} + +func TestMergeInvestigateResults_InMemoryWinsOnDedup(t *testing.T) { + now := time.Now() + start := now.Add(-time.Hour) + end := now + + // SQLite version has old packet count. + sqliteConn := tracker.NewConnection("10.0.0.1", 1234, "1.1.1.1", 443, "tcp", now.Add(-30*time.Minute)) + sqliteConn.LastSeen = now.Add(-5 * time.Minute) + sqliteConn.PacketCount = 5 + + // In-memory version has newer packet count (same ID). + liveConn := tracker.NewConnection("10.0.0.1", 1234, "1.1.1.1", 443, "tcp", now.Add(-30*time.Minute)) + liveConn.ID = sqliteConn.ID // same ID + liveConn.LastSeen = now.Add(-1 * time.Minute) + liveConn.PacketCount = 50 + + r := newTestRouter() + r.connMap[liveConn.ID] = liveConn + + results := mergeInvestigateResults([]*tracker.Connection{sqliteConn}, &r, start, end) + if len(results) != 1 { + t.Fatalf("expected 1 deduplicated result, got %d", len(results)) + } + if results[0].PacketCount != 50 { + t.Errorf("in-memory should win: PacketCount = %d, want 50", results[0].PacketCount) + } +} + +func TestMergeInvestigateResults_InMemoryAddedWhenNotInSQLite(t *testing.T) { + now := time.Now() + start := now.Add(-time.Hour) + end := now + + // In-memory connection not in SQLite results. + liveConn := tracker.NewConnection("192.168.1.5", 9999, "8.8.8.8", 53, "udp", now.Add(-20*time.Minute)) + liveConn.LastSeen = now.Add(-5 * time.Minute) + + r := newTestRouter() + r.connMap[liveConn.ID] = liveConn + + results := mergeInvestigateResults(nil, &r, start, end) + if len(results) != 1 { + t.Fatalf("expected 1 result from in-memory, got %d", len(results)) + } + if results[0].ID != liveConn.ID { + t.Errorf("expected live conn %s, got %s", liveConn.ID, results[0].ID) + } +} + +func TestMergeInvestigateResults_SkipsOutOfWindowInMemory(t *testing.T) { + now := time.Now() + start := now.Add(-time.Hour) + end := now.Add(-30 * time.Minute) + + // Connection entirely after the window. + futureConn := tracker.NewConnection("192.168.1.99", 1111, "2.2.2.2", 80, "tcp", now.Add(-10*time.Minute)) + futureConn.LastSeen = now.Add(-5 * time.Minute) + + r := newTestRouter() + r.connMap[futureConn.ID] = futureConn + + results := mergeInvestigateResults(nil, &r, start, end) + if len(results) != 0 { + t.Errorf("out-of-window in-memory conn should be excluded, got %d results", len(results)) + } +} + +// ── InvestigateModel.Update — input state ─────────────────────────────────── + +func TestInvestigateModel_ValidDuration_TriggersQuery(t *testing.T) { + r := newTestRouter() + inv := newInvestigateModel() + + // Type "1h" then Enter. + inv, _ = inv.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("1")}, &r) + inv, _ = inv.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("h")}, &r) + inv, cmd := inv.Update(tea.KeyMsg{Type: tea.KeyEnter}, &r) + + if !inv.pending { + t.Error("pending should be true while query is in flight") + } + if inv.subState != investigateResults { + t.Errorf("subState should switch to investigateResults, got %v", inv.subState) + } + if cmd == nil { + t.Error("should return a query command") + } +} + +func TestInvestigateModel_InvalidDuration_SetsError(t *testing.T) { + r := newTestRouter() + inv := newInvestigateModel() + + // Type "notaduration" then Enter. + inv.inputBuf = "notaduration" + inv, _ = inv.Update(tea.KeyMsg{Type: tea.KeyEnter}, &r) + + if inv.inputErr == "" { + t.Error("expected inputErr to be set for invalid duration") + } + if inv.pending { + t.Error("pending should remain false for invalid input") + } + if inv.subState != investigateInput { + t.Error("subState should remain investigateInput on parse error") + } +} + +func TestInvestigateModel_NegativeDuration_SetsError(t *testing.T) { + r := newTestRouter() + inv := newInvestigateModel() + inv.inputBuf = "-1h" + inv, _ = inv.Update(tea.KeyMsg{Type: tea.KeyEnter}, &r) + if inv.inputErr == "" { + t.Error("negative duration should set an error") + } +} + +func TestInvestigateModel_BackspaceReducesInput(t *testing.T) { + r := newTestRouter() + inv := newInvestigateModel() + inv.inputBuf = "1h" + inv, _ = inv.Update(tea.KeyMsg{Type: tea.KeyBackspace}, &r) + if inv.inputBuf != "1" { + t.Errorf("inputBuf = %q, want %q", inv.inputBuf, "1") + } +} + +func TestInvestigateModel_BackspaceOnEmpty_DoesNotPanic(t *testing.T) { + r := newTestRouter() + inv := newInvestigateModel() + // Should not panic on empty buf. + inv, _ = inv.Update(tea.KeyMsg{Type: tea.KeyBackspace}, &r) + if inv.inputBuf != "" { + t.Errorf("inputBuf should remain empty, got %q", inv.inputBuf) + } +} + +func TestInvestigateModel_EscReturnsToTable(t *testing.T) { + r := newTestRouter() + r.state = stateInvestigate + inv := newInvestigateModel() + inv, _ = inv.Update(tea.KeyMsg{Type: tea.KeyEsc}, &r) + if r.state != stateTable { + t.Errorf("Esc in investigate input: expected stateTable, got %v", r.state) + } + _ = inv +} + +// ── InvestigateModel.Update — results state ────────────────────────────────── + +func TestInvestigateModel_ResultMsg_PopulatesResults(t *testing.T) { + r := newTestRouter() + now := time.Now() + conn := tracker.NewConnection("10.0.0.1", 1234, "1.1.1.1", 443, "tcp", now.Add(-30*time.Minute)) + + inv := newInvestigateModel() + inv.pending = true + inv.subState = investigateResults + + msg := investigateResultMsg{ + results: []*tracker.Connection{conn}, + windowStart: now.Add(-time.Hour), + windowEnd: now, + } + inv, _ = inv.Update(msg, &r) + + if inv.pending { + t.Error("pending should clear after result arrives") + } + if len(inv.results) != 1 { + t.Fatalf("expected 1 result, got %d", len(inv.results)) + } + if inv.scrollOff != 0 { + t.Error("scrollOff should reset to 0 on new results") + } +} + +func TestInvestigateModel_ResultMsg_ErrorPreserved(t *testing.T) { + r := newTestRouter() + inv := newInvestigateModel() + inv.subState = investigateResults + + inv, _ = inv.Update(investigateResultMsg{ + err: errTest, + windowStart: time.Now().Add(-time.Hour), + windowEnd: time.Now(), + }, &r) + + if inv.queryErr == nil { + t.Error("queryErr should be set from result message") + } + if len(inv.results) != 0 { + t.Error("results should remain nil on error") + } +} + +func TestInvestigateModel_RKeyResetsToInput(t *testing.T) { + r := newTestRouter() + now := time.Now() + inv := newInvestigateModel() + inv.subState = investigateResults + inv.results = []*tracker.Connection{ + tracker.NewConnection("10.0.0.1", 1234, "1.1.1.1", 443, "tcp", now), + } + + inv, _ = inv.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("r")}, &r) + if inv.subState != investigateInput { + t.Errorf("r key: expected investigateInput, got %v", inv.subState) + } + if len(inv.results) != 0 { + t.Error("results should be cleared on reset") + } +} + +func TestInvestigateModel_ScrollDown(t *testing.T) { + r := newTestRouter() + now := time.Now() + inv := newInvestigateModel() + inv.subState = investigateResults + for i := 0; i < 5; i++ { + inv.results = append(inv.results, + tracker.NewConnection("10.0.0.1", uint16(i), "1.1.1.1", 443, "tcp", now)) + } + + inv, _ = inv.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("j")}, &r) + if inv.scrollOff != 1 { + t.Errorf("scrollOff = %d, want 1 after one ↓", inv.scrollOff) + } +} + +func TestInvestigateModel_ScrollDoesNotExceedBounds(t *testing.T) { + r := newTestRouter() + now := time.Now() + inv := newInvestigateModel() + inv.subState = investigateResults + inv.results = []*tracker.Connection{ + tracker.NewConnection("10.0.0.1", 1234, "1.1.1.1", 443, "tcp", now), + } + inv.scrollOff = 0 + + // Scroll down past the end — should clamp. + for i := 0; i < 5; i++ { + inv, _ = inv.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("j")}, &r) + } + if inv.scrollOff != 0 { + t.Errorf("scrollOff = %d, want 0 (clamped for single-entry list)", inv.scrollOff) + } +} + +// ── State machine: I key ───────────────────────────────────────────────────── + +func TestRouter_IKey_EntersInvestigate(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + + updated, _ := r.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("I")}) + r2 := updated.(Router) + if r2.state != stateInvestigate { + t.Errorf("I key: expected stateInvestigate, got %v", r2.state) + } +} + +func TestRouter_IKey_TogglesBackToTable(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + r.state = stateInvestigate + // Must be in results state: global I is suppressed in input state so + // the user can type duration characters (e.g. "1h") freely. + r.investigate.subState = investigateResults + + updated, _ := r.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("I")}) + r2 := updated.(Router) + if r2.state != stateTable { + t.Errorf("I key in investigate results: expected stateTable, got %v", r2.state) + } +} + +func TestRouter_IKey_FromSummaryMode(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + r.state = stateSummary + + updated, _ := r.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("I")}) + r2 := updated.(Router) + if r2.state != stateInvestigate { + t.Errorf("I key from summary: expected stateInvestigate, got %v", r2.state) + } +} + +func TestRouter_IKey_ResetsModelOnEntry(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + // Pre-populate investigate with stale state. + r.investigate.inputBuf = "stale" + r.investigate.subState = investigateResults + + updated, _ := r.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("I")}) + r2 := updated.(Router) + if r2.investigate.inputBuf != "" { + t.Error("investigate model should reset on re-entry") + } + if r2.investigate.subState != investigateInput { + t.Errorf("subState should reset to investigateInput, got %v", r2.investigate.subState) + } +} + +// ── View smoke tests ───────────────────────────────────────────────────────── + +func TestInvestigateModel_View_InputState(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + inv := newInvestigateModel() + + view := inv.View(&r) + if view == "" { + t.Fatal("View returned empty string") + } + if !contains(view, "INVESTIGATE") { + t.Error("View should show INVESTIGATE header") + } + if !contains(view, "Time window") { + t.Error("View should show time window prompt") + } +} + +func TestInvestigateModel_View_InputError(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + inv := newInvestigateModel() + inv.inputErr = "invalid duration — try 1h, 30m, 24h, 2h30m" + + view := inv.View(&r) + if !contains(view, "invalid duration") { + t.Error("View should show input error message") + } +} + +func TestInvestigateModel_View_PendingState(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + inv := newInvestigateModel() + inv.subState = investigateResults + inv.pending = true + inv.windowStart = time.Now().Add(-time.Hour) + inv.windowEnd = time.Now() + + view := inv.View(&r) + if !contains(view, "Querying") { + t.Error("View should show querying state while pending") + } +} + +func TestInvestigateModel_View_NoResults(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + inv := newInvestigateModel() + inv.subState = investigateResults + inv.windowStart = time.Now().Add(-time.Hour) + inv.windowEnd = time.Now() + + view := inv.View(&r) + if !contains(view, "no connections") { + t.Error("View should show empty-state message when no results") + } +} + +func TestInvestigateModel_View_WithResults(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + now := time.Now() + inv := newInvestigateModel() + inv.subState = investigateResults + inv.windowStart = now.Add(-time.Hour) + inv.windowEnd = now + inv.results = []*tracker.Connection{ + tracker.NewConnection("10.0.0.42", 12345, "8.8.8.8", 443, "tcp", now.Add(-30*time.Minute)), + } + + view := inv.View(&r) + if !contains(view, "10.0.0.42") { + t.Error("View should display connection src IP") + } + if !contains(view, "1 connections") { + t.Error("View should show result count") + } +} + +func TestRouter_ViewInvestigate(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + r.state = stateInvestigate + + view := r.View() + if !contains(view, "INVESTIGATE") { + t.Error("Router.View() in stateInvestigate should render investigate view") + } +} + +// ── timeWindowFn wiring ────────────────────────────────────────────────────── + +func TestRouter_TimeWindowFn_NilWhenNoStore(t *testing.T) { + r := newTestRouter() // no store + if r.timeWindowFn != nil { + t.Error("timeWindowFn should be nil when no store is wired") + } +} + +func TestInvestigateModel_QueryCmd_NilFn_ReturnsEmptyResult(t *testing.T) { + r := newTestRouter() // no store → timeWindowFn is nil + inv := newInvestigateModel() + start := time.Now().Add(-time.Hour) + end := time.Now() + + cmd := inv.queryCmd(&r, start, end) + if cmd == nil { + t.Fatal("queryCmd should return a Cmd even with nil timeWindowFn") + } + msg := cmd() + result, ok := msg.(investigateResultMsg) + if !ok { + t.Fatalf("expected investigateResultMsg, got %T", msg) + } + if result.err != nil { + t.Errorf("nil timeWindowFn: unexpected error %v", result.err) + } + if len(result.results) != 0 { + t.Errorf("nil timeWindowFn: expected empty results, got %d", len(result.results)) + } +} + +// errTest is a sentinel error for tests that need a non-nil error. +var errTest = errors.New("test error") diff --git a/internal/tui/keys.go b/internal/tui/keys.go index 457a530..63e6835 100644 --- a/internal/tui/keys.go +++ b/internal/tui/keys.go @@ -9,28 +9,34 @@ import "github.com/charmbracelet/bubbles/key" // // TUI Vocabulary: // -// Global: q=quit, Esc=back/cancel +// Global: q=quit, Esc=back/cancel, S=summary, I=investigate, D=diff // Table mode: p=pause, f or /=filter, a=toggle all/fp-only, // A=alert list, E=export, Enter=detail, j/k=scroll, // G=cycle group, B=toggle bookmark // Detail mode: Esc=back, j/k=scroll, r=toggle raw FPs, E=export // Alert mode: Esc=back, j/k=scroll, Enter=investigate (jump to detail) +// Summary mode: Esc=back to table +// Investigate: Esc=back, Enter=run query, r=new query, j/k=scroll +// Diff mode: Esc=back/prev-step, Enter=next/run, r=new diff, j/k=scroll type keyMap struct { - Quit key.Binding - Filter key.Binding - Search key.Binding - Pause key.Binding - Up key.Binding - Down key.Binding - ToggleAll key.Binding - Enter key.Binding - Back key.Binding - AlertList key.Binding - Export key.Binding - ToggleRaw key.Binding - HistoryTab key.Binding - Group key.Binding - Bookmark key.Binding + Quit key.Binding + Filter key.Binding + Search key.Binding + Pause key.Binding + Up key.Binding + Down key.Binding + ToggleAll key.Binding + Enter key.Binding + Back key.Binding + AlertList key.Binding + Export key.Binding + ToggleRaw key.Binding + HistoryTab key.Binding + Group key.Binding + Bookmark key.Binding + Summary key.Binding + Investigate key.Binding + Diff key.Binding } var keys = keyMap{ @@ -94,4 +100,16 @@ var keys = keyMap{ key.WithKeys("B"), key.WithHelp("B", "bookmark"), ), + Summary: key.NewBinding( + key.WithKeys("S"), + key.WithHelp("S", "summary"), + ), + Investigate: key.NewBinding( + key.WithKeys("I"), + key.WithHelp("I", "investigate"), + ), + Diff: key.NewBinding( + key.WithKeys("D"), + key.WithHelp("D", "diff"), + ), } diff --git a/internal/tui/router.go b/internal/tui/router.go index 6b81aa8..9f04aaa 100644 --- a/internal/tui/router.go +++ b/internal/tui/router.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/Crank-Git/ja4monitor/internal/anomaly" @@ -21,24 +22,50 @@ import ( // // ┌─────────────┐ A ┌─────────────┐ // │ │ ─────► │ │ -// │ TableMode │ │ AlertMode │ +// │ TableMode │ │ AlertMode │ // │ │ ◄───── │ │ // └─────────────┘ Esc └─────────────┘ // │ ▲ │ -// │ │ │ -// ↵ │ │ Esc │ ↵ +// │ │ │ ↵ (Enter on selected alert) +// ↵ │ │ Esc │ // ▼ │ ▼ // ┌─────────────┐ ┌─────────────┐ // │ │ │ │ -// │ DetailMode │ ◄───── │ DetailMode │ +// │ DetailMode │ ◄──────│ DetailMode │ // │ │ │ │ // └─────────────┘ └─────────────┘ +// │ Esc │ Esc +// ▼ ▼ +// (TableMode) (TableMode) +// +// S/I/D are global shortcuts (suppressed in text-input states): +// +// S (global) +// ──────────────────────────► ┌──────────────┐ +// │ SummaryMode │ +// ◄────────────────────────── │ │ +// Esc / S └──────────────┘ +// +// I (global) +// ──────────────────────────► ┌─────────────────┐ +// │ InvestigateMode │ +// ◄────────────────────────── │ (input→results) │ +// Esc / I └─────────────────┘ +// +// D (global) +// ──────────────────────────► ┌──────────────────┐ +// │ DiffMode │ +// ◄────────────────────────── │ (B→A→results) │ +// Esc / D └──────────────────┘ type sessionState int const ( stateTable sessionState = iota stateDetail stateAlertList + stateSummary + stateInvestigate + stateDiff ) // ticker fires every 100ms to drive PPS updates and header stats. @@ -75,22 +102,35 @@ type EngineFunc func() (connCount int, fpCount int, isLearning bool, learnRemain // nil if no store is attached (daemon attach mode). type StoreQueryFunc func(f filter.Filter, timeout time.Duration) (conns []*tracker.Connection, timedOut bool, err error) +// StoreTimeWindowFunc queries connections whose lifetimes overlap [start, end]. +// nil if no store is attached (daemon attach mode). +type StoreTimeWindowFunc func(start, end time.Time, timeout time.Duration) (conns []*tracker.Connection, timedOut bool, err error) + +// StoreDiffFunc counts fingerprint occurrences in two adjacent time windows. +// nil if no store is attached (daemon attach mode). +type StoreDiffFunc func(aStart, aEnd, bStart, bEnd time.Time, timeout time.Duration) (aMap, bMap map[string]map[string]int, err error) + // Router is the top-level Bubbletea model. It holds shared state and // delegates Update/View to the active mode's child model. type Router struct { // Child models - table TableModel - detail DetailModel - alertList AlertListModel - state sessionState + table TableModel + detail DetailModel + alertList AlertListModel + summary SummaryModel + investigate InvestigateModel + diff DiffModel + state sessionState // Shared state (reads and writes funneled through router.Update) - connChan chan *tracker.Connection - alertChan <-chan anomaly.Alert - statsFn StatsFunc - engineFn EngineFunc - queryFn StoreQueryFunc // may be nil - ifaceName string + connChan chan *tracker.Connection + alertChan <-chan anomaly.Alert + statsFn StatsFunc + engineFn EngineFunc + queryFn StoreQueryFunc // may be nil + timeWindowFn StoreTimeWindowFunc // may be nil + diffFn StoreDiffFunc // may be nil + ifaceName string // Connection map: canonical live view. Only holds live connections. // Historical query results live in a separate slice so toggling @@ -128,6 +168,11 @@ type Router struct { setBookmarkFn func(id string, on bool) error loadBookmarksFn func() (map[string]bool, error) + // ruleStatsFn returns fire counts for active custom rules. Used by the + // summary dashboard. nil when no custom rules are configured (e.g. daemon + // attach mode). Set via WithRuleStatsFn after construction. + ruleStatsFn func() []anomaly.RuleStat + // Status message and expiration (for export success/error feedback) statusMsg string statusExpiry time.Time @@ -159,12 +204,25 @@ func NewRouter( } if store != nil { r.queryFn = store.Query + r.timeWindowFn = store.QueryTimeWindow + r.diffFn = store.DiffWindows r.setBookmarkFn = store.SetBookmark r.loadBookmarksFn = store.GetBookmarks } r.table = newTableModel(&r) r.detail = newDetailModel(&r) r.alertList = newAlertListModel(&r) + r.summary = newSummaryModel() + r.investigate = newInvestigateModel() + r.diff = newDiffModel() + return r +} + +// WithRuleStatsFn wires the custom rules stats source for the summary dashboard. +// Returns the modified router — use as: router = tui.NewRouter(...).WithRuleStatsFn(fn) +// Passing nil is valid: the summary view omits the custom rules section. +func (r Router) WithRuleStatsFn(fn func() []anomaly.RuleStat) Router { + r.ruleStatsFn = fn return r } @@ -274,6 +332,9 @@ func (r Router) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tickMsg: r.table.updateSparks() r.table.updatePPS(r.statsFn()) + if r.state == stateSummary { + r.summary.onTick(&r) + } return r, tickCmd() case connUpdateMsg: @@ -303,6 +364,14 @@ func (r Router) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return r, r.listenForAlerts() + case investigateResultMsg: + r.investigate, _ = r.investigate.Update(m, &r) + return r, nil + + case diffResultMsg: + r.diff, _ = r.diff.Update(m, &r) + return r, nil + case bookmarksMsg: r.bookmarks = m.set return r, nil @@ -330,6 +399,42 @@ func (r Router) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.String() == "ctrl+c" { return r, tea.Quit } + // S/I/D are global shortcuts but must not fire when the active mode + // is in raw text-input state (filter bar, diff duration prompts, + // investigate window prompt). In those contexts, S/I/D should be + // appended to the input buffer, not treated as navigation commands. + if !r.inTextInput() { + // S: switch to summary from any mode, or back to table. + if key.Matches(m, keys.Summary) { + if r.state == stateSummary { + r.state = stateTable + } else { + r.summary.onTick(&r) // seed cache immediately on entry + r.state = stateSummary + } + return r, nil + } + // I: open the time-window investigation view, or close it. + if key.Matches(m, keys.Investigate) { + if r.state == stateInvestigate { + r.state = stateTable + } else { + r.investigate = newInvestigateModel() // fresh input state on entry + r.state = stateInvestigate + } + return r, nil + } + // D: open the fingerprint diff view, or close it. + if key.Matches(m, keys.Diff) { + if r.state == stateDiff { + r.state = stateTable + } else { + r.diff = newDiffModel() // fresh input state on entry + r.state = stateDiff + } + return r, nil + } + } // Delegate key handling to the active mode's child model. var cmd tea.Cmd switch r.state { @@ -339,6 +444,12 @@ func (r Router) Update(msg tea.Msg) (tea.Model, tea.Cmd) { r.detail, cmd = r.detail.Update(m, &r) case stateAlertList: r.alertList, cmd = r.alertList.Update(m, &r) + case stateSummary: + r.summary, cmd = r.summary.Update(m, &r) + case stateInvestigate: + r.investigate, cmd = r.investigate.Update(m, &r) + case stateDiff: + r.diff, cmd = r.diff.Update(m, &r) } if cmd != nil { cmds = append(cmds, cmd) @@ -358,11 +469,35 @@ func (r Router) View() string { return r.detail.View(&r) case stateAlertList: return r.alertList.View(&r) + case stateSummary: + return r.summary.View(&r) + case stateInvestigate: + return r.investigate.View(&r) + case stateDiff: + return r.diff.View(&r) default: return r.table.View(&r) } } +// inTextInput returns true when the active mode is currently accepting +// raw text input (filter bar, export prompt, diff duration inputs, +// investigate time-window input). Global key shortcuts (S/I/D) are +// suppressed in this state so the user can type those characters freely +// without accidentally switching TUI modes. +func (r *Router) inTextInput() bool { + if r.filterMode || r.table.exportPrompt { + return true + } + if r.state == stateDiff { + return r.diff.subState == diffInputB || r.diff.subState == diffInputA + } + if r.state == stateInvestigate { + return r.investigate.subState == investigateInput + } + return false +} + // setStatus sets a temporary status message shown in the footer. After // 3 seconds it clears automatically (checked during rendering). func (r *Router) setStatus(msg string) { diff --git a/internal/tui/router_test.go b/internal/tui/router_test.go index 4e45ef6..ff9916a 100644 --- a/internal/tui/router_test.go +++ b/internal/tui/router_test.go @@ -148,3 +148,85 @@ func TestRouter_LoadBookmarksCmdNilWhenNoStore(t *testing.T) { t.Error("loadBookmarksCmd should return nil when no store is wired") } } + +// ── inTextInput guard for global S/I/D shortcuts ───────────────────────────── + +// TestRouter_GlobalSKeyBlockedInFilterMode verifies that pressing 'S' while +// r.filterMode is true does not switch to stateSummary. Before the fix, the +// global S handler ran before checking filterMode, so the character was eaten +// by mode-switching instead of being appended to the filter buffer. +func TestRouter_GlobalSKeyBlockedInFilterMode(t *testing.T) { + r := newTestRouter() + r.filterMode = true // simulates open filter bar + + msg := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("S")} + updated, _ := r.Update(msg) + r2 := updated.(Router) + + if r2.state != stateTable { + t.Errorf("S in filterMode: state = %v, want stateTable (global key must be suppressed)", r2.state) + } + if !r2.filterMode { + t.Error("S in filterMode: filterMode was cleared, expected it to remain open") + } +} + +// TestRouter_GlobalIKeyBlockedInFilterMode verifies the same for 'I'. +func TestRouter_GlobalIKeyBlockedInFilterMode(t *testing.T) { + r := newTestRouter() + r.filterMode = true + + msg := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("I")} + updated, _ := r.Update(msg) + r2 := updated.(Router) + + if r2.state == stateInvestigate { + t.Error("I in filterMode: switched to stateInvestigate, expected state unchanged") + } +} + +// TestRouter_GlobalDKeyBlockedInFilterMode verifies the same for 'D'. +func TestRouter_GlobalDKeyBlockedInFilterMode(t *testing.T) { + r := newTestRouter() + r.filterMode = true + + msg := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("D")} + updated, _ := r.Update(msg) + r2 := updated.(Router) + + if r2.state == stateDiff { + t.Error("D in filterMode: switched to stateDiff, expected state unchanged") + } +} + +// TestRouter_GlobalSKeyBlockedInDiffInput verifies that 'S' is also suppressed +// when the diff model is in its duration-input step (diffInputB), so the user +// can type 'S' as part of e.g. "30s" without accidentally switching modes. +func TestRouter_GlobalSKeyBlockedInDiffInput(t *testing.T) { + r := newTestRouter() + r.state = stateDiff + // diff starts in diffInputB by default (zero value) + + msg := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("S")} + updated, _ := r.Update(msg) + r2 := updated.(Router) + + if r2.state != stateDiff { + t.Errorf("S in diff input: state = %v, want stateDiff", r2.state) + } +} + +// TestRouter_GlobalSKeyWorksWhenNotInTextInput confirms the global 'S' handler +// fires normally when neither filterMode nor any input subState is active. +func TestRouter_GlobalSKeyWorksWhenNotInTextInput(t *testing.T) { + r := newTestRouter() + // default state: stateTable, filterMode=false, diff/investigate at step 0 but state ≠ those modes + + msg := tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("S")} + updated, _ := r.Update(msg) + r2 := updated.(Router) + + if r2.state != stateSummary { + t.Errorf("S in table mode: state = %v, want stateSummary", r2.state) + } +} diff --git a/internal/tui/styles.go b/internal/tui/styles.go index 80fc48c..a6fea70 100644 --- a/internal/tui/styles.go +++ b/internal/tui/styles.go @@ -76,6 +76,19 @@ var ( Foreground(lipgloss.AdaptiveColor{Light: "#000000", Dark: "#dadada"}). Underline(true) + // Diff mode change-category styles. + diffNewStyle = lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.AdaptiveColor{Light: "#008700", Dark: "#87d787"}) + diffGoneStyle = lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.AdaptiveColor{Light: "#af0000", Dark: "#ff5f5f"}) + diffSpikedStyle = lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.AdaptiveColor{Light: "#af8700", Dark: "#ffd75f"}) + diffDroppedStyle = lipgloss.NewStyle(). + Foreground(lipgloss.AdaptiveColor{Light: "#005faf", Dark: "#5fafff"}) + // severityStyles maps anomaly severity strings to rendered colors. severityCriticalStyle = lipgloss.NewStyle(). Bold(true). diff --git a/internal/tui/summary.go b/internal/tui/summary.go new file mode 100644 index 0000000..a03ff79 --- /dev/null +++ b/internal/tui/summary.go @@ -0,0 +1,228 @@ +package tui + +import ( + "fmt" + "sort" + "strings" + "time" + + "github.com/charmbracelet/bubbles/key" + tea "github.com/charmbracelet/bubbletea" + + "github.com/Crank-Git/ja4monitor/internal/anomaly" +) + +// rankEntry is one entry in a top-N aggregation table. +type rankEntry struct { + label string + count int +} + +// SummaryModel renders the network overview dashboard (S key). +// +// Aggregation design: +// +// Top-5 source IPs and top-5 JA4 fingerprints are computed from the +// router's live connMap at most once per second (1s cache). The bubbletea +// event loop is single-threaded, so iterating connMap here is safe — no +// additional locking needed. +// +// Alert severity counts, recent alerts, and engine stats (fp count, +// suppressed count) are derived from router fields on every View call +// (cheap linear scans of small slices), so they are always fresh. +// +// Custom rule fire counts are read from the evaluator via ruleStatsFn, +// also on each View call (atomic loads, negligible cost). +type SummaryModel struct { + lastRefresh time.Time + topSrcIPs []rankEntry + topJA4s []rankEntry +} + +func newSummaryModel() SummaryModel { return SummaryModel{} } + +// onTick is called by the router's tickMsg handler when the summary view is +// active. Refreshes the cached top-5 aggregations if more than 1 second has +// elapsed since the last refresh. +func (s *SummaryModel) onTick(r *Router) { + if time.Since(s.lastRefresh) < time.Second { + return + } + s.lastRefresh = time.Now() + s.topSrcIPs = buildTopSrcIPs(r, 5) + s.topJA4s = buildTopJA4s(r, 5) +} + +// buildTopSrcIPs counts connections per source IP and returns the top n. +func buildTopSrcIPs(r *Router, n int) []rankEntry { + counts := make(map[string]int, len(r.connMap)) + for _, conn := range r.connMap { + counts[conn.SrcIP]++ + } + return topEntries(counts, n) +} + +// buildTopJA4s counts connections per latest JA4 fingerprint and returns top n. +// Connections with no JA4 fingerprint are skipped. +func buildTopJA4s(r *Router, n int) []rankEntry { + counts := make(map[string]int) + for _, conn := range r.connMap { + if fp := conn.LatestFingerprint("ja4"); fp != "" { + counts[fp]++ + } + } + return topEntries(counts, n) +} + +// topEntries converts a frequency map to a sorted []rankEntry, top n. +func topEntries(counts map[string]int, n int) []rankEntry { + entries := make([]rankEntry, 0, len(counts)) + for label, count := range counts { + entries = append(entries, rankEntry{label: label, count: count}) + } + sort.Slice(entries, func(i, j int) bool { + if entries[i].count != entries[j].count { + return entries[i].count > entries[j].count + } + return entries[i].label < entries[j].label + }) + if len(entries) > n { + entries = entries[:n] + } + return entries +} + +// Update handles summary mode keys and drives the 1s cache refresh. +func (s SummaryModel) Update(msg tea.Msg, r *Router) (SummaryModel, tea.Cmd) { + switch m := msg.(type) { + case tickMsg: + s.onTick(r) + case tea.KeyMsg: + switch { + case key.Matches(m, keys.Quit): + return s, tea.Quit + case key.Matches(m, keys.Back), key.Matches(m, keys.Summary): + r.state = stateTable + } + } + return s, nil +} + +// View renders the full-screen summary dashboard. +func (s SummaryModel) View(r *Router) string { + var b strings.Builder + + // Header + _, fpCount, isLearning, learnRemaining, _, suppressedCount := r.engineFn() + hdr := fmt.Sprintf(" SUMMARY %s", r.ifaceName) + if isLearning { + hdr += " " + learningStyle.Render(fmt.Sprintf("LEARNING (%dm)", int(learnRemaining.Minutes()))) + } + b.WriteString(headerStyle.Width(r.width).Render(hdr)) + b.WriteString("\n") + + // ── Network stats bar ────────────────────────────────────── + connCount := len(r.connMap) + alertCount := len(r.alerts) + b.WriteString("\n") + statsLine := fmt.Sprintf(" Connections: %d Fingerprints: %d Alerts: %d", + connCount, fpCount, alertCount) + if suppressedCount > 0 { + statsLine += " " + learningStyle.Render(fmt.Sprintf("Suppressed: %d", suppressedCount)) + } + b.WriteString(statsLine) + b.WriteString("\n\n") + + // ── Alert severity breakdown ──────────────────────────────── + b.WriteString(sectionHeaderStyle.Render("ALERTS BY SEVERITY")) + b.WriteString("\n") + var crit, high, med, low int + for _, a := range r.alerts { + switch a.Severity { + case anomaly.SeverityCritical: + crit++ + case anomaly.SeverityHigh: + high++ + case anomaly.SeverityMedium: + med++ + case anomaly.SeverityLow: + low++ + } + } + b.WriteString(fmt.Sprintf(" %s %s %s %s\n", + severityCriticalStyle.Render(fmt.Sprintf("Critical: %d", crit)), + severityHighStyle.Render(fmt.Sprintf("High: %d", high)), + severityMediumStyle.Render(fmt.Sprintf("Medium: %d", med)), + severityLowStyle.Render(fmt.Sprintf("Low: %d", low)), + )) + b.WriteString("\n") + + // ── Top source IPs ────────────────────────────────────────── + b.WriteString(sectionHeaderStyle.Render("TOP SOURCE IPs")) + b.WriteString("\n") + if len(s.topSrcIPs) == 0 { + b.WriteString(" (no connections yet)\n") + } else { + for i, e := range s.topSrcIPs { + b.WriteString(fmt.Sprintf(" %d. %-40s %d\n", i+1, e.label, e.count)) + } + } + b.WriteString("\n") + + // ── Top JA4 fingerprints ──────────────────────────────────── + b.WriteString(sectionHeaderStyle.Render("TOP JA4 FINGERPRINTS")) + b.WriteString("\n") + if len(s.topJA4s) == 0 { + b.WriteString(" (no JA4 fingerprints seen)\n") + } else { + for i, e := range s.topJA4s { + b.WriteString(fmt.Sprintf(" %d. %-50s %d\n", i+1, truncate(e.label, 50), e.count)) + } + } + b.WriteString("\n") + + // ── Custom rules ──────────────────────────────────────────── + if r.ruleStatsFn != nil { + if stats := r.ruleStatsFn(); len(stats) > 0 { + b.WriteString(sectionHeaderStyle.Render("CUSTOM RULES (since start)")) + b.WriteString("\n") + for _, rs := range stats { + b.WriteString(fmt.Sprintf(" %-40s fired: %d\n", rs.Name, rs.FireCount)) + } + b.WriteString("\n") + } + } + + // ── Recent alerts ─────────────────────────────────────────── + b.WriteString(sectionHeaderStyle.Render("RECENT ALERTS")) + b.WriteString("\n") + recent := r.alerts + if len(recent) == 0 { + b.WriteString(" All clear.\n") + } else { + // Show last 10 in reverse-chron order (newest first). + start := len(recent) - 10 + if start < 0 { + start = 0 + } + for i := len(recent) - 1; i >= start; i-- { + a := recent[i] + row := fmt.Sprintf(" %s %s %-28s %s:%d → %s:%d", + a.Timestamp.Format("15:04:05"), + renderSeverity(a.Severity), + truncate(a.Rule, 28), + a.SrcIP, a.SrcPort, + a.DstIP, a.DstPort, + ) + b.WriteString(row) + b.WriteString("\n") + } + } + + footer := " [S/Esc]back to table [q]quit" + if status := r.currentStatus(); status != "" { + footer = " " + status + return finalize(r, b.String(), learningStyle.Render(footer)) + } + return finalize(r, b.String(), footerStyle.Render(footer)) +} diff --git a/internal/tui/summary_test.go b/internal/tui/summary_test.go new file mode 100644 index 0000000..615ffce --- /dev/null +++ b/internal/tui/summary_test.go @@ -0,0 +1,337 @@ +package tui + +import ( + "testing" + "time" + + tea "github.com/charmbracelet/bubbletea" + + "github.com/Crank-Git/ja4monitor/internal/anomaly" + "github.com/Crank-Git/ja4monitor/internal/tracker" +) + +// newTestRouterWithConns returns a router pre-populated with connections. +func newTestRouterWithConns(conns []*tracker.Connection) Router { + r := newTestRouter() + for _, c := range conns { + r.connMap[c.ID] = c + } + return r +} + +// makeTestConn creates a minimal connection for tests. +func makeTestConn(id, srcIP, dstIP string, dstPort uint16, ja4 string) *tracker.Connection { + conn := tracker.NewConnection(srcIP, 12345, dstIP, dstPort, "tcp", time.Now()) + if ja4 != "" { + conn.AddFingerprint("ja4", ja4, time.Now(), 1) + } + return conn +} + +// ── topEntries ────────────────────────────────────────────────────── + +func TestTopEntries_BasicOrder(t *testing.T) { + counts := map[string]int{ + "b": 3, + "a": 10, + "c": 1, + } + entries := topEntries(counts, 5) + if len(entries) != 3 { + t.Fatalf("expected 3 entries, got %d", len(entries)) + } + if entries[0].label != "a" || entries[0].count != 10 { + t.Errorf("first entry should be a/10, got %s/%d", entries[0].label, entries[0].count) + } + if entries[1].label != "b" { + t.Errorf("second entry should be b, got %s", entries[1].label) + } +} + +func TestTopEntries_CappedAtN(t *testing.T) { + counts := map[string]int{"a": 5, "b": 4, "c": 3, "d": 2, "e": 1, "f": 6} + entries := topEntries(counts, 3) + if len(entries) != 3 { + t.Fatalf("expected 3 entries (capped), got %d", len(entries)) + } +} + +func TestTopEntries_TieBreakAlphabetical(t *testing.T) { + counts := map[string]int{"z": 5, "a": 5} + entries := topEntries(counts, 5) + // Same count → alphabetical: "a" before "z" + if entries[0].label != "a" { + t.Errorf("tie-break: expected 'a' first, got %q", entries[0].label) + } +} + +func TestTopEntries_Empty(t *testing.T) { + entries := topEntries(map[string]int{}, 5) + if len(entries) != 0 { + t.Fatalf("expected 0 entries, got %d", len(entries)) + } +} + +// ── buildTopSrcIPs ────────────────────────────────────────────────── + +func TestBuildTopSrcIPs(t *testing.T) { + // Two connections from 10.0.0.1, one from 10.0.0.2. + conns := []*tracker.Connection{ + makeTestConn("c1", "10.0.0.1", "1.1.1.1", 443, ""), + makeTestConn("c2", "10.0.0.1", "2.2.2.2", 443, ""), + makeTestConn("c3", "10.0.0.2", "3.3.3.3", 80, ""), + } + r := newTestRouterWithConns(conns) + + entries := buildTopSrcIPs(&r, 5) + if len(entries) != 2 { + t.Fatalf("expected 2 unique src IPs, got %d", len(entries)) + } + if entries[0].label != "10.0.0.1" || entries[0].count != 2 { + t.Errorf("top IP: got %s/%d, want 10.0.0.1/2", entries[0].label, entries[0].count) + } +} + +func TestBuildTopSrcIPs_EmptyConnMap(t *testing.T) { + r := newTestRouter() + entries := buildTopSrcIPs(&r, 5) + if len(entries) != 0 { + t.Fatalf("expected 0 entries for empty connMap, got %d", len(entries)) + } +} + +// ── buildTopJA4s ──────────────────────────────────────────────────── + +func TestBuildTopJA4s(t *testing.T) { + conns := []*tracker.Connection{ + makeTestConn("c1", "10.0.0.1", "1.1.1.1", 443, "t13d_aaa"), + makeTestConn("c2", "10.0.0.2", "1.1.1.1", 443, "t13d_aaa"), + makeTestConn("c3", "10.0.0.3", "1.1.1.1", 443, "t13d_bbb"), + } + r := newTestRouterWithConns(conns) + + entries := buildTopJA4s(&r, 5) + if entries[0].label != "t13d_aaa" || entries[0].count != 2 { + t.Errorf("top JA4: got %s/%d, want t13d_aaa/2", entries[0].label, entries[0].count) + } +} + +func TestBuildTopJA4s_SkipsNoFingerprint(t *testing.T) { + // Connections without JA4 fingerprints should not appear. + conns := []*tracker.Connection{ + makeTestConn("c1", "10.0.0.1", "1.1.1.1", 443, ""), + makeTestConn("c2", "10.0.0.2", "1.1.1.1", 443, "t13d_aaa"), + } + r := newTestRouterWithConns(conns) + + entries := buildTopJA4s(&r, 5) + if len(entries) != 1 { + t.Fatalf("expected 1 JA4 entry (no-FP conn skipped), got %d", len(entries)) + } + if entries[0].label != "t13d_aaa" { + t.Errorf("expected t13d_aaa, got %q", entries[0].label) + } +} + +// ── SummaryModel.onTick ───────────────────────────────────────────── + +func TestSummaryModel_OnTick_RefreshesAfter1s(t *testing.T) { + conns := []*tracker.Connection{ + makeTestConn("c1", "10.0.0.1", "1.1.1.1", 443, "t13d_test"), + } + r := newTestRouterWithConns(conns) + s := newSummaryModel() + + // First tick with zero lastRefresh should always refresh. + s.onTick(&r) + if len(s.topSrcIPs) != 1 { + t.Fatalf("expected topSrcIPs populated after first tick, got %d", len(s.topSrcIPs)) + } + if s.topSrcIPs[0].label != "10.0.0.1" { + t.Errorf("expected 10.0.0.1, got %q", s.topSrcIPs[0].label) + } +} + +func TestSummaryModel_OnTick_NoRefreshWithin1s(t *testing.T) { + r := newTestRouter() + s := newSummaryModel() + s.lastRefresh = time.Now() // just refreshed + + // Add a connection after setting lastRefresh — it should NOT appear yet. + conn := makeTestConn("c1", "192.168.1.1", "1.1.1.1", 443, "") + r.connMap[conn.ID] = conn + + s.onTick(&r) + // topSrcIPs should still be empty (cache not expired) + if len(s.topSrcIPs) != 0 { + t.Fatalf("expected no refresh within 1s, but got %d entries", len(s.topSrcIPs)) + } +} + +// ── View smoke test ───────────────────────────────────────────────── + +func TestSummaryModel_View_NoConnections(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + s := newSummaryModel() + + view := s.View(&r) + if view == "" { + t.Fatal("View returned empty string") + } + if !contains(view, "SUMMARY") { + t.Error("View should contain 'SUMMARY' header") + } + if !contains(view, "no connections") { + t.Error("View should show empty-state message for top IPs") + } +} + +func TestSummaryModel_View_WithAlerts(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + r.alerts = []anomaly.Alert{ + {Rule: "new_fingerprint", Severity: anomaly.SeverityHigh, Timestamp: time.Now(), + SrcIP: "10.0.0.1", DstIP: "1.2.3.4", DstPort: 443}, + {Rule: "ua_mismatch", Severity: anomaly.SeverityMedium, Timestamp: time.Now(), + SrcIP: "10.0.0.2", DstIP: "5.6.7.8", DstPort: 443}, + } + s := newSummaryModel() + view := s.View(&r) + + if !contains(view, "new_fingerprint") { + t.Error("View should show recent alerts") + } +} + +func TestSummaryModel_View_WithRuleStats(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + r.ruleStatsFn = func() []anomaly.RuleStat { + return []anomaly.RuleStat{ + {Name: "port_scan_detector", FireCount: 3}, + {Name: "beaconing_detector", FireCount: 0}, + } + } + s := newSummaryModel() + view := s.View(&r) + + if !contains(view, "port_scan_detector") { + t.Error("View should show custom rule names") + } + if !contains(view, "fired: 3") { + t.Error("View should show fire count") + } +} + +func TestSummaryModel_View_NilRuleStatsFn(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + // ruleStatsFn is nil (attach mode, no evaluator) + s := newSummaryModel() + view := s.View(&r) + // Should render without panicking, just omit the CUSTOM RULES section. + if view == "" { + t.Fatal("View should not be empty even with nil ruleStatsFn") + } +} + +// ── State machine: S key ──────────────────────────────────────────── + +func TestRouter_SKey_EntersSummary(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + + updated, _ := r.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("S")}) + r2 := updated.(Router) + if r2.state != stateSummary { + t.Errorf("S key: expected stateSummary, got %v", r2.state) + } +} + +func TestRouter_SKey_TogglesBackToTable(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + r.state = stateSummary + + updated, _ := r.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("S")}) + r2 := updated.(Router) + if r2.state != stateTable { + t.Errorf("S key in summary: expected stateTable, got %v", r2.state) + } +} + +func TestRouter_SKey_FromAlertMode(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + r.state = stateAlertList + + updated, _ := r.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("S")}) + r2 := updated.(Router) + if r2.state != stateSummary { + t.Errorf("S key from alert list: expected stateSummary, got %v", r2.state) + } +} + +func TestRouter_SummaryEscReturnsToTable(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + r.state = stateSummary + + updated, _ := r.Update(tea.KeyMsg{Type: tea.KeyEsc}) + r2 := updated.(Router) + if r2.state != stateTable { + t.Errorf("Esc in summary: expected stateTable, got %v", r2.state) + } +} + +func TestRouter_WithRuleStatsFn(t *testing.T) { + called := false + fn := func() []anomaly.RuleStat { + called = true + return nil + } + r := newTestRouter().WithRuleStatsFn(fn) + if r.ruleStatsFn == nil { + t.Fatal("ruleStatsFn should be set") + } + r.ruleStatsFn() + if !called { + t.Error("ruleStatsFn was not the function we set") + } +} + +// ── View renders for stateSummary ─────────────────────────────────── + +func TestRouter_ViewSummary(t *testing.T) { + r := newTestRouter() + r.width = 120 + r.height = 40 + r.state = stateSummary + + view := r.View() + if !contains(view, "SUMMARY") { + t.Error("Router.View() in summary state should render summary view") + } +} + +// contains is a helper for substring checks in view tests. +func contains(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(sub) == 0 || + func() bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false + }()) +}