diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 0c584c1..47c021c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -4,7 +4,7 @@ on: workflow_dispatch: inputs: version: - description: 'Version to release (e.g., 0.3.0)' + description: 'Version to release (e.g., 0.5.0)' required: true type: string dry_run: @@ -77,7 +77,7 @@ jobs: - name: Validate version format run: | if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then - echo "Error: Version must be in format X.Y.Z (e.g., 0.3.0)" + echo "Error: Version must be in format X.Y.Z (e.g., 0.5.0)" exit 1 fi diff --git a/CHANGELOG.md b/CHANGELOG.md index 887a06e..b35915c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.5.0] - 2026-05-08 + +### Added +- `AppRunner::create_iterative_session(...)` and `AppRunner::invoke_next(...)` for repeated graph invocations under one durable session lineage. +- `RunnerError::InvalidIterativeEntry` for invalid iterative entry nodes. +- Typed state-slot helpers: `StateKey`, `StateSnapshot::get_typed(...)`, `StateSnapshot::require_typed(...)`, `VersionedState::add_typed_extra(...)`, `VersionedStateBuilder::with_typed_extra(...)`, and `NodePartial::with_typed_extra(...)`. +- Runtime clock injection through `RuntimeConfig::with_clock(...)`, `AppRunnerBuilder::clock(...)`, and `NodeContext::now_unix_ms()`. +- Optional node event metadata for `invocation_id` and `now_unix_ms` when runtime metadata is configured. +- `INVOCATION_END_SCOPE` and `AppRunner::finish_iterative_session(...)` for long-lived iterative event streams. +- Graph and run metadata helpers: `App::graph_metadata()`, `App::graph_definition_hash()`, `RuntimeConfig::config_hash()`, and `AppRunner::run_metadata()`. +- `Reducer::definition_label(...)` so graph metadata can distinguish reducer implementations, not only reducer counts. +- Replay conformance helpers in `weavegraph::runtimes::replay` for normalized event comparison, final-state comparison, and reusable replay assertions. + +### Notes +- This feedback package ships as `0.5.0` rather than `0.4.1` because it changes the public runtime surface, adds public error enum variants/types, and extends public structs. +- New public metadata/context structs are marked `#[non_exhaustive]` where they are expected to grow before v1. + ## [0.4.0] - 2026-04-01 ### Added diff --git a/Cargo.toml b/Cargo.toml index d0b8a1e..110f541 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "weavegraph" -version = "0.4.0" +version = "0.5.0" edition = "2024" description = "Graph-driven, concurrent agent workflow framework with versioned state, deterministic barrier merges, and rich diagnostics." license = "MIT" diff --git a/README.md b/README.md index c36ff29..111126e 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,13 @@ Weavegraph lets you build robust, concurrent, stateful workflows using a graph-b - Concurrent graph execution with dependency resolution - Type-safe, role-based message system - Versioned state with snapshot isolation +- Typed state slots for schema-versioned JSON payloads - Structured error handling and diagnostics - Built-in event streaming and observability -- Flexible persistence: SQLite or in-memory +- Flexible persistence: SQLite, PostgreSQL, or in-memory - Conditional routing and dynamic edges +- Iterative checkpointed sessions for repeated invocations +- Replay conformance helpers and deterministic graph/run metadata - Ergonomic APIs and comprehensive examples ## Install @@ -33,10 +36,10 @@ Add to your `Cargo.toml`: ```toml [dependencies] -weavegraph = "0.3" +weavegraph = "0.5" ``` -> **Note:** Examples and instructions in this README are current as of 0.3.x. For upgrading from 0.2.x, see [MIGRATION.md](docs/MIGRATION.md). +> **Note:** Examples and instructions in this README are current as of 0.5.x. For upgrade notes across pre-1.0 releases, see [MIGRATION.md](docs/MIGRATION.md). ## Dependency Compatibility @@ -64,10 +67,10 @@ See [Cargo.toml](Cargo.toml) for complete dependency versions and feature config ```rust use weavegraph::{ - graphs::GraphBuilder, - message::Message, - node::{Node, NodeContext, NodePartial}, - state::VersionedState, + graphs::GraphBuilder, + message::Message, + node::{Node, NodeContext, NodePartial}, + state::VersionedState, }; use async_trait::async_trait; @@ -75,29 +78,29 @@ struct HelloNode; #[async_trait] impl Node for HelloNode { - async fn run( - &self, - _snapshot: weavegraph::state::StateSnapshot, - _ctx: NodeContext, - ) -> Result { - Ok(NodePartial::new().with_messages(vec![Message::assistant("Hello, world!")])) - } + async fn run( + &self, + _snapshot: weavegraph::state::StateSnapshot, + _ctx: NodeContext, + ) -> Result { + Ok(NodePartial::new().with_messages(vec![Message::assistant("Hello, world!")])) + } } #[tokio::main] async fn main() -> Result<(), Box> { - use weavegraph::types::NodeKind; - let app = GraphBuilder::new() - .add_node(NodeKind::Custom("hello".into()), HelloNode) - .add_edge(NodeKind::Start, NodeKind::Custom("hello".into())) - .add_edge(NodeKind::Custom("hello".into()), NodeKind::End) - .compile()?; - let state = VersionedState::new_with_user_message("Hi!"); - let result = app.invoke(state).await?; - for message in result.messages.snapshot() { - println!("{}: {}", message.role, message.content); - } - Ok(()) + use weavegraph::types::NodeKind; + let app = GraphBuilder::new() + .add_node(NodeKind::Custom("hello".into()), HelloNode) + .add_edge(NodeKind::Start, NodeKind::Custom("hello".into())) + .add_edge(NodeKind::Custom("hello".into()), NodeKind::End) + .compile()?; + let state = VersionedState::new_with_user_message("Hi!"); + let result = app.invoke(state).await?; + for message in result.messages.snapshot() { + println!("{}: {}", message.role, message.content); + } + Ok(()) } ``` > NOTE: `NodeKind::Start` and `NodeKind::End` are virtual structural endpoints. @@ -109,27 +112,36 @@ async fn main() -> Result<(), Box> { For testing and ephemeral workflows use the InMemory checkpointer: ```rust +use weavegraph::runtimes::{AppRunner, CheckpointerType}; + // After compiling the graph into an `App`: let runner = AppRunner::builder() - .app(app) - .checkpointer(CheckpointerType::InMemory) - .build() - .await; + .app(app) + .checkpointer(CheckpointerType::InMemory) + .build() + .await; ``` Run the comprehensive test suite: ```bash -# All tests with output -cargo test --all -- --nocapture +# Full integration test suite +cargo nextest run + +# Documentation examples +cargo test --doc + +# Lints used by CI +cargo clippy --all-features --all-targets -- -D warnings +cargo clippy --no-default-features --lib -- -D warnings # Specific test categories -cargo test schedulers:: -- --nocapture -cargo test channels:: -- --nocapture -cargo test integration:: -- --nocapture +cargo test --test schedulers +cargo test --test event_bus +cargo test --test runtimes_runner ``` -Property-based testing with `proptest` ensures correctness across edge cases. +Property-based testing with `proptest` and fuzz harnesses under [fuzz/](fuzz/) exercise edge cases across graph routing, event serialization, replay comparison, and typed state slots. ## CI Parity @@ -151,7 +163,7 @@ Before merging or cutting a release, run full local parity checks: ## Resources -- **[Migration Guide](docs/MIGRATION.md)** - Upgrade paths between releases (0.2.x → 0.3.x and beyond) +- **[Migration Guide](docs/MIGRATION.md)** - Upgrade paths between pre-1.0 releases - **[Architecture Guide](docs/ARCHITECTURE.md)** - Deep dive into core design and internals - **[Examples Directory](examples/)** - Runnable patterns: graph execution, scheduling, streaming, persistence, and more diff --git a/docs/MIGRATION.md b/docs/MIGRATION.md index 94dc828..86964c2 100644 --- a/docs/MIGRATION.md +++ b/docs/MIGRATION.md @@ -5,6 +5,116 @@ migration guidance for upgrading your code. --- +## v0.5.0 + +### Overview + +v0.5.0 is the recommended target for the WeaveQuant production feedback work. The changes add new public runtime APIs and a public `RunnerError` variant, so they should not ship as a `0.4.1` patch. + +### New Runtime APIs + +Use `AppRunner::create_iterative_session(...)` and `AppRunner::invoke_next(...)` when one durable session should process many logical inputs: + +```rust +runner + .create_iterative_session(run_id.clone(), initial_state, NodeKind::Start) + .await?; + +runner + .invoke_next(&run_id, input_patch, NodeKind::Start) + .await?; +``` + +`NodeKind::Start` resolves to the graph's normal Start outgoing frontier. A registered custom node can be supplied for narrower re-entry. `NodeKind::End` now returns `RunnerError::InvalidIterativeEntry` when used as an iterative entry. + +When an `AppRunner` event stream is subscribed before iterative execution, each `invoke_next(...)` emits `INVOCATION_END_SCOPE` and keeps the stream open for the next logical input. Call `finish_iterative_session(...)` after the final input to emit the normal `STREAM_END_SCOPE` sentinel and close the stream. + +### Typed State Slots + +Typed state slots are a thin, JSON-compatible layer over `VersionedState.extra`. Define a reusable key in the domain crate, then read and write typed payloads without hand-rolled `serde_json` calls at every node boundary: + +```rust +use serde::{Deserialize, Serialize}; +use weavegraph::node::NodePartial; +use weavegraph::state::{StateKey, StateSnapshot}; + +#[derive(Serialize, Deserialize)] +struct PortfolioState { + cash_cents: i64, +} + +const PORTFOLIO: StateKey = StateKey::new("wq", "portfolio", 1); + +fn read(snapshot: &StateSnapshot) -> Result, weavegraph::state::StateSlotError> { + snapshot.get_typed(PORTFOLIO) +} + +fn write(value: PortfolioState) -> Result { + NodePartial::new().with_typed_extra(PORTFOLIO, value) +} +``` + +The storage key is namespaced and versioned as `namespace:name:v{schema_version}`. Untyped `extra` remains available. + +### Deterministic Runtime Clock + +Use the existing `Clock` abstraction to inject deterministic time into nodes and emitted node-event metadata: + +```rust +use std::sync::Arc; +use weavegraph::runtimes::{AppRunner, CheckpointerType}; +use weavegraph::utils::clock::MockClock; + +let runner = AppRunner::builder() + .app(app) + .checkpointer(CheckpointerType::InMemory) + .clock(Arc::new(MockClock::new(1_700_000_000))) + .build() + .await; +``` + +Inside a node, call `ctx.now_unix_ms()` and `ctx.invocation_id()`. `NodeContext::new(...)` is now the easiest way to construct contexts in tests. + +### Metadata Helpers + +Compiled graphs and runners expose deterministic metadata helpers for audit labels and replay manifests: + +```rust +let graph = app.graph_metadata(); +let graph_hash = app.graph_definition_hash(); +let run = runner.run_metadata(); +``` + +The graph hash includes node kinds, edges, conditional edge registrations, and reducer definition labels. It does not inspect closure bodies for conditional predicates. Custom reducers can override `Reducer::definition_label(...)` when a durable audit label is preferable to the default Rust type path. + +### Replay Conformance Helpers + +Replay helpers live under `weavegraph::runtimes::replay` and are re-exported from `weavegraph::runtimes`: + +```rust +use weavegraph::runtimes::{ReplayRun, compare_replay_runs}; + +let expected = ReplayRun::new(expected_state, expected_events); +let actual = ReplayRun::new(actual_state, actual_events); + +compare_replay_runs(&expected, &actual).assert_matches()?; +``` + +`normalize_event(...)` strips runtime timestamps. Use `compare_event_sequences_with(...)` or `compare_replay_runs_with(...)` when domain events need semantic normalization. + +### Compatibility Notes + +- `App::invoke(...)`, `AppRunner::create_session(...)`, and `AppRunner::run_until_complete(...)` keep their existing behavior. +- `RunnerError` is an exhaustive public enum. Code that matches every variant must handle `InvalidIterativeEntry` after upgrading. +- `GraphMetadata`, `RunMetadata`, `ReplayRun`, `NodeContext`, and `SchedulerRunContext` are `#[non_exhaustive]`; use provided constructors/builders instead of external struct literals. +- `Reducer` gains a default `definition_label(...)` method for graph metadata. Existing reducer implementations do not need to change unless they want a custom stable label. +- `RuntimeConfig` gains a public `clock` field. Code using struct literals should add `clock: None` or switch to `RuntimeConfig::default()` / builder-style methods. +- `NodeContext` gains `clock` and `invocation_id` fields. Tests should prefer `NodeContext::new(...)` over struct literals. +- Direct calls to `Scheduler::superstep(...)` must pass the optional clock and invocation ID arguments. +- Iterative sessions keep step numbers monotonic across invocations and reload checkpoints through the existing checkpointer path. + +--- + ## v0.4.0 ### Overview diff --git a/docs/OPERATIONS.md b/docs/OPERATIONS.md index 3988d7f..0679201 100644 --- a/docs/OPERATIONS.md +++ b/docs/OPERATIONS.md @@ -157,6 +157,116 @@ let runner = AppRunner::builder() .await?; ``` +### Iterative Checkpointed Workflows + +Use iterative sessions when one logical run should process many inputs while keeping one checkpoint lineage. This is useful for event-driven systems that repeatedly restore the latest durable state, apply the next input, run the graph, and checkpoint the result. + +```rust +use weavegraph::node::NodePartial; +use weavegraph::runtimes::{AppRunner, CheckpointerType}; +use weavegraph::state::VersionedState; +use weavegraph::types::NodeKind; +use weavegraph::utils::collections::new_extra_map; + +# async fn example(app: weavegraph::app::App) -> Result<(), Box> { +let mut runner = AppRunner::builder() + .app(app) + .checkpointer(CheckpointerType::SQLite) + .autosave(true) + .build() + .await; + +let run_id = "market-run-2026-05-08".to_string(); +runner + .create_iterative_session( + run_id.clone(), + VersionedState::new_with_user_message("start"), + NodeKind::Start, + ) + .await?; + +for tick in [1, 2, 3] { + let mut extra = new_extra_map(); + extra.insert("tick".to_string(), serde_json::json!(tick)); + + runner + .invoke_next(&run_id, NodePartial::new().with_extra(extra), NodeKind::Start) + .await?; +} +# Ok(()) +# } +``` + +`NodeKind::Start` means the same initial frontier as a normal session: the graph's outgoing edges from the virtual Start node. A registered custom node can be used to resume from a narrower entry point. `NodeKind::End` is rejected because it is terminal. + +The runner keeps `SessionState.step` monotonic across invocations. It also clears scheduler version-gating state for each `invoke_next` call, so the entry path runs for each logical input even when two consecutive input patches are identical. + +If you subscribe with `AppRunner::event_stream()` before an iterative run, each `invoke_next(...)` emits `INVOCATION_END_SCOPE` and leaves the stream open for the next input. After the final input, call `finish_iterative_session(...)` to emit `STREAM_END_SCOPE` and close the stream for consumers that expect the standard terminal sentinel. + +### Typed State Slots + +Use `StateKey` when checkpointed `extra` state needs a documented schema and compile-time payload type while staying JSON-compatible across backends. + +```rust +use serde::{Deserialize, Serialize}; +use weavegraph::node::NodePartial; +use weavegraph::state::{StateKey, StateSnapshot}; + +#[derive(Serialize, Deserialize)] +struct PortfolioState { + cash_cents: i64, +} + +const PORTFOLIO: StateKey = StateKey::new("wq", "portfolio", 1); + +fn load(snapshot: &StateSnapshot) -> Result { + snapshot.require_typed(PORTFOLIO) +} + +fn store(value: PortfolioState) -> Result { + NodePartial::new().with_typed_extra(PORTFOLIO, value) +} +``` + +The generated storage key is `namespace:name:v{schema_version}`, so old and new schemas can coexist during migrations. + +### Deterministic Clock And Run Metadata + +Inject a clock when simulations, replay, or tests need logical time to be independent of wall-clock time. The same clock is available from `NodeContext::now_unix_ms()` and is attached to node event metadata when present. + +```rust +use std::sync::Arc; +use weavegraph::runtimes::{AppRunner, CheckpointerType}; +use weavegraph::utils::clock::MockClock; + +let runner = AppRunner::builder() + .app(app) + .checkpointer(CheckpointerType::InMemory) + .clock(Arc::new(MockClock::new(1_700_000_000))) + .build() + .await; + +let metadata = runner.run_metadata(); +println!("graph={} runtime={} clock={}", metadata.graph_hash, metadata.runtime_config_hash, metadata.clock_mode); +``` + +`App::graph_metadata()` and `App::graph_definition_hash()` are useful for replay manifests and checkpoint labels. The graph hash covers the graph definition surface, including node kinds, edges, conditional edge registrations, and reducer definition labels. Custom reducers can override `Reducer::definition_label(...)` when a stable domain label is more appropriate than the default Rust type path. + +### Replay Conformance Checks + +Replay helpers compare normalized events and final state snapshots for uninterrupted/resumed run parity. + +```rust +use weavegraph::runtimes::{ReplayRun, compare_replay_runs}; + +let expected = ReplayRun::new(expected_state, expected_events); +let actual = ReplayRun::new(actual_state, actual_events); + +compare_replay_runs(&expected, &actual).assert_matches()?; +``` + +Use `compare_event_sequences_with(...)` or `compare_replay_runs_with(...)` when domain events need custom normalization before comparison. + ### Storage Management **InMemoryCheckpointer** stores only the latest checkpoint per session (automatic retention). No storage management needed. diff --git a/examples/advanced_patterns.rs b/examples/advanced_patterns.rs index d36dae7..34ecb26 100644 --- a/examples/advanced_patterns.rs +++ b/examples/advanced_patterns.rs @@ -399,11 +399,7 @@ async fn main() -> ExampleResult<()> { let emitter = event_bus.get_emitter(); - let ctx1 = NodeContext { - node_id: "api_call".to_string(), - step: 1, - event_emitter: Arc::clone(&emitter), - }; + let ctx1 = NodeContext::new("api_call", 1, Arc::clone(&emitter)); // Demonstrate both success and failure scenarios match api_node.run(state.snapshot(), ctx1).await { @@ -449,11 +445,7 @@ async fn main() -> ExampleResult<()> { max_retries: 2, }; - let ctx1_1 = NodeContext { - node_id: "metrics_api".to_string(), - step: 1, - event_emitter: Arc::clone(&emitter), - }; + let ctx1_1 = NodeContext::new("metrics_api", 1, Arc::clone(&emitter)); match failing_api_node.run(state.snapshot(), ctx1_1).await { Ok(result) => { @@ -487,11 +479,7 @@ async fn main() -> ExampleResult<()> { }, }; - let ctx2 = NodeContext { - node_id: "router".to_string(), - step: 2, - event_emitter: Arc::clone(&emitter), - }; + let ctx2 = NodeContext::new("router", 2, Arc::clone(&emitter)); let result2 = router_node.run(state.snapshot(), ctx2).await?; if let Some(messages) = result2.messages { @@ -529,11 +517,7 @@ async fn main() -> ExampleResult<()> { ], }; - let ctx3 = NodeContext { - node_id: "transformer".to_string(), - step: 3, - event_emitter: Arc::clone(&emitter), - }; + let ctx3 = NodeContext::new("transformer", 3, Arc::clone(&emitter)); let result3 = transformer_node.run(state.snapshot(), ctx3).await?; if let Some(messages) = result3.messages { diff --git a/examples/basic_nodes.rs b/examples/basic_nodes.rs index 105708e..f45244b 100644 --- a/examples/basic_nodes.rs +++ b/examples/basic_nodes.rs @@ -261,11 +261,7 @@ async fn main() -> ExampleResult<()> { let emitter = event_bus.get_emitter(); - let ctx1 = NodeContext { - node_id: "counter-1".to_string(), - step: 2, - event_emitter: Arc::clone(&emitter), - }; + let ctx1 = NodeContext::new("counter-1", 2, Arc::clone(&emitter)); let result1 = counter_node.run(state.snapshot(), ctx1).await?; @@ -290,11 +286,7 @@ async fn main() -> ExampleResult<()> { min_message_count: 1, }; - let ctx2 = NodeContext { - node_id: "validator-1".to_string(), - step: 3, - event_emitter: Arc::clone(&emitter), - }; + let ctx2 = NodeContext::new("validator-1", 3, Arc::clone(&emitter)); let result2 = validation_node.run(state.snapshot(), ctx2).await?; @@ -308,11 +300,7 @@ async fn main() -> ExampleResult<()> { info!("\nšŸ“ˆ Running AggregatorNode..."); let aggregator_node = AggregatorNode; - let ctx3 = NodeContext { - node_id: "aggregator-1".to_string(), - step: 4, - event_emitter: Arc::clone(&emitter), - }; + let ctx3 = NodeContext::new("aggregator-1", 4, Arc::clone(&emitter)); let result3 = aggregator_node.run(state.snapshot(), ctx3).await?; diff --git a/examples/convenience_streaming.rs b/examples/convenience_streaming.rs index bc25c0a..f896b85 100644 --- a/examples/convenience_streaming.rs +++ b/examples/convenience_streaming.rs @@ -37,6 +37,7 @@ use weavegraph::{ graphs::GraphBuilder, message::{Message, Role}, node::{Node, NodeContext, NodeError, NodePartial}, + runtimes::RuntimeConfig, state::{StateSnapshot, VersionedState}, types::NodeKind, }; @@ -101,6 +102,7 @@ async fn main() -> ExampleResult<()> { .add_node(NodeKind::Custom("progress".into()), ProgressNode::new(3)) .add_edge(NodeKind::Start, NodeKind::Custom("progress".into())) .add_edge(NodeKind::Custom("progress".into()), NodeKind::End) + .with_runtime_config(RuntimeConfig::new(None, None).with_memory_event_bus()) .compile()?; // ============================================================================ @@ -159,7 +161,7 @@ async fn main() -> ExampleResult<()> { // Example 2: invoke_with_sinks() - Multiple destinations // ============================================================================ info!("## Example 2: invoke_with_sinks()"); - info!(" Use case: Events to multiple destinations (stdout + channel + file)\n"); + info!(" Use case: Events to multiple destinations (stdout + channel)\n"); let (tx, rx) = flume::unbounded(); @@ -170,8 +172,7 @@ async fn main() -> ExampleResult<()> { // Spawn background collector for channel let channel_collector = tokio::spawn(async move { let mut events = Vec::new(); - let timeout = tokio::time::Duration::from_millis(100); - while let Ok(Ok(event)) = tokio::time::timeout(timeout, rx.recv_async()).await { + while let Ok(event) = rx.recv_async().await { events.push(event); } events diff --git a/fuzz/.gitignore b/fuzz/.gitignore new file mode 100644 index 0000000..ea9ce31 --- /dev/null +++ b/fuzz/.gitignore @@ -0,0 +1,5 @@ +artifacts/ +corpus/ +coverage/ +target/ +Cargo.lock diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml new file mode 100644 index 0000000..cc2a110 --- /dev/null +++ b/fuzz/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "weavegraph-fuzz" +version = "0.0.0" +publish = false +edition = "2024" + +[package.metadata] +cargo-fuzz = true + +[dependencies] +libfuzzer-sys = "0.4" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +weavegraph = { path = "..", default-features = false } + +[[bin]] +name = "event_json" +path = "fuzz_targets/event_json.rs" +test = false +doc = false +bench = false + +[[bin]] +name = "replay_compare" +path = "fuzz_targets/replay_compare.rs" +test = false +doc = false +bench = false + +[[bin]] +name = "state_slots" +path = "fuzz_targets/state_slots.rs" +test = false +doc = false +bench = false diff --git a/fuzz/README.md b/fuzz/README.md new file mode 100644 index 0000000..2a9b681 --- /dev/null +++ b/fuzz/README.md @@ -0,0 +1,12 @@ +# Fuzz Targets + +These targets are intended for `cargo-fuzz` and are kept outside the normal crate build. + +```bash +cargo install cargo-fuzz +cargo +nightly fuzz run event_json +cargo +nightly fuzz run replay_compare +cargo +nightly fuzz run state_slots +``` + +The targets cover event JSON decoding/normalization, replay comparison helpers, and typed state slot serialization boundaries. diff --git a/fuzz/fuzz_targets/event_json.rs b/fuzz/fuzz_targets/event_json.rs new file mode 100644 index 0000000..6e3b14e --- /dev/null +++ b/fuzz/fuzz_targets/event_json.rs @@ -0,0 +1,22 @@ +#![no_main] + +use libfuzzer_sys::fuzz_target; +use weavegraph::event_bus::{Event, NodeEvent}; +use weavegraph::runtimes::normalize_event; + +fuzz_target!(|data: &[u8]| { + if let Ok(event) = serde_json::from_slice::(data) { + let _ = event.scope_label(); + let _ = event.message(); + let _ = event.to_json_value(); + let _ = event.to_json_string(); + let _ = event.to_json_pretty(); + let _ = normalize_event(&event); + } + + if let Ok(node_event) = serde_json::from_slice::(data) { + let event = Event::Node(node_event); + let _ = event.to_json_value(); + let _ = normalize_event(&event); + } +}); diff --git a/fuzz/fuzz_targets/replay_compare.rs b/fuzz/fuzz_targets/replay_compare.rs new file mode 100644 index 0000000..6f4f627 --- /dev/null +++ b/fuzz/fuzz_targets/replay_compare.rs @@ -0,0 +1,49 @@ +#![no_main] + +use libfuzzer_sys::fuzz_target; +use serde_json::json; +use weavegraph::event_bus::Event; +use weavegraph::runtimes::{ + ReplayRun, compare_event_sequences, compare_event_sequences_with, compare_replay_runs, +}; +use weavegraph::state::VersionedState; + +fn events_from_bytes(data: &[u8]) -> Vec { + data.chunks(8) + .take(32) + .enumerate() + .map(|(index, chunk)| { + Event::diagnostic( + format!("scope-{index}"), + String::from_utf8_lossy(chunk).to_string(), + ) + }) + .collect() +} + +fuzz_target!(|data: &[u8]| { + let midpoint = data.len() / 2; + let left_events = events_from_bytes(&data[..midpoint]); + let right_events = events_from_bytes(&data[midpoint..]); + + compare_event_sequences(&left_events, &left_events) + .assert_matches() + .expect("event comparison must be reflexive"); + let _ = compare_event_sequences(&left_events, &right_events); + let _ = compare_event_sequences_with(&left_events, &right_events, |_| json!("ignored")); + + let left_state = VersionedState::builder() + .with_extra("bytes", json!(data.len())) + .build(); + let right_state = VersionedState::builder() + .with_extra("bytes", json!(midpoint)) + .build(); + let left_run = ReplayRun::new(left_state.clone(), left_events.clone()); + let same_run = ReplayRun::new(left_state, left_events); + let right_run = ReplayRun::new(right_state, right_events); + + compare_replay_runs(&left_run, &same_run) + .assert_matches() + .expect("replay comparison must be reflexive"); + let _ = compare_replay_runs(&left_run, &right_run); +}); diff --git a/fuzz/fuzz_targets/state_slots.rs b/fuzz/fuzz_targets/state_slots.rs new file mode 100644 index 0000000..1bfe7a5 --- /dev/null +++ b/fuzz/fuzz_targets/state_slots.rs @@ -0,0 +1,56 @@ +#![no_main] + +use libfuzzer_sys::fuzz_target; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use weavegraph::node::NodePartial; +use weavegraph::state::{StateKey, VersionedState}; + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +struct FuzzPayload { + label: String, + amount: i64, + bytes: Vec, +} + +const FUZZ_SLOT: StateKey = StateKey::new("fuzz", "payload", 1); + +fn amount_from_bytes(data: &[u8]) -> i64 { + let mut bytes = [0_u8; 8]; + let len = data.len().min(bytes.len()); + bytes[..len].copy_from_slice(&data[..len]); + i64::from_le_bytes(bytes) +} + +fuzz_target!(|data: &[u8]| { + let split = data.len().min(32); + let payload = FuzzPayload { + label: String::from_utf8_lossy(&data[..split]).to_string(), + amount: amount_from_bytes(data), + bytes: data.iter().copied().take(64).collect(), + }; + + let state = VersionedState::builder() + .with_typed_extra(FUZZ_SLOT, payload.clone()) + .expect("fuzz payload should serialize") + .build(); + assert_eq!( + state + .snapshot() + .require_typed(FUZZ_SLOT) + .expect("fuzz payload should deserialize"), + payload + ); + + let partial = NodePartial::new() + .with_typed_extra(FUZZ_SLOT, payload) + .expect("fuzz payload should serialize into partial"); + if let Some(extra) = partial.extra { + assert!(extra.contains_key(&FUZZ_SLOT.storage_key())); + } + + let invalid_state = VersionedState::builder() + .with_extra(&FUZZ_SLOT.storage_key(), json!(String::from_utf8_lossy(data).to_string())) + .build(); + let _ = invalid_state.snapshot().get_typed::(FUZZ_SLOT); +}); diff --git a/src/app.rs b/src/app.rs index 1fa7dce..636159d 100644 --- a/src/app.rs +++ b/src/app.rs @@ -122,6 +122,38 @@ pub struct BarrierOutcome { pub frontier_commands: Vec<(NodeKind, FrontierCommand)>, } +/// Stable metadata describing a compiled graph definition. +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub struct GraphMetadata { + /// Weavegraph crate version used to build the graph. + pub weavegraph_version: String, + /// Deterministic hash of the graph definition surface. + pub graph_hash: String, + /// Number of registered executable nodes. + pub node_count: usize, + /// Number of unconditional edges. + pub edge_count: usize, + /// Number of conditional edge registrations. + pub conditional_edge_count: usize, + /// Reducer registration signature included in the hash. + pub reducer_signature: Vec, +} + +fn hash_parts(parts: &[String]) -> String { + const FNV_OFFSET: u64 = 0xcbf29ce484222325; + const FNV_PRIME: u64 = 0x100000001b3; + + let mut hash = FNV_OFFSET; + for part in parts { + for byte in part.as_bytes().iter().copied().chain([0xff]) { + hash ^= u64::from(byte); + hash = hash.wrapping_mul(FNV_PRIME); + } + } + format!("{hash:016x}") +} + impl AppEventStream { fn new(event_bus: EventBus, event_stream: EventStream) -> Self { Self { @@ -303,6 +335,74 @@ impl App { &self.runtime_config } + /// Return the Weavegraph crate version compiled into this binary. + #[must_use] + pub fn weavegraph_version(&self) -> &'static str { + env!("CARGO_PKG_VERSION") + } + + /// Return metadata describing this graph definition. + /// + /// The hash covers registered node IDs, unconditional edges, conditional + /// edge sources/counts, and reducer definition labels. Conditional edge + /// predicate closure bodies are opaque to Rust, so changing a predicate + /// implementation without changing its registration shape is not detectable. + #[must_use] + pub fn graph_metadata(&self) -> GraphMetadata { + let mut parts = vec!["weavegraph-graph-v1".to_string()]; + + let mut nodes: Vec = self.nodes.keys().map(NodeKind::encode).collect(); + nodes.sort(); + parts.extend(nodes.iter().map(|node| format!("node:{node}"))); + + let mut edges: Vec = self + .edges + .iter() + .flat_map(|(from, targets)| { + targets + .iter() + .map(move |target| format!("edge:{}->{}", from.encode(), target.encode())) + }) + .collect(); + edges.sort(); + parts.extend(edges); + + let mut conditional_sources: Vec = self + .conditional_edges + .iter() + .map(|edge| edge.from().encode()) + .collect(); + conditional_sources.sort(); + parts.extend( + conditional_sources + .iter() + .enumerate() + .map(|(index, from)| format!("conditional:{index}:{from}")), + ); + + let reducer_signature = self.reducer_registry.definition_signature(); + parts.extend( + reducer_signature + .iter() + .map(|entry| format!("reducer:{entry}")), + ); + + GraphMetadata { + weavegraph_version: self.weavegraph_version().to_string(), + graph_hash: hash_parts(&parts), + node_count: self.nodes.len(), + edge_count: self.edges.values().map(Vec::len).sum(), + conditional_edge_count: self.conditional_edges.len(), + reducer_signature, + } + } + + /// Return only the graph definition hash. + #[must_use] + pub fn graph_definition_hash(&self) -> String { + self.graph_metadata().graph_hash + } + /// Create a subscription to the configured event bus without starting execution. /// /// This is the low-level entry point when you want to inspect the stream or diff --git a/src/event_bus/event.rs b/src/event_bus/event.rs index fb997d8..735fac3 100644 --- a/src/event_bus/event.rs +++ b/src/event_bus/event.rs @@ -12,6 +12,12 @@ use serde_json::Value; /// so that consumers can detect clean stream termination. pub const STREAM_END_SCOPE: &str = "__weavegraph_stream_end__"; +/// Scope constant marking the end of one logical invocation while a stream stays open. +/// +/// Iterative runners emit this after each [`AppRunner::invoke_next`](crate::runtimes::AppRunner::invoke_next) +/// call so subscribers can separate logical inputs without treating the event bus as closed. +pub const INVOCATION_END_SCOPE: &str = "__weavegraph_invocation_end__"; + /// Scope constant for diagnostic events emitted by the framework. /// /// Use this scope when emitting internal diagnostic information @@ -52,6 +58,25 @@ impl Event { )) } + /// Create a node event with full metadata and additional runtime labels. + pub fn node_message_with_metadata( + node_id: impl Into, + step: u64, + scope: impl Into, + message: impl Into, + metadata: FxHashMap, + ) -> Self { + Event::Node( + NodeEvent::new( + Some(node_id.into()), + Some(step), + scope.into(), + message.into(), + ) + .with_metadata(metadata), + ) + } + /// Create a diagnostic event with the given scope and message. pub fn diagnostic(scope: impl Into, message: impl Into) -> Self { Event::Diagnostic(DiagnosticEvent { @@ -111,6 +136,9 @@ impl Event { let (event_type, metadata) = match self { Event::Node(node) => { let mut meta = serde_json::Map::new(); + for (key, value) in node.metadata() { + meta.insert(key.clone(), value.clone()); + } if let Some(node_id) = node.node_id() { meta.insert("node_id".to_string(), json!(node_id)); } @@ -222,6 +250,8 @@ pub struct NodeEvent { step: Option, scope: String, message: String, + #[serde(default)] + metadata: FxHashMap, } impl NodeEvent { @@ -232,6 +262,7 @@ impl NodeEvent { step, scope, message, + metadata: FxHashMap::default(), } } @@ -254,6 +285,17 @@ impl NodeEvent { pub fn message(&self) -> &str { &self.message } + + /// Returns the metadata map attached to this node event. + pub fn metadata(&self) -> &FxHashMap { + &self.metadata + } + + /// Return a new node event with the given metadata map. + pub fn with_metadata(mut self, metadata: FxHashMap) -> Self { + self.metadata = metadata; + self + } } /// A framework-internal diagnostic event emitted outside normal node execution. diff --git a/src/event_bus/mod.rs b/src/event_bus/mod.rs index e47e39b..2bb37f5 100644 --- a/src/event_bus/mod.rs +++ b/src/event_bus/mod.rs @@ -23,6 +23,8 @@ pub mod sink; pub use bus::EventBus; pub use diagnostics::{DiagnosticsStream, SinkDiagnostic}; pub use emitter::{EmitterError, EventEmitter}; -pub use event::{DIAGNOSTIC_SCOPE, Event, LLMStreamingEvent, NodeEvent, STREAM_END_SCOPE}; +pub use event::{ + DIAGNOSTIC_SCOPE, Event, INVOCATION_END_SCOPE, LLMStreamingEvent, NodeEvent, STREAM_END_SCOPE, +}; pub use hub::{BlockingEventIter, EventHub, EventHubMetrics, EventStream, HubEmitter}; pub use sink::{ChannelSink, EventSink, JsonLinesSink, MemorySink, StdOutSink}; diff --git a/src/node.rs b/src/node.rs index 70c2dfd..6393d54 100644 --- a/src/node.rs +++ b/src/node.rs @@ -2,7 +2,6 @@ //! //! This module provides the core abstractions for executable workflow nodes, //! including the [`Node`] trait, execution context, state updates, and error handling. - // Standard library and external crates use async_trait::async_trait; use rustc_hash::FxHashMap; @@ -14,8 +13,9 @@ use crate::channels::errors::ErrorEvent; use crate::control::{FrontierCommand, NodeRoute}; use crate::event_bus::{Event, EventEmitter, LLMStreamingEvent}; use crate::message::Message; -use crate::state::StateSnapshot; +use crate::state::{StateKey, StateSlotError, StateSnapshot}; use crate::types::NodeKind; +use crate::utils::clock::Clock; use std::sync::Arc; // ============================================================================ @@ -99,6 +99,7 @@ pub trait Node: Send + Sync { /// Provides nodes with access to their execution environment, including step /// information, node identity, and communication channels for observability. #[derive(Clone, Debug)] +#[non_exhaustive] pub struct NodeContext { /// Unique identifier for this node instance. pub node_id: String, @@ -106,9 +107,40 @@ pub struct NodeContext { pub step: u64, /// Channel for emitting events to the workflow's event system. pub event_emitter: Arc, + /// Optional runtime clock for deterministic tests and replay. + pub clock: Option>, + /// Optional invocation or run identifier attached to node events. + pub invocation_id: Option, } impl NodeContext { + /// Construct a node context with no runtime clock or invocation metadata. + pub fn new( + node_id: impl Into, + step: u64, + event_emitter: Arc, + ) -> Self { + Self { + node_id: node_id.into(), + step, + event_emitter, + clock: None, + invocation_id: None, + } + } + + /// Return the current runtime clock timestamp in Unix milliseconds, if configured. + #[must_use] + pub fn now_unix_ms(&self) -> Option { + self.clock.as_ref().map(|clock| clock.now_unix_ms()) + } + + /// Return the invocation identifier, if configured. + #[must_use] + pub fn invocation_id(&self) -> Option<&str> { + self.invocation_id.as_deref() + } + /// Emit a node-scoped event enriched with this context's metadata. /// /// Creates structured events that include the node's ID and step information, @@ -127,12 +159,33 @@ impl NodeContext { scope: impl Into, message: impl Into, ) -> Result<(), NodeContextError> { - self.emit_event(Event::node_message_with_meta( - self.node_id.clone(), - self.step, - scope, - message, - )) + let mut metadata = FxHashMap::default(); + if let Some(invocation_id) = &self.invocation_id { + metadata.insert( + "invocation_id".to_string(), + serde_json::Value::String(invocation_id.clone()), + ); + } + if let Some(now_unix_ms) = self.now_unix_ms() { + metadata.insert("now_unix_ms".to_string(), serde_json::json!(now_unix_ms)); + } + + if metadata.is_empty() { + self.emit_event(Event::node_message_with_meta( + self.node_id.clone(), + self.step, + scope, + message, + )) + } else { + self.emit_event(Event::node_message_with_metadata( + self.node_id.clone(), + self.step, + scope, + message, + metadata, + )) + } } /// Emit a diagnostic event for general workflow telemetry. @@ -285,6 +338,29 @@ impl NodePartial { self } + /// Insert a typed value into this partial's extra updates. + /// + /// The value is serialized to JSON and stored under the key returned by + /// [`StateKey::storage_key`]. If this partial already contains extra data, + /// the typed slot is merged into it and any existing value at the same + /// storage key is replaced. + pub fn with_typed_extra( + mut self, + key: StateKey, + value: T, + ) -> Result { + let storage_key = key.storage_key(); + let json_value = + serde_json::to_value(value).map_err(|source| StateSlotError::Serialize { + key: storage_key.clone(), + source, + })?; + self.extra + .get_or_insert_with(FxHashMap::default) + .insert(storage_key, json_value); + Ok(self) + } + /// Create a `NodePartial` with one or more errors. #[must_use] pub fn with_errors(mut self, errors: Vec) -> Self { diff --git a/src/reducers/mod.rs b/src/reducers/mod.rs index c6513cc..5606c4f 100644 --- a/src/reducers/mod.rs +++ b/src/reducers/mod.rs @@ -17,6 +17,14 @@ use thiserror::Error; /// Unified reducer trait: every reducer mutates VersionedState using a NodePartial delta. /// Channels currently implemented: messages (append) and extra (shallow JSON map merge). pub trait Reducer: Send + Sync { + /// Stable-ish reducer identity included in graph definition metadata. + /// + /// The default is the concrete Rust type path. Custom reducers can override this + /// with a durable label when the type path is too noisy for audit manifests. + fn definition_label(&self) -> &'static str { + std::any::type_name::() + } + /// Apply the partial update `update` to `state`, mutating it in place. fn apply(&self, state: &mut VersionedState, update: &NodePartial); } diff --git a/src/reducers/reducer_registry.rs b/src/reducers/reducer_registry.rs index e50b618..d6ac94c 100644 --- a/src/reducers/reducer_registry.rs +++ b/src/reducers/reducer_registry.rs @@ -100,6 +100,30 @@ impl ReducerRegistry { self } + /// Return a deterministic summary of registered reducers for metadata hashing. + /// + /// Reducer labels are recorded in registration order for each channel. It + /// changes when reducers are added, removed, reordered, or replaced with a + /// reducer that reports a different [`Reducer::definition_label`]. + #[must_use] + pub fn definition_signature(&self) -> Vec { + let mut signature: Vec = self + .reducer_map + .iter() + .map(|(channel, reducers)| { + let labels = reducers + .iter() + .enumerate() + .map(|(index, reducer)| format!("{index}:{}", reducer.definition_label())) + .collect::>() + .join(","); + format!("{}:[{}]", channel, labels) + }) + .collect(); + signature.sort(); + signature + } + #[instrument(skip(self, state, to_update), err)] /// Apply all reducers for `channel_type` to `state` using `to_update` as the delta. pub fn try_update( diff --git a/src/runtimes/mod.rs b/src/runtimes/mod.rs index 46e390d..101b3a2 100644 --- a/src/runtimes/mod.rs +++ b/src/runtimes/mod.rs @@ -54,6 +54,7 @@ pub mod checkpointer_sqlite; mod checkpointer_sqlite_helpers; pub mod execution; pub mod persistence; +pub mod replay; pub mod runner; pub mod runtime_config; pub mod session; @@ -81,7 +82,12 @@ pub use execution::{PausedReason, PausedReport, StepOptions, StepReport, StepRes pub use session::{SessionInit, SessionState, StateVersions}; // Re-export runner -pub use runner::{AppRunner, AppRunnerBuilder}; +pub use runner::{AppRunner, AppRunnerBuilder, RunMetadata}; +pub use replay::{ + ReplayComparison, ReplayConformanceError, ReplayRun, compare_event_sequences, + compare_event_sequences_with, compare_final_state, compare_replay_runs, + compare_replay_runs_with, normalize_event, normalize_state, +}; pub use runtime_config::{EventBusConfig, RuntimeConfig, SinkConfig}; pub use types::{SessionId, StepNumber}; diff --git a/src/runtimes/replay.rs b/src/runtimes/replay.rs new file mode 100644 index 0000000..99ced6a --- /dev/null +++ b/src/runtimes/replay.rs @@ -0,0 +1,209 @@ +//! Replay conformance helpers for comparing workflow runs. +//! +//! These helpers are intentionally small and test-friendly. They normalize common +//! nondeterministic fields, compare final state and event streams, and return +//! human-readable differences that can be used in ordinary assertions. + +use serde_json::{Value, json}; +use thiserror::Error; + +use crate::{channels::Channel, event_bus::Event, state::VersionedState}; + +/// Captured output from one workflow run. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct ReplayRun { + /// Final workflow state produced by the run. + pub final_state: VersionedState, + /// Events captured during the run. + pub events: Vec, +} + +impl ReplayRun { + /// Create a replay run from final state and captured events. + #[must_use] + pub fn new(final_state: VersionedState, events: Vec) -> Self { + Self { + final_state, + events, + } + } +} + +/// Result of comparing two replay artifacts. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ReplayComparison { + differences: Vec, +} + +impl ReplayComparison { + /// Create a successful comparison with no differences. + #[must_use] + pub fn matched() -> Self { + Self { + differences: Vec::new(), + } + } + + /// Create a comparison with the supplied differences. + #[must_use] + pub fn with_differences(differences: Vec) -> Self { + Self { differences } + } + + /// Return true when no differences were found. + #[must_use] + pub fn is_match(&self) -> bool { + self.differences.is_empty() + } + + /// Return the differences found during comparison. + #[must_use] + pub fn differences(&self) -> &[String] { + &self.differences + } + + /// Convert this report into a `Result` suitable for test assertions. + pub fn assert_matches(self) -> Result<(), ReplayConformanceError> { + if self.is_match() { + Ok(()) + } else { + Err(ReplayConformanceError::Mismatch { + differences: self.differences, + }) + } + } +} + +/// Errors returned by replay conformance helpers. +#[derive(Debug, Error)] +#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] +pub enum ReplayConformanceError { + /// The compared runs were not equivalent. + #[error("replay conformance mismatch: {differences:?}")] + #[cfg_attr( + feature = "diagnostics", + diagnostic(code(weavegraph::replay::mismatch)) + )] + Mismatch { + /// Human-readable differences. + differences: Vec, + }, +} + +/// Normalize an event for replay comparison. +/// +/// The default normalizer uses Weavegraph's JSON event shape and removes the +/// top-level timestamp, which is normally wall-clock dependent. +#[must_use] +pub fn normalize_event(event: &Event) -> Value { + let mut value = event.to_json_value(); + if let Value::Object(object) = &mut value { + object.remove("timestamp"); + } + value +} + +/// Normalize a final state into a JSON value for stable comparison and diffs. +#[must_use] +pub fn normalize_state(state: &VersionedState) -> Value { + json!({ + "messages": state.messages.snapshot(), + "messages_version": state.messages.version(), + "extra": state.extra.snapshot(), + "extra_version": state.extra.version(), + "errors": state.errors.snapshot(), + "errors_version": state.errors.version(), + }) +} + +/// Compare two final states with default normalization. +#[must_use] +pub fn compare_final_state(left: &VersionedState, right: &VersionedState) -> ReplayComparison { + let left_value = normalize_state(left); + let right_value = normalize_state(right); + if left_value == right_value { + ReplayComparison::matched() + } else { + ReplayComparison::with_differences(vec![format!( + "final state differs: left={left_value} right={right_value}" + )]) + } +} + +/// Compare two event streams with the default event normalizer. +#[must_use] +pub fn compare_event_sequences(left: &[Event], right: &[Event]) -> ReplayComparison { + compare_event_sequences_with(left, right, normalize_event) +} + +/// Compare two event streams with a caller-provided normalizer. +/// +/// Use this when domain events contain timestamps, generated IDs, or other +/// values that should be compared semantically rather than byte-for-byte. +#[must_use] +pub fn compare_event_sequences_with( + left: &[Event], + right: &[Event], + normalizer: F, +) -> ReplayComparison +where + F: Fn(&Event) -> Value, +{ + let left_values: Vec = left.iter().map(&normalizer).collect(); + let right_values: Vec = right.iter().map(&normalizer).collect(); + + if left_values == right_values { + return ReplayComparison::matched(); + } + + let mut differences = Vec::new(); + if left_values.len() != right_values.len() { + differences.push(format!( + "event count differs: left={} right={}", + left_values.len(), + right_values.len() + )); + } + + let shared_len = left_values.len().min(right_values.len()); + for index in 0..shared_len { + if left_values[index] != right_values[index] { + differences.push(format!( + "event {index} differs: left={} right={}", + left_values[index], right_values[index] + )); + break; + } + } + + ReplayComparison::with_differences(differences) +} + +/// Compare two captured runs with default state and event normalization. +#[must_use] +pub fn compare_replay_runs(left: &ReplayRun, right: &ReplayRun) -> ReplayComparison { + compare_replay_runs_with(left, right, normalize_event) +} + +/// Compare two captured runs with a caller-provided event normalizer. +#[must_use] +pub fn compare_replay_runs_with( + left: &ReplayRun, + right: &ReplayRun, + event_normalizer: F, +) -> ReplayComparison +where + F: Fn(&Event) -> Value, +{ + let mut differences = Vec::new(); + + let state_comparison = compare_final_state(&left.final_state, &right.final_state); + differences.extend(state_comparison.differences().iter().cloned()); + + let event_comparison = + compare_event_sequences_with(&left.events, &right.events, event_normalizer); + differences.extend(event_comparison.differences().iter().cloned()); + + ReplayComparison::with_differences(differences) +} diff --git a/src/runtimes/runner.rs b/src/runtimes/runner.rs index 0726cd7..1266d63 100644 --- a/src/runtimes/runner.rs +++ b/src/runtimes/runner.rs @@ -20,13 +20,14 @@ use crate::runtimes::execution::{ PausedReason, PausedReport, SchedulerOutcome, StepOptions, StepReport, StepResult, }; use crate::runtimes::session::{SessionInit, SessionState, StateVersions}; -use crate::runtimes::streaming::{StreamEndReason, finalize_event_stream}; +use crate::runtimes::streaming::{StreamEndReason, emit_invocation_end, finalize_event_stream}; use crate::runtimes::{ Checkpoint, Checkpointer, CheckpointerError, InMemoryCheckpointer, restore_session_state, }; -use crate::schedulers::{Scheduler, SchedulerError, SchedulerState}; +use crate::schedulers::{Scheduler, SchedulerError, SchedulerRunContext, SchedulerState}; use crate::state::VersionedState; use crate::types::NodeKind; +use crate::utils::clock::Clock; use rustc_hash::FxHashMap; use std::sync::Arc; use thiserror::Error; @@ -136,6 +137,8 @@ pub struct AppRunner { autosave: bool, event_bus: EventBus, event_stream_taken: bool, + clock: Option>, + checkpointer_descriptor: String, } /// Errors that can occur during workflow execution. @@ -164,6 +167,22 @@ pub enum RunnerError { )] NoStartNodes, + /// The requested entry node cannot be used to start an iterative invocation. + #[error("invalid iterative entry node: {node}")] + #[cfg_attr( + feature = "diagnostics", + diagnostic( + code(weavegraph::runner::invalid_iterative_entry), + help( + "Use NodeKind::Start or a registered custom node. NodeKind::End is terminal and cannot be used as an entry." + ) + ) + )] + InvalidIterativeEntry { + /// The invalid entry node. + node: NodeKind, + }, + /// Execution paused unexpectedly during run_until_complete. #[error("unexpected pause during run_until_complete")] #[cfg_attr( @@ -210,6 +229,33 @@ pub enum RunnerError { Scheduler(#[from] SchedulerError), } +/// Runtime metadata useful for audit, replay, and checkpoint labels. +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub struct RunMetadata { + /// Weavegraph crate version compiled into this binary. + pub weavegraph_version: String, + /// Deterministic graph definition hash. + pub graph_hash: String, + /// Deterministic runtime configuration hash. + pub runtime_config_hash: String, + /// Descriptor for the configured checkpointer backend. + pub checkpointer_backend: String, + /// Descriptor for runtime clock injection mode. + pub clock_mode: String, +} + +struct RunnerRuntimeMetadata { + clock: Option>, + checkpointer_descriptor: String, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum CompletionEventPolicy { + CloseStream, + KeepStreamOpen, +} + // ============================================================================ // Builder Pattern // ============================================================================ @@ -295,6 +341,7 @@ pub struct AppRunnerBuilder { autosave: bool, event_bus: Option, start_listener: bool, + clock: Option>, } impl Default for AppRunnerBuilder { @@ -320,6 +367,7 @@ impl AppRunnerBuilder { autosave: true, event_bus: None, start_listener: true, + clock: None, } } @@ -389,6 +437,13 @@ impl AppRunnerBuilder { self } + /// Set a runtime clock that will be injected into node contexts. + #[must_use] + pub fn clock(mut self, clock: Arc) -> Self { + self.clock = Some(clock); + self + } + /// Build the [`AppRunner`]. /// /// # Panics @@ -407,6 +462,16 @@ impl AppRunnerBuilder { let event_bus = self .event_bus .unwrap_or_else(|| app.runtime_config().event_bus.build_event_bus()); + let clock = self.clock.or_else(|| app.runtime_config().clock()); + let checkpointer_descriptor = if self.checkpointer_custom.is_some() { + "custom".to_string() + } else { + AppRunner::checkpointer_type_label(&self.checkpointer_type).to_string() + }; + let runtime_metadata = RunnerRuntimeMetadata { + clock, + checkpointer_descriptor, + }; Some( AppRunner::with_arc_and_bus( @@ -416,6 +481,7 @@ impl AppRunnerBuilder { self.autosave, event_bus, self.start_listener, + runtime_metadata, ) .await, ) @@ -450,7 +516,7 @@ impl AppRunner { async fn create_checkpointer( checkpointer_type: CheckpointerType, - sqlite_db_name: Option, + _sqlite_db_name: Option, ) -> Option> { match checkpointer_type { CheckpointerType::InMemory => { @@ -461,7 +527,7 @@ impl AppRunner { let db_url = std::env::var("WEAVEGRAPH_SQLITE_URL") .ok() .or_else(|| { - sqlite_db_name + _sqlite_db_name .as_ref() .map(|name| format!("sqlite://{name}")) }) @@ -520,6 +586,16 @@ impl AppRunner { } } + fn checkpointer_type_label(checkpointer_type: &CheckpointerType) -> &'static str { + match checkpointer_type { + CheckpointerType::InMemory => "in-memory", + #[cfg(feature = "sqlite")] + CheckpointerType::SQLite => "sqlite", + #[cfg(feature = "postgres")] + CheckpointerType::Postgres => "postgres", + } + } + async fn with_arc_and_bus( app: Arc, checkpointer_type: CheckpointerType, @@ -527,6 +603,7 @@ impl AppRunner { autosave: bool, event_bus: EventBus, start_listener: bool, + runtime_metadata: RunnerRuntimeMetadata, ) -> Self { // Precedence rule: custom checkpointer always wins when provided. // If custom is None, fall back to enum-based factory instantiation. @@ -546,6 +623,8 @@ impl AppRunner { autosave, event_bus, event_stream_taken: false, + clock: runtime_metadata.clock, + checkpointer_descriptor: runtime_metadata.checkpointer_descriptor, } } @@ -616,6 +695,150 @@ impl AppRunner { Ok(SessionInit::Fresh) } + /// Initialize or resume a session for repeated invocations under one durable lineage. + /// + /// This method behaves like [`create_session`](Self::create_session), then prepares + /// the session to run from `entry_node`. Passing [`NodeKind::Start`] uses the + /// graph's outgoing edges from the virtual Start node, matching normal session + /// initialization. Passing a custom node runs directly from that registered node. + /// + /// The session step counter is not reset when a checkpoint is resumed, so steps + /// remain monotonic across repeated invocations. + #[instrument(skip(self, session_id, initial_state), err)] + pub async fn create_iterative_session( + &mut self, + session_id: String, + initial_state: VersionedState, + entry_node: NodeKind, + ) -> Result { + let frontier = self.frontier_for_iterative_entry(&entry_node)?; + let init = self + .create_session(session_id.clone(), initial_state) + .await?; + self.set_iterative_frontier(&session_id, frontier)?; + Ok(init) + } + + /// Apply an input patch, restart the session frontier, and run to completion. + /// + /// The existing session state is updated through the same deterministic barrier + /// path used for node outputs. The frontier is then reset to `entry_node` and the + /// scheduler's version-gating state is cleared so the entry path executes for this + /// logical invocation even when two consecutive input patches serialize to the + /// same state. + /// + /// Use [`create_iterative_session`](Self::create_iterative_session) before the + /// first call, including after process restart, so the latest checkpoint is loaded + /// into the runner. + #[instrument(skip(self, input), err)] + pub async fn invoke_next( + &mut self, + session_id: &str, + input: NodePartial, + entry_node: NodeKind, + ) -> Result { + let frontier = self.frontier_for_iterative_entry(&entry_node)?; + self.apply_iterative_input(session_id, input).await?; + self.set_iterative_frontier(session_id, frontier)?; + self.run_until_complete_with_policy(session_id, CompletionEventPolicy::KeepStreamOpen) + .await + } + + /// Emit the terminal stream marker for a completed iterative session. + /// + /// `invoke_next` keeps long-lived event subscriptions open between logical + /// inputs. Call this after the final input when a subscriber should receive + /// [`STREAM_END_SCOPE`](crate::event_bus::STREAM_END_SCOPE) and the stream + /// should close cleanly. + pub fn finish_iterative_session(&mut self, session_id: &str) -> Result<(), RunnerError> { + let (_, _, final_step) = self.finalize_state_snapshot(session_id)?; + self.emit_completion_event( + session_id, + StreamEndReason::Completed { step: final_step }, + CompletionEventPolicy::CloseStream, + ); + Ok(()) + } + + fn set_iterative_frontier( + &mut self, + session_id: &str, + frontier: Vec, + ) -> Result<(), RunnerError> { + let session_state = + self.sessions + .get_mut(session_id) + .ok_or_else(|| RunnerError::SessionNotFound { + session_id: session_id.to_string(), + })?; + + session_state.frontier = frontier; + session_state.scheduler_state = SchedulerState::default(); + Ok(()) + } + + fn frontier_for_iterative_entry( + &self, + entry_node: &NodeKind, + ) -> Result, RunnerError> { + match entry_node { + NodeKind::Start => { + let frontier = self + .app + .edges() + .get(&NodeKind::Start) + .cloned() + .unwrap_or_default(); + if frontier.is_empty() { + Err(RunnerError::NoStartNodes) + } else { + Ok(frontier) + } + } + NodeKind::End => Err(RunnerError::InvalidIterativeEntry { + node: entry_node.clone(), + }), + NodeKind::Custom(_) => { + if self.app.nodes().contains_key(entry_node) { + Ok(vec![entry_node.clone()]) + } else { + Err(RunnerError::InvalidIterativeEntry { + node: entry_node.clone(), + }) + } + } + } + } + + async fn apply_iterative_input( + &mut self, + session_id: &str, + input: NodePartial, + ) -> Result<(), RunnerError> { + let mut updated_state = self + .sessions + .get(session_id) + .ok_or_else(|| RunnerError::SessionNotFound { + session_id: session_id.to_string(), + })? + .state + .clone(); + + self.app + .apply_barrier(&mut updated_state, &[], vec![input]) + .await + .map_err(RunnerError::AppBarrier)?; + + let session_state = + self.sessions + .get_mut(session_id) + .ok_or_else(|| RunnerError::SessionNotFound { + session_id: session_id.to_string(), + })?; + session_state.state = updated_state; + Ok(()) + } + /// Execute one superstep for the given session #[instrument(skip(self, options), err)] pub async fn run_step( @@ -684,7 +907,7 @@ impl AppRunner { })?; // Execute one superstep; on error, emit an ErrorEvent and rethrow - let step_report = match self.run_one_superstep(&mut session_state).await { + let step_report = match self.run_one_superstep(session_id, &mut session_state).await { Ok(rep) => rep, Err(e) => { // Build error event @@ -800,6 +1023,7 @@ impl AppRunner { #[inline] async fn schedule_step( &self, + session_id: &str, session_state: &mut SessionState, step: u64, ) -> Result { @@ -812,7 +1036,11 @@ impl AppRunner { session_state.frontier.clone(), snapshot.clone(), step, - self.event_bus.get_emitter(), + SchedulerRunContext { + event_emitter: self.event_bus.get_emitter(), + clock: self.clock.clone(), + invocation_id: Some(session_id.to_string()), + }, ) .await?; @@ -984,6 +1212,7 @@ impl AppRunner { #[instrument(skip(self, session_state), err)] async fn run_one_superstep( &self, + session_id: &str, session_state: &mut SessionState, ) -> Result { session_state.step += 1; @@ -998,7 +1227,7 @@ impl AppRunner { frontier_len = session_state.frontier.len() ); let scheduler_outcome = schedule_span - .in_scope(|| self.schedule_step(session_state, step)) + .in_scope(|| self.schedule_step(session_id, session_state, step)) .await?; // Phase 2: apply barrier and update state @@ -1072,6 +1301,15 @@ impl AppRunner { pub async fn run_until_complete( &mut self, session_id: &str, + ) -> Result { + self.run_until_complete_with_policy(session_id, CompletionEventPolicy::CloseStream) + .await + } + + async fn run_until_complete_with_policy( + &mut self, + session_id: &str, + completion_policy: CompletionEventPolicy, ) -> Result { tracing::info!(session = %session_id, "workflow run started"); @@ -1099,12 +1337,13 @@ impl AppRunner { Err(err) => { let reason = err.to_string(); let step = self.sessions.get(session_id).map(|state| state.step); - self.finalize_event_stream( + self.emit_completion_event( session_id, StreamEndReason::Error { step, error: reason, }, + completion_policy, ); return Err(err); } @@ -1119,12 +1358,13 @@ impl AppRunner { StepResult::Paused(_) => { // This shouldn't happen with default options, but handle gracefully let step = self.sessions.get(session_id).map(|state| state.step); - self.finalize_event_stream( + self.emit_completion_event( session_id, StreamEndReason::Error { step, error: "execution paused unexpectedly".to_string(), }, + completion_policy, ); return Err(RunnerError::UnexpectedPause); } @@ -1169,7 +1409,11 @@ impl AppRunner { ); } - self.finalize_event_stream(session_id, StreamEndReason::Completed { step: final_step }); + self.emit_completion_event( + session_id, + StreamEndReason::Completed { step: final_step }, + completion_policy, + ); Ok(final_state) } @@ -1196,6 +1440,22 @@ impl AppRunner { pub fn list_sessions(&self) -> Vec<&String> { self.sessions.keys().collect() } + + /// Return metadata for this runner and its compiled graph. + #[must_use] + pub fn run_metadata(&self) -> RunMetadata { + RunMetadata { + weavegraph_version: self.app.weavegraph_version().to_string(), + graph_hash: self.app.graph_definition_hash(), + runtime_config_hash: self.app.runtime_config().config_hash(), + checkpointer_backend: self.checkpointer_descriptor.clone(), + clock_mode: if self.clock.is_some() { + "configured".to_string() + } else { + "unset".to_string() + }, + } + } } impl AppRunner { @@ -1237,4 +1497,18 @@ impl AppRunner { &mut self.event_stream_taken, ); } + + fn emit_completion_event( + &mut self, + session_id: &str, + reason: StreamEndReason, + policy: CompletionEventPolicy, + ) { + match policy { + CompletionEventPolicy::CloseStream => self.finalize_event_stream(session_id, reason), + CompletionEventPolicy::KeepStreamOpen => { + emit_invocation_end(&self.event_bus, session_id, reason); + } + } + } } diff --git a/src/runtimes/runtime_config.rs b/src/runtimes/runtime_config.rs index e592cbc..d5f4edf 100644 --- a/src/runtimes/runtime_config.rs +++ b/src/runtimes/runtime_config.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use crate::event_bus::{EventBus, EventSink, MemorySink, StdOutSink}; +use crate::utils::clock::Clock; use super::Checkpointer; @@ -16,6 +17,8 @@ pub struct RuntimeConfig { pub sqlite_db_name: Option, /// Event bus configuration used to build the [`EventBus`]. pub event_bus: EventBusConfig, + /// Optional runtime clock injected into node execution contexts. + pub clock: Option>, } impl std::fmt::Debug for RuntimeConfig { @@ -25,6 +28,7 @@ impl std::fmt::Debug for RuntimeConfig { .field("checkpointer_custom", &self.checkpointer_custom.is_some()) .field("sqlite_db_name", &self.sqlite_db_name) .field("event_bus", &self.event_bus) + .field("clock", &self.clock.is_some()) .finish() } } @@ -36,6 +40,7 @@ impl Default for RuntimeConfig { checkpointer_custom: None, sqlite_db_name: Self::resolve_sqlite_db_name(None), event_bus: EventBusConfig::default(), + clock: None, } } } @@ -56,6 +61,7 @@ impl RuntimeConfig { checkpointer_custom: None, sqlite_db_name: Self::resolve_sqlite_db_name(sqlite_db_name), event_bus: EventBusConfig::default(), + clock: None, } } @@ -72,6 +78,50 @@ impl RuntimeConfig { self.checkpointer_custom.clone() } + #[must_use] + /// Set the runtime clock injected into [`NodeContext`](crate::node::NodeContext). + pub fn with_clock(mut self, clock: Arc) -> Self { + self.clock = Some(clock); + self + } + + #[must_use] + /// Return the configured runtime clock, if any. + pub fn clock(&self) -> Option> { + self.clock.clone() + } + + #[must_use] + /// Return a descriptor for the configured clock mode. + pub fn clock_mode(&self) -> &'static str { + if self.clock.is_some() { + "configured" + } else { + "unset" + } + } + + /// Return a deterministic hash of runtime configuration metadata. + #[must_use] + pub fn config_hash(&self) -> String { + let mut parts = vec!["weavegraph-runtime-config-v1".to_string()]; + parts.push(format!( + "session_id:{}", + self.session_id.as_deref().unwrap_or("") + )); + parts.push(format!( + "sqlite_db_name:{}", + self.sqlite_db_name.as_deref().unwrap_or("") + )); + parts.push(format!( + "custom_checkpointer:{}", + self.checkpointer_custom.is_some() + )); + parts.push(format!("clock:{}", self.clock_mode())); + parts.extend(self.event_bus.metadata_signature()); + hash_parts(&parts) + } + #[must_use] /// Replace the event bus configuration for this runtime. pub fn with_event_bus(mut self, event_bus: EventBusConfig) -> Self { @@ -92,6 +142,20 @@ impl RuntimeConfig { } } +fn hash_parts(parts: &[String]) -> String { + const FNV_OFFSET: u64 = 0xcbf29ce484222325; + const FNV_PRIME: u64 = 0x100000001b3; + + let mut hash = FNV_OFFSET; + for part in parts { + for byte in part.as_bytes().iter().copied().chain([0xff]) { + hash ^= u64::from(byte); + hash = hash.wrapping_mul(FNV_PRIME); + } + } + format!("{hash:016x}") +} + /// Selects the output target for an [`EventBusConfig`] sink entry. #[derive(Clone, Debug, PartialEq, Eq)] pub enum SinkConfig { @@ -161,6 +225,20 @@ impl EventBusConfig { &self.sinks } + /// Return deterministic metadata entries for this event bus configuration. + #[must_use] + pub fn metadata_signature(&self) -> Vec { + let mut parts = vec![format!("event_buffer:{}", self.buffer_capacity)]; + parts.extend( + self.sinks + .iter() + .enumerate() + .map(|(index, sink)| format!("event_sink:{index}:{sink:?}")), + ); + parts.extend(self.diagnostics.metadata_signature()); + parts + } + #[must_use] /// Override the diagnostics configuration for this event bus. pub fn with_diagnostics(mut self, diagnostics: DiagnosticsConfig) -> Self { @@ -239,6 +317,19 @@ impl DiagnosticsConfig { self.buffer_capacity .unwrap_or_else(|| Self::normalize_capacity(event_bus_capacity)) } + + fn metadata_signature(&self) -> Vec { + vec![ + format!("diagnostics_enabled:{}", self.enabled), + format!( + "diagnostics_capacity:{}", + self.buffer_capacity + .map(|capacity| capacity.to_string()) + .unwrap_or_default() + ), + format!("diagnostics_emit_to_events:{}", self.emit_to_events), + ] + } } impl Default for DiagnosticsConfig { diff --git a/src/runtimes/streaming.rs b/src/runtimes/streaming.rs index 5e4b2d6..d28aca6 100644 --- a/src/runtimes/streaming.rs +++ b/src/runtimes/streaming.rs @@ -3,7 +3,7 @@ //! This module handles the lifecycle of event streams during workflow //! execution, including finalization and cleanup. -use crate::event_bus::{Event, EventBus, STREAM_END_SCOPE}; +use crate::event_bus::{Event, EventBus, INVOCATION_END_SCOPE, STREAM_END_SCOPE}; /// Internal reason for ending an event stream. pub(crate) enum StreamEndReason { @@ -65,3 +65,21 @@ pub(crate) fn finalize_event_stream( *event_stream_taken = false; } } + +/// Emit a logical invocation completion marker without closing the event channel. +pub(crate) fn emit_invocation_end(event_bus: &EventBus, session_id: &str, reason: StreamEndReason) { + let message = reason.format_message(session_id); + + if let Err(err) = event_bus + .get_emitter() + .emit(Event::diagnostic(INVOCATION_END_SCOPE, message.clone())) + { + tracing::debug!( + session = %session_id, + scope = INVOCATION_END_SCOPE, + completion_message = %message, + error = ?err, + "failed to emit invocation completion event" + ); + } +} diff --git a/src/schedulers/mod.rs b/src/schedulers/mod.rs index 53f2dae..0761197 100644 --- a/src/schedulers/mod.rs +++ b/src/schedulers/mod.rs @@ -1,4 +1,6 @@ //! Frontier-based workflow scheduler with version gating and bounded concurrency. pub mod scheduler; -pub use scheduler::{Scheduler, SchedulerError, SchedulerState, StepRunResult}; +pub use scheduler::{ + Scheduler, SchedulerError, SchedulerRunContext, SchedulerState, StepRunResult, +}; diff --git a/src/schedulers/scheduler.rs b/src/schedulers/scheduler.rs index da8714c..c75b85e 100644 --- a/src/schedulers/scheduler.rs +++ b/src/schedulers/scheduler.rs @@ -42,6 +42,7 @@ use crate::event_bus::EventEmitter; use crate::node::{Node, NodeContext, NodeError, NodePartial}; use crate::state::StateSnapshot; use crate::types::NodeKind; +use crate::utils::clock::Clock; use futures_util::stream::{self, StreamExt}; use rustc_hash::FxHashMap; use std::sync::Arc; @@ -89,6 +90,44 @@ pub struct StepRunResult { pub outputs: Vec<(NodeKind, NodePartial)>, } +/// Runtime context passed to a scheduler superstep. +#[derive(Clone)] +#[non_exhaustive] +pub struct SchedulerRunContext { + /// Event emitter injected into node contexts. + pub event_emitter: Arc, + /// Optional runtime clock injected into node contexts. + pub clock: Option>, + /// Optional invocation identifier injected into node contexts. + pub invocation_id: Option, +} + +impl SchedulerRunContext { + /// Create scheduler runtime context with only an event emitter. + #[must_use] + pub fn new(event_emitter: Arc) -> Self { + Self { + event_emitter, + clock: None, + invocation_id: None, + } + } + + /// Attach a runtime clock. + #[must_use] + pub fn with_clock(mut self, clock: Arc) -> Self { + self.clock = Some(clock); + self + } + + /// Attach an invocation identifier. + #[must_use] + pub fn with_invocation_id(mut self, invocation_id: impl Into) -> Self { + self.invocation_id = Some(invocation_id.into()); + self + } +} + /// Tracks version information for nodes to enable intelligent scheduling. /// /// The scheduler uses this state to determine whether a node needs to run @@ -467,7 +506,7 @@ impl Scheduler { /// * `frontier` - Vector of nodes eligible for execution this step /// * `snap` - Pre-barrier state snapshot for version gating /// * `step` - Current workflow step number (for context and logging) - /// * `event_emitter` - Cloneable handle for sending execution events + /// * `run_context` - Runtime context injected into node execution /// /// # Returns /// * `Ok(StepRunResult)` - Execution results with ran/skipped nodes and outputs @@ -478,7 +517,7 @@ impl Scheduler { /// ```rust /// use weavegraph::channels::Channel; /// use weavegraph::event_bus::EventBus; - /// use weavegraph::schedulers::{Scheduler, SchedulerState}; + /// use weavegraph::schedulers::{Scheduler, SchedulerRunContext, SchedulerState}; /// use weavegraph::state::VersionedState; /// use weavegraph::types::NodeKind; /// use rustc_hash::FxHashMap; @@ -499,7 +538,7 @@ impl Scheduler { /// frontier, /// snapshot, /// 1, - /// event_bus.get_emitter(), + /// SchedulerRunContext::new(event_bus.get_emitter()), /// ).await?; /// /// println!("Executed {} nodes, skipped {}", @@ -514,7 +553,7 @@ impl Scheduler { /// - **Node Failures**: If any node returns an error, the entire superstep fails /// - **Task Panics**: Panicking nodes result in `SchedulerError::Join` /// - **Missing Nodes**: Panics if frontier contains nodes not in registry - #[instrument(skip(self, state, nodes, frontier, snap))] + #[instrument(skip(self, state, nodes, frontier, snap, run_context))] pub async fn superstep( &self, state: &mut SchedulerState, @@ -522,7 +561,7 @@ impl Scheduler { frontier: Vec, // frontier for this step snap: StateSnapshot, // pre-barrier snapshot step: u64, - event_emitter: Arc, + run_context: SchedulerRunContext, ) -> Result { // Partition frontier into to_run vs skipped using a skip predicate and version gating. let channels = Self::channel_versions(&snap); @@ -565,11 +604,15 @@ impl Scheduler { .map(|(id_str, kind)| { // SAFETY: We validated all nodes exist above, so this unwrap is safe. let node = nodes.get(&kind).unwrap().clone(); - let event_emitter = Arc::clone(&event_emitter); + let event_emitter = Arc::clone(&run_context.event_emitter); + let clock = run_context.clock.clone(); + let invocation_id = run_context.invocation_id.clone(); let ctx = NodeContext { node_id: id_str.clone(), step, event_emitter, + clock, + invocation_id, }; let s = snap.clone(); async move { diff --git a/src/state.rs b/src/state.rs index 246e955..43bb91d 100644 --- a/src/state.rs +++ b/src/state.rs @@ -36,13 +36,138 @@ //! ``` use rustc_hash::FxHashMap; +use serde::{Serialize, de::DeserializeOwned}; use serde_json::Value; +use std::marker::PhantomData; +use thiserror::Error; use crate::{ channels::{Channel, ErrorsChannel, ExtrasChannel, MessagesChannel}, message::{Message, Role}, }; +/// A schema-versioned key for typed values stored in [`VersionedState::extra`]. +/// +/// `StateKey` is a thin helper over the JSON-compatible `extra` map. Domain +/// crates can define constants and use them from nodes, reducers, tests, and +/// replay code without repeating string literals. +/// +/// # Examples +/// +/// ```rust +/// use serde::{Deserialize, Serialize}; +/// use weavegraph::state::StateKey; +/// +/// #[derive(Serialize, Deserialize)] +/// struct PortfolioSnapshot { +/// cash: i64, +/// } +/// +/// const PORTFOLIO: StateKey = +/// StateKey::new("wq", "portfolio_snapshot", 1); +/// +/// assert_eq!(PORTFOLIO.storage_key(), "wq:portfolio_snapshot:v1"); +/// ``` +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct StateKey { + namespace: &'static str, + name: &'static str, + schema_version: u32, + _marker: PhantomData T>, +} + +impl Clone for StateKey { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for StateKey {} + +impl StateKey { + /// Create a typed state key. + pub const fn new(namespace: &'static str, name: &'static str, schema_version: u32) -> Self { + Self { + namespace, + name, + schema_version, + _marker: PhantomData, + } + } + + /// Return the namespace component. + #[must_use] + pub fn namespace(&self) -> &'static str { + self.namespace + } + + /// Return the key name component. + #[must_use] + pub fn name(&self) -> &'static str { + self.name + } + + /// Return the schema version component. + #[must_use] + pub fn schema_version(&self) -> u32 { + self.schema_version + } + + /// Return the concrete `extra` map key used for storage. + /// + /// The format is `namespace:name:v{schema_version}`. Changing the schema + /// version intentionally writes to a different slot, avoiding silent + /// collisions between incompatible payload shapes. + #[must_use] + pub fn storage_key(&self) -> String { + format!("{}:{}:v{}", self.namespace, self.name, self.schema_version) + } +} + +/// Errors produced by typed state-slot helpers. +#[derive(Debug, Error)] +#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] +pub enum StateSlotError { + /// The requested typed slot was not present in the state. + #[error("state slot not found: {key}")] + #[cfg_attr( + feature = "diagnostics", + diagnostic(code(weavegraph::state::slot_missing)) + )] + Missing { + /// The concrete storage key that was not found. + key: String, + }, + + /// A typed slot value could not be serialized to JSON. + #[error("failed to serialize state slot {key}: {source}")] + #[cfg_attr( + feature = "diagnostics", + diagnostic(code(weavegraph::state::slot_serialize)) + )] + Serialize { + /// The concrete storage key being written. + key: String, + /// The underlying serde error. + #[source] + source: serde_json::Error, + }, + + /// A typed slot value could not be deserialized from JSON. + #[error("failed to deserialize state slot {key}: {source}")] + #[cfg_attr( + feature = "diagnostics", + diagnostic(code(weavegraph::state::slot_deserialize)) + )] + Deserialize { + /// The concrete storage key being read. + key: String, + /// The underlying serde error. + #[source] + source: serde_json::Error, + }, +} + /// The main state container for workflow execution. /// /// `VersionedState` manages three independent channels of versioned data: @@ -322,6 +447,26 @@ impl VersionedState { self } + /// Adds a typed value to the extra channel using a schema-versioned key. + /// + /// The value is serialized to JSON and stored under + /// [`StateKey::storage_key`]. The channel version is still advanced by the + /// normal barrier system during graph execution. + pub fn add_typed_extra( + &mut self, + key: StateKey, + value: T, + ) -> Result<&mut Self, StateSlotError> { + let storage_key = key.storage_key(); + let json_value = + serde_json::to_value(value).map_err(|source| StateSlotError::Serialize { + key: storage_key.clone(), + source, + })?; + self.extra.get_mut().insert(storage_key, json_value); + Ok(self) + } + /// Creates an immutable snapshot of the current state. /// /// This method clones the current channel data and version numbers, @@ -369,6 +514,41 @@ impl VersionedState { } } +impl StateSnapshot { + /// Read an optional typed value from the extra channel. + /// + /// Returns `Ok(None)` when the slot is absent. Deserialization errors are + /// reported with the concrete storage key. + pub fn get_typed( + &self, + key: StateKey, + ) -> Result, StateSlotError> { + let storage_key = key.storage_key(); + self.extra + .get(&storage_key) + .cloned() + .map(|value| { + serde_json::from_value(value).map_err(|source| StateSlotError::Deserialize { + key: storage_key, + source, + }) + }) + .transpose() + } + + /// Read a required typed value from the extra channel. + /// + /// Use this when a node cannot proceed without a specific typed slot. + pub fn require_typed( + &self, + key: StateKey, + ) -> Result { + let storage_key = key.storage_key(); + self.get_typed(key)? + .ok_or(StateSlotError::Missing { key: storage_key }) + } +} + /// Builder for constructing VersionedState with fluent API. /// /// `VersionedStateBuilder` provides an ergonomic way to construct workflow state @@ -516,6 +696,22 @@ impl VersionedStateBuilder { self } + /// Adds a typed value to the extra channel using a schema-versioned key. + pub fn with_typed_extra( + mut self, + key: StateKey, + value: T, + ) -> Result { + let storage_key = key.storage_key(); + let json_value = + serde_json::to_value(value).map_err(|source| StateSlotError::Serialize { + key: storage_key.clone(), + source, + })?; + self.extra.insert(storage_key, json_value); + Ok(self) + } + /// Builds the final VersionedState. /// /// Creates a new VersionedState with all the configured messages and metadata. diff --git a/src/utils/clock.rs b/src/utils/clock.rs index f7c436f..66238de 100644 --- a/src/utils/clock.rs +++ b/src/utils/clock.rs @@ -8,10 +8,15 @@ use chrono::{DateTime, Utc}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; /// Trait for time sources providing both Unix timestamps and DateTime objects. -pub trait Clock: Send + Sync { +pub trait Clock: Send + Sync + std::fmt::Debug { /// Get the current time as a Unix timestamp (seconds since epoch). fn now(&self) -> u64; + /// Get the current time as a Unix timestamp in milliseconds. + fn now_unix_ms(&self) -> i64 { + self.now_datetime().timestamp_millis() + } + /// Get the current time as a `DateTime` for more complex time operations. fn now_datetime(&self) -> DateTime; diff --git a/tests/common/testing.rs b/tests/common/testing.rs index 1313295..aca48c6 100644 --- a/tests/common/testing.rs +++ b/tests/common/testing.rs @@ -24,11 +24,7 @@ mod tests { async fn test_testnode_construction() { let node = TestNode { name: "example" }; let bus = weavegraph::event_bus::EventBus::default(); - let ctx = NodeContext { - node_id: "test_node".to_string(), - step: 1, - event_emitter: bus.get_emitter(), - }; + let ctx = NodeContext::new("test_node", 1, bus.get_emitter()); let snapshot = VersionedState::builder().build().snapshot(); let result = node.run(snapshot, ctx).await; assert!(result.is_ok()); diff --git a/tests/event_bus.rs b/tests/event_bus.rs index 8a9813b..754b41c 100644 --- a/tests/event_bus.rs +++ b/tests/event_bus.rs @@ -9,8 +9,8 @@ use std::sync::Mutex; use std::time::Duration; use weavegraph::channels::Channel; use weavegraph::event_bus::{ - ChannelSink, Event, EventBus, EventEmitter, EventSink, JsonLinesSink, LLMStreamingEvent, - MemorySink, NodeEvent, STREAM_END_SCOPE, + ChannelSink, Event, EventBus, EventEmitter, EventSink, INVOCATION_END_SCOPE, JsonLinesSink, + LLMStreamingEvent, MemorySink, NodeEvent, STREAM_END_SCOPE, }; use weavegraph::node::NodeContext; @@ -541,11 +541,7 @@ impl EventEmitter for RecordingEmitter { fn node_context_emits_all_event_variants() { let emitter = Arc::new(RecordingEmitter::new()); let event_emitter: Arc = emitter.clone(); - let ctx = NodeContext { - node_id: "node-a".to_string(), - step: 7, - event_emitter, - }; + let ctx = NodeContext::new("node-a", 7, event_emitter); ctx.emit("progress", "started").unwrap(); ctx.emit_diagnostic("diagnostic", "all good").unwrap(); @@ -627,6 +623,40 @@ fn node_context_emits_all_event_variants() { } } +#[test] +fn node_event_metadata_defaults_when_deserializing_legacy_payloads() { + let legacy = r#"{"node_id":"legacy-node","step":3,"scope":"legacy","message":"old"}"#; + let event: NodeEvent = serde_json::from_str(legacy).expect("legacy node event should decode"); + + assert_eq!(event.node_id(), Some("legacy-node")); + assert_eq!(event.step(), Some(3)); + assert_eq!(event.scope(), "legacy"); + assert_eq!(event.message(), "old"); + assert!(event.metadata().is_empty()); +} + +#[test] +fn node_event_runtime_metadata_is_preserved_and_structured_fields_win_collisions() { + let mut metadata = FxHashMap::default(); + metadata.insert("custom".to_string(), json!({ "nested": true })); + metadata.insert("node_id".to_string(), json!("spoofed-node")); + metadata.insert("step".to_string(), json!(0)); + + let event = Event::node_message_with_metadata("real-node", 42, "scope", "message", metadata); + let value = event.to_json_value(); + + assert_eq!(value["metadata"]["custom"], json!({ "nested": true })); + assert_eq!(value["metadata"]["node_id"], "real-node"); + assert_eq!(value["metadata"]["step"], 42); +} + +#[test] +fn stream_scope_constants_are_distinct_and_stable() { + assert_eq!(STREAM_END_SCOPE, "__weavegraph_stream_end__"); + assert_eq!(INVOCATION_END_SCOPE, "__weavegraph_invocation_end__"); + assert_ne!(STREAM_END_SCOPE, INVOCATION_END_SCOPE); +} + fn text_strategy() -> impl Strategy { proptest::string::string_regex("[A-Za-z0-9 _-]{0,32}").unwrap() } @@ -647,7 +677,7 @@ fn event_strategy() -> impl Strategy { let diagnostic = (text_strategy(), text_strategy()) .prop_map(|(scope, message)| Event::diagnostic(scope, message)); - let node = ( + let plain_node = ( prop::option::of(text_strategy()), prop::option::of(any::()), text_strategy(), @@ -657,6 +687,18 @@ fn event_strategy() -> impl Strategy { Event::Node(NodeEvent::new(node_id, step, scope, message)) }); + let node_with_metadata = ( + text_strategy(), + any::(), + text_strategy(), + text_strategy(), + prop::collection::hash_map(text_strategy(), json_value_strategy(), 0..4), + ) + .prop_map(|(node_id, step, scope, message, metadata)| { + let meta: FxHashMap = metadata.into_iter().collect(); + Event::node_message_with_metadata(node_id, step, scope, message, meta) + }); + let llm = ( prop::option::of(text_strategy()), prop::option::of(text_strategy()), @@ -682,7 +724,7 @@ fn event_strategy() -> impl Strategy { }, ); - prop_oneof![diagnostic, node, llm] + prop_oneof![diagnostic, plain_node, node_with_metadata, llm] } proptest! { diff --git a/tests/graphs.rs b/tests/graphs.rs index 91aebce..5079c7f 100644 --- a/tests/graphs.rs +++ b/tests/graphs.rs @@ -1,8 +1,44 @@ mod common; use common::*; +use std::sync::Arc; use weavegraph::graphs::{EdgePredicate, GraphBuilder}; -use weavegraph::types::NodeKind; +use weavegraph::node::NodePartial; +use weavegraph::reducers::Reducer; +use weavegraph::state::VersionedState; +use weavegraph::types::{ChannelType, NodeKind}; + +struct FirstExtraReducer; + +impl Reducer for FirstExtraReducer { + fn apply(&self, _state: &mut VersionedState, _update: &NodePartial) {} +} + +struct SecondExtraReducer; + +impl Reducer for SecondExtraReducer { + fn apply(&self, _state: &mut VersionedState, _update: &NodePartial) {} +} + +struct StableLabelReducerA; + +impl Reducer for StableLabelReducerA { + fn definition_label(&self) -> &'static str { + "stable-extra-label" + } + + fn apply(&self, _state: &mut VersionedState, _update: &NodePartial) {} +} + +struct StableLabelReducerB; + +impl Reducer for StableLabelReducerB { + fn definition_label(&self) -> &'static str { + "stable-extra-label" + } + + fn apply(&self, _state: &mut VersionedState, _update: &NodePartial) {} +} #[test] fn test_add_conditional_edge() { @@ -77,6 +113,127 @@ fn test_compile() { ); } +#[test] +fn test_graph_metadata_and_hash_change_with_definition() { + let app_a = GraphBuilder::new() + .add_node(NodeKind::Custom("A".into()), NoopNode) + .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) + .add_edge(NodeKind::Custom("A".into()), NodeKind::End) + .compile() + .unwrap(); + let app_b = GraphBuilder::new() + .add_node(NodeKind::Custom("B".into()), NoopNode) + .add_edge(NodeKind::Start, NodeKind::Custom("B".into())) + .add_edge(NodeKind::Custom("B".into()), NodeKind::End) + .compile() + .unwrap(); + + let metadata = app_a.graph_metadata(); + assert_eq!(metadata.graph_hash, app_a.graph_definition_hash()); + assert_eq!(metadata.node_count, 1); + assert_eq!(metadata.edge_count, 2); + assert_eq!(metadata.conditional_edge_count, 0); + assert_ne!(app_a.graph_definition_hash(), app_b.graph_definition_hash()); +} + +#[test] +fn test_graph_hash_changes_with_reducer_identity() { + let app_a = GraphBuilder::new() + .add_node(NodeKind::Custom("A".into()), NoopNode) + .with_reducer(ChannelType::Extra, Arc::new(FirstExtraReducer)) + .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) + .add_edge(NodeKind::Custom("A".into()), NodeKind::End) + .compile() + .unwrap(); + let app_b = GraphBuilder::new() + .add_node(NodeKind::Custom("A".into()), NoopNode) + .with_reducer(ChannelType::Extra, Arc::new(SecondExtraReducer)) + .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) + .add_edge(NodeKind::Custom("A".into()), NodeKind::End) + .compile() + .unwrap(); + + assert_ne!(app_a.graph_definition_hash(), app_b.graph_definition_hash()); +} + +#[test] +fn test_graph_hash_is_stable_for_equivalent_definition_ordering() { + let app_a = GraphBuilder::new() + .add_node(NodeKind::Custom("A".into()), NoopNode) + .add_node(NodeKind::Custom("B".into()), NoopNode) + .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) + .add_edge(NodeKind::Start, NodeKind::Custom("B".into())) + .add_edge(NodeKind::Custom("A".into()), NodeKind::End) + .add_edge(NodeKind::Custom("B".into()), NodeKind::End) + .compile() + .unwrap(); + let app_b = GraphBuilder::new() + .add_node(NodeKind::Custom("B".into()), NoopNode) + .add_node(NodeKind::Custom("A".into()), NoopNode) + .add_edge(NodeKind::Custom("B".into()), NodeKind::End) + .add_edge(NodeKind::Custom("A".into()), NodeKind::End) + .add_edge(NodeKind::Start, NodeKind::Custom("B".into())) + .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) + .compile() + .unwrap(); + + assert_eq!(app_a.graph_definition_hash(), app_b.graph_definition_hash()); +} + +#[test] +fn test_graph_hash_changes_with_conditional_edge_registration_count() { + let route_to_end: EdgePredicate = Arc::new(|_snapshot| vec!["End".to_string()]); + let app_without_conditional = GraphBuilder::new() + .add_node(NodeKind::Custom("A".into()), NoopNode) + .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) + .add_edge(NodeKind::Custom("A".into()), NodeKind::End) + .compile() + .unwrap(); + let app_with_conditional = GraphBuilder::new() + .add_node(NodeKind::Custom("A".into()), NoopNode) + .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) + .add_edge(NodeKind::Custom("A".into()), NodeKind::End) + .add_conditional_edge(NodeKind::Custom("A".into()), route_to_end) + .compile() + .unwrap(); + + assert_eq!( + app_with_conditional.graph_metadata().conditional_edge_count, + 1 + ); + assert_ne!( + app_without_conditional.graph_definition_hash(), + app_with_conditional.graph_definition_hash() + ); +} + +#[test] +fn test_graph_hash_uses_custom_reducer_definition_label() { + let app_a = GraphBuilder::new() + .add_node(NodeKind::Custom("A".into()), NoopNode) + .with_reducer(ChannelType::Extra, Arc::new(StableLabelReducerA)) + .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) + .add_edge(NodeKind::Custom("A".into()), NodeKind::End) + .compile() + .unwrap(); + let app_b = GraphBuilder::new() + .add_node(NodeKind::Custom("A".into()), NoopNode) + .with_reducer(ChannelType::Extra, Arc::new(StableLabelReducerB)) + .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) + .add_edge(NodeKind::Custom("A".into()), NodeKind::End) + .compile() + .unwrap(); + + assert!( + app_a + .graph_metadata() + .reducer_signature + .iter() + .any(|entry| entry.contains("stable-extra-label")) + ); + assert_eq!(app_a.graph_definition_hash(), app_b.graph_definition_hash()); +} + #[test] fn test_compile_missing_entry() { let gb = GraphBuilder::new().add_edge(NodeKind::Start, NodeKind::End); diff --git a/tests/nodes.rs b/tests/nodes.rs index b14dbc5..1792f39 100644 --- a/tests/nodes.rs +++ b/tests/nodes.rs @@ -11,11 +11,7 @@ use weavegraph::utils::collections::new_extra_map; fn make_ctx(step: u64) -> (NodeContext, EventBus) { let event_bus = EventBus::default(); event_bus.listen_for_events(); - let ctx = NodeContext { - node_id: "test-node".to_string(), - step, - event_emitter: event_bus.get_emitter(), - }; + let ctx = NodeContext::new("test-node", step, event_bus.get_emitter()); (ctx, event_bus) } diff --git a/tests/runtime_config_metadata.rs b/tests/runtime_config_metadata.rs new file mode 100644 index 0000000..ba1859c --- /dev/null +++ b/tests/runtime_config_metadata.rs @@ -0,0 +1,133 @@ +use async_trait::async_trait; +use proptest::prelude::*; +use std::sync::Arc; +use weavegraph::runtimes::checkpointer::Result as CheckpointerResult; +use weavegraph::runtimes::runtime_config::DiagnosticsConfig; +use weavegraph::runtimes::{Checkpoint, Checkpointer, EventBusConfig, RuntimeConfig, SinkConfig}; +use weavegraph::utils::clock::MockClock; + +#[derive(Default)] +struct NoopCheckpointer; + +#[async_trait] +impl Checkpointer for NoopCheckpointer { + async fn save(&self, _checkpoint: Checkpoint) -> CheckpointerResult<()> { + Ok(()) + } + + async fn load_latest(&self, _session_id: &str) -> CheckpointerResult> { + Ok(None) + } + + async fn list_sessions(&self) -> CheckpointerResult> { + Ok(Vec::new()) + } +} + +#[test] +fn runtime_config_hash_is_stable_and_changes_for_metadata_boundaries() { + let base = RuntimeConfig::new( + Some("session-a".to_string()), + Some("db-a.sqlite".to_string()), + ) + .with_memory_event_bus(); + let same = RuntimeConfig::new( + Some("session-a".to_string()), + Some("db-a.sqlite".to_string()), + ) + .with_memory_event_bus(); + + assert_eq!(base.config_hash(), same.config_hash()); + assert_eq!(base.clock_mode(), "unset"); + assert!(base.clock().is_none()); + + let with_session = RuntimeConfig::new( + Some("session-b".to_string()), + Some("db-a.sqlite".to_string()), + ) + .with_memory_event_bus(); + let with_sqlite = RuntimeConfig::new( + Some("session-a".to_string()), + Some("db-b.sqlite".to_string()), + ) + .with_memory_event_bus(); + let with_clock = base.clone().with_clock(Arc::new(MockClock::new(7))); + let with_custom_checkpointer = base.clone().checkpointer_custom(Arc::new(NoopCheckpointer)); + + assert_ne!(base.config_hash(), with_session.config_hash()); + assert_ne!(base.config_hash(), with_sqlite.config_hash()); + assert_ne!(base.config_hash(), with_clock.config_hash()); + assert_ne!(base.config_hash(), with_custom_checkpointer.config_hash()); + assert_eq!(with_clock.clock_mode(), "configured"); + assert!(with_clock.clock().is_some()); + assert!(with_custom_checkpointer.custom_checkpointer().is_some()); +} + +#[test] +fn event_bus_config_normalizes_capacity_and_deduplicates_sinks() { + let config = EventBusConfig::new(0, Vec::new()) + .add_sink(SinkConfig::Memory) + .add_sink(SinkConfig::Memory) + .add_sink(SinkConfig::StdOut); + + assert_eq!( + config.buffer_capacity(), + EventBusConfig::DEFAULT_BUFFER_CAPACITY + ); + assert_eq!(config.sinks(), &[SinkConfig::Memory, SinkConfig::StdOut]); + + let signature = config.metadata_signature(); + assert!(signature.contains(&"event_buffer:1024".to_string())); + assert!(signature.contains(&"event_sink:0:Memory".to_string())); + assert!(signature.contains(&"event_sink:1:StdOut".to_string())); +} + +#[test] +fn diagnostics_config_defaults_and_overrides_are_reflected_in_metadata() { + let default_for_zero = DiagnosticsConfig::default_with_capacity(0); + assert_eq!(default_for_zero.effective_capacity(99), 1); + + let diagnostics = DiagnosticsConfig { + enabled: false, + buffer_capacity: None, + emit_to_events: true, + }; + let config = EventBusConfig::new(64, vec![SinkConfig::Memory]).with_diagnostics(diagnostics); + + assert_eq!(config.buffer_capacity(), 64); + let signature = config.metadata_signature(); + assert!(signature.contains(&"diagnostics_enabled:false".to_string())); + assert!(signature.contains(&"diagnostics_capacity:64".to_string())); + assert!(signature.contains(&"diagnostics_emit_to_events:true".to_string())); + + let _bus = config.build_event_bus(); +} + +proptest! { + #[test] + fn prop_runtime_config_hash_is_deterministic_for_generated_metadata( + session_id in "[A-Za-z0-9_-]{0,24}", + sqlite_db_name in "[A-Za-z0-9_.-]{1,24}", + capacity in 1usize..2048, + use_memory_sink in any::(), + diagnostics_enabled in any::(), + emit_diagnostics in any::(), + ) { + let sinks = if use_memory_sink { + vec![SinkConfig::Memory] + } else { + vec![SinkConfig::StdOut] + }; + let diagnostics = DiagnosticsConfig { + enabled: diagnostics_enabled, + buffer_capacity: Some(capacity), + emit_to_events: emit_diagnostics, + }; + let config = RuntimeConfig::new(Some(session_id.clone()), Some(sqlite_db_name.clone())) + .with_event_bus(EventBusConfig::new(capacity, sinks.clone()).with_diagnostics(diagnostics.clone())); + let same = RuntimeConfig::new(Some(session_id), Some(sqlite_db_name)) + .with_event_bus(EventBusConfig::new(capacity, sinks).with_diagnostics(diagnostics)); + + prop_assert_eq!(config.config_hash(), same.config_hash()); + } +} diff --git a/tests/runtimes_replay.rs b/tests/runtimes_replay.rs new file mode 100644 index 0000000..3007506 --- /dev/null +++ b/tests/runtimes_replay.rs @@ -0,0 +1,188 @@ +use proptest::prelude::*; +use serde_json::json; +use weavegraph::event_bus::Event; +use weavegraph::runtimes::{ + ReplayComparison, ReplayConformanceError, ReplayRun, compare_event_sequences, + compare_event_sequences_with, compare_final_state, compare_replay_runs, + compare_replay_runs_with, normalize_event, normalize_state, +}; +use weavegraph::state::VersionedState; + +#[test] +fn replay_event_normalization_removes_runtime_timestamp() { + let left = Event::node_message_with_meta("router", 1, "route", "selected"); + let right = Event::node_message_with_meta("router", 1, "route", "selected"); + + assert_eq!(normalize_event(&left), normalize_event(&right)); + assert!(compare_event_sequences(&[left], &[right]).is_match()); +} + +#[test] +fn replay_event_comparison_reports_mismatch() { + let comparison = compare_event_sequences( + &[Event::node_message("route", "selected-a")], + &[Event::node_message("route", "selected-b")], + ); + + assert!(!comparison.is_match()); + assert_eq!(comparison.differences().len(), 1); +} + +#[test] +fn replay_event_comparison_supports_custom_normalizer() { + let comparison = compare_event_sequences_with( + &[Event::node_message("route", "selected-a")], + &[Event::node_message("route", "selected-b")], + |event| json!({ "scope": event.scope_label(), "message": "ignored" }), + ); + + assert!(comparison.is_match()); +} + +#[test] +fn replay_run_comparison_checks_state_and_events() { + let left_state = VersionedState::builder() + .with_extra("value", json!(1)) + .build(); + let right_state = VersionedState::builder() + .with_extra("value", json!(1)) + .build(); + let different_state = VersionedState::builder() + .with_extra("value", json!(2)) + .build(); + + let left = ReplayRun::new(left_state, vec![Event::diagnostic("run", "done")]); + let right = ReplayRun::new(right_state, vec![Event::diagnostic("run", "done")]); + let different = ReplayRun::new(different_state, vec![Event::diagnostic("run", "done")]); + + assert!(compare_replay_runs(&left, &right).is_match()); + + let mismatch = compare_replay_runs(&left, &different); + assert!(!mismatch.is_match()); + assert!(matches!( + mismatch.assert_matches(), + Err(ReplayConformanceError::Mismatch { .. }) + )); +} + +#[test] +fn replay_comparison_constructors_and_assertion_errors_preserve_differences() { + assert!(ReplayComparison::matched().assert_matches().is_ok()); + + let comparison = ReplayComparison::with_differences(vec!["first".into(), "second".into()]); + assert!(!comparison.is_match()); + assert_eq!(comparison.differences(), &["first", "second"]); + + match comparison.assert_matches() { + Err(ReplayConformanceError::Mismatch { differences }) => { + assert_eq!(differences, vec!["first", "second"]); + } + other => panic!("expected mismatch error, got {other:?}"), + } +} + +#[test] +fn replay_event_comparison_reports_count_mismatch_when_shared_prefix_matches() { + let left = vec![ + Event::diagnostic("run", "one"), + Event::diagnostic("run", "two"), + ]; + let right = vec![Event::diagnostic("run", "one")]; + + let comparison = compare_event_sequences(&left, &right); + + assert!(!comparison.is_match()); + assert_eq!(comparison.differences().len(), 1); + assert!(comparison.differences()[0].contains("event count differs")); +} + +#[test] +fn replay_event_comparison_empty_sequences_match() { + assert!(compare_event_sequences(&[], &[]).is_match()); +} + +#[test] +fn replay_final_state_normalization_includes_versions_and_extra() { + let state = VersionedState::builder() + .with_user_message("hello") + .with_extra("answer", json!(42)) + .build(); + let normalized = normalize_state(&state); + + assert_eq!(normalized["messages_version"], 1); + assert_eq!(normalized["extra_version"], 1); + assert_eq!(normalized["errors_version"], 1); + assert_eq!(normalized["extra"]["answer"], 42); +} + +#[test] +fn replay_final_state_comparison_reports_state_mismatch() { + let left = VersionedState::builder() + .with_extra("value", json!(1)) + .build(); + let right = VersionedState::builder() + .with_extra("value", json!(2)) + .build(); + + let comparison = compare_final_state(&left, &right); + + assert!(!comparison.is_match()); + assert!(comparison.differences()[0].contains("final state differs")); +} + +#[test] +fn replay_run_comparison_aggregates_state_and_event_differences() { + let left = ReplayRun::new( + VersionedState::builder() + .with_extra("value", json!(1)) + .build(), + vec![Event::diagnostic("run", "left")], + ); + let right = ReplayRun::new( + VersionedState::builder() + .with_extra("value", json!(2)) + .build(), + vec![Event::diagnostic("run", "right")], + ); + + let comparison = compare_replay_runs(&left, &right); + + assert!(!comparison.is_match()); + assert_eq!(comparison.differences().len(), 2); + assert!(comparison.differences()[0].contains("final state differs")); + assert!(comparison.differences()[1].contains("event 0 differs")); +} + +#[test] +fn replay_run_custom_event_normalizer_can_ignore_event_differences() { + let left = ReplayRun::new( + VersionedState::builder() + .with_extra("value", json!(1)) + .build(), + vec![Event::diagnostic("run", "left")], + ); + let right = ReplayRun::new( + VersionedState::builder() + .with_extra("value", json!(1)) + .build(), + vec![Event::diagnostic("run", "right")], + ); + + let comparison = compare_replay_runs_with(&left, &right, |_| json!({ "event": "ignored" })); + + assert!(comparison.is_match()); +} + +proptest! { + #[test] + fn prop_replay_event_custom_normalizer_matches_same_length_sequences( + left in prop::collection::vec("[A-Za-z0-9 _-]{0,24}", 0..12), + right in prop::collection::vec("[A-Za-z0-9 _-]{0,24}", 0..12), + ) { + let len = left.len().min(right.len()); + let left_events: Vec = left.into_iter().take(len).map(|message| Event::diagnostic("scope", message)).collect(); + let right_events: Vec = right.into_iter().take(len).map(|message| Event::diagnostic("scope", message)).collect(); + + prop_assert!(compare_event_sequences_with(&left_events, &right_events, |_| json!("ignored")).is_match()); + } +} diff --git a/tests/runtimes_runner.rs b/tests/runtimes_runner.rs index a0a176e..d2f5815 100644 --- a/tests/runtimes_runner.rs +++ b/tests/runtimes_runner.rs @@ -2,10 +2,14 @@ use std::sync::{ Arc, RwLock, atomic::{AtomicUsize, Ordering}, }; +use std::time::Duration; use async_trait::async_trait; -#[cfg(feature = "sqlite")] +use serde_json::json; use weavegraph::channels::Channel; +use weavegraph::event_bus::{ + EventBus, EventStream, INVOCATION_END_SCOPE, MemorySink, STREAM_END_SCOPE, +}; use weavegraph::graphs::{EdgePredicate, GraphBuilder}; use weavegraph::message::{Message, Role}; use weavegraph::node::{Node, NodeContext, NodeError, NodePartial}; @@ -16,6 +20,7 @@ use weavegraph::runtimes::{ use weavegraph::schedulers::{Scheduler, SchedulerState}; use weavegraph::state::{StateSnapshot, VersionedState}; use weavegraph::types::NodeKind; +use weavegraph::utils::clock::MockClock; use weavegraph::{FrontierCommand, NodeRoute}; mod common; @@ -104,6 +109,93 @@ fn checkpoint_from_state(session_id: &str, step: u64, state: VersionedState) -> Checkpoint::from_session(session_id, &session_state) } +#[derive(Debug, Clone)] +struct TickAccumulatorNode; + +#[async_trait] +impl Node for TickAccumulatorNode { + async fn run( + &self, + snapshot: StateSnapshot, + ctx: NodeContext, + ) -> Result { + let tick = snapshot + .extra + .get("tick") + .and_then(serde_json::Value::as_i64) + .unwrap_or_default(); + ctx.emit("tick", format!("processed:{tick}"))?; + let sum = snapshot + .extra + .get("sum") + .and_then(serde_json::Value::as_i64) + .unwrap_or_default(); + + let mut extra = weavegraph::utils::collections::new_extra_map(); + extra.insert("sum".to_string(), serde_json::json!(sum + tick)); + extra.insert("last_tick".to_string(), serde_json::json!(tick)); + extra.insert("last_step".to_string(), serde_json::json!(ctx.step)); + + Ok(NodePartial::new().with_extra(extra)) + } +} + +fn make_iterative_app() -> weavegraph::app::App { + GraphBuilder::new() + .add_node(NodeKind::Custom("accumulate".into()), TickAccumulatorNode) + .add_edge(NodeKind::Start, NodeKind::Custom("accumulate".into())) + .add_edge(NodeKind::Custom("accumulate".into()), NodeKind::End) + .compile() + .unwrap() +} + +fn tick_input(tick: i64) -> NodePartial { + let mut extra = weavegraph::utils::collections::new_extra_map(); + extra.insert("tick".to_string(), serde_json::json!(tick)); + NodePartial::new().with_extra(extra) +} + +async fn recv_matching_event( + stream: &mut EventStream, + predicate: impl Fn(&weavegraph::event_bus::Event) -> bool, +) -> weavegraph::event_bus::Event { + for _ in 0..10 { + let event = tokio::time::timeout(Duration::from_secs(1), stream.recv()) + .await + .expect("event stream should receive an event") + .expect("event stream should stay open"); + if predicate(&event) { + return event; + } + } + panic!("matching event was not received"); +} + +#[derive(Debug, Clone)] +struct ClockProbeNode; + +#[async_trait] +impl Node for ClockProbeNode { + async fn run( + &self, + _snapshot: StateSnapshot, + ctx: NodeContext, + ) -> Result { + ctx.emit("clock", "observed")?; + + let mut extra = weavegraph::utils::collections::new_extra_map(); + extra.insert( + "now_unix_ms".to_string(), + serde_json::json!(ctx.now_unix_ms()), + ); + extra.insert( + "invocation_id".to_string(), + serde_json::json!(ctx.invocation_id()), + ); + Ok(NodePartial::new().with_extra(extra)) + } +} + #[tokio::test] async fn test_conditional_edge_routing() { let pred: EdgePredicate = std::sync::Arc::new(|snap: StateSnapshot| { @@ -339,6 +431,464 @@ async fn test_run_until_complete() { assert_message_contains(&final_state, "ran:test:step:1"); } +#[tokio::test] +async fn test_iterative_invocation_processes_identical_inputs() { + let app = make_iterative_app(); + let mut runner = AppRunner::builder() + .app(app) + .checkpointer(CheckpointerType::InMemory) + .build() + .await; + + let init = runner + .create_iterative_session( + "iterative-identical".to_string(), + state_with_user("start"), + NodeKind::Start, + ) + .await + .unwrap(); + assert_eq!(init, SessionInit::Fresh); + + let first = runner + .invoke_next("iterative-identical", tick_input(1), NodeKind::Start) + .await + .unwrap(); + assert_eq!( + first.extra.snapshot().get("sum"), + Some(&serde_json::json!(1)) + ); + + let second = runner + .invoke_next("iterative-identical", tick_input(1), NodeKind::Start) + .await + .unwrap(); + let extra = second.extra.snapshot(); + assert_eq!(extra.get("sum"), Some(&serde_json::json!(2))); + assert_eq!(extra.get("last_step"), Some(&serde_json::json!(2))); + + let session = runner.get_session("iterative-identical").unwrap(); + assert_eq!(session.step, 2); + assert_eq!(session.frontier, vec![NodeKind::End]); +} + +#[tokio::test] +async fn test_iterative_session_rejects_invalid_entry_without_creating_session() { + let app = make_iterative_app(); + let mut runner = AppRunner::builder() + .app(app) + .checkpointer(CheckpointerType::InMemory) + .build() + .await; + + let error = runner + .create_iterative_session( + "invalid-entry".to_string(), + state_with_user("start"), + NodeKind::End, + ) + .await + .unwrap_err(); + + assert!(matches!( + error, + weavegraph::runtimes::runner::RunnerError::InvalidIterativeEntry { + node: NodeKind::End + } + )); + assert!(runner.get_session("invalid-entry").is_none()); +} + +#[tokio::test] +async fn test_iterative_invocation_rejects_invalid_entry_without_applying_input() { + let app = make_iterative_app(); + let mut runner = AppRunner::builder() + .app(app) + .checkpointer(CheckpointerType::InMemory) + .build() + .await; + runner + .create_iterative_session( + "invalid-next".to_string(), + state_with_user("start"), + NodeKind::Start, + ) + .await + .unwrap(); + + let error = runner + .invoke_next("invalid-next", tick_input(99), NodeKind::End) + .await + .unwrap_err(); + + assert!(matches!( + error, + weavegraph::runtimes::runner::RunnerError::InvalidIterativeEntry { + node: NodeKind::End + } + )); + let session = runner.get_session("invalid-next").unwrap(); + assert_eq!(session.step, 0); + assert!(!session.state.snapshot().extra.contains_key("tick")); +} + +#[tokio::test] +async fn test_iterative_invocation_rejects_unregistered_custom_entry() { + let app = make_iterative_app(); + let mut runner = AppRunner::builder() + .app(app) + .checkpointer(CheckpointerType::InMemory) + .build() + .await; + runner + .create_iterative_session( + "unknown-custom".to_string(), + state_with_user("start"), + NodeKind::Start, + ) + .await + .unwrap(); + + let unknown = NodeKind::Custom("missing".to_string()); + let error = runner + .invoke_next("unknown-custom", tick_input(1), unknown.clone()) + .await + .unwrap_err(); + + assert!(matches!( + error, + weavegraph::runtimes::runner::RunnerError::InvalidIterativeEntry { node } if node == unknown + )); +} + +#[tokio::test] +async fn test_iterative_custom_entry_runs_from_registered_node() { + let app = GraphBuilder::new() + .add_node(NodeKind::Custom("first".into()), TickAccumulatorNode) + .add_node(NodeKind::Custom("second".into()), TickAccumulatorNode) + .add_edge(NodeKind::Start, NodeKind::Custom("first".into())) + .add_edge( + NodeKind::Custom("first".into()), + NodeKind::Custom("second".into()), + ) + .add_edge(NodeKind::Custom("second".into()), NodeKind::End) + .compile() + .unwrap(); + let mut runner = AppRunner::builder() + .app(app) + .checkpointer(CheckpointerType::InMemory) + .build() + .await; + runner + .create_iterative_session( + "custom-entry".to_string(), + state_with_user("start"), + NodeKind::Custom("second".into()), + ) + .await + .unwrap(); + + let final_state = runner + .invoke_next( + "custom-entry", + tick_input(5), + NodeKind::Custom("second".into()), + ) + .await + .unwrap(); + + assert_eq!(final_state.extra.snapshot().get("sum"), Some(&json!(5))); + assert_eq!(runner.get_session("custom-entry").unwrap().step, 1); +} + +#[tokio::test] +async fn test_iterative_event_stream_stays_open_until_finished() { + let app = make_iterative_app(); + let mut runner = AppRunner::builder() + .app(app) + .checkpointer(CheckpointerType::InMemory) + .build() + .await; + let mut stream = runner + .event_stream() + .expect("iterative subscription should be available"); + + runner + .create_iterative_session( + "iterative-stream".to_string(), + state_with_user("start"), + NodeKind::Start, + ) + .await + .unwrap(); + + runner + .invoke_next("iterative-stream", tick_input(1), NodeKind::Start) + .await + .unwrap(); + let first_tick = recv_matching_event(&mut stream, |event| { + event.scope_label() == Some("tick") && event.message() == "processed:1" + }) + .await; + assert_eq!(first_tick.scope_label(), Some("tick")); + let first_end = recv_matching_event(&mut stream, |event| { + event.scope_label() == Some(INVOCATION_END_SCOPE) + }) + .await; + assert!(first_end.message().contains("status=completed")); + + runner + .invoke_next("iterative-stream", tick_input(2), NodeKind::Start) + .await + .unwrap(); + let second_tick = recv_matching_event(&mut stream, |event| { + event.scope_label() == Some("tick") && event.message() == "processed:2" + }) + .await; + assert_eq!(second_tick.scope_label(), Some("tick")); + let second_end = recv_matching_event(&mut stream, |event| { + event.scope_label() == Some(INVOCATION_END_SCOPE) + }) + .await; + assert!(second_end.message().contains("status=completed")); + + runner.finish_iterative_session("iterative-stream").unwrap(); + let terminal = recv_matching_event(&mut stream, |event| { + event.scope_label() == Some(STREAM_END_SCOPE) + }) + .await; + assert!(terminal.message().contains("status=completed")); + + let closed = tokio::time::timeout(Duration::from_secs(1), stream.recv()) + .await + .expect("closed stream should resolve promptly"); + assert!(matches!( + closed, + Err(tokio::sync::broadcast::error::RecvError::Closed) + )); +} + +#[tokio::test] +async fn test_iterative_event_stream_reports_errors_without_closing_until_finished() { + let app = GraphBuilder::new() + .add_node(NodeKind::Custom("fail".into()), FailingNode::default()) + .add_edge(NodeKind::Start, NodeKind::Custom("fail".into())) + .add_edge(NodeKind::Custom("fail".into()), NodeKind::End) + .compile() + .unwrap(); + let mut runner = AppRunner::builder() + .app(app) + .checkpointer(CheckpointerType::InMemory) + .build() + .await; + let mut stream = runner + .event_stream() + .expect("event stream should be available"); + runner + .create_iterative_session( + "iterative-error-stream".to_string(), + state_with_user("start"), + NodeKind::Start, + ) + .await + .unwrap(); + + let result = runner + .invoke_next( + "iterative-error-stream", + NodePartial::new(), + NodeKind::Start, + ) + .await; + + assert!(result.is_err()); + let invocation_end = recv_matching_event(&mut stream, |event| { + event.scope_label() == Some(INVOCATION_END_SCOPE) + }) + .await; + assert!(invocation_end.message().contains("status=error")); + + runner + .finish_iterative_session("iterative-error-stream") + .unwrap(); + let terminal = recv_matching_event(&mut stream, |event| { + event.scope_label() == Some(STREAM_END_SCOPE) + }) + .await; + assert!(terminal.message().contains("status=completed")); +} + +#[tokio::test] +async fn test_finish_iterative_session_reports_missing_session() { + let app = make_iterative_app(); + let mut runner = AppRunner::builder() + .app(app) + .checkpointer(CheckpointerType::InMemory) + .build() + .await; + + let error = runner + .finish_iterative_session("missing-session") + .unwrap_err(); + + assert!(matches!( + error, + weavegraph::runtimes::runner::RunnerError::SessionNotFound { session_id } + if session_id == "missing-session" + )); +} + +#[tokio::test] +async fn test_iterative_invocation_resumes_latest_checkpoint() { + const SESSION_ID: &str = "iterative-resume"; + + let mut uninterrupted = AppRunner::builder() + .app(make_iterative_app()) + .checkpointer(CheckpointerType::InMemory) + .build() + .await; + uninterrupted + .create_iterative_session( + SESSION_ID.to_string(), + state_with_user("start"), + NodeKind::Start, + ) + .await + .unwrap(); + + let mut uninterrupted_state = state_with_user("unused"); + for tick in 1..=5 { + uninterrupted_state = uninterrupted + .invoke_next(SESSION_ID, tick_input(tick), NodeKind::Start) + .await + .unwrap(); + } + let uninterrupted_extra = uninterrupted_state.extra.snapshot(); + let uninterrupted_step = uninterrupted.get_session(SESSION_ID).unwrap().step; + + let probe = Arc::new(ProbeCheckpointer::default()); + let mut before_restart = AppRunner::builder() + .app(make_iterative_app()) + .checkpointer_custom(probe.clone()) + .build() + .await; + before_restart + .create_iterative_session( + SESSION_ID.to_string(), + state_with_user("start"), + NodeKind::Start, + ) + .await + .unwrap(); + for tick in 1..=3 { + before_restart + .invoke_next(SESSION_ID, tick_input(tick), NodeKind::Start) + .await + .unwrap(); + } + drop(before_restart); + + let mut after_restart = AppRunner::builder() + .app(make_iterative_app()) + .checkpointer_custom(probe.clone()) + .build() + .await; + let resumed = after_restart + .create_iterative_session( + SESSION_ID.to_string(), + state_with_user("ignored after checkpoint restore"), + NodeKind::Start, + ) + .await + .unwrap(); + assert_eq!(resumed, SessionInit::Resumed { checkpoint_step: 3 }); + + let mut resumed_state = state_with_user("unused"); + for tick in 4..=5 { + resumed_state = after_restart + .invoke_next(SESSION_ID, tick_input(tick), NodeKind::Start) + .await + .unwrap(); + } + + let resumed_extra = resumed_state.extra.snapshot(); + assert_eq!(resumed_extra.get("sum"), uninterrupted_extra.get("sum")); + assert_eq!(resumed_extra.get("last_tick"), Some(&serde_json::json!(5))); + assert_eq!( + after_restart.get_session(SESSION_ID).unwrap().step, + uninterrupted_step + ); + assert!(probe.load_calls() > 0); + assert!(probe.save_calls() > 0); +} + +#[tokio::test] +async fn test_runtime_clock_reaches_node_context_and_events() { + let app = GraphBuilder::new() + .add_node(NodeKind::Custom("clock".into()), ClockProbeNode) + .add_edge(NodeKind::Start, NodeKind::Custom("clock".into())) + .add_edge(NodeKind::Custom("clock".into()), NodeKind::End) + .compile() + .unwrap(); + let event_bus = EventBus::with_sink(MemorySink::new()); + let mut event_stream = event_bus.subscribe(); + + let mut runner = AppRunner::builder() + .app(app) + .checkpointer(CheckpointerType::InMemory) + .event_bus(event_bus) + .clock(Arc::new(MockClock::new(123))) + .build() + .await; + + runner + .create_session("clock-session".to_string(), state_with_user("clock")) + .await + .unwrap(); + let final_state = runner.run_until_complete("clock-session").await.unwrap(); + let extra = final_state.extra.snapshot(); + + assert_eq!(extra.get("now_unix_ms"), Some(&serde_json::json!(123_000))); + assert_eq!( + extra.get("invocation_id"), + Some(&serde_json::json!("clock-session")) + ); + + let mut node_event = None; + for _ in 0..5 { + let event = tokio::time::timeout(Duration::from_secs(1), event_stream.recv()) + .await + .expect("event stream should receive an event") + .expect("event stream should stay open"); + if event.scope_label() == Some("clock") { + node_event = Some(event); + break; + } + } + let node_event = node_event.expect("clock event should be captured"); + let event_json = node_event.to_json_value(); + assert_eq!(event_json["metadata"]["invocation_id"], "clock-session"); + assert_eq!(event_json["metadata"]["now_unix_ms"], 123_000); +} + +#[tokio::test] +async fn test_runner_metadata_reports_graph_runtime_and_backends() { + let app = make_iterative_app(); + let graph_hash = app.graph_definition_hash(); + let runner = AppRunner::builder() + .app(app) + .checkpointer(CheckpointerType::InMemory) + .clock(Arc::new(MockClock::new(1))) + .build() + .await; + + let metadata = runner.run_metadata(); + assert_eq!(metadata.graph_hash, graph_hash); + assert!(!metadata.runtime_config_hash.is_empty()); + assert_eq!(metadata.checkpointer_backend, "in-memory"); + assert_eq!(metadata.clock_mode, "configured"); +} + #[derive(Debug, Clone)] struct ReplaceController; diff --git a/tests/schedulers.rs b/tests/schedulers.rs index f9871f1..f05872a 100644 --- a/tests/schedulers.rs +++ b/tests/schedulers.rs @@ -1,11 +1,37 @@ +use async_trait::async_trait; use rustc_hash::FxHashMap; +use serde_json::json; use std::sync::Arc; use weavegraph::event_bus::EventBus; -use weavegraph::schedulers::scheduler::{Scheduler, SchedulerState, StepRunResult}; +use weavegraph::node::{Node, NodeContext, NodeError, NodePartial}; +use weavegraph::schedulers::scheduler::{ + Scheduler, SchedulerRunContext, SchedulerState, StepRunResult, +}; +use weavegraph::state::StateSnapshot; use weavegraph::types::NodeKind; +use weavegraph::utils::clock::MockClock; mod common; use common::{FailingNode, create_test_snapshot, make_delayed_registry, make_test_registry}; +#[derive(Debug, Clone)] +struct SchedulerContextProbe; + +#[async_trait] +impl Node for SchedulerContextProbe { + async fn run( + &self, + _snapshot: StateSnapshot, + ctx: NodeContext, + ) -> Result { + let mut extra = weavegraph::utils::collections::new_extra_map(); + extra.insert("node_id".to_string(), json!(ctx.node_id)); + extra.insert("step".to_string(), json!(ctx.step)); + extra.insert("now_unix_ms".to_string(), json!(ctx.now_unix_ms())); + extra.insert("invocation_id".to_string(), json!(ctx.invocation_id())); + Ok(NodePartial::new().with_extra(extra)) + } +} + #[tokio::test] async fn test_superstep_propagates_node_error() { let sched = Scheduler::new(4); @@ -26,7 +52,7 @@ async fn test_superstep_propagates_node_error() { frontier, snap, 1, - event_bus.get_emitter(), + SchedulerRunContext::new(event_bus.get_emitter()), ) .await; match res { @@ -88,7 +114,7 @@ async fn test_superstep_skips_end_and_nochange() { frontier.clone(), snap.clone(), 1, - event_bus.get_emitter(), + SchedulerRunContext::new(event_bus.get_emitter()), ) .await .unwrap(); @@ -109,7 +135,7 @@ async fn test_superstep_skips_end_and_nochange() { frontier.clone(), snap.clone(), 2, - event_bus.get_emitter(), + SchedulerRunContext::new(event_bus.get_emitter()), ) .await .unwrap(); @@ -131,7 +157,7 @@ async fn test_superstep_skips_end_and_nochange() { frontier.clone(), snap_bump, 3, - event_bus.get_emitter(), + SchedulerRunContext::new(event_bus.get_emitter()), ) .await .unwrap(); @@ -156,7 +182,7 @@ async fn test_superstep_outputs_order_agnostic() { frontier.clone(), snap, 1, - event_bus.get_emitter(), + SchedulerRunContext::new(event_bus.get_emitter()), ) .await .unwrap(); @@ -191,7 +217,7 @@ async fn test_superstep_serialized_with_limit_1() { frontier.clone(), snap, 1, - event_bus.get_emitter(), + SchedulerRunContext::new(event_bus.get_emitter()), ) .await .unwrap(); @@ -204,3 +230,49 @@ async fn test_superstep_serialized_with_limit_1() { let output_ids: Vec<_> = res.outputs.iter().map(|(id, _)| id.clone()).collect(); assert_eq!(output_ids, res.ran_nodes); } + +#[tokio::test] +async fn test_scheduler_run_context_injects_clock_and_invocation_id() { + let sched = Scheduler::new(1); + let mut state = SchedulerState::default(); + let mut nodes: FxHashMap> = FxHashMap::default(); + nodes.insert( + NodeKind::Custom("probe".into()), + Arc::new(SchedulerContextProbe), + ); + let event_bus = EventBus::default(); + + let result = sched + .superstep( + &mut state, + &nodes, + vec![NodeKind::Custom("probe".into())], + create_test_snapshot(1, 1), + 9, + SchedulerRunContext::new(event_bus.get_emitter()) + .with_clock(Arc::new(MockClock::new(1234))) + .with_invocation_id("scheduler-session"), + ) + .await + .unwrap(); + + assert_eq!(result.outputs.len(), 1); + let extra = result.outputs[0] + .1 + .extra + .as_ref() + .expect("probe should emit extra"); + assert_eq!( + extra.get("node_id"), + Some(&json!(format!( + "{:?}", + NodeKind::Custom("probe".to_string()) + ))) + ); + assert_eq!(extra.get("step"), Some(&json!(9))); + assert_eq!(extra.get("now_unix_ms"), Some(&json!(1_234_000))); + assert_eq!( + extra.get("invocation_id"), + Some(&json!("scheduler-session")) + ); +} diff --git a/tests/state_channels.rs b/tests/state_channels.rs index 1e22379..6ba9f4a 100644 --- a/tests/state_channels.rs +++ b/tests/state_channels.rs @@ -1,7 +1,44 @@ +use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use weavegraph::channels::Channel; use weavegraph::message::{Message, Role}; -use weavegraph::state::VersionedState; +use weavegraph::node::NodePartial; +use weavegraph::state::{StateKey, StateSlotError, VersionedState}; + +use proptest::prelude::*; + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +struct PortfolioSnapshot { + cash_cents: i64, + position_count: u32, +} + +const PORTFOLIO: StateKey = StateKey::new("wq", "portfolio", 1); +const PORTFOLIO_V2: StateKey = StateKey::new("wq", "portfolio", 2); + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +struct PropertyPayload { + label: String, + amount: i64, + flags: Vec, +} + +const PROPERTY_PAYLOAD: StateKey = StateKey::new("prop", "payload", 1); + +struct AlwaysFailsSerialize; + +impl Serialize for AlwaysFailsSerialize { + fn serialize(&self, _serializer: S) -> Result + where + S: serde::Serializer, + { + Err(serde::ser::Error::custom( + "intentional serialization failure", + )) + } +} + +const FAILING_SLOT: StateKey = StateKey::new("wg", "failing", 1); #[test] fn test_new_with_user_message_initializes_fields() { @@ -141,3 +178,206 @@ fn test_convenience_methods() { assert_eq!(snapshot.extra.get("key1"), Some(&json!("value1"))); assert_eq!(snapshot.extra.get("key2"), Some(&json!(42))); } + +#[test] +fn test_typed_state_slots_round_trip() { + let portfolio = PortfolioSnapshot { + cash_cents: 12_345, + position_count: 2, + }; + + let state = VersionedState::builder() + .with_user_message("portfolio") + .with_typed_extra(PORTFOLIO, portfolio.clone()) + .unwrap() + .build(); + + let snapshot = state.snapshot(); + assert_eq!(PORTFOLIO.storage_key(), "wq:portfolio:v1"); + assert_eq!( + snapshot.get_typed(PORTFOLIO).unwrap(), + Some(portfolio.clone()) + ); + assert_eq!(snapshot.require_typed(PORTFOLIO).unwrap(), portfolio); +} + +#[test] +fn test_state_key_accessors_and_schema_versions() { + assert_eq!(PORTFOLIO.namespace(), "wq"); + assert_eq!(PORTFOLIO.name(), "portfolio"); + assert_eq!(PORTFOLIO.schema_version(), 1); + assert_eq!(PORTFOLIO.storage_key(), "wq:portfolio:v1"); + assert_eq!(PORTFOLIO_V2.storage_key(), "wq:portfolio:v2"); + assert_ne!(PORTFOLIO.storage_key(), PORTFOLIO_V2.storage_key()); +} + +#[test] +fn test_typed_state_slots_missing_and_optional_reads() { + let snapshot = VersionedState::builder().build().snapshot(); + + assert_eq!(snapshot.get_typed(PORTFOLIO).unwrap(), None); + match snapshot.require_typed::(PORTFOLIO) { + Err(StateSlotError::Missing { key }) => assert_eq!(key, "wq:portfolio:v1"), + other => panic!("expected missing slot error, got {other:?}"), + } +} + +#[test] +fn test_typed_state_slots_report_deserialization_errors_with_key() { + let state = VersionedState::builder() + .with_extra( + &PORTFOLIO.storage_key(), + json!({ "cash_cents": "not-an-integer", "position_count": 1 }), + ) + .build(); + + match state + .snapshot() + .require_typed::(PORTFOLIO) + { + Err(StateSlotError::Deserialize { key, source }) => { + assert_eq!(key, "wq:portfolio:v1"); + assert!(source.to_string().contains("invalid type")); + } + other => panic!("expected deserialize slot error, got {other:?}"), + } +} + +#[test] +fn test_typed_state_slots_report_serialization_errors_with_key() { + let builder_error = VersionedState::builder() + .with_typed_extra(FAILING_SLOT, AlwaysFailsSerialize) + .unwrap_err(); + match builder_error { + StateSlotError::Serialize { key, source } => { + assert_eq!(key, "wg:failing:v1"); + assert!( + source + .to_string() + .contains("intentional serialization failure") + ); + } + other => panic!("expected serialize slot error, got {other:?}"), + } + + let partial_error = NodePartial::new() + .with_typed_extra(FAILING_SLOT, AlwaysFailsSerialize) + .unwrap_err(); + assert!(matches!( + partial_error, + StateSlotError::Serialize { key, .. } if key == "wg:failing:v1" + )); +} + +#[test] +fn test_typed_state_slots_schema_versions_can_coexist() { + let v1 = PortfolioSnapshot { + cash_cents: 100, + position_count: 1, + }; + let v2 = PortfolioSnapshot { + cash_cents: 200, + position_count: 2, + }; + + let state = VersionedState::builder() + .with_typed_extra(PORTFOLIO, v1.clone()) + .unwrap() + .with_typed_extra(PORTFOLIO_V2, v2.clone()) + .unwrap() + .build(); + let snapshot = state.snapshot(); + + assert_eq!(snapshot.require_typed(PORTFOLIO).unwrap(), v1); + assert_eq!(snapshot.require_typed(PORTFOLIO_V2).unwrap(), v2); +} + +#[test] +fn test_versioned_state_add_typed_extra_chains_and_overwrites_slot() { + let first = PortfolioSnapshot { + cash_cents: 10, + position_count: 1, + }; + let second = PortfolioSnapshot { + cash_cents: 20, + position_count: 2, + }; + let mut state = VersionedState::new_with_user_message("typed"); + + state + .add_typed_extra(PORTFOLIO, first) + .unwrap() + .add_typed_extra(PORTFOLIO, second.clone()) + .unwrap(); + + assert_eq!(state.snapshot().require_typed(PORTFOLIO).unwrap(), second); +} + +#[test] +fn test_node_partial_with_typed_extra() { + let portfolio = PortfolioSnapshot { + cash_cents: 500, + position_count: 1, + }; + + let partial = NodePartial::new() + .with_typed_extra(PORTFOLIO, portfolio.clone()) + .unwrap(); + let extra = partial.extra.expect("typed extra should be inserted"); + let stored = extra + .get(&PORTFOLIO.storage_key()) + .expect("typed storage key should exist"); + + assert_eq!( + serde_json::from_value::(stored.clone()).unwrap(), + portfolio + ); +} + +#[test] +fn test_node_partial_with_typed_extra_merges_with_existing_extra_and_overwrites_same_slot() { + let old = PortfolioSnapshot { + cash_cents: 1, + position_count: 1, + }; + let new = PortfolioSnapshot { + cash_cents: 999, + position_count: 3, + }; + let mut extra = weavegraph::utils::collections::new_extra_map(); + extra.insert("untouched".to_string(), json!(true)); + extra.insert(PORTFOLIO.storage_key(), serde_json::to_value(old).unwrap()); + + let partial = NodePartial::new() + .with_extra(extra) + .with_typed_extra(PORTFOLIO, new.clone()) + .unwrap(); + let extra = partial.extra.expect("extra should be present"); + + assert_eq!(extra.get("untouched"), Some(&json!(true))); + assert_eq!( + serde_json::from_value::(extra[&PORTFOLIO.storage_key()].clone()) + .unwrap(), + new + ); +} + +proptest! { + #[test] + fn prop_typed_state_slots_round_trip_generated_payload( + label in "[A-Za-z0-9 _:-]{0,48}", + amount in any::(), + flags in prop::collection::vec(any::(), 0..16), + ) { + let payload = PropertyPayload { label, amount, flags }; + let state = VersionedState::builder() + .with_typed_extra(PROPERTY_PAYLOAD, payload.clone()) + .expect("generated payload should serialize") + .build(); + + prop_assert_eq!( + state.snapshot().get_typed(PROPERTY_PAYLOAD).expect("generated payload should deserialize"), + Some(payload) + ); + } +}