fix(python): release GIL when spawning stuff, we don't need to keep it and we risk

deadlocking
feat(python): added allow_thread in the macro
This commit is contained in:
cschen 2024-08-21 17:20:12 +02:00
parent fe2f2a3ae0
commit dc7ae20b7d
4 changed files with 76 additions and 55 deletions

View file

@ -24,11 +24,10 @@ impl Client {
// }
#[pyo3(name = "join_workspace")]
fn pyjoin_workspace(&self, workspace: String) -> PyResult<super::Promise> {
fn pyjoin_workspace(&self, py: Python<'_>, workspace: String) -> PyResult<super::Promise> {
tracing::info!("attempting to join the workspace {}", workspace);
let this = self.clone();
crate::a_sync!(this.join_workspace(workspace).await)
crate::a_sync_allow_threads!(py, this.join_workspace(workspace).await)
// let this = self.clone();
// Ok(super::Promise(Some(tokio().spawn(async move {
// Ok(this

View file

@ -6,13 +6,19 @@ use crate::cursor::Controller as CursorController;
use pyo3::prelude::*;
use super::Promise;
use crate::a_sync;
use crate::a_sync_allow_threads;
// need to do manually since Controller is a trait implementation
#[pymethods]
impl CursorController {
#[pyo3(name = "send")]
fn pysend(&self, path: String, start: (i32, i32), end: (i32, i32)) -> PyResult<Promise> {
fn pysend(
&self,
py: Python,
path: String,
start: (i32, i32),
end: (i32, i32),
) -> PyResult<Promise> {
let pos = Cursor {
start,
end,
@ -20,26 +26,26 @@ impl CursorController {
user: None,
};
let this = self.clone();
a_sync!(this.send(pos).await)
a_sync_allow_threads!(py, this.send(pos).await)
}
#[pyo3(name = "try_recv")]
fn pytry_recv(&self) -> PyResult<Promise> {
fn pytry_recv(&self, py: Python) -> PyResult<Promise> {
let this = self.clone();
a_sync!(this.try_recv().await)
a_sync_allow_threads!(py, this.try_recv().await)
}
#[pyo3(name = "recv")]
fn pyrecv(&self) -> crate::Result<Option<Cursor>> {
Ok(super::tokio().block_on(self.try_recv())?)
fn pyrecv(&self, py: Python) -> crate::Result<Option<Cursor>> {
py.allow_threads(|| super::tokio().block_on(self.try_recv()))
// let this = self.clone();
// a_sync!(this.recv().await)
// a_sync_allow_threads!(py, this.recv().await)
}
#[pyo3(name = "poll")]
fn pypoll(&self) -> PyResult<Promise> {
fn pypoll(&self, py: Python) -> PyResult<Promise> {
let this = self.clone();
a_sync!(this.poll().await)
a_sync_allow_threads!(py, this.poll().await)
}
#[pyo3(name = "stop")]
@ -52,13 +58,13 @@ impl CursorController {
#[pymethods]
impl BufferController {
#[pyo3(name = "content")]
async fn pycontent(&self) -> PyResult<Promise> {
fn pycontent(&self, py: Python) -> PyResult<Promise> {
let this = self.clone();
a_sync!(this.content().await)
a_sync_allow_threads!(py, this.content().await)
}
#[pyo3(name = "send")]
async fn pysend(&self, start: u32, end: u32, txt: String) -> PyResult<Promise> {
fn pysend(&self, py: Python, start: u32, end: u32, txt: String) -> PyResult<Promise> {
let op = TextChange {
start,
end,
@ -66,26 +72,26 @@ impl BufferController {
hash: None,
};
let this = self.clone();
a_sync!(this.send(op).await)
a_sync_allow_threads!(py, this.send(op).await)
}
#[pyo3(name = "try_recv")]
fn pytry_recv(&self) -> crate::Result<Option<TextChange>> {
Ok(super::tokio().block_on(self.try_recv())?)
fn pytry_recv(&self, py: Python) -> crate::Result<Option<TextChange>> {
py.allow_threads(|| super::tokio().block_on(self.try_recv()))
// let this = self.clone();
// a_sync!(this.try_recv().await)
// a_sync_allow_threads!(py, this.try_recv().await)
}
#[pyo3(name = "recv")]
async fn pyrecv(&self) -> PyResult<Promise> {
fn pyrecv(&self, py: Python) -> PyResult<Promise> {
let this = self.clone();
a_sync!(this.recv().await)
a_sync_allow_threads!(py, this.recv().await)
}
#[pyo3(name = "poll")]
async fn pypoll(&self) -> PyResult<Promise> {
fn pypoll(&self, py: Python) -> PyResult<Promise> {
let this = self.clone();
a_sync!(this.poll().await)
a_sync_allow_threads!(py, this.poll().await)
}
#[pyo3(name = "stop")]

View file

@ -37,8 +37,8 @@ pub struct Promise(Option<tokio::task::JoinHandle<PyResult<PyObject>>>);
#[pymethods]
impl Promise {
#[pyo3(name = "wait")]
fn _await(&mut self) -> PyResult<PyObject> {
match self.0.take() {
fn _await(&mut self, py: Python<'_>) -> PyResult<PyObject> {
py.allow_threads(move || match self.0.take() {
None => Err(PyRuntimeError::new_err(
"promise can't be awaited multiple times!",
)),
@ -48,15 +48,17 @@ impl Promise {
))),
Ok(res) => res,
},
}
})
}
fn done(&self) -> PyResult<bool> {
if let Some(handle) = &self.0 {
Ok(handle.is_finished())
} else {
Err(PyRuntimeError::new_err("promise was already awaited."))
}
fn done(&self, py: Python<'_>) -> PyResult<bool> {
py.allow_threads(|| {
if let Some(handle) = &self.0 {
Ok(handle.is_finished())
} else {
Err(PyRuntimeError::new_err("promise was already awaited."))
}
})
}
}
@ -70,6 +72,18 @@ macro_rules! a_sync {
}};
}
#[macro_export]
macro_rules! a_sync_allow_threads {
($py:ident, $x:expr) => {{
$py.allow_threads(move || {
Ok($crate::ffi::python::Promise(Some(
$crate::ffi::python::tokio()
.spawn(async move { Ok($x.map(|f| Python::with_gil(|py| f.into_py(py)))?) }),
)))
})
}};
}
#[derive(Debug, Clone)]
struct LoggerProducer(mpsc::UnboundedSender<String>);
@ -125,7 +139,7 @@ fn init(logging_cb: Py<PyFunction>, debug: bool) -> PyResult<PyObject> {
.with_writer(std::sync::Mutex::new(LoggerProducer(tx)))
.try_init();
let (rt_stop_tx, rt_stop_rx) = oneshot::channel::<()>();
let (rt_stop_tx, mut rt_stop_rx) = oneshot::channel::<()>();
match log_subscribing {
Ok(_) => {
@ -133,12 +147,14 @@ fn init(logging_cb: Py<PyFunction>, debug: bool) -> PyResult<PyObject> {
// python logger.
std::thread::spawn(move || {
tokio().block_on(async move {
tokio::select! {
biased;
_ = rt_stop_rx => { todo!() },
Some(msg) = rx.recv() => {
let _ = Python::with_gil(|py| logging_cb.call1(py, (msg,)));
},
loop {
tokio::select! {
biased;
Some(msg) = rx.recv() => {
let _ = Python::with_gil(|py| logging_cb.call1(py, (msg,)));
},
_ = &mut rt_stop_rx => { todo!() },
}
}
})
});

View file

@ -4,22 +4,22 @@ use crate::workspace::Workspace;
use pyo3::prelude::*;
use super::Promise;
use crate::a_sync;
use crate::a_sync_allow_threads;
// use super::Promise;
#[pymethods]
impl Workspace {
// join a workspace
#[pyo3(name = "create")]
fn pycreate(&self, path: String) -> PyResult<Promise> {
fn pycreate(&self, py: Python, path: String) -> PyResult<Promise> {
let this = self.clone();
a_sync!(this.create(path.as_str()).await)
a_sync_allow_threads!(py, this.create(path.as_str()).await)
}
#[pyo3(name = "attach")]
fn pyattach(&self, path: String) -> PyResult<Promise> {
fn pyattach(&self, py: Python, path: String) -> PyResult<Promise> {
let this = self.clone();
a_sync!(this.attach(path.as_str()).await)
a_sync_allow_threads!(py, this.attach(path.as_str()).await)
}
#[pyo3(name = "detach")]
@ -32,34 +32,34 @@ impl Workspace {
}
#[pyo3(name = "event")]
fn pyevent(&self) -> PyResult<Promise> {
fn pyevent(&self, py: Python) -> PyResult<Promise> {
let this = self.clone();
a_sync!(this.event().await)
a_sync_allow_threads!(py, this.event().await)
}
#[pyo3(name = "fetch_buffers")]
fn pyfetch_buffers(&self) -> PyResult<Promise> {
fn pyfetch_buffers(&self, py: Python) -> PyResult<Promise> {
let this = self.clone();
a_sync!(this.fetch_buffers().await)
a_sync_allow_threads!(py, this.fetch_buffers().await)
}
#[pyo3(name = "fetch_users")]
fn pyfetch_users(&self) -> PyResult<Promise> {
fn pyfetch_users(&self, py: Python) -> PyResult<Promise> {
let this = self.clone();
a_sync!(this.fetch_users().await)
a_sync_allow_threads!(py, this.fetch_users().await)
}
#[pyo3(name = "list_buffer_users")]
fn pylist_buffer_users(&self, path: String) -> PyResult<Promise> {
fn pylist_buffer_users(&self, py: Python, path: String) -> PyResult<Promise> {
// crate::Result<Vec<crate::api::User>>
let this = self.clone();
a_sync!(this.list_buffer_users(path.as_str()).await)
a_sync_allow_threads!(py, this.list_buffer_users(path.as_str()).await)
}
#[pyo3(name = "delete")]
fn pydelete(&self, path: String) -> PyResult<Promise> {
fn pydelete(&self, py: Python, path: String) -> PyResult<Promise> {
let this = self.clone();
a_sync!(this.delete(path.as_str()).await)
a_sync_allow_threads!(py, this.delete(path.as_str()).await)
}
#[pyo3(name = "id")]