diff --git a/macro/src/args.rs b/macro/src/args.rs index 8c7c1a0..bd8d307 100644 --- a/macro/src/args.rs +++ b/macro/src/args.rs @@ -1,5 +1,5 @@ use proc_macro2::{Span, TokenStream}; -use quote::{ToTokens, TokenStreamExt}; +use quote::TokenStreamExt; use syn::Ident; pub(crate) struct ArgumentOptions { @@ -22,31 +22,33 @@ fn unpack_pat(pat: syn::Pat) -> Result { } } -fn type_equals(ty: Box, search: impl AsRef) -> bool { +fn bare_type(ty: Box) -> Option { match *ty { - syn::Type::Array(_) => false, - syn::Type::BareFn(_) => false, - syn::Type::ImplTrait(_) => false, - syn::Type::Infer(_) => false, - syn::Type::Macro(_) => false, - syn::Type::Never(_) => false, - syn::Type::Ptr(_) => false, - syn::Type::Slice(_) => false, - syn::Type::TraitObject(_) => false, - syn::Type::Tuple(_) => false, - syn::Type::Verbatim(_) => false, - syn::Type::Group(g) => type_equals(g.elem, search), - syn::Type::Paren(p) => type_equals(p.elem, search), - syn::Type::Reference(r) => type_equals(r.elem, search), - syn::Type::Path(ty) => { - ty.path.segments - .last() - .map_or(false, |e| e.ident == search.as_ref()) - }, - _ => false, + syn::Type::Array(a) => bare_type(a.elem), + syn::Type::BareFn(_) => None, + syn::Type::ImplTrait(_) => None, + syn::Type::Infer(_) => None, + syn::Type::Macro(_) => None, + syn::Type::Never(_) => None, + syn::Type::TraitObject(_) => None, + syn::Type::Verbatim(_) => None, + syn::Type::Ptr(p) => bare_type(p.elem), + syn::Type::Slice(s) => bare_type(s.elem), + syn::Type::Tuple(t) => bare_type(Box::new(t.elems.first()?.clone())), // TODO + syn::Type::Group(g) => bare_type(g.elem), + syn::Type::Paren(p) => bare_type(p.elem), + syn::Type::Reference(r) => bare_type(r.elem), + syn::Type::Path(ty) => Some(ty), + _ => todo!(), } } +fn type_equals(ty: Box, search: impl AsRef) -> bool { + let Some(ty) = bare_type(ty) else { return false }; + let Some(last) = ty.path.segments.last() else { return false }; + last.ident == search.as_ref() +} + impl ArgumentOptions { pub(crate) fn parse_args(fn_item: &syn::ItemFn, ret_expr: TokenStream) -> Result { let mut arguments = Vec::new(); @@ -83,9 +85,9 @@ impl ArgumentOptions { if pass_env { if let Some(arg) = args_iter.next() { let pat = arg.pat; - let ty = arg.ty; + let ty = bare_type(arg.ty); incoming.append_all(quote::quote!( mut #pat: #ty,)); - forwarding.append_all(quote::quote!( #pat,)); + forwarding.append_all(quote::quote!( &mut #pat,)); } } else { incoming.append_all(quote::quote!( mut #env: jni::JNIEnv<'local>,)); @@ -121,25 +123,3 @@ struct SingleArgument { pat: syn::Ident, ty: Box, } - -#[allow(unused)] -fn bare_type(t: syn::Type) -> TokenStream { - match t { - syn::Type::Array(x) => bare_type(*x.elem), - syn::Type::BareFn(f) => f.to_token_stream(), - syn::Type::Group(x) => bare_type(*x.elem), - syn::Type::ImplTrait(t) => t.to_token_stream(), - syn::Type::Infer(x) => x.to_token_stream(), - syn::Type::Macro(x) => x.to_token_stream(), - syn::Type::Never(x) => x.to_token_stream(), - syn::Type::Paren(p) => bare_type(*p.elem), - syn::Type::Path(p) => p.to_token_stream(), - syn::Type::Ptr(x) => bare_type(*x.elem), - syn::Type::Reference(r) => bare_type(*r.elem), - syn::Type::Slice(s) => bare_type(*s.elem), - syn::Type::TraitObject(t) => t.to_token_stream(), - syn::Type::Tuple(x) => x.to_token_stream(), - syn::Type::Verbatim(x) => x.to_token_stream(), - _ => todo!(), - } -}