Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cli
import (
"fmt"
"github.com/arran4/gocdm/auth"
"github.com/arran4/gocdm/session"
"github.com/arran4/gocdm/x11"
"io"
"os"
Expand Down Expand Up @@ -102,6 +103,12 @@ func TestRunDryRunNoSessions(t *testing.T) {
})
t.Setenv("HOME", tmpHome)
t.Setenv("XDG_CONFIG_HOME", tmpHome)

origShellsFile := session.ShellsFile
session.ShellsFile = tmpHome + "/empty-shells-file-doesnt-exist"
t.Cleanup(func() {
session.ShellsFile = origShellsFile
})
exited := false
code := -1
mockExit := func(c int) {
Expand Down
37 changes: 37 additions & 0 deletions session/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ var (
X11SessionsDir = "/etc/X11/Sessions"
XSessionsDir = "/usr/share/xsessions"
WaylandSessionsDir = "/usr/share/wayland-sessions"
ShellsFile = "/etc/shells"
)

type Session struct {
Expand Down Expand Up @@ -121,6 +122,14 @@ func (d *Discoverer) Discover(userHome string) ([]Session, error) {
}
}

// Try /etc/shells for console logins
shellSessions, err := d.discoverShellSessions()
if err == nil {
for _, s := range shellSessions {
addSession(s)
}
}

// Sort sessions by name
sort.SliceStable(sessions, func(i, j int) bool {
return sessions[i].Name < sessions[j].Name
Expand Down Expand Up @@ -214,6 +223,34 @@ func (d *Discoverer) discoverWaylandSessions() ([]Session, error) {
return d.discoverCustomSessions(WaylandSessionsDir, "W")
}

func (d *Discoverer) discoverShellSessions() ([]Session, error) {
file, err := os.Open(ShellsFile)
if err != nil {
return nil, err
}
defer file.Close()

var sessions []Session
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}

if _, err := d.ExecLookPath(line); err == nil {
name := filepath.Base(line) + " shell"
sessions = append(sessions, Session{
Name: name,
Exec: line,
Type: "C",
Path: ShellsFile,
})
}
}
Comment on lines +235 to +250
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The bufio.Scanner loop finishes without checking for errors. It is a best practice to check scanner.Err() after the loop to ensure that the file was read completely and no I/O errors occurred during scanning.

	for scanner.Scan() {
		line := strings.TrimSpace(scanner.Text())
		if line == "" || strings.HasPrefix(line, "#") {
			continue
		}

		if _, err := d.ExecLookPath(line); err == nil {
			name := filepath.Base(line) + " shell"
			sessions = append(sessions, Session{
				Name: name,
				Exec: line,
				Type: "C",
				Path: ShellsFile,
			})
		}
	}

	if err := scanner.Err(); err != nil {
		return nil, err
	}

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The bufio.Scanner loop finishes without checking for errors. It is a best practice to check scanner.Err() after the loop to ensure that the file was read completely and no I/O errors occurred during scanning.

	for scanner.Scan() {
		line := strings.TrimSpace(scanner.Text())
		if line == "" || strings.HasPrefix(line, "#") {
			continue
		}

		if _, err := d.ExecLookPath(line); err == nil {
			name := filepath.Base(line) + " shell"
			sessions = append(sessions, Session{
				Name: name,
				Exec: line,
				Type: "C",
				Path: ShellsFile,
			})
		}
	}

	if err := scanner.Err(); err != nil {
		return nil, err
	}

@jules

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I've added the error check using scanner.Err() after the bufio.Scanner loop in discoverShellSessions to catch any I/O errors that might occur during the file reading.

return sessions, nil
}

// stripFreedesktopExecVariables removes Freedesktop Exec field codes from the command line.
// Some window managers or display managers specify %f, %u, etc., in their .desktop files.
// For a display manager, we generally want to remove these or ignore them.
Expand Down
19 changes: 17 additions & 2 deletions session/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ func TestDiscoverSessions(t *testing.T) {
waylandSessionsDir := filepath.Join(tmpDir, "usr", "share", "wayland-sessions")
userHome := filepath.Join(tmpDir, "home", "user")
userConfigWayland := filepath.Join(userHome, ".config", "wayland-sessions")
etcDir := filepath.Join(tmpDir, "etc")

dirs := []string{x11Dir, xsessionsDir, waylandSessionsDir, userHome, userConfigWayland}
dirs := []string{x11Dir, xsessionsDir, waylandSessionsDir, userHome, userConfigWayland, etcDir}
for _, d := range dirs {
if err := os.MkdirAll(d, 0755); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -72,19 +73,31 @@ func TestDiscoverSessions(t *testing.T) {
t.Fatal(err)
}

// 6. Shells file
shellsContent := "# /etc/shells\n" +
"/not/a/real/path/to/shell\n" + // will not be found by exec.LookPath
testCommand + "\n"
shellsPath := filepath.Join(etcDir, "shells")
if err := os.WriteFile(shellsPath, []byte(shellsContent), 0644); err != nil {
t.Fatal(err)
}

// Save original vars
origX11 := X11SessionsDir
origXSessions := XSessionsDir
origWaylandSessions := WaylandSessionsDir
origShellsFile := ShellsFile
defer func() {
X11SessionsDir = origX11
XSessionsDir = origXSessions
WaylandSessionsDir = origWaylandSessions
ShellsFile = origShellsFile
}()

X11SessionsDir = x11Dir
XSessionsDir = xsessionsDir
WaylandSessionsDir = waylandSessionsDir
ShellsFile = shellsPath

// Test Discovery
sessions, err := DiscoverSessions(userHome)
Expand All @@ -98,8 +111,9 @@ func TestDiscoverSessions(t *testing.T) {
// 3. legacy_x11 (Type X)
// 4. Standard X Session (Type X)
// 5. Wayland Session (Type W)
// 6. Shell Session (Type C)

expectedCount := 5
expectedCount := 6
if len(sessions) != expectedCount {
t.Errorf("Expected %d sessions, got %d", expectedCount, len(sessions))
for _, s := range sessions {
Expand All @@ -119,6 +133,7 @@ func TestDiscoverSessions(t *testing.T) {
"legacy_x11": "X",
"Standard X Session": "X",
"Wayland Session": "W",
filepath.Base(testCommand) + " shell": "C",
}

for name, typ := range expectations {
Expand Down
Loading