use std::collections::HashMap; use proc_macro2::Span; use syn::{ spanned::Spanned, GenericParam, Generics, Ident, PathArguments, Type, TypeParamBound, TypePath, }; use crate::{ parse_trait_sig::{MethodParseError, TypeTransform}, syn_utils::{find_in_path, find_in_type}, As, AssocTypeMatcher, }; #[derive(Default)] pub struct AssocTypeConversions<'a>(pub HashMap); pub enum TransformError { UnconvertibleAssocType(Span), AssocTypeInUnsupportedType(Span), } impl AssocTypeConversions<'_> { pub fn parse_type_path(&self, type_: &mut Type) -> Result { let assoc_span = match find_in_type(type_, &AssocTypeMatcher) { Some(path) => path.span(), None => return Ok(TypeTransform::NoOp), }; if let Type::Path(TypePath { path, qself: None }) = type_ { let ident = &path.segments.first().unwrap().ident; // TODO: support &mut dyn Iterator // conversion to Box> via .map(Into::into) if ident == "Self" && path.segments.len() == 2 { let ident = &path.segments.last().unwrap().ident; *type_ = (*self .0 .get(ident) .ok_or_else(|| TransformError::UnconvertibleAssocType(ident.span()))?) .clone(); return Ok(TypeTransform::Into); } else if ident == "Option" && path.segments.len() == 1 { let first_seg = path.segments.first_mut().unwrap(); if let Some(args) = first_seg.arguments.get_as_mut() { if args.args.len() == 1 { if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() { if find_in_type(generic_type, &AssocTypeMatcher).is_some() { return Ok(TypeTransform::Map( self.parse_type_path(generic_type)?.into(), )); } } } } } else if ident == "Result" && path.segments.len() == 1 { let first_seg = path.segments.first_mut().unwrap(); if let Some(args) = first_seg.arguments.get_as_mut() { if args.args.len() == 2 { let mut args_iter = args.args.iter_mut(); if let (Some(ok_type), Some(err_type)) = ( args_iter.next().unwrap().get_as_mut(), args_iter.next().unwrap().get_as_mut(), ) { if find_in_type(ok_type, &AssocTypeMatcher).is_some() || find_in_type(err_type, &AssocTypeMatcher).is_some() { return Ok(TypeTransform::Result( self.parse_type_path(ok_type)?.into(), self.parse_type_path(err_type)?.into(), )); } } } } } else { let last_seg = &path.segments.last().unwrap(); if last_seg.ident == "Result" { let last_seg = path.segments.last_mut().unwrap(); if let Some(args) = last_seg.arguments.get_as_mut() { if args.args.len() == 1 { if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() { if find_in_type(generic_type, &AssocTypeMatcher).is_some() { return Ok(TypeTransform::Map( self.parse_type_path(generic_type)?.into(), )); } } } } } } } // the type contains an associated type but we // don't know how to deal with it so we abort Err(TransformError::AssocTypeInUnsupportedType(assoc_span)) } } pub fn dynamize_function_bounds( generics: &mut Generics, assoc_type_conversions: &AssocTypeConversions, ) -> Result>, (Span, MethodParseError)> { let mut type_param_transforms = HashMap::new(); for generic_param in &mut generics.params { if let GenericParam::Type(type_param) = generic_param { for bound in &mut type_param.bounds { if let TypeParamBound::Trait(bound) = bound { if bound.path.segments.len() == 1 { let segment = bound.path.segments.first_mut().unwrap(); if let PathArguments::Parenthesized(args) = &mut segment.arguments { if segment.ident == "Fn" || segment.ident == "FnOnce" || segment.ident == "FnMut" { let mut transforms = Vec::new(); for input_type in &mut args.inputs { match assoc_type_conversions.parse_type_path(input_type) { Ok(ret_type) => { transforms.push(ret_type); } Err(TransformError::UnconvertibleAssocType(span)) => { return Err(( span, MethodParseError::UnconvertibleAssocType, )); } Err(TransformError::AssocTypeInUnsupportedType(span)) => { return Err(( span, MethodParseError::UnconvertibleAssocTypeInFnInput, )); } } } if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) { type_param_transforms .insert(type_param.ident.clone(), transforms); } } } } if let Some(path) = find_in_path(&bound.path, &AssocTypeMatcher) { return Err(( path.span(), MethodParseError::UnconvertibleAssocTypeInTraitBound, )); } } } } } Ok(type_param_transforms) }