diff options
| author | Martin Fischer <martin@push-f.com> | 2021-11-15 10:29:52 +0100 | 
|---|---|---|
| committer | Martin Fischer <martin@push-f.com> | 2021-11-18 23:36:01 +0100 | 
| commit | 2a8a0601afcb82d90d0766db5a954b70b10f856d (patch) | |
| tree | 0271062335d450e151598d4ad9aa327ffa0dfaea | |
publishv0.1.0
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | Cargo.toml | 28 | ||||
| -rw-r--r-- | README.md | 119 | ||||
| -rw-r--r-- | src/lib.rs | 425 | ||||
| -rw-r--r-- | src/parse_assoc_type.rs | 119 | ||||
| -rw-r--r-- | src/parse_trait_sig.rs | 505 | ||||
| -rw-r--r-- | src/syn_utils.rs | 97 | ||||
| -rw-r--r-- | tests/gats.rs | 17 | ||||
| -rw-r--r-- | tests/tests.rs | 152 | 
9 files changed, 1464 insertions, 0 deletions
| diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96ef6c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..3b31a88 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "dynamize" +version = "0.1.0" +description = "proc macro to make traits with associated types object-safe" +authors = ["Martin Fischer <martin@push-f.com>"] +license = "MIT" +repository = "https://git.push-f.com/dynamize" +keywords = ["proc-macro", "attribute", "dyn", "object-safe", "blanket"] +categories = ["rust-patterns"] +edition = "2021" + +[lib] +proc-macro = true +doctest = false + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +# internal features +nightly = [] + +[dependencies] +proc-macro2 = "1.0" +quote = "1.0" +syn = { version = "1.0", features = ["full", "extra-traits"] } + +[dev-dependencies] +async-trait = "0.1" diff --git a/README.md b/README.md new file mode 100644 index 0000000..332de4f --- /dev/null +++ b/README.md @@ -0,0 +1,119 @@ +# Dynamize + +In order for a trait to be usable as a trait object it needs to fulfill +[several requirements](https://doc.rust-lang.org/reference/items/traits.html#object-safety). +For example: + +```rust +trait Client { +    type Error; + +    fn get(&self, url: String) -> Result<Vec<u8>, Self::Error>; +} + +impl Client for HttpClient { type Error = HttpError; ...} +impl Client for FtpClient { type Error = FtpError; ...} + +let client: HttpClient = ...; +let object = &client as &dyn Client; +``` + +The last line of the above code fails to compile with: + +> error[E0191]: the value of the associated type `Error` (from trait `Client`) +> must be specified + +Sometimes you however want a trait object to be able to encompass trait +implementations with different associated type values. This crate provides an +attribute macro to achieve that. To use dynamize you only have to make some +small changes: + +```rust +#[dynamize::dynamize] +trait Client { +    type Error: Into<SuperError>; + +    fn get(&self, url: String) -> Result<Vec<u8>, Self::Error>; +} + +let client: HttpClient = ...; +let object = &client as &dyn DynClient; +``` + +1. You add the `#[dynamize::dynamize]` attribute to your trait. +2. You specify an `Into<T>` bound for each associated type. + +Dynamize defines a new trait for you, named after your trait but with the `Dyn` +prefix, so e.g. `Client` becomes `DynClient` in our example. The new +"dynamized" trait can then be used without having to specify the associated +type. + +## How does this work? + +For the above example dynamize generates the following code: + +```rust +trait DynClient { +    fn get(&self, url: String) -> Result<Vec<u8>, SuperError>; +} + +impl<__to_be_dynamized: Client> DynClient for __to_be_dynamized { +    fn get(&self, url: String) -> Result<Vec<u8>, SuperError> { +        Client::get(self, url).map_err(|x| x.into()) +    } +} +``` + +As you can see in the dynamized trait the associated type was replaced with the +destination type of the `Into` bound. The magic however happens afterwards: +dynamize generates a blanket implementation: each type implementing `Client` +automatically also implements `DynClient`! + +## How does this actually work? + +Dynamize recognizes the `Result<T, E>` in the return type and knows that +associated types in `T` need to be mapped with `map()` whereas associated types +in `E` need to be mapped with `map_err()`. Dynamize also understands +`Option<T>`. Thanks to recursion Dynamize can deal with arbitrarily nested +options and results, so e.g. `Result<Option<Self::Item>, Self::Error>` also +just works. + +## Dynamize supports async + +Dynamize supports async out of the box. Since Rust however does not yet support +async functions in traits, you'll have to additionally use another library like +[async-trait](https://crates.io/crates/async-trait), for example: + +```rust +#[dynamize::dynamize] +#[dyn_trait_attr(async_trait)] +#[blanket_impl_attr(async_trait)] +#[async_trait] +trait Client: Sync { +    type Error: Into<SuperError>; + +    async fn get(&self, url: String) -> Result<Vec<u8>, Self::Error>; +} +``` + +The `#[dyn_trait_attr(...)]` attribute lets you attach macro attributes to the +generated dynamized trait.  The `#[blanket_impl_attr(...)]` attribute lets you +attach macro attributes to the generated blanket implementation. Note that it +is important that the dynamize attribute comes before the `async_trait` +attribute. + +## Dynamize supports Fn, FnOnce & FnMut + +The following also just works: + +```rust +#[dynamize::dynamize] +trait TraitWithCallback { +    type A: Into<String>; + +    fn fun_with_callback<F: Fn(Self::A)>(&self, a: F); +} +``` + +Note that since in order to be object-safe methods must not have generics, +dynamize simply moves the generic from the method to the trait definition. diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..1f0432c --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,425 @@ +#![doc = include_str!("../README.md")] +use std::collections::HashMap; + +use proc_macro::TokenStream; +use proc_macro2::Group; +use proc_macro2::Ident; +use quote::format_ident; +use quote::quote; +use quote::quote_spanned; +use quote::ToTokens; +use syn::parse_macro_input; +use syn::punctuated::Punctuated; +use syn::token::Brace; +use syn::token::Gt; +use syn::token::Lt; +use syn::token::Trait; +use syn::AngleBracketedGenericArguments; +use syn::Block; +use syn::Expr; +use syn::GenericArgument; +use syn::GenericParam; +use syn::ImplItemMethod; +use syn::ItemTrait; +use syn::Path; +use syn::PathArguments; +use syn::PathSegment; +use syn::Signature; +use syn::Stmt; +use syn::TraitBound; +use syn::TraitItem; +use syn::TraitItemMethod; +use syn::Type; +use syn::TypeParam; +use syn::TypeParamBound; +use syn::TypePath; +use syn::Visibility; +use syn_utils::TypeMatcher; + +use crate::parse_assoc_type::parse_assoc_type; +use crate::parse_assoc_type::AssocTypeParseError; +use crate::parse_trait_sig::parse_trait_signature; +use crate::parse_trait_sig::MethodParseError; +use crate::parse_trait_sig::SignatureChanges; +use crate::parse_trait_sig::TypeTransform; +mod parse_assoc_type; +mod parse_trait_sig; +mod syn_utils; + +macro_rules! abort { +    ($span:expr, $message:literal) => { +        quote_spanned! {$span => compile_error!($message);}.into() +    }; +} + +#[proc_macro_attribute] +pub fn dynamize(_attr: TokenStream, input: TokenStream) -> TokenStream { +    let mut original_trait = parse_macro_input!(input as ItemTrait); +    assert!(original_trait.auto_token.is_none()); + +    let original_trait_name = original_trait.ident.clone(); + +    let mut objectifiable_methods: Vec<(Signature, SignatureChanges)> = Vec::new(); + +    let mut assoc_type_conversions: HashMap<&Ident, &Type> = HashMap::new(); + +    for item in &original_trait.items { +        if let TraitItem::Type(assoc_type) = item { +            match parse_assoc_type(assoc_type) { +                Err((_, AssocTypeParseError::NoIntoBound)) => continue, +                Err((span, AssocTypeParseError::AssocTypeInBound)) => { +                    return abort!(span, "dynamize does not support associated types here") +                } +                Err((span, AssocTypeParseError::GenericAssociatedType)) => { +                    return abort!( +                        span, +                        "dynamize does not (yet?) support generic associated types" +                    ) +                } +                Ok((ident, type_)) => { +                    assoc_type_conversions.insert(ident, type_); +                } +            } +        } +    } + +    for item in &original_trait.items { +        if let TraitItem::Method(method) = item { +            let mut signature = method.sig.clone(); +            match parse_trait_signature(&mut signature, &assoc_type_conversions) { +                Err((_, MethodParseError::NonDispatchableMethod)) => continue, +                Err((span, MethodParseError::UnconvertibleAssocType)) => { +                    return abort!( +                        span, +                        "associated type is either undefined or doesn't have an Into bound" +                    ) +                } +                Err((span, MethodParseError::AssocTypeInInputs)) => { +                    return abort!( +                        span, +                        "dynamize does not support associated types in parameter types" +                    ) +                } +                Err(( +                    span, +                    MethodParseError::AssocTypeInUnsupportedReturnType +                    | MethodParseError::UnconvertibleAssocTypeInFnInput, +                )) => return abort!(span, "dynamize does not know how to convert this type"), +                Err((span, MethodParseError::UnconvertibleAssocTypeInTraitBound)) => { +                    return abort!(span, "dynamize does not support associated types here") +                } +                Err((span, MethodParseError::ImplTraitInInputs)) => { +                    return abort!( +                        span, +                        "dynamize does not support impl here, change it to a method generic" +                    ) +                } +                Ok(parsed_method) => objectifiable_methods.push((signature, parsed_method)), +            }; +        } +    } + +    let mut method_impls: Vec<ImplItemMethod> = Vec::new(); + +    let mut blanket_impl_attrs = Vec::new(); +    let mut dyn_trait_attrs = Vec::new(); + +    // FUTURE: use Vec::drain_filter once it's stable +    let mut i = 0; +    while i < original_trait.attrs.len() { +        if original_trait.attrs[i].path.is_ident("blanket_impl_attr") { +            let attr = original_trait.attrs.remove(i); +            let group: Group = match syn::parse2(attr.tokens) { +                Ok(g) => g, +                Err(err) => { +                    return abort!( +                        err.span(), +                        "expected parenthesis: #[blanket_impl_attr(...)]" +                    ) +                } +            }; +            let tokens = group.stream(); +            blanket_impl_attrs.push(quote! {#[#tokens]}); +        } else if original_trait.attrs[i].path.is_ident("dyn_trait_attr") { +            let attr = original_trait.attrs.remove(i); +            let group: Group = match syn::parse2(attr.tokens) { +                Ok(g) => g, +                Err(err) => { +                    return abort!(err.span(), "expected parenthesis: #[dyn_trait_attr(...)]") +                } +            }; +            let tokens = group.stream(); +            dyn_trait_attrs.push(quote! {#[#tokens]}); +        } else { +            i += 1; +        } +    } + +    let mut dyn_trait = ItemTrait { +        ident: format_ident!("Dyn{}", original_trait.ident), + +        attrs: Vec::new(), +        vis: original_trait.vis.clone(), +        unsafety: original_trait.unsafety, +        auto_token: None, +        trait_token: Trait::default(), +        generics: original_trait.generics.clone(), +        colon_token: None, +        supertraits: Punctuated::new(), +        brace_token: Brace::default(), +        items: Vec::new(), +    }; + +    for (signature, parsed_method) in objectifiable_methods { +        let mut new_method = TraitItemMethod { +            attrs: Vec::new(), +            sig: signature, +            default: None, +            semi_token: None, +        }; + +        let fun_name = &new_method.sig.ident; + +        let args = new_method +            .sig +            .inputs +            .iter() +            .enumerate() +            .map(|(idx, arg)| match arg { +                syn::FnArg::Receiver(_) => quote! {self}, +                syn::FnArg::Typed(pat_type) => match pat_type.pat.as_ref() { +                    syn::Pat::Ident(ident) => match &parsed_method.inputs[idx] { +                        None => ident.ident.to_token_stream(), +                        Some(transforms) => { +                            let args = (0..transforms.len()).map(|i| format_ident!("a{}", i)); +                            let mut calls: Vec<_> = +                                args.clone().map(|i| i.into_token_stream()).collect(); +                            for i in 0..calls.len() { +                                transforms[i].append_conversion(&mut calls[i]); +                            } +                            let move_opt = new_method.sig.asyncness.map(|_| quote! {move}); +                            quote!(#move_opt |#(#args),*| #ident(#(#calls),*)) +                        } +                    }, +                    _other => { +                        panic!("unexpected"); +                    } +                }, +            }); + +        // in order for a trait to be object-safe its methods may not have +        // generics so we convert method generics into trait generics +        if new_method +            .sig +            .generics +            .params +            .iter() +            .any(|p| matches!(p, GenericParam::Type(_))) +        { +            // syn::punctuated::Punctuated doesn't have a remove(index) +            // method so we firstly move all elements to a vector +            let mut params = Vec::new(); +            while let Some(generic_param) = new_method.sig.generics.params.pop() { +                params.push(generic_param.into_value()); +            } + +            // FUTURE: use Vec::drain_filter once it's stable +            let mut i = 0; +            while i < params.len() { +                if matches!(params[i], GenericParam::Type(_)) { +                    dyn_trait.generics.params.push(params.remove(i)); +                } else { +                    i += 1; +                } +            } + +            new_method.sig.generics.params.extend(params); + +            if dyn_trait.generics.lt_token.is_none() { +                dyn_trait.generics.lt_token = Some(Lt::default()); +                dyn_trait.generics.gt_token = Some(Gt::default()); +            } +        } + +        let mut expr = quote!(#original_trait_name::#fun_name(#(#args),*)); +        if new_method.sig.asyncness.is_some() { +            expr.extend(quote! {.await}) +        } +        parsed_method.return_type.append_conversion(&mut expr); + +        method_impls.push(ImplItemMethod { +            attrs: Vec::new(), +            vis: Visibility::Inherited, +            defaultness: None, +            sig: new_method.sig.clone(), +            block: Block { +                brace_token: Brace::default(), +                stmts: vec![Stmt::Expr(Expr::Verbatim(expr))], +            }, +        }); +        dyn_trait.items.push(new_method.into()); +    } + +    let blanket_impl = generate_blanket_impl(&dyn_trait, &original_trait, &method_impls); + +    let expanded = quote! { +        #original_trait + +        #(#dyn_trait_attrs)* +        #dyn_trait + +        #(#blanket_impl_attrs)* +        #blanket_impl +    }; +    TokenStream::from(expanded) +} + +fn generate_blanket_impl( +    dyn_trait: &ItemTrait, +    original_trait: &ItemTrait, +    method_impls: &[ImplItemMethod], +) -> proc_macro2::TokenStream { +    let mut blanket_generics = dyn_trait.generics.clone(); +    let some_ident = format_ident!("__to_be_dynamized"); +    blanket_generics.params.push(GenericParam::Type(TypeParam { +        attrs: Vec::new(), +        ident: some_ident.clone(), +        colon_token: None, +        bounds: std::iter::once(TypeParamBound::Trait(TraitBound { +            paren_token: None, +            modifier: syn::TraitBoundModifier::None, +            lifetimes: None, +            path: Path { +                leading_colon: None, +                segments: std::iter::once(path_segment_for_trait(original_trait)).collect(), +            }, +        })) +        .collect(), +        eq_token: None, +        default: None, +    })); +    let (_, type_gen, _where) = dyn_trait.generics.split_for_impl(); +    let dyn_trait_name = &dyn_trait.ident; +    quote! { +        impl #blanket_generics #dyn_trait_name #type_gen for #some_ident { +            #(#method_impls)* +        } +    } +} + +struct AssocTypeMatcher; +impl TypeMatcher<Path> for AssocTypeMatcher { +    fn match_path<'a>(&self, path: &'a Path) -> Option<&'a Path> { +        if path.segments.first().unwrap().ident == "Self" { +            return Some(path); +        } +        None +    } +} + +impl<TypeMatch, PathMatch, T> TypeMatcher<T> for (TypeMatch, PathMatch) +where +    TypeMatch: Fn(&Type) -> Option<&T>, +    PathMatch: Fn(&Path) -> Option<&T>, +{ +    fn match_type<'a>(&self, t: &'a Type) -> Option<&'a T> { +        self.0(t) +    } + +    fn match_path<'a>(&self, path: &'a Path) -> Option<&'a T> { +        self.1(path) +    } +} + +impl TypeTransform { +    fn append_conversion(&self, stream: &mut proc_macro2::TokenStream) { +        match self { +            TypeTransform::Into => stream.extend(quote! {.into()}), +            TypeTransform::Map(opt) => { +                let mut inner = quote!(x); +                opt.append_conversion(&mut inner); +                stream.extend(quote! {.map(|x| #inner)}) +            } +            TypeTransform::Result(ok, err) => { +                if !matches!(ok.as_ref(), TypeTransform::NoOp) { +                    let mut inner = quote!(x); +                    ok.append_conversion(&mut inner); +                    stream.extend(quote! {.map(|x| #inner)}) +                } +                if !matches!(err.as_ref(), TypeTransform::NoOp) { +                    let mut inner = quote!(x); +                    err.append_conversion(&mut inner); +                    stream.extend(quote! {.map_err(|x| #inner)}) +                } +            } +            _other => {} +        } +    } +} + +/// Just a convenience trait for us to avoid match/if-let blocks everywhere. +trait As<T> { +    fn get_as(&self) -> Option<&T>; +    fn get_as_mut(&mut self) -> Option<&mut T>; +} + +impl As<AngleBracketedGenericArguments> for PathArguments { +    fn get_as(&self) -> Option<&AngleBracketedGenericArguments> { +        match self { +            PathArguments::AngleBracketed(args) => Some(args), +            _other => None, +        } +    } +    fn get_as_mut(&mut self) -> Option<&mut AngleBracketedGenericArguments> { +        match self { +            PathArguments::AngleBracketed(args) => Some(args), +            _other => None, +        } +    } +} + +impl As<Type> for GenericArgument { +    fn get_as(&self) -> Option<&Type> { +        match self { +            GenericArgument::Type(typearg) => Some(typearg), +            _other => None, +        } +    } +    fn get_as_mut(&mut self) -> Option<&mut Type> { +        match self { +            GenericArgument::Type(typearg) => Some(typearg), +            _other => None, +        } +    } +} + +fn path_segment_for_trait(sometrait: &ItemTrait) -> PathSegment { +    PathSegment { +        ident: sometrait.ident.clone(), +        arguments: match sometrait.generics.params.is_empty() { +            true => PathArguments::None, +            false => PathArguments::AngleBracketed(AngleBracketedGenericArguments { +                colon2_token: None, +                lt_token: Lt::default(), +                args: sometrait +                    .generics +                    .params +                    .iter() +                    .map(|param| match param { +                        GenericParam::Type(type_param) => { +                            GenericArgument::Type(Type::Path(TypePath { +                                path: type_param.ident.clone().into(), +                                qself: None, +                            })) +                        } +                        GenericParam::Lifetime(lifetime_def) => { +                            GenericArgument::Lifetime(lifetime_def.lifetime.clone()) +                        } +                        GenericParam::Const(_) => todo!("const generic param not supported"), +                    }) +                    .collect(), +                gt_token: Gt::default(), +            }), +        }, +    } +} diff --git a/src/parse_assoc_type.rs b/src/parse_assoc_type.rs new file mode 100644 index 0000000..37fd78c --- /dev/null +++ b/src/parse_assoc_type.rs @@ -0,0 +1,119 @@ +use proc_macro2::Span; +use syn::spanned::Spanned; +use syn::{GenericArgument, Ident, PathArguments, PathSegment, TraitItemType, Type}; + +use crate::syn_utils::{find_in_type, trait_bounds}; +use crate::AssocTypeMatcher; + +#[derive(Debug)] +pub enum AssocTypeParseError { +    AssocTypeInBound, +    GenericAssociatedType, +    NoIntoBound, +} + +pub fn parse_assoc_type( +    assoc_type: &TraitItemType, +) -> Result<(&Ident, &Type), (Span, AssocTypeParseError)> { +    for bound in trait_bounds(&assoc_type.bounds) { +        if let PathSegment { +            ident, +            arguments: PathArguments::AngleBracketed(args), +        } = bound.path.segments.first().unwrap() +        { +            if ident == "Into" && args.args.len() == 1 { +                if let GenericArgument::Type(into_type) = args.args.first().unwrap() { +                    // provide a better error message for type A: Into<Self::B> +                    if find_in_type(into_type, &AssocTypeMatcher).is_some() { +                        return Err((into_type.span(), AssocTypeParseError::AssocTypeInBound)); +                    } + +                    // TODO: support lifetime GATs (see the currently failing tests/gats.rs) +                    if !assoc_type.generics.params.is_empty() { +                        return Err(( +                            assoc_type.generics.params.span(), +                            AssocTypeParseError::GenericAssociatedType, +                        )); +                    } + +                    return Ok((&assoc_type.ident, into_type)); +                } +            } +        } +    } +    Err((assoc_type.span(), AssocTypeParseError::NoIntoBound)) +} + +#[cfg(test)] +mod tests { +    use quote::quote; +    use syn::{TraitItemType, Type}; + +    use crate::parse_assoc_type::{parse_assoc_type, AssocTypeParseError}; + +    #[test] +    fn ok() { +        let type1: TraitItemType = syn::parse2(quote! { +            type A: Into<String>; +        }) +        .unwrap(); + +        assert!(matches!( +            parse_assoc_type(&type1), +            Ok((id, Type::Path(path))) +            if id == "A" && path.path.is_ident("String") +        )); +    } + +    #[test] +    fn err_no_bound() { +        let type1: TraitItemType = syn::parse2(quote! { +            type A; +        }) +        .unwrap(); + +        assert!(matches!( +            parse_assoc_type(&type1), +            Err((_, AssocTypeParseError::NoIntoBound)) +        )); +    } + +    #[test] +    fn err_assoc_type_in_bound() { +        let type1: TraitItemType = syn::parse2(quote! { +            type A: Into<Self::B>; +        }) +        .unwrap(); + +        assert!(matches!( +            parse_assoc_type(&type1), +            Err((_, AssocTypeParseError::AssocTypeInBound)) +        )); +    } + +    #[test] +    fn err_gat_type() { +        let type1: TraitItemType = syn::parse2(quote! { +            type A<X>: Into<Foobar<X>>; +        }) +        .unwrap(); + +        assert!(matches!( +            parse_assoc_type(&type1), +            Err((_, AssocTypeParseError::GenericAssociatedType)) +        )); +    } + +    #[test] +    fn err_gat_lifetime() { +        let type1: TraitItemType = syn::parse2(quote! { +            type A<'a>: Into<Foobar<'a>>; +        }) +        .unwrap(); + +        assert!(matches!( +            parse_assoc_type(&type1), +            Err((_, AssocTypeParseError::GenericAssociatedType)) +        )); +    } +} diff --git a/src/parse_trait_sig.rs b/src/parse_trait_sig.rs new file mode 100644 index 0000000..55a3214 --- /dev/null +++ b/src/parse_trait_sig.rs @@ -0,0 +1,505 @@ +use std::collections::HashMap; + +use proc_macro2::Span; +use syn::{ +    spanned::Spanned, FnArg, Ident, PathArguments, PredicateType, Receiver, ReturnType, Type, +    TypePath, WherePredicate, +}; +use syn::{GenericParam, Signature, TypeImplTrait, TypeParamBound}; + +use crate::syn_utils::{find_in_path, find_in_type, trait_bounds, TypeMatcher}; +use crate::{As, AssocTypeMatcher}; + +#[derive(Debug, Clone)] +pub enum TypeTransform { +    NoOp, +    Into, +    Map(Box<TypeTransform>), +    Result(Box<TypeTransform>, Box<TypeTransform>), +} + +#[derive(Debug)] +pub enum MethodParseError { +    NonDispatchableMethod, +    AssocTypeInInputs, +    ImplTraitInInputs, +    AssocTypeInUnsupportedReturnType, +    UnconvertibleAssocTypeInFnInput, +    UnconvertibleAssocTypeInTraitBound, +    UnconvertibleAssocType, +} + +struct ImplTraitMatcher; +impl TypeMatcher<TypeImplTrait> for ImplTraitMatcher { +    fn match_type<'a>(&self, t: &'a Type) -> Option<&'a TypeImplTrait> { +        if let Type::ImplTrait(impltrait) = t { +            Some(impltrait) +        } else { +            None +        } +    } +} + +pub struct SignatureChanges { +    pub return_type: TypeTransform, +    pub inputs: Vec<Option<Vec<TypeTransform>>>, +} + +pub fn parse_trait_signature( +    signature: &mut Signature, +    assoc_type_conversions: &HashMap<&Ident, &Type>, +) -> Result<SignatureChanges, (Span, MethodParseError)> { +    let assoc_type_conversions = AssocTypeConversions(assoc_type_conversions); + +    if is_non_dispatchable(signature) { +        return Err((signature.span(), MethodParseError::NonDispatchableMethod)); +    } + +    // provide better error messages for associated types in params +    for input in &signature.inputs { +        if let FnArg::Typed(pattype) = input { +            if find_in_type(&pattype.ty, &AssocTypeMatcher).is_some() { +                return Err((pattype.ty.span(), MethodParseError::AssocTypeInInputs)); +            } +            if let Some(impl_trait) = find_in_type(&pattype.ty, &ImplTraitMatcher) { +                return Err((impl_trait.span(), MethodParseError::ImplTraitInInputs)); +            } +        } +    } + +    let mut type_param_transforms = HashMap::new(); +    let mut input_transforms = Vec::new(); + +    for generic_param in &mut signature.generics.params { +        if let GenericParam::Type(type_param) = generic_param { +            for bound in &mut type_param.bounds { +                if let TypeParamBound::Trait(bound) = bound { +                    if bound.path.segments.len() == 1 { +                        let segment = bound.path.segments.first_mut().unwrap(); + +                        if let PathArguments::Parenthesized(args) = &mut segment.arguments { +                            if segment.ident == "Fn" +                                || segment.ident == "FnOnce" +                                || segment.ident == "FnMut" +                            { +                                let mut transforms = Vec::new(); +                                for input_type in &mut args.inputs { +                                    match assoc_type_conversions.parse_type_path(input_type) { +                                        Ok(ret_type) => { +                                            transforms.push(ret_type); +                                        } +                                        Err(TransformError::UnconvertibleAssocType(span)) => { +                                            return Err(( +                                                span, +                                                MethodParseError::UnconvertibleAssocType, +                                            )); +                                        } +                                        Err(TransformError::AssocTypeInUnsupportedType(span)) => { +                                            return Err(( +                                                span, +                                                MethodParseError::UnconvertibleAssocTypeInFnInput, +                                            )); +                                        } +                                    } +                                } +                                if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { +                                    type_param_transforms.insert(&type_param.ident, transforms); +                                } +                            } +                        } +                    } +                    if let Some(path) = find_in_path(&bound.path, &AssocTypeMatcher) { +                        return Err(( +                            path.span(), +                            MethodParseError::UnconvertibleAssocTypeInTraitBound, +                        )); +                    } +                } +            } +        } +    } + +    for input in &signature.inputs { +        if let FnArg::Typed(pattype) = input { +            if let Type::Path(path) = &*pattype.ty { +                if let Some(ident) = path.path.get_ident() { +                    input_transforms.push(type_param_transforms.get(ident).map(|x| (*x).clone())); +                    continue; +                } +            } +        } +        input_transforms.push(None); +    } + +    let return_type = match &mut signature.output { +        ReturnType::Type(_, og_type) => match assoc_type_conversions.parse_type_path(og_type) { +            Ok(ret_type) => ret_type, +            Err(TransformError::UnconvertibleAssocType(span)) => { +                return Err((span, MethodParseError::UnconvertibleAssocType)); +            } +            Err(TransformError::AssocTypeInUnsupportedType(span)) => { +                return Err((span, MethodParseError::AssocTypeInUnsupportedReturnType)); +            } +        }, +        ReturnType::Default => TypeTransform::NoOp, +    }; +    Ok(SignatureChanges { +        return_type, +        inputs: input_transforms, +    }) +} + +struct AssocTypeConversions<'a>(&'a HashMap<&'a Ident, &'a Type>); + +enum TransformError { +    UnconvertibleAssocType(Span), +    AssocTypeInUnsupportedType(Span), +} + +impl AssocTypeConversions<'_> { +    fn parse_type_path(&self, type_: &mut Type) -> Result<TypeTransform, TransformError> { +        let assoc_span = match find_in_type(type_, &AssocTypeMatcher) { +            Some(path) => path.span(), +            None => return Ok(TypeTransform::NoOp), +        }; + +        if let Type::Path(TypePath { path, qself: None }) = type_ { +            let ident = &path.segments.first().unwrap().ident; + +            // TODO: support &mut dyn Iterator<Item = Self::A> +            // conversion to  Box<dyn Iterator<Item = Whatever>> via .map(Into::into) + +            if ident == "Self" && path.segments.len() == 2 { +                let ident = &path.segments.last().unwrap().ident; +                *type_ = (*self +                    .0 +                    .get(&ident) +                    .ok_or_else(|| TransformError::UnconvertibleAssocType(ident.span()))?) +                .clone(); +                return Ok(TypeTransform::Into); +            } else if ident == "Option" && path.segments.len() == 1 { +                let first_seg = path.segments.first_mut().unwrap(); + +                if let Some(args) = first_seg.arguments.get_as_mut() { +                    if args.args.len() == 1 { +                        if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() { +                            if find_in_type(generic_type, &AssocTypeMatcher).is_some() { +                                return Ok(TypeTransform::Map( +                                    self.parse_type_path(generic_type)?.into(), +                                )); +                            } +                        } +                    } +                } +            } else if ident == "Result" && path.segments.len() == 1 { +                let first_seg = path.segments.first_mut().unwrap(); +                if let Some(args) = first_seg.arguments.get_as_mut() { +                    if args.args.len() == 2 { +                        let mut args_iter = args.args.iter_mut(); +                        if let (Some(ok_type), Some(err_type)) = ( +                            args_iter.next().unwrap().get_as_mut(), +                            args_iter.next().unwrap().get_as_mut(), +                        ) { +                            if find_in_type(ok_type, &AssocTypeMatcher).is_some() +                                || find_in_type(err_type, &AssocTypeMatcher).is_some() +                            { +                                return Ok(TypeTransform::Result( +                                    self.parse_type_path(ok_type)?.into(), +                                    self.parse_type_path(err_type)?.into(), +                                )); +                            } +                        } +                    } +                } +            } else { +                let last_seg = &path.segments.last().unwrap(); +                if last_seg.ident == "Result" { +                    let last_seg = path.segments.last_mut().unwrap(); +                    if let Some(args) = last_seg.arguments.get_as_mut() { +                        if args.args.len() == 1 { +                            if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() +                            { +                                if find_in_type(generic_type, &AssocTypeMatcher).is_some() { +                                    return Ok(TypeTransform::Map( +                                        self.parse_type_path(generic_type)?.into(), +                                    )); +                                } +                            } +                        } +                    } +                } +            } +        } + +        // the type contains an associated type but we +        // don't know how to deal with it so we abort +        Err(TransformError::AssocTypeInUnsupportedType(assoc_span)) +    } +} + +fn is_non_dispatchable(signature: &Signature) -> bool { +    // non-dispatchable: fn example(&self) where Self: Sized; +    if let Some(where_clause) = &signature.generics.where_clause { +        if where_clause +            .predicates +            .iter() +            .any(bounds_self_and_has_bound_sized) +        { +            return true; +        } +    } + +    // non-dispatchable: fn example(); +    if signature.inputs.is_empty() { +        return true; +    } + +    // non-dispatchable: fn example(arg: Type); +    if matches!(signature.inputs.first(), Some(FnArg::Typed(_))) { +        return true; +    } + +    // non-dispatchable: fn example(self); +    if matches!( +        signature.inputs.first(), +        Some(FnArg::Receiver(Receiver { +            reference: None, +            .. +        })) +    ) { +        return true; +    } +    false +} + +/// Returns true if the bounded type is `Self` and the bounds contain `Sized`. +fn bounds_self_and_has_bound_sized(predicate: &WherePredicate) -> bool { +    matches!( +        predicate, +        WherePredicate::Type(PredicateType { +            bounded_ty: Type::Path(TypePath { path, .. }), +            bounds, +            .. +        }) +        if path.is_ident("Self") +        && trait_bounds(bounds).any(|b| b.path.is_ident("Sized")) +    ) +} + +#[cfg(test)] +mod tests { +    use std::collections::HashMap; + +    use quote::{format_ident, quote}; +    use syn::{TraitItemMethod, Type}; + +    use crate::parse_trait_sig::{ +        parse_trait_signature, MethodParseError, SignatureChanges, TypeTransform, +    }; + +    #[test] +    fn ok_void() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(&self); +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Ok(SignatureChanges { +                return_type: TypeTransform::NoOp, +                .. +            }) +        )); +    } + +    #[test] +    fn ok_assoc_type() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(&self) -> Self::A; +        }) +        .unwrap(); + +        let mut assoc_type_map = HashMap::new(); +        let ident = format_ident!("A"); +        let dest = Type::Verbatim(quote! {Example}); +        assoc_type_map.insert(&ident, &dest); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &assoc_type_map), +            Ok(SignatureChanges { +                return_type: TypeTransform::Into, +                .. +            }) +        )); +    } + +    #[test] +    fn err_unconvertible_assoc_type() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(&self) -> Self::A; +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Err((_, MethodParseError::UnconvertibleAssocType)) +        )); +    } + +    #[test] +    fn err_non_dispatchable_assoc_function_no_args() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(); +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Err((_, MethodParseError::NonDispatchableMethod)) +        )); +    } + +    #[test] +    fn err_non_dispatchable_assoc_function_with_args() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(arg: Type); +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Err((_, MethodParseError::NonDispatchableMethod)) +        )); +    } + +    #[test] +    fn err_non_dispatchable_consume_self() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(self); +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Err((_, MethodParseError::NonDispatchableMethod)) +        )); +    } + +    #[test] +    fn err_non_dispatchable_where_self_sized() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(&self) where Self: Sized; +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Err((_, MethodParseError::NonDispatchableMethod)) +        )); +    } + +    #[test] +    fn err_assoc_type_in_unsupported_return() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(&self) -> Foo<Self::A>; +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) +        )); +    } + +    #[test] +    fn err_assoc_type_in_unsupported_return_in_opt() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(&self) -> Option<Foo<Self::A>>; +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) +        )); +    } + +    #[test] +    fn err_assoc_type_in_unsupported_return_in_ok() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(&self) -> Result<Foo<Self::A>, Error>; +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) +        )); +    } + +    #[test] +    fn err_assoc_type_in_unsupported_return_in_err() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(&self) -> Result<Ok, Foo<Self::A>>; +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) +        )); +    } + +    #[test] +    fn err_assoc_type_in_input() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(&self, x: Self::A); +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Err((_, MethodParseError::AssocTypeInInputs)) +        )); +    } + +    #[test] +    fn err_assoc_type_in_input_opt() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(&self, x: Option<Self::A>); +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Err((_, MethodParseError::AssocTypeInInputs)) +        )); +    } + +    #[test] +    fn err_impl_in_input() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test(&self, arg: Option<impl SomeTrait>); +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Err((_, MethodParseError::ImplTraitInInputs)) +        )); +    } + +    #[test] +    fn err_assoc_type_in_generic() { +        let mut type1: TraitItemMethod = syn::parse2(quote! { +            fn test<F: Fn(Foo<Self::A>)>(&self, fun: F); +        }) +        .unwrap(); + +        assert!(matches!( +            parse_trait_signature(&mut type1.sig, &Default::default()), +            Err((_, MethodParseError::UnconvertibleAssocTypeInFnInput)) +        )); +    } +} diff --git a/src/syn_utils.rs b/src/syn_utils.rs new file mode 100644 index 0000000..4588186 --- /dev/null +++ b/src/syn_utils.rs @@ -0,0 +1,97 @@ +use syn::{ +    punctuated::Punctuated, GenericArgument, Path, PathArguments, ReturnType, TraitBound, Type, +    TypeParamBound, +}; + +pub trait TypeMatcher<T> { +    fn match_type<'a>(&self, t: &'a Type) -> Option<&'a T> { +        None +    } +    fn match_path<'a>(&self, path: &'a Path) -> Option<&'a T> { +        None +    } +} + +pub fn trait_bounds<T>( +    bounds: &Punctuated<TypeParamBound, T>, +) -> impl Iterator<Item = &TraitBound> { +    bounds.iter().filter_map(|b| match b { +        TypeParamBound::Trait(t) => Some(t), +        TypeParamBound::Lifetime(_) => None, +    }) +} + +pub fn find_in_type<'a, T>(t: &'a Type, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> { +    if let Some(ret) = matcher.match_type(t) { +        return Some(ret); +    } +    match t { +        Type::Array(array) => find_in_type(&array.elem, matcher), +        Type::BareFn(fun) => { +            for input in &fun.inputs { +                if let Some(ret) = find_in_type(&input.ty, matcher) { +                    return Some(ret); +                } +            } +            if let ReturnType::Type(_, t) = &fun.output { +                return find_in_type(t, matcher); +            } +            None +        } +        Type::Group(group) => find_in_type(&group.elem, matcher), +        Type::ImplTrait(impltrait) => { +            for bound in &impltrait.bounds { +                if let TypeParamBound::Trait(bound) = bound { +                    if let Some(ret) = find_in_path(&bound.path, matcher) { +                        return Some(ret); +                    } +                } +            } +            None +        } +        Type::Infer(_) => None, +        Type::Macro(_) => None, +        Type::Paren(paren) => find_in_type(&paren.elem, matcher), +        Type::Path(path) => find_in_path(&path.path, matcher), +        Type::Ptr(ptr) => find_in_type(&ptr.elem, matcher), +        Type::Reference(reference) => find_in_type(&reference.elem, matcher), +        Type::Slice(slice) => find_in_type(&slice.elem, matcher), +        Type::TraitObject(traitobj) => { +            for bound in &traitobj.bounds { +                if let TypeParamBound::Trait(bound) = bound { +                    if let Some(ret) = find_in_path(&bound.path, matcher) { +                        return Some(ret); +                    } +                } +            } +            None +        } +        Type::Tuple(tuple) => { +            for elem in &tuple.elems { +                if let Some(ret) = find_in_type(elem, matcher) { +                    return Some(ret); +                } +            } +            None +        } +        _other => None, +    } +} + +pub fn find_in_path<'a, T>(path: &'a Path, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> { +    if let Some(ret) = matcher.match_path(path) { +        return Some(ret); +    } +    for segment in &path.segments { +        if let PathArguments::AngleBracketed(args) = &segment.arguments { +            for arg in &args.args { +                if let GenericArgument::Type(t) = arg { +                    if let Some(ret) = find_in_type(t, matcher) { +                        return Some(ret); +                    } +                } +            } +        } +    } +    None +} diff --git a/tests/gats.rs b/tests/gats.rs new file mode 100644 index 0000000..92e483c --- /dev/null +++ b/tests/gats.rs @@ -0,0 +1,17 @@ +//! This test can be run with `cargo +nightly test --features=nightly` +#![cfg_attr(feature = "nightly", feature(generic_associated_types))] + +#[cfg(feature = "nightly")] +mod test_gats { +    #[dynamize::dynamize] +    pub trait MyTrait { +        type A<'a>: Into<&'a str>; + +        fn test1<'b>(&self) -> Self::A<'b>; +    } + +    fn test<T: MyTrait>(mut some: T) { +        let dyn_trait: &dyn DynMyTrait = &some; +        let _: &str = dyn_trait.test1(); +    } +} diff --git a/tests/tests.rs b/tests/tests.rs new file mode 100644 index 0000000..6bccc04 --- /dev/null +++ b/tests/tests.rs @@ -0,0 +1,152 @@ +#![allow(dead_code)] + +#[test] +fn it_works() { +    use dynamize::dynamize; + +    mod some { +        pub mod module { +            pub type Result<T> = std::io::Result<T>; +        } +    } + +    #[dynamize] +    /// This is a great trait! +    pub trait MyTrait { +        type A: Into<String>; +        // if there are multiple Into bounds the first one is used +        type B: Into<i32> + Into<u64>; + +        fn test1(&self) -> Self::A; +        fn test2(&self) -> Self::B; +        fn test3(&self) -> Option<Self::B>; +        fn test4(&self) -> Result<(), Self::A>; +        fn test5(&self) -> Result<Self::A, ()>; +        /// some method documentation +        fn test6(&self) -> Result<Self::A, Self::B>; + +        #[allow(clippy::type_complexity)] +        fn test7(&self) -> Result<Option<Option<Self::A>>, Option<Option<Self::B>>>; + +        // also support Result type aliases with a fixed error type +        fn test8(&self) -> some::module::Result<Self::A>; + +        // fn test9(&self) -> &dyn Iterator<Item = Self::A>; + +        fn mut1(&mut self) -> Self::A; + +        fn safe1(&self); +        fn safe2(&self, num: i32) -> i32; +        fn safe3<'a>(&self, text: &'a str) -> &'a str; +        fn safe4(&self) -> Option<i32>; + +        // non-dispatchable functions are skipped +        fn non_dispatch1(); +        fn non_dispatch2(num: i32); +        fn non_dispatch3(self) -> Self::A; +        fn non_dispatch4(&self) +        where +            Self: Sized; +    } + +    fn test<T: MyTrait>(mut some: T) { +        let dyn_trait: &dyn DynMyTrait = &some; +        let _: String = dyn_trait.test1(); +        let _: i32 = dyn_trait.test2(); +        let _: Option<i32> = dyn_trait.test3(); +        let _: Result<(), String> = dyn_trait.test4(); +        let _: Result<String, ()> = dyn_trait.test5(); +        let _: Result<String, i32> = dyn_trait.test6(); +        let _: Result<Option<Option<String>>, Option<Option<i32>>> = dyn_trait.test7(); + +        let dyn_trait: &mut dyn DynMyTrait = &mut some; +        dyn_trait.mut1(); + +        let _: () = dyn_trait.safe1(); +        let _: i32 = dyn_trait.safe2(0); +        let _: &str = dyn_trait.safe3("test"); +        let _: Option<i32> = dyn_trait.safe4(); +    } +} + +#[dynamize::dynamize] +trait Foo<X> { +    type A: Into<String>; + +    fn foobar(&self, x: X) -> Self::A; +} + +#[dynamize::dynamize] +trait Bar<X> { +    fn foobar<A>(&self, x: X) -> A; +} + +fn test<T: Bar<X>, X, A>(some: T) { +    let _dyn_trait: &dyn DynBar<X, A> = &some; +} + +#[dynamize::dynamize] +trait Bar1<X> { +    fn foobar<A>(&self, x: X) -> A; +    fn foobar1<B>(&self, x: X) -> B; +    fn foobar2<C>(&self, x: X) -> C; +} + +fn test1<T: Bar1<X>, X, A, B, C>(some: T) { +    let _dyn_trait: &dyn DynBar1<X, A, B, C> = &some; +} + +#[dynamize::dynamize] +trait Buz<X> { +    type C: Into<String>; + +    fn foobar<A>(&self, x: X) -> Result<A, Self::C>; +} + +fn test2<T: Buz<X>, X, A>(some: T, x: X) -> Result<A, String> { +    let dyn_trait: &dyn DynBuz<X, A> = &some; +    dyn_trait.foobar(x) +} + +#[dynamize::dynamize] +trait Gen { +    fn foobar<A>(&self, a: A) -> A; +} + +use async_trait::async_trait; + +#[dynamize::dynamize] +#[dyn_trait_attr(async_trait)] +#[blanket_impl_attr(async_trait)] +#[async_trait] +trait SomeTraitWithAsync: Sync { +    type A: Into<String>; +    async fn test1(&self) -> Self::A; +} + +async fn async_test<T: SomeTraitWithAsync>(some: T) { +    let dyn_trait: &dyn DynSomeTraitWithAsync = &some; +    let _: String = dyn_trait.test1().await; +} + +#[dynamize::dynamize] +trait TraitWithCallback { +    type A: Into<String>; +    fn fun_with_callback<F: Fn(Self::A)>(&self, a: F); + +    fn fun_with_callback1<X: Fn(Option<Self::A>)>(&self, a: X); + +    fn fun_with_callback2<Y: Fn(i32, Option<Self::A>, String) -> bool>(&self, a: Y); + +    fn fun_with_callback3<Z: Fn(i32)>(&self, a: Z); +} + +#[dynamize::dynamize] +#[dyn_trait_attr(async_trait)] +#[blanket_impl_attr(async_trait)] +#[async_trait] +trait AsyncWithCallback: Sync { +    type A: Into<String>; +    async fn test1(&self) -> Self::A; +    async fn fun_with_callback<F: Fn(Self::A) + Sync + Send + 'static>(&self, a: F); +} | 
