aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMartin Fischer <martin@push-f.com>2021-11-15 10:29:52 +0100
committerMartin Fischer <martin@push-f.com>2021-11-18 23:36:01 +0100
commit2a8a0601afcb82d90d0766db5a954b70b10f856d (patch)
tree0271062335d450e151598d4ad9aa327ffa0dfaea /src
publishv0.1.0
Diffstat (limited to 'src')
-rw-r--r--src/lib.rs425
-rw-r--r--src/parse_assoc_type.rs119
-rw-r--r--src/parse_trait_sig.rs505
-rw-r--r--src/syn_utils.rs97
4 files changed, 1146 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..1f0432c
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,425 @@
+#![doc = include_str!("../README.md")]
+use std::collections::HashMap;
+
+use proc_macro::TokenStream;
+use proc_macro2::Group;
+use proc_macro2::Ident;
+use quote::format_ident;
+use quote::quote;
+use quote::quote_spanned;
+use quote::ToTokens;
+use syn::parse_macro_input;
+use syn::punctuated::Punctuated;
+use syn::token::Brace;
+use syn::token::Gt;
+use syn::token::Lt;
+use syn::token::Trait;
+use syn::AngleBracketedGenericArguments;
+use syn::Block;
+use syn::Expr;
+use syn::GenericArgument;
+use syn::GenericParam;
+use syn::ImplItemMethod;
+use syn::ItemTrait;
+use syn::Path;
+use syn::PathArguments;
+use syn::PathSegment;
+use syn::Signature;
+use syn::Stmt;
+use syn::TraitBound;
+use syn::TraitItem;
+use syn::TraitItemMethod;
+use syn::Type;
+use syn::TypeParam;
+use syn::TypeParamBound;
+use syn::TypePath;
+use syn::Visibility;
+use syn_utils::TypeMatcher;
+
+use crate::parse_assoc_type::parse_assoc_type;
+use crate::parse_assoc_type::AssocTypeParseError;
+use crate::parse_trait_sig::parse_trait_signature;
+use crate::parse_trait_sig::MethodParseError;
+use crate::parse_trait_sig::SignatureChanges;
+use crate::parse_trait_sig::TypeTransform;
+mod parse_assoc_type;
+mod parse_trait_sig;
+mod syn_utils;
+
+macro_rules! abort {
+ ($span:expr, $message:literal) => {
+ quote_spanned! {$span => compile_error!($message);}.into()
+ };
+}
+
+#[proc_macro_attribute]
+pub fn dynamize(_attr: TokenStream, input: TokenStream) -> TokenStream {
+ let mut original_trait = parse_macro_input!(input as ItemTrait);
+ assert!(original_trait.auto_token.is_none());
+
+ let original_trait_name = original_trait.ident.clone();
+
+ let mut objectifiable_methods: Vec<(Signature, SignatureChanges)> = Vec::new();
+
+ let mut assoc_type_conversions: HashMap<&Ident, &Type> = HashMap::new();
+
+ for item in &original_trait.items {
+ if let TraitItem::Type(assoc_type) = item {
+ match parse_assoc_type(assoc_type) {
+ Err((_, AssocTypeParseError::NoIntoBound)) => continue,
+ Err((span, AssocTypeParseError::AssocTypeInBound)) => {
+ return abort!(span, "dynamize does not support associated types here")
+ }
+ Err((span, AssocTypeParseError::GenericAssociatedType)) => {
+ return abort!(
+ span,
+ "dynamize does not (yet?) support generic associated types"
+ )
+ }
+ Ok((ident, type_)) => {
+ assoc_type_conversions.insert(ident, type_);
+ }
+ }
+ }
+ }
+
+ for item in &original_trait.items {
+ if let TraitItem::Method(method) = item {
+ let mut signature = method.sig.clone();
+ match parse_trait_signature(&mut signature, &assoc_type_conversions) {
+ Err((_, MethodParseError::NonDispatchableMethod)) => continue,
+ Err((span, MethodParseError::UnconvertibleAssocType)) => {
+ return abort!(
+ span,
+ "associated type is either undefined or doesn't have an Into bound"
+ )
+ }
+ Err((span, MethodParseError::AssocTypeInInputs)) => {
+ return abort!(
+ span,
+ "dynamize does not support associated types in parameter types"
+ )
+ }
+ Err((
+ span,
+ MethodParseError::AssocTypeInUnsupportedReturnType
+ | MethodParseError::UnconvertibleAssocTypeInFnInput,
+ )) => return abort!(span, "dynamize does not know how to convert this type"),
+ Err((span, MethodParseError::UnconvertibleAssocTypeInTraitBound)) => {
+ return abort!(span, "dynamize does not support associated types here")
+ }
+ Err((span, MethodParseError::ImplTraitInInputs)) => {
+ return abort!(
+ span,
+ "dynamize does not support impl here, change it to a method generic"
+ )
+ }
+ Ok(parsed_method) => objectifiable_methods.push((signature, parsed_method)),
+ };
+ }
+ }
+
+ let mut method_impls: Vec<ImplItemMethod> = Vec::new();
+
+ let mut blanket_impl_attrs = Vec::new();
+ let mut dyn_trait_attrs = Vec::new();
+
+ // FUTURE: use Vec::drain_filter once it's stable
+ let mut i = 0;
+ while i < original_trait.attrs.len() {
+ if original_trait.attrs[i].path.is_ident("blanket_impl_attr") {
+ let attr = original_trait.attrs.remove(i);
+ let group: Group = match syn::parse2(attr.tokens) {
+ Ok(g) => g,
+ Err(err) => {
+ return abort!(
+ err.span(),
+ "expected parenthesis: #[blanket_impl_attr(...)]"
+ )
+ }
+ };
+ let tokens = group.stream();
+ blanket_impl_attrs.push(quote! {#[#tokens]});
+ } else if original_trait.attrs[i].path.is_ident("dyn_trait_attr") {
+ let attr = original_trait.attrs.remove(i);
+ let group: Group = match syn::parse2(attr.tokens) {
+ Ok(g) => g,
+ Err(err) => {
+ return abort!(err.span(), "expected parenthesis: #[dyn_trait_attr(...)]")
+ }
+ };
+ let tokens = group.stream();
+ dyn_trait_attrs.push(quote! {#[#tokens]});
+ } else {
+ i += 1;
+ }
+ }
+
+ let mut dyn_trait = ItemTrait {
+ ident: format_ident!("Dyn{}", original_trait.ident),
+
+ attrs: Vec::new(),
+ vis: original_trait.vis.clone(),
+ unsafety: original_trait.unsafety,
+ auto_token: None,
+ trait_token: Trait::default(),
+ generics: original_trait.generics.clone(),
+ colon_token: None,
+ supertraits: Punctuated::new(),
+ brace_token: Brace::default(),
+ items: Vec::new(),
+ };
+
+ for (signature, parsed_method) in objectifiable_methods {
+ let mut new_method = TraitItemMethod {
+ attrs: Vec::new(),
+ sig: signature,
+ default: None,
+ semi_token: None,
+ };
+
+ let fun_name = &new_method.sig.ident;
+
+ let args = new_method
+ .sig
+ .inputs
+ .iter()
+ .enumerate()
+ .map(|(idx, arg)| match arg {
+ syn::FnArg::Receiver(_) => quote! {self},
+ syn::FnArg::Typed(pat_type) => match pat_type.pat.as_ref() {
+ syn::Pat::Ident(ident) => match &parsed_method.inputs[idx] {
+ None => ident.ident.to_token_stream(),
+ Some(transforms) => {
+ let args = (0..transforms.len()).map(|i| format_ident!("a{}", i));
+ let mut calls: Vec<_> =
+ args.clone().map(|i| i.into_token_stream()).collect();
+ for i in 0..calls.len() {
+ transforms[i].append_conversion(&mut calls[i]);
+ }
+ let move_opt = new_method.sig.asyncness.map(|_| quote! {move});
+ quote!(#move_opt |#(#args),*| #ident(#(#calls),*))
+ }
+ },
+ _other => {
+ panic!("unexpected");
+ }
+ },
+ });
+
+ // in order for a trait to be object-safe its methods may not have
+ // generics so we convert method generics into trait generics
+ if new_method
+ .sig
+ .generics
+ .params
+ .iter()
+ .any(|p| matches!(p, GenericParam::Type(_)))
+ {
+ // syn::punctuated::Punctuated doesn't have a remove(index)
+ // method so we firstly move all elements to a vector
+ let mut params = Vec::new();
+ while let Some(generic_param) = new_method.sig.generics.params.pop() {
+ params.push(generic_param.into_value());
+ }
+
+ // FUTURE: use Vec::drain_filter once it's stable
+ let mut i = 0;
+ while i < params.len() {
+ if matches!(params[i], GenericParam::Type(_)) {
+ dyn_trait.generics.params.push(params.remove(i));
+ } else {
+ i += 1;
+ }
+ }
+
+ new_method.sig.generics.params.extend(params);
+
+ if dyn_trait.generics.lt_token.is_none() {
+ dyn_trait.generics.lt_token = Some(Lt::default());
+ dyn_trait.generics.gt_token = Some(Gt::default());
+ }
+ }
+
+ let mut expr = quote!(#original_trait_name::#fun_name(#(#args),*));
+ if new_method.sig.asyncness.is_some() {
+ expr.extend(quote! {.await})
+ }
+ parsed_method.return_type.append_conversion(&mut expr);
+
+ method_impls.push(ImplItemMethod {
+ attrs: Vec::new(),
+ vis: Visibility::Inherited,
+ defaultness: None,
+ sig: new_method.sig.clone(),
+ block: Block {
+ brace_token: Brace::default(),
+ stmts: vec![Stmt::Expr(Expr::Verbatim(expr))],
+ },
+ });
+ dyn_trait.items.push(new_method.into());
+ }
+
+ let blanket_impl = generate_blanket_impl(&dyn_trait, &original_trait, &method_impls);
+
+ let expanded = quote! {
+ #original_trait
+
+ #(#dyn_trait_attrs)*
+ #dyn_trait
+
+ #(#blanket_impl_attrs)*
+ #blanket_impl
+ };
+ TokenStream::from(expanded)
+}
+
+fn generate_blanket_impl(
+ dyn_trait: &ItemTrait,
+ original_trait: &ItemTrait,
+ method_impls: &[ImplItemMethod],
+) -> proc_macro2::TokenStream {
+ let mut blanket_generics = dyn_trait.generics.clone();
+ let some_ident = format_ident!("__to_be_dynamized");
+ blanket_generics.params.push(GenericParam::Type(TypeParam {
+ attrs: Vec::new(),
+ ident: some_ident.clone(),
+ colon_token: None,
+ bounds: std::iter::once(TypeParamBound::Trait(TraitBound {
+ paren_token: None,
+ modifier: syn::TraitBoundModifier::None,
+ lifetimes: None,
+ path: Path {
+ leading_colon: None,
+ segments: std::iter::once(path_segment_for_trait(original_trait)).collect(),
+ },
+ }))
+ .collect(),
+ eq_token: None,
+ default: None,
+ }));
+ let (_, type_gen, _where) = dyn_trait.generics.split_for_impl();
+ let dyn_trait_name = &dyn_trait.ident;
+ quote! {
+ impl #blanket_generics #dyn_trait_name #type_gen for #some_ident {
+ #(#method_impls)*
+ }
+ }
+}
+
+struct AssocTypeMatcher;
+impl TypeMatcher<Path> for AssocTypeMatcher {
+ fn match_path<'a>(&self, path: &'a Path) -> Option<&'a Path> {
+ if path.segments.first().unwrap().ident == "Self" {
+ return Some(path);
+ }
+ None
+ }
+}
+
+impl<TypeMatch, PathMatch, T> TypeMatcher<T> for (TypeMatch, PathMatch)
+where
+ TypeMatch: Fn(&Type) -> Option<&T>,
+ PathMatch: Fn(&Path) -> Option<&T>,
+{
+ fn match_type<'a>(&self, t: &'a Type) -> Option<&'a T> {
+ self.0(t)
+ }
+
+ fn match_path<'a>(&self, path: &'a Path) -> Option<&'a T> {
+ self.1(path)
+ }
+}
+
+impl TypeTransform {
+ fn append_conversion(&self, stream: &mut proc_macro2::TokenStream) {
+ match self {
+ TypeTransform::Into => stream.extend(quote! {.into()}),
+ TypeTransform::Map(opt) => {
+ let mut inner = quote!(x);
+ opt.append_conversion(&mut inner);
+ stream.extend(quote! {.map(|x| #inner)})
+ }
+ TypeTransform::Result(ok, err) => {
+ if !matches!(ok.as_ref(), TypeTransform::NoOp) {
+ let mut inner = quote!(x);
+ ok.append_conversion(&mut inner);
+ stream.extend(quote! {.map(|x| #inner)})
+ }
+ if !matches!(err.as_ref(), TypeTransform::NoOp) {
+ let mut inner = quote!(x);
+ err.append_conversion(&mut inner);
+ stream.extend(quote! {.map_err(|x| #inner)})
+ }
+ }
+ _other => {}
+ }
+ }
+}
+
+/// Just a convenience trait for us to avoid match/if-let blocks everywhere.
+trait As<T> {
+ fn get_as(&self) -> Option<&T>;
+ fn get_as_mut(&mut self) -> Option<&mut T>;
+}
+
+impl As<AngleBracketedGenericArguments> for PathArguments {
+ fn get_as(&self) -> Option<&AngleBracketedGenericArguments> {
+ match self {
+ PathArguments::AngleBracketed(args) => Some(args),
+ _other => None,
+ }
+ }
+ fn get_as_mut(&mut self) -> Option<&mut AngleBracketedGenericArguments> {
+ match self {
+ PathArguments::AngleBracketed(args) => Some(args),
+ _other => None,
+ }
+ }
+}
+
+impl As<Type> for GenericArgument {
+ fn get_as(&self) -> Option<&Type> {
+ match self {
+ GenericArgument::Type(typearg) => Some(typearg),
+ _other => None,
+ }
+ }
+ fn get_as_mut(&mut self) -> Option<&mut Type> {
+ match self {
+ GenericArgument::Type(typearg) => Some(typearg),
+ _other => None,
+ }
+ }
+}
+
+fn path_segment_for_trait(sometrait: &ItemTrait) -> PathSegment {
+ PathSegment {
+ ident: sometrait.ident.clone(),
+ arguments: match sometrait.generics.params.is_empty() {
+ true => PathArguments::None,
+ false => PathArguments::AngleBracketed(AngleBracketedGenericArguments {
+ colon2_token: None,
+ lt_token: Lt::default(),
+ args: sometrait
+ .generics
+ .params
+ .iter()
+ .map(|param| match param {
+ GenericParam::Type(type_param) => {
+ GenericArgument::Type(Type::Path(TypePath {
+ path: type_param.ident.clone().into(),
+ qself: None,
+ }))
+ }
+ GenericParam::Lifetime(lifetime_def) => {
+ GenericArgument::Lifetime(lifetime_def.lifetime.clone())
+ }
+ GenericParam::Const(_) => todo!("const generic param not supported"),
+ })
+ .collect(),
+ gt_token: Gt::default(),
+ }),
+ },
+ }
+}
diff --git a/src/parse_assoc_type.rs b/src/parse_assoc_type.rs
new file mode 100644
index 0000000..37fd78c
--- /dev/null
+++ b/src/parse_assoc_type.rs
@@ -0,0 +1,119 @@
+use proc_macro2::Span;
+use syn::spanned::Spanned;
+use syn::{GenericArgument, Ident, PathArguments, PathSegment, TraitItemType, Type};
+
+use crate::syn_utils::{find_in_type, trait_bounds};
+use crate::AssocTypeMatcher;
+
+#[derive(Debug)]
+pub enum AssocTypeParseError {
+ AssocTypeInBound,
+ GenericAssociatedType,
+ NoIntoBound,
+}
+
+pub fn parse_assoc_type(
+ assoc_type: &TraitItemType,
+) -> Result<(&Ident, &Type), (Span, AssocTypeParseError)> {
+ for bound in trait_bounds(&assoc_type.bounds) {
+ if let PathSegment {
+ ident,
+ arguments: PathArguments::AngleBracketed(args),
+ } = bound.path.segments.first().unwrap()
+ {
+ if ident == "Into" && args.args.len() == 1 {
+ if let GenericArgument::Type(into_type) = args.args.first().unwrap() {
+ // provide a better error message for type A: Into<Self::B>
+ if find_in_type(into_type, &AssocTypeMatcher).is_some() {
+ return Err((into_type.span(), AssocTypeParseError::AssocTypeInBound));
+ }
+
+ // TODO: support lifetime GATs (see the currently failing tests/gats.rs)
+ if !assoc_type.generics.params.is_empty() {
+ return Err((
+ assoc_type.generics.params.span(),
+ AssocTypeParseError::GenericAssociatedType,
+ ));
+ }
+
+ return Ok((&assoc_type.ident, into_type));
+ }
+ }
+ }
+ }
+ Err((assoc_type.span(), AssocTypeParseError::NoIntoBound))
+}
+
+#[cfg(test)]
+mod tests {
+ use quote::quote;
+ use syn::{TraitItemType, Type};
+
+ use crate::parse_assoc_type::{parse_assoc_type, AssocTypeParseError};
+
+ #[test]
+ fn ok() {
+ let type1: TraitItemType = syn::parse2(quote! {
+ type A: Into<String>;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_assoc_type(&type1),
+ Ok((id, Type::Path(path)))
+ if id == "A" && path.path.is_ident("String")
+ ));
+ }
+
+ #[test]
+ fn err_no_bound() {
+ let type1: TraitItemType = syn::parse2(quote! {
+ type A;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_assoc_type(&type1),
+ Err((_, AssocTypeParseError::NoIntoBound))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_bound() {
+ let type1: TraitItemType = syn::parse2(quote! {
+ type A: Into<Self::B>;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_assoc_type(&type1),
+ Err((_, AssocTypeParseError::AssocTypeInBound))
+ ));
+ }
+
+ #[test]
+ fn err_gat_type() {
+ let type1: TraitItemType = syn::parse2(quote! {
+ type A<X>: Into<Foobar<X>>;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_assoc_type(&type1),
+ Err((_, AssocTypeParseError::GenericAssociatedType))
+ ));
+ }
+
+ #[test]
+ fn err_gat_lifetime() {
+ let type1: TraitItemType = syn::parse2(quote! {
+ type A<'a>: Into<Foobar<'a>>;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_assoc_type(&type1),
+ Err((_, AssocTypeParseError::GenericAssociatedType))
+ ));
+ }
+}
diff --git a/src/parse_trait_sig.rs b/src/parse_trait_sig.rs
new file mode 100644
index 0000000..55a3214
--- /dev/null
+++ b/src/parse_trait_sig.rs
@@ -0,0 +1,505 @@
+use std::collections::HashMap;
+
+use proc_macro2::Span;
+use syn::{
+ spanned::Spanned, FnArg, Ident, PathArguments, PredicateType, Receiver, ReturnType, Type,
+ TypePath, WherePredicate,
+};
+use syn::{GenericParam, Signature, TypeImplTrait, TypeParamBound};
+
+use crate::syn_utils::{find_in_path, find_in_type, trait_bounds, TypeMatcher};
+use crate::{As, AssocTypeMatcher};
+
+#[derive(Debug, Clone)]
+pub enum TypeTransform {
+ NoOp,
+ Into,
+ Map(Box<TypeTransform>),
+ Result(Box<TypeTransform>, Box<TypeTransform>),
+}
+
+#[derive(Debug)]
+pub enum MethodParseError {
+ NonDispatchableMethod,
+ AssocTypeInInputs,
+ ImplTraitInInputs,
+ AssocTypeInUnsupportedReturnType,
+ UnconvertibleAssocTypeInFnInput,
+ UnconvertibleAssocTypeInTraitBound,
+ UnconvertibleAssocType,
+}
+
+struct ImplTraitMatcher;
+impl TypeMatcher<TypeImplTrait> for ImplTraitMatcher {
+ fn match_type<'a>(&self, t: &'a Type) -> Option<&'a TypeImplTrait> {
+ if let Type::ImplTrait(impltrait) = t {
+ Some(impltrait)
+ } else {
+ None
+ }
+ }
+}
+
+pub struct SignatureChanges {
+ pub return_type: TypeTransform,
+ pub inputs: Vec<Option<Vec<TypeTransform>>>,
+}
+
+pub fn parse_trait_signature(
+ signature: &mut Signature,
+ assoc_type_conversions: &HashMap<&Ident, &Type>,
+) -> Result<SignatureChanges, (Span, MethodParseError)> {
+ let assoc_type_conversions = AssocTypeConversions(assoc_type_conversions);
+
+ if is_non_dispatchable(signature) {
+ return Err((signature.span(), MethodParseError::NonDispatchableMethod));
+ }
+
+ // provide better error messages for associated types in params
+ for input in &signature.inputs {
+ if let FnArg::Typed(pattype) = input {
+ if find_in_type(&pattype.ty, &AssocTypeMatcher).is_some() {
+ return Err((pattype.ty.span(), MethodParseError::AssocTypeInInputs));
+ }
+ if let Some(impl_trait) = find_in_type(&pattype.ty, &ImplTraitMatcher) {
+ return Err((impl_trait.span(), MethodParseError::ImplTraitInInputs));
+ }
+ }
+ }
+
+ let mut type_param_transforms = HashMap::new();
+ let mut input_transforms = Vec::new();
+
+ for generic_param in &mut signature.generics.params {
+ if let GenericParam::Type(type_param) = generic_param {
+ for bound in &mut type_param.bounds {
+ if let TypeParamBound::Trait(bound) = bound {
+ if bound.path.segments.len() == 1 {
+ let segment = bound.path.segments.first_mut().unwrap();
+
+ if let PathArguments::Parenthesized(args) = &mut segment.arguments {
+ if segment.ident == "Fn"
+ || segment.ident == "FnOnce"
+ || segment.ident == "FnMut"
+ {
+ let mut transforms = Vec::new();
+ for input_type in &mut args.inputs {
+ match assoc_type_conversions.parse_type_path(input_type) {
+ Ok(ret_type) => {
+ transforms.push(ret_type);
+ }
+ Err(TransformError::UnconvertibleAssocType(span)) => {
+ return Err((
+ span,
+ MethodParseError::UnconvertibleAssocType,
+ ));
+ }
+ Err(TransformError::AssocTypeInUnsupportedType(span)) => {
+ return Err((
+ span,
+ MethodParseError::UnconvertibleAssocTypeInFnInput,
+ ));
+ }
+ }
+ }
+ if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) {
+ type_param_transforms.insert(&type_param.ident, transforms);
+ }
+ }
+ }
+ }
+ if let Some(path) = find_in_path(&bound.path, &AssocTypeMatcher) {
+ return Err((
+ path.span(),
+ MethodParseError::UnconvertibleAssocTypeInTraitBound,
+ ));
+ }
+ }
+ }
+ }
+ }
+
+ for input in &signature.inputs {
+ if let FnArg::Typed(pattype) = input {
+ if let Type::Path(path) = &*pattype.ty {
+ if let Some(ident) = path.path.get_ident() {
+ input_transforms.push(type_param_transforms.get(ident).map(|x| (*x).clone()));
+ continue;
+ }
+ }
+ }
+ input_transforms.push(None);
+ }
+
+ let return_type = match &mut signature.output {
+ ReturnType::Type(_, og_type) => match assoc_type_conversions.parse_type_path(og_type) {
+ Ok(ret_type) => ret_type,
+ Err(TransformError::UnconvertibleAssocType(span)) => {
+ return Err((span, MethodParseError::UnconvertibleAssocType));
+ }
+ Err(TransformError::AssocTypeInUnsupportedType(span)) => {
+ return Err((span, MethodParseError::AssocTypeInUnsupportedReturnType));
+ }
+ },
+ ReturnType::Default => TypeTransform::NoOp,
+ };
+ Ok(SignatureChanges {
+ return_type,
+ inputs: input_transforms,
+ })
+}
+
+struct AssocTypeConversions<'a>(&'a HashMap<&'a Ident, &'a Type>);
+
+enum TransformError {
+ UnconvertibleAssocType(Span),
+ AssocTypeInUnsupportedType(Span),
+}
+
+impl AssocTypeConversions<'_> {
+ fn parse_type_path(&self, type_: &mut Type) -> Result<TypeTransform, TransformError> {
+ let assoc_span = match find_in_type(type_, &AssocTypeMatcher) {
+ Some(path) => path.span(),
+ None => return Ok(TypeTransform::NoOp),
+ };
+
+ if let Type::Path(TypePath { path, qself: None }) = type_ {
+ let ident = &path.segments.first().unwrap().ident;
+
+ // TODO: support &mut dyn Iterator<Item = Self::A>
+ // conversion to Box<dyn Iterator<Item = Whatever>> via .map(Into::into)
+
+ if ident == "Self" && path.segments.len() == 2 {
+ let ident = &path.segments.last().unwrap().ident;
+ *type_ = (*self
+ .0
+ .get(&ident)
+ .ok_or_else(|| TransformError::UnconvertibleAssocType(ident.span()))?)
+ .clone();
+ return Ok(TypeTransform::Into);
+ } else if ident == "Option" && path.segments.len() == 1 {
+ let first_seg = path.segments.first_mut().unwrap();
+
+ if let Some(args) = first_seg.arguments.get_as_mut() {
+ if args.args.len() == 1 {
+ if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() {
+ if find_in_type(generic_type, &AssocTypeMatcher).is_some() {
+ return Ok(TypeTransform::Map(
+ self.parse_type_path(generic_type)?.into(),
+ ));
+ }
+ }
+ }
+ }
+ } else if ident == "Result" && path.segments.len() == 1 {
+ let first_seg = path.segments.first_mut().unwrap();
+ if let Some(args) = first_seg.arguments.get_as_mut() {
+ if args.args.len() == 2 {
+ let mut args_iter = args.args.iter_mut();
+ if let (Some(ok_type), Some(err_type)) = (
+ args_iter.next().unwrap().get_as_mut(),
+ args_iter.next().unwrap().get_as_mut(),
+ ) {
+ if find_in_type(ok_type, &AssocTypeMatcher).is_some()
+ || find_in_type(err_type, &AssocTypeMatcher).is_some()
+ {
+ return Ok(TypeTransform::Result(
+ self.parse_type_path(ok_type)?.into(),
+ self.parse_type_path(err_type)?.into(),
+ ));
+ }
+ }
+ }
+ }
+ } else {
+ let last_seg = &path.segments.last().unwrap();
+ if last_seg.ident == "Result" {
+ let last_seg = path.segments.last_mut().unwrap();
+ if let Some(args) = last_seg.arguments.get_as_mut() {
+ if args.args.len() == 1 {
+ if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut()
+ {
+ if find_in_type(generic_type, &AssocTypeMatcher).is_some() {
+ return Ok(TypeTransform::Map(
+ self.parse_type_path(generic_type)?.into(),
+ ));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // the type contains an associated type but we
+ // don't know how to deal with it so we abort
+ Err(TransformError::AssocTypeInUnsupportedType(assoc_span))
+ }
+}
+
+fn is_non_dispatchable(signature: &Signature) -> bool {
+ // non-dispatchable: fn example(&self) where Self: Sized;
+ if let Some(where_clause) = &signature.generics.where_clause {
+ if where_clause
+ .predicates
+ .iter()
+ .any(bounds_self_and_has_bound_sized)
+ {
+ return true;
+ }
+ }
+
+ // non-dispatchable: fn example();
+ if signature.inputs.is_empty() {
+ return true;
+ }
+
+ // non-dispatchable: fn example(arg: Type);
+ if matches!(signature.inputs.first(), Some(FnArg::Typed(_))) {
+ return true;
+ }
+
+ // non-dispatchable: fn example(self);
+ if matches!(
+ signature.inputs.first(),
+ Some(FnArg::Receiver(Receiver {
+ reference: None,
+ ..
+ }))
+ ) {
+ return true;
+ }
+ false
+}
+
+/// Returns true if the bounded type is `Self` and the bounds contain `Sized`.
+fn bounds_self_and_has_bound_sized(predicate: &WherePredicate) -> bool {
+ matches!(
+ predicate,
+ WherePredicate::Type(PredicateType {
+ bounded_ty: Type::Path(TypePath { path, .. }),
+ bounds,
+ ..
+ })
+ if path.is_ident("Self")
+ && trait_bounds(bounds).any(|b| b.path.is_ident("Sized"))
+ )
+}
+
+#[cfg(test)]
+mod tests {
+ use std::collections::HashMap;
+
+ use quote::{format_ident, quote};
+ use syn::{TraitItemMethod, Type};
+
+ use crate::parse_trait_sig::{
+ parse_trait_signature, MethodParseError, SignatureChanges, TypeTransform,
+ };
+
+ #[test]
+ fn ok_void() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Ok(SignatureChanges {
+ return_type: TypeTransform::NoOp,
+ ..
+ })
+ ));
+ }
+
+ #[test]
+ fn ok_assoc_type() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) -> Self::A;
+ })
+ .unwrap();
+
+ let mut assoc_type_map = HashMap::new();
+ let ident = format_ident!("A");
+ let dest = Type::Verbatim(quote! {Example});
+ assoc_type_map.insert(&ident, &dest);
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &assoc_type_map),
+ Ok(SignatureChanges {
+ return_type: TypeTransform::Into,
+ ..
+ })
+ ));
+ }
+
+ #[test]
+ fn err_unconvertible_assoc_type() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) -> Self::A;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::UnconvertibleAssocType))
+ ));
+ }
+
+ #[test]
+ fn err_non_dispatchable_assoc_function_no_args() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test();
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::NonDispatchableMethod))
+ ));
+ }
+
+ #[test]
+ fn err_non_dispatchable_assoc_function_with_args() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(arg: Type);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::NonDispatchableMethod))
+ ));
+ }
+
+ #[test]
+ fn err_non_dispatchable_consume_self() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(self);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::NonDispatchableMethod))
+ ));
+ }
+
+ #[test]
+ fn err_non_dispatchable_where_self_sized() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) where Self: Sized;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::NonDispatchableMethod))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_unsupported_return() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) -> Foo<Self::A>;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::AssocTypeInUnsupportedReturnType))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_unsupported_return_in_opt() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) -> Option<Foo<Self::A>>;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::AssocTypeInUnsupportedReturnType))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_unsupported_return_in_ok() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) -> Result<Foo<Self::A>, Error>;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::AssocTypeInUnsupportedReturnType))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_unsupported_return_in_err() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) -> Result<Ok, Foo<Self::A>>;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::AssocTypeInUnsupportedReturnType))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_input() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self, x: Self::A);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::AssocTypeInInputs))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_input_opt() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self, x: Option<Self::A>);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::AssocTypeInInputs))
+ ));
+ }
+
+ #[test]
+ fn err_impl_in_input() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self, arg: Option<impl SomeTrait>);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::ImplTraitInInputs))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_generic() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test<F: Fn(Foo<Self::A>)>(&self, fun: F);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::UnconvertibleAssocTypeInFnInput))
+ ));
+ }
+}
diff --git a/src/syn_utils.rs b/src/syn_utils.rs
new file mode 100644
index 0000000..4588186
--- /dev/null
+++ b/src/syn_utils.rs
@@ -0,0 +1,97 @@
+use syn::{
+ punctuated::Punctuated, GenericArgument, Path, PathArguments, ReturnType, TraitBound, Type,
+ TypeParamBound,
+};
+
+pub trait TypeMatcher<T> {
+ fn match_type<'a>(&self, t: &'a Type) -> Option<&'a T> {
+ None
+ }
+ fn match_path<'a>(&self, path: &'a Path) -> Option<&'a T> {
+ None
+ }
+}
+
+pub fn trait_bounds<T>(
+ bounds: &Punctuated<TypeParamBound, T>,
+) -> impl Iterator<Item = &TraitBound> {
+ bounds.iter().filter_map(|b| match b {
+ TypeParamBound::Trait(t) => Some(t),
+ TypeParamBound::Lifetime(_) => None,
+ })
+}
+
+pub fn find_in_type<'a, T>(t: &'a Type, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> {
+ if let Some(ret) = matcher.match_type(t) {
+ return Some(ret);
+ }
+ match t {
+ Type::Array(array) => find_in_type(&array.elem, matcher),
+ Type::BareFn(fun) => {
+ for input in &fun.inputs {
+ if let Some(ret) = find_in_type(&input.ty, matcher) {
+ return Some(ret);
+ }
+ }
+ if let ReturnType::Type(_, t) = &fun.output {
+ return find_in_type(t, matcher);
+ }
+ None
+ }
+ Type::Group(group) => find_in_type(&group.elem, matcher),
+ Type::ImplTrait(impltrait) => {
+ for bound in &impltrait.bounds {
+ if let TypeParamBound::Trait(bound) = bound {
+ if let Some(ret) = find_in_path(&bound.path, matcher) {
+ return Some(ret);
+ }
+ }
+ }
+ None
+ }
+ Type::Infer(_) => None,
+ Type::Macro(_) => None,
+ Type::Paren(paren) => find_in_type(&paren.elem, matcher),
+ Type::Path(path) => find_in_path(&path.path, matcher),
+ Type::Ptr(ptr) => find_in_type(&ptr.elem, matcher),
+ Type::Reference(reference) => find_in_type(&reference.elem, matcher),
+ Type::Slice(slice) => find_in_type(&slice.elem, matcher),
+ Type::TraitObject(traitobj) => {
+ for bound in &traitobj.bounds {
+ if let TypeParamBound::Trait(bound) = bound {
+ if let Some(ret) = find_in_path(&bound.path, matcher) {
+ return Some(ret);
+ }
+ }
+ }
+ None
+ }
+ Type::Tuple(tuple) => {
+ for elem in &tuple.elems {
+ if let Some(ret) = find_in_type(elem, matcher) {
+ return Some(ret);
+ }
+ }
+ None
+ }
+ _other => None,
+ }
+}
+
+pub fn find_in_path<'a, T>(path: &'a Path, matcher: &dyn TypeMatcher<T>) -> Option<&'a T> {
+ if let Some(ret) = matcher.match_path(path) {
+ return Some(ret);
+ }
+ for segment in &path.segments {
+ if let PathArguments::AngleBracketed(args) = &segment.arguments {
+ for arg in &args.args {
+ if let GenericArgument::Type(t) = arg {
+ if let Some(ret) = find_in_type(t, matcher) {
+ return Some(ret);
+ }
+ }
+ }
+ }
+ }
+ None
+}