Skip to content

Commit deebeb3

Browse files
committed
Add derives attr for types in IR
1 parent 739c990 commit deebeb3

4 files changed

Lines changed: 60 additions & 6 deletions

File tree

rule-preprocessor/src/ir.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ pub struct TypeInfo {
4444
pub is_refcount_pointer: bool,
4545
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
4646
pub is_unsafe_pointer: bool,
47+
#[serde(default, skip_serializing_if = "Vec::is_empty")]
48+
pub derives: Vec<String>,
4749
}
4850

4951
#[derive(Debug, Clone, Serialize, Deserialize)]

rule-preprocessor/src/main.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
extern crate rustc_driver;
77
extern crate rustc_hir;
8+
extern crate rustc_infer;
89
extern crate rustc_interface;
910
extern crate rustc_middle;
1011
extern crate rustc_span;
12+
extern crate rustc_trait_selection;
1113

1214
mod ir;
1315
mod semantic;

rule-preprocessor/src/semantic.rs

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ fn find_rlib(deps_dir: &Path, crate_name: &str) -> Option<PathBuf> {
9090
struct FnDecl<'tcx> {
9191
source_file: String,
9292
name: String,
93+
def_id: rustc_span::def_id::DefId,
9394
body: &'tcx rustc_hir::Body<'tcx>,
9495
}
9596

@@ -124,11 +125,17 @@ struct MethodResolver {
124125
}
125126

126127
impl MethodResolver {
127-
fn resolve_fn_decl<'tcx>(&mut self, tcx: rustc_middle::ty::TyCtxt<'tcx>, f: &FnDecl<'tcx>) {
128-
if let Some(file_ir) = self.ir.all_ir.get_mut(&f.source_file)
129-
&& let Some(RuleIr::Fn(fn_ir)) = file_ir.get_mut(&f.name)
130-
{
131-
f.resolve_unknowns(tcx, fn_ir);
128+
fn resolve_rule<'tcx>(&mut self, tcx: rustc_middle::ty::TyCtxt<'tcx>, f: &FnDecl<'tcx>) {
129+
let Some(file_ir) = self.ir.all_ir.get_mut(&f.source_file) else {
130+
return;
131+
};
132+
match file_ir.get_mut(&f.name) {
133+
Some(RuleIr::Fn(fn_ir)) => f.resolve_unknowns(tcx, fn_ir),
134+
Some(RuleIr::Type(type_ir)) => {
135+
let ret_ty = tcx.fn_sig(f.def_id).skip_binder().output().skip_binder();
136+
type_ir.type_info.derives = type_derives(tcx, ret_ty);
137+
}
138+
None => {}
132139
}
133140
}
134141

@@ -154,7 +161,7 @@ impl rustc_driver::Callbacks for MethodResolver {
154161
tcx: rustc_middle::ty::TyCtxt<'_>,
155162
) -> rustc_driver::Compilation {
156163
for f in iter_fn_decls(tcx) {
157-
self.resolve_fn_decl(tcx, &f);
164+
self.resolve_rule(tcx, &f);
158165
}
159166

160167
rustc_driver::Compilation::Stop
@@ -181,6 +188,7 @@ fn iter_fn_decls<'tcx>(tcx: rustc_middle::ty::TyCtxt<'tcx>) -> Vec<FnDecl<'tcx>>
181188
result.push(FnDecl {
182189
source_file,
183190
name: ident.name.as_str().to_string(),
191+
def_id: decl_id.owner_id.to_def_id(),
184192
body: tcx.hir_body(body_id),
185193
});
186194
}
@@ -205,6 +213,45 @@ fn decl_source_file(
205213
)
206214
}
207215

216+
fn type_derives<'tcx>(
217+
tcx: rustc_middle::ty::TyCtxt<'tcx>,
218+
ty: rustc_middle::ty::Ty<'tcx>,
219+
) -> Vec<String> {
220+
use rustc_infer::infer::TyCtxtInferExt;
221+
use rustc_span::sym;
222+
use rustc_trait_selection::infer::InferCtxtExt;
223+
224+
let lang = tcx.lang_items();
225+
let derivable = [
226+
lang.copy_trait(),
227+
lang.clone_trait(),
228+
tcx.get_diagnostic_item(sym::Debug),
229+
tcx.get_diagnostic_item(sym::Default),
230+
tcx.get_diagnostic_item(sym::PartialEq),
231+
tcx.get_diagnostic_item(sym::Eq),
232+
tcx.get_diagnostic_item(sym::PartialOrd),
233+
tcx.get_diagnostic_item(sym::Ord),
234+
tcx.get_diagnostic_item(sym::Hash),
235+
];
236+
237+
let infcx = tcx
238+
.infer_ctxt()
239+
.build(rustc_middle::ty::TypingMode::non_body_analysis());
240+
let param_env = rustc_middle::ty::ParamEnv::empty();
241+
242+
derivable
243+
.into_iter()
244+
.flatten()
245+
.filter(|&trait_def_id| {
246+
let args = vec![ty; tcx.generics_of(trait_def_id).count()];
247+
infcx
248+
.type_implements_trait(trait_def_id, args, param_env)
249+
.must_apply_modulo_regions()
250+
})
251+
.map(|trait_def_id| tcx.item_name(trait_def_id).to_string())
252+
.collect()
253+
}
254+
208255
struct AstVisitor<'a, 'tcx> {
209256
tcx: rustc_middle::ty::TyCtxt<'tcx>,
210257
param_names: Vec<String>,

rule-preprocessor/src/syntactic.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ impl<'a> FnIrBuilder<'a> {
322322
ty: ty_str,
323323
is_refcount_pointer,
324324
is_unsafe_pointer,
325+
derives: Vec::new(),
325326
})
326327
}
327328
}
@@ -482,6 +483,7 @@ impl<'a> FnIrBuilder<'a> {
482483
ty: p.ty.clone(),
483484
is_refcount_pointer: p.is_refcount_pointer,
484485
is_unsafe_pointer: p.is_unsafe_pointer,
486+
derives: Vec::new(),
485487
},
486488
)
487489
})
@@ -560,6 +562,7 @@ impl<'a> TypeIrBuilder<'a> {
560562
ty: ty.syntax().text().to_string(),
561563
is_refcount_pointer,
562564
is_unsafe_pointer,
565+
derives: Vec::new(),
563566
},
564567
}
565568
}

0 commit comments

Comments
 (0)