From 8e2f41a1c8fac456b650a466c67ed503c83249d8 Mon Sep 17 00:00:00 2001 From: alemi Date: Tue, 11 Apr 2023 22:35:37 +0200 Subject: [PATCH] chore: made OperationFactory async and mutexless --- src/lib/client.rs | 47 ++++++++-------- src/lib/opfactory.rs | 128 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 145 insertions(+), 30 deletions(-) diff --git a/src/lib/client.rs b/src/lib/client.rs index 89b499e..d022955 100644 --- a/src/lib/client.rs +++ b/src/lib/client.rs @@ -1,23 +1,21 @@ /// TODO better name for this file -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use tracing::{error, warn}; use uuid::Uuid; use crate::{ - opfactory::OperationFactory, + opfactory::AsyncFactory, proto::{buffer_client::BufferClient, BufferPayload, OperationRequest, RawOp}, tonic::{transport::Channel, Status, Streaming}, }; -type FactoryHandle = Arc>; - impl From::> for CodempClient { fn from(x: BufferClient) -> CodempClient { CodempClient { id: Uuid::new_v4(), client:x, - factory: Arc::new(Mutex::new(OperationFactory::new(None))) + factory: Arc::new(AsyncFactory::new(None)), } } } @@ -26,7 +24,7 @@ impl From::> for CodempClient { pub struct CodempClient { id: Uuid, client: BufferClient, - factory: FactoryHandle, // TODO less jank solution than Arc + factory: Arc, } impl CodempClient { @@ -46,8 +44,7 @@ impl CodempClient { } pub async fn insert(&mut self, path: String, txt: String, pos: u64) -> Result { - let res = { self.factory.lock().unwrap().insert(&txt, pos) }; - match res { + match self.factory.insert(txt, pos).await { Ok(op) => { Ok( self.client.edit( @@ -68,8 +65,7 @@ impl CodempClient { } pub async fn delete(&mut self, path: String, pos: u64, count: u64) -> Result { - let res = { self.factory.lock().unwrap().delete(pos, count) }; - match res { + match self.factory.delete(pos, count).await { Ok(op) => { Ok( self.client.edit( @@ -106,31 +102,32 @@ impl CodempClient { Ok(()) } - async fn worker ()>(mut stream: Streaming, factory: FactoryHandle, callback: F) { + pub async fn sync(&mut self, path: String) -> Result { + let res = self.client.sync( + BufferPayload { + path, content: None, user: self.id.to_string(), + } + ).await?; + Ok(res.into_inner().content.unwrap_or("".into())) + } + + async fn worker ()>(mut stream: Streaming, factory: Arc, callback: F) { loop { match stream.message().await { + Err(e) => break error!("error receiving change: {}", e), Ok(v) => match v { + None => break warn!("stream closed"), Some(operation) => { match serde_json::from_str(&operation.opseq) { - Ok(op) => { - let res = { factory.lock().unwrap().process(op) }; - match res { - Ok(x) => callback(x), - Err(e) => break error!("desynched: {}", e), - } - }, Err(e) => break error!("could not deserialize opseq: {}", e), + Ok(op) => match factory.process(op).await { + Err(e) => break error!("desynched: {}", e), + Ok(x) => callback(x), + }, } } - None => break warn!("stream closed"), }, - Err(e) => break error!("error receiving change: {}", e), } } } - - pub fn content(&self) -> String { - let factory = self.factory.lock().unwrap(); - factory.content() - } } diff --git a/src/lib/opfactory.rs b/src/lib/opfactory.rs index 707c951..8c92a8b 100644 --- a/src/lib/opfactory.rs +++ b/src/lib/opfactory.rs @@ -1,4 +1,6 @@ use operational_transform::{OperationSeq, OTError}; +use tokio::sync::{mpsc, watch, oneshot}; +use tracing::error; #[derive(Clone)] pub struct OperationFactory { @@ -30,11 +32,11 @@ impl OperationFactory { pub fn insert(&mut self, txt: &str, pos: u64) -> Result { let mut out = OperationSeq::default(); - let len = self.content.len() as u64; + let total = self.content.len() as u64; out.retain(pos); out.insert(txt); - out.retain(len - pos); - self.content = out.apply(&self.content)?; // TODO does applying mutate the OpSeq itself? + out.retain(total - pos); + self.content = out.apply(&self.content)?; Ok(out) } @@ -44,7 +46,7 @@ impl OperationFactory { out.retain(pos - count); out.delete(count); out.retain(len - pos); - self.content = out.apply(&self.content)?; // TODO does applying mutate the OpSeq itself? + self.content = out.apply(&self.content)?; Ok(out) } @@ -54,7 +56,7 @@ impl OperationFactory { out.retain(pos); out.delete(count); out.retain(len - (pos+count)); - self.content = out.apply(&self.content)?; // TODO does applying mutate the OpSeq itself? + self.content = out.apply(&self.content)?; Ok(out) } @@ -63,3 +65,119 @@ impl OperationFactory { Ok(self.content.clone()) } } + + +pub struct AsyncFactory { + run: watch::Sender, + ops: mpsc::Sender, + content: watch::Receiver, +} + +impl Drop for AsyncFactory { + fn drop(&mut self) { + self.run.send(false).unwrap_or(()); + } +} + +impl AsyncFactory { + pub fn new(init: Option) -> Self { + let (run_tx, run_rx) = watch::channel(true); + let (ops_tx, ops_rx) = mpsc::channel(64); // TODO hardcoded size + let (txt_tx, txt_rx) = watch::channel("".into()); + + let worker = AsyncFactoryWorker { + factory: OperationFactory::new(init), + ops: ops_rx, + run: run_rx, + content: txt_tx, + }; + + tokio::spawn(async move { worker.work().await }); + + AsyncFactory { run: run_tx, ops: ops_tx, content: txt_rx } + } + + pub async fn insert(&self, txt: String, pos: u64) -> Result { + let (tx, rx) = oneshot::channel(); + self.ops.send(OpMsg::Exec(OpWrapper::Insert(txt, pos), tx)).await.unwrap(); + rx.await.unwrap() + } + + pub async fn delete(&self, pos: u64, count: u64) -> Result { + let (tx, rx) = oneshot::channel(); + self.ops.send(OpMsg::Exec(OpWrapper::Delete(pos, count), tx)).await.unwrap(); + rx.await.unwrap() + } + + pub async fn cancel(&self, pos: u64, count: u64) -> Result { + let (tx, rx) = oneshot::channel(); + self.ops.send(OpMsg::Exec(OpWrapper::Cancel(pos, count), tx)).await.unwrap(); + rx.await.unwrap() + } + + pub async fn process(&self, opseq: OperationSeq) -> Result { + let (tx, rx) = oneshot::channel(); + self.ops.send(OpMsg::Process(opseq, tx)).await.unwrap(); + rx.await.unwrap() + } +} + + + + + +#[derive(Debug)] +enum OpMsg { + Exec(OpWrapper, oneshot::Sender>), + Process(OperationSeq, oneshot::Sender>), +} + +#[derive(Debug)] +enum OpWrapper { + Insert(String, u64), + Delete(u64, u64), + Cancel(u64, u64), +} + +struct AsyncFactoryWorker { + factory: OperationFactory, + ops: mpsc::Receiver, + run: watch::Receiver, + content: watch::Sender +} + +impl AsyncFactoryWorker { + async fn work(mut self) { + while *self.run.borrow() { + tokio::select! { // periodically check run so that we stop cleanly + + recv = self.ops.recv() => { + match recv { + Some(msg) => { + match msg { + OpMsg::Exec(op, tx) => tx.send(self.exec(op)).unwrap_or(()), + OpMsg::Process(opseq, tx) => tx.send(self.factory.process(opseq)).unwrap_or(()), + } + if let Err(e) = self.content.send(self.factory.content()) { + error!("error updating content: {}", e); + break; + } + }, + None => break, + } + }, + + _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => {}, + + }; + } + } + + fn exec(&mut self, op: OpWrapper) -> Result { + match op { + OpWrapper::Insert(txt, pos) => Ok(self.factory.insert(&txt, pos)?), + OpWrapper::Delete(pos, count) => Ok(self.factory.delete(pos, count)?), + OpWrapper::Cancel(pos, count) => Ok(self.factory.cancel(pos, count)?), + } + } +}