From 06a384000b0a6c7b657e4f8a9145b3d4357f3d70 Mon Sep 17 00:00:00 2001
From: Martin Fischer <martin@push-f.com>
Date: Thu, 25 Nov 2021 08:49:47 +0100
Subject: refactor: factor out parse_attrs module

---
 src/lib.rs         | 94 ++++++++----------------------------------------------
 src/parse_attrs.rs | 93 +++++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 106 insertions(+), 81 deletions(-)
 create mode 100644 src/parse_attrs.rs

(limited to 'src')

diff --git a/src/lib.rs b/src/lib.rs
index 1cf086c..ace708c 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -3,13 +3,10 @@
 use std::collections::HashMap;
 
 use proc_macro::TokenStream;
-use proc_macro2::Group;
 use quote::format_ident;
 use quote::quote;
 use quote::quote_spanned;
 use quote::ToTokens;
-use syn::parenthesized;
-use syn::parse::Parse;
 use syn::parse_macro_input;
 use syn::spanned::Spanned;
 use syn::token::Brace;
@@ -18,20 +15,16 @@ use syn::token::Lt;
 use syn::token::Trait;
 use syn::AngleBracketedGenericArguments;
 use syn::Block;
-use syn::Error;
 use syn::Expr;
 use syn::GenericArgument;
 use syn::GenericParam;
-use syn::Ident;
 use syn::ImplItemMethod;
 use syn::ItemTrait;
-use syn::LitInt;
 use syn::Path;
 use syn::PathArguments;
 use syn::PathSegment;
 use syn::Signature;
 use syn::Stmt;
-use syn::Token;
 use syn::TraitBound;
 use syn::TraitItem;
 use syn::TraitItemMethod;
@@ -44,6 +37,7 @@ use syn_utils::TypeOrPath;
 
 use crate::parse_assoc_type::parse_assoc_type;
 use crate::parse_assoc_type::AssocTypeError;
+use crate::parse_attrs::TraitAttrs;
 use crate::syn_utils::iter_path;
 use crate::syn_utils::trait_bounds;
 use crate::trait_sig::convert_trait_signature;
@@ -54,6 +48,7 @@ use crate::transform::TransformError;
 use crate::transform::TypeConverter;
 
 mod parse_assoc_type;
+mod parse_attrs;
 mod syn_utils;
 mod trait_sig;
 mod transform;
@@ -66,30 +61,6 @@ macro_rules! abort {
     }};
 }
 
-struct Collection {
-    id: Ident,
-    count: usize,
-}
-
-impl Parse for Collection {
-    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
-        let inner;
-        parenthesized!(inner in input);
-
-        let id: Ident = inner.parse()?;
-        let _: Token![,] = inner.parse()?;
-        let count_lit: LitInt = inner.parse()?;
-        let count: usize = count_lit.base10_parse()?;
-        if count < 1 {
-            return Err(Error::new(
-                count_lit.span(),
-                "number of type parameters must be >= 1",
-            ));
-        }
-        Ok(Self { id, count })
-    }
-}
-
 #[proc_macro_attribute]
 pub fn dynamize(_attr: TokenStream, input: TokenStream) -> TokenStream {
     let mut original_trait = parse_macro_input!(input as ItemTrait);
@@ -108,57 +79,15 @@ pub fn dynamize(_attr: TokenStream, input: TokenStream) -> TokenStream {
         }
     }
 
-    let mut type_converter = TypeConverter::default();
-
-    let mut blanket_impl_attrs = Vec::new();
-    let mut dyn_trait_attrs = Vec::new();
+    let method_attrs = match TraitAttrs::parse(&mut original_trait.attrs) {
+        Ok(attrs) => attrs,
+        Err(err) => return err.to_compile_error().into(),
+    };
 
-    // FUTURE: use Vec::drain_filter once it's stable
-    let mut i = 0;
-    while i < original_trait.attrs.len() {
-        if original_trait.attrs[i].path.is_ident("blanket_impl_attr") {
-            let attr = original_trait.attrs.remove(i);
-            let group: Group = match syn::parse2(attr.tokens) {
-                Ok(g) => g,
-                Err(err) => {
-                    return abort!(
-                        err.span(),
-                        "expected parenthesis: #[blanket_impl_attr(...)]"
-                    )
-                }
-            };
-            let tokens = group.stream();
-            blanket_impl_attrs.push(quote! {#[#tokens]});
-        } else if original_trait.attrs[i].path.is_ident("dyn_trait_attr") {
-            let attr = original_trait.attrs.remove(i);
-            let group: Group = match syn::parse2(attr.tokens) {
-                Ok(g) => g,
-                Err(err) => {
-                    return abort!(err.span(), "expected parenthesis: #[dyn_trait_attr(...)]")
-                }
-            };
-            let tokens = group.stream();
-            dyn_trait_attrs.push(quote! {#[#tokens]});
-        } else if original_trait.attrs[i].path.is_ident("collection") {
-            let attr = original_trait.attrs.remove(i);
-            let tokens = attr.tokens.into();
-            let coll = parse_macro_input!(tokens as Collection);
-
-            if type_converter
-                .collections
-                .insert(coll.id.clone(), coll.count)
-                .is_some()
-            {
-                return abort!(
-                    coll.id.span(),
-                    "collection `{}` is defined multiple times for this trait",
-                    coll.id
-                );
-            }
-        } else {
-            i += 1;
-        }
-    }
+    let mut type_converter = TypeConverter {
+        collections: method_attrs.collections,
+        ..TypeConverter::default()
+    };
 
     for item in &original_trait.items {
         if let TraitItem::Type(assoc_type) = item {
@@ -393,6 +322,9 @@ pub fn dynamize(_attr: TokenStream, input: TokenStream) -> TokenStream {
     let dyn_trait_name = &dyn_trait.ident;
     let (impl_generics, ty_generics, where_clause) = dyn_trait.generics.split_for_impl();
 
+    let dyn_trait_attrs = method_attrs.dyn_trait_attrs;
+    let blanket_impl_attrs = method_attrs.blanket_impl_attrs;
+
     let expanded = quote! {
         #original_trait
 
diff --git a/src/parse_attrs.rs b/src/parse_attrs.rs
new file mode 100644
index 0000000..03b5377
--- /dev/null
+++ b/src/parse_attrs.rs
@@ -0,0 +1,93 @@
+use std::collections::HashMap;
+
+use proc_macro2::{Group, TokenStream};
+use quote::quote;
+use syn::{parenthesized, parse::Parse, Attribute, Error, Ident, LitInt, Token};
+
+struct Collection {
+    pub id: Ident,
+    pub count: usize,
+}
+
+impl Parse for Collection {
+    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+        let inner;
+        parenthesized!(inner in input);
+
+        let id: Ident = inner.parse()?;
+        let _: Token![,] = inner.parse()?;
+        let count_lit: LitInt = inner.parse()?;
+        let count: usize = count_lit.base10_parse()?;
+        if count < 1 {
+            return Err(Error::new(
+                count_lit.span(),
+                "number of type parameters must be >= 1",
+            ));
+        }
+        Ok(Self { id, count })
+    }
+}
+
+#[derive(Default)]
+pub struct TraitAttrs {
+    pub blanket_impl_attrs: Vec<TokenStream>,
+    pub dyn_trait_attrs: Vec<TokenStream>,
+    pub collections: HashMap<Ident, usize>,
+}
+
+impl TraitAttrs {
+    pub fn parse(attrs: &mut Vec<Attribute>) -> Result<Self, Error> {
+        let mut parsed = TraitAttrs::default();
+        // FUTURE: use Vec::drain_filter once it's stable
+        let mut i = 0;
+        while i < attrs.len() {
+            if attrs[i].path.is_ident("blanket_impl_attr") {
+                let attr = attrs.remove(i);
+                let group: Group = match syn::parse2(attr.tokens) {
+                    Ok(g) => g,
+                    Err(err) => {
+                        return Err(Error::new(
+                            err.span(),
+                            "expected parenthesis: #[blanket_impl_attr(...)]",
+                        ))
+                    }
+                };
+                let tokens = group.stream();
+                parsed.blanket_impl_attrs.push(quote! {#[#tokens]});
+            } else if attrs[i].path.is_ident("dyn_trait_attr") {
+                let attr = attrs.remove(i);
+                let group: Group = match syn::parse2(attr.tokens) {
+                    Ok(g) => g,
+                    Err(err) => {
+                        return Err(Error::new(
+                            err.span(),
+                            "expected parenthesis: #[dyn_trait_attr(...)]",
+                        ))
+                    }
+                };
+                let tokens = group.stream();
+                parsed.dyn_trait_attrs.push(quote! {#[#tokens]});
+            } else if attrs[i].path.is_ident("collection") {
+                let attr = attrs.remove(i);
+                let coll: Collection = syn::parse2(attr.tokens)?;
+
+                if parsed
+                    .collections
+                    .insert(coll.id.clone(), coll.count)
+                    .is_some()
+                {
+                    return Err(Error::new(
+                        coll.id.span(),
+                        format_args!(
+                            "collection `{}` is defined multiple times for this trait",
+                            coll.id
+                        ),
+                    ));
+                }
+            } else {
+                i += 1;
+            }
+        }
+        Ok(parsed)
+    }
+}
-- 
cgit v1.2.3