diff --git a/Cargo.lock b/Cargo.lock index 32fba258..86fbc18e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3678,6 +3678,16 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" +[[package]] +name = "signal-hook" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +dependencies = [ + "libc", + "signal-hook-registry", +] + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -3687,6 +3697,18 @@ dependencies = [ "libc", ] +[[package]] +name = "signal-hook-tokio" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213241f76fb1e37e27de3b6aa1b068a2c333233b59cca6634f634b80a27ecf1e" +dependencies = [ + "futures-core", + "libc", + "signal-hook", + "tokio", +] + [[package]] name = "signature" version = "2.2.0" @@ -4723,7 +4745,10 @@ name = "upub-bin" version = "0.3.0" dependencies = [ "clap", + "futures", "sea-orm", + "signal-hook", + "signal-hook-tokio", "tokio", "toml", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 15b64ee1..c3cb171e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,10 @@ tracing = "0.1" tracing-subscriber = "0.3" sea-orm = "0.12" clap = { version = "4.5", features = ["derive"] } +signal-hook = "0.3" +signal-hook-tokio = { version = "0.3", features = ["futures-v0_3"] } tokio = { version = "1.35", features = ["full"] } # TODO slim this down +futures = "0.3" upub = { path = "upub/core" } upub-cli = { path = "upub/cli", optional = true } diff --git a/main.rs b/main.rs index ca9d3752..8cf0bd1e 100644 --- a/main.rs +++ b/main.rs @@ -1,7 +1,11 @@ use std::path::PathBuf; use clap::{Parser, Subcommand}; use sea_orm::{ConnectOptions, Database}; +use signal_hook::consts::signal::*; +use signal_hook_tokio::Signals; +use futures::stream::StreamExt; +use upub::ext::LoggableError; #[cfg(feature = "cli")] use upub_cli as cli; @@ -96,14 +100,6 @@ enum Mode { }, } -#[derive(Debug, Clone, clap::ValueEnum)] -enum Filter { - All, - Delivery, - Inbound, - Outbound, -} - #[tokio::main] async fn main() { @@ -153,6 +149,12 @@ async fn main() { return; } + let (tx, rx) = tokio::sync::watch::channel(false); + let signals = Signals::new(&[SIGTERM, SIGINT]).expect("failed registering signal handler"); + let handle = signals.handle(); + let signals_task = tokio::spawn(handle_signals(signals, tx)); + let stop = CancellationToken(rx); + let ctx = upub::Context::new(db, domain, config.clone()) .await.expect("failed creating server context"); @@ -164,24 +166,68 @@ async fn main() { #[cfg(feature = "serve")] Mode::Serve { bind } => - routes::serve(ctx, bind) + routes::serve(ctx, bind, stop) .await.expect("failed serving api routes"), #[cfg(feature = "worker")] Mode::Work { filter, tasks, poll } => - worker::spawn(ctx, tasks, poll, filter.into()) + worker::spawn(ctx, tasks, poll, filter.into(), stop) .await.expect("failed running worker"), #[cfg(all(feature = "serve", feature = "worker"))] Mode::Monolith { bind, tasks, poll } => { - worker::spawn(ctx.clone(), tasks, poll, None); + worker::spawn(ctx.clone(), tasks, poll, None, stop.clone()); - routes::serve(ctx, bind) + routes::serve(ctx, bind, stop) .await.expect("failed serving api routes"); }, - _ => unreachable!(), + Mode::Config => unreachable!(), + #[cfg(feature = "migrate")] + Mode::Migrate => unreachable!(), } + + handle.close(); + signals_task.await.expect("failed joining signal handler task"); +} + +#[derive(Clone)] +struct CancellationToken(tokio::sync::watch::Receiver); + +impl worker::StopToken for CancellationToken { + fn stop(&self) -> bool { + *self.0.borrow() + } +} + +#[sea_orm::prelude::async_trait::async_trait] // ahahaha we avoid this??? +impl routes::ShutdownToken for CancellationToken { + async fn event(mut self) { + self.0.changed().await.warn_failed("cancellation token channel closed, stopping..."); + } +} + +async fn handle_signals( + mut signals: signal_hook_tokio::Signals, + tx: tokio::sync::watch::Sender, +) { + while let Some(signal) = signals.next().await { + match signal { + SIGTERM | SIGINT => { + tracing::info!("received stop signal, closing tasks"); + tx.send(true).info_failed("error sending stop signal to tasks") + }, + _ => unreachable!(), + } + } +} + +#[derive(Debug, Clone, clap::ValueEnum)] +enum Filter { + All, + Delivery, + Inbound, + Outbound, } impl From for Option { diff --git a/upub/routes/src/lib.rs b/upub/routes/src/lib.rs index de5e3d79..be12f9af 100644 --- a/upub/routes/src/lib.rs +++ b/upub/routes/src/lib.rs @@ -25,7 +25,7 @@ pub mod mastodon { impl MastodonRouter for axum::Router {} } -pub async fn serve(ctx: upub::Context, bind: String) -> Result<(), std::io::Error> { +pub async fn serve(ctx: upub::Context, bind: String, shutdown: impl ShutdownToken) -> Result<(), std::io::Error> { use activitypub::ActivityPubRouter; use mastodon::MastodonRouter; use tower_http::{cors::CorsLayer, trace::TraceLayer}; @@ -52,7 +52,14 @@ pub async fn serve(ctx: upub::Context, bind: String) -> Result<(), std::io::Erro tracing::info!("serving api routes on {bind}"); let listener = tokio::net::TcpListener::bind(bind).await?; - axum::serve(listener, router).await?; + axum::serve(listener, router) + .with_graceful_shutdown(async move { shutdown.event().await }) + .await?; Ok(()) } + +#[axum::async_trait] +pub trait ShutdownToken: Sync + Send + 'static { + async fn event(self); +} diff --git a/upub/worker/src/dispatcher.rs b/upub/worker/src/dispatcher.rs index df293f5f..ccb23ecb 100644 --- a/upub/worker/src/dispatcher.rs +++ b/upub/worker/src/dispatcher.rs @@ -29,7 +29,7 @@ pub type JobResult = Result; pub trait JobDispatcher : Sized { async fn poll(&self, filter: Option) -> JobResult>; async fn lock(&self, job_internal: i64) -> JobResult; - async fn run(self, concurrency: usize, poll_interval: u64, job_filter: Option); + async fn run(self, concurrency: usize, poll_interval: u64, job_filter: Option, stop: impl crate::StopToken); } #[async_trait::async_trait] @@ -67,7 +67,7 @@ impl JobDispatcher for Context { Ok(true) } - async fn run(self, concurrency: usize, poll_interval: u64, job_filter: Option) { + async fn run(self, concurrency: usize, poll_interval: u64, job_filter: Option, stop: impl crate::StopToken) { macro_rules! restart { (now) => { continue }; () => { @@ -81,6 +81,8 @@ impl JobDispatcher for Context { let mut pool = tokio::task::JoinSet::new(); loop { + if stop.stop() { break } + let job = match self.poll(job_filter).await { Ok(Some(j)) => j, Ok(None) => restart!(), @@ -154,5 +156,12 @@ impl JobDispatcher for Context { } } } + + while let Some(joined) = pool.join_next().await { + if let Err(e) = joined { + tracing::error!("failed joining process task: {e}"); + } + } + } } diff --git a/upub/worker/src/lib.rs b/upub/worker/src/lib.rs index 0b2bbf97..72d596ba 100644 --- a/upub/worker/src/lib.rs +++ b/upub/worker/src/lib.rs @@ -10,10 +10,15 @@ pub fn spawn( concurrency: usize, poll: u64, filter: Option, + stop: impl StopToken, ) -> tokio::task::JoinHandle<()> { use dispatcher::JobDispatcher; tokio::spawn(async move { tracing::info!("starting worker task"); - ctx.run(concurrency, poll, filter).await + ctx.run(concurrency, poll, filter, stop).await }) } + +pub trait StopToken: Sync + Send + 'static { + fn stop(&self) -> bool; +}