diff options
-rw-r--r-- | .gitignore | 2 | ||||
-rw-r--r-- | Cargo.toml | 28 | ||||
-rw-r--r-- | README.md | 119 | ||||
-rw-r--r-- | src/lib.rs | 425 | ||||
-rw-r--r-- | src/parse_assoc_type.rs | 119 | ||||
-rw-r--r-- | src/parse_trait_sig.rs | 505 | ||||
-rw-r--r-- | src/syn_utils.rs | 97 | ||||
-rw-r--r-- | tests/gats.rs | 17 | ||||
-rw-r--r-- | tests/tests.rs | 152 |
9 files changed, 1464 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96ef6c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..3b31a88 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "dynamize" +version = "0.1.0" +description = "proc macro to make traits with associated types object-safe" +authors = ["Martin Fischer <martin@push-f.com>"] +license = "MIT" +repository = "https://git.push-f.com/dynamize" +keywords = ["proc-macro", "attribute", "dyn", "object-safe", "blanket"] +categories = ["rust-patterns"] +edition = "2021" + +[lib] +proc-macro = true +doctest = false + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +# internal features +nightly = [] + +[dependencies] +proc-macro2 = "1.0" +quote = "1.0" +syn = { version = "1.0", features = ["full", "extra-traits"] } + +[dev-dependencies] +async-trait = "0.1" diff --git a/README.md b/README.md new file mode 100644 index 0000000..332de4f --- /dev/null +++ b/README.md @@ -0,0 +1,119 @@ +# Dynamize + +In order for a trait to be usable as a trait object it needs to fulfill +[several requirements](https://doc.rust-lang.org/reference/items/traits.html#object-safety). +For example: + +```rust +trait Client { + type Error; + + fn get(&self, url: String) -> Result<Vec<u8>, Self::Error>; +} + +impl Client for HttpClient { type Error = HttpError; ...} +impl Client for FtpClient { type Error = FtpError; ...} + +let client: HttpClient = ...; +let object = &client as &dyn Client; +``` + +The last line of the above code fails to compile with: + +> error[E0191]: the value of the associated type `Error` (from trait `Client`) +> must be specified + +Sometimes you however want a trait object to be able to encompass trait +implementations with different associated type values. This crate provides an +attribute macro to achieve that. To use dynamize you only have to make some +small changes: + +```rust +#[dynamize::dynamize] +trait Client { + type Error: Into<SuperError>; + + fn get(&self, url: String) -> Result<Vec<u8>, Self::Error>; +} + +let client: HttpClient = ...; +let object = &client as &dyn DynClient; +``` + +1. You add the `#[dynamize::dynamize]` attribute to your trait. +2. You specify an `Into<T>` bound for each associated type. + +Dynamize defines a new trait for you, named after your trait but with the `Dyn` +prefix, so e.g. `Client` becomes `DynClient` in our example. The new +"dynamized" trait can then be used without having to specify the associated +type. + +## How does this work? + +For the above example dynamize generates the following code: + +```rust +trait DynClient { + fn get(&self, url: String) -> Result<Vec<u8>, SuperError>; +} + +impl<__to_be_dynamized: Client> DynClient for __to_be_dynamized { + fn get(&self, url: String) -> Result<Vec<u8>, SuperError> { + Client::get(self, url).map_err(|x| x.into()) + } +} +``` + +As you can see in the dynamized trait the associated type was replaced with the +destination type of the `Into` bound. The magic however happens afterwards: +dynamize generates a blanket implementation: each type implementing `Client` +automatically also implements `DynClient`! + +## How does this actually work? + +Dynamize recognizes the `Result<T, E>` in the return type and knows that +associated types in `T` need to be mapped with `map()` whereas associated types +in `E` need to be mapped with `map_err()`. Dynamize also understands +`Option<T>`. Thanks to recursion Dynamize can deal with arbitrarily nested +options and results, so e.g. `Result<Option<Self::Item>, Self::Error>` also +just works. + +## Dynamize supports async + +Dynamize supports async out of the box. Since Rust however does not yet support +async functions in traits, you'll have to additionally use another library like +[async-trait](https://crates.io/crates/async-trait), for example: + +```rust +#[dynamize::dynamize] +#[dyn_trait_attr(async_trait)] +#[blanket_impl_attr(async_trait)] +#[async_trait] +trait Client: Sync { + type Error: Into<SuperError>; + + async fn get(&self, url: String) -> Result<Vec<u8>, Self::Error>; +} +``` + +The `#[dyn_trait_attr(...)]` attribute lets you attach macro attributes to the +generated dynamized trait. The `#[blanket_impl_attr(...)]` attribute lets you +attach macro attributes to the generated blanket implementation. Note that it +is important that the dynamize attribute comes before the `async_trait` +attribute. + +## Dynamize supports Fn, FnOnce & FnMut + +The following also just works: + +```rust +#[dynamize::dynamize] +trait TraitWithCallback { + type A: Into<String>; + + fn fun_with_callback<F: Fn(Self::A)>(&self, a: F); +} +``` + +Note that since in order to be object-safe methods must not have generics, +dynamize simply moves the generic from the method to the trait definition. diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..1f0432c --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,425 @@ +#![doc = include_str!("../README.md")] +use std::collections::HashMap; + +use proc_macro::TokenStream; +use proc_macro2::Group; +use proc_macro2::Ident; +use quote::format_ident; +use quote::quote; +use quote::quote_spanned; +use quote::ToTokens; +use syn::parse_macro_input; +use syn::punctuated::Punctuated; +use syn::token::Brace; +use syn::token::Gt; +use syn::token::Lt; +use syn::token::Trait; +use syn::AngleBracketedGenericArguments; +use syn::Block; +use syn::Expr; +use syn::GenericArgument; +use syn::GenericParam; +use syn::ImplItemMethod; +use syn::ItemTrait; +use syn::Path; +use syn::PathArguments; +use syn::PathSegment; +use syn::Signature; +use syn::Stmt; +use syn::TraitBound; +use syn::TraitItem; +use syn::TraitItemMethod; +use syn::Type; +use syn::TypeParam; +use syn::TypeParamBound; +use syn::TypePath; +use syn::Visibility; +use syn_utils::TypeMatcher; + +use crate::parse_assoc_type::parse_assoc_type; +use crate::parse_assoc_type::AssocTypeParseError; +use crate::parse_trait_sig::parse_trait_signature; +use crate::parse_trait_sig::MethodParseError; +use crate::parse_trait_sig::SignatureChanges; +use crate::parse_trait_sig::TypeTransform; +mod parse_assoc_type; +mod parse_trait_sig; +mod syn_utils; + +macro_rules! abort { + ($span:expr, $message:literal) => { + quote_spanned! {$span => compile_error!($message);}.into() + }; +} + +#[proc_macro_attribute] +pub fn dynamize(_attr: TokenStream, input: TokenStream) -> TokenStream { + let mut original_trait = parse_macro_input!(input as ItemTrait); + assert!(original_trait.auto_token.is_none()); + + let original_trait_name = original_trait.ident.clone(); + + let mut objectifiable_methods: Vec<(Signature, SignatureChanges)> = Vec::new(); + + let mut assoc_type_conversions: HashMap<&Ident, &Type> = HashMap::new(); + + for item in &original_trait.items { + if let TraitItem::Type(assoc_type) = item { + match parse_assoc_type(assoc_type) { + Err((_, AssocTypeParseError::NoIntoBound)) => continue, + Err((span, AssocTypeParseError::AssocTypeInBound)) => { + return abort!(span, "dynamize does not support associated types here") + } + Err((span, AssocTypeParseError::GenericAssociatedType)) => { + return abort!( + span, + "dynamize does not (yet?) support generic associated types" + ) + } + Ok((ident, type_)) => { + assoc_type_conversions.insert(ident, type_); + } + } + } + } + + for item in &original_trait.items { + if let TraitItem::Method(method) = item { + let mut signature = method.sig.clone(); + match parse_trait_signature(&mut signature, &assoc_type_conversions) { + Err((_, MethodParseError::NonDispatchableMethod)) => continue, + Err((span, MethodParseError::UnconvertibleAssocType)) => { + return abort!( + span, + "associated type is either undefined or doesn't have an Into bound" + ) + } + Err((span, MethodParseError::AssocTypeInInputs)) => { + return abort!( + span, + "dynamize does not support associated types in parameter types" + ) + } + Err(( + span, + MethodParseError::AssocTypeInUnsupportedReturnType + | MethodParseError::UnconvertibleAssocTypeInFnInput, + )) => return abort!(span, "dynamize does not know how to convert this type"), + Err((span, MethodParseError::UnconvertibleAssocTypeInTraitBound)) => { + return abort!(span, "dynamize does not support associated types here") + } + Err((span, MethodParseError::ImplTraitInInputs)) => { + return abort!( + span, + "dynamize does not support impl here, change it to a method generic" + ) + } + Ok(parsed_method) => objectifiable_methods.push((signature, parsed_method)), + }; + } + } + + let mut method_impls: Vec<ImplItemMethod> = Vec::new(); + + let mut blanket_impl_attrs = Vec::new(); + let mut dyn_trait_attrs = Vec::new(); + + // FUTURE: use Vec::drain_filter once it's stable + let mut i = 0; + while i < original_trait.attrs.len() { + if original_trait.attrs[i].path.is_ident("blanket_impl_attr") { + let attr = original_trait.attrs.remove(i); + let group: Group = match syn::parse2(attr.tokens) { + Ok(g) => g, + Err(err) => { + return abort!( + err.span(), + "expected parenthesis: #[blanket_impl_attr(...)]" + ) + } + }; + let tokens = group.stream(); + blanket_impl_attrs.push(quote! {#[#tokens]}); + } else if original_trait.attrs[i].path.is_ident("dyn_trait_attr") { + let attr = original_trait.attrs.remove(i); + let group: Group = match syn::parse2(attr.tokens) { + Ok(g) => g, + Err(err) => { + return abort!(err.span(), "expected parenthesis: #[dyn_trait_attr(...)]") + } + }; + let tokens = group.stream(); + dyn_trait_attrs.push(quote! {#[#tokens]}); + } else { + i += 1; + } + } + + let mut dyn_trait = ItemTrait { + ident: format_ident!("Dyn{}", original_trait.ident), + + attrs: Vec::new(), + vis: original_trait.vis.clone(), + unsafety: original_trait.unsafety, + auto_token: None, + trait_token: Trait::default(), + generics: original_trait.generics.clone(), + colon_token: None, + supertraits: Punctuated::new(), + brace_token: Brace::default(), + items: Vec::new(), + }; + + for (signature, parsed_method) in objectifiable_methods { + let mut new_method = TraitItemMethod { + attrs: Vec::new(), + sig: signature, + default: None, + semi_token: None, + }; + + let fun_name = &new_method.sig.ident; + + let args = new_method + .sig + .inputs + .iter() + .enumerate() + .map(|(idx, arg)| match arg { + syn::FnArg::Receiver(_) => quote! {self}, + syn::FnArg::Typed(pat_type) => match pat_type.pat.as_ref() { + syn::Pat::Ident(ident) => match &parsed_method.inputs[idx] { + None => ident.ident.to_token_stream(), + Some(transforms) => { + let args = (0..transforms.len()).map(|i| format_ident!("a{}", i)); + let mut calls: Vec<_> = + args.clone().map(|i| i.into_token_stream()).collect(); + for i in 0..calls.len() { + transforms[i].append_conversion(&mut calls[i]); + } + let move_opt = new_method.sig.asyncness.map(|_| quote! {move}); + quote!(#move_opt |#(#args),*| #ident(#(#calls),*)) + } + }, + _other => { + panic!("unexpected"); + } + }, + }); + + // in order for a trait to be object-safe its methods may not have + // generics so we convert method generics into trait generics + if new_method + .sig + .generics + .params + .iter() + .any(|p| matches!(p, GenericParam::Type(_))) + { + // syn::punctuated::Punctuated doesn't have a remove(index) + // method so we firstly move all elements to a vector + let mut params = Vec::new(); + while let Some(generic_param) = new_method.sig.generics.params.pop() { + params.push(generic_param.into_value()); + } + + // FUTURE: use Vec::drain_filter once it's stable + let mut i = 0; + while i < params.len() { + if matches!(params[i], GenericParam::Type(_)) { + dyn_trait.generics.params.push(params.remove(i)); + } else { + i += 1; + } + } + + new_method.sig.generics.params.extend(params); + + if dyn_trait.generics.lt_token.is_none() { + dyn_trait.generics.lt_token = Some(Lt::default()); + dyn_trait.generics.gt_token = Some(Gt::default()); + } + } + + let mut expr = quote!(#original_trait_name::#fun_name(#(#args),*)); + if new_method.sig.asyncness.is_some() { + expr.extend(quote! {.await}) + } + parsed_method.return_type.append_conversion(&mut expr); + + method_impls.push(ImplItemMethod { + attrs: Vec::new(), + vis: Visibility::Inherited, + defaultness: None, + sig: new_method.sig.clone(), + block: Block { + brace_token: Brace::default(), + stmts: vec![Stmt::Expr(Expr::Verbatim(expr))], + }, + }); + dyn_trait.items.push(new_method.into()); + } + + let blanket_impl = generate_blanket_impl(&dyn_trait, &original_trait, &method_impls); + + let expanded = quote! { + #original_trait + + #(#dyn_trait_attrs)* + #dyn_trait + + #(#blanket_impl_attrs)* + #blanket_impl + }; + TokenStream::from(expanded) +} + +fn generate_blanket_impl( + dyn_trait: &ItemTrait, + original_trait: &ItemTrait, + method_impls: &[ImplItemMethod], +) -> proc_macro2::TokenStream { + let mut blanket_generics = dyn_trait.generics.clone(); + let some_ident = format_ident!("__to_be_dynamized"); + blanket_generics.params.push(GenericParam::Type(TypeParam { + attrs: Vec::new(), + ident: some_ident.clone(), + colon_token: None, + bounds: std::iter::once(TypeParamBound::Trait(TraitBound { + paren_token: None, + modifier: syn::TraitBoundModifier::None, + lifetimes: None, + path: Path { + leading_colon: None, + segments: std::iter::once(path_segment_for_trait(original_trait)).collect(), + }, + })) + .collect(), + eq_token: None, + default: None, + })); + let (_, type_gen, _where) = dyn_trait.generics.split_for_impl(); + let dyn_trait_name = &dyn_trait.ident; + quote! { + impl #blanket_generics #dyn_trait_name #type_gen for #some_ident { + #(#method_impls)* + } + } +} + +struct AssocTypeMatcher; +impl TypeMatcher<Path> for AssocTypeMatcher { + fn match_path<'a>(&self, path: &'a Path) -> Option<&'a Path> { + if path.segments.first().unwrap().ident == "Self" { + return Some(path); + } + None + } +} + +impl<TypeMatch, PathMatch, T> TypeMatcher<T> for (TypeMatch, PathMatch) +where + TypeMatch: Fn(&Type) -> Option<&T>, + PathMatch: Fn(&Path) -> Option<&T>, +{ + fn match_type<'a>(&self, t: &'a Type) -> Option<&'a T> { + self.0(t) + } + + fn match_path<'a>(&self, path: &'a Path) -> Option<&'a T> { + self.1(path) + } +} + +impl TypeTransform { + fn append_conversion(&self, stream: &mut proc_macro2::TokenStream) { + match self { + TypeTransform::Into => stream.extend(quote! {.into()}), + TypeTransform::Map(opt) => { + let mut inner = quote!(x); + opt.append_conversion(&mut inner); + stream.extend(quote! {.map(|x| #inner)}) + } + TypeTransform::Result(ok, err) => { + if !matches!(ok.as_ref(), TypeTransform::NoOp) { + let mut inner = quote!(x); + ok.append_conversion(&mut inner); + stream.extend(quote! {.map(|x| #inner)}) + } + if !matches!(err.as_ref(), TypeTransform::NoOp) { + let mut inner = quote!(x); + err.append_conversion(&mut inner); + stream.extend(quote! {.map_err(|x| #inner)}) + } + } + _other => {} + } + } +} + +/// Just a convenience trait for us to avoid match/if-let blocks everywhere. +trait As<T> { + fn get_as(&self) -> Option<&T>; + fn get_as_mut(&mut self) -> Option<&mut T>; +} + +impl As<AngleBracketedGenericArguments> for PathArguments { + fn get_as(&self) -> Option<&AngleBracketedGenericArguments> { + match self { + PathArguments::AngleBracketed(args) => Some(args), + _other => None, + } + } + fn get_as_mut(&mut self) -> Option<&mut AngleBracketedGenericArguments> { + match self { + PathArguments::AngleBracketed(args) => Some(args), + _other => None, + } + } +} + +impl As<Type> for GenericArgument { + fn get_as(&self) -> Option<&Type> { + match self { + GenericArgument::Type(typearg) => Some(typearg), + _other => None, + } + } + fn get_as_mut(&mut self) -> Option<&mut Type> { + match self { + GenericArgument::Type(typearg) => Some(typearg), + _other => None, + } + } +} + +fn path_segment_for_trait(sometrait: &ItemTrait) -> PathSegment { + PathSegment { + ident: sometrait.ident.clone(), + arguments: match sometrait.generics.params.is_empty() { + true => PathArguments::None, + false => PathArguments::AngleBracketed(AngleBracketedGenericArguments { + colon2_token: None, + lt_token: Lt::default(), + args: sometrait + .generics + .params + .iter() + .map(|param| match param { + GenericParam::Type(type_param) => { + GenericArgument::Type(Type::Path(TypePath { + path: type_param.ident.clone().into(), + qself: None, + })) + } + GenericParam::Lifetime(lifetime_def) => { + GenericArgument::Lifetime(lifetime_def.lifetime.clone()) + } + GenericParam::Const(_) => todo!("const generic param not supported"), + }) + .collect(), + gt_token: Gt::default(), + }), + }, + } +} diff --git a/src/parse_assoc_type.rs b/src/parse_assoc_type.rs new file mode 100644 index 0000000..37fd78c --- /dev/null +++ b/src/parse_assoc_type.rs @@ -0,0 +1,119 @@ +use proc_macro2::Span; +use syn::spanned::Spanned; +use syn::{GenericArgument, Ident, PathArguments, PathSegment, TraitItemType, Type}; + +use crate::syn_utils::{find_in_type, trait_bounds}; +use crate::AssocTypeMatcher; + +#[derive(Debug)] +pub enum AssocTypeParseError { + AssocTypeInBound, + GenericAssociatedType, + NoIntoBound, +} + +pub fn parse_assoc_type( + assoc_type: &TraitItemType, +) -> Result<(&Ident, &Type), (Span, AssocTypeParseError)> { + for bound in trait_bounds(&assoc_type.bounds) { + if let PathSegment { + ident, + arguments: PathArguments::AngleBracketed(args), + } = bound.path.segments.first().unwrap() + { + if ident == "Into" && args.args.len() == 1 { + if let GenericArgument::Type(into_type) = args.args.first().unwrap() { + // provide a better error message for type A: Into<Self::B> + if find_in_type(into_type, &AssocTypeMatcher).is_some() { + return Err((into_type.span(), AssocTypeParseError::AssocTypeInBound)); + } + + // TODO: support lifetime GATs (see the currently failing tests/gats.rs) + if !assoc_type.generics.params.is_empty() { + return Err(( + assoc_type.generics.params.span(), + AssocTypeParseError::GenericAssociatedType, + )); + } + + return Ok((&assoc_type.ident, into_type)); + } + } + } + } + Err((assoc_type.span(), AssocTypeParseError::NoIntoBound)) +} + +#[cfg(test)] +mod tests { + use quote::quote; + use syn::{TraitItemType, Type}; + + use crate::parse_assoc_type::{parse_assoc_type, AssocTypeParseError}; + + #[test] + fn ok() { + let type1: TraitItemType = syn::parse2(quote! { + type A: Into<String>; + }) + .unwrap(); + + assert!(matches!( + parse_assoc_type(&type1), + Ok((id, Type::Path(path))) + if id == "A" && path.path.is_ident("String") + )); + } + + #[test] + fn err_no_bound() { + let type1: TraitItemType = syn::parse2(quote! { + type A; + }) + .unwrap(); + + assert!(matches!( + parse_assoc_type(&type1), + Err((_, AssocTypeParseError::NoIntoBound)) + )); + } + + #[test] + fn err_assoc_type_in_bound() { + let type1: TraitItemType = syn::parse2(quote! { + type A: Into<Self::B>; + }) + .unwrap(); + + assert!(matches!( + parse_assoc_type(&type1), + Err((_, AssocTypeParseError::AssocTypeInBound)) + )); + } + + #[test] + fn err_gat_type() { + let type1: TraitItemType = syn::parse2(quote! { + type A<X>: Into<Foobar<X>>; + }) + .unwrap(); + + assert!(matches!( + parse_assoc_type(&type1), + Err((_, AssocTypeParseError::GenericAssociatedType)) + )); + } + + #[test] + fn err_gat_lifetime() { + let type1: TraitItemType = syn::parse2(quote! { + type A<'a>: Into<Foobar<'a>>; + }) + .unwrap(); + + assert!(matches!( + parse_assoc_type(&type1), + Err((_, AssocTypeParseError::GenericAssociatedType)) + )); + } +} diff --git a/src/parse_trait_sig.rs b/src/parse_trait_sig.rs new file mode 100644 index 0000000..55a3214 --- /dev/null +++ b/src/parse_trait_sig.rs @@ -0,0 +1,505 @@ +use std::collections::HashMap; + +use proc_macro2::Span; +use syn::{ + spanned::Spanned, FnArg, Ident, PathArguments, PredicateType, Receiver, ReturnType, Type, + TypePath, WherePredicate, +}; +use syn::{GenericParam, Signature, TypeImplTrait, TypeParamBound}; + +use crate::syn_utils::{find_in_path, find_in_type, trait_bounds, TypeMatcher}; +use crate::{As, AssocTypeMatcher}; + +#[derive(Debug, Clone)] +pub enum TypeTransform { + NoOp, + Into, + Map(Box<TypeTransform>), + Result(Box<TypeTransform>, Box<TypeTransform>), +} + +#[derive(Debug)] +pub enum MethodParseError { + NonDispatchableMethod, + AssocTypeInInputs, + ImplTraitInInputs, + AssocTypeInUnsupportedReturnType, + UnconvertibleAssocTypeInFnInput, + UnconvertibleAssocTypeInTraitBound, + UnconvertibleAssocType, +} + +struct ImplTraitMatcher; +impl TypeMatcher<TypeImplTrait> for ImplTraitMatcher { + fn match_type<'a>(&self, t: &'a Type) -> Option<&'a TypeImplTrait> { + if let Type::ImplTrait(impltrait) = t { + Some(impltrait) + } else { + None + } + } +} + +pub struct SignatureChanges { + pub return_type: TypeTransform, + pub inputs: Vec<Option<Vec<TypeTransform>>>, +} + +pub fn parse_trait_signature( + signature: &mut Signature, + assoc_type_conversions: &HashMap<&Ident, &Type>, +) -> Result<SignatureChanges, (Span, MethodParseError)> { + let assoc_type_conversions = AssocTypeConversions(assoc_type_conversions); + + if is_non_dispatchable(signature) { + return Err((signature.span(), MethodParseError::NonDispatchableMethod)); + } + + // provide better error messages for associated types in params + for input in &signature.inputs { + if let FnArg::Typed(pattype) = input { + if find_in_type(&pattype.ty, &AssocTypeMatcher).is_some() { + return Err((pattype.ty.span(), MethodParseError::AssocTypeInInputs)); + } + if let Some(impl_trait) = find_in_type(&pattype.ty, &ImplTraitMatcher) { + return Err((impl_trait.span(), MethodParseError::ImplTraitInInputs)); + } + } + } + + let mut type_param_transforms = HashMap::new(); + let mut input_transforms = Vec::new(); + + for generic_param in &mut signature.generics.params { + if let GenericParam::Type(type_param) = generic_param { + for bound in &mut type_param.bounds { + if let TypeParamBound::Trait(bound) = bound { + if bound.path.segments.len() == 1 { + let segment = bound.path.segments.first_mut().unwrap(); + + if let PathArguments::Parenthesized(args) = &mut segment.arguments { + if segment.ident == "Fn" + || segment.ident == "FnOnce" + || segment.ident == "FnMut" + { + let mut transforms = Vec::new(); + for input_type in &mut args.inputs { + match assoc_type_conversions.parse_type_path(input_type) { + Ok(ret_type) => { + transforms.push(ret_type); + } + Err(TransformError::UnconvertibleAssocType(span)) => { + return Err(( + span, + MethodParseError::UnconvertibleAssocType, + )); + } + Err(TransformError::AssocTypeInUnsupportedType(span)) => { + return Err(( + span, + MethodParseError::UnconvertibleAssocTypeInFnInput, + )); + } + } + } + if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { + type_param_transforms.insert(&type_param.ident, transforms); + } + } + } + } + if let Some(path) = find_in_path(&bound.path, &AssocTypeMatcher) { + return Err(( + path.span(), + MethodParseError::UnconvertibleAssocTypeInTraitBound, + )); + } + } + } + } + } + + for input in &signature.inputs { + if let FnArg::Typed(pattype) = input { + if let Type::Path(path) = &*pattype.ty { + if let Some(ident) = path.path.get_ident() { + input_transforms.push(type_param_transforms.get(ident).map(|x| (*x).clone())); + continue; + } + } + } + input_transforms.push(None); + } + + let return_type = match &mut signature.output { + ReturnType::Type(_, og_type) => match assoc_type_conversions.parse_type_path(og_type) { + Ok(ret_type) => ret_type, + Err(TransformError::UnconvertibleAssocType(span)) => { + return Err((span, MethodParseError::UnconvertibleAssocType)); + } + Err(TransformError::AssocTypeInUnsupportedType(span)) => { + return Err((span, MethodParseError::AssocTypeInUnsupportedReturnType)); + } + }, + ReturnType::Default => TypeTransform::NoOp, + }; + Ok(SignatureChanges { + return_type, + inputs: input_transforms, + }) +} + +struct AssocTypeConversions<'a>(&'a HashMap<&'a Ident, &'a Type>); + +enum TransformError { + UnconvertibleAssocType(Span), + AssocTypeInUnsupportedType(Span), +} + +impl AssocTypeConversions<'_> { + fn parse_type_path(&self, type_: &mut Type) -> Result<TypeTransform, TransformError> { + let assoc_span = match find_in_type(type_, &AssocTypeMatcher) { + Some(path) => path.span(), + None => return Ok(TypeTransform::NoOp), + }; + + if let Type::Path(TypePath { path, qself: None }) = type_ { + let ident = &path.segments.first().unwrap().ident; + + // TODO: support &mut dyn Iterator<Item = Self::A> + // conversion to Box<dyn Iterator<Item = Whatever>> via .map(Into::into) + + if ident == "Self" && path.segments.len() == 2 { + let ident = &path.segments.last().unwrap().ident; + *type_ = (*self + .0 + .get(&ident) + .ok_or_else(|| TransformError::UnconvertibleAssocType(ident.span()))?) + .clone(); + return Ok(TypeTransform::Into); + } else if ident == "Option" && path.segments.len() == 1 { + let first_seg = path.segments.first_mut().unwrap(); + + if let Some(args) = first_seg.arguments.get_as_mut() { + if args.args.len() == 1 { + if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() { + if find_in_type(generic_type, &AssocTypeMatcher).is_some() { + return Ok(TypeTransform::Map( + self.parse_type_path(generic_type)?.into(), + )); + } + } + } + } + } else if ident == "Result" && path.segments.len() == 1 { + let first_seg = path.segments.first_mut().unwrap(); + if let Some(args) = first_seg.arguments.get_as_mut() { + if args.args.len() == 2 { + let mut args_iter = args.args.iter_mut(); + if let (Some(ok_type), Some(err_type)) = ( + args_iter.next().unwrap().get_as_mut(), + args_iter.next().unwrap().get_as_mut(), + ) { + if find_in_type(ok_type, &AssocTypeMatcher).is_some() + || find_in_type(err_type, &AssocTypeMatcher).is_some() + { + return Ok(TypeTransform::Result( + self.parse_type_path(ok_type)?.into(), + self.parse_type_path(err_type)?.into(), + )); + } + } + } + } + } else { + let last_seg = &path.segments.last().unwrap(); + if last_seg.ident == "Result" { + let last_seg = path.segments.last_mut().unwrap(); + if let Some(args) = last_seg.arguments.get_as_mut() { + if args.args.len() == 1 { + if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() + { + if find_in_type(generic_type, &AssocTypeMatcher).is_some() { + return Ok(TypeTransform::Map( + self.parse_type_path(generic_type)?.into(), + )); + } + } + } + } + } + } + } + + // the type contains an associated type but we + // don't know how to deal with it so we abort + Err(TransformError::AssocTypeInUnsupportedType(assoc_span)) + } +} + +fn is_non_dispatchable(signature: &Signature) -> bool { + // non-dispatchable: fn example(&self) where Self: Sized; + if let Some(where_clause) = &signature.generics.where_clause { + if where_clause + .predicates + .iter() + .any(bounds_self_and_has_bound_sized) + { + return true; + } + } + + // non-dispatchable: fn example(); + if signature.inputs.is_empty() { + return true; + } + + // non-dispatchable: fn example(arg: Type); + if matches!(signature.inputs.first(), Some(FnArg::Typed(_))) { + return true; + } + + // non-dispatchable: fn example(self); + if matches!( + signature.inputs.first(), + Some(FnArg::Receiver(Receiver { + reference: None, + .. + })) + ) { + return true; + } + false +} + +/// Returns true if the bounded type is `Self` and the bounds contain `Sized`. +fn bounds_self_and_has_bound_sized(predicate: &WherePredicate) -> bool { + matches!( + predicate, + WherePredicate::Type(PredicateType { + bounded_ty: Type::Path(TypePath { path, .. }), + bounds, + .. + }) + if path.is_ident("Self") + && trait_bounds(bounds).any(|b| b.path.is_ident("Sized")) + ) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use quote::{format_ident, quote}; + use syn::{TraitItemMethod, Type}; + + use crate::parse_trait_sig::{ + parse_trait_signature, MethodParseError, SignatureChanges, TypeTransform, + }; + + #[test] + fn ok_void() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Ok(SignatureChanges { + return_type: TypeTransform::NoOp, + .. + }) + )); + } + + #[test] + fn ok_assoc_type() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Self::A; + }) + .unwrap(); + + let mut assoc_type_map = HashMap::new(); + let ident = format_ident!("A"); + let dest = Type::Verbatim(quote! {Example}); + assoc_type_map.insert(&ident, &dest); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &assoc_type_map), + Ok(SignatureChanges { + return_type: TypeTransform::Into, + .. + }) + )); + } + + #[test] + fn err_unconvertible_assoc_type() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Self::A; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::UnconvertibleAssocType)) + )); + } + + #[test] + fn err_non_dispatchable_assoc_function_no_args() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::NonDispatchableMethod)) + )); + } + + #[test] + fn err_non_dispatchable_assoc_function_with_args() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(arg: Type); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::NonDispatchableMethod)) + )); + } + + #[test] + fn err_non_dispatchable_consume_self() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(self); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::NonDispatchableMethod)) + )); + } + + #[test] + fn err_non_dispatchable_where_self_sized() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) where Self: Sized; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::NonDispatchableMethod)) + )); + } + + #[test] + fn err_assoc_type_in_unsupported_return() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Foo<Self::A>; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) + )); + } + + #[test] + fn err_assoc_type_in_unsupported_return_in_opt() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Option<Foo<Self::A>>; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) + )); + } + + #[test] + fn err_assoc_type_in_unsupported_return_in_ok() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Result<Foo<Self::A>, Error>; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) + )); + } + + #[test] + fn err_assoc_type_in_unsupported_return_in_err() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Result<Ok, Foo<Self::A>>; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) + )); + } + + #[test] + fn err_assoc_type_in_input() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self, x: Self::A); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInInputs)) + )); + } + + #[test] + fn err_assoc_type_in_input_opt() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self, x: Option<Self::A>); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInInputs)) + )); + } + + #[test] + fn err_impl_in_input() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self, arg: Option<impl SomeTrait>); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::ImplTraitInInputs)) + )); + } + + #[test] + fn err_assoc_type_in_generic() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test<F: Fn(Foo<Self::A>)>(&self, fun: F); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::UnconvertibleAssocTypeInFnInput)) + )); + } +} diff --git a/src/syn_utils.rs b/src/syn_utils.rs new file mode 100644 index 0000000..4588186 --- /dev/null +++ b/src/syn_utils.rs @@ -0,0 +1,97 @@ +use syn::{ + punctuated::Punctuated, GenericArgument, Path, PathArguments, ReturnType, TraitBound, Type, + TypeParamBound, +}; + +pub trait TypeMatcher<T> { + fn match_type<'a>(&self, t: &'a Type) -> Option<&'a T> { + None + } + fn match_path<'a>(&self, path: &'a Path) -> Option<&'a T> { + None + } +} + +pub fn trait_bounds<T>( + bounds: &Punctuated<TypeParamBound, T>, +) -> impl Iterator<Item = &TraitBound> { + bounds.iter().filter_map(|b| match b { + TypeParamBound::Trait(t) => Some(t), + TypeParamBound::Lifetime(_) => None, + }) +} + +pub fn find_in_type<'a, T>(t: &'a Type, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> { + if let Some(ret) = matcher.match_type(t) { + return Some(ret); + } + match t { + Type::Array(array) => find_in_type(&array.elem, matcher), + Type::BareFn(fun) => { + for input in &fun.inputs { + if let Some(ret) = find_in_type(&input.ty, matcher) { + return Some(ret); + } + } + if let ReturnType::Type(_, t) = &fun.output { + return find_in_type(t, matcher); + } + None + } + Type::Group(group) => find_in_type(&group.elem, matcher), + Type::ImplTrait(impltrait) => { + for bound in &impltrait.bounds { + if let TypeParamBound::Trait(bound) = bound { + if let Some(ret) = find_in_path(&bound.path, matcher) { + return Some(ret); + } + } + } + None + } + Type::Infer(_) => None, + Type::Macro(_) => None, + Type::Paren(paren) => find_in_type(&paren.elem, matcher), + Type::Path(path) => find_in_path(&path.path, matcher), + Type::Ptr(ptr) => find_in_type(&ptr.elem, matcher), + Type::Reference(reference) => find_in_type(&reference.elem, matcher), + Type::Slice(slice) => find_in_type(&slice.elem, matcher), + Type::TraitObject(traitobj) => { + for bound in &traitobj.bounds { + if let TypeParamBound::Trait(bound) = bound { + if let Some(ret) = find_in_path(&bound.path, matcher) { + return Some(ret); + } + } + } + None + } + Type::Tuple(tuple) => { + for elem in &tuple.elems { + if let Some(ret) = find_in_type(elem, matcher) { + return Some(ret); + } + } + None + } + _other => None, + } +} + +pub fn find_in_path<'a, T>(path: &'a Path, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> { + if let Some(ret) = matcher.match_path(path) { + return Some(ret); + } + for segment in &path.segments { + if let PathArguments::AngleBracketed(args) = &segment.arguments { + for arg in &args.args { + if let GenericArgument::Type(t) = arg { + if let Some(ret) = find_in_type(t, matcher) { + return Some(ret); + } + } + } + } + } + None +} diff --git a/tests/gats.rs b/tests/gats.rs new file mode 100644 index 0000000..92e483c --- /dev/null +++ b/tests/gats.rs @@ -0,0 +1,17 @@ +//! This test can be run with `cargo +nightly test --features=nightly` +#![cfg_attr(feature = "nightly", feature(generic_associated_types))] + +#[cfg(feature = "nightly")] +mod test_gats { + #[dynamize::dynamize] + pub trait MyTrait { + type A<'a>: Into<&'a str>; + + fn test1<'b>(&self) -> Self::A<'b>; + } + + fn test<T: MyTrait>(mut some: T) { + let dyn_trait: &dyn DynMyTrait = &some; + let _: &str = dyn_trait.test1(); + } +} diff --git a/tests/tests.rs b/tests/tests.rs new file mode 100644 index 0000000..6bccc04 --- /dev/null +++ b/tests/tests.rs @@ -0,0 +1,152 @@ +#![allow(dead_code)] + +#[test] +fn it_works() { + use dynamize::dynamize; + + mod some { + pub mod module { + pub type Result<T> = std::io::Result<T>; + } + } + + #[dynamize] + /// This is a great trait! + pub trait MyTrait { + type A: Into<String>; + // if there are multiple Into bounds the first one is used + type B: Into<i32> + Into<u64>; + + fn test1(&self) -> Self::A; + fn test2(&self) -> Self::B; + fn test3(&self) -> Option<Self::B>; + fn test4(&self) -> Result<(), Self::A>; + fn test5(&self) -> Result<Self::A, ()>; + /// some method documentation + fn test6(&self) -> Result<Self::A, Self::B>; + + #[allow(clippy::type_complexity)] + fn test7(&self) -> Result<Option<Option<Self::A>>, Option<Option<Self::B>>>; + + // also support Result type aliases with a fixed error type + fn test8(&self) -> some::module::Result<Self::A>; + + // fn test9(&self) -> &dyn Iterator<Item = Self::A>; + + fn mut1(&mut self) -> Self::A; + + fn safe1(&self); + fn safe2(&self, num: i32) -> i32; + fn safe3<'a>(&self, text: &'a str) -> &'a str; + fn safe4(&self) -> Option<i32>; + + // non-dispatchable functions are skipped + fn non_dispatch1(); + fn non_dispatch2(num: i32); + fn non_dispatch3(self) -> Self::A; + fn non_dispatch4(&self) + where + Self: Sized; + } + + fn test<T: MyTrait>(mut some: T) { + let dyn_trait: &dyn DynMyTrait = &some; + let _: String = dyn_trait.test1(); + let _: i32 = dyn_trait.test2(); + let _: Option<i32> = dyn_trait.test3(); + let _: Result<(), String> = dyn_trait.test4(); + let _: Result<String, ()> = dyn_trait.test5(); + let _: Result<String, i32> = dyn_trait.test6(); + let _: Result<Option<Option<String>>, Option<Option<i32>>> = dyn_trait.test7(); + + let dyn_trait: &mut dyn DynMyTrait = &mut some; + dyn_trait.mut1(); + + let _: () = dyn_trait.safe1(); + let _: i32 = dyn_trait.safe2(0); + let _: &str = dyn_trait.safe3("test"); + let _: Option<i32> = dyn_trait.safe4(); + } +} + +#[dynamize::dynamize] +trait Foo<X> { + type A: Into<String>; + + fn foobar(&self, x: X) -> Self::A; +} + +#[dynamize::dynamize] +trait Bar<X> { + fn foobar<A>(&self, x: X) -> A; +} + +fn test<T: Bar<X>, X, A>(some: T) { + let _dyn_trait: &dyn DynBar<X, A> = &some; +} + +#[dynamize::dynamize] +trait Bar1<X> { + fn foobar<A>(&self, x: X) -> A; + fn foobar1<B>(&self, x: X) -> B; + fn foobar2<C>(&self, x: X) -> C; +} + +fn test1<T: Bar1<X>, X, A, B, C>(some: T) { + let _dyn_trait: &dyn DynBar1<X, A, B, C> = &some; +} + +#[dynamize::dynamize] +trait Buz<X> { + type C: Into<String>; + + fn foobar<A>(&self, x: X) -> Result<A, Self::C>; +} + +fn test2<T: Buz<X>, X, A>(some: T, x: X) -> Result<A, String> { + let dyn_trait: &dyn DynBuz<X, A> = &some; + dyn_trait.foobar(x) +} + +#[dynamize::dynamize] +trait Gen { + fn foobar<A>(&self, a: A) -> A; +} + +use async_trait::async_trait; + +#[dynamize::dynamize] +#[dyn_trait_attr(async_trait)] +#[blanket_impl_attr(async_trait)] +#[async_trait] +trait SomeTraitWithAsync: Sync { + type A: Into<String>; + async fn test1(&self) -> Self::A; +} + +async fn async_test<T: SomeTraitWithAsync>(some: T) { + let dyn_trait: &dyn DynSomeTraitWithAsync = &some; + let _: String = dyn_trait.test1().await; +} + +#[dynamize::dynamize] +trait TraitWithCallback { + type A: Into<String>; + fn fun_with_callback<F: Fn(Self::A)>(&self, a: F); + + fn fun_with_callback1<X: Fn(Option<Self::A>)>(&self, a: X); + + fn fun_with_callback2<Y: Fn(i32, Option<Self::A>, String) -> bool>(&self, a: Y); + + fn fun_with_callback3<Z: Fn(i32)>(&self, a: Z); +} + +#[dynamize::dynamize] +#[dyn_trait_attr(async_trait)] +#[blanket_impl_attr(async_trait)] +#[async_trait] +trait AsyncWithCallback: Sync { + type A: Into<String>; + async fn test1(&self) -> Self::A; + async fn fun_with_callback<F: Fn(Self::A) + Sync + Send + 'static>(&self, a: F); +} |