diff --git a/internal/carrier/client.go b/internal/carrier/client.go index b719a4e..58a7425 100644 --- a/internal/carrier/client.go +++ b/internal/carrier/client.go @@ -4,14 +4,18 @@ import ( "bytes" "context" "crypto/rand" + "errors" "fmt" "io" "log" + "net" "net/http" + "os" "sort" "strings" "sync" "sync/atomic" + "syscall" "time" "github.com/kianmhz/GooseRelayVPN/internal/frame" @@ -65,8 +69,87 @@ const ( // or tail-latency events without changing protocol behavior. endpointBlacklistBaseTTL = 3 * time.Second endpointBlacklistMaxTTL = 1 * time.Hour + + // Local offline failures should not ramp a mobile client into the 30m/1h + // endpoint penalty box. Keep the pause long enough to avoid a tight retry + // loop while airplane mode is on, but short enough that new sessions recover + // quickly when the network returns. + localNetworkOfflineBlacklistTTL = 15 * time.Second + localNetworkRecoveryProbeEvery = 5 * time.Second + localNetworkRecoveryProbeTO = 2 * time.Second ) +func isLocalNetworkOffline(err error) bool { + if err == nil { + return false + } + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) { + if dnsErr.IsTimeout || dnsErr.IsTemporary || dnsErr.IsNotFound { + return true + } + } + var opErr *net.OpError + if errors.As(err, &opErr) && strings.EqualFold(opErr.Op, "dial") { + if opErr.Timeout() || errors.Is(opErr.Err, context.DeadlineExceeded) { + return true + } + } + var syscallErr *os.SyscallError + if errors.As(err, &syscallErr) && isLocalOfflineSyscall(syscallErr.Err) { + return true + } + if isLocalOfflineSyscall(err) { + return true + } + + // Last-resort fallback for platform-specific wrapped messages, especially + // Windows WSA errors whose Errno values do not always compare cleanly after + // net/http wraps them in url.Error/net.OpError. + msg := strings.ToLower(err.Error()) + for _, needle := range []string{ + "network is unreachable", + "unreachable network", + "no route to host", + "network is down", + "host is down", + "host is unreachable", + "temporary failure in name resolution", + "no such host", + } { + if strings.Contains(msg, needle) { + return true + } + } + return false +} + +func isLocalOfflineSyscall(err error) bool { + for _, target := range []error{ + syscall.ENETUNREACH, + syscall.EHOSTUNREACH, + syscall.ENETDOWN, + syscall.EHOSTDOWN, + syscall.ENONET, + } { + if errors.Is(err, target) { + return true + } + } + return false +} + +func recoveryProbeAddress(cfg Config) string { + addr := strings.TrimSpace(cfg.Fronting.GoogleIP) + if addr == "" { + return "" + } + if _, _, err := net.SplitHostPort(addr); err == nil { + return addr + } + return net.JoinHostPort(addr, "443") +} + // Config bundles everything the carrier needs to talk to the relay. type Config struct { ScriptURLs []string // one or more full https://script.google.com/macros/s/.../exec URLs @@ -99,12 +182,13 @@ type Config struct { } type relayEndpoint struct { - url string - account string // optional human-readable Google account label, "" = unlabeled - blacklistedTill time.Time - failCount int - statsOK uint64 - statsFail uint64 + url string + account string // optional human-readable Google account label, "" = unlabeled + blacklistedTill time.Time + localNetworkOffline bool + failCount int + statsOK uint64 + statsFail uint64 // Per-quota-window counters. dailyCount is the number of HTTP responses // received from Apps Script in the current window; dailyResetAt is the @@ -163,7 +247,7 @@ type Client struct { numWorkers int // (workersPerEndpoint + idleSlotsPerBucket - 1) × bucketCount bucketCount int // distinct account labels in endpoints; unlabeled all share one bucket idleSlotsPerBucket int // resolved from Config.IdleSlotsPerBucket, default 1 - clientVersion string + clientVersion string // clientID is a random 16-byte identifier minted once per process. It is // embedded in every encrypted batch so the server can route downstream @@ -197,6 +281,8 @@ type Client struct { coalesceMu sync.Mutex coalesceTimer *time.Timer // armed during a coalesce window; nil otherwise coalesceDeadline time.Time // hard cap for the in-flight window + + recoveryProbeAddr string } // clientStats holds atomic counters surfaced periodically by statsLoop. @@ -305,6 +391,7 @@ func New(cfg Config) (*Client, error) { wake: newWaker(), coalesceStep: cfg.CoalesceStep, coalesceMax: cfg.CoalesceMax, + recoveryProbeAddr: recoveryProbeAddress(cfg), }, nil } @@ -412,6 +499,11 @@ func (c *Client) Run(ctx context.Context) error { defer wg.Done() c.runScriptStatsLoop(ctx) }() + wg.Add(1) + go func() { + defer wg.Done() + c.runEndpointRecoveryLoop(ctx) + }() wg.Wait() return ctx.Err() } @@ -446,6 +538,65 @@ func (c *Client) runWorker(ctx context.Context) { } } +func (c *Client) runEndpointRecoveryLoop(ctx context.Context) { + if c.recoveryProbeAddr == "" { + return + } + t := time.NewTicker(localNetworkRecoveryProbeEvery) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + if c.runEndpointRecoveryProbeOnce(ctx) { + c.wake.Broadcast() + } + } + } +} + +func (c *Client) runEndpointRecoveryProbeOnce(ctx context.Context) bool { + if c.recoveryProbeAddr == "" || !c.shouldRunLocalNetworkRecoveryProbe() { + return false + } + probeCtx, cancel := context.WithTimeout(ctx, localNetworkRecoveryProbeTO) + defer cancel() + dialer := net.Dialer{Timeout: localNetworkRecoveryProbeTO} + conn, err := dialer.DialContext(probeCtx, "tcp", c.recoveryProbeAddr) + if err != nil { + return false + } + _ = conn.Close() + cleared := c.resetLocalNetworkFailures() + if cleared > 0 { + log.Printf("[carrier] local network appears reachable again; cleared %d local-offline endpoint backoff(s)", cleared) + } + return cleared > 0 +} + +func (c *Client) shouldRunLocalNetworkRecoveryProbe() bool { + c.endpointMu.Lock() + defer c.endpointMu.Unlock() + if len(c.endpoints) == 0 { + return false + } + now := time.Now() + allUnavailable := true + hasLocalOffline := false + for i := range c.endpoints { + ep := &c.endpoints[i] + if !ep.blacklistedTill.After(now) { + allUnavailable = false + break + } + if ep.localNetworkOffline && ep.blacklistedTill.After(now) { + hasLocalOffline = true + } + } + return allUnavailable && hasLocalOffline +} + // idleBackoff returns how long a worker should sleep after n consecutive // no-work polls. The wake channel is selected against this timer so any // new TX (kick) cancels the sleep immediately and any held server-side @@ -577,7 +728,11 @@ func (c *Client) pollOnce(ctx context.Context) bool { if ctx.Err() != nil { return false } - c.markEndpointFailure(endpointIdx) + if isLocalNetworkOffline(err) { + c.markEndpointLocalNetworkFailure(endpointIdx) + } else { + c.markEndpointFailure(endpointIdx) + } if attempt < maxAttempts { log.Printf("[carrier] relay request failed via %s (attempt %d/%d): %v; retrying alternate script", shortScriptKey(scriptURL), attempt, maxAttempts, err) continue @@ -745,6 +900,23 @@ func (c *Client) pickRelayEndpoint() (int, string) { return -1, "" } +func (c *Client) resetLocalNetworkFailures() int { + c.endpointMu.Lock() + defer c.endpointMu.Unlock() + cleared := 0 + for i := range c.endpoints { + ep := &c.endpoints[i] + if !ep.localNetworkOffline { + continue + } + ep.blacklistedTill = time.Time{} + ep.failCount = 0 + ep.localNetworkOffline = false + cleared++ + } + return cleared +} + func (c *Client) markEndpointSuccess(endpointIdx int) { c.endpointMu.Lock() if endpointIdx < 0 || endpointIdx >= len(c.endpoints) { @@ -757,6 +929,7 @@ func (c *Client) markEndpointSuccess(endpointIdx int) { url := ep.url ep.failCount = 0 ep.blacklistedTill = time.Time{} + ep.localNetworkOffline = false c.endpointMu.Unlock() if wasFailing { log.Printf("[carrier] endpoint %s recovered (back in rotation)", shortScriptKey(url)) @@ -769,6 +942,27 @@ func (c *Client) markEndpointFailure(endpointIdx int) { c.markEndpointFailureWith(endpointIdx, 0) } +func (c *Client) markEndpointLocalNetworkFailure(endpointIdx int) { + c.endpointMu.Lock() + if endpointIdx < 0 || endpointIdx >= len(c.endpoints) { + c.endpointMu.Unlock() + return + } + ep := &c.endpoints[endpointIdx] + wasHealthy := ep.failCount == 0 && !ep.blacklistedTill.After(time.Now()) + ep.failCount = 0 + ep.statsFail++ + ep.localNetworkOffline = true + ep.blacklistedTill = time.Now().Add(localNetworkOfflineBlacklistTTL) + url := ep.url + peerCount := len(c.endpoints) - 1 + c.endpointMu.Unlock() + if wasHealthy { + log.Printf("[carrier] endpoint %s local network offline; retrying in %s (still rotating across %d others)", + shortScriptKey(url), localNetworkOfflineBlacklistTTL.Round(time.Second), peerCount) + } +} + // markEndpoint403 handles HTTP 403 (quota exhausted or deployment misconfigured). // Quota walls don't self-heal in seconds; they persist until midnight Pacific. // Jump straight to the 5-minute tier (failCount floor = 5 → next hit → 6 → 5 min) @@ -807,6 +1001,7 @@ func (c *Client) markEndpointFailureWith(endpointIdx, minFailCount int) { } ep.failCount++ ep.statsFail++ + ep.localNetworkOffline = false ttl := endpointBlacklistTTL(ep.failCount) ep.blacklistedTill = time.Now().Add(ttl) url := ep.url diff --git a/internal/carrier/client_blacklist_test.go b/internal/carrier/client_blacklist_test.go index 9e2797e..0c5cfdb 100644 --- a/internal/carrier/client_blacklist_test.go +++ b/internal/carrier/client_blacklist_test.go @@ -2,17 +2,29 @@ package carrier import ( "context" + "errors" "io" + "net" "net/http" "net/http/httptest" + "net/url" + "os" + "strings" "sync" "sync/atomic" + "syscall" "testing" "time" "github.com/kianmhz/GooseRelayVPN/internal/frame" ) +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + // TestEndpointFullRecoveryFromHighFailCount: a single successful response must // fully clear failCount and blacklistedTill, regardless of how badly the // endpoint was previously failing. This is the load-bearing invariant for @@ -135,6 +147,163 @@ func TestPickRelayEndpointAllBlacklistedRefuses(t *testing.T) { } } +func TestLocalNetworkOfflineClassificationAndBackoff(t *testing.T) { + wrapped := &url.Error{ + Op: "Post", + URL: "https://script.google.com/macros/s/test/exec", + Err: &net.OpError{ + Op: "dial", + Err: &os.SyscallError{Syscall: "connect", Err: syscall.ENETUNREACH}, + }, + } + if !isLocalNetworkOffline(wrapped) { + t.Fatal("wrapped ENETUNREACH dial error should be classified as local offline") + } + if isLocalNetworkOffline(errors.New("relay returned HTTP 500")) { + t.Fatal("generic relay/server failure must not be classified as local offline") + } + + c, err := New(Config{ + ScriptURLs: []string{"https://example.invalid/exec"}, + AESKeyHex: testKeyHex, + }) + if err != nil { + t.Fatalf("new client: %v", err) + } + c.markEndpointLocalNetworkFailure(0) + + c.endpointMu.Lock() + ep := c.endpoints[0] + c.endpointMu.Unlock() + if ep.failCount != 0 { + t.Fatalf("local offline failCount = %d, want 0 so standard backoff tiers do not ramp", ep.failCount) + } + if !ep.localNetworkOffline { + t.Fatal("localNetworkOffline flag not set") + } + remaining := time.Until(ep.blacklistedTill) + if remaining <= 0 || remaining > localNetworkOfflineBlacklistTTL+2*time.Second { + t.Fatalf("local offline blacklist remaining = %v, want short cap around %v", remaining, localNetworkOfflineBlacklistTTL) + } +} + +func TestRecoveryProbeClearsOnlyLocalNetworkFailures(t *testing.T) { + c, err := New(Config{ + ScriptURLs: []string{ + "https://local-offline.example/exec", + "https://generic-failure.example/exec", + }, + AESKeyHex: testKeyHex, + }) + if err != nil { + t.Fatalf("new client: %v", err) + } + now := time.Now() + c.endpointMu.Lock() + c.endpoints[0].blacklistedTill = now.Add(time.Minute) + c.endpoints[0].localNetworkOffline = true + c.endpoints[1].blacklistedTill = now.Add(time.Minute) + c.endpoints[1].failCount = 7 + c.endpointMu.Unlock() + + cleared := c.resetLocalNetworkFailures() + if cleared != 1 { + t.Fatalf("resetLocalNetworkFailures cleared %d endpoint(s), want 1", cleared) + } + + c.endpointMu.Lock() + first := c.endpoints[0] + second := c.endpoints[1] + c.endpointMu.Unlock() + if !first.blacklistedTill.IsZero() || first.localNetworkOffline { + t.Fatalf("local-offline endpoint was not fully reset: %+v", first) + } + if second.blacklistedTill.IsZero() || second.failCount != 7 || second.localNetworkOffline { + t.Fatalf("generic blacklist should be preserved, got: %+v", second) + } +} + +func TestRecoveryProbeClearsLocalNetworkFailuresWhenNetworkReturns(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + go func() { + conn, err := ln.Accept() + if err == nil { + _ = conn.Close() + } + }() + + c, err := New(Config{ + ScriptURLs: []string{"https://local-offline.example/exec"}, + AESKeyHex: testKeyHex, + }) + if err != nil { + t.Fatalf("new client: %v", err) + } + c.recoveryProbeAddr = ln.Addr().String() + c.endpointMu.Lock() + c.endpoints[0].blacklistedTill = time.Now().Add(time.Minute) + c.endpoints[0].localNetworkOffline = true + c.endpointMu.Unlock() + + if !c.runEndpointRecoveryProbeOnce(context.Background()) { + t.Fatal("recovery probe did not report a successful reset") + } + c.endpointMu.Lock() + ep := c.endpoints[0] + c.endpointMu.Unlock() + if !ep.blacklistedTill.IsZero() || ep.failCount != 0 || ep.localNetworkOffline { + t.Fatalf("local network recovery did not clear transient backoff: %+v", ep) + } +} + +func TestPollOnceMarksOnlyDoErrorsAsLocalNetworkFailures(t *testing.T) { + offlineErr := &net.OpError{ + Op: "dial", + Err: &os.SyscallError{Syscall: "connect", Err: syscall.ENETUNREACH}, + } + c, err := New(Config{ScriptURLs: []string{"http://offline.example/exec"}, AESKeyHex: testKeyHex}) + if err != nil { + t.Fatalf("new client: %v", err) + } + c.httpClients = []*http.Client{{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, offlineErr + }), + }} + c.pollOnce(context.Background()) + c.endpointMu.Lock() + local := c.endpoints[0] + c.endpointMu.Unlock() + if local.failCount != 0 || !local.localNetworkOffline { + t.Fatalf("Do dial error should use local offline backoff, got failCount=%d local=%v", local.failCount, local.localNetworkOffline) + } + + c2, err := New(Config{ScriptURLs: []string{"http://server-error.example/exec"}, AESKeyHex: testKeyHex}) + if err != nil { + t.Fatalf("new client: %v", err) + } + c2.httpClients = []*http.Client{{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(strings.NewReader("server error")), + Header: make(http.Header), + }, nil + }), + }} + c2.pollOnce(context.Background()) + c2.endpointMu.Lock() + generic := c2.endpoints[0] + c2.endpointMu.Unlock() + if generic.failCount == 0 || generic.localNetworkOffline { + t.Fatalf("HTTP 500 should use normal endpoint failure, got failCount=%d local=%v", generic.failCount, generic.localNetworkOffline) + } +} + // TestPollOnce_AllBlacklistedSendsNoTraffic: integration check that no HTTP // request goes out when every endpoint is blacklisted. Before the fix, the // carrier kept POSTing to the soonest-expiring endpoint at the idle-backoff @@ -191,13 +360,13 @@ func TestPollOnce_AllBlacklistedSendsNoTraffic(t *testing.T) { // many decoded frames it has seen after the outage ended, which is the signal // for whether the carrier retransmitted dropped frames. type blacklistHammerServer struct { - t *testing.T - aead *frame.Crypto - hits atomic.Int64 - outage atomic.Bool - framesAfterOutage atomic.Int64 - rxSeqMu sync.Mutex - rxSeq map[[frame.SessionIDLen]byte]uint64 + t *testing.T + aead *frame.Crypto + hits atomic.Int64 + outage atomic.Bool + framesAfterOutage atomic.Int64 + rxSeqMu sync.Mutex + rxSeq map[[frame.SessionIDLen]byte]uint64 } func newBlacklistHammerServer(t *testing.T, aead *frame.Crypto) *blacklistHammerServer {