diff --git a/cpp2rust/converter/converter.cpp b/cpp2rust/converter/converter.cpp index 17eeece4..1dc131c5 100644 --- a/cpp2rust/converter/converter.cpp +++ b/cpp2rust/converter/converter.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -3210,46 +3211,53 @@ bool Converter::ConvertSwitchCaseCondition(clang::SwitchCase *stmt) { } if (clang::isa(last)) { - StrCat(" => {"); + StrCat(" => "); } else /* DefaultStmt */ { - StrCat("_ => {"); + StrCat("_ => "); } return false; } -void Converter::EmitSwitchArm(clang::CompoundStmt *body, clang::SwitchCase *sc, - bool is_default) { +void Converter::EmitSwitchArm(const SwitchArm &arm, bool is_default) { if (is_default) { - StrCat("_ => {"); + StrCat("_ => "); } else { StrCat("__v if __v == "); - ConvertSwitchCaseCondition(sc); + ConvertSwitchCaseCondition(arm.head); } - for (auto *t : GetSwitchCaseBody(body, sc)) { + if (!arm.label.empty()) { + StrCat(std::format("'{}: ", arm.label.str())); + } + StrCat(token::kOpenCurlyBracket); + for (auto *t : arm.body) { Convert(t); } StrCat("},"); } bool Converter::VisitSwitchStmt(clang::SwitchStmt *stmt) { - bool has_fallthrough = SwitchHasFallthrough(stmt); - PushBreakTarget push(break_target_, has_fallthrough - ? BreakTarget::FallthroughSwitch - : BreakTarget::Switch); auto *body = clang::dyn_cast(stmt->getBody()); assert(body); + auto arms = AnalyzeSwitchArms(body); + + bool needs_switch_macro = std::ranges::any_of(arms, [](const SwitchArm &arm) { + return !arm.label.empty() || arm.has_fallthrough; + }); + + PushBreakTarget push(break_target_, needs_switch_macro + ? BreakTarget::FallthroughSwitch + : BreakTarget::Switch); - if (has_fallthrough) { - // Use the switch-with-fallthrough macro + if (needs_switch_macro) { StrCat("switch!"); } else { StrCat("'switch:"); } - PushParen switch_macro_paren(*this, has_fallthrough); - PushBrace switch_label_brace(*this, !has_fallthrough); + PushParen switch_macro_paren(*this, needs_switch_macro); + PushBrace switch_label_brace(*this, !needs_switch_macro); - if (has_fallthrough) { + if (needs_switch_macro) { StrCat("match", ToString(stmt->getCond())); } else { StrCat( @@ -3259,17 +3267,17 @@ bool Converter::VisitSwitchStmt(clang::SwitchStmt *stmt) { PushBrace match_brace(*this); - clang::SwitchCase *default_case = nullptr; - for (auto *sc : GetTopLevelSwitchCases(stmt)) { - if (SwitchCaseContainsDefault(sc)) { - default_case = sc; + const SwitchArm *default_arm = nullptr; + for (const auto &arm : arms) { + if (arm.is_default_case) { + default_arm = &arm; continue; } - EmitSwitchArm(body, sc, /*is_default=*/false); + EmitSwitchArm(arm, /*is_default=*/false); } - if (default_case) { - EmitSwitchArm(body, default_case, /*is_default=*/true); + if (default_arm) { + EmitSwitchArm(*default_arm, /*is_default=*/true); } else { StrCat(R"( _ => {})"); } diff --git a/cpp2rust/converter/converter.h b/cpp2rust/converter/converter.h index eb2e90eb..a8d41078 100644 --- a/cpp2rust/converter/converter.h +++ b/cpp2rust/converter/converter.h @@ -16,6 +16,7 @@ #include #include +#include "converter/converter_lib.h" #include "converter/lex.h" #include "converter/translation_rule.h" #include "logging.h" @@ -367,8 +368,7 @@ class Converter : public clang::RecursiveASTVisitor { virtual bool VisitSwitchStmt(clang::SwitchStmt *stmt); - void EmitSwitchArm(clang::CompoundStmt *body, clang::SwitchCase *sc, - bool is_default); + void EmitSwitchArm(const SwitchArm &arm, bool is_default); bool ConvertSwitchCaseCondition(clang::SwitchCase *stmt); diff --git a/cpp2rust/converter/converter_lib.cpp b/cpp2rust/converter/converter_lib.cpp index 20ff5e25..3ba0ccc3 100644 --- a/cpp2rust/converter/converter_lib.cpp +++ b/cpp2rust/converter/converter_lib.cpp @@ -876,20 +876,15 @@ clang::Expr *NormalizeToBool(clang::Expr *expr, clang::ASTContext &ctx) { /*BasePath=*/nullptr, clang::VK_PRValue, clang::FPOptionsOverride()); } -std::vector -GetTopLevelSwitchCases(clang::SwitchStmt *stmt) { - std::vector cases; - if (auto *body = llvm::dyn_cast(stmt->getBody())) { - for (auto *s : body->body()) { - if (auto *sc = clang::dyn_cast(s)) { - cases.push_back(sc); - } - } +static clang::Stmt *GetLastStmtOfSwitchCase(clang::SwitchCase *c) { + clang::Stmt *cur = c->getSubStmt(); + while (auto *sc = clang::dyn_cast(cur)) { + cur = sc->getSubStmt(); } - return cases; + return cur; } -bool SwitchCaseContainsDefault(clang::SwitchCase *c) { +static bool CaseChainHasDefault(clang::SwitchCase *c) { for (clang::Stmt *cur = c;;) { if (clang::isa(cur)) { return true; @@ -900,32 +895,6 @@ bool SwitchCaseContainsDefault(clang::SwitchCase *c) { } cur = sc->getSubStmt(); } - return false; -} - -static clang::Stmt *GetLastStmtOfSwitchCase(clang::SwitchCase *c) { - clang::Stmt *cur = c->getSubStmt(); - while (auto *sc = clang::dyn_cast(cur)) { - cur = sc->getSubStmt(); - } - return cur; -} - -std::vector GetSwitchCaseBody(clang::CompoundStmt *body, - clang::SwitchCase *head) { - std::vector out; - out.push_back(GetLastStmtOfSwitchCase(head)); - auto it = body->body_begin(), end = body->body_end(); - while (it != end && *it != head) { - ++it; - } - assert(it != end); - ++it; - while (it != end && !clang::isa(*it)) { - out.push_back(*it); - ++it; - } - return out; } static bool SwitchCaseHasFallthrough(clang::Stmt *stmt) { @@ -947,16 +916,32 @@ static bool SwitchCaseHasFallthrough(clang::Stmt *stmt) { return true; } -bool SwitchHasFallthrough(clang::SwitchStmt *stmt) { - if (auto *body = clang::dyn_cast(stmt->getBody())) { - for (auto top_level_case : GetTopLevelSwitchCases(stmt)) { - auto arm = GetSwitchCaseBody(body, top_level_case); - if (arm.empty() || SwitchCaseHasFallthrough(arm.back())) { - return true; - } +std::vector AnalyzeSwitchArms(clang::CompoundStmt *body) { + std::vector arms; + for (clang::Stmt *s : body->body()) { + llvm::StringRef label; + clang::Stmt *inner = s; + if (auto *outer = clang::dyn_cast(inner)) { + label = outer->getDecl()->getName(); + do { + inner = clang::cast(inner)->getSubStmt(); + } while (clang::isa(inner)); + } + + if (auto *sc = clang::dyn_cast(inner)) { + arms.emplace_back(std::vector{GetLastStmtOfSwitchCase(sc)}, + label, sc, CaseChainHasDefault(sc), + /*has_fallthrough=*/false); + } else if (!arms.empty()) { + arms.back().body.push_back(s); } } - return false; + + for (SwitchArm &arm : arms) { + arm.has_fallthrough = + arm.body.empty() || SwitchCaseHasFallthrough(arm.body.back()); + } + return arms; } bool CompoundHasTopLevelLabel(const clang::CompoundStmt *compound) { diff --git a/cpp2rust/converter/converter_lib.h b/cpp2rust/converter/converter_lib.h index ece8b654..a09e0e5d 100644 --- a/cpp2rust/converter/converter_lib.h +++ b/cpp2rust/converter/converter_lib.h @@ -183,15 +183,15 @@ bool ContainsVAArgExpr(const clang::Stmt *stmt); clang::Expr *NormalizeToBool(clang::Expr *expr, clang::ASTContext &ctx); -std::vector -GetTopLevelSwitchCases(clang::SwitchStmt *stmt); - -bool SwitchCaseContainsDefault(clang::SwitchCase *c); - -std::vector GetSwitchCaseBody(clang::CompoundStmt *body, - clang::SwitchCase *head); +struct SwitchArm { + std::vector body; + llvm::StringRef label; + clang::SwitchCase *head; + bool is_default_case; + bool has_fallthrough; +}; -bool SwitchHasFallthrough(clang::SwitchStmt *stmt); +std::vector AnalyzeSwitchArms(clang::CompoundStmt *body); bool CompoundHasTopLevelLabel(const clang::CompoundStmt *compound); diff --git a/libcc2rs-macros/src/switch.rs b/libcc2rs-macros/src/switch.rs index 30845113..5f2e2f8d 100644 --- a/libcc2rs-macros/src/switch.rs +++ b/libcc2rs-macros/src/switch.rs @@ -3,7 +3,7 @@ use proc_macro::TokenStream; use syn::parse::{Parse, ParseStream}; -use syn::{Expr, Pat, parse_macro_input}; +use syn::{Expr, ExprBlock, Pat, parse_macro_input}; use crate::state_machine::{ Arm, DispatchCase, GotoStateMachine, StateMachine, StateMachineNames, SwitchStateMachine, @@ -14,16 +14,23 @@ pub fn expand(input: TokenStream) -> TokenStream { let mut cases = Vec::with_capacity(arms.len()); let mut cfg_arms = Vec::with_capacity(arms.len()); for (i, a) in arms.into_iter().enumerate() { - let label = format!("__c{}", i); + let (label, body) = match a.body { + Expr::Block(eb) if eb.label.is_some() => ( + eb.label.unwrap().name.ident.to_string(), + Expr::Block(ExprBlock { + attrs: eb.attrs, + label: None, + block: eb.block, + }), + ), + other => (format!("__c{}", i), other), + }; cases.push(DispatchCase { pat: a.pat, guard: a.guard, target: label.clone(), }); - cfg_arms.push(Arm { - label, - body: a.body, - }); + cfg_arms.push(Arm { label, body }); } SwitchStateMachine { goto: GotoStateMachine { diff --git a/tests/unit/goto_switch_self_case.c b/tests/unit/goto_switch_self_case.c new file mode 100644 index 00000000..565bede0 --- /dev/null +++ b/tests/unit/goto_switch_self_case.c @@ -0,0 +1,25 @@ +#include + +static int sm(int n) { + int steps = 0; + switch (n) { + target: + case 0: + steps += 1; + break; + case 1: + steps += 10; + goto target; + default: + steps = -1; + break; + } + return steps; +} + +int main(void) { + assert(sm(0) == 1); + assert(sm(1) == 11); + assert(sm(7) == -1); + return 0; +} diff --git a/tests/unit/out/refcount/goto_switch_self_case.rs b/tests/unit/out/refcount/goto_switch_self_case.rs new file mode 100644 index 00000000..86c3a7de --- /dev/null +++ b/tests/unit/out/refcount/goto_switch_self_case.rs @@ -0,0 +1,36 @@ +extern crate libcc2rs; +use libcc2rs::*; +use std::cell::RefCell; +use std::collections::BTreeMap; +use std::io::prelude::*; +use std::io::{Read, Seek, Write}; +use std::os::fd::AsFd; +use std::rc::{Rc, Weak}; +pub fn sm_0(n: i32) -> i32 { + let n: Value = Rc::new(RefCell::new(n)); + let steps: Value = Rc::new(RefCell::new(0)); + switch!(match (*n.borrow()) { + __v if __v == 0 => 'target: { + (*steps.borrow_mut()) += 1; + break; + } + __v if __v == 1 => { + (*steps.borrow_mut()) += 10; + goto!('target); + } + _ => { + (*steps.borrow_mut()) = -1_i32; + break; + } + }); + return (*steps.borrow()); +} +pub fn main() { + std::process::exit(main_0()); +} +fn main_0() -> i32 { + assert!((((({ sm_0(0,) }) == 1) as i32) != 0)); + assert!((((({ sm_0(1,) }) == 11) as i32) != 0)); + assert!((((({ sm_0(7,) }) == -1_i32) as i32) != 0)); + return 0; +} diff --git a/tests/unit/out/unsafe/goto_switch_self_case.rs b/tests/unit/out/unsafe/goto_switch_self_case.rs new file mode 100644 index 00000000..5a0b5c2e --- /dev/null +++ b/tests/unit/out/unsafe/goto_switch_self_case.rs @@ -0,0 +1,37 @@ +extern crate libc; +use libc::*; +extern crate libcc2rs; +use libcc2rs::*; +use std::collections::BTreeMap; +use std::io::{Read, Seek, Write}; +use std::os::fd::{AsFd, FromRawFd, IntoRawFd}; +use std::rc::Rc; +pub unsafe fn sm_0(mut n: i32) -> i32 { + let mut steps: i32 = 0; + switch!(match n { + __v if __v == 0 => 'target: { + steps += 1; + break; + } + __v if __v == 1 => { + steps += 10; + goto!('target); + } + _ => { + steps = -1_i32; + break; + } + }); + return steps; +} +pub fn main() { + unsafe { + std::process::exit(main_0() as i32); + } +} +unsafe fn main_0() -> i32 { + assert!(((((unsafe { sm_0(0,) }) == (1)) as i32) != 0)); + assert!(((((unsafe { sm_0(1,) }) == (11)) as i32) != 0)); + assert!(((((unsafe { sm_0(7,) }) == (-1_i32)) as i32) != 0)); + return 0; +}