Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/alessio/shellescape v1.4.1
github.com/brevdev/parse v0.0.11
github.com/briandowns/spinner v1.16.0
github.com/cenkalti/backoff/v4 v4.3.0
github.com/fatih/color v1.13.0
github.com/getsentry/sentry-go v0.14.0
github.com/gin-gonic/gin v1.10.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/chzyer/logex v1.1.10 h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
Expand Down
26 changes: 13 additions & 13 deletions pkg/cmd/enablessh/enablessh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func readAuthorizedKeys(t *testing.T, u *user.User) string {
func Test_InstallAuthorizedKey_TagsKeyWithBrevComment(t *testing.T) {
u := tempUser(t)

if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
t.Fatalf("InstallAuthorizedKey: %v", err)
}

Expand All @@ -44,10 +44,10 @@ func Test_InstallAuthorizedKey_TagsKeyWithBrevComment(t *testing.T) {
func Test_InstallAuthorizedKey_SkipsDuplicate(t *testing.T) {
u := tempUser(t)

if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
t.Fatalf("first install: %v", err)
}
if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
t.Fatalf("second install: %v", err)
}

Expand All @@ -70,7 +70,7 @@ func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) {
t.Fatal(err)
}

if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
t.Fatalf("InstallAuthorizedKey: %v", err)
}

Expand All @@ -84,10 +84,10 @@ func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) {
func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) {
u := tempUser(t)

if err := register.InstallAuthorizedKey(u, ""); err != nil {
if _, err := register.InstallAuthorizedKey(u, ""); err != nil {
t.Fatalf("InstallAuthorizedKey: %v", err)
}
if err := register.InstallAuthorizedKey(u, " "); err != nil {
if _, err := register.InstallAuthorizedKey(u, " "); err != nil {
t.Fatalf("InstallAuthorizedKey (whitespace): %v", err)
}

Expand All @@ -101,7 +101,7 @@ func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) {
func Test_InstallAuthorizedKey_CreatesSSHDir(t *testing.T) {
u := tempUser(t)

if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
t.Fatalf("InstallAuthorizedKey: %v", err)
}

Expand All @@ -126,7 +126,7 @@ func Test_InstallAuthorizedKey_PreservesExistingKeys(t *testing.T) {
t.Fatal(err)
}

if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
t.Fatalf("InstallAuthorizedKey: %v", err)
}

Expand All @@ -152,7 +152,7 @@ func Test_InstallAuthorizedKey_TagsExistingUntaggedKey(t *testing.T) {
}

// InstallAuthorizedKey should tag the existing key rather than adding a duplicate.
if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil {
t.Fatalf("InstallAuthorizedKey: %v", err)
}

Expand Down Expand Up @@ -366,10 +366,10 @@ func Test_InstallThenRemove_RoundTrip(t *testing.T) {
}

// Install two brev keys.
if err := register.InstallAuthorizedKey(u, "ssh-rsa KEY1"); err != nil {
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY1"); err != nil {
t.Fatal(err)
}
if err := register.InstallAuthorizedKey(u, "ssh-rsa KEY2"); err != nil {
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY2"); err != nil {
t.Fatal(err)
}

Expand All @@ -393,10 +393,10 @@ func Test_InstallThenRemoveSpecificKey_RollbackScenario(t *testing.T) {
u := tempUser(t)

// Install two brev keys (simulating two users granted access).
if err := register.InstallAuthorizedKey(u, "ssh-rsa ALICE"); err != nil {
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa ALICE"); err != nil {
t.Fatal(err)
}
if err := register.InstallAuthorizedKey(u, "ssh-rsa BOB"); err != nil {
if _, err := register.InstallAuthorizedKey(u, "ssh-rsa BOB"); err != nil {
t.Fatal(err)
}

Expand Down
15 changes: 6 additions & 9 deletions pkg/cmd/register/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam
runSetup(node, t, deps)

if deps.prompter.ConfirmYesNo("Would you like to enable SSH access to this device?") {
grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser)
if err := grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser); err != nil {
t.Vprintf(" Warning: SSH access not granted: %v\n", err)
}
}

return nil
Expand Down Expand Up @@ -330,7 +332,7 @@ func runSetup(node *nodev1.ExternalNode, t *terminal.Terminal, deps registerDeps
}
}

func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) {
func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) error {
t.Vprint("")
t.Vprint(t.Green("Enabling SSH access on this device"))
t.Vprint("")
Expand All @@ -341,14 +343,9 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps

err := GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser)
if err != nil {
t.Vprint(" Retrying in 3 seconds...")
time.Sleep(3 * time.Second)
err = GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser)
}
if err != nil {
t.Vprintf(" Warning: %v\n", err)
return
return fmt.Errorf("grant SSH failed: %w", err)
}

t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName)))
return nil
}
97 changes: 97 additions & 0 deletions pkg/cmd/register/register_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,100 @@ Peers count: 0/0 Connected`
})
}
}

func Test_runRegister_GrantSSH_retries_on_connection_error_then_succeeds(t *testing.T) {
regStore := &mockRegistrationStore{}

store := &mockRegisterStore{
user: &entity.User{ID: "user_1"},
org: &entity.Organization{ID: "org_123", Name: "TestOrg"},
home: "/home/testuser/.brev",
token: "tok",
}

var grantCalls int
svc := &fakeNodeService{
addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) {
return &nodev1.AddNodeResponse{
ExternalNode: &nodev1.ExternalNode{
ExternalNodeId: "unode_abc",
OrganizationId: "org_123",
Name: req.GetName(),
DeviceId: req.GetDeviceId(),
ConnectivityInfo: &nodev1.ConnectivityInfo{
RegistrationCommand: "netbird up --key abc",
},
},
}, nil
},
grantNodeSSHAccessFn: func(_ *nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) {
grantCalls++
if grantCalls < 2 {
return nil, connect.NewError(connect.CodeInternal, nil)
}
return &nodev1.GrantNodeSSHAccessResponse{}, nil
},
}

deps, server := testRegisterDeps(t, svc, regStore)
defer server.Close()

deps.prompter = mockConfirmer{confirm: true}

term := terminal.New()
err := runRegister(context.Background(), term, store, "My Spark", deps)
if err != nil {
t.Fatalf("runRegister failed: %v", err)
}

if grantCalls != 2 {
t.Errorf("expected GrantNodeSSHAccess to be called 2 times (retry once), got %d", grantCalls)
}
}

func Test_runRegister_GrantSSH_no_retry_on_permanent_error(t *testing.T) {
regStore := &mockRegistrationStore{}

store := &mockRegisterStore{
user: &entity.User{ID: "user_1"},
org: &entity.Organization{ID: "org_123", Name: "TestOrg"},
home: "/home/testuser/.brev",
token: "tok",
}

var grantCalls int
svc := &fakeNodeService{
addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) {
return &nodev1.AddNodeResponse{
ExternalNode: &nodev1.ExternalNode{
ExternalNodeId: "unode_abc",
OrganizationId: "org_123",
Name: req.GetName(),
DeviceId: req.GetDeviceId(),
ConnectivityInfo: &nodev1.ConnectivityInfo{
RegistrationCommand: "netbird up --key abc",
},
},
}, nil
},
grantNodeSSHAccessFn: func(_ *nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) {
grantCalls++
return nil, connect.NewError(connect.CodePermissionDenied, nil)
},
}

deps, server := testRegisterDeps(t, svc, regStore)
defer server.Close()

deps.prompter = mockConfirmer{confirm: true}

term := terminal.New()
err := runRegister(context.Background(), term, store, "My Spark", deps)
if err != nil {
t.Fatalf("runRegister should not fail the overall flow when SSH grant fails: %v", err)
}

if grantCalls != 1 {
t.Errorf("expected GrantNodeSSHAccess to be called once (no retry on permanent error), got %d", grantCalls)
}
}
18 changes: 15 additions & 3 deletions pkg/cmd/register/rpcclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,10 @@ func Test_toProtoNodeSpec_MinimalFields(t *testing.T) {
// fakeNodeService implements the server side of ExternalNodeService for testing.
type fakeNodeService struct {
nodev1connect.UnimplementedExternalNodeServiceHandler
addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error)
removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error)
getNodeFn func(*nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error)
addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error)
removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error)
getNodeFn func(*nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error)
grantNodeSSHAccessFn func(*nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error)
}

func (f *fakeNodeService) AddNode(_ context.Context, req *connect.Request[nodev1.AddNodeRequest]) (*connect.Response[nodev1.AddNodeResponse], error) {
Expand Down Expand Up @@ -175,6 +176,17 @@ func (f *fakeNodeService) GetNode(_ context.Context, req *connect.Request[nodev1
return connect.NewResponse(resp), nil
}

func (f *fakeNodeService) GrantNodeSSHAccess(_ context.Context, req *connect.Request[nodev1.GrantNodeSSHAccessRequest]) (*connect.Response[nodev1.GrantNodeSSHAccessResponse], error) {
if f.grantNodeSSHAccessFn == nil {
return nil, connect.NewError(connect.CodeUnimplemented, nil)
}
resp, err := f.grantNodeSSHAccessFn(req.Msg)
if err != nil {
return nil, err
}
return connect.NewResponse(resp), nil
}

func Test_NewNodeServiceClient_AddNode(t *testing.T) {
svc := &fakeNodeService{
addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) {
Expand Down
Loading
Loading