diff --git a/cmd/server/server.go b/cmd/server/server.go index 6a7fa7f0..561877e8 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -8,6 +8,7 @@ import ( "log/slog" "net/http" "os" + "path/filepath" "sort" "strings" @@ -103,6 +104,35 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } } + // Get the variables related to state management + stateFile := viper.GetString(StateFile) + loadState := true + saveState := true + if stateFile != "" { + if !viper.IsSet(LoadState) { + loadState = true + } else { + loadState = viper.GetBool(LoadState) + } + + if !viper.IsSet(SaveState) { + saveState = true + } else { + saveState = viper.GetBool(SaveState) + } + } + + pidFile := viper.GetString(PidFile) + + // Write PID file if configured + if pidFile != "" { + if err := writePIDFile(pidFile, logger); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + // Ensure PID file is cleaned up on exit + defer cleanupPIDFile(pidFile, logger) + } + printOpenAPI := viper.GetBool(FlagPrintOpenAPI) var process *termexec.Process if printOpenAPI { @@ -128,7 +158,13 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), InitialPrompt: initialPrompt, + StatePersistenceConfig: httpapi.StatePersistenceConfig{ + StateFile: stateFile, + LoadState: loadState, + SaveState: saveState, + }, }) + if err != nil { return xerrors.Errorf("failed to create server: %w", err) } @@ -137,6 +173,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er return nil } srv.StartSnapshotLoop(ctx) + srv.HandleSignals(ctx, process) logger.Info("Starting server on port", "port", port) processExitCh := make(chan error, 1) go func() { @@ -152,7 +189,7 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er logger.Error("Failed to stop server", "error", err) } }() - if err := srv.Start(); err != nil && err != context.Canceled && err != http.ErrServerClosed { + if err := srv.Start(); err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) { return xerrors.Errorf("failed to start server: %w", err) } select { @@ -172,6 +209,35 @@ var agentNames = (func() []string { return names })() +// writePIDFile writes the current process ID to the specified file +func writePIDFile(pidFile string, logger *slog.Logger) error { + pid := os.Getpid() + pidContent := fmt.Sprintf("%d\n", pid) + + // Create directory if it doesn't exist + dir := filepath.Dir(pidFile) + if err := os.MkdirAll(dir, 0o755); err != nil { + return xerrors.Errorf("failed to create PID file directory: %w", err) + } + + // Write PID file + if err := os.WriteFile(pidFile, []byte(pidContent), 0o644); err != nil { + return xerrors.Errorf("failed to write PID file: %w", err) + } + + logger.Info("Wrote PID file", "pidFile", pidFile, "pid", pid) + return nil +} + +// cleanupPIDFile removes the PID file if it exists +func cleanupPIDFile(pidFile string, logger *slog.Logger) { + if err := os.Remove(pidFile); err != nil && !os.IsNotExist(err) { + logger.Error("Failed to remove PID file", "pidFile", pidFile, "error", err) + } else if err == nil { + logger.Info("Removed PID file", "pidFile", pidFile) + } +} + type flagSpec struct { name string shorthand string @@ -191,6 +257,10 @@ const ( FlagAllowedOrigins = "allowed-origins" FlagExit = "exit" FlagInitialPrompt = "initial-prompt" + StateFile = "state-file" + LoadState = "load-state" + SaveState = "save-state" + PidFile = "pid-file" ) func CreateServerCmd() *cobra.Command { @@ -229,6 +299,10 @@ func CreateServerCmd() *cobra.Command { // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, {FlagInitialPrompt, "I", "", "Initial prompt for the agent. Recommended only if the agent doesn't support initial prompt in interaction mode. Will be read from stdin if piped (e.g., echo 'prompt' | agentapi server -- my-agent)", "string"}, + {StateFile, "s", "", "Path to file for saving/loading server state", "string"}, + {LoadState, "", false, "Load state from state-file on startup (defaults to true when state-file is set)", "bool"}, + {SaveState, "", false, "Save state to state-file on shutdown (defaults to true when state-file is set)", "bool"}, + {PidFile, "", "", "Path to file where the server process ID will be written for shutdown scripts", "string"}, } for _, spec := range flagSpecs { diff --git a/lib/httpapi/events.go b/lib/httpapi/events.go index 73eff07b..e8cabab6 100644 --- a/lib/httpapi/events.go +++ b/lib/httpapi/events.go @@ -120,7 +120,7 @@ func (e *EventEmitter) notifyChannels(eventType EventType, payload any) { } } -// Assumes that only the last message can change or new messages can be added. +// UpdateMessagesAndEmitChanges assumes that only the last message can change or new messages can be added. // If a new message is injected between existing messages (identified by Id), the behavior is undefined. func (e *EventEmitter) UpdateMessagesAndEmitChanges(newMessages []st.ConversationMessage) { e.mu.Lock() diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index c94abafb..d234bad0 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -34,18 +34,20 @@ import ( // Server represents the HTTP server type Server struct { - router chi.Router - api huma.API - port int - srv *http.Server - mu sync.RWMutex - logger *slog.Logger - conversation *st.Conversation - agentio *termexec.Process - agentType mf.AgentType - emitter *EventEmitter - chatBasePath string - tempDir string + router chi.Router + api huma.API + port int + srv *http.Server + mu sync.RWMutex + logger *slog.Logger + conversation *st.PTYConversation + agentio *termexec.Process + agentType mf.AgentType + emitter *EventEmitter + chatBasePath string + tempDir string + statePersistenceConfig StatePersistenceConfig + stateLoadComplete bool } func (s *Server) NormalizeSchema(schema any) any { @@ -94,14 +96,21 @@ func (s *Server) GetOpenAPI() string { // because the action of taking a snapshot takes time too. const snapshotInterval = 25 * time.Millisecond +type StatePersistenceConfig struct { + StateFile string + LoadState bool + SaveState bool +} + type ServerConfig struct { - AgentType mf.AgentType - Process *termexec.Process - Port int - ChatBasePath string - AllowedHosts []string - AllowedOrigins []string - InitialPrompt string + AgentType mf.AgentType + Process *termexec.Process + Port int + ChatBasePath string + AllowedHosts []string + AllowedOrigins []string + InitialPrompt string + StatePersistenceConfig StatePersistenceConfig } // Validate allowed hosts don't contain whitespace, commas, schemes, or ports. @@ -237,7 +246,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { return mf.FormatToolCall(config.AgentType, message) } - conversation := st.NewConversation(ctx, st.ConversationConfig{ + conversation := st.NewPTY(ctx, st.PTYConversationConfig{ AgentType: config.AgentType, AgentIO: config.Process, GetTime: func() time.Time { @@ -260,16 +269,18 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { logger.Info("Created temporary directory for uploads", "tempDir", tempDir) s := &Server{ - router: router, - api: api, - port: config.Port, - conversation: conversation, - logger: logger, - agentio: config.Process, - agentType: config.AgentType, - emitter: emitter, - chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), - tempDir: tempDir, + router: router, + api: api, + port: config.Port, + conversation: conversation, + logger: logger, + agentio: config.Process, + agentType: config.AgentType, + emitter: emitter, + chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), + tempDir: tempDir, + statePersistenceConfig: config.StatePersistenceConfig, + stateLoadComplete: false, } // Register API routes @@ -331,26 +342,32 @@ func sseMiddleware(ctx huma.Context, next func(huma.Context)) { } func (s *Server) StartSnapshotLoop(ctx context.Context) { - s.conversation.StartSnapshotLoop(ctx) + s.conversation.Start(ctx) go func() { for { currentStatus := s.conversation.Status() - // Send initial prompt when agent becomes stable for the first time - if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable { + // Send initial prompt & load state when agent becomes stable for the first time + if convertStatus(currentStatus) == AgentStatusStable { - if err := s.conversation.SendMessage(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { - s.logger.Error("Failed to send initial prompt", "error", err) - } else { - s.conversation.InitialPromptSent = true - s.conversation.ReadyForInitialPrompt = false - currentStatus = st.ConversationStatusChanging - s.logger.Info("Initial prompt sent successfully") + if !s.stateLoadComplete && s.statePersistenceConfig.LoadState { + _, _ = s.conversation.LoadState(s.statePersistenceConfig.StateFile) + s.stateLoadComplete = true + } + if !s.conversation.InitialPromptSent { + if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { + s.logger.Error("Failed to send initial prompt", "error", err) + } else { + s.conversation.InitialPromptSent = true + s.conversation.ReadyForInitialPrompt = false + currentStatus = st.ConversationStatusChanging + s.logger.Info("Initial prompt sent successfully") + } } } s.emitter.UpdateStatusAndEmitChanges(currentStatus, s.agentType) s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages()) - s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen()) + s.emitter.UpdateScreenAndEmitChanges(s.conversation.String()) time.Sleep(snapshotInterval) } }() @@ -449,7 +466,7 @@ func (s *Server) createMessage(ctx context.Context, input *MessageRequest) (*Mes switch input.Body.Type { case MessageTypeUser: - if err := s.conversation.SendMessage(FormatMessage(s.agentType, input.Body.Content)...); err != nil { + if err := s.conversation.Send(FormatMessage(s.agentType, input.Body.Content)...); err != nil { return nil, xerrors.Errorf("failed to send message: %w", err) } case MessageTypeRaw: @@ -610,6 +627,30 @@ func (s *Server) cleanupTempDir() { } } +// saveAndCleanup saves the conversation state and cleans up before shutdown +func (s *Server) saveAndCleanup(sig os.Signal, process *termexec.Process) { + // Save conversation state if configured (synchronously before closing process) + s.saveStateIfConfigured(sig.String()) + + // Now close the process + if err := process.Close(s.logger, 5*time.Second); err != nil { + s.logger.Error("Error closing process", "signal", sig, "error", err) + } +} + +// saveStateIfConfigured saves the conversation state if configured +func (s *Server) saveStateIfConfigured(source string) { + if s.statePersistenceConfig.SaveState && s.statePersistenceConfig.StateFile != "" { + if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceConfig.StateFile); err != nil { + s.logger.Error("Failed to save conversation state", "source", source, "error", err) + } else { + s.logger.Info("Saved conversation state", "source", source, "stateFile", s.statePersistenceConfig.StateFile) + } + } else { + s.logger.Warn("Save requested but state saving is not configured", "source", source) + } +} + // registerStaticFileRoutes sets up routes for serving static files func (s *Server) registerStaticFileRoutes() { chatHandler := FileServerWithIndexFallback(s.chatBasePath) diff --git a/lib/httpapi/server_signals_unix.go b/lib/httpapi/server_signals_unix.go new file mode 100644 index 00000000..837db86c --- /dev/null +++ b/lib/httpapi/server_signals_unix.go @@ -0,0 +1,44 @@ +//go:build unix + +package httpapi + +import ( + "context" + "os" + "os/signal" + "syscall" + + "github.com/coder/agentapi/lib/termexec" +) + +// HandleSignals sets up signal handlers for: +// - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process +// - SIGUSR1: save conversation state without exiting +func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { + // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) + go func() { + defer signal.Stop(shutdownCh) + sig := <-shutdownCh + s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) + + s.saveAndCleanup(sig, process) + }() + + // Handle SIGUSR1 for save without exit + saveOnlyCh := make(chan os.Signal, 1) + signal.Notify(saveOnlyCh, syscall.SIGUSR1) + go func() { + defer signal.Stop(saveOnlyCh) + for { + select { + case <-saveOnlyCh: + s.logger.Info("Received SIGUSR1, saving state without exiting") + s.saveStateIfConfigured("SIGUSR1") + case <-ctx.Done(): + return + } + } + }() +} diff --git a/lib/httpapi/server_signals_windows.go b/lib/httpapi/server_signals_windows.go new file mode 100644 index 00000000..503e56a9 --- /dev/null +++ b/lib/httpapi/server_signals_windows.go @@ -0,0 +1,27 @@ +//go:build windows + +package httpapi + +import ( + "context" + "os" + "os/signal" + "syscall" + + "github.com/coder/agentapi/lib/termexec" +) + +// HandleSignals sets up signal handlers for Windows. +// Only handles SIGTERM and SIGINT (SIGHUP and SIGUSR1 don't exist on Windows). +func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { + // Handle shutdown signals (SIGTERM, SIGINT only on Windows) + shutdownCh := make(chan os.Signal, 1) + signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) + go func() { + defer signal.Stop(shutdownCh) + sig := <-shutdownCh + s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) + + s.saveAndCleanup(sig, process) + }() +} diff --git a/lib/httpapi/setup.go b/lib/httpapi/setup.go index 16203041..c8d95b6e 100644 --- a/lib/httpapi/setup.go +++ b/lib/httpapi/setup.go @@ -4,10 +4,7 @@ import ( "context" "fmt" "os" - "os/signal" "strings" - "syscall" - "time" "github.com/coder/agentapi/lib/logctx" mf "github.com/coder/agentapi/lib/msgfmt" @@ -45,16 +42,5 @@ func SetupProcess(ctx context.Context, config SetupProcessConfig) (*termexec.Pro return nil, err } } - - // Handle SIGINT (Ctrl+C) and send it to the process - signalCh := make(chan os.Signal, 1) - signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) - go func() { - <-signalCh - if err := process.Close(logger, 5*time.Second); err != nil { - logger.Error("Error closing process", "error", err) - } - }() - return process, nil } diff --git a/lib/screentracker/conversation.go b/lib/screentracker/conversation.go index 97a74722..daf129a1 100644 --- a/lib/screentracker/conversation.go +++ b/lib/screentracker/conversation.go @@ -2,55 +2,27 @@ package screentracker import ( "context" - "fmt" - "log/slog" - "strings" - "sync" "time" - "github.com/coder/agentapi/lib/msgfmt" "github.com/coder/agentapi/lib/util" "github.com/danielgtaylor/huma/v2" "golang.org/x/xerrors" ) -type screenSnapshot struct { - timestamp time.Time - screen string -} - -type AgentIO interface { - Write(data []byte) (int, error) - ReadScreen() string -} +type ConversationStatus string -type ConversationConfig struct { - AgentType msgfmt.AgentType - AgentIO AgentIO - // GetTime returns the current time - GetTime func() time.Time - // How often to take a snapshot for the stability check - SnapshotInterval time.Duration - // How long the screen should not change to be considered stable - ScreenStabilityLength time.Duration - // Function to format the messages received from the agent - // userInput is the last user message - FormatMessage func(message string, userInput string) string - // SkipWritingMessage skips the writing of a message to the agent. - // This is used in tests - SkipWritingMessage bool - // SkipSendMessageStatusCheck skips the check for whether the message can be sent. - // This is used in tests - SkipSendMessageStatusCheck bool - // ReadyForInitialPrompt detects whether the agent has initialized and is ready to accept the initial prompt - ReadyForInitialPrompt func(message string) bool - // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls - FormatToolCall func(message string) (string, []string) - Logger *slog.Logger -} +const ( + ConversationStatusChanging ConversationStatus = "changing" + ConversationStatusStable ConversationStatus = "stable" + ConversationStatusInitializing ConversationStatus = "initializing" +) type ConversationRole string +func (c ConversationRole) Schema(r huma.Registry) *huma.Schema { + return util.OpenAPISchema(r, "ConversationRole", ConversationRoleValues) +} + const ( ConversationRoleUser ConversationRole = "user" ConversationRoleAgent ConversationRole = "agent" @@ -61,207 +33,15 @@ var ConversationRoleValues = []ConversationRole{ ConversationRoleAgent, } -func (c ConversationRole) Schema(r huma.Registry) *huma.Schema { - return util.OpenAPISchema(r, "ConversationRole", ConversationRoleValues) -} - -type ConversationMessage struct { - Id int - Message string - Role ConversationRole - Time time.Time -} - -type Conversation struct { - cfg ConversationConfig - // How many stable snapshots are required to consider the screen stable - stableSnapshotsThreshold int - snapshotBuffer *RingBuffer[screenSnapshot] - messages []ConversationMessage - screenBeforeLastUserMessage string - lock sync.Mutex - // InitialPrompt is the initial prompt passed to the agent - InitialPrompt string - // InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents - InitialPromptSent bool - // ReadyForInitialPrompt keeps track if the agent is ready to accept the initial prompt - ReadyForInitialPrompt bool - // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message - toolCallMessageSet map[string]bool -} - -type ConversationStatus string - -const ( - ConversationStatusChanging ConversationStatus = "changing" - ConversationStatusStable ConversationStatus = "stable" - ConversationStatusInitializing ConversationStatus = "initializing" +var ( + MessageValidationErrorWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace") + MessageValidationErrorEmpty = xerrors.New("message must not be empty") + MessageValidationErrorChanging = xerrors.New("message can only be sent when the agent is waiting for user input") ) -func getStableSnapshotsThreshold(cfg ConversationConfig) int { - length := cfg.ScreenStabilityLength.Milliseconds() - interval := cfg.SnapshotInterval.Milliseconds() - threshold := int(length / interval) - if length%interval != 0 { - threshold++ - } - return threshold + 1 -} - -func NewConversation(ctx context.Context, cfg ConversationConfig, initialPrompt string) *Conversation { - threshold := getStableSnapshotsThreshold(cfg) - c := &Conversation{ - cfg: cfg, - stableSnapshotsThreshold: threshold, - snapshotBuffer: NewRingBuffer[screenSnapshot](threshold), - messages: []ConversationMessage{ - { - Message: "", - Role: ConversationRoleAgent, - Time: cfg.GetTime(), - }, - }, - InitialPrompt: initialPrompt, - InitialPromptSent: len(initialPrompt) == 0, - toolCallMessageSet: make(map[string]bool), - } - return c -} - -func (c *Conversation) StartSnapshotLoop(ctx context.Context) { - go func() { - for { - select { - case <-ctx.Done(): - return - case <-time.After(c.cfg.SnapshotInterval): - // It's important that we hold the lock while reading the screen. - // There's a race condition that occurs without it: - // 1. The screen is read - // 2. Independently, SendMessage is called and takes the lock. - // 3. AddSnapshot is called and waits on the lock. - // 4. SendMessage modifies the terminal state, releases the lock - // 5. AddSnapshot adds a snapshot from a stale screen - c.lock.Lock() - screen := c.cfg.AgentIO.ReadScreen() - c.addSnapshotInner(screen) - c.lock.Unlock() - } - } - }() -} - -func FindNewMessage(oldScreen, newScreen string, agentType msgfmt.AgentType) string { - oldLines := strings.Split(oldScreen, "\n") - newLines := strings.Split(newScreen, "\n") - oldLinesMap := make(map[string]bool) - - // -1 indicates no header - dynamicHeaderEnd := -1 - - // Skip header lines for Opencode agent type to avoid false positives - // The header contains dynamic content (token count, context percentage, cost) - // that changes between screens, causing line comparison mismatches: - // - // ┃ # Getting Started with Claude CLI ┃ - // ┃ /share to create a shareable link 12.6K/6% ($0.05) ┃ - if len(newLines) >= 2 && agentType == msgfmt.AgentTypeOpencode { - dynamicHeaderEnd = 2 - } - - for _, line := range oldLines { - oldLinesMap[line] = true - } - firstNonMatchingLine := len(newLines) - for i, line := range newLines[dynamicHeaderEnd+1:] { - if !oldLinesMap[line] { - firstNonMatchingLine = i - break - } - } - newSectionLines := newLines[firstNonMatchingLine:] - - // remove leading and trailing lines which are empty or have only whitespace - startLine := 0 - endLine := len(newSectionLines) - 1 - for i := 0; i < len(newSectionLines); i++ { - if strings.TrimSpace(newSectionLines[i]) != "" { - startLine = i - break - } - } - for i := len(newSectionLines) - 1; i >= 0; i-- { - if strings.TrimSpace(newSectionLines[i]) != "" { - endLine = i - break - } - } - return strings.Join(newSectionLines[startLine:endLine+1], "\n") -} - -func (c *Conversation) lastMessage(role ConversationRole) ConversationMessage { - for i := len(c.messages) - 1; i >= 0; i-- { - if c.messages[i].Role == role { - return c.messages[i] - } - } - return ConversationMessage{} -} - -// This function assumes that the caller holds the lock -func (c *Conversation) updateLastAgentMessage(screen string, timestamp time.Time) { - agentMessage := FindNewMessage(c.screenBeforeLastUserMessage, screen, c.cfg.AgentType) - lastUserMessage := c.lastMessage(ConversationRoleUser) - var toolCalls []string - if c.cfg.FormatMessage != nil { - agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) - } - if c.cfg.FormatToolCall != nil { - agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) - } - for _, toolCall := range toolCalls { - if c.toolCallMessageSet[toolCall] == false { - c.toolCallMessageSet[toolCall] = true - c.cfg.Logger.Info("Tool call detected", "toolCall", toolCall) - } - } - shouldCreateNewMessage := len(c.messages) == 0 || c.messages[len(c.messages)-1].Role == ConversationRoleUser - lastAgentMessage := c.lastMessage(ConversationRoleAgent) - if lastAgentMessage.Message == agentMessage { - return - } - conversationMessage := ConversationMessage{ - Message: agentMessage, - Role: ConversationRoleAgent, - Time: timestamp, - } - if shouldCreateNewMessage { - c.messages = append(c.messages, conversationMessage) - - // Cleanup - c.toolCallMessageSet = make(map[string]bool) - - } else { - c.messages[len(c.messages)-1] = conversationMessage - } - c.messages[len(c.messages)-1].Id = len(c.messages) - 1 -} - -// assumes the caller holds the lock -func (c *Conversation) addSnapshotInner(screen string) { - snapshot := screenSnapshot{ - timestamp: c.cfg.GetTime(), - screen: screen, - } - c.snapshotBuffer.Add(snapshot) - c.updateLastAgentMessage(screen, snapshot.timestamp) -} - -func (c *Conversation) AddSnapshot(screen string) { - c.lock.Lock() - defer c.lock.Unlock() - - c.addSnapshotInner(screen) +type AgentIO interface { + Write(data []byte) (int, error) + ReadScreen() string } type MessagePart interface { @@ -269,198 +49,27 @@ type MessagePart interface { String() string } -type MessagePartText struct { - Content string - Alias string - Hidden bool -} - -func (p MessagePartText) Do(writer AgentIO) error { - _, err := writer.Write([]byte(p.Content)) - return err -} - -func (p MessagePartText) String() string { - if p.Hidden { - return "" - } - if p.Alias != "" { - return p.Alias - } - return p.Content -} - -func PartsToString(parts ...MessagePart) string { - var sb strings.Builder - for _, part := range parts { - sb.WriteString(part.String()) - } - return sb.String() -} - -func ExecuteParts(writer AgentIO, parts ...MessagePart) error { - for _, part := range parts { - if err := part.Do(writer); err != nil { - return xerrors.Errorf("failed to write message part: %w", err) - } - } - return nil -} - -func (c *Conversation) writeMessageWithConfirmation(ctx context.Context, messageParts ...MessagePart) error { - if c.cfg.SkipWritingMessage { - return nil - } - screenBeforeMessage := c.cfg.AgentIO.ReadScreen() - if err := ExecuteParts(c.cfg.AgentIO, messageParts...); err != nil { - return xerrors.Errorf("failed to write message: %w", err) - } - // wait for the screen to stabilize after the message is written - if err := util.WaitFor(ctx, util.WaitTimeout{ - Timeout: 15 * time.Second, - MinInterval: 50 * time.Millisecond, - InitialWait: true, - }, func() (bool, error) { - screen := c.cfg.AgentIO.ReadScreen() - if screen != screenBeforeMessage { - time.Sleep(1 * time.Second) - newScreen := c.cfg.AgentIO.ReadScreen() - return newScreen == screen, nil - } - return false, nil - }); err != nil { - return xerrors.Errorf("failed to wait for screen to stabilize: %w", err) - } - - // wait for the screen to change after the carriage return is written - screenBeforeCarriageReturn := c.cfg.AgentIO.ReadScreen() - lastCarriageReturnTime := time.Time{} - if err := util.WaitFor(ctx, util.WaitTimeout{ - Timeout: 15 * time.Second, - MinInterval: 25 * time.Millisecond, - }, func() (bool, error) { - // we don't want to spam additional carriage returns because the agent may process them - // (aider does this), but we do want to retry sending one if nothing's - // happening for a while - if time.Since(lastCarriageReturnTime) >= 3*time.Second { - lastCarriageReturnTime = time.Now() - if _, err := c.cfg.AgentIO.Write([]byte("\r")); err != nil { - return false, xerrors.Errorf("failed to write carriage return: %w", err) - } - } - time.Sleep(25 * time.Millisecond) - screen := c.cfg.AgentIO.ReadScreen() - - return screen != screenBeforeCarriageReturn, nil - }); err != nil { - return xerrors.Errorf("failed to wait for processing to start: %w", err) - } - - return nil -} - -var MessageValidationErrorWhitespace = xerrors.New("message must be trimmed of leading and trailing whitespace") -var MessageValidationErrorEmpty = xerrors.New("message must not be empty") -var MessageValidationErrorChanging = xerrors.New("message can only be sent when the agent is waiting for user input") - -func (c *Conversation) SendMessage(messageParts ...MessagePart) error { - c.lock.Lock() - defer c.lock.Unlock() - - if !c.cfg.SkipSendMessageStatusCheck && c.statusInner() != ConversationStatusStable { - return MessageValidationErrorChanging - } - - message := PartsToString(messageParts...) - if message != msgfmt.TrimWhitespace(message) { - // msgfmt formatting functions assume this - return MessageValidationErrorWhitespace - } - if message == "" { - // writeMessageWithConfirmation requires a non-empty message - return MessageValidationErrorEmpty - } - - screenBeforeMessage := c.cfg.AgentIO.ReadScreen() - now := c.cfg.GetTime() - c.updateLastAgentMessage(screenBeforeMessage, now) - - if err := c.writeMessageWithConfirmation(context.Background(), messageParts...); err != nil { - return xerrors.Errorf("failed to send message: %w", err) - } - - c.screenBeforeLastUserMessage = screenBeforeMessage - c.messages = append(c.messages, ConversationMessage{ - Id: len(c.messages), - Message: message, - Role: ConversationRoleUser, - Time: now, - }) - return nil -} - -// Assumes that the caller holds the lock -func (c *Conversation) statusInner() ConversationStatus { - // sanity checks - if c.snapshotBuffer.Capacity() != c.stableSnapshotsThreshold { - panic(fmt.Sprintf("snapshot buffer capacity %d is not equal to snapshot threshold %d. can't check stability", c.snapshotBuffer.Capacity(), c.stableSnapshotsThreshold)) - } - if c.stableSnapshotsThreshold == 0 { - panic("stable snapshots threshold is 0. can't check stability") - } - - snapshots := c.snapshotBuffer.GetAll() - if len(c.messages) > 0 && c.messages[len(c.messages)-1].Role == ConversationRoleUser { - // if the last message is a user message then the snapshot loop hasn't - // been triggered since the last user message, and we should assume - // the screen is changing - return ConversationStatusChanging - } - - if len(snapshots) != c.stableSnapshotsThreshold { - return ConversationStatusInitializing - } - - for i := 1; i < len(snapshots); i++ { - if snapshots[0].screen != snapshots[i].screen { - return ConversationStatusChanging - } - } - - if !c.InitialPromptSent && !c.ReadyForInitialPrompt { - if len(snapshots) > 0 && c.cfg.ReadyForInitialPrompt(snapshots[len(snapshots)-1].screen) { - c.ReadyForInitialPrompt = true - return ConversationStatusStable - } - return ConversationStatusChanging - } - - return ConversationStatusStable -} - -func (c *Conversation) Status() ConversationStatus { - c.lock.Lock() - defer c.lock.Unlock() - - return c.statusInner() +// Conversation allows tracking of a conversation between a user and an agent. +type Conversation interface { + Messages() []ConversationMessage + SaveState([]ConversationMessage, string) error + LoadState(string) ([]ConversationMessage, error) + Snapshot(string) + Start(context.Context) + Status() ConversationStatus + String() string } -func (c *Conversation) Messages() []ConversationMessage { - c.lock.Lock() - defer c.lock.Unlock() - - result := make([]ConversationMessage, len(c.messages)) - copy(result, c.messages) - return result +type ConversationMessage struct { + Id int + Message string + Role ConversationRole + Time time.Time } -func (c *Conversation) Screen() string { - c.lock.Lock() - defer c.lock.Unlock() - - snapshots := c.snapshotBuffer.GetAll() - if len(snapshots) == 0 { - return "" - } - return snapshots[len(snapshots)-1].screen +type AgentState struct { + Version int `json:"version"` + Messages []ConversationMessage `json:"messages"` + InitialPrompt string `json:"initial_prompt"` + InitialPromptSent bool `json:"initial_prompt_sent"` } diff --git a/lib/screentracker/diff.go b/lib/screentracker/diff.go new file mode 100644 index 00000000..47c5b78c --- /dev/null +++ b/lib/screentracker/diff.go @@ -0,0 +1,56 @@ +package screentracker + +import ( + "strings" + + "github.com/coder/agentapi/lib/msgfmt" +) + +// screenDiff compares two screen states and attempts to find latest message of the given agent type. +func screenDiff(oldScreen, newScreen string, agentType msgfmt.AgentType) string { + oldLines := strings.Split(oldScreen, "\n") + newLines := strings.Split(newScreen, "\n") + oldLinesMap := make(map[string]bool) + + // -1 indicates no header + dynamicHeaderEnd := -1 + + // Skip header lines for Opencode agent type to avoid false positives + // The header contains dynamic content (token count, context percentage, cost) + // that changes between screens, causing line comparison mismatches: + // + // ┃ # Getting Started with Claude CLI ┃ + // ┃ /share to create a shareable link 12.6K/6% ($0.05) ┃ + if len(newLines) >= 2 && agentType == msgfmt.AgentTypeOpencode { + dynamicHeaderEnd = 2 + } + + for _, line := range oldLines { + oldLinesMap[line] = true + } + firstNonMatchingLine := len(newLines) + for i, line := range newLines[dynamicHeaderEnd+1:] { + if !oldLinesMap[line] { + firstNonMatchingLine = i + break + } + } + newSectionLines := newLines[firstNonMatchingLine:] + + // remove leading and trailing lines which are empty or have only whitespace + startLine := 0 + endLine := len(newSectionLines) - 1 + for i := range newSectionLines { + if strings.TrimSpace(newSectionLines[i]) != "" { + startLine = i + break + } + } + for i := len(newSectionLines) - 1; i >= 0; i-- { + if strings.TrimSpace(newSectionLines[i]) != "" { + endLine = i + break + } + } + return strings.Join(newSectionLines[startLine:endLine+1], "\n") +} diff --git a/lib/screentracker/diff_internal_test.go b/lib/screentracker/diff_internal_test.go new file mode 100644 index 00000000..d68bc36c --- /dev/null +++ b/lib/screentracker/diff_internal_test.go @@ -0,0 +1,39 @@ +package screentracker + +import ( + "embed" + "path" + "testing" + + "github.com/coder/agentapi/lib/msgfmt" + "github.com/stretchr/testify/assert" +) + +//go:embed testdata +var testdataDir embed.FS + +func TestScreenDiff(t *testing.T) { + t.Run("simple", func(t *testing.T) { + assert.Equal(t, "", screenDiff("123456", "123456", msgfmt.AgentTypeCustom)) + assert.Equal(t, "1234567", screenDiff("123456", "1234567", msgfmt.AgentTypeCustom)) + assert.Equal(t, "42", screenDiff("123", "123\n \n \n \n42", msgfmt.AgentTypeCustom)) + assert.Equal(t, "12342", screenDiff("123", "12342\n \n \n \n", msgfmt.AgentTypeCustom)) + assert.Equal(t, "42", screenDiff("123", "123\n \n \n \n42\n \n \n \n", msgfmt.AgentTypeCustom)) + assert.Equal(t, "42", screenDiff("89", "42", msgfmt.AgentTypeCustom)) + }) + + dir := "testdata/diff" + cases, err := testdataDir.ReadDir(dir) + assert.NoError(t, err) + for _, c := range cases { + t.Run(c.Name(), func(t *testing.T) { + before, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "before.txt")) + assert.NoError(t, err) + after, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "after.txt")) + assert.NoError(t, err) + expected, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "expected.txt")) + assert.NoError(t, err) + assert.Equal(t, string(expected), screenDiff(string(before), string(after), msgfmt.AgentTypeCustom)) + }) + } +} diff --git a/lib/screentracker/pty_conversation.go b/lib/screentracker/pty_conversation.go new file mode 100644 index 00000000..a8f73c0a --- /dev/null +++ b/lib/screentracker/pty_conversation.go @@ -0,0 +1,500 @@ +package screentracker + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/coder/agentapi/lib/msgfmt" + "github.com/coder/agentapi/lib/util" + "golang.org/x/xerrors" +) + +// A screenSnapshot represents a snapshot of the PTY at a specific time. +type screenSnapshot struct { + timestamp time.Time + screen string +} + +type MessagePartText struct { + Content string + Alias string + Hidden bool +} + +var _ MessagePart = &MessagePartText{} + +func (p MessagePartText) Do(writer AgentIO) error { + _, err := writer.Write([]byte(p.Content)) + return err +} + +func (p MessagePartText) String() string { + if p.Hidden { + return "" + } + if p.Alias != "" { + return p.Alias + } + return p.Content +} + +// PTYConversationConfig is the configuration for a PTYConversation. +type PTYConversationConfig struct { + AgentType msgfmt.AgentType + AgentIO AgentIO + // GetTime returns the current time + GetTime func() time.Time + // How often to take a snapshot for the stability check + SnapshotInterval time.Duration + // How long the screen should not change to be considered stable + ScreenStabilityLength time.Duration + // Function to format the messages received from the agent + // userInput is the last user message + FormatMessage func(message string, userInput string) string + // SkipWritingMessage skips the writing of a message to the agent. + // This is used in tests + SkipWritingMessage bool + // SkipSendMessageStatusCheck skips the check for whether the message can be sent. + // This is used in tests + SkipSendMessageStatusCheck bool + // ReadyForInitialPrompt detects whether the agent has initialized and is ready to accept the initial prompt + ReadyForInitialPrompt func(message string) bool + // FormatToolCall removes the coder report_task tool call from the agent message and also returns the array of removed tool calls + FormatToolCall func(message string) (string, []string) + Logger *slog.Logger +} + +func (cfg PTYConversationConfig) getStableSnapshotsThreshold() int { + length := cfg.ScreenStabilityLength.Milliseconds() + interval := cfg.SnapshotInterval.Milliseconds() + threshold := int(length / interval) + if length%interval != 0 { + threshold++ + } + return threshold + 1 +} + +// PTYConversation is a conversation that uses a pseudo-terminal (PTY) for communication. +// It uses a combination of polling and diffs to detect changes in the screen. +type PTYConversation struct { + cfg PTYConversationConfig + // How many stable snapshots are required to consider the screen stable + stableSnapshotsThreshold int + snapshotBuffer *RingBuffer[screenSnapshot] + messages []ConversationMessage + screenBeforeLastUserMessage string + lock sync.Mutex + + // InitialPrompt is the initial prompt passed to the agent + InitialPrompt string + // InitialPromptSent keeps track if the InitialPrompt has been successfully sent to the agents + InitialPromptSent bool + // ReadyForInitialPrompt keeps track if the agent is ready to accept the initial prompt + ReadyForInitialPrompt bool + // toolCallMessageSet keeps track of the tool calls that have been detected & logged in the current agent message + toolCallMessageSet map[string]bool + // dirty tracks whether the conversation state has changed since the last save + dirty bool + // firstStableSnapshot is the conversation history rolled out by the agent in case of a resume (given that the agent supports it) + firstStableSnapshot string + // userSentMessageAfterLoadState tracks if the user has sent their first message after we load the state + userSentMessageAfterLoadState bool + // loadStateSuccessful indicates whether conversation state was successfully restored from file. + loadStateSuccessful bool +} + +var _ Conversation = &PTYConversation{} + +func NewPTY(ctx context.Context, cfg PTYConversationConfig, initialPrompt string) *PTYConversation { + threshold := cfg.getStableSnapshotsThreshold() + c := &PTYConversation{ + cfg: cfg, + stableSnapshotsThreshold: threshold, + snapshotBuffer: NewRingBuffer[screenSnapshot](threshold), + messages: []ConversationMessage{ + { + Message: "", + Role: ConversationRoleAgent, + Time: cfg.GetTime(), + }, + }, + InitialPrompt: initialPrompt, + InitialPromptSent: len(initialPrompt) == 0, + toolCallMessageSet: make(map[string]bool), + dirty: false, + firstStableSnapshot: "", + userSentMessageAfterLoadState: false, + loadStateSuccessful: false, + } + return c +} + +func (c *PTYConversation) Start(ctx context.Context) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(c.cfg.SnapshotInterval): + // It's important that we hold the lock while reading the screen. + // There's a race condition that occurs without it: + // 1. The screen is read + // 2. Independently, SendMessage is called and takes the lock. + // 3. AddSnapshot is called and waits on the lock. + // 4. SendMessage modifies the terminal state, releases the lock + // 5. AddSnapshot adds a snapshot from a stale screen + c.lock.Lock() + screen := c.cfg.AgentIO.ReadScreen() + c.snapshotLocked(screen) + c.lock.Unlock() + } + } + }() +} + +func (c *PTYConversation) lastMessage(role ConversationRole) ConversationMessage { + for i := len(c.messages) - 1; i >= 0; i-- { + if c.messages[i].Role == role { + return c.messages[i] + } + } + return ConversationMessage{} +} + +// caller MUST hold c.lock +func (c *PTYConversation) updateLastAgentMessageLocked(screen string, timestamp time.Time) { + agentMessage := screenDiff(c.screenBeforeLastUserMessage, screen, c.cfg.AgentType) + lastUserMessage := c.lastMessage(ConversationRoleUser) + var toolCalls []string + if c.cfg.FormatMessage != nil { + agentMessage = c.cfg.FormatMessage(agentMessage, lastUserMessage.Message) + } + if c.loadStateSuccessful { + agentMessage = c.adjustScreenAfterStateLoad(agentMessage) + } + if c.cfg.FormatToolCall != nil { + agentMessage, toolCalls = c.cfg.FormatToolCall(agentMessage) + } + for _, toolCall := range toolCalls { + if c.toolCallMessageSet[toolCall] == false { + c.toolCallMessageSet[toolCall] = true + c.cfg.Logger.Info("Tool call detected", "toolCall", toolCall) + } + } + shouldCreateNewMessage := len(c.messages) == 0 || c.messages[len(c.messages)-1].Role == ConversationRoleUser + lastAgentMessage := c.lastMessage(ConversationRoleAgent) + if lastAgentMessage.Message == agentMessage { + return + } + conversationMessage := ConversationMessage{ + Message: agentMessage, + Role: ConversationRoleAgent, + Time: timestamp, + } + if shouldCreateNewMessage { + c.messages = append(c.messages, conversationMessage) + + // Cleanup + c.toolCallMessageSet = make(map[string]bool) + + } else { + c.messages[len(c.messages)-1] = conversationMessage + } + c.messages[len(c.messages)-1].Id = len(c.messages) - 1 + c.dirty = true +} + +func (c *PTYConversation) Snapshot(screen string) { + c.lock.Lock() + defer c.lock.Unlock() + + c.snapshotLocked(screen) +} + +// caller MUST hold c.lock +func (c *PTYConversation) snapshotLocked(screen string) { + snapshot := screenSnapshot{ + timestamp: c.cfg.GetTime(), + screen: screen, + } + c.snapshotBuffer.Add(snapshot) + c.updateLastAgentMessageLocked(screen, snapshot.timestamp) +} + +func (c *PTYConversation) Send(messageParts ...MessagePart) error { + c.lock.Lock() + defer c.lock.Unlock() + + if !c.cfg.SkipSendMessageStatusCheck && c.statusLocked() != ConversationStatusStable { + return MessageValidationErrorChanging + } + + var sb strings.Builder + for _, part := range messageParts { + sb.WriteString(part.String()) + } + message := sb.String() + if message != msgfmt.TrimWhitespace(message) { + // msgfmt formatting functions assume this + return MessageValidationErrorWhitespace + } + if message == "" { + // writeMessageWithConfirmation requires a non-empty message + return MessageValidationErrorEmpty + } + + screenBeforeMessage := c.cfg.AgentIO.ReadScreen() + now := c.cfg.GetTime() + c.updateLastAgentMessageLocked(screenBeforeMessage, now) + + if err := c.writeStabilize(context.Background(), messageParts...); err != nil { + return xerrors.Errorf("failed to send message: %w", err) + } + + c.screenBeforeLastUserMessage = screenBeforeMessage + c.messages = append(c.messages, ConversationMessage{ + Id: len(c.messages), + Message: message, + Role: ConversationRoleUser, + Time: now, + }) + c.userSentMessageAfterLoadState = true + + return nil +} + +// writeStabilize writes messageParts to the screen and waits for the screen to stabilize after the message is written. +func (c *PTYConversation) writeStabilize(ctx context.Context, messageParts ...MessagePart) error { + if c.cfg.SkipWritingMessage { + return nil + } + screenBeforeMessage := c.cfg.AgentIO.ReadScreen() + for _, part := range messageParts { + if err := part.Do(c.cfg.AgentIO); err != nil { + return xerrors.Errorf("failed to write message part: %w", err) + } + } + // wait for the screen to stabilize after the message is written + if err := util.WaitFor(ctx, util.WaitTimeout{ + Timeout: 15 * time.Second, + MinInterval: 50 * time.Millisecond, + InitialWait: true, + }, func() (bool, error) { + screen := c.cfg.AgentIO.ReadScreen() + if screen != screenBeforeMessage { + time.Sleep(1 * time.Second) + newScreen := c.cfg.AgentIO.ReadScreen() + return newScreen == screen, nil + } + return false, nil + }); err != nil { + return xerrors.Errorf("failed to wait for screen to stabilize: %w", err) + } + + // wait for the screen to change after the carriage return is written + screenBeforeCarriageReturn := c.cfg.AgentIO.ReadScreen() + lastCarriageReturnTime := time.Time{} + if err := util.WaitFor(ctx, util.WaitTimeout{ + Timeout: 15 * time.Second, + MinInterval: 25 * time.Millisecond, + }, func() (bool, error) { + // we don't want to spam additional carriage returns because the agent may process them + // (aider does this), but we do want to retry sending one if nothing's + // happening for a while + if time.Since(lastCarriageReturnTime) >= 3*time.Second { + lastCarriageReturnTime = time.Now() + if _, err := c.cfg.AgentIO.Write([]byte("\r")); err != nil { + return false, xerrors.Errorf("failed to write carriage return: %w", err) + } + } + time.Sleep(25 * time.Millisecond) + screen := c.cfg.AgentIO.ReadScreen() + + return screen != screenBeforeCarriageReturn, nil + }); err != nil { + return xerrors.Errorf("failed to wait for processing to start: %w", err) + } + + return nil +} + +func (c *PTYConversation) Status() ConversationStatus { + c.lock.Lock() + defer c.lock.Unlock() + + return c.statusLocked() +} + +// caller MUST hold c.lock +func (c *PTYConversation) statusLocked() ConversationStatus { + // sanity checks + if c.snapshotBuffer.Capacity() != c.stableSnapshotsThreshold { + panic(fmt.Sprintf("snapshot buffer capacity %d is not equal to snapshot threshold %d. can't check stability", c.snapshotBuffer.Capacity(), c.stableSnapshotsThreshold)) + } + if c.stableSnapshotsThreshold == 0 { + panic("stable snapshots threshold is 0. can't check stability") + } + + snapshots := c.snapshotBuffer.GetAll() + if len(c.messages) > 0 && c.messages[len(c.messages)-1].Role == ConversationRoleUser { + // if the last message is a user message then the snapshot loop hasn't + // been triggered since the last user message, and we should assume + // the screen is changing + return ConversationStatusChanging + } + + if len(snapshots) != c.stableSnapshotsThreshold { + return ConversationStatusInitializing + } + + for i := 1; i < len(snapshots); i++ { + if snapshots[0].screen != snapshots[i].screen { + return ConversationStatusChanging + } + } + + if !c.InitialPromptSent && !c.ReadyForInitialPrompt { + if len(snapshots) > 0 && c.cfg.ReadyForInitialPrompt(snapshots[len(snapshots)-1].screen) { + c.ReadyForInitialPrompt = true + return ConversationStatusStable + } + return ConversationStatusChanging + } + + return ConversationStatusStable +} + +func (c *PTYConversation) Messages() []ConversationMessage { + c.lock.Lock() + defer c.lock.Unlock() + + result := make([]ConversationMessage, len(c.messages)) + copy(result, c.messages) + return result +} + +func (c *PTYConversation) String() string { + c.lock.Lock() + defer c.lock.Unlock() + + snapshots := c.snapshotBuffer.GetAll() + if len(snapshots) == 0 { + return "" + } + return snapshots[len(snapshots)-1].screen +} + +func (c *PTYConversation) SaveState(conversation []ConversationMessage, stateFile string) error { + c.lock.Lock() + defer c.lock.Unlock() + + // Skip if state file is not configured + if stateFile == "" { + return nil + } + + // Skip if not dirty + if !c.dirty { + return nil + } + + // Use atomic write: write to temp file, then rename to target path + data, err := json.MarshalIndent(AgentState{ + Version: 1, + Messages: conversation, + InitialPrompt: c.InitialPrompt, + InitialPromptSent: c.InitialPromptSent, + }, "", " ") + if err != nil { + return xerrors.Errorf("failed to marshal state: %w", err) + } + + // Create directory if it doesn't exist + dir := filepath.Dir(stateFile) + if err := os.MkdirAll(dir, 0o755); err != nil { + return xerrors.Errorf("failed to create state directory: %w", err) + } + + // Write to temp file + tempFile := stateFile + ".tmp" + if err := os.WriteFile(tempFile, data, 0o644); err != nil { + return xerrors.Errorf("failed to write temp state file: %w", err) + } + + // Atomic rename + if err := os.Rename(tempFile, stateFile); err != nil { + return xerrors.Errorf("failed to rename state file: %w", err) + } + + // Clear dirty flag after successful save + c.dirty = false + return nil +} + +func (c *PTYConversation) LoadState(stateFile string) ([]ConversationMessage, error) { + c.lock.Lock() + defer c.lock.Unlock() + + // Skip if state file is not configured + if stateFile == "" { + return nil, nil + } + + // Check if file exists + if _, err := os.Stat(stateFile); os.IsNotExist(err) { + c.cfg.Logger.Info("No previous state to load (file does not exist)", "path", stateFile) + return nil, nil + } + + // Read state file + data, err := os.ReadFile(stateFile) + if err != nil { + c.cfg.Logger.Warn("Failed to load state file", "path", stateFile, "err", err) + return nil, xerrors.Errorf("failed to read state file: %w", err) + } + + if len(data) == 0 { + c.cfg.Logger.Info("No previous state to load (file is empty)", "path", stateFile) + return nil, nil + } + + var agentState AgentState + if err := json.Unmarshal(data, &agentState); err != nil { + c.cfg.Logger.Warn("Failed to load state file (corrupted or invalid JSON)", "path", stateFile, "err", err) + return nil, xerrors.Errorf("failed to unmarshal state (corrupted or invalid JSON): %w", err) + } + + c.InitialPromptSent = agentState.InitialPromptSent + c.InitialPrompt = agentState.InitialPrompt + c.messages = agentState.Messages + + // Store the first stable snapshot for filtering later + snapshots := c.snapshotBuffer.GetAll() + if len(snapshots) > 0 { + c.firstStableSnapshot = c.cfg.FormatMessage(strings.TrimSpace(snapshots[len(snapshots)-1].screen), "") + } + + c.loadStateSuccessful = true + c.cfg.Logger.Info("Successfully loaded state", "path", stateFile, "messages", len(c.messages)) + return c.messages, nil +} + +func (c *PTYConversation) adjustScreenAfterStateLoad(screen string) string { + newScreen := strings.Replace(screen, c.firstStableSnapshot, "", 1) + + // Before the first user message after loading state, return the last message from the loaded state. + // This prevents computing incorrect diffs from the restored screen, as the agent's message should + // remain stable until the user continues the conversation. + if c.userSentMessageAfterLoadState == false { + newScreen = "\n" + c.messages[len(c.messages)-1].Message + } + + return newScreen +} diff --git a/lib/screentracker/conversation_test.go b/lib/screentracker/pty_conversation_test.go similarity index 75% rename from lib/screentracker/conversation_test.go rename to lib/screentracker/pty_conversation_test.go index 9b888813..6798de4d 100644 --- a/lib/screentracker/conversation_test.go +++ b/lib/screentracker/pty_conversation_test.go @@ -2,13 +2,10 @@ package screentracker_test import ( "context" - "embed" "fmt" - "path" "testing" "time" - "github.com/coder/agentapi/lib/msgfmt" "github.com/stretchr/testify/assert" st "github.com/coder/agentapi/lib/screentracker" @@ -19,7 +16,7 @@ type statusTestStep struct { status st.ConversationStatus } type statusTestParams struct { - cfg st.ConversationConfig + cfg st.PTYConversationConfig steps []statusTestStep } @@ -42,11 +39,11 @@ func statusTest(t *testing.T, params statusTestParams) { if params.cfg.GetTime == nil { params.cfg.GetTime = func() time.Time { return time.Now() } } - c := st.NewConversation(ctx, params.cfg, "") + c := st.NewPTY(ctx, params.cfg, "") assert.Equal(t, st.ConversationStatusInitializing, c.Status()) for i, step := range params.steps { - c.AddSnapshot(step.snapshot) + c.Snapshot(step.snapshot) assert.Equal(t, step.status, c.Status(), "step %d", i) } }) @@ -58,7 +55,7 @@ func TestConversation(t *testing.T) { initializing := st.ConversationStatusInitializing statusTest(t, statusTestParams{ - cfg: st.ConversationConfig{ + cfg: st.PTYConversationConfig{ SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 2 * time.Second, // stability threshold: 3 @@ -76,7 +73,7 @@ func TestConversation(t *testing.T) { }) statusTest(t, statusTestParams{ - cfg: st.ConversationConfig{ + cfg: st.PTYConversationConfig{ SnapshotInterval: 2 * time.Second, ScreenStabilityLength: 3 * time.Second, // stability threshold: 3 @@ -95,7 +92,7 @@ func TestConversation(t *testing.T) { }) statusTest(t, statusTestParams{ - cfg: st.ConversationConfig{ + cfg: st.PTYConversationConfig{ SnapshotInterval: 6 * time.Second, ScreenStabilityLength: 14 * time.Second, // stability threshold: 4 @@ -133,11 +130,11 @@ func TestMessages(t *testing.T) { Time: now, } } - sendMsg := func(c *st.Conversation, msg string) error { - return c.SendMessage(st.MessagePartText{Content: msg}) + sendMsg := func(c *st.PTYConversation, msg string) error { + return c.Send(st.MessagePartText{Content: msg}) } - newConversation := func(opts ...func(*st.ConversationConfig)) *st.Conversation { - cfg := st.ConversationConfig{ + newConversation := func(opts ...func(*st.PTYConversationConfig)) *st.PTYConversation { + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 2 * time.Second, @@ -147,7 +144,7 @@ func TestMessages(t *testing.T) { for _, opt := range opts { opt(&cfg) } - return st.NewConversation(context.Background(), cfg, "") + return st.NewPTY(context.Background(), cfg, "") } t.Run("messages are copied", func(t *testing.T) { @@ -167,7 +164,7 @@ func TestMessages(t *testing.T) { t.Run("whitespace-padding", func(t *testing.T) { c := newConversation() for _, msg := range []string{"123 ", " 123", "123\t\t", "\n123", "123\n\t", " \t123\n\t"} { - err := c.SendMessage(st.MessagePartText{Content: msg}) + err := c.Send(st.MessagePartText{Content: msg}) assert.Error(t, err, st.MessageValidationErrorWhitespace) } }) @@ -178,33 +175,33 @@ func TestMessages(t *testing.T) { }{ Time: now, } - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.GetTime = func() time.Time { return nowWrapper.Time } }) - c.AddSnapshot("1") + c.Snapshot("1") msgs := c.Messages() assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), }, msgs) nowWrapper.Time = nowWrapper.Add(1 * time.Second) - c.AddSnapshot("1") + c.Snapshot("1") assert.Equal(t, msgs, c.Messages()) }) t.Run("tracking messages", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent }) // agent message is recorded when the first snapshot is added - c.AddSnapshot("1") + c.Snapshot("1") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), }, c.Messages()) // agent message is updated when the screen changes - c.AddSnapshot("2") + c.Snapshot("2") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "2"), }, c.Messages()) @@ -218,7 +215,7 @@ func TestMessages(t *testing.T) { }, c.Messages()) // agent message is added after a user message - c.AddSnapshot("4") + c.Snapshot("4") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "2"), userMsg(1, "3"), @@ -236,9 +233,9 @@ func TestMessages(t *testing.T) { }, c.Messages()) // conversation status is changing right after a user message - c.AddSnapshot("7") - c.AddSnapshot("7") - c.AddSnapshot("7") + c.Snapshot("7") + c.Snapshot("7") + c.Snapshot("7") assert.Equal(t, st.ConversationStatusStable, c.Status()) agent.screen = "7" assert.NoError(t, sendMsg(c, "8")) @@ -254,21 +251,21 @@ func TestMessages(t *testing.T) { // conversation status is back to stable after a snapshot that // doesn't change the screen - c.AddSnapshot("7") + c.Snapshot("7") assert.Equal(t, st.ConversationStatusStable, c.Status()) }) t.Run("tracking messages overlap", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent }) // common overlap between screens is removed after a user message - c.AddSnapshot("1") + c.Snapshot("1") agent.screen = "1" assert.NoError(t, sendMsg(c, "2")) - c.AddSnapshot("1\n3") + c.Snapshot("1\n3") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), userMsg(1, "2"), @@ -277,7 +274,7 @@ func TestMessages(t *testing.T) { agent.screen = "1\n3x" assert.NoError(t, sendMsg(c, "4")) - c.AddSnapshot("1\n3x\n5") + c.Snapshot("1\n3x\n5") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1"), userMsg(1, "2"), @@ -289,7 +286,7 @@ func TestMessages(t *testing.T) { t.Run("format-message", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent cfg.FormatMessage = func(message string, userInput string) string { return message + " " + userInput @@ -302,7 +299,7 @@ func TestMessages(t *testing.T) { userMsg(1, "2"), }, c.Messages()) agent.screen = "x" - c.AddSnapshot("x") + c.Snapshot("x") assert.Equal(t, []st.ConversationMessage{ agentMsg(0, "1 "), userMsg(1, "2"), @@ -312,7 +309,7 @@ func TestMessages(t *testing.T) { t.Run("format-message", func(t *testing.T) { agent := &testAgent{} - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.AgentIO = agent cfg.FormatMessage = func(message string, userInput string) string { return "formatted" @@ -329,7 +326,7 @@ func TestMessages(t *testing.T) { }) t.Run("send-message-status-check", func(t *testing.T) { - c := newConversation(func(cfg *st.ConversationConfig) { + c := newConversation(func(cfg *st.PTYConversationConfig) { cfg.SkipSendMessageStatusCheck = false cfg.SnapshotInterval = 1 * time.Second cfg.ScreenStabilityLength = 2 * time.Second @@ -337,10 +334,10 @@ func TestMessages(t *testing.T) { }) assert.Error(t, sendMsg(c, "1"), st.MessageValidationErrorChanging) for range 3 { - c.AddSnapshot("1") + c.Snapshot("1") } assert.NoError(t, sendMsg(c, "4")) - c.AddSnapshot("2") + c.Snapshot("2") assert.Error(t, sendMsg(c, "5"), st.MessageValidationErrorChanging) }) @@ -350,68 +347,11 @@ func TestMessages(t *testing.T) { }) } -//go:embed testdata -var testdataDir embed.FS - -func TestFindNewMessage(t *testing.T) { - assert.Equal(t, "", st.FindNewMessage("123456", "123456", msgfmt.AgentTypeCustom)) - assert.Equal(t, "1234567", st.FindNewMessage("123456", "1234567", msgfmt.AgentTypeCustom)) - assert.Equal(t, "42", st.FindNewMessage("123", "123\n \n \n \n42", msgfmt.AgentTypeCustom)) - assert.Equal(t, "12342", st.FindNewMessage("123", "12342\n \n \n \n", msgfmt.AgentTypeCustom)) - assert.Equal(t, "42", st.FindNewMessage("123", "123\n \n \n \n42\n \n \n \n", msgfmt.AgentTypeCustom)) - assert.Equal(t, "42", st.FindNewMessage("89", "42", msgfmt.AgentTypeCustom)) - - dir := "testdata/diff" - cases, err := testdataDir.ReadDir(dir) - assert.NoError(t, err) - for _, c := range cases { - t.Run(c.Name(), func(t *testing.T) { - before, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "before.txt")) - assert.NoError(t, err) - after, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "after.txt")) - assert.NoError(t, err) - expected, err := testdataDir.ReadFile(path.Join(dir, c.Name(), "expected.txt")) - assert.NoError(t, err) - assert.Equal(t, string(expected), st.FindNewMessage(string(before), string(after), msgfmt.AgentTypeCustom)) - }) - } -} - -func TestPartsToString(t *testing.T) { - assert.Equal(t, "123", st.PartsToString(st.MessagePartText{Content: "123"})) - assert.Equal(t, - "123", - st.PartsToString( - st.MessagePartText{Content: "1"}, - st.MessagePartText{Content: "2"}, - st.MessagePartText{Content: "3"}, - ), - ) - assert.Equal(t, - "123", - st.PartsToString( - st.MessagePartText{Content: "1"}, - st.MessagePartText{Content: "x", Hidden: true}, - st.MessagePartText{Content: "2"}, - st.MessagePartText{Content: "3"}, - st.MessagePartText{Content: "y", Hidden: true}, - ), - ) - assert.Equal(t, - "ab", - st.PartsToString( - st.MessagePartText{Content: "1", Alias: "a"}, - st.MessagePartText{Content: "2", Alias: "b"}, - st.MessagePartText{Content: "3", Alias: "c", Hidden: true}, - ), - ) -} - func TestInitialPromptReadiness(t *testing.T) { now := time.Now() t.Run("agent not ready - status remains changing", func(t *testing.T) { - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -420,10 +360,10 @@ func TestInitialPromptReadiness(t *testing.T) { return message == "ready" }, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // Fill buffer with stable snapshots, but agent is not ready - c.AddSnapshot("loading...") + c.Snapshot("loading...") // Even though screen is stable, status should be changing because agent is not ready assert.Equal(t, st.ConversationStatusChanging, c.Status()) @@ -432,7 +372,7 @@ func TestInitialPromptReadiness(t *testing.T) { }) t.Run("agent becomes ready - status changes to stable", func(t *testing.T) { - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -441,14 +381,14 @@ func TestInitialPromptReadiness(t *testing.T) { return message == "ready" }, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // Agent not ready initially - c.AddSnapshot("loading...") + c.Snapshot("loading...") assert.Equal(t, st.ConversationStatusChanging, c.Status()) // Agent becomes ready - c.AddSnapshot("ready") + c.Snapshot("ready") assert.Equal(t, st.ConversationStatusStable, c.Status()) assert.True(t, c.ReadyForInitialPrompt) assert.False(t, c.InitialPromptSent) @@ -456,7 +396,7 @@ func TestInitialPromptReadiness(t *testing.T) { t.Run("ready for initial prompt lifecycle: false -> true -> false", func(t *testing.T) { agent := &testAgent{screen: "loading..."} - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -467,23 +407,23 @@ func TestInitialPromptReadiness(t *testing.T) { SkipWritingMessage: true, SkipSendMessageStatusCheck: true, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // Initial state: ReadyForInitialPrompt should be false - c.AddSnapshot("loading...") + c.Snapshot("loading...") assert.False(t, c.ReadyForInitialPrompt, "should start as false") assert.False(t, c.InitialPromptSent) assert.Equal(t, st.ConversationStatusChanging, c.Status()) // Agent becomes ready: ReadyForInitialPrompt should become true agent.screen = "ready" - c.AddSnapshot("ready") + c.Snapshot("ready") assert.Equal(t, st.ConversationStatusStable, c.Status()) assert.True(t, c.ReadyForInitialPrompt, "should become true when ready") assert.False(t, c.InitialPromptSent) // Send the initial prompt - assert.NoError(t, c.SendMessage(st.MessagePartText{Content: "initial prompt here"})) + assert.NoError(t, c.Send(st.MessagePartText{Content: "initial prompt here"})) // After sending initial prompt: ReadyForInitialPrompt should be set back to false // (simulating what happens in the actual server code) @@ -496,7 +436,7 @@ func TestInitialPromptReadiness(t *testing.T) { }) t.Run("no initial prompt - normal status logic applies", func(t *testing.T) { - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -506,9 +446,9 @@ func TestInitialPromptReadiness(t *testing.T) { }, } // Empty initial prompt means no need to wait for readiness - c := st.NewConversation(context.Background(), cfg, "") + c := st.NewPTY(context.Background(), cfg, "") - c.AddSnapshot("loading...") + c.Snapshot("loading...") // Status should be stable because no initial prompt to wait for assert.Equal(t, st.ConversationStatusStable, c.Status()) @@ -518,7 +458,7 @@ func TestInitialPromptReadiness(t *testing.T) { t.Run("initial prompt sent - normal status logic applies", func(t *testing.T) { agent := &testAgent{screen: "ready"} - cfg := st.ConversationConfig{ + cfg := st.PTYConversationConfig{ GetTime: func() time.Time { return now }, SnapshotInterval: 1 * time.Second, ScreenStabilityLength: 0, @@ -529,24 +469,24 @@ func TestInitialPromptReadiness(t *testing.T) { SkipWritingMessage: true, SkipSendMessageStatusCheck: true, } - c := st.NewConversation(context.Background(), cfg, "initial prompt here") + c := st.NewPTY(context.Background(), cfg, "initial prompt here") // First, agent becomes ready - c.AddSnapshot("ready") + c.Snapshot("ready") assert.Equal(t, st.ConversationStatusStable, c.Status()) assert.True(t, c.ReadyForInitialPrompt) assert.False(t, c.InitialPromptSent) // Send the initial prompt agent.screen = "processing..." - assert.NoError(t, c.SendMessage(st.MessagePartText{Content: "initial prompt here"})) + assert.NoError(t, c.Send(st.MessagePartText{Content: "initial prompt here"})) // Mark initial prompt as sent (simulating what the server does) c.InitialPromptSent = true c.ReadyForInitialPrompt = false // Now test that status logic works normally after initial prompt is sent - c.AddSnapshot("processing...") + c.Snapshot("processing...") // Status should be stable because initial prompt was already sent // and the readiness check is bypassed