From 3047d21870edb29af774461655e9ac437baed06e Mon Sep 17 00:00:00 2001 From: alemi Date: Tue, 17 Sep 2024 23:27:27 +0200 Subject: [PATCH] feat(lua): also pass errors in callbacks --- .github/workflows/lua.yml | 1 + src/ffi/lua/buffer.rs | 3 +-- src/ffi/lua/cursor.rs | 3 +-- src/ffi/lua/ext/a_sync.rs | 8 ++++---- src/ffi/lua/ext/callback.rs | 37 ++++++++++++++++++++++++------------- src/ffi/lua/ext/log.rs | 2 +- src/ffi/lua/ext/mod.rs | 1 + src/ffi/lua/mod.rs | 13 +++++++++---- 8 files changed, 42 insertions(+), 26 deletions(-) diff --git a/.github/workflows/lua.yml b/.github/workflows/lua.yml index a48975f..6310df7 100644 --- a/.github/workflows/lua.yml +++ b/.github/workflows/lua.yml @@ -4,6 +4,7 @@ on: push: branches: - stable + - dev permissions: contents: read diff --git a/src/ffi/lua/buffer.rs b/src/ffi/lua/buffer.rs index 2e94894..acf3d04 100644 --- a/src/ffi/lua/buffer.rs +++ b/src/ffi/lua/buffer.rs @@ -4,7 +4,6 @@ use crate::prelude::*; use super::ext::a_sync::a_sync; use super::ext::from_lua_serde; -use super::ext::callback::CHANNEL; impl LuaUserData for CodempBufferController { @@ -25,7 +24,7 @@ impl LuaUserData for CodempBufferController { methods.add_method("clear_callback", |_, this, ()| { this.clear_callback(); Ok(()) }); methods.add_method("callback", |_, this, (cb,):(LuaFunction,)| { - this.callback(move |controller: CodempBufferController| CHANNEL.send(cb.clone(), controller)); + this.callback(move |controller: CodempBufferController| super::ext::callback().invoke(cb.clone(), controller)); Ok(()) }); } diff --git a/src/ffi/lua/cursor.rs b/src/ffi/lua/cursor.rs index dc43139..46d4e29 100644 --- a/src/ffi/lua/cursor.rs +++ b/src/ffi/lua/cursor.rs @@ -4,7 +4,6 @@ use crate::prelude::*; use super::ext::a_sync::a_sync; use super::ext::from_lua_serde; -use super::ext::callback::CHANNEL; use super::ext::lua_tuple; impl LuaUserData for CodempCursorController { @@ -24,7 +23,7 @@ impl LuaUserData for CodempCursorController { methods.add_method("clear_callback", |_, this, ()| { this.clear_callback(); Ok(()) }); methods.add_method("callback", |_, this, (cb,):(LuaFunction,)| { - this.callback(move |controller: CodempCursorController| CHANNEL.send(cb.clone(), controller)); + this.callback(move |controller: CodempCursorController| super::ext::callback().invoke(cb.clone(), controller)); Ok(()) }); } diff --git a/src/ffi/lua/ext/a_sync.rs b/src/ffi/lua/ext/a_sync.rs index 0b4e860..c994cc9 100644 --- a/src/ffi/lua/ext/a_sync.rs +++ b/src/ffi/lua/ext/a_sync.rs @@ -1,8 +1,6 @@ use mlua_codemp_patch as mlua; use mlua::prelude::*; -use super::callback::CHANNEL; - pub(crate) fn tokio() -> &'static tokio::runtime::Runtime { use std::sync::OnceLock; static RT: OnceLock = OnceLock::new(); @@ -60,8 +58,10 @@ impl LuaUserData for Promise { .spawn(async move { match x.await { Err(e) => tracing::error!("could not join promise to run callback: {e}"), - Ok(Err(e)) => tracing::error!("promise returned error: {e}"), - Ok(Ok(res)) => CHANNEL.send(cb, res), + Ok(res) => match res { + Err(e) => super::callback().failure(e), + Ok(val) => super::callback().invoke(cb, val), + }, } }); Ok(()) diff --git a/src/ffi/lua/ext/callback.rs b/src/ffi/lua/ext/callback.rs index 5c3d825..636f9c0 100644 --- a/src/ffi/lua/ext/callback.rs +++ b/src/ffi/lua/ext/callback.rs @@ -3,16 +3,17 @@ use mlua::prelude::*; use crate::prelude::*; use crate::ext::IgnorableError; -lazy_static::lazy_static! { - pub(crate) static ref CHANNEL: CallbackChannel = CallbackChannel::default(); +pub(crate) fn callback() -> &'static CallbackChannel { + static CHANNEL: std::sync::OnceLock> = std::sync::OnceLock::new(); + CHANNEL.get_or_init(CallbackChannel::default) } -pub(crate) struct CallbackChannel { - tx: std::sync::Arc>, - rx: std::sync::Mutex> +pub(crate) struct CallbackChannel { + tx: std::sync::Arc>, + rx: std::sync::Mutex> } -impl Default for CallbackChannel { +impl Default for CallbackChannel { fn default() -> Self { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let rx = std::sync::Mutex::new(rx); @@ -23,30 +24,40 @@ impl Default for CallbackChannel { } } -impl CallbackChannel { - pub(crate) fn send(&self, cb: LuaFunction, arg: impl Into) { - self.tx.send((cb, arg.into())) +impl CallbackChannel { + pub(crate) fn invoke(&self, cb: LuaFunction, arg: impl Into) { + self.tx.send(LuaCallback::Invoke(cb, arg.into())) .unwrap_or_warn("error scheduling callback") } - pub(crate) fn recv(&self) -> Option<(LuaFunction, CallbackArg)> { + pub(crate) fn failure(&self, err: impl std::error::Error) { + self.tx.send(LuaCallback::Fail(format!("promise failed with error: {err:?}"))) + .unwrap_or_warn("error scheduling callback failure") + } + + pub(crate) fn recv(&self) -> Option { match self.rx.try_lock() { Err(e) => { - tracing::warn!("could not acquire callback channel mutex: {e}"); + tracing::debug!("backing off from callback mutex: {e}"); None }, Ok(mut lock) => match lock.try_recv() { - Err(tokio::sync::mpsc::error::TryRecvError::Empty) => None, Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => { tracing::error!("callback channel closed"); None }, - Ok((cb, arg)) => Some((cb, arg)), + Err(tokio::sync::mpsc::error::TryRecvError::Empty) => None, + Ok(cb) => Some(cb), }, } } } +pub(crate) enum LuaCallback { + Fail(String), + Invoke(LuaFunction, CallbackArg), +} + pub(crate) enum CallbackArg { Nil, Str(String), diff --git a/src/ffi/lua/ext/log.rs b/src/ffi/lua/ext/log.rs index 76287c0..9d042c7 100644 --- a/src/ffi/lua/ext/log.rs +++ b/src/ffi/lua/ext/log.rs @@ -66,7 +66,7 @@ pub(crate) fn logger(_: &Lua, (printer, debug): (LuaValue, Option)) -> Lua if res { super::a_sync::tokio().spawn(async move { while let Some(msg) = rx.recv().await { - super::callback::CHANNEL.send(cb.clone(), msg); + super::callback().invoke(cb.clone(), msg); } }); } diff --git a/src/ffi/lua/ext/mod.rs b/src/ffi/lua/ext/mod.rs index e4cde45..081f22a 100644 --- a/src/ffi/lua/ext/mod.rs +++ b/src/ffi/lua/ext/mod.rs @@ -6,6 +6,7 @@ use mlua_codemp_patch as mlua; use mlua::prelude::*; pub(crate) use a_sync::tokio; +pub(crate) use callback::callback; pub(crate) fn lua_tuple(lua: &Lua, (a, b): (T, T)) -> LuaResult { let table = lua.create_table()?; diff --git a/src/ffi/lua/mod.rs b/src/ffi/lua/mod.rs index 55240af..188232f 100644 --- a/src/ffi/lua/mod.rs +++ b/src/ffi/lua/mod.rs @@ -30,11 +30,16 @@ fn entrypoint(lua: &Lua) -> LuaResult { // runtime exports.set("spawn_runtime_driver", lua.create_function(ext::a_sync::spawn_runtime_driver)?)?; exports.set("poll_callback", lua.create_function(|lua, ()| { - // TODO pass args too let mut val = LuaMultiValue::new(); - if let Some((cb, arg)) = ext::callback::CHANNEL.recv() { - val.push_back(LuaValue::Function(cb)); - val.push_back(arg.into_lua(lua)?); + match ext::callback().recv() { + None => {}, + Some(ext::callback::LuaCallback::Invoke(cb, arg)) => { + val.push_back(LuaValue::Function(cb)); + val.push_back(arg.into_lua(lua)?); + } + Some(ext::callback::LuaCallback::Fail(msg)) => { + return Err(LuaError::runtime(msg)); + }, } Ok(val) })?)?;