diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index e56c0e6c..a38e3bb0 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -330,6 +330,20 @@ func runSetup(node *nodev1.ExternalNode, t *terminal.Terminal, deps registerDeps } } +// waitForNetbirdConnected polls "netbird status" until the management server +// reports Connected or the timeout expires. Returns true if connected. +func waitForNetbirdConnected(timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + out, err := exec.Command("netbird", "status").Output() //nolint:gosec // fixed command + if err == nil && netbirdManagementConnected(string(out)) { + return true + } + time.Sleep(2 * time.Second) + } + return false +} + func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) { t.Vprint("") t.Vprint(t.Green("Enabling SSH access on this device")) @@ -339,14 +353,36 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps t.Vprintf(" Linux user: %s\n", osUser.Username) t.Vprint("") - err := GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser) - if err != nil { - t.Vprint(" Retrying in 3 seconds...") - time.Sleep(3 * time.Second) + t.Vprint(" Waiting for Brev tunnel to connect...") + if !waitForNetbirdConnected(60 * time.Second) { + t.Vprint(t.Yellow(" Tunnel did not connect within 60s.")) + t.Vprint(t.Yellow(" Run 'brev enable-ssh' once the tunnel is established.")) + return + } + t.Vprint(t.Green(" Tunnel connected.")) + t.Vprint("") + + // Peer routes finish propagating after the management handshake. Retry + // with increasing delays to give the routing up to ~90s to settle. + retryDelays := []time.Duration{10 * time.Second, 20 * time.Second, 30 * time.Second} + var err error + for i, delay := range append([]time.Duration{0}, retryDelays...) { + if delay > 0 { + t.Vprintf(" Retrying in %s...\n", delay) + time.Sleep(delay) + } err = GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser) + if err == nil { + break + } + if i < len(retryDelays) { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("(%d/%d) %v", i+1, len(retryDelays)+1, err))) + } } if err != nil { - t.Vprintf(" Warning: %v\n", err) + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: %v", err))) + t.Vprint(t.Yellow(" Your SSH public key is already installed locally on this device.")) + t.Vprint(t.Yellow(" Run 'brev enable-ssh' in ~1 minute to complete the server-side record.")) return } diff --git a/pkg/cmd/register/sshkeys.go b/pkg/cmd/register/sshkeys.go index db36f01b..4fc46e9c 100644 --- a/pkg/cmd/register/sshkeys.go +++ b/pkg/cmd/register/sshkeys.go @@ -2,6 +2,7 @@ package register import ( "context" + "errors" "fmt" "os" "os/user" @@ -22,8 +23,12 @@ import ( const BrevKeyComment = "# brev-cli" // GrantSSHAccessToNode installs the user's public key in authorized_keys and -// calls GrantNodeSSHAccess to record access server-side. If the RPC fails, -// the installed key is rolled back. +// calls GrantNodeSSHAccess to record access server-side. +// +// On a transient transport error (connect.CodeInternal, e.g. connection reset), +// the key is left in authorized_keys so the caller can retry without +// re-installing it. On a permanent application error (auth, not found, etc.) +// the key is rolled back. func GrantSSHAccessToNode( ctx context.Context, t *terminal.Terminal, @@ -34,9 +39,10 @@ func GrantSSHAccessToNode( osUser *user.User, ) error { if targetUser.PublicKey != "" { - if err := InstallAuthorizedKey(osUser, targetUser.PublicKey); err != nil { + added, err := InstallAuthorizedKey(osUser, targetUser.PublicKey) + if err != nil { t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) - } else { + } else if added { t.Vprint(" Brev public key added to authorized_keys.") } } @@ -48,6 +54,14 @@ func GrantSSHAccessToNode( LinuxUser: osUser.Username, })) if err != nil { + // Transport errors (connection reset, EOF) are transient — leave the key + // installed so retries don't need to reinstall it, and signal the caller + // with a distinct error type. + var connectErr *connect.Error + if errors.As(err, &connectErr) && connectErr.Code() == connect.CodeInternal { + return fmt.Errorf("failed to grant SSH access (transient): %w", err) + } + // Permanent error — roll back the key so we don't leave an unrecorded entry. if targetUser.PublicKey != "" { if rerr := RemoveAuthorizedKey(osUser, targetUser.PublicKey); rerr != nil { t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr))) @@ -62,28 +76,29 @@ func GrantSSHAccessToNode( // InstallAuthorizedKey appends the given public key to the user's // ~/.ssh/authorized_keys if it isn't already present. The key is tagged with // a brev-cli comment so it can be removed later by RemoveBrevAuthorizedKeys. -func InstallAuthorizedKey(u *user.User, pubKey string) error { +// Returns true if the key was newly written, false if it was already present. +func InstallAuthorizedKey(u *user.User, pubKey string) (bool, error) { pubKey = strings.TrimSpace(pubKey) if pubKey == "" { - return nil + return false, nil } sshDir := filepath.Join(u.HomeDir, ".ssh") if err := os.MkdirAll(sshDir, 0o700); err != nil { - return fmt.Errorf("creating .ssh directory: %w", err) + return false, fmt.Errorf("creating .ssh directory: %w", err) } authKeysPath := filepath.Join(sshDir, "authorized_keys") existing, err := os.ReadFile(authKeysPath) // #nosec G304 if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("reading authorized_keys: %w", err) + return false, fmt.Errorf("reading authorized_keys: %w", err) } taggedKey := pubKey + " " + BrevKeyComment if strings.Contains(string(existing), taggedKey) { - return nil // already present with tag + return false, nil // already present with tag } // If the key exists but isn't tagged, replace it with the tagged version @@ -91,9 +106,9 @@ func InstallAuthorizedKey(u *user.User, pubKey string) error { if strings.Contains(string(existing), pubKey) { updated := strings.ReplaceAll(string(existing), pubKey, taggedKey) if err := os.WriteFile(authKeysPath, []byte(updated), 0o600); err != nil { - return fmt.Errorf("writing authorized_keys: %w", err) + return false, fmt.Errorf("writing authorized_keys: %w", err) } - return nil + return true, nil } // Ensure existing content ends with a newline before appending. @@ -104,10 +119,10 @@ func InstallAuthorizedKey(u *user.User, pubKey string) error { content += taggedKey + "\n" if err := os.WriteFile(authKeysPath, []byte(content), 0o600); err != nil { - return fmt.Errorf("writing authorized_keys: %w", err) + return false, fmt.Errorf("writing authorized_keys: %w", err) } - return nil + return true, nil } // RemoveAuthorizedKey removes a specific public key from the user's