Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions browser/browser.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,33 @@ import (
var webpage string

func LogSocketViewHandler(w http.ResponseWriter, r *http.Request) {
wsResource := websocketScheme(r) + r.Host + r.URL.Path
wsResource := websocketScheme(r) + websocketHost(r) + r.URL.Path
wsResource = strings.TrimSuffix(wsResource, "/") + "/ws"
homeTemplate.Execute(w, wsResource)
}

func websocketScheme(r *http.Request) string {
if forwardedProto := r.Header.Get("X-Forwarded-Proto"); forwardedProto != "" {
protocol := strings.ToLower(strings.TrimSpace(strings.Split(forwardedProto, ",")[0]))
if protocol == "https" {
return "wss://"
}
if protocol == "http" {
return "ws://"
}
switch strings.ToLower(forwardedHeaderValue(r, "X-Forwarded-Proto")) {
case "https":
return "wss://"
case "http":
return "ws://"
}
if r.TLS != nil {
return "wss://"
}
return "ws://"
}

func websocketHost(r *http.Request) string {
if host := forwardedHeaderValue(r, "X-Forwarded-Host"); host != "" {
return host
}
return r.Host
}

func forwardedHeaderValue(r *http.Request, key string) string {
return strings.TrimSpace(strings.Split(r.Header.Get(key), ",")[0])
}

var homeTemplate = template.Must(template.New("").Parse(webpage))
30 changes: 30 additions & 0 deletions browser/browser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,33 @@ func TestLogSocketViewHandler_ForwardedHTTPOverridesTLS(t *testing.T) {
t.Error("expected escaped ws://example.com/logs/ws in body")
}
}

func TestLogSocketViewHandler_ForwardedHostUsedForWebsocketURL(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "https://internal:8080/logs/", nil)
req.TLS = &tls.ConnectionState{}
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("X-Forwarded-Host", "logs.example.com")
w := httptest.NewRecorder()
LogSocketViewHandler(w, req)

body := w.Body.String()
if !strings.Contains(body, `wss:\/\/logs.example.com\/logs\/ws`) {
t.Error("expected escaped wss://logs.example.com/logs/ws in body")
}
if strings.Contains(body, `internal:8080`) {
t.Error("response should not leak the internal host when X-Forwarded-Host is set")
}
}

func TestLogSocketViewHandler_ForwardedHeadersUseFirstValue(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://internal:8080/logs/", nil)
req.Header.Set("X-Forwarded-Proto", "HTTPS, http")
req.Header.Set("X-Forwarded-Host", "logs.example.com, internal:8080")
w := httptest.NewRecorder()
LogSocketViewHandler(w, req)

body := w.Body.String()
if !strings.Contains(body, `wss:\/\/logs.example.com\/logs\/ws`) {
t.Error("expected escaped wss://logs.example.com/logs/ws in body")
}
}
17 changes: 15 additions & 2 deletions ws/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,31 @@ import (
"encoding/json"
"net/http"
"strings"
"sync"

"github.com/gorilla/websocket"
logger "github.com/taigrr/log-socket/v2/log"
)

var upgrader = websocket.Upgrader{} // use default options
var (
upgrader = websocket.Upgrader{} // use default options
upgraderMux sync.RWMutex
)

// SetUpgrader replaces the default [websocket.Upgrader] used by
// [LogSocketHandler].
func SetUpgrader(u websocket.Upgrader) {
upgraderMux.Lock()
defer upgraderMux.Unlock()
upgrader = u
}

func getUpgrader() websocket.Upgrader {
upgraderMux.RLock()
defer upgraderMux.RUnlock()
return upgrader
}

func parseNamespaces(raw string) []string {
if raw == "" {
return nil
Expand Down Expand Up @@ -46,7 +58,8 @@ func LogSocketHandler(w http.ResponseWriter, r *http.Request) {
// Empty or missing means all namespaces.
namespaces := parseNamespaces(r.URL.Query().Get("namespaces"))

conn, err := upgrader.Upgrade(w, r, nil)
currentUpgrader := getUpgrader()
conn, err := currentUpgrader.Upgrade(w, r, nil)
if err != nil {
logger.Error("upgrade:", err)
return
Expand Down
89 changes: 51 additions & 38 deletions ws/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ws

import (
"encoding/json"
"net"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -18,11 +19,12 @@ func TestSetUpgrader(t *testing.T) {
WriteBufferSize: 2048,
}
SetUpgrader(custom)
if upgrader.ReadBufferSize != 2048 {
t.Errorf("ReadBufferSize = %d, want 2048", upgrader.ReadBufferSize)
current := getUpgrader()
if current.ReadBufferSize != 2048 {
t.Errorf("ReadBufferSize = %d, want 2048", current.ReadBufferSize)
}
if upgrader.WriteBufferSize != 2048 {
t.Errorf("WriteBufferSize = %d, want 2048", upgrader.WriteBufferSize)
if current.WriteBufferSize != 2048 {
t.Errorf("WriteBufferSize = %d, want 2048", current.WriteBufferSize)
}
// Reset to default
SetUpgrader(websocket.Upgrader{})
Expand Down Expand Up @@ -56,30 +58,20 @@ func TestLogSocketHandler_WebSocket(t *testing.T) {
}
defer conn.Close()

// Send a log entry and verify it arrives over the websocket
waitForWebSocketEntry(t, conn, func(entry logger.Entry) bool {
return entry.Namespace == logger.DefaultNamespace && entry.Output == "Websocket client attached."
})

// Send a log entry and verify it arrives over the websocket after the
// client is fully attached.
testLogger := logger.NewLogger("ws-test")
testLogger.Info("test message for websocket")

// Read messages until we find our test entry (the handler itself
// logs "Websocket client attached." which may arrive first)
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
var found bool
for i := 0; i < 10; i++ {
_, message, err := conn.ReadMessage()
if err != nil {
t.Fatalf("failed to read message: %v", err)
}
var entry logger.Entry
if err := json.Unmarshal(message, &entry); err != nil {
t.Fatalf("failed to unmarshal entry: %v", err)
}
if entry.Namespace == "ws-test" && entry.Level == "INFO" {
found = true
break
}
}
if !found {
t.Error("did not receive expected log entry with namespace ws-test")
entry := waitForWebSocketEntry(t, conn, func(entry logger.Entry) bool {
return entry.Namespace == "ws-test" && entry.Level == "INFO"
})
if !strings.Contains(entry.Output, "test message for websocket") {
t.Errorf("output = %q, want to contain test message", entry.Output)
}
}

Expand All @@ -100,32 +92,53 @@ func TestLogSocketHandler_NamespaceFilter(t *testing.T) {
}
defer conn.Close()

// Send a log to a different namespace — it should NOT be received
// Send a log to a different namespace — it should NOT be received.
otherLogger := logger.NewLogger("other-ns")
otherLogger.Info("should not arrive")

// Send a log to the filtered namespace — it SHOULD be received
// Retry the matching namespace log a few times to avoid racing the server
// goroutine that registers the filtered client after the websocket dial completes.
filteredLogger := logger.NewLogger("filtered-ns")
filteredLogger.Info("should arrive")

conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, message, err := conn.ReadMessage()
if err != nil {
t.Fatalf("failed to read message: %v", err)
}

var entry logger.Entry
if err := json.Unmarshal(message, &entry); err != nil {
t.Fatalf("failed to unmarshal entry: %v", err)
for i := 0; i < 5; i++ {
filteredLogger.Infof("should arrive %d", i)
entry = waitForWebSocketEntry(t, conn, func(entry logger.Entry) bool {
return entry.Namespace == "filtered-ns"
})
if entry.Namespace == "filtered-ns" {
break
}
}
if entry.Namespace != "filtered-ns" {
t.Errorf("namespace = %q, want filtered-ns", entry.Namespace)
t.Fatal("did not receive expected filtered namespace entry")
}
if !strings.Contains(entry.Output, "should arrive") {
t.Errorf("output = %q, want to contain 'should arrive'", entry.Output)
}
}

func waitForWebSocketEntry(t *testing.T, conn *websocket.Conn, match func(logger.Entry) bool) logger.Entry {
t.Helper()
conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
for i := 0; i < 10; i++ {
_, message, err := conn.ReadMessage()
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Timeout() {
return logger.Entry{}
}
t.Fatalf("failed to read message: %v", err)
}
var entry logger.Entry
if err := json.Unmarshal(message, &entry); err != nil {
t.Fatalf("failed to unmarshal entry: %v", err)
}
if match(entry) {
return entry
}
}
return logger.Entry{}
}

func TestParseNamespaces(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading