From 91ca29ade1ab6b90dda972c2a42a5dc529ddfb44 Mon Sep 17 00:00:00 2001
From: Martin Fischer <martin@push-f.com>
Date: Sat, 20 Nov 2021 15:10:23 +0100
Subject: refactor: traverse AST via iterators

---
 src/lib.rs              |  35 ++++------
 src/parse_assoc_type.rs |   6 +-
 src/parse_trait_sig.rs  |  20 +++---
 src/syn_utils.rs        | 181 ++++++++++++++++++++++++++++--------------------
 src/transform.rs        |  19 ++---
 5 files changed, 142 insertions(+), 119 deletions(-)

diff --git a/src/lib.rs b/src/lib.rs
index 21efe52..2e0129e 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -35,7 +35,7 @@ use syn::TypeParam;
 use syn::TypeParamBound;
 use syn::TypePath;
 use syn::Visibility;
-use syn_utils::TypeMatcher;
+use syn_utils::TypeOrPath;
 
 use crate::parse_assoc_type::parse_assoc_type;
 use crate::parse_assoc_type::AssocTypeParseError;
@@ -43,7 +43,7 @@ use crate::parse_trait_sig::parse_trait_signature;
 use crate::parse_trait_sig::MethodParseError;
 use crate::parse_trait_sig::SignatureChanges;
 use crate::parse_trait_sig::TypeTransform;
-use crate::syn_utils::find_in_path;
+use crate::syn_utils::iter_path;
 use crate::syn_utils::trait_bounds;
 use crate::transform::AssocTypeConversions;
 
@@ -181,7 +181,8 @@ pub fn dynamize(_attr: TokenStream, input: TokenStream) -> TokenStream {
     for type_param in dyn_trait.generics.type_params() {
         generic_map.insert(type_param.ident.clone(), type_param.bounds.clone());
         for trait_bound in trait_bounds(&type_param.bounds) {
-            if let Some(assoc_type) = find_in_path(&trait_bound.path, &AssocTypeMatcher) {
+            if let Some(assoc_type) = iter_path(&trait_bound.path).find_map(filter_map_assoc_paths)
+            {
                 return abort!(
                     assoc_type.span(),
                     "dynamize does not support associated types in trait generic bounds"
@@ -345,27 +346,21 @@ fn generate_blanket_impl(
     }
 }
 
-struct AssocTypeMatcher;
-impl TypeMatcher<Path> for AssocTypeMatcher {
-    fn match_path<'a>(&self, path: &'a Path) -> Option<&'a Path> {
-        if path.segments.first().unwrap().ident == "Self" {
-            return Some(path);
-        }
-        None
-    }
+fn path_is_assoc_type(path: &Path) -> bool {
+    path.segments.first().unwrap().ident == "Self"
 }
 
-impl<TypeMatch, PathMatch, T> TypeMatcher<T> for (TypeMatch, PathMatch)
-where
-    TypeMatch: Fn(&Type) -> Option<&T>,
-    PathMatch: Fn(&Path) -> Option<&T>,
-{
-    fn match_type<'a>(&self, t: &'a Type) -> Option<&'a T> {
-        self.0(t)
+fn match_assoc_type(item: TypeOrPath) -> bool {
+    if let TypeOrPath::Path(path) = item {
+        return path_is_assoc_type(path);
     }
+    false
+}
 
-    fn match_path<'a>(&self, path: &'a Path) -> Option<&'a T> {
-        self.1(path)
+fn filter_map_assoc_paths(item: TypeOrPath) -> Option<&Path> {
+    match item {
+        TypeOrPath::Path(p) if path_is_assoc_type(p) => Some(p),
+        _other => None,
     }
 }
 
diff --git a/src/parse_assoc_type.rs b/src/parse_assoc_type.rs
index 7efc7da..85f3723 100644
--- a/src/parse_assoc_type.rs
+++ b/src/parse_assoc_type.rs
@@ -3,9 +3,9 @@ use quote::{quote, ToTokens};
 use syn::spanned::Spanned;
 use syn::{GenericArgument, Ident, Path, PathArguments, PathSegment, TraitItemType, Type};
 
+use crate::match_assoc_type;
 use crate::parse_trait_sig::TypeTransform;
-use crate::syn_utils::{find_in_type, lifetime_bounds, trait_bounds};
-use crate::AssocTypeMatcher;
+use crate::syn_utils::{iter_type, lifetime_bounds, trait_bounds};
 
 #[derive(Debug)]
 pub enum AssocTypeParseError {
@@ -63,7 +63,7 @@ pub fn parse_assoc_type(
             if ident == "Into" && args.args.len() == 1 {
                 if let GenericArgument::Type(into_type) = args.args.first().unwrap() {
                     // provide a better error message for type A: Into<Self::B>
-                    if find_in_type(into_type, &AssocTypeMatcher).is_some() {
+                    if iter_type(into_type).any(match_assoc_type) {
                         return Err((into_type.span(), AssocTypeParseError::AssocTypeInBound));
                     }
 
diff --git a/src/parse_trait_sig.rs b/src/parse_trait_sig.rs
index 76402ac..508e131 100644
--- a/src/parse_trait_sig.rs
+++ b/src/parse_trait_sig.rs
@@ -6,10 +6,10 @@ use syn::{
 };
 use syn::{Ident, Signature, TypeImplTrait};
 
+use crate::match_assoc_type;
 use crate::parse_assoc_type::BoxType;
-use crate::syn_utils::{find_in_type, trait_bounds, TypeMatcher};
+use crate::syn_utils::{iter_type, trait_bounds, TypeOrPath};
 use crate::transform::{dynamize_function_bounds, AssocTypeConversions, TransformError};
-use crate::AssocTypeMatcher;
 
 #[derive(Debug, Clone)]
 pub enum TypeTransform {
@@ -31,14 +31,10 @@ pub enum MethodParseError {
     UnconvertibleAssocType,
 }
 
-struct ImplTraitMatcher;
-impl TypeMatcher<TypeImplTrait> for ImplTraitMatcher {
-    fn match_type<'a>(&self, t: &'a Type) -> Option<&'a TypeImplTrait> {
-        if let Type::ImplTrait(impltrait) = t {
-            Some(impltrait)
-        } else {
-            None
-        }
+fn filter_map_impl_trait(item: TypeOrPath) -> Option<&TypeImplTrait> {
+    match item {
+        TypeOrPath::Type(Type::ImplTrait(impltrait)) => Some(impltrait),
+        _other => None,
     }
 }
 
@@ -58,10 +54,10 @@ pub fn parse_trait_signature(
     // provide better error messages for associated types in params
     for input in &signature.inputs {
         if let FnArg::Typed(pattype) = input {
-            if find_in_type(&pattype.ty, &AssocTypeMatcher).is_some() {
+            if iter_type(&pattype.ty).any(match_assoc_type) {
                 return Err((pattype.ty.span(), MethodParseError::AssocTypeInInputs));
             }
-            if let Some(impl_trait) = find_in_type(&pattype.ty, &ImplTraitMatcher) {
+            if let Some(impl_trait) = iter_type(&pattype.ty).find_map(filter_map_impl_trait) {
                 return Err((impl_trait.span(), MethodParseError::ImplTraitInInputs));
             }
         }
diff --git a/src/syn_utils.rs b/src/syn_utils.rs
index a778793..43c4d94 100644
--- a/src/syn_utils.rs
+++ b/src/syn_utils.rs
@@ -1,18 +1,10 @@
+use std::iter;
+
 use syn::{
     punctuated::Punctuated, GenericArgument, Lifetime, Path, PathArguments, ReturnType, TraitBound,
     Type, TypeParamBound,
 };
 
-#[allow(unused_variables)]
-pub trait TypeMatcher<T> {
-    fn match_type<'a>(&self, t: &'a Type) -> Option<&'a T> {
-        None
-    }
-    fn match_path<'a>(&self, path: &'a Path) -> Option<&'a T> {
-        None
-    }
-}
-
 pub fn trait_bounds<T>(
     bounds: &Punctuated<TypeParamBound, T>,
 ) -> impl Iterator<Item = &TraitBound> {
@@ -31,77 +23,114 @@ pub fn lifetime_bounds<T>(
     })
 }
 
-pub fn find_in_type<'a, T>(t: &'a Type, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> {
-    if let Some(ret) = matcher.match_type(t) {
-        return Some(ret);
-    }
-    match t {
-        Type::Array(array) => find_in_type(&array.elem, matcher),
-        Type::BareFn(fun) => {
-            for input in &fun.inputs {
-                if let Some(ret) = find_in_type(&input.ty, matcher) {
-                    return Some(ret);
-                }
-            }
-            if let ReturnType::Type(_, t) = &fun.output {
-                return find_in_type(t, matcher);
-            }
-            None
-        }
-        Type::Group(group) => find_in_type(&group.elem, matcher),
-        Type::ImplTrait(impltrait) => {
-            for bound in &impltrait.bounds {
-                if let TypeParamBound::Trait(bound) = bound {
-                    if let Some(ret) = find_in_path(&bound.path, matcher) {
-                        return Some(ret);
-                    }
-                }
-            }
-            None
-        }
-        Type::Infer(_) => None,
-        Type::Macro(_) => None,
-        Type::Paren(paren) => find_in_type(&paren.elem, matcher),
-        Type::Path(path) => find_in_path(&path.path, matcher),
-        Type::Ptr(ptr) => find_in_type(&ptr.elem, matcher),
-        Type::Reference(reference) => find_in_type(&reference.elem, matcher),
-        Type::Slice(slice) => find_in_type(&slice.elem, matcher),
-        Type::TraitObject(traitobj) => {
-            for bound in &traitobj.bounds {
-                if let TypeParamBound::Trait(bound) = bound {
-                    if let Some(ret) = find_in_path(&bound.path, matcher) {
-                        return Some(ret);
-                    }
-                }
-            }
-            None
-        }
-        Type::Tuple(tuple) => {
-            for elem in &tuple.elems {
-                if let Some(ret) = find_in_type(elem, matcher) {
-                    return Some(ret);
-                }
-            }
-            None
+pub enum TypeOrPath<'a> {
+    Type(&'a Type),
+    Path(&'a Path),
+}
+
+enum IterTypes<A, B, C, D, E, F> {
+    Single(A),
+    Function(B),
+    Tuple(C),
+    ImplTrait(D),
+    TraitObject(E),
+    Path(F),
+    Empty,
+}
+
+impl<A, B, C, D, E, F> Iterator for IterTypes<A, B, C, D, E, F>
+where
+    A: Iterator,
+    B: Iterator<Item = A::Item>,
+    C: Iterator<Item = A::Item>,
+    D: Iterator<Item = A::Item>,
+    E: Iterator<Item = A::Item>,
+    F: Iterator<Item = A::Item>,
+{
+    type Item = A::Item;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        match self {
+            IterTypes::Single(a) => a.next(),
+            IterTypes::Function(a) => a.next(),
+            IterTypes::Tuple(a) => a.next(),
+            IterTypes::ImplTrait(a) => a.next(),
+            IterTypes::TraitObject(a) => a.next(),
+            IterTypes::Path(a) => a.next(),
+            IterTypes::Empty => None,
         }
-        _other => None,
     }
 }
 
-pub fn find_in_path<'a, T>(path: &'a Path, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> {
-    if let Some(ret) = matcher.match_path(path) {
-        return Some(ret);
-    }
-    for segment in &path.segments {
-        if let PathArguments::AngleBracketed(args) = &segment.arguments {
-            for arg in &args.args {
-                if let GenericArgument::Type(t) = arg {
-                    if let Some(ret) = find_in_type(t, matcher) {
-                        return Some(ret);
-                    }
-                }
+pub fn iter_path(path: &Path) -> impl Iterator<Item = TypeOrPath> {
+    iter::once(TypeOrPath::Path(path)).chain(types_in_path(path).flat_map(|t| iter_type(t)))
+}
+
+pub fn iter_type<'a>(t: &'a Type) -> Box<dyn Iterator<Item = TypeOrPath<'a>> + 'a> {
+    Box::new(
+        iter::once(TypeOrPath::Type(t)).chain(match t {
+            Type::Array(array) => IterTypes::Single(iter_type(&array.elem)),
+            Type::Group(group) => IterTypes::Single(iter_type(&group.elem)),
+            Type::Paren(paren) => IterTypes::Single(iter_type(&paren.elem)),
+            Type::Ptr(ptr) => IterTypes::Single(iter_type(&ptr.elem)),
+            Type::Reference(r) => IterTypes::Single(iter_type(&r.elem)),
+            Type::Slice(slice) => IterTypes::Single(iter_type(&slice.elem)),
+
+            Type::Tuple(tuple) => IterTypes::Tuple(tuple.elems.iter().flat_map(|i| iter_type(i))),
+
+            Type::BareFn(fun) => {
+                IterTypes::Function(fun.inputs.iter().flat_map(|i| iter_type(&i.ty)).chain(
+                    match &fun.output {
+                        ReturnType::Default => IterEnum::Left(iter::empty()),
+                        ReturnType::Type(_, ty) => IterEnum::Right(iter_type(ty.as_ref())),
+                    },
+                ))
             }
+            Type::ImplTrait(impl_trait) => IterTypes::ImplTrait(
+                trait_bounds(&impl_trait.bounds)
+                    .flat_map(|b| types_in_path(&b.path))
+                    .flat_map(|t| iter_type(t)),
+            ),
+
+            Type::Path(path) => IterTypes::Path(iter_path(&path.path)),
+            Type::TraitObject(trait_obj) => IterTypes::TraitObject(
+                trait_bounds(&trait_obj.bounds)
+                    .flat_map(|b| types_in_path(&b.path))
+                    .flat_map(|t| iter_type(t)),
+            ),
+            _other => IterTypes::Empty,
+        }),
+    )
+}
+
+enum IterEnum<L, R> {
+    Left(L),
+    Right(R),
+}
+
+impl<L, R> Iterator for IterEnum<L, R>
+where
+    L: Iterator,
+    R: Iterator<Item = L::Item>,
+{
+    type Item = L::Item;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        match self {
+            IterEnum::Left(iter) => iter.next(),
+            IterEnum::Right(iter) => iter.next(),
         }
     }
-    None
+}
+
+fn types_in_path(p: &Path) -> impl Iterator<Item = &Type> {
+    p.segments.iter().flat_map(|s| match &s.arguments {
+        PathArguments::AngleBracketed(ang) => {
+            IterEnum::Left(ang.args.iter().flat_map(|a| match a {
+                GenericArgument::Type(t) => Some(t),
+                _other => None,
+            }))
+        }
+        _other => IterEnum::Right(iter::empty()),
+    })
 }
diff --git a/src/transform.rs b/src/transform.rs
index c3116cf..6640cf3 100644
--- a/src/transform.rs
+++ b/src/transform.rs
@@ -7,10 +7,10 @@ use syn::{
 };
 
 use crate::{
+    filter_map_assoc_paths, match_assoc_type,
     parse_assoc_type::DestType,
     parse_trait_sig::{MethodParseError, TypeTransform},
-    syn_utils::{find_in_path, find_in_type},
-    AssocTypeMatcher,
+    syn_utils::{iter_path, iter_type},
 };
 
 #[derive(Default)]
@@ -23,7 +23,7 @@ pub enum TransformError {
 
 impl AssocTypeConversions<'_> {
     pub fn parse_type_path(&self, type_: &mut Type) -> Result<TypeTransform, TransformError> {
-        let assoc_span = match find_in_type(type_, &AssocTypeMatcher) {
+        let assoc_span = match iter_type(type_).filter_map(filter_map_assoc_paths).next() {
             Some(path) => path.span(),
             None => return Ok(TypeTransform::NoOp),
         };
@@ -49,7 +49,7 @@ impl AssocTypeConversions<'_> {
                     if args.args.len() == 1 {
                         if let GenericArgument::Type(generic_type) = args.args.first_mut().unwrap()
                         {
-                            if find_in_type(generic_type, &AssocTypeMatcher).is_some() {
+                            if iter_type(generic_type).any(match_assoc_type) {
                                 return Ok(TypeTransform::Map(
                                     self.parse_type_path(generic_type)?.into(),
                                 ));
@@ -65,8 +65,8 @@ impl AssocTypeConversions<'_> {
                         if let (GenericArgument::Type(ok_type), GenericArgument::Type(err_type)) =
                             (args_iter.next().unwrap(), args_iter.next().unwrap())
                         {
-                            if find_in_type(ok_type, &AssocTypeMatcher).is_some()
-                                || find_in_type(err_type, &AssocTypeMatcher).is_some()
+                            if iter_type(ok_type).any(match_assoc_type)
+                                || iter_type(err_type).any(match_assoc_type)
                             {
                                 return Ok(TypeTransform::Result(
                                     self.parse_type_path(ok_type)?.into(),
@@ -85,7 +85,7 @@ impl AssocTypeConversions<'_> {
                             if let GenericArgument::Type(generic_type) =
                                 args.args.first_mut().unwrap()
                             {
-                                if find_in_type(generic_type, &AssocTypeMatcher).is_some() {
+                                if iter_type(generic_type).any(match_assoc_type) {
                                     return Ok(TypeTransform::Map(
                                         self.parse_type_path(generic_type)?.into(),
                                     ));
@@ -146,7 +146,10 @@ pub fn dynamize_function_bounds(
                         }
                     }
                 }
-                if let Some(path) = find_in_path(&bound.path, &AssocTypeMatcher) {
+                if let Some(path) = iter_path(&bound.path)
+                    .filter_map(filter_map_assoc_paths)
+                    .next()
+                {
                     return Err((
                         path.span(),
                         MethodParseError::UnconvertibleAssocTypeInTraitBound,
-- 
cgit v1.2.3