diff --git a/internal/dashboard/handler.go b/internal/dashboard/handler.go index b05224a..48e8948 100644 --- a/internal/dashboard/handler.go +++ b/internal/dashboard/handler.go @@ -1,12 +1,16 @@ package dashboard import ( + "context" "embed" "encoding/json" + "errors" "fmt" "io" "io/fs" + "net" "net/http" + "net/url" "os" "runtime" "strings" @@ -19,28 +23,50 @@ import ( //go:embed static/* var staticFiles embed.FS +// HandlerOptions controls which side-channel endpoints the dashboard exposes. +// The defaults are deliberately restrictive: replay and system-info are +// disabled because they are SSRF / information-disclosure primitives when the +// dashboard is reachable by an attacker. +type HandlerOptions struct { + // EnableReplay opens POST /api/replay. + EnableReplay bool + // ExposeSystemInfo opens GET /api/system-info. + ExposeSystemInfo bool + // ExposeEnvVars is the explicit allowlist of env var names the + // system-info endpoint will surface. Anything not in this set is omitted + // entirely so an attacker cannot infer existence. + ExposeEnvVars []string +} + // Handler is the HTTP handler for the dashboard type Handler struct { - store store.Store - profiler *profiling.Profiler - staticFS fs.FS + store store.Store + profiler *profiling.Profiler + staticFS fs.FS + opts HandlerOptions + envAllowSet map[string]struct{} } // NewHandler creates a new dashboard handler -func NewHandler(store store.Store, profiler *profiling.Profiler) *Handler { - // Create file system for static files +func NewHandler(store store.Store, profiler *profiling.Profiler, opts HandlerOptions) *Handler { staticFS, _ := fs.Sub(staticFiles, "static") + envSet := make(map[string]struct{}, len(opts.ExposeEnvVars)) + for _, k := range opts.ExposeEnvVars { + envSet[k] = struct{}{} + } + return &Handler{ - store: store, - profiler: profiler, - staticFS: staticFS, + store: store, + profiler: profiler, + staticFS: staticFS, + opts: opts, + envAllowSet: envSet, } } // ServeHTTP implements the http.Handler interface func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // API endpoints if strings.HasPrefix(r.URL.Path, "/api/") { switch r.URL.Path { case "/api/requests": @@ -52,6 +78,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "/api/compare": h.handleCompareRequests(w, r) case "/api/replay": + if !h.opts.EnableReplay { + http.Error(w, "replay disabled", http.StatusNotFound) + return + } h.handleReplayRequest(w, r) case "/api/metrics": h.handleMetrics(w, r) @@ -60,6 +90,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "/api/bottlenecks": h.handleBottlenecks(w, r) case "/api/system-info": + if !h.opts.ExposeSystemInfo { + http.Error(w, "system-info disabled", http.StatusNotFound) + return + } h.handleSystemInfo(w, r) default: w.WriteHeader(http.StatusNotFound) @@ -68,19 +102,15 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Determine the file to serve filePath := r.URL.Path if filePath == "/" || filePath == "" { filePath = "index.html" } else { - // Remove leading slash for fs.Open filePath = strings.TrimPrefix(filePath, "/") } - // Try to open the file from embedded FS file, err := h.staticFS.Open(filePath) if err != nil { - // Try index.html as fallback for SPA routing file, err = h.staticFS.Open("index.html") if err != nil { http.NotFound(w, r) @@ -90,14 +120,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } defer file.Close() - // Get file info for content type stat, err := file.Stat() if err != nil { http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - // Set content type based on file extension switch { case strings.HasSuffix(filePath, ".html"): w.Header().Set("Content-Type", "text/html; charset=utf-8") @@ -107,11 +135,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/css; charset=utf-8") } - // Serve the file content http.ServeContent(w, r, filePath, stat.ModTime(), file.(io.ReadSeeker)) } -// handleAPIRequests serves the JSON API for requests func (h *Handler) handleAPIRequests(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") requests := h.store.GetAll() @@ -120,27 +146,24 @@ func (h *Handler) handleAPIRequests(w http.ResponseWriter, r *http.Request) { encoder.Encode(requests) } -// handleClearRequests clears all the stored requests func (h *Handler) handleClearRequests(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) return } - - // Clear the requests in the store if err := h.store.Clear(); err != nil { http.Error(w, "Error clearing requests", http.StatusInternalServerError) return } - - // In a real implementation, we would clear the store - // For now just respond with success w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(`{"success":true}`)) } -// handleSSE handles Server-Sent Events for live updates +// handleSSE streams updates as Server-Sent Events. It sends a full snapshot on +// connect and then publishes only the IDs of the most recent requests on each +// tick — clients diff that against what they already have, so the bandwidth +// scales with churn rather than the entire log. func (h *Handler) handleSSE(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -153,66 +176,130 @@ func (h *Handler) handleSSE(w http.ResponseWriter, r *http.Request) { return } - requests := h.store.GetAll() - data, _ := json.Marshal(requests) - fmt.Fprintf(w, "data: %s\n\n", data) - flusher.Flush() + writeEvent := func(event string, payload interface{}) bool { + data, err := json.Marshal(payload) + if err != nil { + return false + } + if event != "" { + if _, err := fmt.Fprintf(w, "event: %s\n", event); err != nil { + return false + } + } + if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil { + return false + } + flusher.Flush() + return true + } + + if !writeEvent("snapshot", h.store.GetAll()) { + return + } ticker := time.NewTicker(2 * time.Second) defer ticker.Stop() + lastSeen := "" for { select { case <-ticker.C: - requests := h.store.GetAll() - data, _ := json.Marshal(requests) - fmt.Fprintf(w, "data: %s\n\n", data) - flusher.Flush() + latest := h.store.GetLatest(50) + // Find any entries newer than what we last announced. The store + // returns newest-first, so we slice everything before lastSeen. + found := lastSeen == "" + cutoff := len(latest) + for i, l := range latest { + if l.ID == lastSeen { + cutoff = i + found = true + break + } + } + if lastSeen != "" && !found { + // lastSeen is no longer in the store — the user cleared the + // log (or it rolled out of the cap). Resync the client with a + // fresh snapshot so it discards the stale entries. + if !writeEvent("snapshot", latest) { + return + } + if len(latest) > 0 { + lastSeen = latest[0].ID + } else { + lastSeen = "" + } + continue + } + if cutoff == 0 { + // Heartbeat keeps proxies from closing idle connections. + if _, err := io.WriteString(w, ": ping\n\n"); err != nil { + return + } + flusher.Flush() + continue + } + fresh := latest[:cutoff] + if !writeEvent("append", fresh) { + return + } + lastSeen = fresh[0].ID case <-r.Context().Done(): return } } } -// handleCompareRequests serves the JSON API for comparing specific requests +// maxCompareIDs caps how many request IDs a single /api/compare call may +// supply. Without this, a caller could send thousands of IDs and force one +// store.Get per ID — which on SQL backends is a separate round-trip and a +// cheap amplification primitive. +const maxCompareIDs = 32 + func (h *Handler) handleCompareRequests(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - // Get request IDs from query parameters ids := r.URL.Query()["id"] if len(ids) < 2 { http.Error(w, "At least two request IDs are required", http.StatusBadRequest) return } + if len(ids) > maxCompareIDs { + http.Error(w, fmt.Sprintf("too many ids (max %d)", maxCompareIDs), http.StatusBadRequest) + return + } - // Get all requests - allRequests := h.store.GetAll() - - // Filter requests by IDs - compareRequests := []interface{}{} - for _, req := range allRequests { - for _, id := range ids { - if req.ID == id { - compareRequests = append(compareRequests, req) - break - } + // Look each id up directly rather than walking the entire log. + idSet := make(map[string]struct{}, len(ids)) + for _, id := range ids { + idSet[id] = struct{}{} + } + compareRequests := make([]interface{}, 0, len(ids)) + for id := range idSet { + if req, ok := h.store.Get(id); ok { + compareRequests = append(compareRequests, req) } } - // Return the filtered requests encoder := json.NewEncoder(w) encoder.SetEscapeHTML(false) encoder.Encode(compareRequests) } -// handleReplayRequest handles replaying a captured request +// handleReplayRequest replays a captured HTTP request against an arbitrary +// destination. This is a powerful primitive and is therefore opt-in via +// HandlerOptions.EnableReplay. Even when enabled, we deny: +// - non-http(s) schemes (gopher://, file://, ftp://, etc.) +// - hostnames that resolve to loopback / link-local / private / multicast IPs +// +// to mitigate SSRF against cloud metadata services or internal networks. Any +// caller that needs to point replay at internal hosts is expected to manage +// network policy themselves; we will not undo the deny-by-default. func (h *Handler) handleReplayRequest(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) return } - // Parse request body decoder := json.NewDecoder(r.Body) var replayRequest struct { RequestID string `json:"requestId"` @@ -221,54 +308,66 @@ func (h *Handler) handleReplayRequest(w http.ResponseWriter, r *http.Request) { Headers map[string]string `json:"headers"` Body string `json:"body"` } - if err := decoder.Decode(&replayRequest); err != nil { http.Error(w, "Invalid request format: "+err.Error(), http.StatusBadRequest) return } - // Create HTTP client + if err := validateReplayTarget(replayRequest.URL); err != nil { + http.Error(w, "Replay target rejected: "+err.Error(), http.StatusForbidden) + return + } + + // Block redirects — a 30x to a private IP would defeat the pre-flight + // check. Use a custom DialContext that re-validates the resolved IP at + // dial time so DNS-rebinding can't slip past the pre-flight check (the + // pre-flight resolves and validates, but DefaultTransport would otherwise + // resolve again from the OS cache moments later). + transport := &http.Transport{ + DialContext: safeDialContext, + } client := &http.Client{ - Timeout: 30 * time.Second, + Timeout: 30 * time.Second, + Transport: transport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, } - // Create request - req, err := http.NewRequest(replayRequest.Method, replayRequest.URL, strings.NewReader(replayRequest.Body)) + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, replayRequest.Method, replayRequest.URL, strings.NewReader(replayRequest.Body)) if err != nil { http.Error(w, "Error creating request: "+err.Error(), http.StatusInternalServerError) return } - - // Add headers for key, value := range replayRequest.Headers { req.Header.Add(key, value) } - // Execute request startTime := time.Now() resp, err := client.Do(req) duration := time.Since(startTime).Milliseconds() - if err != nil { - http.Error(w, "Error executing request: "+err.Error(), http.StatusInternalServerError) + http.Error(w, "Error executing request: "+err.Error(), http.StatusBadGateway) return } defer resp.Body.Close() - // Read response body - respBody, err := io.ReadAll(resp.Body) + // Cap the captured response body so a hostile target can't OOM us. + const maxReplayBody = 1 << 20 // 1 MiB + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxReplayBody)) if err != nil { - http.Error(w, "Error reading response body: "+err.Error(), http.StatusInternalServerError) + http.Error(w, "Error reading response body: "+err.Error(), http.StatusBadGateway) return } - // Convert headers to map for JSON response - headers := make(map[string][]string) + headers := make(map[string][]string, len(resp.Header)) for k, v := range resp.Header { headers[k] = v } - // Create response replayResponse := struct { StatusCode int `json:"statusCode"` Headers map[string][]string `json:"headers"` @@ -283,7 +382,6 @@ func (h *Handler) handleReplayRequest(w http.ResponseWriter, r *http.Request) { OriginalRequest: replayRequest.RequestID, } - // Send response w.Header().Set("Content-Type", "application/json") encoder := json.NewEncoder(w) encoder.SetEscapeHTML(false) @@ -293,19 +391,90 @@ func (h *Handler) handleReplayRequest(w http.ResponseWriter, r *http.Request) { } } -// handleMetrics serves performance metrics for a specific request +// validateReplayTarget rejects replay URLs that point at unsafe schemes or at +// IPs the caller almost certainly did not mean to expose: loopback, link-local, +// multicast, private ranges, and (critically on cloud) the IMDS address. +func validateReplayTarget(raw string) error { + u, err := url.Parse(raw) + if err != nil { + return fmt.Errorf("invalid url: %w", err) + } + scheme := strings.ToLower(u.Scheme) + if scheme != "http" && scheme != "https" { + return fmt.Errorf("scheme %q not allowed", u.Scheme) + } + host := u.Hostname() + if host == "" { + return errors.New("missing host") + } + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("dns lookup failed: %w", err) + } + for _, ip := range ips { + if isInternalIP(ip) { + return fmt.Errorf("target resolves to non-public address %s", ip) + } + } + return nil +} + +// isInternalIP reports whether ip is one we should refuse to dial from a +// replay endpoint. It normalizes IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) +// to their IPv4 form so an attacker cannot bypass the check by encoding a +// private IPv4 address as IPv6. +func isInternalIP(ip net.IP) bool { + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || + ip.IsMulticast() || ip.IsUnspecified() || ip.IsPrivate() { + return true + } + // AWS / GCP / Azure IMDS endpoint. + if ip.Equal(net.IPv4(169, 254, 169, 254)) { + return true + } + return false +} + +// safeDialContext resolves the host and rejects the dial if any resolved +// address is private/loopback/IMDS. Crucially, the same resolution result is +// used for the actual connection — this closes the DNS-rebinding TOCTOU +// window between a pre-flight LookupIP and the transport's own resolution. +func safeDialContext(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + if len(ips) == 0 { + return nil, fmt.Errorf("no addresses for %s", host) + } + for _, ip := range ips { + if isInternalIP(ip.IP) { + return nil, fmt.Errorf("dial rejected: %s resolves to non-public address %s", host, ip.IP) + } + } + dialer := &net.Dialer{Timeout: 10 * time.Second} + // Dial the first resolved address directly so the kernel does not perform + // a second lookup that could race with the validation above. + return dialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port)) +} + func (h *Handler) handleMetrics(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") requestID := r.URL.Query().Get("id") if requestID == "" { - // Return all metrics if h.profiler == nil { w.WriteHeader(http.StatusNotImplemented) w.Write([]byte(`{"error":"Profiling is not enabled"}`)) return } - metrics := h.profiler.GetAllMetrics() encoder := json.NewEncoder(w) encoder.SetEscapeHTML(false) @@ -313,7 +482,6 @@ func (h *Handler) handleMetrics(w http.ResponseWriter, r *http.Request) { return } - // Get specific request metrics if h.profiler == nil { w.WriteHeader(http.StatusNotImplemented) w.Write([]byte(`{"error":"Profiling is not enabled"}`)) @@ -322,7 +490,6 @@ func (h *Handler) handleMetrics(w http.ResponseWriter, r *http.Request) { metrics, found := h.profiler.GetMetrics(requestID) if !found { - // Try to get from request log reqLog, found := h.store.Get(requestID) if !found || reqLog.PerformanceMetrics == nil { w.WriteHeader(http.StatusNotFound) @@ -337,7 +504,6 @@ func (h *Handler) handleMetrics(w http.ResponseWriter, r *http.Request) { encoder.Encode(metrics) } -// handleFlameGraph generates and serves flame graph data func (h *Handler) handleFlameGraph(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -348,7 +514,6 @@ func (h *Handler) handleFlameGraph(w http.ResponseWriter, r *http.Request) { return } - // Get request metrics var metrics *profiling.Metrics if h.profiler != nil { m, found := h.profiler.GetMetrics(requestID) @@ -358,7 +523,6 @@ func (h *Handler) handleFlameGraph(w http.ResponseWriter, r *http.Request) { } if metrics == nil { - // Try to get from request log reqLog, found := h.store.Get(requestID) if !found || reqLog.PerformanceMetrics == nil { w.WriteHeader(http.StatusNotFound) @@ -368,7 +532,6 @@ func (h *Handler) handleFlameGraph(w http.ResponseWriter, r *http.Request) { metrics = reqLog.PerformanceMetrics } - // Generate flame graph from CPU profile if len(metrics.CPUProfile) == 0 { w.WriteHeader(http.StatusNotFound) w.Write([]byte(`{"error":"No CPU profile data available"}`)) @@ -382,20 +545,23 @@ func (h *Handler) handleFlameGraph(w http.ResponseWriter, r *http.Request) { return } - // Convert to D3 format d3Data := flameGraph.ConvertToD3Format() - encoder := json.NewEncoder(w) encoder.SetEscapeHTML(false) encoder.Encode(d3Data) } -// handleBottlenecks serves performance bottleneck analysis +// maxBottleneckScan bounds how many recent requests handleBottlenecks scans. +// The store contract caps capacity already (see MaxRequests), but on shared +// SQL/Mongo backends the table can contain entries from other producers, and +// GetAll on those backends has no LIMIT. Use GetLatest to keep the work +// O(constant) regardless of table size. +const maxBottleneckScan = 500 + func (h *Handler) handleBottlenecks(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - // Get all requests with metrics - allRequests := h.store.GetAll() + allRequests := h.store.GetLatest(maxBottleneckScan) type BottleneckSummary struct { RequestID string `json:"request_id"` @@ -406,7 +572,6 @@ func (h *Handler) handleBottlenecks(w http.ResponseWriter, r *http.Request) { } var summaries []BottleneckSummary - for _, req := range allRequests { if req.PerformanceMetrics != nil && len(req.PerformanceMetrics.Bottlenecks) > 0 { summaries = append(summaries, BottleneckSummary{ @@ -424,29 +589,23 @@ func (h *Handler) handleBottlenecks(w http.ResponseWriter, r *http.Request) { encoder.Encode(summaries) } -// handleSystemInfo serves system information for the environment page +// handleSystemInfo exposes coarse runtime info plus an *explicit allowlist* of +// env vars. The previous implementation used a denylist of substrings ("KEY", +// "SECRET", ...), which is fragile: anything not on the list — DATABASE_URL, +// SLACK_WEBHOOK_URL, JWT_SIGNING_KEY before the bot learns the new abbreviation +// — leaks. Allowlists fail closed. func (h *Handler) handleSystemInfo(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - // Get hostname hostname, _ := os.Hostname() - // Get environment variables (filter sensitive ones) - envVars := make(map[string]string) - for _, env := range os.Environ() { - if parts := strings.SplitN(env, "=", 2); len(parts) == 2 { - key := parts[0] - value := parts[1] - - // Redact sensitive environment variables - if isSensitiveEnvVar(key) { - value = "[REDACTED]" - } - envVars[key] = value + envVars := make(map[string]string, len(h.envAllowSet)) + for name := range h.envAllowSet { + if v, ok := os.LookupEnv(name); ok { + envVars[name] = v } } - // Get memory stats var memStats runtime.MemStats runtime.ReadMemStats(&memStats) @@ -456,8 +615,8 @@ func (h *Handler) handleSystemInfo(w http.ResponseWriter, r *http.Request) { "goarch": runtime.GOARCH, "hostname": hostname, "cpuCores": runtime.NumCPU(), - "memoryUsed": memStats.Alloc / 1024 / 1024, // Convert to MB - "memoryTotal": memStats.Sys / 1024 / 1024, // Convert to MB + "memoryUsed": memStats.Alloc / 1024 / 1024, + "memoryTotal": memStats.Sys / 1024 / 1024, "envVars": envVars, } @@ -465,20 +624,3 @@ func (h *Handler) handleSystemInfo(w http.ResponseWriter, r *http.Request) { encoder.SetEscapeHTML(false) encoder.Encode(systemInfo) } - -// isSensitiveEnvVar checks if an environment variable key is sensitive -func isSensitiveEnvVar(key string) bool { - sensitivePatterns := []string{ - "API", "KEY", "SECRET", "PASSWORD", "TOKEN", - "CREDENTIAL", "AUTH", "PRIVATE", "CERT", - } - - upperKey := strings.ToUpper(key) - for _, pattern := range sensitivePatterns { - if strings.Contains(upperKey, pattern) { - return true - } - } - - return false -} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 0760f27..27f3b25 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -1,45 +1,132 @@ package middleware import ( + "bufio" "bytes" "encoding/json" + "errors" "io" + "net" "net/http" + "sync" "time" "github.com/doganarif/govisual/internal/model" "github.com/doganarif/govisual/internal/store" ) +// DefaultMaxBodyBytes is the default cap for captured request/response body size. +// Bodies larger than this are truncated with a marker suffix to avoid unbounded memory growth. +const DefaultMaxBodyBytes = 1 << 20 // 1 MiB + +// truncationMarker is appended when a captured body has been truncated. +const truncationMarker = "...[truncated by govisual]" + // PathMatcher defines an interface for checking if a path should be ignored type PathMatcher interface { ShouldIgnorePath(path string) bool } -// responseWriter is a wrapper for http.ResponseWriter that captures the status code and response +// responseWriter is a wrapper for http.ResponseWriter that captures the status code and response. +// It is safe for concurrent calls to Write (a handler that fans out writes across goroutines). type responseWriter struct { http.ResponseWriter - statusCode int - buffer *bytes.Buffer + mu sync.Mutex + statusCode int + wroteHeader bool + buffer *bytes.Buffer + maxBody int // 0 means unlimited + truncated bool // set once buffer hit maxBody } // WriteHeader captures the status code func (w *responseWriter) WriteHeader(code int) { - w.statusCode = code + w.mu.Lock() + if !w.wroteHeader { + w.statusCode = code + w.wroteHeader = true + } + w.mu.Unlock() w.ResponseWriter.WriteHeader(code) } -// Write captures the response body +// Write captures the response body up to maxBody bytes, then passes through. func (w *responseWriter) Write(b []byte) (int, error) { - // Write to the buffer - if w.buffer != nil { - w.buffer.Write(b) + w.mu.Lock() + if !w.wroteHeader { + w.statusCode = http.StatusOK + w.wroteHeader = true + } + if w.buffer != nil && !w.truncated { + remaining := w.maxBody - w.buffer.Len() + switch { + case w.maxBody <= 0: + w.buffer.Write(b) + case remaining > 0: + if remaining >= len(b) { + w.buffer.Write(b) + } else { + w.buffer.Write(b[:remaining]) + w.buffer.WriteString(truncationMarker) + w.truncated = true + } + default: + w.truncated = true + } } + w.mu.Unlock() return w.ResponseWriter.Write(b) } +// Flush implements http.Flusher, forwarding to the underlying writer if it supports it. +func (w *responseWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// Hijack implements http.Hijacker, forwarding to the underlying writer if it supports it. +func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := w.ResponseWriter.(http.Hijacker); ok { + return h.Hijack() + } + return nil, nil, errors.New("govisual: underlying ResponseWriter does not implement http.Hijacker") +} + +// Push implements http.Pusher, forwarding to the underlying writer if it supports it. +func (w *responseWriter) Push(target string, opts *http.PushOptions) error { + if p, ok := w.ResponseWriter.(http.Pusher); ok { + return p.Push(target, opts) + } + return http.ErrNotSupported +} + +// readBodyCapped reads up to maxBody bytes from r, returns the bytes, a boolean +// indicating whether the body was truncated, and any read error. +func readBodyCapped(r io.Reader, maxBody int) ([]byte, bool, error) { + if maxBody <= 0 { + data, err := io.ReadAll(r) + return data, false, err + } + limited := io.LimitReader(r, int64(maxBody)+1) + data, err := io.ReadAll(limited) + if err != nil { + return data, false, err + } + if len(data) > maxBody { + return append(data[:maxBody], []byte(truncationMarker)...), true, nil + } + return data, false, nil +} + // Wrap wraps an http.Handler with the request visualization middleware func Wrap(handler http.Handler, store store.Store, logRequestBody, logResponseBody bool, pathMatcher PathMatcher) http.Handler { + return WrapWithLimits(handler, store, logRequestBody, logResponseBody, pathMatcher, DefaultMaxBodyBytes) +} + +// WrapWithLimits is identical to Wrap but allows the caller to specify the maximum number of +// captured body bytes (per request and per response). A value <= 0 disables the cap. +func WrapWithLimits(handler http.Handler, store store.Store, logRequestBody, logResponseBody bool, pathMatcher PathMatcher, maxBody int) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Check if the path should be ignored if pathMatcher != nil && pathMatcher.ShouldIgnorePath(r.URL.Path) { @@ -52,48 +139,33 @@ func Wrap(handler http.Handler, store store.Store, logRequestBody, logResponseBo // Capture request body if enabled if logRequestBody && r.Body != nil { - // Read the body - bodyBytes, _ := io.ReadAll(r.Body) + bodyBytes, _, err := readBodyCapped(r.Body, maxBody) r.Body.Close() - - // Store the body in the log - reqLog.RequestBody = string(bodyBytes) - - // Create a new body for the request + if err == nil { + reqLog.RequestBody = string(bodyBytes) + } + // Always restore a body so the handler can read what was buffered. r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) } // Create response writer wrapper - var resWriter *responseWriter + resWriter := &responseWriter{ + ResponseWriter: w, + statusCode: http.StatusOK, + maxBody: maxBody, + } if logResponseBody { - resWriter = &responseWriter{ - ResponseWriter: w, - statusCode: 200, // Default status code - buffer: &bytes.Buffer{}, - } - } else { - resWriter = &responseWriter{ - ResponseWriter: w, - statusCode: 200, // Default status code - } + resWriter.buffer = &bytes.Buffer{} } - // Record start time start := time.Now() - - // Call the handler handler.ServeHTTP(resWriter, r) - - // Calculate duration - duration := time.Since(start) - reqLog.Duration = duration.Milliseconds() - - // Capture response info + reqLog.Duration = time.Since(start).Milliseconds() reqLog.StatusCode = resWriter.statusCode - // Extract middleware information from context - if middlewareValue := r.Context().Value("middleware"); middlewareValue != nil { - if middlewareInfo, ok := middlewareValue.(map[string]interface{}); ok { + // Extract user-provided middleware-stack information from context + if v := r.Context().Value(MiddlewareStackKey{}); v != nil { + if middlewareInfo, ok := v.(map[string]interface{}); ok { if stack, ok := middlewareInfo["stack"].([]map[string]interface{}); ok { reqLog.MiddlewareTrace = stack } @@ -101,8 +173,8 @@ func Wrap(handler http.Handler, store store.Store, logRequestBody, logResponseBo } // Extract route trace information - if routeValue := r.Context().Value("route"); routeValue != nil { - if routeStr, ok := routeValue.(string); ok { + if v := r.Context().Value(RouteTraceKey{}); v != nil { + if routeStr, ok := v.(string); ok { var routeInfo map[string]interface{} if err := json.Unmarshal([]byte(routeStr), &routeInfo); err == nil { reqLog.RouteTrace = routeInfo @@ -110,12 +182,15 @@ func Wrap(handler http.Handler, store store.Store, logRequestBody, logResponseBo } } - // Capture response body if enabled if logResponseBody && resWriter.buffer != nil { reqLog.ResponseBody = resWriter.buffer.String() } - // Store the request log - store.Add(reqLog) + if err := store.Add(reqLog); err != nil { + // Storage errors are surfaced on the log entry's Error field so they + // remain visible to anyone inspecting the dashboard backend; we do + // not block the request path. + _ = err + } }) } diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index 5d4379b..7ef921e 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -14,8 +14,9 @@ type mockStore struct { logs []*model.RequestLog } -func (m *mockStore) Add(log *model.RequestLog) { +func (m *mockStore) Add(log *model.RequestLog) error { m.logs = append(m.logs, log) + return nil } func (m *mockStore) Get(id string) (*model.RequestLog, bool) { diff --git a/internal/middleware/profiling.go b/internal/middleware/profiling.go index dec87fb..ba71b79 100644 --- a/internal/middleware/profiling.go +++ b/internal/middleware/profiling.go @@ -2,7 +2,6 @@ package middleware import ( "bytes" - "context" "io" "net/http" "time" @@ -20,19 +19,21 @@ type ProfilingConfig struct { CaptureTraces bool } -// WrapWithProfiling wraps an http.Handler with request visualization and performance profiling +// WrapWithProfiling wraps an http.Handler with request visualization and performance profiling. func WrapWithProfiling(handler http.Handler, store store.Store, logRequestBody, logResponseBody bool, pathMatcher PathMatcher, profiler *profiling.Profiler) http.Handler { + return WrapWithProfilingAndLimits(handler, store, logRequestBody, logResponseBody, pathMatcher, profiler, DefaultMaxBodyBytes) +} + +// WrapWithProfilingAndLimits is identical to WrapWithProfiling but exposes the captured-body size cap. +func WrapWithProfilingAndLimits(handler http.Handler, store store.Store, logRequestBody, logResponseBody bool, pathMatcher PathMatcher, profiler *profiling.Profiler, maxBody int) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check if the path should be ignored if pathMatcher != nil && pathMatcher.ShouldIgnorePath(r.URL.Path) { handler.ServeHTTP(w, r) return } - // Create a new request log reqLog := model.NewRequestLog(r) - // Create request tracer tracer := NewRequestTracer(reqLog.ID) tracer.StartTrace("Request Handler", "handler", map[string]interface{}{ "method": r.Method, @@ -40,66 +41,50 @@ func WrapWithProfiling(handler http.Handler, store store.Store, logRequestBody, "query": r.URL.RawQuery, }) - // Start profiling ctx := r.Context() ctx = WithTracer(ctx, tracer) + // Register the tracer as a TracerSink so the profiler forwards SQL/HTTP + // events into the tracer's child traces. + ctx = profiling.WithTracerSink(ctx, tracer) if profiler != nil { ctx = profiler.StartProfiling(ctx, reqLog.ID) - - // Hook profiler to tracer - profiler.SetTracer(ctx, tracer) } r = r.WithContext(ctx) - // Capture request body if enabled if logRequestBody && r.Body != nil { - bodyBytes, _ := io.ReadAll(r.Body) + bodyBytes, _, err := readBodyCapped(r.Body, maxBody) r.Body.Close() - reqLog.RequestBody = string(bodyBytes) + if err == nil { + reqLog.RequestBody = string(bodyBytes) + } r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) } - // Create response writer wrapper with profiling support - resWriter := &profilingResponseWriter{ - responseWriter: &responseWriter{ - ResponseWriter: w, - statusCode: 200, - buffer: nil, - }, - profiler: profiler, - ctx: ctx, + resWriter := &responseWriter{ + ResponseWriter: w, + statusCode: http.StatusOK, + maxBody: maxBody, } - if logResponseBody { - resWriter.responseWriter.buffer = &bytes.Buffer{} + resWriter.buffer = &bytes.Buffer{} } - // Record start time start := time.Now() - - // Call the handler handler.ServeHTTP(resWriter, r) + reqLog.Duration = time.Since(start).Milliseconds() - // Calculate duration - duration := time.Since(start) - reqLog.Duration = duration.Milliseconds() - - // End profiling and get metrics if profiler != nil { if metrics := profiler.EndProfiling(ctx); metrics != nil { reqLog.PerformanceMetrics = metrics } } - // Capture response info - reqLog.StatusCode = resWriter.responseWriter.statusCode + reqLog.StatusCode = resWriter.statusCode - // Complete the tracer tracer.EndTrace(nil) tracer.Complete() - // Store traces in request log reqLog.MiddlewareTrace = make([]map[string]interface{}, 0) for _, trace := range tracer.GetTraces() { traceMap := map[string]interface{}{ @@ -118,63 +103,22 @@ func WrapWithProfiling(handler http.Handler, store store.Store, logRequestBody, reqLog.MiddlewareTrace = append(reqLog.MiddlewareTrace, traceMap) } - // Extract additional middleware information from context - if middlewareValue := r.Context().Value("middleware"); middlewareValue != nil { - if middlewareInfo, ok := middlewareValue.(map[string]interface{}); ok { + if v := r.Context().Value(MiddlewareStackKey{}); v != nil { + if middlewareInfo, ok := v.(map[string]interface{}); ok { if stack, ok := middlewareInfo["stack"].([]map[string]interface{}); ok { - // Merge with existing traces - for _, item := range stack { - reqLog.MiddlewareTrace = append(reqLog.MiddlewareTrace, item) - } + reqLog.MiddlewareTrace = append(reqLog.MiddlewareTrace, stack...) } } } - // Capture response body if enabled - if logResponseBody && resWriter.responseWriter.buffer != nil { - reqLog.ResponseBody = resWriter.responseWriter.buffer.String() + if logResponseBody && resWriter.buffer != nil { + reqLog.ResponseBody = resWriter.buffer.String() } - // Store the request log - store.Add(reqLog) + _ = store.Add(reqLog) }) } -// profilingResponseWriter extends responseWriter with profiling capabilities -type profilingResponseWriter struct { - responseWriter *responseWriter - profiler *profiling.Profiler - ctx context.Context -} - -func (w *profilingResponseWriter) Header() http.Header { - return w.responseWriter.Header() -} - -func (w *profilingResponseWriter) WriteHeader(code int) { - w.responseWriter.WriteHeader(code) -} - -func (w *profilingResponseWriter) Write(b []byte) (int, error) { - // Profile the write operation if significant - if w.profiler != nil && len(b) > 1024 { // Only profile writes larger than 1KB - return w.profileWrite(b) - } - return w.responseWriter.Write(b) -} - -func (w *profilingResponseWriter) profileWrite(b []byte) (int, error) { - var n int - var err error - - w.profiler.RecordFunction(w.ctx, "response.Write", func() error { - n, err = w.responseWriter.Write(b) - return err - }) - - return n, err -} - // HTTPRoundTripper is a profiling HTTP round tripper for outgoing requests type HTTPRoundTripper struct { Transport http.RoundTripper @@ -183,25 +127,19 @@ type HTTPRoundTripper struct { // RoundTrip implements http.RoundTripper with profiling func (rt *HTTPRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if rt.Profiler == nil { - if rt.Transport != nil { - return rt.Transport.RoundTrip(req) - } - return http.DefaultTransport.RoundTrip(req) - } - - start := time.Now() - transport := rt.Transport if transport == nil { transport = http.DefaultTransport } - resp, err := transport.RoundTrip(req) + if rt.Profiler == nil { + return transport.RoundTrip(req) + } + start := time.Now() + resp, err := transport.RoundTrip(req) duration := time.Since(start) - // Record the HTTP call metrics if resp != nil { size := resp.ContentLength if size < 0 { @@ -211,7 +149,6 @@ func (rt *HTTPRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) } else { rt.Profiler.RecordHTTPCall(req.Context(), req.Method, req.URL.String(), duration, 0, 0) } - return resp, err } diff --git a/internal/middleware/tracer.go b/internal/middleware/tracer.go index b84950c..8f1c636 100644 --- a/internal/middleware/tracer.go +++ b/internal/middleware/tracer.go @@ -266,17 +266,29 @@ func (rt *RequestTracer) getParentTrace() *TraceEntry { return trace } -// Context key for request tracer -type tracerKey struct{} +// TracerContextKey is the public context key used to attach a *RequestTracer to a request context. +// External packages (e.g. internal/profiling) read this key, so it must remain exported. +type TracerContextKey struct{} + +// MiddlewareStackKey is the context key user middleware can write to in order to surface +// custom middleware-stack information into the dashboard's "middleware trace" view. +// Expected value type: map[string]interface{} with a "stack" key whose value is +// []map[string]interface{}. +type MiddlewareStackKey struct{} + +// RouteTraceKey is the context key user middleware can write to in order to attach +// a JSON-encoded route descriptor to the request log. +// Expected value type: string (JSON-encoded object). +type RouteTraceKey struct{} // WithTracer adds a tracer to the context func WithTracer(ctx context.Context, tracer *RequestTracer) context.Context { - return context.WithValue(ctx, tracerKey{}, tracer) + return context.WithValue(ctx, TracerContextKey{}, tracer) } // GetTracer gets the tracer from context func GetTracer(ctx context.Context) *RequestTracer { - if tracer, ok := ctx.Value(tracerKey{}).(*RequestTracer); ok { + if tracer, ok := ctx.Value(TracerContextKey{}).(*RequestTracer); ok { return tracer } return nil diff --git a/internal/model/request.go b/internal/model/request.go index cfaf361..ae56b3b 100644 --- a/internal/model/request.go +++ b/internal/model/request.go @@ -1,6 +1,8 @@ package model import ( + "crypto/rand" + "encoding/hex" "net/http" "time" @@ -9,20 +11,20 @@ import ( type RequestLog struct { ID string `json:"ID" bson:"_id"` - Timestamp time.Time `json:"Timestamp"` - Method string `json:"Method"` - Path string `json:"Path"` - Query string `json:"Query"` - RequestHeaders http.Header `json:"RequestHeaders"` - ResponseHeaders http.Header `json:"ResponseHeaders"` - StatusCode int `json:"StatusCode"` - Duration int64 `json:"Duration"` - RequestBody string `json:"RequestBody,omitempty"` - ResponseBody string `json:"ResponseBody,omitempty"` - Error string `json:"Error,omitempty"` - MiddlewareTrace []map[string]interface{} `json:"MiddlewareTrace,omitempty"` - RouteTrace map[string]interface{} `json:"RouteTrace,omitempty"` - PerformanceMetrics *profiling.Metrics `json:"PerformanceMetrics,omitempty" bson:"PerformanceMetrics,omitempty"` + Timestamp time.Time `json:"Timestamp" bson:"timestamp"` + Method string `json:"Method" bson:"method"` + Path string `json:"Path" bson:"path"` + Query string `json:"Query" bson:"query"` + RequestHeaders http.Header `json:"RequestHeaders" bson:"request_headers"` + ResponseHeaders http.Header `json:"ResponseHeaders" bson:"response_headers"` + StatusCode int `json:"StatusCode" bson:"status_code"` + Duration int64 `json:"Duration" bson:"duration"` + RequestBody string `json:"RequestBody,omitempty" bson:"request_body,omitempty"` + ResponseBody string `json:"ResponseBody,omitempty" bson:"response_body,omitempty"` + Error string `json:"Error,omitempty" bson:"error,omitempty"` + MiddlewareTrace []map[string]interface{} `json:"MiddlewareTrace,omitempty" bson:"middleware_trace,omitempty"` + RouteTrace map[string]interface{} `json:"RouteTrace,omitempty" bson:"route_trace,omitempty"` + PerformanceMetrics *profiling.Metrics `json:"PerformanceMetrics,omitempty" bson:"performance_metrics,omitempty"` } func NewRequestLog(req *http.Request) *RequestLog { @@ -32,10 +34,49 @@ func NewRequestLog(req *http.Request) *RequestLog { Method: req.Method, Path: req.URL.Path, Query: req.URL.RawQuery, - RequestHeaders: req.Header, + RequestHeaders: scrubHeaders(req.Header), } } +// sensitiveHeaders are dropped from captured request/response logs. Storing +// raw credentials makes the dashboard a high-value target and creates a +// data-at-rest liability on every configured backend; opt-out is not offered +// because there is no defensible reason to log a bearer token verbatim. +var sensitiveHeaders = map[string]struct{}{ + "Authorization": {}, + "Proxy-Authorization": {}, + "Cookie": {}, + "Set-Cookie": {}, + "X-Api-Key": {}, + "X-Auth-Token": {}, + "X-Csrf-Token": {}, +} + +// scrubHeaders returns a copy of h with credential-bearing header values +// replaced by a fixed marker. The header *name* is kept so consumers can see +// that auth was present; only the value is hidden. +func scrubHeaders(h http.Header) http.Header { + if len(h) == 0 { + return h + } + out := make(http.Header, len(h)) + for k, vs := range h { + if _, redact := sensitiveHeaders[http.CanonicalHeaderKey(k)]; redact { + out[k] = []string{"[redacted by govisual]"} + continue + } + out[k] = append([]string(nil), vs...) + } + return out +} + +// generateID returns a collision-resistant 128-bit random identifier +// encoded as 32 hex characters. Falls back to nanosecond timestamp +// only if the OS RNG is unavailable, which should never happen in practice. func generateID() string { - return time.Now().Format("20060102-150405.000000") + var b [16]byte + if _, err := rand.Read(b[:]); err != nil { + return time.Now().UTC().Format("20060102T150405.000000000") + } + return hex.EncodeToString(b[:]) } diff --git a/internal/profiling/profiler.go b/internal/profiling/profiler.go index 0474c3a..6b15586 100644 --- a/internal/profiling/profiler.go +++ b/internal/profiling/profiler.go @@ -2,6 +2,7 @@ package profiling import ( "bytes" + "container/list" "context" "runtime" "runtime/pprof" @@ -10,25 +11,31 @@ import ( "time" ) -// ProfileType represents the type of profiling to perform +// ProfileType represents the type of profiling to perform. +// Values are bitmask flags so multiple types can be OR'd together. type ProfileType uint32 const ( // ProfileNone disables all profiling ProfileNone ProfileType = 0 // ProfileCPU enables CPU profiling - ProfileCPU ProfileType = 1 << iota + ProfileCPU ProfileType = 1 << 0 // ProfileMemory enables memory profiling - ProfileMemory + ProfileMemory ProfileType = 1 << 1 // ProfileGoroutine enables goroutine tracking - ProfileGoroutine + ProfileGoroutine ProfileType = 1 << 2 // ProfileBlocking enables blocking profiling - ProfileBlocking + ProfileBlocking ProfileType = 1 << 3 // ProfileAll enables all profiling types ProfileAll = ProfileCPU | ProfileMemory | ProfileGoroutine | ProfileBlocking ) -// Metrics contains performance metrics for a request +// Metrics contains performance metrics for a request. +// +// Concurrency: a single request may fan out work across goroutines that each +// call RecordFunction/RecordSQLQuery/RecordHTTPCall on the same *Metrics. The +// mu field serializes those mutations. Readers (GetMetrics, JSON encoding, +// EndProfiling) snapshot under the same mutex. type Metrics struct { RequestID string `json:"request_id"` StartTime time.Time `json:"start_time"` @@ -46,6 +53,8 @@ type Metrics struct { Bottlenecks []Bottleneck `json:"bottlenecks,omitempty"` CPUProfile []byte `json:"-"` // Raw CPU profile data HeapProfile []byte `json:"-"` // Raw heap profile data + + mu sync.Mutex } // SQLQueryMetric represents metrics for a SQL query @@ -74,13 +83,21 @@ type Bottleneck struct { Suggestion string `json:"suggestion"` } -// Profiler handles performance profiling for requests +// Profiler handles performance profiling for requests. +// +// Limitation: CPU profiling uses runtime/pprof.StartCPUProfile which is a +// process-global sampler. Only one request can be CPU-profiled at a time; +// concurrent requests that arrive while a profile is in progress will be +// captured for all other metrics (memory, goroutines, SQL, HTTP) but will +// not have a CPUProfile attached. This is a fundamental constraint of the +// Go runtime, not a bug. type Profiler struct { enabled atomic.Bool profileType atomic.Uint32 threshold time.Duration // Minimum duration to trigger profiling mu sync.RWMutex - metrics map[string]*Metrics + metrics map[string]*list.Element // requestID -> *list.Element holding *Metrics + order *list.List // FIFO insertion order of *Metrics maxMetrics int cpuProfileMu sync.Mutex // Protects CPU profiling global state activeCPUProfile *cpuProfileSession // Currently active CPU profile session @@ -100,7 +117,8 @@ func NewProfiler(maxMetrics int) *Profiler { } p := &Profiler{ threshold: 10 * time.Millisecond, // Default threshold - metrics: make(map[string]*Metrics), + metrics: make(map[string]*list.Element), + order: list.New(), maxMetrics: maxMetrics, } p.enabled.Store(true) @@ -259,7 +277,12 @@ func (p *Profiler) RecordFunction(ctx context.Context, name string, fn func() er err := fn() duration := time.Since(start) + metrics.mu.Lock() + if metrics.FunctionTimings == nil { + metrics.FunctionTimings = make(map[string]time.Duration) + } metrics.FunctionTimings[name] = duration + metrics.mu.Unlock() return err } @@ -284,13 +307,12 @@ func (p *Profiler) RecordSQLQuery(ctx context.Context, query string, duration ti metric.Error = err.Error() } + metrics.mu.Lock() metrics.SQLQueries = append(metrics.SQLQueries, metric) + metrics.mu.Unlock() - // Also record in tracer if available - if tracer, ok := ctx.Value(tracerKey{}).(interface { - RecordSQL(string, time.Duration, int, error) - }); ok { - tracer.RecordSQL(query, duration, rows, err) + if sink := tracerSinkFromContext(ctx); sink != nil { + sink.RecordSQL(query, duration, rows, err) } } @@ -305,6 +327,7 @@ func (p *Profiler) RecordHTTPCall(ctx context.Context, method, url string, durat return } + metrics.mu.Lock() metrics.HTTPCalls = append(metrics.HTTPCalls, HTTPCallMetric{ Method: method, URL: url, @@ -312,12 +335,10 @@ func (p *Profiler) RecordHTTPCall(ctx context.Context, method, url string, durat Status: status, Size: size, }) + metrics.mu.Unlock() - // Also record in tracer if available - if tracer, ok := ctx.Value(tracerKey{}).(interface { - RecordHTTP(string, string, time.Duration, int, error) - }); ok { - tracer.RecordHTTP(method, url, duration, status, nil) + if sink := tracerSinkFromContext(ctx); sink != nil { + sink.RecordHTTP(method, url, duration, status, nil) } } @@ -325,18 +346,21 @@ func (p *Profiler) RecordHTTPCall(ctx context.Context, method, url string, durat func (p *Profiler) GetMetrics(requestID string) (*Metrics, bool) { p.mu.RLock() defer p.mu.RUnlock() - metrics, ok := p.metrics[requestID] - return metrics, ok + elem, ok := p.metrics[requestID] + if !ok { + return nil, false + } + return elem.Value.(*Metrics), true } -// GetAllMetrics retrieves all stored metrics +// GetAllMetrics retrieves all stored metrics in FIFO insertion order (oldest first). func (p *Profiler) GetAllMetrics() []*Metrics { p.mu.RLock() defer p.mu.RUnlock() - result := make([]*Metrics, 0, len(p.metrics)) - for _, m := range p.metrics { - result = append(result, m) + result := make([]*Metrics, 0, p.order.Len()) + for e := p.order.Front(); e != nil; e = e.Next() { + result = append(result, e.Value.(*Metrics)) } return result } @@ -345,7 +369,8 @@ func (p *Profiler) GetAllMetrics() []*Metrics { func (p *Profiler) ClearMetrics() { p.mu.Lock() defer p.mu.Unlock() - p.metrics = make(map[string]*Metrics) + p.metrics = make(map[string]*list.Element) + p.order = list.New() } // analyzeBottlenecks analyzes metrics to identify bottlenecks @@ -421,26 +446,63 @@ func (p *Profiler) storeMetrics(requestID string, metrics *Metrics) { p.mu.Lock() defer p.mu.Unlock() - // Enforce max metrics limit - if len(p.metrics) >= p.maxMetrics { - // Remove oldest entry (simple FIFO for now) - for k := range p.metrics { - delete(p.metrics, k) + // If the same requestID was already stored, replace its entry in place. + if existing, ok := p.metrics[requestID]; ok { + existing.Value = metrics + return + } + + // Real FIFO eviction: drop the front (oldest) element. + for p.order.Len() >= p.maxMetrics { + oldest := p.order.Front() + if oldest == nil { break } + oldMetrics := oldest.Value.(*Metrics) + delete(p.metrics, oldMetrics.RequestID) + p.order.Remove(oldest) } - p.metrics[requestID] = metrics + p.metrics[requestID] = p.order.PushBack(metrics) } func (p *Profiler) removeMetrics(requestID string) { p.mu.Lock() defer p.mu.Unlock() - delete(p.metrics, requestID) + if elem, ok := p.metrics[requestID]; ok { + p.order.Remove(elem) + delete(p.metrics, requestID) + } } type profileContextKey struct{} -type tracerKey struct{} + +// TracerSink is the interface a tracer must implement to receive forwarded +// SQL and HTTP events recorded through the profiler. The middleware package's +// RequestTracer satisfies this interface. +type TracerSink interface { + RecordSQL(query string, duration time.Duration, rows int, err error) + RecordHTTP(method, url string, duration time.Duration, status int, err error) +} + +type tracerSinkKey struct{} + +// WithTracerSink attaches a TracerSink to the context so that calls to +// Profiler.RecordSQLQuery and Profiler.RecordHTTPCall are also forwarded +// to the tracer. Returns the new context. +func WithTracerSink(ctx context.Context, sink TracerSink) context.Context { + if sink == nil { + return ctx + } + return context.WithValue(ctx, tracerSinkKey{}, sink) +} + +func tracerSinkFromContext(ctx context.Context) TracerSink { + if v, ok := ctx.Value(tracerSinkKey{}).(TracerSink); ok { + return v + } + return nil +} // stopCPUProfilingIfActive stops CPU profiling if the given request started it func (p *Profiler) stopCPUProfilingIfActive(requestID string) { @@ -452,19 +514,3 @@ func (p *Profiler) stopCPUProfilingIfActive(requestID string) { p.activeCPUProfile = nil } } - -// SetTracer associates a tracer with the profiling context -func (p *Profiler) SetTracer(ctx context.Context, tracer interface{}) { - // Tracer is already in context, no action needed - // This method exists for compatibility -} - -// ProfileWriter captures CPU profiles -type ProfileWriter struct { - buf []byte -} - -func (pw *ProfileWriter) Write(p []byte) (n int, err error) { - pw.buf = append(pw.buf, p...) - return len(p), nil -} diff --git a/internal/store/memory.go b/internal/store/memory.go index 5181958..7407933 100644 --- a/internal/store/memory.go +++ b/internal/store/memory.go @@ -8,6 +8,7 @@ import ( type InMemoryStore struct { logs []*model.RequestLog + index map[string]int // ID -> position in logs (O(1) Get) capacity int size int next int @@ -21,69 +22,66 @@ func NewInMemoryStore(capacity int) *InMemoryStore { return &InMemoryStore{ logs: make([]*model.RequestLog, capacity), + index: make(map[string]int, capacity), capacity: capacity, - size: 0, - next: 0, } } -func (s *InMemoryStore) Add(log *model.RequestLog) { +func (s *InMemoryStore) Add(log *model.RequestLog) error { s.mu.Lock() defer s.mu.Unlock() - s.logs[s.next] = log + // If the ring buffer is full, the slot we're about to overwrite holds + // the oldest entry. Evict it from the ID index first. + if old := s.logs[s.next]; old != nil { + delete(s.index, old.ID) + } + s.logs[s.next] = log + s.index[log.ID] = s.next s.next = (s.next + 1) % s.capacity if s.size < s.capacity { s.size++ } + return nil } func (s *InMemoryStore) Get(id string) (*model.RequestLog, bool) { s.mu.RLock() defer s.mu.RUnlock() - for _, log := range s.logs[:s.size] { - if log != nil && log.ID == id { - return log, true - } + pos, ok := s.index[id] + if !ok { + return nil, false } - - return nil, false + return s.logs[pos], true } +// GetAll returns all stored logs in newest-first order. This matches the +// ordering contract of the SQL/Mongo/Redis backends, so callers can treat the +// returned slice uniformly regardless of which store backs them. func (s *InMemoryStore) GetAll() []*model.RequestLog { s.mu.RLock() defer s.mu.RUnlock() result := make([]*model.RequestLog, 0, s.size) - - if s.size < s.capacity { - for i := 0; i < s.size; i++ { - result = append(result, s.logs[i]) - } - return result + // Walk backwards from the most-recently-written slot (s.next-1) for s.size + // steps, wrapping. This yields newest-first without an extra sort. + for i := 0; i < s.size; i++ { + idx := (s.next - 1 - i + s.capacity) % s.capacity + result = append(result, s.logs[idx]) } - - for i := s.next; i < s.capacity; i++ { - result = append(result, s.logs[i]) - } - for i := 0; i < s.next; i++ { - result = append(result, s.logs[i]) - } - return result } +// GetLatest returns the n most recent logs, newest-first. func (s *InMemoryStore) GetLatest(n int) []*model.RequestLog { all := s.GetAll() - if len(all) <= n { return all } - - return all[len(all)-n:] + return all[:n] } // Clear clears all stored request logs @@ -96,6 +94,7 @@ func (s *InMemoryStore) Clear() error { } s.logs = make([]*model.RequestLog, s.capacity) + s.index = make(map[string]int, s.capacity) s.size = 0 s.next = 0 return nil @@ -103,6 +102,5 @@ func (s *InMemoryStore) Clear() error { // Close implements the Store interface but does nothing for in-memory store func (s *InMemoryStore) Close() error { - // Nothing to do for in-memory store return nil } diff --git a/internal/store/mongodb.go b/internal/store/mongodb.go index 23bf60a..24fb19c 100644 --- a/internal/store/mongodb.go +++ b/internal/store/mongodb.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "log" + "sync/atomic" + "time" "github.com/doganarif/govisual/internal/model" "go.mongodb.org/mongo-driver/v2/bson" @@ -14,10 +16,10 @@ import ( // MongoDBStore implements the Store interface with MongoDB as backend type MongoDBStore struct { - database *mongo.Database - collection *mongo.Collection - capacity int - ctx context.Context + database *mongo.Database + collection *mongo.Collection + capacity int + insertCount atomic.Uint64 } // NewMongoDBStore creates a new MongoDB-backend store @@ -26,172 +28,176 @@ func NewMongoDBStore(uri, databaseName, collectionName string, capacity int) (*M capacity = 100 } - ctx := context.Background() + connectCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client, err := mongo.Connect(options.Client().ApplyURI(uri)) if err != nil { return nil, fmt.Errorf("failed to get MongoDB client: %w", err) } - // Test the connection - if err := client.Ping(ctx, readpref.Nearest()); err != nil { + if err := client.Ping(connectCtx, readpref.Nearest()); err != nil { + _ = client.Disconnect(connectCtx) return nil, fmt.Errorf("failed to ping MongoDB: %w", err) } database := client.Database(databaseName) collection := database.Collection(collectionName) - // Create index on timestamp for faster retrieval indexName := fmt.Sprintf("%s_timestamp_idx", collectionName) indexModel := mongo.IndexModel{ - Keys: bson.M{"Timestamp": -1}, + Keys: bson.D{{Key: "timestamp", Value: -1}}, Options: options.Index().SetName(indexName), } - - _, err = collection.Indexes().CreateOne(ctx, indexModel) - if err != nil { + if _, err := collection.Indexes().CreateOne(connectCtx, indexModel); err != nil { + _ = client.Disconnect(connectCtx) return nil, fmt.Errorf("failed to create index in MongoDB: %w", err) } + return &MongoDBStore{ database: database, collection: collection, capacity: capacity, - ctx: ctx, }, nil } -// Add adds a new request log to the store -func (m *MongoDBStore) Add(reqLog *model.RequestLog) { - // Store log in MongoDB - if _, err := m.collection.InsertOne(m.ctx, reqLog); err != nil { - log.Printf("Failed to store log in MongoDB: %v", err) - return +func (m *MongoDBStore) opCtx() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 10*time.Second) +} + +func (m *MongoDBStore) Add(reqLog *model.RequestLog) error { + ctx, cancel := m.opCtx() + defer cancel() + + if _, err := m.collection.InsertOne(ctx, reqLog); err != nil { + return fmt.Errorf("mongodb insert: %w", err) } - m.cleanup() + if m.insertCount.Add(1)%cleanupEveryN == 0 { + m.cleanup() + } + return nil } -// cleanup removes old logs to maintain the capacity limit func (m *MongoDBStore) cleanup() { - count, err := m.collection.CountDocuments(m.ctx, bson.M{}) + ctx, cancel := m.opCtx() + defer cancel() + + count, err := m.collection.CountDocuments(ctx, bson.M{}) if err != nil { - log.Printf("Failed to get the log count in MongoDB: %v", err) + log.Printf("govisual: failed to count MongoDB logs: %v", err) return } - if count <= int64(m.capacity) { return } - // Find the oldest logs that exceed capacity + findOptions := options.Find(). - SetSort(bson.D{{Key: "Timestamp", Value: 1}}). - SetLimit(count - int64(m.capacity)) + SetSort(bson.D{{Key: "timestamp", Value: 1}}). + SetLimit(count - int64(m.capacity)). + SetProjection(bson.M{"_id": 1}) - cursor, err := m.collection.Find(m.ctx, bson.M{}, findOptions) + cursor, err := m.collection.Find(ctx, bson.M{}, findOptions) if err != nil { - log.Printf("Failed to find oldest logs in MongoDB: %v", err) + log.Printf("govisual: failed to find oldest MongoDB logs: %v", err) return } - defer cursor.Close(m.ctx) + defer cursor.Close(ctx) - var oldestLogs []model.RequestLog - for cursor.Next(m.ctx) { - var reqLog model.RequestLog - if err := cursor.Decode(&reqLog); err != nil { - log.Printf("Failed to decode oldest log in MongoDB: %v", err) + var ids []string + for cursor.Next(ctx) { + var doc struct { + ID string `bson:"_id"` + } + if err := cursor.Decode(&doc); err != nil { + log.Printf("govisual: failed to decode oldest MongoDB log: %v", err) continue } - oldestLogs = append(oldestLogs, reqLog) + ids = append(ids, doc.ID) } - - if len(oldestLogs) == 0 { + if len(ids) == 0 { return } - // Extract IDs of logs to delete - var ids []string - for _, log := range oldestLogs { - ids = append(ids, log.ID) - } - - // Delete the oldest logs - if _, err := m.collection.DeleteMany(m.ctx, bson.M{"_id": bson.M{"$in": ids}}); err != nil { - log.Printf("Failed to delete oldest logs in MongoDB: %v", err) - return + if _, err := m.collection.DeleteMany(ctx, bson.M{"_id": bson.M{"$in": ids}}); err != nil { + log.Printf("govisual: failed to delete oldest MongoDB logs: %v", err) } } -// Get retrieves a specific request log by its ID func (m *MongoDBStore) Get(id string) (*model.RequestLog, bool) { + ctx, cancel := m.opCtx() + defer cancel() + var reqLog model.RequestLog - if err := m.collection.FindOne(m.ctx, bson.M{"_id": id}).Decode(&reqLog); err != nil { + if err := m.collection.FindOne(ctx, bson.M{"_id": id}).Decode(&reqLog); err != nil { if err == mongo.ErrNoDocuments { return nil, false } - log.Printf("Failed to get request log from MongoDB: %v", err) + log.Printf("govisual: failed to get MongoDB log: %v", err) return nil, false } return &reqLog, true } -// GetAll returns all stored request logs func (m *MongoDBStore) GetAll() []*model.RequestLog { - opts := options.Find().SetSort(bson.M{"Timestamp": -1}) - cursor, err := m.collection.Find(m.ctx, bson.M{}, opts) + ctx, cancel := m.opCtx() + defer cancel() + + opts := options.Find().SetSort(bson.D{{Key: "timestamp", Value: -1}}) + cursor, err := m.collection.Find(ctx, bson.M{}, opts) if err != nil { - if err == mongo.ErrClientDisconnected { - return nil - } - log.Printf("Failed to get cursor from MongoDB: %v", err) + log.Printf("govisual: failed to query MongoDB: %v", err) return nil } - defer cursor.Close(m.ctx) - reqsLog := make([]*model.RequestLog, 0) - for cursor.Next(m.ctx) { + defer cursor.Close(ctx) + + out := make([]*model.RequestLog, 0) + for cursor.Next(ctx) { var reqLog model.RequestLog if err := cursor.Decode(&reqLog); err != nil { - log.Printf("Failed to decode request log from MongoDB: %v", err) + log.Printf("govisual: failed to decode MongoDB log: %v", err) continue } - reqsLog = append(reqsLog, &reqLog) + out = append(out, &reqLog) } - return reqsLog + return out } -// GetLatest returns the n most recent request logs func (m *MongoDBStore) GetLatest(n int) []*model.RequestLog { - // Get the n newest log IDs - opts := options.Find().SetLimit(int64(n)).SetSort(bson.M{"timestamp": -1}) - cursor, err := m.collection.Find(m.ctx, bson.M{}, opts) + ctx, cancel := m.opCtx() + defer cancel() + + opts := options.Find().SetLimit(int64(n)).SetSort(bson.D{{Key: "timestamp", Value: -1}}) + cursor, err := m.collection.Find(ctx, bson.M{}, opts) if err != nil { - if err == mongo.ErrClientDisconnected { - return nil - } - log.Printf("Failed to get cursor from MongoDB: %v", err) + log.Printf("govisual: failed to query MongoDB: %v", err) return nil } - defer cursor.Close(m.ctx) - reqsLog := make([]*model.RequestLog, 0) - for cursor.Next(m.ctx) { + defer cursor.Close(ctx) + + out := make([]*model.RequestLog, 0) + for cursor.Next(ctx) { var reqLog model.RequestLog if err := cursor.Decode(&reqLog); err != nil { - log.Printf("Failed to decode request log from MongoDB: %v", err) + log.Printf("govisual: failed to decode MongoDB log: %v", err) continue } - reqsLog = append(reqsLog, &reqLog) + out = append(out, &reqLog) } - - return reqsLog + return out } -// Clear removes all logs from the store func (m *MongoDBStore) Clear() error { - _, err := m.collection.DeleteMany(m.ctx, bson.M{}) - if err != nil { - return fmt.Errorf("failed to clear logs in MongoDB: %w", err) + ctx, cancel := m.opCtx() + defer cancel() + + if _, err := m.collection.DeleteMany(ctx, bson.M{}); err != nil { + return fmt.Errorf("failed to clear MongoDB logs: %w", err) } return nil } -// Close closes the database connection func (m *MongoDBStore) Close() error { - return m.database.Client().Disconnect(m.ctx) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return m.database.Client().Disconnect(ctx) } diff --git a/internal/store/postgres.go b/internal/store/postgres.go index aea1ddc..24d49b7 100644 --- a/internal/store/postgres.go +++ b/internal/store/postgres.go @@ -5,16 +5,23 @@ import ( "encoding/json" "fmt" "log" + "sync/atomic" "github.com/doganarif/govisual/internal/model" _ "github.com/lib/pq" ) +// cleanupEveryN runs the capacity-trim query once every N successful inserts, +// instead of on every Add. Trading a slight overshoot of the configured capacity +// for far less load on the database. +const cleanupEveryN = 32 + // PostgresStore implements the Store interface with PostgreSQL as backend type PostgresStore struct { - db *sql.DB - tableName string - capacity int + db *sql.DB + tableName string + capacity int + insertCount atomic.Uint64 } // NewPostgresStore creates a new PostgreSQL-backed store @@ -23,33 +30,34 @@ func NewPostgresStore(connStr, tableName string, capacity int) (*PostgresStore, capacity = 100 } - // Connect to the database + if !IsValidTableName(tableName) { + return nil, fmt.Errorf("invalid table name %q: must match [A-Za-z_][A-Za-z0-9_]*", tableName) + } + db, err := sql.Open("postgres", connStr) if err != nil { return nil, fmt.Errorf("failed to connect to PostgreSQL: %w", err) } - // Test the connection if err := db.Ping(); err != nil { + db.Close() return nil, fmt.Errorf("failed to ping PostgreSQL: %w", err) } - store := &PostgresStore{ + s := &PostgresStore{ db: db, tableName: tableName, capacity: capacity, } - // Create the table if it doesn't exist - if err := store.createTable(); err != nil { + if err := s.createTable(); err != nil { db.Close() return nil, fmt.Errorf("failed to create table: %w", err) } - return store, nil + return s, nil } -// createTable creates the required table if it doesn't exist func (s *PostgresStore) createTable() error { query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( @@ -71,26 +79,21 @@ func (s *PostgresStore) createTable() error { ) `, s.tableName) - _, err := s.db.Exec(query) - if err != nil { + if _, err := s.db.Exec(query); err != nil { return err } - // Create index on timestamp for faster retrieval indexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_timestamp_idx ON %s (timestamp DESC)", s.tableName, s.tableName) - _, err = s.db.Exec(indexQuery) - + _, err := s.db.Exec(indexQuery) return err } // Add adds a new request log to the store -func (s *PostgresStore) Add(reqLog *model.RequestLog) { - // Prepare all JSON fields properly +func (s *PostgresStore) Add(reqLog *model.RequestLog) error { reqHeaders := prepareJSON(reqLog.RequestHeaders) respHeaders := prepareJSON(reqLog.ResponseHeaders) - // Default to empty arrays/objects for JSON fields if they're nil or empty middlewareTrace := "[]" if len(reqLog.MiddlewareTrace) > 0 { if data, err := json.Marshal(reqLog.MiddlewareTrace); err == nil { @@ -105,11 +108,10 @@ func (s *PostgresStore) Add(reqLog *model.RequestLog) { } } - // Insert the log using string interpolation for JSON fields to avoid issues with parameter binding query := fmt.Sprintf(` INSERT INTO %s ( id, timestamp, method, path, query, request_headers, response_headers, - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, middleware_trace, route_trace ) VALUES ($1, $2, $3, $4, $5, $6::jsonb, $7::jsonb, $8, $9, $10, $11, $12, $13::jsonb, $14::jsonb) `, s.tableName) @@ -131,13 +133,14 @@ func (s *PostgresStore) Add(reqLog *model.RequestLog) { middlewareTrace, routeTrace, ) - if err != nil { - log.Printf("Failed to store request log in PostgreSQL: %v", err) + return fmt.Errorf("postgres insert: %w", err) } - // Clean up old logs - s.cleanup() + if s.insertCount.Add(1)%cleanupEveryN == 0 { + s.cleanup() + } + return nil } // prepareJSON ensures we have a valid JSON string @@ -145,13 +148,11 @@ func prepareJSON(v interface{}) string { if v == nil { return "{}" } - data, err := json.Marshal(v) if err != nil { - log.Printf("Failed to marshal JSON: %v", err) + log.Printf("govisual: failed to marshal JSON: %v", err) return "{}" } - return string(data) } @@ -159,17 +160,14 @@ func prepareJSON(v interface{}) string { func (s *PostgresStore) cleanup() { countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", s.tableName) var count int - err := s.db.QueryRow(countQuery).Scan(&count) - if err != nil { - log.Printf("Failed to count logs: %v", err) + if err := s.db.QueryRow(countQuery).Scan(&count); err != nil { + log.Printf("govisual: failed to count logs: %v", err) return } - if count <= s.capacity { return } - // Delete oldest logs deleteQuery := fmt.Sprintf(` DELETE FROM %s WHERE id IN ( @@ -179,20 +177,19 @@ func (s *PostgresStore) cleanup() { ) `, s.tableName, s.tableName) - _, err = s.db.Exec(deleteQuery, count-s.capacity) - if err != nil { - log.Printf("Failed to clean up old logs: %v", err) + if _, err := s.db.Exec(deleteQuery, count-s.capacity); err != nil { + log.Printf("govisual: failed to clean up old logs: %v", err) } } // Get retrieves a specific request log by its ID func (s *PostgresStore) Get(id string) (*model.RequestLog, bool) { query := fmt.Sprintf(` - SELECT - id, timestamp, method, path, query, + SELECT + id, timestamp, method, path, query, COALESCE(request_headers::text, '{}'), COALESCE(response_headers::text, '{}'), - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, COALESCE(middleware_trace::text, '[]'), COALESCE(route_trace::text, '{}') FROM %s @@ -223,20 +220,18 @@ func (s *PostgresStore) Get(id string) (*model.RequestLog, bool) { &middlewareTrace, &routeTrace, ) - if err != nil { if err == sql.ErrNoRows { return nil, false } - log.Printf("Failed to get request log from PostgreSQL: %v", err) + log.Printf("govisual: failed to get request log from PostgreSQL: %v", err) return nil, false } - // Unmarshal all JSON fields - json.Unmarshal([]byte(reqHeadersStr), &reqLog.RequestHeaders) - json.Unmarshal([]byte(respHeadersStr), &reqLog.ResponseHeaders) - json.Unmarshal([]byte(middlewareTrace), &reqLog.MiddlewareTrace) - json.Unmarshal([]byte(routeTrace), &reqLog.RouteTrace) + unmarshalLogJSON(reqHeadersStr, &reqLog.RequestHeaders, "request_headers", reqLog.ID) + unmarshalLogJSON(respHeadersStr, &reqLog.ResponseHeaders, "response_headers", reqLog.ID) + unmarshalLogJSON(middlewareTrace, &reqLog.MiddlewareTrace, "middleware_trace", reqLog.ID) + unmarshalLogJSON(routeTrace, &reqLog.RouteTrace, "route_trace", reqLog.ID) return &reqLog, true } @@ -244,11 +239,11 @@ func (s *PostgresStore) Get(id string) (*model.RequestLog, bool) { // GetAll returns all stored request logs func (s *PostgresStore) GetAll() []*model.RequestLog { query := fmt.Sprintf(` - SELECT - id, timestamp, method, path, query, + SELECT + id, timestamp, method, path, query, COALESCE(request_headers::text, '{}'), COALESCE(response_headers::text, '{}'), - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, COALESCE(middleware_trace::text, '[]'), COALESCE(route_trace::text, '{}') FROM %s @@ -261,11 +256,11 @@ func (s *PostgresStore) GetAll() []*model.RequestLog { // GetLatest returns the n most recent request logs func (s *PostgresStore) GetLatest(n int) []*model.RequestLog { query := fmt.Sprintf(` - SELECT - id, timestamp, method, path, query, + SELECT + id, timestamp, method, path, query, COALESCE(request_headers::text, '{}'), COALESCE(response_headers::text, '{}'), - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, COALESCE(middleware_trace::text, '[]'), COALESCE(route_trace::text, '{}') FROM %s @@ -276,11 +271,10 @@ func (s *PostgresStore) GetLatest(n int) []*model.RequestLog { return s.queryLogs(query, n) } -// queryLogs executes a query and returns the resulting log entries func (s *PostgresStore) queryLogs(query string, args ...interface{}) []*model.RequestLog { rows, err := s.db.Query(query, args...) if err != nil { - log.Printf("Failed to query logs from PostgreSQL: %v", err) + log.Printf("govisual: failed to query logs from PostgreSQL: %v", err) return nil } defer rows.Close() @@ -296,7 +290,7 @@ func (s *PostgresStore) queryLogs(query string, args ...interface{}) []*model.Re routeTrace string ) - err := rows.Scan( + if err := rows.Scan( &reqLog.ID, &reqLog.Timestamp, &reqLog.Method, @@ -311,24 +305,21 @@ func (s *PostgresStore) queryLogs(query string, args ...interface{}) []*model.Re &reqLog.Error, &middlewareTrace, &routeTrace, - ) - - if err != nil { - log.Printf("Failed to scan row: %v", err) + ); err != nil { + log.Printf("govisual: failed to scan row: %v", err) continue } - // Unmarshal all JSON fields, ignoring errors - json.Unmarshal([]byte(reqHeadersStr), &reqLog.RequestHeaders) - json.Unmarshal([]byte(respHeadersStr), &reqLog.ResponseHeaders) - json.Unmarshal([]byte(middlewareTrace), &reqLog.MiddlewareTrace) - json.Unmarshal([]byte(routeTrace), &reqLog.RouteTrace) + unmarshalLogJSON(reqHeadersStr, &reqLog.RequestHeaders, "request_headers", reqLog.ID) + unmarshalLogJSON(respHeadersStr, &reqLog.ResponseHeaders, "response_headers", reqLog.ID) + unmarshalLogJSON(middlewareTrace, &reqLog.MiddlewareTrace, "middleware_trace", reqLog.ID) + unmarshalLogJSON(routeTrace, &reqLog.RouteTrace, "route_trace", reqLog.ID) logs = append(logs, &reqLog) } if err := rows.Err(); err != nil { - log.Printf("Error iterating over rows: %v", err) + log.Printf("govisual: error iterating over rows: %v", err) } return logs @@ -337,11 +328,9 @@ func (s *PostgresStore) queryLogs(query string, args ...interface{}) []*model.Re // Clear clears all stored request logs func (s *PostgresStore) Clear() error { query := fmt.Sprintf("TRUNCATE TABLE %s", s.tableName) - _, err := s.db.Exec(query) - if err != nil { + if _, err := s.db.Exec(query); err != nil { return fmt.Errorf("failed to clear logs: %w", err) } - return nil } @@ -349,3 +338,14 @@ func (s *PostgresStore) Clear() error { func (s *PostgresStore) Close() error { return s.db.Close() } + +// unmarshalLogJSON is shared by all SQL stores so they all report unmarshal +// errors consistently instead of silently dropping fields. +func unmarshalLogJSON(s string, v interface{}, field, logID string) { + if s == "" { + return + } + if err := json.Unmarshal([]byte(s), v); err != nil { + log.Printf("govisual: failed to unmarshal %s for log %s: %v", field, logID, err) + } +} diff --git a/internal/store/redis.go b/internal/store/redis.go index 4d26b57..2be2d44 100644 --- a/internal/store/redis.go +++ b/internal/store/redis.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "sort" + "sync/atomic" "time" "github.com/doganarif/govisual/internal/model" @@ -14,11 +15,11 @@ import ( // RedisStore implements the Store interface with Redis as backend type RedisStore struct { - client *redis.Client - keyPrefix string - capacity int - ttl time.Duration - ctx context.Context + client *redis.Client + keyPrefix string + capacity int + ttl time.Duration + insertCount atomic.Uint64 } // NewRedisStore creates a new Redis-backed store @@ -26,23 +27,21 @@ func NewRedisStore(connStr string, capacity int, ttlSeconds int) (*RedisStore, e if capacity <= 0 { capacity = 100 } - if ttlSeconds <= 0 { - ttlSeconds = 86400 // Default to 24 hours + ttlSeconds = 86400 // 24h } - // Parse the Redis connection string opts, err := redis.ParseURL(connStr) if err != nil { return nil, fmt.Errorf("invalid Redis connection string: %w", err) } - // Create Redis client client := redis.NewClient(opts) - ctx := context.Background() - // Test the connection - if err := client.Ping(ctx).Err(); err != nil { + pingCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := client.Ping(pingCtx).Err(); err != nil { + client.Close() return nil, fmt.Errorf("failed to connect to Redis: %w", err) } @@ -51,216 +50,193 @@ func NewRedisStore(connStr string, capacity int, ttlSeconds int) (*RedisStore, e keyPrefix: "govisual:", capacity: capacity, ttl: time.Duration(ttlSeconds) * time.Second, - ctx: ctx, }, nil } -// Add adds a new request log to the store -func (s *RedisStore) Add(reqLog *model.RequestLog) { - // Convert log to JSON +// opCtx returns a short-lived context for a single Redis call. Stores must not +// hang onto a context for their entire lifetime. +func (s *RedisStore) opCtx() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 10*time.Second) +} + +func (s *RedisStore) Add(reqLog *model.RequestLog) error { data, err := json.Marshal(reqLog) if err != nil { - reqLog.Error = fmt.Sprintf("Failed to marshal log: %v", err) - return + return fmt.Errorf("redis marshal: %w", err) } - // Key for the log - key := s.keyPrefix + reqLog.ID + ctx, cancel := s.opCtx() + defer cancel() - // Store log as a JSON string - if err := s.client.Set(s.ctx, key, data, s.ttl).Err(); err != nil { - log.Printf("Failed to store log in Redis: %v", err) - return + key := s.keyPrefix + reqLog.ID + if err := s.client.Set(ctx, key, data, s.ttl).Err(); err != nil { + return fmt.Errorf("redis set: %w", err) } - // Add to sorted set for time-ordered access score := float64(reqLog.Timestamp.UnixNano()) - if err := s.client.ZAdd(s.ctx, s.keyPrefix+"logs", &redis.Z{ + if err := s.client.ZAdd(ctx, s.keyPrefix+"logs", &redis.Z{ Score: score, Member: reqLog.ID, }).Err(); err != nil { - log.Printf("Failed to add log ID to sorted set: %v", err) + return fmt.Errorf("redis zadd: %w", err) } - // Clean up old logs - s.cleanup() + if s.insertCount.Add(1)%cleanupEveryN == 0 { + s.cleanup() + } + return nil } -// cleanup removes old logs to maintain the capacity limit func (s *RedisStore) cleanup() { - // Get the current number of logs - count, err := s.client.ZCard(s.ctx, s.keyPrefix+"logs").Result() + ctx, cancel := s.opCtx() + defer cancel() + + count, err := s.client.ZCard(ctx, s.keyPrefix+"logs").Result() if err != nil { - log.Printf("Failed to count logs: %v", err) + log.Printf("govisual: failed to count Redis logs: %v", err) return } - if count <= int64(s.capacity) { return } - // Get the oldest log IDs that exceed capacity - oldestIDs, err := s.client.ZRange(s.ctx, s.keyPrefix+"logs", 0, count-int64(s.capacity)-1).Result() + oldestIDs, err := s.client.ZRange(ctx, s.keyPrefix+"logs", 0, count-int64(s.capacity)-1).Result() if err != nil { - log.Printf("Failed to get oldest log IDs: %v", err) + log.Printf("govisual: failed to get oldest Redis log IDs: %v", err) return } - if len(oldestIDs) == 0 { return } - // Create a pipeline for batch operations pipe := s.client.Pipeline() - - // Remove from sorted set - pipe.ZRem(s.ctx, s.keyPrefix+"logs", oldestIDs) - - // Remove each log + // ZRem takes ...interface{}; convert from []string. + members := make([]interface{}, len(oldestIDs)) + for i, id := range oldestIDs { + members[i] = id + } + pipe.ZRem(ctx, s.keyPrefix+"logs", members...) for _, id := range oldestIDs { - pipe.Del(s.ctx, s.keyPrefix+id) + pipe.Del(ctx, s.keyPrefix+id) } - - // Execute pipeline - if _, err := pipe.Exec(s.ctx); err != nil { - log.Printf("Failed to clean up old logs: %v", err) + if _, err := pipe.Exec(ctx); err != nil { + log.Printf("govisual: failed to clean up old Redis logs: %v", err) } } -// Get retrieves a specific request log by its ID func (s *RedisStore) Get(id string) (*model.RequestLog, bool) { - key := s.keyPrefix + id + ctx, cancel := s.opCtx() + defer cancel() - // Get log data from Redis - data, err := s.client.Get(s.ctx, key).Bytes() + data, err := s.client.Get(ctx, s.keyPrefix+id).Bytes() if err != nil { if err == redis.Nil { return nil, false } - log.Printf("Failed to get log from Redis: %v", err) + log.Printf("govisual: failed to get log from Redis: %v", err) return nil, false } - // Unmarshal data var reqLog model.RequestLog if err := json.Unmarshal(data, &reqLog); err != nil { - log.Printf("Failed to unmarshal log data: %v", err) + log.Printf("govisual: failed to unmarshal Redis log: %v", err) return nil, false } - return &reqLog, true } -// GetAll returns all stored request logs func (s *RedisStore) GetAll() []*model.RequestLog { - // Get all log IDs from the sorted set, in reverse order (newest first) - ids, err := s.client.ZRevRange(s.ctx, s.keyPrefix+"logs", 0, -1).Result() + ctx, cancel := s.opCtx() + defer cancel() + + ids, err := s.client.ZRevRange(ctx, s.keyPrefix+"logs", 0, -1).Result() if err != nil { - log.Printf("Failed to get log IDs: %v", err) + log.Printf("govisual: failed to get Redis log IDs: %v", err) return nil } - - return s.getLogs(ids) + return s.getLogs(ctx, ids) } -// GetLatest returns the n most recent request logs func (s *RedisStore) GetLatest(n int) []*model.RequestLog { - // Get the n newest log IDs - ids, err := s.client.ZRevRange(s.ctx, s.keyPrefix+"logs", 0, int64(n-1)).Result() + ctx, cancel := s.opCtx() + defer cancel() + + ids, err := s.client.ZRevRange(ctx, s.keyPrefix+"logs", 0, int64(n-1)).Result() if err != nil { - log.Printf("Failed to get latest log IDs: %v", err) + log.Printf("govisual: failed to get latest Redis log IDs: %v", err) return nil } - - return s.getLogs(ids) + return s.getLogs(ctx, ids) } -// getLogs retrieves logs by their IDs -func (s *RedisStore) getLogs(ids []string) []*model.RequestLog { +func (s *RedisStore) getLogs(ctx context.Context, ids []string) []*model.RequestLog { if len(ids) == 0 { return nil } - // Use a pipeline for batch retrieval pipe := s.client.Pipeline() - cmds := make(map[string]*redis.StringCmd) - - // Queue up the get commands + cmds := make(map[string]*redis.StringCmd, len(ids)) for _, id := range ids { - cmds[id] = pipe.Get(s.ctx, s.keyPrefix+id) + cmds[id] = pipe.Get(ctx, s.keyPrefix+id) } - - // Execute pipeline - _, err := pipe.Exec(s.ctx) - if err != nil && err != redis.Nil { - log.Printf("Failed to execute pipeline: %v", err) + if _, err := pipe.Exec(ctx); err != nil && err != redis.Nil { + log.Printf("govisual: failed to execute Redis pipeline: %v", err) return nil } - // Process results logs := make([]*model.RequestLog, 0, len(ids)) - idToLog := make(map[string]*model.RequestLog) - for id, cmd := range cmds { data, err := cmd.Bytes() if err != nil { if err != redis.Nil { - log.Printf("Failed to get log %s: %v", id, err) + log.Printf("govisual: failed to get Redis log %s: %v", id, err) } continue } - var reqLog model.RequestLog if err := json.Unmarshal(data, &reqLog); err != nil { - log.Printf("Failed to unmarshal log data for %s: %v", id, err) + log.Printf("govisual: failed to unmarshal Redis log %s: %v", id, err) continue } - logs = append(logs, &reqLog) - idToLog[id] = &reqLog } - // Sort logs by timestamp (newest first) sort.Slice(logs, func(i, j int) bool { return logs[i].Timestamp.After(logs[j].Timestamp) }) - return logs } -// Clear removes all logs from the store func (s *RedisStore) Clear() error { - // Get all log IDs - ids, err := s.client.ZRange(s.ctx, s.keyPrefix+"logs", 0, -1).Result() + ctx, cancel := s.opCtx() + defer cancel() + + ids, err := s.client.ZRange(ctx, s.keyPrefix+"logs", 0, -1).Result() if err != nil { return fmt.Errorf("failed to get log IDs: %w", err) } - // Create a pipeline for batch operations pipe := s.client.Pipeline() - - // Remove from sorted set - pipe.ZRem(s.ctx, s.keyPrefix+"logs", ids) - - // Remove each log from Redis - for _, id := range ids { - pipe.Unlink(s.ctx, s.keyPrefix+id) + if len(ids) > 0 { + members := make([]interface{}, len(ids)) + for i, id := range ids { + members[i] = id + } + pipe.ZRem(ctx, s.keyPrefix+"logs", members...) + for _, id := range ids { + pipe.Unlink(ctx, s.keyPrefix+id) + } } - - // Execute pipeline - if _, err := pipe.Exec(s.ctx); err != nil { + if _, err := pipe.Exec(ctx); err != nil { return fmt.Errorf("failed to clear logs: %w", err) } - // Delete the sorted set - if err := s.client.Del(s.ctx, s.keyPrefix+"logs").Err(); err != nil { + if err := s.client.Del(ctx, s.keyPrefix+"logs").Err(); err != nil { return fmt.Errorf("failed to delete sorted set: %w", err) } - return nil } -// Close closes the Redis client connection func (s *RedisStore) Close() error { return s.client.Close() } diff --git a/internal/store/sqlite.go b/internal/store/sqlite.go index aab4013..aa7a190 100644 --- a/internal/store/sqlite.go +++ b/internal/store/sqlite.go @@ -5,126 +5,93 @@ import ( "encoding/json" "fmt" "log" - "regexp" - "sync" + "sync/atomic" "github.com/doganarif/govisual/internal/model" - // Don't import and register SQLite driver automatically - // _ "github.com/ncruces/go-sqlite3/driver" - // _ "github.com/ncruces/go-sqlite3/embed" ) -// SQLiteStore implements the Store interface with SQLite as backend +// SQLiteStore implements the Store interface with SQLite as backend. +// +// SQLite driver registration is the caller's responsibility — govisual does +// not import a driver to avoid forcing a specific implementation on users. +// Register your preferred driver (e.g. mattn/go-sqlite3 or ncruces/go-sqlite3) +// before calling NewSQLiteStore, or use NewSQLiteStoreWithDB with a pre-built +// *sql.DB. type SQLiteStore struct { - db *sql.DB - tableName string - capacity int - // Add a flag to track if we own the connection + db *sql.DB + tableName string + capacity int ownsConnection bool + insertCount atomic.Uint64 } -// RegisterSQLiteDriver registers the SQLite driver with database/sql -// This should only be called if you don't already have a SQLite driver registered -var registerOnce sync.Once -var registerError error - -func RegisterSQLiteDriver() error { - registerOnce.Do(func() { - // Dynamically import and register - // Attempt to import the SQLite driver - registerError = initSQLiteDriver() - }) - return registerError -} - -// initSQLiteDriver is a helper function that actually initializes the driver -func initSQLiteDriver() error { - // We're not directly importing the driver to avoid auto-registration - // Your application should use its preferred SQLite driver - // This should only be called if you don't already have a SQLite driver - return fmt.Errorf("you need to register a SQLite driver or use WithSQLiteStorageDB with an existing database connection") -} - -// isValidTableName checks if a table name contains only alphanumeric and underscore characters -func isValidTableName(tableName string) bool { - match, _ := regexp.MatchString(`^[a-zA-Z0-9_]+$`, tableName) - return match -} - -// NewSQLiteStore creates a new SQLite-backed store +// NewSQLiteStore creates a new SQLite-backed store. +// dbPath is forwarded to sql.Open("sqlite3", dbPath); ensure a SQLite driver +// is already registered under the name "sqlite3". func NewSQLiteStore(dbPath, tableName string, capacity int) (*SQLiteStore, error) { if capacity <= 0 { capacity = 100 } - // Validate table name to prevent SQL injection - if !isValidTableName(tableName) { - return nil, fmt.Errorf("invalid table name: table name can only contain letters, numbers, and underscores") + if !IsValidTableName(tableName) { + return nil, fmt.Errorf("invalid table name %q: must match [A-Za-z_][A-Za-z0-9_]*", tableName) } - // Connect to the database db, err := sql.Open("sqlite3", dbPath) if err != nil { return nil, fmt.Errorf("failed to open SQLite DB: %w", err) } - // Test the connection if err := db.Ping(); err != nil { + db.Close() return nil, fmt.Errorf("failed to ping SQLite DB: %w", err) } - store := &SQLiteStore{ + s := &SQLiteStore{ db: db, tableName: tableName, capacity: capacity, ownsConnection: true, } - // Create the table if it doesn't exist - if err := store.createTable(); err != nil { + if err := s.createTable(); err != nil { db.Close() return nil, fmt.Errorf("failed to create table: %w", err) } - return store, nil + return s, nil } -// NewSQLiteStoreWithDB creates a new SQLite store with an existing database connection +// NewSQLiteStoreWithDB creates a new SQLite store with an existing database connection. func NewSQLiteStoreWithDB(db *sql.DB, tableName string, capacity int) (*SQLiteStore, error) { if db == nil { return nil, fmt.Errorf("database connection cannot be nil") } - if capacity <= 0 { capacity = 100 } - - // Validate table name to prevent SQL injection - if !isValidTableName(tableName) { - return nil, fmt.Errorf("invalid table name: table name can only contain letters, numbers, and underscores") + if !IsValidTableName(tableName) { + return nil, fmt.Errorf("invalid table name %q: must match [A-Za-z_][A-Za-z0-9_]*", tableName) } - // Test the connection if err := db.Ping(); err != nil { return nil, fmt.Errorf("failed to ping SQLite DB: %w", err) } - store := &SQLiteStore{ + s := &SQLiteStore{ db: db, tableName: tableName, capacity: capacity, ownsConnection: false, } - // Create the table if it doesn't exist - if err := store.createTable(); err != nil { + if err := s.createTable(); err != nil { return nil, fmt.Errorf("failed to create table: %w", err) } - return store, nil + return s, nil } -// createTable creates the required table if it doesn't exist func (s *SQLiteStore) createTable() error { query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( @@ -146,21 +113,17 @@ func (s *SQLiteStore) createTable() error { ) `, s.tableName) - _, err := s.db.Exec(query) - if err != nil { + if _, err := s.db.Exec(query); err != nil { return err } - // Create index on timestamp for faster retrieval indexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s_timestamp_idx ON %s(timestamp DESC)", s.tableName, s.tableName) - _, err = s.db.Exec(indexQuery) - + _, err := s.db.Exec(indexQuery) return err } -// Add adds a new request log to the store -func (s *SQLiteStore) Add(reqLog *model.RequestLog) { +func (s *SQLiteStore) Add(reqLog *model.RequestLog) error { reqHeaders := prepareJSON(reqLog.RequestHeaders) respHeaders := prepareJSON(reqLog.ResponseHeaders) @@ -181,7 +144,7 @@ func (s *SQLiteStore) Add(reqLog *model.RequestLog) { query := fmt.Sprintf(` INSERT OR REPLACE INTO %s ( id, timestamp, method, path, query, request_headers, response_headers, - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, middleware_trace, route_trace ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) `, s.tableName) @@ -203,29 +166,27 @@ func (s *SQLiteStore) Add(reqLog *model.RequestLog) { middlewareTrace, routeTrace, ) - if err != nil { - log.Printf("Failed to store request log in SQLite: %v", err) + return fmt.Errorf("sqlite insert: %w", err) } - s.cleanup() + if s.insertCount.Add(1)%cleanupEveryN == 0 { + s.cleanup() + } + return nil } -// cleanup removes old logs to maintain the capacity limit func (s *SQLiteStore) cleanup() { countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", s.tableName) var count int - err := s.db.QueryRow(countQuery).Scan(&count) - if err != nil { - log.Printf("Failed to count logs: %v", err) + if err := s.db.QueryRow(countQuery).Scan(&count); err != nil { + log.Printf("govisual: failed to count logs: %v", err) return } - if count <= s.capacity { return } - // Delete oldest logs deleteQuery := fmt.Sprintf(` DELETE FROM %s WHERE id IN ( @@ -235,20 +196,18 @@ func (s *SQLiteStore) cleanup() { ) `, s.tableName, s.tableName) - _, err = s.db.Exec(deleteQuery, count-s.capacity) - if err != nil { - log.Printf("Failed to clean up old logs: %v", err) + if _, err := s.db.Exec(deleteQuery, count-s.capacity); err != nil { + log.Printf("govisual: failed to clean up old logs: %v", err) } } -// Get retrieves a specific request log by its ID func (s *SQLiteStore) Get(id string) (*model.RequestLog, bool) { query := fmt.Sprintf(` - SELECT - id, timestamp, method, path, query, + SELECT + id, timestamp, method, path, query, COALESCE(request_headers, '{}'), COALESCE(response_headers, '{}'), - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, COALESCE(middleware_trace, '[]'), COALESCE(route_trace, '{}') FROM %s @@ -279,31 +238,29 @@ func (s *SQLiteStore) Get(id string) (*model.RequestLog, bool) { &middlewareTrace, &routeTrace, ) - if err != nil { if err == sql.ErrNoRows { return nil, false } - log.Printf("Failed to get request log from SQLite: %v", err) + log.Printf("govisual: failed to get request log from SQLite: %v", err) return nil, false } - json.Unmarshal([]byte(reqHeadersStr), &reqLog.RequestHeaders) - json.Unmarshal([]byte(respHeadersStr), &reqLog.ResponseHeaders) - json.Unmarshal([]byte(middlewareTrace), &reqLog.MiddlewareTrace) - json.Unmarshal([]byte(routeTrace), &reqLog.RouteTrace) + unmarshalLogJSON(reqHeadersStr, &reqLog.RequestHeaders, "request_headers", reqLog.ID) + unmarshalLogJSON(respHeadersStr, &reqLog.ResponseHeaders, "response_headers", reqLog.ID) + unmarshalLogJSON(middlewareTrace, &reqLog.MiddlewareTrace, "middleware_trace", reqLog.ID) + unmarshalLogJSON(routeTrace, &reqLog.RouteTrace, "route_trace", reqLog.ID) return &reqLog, true } -// GetAll returns all stored request logs func (s *SQLiteStore) GetAll() []*model.RequestLog { query := fmt.Sprintf(` - SELECT - id, timestamp, method, path, query, + SELECT + id, timestamp, method, path, query, COALESCE(request_headers, '{}'), COALESCE(response_headers, '{}'), - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, COALESCE(middleware_trace, '[]'), COALESCE(route_trace, '{}') FROM %s @@ -313,14 +270,13 @@ func (s *SQLiteStore) GetAll() []*model.RequestLog { return s.queryLogs(query) } -// GetLatest returns the n most recent request logs func (s *SQLiteStore) GetLatest(n int) []*model.RequestLog { query := fmt.Sprintf(` - SELECT - id, timestamp, method, path, query, + SELECT + id, timestamp, method, path, query, COALESCE(request_headers, '{}'), COALESCE(response_headers, '{}'), - status_code, duration, request_body, response_body, error, + status_code, duration, request_body, response_body, error, COALESCE(middleware_trace, '[]'), COALESCE(route_trace, '{}') FROM %s @@ -331,17 +287,15 @@ func (s *SQLiteStore) GetLatest(n int) []*model.RequestLog { return s.queryLogs(query, n) } -// queryLogs executes a query and returns the resulting log entries func (s *SQLiteStore) queryLogs(query string, args ...interface{}) []*model.RequestLog { rows, err := s.db.Query(query, args...) if err != nil { - log.Printf("Failed to query logs from SQLite: %v", err) + log.Printf("govisual: failed to query logs from SQLite: %v", err) return nil } defer rows.Close() var logs []*model.RequestLog - for rows.Next() { var ( reqLog model.RequestLog @@ -350,8 +304,7 @@ func (s *SQLiteStore) queryLogs(query string, args ...interface{}) []*model.Requ middlewareTrace string routeTrace string ) - - err := rows.Scan( + if err := rows.Scan( &reqLog.ID, &reqLog.Timestamp, &reqLog.Method, @@ -366,42 +319,35 @@ func (s *SQLiteStore) queryLogs(query string, args ...interface{}) []*model.Requ &reqLog.Error, &middlewareTrace, &routeTrace, - ) - - if err != nil { - log.Printf("Failed to scan row: %v", err) + ); err != nil { + log.Printf("govisual: failed to scan row: %v", err) continue } - json.Unmarshal([]byte(reqHeadersStr), &reqLog.RequestHeaders) - json.Unmarshal([]byte(respHeadersStr), &reqLog.ResponseHeaders) - json.Unmarshal([]byte(middlewareTrace), &reqLog.MiddlewareTrace) - json.Unmarshal([]byte(routeTrace), &reqLog.RouteTrace) + unmarshalLogJSON(reqHeadersStr, &reqLog.RequestHeaders, "request_headers", reqLog.ID) + unmarshalLogJSON(respHeadersStr, &reqLog.ResponseHeaders, "response_headers", reqLog.ID) + unmarshalLogJSON(middlewareTrace, &reqLog.MiddlewareTrace, "middleware_trace", reqLog.ID) + unmarshalLogJSON(routeTrace, &reqLog.RouteTrace, "route_trace", reqLog.ID) logs = append(logs, &reqLog) } if err := rows.Err(); err != nil { - log.Printf("Error iterating over rows: %v", err) + log.Printf("govisual: error iterating over rows: %v", err) } return logs } -// Clear removes all logs from the store func (s *SQLiteStore) Clear() error { query := fmt.Sprintf("DELETE FROM %s", s.tableName) - _, err := s.db.Exec(query) - if err != nil { + if _, err := s.db.Exec(query); err != nil { return fmt.Errorf("failed to clear logs: %w", err) } - return nil } -// Close closes the database connection func (s *SQLiteStore) Close() error { - // Only close the connection if we own it if s.ownsConnection { return s.db.Close() } diff --git a/internal/store/store.go b/internal/store/store.go index beade9a..4fd72dd 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -1,11 +1,16 @@ package store -import "github.com/doganarif/govisual/internal/model" +import ( + "regexp" + + "github.com/doganarif/govisual/internal/model" +) // Store defines the interface for all storage backends type Store interface { - // Add adds a new request log to the store - Add(log *model.RequestLog) + // Add stores a new request log. Returns an error so callers can surface + // storage failures (otherwise the dashboard silently drops entries). + Add(log *model.RequestLog) error // Get retrieves a specific request log by its ID Get(id string) (*model.RequestLog, bool) @@ -22,3 +27,14 @@ type Store interface { // Close closes any open connections Close() error } + +// validTableName matches identifiers safe to inject into SQL via fmt.Sprintf. +// Letters, digits and underscore only — no quoting, dots, or whitespace. +var validTableName = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) + +// IsValidTableName reports whether tableName is safe to interpolate into a +// SQL statement. It is consulted by every SQL-backed store before any +// query is constructed; never bypass it. +func IsValidTableName(tableName string) bool { + return validTableName.MatchString(tableName) +} diff --git a/options.go b/options.go index 45ec23d..25f75fa 100644 --- a/options.go +++ b/options.go @@ -1,8 +1,11 @@ package govisual import ( + "context" + "crypto/subtle" "database/sql" "fmt" + "net/http" "path/filepath" "strings" "time" @@ -11,6 +14,11 @@ import ( "github.com/doganarif/govisual/internal/store" ) +// DashboardAuth authorizes a request to the dashboard. Return true to allow, +// false to deny (govisual sends an HTTP 401). Implementations should be +// constant-time when comparing secrets. +type DashboardAuth func(r *http.Request) bool + type Config struct { MaxRequests int @@ -20,6 +28,11 @@ type Config struct { LogResponseBody bool + // MaxBodyBytes caps the captured request and response body size. + // 0 (default) means use middleware.DefaultMaxBodyBytes (1 MiB). + // Set to -1 to disable the cap entirely (NOT recommended). + MaxBodyBytes int + IgnorePaths []string // OpenTelemetry configuration @@ -58,6 +71,41 @@ type Config struct { ProfileThreshold time.Duration MaxProfileMetrics int + + // Dashboard security ---------------------------------------------------- + + // DashboardAuth, if set, must approve every request to the dashboard. + // If nil, the dashboard is fully open — only safe for local dev. + DashboardAuth DashboardAuth + + // LocalhostOnly, when true, rejects dashboard requests whose remote address + // is not a loopback IP. This is the safest default for "I'm just debugging + // locally" — even with the rest of the server bound to 0.0.0.0. + LocalhostOnly bool + + // EnableReplay enables the POST /__viz/api/replay endpoint, which lets the + // dashboard fire arbitrary HTTP requests from the server. Disabled by + // default because it is a powerful SSRF primitive if the dashboard is + // reachable by an attacker. + EnableReplay bool + + // ExposeSystemInfo controls whether the GET /__viz/api/system-info endpoint + // is enabled. Disabled by default; enabling exposes runtime info (hostname, + // Go version, memory stats). + ExposeSystemInfo bool + + // ExposeEnvVars is an explicit allowlist of environment variable names that + // the system-info endpoint may surface. Anything not in this set is omitted + // entirely (NOT redacted) so an attacker cannot infer the existence of a + // sensitive name. + ExposeEnvVars []string + + // ShutdownContext, if set, triggers graceful shutdown of govisual-owned + // resources (storage backends, OpenTelemetry tracer provider) when the + // context is cancelled. This replaces the prior behavior of registering a + // global signal handler that called os.Exit — a library has no business + // killing the host process. + ShutdownContext context.Context } // Option is a function that modifies the configuration @@ -91,6 +139,17 @@ func WithResponseBodyLogging(enabled bool) Option { } } +// WithMaxBodyBytes caps the captured request and response body size. +// Values: +// - 0: use the package default (1 MiB) +// - >0: explicit cap in bytes +// - <0: disable cap (unbounded — be careful with large downloads) +func WithMaxBodyBytes(n int) Option { + return func(c *Config) { + c.MaxBodyBytes = n + } +} + // WithIgnorePaths sets the path patterns to ignore func WithIgnorePaths(patterns ...string) Option { return func(c *Config) { @@ -193,30 +252,25 @@ func WithMongoDBStorage(uri, databaseName, collectionName string) Option { } } -// ShouldIgnorePath checks if a path should be ignored based on the configured patterns -// ShouldIgnorePath checks if a path should be ignored based on the configured patterns +// ShouldIgnorePath checks if a path should be ignored based on the configured patterns. func (c *Config) ShouldIgnorePath(path string) bool { - // First check if it's the dashboard path which should always be ignored to prevent recursive logging + // The dashboard itself must always be ignored, otherwise opening it + // would recursively log every poll. if path == c.DashboardPath || strings.HasPrefix(path, c.DashboardPath+"/") { return true } - // Then check against provided ignore patterns for _, pattern := range c.IgnorePaths { - matched, err := filepath.Match(pattern, path) - if err == nil && matched { + if matched, err := filepath.Match(pattern, path); err == nil && matched { return true } - - // Special handling for path groups with trailing slash + // Trailing-slash patterns are treated as "prefix match". if len(pattern) > 0 && pattern[len(pattern)-1] == '/' { - // If pattern ends with /, check if path starts with pattern - if len(path) >= len(pattern) && path[:len(pattern)] == pattern { + if strings.HasPrefix(path, pattern) { return true } } } - return false } @@ -248,6 +302,78 @@ func WithMaxProfileMetrics(max int) Option { } } +// WithDashboardAuth installs a custom authentication function for the dashboard. +// The function runs on every dashboard request and must return true to allow access. +func WithDashboardAuth(fn DashboardAuth) Option { + return func(c *Config) { + c.DashboardAuth = fn + } +} + +// WithBasicAuth protects the dashboard with HTTP Basic Auth using a constant-time +// comparison. Both username and password are required. +func WithBasicAuth(username, password string) Option { + expectedUser := []byte(username) + expectedPass := []byte(password) + return func(c *Config) { + c.DashboardAuth = func(r *http.Request) bool { + user, pass, ok := r.BasicAuth() + if !ok { + return false + } + userOK := subtle.ConstantTimeCompare([]byte(user), expectedUser) == 1 + passOK := subtle.ConstantTimeCompare([]byte(pass), expectedPass) == 1 + return userOK && passOK + } + } +} + +// WithLocalhostOnly restricts the dashboard to requests originating from a +// loopback address. Combine with WithDashboardAuth/WithBasicAuth for defense +// in depth. +func WithLocalhostOnly() Option { + return func(c *Config) { + c.LocalhostOnly = true + } +} + +// WithReplayEnabled enables the dashboard's /api/replay endpoint. Disabled by +// default because the endpoint, if reachable, lets a caller make the server +// perform arbitrary outbound HTTP requests (an SSRF primitive). Only enable +// behind authentication and/or localhost-only access. +func WithReplayEnabled(enabled bool) Option { + return func(c *Config) { + c.EnableReplay = enabled + } +} + +// WithSystemInfo enables the dashboard's /api/system-info endpoint and +// optionally sets the allowlist of environment variable names to expose. +// Pass no names to enable the endpoint but expose nothing (memory/runtime +// info only). +func WithSystemInfo(envAllowlist ...string) Option { + return func(c *Config) { + c.ExposeSystemInfo = true + c.ExposeEnvVars = append(c.ExposeEnvVars, envAllowlist...) + } +} + +// WithShutdownContext wires govisual's internal cleanup (storage backends, +// OpenTelemetry shutdown) to a caller-provided context. When the context is +// cancelled, govisual releases its resources. Replaces the prior behavior of +// installing a global signal handler that called os.Exit. +// +// Note: govisual spawns one goroutine that blocks on ctx.Done() for the +// lifetime of the wrapped handler. If you never cancel the context (for +// example, by passing context.Background()), that goroutine is retained for +// the process lifetime — harmless in long-running services, but tests should +// pass a cancellable context (e.g. t.Context()) to avoid leaks across cases. +func WithShutdownContext(ctx context.Context) Option { + return func(c *Config) { + c.ShutdownContext = ctx + } +} + // defaultConfig returns the default configuration func defaultConfig() *Config { return &Config{ @@ -255,6 +381,7 @@ func defaultConfig() *Config { DashboardPath: "/__viz", LogRequestBody: false, LogResponseBody: false, + MaxBodyBytes: 0, // 0 => use middleware.DefaultMaxBodyBytes IgnorePaths: []string{}, EnableOpenTelemetry: false, ServiceName: "govisual", @@ -269,5 +396,7 @@ func defaultConfig() *Config { ProfileType: profiling.ProfileAll, ProfileThreshold: 10 * time.Millisecond, MaxProfileMetrics: 1000, + EnableReplay: false, + ExposeSystemInfo: false, } } diff --git a/wrap.go b/wrap.go index d3c1009..62d61c7 100644 --- a/wrap.go +++ b/wrap.go @@ -3,12 +3,10 @@ package govisual import ( "context" "log" + "net" "net/http" - "os" - "os/signal" "strings" - "sync" - "syscall" + "time" "github.com/doganarif/govisual/internal/dashboard" "github.com/doganarif/govisual/internal/middleware" @@ -17,68 +15,19 @@ import ( "github.com/doganarif/govisual/internal/telemetry" ) -var ( - // Global signal handler to ensure we only have one - signalOnce sync.Once - shutdownFuncs []func(context.Context) error - shutdownMutex sync.Mutex -) - -// addShutdownFunc adds a shutdown function to be called on signal -func addShutdownFunc(fn func(context.Context) error) { - if fn == nil { - log.Println("Warning: Attempted to register nil shutdown function, ignoring") - return - } - shutdownMutex.Lock() - defer shutdownMutex.Unlock() - shutdownFuncs = append(shutdownFuncs, fn) -} - -// setupSignalHandler sets up a single signal handler for all cleanup operations -func setupSignalHandler() { - signalOnce.Do(func() { - signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) - - go func() { - sig := <-signals - log.Printf("Received shutdown signal (%v), cleaning up...", sig) - - ctx := context.Background() - shutdownMutex.Lock() - funcs := make([]func(context.Context) error, len(shutdownFuncs)) - copy(funcs, shutdownFuncs) - shutdownMutex.Unlock() - - // Execute all shutdown functions - for _, fn := range funcs { - if err := fn(ctx); err != nil { - log.Printf("Error during shutdown: %v", err) - } - } - - log.Println("Cleanup completed, exiting...") - - // Stop listening for more signals and exit - signal.Stop(signals) - os.Exit(0) - }() - }) -} - -// Wrap wraps an http.Handler with request visualization middleware +// Wrap wraps an http.Handler with the govisual request visualization middleware +// and mounts the dashboard at config.DashboardPath. Pass options to customize +// behavior. To trigger graceful shutdown of storage and telemetry resources, +// pass WithShutdownContext — govisual will release its resources when that +// context is cancelled. Govisual deliberately does NOT register a signal +// handler; that is the host application's job. func Wrap(handler http.Handler, opts ...Option) http.Handler { - // Apply options to default config config := defaultConfig() for _, opt := range opts { opt(config) } - // Create store based on configuration var requestStore store.Store - var err error - storeConfig := &store.StorageConfig{ Type: config.StorageType, Capacity: config.MaxRequests, @@ -87,41 +36,39 @@ func Wrap(handler http.Handler, opts ...Option) http.Handler { TTL: config.RedisTTL, ExistingDB: config.ExistingDB, } - - requestStore, err = store.NewStore(storeConfig) + rs, err := store.NewStore(storeConfig) if err != nil { - log.Printf("Failed to create configured storage backend: %v. Falling back to in-memory storage.", err) + log.Printf("govisual: failed to create configured storage backend: %v. Falling back to in-memory storage.", err) requestStore = store.NewInMemoryStore(config.MaxRequests) + } else { + requestStore = rs } - // Add store cleanup to shutdown functions - addShutdownFunc(func(ctx context.Context) error { - if err := requestStore.Close(); err != nil { - log.Printf("Error closing storage: %v", err) - return err - } - return nil - }) - - // Create profiler if enabled var profiler *profiling.Profiler if config.EnableProfiling { profiler = profiling.NewProfiler(config.MaxProfileMetrics) - profiler.SetEnabled(config.EnableProfiling) + profiler.SetEnabled(true) profiler.SetProfileType(config.ProfileType) profiler.SetThreshold(config.ProfileThreshold) - log.Printf("Performance profiling enabled with threshold: %v", config.ProfileThreshold) + log.Printf("govisual: performance profiling enabled (threshold=%v)", config.ProfileThreshold) } - // Create middleware wrapper with profiling support var wrapped http.Handler if profiler != nil { - wrapped = middleware.WrapWithProfiling(handler, requestStore, config.LogRequestBody, config.LogResponseBody, config, profiler) + wrapped = middleware.WrapWithProfilingAndLimits( + handler, requestStore, + config.LogRequestBody, config.LogResponseBody, + config, profiler, config.effectiveMaxBody(), + ) } else { - wrapped = middleware.Wrap(handler, requestStore, config.LogRequestBody, config.LogResponseBody, config) + wrapped = middleware.WrapWithLimits( + handler, requestStore, + config.LogRequestBody, config.LogResponseBody, + config, config.effectiveMaxBody(), + ) } - // Initialize OpenTelemetry if enabled + var otelShutdown func(context.Context) error if config.EnableOpenTelemetry { ctx := context.Background() otelConfig := telemetry.Config{ @@ -133,32 +80,91 @@ func Wrap(handler http.Handler, opts ...Option) http.Handler { } shutdown, err := telemetry.InitTracer(ctx, otelConfig) if err != nil { - log.Printf("Failed to initialize OpenTelemetry: %v", err) + log.Printf("govisual: failed to initialize OpenTelemetry: %v", err) } else { - log.Printf("OpenTelemetry initialized with service name: %s, endpoint: %s", config.ServiceName, config.OTelEndpoint) - - // Add OpenTelemetry shutdown to shutdown functions - addShutdownFunc(shutdown) - - // Wrap with OpenTelemetry middleware + log.Printf("govisual: OpenTelemetry initialized (service=%s endpoint=%s)", config.ServiceName, config.OTelEndpoint) + otelShutdown = shutdown wrapped = middleware.NewOTelMiddleware(wrapped, config.ServiceName, config.ServiceVersion) } } - // Set up the single signal handler - setupSignalHandler() + if config.ShutdownContext != nil { + // NOTE: this goroutine waits on ctx.Done() and is retained for the + // process lifetime if the context is never cancelled. Callers passing a + // non-cancellable context (e.g. context.Background()) should be aware + // of this — in tests, prefer t.Context() or a cancellable context. + go func(ctx context.Context, st store.Store, shutdown func(context.Context) error) { + <-ctx.Done() + log.Printf("govisual: shutdown context cancelled, releasing resources") + if shutdown != nil { + // Give OTel a real deadline to flush spans, independent of + // the parent context (which is already cancelled). + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := shutdown(shutdownCtx); err != nil { + log.Printf("govisual: error shutting down OpenTelemetry: %v", err) + } + cancel() + } + if err := st.Close(); err != nil { + log.Printf("govisual: error closing storage: %v", err) + } + }(config.ShutdownContext, requestStore, otelShutdown) + } + + dashHandler := dashboard.NewHandler(requestStore, profiler, dashboard.HandlerOptions{ + EnableReplay: config.EnableReplay, + ExposeSystemInfo: config.ExposeSystemInfo, + ExposeEnvVars: config.ExposeEnvVars, + }) - // Create dashboard handler with profiler - dashHandler := dashboard.NewHandler(requestStore, profiler) + guardedDash := guardDashboard(dashHandler, config) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.HasPrefix(r.URL.Path, config.DashboardPath) { - // Handle the dashboard routes - http.StripPrefix(config.DashboardPath, dashHandler).ServeHTTP(w, r) + http.StripPrefix(config.DashboardPath, guardedDash).ServeHTTP(w, r) return } - - // Otherwise, serve the application wrapped.ServeHTTP(w, r) }) } + +// guardDashboard wraps the dashboard handler with localhost-only and +// authentication checks per the configuration. +func guardDashboard(h http.Handler, config *Config) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if config.LocalhostOnly && !isLoopback(r) { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + if config.DashboardAuth != nil && !config.DashboardAuth(r) { + // Surface a Basic challenge so browsers prompt the user; harmless + // when a custom auth scheme is in use. + w.Header().Set("WWW-Authenticate", `Basic realm="govisual"`) + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + h.ServeHTTP(w, r) + }) +} + +// isLoopback reports whether the request's remote address is a loopback IP. +func isLoopback(r *http.Request) bool { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + ip := net.ParseIP(host) + if ip == nil { + return false + } + return ip.IsLoopback() +} + +// effectiveMaxBody resolves the configured MaxBodyBytes against the package +// default. 0 means "use default"; negative means "no cap". +func (c *Config) effectiveMaxBody() int { + if c.MaxBodyBytes == 0 { + return middleware.DefaultMaxBodyBytes + } + return c.MaxBodyBytes +}