Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 54 additions & 22 deletions internal/env/patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@ package env

import (
"bufio"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/joho/godotenv"
)

func Patch(path, key, value string) error {
existing, err := godotenv.Read(path)
if err != nil {
// If file doesn't exist, start fresh
existing = make(map[string]string)
}
existing[key] = value
return godotenv.Write(existing, path)
return PatchAll(path, map[string]string{key: value}, nil)
}

// ScanComments reads path and returns full-line comment lines per mode.
Expand Down Expand Up @@ -65,35 +61,71 @@ func ScanComments(path, mode string) ([]string, error) {
return collected, scanner.Err()
}

// PatchAll reads existing env from path, merges patches, writes once, then appends comments.
// If path does not exist, starts from an empty env. comments may be nil (no-op).
// PatchAll reads existing env from path, merges patches, and writes atomically via a
// temp file + rename. If path does not exist, starts from an empty env.
// comments may be nil (no-op). The original file's permissions are preserved.
func PatchAll(path string, patches map[string]string, comments []string) error {
var originalMode os.FileMode = 0600
if fi, err := os.Stat(path); err == nil {
originalMode = fi.Mode().Perm()
} else if !os.IsNotExist(err) {
return err
}

existing, err := godotenv.Read(path)
if err != nil {
if !os.IsNotExist(err) {
return err
}
existing = make(map[string]string)
}

for k, v := range patches {
existing[k] = v
}
if err := godotenv.Write(existing, path); err != nil {

content, err := godotenv.Marshal(existing)
if err != nil {
return err
}
return appendComments(path, comments)
}

// appendComments appends collected comment lines to path, preceded by a blank line.
// No-op if comments is empty.
func appendComments(path string, comments []string) error {
if len(comments) == 0 {
return nil
dir := filepath.Dir(path)
base := filepath.Base(path)
tmp, err := os.CreateTemp(dir, fmt.Sprintf(".%s.*.tmp", base))
if err != nil {
return err
}
tmpName := tmp.Name()
committed := false
defer func() {
if !committed {
tmp.Close()
os.Remove(tmpName)
}
}()

if _, err := tmp.WriteString(content + "\n"); err != nil {
return err
}
if len(comments) > 0 {
if _, err := tmp.WriteString("\n" + strings.Join(comments, "\n") + "\n"); err != nil {
return err
}
}
if err := tmp.Sync(); err != nil {
return err
}
if err := tmp.Close(); err != nil {
return err
}

f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0)
if err != nil {
if err := os.Chmod(tmpName, originalMode); err != nil {
return err
}
defer f.Close()

_, err = f.WriteString("\n" + strings.Join(comments, "\n") + "\n")
return err
if err := os.Rename(tmpName, path); err != nil {
return err
}
committed = true
return nil
}
73 changes: 73 additions & 0 deletions internal/env/patch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package env

import (
"os"
"path/filepath"
"strings"
"testing"

Expand Down Expand Up @@ -209,3 +210,75 @@ func TestPatchAll_NonExistentFile(t *testing.T) {
t.Errorf("SECRET = %q, want %q", env["SECRET"], "abc123")
}
}

func TestPatchAll_PermissionsPreserved(t *testing.T) {
path := writeTempEnv(t, "KEY=old\n")
if err := os.Chmod(path, 0640); err != nil {
t.Fatal(err)
}

if err := PatchAll(path, map[string]string{"KEY": "new"}, nil); err != nil {
t.Fatalf("PatchAll: %v", err)
}

fi, err := os.Stat(path)
if err != nil {
t.Fatal(err)
}
if fi.Mode().Perm() != 0640 {
t.Errorf("mode = %04o, want 0640", fi.Mode().Perm())
}
}

func TestPatchAll_DefaultPermissionsForNewFile(t *testing.T) {
path := filepath.Join(t.TempDir(), ".env.new")

if err := PatchAll(path, map[string]string{"X": "1"}, nil); err != nil {
t.Fatalf("PatchAll: %v", err)
}

fi, err := os.Stat(path)
if err != nil {
t.Fatal(err)
}
if fi.Mode().Perm() != 0600 {
t.Errorf("mode = %04o, want 0600", fi.Mode().Perm())
}
}

func TestPatchAll_TempFileRemovedOnError(t *testing.T) {
dir := t.TempDir()
// Point at a non-existent subdirectory so CreateTemp fails.
path := filepath.Join(dir, "nonexistent-subdir", ".env")

err := PatchAll(path, map[string]string{"KEY": "val"}, nil)
if err == nil {
t.Fatal("expected error, got nil")
}

entries, _ := os.ReadDir(dir)
for _, e := range entries {
if strings.HasSuffix(e.Name(), ".tmp") {
t.Errorf("stale temp file found: %s", e.Name())
}
}
}

func TestPatch_IsTransactional(t *testing.T) {
path := writeTempEnv(t, "A=1\nB=2\n")

if err := Patch(path, "A", "updated"); err != nil {
t.Fatalf("Patch: %v", err)
}

env, err := godotenv.Read(path)
if err != nil {
t.Fatal(err)
}
if env["A"] != "updated" {
t.Errorf("A = %q, want %q", env["A"], "updated")
}
if env["B"] != "2" {
t.Errorf("B = %q, want %q", env["B"], "2")
}
}