diff --git a/cmd/server/main.go b/cmd/server/main.go index 36f0681..f85f178 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -31,11 +31,12 @@ func main() { } srv, err := exit.New(exit.Config{ - ListenAddr: cfg.ListenAddr, - AESKeyHex: cfg.AESKeyHex, - DebugTiming: cfg.DebugTiming, - UpstreamProxy: cfg.UpstreamProxy, - Version: version, + ListenAddr: cfg.ListenAddr, + AESKeyHex: cfg.AESKeyHex, + DebugTiming: cfg.DebugTiming, + UpstreamProxy: cfg.UpstreamProxy, + InitialResponseBytesPreEncode: cfg.InitialResponseBytesPreEncode, + Version: version, }) if err != nil { log.Fatalf("exit: %v", err) diff --git a/internal/config/server.go b/internal/config/server.go index fe2664e..aff4c85 100644 --- a/internal/config/server.go +++ b/internal/config/server.go @@ -15,10 +15,11 @@ import ( // Server is the VPS exit server config. type Server struct { - ListenAddr string - AESKeyHex string - DebugTiming bool - UpstreamProxy string // optional socks5://host:port; when set, all outbound dials go through this proxy + ListenAddr string + AESKeyHex string + DebugTiming bool + UpstreamProxy string // optional socks5://host:port; when set, all outbound dials go through this proxy + InitialResponseBytesPreEncode int } type serverFile struct { @@ -36,6 +37,10 @@ type serverFile struct { // datacenter IP is blocked by certain sites. UpstreamProxy string `json:"upstream_proxy"` + // Optional: cap the first downstream response for a newly opened session. + // 0 uses the server default. + InitialResponseBytesPreEncode int `json:"initial_response_bytes_pre_encode"` + // Legacy keys kept as fallback for existing deployments. ListenAddr string `json:"listen_addr"` AESKeyHex string `json:"aes_key_hex"` @@ -104,12 +109,16 @@ func LoadServer(path string) (*Server, error) { } upstreamProxy = u.Host } + if f.InitialResponseBytesPreEncode < 0 { + return nil, fmt.Errorf("initial_response_bytes_pre_encode must be >= 0") + } c := Server{ - ListenAddr: net.JoinHostPort(listenHost, strconv.Itoa(listenPort)), - AESKeyHex: key, - DebugTiming: f.DebugTiming, - UpstreamProxy: upstreamProxy, + ListenAddr: net.JoinHostPort(listenHost, strconv.Itoa(listenPort)), + AESKeyHex: key, + DebugTiming: f.DebugTiming, + UpstreamProxy: upstreamProxy, + InitialResponseBytesPreEncode: f.InitialResponseBytesPreEncode, } return &c, nil } diff --git a/internal/config/server_test.go b/internal/config/server_test.go new file mode 100644 index 0000000..fe224b0 --- /dev/null +++ b/internal/config/server_test.go @@ -0,0 +1,50 @@ +package config + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +const testServerKeyHex = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + +func TestLoadServerInitialResponseBytesPreEncode(t *testing.T) { + path := filepath.Join(t.TempDir(), "server.json") + body := `{ + "server_host": "127.0.0.1", + "server_port": 8443, + "tunnel_key": "` + testServerKeyHex + `", + "initial_response_bytes_pre_encode": 131072 + }` + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + cfg, err := LoadServer(path) + if err != nil { + t.Fatalf("LoadServer: %v", err) + } + if cfg.InitialResponseBytesPreEncode != 131072 { + t.Fatalf("InitialResponseBytesPreEncode = %d, want 131072", cfg.InitialResponseBytesPreEncode) + } +} + +func TestLoadServerRejectsNegativeInitialResponseBytes(t *testing.T) { + path := filepath.Join(t.TempDir(), "server.json") + body := `{ + "server_host": "127.0.0.1", + "server_port": 8443, + "tunnel_key": "` + testServerKeyHex + `", + "initial_response_bytes_pre_encode": -1 + }` + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + _, err := LoadServer(path) + if err == nil { + t.Fatal("LoadServer succeeded with negative initial_response_bytes_pre_encode") + } + if !strings.Contains(err.Error(), "initial_response_bytes_pre_encode") { + t.Fatalf("LoadServer err = %v, want initial_response_bytes_pre_encode validation", err) + } +} diff --git a/internal/exit/exit.go b/internal/exit/exit.go index 930f8d4..738a472 100644 --- a/internal/exit/exit.go +++ b/internal/exit/exit.go @@ -96,6 +96,12 @@ const ( // last drained session. maxResponseBytesPreEncode = 22 * 1024 * 1024 + // initialResponseBytesPreEncode caps the first downstream response for a + // newly-opened session. Apps Script buffers full HTTP responses, so keeping + // the first file-download/header burst small improves browser-visible start + // time without reducing later bulk throughput. + initialResponseBytesPreEncode = 512 * 1024 + // dialFailureBackoff is how long we suppress repeated SYN dial attempts to a // target after a structural network/DNS failure. dialFailureBackoff = 2 * time.Second @@ -116,21 +122,23 @@ const ( // Config is the VPS server's configuration. type Config struct { - ListenAddr string // "0.0.0.0:8443" - AESKeyHex string // 64-char hex - DebugTiming bool // when true, log per-session dial breakdown and first-read latency - UpstreamProxy string // optional "host:port" of a local SOCKS5 proxy (e.g. WARP on 127.0.0.1:40000) - Version string // build version string (exposed in /healthz and version probe) + ListenAddr string // "0.0.0.0:8443" + AESKeyHex string // 64-char hex + DebugTiming bool // when true, log per-session dial breakdown and first-read latency + UpstreamProxy string // optional "host:port" of a local SOCKS5 proxy (e.g. WARP on 127.0.0.1:40000) + InitialResponseBytesPreEncode int // optional cap for first downstream response; <=0 uses default + Version string // build version string (exposed in /healthz and version probe) } // Server holds the per-process session state. type Server struct { - cfg Config - aead *frame.Crypto - dial func(network, address string, timeout time.Duration) (net.Conn, error) - dns *dnsCache - debugTiming bool - version string + cfg Config + aead *frame.Crypto + dial func(network, address string, timeout time.Duration) (net.Conn, error) + dns *dnsCache + debugTiming bool + version string + initialResponseBytesPreEncode int mu sync.Mutex sessions map[[frame.SessionIDLen]byte]*session.Session @@ -178,23 +186,28 @@ func New(cfg Config) (*Server, error) { return nil, err } dialFn := dialFunc(cfg.UpstreamProxy) + initialResponseCap := cfg.InitialResponseBytesPreEncode + if initialResponseCap <= 0 { + initialResponseCap = initialResponseBytesPreEncode + } s := &Server{ - cfg: cfg, - aead: aead, - dial: dialFn, - dns: newDNSCache(), - debugTiming: cfg.DebugTiming, - version: cfg.Version, - sessions: make(map[[frame.SessionIDLen]byte]*session.Session), - sessionOwners: make(map[[frame.SessionIDLen]byte][frame.ClientIDLen]byte), - txReady: make(map[[frame.SessionIDLen]byte]struct{}), - firstReply: make(map[[frame.SessionIDLen]byte]struct{}), - upstreams: make(map[[frame.SessionIDLen]byte]net.Conn), - lastActivity: make(map[[frame.SessionIDLen]byte]time.Time), - dialFail: make(map[string]time.Time), - pendingRSTs: make(map[[frame.ClientIDLen]byte][]*frame.Frame), - pendingCtrl: make(map[[frame.ClientIDLen]byte][]*frame.Frame), - activity: make(map[[frame.ClientIDLen]byte]chan struct{}), + cfg: cfg, + aead: aead, + dial: dialFn, + dns: newDNSCache(), + debugTiming: cfg.DebugTiming, + version: cfg.Version, + initialResponseBytesPreEncode: initialResponseCap, + sessions: make(map[[frame.SessionIDLen]byte]*session.Session), + sessionOwners: make(map[[frame.SessionIDLen]byte][frame.ClientIDLen]byte), + txReady: make(map[[frame.SessionIDLen]byte]struct{}), + firstReply: make(map[[frame.SessionIDLen]byte]struct{}), + upstreams: make(map[[frame.SessionIDLen]byte]net.Conn), + lastActivity: make(map[[frame.SessionIDLen]byte]time.Time), + dialFail: make(map[string]time.Time), + pendingRSTs: make(map[[frame.ClientIDLen]byte][]*frame.Frame), + pendingCtrl: make(map[[frame.ClientIDLen]byte][]*frame.Frame), + activity: make(map[[frame.ClientIDLen]byte]chan struct{}), } s.upstreamReadPool.New = func() interface{} { buf := make([]byte, upstreamReadBuf) @@ -695,7 +708,14 @@ func (s *Server) drainAll(owner [frame.ClientIDLen]byte, byteBudget int) ([]*fra if remaining < perSessionCap { perSessionCap = remaining } - frames := sess.DrainTxLimited(MaxFramePayload, perSessionCap) + maxPayload := MaxFramePayload + if _, isFirst := s.firstReply[id]; isFirst && perSessionCap > 0 { + firstPayload := (s.initialResponseBytesPreEncode + perSessionCap - 1) / perSessionCap + if firstPayload > 0 && firstPayload < maxPayload { + maxPayload = firstPayload + } + } + frames := sess.DrainTxLimited(maxPayload, perSessionCap) // Only clear from txReady when fully drained. A partial drain (cap // hit before all data + a trailing FIN could be emitted) needs to // stay queued, otherwise the session is stranded with no path back diff --git a/internal/exit/exit_timing_test.go b/internal/exit/exit_timing_test.go index 57060d6..9432dc8 100644 --- a/internal/exit/exit_timing_test.go +++ b/internal/exit/exit_timing_test.go @@ -524,6 +524,56 @@ func TestDrainAll_RespectsByteBudget(t *testing.T) { } } +func TestDrainAll_CapsInitialResponseOnly(t *testing.T) { + s := mustExitTimingServer(t) + id := benchSessionID(777) + var owner [frame.ClientIDLen]byte + owner[0] = 0x77 + + payload := bytes.Repeat([]byte("x"), maxResponseBytesPreEncode) + sess := session.New(id, "x:1", false) + sess.EnqueueTx(payload) + s.sessions[id] = sess + s.sessionOwners[id] = owner + s.firstReply[id] = struct{}{} + s.txReady[id] = struct{}{} + + frames, urgent := s.drainAll(owner, maxResponseBytesPreEncode) + if !urgent { + t.Fatal("first downstream response should be urgent") + } + firstBytes := sumFramePayloadBytes(frames) + if firstBytes == 0 { + t.Fatal("first drain returned no payload") + } + if firstBytes > s.initialResponseBytesPreEncode { + t.Fatalf("first response bytes = %d, want <= %d", firstBytes, s.initialResponseBytesPreEncode) + } + if _, stillFirst := s.firstReply[id]; stillFirst { + t.Fatal("firstReply marker was not cleared after first downstream drain") + } + if !sess.HasPendingTx() { + t.Fatal("test setup expected payload to remain after capped first response") + } + + frames, urgent = s.drainAll(owner, maxResponseBytesPreEncode) + if urgent { + t.Fatal("second downstream response should not be urgent after firstReply is cleared") + } + secondBytes := sumFramePayloadBytes(frames) + if secondBytes <= s.initialResponseBytesPreEncode { + t.Fatalf("second response bytes = %d, want normal larger drain above %d", secondBytes, s.initialResponseBytesPreEncode) + } +} + +func sumFramePayloadBytes(frames []*frame.Frame) int { + var total int + for _, f := range frames { + total += len(f.Payload) + } + return total +} + // BenchmarkExitRouteIncoming_NSessions measures the cost of routing a data // frame to one of N already-open sessions on the server. This surfaces any // regression in lock contention or per-frame routing work as session fan-out