diff options
Diffstat (limited to 'src/parse_trait_sig.rs')
-rw-r--r-- | src/parse_trait_sig.rs | 169 |
1 files changed, 15 insertions, 154 deletions
diff --git a/src/parse_trait_sig.rs b/src/parse_trait_sig.rs index 55a3214..0078c93 100644 --- a/src/parse_trait_sig.rs +++ b/src/parse_trait_sig.rs @@ -1,14 +1,12 @@ -use std::collections::HashMap; - use proc_macro2::Span; use syn::{ - spanned::Spanned, FnArg, Ident, PathArguments, PredicateType, Receiver, ReturnType, Type, - TypePath, WherePredicate, + spanned::Spanned, FnArg, PredicateType, Receiver, ReturnType, Type, TypePath, WherePredicate, }; -use syn::{GenericParam, Signature, TypeImplTrait, TypeParamBound}; +use syn::{Signature, TypeImplTrait}; -use crate::syn_utils::{find_in_path, find_in_type, trait_bounds, TypeMatcher}; -use crate::{As, AssocTypeMatcher}; +use crate::syn_utils::{find_in_type, trait_bounds, TypeMatcher}; +use crate::transform::{dynamize_function_bounds, AssocTypeConversions, TransformError}; +use crate::AssocTypeMatcher; #[derive(Debug, Clone)] pub enum TypeTransform { @@ -47,10 +45,8 @@ pub struct SignatureChanges { pub fn parse_trait_signature( signature: &mut Signature, - assoc_type_conversions: &HashMap<&Ident, &Type>, + assoc_type_conversions: &AssocTypeConversions, ) -> Result<SignatureChanges, (Span, MethodParseError)> { - let assoc_type_conversions = AssocTypeConversions(assoc_type_conversions); - if is_non_dispatchable(signature) { return Err((signature.span(), MethodParseError::NonDispatchableMethod)); } @@ -67,58 +63,10 @@ pub fn parse_trait_signature( } } - let mut type_param_transforms = HashMap::new(); + let type_param_transforms = + dynamize_function_bounds(&mut signature.generics, assoc_type_conversions)?; let mut input_transforms = Vec::new(); - for generic_param in &mut signature.generics.params { - if let GenericParam::Type(type_param) = generic_param { - for bound in &mut type_param.bounds { - if let TypeParamBound::Trait(bound) = bound { - if bound.path.segments.len() == 1 { - let segment = bound.path.segments.first_mut().unwrap(); - - if let PathArguments::Parenthesized(args) = &mut segment.arguments { - if segment.ident == "Fn" - || segment.ident == "FnOnce" - || segment.ident == "FnMut" - { - let mut transforms = Vec::new(); - for input_type in &mut args.inputs { - match assoc_type_conversions.parse_type_path(input_type) { - Ok(ret_type) => { - transforms.push(ret_type); - } - Err(TransformError::UnconvertibleAssocType(span)) => { - return Err(( - span, - MethodParseError::UnconvertibleAssocType, - )); - } - Err(TransformError::AssocTypeInUnsupportedType(span)) => { - return Err(( - span, - MethodParseError::UnconvertibleAssocTypeInFnInput, - )); - } - } - } - if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { - type_param_transforms.insert(&type_param.ident, transforms); - } - } - } - } - if let Some(path) = find_in_path(&bound.path, &AssocTypeMatcher) { - return Err(( - path.span(), - MethodParseError::UnconvertibleAssocTypeInTraitBound, - )); - } - } - } - } - } - for input in &signature.inputs { if let FnArg::Typed(pattype) = input { if let Type::Path(path) = &*pattype.ty { @@ -149,94 +97,6 @@ pub fn parse_trait_signature( }) } -struct AssocTypeConversions<'a>(&'a HashMap<&'a Ident, &'a Type>); - -enum TransformError { - UnconvertibleAssocType(Span), - AssocTypeInUnsupportedType(Span), -} - -impl AssocTypeConversions<'_> { - fn parse_type_path(&self, type_: &mut Type) -> Result<TypeTransform, TransformError> { - let assoc_span = match find_in_type(type_, &AssocTypeMatcher) { - Some(path) => path.span(), - None => return Ok(TypeTransform::NoOp), - }; - - if let Type::Path(TypePath { path, qself: None }) = type_ { - let ident = &path.segments.first().unwrap().ident; - - // TODO: support &mut dyn Iterator<Item = Self::A> - // conversion to Box<dyn Iterator<Item = Whatever>> via .map(Into::into) - - if ident == "Self" && path.segments.len() == 2 { - let ident = &path.segments.last().unwrap().ident; - *type_ = (*self - .0 - .get(&ident) - .ok_or_else(|| TransformError::UnconvertibleAssocType(ident.span()))?) - .clone(); - return Ok(TypeTransform::Into); - } else if ident == "Option" && path.segments.len() == 1 { - let first_seg = path.segments.first_mut().unwrap(); - - if let Some(args) = first_seg.arguments.get_as_mut() { - if args.args.len() == 1 { - if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() { - if find_in_type(generic_type, &AssocTypeMatcher).is_some() { - return Ok(TypeTransform::Map( - self.parse_type_path(generic_type)?.into(), - )); - } - } - } - } - } else if ident == "Result" && path.segments.len() == 1 { - let first_seg = path.segments.first_mut().unwrap(); - if let Some(args) = first_seg.arguments.get_as_mut() { - if args.args.len() == 2 { - let mut args_iter = args.args.iter_mut(); - if let (Some(ok_type), Some(err_type)) = ( - args_iter.next().unwrap().get_as_mut(), - args_iter.next().unwrap().get_as_mut(), - ) { - if find_in_type(ok_type, &AssocTypeMatcher).is_some() - || find_in_type(err_type, &AssocTypeMatcher).is_some() - { - return Ok(TypeTransform::Result( - self.parse_type_path(ok_type)?.into(), - self.parse_type_path(err_type)?.into(), - )); - } - } - } - } - } else { - let last_seg = &path.segments.last().unwrap(); - if last_seg.ident == "Result" { - let last_seg = path.segments.last_mut().unwrap(); - if let Some(args) = last_seg.arguments.get_as_mut() { - if args.args.len() == 1 { - if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() - { - if find_in_type(generic_type, &AssocTypeMatcher).is_some() { - return Ok(TypeTransform::Map( - self.parse_type_path(generic_type)?.into(), - )); - } - } - } - } - } - } - } - - // the type contains an associated type but we - // don't know how to deal with it so we abort - Err(TransformError::AssocTypeInUnsupportedType(assoc_span)) - } -} - fn is_non_dispatchable(signature: &Signature) -> bool { // non-dispatchable: fn example(&self) where Self: Sized; if let Some(where_clause) = &signature.generics.where_clause { @@ -288,13 +148,14 @@ fn bounds_self_and_has_bound_sized(predicate: &WherePredicate) -> bool { #[cfg(test)] mod tests { - use std::collections::HashMap; - use quote::{format_ident, quote}; use syn::{TraitItemMethod, Type}; - use crate::parse_trait_sig::{ - parse_trait_signature, MethodParseError, SignatureChanges, TypeTransform, + use crate::{ + parse_trait_sig::{ + parse_trait_signature, MethodParseError, SignatureChanges, TypeTransform, + }, + transform::AssocTypeConversions, }; #[test] @@ -320,10 +181,10 @@ mod tests { }) .unwrap(); - let mut assoc_type_map = HashMap::new(); + let mut assoc_type_map = AssocTypeConversions::default(); let ident = format_ident!("A"); let dest = Type::Verbatim(quote! {Example}); - assoc_type_map.insert(&ident, &dest); + assoc_type_map.0.insert(&ident, &dest); assert!(matches!( parse_trait_signature(&mut type1.sig, &assoc_type_map), |