diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/lib.rs | 35 | ||||
-rw-r--r-- | src/parse_assoc_type.rs | 6 | ||||
-rw-r--r-- | src/parse_trait_sig.rs | 20 | ||||
-rw-r--r-- | src/syn_utils.rs | 181 | ||||
-rw-r--r-- | src/transform.rs | 19 |
5 files changed, 142 insertions, 119 deletions
@@ -35,7 +35,7 @@ use syn::TypeParam; use syn::TypeParamBound; use syn::TypePath; use syn::Visibility; -use syn_utils::TypeMatcher; +use syn_utils::TypeOrPath; use crate::parse_assoc_type::parse_assoc_type; use crate::parse_assoc_type::AssocTypeParseError; @@ -43,7 +43,7 @@ use crate::parse_trait_sig::parse_trait_signature; use crate::parse_trait_sig::MethodParseError; use crate::parse_trait_sig::SignatureChanges; use crate::parse_trait_sig::TypeTransform; -use crate::syn_utils::find_in_path; +use crate::syn_utils::iter_path; use crate::syn_utils::trait_bounds; use crate::transform::AssocTypeConversions; @@ -181,7 +181,8 @@ pub fn dynamize(_attr: TokenStream, input: TokenStream) -> TokenStream { for type_param in dyn_trait.generics.type_params() { generic_map.insert(type_param.ident.clone(), type_param.bounds.clone()); for trait_bound in trait_bounds(&type_param.bounds) { - if let Some(assoc_type) = find_in_path(&trait_bound.path, &AssocTypeMatcher) { + if let Some(assoc_type) = iter_path(&trait_bound.path).find_map(filter_map_assoc_paths) + { return abort!( assoc_type.span(), "dynamize does not support associated types in trait generic bounds" @@ -345,27 +346,21 @@ fn generate_blanket_impl( } } -struct AssocTypeMatcher; -impl TypeMatcher<Path> for AssocTypeMatcher { - fn match_path<'a>(&self, path: &'a Path) -> Option<&'a Path> { - if path.segments.first().unwrap().ident == "Self" { - return Some(path); - } - None - } +fn path_is_assoc_type(path: &Path) -> bool { + path.segments.first().unwrap().ident == "Self" } -impl<TypeMatch, PathMatch, T> TypeMatcher<T> for (TypeMatch, PathMatch) -where - TypeMatch: Fn(&Type) -> Option<&T>, - PathMatch: Fn(&Path) -> Option<&T>, -{ - fn match_type<'a>(&self, t: &'a Type) -> Option<&'a T> { - self.0(t) +fn match_assoc_type(item: TypeOrPath) -> bool { + if let TypeOrPath::Path(path) = item { + return path_is_assoc_type(path); } + false +} - fn match_path<'a>(&self, path: &'a Path) -> Option<&'a T> { - self.1(path) +fn filter_map_assoc_paths(item: TypeOrPath) -> Option<&Path> { + match item { + TypeOrPath::Path(p) if path_is_assoc_type(p) => Some(p), + _other => None, } } diff --git a/src/parse_assoc_type.rs b/src/parse_assoc_type.rs index 7efc7da..85f3723 100644 --- a/src/parse_assoc_type.rs +++ b/src/parse_assoc_type.rs @@ -3,9 +3,9 @@ use quote::{quote, ToTokens}; use syn::spanned::Spanned; use syn::{GenericArgument, Ident, Path, PathArguments, PathSegment, TraitItemType, Type}; +use crate::match_assoc_type; use crate::parse_trait_sig::TypeTransform; -use crate::syn_utils::{find_in_type, lifetime_bounds, trait_bounds}; -use crate::AssocTypeMatcher; +use crate::syn_utils::{iter_type, lifetime_bounds, trait_bounds}; #[derive(Debug)] pub enum AssocTypeParseError { @@ -63,7 +63,7 @@ pub fn parse_assoc_type( if ident == "Into" && args.args.len() == 1 { if let GenericArgument::Type(into_type) = args.args.first().unwrap() { // provide a better error message for type A: Into<Self::B> - if find_in_type(into_type, &AssocTypeMatcher).is_some() { + if iter_type(into_type).any(match_assoc_type) { return Err((into_type.span(), AssocTypeParseError::AssocTypeInBound)); } diff --git a/src/parse_trait_sig.rs b/src/parse_trait_sig.rs index 76402ac..508e131 100644 --- a/src/parse_trait_sig.rs +++ b/src/parse_trait_sig.rs @@ -6,10 +6,10 @@ use syn::{ }; use syn::{Ident, Signature, TypeImplTrait}; +use crate::match_assoc_type; use crate::parse_assoc_type::BoxType; -use crate::syn_utils::{find_in_type, trait_bounds, TypeMatcher}; +use crate::syn_utils::{iter_type, trait_bounds, TypeOrPath}; use crate::transform::{dynamize_function_bounds, AssocTypeConversions, TransformError}; -use crate::AssocTypeMatcher; #[derive(Debug, Clone)] pub enum TypeTransform { @@ -31,14 +31,10 @@ pub enum MethodParseError { UnconvertibleAssocType, } -struct ImplTraitMatcher; -impl TypeMatcher<TypeImplTrait> for ImplTraitMatcher { - fn match_type<'a>(&self, t: &'a Type) -> Option<&'a TypeImplTrait> { - if let Type::ImplTrait(impltrait) = t { - Some(impltrait) - } else { - None - } +fn filter_map_impl_trait(item: TypeOrPath) -> Option<&TypeImplTrait> { + match item { + TypeOrPath::Type(Type::ImplTrait(impltrait)) => Some(impltrait), + _other => None, } } @@ -58,10 +54,10 @@ pub fn parse_trait_signature( // provide better error messages for associated types in params for input in &signature.inputs { if let FnArg::Typed(pattype) = input { - if find_in_type(&pattype.ty, &AssocTypeMatcher).is_some() { + if iter_type(&pattype.ty).any(match_assoc_type) { return Err((pattype.ty.span(), MethodParseError::AssocTypeInInputs)); } - if let Some(impl_trait) = find_in_type(&pattype.ty, &ImplTraitMatcher) { + if let Some(impl_trait) = iter_type(&pattype.ty).find_map(filter_map_impl_trait) { return Err((impl_trait.span(), MethodParseError::ImplTraitInInputs)); } } diff --git a/src/syn_utils.rs b/src/syn_utils.rs index a778793..43c4d94 100644 --- a/src/syn_utils.rs +++ b/src/syn_utils.rs @@ -1,18 +1,10 @@ +use std::iter; + use syn::{ punctuated::Punctuated, GenericArgument, Lifetime, Path, PathArguments, ReturnType, TraitBound, Type, TypeParamBound, }; -#[allow(unused_variables)] -pub trait TypeMatcher<T> { - fn match_type<'a>(&self, t: &'a Type) -> Option<&'a T> { - None - } - fn match_path<'a>(&self, path: &'a Path) -> Option<&'a T> { - None - } -} - pub fn trait_bounds<T>( bounds: &Punctuated<TypeParamBound, T>, ) -> impl Iterator<Item = &TraitBound> { @@ -31,77 +23,114 @@ pub fn lifetime_bounds<T>( }) } -pub fn find_in_type<'a, T>(t: &'a Type, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> { - if let Some(ret) = matcher.match_type(t) { - return Some(ret); - } - match t { - Type::Array(array) => find_in_type(&array.elem, matcher), - Type::BareFn(fun) => { - for input in &fun.inputs { - if let Some(ret) = find_in_type(&input.ty, matcher) { - return Some(ret); - } - } - if let ReturnType::Type(_, t) = &fun.output { - return find_in_type(t, matcher); - } - None - } - Type::Group(group) => find_in_type(&group.elem, matcher), - Type::ImplTrait(impltrait) => { - for bound in &impltrait.bounds { - if let TypeParamBound::Trait(bound) = bound { - if let Some(ret) = find_in_path(&bound.path, matcher) { - return Some(ret); - } - } - } - None - } - Type::Infer(_) => None, - Type::Macro(_) => None, - Type::Paren(paren) => find_in_type(&paren.elem, matcher), - Type::Path(path) => find_in_path(&path.path, matcher), - Type::Ptr(ptr) => find_in_type(&ptr.elem, matcher), - Type::Reference(reference) => find_in_type(&reference.elem, matcher), - Type::Slice(slice) => find_in_type(&slice.elem, matcher), - Type::TraitObject(traitobj) => { - for bound in &traitobj.bounds { - if let TypeParamBound::Trait(bound) = bound { - if let Some(ret) = find_in_path(&bound.path, matcher) { - return Some(ret); - } - } - } - None - } - Type::Tuple(tuple) => { - for elem in &tuple.elems { - if let Some(ret) = find_in_type(elem, matcher) { - return Some(ret); - } - } - None +pub enum TypeOrPath<'a> { + Type(&'a Type), + Path(&'a Path), +} + +enum IterTypes<A, B, C, D, E, F> { + Single(A), + Function(B), + Tuple(C), + ImplTrait(D), + TraitObject(E), + Path(F), + Empty, +} + +impl<A, B, C, D, E, F> Iterator for IterTypes<A, B, C, D, E, F> +where + A: Iterator, + B: Iterator<Item = A::Item>, + C: Iterator<Item = A::Item>, + D: Iterator<Item = A::Item>, + E: Iterator<Item = A::Item>, + F: Iterator<Item = A::Item>, +{ + type Item = A::Item; + + fn next(&mut self) -> Option<Self::Item> { + match self { + IterTypes::Single(a) => a.next(), + IterTypes::Function(a) => a.next(), + IterTypes::Tuple(a) => a.next(), + IterTypes::ImplTrait(a) => a.next(), + IterTypes::TraitObject(a) => a.next(), + IterTypes::Path(a) => a.next(), + IterTypes::Empty => None, } - _other => None, } } -pub fn find_in_path<'a, T>(path: &'a Path, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> { - if let Some(ret) = matcher.match_path(path) { - return Some(ret); - } - for segment in &path.segments { - if let PathArguments::AngleBracketed(args) = &segment.arguments { - for arg in &args.args { - if let GenericArgument::Type(t) = arg { - if let Some(ret) = find_in_type(t, matcher) { - return Some(ret); - } - } +pub fn iter_path(path: &Path) -> impl Iterator<Item = TypeOrPath> { + iter::once(TypeOrPath::Path(path)).chain(types_in_path(path).flat_map(|t| iter_type(t))) +} + +pub fn iter_type<'a>(t: &'a Type) -> Box<dyn Iterator<Item = TypeOrPath<'a>> + 'a> { + Box::new( + iter::once(TypeOrPath::Type(t)).chain(match t { + Type::Array(array) => IterTypes::Single(iter_type(&array.elem)), + Type::Group(group) => IterTypes::Single(iter_type(&group.elem)), + Type::Paren(paren) => IterTypes::Single(iter_type(&paren.elem)), + Type::Ptr(ptr) => IterTypes::Single(iter_type(&ptr.elem)), + Type::Reference(r) => IterTypes::Single(iter_type(&r.elem)), + Type::Slice(slice) => IterTypes::Single(iter_type(&slice.elem)), + + Type::Tuple(tuple) => IterTypes::Tuple(tuple.elems.iter().flat_map(|i| iter_type(i))), + + Type::BareFn(fun) => { + IterTypes::Function(fun.inputs.iter().flat_map(|i| iter_type(&i.ty)).chain( + match &fun.output { + ReturnType::Default => IterEnum::Left(iter::empty()), + ReturnType::Type(_, ty) => IterEnum::Right(iter_type(ty.as_ref())), + }, + )) } + Type::ImplTrait(impl_trait) => IterTypes::ImplTrait( + trait_bounds(&impl_trait.bounds) + .flat_map(|b| types_in_path(&b.path)) + .flat_map(|t| iter_type(t)), + ), + + Type::Path(path) => IterTypes::Path(iter_path(&path.path)), + Type::TraitObject(trait_obj) => IterTypes::TraitObject( + trait_bounds(&trait_obj.bounds) + .flat_map(|b| types_in_path(&b.path)) + .flat_map(|t| iter_type(t)), + ), + _other => IterTypes::Empty, + }), + ) +} + +enum IterEnum<L, R> { + Left(L), + Right(R), +} + +impl<L, R> Iterator for IterEnum<L, R> +where + L: Iterator, + R: Iterator<Item = L::Item>, +{ + type Item = L::Item; + + fn next(&mut self) -> Option<Self::Item> { + match self { + IterEnum::Left(iter) => iter.next(), + IterEnum::Right(iter) => iter.next(), } } - None +} + +fn types_in_path(p: &Path) -> impl Iterator<Item = &Type> { + p.segments.iter().flat_map(|s| match &s.arguments { + PathArguments::AngleBracketed(ang) => { + IterEnum::Left(ang.args.iter().flat_map(|a| match a { + GenericArgument::Type(t) => Some(t), + _other => None, + })) + } + _other => IterEnum::Right(iter::empty()), + }) } diff --git a/src/transform.rs b/src/transform.rs index c3116cf..6640cf3 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -7,10 +7,10 @@ use syn::{ }; use crate::{ + filter_map_assoc_paths, match_assoc_type, parse_assoc_type::DestType, parse_trait_sig::{MethodParseError, TypeTransform}, - syn_utils::{find_in_path, find_in_type}, - AssocTypeMatcher, + syn_utils::{iter_path, iter_type}, }; #[derive(Default)] @@ -23,7 +23,7 @@ pub enum TransformError { impl AssocTypeConversions<'_> { pub fn parse_type_path(&self, type_: &mut Type) -> Result<TypeTransform, TransformError> { - let assoc_span = match find_in_type(type_, &AssocTypeMatcher) { + let assoc_span = match iter_type(type_).filter_map(filter_map_assoc_paths).next() { Some(path) => path.span(), None => return Ok(TypeTransform::NoOp), }; @@ -49,7 +49,7 @@ impl AssocTypeConversions<'_> { if args.args.len() == 1 { if let GenericArgument::Type(generic_type) = args.args.first_mut().unwrap() { - if find_in_type(generic_type, &AssocTypeMatcher).is_some() { + if iter_type(generic_type).any(match_assoc_type) { return Ok(TypeTransform::Map( self.parse_type_path(generic_type)?.into(), )); @@ -65,8 +65,8 @@ impl AssocTypeConversions<'_> { if let (GenericArgument::Type(ok_type), GenericArgument::Type(err_type)) = (args_iter.next().unwrap(), args_iter.next().unwrap()) { - if find_in_type(ok_type, &AssocTypeMatcher).is_some() - || find_in_type(err_type, &AssocTypeMatcher).is_some() + if iter_type(ok_type).any(match_assoc_type) + || iter_type(err_type).any(match_assoc_type) { return Ok(TypeTransform::Result( self.parse_type_path(ok_type)?.into(), @@ -85,7 +85,7 @@ impl AssocTypeConversions<'_> { if let GenericArgument::Type(generic_type) = args.args.first_mut().unwrap() { - if find_in_type(generic_type, &AssocTypeMatcher).is_some() { + if iter_type(generic_type).any(match_assoc_type) { return Ok(TypeTransform::Map( self.parse_type_path(generic_type)?.into(), )); @@ -146,7 +146,10 @@ pub fn dynamize_function_bounds( } } } - if let Some(path) = find_in_path(&bound.path, &AssocTypeMatcher) { + if let Some(path) = iter_path(&bound.path) + .filter_map(filter_map_assoc_paths) + .next() + { return Err(( path.span(), MethodParseError::UnconvertibleAssocTypeInTraitBound, |