From 346113bbebddbd199b61249957c7569514071e89 Mon Sep 17 00:00:00 2001
From: Martin Fischer <martin@push-f.com>
Date: Mon, 22 Nov 2021 15:57:17 +0100
Subject: support other collections via #[collection(...)]

---
 README.md                                         | 26 ++++++++++++
 src/lib.rs                                        | 48 ++++++++++++++++++++++-
 src/transform.rs                                  |  5 +++
 tests/tests.rs                                    | 25 ++++++++++++
 ui-tests/src/bin/attr_collection_duplicate.rs     |  6 +++
 ui-tests/src/bin/attr_collection_duplicate.stderr |  5 +++
 ui-tests/src/bin/attr_collection_zero.rs          |  5 +++
 ui-tests/src/bin/attr_collection_zero.stderr      |  5 +++
 8 files changed, 124 insertions(+), 1 deletion(-)
 create mode 100644 ui-tests/src/bin/attr_collection_duplicate.rs
 create mode 100644 ui-tests/src/bin/attr_collection_duplicate.stderr
 create mode 100644 ui-tests/src/bin/attr_collection_zero.rs
 create mode 100644 ui-tests/src/bin/attr_collection_zero.stderr

diff --git a/README.md b/README.md
index 6051fb5..555e7f1 100644
--- a/README.md
+++ b/README.md
@@ -147,3 +147,29 @@ trait Client: Sync {
 
 Note that it is important that the `#[dynamize]` attribute comes before the
 `#[async_trait]` attribute, since dynamize must run before async_trait.
+
+## Using dynamize with other collections
+
+Dynamize automatically recognizes collections from the standard library like
+`Vec<_>` and `HashMap<_, _>`. Dynamize can also work with other collection
+types as long as they implement `IntoIterator` and `FromIterator`, for example
+dynamize can be used with [indexmap](https://crates.io/crates/indexmap) as
+follows:
+
+```rust ignore
+#[dynamize::dynamize]
+#[collection(IndexMap, 2)]
+trait Trait {
+    type A: Into<String>;
+    type B: Into<i32>;
+
+    fn example(&self) -> IndexMap<Self::A, Self::B>;
+}
+```
+
+The passed number tells dynamize how many generic type parameters to expect.
+
+* for 1 dynamize expects: `Type<A>: IntoIterator<Item=A> + FromIterator<A>`
+* for 2 dynamize expects: `Type<A,B>: IntoIterator<Item=(A,B)> + FromIterator<(A,B)>`
+* for 3 dynamize expects: `Type<A,B,C>: IntoIterator<Item=(A,B,C)> + FromIterator<(A,B,C)>`
+* etc ...
diff --git a/src/lib.rs b/src/lib.rs
index 13cbce7..05fdb81 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -8,6 +8,8 @@ use quote::format_ident;
 use quote::quote;
 use quote::quote_spanned;
 use quote::ToTokens;
+use syn::parenthesized;
+use syn::parse::Parse;
 use syn::parse_macro_input;
 use syn::spanned::Spanned;
 use syn::token::Brace;
@@ -16,16 +18,20 @@ use syn::token::Lt;
 use syn::token::Trait;
 use syn::AngleBracketedGenericArguments;
 use syn::Block;
+use syn::Error;
 use syn::Expr;
 use syn::GenericArgument;
 use syn::GenericParam;
+use syn::Ident;
 use syn::ImplItemMethod;
 use syn::ItemTrait;
+use syn::LitInt;
 use syn::Path;
 use syn::PathArguments;
 use syn::PathSegment;
 use syn::Signature;
 use syn::Stmt;
+use syn::Token;
 use syn::TraitBound;
 use syn::TraitItem;
 use syn::TraitItemMethod;
@@ -53,13 +59,37 @@ mod syn_utils;
 mod transform;
 
 macro_rules! abort {
-    ($span:expr, $message:literal $(,$args:tt)*) => {{
+    ($span:expr, $message:literal $(,$args:expr)*) => {{
         let msg = format!($message $(,$args)*);
         let tokens = quote_spanned! {$span => compile_error!(#msg);}.into();
         tokens
     }};
 }
 
+struct Collection {
+    id: Ident,
+    count: usize,
+}
+
+impl Parse for Collection {
+    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+        let inner;
+        parenthesized!(inner in input);
+
+        let id: Ident = inner.parse()?;
+        let _: Token![,] = inner.parse()?;
+        let count_lit: LitInt = inner.parse()?;
+        let count: usize = count_lit.base10_parse()?;
+        if count < 1 {
+            return Err(Error::new(
+                count_lit.span(),
+                "number of type parameters must be >= 1",
+            ));
+        }
+        Ok(Self { id, count })
+    }
+}
+
 #[proc_macro_attribute]
 pub fn dynamize(_attr: TokenStream, input: TokenStream) -> TokenStream {
     let mut original_trait = parse_macro_input!(input as ItemTrait);
@@ -100,6 +130,22 @@ pub fn dynamize(_attr: TokenStream, input: TokenStream) -> TokenStream {
             };
             let tokens = group.stream();
             dyn_trait_attrs.push(quote! {#[#tokens]});
+        } else if original_trait.attrs[i].path.is_ident("collection") {
+            let attr = original_trait.attrs.remove(i);
+            let tokens = attr.tokens.into();
+            let coll = parse_macro_input!(tokens as Collection);
+
+            if type_converter
+                .collections
+                .insert(coll.id.clone(), coll.count)
+                .is_some()
+            {
+                return abort!(
+                    coll.id.span(),
+                    "collection `{}` is defined multiple times for this trait",
+                    coll.id
+                );
+            }
         } else {
             i += 1;
         }
diff --git a/src/transform.rs b/src/transform.rs
index 05866b6..11a98c2 100644
--- a/src/transform.rs
+++ b/src/transform.rs
@@ -17,6 +17,7 @@ use crate::{
 #[derive(Default)]
 pub struct TypeConverter<'a> {
     pub assoc_type_conversions: HashMap<Ident, DestType<'a>>,
+    pub collections: HashMap<Ident, usize>,
 }
 
 #[derive(Debug)]
@@ -39,6 +40,10 @@ impl TypeConverter<'_> {
     /// ... etc. A return type of None means the type isn't recognized.
     #[rustfmt::skip]
     fn get_collection_type_count(&self, ident: &Ident) -> Option<usize> {
+        if let Some(count) = self.collections.get(ident) {
+            return Some(*count);
+        }
+
         // when adding a type here don't forget to document it in the README
         if ident == "Vec"        { return Some(1); }
         if ident == "VecDeque"   { return Some(1); }
diff --git a/tests/tests.rs b/tests/tests.rs
index 925fe72..1bff07f 100644
--- a/tests/tests.rs
+++ b/tests/tests.rs
@@ -251,3 +251,28 @@ trait FunIter {
 
     fn foobar2<H: Fn(&mut dyn Iterator<Item = &mut dyn Iterator<Item = Self::A>>)>(&mut self, f: H);
 }
+
+struct MyCollection<A, B, C>(A, B, C);
+impl<A, B, C> IntoIterator for MyCollection<A, B, C> {
+    type Item = (A, B, C);
+
+    type IntoIter = Box<dyn Iterator<Item = (A, B, C)>>;
+
+    fn into_iter(self) -> Self::IntoIter {
+        todo!()
+    }
+}
+
+impl<A, B, C> FromIterator<(A, B, C)> for MyCollection<A, B, C> {
+    fn from_iter<T: IntoIterator<Item = (A, B, C)>>(_iter: T) -> Self {
+        todo!()
+    }
+}
+
+#[dynamize::dynamize]
+#[collection(MyCollection, 3)]
+trait CustomCollection {
+    type A: Into<String>;
+
+    fn test(&self) -> MyCollection<Self::A, Self::A, Self::A>;
+}
diff --git a/ui-tests/src/bin/attr_collection_duplicate.rs b/ui-tests/src/bin/attr_collection_duplicate.rs
new file mode 100644
index 0000000..27e412d
--- /dev/null
+++ b/ui-tests/src/bin/attr_collection_duplicate.rs
@@ -0,0 +1,6 @@
+#[dynamize::dynamize]
+#[collection(Foo, 1)]
+#[collection(Foo, 2)]
+trait Trait {}
+
+fn main() {}
diff --git a/ui-tests/src/bin/attr_collection_duplicate.stderr b/ui-tests/src/bin/attr_collection_duplicate.stderr
new file mode 100644
index 0000000..7d91236
--- /dev/null
+++ b/ui-tests/src/bin/attr_collection_duplicate.stderr
@@ -0,0 +1,5 @@
+error: collection `Foo` is defined multiple times for this trait
+ --> src/bin/attr_collection_duplicate.rs:3:14
+  |
+3 | #[collection(Foo, 2)]
+  |              ^^^
diff --git a/ui-tests/src/bin/attr_collection_zero.rs b/ui-tests/src/bin/attr_collection_zero.rs
new file mode 100644
index 0000000..6f45ea8
--- /dev/null
+++ b/ui-tests/src/bin/attr_collection_zero.rs
@@ -0,0 +1,5 @@
+#[dynamize::dynamize]
+#[collection(Foo, 0)]
+trait Trait {}
+
+fn main() {}
diff --git a/ui-tests/src/bin/attr_collection_zero.stderr b/ui-tests/src/bin/attr_collection_zero.stderr
new file mode 100644
index 0000000..47a87c9
--- /dev/null
+++ b/ui-tests/src/bin/attr_collection_zero.stderr
@@ -0,0 +1,5 @@
+error: number of type parameters must be >= 1
+ --> src/bin/attr_collection_zero.rs:2:19
+  |
+2 | #[collection(Foo, 0)]
+  |                   ^
-- 
cgit v1.2.3