diff --git a/internal/env/patch.go b/internal/env/patch.go index 7ab5ba1..cb986ff 100644 --- a/internal/env/patch.go +++ b/internal/env/patch.go @@ -2,20 +2,16 @@ package env import ( "bufio" + "fmt" "os" + "path/filepath" "strings" "github.com/joho/godotenv" ) func Patch(path, key, value string) error { - existing, err := godotenv.Read(path) - if err != nil { - // If file doesn't exist, start fresh - existing = make(map[string]string) - } - existing[key] = value - return godotenv.Write(existing, path) + return PatchAll(path, map[string]string{key: value}, nil) } // ScanComments reads path and returns full-line comment lines per mode. @@ -65,35 +61,71 @@ func ScanComments(path, mode string) ([]string, error) { return collected, scanner.Err() } -// PatchAll reads existing env from path, merges patches, writes once, then appends comments. -// If path does not exist, starts from an empty env. comments may be nil (no-op). +// PatchAll reads existing env from path, merges patches, and writes atomically via a +// temp file + rename. If path does not exist, starts from an empty env. +// comments may be nil (no-op). The original file's permissions are preserved. func PatchAll(path string, patches map[string]string, comments []string) error { + var originalMode os.FileMode = 0600 + if fi, err := os.Stat(path); err == nil { + originalMode = fi.Mode().Perm() + } else if !os.IsNotExist(err) { + return err + } + existing, err := godotenv.Read(path) if err != nil { + if !os.IsNotExist(err) { + return err + } existing = make(map[string]string) } + for k, v := range patches { existing[k] = v } - if err := godotenv.Write(existing, path); err != nil { + + content, err := godotenv.Marshal(existing) + if err != nil { return err } - return appendComments(path, comments) -} -// appendComments appends collected comment lines to path, preceded by a blank line. -// No-op if comments is empty. -func appendComments(path string, comments []string) error { - if len(comments) == 0 { - return nil + dir := filepath.Dir(path) + base := filepath.Base(path) + tmp, err := os.CreateTemp(dir, fmt.Sprintf(".%s.*.tmp", base)) + if err != nil { + return err + } + tmpName := tmp.Name() + committed := false + defer func() { + if !committed { + tmp.Close() + os.Remove(tmpName) + } + }() + + if _, err := tmp.WriteString(content + "\n"); err != nil { + return err + } + if len(comments) > 0 { + if _, err := tmp.WriteString("\n" + strings.Join(comments, "\n") + "\n"); err != nil { + return err + } + } + if err := tmp.Sync(); err != nil { + return err + } + if err := tmp.Close(); err != nil { + return err } - f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0) - if err != nil { + if err := os.Chmod(tmpName, originalMode); err != nil { return err } - defer f.Close() - _, err = f.WriteString("\n" + strings.Join(comments, "\n") + "\n") - return err + if err := os.Rename(tmpName, path); err != nil { + return err + } + committed = true + return nil } diff --git a/internal/env/patch_test.go b/internal/env/patch_test.go index 942f5a4..df9eafe 100644 --- a/internal/env/patch_test.go +++ b/internal/env/patch_test.go @@ -2,6 +2,7 @@ package env import ( "os" + "path/filepath" "strings" "testing" @@ -209,3 +210,75 @@ func TestPatchAll_NonExistentFile(t *testing.T) { t.Errorf("SECRET = %q, want %q", env["SECRET"], "abc123") } } + +func TestPatchAll_PermissionsPreserved(t *testing.T) { + path := writeTempEnv(t, "KEY=old\n") + if err := os.Chmod(path, 0640); err != nil { + t.Fatal(err) + } + + if err := PatchAll(path, map[string]string{"KEY": "new"}, nil); err != nil { + t.Fatalf("PatchAll: %v", err) + } + + fi, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if fi.Mode().Perm() != 0640 { + t.Errorf("mode = %04o, want 0640", fi.Mode().Perm()) + } +} + +func TestPatchAll_DefaultPermissionsForNewFile(t *testing.T) { + path := filepath.Join(t.TempDir(), ".env.new") + + if err := PatchAll(path, map[string]string{"X": "1"}, nil); err != nil { + t.Fatalf("PatchAll: %v", err) + } + + fi, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if fi.Mode().Perm() != 0600 { + t.Errorf("mode = %04o, want 0600", fi.Mode().Perm()) + } +} + +func TestPatchAll_TempFileRemovedOnError(t *testing.T) { + dir := t.TempDir() + // Point at a non-existent subdirectory so CreateTemp fails. + path := filepath.Join(dir, "nonexistent-subdir", ".env") + + err := PatchAll(path, map[string]string{"KEY": "val"}, nil) + if err == nil { + t.Fatal("expected error, got nil") + } + + entries, _ := os.ReadDir(dir) + for _, e := range entries { + if strings.HasSuffix(e.Name(), ".tmp") { + t.Errorf("stale temp file found: %s", e.Name()) + } + } +} + +func TestPatch_IsTransactional(t *testing.T) { + path := writeTempEnv(t, "A=1\nB=2\n") + + if err := Patch(path, "A", "updated"); err != nil { + t.Fatalf("Patch: %v", err) + } + + env, err := godotenv.Read(path) + if err != nil { + t.Fatal(err) + } + if env["A"] != "updated" { + t.Errorf("A = %q, want %q", env["A"], "updated") + } + if env["B"] != "2" { + t.Errorf("B = %q, want %q", env["B"], "2") + } +}