1
0
Fork 0
forked from alemi/upub

feat: immediately wakeup dispatcher when sending

This commit is contained in:
əlemi 2024-03-27 05:09:20 +01:00
parent 70c1f96959
commit 4274367285
Signed by: alemi
GPG key ID: A4895B84D311642C
2 changed files with 31 additions and 8 deletions

View file

@ -2,23 +2,38 @@ use base64::Engine;
use openssl::{hash::MessageDigest, pkey::{PKey, Private}, sign::Signer}; use openssl::{hash::MessageDigest, pkey::{PKey, Private}, sign::Signer};
use reqwest::header::{CONTENT_TYPE, USER_AGENT}; use reqwest::header::{CONTENT_TYPE, USER_AGENT};
use sea_orm::{ColumnTrait, Condition, DatabaseConnection, EntityTrait, Order, QueryFilter, QueryOrder}; use sea_orm::{ColumnTrait, Condition, DatabaseConnection, EntityTrait, Order, QueryFilter, QueryOrder};
use tokio::task::JoinHandle; use tokio::{sync::broadcast, task::JoinHandle};
use crate::{activitypub::{activity::ap_activity, object::ap_object}, activitystream::{object::activity::ActivityMut, Node}, errors::UpubError, model, server::Context, VERSION}; use crate::{activitypub::{activity::ap_activity, object::ap_object}, activitystream::{object::activity::ActivityMut, Node}, errors::UpubError, model, server::Context, VERSION};
pub struct Dispatcher; pub struct Dispatcher {
waker: broadcast::Sender<()>,
}
impl Dispatcher { impl Dispatcher {
pub fn spawn(db: DatabaseConnection, domain: String, poll_interval: u64) -> JoinHandle<()> { pub fn new() -> Self {
let (waker, _) = broadcast::channel(1);
Dispatcher { waker }
}
pub fn spawn(&self, db: DatabaseConnection, domain: String, poll_interval: u64) -> JoinHandle<()> {
let waker = self.waker.subscribe();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = worker(db, domain, poll_interval).await { if let Err(e) = worker(db, domain, poll_interval, waker).await {
tracing::error!("delivery worker exited with error: {e}"); tracing::error!("delivery worker exited with error: {e}");
} }
}) })
} }
pub fn wakeup(&self) {
match self.waker.send(()) {
Err(_) => tracing::error!("no worker to wakeup"),
Ok(n) => tracing::debug!("woken {n} workers"),
}
}
} }
async fn worker(db: DatabaseConnection, domain: String, poll_interval: u64) -> Result<(), UpubError> { async fn worker(db: DatabaseConnection, domain: String, poll_interval: u64, mut waker: broadcast::Receiver<()>) -> Result<(), UpubError> {
loop { loop {
let Some(delivery) = model::delivery::Entity::find() let Some(delivery) = model::delivery::Entity::find()
.filter(Condition::all().add(model::delivery::Column::NotBefore.lte(chrono::Utc::now()))) .filter(Condition::all().add(model::delivery::Column::NotBefore.lte(chrono::Utc::now())))
@ -26,7 +41,11 @@ async fn worker(db: DatabaseConnection, domain: String, poll_interval: u64) -> R
.one(&db) .one(&db)
.await? .await?
else { else {
tokio::time::sleep(std::time::Duration::from_secs(poll_interval)).await; tokio::select! {
biased;
_ = waker.recv() => {},
_ = tokio::time::sleep(std::time::Duration::from_secs(poll_interval)) => {},
}
continue continue
}; };

View file

@ -12,6 +12,7 @@ struct ContextInner {
domain: String, domain: String,
protocol: String, protocol: String,
fetcher: Fetcher, fetcher: Fetcher,
dispatcher: Dispatcher,
// TODO keep these pre-parsed // TODO keep these pre-parsed
app: model::application::Model, app: model::application::Model,
} }
@ -45,8 +46,9 @@ impl Context {
if domain.starts_with("http") { if domain.starts_with("http") {
domain = domain.replace("https://", "").replace("http://", ""); domain = domain.replace("https://", "").replace("http://", "");
} }
let dispatcher = Dispatcher::new();
for _ in 0..1 { // TODO customize delivery workers amount for _ in 0..1 { // TODO customize delivery workers amount
Dispatcher::spawn(db.clone(), domain.clone(), 30); // TODO ew don't do it this deep and secretly!! dispatcher.spawn(db.clone(), domain.clone(), 30); // TODO ew don't do it this deep and secretly!!
} }
let app = match model::application::Entity::find().one(&db).await? { let app = match model::application::Entity::find().one(&db).await? {
Some(model) => model, Some(model) => model,
@ -70,7 +72,7 @@ impl Context {
let fetcher = Fetcher::new(db.clone(), domain.clone(), app.private_key.clone()); let fetcher = Fetcher::new(db.clone(), domain.clone(), app.private_key.clone());
Ok(Context(Arc::new(ContextInner { Ok(Context(Arc::new(ContextInner {
db, domain, protocol, app, fetcher, db, domain, protocol, app, fetcher, dispatcher,
}))) })))
} }
@ -200,6 +202,8 @@ impl Context {
.await?; .await?;
} }
self.0.dispatcher.wakeup();
Ok(()) Ok(())
} }
} }