Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions crates/cgp-macro-lib/src/cgp_fn/item_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use syn::punctuated::Punctuated;
use syn::token::Plus;
use syn::{Generics, Ident, ItemFn, ItemImpl, TypeParamBound, parse2};

use crate::cgp_fn::{FunctionAttributes, ImplicitArgField, substitute_abstract_type};
use crate::cgp_fn::{
FunctionAttributes, ImplicitArgField, derive_use_type_trait_bounds, substitute_abstract_type,
};
use crate::derive_getter::derive_getter_constraint;
use crate::symbol::symbol_from_string;

Expand Down Expand Up @@ -74,11 +76,8 @@ pub fn derive_item_impl(
item_impl.to_token_stream(),
))?;

let mut bounds: Punctuated<TypeParamBound, Plus> = Punctuated::default();

for use_type in attributes.use_type.iter() {
bounds.push(parse2(use_type.trait_path.to_token_stream())?);
}
let bounds = derive_use_type_trait_bounds(&quote! { Self }, &attributes.use_type)?;
let bounds = Punctuated::<TypeParamBound, Plus>::from_iter(bounds);

item_impl
.generics
Expand Down
2 changes: 2 additions & 0 deletions crates/cgp-macro-lib/src/cgp_fn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ mod item_trait;
mod parse_implicits;
mod spec;
mod substitute_type;
mod type_equality;
mod use_type;

pub use attributes::*;
pub use derive::*;
pub use parse_implicits::*;
pub use spec::*;
pub use substitute_type::*;
pub use type_equality::*;
pub use use_type::*;
118 changes: 118 additions & 0 deletions crates/cgp-macro-lib/src/cgp_fn/type_equality.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use proc_macro2::TokenStream;
use quote::{ToTokens, quote};
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{Ident, Type, TypeParamBound, parse2};

use crate::cgp_fn::{UseTypeIdent, UseTypeSpec};

pub fn derive_use_type_trait_bounds(
context_type: &TokenStream,
specs: &[UseTypeSpec],
) -> syn::Result<Vec<TypeParamBound>> {
let mut bounds = Vec::new();

for use_type in specs.iter() {
let type_equalities = find_type_equalities(use_type, context_type, specs)?;

if type_equalities.is_empty() {
bounds.push(parse2(use_type.trait_path.to_token_stream())?);
} else {
let mut constraints: Punctuated<TokenStream, Comma> = Punctuated::new();

for (alias_ident, equal_target) in type_equalities.into_iter() {
constraints.push(quote! {
#alias_ident = #equal_target
});
}

let trait_path = &use_type.trait_path;
let bound = quote! {
#trait_path < #constraints >
};

bounds.push(parse2(bound)?);
}
}

Ok(bounds)
}

pub fn find_type_equalities(
current_spec: &UseTypeSpec,
context_type: &TokenStream,
specs: &[UseTypeSpec],
) -> syn::Result<Vec<(Ident, Type)>> {
let mut equalities = Vec::new();

for current_type_ident in current_spec.type_idents.iter() {
forbid_same_alias(current_type_ident, current_spec, specs)?;

if let Some(equality) =
find_type_equality(context_type, current_type_ident, current_spec, specs)?
{
equalities.push(equality);
}
}

Ok(equalities)
}

fn forbid_same_alias(
current_ident: &UseTypeIdent,
current_spec: &UseTypeSpec,
specs: &[UseTypeSpec],
) -> syn::Result<()> {
for spec in specs.iter() {
if core::ptr::eq(spec, current_spec) {
// Skip the current spec
continue;
}

for type_ident in spec.type_idents.iter() {
if current_ident.alias_ident() == type_ident.alias_ident() {
return Err(syn::Error::new_spanned(
&current_ident.type_ident,
"Multiple abstract types cannot share the same identifier or alias",
));
}
}
}

Ok(())
}

fn find_type_equality(
context_type: &TokenStream,
current_ident: &UseTypeIdent,
current_spec: &UseTypeSpec,
specs: &[UseTypeSpec],
) -> syn::Result<Option<(Ident, Type)>> {
if let Some(equal_target) = current_ident.equals.clone() {
for spec in specs.iter() {
if core::ptr::eq(spec, current_spec) {
// Skip the current spec
continue;
}

for match_use_type in spec.type_idents.iter() {
let match_type: Type = parse2(match_use_type.alias_ident().to_token_stream())?;
if match_type == equal_target {
let trait_path = &spec.trait_path;
let current_type_ident = &current_ident.type_ident;
let match_type_ident = &match_use_type.type_ident;

let equal_target: Type = parse2(quote! {
<#context_type as #trait_path>::#match_type_ident
})?;

return Ok(Some((current_type_ident.clone(), equal_target)));
}
}
}

Ok(Some((current_ident.type_ident.clone(), equal_target)))
} else {
Ok(None)
}
}
18 changes: 14 additions & 4 deletions crates/cgp-macro-lib/src/cgp_fn/use_type.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use syn::parse::{Parse, ParseStream};
use syn::token::{As, Brace, Colon, Comma, Gt, Lt};
use syn::{Ident, braced};
use syn::token::{As, Brace, Colon, Comma, Eq, Gt, Lt};
use syn::{Ident, Type, braced};

use crate::parse::SimpleType;

Expand All @@ -12,12 +12,13 @@ pub struct UseTypeSpec {
pub struct UseTypeIdent {
pub type_ident: Ident,
pub as_alias: Option<Ident>,
pub equals: Option<Type>,
}

impl UseTypeSpec {
pub fn replace_ident(&self, ident: &Ident) -> Option<Ident> {
for type_ident in &self.type_idents {
if type_ident.replacement_ident() == ident {
if type_ident.alias_ident() == ident {
let mut new_ident = type_ident.type_ident.clone();
new_ident.set_span(ident.span());
return Some(new_ident);
Expand All @@ -29,7 +30,7 @@ impl UseTypeSpec {
}

impl UseTypeIdent {
pub fn replacement_ident(&self) -> &Ident {
pub fn alias_ident(&self) -> &Ident {
self.as_alias.as_ref().unwrap_or(&self.type_ident)
}
}
Expand Down Expand Up @@ -64,6 +65,7 @@ impl Parse for UseTypeSpec {
vec![UseTypeIdent {
type_ident: ident,
as_alias: None,
equals: None,
}]
};

Expand All @@ -85,9 +87,17 @@ impl Parse for UseTypeIdent {
None
};

let equals = if input.peek(Eq) {
let _: Eq = input.parse()?;
Some(input.parse()?)
} else {
None
};

Ok(Self {
type_ident,
as_alias,
equals,
})
}
}
2 changes: 2 additions & 0 deletions crates/cgp-tests/tests/cgp_fn.rs
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
#![allow(clippy::disallowed_names)]

pub mod cgp_fn_tests;
1 change: 1 addition & 0 deletions crates/cgp-tests/tests/cgp_fn_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod extend;
pub mod generics;
pub mod multi;
pub mod mutable;
pub mod type_equality;
pub mod use_type;
pub mod use_type_alias;
pub mod uses;
50 changes: 50 additions & 0 deletions crates/cgp-tests/tests/cgp_fn_tests/type_equality.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use std::fmt::Display;

use cgp::prelude::*;

#[cgp_type]
pub trait HasScalarType {
type Scalar;
}

#[cgp_fn]
#[use_type(HasScalarType::{Scalar = f64})]
pub fn rectangle_area(&self, #[implicit] width: Scalar, #[implicit] height: Scalar) -> Scalar {
let res: f64 = width * height;
res
}

pub trait HasFooType {
// The `Ord + Clone` bounds are visible to both `Foo` and `Bar` because of `Bar = Foo` below
type Foo: Ord + Clone;
}

pub trait HasBarType {
// The `Display` bounds are hidden because of `Bar = Foo` below
type Bar: Display;
}

#[cgp_fn]
#[use_type(HasFooType::Foo)]
pub fn do_foo(&self) -> Foo {
todo!()
}

#[cgp_fn]
#[use_type(HasBarType::Bar)]
pub fn do_bar(&self) -> Bar {
todo!()
}

#[cgp_fn]
#[use_type(HasBarType::{Bar as Baz = Foo}, HasFooType::Foo)]
#[uses(DoFoo, DoBar)]
fn return_foo_or_bar(&self, flag: bool, #[implicit] foo: &Foo, #[implicit] bar: &Baz) -> Foo {
if flag {
let res: Foo = self.do_foo();
if &res < foo { res } else { foo.clone() }
} else {
let res: Baz = self.do_bar();
if &res < bar { res } else { bar.clone() }
}
}
Loading