diff --git a/crates/dockerized/src/prover.rs b/crates/dockerized/src/prover.rs index f3e5bde1..556f1084 100644 --- a/crates/dockerized/src/prover.rs +++ b/crates/dockerized/src/prover.rs @@ -205,6 +205,8 @@ impl ServerContainer { .inherit_env("RUST_LOG") .inherit_env("RUST_BACKTRACE") .inherit_env("NO_COLOR") + .inherit_env("ERE_PROVE_TIMEOUT_MS") + .inherit_env("ERE_PROVE_TIMEOUT_SEC") .publish(port.to_string(), port.to_string()) .name(&name); diff --git a/crates/server/cli/src/commands/server.rs b/crates/server/cli/src/commands/server.rs index edcaf328..e59afea2 100644 --- a/crates/server/cli/src/commands/server.rs +++ b/crates/server/cli/src/commands/server.rs @@ -24,7 +24,7 @@ use tokio::{ }; use tower::ServiceBuilder; use tower_http::{catch_panic::CatchPanicLayer, trace::TraceLayer}; -use tracing::info; +use tracing::{error, info}; use twirp::{ Request, Response, Router, TwirpErrorResponse, async_trait::async_trait, @@ -41,6 +41,7 @@ pub async fn run( elf: Elf, resource: ProverResource, prove_timeout: Option, + prove_hard_timeout: Option, ) -> Result<(), Error> { let resource_kind = resource.kind(); let zkvm = crate::construct_zkvm(elf, resource)?; @@ -53,6 +54,32 @@ pub async fn run( let prove_state = Arc::new(ProveState::new(prove_timeout)); let server = Arc::new(zkVMServer::new(zkvm, Arc::clone(&prove_state))); + // Spawn global watchdog for hard timeout if configured + if let Some(hard_timeout) = prove_hard_timeout { + let prove_state_clone = Arc::clone(&prove_state); + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(5)); + loop { + interval.tick().await; + + if let Some(started) = *prove_state_clone.started_at.lock() { + let elapsed = started.elapsed(); + if elapsed > hard_timeout { + error!( + "Prove exceeded hard timeout of {:?} (elapsed: {:?}), terminating server", + hard_timeout, elapsed + ); + std::process::exit(1); + } + } + } + }); + info!( + "Global prove watchdog enabled with hard timeout of {:?}", + hard_timeout + ); + } + let api_middleware = ServiceBuilder::new() .layer( TraceLayer::new_for_http() @@ -307,3 +334,78 @@ async fn shutdown_signal() { fn serialize_report_err(err: bincode::error::EncodeError) -> TwirpErrorResponse { internal(format!("failed to serialize report: {err}")) } + +#[cfg(test)] +mod tests { + use super::*; + use std::thread; + + #[test] + fn test_prove_state_timeout_detection() { + // Test soft timeout detection (used by health check) + let timeout = Duration::from_millis(100); + let prove_state = ProveState::new(Some(timeout)); + + // Initially, no prove is running + assert!(!prove_state.is_timeout()); + + // Simulate prove start + { + *prove_state.started_at.lock() = Some(Instant::now()); + } + + // Immediately after start, should not timeout + assert!(!prove_state.is_timeout()); + + // Wait for timeout to expire + thread::sleep(Duration::from_millis(150)); + + // Now it should timeout + assert!(prove_state.is_timeout()); + + // Clear the started_at (simulating prove completion) + { + *prove_state.started_at.lock() = None; + } + + // After clearing, should not timeout + assert!(!prove_state.is_timeout()); + } + + #[test] + fn test_prove_state_no_timeout_configured() { + // When no timeout is configured, is_timeout should always return false + let prove_state = ProveState::new(None); + + assert!(!prove_state.is_timeout()); + + // Even with a prove running + { + *prove_state.started_at.lock() = Some(Instant::now()); + } + + assert!(!prove_state.is_timeout()); + } + + #[test] + fn test_prove_in_flight_guard() { + // Test that ProveInFlight guard correctly sets and clears started_at + let prove_state = Arc::new(ProveState::new(None)); + + // Initially no prove running + assert!(prove_state.started_at.lock().is_none()); + + { + // Create guard (simulates prove start) + let _guard = ProveInFlight::new(Arc::clone(&prove_state)); + + // started_at should be set + assert!(prove_state.started_at.lock().is_some()); + + // Guard is dropped here + } + + // After guard drop, started_at should be cleared + assert!(prove_state.started_at.lock().is_none()); + } +} diff --git a/crates/server/cli/src/main.rs b/crates/server/cli/src/main.rs index c07adf0b..4a03fe00 100644 --- a/crates/server/cli/src/main.rs +++ b/crates/server/cli/src/main.rs @@ -44,6 +44,11 @@ struct Args { /// milliseconds. Disabled when not set. #[arg(long, env = "ERE_PROVE_TIMEOUT_MS")] prove_timeout_ms: Option, + /// Hard timeout: terminate server process if a single prove has been running longer than + /// this many seconds. Disabled when not set. This forces container restart to recover from + /// deadlocked provers. + #[arg(long, env = "ERE_PROVE_TIMEOUT_SEC")] + prove_hard_timeout_sec: Option, #[command( flatten, next_help_heading = "ELF source (read from stdin if none set)" @@ -99,7 +104,9 @@ async fn main() -> Result<(), Error> { match args.command { Command::Server(resource) => { let prove_timeout = args.prove_timeout_ms.map(Duration::from_millis); - commands::server::run(args.port, elf, resource, prove_timeout).await? + let prove_hard_timeout = args.prove_hard_timeout_sec.map(Duration::from_secs); + commands::server::run(args.port, elf, resource, prove_timeout, prove_hard_timeout) + .await? } Command::Keygen { program_vk_path } => commands::keygen::run(elf, &program_vk_path)?, }