Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions libcc2rs-macros/src/goto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
// Distributed under the MIT license that can be found in the LICENSE file.

use proc_macro::TokenStream;
use proc_macro2::Span;
use syn::parse::{Parse, ParseStream};
use syn::{Expr, Lifetime, Token, parse_macro_input};
use syn::{Block, Expr, ExprBlock, Lifetime, Stmt, parse_macro_input};

use crate::state_machine::{Arm, GotoStateMachine, StateMachine};
use crate::state_machine::{Arm, GotoStateMachine, StateMachine, StateMachineNames};

pub fn expand(input: TokenStream) -> TokenStream {
let GotoBlockInput { arms } = parse_macro_input!(input as GotoBlockInput);
GotoStateMachine {
names: StateMachineNames::fresh(),
arms: arms
.into_iter()
.map(|a| Arm {
Expand All @@ -33,15 +35,29 @@ struct GotoArm {

impl Parse for GotoBlockInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let block: Block = input.parse()?;
let mut arms = Vec::new();
while !input.is_empty() {
let label: Lifetime = input.parse()?;
input.parse::<Token![=>]>()?;
let body: Expr = input.parse()?;
arms.push(GotoArm { label, body });
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
for stmt in block.stmts {
let Stmt::Expr(Expr::Block(eb), _) = stmt else {
return Err(syn::Error::new(
Span::call_site(),
"goto_block! body must be a sequence of labeled blocks",
));
};
let Some(label) = eb.label else {
return Err(syn::Error::new(
Span::call_site(),
"goto_block! arm must be a labeled block",
));
};
arms.push(GotoArm {
label: label.name,
body: Expr::Block(ExprBlock {
attrs: eb.attrs,
label: None,
block: eb.block,
}),
});
}
Ok(Self { arms })
}
Expand Down
6 changes: 3 additions & 3 deletions libcc2rs-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ pub fn switch(input: TokenStream) -> TokenStream {
switch::expand(input)
}

// goto_block! {
// '<label> => { /* body; may contain `break` or `continue` */ },
// goto_block!({
// '<label>: { /* body; may contain `break`, `continue`, or goto!('other) */ }
// ...
// };
// });
//
// Expands to
//
Expand Down
108 changes: 96 additions & 12 deletions libcc2rs-macros/src/state_machine.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
// Copyright (c) 2022-present INESC-ID.
// Distributed under the MIT license that can be found in the LICENSE file.

use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};

use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
use quote::{format_ident, quote};
use syn::visit_mut::{self, VisitMut};
use syn::{Expr, ExprBreak, ExprContinue, Lifetime, Pat};
use syn::{Expr, ExprBreak, ExprContinue, Lifetime, Pat, Stmt, parse_quote};

pub struct Arm {
pub label: String,
Expand All @@ -21,12 +24,29 @@ pub trait StateMachine {
fn emit(self) -> TokenStream2;
}

fn sm_label() -> Lifetime {
Lifetime::new("'__sm", Span::call_site())
pub(crate) struct StateMachineNames {
pub label: Lifetime,
pub state: Ident,
pub break_flag: Ident,
pub cont_flag: Ident,
}

impl StateMachineNames {
pub fn fresh() -> Self {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Self {
label: Lifetime::new(&format!("'__sm{id}"), Span::call_site()),
state: format_ident!("__s{}", id),
break_flag: format_ident!("__user_break{}", id),
cont_flag: format_ident!("__user_continue{}", id),
}
}
}

// Collection of labeled arms that fall-through by default
pub struct GotoStateMachine {
pub names: StateMachineNames,
pub arms: Vec<Arm>,
}

Expand Down Expand Up @@ -87,10 +107,12 @@ impl GotoStateMachine {

impl StateMachine for GotoStateMachine {
fn emit(self) -> TokenStream2 {
let lbl = sm_label();
let s = format_ident!("__s");
let break_flag = format_ident!("__user_break");
let cont_flag = format_ident!("__user_continue");
let StateMachineNames {
label: lbl,
state: s,
break_flag,
cont_flag,
} = self.names;

let n = self.arms.len();
let mut arms_have_break = false;
Expand All @@ -101,6 +123,17 @@ impl StateMachine for GotoStateMachine {
.enumerate()
.map(|(i, arm)| {
let mut body = arm.body.clone();
GotoRewriter {
map: &self
.arms
.iter()
.enumerate()
.map(|(i, a)| (a.label.clone(), i as u32))
.collect(),
state: s.clone(),
sm_label: lbl.clone(),
}
.visit_expr_mut(&mut body);
let (had_br, had_cn) =
Self::propagate_rewrite(&mut body, &lbl, &break_flag, &cont_flag);
arms_have_break |= had_br;
Expand Down Expand Up @@ -131,6 +164,58 @@ impl StateMachine for GotoStateMachine {
}
}

// Rewrites `goto!('label)` into `{ __s = <target index>; continue '__sm; }`.
struct GotoRewriter<'a> {
// Map with labels and their indices inside the current state machine. Used to check if the
// label the goto jumps to is part of the current state machine. If it is, emit
// `__s = map[label]`
map: &'a HashMap<String, u32>,
state: Ident,
sm_label: Lifetime,
}

impl GotoRewriter<'_> {
fn expand_goto_into_state_machine_jump(&self, tokens: &TokenStream2) -> Option<Expr> {
let idx = *self.map.get(
&syn::parse2::<Lifetime>(tokens.clone())
.expect("goto! expects a lifetime label")
.ident
.to_string(),
)?;
let state = &self.state;
let sm_label = &self.sm_label;
Some(parse_quote!({ #state = #idx; continue #sm_label; }))
}

fn recurse_into_inner_goto_block(&mut self, mac: &mut syn::Macro) -> bool {
if mac.path.is_ident("switch") || mac.path.is_ident("goto_block") {
if let Ok(mut inner) = syn::parse2::<Expr>(mac.tokens.clone()) {
self.visit_expr_mut(&mut inner);
mac.tokens = quote!(#inner);
}
return true;
}
false
}
}

impl VisitMut for GotoRewriter<'_> {
fn visit_stmt_mut(&mut self, stmt: &mut Stmt) {
if let Stmt::Macro(sm) = stmt {
if sm.mac.path.is_ident("goto") {
if let Some(jump) = self.expand_goto_into_state_machine_jump(&sm.mac.tokens) {
*stmt = Stmt::Expr(jump, Some(Default::default()));
}
return;
}
if self.recurse_into_inner_goto_block(&mut sm.mac) {
return;
}
}
visit_mut::visit_stmt_mut(self, stmt);
}
}

// GotoStateMachine(dispatch arm + cases)
pub struct SwitchStateMachine {
pub goto: GotoStateMachine,
Expand Down Expand Up @@ -186,17 +271,16 @@ impl SwitchStateMachine {

impl StateMachine for SwitchStateMachine {
fn emit(self) -> TokenStream2 {
let lbl = sm_label();
let s = format_ident!("__s");
let names = StateMachineNames::fresh();

let user_arms = Self::convert_break_to_switch_exit(&self.goto.arms, &lbl);
let dispatch = self.build_dispatch_arm(&user_arms, &lbl, &s);
let user_arms = Self::convert_break_to_switch_exit(&self.goto.arms, &names.label);
let dispatch = self.build_dispatch_arm(&user_arms, &names.label, &names.state);

let mut arms = Vec::new();
arms.push(dispatch);
arms.extend(user_arms);

GotoStateMachine { arms }.emit()
GotoStateMachine { names, arms }.emit()
}
}

Expand Down
9 changes: 7 additions & 2 deletions libcc2rs-macros/src/switch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use proc_macro::TokenStream;
use syn::parse::{Parse, ParseStream};
use syn::{Expr, Pat, parse_macro_input};

use crate::state_machine::{Arm, DispatchCase, GotoStateMachine, StateMachine, SwitchStateMachine};
use crate::state_machine::{
Arm, DispatchCase, GotoStateMachine, StateMachine, StateMachineNames, SwitchStateMachine,
};

pub fn expand(input: TokenStream) -> TokenStream {
let SwitchInput { condition, arms } = parse_macro_input!(input as SwitchInput);
Expand All @@ -24,7 +26,10 @@ pub fn expand(input: TokenStream) -> TokenStream {
});
}
SwitchStateMachine {
goto: GotoStateMachine { arms: cfg_arms },
goto: GotoStateMachine {
names: StateMachineNames::fresh(),
arms: cfg_arms,
},
condition,
cases,
}
Expand Down
Loading
Loading