diff --git a/crates/cgp-macro-lib/src/cgp_fn/item_impl.rs b/crates/cgp-macro-lib/src/cgp_fn/item_impl.rs index ebef39e1..23be35c9 100644 --- a/crates/cgp-macro-lib/src/cgp_fn/item_impl.rs +++ b/crates/cgp-macro-lib/src/cgp_fn/item_impl.rs @@ -3,7 +3,9 @@ use syn::punctuated::Punctuated; use syn::token::Plus; use syn::{Generics, Ident, ItemFn, ItemImpl, TypeParamBound, parse2}; -use crate::cgp_fn::{FunctionAttributes, ImplicitArgField, substitute_abstract_type}; +use crate::cgp_fn::{ + FunctionAttributes, ImplicitArgField, derive_use_type_trait_bounds, substitute_abstract_type, +}; use crate::derive_getter::derive_getter_constraint; use crate::symbol::symbol_from_string; @@ -74,11 +76,8 @@ pub fn derive_item_impl( item_impl.to_token_stream(), ))?; - let mut bounds: Punctuated = Punctuated::default(); - - for use_type in attributes.use_type.iter() { - bounds.push(parse2(use_type.trait_path.to_token_stream())?); - } + let bounds = derive_use_type_trait_bounds("e! { Self }, &attributes.use_type)?; + let bounds = Punctuated::::from_iter(bounds); item_impl .generics diff --git a/crates/cgp-macro-lib/src/cgp_fn/mod.rs b/crates/cgp-macro-lib/src/cgp_fn/mod.rs index 3c094c21..a33875ba 100644 --- a/crates/cgp-macro-lib/src/cgp_fn/mod.rs +++ b/crates/cgp-macro-lib/src/cgp_fn/mod.rs @@ -6,6 +6,7 @@ mod item_trait; mod parse_implicits; mod spec; mod substitute_type; +mod type_equality; mod use_type; pub use attributes::*; @@ -13,4 +14,5 @@ pub use derive::*; pub use parse_implicits::*; pub use spec::*; pub use substitute_type::*; +pub use type_equality::*; pub use use_type::*; diff --git a/crates/cgp-macro-lib/src/cgp_fn/type_equality.rs b/crates/cgp-macro-lib/src/cgp_fn/type_equality.rs new file mode 100644 index 00000000..fd289ffc --- /dev/null +++ b/crates/cgp-macro-lib/src/cgp_fn/type_equality.rs @@ -0,0 +1,118 @@ +use proc_macro2::TokenStream; +use quote::{ToTokens, quote}; +use syn::punctuated::Punctuated; +use syn::token::Comma; +use syn::{Ident, Type, TypeParamBound, parse2}; + +use crate::cgp_fn::{UseTypeIdent, UseTypeSpec}; + +pub fn derive_use_type_trait_bounds( + context_type: &TokenStream, + specs: &[UseTypeSpec], +) -> syn::Result> { + let mut bounds = Vec::new(); + + for use_type in specs.iter() { + let type_equalities = find_type_equalities(use_type, context_type, specs)?; + + if type_equalities.is_empty() { + bounds.push(parse2(use_type.trait_path.to_token_stream())?); + } else { + let mut constraints: Punctuated = Punctuated::new(); + + for (alias_ident, equal_target) in type_equalities.into_iter() { + constraints.push(quote! { + #alias_ident = #equal_target + }); + } + + let trait_path = &use_type.trait_path; + let bound = quote! { + #trait_path < #constraints > + }; + + bounds.push(parse2(bound)?); + } + } + + Ok(bounds) +} + +pub fn find_type_equalities( + current_spec: &UseTypeSpec, + context_type: &TokenStream, + specs: &[UseTypeSpec], +) -> syn::Result> { + let mut equalities = Vec::new(); + + for current_type_ident in current_spec.type_idents.iter() { + forbid_same_alias(current_type_ident, current_spec, specs)?; + + if let Some(equality) = + find_type_equality(context_type, current_type_ident, current_spec, specs)? + { + equalities.push(equality); + } + } + + Ok(equalities) +} + +fn forbid_same_alias( + current_ident: &UseTypeIdent, + current_spec: &UseTypeSpec, + specs: &[UseTypeSpec], +) -> syn::Result<()> { + for spec in specs.iter() { + if core::ptr::eq(spec, current_spec) { + // Skip the current spec + continue; + } + + for type_ident in spec.type_idents.iter() { + if current_ident.alias_ident() == type_ident.alias_ident() { + return Err(syn::Error::new_spanned( + ¤t_ident.type_ident, + "Multiple abstract types cannot share the same identifier or alias", + )); + } + } + } + + Ok(()) +} + +fn find_type_equality( + context_type: &TokenStream, + current_ident: &UseTypeIdent, + current_spec: &UseTypeSpec, + specs: &[UseTypeSpec], +) -> syn::Result> { + if let Some(equal_target) = current_ident.equals.clone() { + for spec in specs.iter() { + if core::ptr::eq(spec, current_spec) { + // Skip the current spec + continue; + } + + for match_use_type in spec.type_idents.iter() { + let match_type: Type = parse2(match_use_type.alias_ident().to_token_stream())?; + if match_type == equal_target { + let trait_path = &spec.trait_path; + let current_type_ident = ¤t_ident.type_ident; + let match_type_ident = &match_use_type.type_ident; + + let equal_target: Type = parse2(quote! { + <#context_type as #trait_path>::#match_type_ident + })?; + + return Ok(Some((current_type_ident.clone(), equal_target))); + } + } + } + + Ok(Some((current_ident.type_ident.clone(), equal_target))) + } else { + Ok(None) + } +} diff --git a/crates/cgp-macro-lib/src/cgp_fn/use_type.rs b/crates/cgp-macro-lib/src/cgp_fn/use_type.rs index 08901872..d36949c1 100644 --- a/crates/cgp-macro-lib/src/cgp_fn/use_type.rs +++ b/crates/cgp-macro-lib/src/cgp_fn/use_type.rs @@ -1,6 +1,6 @@ use syn::parse::{Parse, ParseStream}; -use syn::token::{As, Brace, Colon, Comma, Gt, Lt}; -use syn::{Ident, braced}; +use syn::token::{As, Brace, Colon, Comma, Eq, Gt, Lt}; +use syn::{Ident, Type, braced}; use crate::parse::SimpleType; @@ -12,12 +12,13 @@ pub struct UseTypeSpec { pub struct UseTypeIdent { pub type_ident: Ident, pub as_alias: Option, + pub equals: Option, } impl UseTypeSpec { pub fn replace_ident(&self, ident: &Ident) -> Option { for type_ident in &self.type_idents { - if type_ident.replacement_ident() == ident { + if type_ident.alias_ident() == ident { let mut new_ident = type_ident.type_ident.clone(); new_ident.set_span(ident.span()); return Some(new_ident); @@ -29,7 +30,7 @@ impl UseTypeSpec { } impl UseTypeIdent { - pub fn replacement_ident(&self) -> &Ident { + pub fn alias_ident(&self) -> &Ident { self.as_alias.as_ref().unwrap_or(&self.type_ident) } } @@ -64,6 +65,7 @@ impl Parse for UseTypeSpec { vec![UseTypeIdent { type_ident: ident, as_alias: None, + equals: None, }] }; @@ -85,9 +87,17 @@ impl Parse for UseTypeIdent { None }; + let equals = if input.peek(Eq) { + let _: Eq = input.parse()?; + Some(input.parse()?) + } else { + None + }; + Ok(Self { type_ident, as_alias, + equals, }) } } diff --git a/crates/cgp-tests/tests/cgp_fn.rs b/crates/cgp-tests/tests/cgp_fn.rs index f623a10a..5e7cd854 100644 --- a/crates/cgp-tests/tests/cgp_fn.rs +++ b/crates/cgp-tests/tests/cgp_fn.rs @@ -1 +1,3 @@ +#![allow(clippy::disallowed_names)] + pub mod cgp_fn_tests; diff --git a/crates/cgp-tests/tests/cgp_fn_tests/mod.rs b/crates/cgp-tests/tests/cgp_fn_tests/mod.rs index 90833d71..2bbb504f 100644 --- a/crates/cgp-tests/tests/cgp_fn_tests/mod.rs +++ b/crates/cgp-tests/tests/cgp_fn_tests/mod.rs @@ -4,6 +4,7 @@ pub mod extend; pub mod generics; pub mod multi; pub mod mutable; +pub mod type_equality; pub mod use_type; pub mod use_type_alias; pub mod uses; diff --git a/crates/cgp-tests/tests/cgp_fn_tests/type_equality.rs b/crates/cgp-tests/tests/cgp_fn_tests/type_equality.rs new file mode 100644 index 00000000..8da2b1a3 --- /dev/null +++ b/crates/cgp-tests/tests/cgp_fn_tests/type_equality.rs @@ -0,0 +1,50 @@ +use std::fmt::Display; + +use cgp::prelude::*; + +#[cgp_type] +pub trait HasScalarType { + type Scalar; +} + +#[cgp_fn] +#[use_type(HasScalarType::{Scalar = f64})] +pub fn rectangle_area(&self, #[implicit] width: Scalar, #[implicit] height: Scalar) -> Scalar { + let res: f64 = width * height; + res +} + +pub trait HasFooType { + // The `Ord + Clone` bounds are visible to both `Foo` and `Bar` because of `Bar = Foo` below + type Foo: Ord + Clone; +} + +pub trait HasBarType { + // The `Display` bounds are hidden because of `Bar = Foo` below + type Bar: Display; +} + +#[cgp_fn] +#[use_type(HasFooType::Foo)] +pub fn do_foo(&self) -> Foo { + todo!() +} + +#[cgp_fn] +#[use_type(HasBarType::Bar)] +pub fn do_bar(&self) -> Bar { + todo!() +} + +#[cgp_fn] +#[use_type(HasBarType::{Bar as Baz = Foo}, HasFooType::Foo)] +#[uses(DoFoo, DoBar)] +fn return_foo_or_bar(&self, flag: bool, #[implicit] foo: &Foo, #[implicit] bar: &Baz) -> Foo { + if flag { + let res: Foo = self.do_foo(); + if &res < foo { res } else { foo.clone() } + } else { + let res: Baz = self.do_bar(); + if &res < bar { res } else { bar.clone() } + } +}