From b942329477a8c039361d5bf77e3e4b5425c7b46c Mon Sep 17 00:00:00 2001 From: Tai Groot Date: Wed, 6 May 2026 06:34:50 +0000 Subject: [PATCH] fix(browser): honor forwarded websocket host --- browser/browser.go | 26 +++++++----- browser/browser_test.go | 30 ++++++++++++++ ws/server.go | 17 +++++++- ws/server_test.go | 89 +++++++++++++++++++++++------------------ 4 files changed, 113 insertions(+), 49 deletions(-) diff --git a/browser/browser.go b/browser/browser.go index d0e3040..d9a78ad 100644 --- a/browser/browser.go +++ b/browser/browser.go @@ -11,20 +11,17 @@ import ( var webpage string func LogSocketViewHandler(w http.ResponseWriter, r *http.Request) { - wsResource := websocketScheme(r) + r.Host + r.URL.Path + wsResource := websocketScheme(r) + websocketHost(r) + r.URL.Path wsResource = strings.TrimSuffix(wsResource, "/") + "/ws" homeTemplate.Execute(w, wsResource) } func websocketScheme(r *http.Request) string { - if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" { - protocol := strings.ToLower(strings.TrimSpace(strings.Split(forwardedProto, ",")[0])) - if protocol == "https" { - return "wss://" - } - if protocol == "http" { - return "ws://" - } + switch strings.ToLower(forwardedHeaderValue(r, "X-Forwarded-Proto")) { + case "https": + return "wss://" + case "http": + return "ws://" } if r.TLS != nil { return "wss://" @@ -32,4 +29,15 @@ func websocketScheme(r *http.Request) string { return "ws://" } +func websocketHost(r *http.Request) string { + if host := forwardedHeaderValue(r, "X-Forwarded-Host"); host != "" { + return host + } + return r.Host +} + +func forwardedHeaderValue(r *http.Request, key string) string { + return strings.TrimSpace(strings.Split(r.Header.Get(key), ",")[0]) +} + var homeTemplate = template.Must(template.New("").Parse(webpage)) diff --git a/browser/browser_test.go b/browser/browser_test.go index 01f9063..15b15b6 100644 --- a/browser/browser_test.go +++ b/browser/browser_test.go @@ -77,3 +77,33 @@ func TestLogSocketViewHandler_ForwardedHTTPOverridesTLS(t *testing.T) { t.Error("expected escaped ws://example.com/logs/ws in body") } } + +func TestLogSocketViewHandler_ForwardedHostUsedForWebsocketURL(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "https://internal:8080/logs/", nil) + req.TLS = &tls.ConnectionState{} + req.Header.Set("X-Forwarded-Proto", "https") + req.Header.Set("X-Forwarded-Host", "logs.example.com") + w := httptest.NewRecorder() + LogSocketViewHandler(w, req) + + body := w.Body.String() + if !strings.Contains(body, `wss:\/\/logs.example.com\/logs\/ws`) { + t.Error("expected escaped wss://logs.example.com/logs/ws in body") + } + if strings.Contains(body, `internal:8080`) { + t.Error("response should not leak the internal host when X-Forwarded-Host is set") + } +} + +func TestLogSocketViewHandler_ForwardedHeadersUseFirstValue(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://internal:8080/logs/", nil) + req.Header.Set("X-Forwarded-Proto", "HTTPS, http") + req.Header.Set("X-Forwarded-Host", "logs.example.com, internal:8080") + w := httptest.NewRecorder() + LogSocketViewHandler(w, req) + + body := w.Body.String() + if !strings.Contains(body, `wss:\/\/logs.example.com\/logs\/ws`) { + t.Error("expected escaped wss://logs.example.com/logs/ws in body") + } +} diff --git a/ws/server.go b/ws/server.go index 9c914b2..ef1f126 100644 --- a/ws/server.go +++ b/ws/server.go @@ -5,19 +5,31 @@ import ( "encoding/json" "net/http" "strings" + "sync" "github.com/gorilla/websocket" logger "github.com/taigrr/log-socket/v2/log" ) -var upgrader = websocket.Upgrader{} // use default options +var ( + upgrader = websocket.Upgrader{} // use default options + upgraderMux sync.RWMutex +) // SetUpgrader replaces the default [websocket.Upgrader] used by // [LogSocketHandler]. func SetUpgrader(u websocket.Upgrader) { + upgraderMux.Lock() + defer upgraderMux.Unlock() upgrader = u } +func getUpgrader() websocket.Upgrader { + upgraderMux.RLock() + defer upgraderMux.RUnlock() + return upgrader +} + // LogSocketHandler upgrades the HTTP connection to a WebSocket and streams // log entries to the client. An optional "namespaces" query parameter // (comma-separated) filters which namespaces the client receives. @@ -30,7 +42,8 @@ func LogSocketHandler(w http.ResponseWriter, r *http.Request) { namespaces = strings.Split(namespacesParam, ",") } - conn, err := upgrader.Upgrade(w, r, nil) + currentUpgrader := getUpgrader() + conn, err := currentUpgrader.Upgrade(w, r, nil) if err != nil { logger.Error("upgrade:", err) return diff --git a/ws/server_test.go b/ws/server_test.go index 1421fe0..d70a420 100644 --- a/ws/server_test.go +++ b/ws/server_test.go @@ -2,6 +2,7 @@ package ws import ( "encoding/json" + "net" "net/http" "net/http/httptest" "strings" @@ -18,11 +19,12 @@ func TestSetUpgrader(t *testing.T) { WriteBufferSize: 2048, } SetUpgrader(custom) - if upgrader.ReadBufferSize != 2048 { - t.Errorf("ReadBufferSize = %d, want 2048", upgrader.ReadBufferSize) + current := getUpgrader() + if current.ReadBufferSize != 2048 { + t.Errorf("ReadBufferSize = %d, want 2048", current.ReadBufferSize) } - if upgrader.WriteBufferSize != 2048 { - t.Errorf("WriteBufferSize = %d, want 2048", upgrader.WriteBufferSize) + if current.WriteBufferSize != 2048 { + t.Errorf("WriteBufferSize = %d, want 2048", current.WriteBufferSize) } // Reset to default SetUpgrader(websocket.Upgrader{}) @@ -56,30 +58,20 @@ func TestLogSocketHandler_WebSocket(t *testing.T) { } defer conn.Close() - // Send a log entry and verify it arrives over the websocket + waitForWebSocketEntry(t, conn, func(entry logger.Entry) bool { + return entry.Namespace == logger.DefaultNamespace && entry.Output == "Websocket client attached." + }) + + // Send a log entry and verify it arrives over the websocket after the + // client is fully attached. testLogger := logger.NewLogger("ws-test") testLogger.Info("test message for websocket") - // Read messages until we find our test entry (the handler itself - // logs "Websocket client attached." which may arrive first) - conn.SetReadDeadline(time.Now().Add(2 * time.Second)) - var found bool - for i := 0; i < 10; i++ { - _, message, err := conn.ReadMessage() - if err != nil { - t.Fatalf("failed to read message: %v", err) - } - var entry logger.Entry - if err := json.Unmarshal(message, &entry); err != nil { - t.Fatalf("failed to unmarshal entry: %v", err) - } - if entry.Namespace == "ws-test" && entry.Level == "INFO" { - found = true - break - } - } - if !found { - t.Error("did not receive expected log entry with namespace ws-test") + entry := waitForWebSocketEntry(t, conn, func(entry logger.Entry) bool { + return entry.Namespace == "ws-test" && entry.Level == "INFO" + }) + if !strings.Contains(entry.Output, "test message for websocket") { + t.Errorf("output = %q, want to contain test message", entry.Output) } } @@ -100,28 +92,49 @@ func TestLogSocketHandler_NamespaceFilter(t *testing.T) { } defer conn.Close() - // Send a log to a different namespace — it should NOT be received + // Send a log to a different namespace — it should NOT be received. otherLogger := logger.NewLogger("other-ns") otherLogger.Info("should not arrive") - // Send a log to the filtered namespace — it SHOULD be received + // Retry the matching namespace log a few times to avoid racing the server + // goroutine that registers the filtered client after the websocket dial completes. filteredLogger := logger.NewLogger("filtered-ns") - filteredLogger.Info("should arrive") - - conn.SetReadDeadline(time.Now().Add(2 * time.Second)) - _, message, err := conn.ReadMessage() - if err != nil { - t.Fatalf("failed to read message: %v", err) - } - var entry logger.Entry - if err := json.Unmarshal(message, &entry); err != nil { - t.Fatalf("failed to unmarshal entry: %v", err) + for i := 0; i < 5; i++ { + filteredLogger.Infof("should arrive %d", i) + entry = waitForWebSocketEntry(t, conn, func(entry logger.Entry) bool { + return entry.Namespace == "filtered-ns" + }) + if entry.Namespace == "filtered-ns" { + break + } } if entry.Namespace != "filtered-ns" { - t.Errorf("namespace = %q, want filtered-ns", entry.Namespace) + t.Fatal("did not receive expected filtered namespace entry") } if !strings.Contains(entry.Output, "should arrive") { t.Errorf("output = %q, want to contain 'should arrive'", entry.Output) } } + +func waitForWebSocketEntry(t *testing.T, conn *websocket.Conn, match func(logger.Entry) bool) logger.Entry { + t.Helper() + conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond)) + for i := 0; i < 10; i++ { + _, message, err := conn.ReadMessage() + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Timeout() { + return logger.Entry{} + } + t.Fatalf("failed to read message: %v", err) + } + var entry logger.Entry + if err := json.Unmarshal(message, &entry); err != nil { + t.Fatalf("failed to unmarshal entry: %v", err) + } + if match(entry) { + return entry + } + } + return logger.Entry{} +}