diff --git a/src/dispatcher.rs b/src/dispatcher.rs index 0b10414b..7edca27d 100644 --- a/src/dispatcher.rs +++ b/src/dispatcher.rs @@ -2,23 +2,38 @@ use base64::Engine; use openssl::{hash::MessageDigest, pkey::{PKey, Private}, sign::Signer}; use reqwest::header::{CONTENT_TYPE, USER_AGENT}; use sea_orm::{ColumnTrait, Condition, DatabaseConnection, EntityTrait, Order, QueryFilter, QueryOrder}; -use tokio::task::JoinHandle; +use tokio::{sync::broadcast, task::JoinHandle}; use crate::{activitypub::{activity::ap_activity, object::ap_object}, activitystream::{object::activity::ActivityMut, Node}, errors::UpubError, model, server::Context, VERSION}; -pub struct Dispatcher; +pub struct Dispatcher { + waker: broadcast::Sender<()>, +} impl Dispatcher { - pub fn spawn(db: DatabaseConnection, domain: String, poll_interval: u64) -> JoinHandle<()> { + pub fn new() -> Self { + let (waker, _) = broadcast::channel(1); + Dispatcher { waker } + } + + pub fn spawn(&self, db: DatabaseConnection, domain: String, poll_interval: u64) -> JoinHandle<()> { + let waker = self.waker.subscribe(); tokio::spawn(async move { - if let Err(e) = worker(db, domain, poll_interval).await { + if let Err(e) = worker(db, domain, poll_interval, waker).await { tracing::error!("delivery worker exited with error: {e}"); } }) } + + pub fn wakeup(&self) { + match self.waker.send(()) { + Err(_) => tracing::error!("no worker to wakeup"), + Ok(n) => tracing::debug!("woken {n} workers"), + } + } } -async fn worker(db: DatabaseConnection, domain: String, poll_interval: u64) -> Result<(), UpubError> { +async fn worker(db: DatabaseConnection, domain: String, poll_interval: u64, mut waker: broadcast::Receiver<()>) -> Result<(), UpubError> { loop { let Some(delivery) = model::delivery::Entity::find() .filter(Condition::all().add(model::delivery::Column::NotBefore.lte(chrono::Utc::now()))) @@ -26,7 +41,11 @@ async fn worker(db: DatabaseConnection, domain: String, poll_interval: u64) -> R .one(&db) .await? else { - tokio::time::sleep(std::time::Duration::from_secs(poll_interval)).await; + tokio::select! { + biased; + _ = waker.recv() => {}, + _ = tokio::time::sleep(std::time::Duration::from_secs(poll_interval)) => {}, + } continue }; diff --git a/src/server.rs b/src/server.rs index bce4be6c..560b47d4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -12,6 +12,7 @@ struct ContextInner { domain: String, protocol: String, fetcher: Fetcher, + dispatcher: Dispatcher, // TODO keep these pre-parsed app: model::application::Model, } @@ -45,8 +46,9 @@ impl Context { if domain.starts_with("http") { domain = domain.replace("https://", "").replace("http://", ""); } + let dispatcher = Dispatcher::new(); for _ in 0..1 { // TODO customize delivery workers amount - Dispatcher::spawn(db.clone(), domain.clone(), 30); // TODO ew don't do it this deep and secretly!! + dispatcher.spawn(db.clone(), domain.clone(), 30); // TODO ew don't do it this deep and secretly!! } let app = match model::application::Entity::find().one(&db).await? { Some(model) => model, @@ -70,7 +72,7 @@ impl Context { let fetcher = Fetcher::new(db.clone(), domain.clone(), app.private_key.clone()); Ok(Context(Arc::new(ContextInner { - db, domain, protocol, app, fetcher, + db, domain, protocol, app, fetcher, dispatcher, }))) } @@ -200,6 +202,8 @@ impl Context { .await?; } + self.0.dispatcher.wakeup(); + Ok(()) } }