use std::collections::HashMap; use proc_macro2::Span; use syn::{ spanned::Spanned, FnArg, Ident, PathArguments, PredicateType, Receiver, ReturnType, Type, TypePath, WherePredicate, }; use syn::{GenericParam, Signature, TypeImplTrait, TypeParamBound}; use crate::syn_utils::{find_in_path, find_in_type, trait_bounds, TypeMatcher}; use crate::{As, AssocTypeMatcher}; #[derive(Debug, Clone)] pub enum TypeTransform { NoOp, Into, Map(Box), Result(Box, Box), } #[derive(Debug)] pub enum MethodParseError { NonDispatchableMethod, AssocTypeInInputs, ImplTraitInInputs, AssocTypeInUnsupportedReturnType, UnconvertibleAssocTypeInFnInput, UnconvertibleAssocTypeInTraitBound, UnconvertibleAssocType, } struct ImplTraitMatcher; impl TypeMatcher for ImplTraitMatcher { fn match_type<'a>(&self, t: &'a Type) -> Option<&'a TypeImplTrait> { if let Type::ImplTrait(impltrait) = t { Some(impltrait) } else { None } } } pub struct SignatureChanges { pub return_type: TypeTransform, pub inputs: Vec>>, } pub fn parse_trait_signature( signature: &mut Signature, assoc_type_conversions: &HashMap<&Ident, &Type>, ) -> Result { let assoc_type_conversions = AssocTypeConversions(assoc_type_conversions); if is_non_dispatchable(signature) { return Err((signature.span(), MethodParseError::NonDispatchableMethod)); } // provide better error messages for associated types in params for input in &signature.inputs { if let FnArg::Typed(pattype) = input { if find_in_type(&pattype.ty, &AssocTypeMatcher).is_some() { return Err((pattype.ty.span(), MethodParseError::AssocTypeInInputs)); } if let Some(impl_trait) = find_in_type(&pattype.ty, &ImplTraitMatcher) { return Err((impl_trait.span(), MethodParseError::ImplTraitInInputs)); } } } let mut type_param_transforms = HashMap::new(); let mut input_transforms = Vec::new(); for generic_param in &mut signature.generics.params { if let GenericParam::Type(type_param) = generic_param { for bound in &mut type_param.bounds { if let TypeParamBound::Trait(bound) = bound { if bound.path.segments.len() == 1 { let segment = bound.path.segments.first_mut().unwrap(); if let PathArguments::Parenthesized(args) = &mut segment.arguments { if segment.ident == "Fn" || segment.ident == "FnOnce" || segment.ident == "FnMut" { let mut transforms = Vec::new(); for input_type in &mut args.inputs { match assoc_type_conversions.parse_type_path(input_type) { Ok(ret_type) => { transforms.push(ret_type); } Err(TransformError::UnconvertibleAssocType(span)) => { return Err(( span, MethodParseError::UnconvertibleAssocType, )); } Err(TransformError::AssocTypeInUnsupportedType(span)) => { return Err(( span, MethodParseError::UnconvertibleAssocTypeInFnInput, )); } } } if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { type_param_transforms.insert(&type_param.ident, transforms); } } } } if let Some(path) = find_in_path(&bound.path, &AssocTypeMatcher) { return Err(( path.span(), MethodParseError::UnconvertibleAssocTypeInTraitBound, )); } } } } } for input in &signature.inputs { if let FnArg::Typed(pattype) = input { if let Type::Path(path) = &*pattype.ty { if let Some(ident) = path.path.get_ident() { input_transforms.push(type_param_transforms.get(ident).map(|x| (*x).clone())); continue; } } } input_transforms.push(None); } let return_type = match &mut signature.output { ReturnType::Type(_, og_type) => match assoc_type_conversions.parse_type_path(og_type) { Ok(ret_type) => ret_type, Err(TransformError::UnconvertibleAssocType(span)) => { return Err((span, MethodParseError::UnconvertibleAssocType)); } Err(TransformError::AssocTypeInUnsupportedType(span)) => { return Err((span, MethodParseError::AssocTypeInUnsupportedReturnType)); } }, ReturnType::Default => TypeTransform::NoOp, }; Ok(SignatureChanges { return_type, inputs: input_transforms, }) } struct AssocTypeConversions<'a>(&'a HashMap<&'a Ident, &'a Type>); enum TransformError { UnconvertibleAssocType(Span), AssocTypeInUnsupportedType(Span), } impl AssocTypeConversions<'_> { fn parse_type_path(&self, type_: &mut Type) -> Result { let assoc_span = match find_in_type(type_, &AssocTypeMatcher) { Some(path) => path.span(), None => return Ok(TypeTransform::NoOp), }; if let Type::Path(TypePath { path, qself: None }) = type_ { let ident = &path.segments.first().unwrap().ident; // TODO: support &mut dyn Iterator // conversion to Box> via .map(Into::into) if ident == "Self" && path.segments.len() == 2 { let ident = &path.segments.last().unwrap().ident; *type_ = (*self .0 .get(&ident) .ok_or_else(|| TransformError::UnconvertibleAssocType(ident.span()))?) .clone(); return Ok(TypeTransform::Into); } else if ident == "Option" && path.segments.len() == 1 { let first_seg = path.segments.first_mut().unwrap(); if let Some(args) = first_seg.arguments.get_as_mut() { if args.args.len() == 1 { if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() { if find_in_type(generic_type, &AssocTypeMatcher).is_some() { return Ok(TypeTransform::Map( self.parse_type_path(generic_type)?.into(), )); } } } } } else if ident == "Result" && path.segments.len() == 1 { let first_seg = path.segments.first_mut().unwrap(); if let Some(args) = first_seg.arguments.get_as_mut() { if args.args.len() == 2 { let mut args_iter = args.args.iter_mut(); if let (Some(ok_type), Some(err_type)) = ( args_iter.next().unwrap().get_as_mut(), args_iter.next().unwrap().get_as_mut(), ) { if find_in_type(ok_type, &AssocTypeMatcher).is_some() || find_in_type(err_type, &AssocTypeMatcher).is_some() { return Ok(TypeTransform::Result( self.parse_type_path(ok_type)?.into(), self.parse_type_path(err_type)?.into(), )); } } } } } else { let last_seg = &path.segments.last().unwrap(); if last_seg.ident == "Result" { let last_seg = path.segments.last_mut().unwrap(); if let Some(args) = last_seg.arguments.get_as_mut() { if args.args.len() == 1 { if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() { if find_in_type(generic_type, &AssocTypeMatcher).is_some() { return Ok(TypeTransform::Map( self.parse_type_path(generic_type)?.into(), )); } } } } } } } // the type contains an associated type but we // don't know how to deal with it so we abort Err(TransformError::AssocTypeInUnsupportedType(assoc_span)) } } fn is_non_dispatchable(signature: &Signature) -> bool { // non-dispatchable: fn example(&self) where Self: Sized; if let Some(where_clause) = &signature.generics.where_clause { if where_clause .predicates .iter() .any(bounds_self_and_has_bound_sized) { return true; } } // non-dispatchable: fn example(); if signature.inputs.is_empty() { return true; } // non-dispatchable: fn example(arg: Type); if matches!(signature.inputs.first(), Some(FnArg::Typed(_))) { return true; } // non-dispatchable: fn example(self); if matches!( signature.inputs.first(), Some(FnArg::Receiver(Receiver { reference: None, .. })) ) { return true; } false } /// Returns true if the bounded type is `Self` and the bounds contain `Sized`. fn bounds_self_and_has_bound_sized(predicate: &WherePredicate) -> bool { matches!( predicate, WherePredicate::Type(PredicateType { bounded_ty: Type::Path(TypePath { path, .. }), bounds, .. }) if path.is_ident("Self") && trait_bounds(bounds).any(|b| b.path.is_ident("Sized")) ) } #[cfg(test)] mod tests { use std::collections::HashMap; use quote::{format_ident, quote}; use syn::{TraitItemMethod, Type}; use crate::parse_trait_sig::{ parse_trait_signature, MethodParseError, SignatureChanges, TypeTransform, }; #[test] fn ok_void() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(&self); }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Ok(SignatureChanges { return_type: TypeTransform::NoOp, .. }) )); } #[test] fn ok_assoc_type() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(&self) -> Self::A; }) .unwrap(); let mut assoc_type_map = HashMap::new(); let ident = format_ident!("A"); let dest = Type::Verbatim(quote! {Example}); assoc_type_map.insert(&ident, &dest); assert!(matches!( parse_trait_signature(&mut type1.sig, &assoc_type_map), Ok(SignatureChanges { return_type: TypeTransform::Into, .. }) )); } #[test] fn err_unconvertible_assoc_type() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(&self) -> Self::A; }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Err((_, MethodParseError::UnconvertibleAssocType)) )); } #[test] fn err_non_dispatchable_assoc_function_no_args() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(); }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Err((_, MethodParseError::NonDispatchableMethod)) )); } #[test] fn err_non_dispatchable_assoc_function_with_args() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(arg: Type); }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Err((_, MethodParseError::NonDispatchableMethod)) )); } #[test] fn err_non_dispatchable_consume_self() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(self); }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Err((_, MethodParseError::NonDispatchableMethod)) )); } #[test] fn err_non_dispatchable_where_self_sized() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(&self) where Self: Sized; }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Err((_, MethodParseError::NonDispatchableMethod)) )); } #[test] fn err_assoc_type_in_unsupported_return() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(&self) -> Foo; }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) )); } #[test] fn err_assoc_type_in_unsupported_return_in_opt() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(&self) -> Option>; }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) )); } #[test] fn err_assoc_type_in_unsupported_return_in_ok() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(&self) -> Result, Error>; }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) )); } #[test] fn err_assoc_type_in_unsupported_return_in_err() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(&self) -> Result>; }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Err((_, MethodParseError::AssocTypeInUnsupportedReturnType)) )); } #[test] fn err_assoc_type_in_input() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(&self, x: Self::A); }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Err((_, MethodParseError::AssocTypeInInputs)) )); } #[test] fn err_assoc_type_in_input_opt() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(&self, x: Option); }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Err((_, MethodParseError::AssocTypeInInputs)) )); } #[test] fn err_impl_in_input() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test(&self, arg: Option); }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Err((_, MethodParseError::ImplTraitInInputs)) )); } #[test] fn err_assoc_type_in_generic() { let mut type1: TraitItemMethod = syn::parse2(quote! { fn test)>(&self, fun: F); }) .unwrap(); assert!(matches!( parse_trait_signature(&mut type1.sig, &Default::default()), Err((_, MethodParseError::UnconvertibleAssocTypeInFnInput)) )); } }