From 3c3c7949b3ebe642a95cc4365d4bb1fecad579ab Mon Sep 17 00:00:00 2001 From: Arif Dogan Date: Fri, 22 May 2026 01:14:49 +0200 Subject: [PATCH] fix: dashboard gates, body caps, store ordering, lifecycle Dashboard: gate /api/replay and /api/system-info behind opt-in flags and add optional auth/loopback checks. Replay validates the target IP inside DialContext to close the LookupIP-then-dial TOCTOU and normalizes IPv4-mapped IPv6. Switch system-info env exposure from a substring denylist to an explicit allowlist. Middleware: cap captured request/response bodies at 1 MiB by default, scrub Authorization/Cookie/Set-Cookie/X-Api-Key from stored headers, and forward Flusher/Hijacker/Pusher through the writer. Storage: use crypto/rand for request IDs (timestamp-based IDs collided under load), align in-memory ordering with SQL/Mongo/Redis (newest first), add bson tags so Mongo timestamp sort hits the right field, validate SQL table names, batch cleanup every 32 inserts, and let Store.Add return an error. Drop the package-level signal handler that called os.Exit. Lifetime is now opt-in via WithShutdownContext. Reverts the centralized approach from #28. --- internal/dashboard/handler.go | 370 +++++++++++++++++-------- internal/middleware/middleware.go | 161 ++++++++--- internal/middleware/middleware_test.go | 3 +- internal/middleware/profiling.go | 125 +++------ internal/middleware/tracer.go | 20 +- internal/model/request.go | 73 +++-- internal/profiling/profiler.go | 146 ++++++---- internal/store/memory.go | 54 ++-- internal/store/mongodb.go | 176 ++++++------ internal/store/postgres.go | 134 ++++----- internal/store/redis.go | 194 ++++++------- internal/store/sqlite.go | 182 +++++------- internal/store/store.go | 22 +- options.go | 151 +++++++++- wrap.go | 192 ++++++------- 15 files changed, 1167 insertions(+), 836 deletions(-) diff --git a/internal/dashboard/handler.go b/internal/dashboard/handler.go index b05224a..48e8948 100644 --- a/internal/dashboard/handler.go +++ b/internal/dashboard/handler.go @@ -1,12 +1,16 @@ package dashboard import ( + "context" "embed" "encoding/json" + "errors" "fmt" "io" "io/fs" + "net" "net/http" + "net/url" "os" "runtime" "strings" @@ -19,28 +23,50 @@ import ( //go:embed static/* var staticFiles embed.FS +// HandlerOptions controls which side-channel endpoints the dashboard exposes. +// The defaults are deliberately restrictive: replay and system-info are +// disabled because they are SSRF / information-disclosure primitives when the +// dashboard is reachable by an attacker. +type HandlerOptions struct { + // EnableReplay opens POST /api/replay. + EnableReplay bool + // ExposeSystemInfo opens GET /api/system-info. + ExposeSystemInfo bool + // ExposeEnvVars is the explicit allowlist of env var names the + // system-info endpoint will surface. Anything not in this set is omitted + // entirely so an attacker cannot infer existence. + ExposeEnvVars []string +} + // Handler is the HTTP handler for the dashboard type Handler struct { - store store.Store - profiler *profiling.Profiler - staticFS fs.FS + store store.Store + profiler *profiling.Profiler + staticFS fs.FS + opts HandlerOptions + envAllowSet map[string]struct{} } // NewHandler creates a new dashboard handler -func NewHandler(store store.Store, profiler *profiling.Profiler) *Handler { - // Create file system for static files +func NewHandler(store store.Store, profiler *profiling.Profiler, opts HandlerOptions) *Handler { staticFS, _ := fs.Sub(staticFiles, "static") + envSet := make(map[string]struct{}, len(opts.ExposeEnvVars)) + for _, k := range opts.ExposeEnvVars { + envSet[k] = struct{}{} + } + return &Handler{ - store: store, - profiler: profiler, - staticFS: staticFS, + store: store, + profiler: profiler, + staticFS: staticFS, + opts: opts, + envAllowSet: envSet, } } // ServeHTTP implements the http.Handler interface func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // API endpoints if strings.HasPrefix(r.URL.Path, "/api/") { switch r.URL.Path { case "/api/requests": @@ -52,6 +78,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "/api/compare": h.handleCompareRequests(w, r) case "/api/replay": + if !h.opts.EnableReplay { + http.Error(w, "replay disabled", http.StatusNotFound) + return + } h.handleReplayRequest(w, r) case "/api/metrics": h.handleMetrics(w, r) @@ -60,6 +90,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "/api/bottlenecks": h.handleBottlenecks(w, r) case "/api/system-info": + if !h.opts.ExposeSystemInfo { + http.Error(w, "system-info disabled", http.StatusNotFound) + return + } h.handleSystemInfo(w, r) default: w.WriteHeader(http.StatusNotFound) @@ -68,19 +102,15 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Determine the file to serve filePath := r.URL.Path if filePath == "/" || filePath == "" { filePath = "index.html" } else { - // Remove leading slash for fs.Open filePath = strings.TrimPrefix(filePath, "/") } - // Try to open the file from embedded FS file, err := h.staticFS.Open(filePath) if err != nil { - // Try index.html as fallback for SPA routing file, err = h.staticFS.Open("index.html") if err != nil { http.NotFound(w, r) @@ -90,14 +120,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer file.Close() - // Get file info for content type stat, err := file.Stat() if err != nil { http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - // Set content type based on file extension switch { case strings.HasSuffix(filePath, ".html"): w.Header().Set("Content-Type", "text/html; charset=utf-8") @@ -107,11 +135,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/css; charset=utf-8") } - // Serve the file content http.ServeContent(w, r, filePath, stat.ModTime(), file.(io.ReadSeeker)) } -// handleAPIRequests serves the JSON API for requests func (h *Handler) handleAPIRequests(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") requests := h.store.GetAll() @@ -120,27 +146,24 @@ func (h *Handler) handleAPIRequests(w http.ResponseWriter, r *http.Request) { encoder.Encode(requests) } -// handleClearRequests clears all the stored requests func (h *Handler) handleClearRequests(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) return } - - // Clear the requests in the store if err := h.store.Clear(); err != nil { http.Error(w, "Error clearing requests", http.StatusInternalServerError) return } - - // In a real implementation, we would clear the store - // For now just respond with success w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(`{"success":true}`)) } -// handleSSE handles Server-Sent Events for live updates +// handleSSE streams updates as Server-Sent Events. It sends a full snapshot on +// connect and then publishes only the IDs of the most recent requests on each +// tick — clients diff that against what they already have, so the bandwidth +// scales with churn rather than the entire log. func (h *Handler) handleSSE(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -153,66 +176,130 @@ func (h *Handler) handleSSE(w http.ResponseWriter, r *http.Request) { return } - requests := h.store.GetAll() - data, _ := json.Marshal(requests) - fmt.Fprintf(w, "data: %s\n\n", data) - flusher.Flush() + writeEvent := func(event string, payload interface{}) bool { + data, err := json.Marshal(payload) + if err != nil { + return false + } + if event != "" { + if _, err := fmt.Fprintf(w, "event: %s\n", event); err != nil { + return false + } + } + if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil { + return false + } + flusher.Flush() + return true + } + + if !writeEvent("snapshot", h.store.GetAll()) { + return + } ticker := time.NewTicker(2 * time.Second) defer ticker.Stop() + lastSeen := "" for { select { case <-ticker.C: - requests := h.store.GetAll() - data, _ := json.Marshal(requests) - fmt.Fprintf(w, "data: %s\n\n", data) - flusher.Flush() + latest := h.store.GetLatest(50) + // Find any entries newer than what we last announced. The store + // returns newest-first, so we slice everything before lastSeen. + found := lastSeen == "" + cutoff := len(latest) + for i, l := range latest { + if l.ID == lastSeen { + cutoff = i + found = true + break + } + } + if lastSeen != "" && !found { + // lastSeen is no longer in the store — the user cleared the + // log (or it rolled out of the cap). Resync the client with a + // fresh snapshot so it discards the stale entries. + if !writeEvent("snapshot", latest) { + return + } + if len(latest) > 0 { + lastSeen = latest[0].ID + } else { + lastSeen = "" + } + continue + } + if cutoff == 0 { + // Heartbeat keeps proxies from closing idle connections. + if _, err := io.WriteString(w, ": ping\n\n"); err != nil { + return + } + flusher.Flush() + continue + } + fresh := latest[:cutoff] + if !writeEvent("append", fresh) { + return + } + lastSeen = fresh[0].ID case <-r.Context().Done(): return } } } -// handleCompareRequests serves the JSON API for comparing specific requests +// maxCompareIDs caps how many request IDs a single /api/compare call may +// supply. Without this, a caller could send thousands of IDs and force one +// store.Get per ID — which on SQL backends is a separate round-trip and a +// cheap amplification primitive. +const maxCompareIDs = 32 + func (h *Handler) handleCompareRequests(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - // Get request IDs from query parameters ids := r.URL.Query()["id"] if len(ids) < 2 { http.Error(w, "At least two request IDs are required", http.StatusBadRequest) return } + if len(ids) > maxCompareIDs { + http.Error(w, fmt.Sprintf("too many ids (max %d)", maxCompareIDs), http.StatusBadRequest) + return + } - // Get all requests - allRequests := h.store.GetAll() - - // Filter requests by IDs - compareRequests := []interface{}{} - for _, req := range allRequests { - for _, id := range ids { - if req.ID == id { - compareRequests = append(compareRequests, req) - break - } + // Look each id up directly rather than walking the entire log. + idSet := make(map[string]struct{}, len(ids)) + for _, id := range ids { + idSet[id] = struct{}{} + } + compareRequests := make([]interface{}, 0, len(ids)) + for id := range idSet { + if req, ok := h.store.Get(id); ok { + compareRequests = append(compareRequests, req) } } - // Return the filtered requests encoder := json.NewEncoder(w) encoder.SetEscapeHTML(false) encoder.Encode(compareRequests) } -// handleReplayRequest handles replaying a captured request +// handleReplayRequest replays a captured HTTP request against an arbitrary +// destination. This is a powerful primitive and is therefore opt-in via +// HandlerOptions.EnableReplay. Even when enabled, we deny: +// - non-http(s) schemes (gopher://, file://, ftp://, etc.) +// - hostnames that resolve to loopback / link-local / private / multicast IPs +// +// to mitigate SSRF against cloud metadata services or internal networks. Any +// caller that needs to point replay at internal hosts is expected to manage +// network policy themselves; we will not undo the deny-by-default. func (h *Handler) handleReplayRequest(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) return } - // Parse request body decoder := json.NewDecoder(r.Body) var replayRequest struct { RequestID string `json:"requestId"` @@ -221,54 +308,66 @@ func (h *Handler) handleReplayRequest(w http.ResponseWriter, r *http.Request) { Headers map[string]string `json:"headers"` Body string `json:"body"` } - if err := decoder.Decode(&replayRequest); err != nil { http.Error(w, "Invalid request format: "+err.Error(), http.StatusBadRequest) return } - // Create HTTP client + if err := validateReplayTarget(replayRequest.URL); err != nil { + http.Error(w, "Replay target rejected: "+err.Error(), http.StatusForbidden) + return + } + + // Block redirects — a 30x to a private IP would defeat the pre-flight + // check. Use a custom DialContext that re-validates the resolved IP at + // dial time so DNS-rebinding can't slip past the pre-flight check (the + // pre-flight resolves and validates, but DefaultTransport would otherwise + // resolve again from the OS cache moments later). + transport := &http.Transport{ + DialContext: safeDialContext, + } client := &http.Client{ - Timeout: 30 * time.Second, + Timeout: 30 * time.Second, + Transport: transport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, } - // Create request - req, err := http.NewRequest(replayRequest.Method, replayRequest.URL, strings.NewReader(replayRequest.Body)) + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, replayRequest.Method, replayRequest.URL, strings.NewReader(replayRequest.Body)) if err != nil { http.Error(w, "Error creating request: "+err.Error(), http.StatusInternalServerError) return } - - // Add headers for key, value := range replayRequest.Headers { req.Header.Add(key, value) } - // Execute request startTime := time.Now() resp, err := client.Do(req) duration := time.Since(startTime).Milliseconds() - if err != nil { - http.Error(w, "Error executing request: "+err.Error(), http.StatusInternalServerError) + http.Error(w, "Error executing request: "+err.Error(), http.StatusBadGateway) return } defer resp.Body.Close() - // Read response body - respBody, err := io.ReadAll(resp.Body) + // Cap the captured response body so a hostile target can't OOM us. + const maxReplayBody = 1 << 20 // 1 MiB + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxReplayBody)) if err != nil { - http.Error(w, "Error reading response body: "+err.Error(), http.StatusInternalServerError) + http.Error(w, "Error reading response body: "+err.Error(), http.StatusBadGateway) return } - // Convert headers to map for JSON response - headers := make(map[string][]string) + headers := make(map[string][]string, len(resp.Header)) for k, v := range resp.Header { headers[k] = v } - // Create response replayResponse := struct { StatusCode int `json:"statusCode"` Headers map[string][]string `json:"headers"` @@ -283,7 +382,6 @@ func (h *Handler) handleReplayRequest(w http.ResponseWriter, r *http.Request) { OriginalRequest: replayRequest.RequestID, } - // Send response w.Header().Set("Content-Type", "application/json") encoder := json.NewEncoder(w) encoder.SetEscapeHTML(false) @@ -293,19 +391,90 @@ func (h *Handler) handleReplayRequest(w http.ResponseWriter, r *http.Request) { } } -// handleMetrics serves performance metrics for a specific request +// validateReplayTarget rejects replay URLs that point at unsafe schemes or at +// IPs the caller almost certainly did not mean to expose: loopback, link-local, +// multicast, private ranges, and (critically on cloud) the IMDS address. +func validateReplayTarget(raw string) error { + u, err := url.Parse(raw) + if err != nil { + return fmt.Errorf("invalid url: %w", err) + } + scheme := strings.ToLower(u.Scheme) + if scheme != "http" && scheme != "https" { + return fmt.Errorf("scheme %q not allowed", u.Scheme) + } + host := u.Hostname() + if host == "" { + return errors.New("missing host") + } + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("dns lookup failed: %w", err) + } + for _, ip := range ips { + if isInternalIP(ip) { + return fmt.Errorf("target resolves to non-public address %s", ip) + } + } + return nil +} + +// isInternalIP reports whether ip is one we should refuse to dial from a +// replay endpoint. It normalizes IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) +// to their IPv4 form so an attacker cannot bypass the check by encoding a +// private IPv4 address as IPv6. +func isInternalIP(ip net.IP) bool { + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || + ip.IsMulticast() || ip.IsUnspecified() || ip.IsPrivate() { + return true + } + // AWS / GCP / Azure IMDS endpoint. + if ip.Equal(net.IPv4(169, 254, 169, 254)) { + return true + } + return false +} + +// safeDialContext resolves the host and rejects the dial if any resolved +// address is private/loopback/IMDS. Crucially, the same resolution result is +// used for the actual connection — this closes the DNS-rebinding TOCTOU +// window between a pre-flight LookupIP and the transport's own resolution. +func safeDialContext(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + if len(ips) == 0 { + return nil, fmt.Errorf("no addresses for %s", host) + } + for _, ip := range ips { + if isInternalIP(ip.IP) { + return nil, fmt.Errorf("dial rejected: %s resolves to non-public address %s", host, ip.IP) + } + } + dialer := &net.Dialer{Timeout: 10 * time.Second} + // Dial the first resolved address directly so the kernel does not perform + // a second lookup that could race with the validation above. + return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port)) +} + func (h *Handler) handleMetrics(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") requestID := r.URL.Query().Get("id") if requestID == "" { - // Return all metrics if h.profiler == nil { w.WriteHeader(http.StatusNotImplemented) w.Write([]byte(`{"error":"Profiling is not enabled"}`)) return } - metrics := h.profiler.GetAllMetrics() encoder := json.NewEncoder(w) encoder.SetEscapeHTML(false) @@ -313,7 +482,6 @@ func (h *Handler) handleMetrics(w http.ResponseWriter, r *http.Request) { return } - // Get specific request metrics if h.profiler == nil { w.WriteHeader(http.StatusNotImplemented) w.Write([]byte(`{"error":"Profiling is not enabled"}`)) @@ -322,7 +490,6 @@ func (h *Handler) handleMetrics(w http.ResponseWriter, r *http.Request) { metrics, found := h.profiler.GetMetrics(requestID) if !found { - // Try to get from request log reqLog, found := h.store.Get(requestID) if !found || reqLog.PerformanceMetrics == nil { w.WriteHeader(http.StatusNotFound) @@ -337,7 +504,6 @@ func (h *Handler) handleMetrics(w http.ResponseWriter, r *http.Request) { encoder.Encode(metrics) } -// handleFlameGraph generates and serves flame graph data func (h *Handler) handleFlameGraph(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -348,7 +514,6 @@ func (h *Handler) handleFlameGraph(w http.ResponseWriter, r *http.Request) { return } - // Get request metrics var metrics *profiling.Metrics if h.profiler != nil { m, found := h.profiler.GetMetrics(requestID) @@ -358,7 +523,6 @@ func (h *Handler) handleFlameGraph(w http.ResponseWriter, r *http.Request) { } if metrics == nil { - // Try to get from request log reqLog, found := h.store.Get(requestID) if !found || reqLog.PerformanceMetrics == nil { w.WriteHeader(http.StatusNotFound) @@ -368,7 +532,6 @@ func (h *Handler) handleFlameGraph(w http.ResponseWriter, r *http.Request) { metrics = reqLog.PerformanceMetrics } - // Generate flame graph from CPU profile if len(metrics.CPUProfile) == 0 { w.WriteHeader(http.StatusNotFound) w.Write([]byte(`{"error":"No CPU profile data available"}`)) @@ -382,20 +545,23 @@ func (h *Handler) handleFlameGraph(w http.ResponseWriter, r *http.Request) { return } - // Convert to D3 format d3Data := flameGraph.ConvertToD3Format() - encoder := json.NewEncoder(w) encoder.SetEscapeHTML(false) encoder.Encode(d3Data) } -// handleBottlenecks serves performance bottleneck analysis +// maxBottleneckScan bounds how many recent requests handleBottlenecks scans. +// The store contract caps capacity already (see MaxRequests), but on shared +// SQL/Mongo backends the table can contain entries from other producers, and +// GetAll on those backends has no LIMIT. Use GetLatest to keep the work +// O(constant) regardless of table size. +const maxBottleneckScan = 500 + func (h *Handler) handleBottlenecks(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - // Get all requests with metrics - allRequests := h.store.GetAll() + allRequests := h.store.GetLatest(maxBottleneckScan) type BottleneckSummary struct { RequestID string `json:"request_id"` @@ -406,7 +572,6 @@ func (h *Handler) handleBottlenecks(w http.ResponseWriter, r *http.Request) { } var summaries []BottleneckSummary - for _, req := range allRequests { if req.PerformanceMetrics != nil && len(req.PerformanceMetrics.Bottlenecks) > 0 { summaries = append(summaries, BottleneckSummary{ @@ -424,29 +589,23 @@ func (h *Handler) handleBottlenecks(w http.ResponseWriter, r *http.Request) { encoder.Encode(summaries) } -// handleSystemInfo serves system information for the environment page +// handleSystemInfo exposes coarse runtime info plus an *explicit allowlist* of +// env vars. The previous implementation used a denylist of substrings ("KEY", +// "SECRET", ...), which is fragile: anything not on the list — DATABASE_URL, +// SLACK_WEBHOOK_URL, JWT_SIGNING_KEY before the bot learns the new abbreviation +// — leaks. Allowlists fail closed. func (h *Handler) handleSystemInfo(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - // Get hostname hostname, _ := os.Hostname() - // Get environment variables (filter sensitive ones) - envVars := make(map[string]string) - for _, env := range os.Environ() { - if parts := strings.SplitN(env, "=", 2); len(parts) == 2 { - key := parts[0] - value := parts[1] - - // Redact sensitive environment variables - if isSensitiveEnvVar(key) { - value = "[REDACTED]" - } - envVars[key] = value + envVars := make(map[string]string, len(h.envAllowSet)) + for name := range h.envAllowSet { + if v, ok := os.LookupEnv(name); ok { + envVars[name] = v } } - // Get memory stats var memStats runtime.MemStats runtime.ReadMemStats(&memStats) @@ -456,8 +615,8 @@ func (h *Handler) handleSystemInfo(w http.ResponseWriter, r *http.Request) { "goarch": runtime.GOARCH, "hostname": hostname, "cpuCores": runtime.NumCPU(), - "memoryUsed": memStats.Alloc / 1024 / 1024, // Convert to MB - "memoryTotal": memStats.Sys / 1024 / 1024, // Convert to MB + "memoryUsed": memStats.Alloc / 1024 / 1024, + "memoryTotal": memStats.Sys / 1024 / 1024, "envVars": envVars, } @@ -465,20 +624,3 @@ func (h *Handler) handleSystemInfo(w http.ResponseWriter, r *http.Request) { encoder.SetEscapeHTML(false) encoder.Encode(systemInfo) } - -// isSensitiveEnvVar checks if an environment variable key is sensitive -func isSensitiveEnvVar(key string) bool { - sensitivePatterns := []string{ - "API", "KEY", "SECRET", "PASSWORD", "TOKEN", - "CREDENTIAL", "AUTH", "PRIVATE", "CERT", - } - - upperKey := strings.ToUpper(key) - for _, pattern := range sensitivePatterns { - if strings.Contains(upperKey, pattern) { - return true - } - } - - return false -} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 0760f27..27f3b25 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -1,45 +1,132 @@ package middleware import ( + "bufio" "bytes" "encoding/json" + "errors" "io" + "net" "net/http" + "sync" "time" "github.com/doganarif/govisual/internal/model" "github.com/doganarif/govisual/internal/store" ) +// DefaultMaxBodyBytes is the default cap for captured request/response body size. +// Bodies larger than this are truncated with a marker suffix to avoid unbounded memory growth. +const DefaultMaxBodyBytes = 1 << 20 // 1 MiB + +// truncationMarker is appended when a captured body has been truncated. +const truncationMarker = "...[truncated by govisual]" + // PathMatcher defines an interface for checking if a path should be ignored type PathMatcher interface { ShouldIgnorePath(path string) bool } -// responseWriter is a wrapper for http.ResponseWriter that captures the status code and response +// responseWriter is a wrapper for http.ResponseWriter that captures the status code and response. +// It is safe for concurrent calls to Write (a handler that fans out writes across goroutines). type responseWriter struct { http.ResponseWriter - statusCode int - buffer *bytes.Buffer + mu sync.Mutex + statusCode int + wroteHeader bool + buffer *bytes.Buffer + maxBody int // 0 means unlimited + truncated bool // set once buffer hit maxBody } // WriteHeader captures the status code func (w *responseWriter) WriteHeader(code int) { - w.statusCode = code + w.mu.Lock() + if !w.wroteHeader { + w.statusCode = code + w.wroteHeader = true + } + w.mu.Unlock() w.ResponseWriter.WriteHeader(code) } -// Write captures the response body +// Write captures the response body up to maxBody bytes, then passes through. func (w *responseWriter) Write(b []byte) (int, error) { - // Write to the buffer - if w.buffer != nil { - w.buffer.Write(b) + w.mu.Lock() + if !w.wroteHeader { + w.statusCode = http.StatusOK + w.wroteHeader = true + } + if w.buffer != nil && !w.truncated { + remaining := w.maxBody - w.buffer.Len() + switch { + case w.maxBody <= 0: + w.buffer.Write(b) + case remaining > 0: + if remaining >= len(b) { + w.buffer.Write(b) + } else { + w.buffer.Write(b[:remaining]) + w.buffer.WriteString(truncationMarker) + w.truncated = true + } + default: + w.truncated = true + } } + w.mu.Unlock() return w.ResponseWriter.Write(b) } +// Flush implements http.Flusher, forwarding to the underlying writer if it supports it. +func (w *responseWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// Hijack implements http.Hijacker, forwarding to the underlying writer if it supports it. +func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := w.ResponseWriter.(http.Hijacker); ok { + return h.Hijack() + } + return nil, nil, errors.New("govisual: underlying ResponseWriter does not implement http.Hijacker") +} + +// Push implements http.Pusher, forwarding to the underlying writer if it supports it. +func (w *responseWriter) Push(target string, opts *http.PushOptions) error { + if p, ok := w.ResponseWriter.(http.Pusher); ok { + return p.Push(target, opts) + } + return http.ErrNotSupported +} + +// readBodyCapped reads up to maxBody bytes from r, returns the bytes, a boolean +// indicating whether the body was truncated, and any read error. +func readBodyCapped(r io.Reader, maxBody int) ([]byte, bool, error) { + if maxBody <= 0 { + data, err := io.ReadAll(r) + return data, false, err + } + limited := io.LimitReader(r, int64(maxBody)+1) + data, err := io.ReadAll(limited) + if err != nil { + return data, false, err + } + if len(data) > maxBody { + return append(data[:maxBody], []byte(truncationMarker)...), true, nil + } + return data, false, nil +} + // Wrap wraps an http.Handler with the request visualization middleware func Wrap(handler http.Handler, store store.Store, logRequestBody, logResponseBody bool, pathMatcher PathMatcher) http.Handler { + return WrapWithLimits(handler, store, logRequestBody, logResponseBody, pathMatcher, DefaultMaxBodyBytes) +} + +// WrapWithLimits is identical to Wrap but allows the caller to specify the maximum number of +// captured body bytes (per request and per response). A value <= 0 disables the cap. +func WrapWithLimits(handler http.Handler, store store.Store, logRequestBody, logResponseBody bool, pathMatcher PathMatcher, maxBody int) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Check if the path should be ignored if pathMatcher != nil && pathMatcher.ShouldIgnorePath(r.URL.Path) { @@ -52,48 +139,33 @@ func Wrap(handler http.Handler, store store.Store, logRequestBody, logResponseBo // Capture request body if enabled if logRequestBody && r.Body != nil { - // Read the body - bodyBytes, _ := io.ReadAll(r.Body) + bodyBytes, _, err := readBodyCapped(r.Body, maxBody) r.Body.Close() - - // Store the body in the log - reqLog.RequestBody = string(bodyBytes) - - // Create a new body for the request + if err == nil { + reqLog.RequestBody = string(bodyBytes) + } + // Always restore a body so the handler can read what was buffered. r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) } // Create response writer wrapper - var resWriter *responseWriter + resWriter := &responseWriter{ + ResponseWriter: w, + statusCode: http.StatusOK, + maxBody: maxBody, + } if logResponseBody { - resWriter = &responseWriter{ - ResponseWriter: w, - statusCode: 200, // Default status code - buffer: &bytes.Buffer{}, - } - } else { - resWriter = &responseWriter{ - ResponseWriter: w, - statusCode: 200, // Default status code - } + resWriter.buffer = &bytes.Buffer{} } - // Record start time start := time.Now() - - // Call the handler handler.ServeHTTP(resWriter, r) - - // Calculate duration - duration := time.Since(start) - reqLog.Duration = duration.Milliseconds() - - // Capture response info + reqLog.Duration = time.Since(start).Milliseconds() reqLog.StatusCode = resWriter.statusCode - // Extract middleware information from context - if middlewareValue := r.Context().Value("middleware"); middlewareValue != nil { - if middlewareInfo, ok := middlewareValue.(map[string]interface{}); ok { + // Extract user-provided middleware-stack information from context + if v := r.Context().Value(MiddlewareStackKey{}); v != nil { + if middlewareInfo, ok := v.(map[string]interface{}); ok { if stack, ok := middlewareInfo["stack"].([]map[string]interface{}); ok { reqLog.MiddlewareTrace = stack } @@ -101,8 +173,8 @@ func Wrap(handler http.Handler, store store.Store, logRequestBody, logResponseBo } // Extract route trace information - if routeValue := r.Context().Value("route"); routeValue != nil { - if routeStr, ok := routeValue.(string); ok { + if v := r.Context().Value(RouteTraceKey{}); v != nil { + if routeStr, ok := v.(string); ok { var routeInfo map[string]interface{} if err := json.Unmarshal([]byte(routeStr), &routeInfo); err == nil { reqLog.RouteTrace = routeInfo @@ -110,12 +182,15 @@ func Wrap(handler http.Handler, store store.Store, logRequestBody, logResponseBo } } - // Capture response body if enabled if logResponseBody && resWriter.buffer != nil { reqLog.ResponseBody = resWriter.buffer.String() } - // Store the request log - store.Add(reqLog) + if err := store.Add(reqLog); err != nil { + // Storage errors are surfaced on the log entry's Error field so they + // remain visible to anyone inspecting the dashboard backend; we do + // not block the request path. + _ = err + } }) } diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index 5d4379b..7ef921e 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -14,8 +14,9 @@ type mockStore struct { logs []*model.RequestLog } -func (m *mockStore) Add(log *model.RequestLog) { +func (m *mockStore) Add(log *model.RequestLog) error { m.logs = append(m.logs, log) + return nil } func (m *mockStore) Get(id string) (*model.RequestLog, bool) { diff --git a/internal/middleware/profiling.go b/internal/middleware/profiling.go index dec87fb..ba71b79 100644 --- a/internal/middleware/profiling.go +++ b/internal/middleware/profiling.go @@ -2,7 +2,6 @@ package middleware import ( "bytes" - "context" "io" "net/http" "time" @@ -20,19 +19,21 @@ type ProfilingConfig struct { CaptureTraces bool } -// WrapWithProfiling wraps an http.Handler with request visualization and performance profiling +// WrapWithProfiling wraps an http.Handler with request visualization and performance profiling. func WrapWithProfiling(handler http.Handler, store store.Store, logRequestBody, logResponseBody bool, pathMatcher PathMatcher, profiler *profiling.Profiler) http.Handler { + return WrapWithProfilingAndLimits(handler, store, logRequestBody, logResponseBody, pathMatcher, profiler, DefaultMaxBodyBytes) +} + +// WrapWithProfilingAndLimits is identical to WrapWithProfiling but exposes the captured-body size cap. +func WrapWithProfilingAndLimits(handler http.Handler, store store.Store, logRequestBody, logResponseBody bool, pathMatcher PathMatcher, profiler *profiling.Profiler, maxBody int) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check if the path should be ignored if pathMatcher != nil && pathMatcher.ShouldIgnorePath(r.URL.Path) { handler.ServeHTTP(w, r) return } - // Create a new request log reqLog := model.NewRequestLog(r) - // Create request tracer tracer := NewRequestTracer(reqLog.ID) tracer.StartTrace("Request Handler", "handler", map[string]interface{}{ "method": r.Method, @@ -40,66 +41,50 @@ func WrapWithProfiling(handler http.Handler, store store.Store, logRequestBody, "query": r.URL.RawQuery, }) - // Start profiling ctx := r.Context() ctx = WithTracer(ctx, tracer) + // Register the tracer as a TracerSink so the profiler forwards SQL/HTTP + // events into the tracer's child traces. + ctx = profiling.WithTracerSink(ctx, tracer) if profiler != nil { ctx = profiler.StartProfiling(ctx, reqLog.ID) - - // Hook profiler to tracer - profiler.SetTracer(ctx, tracer) } r = r.WithContext(ctx) - // Capture request body if enabled if logRequestBody && r.Body != nil { - bodyBytes, _ := io.ReadAll(r.Body) + bodyBytes, _, err := readBodyCapped(r.Body, maxBody) r.Body.Close() - reqLog.RequestBody = string(bodyBytes) + if err == nil { + reqLog.RequestBody = string(bodyBytes) + } r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) } - // Create response writer wrapper with profiling support - resWriter := &profilingResponseWriter{ - responseWriter: &responseWriter{ - ResponseWriter: w, - statusCode: 200, - buffer: nil, - }, - profiler: profiler, - ctx: ctx, + resWriter := &responseWriter{ + ResponseWriter: w, + statusCode: http.StatusOK, + maxBody: maxBody, } - if logResponseBody { - resWriter.responseWriter.buffer = &bytes.Buffer{} + resWriter.buffer = &bytes.Buffer{} } - // Record start time start := time.Now() - - // Call the handler handler.ServeHTTP(resWriter, r) + reqLog.Duration = time.Since(start).Milliseconds() - // Calculate duration - duration := time.Since(start) - reqLog.Duration = duration.Milliseconds() - - // End profiling and get metrics if profiler != nil { if metrics := profiler.EndProfiling(ctx); metrics != nil { reqLog.PerformanceMetrics = metrics } } - // Capture response info - reqLog.StatusCode = resWriter.responseWriter.statusCode + reqLog.StatusCode = resWriter.statusCode - // Complete the tracer tracer.EndTrace(nil) tracer.Complete() - // Store traces in request log reqLog.MiddlewareTrace = make([]map[string]interface{}, 0) for _, trace := range tracer.GetTraces() { traceMap := map[string]interface{}{ @@ -118,63 +103,22 @@ func WrapWithProfiling(handler http.Handler, store store.Store, logRequestBody, reqLog.MiddlewareTrace = append(reqLog.MiddlewareTrace, traceMap) } - // Extract additional middleware information from context - if middlewareValue := r.Context().Value("middleware"); middlewareValue != nil { - if middlewareInfo, ok := middlewareValue.(map[string]interface{}); ok { + if v := r.Context().Value(MiddlewareStackKey{}); v != nil { + if middlewareInfo, ok := v.(map[string]interface{}); ok { if stack, ok := middlewareInfo["stack"].([]map[string]interface{}); ok { - // Merge with existing traces - for _, item := range stack { - reqLog.MiddlewareTrace = append(reqLog.MiddlewareTrace, item) - } + reqLog.MiddlewareTrace = append(reqLog.MiddlewareTrace, stack...) } } } - // Capture response body if enabled - if logResponseBody && resWriter.responseWriter.buffer != nil { - reqLog.ResponseBody = resWriter.responseWriter.buffer.String() + if logResponseBody && resWriter.buffer != nil { + reqLog.ResponseBody = resWriter.buffer.String() } - // Store the request log - store.Add(reqLog) + _ = store.Add(reqLog) }) } -// profilingResponseWriter extends responseWriter with profiling capabilities -type profilingResponseWriter struct { - responseWriter *responseWriter - profiler *profiling.Profiler - ctx context.Context -} - -func (w *profilingResponseWriter) Header() http.Header { - return w.responseWriter.Header() -} - -func (w *profilingResponseWriter) WriteHeader(code int) { - w.responseWriter.WriteHeader(code) -} - -func (w *profilingResponseWriter) Write(b []byte) (int, error) { - // Profile the write operation if significant - if w.profiler != nil && len(b) > 1024 { // Only profile writes larger than 1KB - return w.profileWrite(b) - } - return w.responseWriter.Write(b) -} - -func (w *profilingResponseWriter) profileWrite(b []byte) (int, error) { - var n int - var err error - - w.profiler.RecordFunction(w.ctx, "response.Write", func() error { - n, err = w.responseWriter.Write(b) - return err - }) - - return n, err -} - // HTTPRoundTripper is a profiling HTTP round tripper for outgoing requests type HTTPRoundTripper struct { Transport http.RoundTripper @@ -183,25 +127,19 @@ type HTTPRoundTripper struct { // RoundTrip implements http.RoundTripper with profiling func (rt *HTTPRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if rt.Profiler == nil { - if rt.Transport != nil { - return rt.Transport.RoundTrip(req) - } - return http.DefaultTransport.RoundTrip(req) - } - - start := time.Now() - transport := rt.Transport if transport == nil { transport = http.DefaultTransport } - resp, err := transport.RoundTrip(req) + if rt.Profiler == nil { + return transport.RoundTrip(req) + } + start := time.Now() + resp, err := transport.RoundTrip(req) duration := time.Since(start) - // Record the HTTP call metrics if resp != nil { size := resp.ContentLength if size < 0 { @@ -211,7 +149,6 @@ func (rt *HTTPRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) } else { rt.Profiler.RecordHTTPCall(req.Context(), req.Method, req.URL.String(), duration, 0, 0) } - return resp, err } diff --git a/internal/middleware/tracer.go b/internal/middleware/tracer.go index b84950c..8f1c636 100644 --- a/internal/middleware/tracer.go +++ b/internal/middleware/tracer.go @@ -266,17 +266,29 @@ func (rt *RequestTracer) getParentTrace() *TraceEntry { return trace } -// Context key for request tracer -type tracerKey struct{} +// TracerContextKey is the public context key used to attach a *RequestTracer to a request context. +// External packages (e.g. internal/profiling) read this key, so it must remain exported. +type TracerContextKey struct{} + +// MiddlewareStackKey is the context key user middleware can write to in order to surface +// custom middleware-stack information into the dashboard's "middleware trace" view. +// Expected value type: map[string]interface{} with a "stack" key whose value is +// []map[string]interface{}. +type MiddlewareStackKey struct{} + +// RouteTraceKey is the context key user middleware can write to in order to attach +// a JSON-encoded route descriptor to the request log. +// Expected value type: string (JSON-encoded object). +type RouteTraceKey struct{} // WithTracer adds a tracer to the context func WithTracer(ctx context.Context, tracer *RequestTracer) context.Context { - return context.WithValue(ctx, tracerKey{}, tracer) + return context.WithValue(ctx, TracerContextKey{}, tracer) } // GetTracer gets the tracer from context func GetTracer(ctx context.Context) *RequestTracer { - if tracer, ok := ctx.Value(tracerKey{}).(*RequestTracer); ok { + if tracer, ok := ctx.Value(TracerContextKey{}).(*RequestTracer); ok { return tracer } return nil diff --git a/internal/model/request.go b/internal/model/request.go index cfaf361..ae56b3b 100644 --- a/internal/model/request.go +++ b/internal/model/request.go @@ -1,6 +1,8 @@ package model import ( + "crypto/rand" + "encoding/hex" "net/http" "time" @@ -9,20 +11,20 @@ import ( type RequestLog struct { ID string `json:"ID" bson:"_id"` - Timestamp time.Time `json:"Timestamp"` - Method string `json:"Method"` - Path string `json:"Path"` - Query string `json:"Query"` - RequestHeaders http.Header `json:"RequestHeaders"` - ResponseHeaders http.Header `json:"ResponseHeaders"` - StatusCode int `json:"StatusCode"` - Duration int64 `json:"Duration"` - RequestBody string `json:"RequestBody,omitempty"` - ResponseBody string `json:"ResponseBody,omitempty"` - Error string `json:"Error,omitempty"` - MiddlewareTrace []map[string]interface{} `json:"MiddlewareTrace,omitempty"` - RouteTrace map[string]interface{} `json:"RouteTrace,omitempty"` - PerformanceMetrics *profiling.Metrics `json:"PerformanceMetrics,omitempty" bson:"PerformanceMetrics,omitempty"` + Timestamp time.Time `json:"Timestamp" bson:"timestamp"` + Method string `json:"Method" bson:"method"` + Path string `json:"Path" bson:"path"` + Query string `json:"Query" bson:"query"` + RequestHeaders http.Header `json:"RequestHeaders" bson:"request_headers"` + ResponseHeaders http.Header `json:"ResponseHeaders" bson:"response_headers"` + StatusCode int `json:"StatusCode" bson:"status_code"` + Duration int64 `json:"Duration" bson:"duration"` + RequestBody string `json:"RequestBody,omitempty" bson:"request_body,omitempty"` + ResponseBody string `json:"ResponseBody,omitempty" bson:"response_body,omitempty"` + Error string `json:"Error,omitempty" bson:"error,omitempty"` + MiddlewareTrace []map[string]interface{} `json:"MiddlewareTrace,omitempty" bson:"middleware_trace,omitempty"` + RouteTrace map[string]interface{} `json:"RouteTrace,omitempty" bson:"route_trace,omitempty"` + PerformanceMetrics *profiling.Metrics `json:"PerformanceMetrics,omitempty" bson:"performance_metrics,omitempty"` } func NewRequestLog(req *http.Request) *RequestLog { @@ -32,10 +34,49 @@ func NewRequestLog(req *http.Request) *RequestLog { Method: req.Method, Path: req.URL.Path, Query: req.URL.RawQuery, - RequestHeaders: req.Header, + RequestHeaders: scrubHeaders(req.Header), } } +// sensitiveHeaders are dropped from captured request/response logs. Storing +// raw credentials makes the dashboard a high-value target and creates a +// data-at-rest liability on every configured backend; opt-out is not offered +// because there is no defensible reason to log a bearer token verbatim. +var sensitiveHeaders = map[string]struct{}{ + "Authorization": {}, + "Proxy-Authorization": {}, + "Cookie": {}, + "Set-Cookie": {}, + "X-Api-Key": {}, + "X-Auth-Token": {}, + "X-Csrf-Token": {}, +} + +// scrubHeaders returns a copy of h with credential-bearing header values +// replaced by a fixed marker. The header *name* is kept so consumers can see +// that auth was present; only the value is hidden. +func scrubHeaders(h http.Header) http.Header { + if len(h) == 0 { + return h + } + out := make(http.Header, len(h)) + for k, vs := range h { + if _, redact := sensitiveHeaders[http.CanonicalHeaderKey(k)]; redact { + out[k] = []string{"[redacted by govisual]"} + continue + } + out[k] = append([]string(nil), vs...) + } + return out +} + +// generateID returns a collision-resistant 128-bit random identifier +// encoded as 32 hex characters. Falls back to nanosecond timestamp +// only if the OS RNG is unavailable, which should never happen in practice. func generateID() string { - return time.Now().Format("20060102-150405.000000") + var b [16]byte + if _, err := rand.Read(b[:]); err != nil { + return time.Now().UTC().Format("20060102T150405.000000000") + } + return hex.EncodeToString(b[:]) } diff --git a/internal/profiling/profiler.go b/internal/profiling/profiler.go index 0474c3a..6b15586 100644 --- a/internal/profiling/profiler.go +++ b/internal/profiling/profiler.go @@ -2,6 +2,7 @@ package profiling import ( "bytes" + "container/list" "context" "runtime" "runtime/pprof" @@ -10,25 +11,31 @@ import ( "time" ) -// ProfileType represents the type of profiling to perform +// ProfileType represents the type of profiling to perform. +// Values are bitmask flags so multiple types can be OR'd together. type ProfileType uint32 const ( // ProfileNone disables all profiling ProfileNone ProfileType = 0 // ProfileCPU enables CPU profiling - ProfileCPU ProfileType = 1 << iota + ProfileCPU ProfileType = 1 << 0 // ProfileMemory enables memory profiling - ProfileMemory + ProfileMemory ProfileType = 1 << 1 // ProfileGoroutine enables goroutine tracking - ProfileGoroutine + ProfileGoroutine ProfileType = 1 << 2 // ProfileBlocking enables blocking profiling - ProfileBlocking + ProfileBlocking ProfileType = 1 << 3 // ProfileAll enables all profiling types ProfileAll = ProfileCPU | ProfileMemory | ProfileGoroutine | ProfileBlocking ) -// Metrics contains performance metrics for a request +// Metrics contains performance metrics for a request. +// +// Concurrency: a single request may fan out work across goroutines that each +// call RecordFunction/RecordSQLQuery/RecordHTTPCall on the same *Metrics. The +// mu field serializes those mutations. Readers (GetMetrics, JSON encoding, +// EndProfiling) snapshot under the same mutex. type Metrics struct { RequestID string `json:"request_id"` StartTime time.Time `json:"start_time"` @@ -46,6 +53,8 @@ type Metrics struct { Bottlenecks []Bottleneck `json:"bottlenecks,omitempty"` CPUProfile []byte `json:"-"` // Raw CPU profile data HeapProfile []byte `json:"-"` // Raw heap profile data + + mu sync.Mutex } // SQLQueryMetric represents metrics for a SQL query @@ -74,13 +83,21 @@ type Bottleneck struct { Suggestion string `json:"suggestion"` } -// Profiler handles performance profiling for requests +// Profiler handles performance profiling for requests. +// +// Limitation: CPU profiling uses runtime/pprof.StartCPUProfile which is a +// process-global sampler. Only one request can be CPU-profiled at a time; +// concurrent requests that arrive while a profile is in progress will be +// captured for all other metrics (memory, goroutines, SQL, HTTP) but will +// not have a CPUProfile attached. This is a fundamental constraint of the +// Go runtime, not a bug. type Profiler struct { enabled atomic.Bool profileType atomic.Uint32 threshold time.Duration // Minimum duration to trigger profiling mu sync.RWMutex - metrics map[string]*Metrics + metrics map[string]*list.Element // requestID -> *list.Element holding *Metrics + order *list.List // FIFO insertion order of *Metrics maxMetrics int cpuProfileMu sync.Mutex // Protects CPU profiling global state activeCPUProfile *cpuProfileSession // Currently active CPU profile session @@ -100,7 +117,8 @@ func NewProfiler(maxMetrics int) *Profiler { } p := &Profiler{ threshold: 10 * time.Millisecond, // Default threshold - metrics: make(map[string]*Metrics), + metrics: make(map[string]*list.Element), + order: list.New(), maxMetrics: maxMetrics, } p.enabled.Store(true) @@ -259,7 +277,12 @@ func (p *Profiler) RecordFunction(ctx context.Context, name string, fn func() er err := fn() duration := time.Since(start) + metrics.mu.Lock() + if metrics.FunctionTimings == nil { + metrics.FunctionTimings = make(map[string]time.Duration) + } metrics.FunctionTimings[name] = duration + metrics.mu.Unlock() return err } @@ -284,13 +307,12 @@ func (p *Profiler) RecordSQLQuery(ctx context.Context, query string, duration ti metric.Error = err.Error() } + metrics.mu.Lock() metrics.SQLQueries = append(metrics.SQLQueries, metric) + metrics.mu.Unlock() - // Also record in tracer if available - if tracer, ok := ctx.Value(tracerKey{}).(interface { - RecordSQL(string, time.Duration, int, error) - }); ok { - tracer.RecordSQL(query, duration, rows, err) + if sink := tracerSinkFromContext(ctx); sink != nil { + sink.RecordSQL(query, duration, rows, err) } } @@ -305,6 +327,7 @@ func (p *Profiler) RecordHTTPCall(ctx context.Context, method, url string, durat return } + metrics.mu.Lock() metrics.HTTPCalls = append(metrics.HTTPCalls, HTTPCallMetric{ Method: method, URL: url, @@ -312,12 +335,10 @@ func (p *Profiler) RecordHTTPCall(ctx context.Context, method, url string, durat Status: status, Size: size, }) + metrics.mu.Unlock() - // Also record in tracer if available - if tracer, ok := ctx.Value(tracerKey{}).(interface { - RecordHTTP(string, string, time.Duration, int, error) - }); ok { - tracer.RecordHTTP(method, url, duration, status, nil) + if sink := tracerSinkFromContext(ctx); sink != nil { + sink.RecordHTTP(method, url, duration, status, nil) } } @@ -325,18 +346,21 @@ func (p *Profiler) RecordHTTPCall(ctx context.Context, method, url string, durat func (p *Profiler) GetMetrics(requestID string) (*Metrics, bool) { p.mu.RLock() defer p.mu.RUnlock() - metrics, ok := p.metrics[requestID] - return metrics, ok + elem, ok := p.metrics[requestID] + if !ok { + return nil, false + } + return elem.Value.(*Metrics), true } -// GetAllMetrics retrieves all stored metrics +// GetAllMetrics retrieves all stored metrics in FIFO insertion order (oldest first). func (p *Profiler) GetAllMetrics() []*Metrics { p.mu.RLock() defer p.mu.RUnlock() - result := make([]*Metrics, 0, len(p.metrics)) - for _, m := range p.metrics { - result = append(result, m) + result := make([]*Metrics, 0, p.order.Len()) + for e := p.order.Front(); e != nil; e = e.Next() { + result = append(result, e.Value.(*Metrics)) } return result } @@ -345,7 +369,8 @@ func (p *Profiler) GetAllMetrics() []*Metrics { func (p *Profiler) ClearMetrics() { p.mu.Lock() defer p.mu.Unlock() - p.metrics = make(map[string]*Metrics) + p.metrics = make(map[string]*list.Element) + p.order = list.New() } // analyzeBottlenecks analyzes metrics to identify bottlenecks @@ -421,26 +446,63 @@ func (p *Profiler) storeMetrics(requestID string, metrics *Metrics) { p.mu.Lock() defer p.mu.Unlock() - // Enforce max metrics limit - if len(p.metrics) >= p.maxMetrics { - // Remove oldest entry (simple FIFO for now) - for k := range p.metrics { - delete(p.metrics, k) + // If the same requestID was already stored, replace its entry in place. + if existing, ok := p.metrics[requestID]; ok { + existing.Value = metrics + return + } + + // Real FIFO eviction: drop the front (oldest) element. + for p.order.Len() >= p.maxMetrics { + oldest := p.order.Front() + if oldest == nil { break } + oldMetrics := oldest.Value.(*Metrics) + delete(p.metrics, oldMetrics.RequestID) + p.order.Remove(oldest) } - p.metrics[requestID] = metrics + p.metrics[requestID] = p.order.PushBack(metrics) } func (p *Profiler) removeMetrics(requestID string) { p.mu.Lock() defer p.mu.Unlock() - delete(p.metrics, requestID) + if elem, ok := p.metrics[requestID]; ok { + p.order.Remove(elem) + delete(p.metrics, requestID) + } } type profileContextKey struct{} -type tracerKey struct{} + +// TracerSink is the interface a tracer must implement to receive forwarded +// SQL and HTTP events recorded through the profiler. The middleware package's +// RequestTracer satisfies this interface. +type TracerSink interface { + RecordSQL(query string, duration time.Duration, rows int, err error) + RecordHTTP(method, url string, duration time.Duration, status int, err error) +} + +type tracerSinkKey struct{} + +// WithTracerSink attaches a TracerSink to the context so that calls to +// Profiler.RecordSQLQuery and Profiler.RecordHTTPCall are also forwarded +// to the tracer. Returns the new context. +func WithTracerSink(ctx context.Context, sink TracerSink) context.Context { + if sink == nil { + return ctx + } + return context.WithValue(ctx, tracerSinkKey{}, sink) +} + +func tracerSinkFromContext(ctx context.Context) TracerSink { + if v, ok := ctx.Value(tracerSinkKey{}).(TracerSink); ok { + return v + } + return nil +} // stopCPUProfilingIfActive stops CPU profiling if the given request started it func (p *Profiler) stopCPUProfilingIfActive(requestID string) { @@ -452,19 +514,3 @@ func (p *Profiler) stopCPUProfilingIfActive(requestID string) { p.activeCPUProfile = nil } } - -// SetTracer associates a tracer with the profiling context -func (p *Profiler) SetTracer(ctx context.Context, tracer interface{}) { - // Tracer is already in context, no action needed - // This method exists for compatibility -} - -// ProfileWriter captures CPU profiles -type ProfileWriter struct { - buf []byte -} - -func (pw *ProfileWriter) Write(p []byte) (n int, err error) { - pw.buf = append(pw.buf, p...) - return len(p), nil -} diff --git a/internal/store/memory.go b/internal/store/memory.go index 5181958..7407933 100644 --- a/internal/store/memory.go +++ b/internal/store/memory.go @@ -8,6 +8,7 @@ import ( type InMemoryStore struct { logs []*model.RequestLog + index map[string]int // ID -> position in logs (O(1) Get) capacity int size int next int @@ -21,69 +22,66 @@ func NewInMemoryStore(capacity int) *InMemoryStore { return &InMemoryStore{ logs: make([]*model.RequestLog, capacity), + index: make(map[string]int, capacity), capacity: capacity, - size: 0, - next: 0, } } -func (s *InMemoryStore) Add(log *model.RequestLog) { +func (s *InMemoryStore) Add(log *model.RequestLog) error { s.mu.Lock() defer s.mu.Unlock() - s.logs[s.next] = log + // If the ring buffer is full, the slot we're about to overwrite holds + // the oldest entry. Evict it from the ID index first. + if old := s.logs[s.next]; old != nil { + delete(s.index, old.ID) + } + s.logs[s.next] = log + s.index[log.ID] = s.next s.next = (s.next + 1) % s.capacity if s.size < s.capacity { s.size++ } + return nil } func (s *InMemoryStore) Get(id string) (*model.RequestLog, bool) { s.mu.RLock() defer s.mu.RUnlock() - for _, log := range s.logs[:s.size] { - if log != nil && log.ID == id { - return log, true - } + pos, ok := s.index[id] + if !ok { + return nil, false } - - return nil, false + return s.logs[pos], true } +// GetAll returns all stored logs in newest-first order. This matches the +// ordering contract of the SQL/Mongo/Redis backends, so callers can treat the +// returned slice uniformly regardless of which store backs them. func (s *InMemoryStore) GetAll() []*model.RequestLog { s.mu.RLock() defer s.mu.RUnlock() result := make([]*model.RequestLog, 0, s.size) - - if s.size < s.capacity { - for i := 0; i < s.size; i++ { - result = append(result, s.logs[i]) - } - return result + // Walk backwards from the most-recently-written slot (s.next-1) for s.size + // steps, wrapping. This yields newest-first without an extra sort. + for i := 0; i < s.size; i++ { + idx := (s.next - 1 - i + s.capacity) % s.capacity + result = append(result, s.logs[idx]) } - - for i := s.next; i < s.capacity; i++ { - result = append(result, s.logs[i]) - } - for i := 0; i < s.next; i++ { - result = append(result, s.logs[i]) - } - return result } +// GetLatest returns the n most recent logs, newest-first. func (s *InMemoryStore) GetLatest(n int) []*model.RequestLog { all := s.GetAll() - if len(all) <= n { return all } - - return all[len(all)-n:] + return all[:n] } // Clear clears all stored request logs @@ -96,6 +94,7 @@ func (s *InMemoryStore) Clear() error { } s.logs = make([]*model.RequestLog, s.capacity) + s.index = make(map[string]int, s.capacity) s.size = 0 s.next = 0 return nil @@ -103,6 +102,5 @@ func (s *InMemoryStore) Clear() error { // Close implements the Store interface but does nothing for in-memory store func (s *InMemoryStore) Close() error { - // Nothing to do for in-memory store return nil } diff --git a/internal/store/mongodb.go b/internal/store/mongodb.go index 23bf60a..24fb19c 100644 --- a/internal/store/mongodb.go +++ b/internal/store/mongodb.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "log" + "sync/atomic" + "time" "github.com/doganarif/govisual/internal/model" "go.mongodb.org/mongo-driver/v2/bson" @@ -14,10 +16,10 @@ import ( // MongoDBStore implements the Store interface with MongoDB as backend type MongoDBStore struct { - database *mongo.Database - collection *mongo.Collection - capacity int - ctx context.Context + database *mongo.Database + collection *mongo.Collection + capacity int + insertCount atomic.Uint64 } // NewMongoDBStore creates a new MongoDB-backend store @@ -26,172 +28,176 @@ func NewMongoDBStore(uri, databaseName, collectionName string, capacity int) (*M capacity = 100 } - ctx := context.Background() + connectCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client, err := mongo.Connect(options.Client().ApplyURI(uri)) if err != nil { return nil, fmt.Errorf("failed to get MongoDB client: %w", err) } - // Test the connection - if err := client.Ping(ctx, readpref.Nearest()); err != nil { + if err := client.Ping(connectCtx, readpref.Nearest()); err != nil { + _ = client.Disconnect(connectCtx) return nil, fmt.Errorf("failed to ping MongoDB: %w", err) } database := client.Database(databaseName) collection := database.Collection(collectionName) - // Create index on timestamp for faster retrieval indexName := fmt.Sprintf("%s_timestamp_idx", collectionName) indexModel := mongo.IndexModel{ - Keys: bson.M{"Timestamp": -1}, + Keys: bson.D{{Key: "timestamp", Value: -1}}, Options: options.Index().SetName(indexName), } - - _, err = collection.Indexes().CreateOne(ctx, indexModel) - if err != nil { + if _, err := collection.Indexes().CreateOne(connectCtx, indexModel); err != nil { + _ = client.Disconnect(connectCtx) return nil, fmt.Errorf("failed to create index in MongoDB: %w", err) } + return &MongoDBStore{ database: database, collection: collection, capacity: capacity, - ctx: ctx, }, nil } -// Add adds a new request log to the store -func (m *MongoDBStore) Add(reqLog *model.RequestLog) { - // Store log in MongoDB - if _, err := m.collection.InsertOne(m.ctx, reqLog); err != nil { - log.Printf("Failed to store log in MongoDB: %v", err) - return +func (m *MongoDBStore) opCtx() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 10*time.Second) +} + +func (m *MongoDBStore) Add(reqLog *model.RequestLog) error { + ctx, cancel := m.opCtx() + defer cancel() + + if _, err := m.collection.InsertOne(ctx, reqLog); err != nil { + return fmt.Errorf("mongodb insert: %w", err) } - m.cleanup() + if m.insertCount.Add(1)%cleanupEveryN == 0 { + m.cleanup() + } + return nil } -// cleanup removes old logs to maintain the capacity limit func (m *MongoDBStore) cleanup() { - count, err := m.collection.CountDocuments(m.ctx, bson.M{}) + ctx, cancel := m.opCtx() + defer cancel() + + count, err := m.collection.CountDocuments(ctx, bson.M{}) if err != nil { - log.Printf("Failed to get the log count in MongoDB: %v", err) + log.Printf("govisual: failed to count MongoDB logs: %v", err) return } - if count <= int64(m.capacity) { return } - // Find the oldest logs that exceed capacity + findOptions := options.Find(). - SetSort(bson.D{{Key: "Timestamp", Value: 1}}). - SetLimit(count - int64(m.capacity)) + SetSort(bson.D{{Key: "timestamp", Value: 1}}). + SetLimit(count - int64(m.capacity)). + SetProjection(bson.M{"_id": 1}) - cursor, err := m.collection.Find(m.ctx, bson.M{}, findOptions) + cursor, err := m.collection.Find(ctx, bson.M{}, findOptions) if err != nil { - log.Printf("Failed to find oldest logs in MongoDB: %v", err) + log.Printf("govisual: failed to find oldest MongoDB logs: %v", err) return } - defer cursor.Close(m.ctx) + defer cursor.Close(ctx) - var oldestLogs []model.RequestLog - for cursor.Next(m.ctx) { - var reqLog model.RequestLog - if err := cursor.Decode(&reqLog); err != nil { - log.Printf("Failed to decode oldest log in MongoDB: %v", err) + var ids []string + for cursor.Next(ctx) { + var doc struct { + ID string `bson:"_id"` + } + if err := cursor.Decode(&doc); err != nil { + log.Printf("govisual: failed to decode oldest MongoDB log: %v", err) continue } - oldestLogs = append(oldestLogs, reqLog) + ids = append(ids, doc.ID) } - - if len(oldestLogs) == 0 { + if len(ids) == 0 { return } - // Extract IDs of logs to delete - var ids []string - for _, log := range oldestLogs { - ids = append(ids, log.ID) - } - - // Delete the oldest logs - if _, err := m.collection.DeleteMany(m.ctx, bson.M{"_id": bson.M{"$in": ids}}); err != nil { - log.Printf("Failed to delete oldest logs in MongoDB: %v", err) - return + if _, err := m.collection.DeleteMany(ctx, bson.M{"_id": bson.M{"$in": ids}}); err != nil { + log.Printf("govisual: failed to delete oldest MongoDB logs: %v", err) } } -// Get retrieves a specific request log by its ID func (m *MongoDBStore) Get(id string) (*model.RequestLog, bool) { + ctx, cancel := m.opCtx() + defer cancel() + var reqLog model.RequestLog - if err := m.collection.FindOne(m.ctx, bson.M{"_id": id}).Decode(&reqLog); err != nil { + if err := m.collection.FindOne(ctx, bson.M{"_id": id}).Decode(&reqLog); err != nil { if err == mongo.ErrNoDocuments { return nil, false } - log.Printf("Failed to get request log from MongoDB: %v", err) + log.Printf("govisual: failed to get MongoDB log: %v", err) return nil, false } return &reqLog, true } -// GetAll returns all stored request logs func (m *MongoDBStore) GetAll() []*model.RequestLog { - opts := options.Find().SetSort(bson.M{"Timestamp": -1}) - cursor, err := m.collection.Find(m.ctx, bson.M{}, opts) + ctx, cancel := m.opCtx() + defer cancel() + + opts := options.Find().SetSort(bson.D{{Key: "timestamp", Value: -1}}) + cursor, err := m.collection.Find(ctx, bson.M{}, opts) if err != nil { - if err == mongo.ErrClientDisconnected { - return nil - } - log.Printf("Failed to get cursor from MongoDB: %v", err) + log.Printf("govisual: failed to query MongoDB: %v", err) return nil } - defer cursor.Close(m.ctx) - reqsLog := make([]*model.RequestLog, 0) - for cursor.Next(m.ctx) { + defer cursor.Close(ctx) + + out := make([]*model.RequestLog, 0) + for cursor.Next(ctx) { var reqLog model.RequestLog if err := cursor.Decode(&reqLog); err != nil { - log.Printf("Failed to decode request log from MongoDB: %v", err) + log.Printf("govisual: failed to decode MongoDB log: %v", err) continue } - reqsLog = append(reqsLog, &reqLog) + out = append(out, &reqLog) } - return reqsLog + return out } -// GetLatest returns the n most recent request logs func (m *MongoDBStore) GetLatest(n int) []*model.RequestLog { - // Get the n newest log IDs - opts := options.Find().SetLimit(int64(n)).SetSort(bson.M{"timestamp": -1}) - cursor, err := m.collection.Find(m.ctx, bson.M{}, opts) + ctx, cancel := m.opCtx() + defer cancel() + + opts := options.Find().SetLimit(int64(n)).SetSort(bson.D{{Key: "timestamp", Value: -1}}) + cursor, err := m.collection.Find(ctx, bson.M{}, opts) if err != nil { - if err == mongo.ErrClientDisconnected { - return nil - } - log.Printf("Failed to get cursor from MongoDB: %v", err) + log.Printf("govisual: failed to query MongoDB: %v", err) return nil } - defer cursor.Close(m.ctx) - reqsLog := make([]*model.RequestLog, 0) - for cursor.Next(m.ctx) { + defer cursor.Close(ctx) + + out := make([]*model.RequestLog, 0) + for cursor.Next(ctx) { var reqLog model.RequestLog if err := cursor.Decode(&reqLog); err != nil { - log.Printf("Failed to decode request log from MongoDB: %v", err) + log.Printf("govisual: failed to decode MongoDB log: %v", err) continue } - reqsLog = append(reqsLog, &reqLog) + out = append(out, &reqLog) } - - return reqsLog + return out } -// Clear removes all logs from the store func (m *MongoDBStore) Clear() error { - _, err := m.collection.DeleteMany(m.ctx, bson.M{}) - if err != nil { - return fmt.Errorf("failed to clear logs in MongoDB: %w", err) + ctx, cancel := m.opCtx() + defer cancel() + + if _, err := m.collection.DeleteMany(ctx, bson.M{}); err != nil { + return fmt.Errorf("failed to clear MongoDB logs: %w", err) } return nil } -// Close closes the database connection func (m *MongoDBStore) Close() error { - return m.database.Client().Disconnect(m.ctx) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return m.database.Client().Disconnect(ctx) } diff --git a/internal/store/postgres.go b/internal/store/postgres.go index aea1ddc..24d49b7 100644 --- a/internal/store/postgres.go +++ b/internal/store/postgres.go @@ -5,16 +5,23 @@ import ( "encoding/json" "fmt" "log" + "sync/atomic" "github.com/doganarif/govisual/internal/model" _ "github.com/lib/pq" ) +// cleanupEveryN runs the capacity-trim query once every N successful inserts, +// instead of on every Add. Trading a slight overshoot of the configured capacity +// for far less load on the database. +const cleanupEveryN = 32 + // PostgresStore implements the Store interface with PostgreSQL as backend type PostgresStore struct { - db *sql.DB - tableName string - capacity int + db *sql.DB + tableName string + capacity int + insertCount atomic.Uint64 } // NewPostgresStore creates a new PostgreSQL-backed store @@ -23,33 +30,34 @@ func NewPostgresStore(connStr, tableName string, capacity int) (*PostgresStore, capacity = 100 } - // Connect to the database + if !IsValidTableName(tableName) { + return nil, fmt.Errorf("invalid table name %q: must match [A-Za-z_][A-Za-z0-9_]*", tableName) + } + db, err := sql.Open("postgres", connStr) if err != nil { return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err) } - // Test the connection if err := db.Ping(); err != nil { + db.Close() return nil, fmt.Errorf("failed to ping PostgreSQL: %w", err) } - store := &PostgresStore{ + s := &PostgresStore{ db: db, tableName: tableName, capacity: capacity, } - // Create the table if it doesn't exist - if err := store.createTable(); err != nil { + if err := s.createTable(); err != nil { db.Close() return nil, fmt.Errorf("failed to create table: %w", err) } - return store, nil + return s, nil } -// createTable creates the required table if it doesn't exist func (s *PostgresStore) createTable() error { query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( @@ -71,26 +79,21 @@ func (s *PostgresStore) createTable() error { ) `, s.tableName) - _, err := s.db.Exec(query) - if err != nil { + if _, err := s.db.Exec(query); err != nil { return err } - // Create index on timestamp for faster retrieval indexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_timestamp_idx ON %s (timestamp DESC)", s.tableName, s.tableName) - _, err = s.db.Exec(indexQuery) - + _, err := s.db.Exec(indexQuery) return err } // Add adds a new request log to the store -func (s *PostgresStore) Add(reqLog *model.RequestLog) { - // Prepare all JSON fields properly +func (s *PostgresStore) Add(reqLog *model.RequestLog) error { reqHeaders := prepareJSON(reqLog.RequestHeaders) respHeaders := prepareJSON(reqLog.ResponseHeaders) - // Default to empty arrays/objects for JSON fields if they're nil or empty middlewareTrace := "[]" if len(reqLog.MiddlewareTrace) > 0 { if data, err := json.Marshal(reqLog.MiddlewareTrace); err == nil { @@ -105,11 +108,10 @@ func (s *PostgresStore) Add(reqLog *model.RequestLog) { } } - // Insert the log using string interpolation for JSON fields to avoid issues with parameter binding query := fmt.Sprintf(` INSERT INTO %s ( id, timestamp, method, path, query, request_headers, response_headers, - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, middleware_trace, route_trace ) VALUES ($1, $2, $3, $4, $5, $6::jsonb, $7::jsonb, $8, $9, $10, $11, $12, $13::jsonb, $14::jsonb) `, s.tableName) @@ -131,13 +133,14 @@ func (s *PostgresStore) Add(reqLog *model.RequestLog) { middlewareTrace, routeTrace, ) - if err != nil { - log.Printf("Failed to store request log in PostgreSQL: %v", err) + return fmt.Errorf("postgres insert: %w", err) } - // Clean up old logs - s.cleanup() + if s.insertCount.Add(1)%cleanupEveryN == 0 { + s.cleanup() + } + return nil } // prepareJSON ensures we have a valid JSON string @@ -145,13 +148,11 @@ func prepareJSON(v interface{}) string { if v == nil { return "{}" } - data, err := json.Marshal(v) if err != nil { - log.Printf("Failed to marshal JSON: %v", err) + log.Printf("govisual: failed to marshal JSON: %v", err) return "{}" } - return string(data) } @@ -159,17 +160,14 @@ func prepareJSON(v interface{}) string { func (s *PostgresStore) cleanup() { countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", s.tableName) var count int - err := s.db.QueryRow(countQuery).Scan(&count) - if err != nil { - log.Printf("Failed to count logs: %v", err) + if err := s.db.QueryRow(countQuery).Scan(&count); err != nil { + log.Printf("govisual: failed to count logs: %v", err) return } - if count <= s.capacity { return } - // Delete oldest logs deleteQuery := fmt.Sprintf(` DELETE FROM %s WHERE id IN ( @@ -179,20 +177,19 @@ func (s *PostgresStore) cleanup() { ) `, s.tableName, s.tableName) - _, err = s.db.Exec(deleteQuery, count-s.capacity) - if err != nil { - log.Printf("Failed to clean up old logs: %v", err) + if _, err := s.db.Exec(deleteQuery, count-s.capacity); err != nil { + log.Printf("govisual: failed to clean up old logs: %v", err) } } // Get retrieves a specific request log by its ID func (s *PostgresStore) Get(id string) (*model.RequestLog, bool) { query := fmt.Sprintf(` - SELECT - id, timestamp, method, path, query, + SELECT + id, timestamp, method, path, query, COALESCE(request_headers::text, '{}'), COALESCE(response_headers::text, '{}'), - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, COALESCE(middleware_trace::text, '[]'), COALESCE(route_trace::text, '{}') FROM %s @@ -223,20 +220,18 @@ func (s *PostgresStore) Get(id string) (*model.RequestLog, bool) { &middlewareTrace, &routeTrace, ) - if err != nil { if err == sql.ErrNoRows { return nil, false } - log.Printf("Failed to get request log from PostgreSQL: %v", err) + log.Printf("govisual: failed to get request log from PostgreSQL: %v", err) return nil, false } - // Unmarshal all JSON fields - json.Unmarshal([]byte(reqHeadersStr), &reqLog.RequestHeaders) - json.Unmarshal([]byte(respHeadersStr), &reqLog.ResponseHeaders) - json.Unmarshal([]byte(middlewareTrace), &reqLog.MiddlewareTrace) - json.Unmarshal([]byte(routeTrace), &reqLog.RouteTrace) + unmarshalLogJSON(reqHeadersStr, &reqLog.RequestHeaders, "request_headers", reqLog.ID) + unmarshalLogJSON(respHeadersStr, &reqLog.ResponseHeaders, "response_headers", reqLog.ID) + unmarshalLogJSON(middlewareTrace, &reqLog.MiddlewareTrace, "middleware_trace", reqLog.ID) + unmarshalLogJSON(routeTrace, &reqLog.RouteTrace, "route_trace", reqLog.ID) return &reqLog, true } @@ -244,11 +239,11 @@ func (s *PostgresStore) Get(id string) (*model.RequestLog, bool) { // GetAll returns all stored request logs func (s *PostgresStore) GetAll() []*model.RequestLog { query := fmt.Sprintf(` - SELECT - id, timestamp, method, path, query, + SELECT + id, timestamp, method, path, query, COALESCE(request_headers::text, '{}'), COALESCE(response_headers::text, '{}'), - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, COALESCE(middleware_trace::text, '[]'), COALESCE(route_trace::text, '{}') FROM %s @@ -261,11 +256,11 @@ func (s *PostgresStore) GetAll() []*model.RequestLog { // GetLatest returns the n most recent request logs func (s *PostgresStore) GetLatest(n int) []*model.RequestLog { query := fmt.Sprintf(` - SELECT - id, timestamp, method, path, query, + SELECT + id, timestamp, method, path, query, COALESCE(request_headers::text, '{}'), COALESCE(response_headers::text, '{}'), - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, COALESCE(middleware_trace::text, '[]'), COALESCE(route_trace::text, '{}') FROM %s @@ -276,11 +271,10 @@ func (s *PostgresStore) GetLatest(n int) []*model.RequestLog { return s.queryLogs(query, n) } -// queryLogs executes a query and returns the resulting log entries func (s *PostgresStore) queryLogs(query string, args ...interface{}) []*model.RequestLog { rows, err := s.db.Query(query, args...) if err != nil { - log.Printf("Failed to query logs from PostgreSQL: %v", err) + log.Printf("govisual: failed to query logs from PostgreSQL: %v", err) return nil } defer rows.Close() @@ -296,7 +290,7 @@ func (s *PostgresStore) queryLogs(query string, args ...interface{}) []*model.Re routeTrace string ) - err := rows.Scan( + if err := rows.Scan( &reqLog.ID, &reqLog.Timestamp, &reqLog.Method, @@ -311,24 +305,21 @@ func (s *PostgresStore) queryLogs(query string, args ...interface{}) []*model.Re &reqLog.Error, &middlewareTrace, &routeTrace, - ) - - if err != nil { - log.Printf("Failed to scan row: %v", err) + ); err != nil { + log.Printf("govisual: failed to scan row: %v", err) continue } - // Unmarshal all JSON fields, ignoring errors - json.Unmarshal([]byte(reqHeadersStr), &reqLog.RequestHeaders) - json.Unmarshal([]byte(respHeadersStr), &reqLog.ResponseHeaders) - json.Unmarshal([]byte(middlewareTrace), &reqLog.MiddlewareTrace) - json.Unmarshal([]byte(routeTrace), &reqLog.RouteTrace) + unmarshalLogJSON(reqHeadersStr, &reqLog.RequestHeaders, "request_headers", reqLog.ID) + unmarshalLogJSON(respHeadersStr, &reqLog.ResponseHeaders, "response_headers", reqLog.ID) + unmarshalLogJSON(middlewareTrace, &reqLog.MiddlewareTrace, "middleware_trace", reqLog.ID) + unmarshalLogJSON(routeTrace, &reqLog.RouteTrace, "route_trace", reqLog.ID) logs = append(logs, &reqLog) } if err := rows.Err(); err != nil { - log.Printf("Error iterating over rows: %v", err) + log.Printf("govisual: error iterating over rows: %v", err) } return logs @@ -337,11 +328,9 @@ func (s *PostgresStore) queryLogs(query string, args ...interface{}) []*model.Re // Clear clears all stored request logs func (s *PostgresStore) Clear() error { query := fmt.Sprintf("TRUNCATE TABLE %s", s.tableName) - _, err := s.db.Exec(query) - if err != nil { + if _, err := s.db.Exec(query); err != nil { return fmt.Errorf("failed to clear logs: %w", err) } - return nil } @@ -349,3 +338,14 @@ func (s *PostgresStore) Clear() error { func (s *PostgresStore) Close() error { return s.db.Close() } + +// unmarshalLogJSON is shared by all SQL stores so they all report unmarshal +// errors consistently instead of silently dropping fields. +func unmarshalLogJSON(s string, v interface{}, field, logID string) { + if s == "" { + return + } + if err := json.Unmarshal([]byte(s), v); err != nil { + log.Printf("govisual: failed to unmarshal %s for log %s: %v", field, logID, err) + } +} diff --git a/internal/store/redis.go b/internal/store/redis.go index 4d26b57..2be2d44 100644 --- a/internal/store/redis.go +++ b/internal/store/redis.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "sort" + "sync/atomic" "time" "github.com/doganarif/govisual/internal/model" @@ -14,11 +15,11 @@ import ( // RedisStore implements the Store interface with Redis as backend type RedisStore struct { - client *redis.Client - keyPrefix string - capacity int - ttl time.Duration - ctx context.Context + client *redis.Client + keyPrefix string + capacity int + ttl time.Duration + insertCount atomic.Uint64 } // NewRedisStore creates a new Redis-backed store @@ -26,23 +27,21 @@ func NewRedisStore(connStr string, capacity int, ttlSeconds int) (*RedisStore, e if capacity <= 0 { capacity = 100 } - if ttlSeconds <= 0 { - ttlSeconds = 86400 // Default to 24 hours + ttlSeconds = 86400 // 24h } - // Parse the Redis connection string opts, err := redis.ParseURL(connStr) if err != nil { return nil, fmt.Errorf("invalid Redis connection string: %w", err) } - // Create Redis client client := redis.NewClient(opts) - ctx := context.Background() - // Test the connection - if err := client.Ping(ctx).Err(); err != nil { + pingCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := client.Ping(pingCtx).Err(); err != nil { + client.Close() return nil, fmt.Errorf("failed to connect to Redis: %w", err) } @@ -51,216 +50,193 @@ func NewRedisStore(connStr string, capacity int, ttlSeconds int) (*RedisStore, e keyPrefix: "govisual:", capacity: capacity, ttl: time.Duration(ttlSeconds) * time.Second, - ctx: ctx, }, nil } -// Add adds a new request log to the store -func (s *RedisStore) Add(reqLog *model.RequestLog) { - // Convert log to JSON +// opCtx returns a short-lived context for a single Redis call. Stores must not +// hang onto a context for their entire lifetime. +func (s *RedisStore) opCtx() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 10*time.Second) +} + +func (s *RedisStore) Add(reqLog *model.RequestLog) error { data, err := json.Marshal(reqLog) if err != nil { - reqLog.Error = fmt.Sprintf("Failed to marshal log: %v", err) - return + return fmt.Errorf("redis marshal: %w", err) } - // Key for the log - key := s.keyPrefix + reqLog.ID + ctx, cancel := s.opCtx() + defer cancel() - // Store log as a JSON string - if err := s.client.Set(s.ctx, key, data, s.ttl).Err(); err != nil { - log.Printf("Failed to store log in Redis: %v", err) - return + key := s.keyPrefix + reqLog.ID + if err := s.client.Set(ctx, key, data, s.ttl).Err(); err != nil { + return fmt.Errorf("redis set: %w", err) } - // Add to sorted set for time-ordered access score := float64(reqLog.Timestamp.UnixNano()) - if err := s.client.ZAdd(s.ctx, s.keyPrefix+"logs", &redis.Z{ + if err := s.client.ZAdd(ctx, s.keyPrefix+"logs", &redis.Z{ Score: score, Member: reqLog.ID, }).Err(); err != nil { - log.Printf("Failed to add log ID to sorted set: %v", err) + return fmt.Errorf("redis zadd: %w", err) } - // Clean up old logs - s.cleanup() + if s.insertCount.Add(1)%cleanupEveryN == 0 { + s.cleanup() + } + return nil } -// cleanup removes old logs to maintain the capacity limit func (s *RedisStore) cleanup() { - // Get the current number of logs - count, err := s.client.ZCard(s.ctx, s.keyPrefix+"logs").Result() + ctx, cancel := s.opCtx() + defer cancel() + + count, err := s.client.ZCard(ctx, s.keyPrefix+"logs").Result() if err != nil { - log.Printf("Failed to count logs: %v", err) + log.Printf("govisual: failed to count Redis logs: %v", err) return } - if count <= int64(s.capacity) { return } - // Get the oldest log IDs that exceed capacity - oldestIDs, err := s.client.ZRange(s.ctx, s.keyPrefix+"logs", 0, count-int64(s.capacity)-1).Result() + oldestIDs, err := s.client.ZRange(ctx, s.keyPrefix+"logs", 0, count-int64(s.capacity)-1).Result() if err != nil { - log.Printf("Failed to get oldest log IDs: %v", err) + log.Printf("govisual: failed to get oldest Redis log IDs: %v", err) return } - if len(oldestIDs) == 0 { return } - // Create a pipeline for batch operations pipe := s.client.Pipeline() - - // Remove from sorted set - pipe.ZRem(s.ctx, s.keyPrefix+"logs", oldestIDs) - - // Remove each log + // ZRem takes ...interface{}; convert from []string. + members := make([]interface{}, len(oldestIDs)) + for i, id := range oldestIDs { + members[i] = id + } + pipe.ZRem(ctx, s.keyPrefix+"logs", members...) for _, id := range oldestIDs { - pipe.Del(s.ctx, s.keyPrefix+id) + pipe.Del(ctx, s.keyPrefix+id) } - - // Execute pipeline - if _, err := pipe.Exec(s.ctx); err != nil { - log.Printf("Failed to clean up old logs: %v", err) + if _, err := pipe.Exec(ctx); err != nil { + log.Printf("govisual: failed to clean up old Redis logs: %v", err) } } -// Get retrieves a specific request log by its ID func (s *RedisStore) Get(id string) (*model.RequestLog, bool) { - key := s.keyPrefix + id + ctx, cancel := s.opCtx() + defer cancel() - // Get log data from Redis - data, err := s.client.Get(s.ctx, key).Bytes() + data, err := s.client.Get(ctx, s.keyPrefix+id).Bytes() if err != nil { if err == redis.Nil { return nil, false } - log.Printf("Failed to get log from Redis: %v", err) + log.Printf("govisual: failed to get log from Redis: %v", err) return nil, false } - // Unmarshal data var reqLog model.RequestLog if err := json.Unmarshal(data, &reqLog); err != nil { - log.Printf("Failed to unmarshal log data: %v", err) + log.Printf("govisual: failed to unmarshal Redis log: %v", err) return nil, false } - return &reqLog, true } -// GetAll returns all stored request logs func (s *RedisStore) GetAll() []*model.RequestLog { - // Get all log IDs from the sorted set, in reverse order (newest first) - ids, err := s.client.ZRevRange(s.ctx, s.keyPrefix+"logs", 0, -1).Result() + ctx, cancel := s.opCtx() + defer cancel() + + ids, err := s.client.ZRevRange(ctx, s.keyPrefix+"logs", 0, -1).Result() if err != nil { - log.Printf("Failed to get log IDs: %v", err) + log.Printf("govisual: failed to get Redis log IDs: %v", err) return nil } - - return s.getLogs(ids) + return s.getLogs(ctx, ids) } -// GetLatest returns the n most recent request logs func (s *RedisStore) GetLatest(n int) []*model.RequestLog { - // Get the n newest log IDs - ids, err := s.client.ZRevRange(s.ctx, s.keyPrefix+"logs", 0, int64(n-1)).Result() + ctx, cancel := s.opCtx() + defer cancel() + + ids, err := s.client.ZRevRange(ctx, s.keyPrefix+"logs", 0, int64(n-1)).Result() if err != nil { - log.Printf("Failed to get latest log IDs: %v", err) + log.Printf("govisual: failed to get latest Redis log IDs: %v", err) return nil } - - return s.getLogs(ids) + return s.getLogs(ctx, ids) } -// getLogs retrieves logs by their IDs -func (s *RedisStore) getLogs(ids []string) []*model.RequestLog { +func (s *RedisStore) getLogs(ctx context.Context, ids []string) []*model.RequestLog { if len(ids) == 0 { return nil } - // Use a pipeline for batch retrieval pipe := s.client.Pipeline() - cmds := make(map[string]*redis.StringCmd) - - // Queue up the get commands + cmds := make(map[string]*redis.StringCmd, len(ids)) for _, id := range ids { - cmds[id] = pipe.Get(s.ctx, s.keyPrefix+id) + cmds[id] = pipe.Get(ctx, s.keyPrefix+id) } - - // Execute pipeline - _, err := pipe.Exec(s.ctx) - if err != nil && err != redis.Nil { - log.Printf("Failed to execute pipeline: %v", err) + if _, err := pipe.Exec(ctx); err != nil && err != redis.Nil { + log.Printf("govisual: failed to execute Redis pipeline: %v", err) return nil } - // Process results logs := make([]*model.RequestLog, 0, len(ids)) - idToLog := make(map[string]*model.RequestLog) - for id, cmd := range cmds { data, err := cmd.Bytes() if err != nil { if err != redis.Nil { - log.Printf("Failed to get log %s: %v", id, err) + log.Printf("govisual: failed to get Redis log %s: %v", id, err) } continue } - var reqLog model.RequestLog if err := json.Unmarshal(data, &reqLog); err != nil { - log.Printf("Failed to unmarshal log data for %s: %v", id, err) + log.Printf("govisual: failed to unmarshal Redis log %s: %v", id, err) continue } - logs = append(logs, &reqLog) - idToLog[id] = &reqLog } - // Sort logs by timestamp (newest first) sort.Slice(logs, func(i, j int) bool { return logs[i].Timestamp.After(logs[j].Timestamp) }) - return logs } -// Clear removes all logs from the store func (s *RedisStore) Clear() error { - // Get all log IDs - ids, err := s.client.ZRange(s.ctx, s.keyPrefix+"logs", 0, -1).Result() + ctx, cancel := s.opCtx() + defer cancel() + + ids, err := s.client.ZRange(ctx, s.keyPrefix+"logs", 0, -1).Result() if err != nil { return fmt.Errorf("failed to get log IDs: %w", err) } - // Create a pipeline for batch operations pipe := s.client.Pipeline() - - // Remove from sorted set - pipe.ZRem(s.ctx, s.keyPrefix+"logs", ids) - - // Remove each log from Redis - for _, id := range ids { - pipe.Unlink(s.ctx, s.keyPrefix+id) + if len(ids) > 0 { + members := make([]interface{}, len(ids)) + for i, id := range ids { + members[i] = id + } + pipe.ZRem(ctx, s.keyPrefix+"logs", members...) + for _, id := range ids { + pipe.Unlink(ctx, s.keyPrefix+id) + } } - - // Execute pipeline - if _, err := pipe.Exec(s.ctx); err != nil { + if _, err := pipe.Exec(ctx); err != nil { return fmt.Errorf("failed to clear logs: %w", err) } - // Delete the sorted set - if err := s.client.Del(s.ctx, s.keyPrefix+"logs").Err(); err != nil { + if err := s.client.Del(ctx, s.keyPrefix+"logs").Err(); err != nil { return fmt.Errorf("failed to delete sorted set: %w", err) } - return nil } -// Close closes the Redis client connection func (s *RedisStore) Close() error { return s.client.Close() } diff --git a/internal/store/sqlite.go b/internal/store/sqlite.go index aab4013..aa7a190 100644 --- a/internal/store/sqlite.go +++ b/internal/store/sqlite.go @@ -5,126 +5,93 @@ import ( "encoding/json" "fmt" "log" - "regexp" - "sync" + "sync/atomic" "github.com/doganarif/govisual/internal/model" - // Don't import and register SQLite driver automatically - // _ "github.com/ncruces/go-sqlite3/driver" - // _ "github.com/ncruces/go-sqlite3/embed" ) -// SQLiteStore implements the Store interface with SQLite as backend +// SQLiteStore implements the Store interface with SQLite as backend. +// +// SQLite driver registration is the caller's responsibility — govisual does +// not import a driver to avoid forcing a specific implementation on users. +// Register your preferred driver (e.g. mattn/go-sqlite3 or ncruces/go-sqlite3) +// before calling NewSQLiteStore, or use NewSQLiteStoreWithDB with a pre-built +// *sql.DB. type SQLiteStore struct { - db *sql.DB - tableName string - capacity int - // Add a flag to track if we own the connection + db *sql.DB + tableName string + capacity int ownsConnection bool + insertCount atomic.Uint64 } -// RegisterSQLiteDriver registers the SQLite driver with database/sql -// This should only be called if you don't already have a SQLite driver registered -var registerOnce sync.Once -var registerError error - -func RegisterSQLiteDriver() error { - registerOnce.Do(func() { - // Dynamically import and register - // Attempt to import the SQLite driver - registerError = initSQLiteDriver() - }) - return registerError -} - -// initSQLiteDriver is a helper function that actually initializes the driver -func initSQLiteDriver() error { - // We're not directly importing the driver to avoid auto-registration - // Your application should use its preferred SQLite driver - // This should only be called if you don't already have a SQLite driver - return fmt.Errorf("you need to register a SQLite driver or use WithSQLiteStorageDB with an existing database connection") -} - -// isValidTableName checks if a table name contains only alphanumeric and underscore characters -func isValidTableName(tableName string) bool { - match, _ := regexp.MatchString(`^[a-zA-Z0-9_]+$`, tableName) - return match -} - -// NewSQLiteStore creates a new SQLite-backed store +// NewSQLiteStore creates a new SQLite-backed store. +// dbPath is forwarded to sql.Open("sqlite3", dbPath); ensure a SQLite driver +// is already registered under the name "sqlite3". func NewSQLiteStore(dbPath, tableName string, capacity int) (*SQLiteStore, error) { if capacity <= 0 { capacity = 100 } - // Validate table name to prevent SQL injection - if !isValidTableName(tableName) { - return nil, fmt.Errorf("invalid table name: table name can only contain letters, numbers, and underscores") + if !IsValidTableName(tableName) { + return nil, fmt.Errorf("invalid table name %q: must match [A-Za-z_][A-Za-z0-9_]*", tableName) } - // Connect to the database db, err := sql.Open("sqlite3", dbPath) if err != nil { return nil, fmt.Errorf("failed to open SQLite DB: %w", err) } - // Test the connection if err := db.Ping(); err != nil { + db.Close() return nil, fmt.Errorf("failed to ping SQLite DB: %w", err) } - store := &SQLiteStore{ + s := &SQLiteStore{ db: db, tableName: tableName, capacity: capacity, ownsConnection: true, } - // Create the table if it doesn't exist - if err := store.createTable(); err != nil { + if err := s.createTable(); err != nil { db.Close() return nil, fmt.Errorf("failed to create table: %w", err) } - return store, nil + return s, nil } -// NewSQLiteStoreWithDB creates a new SQLite store with an existing database connection +// NewSQLiteStoreWithDB creates a new SQLite store with an existing database connection. func NewSQLiteStoreWithDB(db *sql.DB, tableName string, capacity int) (*SQLiteStore, error) { if db == nil { return nil, fmt.Errorf("database connection cannot be nil") } - if capacity <= 0 { capacity = 100 } - - // Validate table name to prevent SQL injection - if !isValidTableName(tableName) { - return nil, fmt.Errorf("invalid table name: table name can only contain letters, numbers, and underscores") + if !IsValidTableName(tableName) { + return nil, fmt.Errorf("invalid table name %q: must match [A-Za-z_][A-Za-z0-9_]*", tableName) } - // Test the connection if err := db.Ping(); err != nil { return nil, fmt.Errorf("failed to ping SQLite DB: %w", err) } - store := &SQLiteStore{ + s := &SQLiteStore{ db: db, tableName: tableName, capacity: capacity, ownsConnection: false, } - // Create the table if it doesn't exist - if err := store.createTable(); err != nil { + if err := s.createTable(); err != nil { return nil, fmt.Errorf("failed to create table: %w", err) } - return store, nil + return s, nil } -// createTable creates the required table if it doesn't exist func (s *SQLiteStore) createTable() error { query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( @@ -146,21 +113,17 @@ func (s *SQLiteStore) createTable() error { ) `, s.tableName) - _, err := s.db.Exec(query) - if err != nil { + if _, err := s.db.Exec(query); err != nil { return err } - // Create index on timestamp for faster retrieval indexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_timestamp_idx ON %s(timestamp DESC)", s.tableName, s.tableName) - _, err = s.db.Exec(indexQuery) - + _, err := s.db.Exec(indexQuery) return err } -// Add adds a new request log to the store -func (s *SQLiteStore) Add(reqLog *model.RequestLog) { +func (s *SQLiteStore) Add(reqLog *model.RequestLog) error { reqHeaders := prepareJSON(reqLog.RequestHeaders) respHeaders := prepareJSON(reqLog.ResponseHeaders) @@ -181,7 +144,7 @@ func (s *SQLiteStore) Add(reqLog *model.RequestLog) { query := fmt.Sprintf(` INSERT OR REPLACE INTO %s ( id, timestamp, method, path, query, request_headers, response_headers, - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, middleware_trace, route_trace ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) `, s.tableName) @@ -203,29 +166,27 @@ func (s *SQLiteStore) Add(reqLog *model.RequestLog) { middlewareTrace, routeTrace, ) - if err != nil { - log.Printf("Failed to store request log in SQLite: %v", err) + return fmt.Errorf("sqlite insert: %w", err) } - s.cleanup() + if s.insertCount.Add(1)%cleanupEveryN == 0 { + s.cleanup() + } + return nil } -// cleanup removes old logs to maintain the capacity limit func (s *SQLiteStore) cleanup() { countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", s.tableName) var count int - err := s.db.QueryRow(countQuery).Scan(&count) - if err != nil { - log.Printf("Failed to count logs: %v", err) + if err := s.db.QueryRow(countQuery).Scan(&count); err != nil { + log.Printf("govisual: failed to count logs: %v", err) return } - if count <= s.capacity { return } - // Delete oldest logs deleteQuery := fmt.Sprintf(` DELETE FROM %s WHERE id IN ( @@ -235,20 +196,18 @@ func (s *SQLiteStore) cleanup() { ) `, s.tableName, s.tableName) - _, err = s.db.Exec(deleteQuery, count-s.capacity) - if err != nil { - log.Printf("Failed to clean up old logs: %v", err) + if _, err := s.db.Exec(deleteQuery, count-s.capacity); err != nil { + log.Printf("govisual: failed to clean up old logs: %v", err) } } -// Get retrieves a specific request log by its ID func (s *SQLiteStore) Get(id string) (*model.RequestLog, bool) { query := fmt.Sprintf(` - SELECT - id, timestamp, method, path, query, + SELECT + id, timestamp, method, path, query, COALESCE(request_headers, '{}'), COALESCE(response_headers, '{}'), - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, COALESCE(middleware_trace, '[]'), COALESCE(route_trace, '{}') FROM %s @@ -279,31 +238,29 @@ func (s *SQLiteStore) Get(id string) (*model.RequestLog, bool) { &middlewareTrace, &routeTrace, ) - if err != nil { if err == sql.ErrNoRows { return nil, false } - log.Printf("Failed to get request log from SQLite: %v", err) + log.Printf("govisual: failed to get request log from SQLite: %v", err) return nil, false } - json.Unmarshal([]byte(reqHeadersStr), &reqLog.RequestHeaders) - json.Unmarshal([]byte(respHeadersStr), &reqLog.ResponseHeaders) - json.Unmarshal([]byte(middlewareTrace), &reqLog.MiddlewareTrace) - json.Unmarshal([]byte(routeTrace), &reqLog.RouteTrace) + unmarshalLogJSON(reqHeadersStr, &reqLog.RequestHeaders, "request_headers", reqLog.ID) + unmarshalLogJSON(respHeadersStr, &reqLog.ResponseHeaders, "response_headers", reqLog.ID) + unmarshalLogJSON(middlewareTrace, &reqLog.MiddlewareTrace, "middleware_trace", reqLog.ID) + unmarshalLogJSON(routeTrace, &reqLog.RouteTrace, "route_trace", reqLog.ID) return &reqLog, true } -// GetAll returns all stored request logs func (s *SQLiteStore) GetAll() []*model.RequestLog { query := fmt.Sprintf(` - SELECT - id, timestamp, method, path, query, + SELECT + id, timestamp, method, path, query, COALESCE(request_headers, '{}'), COALESCE(response_headers, '{}'), - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, COALESCE(middleware_trace, '[]'), COALESCE(route_trace, '{}') FROM %s @@ -313,14 +270,13 @@ func (s *SQLiteStore) GetAll() []*model.RequestLog { return s.queryLogs(query) } -// GetLatest returns the n most recent request logs func (s *SQLiteStore) GetLatest(n int) []*model.RequestLog { query := fmt.Sprintf(` - SELECT - id, timestamp, method, path, query, + SELECT + id, timestamp, method, path, query, COALESCE(request_headers, '{}'), COALESCE(response_headers, '{}'), - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, COALESCE(middleware_trace, '[]'), COALESCE(route_trace, '{}') FROM %s @@ -331,17 +287,15 @@ func (s *SQLiteStore) GetLatest(n int) []*model.RequestLog { return s.queryLogs(query, n) } -// queryLogs executes a query and returns the resulting log entries func (s *SQLiteStore) queryLogs(query string, args ...interface{}) []*model.RequestLog { rows, err := s.db.Query(query, args...) if err != nil { - log.Printf("Failed to query logs from SQLite: %v", err) + log.Printf("govisual: failed to query logs from SQLite: %v", err) return nil } defer rows.Close() var logs []*model.RequestLog - for rows.Next() { var ( reqLog model.RequestLog @@ -350,8 +304,7 @@ func (s *SQLiteStore) queryLogs(query string, args ...interface{}) []*model.Requ middlewareTrace string routeTrace string ) - - err := rows.Scan( + if err := rows.Scan( &reqLog.ID, &reqLog.Timestamp, &reqLog.Method, @@ -366,42 +319,35 @@ func (s *SQLiteStore) queryLogs(query string, args ...interface{}) []*model.Requ &reqLog.Error, &middlewareTrace, &routeTrace, - ) - - if err != nil { - log.Printf("Failed to scan row: %v", err) + ); err != nil { + log.Printf("govisual: failed to scan row: %v", err) continue } - json.Unmarshal([]byte(reqHeadersStr), &reqLog.RequestHeaders) - json.Unmarshal([]byte(respHeadersStr), &reqLog.ResponseHeaders) - json.Unmarshal([]byte(middlewareTrace), &reqLog.MiddlewareTrace) - json.Unmarshal([]byte(routeTrace), &reqLog.RouteTrace) + unmarshalLogJSON(reqHeadersStr, &reqLog.RequestHeaders, "request_headers", reqLog.ID) + unmarshalLogJSON(respHeadersStr, &reqLog.ResponseHeaders, "response_headers", reqLog.ID) + unmarshalLogJSON(middlewareTrace, &reqLog.MiddlewareTrace, "middleware_trace", reqLog.ID) + unmarshalLogJSON(routeTrace, &reqLog.RouteTrace, "route_trace", reqLog.ID) logs = append(logs, &reqLog) } if err := rows.Err(); err != nil { - log.Printf("Error iterating over rows: %v", err) + log.Printf("govisual: error iterating over rows: %v", err) } return logs } -// Clear removes all logs from the store func (s *SQLiteStore) Clear() error { query := fmt.Sprintf("DELETE FROM %s", s.tableName) - _, err := s.db.Exec(query) - if err != nil { + if _, err := s.db.Exec(query); err != nil { return fmt.Errorf("failed to clear logs: %w", err) } - return nil } -// Close closes the database connection func (s *SQLiteStore) Close() error { - // Only close the connection if we own it if s.ownsConnection { return s.db.Close() } diff --git a/internal/store/store.go b/internal/store/store.go index beade9a..4fd72dd 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -1,11 +1,16 @@ package store -import "github.com/doganarif/govisual/internal/model" +import ( + "regexp" + + "github.com/doganarif/govisual/internal/model" +) // Store defines the interface for all storage backends type Store interface { - // Add adds a new request log to the store - Add(log *model.RequestLog) + // Add stores a new request log. Returns an error so callers can surface + // storage failures (otherwise the dashboard silently drops entries). + Add(log *model.RequestLog) error // Get retrieves a specific request log by its ID Get(id string) (*model.RequestLog, bool) @@ -22,3 +27,14 @@ type Store interface { // Close closes any open connections Close() error } + +// validTableName matches identifiers safe to inject into SQL via fmt.Sprintf. +// Letters, digits and underscore only — no quoting, dots, or whitespace. +var validTableName = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) + +// IsValidTableName reports whether tableName is safe to interpolate into a +// SQL statement. It is consulted by every SQL-backed store before any +// query is constructed; never bypass it. +func IsValidTableName(tableName string) bool { + return validTableName.MatchString(tableName) +} diff --git a/options.go b/options.go index 45ec23d..25f75fa 100644 --- a/options.go +++ b/options.go @@ -1,8 +1,11 @@ package govisual import ( + "context" + "crypto/subtle" "database/sql" "fmt" + "net/http" "path/filepath" "strings" "time" @@ -11,6 +14,11 @@ import ( "github.com/doganarif/govisual/internal/store" ) +// DashboardAuth authorizes a request to the dashboard. Return true to allow, +// false to deny (govisual sends an HTTP 401). Implementations should be +// constant-time when comparing secrets. +type DashboardAuth func(r *http.Request) bool + type Config struct { MaxRequests int @@ -20,6 +28,11 @@ type Config struct { LogResponseBody bool + // MaxBodyBytes caps the captured request and response body size. + // 0 (default) means use middleware.DefaultMaxBodyBytes (1 MiB). + // Set to -1 to disable the cap entirely (NOT recommended). + MaxBodyBytes int + IgnorePaths []string // OpenTelemetry configuration @@ -58,6 +71,41 @@ type Config struct { ProfileThreshold time.Duration MaxProfileMetrics int + + // Dashboard security ---------------------------------------------------- + + // DashboardAuth, if set, must approve every request to the dashboard. + // If nil, the dashboard is fully open — only safe for local dev. + DashboardAuth DashboardAuth + + // LocalhostOnly, when true, rejects dashboard requests whose remote address + // is not a loopback IP. This is the safest default for "I'm just debugging + // locally" — even with the rest of the server bound to 0.0.0.0. + LocalhostOnly bool + + // EnableReplay enables the POST /__viz/api/replay endpoint, which lets the + // dashboard fire arbitrary HTTP requests from the server. Disabled by + // default because it is a powerful SSRF primitive if the dashboard is + // reachable by an attacker. + EnableReplay bool + + // ExposeSystemInfo controls whether the GET /__viz/api/system-info endpoint + // is enabled. Disabled by default; enabling exposes runtime info (hostname, + // Go version, memory stats). + ExposeSystemInfo bool + + // ExposeEnvVars is an explicit allowlist of environment variable names that + // the system-info endpoint may surface. Anything not in this set is omitted + // entirely (NOT redacted) so an attacker cannot infer the existence of a + // sensitive name. + ExposeEnvVars []string + + // ShutdownContext, if set, triggers graceful shutdown of govisual-owned + // resources (storage backends, OpenTelemetry tracer provider) when the + // context is cancelled. This replaces the prior behavior of registering a + // global signal handler that called os.Exit — a library has no business + // killing the host process. + ShutdownContext context.Context } // Option is a function that modifies the configuration @@ -91,6 +139,17 @@ func WithResponseBodyLogging(enabled bool) Option { } } +// WithMaxBodyBytes caps the captured request and response body size. +// Values: +// - 0: use the package default (1 MiB) +// - >0: explicit cap in bytes +// - <0: disable cap (unbounded — be careful with large downloads) +func WithMaxBodyBytes(n int) Option { + return func(c *Config) { + c.MaxBodyBytes = n + } +} + // WithIgnorePaths sets the path patterns to ignore func WithIgnorePaths(patterns ...string) Option { return func(c *Config) { @@ -193,30 +252,25 @@ func WithMongoDBStorage(uri, databaseName, collectionName string) Option { } } -// ShouldIgnorePath checks if a path should be ignored based on the configured patterns -// ShouldIgnorePath checks if a path should be ignored based on the configured patterns +// ShouldIgnorePath checks if a path should be ignored based on the configured patterns. func (c *Config) ShouldIgnorePath(path string) bool { - // First check if it's the dashboard path which should always be ignored to prevent recursive logging + // The dashboard itself must always be ignored, otherwise opening it + // would recursively log every poll. if path == c.DashboardPath || strings.HasPrefix(path, c.DashboardPath+"/") { return true } - // Then check against provided ignore patterns for _, pattern := range c.IgnorePaths { - matched, err := filepath.Match(pattern, path) - if err == nil && matched { + if matched, err := filepath.Match(pattern, path); err == nil && matched { return true } - - // Special handling for path groups with trailing slash + // Trailing-slash patterns are treated as "prefix match". if len(pattern) > 0 && pattern[len(pattern)-1] == '/' { - // If pattern ends with /, check if path starts with pattern - if len(path) >= len(pattern) && path[:len(pattern)] == pattern { + if strings.HasPrefix(path, pattern) { return true } } } - return false } @@ -248,6 +302,78 @@ func WithMaxProfileMetrics(max int) Option { } } +// WithDashboardAuth installs a custom authentication function for the dashboard. +// The function runs on every dashboard request and must return true to allow access. +func WithDashboardAuth(fn DashboardAuth) Option { + return func(c *Config) { + c.DashboardAuth = fn + } +} + +// WithBasicAuth protects the dashboard with HTTP Basic Auth using a constant-time +// comparison. Both username and password are required. +func WithBasicAuth(username, password string) Option { + expectedUser := []byte(username) + expectedPass := []byte(password) + return func(c *Config) { + c.DashboardAuth = func(r *http.Request) bool { + user, pass, ok := r.BasicAuth() + if !ok { + return false + } + userOK := subtle.ConstantTimeCompare([]byte(user), expectedUser) == 1 + passOK := subtle.ConstantTimeCompare([]byte(pass), expectedPass) == 1 + return userOK && passOK + } + } +} + +// WithLocalhostOnly restricts the dashboard to requests originating from a +// loopback address. Combine with WithDashboardAuth/WithBasicAuth for defense +// in depth. +func WithLocalhostOnly() Option { + return func(c *Config) { + c.LocalhostOnly = true + } +} + +// WithReplayEnabled enables the dashboard's /api/replay endpoint. Disabled by +// default because the endpoint, if reachable, lets a caller make the server +// perform arbitrary outbound HTTP requests (an SSRF primitive). Only enable +// behind authentication and/or localhost-only access. +func WithReplayEnabled(enabled bool) Option { + return func(c *Config) { + c.EnableReplay = enabled + } +} + +// WithSystemInfo enables the dashboard's /api/system-info endpoint and +// optionally sets the allowlist of environment variable names to expose. +// Pass no names to enable the endpoint but expose nothing (memory/runtime +// info only). +func WithSystemInfo(envAllowlist ...string) Option { + return func(c *Config) { + c.ExposeSystemInfo = true + c.ExposeEnvVars = append(c.ExposeEnvVars, envAllowlist...) + } +} + +// WithShutdownContext wires govisual's internal cleanup (storage backends, +// OpenTelemetry shutdown) to a caller-provided context. When the context is +// cancelled, govisual releases its resources. Replaces the prior behavior of +// installing a global signal handler that called os.Exit. +// +// Note: govisual spawns one goroutine that blocks on ctx.Done() for the +// lifetime of the wrapped handler. If you never cancel the context (for +// example, by passing context.Background()), that goroutine is retained for +// the process lifetime — harmless in long-running services, but tests should +// pass a cancellable context (e.g. t.Context()) to avoid leaks across cases. +func WithShutdownContext(ctx context.Context) Option { + return func(c *Config) { + c.ShutdownContext = ctx + } +} + // defaultConfig returns the default configuration func defaultConfig() *Config { return &Config{ @@ -255,6 +381,7 @@ func defaultConfig() *Config { DashboardPath: "/__viz", LogRequestBody: false, LogResponseBody: false, + MaxBodyBytes: 0, // 0 => use middleware.DefaultMaxBodyBytes IgnorePaths: []string{}, EnableOpenTelemetry: false, ServiceName: "govisual", @@ -269,5 +396,7 @@ func defaultConfig() *Config { ProfileType: profiling.ProfileAll, ProfileThreshold: 10 * time.Millisecond, MaxProfileMetrics: 1000, + EnableReplay: false, + ExposeSystemInfo: false, } } diff --git a/wrap.go b/wrap.go index d3c1009..62d61c7 100644 --- a/wrap.go +++ b/wrap.go @@ -3,12 +3,10 @@ package govisual import ( "context" "log" + "net" "net/http" - "os" - "os/signal" "strings" - "sync" - "syscall" + "time" "github.com/doganarif/govisual/internal/dashboard" "github.com/doganarif/govisual/internal/middleware" @@ -17,68 +15,19 @@ import ( "github.com/doganarif/govisual/internal/telemetry" ) -var ( - // Global signal handler to ensure we only have one - signalOnce sync.Once - shutdownFuncs []func(context.Context) error - shutdownMutex sync.Mutex -) - -// addShutdownFunc adds a shutdown function to be called on signal -func addShutdownFunc(fn func(context.Context) error) { - if fn == nil { - log.Println("Warning: Attempted to register nil shutdown function, ignoring") - return - } - shutdownMutex.Lock() - defer shutdownMutex.Unlock() - shutdownFuncs = append(shutdownFuncs, fn) -} - -// setupSignalHandler sets up a single signal handler for all cleanup operations -func setupSignalHandler() { - signalOnce.Do(func() { - signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) - - go func() { - sig := <-signals - log.Printf("Received shutdown signal (%v), cleaning up...", sig) - - ctx := context.Background() - shutdownMutex.Lock() - funcs := make([]func(context.Context) error, len(shutdownFuncs)) - copy(funcs, shutdownFuncs) - shutdownMutex.Unlock() - - // Execute all shutdown functions - for _, fn := range funcs { - if err := fn(ctx); err != nil { - log.Printf("Error during shutdown: %v", err) - } - } - - log.Println("Cleanup completed, exiting...") - - // Stop listening for more signals and exit - signal.Stop(signals) - os.Exit(0) - }() - }) -} - -// Wrap wraps an http.Handler with request visualization middleware +// Wrap wraps an http.Handler with the govisual request visualization middleware +// and mounts the dashboard at config.DashboardPath. Pass options to customize +// behavior. To trigger graceful shutdown of storage and telemetry resources, +// pass WithShutdownContext — govisual will release its resources when that +// context is cancelled. Govisual deliberately does NOT register a signal +// handler; that is the host application's job. func Wrap(handler http.Handler, opts ...Option) http.Handler { - // Apply options to default config config := defaultConfig() for _, opt := range opts { opt(config) } - // Create store based on configuration var requestStore store.Store - var err error - storeConfig := &store.StorageConfig{ Type: config.StorageType, Capacity: config.MaxRequests, @@ -87,41 +36,39 @@ func Wrap(handler http.Handler, opts ...Option) http.Handler { TTL: config.RedisTTL, ExistingDB: config.ExistingDB, } - - requestStore, err = store.NewStore(storeConfig) + rs, err := store.NewStore(storeConfig) if err != nil { - log.Printf("Failed to create configured storage backend: %v. Falling back to in-memory storage.", err) + log.Printf("govisual: failed to create configured storage backend: %v. Falling back to in-memory storage.", err) requestStore = store.NewInMemoryStore(config.MaxRequests) + } else { + requestStore = rs } - // Add store cleanup to shutdown functions - addShutdownFunc(func(ctx context.Context) error { - if err := requestStore.Close(); err != nil { - log.Printf("Error closing storage: %v", err) - return err - } - return nil - }) - - // Create profiler if enabled var profiler *profiling.Profiler if config.EnableProfiling { profiler = profiling.NewProfiler(config.MaxProfileMetrics) - profiler.SetEnabled(config.EnableProfiling) + profiler.SetEnabled(true) profiler.SetProfileType(config.ProfileType) profiler.SetThreshold(config.ProfileThreshold) - log.Printf("Performance profiling enabled with threshold: %v", config.ProfileThreshold) + log.Printf("govisual: performance profiling enabled (threshold=%v)", config.ProfileThreshold) } - // Create middleware wrapper with profiling support var wrapped http.Handler if profiler != nil { - wrapped = middleware.WrapWithProfiling(handler, requestStore, config.LogRequestBody, config.LogResponseBody, config, profiler) + wrapped = middleware.WrapWithProfilingAndLimits( + handler, requestStore, + config.LogRequestBody, config.LogResponseBody, + config, profiler, config.effectiveMaxBody(), + ) } else { - wrapped = middleware.Wrap(handler, requestStore, config.LogRequestBody, config.LogResponseBody, config) + wrapped = middleware.WrapWithLimits( + handler, requestStore, + config.LogRequestBody, config.LogResponseBody, + config, config.effectiveMaxBody(), + ) } - // Initialize OpenTelemetry if enabled + var otelShutdown func(context.Context) error if config.EnableOpenTelemetry { ctx := context.Background() otelConfig := telemetry.Config{ @@ -133,32 +80,91 @@ func Wrap(handler http.Handler, opts ...Option) http.Handler { } shutdown, err := telemetry.InitTracer(ctx, otelConfig) if err != nil { - log.Printf("Failed to initialize OpenTelemetry: %v", err) + log.Printf("govisual: failed to initialize OpenTelemetry: %v", err) } else { - log.Printf("OpenTelemetry initialized with service name: %s, endpoint: %s", config.ServiceName, config.OTelEndpoint) - - // Add OpenTelemetry shutdown to shutdown functions - addShutdownFunc(shutdown) - - // Wrap with OpenTelemetry middleware + log.Printf("govisual: OpenTelemetry initialized (service=%s endpoint=%s)", config.ServiceName, config.OTelEndpoint) + otelShutdown = shutdown wrapped = middleware.NewOTelMiddleware(wrapped, config.ServiceName, config.ServiceVersion) } } - // Set up the single signal handler - setupSignalHandler() + if config.ShutdownContext != nil { + // NOTE: this goroutine waits on ctx.Done() and is retained for the + // process lifetime if the context is never cancelled. Callers passing a + // non-cancellable context (e.g. context.Background()) should be aware + // of this — in tests, prefer t.Context() or a cancellable context. + go func(ctx context.Context, st store.Store, shutdown func(context.Context) error) { + <-ctx.Done() + log.Printf("govisual: shutdown context cancelled, releasing resources") + if shutdown != nil { + // Give OTel a real deadline to flush spans, independent of + // the parent context (which is already cancelled). + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := shutdown(shutdownCtx); err != nil { + log.Printf("govisual: error shutting down OpenTelemetry: %v", err) + } + cancel() + } + if err := st.Close(); err != nil { + log.Printf("govisual: error closing storage: %v", err) + } + }(config.ShutdownContext, requestStore, otelShutdown) + } + + dashHandler := dashboard.NewHandler(requestStore, profiler, dashboard.HandlerOptions{ + EnableReplay: config.EnableReplay, + ExposeSystemInfo: config.ExposeSystemInfo, + ExposeEnvVars: config.ExposeEnvVars, + }) - // Create dashboard handler with profiler - dashHandler := dashboard.NewHandler(requestStore, profiler) + guardedDash := guardDashboard(dashHandler, config) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.HasPrefix(r.URL.Path, config.DashboardPath) { - // Handle the dashboard routes - http.StripPrefix(config.DashboardPath, dashHandler).ServeHTTP(w, r) + http.StripPrefix(config.DashboardPath, guardedDash).ServeHTTP(w, r) return } - - // Otherwise, serve the application wrapped.ServeHTTP(w, r) }) } + +// guardDashboard wraps the dashboard handler with localhost-only and +// authentication checks per the configuration. +func guardDashboard(h http.Handler, config *Config) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if config.LocalhostOnly && !isLoopback(r) { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + if config.DashboardAuth != nil && !config.DashboardAuth(r) { + // Surface a Basic challenge so browsers prompt the user; harmless + // when a custom auth scheme is in use. + w.Header().Set("WWW-Authenticate", `Basic realm="govisual"`) + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + h.ServeHTTP(w, r) + }) +} + +// isLoopback reports whether the request's remote address is a loopback IP. +func isLoopback(r *http.Request) bool { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + ip := net.ParseIP(host) + if ip == nil { + return false + } + return ip.IsLoopback() +} + +// effectiveMaxBody resolves the configured MaxBodyBytes against the package +// default. 0 means "use default"; negative means "no cap". +func (c *Config) effectiveMaxBody() int { + if c.MaxBodyBytes == 0 { + return middleware.DefaultMaxBodyBytes + } + return c.MaxBodyBytes +}