diff options
-rw-r--r-- | src/transform.rs | 115 | ||||
-rw-r--r-- | tests/tests.rs | 4 |
2 files changed, 80 insertions, 39 deletions
diff --git a/src/transform.rs b/src/transform.rs index 6640cf3..676cfd8 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -2,8 +2,8 @@ use std::collections::HashMap; use proc_macro2::Span; use syn::{ - spanned::Spanned, GenericArgument, Generics, Ident, PathArguments, Type, TypeParamBound, - TypePath, + spanned::Spanned, GenericArgument, Generics, Ident, PathArguments, TraitBound, Type, + TypeParamBound, TypePath, WherePredicate, }; use crate::{ @@ -112,51 +112,88 @@ pub fn dynamize_function_bounds( for type_param in generics.type_params_mut() { for bound in &mut type_param.bounds { if let TypeParamBound::Trait(bound) = bound { - if bound.path.segments.len() == 1 { - let segment = bound.path.segments.first_mut().unwrap(); + dynamize_trait_bound( + bound, + assoc_type_conversions, + &type_param.ident, + &mut type_param_transforms, + )?; + } + } + } - if let PathArguments::Parenthesized(args) = &mut segment.arguments { - if segment.ident == "Fn" - || segment.ident == "FnOnce" - || segment.ident == "FnMut" - { - let mut transforms = Vec::new(); - for input_type in &mut args.inputs { - match assoc_type_conversions.parse_type_path(input_type) { - Ok(ret_type) => { - transforms.push(ret_type); - } - Err(TransformError::UnconvertibleAssocType(span)) => { - return Err(( - span, - MethodParseError::UnconvertibleAssocType, - )); - } - Err(TransformError::AssocTypeInUnsupportedType(span)) => { - return Err(( - span, - MethodParseError::UnconvertibleAssocTypeInFnInput, - )); - } - } - } - if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { - type_param_transforms.insert(type_param.ident.clone(), transforms); + if let Some(where_clause) = &mut generics.where_clause { + for predicate in &mut where_clause.predicates { + if let WherePredicate::Type(predicate_type) = predicate { + if let Type::Path(path) = &mut predicate_type.bounded_ty { + if let Some(ident) = path.path.get_ident() { + for bound in &mut predicate_type.bounds { + if let TypeParamBound::Trait(bound) = bound { + dynamize_trait_bound( + bound, + assoc_type_conversions, + ident, + &mut type_param_transforms, + )?; } } + continue; } } - if let Some(path) = iter_path(&bound.path) - .filter_map(filter_map_assoc_paths) - .next() - { - return Err(( - path.span(), - MethodParseError::UnconvertibleAssocTypeInTraitBound, - )); + + // TODO: return error if predicate_type.bounded_ty contains associated type + + for bound in &mut predicate_type.bounds { + if let TypeParamBound::Trait(_bound) = bound { + // TODO: return error if bound.path contains associated type + } } } } } + Ok(type_param_transforms) } + +fn dynamize_trait_bound( + bound: &mut TraitBound, + assoc_type_conversions: &AssocTypeConversions, + type_ident: &Ident, + type_param_transforms: &mut HashMap<Ident, Vec<TypeTransform>>, +) -> Result<(), (Span, MethodParseError)> { + if bound.path.segments.len() == 1 { + let segment = bound.path.segments.first_mut().unwrap(); + + if let PathArguments::Parenthesized(args) = &mut segment.arguments { + if segment.ident == "Fn" || segment.ident == "FnOnce" || segment.ident == "FnMut" { + let mut transforms = Vec::new(); + for input_type in &mut args.inputs { + match assoc_type_conversions.parse_type_path(input_type) { + Ok(ret_type) => { + transforms.push(ret_type); + } + Err(TransformError::UnconvertibleAssocType(span)) => { + return Err((span, MethodParseError::UnconvertibleAssocType)); + } + Err(TransformError::AssocTypeInUnsupportedType(span)) => { + return Err((span, MethodParseError::UnconvertibleAssocTypeInFnInput)); + } + } + } + if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { + type_param_transforms.insert(type_ident.clone(), transforms); + } + } + } + } + if let Some(path) = iter_path(&bound.path) + .filter_map(filter_map_assoc_paths) + .next() + { + return Err(( + path.span(), + MethodParseError::UnconvertibleAssocTypeInTraitBound, + )); + } + Ok(()) +} diff --git a/tests/tests.rs b/tests/tests.rs index 448af64..3139185 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -146,6 +146,10 @@ trait TraitWithCallback { type A: Into<String>; fn fun_with_callback<F: Fn(Self::A)>(&self, a: F); + fn fun_with_callback0<G>(&self, a: G) + where + G: Fn(Self::A); + fn fun_with_callback1<X: Fn(Option<Self::A>)>(&self, a: X); fn fun_with_callback2<Y: Fn(i32, Option<Self::A>, String) -> bool>(&self, a: Y); |