#![doc = include_str!("../README.md")] use std::collections::HashMap; use proc_macro::TokenStream; use proc_macro2::Group; use quote::format_ident; use quote::quote; use quote::quote_spanned; use quote::ToTokens; use syn::parenthesized; use syn::parse::Parse; use syn::parse_macro_input; use syn::spanned::Spanned; use syn::token::Brace; use syn::token::Gt; use syn::token::Lt; use syn::token::Trait; use syn::AngleBracketedGenericArguments; use syn::Block; use syn::Error; use syn::Expr; use syn::GenericArgument; use syn::GenericParam; use syn::Ident; use syn::ImplItemMethod; use syn::ItemTrait; use syn::LitInt; use syn::Path; use syn::PathArguments; use syn::PathSegment; use syn::Signature; use syn::Stmt; use syn::Token; use syn::TraitBound; use syn::TraitItem; use syn::TraitItemMethod; use syn::Type; use syn::TypeParam; use syn::TypeParamBound; use syn::TypePath; use syn::Visibility; use syn_utils::TypeOrPath; use crate::parse_assoc_type::parse_assoc_type; use crate::parse_assoc_type::AssocTypeError; use crate::parse_trait_sig::parse_trait_signature; use crate::parse_trait_sig::MethodError; use crate::parse_trait_sig::SignatureChanges; use crate::parse_trait_sig::TypeTransform; use crate::syn_utils::iter_path; use crate::syn_utils::trait_bounds; use crate::transform::TransformError; use crate::transform::TypeConverter; mod parse_assoc_type; mod parse_trait_sig; mod syn_utils; mod transform; macro_rules! abort { ($span:expr, $message:literal $(,$args:expr)*) => {{ let msg = format!($message $(,$args)*); let tokens = quote_spanned! {$span => compile_error!(#msg);}.into(); tokens }}; } struct Collection { id: Ident, count: usize, } impl Parse for Collection { fn parse(input: syn::parse::ParseStream) -> syn::Result { let inner; parenthesized!(inner in input); let id: Ident = inner.parse()?; let _: Token![,] = inner.parse()?; let count_lit: LitInt = inner.parse()?; let count: usize = count_lit.base10_parse()?; if count < 1 { return Err(Error::new( count_lit.span(), "number of type parameters must be >= 1", )); } Ok(Self { id, count }) } } #[proc_macro_attribute] pub fn dynamize(_attr: TokenStream, input: TokenStream) -> TokenStream { let mut original_trait = parse_macro_input!(input as ItemTrait); assert!(original_trait.auto_token.is_none()); let original_trait_name = original_trait.ident.clone(); let mut objectifiable_methods: Vec<(Signature, SignatureChanges)> = Vec::new(); let mut type_converter = TypeConverter::default(); let mut blanket_impl_attrs = Vec::new(); let mut dyn_trait_attrs = Vec::new(); // FUTURE: use Vec::drain_filter once it's stable let mut i = 0; while i < original_trait.attrs.len() { if original_trait.attrs[i].path.is_ident("blanket_impl_attr") { let attr = original_trait.attrs.remove(i); let group: Group = match syn::parse2(attr.tokens) { Ok(g) => g, Err(err) => { return abort!( err.span(), "expected parenthesis: #[blanket_impl_attr(...)]" ) } }; let tokens = group.stream(); blanket_impl_attrs.push(quote! {#[#tokens]}); } else if original_trait.attrs[i].path.is_ident("dyn_trait_attr") { let attr = original_trait.attrs.remove(i); let group: Group = match syn::parse2(attr.tokens) { Ok(g) => g, Err(err) => { return abort!(err.span(), "expected parenthesis: #[dyn_trait_attr(...)]") } }; let tokens = group.stream(); dyn_trait_attrs.push(quote! {#[#tokens]}); } else if original_trait.attrs[i].path.is_ident("collection") { let attr = original_trait.attrs.remove(i); let tokens = attr.tokens.into(); let coll = parse_macro_input!(tokens as Collection); if type_converter .collections .insert(coll.id.clone(), coll.count) .is_some() { return abort!( coll.id.span(), "collection `{}` is defined multiple times for this trait", coll.id ); } } else { i += 1; } } for item in &original_trait.items { if let TraitItem::Type(assoc_type) = item { match parse_assoc_type(assoc_type) { Ok((ident, type_)) => { type_converter .assoc_type_conversions .insert(ident.clone(), type_); } Err((_, AssocTypeError::NoTraitBound)) => continue, Err((span, AssocTypeError::TraitBoundContainsAssocType)) => { return abort!(span, "dynamize does not support associated types here") } Err((span, AssocTypeError::GenericAssociatedType)) => { return abort!( span, "dynamize does not (yet?) support generic associated types" ) } } } } for item in &original_trait.items { if let TraitItem::Method(method) = item { let mut signature = method.sig.clone(); match parse_trait_signature(&mut signature, &type_converter) { Ok(parsed_method) => objectifiable_methods.push((signature, parsed_method)), Err((span, err)) => match err { MethodError::NonDispatchableMethod => continue, MethodError::AssocTypeInInputs => { return abort!( span, "dynamize does not support associated types in parameter types (except in Fn arguments)" ) } MethodError::ImplTraitInInputs => { return abort!( span, "dynamize does not support impl here, change it to a method generic" ) } MethodError::Transform(TransformError::AssocTypeWithoutDestType) => { return abort!( span, "associated type is either undefined or doesn't have a trait bound" ) } MethodError::Transform(TransformError::UnsupportedType) => { return abort!(span, "dynamize does not know how to convert this type") } MethodError::Transform(TransformError::ExpectedAtLeastNTypes(n)) => { return abort!( span, "dynamize expects at least {} generic type arguments for this type", n ) } MethodError::Transform(TransformError::AssocTypeAfterFirstNTypes(n, ident)) => { return abort!( span, "for {} dynamize supports associated types only within the first {} generic type parameters", ident, n ) } MethodError::UnconvertedAssocType => { return abort!(span, "dynamize does not support associated types here") } }, }; } } let mut method_impls: Vec = Vec::new(); let mut dyn_trait = ItemTrait { ident: format_ident!("Dyn{}", original_trait.ident), attrs: Vec::new(), vis: original_trait.vis.clone(), unsafety: original_trait.unsafety, auto_token: None, trait_token: Trait::default(), generics: original_trait.generics.clone(), colon_token: None, supertraits: original_trait .supertraits .iter() .filter(|t| match t { TypeParamBound::Trait(t) => !t.path.is_ident("Sized"), TypeParamBound::Lifetime(_) => true, }) .cloned() .collect(), brace_token: Brace::default(), items: Vec::new(), }; let mut generic_map = HashMap::new(); for type_param in dyn_trait.generics.type_params() { generic_map.insert(type_param.ident.clone(), type_param.bounds.clone()); for trait_bound in trait_bounds(&type_param.bounds) { if let Some(assoc_type) = iter_path(&trait_bound.path).find_map(filter_map_assoc_paths) { return abort!( assoc_type.span(), "dynamize does not support associated types in trait generic bounds" ); } } } for (signature, parsed_method) in objectifiable_methods { let mut new_method = TraitItemMethod { attrs: Vec::new(), sig: signature, default: None, semi_token: None, }; let fun_name = &new_method.sig.ident; let args = new_method.sig.inputs.iter().map(|arg| match arg { syn::FnArg::Receiver(_) => quote! {self}, syn::FnArg::Typed(pat_type) => match pat_type.pat.as_ref() { syn::Pat::Ident(ident) => { // FUTURE: use try block if let Type::Path(path) = &*pat_type.ty { if let Some(type_ident) = path.path.get_ident() { if let Some(transforms) = parsed_method.type_param_transforms.get(type_ident) { let args = (0..transforms.len()).map(|i| format_ident!("a{}", i)); let calls: Vec<_> = args .clone() .enumerate() .map(|(idx, i)| transforms[idx].convert(i.into_token_stream())) .collect(); let move_opt = new_method.sig.asyncness.map(|_| quote! {move}); quote!(#move_opt |#(#args),*| #ident(#(#calls),*)) } else { ident.ident.to_token_stream() } } else { ident.ident.to_token_stream() } } else { ident.ident.to_token_stream() } } _other => { panic!("unexpected"); } }, }); // in order for a trait to be object-safe its methods may not have // generics so we convert method generics into trait generics if new_method .sig .generics .params .iter() .any(|p| matches!(p, GenericParam::Type(_))) { // syn::punctuated::Punctuated doesn't have a remove(index) // method so we firstly move all elements to a vector let mut params = Vec::new(); while let Some(generic_param) = new_method.sig.generics.params.pop() { params.push(generic_param.into_value()); } // FUTURE: use Vec::drain_filter once it's stable let mut i = 0; while i < params.len() { if let GenericParam::Type(type_param) = ¶ms[i] { if let Some(bounds) = generic_map.get(&type_param.ident) { if *bounds == type_param.bounds { params.remove(i); continue; } else { return abort!(type_param.span(), "dynamize failure: there exists a same-named method generic with different bounds"); } } else { generic_map.insert(type_param.ident.clone(), type_param.bounds.clone()); dyn_trait.generics.params.push(params.remove(i)); } } else { i += 1; } } new_method.sig.generics.params.extend(params); if dyn_trait.generics.lt_token.is_none() { dyn_trait.generics.lt_token = Some(Lt::default()); dyn_trait.generics.gt_token = Some(Gt::default()); } } let mut expr = quote!(#original_trait_name::#fun_name(#(#args),*)); if new_method.sig.asyncness.is_some() { expr.extend(quote! {.await}) } let expr = parsed_method.return_type.convert(expr); method_impls.push(ImplItemMethod { attrs: Vec::new(), vis: Visibility::Inherited, defaultness: None, sig: new_method.sig.clone(), block: Block { brace_token: Brace::default(), stmts: vec![Stmt::Expr(Expr::Verbatim(expr))], }, }); dyn_trait.items.push(new_method.into()); } let blanket_impl = generate_blanket_impl(&dyn_trait, &original_trait, &method_impls); let dyn_trait_name = &dyn_trait.ident; let (impl_generics, ty_generics, where_clause) = dyn_trait.generics.split_for_impl(); let expanded = quote! { #original_trait #(#dyn_trait_attrs)* #dyn_trait // assert that dyn_trait can actually be made into an object impl #impl_generics dyn #dyn_trait_name #ty_generics #where_clause {} #(#blanket_impl_attrs)* #blanket_impl }; TokenStream::from(expanded) } fn generate_blanket_impl( dyn_trait: &ItemTrait, original_trait: &ItemTrait, method_impls: &[ImplItemMethod], ) -> proc_macro2::TokenStream { let mut blanket_generics = dyn_trait.generics.clone(); let some_ident = format_ident!("__to_be_dynamized"); blanket_generics.params.push(GenericParam::Type(TypeParam { attrs: Vec::new(), ident: some_ident.clone(), colon_token: None, bounds: std::iter::once(TypeParamBound::Trait(TraitBound { paren_token: None, modifier: syn::TraitBoundModifier::None, lifetimes: None, path: Path { leading_colon: None, segments: std::iter::once(path_segment_for_trait(original_trait)).collect(), }, })) .collect(), eq_token: None, default: None, })); let (_, type_gen, where_clause) = dyn_trait.generics.split_for_impl(); let dyn_trait_name = &dyn_trait.ident; quote! { impl #blanket_generics #dyn_trait_name #type_gen for #some_ident #where_clause { #(#method_impls)* } } } fn path_is_assoc_type(path: &Path) -> bool { path.segments[0].ident == "Self" } fn match_assoc_type(item: TypeOrPath) -> bool { if let TypeOrPath::Path(path) = item { return path_is_assoc_type(path); } false } fn filter_map_assoc_paths(item: TypeOrPath) -> Option<&Path> { match item { TypeOrPath::Path(p) if path_is_assoc_type(p) => Some(p), _other => None, } } impl TypeTransform { fn convert(&self, arg: proc_macro2::TokenStream) -> proc_macro2::TokenStream { match self { TypeTransform::Into => quote! {#arg.into()}, TypeTransform::Box(box_type) => { quote! {Box::new(#arg) as #box_type} } TypeTransform::Map(opt) => { let inner = opt.convert(quote!(x)); quote! {#arg.map(|x| #inner)} } TypeTransform::Iterator(box_type, inner) => { let inner = inner.convert(quote!(x)); quote! {Box::new(#arg.map(|x| #inner)) as #box_type} } TypeTransform::Result(ok, err) => { let map_ok = !matches!(ok.as_ref(), TypeTransform::NoOp); let map_err = !matches!(err.as_ref(), TypeTransform::NoOp); if map_ok && map_err { let ok_inner = ok.convert(quote!(x)); let err_inner = err.convert(quote!(x)); quote! {#arg.map(|x| #ok_inner).map_err(|x| #err_inner)} } else if map_ok { let ok_inner = ok.convert(quote!(x)); quote! {#arg.map(|x| #ok_inner)} } else { let err_inner = err.convert(quote!(x)); quote! {#arg.map_err(|x| #err_inner)} } } TypeTransform::Tuple(types) => { let idents = (0..types.len()).map(|i| format_ident!("v{}", i)); // FUTURE: let transforms = std::iter::zip(idents, types).map(|(i, t)| t.convert(quote! {#i})); let transforms = types.iter().enumerate().map(|(idx, t)| { let id = format_ident!("v{}", idx); t.convert(quote! {#id}) }); quote! { {let (#(#idents),*) = #arg; (#(#transforms),*)} } } TypeTransform::IntoIterMapCollect(types) => { let idents = (0..types.len()).map(|i| format_ident!("v{}", i)); // FUTURE: let transforms = std::iter::zip(idents, types).map(|(i, t)| t.convert(quote! {#i})); let transforms = types.iter().enumerate().map(|(idx, t)| { let id = format_ident!("v{}", idx); t.convert(quote! {#id}) }); quote! {#arg.into_iter().map(|(#(#idents),*)| (#(#transforms),*)).collect()} } TypeTransform::NoOp => arg, } } } fn path_segment_for_trait(sometrait: &ItemTrait) -> PathSegment { PathSegment { ident: sometrait.ident.clone(), arguments: match sometrait.generics.params.is_empty() { true => PathArguments::None, false => PathArguments::AngleBracketed(AngleBracketedGenericArguments { colon2_token: None, lt_token: Lt::default(), args: sometrait .generics .params .iter() .map(|param| match param { GenericParam::Type(type_param) => { GenericArgument::Type(Type::Path(TypePath { path: type_param.ident.clone().into(), qself: None, })) } GenericParam::Lifetime(lifetime_def) => { GenericArgument::Lifetime(lifetime_def.lifetime.clone()) } GenericParam::Const(_) => todo!("const generic param not supported"), }) .collect(), gt_token: Gt::default(), }), }, } } #[doc = include_str!("../tests/doctests.md")] #[cfg(doctest)] struct Doctests; #[test] fn ui() { use std::process::Command; for entry in std::fs::read_dir("ui-tests/src/bin").unwrap().flatten() { if entry.path().extension().unwrap() == "rs" { let output = Command::new("cargo") .arg("check") .arg("--quiet") .arg("--bin") .arg(entry.path().file_stem().unwrap()) .current_dir("ui-tests") .output() .unwrap(); let received_stderr = std::str::from_utf8(&output.stderr) .unwrap() .lines() .filter(|l| !l.starts_with("error: could not compile")) .collect::>() .join("\n"); let expected_stderr = std::fs::read_to_string(entry.path().with_extension("stderr")).unwrap_or_default(); if received_stderr != expected_stderr { println!( "EXPECTED:\n{banner}\n{expected}{banner}\n\nACTUAL OUTPUT:\n{banner}\n{actual}{banner}", banner = "-".repeat(30), expected = expected_stderr, actual = received_stderr ); panic!("failed"); } } } }