diff --git a/pkg/cmd/copy/copy.go b/pkg/cmd/copy/copy.go index e0abb17e..5aadf4ef 100644 --- a/pkg/cmd/copy/copy.go +++ b/pkg/cmd/copy/copy.go @@ -8,6 +8,8 @@ import ( "strings" "time" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "github.com/brevdev/brev-cli/pkg/cmd/cmderrors" "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmd/refresh" @@ -75,7 +77,15 @@ func runCopyCommand(t *terminal.Terminal, cstore CopyStore, source, dest string, } } - workspace, err := prepareWorkspace(t, cstore, workspaceNameOrID) + target, err := util.ResolveWorkspaceOrNode(cstore, workspaceNameOrID) + if err != nil { + return breverrors.WrapAndTrace(err) + } + if target.Node != nil { + return copyExternalNode(t, cstore, target.Node, localPath, remotePath, isUpload) + } + + workspace, err := prepareWorkspace(t, cstore, target.Workspace) if err != nil { return breverrors.WrapAndTrace(err) } @@ -116,26 +126,22 @@ func parseCopyArguments(source, dest string) (workspaceNameOrID, remotePath, loc return destWorkspace, destPath, source, true, nil } -func prepareWorkspace(t *terminal.Terminal, cstore CopyStore, workspaceNameOrID string) (*entity.Workspace, error) { +func prepareWorkspace(t *terminal.Terminal, cstore CopyStore, workspace *entity.Workspace) (*entity.Workspace, error) { s := t.NewSpinner() - workspace, err := util.GetUserWorkspaceByNameOrIDErr(cstore, workspaceNameOrID) - if err != nil { - return nil, breverrors.WrapAndTrace(err) - } if workspace.Status == "STOPPED" { - err = startWorkspaceIfStopped(t, s, cstore, workspaceNameOrID, workspace) + err := startWorkspaceIfStopped(t, s, cstore, workspace.Name, workspace) if err != nil { return nil, breverrors.WrapAndTrace(err) } } - err = pollUntil(s, workspace.ID, "RUNNING", cstore, " waiting for instance to be ready...") + err := pollUntil(s, workspace.ID, "RUNNING", cstore, " waiting for instance to be ready...") if err != nil { return nil, breverrors.WrapAndTrace(err) } - workspace, err = util.GetUserWorkspaceByNameOrIDErr(cstore, workspaceNameOrID) + workspace, err = util.GetUserWorkspaceByNameOrIDErr(cstore, workspace.Name) if err != nil { return nil, breverrors.WrapAndTrace(err) } @@ -287,6 +293,28 @@ func startWorkspaceIfStopped(t *terminal.Terminal, s *spinner.Spinner, tstore Co return nil } +func copyExternalNode(t *terminal.Terminal, cstore CopyStore, node *nodev1.ExternalNode, localPath, remotePath string, isUpload bool) error { + info, err := util.ResolveExternalNodeSSH(cstore, node) + if err != nil { + return breverrors.WrapAndTrace(err) + } + alias := info.SSHAlias() + + // Ensure SSH config is up to date so the alias resolves. + refreshRes := refresh.RunRefreshAsync(cstore) + if err := refreshRes.Await(); err != nil { + return breverrors.WrapAndTrace(err) + } + + s := t.NewSpinner() + err = waitForSSHToBeAvailable(alias, s) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + return runSCP(t, alias, localPath, remotePath, isUpload) +} + func pollUntil(s *spinner.Spinner, wsid string, state string, copyStore CopyStore, waitMsg string) error { isReady := false s.Suffix = waitMsg diff --git a/pkg/cmd/copy/copy_test.go b/pkg/cmd/copy/copy_test.go new file mode 100644 index 00000000..483aa31b --- /dev/null +++ b/pkg/cmd/copy/copy_test.go @@ -0,0 +1,90 @@ +package copy + +import ( + "testing" +) + +func TestParseCopyArguments_Upload(t *testing.T) { + ws, remotePath, localPath, isUpload, err := parseCopyArguments("./local.txt", "my-node:/tmp/dest") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ws != "my-node" { + t.Errorf("expected workspace my-node, got %s", ws) + } + if remotePath != "/tmp/dest" { + t.Errorf("expected remotePath /tmp/dest, got %s", remotePath) + } + if localPath != "./local.txt" { + t.Errorf("expected localPath ./local.txt, got %s", localPath) + } + if !isUpload { + t.Error("expected isUpload=true") + } +} + +func TestParseCopyArguments_Download(t *testing.T) { + ws, remotePath, localPath, isUpload, err := parseCopyArguments("my-node:/tmp/file", "./local.txt") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ws != "my-node" { + t.Errorf("expected workspace my-node, got %s", ws) + } + if remotePath != "/tmp/file" { + t.Errorf("expected remotePath /tmp/file, got %s", remotePath) + } + if localPath != "./local.txt" { + t.Errorf("expected localPath ./local.txt, got %s", localPath) + } + if isUpload { + t.Error("expected isUpload=false") + } +} + +func TestParseCopyArguments_BothLocal(t *testing.T) { + _, _, _, _, err := parseCopyArguments("./a", "./b") + if err == nil { + t.Fatal("expected error when both paths are local") + } +} + +func TestParseCopyArguments_BothRemote(t *testing.T) { + _, _, _, _, err := parseCopyArguments("ws1:/a", "ws2:/b") + if err == nil { + t.Fatal("expected error when both paths are remote") + } +} + +func TestParseWorkspacePath_Local(t *testing.T) { + ws, fp, err := parseWorkspacePath("/tmp/local/file") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ws != "" { + t.Errorf("expected empty workspace, got %s", ws) + } + if fp != "/tmp/local/file" { + t.Errorf("expected /tmp/local/file, got %s", fp) + } +} + +func TestParseWorkspacePath_Remote(t *testing.T) { + ws, fp, err := parseWorkspacePath("my-instance:/remote/path") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ws != "my-instance" { + t.Errorf("expected my-instance, got %s", ws) + } + if fp != "/remote/path" { + t.Errorf("expected /remote/path, got %s", fp) + } +} + +func TestParseWorkspacePath_InvalidMultipleColons(t *testing.T) { + _, _, err := parseWorkspacePath("ws:path:extra") + if err == nil { + t.Fatal("expected error for multiple colons") + } +} diff --git a/pkg/cmd/notebook/notebook.go b/pkg/cmd/notebook/notebook.go index 3a3bfbae..46d221da 100644 --- a/pkg/cmd/notebook/notebook.go +++ b/pkg/cmd/notebook/notebook.go @@ -25,7 +25,7 @@ type WorkspaceResult struct { Err error } -func NewCmdNotebook(store NotebookStore, _ *terminal.Terminal) *cobra.Command { +func NewCmdNotebook(store NotebookStore, t *terminal.Terminal) *cobra.Command { cmd := &cobra.Command{ Use: "notebook", Short: "Open a notebook on your Brev machine", @@ -66,7 +66,7 @@ func NewCmdNotebook(store NotebookStore, _ *terminal.Terminal) *cobra.Command { hello.TypeItToMeUnskippable27("\nClick here to go to your Jupyter notebook:\n\t πŸ‘‰" + urlType("http://localhost:8888") + "πŸ‘ˆ\n\n\n") // Port forward on 8888 - err2 := portforward.RunPortforward(store, args[0], "8888:8888", false) + err2 := portforward.RunPortforward(t, store, args[0], "8888:8888", false) if err2 != nil { return breverrors.WrapAndTrace(err2) } diff --git a/pkg/cmd/open/open.go b/pkg/cmd/open/open.go index 0ad063a0..e7378092 100644 --- a/pkg/cmd/open/open.go +++ b/pkg/cmd/open/open.go @@ -10,6 +10,8 @@ import ( "strings" "time" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "github.com/alessio/shellescape" "github.com/brevdev/brev-cli/pkg/analytics" "github.com/brevdev/brev-cli/pkg/cmd/cmderrors" @@ -281,10 +283,18 @@ func runOpenCommand(t *terminal.Terminal, tstore OpenStore, wsIDOrName string, s // todo check if workspace is stopped and start if it if it is stopped fmt.Println("finding your instance...") res := refresh.RunRefreshAsync(tstore) - workspace, err := util.GetUserWorkspaceByNameOrIDErr(tstore, wsIDOrName) + target, err := util.ResolveWorkspaceOrNode(tstore, wsIDOrName) if err != nil { return breverrors.WrapAndTrace(err) } + if target.Node != nil { + // Await refresh so SSH config entries are written for the node. + if awaitErr := res.Await(); awaitErr != nil { + return breverrors.WrapAndTrace(awaitErr) + } + return openExternalNode(t, tstore, target.Node, directory, editorType) + } + workspace := target.Workspace if workspace.Status == "STOPPED" { // we start the env for the user err = startWorkspaceIfStopped(t, tstore, wsIDOrName, workspace) if err != nil { @@ -356,6 +366,36 @@ func runOpenCommand(t *terminal.Terminal, tstore OpenStore, wsIDOrName string, s return nil } +func openExternalNode(t *terminal.Terminal, tstore OpenStore, node *nodev1.ExternalNode, directory string, editorType string) error { + info, err := util.ResolveExternalNodeSSH(tstore, node) + if err != nil { + return breverrors.WrapAndTrace(err) + } + alias := info.SSHAlias() + path := info.HomePath() + if directory != "" { + path = directory + } + + _ = hello.SetHasRunOpen(true) + + s := t.NewSpinner() + s.Start() + s.Suffix = " checking if your node is ready..." + err = waitForSSHToBeAvailable(t, s, alias) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + editorName := getEditorName(editorType) + s.Suffix = fmt.Sprintf(" Node is ready. Opening %s", editorName) + time.Sleep(250 * time.Millisecond) + s.Stop() + t.Vprintf("\n") + + return openEditorByType(t, editorType, alias, path, tstore) +} + func pushOpenAnalytics(tstore OpenStore, workspace *entity.Workspace) error { userID := "" user, err := tstore.GetCurrentUser() diff --git a/pkg/cmd/open/open_test.go b/pkg/cmd/open/open_test.go index a1ed6d51..c35c8bc4 100644 --- a/pkg/cmd/open/open_test.go +++ b/pkg/cmd/open/open_test.go @@ -1 +1,43 @@ package open + +import ( + "testing" +) + +func TestIsEditorType(t *testing.T) { + valid := []string{"code", "cursor", "windsurf", "terminal", "tmux"} + for _, v := range valid { + if !isEditorType(v) { + t.Errorf("expected %q to be valid editor type", v) + } + } + + invalid := []string{"vim", "emacs", "vscode", "Code", "", "ssh"} + for _, v := range invalid { + if isEditorType(v) { + t.Errorf("expected %q to NOT be valid editor type", v) + } + } +} + +func TestGetEditorName(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"code", "VSCode"}, + {"cursor", "Cursor"}, + {"windsurf", "Windsurf"}, + {"terminal", "Terminal"}, + {"tmux", "tmux"}, + {"unknown", "VSCode"}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := getEditorName(tt.input) + if got != tt.want { + t.Errorf("getEditorName(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/pkg/cmd/portforward/portforward.go b/pkg/cmd/portforward/portforward.go index 70bfbd26..d0a6fa97 100644 --- a/pkg/cmd/portforward/portforward.go +++ b/pkg/cmd/portforward/portforward.go @@ -6,8 +6,11 @@ import ( "os/exec" "os/signal" "path/filepath" + "strconv" "strings" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "github.com/brevdev/brev-cli/pkg/cmd/cmderrors" "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmd/refresh" @@ -20,7 +23,7 @@ import ( var ( sshLinkLong = "Port forward your Brev machine's port to your local port" - sshLinkExample = "brev port-forward -p local_port:remote_port" + sshLinkExample = "brev port-forward -p local_port:remote_port" ) type PortforwardStore interface { @@ -46,14 +49,14 @@ func NewCmdPortForwardSSH(pfStore PortforwardStore, t *terminal.Terminal) *cobra if port == "" { port = startInput(t) } - err := RunPortforward(pfStore, args[0], port, useHost) + err := RunPortforward(t, pfStore, args[0], port, useHost) if err != nil { return breverrors.WrapAndTrace(err) } return nil }, } - cmd.Flags().StringVarP(&port, "port", "p", "", "port forward flag describe me better") + cmd.Flags().StringVarP(&port, "port", "p", "", "port forward string, local_port:remote_port") cmd.Flags().BoolVar(&useHost, "host", false, "Use the -host version of the instance") err := cmd.RegisterFlagCompletionFunc("port", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { return nil, cobra.ShellCompDirectiveNoSpace @@ -66,46 +69,97 @@ func NewCmdPortForwardSSH(pfStore PortforwardStore, t *terminal.Terminal) *cobra return cmd } -func RunPortforward(pfStore PortforwardStore, nameOrID string, portString string, useHost bool) error { - var portSplit []string - if strings.Contains(portString, ":") { - portSplit = strings.Split(portString, ":") - if len(portSplit) != 2 { - return breverrors.NewValidationError("port format invalid, use local_port:remote_port") - } - } else { - return breverrors.NewValidationError("port format invalid, use local_port:remote_port") +// parsePortString validates and splits a "local:remote" port string. +func parsePortString(portString string) (localPort, remotePort string, err error) { + if !strings.Contains(portString, ":") { + return "", "", breverrors.NewValidationError("port format invalid, use local_port:remote_port") + } + parts := strings.Split(portString, ":") + if len(parts) != 2 { + return "", "", breverrors.NewValidationError("port format invalid, use local_port:remote_port") + } + return parts[0], parts[1], nil +} + +// isPortAlreadyAllocatedError returns true if the error indicates the port is already open. +func isPortAlreadyAllocatedError(err error) bool { + return err != nil && strings.Contains(err.Error(), "already allocated") +} + +func RunPortforward(t *terminal.Terminal, pfStore PortforwardStore, nameOrID string, portString string, useHost bool) error { + localPort, remotePort, err := parsePortString(portString) + if err != nil { + return err } res := refresh.RunRefreshAsync(pfStore) - sshName, err := ConvertNametoSSHName(pfStore, nameOrID, useHost) + target, err := util.ResolveWorkspaceOrNode(pfStore, nameOrID) if err != nil { return breverrors.WrapAndTrace(err) } + if target.Node != nil { + return portForwardExternalNode(t, pfStore, res, target.Node, localPort, remotePort) + } + sshName := string(target.Workspace.GetLocalIdentifier()) + if useHost { + sshName += "-host" + } err = res.Await() if err != nil { return breverrors.WrapAndTrace(err) } - _, err = RunSSHPortForward("-L", portSplit[0], portSplit[1], sshName) + t.Vprintf("Port forwarding...\n") + t.Vprintf("localhost:%s -> %s:%s\n", localPort, sshName, remotePort) + _, err = RunSSHPortForward("-L", localPort, remotePort, sshName) if err != nil { return breverrors.WrapAndTrace(err) } return nil } -func ConvertNametoSSHName(store PortforwardStore, workspaceNameOrID string, useHost bool) (string, error) { - workspace, err := util.GetUserWorkspaceByNameOrIDErr(store, workspaceNameOrID) +func portForwardExternalNode(t *terminal.Terminal, pfStore PortforwardStore, res *refresh.RefreshRes, node *nodev1.ExternalNode, localPort, remotePort string) error { + info, err := util.ResolveExternalNodeSSH(pfStore, node) if err != nil { - return "", breverrors.WrapAndTrace(err) + return breverrors.WrapAndTrace(err) } - sshName := string(workspace.GetLocalIdentifier()) - if useHost { - sshName += "-host" + + // Parse the remote port so we can open it via the OpenPort RPC. + remotePortNum, err := strconv.ParseInt(remotePort, 10, 32) + if err != nil { + return breverrors.WrapAndTrace(fmt.Errorf("invalid remote port %q: %w", remotePort, err)) + } + + // Open the port on the netbird side so it's accessible. + // This binding persists after the CLI exits β€” it won't be closed on Ctrl+C. + t.Vprintf("Opening port %s on node %q...\n", remotePort, node.GetName()) + _, err = util.OpenPort(pfStore, node.GetExternalNodeId(), int32(remotePortNum), nodev1.PortProtocol_PORT_PROTOCOL_TCP) + if err != nil { + // Port already allocated is not a real error β€” it's already open. + if isPortAlreadyAllocatedError(err) { + t.Vprintf("Port %s is already open on the remote node.\n", remotePort) + } else { + return breverrors.WrapAndTrace(err) + } + } else { + t.Vprintf("Port %s is now bound on the remote node. Note: this binding persists after this command exits.\n", remotePort) + } + + if err := res.Await(); err != nil { + return breverrors.WrapAndTrace(err) + } + + // The SSH tunnel forwards local traffic through the SSH connection to the actual port on the box. + // TODO there isn't support for killing the port forward in either case, and no ClosePort for external node + alias := info.SSHAlias() + t.Vprintf("Setting up local forward: localhost:%s -> %s:%s\n", localPort, alias, remotePort) + _, err = RunSSHPortForward("-L", localPort, remotePort, alias) + if err != nil { + return breverrors.WrapAndTrace(err) } - return sshName, nil + return nil } func RunSSHPortForward(forwardType string, localPort string, remotePort string, sshName string) (*os.Process, error) { @@ -131,9 +185,6 @@ func RunSSHPortForward(forwardType string, localPort string, remotePort string, cmdSHH.Stdout = os.Stdout cmdSHH.Stderr = os.Stderr - fmt.Println("Port forwarding...") - fmt.Printf("localhost:%s -> %s:%s\n", localPort, sshName, remotePort) - err = cmdSHH.Start() if err != nil { return nil, breverrors.Wrap(err, "Failed to start SSH command") diff --git a/pkg/cmd/portforward/portforward_test.go b/pkg/cmd/portforward/portforward_test.go index 42e087ae..f17dd3b2 100644 --- a/pkg/cmd/portforward/portforward_test.go +++ b/pkg/cmd/portforward/portforward_test.go @@ -1 +1,70 @@ package portforward + +import ( + "fmt" + "testing" +) + +func TestParsePortString_Valid(t *testing.T) { + local, remote, err := parsePortString("8080:3000") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if local != "8080" { + t.Errorf("expected local 8080, got %s", local) + } + if remote != "3000" { + t.Errorf("expected remote 3000, got %s", remote) + } +} + +func TestParsePortString_SamePort(t *testing.T) { + local, remote, err := parsePortString("8080:8080") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if local != "8080" || remote != "8080" { + t.Errorf("expected 8080:8080, got %s:%s", local, remote) + } +} + +func TestParsePortString_NoColon(t *testing.T) { + _, _, err := parsePortString("8080") + if err == nil { + t.Fatal("expected error for missing colon") + } +} + +func TestParsePortString_TooManyColons(t *testing.T) { + _, _, err := parsePortString("8080:3000:443") + if err == nil { + t.Fatal("expected error for too many colons") + } +} + +func TestParsePortString_Empty(t *testing.T) { + _, _, err := parsePortString("") + if err == nil { + t.Fatal("expected error for empty string") + } +} + +func TestIsPortAlreadyAllocatedError_True(t *testing.T) { + err := fmt.Errorf("skybridge API error: 400, body: Port 8080 is already allocated for this client") + if !isPortAlreadyAllocatedError(err) { + t.Error("expected true for 'already allocated' error") + } +} + +func TestIsPortAlreadyAllocatedError_False(t *testing.T) { + err := fmt.Errorf("connection refused") + if isPortAlreadyAllocatedError(err) { + t.Error("expected false for unrelated error") + } +} + +func TestIsPortAlreadyAllocatedError_Nil(t *testing.T) { + if isPortAlreadyAllocatedError(nil) { + t.Error("expected false for nil error") + } +} diff --git a/pkg/cmd/refresh/refresh.go b/pkg/cmd/refresh/refresh.go index 8499156a..24df0986 100644 --- a/pkg/cmd/refresh/refresh.go +++ b/pkg/cmd/refresh/refresh.go @@ -2,12 +2,20 @@ package refresh import ( + "context" "fmt" "io" "io/fs" + "log" "sync" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/cmd/util" "github.com/brevdev/brev-cli/pkg/cmdcontext" + "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/ssh" @@ -22,6 +30,8 @@ type RefreshStore interface { ssh.SSHConfigurerV2Store GetCurrentUser() (*entity.User, error) GetCurrentUserKeys() (*entity.UserKeys, error) + GetActiveOrganizationOrDefault() (*entity.Organization, error) + GetAccessToken() (string, error) Chmod(string, fs.FileMode) error MkdirAll(string, fs.FileMode) error GetBrevCloudflaredBinaryPath() (string, error) @@ -151,10 +161,46 @@ func GetConfigUpdater(store RefreshStore) (*ssh.ConfigUpdater, error) { } cu := ssh.NewConfigUpdater(store, configs, keys.PrivateKey) + cu.ExternalNodes = getExternalNodeSSHEntries(store) return cu, nil } +// getExternalNodeSSHEntries fetches external nodes and resolves their SSH details. +// This is best-effort: if anything fails, it returns nil so workspace SSH config is unaffected. +func getExternalNodeSSHEntries(store RefreshStore) []ssh.ExternalNodeSSHEntry { + org, err := store.GetActiveOrganizationOrDefault() + if err != nil { + log.Printf("external nodes: skipping (no org): %v", err) + return nil + } + + user, err := store.GetCurrentUser() + if err != nil { + log.Printf("external nodes: skipping (no user): %v", err) + return nil + } + + client := register.NewNodeServiceClient(store, config.GlobalConfig.GetBrevPublicAPIURL()) + resp, err := client.ListNodes(context.Background(), connect.NewRequest(&nodev1.ListNodesRequest{ + OrganizationId: org.ID, + })) + if err != nil { + log.Printf("external nodes: skipping (list failed): %v", err) + return nil + } + + var entries []ssh.ExternalNodeSSHEntry + for _, node := range resp.Msg.GetItems() { + entry := util.ResolveNodeSSHEntry(user.ID, node) + if entry != nil { + entries = append(entries, *entry) + } + } + + return entries +} + func GetCloudflare(refreshStore RefreshStore) store.Cloudflared { cl := store.NewCloudflare(refreshStore) return cl diff --git a/pkg/cmd/refresh/refresh_test.go b/pkg/cmd/refresh/refresh_test.go index d8c53677..bf13a84a 100644 --- a/pkg/cmd/refresh/refresh_test.go +++ b/pkg/cmd/refresh/refresh_test.go @@ -1 +1,170 @@ package refresh + +import ( + "testing" + + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + + "github.com/brevdev/brev-cli/pkg/cmd/util" +) + +func strPtr(s string) *string { return &s } + +func TestResolveNodeSSHEntry_HappyPath(t *testing.T) { + node := &nodev1.ExternalNode{ + Name: "My GPU Box", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "user_1", LinuxUser: "ec2-user"}, + }, + Ports: []*nodev1.Port{ + { + Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, + PortNumber: 22, + ServerPort: 41920, + Hostname: strPtr("10.0.0.5"), + }, + }, + } + + entry := util.ResolveNodeSSHEntry("user_1", node) + if entry == nil { + t.Fatal("expected non-nil entry") + } + if entry.Alias != "my-gpu-box" { + t.Errorf("expected alias my-gpu-box, got %s", entry.Alias) + } + if entry.Hostname != "10.0.0.5" { + t.Errorf("expected hostname 10.0.0.5, got %s", entry.Hostname) + } + if entry.Port != 41920 { + t.Errorf("expected port 41920 (ServerPort), got %d", entry.Port) + } + if entry.User != "ec2-user" { + t.Errorf("expected user ec2-user, got %s", entry.User) + } +} + +func TestResolveNodeSSHEntry_UsesServerPortNotPortNumber(t *testing.T) { + node := &nodev1.ExternalNode{ + Name: "test-node", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "user_1", LinuxUser: "ubuntu"}, + }, + Ports: []*nodev1.Port{ + { + Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, + PortNumber: 22, // well-known port β€” NOT what we should connect to + ServerPort: 51234, // netbird-assigned port β€” correct + Hostname: strPtr("gateway.example.com"), + }, + }, + } + + entry := util.ResolveNodeSSHEntry("user_1", node) + if entry == nil { + t.Fatal("expected non-nil entry") + } + if entry.Port != 51234 { + t.Errorf("expected ServerPort 51234, got %d (should not use PortNumber 22)", entry.Port) + } +} + +func TestResolveNodeSSHEntry_SkipsNoAccess(t *testing.T) { + node := &nodev1.ExternalNode{ + Name: "box", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "other_user", LinuxUser: "ubuntu"}, + }, + Ports: []*nodev1.Port{ + {Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, ServerPort: 22, Hostname: strPtr("h")}, + }, + } + + entry := util.ResolveNodeSSHEntry("user_1", node) + if entry != nil { + t.Errorf("expected nil for no access, got %+v", entry) + } +} + +func TestResolveNodeSSHEntry_SkipsNoSSHPort(t *testing.T) { + node := &nodev1.ExternalNode{ + Name: "box", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "user_1", LinuxUser: "ubuntu"}, + }, + Ports: []*nodev1.Port{ + {Protocol: nodev1.PortProtocol_PORT_PROTOCOL_TCP, ServerPort: 8080, Hostname: strPtr("h")}, + }, + } + + entry := util.ResolveNodeSSHEntry("user_1", node) + if entry != nil { + t.Errorf("expected nil for no SSH port, got %+v", entry) + } +} + +func TestResolveNodeSSHEntry_SkipsEmptyHostname(t *testing.T) { + node := &nodev1.ExternalNode{ + Name: "box", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "user_1", LinuxUser: "ubuntu"}, + }, + Ports: []*nodev1.Port{ + {Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, ServerPort: 22}, + }, + } + + entry := util.ResolveNodeSSHEntry("user_1", node) + if entry != nil { + t.Errorf("expected nil for empty hostname, got %+v", entry) + } +} + +func TestResolveNodeSSHEntry_MultipleNodes(t *testing.T) { + nodes := []*nodev1.ExternalNode{ + { + Name: "Node A", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "user_1", LinuxUser: "ubuntu"}, + }, + Ports: []*nodev1.Port{ + {Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, ServerPort: 41000, Hostname: strPtr("10.0.0.1")}, + }, + }, + { + Name: "Node B", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "other_user", LinuxUser: "root"}, + }, + Ports: []*nodev1.Port{ + {Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, ServerPort: 42000, Hostname: strPtr("10.0.0.2")}, + }, + }, + { + Name: "Node C", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "user_1", LinuxUser: "admin"}, + }, + Ports: []*nodev1.Port{ + {Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, ServerPort: 43000, Hostname: strPtr("10.0.0.3")}, + }, + }, + } + + var entries []string + for _, node := range nodes { + entry := util.ResolveNodeSSHEntry("user_1", node) + if entry != nil { + entries = append(entries, entry.Alias) + } + } + if len(entries) != 2 { + t.Fatalf("expected 2 entries (skipping Node B), got %d", len(entries)) + } + if entries[0] != "node-a" { + t.Errorf("expected alias node-a, got %s", entries[0]) + } + if entries[1] != "node-c" { + t.Errorf("expected alias node-c, got %s", entries[1]) + } +} diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 63cc9550..7da520e2 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -241,7 +241,7 @@ func checkExistingRegistration(ctx context.Context, t *terminal.Terminal, s Regi t.Vprintf("This machine is already registered as %q.\n", reg.DisplayName) t.Vprint("Run 'brev deregister' first if you want to re-register with a different name.") t.Vprint("") - t.Vprintf("If you are having tunnel issues, run 'brev register %q' to reconnect.", reg.DisplayName) + t.Vprintf("If you are having tunnel issues, run 'brev register %q' to reconnect.\n", reg.DisplayName) return nil } diff --git a/pkg/cmd/shell/shell.go b/pkg/cmd/shell/shell.go index 22db1ac7..3998dbe5 100644 --- a/pkg/cmd/shell/shell.go +++ b/pkg/cmd/shell/shell.go @@ -6,6 +6,8 @@ import ( "os/exec" "time" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "github.com/brevdev/brev-cli/pkg/analytics" "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmd/hello" @@ -76,10 +78,14 @@ const pollTimeout = 10 * time.Minute func runShellCommand(t *terminal.Terminal, sstore ShellStore, workspaceNameOrID string, host bool) error { s := t.NewSpinner() - workspace, err := util.GetUserWorkspaceByNameOrIDErr(sstore, workspaceNameOrID) + target, err := util.ResolveWorkspaceOrNode(sstore, workspaceNameOrID) if err != nil { return breverrors.WrapAndTrace(err) } + if target.Node != nil { + return shellIntoExternalNode(t, sstore, target.Node) + } + workspace := target.Workspace if workspace.Status == "STOPPED" { // we start the env for the user err = util.StartWorkspaceIfStopped(t, s, sstore, workspaceNameOrID, workspace, pollTimeout) @@ -144,6 +150,42 @@ func runShellCommand(t *terminal.Terminal, sstore ShellStore, workspaceNameOrID return nil } +func shellIntoExternalNode(t *terminal.Terminal, sstore ShellStore, node *nodev1.ExternalNode) error { + info, err := util.ResolveExternalNodeSSH(sstore, node) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + privateKeyPath, err := sstore.GetPrivateKeyPath() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + t.Vprintf("Connecting to external node %q as %s on port %d (key: %s)...\n", node.GetName(), info.LinuxUser, info.Port, privateKeyPath) + return runSSHWithPort(info.SSHTarget(), info.Port, privateKeyPath) +} + +func runSSHWithPort(target string, port int32, identityFile string) error { + sshAgentEval := "eval $(ssh-agent -s)" + cmd := fmt.Sprintf("%s && ssh -i %q -o StrictHostKeyChecking=no -p %d %s", sshAgentEval, identityFile, port, target) + + sshCmd := exec.Command("bash", "-c", cmd) //nolint:gosec //cmd is constructed from API data + sshCmd.Stderr = os.Stderr + sshCmd.Stdout = os.Stdout + sshCmd.Stdin = os.Stdin + + err := hello.SetHasRunShell(true) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + err = sshCmd.Run() + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil +} + func runSSH(sshAlias string) error { sshAgentEval := "eval $(ssh-agent -s)" cmd := fmt.Sprintf("%s && ssh %s", sshAgentEval, sshAlias) diff --git a/pkg/cmd/shell/shell_test.go b/pkg/cmd/shell/shell_test.go index b1f847e4..172e3b3c 100644 --- a/pkg/cmd/shell/shell_test.go +++ b/pkg/cmd/shell/shell_test.go @@ -1 +1,102 @@ package shell + +import ( + "testing" + + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + + "github.com/brevdev/brev-cli/pkg/cmd/util" +) + +func strPtr(s string) *string { return &s } + +// TestResolveExternalNodeSSH_BuildsCorrectInfo tests that the SSH info +// returned by ResolveNodeSSHEntry has the correct target, alias, and home path β€” +// the same values shellIntoExternalNode uses to build its SSH command. +func TestResolveExternalNodeSSH_BuildsCorrectInfo(t *testing.T) { + node := &nodev1.ExternalNode{ + Name: "My GPU Box", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "user_1", LinuxUser: "ec2-user"}, + }, + Ports: []*nodev1.Port{ + { + Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, + PortNumber: 22, + ServerPort: 41920, + Hostname: strPtr("10.0.0.5"), + }, + }, + } + + entry := util.ResolveNodeSSHEntry("user_1", node) + if entry == nil { + t.Fatal("expected non-nil entry") + } + + // Build ExternalNodeSSHInfo the same way ResolveExternalNodeSSH does. + info := &util.ExternalNodeSSHInfo{ + Node: node, + LinuxUser: entry.User, + Hostname: entry.Hostname, + Port: entry.Port, + } + + // SSHTarget β€” used by runSSHWithPort + if got := info.SSHTarget(); got != "ec2-user@10.0.0.5" { + t.Errorf("SSHTarget() = %q, want %q", got, "ec2-user@10.0.0.5") + } + + // SSHAlias β€” used by SSH config-based commands (open, copy, port-forward) + if got := info.SSHAlias(); got != "my-gpu-box" { + t.Errorf("SSHAlias() = %q, want %q", got, "my-gpu-box") + } + + // HomePath β€” used by open to set the remote directory + if got := info.HomePath(); got != "/home/ec2-user" { + t.Errorf("HomePath() = %q, want %q", got, "/home/ec2-user") + } + + // Port β€” shellIntoExternalNode passes this to runSSHWithPort + if info.Port != 41920 { + t.Errorf("Port = %d, want 41920", info.Port) + } +} + +// TestResolveExternalNodeSSH_NoAccess verifies that a node without SSH access +// for the given user returns nil, which shellIntoExternalNode would treat as an error. +func TestResolveExternalNodeSSH_NoAccess(t *testing.T) { + node := &nodev1.ExternalNode{ + Name: "locked-box", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "other_user", LinuxUser: "root"}, + }, + Ports: []*nodev1.Port{ + {Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, ServerPort: 22, Hostname: strPtr("10.0.0.1")}, + }, + } + + entry := util.ResolveNodeSSHEntry("user_1", node) + if entry != nil { + t.Errorf("expected nil for user without access, got %+v", entry) + } +} + +// TestResolveExternalNodeSSH_NoSSHPort verifies that a node with no SSH port +// returns nil even when the user has access. +func TestResolveExternalNodeSSH_NoSSHPort(t *testing.T) { + node := &nodev1.ExternalNode{ + Name: "no-ssh", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "user_1", LinuxUser: "ubuntu"}, + }, + Ports: []*nodev1.Port{ + {Protocol: nodev1.PortProtocol_PORT_PROTOCOL_TCP, ServerPort: 8080, Hostname: strPtr("10.0.0.1")}, + }, + } + + entry := util.ResolveNodeSSHEntry("user_1", node) + if entry != nil { + t.Errorf("expected nil for node without SSH port, got %+v", entry) + } +} diff --git a/pkg/cmd/util/externalnode.go b/pkg/cmd/util/externalnode.go new file mode 100644 index 00000000..bbf29f47 --- /dev/null +++ b/pkg/cmd/util/externalnode.go @@ -0,0 +1,165 @@ +package util + +import ( + "context" + "fmt" + "strings" + + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/config" + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/ssh" +) + +// ExternalNodeStore is the minimal interface needed for external node lookup and SSH resolution. +type ExternalNodeStore interface { + GetActiveOrganizationOrDefault() (*entity.Organization, error) + GetAccessToken() (string, error) + GetCurrentUser() (*entity.User, error) +} + +type WorkspaceOrNodeResolver interface { + GetWorkspaceByNameOrIDErrStore + ExternalNodeStore +} + +// WorkspaceOrNode is returned by ResolveWorkspaceOrNode. Exactly one field is non-nil. +type WorkspaceOrNode struct { + Workspace *entity.Workspace + Node *nodev1.ExternalNode +} + +// ResolveWorkspaceOrNode looks up a workspace first; if not found, falls back to external nodes. +// The store must satisfy both GetWorkspaceByNameOrIDErrStore and ExternalNodeStore. +func ResolveWorkspaceOrNode(store WorkspaceOrNodeResolver, nameOrID string, +) (*WorkspaceOrNode, error) { + workspace, wsErr := GetUserWorkspaceByNameOrIDErr(store, nameOrID) + if wsErr == nil { + return &WorkspaceOrNode{Workspace: workspace}, nil + } + node, nodeErr := FindExternalNode(store, nameOrID) + if nodeErr != nil || node == nil { + return nil, wsErr // return original workspace error + } + return &WorkspaceOrNode{Node: node}, nil +} + +// ExternalNodeSSHInfo holds resolved SSH connection details for an external node. +type ExternalNodeSSHInfo struct { + Node *nodev1.ExternalNode + LinuxUser string + Hostname string + Port int32 +} + +// SSHTarget returns the "user@host" string for direct SSH. +func (info *ExternalNodeSSHInfo) SSHTarget() string { + return fmt.Sprintf("%s@%s", info.LinuxUser, info.Hostname) +} + +// SSHAlias returns a sanitized node name suitable for use as an SSH config Host alias. +func (info *ExternalNodeSSHInfo) SSHAlias() string { + return ssh.SanitizeNodeName(info.Node.GetName()) +} + +// HomePath returns the home directory path for the linux user. +func (info *ExternalNodeSSHInfo) HomePath() string { + return fmt.Sprintf("/home/%s", info.LinuxUser) +} + +// ResolveNodeSSHEntry is a pure data function that extracts the SSH config entry +// for a given user from a node. Returns nil if the user has no access or the node +// has no SSH port. This is the single source of truth for nodeβ†’SSHEntry conversion, +// used by both ResolveExternalNodeSSH (for commands) and refresh (for SSH config generation). +func ResolveNodeSSHEntry(userID string, node *nodev1.ExternalNode) *ssh.ExternalNodeSSHEntry { + var linuxUser string + for _, access := range node.GetSshAccess() { + if access.GetUserId() == userID { + linuxUser = access.GetLinuxUser() + break + } + } + if linuxUser == "" { + return nil + } + + var sshPort *nodev1.Port + for _, p := range node.GetPorts() { + if p.GetProtocol() == nodev1.PortProtocol_PORT_PROTOCOL_SSH { + sshPort = p + break + } + } + if sshPort == nil || sshPort.GetHostname() == "" { + return nil + } + + return &ssh.ExternalNodeSSHEntry{ + Alias: ssh.SanitizeNodeName(node.GetName()), + Hostname: sshPort.GetHostname(), + Port: sshPort.GetServerPort(), + User: linuxUser, + } +} + +// OpenPort calls the OpenPort RPC to open a port on an external node via netbird. +// This must be called before attempting to connect to a non-SSH port on a node. +func OpenPort(store ExternalNodeStore, nodeID string, portNumber int32, protocol nodev1.PortProtocol) (*nodev1.Port, error) { + client := register.NewNodeServiceClient(store, config.GlobalConfig.GetBrevPublicAPIURL()) + resp, err := client.OpenPort(context.Background(), connect.NewRequest(&nodev1.OpenPortRequest{ + ExternalNodeId: nodeID, + Protocol: protocol, + PortNumber: portNumber, + })) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return resp.Msg.GetPort(), nil +} + +// FindExternalNode searches for an external node by name in the user's active organization. +// Returns (nil, nil) if no matching node is found. +func FindExternalNode(store ExternalNodeStore, name string) (*nodev1.ExternalNode, error) { + org, err := store.GetActiveOrganizationOrDefault() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + client := register.NewNodeServiceClient(store, config.GlobalConfig.GetBrevPublicAPIURL()) + resp, err := client.ListNodes(context.Background(), connect.NewRequest(&nodev1.ListNodesRequest{ + OrganizationId: org.ID, + })) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + for _, node := range resp.Msg.GetItems() { + if strings.EqualFold(node.GetName(), name) { + return node, nil + } + } + return nil, nil +} + +// ResolveExternalNodeSSH resolves the SSH connection details for an external node +// by finding the current user's SSH access and the node's SSH port. +func ResolveExternalNodeSSH(store ExternalNodeStore, node *nodev1.ExternalNode) (*ExternalNodeSSHInfo, error) { + user, err := store.GetCurrentUser() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + entry := ResolveNodeSSHEntry(user.ID, node) + if entry == nil { + return nil, breverrors.New(fmt.Sprintf("cannot resolve SSH for node %q β€” no access, no SSH port, or no hostname", node.GetName())) + } + + return &ExternalNodeSSHInfo{ + Node: node, + LinuxUser: entry.User, + Hostname: entry.Hostname, + Port: entry.Port, + }, nil +} diff --git a/pkg/cmd/util/externalnode_test.go b/pkg/cmd/util/externalnode_test.go new file mode 100644 index 00000000..c8521b83 --- /dev/null +++ b/pkg/cmd/util/externalnode_test.go @@ -0,0 +1,187 @@ +package util + +import ( + "fmt" + "testing" + + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + + "github.com/brevdev/brev-cli/pkg/entity" +) + +// mockExternalNodeStore satisfies ExternalNodeStore for unit tests that +// only exercise ResolveExternalNodeSSH (no RPC calls). +type mockExternalNodeStore struct { + user *entity.User + org *entity.Organization + err error +} + +func (m *mockExternalNodeStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + return m.org, nil +} + +func (m *mockExternalNodeStore) GetAccessToken() (string, error) { return "tok", nil } + +func (m *mockExternalNodeStore) GetCurrentUser() (*entity.User, error) { + if m.err != nil { + return nil, m.err + } + return m.user, nil +} + +func strPtr(s string) *string { return &s } + +func makeTestNode(name, userID, linuxUser, hostname string, serverPort int32) *nodev1.ExternalNode { + return &nodev1.ExternalNode{ + ExternalNodeId: "unode_test", + Name: name, + SshAccess: []*nodev1.SSHAccess{ + {UserId: userID, LinuxUser: linuxUser}, + }, + Ports: []*nodev1.Port{ + { + Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, + PortNumber: 22, + ServerPort: serverPort, + Hostname: &hostname, + }, + }, + } +} + +func TestResolveExternalNodeSSH_HappyPath(t *testing.T) { + store := &mockExternalNodeStore{ + user: &entity.User{ID: "user_1"}, + } + node := makeTestNode("My GPU Box", "user_1", "ec2-user", "10.0.0.5", 51234) + + info, err := ResolveExternalNodeSSH(store, node) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if info.LinuxUser != "ec2-user" { + t.Errorf("expected ec2-user, got %s", info.LinuxUser) + } + if info.Hostname != "10.0.0.5" { + t.Errorf("expected 10.0.0.5, got %s", info.Hostname) + } + if info.Port != 51234 { + t.Errorf("expected port 51234, got %d", info.Port) + } +} + +func TestResolveExternalNodeSSH_UsesServerPortNotPortNumber(t *testing.T) { + store := &mockExternalNodeStore{ + user: &entity.User{ID: "user_1"}, + } + node := &nodev1.ExternalNode{ + Name: "test-node", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "user_1", LinuxUser: "ubuntu"}, + }, + Ports: []*nodev1.Port{ + { + Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, + PortNumber: 22, // well-known port β€” NOT what we connect to + ServerPort: 41920, // netbird-assigned port β€” this is correct + Hostname: strPtr("gateway.example.com"), + }, + }, + } + + info, err := ResolveExternalNodeSSH(store, node) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if info.Port != 41920 { + t.Errorf("expected ServerPort 41920, got %d (should not use PortNumber 22)", info.Port) + } +} + +func TestResolveExternalNodeSSH_NoAccess(t *testing.T) { + store := &mockExternalNodeStore{ + user: &entity.User{ID: "user_1"}, + } + node := makeTestNode("box", "other_user", "ubuntu", "10.0.0.5", 22) + + _, err := ResolveExternalNodeSSH(store, node) + if err == nil { + t.Fatal("expected error for no SSH access") + } +} + +func TestResolveExternalNodeSSH_NoSSHPort(t *testing.T) { + store := &mockExternalNodeStore{ + user: &entity.User{ID: "user_1"}, + } + node := &nodev1.ExternalNode{ + Name: "box", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "user_1", LinuxUser: "ubuntu"}, + }, + Ports: []*nodev1.Port{ + {Protocol: nodev1.PortProtocol_PORT_PROTOCOL_TCP, Hostname: strPtr("h")}, + }, + } + + _, err := ResolveExternalNodeSSH(store, node) + if err == nil { + t.Fatal("expected error for no SSH port") + } +} + +func TestResolveExternalNodeSSH_EmptyHostname(t *testing.T) { + store := &mockExternalNodeStore{ + user: &entity.User{ID: "user_1"}, + } + node := &nodev1.ExternalNode{ + Name: "box", + SshAccess: []*nodev1.SSHAccess{ + {UserId: "user_1", LinuxUser: "ubuntu"}, + }, + Ports: []*nodev1.Port{ + {Protocol: nodev1.PortProtocol_PORT_PROTOCOL_SSH, ServerPort: 22}, + }, + } + + _, err := ResolveExternalNodeSSH(store, node) + if err == nil { + t.Fatal("expected error for empty hostname") + } +} + +func TestResolveExternalNodeSSH_GetUserError(t *testing.T) { + store := &mockExternalNodeStore{ + err: fmt.Errorf("auth failed"), + } + node := makeTestNode("box", "user_1", "ubuntu", "h", 22) + + _, err := ResolveExternalNodeSSH(store, node) + if err == nil { + t.Fatal("expected error when GetCurrentUser fails") + } +} + +func TestExternalNodeSSHInfo_SSHTarget(t *testing.T) { + info := &ExternalNodeSSHInfo{LinuxUser: "ec2-user", Hostname: "10.0.0.5"} + if got := info.SSHTarget(); got != "ec2-user@10.0.0.5" { + t.Errorf("expected ec2-user@10.0.0.5, got %s", got) + } +} + +func TestExternalNodeSSHInfo_SSHAlias(t *testing.T) { + info := &ExternalNodeSSHInfo{ + Node: &nodev1.ExternalNode{Name: "My GPU Box"}, + } + if got := info.SSHAlias(); got != "my-gpu-box" { + t.Errorf("expected my-gpu-box, got %s", got) + } +} + +func TestExternalNodeSSHInfo_HomePath(t *testing.T) { + info := &ExternalNodeSSHInfo{LinuxUser: "ec2-user"} + if got := info.HomePath(); got != "/home/ec2-user" { + t.Errorf("expected /home/ec2-user, got %s", got) + } +} diff --git a/pkg/ssh/sshconfigurer.go b/pkg/ssh/sshconfigurer.go index 7539cba3..8c5f88d3 100644 --- a/pkg/ssh/sshconfigurer.go +++ b/pkg/ssh/sshconfigurer.go @@ -5,6 +5,7 @@ import ( "encoding/xml" "fmt" "log" + "regexp" "strings" "text/template" @@ -16,6 +17,66 @@ import ( "github.com/hashicorp/go-multierror" ) +// ExternalNodeSSHEntry holds pre-resolved SSH details for an external node. +type ExternalNodeSSHEntry struct { + Alias string + Hostname string + Port int32 + User string +} + +var ( + sanitizeNodeNameRe = regexp.MustCompile(`[^a-z0-9-]+`) + collapseHyphensRe = regexp.MustCompile(`-{2,}`) +) + +// SanitizeNodeName converts a node display name into a valid SSH Host alias. +func SanitizeNodeName(name string) string { + s := strings.ToLower(name) + s = sanitizeNodeNameRe.ReplaceAllString(s, "-") + s = collapseHyphensRe.ReplaceAllString(s, "-") + s = strings.Trim(s, "-") + if s == "" { + s = "node" + } + return s +} + +const SSHConfigEntryTemplateNode = `Host {{ .Alias }} + HostName {{ .Hostname }} + User {{ .User }} + Port {{ .Port }} + IdentityFile {{ .IdentityFile }} + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + ServerAliveInterval 30 + ForwardAgent yes + +` + +type externalNodeSSHConfigEntry struct { + Alias string + Hostname string + User string + Port int32 + IdentityFile string +} + +func makeSSHConfigEntryForNode(node ExternalNodeSSHEntry, privateKeyPath string) (string, error) { + entry := externalNodeSSHConfigEntry{ + Alias: node.Alias, + Hostname: node.Hostname, + User: node.User, + Port: node.Port, + IdentityFile: "\"" + privateKeyPath + "\"", + } + tmpl, err := template.New(node.Alias).Parse(SSHConfigEntryTemplateNode) + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + return tmplAndValToString(tmpl, entry) +} + type ConfigUpdaterStore interface { autostartconf.AutoStartStore GetContextWorkspaces() ([]entity.Workspace, error) @@ -23,13 +84,14 @@ type ConfigUpdaterStore interface { } type Config interface { - Update(workspaces []entity.Workspace) error + Update(workspaces []entity.Workspace, nodes []ExternalNodeSSHEntry) error } type ConfigUpdater struct { - Store ConfigUpdaterStore - Configs []Config - PrivateKey string + Store ConfigUpdaterStore + Configs []Config + PrivateKey string + ExternalNodes []ExternalNodeSSHEntry } func NewConfigUpdater(store ConfigUpdaterStore, configs []Config, privateKey string) *ConfigUpdater { @@ -59,8 +121,8 @@ func (c ConfigUpdater) Run() error { } var res error - for _, c := range c.Configs { - err := c.Update(runningWorkspaces) + for _, cfg := range c.Configs { + err := cfg.Update(runningWorkspaces, c.ExternalNodes) if err != nil { res = multierror.Append(res, err) } @@ -121,8 +183,8 @@ func NewSSHConfigurerV2(store SSHConfigurerV2Store) *SSHConfigurerV2 { } } -func (s SSHConfigurerV2) Update(workspaces []entity.Workspace) error { - newConfig, err := s.CreateNewSSHConfig(workspaces) +func (s SSHConfigurerV2) Update(workspaces []entity.Workspace, nodes []ExternalNodeSSHEntry) error { + newConfig, err := s.CreateNewSSHConfig(workspaces, nodes) if err != nil { return breverrors.WrapAndTrace(err) } @@ -183,7 +245,7 @@ func (s SSHConfigurerV2) CreateWSLConfig(workspaces []entity.Workspace) (string, return sshConfig, nil } -func (s SSHConfigurerV2) CreateNewSSHConfig(workspaces []entity.Workspace) (string, error) { +func (s SSHConfigurerV2) CreateNewSSHConfig(workspaces []entity.Workspace, nodes []ExternalNodeSSHEntry) (string, error) { configPath, err := s.store.GetUserSSHConfigPath() if err != nil { return "", breverrors.WrapAndTrace(err) @@ -203,6 +265,15 @@ func (s SSHConfigurerV2) CreateNewSSHConfig(workspaces []entity.Workspace) (stri if err != nil { return "", breverrors.WrapAndTrace(err) } + + for _, node := range nodes { + entry, err := makeSSHConfigEntryForNode(node, pkPath) + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + sshConfig += entry + } + return sshConfig, nil } @@ -570,7 +641,7 @@ func NewSSHConfigurerJetBrains(store SSHConfigurerV2Store) (*SSHConfigurerJetBra }, nil } -func (s SSHConfigurerJetBrains) Update(workspaces []entity.Workspace) error { +func (s SSHConfigurerJetBrains) Update(workspaces []entity.Workspace, _ []ExternalNodeSSHEntry) error { doesJbPathExist, err := s.store.DoesJetbrainsFilePathExist() if err != nil { return breverrors.WrapAndTrace(err) diff --git a/pkg/ssh/sshconfigurer_test.go b/pkg/ssh/sshconfigurer_test.go index f38d16af..5b242dcf 100644 --- a/pkg/ssh/sshconfigurer_test.go +++ b/pkg/ssh/sshconfigurer_test.go @@ -129,7 +129,7 @@ func (d DummySSHConfigurerV2Store) GetBrevCloudflaredBinaryPath() (string, error func TestCreateNewSSHConfig(t *testing.T) { c := NewSSHConfigurerV2(DummySSHConfigurerV2Store{}) - cStr, err := c.CreateNewSSHConfig(somePlainWorkspaces) + cStr, err := c.CreateNewSSHConfig(somePlainWorkspaces, nil) assert.Nil(t, err) // sometimes vs code is not happy with the formatting @@ -199,7 +199,7 @@ Host %s-host ) assert.Equal(t, correct, cStr) - cStr, err = c.CreateNewSSHConfig([]entity.Workspace{}) + cStr, err = c.CreateNewSSHConfig([]entity.Workspace{}, nil) assert.Nil(t, err) correct = `# included in /my/user/config ` @@ -559,6 +559,115 @@ Host testName2-host } } +func TestSanitizeNodeName(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"My GPU Box", "my-gpu-box"}, + {"pratik-ec2", "pratik-ec2"}, + {"already-clean", "already-clean"}, + {"UPPER CASE", "upper-case"}, + {"special!@#chars", "special-chars"}, + {" leading/trailing ", "leading-trailing"}, + {"multiple spaces", "multiple-spaces"}, + {"", "node"}, + {"!!!!", "node"}, + {"a", "a"}, + {"node-with--double-dash", "node-with-double-dash"}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := SanitizeNodeName(tt.input) + if got != tt.want { + t.Errorf("SanitizeNodeName(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestMakeSSHConfigEntryForNode(t *testing.T) { + entry := ExternalNodeSSHEntry{ + Alias: "my-gpu-box", + Hostname: "10.0.0.5", + Port: 41920, + User: "ec2-user", + } + + got, err := makeSSHConfigEntryForNode(entry, "/home/test/.brev/brev.pem") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + want := `Host my-gpu-box + HostName 10.0.0.5 + User ec2-user + Port 41920 + IdentityFile "/home/test/.brev/brev.pem" + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + ServerAliveInterval 30 + ForwardAgent yes + +` + if got != want { + t.Errorf("makeSSHConfigEntryForNode() mismatch:\ngot:\n%s\nwant:\n%s", got, want) + } +} + +func TestCreateNewSSHConfig_WithNodes(t *testing.T) { + c := NewSSHConfigurerV2(DummySSHConfigurerV2Store{}) + + nodes := []ExternalNodeSSHEntry{ + {Alias: "gpu-box", Hostname: "10.0.0.5", Port: 41920, User: "ec2-user"}, + } + + cStr, err := c.CreateNewSSHConfig([]entity.Workspace{}, nodes) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + want := `# included in /my/user/config +Host gpu-box + HostName 10.0.0.5 + User ec2-user + Port 41920 + IdentityFile "/my/priv/key.pem" + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + ServerAliveInterval 30 + ForwardAgent yes + +` + if cStr != want { + t.Errorf("CreateNewSSHConfig with nodes mismatch:\ngot:\n%s\nwant:\n%s", cStr, want) + } +} + +func TestCreateNewSSHConfig_WorkspacesAndNodes(t *testing.T) { + c := NewSSHConfigurerV2(DummySSHConfigurerV2Store{}) + + nodes := []ExternalNodeSSHEntry{ + {Alias: "my-node", Hostname: "192.168.1.100", Port: 33000, User: "ubuntu"}, + } + + cStr, err := c.CreateNewSSHConfig(somePlainWorkspaces[:1], nodes) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Should contain both workspace entry and node entry + if !assert.Contains(t, cStr, "Host testName1\n") { + return + } + if !assert.Contains(t, cStr, "Host my-node\n") { + return + } + if !assert.Contains(t, cStr, "Port 33000\n") { + return + } +} + func makeMockFS() SSHConfigurerV2Store { bs := store.NewBasicStore().WithEnvGetter( func(s string) string { @@ -775,7 +884,7 @@ Host testName1-host s := SSHConfigurerV2{ store: tt.fields.store, } - if err := s.Update(tt.args.workspaces); (err != nil) != tt.wantErr { + if err := s.Update(tt.args.workspaces, nil); (err != nil) != tt.wantErr { t.Errorf("SSHConfigurerV2.Update() error = %v, wantErr %v", err, tt.wantErr) } // make sure the linux config is correct