aboutsummaryrefslogtreecommitdiff
path: root/src/transform.rs
blob: 4fc22f58e7fd913e44e4b65a9da680b07d0a6d2d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
use std::collections::HashMap;

use proc_macro2::Span;
use syn::{
    spanned::Spanned, GenericParam, Generics, Ident, PathArguments, Type, TypeParamBound, TypePath,
};

use crate::{
    parse_trait_sig::{MethodParseError, TypeTransform},
    syn_utils::{find_in_path, find_in_type},
    As, AssocTypeMatcher,
};

#[derive(Default)]
pub struct AssocTypeConversions<'a>(pub HashMap<&'a Ident, &'a Type>);

pub enum TransformError {
    UnconvertibleAssocType(Span),
    AssocTypeInUnsupportedType(Span),
}

impl AssocTypeConversions<'_> {
    pub fn parse_type_path(&self, type_: &mut Type) -> Result<TypeTransform, TransformError> {
        let assoc_span = match find_in_type(type_, &AssocTypeMatcher) {
            Some(path) => path.span(),
            None => return Ok(TypeTransform::NoOp),
        };

        if let Type::Path(TypePath { path, qself: None }) = type_ {
            let ident = &path.segments.first().unwrap().ident;

            // TODO: support &mut dyn Iterator<Item = Self::A>
            // conversion to  Box<dyn Iterator<Item = Whatever>> via .map(Into::into)

            if ident == "Self" && path.segments.len() == 2 {
                let ident = &path.segments.last().unwrap().ident;
                *type_ = (*self
                    .0
                    .get(&ident)
                    .ok_or_else(|| TransformError::UnconvertibleAssocType(ident.span()))?)
                .clone();
                return Ok(TypeTransform::Into);
            } else if ident == "Option" && path.segments.len() == 1 {
                let first_seg = path.segments.first_mut().unwrap();

                if let Some(args) = first_seg.arguments.get_as_mut() {
                    if args.args.len() == 1 {
                        if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() {
                            if find_in_type(generic_type, &AssocTypeMatcher).is_some() {
                                return Ok(TypeTransform::Map(
                                    self.parse_type_path(generic_type)?.into(),
                                ));
                            }
                        }
                    }
                }
            } else if ident == "Result" && path.segments.len() == 1 {
                let first_seg = path.segments.first_mut().unwrap();
                if let Some(args) = first_seg.arguments.get_as_mut() {
                    if args.args.len() == 2 {
                        let mut args_iter = args.args.iter_mut();
                        if let (Some(ok_type), Some(err_type)) = (
                            args_iter.next().unwrap().get_as_mut(),
                            args_iter.next().unwrap().get_as_mut(),
                        ) {
                            if find_in_type(ok_type, &AssocTypeMatcher).is_some()
                                || find_in_type(err_type, &AssocTypeMatcher).is_some()
                            {
                                return Ok(TypeTransform::Result(
                                    self.parse_type_path(ok_type)?.into(),
                                    self.parse_type_path(err_type)?.into(),
                                ));
                            }
                        }
                    }
                }
            } else {
                let last_seg = &path.segments.last().unwrap();
                if last_seg.ident == "Result" {
                    let last_seg = path.segments.last_mut().unwrap();
                    if let Some(args) = last_seg.arguments.get_as_mut() {
                        if args.args.len() == 1 {
                            if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut()
                            {
                                if find_in_type(generic_type, &AssocTypeMatcher).is_some() {
                                    return Ok(TypeTransform::Map(
                                        self.parse_type_path(generic_type)?.into(),
                                    ));
                                }
                            }
                        }
                    }
                }
            }
        }

        // the type contains an associated type but we
        // don't know how to deal with it so we abort
        Err(TransformError::AssocTypeInUnsupportedType(assoc_span))
    }
}

pub fn dynamize_function_bounds<'a>(
    generics: &'a mut Generics,
    assoc_type_conversions: &AssocTypeConversions<'a>,
) -> Result<HashMap<&'a Ident, Vec<TypeTransform>>, (Span, MethodParseError)> {
    let mut type_param_transforms = HashMap::new();

    for generic_param in &mut generics.params {
        if let GenericParam::Type(type_param) = generic_param {
            for bound in &mut type_param.bounds {
                if let TypeParamBound::Trait(bound) = bound {
                    if bound.path.segments.len() == 1 {
                        let segment = bound.path.segments.first_mut().unwrap();

                        if let PathArguments::Parenthesized(args) = &mut segment.arguments {
                            if segment.ident == "Fn"
                                || segment.ident == "FnOnce"
                                || segment.ident == "FnMut"
                            {
                                let mut transforms = Vec::new();
                                for input_type in &mut args.inputs {
                                    match assoc_type_conversions.parse_type_path(input_type) {
                                        Ok(ret_type) => {
                                            transforms.push(ret_type);
                                        }
                                        Err(TransformError::UnconvertibleAssocType(span)) => {
                                            return Err((
                                                span,
                                                MethodParseError::UnconvertibleAssocType,
                                            ));
                                        }
                                        Err(TransformError::AssocTypeInUnsupportedType(span)) => {
                                            return Err((
                                                span,
                                                MethodParseError::UnconvertibleAssocTypeInFnInput,
                                            ));
                                        }
                                    }
                                }
                                if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) {
                                    type_param_transforms.insert(&type_param.ident, transforms);
                                }
                            }
                        }
                    }
                    if let Some(path) = find_in_path(&bound.path, &AssocTypeMatcher) {
                        return Err((
                            path.span(),
                            MethodParseError::UnconvertibleAssocTypeInTraitBound,
                        ));
                    }
                }
            }
        }
    }
    Ok(type_param_transforms)
}