Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions internal/configstore/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,83 @@ func boolPtr(v bool) *bool {
return &b
}

// Merge returns a new Config that combines base with overlay. Overlay values
// take precedence: non-empty scalars replace base values, and maps are merged
// with overlay keys overriding base keys.
func Merge(base, overlay Config) Config {
out := base.Clone()

if strings.TrimSpace(overlay.TargetImage) != "" {
out.TargetImage = overlay.TargetImage
}

for cmd, value := range overlay.CommandVolumes {
if value != nil {
out.CommandVolumes[cmd] = boolPtr(*value)
}
}

for host, spec := range overlay.CustomVolumes {
out.CustomVolumes[host] = spec
}

for key, value := range overlay.EnvVars {
out.EnvVars[key] = value
}

for key, image := range overlay.ProjectTargetImages {
out.ProjectTargetImages[key] = image
}

for projectKey, settings := range overlay.ProjectCommandVolumes {
existing := out.ProjectCommandVolumes[projectKey]
if existing == nil {
existing = make(map[string]*bool)
}
for cmd, value := range settings {
if value != nil {
existing[cmd] = boolPtr(*value)
}
}
out.ProjectCommandVolumes[projectKey] = existing
}

for projectKey, specs := range overlay.ProjectCustomVolumes {
existing := out.ProjectCustomVolumes[projectKey]
if existing == nil {
existing = make(map[string]string)
}
for key, value := range specs {
existing[key] = value
}
out.ProjectCustomVolumes[projectKey] = existing
}

for projectKey, disables := range overlay.ProjectVolumeDisables {
existing := out.ProjectVolumeDisables[projectKey]
if existing == nil {
existing = make(map[string]bool)
}
for key, value := range disables {
existing[key] = value
}
out.ProjectVolumeDisables[projectKey] = existing
}

for projectKey, envs := range overlay.ProjectEnvVars {
existing := out.ProjectEnvVars[projectKey]
if existing == nil {
existing = make(map[string]string)
}
for key, value := range envs {
existing[key] = value
}
out.ProjectEnvVars[projectKey] = existing
}

return out
}

// SetGlobalTargetImage records the default container image for leash-managed sessions.
func (c *Config) SetGlobalTargetImage(image string) {
c.TargetImage = strings.TrimSpace(image)
Expand Down
26 changes: 26 additions & 0 deletions internal/configstore/loadsave.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,32 @@ func decodeConfig(data []byte, path string, cfg *Config) error {
return nil
}

// LoadWithOverlay loads the global XDG configuration and then, if a
// .leash.toml file exists in dir, merges it on top. The local file values
// take precedence over the global config.
func LoadWithOverlay(dir string) (Config, error) {
base, err := Load()
if err != nil {
return base, err
}
if strings.TrimSpace(dir) == "" {
return base, nil
}
localPath := GetLocalConfigPath(dir)
data, err := os.ReadFile(localPath)
if errors.Is(err, os.ErrNotExist) {
return base, nil
}
if err != nil {
return base, fmt.Errorf("read local config %s: %w", localPath, err)
}
overlay := New()
if err := decodeConfig(data, localPath, &overlay); err != nil {
return base, err
}
return Merge(base, overlay), nil
}

// Save atomically writes the configuration to disk.
func Save(cfg Config) error {
cfg.ensureInitialized()
Expand Down
178 changes: 178 additions & 0 deletions internal/configstore/loadsave_overlay_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package configstore

import (
"os"
"path/filepath"
"testing"
)

// These tests override XDG_CONFIG_HOME and HOME; run serially.

func TestLoadWithOverlayNoLocalFile(t *testing.T) {
testSetEnv(t, "LEASH_HOME", "")
base := t.TempDir()
testSetEnv(t, "XDG_CONFIG_HOME", base)
setHome(t, filepath.Join(base, "home"))

// Write a global config with a target image.
cfgDir := filepath.Join(base, "leash")
if err := os.MkdirAll(cfgDir, 0o700); err != nil {
t.Fatal(err)
}
globalTOML := `[leash]
target_image = "global-image"

[leash.envvars]
GH_CONFIG_DIR = "/root/.config/gh"
`
if err := os.WriteFile(filepath.Join(cfgDir, configFileName), []byte(globalTOML), 0o600); err != nil {
t.Fatal(err)
}

// Load with a directory that has no .leash.toml.
projectDir := t.TempDir()
cfg, err := LoadWithOverlay(projectDir)
if err != nil {
t.Fatalf("LoadWithOverlay: %v", err)
}

if cfg.TargetImage != "global-image" {
t.Fatalf("TargetImage = %q, want %q", cfg.TargetImage, "global-image")
}
if cfg.EnvVars["GH_CONFIG_DIR"] != "/root/.config/gh" {
t.Fatalf("GH_CONFIG_DIR = %q, want %q", cfg.EnvVars["GH_CONFIG_DIR"], "/root/.config/gh")
}
}

func TestLoadWithOverlayMergesLocalFile(t *testing.T) {
testSetEnv(t, "LEASH_HOME", "")
base := t.TempDir()
testSetEnv(t, "XDG_CONFIG_HOME", base)
setHome(t, filepath.Join(base, "home"))

// Write global config.
cfgDir := filepath.Join(base, "leash")
if err := os.MkdirAll(cfgDir, 0o700); err != nil {
t.Fatal(err)
}
globalTOML := `[leash]
target_image = "global-image"

[leash.envvars]
GH_CONFIG_DIR = "/root/.config/gh"
SHARED_KEY = "global"
`
if err := os.WriteFile(filepath.Join(cfgDir, configFileName), []byte(globalTOML), 0o600); err != nil {
t.Fatal(err)
}

// Write local override.
projectDir := t.TempDir()
localTOML := `[leash.envvars]
ATLASSIAN_USER = "secret-user"
SHARED_KEY = "local-override"
`
if err := os.WriteFile(filepath.Join(projectDir, LocalConfigFileName), []byte(localTOML), 0o600); err != nil {
t.Fatal(err)
}

cfg, err := LoadWithOverlay(projectDir)
if err != nil {
t.Fatalf("LoadWithOverlay: %v", err)
}

// Global values preserved.
if cfg.TargetImage != "global-image" {
t.Fatalf("TargetImage = %q, want %q", cfg.TargetImage, "global-image")
}
if cfg.EnvVars["GH_CONFIG_DIR"] != "/root/.config/gh" {
t.Fatalf("GH_CONFIG_DIR = %q, want %q", cfg.EnvVars["GH_CONFIG_DIR"], "/root/.config/gh")
}

// Local values added.
if cfg.EnvVars["ATLASSIAN_USER"] != "secret-user" {
t.Fatalf("ATLASSIAN_USER = %q, want %q", cfg.EnvVars["ATLASSIAN_USER"], "secret-user")
}

// Local overrides global.
if cfg.EnvVars["SHARED_KEY"] != "local-override" {
t.Fatalf("SHARED_KEY = %q, want %q", cfg.EnvVars["SHARED_KEY"], "local-override")
}
}

func TestLoadWithOverlayLocalOverridesTargetImage(t *testing.T) {
testSetEnv(t, "LEASH_HOME", "")
base := t.TempDir()
testSetEnv(t, "XDG_CONFIG_HOME", base)
setHome(t, filepath.Join(base, "home"))

cfgDir := filepath.Join(base, "leash")
if err := os.MkdirAll(cfgDir, 0o700); err != nil {
t.Fatal(err)
}
globalTOML := `[leash]
target_image = "global-image"
`
if err := os.WriteFile(filepath.Join(cfgDir, configFileName), []byte(globalTOML), 0o600); err != nil {
t.Fatal(err)
}

projectDir := t.TempDir()
localTOML := `[leash]
target_image = "local-image"
`
if err := os.WriteFile(filepath.Join(projectDir, LocalConfigFileName), []byte(localTOML), 0o600); err != nil {
t.Fatal(err)
}

cfg, err := LoadWithOverlay(projectDir)
if err != nil {
t.Fatalf("LoadWithOverlay: %v", err)
}

if cfg.TargetImage != "local-image" {
t.Fatalf("TargetImage = %q, want %q", cfg.TargetImage, "local-image")
}
}

func TestLoadWithOverlayEmptyDirSkipsLocal(t *testing.T) {
testSetEnv(t, "LEASH_HOME", "")
base := t.TempDir()
testSetEnv(t, "XDG_CONFIG_HOME", base)
setHome(t, filepath.Join(base, "home"))

cfg, err := LoadWithOverlay("")
if err != nil {
t.Fatalf("LoadWithOverlay: %v", err)
}
// Should just return defaults with no error.
if cfg.TargetImage != "" {
t.Fatalf("expected empty TargetImage, got %q", cfg.TargetImage)
}
}

func TestLoadWithOverlayBadLocalTOMLReturnsError(t *testing.T) {
testSetEnv(t, "LEASH_HOME", "")
base := t.TempDir()
testSetEnv(t, "XDG_CONFIG_HOME", base)
setHome(t, filepath.Join(base, "home"))

projectDir := t.TempDir()
if err := os.WriteFile(filepath.Join(projectDir, LocalConfigFileName), []byte("not valid {{{ toml"), 0o600); err != nil {
t.Fatal(err)
}

_, err := LoadWithOverlay(projectDir)
if err == nil {
t.Fatal("expected error for invalid local TOML")
}
}

func TestGetLocalConfigPath(t *testing.T) {
t.Parallel()
got := GetLocalConfigPath("/some/project")
want := filepath.Join("/some/project", LocalConfigFileName)
if got != want {
t.Fatalf("GetLocalConfigPath = %q, want %q", got, want)
}
}
Loading
Loading