diff --git a/Cargo.lock b/Cargo.lock index fee4fd2..786a055 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -240,7 +240,7 @@ dependencies = [ "jni", "jni-toolbox", "lazy_static", - "mlua-codemp-patch", + "mlua", "napi", "napi-build", "napi-derive", @@ -911,12 +911,13 @@ dependencies = [ ] [[package]] -name = "mlua-codemp-patch" -version = "0.10.0-beta.2" +name = "mlua" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a52f529509c236114a5cf5bb3c0c06ff0695ad45d718256930ec2416edf3817" +checksum = "0ae9546e4a268c309804e8bbb7526e31cbfdedca7cd60ac1b987d0b212e0d876" dependencies = [ "bstr", + "either", "erased-serde", "mlua-sys", "mlua_derive", @@ -929,9 +930,9 @@ dependencies = [ [[package]] name = "mlua-sys" -version = "0.6.4" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9eebac25c35a13285456c88ee2fde93d9aee8bcfdaf03f9d6d12be3391351ec" +checksum = "efa6bf1a64f06848749b7e7727417f4ec2121599e2a10ef0a8a3888b0e9a5a0d" dependencies = [ "cc", "cfg-if", @@ -940,9 +941,9 @@ dependencies = [ [[package]] name = "mlua_derive" -version = "0.10.0-beta.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e6f40fa1fd8426285688f4a37b56beac69284743d057ee6db352b543f4b621" +checksum = "2cfc5faa2e0d044b3f5f0879be2920e0a711c97744c42cf1c295cb183668933e" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 4394369..df57212 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,7 @@ jni = { version = "0.21", features = ["invocation"], optional = true } jni-toolbox = { version = "0.2", optional = true, features = ["uuid"] } # glue (lua) -mlua-codemp-patch = { version = "0.10.0-beta.2", features = ["module", "send", "serialize"], optional = true } +mlua = { version = "0.10", features = ["module", "serialize", "error-send"], optional = true } # glue (js) napi = { version = "2.16", features = ["full"], optional = true } @@ -74,10 +74,10 @@ test-e2e = [] java = ["dep:lazy_static", "dep:jni", "dep:tracing-subscriber", "dep:jni-toolbox"] js = ["dep:napi-build", "dep:tracing-subscriber", "dep:napi", "dep:napi-derive"] py = ["dep:pyo3", "dep:tracing-subscriber", "dep:pyo3-build-config"] -lua = ["serialize", "dep:mlua-codemp-patch", "dep:tracing-subscriber", "dep:lazy_static"] +lua = ["serialize", "dep:mlua", "dep:tracing-subscriber", "dep:lazy_static"] # ffi variants -lua-jit = ["mlua-codemp-patch?/luajit"] -lua-54 = ["mlua-codemp-patch?/lua54"] +lua-jit = ["mlua?/luajit"] +lua-54 = ["mlua?/lua54"] py-abi3 = ["pyo3?/abi3-py38"] diff --git a/src/buffer/controller.rs b/src/buffer/controller.rs index 62b6f6d..9f1a9d1 100644 --- a/src/buffer/controller.rs +++ b/src/buffer/controller.rs @@ -22,9 +22,14 @@ use crate::ext::IgnorableError; pub struct BufferController(pub(crate) Arc); impl BufferController { + /// Get id of workspace containing this controller + pub fn workspace_id(&self) -> &str { + &self.0.workspace_id + } + /// Get the buffer path. pub fn path(&self) -> &str { - &self.0.name + &self.0.path } /// Return buffer whole content, updating internal acknowledgement tracker. @@ -50,7 +55,7 @@ impl BufferController { #[derive(Debug)] pub(crate) struct BufferControllerInner { - pub(crate) name: String, + pub(crate) path: String, pub(crate) latest_version: watch::Receiver, pub(crate) local_version: watch::Receiver, pub(crate) ops_in: mpsc::UnboundedSender, @@ -59,6 +64,7 @@ pub(crate) struct BufferControllerInner { pub(crate) delta_request: mpsc::Sender>>, pub(crate) callback: watch::Sender>>, pub(crate) ack_tx: mpsc::UnboundedSender, + pub(crate) workspace_id: String, } #[cfg_attr(feature = "async-trait", async_trait::async_trait)] diff --git a/src/buffer/worker.rs b/src/buffer/worker.rs index 48de01b..8a9b405 100644 --- a/src/buffer/worker.rs +++ b/src/buffer/worker.rs @@ -40,6 +40,7 @@ impl BufferController { path: &str, tx: mpsc::Sender, rx: Streaming, + workspace_id: &str, ) -> Self { let init = diamond_types::LocalVersion::default(); @@ -57,7 +58,7 @@ impl BufferController { let agent_id = oplog.get_or_create_agent_id(&user_id.to_string()); let controller = Arc::new(BufferControllerInner { - name: path.to_string(), + path: path.to_string(), latest_version: latest_version_rx, local_version: my_version_rx, ops_in: opin_tx, @@ -66,6 +67,7 @@ impl BufferController { delta_request: recv_tx, callback: cb_tx, ack_tx, + workspace_id: workspace_id.to_string(), }); let weak = Arc::downgrade(&controller); diff --git a/src/cursor/controller.rs b/src/cursor/controller.rs index ec26069..d0c544c 100644 --- a/src/cursor/controller.rs +++ b/src/cursor/controller.rs @@ -25,12 +25,19 @@ use codemp_proto::{ #[cfg_attr(feature = "js", napi_derive::napi)] pub struct CursorController(pub(crate) Arc); +impl CursorController { + pub fn workspace_id(&self) -> &str { + &self.0.workspace_id + } +} + #[derive(Debug)] pub(crate) struct CursorControllerInner { pub(crate) op: mpsc::UnboundedSender, pub(crate) stream: mpsc::Sender>>, pub(crate) poll: mpsc::UnboundedSender>, pub(crate) callback: watch::Sender>>, + pub(crate) workspace_id: String, } #[cfg_attr(feature = "async-trait", async_trait::async_trait)] diff --git a/src/cursor/worker.rs b/src/cursor/worker.rs index a690a9a..16d5d26 100644 --- a/src/cursor/worker.rs +++ b/src/cursor/worker.rs @@ -28,6 +28,7 @@ impl CursorController { user_map: Arc>, tx: mpsc::Sender, rx: Streaming, + workspace_id: &str, ) -> Self { // TODO we should tweak the channel buffer size to better propagate backpressure let (op_tx, op_rx) = mpsc::unbounded_channel(); @@ -39,6 +40,7 @@ impl CursorController { stream: stream_tx, callback: cb_tx, poll: poll_tx, + workspace_id: workspace_id.to_string(), }); let weak = Arc::downgrade(&controller); diff --git a/src/ffi/lua/buffer.rs b/src/ffi/lua/buffer.rs index 2392734..d061b9c 100644 --- a/src/ffi/lua/buffer.rs +++ b/src/ffi/lua/buffer.rs @@ -1,6 +1,5 @@ use crate::prelude::*; use mlua::prelude::*; -use mlua_codemp_patch as mlua; use super::ext::a_sync::a_sync; @@ -31,11 +30,23 @@ impl LuaUserData for CodempBufferController { |_, this, ()| a_sync! { this => this.content().await? }, ); - methods.add_method("clear_callback", |_, this, ()| Ok(this.clear_callback())); - methods.add_method("callback", |_, this, (cb,): (LuaFunction,)| { + methods.add_method("clear_callback", move |lua, this, ()| { + this.clear_callback(); + lua.unset_named_registry_value(&this.lua_callback_id()) + }); + + methods.add_method("callback", move |lua, this, (cb,): (LuaFunction,)| { + let key = this.lua_callback_id(); + lua.set_named_registry_value(&key, cb)?; Ok(this.callback(move |controller: CodempBufferController| { - super::ext::callback().invoke(cb.clone(), controller) + super::ext::callback().invoke(key.clone(), controller, false) })) }); } } + +impl CodempBufferController { + fn lua_callback_id(&self) -> String { + format!("codemp-buffercontroller({}:{})-callback-registry", self.workspace_id(), self.path()) + } +} diff --git a/src/ffi/lua/client.rs b/src/ffi/lua/client.rs index 4a8f4d2..4c69c40 100644 --- a/src/ffi/lua/client.rs +++ b/src/ffi/lua/client.rs @@ -1,6 +1,5 @@ use crate::prelude::*; use mlua::prelude::*; -use mlua_codemp_patch as mlua; use super::ext::a_sync::a_sync; diff --git a/src/ffi/lua/cursor.rs b/src/ffi/lua/cursor.rs index 5a7c220..d590034 100644 --- a/src/ffi/lua/cursor.rs +++ b/src/ffi/lua/cursor.rs @@ -1,6 +1,5 @@ use crate::prelude::*; use mlua::prelude::*; -use mlua_codemp_patch as mlua; use super::ext::a_sync::a_sync; @@ -22,11 +21,22 @@ impl LuaUserData for CodempCursorController { methods.add_method("recv", |_, this, ()| a_sync! { this => this.recv().await? }); methods.add_method("poll", |_, this, ()| a_sync! { this => this.poll().await? }); - methods.add_method("clear_callback", |_, this, ()| Ok(this.clear_callback())); - methods.add_method("callback", |_, this, (cb,): (LuaFunction,)| { + methods.add_method("clear_callback", |lua, this, ()| { + this.clear_callback(); + lua.unset_named_registry_value(&this.lua_callback_id()) + }); + methods.add_method("callback", |lua, this, (cb,): (LuaFunction,)| { + let key = this.lua_callback_id(); + lua.set_named_registry_value(&key, cb)?; Ok(this.callback(move |controller: CodempCursorController| { - super::ext::callback().invoke(cb.clone(), controller) + super::ext::callback().invoke(key.clone(), controller, false) })) }); } } + +impl CodempCursorController { + fn lua_callback_id(&self) -> String { + format!("codemp-cursorcontroller({})-callback-registry", self.workspace_id()) + } +} diff --git a/src/ffi/lua/ext/a_sync.rs b/src/ffi/lua/ext/a_sync.rs index 2e10776..13600dc 100644 --- a/src/ffi/lua/ext/a_sync.rs +++ b/src/ffi/lua/ext/a_sync.rs @@ -1,5 +1,4 @@ use mlua::prelude::*; -use mlua_codemp_patch as mlua; pub(crate) fn tokio() -> &'static tokio::runtime::Runtime { use std::sync::OnceLock; @@ -21,7 +20,8 @@ macro_rules! a_sync { Some( crate::ffi::lua::ext::a_sync::tokio() .spawn(async move { - Ok(crate::ffi::lua::ext::callback::CallbackArg::from($x)) + let res = $x; + Ok(crate::ffi::lua::ext::callback::CallbackArg::from(res)) }) ) ) @@ -47,13 +47,20 @@ impl LuaUserData for Promise { // TODO: await MUST NOT be used in callbacks!! methods.add_method_mut("await", |_, this, ()| match this.0.take() { None => Err(LuaError::runtime("Promise already awaited")), - Some(x) => tokio().block_on(x).map_err(LuaError::runtime)?, + Some(x) => Ok( + tokio() + .block_on(x) + .map_err(LuaError::runtime)? + .map_err(LuaError::runtime)? + ), }); methods.add_method_mut("cancel", |_, this, ()| match this.0.take() { None => Err(LuaError::runtime("Promise already awaited")), Some(x) => Ok(x.abort()), }); - methods.add_method_mut("and_then", |_, this, (cb,): (LuaFunction,)| { + methods.add_method_mut("and_then", |lua, this, (cb,): (LuaFunction,)| { + let key = uuid::Uuid::new_v4().to_string(); + lua.set_named_registry_value(&key, cb)?; match this.0.take() { None => Err(LuaError::runtime("Promise already awaited")), Some(x) => { @@ -64,9 +71,9 @@ impl LuaUserData for Promise { } Ok(res) => match res { Err(e) => super::callback().failure(e), - Ok(val) => super::callback().invoke(cb, val), + Ok(val) => super::callback().invoke(key, val, true), }, - } + }; }); Ok(()) } diff --git a/src/ffi/lua/ext/callback.rs b/src/ffi/lua/ext/callback.rs index 54d8e03..269fa8f 100644 --- a/src/ffi/lua/ext/callback.rs +++ b/src/ffi/lua/ext/callback.rs @@ -1,7 +1,6 @@ use crate::ext::IgnorableError; use crate::prelude::*; use mlua::prelude::*; -use mlua_codemp_patch as mlua; pub(crate) fn callback() -> &'static CallbackChannel { static CHANNEL: std::sync::OnceLock> = std::sync::OnceLock::new(); @@ -25,21 +24,19 @@ impl Default for CallbackChannel { } impl CallbackChannel { - pub(crate) fn invoke(&self, cb: LuaFunction, arg: impl Into) { + pub(crate) fn invoke(&self, key: String, arg: impl Into, cleanup: bool) { self.tx - .send(LuaCallback::Invoke(cb, arg.into())) + .send(LuaCallback::Invoke(key, arg.into(), cleanup)) .unwrap_or_warn("error scheduling callback") } pub(crate) fn failure(&self, err: impl std::error::Error) { self.tx - .send(LuaCallback::Fail(format!( - "promise failed with error: {err:?}" - ))) + .send(LuaCallback::Fail(format!("callback returned error: {err:?}"))) .unwrap_or_warn("error scheduling callback failure") } - pub(crate) fn recv(&self) -> Option { + pub(crate) fn recv(&self, lua: &Lua) -> Option<(LuaFunction, CallbackArg)> { match self.rx.try_lock() { Err(e) => { tracing::debug!("backing off from callback mutex: {e}"); @@ -51,7 +48,25 @@ impl CallbackChannel { None } Err(tokio::sync::mpsc::error::TryRecvError::Empty) => None, - Ok(cb) => Some(cb), + Ok(LuaCallback::Fail(msg)) => { + tracing::error!("callback returned error: {msg}"); + None + }, + Ok(LuaCallback::Invoke(key, arg, cleanup)) => { + let cb = match lua.named_registry_value::(&key) { + Ok(x) => x, + Err(e) => { + tracing::error!("could not get callback to invoke: {e}"); + return None; + }, + }; + if cleanup { + if let Err(e) = lua.unset_named_registry_value(&key) { + tracing::warn!("could not unset callback from registry: {e}"); + } + } + Some((cb, arg)) + }, }, } } @@ -59,7 +74,7 @@ impl CallbackChannel { pub(crate) enum LuaCallback { Fail(String), - Invoke(LuaFunction, CallbackArg), + Invoke(String, CallbackArg, bool), } macro_rules! callback_args { diff --git a/src/ffi/lua/ext/log.rs b/src/ffi/lua/ext/log.rs index e22bc64..b9728e9 100644 --- a/src/ffi/lua/ext/log.rs +++ b/src/ffi/lua/ext/log.rs @@ -1,7 +1,6 @@ use std::{io::Write, sync::Mutex}; use mlua::prelude::*; -use mlua_codemp_patch as mlua; use tokio::sync::mpsc; #[derive(Debug, Clone)] @@ -19,7 +18,7 @@ impl Write for LuaLoggerProducer { // TODO can we make this less verbose? pub(crate) fn setup_tracing( - _: &Lua, + lua: &Lua, (printer, debug): (LuaValue, Option), ) -> LuaResult { let level = if debug.unwrap_or_default() { @@ -37,14 +36,6 @@ pub(crate) fn setup_tracing( .with_source_location(false); let success = match printer { - LuaValue::Boolean(_) - | LuaValue::LightUserData(_) - | LuaValue::Integer(_) - | LuaValue::Number(_) - | LuaValue::Table(_) - | LuaValue::Thread(_) - | LuaValue::UserData(_) - | LuaValue::Error(_) => return Err(LuaError::BindError), // TODO full BadArgument type?? LuaValue::Nil => tracing_subscriber::fmt() .event_format(format) .with_max_level(level) @@ -63,6 +54,8 @@ pub(crate) fn setup_tracing( .is_ok() } LuaValue::Function(cb) => { + let key = uuid::Uuid::new_v4().to_string(); + lua.set_named_registry_value(&key, cb)?; let (tx, mut rx) = mpsc::unbounded_channel(); let res = tracing_subscriber::fmt() .event_format(format) @@ -74,12 +67,13 @@ pub(crate) fn setup_tracing( if res { super::a_sync::tokio().spawn(async move { while let Some(msg) = rx.recv().await { - super::callback().invoke(cb.clone(), msg); + super::callback().invoke(key.clone(), msg, false); } }); } res - } + }, + _ => return Err(LuaError::BindError), // TODO full BadArgument type?? }; Ok(success) diff --git a/src/ffi/lua/mod.rs b/src/ffi/lua/mod.rs index 33792d0..7c4a5e2 100644 --- a/src/ffi/lua/mod.rs +++ b/src/ffi/lua/mod.rs @@ -6,7 +6,6 @@ mod workspace; use crate::prelude::*; use mlua::prelude::*; -use mlua_codemp_patch as mlua; // define multiple entrypoints, so this library can have multiple names and still work #[mlua::lua_module(name = "codemp")] @@ -57,16 +56,12 @@ fn entrypoint(lua: &Lua) -> LuaResult { "poll_callback", lua.create_function(|lua, ()| { let mut val = LuaMultiValue::new(); - match ext::callback().recv() { + match ext::callback().recv(lua) { None => {} - Some(ext::callback::LuaCallback::Invoke(cb, arg)) => { + Some((cb, arg)) => { val.push_back(LuaValue::Function(cb)); val.push_back(arg.into_lua(lua)?); } - Some(ext::callback::LuaCallback::Fail(msg)) => { - val.push_back(false.into_lua(lua)?); - val.push_back(msg.into_lua(lua)?); - } } Ok(val) })?, diff --git a/src/ffi/lua/workspace.rs b/src/ffi/lua/workspace.rs index 2094f75..3ce118e 100644 --- a/src/ffi/lua/workspace.rs +++ b/src/ffi/lua/workspace.rs @@ -1,6 +1,5 @@ use crate::prelude::*; use mlua::prelude::*; -use mlua_codemp_patch as mlua; use super::ext::a_sync::a_sync; @@ -67,12 +66,23 @@ impl LuaUserData for CodempWorkspace { methods.add_method("poll", |_, this, ()| a_sync! { this => this.poll().await? }); - methods.add_method("callback", |_, this, (cb,): (LuaFunction,)| { + methods.add_method("callback", |lua, this, (cb,): (LuaFunction,)| { + let key = this.lua_callback_id(); + lua.set_named_registry_value(&key, cb)?; Ok(this.callback(move |controller: CodempWorkspace| { - super::ext::callback().invoke(cb.clone(), controller) + super::ext::callback().invoke(key.clone(), controller, false) })) }); - methods.add_method("clear_callback", |_, this, ()| Ok(this.clear_callback())); + methods.add_method("clear_callback", |lua, this, ()| { + this.clear_callback(); + lua.unset_named_registry_value(&this.lua_callback_id()) + }); + } +} + +impl CodempWorkspace { + fn lua_callback_id(&self) -> String { + format!("codemp-workspace({})-callback-registry", self.id()) } } diff --git a/src/workspace.rs b/src/workspace.rs index 4d0a788..bd46f5c 100644 --- a/src/workspace.rs +++ b/src/workspace.rs @@ -113,7 +113,7 @@ impl Workspace { let users = Arc::new(DashMap::default()); - let controller = cursor::Controller::spawn(users.clone(), tx, cur_stream); + let controller = cursor::Controller::spawn(users.clone(), tx, cur_stream, &name); let ws = Self(Arc::new(WorkspaceInner { name, @@ -175,7 +175,7 @@ impl Workspace { ); let stream = self.0.services.buf().attach(req).await?.into_inner(); - let controller = buffer::Controller::spawn(self.0.user.id, path, tx, stream); + let controller = buffer::Controller::spawn(self.0.user.id, path, tx, stream, &self.0.name); self.0.buffers.insert(path.to_string(), controller.clone()); Ok(controller)