use std::collections::HashMap; use proc_macro2::Span; use syn::{ spanned::Spanned, GenericArgument, Generics, Ident, PathArguments, Type, TypeParamBound, TypePath, }; use crate::{ filter_map_assoc_paths, match_assoc_type, parse_assoc_type::DestType, parse_trait_sig::{MethodParseError, TypeTransform}, syn_utils::{iter_path, iter_type}, }; #[derive(Default)] pub struct AssocTypeConversions<'a>(pub HashMap>); pub enum TransformError { UnconvertibleAssocType(Span), AssocTypeInUnsupportedType(Span), } impl AssocTypeConversions<'_> { pub fn parse_type_path(&self, type_: &mut Type) -> Result { let assoc_span = match iter_type(type_).filter_map(filter_map_assoc_paths).next() { Some(path) => path.span(), None => return Ok(TypeTransform::NoOp), }; if let Type::Path(TypePath { path, qself: None }) = type_ { let ident = &path.segments.first().unwrap().ident; // TODO: support &mut dyn Iterator // conversion to Box> via .map(Into::into) if ident == "Self" && path.segments.len() == 2 { let ident = &path.segments.last().unwrap().ident; let dest_type = self .0 .get(ident) .ok_or_else(|| TransformError::UnconvertibleAssocType(ident.span()))?; *type_ = dest_type.get_dest(); return Ok(dest_type.type_transformation()); } else if ident == "Option" && path.segments.len() == 1 { let first_seg = path.segments.first_mut().unwrap(); if let PathArguments::AngleBracketed(args) = &mut first_seg.arguments { if args.args.len() == 1 { if let GenericArgument::Type(generic_type) = args.args.first_mut().unwrap() { if iter_type(generic_type).any(match_assoc_type) { return Ok(TypeTransform::Map( self.parse_type_path(generic_type)?.into(), )); } } } } } else if ident == "Result" && path.segments.len() == 1 { let first_seg = path.segments.first_mut().unwrap(); if let PathArguments::AngleBracketed(args) = &mut first_seg.arguments { if args.args.len() == 2 { let mut args_iter = args.args.iter_mut(); if let (GenericArgument::Type(ok_type), GenericArgument::Type(err_type)) = (args_iter.next().unwrap(), args_iter.next().unwrap()) { if iter_type(ok_type).any(match_assoc_type) || iter_type(err_type).any(match_assoc_type) { return Ok(TypeTransform::Result( self.parse_type_path(ok_type)?.into(), self.parse_type_path(err_type)?.into(), )); } } } } } else { let last_seg = &path.segments.last().unwrap(); if last_seg.ident == "Result" { let last_seg = path.segments.last_mut().unwrap(); if let PathArguments::AngleBracketed(args) = &mut last_seg.arguments { if args.args.len() == 1 { if let GenericArgument::Type(generic_type) = args.args.first_mut().unwrap() { if iter_type(generic_type).any(match_assoc_type) { return Ok(TypeTransform::Map( self.parse_type_path(generic_type)?.into(), )); } } } } } } } // the type contains an associated type but we // don't know how to deal with it so we abort Err(TransformError::AssocTypeInUnsupportedType(assoc_span)) } } pub fn dynamize_function_bounds( generics: &mut Generics, assoc_type_conversions: &AssocTypeConversions, ) -> Result>, (Span, MethodParseError)> { let mut type_param_transforms = HashMap::new(); for type_param in generics.type_params_mut() { for bound in &mut type_param.bounds { if let TypeParamBound::Trait(bound) = bound { if bound.path.segments.len() == 1 { let segment = bound.path.segments.first_mut().unwrap(); if let PathArguments::Parenthesized(args) = &mut segment.arguments { if segment.ident == "Fn" || segment.ident == "FnOnce" || segment.ident == "FnMut" { let mut transforms = Vec::new(); for input_type in &mut args.inputs { match assoc_type_conversions.parse_type_path(input_type) { Ok(ret_type) => { transforms.push(ret_type); } Err(TransformError::UnconvertibleAssocType(span)) => { return Err(( span, MethodParseError::UnconvertibleAssocType, )); } Err(TransformError::AssocTypeInUnsupportedType(span)) => { return Err(( span, MethodParseError::UnconvertibleAssocTypeInFnInput, )); } } } if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { type_param_transforms.insert(type_param.ident.clone(), transforms); } } } } if let Some(path) = iter_path(&bound.path) .filter_map(filter_map_assoc_paths) .next() { return Err(( path.span(), MethodParseError::UnconvertibleAssocTypeInTraitBound, )); } } } } Ok(type_param_transforms) }