diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/transform.rs | 115 | 
1 files changed, 76 insertions, 39 deletions
| diff --git a/src/transform.rs b/src/transform.rs index 6640cf3..676cfd8 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -2,8 +2,8 @@ use std::collections::HashMap;  use proc_macro2::Span;  use syn::{ -    spanned::Spanned, GenericArgument, Generics, Ident, PathArguments, Type, TypeParamBound, -    TypePath, +    spanned::Spanned, GenericArgument, Generics, Ident, PathArguments, TraitBound, Type, +    TypeParamBound, TypePath, WherePredicate,  };  use crate::{ @@ -112,51 +112,88 @@ pub fn dynamize_function_bounds(      for type_param in generics.type_params_mut() {          for bound in &mut type_param.bounds {              if let TypeParamBound::Trait(bound) = bound { -                if bound.path.segments.len() == 1 { -                    let segment = bound.path.segments.first_mut().unwrap(); +                dynamize_trait_bound( +                    bound, +                    assoc_type_conversions, +                    &type_param.ident, +                    &mut type_param_transforms, +                )?; +            } +        } +    } -                    if let PathArguments::Parenthesized(args) = &mut segment.arguments { -                        if segment.ident == "Fn" -                            || segment.ident == "FnOnce" -                            || segment.ident == "FnMut" -                        { -                            let mut transforms = Vec::new(); -                            for input_type in &mut args.inputs { -                                match assoc_type_conversions.parse_type_path(input_type) { -                                    Ok(ret_type) => { -                                        transforms.push(ret_type); -                                    } -                                    Err(TransformError::UnconvertibleAssocType(span)) => { -                                        return Err(( -                                            span, -                                            MethodParseError::UnconvertibleAssocType, -                                        )); -                                    } -                                    Err(TransformError::AssocTypeInUnsupportedType(span)) => { -                                        return Err(( -                                            span, -                                            MethodParseError::UnconvertibleAssocTypeInFnInput, -                                        )); -                                    } -                                } -                            } -                            if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { -                                type_param_transforms.insert(type_param.ident.clone(), transforms); +    if let Some(where_clause) = &mut generics.where_clause { +        for predicate in &mut where_clause.predicates { +            if let WherePredicate::Type(predicate_type) = predicate { +                if let Type::Path(path) = &mut predicate_type.bounded_ty { +                    if let Some(ident) = path.path.get_ident() { +                        for bound in &mut predicate_type.bounds { +                            if let TypeParamBound::Trait(bound) = bound { +                                dynamize_trait_bound( +                                    bound, +                                    assoc_type_conversions, +                                    ident, +                                    &mut type_param_transforms, +                                )?;                              }                          } +                        continue;                      }                  } -                if let Some(path) = iter_path(&bound.path) -                    .filter_map(filter_map_assoc_paths) -                    .next() -                { -                    return Err(( -                        path.span(), -                        MethodParseError::UnconvertibleAssocTypeInTraitBound, -                    )); + +                // TODO: return error if predicate_type.bounded_ty contains associated type + +                for bound in &mut predicate_type.bounds { +                    if let TypeParamBound::Trait(_bound) = bound { +                        // TODO: return error if bound.path contains associated type +                    }                  }              }          }      } +      Ok(type_param_transforms)  } + +fn dynamize_trait_bound( +    bound: &mut TraitBound, +    assoc_type_conversions: &AssocTypeConversions, +    type_ident: &Ident, +    type_param_transforms: &mut HashMap<Ident, Vec<TypeTransform>>, +) -> Result<(), (Span, MethodParseError)> { +    if bound.path.segments.len() == 1 { +        let segment = bound.path.segments.first_mut().unwrap(); + +        if let PathArguments::Parenthesized(args) = &mut segment.arguments { +            if segment.ident == "Fn" || segment.ident == "FnOnce" || segment.ident == "FnMut" { +                let mut transforms = Vec::new(); +                for input_type in &mut args.inputs { +                    match assoc_type_conversions.parse_type_path(input_type) { +                        Ok(ret_type) => { +                            transforms.push(ret_type); +                        } +                        Err(TransformError::UnconvertibleAssocType(span)) => { +                            return Err((span, MethodParseError::UnconvertibleAssocType)); +                        } +                        Err(TransformError::AssocTypeInUnsupportedType(span)) => { +                            return Err((span, MethodParseError::UnconvertibleAssocTypeInFnInput)); +                        } +                    } +                } +                if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { +                    type_param_transforms.insert(type_ident.clone(), transforms); +                } +            } +        } +    } +    if let Some(path) = iter_path(&bound.path) +        .filter_map(filter_map_assoc_paths) +        .next() +    { +        return Err(( +            path.span(), +            MethodParseError::UnconvertibleAssocTypeInTraitBound, +        )); +    } +    Ok(()) +} | 
