Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions proc-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,29 @@ pub fn bridge(args: TokenStream, item: TokenStream) -> TokenStream {
let wrapper_name = impl_name.to_string();
let wrapper_name = Ident::new(&wrapper_name, Span::call_site());

let (rest_args, is_variadic) = if let Some(last_arg) = bridge.sig.inputs.last()
&& is_slice(&last_arg)
// If the first parameter is dyn_state we can leave it off the Scheme parameter list
let has_dyn_state = bridge.sig.inputs.first().map(is_dyn_state).unwrap_or(false);
let scheme_inputs: Vec<_> = if has_dyn_state {
bridge.sig.inputs.iter().skip(1).collect()
} else {
bridge.sig.inputs.iter().collect()
};

let (rest_args, is_variadic) = if let Some(last_arg) = scheme_inputs.last()
&& is_slice(last_arg)
{
(quote!(rest_args), true)
} else {
(quote!(), false)
};

let num_args = if is_variadic {
bridge.sig.inputs.len().saturating_sub(1)
scheme_inputs.len().saturating_sub(1)
} else {
bridge.sig.inputs.len()
scheme_inputs.len()
};

let arg_names: Vec<_> = bridge
.sig
.inputs
let arg_names: Vec<_> = scheme_inputs
.iter()
.enumerate()
.map(|(i, arg)| {
Expand All @@ -86,6 +92,12 @@ pub fn bridge(args: TokenStream, item: TokenStream) -> TokenStream {

let arg_indices: Vec<_> = (0..num_args).collect();

let dyn_state_arg = if has_dyn_state {
quote!(dyn_state,)
} else {
quote!()
};

let visibility = bridge.vis.clone();

if bridge.sig.asyncness.is_some() {
Expand All @@ -103,6 +115,7 @@ pub fn bridge(args: TokenStream, item: TokenStream) -> TokenStream {
Box::pin(
async move {
let result = #impl_name(
#dyn_state_arg
#(
match (&args[#arg_indices]).try_into() {
Ok(ok) => ok,
Expand Down Expand Up @@ -163,6 +176,7 @@ pub fn bridge(args: TokenStream, item: TokenStream) -> TokenStream {
#bridge

let result = #impl_name(
#dyn_state_arg
#(
match (&args[#arg_indices]).try_into() {
Ok(ok) => ok,
Expand Down Expand Up @@ -431,6 +445,17 @@ fn is_slice(arg: &FnArg) -> bool {
matches!(arg, FnArg::Typed(PatType { ty, ..}) if matches!(ty.as_ref(), Type::Reference(TypeReference { elem, .. }) if matches!(elem.as_ref(), Type::Slice(_))))
}

fn is_dyn_state(arg: &FnArg) -> bool {
if let FnArg::Typed(PatType { ty, .. }) = arg {
if let Type::Reference(TypeReference { mutability: Some(_), elem, .. }) = ty.as_ref() {
if let Type::Path(TypePath { path, .. }) = elem.as_ref() {
return path.segments.last().map(|s| s.ident == "DynamicState").unwrap_or(false);
}
}
}
false
}

/// Derive the `Trace` trait for a type.
///
/// `Trace` assumes that all fields of the type implement `Trace` or are a `Gc`
Expand Down
67 changes: 66 additions & 1 deletion src/proc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ use crate::{
use parking_lot::RwLock;
use scheme_rs_macros::{cps_bridge, maybe_async, maybe_await};
use std::{
any::Any,
collections::HashMap,
fmt,
sync::{
Expand Down Expand Up @@ -571,6 +572,22 @@ impl Procedure {
maybe_await!(Application::new(self.clone(), args).eval(&mut DynamicState::default()))
}

/// Like [`call`], but with a user context that will be available to all
/// bridge functions via [`DynamicState::user_ctx`] throughout the call chain,
/// including across reentrant Rust-Scheme-Rust boundaries.
#[maybe_async]
pub fn call_with_ctx(
&self,
args: &[Value],
ctx: Arc<dyn Any + Send + Sync>,
) -> Result<Vec<Value>, Exception> {
let mut args = args.to_vec();
args.push(halt_continuation(self.get_runtime()));
let mut dyn_state = DynamicState::default();
dyn_state.set_user_ctx(ctx);
maybe_await!(Application::new(self.clone(), args).eval(&mut dyn_state))
}

#[cfg(feature = "async")]
pub fn call_sync(&self, args: &[Value]) -> Result<Vec<Value>, Exception> {
let mut args = args.to_vec();
Expand All @@ -579,6 +596,19 @@ impl Procedure {

Application::new(self.clone(), args).eval_sync(&mut DynamicState::default())
}

#[cfg(feature = "async")]
pub fn call_sync_with_ctx(
&self,
args: &[Value],
ctx: Arc<dyn Any + Send + Sync>,
) -> Result<Vec<Value>, Exception> {
let mut args = args.to_vec();
args.push(halt_continuation(self.get_runtime()));
let mut dyn_state = DynamicState::default();
dyn_state.set_user_ctx(ctx);
Application::new(self.clone(), args).eval_sync(&mut dyn_state)
}
}

static HALT_CONTINUATION: OnceLock<Value> = OnceLock::new();
Expand Down Expand Up @@ -754,10 +784,21 @@ pub fn apply(

/// The dynamic state of the running program, including winders, exception
/// handlers, and continuation marks.
#[derive(Clone, Debug, Trace)]
#[derive(Clone, Trace)]
pub struct DynamicState {
dyn_stack: Vec<DynStackElem>,
cont_marks: Vec<HashMap<Symbol, Value>>,
user_ctx: Option<Arc<dyn Any + Send + Sync>>,
}

impl fmt::Debug for DynamicState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DynamicState")
.field("dyn_stack", &self.dyn_stack)
.field("cont_marks", &self.cont_marks)
.field("user_ctx", &self.user_ctx.as_ref().map(|_| ".."))
.finish()
}
}

impl DynamicState {
Expand All @@ -769,6 +810,7 @@ impl DynamicState {
// the initial marks for them since there's no mechanism to allocate
// for them when they're run.
cont_marks: vec![HashMap::new()],
user_ctx: None,
}
}

Expand Down Expand Up @@ -892,6 +934,28 @@ impl DynamicState {
pub(crate) fn dyn_stack_is_empty(&self) -> bool {
self.dyn_stack.is_empty()
}

/// Get the user context, if set.
pub fn user_ctx(&self) -> Option<&Arc<dyn Any + Send + Sync>> {
self.user_ctx.as_ref()
}

/// Downcast the user context to a concrete type.
pub fn user_ctx_downcast<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.user_ctx
.as_ref()
.and_then(|arc| arc.downcast_ref::<T>())
}

/// Set the user context.
pub fn set_user_ctx(&mut self, ctx: Arc<dyn Any + Send + Sync>) {
self.user_ctx = Some(ctx);
}

/// Clear the user context.
pub fn clear_user_ctx(&mut self) {
self.user_ctx = None;
}
}

impl Default for DynamicState {
Expand Down Expand Up @@ -1528,6 +1592,7 @@ unsafe extern "C" fn unwind_to_prompt(
[dyn_state.dyn_stack_len() + 1..]
.to_vec(),
cont_marks: saved_dyn_state.cont_marks.clone(),
user_ctx: dyn_state.user_ctx.clone(),
};
let (req_args, var) = {
let k_proc: Procedure = k.clone().try_into().unwrap();
Expand Down
Loading