diff --git a/go.mod b/go.mod index 982f8daf..8b44aa62 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/alessio/shellescape v1.4.1 github.com/brevdev/parse v0.0.11 github.com/briandowns/spinner v1.16.0 + github.com/cenkalti/backoff/v4 v4.3.0 github.com/fatih/color v1.13.0 github.com/getsentry/sentry-go v0.14.0 github.com/gin-gonic/gin v1.10.0 diff --git a/go.sum b/go.sum index 78a1436c..d9bb3161 100644 --- a/go.sum +++ b/go.sum @@ -73,6 +73,8 @@ github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= diff --git a/pkg/cmd/enablessh/enablessh_test.go b/pkg/cmd/enablessh/enablessh_test.go index ba5ec30f..8c26b88f 100644 --- a/pkg/cmd/enablessh/enablessh_test.go +++ b/pkg/cmd/enablessh/enablessh_test.go @@ -31,7 +31,7 @@ func readAuthorizedKeys(t *testing.T, u *user.User) string { func Test_InstallAuthorizedKey_TagsKeyWithBrevComment(t *testing.T) { u := tempUser(t) - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -44,10 +44,10 @@ func Test_InstallAuthorizedKey_TagsKeyWithBrevComment(t *testing.T) { func Test_InstallAuthorizedKey_SkipsDuplicate(t *testing.T) { u := tempUser(t) - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("first install: %v", err) } - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("second install: %v", err) } @@ -70,7 +70,7 @@ func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) { t.Fatal(err) } - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -84,10 +84,10 @@ func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) { func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) { u := tempUser(t) - if err := register.InstallAuthorizedKey(u, ""); err != nil { + if _, err := register.InstallAuthorizedKey(u, ""); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } - if err := register.InstallAuthorizedKey(u, " "); err != nil { + if _, err := register.InstallAuthorizedKey(u, " "); err != nil { t.Fatalf("InstallAuthorizedKey (whitespace): %v", err) } @@ -101,7 +101,7 @@ func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) { func Test_InstallAuthorizedKey_CreatesSSHDir(t *testing.T) { u := tempUser(t) - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -126,7 +126,7 @@ func Test_InstallAuthorizedKey_PreservesExistingKeys(t *testing.T) { t.Fatal(err) } - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -152,7 +152,7 @@ func Test_InstallAuthorizedKey_TagsExistingUntaggedKey(t *testing.T) { } // InstallAuthorizedKey should tag the existing key rather than adding a duplicate. - if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -366,10 +366,10 @@ func Test_InstallThenRemove_RoundTrip(t *testing.T) { } // Install two brev keys. - if err := register.InstallAuthorizedKey(u, "ssh-rsa KEY1"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY1"); err != nil { t.Fatal(err) } - if err := register.InstallAuthorizedKey(u, "ssh-rsa KEY2"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY2"); err != nil { t.Fatal(err) } @@ -393,10 +393,10 @@ func Test_InstallThenRemoveSpecificKey_RollbackScenario(t *testing.T) { u := tempUser(t) // Install two brev keys (simulating two users granted access). - if err := register.InstallAuthorizedKey(u, "ssh-rsa ALICE"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa ALICE"); err != nil { t.Fatal(err) } - if err := register.InstallAuthorizedKey(u, "ssh-rsa BOB"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa BOB"); err != nil { t.Fatal(err) } diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index e56c0e6c..5049ff3f 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -208,7 +208,9 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam runSetup(node, t, deps) if deps.prompter.ConfirmYesNo("Would you like to enable SSH access to this device?") { - grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser) + if err := grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser); err != nil { + t.Vprintf(" Warning: SSH access not granted: %v\n", err) + } } return nil @@ -330,7 +332,7 @@ func runSetup(node *nodev1.ExternalNode, t *terminal.Terminal, deps registerDeps } } -func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) { +func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) error { t.Vprint("") t.Vprint(t.Green("Enabling SSH access on this device")) t.Vprint("") @@ -341,14 +343,9 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps err := GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser) if err != nil { - t.Vprint(" Retrying in 3 seconds...") - time.Sleep(3 * time.Second) - err = GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser) - } - if err != nil { - t.Vprintf(" Warning: %v\n", err) - return + return fmt.Errorf("grant SSH failed: %w", err) } t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName))) + return nil } diff --git a/pkg/cmd/register/register_test.go b/pkg/cmd/register/register_test.go index 5aad1b80..450ae720 100644 --- a/pkg/cmd/register/register_test.go +++ b/pkg/cmd/register/register_test.go @@ -499,3 +499,100 @@ Peers count: 0/0 Connected` }) } } + +func Test_runRegister_GrantSSH_retries_on_connection_error_then_succeeds(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + var grantCalls int + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + ConnectivityInfo: &nodev1.ConnectivityInfo{ + RegistrationCommand: "netbird up --key abc", + }, + }, + }, nil + }, + grantNodeSSHAccessFn: func(_ *nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) { + grantCalls++ + if grantCalls < 2 { + return nil, connect.NewError(connect.CodeInternal, nil) + } + return &nodev1.GrantNodeSSHAccessResponse{}, nil + }, + } + + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + deps.prompter = mockConfirmer{confirm: true} + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("runRegister failed: %v", err) + } + + if grantCalls != 2 { + t.Errorf("expected GrantNodeSSHAccess to be called 2 times (retry once), got %d", grantCalls) + } +} + +func Test_runRegister_GrantSSH_no_retry_on_permanent_error(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + var grantCalls int + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + ConnectivityInfo: &nodev1.ConnectivityInfo{ + RegistrationCommand: "netbird up --key abc", + }, + }, + }, nil + }, + grantNodeSSHAccessFn: func(_ *nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) { + grantCalls++ + return nil, connect.NewError(connect.CodePermissionDenied, nil) + }, + } + + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + deps.prompter = mockConfirmer{confirm: true} + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("runRegister should not fail the overall flow when SSH grant fails: %v", err) + } + + if grantCalls != 1 { + t.Errorf("expected GrantNodeSSHAccess to be called once (no retry on permanent error), got %d", grantCalls) + } +} diff --git a/pkg/cmd/register/rpcclient_test.go b/pkg/cmd/register/rpcclient_test.go index 69ea4bb2..ce45c079 100644 --- a/pkg/cmd/register/rpcclient_test.go +++ b/pkg/cmd/register/rpcclient_test.go @@ -143,9 +143,10 @@ func Test_toProtoNodeSpec_MinimalFields(t *testing.T) { // fakeNodeService implements the server side of ExternalNodeService for testing. type fakeNodeService struct { nodev1connect.UnimplementedExternalNodeServiceHandler - addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) - removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) - getNodeFn func(*nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) + addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) + removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) + getNodeFn func(*nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) + grantNodeSSHAccessFn func(*nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) } func (f *fakeNodeService) AddNode(_ context.Context, req *connect.Request[nodev1.AddNodeRequest]) (*connect.Response[nodev1.AddNodeResponse], error) { @@ -175,6 +176,17 @@ func (f *fakeNodeService) GetNode(_ context.Context, req *connect.Request[nodev1 return connect.NewResponse(resp), nil } +func (f *fakeNodeService) GrantNodeSSHAccess(_ context.Context, req *connect.Request[nodev1.GrantNodeSSHAccessRequest]) (*connect.Response[nodev1.GrantNodeSSHAccessResponse], error) { + if f.grantNodeSSHAccessFn == nil { + return nil, connect.NewError(connect.CodeUnimplemented, nil) + } + resp, err := f.grantNodeSSHAccessFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + func Test_NewNodeServiceClient_AddNode(t *testing.T) { svc := &fakeNodeService{ addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { diff --git a/pkg/cmd/register/sshkeys.go b/pkg/cmd/register/sshkeys.go index db36f01b..aad47e95 100644 --- a/pkg/cmd/register/sshkeys.go +++ b/pkg/cmd/register/sshkeys.go @@ -2,11 +2,13 @@ package register import ( "context" + "errors" "fmt" "os" "os/user" "path/filepath" "strings" + "time" nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" "connectrpc.com/connect" @@ -15,6 +17,15 @@ import ( "github.com/brevdev/brev-cli/pkg/entity" "github.com/brevdev/brev-cli/pkg/externalnode" "github.com/brevdev/brev-cli/pkg/terminal" + "github.com/cenkalti/backoff/v4" +) + +const ( + backoffInitialInterval = 1 * time.Second + backoffMaxInterval = 10 * time.Second + backoffMaxElapsedTime = 1 * time.Minute + + backoffPrintRound = 500 * time.Millisecond ) // BrevKeyComment is the marker appended to every SSH key that Brev installs. @@ -34,56 +45,81 @@ func GrantSSHAccessToNode( osUser *user.User, ) error { if targetUser.PublicKey != "" { - if err := InstallAuthorizedKey(osUser, targetUser.PublicKey); err != nil { + if added, err := InstallAuthorizedKey(osUser, targetUser.PublicKey); err != nil { t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) - } else { + } else if added { t.Vprint(" Brev public key added to authorized_keys.") } } client := nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL()) - _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ - ExternalNodeId: reg.ExternalNodeID, - UserId: targetUser.ID, - LinuxUser: osUser.Username, - })) - if err != nil { - if targetUser.PublicKey != "" { - if rerr := RemoveAuthorizedKey(osUser, targetUser.PublicKey); rerr != nil { - t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr))) + + backoffCtx := backoff.WithContext(backoff.NewExponentialBackOff( + backoff.WithInitialInterval(backoffInitialInterval), + backoff.WithMaxInterval(backoffMaxInterval), + backoff.WithMaxElapsedTime(backoffMaxElapsedTime), + ), ctx) + + opToTry := func() error { + _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ + ExternalNodeId: reg.ExternalNodeID, + UserId: targetUser.ID, + LinuxUser: osUser.Username, + })) + if err != nil { + // Retryable error + var connectErr *connect.Error + if errors.As(err, &connectErr) && connectErr.Code() == connect.CodeInternal { + return fmt.Errorf("failed to grant SSH access (transient): %w", err) } + + // Permanent error — roll back the key so we don't leave an unrecorded entry and abort the backoff retry + if targetUser.PublicKey != "" { + if rerr := RemoveAuthorizedKey(osUser, targetUser.PublicKey); rerr != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr))) + } + } + return backoff.Permanent(fmt.Errorf("failed to grant SSH access: %w", err)) } + + return nil + } + onOpErr := func(err error, d time.Duration) { + t.Vprintf(" SSH access not yet granted; retrying in: %s...\n", d.Round(backoffPrintRound)) + } + err := backoff.RetryNotify(opToTry, backoffCtx, onOpErr) + if err != nil { return fmt.Errorf("failed to grant SSH access: %w", err) } - return nil } // InstallAuthorizedKey appends the given public key to the user's // ~/.ssh/authorized_keys if it isn't already present. The key is tagged with // a brev-cli comment so it can be removed later by RemoveBrevAuthorizedKeys. -func InstallAuthorizedKey(u *user.User, pubKey string) error { +// Returns true if the key was newly written, false if it was already present. +func InstallAuthorizedKey(u *user.User, pubKey string) (bool, error) { pubKey = strings.TrimSpace(pubKey) if pubKey == "" { - return nil + return false, nil } sshDir := filepath.Join(u.HomeDir, ".ssh") if err := os.MkdirAll(sshDir, 0o700); err != nil { - return fmt.Errorf("creating .ssh directory: %w", err) + return false, fmt.Errorf("creating .ssh directory: %w", err) } authKeysPath := filepath.Join(sshDir, "authorized_keys") existing, err := os.ReadFile(authKeysPath) // #nosec G304 if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("reading authorized_keys: %w", err) + return false, fmt.Errorf("reading authorized_keys: %w", err) } taggedKey := pubKey + " " + BrevKeyComment if strings.Contains(string(existing), taggedKey) { - return nil // already present with tag + return false, nil // already present with tag } // If the key exists but isn't tagged, replace it with the tagged version @@ -91,9 +127,9 @@ func InstallAuthorizedKey(u *user.User, pubKey string) error { if strings.Contains(string(existing), pubKey) { updated := strings.ReplaceAll(string(existing), pubKey, taggedKey) if err := os.WriteFile(authKeysPath, []byte(updated), 0o600); err != nil { - return fmt.Errorf("writing authorized_keys: %w", err) + return false, fmt.Errorf("writing authorized_keys: %w", err) } - return nil + return true, nil } // Ensure existing content ends with a newline before appending. @@ -104,10 +140,10 @@ func InstallAuthorizedKey(u *user.User, pubKey string) error { content += taggedKey + "\n" if err := os.WriteFile(authKeysPath, []byte(content), 0o600); err != nil { - return fmt.Errorf("writing authorized_keys: %w", err) + return false, fmt.Errorf("writing authorized_keys: %w", err) } - return nil + return true, nil } // RemoveAuthorizedKey removes a specific public key from the user's