use std::collections::HashMap; use proc_macro2::Span; use quote::quote; use syn::{ spanned::Spanned, GenericArgument, Generics, Ident, PathArguments, TraitBound, Type, TypeParamBound, TypePath, TypeReference, TypeTraitObject, WherePredicate, }; use crate::{ filter_map_assoc_paths, match_assoc_type, parse_assoc_type::{BoxType, DestType}, parse_trait_sig::{MethodError, TypeTransform}, syn_utils::{iter_path, iter_type, type_arguments_mut}, }; #[derive(Default)] pub struct TypeConverter<'a> { pub assoc_type_conversions: HashMap>, pub collections: HashMap, } #[derive(Debug)] pub enum TransformError { AssocTypeWithoutDestType, UnsupportedType, ExpectedAtLeastNTypes(usize), AssocTypeAfterFirstNTypes(usize, Ident), } impl TypeConverter<'_> { /// A return type of Some(1) means that the type implements /// IntoIterator and FromIterator /// with T1 being its first generic type parameter. /// /// A return type of Some(2) means that the type implements /// IntoIterator and FromIterator<(T1, T2)> /// with T1 and T2 being its first two generic type parameters. /// /// ... etc. A return type of None means the type isn't recognized. #[rustfmt::skip] fn get_collection_type_count(&self, ident: &Ident) -> Option { if let Some(count) = self.collections.get(ident) { return Some(*count); } // when adding a type here don't forget to document it in the README if ident == "Vec" { return Some(1); } if ident == "VecDeque" { return Some(1); } if ident == "LinkedList" { return Some(1); } if ident == "HashSet" { return Some(1); } if ident == "BinaryHeap" { return Some(1); } if ident == "BTreeSet" { return Some(1); } if ident == "HashMap" { return Some(2); } if ident == "BTreeMap" { return Some(2); } None } pub fn convert_type(&self, type_: &mut Type) -> Result { if !iter_type(type_).any(match_assoc_type) { return Ok(TypeTransform::NoOp); } if let Type::Tuple(tuple) = type_ { let mut types = Vec::new(); for elem in &mut tuple.elems { types.push(self.convert_type(elem)?); } return Ok(TypeTransform::Tuple(types)); } else if let Type::Reference(TypeReference { lifetime: None, mutability: Some(_), elem, .. }) = type_ { if let Type::TraitObject(TypeTraitObject { dyn_token: Some(_), bounds, }) = elem.as_mut() { if bounds.len() == 1 { if let TypeParamBound::Trait(bound) = &mut bounds[0] { if bound.path.segments.len() == 1 { let first = &mut bound.path.segments[0]; if first.ident == "Iterator" { if let PathArguments::AngleBracketed(args) = &mut first.arguments { if args.args.len() == 1 { if let GenericArgument::Binding(binding) = &mut args.args[0] { if binding.ident == "Item" && iter_type(&binding.ty).any(match_assoc_type) { let inner = self.convert_type(&mut binding.ty)?; let box_type = BoxType { inner: quote! {#elem}, placeholder_lifetime: true, }; *type_ = Type::Verbatim(quote! {#box_type}); return Ok(TypeTransform::Iterator( box_type, inner.into(), )); } } } } } } } } } } if let Type::Path(TypePath { path, qself: None }) = type_ { if path.segments[0].ident == "Self" { if path.segments.len() == 2 { let ident = &path.segments.last().unwrap().ident; let dest_type = self .assoc_type_conversions .get(ident) .ok_or_else(|| (ident.span(), TransformError::AssocTypeWithoutDestType))?; *type_ = dest_type.get_dest(); return Ok(dest_type.type_transformation()); } } else { let path_len = path.segments.len(); let last_seg = path.segments.last_mut().unwrap(); if let PathArguments::AngleBracketed(args) = &mut last_seg.arguments { let mut args: Vec<_> = type_arguments_mut(&mut args.args).collect(); if path_len == 1 { if let Some(type_count) = self.get_collection_type_count(&last_seg.ident) { if args.len() < type_count { return Err(( last_seg.span(), TransformError::ExpectedAtLeastNTypes(type_count), )); } for i in type_count..args.len() { if iter_type(args[i]).any(match_assoc_type) { return Err(( args[i].span(), TransformError::AssocTypeAfterFirstNTypes( type_count, last_seg.ident.clone(), ), )); } } let mut transforms = Vec::new(); for arg in args { transforms.push(self.convert_type(arg)?); } return Ok(TypeTransform::IntoIterMapCollect(transforms)); } } if args.len() == 1 { if iter_type(args[0]).any(match_assoc_type) && ((last_seg.ident == "Option" && path_len == 1) || last_seg.ident == "Result") { return Ok(TypeTransform::Map(self.convert_type(args[0])?.into())); } } else if args.len() == 2 && path_len == 1 && (iter_type(args[0]).any(match_assoc_type) || iter_type(args[1]).any(match_assoc_type)) && last_seg.ident == "Result" { return Ok(TypeTransform::Result( self.convert_type(args[0])?.into(), self.convert_type(args[1])?.into(), )); } } } } // the type contains an associated type but we // don't know how to deal with it so we abort Err((type_.span(), TransformError::UnsupportedType)) } } pub fn dynamize_function_bounds( generics: &mut Generics, type_converter: &TypeConverter, ) -> Result>, (Span, MethodError)> { let mut type_param_transforms = HashMap::new(); for type_param in generics.type_params_mut() { for bound in &mut type_param.bounds { if let TypeParamBound::Trait(bound) = bound { dynamize_trait_bound( bound, type_converter, &type_param.ident, &mut type_param_transforms, )?; } } } if let Some(where_clause) = &mut generics.where_clause { for predicate in &mut where_clause.predicates { if let WherePredicate::Type(predicate_type) = predicate { if let Type::Path(path) = &mut predicate_type.bounded_ty { if let Some(ident) = path.path.get_ident() { for bound in &mut predicate_type.bounds { if let TypeParamBound::Trait(bound) = bound { dynamize_trait_bound( bound, type_converter, ident, &mut type_param_transforms, )?; } } continue; } } // just to provide better error messages if let Some(assoc_type) = iter_type(&predicate_type.bounded_ty).find_map(filter_map_assoc_paths) { return Err((assoc_type.span(), MethodError::UnconvertedAssocType)); } // just to provide better error messages for bound in &mut predicate_type.bounds { if let TypeParamBound::Trait(bound) = bound { if let Some(assoc_type) = iter_path(&bound.path).find_map(filter_map_assoc_paths) { return Err((assoc_type.span(), MethodError::UnconvertedAssocType)); } } } } } } Ok(type_param_transforms) } fn dynamize_trait_bound( bound: &mut TraitBound, type_converter: &TypeConverter, type_ident: &Ident, type_param_transforms: &mut HashMap>, ) -> Result<(), (Span, MethodError)> { if bound.path.segments.len() == 1 { let segment = &mut bound.path.segments[0]; if let PathArguments::Parenthesized(args) = &mut segment.arguments { if segment.ident == "Fn" || segment.ident == "FnOnce" || segment.ident == "FnMut" { let mut transforms = Vec::new(); for input_type in &mut args.inputs { match type_converter.convert_type(input_type) { Ok(ret_type) => { transforms.push(ret_type); } Err((span, err)) => { return Err((span, err.into())); } } } if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { type_param_transforms.insert(type_ident.clone(), transforms); } } } } if let Some(path) = iter_path(&bound.path) .filter_map(filter_map_assoc_paths) .next() { return Err((path.span(), MethodError::UnconvertedAssocType)); } Ok(()) }