use std::collections::HashMap; use proc_macro2::Span; use quote::quote; use syn::{ spanned::Spanned, GenericArgument, Generics, Ident, PathArguments, TraitBound, Type, TypeParamBound, TypePath, TypeReference, TypeTraitObject, WherePredicate, }; use crate::{ filter_map_assoc_paths, match_assoc_type, parse_assoc_type::DestType, parse_trait_sig::{MethodParseError, TypeTransform}, syn_utils::{iter_path, iter_type}, }; #[derive(Default)] pub struct AssocTypeConversions<'a>(pub HashMap>); pub enum TransformError { UnconvertibleAssocType(Span), AssocTypeInUnsupportedType(Span), } impl AssocTypeConversions<'_> { pub fn parse_type_path(&self, type_: &mut Type) -> Result { let assoc_span = match iter_type(type_).filter_map(filter_map_assoc_paths).next() { Some(path) => path.span(), None => return Ok(TypeTransform::NoOp), }; if let Type::Reference(TypeReference { lifetime: None, mutability: Some(_), elem, .. }) = type_ { if let Type::TraitObject(TypeTraitObject { dyn_token: Some(_), bounds, }) = elem.as_mut() { if bounds.len() == 1 { if let TypeParamBound::Trait(bound) = bounds.first_mut().unwrap() { if bound.path.segments.len() == 1 { let first = bound.path.segments.first_mut().unwrap(); if first.ident == "Iterator" { if let PathArguments::AngleBracketed(args) = &mut first.arguments { if args.args.len() == 1 { if let GenericArgument::Binding(binding) = args.args.first_mut().unwrap() { if binding.ident == "Item" && iter_type(&binding.ty).any(match_assoc_type) { let inner = self.parse_type_path(&mut binding.ty)?; *type_ = Type::Verbatim(quote! {Box<#elem + '_>}); return Ok(TypeTransform::Iterator(inner.into())); } } } } } } } } } } if let Type::Path(TypePath { path, qself: None }) = type_ { let ident = &path.segments.first().unwrap().ident; if ident == "Self" && path.segments.len() == 2 { let ident = &path.segments.last().unwrap().ident; let dest_type = self .0 .get(ident) .ok_or_else(|| TransformError::UnconvertibleAssocType(ident.span()))?; *type_ = dest_type.get_dest(); return Ok(dest_type.type_transformation()); } else if ident == "Option" && path.segments.len() == 1 { let first_seg = path.segments.first_mut().unwrap(); if let PathArguments::AngleBracketed(args) = &mut first_seg.arguments { if args.args.len() == 1 { if let GenericArgument::Type(generic_type) = args.args.first_mut().unwrap() { if iter_type(generic_type).any(match_assoc_type) { return Ok(TypeTransform::Map( self.parse_type_path(generic_type)?.into(), )); } } } } } else if ident == "Result" && path.segments.len() == 1 { let first_seg = path.segments.first_mut().unwrap(); if let PathArguments::AngleBracketed(args) = &mut first_seg.arguments { if args.args.len() == 2 { let mut args_iter = args.args.iter_mut(); if let (GenericArgument::Type(ok_type), GenericArgument::Type(err_type)) = (args_iter.next().unwrap(), args_iter.next().unwrap()) { if iter_type(ok_type).any(match_assoc_type) || iter_type(err_type).any(match_assoc_type) { return Ok(TypeTransform::Result( self.parse_type_path(ok_type)?.into(), self.parse_type_path(err_type)?.into(), )); } } } } } else { let last_seg = &path.segments.last().unwrap(); if last_seg.ident == "Result" { let last_seg = path.segments.last_mut().unwrap(); if let PathArguments::AngleBracketed(args) = &mut last_seg.arguments { if args.args.len() == 1 { if let GenericArgument::Type(generic_type) = args.args.first_mut().unwrap() { if iter_type(generic_type).any(match_assoc_type) { return Ok(TypeTransform::Map( self.parse_type_path(generic_type)?.into(), )); } } } } } } } // the type contains an associated type but we // don't know how to deal with it so we abort Err(TransformError::AssocTypeInUnsupportedType(assoc_span)) } } pub fn dynamize_function_bounds( generics: &mut Generics, assoc_type_conversions: &AssocTypeConversions, ) -> Result>, (Span, MethodParseError)> { let mut type_param_transforms = HashMap::new(); for type_param in generics.type_params_mut() { for bound in &mut type_param.bounds { if let TypeParamBound::Trait(bound) = bound { dynamize_trait_bound( bound, assoc_type_conversions, &type_param.ident, &mut type_param_transforms, )?; } } } if let Some(where_clause) = &mut generics.where_clause { for predicate in &mut where_clause.predicates { if let WherePredicate::Type(predicate_type) = predicate { if let Type::Path(path) = &mut predicate_type.bounded_ty { if let Some(ident) = path.path.get_ident() { for bound in &mut predicate_type.bounds { if let TypeParamBound::Trait(bound) = bound { dynamize_trait_bound( bound, assoc_type_conversions, ident, &mut type_param_transforms, )?; } } continue; } } // just to provide better error messages if let Some(assoc_type) = iter_type(&predicate_type.bounded_ty).find_map(filter_map_assoc_paths) { return Err(( assoc_type.span(), MethodParseError::UnconvertibleAssocTypeInWhereClause, )); } // just to provide better error messages for bound in &mut predicate_type.bounds { if let TypeParamBound::Trait(bound) = bound { if let Some(assoc_type) = iter_path(&bound.path).find_map(filter_map_assoc_paths) { return Err(( assoc_type.span(), MethodParseError::UnconvertibleAssocTypeInWhereClause, )); } } } } } } Ok(type_param_transforms) } fn dynamize_trait_bound( bound: &mut TraitBound, assoc_type_conversions: &AssocTypeConversions, type_ident: &Ident, type_param_transforms: &mut HashMap>, ) -> Result<(), (Span, MethodParseError)> { if bound.path.segments.len() == 1 { let segment = bound.path.segments.first_mut().unwrap(); if let PathArguments::Parenthesized(args) = &mut segment.arguments { if segment.ident == "Fn" || segment.ident == "FnOnce" || segment.ident == "FnMut" { let mut transforms = Vec::new(); for input_type in &mut args.inputs { match assoc_type_conversions.parse_type_path(input_type) { Ok(ret_type) => { transforms.push(ret_type); } Err(TransformError::UnconvertibleAssocType(span)) => { return Err((span, MethodParseError::UnconvertibleAssocType)); } Err(TransformError::AssocTypeInUnsupportedType(span)) => { return Err((span, MethodParseError::UnconvertibleAssocTypeInFnInput)); } } } if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { type_param_transforms.insert(type_ident.clone(), transforms); } } } } if let Some(path) = iter_path(&bound.path) .filter_map(filter_map_assoc_paths) .next() { return Err(( path.span(), MethodParseError::UnconvertibleAssocTypeInTraitBound, )); } Ok(()) }