diff --git a/Cargo.lock b/Cargo.lock index 861cbe1..b021645 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,15 +54,15 @@ checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" [[package]] name = "jni-toolbox" -version = "0.1.1" +version = "0.1.2" dependencies = [ "jni", - "jni-toolbox-macro 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", + "jni-toolbox-macro 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] name = "jni-toolbox-macro" -version = "0.1.1" +version = "0.1.2" dependencies = [ "proc-macro2", "quote", @@ -71,9 +71,9 @@ dependencies = [ [[package]] name = "jni-toolbox-macro" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac89a7b1059ecda1171ee751e153c5b00722cbce7a12781a664489155af8112b" +checksum = "745ca1133357c564fec2b82d0bbe2c9b537b952ac2525667802c298015eaa494" dependencies = [ "proc-macro2", "quote", diff --git a/macro/Cargo.toml b/macro/Cargo.toml index 8b388e2..c6bb07b 100644 --- a/macro/Cargo.toml +++ b/macro/Cargo.toml @@ -11,7 +11,6 @@ edition = "2021" [lib] proc-macro = true -path = "macro.rs" [dependencies] syn = {version = "2.0", features = ["full"]} diff --git a/macro/src/attrs.rs b/macro/src/attrs.rs new file mode 100644 index 0000000..bf66bf8 --- /dev/null +++ b/macro/src/attrs.rs @@ -0,0 +1,72 @@ +use proc_macro2::{Span, TokenStream, TokenTree}; + +pub(crate) struct AttrsOptions { + pub(crate) package: String, + pub(crate) class: String, + pub(crate) exception: Option, + pub(crate) return_pointer: bool, + pub(crate) without_env: bool, + pub(crate) without_class: bool, +} + +impl AttrsOptions { + + pub(crate) fn parse_attr(attrs: TokenStream) -> Result { + let mut what_next = WhatNext::Nothing; + + let mut package = None; + let mut class = None; + let mut exception = None; + let mut return_pointer = false; + let mut without_env = false; + let mut without_class = false; + + for attr in attrs { + match what_next { + WhatNext::Nothing => { + if let TokenTree::Ident(ref i) = attr { + match i.to_string().as_ref() { + "package" => what_next = WhatNext::Package, + "class" => what_next = WhatNext::Class, + "exception" => what_next = WhatNext::Exception, + "ptr" => return_pointer = true, + "no_env" => without_env = true, + "no_class" => without_class = true, + _ => return Err(syn::Error::new(Span::call_site(), "unexpected attribute on macro: {attr}")), + } + } + }, + WhatNext::Class => { + if let TokenTree::Literal(i) = attr { + class = Some(i.to_string().replace('"', "")); + what_next = WhatNext::Nothing; + } + }, + WhatNext::Package => { + if let TokenTree::Literal(i) = attr { + package = Some(i.to_string().replace('"', "").replace(".", "_")); + what_next = WhatNext::Nothing; + } + }, + WhatNext::Exception => { + if let TokenTree::Literal(i) = attr { + exception = Some(i.to_string().replace('"', "").replace(".", "_")); + what_next = WhatNext::Nothing; + } + } + } + } + + let Some(package) = package else { return Err(syn::Error::new(Span::call_site(), "missing required attribute 'package'")) }; + let Some(class) = class else { return Err(syn::Error::new(Span::call_site(), "missing required attribute 'class'")) }; + + Ok(Self { package, class, exception, return_pointer, without_class, without_env }) + } +} + +enum WhatNext { + Nothing, + Package, + Class, + Exception, +} diff --git a/macro/src/lib.rs b/macro/src/lib.rs new file mode 100644 index 0000000..175a485 --- /dev/null +++ b/macro/src/lib.rs @@ -0,0 +1,17 @@ +mod attrs; +mod wrapper; + + +/// wrap this function in in a JNI exported fn +#[proc_macro_attribute] +pub fn jni( + attrs: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + wrapper::generate_jni_wrapper( + syn::parse_macro_input!(attrs), + syn::parse_macro_input!(input), + ) + .unwrap() + .into() +} diff --git a/macro/src/wrapper.rs b/macro/src/wrapper.rs new file mode 100644 index 0000000..f2f3b7d --- /dev/null +++ b/macro/src/wrapper.rs @@ -0,0 +1,158 @@ +use proc_macro2::{Span, TokenStream}; +use quote::TokenStreamExt; +use syn::{FnArg, Item, ReturnType, Type}; + +use crate::attrs::AttrsOptions; + +pub(crate) fn generate_jni_wrapper(attrs: TokenStream, input: TokenStream) -> Result { + let mut out = TokenStream::new(); + + let Item::Fn(fn_item) = syn::parse2(input.clone())? else { + return Err(syn::Error::new(Span::call_site(), "#[jni] is only supported on functions")); + }; + + let attrs = AttrsOptions::parse_attr(attrs)?; + + let (could_error, ret_type) = match fn_item.sig.output { + syn::ReturnType::Default => (false, fn_item.sig.output), + syn::ReturnType::Type(_tok, ty) => match *ty { + syn::Type::Path(ref path) => { + let Some(last) = path.path.segments.last() else { + return Err(syn::Error::new(Span::call_site(), "empty Result type is not valid")); + }; + + // TODO this is terrible, macro returns a function and we call it?? there must be a + // better way!!! + let mut out = ( + false, + ReturnType::Type(syn::Token![->](Span::call_site()), Box::new(Type::Path(path.clone()))) + ); + + if last.ident == "Result" { + match &last.arguments { + syn::PathArguments::None => return Err(syn::Error::new(Span::call_site(), "Result without generics is not valid")), + syn::PathArguments::Parenthesized(_) => return Err(syn::Error::new(Span::call_site(), "Parenthesized Result is not valid")), + syn::PathArguments::AngleBracketed(ref generics) => for generic in generics.args.iter() { + match generic { + syn::GenericArgument::Lifetime(_) => continue, + syn::GenericArgument::Type(ty) => { + out = (true, ReturnType::Type(syn::Token![->](Span::call_site()), Box::new(ty.clone()))); + break; + }, + _ => return Err(syn::Error::new(Span::call_site(), "unexpected type in Result")), + } + } + } + } + + out + }, + _ => return Err(syn::Error::new(Span::call_site(), "unsupported return type")), + }, + }; + + + let mut incoming = TokenStream::new(); + let mut forwarding = TokenStream::new(); + + for arg in fn_item.sig.inputs { + let FnArg::Typed(ty) = arg else { + return Err(syn::Error::new(Span::call_site(), "#[jni] macro doesn't work on methods")); + }; + incoming.append_all(quote::quote!( #ty , )); + let pat = unpack_pat(*ty.pat)?; + forwarding.append_all(pat); + } + + let name = fn_item.sig.ident.to_string(); + let name_jni = name.replace("_", "_1"); + let fn_name_inner = syn::Ident::new(&name, Span::call_site()); + let fn_name = syn::Ident::new(&format!("Java_{}_{}_{name_jni}", attrs.package, attrs.class), Span::call_site()); + + let Some(env_ident) = forwarding.clone().into_iter().next() else { + return Err(syn::Error::new(Span::call_site(), "missing JNIEnv argument")); + }; + + let return_expr = if attrs.return_pointer { + quote::quote!( std::ptr::null_mut() ) + } else { + quote::quote!( 0 ) + }; + + let wrapped = if could_error { + if let Some(exception) = attrs.exception { + // V----------------------------------V + quote::quote! { + #[no_mangle] + #[allow(unused_mut)] + pub extern "system" fn #fn_name<'local>(#incoming) #ret_type { + use jni_toolbox::JniToolboxError; + match #fn_name_inner(#forwarding) { + Ok(ret) => ret, + Err(e) => match #env_ident.throw_new(#exception, format!("{e:?}")) { + Ok(_) => return #return_expr, + Err(e) => panic!("error throwing java exception: {e}"), + } + } + } + } + // ^----------------------------------^ + } else { + // V----------------------------------V + quote::quote! { + #[no_mangle] + #[allow(unused_mut)] + pub extern "system" fn #fn_name<'local>(#incoming) #ret_type { + use jni_toolbox::JniToolboxError; + // NOTE: this is SAFE! the cloned env reference lives less than the actual one, we just lack a + // way to get it back from the called function and thus resort to unsafe cloning + let mut env_copy = unsafe { #env_ident.unsafe_clone() }; + match #fn_name_inner(#forwarding) { + Err(e) => match env_copy.find_class(e.jclass()) { + Err(e) => panic!("error throwing Java exception -- failed resolving error class: {e}"), + Ok(class) => match env_copy.new_string(format!("{e:?}")) { + Err(e) => panic!("error throwing Java exception -- failed creating error string: {e}"), + Ok(msg) => match env_copy.new_object(class, "(Ljava/lang/String;)V", &[jni::objects::JValueGen::Object(&msg)]) { + Err(e) => panic!("error throwing Java exception -- failed creating object: {e}"), + Ok(obj) => match env_copy.throw(jni::objects::JThrowable::from(obj)) { + Err(e) => panic!("error throwing Java exception -- failed throwing: {e}"), + Ok(_) => return #return_expr, + }, + }, + }, + } + Ok(ret) => ret, + } + } + } + // ^----------------------------------^ + } + } else { + // V----------------------------------V + quote::quote! { + #[no_mangle] + #[allow(unused_mut)] + pub extern "system" fn #fn_name<'local>(#incoming) #ret_type { + #fn_name_inner(#forwarding) + } + } + // ^----------------------------------^ + }; + + out.append_all(input); + out.append_all(wrapped); + Ok(out) +} + +fn unpack_pat(pat: syn::Pat) -> Result { + match pat { + syn::Pat::Ident(i) => { + let ident = i.ident; + Ok(quote::quote!( #ident ,)) + }, + syn::Pat::Reference(r) => { + unpack_pat(*r.pat) + }, + _ => Err(syn::Error::new(Span::call_site(), "unsupported argument type")), + } +}