Skip to content

Commit e91f356

Browse files
authored
feat!: support passing state from execute() to finalize() via typed Memo (#69)
## Summary - Add a default type parameter `Memo = ()` to `TypedExecutor<T>` so executors can return typed state from `execute()` that is persisted to SQLite and delivered to `finalize()` after children complete - New `Domain::task_memo()` and `Domain::task_with_memo()` registration methods for memo-producing executors; existing `Domain::task()` is unchanged - Add `memo BLOB` column to `tasks` and `task_history` tables (migration 009); `TypeId::of::<()>()` check avoids DB writes for non-memo tasks - Update all doc examples (`lib.rs`, `quick-start.md`, `migrating-to-0.5.md`) to include the new `_memo: ()` parameter in `finalize()` signatures Closes #64 ## Details The memo is serialized via `serde_json` at the `set_waiting` transition (the single correct write point after execute returns and children are detected). On finalize dispatch, the memo bytes are deserialized back into the concrete `Memo` type. When `Memo = ()`, no serialization or DB write occurs — the `TypeId` guard short-circuits to `Ok(None)`. ### Public API changes | Before | After | |---|---| | `TypedExecutor<T>` | `TypedExecutor<T, Memo = ()>` | | `execute() -> Result<(), TaskError>` | `execute() -> Result<Memo, TaskError>` | | `finalize(payload, ctx)` | `finalize(payload, memo, ctx)` | | `Domain::task()` | `Domain::task()` (unchanged) + `Domain::task_memo()` | | `Domain::task_with()` | `Domain::task_with()` (unchanged) + `Domain::task_with_memo()` | ### Migration Existing executors with `Memo = ()` only need to add `_memo: ()` to their `finalize()` override (if any). Executors that don't override `finalize()` require zero changes. ## BREAKING CHANGE `TypedExecutor::finalize()` signature adds a `Memo` parameter between `payload` and `ctx`.
1 parent d903234 commit e91f356

18 files changed

Lines changed: 675 additions & 44 deletions

File tree

docs/migrating-to-0.5.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ The `finalize` and `on_cancel` hooks follow the same pattern:
165165

166166
```rust
167167
impl TypedExecutor<Thumbnail> for ThumbnailExec {
168-
async fn finalize(&self, thumb: Thumbnail, ctx: &TaskContext) -> Result<(), TaskError> {
168+
async fn finalize(&self, thumb: Thumbnail, _memo: (), ctx: &TaskContext) -> Result<(), TaskError> {
169169
// called after all children settle
170170
Ok(())
171171
}

docs/quick-start.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ impl TypedExecutor<MultipartUpload> for MultipartUploader {
277277
}
278278

279279
async fn finalize(
280-
&self, upload: MultipartUpload, ctx: &TaskContext,
280+
&self, upload: MultipartUpload, _memo: (), ctx: &TaskContext,
281281
) -> Result<(), TaskError> {
282282
// Called after all children complete
283283
complete_multipart_upload(&upload).await

migrations/009_memo.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
-- Execute-to-finalize memo: typed state persisted between phases.
2+
ALTER TABLE tasks ADD COLUMN memo BLOB DEFAULT NULL;
3+
ALTER TABLE task_history ADD COLUMN memo BLOB DEFAULT NULL;

src/domain.rs

Lines changed: 107 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ use std::pin::Pin;
2020
use std::sync::Arc;
2121
use std::time::Duration;
2222

23+
use serde::{de::DeserializeOwned, Serialize};
24+
2325
use crate::module::{ExecutorOptions, ModuleExecutor, ModuleHandle};
2426
use crate::priority::Priority;
2527
use crate::registry::{DomainTaskContext, ErasedExecutor, TaskContext, TaskExecutor};
@@ -181,19 +183,29 @@ pub struct TaskTypeOptions {
181183
/// }
182184
/// }
183185
/// ```
184-
pub trait TypedExecutor<T: TypedTask>: Send + Sync + 'static {
186+
pub trait TypedExecutor<
187+
T: TypedTask,
188+
Memo: Serialize + DeserializeOwned + Send + Sync + 'static = (),
189+
>: Send + Sync + 'static
190+
{
185191
/// Primary execution. Called once per dispatch.
192+
///
193+
/// Returns a `Memo` that will be persisted and passed to [`finalize()`](Self::finalize)
194+
/// after all children complete. For the default `Memo = ()`, the return type
195+
/// is `Result<(), TaskError>` — identical to the pre-memo API.
186196
fn execute<'a>(
187197
&'a self,
188198
payload: T,
189199
ctx: DomainTaskContext<'a, T::Domain>,
190-
) -> impl Future<Output = Result<(), TaskError>> + Send + 'a;
200+
) -> impl Future<Output = Result<Memo, TaskError>> + Send + 'a;
191201

192202
/// Called when all child tasks spawned by this task have settled.
203+
/// Receives the `Memo` returned by [`execute()`](Self::execute).
193204
/// Default: no-op.
194205
fn finalize<'a>(
195206
&'a self,
196207
_payload: T,
208+
_memo: Memo,
197209
_ctx: DomainTaskContext<'a, T::Domain>,
198210
) -> impl Future<Output = Result<(), TaskError>> + Send + 'a {
199211
async { Ok(()) }
@@ -212,26 +224,46 @@ pub trait TypedExecutor<T: TypedTask>: Send + Sync + 'static {
212224

213225
// ── TypedExecutorAdapter ─────────────────────────────────────────────
214226

215-
/// Internal adapter that wraps a [`TypedExecutor<T>`] into a [`TaskExecutor`]
227+
/// Internal adapter that wraps a [`TypedExecutor<T, Memo>`] into a [`TaskExecutor`]
216228
/// for the scheduler engine.
217229
///
218-
/// Handles payload deserialization before delegating to the typed executor.
219-
struct TypedExecutorAdapter<T, E> {
230+
/// Handles payload deserialization and memo serialization/deserialization.
231+
struct TypedExecutorAdapter<T, M, E> {
220232
executor: E,
221-
_marker: PhantomData<fn() -> T>,
233+
_marker: PhantomData<fn() -> (T, M)>,
222234
}
223235

224-
impl<T: TypedTask, E: TypedExecutor<T>> TaskExecutor for TypedExecutorAdapter<T, E> {
225-
async fn execute<'a>(&'a self, ctx: &'a TaskContext) -> Result<(), TaskError> {
236+
impl<T, M, E> TaskExecutor for TypedExecutorAdapter<T, M, E>
237+
where
238+
T: TypedTask,
239+
M: Serialize + DeserializeOwned + Send + Sync + 'static,
240+
E: TypedExecutor<T, M>,
241+
{
242+
async fn execute<'a>(&'a self, ctx: &'a TaskContext) -> Result<Option<Vec<u8>>, TaskError> {
226243
let payload: T = ctx.payload()?;
227244
let dctx = DomainTaskContext::<T::Domain>::new(ctx);
228-
self.executor.execute(payload, dctx).await
245+
let memo = self.executor.execute(payload, dctx).await?;
246+
247+
// Don't persist () — serialize to None.
248+
if std::any::TypeId::of::<M>() == std::any::TypeId::of::<()>() {
249+
return Ok(None);
250+
}
251+
252+
let bytes = serde_json::to_vec(&memo)
253+
.map_err(|e| TaskError::permanent(format!("memo serialization: {e}")))?;
254+
Ok(Some(bytes))
229255
}
230256

231257
async fn finalize<'a>(&'a self, ctx: &'a TaskContext) -> Result<(), TaskError> {
232258
let payload: T = ctx.payload()?;
259+
let memo: M = match &ctx.record().memo {
260+
Some(bytes) => serde_json::from_slice(bytes)
261+
.map_err(|e| TaskError::permanent(format!("memo deserialization: {e}")))?,
262+
None => serde_json::from_value(serde_json::Value::Null)
263+
.map_err(|e| TaskError::permanent(format!("memo deserialization: {e}")))?,
264+
};
233265
let dctx = DomainTaskContext::<T::Domain>::new(ctx);
234-
self.executor.finalize(payload, dctx).await
266+
self.executor.finalize(payload, memo, dctx).await
235267
}
236268

237269
async fn on_cancel<'a>(&'a self, ctx: &'a TaskContext) -> Result<(), TaskError> {
@@ -241,6 +273,19 @@ impl<T: TypedTask, E: TypedExecutor<T>> TaskExecutor for TypedExecutorAdapter<T,
241273
}
242274
}
243275

276+
/// Build an erased executor from a typed executor and memo type.
277+
fn erase_executor<T, M, E>(executor: E) -> Arc<dyn ErasedExecutor>
278+
where
279+
T: TypedTask,
280+
M: Serialize + DeserializeOwned + Send + Sync + 'static,
281+
E: TypedExecutor<T, M>,
282+
{
283+
Arc::new(TypedExecutorAdapter {
284+
executor,
285+
_marker: PhantomData::<fn() -> (T, M)>,
286+
})
287+
}
288+
244289
// ── Domain<D> ────────────────────────────────────────────────────────
245290

246291
/// A typed module builder that enforces the link between a [`DomainKey`],
@@ -317,7 +362,36 @@ impl<D: DomainKey> Domain<D> {
317362
T: TypedTask<Domain = D>,
318363
{
319364
let config = T::config();
320-
self.task_inner::<T>(executor, config.ttl, config.retry_policy)
365+
self.task_inner::<T>(
366+
erase_executor::<T, (), _>(executor),
367+
config.ttl,
368+
config.retry_policy,
369+
)
370+
}
371+
372+
/// Register a typed executor that produces a memo in `execute()` which
373+
/// is persisted and passed to `finalize()`.
374+
///
375+
/// Both `T` and `Memo` are inferred from the executor's
376+
/// `TypedExecutor<T, Memo>` impl — turbofish is only needed when the
377+
/// executor is generic over task types.
378+
///
379+
/// # Example
380+
///
381+
/// ```ignore
382+
/// domain.task_memo(ScanL1Executor)
383+
/// ```
384+
pub fn task_memo<T, Memo>(self, executor: impl TypedExecutor<T, Memo>) -> Self
385+
where
386+
T: TypedTask<Domain = D>,
387+
Memo: Serialize + DeserializeOwned + Send + Sync + 'static,
388+
{
389+
let config = T::config();
390+
self.task_inner::<T>(
391+
erase_executor::<T, Memo, _>(executor),
392+
config.ttl,
393+
config.retry_policy,
394+
)
321395
}
322396

323397
/// Register a typed executor with per-type option overrides.
@@ -332,25 +406,35 @@ impl<D: DomainKey> Domain<D> {
332406
let config = T::config();
333407
let ttl = options.ttl.or(config.ttl);
334408
let retry_policy = options.retry_policy.or(config.retry_policy);
335-
self.task_inner::<T>(executor, ttl, retry_policy)
409+
self.task_inner::<T>(erase_executor::<T, (), _>(executor), ttl, retry_policy)
336410
}
337411

338-
fn task_inner<T>(
339-
mut self,
340-
executor: impl TypedExecutor<T>,
341-
ttl: Option<Duration>,
342-
retry_policy: Option<RetryPolicy>,
412+
/// Like [`task_with()`](Self::task_with), but for executors that produce
413+
/// a memo (see [`task_memo()`](Self::task_memo)).
414+
pub fn task_with_memo<T, Memo>(
415+
self,
416+
executor: impl TypedExecutor<T, Memo>,
417+
options: TaskTypeOptions,
343418
) -> Self
344419
where
345-
T: TypedTask,
420+
T: TypedTask<Domain = D>,
421+
Memo: Serialize + DeserializeOwned + Send + Sync + 'static,
346422
{
347-
let adapter = TypedExecutorAdapter {
348-
executor,
349-
_marker: PhantomData::<fn() -> T>,
350-
};
423+
let config = T::config();
424+
let ttl = options.ttl.or(config.ttl);
425+
let retry_policy = options.retry_policy.or(config.retry_policy);
426+
self.task_inner::<T>(erase_executor::<T, Memo, _>(executor), ttl, retry_policy)
427+
}
428+
429+
fn task_inner<T: TypedTask>(
430+
mut self,
431+
executor: Arc<dyn ErasedExecutor>,
432+
ttl: Option<Duration>,
433+
retry_policy: Option<RetryPolicy>,
434+
) -> Self {
351435
self.executors.push(ModuleExecutor {
352436
task_type: T::TASK_TYPE.to_string(),
353-
executor: Arc::new(adapter) as Arc<dyn ErasedExecutor>,
437+
executor,
354438
options: ExecutorOptions { ttl, retry_policy },
355439
});
356440
self

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@
534534
//! Ok(())
535535
//! }
536536
//!
537-
//! async fn finalize(&self, upload: MultipartUpload, ctx: DomainTaskContext<'_, Uploads>) -> Result<(), TaskError> {
537+
//! async fn finalize(&self, upload: MultipartUpload, _memo: (), ctx: DomainTaskContext<'_, Uploads>) -> Result<(), TaskError> {
538538
//! // All parts uploaded — complete the multipart upload.
539539
//! complete_multipart(&upload).await?;
540540
//! Ok(())

src/registry/mod.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,14 @@ pub(crate) trait TaskExecutor: Send + Sync + 'static {
6161
/// - `ctx`: Execution context with the task record, cancellation token,
6262
/// and progress reporter.
6363
///
64-
/// On success, return `Ok(())`. Use [`TaskContext::record_read_bytes`]
64+
/// On success, return `Ok(None)` or `Ok(Some(bytes))` with serialized
65+
/// memo data to pass to `finalize()`. Use [`TaskContext::record_read_bytes`]
6566
/// and [`TaskContext::record_write_bytes`] to report IO during execution.
6667
/// On failure, return a [`TaskError`] indicating whether retry is appropriate.
6768
fn execute<'a>(
6869
&'a self,
6970
ctx: &'a TaskContext,
70-
) -> impl Future<Output = Result<(), TaskError>> + Send + 'a;
71+
) -> impl Future<Output = Result<Option<Vec<u8>>, TaskError>> + Send + 'a;
7172

7273
/// Called after all children of a parent task have completed.
7374
///
@@ -110,6 +111,9 @@ pub struct TaskTypeRegistry {
110111
type_retry_policies: HashMap<String, RetryPolicy>,
111112
}
112113

114+
/// Serialized memo bytes returned by `execute_erased`.
115+
type MemoBytes = Option<Vec<u8>>;
116+
113117
/// Object-safe wrapper around [`TaskExecutor`] for dynamic dispatch in the registry.
114118
///
115119
/// This trait exists because RPITIT (`impl Future`) in `TaskExecutor` is not
@@ -119,7 +123,7 @@ pub(crate) trait ErasedExecutor: Send + Sync + 'static {
119123
fn execute_erased<'a>(
120124
&'a self,
121125
ctx: &'a TaskContext,
122-
) -> std::pin::Pin<Box<dyn Future<Output = Result<(), TaskError>> + Send + 'a>>;
126+
) -> std::pin::Pin<Box<dyn Future<Output = Result<MemoBytes, TaskError>> + Send + 'a>>;
123127

124128
fn finalize_erased<'a>(
125129
&'a self,
@@ -136,7 +140,7 @@ impl<T: TaskExecutor> ErasedExecutor for T {
136140
fn execute_erased<'a>(
137141
&'a self,
138142
ctx: &'a TaskContext,
139-
) -> std::pin::Pin<Box<dyn Future<Output = Result<(), TaskError>> + Send + 'a>> {
143+
) -> std::pin::Pin<Box<dyn Future<Output = Result<MemoBytes, TaskError>> + Send + 'a>> {
140144
Box::pin(self.execute(ctx))
141145
}
142146

src/scheduler/spawn.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ pub(crate) async fn spawn_task(
111111

112112
let result = match phase {
113113
ExecutionPhase::Execute => executor.execute_erased(&prepared.ctx).await,
114-
ExecutionPhase::Finalize => executor.finalize_erased(&prepared.ctx).await,
114+
ExecutionPhase::Finalize => {
115+
executor.finalize_erased(&prepared.ctx).await.map(|()| None)
116+
} // finalize doesn't produce a memo
115117
};
116118

117119
// Read IO bytes from the context tracker.
@@ -121,11 +123,12 @@ pub(crate) async fn spawn_task(
121123
drop(prepared.ctx);
122124

123125
match result {
124-
Ok(()) => {
126+
Ok(memo) => {
125127
completion::handle_success(
126128
&task,
127129
phase,
128130
&metrics,
131+
memo,
129132
&completion_deps,
130133
decrement_module,
131134
)

src/scheduler/spawn/completion.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pub(crate) async fn handle_success(
3030
task: &TaskRecord,
3131
phase: ExecutionPhase,
3232
metrics: &IoBudget,
33+
memo: Option<Vec<u8>>,
3334
deps: &CompletionDeps,
3435
decrement_module: impl FnOnce(),
3536
) {
@@ -46,7 +47,7 @@ pub(crate) async fn handle_success(
4647
{
4748
match deps.store.active_children_count(task_id).await {
4849
Ok(count) if count > 0 => {
49-
if let Err(e) = deps.store.set_waiting(task_id).await {
50+
if let Err(e) = deps.store.set_waiting(task_id, memo.as_deref()).await {
5051
tracing::error!(task_id, error = %e, "failed to set task to waiting");
5152
}
5253
decrement_module();

src/scheduler/tests.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,7 @@ impl TypedExecutor<ParentTask> for FinalizeTrackingExecutor {
613613
async fn finalize<'a>(
614614
&'a self,
615615
_payload: ParentTask,
616+
_memo: (),
616617
_ctx: DomainTaskContext<'a, ParentDomain>,
617618
) -> Result<(), TaskError> {
618619
self.finalized

0 commit comments

Comments
 (0)