diff --git a/Cargo.lock b/Cargo.lock index 32da2cc26..af1a2d6ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1450,6 +1450,26 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "const_format" +version = "0.2.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7faa7469a93a566e9ccc1c73fe783b4a65c274c5ace346038dca9c39fe0030ad" +dependencies = [ + "const_format_proc_macros", +] + +[[package]] +name = "const_format_proc_macros" +version = "0.2.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d57c2eccfb16dbac1f4e61e206105db5820c9d26c3c472bc17c774259ef7744" +dependencies = [ + "proc-macro2", + "quote", + "unicode-xid", +] + [[package]] name = "constant_time_eq" version = "0.4.2" @@ -3388,6 +3408,7 @@ dependencies = [ "chrono-tz", "clap", "clap_complete", + "const_format", "cron", "crossterm 0.28.1", "deadpool-postgres", diff --git a/Cargo.toml b/Cargo.toml index c3f21efe8..c517139d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -176,6 +176,7 @@ hex = "0.4.3" rusqlite = { version = "0.32", optional = true } json5 = { version = "0.4", optional = true } tempfile = { version = "3", optional = true } +const_format = { version = "0.2.35", default-features = false } # macOS keychain [target.'cfg(target_os = "macos")'.dependencies] diff --git a/docs/contents.md b/docs/contents.md index 60cc641cb..2ac61acaf 100644 --- a/docs/contents.md +++ b/docs/contents.md @@ -43,6 +43,9 @@ provider interfaces, adapters, and the places embeddings are used. - [Jobs and routines](jobs-and-routines.md) covers the scheduler, background jobs, routines engine, touchpoints, and extension seams. +- [Worker-orchestrator contract](worker-orchestrator-contract.md) documents the + sandbox worker HTTP boundary, the shared route constants, and the reporting + split between authoritative status and best-effort events. - [Agent skills support](agent-skills-support.md) explains how skills are discovered, installed, selected, and injected into model context. - [Smart routing spec](smart-routing-spec.md) captures the current design for diff --git a/docs/worker-orchestrator-contract.md b/docs/worker-orchestrator-contract.md new file mode 100644 index 000000000..87a1403cb --- /dev/null +++ b/docs/worker-orchestrator-contract.md @@ -0,0 +1,156 @@ +# Worker-orchestrator contract + +This document is for maintainers who need to change the hosted worker path +without breaking the orchestrator boundary. It explains the current transport +contract, the dependency-injection seams used for tests and production, and the +event-reporting split between authoritative state and best-effort visibility. + +## 1. Scope and source of truth + +The worker runtime runs inside a container and talks back to the orchestrator +over HTTP. The shared transport contract lives in +[the worker API module](../src/worker/api.rs) and +[its shared types](../src/worker/api/types.rs). The orchestrator side imports +the same route constants and payload types from that module instead of +re-declaring them. + +This document is descriptive for the current implementation. The code remains +the authoritative source of truth for the wire format. + +## 2. Boundary model + +The worker and orchestrator have distinct responsibilities: + +- The orchestrator owns job lifecycle, credential issuance, event ingestion, + and proxied external access. +- The worker owns local reasoning, tool execution inside the sandbox, and + periodic reporting back to the orchestrator. +- The shared HTTP boundary exists so the worker can stay isolated from the + host process while still using the host's approved network, credential, and + observability surfaces. + +Figure 1. Worker-orchestrator boundary and reporting channels. + +```mermaid +flowchart LR + WorkerRuntime[WorkerRuntime in container] + Delegate[ContainerDelegate] + Client[WorkerHttpClient] + Orchestrator[Orchestrator worker API] + Timeline[Job event timeline] + Status[Authoritative status store] + Completion[Authoritative completion store] + + WorkerRuntime --> Delegate + WorkerRuntime --> Client + Delegate --> Client + Client --> Orchestrator + Orchestrator --> Timeline + Orchestrator --> Status + Orchestrator --> Completion +``` + +## 3. Shared route constants + +All worker endpoints are declared once in `src/worker/api/types.rs` as paired +`*_PATH` and `*_ROUTE` constants. The design intent is: + +- `*_PATH` is the relative suffix used by `WorkerHttpClient`. +- `*_ROUTE` is the fully scoped Axum route used by the orchestrator router. +- Both sides derive their concrete URLs from the same source strings. + +The current contract includes: + +Worker-orchestrator HTTP endpoints and their purposes. + +| Endpoint | Purpose | +| --- | --- | +| `GET /worker/{job_id}/job` | Fetch the sandboxed job description | +| `GET /worker/{job_id}/credentials` | Deliver job-scoped credentials for child-process injection | +| `POST /worker/{job_id}/status` | Persist authoritative progress state | +| `POST /worker/{job_id}/complete` | Persist authoritative terminal outcome | +| `POST /worker/{job_id}/event` | Append user-visible timeline events | +| `GET /worker/{job_id}/prompt` | Poll orchestrator-injected follow-up prompts | +| `POST /worker/{job_id}/llm/complete` | Proxy plain language model (LLM) completion | +| `POST /worker/{job_id}/llm/complete_with_tools` | Proxy tool-capable language model (LLM) completion | +| `GET /worker/{job_id}/tools/catalog` | Fetch hosted-visible remote tool definitions | +| `POST /worker/{job_id}/tools/execute` | Execute a hosted remote tool through the orchestrator | + +Compile-time assertions in the worker API tests lock the canonical route values +so accidental path drift fails the build before runtime tests execute. + +## 4. Dependency injection and construction + +`WorkerRuntime` uses two constructors with distinct roles: + +- `WorkerRuntime::new(config, client)` is the primary constructor. It is used + by tests and by any caller that already owns a prepared `WorkerHttpClient`. +- `WorkerRuntime::from_env(config)` is the production convenience wrapper. It + reads `IRONCLAW_WORKER_TOKEN` and then delegates to `new`, which builds the + HTTP client with the shared timeout and error mapping. + +This split exists so tests can validate runtime behaviour without relying on +ambient environment state. It also gives construction-time validation one +obvious home: `new` checks that `WorkerConfig` and `WorkerHttpClient` agree on +job identity and orchestrator base URL before the runtime starts. + +`WorkerHttpClient::new(...)` follows the same pattern for tests, while +`WorkerHttpClient::from_env(...)` is reserved for production bootstrap. + +## 5. Authoritative reports versus best-effort events + +The worker emits two classes of outbound signal: + +- Authoritative reports: + - `report_status` + - `report_complete` +- Best-effort timeline events: + - `post_event` + - `report_status_lossy` + +The distinction matters: + +- Status and completion calls define the durable job record. If they fail at a + point where correctness depends on them, the worker treats that as a real + error. +- Event posting exists for operator visibility. It enriches the browser and + audit timeline, but it must not be allowed to block or invalidate terminal + completion reporting. + +`ContainerDelegate` therefore, uses a background task and bounded queue for +event posting. `shutdown()` closes the queue and waits for the event worker, so +buffered events flush before the delegate disappears. + +`WorkerRuntime::post_event(...)` also uses a bounded timeout around terminal +event publication, so the final `report_complete(...)` call remains the +authoritative acknowledgement path. + +## 6. Credential handling + +Credentials are fetched through `GET /worker/{job_id}/credentials` and +deserialized into `CredentialResponse`. The worker runtime does not write them +into global process environment variables. Instead: + +1. `WorkerRuntime::hydrate_credentials()` fetches the granted credentials. +2. The runtime stores them in `extra_env`. +3. Tool execution passes `extra_env` through `JobContext` into child processes. + +This keeps credential scope limited to the worker execution path and avoids +cross-test or cross-job global environment mutation. + +## 7. Prompt polling and hosted tool context + +The worker loop polls `GET /worker/{job_id}/prompt` before LLM calls. The +orchestrator can use that channel to inject operator prompts or follow-up work +without restarting the worker process. + +Hosted remote tools use a parallel mechanism: + +1. The worker fetches the hosted tool catalogue from the orchestrator. +2. The worker registers local proxy wrappers using the orchestrator-provided + canonical `ToolDefinition` values. +3. The runtime merges those definitions into the reasoning context alongside + container-local tools. + +The shared route constants and transport types are what keep that hosted tool +surface consistent across the sandbox boundary. diff --git a/src/error/mod.rs b/src/error/mod.rs index 473b017df..c4a1087db 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -25,7 +25,7 @@ pub use self::repair::RepairError; pub use self::routine::RoutineError; pub use self::safety::SafetyError; pub use self::tool::ToolError; -pub use self::worker::WorkerError; +pub use self::worker::{ConfigMismatchField, WorkerError}; pub use self::workspace::WorkspaceError; pub use crate::llm::error::LlmError; diff --git a/src/error/worker.rs b/src/error/worker.rs index 918b241eb..615b81e39 100644 --- a/src/error/worker.rs +++ b/src/error/worker.rs @@ -4,6 +4,24 @@ use std::time::Duration; use uuid::Uuid; +/// Configuration field that mismatched between worker config and HTTP client. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ConfigMismatchField { + /// The job_id field mismatched. + JobId, + /// The orchestrator_url field mismatched. + OrchestratorUrl, +} + +impl std::fmt::Display for ConfigMismatchField { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::JobId => write!(f, "job_id"), + Self::OrchestratorUrl => write!(f, "orchestrator_url"), + } + } +} + /// Worker errors (container-side execution). #[derive(Debug, thiserror::Error)] pub enum WorkerError { @@ -76,4 +94,14 @@ pub enum WorkerError { /// The worker token environment variable was not available at startup. #[error("Missing worker token (IRONCLAW_WORKER_TOKEN not set)")] MissingToken, + + /// The worker configuration does not match the provided HTTP client. + /// + /// `field` identifies which configuration field mismatched, and `reason` + /// describes the mismatch. + #[error("Worker configuration mismatch for {field}: {reason}")] + ConfigMismatch { + field: ConfigMismatchField, + reason: String, + }, } diff --git a/src/orchestrator/api.rs b/src/orchestrator/api.rs index 3b5838f9d..02d77c19f 100644 --- a/src/orchestrator/api.rs +++ b/src/orchestrator/api.rs @@ -24,7 +24,11 @@ mod handler_support; mod handlers; mod remote_tools; -use crate::worker::api::{REMOTE_TOOL_CATALOG_ROUTE, REMOTE_TOOL_EXECUTE_ROUTE}; +use crate::worker::api::{ + COMPLETE_ROUTE, CREDENTIALS_ROUTE, EVENT_ROUTE, JOB_ROUTE, LLM_COMPLETE_ROUTE, + LLM_COMPLETE_WITH_TOOLS_ROUTE, PROMPT_ROUTE, REMOTE_TOOL_CATALOG_ROUTE, + REMOTE_TOOL_EXECUTE_ROUTE, STATUS_ROUTE, WORKER_HEALTH_ROUTE, +}; use handler_support::{get_credentials_handler, get_prompt_handler}; use handlers::{ execute_remote_tool, get_job, get_remote_tool_catalog, health_check, job_event_handler, @@ -65,25 +69,22 @@ impl OrchestratorApi { pub fn router(state: OrchestratorState) -> Router { Router::new() // Worker routes: authenticated via route_layer middleware. - .route("/worker/{job_id}/job", get(get_job)) - .route("/worker/{job_id}/llm/complete", post(llm_complete)) - .route( - "/worker/{job_id}/llm/complete_with_tools", - post(llm_complete_with_tools), - ) + .route(JOB_ROUTE, get(get_job)) + .route(LLM_COMPLETE_ROUTE, post(llm_complete)) + .route(LLM_COMPLETE_WITH_TOOLS_ROUTE, post(llm_complete_with_tools)) .route(REMOTE_TOOL_CATALOG_ROUTE, get(get_remote_tool_catalog)) .route(REMOTE_TOOL_EXECUTE_ROUTE, post(execute_remote_tool)) - .route("/worker/{job_id}/status", post(report_status)) - .route("/worker/{job_id}/complete", post(report_complete)) - .route("/worker/{job_id}/event", post(job_event_handler)) - .route("/worker/{job_id}/prompt", get(get_prompt_handler)) - .route("/worker/{job_id}/credentials", get(get_credentials_handler)) + .route(STATUS_ROUTE, post(report_status)) + .route(COMPLETE_ROUTE, post(report_complete)) + .route(EVENT_ROUTE, post(job_event_handler)) + .route(PROMPT_ROUTE, get(get_prompt_handler)) + .route(CREDENTIALS_ROUTE, get(get_credentials_handler)) .route_layer(axum::middleware::from_fn_with_state( state.token_store.clone(), worker_auth_middleware, )) // Unauthenticated routes (added after the layer). - .route("/health", get(health_check)) + .route(WORKER_HEALTH_ROUTE, get(health_check)) .with_state(state) } diff --git a/src/tools/builtin/worker_remote_tool_proxy/tests.rs b/src/tools/builtin/worker_remote_tool_proxy/tests.rs index 0e7eb7053..889f11ea2 100644 --- a/src/tools/builtin/worker_remote_tool_proxy/tests.rs +++ b/src/tools/builtin/worker_remote_tool_proxy/tests.rs @@ -65,11 +65,10 @@ async fn proxy_test_server() -> anyhow::Result { let _ = axum::serve(listener, router).await; }); let job_id = Uuid::new_v4(); - let client = Arc::new(WorkerHttpClient::new( - format!("http://{}", addr), - job_id, - "test-token".to_string(), - )); + let client = Arc::new( + WorkerHttpClient::new(format!("http://{}", addr), job_id, "test-token".to_string()) + .context("test client should build")?, + ); Ok(ProxyTestServer { client, job_id, @@ -183,11 +182,14 @@ async fn worker_remote_tool_proxy_preserves_full_tool_definition_fields() { "Complex tool for proxy definition fidelity testing", ); - let client = Arc::new(WorkerHttpClient::new( - "http://127.0.0.1:0".to_string(), - Uuid::new_v4(), - "test-token".to_string(), - )); + let client = Arc::new( + WorkerHttpClient::new( + "http://127.0.0.1:0".to_string(), + Uuid::new_v4(), + "test-token".to_string(), + ) + .expect("test client should build"), + ); let proxy = WorkerRemoteToolProxy::new(complex_definition.clone(), client); let reconstructed = ToolDefinition { @@ -243,11 +245,10 @@ async fn worker_remote_tool_proxy_routes_execution_through_orchestrator_endpoint }); let job_id = Uuid::new_v4(); - let client = Arc::new(WorkerHttpClient::new( - format!("http://{}", addr), - job_id, - "test-token".to_string(), - )); + let client = Arc::new( + WorkerHttpClient::new(format!("http://{}", addr), job_id, "test-token".to_string()) + .context("test client should build")?, + ); let proxy = WorkerRemoteToolProxy::new( ToolDefinition { name: "route_test_tool".to_string(), @@ -272,7 +273,7 @@ async fn worker_remote_tool_proxy_routes_execution_through_orchestrator_endpoint let (route_path, received_job_id, tool_name) = &requests[0]; assert_eq!( route_path, - &format!("/worker/{}/tools/execute", job_id), + &REMOTE_TOOL_EXECUTE_ROUTE.replace("{job_id}", &job_id.to_string()), "proxy must route execution through the correct orchestrator endpoint" ); assert_eq!(received_job_id, &job_id); diff --git a/src/worker/api.rs b/src/worker/api.rs index c2646ec58..83fb89d17 100644 --- a/src/worker/api.rs +++ b/src/worker/api.rs @@ -12,19 +12,27 @@ use crate::llm::{ }; use crate::tools::ToolOutput; +mod client_methods; mod types; -use error_mapping::map_remote_tool_status; - pub use types::{ - CompletionReport, CredentialResponse, FinishReason as ProxyFinishReason, JobDescription, - JobEventPayload, JobEventType, PromptResponse, ProxyCompletionRequest, ProxyCompletionResponse, + COMPLETE_PATH, COMPLETE_ROUTE, CREDENTIALS_PATH, CREDENTIALS_ROUTE, CompletionReport, + CredentialResponse, EVENT_PATH, EVENT_ROUTE, FinishReason as ProxyFinishReason, JOB_PATH, + JOB_ROUTE, JobDescription, JobEventPayload, JobEventType, LLM_COMPLETE_PATH, + LLM_COMPLETE_ROUTE, LLM_COMPLETE_WITH_TOOLS_PATH, LLM_COMPLETE_WITH_TOOLS_ROUTE, PROMPT_PATH, + PROMPT_ROUTE, PromptResponse, ProxyCompletionRequest, ProxyCompletionResponse, ProxyToolCompletionRequest, ProxyToolCompletionResponse, REMOTE_TOOL_CATALOG_PATH, REMOTE_TOOL_CATALOG_ROUTE, REMOTE_TOOL_EXECUTE_PATH, REMOTE_TOOL_EXECUTE_ROUTE, RemoteToolCatalogResponse, RemoteToolExecutionRequest, RemoteToolExecutionResponse, - StatusUpdate, WorkerState, + STATUS_PATH, STATUS_ROUTE, StatusUpdate, TerminalResult, WORKER_HEALTH_PATH, + WORKER_HEALTH_ROUTE, WorkerState, job_scoped_path, worker_job_url, }; /// HTTP client that a container worker uses to talk to the orchestrator. +/// +/// This client is the worker-side transport boundary for authoritative job +/// state. It owns the per-job base URL, attaches the worker bearer token on +/// every request, and keeps the worker implementation aligned with the +/// orchestrator's shared route constants and JSON payloads. pub struct WorkerHttpClient { client: reqwest::Client, orchestrator_url: String, @@ -40,41 +48,54 @@ impl WorkerHttpClient { pub fn from_env(orchestrator_url: String, job_id: Uuid) -> Result { let token = std::env::var("IRONCLAW_WORKER_TOKEN").map_err(|_| WorkerError::MissingToken)?; - - Ok(Self { - client: reqwest::Client::builder() - .timeout(REQUEST_TIMEOUT) - .build() - .map_err(|e| WorkerError::ConnectionFailed { - url: orchestrator_url.clone(), - reason: format!("failed to build HTTP client: {}", e), - })?, - orchestrator_url: orchestrator_url.trim_end_matches('/').to_string(), - job_id, - token, - }) + Self::new(orchestrator_url, job_id, token) } /// Create with an explicit token (for testing). - pub fn new(orchestrator_url: String, job_id: Uuid, token: String) -> Self { - Self { - client: reqwest::Client::builder() - .timeout(REQUEST_TIMEOUT) - .build() - .unwrap_or_default(), + /// + /// This constructor exists so tests and injected runtimes can avoid + /// ambient environment reads while still exercising the same request path + /// and route construction as production workers. + pub fn new(orchestrator_url: String, job_id: Uuid, token: String) -> Result { + let client = reqwest::Client::builder() + .timeout(REQUEST_TIMEOUT) + .build() + .map_err(|e| WorkerError::ConnectionFailed { + url: orchestrator_url.clone(), + reason: format!("failed to build HTTP client: {}", e), + })?; + Ok(Self { + client, orchestrator_url: orchestrator_url.trim_end_matches('/').to_string(), job_id, token, - } + }) } /// Get the base orchestrator URL. + /// + /// Returns the normalized base URL after trailing-slash trimming, which is + /// the canonical prefix used for all job-scoped worker endpoints. pub fn orchestrator_url(&self) -> &str { &self.orchestrator_url } + /// Get the job ID. + /// + /// The worker and orchestrator treat this as part of the transport + /// identity. Every request path is derived from this value. + pub fn job_id(&self) -> Uuid { + self.job_id + } + fn url(&self, path: &str) -> String { - format!("{}/worker/{}/{}", self.orchestrator_url, self.job_id, path) + let base = self.orchestrator_url.trim_end_matches('/'); + format!( + "{}/{}", + base, + crate::worker::api::job_scoped_path(&self.job_id.to_string(), path) + .trim_start_matches('/') + ) } async fn send_post_json( @@ -100,14 +121,15 @@ impl WorkerHttpClient { path: &str, context: &str, ) -> Result { + let url = self.url(path); let resp = self .client - .get(self.url(path)) + .get(&url) .bearer_auth(&self.token) .send() .await .map_err(|e| WorkerError::ConnectionFailed { - url: self.orchestrator_url.clone(), + url, reason: e.to_string(), })?; @@ -147,7 +169,7 @@ impl WorkerHttpClient { /// Fetch the job description from the orchestrator. pub async fn get_job(&self) -> Result { - self.get_json("job", "GET /job").await + self.get_json(JOB_PATH, "GET /job").await } /// Proxy an LLM completion request through the orchestrator. @@ -164,7 +186,7 @@ impl WorkerHttpClient { }; let proxy_resp: ProxyCompletionResponse = self - .post_json("llm/complete", &proxy_req, "LLM complete") + .post_json(LLM_COMPLETE_PATH, &proxy_req, "LLM complete") .await?; Ok(CompletionResponse { @@ -192,7 +214,11 @@ impl WorkerHttpClient { }; let proxy_resp: ProxyToolCompletionResponse = self - .post_json("llm/complete_with_tools", &proxy_req, "LLM tool complete") + .post_json( + LLM_COMPLETE_WITH_TOOLS_PATH, + &proxy_req, + "LLM tool complete", + ) .await?; Ok(ToolCompletionResponse { @@ -242,159 +268,36 @@ impl WorkerHttpClient { Ok(proxy_resp.output) } +} - /// Report status to the orchestrator. - pub async fn report_status(&self, update: &StatusUpdate) -> Result<(), WorkerError> { - let resp = self - .client - .post(self.url("status")) - .bearer_auth(&self.token) - .json(update) - .send() - .await - .map_err(|e| WorkerError::ConnectionFailed { - url: self.orchestrator_url.clone(), - reason: e.to_string(), - })?; - - if !resp.status().is_success() { - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - return Err(WorkerError::OrchestratorRejected { - job_id: self.job_id, - reason: format!("status endpoint returned {}: {}", status, body), - }); - } - - Ok(()) - } - - /// Report a non-terminal status update without failing the worker on rejection. - pub async fn report_status_lossy(&self, update: &StatusUpdate) { - if let Err(error) = self.report_status(update).await { - tracing::warn!( - job_id = %self.job_id, - state = %update.state, - iteration = update.iteration, - error = %error, - "Worker status report failed" - ); - } - } - - /// Post a job event to the orchestrator (fire-and-forget style, logs on failure). - pub async fn post_event(&self, payload: &JobEventPayload) { - let resp = self - .client - .post(self.url("event")) - .bearer_auth(&self.token) - .json(payload) - .send() - .await; - - match resp { - Ok(r) if !r.status().is_success() => { - tracing::debug!( - job_id = %self.job_id, - event_type = %payload.event_type, - status = %r.status(), - "Job event POST rejected" - ); - } - Err(e) => { - tracing::debug!( - job_id = %self.job_id, - event_type = %payload.event_type, - "Job event POST failed: {}", e - ); - } - _ => {} - } - } - - /// Poll the orchestrator for a follow-up prompt. - /// - /// Returns `None` if no prompt is available (204 No Content). - pub async fn poll_prompt(&self) -> Result, WorkerError> { - let resp = self - .client - .get(self.url("prompt")) - .bearer_auth(&self.token) - .send() - .await - .map_err(|e| WorkerError::ConnectionFailed { - url: self.orchestrator_url.clone(), - reason: e.to_string(), - })?; - - if resp.status() == reqwest::StatusCode::NO_CONTENT { - return Ok(None); - } - - if !resp.status().is_success() { - return Err(WorkerError::OrchestratorRejected { - job_id: self.job_id, - reason: format!("prompt endpoint returned {}", resp.status()), - }); - } - - let prompt: PromptResponse = - resp.json().await.map_err(|e| WorkerError::LlmProxyFailed { - reason: format!("failed to parse prompt response: {}", e), - })?; - - Ok(Some(prompt)) - } - - /// Fetch credentials granted to this job from the orchestrator. - /// - /// Returns an empty vec if no credentials are granted (204 No Content) - /// or if the endpoint returns 404. The caller should set each credential - /// as an environment variable before starting the execution loop. - pub async fn fetch_credentials(&self) -> Result, WorkerError> { - let resp = self - .client - .get(self.url("credentials")) - .bearer_auth(&self.token) - .send() - .await - .map_err(|e| WorkerError::ConnectionFailed { - url: self.orchestrator_url.clone(), - reason: e.to_string(), - })?; - - // 204 or 404 means no credentials granted, not an error - if resp.status() == reqwest::StatusCode::NO_CONTENT - || resp.status() == reqwest::StatusCode::NOT_FOUND - { - return Ok(vec![]); - } - - if !resp.status().is_success() { - return Err(WorkerError::SecretResolveFailed { - secret_name: "(all)".to_string(), - reason: format!("credentials endpoint returned {}", resp.status()), - }); +/// Map HTTP response status to appropriate WorkerError for remote tool execution. +async fn map_remote_tool_status(resp: reqwest::Response) -> WorkerError { + let status = resp.status(); + let retry_after = resp + .headers() + .get("retry-after") + .and_then(|value| value.to_str().ok()) + .and_then(|value| value.parse::().ok()) + .map(std::time::Duration::from_secs); + let body = resp.text().await.unwrap_or_default(); + let reason = format!( + "Remote tool execution: orchestrator returned {}: {}", + status, body + ); + + match status { + reqwest::StatusCode::BAD_REQUEST => WorkerError::BadRequest { reason }, + reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => { + WorkerError::Unauthorized { reason } } - - resp.json() - .await - .map_err(|e| WorkerError::SecretResolveFailed { - secret_name: "(all)".to_string(), - reason: format!("failed to parse credentials response: {}", e), - }) - } - - /// Signal job completion to the orchestrator. - pub async fn report_complete(&self, report: &CompletionReport) -> Result<(), WorkerError> { - let _: serde_json::Value = self - .post_json("complete", report, "report complete") - .await?; - Ok(()) + reqwest::StatusCode::TOO_MANY_REQUESTS => WorkerError::RateLimited { + reason, + retry_after, + }, + reqwest::StatusCode::BAD_GATEWAY => WorkerError::BadGateway { reason }, + _ => WorkerError::RemoteToolFailed { reason }, } } #[cfg(test)] mod tests; - -mod error_mapping; diff --git a/src/worker/api/client_methods.rs b/src/worker/api/client_methods.rs new file mode 100644 index 000000000..91d89aa5a --- /dev/null +++ b/src/worker/api/client_methods.rs @@ -0,0 +1,181 @@ +//! Additional WorkerHttpClient methods for status reporting, events, and credentials. + +use serde::Serialize; + +use crate::error::WorkerError; +use crate::worker::api::{ + COMPLETE_PATH, CREDENTIALS_PATH, CompletionReport, CredentialResponse, EVENT_PATH, + JobEventPayload, PROMPT_PATH, PromptResponse, STATUS_PATH, StatusUpdate, +}; + +use super::WorkerHttpClient; + +impl WorkerHttpClient { + /// Send a POST request with a JSON payload and require a 2xx response. + /// + /// Maps transport failures to `WorkerError::ConnectionFailed` and + /// non-success HTTP responses to `WorkerError::OrchestratorRejected`. + async fn post_and_require_success( + &self, + path: &str, + payload: &T, + ) -> Result<(), WorkerError> { + let url = self.url(path); + let resp = self + .client + .post(&url) + .bearer_auth(&self.token) + .json(payload) + .send() + .await + .map_err(|e| WorkerError::ConnectionFailed { + url: url.clone(), + reason: e.to_string(), + })?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(WorkerError::OrchestratorRejected { + job_id: self.job_id, + reason: format!("{} endpoint returned {}: {}", path, status, body), + }); + } + + Ok(()) + } + + /// Report status to the orchestrator. + /// + /// Status updates are the authoritative progress signal for the hosted + /// worker lifecycle. Callers should use this when rejection should abort + /// execution, such as startup and terminal reporting. + pub async fn report_status(&self, update: &StatusUpdate) -> Result<(), WorkerError> { + self.post_and_require_success(STATUS_PATH, update).await + } + + /// Report a non-terminal status update without failing the worker on rejection. + /// + /// This is intended for opportunistic progress updates during long-running + /// loops. It preserves observability when the orchestrator accepts the + /// update, but it intentionally does not let transient reporting failures + /// derail the worker's primary execution flow. + pub async fn report_status_lossy(&self, update: &StatusUpdate) { + if let Err(error) = self.report_status(update).await { + tracing::warn!( + job_id = %self.job_id, + state = %update.state, + iteration = update.iteration, + error = %error, + "Worker status report failed" + ); + } + } + + /// Post a job event to the orchestrator. + /// + /// Returns `Ok(())` on success, `WorkerError::ConnectionFailed` if the + /// request fails at the transport layer, or `WorkerError::OrchestratorRejected` + /// if the endpoint returns a non-2xx status. These events feed the + /// orchestrator's user-visible job timeline and are separate from the + /// authoritative status and completion reports. + pub async fn post_event(&self, payload: &JobEventPayload) -> Result<(), WorkerError> { + self.post_and_require_success(EVENT_PATH, payload).await + } + + /// Poll the orchestrator for a follow-up prompt. + /// + /// Returns `None` if no prompt is available (204 No Content). The worker + /// loop uses this to merge orchestrator-provided operator nudges into the + /// local reasoning context without treating "no prompt" as an error. + pub async fn poll_prompt(&self) -> Result, WorkerError> { + let url = self.url(PROMPT_PATH); + let resp = self + .client + .get(&url) + .bearer_auth(&self.token) + .send() + .await + .map_err(|e| WorkerError::ConnectionFailed { + url: url.clone(), + reason: e.to_string(), + })?; + + if resp.status() == reqwest::StatusCode::NO_CONTENT { + return Ok(None); + } + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(WorkerError::OrchestratorRejected { + job_id: self.job_id, + reason: format!("prompt endpoint returned {}: {}", status, body), + }); + } + + let prompt: PromptResponse = + resp.json() + .await + .map_err(|e| WorkerError::OrchestratorRejected { + job_id: self.job_id, + reason: format!("failed to parse prompt response: {}", e), + })?; + + Ok(Some(prompt)) + } + + /// Fetch credentials granted to this job from the orchestrator. + /// + /// Returns an empty vec if no credentials are granted (204 No Content). + /// Fetched credentials should be handed off to the + /// [`WorkerRuntime`](crate::worker::container::WorkerRuntime) credential + /// hydration path, which stores them in `extra_env` and injects them into + /// child processes. Callers should use that runtime hydrate/injection + /// pathway rather than setting global environment variables directly. This + /// keeps credential scope local to the worker execution context rather than + /// mutating global process state. + pub async fn fetch_credentials(&self) -> Result, WorkerError> { + let url = self.url(CREDENTIALS_PATH); + let resp = self + .client + .get(&url) + .bearer_auth(&self.token) + .send() + .await + .map_err(|e| WorkerError::ConnectionFailed { + url: url.clone(), + reason: e.to_string(), + })?; + + // 204 means no credentials granted, not an error + if resp.status() == reqwest::StatusCode::NO_CONTENT { + return Ok(vec![]); + } + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(WorkerError::SecretResolveFailed { + secret_name: "(all)".to_string(), + reason: format!("credentials endpoint returned {}: {}", status, body), + }); + } + + resp.json() + .await + .map_err(|e| WorkerError::SecretResolveFailed { + secret_name: "(all)".to_string(), + reason: format!("failed to parse credentials response: {}", e), + }) + } + + /// Signal job completion to the orchestrator. + /// + /// Completion reports are the authoritative terminal record for worker + /// execution. Event posting is intentionally separate and best-effort so a + /// slow event sink cannot block this terminal acknowledgement. + pub async fn report_complete(&self, report: &CompletionReport) -> Result<(), WorkerError> { + self.post_and_require_success(COMPLETE_PATH, report).await + } +} diff --git a/src/worker/api/error_mapping.rs b/src/worker/api/error_mapping.rs deleted file mode 100644 index 9e3633e21..000000000 --- a/src/worker/api/error_mapping.rs +++ /dev/null @@ -1,29 +0,0 @@ -//! Error mapping helpers for worker-side HTTP transport. - -use crate::error::WorkerError; - -pub(super) async fn map_remote_tool_status(resp: reqwest::Response) -> WorkerError { - let status = resp.status(); - let retry_after = resp - .headers() - .get("retry-after") - .and_then(|value| value.to_str().ok()) - .and_then(|value| value.parse::().ok()) - .map(std::time::Duration::from_secs); - let body = resp.text().await.unwrap_or_default(); - let reason = format!( - "Remote tool execution: orchestrator returned {}: {}", - status, body - ); - - match status { - reqwest::StatusCode::BAD_REQUEST => WorkerError::BadRequest { reason }, - reqwest::StatusCode::FORBIDDEN => WorkerError::Unauthorized { reason }, - reqwest::StatusCode::TOO_MANY_REQUESTS => WorkerError::RateLimited { - reason, - retry_after, - }, - reqwest::StatusCode::BAD_GATEWAY => WorkerError::BadGateway { reason }, - _ => WorkerError::RemoteToolFailed { reason }, - } -} diff --git a/src/worker/api/proxy_types.rs b/src/worker/api/proxy_types.rs new file mode 100644 index 000000000..6b0a78d17 --- /dev/null +++ b/src/worker/api/proxy_types.rs @@ -0,0 +1,118 @@ +//! Proxy transport types used by the worker-orchestrator boundary. +//! +//! This module defines the serializable request and response payloads used for +//! proxied completions, including shapes built around [`ChatMessage`], +//! [`ToolCall`], and [`ToolDefinition`]. + +use serde::{Deserialize, Serialize}; + +use crate::llm::{ChatMessage, ToolCall, ToolDefinition}; + +/// Provider finish reason transported between orchestrator and worker. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FinishReason { + Stop, + Length, + #[serde(alias = "tool_calls")] + ToolUse, + ContentFilter, + #[serde(other)] + Unknown, +} + +impl From for FinishReason { + fn from(value: crate::llm::FinishReason) -> Self { + match value { + crate::llm::FinishReason::Stop => Self::Stop, + crate::llm::FinishReason::Length => Self::Length, + crate::llm::FinishReason::ToolUse => Self::ToolUse, + crate::llm::FinishReason::ContentFilter => Self::ContentFilter, + crate::llm::FinishReason::Unknown => Self::Unknown, + } + } +} + +impl From for crate::llm::FinishReason { + fn from(value: FinishReason) -> Self { + match value { + FinishReason::Stop => Self::Stop, + FinishReason::Length => Self::Length, + FinishReason::ToolUse => Self::ToolUse, + FinishReason::ContentFilter => Self::ContentFilter, + FinishReason::Unknown => Self::Unknown, + } + } +} + +/// Request payload for a completion proxied through the orchestrator. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProxyCompletionRequest { + /// Conversation history forwarded to the orchestrator-backed provider. + pub messages: Vec, + /// Optional model override requested by the worker. + pub model: Option, + /// Optional token ceiling for the completion. + pub max_tokens: Option, + /// Optional sampling temperature for the completion. + pub temperature: Option, + /// Optional stop-sequence list forwarded unchanged to the provider. + pub stop_sequences: Option>, +} + +/// Completion result returned by the orchestrator-backed provider. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProxyCompletionResponse { + /// Assistant text produced by the proxied completion call. + pub content: String, + /// Provider-reported prompt token usage. + pub input_tokens: u32, + /// Provider-reported completion token usage. + pub output_tokens: u32, + /// Provider finish reason normalized into a transport enum. + pub finish_reason: FinishReason, + /// Tokens served from cache when the provider exposes that metric. + #[serde(default)] + pub cache_read_input_tokens: u32, + /// Tokens written into cache when the provider exposes that metric. + #[serde(default)] + pub cache_creation_input_tokens: u32, +} + +/// Tool-capable completion request forwarded to the orchestrator. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProxyToolCompletionRequest { + /// Conversation history forwarded to the orchestrator-backed provider. + pub messages: Vec, + /// Tool definitions currently visible to the worker. + pub tools: Vec, + /// Optional model override requested by the worker. + pub model: Option, + /// Optional token ceiling for the completion. + pub max_tokens: Option, + /// Optional sampling temperature for the completion. + pub temperature: Option, + /// Optional provider-specific tool-choice override. + pub tool_choice: Option, +} + +/// Tool-capable completion result returned by the orchestrator. +#[derive(Debug, Serialize, Deserialize)] +pub struct ProxyToolCompletionResponse { + /// Optional assistant text returned alongside tool calls. + pub content: Option, + /// Tool calls selected by the orchestrator-backed provider. + pub tool_calls: Vec, + /// Provider-reported prompt token usage. + pub input_tokens: u32, + /// Provider-reported completion token usage. + pub output_tokens: u32, + /// Provider finish reason normalized into a transport enum. + pub finish_reason: FinishReason, + /// Tokens served from cache when the provider exposes that metric. + #[serde(default)] + pub cache_read_input_tokens: u32, + /// Tokens written into cache when the provider exposes that metric. + #[serde(default)] + pub cache_creation_input_tokens: u32, +} diff --git a/src/worker/api/remote_tool_types.rs b/src/worker/api/remote_tool_types.rs new file mode 100644 index 000000000..b25ef2899 --- /dev/null +++ b/src/worker/api/remote_tool_types.rs @@ -0,0 +1,46 @@ +//! Remote-tool transport types shared between worker and orchestrator. +//! +//! This module defines the serializable payloads used for hosted remote tool +//! definitions and execution outputs, where [`ToolDefinition`] describes +//! visible tools and [`ToolOutput`] carries execution results. + +use serde::{Deserialize, Serialize}; + +use crate::llm::ToolDefinition; +use crate::tools::ToolOutput; + +/// Request sent from a worker to the orchestrator for hosted remote-tool execution. +/// +/// `tool_name` is the orchestrator tool identifier. `params` must match that +/// tool's JSON Schema because the orchestrator validates and executes the call. +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct RemoteToolExecutionRequest { + /// Stable hosted remote-tool identifier known to both worker and orchestrator. + pub tool_name: String, + /// JSON parameters passed through to the tool implementation. + pub params: serde_json::Value, +} + +/// Response returned after the orchestrator executes a hosted remote tool. +/// +/// `output` is the tool's `ToolOutput`, including its result payload and +/// reported side-effect metadata such as duration and optional cost. +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct RemoteToolExecutionResponse { + /// Tool execution output returned by the orchestrator. + pub output: ToolOutput, +} + +/// Catalogue payload returned to workers for hosted-visible remote tools. +/// +/// `tools` is the current model-facing tool list. `toolset_instructions` is +/// optional human-readable guidance and defaults to an empty list. +/// `catalog_version` is a deterministic content version derived from the +/// serialized catalogue payload. +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct RemoteToolCatalogResponse { + pub tools: Vec, + #[serde(default)] + pub toolset_instructions: Vec, + pub catalog_version: u64, +} diff --git a/src/worker/api/tests/client_construction.rs b/src/worker/api/tests/client_construction.rs new file mode 100644 index 000000000..f86ba34f5 --- /dev/null +++ b/src/worker/api/tests/client_construction.rs @@ -0,0 +1,53 @@ +//! Tests for WorkerHttpClient construction and error handling. + +use rstest::rstest; +use uuid::Uuid; + +use crate::testing::credentials::TEST_BEARER_TOKEN; +use crate::worker::api::WorkerHttpClient; + +/// Regression test: WorkerHttpClient::new succeeds with valid URLs. +/// +/// This test verifies that the fallible constructor properly constructs +/// a WorkerHttpClient with valid URLs without panicking or using `unwrap`. +#[rstest] +#[case("http://localhost:50051")] +#[case("http://localhost:50051/")] +#[case("http://example.com")] +#[case("https://api.example.com")] +fn worker_http_client_new_succeeds_with_valid_url(#[case] url: &str) { + let result = WorkerHttpClient::new( + url.to_string(), + Uuid::new_v4(), + TEST_BEARER_TOKEN.to_string(), + ); + + assert!( + result.is_ok(), + "WorkerHttpClient::new should succeed with valid URL, got error" + ); + + let client = result.expect("client should be built"); + assert_eq!(client.orchestrator_url(), url.trim_end_matches('/')); +} + +/// Regression test: Verify that the new() constructor returns Result and +/// can be properly constructed in test contexts. +#[test] +fn worker_http_client_new_returns_ok_for_test_token() { + let job_id = Uuid::new_v4(); + let result = WorkerHttpClient::new( + "http://host.docker.internal:50051".to_string(), + job_id, + TEST_BEARER_TOKEN.to_string(), + ); + + assert!(result.is_ok(), "expected Ok, got error"); + + let client = result.expect("client should be built"); + assert_eq!(client.job_id(), job_id); + assert_eq!( + client.orchestrator_url(), + "http://host.docker.internal:50051" + ); +} diff --git a/src/worker/api/tests/client_methods.rs b/src/worker/api/tests/client_methods.rs new file mode 100644 index 000000000..4cd5f1085 --- /dev/null +++ b/src/worker/api/tests/client_methods.rs @@ -0,0 +1,331 @@ +//! Tests for `WorkerHttpClient` status, event, prompt, credential, and completion methods. + +use std::sync::Arc; + +use axum::extract::{Path, State}; +use axum::http::{HeaderMap, StatusCode}; +use axum::response::IntoResponse; +use axum::routing::{get, post}; +use axum::{Json, Router}; +use tokio::net::TcpListener; +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::testing::credentials::TEST_BEARER_TOKEN; +use crate::worker::api::{ + COMPLETE_ROUTE, CREDENTIALS_ROUTE, CompletionReport, CredentialResponse, EVENT_ROUTE, + JobEventPayload, PROMPT_ROUTE, STATUS_ROUTE, StatusUpdate, WorkerHttpClient, WorkerState, +}; + +#[derive(Default)] +struct ClientMethodTestState { + status_updates: Mutex>, + event_payloads: Mutex>, + completion_reports: Mutex>, + auth_headers: Mutex>, +} + +async fn record_auth(headers: &HeaderMap, state: &ClientMethodTestState) { + let auth = headers + .get(axum::http::header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()) + .unwrap_or_default() + .to_string(); + state.auth_headers.lock().await.push(auth); +} + +async fn spawn_test_server( + router: Router, +) -> anyhow::Result<(String, tokio::task::JoinHandle<()>)> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let handle = tokio::spawn(async move { + axum::serve(listener, router) + .await + .expect("client method test server should run"); + }); + Ok((format!("http://{addr}"), handle)) +} + +/// Spin up a test server using the provided router and return a connected client. +/// +/// The caller supplies a closure that receives the shared `ClientMethodTestState` +/// and returns a fully configured `Router`. This avoids repeating state +/// construction and client initialisation in every test body. +async fn setup_for_test( + make_router: impl FnOnce(Arc) -> Router, +) -> anyhow::Result<( + Arc, + WorkerHttpClient, + tokio::task::JoinHandle<()>, +)> { + let state = Arc::new(ClientMethodTestState::default()); + let (base_url, handle) = spawn_test_server(make_router(Arc::clone(&state))).await?; + let client = WorkerHttpClient::new(base_url, Uuid::nil(), TEST_BEARER_TOKEN.to_string())?; + Ok((state, client, handle)) +} + +#[tokio::test] +async fn worker_http_client_report_status_posts_status_payload() -> anyhow::Result<()> { + async fn handler( + State(state): State>, + Path(_job_id): Path, + headers: HeaderMap, + Json(update): Json, + ) -> StatusCode { + record_auth(&headers, &state).await; + state.status_updates.lock().await.push(update); + StatusCode::OK + } + + let (state, client, handle) = setup_for_test(|state| { + Router::new() + .route(STATUS_ROUTE, post(handler)) + .with_state(state) + }) + .await?; + + let update = StatusUpdate::new(WorkerState::InProgress, Some("working".to_string()), 3); + client.report_status(&update).await?; + + let updates = state.status_updates.lock().await; + assert_eq!(updates.len(), 1); + assert_eq!(updates[0].state, update.state); + assert_eq!(updates[0].message, update.message); + assert_eq!(updates[0].iteration, update.iteration); + drop(updates); + assert_eq!( + state.auth_headers.lock().await.as_slice(), + &[format!("Bearer {TEST_BEARER_TOKEN}")] + ); + + handle.abort(); + let _ = handle.await; + Ok(()) +} + +#[tokio::test] +async fn worker_http_client_report_status_lossy_swallows_rejections() -> anyhow::Result<()> { + async fn handler( + State(state): State>, + Path(_job_id): Path, + headers: HeaderMap, + Json(update): Json, + ) -> (StatusCode, &'static str) { + record_auth(&headers, &state).await; + state.status_updates.lock().await.push(update); + (StatusCode::INTERNAL_SERVER_ERROR, "nope") + } + + let (state, client, handle) = setup_for_test(|state| { + Router::new() + .route(STATUS_ROUTE, post(handler)) + .with_state(state) + }) + .await?; + + let update = StatusUpdate::new(WorkerState::Running, Some("still going".to_string()), 5); + client.report_status_lossy(&update).await; + + let updates = state.status_updates.lock().await; + assert_eq!(updates.len(), 1); + assert_eq!(updates[0].state, update.state); + assert_eq!(updates[0].message, update.message); + assert_eq!(updates[0].iteration, update.iteration); + + handle.abort(); + let _ = handle.await; + Ok(()) +} + +#[tokio::test] +async fn worker_http_client_post_event_posts_event_payload() -> anyhow::Result<()> { + async fn handler( + State(state): State>, + Path(_job_id): Path, + headers: HeaderMap, + Json(payload): Json, + ) -> StatusCode { + record_auth(&headers, &state).await; + state.event_payloads.lock().await.push(payload); + StatusCode::OK + } + + let (state, client, handle) = setup_for_test(|state| { + Router::new() + .route(EVENT_ROUTE, post(handler)) + .with_state(state) + }) + .await?; + let payload = JobEventPayload { + event_type: crate::worker::api::JobEventType::Message, + data: serde_json::json!({"role": "assistant", "content": "hello"}), + }; + + client.post_event(&payload).await?; + + assert_eq!(state.event_payloads.lock().await.len(), 1); + assert_eq!( + state.event_payloads.lock().await[0].data["content"], + serde_json::json!("hello") + ); + + handle.abort(); + let _ = handle.await; + Ok(()) +} + +#[tokio::test] +async fn worker_http_client_poll_prompt_returns_prompt_response() -> anyhow::Result<()> { + async fn handler( + State(state): State>, + Path(_job_id): Path, + headers: HeaderMap, + ) -> impl IntoResponse { + record_auth(&headers, &state).await; + ( + StatusCode::OK, + [(axum::http::header::CONTENT_TYPE, "application/json")], + r#"{"content":"follow up","done":false}"#, + ) + } + + let (_, client, handle) = setup_for_test(|state| { + Router::new() + .route(PROMPT_ROUTE, get(handler)) + .with_state(state) + }) + .await?; + + let prompt = client + .poll_prompt() + .await? + .expect("prompt should be present"); + + assert_eq!(prompt.content, "follow up"); + assert!(!prompt.done); + + handle.abort(); + let _ = handle.await; + Ok(()) +} + +#[tokio::test] +async fn worker_http_client_poll_prompt_returns_none_for_no_content() -> anyhow::Result<()> { + async fn handler( + State(state): State>, + Path(_job_id): Path, + headers: HeaderMap, + ) -> StatusCode { + record_auth(&headers, &state).await; + StatusCode::NO_CONTENT + } + + let (_, client, handle) = setup_for_test(|state| { + Router::new() + .route(PROMPT_ROUTE, get(handler)) + .with_state(state) + }) + .await?; + + assert!(client.poll_prompt().await?.is_none()); + + handle.abort(); + let _ = handle.await; + Ok(()) +} + +#[tokio::test] +async fn worker_http_client_fetch_credentials_returns_payload() -> anyhow::Result<()> { + async fn handler( + State(state): State>, + Path(_job_id): Path, + headers: HeaderMap, + ) -> Json> { + record_auth(&headers, &state).await; + Json(vec![CredentialResponse { + env_var: "API_TOKEN".to_string(), + value: "secret".to_string(), + }]) + } + + let (_, client, handle) = setup_for_test(|state| { + Router::new() + .route(CREDENTIALS_ROUTE, get(handler)) + .with_state(state) + }) + .await?; + + let credentials = client.fetch_credentials().await?; + + assert_eq!(credentials.len(), 1); + assert_eq!(credentials[0].env_var, "API_TOKEN"); + assert_eq!(credentials[0].value, "secret"); + + handle.abort(); + let _ = handle.await; + Ok(()) +} + +#[tokio::test] +async fn worker_http_client_fetch_credentials_returns_empty_for_no_content() -> anyhow::Result<()> { + async fn handler( + State(state): State>, + Path(_job_id): Path, + headers: HeaderMap, + ) -> StatusCode { + record_auth(&headers, &state).await; + StatusCode::NO_CONTENT + } + + let (_, client, handle) = setup_for_test(|state| { + Router::new() + .route(CREDENTIALS_ROUTE, get(handler)) + .with_state(state) + }) + .await?; + + assert!(client.fetch_credentials().await?.is_empty()); + + handle.abort(); + let _ = handle.await; + Ok(()) +} + +#[tokio::test] +async fn worker_http_client_report_complete_posts_completion_report() -> anyhow::Result<()> { + async fn handler( + State(state): State>, + Path(_job_id): Path, + headers: HeaderMap, + Json(report): Json, + ) -> StatusCode { + record_auth(&headers, &state).await; + state.completion_reports.lock().await.push(report); + StatusCode::OK + } + + let (state, client, handle) = setup_for_test(|state| { + Router::new() + .route(COMPLETE_ROUTE, post(handler)) + .with_state(state) + }) + .await?; + let report = CompletionReport { + success: true, + message: Some("done".to_string()), + iterations: 9, + }; + + client.report_complete(&report).await?; + + let reports = state.completion_reports.lock().await; + assert_eq!(reports.len(), 1); + assert_eq!(reports[0].success, report.success); + assert_eq!(reports[0].message, report.message); + assert_eq!(reports[0].iterations, report.iterations); + + handle.abort(); + let _ = handle.await; + Ok(()) +} diff --git a/src/worker/api/tests/mod.rs b/src/worker/api/tests/mod.rs index 93562ddb2..1ff4904d3 100644 --- a/src/worker/api/tests/mod.rs +++ b/src/worker/api/tests/mod.rs @@ -1,5 +1,7 @@ //! Tests for the worker HTTP client and its shared API type conversions. +mod client_construction; +mod client_methods; mod finish_reason; mod fixtures; mod remote_tool_catalog; diff --git a/src/worker/api/tests/remote_tool_catalog.rs b/src/worker/api/tests/remote_tool_catalog.rs index bdcf49dea..7a154a8f5 100644 --- a/src/worker/api/tests/remote_tool_catalog.rs +++ b/src/worker/api/tests/remote_tool_catalog.rs @@ -22,7 +22,8 @@ async fn remote_tool_catalog_reports_non_success_statuses( server.base_url.clone(), Uuid::new_v4(), TEST_BEARER_TOKEN.to_string(), - ); + ) + .expect("test client should build"); let err = client .get_remote_tool_catalog() diff --git a/src/worker/api/tests/remote_tool_execute.rs b/src/worker/api/tests/remote_tool_execute.rs index e0a021941..c380a6305 100644 --- a/src/worker/api/tests/remote_tool_execute.rs +++ b/src/worker/api/tests/remote_tool_execute.rs @@ -29,7 +29,8 @@ async fn remote_tool_execute_preserves_non_success_statuses( server.base_url.clone(), Uuid::new_v4(), TEST_BEARER_TOKEN.to_string(), - ); + ) + .expect("test client should build"); let err = client .execute_remote_tool("github_search", &serde_json::json!({"query": 7})) diff --git a/src/worker/api/tests/transport_types.rs b/src/worker/api/tests/transport_types.rs index 2de5d2519..3b31ca5cd 100644 --- a/src/worker/api/tests/transport_types.rs +++ b/src/worker/api/tests/transport_types.rs @@ -1,16 +1,69 @@ //! Transport type serialisation fidelity tests. use rstest::rstest; +use serde::Serialize; +use serde::de::DeserializeOwned; +use std::fmt::Debug; use crate::worker::api::{ + COMPLETE_ROUTE, CREDENTIALS_ROUTE, EVENT_ROUTE, JOB_ROUTE, PROMPT_ROUTE, REMOTE_TOOL_CATALOG_ROUTE, REMOTE_TOOL_EXECUTE_ROUTE, RemoteToolCatalogResponse, - RemoteToolExecutionRequest, RemoteToolExecutionResponse, + RemoteToolExecutionRequest, RemoteToolExecutionResponse, STATUS_ROUTE, TerminalResult, }; use super::fixtures::{ sample_catalog_response, sample_execution_request, sample_execution_response, }; +/// Serialise `value` to JSON and immediately deserialise it back, asserting that the round-trip produces an equal value without field loss. +fn assert_round_trips(value: T) +where + T: Serialize + DeserializeOwned + Debug + PartialEq, +{ + let serialized = serde_json::to_string(&value).expect("serialise"); + let deserialized: T = serde_json::from_str(&serialized).expect("deserialise"); + assert_eq!( + deserialized, value, + "value must round-trip without field loss" + ); +} + +const fn const_str_eq(left: &str, right: &str) -> bool { + let left = left.as_bytes(); + let right = right.as_bytes(); + if left.len() != right.len() { + return false; + } + + let mut index = 0; + while index < left.len() { + if left[index] != right[index] { + return false; + } + index += 1; + } + + true +} + +const _: () = assert!(const_str_eq(JOB_ROUTE, "/worker/{job_id}/job")); +const _: () = assert!(const_str_eq( + CREDENTIALS_ROUTE, + "/worker/{job_id}/credentials" +)); +const _: () = assert!(const_str_eq(STATUS_ROUTE, "/worker/{job_id}/status")); +const _: () = assert!(const_str_eq(COMPLETE_ROUTE, "/worker/{job_id}/complete")); +const _: () = assert!(const_str_eq(EVENT_ROUTE, "/worker/{job_id}/event")); +const _: () = assert!(const_str_eq(PROMPT_ROUTE, "/worker/{job_id}/prompt")); +const _: () = assert!(const_str_eq( + REMOTE_TOOL_CATALOG_ROUTE, + "/worker/{job_id}/tools/catalog" +)); +const _: () = assert!(const_str_eq( + REMOTE_TOOL_EXECUTE_ROUTE, + "/worker/{job_id}/tools/execute" +)); + #[test] fn worker_and_orchestrator_share_remote_tool_route_constants() { // These constants are declared in worker::api::types but are shared with the @@ -65,43 +118,39 @@ fn worker_and_orchestrator_share_remote_tool_route_constants() { fn remote_tool_catalog_response_round_trip_without_field_loss( sample_catalog_response: RemoteToolCatalogResponse, ) { - let serialized = serde_json::to_string(&sample_catalog_response) - .expect("serialize RemoteToolCatalogResponse"); - let deserialized: RemoteToolCatalogResponse = - serde_json::from_str(&serialized).expect("deserialize RemoteToolCatalogResponse"); - - assert_eq!( - deserialized, sample_catalog_response, - "catalog response must round-trip without field loss" - ); + assert_round_trips(sample_catalog_response); } #[rstest] fn remote_tool_execution_request_round_trip_without_field_loss( sample_execution_request: RemoteToolExecutionRequest, ) { - let serialized = serde_json::to_string(&sample_execution_request) - .expect("serialize RemoteToolExecutionRequest"); - let deserialized: RemoteToolExecutionRequest = - serde_json::from_str(&serialized).expect("deserialize RemoteToolExecutionRequest"); - - assert_eq!( - deserialized, sample_execution_request, - "execution request must round-trip without field loss" - ); + assert_round_trips(sample_execution_request); } #[rstest] fn remote_tool_execution_response_round_trip_without_field_loss( sample_execution_response: RemoteToolExecutionResponse, ) { - let serialized = serde_json::to_string(&sample_execution_response) - .expect("serialize RemoteToolExecutionResponse"); - let deserialized: RemoteToolExecutionResponse = - serde_json::from_str(&serialized).expect("deserialize RemoteToolExecutionResponse"); + assert_round_trips(sample_execution_response); +} - assert_eq!( - deserialized, sample_execution_response, - "execution response must round-trip without field loss" +#[test] +fn terminal_result_round_trip_preserves_all_fields() { + let result = TerminalResult::success("completed", Some(11)); + + assert_round_trips(result); +} + +#[test] +fn terminal_result_omits_iterations_when_absent() { + let serialized = serde_json::to_value(TerminalResult::failure("failed", None)) + .expect("serialize TerminalResult"); + + assert_eq!(serialized["success"], false); + assert_eq!(serialized["message"], "failed"); + assert!( + serialized.get("iterations").is_none(), + "iterations should be omitted when absent" ); } diff --git a/src/worker/api/tests/url_construction.rs b/src/worker/api/tests/url_construction.rs index 57a37c701..ca5be797e 100644 --- a/src/worker/api/tests/url_construction.rs +++ b/src/worker/api/tests/url_construction.rs @@ -14,7 +14,8 @@ fn test_url_construction(#[case] path: &str) { "http://host.docker.internal:50051".to_string(), Uuid::nil(), TEST_BEARER_TOKEN.to_string(), - ); + ) + .expect("test client should build"); assert_eq!( client.url(path), @@ -34,7 +35,8 @@ fn test_url_construction_with_trailing_slash(#[case] path: &str) { "http://host.docker.internal:50051/".to_string(), Uuid::nil(), TEST_BEARER_TOKEN.to_string(), - ); + ) + .expect("test client should build"); assert_eq!( client.url(path), @@ -53,7 +55,8 @@ fn remote_tool_catalog_url_construction() { "http://host.docker.internal:50051".to_string(), Uuid::nil(), TEST_BEARER_TOKEN.to_string(), - ); + ) + .expect("test client should build"); assert_eq!( client.url(REMOTE_TOOL_CATALOG_PATH), diff --git a/src/worker/api/types.rs b/src/worker/api/types.rs index b5f8dced6..3a2bd9eb8 100644 --- a/src/worker/api/types.rs +++ b/src/worker/api/types.rs @@ -5,10 +5,21 @@ //! status updates, and credential delivery, including shared types such as //! [`ChatMessage`], [`ToolCall`], [`ToolDefinition`], and [`ToolOutput`]. +#[path = "proxy_types.rs"] +mod proxy_types; +#[path = "remote_tool_types.rs"] +mod remote_tool_types; + +use const_format::concatcp; use serde::{Deserialize, Serialize}; -use crate::llm::{ChatMessage, ToolCall, ToolDefinition}; -use crate::tools::ToolOutput; +pub use proxy_types::{ + FinishReason, ProxyCompletionRequest, ProxyCompletionResponse, ProxyToolCompletionRequest, + ProxyToolCompletionResponse, +}; +pub use remote_tool_types::{ + RemoteToolCatalogResponse, RemoteToolExecutionRequest, RemoteToolExecutionResponse, +}; /// Worker lifecycle state sent to the orchestrator. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] @@ -40,14 +51,82 @@ impl std::fmt::Display for WorkerState { } } +/// Route prefix for all per-job worker endpoints. +const WORKER_PREFIX: &str = "/worker/{job_id}/"; + /// Relative worker path for the hosted remote-tool catalogue endpoint. pub const REMOTE_TOOL_CATALOG_PATH: &str = "tools/catalog"; /// Relative worker path for hosted remote-tool execution. pub const REMOTE_TOOL_EXECUTE_PATH: &str = "tools/execute"; /// Axum route for the hosted remote-tool catalogue endpoint. -pub const REMOTE_TOOL_CATALOG_ROUTE: &str = "/worker/{job_id}/tools/catalog"; +pub const REMOTE_TOOL_CATALOG_ROUTE: &str = concatcp!(WORKER_PREFIX, REMOTE_TOOL_CATALOG_PATH); /// Axum route for hosted remote-tool execution. -pub const REMOTE_TOOL_EXECUTE_ROUTE: &str = "/worker/{job_id}/tools/execute"; +pub const REMOTE_TOOL_EXECUTE_ROUTE: &str = concatcp!(WORKER_PREFIX, REMOTE_TOOL_EXECUTE_PATH); + +/// Relative worker path for job description endpoint. +pub const JOB_PATH: &str = "job"; +/// Axum route for job description endpoint. +pub const JOB_ROUTE: &str = concatcp!(WORKER_PREFIX, JOB_PATH); + +/// Relative worker path for credentials endpoint. +pub const CREDENTIALS_PATH: &str = "credentials"; +/// Axum route for credentials endpoint. +pub const CREDENTIALS_ROUTE: &str = concatcp!(WORKER_PREFIX, CREDENTIALS_PATH); + +/// Relative worker path for status update endpoint. +pub const STATUS_PATH: &str = "status"; +/// Axum route for status update endpoint. +pub const STATUS_ROUTE: &str = concatcp!(WORKER_PREFIX, STATUS_PATH); + +/// Relative worker path for completion report endpoint. +pub const COMPLETE_PATH: &str = "complete"; +/// Axum route for completion report endpoint. +pub const COMPLETE_ROUTE: &str = concatcp!(WORKER_PREFIX, COMPLETE_PATH); + +/// Relative worker path for job event endpoint. +pub const EVENT_PATH: &str = "event"; +/// Axum route for job event endpoint. +pub const EVENT_ROUTE: &str = concatcp!(WORKER_PREFIX, EVENT_PATH); + +/// Relative worker path for prompt polling endpoint. +pub const PROMPT_PATH: &str = "prompt"; +/// Axum route for prompt polling endpoint. +pub const PROMPT_ROUTE: &str = concatcp!(WORKER_PREFIX, PROMPT_PATH); + +/// Relative worker path for LLM completion endpoint. +pub const LLM_COMPLETE_PATH: &str = "llm/complete"; +/// Axum route for LLM completion endpoint. +pub const LLM_COMPLETE_ROUTE: &str = concatcp!(WORKER_PREFIX, LLM_COMPLETE_PATH); + +/// Relative worker path for LLM tool completion endpoint. +pub const LLM_COMPLETE_WITH_TOOLS_PATH: &str = "llm/complete_with_tools"; +/// Axum route for LLM tool completion endpoint. +pub const LLM_COMPLETE_WITH_TOOLS_ROUTE: &str = + concatcp!(WORKER_PREFIX, LLM_COMPLETE_WITH_TOOLS_PATH); + +/// Relative path for health check endpoint (no job_id path component). +pub const WORKER_HEALTH_PATH: &str = "health"; +/// Axum route for health check endpoint (no job_id path component). +pub const WORKER_HEALTH_ROUTE: &str = concatcp!("/", WORKER_HEALTH_PATH); + +/// Build a concrete job-scoped path from a job ID and relative suffix. +/// +/// Uses the canonical `WORKER_PREFIX` pattern so route registration and +/// client URL construction share the same source of truth. +pub fn job_scoped_path(job_id: &str, relative: &str) -> String { + WORKER_PREFIX.replace("{job_id}", job_id) + relative +} + +/// Build a worker job URL path from the orchestrator URL, job ID, and path suffix. +/// +/// Returns a canonical URL of the form `{orchestrator_url}/worker/{job_id}/{path}`. +pub fn worker_job_url(orchestrator_url: &str, job_id: &str, path: &str) -> String { + let base = orchestrator_url.trim_end_matches('/'); + let scoped_path = job_scoped_path(job_id, ""); + let scoped = scoped_path.trim_start_matches('/').trim_end_matches('/'); + let path = path.trim_start_matches('/'); + format!("{}/{}/{}", base, scoped, path) +} /// Status update sent from worker to orchestrator. #[derive(Debug, Serialize, Deserialize)] @@ -58,6 +137,11 @@ pub struct StatusUpdate { } impl StatusUpdate { + /// Build a canonical worker status payload for the orchestrator API. + /// + /// Using this constructor keeps call sites aligned with the shared + /// transport type and makes iteration counts explicit at the reporting + /// boundary. pub fn new(state: WorkerState, message: Option, iteration: u32) -> Self { Self { state, @@ -75,151 +159,6 @@ pub struct JobDescription { pub project_dir: Option, } -/// Provider finish reason transported between orchestrator and worker. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum FinishReason { - Stop, - Length, - #[serde(alias = "tool_calls")] - ToolUse, - ContentFilter, - #[serde(other)] - Unknown, -} - -impl From for FinishReason { - fn from(value: crate::llm::FinishReason) -> Self { - match value { - crate::llm::FinishReason::Stop => Self::Stop, - crate::llm::FinishReason::Length => Self::Length, - crate::llm::FinishReason::ToolUse => Self::ToolUse, - crate::llm::FinishReason::ContentFilter => Self::ContentFilter, - crate::llm::FinishReason::Unknown => Self::Unknown, - } - } -} - -impl From for crate::llm::FinishReason { - fn from(value: FinishReason) -> Self { - match value { - FinishReason::Stop => Self::Stop, - FinishReason::Length => Self::Length, - FinishReason::ToolUse => Self::ToolUse, - FinishReason::ContentFilter => Self::ContentFilter, - FinishReason::Unknown => Self::Unknown, - } - } -} - -/// Completion result from the orchestrator (proxied from the real LLM). -#[derive(Debug, Serialize, Deserialize)] -pub struct ProxyCompletionRequest { - /// Conversation history forwarded to the orchestrator-backed provider. - pub messages: Vec, - /// Optional model override requested by the worker. - pub model: Option, - /// Optional token ceiling for the completion. - pub max_tokens: Option, - /// Optional sampling temperature for the completion. - pub temperature: Option, - /// Optional stop-sequence list forwarded unchanged to the provider. - pub stop_sequences: Option>, -} - -/// Completion result returned by the orchestrator-backed provider. -#[derive(Debug, Serialize, Deserialize)] -pub struct ProxyCompletionResponse { - /// Assistant text produced by the proxied completion call. - pub content: String, - /// Provider-reported prompt token usage. - pub input_tokens: u32, - /// Provider-reported completion token usage. - pub output_tokens: u32, - /// Provider finish reason normalised into a transport enum. - pub finish_reason: FinishReason, - /// Tokens served from cache when the provider exposes that metric. - #[serde(default)] - pub cache_read_input_tokens: u32, - /// Tokens written into cache when the provider exposes that metric. - #[serde(default)] - pub cache_creation_input_tokens: u32, -} - -/// Tool-capable completion request forwarded to the orchestrator. -#[derive(Debug, Serialize, Deserialize)] -pub struct ProxyToolCompletionRequest { - /// Conversation history forwarded to the orchestrator-backed provider. - pub messages: Vec, - /// Tool definitions currently visible to the worker. - pub tools: Vec, - /// Optional model override requested by the worker. - pub model: Option, - /// Optional token ceiling for the completion. - pub max_tokens: Option, - /// Optional sampling temperature for the completion. - pub temperature: Option, - /// Optional provider-specific tool-choice override. - pub tool_choice: Option, -} - -/// Tool-capable completion result returned by the orchestrator. -#[derive(Debug, Serialize, Deserialize)] -pub struct ProxyToolCompletionResponse { - /// Optional assistant text returned alongside tool calls. - pub content: Option, - /// Tool calls selected by the orchestrator-backed provider. - pub tool_calls: Vec, - /// Provider-reported prompt token usage. - pub input_tokens: u32, - /// Provider-reported completion token usage. - pub output_tokens: u32, - /// Provider finish reason normalised into a transport enum. - pub finish_reason: FinishReason, - /// Tokens served from cache when the provider exposes that metric. - #[serde(default)] - pub cache_read_input_tokens: u32, - /// Tokens written into cache when the provider exposes that metric. - #[serde(default)] - pub cache_creation_input_tokens: u32, -} - -/// Request sent from a worker to the orchestrator for hosted remote-tool execution. -/// -/// `tool_name` is the orchestrator tool identifier. `params` must match that -/// tool's JSON Schema because the orchestrator validates and executes the call. -#[derive(Debug, Serialize, Deserialize, PartialEq)] -pub struct RemoteToolExecutionRequest { - /// Stable hosted remote-tool identifier known to both worker and orchestrator. - pub tool_name: String, - /// JSON parameters passed through to the tool implementation. - pub params: serde_json::Value, -} - -/// Response returned after the orchestrator executes a hosted remote tool. -/// -/// `output` is the tool's `ToolOutput`, including its result payload and -/// reported side-effect metadata such as duration and optional cost. -#[derive(Debug, Serialize, Deserialize, PartialEq)] -pub struct RemoteToolExecutionResponse { - /// Tool execution output returned by the orchestrator. - pub output: ToolOutput, -} - -/// Catalogue payload returned to workers for hosted-visible remote tools. -/// -/// `tools` is the current model-facing tool list. `toolset_instructions` is -/// optional human-readable guidance and defaults to an empty list. -/// `catalog_version` is a deterministic content version derived from the -/// serialized catalogue payload. -#[derive(Debug, Serialize, Deserialize, PartialEq)] -pub struct RemoteToolCatalogResponse { - pub tools: Vec, - #[serde(default)] - pub toolset_instructions: Vec, - pub catalog_version: u64, -} - #[derive(Debug, Serialize, Deserialize)] pub struct CompletionReport { /// Whether the worker completed the job successfully. @@ -296,3 +235,44 @@ impl std::fmt::Debug for CredentialResponse { .finish() } } + +/// Terminal result payload emitted with [`JobEventType::Result`]. +/// +/// Provides a consistent serialized shape for job completion events, +/// whether successful or failed. +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct TerminalResult { + /// Whether the job completed successfully. + pub success: bool, + /// Human-readable completion summary or failure message. + pub message: String, + /// Number of iterations completed before exit, if applicable. + #[serde(skip_serializing_if = "Option::is_none")] + pub iterations: Option, +} + +impl TerminalResult { + /// Create a new terminal result for a successful job. + /// + /// This is the result payload carried in `JobEventType::Result`, which + /// complements but does not replace the authoritative completion report. + pub fn success(message: impl Into, iterations: Option) -> Self { + Self { + success: true, + message: message.into(), + iterations, + } + } + + /// Create a new terminal result for a failed job. + /// + /// Failure payloads intentionally carry a sanitized, user-facing summary + /// rather than arbitrary internal error detail. + pub fn failure(message: impl Into, iterations: Option) -> Self { + Self { + success: false, + message: message.into(), + iterations, + } + } +} diff --git a/src/worker/claude_bridge/reporting.rs b/src/worker/claude_bridge/reporting.rs index 999058b11..c233e3d3b 100644 --- a/src/worker/claude_bridge/reporting.rs +++ b/src/worker/claude_bridge/reporting.rs @@ -51,7 +51,9 @@ impl ClaudeBridgeRuntime { event_type, data: data.clone(), }; - self.client.post_event(&payload).await; + if let Err(e) = self.client.post_event(&payload).await { + tracing::debug!(error = %e, "Failed to report event"); + } } pub(super) async fn poll_for_prompt(&self) -> Result, WorkerError> { diff --git a/src/worker/claude_bridge/session.rs b/src/worker/claude_bridge/session.rs index 121e5d030..f17a3b35a 100644 --- a/src/worker/claude_bridge/session.rs +++ b/src/worker/claude_bridge/session.rs @@ -144,7 +144,9 @@ impl ClaudeBridgeRuntime { event_type: JobEventType::Status, data: serde_json::json!({ "message": line }), }; - client_for_stderr.post_event(&payload).await; + if let Err(e) = client_for_stderr.post_event(&payload).await { + tracing::debug!(job_id = %job_id, error = %e, "Failed to post stderr event"); + } } Ok(None) => break, Err(error) => { diff --git a/src/worker/container.rs b/src/worker/container.rs index 25413e54d..2be9f8eb5 100644 --- a/src/worker/container.rs +++ b/src/worker/container.rs @@ -21,20 +21,24 @@ use crate::llm::{ChatMessage, LlmProvider, Reasoning, ReasoningContext}; use crate::safety::SafetyLayer; use crate::tools::ToolRegistry; use crate::tools::builtin::worker_remote_tool_proxy::register_worker_remote_tool_proxies; -use crate::worker::api::{ - CompletionReport, JobEventPayload, JobEventType, StatusUpdate, WorkerHttpClient, WorkerState, -}; +use crate::worker::api::{WorkerHttpClient, WorkerState}; use crate::worker::proxy_llm::ProxyLlmProvider; mod delegate; +mod reporting; use delegate::ContainerDelegate; +use reporting::WorkerExecutionResult; /// Configuration for the worker runtime. pub struct WorkerConfig { + /// Job identifier used to scope all worker-orchestrator requests. pub job_id: Uuid, + /// Base orchestrator URL that owns the per-job worker endpoints. pub orchestrator_url: String, + /// Maximum number of LLM/tool iterations before the worker aborts. pub max_iterations: u32, + /// Hard wall-clock timeout for the full worker execution. pub timeout: Duration, } @@ -68,12 +72,6 @@ pub struct WorkerRuntime { extra_env: Arc>, } -enum WorkerExecutionResult { - Outcome(LoopOutcome), - Failed(crate::error::Error), - TimedOut, -} - /// Heading used to tag hosted remote-tool guidance messages in reasoning context. pub(crate) const HOSTED_GUIDANCE_HEADING: &str = "Hosted remote-tool guidance"; @@ -88,26 +86,41 @@ async fn available_tool_definitions(tools: &ToolRegistry) -> Vec Result { - let client = Arc::new(WorkerHttpClient::from_env( - config.orchestrator_url.clone(), - config.job_id, - )?); - - Ok(Self::from_client(config, client)) - } - - /// Construct a worker runtime from a pre-validated [`WorkerHttpClient`]. + /// This is the primary constructor for `WorkerRuntime`. It takes a + /// pre-constructed client, making dependency injection straightforward + /// and allowing tests to provide mock clients without environment setup. /// - /// Unlike [`Self::new`], this path performs no fallible initialization: - /// `new` returns `Result` because it builds the client - /// with [`WorkerHttpClient::from_env`] using the supplied [`WorkerConfig`], - /// while `from_client` takes an `Arc` that has already - /// completed that validation and therefore returns `Self` directly. - fn from_client(config: WorkerConfig, client: Arc) -> Self { + /// # Errors + /// + /// Returns `WorkerError::ConfigMismatch` if the provided `client` job_id or + /// orchestrator_url doesn't match the values in `config`. This validates + /// configuration consistency at construction time to prevent subtle runtime + /// errors. + pub fn new(config: WorkerConfig, client: Arc) -> Result { + // Validate that config and client are consistent + if config.job_id != client.job_id() { + return Err(WorkerError::ConfigMismatch { + field: crate::error::ConfigMismatchField::JobId, + reason: format!( + "WorkerConfig job_id ({}) must match WorkerHttpClient job_id ({})", + config.job_id, + client.job_id() + ), + }); + } + if config.orchestrator_url.trim_end_matches('/') != client.orchestrator_url() { + return Err(WorkerError::ConfigMismatch { + field: crate::error::ConfigMismatchField::OrchestratorUrl, + reason: format!( + "WorkerConfig orchestrator_url ({}) must match WorkerHttpClient orchestrator_url ({})", + config.orchestrator_url.trim_end_matches('/'), + client.orchestrator_url() + ), + }); + } + let llm: Arc = Arc::new(ProxyLlmProvider::new( Arc::clone(&client), "proxied".to_string(), @@ -120,7 +133,7 @@ impl WorkerRuntime { let tools = Self::build_tools(); - Self { + Ok(Self { config, client, llm, @@ -128,7 +141,21 @@ impl WorkerRuntime { tools, toolset_instructions: Vec::new(), extra_env: Arc::new(HashMap::new()), - } + }) + } + + /// Create a new worker runtime from environment variables. + /// + /// Reads `IRONCLAW_WORKER_TOKEN` from the environment for auth. + /// This is a convenience constructor for production use; tests should + /// prefer [`Self::new`] with an explicit client. + pub fn from_env(config: WorkerConfig) -> Result { + let client = Arc::new(WorkerHttpClient::from_env( + config.orchestrator_url.clone(), + config.job_id, + )?); + + Self::new(config, client) } fn build_tools() -> Arc { @@ -177,15 +204,13 @@ impl WorkerRuntime { } let iteration_tracker = Arc::new(Mutex::new(0u32)); - let execution = match tokio::time::timeout( - self.config.timeout, - self.run_job_loop(&job, Arc::clone(&iteration_tracker)), - ) - .await + let execution = match self + .run_job_loop(&job, Arc::clone(&iteration_tracker)) + .await { - Ok(Ok(outcome)) => WorkerExecutionResult::Outcome(outcome), - Ok(Err(error)) => WorkerExecutionResult::Failed(error), - Err(_) => WorkerExecutionResult::TimedOut, + Ok(Some(outcome)) => WorkerExecutionResult::Outcome(outcome), + Ok(None) => WorkerExecutionResult::TimedOut, + Err(error) => WorkerExecutionResult::Failed(error), }; let iterations = *iteration_tracker.lock().await; @@ -212,69 +237,21 @@ impl WorkerRuntime { Ok(()) } - async fn fail_pre_loop(&self, stage: &str, error: WorkerError) -> Result { - tracing::error!( - job_id = %self.config.job_id, - stage, - error = %error, - "Worker failed before the execution loop started" - ); - - if let Err(report_error) = self - .report_worker_status( - WorkerState::Failed, - Some("pre-loop failure".to_string()), - 100, - ) - .await - { - tracing::warn!( - job_id = %self.config.job_id, - stage, - error = %report_error, - "Failed to emit terminal pre-loop worker status" - ); - } - - if let Err(report_error) = self.report_failure(0, "Worker failed during startup").await { - tracing::warn!( - job_id = %self.config.job_id, - stage, - error = %report_error, - "Failed to emit terminal pre-loop completion" - ); - } - - Err(error) - } - - async fn report_worker_status( - &self, - state: WorkerState, - message: Option, - iteration: u32, - ) -> Result<(), WorkerError> { - self.client - .report_status(&StatusUpdate::new(state, message, iteration)) - .await - } - async fn run_job_loop( &self, job: &crate::worker::api::JobDescription, iteration_tracker: Arc>, - ) -> Result { + ) -> Result, crate::error::Error> { let reasoning = Reasoning::new(Arc::clone(&self.llm)); let mut reason_ctx = self.build_reasoning_context(job).await; - let delegate = ContainerDelegate { - client: Arc::clone(&self.client), - safety: Arc::clone(&self.safety), - tools: Arc::clone(&self.tools), - extra_env: Arc::clone(&self.extra_env), - last_output: Mutex::new(String::new()), + let delegate = ContainerDelegate::new( + Arc::clone(&self.client), + Arc::clone(&self.safety), + Arc::clone(&self.tools), + Arc::clone(&self.extra_env), iteration_tracker, - }; + ); let config = AgenticLoopConfig { max_iterations: self.config.max_iterations as usize, @@ -282,13 +259,22 @@ impl WorkerRuntime { max_tool_intent_nudges: 2, }; - crate::agent::agentic_loop::run_agentic_loop( - &delegate, - &reasoning, - &mut reason_ctx, - &config, + let outcome = tokio::time::timeout( + self.config.timeout, + crate::agent::agentic_loop::run_agentic_loop( + &delegate, + &reasoning, + &mut reason_ctx, + &config, + ), ) - .await + .await; + + delegate.shutdown().await; + match outcome { + Ok(result) => result.map(Some), + Err(_) => Ok(None), + } } async fn build_reasoning_context( @@ -318,91 +304,6 @@ Work independently to complete this job. Report when done."#, reason_ctx } - async fn report_completion( - &self, - execution: WorkerExecutionResult, - iterations: u32, - ) -> Result<(), WorkerError> { - match execution { - WorkerExecutionResult::Outcome(LoopOutcome::Response(output)) => { - tracing::info!("Worker completed job {} successfully", self.config.job_id); - self.post_event( - JobEventType::Result, - serde_json::json!({ - "success": true, - "message": truncate_for_preview(&output, 2000), - }), - ) - .await; - self.client - .report_complete(&CompletionReport { - success: true, - message: Some(output), - iterations, - }) - .await - } - WorkerExecutionResult::Outcome(LoopOutcome::MaxIterations) => { - let msg = format!("max iterations ({}) exceeded", self.config.max_iterations); - tracing::warn!("Worker failed for job {}: {}", self.config.job_id, msg); - self.report_failure(iterations, &format!("Execution failed: {}", msg)) - .await - } - WorkerExecutionResult::Outcome(LoopOutcome::Stopped | LoopOutcome::NeedApproval(_)) => { - tracing::info!("Worker for job {} stopped", self.config.job_id); - self.post_event( - JobEventType::Result, - serde_json::json!({ - "success": false, - "message": "Execution stopped", - "iterations": iterations, - }), - ) - .await; - self.client - .report_complete(&CompletionReport { - success: false, - message: Some("Execution stopped".to_string()), - iterations, - }) - .await - } - WorkerExecutionResult::Failed(error) => { - tracing::error!("Worker failed for job {}: {}", self.config.job_id, error); - self.report_failure(iterations, "Execution failed").await - } - WorkerExecutionResult::TimedOut => { - tracing::warn!("Worker timed out for job {}", self.config.job_id); - self.report_failure(iterations, "Execution timed out").await - } - } - } - - async fn report_failure(&self, iterations: u32, message: &str) -> Result<(), WorkerError> { - self.post_event( - JobEventType::Result, - serde_json::json!({ - "success": false, - "message": message, - }), - ) - .await; - self.client - .report_complete(&CompletionReport { - success: false, - message: Some(message.to_string()), - iterations, - }) - .await - } - - /// Post a job event to the orchestrator (fire-and-forget). - async fn post_event(&self, event_type: JobEventType, data: serde_json::Value) { - self.client - .post_event(&JobEventPayload { event_type, data }) - .await; - } - async fn register_remote_tools(&self) -> Result, WorkerError> { let remote_catalog = self.client.get_remote_tool_catalog().await?; let remote_tool_count = remote_catalog.tools.len(); diff --git a/src/worker/container/delegate.rs b/src/worker/container/delegate.rs index 0efd34b9f..3ea259d27 100644 --- a/src/worker/container/delegate.rs +++ b/src/worker/container/delegate.rs @@ -8,7 +8,7 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, mpsc}; use super::{WorkerHttpClient, available_tool_definitions}; use crate::agent::agentic_loop::{ @@ -21,6 +21,10 @@ use crate::tools::ToolRegistry; use crate::tools::execute::{execute_tool_simple, process_tool_result}; use crate::worker::api::{JobEventPayload, JobEventType, StatusUpdate, WorkerState}; +/// Capacity for the event channel; bounds memory growth if the orchestrator +/// becomes slow or unresponsive. +const EVENT_CHANNEL_CAPACITY: usize = 256; + /// Container delegate: implements `LoopDelegate` for the Docker container context. /// /// Tools execute sequentially. Events are posted to the orchestrator via HTTP. @@ -34,13 +38,62 @@ pub(super) struct ContainerDelegate { pub(super) last_output: Mutex, /// Tracks the current iteration so `CompletionReport` can include accurate counts. pub(super) iteration_tracker: Arc>, + /// Sender for fire-and-forget event posting to the background worker. + pub(super) event_sender: mpsc::Sender, + /// Handle for the background event-posting task. + event_handle: tokio::task::JoinHandle<()>, } impl ContainerDelegate { - pub(super) async fn post_event(&self, event_type: JobEventType, data: serde_json::Value) { - self.client - .post_event(&JobEventPayload { event_type, data }) - .await; + /// Create a new [`ContainerDelegate`] with a background event-sender task. + pub(super) fn new( + client: Arc, + safety: Arc, + tools: Arc, + extra_env: Arc>, + iteration_tracker: Arc>, + ) -> Self { + let (event_sender, mut event_receiver) = + mpsc::channel::(EVENT_CHANNEL_CAPACITY); + + // Spawn background task to handle event POSTs asynchronously + let bg_client = Arc::clone(&client); + let event_handle = tokio::spawn(async move { + while let Some(payload) = event_receiver.recv().await { + if let Err(e) = bg_client.post_event(&payload).await { + tracing::warn!(error = %e, "Failed to post event"); + } + } + }); + + Self { + client, + safety, + tools, + extra_env, + last_output: Mutex::new(String::new()), + iteration_tracker, + event_sender, + event_handle, + } + } + + /// Shut down the delegate, draining any buffered events. + /// + /// Closes the event channel and awaits the background worker so + /// in-flight events are flushed before the delegate is dropped. + pub(super) async fn shutdown(self) { + drop(self.event_sender); + if let Err(e) = self.event_handle.await { + tracing::warn!(error = %e, "Event worker task panicked"); + } + } + + pub(super) fn post_event(&self, event_type: JobEventType, data: serde_json::Value) { + let payload = JobEventPayload { event_type, data }; + if let Err(e) = self.event_sender.try_send(payload) { + tracing::warn!(error = %e, "Failed to enqueue event for posting"); + } } async fn poll_and_inject_prompt(&self, reason_ctx: &mut ReasoningContext) { @@ -56,8 +109,7 @@ impl ContainerDelegate { "role": "user", "content": truncate_for_preview(&prompt.content, 2000), }), - ) - .await; + ); reason_ctx.messages.push(ChatMessage::user(&prompt.content)); } Ok(None) => {} @@ -120,8 +172,7 @@ impl NativeLoopDelegate for ContainerDelegate { "role": "assistant", "content": truncate_for_preview(text, 2000), }), - ) - .await; + ); if crate::util::llm_signals_completion(text) { let last = self.last_output.lock().await; @@ -150,8 +201,7 @@ impl NativeLoopDelegate for ContainerDelegate { "role": "assistant", "content": truncate_for_preview(text, 2000), }), - ) - .await; + ); } reason_ctx @@ -168,8 +218,7 @@ impl NativeLoopDelegate for ContainerDelegate { "tool_name": tc.name, "input": truncate_for_preview(&tc.arguments.to_string(), 500), }), - ) - .await; + ); let job_ctx = JobContext { extra_env: self.extra_env.clone(), @@ -189,8 +238,7 @@ impl NativeLoopDelegate for ContainerDelegate { "output": truncate_for_preview(&tool_result_content, 2000), "success": result.is_ok(), }), - ) - .await; + ); if let Ok(ref output) = result { *self.last_output.lock().await = output.clone(); @@ -210,8 +258,7 @@ impl NativeLoopDelegate for ContainerDelegate { "content": truncate_for_preview(text, 2000), "nudge": true, }), - ) - .await; + ); } async fn after_iteration(&self, _iteration: usize) { diff --git a/src/worker/container/reporting.rs b/src/worker/container/reporting.rs new file mode 100644 index 000000000..fa5612490 --- /dev/null +++ b/src/worker/container/reporting.rs @@ -0,0 +1,210 @@ +//! Worker completion and status reporting logic. +//! +//! This module handles all interactions with the orchestrator for status updates, +//! completion reports, and job events. It encapsulates the reporting protocol +//! and provides a clean interface for the main worker loop. + +use std::sync::Arc; + +use crate::agent::agentic_loop::{LoopOutcome, truncate_for_preview}; +use crate::error::WorkerError; +use crate::worker::api::{ + CompletionReport, JobEventPayload, JobEventType, StatusUpdate, TerminalResult, WorkerState, +}; +use crate::worker::container::WorkerRuntime; + +/// Execution result discriminator used internally by the worker loop. +pub(super) enum WorkerExecutionResult { + Outcome(LoopOutcome), + Failed(crate::error::Error), + TimedOut, +} + +impl WorkerRuntime { + /// Report a pre-loop failure to the orchestrator and return an error. + /// + /// This is called when the worker fails during initialization (e.g., fetching + /// the job description or hydrating credentials) before the main execution + /// loop starts. + pub(super) async fn fail_pre_loop( + &self, + stage: &str, + error: WorkerError, + ) -> Result { + tracing::error!( + job_id = %self.config.job_id, + stage, + error = %error, + "Worker failed before the execution loop started" + ); + + if let Err(report_error) = self + .report_worker_status(WorkerState::Failed, Some("pre-loop failure".to_string()), 0) + .await + { + tracing::warn!( + job_id = %self.config.job_id, + stage, + error = %report_error, + "Failed to emit terminal pre-loop worker status" + ); + } + + if let Err(report_error) = self.report_failure(0, "Worker failed during startup").await { + tracing::warn!( + job_id = %self.config.job_id, + stage, + error = %report_error, + "Failed to emit terminal pre-loop completion" + ); + } + + Err(error) + } + + /// Report worker status to the orchestrator. + pub(super) async fn report_worker_status( + &self, + state: WorkerState, + message: Option, + iteration: u32, + ) -> Result<(), WorkerError> { + self.client + .report_status(&StatusUpdate::new(state, message, iteration)) + .await + } + + /// Report the final completion state to the orchestrator based on execution result. + pub(super) async fn report_completion( + &self, + execution: WorkerExecutionResult, + iterations: u32, + ) -> Result<(), WorkerError> { + match execution { + WorkerExecutionResult::Outcome(LoopOutcome::Response(output)) => { + tracing::info!("Worker completed job {} successfully", self.config.job_id); + self.post_event( + JobEventType::Result, + serde_json::to_value(TerminalResult::success( + truncate_for_preview(&output, 2000), + Some(iterations), + )) + .unwrap_or_default(), + ); + self.client + .report_complete(&CompletionReport { + success: true, + message: Some(output), + iterations, + }) + .await + } + WorkerExecutionResult::Outcome(LoopOutcome::MaxIterations) => { + let msg = format!("max iterations ({}) exceeded", self.config.max_iterations); + tracing::warn!("Worker failed for job {}: {}", self.config.job_id, msg); + self.report_failure(iterations, &format!("Execution failed: {}", msg)) + .await + } + WorkerExecutionResult::Outcome(LoopOutcome::Stopped) => { + tracing::info!("Worker for job {} stopped", self.config.job_id); + self.report_stopped_outcome(iterations).await + } + WorkerExecutionResult::Outcome(LoopOutcome::NeedApproval(_)) => { + tracing::warn!( + "Worker for job {} reached unexpected NeedApproval state", + self.config.job_id + ); + self.report_stopped_outcome(iterations).await + } + WorkerExecutionResult::Failed(error) => { + tracing::error!("Worker failed for job {}: {}", self.config.job_id, error); + self.report_failure(iterations, "Execution failed").await + } + WorkerExecutionResult::TimedOut => { + tracing::warn!("Worker timed out for job {}", self.config.job_id); + self.report_failure(iterations, "Execution timed out").await + } + } + } + + /// Report a stopped outcome to the orchestrator. + /// + /// This helper is shared by the `Stopped` and `NeedApproval` arms of + /// `report_completion`, which use identical reporting behaviour. + async fn report_stopped_outcome(&self, iterations: u32) -> Result<(), WorkerError> { + self.post_event( + JobEventType::Result, + serde_json::to_value(TerminalResult::failure( + "Execution stopped", + Some(iterations), + )) + .unwrap_or_default(), + ); + self.client + .report_complete(&CompletionReport { + success: false, + message: Some("Execution stopped".to_string()), + iterations, + }) + .await + } + + /// Report a failure to the orchestrator. + pub(super) async fn report_failure( + &self, + iterations: u32, + message: &str, + ) -> Result<(), WorkerError> { + self.post_event( + JobEventType::Result, + serde_json::to_value(TerminalResult::failure(message, Some(iterations))) + .unwrap_or_default(), + ); + self.client + .report_complete(&CompletionReport { + success: false, + message: Some(message.to_string()), + iterations, + }) + .await + } + + /// Post a job event to the orchestrator (fire-and-forget). + /// + /// Spawns a background task with a bounded timeout to ensure slow event + /// endpoints cannot delay authoritative completion reports. + pub(super) fn post_event(&self, event_type: JobEventType, data: serde_json::Value) { + let client = Arc::clone(&self.client); + let job_id = self.config.job_id; + + tokio::spawn(async move { + let payload = JobEventPayload { event_type, data }; + let result = tokio::time::timeout( + std::time::Duration::from_secs(5), + client.post_event(&payload), + ) + .await; + + match result { + Ok(Ok(())) => { + tracing::debug!(job_id = %job_id, ?event_type, "Posted job event"); + } + Ok(Err(e)) => { + tracing::warn!( + job_id = %job_id, + ?event_type, + error = %e, + "Job event post failed" + ); + } + Err(_) => { + tracing::warn!( + job_id = %job_id, + ?event_type, + "Job event post timed out after 5s" + ); + } + } + }); + } +} diff --git a/src/worker/container/tests/hosted_fidelity.rs b/src/worker/container/tests/hosted_fidelity.rs index 9092c692b..68a912330 100644 --- a/src/worker/container/tests/hosted_fidelity.rs +++ b/src/worker/container/tests/hosted_fidelity.rs @@ -6,6 +6,7 @@ use std::sync::Arc; +use anyhow::Context as _; use axum::Json; use axum::extract::{Path, State}; use rstest::{fixture, rstest}; @@ -55,19 +56,19 @@ async fn remote_tool_catalog_with_complex_tool( async fn hosted_catalog_harness() -> Result> { let (base_url, server) = spawn_test_server(remote_tool_catalog_with_complex_tool).await?; - let client = Arc::new(WorkerHttpClient::new( - base_url.clone(), - Uuid::nil(), - "test".to_string(), - )); - let runtime = WorkerRuntime::from_client( + let client = Arc::new( + WorkerHttpClient::new(base_url.clone(), Uuid::nil(), "test".to_string()) + .context("test client should build")?, + ); + let runtime = WorkerRuntime::new( WorkerConfig { job_id: Uuid::nil(), orchestrator_url: base_url, ..WorkerConfig::default() }, client, - ); + ) + .context("test runtime should build")?; Ok(HostedCatalogHarness { runtime, server }) } diff --git a/src/worker/container/tests/mod.rs b/src/worker/container/tests/mod.rs index eb2bd77f6..4a2143d20 100644 --- a/src/worker/container/tests/mod.rs +++ b/src/worker/container/tests/mod.rs @@ -7,6 +7,7 @@ use super::*; mod hosted_fidelity; mod pre_loop; mod remote_tools; +mod shutdown; pub mod test_support; #[rstest] diff --git a/src/worker/container/tests/pre_loop.rs b/src/worker/container/tests/pre_loop.rs index a58231c75..d206fad19 100644 --- a/src/worker/container/tests/pre_loop.rs +++ b/src/worker/container/tests/pre_loop.rs @@ -7,9 +7,62 @@ use rstest::rstest; use uuid::Uuid; use super::test_support::{RuntimeTestState, setup_runtime_test}; -use crate::error::{Error, ToolError}; -use crate::worker::api::WorkerState; -use crate::worker::container::{WorkerError, WorkerExecutionResult}; +use crate::error::{ConfigMismatchField, Error, ToolError}; +use crate::testing::test_utils::EnvVarsGuard; +use crate::worker::api::{WorkerHttpClient, WorkerState}; +use crate::worker::container::{WorkerConfig, WorkerError, WorkerExecutionResult, WorkerRuntime}; + +/// Regression test: WorkerRuntime::new should return ConfigMismatch error +/// when config fields don't match the client. +#[rstest] +#[case(ConfigMismatchField::JobId, Uuid::new_v4(), "http://localhost:50051")] +#[case( + ConfigMismatchField::OrchestratorUrl, + Uuid::nil(), + "http://different-host:50052" +)] +fn worker_runtime_new_returns_error_on_config_mismatch( + #[case] expected_field: ConfigMismatchField, + #[case] job_id: Uuid, + #[case] orchestrator_url: &str, +) { + // Client is created with Uuid::nil() and "http://localhost:50051" + let client = Arc::new( + WorkerHttpClient::new( + "http://localhost:50051".to_string(), + Uuid::nil(), + "test".to_string(), + ) + .expect("test client should build"), + ); + + let result = WorkerRuntime::new( + WorkerConfig { + job_id, + orchestrator_url: orchestrator_url.to_string(), + ..WorkerConfig::default() + }, + client, + ); + + match result { + Err(WorkerError::ConfigMismatch { field, .. }) => { + assert_eq!( + field, expected_field, + "expected ConfigMismatch for {:?}", + expected_field + ) + } + Ok(_) => panic!( + "expected ConfigMismatch error for {:?}, got Ok", + expected_field + ), + Err(other) => panic!( + "expected ConfigMismatch error for {:?}, got {:?}", + expected_field, other + ), + } +} #[derive(Clone, Copy, Debug)] enum PreLoopFailureCase { @@ -17,6 +70,22 @@ enum PreLoopFailureCase { HydrateCredentials, } +async fn poll_result_event(state: &RuntimeTestState) -> serde_json::Value { + let mut attempts = 0; + loop { + let result_events = state.result_events.lock().await; + if !result_events.is_empty() { + return result_events[0].clone(); + } + drop(result_events); + attempts += 1; + if attempts > 50 { + panic!("expected a terminal result event within 500ms"); + } + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + } +} + async fn assert_startup_failure_completions(state: &RuntimeTestState) { let completions = state.completions.lock().await; assert_eq!( @@ -30,10 +99,9 @@ async fn assert_startup_failure_completions(state: &RuntimeTestState) { ); drop(completions); - let result_events = state.result_events.lock().await; - assert_eq!(result_events.len(), 1, "expected a terminal result event"); - assert_eq!(result_events[0]["message"], "Worker failed during startup"); - assert_eq!(result_events[0]["success"], false); + let result_event = poll_result_event(state).await; + assert_eq!(result_event["message"], "Worker failed during startup"); + assert_eq!(result_event["success"], false); } async fn assert_startup_failure(state: &RuntimeTestState) { @@ -47,7 +115,7 @@ async fn assert_startup_failure(state: &RuntimeTestState) { .first() .filter(|status| status.state == WorkerState::Failed) .expect("expected a terminal failed status update"); - assert_eq!(failed_status.iteration, 100); + assert_eq!(failed_status.iteration, 0); assert_eq!( failed_status.message.as_deref(), Some("pre-loop failure"), @@ -139,7 +207,7 @@ async fn worker_runtime_emits_failed_status_for_initial_status_rejections() -> a ); assert_eq!(statuses[0].state, WorkerState::InProgress); assert_eq!(statuses[1].state, WorkerState::Failed); - assert_eq!(statuses[1].iteration, 100); + assert_eq!(statuses[1].iteration, 0); assert_eq!( statuses[1].message.as_deref(), Some("pre-loop failure"), @@ -184,18 +252,59 @@ async fn worker_runtime_sanitizes_failure_messages( assert_eq!(completions[0].iterations, 7); drop(completions); - let result_events = state.result_events.lock().await; - assert_eq!(result_events.len(), 1); - assert_eq!(result_events[0]["message"], expected_message); - assert_eq!(result_events[0]["success"], false); + let result_event = poll_result_event(&state).await; + + assert_eq!(result_event["message"], expected_message); + assert_eq!(result_event["success"], false); assert!( - result_events[0].to_string().contains(expected_message), + result_event.to_string().contains(expected_message), "expected result payload to contain the sanitised message" ); assert!( - !result_events[0].to_string().contains("secret-123"), + !result_event.to_string().contains("secret-123"), "result payload should not leak the detailed error text" ); Ok(()) } + +#[test] +fn worker_runtime_from_env_reads_worker_token() { + let mut env = EnvVarsGuard::new(&["IRONCLAW_WORKER_TOKEN"]); + env.set("IRONCLAW_WORKER_TOKEN", "token-from-env"); + + let runtime = WorkerRuntime::from_env(WorkerConfig { + job_id: Uuid::new_v4(), + orchestrator_url: "http://localhost:50051/".to_string(), + ..WorkerConfig::default() + }) + .expect("from_env should succeed when the worker token is present"); + + assert_eq!( + runtime.client.orchestrator_url(), + "http://localhost:50051", + "from_env should preserve the client URL normalization rules" + ); + assert_eq!( + runtime.config.max_iterations, + WorkerConfig::default().max_iterations + ); + assert_eq!(runtime.config.timeout, WorkerConfig::default().timeout); +} + +#[test] +fn worker_runtime_from_env_returns_missing_token_without_worker_env() { + let mut env = EnvVarsGuard::new(&["IRONCLAW_WORKER_TOKEN"]); + env.remove("IRONCLAW_WORKER_TOKEN"); + + let result = WorkerRuntime::from_env(WorkerConfig { + job_id: Uuid::new_v4(), + orchestrator_url: "http://localhost:50051".to_string(), + ..WorkerConfig::default() + }); + + assert!( + matches!(result, Err(WorkerError::MissingToken)), + "expected MissingToken when IRONCLAW_WORKER_TOKEN is absent" + ); +} diff --git a/src/worker/container/tests/remote_tools.rs b/src/worker/container/tests/remote_tools.rs index 7e2b98a32..0ee0a0559 100644 --- a/src/worker/container/tests/remote_tools.rs +++ b/src/worker/container/tests/remote_tools.rs @@ -2,6 +2,7 @@ use std::sync::Arc; +use anyhow::Context; use axum::extract::{Path, State}; use axum::routing::get; use axum::{Json, Router}; @@ -81,7 +82,7 @@ pub(super) struct TestState; /// handle for the background server task. pub(super) async fn spawn_test_server( handler: H, -) -> Result<(String, tokio::task::JoinHandle<()>), Box> +) -> Result<(String, tokio::task::JoinHandle<()>), anyhow::Error> where H: axum::handler::Handler + Clone + Send + 'static, T: 'static, @@ -100,26 +101,26 @@ where } async fn spawn_hosted_guidance_catalog_server() --> Result<(String, tokio::task::JoinHandle<()>), Box> { +-> Result<(String, tokio::task::JoinHandle<()>), anyhow::Error> { spawn_test_server(remote_tool_catalog).await } async fn build_runtime_with_remote_tools( base_url: &str, -) -> Result<(WorkerRuntime, Arc), Box> { +) -> Result<(WorkerRuntime, Arc), anyhow::Error> { let client = Arc::new(WorkerHttpClient::new( base_url.to_string(), Uuid::nil(), "test".to_string(), - )); - let mut runtime = WorkerRuntime::from_client( + )?); + let mut runtime = WorkerRuntime::new( WorkerConfig { job_id: Uuid::nil(), orchestrator_url: base_url.to_string(), ..WorkerConfig::default() }, Arc::clone(&client), - ); + )?; runtime.toolset_instructions = runtime.register_remote_tools().await?; Ok((runtime, client)) } @@ -129,39 +130,17 @@ async fn build_runtime_with_remote_tools( async fn hosted_worker_remote_tool_catalog_registers_remote_tools() -> Result<(), Box> { let (base_url, server) = spawn_hosted_guidance_catalog_server().await?; + let (runtime, _client) = build_runtime_with_remote_tools(&base_url).await?; - let client = Arc::new(WorkerHttpClient::new( - base_url.clone(), - Uuid::nil(), - "test".to_string(), - )); - let runtime = WorkerRuntime::from_client( - WorkerConfig { - job_id: Uuid::nil(), - orchestrator_url: base_url, - ..WorkerConfig::default() - }, - client, - ); - - runtime.register_remote_tools().await?; - - let mut names: Vec = runtime - .tools - .tool_definitions() - .await - .into_iter() - .map(|def| def.name) - .collect(); + let definitions: Vec = runtime.tools.tool_definitions().await; + let mut names: Vec = definitions.into_iter().map(|def| def.name).collect(); names.sort(); assert_eq!(names, expected_merged_tool_names()); - let remote_tool = runtime - .tools - .get("hosted_worker_remote_tool_fixture") - .await - .expect("hosted remote tool should be registered"); + let remote_tool: Option> = + runtime.tools.get("hosted_worker_remote_tool_fixture").await; + let remote_tool = remote_tool.expect("hosted remote tool should be registered"); let expected = expected_remote_tool_definition(); assert_eq!(remote_tool.name(), expected.name); assert_eq!(remote_tool.description(), expected.description); @@ -248,14 +227,13 @@ async fn worker_runtime_refresh_keeps_merged_tools_without_duplicate_guidance() "expected one guidance message before refresh" ); - let delegate = ContainerDelegate { + let delegate = ContainerDelegate::new( client, - safety: Arc::clone(&runtime.safety), - tools: Arc::clone(&runtime.tools), - extra_env: Arc::clone(&runtime.extra_env), - last_output: Mutex::new(String::new()), - iteration_tracker: Arc::new(Mutex::new(0)), - }; + Arc::clone(&runtime.safety), + Arc::clone(&runtime.tools), + Arc::clone(&runtime.extra_env), + Arc::new(Mutex::new(0)), + ); let outcome = delegate.before_llm_call(&mut reason_ctx, 1).await; assert!( @@ -290,21 +268,23 @@ async fn worker_runtime_refresh_keeps_merged_tools_without_duplicate_guidance() #[tokio::test] async fn hosted_worker_remote_tool_catalog_degraded_startup_keeps_local_tools() -> Result<(), Box> { - let (base_url, server) = spawn_test_server(remote_tool_catalog_error).await?; + let (base_url, server) = spawn_test_server(remote_tool_catalog_error) + .await + .context("spawning test server in hosted_worker_remote_tool_catalog_degraded_startup_keeps_local_tools")?; - let client = Arc::new(WorkerHttpClient::new( - base_url.clone(), - Uuid::nil(), - "test".to_string(), - )); - let runtime = WorkerRuntime::from_client( + let client = Arc::new( + WorkerHttpClient::new(base_url.clone(), Uuid::nil(), "test".to_string()) + .context("building test WorkerHttpClient")?, + ); + let runtime = WorkerRuntime::new( WorkerConfig { job_id: Uuid::nil(), orchestrator_url: base_url, ..WorkerConfig::default() }, client, - ); + ) + .context("building WorkerRuntime")?; runtime.register_remote_tools_with_degraded_startup().await; diff --git a/src/worker/container/tests/shutdown.rs b/src/worker/container/tests/shutdown.rs new file mode 100644 index 000000000..a33a7545a --- /dev/null +++ b/src/worker/container/tests/shutdown.rs @@ -0,0 +1,92 @@ +//! Tests for container delegate shutdown behaviour. + +use std::sync::Arc; + +use anyhow::Result; +use axum::extract::{Path, State}; +use axum::http::StatusCode; +use axum::routing::post; +use axum::{Json, Router}; +use tokio::net::TcpListener; +use tokio::sync::{Mutex, Notify}; +use uuid::Uuid; + +use crate::worker::api::{EVENT_ROUTE, JobEventPayload, WorkerHttpClient}; +use crate::worker::container::delegate::ContainerDelegate; + +use super::test_support::build_test_runtime; + +#[derive(Default)] +struct EventState { + events: Mutex>, + notify: Notify, +} + +async fn event_handler( + State(state): State>, + Path(_job_id): Path, + Json(payload): Json, +) -> StatusCode { + state.events.lock().await.push(payload); + state.notify.notify_waiters(); + StatusCode::OK +} + +async fn spawn_event_server( + state: Arc, +) -> Result<(String, tokio::task::JoinHandle<()>)> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let app = Router::new() + .route(EVENT_ROUTE, post(event_handler)) + .with_state(state); + let handle = tokio::spawn(async move { + axum::serve(listener, app) + .await + .expect("event test server should run"); + }); + Ok((format!("http://{addr}"), handle)) +} + +#[tokio::test] +async fn container_delegate_shutdown_drains_buffered_events() -> Result<()> { + let state = Arc::new(EventState::default()); + let (base_url, handle) = spawn_event_server(Arc::clone(&state)).await?; + let runtime = build_test_runtime(base_url.clone(), Uuid::nil())?; + let client = Arc::new(WorkerHttpClient::new( + base_url, + Uuid::nil(), + "test-token".to_string(), + )?); + + let delegate = ContainerDelegate::new( + client, + Arc::clone(&runtime.safety), + Arc::clone(&runtime.tools), + Arc::clone(&runtime.extra_env), + Arc::new(Mutex::new(0)), + ); + + delegate.post_event( + crate::worker::api::JobEventType::Message, + serde_json::json!({"content": "queued"}), + ); + delegate.shutdown().await; + + tokio::time::timeout(std::time::Duration::from_secs(1), async { + loop { + if state.events.lock().await.len() == 1 { + break; + } + state.notify.notified().await; + } + }) + .await?; + let events = state.events.lock().await; + assert_eq!(events.len(), 1, "shutdown should flush the queued event"); + assert_eq!(events[0].data["content"], "queued"); + + handle.abort(); + let _ = handle.await; + Ok(()) +} diff --git a/src/worker/container/tests/test_support.rs b/src/worker/container/tests/test_support.rs index 1e8df3e29..2ebdab528 100644 --- a/src/worker/container/tests/test_support.rs +++ b/src/worker/container/tests/test_support.rs @@ -13,7 +13,10 @@ use tokio::sync::Mutex; use tokio::sync::oneshot; use uuid::Uuid; -use crate::worker::api::{CompletionReport, CredentialResponse, JobDescription, StatusUpdate}; +use crate::worker::api::{ + COMPLETE_ROUTE, CREDENTIALS_ROUTE, CompletionReport, CredentialResponse, EVENT_ROUTE, + JOB_ROUTE, JobDescription, PROMPT_ROUTE, STATUS_ROUTE, StatusUpdate, +}; use crate::worker::container::{WorkerConfig, WorkerHttpClient, WorkerRuntime}; /// Shared state for recording HTTP interactions from the worker runtime during tests. @@ -110,12 +113,12 @@ pub async fn spawn_runtime_test_server( let addr = listener.local_addr()?; let app = Router::new() - .route("/worker/{job_id}/job", get(job_handler)) - .route("/worker/{job_id}/credentials", get(credentials_handler)) - .route("/worker/{job_id}/prompt", get(prompt_handler)) - .route("/worker/{job_id}/status", post(status_handler)) - .route("/worker/{job_id}/complete", post(complete_handler)) - .route("/worker/{job_id}/event", post(event_handler)) + .route(JOB_ROUTE, get(job_handler)) + .route(CREDENTIALS_ROUTE, get(credentials_handler)) + .route(PROMPT_ROUTE, get(prompt_handler)) + .route(STATUS_ROUTE, post(status_handler)) + .route(COMPLETE_ROUTE, post(complete_handler)) + .route(EVENT_ROUTE, post(event_handler)) .with_state(state); let (shutdown_tx, shutdown_rx) = oneshot::channel(); @@ -133,20 +136,23 @@ pub async fn spawn_runtime_test_server( /// orchestrator URL and job ID. /// /// Uses a fixed test token (`"test-token"`) and default configuration suitable for unit tests. -pub fn build_test_runtime(orchestrator_url: String, job_id: Uuid) -> WorkerRuntime { +pub fn build_test_runtime( + orchestrator_url: String, + job_id: Uuid, +) -> Result { let client = Arc::new(WorkerHttpClient::new( orchestrator_url.clone(), job_id, "test-token".to_string(), - )); - WorkerRuntime::from_client( + )?); + Ok(WorkerRuntime::new( WorkerConfig { job_id, orchestrator_url, ..WorkerConfig::default() }, client, - ) + )?) } /// Test harness that owns the `WorkerRuntime` under test and coordinates graceful shutdown @@ -209,10 +215,10 @@ impl Drop for RuntimeTestHarness { pub async fn setup_runtime_test( state: Arc, job_id: Uuid, -) -> std::io::Result { +) -> anyhow::Result { let (orchestrator_url, shutdown_tx, handle) = spawn_runtime_test_server(Arc::clone(&state)).await?; - let runtime = build_test_runtime(orchestrator_url, job_id); + let runtime = build_test_runtime(orchestrator_url, job_id)?; Ok(RuntimeTestHarness { runtime: Some(runtime), shutdown_tx: Some(shutdown_tx), diff --git a/src/worker/mod.rs b/src/worker/mod.rs index e815496f5..abed8b202 100644 --- a/src/worker/mod.rs +++ b/src/worker/mod.rs @@ -38,6 +38,8 @@ pub use container::WorkerRuntime; pub use job::{Worker, WorkerDeps}; pub use proxy_llm::ProxyLlmProvider; +use anyhow::Context as _; + /// Run the Worker subcommand (inside Docker containers). pub async fn run_worker( job_id: uuid::Uuid, @@ -57,12 +59,9 @@ pub async fn run_worker( timeout: std::time::Duration::from_secs(600), }; - let rt = - WorkerRuntime::new(config).map_err(|e| anyhow::anyhow!("Worker init failed: {}", e))?; + let rt = WorkerRuntime::from_env(config).context("Worker init failed")?; - rt.run() - .await - .map_err(|e| anyhow::anyhow!("Worker failed: {}", e)) + rt.run().await.context("Worker failed") } /// Run the Claude Code bridge subcommand (inside Docker containers). @@ -88,10 +87,7 @@ pub async fn run_claude_bridge( allowed_tools: crate::config::ClaudeCodeConfig::from_env().allowed_tools, }; - let rt = ClaudeBridgeRuntime::new(config) - .map_err(|e| anyhow::anyhow!("Claude bridge init failed: {}", e))?; + let rt = ClaudeBridgeRuntime::new(config).context("Claude bridge init failed")?; - rt.run() - .await - .map_err(|e| anyhow::anyhow!("Claude bridge failed: {}", e)) + rt.run().await.context("Claude bridge failed") } diff --git a/src/worker/proxy_llm.rs b/src/worker/proxy_llm.rs index e853b7201..f1ff298c9 100644 --- a/src/worker/proxy_llm.rs +++ b/src/worker/proxy_llm.rs @@ -67,25 +67,26 @@ impl NativeLlmProvider for ProxyLlmProvider { mod tests { use super::*; + fn test_client() -> Arc { + Arc::new( + WorkerHttpClient::new( + "http://localhost:50051".to_string(), + uuid::Uuid::nil(), + "test".to_string(), + ) + .expect("test client should build"), + ) + } + #[test] fn test_proxy_model_name() { - let client = Arc::new(WorkerHttpClient::new( - "http://localhost:50051".to_string(), - uuid::Uuid::nil(), - "test".to_string(), - )); - let provider = ProxyLlmProvider::new(client, "test-model".to_string()); + let provider = ProxyLlmProvider::new(test_client(), "test-model".to_string()); assert_eq!(provider.model_name(), "test-model"); } #[test] fn test_proxy_cost_is_zero() { - let client = Arc::new(WorkerHttpClient::new( - "http://localhost:50051".to_string(), - uuid::Uuid::nil(), - "test".to_string(), - )); - let provider = ProxyLlmProvider::new(client, "test-model".to_string()); + let provider = ProxyLlmProvider::new(test_client(), "test-model".to_string()); let (input, output) = provider.cost_per_token(); assert_eq!(input, Decimal::ZERO); assert_eq!(output, Decimal::ZERO);