diff --git a/ws/server.go b/ws/server.go index 9c914b2..a969d06 100644 --- a/ws/server.go +++ b/ws/server.go @@ -18,17 +18,33 @@ func SetUpgrader(u websocket.Upgrader) { upgrader = u } +func parseNamespaces(raw string) []string { + if raw == "" { + return nil + } + + parts := strings.Split(raw, ",") + namespaces := make([]string, 0, len(parts)) + for _, part := range parts { + namespace := strings.TrimSpace(part) + if namespace == "" { + continue + } + namespaces = append(namespaces, namespace) + } + if len(namespaces) == 0 { + return nil + } + return namespaces +} + // LogSocketHandler upgrades the HTTP connection to a WebSocket and streams // log entries to the client. An optional "namespaces" query parameter // (comma-separated) filters which namespaces the client receives. func LogSocketHandler(w http.ResponseWriter, r *http.Request) { // Get namespaces from query parameter, comma-separated. // Empty or missing means all namespaces. - namespacesParam := r.URL.Query().Get("namespaces") - var namespaces []string - if namespacesParam != "" { - namespaces = strings.Split(namespacesParam, ",") - } + namespaces := parseNamespaces(r.URL.Query().Get("namespaces")) conn, err := upgrader.Upgrade(w, r, nil) if err != nil { diff --git a/ws/server_test.go b/ws/server_test.go index 1421fe0..759c3a9 100644 --- a/ws/server_test.go +++ b/ws/server_test.go @@ -125,3 +125,67 @@ func TestLogSocketHandler_NamespaceFilter(t *testing.T) { t.Errorf("output = %q, want to contain 'should arrive'", entry.Output) } } + +func TestParseNamespaces(t *testing.T) { + tests := []struct { + name string + raw string + want []string + }{ + {name: "empty", raw: "", want: nil}, + {name: "comma separated", raw: "api,worker", want: []string{"api", "worker"}}, + {name: "trims whitespace", raw: " api, worker ,jobs ", want: []string{"api", "worker", "jobs"}}, + {name: "skips empty values", raw: "api,, worker, ", want: []string{"api", "worker"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseNamespaces(tt.raw) + if len(got) != len(tt.want) { + t.Fatalf("len(parseNamespaces(%q)) = %d, want %d (%v)", tt.raw, len(got), len(tt.want), got) + } + for i := range tt.want { + if got[i] != tt.want[i] { + t.Fatalf("parseNamespaces(%q)[%d] = %q, want %q", tt.raw, i, got[i], tt.want[i]) + } + } + }) + } +} + +func TestLogSocketHandler_NamespaceFilter_TrimsWhitespace(t *testing.T) { + SetUpgrader(websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + }) + defer SetUpgrader(websocket.Upgrader{}) + + server := httptest.NewServer(http.HandlerFunc(LogSocketHandler)) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws?namespaces=%20filtered-ns%20,%20other-ns%20" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer conn.Close() + + filteredLogger := logger.NewLogger("filtered-ns") + filteredLogger.Info("trimmed namespace should arrive") + + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, message, err := conn.ReadMessage() + if err != nil { + t.Fatalf("failed to read message: %v", err) + } + + var entry logger.Entry + if err := json.Unmarshal(message, &entry); err != nil { + t.Fatalf("failed to unmarshal entry: %v", err) + } + if entry.Namespace != "filtered-ns" { + t.Fatalf("namespace = %q, want filtered-ns", entry.Namespace) + } + if !strings.Contains(entry.Output, "trimmed namespace should arrive") { + t.Fatalf("output = %q, want trimmed namespace message", entry.Output) + } +}