-
Notifications
You must be signed in to change notification settings - Fork 104
feat: implement state persistence #177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e3bd936
e5f1bda
a0f8bb5
ca3cdff
1c224e9
30f82d7
12bed1c
e366e8b
26fdf81
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,18 +34,20 @@ import ( | |
|
|
||
| // Server represents the HTTP server | ||
| type Server struct { | ||
| router chi.Router | ||
| api huma.API | ||
| port int | ||
| srv *http.Server | ||
| mu sync.RWMutex | ||
| logger *slog.Logger | ||
| conversation *st.Conversation | ||
| agentio *termexec.Process | ||
| agentType mf.AgentType | ||
| emitter *EventEmitter | ||
| chatBasePath string | ||
| tempDir string | ||
| router chi.Router | ||
| api huma.API | ||
| port int | ||
| srv *http.Server | ||
| mu sync.RWMutex | ||
| logger *slog.Logger | ||
| conversation *st.PTYConversation | ||
| agentio *termexec.Process | ||
| agentType mf.AgentType | ||
| emitter *EventEmitter | ||
| chatBasePath string | ||
| tempDir string | ||
| statePersistenceConfig StatePersistenceConfig | ||
| stateLoadComplete bool | ||
| } | ||
|
|
||
| func (s *Server) NormalizeSchema(schema any) any { | ||
|
|
@@ -94,14 +96,21 @@ func (s *Server) GetOpenAPI() string { | |
| // because the action of taking a snapshot takes time too. | ||
| const snapshotInterval = 25 * time.Millisecond | ||
|
|
||
| type StatePersistenceConfig struct { | ||
| StateFile string | ||
| LoadState bool | ||
| SaveState bool | ||
| } | ||
|
|
||
| type ServerConfig struct { | ||
| AgentType mf.AgentType | ||
| Process *termexec.Process | ||
| Port int | ||
| ChatBasePath string | ||
| AllowedHosts []string | ||
| AllowedOrigins []string | ||
| InitialPrompt string | ||
| AgentType mf.AgentType | ||
| Process *termexec.Process | ||
| Port int | ||
| ChatBasePath string | ||
| AllowedHosts []string | ||
| AllowedOrigins []string | ||
| InitialPrompt string | ||
| StatePersistenceConfig StatePersistenceConfig | ||
| } | ||
|
|
||
| // Validate allowed hosts don't contain whitespace, commas, schemes, or ports. | ||
|
|
@@ -237,7 +246,7 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { | |
| return mf.FormatToolCall(config.AgentType, message) | ||
| } | ||
|
|
||
| conversation := st.NewConversation(ctx, st.ConversationConfig{ | ||
| conversation := st.NewPTY(ctx, st.PTYConversationConfig{ | ||
| AgentType: config.AgentType, | ||
| AgentIO: config.Process, | ||
| GetTime: func() time.Time { | ||
|
|
@@ -260,16 +269,18 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { | |
| logger.Info("Created temporary directory for uploads", "tempDir", tempDir) | ||
|
|
||
| s := &Server{ | ||
| router: router, | ||
| api: api, | ||
| port: config.Port, | ||
| conversation: conversation, | ||
| logger: logger, | ||
| agentio: config.Process, | ||
| agentType: config.AgentType, | ||
| emitter: emitter, | ||
| chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), | ||
| tempDir: tempDir, | ||
| router: router, | ||
| api: api, | ||
| port: config.Port, | ||
| conversation: conversation, | ||
| logger: logger, | ||
| agentio: config.Process, | ||
| agentType: config.AgentType, | ||
| emitter: emitter, | ||
| chatBasePath: strings.TrimSuffix(config.ChatBasePath, "/"), | ||
| tempDir: tempDir, | ||
| statePersistenceConfig: config.StatePersistenceConfig, | ||
| stateLoadComplete: false, | ||
| } | ||
|
|
||
| // Register API routes | ||
|
|
@@ -331,26 +342,32 @@ func sseMiddleware(ctx huma.Context, next func(huma.Context)) { | |
| } | ||
|
|
||
| func (s *Server) StartSnapshotLoop(ctx context.Context) { | ||
| s.conversation.StartSnapshotLoop(ctx) | ||
| s.conversation.Start(ctx) | ||
| go func() { | ||
| for { | ||
| currentStatus := s.conversation.Status() | ||
|
|
||
| // Send initial prompt when agent becomes stable for the first time | ||
| if !s.conversation.InitialPromptSent && convertStatus(currentStatus) == AgentStatusStable { | ||
| // Send initial prompt & load state when agent becomes stable for the first time | ||
| if convertStatus(currentStatus) == AgentStatusStable { | ||
|
|
||
| if err := s.conversation.SendMessage(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { | ||
| s.logger.Error("Failed to send initial prompt", "error", err) | ||
| } else { | ||
| s.conversation.InitialPromptSent = true | ||
| s.conversation.ReadyForInitialPrompt = false | ||
| currentStatus = st.ConversationStatusChanging | ||
| s.logger.Info("Initial prompt sent successfully") | ||
| if !s.stateLoadComplete && s.statePersistenceConfig.LoadState { | ||
| _, _ = s.conversation.LoadState(s.statePersistenceConfig.StateFile) | ||
| s.stateLoadComplete = true | ||
| } | ||
| if !s.conversation.InitialPromptSent { | ||
| if err := s.conversation.Send(FormatMessage(s.agentType, s.conversation.InitialPrompt)...); err != nil { | ||
| s.logger.Error("Failed to send initial prompt", "error", err) | ||
| } else { | ||
| s.conversation.InitialPromptSent = true | ||
| s.conversation.ReadyForInitialPrompt = false | ||
| currentStatus = st.ConversationStatusChanging | ||
| s.logger.Info("Initial prompt sent successfully") | ||
| } | ||
| } | ||
| } | ||
| s.emitter.UpdateStatusAndEmitChanges(currentStatus, s.agentType) | ||
| s.emitter.UpdateMessagesAndEmitChanges(s.conversation.Messages()) | ||
| s.emitter.UpdateScreenAndEmitChanges(s.conversation.Screen()) | ||
| s.emitter.UpdateScreenAndEmitChanges(s.conversation.String()) | ||
| time.Sleep(snapshotInterval) | ||
| } | ||
| }() | ||
|
|
@@ -449,7 +466,7 @@ func (s *Server) createMessage(ctx context.Context, input *MessageRequest) (*Mes | |
|
|
||
| switch input.Body.Type { | ||
| case MessageTypeUser: | ||
| if err := s.conversation.SendMessage(FormatMessage(s.agentType, input.Body.Content)...); err != nil { | ||
| if err := s.conversation.Send(FormatMessage(s.agentType, input.Body.Content)...); err != nil { | ||
| return nil, xerrors.Errorf("failed to send message: %w", err) | ||
| } | ||
| case MessageTypeRaw: | ||
|
|
@@ -610,6 +627,30 @@ func (s *Server) cleanupTempDir() { | |
| } | ||
| } | ||
|
|
||
| // saveAndCleanup saves the conversation state and cleans up before shutdown | ||
| func (s *Server) saveAndCleanup(sig os.Signal, process *termexec.Process) { | ||
| // Save conversation state if configured (synchronously before closing process) | ||
| s.saveStateIfConfigured(sig.String()) | ||
|
|
||
| // Now close the process | ||
| if err := process.Close(s.logger, 5*time.Second); err != nil { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels a bit strange to control the process from this "far in". To me it would make sense to invert some of this control, i.e. change how things are wired up in If we close here, won't we likely be logging an error in |
||
| s.logger.Error("Error closing process", "signal", sig, "error", err) | ||
| } | ||
| } | ||
|
|
||
| // saveStateIfConfigured saves the conversation state if configured | ||
| func (s *Server) saveStateIfConfigured(source string) { | ||
| if s.statePersistenceConfig.SaveState && s.statePersistenceConfig.StateFile != "" { | ||
| if err := s.conversation.SaveState(s.conversation.Messages(), s.statePersistenceConfig.StateFile); err != nil { | ||
| s.logger.Error("Failed to save conversation state", "source", source, "error", err) | ||
| } else { | ||
| s.logger.Info("Saved conversation state", "source", source, "stateFile", s.statePersistenceConfig.StateFile) | ||
| } | ||
| } else { | ||
| s.logger.Warn("Save requested but state saving is not configured", "source", source) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't this print save requested for regular stop signals like SIGTERM? I think this log is only applicable for USR1. |
||
| } | ||
| } | ||
|
|
||
| // registerStaticFileRoutes sets up routes for serving static files | ||
| func (s *Server) registerStaticFileRoutes() { | ||
| chatHandler := FileServerWithIndexFallback(s.chatBasePath) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| //go:build unix | ||
|
|
||
| package httpapi | ||
|
|
||
| import ( | ||
| "context" | ||
| "os" | ||
| "os/signal" | ||
| "syscall" | ||
|
|
||
| "github.com/coder/agentapi/lib/termexec" | ||
| ) | ||
|
|
||
| // HandleSignals sets up signal handlers for: | ||
| // - SIGTERM, SIGINT, SIGHUP: save conversation state, then close the process | ||
| // - SIGUSR1: save conversation state without exiting | ||
| func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { | ||
| // Handle shutdown signals (SIGTERM, SIGINT, SIGHUP) | ||
| shutdownCh := make(chan os.Signal, 1) | ||
| signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) | ||
| go func() { | ||
| defer signal.Stop(shutdownCh) | ||
| sig := <-shutdownCh | ||
| s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) | ||
|
|
||
| s.saveAndCleanup(sig, process) | ||
| }() | ||
|
|
||
| // Handle SIGUSR1 for save without exit | ||
| saveOnlyCh := make(chan os.Signal, 1) | ||
| signal.Notify(saveOnlyCh, syscall.SIGUSR1) | ||
| go func() { | ||
35C4n0r marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| defer signal.Stop(saveOnlyCh) | ||
| for { | ||
| select { | ||
| case <-saveOnlyCh: | ||
| s.logger.Info("Received SIGUSR1, saving state without exiting") | ||
| s.saveStateIfConfigured("SIGUSR1") | ||
| case <-ctx.Done(): | ||
| return | ||
| } | ||
| } | ||
| }() | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| //go:build windows | ||
|
|
||
| package httpapi | ||
|
|
||
| import ( | ||
| "context" | ||
| "os" | ||
| "os/signal" | ||
| "syscall" | ||
|
|
||
| "github.com/coder/agentapi/lib/termexec" | ||
| ) | ||
|
|
||
| // HandleSignals sets up signal handlers for Windows. | ||
| // Only handles SIGTERM and SIGINT (SIGHUP and SIGUSR1 don't exist on Windows). | ||
| func (s *Server) HandleSignals(ctx context.Context, process *termexec.Process) { | ||
| // Handle shutdown signals (SIGTERM, SIGINT only on Windows) | ||
| shutdownCh := make(chan os.Signal, 1) | ||
| signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this compile on Windows? IIRC we can only support |
||
| go func() { | ||
| defer signal.Stop(shutdownCh) | ||
| sig := <-shutdownCh | ||
| s.logger.Info("Received shutdown signal, saving state before closing process", "signal", sig) | ||
|
|
||
| s.saveAndCleanup(sig, process) | ||
| }() | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.