diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/lib.rs | 425 | ||||
-rw-r--r-- | src/parse_assoc_type.rs | 119 | ||||
-rw-r--r-- | src/parse_trait_sig.rs | 505 | ||||
-rw-r--r-- | src/syn_utils.rs | 97 |
4 files changed, 1146 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..1f0432c --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,425 @@ +#![doc = include_str!("../README.md")] +use std::collections::HashMap; + +use proc_macro::TokenStream; +use proc_macro2::Group; +use proc_macro2::Ident; +use quote::format_ident; +use quote::quote; +use quote::quote_spanned; +use quote::ToTokens; +use syn::parse_macro_input; +use syn::punctuated::Punctuated; +use syn::token::Brace; +use syn::token::Gt; +use syn::token::Lt; +use syn::token::Trait; +use syn::AngleBracketedGenericArguments; +use syn::Block; +use syn::Expr; +use syn::GenericArgument; +use syn::GenericParam; +use syn::ImplItemMethod; +use syn::ItemTrait; +use syn::Path; +use syn::PathArguments; +use syn::PathSegment; +use syn::Signature; +use syn::Stmt; +use syn::TraitBound; +use syn::TraitItem; +use syn::TraitItemMethod; +use syn::Type; +use syn::TypeParam; +use syn::TypeParamBound; +use syn::TypePath; +use syn::Visibility; +use syn_utils::TypeMatcher; + +use crate::parse_assoc_type::parse_assoc_type; +use crate::parse_assoc_type::AssocTypeParseError; +use crate::parse_trait_sig::parse_trait_signature; +use crate::parse_trait_sig::MethodParseError; +use crate::parse_trait_sig::SignatureChanges; +use crate::parse_trait_sig::TypeTransform; +mod parse_assoc_type; +mod parse_trait_sig; +mod syn_utils; + +macro_rules! abort { + ($span:expr, $message:literal) => { + quote_spanned! {$span => compile_error!($message);}.into() + }; +} + +#[proc_macro_attribute] +pub fn dynamize(_attr: TokenStream, input: TokenStream) -> TokenStream { + let mut original_trait = parse_macro_input!(input as ItemTrait); + assert!(original_trait.auto_token.is_none()); + + let original_trait_name = original_trait.ident.clone(); + + let mut objectifiable_methods: Vec<(Signature, SignatureChanges)> = Vec::new(); + + let mut assoc_type_conversions: HashMap<&Ident, &Type> = HashMap::new(); + + for item in &original_trait.items { + if let TraitItem::Type(assoc_type) = item { + match parse_assoc_type(assoc_type) { + Err((_, AssocTypeParseError::NoIntoBound)) => continue, + Err((span, AssocTypeParseError::AssocTypeInBound)) => { + return abort!(span, "dynamize does not support associated types here") + } + Err((span, AssocTypeParseError::GenericAssociatedType)) => { + return abort!( + span, + "dynamize does not (yet?) support generic associated types" + ) + } + Ok((ident, type_)) => { + assoc_type_conversions.insert(ident, type_); + } + } + } + } + + for item in &original_trait.items { + if let TraitItem::Method(method) = item { + let mut signature = method.sig.clone(); + match parse_trait_signature(&mut signature, &assoc_type_conversions) { + Err((_, MethodParseError::NonDispatchableMethod)) => continue, + Err((span, MethodParseError::UnconvertibleAssocType)) => { + return abort!( + span, + "associated type is either undefined or doesn't have an Into bound" + ) + } + Err((span, MethodParseError::AssocTypeInInputs)) => { + return abort!( + span, + "dynamize does not support associated types in parameter types" + ) + } + Err(( + span, + MethodParseError::AssocTypeInUnsupportedReturnType + | MethodParseError::UnconvertibleAssocTypeInFnInput, + )) => return abort!(span, "dynamize does not know how to convert this type"), + Err((span, MethodParseError::UnconvertibleAssocTypeInTraitBound)) => { + return abort!(span, "dynamize does not support associated types here") + } + Err((span, MethodParseError::ImplTraitInInputs)) => { + return abort!( + span, + "dynamize does not support impl here, change it to a method generic" + ) + } + Ok(parsed_method) => objectifiable_methods.push((signature, parsed_method)), + }; + } + } + + let mut method_impls: Vec<ImplItemMethod> = Vec::new(); + + let mut blanket_impl_attrs = Vec::new(); + let mut dyn_trait_attrs = Vec::new(); + + // FUTURE: use Vec::drain_filter once it's stable + let mut i = 0; + while i < original_trait.attrs.len() { + if original_trait.attrs[i].path.is_ident("blanket_impl_attr") { + let attr = original_trait.attrs.remove(i); + let group: Group = match syn::parse2(attr.tokens) { + Ok(g) => g, + Err(err) => { + return abort!( + err.span(), + "expected parenthesis: #[blanket_impl_attr(...)]" + ) + } + }; + let tokens = group.stream(); + blanket_impl_attrs.push(quote! {#[#tokens]}); + } else if original_trait.attrs[i].path.is_ident("dyn_trait_attr") { + let attr = original_trait.attrs.remove(i); + let group: Group = match syn::parse2(attr.tokens) { + Ok(g) => g, + Err(err) => { + return abort!(err.span(), "expected parenthesis: #[dyn_trait_attr(...)]") + } + }; + let tokens = group.stream(); + dyn_trait_attrs.push(quote! {#[#tokens]}); + } else { + i += 1; + } + } + + let mut dyn_trait = ItemTrait { + ident: format_ident!("Dyn{}", original_trait.ident), + + attrs: Vec::new(), + vis: original_trait.vis.clone(), + unsafety: original_trait.unsafety, + auto_token: None, + trait_token: Trait::default(), + generics: original_trait.generics.clone(), + colon_token: None, + supertraits: Punctuated::new(), + brace_token: Brace::default(), + items: Vec::new(), + }; + + for (signature, parsed_method) in objectifiable_methods { + let mut new_method = TraitItemMethod { + attrs: Vec::new(), + sig: signature, + default: None, + semi_token: None, + }; + + let fun_name = &new_method.sig.ident; + + let args = new_method + .sig + .inputs + .iter() + .enumerate() + .map(|(idx, arg)| match arg { + syn::FnArg::Receiver(_) => quote! {self}, + syn::FnArg::Typed(pat_type) => match pat_type.pat.as_ref() { + syn::Pat::Ident(ident) => match &parsed_method.inputs[idx] { + None => ident.ident.to_token_stream(), + Some(transforms) => { + let args = (0..transforms.len()).map(|i| format_ident!("a{}", i)); + let mut calls: Vec<_> = + args.clone().map(|i| i.into_token_stream()).collect(); + for i in 0..calls.len() { + transforms[i].append_conversion(&mut calls[i]); + } + let move_opt = new_method.sig.asyncness.map(|_| quote! {move}); + quote!(#move_opt |#(#args),*| #ident(#(#calls),*)) + } + }, + _other => { + panic!("unexpected"); + } + }, + }); + + // in order for a trait to be object-safe its methods may not have + // generics so we convert method generics into trait generics + if new_method + .sig + .generics + .params + .iter() + .any(|p| matches!(p, GenericParam::Type(_))) + { + // syn::punctuated::Punctuated doesn't have a remove(index) + // method so we firstly move all elements to a vector + let mut params = Vec::new(); + while let Some(generic_param) = new_method.sig.generics.params.pop() { + params.push(generic_param.into_value()); + } + + // FUTURE: use Vec::drain_filter once it's stable + let mut i = 0; + while i < params.len() { + if matches!(params[i], GenericParam::Type(_)) { + dyn_trait.generics.params.push(params.remove(i)); + } else { + i += 1; + } + } + + new_method.sig.generics.params.extend(params); + + if dyn_trait.generics.lt_token.is_none() { + dyn_trait.generics.lt_token = Some(Lt::default()); + dyn_trait.generics.gt_token = Some(Gt::default()); + } + } + + let mut expr = quote!(#original_trait_name::#fun_name(#(#args),*)); + if new_method.sig.asyncness.is_some() { + expr.extend(quote! {.await}) + } + parsed_method.return_type.append_conversion(&mut expr); + + method_impls.push(ImplItemMethod { + attrs: Vec::new(), + vis: Visibility::Inherited, + defaultness: None, + sig: new_method.sig.clone(), + block: Block { + brace_token: Brace::default(), + stmts: vec![Stmt::Expr(Expr::Verbatim(expr))], + }, + }); + dyn_trait.items.push(new_method.into()); + } + + let blanket_impl = generate_blanket_impl(&dyn_trait, &original_trait, &method_impls); + + let expanded = quote! { + #original_trait + + #(#dyn_trait_attrs)* + #dyn_trait + + #(#blanket_impl_attrs)* + #blanket_impl + }; + TokenStream::from(expanded) +} + +fn generate_blanket_impl( + dyn_trait: &ItemTrait, + original_trait: &ItemTrait, + method_impls: &[ImplItemMethod], +) -> proc_macro2::TokenStream { + let mut blanket_generics = dyn_trait.generics.clone(); + let some_ident = format_ident!("__to_be_dynamized"); + blanket_generics.params.push(GenericParam::Type(TypeParam { + attrs: Vec::new(), + ident: some_ident.clone(), + colon_token: None, + bounds: std::iter::once(TypeParamBound::Trait(TraitBound { + paren_token: None, + modifier: syn::TraitBoundModifier::None, + lifetimes: None, + path: Path { + leading_colon: None, + segments: std::iter::once(path_segment_for_trait(original_trait)).collect(), + }, + })) + .collect(), + eq_token: None, + default: None, + })); + let (_, type_gen, _where) = dyn_trait.generics.split_for_impl(); + let dyn_trait_name = &dyn_trait.ident; + quote! { + impl #blanket_generics #dyn_trait_name #type_gen for #some_ident { + #(#method_impls)* + } + } +} + +struct AssocTypeMatcher; +impl TypeMatcher<Path> for AssocTypeMatcher { + fn match_path<'a>(&self, path: &'a Path) -> Option<&'a Path> { + if path.segments.first().unwrap().ident == "Self" { + return Some(path); + } + None + } +} + +impl<TypeMatch, PathMatch, T> TypeMatcher<T> for (TypeMatch, PathMatch) +where + TypeMatch: Fn(&Type) -> Option<&T>, + PathMatch: Fn(&Path) -> Option<&T>, +{ + fn match_type<'a>(&self, t: &'a Type) -> Option<&'a T> { + self.0(t) + } + + fn match_path<'a>(&self, path: &'a Path) -> Option<&'a T> { + self.1(path) + } +} + +impl TypeTransform { + fn append_conversion(&self, stream: &mut proc_macro2::TokenStream) { + match self { + TypeTransform::Into => stream.extend(quote! {.into()}), + TypeTransform::Map(opt) => { + let mut inner = quote!(x); + opt.append_conversion(&mut inner); + stream.extend(quote! {.map(|x| #inner)}) + } + TypeTransform::Result(ok, err) => { + if !matches!(ok.as_ref(), TypeTransform::NoOp) { + let mut inner = quote!(x); + ok.append_conversion(&mut inner); + stream.extend(quote! {.map(|x| #inner)}) + } + if !matches!(err.as_ref(), TypeTransform::NoOp) { + let mut inner = quote!(x); + err.append_conversion(&mut inner); + stream.extend(quote! {.map_err(|x| #inner)}) + } + } + _other => {} + } + } +} + +/// Just a convenience trait for us to avoid match/if-let blocks everywhere. +trait As<T> { + fn get_as(&self) -> Option<&T>; + fn get_as_mut(&mut self) -> Option<&mut T>; +} + +impl As<AngleBracketedGenericArguments> for PathArguments { + fn get_as(&self) -> Option<&AngleBracketedGenericArguments> { + match self { + PathArguments::AngleBracketed(args) => Some(args), + _other => None, + } + } + fn get_as_mut(&mut self) -> Option<&mut AngleBracketedGenericArguments> { + match self { + PathArguments::AngleBracketed(args) => Some(args), + _other => None, + } + } +} + +impl As<Type> for GenericArgument { + fn get_as(&self) -> Option<&Type> { + match self { + GenericArgument::Type(typearg) => Some(typearg), + _other => None, + } + } + fn get_as_mut(&mut self) -> Option<&mut Type> { + match self { + GenericArgument::Type(typearg) => Some(typearg), + _other => None, + } + } +} + +fn path_segment_for_trait(sometrait: &ItemTrait) -> PathSegment { + PathSegment { + ident: sometrait.ident.clone(), + arguments: match sometrait.generics.params.is_empty() { + true => PathArguments::None, + false => PathArguments::AngleBracketed(AngleBracketedGenericArguments { + colon2_token: None, + lt_token: Lt::default(), + args: sometrait + .generics + .params + .iter() + .map(|param| match param { + GenericParam::Type(type_param) => { + GenericArgument::Type(Type::Path(TypePath { + path: type_param.ident.clone().into(), + qself: None, + })) + } + GenericParam::Lifetime(lifetime_def) => { + GenericArgument::Lifetime(lifetime_def.lifetime.clone()) + } + GenericParam::Const(_) => todo!("const generic param not supported"), + }) + .collect(), + gt_token: Gt::default(), + }), + }, + } +} diff --git a/src/parse_assoc_type.rs b/src/parse_assoc_type.rs new file mode 100644 index 0000000..37fd78c --- /dev/null +++ b/src/parse_assoc_type.rs @@ -0,0 +1,119 @@ +use proc_macro2::Span; +use syn::spanned::Spanned; +use syn::{GenericArgument, Ident, PathArguments, PathSegment, TraitItemType, Type}; + +use crate::syn_utils::{find_in_type, trait_bounds}; +use crate::AssocTypeMatcher; + +#[derive(Debug)] +pub enum AssocTypeParseError { + AssocTypeInBound, + GenericAssociatedType, + NoIntoBound, +} + +pub fn parse_assoc_type( + assoc_type: &TraitItemType, +) -> Result<(&Ident, &Type), (Span, AssocTypeParseError)> { + for bound in trait_bounds(&assoc_type.bounds) { + if let PathSegment { + ident, + arguments: PathArguments::AngleBracketed(args), + } = bound.path.segments.first().unwrap() + { + if ident == "Into" && args.args.len() == 1 { + if let GenericArgument::Type(into_type) = args.args.first().unwrap() { + // provide a better error message for type A: Into<Self::B> + if find_in_type(into_type, &AssocTypeMatcher).is_some() { + return Err((into_type.span(), AssocTypeParseError::AssocTypeInBound)); + } + + // TODO: support lifetime GATs (see the currently failing tests/gats.rs) + if !assoc_type.generics.params.is_empty() { + return Err(( + assoc_type.generics.params.span(), + AssocTypeParseError::GenericAssociatedType, + )); + } + + return Ok((&assoc_type.ident, into_type)); + } + } + } + } + Err((assoc_type.span(), AssocTypeParseError::NoIntoBound)) +} + +#[cfg(test)] +mod tests { + use quote::quote; + use syn::{TraitItemType, Type}; + + use crate::parse_assoc_type::{parse_assoc_type, AssocTypeParseError}; + + #[test] + fn ok() { + let type1: TraitItemType = syn::parse2(quote! { + type A: Into<String>; + }) + .unwrap(); + + assert!(matches!( + parse_assoc_type(&type1), + Ok((id, Type::Path(path))) + if id == "A" && path.path.is_ident("String") + )); + } + + #[test] + fn err_no_bound() { + let type1: TraitItemType = syn::parse2(quote! { + type A; + }) + .unwrap(); + + assert!(matches!( + parse_assoc_type(&type1), + Err((_, AssocTypeParseError::NoIntoBound)) + )); + } + + #[test] + fn err_assoc_type_in_bound() { + let type1: TraitItemType = syn::parse2(quote! { + type A: Into<Self::B>; + }) + .unwrap(); + + assert!(matches!( + parse_assoc_type(&type1), + Err((_, AssocTypeParseError::AssocTypeInBound)) + )); + } + + #[test] + fn err_gat_type() { + let type1: TraitItemType = syn::parse2(quote! { + type A<X>: Into<Foobar<X>>; + }) + .unwrap(); + + assert!(matches!( + parse_assoc_type(&type1), + Err((_, AssocTypeParseError::GenericAssociatedType)) + )); + } + + #[test] + fn err_gat_lifetime() { + let type1: TraitItemType = syn::parse2(quote! { + type A<'a>: Into<Foobar<'a>>; + }) + .unwrap(); + + assert!(matches!( + parse_assoc_type(&type1), + Err((_, AssocTypeParseError::GenericAssociatedType)) + )); + } +} diff --git a/src/parse_trait_sig.rs b/src/parse_trait_sig.rs new file mode 100644 index 0000000..55a3214 --- /dev/null +++ b/src/parse_trait_sig.rs @@ -0,0 +1,505 @@ +use std::collections::HashMap; + +use proc_macro2::Span; +use syn::{ + spanned::Spanned, FnArg, Ident, PathArguments, PredicateType, Receiver, ReturnType, Type, + TypePath, WherePredicate, +}; +use syn::{GenericParam, Signature, TypeImplTrait, TypeParamBound}; + +use crate::syn_utils::{find_in_path, find_in_type, trait_bounds, TypeMatcher}; +use crate::{As, AssocTypeMatcher}; + +#[derive(Debug, Clone)] +pub enum TypeTransform { + NoOp, + Into, + Map(Box<TypeTransform>), + Result(Box<TypeTransform>, Box<TypeTransform>), +} + +#[derive(Debug)] +pub enum MethodParseError { + NonDispatchableMethod, + AssocTypeInInputs, + ImplTraitInInputs, + AssocTypeInUnsupportedReturnType, + UnconvertibleAssocTypeInFnInput, + UnconvertibleAssocTypeInTraitBound, + UnconvertibleAssocType, +} + +struct ImplTraitMatcher; +impl TypeMatcher<TypeImplTrait> for ImplTraitMatcher { + fn match_type<'a>(&self, t: &'a Type) -> Option<&'a TypeImplTrait> { + if let Type::ImplTrait(impltrait) = t { + Some(impltrait) + } else { + None + } + } +} + +pub struct SignatureChanges { + pub return_type: TypeTransform, + pub inputs: Vec<Option<Vec<TypeTransform>>>, +} + +pub fn parse_trait_signature( + signature: &mut Signature, + assoc_type_conversions: &HashMap<&Ident, &Type>, +) -> Result<SignatureChanges, (Span, MethodParseError)> { + let assoc_type_conversions = AssocTypeConversions(assoc_type_conversions); + + if is_non_dispatchable(signature) { + return Err((signature.span(), MethodParseError::NonDispatchableMethod)); + } + + // provide better error messages for associated types in params + for input in &signature.inputs { + if let FnArg::Typed(pattype) = input { + if find_in_type(&pattype.ty, &AssocTypeMatcher).is_some() { + return Err((pattype.ty.span(), MethodParseError::AssocTypeInInputs)); + } + if let Some(impl_trait) = find_in_type(&pattype.ty, &ImplTraitMatcher) { + return Err((impl_trait.span(), MethodParseError::ImplTraitInInputs)); + } + } + } + + let mut type_param_transforms = HashMap::new(); + let mut input_transforms = Vec::new(); + + for generic_param in &mut signature.generics.params { + if let GenericParam::Type(type_param) = generic_param { + for bound in &mut type_param.bounds { + if let TypeParamBound::Trait(bound) = bound { + if bound.path.segments.len() == 1 { + let segment = bound.path.segments.first_mut().unwrap(); + + if let PathArguments::Parenthesized(args) = &mut segment.arguments { + if segment.ident == "Fn" + || segment.ident == "FnOnce" + || segment.ident == "FnMut" + { + let mut transforms = Vec::new(); + for input_type in &mut args.inputs { + match assoc_type_conversions.parse_type_path(input_type) { + Ok(ret_type) => { + transforms.push(ret_type); + } + Err(TransformError::UnconvertibleAssocType(span)) => { + return Err(( + span, + MethodParseError::UnconvertibleAssocType, + )); + } + Err(TransformError::AssocTypeInUnsupportedType(span)) => { + return Err(( + span, + MethodParseError::UnconvertibleAssocTypeInFnInput, + )); + } + } + } + if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { + type_param_transforms.insert(&type_param.ident, transforms); + } + } + } + } + if let Some(path) = find_in_path(&bound.path, &AssocTypeMatcher) { + return Err(( + path.span(), + MethodParseError::UnconvertibleAssocTypeInTraitBound, + )); + } + } + } + } + } + + for input in &signature.inputs { + if let FnArg::Typed(pattype) = input { + if let Type::Path(path) = &*pattype.ty { + if let Some(ident) = path.path.get_ident() { + input_transforms.push(type_param_transforms.get(ident).map(|x| (*x).clone())); + continue; + } + } + } + input_transforms.push(None); + } + + let return_type = match &mut signature.output { + ReturnType::Type(_, og_type) => match assoc_type_conversions.parse_type_path(og_type) { + Ok(ret_type) => ret_type, + Err(TransformError::UnconvertibleAssocType(span)) => { + return Err((span, MethodParseError::UnconvertibleAssocType)); + } + Err(TransformError::AssocTypeInUnsupportedType(span)) => { + return Err((span, MethodParseError::AssocTypeInUnsupportedReturnType)); + } + }, + ReturnType::Default => TypeTransform::NoOp, + }; + Ok(SignatureChanges { + return_type, + inputs: input_transforms, + }) +} + +struct AssocTypeConversions<'a>(&'a HashMap<&'a Ident, &'a Type>); + +enum TransformError { + UnconvertibleAssocType(Span), + AssocTypeInUnsupportedType(Span), +} + +impl AssocTypeConversions<'_> { + fn parse_type_path(&self, type_: &mut Type) -> Result<TypeTransform, TransformError> { + let assoc_span = match find_in_type(type_, &AssocTypeMatcher) { + Some(path) => path.span(), + None => return Ok(TypeTransform::NoOp), + }; + + if let Type::Path(TypePath { path, qself: None }) = type_ { + let ident = &path.segments.first().unwrap().ident; + + // TODO: support &mut dyn Iterator<Item = Self::A> + // conversion to Box<dyn Iterator<Item = Whatever>> via .map(Into::into) + + if ident == "Self" && path.segments.len() == 2 { + let ident = &path.segments.last().unwrap().ident; + *type_ = (*self + .0 + .get(&ident) + .ok_or_else(|| TransformError::UnconvertibleAssocType(ident.span()))?) + .clone(); + return Ok(TypeTransform::Into); + } else if ident == "Option" && path.segments.len() == 1 { + let first_seg = path.segments.first_mut().unwrap(); + + if let Some(args) = first_seg.arguments.get_as_mut() { + if args.args.len() == 1 { + if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() { + if find_in_type(generic_type, &AssocTypeMatcher).is_some() { + return Ok(TypeTransform::Map( + self.parse_type_path(generic_type)?.into(), + )); + } + } + } + } + } else if ident == "Result" && path.segments.len() == 1 { + let first_seg = path.segments.first_mut().unwrap(); + if let Some(args) = first_seg.arguments.get_as_mut() { + if args.args.len() == 2 { + let mut args_iter = args.args.iter_mut(); + if let (Some(ok_type), Some(err_type)) = ( + args_iter.next().unwrap().get_as_mut(), + args_iter.next().unwrap().get_as_mut(), + ) { + if find_in_type(ok_type, &AssocTypeMatcher).is_some() + || find_in_type(err_type, &AssocTypeMatcher).is_some() + { + return Ok(TypeTransform::Result( + self.parse_type_path(ok_type)?.into(), + self.parse_type_path(err_type)?.into(), + )); + } + } + } + } + } else { + let last_seg = &path.segments.last().unwrap(); + if last_seg.ident == "Result" { + let last_seg = path.segments.last_mut().unwrap(); + if let Some(args) = last_seg.arguments.get_as_mut() { + if args.args.len() == 1 { + if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() + { + if find_in_type(generic_type, &AssocTypeMatcher).is_some() { + return Ok(TypeTransform::Map( + self.parse_type_path(generic_type)?.into(), + )); + } + } + } + } + } + } + } + + // the type contains an associated type but we + // don't know how to deal with it so we abort + Err(TransformError::AssocTypeInUnsupportedType(assoc_span)) + } +} + +fn is_non_dispatchable(signature: &Signature) -> bool { + // non-dispatchable: fn example(&self) where Self: Sized; + if let Some(where_clause) = &signature.generics.where_clause { + if where_clause + .predicates + .iter() + .any(bounds_self_and_has_bound_sized) + { + return true; + } + } + + // non-dispatchable: fn example(); + if signature.inputs.is_empty() { + return true; + } + + // non-dispatchable: fn example(arg: Type); + if matches!(signature.inputs.first(), Some(FnArg::Typed(_))) { + return true; + } + + // non-dispatchable: fn example(self); + if matches!( + signature.inputs.first(), + Some(FnArg::Receiver(Receiver { + reference: None, + .. + })) + ) { + return true; + } + false +} + +/// Returns true if the bounded type is `Self` and the bounds contain `Sized`. +fn bounds_self_and_has_bound_sized(predicate: &WherePredicate) -> bool { + matches!( + predicate, + WherePredicate::Type(PredicateType { + bounded_ty: Type::Path(TypePath { path, .. }), + bounds, + .. + }) + if path.is_ident("Self") + && trait_bounds(bounds).any(|b| b.path.is_ident("Sized")) + ) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use quote::{format_ident, quote}; + use syn::{TraitItemMethod, Type}; + + use crate::parse_trait_sig::{ + parse_trait_signature, MethodParseError, SignatureChanges, TypeTransform, + }; + + #[test] + fn ok_void() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Ok(SignatureChanges { + return_type: TypeTransform::NoOp, + .. + }) + )); + } + + #[test] + fn ok_assoc_type() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Self::A; + }) + .unwrap(); + + let mut assoc_type_map = HashMap::new(); + let ident = format_ident!("A"); + let dest = Type::Verbatim(quote! {Example}); + assoc_type_map.insert(&ident, &dest); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &assoc_type_map), + Ok(SignatureChanges { + return_type: TypeTransform::Into, + .. + }) + )); + } + + #[test] + fn err_unconvertible_assoc_type() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Self::A; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::UnconvertibleAssocType)) + )); + } + + #[test] + fn err_non_dispatchable_assoc_function_no_args() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::NonDispatchableMethod)) + )); + } + + #[test] + fn err_non_dispatchable_assoc_function_with_args() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(arg: Type); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::NonDispatchableMethod)) + )); + } + + #[test] + fn err_non_dispatchable_consume_self() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(self); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::NonDispatchableMethod)) + )); + } + + #[test] + fn err_non_dispatchable_where_self_sized() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) where Self: Sized; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::NonDispatchableMethod)) + )); + } + + #[test] + fn err_assoc_type_in_unsupported_return() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Foo<Self::A>; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) + )); + } + + #[test] + fn err_assoc_type_in_unsupported_return_in_opt() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Option<Foo<Self::A>>; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) + )); + } + + #[test] + fn err_assoc_type_in_unsupported_return_in_ok() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Result<Foo<Self::A>, Error>; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) + )); + } + + #[test] + fn err_assoc_type_in_unsupported_return_in_err() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Result<Ok, Foo<Self::A>>; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) + )); + } + + #[test] + fn err_assoc_type_in_input() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self, x: Self::A); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInInputs)) + )); + } + + #[test] + fn err_assoc_type_in_input_opt() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self, x: Option<Self::A>); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInInputs)) + )); + } + + #[test] + fn err_impl_in_input() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self, arg: Option<impl SomeTrait>); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::ImplTraitInInputs)) + )); + } + + #[test] + fn err_assoc_type_in_generic() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test<F: Fn(Foo<Self::A>)>(&self, fun: F); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::UnconvertibleAssocTypeInFnInput)) + )); + } +} diff --git a/src/syn_utils.rs b/src/syn_utils.rs new file mode 100644 index 0000000..4588186 --- /dev/null +++ b/src/syn_utils.rs @@ -0,0 +1,97 @@ +use syn::{ + punctuated::Punctuated, GenericArgument, Path, PathArguments, ReturnType, TraitBound, Type, + TypeParamBound, +}; + +pub trait TypeMatcher<T> { + fn match_type<'a>(&self, t: &'a Type) -> Option<&'a T> { + None + } + fn match_path<'a>(&self, path: &'a Path) -> Option<&'a T> { + None + } +} + +pub fn trait_bounds<T>( + bounds: &Punctuated<TypeParamBound, T>, +) -> impl Iterator<Item = &TraitBound> { + bounds.iter().filter_map(|b| match b { + TypeParamBound::Trait(t) => Some(t), + TypeParamBound::Lifetime(_) => None, + }) +} + +pub fn find_in_type<'a, T>(t: &'a Type, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> { + if let Some(ret) = matcher.match_type(t) { + return Some(ret); + } + match t { + Type::Array(array) => find_in_type(&array.elem, matcher), + Type::BareFn(fun) => { + for input in &fun.inputs { + if let Some(ret) = find_in_type(&input.ty, matcher) { + return Some(ret); + } + } + if let ReturnType::Type(_, t) = &fun.output { + return find_in_type(t, matcher); + } + None + } + Type::Group(group) => find_in_type(&group.elem, matcher), + Type::ImplTrait(impltrait) => { + for bound in &impltrait.bounds { + if let TypeParamBound::Trait(bound) = bound { + if let Some(ret) = find_in_path(&bound.path, matcher) { + return Some(ret); + } + } + } + None + } + Type::Infer(_) => None, + Type::Macro(_) => None, + Type::Paren(paren) => find_in_type(&paren.elem, matcher), + Type::Path(path) => find_in_path(&path.path, matcher), + Type::Ptr(ptr) => find_in_type(&ptr.elem, matcher), + Type::Reference(reference) => find_in_type(&reference.elem, matcher), + Type::Slice(slice) => find_in_type(&slice.elem, matcher), + Type::TraitObject(traitobj) => { + for bound in &traitobj.bounds { + if let TypeParamBound::Trait(bound) = bound { + if let Some(ret) = find_in_path(&bound.path, matcher) { + return Some(ret); + } + } + } + None + } + Type::Tuple(tuple) => { + for elem in &tuple.elems { + if let Some(ret) = find_in_type(elem, matcher) { + return Some(ret); + } + } + None + } + _other => None, + } +} + +pub fn find_in_path<'a, T>(path: &'a Path, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> { + if let Some(ret) = matcher.match_path(path) { + return Some(ret); + } + for segment in &path.segments { + if let PathArguments::AngleBracketed(args) = &segment.arguments { + for arg in &args.args { + if let GenericArgument::Type(t) = arg { + if let Some(ret) = find_in_type(t, matcher) { + return Some(ret); + } + } + } + } + } + None +} |