diff --git a/internal/llm/claude_stream_driver.go b/internal/llm/claude_stream_driver.go index 9817a09..220dde4 100644 --- a/internal/llm/claude_stream_driver.go +++ b/internal/llm/claude_stream_driver.go @@ -106,8 +106,7 @@ func (d *claudeStreamDriver) Close() error { _ = stdin.Close() } if cmd != nil && cmd.Process != nil { - _ = cmd.Process.Kill() - err = cmd.Wait() + err = shared.WaitOrKill(cmd) } close(d.events) }) @@ -159,8 +158,7 @@ func (d *claudeStreamDriver) ensureStarted(ctx context.Context, req RunRequest) if d.closed { d.mu.Unlock() _ = stdin.Close() - _ = cmd.Process.Kill() - _ = cmd.Wait() + _ = shared.WaitOrKill(cmd) return ErrInteractiveClosed } d.cmd = cmd diff --git a/internal/llm/internal/shared/shared.go b/internal/llm/internal/shared/shared.go index effc575..03efb78 100644 --- a/internal/llm/internal/shared/shared.go +++ b/internal/llm/internal/shared/shared.go @@ -1,7 +1,11 @@ -// Package shared provides scanner buffer constants used across LLM providers. +// Package shared provides utilities and constants used across LLM providers. package shared -import "time" +import ( + "os/exec" + "syscall" + "time" +) // Default scanner buffer and token size constants used across providers. const ( @@ -12,3 +16,37 @@ const ( // AuthCheckTimeout is the default timeout for CLI login/auth status checks. const AuthCheckTimeout = 15 * time.Second + +var ( + // SubprocessGracePeriod is how long to wait for a subprocess to exit + // naturally after its stdin is closed before escalating to SIGTERM. + SubprocessGracePeriod = 10 * time.Second + // SubprocessTermGracePeriod is how long to wait after SIGTERM before + // escalating to SIGKILL. + SubprocessTermGracePeriod = 10 * time.Second +) + +// WaitOrKill closes the subprocess gracefully: wait for natural exit after +// stdin close, then SIGTERM, then SIGKILL as last resort. +func WaitOrKill(cmd *exec.Cmd) error { + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + select { + case err := <-done: + return err + case <-time.After(SubprocessGracePeriod): + } + if err := cmd.Process.Signal(syscall.SIGTERM); err != nil { + _ = cmd.Process.Kill() + return <-done + } + select { + case err := <-done: + return err + case <-time.After(SubprocessTermGracePeriod): + } + _ = cmd.Process.Kill() + return <-done +} diff --git a/internal/llm/internal/shared/shared_test.go b/internal/llm/internal/shared/shared_test.go new file mode 100644 index 0000000..e0059c4 --- /dev/null +++ b/internal/llm/internal/shared/shared_test.go @@ -0,0 +1,130 @@ +package shared + +import ( + "os/exec" + "syscall" + "testing" + "time" +) + +func TestWaitOrKill_NormalExit(t *testing.T) { + cmd := exec.Command("true") + if err := cmd.Start(); err != nil { + t.Fatalf("start: %v", err) + } + if err := WaitOrKill(cmd); err != nil { + t.Fatalf("expected nil for exit 0, got %v", err) + } +} + +func TestWaitOrKill_NonZeroExit(t *testing.T) { + cmd := exec.Command("sh", "-c", "exit 42") + if err := cmd.Start(); err != nil { + t.Fatalf("start: %v", err) + } + err := WaitOrKill(cmd) + if err == nil { + t.Fatal("expected non-nil error for exit 42") + } + exitErr, ok := err.(*exec.ExitError) + if !ok { + t.Fatalf("expected *exec.ExitError, got %T: %v", err, err) + } + if status := exitErr.Sys().(syscall.WaitStatus); status.ExitStatus() != 42 { + t.Fatalf("expected exit code 42, got %d", status.ExitStatus()) + } +} + +func TestWaitOrKill_AlreadyExited(t *testing.T) { + cmd := exec.Command("true") + if err := cmd.Run(); err != nil { + t.Fatalf("run: %v", err) + } + // WaitOrKill on already-waited process should return an error + // because cmd.Wait() can only be called once. + err := WaitOrKill(cmd) + if err == nil { + t.Fatal("expected error for already-waited process") + } +} + +func TestWaitOrKill_SIGTERM_Escalation(t *testing.T) { + // Restore original durations after test. + oldGrace := SubprocessGracePeriod + oldTerm := SubprocessTermGracePeriod + SubprocessGracePeriod = 100 * time.Millisecond + SubprocessTermGracePeriod = 100 * time.Millisecond + defer func() { + SubprocessGracePeriod = oldGrace + SubprocessTermGracePeriod = oldTerm + }() + + // Start a process that ignores SIGTERM. + // We use a shell script that traps SIGTERM and sleeps. + cmd := exec.Command("sh", "-c", "trap '' TERM; sleep 60") + if err := cmd.Start(); err != nil { + t.Fatalf("start: %v", err) + } + + start := time.Now() + err := WaitOrKill(cmd) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected non-nil error from killed process") + } + + // Should have been killed after both grace periods expire + // (wait 100ms + SIGTERM 100ms ≈ 200ms, then SIGKILL). + // Allow some slack for scheduling. + if elapsed < 150*time.Millisecond { + t.Fatalf("expected at least ~200ms before kill, took %v", elapsed) + } + if elapsed > 2*time.Second { + t.Fatalf("kill took too long: %v", elapsed) + } + + exitErr, ok := err.(*exec.ExitError) + if !ok { + t.Fatalf("expected *exec.ExitError, got %T: %v", err, err) + } + status := exitErr.Sys().(syscall.WaitStatus) + if !status.Signaled() { + t.Fatalf("expected process killed by signal, got status %v", status) + } +} + +func TestWaitOrKill_SIGTERM_Accepted(t *testing.T) { + oldGrace := SubprocessGracePeriod + oldTerm := SubprocessTermGracePeriod + SubprocessGracePeriod = 100 * time.Millisecond + SubprocessTermGracePeriod = 100 * time.Millisecond + defer func() { + SubprocessGracePeriod = oldGrace + SubprocessTermGracePeriod = oldTerm + }() + + // Start a process that accepts SIGTERM (the default). + cmd := exec.Command("sleep", "60") + if err := cmd.Start(); err != nil { + t.Fatalf("start: %v", err) + } + + err := WaitOrKill(cmd) + if err == nil { + t.Fatal("expected non-nil error from signaled process") + } + + exitErr, ok := err.(*exec.ExitError) + if !ok { + t.Fatalf("expected *exec.ExitError, got %T: %v", err, err) + } + status := exitErr.Sys().(syscall.WaitStatus) + if !status.Signaled() { + t.Fatalf("expected process killed by signal, got status %v", status) + } + // SIGTERM (15) is the expected signal. + if status.Signal() != syscall.SIGTERM { + t.Fatalf("expected SIGTERM (15), got signal %d", status.Signal()) + } +} diff --git a/internal/llm/jsonrpc_line_client.go b/internal/llm/jsonrpc_line_client.go index c6a97aa..0749e4f 100644 --- a/internal/llm/jsonrpc_line_client.go +++ b/internal/llm/jsonrpc_line_client.go @@ -176,9 +176,9 @@ func (c *lineRPCClient) Close() error { } _ = c.stdin.Close() if c.cmd.Process != nil { - _ = c.cmd.Process.Kill() + return shared.WaitOrKill(c.cmd) } - return c.cmd.Wait() + return nil } func (c *lineRPCClient) write(payload any) error { diff --git a/internal/llm/opencode_appserver_driver.go b/internal/llm/opencode_appserver_driver.go index f03cf94..20113dd 100644 --- a/internal/llm/opencode_appserver_driver.go +++ b/internal/llm/opencode_appserver_driver.go @@ -117,8 +117,7 @@ func (d *openCodeAppServerDriver) Close() error { d.eventCancel = nil d.mu.Unlock() if cmd != nil && cmd.Process != nil { - _ = cmd.Process.Kill() - err = cmd.Wait() + err = shared.WaitOrKill(cmd) } if cancelEvents != nil { cancelEvents() @@ -176,15 +175,13 @@ func (d *openCodeAppServerDriver) ensureServer(ctx context.Context, req RunReque select { case serverURL := <-urlCh: if serverURL == "" { - _ = cmd.Process.Kill() - _ = cmd.Wait() + _ = shared.WaitOrKill(cmd) return fmt.Errorf("opencode serve exited before reporting URL: %s", strings.TrimSpace(d.stderr.String())) } d.mu.Lock() if d.closed { d.mu.Unlock() - _ = cmd.Process.Kill() - _ = cmd.Wait() + _ = shared.WaitOrKill(cmd) return ErrInteractiveClosed } d.cmd = cmd @@ -193,8 +190,7 @@ func (d *openCodeAppServerDriver) ensureServer(ctx context.Context, req RunReque d.ensureEventStream() return nil case <-ctx.Done(): - _ = cmd.Process.Kill() - _ = cmd.Wait() + _ = shared.WaitOrKill(cmd) return ctx.Err() } } @@ -698,7 +694,6 @@ func (d *openCodeAppServerDriver) hasOpenCodeAssistantText() bool { func (d *openCodeAppServerDriver) resetServerForNextRequest() { d.mu.Lock() - defer d.mu.Unlock() d.baseURL = "" d.sessionID = "" d.activeID = "" @@ -709,11 +704,13 @@ func (d *openCodeAppServerDriver) resetServerForNextRequest() { d.eventCancel() d.eventCancel = nil } - if d.cmd != nil && d.cmd.Process != nil { - _ = d.cmd.Process.Kill() - _ = d.cmd.Wait() - } + cmd := d.cmd d.cmd = nil + d.mu.Unlock() + + if cmd != nil && cmd.Process != nil { + _ = shared.WaitOrKill(cmd) + } } func (d *openCodeAppServerDriver) emit(event TurnEvent) {