From a28eb3bd996ab24dbfcbf1060294085416ab21a1 Mon Sep 17 00:00:00 2001 From: Shayan SalehiRad Date: Mon, 18 May 2026 05:23:08 +0330 Subject: [PATCH] carrier/exit: bound relay body reads --- internal/carrier/client.go | 26 +++++++++++++++++++-- internal/carrier/client_test.go | 28 +++++++++++++++++++++++ internal/exit/exit.go | 38 +++++++++++++++++++++++++++++-- internal/exit/exit_timing_test.go | 32 ++++++++++++++++++++++++++ 4 files changed, 120 insertions(+), 4 deletions(-) diff --git a/internal/carrier/client.go b/internal/carrier/client.go index b719a4e..b1dbf91 100644 --- a/internal/carrier/client.go +++ b/internal/carrier/client.go @@ -67,6 +67,28 @@ const ( endpointBlacklistMaxTTL = 1 * time.Hour ) +func readRelayResponseBody(r io.Reader, contentLength int64, limit int) ([]byte, error) { + if contentLength > int64(limit) { + return nil, fmt.Errorf("relay response too large (%d bytes > %d)", contentLength, limit) + } + if contentLength >= 0 { + body := make([]byte, int(contentLength)) + if _, err := io.ReadFull(r, body); err != nil { + return nil, err + } + return body, nil + } + lr := &io.LimitedReader{R: r, N: int64(limit) + 1} + body, err := io.ReadAll(lr) + if err != nil { + return nil, err + } + if len(body) > limit { + return nil, fmt.Errorf("relay response too large (%d bytes > %d)", len(body), limit) + } + return body, nil +} + // Config bundles everything the carrier needs to talk to the relay. type Config struct { ScriptURLs []string // one or more full https://script.google.com/macros/s/.../exec URLs @@ -163,7 +185,7 @@ type Client struct { numWorkers int // (workersPerEndpoint + idleSlotsPerBucket - 1) × bucketCount bucketCount int // distinct account labels in endpoints; unlabeled all share one bucket idleSlotsPerBucket int // resolved from Config.IdleSlotsPerBucket, default 1 - clientVersion string + clientVersion string // clientID is a random 16-byte identifier minted once per process. It is // embedded in every encrypted batch so the server can route downstream @@ -587,7 +609,7 @@ func (c *Client) pollOnce(ctx context.Context) bool { return false } - respBody, readErr := io.ReadAll(resp.Body) + respBody, readErr := readRelayResponseBody(resp.Body, resp.ContentLength, maxRelayResponseBodyBytes) _ = resp.Body.Close() if readErr != nil { c.markEndpointFailure(endpointIdx) diff --git a/internal/carrier/client_test.go b/internal/carrier/client_test.go index 26e6abc..b613498 100644 --- a/internal/carrier/client_test.go +++ b/internal/carrier/client_test.go @@ -1,7 +1,9 @@ package carrier import ( + "bytes" "context" + "errors" "io" "net/http" "net/http/httptest" @@ -14,6 +16,32 @@ import ( const testKeyHex = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" +type carrierErrReader struct{} + +func (carrierErrReader) Read([]byte) (int, error) { + return 0, errors.New("reader should not be called") +} + +func TestReadRelayResponseBodyBoundsAndPreallocates(t *testing.T) { + _, err := readRelayResponseBody(bytes.NewReader([]byte("abcdef")), -1, 5) + if err == nil { + t.Fatal("readRelayResponseBody succeeded for over-limit unknown-length body") + } + + _, err = readRelayResponseBody(carrierErrReader{}, 6, 5) + if err == nil { + t.Fatal("readRelayResponseBody succeeded for over-limit Content-Length") + } + + got, err := readRelayResponseBody(bytes.NewReader([]byte("abcde")), int64(len("abcde")), 5) + if err != nil { + t.Fatalf("readRelayResponseBody: %v", err) + } + if string(got) != "abcde" { + t.Fatalf("readRelayResponseBody = %q, want abcde", got) + } +} + // echoServer decodes the incoming batch, echoes each frame's payload back // (with the SYN bit cleared and seq reset per session), and returns it. func echoServer(t *testing.T, aead *frame.Crypto) (*httptest.Server, *int) { diff --git a/internal/exit/exit.go b/internal/exit/exit.go index 930f8d4..0570bae 100644 --- a/internal/exit/exit.go +++ b/internal/exit/exit.go @@ -8,6 +8,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "log" "net" @@ -112,8 +113,37 @@ const ( // idleGCInterval is how often the cleanup loop scans for orphaned sessions. idleGCInterval = 60 * time.Second + + // maxRequestBodyBytes caps the encrypted/base64 POST body accepted by + // /tunnel. Keep this above the largest current client batch envelope while + // still rejecting accidental or hostile unbounded uploads before decoding. + maxRequestBodyBytes = 64 * 1024 * 1024 ) +var errRequestTooLarge = errors.New("tunnel request too large") + +func readTunnelRequestBody(r io.Reader, contentLength int64, limit int) ([]byte, error) { + if contentLength > int64(limit) { + return nil, fmt.Errorf("%w (%d bytes > %d)", errRequestTooLarge, contentLength, limit) + } + if contentLength >= 0 { + body := make([]byte, int(contentLength)) + if _, err := io.ReadFull(r, body); err != nil { + return nil, err + } + return body, nil + } + lr := &io.LimitedReader{R: r, N: int64(limit) + 1} + body, err := io.ReadAll(lr) + if err != nil { + return nil, err + } + if len(body) > limit { + return nil, fmt.Errorf("%w (%d bytes > %d)", errRequestTooLarge, len(body), limit) + } + return body, nil +} + // Config is the VPS server's configuration. type Config struct { ListenAddr string // "0.0.0.0:8443" @@ -274,10 +304,14 @@ func (s *Server) handleTunnel(w http.ResponseWriter, r *http.Request) { return } s.stats.requests.Add(1) - body, err := io.ReadAll(r.Body) + body, err := readTunnelRequestBody(r.Body, r.ContentLength, maxRequestBodyBytes) if err != nil { log.Printf("[exit] read body: %v", err) - w.WriteHeader(http.StatusBadRequest) + if errors.Is(err, errRequestTooLarge) { + w.WriteHeader(http.StatusRequestEntityTooLarge) + } else { + w.WriteHeader(http.StatusBadRequest) + } return } diff --git a/internal/exit/exit_timing_test.go b/internal/exit/exit_timing_test.go index 57060d6..b45d99a 100644 --- a/internal/exit/exit_timing_test.go +++ b/internal/exit/exit_timing_test.go @@ -20,6 +20,38 @@ import ( const exitTimingTestKeyHex = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" +type errReader struct{} + +func (errReader) Read([]byte) (int, error) { + return 0, errors.New("reader should not be called") +} + +func TestReadTunnelRequestBodyBoundsAndPreallocates(t *testing.T) { + _, err := readTunnelRequestBody(bytes.NewReader([]byte("abcdef")), -1, 5) + if err == nil { + t.Fatal("readTunnelRequestBody succeeded for over-limit unknown-length body") + } + if !errors.Is(err, errRequestTooLarge) { + t.Fatalf("readTunnelRequestBody err = %v, want errRequestTooLarge", err) + } + + _, err = readTunnelRequestBody(errReader{}, 6, 5) + if err == nil { + t.Fatal("readTunnelRequestBody succeeded for over-limit Content-Length") + } + if !errors.Is(err, errRequestTooLarge) { + t.Fatalf("readTunnelRequestBody err = %v, want errRequestTooLarge", err) + } + + got, err := readTunnelRequestBody(bytes.NewReader([]byte("abcde")), int64(len("abcde")), 5) + if err != nil { + t.Fatalf("readTunnelRequestBody: %v", err) + } + if string(got) != "abcde" { + t.Fatalf("readTunnelRequestBody = %q, want abcde", got) + } +} + func mustExitTimingServer(tb testing.TB) *Server { tb.Helper() s, err := New(Config{ListenAddr: "127.0.0.1:0", AESKeyHex: exitTimingTestKeyHex})