aboutsummaryrefslogtreecommitdiff
path: root/src/parse_trait_sig.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/parse_trait_sig.rs')
-rw-r--r--src/parse_trait_sig.rs505
1 files changed, 505 insertions, 0 deletions
diff --git a/src/parse_trait_sig.rs b/src/parse_trait_sig.rs
new file mode 100644
index 0000000..55a3214
--- /dev/null
+++ b/src/parse_trait_sig.rs
@@ -0,0 +1,505 @@
+use std::collections::HashMap;
+
+use proc_macro2::Span;
+use syn::{
+ spanned::Spanned, FnArg, Ident, PathArguments, PredicateType, Receiver, ReturnType, Type,
+ TypePath, WherePredicate,
+};
+use syn::{GenericParam, Signature, TypeImplTrait, TypeParamBound};
+
+use crate::syn_utils::{find_in_path, find_in_type, trait_bounds, TypeMatcher};
+use crate::{As, AssocTypeMatcher};
+
+#[derive(Debug, Clone)]
+pub enum TypeTransform {
+ NoOp,
+ Into,
+ Map(Box<TypeTransform>),
+ Result(Box<TypeTransform>, Box<TypeTransform>),
+}
+
+#[derive(Debug)]
+pub enum MethodParseError {
+ NonDispatchableMethod,
+ AssocTypeInInputs,
+ ImplTraitInInputs,
+ AssocTypeInUnsupportedReturnType,
+ UnconvertibleAssocTypeInFnInput,
+ UnconvertibleAssocTypeInTraitBound,
+ UnconvertibleAssocType,
+}
+
+struct ImplTraitMatcher;
+impl TypeMatcher<TypeImplTrait> for ImplTraitMatcher {
+ fn match_type<'a>(&self, t: &'a Type) -> Option<&'a TypeImplTrait> {
+ if let Type::ImplTrait(impltrait) = t {
+ Some(impltrait)
+ } else {
+ None
+ }
+ }
+}
+
+pub struct SignatureChanges {
+ pub return_type: TypeTransform,
+ pub inputs: Vec<Option<Vec<TypeTransform>>>,
+}
+
+pub fn parse_trait_signature(
+ signature: &mut Signature,
+ assoc_type_conversions: &HashMap<&Ident, &Type>,
+) -> Result<SignatureChanges, (Span, MethodParseError)> {
+ let assoc_type_conversions = AssocTypeConversions(assoc_type_conversions);
+
+ if is_non_dispatchable(signature) {
+ return Err((signature.span(), MethodParseError::NonDispatchableMethod));
+ }
+
+ // provide better error messages for associated types in params
+ for input in &signature.inputs {
+ if let FnArg::Typed(pattype) = input {
+ if find_in_type(&pattype.ty, &AssocTypeMatcher).is_some() {
+ return Err((pattype.ty.span(), MethodParseError::AssocTypeInInputs));
+ }
+ if let Some(impl_trait) = find_in_type(&pattype.ty, &ImplTraitMatcher) {
+ return Err((impl_trait.span(), MethodParseError::ImplTraitInInputs));
+ }
+ }
+ }
+
+ let mut type_param_transforms = HashMap::new();
+ let mut input_transforms = Vec::new();
+
+ for generic_param in &mut signature.generics.params {
+ if let GenericParam::Type(type_param) = generic_param {
+ for bound in &mut type_param.bounds {
+ if let TypeParamBound::Trait(bound) = bound {
+ if bound.path.segments.len() == 1 {
+ let segment = bound.path.segments.first_mut().unwrap();
+
+ if let PathArguments::Parenthesized(args) = &mut segment.arguments {
+ if segment.ident == "Fn"
+ || segment.ident == "FnOnce"
+ || segment.ident == "FnMut"
+ {
+ let mut transforms = Vec::new();
+ for input_type in &mut args.inputs {
+ match assoc_type_conversions.parse_type_path(input_type) {
+ Ok(ret_type) => {
+ transforms.push(ret_type);
+ }
+ Err(TransformError::UnconvertibleAssocType(span)) => {
+ return Err((
+ span,
+ MethodParseError::UnconvertibleAssocType,
+ ));
+ }
+ Err(TransformError::AssocTypeInUnsupportedType(span)) => {
+ return Err((
+ span,
+ MethodParseError::UnconvertibleAssocTypeInFnInput,
+ ));
+ }
+ }
+ }
+ if transforms.iter().any(|t| !matches!(t, TypeTransform::NoOp)) {
+ type_param_transforms.insert(&type_param.ident, transforms);
+ }
+ }
+ }
+ }
+ if let Some(path) = find_in_path(&bound.path, &AssocTypeMatcher) {
+ return Err((
+ path.span(),
+ MethodParseError::UnconvertibleAssocTypeInTraitBound,
+ ));
+ }
+ }
+ }
+ }
+ }
+
+ for input in &signature.inputs {
+ if let FnArg::Typed(pattype) = input {
+ if let Type::Path(path) = &*pattype.ty {
+ if let Some(ident) = path.path.get_ident() {
+ input_transforms.push(type_param_transforms.get(ident).map(|x| (*x).clone()));
+ continue;
+ }
+ }
+ }
+ input_transforms.push(None);
+ }
+
+ let return_type = match &mut signature.output {
+ ReturnType::Type(_, og_type) => match assoc_type_conversions.parse_type_path(og_type) {
+ Ok(ret_type) => ret_type,
+ Err(TransformError::UnconvertibleAssocType(span)) => {
+ return Err((span, MethodParseError::UnconvertibleAssocType));
+ }
+ Err(TransformError::AssocTypeInUnsupportedType(span)) => {
+ return Err((span, MethodParseError::AssocTypeInUnsupportedReturnType));
+ }
+ },
+ ReturnType::Default => TypeTransform::NoOp,
+ };
+ Ok(SignatureChanges {
+ return_type,
+ inputs: input_transforms,
+ })
+}
+
+struct AssocTypeConversions<'a>(&'a HashMap<&'a Ident, &'a Type>);
+
+enum TransformError {
+ UnconvertibleAssocType(Span),
+ AssocTypeInUnsupportedType(Span),
+}
+
+impl AssocTypeConversions<'_> {
+ fn parse_type_path(&self, type_: &mut Type) -> Result<TypeTransform, TransformError> {
+ let assoc_span = match find_in_type(type_, &AssocTypeMatcher) {
+ Some(path) => path.span(),
+ None => return Ok(TypeTransform::NoOp),
+ };
+
+ if let Type::Path(TypePath { path, qself: None }) = type_ {
+ let ident = &path.segments.first().unwrap().ident;
+
+ // TODO: support &mut dyn Iterator<Item = Self::A>
+ // conversion to Box<dyn Iterator<Item = Whatever>> via .map(Into::into)
+
+ if ident == "Self" && path.segments.len() == 2 {
+ let ident = &path.segments.last().unwrap().ident;
+ *type_ = (*self
+ .0
+ .get(&ident)
+ .ok_or_else(|| TransformError::UnconvertibleAssocType(ident.span()))?)
+ .clone();
+ return Ok(TypeTransform::Into);
+ } else if ident == "Option" && path.segments.len() == 1 {
+ let first_seg = path.segments.first_mut().unwrap();
+
+ if let Some(args) = first_seg.arguments.get_as_mut() {
+ if args.args.len() == 1 {
+ if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut() {
+ if find_in_type(generic_type, &AssocTypeMatcher).is_some() {
+ return Ok(TypeTransform::Map(
+ self.parse_type_path(generic_type)?.into(),
+ ));
+ }
+ }
+ }
+ }
+ } else if ident == "Result" && path.segments.len() == 1 {
+ let first_seg = path.segments.first_mut().unwrap();
+ if let Some(args) = first_seg.arguments.get_as_mut() {
+ if args.args.len() == 2 {
+ let mut args_iter = args.args.iter_mut();
+ if let (Some(ok_type), Some(err_type)) = (
+ args_iter.next().unwrap().get_as_mut(),
+ args_iter.next().unwrap().get_as_mut(),
+ ) {
+ if find_in_type(ok_type, &AssocTypeMatcher).is_some()
+ || find_in_type(err_type, &AssocTypeMatcher).is_some()
+ {
+ return Ok(TypeTransform::Result(
+ self.parse_type_path(ok_type)?.into(),
+ self.parse_type_path(err_type)?.into(),
+ ));
+ }
+ }
+ }
+ }
+ } else {
+ let last_seg = &path.segments.last().unwrap();
+ if last_seg.ident == "Result" {
+ let last_seg = path.segments.last_mut().unwrap();
+ if let Some(args) = last_seg.arguments.get_as_mut() {
+ if args.args.len() == 1 {
+ if let Some(generic_type) = args.args.first_mut().unwrap().get_as_mut()
+ {
+ if find_in_type(generic_type, &AssocTypeMatcher).is_some() {
+ return Ok(TypeTransform::Map(
+ self.parse_type_path(generic_type)?.into(),
+ ));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // the type contains an associated type but we
+ // don't know how to deal with it so we abort
+ Err(TransformError::AssocTypeInUnsupportedType(assoc_span))
+ }
+}
+
+fn is_non_dispatchable(signature: &Signature) -> bool {
+ // non-dispatchable: fn example(&self) where Self: Sized;
+ if let Some(where_clause) = &signature.generics.where_clause {
+ if where_clause
+ .predicates
+ .iter()
+ .any(bounds_self_and_has_bound_sized)
+ {
+ return true;
+ }
+ }
+
+ // non-dispatchable: fn example();
+ if signature.inputs.is_empty() {
+ return true;
+ }
+
+ // non-dispatchable: fn example(arg: Type);
+ if matches!(signature.inputs.first(), Some(FnArg::Typed(_))) {
+ return true;
+ }
+
+ // non-dispatchable: fn example(self);
+ if matches!(
+ signature.inputs.first(),
+ Some(FnArg::Receiver(Receiver {
+ reference: None,
+ ..
+ }))
+ ) {
+ return true;
+ }
+ false
+}
+
+/// Returns true if the bounded type is `Self` and the bounds contain `Sized`.
+fn bounds_self_and_has_bound_sized(predicate: &WherePredicate) -> bool {
+ matches!(
+ predicate,
+ WherePredicate::Type(PredicateType {
+ bounded_ty: Type::Path(TypePath { path, .. }),
+ bounds,
+ ..
+ })
+ if path.is_ident("Self")
+ && trait_bounds(bounds).any(|b| b.path.is_ident("Sized"))
+ )
+}
+
+#[cfg(test)]
+mod tests {
+ use std::collections::HashMap;
+
+ use quote::{format_ident, quote};
+ use syn::{TraitItemMethod, Type};
+
+ use crate::parse_trait_sig::{
+ parse_trait_signature, MethodParseError, SignatureChanges, TypeTransform,
+ };
+
+ #[test]
+ fn ok_void() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Ok(SignatureChanges {
+ return_type: TypeTransform::NoOp,
+ ..
+ })
+ ));
+ }
+
+ #[test]
+ fn ok_assoc_type() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) -> Self::A;
+ })
+ .unwrap();
+
+ let mut assoc_type_map = HashMap::new();
+ let ident = format_ident!("A");
+ let dest = Type::Verbatim(quote! {Example});
+ assoc_type_map.insert(&ident, &dest);
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &assoc_type_map),
+ Ok(SignatureChanges {
+ return_type: TypeTransform::Into,
+ ..
+ })
+ ));
+ }
+
+ #[test]
+ fn err_unconvertible_assoc_type() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) -> Self::A;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::UnconvertibleAssocType))
+ ));
+ }
+
+ #[test]
+ fn err_non_dispatchable_assoc_function_no_args() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test();
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::NonDispatchableMethod))
+ ));
+ }
+
+ #[test]
+ fn err_non_dispatchable_assoc_function_with_args() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(arg: Type);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::NonDispatchableMethod))
+ ));
+ }
+
+ #[test]
+ fn err_non_dispatchable_consume_self() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(self);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::NonDispatchableMethod))
+ ));
+ }
+
+ #[test]
+ fn err_non_dispatchable_where_self_sized() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) where Self: Sized;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::NonDispatchableMethod))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_unsupported_return() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) -> Foo<Self::A>;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::AssocTypeInUnsupportedReturnType))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_unsupported_return_in_opt() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) -> Option<Foo<Self::A>>;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::AssocTypeInUnsupportedReturnType))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_unsupported_return_in_ok() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) -> Result<Foo<Self::A>, Error>;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::AssocTypeInUnsupportedReturnType))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_unsupported_return_in_err() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self) -> Result<Ok, Foo<Self::A>>;
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::AssocTypeInUnsupportedReturnType))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_input() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self, x: Self::A);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::AssocTypeInInputs))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_input_opt() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self, x: Option<Self::A>);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::AssocTypeInInputs))
+ ));
+ }
+
+ #[test]
+ fn err_impl_in_input() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test(&self, arg: Option<impl SomeTrait>);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::ImplTraitInInputs))
+ ));
+ }
+
+ #[test]
+ fn err_assoc_type_in_generic() {
+ let mut type1: TraitItemMethod = syn::parse2(quote! {
+ fn test<F: Fn(Foo<Self::A>)>(&self, fun: F);
+ })
+ .unwrap();
+
+ assert!(matches!(
+ parse_trait_signature(&mut type1.sig, &Default::default()),
+ Err((_, MethodParseError::UnconvertibleAssocTypeInFnInput))
+ ));
+ }
+}