diff options
Diffstat (limited to 'src/parse_trait_sig.rs')
-rw-r--r-- | src/parse_trait_sig.rs | 505 |
1 files changed, 505 insertions, 0 deletions
diff --git a/src/parse_trait_sig.rs b/src/parse_trait_sig.rs new file mode 100644 index 0000000..55a3214 --- /dev/null +++ b/src/parse_trait_sig.rs @@ -0,0 +1,505 @@ +use std::collections::HashMap; + +use proc_macro2::Span; +use syn::{ + spanned::Spanned, FnArg, Ident, PathArguments, PredicateType, Receiver, ReturnType, Type, + TypePath, WherePredicate, +}; +use syn::{GenericParam, Signature, TypeImplTrait, TypeParamBound}; + +use crate::syn_utils::{find_in_path, find_in_type, trait_bounds, TypeMatcher}; +use crate::{As, AssocTypeMatcher}; + +#[derive(Debug, Clone)] +pub enum TypeTransform { + NoOp, + Into, + Map(Box<TypeTransform>), + Result(Box<TypeTransform>, Box<TypeTransform>), +} + +#[derive(Debug)] +pub enum MethodParseError { + NonDispatchableMethod, + AssocTypeInInputs, + ImplTraitInInputs, + AssocTypeInUnsupportedReturnType, + UnconvertibleAssocTypeInFnInput, + UnconvertibleAssocTypeInTraitBound, + UnconvertibleAssocType, +} + +struct ImplTraitMatcher; +impl TypeMatcher<TypeImplTrait> for ImplTraitMatcher { + fn match_type<'a>(&self, t: &'a Type) -> Option<&'a TypeImplTrait> { + if let Type::ImplTrait(impltrait) = t { + Some(impltrait) + } else { + None + } + } +} + +pub struct SignatureChanges { + pub return_type: TypeTransform, + pub inputs: Vec<Option<Vec<TypeTransform>>>, +} + +pub fn parse_trait_signature( + signature: &mut Signature, + assoc_type_conversions: &HashMap<&Ident, &Type>, +) -> Result<SignatureChanges, (Span, MethodParseError)> { + let assoc_type_conversions = AssocTypeConversions(assoc_type_conversions); + + if is_non_dispatchable(signature) { + return Err((signature.span(), MethodParseError::NonDispatchableMethod)); + } + + // provide better error messages for associated types in params + for input in &signature.inputs { + if let FnArg::Typed(pattype) = input { + if find_in_type(&pattype.ty, &AssocTypeMatcher).is_some() { + return Err((pattype.ty.span(), MethodParseError::AssocTypeInInputs)); + } + if let Some(impl_trait) = find_in_type(&pattype.ty, &ImplTraitMatcher) { + return Err((impl_trait.span(), MethodParseError::ImplTraitInInputs)); + } + } + } + + let mut type_param_transforms = HashMap::new(); + let mut input_transforms = Vec::new(); + + for generic_param in &mut signature.generics.params { + if let GenericParam::Type(type_param) = generic_param { + for bound in &mut type_param.bounds { + if let TypeParamBound::Trait(bound) = bound { + if bound.path.segments.len() == 1 { + let segment = bound.path.segments.first_mut().unwrap(); + + if let PathArguments::Parenthesized(args) = &mut segment.arguments { + if segment.ident == "Fn" + || segment.ident == "FnOnce" + || segment.ident == "FnMut" + { + let mut transforms = Vec::new(); + for input_type in &mut args.inputs { + match assoc_type_conversions.parse_type_path(input_type) { + Ok(ret_type) => { + transforms.push(ret_type); + } + Err(TransformError::UnconvertibleAssocType(span)) => { + return Err(( + span, + MethodParseError::UnconvertibleAssocType, + )); + } + Err(TransformError::AssocTypeInUnsupportedType(span)) => { + return Err(( + span, + MethodParseError::UnconvertibleAssocTypeInFnInput, + )); + } + } + } + if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { + type_param_transforms.insert(&type_param.ident, transforms); + } + } + } + } + if let Some(path) = find_in_path(&bound.path, &AssocTypeMatcher) { + return Err(( + path.span(), + MethodParseError::UnconvertibleAssocTypeInTraitBound, + )); + } + } + } + } + } + + for input in &signature.inputs { + if let FnArg::Typed(pattype) = input { + if let Type::Path(path) = &*pattype.ty { + if let Some(ident) = path.path.get_ident() { + input_transforms.push(type_param_transforms.get(ident).map(|x| (*x).clone())); + continue; + } + } + } + input_transforms.push(None); + } + + let return_type = match &mut signature.output { + ReturnType::Type(_, og_type) => match assoc_type_conversions.parse_type_path(og_type) { + Ok(ret_type) => ret_type, + Err(TransformError::UnconvertibleAssocType(span)) => { + return Err((span, MethodParseError::UnconvertibleAssocType)); + } + Err(TransformError::AssocTypeInUnsupportedType(span)) => { + return Err((span, MethodParseError::AssocTypeInUnsupportedReturnType)); + } + }, + ReturnType::Default => TypeTransform::NoOp, + }; + Ok(SignatureChanges { + return_type, + inputs: input_transforms, + }) +} + +struct AssocTypeConversions<'a>(&'a HashMap<&'a Ident, &'a Type>); + +enum TransformError { + UnconvertibleAssocType(Span), + AssocTypeInUnsupportedType(Span), +} + +impl AssocTypeConversions<'_> { + fn parse_type_path(&self, type_: &mut Type) -> Result<TypeTransform, TransformError> { + let assoc_span = match find_in_type(type_, &AssocTypeMatcher) { + Some(path) => path.span(), + None => return Ok(TypeTransform::NoOp), + }; + + if let Type::Path(TypePath { path, qself: None }) = type_ { + let ident = &path.segments.first().unwrap().ident; + + // TODO: support &mut dyn Iterator<Item = Self::A> + // conversion to Box<dyn Iterator<Item = Whatever>> via .map(Into::into) + + if ident == "Self" && path.segments.len() == 2 { + let ident = &path.segments.last().unwrap().ident; + *type_ = (*self + .0 + .get(&ident) + .ok_or_else(|| TransformError::UnconvertibleAssocType(ident.span()))?) + .clone(); + return Ok(TypeTransform::Into); + } else if ident == "Option" && path.segments.len() == 1 { + let first_seg = path.segments.first_mut().unwrap(); + + if let Some(args) = first_seg.arguments.get_as_mut() { + if args.args.len() == 1 { + if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() { + if find_in_type(generic_type, &AssocTypeMatcher).is_some() { + return Ok(TypeTransform::Map( + self.parse_type_path(generic_type)?.into(), + )); + } + } + } + } + } else if ident == "Result" && path.segments.len() == 1 { + let first_seg = path.segments.first_mut().unwrap(); + if let Some(args) = first_seg.arguments.get_as_mut() { + if args.args.len() == 2 { + let mut args_iter = args.args.iter_mut(); + if let (Some(ok_type), Some(err_type)) = ( + args_iter.next().unwrap().get_as_mut(), + args_iter.next().unwrap().get_as_mut(), + ) { + if find_in_type(ok_type, &AssocTypeMatcher).is_some() + || find_in_type(err_type, &AssocTypeMatcher).is_some() + { + return Ok(TypeTransform::Result( + self.parse_type_path(ok_type)?.into(), + self.parse_type_path(err_type)?.into(), + )); + } + } + } + } + } else { + let last_seg = &path.segments.last().unwrap(); + if last_seg.ident == "Result" { + let last_seg = path.segments.last_mut().unwrap(); + if let Some(args) = last_seg.arguments.get_as_mut() { + if args.args.len() == 1 { + if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() + { + if find_in_type(generic_type, &AssocTypeMatcher).is_some() { + return Ok(TypeTransform::Map( + self.parse_type_path(generic_type)?.into(), + )); + } + } + } + } + } + } + } + + // the type contains an associated type but we + // don't know how to deal with it so we abort + Err(TransformError::AssocTypeInUnsupportedType(assoc_span)) + } +} + +fn is_non_dispatchable(signature: &Signature) -> bool { + // non-dispatchable: fn example(&self) where Self: Sized; + if let Some(where_clause) = &signature.generics.where_clause { + if where_clause + .predicates + .iter() + .any(bounds_self_and_has_bound_sized) + { + return true; + } + } + + // non-dispatchable: fn example(); + if signature.inputs.is_empty() { + return true; + } + + // non-dispatchable: fn example(arg: Type); + if matches!(signature.inputs.first(), Some(FnArg::Typed(_))) { + return true; + } + + // non-dispatchable: fn example(self); + if matches!( + signature.inputs.first(), + Some(FnArg::Receiver(Receiver { + reference: None, + .. + })) + ) { + return true; + } + false +} + +/// Returns true if the bounded type is `Self` and the bounds contain `Sized`. +fn bounds_self_and_has_bound_sized(predicate: &WherePredicate) -> bool { + matches!( + predicate, + WherePredicate::Type(PredicateType { + bounded_ty: Type::Path(TypePath { path, .. }), + bounds, + .. + }) + if path.is_ident("Self") + && trait_bounds(bounds).any(|b| b.path.is_ident("Sized")) + ) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use quote::{format_ident, quote}; + use syn::{TraitItemMethod, Type}; + + use crate::parse_trait_sig::{ + parse_trait_signature, MethodParseError, SignatureChanges, TypeTransform, + }; + + #[test] + fn ok_void() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Ok(SignatureChanges { + return_type: TypeTransform::NoOp, + .. + }) + )); + } + + #[test] + fn ok_assoc_type() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Self::A; + }) + .unwrap(); + + let mut assoc_type_map = HashMap::new(); + let ident = format_ident!("A"); + let dest = Type::Verbatim(quote! {Example}); + assoc_type_map.insert(&ident, &dest); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &assoc_type_map), + Ok(SignatureChanges { + return_type: TypeTransform::Into, + .. + }) + )); + } + + #[test] + fn err_unconvertible_assoc_type() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Self::A; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::UnconvertibleAssocType)) + )); + } + + #[test] + fn err_non_dispatchable_assoc_function_no_args() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::NonDispatchableMethod)) + )); + } + + #[test] + fn err_non_dispatchable_assoc_function_with_args() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(arg: Type); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::NonDispatchableMethod)) + )); + } + + #[test] + fn err_non_dispatchable_consume_self() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(self); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::NonDispatchableMethod)) + )); + } + + #[test] + fn err_non_dispatchable_where_self_sized() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) where Self: Sized; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::NonDispatchableMethod)) + )); + } + + #[test] + fn err_assoc_type_in_unsupported_return() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Foo<Self::A>; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) + )); + } + + #[test] + fn err_assoc_type_in_unsupported_return_in_opt() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Option<Foo<Self::A>>; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) + )); + } + + #[test] + fn err_assoc_type_in_unsupported_return_in_ok() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Result<Foo<Self::A>, Error>; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) + )); + } + + #[test] + fn err_assoc_type_in_unsupported_return_in_err() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self) -> Result<Ok, Foo<Self::A>>; + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) + )); + } + + #[test] + fn err_assoc_type_in_input() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self, x: Self::A); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInInputs)) + )); + } + + #[test] + fn err_assoc_type_in_input_opt() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self, x: Option<Self::A>); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::AssocTypeInInputs)) + )); + } + + #[test] + fn err_impl_in_input() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test(&self, arg: Option<impl SomeTrait>); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::ImplTraitInInputs)) + )); + } + + #[test] + fn err_assoc_type_in_generic() { + let mut type1: TraitItemMethod = syn::parse2(quote! { + fn test<F: Fn(Foo<Self::A>)>(&self, fun: F); + }) + .unwrap(); + + assert!(matches!( + parse_trait_signature(&mut type1.sig, &Default::default()), + Err((_, MethodParseError::UnconvertibleAssocTypeInFnInput)) + )); + } +} |