Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions internal/carrier/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,28 @@ const (
endpointBlacklistMaxTTL = 1 * time.Hour
)

func readRelayResponseBody(r io.Reader, contentLength int64, limit int) ([]byte, error) {
if contentLength > int64(limit) {
return nil, fmt.Errorf("relay response too large (%d bytes > %d)", contentLength, limit)
}
if contentLength >= 0 {
body := make([]byte, int(contentLength))
if _, err := io.ReadFull(r, body); err != nil {
return nil, err
}
return body, nil
}
lr := &io.LimitedReader{R: r, N: int64(limit) + 1}
body, err := io.ReadAll(lr)
if err != nil {
return nil, err
}
if len(body) > limit {
return nil, fmt.Errorf("relay response too large (%d bytes > %d)", len(body), limit)
}
return body, nil
}

// Config bundles everything the carrier needs to talk to the relay.
type Config struct {
ScriptURLs []string // one or more full https://script.google.com/macros/s/.../exec URLs
Expand Down Expand Up @@ -163,7 +185,7 @@ type Client struct {
numWorkers int // (workersPerEndpoint + idleSlotsPerBucket - 1) × bucketCount
bucketCount int // distinct account labels in endpoints; unlabeled all share one bucket
idleSlotsPerBucket int // resolved from Config.IdleSlotsPerBucket, default 1
clientVersion string
clientVersion string

// clientID is a random 16-byte identifier minted once per process. It is
// embedded in every encrypted batch so the server can route downstream
Expand Down Expand Up @@ -587,7 +609,7 @@ func (c *Client) pollOnce(ctx context.Context) bool {
return false
}

respBody, readErr := io.ReadAll(resp.Body)
respBody, readErr := readRelayResponseBody(resp.Body, resp.ContentLength, maxRelayResponseBodyBytes)
_ = resp.Body.Close()
if readErr != nil {
c.markEndpointFailure(endpointIdx)
Expand Down
28 changes: 28 additions & 0 deletions internal/carrier/client_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package carrier

import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
Expand All @@ -14,6 +16,32 @@ import (

const testKeyHex = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"

type carrierErrReader struct{}

func (carrierErrReader) Read([]byte) (int, error) {
return 0, errors.New("reader should not be called")
}

func TestReadRelayResponseBodyBoundsAndPreallocates(t *testing.T) {
_, err := readRelayResponseBody(bytes.NewReader([]byte("abcdef")), -1, 5)
if err == nil {
t.Fatal("readRelayResponseBody succeeded for over-limit unknown-length body")
}

_, err = readRelayResponseBody(carrierErrReader{}, 6, 5)
if err == nil {
t.Fatal("readRelayResponseBody succeeded for over-limit Content-Length")
}

got, err := readRelayResponseBody(bytes.NewReader([]byte("abcde")), int64(len("abcde")), 5)
if err != nil {
t.Fatalf("readRelayResponseBody: %v", err)
}
if string(got) != "abcde" {
t.Fatalf("readRelayResponseBody = %q, want abcde", got)
}
}

// echoServer decodes the incoming batch, echoes each frame's payload back
// (with the SYN bit cleared and seq reset per session), and returns it.
func echoServer(t *testing.T, aead *frame.Crypto) (*httptest.Server, *int) {
Expand Down
38 changes: 36 additions & 2 deletions internal/exit/exit.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
Expand Down Expand Up @@ -112,8 +113,37 @@ const (

// idleGCInterval is how often the cleanup loop scans for orphaned sessions.
idleGCInterval = 60 * time.Second

// maxRequestBodyBytes caps the encrypted/base64 POST body accepted by
// /tunnel. Keep this above the largest current client batch envelope while
// still rejecting accidental or hostile unbounded uploads before decoding.
maxRequestBodyBytes = 64 * 1024 * 1024
)

var errRequestTooLarge = errors.New("tunnel request too large")

func readTunnelRequestBody(r io.Reader, contentLength int64, limit int) ([]byte, error) {
if contentLength > int64(limit) {
return nil, fmt.Errorf("%w (%d bytes > %d)", errRequestTooLarge, contentLength, limit)
}
if contentLength >= 0 {
body := make([]byte, int(contentLength))
if _, err := io.ReadFull(r, body); err != nil {
return nil, err
}
return body, nil
}
lr := &io.LimitedReader{R: r, N: int64(limit) + 1}
body, err := io.ReadAll(lr)
if err != nil {
return nil, err
}
if len(body) > limit {
return nil, fmt.Errorf("%w (%d bytes > %d)", errRequestTooLarge, len(body), limit)
}
return body, nil
}

// Config is the VPS server's configuration.
type Config struct {
ListenAddr string // "0.0.0.0:8443"
Expand Down Expand Up @@ -274,10 +304,14 @@ func (s *Server) handleTunnel(w http.ResponseWriter, r *http.Request) {
return
}
s.stats.requests.Add(1)
body, err := io.ReadAll(r.Body)
body, err := readTunnelRequestBody(r.Body, r.ContentLength, maxRequestBodyBytes)
if err != nil {
log.Printf("[exit] read body: %v", err)
w.WriteHeader(http.StatusBadRequest)
if errors.Is(err, errRequestTooLarge) {
w.WriteHeader(http.StatusRequestEntityTooLarge)
} else {
w.WriteHeader(http.StatusBadRequest)
}
return
}

Expand Down
32 changes: 32 additions & 0 deletions internal/exit/exit_timing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,38 @@ import (

const exitTimingTestKeyHex = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"

type errReader struct{}

func (errReader) Read([]byte) (int, error) {
return 0, errors.New("reader should not be called")
}

func TestReadTunnelRequestBodyBoundsAndPreallocates(t *testing.T) {
_, err := readTunnelRequestBody(bytes.NewReader([]byte("abcdef")), -1, 5)
if err == nil {
t.Fatal("readTunnelRequestBody succeeded for over-limit unknown-length body")
}
if !errors.Is(err, errRequestTooLarge) {
t.Fatalf("readTunnelRequestBody err = %v, want errRequestTooLarge", err)
}

_, err = readTunnelRequestBody(errReader{}, 6, 5)
if err == nil {
t.Fatal("readTunnelRequestBody succeeded for over-limit Content-Length")
}
if !errors.Is(err, errRequestTooLarge) {
t.Fatalf("readTunnelRequestBody err = %v, want errRequestTooLarge", err)
}

got, err := readTunnelRequestBody(bytes.NewReader([]byte("abcde")), int64(len("abcde")), 5)
if err != nil {
t.Fatalf("readTunnelRequestBody: %v", err)
}
if string(got) != "abcde" {
t.Fatalf("readTunnelRequestBody = %q, want abcde", got)
}
}

func mustExitTimingServer(tb testing.TB) *Server {
tb.Helper()
s, err := New(Config{ListenAddr: "127.0.0.1:0", AESKeyHex: exitTimingTestKeyHex})
Expand Down
Loading