diff --git a/internal/handler/cors_test.go b/internal/handler/cors_test.go index 0961e01e..05dd3941 100644 --- a/internal/handler/cors_test.go +++ b/internal/handler/cors_test.go @@ -4,6 +4,7 @@ import ( "net/http" "net/http/httptest" "testing" + "testing/fstest" ) func TestParseCORSOrigins(t *testing.T) { @@ -167,6 +168,41 @@ func TestCORSMiddlewareVaryOriginForDisallowedOrigin(t *testing.T) { } } +func TestCORSMiddlewareVaryPreservedThroughStaticHandler(t *testing.T) { + // Regression test: serveFromCache used to call w.Header().Set("Vary", + // "Accept-Encoding") which silently dropped the "Origin" entry that + // CORSMiddleware had already added. After the fix (Set → Add) both + // values must appear in the response Vary header for an allowed origin + // requesting a static file. + const index = "test" + prev := StaticFS + StaticFS = fstest.MapFS{ + "index.html": &fstest.MapFile{Data: []byte(index)}, + } + defer func() { StaticFS = prev }() + + staticHandler := NewStaticHandler() + h := CORSMiddleware(ParseCORSOrigins("https://ui.example.com"), staticHandler) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "https://ui.example.com") + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 from static handler, got %d", rec.Code) + } + vary := rec.Header().Values("Vary") + for _, want := range []string{"Origin", "Accept-Encoding"} { + if !containsStr(vary, want) { + t.Fatalf("static handler response Vary=%v missing %q — CORSMiddleware Vary was overwritten", vary, want) + } + } + if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "https://ui.example.com" { + t.Fatalf("Allow-Origin=%q, want reflected origin", got) + } +} + func containsStr(s []string, want string) bool { for _, v := range s { if v == want { diff --git a/internal/handler/static.go b/internal/handler/static.go index b5fc6744..2130a86f 100644 --- a/internal/handler/static.go +++ b/internal/handler/static.go @@ -197,12 +197,14 @@ func serveFromCache(w http.ResponseWriter, r *http.Request, cached *staticFileCa // Set content type w.Header().Set("Content-Type", cached.contentType) - // Always set Vary header to ensure caches differentiate by Accept-Encoding - w.Header().Set("Vary", "Accept-Encoding") + // Add Accept-Encoding to Vary so caches differentiate by encoding. Use Add + // (not Set) to preserve any Vary values already written by upstream + // middleware (e.g. "Origin" from CORSMiddleware). + w.Header().Add("Vary", "Accept-Encoding") // Check if client accepts gzip and we have gzipped content if cached.gzipped != nil && strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { - w.Header().Set("Content-Encoding", "gzip") + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(cached.gzipped))) w.WriteHeader(http.StatusOK) w.Write(cached.gzipped)