From 8018915948bb51b09f02053537c8c27e2bc01b1e Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Tue, 3 Mar 2026 18:16:46 -0800 Subject: [PATCH 1/5] backoff on ssh attempt during registration --- go.mod | 1 + go.sum | 2 ++ pkg/cmd/enablessh/enablessh_test.go | 26 +++++++++++------------ pkg/cmd/register/register.go | 31 ++++++++++++++++++++++------ pkg/cmd/register/sshkeys.go | 32 +++++++++++++++++++---------- 5 files changed, 62 insertions(+), 30 deletions(-) diff --git a/go.mod b/go.mod index 982f8daf..1499fe46 100644 --- a/go.mod +++ b/go.mod @@ -57,6 +57,7 @@ require ( github.com/blang/semver/v4 v4.0.0 // indirect github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cloudflare/circl v1.6.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect diff --git a/go.sum b/go.sum index 78a1436c..d9bb3161 100644 --- a/go.sum +++ b/go.sum @@ -73,6 +73,8 @@ github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= diff --git a/pkg/cmd/enablessh/enablessh_test.go b/pkg/cmd/enablessh/enablessh_test.go index ba5ec30f..8c26b88f 100644 --- a/pkg/cmd/enablessh/enablessh_test.go +++ b/pkg/cmd/enablessh/enablessh_test.go @@ -31,7 +31,7 @@ func readAuthorizedKeys(t *testing.T, u *user.User) string { func Test_InstallAuthorizedKey_TagsKeyWithBrevComment(t *testing.T) { u := tempUser(t) - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -44,10 +44,10 @@ func Test_InstallAuthorizedKey_TagsKeyWithBrevComment(t *testing.T) { func Test_InstallAuthorizedKey_SkipsDuplicate(t *testing.T) { u := tempUser(t) - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("first install: %v", err) } - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("second install: %v", err) } @@ -70,7 +70,7 @@ func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) { t.Fatal(err) } - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -84,10 +84,10 @@ func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) { func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) { u := tempUser(t) - if err := register.InstallAuthorizedKey(u, ""); err != nil { + if _, err := register.InstallAuthorizedKey(u, ""); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } - if err := register.InstallAuthorizedKey(u, " "); err != nil { + if _, err := register.InstallAuthorizedKey(u, " "); err != nil { t.Fatalf("InstallAuthorizedKey (whitespace): %v", err) } @@ -101,7 +101,7 @@ func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) { func Test_InstallAuthorizedKey_CreatesSSHDir(t *testing.T) { u := tempUser(t) - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -126,7 +126,7 @@ func Test_InstallAuthorizedKey_PreservesExistingKeys(t *testing.T) { t.Fatal(err) } - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -152,7 +152,7 @@ func Test_InstallAuthorizedKey_TagsExistingUntaggedKey(t *testing.T) { } // InstallAuthorizedKey should tag the existing key rather than adding a duplicate. - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -366,10 +366,10 @@ func Test_InstallThenRemove_RoundTrip(t *testing.T) { } // Install two brev keys. - if err := register.InstallAuthorizedKey(u, "ssh-rsa KEY1"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY1"); err != nil { t.Fatal(err) } - if err := register.InstallAuthorizedKey(u, "ssh-rsa KEY2"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY2"); err != nil { t.Fatal(err) } @@ -393,10 +393,10 @@ func Test_InstallThenRemoveSpecificKey_RollbackScenario(t *testing.T) { u := tempUser(t) // Install two brev keys (simulating two users granted access). - if err := register.InstallAuthorizedKey(u, "ssh-rsa ALICE"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa ALICE"); err != nil { t.Fatal(err) } - if err := register.InstallAuthorizedKey(u, "ssh-rsa BOB"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa BOB"); err != nil { t.Fatal(err) } diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index e56c0e6c..69a1f99f 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -12,6 +12,7 @@ import ( nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" "connectrpc.com/connect" + "github.com/cenkalti/backoff/v4" "github.com/google/uuid" "github.com/brevdev/brev-cli/pkg/config" @@ -339,16 +340,34 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps t.Vprintf(" Linux user: %s\n", osUser.Username) t.Vprint("") - err := GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser) - if err != nil { - t.Vprint(" Retrying in 3 seconds...") - time.Sleep(3 * time.Second) - err = GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser) + op := func() error { + return GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser) + } + b := backoff.WithContext(newBackoff(), ctx) + notify := func(err error, d time.Duration) { + t.Vprintf(" SSH access not yet granted; retrying in: %s...\n", d.Round(backoffPrintRound)) } + err := backoff.RetryNotify(op, b, notify) if err != nil { - t.Vprintf(" Warning: %v\n", err) + t.Vprintf(" Warning: SSH access not granted: %v\n", err) return } t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName))) } + +const ( + backoffInitialInterval = 1 * time.Second + backoffMaxInterval = 10 * time.Second + backoffMaxElapsedTime = 1 * time.Minute + + backoffPrintRound = 500 * time.Millisecond +) + +func newBackoff() *backoff.ExponentialBackOff { + return backoff.NewExponentialBackOff( + backoff.WithInitialInterval(backoffInitialInterval), + backoff.WithMaxInterval(backoffMaxInterval), + backoff.WithMaxElapsedTime(backoffMaxElapsedTime), + ) +} diff --git a/pkg/cmd/register/sshkeys.go b/pkg/cmd/register/sshkeys.go index db36f01b..bef4fb34 100644 --- a/pkg/cmd/register/sshkeys.go +++ b/pkg/cmd/register/sshkeys.go @@ -2,6 +2,7 @@ package register import ( "context" + "errors" "fmt" "os" "os/user" @@ -34,9 +35,9 @@ func GrantSSHAccessToNode( osUser *user.User, ) error { if targetUser.PublicKey != "" { - if err := InstallAuthorizedKey(osUser, targetUser.PublicKey); err != nil { + if added, err := InstallAuthorizedKey(osUser, targetUser.PublicKey); err != nil { t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) - } else { + } else if added { t.Vprint(" Brev public key added to authorized_keys.") } } @@ -48,6 +49,14 @@ func GrantSSHAccessToNode( LinuxUser: osUser.Username, })) if err != nil { + // Transport errors (connection reset, EOF) are transient — leave the key + // installed so retries don't need to reinstall it, and signal the caller + // with a distinct error type. + var connectErr *connect.Error + if errors.As(err, &connectErr) && connectErr.Code() == connect.CodeInternal { + return fmt.Errorf("failed to grant SSH access (transient): %w", err) + } + // Permanent error — roll back the key so we don't leave an unrecorded entry. if targetUser.PublicKey != "" { if rerr := RemoveAuthorizedKey(osUser, targetUser.PublicKey); rerr != nil { t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr))) @@ -62,28 +71,29 @@ func GrantSSHAccessToNode( // InstallAuthorizedKey appends the given public key to the user's // ~/.ssh/authorized_keys if it isn't already present. The key is tagged with // a brev-cli comment so it can be removed later by RemoveBrevAuthorizedKeys. -func InstallAuthorizedKey(u *user.User, pubKey string) error { +// Returns true if the key was newly written, false if it was already present. +func InstallAuthorizedKey(u *user.User, pubKey string) (bool, error) { pubKey = strings.TrimSpace(pubKey) if pubKey == "" { - return nil + return false, nil } sshDir := filepath.Join(u.HomeDir, ".ssh") if err := os.MkdirAll(sshDir, 0o700); err != nil { - return fmt.Errorf("creating .ssh directory: %w", err) + return false, fmt.Errorf("creating .ssh directory: %w", err) } authKeysPath := filepath.Join(sshDir, "authorized_keys") existing, err := os.ReadFile(authKeysPath) // #nosec G304 if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("reading authorized_keys: %w", err) + return false, fmt.Errorf("reading authorized_keys: %w", err) } taggedKey := pubKey + " " + BrevKeyComment if strings.Contains(string(existing), taggedKey) { - return nil // already present with tag + return false, nil // already present with tag } // If the key exists but isn't tagged, replace it with the tagged version @@ -91,9 +101,9 @@ func InstallAuthorizedKey(u *user.User, pubKey string) error { if strings.Contains(string(existing), pubKey) { updated := strings.ReplaceAll(string(existing), pubKey, taggedKey) if err := os.WriteFile(authKeysPath, []byte(updated), 0o600); err != nil { - return fmt.Errorf("writing authorized_keys: %w", err) + return false, fmt.Errorf("writing authorized_keys: %w", err) } - return nil + return false, nil } // Ensure existing content ends with a newline before appending. @@ -104,10 +114,10 @@ func InstallAuthorizedKey(u *user.User, pubKey string) error { content += taggedKey + "\n" if err := os.WriteFile(authKeysPath, []byte(content), 0o600); err != nil { - return fmt.Errorf("writing authorized_keys: %w", err) + return false, fmt.Errorf("writing authorized_keys: %w", err) } - return nil + return true, nil } // RemoveAuthorizedKey removes a specific public key from the user's From d109761af9cd212595f775706555596f4b37405f Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Wed, 4 Mar 2026 09:01:55 -0800 Subject: [PATCH 2/5] retry cleanup --- pkg/cmd/register/register.go | 45 +++++++++++++++++++----------------- pkg/cmd/register/sshkeys.go | 17 ++++++++++++-- 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 69a1f99f..ed0477b4 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -24,6 +24,14 @@ import ( "github.com/spf13/cobra" ) +const ( + backoffInitialInterval = 1 * time.Second + backoffMaxInterval = 10 * time.Second + backoffMaxElapsedTime = 1 * time.Minute + + backoffPrintRound = 500 * time.Millisecond +) + // RegisterStore defines the store methods needed by the register command. type RegisterStore interface { GetCurrentUser() (*entity.User, error) @@ -340,14 +348,25 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps t.Vprintf(" Linux user: %s\n", osUser.Username) t.Vprint("") - op := func() error { - return GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser) + backoffCtx := backoff.WithContext(backoff.NewExponentialBackOff( + backoff.WithInitialInterval(backoffInitialInterval), + backoff.WithMaxInterval(backoffMaxInterval), + backoff.WithMaxElapsedTime(backoffMaxElapsedTime), + ), ctx) + + opToTry := func() error { + err := GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser) + if err != nil && !IsSSHConnectionError(err) { + return backoff.Permanent(err) + } + return err } - b := backoff.WithContext(newBackoff(), ctx) - notify := func(err error, d time.Duration) { + onOpErr := func(err error, d time.Duration) { t.Vprintf(" SSH access not yet granted; retrying in: %s...\n", d.Round(backoffPrintRound)) } - err := backoff.RetryNotify(op, b, notify) + + // Retry until the operation succeeds or the context is cancelled. + err := backoff.RetryNotify(opToTry, backoffCtx, onOpErr) if err != nil { t.Vprintf(" Warning: SSH access not granted: %v\n", err) return @@ -355,19 +374,3 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName))) } - -const ( - backoffInitialInterval = 1 * time.Second - backoffMaxInterval = 10 * time.Second - backoffMaxElapsedTime = 1 * time.Minute - - backoffPrintRound = 500 * time.Millisecond -) - -func newBackoff() *backoff.ExponentialBackOff { - return backoff.NewExponentialBackOff( - backoff.WithInitialInterval(backoffInitialInterval), - backoff.WithMaxInterval(backoffMaxInterval), - backoff.WithMaxElapsedTime(backoffMaxElapsedTime), - ) -} diff --git a/pkg/cmd/register/sshkeys.go b/pkg/cmd/register/sshkeys.go index bef4fb34..08ef5b84 100644 --- a/pkg/cmd/register/sshkeys.go +++ b/pkg/cmd/register/sshkeys.go @@ -18,6 +18,19 @@ import ( "github.com/brevdev/brev-cli/pkg/terminal" ) +// sshConnectionError marks an error as being due to a transient connection/transport failure +type sshConnectionError struct{ err error } + +func (e *sshConnectionError) Error() string { return e.err.Error() } +func (e *sshConnectionError) Unwrap() error { return e.err } + +// IsSSHConnectionError reports whether err indicates a transient connection/transport +// failure that may be retried. Used by grantSSHAccess to decide whether to backoff-retry. +func IsSSHConnectionError(err error) bool { + var e *sshConnectionError + return errors.As(err, &e) +} + // BrevKeyComment is the marker appended to every SSH key that Brev installs. // It allows RemoveBrevAuthorizedKeys to identify and remove exactly those keys. const BrevKeyComment = "# brev-cli" @@ -54,7 +67,7 @@ func GrantSSHAccessToNode( // with a distinct error type. var connectErr *connect.Error if errors.As(err, &connectErr) && connectErr.Code() == connect.CodeInternal { - return fmt.Errorf("failed to grant SSH access (transient): %w", err) + return &sshConnectionError{err: fmt.Errorf("failed to grant SSH access (transient): %w", err)} } // Permanent error — roll back the key so we don't leave an unrecorded entry. if targetUser.PublicKey != "" { @@ -103,7 +116,7 @@ func InstallAuthorizedKey(u *user.User, pubKey string) (bool, error) { if err := os.WriteFile(authKeysPath, []byte(updated), 0o600); err != nil { return false, fmt.Errorf("writing authorized_keys: %w", err) } - return false, nil + return true, nil } // Ensure existing content ends with a newline before appending. From edf020d4131b31ad70a28456e767f36d4281bab0 Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Wed, 4 Mar 2026 09:24:41 -0800 Subject: [PATCH 3/5] retry error --- pkg/cmd/register/register_test.go | 122 +++++++++++++++++++++++++++++ pkg/cmd/register/rpcclient_test.go | 18 ++++- 2 files changed, 137 insertions(+), 3 deletions(-) diff --git a/pkg/cmd/register/register_test.go b/pkg/cmd/register/register_test.go index 5aad1b80..59f8726c 100644 --- a/pkg/cmd/register/register_test.go +++ b/pkg/cmd/register/register_test.go @@ -499,3 +499,125 @@ Peers count: 0/0 Connected` }) } } + +func TestIsSSHConnectionError(t *testing.T) { + t.Run("nil", func(t *testing.T) { + if IsSSHConnectionError(nil) { + t.Error("IsSSHConnectionError(nil) should be false") + } + }) + t.Run("plain_error", func(t *testing.T) { + if IsSSHConnectionError(fmt.Errorf("some error")) { + t.Error("IsSSHConnectionError(plain error) should be false") + } + }) + t.Run("connection_error_type", func(t *testing.T) { + err := &sshConnectionError{err: fmt.Errorf("transient")} + if !IsSSHConnectionError(err) { + t.Error("IsSSHConnectionError(sshConnectionError) should be true") + } + }) + t.Run("wrapped_connection_error", func(t *testing.T) { + err := fmt.Errorf("wrapped: %w", &sshConnectionError{err: fmt.Errorf("transient")}) + if !IsSSHConnectionError(err) { + t.Error("IsSSHConnectionError(wrapped sshConnectionError) should be true") + } + }) +} + +func Test_runRegister_GrantSSH_retries_on_connection_error_then_succeeds(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + var grantCalls int + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + ConnectivityInfo: &nodev1.ConnectivityInfo{ + RegistrationCommand: "netbird up --key abc", + }, + }, + }, nil + }, + grantNodeSSHAccessFn: func(_ *nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) { + grantCalls++ + if grantCalls < 2 { + return nil, connect.NewError(connect.CodeInternal, nil) + } + return &nodev1.GrantNodeSSHAccessResponse{}, nil + }, + } + + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + deps.prompter = mockConfirmer{confirm: true} + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("runRegister failed: %v", err) + } + + if grantCalls != 2 { + t.Errorf("expected GrantNodeSSHAccess to be called 2 times (retry once), got %d", grantCalls) + } +} + +func Test_runRegister_GrantSSH_no_retry_on_permanent_error(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + var grantCalls int + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + ConnectivityInfo: &nodev1.ConnectivityInfo{ + RegistrationCommand: "netbird up --key abc", + }, + }, + }, nil + }, + grantNodeSSHAccessFn: func(_ *nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) { + grantCalls++ + return nil, connect.NewError(connect.CodePermissionDenied, nil) + }, + } + + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + deps.prompter = mockConfirmer{confirm: true} + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("runRegister should not fail the overall flow when SSH grant fails: %v", err) + } + + if grantCalls != 1 { + t.Errorf("expected GrantNodeSSHAccess to be called once (no retry on permanent error), got %d", grantCalls) + } +} diff --git a/pkg/cmd/register/rpcclient_test.go b/pkg/cmd/register/rpcclient_test.go index 69ea4bb2..ce45c079 100644 --- a/pkg/cmd/register/rpcclient_test.go +++ b/pkg/cmd/register/rpcclient_test.go @@ -143,9 +143,10 @@ func Test_toProtoNodeSpec_MinimalFields(t *testing.T) { // fakeNodeService implements the server side of ExternalNodeService for testing. type fakeNodeService struct { nodev1connect.UnimplementedExternalNodeServiceHandler - addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) - removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) - getNodeFn func(*nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) + addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) + removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) + getNodeFn func(*nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) + grantNodeSSHAccessFn func(*nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) } func (f *fakeNodeService) AddNode(_ context.Context, req *connect.Request[nodev1.AddNodeRequest]) (*connect.Response[nodev1.AddNodeResponse], error) { @@ -175,6 +176,17 @@ func (f *fakeNodeService) GetNode(_ context.Context, req *connect.Request[nodev1 return connect.NewResponse(resp), nil } +func (f *fakeNodeService) GrantNodeSSHAccess(_ context.Context, req *connect.Request[nodev1.GrantNodeSSHAccessRequest]) (*connect.Response[nodev1.GrantNodeSSHAccessResponse], error) { + if f.grantNodeSSHAccessFn == nil { + return nil, connect.NewError(connect.CodeUnimplemented, nil) + } + resp, err := f.grantNodeSSHAccessFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + func Test_NewNodeServiceClient_AddNode(t *testing.T) { svc := &fakeNodeService{ addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { From ae75746e95389ef9cf8cde15d9946145f6a1f7f0 Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Wed, 4 Mar 2026 09:39:24 -0800 Subject: [PATCH 4/5] lint, mod tidy, fmt --- go.mod | 2 +- pkg/cmd/register/register.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 1499fe46..8b44aa62 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/alessio/shellescape v1.4.1 github.com/brevdev/parse v0.0.11 github.com/briandowns/spinner v1.16.0 + github.com/cenkalti/backoff/v4 v4.3.0 github.com/fatih/color v1.13.0 github.com/getsentry/sentry-go v0.14.0 github.com/gin-gonic/gin v1.10.0 @@ -57,7 +58,6 @@ require ( github.com/blang/semver/v4 v4.0.0 // indirect github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect - github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cloudflare/circl v1.6.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index ed0477b4..87f738dc 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -365,7 +365,7 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps t.Vprintf(" SSH access not yet granted; retrying in: %s...\n", d.Round(backoffPrintRound)) } - // Retry until the operation succeeds or the context is cancelled. + // Retry until the operation succeeds or the context is canceled. err := backoff.RetryNotify(opToTry, backoffCtx, onOpErr) if err != nil { t.Vprintf(" Warning: SSH access not granted: %v\n", err) From fef15d907267cd36d8a50f6604cf93ce0347a16f Mon Sep 17 00:00:00 2001 From: Drew Malin Date: Wed, 4 Mar 2026 10:03:44 -0800 Subject: [PATCH 5/5] cleanup --- pkg/cmd/register/register.go | 39 +++-------------- pkg/cmd/register/register_test.go | 25 ----------- pkg/cmd/register/sshkeys.go | 71 ++++++++++++++++++------------- 3 files changed, 49 insertions(+), 86 deletions(-) diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 87f738dc..5049ff3f 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -12,7 +12,6 @@ import ( nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" "connectrpc.com/connect" - "github.com/cenkalti/backoff/v4" "github.com/google/uuid" "github.com/brevdev/brev-cli/pkg/config" @@ -24,14 +23,6 @@ import ( "github.com/spf13/cobra" ) -const ( - backoffInitialInterval = 1 * time.Second - backoffMaxInterval = 10 * time.Second - backoffMaxElapsedTime = 1 * time.Minute - - backoffPrintRound = 500 * time.Millisecond -) - // RegisterStore defines the store methods needed by the register command. type RegisterStore interface { GetCurrentUser() (*entity.User, error) @@ -217,7 +208,9 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam runSetup(node, t, deps) if deps.prompter.ConfirmYesNo("Would you like to enable SSH access to this device?") { - grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser) + if err := grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser); err != nil { + t.Vprintf(" Warning: SSH access not granted: %v\n", err) + } } return nil @@ -339,7 +332,7 @@ func runSetup(node *nodev1.ExternalNode, t *terminal.Terminal, deps registerDeps } } -func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) { +func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) error { t.Vprint("") t.Vprint(t.Green("Enabling SSH access on this device")) t.Vprint("") @@ -348,29 +341,11 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps t.Vprintf(" Linux user: %s\n", osUser.Username) t.Vprint("") - backoffCtx := backoff.WithContext(backoff.NewExponentialBackOff( - backoff.WithInitialInterval(backoffInitialInterval), - backoff.WithMaxInterval(backoffMaxInterval), - backoff.WithMaxElapsedTime(backoffMaxElapsedTime), - ), ctx) - - opToTry := func() error { - err := GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser) - if err != nil && !IsSSHConnectionError(err) { - return backoff.Permanent(err) - } - return err - } - onOpErr := func(err error, d time.Duration) { - t.Vprintf(" SSH access not yet granted; retrying in: %s...\n", d.Round(backoffPrintRound)) - } - - // Retry until the operation succeeds or the context is canceled. - err := backoff.RetryNotify(opToTry, backoffCtx, onOpErr) + err := GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser) if err != nil { - t.Vprintf(" Warning: SSH access not granted: %v\n", err) - return + return fmt.Errorf("grant SSH failed: %w", err) } t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName))) + return nil } diff --git a/pkg/cmd/register/register_test.go b/pkg/cmd/register/register_test.go index 59f8726c..450ae720 100644 --- a/pkg/cmd/register/register_test.go +++ b/pkg/cmd/register/register_test.go @@ -500,31 +500,6 @@ Peers count: 0/0 Connected` } } -func TestIsSSHConnectionError(t *testing.T) { - t.Run("nil", func(t *testing.T) { - if IsSSHConnectionError(nil) { - t.Error("IsSSHConnectionError(nil) should be false") - } - }) - t.Run("plain_error", func(t *testing.T) { - if IsSSHConnectionError(fmt.Errorf("some error")) { - t.Error("IsSSHConnectionError(plain error) should be false") - } - }) - t.Run("connection_error_type", func(t *testing.T) { - err := &sshConnectionError{err: fmt.Errorf("transient")} - if !IsSSHConnectionError(err) { - t.Error("IsSSHConnectionError(sshConnectionError) should be true") - } - }) - t.Run("wrapped_connection_error", func(t *testing.T) { - err := fmt.Errorf("wrapped: %w", &sshConnectionError{err: fmt.Errorf("transient")}) - if !IsSSHConnectionError(err) { - t.Error("IsSSHConnectionError(wrapped sshConnectionError) should be true") - } - }) -} - func Test_runRegister_GrantSSH_retries_on_connection_error_then_succeeds(t *testing.T) { regStore := &mockRegistrationStore{} diff --git a/pkg/cmd/register/sshkeys.go b/pkg/cmd/register/sshkeys.go index 08ef5b84..aad47e95 100644 --- a/pkg/cmd/register/sshkeys.go +++ b/pkg/cmd/register/sshkeys.go @@ -8,6 +8,7 @@ import ( "os/user" "path/filepath" "strings" + "time" nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" "connectrpc.com/connect" @@ -16,20 +17,16 @@ import ( "github.com/brevdev/brev-cli/pkg/entity" "github.com/brevdev/brev-cli/pkg/externalnode" "github.com/brevdev/brev-cli/pkg/terminal" + "github.com/cenkalti/backoff/v4" ) -// sshConnectionError marks an error as being due to a transient connection/transport failure -type sshConnectionError struct{ err error } +const ( + backoffInitialInterval = 1 * time.Second + backoffMaxInterval = 10 * time.Second + backoffMaxElapsedTime = 1 * time.Minute -func (e *sshConnectionError) Error() string { return e.err.Error() } -func (e *sshConnectionError) Unwrap() error { return e.err } - -// IsSSHConnectionError reports whether err indicates a transient connection/transport -// failure that may be retried. Used by grantSSHAccess to decide whether to backoff-retry. -func IsSSHConnectionError(err error) bool { - var e *sshConnectionError - return errors.As(err, &e) -} + backoffPrintRound = 500 * time.Millisecond +) // BrevKeyComment is the marker appended to every SSH key that Brev installs. // It allows RemoveBrevAuthorizedKeys to identify and remove exactly those keys. @@ -56,28 +53,44 @@ func GrantSSHAccessToNode( } client := nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL()) - _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ - ExternalNodeId: reg.ExternalNodeID, - UserId: targetUser.ID, - LinuxUser: osUser.Username, - })) - if err != nil { - // Transport errors (connection reset, EOF) are transient — leave the key - // installed so retries don't need to reinstall it, and signal the caller - // with a distinct error type. - var connectErr *connect.Error - if errors.As(err, &connectErr) && connectErr.Code() == connect.CodeInternal { - return &sshConnectionError{err: fmt.Errorf("failed to grant SSH access (transient): %w", err)} - } - // Permanent error — roll back the key so we don't leave an unrecorded entry. - if targetUser.PublicKey != "" { - if rerr := RemoveAuthorizedKey(osUser, targetUser.PublicKey); rerr != nil { - t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr))) + + backoffCtx := backoff.WithContext(backoff.NewExponentialBackOff( + backoff.WithInitialInterval(backoffInitialInterval), + backoff.WithMaxInterval(backoffMaxInterval), + backoff.WithMaxElapsedTime(backoffMaxElapsedTime), + ), ctx) + + opToTry := func() error { + _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ + ExternalNodeId: reg.ExternalNodeID, + UserId: targetUser.ID, + LinuxUser: osUser.Username, + })) + if err != nil { + // Retryable error + var connectErr *connect.Error + if errors.As(err, &connectErr) && connectErr.Code() == connect.CodeInternal { + return fmt.Errorf("failed to grant SSH access (transient): %w", err) + } + + // Permanent error — roll back the key so we don't leave an unrecorded entry and abort the backoff retry + if targetUser.PublicKey != "" { + if rerr := RemoveAuthorizedKey(osUser, targetUser.PublicKey); rerr != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr))) + } } + return backoff.Permanent(fmt.Errorf("failed to grant SSH access: %w", err)) } + + return nil + } + onOpErr := func(err error, d time.Duration) { + t.Vprintf(" SSH access not yet granted; retrying in: %s...\n", d.Round(backoffPrintRound)) + } + err := backoff.RetryNotify(opToTry, backoffCtx, onOpErr) + if err != nil { return fmt.Errorf("failed to grant SSH access: %w", err) } - return nil }