aboutsummaryrefslogtreecommitdiff
path: root/src/syn_utils.rs
diff options
context:
space:
mode:
authorMartin Fischer <martin@push-f.com>2021-11-20 15:10:23 +0100
committerMartin Fischer <martin@push-f.com>2021-11-20 18:36:07 +0100
commit91ca29ade1ab6b90dda972c2a42a5dc529ddfb44 (patch)
treeaec288a5858fe3e8a90f0357c69571e091cf85d6 /src/syn_utils.rs
parente918eb0ab2cf6d84751f5f52b49414d382b7abbb (diff)
refactor: traverse AST via iterators
Diffstat (limited to 'src/syn_utils.rs')
-rw-r--r--src/syn_utils.rs181
1 files changed, 105 insertions, 76 deletions
diff --git a/src/syn_utils.rs b/src/syn_utils.rs
index a778793..43c4d94 100644
--- a/src/syn_utils.rs
+++ b/src/syn_utils.rs
@@ -1,18 +1,10 @@
+use std::iter;
+
use syn::{
punctuated::Punctuated, GenericArgument, Lifetime, Path, PathArguments, ReturnType, TraitBound,
Type, TypeParamBound,
};
-#[allow(unused_variables)]
-pub trait TypeMatcher<T> {
- fn match_type<'a>(&self, t: &'a Type) -> Option<&'a T> {
- None
- }
- fn match_path<'a>(&self, path: &'a Path) -> Option<&'a T> {
- None
- }
-}
-
pub fn trait_bounds<T>(
bounds: &Punctuated<TypeParamBound, T>,
) -> impl Iterator<Item = &TraitBound> {
@@ -31,77 +23,114 @@ pub fn lifetime_bounds<T>(
})
}
-pub fn find_in_type<'a, T>(t: &'a Type, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> {
- if let Some(ret) = matcher.match_type(t) {
- return Some(ret);
- }
- match t {
- Type::Array(array) => find_in_type(&array.elem, matcher),
- Type::BareFn(fun) => {
- for input in &fun.inputs {
- if let Some(ret) = find_in_type(&input.ty, matcher) {
- return Some(ret);
- }
- }
- if let ReturnType::Type(_, t) = &fun.output {
- return find_in_type(t, matcher);
- }
- None
- }
- Type::Group(group) => find_in_type(&group.elem, matcher),
- Type::ImplTrait(impltrait) => {
- for bound in &impltrait.bounds {
- if let TypeParamBound::Trait(bound) = bound {
- if let Some(ret) = find_in_path(&bound.path, matcher) {
- return Some(ret);
- }
- }
- }
- None
- }
- Type::Infer(_) => None,
- Type::Macro(_) => None,
- Type::Paren(paren) => find_in_type(&paren.elem, matcher),
- Type::Path(path) => find_in_path(&path.path, matcher),
- Type::Ptr(ptr) => find_in_type(&ptr.elem, matcher),
- Type::Reference(reference) => find_in_type(&reference.elem, matcher),
- Type::Slice(slice) => find_in_type(&slice.elem, matcher),
- Type::TraitObject(traitobj) => {
- for bound in &traitobj.bounds {
- if let TypeParamBound::Trait(bound) = bound {
- if let Some(ret) = find_in_path(&bound.path, matcher) {
- return Some(ret);
- }
- }
- }
- None
- }
- Type::Tuple(tuple) => {
- for elem in &tuple.elems {
- if let Some(ret) = find_in_type(elem, matcher) {
- return Some(ret);
- }
- }
- None
+pub enum TypeOrPath<'a> {
+ Type(&'a Type),
+ Path(&'a Path),
+}
+
+enum IterTypes<A, B, C, D, E, F> {
+ Single(A),
+ Function(B),
+ Tuple(C),
+ ImplTrait(D),
+ TraitObject(E),
+ Path(F),
+ Empty,
+}
+
+impl<A, B, C, D, E, F> Iterator for IterTypes<A, B, C, D, E, F>
+where
+ A: Iterator,
+ B: Iterator<Item = A::Item>,
+ C: Iterator<Item = A::Item>,
+ D: Iterator<Item = A::Item>,
+ E: Iterator<Item = A::Item>,
+ F: Iterator<Item = A::Item>,
+{
+ type Item = A::Item;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ match self {
+ IterTypes::Single(a) => a.next(),
+ IterTypes::Function(a) => a.next(),
+ IterTypes::Tuple(a) => a.next(),
+ IterTypes::ImplTrait(a) => a.next(),
+ IterTypes::TraitObject(a) => a.next(),
+ IterTypes::Path(a) => a.next(),
+ IterTypes::Empty => None,
}
- _other => None,
}
}
-pub fn find_in_path<'a, T>(path: &'a Path, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> {
- if let Some(ret) = matcher.match_path(path) {
- return Some(ret);
- }
- for segment in &path.segments {
- if let PathArguments::AngleBracketed(args) = &segment.arguments {
- for arg in &args.args {
- if let GenericArgument::Type(t) = arg {
- if let Some(ret) = find_in_type(t, matcher) {
- return Some(ret);
- }
- }
+pub fn iter_path(path: &Path) -> impl Iterator<Item = TypeOrPath> {
+ iter::once(TypeOrPath::Path(path)).chain(types_in_path(path).flat_map(|t| iter_type(t)))
+}
+
+pub fn iter_type<'a>(t: &'a Type) -> Box<dyn Iterator<Item = TypeOrPath<'a>> + 'a> {
+ Box::new(
+ iter::once(TypeOrPath::Type(t)).chain(match t {
+ Type::Array(array) => IterTypes::Single(iter_type(&array.elem)),
+ Type::Group(group) => IterTypes::Single(iter_type(&group.elem)),
+ Type::Paren(paren) => IterTypes::Single(iter_type(&paren.elem)),
+ Type::Ptr(ptr) => IterTypes::Single(iter_type(&ptr.elem)),
+ Type::Reference(r) => IterTypes::Single(iter_type(&r.elem)),
+ Type::Slice(slice) => IterTypes::Single(iter_type(&slice.elem)),
+
+ Type::Tuple(tuple) => IterTypes::Tuple(tuple.elems.iter().flat_map(|i| iter_type(i))),
+
+ Type::BareFn(fun) => {
+ IterTypes::Function(fun.inputs.iter().flat_map(|i| iter_type(&i.ty)).chain(
+ match &fun.output {
+ ReturnType::Default => IterEnum::Left(iter::empty()),
+ ReturnType::Type(_, ty) => IterEnum::Right(iter_type(ty.as_ref())),
+ },
+ ))
}
+ Type::ImplTrait(impl_trait) => IterTypes::ImplTrait(
+ trait_bounds(&impl_trait.bounds)
+ .flat_map(|b| types_in_path(&b.path))
+ .flat_map(|t| iter_type(t)),
+ ),
+
+ Type::Path(path) => IterTypes::Path(iter_path(&path.path)),
+ Type::TraitObject(trait_obj) => IterTypes::TraitObject(
+ trait_bounds(&trait_obj.bounds)
+ .flat_map(|b| types_in_path(&b.path))
+ .flat_map(|t| iter_type(t)),
+ ),
+ _other => IterTypes::Empty,
+ }),
+ )
+}
+
+enum IterEnum<L, R> {
+ Left(L),
+ Right(R),
+}
+
+impl<L, R> Iterator for IterEnum<L, R>
+where
+ L: Iterator,
+ R: Iterator<Item = L::Item>,
+{
+ type Item = L::Item;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ match self {
+ IterEnum::Left(iter) => iter.next(),
+ IterEnum::Right(iter) => iter.next(),
}
}
- None
+}
+
+fn types_in_path(p: &Path) -> impl Iterator<Item = &Type> {
+ p.segments.iter().flat_map(|s| match &s.arguments {
+ PathArguments::AngleBracketed(ang) => {
+ IterEnum::Left(ang.args.iter().flat_map(|a| match a {
+ GenericArgument::Type(t) => Some(t),
+ _other => None,
+ }))
+ }
+ _other => IterEnum::Right(iter::empty()),
+ })
}