feat: immediately wakeup dispatcher when sending

This commit is contained in:
əlemi 2024-03-27 05:09:20 +01:00
parent 70c1f96959
commit 4274367285
Signed by: alemi
GPG key ID: A4895B84D311642C
2 changed files with 31 additions and 8 deletions

View file

@ -2,23 +2,38 @@ use base64::Engine;
use openssl::{hash::MessageDigest, pkey::{PKey, Private}, sign::Signer};
use reqwest::header::{CONTENT_TYPE, USER_AGENT};
use sea_orm::{ColumnTrait, Condition, DatabaseConnection, EntityTrait, Order, QueryFilter, QueryOrder};
use tokio::task::JoinHandle;
use tokio::{sync::broadcast, task::JoinHandle};
use crate::{activitypub::{activity::ap_activity, object::ap_object}, activitystream::{object::activity::ActivityMut, Node}, errors::UpubError, model, server::Context, VERSION};
pub struct Dispatcher;
pub struct Dispatcher {
waker: broadcast::Sender<()>,
}
impl Dispatcher {
pub fn spawn(db: DatabaseConnection, domain: String, poll_interval: u64) -> JoinHandle<()> {
pub fn new() -> Self {
let (waker, _) = broadcast::channel(1);
Dispatcher { waker }
}
pub fn spawn(&self, db: DatabaseConnection, domain: String, poll_interval: u64) -> JoinHandle<()> {
let waker = self.waker.subscribe();
tokio::spawn(async move {
if let Err(e) = worker(db, domain, poll_interval).await {
if let Err(e) = worker(db, domain, poll_interval, waker).await {
tracing::error!("delivery worker exited with error: {e}");
}
})
}
pub fn wakeup(&self) {
match self.waker.send(()) {
Err(_) => tracing::error!("no worker to wakeup"),
Ok(n) => tracing::debug!("woken {n} workers"),
}
}
}
async fn worker(db: DatabaseConnection, domain: String, poll_interval: u64) -> Result<(), UpubError> {
async fn worker(db: DatabaseConnection, domain: String, poll_interval: u64, mut waker: broadcast::Receiver<()>) -> Result<(), UpubError> {
loop {
let Some(delivery) = model::delivery::Entity::find()
.filter(Condition::all().add(model::delivery::Column::NotBefore.lte(chrono::Utc::now())))
@ -26,7 +41,11 @@ async fn worker(db: DatabaseConnection, domain: String, poll_interval: u64) -> R
.one(&db)
.await?
else {
tokio::time::sleep(std::time::Duration::from_secs(poll_interval)).await;
tokio::select! {
biased;
_ = waker.recv() => {},
_ = tokio::time::sleep(std::time::Duration::from_secs(poll_interval)) => {},
}
continue
};

View file

@ -12,6 +12,7 @@ struct ContextInner {
domain: String,
protocol: String,
fetcher: Fetcher,
dispatcher: Dispatcher,
// TODO keep these pre-parsed
app: model::application::Model,
}
@ -45,8 +46,9 @@ impl Context {
if domain.starts_with("http") {
domain = domain.replace("https://", "").replace("http://", "");
}
let dispatcher = Dispatcher::new();
for _ in 0..1 { // TODO customize delivery workers amount
Dispatcher::spawn(db.clone(), domain.clone(), 30); // TODO ew don't do it this deep and secretly!!
dispatcher.spawn(db.clone(), domain.clone(), 30); // TODO ew don't do it this deep and secretly!!
}
let app = match model::application::Entity::find().one(&db).await? {
Some(model) => model,
@ -70,7 +72,7 @@ impl Context {
let fetcher = Fetcher::new(db.clone(), domain.clone(), app.private_key.clone());
Ok(Context(Arc::new(ContextInner {
db, domain, protocol, app, fetcher,
db, domain, protocol, app, fetcher, dispatcher,
})))
}
@ -200,6 +202,8 @@ impl Context {
.await?;
}
self.0.dispatcher.wakeup();
Ok(())
}
}