Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions pkg/cmd/register/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,20 @@ func runSetup(node *nodev1.ExternalNode, t *terminal.Terminal, deps registerDeps
}
}

// waitForNetbirdConnected polls "netbird status" until the management server
// reports Connected or the timeout expires. Returns true if connected.
func waitForNetbirdConnected(timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
out, err := exec.Command("netbird", "status").Output() //nolint:gosec // fixed command
if err == nil && netbirdManagementConnected(string(out)) {
return true
}
time.Sleep(2 * time.Second)
}
return false
}

func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) {
t.Vprint("")
t.Vprint(t.Green("Enabling SSH access on this device"))
Expand All @@ -339,14 +353,36 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps
t.Vprintf(" Linux user: %s\n", osUser.Username)
t.Vprint("")

err := GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser)
if err != nil {
t.Vprint(" Retrying in 3 seconds...")
time.Sleep(3 * time.Second)
t.Vprint(" Waiting for Brev tunnel to connect...")
if !waitForNetbirdConnected(60 * time.Second) {
t.Vprint(t.Yellow(" Tunnel did not connect within 60s."))
t.Vprint(t.Yellow(" Run 'brev enable-ssh' once the tunnel is established."))
return
}
t.Vprint(t.Green(" Tunnel connected."))
t.Vprint("")

// Peer routes finish propagating after the management handshake. Retry
// with increasing delays to give the routing up to ~90s to settle.
retryDelays := []time.Duration{10 * time.Second, 20 * time.Second, 30 * time.Second}
var err error
for i, delay := range append([]time.Duration{0}, retryDelays...) {
if delay > 0 {
t.Vprintf(" Retrying in %s...\n", delay)
time.Sleep(delay)
}
err = GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser)
if err == nil {
break
}
if i < len(retryDelays) {
t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("(%d/%d) %v", i+1, len(retryDelays)+1, err)))
}
}
if err != nil {
t.Vprintf(" Warning: %v\n", err)
t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: %v", err)))
t.Vprint(t.Yellow(" Your SSH public key is already installed locally on this device."))
t.Vprint(t.Yellow(" Run 'brev enable-ssh' in ~1 minute to complete the server-side record."))
return
}

Expand Down
41 changes: 28 additions & 13 deletions pkg/cmd/register/sshkeys.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package register

import (
"context"
"errors"
"fmt"
"os"
"os/user"
Expand All @@ -22,8 +23,12 @@ import (
const BrevKeyComment = "# brev-cli"

// GrantSSHAccessToNode installs the user's public key in authorized_keys and
// calls GrantNodeSSHAccess to record access server-side. If the RPC fails,
// the installed key is rolled back.
// calls GrantNodeSSHAccess to record access server-side.
//
// On a transient transport error (connect.CodeInternal, e.g. connection reset),
// the key is left in authorized_keys so the caller can retry without
// re-installing it. On a permanent application error (auth, not found, etc.)
// the key is rolled back.
func GrantSSHAccessToNode(
ctx context.Context,
t *terminal.Terminal,
Expand All @@ -34,9 +39,10 @@ func GrantSSHAccessToNode(
osUser *user.User,
) error {
if targetUser.PublicKey != "" {
if err := InstallAuthorizedKey(osUser, targetUser.PublicKey); err != nil {
added, err := InstallAuthorizedKey(osUser, targetUser.PublicKey)
if err != nil {
t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err)))
} else {
} else if added {
t.Vprint(" Brev public key added to authorized_keys.")
}
}
Expand All @@ -48,6 +54,14 @@ func GrantSSHAccessToNode(
LinuxUser: osUser.Username,
}))
if err != nil {
// Transport errors (connection reset, EOF) are transient — leave the key
// installed so retries don't need to reinstall it, and signal the caller
// with a distinct error type.
var connectErr *connect.Error
if errors.As(err, &connectErr) && connectErr.Code() == connect.CodeInternal {
return fmt.Errorf("failed to grant SSH access (transient): %w", err)
}
// Permanent error — roll back the key so we don't leave an unrecorded entry.
if targetUser.PublicKey != "" {
if rerr := RemoveAuthorizedKey(osUser, targetUser.PublicKey); rerr != nil {
t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr)))
Expand All @@ -62,38 +76,39 @@ func GrantSSHAccessToNode(
// InstallAuthorizedKey appends the given public key to the user's
// ~/.ssh/authorized_keys if it isn't already present. The key is tagged with
// a brev-cli comment so it can be removed later by RemoveBrevAuthorizedKeys.
func InstallAuthorizedKey(u *user.User, pubKey string) error {
// Returns true if the key was newly written, false if it was already present.
func InstallAuthorizedKey(u *user.User, pubKey string) (bool, error) {
pubKey = strings.TrimSpace(pubKey)
if pubKey == "" {
return nil
return false, nil
}

sshDir := filepath.Join(u.HomeDir, ".ssh")
if err := os.MkdirAll(sshDir, 0o700); err != nil {
return fmt.Errorf("creating .ssh directory: %w", err)
return false, fmt.Errorf("creating .ssh directory: %w", err)
}

authKeysPath := filepath.Join(sshDir, "authorized_keys")

existing, err := os.ReadFile(authKeysPath) // #nosec G304
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("reading authorized_keys: %w", err)
return false, fmt.Errorf("reading authorized_keys: %w", err)
}

taggedKey := pubKey + " " + BrevKeyComment

if strings.Contains(string(existing), taggedKey) {
return nil // already present with tag
return false, nil // already present with tag
}

// If the key exists but isn't tagged, replace it with the tagged version
// so that RemoveBrevAuthorizedKeys can find it later.
if strings.Contains(string(existing), pubKey) {
updated := strings.ReplaceAll(string(existing), pubKey, taggedKey)
if err := os.WriteFile(authKeysPath, []byte(updated), 0o600); err != nil {
return fmt.Errorf("writing authorized_keys: %w", err)
return false, fmt.Errorf("writing authorized_keys: %w", err)
}
return nil
return true, nil
}

// Ensure existing content ends with a newline before appending.
Expand All @@ -104,10 +119,10 @@ func InstallAuthorizedKey(u *user.User, pubKey string) error {
content += taggedKey + "\n"

if err := os.WriteFile(authKeysPath, []byte(content), 0o600); err != nil {
return fmt.Errorf("writing authorized_keys: %w", err)
return false, fmt.Errorf("writing authorized_keys: %w", err)
}

return nil
return true, nil
}

// RemoveAuthorizedKey removes a specific public key from the user's
Expand Down
Loading