Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pkg/cmd/deregister/deregister.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type deregisterDeps struct {
platform externalnode.PlatformChecker
prompter terminal.Selector
netbird register.NetBirdManager
sshd register.ManagedSSHDaemon
nodeClients externalnode.NodeClientFactory
registrationStore register.RegistrationStore
sshKeys SSHKeyRemover
Expand All @@ -58,6 +59,7 @@ func defaultDeregisterDeps(brevHome string) deregisterDeps {
platform: register.LinuxPlatform{},
prompter: register.TerminalPrompter{},
netbird: register.Netbird{},
sshd: register.BrevSSHD{},
nodeClients: register.DefaultNodeClientFactory{},
registrationStore: register.NewFileRegistrationStore(brevHome),
sshKeys: brevSSHKeyRemover{},
Expand Down Expand Up @@ -158,6 +160,15 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore,
}
t.Vprint("")

// Remove brev-managed sshd (non-fatal on failure).
t.Vprint("Removing managed SSH daemon...")
if err := deps.sshd.Uninstall(); err != nil {
t.Vprintf(" Warning: failed to remove managed SSH daemon: %v\n", err)
} else {
t.Vprint(t.Green(" Managed SSH daemon removed."))
}
t.Vprint("")

t.Vprint("Removing Brev tunnel...")
if err := deps.netbird.Uninstall(); err != nil {
t.Vprintf(" Warning: failed to remove Brev tunnel: %v\n", err)
Expand Down
55 changes: 55 additions & 0 deletions pkg/cmd/deregister/deregister_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ type mockNetBirdManager struct {
func (m *mockNetBirdManager) Install() error { return m.err }
func (m *mockNetBirdManager) Uninstall() error { m.called = true; return m.err }

type mockManagedSSHDaemon struct {
uninstallCalled bool
uninstallErr error
}

func (m *mockManagedSSHDaemon) Install() error { return nil }
func (m *mockManagedSSHDaemon) Uninstall() error {
m.uninstallCalled = true
return m.uninstallErr
}

type mockNodeClientFactory struct {
serverURL string
}
Expand Down Expand Up @@ -133,6 +144,7 @@ func testDeregisterDeps(t *testing.T, svc *fakeNodeService, regStore register.Re
return ""
}},
netbird: &mockNetBirdManager{},
sshd: &mockManagedSSHDaemon{},
nodeClients: mockNodeClientFactory{serverURL: server.URL},
registrationStore: regStore,
sshKeys: &mockSSHKeyRemover{},
Expand Down Expand Up @@ -379,3 +391,46 @@ func Test_runDeregister_RemoveBrevKeysHandling(t *testing.T) {
})
}
}

func Test_runDeregister_SSHDUninstallFailureIsNonFatal(t *testing.T) {
regStore := &mockRegistrationStore{
reg: &register.DeviceRegistration{
ExternalNodeID: "unode_abc",
DisplayName: "My Spark",
OrgID: "org_123",
},
}

store := &mockDeregisterStore{
user: &entity.User{ID: "user_1"},
home: "/home/testuser/.brev",
token: "tok",
}

svc := &fakeNodeService{
removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) {
return &nodev1.RemoveNodeResponse{}, nil
},
}

sshdMock := &mockManagedSSHDaemon{uninstallErr: fmt.Errorf("permission denied")}
deps, server := testDeregisterDeps(t, svc, regStore)
defer server.Close()
deps.sshd = sshdMock

term := terminal.New()
err := runDeregister(context.Background(), term, store, deps)
if err != nil {
t.Fatalf("expected nil error (sshd failure should be non-fatal), got: %v", err)
}

if !sshdMock.uninstallCalled {
t.Error("expected sshd Uninstall to be called")
}

// Registration should still be cleaned up despite sshd failure.
exists, _ := regStore.Exists()
if exists {
t.Error("expected registration to be deleted")
}
}
39 changes: 23 additions & 16 deletions pkg/cmd/enablessh/enablessh.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"os/exec"
"os/user"
"strings"

"github.com/brevdev/brev-cli/pkg/cmd/register"
"github.com/brevdev/brev-cli/pkg/entity"
Expand All @@ -28,13 +29,15 @@ type EnableSSHStore interface {
// can be replaced in tests.
type enableSSHDeps struct {
platform externalnode.PlatformChecker
sshd register.ManagedSSHDaemon
nodeClients externalnode.NodeClientFactory
registrationStore register.RegistrationStore
}

func defaultEnableSSHDeps(brevHome string) enableSSHDeps {
return enableSSHDeps{
platform: register.LinuxPlatform{},
sshd: register.BrevSSHD{},
nodeClients: register.DefaultNodeClientFactory{},
registrationStore: register.NewFileRegistrationStore(brevHome),
}
Expand Down Expand Up @@ -83,15 +86,15 @@ func runEnableSSH(ctx context.Context, t *terminal.Terminal, s EnableSSHStore, d
return breverrors.WrapAndTrace(err)
}

return enableSSH(ctx, t, deps.nodeClients, s, reg, brevUser)
return enableSSH(ctx, t, deps, s, reg, brevUser)
}

// enableSSH grants SSH access to the given node for the current Brev user.
// This is the "reflexive grant" — granting yourself SSH access to the device.
// It ensures the managed sshd is installed before granting access.
func enableSSH(
ctx context.Context,
t *terminal.Terminal,
nodeClients externalnode.NodeClientFactory,
deps enableSSHDeps,
tokenProvider externalnode.TokenProvider,
reg *register.DeviceRegistration,
brevUser *entity.User,
Expand All @@ -101,8 +104,6 @@ func enableSSH(
return fmt.Errorf("failed to determine current Linux user: %w", err)
}

checkSSHDaemon(t)

t.Vprint("")
t.Vprint(t.Green("Enabling SSH access on this device"))
t.Vprint("")
Expand All @@ -111,22 +112,28 @@ func enableSSH(
t.Vprintf(" Linux user: %s\n", u.Username)
t.Vprint("")

if err := register.GrantSSHAccessToNode(ctx, t, nodeClients, tokenProvider, reg, brevUser, u); err != nil {
if !brevSSHDRunning() {
t.Vprint("This will:")
t.Vprint(" 1. Install or upgrade openssh-server")
t.Vprint(" 2. Set up a secure SSH server on port 2222")
t.Vprint("")
if err := deps.sshd.Install(); err != nil {
return fmt.Errorf("managed sshd setup failed: %w", err)
}
t.Vprint(t.Green(" Managed SSH daemon ready (port 2222)."))
t.Vprint("")
}

if err := register.GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, u); err != nil {
return fmt.Errorf("enable SSH failed: %w", err)
}

t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName)))
return nil
}

// checkSSHDaemon prints a warning if neither "ssh" nor "sshd" systemd services
// appear to be active. It never returns an error — it is best-effort.
func checkSSHDaemon(t *terminal.Terminal) {
for _, svc := range []string{"ssh", "sshd"} {
out, err := exec.Command("systemctl", "is-active", svc).Output() //nolint:gosec // fixed service names
if err == nil && len(out) > 0 && string(out[:len(out)-1]) == "active" {
return
}
}
t.Vprintf(" %s\n", t.Yellow("Warning: SSH daemon does not appear to be running. SSH access may not work until sshd is started."))
// brevSSHDRunning returns true if the brev-sshd systemd service is active.
func brevSSHDRunning() bool {
out, err := exec.Command("systemctl", "is-active", "brev-sshd").Output() //nolint:gosec // fixed service name
return err == nil && strings.TrimSpace(string(out)) == "active"
}
6 changes: 6 additions & 0 deletions pkg/cmd/register/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ type Netbird struct{}
func (Netbird) Install() error { return InstallNetbird() }
func (Netbird) Uninstall() error { return UninstallNetbird() }

// BrevSSHD manages the brev-managed sshd instance on port 2222.
type BrevSSHD struct{}

func (BrevSSHD) Install() error { return InstallBrevSSHD() }
func (BrevSSHD) Uninstall() error { return UninstallBrevSSHD() }

// ShellSetupRunner runs setup scripts via shell.
type ShellSetupRunner struct{}

Expand Down
28 changes: 24 additions & 4 deletions pkg/cmd/register/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ type NetBirdManager interface {
Uninstall() error
}

// ManagedSSHDaemon installs and uninstalls a brev-managed sshd instance.
type ManagedSSHDaemon interface {
Install() error
Uninstall() error
}

// SetupRunner runs a setup script on the local machine.
type SetupRunner interface {
RunSetup(script string) error
Expand All @@ -58,8 +64,8 @@ type SetupRunner interface {
type registerDeps struct {
platform externalnode.PlatformChecker
prompter terminal.Confirmer
netbird NetBirdManager
setupRunner SetupRunner
netbird NetBirdManager
setupRunner SetupRunner
nodeClients externalnode.NodeClientFactory
commandRunner CommandRunner
fileReader FileReader
Expand All @@ -70,8 +76,8 @@ func defaultRegisterDeps(brevHome string) registerDeps {
return registerDeps{
platform: LinuxPlatform{},
prompter: TerminalPrompter{},
netbird: Netbird{},
setupRunner: ShellSetupRunner{},
netbird: Netbird{},
setupRunner: ShellSetupRunner{},
nodeClients: DefaultNodeClientFactory{},
commandRunner: ExecCommandRunner{},
fileReader: OSFileReader{},
Expand Down Expand Up @@ -207,6 +213,12 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam

runSetup(node, t, deps)

t.Vprint("")
t.Vprint("SSH access allows you to connect to this device remotely.")
t.Vprint("This will:")
t.Vprint(" 1. Install or upgrade openssh-server")
t.Vprint(" 2. Set up a secure SSH server on port 2222")
t.Vprint("")
if deps.prompter.ConfirmYesNo("Would you like to enable SSH access to this device?") {
grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser)
}
Expand Down Expand Up @@ -339,6 +351,14 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps
t.Vprintf(" Linux user: %s\n", osUser.Username)
t.Vprint("")

t.Vprint("Setting up managed SSH daemon...")
if err := InstallBrevSSHD(); err != nil {
t.Vprintf(" Warning: managed sshd setup failed: %v\n", err)
} else {
t.Vprint(t.Green(" Managed SSH daemon ready (port 2222)."))
}
t.Vprint("")

err := GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser)
if err != nil {
t.Vprint(" Retrying in 3 seconds...")
Expand Down
2 changes: 2 additions & 0 deletions pkg/cmd/register/register_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ type mockNetBirdManager struct{ err error }
func (m mockNetBirdManager) Install() error { return m.err }
func (m mockNetBirdManager) Uninstall() error { return m.err }


type mockSetupRunner struct {
called bool
cmd string
Expand Down Expand Up @@ -419,6 +420,7 @@ func Test_runRegister_NoSetupCommand(t *testing.T) {
}
}


func Test_runSetupCommand_Validation(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading