Compare commits

..

2 commits

Author SHA1 Message Date
bc747af055
feat: graceful shutdown to not lose tasks 2024-06-10 04:13:15 +02:00
ec910693d9
fix: oops didnt actually fix the comparison 2024-06-10 04:07:58 +02:00
7 changed files with 131 additions and 38 deletions

25
Cargo.lock generated
View file

@ -3678,6 +3678,16 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde"
[[package]]
name = "signal-hook"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801"
dependencies = [
"libc",
"signal-hook-registry",
]
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.2" version = "1.4.2"
@ -3687,6 +3697,18 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "signal-hook-tokio"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "213241f76fb1e37e27de3b6aa1b068a2c333233b59cca6634f634b80a27ecf1e"
dependencies = [
"futures-core",
"libc",
"signal-hook",
"tokio",
]
[[package]] [[package]]
name = "signature" name = "signature"
version = "2.2.0" version = "2.2.0"
@ -4723,7 +4745,10 @@ name = "upub-bin"
version = "0.3.0" version = "0.3.0"
dependencies = [ dependencies = [
"clap", "clap",
"futures",
"sea-orm", "sea-orm",
"signal-hook",
"signal-hook-tokio",
"tokio", "tokio",
"toml", "toml",
"tracing", "tracing",

View file

@ -33,7 +33,10 @@ tracing = "0.1"
tracing-subscriber = "0.3" tracing-subscriber = "0.3"
sea-orm = "0.12" sea-orm = "0.12"
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
signal-hook = "0.3"
signal-hook-tokio = { version = "0.3", features = ["futures-v0_3"] }
tokio = { version = "1.35", features = ["full"] } # TODO slim this down tokio = { version = "1.35", features = ["full"] } # TODO slim this down
futures = "0.3"
upub = { path = "upub/core" } upub = { path = "upub/core" }
upub-cli = { path = "upub/cli", optional = true } upub-cli = { path = "upub/cli", optional = true }

72
main.rs
View file

@ -1,7 +1,11 @@
use std::path::PathBuf; use std::path::PathBuf;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use sea_orm::{ConnectOptions, Database}; use sea_orm::{ConnectOptions, Database};
use signal_hook::consts::signal::*;
use signal_hook_tokio::Signals;
use futures::stream::StreamExt;
use upub::ext::LoggableError;
#[cfg(feature = "cli")] #[cfg(feature = "cli")]
use upub_cli as cli; use upub_cli as cli;
@ -96,14 +100,6 @@ enum Mode {
}, },
} }
#[derive(Debug, Clone, clap::ValueEnum)]
enum Filter {
All,
Delivery,
Inbound,
Outbound,
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@ -153,6 +149,12 @@ async fn main() {
return; return;
} }
let (tx, rx) = tokio::sync::watch::channel(false);
let signals = Signals::new(&[SIGTERM, SIGINT]).expect("failed registering signal handler");
let handle = signals.handle();
let signals_task = tokio::spawn(handle_signals(signals, tx));
let stop = CancellationToken(rx);
let ctx = upub::Context::new(db, domain, config.clone()) let ctx = upub::Context::new(db, domain, config.clone())
.await.expect("failed creating server context"); .await.expect("failed creating server context");
@ -164,24 +166,68 @@ async fn main() {
#[cfg(feature = "serve")] #[cfg(feature = "serve")]
Mode::Serve { bind } => Mode::Serve { bind } =>
routes::serve(ctx, bind) routes::serve(ctx, bind, stop)
.await.expect("failed serving api routes"), .await.expect("failed serving api routes"),
#[cfg(feature = "worker")] #[cfg(feature = "worker")]
Mode::Work { filter, tasks, poll } => Mode::Work { filter, tasks, poll } =>
worker::spawn(ctx, tasks, poll, filter.into()) worker::spawn(ctx, tasks, poll, filter.into(), stop)
.await.expect("failed running worker"), .await.expect("failed running worker"),
#[cfg(all(feature = "serve", feature = "worker"))] #[cfg(all(feature = "serve", feature = "worker"))]
Mode::Monolith { bind, tasks, poll } => { Mode::Monolith { bind, tasks, poll } => {
worker::spawn(ctx.clone(), tasks, poll, None); worker::spawn(ctx.clone(), tasks, poll, None, stop.clone());
routes::serve(ctx, bind) routes::serve(ctx, bind, stop)
.await.expect("failed serving api routes"); .await.expect("failed serving api routes");
}, },
_ => unreachable!(), Mode::Config => unreachable!(),
#[cfg(feature = "migrate")]
Mode::Migrate => unreachable!(),
} }
handle.close();
signals_task.await.expect("failed joining signal handler task");
}
#[derive(Clone)]
struct CancellationToken(tokio::sync::watch::Receiver<bool>);
impl worker::StopToken for CancellationToken {
fn stop(&self) -> bool {
*self.0.borrow()
}
}
#[sea_orm::prelude::async_trait::async_trait] // ahahaha we avoid this???
impl routes::ShutdownToken for CancellationToken {
async fn event(mut self) {
self.0.changed().await.warn_failed("cancellation token channel closed, stopping...");
}
}
async fn handle_signals(
mut signals: signal_hook_tokio::Signals,
tx: tokio::sync::watch::Sender<bool>,
) {
while let Some(signal) = signals.next().await {
match signal {
SIGTERM | SIGINT => {
tracing::info!("received stop signal, closing tasks");
tx.send(true).info_failed("error sending stop signal to tasks")
},
_ => unreachable!(),
}
}
}
#[derive(Debug, Clone, clap::ValueEnum)]
enum Filter {
All,
Delivery,
Inbound,
Outbound,
} }
impl From<Filter> for Option<upub::model::job::JobType> { impl From<Filter> for Option<upub::model::job::JobType> {

View file

@ -62,30 +62,28 @@ pub async fn page<const OUTGOING: bool>(
let limit = page.batch.unwrap_or(20).min(50); let limit = page.batch.unwrap_or(20).min(50);
let offset = page.offset.unwrap_or(0); let offset = page.offset.unwrap_or(0);
let mut filter = Condition::all() let (user, config) = model::actor::Entity::find_by_ap_id(&ctx.uid(&id))
.add(if OUTGOING { Follower } else { Following }.eq(ctx.uid(&id))); .find_also_related(model::config::Entity)
.one(ctx.db())
.await?
.ok_or_else(ApiError::not_found)?;
let hidden = { let hidden = match config {
// TODO i could avoid this query if ctx.uid(id) == Identity::Local { id } // assume all remote users have private followers
match model::actor::Entity::find_by_ap_id(&ctx.uid(&id)) // this because we get to see some of their "private" followers if they follow local users,
.find_also_related(model::config::Entity) // and there is no mechanism to broadcast privacy on/off, so we could be leaking followers. to
.one(ctx.db()) // mitigate this, just assume them all private: local users can only see themselves and remote
.await? // fetchers can only see relations from their instance (meaning likely zero because we only
.ok_or_else(ApiError::not_found)? // store relations for which at least one end is on local instance)
{ None => true,
// assume all remote users have private followers Some(config) => {
// this because we get to see some of their "private" followers if they follow local users, if OUTGOING { !config.show_followers } else { !config.show_following }
// and there is no mechanism to broadcast privacy on/off, so we could be leaking followers. to
// mitigate this, just assume them all private: local users can only see themselves and remote
// fetchers can only see relations from their instance (meaning likely zero because we only
// store relations for which at least one end is on local instance)
(_, None) => true,
(_, Some(config)) => {
if OUTGOING { !config.show_followers } else { !config.show_following }
},
} }
}; };
let mut filter = Condition::all()
.add(if OUTGOING { Follower } else { Following }.eq(user.internal));
if hidden { if hidden {
match auth { match auth {
Identity::Anonymous => return Err(ApiError::unauthorized()), Identity::Anonymous => return Err(ApiError::unauthorized()),

View file

@ -25,7 +25,7 @@ pub mod mastodon {
impl MastodonRouter for axum::Router<upub::Context> {} impl MastodonRouter for axum::Router<upub::Context> {}
} }
pub async fn serve(ctx: upub::Context, bind: String) -> Result<(), std::io::Error> { pub async fn serve(ctx: upub::Context, bind: String, shutdown: impl ShutdownToken) -> Result<(), std::io::Error> {
use activitypub::ActivityPubRouter; use activitypub::ActivityPubRouter;
use mastodon::MastodonRouter; use mastodon::MastodonRouter;
use tower_http::{cors::CorsLayer, trace::TraceLayer}; use tower_http::{cors::CorsLayer, trace::TraceLayer};
@ -52,7 +52,14 @@ pub async fn serve(ctx: upub::Context, bind: String) -> Result<(), std::io::Erro
tracing::info!("serving api routes on {bind}"); tracing::info!("serving api routes on {bind}");
let listener = tokio::net::TcpListener::bind(bind).await?; let listener = tokio::net::TcpListener::bind(bind).await?;
axum::serve(listener, router).await?; axum::serve(listener, router)
.with_graceful_shutdown(async move { shutdown.event().await })
.await?;
Ok(()) Ok(())
} }
#[axum::async_trait]
pub trait ShutdownToken: Sync + Send + 'static {
async fn event(self);
}

View file

@ -29,7 +29,7 @@ pub type JobResult<T> = Result<T, JobError>;
pub trait JobDispatcher : Sized { pub trait JobDispatcher : Sized {
async fn poll(&self, filter: Option<model::job::JobType>) -> JobResult<Option<model::job::Model>>; async fn poll(&self, filter: Option<model::job::JobType>) -> JobResult<Option<model::job::Model>>;
async fn lock(&self, job_internal: i64) -> JobResult<bool>; async fn lock(&self, job_internal: i64) -> JobResult<bool>;
async fn run(self, concurrency: usize, poll_interval: u64, job_filter: Option<model::job::JobType>); async fn run(self, concurrency: usize, poll_interval: u64, job_filter: Option<model::job::JobType>, stop: impl crate::StopToken);
} }
#[async_trait::async_trait] #[async_trait::async_trait]
@ -67,7 +67,7 @@ impl JobDispatcher for Context {
Ok(true) Ok(true)
} }
async fn run(self, concurrency: usize, poll_interval: u64, job_filter: Option<model::job::JobType>) { async fn run(self, concurrency: usize, poll_interval: u64, job_filter: Option<model::job::JobType>, stop: impl crate::StopToken) {
macro_rules! restart { macro_rules! restart {
(now) => { continue }; (now) => { continue };
() => { () => {
@ -81,6 +81,8 @@ impl JobDispatcher for Context {
let mut pool = tokio::task::JoinSet::new(); let mut pool = tokio::task::JoinSet::new();
loop { loop {
if stop.stop() { break }
let job = match self.poll(job_filter).await { let job = match self.poll(job_filter).await {
Ok(Some(j)) => j, Ok(Some(j)) => j,
Ok(None) => restart!(), Ok(None) => restart!(),
@ -154,5 +156,12 @@ impl JobDispatcher for Context {
} }
} }
} }
while let Some(joined) = pool.join_next().await {
if let Err(e) = joined {
tracing::error!("failed joining process task: {e}");
}
}
} }
} }

View file

@ -10,10 +10,15 @@ pub fn spawn(
concurrency: usize, concurrency: usize,
poll: u64, poll: u64,
filter: Option<upub::model::job::JobType>, filter: Option<upub::model::job::JobType>,
stop: impl StopToken,
) -> tokio::task::JoinHandle<()> { ) -> tokio::task::JoinHandle<()> {
use dispatcher::JobDispatcher; use dispatcher::JobDispatcher;
tokio::spawn(async move { tokio::spawn(async move {
tracing::info!("starting worker task"); tracing::info!("starting worker task");
ctx.run(concurrency, poll, filter).await ctx.run(concurrency, poll, filter, stop).await
}) })
} }
pub trait StopToken: Sync + Send + 'static {
fn stop(&self) -> bool;
}