From fbad0856cd81fad509bdac7d52e4597968a16d48 Mon Sep 17 00:00:00 2001 From: Ellie <6687206+wizzeh@users.noreply.github.com> Date: Sat, 28 Mar 2026 15:51:16 -0700 Subject: [PATCH] Add support for dynamic user context --- proc-macros/src/lib.rs | 39 +++++++++++++++++++----- src/proc.rs | 67 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 8 deletions(-) diff --git a/proc-macros/src/lib.rs b/proc-macros/src/lib.rs index 01fac7f4..c53ad024 100644 --- a/proc-macros/src/lib.rs +++ b/proc-macros/src/lib.rs @@ -55,8 +55,16 @@ pub fn bridge(args: TokenStream, item: TokenStream) -> TokenStream { let wrapper_name = impl_name.to_string(); let wrapper_name = Ident::new(&wrapper_name, Span::call_site()); - let (rest_args, is_variadic) = if let Some(last_arg) = bridge.sig.inputs.last() - && is_slice(&last_arg) + // If the first parameter is dyn_state we can leave it off the Scheme parameter list + let has_dyn_state = bridge.sig.inputs.first().map(is_dyn_state).unwrap_or(false); + let scheme_inputs: Vec<_> = if has_dyn_state { + bridge.sig.inputs.iter().skip(1).collect() + } else { + bridge.sig.inputs.iter().collect() + }; + + let (rest_args, is_variadic) = if let Some(last_arg) = scheme_inputs.last() + && is_slice(last_arg) { (quote!(rest_args), true) } else { @@ -64,14 +72,12 @@ pub fn bridge(args: TokenStream, item: TokenStream) -> TokenStream { }; let num_args = if is_variadic { - bridge.sig.inputs.len().saturating_sub(1) + scheme_inputs.len().saturating_sub(1) } else { - bridge.sig.inputs.len() + scheme_inputs.len() }; - let arg_names: Vec<_> = bridge - .sig - .inputs + let arg_names: Vec<_> = scheme_inputs .iter() .enumerate() .map(|(i, arg)| { @@ -86,6 +92,12 @@ pub fn bridge(args: TokenStream, item: TokenStream) -> TokenStream { let arg_indices: Vec<_> = (0..num_args).collect(); + let dyn_state_arg = if has_dyn_state { + quote!(dyn_state,) + } else { + quote!() + }; + let visibility = bridge.vis.clone(); if bridge.sig.asyncness.is_some() { @@ -103,6 +115,7 @@ pub fn bridge(args: TokenStream, item: TokenStream) -> TokenStream { Box::pin( async move { let result = #impl_name( + #dyn_state_arg #( match (&args[#arg_indices]).try_into() { Ok(ok) => ok, @@ -163,6 +176,7 @@ pub fn bridge(args: TokenStream, item: TokenStream) -> TokenStream { #bridge let result = #impl_name( + #dyn_state_arg #( match (&args[#arg_indices]).try_into() { Ok(ok) => ok, @@ -431,6 +445,17 @@ fn is_slice(arg: &FnArg) -> bool { matches!(arg, FnArg::Typed(PatType { ty, ..}) if matches!(ty.as_ref(), Type::Reference(TypeReference { elem, .. }) if matches!(elem.as_ref(), Type::Slice(_)))) } +fn is_dyn_state(arg: &FnArg) -> bool { + if let FnArg::Typed(PatType { ty, .. }) = arg { + if let Type::Reference(TypeReference { mutability: Some(_), elem, .. }) = ty.as_ref() { + if let Type::Path(TypePath { path, .. }) = elem.as_ref() { + return path.segments.last().map(|s| s.ident == "DynamicState").unwrap_or(false); + } + } + } + false +} + /// Derive the `Trace` trait for a type. /// /// `Trace` assumes that all fields of the type implement `Trace` or are a `Gc` diff --git a/src/proc.rs b/src/proc.rs index d3928280..a64cd479 100644 --- a/src/proc.rs +++ b/src/proc.rs @@ -108,6 +108,7 @@ use crate::{ use parking_lot::RwLock; use scheme_rs_macros::{cps_bridge, maybe_async, maybe_await}; use std::{ + any::Any, collections::HashMap, fmt, sync::{ @@ -571,6 +572,22 @@ impl Procedure { maybe_await!(Application::new(self.clone(), args).eval(&mut DynamicState::default())) } + /// Like [`call`], but with a user context that will be available to all + /// bridge functions via [`DynamicState::user_ctx`] throughout the call chain, + /// including across reentrant Rust-Scheme-Rust boundaries. + #[maybe_async] + pub fn call_with_ctx( + &self, + args: &[Value], + ctx: Arc, + ) -> Result, Exception> { + let mut args = args.to_vec(); + args.push(halt_continuation(self.get_runtime())); + let mut dyn_state = DynamicState::default(); + dyn_state.set_user_ctx(ctx); + maybe_await!(Application::new(self.clone(), args).eval(&mut dyn_state)) + } + #[cfg(feature = "async")] pub fn call_sync(&self, args: &[Value]) -> Result, Exception> { let mut args = args.to_vec(); @@ -579,6 +596,19 @@ impl Procedure { Application::new(self.clone(), args).eval_sync(&mut DynamicState::default()) } + + #[cfg(feature = "async")] + pub fn call_sync_with_ctx( + &self, + args: &[Value], + ctx: Arc, + ) -> Result, Exception> { + let mut args = args.to_vec(); + args.push(halt_continuation(self.get_runtime())); + let mut dyn_state = DynamicState::default(); + dyn_state.set_user_ctx(ctx); + Application::new(self.clone(), args).eval_sync(&mut dyn_state) + } } static HALT_CONTINUATION: OnceLock = OnceLock::new(); @@ -754,10 +784,21 @@ pub fn apply( /// The dynamic state of the running program, including winders, exception /// handlers, and continuation marks. -#[derive(Clone, Debug, Trace)] +#[derive(Clone, Trace)] pub struct DynamicState { dyn_stack: Vec, cont_marks: Vec>, + user_ctx: Option>, +} + +impl fmt::Debug for DynamicState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DynamicState") + .field("dyn_stack", &self.dyn_stack) + .field("cont_marks", &self.cont_marks) + .field("user_ctx", &self.user_ctx.as_ref().map(|_| "..")) + .finish() + } } impl DynamicState { @@ -769,6 +810,7 @@ impl DynamicState { // the initial marks for them since there's no mechanism to allocate // for them when they're run. cont_marks: vec![HashMap::new()], + user_ctx: None, } } @@ -892,6 +934,28 @@ impl DynamicState { pub(crate) fn dyn_stack_is_empty(&self) -> bool { self.dyn_stack.is_empty() } + + /// Get the user context, if set. + pub fn user_ctx(&self) -> Option<&Arc> { + self.user_ctx.as_ref() + } + + /// Downcast the user context to a concrete type. + pub fn user_ctx_downcast(&self) -> Option<&T> { + self.user_ctx + .as_ref() + .and_then(|arc| arc.downcast_ref::()) + } + + /// Set the user context. + pub fn set_user_ctx(&mut self, ctx: Arc) { + self.user_ctx = Some(ctx); + } + + /// Clear the user context. + pub fn clear_user_ctx(&mut self) { + self.user_ctx = None; + } } impl Default for DynamicState { @@ -1528,6 +1592,7 @@ unsafe extern "C" fn unwind_to_prompt( [dyn_state.dyn_stack_len() + 1..] .to_vec(), cont_marks: saved_dyn_state.cont_marks.clone(), + user_ctx: dyn_state.user_ctx.clone(), }; let (req_args, var) = { let k_proc: Procedure = k.clone().try_into().unwrap();