aboutsummaryrefslogtreecommitdiff
path: root/src/transform.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/transform.rs')
-rw-r--r--src/transform.rs158
1 files changed, 158 insertions, 0 deletions
diff --git a/src/transform.rs b/src/transform.rs
new file mode 100644
index 0000000..4fc22f5
--- /dev/null
+++ b/src/transform.rs
@@ -0,0 +1,158 @@
+use std::collections::HashMap;
+
+use proc_macro2::Span;
+use syn::{
+ spanned::Spanned, GenericParam, Generics, Ident, PathArguments, Type, TypeParamBound, TypePath,
+};
+
+use crate::{
+ parse_trait_sig::{MethodParseError, TypeTransform},
+ syn_utils::{find_in_path, find_in_type},
+ As, AssocTypeMatcher,
+};
+
+#[derive(Default)]
+pub struct AssocTypeConversions<'a>(pub HashMap<&'a Ident, &'a Type>);
+
+pub enum TransformError {
+ UnconvertibleAssocType(Span),
+ AssocTypeInUnsupportedType(Span),
+}
+
+impl AssocTypeConversions<'_> {
+ pub fn parse_type_path(&self, type_: &mut Type) -> Result<TypeTransform, TransformError> {
+ let assoc_span = match find_in_type(type_, &AssocTypeMatcher) {
+ Some(path) => path.span(),
+ None => return Ok(TypeTransform::NoOp),
+ };
+
+ if let Type::Path(TypePath { path, qself: None }) = type_ {
+ let ident = &path.segments.first().unwrap().ident;
+
+ // TODO: support &mut dyn Iterator<Item = Self::A>
+ // conversion to Box<dyn Iterator<Item = Whatever>> via .map(Into::into)
+
+ if ident == "Self" && path.segments.len() == 2 {
+ let ident = &path.segments.last().unwrap().ident;
+ *type_ = (*self
+ .0
+ .get(&ident)
+ .ok_or_else(|| TransformError::UnconvertibleAssocType(ident.span()))?)
+ .clone();
+ return Ok(TypeTransform::Into);
+ } else if ident == "Option" && path.segments.len() == 1 {
+ let first_seg = path.segments.first_mut().unwrap();
+
+ if let Some(args) = first_seg.arguments.get_as_mut() {
+ if args.args.len() == 1 {
+ if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() {
+ if find_in_type(generic_type, &AssocTypeMatcher).is_some() {
+ return Ok(TypeTransform::Map(
+ self.parse_type_path(generic_type)?.into(),
+ ));
+ }
+ }
+ }
+ }
+ } else if ident == "Result" && path.segments.len() == 1 {
+ let first_seg = path.segments.first_mut().unwrap();
+ if let Some(args) = first_seg.arguments.get_as_mut() {
+ if args.args.len() == 2 {
+ let mut args_iter = args.args.iter_mut();
+ if let (Some(ok_type), Some(err_type)) = (
+ args_iter.next().unwrap().get_as_mut(),
+ args_iter.next().unwrap().get_as_mut(),
+ ) {
+ if find_in_type(ok_type, &AssocTypeMatcher).is_some()
+ || find_in_type(err_type, &AssocTypeMatcher).is_some()
+ {
+ return Ok(TypeTransform::Result(
+ self.parse_type_path(ok_type)?.into(),
+ self.parse_type_path(err_type)?.into(),
+ ));
+ }
+ }
+ }
+ }
+ } else {
+ let last_seg = &path.segments.last().unwrap();
+ if last_seg.ident == "Result" {
+ let last_seg = path.segments.last_mut().unwrap();
+ if let Some(args) = last_seg.arguments.get_as_mut() {
+ if args.args.len() == 1 {
+ if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut()
+ {
+ if find_in_type(generic_type, &AssocTypeMatcher).is_some() {
+ return Ok(TypeTransform::Map(
+ self.parse_type_path(generic_type)?.into(),
+ ));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // the type contains an associated type but we
+ // don't know how to deal with it so we abort
+ Err(TransformError::AssocTypeInUnsupportedType(assoc_span))
+ }
+}
+
+pub fn dynamize_function_bounds<'a>(
+ generics: &'a mut Generics,
+ assoc_type_conversions: &AssocTypeConversions<'a>,
+) -> Result<HashMap<&'a Ident, Vec<TypeTransform>>, (Span, MethodParseError)> {
+ let mut type_param_transforms = HashMap::new();
+
+ for generic_param in &mut generics.params {
+ if let GenericParam::Type(type_param) = generic_param {
+ for bound in &mut type_param.bounds {
+ if let TypeParamBound::Trait(bound) = bound {
+ if bound.path.segments.len() == 1 {
+ let segment = bound.path.segments.first_mut().unwrap();
+
+ if let PathArguments::Parenthesized(args) = &mut segment.arguments {
+ if segment.ident == "Fn"
+ || segment.ident == "FnOnce"
+ || segment.ident == "FnMut"
+ {
+ let mut transforms = Vec::new();
+ for input_type in &mut args.inputs {
+ match assoc_type_conversions.parse_type_path(input_type) {
+ Ok(ret_type) => {
+ transforms.push(ret_type);
+ }
+ Err(TransformError::UnconvertibleAssocType(span)) => {
+ return Err((
+ span,
+ MethodParseError::UnconvertibleAssocType,
+ ));
+ }
+ Err(TransformError::AssocTypeInUnsupportedType(span)) => {
+ return Err((
+ span,
+ MethodParseError::UnconvertibleAssocTypeInFnInput,
+ ));
+ }
+ }
+ }
+ if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) {
+ type_param_transforms.insert(&type_param.ident, transforms);
+ }
+ }
+ }
+ }
+ if let Some(path) = find_in_path(&bound.path, &AssocTypeMatcher) {
+ return Err((
+ path.span(),
+ MethodParseError::UnconvertibleAssocTypeInTraitBound,
+ ));
+ }
+ }
+ }
+ }
+ }
+ Ok(type_param_transforms)
+}