Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions cpp2rust/converter/converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <llvm/ADT/DenseMap.h>
#include <llvm/Support/ConvertUTF.h>

#include <algorithm>
#include <format>
#include <utility>

Expand Down Expand Up @@ -3210,46 +3211,53 @@ bool Converter::ConvertSwitchCaseCondition(clang::SwitchCase *stmt) {
}

if (clang::isa<clang::CaseStmt>(last)) {
StrCat(" => {");
StrCat(" => ");
} else /* DefaultStmt */ {
StrCat("_ => {");
StrCat("_ => ");
}
return false;
}

void Converter::EmitSwitchArm(clang::CompoundStmt *body, clang::SwitchCase *sc,
bool is_default) {
void Converter::EmitSwitchArm(const SwitchArm &arm, bool is_default) {
if (is_default) {
StrCat("_ => {");
StrCat("_ => ");
} else {
StrCat("__v if __v == ");
ConvertSwitchCaseCondition(sc);
ConvertSwitchCaseCondition(arm.head);
}
for (auto *t : GetSwitchCaseBody(body, sc)) {
if (!arm.label.empty()) {
StrCat(std::format("'{}: ", arm.label.str()));
}
StrCat(token::kOpenCurlyBracket);
for (auto *t : arm.body) {
Convert(t);
}
StrCat("},");
}

bool Converter::VisitSwitchStmt(clang::SwitchStmt *stmt) {
bool has_fallthrough = SwitchHasFallthrough(stmt);
PushBreakTarget push(break_target_, has_fallthrough
? BreakTarget::FallthroughSwitch
: BreakTarget::Switch);
auto *body = clang::dyn_cast<clang::CompoundStmt>(stmt->getBody());
assert(body);
auto arms = AnalyzeSwitchArms(body);

bool needs_switch_macro = std::ranges::any_of(arms, [](const SwitchArm &arm) {
return !arm.label.empty() || arm.has_fallthrough;
});

PushBreakTarget push(break_target_, needs_switch_macro
? BreakTarget::FallthroughSwitch
: BreakTarget::Switch);

if (has_fallthrough) {
// Use the switch-with-fallthrough macro
if (needs_switch_macro) {
StrCat("switch!");
} else {
StrCat("'switch:");
}

PushParen switch_macro_paren(*this, has_fallthrough);
PushBrace switch_label_brace(*this, !has_fallthrough);
PushParen switch_macro_paren(*this, needs_switch_macro);
PushBrace switch_label_brace(*this, !needs_switch_macro);

if (has_fallthrough) {
if (needs_switch_macro) {
StrCat("match", ToString(stmt->getCond()));
} else {
StrCat(
Expand All @@ -3259,17 +3267,17 @@ bool Converter::VisitSwitchStmt(clang::SwitchStmt *stmt) {

PushBrace match_brace(*this);

clang::SwitchCase *default_case = nullptr;
for (auto *sc : GetTopLevelSwitchCases(stmt)) {
if (SwitchCaseContainsDefault(sc)) {
default_case = sc;
const SwitchArm *default_arm = nullptr;
for (const auto &arm : arms) {
if (arm.is_default_case) {
default_arm = &arm;
continue;
}
EmitSwitchArm(body, sc, /*is_default=*/false);
EmitSwitchArm(arm, /*is_default=*/false);
}

if (default_case) {
EmitSwitchArm(body, default_case, /*is_default=*/true);
if (default_arm) {
EmitSwitchArm(*default_arm, /*is_default=*/true);
} else {
StrCat(R"( _ => {})");
}
Expand Down
4 changes: 2 additions & 2 deletions cpp2rust/converter/converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <utility>
#include <vector>

#include "converter/converter_lib.h"
#include "converter/lex.h"
#include "converter/translation_rule.h"
#include "logging.h"
Expand Down Expand Up @@ -367,8 +368,7 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {

virtual bool VisitSwitchStmt(clang::SwitchStmt *stmt);

void EmitSwitchArm(clang::CompoundStmt *body, clang::SwitchCase *sc,
bool is_default);
void EmitSwitchArm(const SwitchArm &arm, bool is_default);

bool ConvertSwitchCaseCondition(clang::SwitchCase *stmt);

Expand Down
75 changes: 30 additions & 45 deletions cpp2rust/converter/converter_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,20 +876,15 @@ clang::Expr *NormalizeToBool(clang::Expr *expr, clang::ASTContext &ctx) {
/*BasePath=*/nullptr, clang::VK_PRValue, clang::FPOptionsOverride());
}

std::vector<clang::SwitchCase *>
GetTopLevelSwitchCases(clang::SwitchStmt *stmt) {
std::vector<clang::SwitchCase *> cases;
if (auto *body = llvm::dyn_cast<clang::CompoundStmt>(stmt->getBody())) {
for (auto *s : body->body()) {
if (auto *sc = clang::dyn_cast<clang::SwitchCase>(s)) {
cases.push_back(sc);
}
}
static clang::Stmt *GetLastStmtOfSwitchCase(clang::SwitchCase *c) {
clang::Stmt *cur = c->getSubStmt();
while (auto *sc = clang::dyn_cast<clang::SwitchCase>(cur)) {
cur = sc->getSubStmt();
}
return cases;
return cur;
}

bool SwitchCaseContainsDefault(clang::SwitchCase *c) {
static bool CaseChainHasDefault(clang::SwitchCase *c) {
for (clang::Stmt *cur = c;;) {
if (clang::isa<clang::DefaultStmt>(cur)) {
return true;
Expand All @@ -900,32 +895,6 @@ bool SwitchCaseContainsDefault(clang::SwitchCase *c) {
}
cur = sc->getSubStmt();
}
return false;
}

static clang::Stmt *GetLastStmtOfSwitchCase(clang::SwitchCase *c) {
clang::Stmt *cur = c->getSubStmt();
while (auto *sc = clang::dyn_cast<clang::SwitchCase>(cur)) {
cur = sc->getSubStmt();
}
return cur;
}

std::vector<clang::Stmt *> GetSwitchCaseBody(clang::CompoundStmt *body,
clang::SwitchCase *head) {
std::vector<clang::Stmt *> out;
out.push_back(GetLastStmtOfSwitchCase(head));
auto it = body->body_begin(), end = body->body_end();
while (it != end && *it != head) {
++it;
}
assert(it != end);
++it;
while (it != end && !clang::isa<clang::SwitchCase>(*it)) {
out.push_back(*it);
++it;
}
return out;
}

static bool SwitchCaseHasFallthrough(clang::Stmt *stmt) {
Expand All @@ -947,16 +916,32 @@ static bool SwitchCaseHasFallthrough(clang::Stmt *stmt) {
return true;
}

bool SwitchHasFallthrough(clang::SwitchStmt *stmt) {
if (auto *body = clang::dyn_cast<clang::CompoundStmt>(stmt->getBody())) {
for (auto top_level_case : GetTopLevelSwitchCases(stmt)) {
auto arm = GetSwitchCaseBody(body, top_level_case);
if (arm.empty() || SwitchCaseHasFallthrough(arm.back())) {
return true;
}
std::vector<SwitchArm> AnalyzeSwitchArms(clang::CompoundStmt *body) {
std::vector<SwitchArm> arms;
for (clang::Stmt *s : body->body()) {
llvm::StringRef label;
clang::Stmt *inner = s;
if (auto *outer = clang::dyn_cast<clang::LabelStmt>(inner)) {
label = outer->getDecl()->getName();
do {
inner = clang::cast<clang::LabelStmt>(inner)->getSubStmt();
} while (clang::isa<clang::LabelStmt>(inner));
}

if (auto *sc = clang::dyn_cast<clang::SwitchCase>(inner)) {
arms.emplace_back(std::vector<clang::Stmt *>{GetLastStmtOfSwitchCase(sc)},
label, sc, CaseChainHasDefault(sc),
/*has_fallthrough=*/false);
} else if (!arms.empty()) {
arms.back().body.push_back(s);
}
}
return false;

for (SwitchArm &arm : arms) {
arm.has_fallthrough =
arm.body.empty() || SwitchCaseHasFallthrough(arm.body.back());
}
return arms;
}

bool CompoundHasTopLevelLabel(const clang::CompoundStmt *compound) {
Expand Down
16 changes: 8 additions & 8 deletions cpp2rust/converter/converter_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,15 @@ bool ContainsVAArgExpr(const clang::Stmt *stmt);

clang::Expr *NormalizeToBool(clang::Expr *expr, clang::ASTContext &ctx);

std::vector<clang::SwitchCase *>
GetTopLevelSwitchCases(clang::SwitchStmt *stmt);

bool SwitchCaseContainsDefault(clang::SwitchCase *c);

std::vector<clang::Stmt *> GetSwitchCaseBody(clang::CompoundStmt *body,
clang::SwitchCase *head);
struct SwitchArm {
std::vector<clang::Stmt *> body;
llvm::StringRef label;
clang::SwitchCase *head;
bool is_default_case;
bool has_fallthrough;
};

bool SwitchHasFallthrough(clang::SwitchStmt *stmt);
std::vector<SwitchArm> AnalyzeSwitchArms(clang::CompoundStmt *body);

bool CompoundHasTopLevelLabel(const clang::CompoundStmt *compound);

Expand Down
19 changes: 13 additions & 6 deletions libcc2rs-macros/src/switch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use proc_macro::TokenStream;
use syn::parse::{Parse, ParseStream};
use syn::{Expr, Pat, parse_macro_input};
use syn::{Expr, ExprBlock, Pat, parse_macro_input};

use crate::state_machine::{
Arm, DispatchCase, GotoStateMachine, StateMachine, StateMachineNames, SwitchStateMachine,
Expand All @@ -14,16 +14,23 @@ pub fn expand(input: TokenStream) -> TokenStream {
let mut cases = Vec::with_capacity(arms.len());
let mut cfg_arms = Vec::with_capacity(arms.len());
for (i, a) in arms.into_iter().enumerate() {
let label = format!("__c{}", i);
let (label, body) = match a.body {
Expr::Block(eb) if eb.label.is_some() => (
eb.label.unwrap().name.ident.to_string(),
Expr::Block(ExprBlock {
attrs: eb.attrs,
label: None,
block: eb.block,
}),
),
other => (format!("__c{}", i), other),
};
cases.push(DispatchCase {
pat: a.pat,
guard: a.guard,
target: label.clone(),
});
cfg_arms.push(Arm {
label,
body: a.body,
});
cfg_arms.push(Arm { label, body });
}
SwitchStateMachine {
goto: GotoStateMachine {
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/goto_switch_self_case.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <assert.h>

static int sm(int n) {
int steps = 0;
switch (n) {
target:
case 0:
steps += 1;
break;
case 1:
steps += 10;
goto target;
default:
steps = -1;
break;
}
return steps;
}

int main(void) {
assert(sm(0) == 1);
assert(sm(1) == 11);
assert(sm(7) == -1);
return 0;
}
36 changes: 36 additions & 0 deletions tests/unit/out/refcount/goto_switch_self_case.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
extern crate libcc2rs;
use libcc2rs::*;
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::io::prelude::*;
use std::io::{Read, Seek, Write};
use std::os::fd::AsFd;
use std::rc::{Rc, Weak};
pub fn sm_0(n: i32) -> i32 {
let n: Value<i32> = Rc::new(RefCell::new(n));
let steps: Value<i32> = Rc::new(RefCell::new(0));
switch!(match (*n.borrow()) {
__v if __v == 0 => 'target: {
(*steps.borrow_mut()) += 1;
break;
}
__v if __v == 1 => {
(*steps.borrow_mut()) += 10;
goto!('target);
}
_ => {
(*steps.borrow_mut()) = -1_i32;
break;
}
});
return (*steps.borrow());
}
pub fn main() {
std::process::exit(main_0());
}
fn main_0() -> i32 {
assert!((((({ sm_0(0,) }) == 1) as i32) != 0));
assert!((((({ sm_0(1,) }) == 11) as i32) != 0));
assert!((((({ sm_0(7,) }) == -1_i32) as i32) != 0));
return 0;
}
37 changes: 37 additions & 0 deletions tests/unit/out/unsafe/goto_switch_self_case.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
extern crate libc;
use libc::*;
extern crate libcc2rs;
use libcc2rs::*;
use std::collections::BTreeMap;
use std::io::{Read, Seek, Write};
use std::os::fd::{AsFd, FromRawFd, IntoRawFd};
use std::rc::Rc;
pub unsafe fn sm_0(mut n: i32) -> i32 {
let mut steps: i32 = 0;
switch!(match n {
__v if __v == 0 => 'target: {
steps += 1;
break;
}
__v if __v == 1 => {
steps += 10;
goto!('target);
}
_ => {
steps = -1_i32;
break;
}
});
return steps;
}
pub fn main() {
unsafe {
std::process::exit(main_0() as i32);
}
}
unsafe fn main_0() -> i32 {
assert!(((((unsafe { sm_0(0,) }) == (1)) as i32) != 0));
assert!(((((unsafe { sm_0(1,) }) == (11)) as i32) != 0));
assert!(((((unsafe { sm_0(7,) }) == (-1_i32)) as i32) != 0));
return 0;
}
Loading